Unverified Commit 0b4a7da6 authored by Steve Bronder's avatar Steve Bronder
Browse files

Merge remote-tracking branch 'upstream/develop' into gpu_lower_tri_inverse

parents dcc0800a 2072997d
No related merge requests found
Showing with 30 additions and 31 deletions
+30 -31
......@@ -101,19 +101,12 @@ struct coupled_ode_system<F, double, var> {
double t) const {
using std::vector;
vector<double> grad(N_ + M_);
try {
start_nested();
vector<var> z_vars;
z_vars.reserve(N_ + M_);
vector<var> y_vars(z.begin(), z.begin() + N_);
z_vars.insert(z_vars.end(), y_vars.begin(), y_vars.end());
vector<var> theta_vars(theta_dbl_.begin(), theta_dbl_.end());
z_vars.insert(z_vars.end(), theta_vars.begin(), theta_vars.end());
vector<var> dy_dt_vars = f_(t, y_vars, theta_vars, x_, x_int_, msgs_);
......@@ -122,19 +115,21 @@ struct coupled_ode_system<F, double, var> {
for (size_t i = 0; i < N_; i++) {
dz_dt[i] = dy_dt_vars[i].val();
set_zero_all_adjoints_nested();
dy_dt_vars[i].grad(z_vars, grad);
dy_dt_vars[i].grad();
for (size_t j = 0; j < M_; j++) {
// orders derivatives by equation (i.e. if there are 2 eqns
// (y1, y2) and 2 parameters (a, b), dy_dt will be ordered as:
// dy1_dt, dy2_dt, dy1_da, dy2_da, dy1_db, dy2_db
double temp_deriv = grad[N_ + j];
double temp_deriv = theta_vars[j].adj();
const size_t offset = N_ + N_ * j;
for (size_t k = 0; k < N_; k++)
temp_deriv += z[N_ + N_ * j + k] * grad[k];
temp_deriv += z[offset + k] * y_vars[k].adj();
dz_dt[N_ + i + j * N_] = temp_deriv;
dz_dt[offset + i] = temp_deriv;
}
set_zero_all_adjoints_nested();
}
} catch (const std::exception& e) {
recover_memory_nested();
......@@ -284,8 +279,6 @@ struct coupled_ode_system<F, var, double> {
double t) const {
using std::vector;
std::vector<double> grad(N_);
try {
start_nested();
......@@ -298,19 +291,21 @@ struct coupled_ode_system<F, var, double> {
for (size_t i = 0; i < N_; i++) {
dz_dt[i] = dy_dt_vars[i].val();
set_zero_all_adjoints_nested();
dy_dt_vars[i].grad(y_vars, grad);
dy_dt_vars[i].grad();
for (size_t j = 0; j < N_; j++) {
// orders derivatives by equation (i.e. if there are 2 eqns
// (y1, y2) and 2 parameters (a, b), dy_dt will be ordered as:
// dy1_dt, dy2_dt, dy1_da, dy2_da, dy1_db, dy2_db
double temp_deriv = 0;
const size_t offset = N_ + N_ * j;
for (size_t k = 0; k < N_; k++)
temp_deriv += z[N_ + N_ * j + k] * grad[k];
temp_deriv += z[offset + k] * y_vars[k].adj();
dz_dt[N_ + i + j * N_] = temp_deriv;
dz_dt[offset + i] = temp_deriv;
}
set_zero_all_adjoints_nested();
}
} catch (const std::exception& e) {
recover_memory_nested();
......@@ -478,19 +473,12 @@ struct coupled_ode_system<F, var, var> {
double t) const {
using std::vector;
vector<double> grad(N_ + M_);
try {
start_nested();
vector<var> z_vars;
z_vars.reserve(N_ + M_);
vector<var> y_vars(z.begin(), z.begin() + N_);
z_vars.insert(z_vars.end(), y_vars.begin(), y_vars.end());
vector<var> theta_vars(theta_dbl_.begin(), theta_dbl_.end());
z_vars.insert(z_vars.end(), theta_vars.begin(), theta_vars.end());
vector<var> dy_dt_vars = f_(t, y_vars, theta_vars, x_, x_int_, msgs_);
......@@ -499,19 +487,30 @@ struct coupled_ode_system<F, var, var> {
for (size_t i = 0; i < N_; i++) {
dz_dt[i] = dy_dt_vars[i].val();
set_zero_all_adjoints_nested();
dy_dt_vars[i].grad(z_vars, grad);
dy_dt_vars[i].grad();
for (size_t j = 0; j < N_ + M_; j++) {
for (size_t j = 0; j < N_; j++) {
// orders derivatives by equation (i.e. if there are 2 eqns
// (y1, y2) and 2 parameters (a, b), dy_dt will be ordered as:
// dy1_dt, dy2_dt, dy1_da, dy2_da, dy1_db, dy2_db
double temp_deriv = j < N_ ? 0 : grad[j];
double temp_deriv = 0;
const size_t offset = N_ + N_ * j;
for (size_t k = 0; k < N_; k++)
temp_deriv += z[offset + k] * y_vars[k].adj();
dz_dt[offset + i] = temp_deriv;
}
for (size_t j = 0; j < M_; j++) {
double temp_deriv = theta_vars[j].adj();
const size_t offset = N_ + N_ * N_ + N_ * j;
for (size_t k = 0; k < N_; k++)
temp_deriv += z[N_ + N_ * j + k] * grad[k];
temp_deriv += z[offset + k] * y_vars[k].adj();
dz_dt[N_ + i + j * N_] = temp_deriv;
dz_dt[offset + i] = temp_deriv;
}
set_zero_all_adjoints_nested();
}
} catch (const std::exception& e) {
recover_memory_nested();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment