Stan Math Library  2.20.0
reverse mode automatic differentiation
log_softmax.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUN_LOG_SOFTMAX_HPP
2 #define STAN_MATH_REV_MAT_FUN_LOG_SOFTMAX_HPP
3 
4 #include <stan/math/rev/meta.hpp>
9 #include <stan/math/rev/core.hpp>
10 #include <cmath>
11 #include <vector>
12 
13 namespace stan {
14 namespace math {
15 
16 namespace internal {
17 
18 class log_softmax_elt_vari : public vari {
19  private:
20  vari** alpha_;
21  const double* softmax_alpha_;
22  const int size_; // array sizes
23  const int idx_; // in in softmax output
24 
25  public:
26  log_softmax_elt_vari(double val, vari** alpha, const double* softmax_alpha,
27  int size, int idx)
28  : vari(val),
29  alpha_(alpha),
30  softmax_alpha_(softmax_alpha),
31  size_(size),
32  idx_(idx) {}
33  void chain() {
34  for (int m = 0; m < size_; ++m) {
35  if (m == idx_)
36  alpha_[m]->adj_ += adj_ * (1 - softmax_alpha_[m]);
37  else
38  alpha_[m]->adj_ -= adj_ * softmax_alpha_[m];
39  }
40  }
41 };
42 
43 } // namespace internal
44 
55 inline Eigen::Matrix<var, Eigen::Dynamic, 1> log_softmax(
56  const Eigen::Matrix<var, Eigen::Dynamic, 1>& alpha) {
57  using Eigen::Dynamic;
58  using Eigen::Matrix;
59 
60  check_nonzero_size("log_softmax", "alpha", alpha);
61 
62  // TODO(carpenter): replace with array alloc
63  vari** alpha_vi_array = reinterpret_cast<vari**>(
64  vari::operator new(sizeof(vari*) * alpha.size()));
65  for (int i = 0; i < alpha.size(); ++i)
66  alpha_vi_array[i] = alpha(i).vi_;
67 
68  Matrix<double, Dynamic, 1> alpha_d(alpha.size());
69  for (int i = 0; i < alpha_d.size(); ++i)
70  alpha_d(i) = alpha(i).val();
71 
72  // fold logic of math::softmax() and math::log_softmax()
73  // to save computations
74 
75  Matrix<double, Dynamic, 1> softmax_alpha_d(alpha_d.size());
76  Matrix<double, Dynamic, 1> log_softmax_alpha_d(alpha_d.size());
77 
78  double max_v = alpha_d.maxCoeff();
79 
80  double sum = 0.0;
81  for (int i = 0; i < alpha_d.size(); ++i) {
82  softmax_alpha_d(i) = std::exp(alpha_d(i) - max_v);
83  sum += softmax_alpha_d(i);
84  }
85 
86  for (int i = 0; i < alpha_d.size(); ++i)
87  softmax_alpha_d(i) /= sum;
88  double log_sum = std::log(sum);
89 
90  for (int i = 0; i < alpha_d.size(); ++i)
91  log_softmax_alpha_d(i) = (alpha_d(i) - max_v) - log_sum;
92 
93  // end fold
94  // TODO(carpenter): replace with array alloc
95  double* softmax_alpha_d_array = reinterpret_cast<double*>(
96  vari::operator new(sizeof(double) * alpha_d.size()));
97 
98  for (int i = 0; i < alpha_d.size(); ++i)
99  softmax_alpha_d_array[i] = softmax_alpha_d(i);
100 
101  Matrix<var, Dynamic, 1> log_softmax_alpha(alpha.size());
102  for (int k = 0; k < log_softmax_alpha.size(); ++k)
103  log_softmax_alpha(k) = var(new internal::log_softmax_elt_vari(
104  log_softmax_alpha_d[k], alpha_vi_array, softmax_alpha_d_array,
105  alpha.size(), k));
106  return log_softmax_alpha;
107 }
108 
109 } // namespace math
110 } // namespace stan
111 #endif
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition: sum.hpp:20
void check_nonzero_size(const char *function, const char *name, const T_y &y)
Check if the specified matrix/vector is of non-zero size.
fvar< T > log(const fvar< T > &x)
Definition: log.hpp:12
The variable implementation base class.
Definition: vari.hpp:30
Eigen::Matrix< fvar< T >, Eigen::Dynamic, 1 > log_softmax(const Eigen::Matrix< fvar< T >, Eigen::Dynamic, 1 > &alpha)
Definition: log_softmax.hpp:14
friend class var
Definition: vari.hpp:32
void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
Definition: log_softmax.hpp:33
fvar< T > exp(const fvar< T > &x)
Definition: exp.hpp:11
int size(const std::vector< T > &x)
Return the size of the specified standard vector.
Definition: size.hpp:17
double adj_
The adjoint of this variable, which is the partial derivative of this variable with respect to the ro...
Definition: vari.hpp:44
log_softmax_elt_vari(double val, vari **alpha, const double *softmax_alpha, int size, int idx)
Definition: log_softmax.hpp:26

     [ Stan Home Page ] © 2011–2018, Stan Development Team.