Stan Math Library  2.20.0
reverse mode automatic differentiation
cov_exp_quad.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUN_COV_EXP_QUAD_HPP
2 #define STAN_MATH_REV_MAT_FUN_COV_EXP_QUAD_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
11 #include <type_traits>
12 #include <vector>
13 #include <cmath>
14 
15 namespace stan {
16 namespace math {
17 
21 template <typename T_x, typename T_sigma, typename T_l>
22 class cov_exp_quad_vari : public vari {
23  public:
24  const size_t size_;
25  const size_t size_ltri_;
26  const double l_d_;
27  const double sigma_d_;
28  const double sigma_sq_d_;
29  double* dist_;
34 
38  cov_exp_quad_vari(const std::vector<T_x>& x, const T_sigma& sigma,
39  const T_l& l)
40  : vari(0.0),
41  size_(x.size()),
42  size_ltri_(size_ * (size_ - 1) / 2),
43  l_d_(value_of(l)),
44  sigma_d_(value_of(sigma)),
45  sigma_sq_d_(sigma_d_ * sigma_d_),
46  dist_(ChainableStack::instance_->memalloc_.alloc_array<double>(
47  size_ltri_)),
48  l_vari_(l.vi_),
49  sigma_vari_(sigma.vi_),
50  cov_lower_(ChainableStack::instance_->memalloc_.alloc_array<vari*>(
51  size_ltri_)),
52  cov_diag_(
53  ChainableStack::instance_->memalloc_.alloc_array<vari*>(size_)) {
54  double inv_half_sq_l_d = 0.5 / (l_d_ * l_d_);
55  size_t pos = 0;
56  for (size_t j = 0; j < size_ - 1; ++j) {
57  for (size_t i = j + 1; i < size_; ++i) {
58  double dist_sq = squared_distance(x[i], x[j]);
59  dist_[pos] = dist_sq;
60  cov_lower_[pos] = new vari(
61  sigma_sq_d_ * std::exp(-dist_sq * inv_half_sq_l_d), false);
62  ++pos;
63  }
64  }
65  for (size_t i = 0; i < size_; ++i)
66  cov_diag_[i] = new vari(sigma_sq_d_, false);
67  }
68 
69  virtual void chain() {
70  double adjl = 0;
71  double adjsigma = 0;
72 
73  for (size_t i = 0; i < size_ltri_; ++i) {
74  vari* el_low = cov_lower_[i];
75  double prod_add = el_low->adj_ * el_low->val_;
76  adjl += prod_add * dist_[i];
77  adjsigma += prod_add;
78  }
79  for (size_t i = 0; i < size_; ++i) {
80  vari* el = cov_diag_[i];
81  adjsigma += el->adj_ * el->val_;
82  }
83  l_vari_->adj_ += adjl / (l_d_ * l_d_ * l_d_);
84  sigma_vari_->adj_ += adjsigma * 2 / sigma_d_;
85  }
86 };
87 
91 template <typename T_x, typename T_l>
92 class cov_exp_quad_vari<T_x, double, T_l> : public vari {
93  public:
94  const size_t size_;
95  const size_t size_ltri_;
96  const double l_d_;
97  const double sigma_d_;
98  const double sigma_sq_d_;
99  double* dist_;
103 
107  cov_exp_quad_vari(const std::vector<T_x>& x, double sigma, const T_l& l)
108  : vari(0.0),
109  size_(x.size()),
110  size_ltri_(size_ * (size_ - 1) / 2),
111  l_d_(value_of(l)),
112  sigma_d_(value_of(sigma)),
113  sigma_sq_d_(sigma_d_ * sigma_d_),
114  dist_(ChainableStack::instance_->memalloc_.alloc_array<double>(
115  size_ltri_)),
116  l_vari_(l.vi_),
117  cov_lower_(ChainableStack::instance_->memalloc_.alloc_array<vari*>(
118  size_ltri_)),
119  cov_diag_(
120  ChainableStack::instance_->memalloc_.alloc_array<vari*>(size_)) {
121  double inv_half_sq_l_d = 0.5 / (l_d_ * l_d_);
122  size_t pos = 0;
123  for (size_t j = 0; j < size_ - 1; ++j) {
124  for (size_t i = j + 1; i < size_; ++i) {
125  double dist_sq = squared_distance(x[i], x[j]);
126  dist_[pos] = dist_sq;
127  cov_lower_[pos] = new vari(
128  sigma_sq_d_ * std::exp(-dist_sq * inv_half_sq_l_d), false);
129  ++pos;
130  }
131  }
132  for (size_t i = 0; i < size_; ++i)
133  cov_diag_[i] = new vari(sigma_sq_d_, false);
134  }
135 
136  virtual void chain() {
137  double adjl = 0;
138 
139  for (size_t i = 0; i < size_ltri_; ++i) {
140  vari* el_low = cov_lower_[i];
141  adjl += el_low->adj_ * el_low->val_ * dist_[i];
142  }
143  l_vari_->adj_ += adjl / (l_d_ * l_d_ * l_d_);
144  }
145 };
146 
150 template <typename T_x>
151 inline typename std::enable_if<
152  std::is_same<typename scalar_type<T_x>::type, double>::value,
153  Eigen::Matrix<var, -1, -1> >::type
154 cov_exp_quad(const std::vector<T_x>& x, const var& sigma, const var& l) {
155  return gp_exp_quad_cov(x, sigma, l);
156 }
157 
161 template <typename T_x>
162 inline typename std::enable_if<
163  std::is_same<typename scalar_type<T_x>::type, double>::value,
164  Eigen::Matrix<var, -1, -1> >::type
165 cov_exp_quad(const std::vector<T_x>& x, double sigma, const var& l) {
166  return gp_exp_quad_cov(x, sigma, l);
167 }
168 
169 } // namespace math
170 } // namespace stan
171 #endif
Eigen::Matrix< typename stan::return_type< T_x, T_sigma, T_l >::type, Eigen::Dynamic, Eigen::Dynamic > cov_exp_quad(const std::vector< T_x > &x, const T_sigma &sigma, const T_l &length_scale)
virtual void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition: value_of.hpp:17
Eigen::Matrix< typename stan::return_type< T_x, T_sigma, T_l >::type, Eigen::Dynamic, Eigen::Dynamic > gp_exp_quad_cov(const std::vector< T_x > &x, const T_sigma &sigma, const T_l &length_scale)
Returns a squared exponential kernel.
The variable implementation base class.
Definition: vari.hpp:30
friend class var
Definition: vari.hpp:32
const double val_
The value of this variable.
Definition: vari.hpp:38
fvar< T > squared_distance(const Eigen::Matrix< fvar< T >, R, C > &v1, const Eigen::Matrix< double, R, C > &v2)
Returns the squared distance between the specified vectors of the same dimensions.
fvar< T > exp(const fvar< T > &x)
Definition: exp.hpp:11
vari(double x)
Construct a variable implementation from a value.
Definition: vari.hpp:58
int size(const std::vector< T > &x)
Return the size of the specified standard vector.
Definition: size.hpp:17
cov_exp_quad_vari(const std::vector< T_x > &x, const T_sigma &sigma, const T_l &l)
virtual void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
cov_exp_quad_vari(const std::vector< T_x > &x, double sigma, const T_l &l)
double adj_
The adjoint of this variable, which is the partial derivative of this variable with respect to the ro...
Definition: vari.hpp:44
This struct always provides access to the autodiff stack using the singleton pattern.

     [ Stan Home Page ] © 2011–2018, Stan Development Team.