Commit bc9e0181 authored by Sebastian Weber's avatar Sebastian Weber
Browse files

avoid duplicate evaluation of ODE RHS in coupled_ode_system; take advantage of pre-allocated dz_dt

parent f37b91d1
Showing with 25 additions and 35 deletions
+25 -35
......@@ -122,12 +122,6 @@ namespace stan {
double t) const {
using std::vector;
vector<double> y(z.begin(), z.begin() + N_);
dz_dt = f_(t, y, theta_dbl_, x_, x_int_, msgs_);
check_size_match("coupled_ode_system", "dz_dt", dz_dt.size(),
"states", N_);
vector<double> coupled_sys(N_ * M_);
vector<double> grad(N_ + M_);
try {
......@@ -136,7 +130,7 @@ namespace stan {
vector<var> z_vars;
z_vars.reserve(N_ + M_);
vector<var> y_vars(y.begin(), y.end());
vector<var> y_vars(z.begin(), z.begin() + N_);
z_vars.insert(z_vars.end(), y_vars.begin(), y_vars.end());
vector<var> theta_vars(theta_dbl_.begin(), theta_dbl_.end());
......@@ -144,7 +138,11 @@ namespace stan {
vector<var> dy_dt_vars = f_(t, y_vars, theta_vars, x_, x_int_, msgs_);
check_size_match("coupled_ode_system", "dz_dt", dy_dt_vars.size(),
"states", N_);
for (size_t i = 0; i < N_; i++) {
dz_dt[i] = dy_dt_vars[i].val();
set_zero_all_adjoints_nested();
dy_dt_vars[i].grad(z_vars, grad);
......@@ -156,7 +154,7 @@ namespace stan {
for (size_t k = 0; k < N_; k++)
temp_deriv += z[N_ + N_ * j + k] * grad[k];
coupled_sys[i + j * N_] = temp_deriv;
dz_dt[N_ + i + j * N_] = temp_deriv;
}
}
} catch (const std::exception& e) {
......@@ -164,8 +162,6 @@ namespace stan {
throw;
}
recover_memory_nested();
dz_dt.insert(dz_dt.end(), coupled_sys.begin(), coupled_sys.end());
}
/**
......@@ -320,11 +316,6 @@ namespace stan {
for (size_t n = 0; n < N_; n++)
y[n] += y0_dbl_[n];
dz_dt = f_(t, y, theta_dbl_, x_, x_int_, msgs_);
check_size_match("coupled_ode_system", "dz_dt", dz_dt.size(),
"states", N_);
std::vector<double> coupled_sys(N_ * N_);
std::vector<double> grad(N_);
try {
......@@ -338,7 +329,11 @@ namespace stan {
vector<var> dy_dt_vars = f_(t, y_vars, theta_dbl_, x_, x_int_, msgs_);
check_size_match("coupled_ode_system", "dz_dt", dy_dt_vars.size(),
"states", N_);
for (size_t i = 0; i < N_; i++) {
dz_dt[i] = dy_dt_vars[i].val();
set_zero_all_adjoints_nested();
dy_dt_vars[i].grad(z_vars, grad);
......@@ -350,7 +345,7 @@ namespace stan {
for (size_t k = 0; k < N_; k++)
temp_deriv += z[N_ + N_ * j + k] * grad[k];
coupled_sys[i + j * N_] = temp_deriv;
dz_dt[N_ + i + j * N_] = temp_deriv;
}
}
} catch (const std::exception& e) {
......@@ -358,8 +353,6 @@ namespace stan {
throw;
}
recover_memory_nested();
dz_dt.insert(dz_dt.end(), coupled_sys.begin(), coupled_sys.end());
}
/**
......@@ -528,11 +521,6 @@ namespace stan {
for (size_t n = 0; n < N_; n++)
y[n] += y0_dbl_[n];
dz_dt = f_(t, y, theta_dbl_, x_, x_int_, msgs_);
check_size_match("coupled_ode_system", "dz_dt", dz_dt.size(),
"states", N_);
vector<double> coupled_sys(N_ * (N_ + M_));
vector<double> grad(N_ + M_);
try {
......@@ -549,7 +537,11 @@ namespace stan {
vector<var> dy_dt_vars = f_(t, y_vars, theta_vars, x_, x_int_, msgs_);
check_size_match("coupled_ode_system", "dz_dt", dy_dt_vars.size(),
"states", N_);
for (size_t i = 0; i < N_; i++) {
dz_dt[i] = dy_dt_vars[i].val();
set_zero_all_adjoints_nested();
dy_dt_vars[i].grad(z_vars, grad);
......@@ -561,7 +553,7 @@ namespace stan {
for (size_t k = 0; k < N_; k++)
temp_deriv += z[N_ + N_ * j + k] * grad[k];
coupled_sys[i + j * N_] = temp_deriv;
dz_dt[N_ + i + j * N_] = temp_deriv;
}
}
} catch (const std::exception& e) {
......@@ -569,8 +561,6 @@ namespace stan {
throw;
}
recover_memory_nested();
dz_dt.insert(dz_dt.end(), coupled_sys.begin(), coupled_sys.end());
}
/**
......
......@@ -25,7 +25,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_dv) {
std::vector<double> coupled_y0;
std::vector<double> y0;
double t0;
std::vector<double> dy_dt;
std::vector<double> dy_dt(4, 0);
double gamma(0.15);
t0 = 0;
......@@ -146,7 +146,7 @@ TEST_F(StanAgradRevOde, memory_recovery_dv) {
coupled_system_dv(base_ode, y0_d, theta_v, x, x_int, &msgs);
std::vector<double> y(3,0);
std::vector<double> dy_dt(3,0);
std::vector<double> dy_dt(3 + N * M,0);
double t = 10;
EXPECT_TRUE(stan::math::empty_nested());
......@@ -174,7 +174,7 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_dv) {
coupled_system_dv(throwing_ode, y0_d, theta_v, x, x_int, &msgs);
std::vector<double> y(3,0);
std::vector<double> dy_dt(3,0);
std::vector<double> dy_dt(3 + N * M,0);
double t = 10;
EXPECT_TRUE(stan::math::empty_nested());
......@@ -197,7 +197,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vd) {
std::vector<stan::math::var> y0_var;
std::vector<double> y0_adj;
double t0;
std::vector<double> dy_dt;
std::vector<double> dy_dt(6, 0);
double gamma(0.15);
t0 = 0;
......@@ -324,7 +324,7 @@ TEST_F(StanAgradRevOde, memory_recovery_vd) {
coupled_system_vd(base_ode, y0_v, theta_d, x, x_int, &msgs);
std::vector<double> y(3,0);
std::vector<double> dy_dt(3,0);
std::vector<double> dy_dt(3 + N * N,0);
double t = 10;
EXPECT_TRUE(stan::math::empty_nested());
......@@ -352,7 +352,7 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_vd) {
coupled_system_vd(throwing_ode, y0_v, theta_d, x, x_int, &msgs);
std::vector<double> y(3,0);
std::vector<double> dy_dt(3,0);
std::vector<double> dy_dt(3 + N * N,0);
double t = 10;
EXPECT_TRUE(stan::math::empty_nested());
......@@ -385,7 +385,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vv) {
double t0;
t0 = 0;
std::vector<double> dy_dt;
std::vector<double> dy_dt(2 + 2 * 2 + 2 * 1);
system(coupled_y0, dy_dt, t0);
std::vector<double> y0_double(2);
......@@ -503,7 +503,7 @@ TEST_F(StanAgradRevOde, memory_recovery_vv) {
coupled_system_vv(base_ode, y0_v, theta_v, x, x_int, &msgs);
std::vector<double> y(3,0);
std::vector<double> dy_dt(3,0);
std::vector<double> dy_dt(3 + N * N + N * M,0);
double t = 10;
EXPECT_TRUE(stan::math::empty_nested());
......@@ -531,7 +531,7 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_vv) {
coupled_system_vv(throwing_ode, y0_v, theta_v, x, x_int, &msgs);
std::vector<double> y(3,0);
std::vector<double> dy_dt(3,0);
std::vector<double> dy_dt(3 + N * N + N * M,0);
double t = 10;
EXPECT_TRUE(stan::math::empty_nested());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment