1 #ifndef STAN_MATH_REV_MAT_FUN_QUAD_FORM_HPP 2 #define STAN_MATH_REV_MAT_FUN_QUAD_FORM_HPP 12 #include <type_traits> 18 template <
typename Ta,
int Ra,
int Ca,
typename Tb,
int Rb,
int Cb>
21 inline void compute(
const Eigen::Matrix<double, Ra, Ca>& A,
22 const Eigen::Matrix<double, Rb, Cb>& B) {
23 Eigen::Matrix<double, Cb, Cb> Cd(B.transpose() * A * B);
24 for (
int j = 0; j <
C_.cols(); j++) {
25 for (
int i = 0; i <
C_.rows(); i++) {
27 C_(i, j) =
var(
new vari(0.5 * (Cd(i, j) + Cd(j, i)),
false));
29 C_(i, j) =
var(
new vari(Cd(i, j),
false));
37 const Eigen::Matrix<Tb, Rb, Cb>& B,
38 bool symmetric =
false)
43 Eigen::Matrix<Ta, Ra, Ca>
A_;
44 Eigen::Matrix<Tb, Rb, Cb>
B_;
45 Eigen::Matrix<var, Cb, Cb>
C_;
49 template <
typename Ta,
int Ra,
int Ca,
typename Tb,
int Rb,
int Cb>
52 inline void chainA(Eigen::Matrix<double, Ra, Ca>& A,
53 const Eigen::Matrix<double, Rb, Cb>& Bd,
54 const Eigen::Matrix<double, Cb, Cb>& adjC) {}
55 inline void chainB(Eigen::Matrix<double, Rb, Cb>& B,
56 const Eigen::Matrix<double, Ra, Ca>& Ad,
57 const Eigen::Matrix<double, Rb, Cb>& Bd,
58 const Eigen::Matrix<double, Cb, Cb>& adjC) {}
60 inline void chainA(Eigen::Matrix<var, Ra, Ca>& A,
61 const Eigen::Matrix<double, Rb, Cb>& Bd,
62 const Eigen::Matrix<double, Cb, Cb>& adjC) {
63 Eigen::Matrix<double, Ra, Ca> adjA(Bd * adjC * Bd.transpose());
64 for (
int j = 0; j < A.cols(); j++) {
65 for (
int i = 0; i < A.rows(); i++) {
66 A(i, j).vi_->adj_ += adjA(i, j);
70 inline void chainB(Eigen::Matrix<var, Rb, Cb>& B,
71 const Eigen::Matrix<double, Ra, Ca>& Ad,
72 const Eigen::Matrix<double, Rb, Cb>& Bd,
73 const Eigen::Matrix<double, Cb, Cb>& adjC) {
74 Eigen::Matrix<double, Ra, Ca> adjB(Ad * Bd * adjC.transpose()
75 + Ad.transpose() * Bd * adjC);
76 for (
int j = 0; j < B.cols(); j++)
77 for (
int i = 0; i < B.rows(); i++)
78 B(i, j).vi_->adj_ += adjB(i, j);
81 inline void chainAB(Eigen::Matrix<Ta, Ra, Ca>& A,
82 Eigen::Matrix<Tb, Rb, Cb>& B,
83 const Eigen::Matrix<double, Ra, Ca>& Ad,
84 const Eigen::Matrix<double, Rb, Cb>& Bd,
85 const Eigen::Matrix<double, Cb, Cb>& adjC) {
87 chainB(B, Ad, Bd, adjC);
92 const Eigen::Matrix<Tb, Rb, Cb>& B,
bool symmetric =
false)
98 Eigen::Matrix<double, Cb, Cb> adjC(impl_->C_.rows(), impl_->C_.cols());
100 for (
int j = 0; j < impl_->C_.cols(); j++)
101 for (
int i = 0; i < impl_->C_.rows(); i++)
102 adjC(i, j) = impl_->C_(i, j).vi_->adj_;
112 template <
typename Ta,
int Ra,
int Ca,
typename Tb,
int Rb,
int Cb>
113 inline typename std::enable_if<std::is_same<Ta, var>::value
114 || std::is_same<Tb, var>::value,
115 Eigen::Matrix<var, Cb, Cb> >::type
117 const Eigen::Matrix<Tb, Rb, Cb>& B) {
124 return baseVari->
impl_->C_;
127 template <
typename Ta,
int Ra,
int Ca,
typename Tb,
int Rb>
128 inline typename std::enable_if<
129 std::is_same<Ta, var>::value || std::is_same<Tb, var>::value,
var>::type
131 const Eigen::Matrix<Tb, Rb, 1>& B) {
138 return baseVari->
impl_->C_(0, 0);
T value_of(const fvar< T > &v)
Return the value of the specified variable.
void check_square(const char *function, const char *name, const matrix_cl &y)
Check if the matrix_cl is square.
The variable implementation base class.
Independent (input) and dependent (output) variables for gradients.
int cols(const Eigen::Matrix< T, R, C > &m)
Return the number of columns in the specified matrix, vector, or row vector.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
Eigen::Matrix< T, CB, CB > quad_form(const Eigen::Matrix< T, RA, CA > &A, const Eigen::Matrix< T, RB, CB > &B)
Compute B^T A B.