Stan Math Library  2.20.0
reverse mode automatic differentiation
mpi_cluster.hpp
Go to the documentation of this file.
1 #ifdef STAN_MPI
2 
3 #ifndef STAN_MATH_PRIM_ARR_FUNCTOR_MPI_CLUSTER_HPP
4 #define STAN_MATH_PRIM_ARR_FUNCTOR_MPI_CLUSTER_HPP
5 
7 
8 #include <boost/mpi/allocator.hpp>
9 #include <boost/mpi/collectives.hpp>
10 #include <boost/mpi/communicator.hpp>
11 #include <boost/mpi/datatype.hpp>
12 #include <boost/mpi/environment.hpp>
13 #include <boost/mpi/nonblocking.hpp>
14 #include <boost/mpi/operations.hpp>
15 
16 #include <boost/serialization/access.hpp>
17 #include <boost/serialization/base_object.hpp>
18 #include <boost/serialization/export.hpp>
19 #include <boost/serialization/shared_ptr.hpp>
20 
21 #include <mutex>
22 #include <vector>
23 #include <memory>
24 
25 namespace stan {
26 namespace math {
27 
32 class mpi_stop_listen : public std::exception {
33  virtual const char* what() const throw() {
34  return "Stopping MPI listening mode.";
35  }
36 };
37 
41 class mpi_is_in_use : public std::exception {
42  virtual const char* what() const throw() { return "MPI resource is in use."; }
43 };
44 
49 struct mpi_stop_worker : public mpi_command {
50  friend class boost::serialization::access;
51  template <class Archive>
52  void serialize(Archive& ar, const unsigned int version) {
53  ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(mpi_command);
54  }
55  void run() const {
56  boost::mpi::communicator world;
57  throw mpi_stop_listen();
58  }
59 };
60 
84 inline std::vector<int> mpi_map_chunks(std::size_t num_jobs,
85  std::size_t chunk_size = 1) {
86  boost::mpi::communicator world;
87  const std::size_t world_size = world.size();
88 
89  std::vector<int> chunks(world_size, num_jobs / world_size);
90 
91  const std::size_t delta_r = chunks[0] == 0 ? 0 : 1;
92 
93  for (std::size_t r = 0; r != num_jobs % world_size; ++r)
94  ++chunks[r + delta_r];
95 
96  for (std::size_t i = 0; i != world_size; ++i)
97  chunks[i] *= chunk_size;
98 
99  return chunks;
100 }
101 
102 template <typename T>
103 std::unique_lock<std::mutex> mpi_broadcast_command();
104 
126 struct mpi_cluster {
127  boost::mpi::environment env;
128  boost::mpi::communicator world_;
129  std::size_t const rank_ = world_.rank();
130 
131  mpi_cluster() {}
132 
133  ~mpi_cluster() {
134  // the destructor will ensure that the childs are being
135  // shutdown
136  stop_listen();
137  }
138 
149  void listen() {
150  listening_status() = true;
151  if (rank_ == 0) {
152  return;
153  }
154 
155  try {
156  // lock on the workers the cluster as MPI commands must be
157  // initiated from the root and any attempt to do this on the
158  // workers must fail
159  std::unique_lock<std::mutex> worker_lock(in_use());
160  while (1) {
161  std::shared_ptr<mpi_command> work;
162 
163  boost::mpi::broadcast(world_, work, 0);
164 
165  work->run();
166  }
167  } catch (const mpi_stop_listen& e) {
168  }
169  }
170 
176  void stop_listen() {
177  if (rank_ == 0 && listening_status()) {
178  mpi_broadcast_command<mpi_stop_worker>();
179  }
180  listening_status() = false;
181  }
182 
186  static bool& listening_status() {
187  static bool listening_status = false;
188  return listening_status;
189  }
190 
194  static std::mutex& in_use() {
195  static std::mutex in_use_mutex;
196  return in_use_mutex;
197  }
198 };
199 
209 inline std::unique_lock<std::mutex> mpi_broadcast_command(
210  std::shared_ptr<mpi_command>& command) {
211  boost::mpi::communicator world;
212 
213  if (world.rank() != 0)
214  throw std::runtime_error("only root may broadcast commands.");
215 
216  if (!mpi_cluster::listening_status())
217  throw std::runtime_error("cluster is not listening to commands.");
218 
219  std::unique_lock<std::mutex> cluster_lock(mpi_cluster::in_use(),
220  std::try_to_lock);
221 
222  if (!cluster_lock.owns_lock())
223  throw mpi_is_in_use();
224 
225  boost::mpi::broadcast(world, command, 0);
226 
227  return cluster_lock;
228 }
229 
237 template <typename T>
238 std::unique_lock<std::mutex> mpi_broadcast_command() {
239  std::shared_ptr<mpi_command> command(new T);
240 
241  return mpi_broadcast_command(command);
242 }
243 
244 } // namespace math
245 } // namespace stan
246 
247 #endif
248 
249 #endif
double e()
Return the base of the natural logarithm.
Definition: constants.hpp:87

     [ Stan Home Page ] © 2011–2018, Stan Development Team.