Stan Math Library  2.20.0
reverse mode automatic differentiation
wiener_lpdf.hpp
Go to the documentation of this file.
1 // Original code from which Stan's code is derived:
2 // Copyright (c) 2013, Joachim Vandekerckhove.
3 // All rights reserved.
5 //
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted
8 // provided that the following conditions are met:
9 //
10 // * Redistributions of source code must retain the above copyright notice,
11 // * this list of conditions and the following disclaimer.
12 // * Redistributions in binary form must reproduce the above copyright notice,
13 // * this list of conditions and the following disclaimer in the
14 // * documentation and/or other materials provided with the distribution.
15 // * Neither the name of the University of California, Irvine nor the names
16 // * of its contributors may be used to endorse or promote products derived
17 // * from this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
29 // THE POSSIBILITY OF SUCH DAMAGE.
30 
31 #ifndef STAN_MATH_PRIM_MAT_PROB_WIENER_LPDF_HPP
32 #define STAN_MATH_PRIM_MAT_PROB_WIENER_LPDF_HPP
33 
43 #include <algorithm>
44 #include <cmath>
45 #include <string>
46 
47 namespace stan {
48 namespace math {
49 
68 template <bool propto, typename T_y, typename T_alpha, typename T_tau,
69  typename T_beta, typename T_delta>
71  const T_y& y, const T_alpha& alpha, const T_tau& tau, const T_beta& beta,
72  const T_delta& delta) {
73  static const char* function = "wiener_lpdf";
74 
75  using std::exp;
76  using std::log;
77  using std::pow;
78 
79  static const double WIENER_ERR = 0.000001;
80  static const double PI_TIMES_WIENER_ERR = pi() * WIENER_ERR;
81  static const double LOG_PI_LOG_WIENER_ERR = LOG_PI + log(WIENER_ERR);
82  static const double TWO_TIMES_SQRT_2_TIMES_SQRT_PI_TIMES_WIENER_ERR
83  = 2.0 * SQRT_2_TIMES_SQRT_PI * WIENER_ERR;
84  static const double LOG_TWO_OVER_TWO_PLUS_LOG_SQRT_PI
85  = LOG_2 / 2 + LOG_SQRT_PI;
86  static const double SQUARE_PI_OVER_TWO = square(pi()) * 0.5;
87  static const double TWO_TIMES_LOG_SQRT_PI = 2.0 * LOG_SQRT_PI;
88 
89  if (size_zero(y, alpha, beta, tau, delta))
90  return 0.0;
91 
93  T_return_type;
94  T_return_type lp(0.0);
95 
96  check_not_nan(function, "Random variable", y);
97  check_not_nan(function, "Boundary separation", alpha);
98  check_not_nan(function, "A-priori bias", beta);
99  check_not_nan(function, "Nondecision time", tau);
100  check_not_nan(function, "Drift rate", delta);
101  check_finite(function, "Boundary separation", alpha);
102  check_finite(function, "A-priori bias", beta);
103  check_finite(function, "Nondecision time", tau);
104  check_finite(function, "Drift rate", delta);
105  check_positive(function, "Random variable", y);
106  check_positive(function, "Boundary separation", alpha);
107  check_positive(function, "Nondecision time", tau);
108  check_bounded(function, "A-priori bias", beta, 0, 1);
109  check_consistent_sizes(function, "Random variable", y, "Boundary separation",
110  alpha, "A-priori bias", beta, "Nondecision time", tau,
111  "Drift rate", delta);
112 
113  size_t N = std::max(max_size(y, alpha, beta), max_size(tau, delta));
114  if (!N)
115  return 0.0;
116 
117  scalar_seq_view<T_y> y_vec(y);
118  scalar_seq_view<T_alpha> alpha_vec(alpha);
119  scalar_seq_view<T_beta> beta_vec(beta);
120  scalar_seq_view<T_tau> tau_vec(tau);
121  scalar_seq_view<T_delta> delta_vec(delta);
122 
123  size_t N_y_tau = max_size(y, tau);
124  for (size_t i = 0; i < N_y_tau; ++i) {
125  if (y_vec[i] <= tau_vec[i]) {
126  std::stringstream msg;
127  msg << ", but must be greater than nondecision time = " << tau_vec[i];
128  std::string msg_str(msg.str());
129  domain_error(function, "Random variable", y_vec[i], " = ",
130  msg_str.c_str());
131  }
132  }
133 
135  return 0;
136 
137  for (size_t i = 0; i < N; i++) {
138  typename scalar_type<T_beta>::type one_minus_beta = 1.0 - beta_vec[i];
139  typename scalar_type<T_alpha>::type alpha2 = square(alpha_vec[i]);
140  T_return_type x = (y_vec[i] - tau_vec[i]) / alpha2;
141  T_return_type kl, ks, tmp = 0;
142  T_return_type k, K;
143  T_return_type sqrt_x = sqrt(x);
144  T_return_type log_x = log(x);
145  T_return_type one_over_pi_times_sqrt_x = 1.0 / pi() * sqrt_x;
146 
147  // calculate number of terms needed for large t:
148  // if error threshold is set low enough
149  if (PI_TIMES_WIENER_ERR * x < 1) {
150  // compute bound
151  kl = sqrt(-2.0 * SQRT_PI * (LOG_PI_LOG_WIENER_ERR + log_x)) / sqrt_x;
152  // ensure boundary conditions met
153  kl = (kl > one_over_pi_times_sqrt_x) ? kl : one_over_pi_times_sqrt_x;
154  } else {
155  kl = one_over_pi_times_sqrt_x; // set to boundary condition
156  }
157  // calculate number of terms needed for small t:
158  // if error threshold is set low enough
159  T_return_type tmp_expr0
160  = TWO_TIMES_SQRT_2_TIMES_SQRT_PI_TIMES_WIENER_ERR * sqrt_x;
161  if (tmp_expr0 < 1) {
162  // compute bound
163  ks = 2.0 + sqrt_x * sqrt(-2 * log(tmp_expr0));
164  // ensure boundary conditions are met
165  T_return_type sqrt_x_plus_one = sqrt_x + 1.0;
166  ks = (ks > sqrt_x_plus_one) ? ks : sqrt_x_plus_one;
167  } else { // if error threshold was set too high
168  ks = 2.0; // minimal kappa for that case
169  }
170  if (ks < kl) { // small t
171  K = ceil(ks); // round to smallest integer meeting error
172  T_return_type tmp_expr1 = (K - 1.0) / 2.0;
173  T_return_type tmp_expr2 = ceil(tmp_expr1);
174  for (k = -floor(tmp_expr1); k <= tmp_expr2; k++)
175  tmp += (one_minus_beta + 2.0 * k)
176  * exp(-(square(one_minus_beta + 2.0 * k)) * 0.5 / x);
177  tmp = log(tmp) - LOG_TWO_OVER_TWO_PLUS_LOG_SQRT_PI - 1.5 * log_x;
178  } else { // if large t is better...
179  K = ceil(kl); // round to smallest integer meeting error
180  for (k = 1; k <= K; ++k)
181  tmp += k * exp(-(square(k)) * (SQUARE_PI_OVER_TWO * x))
182  * sin(k * pi() * one_minus_beta);
183  tmp = log(tmp) + TWO_TIMES_LOG_SQRT_PI;
184  }
185 
186  // convert to f(t|v,a,w) and return result
187  lp += delta_vec[i] * alpha_vec[i] * one_minus_beta
188  - square(delta_vec[i]) * x * alpha2 / 2.0 - log(alpha2) + tmp;
189  }
190  return lp;
191 }
192 
193 template <typename T_y, typename T_alpha, typename T_tau, typename T_beta,
194  typename T_delta>
196 wiener_lpdf(const T_y& y, const T_alpha& alpha, const T_tau& tau,
197  const T_beta& beta, const T_delta& delta) {
198  return wiener_lpdf<false>(y, alpha, tau, beta, delta);
199 }
200 
201 } // namespace math
202 } // namespace stan
203 #endif
const double LOG_2
The natural logarithm of 2, .
Definition: constants.hpp:37
void check_finite(const char *function, const char *name, const T_y &y)
Check if y is finite.
fvar< T > sqrt(const fvar< T > &x)
Definition: sqrt.hpp:13
void check_bounded(const char *function, const char *name, const T_y &y, const T_low &low, const T_high &high)
Check if the value is between the low and high values, inclusively.
const double LOG_PI
Definition: constants.hpp:144
fvar< T > log(const fvar< T > &x)
Definition: log.hpp:12
scalar_seq_view provides a uniform sequence-like wrapper around either a scalar or a sequence of scal...
const double LOG_SQRT_PI
Definition: constants.hpp:148
bool size_zero(T &x)
Returns 1 if input is of length 0, returns 0 otherwise.
Definition: size_zero.hpp:18
Template metaprogram to calculate whether a summand needs to be included in a proportional (log) prob...
fvar< T > square(const fvar< T > &x)
Definition: square.hpp:12
const double SQRT_2_TIMES_SQRT_PI
Definition: constants.hpp:134
fvar< T > beta(const fvar< T > &x1, const fvar< T > &x2)
Return fvar with the beta function applied to the specified arguments and its gradient.
Definition: beta.hpp:51
fvar< T > sin(const fvar< T > &x)
Definition: sin.hpp:11
boost::math::tools::promote_args< double, typename scalar_type< T >::type, typename return_type< Types_pack... >::type >::type type
Definition: return_type.hpp:36
fvar< T > exp(const fvar< T > &x)
Definition: exp.hpp:11
void check_not_nan(const char *function, const char *name, const T_y &y)
Check if y is not NaN.
return_type< T_y, T_alpha, T_tau, T_beta, T_delta >::type wiener_lpdf(const T_y &y, const T_alpha &alpha, const T_tau &tau, const T_beta &beta, const T_delta &delta)
The log of the first passage time density function for a (Wiener) drift diffusion model for the given...
Definition: wiener_lpdf.hpp:70
size_t max_size(const T1 &x1, const T2 &x2)
Definition: max_size.hpp:9
void domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
int max(const std::vector< int > &x)
Returns the maximum coefficient in the specified column vector.
Definition: max.hpp:21
fvar< T > floor(const fvar< T > &x)
Definition: floor.hpp:12
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
double pi()
Return the value of pi.
Definition: constants.hpp:80
fvar< T > pow(const fvar< T > &x1, const fvar< T > &x2)
Definition: pow.hpp:16
void check_consistent_sizes(const char *function, const char *name1, const T1 &x1, const char *name2, const T2 &x2)
Check if the dimension of x1 is consistent with x2.
const double SQRT_PI
Definition: constants.hpp:132
fvar< T > ceil(const fvar< T > &x)
Definition: ceil.hpp:12

     [ Stan Home Page ] © 2011–2018, Stan Development Team.