Stan Math Library  2.20.0
reverse mode automatic differentiation
integrate_1d.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_ARR_FUNCTOR_integrate_1d_HPP
2 #define STAN_MATH_REV_ARR_FUNCTOR_integrate_1d_HPP
3 
4 #include <stan/math/rev/meta.hpp>
11 #include <type_traits>
12 #include <string>
13 #include <vector>
14 #include <functional>
15 #include <ostream>
16 #include <limits>
17 
18 namespace stan {
19 namespace math {
20 
29 template <typename F>
30 inline double gradient_of_f(const F &f, const double &x, const double &xc,
31  const std::vector<double> &theta_vals,
32  const std::vector<double> &x_r,
33  const std::vector<int> &x_i, size_t n,
34  std::ostream &msgs) {
35  double gradient = 0.0;
36  start_nested();
37  std::vector<var> theta_var(theta_vals.size());
38  try {
39  for (size_t i = 0; i < theta_vals.size(); i++)
40  theta_var[i] = theta_vals[i];
41  var fx = f(x, xc, theta_var, x_r, x_i, &msgs);
42  fx.grad();
43  gradient = theta_var[n].adj();
44  if (is_nan(gradient)) {
45  if (fx.val() == 0) {
46  gradient = 0;
47  } else {
48  domain_error("gradient_of_f", "The gradient of f", n,
49  "is nan for parameter ", "");
50  }
51  }
52  } catch (const std::exception &e) {
54  throw;
55  }
57 
58  return gradient;
59 }
60 
115 template <typename F, typename T_a, typename T_b, typename T_theta>
116 inline typename std::enable_if<std::is_same<T_a, var>::value
117  || std::is_same<T_b, var>::value
118  || std::is_same<T_theta, var>::value,
119  var>::type
120 integrate_1d(const F &f, const T_a &a, const T_b &b,
121  const std::vector<T_theta> &theta, const std::vector<double> &x_r,
122  const std::vector<int> &x_i, std::ostream &msgs,
123  const double relative_tolerance
124  = std::sqrt(std::numeric_limits<double>::epsilon())) {
125  static const char *function = "integrate_1d";
126  check_less_or_equal(function, "lower limit", a, b);
127 
128  if (value_of(a) == value_of(b)) {
129  if (is_inf(a))
130  domain_error(function, "Integration endpoints are both", value_of(a), "",
131  "");
132  return var(0.0);
133  } else {
134  double integral = integrate(
135  std::bind<double>(f, std::placeholders::_1, std::placeholders::_2,
136  value_of(theta), x_r, x_i, &msgs),
137  value_of(a), value_of(b), relative_tolerance);
138 
139  size_t N_theta_vars = is_var<T_theta>::value ? theta.size() : 0;
140  std::vector<double> dintegral_dtheta(N_theta_vars);
141  std::vector<var> theta_concat(N_theta_vars);
142 
143  if (N_theta_vars > 0) {
144  std::vector<double> theta_vals = value_of(theta);
145 
146  for (size_t n = 0; n < N_theta_vars; ++n) {
147  dintegral_dtheta[n] = integrate(
148  std::bind<double>(gradient_of_f<F>, f, std::placeholders::_1,
149  std::placeholders::_2, theta_vals, x_r, x_i, n,
150  std::ref(msgs)),
151  value_of(a), value_of(b), relative_tolerance);
152  theta_concat[n] = theta[n];
153  }
154  }
155 
156  if (!is_inf(a) && is_var<T_a>::value) {
157  theta_concat.push_back(a);
158  dintegral_dtheta.push_back(
159  -value_of(f(value_of(a), 0.0, theta, x_r, x_i, &msgs)));
160  }
161 
162  if (!is_inf(b) && is_var<T_b>::value) {
163  theta_concat.push_back(b);
164  dintegral_dtheta.push_back(
165  value_of(f(value_of(b), 0.0, theta, x_r, x_i, &msgs)));
166  }
167 
168  return precomputed_gradients(integral, theta_concat, dintegral_dtheta);
169  }
170 }
171 
172 } // namespace math
173 } // namespace stan
174 
175 #endif
void check_less_or_equal(const char *function, const char *name, const T_y &y, const T_high &high)
Check if y is less or equal to high.
double gradient_of_f(const F &f, const double &x, const double &xc, const std::vector< double > &theta_vals, const std::vector< double > &x_r, const std::vector< int > &x_i, size_t n, std::ostream &msgs)
Calculate first derivative of f(x, param, std::ostream&) with respect to the nth parameter.
fvar< T > sqrt(const fvar< T > &x)
Definition: sqrt.hpp:13
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition: value_of.hpp:17
Defines a public enum named value which is defined to be false as the primitive scalar types cannot b...
Definition: is_var.hpp:10
double integrate(const F &f, double a, double b, double relative_tolerance)
Integrate a single variable function f from a to b to within a specified relative tolerance...
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:33
var precomputed_gradients(double value, const std::vector< var > &operands, const std::vector< double > &gradients)
This function returns a var for an expression that has the specified value, vector of operands...
void domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
void grad(std::vector< var > &x, std::vector< double > &g)
Compute the gradient of this (dependent) variable with respect to the specified vector of (independen...
Definition: var.hpp:318
int is_inf(const fvar< T > &x)
Returns 1 if the input&#39;s value is infinite and 0 otherwise.
Definition: is_inf.hpp:20
double e()
Return the base of the natural logarithm.
Definition: constants.hpp:87
double integrate_1d(const F &f, const double a, const double b, const std::vector< double > &theta, const std::vector< double > &x_r, const std::vector< int > &x_i, std::ostream &msgs, const double relative_tolerance=std::sqrt(std::numeric_limits< double >::epsilon()))
Compute the integral of the single variable function f from a to b to within a specified relative tol...
static void recover_memory_nested()
Recover only the memory used for the top nested call.
int is_nan(const fvar< T > &x)
Returns 1 if the input&#39;s value is NaN and 0 otherwise.
Definition: is_nan.hpp:20
void gradient(const F &f, const Eigen::Matrix< T, Eigen::Dynamic, 1 > &x, T &fx, Eigen::Matrix< T, Eigen::Dynamic, 1 > &grad_fx)
Calculate the value and the gradient of the specified function at the specified argument.
Definition: gradient.hpp:39
static void start_nested()
Record the current position so that recover_memory_nested() can find it.
double val() const
Return the value of this variable.
Definition: var.hpp:294

     [ Stan Home Page ] © 2011–2018, Stan Development Team.