Stan Math Library  2.20.0
reverse mode automatic differentiation
pow.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_SCAL_FUN_POW_HPP
2 #define STAN_MATH_REV_SCAL_FUN_POW_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
12 #include <cmath>
13 #include <limits>
14 
15 namespace stan {
16 namespace math {
17 
18 namespace internal {
19 class pow_vv_vari : public op_vv_vari {
20  public:
21  pow_vv_vari(vari* avi, vari* bvi)
22  : op_vv_vari(std::pow(avi->val_, bvi->val_), avi, bvi) {}
23  void chain() {
24  if (unlikely(is_any_nan(avi_->val_, bvi_->val_))) {
25  avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
26  bvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
27  } else {
28  if (avi_->val_ == 0.0)
29  return; // partials zero, avoids 0 & log(0)
30  avi_->adj_ += adj_ * bvi_->val_ * val_ / avi_->val_;
31  bvi_->adj_ += adj_ * std::log(avi_->val_) * val_;
32  }
33  }
34 };
35 
36 class pow_vd_vari : public op_vd_vari {
37  public:
38  pow_vd_vari(vari* avi, double b)
39  : op_vd_vari(std::pow(avi->val_, b), avi, b) {}
40  void chain() {
41  if (unlikely(is_any_nan(avi_->val_, bd_))) {
42  avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
43  } else {
44  if (avi_->val_ == 0.0)
45  return; // partials zero, avoids 0 & log(0)
46  avi_->adj_ += adj_ * bd_ * val_ / avi_->val_;
47  }
48  }
49 };
50 
51 class pow_dv_vari : public op_dv_vari {
52  public:
53  pow_dv_vari(double a, vari* bvi)
54  : op_dv_vari(std::pow(a, bvi->val_), a, bvi) {}
55  void chain() {
56  if (unlikely(is_any_nan(bvi_->val_, ad_))) {
57  bvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
58  } else {
59  if (ad_ == 0.0)
60  return; // partials zero, avoids 0 & log(0)
61  bvi_->adj_ += adj_ * std::log(ad_) * val_;
62  }
63  }
64 };
65 } // namespace internal
66 
105 inline var pow(const var& base, const var& exponent) {
106  return var(new internal::pow_vv_vari(base.vi_, exponent.vi_));
107 }
108 
121 inline var pow(const var& base, double exponent) {
122  if (exponent == 0.5)
123  return sqrt(base);
124  if (exponent == 1.0)
125  return base;
126  if (exponent == 2.0)
127  return square(base);
128  if (exponent == -2.0)
129  return inv_square(base);
130  if (exponent == -1.0)
131  return inv(base);
132  if (exponent == -0.5)
133  return inv_sqrt(base);
134  return var(new internal::pow_vd_vari(base.vi_, exponent));
135 }
136 
149 inline var pow(double base, const var& exponent) {
150  return var(new internal::pow_dv_vari(base, exponent.vi_));
151 }
152 
153 } // namespace math
154 } // namespace stan
155 #endif
fvar< T > inv_sqrt(const fvar< T > &x)
Definition: inv_sqrt.hpp:11
fvar< T > sqrt(const fvar< T > &x)
Definition: sqrt.hpp:13
fvar< T > log(const fvar< T > &x)
Definition: log.hpp:12
void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
Definition: pow.hpp:40
The variable implementation base class.
Definition: vari.hpp:30
bool is_any_nan(const T &x)
Returns true if the input is NaN and false otherwise.
Definition: is_any_nan.hpp:21
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:33
friend class var
Definition: vari.hpp:32
const double val_
The value of this variable.
Definition: vari.hpp:38
fvar< T > square(const fvar< T > &x)
Definition: square.hpp:12
void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
Definition: pow.hpp:55
pow_dv_vari(double a, vari *bvi)
Definition: pow.hpp:53
#define unlikely(x)
Definition: likely.hpp:9
vari * vi_
Pointer to the implementation of this variable.
Definition: var.hpp:45
pow_vd_vari(vari *avi, double b)
Definition: pow.hpp:38
fvar< T > inv_square(const fvar< T > &x)
Definition: inv_square.hpp:12
double adj_
The adjoint of this variable, which is the partial derivative of this variable with respect to the ro...
Definition: vari.hpp:44
fvar< T > pow(const fvar< T > &x1, const fvar< T > &x2)
Definition: pow.hpp:16
pow_vv_vari(vari *avi, vari *bvi)
Definition: pow.hpp:21
void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
Definition: pow.hpp:23
fvar< T > inv(const fvar< T > &x)
Definition: inv.hpp:12

     [ Stan Home Page ] © 2011–2018, Stan Development Team.