Stan Math Library  2.20.0
reverse mode automatic differentiation
quad_form.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUN_QUAD_FORM_HPP
2 #define STAN_MATH_REV_MAT_FUN_QUAD_FORM_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
12 #include <type_traits>
13 
14 namespace stan {
15 namespace math {
16 
17 namespace internal {
18 template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
20  private:
21  inline void compute(const Eigen::Matrix<double, Ra, Ca>& A,
22  const Eigen::Matrix<double, Rb, Cb>& B) {
23  Eigen::Matrix<double, Cb, Cb> Cd(B.transpose() * A * B);
24  for (int j = 0; j < C_.cols(); j++) {
25  for (int i = 0; i < C_.rows(); i++) {
26  if (sym_) {
27  C_(i, j) = var(new vari(0.5 * (Cd(i, j) + Cd(j, i)), false));
28  } else {
29  C_(i, j) = var(new vari(Cd(i, j), false));
30  }
31  }
32  }
33  }
34 
35  public:
36  quad_form_vari_alloc(const Eigen::Matrix<Ta, Ra, Ca>& A,
37  const Eigen::Matrix<Tb, Rb, Cb>& B,
38  bool symmetric = false)
39  : A_(A), B_(B), C_(B_.cols(), B_.cols()), sym_(symmetric) {
40  compute(value_of(A), value_of(B));
41  }
42 
43  Eigen::Matrix<Ta, Ra, Ca> A_;
44  Eigen::Matrix<Tb, Rb, Cb> B_;
45  Eigen::Matrix<var, Cb, Cb> C_;
46  bool sym_;
47 };
48 
49 template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
50 class quad_form_vari : public vari {
51  protected:
52  inline void chainA(Eigen::Matrix<double, Ra, Ca>& A,
53  const Eigen::Matrix<double, Rb, Cb>& Bd,
54  const Eigen::Matrix<double, Cb, Cb>& adjC) {}
55  inline void chainB(Eigen::Matrix<double, Rb, Cb>& B,
56  const Eigen::Matrix<double, Ra, Ca>& Ad,
57  const Eigen::Matrix<double, Rb, Cb>& Bd,
58  const Eigen::Matrix<double, Cb, Cb>& adjC) {}
59 
60  inline void chainA(Eigen::Matrix<var, Ra, Ca>& A,
61  const Eigen::Matrix<double, Rb, Cb>& Bd,
62  const Eigen::Matrix<double, Cb, Cb>& adjC) {
63  Eigen::Matrix<double, Ra, Ca> adjA(Bd * adjC * Bd.transpose());
64  for (int j = 0; j < A.cols(); j++) {
65  for (int i = 0; i < A.rows(); i++) {
66  A(i, j).vi_->adj_ += adjA(i, j);
67  }
68  }
69  }
70  inline void chainB(Eigen::Matrix<var, Rb, Cb>& B,
71  const Eigen::Matrix<double, Ra, Ca>& Ad,
72  const Eigen::Matrix<double, Rb, Cb>& Bd,
73  const Eigen::Matrix<double, Cb, Cb>& adjC) {
74  Eigen::Matrix<double, Ra, Ca> adjB(Ad * Bd * adjC.transpose()
75  + Ad.transpose() * Bd * adjC);
76  for (int j = 0; j < B.cols(); j++)
77  for (int i = 0; i < B.rows(); i++)
78  B(i, j).vi_->adj_ += adjB(i, j);
79  }
80 
81  inline void chainAB(Eigen::Matrix<Ta, Ra, Ca>& A,
82  Eigen::Matrix<Tb, Rb, Cb>& B,
83  const Eigen::Matrix<double, Ra, Ca>& Ad,
84  const Eigen::Matrix<double, Rb, Cb>& Bd,
85  const Eigen::Matrix<double, Cb, Cb>& adjC) {
86  chainA(A, Bd, adjC);
87  chainB(B, Ad, Bd, adjC);
88  }
89 
90  public:
91  quad_form_vari(const Eigen::Matrix<Ta, Ra, Ca>& A,
92  const Eigen::Matrix<Tb, Rb, Cb>& B, bool symmetric = false)
93  : vari(0.0) {
94  impl_ = new quad_form_vari_alloc<Ta, Ra, Ca, Tb, Rb, Cb>(A, B, symmetric);
95  }
96 
97  virtual void chain() {
98  Eigen::Matrix<double, Cb, Cb> adjC(impl_->C_.rows(), impl_->C_.cols());
99 
100  for (int j = 0; j < impl_->C_.cols(); j++)
101  for (int i = 0; i < impl_->C_.rows(); i++)
102  adjC(i, j) = impl_->C_(i, j).vi_->adj_;
103 
104  chainAB(impl_->A_, impl_->B_, value_of(impl_->A_), value_of(impl_->B_),
105  adjC);
106  }
107 
109 };
110 } // namespace internal
111 
112 template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
113 inline typename std::enable_if<std::is_same<Ta, var>::value
114  || std::is_same<Tb, var>::value,
115  Eigen::Matrix<var, Cb, Cb> >::type
116 quad_form(const Eigen::Matrix<Ta, Ra, Ca>& A,
117  const Eigen::Matrix<Tb, Rb, Cb>& B) {
118  check_square("quad_form", "A", A);
119  check_multiplicable("quad_form", "A", A, "B", B);
120 
123 
124  return baseVari->impl_->C_;
125 }
126 
127 template <typename Ta, int Ra, int Ca, typename Tb, int Rb>
128 inline typename std::enable_if<
129  std::is_same<Ta, var>::value || std::is_same<Tb, var>::value, var>::type
130 quad_form(const Eigen::Matrix<Ta, Ra, Ca>& A,
131  const Eigen::Matrix<Tb, Rb, 1>& B) {
132  check_square("quad_form", "A", A);
133  check_multiplicable("quad_form", "A", A, "B", B);
134 
137 
138  return baseVari->impl_->C_(0, 0);
139 }
140 
141 } // namespace math
142 } // namespace stan
143 #endif
void chainB(Eigen::Matrix< var, Rb, Cb > &B, const Eigen::Matrix< double, Ra, Ca > &Ad, const Eigen::Matrix< double, Rb, Cb > &Bd, const Eigen::Matrix< double, Cb, Cb > &adjC)
Definition: quad_form.hpp:70
void chainA(Eigen::Matrix< double, Ra, Ca > &A, const Eigen::Matrix< double, Rb, Cb > &Bd, const Eigen::Matrix< double, Cb, Cb > &adjC)
Definition: quad_form.hpp:52
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition: value_of.hpp:17
quad_form_vari_alloc(const Eigen::Matrix< Ta, Ra, Ca > &A, const Eigen::Matrix< Tb, Rb, Cb > &B, bool symmetric=false)
Definition: quad_form.hpp:36
void chainB(Eigen::Matrix< double, Rb, Cb > &B, const Eigen::Matrix< double, Ra, Ca > &Ad, const Eigen::Matrix< double, Rb, Cb > &Bd, const Eigen::Matrix< double, Cb, Cb > &adjC)
Definition: quad_form.hpp:55
void check_square(const char *function, const char *name, const matrix_cl &y)
Check if the matrix_cl is square.
The variable implementation base class.
Definition: vari.hpp:30
Eigen::Matrix< Ta, Ra, Ca > A_
Definition: quad_form.hpp:43
quad_form_vari_alloc< Ta, Ra, Ca, Tb, Rb, Cb > * impl_
Definition: quad_form.hpp:108
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:33
Eigen::Matrix< Tb, Rb, Cb > B_
Definition: quad_form.hpp:44
int cols(const Eigen::Matrix< T, R, C > &m)
Return the number of columns in the specified matrix, vector, or row vector.
Definition: cols.hpp:20
virtual void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
Definition: quad_form.hpp:97
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void chainA(Eigen::Matrix< var, Ra, Ca > &A, const Eigen::Matrix< double, Rb, Cb > &Bd, const Eigen::Matrix< double, Cb, Cb > &adjC)
Definition: quad_form.hpp:60
quad_form_vari(const Eigen::Matrix< Ta, Ra, Ca > &A, const Eigen::Matrix< Tb, Rb, Cb > &B, bool symmetric=false)
Definition: quad_form.hpp:91
void chainAB(Eigen::Matrix< Ta, Ra, Ca > &A, Eigen::Matrix< Tb, Rb, Cb > &B, const Eigen::Matrix< double, Ra, Ca > &Ad, const Eigen::Matrix< double, Rb, Cb > &Bd, const Eigen::Matrix< double, Cb, Cb > &adjC)
Definition: quad_form.hpp:81
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
Eigen::Matrix< var, Cb, Cb > C_
Definition: quad_form.hpp:45
Eigen::Matrix< T, CB, CB > quad_form(const Eigen::Matrix< T, RA, CA > &A, const Eigen::Matrix< T, RB, CB > &B)
Compute B^T A B.
Definition: quad_form.hpp:14

     [ Stan Home Page ] © 2011–2018, Stan Development Team.