Stan Math Library  2.20.0
reverse mode automatic differentiation
fma.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_SCAL_FUN_FMA_HPP
2 #define STAN_MATH_REV_SCAL_FUN_FMA_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
8 #include <limits>
9 
10 namespace stan {
11 namespace math {
12 
13 namespace internal {
14 class fma_vvv_vari : public op_vvv_vari {
15  public:
16  fma_vvv_vari(vari* avi, vari* bvi, vari* cvi)
17  : op_vvv_vari(fma(avi->val_, bvi->val_, cvi->val_), avi, bvi, cvi) {}
18  void chain() {
20  avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
21  bvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
22  cvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
23  } else {
24  avi_->adj_ += adj_ * bvi_->val_;
25  bvi_->adj_ += adj_ * avi_->val_;
26  cvi_->adj_ += adj_;
27  }
28  }
29 };
30 
31 class fma_vvd_vari : public op_vvd_vari {
32  public:
33  fma_vvd_vari(vari* avi, vari* bvi, double c)
34  : op_vvd_vari(fma(avi->val_, bvi->val_, c), avi, bvi, c) {}
35  void chain() {
36  if (unlikely(is_any_nan(avi_->val_, bvi_->val_, cd_))) {
37  avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
38  bvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
39  } else {
40  avi_->adj_ += adj_ * bvi_->val_;
41  bvi_->adj_ += adj_ * avi_->val_;
42  }
43  }
44 };
45 
46 class fma_vdv_vari : public op_vdv_vari {
47  public:
48  fma_vdv_vari(vari* avi, double b, vari* cvi)
49  : op_vdv_vari(fma(avi->val_, b, cvi->val_), avi, b, cvi) {}
50  void chain() {
51  if (unlikely(is_any_nan(avi_->val_, cvi_->val_, bd_))) {
52  avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
53  cvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
54  } else {
55  avi_->adj_ += adj_ * bd_;
56  cvi_->adj_ += adj_;
57  }
58  }
59 };
60 
61 class fma_vdd_vari : public op_vdd_vari {
62  public:
63  fma_vdd_vari(vari* avi, double b, double c)
64  : op_vdd_vari(fma(avi->val_, b, c), avi, b, c) {}
65  void chain() {
66  if (unlikely(is_any_nan(avi_->val_, bd_, cd_)))
67  avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
68  else
69  avi_->adj_ += adj_ * bd_;
70  }
71 };
72 
73 class fma_ddv_vari : public op_ddv_vari {
74  public:
75  fma_ddv_vari(double a, double b, vari* cvi)
76  : op_ddv_vari(fma(a, b, cvi->val_), a, b, cvi) {}
77  void chain() {
78  if (unlikely(is_any_nan(cvi_->val_, ad_, bd_)))
79  cvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
80  else
81  cvi_->adj_ += adj_;
82  }
83 };
84 } // namespace internal
85 
104 inline var fma(const var& a, const var& b, const var& c) {
105  return var(new internal::fma_vvv_vari(a.vi_, b.vi_, c.vi_));
106 }
107 
124 inline var fma(const var& a, const var& b, double c) {
125  return var(new internal::fma_vvd_vari(a.vi_, b.vi_, c));
126 }
127 
144 inline var fma(const var& a, double b, const var& c) {
145  return var(new internal::fma_vdv_vari(a.vi_, b, c.vi_));
146 }
147 
166 inline var fma(const var& a, double b, double c) {
167  return var(new internal::fma_vdd_vari(a.vi_, b, c));
168 }
169 
184 inline var fma(double a, const var& b, double c) {
185  return var(new internal::fma_vdd_vari(b.vi_, a, c));
186 }
187 
202 inline var fma(double a, double b, const var& c) {
203  return var(new internal::fma_ddv_vari(a, b, c.vi_));
204 }
205 
222 inline var fma(double a, const var& b, const var& c) {
223  return var(new internal::fma_vdv_vari(b.vi_, a, c.vi_)); // a-b symmetry
224 }
225 
226 } // namespace math
227 } // namespace stan
228 #endif
fma_vvv_vari(vari *avi, vari *bvi, vari *cvi)
Definition: fma.hpp:16
void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
Definition: fma.hpp:50
void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
Definition: fma.hpp:35
fma_ddv_vari(double a, double b, vari *cvi)
Definition: fma.hpp:75
The variable implementation base class.
Definition: vari.hpp:30
fma_vvd_vari(vari *avi, vari *bvi, double c)
Definition: fma.hpp:33
bool is_any_nan(const T &x)
Returns true if the input is NaN and false otherwise.
Definition: is_any_nan.hpp:21
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:33
friend class var
Definition: vari.hpp:32
const double val_
The value of this variable.
Definition: vari.hpp:38
void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
Definition: fma.hpp:77
void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
Definition: fma.hpp:18
fma_vdv_vari(vari *avi, double b, vari *cvi)
Definition: fma.hpp:48
#define unlikely(x)
Definition: likely.hpp:9
fvar< typename stan::return_type< T1, T2, T3 >::type > fma(const fvar< T1 > &x1, const fvar< T2 > &x2, const fvar< T3 > &x3)
The fused multiply-add operation (C99).
Definition: fma.hpp:59
vari * vi_
Pointer to the implementation of this variable.
Definition: var.hpp:45
fma_vdd_vari(vari *avi, double b, double c)
Definition: fma.hpp:63
void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
Definition: fma.hpp:65
double adj_
The adjoint of this variable, which is the partial derivative of this variable with respect to the ro...
Definition: vari.hpp:44

     [ Stan Home Page ] © 2011–2018, Stan Development Team.