Stan Math Library  2.20.0
reverse mode automatic differentiation
Eigen_NumTraits.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUN_EIGEN_NUMTRAITS_HPP
2 #define STAN_MATH_REV_MAT_FUN_EIGEN_NUMTRAITS_HPP
3 
4 #include <stan/math/rev/meta.hpp>
6 #include <stan/math/rev/core.hpp>
8 #include <limits>
9 
10 namespace Eigen {
11 
19 template <>
20 struct NumTraits<stan::math::var> : GenericNumTraits<stan::math::var> {
24 
31  static inline stan::math::var dummy_precision() {
32  return NumTraits<double>::dummy_precision();
33  }
34 
35  enum {
39  IsComplex = 0,
40 
44  IsInteger = 0,
45 
49  IsSigned = 1,
50 
54  RequireInitialization = 0,
55 
59  ReadCost = 2 * NumTraits<double>::ReadCost,
60 
65  AddCost = NumTraits<double>::AddCost,
66 
71  MulCost = NumTraits<double>::MulCost
72  };
73 
79  static int digits10() { return std::numeric_limits<double>::digits10; }
80 };
81 
82 namespace internal {
87 template <>
88 struct remove_all<stan::math::vari*> {
90 };
91 
92 #if EIGEN_VERSION_AT_LEAST(3, 3, 0)
93 
97 template <>
98 struct scalar_product_traits<stan::math::var, double> {
99  typedef stan::math::var ReturnType;
100 };
101 
106 template <>
107 struct scalar_product_traits<double, stan::math::var> {
108  typedef stan::math::var ReturnType;
109 };
110 
122 template <typename Index, typename LhsMapper, bool ConjugateLhs,
123  bool ConjugateRhs, typename RhsMapper, int Version>
124 struct general_matrix_vector_product<Index, stan::math::var, LhsMapper,
125  ColMajor, ConjugateLhs, stan::math::var,
126  RhsMapper, ConjugateRhs, Version> {
127  typedef stan::math::var LhsScalar;
128  typedef stan::math::var RhsScalar;
129  typedef stan::math::var ResScalar;
130  enum { LhsStorageOrder = ColMajor };
131 
132  EIGEN_DONT_INLINE static void run(Index rows, Index cols,
133  const LhsMapper& lhsMapper,
134  const RhsMapper& rhsMapper, ResScalar* res,
135  Index resIncr, const ResScalar& alpha) {
136  const LhsScalar* lhs = lhsMapper.data();
137  const Index lhsStride = lhsMapper.stride();
138  const RhsScalar* rhs = rhsMapper.data();
139  const Index rhsIncr = rhsMapper.stride();
140  run(rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha);
141  }
142 
143  EIGEN_DONT_INLINE static void run(Index rows, Index cols,
144  const LhsScalar* lhs, Index lhsStride,
145  const RhsScalar* rhs, Index rhsIncr,
146  ResScalar* res, Index resIncr,
147  const ResScalar& alpha) {
149  using stan::math::var;
150  for (Index i = 0; i < rows; ++i) {
151  res[i * resIncr] += var(
152  new gevv_vvv_vari(&alpha, &lhs[i], lhsStride, rhs, rhsIncr, cols));
153  }
154  }
155 };
156 
157 template <typename Index, typename LhsMapper, bool ConjugateLhs,
158  bool ConjugateRhs, typename RhsMapper, int Version>
159 struct general_matrix_vector_product<Index, stan::math::var, LhsMapper,
160  RowMajor, ConjugateLhs, stan::math::var,
161  RhsMapper, ConjugateRhs, Version> {
162  typedef stan::math::var LhsScalar;
163  typedef stan::math::var RhsScalar;
164  typedef stan::math::var ResScalar;
165  enum { LhsStorageOrder = RowMajor };
166 
167  EIGEN_DONT_INLINE static void run(Index rows, Index cols,
168  const LhsMapper& lhsMapper,
169  const RhsMapper& rhsMapper, ResScalar* res,
170  Index resIncr, const RhsScalar& alpha) {
171  const LhsScalar* lhs = lhsMapper.data();
172  const Index lhsStride = lhsMapper.stride();
173  const RhsScalar* rhs = rhsMapper.data();
174  const Index rhsIncr = rhsMapper.stride();
175  run(rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha);
176  }
177 
178  EIGEN_DONT_INLINE static void run(Index rows, Index cols,
179  const LhsScalar* lhs, Index lhsStride,
180  const RhsScalar* rhs, Index rhsIncr,
181  ResScalar* res, Index resIncr,
182  const RhsScalar& alpha) {
183  for (Index i = 0; i < rows; i++) {
184  res[i * resIncr] += stan::math::var(new stan::math::gevv_vvv_vari(
185  &alpha,
186  (static_cast<int>(LhsStorageOrder) == static_cast<int>(ColMajor))
187  ? (&lhs[i])
188  : (&lhs[i * lhsStride]),
189  (static_cast<int>(LhsStorageOrder) == static_cast<int>(ColMajor))
190  ? (lhsStride)
191  : (1),
192  rhs, rhsIncr, cols));
193  }
194  }
195 };
196 
197 template <typename Index, int LhsStorageOrder, bool ConjugateLhs,
198  int RhsStorageOrder, bool ConjugateRhs>
199 struct general_matrix_matrix_product<Index, stan::math::var, LhsStorageOrder,
200  ConjugateLhs, stan::math::var,
201  RhsStorageOrder, ConjugateRhs, ColMajor> {
202  typedef stan::math::var LhsScalar;
203  typedef stan::math::var RhsScalar;
204  typedef stan::math::var ResScalar;
205 
206  typedef gebp_traits<RhsScalar, LhsScalar> Traits;
207 
208  typedef const_blas_data_mapper<stan::math::var, Index, LhsStorageOrder>
209  LhsMapper;
210  typedef const_blas_data_mapper<stan::math::var, Index, RhsStorageOrder>
211  RhsMapper;
212 
213  EIGEN_DONT_INLINE
214  static void run(Index rows, Index cols, Index depth, const LhsScalar* lhs,
215  Index lhsStride, const RhsScalar* rhs, Index rhsStride,
216  ResScalar* res, Index resStride, const ResScalar& alpha,
217  level3_blocking<LhsScalar, RhsScalar>& /* blocking */,
218  GemmParallelInfo<Index>* /* info = 0 */) {
219  for (Index i = 0; i < cols; i++) {
220  general_matrix_vector_product<
221  Index, LhsScalar, LhsMapper, LhsStorageOrder, ConjugateLhs, RhsScalar,
222  RhsMapper,
223  ConjugateRhs>::run(rows, depth, lhs, lhsStride,
224  &rhs[static_cast<int>(RhsStorageOrder)
225  == static_cast<int>(ColMajor)
226  ? i * rhsStride
227  : i],
228  static_cast<int>(RhsStorageOrder)
229  == static_cast<int>(ColMajor)
230  ? 1
231  : rhsStride,
232  &res[i * resStride], 1, alpha);
233  }
234  }
235 
236  EIGEN_DONT_INLINE
237  static void run(Index rows, Index cols, Index depth,
238  const LhsMapper& lhsMapper, const RhsMapper& rhsMapper,
239  ResScalar* res, Index resStride, const ResScalar& alpha,
240  level3_blocking<LhsScalar, RhsScalar>& blocking,
241  GemmParallelInfo<Index>* info = 0) {
242  const LhsScalar* lhs = lhsMapper.data();
243  const Index lhsStride = lhsMapper.stride();
244  const RhsScalar* rhs = rhsMapper.data();
245  const Index rhsStride = rhsMapper.stride();
246 
247  run(rows, cols, depth, lhs, lhsStride, rhs, rhsStride, res, resStride,
248  alpha, blocking, info);
249  }
250 };
251 #else
252 
255 template <>
256 struct significant_decimals_default_impl<stan::math::var, false> {
257  static inline int run() {
258  using std::ceil;
259  using std::log;
260  return cast<double, int>(
261  ceil(-log(std::numeric_limits<double>::epsilon()) / log(10.0)));
262  }
263 };
264 
269 template <>
270 struct scalar_product_traits<stan::math::var, double> {
272 };
273 
278 template <>
279 struct scalar_product_traits<double, stan::math::var> {
281 };
282 
287 template <typename Index, bool ConjugateLhs, bool ConjugateRhs>
288 struct general_matrix_vector_product<Index, stan::math::var, ColMajor,
289  ConjugateLhs, stan::math::var,
290  ConjugateRhs> {
293  typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType
295  enum { LhsStorageOrder = ColMajor };
296 
297  EIGEN_DONT_INLINE static void run(Index rows, Index cols,
298  const LhsScalar* lhs, Index lhsStride,
299  const RhsScalar* rhs, Index rhsIncr,
300  ResScalar* res, Index resIncr,
301  const ResScalar& alpha) {
302  for (Index i = 0; i < rows; i++) {
303  res[i * resIncr] += stan::math::var(new stan::math::gevv_vvv_vari(
304  &alpha,
305  (static_cast<int>(LhsStorageOrder) == static_cast<int>(ColMajor))
306  ? (&lhs[i])
307  : (&lhs[i * lhsStride]),
308  (static_cast<int>(LhsStorageOrder) == static_cast<int>(ColMajor))
309  ? (lhsStride)
310  : (1),
311  rhs, rhsIncr, cols));
312  }
313  }
314 };
315 template <typename Index, bool ConjugateLhs, bool ConjugateRhs>
316 struct general_matrix_vector_product<Index, stan::math::var, RowMajor,
317  ConjugateLhs, stan::math::var,
318  ConjugateRhs> {
321  typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType
323  enum { LhsStorageOrder = RowMajor };
324 
325  EIGEN_DONT_INLINE static void run(Index rows, Index cols,
326  const LhsScalar* lhs, Index lhsStride,
327  const RhsScalar* rhs, Index rhsIncr,
328  ResScalar* res, Index resIncr,
329  const RhsScalar& alpha) {
330  for (Index i = 0; i < rows; i++) {
331  res[i * resIncr] += stan::math::var(new stan::math::gevv_vvv_vari(
332  &alpha,
333  (static_cast<int>(LhsStorageOrder) == static_cast<int>(ColMajor))
334  ? (&lhs[i])
335  : (&lhs[i * lhsStride]),
336  (static_cast<int>(LhsStorageOrder) == static_cast<int>(ColMajor))
337  ? (lhsStride)
338  : (1),
339  rhs, rhsIncr, cols));
340  }
341  }
342 };
343 template <typename Index, int LhsStorageOrder, bool ConjugateLhs,
344  int RhsStorageOrder, bool ConjugateRhs>
345 struct general_matrix_matrix_product<Index, stan::math::var, LhsStorageOrder,
346  ConjugateLhs, stan::math::var,
347  RhsStorageOrder, ConjugateRhs, ColMajor> {
350  typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType
352  static void run(Index rows, Index cols, Index depth, const LhsScalar* lhs,
353  Index lhsStride, const RhsScalar* rhs, Index rhsStride,
354  ResScalar* res, Index resStride, const ResScalar& alpha,
355  level3_blocking<LhsScalar, RhsScalar>& /* blocking */,
356  GemmParallelInfo<Index>* /* info = 0 */) {
357  for (Index i = 0; i < cols; i++) {
358  general_matrix_vector_product<
359  Index, LhsScalar, LhsStorageOrder, ConjugateLhs, RhsScalar,
360  ConjugateRhs>::run(rows, depth, lhs, lhsStride,
361  &rhs[(static_cast<int>(RhsStorageOrder)
362  == static_cast<int>(ColMajor))
363  ? (i * rhsStride)
364  : (i)],
365  (static_cast<int>(RhsStorageOrder)
366  == static_cast<int>(ColMajor))
367  ? (1)
368  : (rhsStride),
369  &res[i * resStride], 1, alpha);
370  }
371  }
372 };
373 #endif
374 } // namespace internal
375 } // namespace Eigen
376 #endif
int rows(const Eigen::Matrix< T, R, C > &m)
Return the number of rows in the specified matrix, vector, or row vector.
Definition: rows.hpp:20
static int digits10()
Return the number of decimal digits that can be represented without change.
static EIGEN_DONT_INLINE void run(Index rows, Index cols, const LhsScalar *lhs, Index lhsStride, const RhsScalar *rhs, Index rhsIncr, ResScalar *res, Index resIncr, const ResScalar &alpha)
fvar< T > log(const fvar< T > &x)
Definition: log.hpp:12
The variable implementation base class.
Definition: vari.hpp:30
static stan::math::var dummy_precision()
Return the precision for stan::math::var delegates to precision for douboe.
static void run(Index rows, Index cols, Index depth, const LhsScalar *lhs, Index lhsStride, const RhsScalar *rhs, Index rhsStride, ResScalar *res, Index resStride, const ResScalar &alpha, level3_blocking< LhsScalar, RhsScalar > &, GemmParallelInfo< Index > *)
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:33
(Expert) Numerical traits for algorithmic differentiation variables.
int cols(const Eigen::Matrix< T, R, C > &m)
Return the number of columns in the specified matrix, vector, or row vector.
Definition: cols.hpp:20
fvar< T > ceil(const fvar< T > &x)
Definition: ceil.hpp:12
static EIGEN_DONT_INLINE void run(Index rows, Index cols, const LhsScalar *lhs, Index lhsStride, const RhsScalar *rhs, Index rhsIncr, ResScalar *res, Index resIncr, const RhsScalar &alpha)

     [ Stan Home Page ] © 2011–2018, Stan Development Team.