Stan Math Library  2.20.0
reverse mode automatic differentiation
cvodes_integrator.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUNCTOR_INTEGRATE_ODE_CVODES_HPP
2 #define STAN_MATH_REV_MAT_FUNCTOR_INTEGRATE_ODE_CVODES_HPP
3 
4 #include <stan/math/rev/meta.hpp>
15 #include <cvodes/cvodes.h>
16 #include <sunlinsol/sunlinsol_dense.h>
17 #include <algorithm>
18 #include <ostream>
19 #include <vector>
20 
21 namespace stan {
22 namespace math {
23 
29 template <int Lmm>
31  public:
33 
73  template <typename F, typename T_initial, typename T_param, typename T_t0,
74  typename T_ts>
75  std::vector<std::vector<
77  integrate(const F& f, const std::vector<T_initial>& y0, const T_t0& t0,
78  const std::vector<T_ts>& ts, const std::vector<T_param>& theta,
79  const std::vector<double>& x, const std::vector<int>& x_int,
80  std::ostream* msgs, double relative_tolerance,
81  double absolute_tolerance,
82  long int max_num_steps) { // NOLINT(runtime/int)
83  typedef stan::is_var<T_initial> initial_var;
84  typedef stan::is_var<T_param> param_var;
85 
86  const char* fun = "integrate_ode_cvodes";
87 
88  const double t0_dbl = value_of(t0);
89  const std::vector<double> ts_dbl = value_of(ts);
90 
91  check_finite(fun, "initial state", y0);
92  check_finite(fun, "initial time", t0_dbl);
93  check_finite(fun, "times", ts_dbl);
94  check_finite(fun, "parameter vector", theta);
95  check_finite(fun, "continuous data", x);
96  check_nonzero_size(fun, "times", ts);
97  check_nonzero_size(fun, "initial state", y0);
98  check_ordered(fun, "times", ts_dbl);
99  check_less(fun, "initial time", t0_dbl, ts_dbl[0]);
100  if (relative_tolerance <= 0)
101  invalid_argument("integrate_ode_cvodes", "relative_tolerance,",
102  relative_tolerance, "", ", must be greater than 0");
103  if (absolute_tolerance <= 0)
104  invalid_argument("integrate_ode_cvodes", "absolute_tolerance,",
105  absolute_tolerance, "", ", must be greater than 0");
106  if (max_num_steps <= 0)
107  invalid_argument("integrate_ode_cvodes", "max_num_steps,", max_num_steps,
108  "", ", must be greater than 0");
109 
110  const size_t N = y0.size();
111  const size_t M = theta.size();
112  const size_t S = (initial_var::value ? N : 0) + (param_var::value ? M : 0);
113 
115  ode_data cvodes_data(f, y0, theta, x, x_int, msgs);
116 
117  void* cvodes_mem = CVodeCreate(Lmm);
118  if (cvodes_mem == nullptr)
119  throw std::runtime_error("CVodeCreate failed to allocate memory");
120 
121  const size_t coupled_size = cvodes_data.coupled_ode_.size();
122 
123  std::vector<std::vector<
125  y;
127  f, y0, theta, t0, ts, x, x_int, msgs, y);
128 
129  try {
130  cvodes_check_flag(CVodeInit(cvodes_mem, &ode_data::cv_rhs, t0_dbl,
131  cvodes_data.nv_state_),
132  "CVodeInit");
133 
134  // Assign pointer to this as user data
136  CVodeSetUserData(cvodes_mem, reinterpret_cast<void*>(&cvodes_data)),
137  "CVodeSetUserData");
138 
139  cvodes_set_options(cvodes_mem, relative_tolerance, absolute_tolerance,
140  max_num_steps);
141 
142  // for the stiff solvers we need to reserve additional memory
143  // and provide a Jacobian function call. new API since 3.0.0:
144  // create matrix object and linear solver object; resource
145  // (de-)allocation is handled in the cvodes_ode_data
147  CVodeSetLinearSolver(cvodes_mem, cvodes_data.LS_, cvodes_data.A_),
148  "CVodeSetLinearSolver");
150  CVodeSetJacFn(cvodes_mem, &ode_data::cv_jacobian_states),
151  "CVodeSetJacFn");
152 
153  // initialize forward sensitivity system of CVODES as needed
154  if (S > 0) {
156  CVodeSensInit(cvodes_mem, static_cast<int>(S), CV_STAGGERED,
157  &ode_data::cv_rhs_sens, cvodes_data.nv_state_sens_),
158  "CVodeSensInit");
159 
160  cvodes_check_flag(CVodeSensEEtolerances(cvodes_mem),
161  "CVodeSensEEtolerances");
162  }
163 
164  double t_init = t0_dbl;
165  for (size_t n = 0; n < ts.size(); ++n) {
166  double t_final = ts_dbl[n];
167  if (t_final != t_init)
168  cvodes_check_flag(CVode(cvodes_mem, t_final, cvodes_data.nv_state_,
169  &t_init, CV_NORMAL),
170  "CVode");
171  if (S > 0) {
173  CVodeGetSens(cvodes_mem, &t_init, cvodes_data.nv_state_sens_),
174  "CVodeGetSens");
175  }
176  observer(cvodes_data.coupled_state_, t_final);
177  t_init = t_final;
178  }
179  } catch (const std::exception& e) {
180  CVodeFree(&cvodes_mem);
181  throw;
182  }
183 
184  CVodeFree(&cvodes_mem);
185 
186  return y;
187  }
188 }; // cvodes integrator
189 } // namespace math
190 } // namespace stan
191 #endif
void check_finite(const char *function, const char *name, const T_y &y)
Check if y is finite.
void check_nonzero_size(const char *function, const char *name, const T_y &y)
Check if the specified matrix/vector is of non-zero size.
void cvodes_set_options(void *cvodes_mem, double rel_tol, double abs_tol, long int max_num_steps)
void check_ordered(const char *function, const char *name, const std::vector< T_y > &y)
Check if the specified vector is sorted into strictly increasing order.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition: value_of.hpp:17
Defines a public enum named value which is defined to be false as the primitive scalar types cannot b...
Definition: is_var.hpp:10
Observer for the coupled states.
CVODES ode data holder object which is used during CVODES integration for CVODES callbacks.
boost::math::tools::promote_args< double, typename scalar_type< T >::type, typename return_type< Types_pack... >::type >::type type
Definition: return_type.hpp:36
void cvodes_check_flag(int flag, const char *func_name)
void invalid_argument(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw an invalid_argument exception with a consistently formatted message.
double e()
Return the base of the natural logarithm.
Definition: constants.hpp:87
std::vector< std::vector< typename stan::return_type< T_initial, T_param, T_t0, T_ts >::type > > integrate(const F &f, const std::vector< T_initial > &y0, const T_t0 &t0, const std::vector< T_ts > &ts, const std::vector< T_param > &theta, const std::vector< double > &x, const std::vector< int > &x_int, std::ostream *msgs, double relative_tolerance, double absolute_tolerance, long int max_num_steps)
Return the solutions for the specified system of ordinary differential equations given the specified ...
void check_less(const char *function, const char *name, const T_y &y, const T_high &high)
Check if y is strictly less than high.
Definition: check_less.hpp:63
Integrator interface for CVODES&#39; ODE solvers (Adams & BDF methods).

     [ Stan Home Page ] © 2011–2018, Stan Development Team.