1 #ifndef STAN_MATH_REV_MAT_FUN_TRACE_GEN_QUAD_FORM_HPP 2 #define STAN_MATH_REV_MAT_FUN_TRACE_GEN_QUAD_FORM_HPP 13 #include <type_traits> 18 template <
typename Td,
int Rd,
int Cd,
typename Ta,
int Ra,
int Ca,
typename Tb,
23 const Eigen::Matrix<Ta, Ra, Ca>& A,
24 const Eigen::Matrix<Tb, Rb, Cb>& B)
31 Eigen::Matrix<Td, Rd, Cd>
D_;
32 Eigen::Matrix<Ta, Ra, Ca>
A_;
33 Eigen::Matrix<Tb, Rb, Cb>
B_;
36 template <
typename Td,
int Rd,
int Cd,
typename Ta,
int Ra,
int Ca,
typename Tb,
41 const Eigen::Matrix<double, Rd, Cd>& D,
42 const Eigen::Matrix<double, Ra, Ca>& A,
43 const Eigen::Matrix<double, Rb, Cb>& B,
44 Eigen::Matrix<var, Rd, Cd>* varD,
45 Eigen::Matrix<var, Ra, Ca>* varA,
46 Eigen::Matrix<var, Rb, Cb>* varB) {
47 Eigen::Matrix<double, Ca, Cb> AtB;
48 Eigen::Matrix<double, Ra, Cb> BD;
52 AtB.noalias() = A.transpose() * B;
55 Eigen::Matrix<double, Rb, Cb> adjB(adj * (A * BD + AtB * D.transpose()));
56 for (
int j = 0; j < B.cols(); j++)
57 for (
int i = 0; i < B.rows(); i++)
58 (*varB)(i, j).vi_->adj_ += adjB(i, j);
61 Eigen::Matrix<double, Ra, Ca> adjA(adj * (B * BD.transpose()));
62 for (
int j = 0; j < A.cols(); j++)
63 for (
int i = 0; i < A.rows(); i++)
64 (*varA)(i, j).vi_->adj_ += adjA(i, j);
67 Eigen::Matrix<double, Rd, Cd> adjD(adj * (B.transpose() * AtB));
68 for (
int j = 0; j < D.cols(); j++)
69 for (
int i = 0; i < D.rows(); i++)
70 (*varD)(i, j).vi_->adj_ += adjD(i, j);
82 reinterpret_cast<Eigen::Matrix<var, Rd, Cd>*>(
83 std::is_same<Td, var>::value ? (&impl_->D_) : NULL),
84 reinterpret_cast<Eigen::Matrix<var, Ra, Ca>*>(
85 std::is_same<Ta, var>::value ? (&impl_->A_) : NULL),
86 reinterpret_cast<Eigen::Matrix<var, Rb, Cb>*>(
87 std::is_same<Tb, var>::value ? (&impl_->B_) : NULL));
94 template <
typename Td,
int Rd,
int Cd,
typename Ta,
int Ra,
int Ca,
typename Tb,
96 inline typename std::enable_if<std::is_same<Td, var>::value
97 || std::is_same<Ta, var>::value
98 || std::is_same<Tb, var>::value,
101 const Eigen::Matrix<Ta, Ra, Ca>& A,
102 const Eigen::Matrix<Tb, Rb, Cb>& B) {
T value_of(const fvar< T > &v)
Return the value of the specified variable.
void check_square(const char *function, const char *name, const matrix_cl &y)
Check if the matrix_cl is square.
The variable implementation base class.
Independent (input) and dependent (output) variables for gradients.
fvar< T > trace_gen_quad_form(const Eigen::Matrix< fvar< T >, RD, CD > &D, const Eigen::Matrix< fvar< T >, RA, CA > &A, const Eigen::Matrix< fvar< T >, RB, CB > &B)
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...