Stan Math Library  2.20.0
reverse mode automatic differentiation
trace_quad_form.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUN_TRACE_QUAD_FORM_HPP
2 #define STAN_MATH_REV_MAT_FUN_TRACE_QUAD_FORM_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
13 #include <type_traits>
14 
15 namespace stan {
16 namespace math {
17 namespace internal {
18 template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
20  public:
21  trace_quad_form_vari_alloc(const Eigen::Matrix<Ta, Ra, Ca>& A,
22  const Eigen::Matrix<Tb, Rb, Cb>& B)
23  : A_(A), B_(B) {}
24 
25  double compute() { return trace_quad_form(value_of(A_), value_of(B_)); }
26 
27  Eigen::Matrix<Ta, Ra, Ca> A_;
28  Eigen::Matrix<Tb, Rb, Cb> B_;
29 };
30 
31 template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
32 class trace_quad_form_vari : public vari {
33  protected:
34  static inline void chainA(Eigen::Matrix<double, Ra, Ca>& A,
35  const Eigen::Matrix<double, Rb, Cb>& Bd,
36  double adjC) {}
37  static inline void chainB(Eigen::Matrix<double, Rb, Cb>& B,
38  const Eigen::Matrix<double, Ra, Ca>& Ad,
39  const Eigen::Matrix<double, Rb, Cb>& Bd,
40  double adjC) {}
41 
42  static inline void chainA(Eigen::Matrix<var, Ra, Ca>& A,
43  const Eigen::Matrix<double, Rb, Cb>& Bd,
44  double adjC) {
45  Eigen::Matrix<double, Ra, Ca> adjA(adjC * Bd * Bd.transpose());
46  for (int j = 0; j < A.cols(); j++)
47  for (int i = 0; i < A.rows(); i++)
48  A(i, j).vi_->adj_ += adjA(i, j);
49  }
50  static inline void chainB(Eigen::Matrix<var, Rb, Cb>& B,
51  const Eigen::Matrix<double, Ra, Ca>& Ad,
52  const Eigen::Matrix<double, Rb, Cb>& Bd,
53  double adjC) {
54  Eigen::Matrix<double, Ra, Ca> adjB(adjC * (Ad + Ad.transpose()) * Bd);
55  for (int j = 0; j < B.cols(); j++)
56  for (int i = 0; i < B.rows(); i++)
57  B(i, j).vi_->adj_ += adjB(i, j);
58  }
59 
60  inline void chainAB(Eigen::Matrix<Ta, Ra, Ca>& A,
61  Eigen::Matrix<Tb, Rb, Cb>& B,
62  const Eigen::Matrix<double, Ra, Ca>& Ad,
63  const Eigen::Matrix<double, Rb, Cb>& Bd, double adjC) {
64  chainA(A, Bd, adjC);
65  chainB(B, Ad, Bd, adjC);
66  }
67 
68  public:
71  : vari(impl->compute()), impl_(impl) {}
72 
73  virtual void chain() {
74  chainAB(impl_->A_, impl_->B_, value_of(impl_->A_), value_of(impl_->B_),
75  adj_);
76  }
77 
79 };
80 } // namespace internal
81 
82 template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
83 inline typename std::enable_if<
84  std::is_same<Ta, var>::value || std::is_same<Tb, var>::value, var>::type
85 trace_quad_form(const Eigen::Matrix<Ta, Ra, Ca>& A,
86  const Eigen::Matrix<Tb, Rb, Cb>& B) {
87  check_square("trace_quad_form", "A", A);
88  check_multiplicable("trace_quad_form", "A", A, "B", B);
89 
92 
93  return var(
95 }
96 
97 } // namespace math
98 } // namespace stan
99 #endif
trace_quad_form_vari_alloc(const Eigen::Matrix< Ta, Ra, Ca > &A, const Eigen::Matrix< Tb, Rb, Cb > &B)
fvar< T > trace_quad_form(const Eigen::Matrix< fvar< T >, RA, CA > &A, const Eigen::Matrix< fvar< T >, RB, CB > &B)
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition: value_of.hpp:17
void check_square(const char *function, const char *name, const matrix_cl &y)
Check if the matrix_cl is square.
The variable implementation base class.
Definition: vari.hpp:30
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:33
void chainAB(Eigen::Matrix< Ta, Ra, Ca > &A, Eigen::Matrix< Tb, Rb, Cb > &B, const Eigen::Matrix< double, Ra, Ca > &Ad, const Eigen::Matrix< double, Rb, Cb > &Bd, double adjC)
virtual void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
trace_quad_form_vari_alloc< Ta, Ra, Ca, Tb, Rb, Cb > * impl_
static void chainA(Eigen::Matrix< double, Ra, Ca > &A, const Eigen::Matrix< double, Rb, Cb > &Bd, double adjC)
static void chainB(Eigen::Matrix< double, Rb, Cb > &B, const Eigen::Matrix< double, Ra, Ca > &Ad, const Eigen::Matrix< double, Rb, Cb > &Bd, double adjC)
static void chainB(Eigen::Matrix< var, Rb, Cb > &B, const Eigen::Matrix< double, Ra, Ca > &Ad, const Eigen::Matrix< double, Rb, Cb > &Bd, double adjC)
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
static void chainA(Eigen::Matrix< var, Ra, Ca > &A, const Eigen::Matrix< double, Rb, Cb > &Bd, double adjC)
trace_quad_form_vari(trace_quad_form_vari_alloc< Ta, Ra, Ca, Tb, Rb, Cb > *impl)

     [ Stan Home Page ] © 2011–2018, Stan Development Team.