#include #include "mpi.h" #define CHECK(ret, const) if (ret != const) return 1 /* Base broadcast function */ int MYMPI_Bcast_(void *buf, int s, MPI_Datatype d_type, int rt, MPI_Comm com) { int rnk, n_task, src, dst, ret; MPI_Status stat; MPI_Comm_size(MPI_COMM_WORLD, &n_task); if (n_task < 2) return 1; MPI_Comm_rank(com, &rnk); src = rnk == 0 ? n_task - 1 : rnk - 1; dst = (rnk + 1) % n_task; if (rnk == rt) { ret = MPI_Send(buf, s, d_type, dst, 99, com); CHECK(ret, MPI_SUCCESS); } else { ret = MPI_Recv(buf, s, d_type, src, 99, com, &stat); CHECK(ret, MPI_SUCCESS); /* Pass message on if task is not to the left of the root */ if ((rnk + 1) % n_task != rt) { ret = MPI_Send(buf, s, d_type, dst, 99, com); CHECK(ret, MPI_SUCCESS); } } return 0; } /* Optimized broadcast function */ int MYMPI_Bcast(void *buf, int s, MPI_Datatype d_type, int rt, MPI_Comm com) { int rnk, n_task, bef, aft, ret; MPI_Status stat; MPI_Comm_size(MPI_COMM_WORLD, &n_task); if (n_task < 2) return 1; float half = n_task / 2; float n_half = rt + half > n_task ? (rt + half) - n_task : rt + half; MPI_Comm_rank(com, &rnk); bef = rnk == 0 ? n_task - 1 : rnk - 1; aft = (rnk + 1) % n_task; if (rnk == rt) { ret = MPI_Send(buf, s, d_type, aft, 99, com); CHECK(ret, MPI_SUCCESS); /* Send to 2nd neighbouring task if available */ if (n_task > 2) { ret = MPI_Send(buf, s, d_type, bef, 99, com); CHECK(ret, MPI_SUCCESS); } } else { int src = rnk <= n_half ? bef : aft; int dst = rnk <= n_half ? aft : bef; ret = MPI_Recv(buf, s, d_type, src, 99, com, &stat); CHECK(ret, MPI_SUCCESS); if ((rnk <= n_half && dst <= n_half) || (rnk > n_half && dst > n_half)) { ret = MPI_Send(buf, s, d_type, dst, 99, com); CHECK(ret, MPI_SUCCESS); } } return 0; }