#include #include "stdio.h" #include "mpi.h" #include "stdlib.h" #include "pthread.h" #include "semaphore.h" #ifndef num_threads #define num_threads 4 #endif #ifndef num_reps #define num_reps 10 #endif #ifndef sem_value #define sem_value 1 #endif int rank, num_nodes; int ids[num_threads]; pthread_t threads[num_threads]; MPI_Comm comms[num_threads]; sem_t sem; int mpi_init(int argc, char *argv[]) { int provided; MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided); if (provided != MPI_THREAD_MULTIPLE) { printf("NO MULTIPLE THREAD SUPPORT\n"); fflush(stdout); MPI_Finalize(); return -1; } MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &num_nodes); return 0; } void mpi_finalize() { MPI_Finalize(); } void *thread_fcn(void *arg) { int id = *(int*)arg; int rep; for (rep = 0; rep < num_reps; rep++) { if (rank == 0) sem_wait(&sem); int sum = 0; MPI_Comm comm_i; MPI_Comm_dup(comms[id], &comm_i); MPI_Allreduce(&id, &sum, 1, MPI_INT, MPI_SUM, comm_i); assert(sum == 2 * id); MPI_Comm_free(&comm_i); if (rank == 0) sem_post(&sem); } } int main(int argc, char* argv[]) { if (mpi_init(argc, argv) != 0) { exit(1); } assert(num_nodes == 2); if (rank == 0) sem_init(&sem, 0, sem_value); int i; for (i = 0; i < num_threads; i++) { ids[i] = i; MPI_Comm_dup(MPI_COMM_WORLD, &comms[i]); } for (i = 0; i < num_threads; i++) { pthread_create(&threads[i], NULL, thread_fcn, &ids[i]); } for (i = 0; i < num_threads; i++) { pthread_join(threads[i], NULL); } printf("Done.\n"); if (rank == 0) sem_destroy(&sem); mpi_finalize(); return 0; }