Stan Math Library  2.20.0
reverse mode automatic differentiation
trace_inv_quad_form_ldlt.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
2 #define STAN_MATH_REV_MAT_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
3 
4 #include <stan/math/rev/meta.hpp>
6 #include <stan/math/rev/core.hpp>
10 #include <type_traits>
11 
12 namespace stan {
13 namespace math {
14 
15 namespace internal {
16 template <typename T2, int R2, int C2, typename T3, int R3, int C3>
18  protected:
19  inline void initializeB(const Eigen::Matrix<var, R3, C3> &B, bool haveD) {
20  Eigen::Matrix<double, R3, C3> Bd(B.rows(), B.cols());
21  variB_.resize(B.rows(), B.cols());
22  for (int j = 0; j < B.cols(); j++) {
23  for (int i = 0; i < B.rows(); i++) {
24  variB_(i, j) = B(i, j).vi_;
25  Bd(i, j) = B(i, j).val();
26  }
27  }
28  AinvB_ = ldlt_.solve(Bd);
29  if (haveD)
30  C_.noalias() = Bd.transpose() * AinvB_;
31  else
32  value_ = (Bd.transpose() * AinvB_).trace();
33  }
34  inline void initializeB(const Eigen::Matrix<double, R3, C3> &B, bool haveD) {
35  AinvB_ = ldlt_.solve(B);
36  if (haveD)
37  C_.noalias() = B.transpose() * AinvB_;
38  else
39  value_ = (B.transpose() * AinvB_).trace();
40  }
41 
42  template <int R1, int C1>
43  inline void initializeD(const Eigen::Matrix<var, R1, C1> &D) {
44  D_.resize(D.rows(), D.cols());
45  variD_.resize(D.rows(), D.cols());
46  for (int j = 0; j < D.cols(); j++) {
47  for (int i = 0; i < D.rows(); i++) {
48  variD_(i, j) = D(i, j).vi_;
49  D_(i, j) = D(i, j).val();
50  }
51  }
52  }
53  template <int R1, int C1>
54  inline void initializeD(const Eigen::Matrix<double, R1, C1> &D) {
55  D_ = D;
56  }
57 
58  public:
59  template <typename T1, int R1, int C1>
60  trace_inv_quad_form_ldlt_impl(const Eigen::Matrix<T1, R1, C1> &D,
61  const LDLT_factor<T2, R2, C2> &A,
62  const Eigen::Matrix<T3, R3, C3> &B)
63  : Dtype_(stan::is_var<T1>::value), ldlt_(A) {
64  initializeB(B, true);
65  initializeD(D);
66 
67  value_ = (D_ * C_).trace();
68  }
69 
71  const Eigen::Matrix<T3, R3, C3> &B)
72  : Dtype_(2), ldlt_(A) {
73  initializeB(B, false);
74  }
75 
76  const int Dtype_; // 0 = double, 1 = var, 2 = missing
78  Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> D_;
79  Eigen::Matrix<vari *, Eigen::Dynamic, Eigen::Dynamic> variD_;
80  Eigen::Matrix<vari *, R3, C3> variB_;
81  Eigen::Matrix<double, R3, C3> AinvB_;
82  Eigen::Matrix<double, C3, C3> C_;
83  double value_;
84 };
85 
86 template <typename T2, int R2, int C2, typename T3, int R3, int C3>
88  protected:
89  static inline void chainA(
90  double adj,
92  static inline void chainB(
93  double adj,
95 
96  static inline void chainA(
97  double adj,
99  Eigen::Matrix<double, R2, C2> aA;
100 
101  if (impl->Dtype_ != 2)
102  aA.noalias()
103  = -adj
104  * (impl->AinvB_ * impl->D_.transpose() * impl->AinvB_.transpose());
105  else
106  aA.noalias() = -adj * (impl->AinvB_ * impl->AinvB_.transpose());
107 
108  for (int j = 0; j < aA.cols(); j++)
109  for (int i = 0; i < aA.rows(); i++)
110  impl->ldlt_.alloc_->variA_(i, j)->adj_ += aA(i, j);
111  }
112  static inline void chainB(
113  double adj,
115  Eigen::Matrix<double, R3, C3> aB;
116 
117  if (impl->Dtype_ != 2)
118  aB.noalias() = adj * impl->AinvB_ * (impl->D_ + impl->D_.transpose());
119  else
120  aB.noalias() = 2 * adj * impl->AinvB_;
121 
122  for (int j = 0; j < aB.cols(); j++)
123  for (int i = 0; i < aB.rows(); i++)
124  impl->variB_(i, j)->adj_ += aB(i, j);
125  }
126 
127  public:
130  : vari(impl->value_), impl_(impl) {}
131 
132  virtual void chain() {
133  // F = trace(D * B' * inv(A) * B)
134  // aA = -aF * inv(A') * B * D' * B' * inv(A')
135  // aB = aF*(inv(A) * B * D + inv(A') * B * D')
136  // aD = aF*(B' * inv(A) * B)
137  chainA(adj_, impl_);
138 
139  chainB(adj_, impl_);
140 
141  if (impl_->Dtype_ == 1) {
142  for (int j = 0; j < impl_->variD_.cols(); j++)
143  for (int i = 0; i < impl_->variD_.rows(); i++)
144  impl_->variD_(i, j)->adj_ += adj_ * impl_->C_(i, j);
145  }
146  }
147 
149 };
150 
151 } // namespace internal
152 
158 template <typename T2, int R2, int C2, typename T3, int R3, int C3>
159 inline
160  typename std::enable_if<stan::is_var<T2>::value || stan::is_var<T3>::value,
161  var>::type
163  const Eigen::Matrix<T3, R3, C3> &B) {
164  check_multiplicable("trace_inv_quad_form_ldlt", "A", A, "B", B);
165 
168  B);
169 
170  return var(
172  impl_));
173 }
174 
175 } // namespace math
176 } // namespace stan
177 #endif
void initializeB(const Eigen::Matrix< double, R3, C3 > &B, bool haveD)
Defines a public enum named value which is defined to be false as the primitive scalar types cannot b...
Definition: is_var.hpp:10
The variable implementation base class.
Definition: vari.hpp:30
const Eigen::internal::solve_retval< ldlt_t, Rhs > solve(const Eigen::MatrixBase< Rhs > &b) const
static void chainA(double adj, trace_inv_quad_form_ldlt_impl< var, R2, C2, T3, R3, C3 > *impl)
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:33
static void chainB(double adj, trace_inv_quad_form_ldlt_impl< T2, R2, C2, var, R3, C3 > *impl)
static void chainB(double adj, trace_inv_quad_form_ldlt_impl< T2, R2, C2, double, R3, C3 > *impl)
trace_inv_quad_form_ldlt_impl(const Eigen::Matrix< T1, R1, C1 > &D, const LDLT_factor< T2, R2, C2 > &A, const Eigen::Matrix< T3, R3, C3 > &B)
void initializeD(const Eigen::Matrix< var, R1, C1 > &D)
std::enable_if< !stan::is_var< T1 >::value &&!stan::is_var< T2 >::value, typename boost::math::tools::promote_args< T1, T2 >::type >::type trace_inv_quad_form_ldlt(const LDLT_factor< T1, R2, C2 > &A, const Eigen::Matrix< T2, R3, C3 > &B)
trace_inv_quad_form_ldlt_impl< T2, R2, C2, T3, R3, C3 > * impl_
Eigen::Matrix< vari *, Eigen::Dynamic, Eigen::Dynamic > variD_
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
trace_inv_quad_form_ldlt_vari(trace_inv_quad_form_ldlt_impl< T2, R2, C2, T3, R3, C3 > *impl)
virtual void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic > D_
void initializeB(const Eigen::Matrix< var, R3, C3 > &B, bool haveD)
T trace(const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > &m)
Returns the trace of the specified matrix.
Definition: trace.hpp:19
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
trace_inv_quad_form_ldlt_impl(const LDLT_factor< T2, R2, C2 > &A, const Eigen::Matrix< T3, R3, C3 > &B)
static void chainA(double adj, trace_inv_quad_form_ldlt_impl< double, R2, C2, T3, R3, C3 > *impl)
void initializeD(const Eigen::Matrix< double, R1, C1 > &D)

     [ Stan Home Page ] © 2011–2018, Stan Development Team.