Stan Math Library  2.20.0
reverse mode automatic differentiation
trace_gen_quad_form.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUN_TRACE_GEN_QUAD_FORM_HPP
2 #define STAN_MATH_REV_MAT_FUN_TRACE_GEN_QUAD_FORM_HPP
3 
4 #include <stan/math/rev/meta.hpp>
6 #include <stan/math/rev/core.hpp>
13 #include <type_traits>
14 
15 namespace stan {
16 namespace math {
17 namespace internal {
18 template <typename Td, int Rd, int Cd, typename Ta, int Ra, int Ca, typename Tb,
19  int Rb, int Cb>
21  public:
22  trace_gen_quad_form_vari_alloc(const Eigen::Matrix<Td, Rd, Cd>& D,
23  const Eigen::Matrix<Ta, Ra, Ca>& A,
24  const Eigen::Matrix<Tb, Rb, Cb>& B)
25  : D_(D), A_(A), B_(B) {}
26 
27  double compute() {
29  }
30 
31  Eigen::Matrix<Td, Rd, Cd> D_;
32  Eigen::Matrix<Ta, Ra, Ca> A_;
33  Eigen::Matrix<Tb, Rb, Cb> B_;
34 };
35 
36 template <typename Td, int Rd, int Cd, typename Ta, int Ra, int Ca, typename Tb,
37  int Rb, int Cb>
39  protected:
40  static inline void computeAdjoints(double adj,
41  const Eigen::Matrix<double, Rd, Cd>& D,
42  const Eigen::Matrix<double, Ra, Ca>& A,
43  const Eigen::Matrix<double, Rb, Cb>& B,
44  Eigen::Matrix<var, Rd, Cd>* varD,
45  Eigen::Matrix<var, Ra, Ca>* varA,
46  Eigen::Matrix<var, Rb, Cb>* varB) {
47  Eigen::Matrix<double, Ca, Cb> AtB;
48  Eigen::Matrix<double, Ra, Cb> BD;
49  if (varB || varA)
50  BD.noalias() = B * D;
51  if (varB || varD)
52  AtB.noalias() = A.transpose() * B;
53 
54  if (varB) {
55  Eigen::Matrix<double, Rb, Cb> adjB(adj * (A * BD + AtB * D.transpose()));
56  for (int j = 0; j < B.cols(); j++)
57  for (int i = 0; i < B.rows(); i++)
58  (*varB)(i, j).vi_->adj_ += adjB(i, j);
59  }
60  if (varA) {
61  Eigen::Matrix<double, Ra, Ca> adjA(adj * (B * BD.transpose()));
62  for (int j = 0; j < A.cols(); j++)
63  for (int i = 0; i < A.rows(); i++)
64  (*varA)(i, j).vi_->adj_ += adjA(i, j);
65  }
66  if (varD) {
67  Eigen::Matrix<double, Rd, Cd> adjD(adj * (B.transpose() * AtB));
68  for (int j = 0; j < D.cols(); j++)
69  for (int i = 0; i < D.rows(); i++)
70  (*varD)(i, j).vi_->adj_ += adjD(i, j);
71  }
72  }
73 
74  public:
77  : vari(impl->compute()), impl_(impl) {}
78 
79  virtual void chain() {
80  computeAdjoints(adj_, value_of(impl_->D_), value_of(impl_->A_),
81  value_of(impl_->B_),
82  reinterpret_cast<Eigen::Matrix<var, Rd, Cd>*>(
83  std::is_same<Td, var>::value ? (&impl_->D_) : NULL),
84  reinterpret_cast<Eigen::Matrix<var, Ra, Ca>*>(
85  std::is_same<Ta, var>::value ? (&impl_->A_) : NULL),
86  reinterpret_cast<Eigen::Matrix<var, Rb, Cb>*>(
87  std::is_same<Tb, var>::value ? (&impl_->B_) : NULL));
88  }
89 
91 };
92 } // namespace internal
93 
94 template <typename Td, int Rd, int Cd, typename Ta, int Ra, int Ca, typename Tb,
95  int Rb, int Cb>
96 inline typename std::enable_if<std::is_same<Td, var>::value
97  || std::is_same<Ta, var>::value
98  || std::is_same<Tb, var>::value,
99  var>::type
100 trace_gen_quad_form(const Eigen::Matrix<Td, Rd, Cd>& D,
101  const Eigen::Matrix<Ta, Ra, Ca>& A,
102  const Eigen::Matrix<Tb, Rb, Cb>& B) {
103  check_square("trace_gen_quad_form", "A", A);
104  check_square("trace_gen_quad_form", "D", D);
105  check_multiplicable("trace_gen_quad_form", "A", A, "B", B);
106  check_multiplicable("trace_gen_quad_form", "B", B, "D", D);
107 
109  baseVari
110  = new internal::trace_gen_quad_form_vari_alloc<Td, Rd, Cd, Ta, Ra, Ca, Tb,
111  Rb, Cb>(D, A, B);
112 
113  return var(new internal::trace_gen_quad_form_vari<Td, Rd, Cd, Ta, Ra, Ca, Tb,
114  Rb, Cb>(baseVari));
115 }
116 
117 } // namespace math
118 } // namespace stan
119 #endif
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition: value_of.hpp:17
trace_gen_quad_form_vari(trace_gen_quad_form_vari_alloc< Td, Rd, Cd, Ta, Ra, Ca, Tb, Rb, Cb > *impl)
void check_square(const char *function, const char *name, const matrix_cl &y)
Check if the matrix_cl is square.
The variable implementation base class.
Definition: vari.hpp:30
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:33
fvar< T > trace_gen_quad_form(const Eigen::Matrix< fvar< T >, RD, CD > &D, const Eigen::Matrix< fvar< T >, RA, CA > &A, const Eigen::Matrix< fvar< T >, RB, CB > &B)
trace_gen_quad_form_vari_alloc(const Eigen::Matrix< Td, Rd, Cd > &D, const Eigen::Matrix< Ta, Ra, Ca > &A, const Eigen::Matrix< Tb, Rb, Cb > &B)
trace_gen_quad_form_vari_alloc< Td, Rd, Cd, Ta, Ra, Ca, Tb, Rb, Cb > * impl_
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
static void computeAdjoints(double adj, const Eigen::Matrix< double, Rd, Cd > &D, const Eigen::Matrix< double, Ra, Ca > &A, const Eigen::Matrix< double, Rb, Cb > &B, Eigen::Matrix< var, Rd, Cd > *varD, Eigen::Matrix< var, Ra, Ca > *varA, Eigen::Matrix< var, Rb, Cb > *varB)
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
virtual void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...

     [ Stan Home Page ] © 2011–2018, Stan Development Team.