diff --git a/BaseLib/MPI.h b/BaseLib/MPI.h index bdab781de4282082b4b630230edd7e50c282bb1d..de83e3fc4dd01a4459c253fcff2444c9763b13cf 100644 --- a/BaseLib/MPI.h +++ b/BaseLib/MPI.h @@ -51,5 +51,53 @@ struct Mpi int size; int rank; }; + +template <typename T> +constexpr MPI_Datatype mpiType() +{ + using U = std::remove_const_t<T>; + if constexpr (std::is_same_v<U, bool>) + { + return MPI_C_BOOL; + } + if constexpr (std::is_same_v<U, char>) + { + return MPI_CHAR; + } + if constexpr (std::is_same_v<U, double>) + { + return MPI_DOUBLE; + } + if constexpr (std::is_same_v<U, float>) + { + return MPI_FLOAT; + } + if constexpr (std::is_same_v<U, int>) + { + return MPI_INT; + } + if constexpr (std::is_same_v<U, std::size_t>) + { + return MPI_UNSIGNED_LONG; + } + if constexpr (std::is_same_v<U, unsigned int>) + { + return MPI_UNSIGNED; + } +} + +template <typename T> +static std::vector<T> allgather(T const& value, Mpi const& mpi) +{ + std::vector<T> result(mpi.size); + + result[mpi.rank] = value; + + MPI_Allgather(&result[mpi.rank], 1, mpiType<T>(), result.data(), 1, + mpiType<T>(), mpi.communicator); + + return result; +} + #endif } // namespace BaseLib::MPI