Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Francesco Brarda
stan-math-petsc
Commits
bc9e0181
Commit
bc9e0181
authored
8 years ago
by
Sebastian Weber
Browse files
Options
Download
Email Patches
Plain Diff
avoid duplicate evaluation of ODE RHS in coupled_ode_system; take advantage of pre-allocated dz_dt
parent
f37b91d1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
stan/math/rev/arr/functor/coupled_ode_system.hpp
+16
-26
stan/math/rev/arr/functor/coupled_ode_system.hpp
test/unit/math/rev/arr/functor/coupled_ode_system_test.cpp
+9
-9
test/unit/math/rev/arr/functor/coupled_ode_system_test.cpp
with
25 additions
and
35 deletions
+25
-35
stan/math/rev/arr/functor/coupled_ode_system.hpp
View file @
bc9e0181
...
...
@@ -122,12 +122,6 @@ namespace stan {
double
t
)
const
{
using
std
::
vector
;
vector
<
double
>
y
(
z
.
begin
(),
z
.
begin
()
+
N_
);
dz_dt
=
f_
(
t
,
y
,
theta_dbl_
,
x_
,
x_int_
,
msgs_
);
check_size_match
(
"coupled_ode_system"
,
"dz_dt"
,
dz_dt
.
size
(),
"states"
,
N_
);
vector
<
double
>
coupled_sys
(
N_
*
M_
);
vector
<
double
>
grad
(
N_
+
M_
);
try
{
...
...
@@ -136,7 +130,7 @@ namespace stan {
vector
<
var
>
z_vars
;
z_vars
.
reserve
(
N_
+
M_
);
vector
<
var
>
y_vars
(
y
.
begin
(),
y
.
end
()
);
vector
<
var
>
y_vars
(
z
.
begin
(),
z
.
begin
()
+
N_
);
z_vars
.
insert
(
z_vars
.
end
(),
y_vars
.
begin
(),
y_vars
.
end
());
vector
<
var
>
theta_vars
(
theta_dbl_
.
begin
(),
theta_dbl_
.
end
());
...
...
@@ -144,7 +138,11 @@ namespace stan {
vector
<
var
>
dy_dt_vars
=
f_
(
t
,
y_vars
,
theta_vars
,
x_
,
x_int_
,
msgs_
);
check_size_match
(
"coupled_ode_system"
,
"dz_dt"
,
dy_dt_vars
.
size
(),
"states"
,
N_
);
for
(
size_t
i
=
0
;
i
<
N_
;
i
++
)
{
dz_dt
[
i
]
=
dy_dt_vars
[
i
].
val
();
set_zero_all_adjoints_nested
();
dy_dt_vars
[
i
].
grad
(
z_vars
,
grad
);
...
...
@@ -156,7 +154,7 @@ namespace stan {
for
(
size_t
k
=
0
;
k
<
N_
;
k
++
)
temp_deriv
+=
z
[
N_
+
N_
*
j
+
k
]
*
grad
[
k
];
coupled_sys
[
i
+
j
*
N_
]
=
temp_deriv
;
dz_dt
[
N_
+
i
+
j
*
N_
]
=
temp_deriv
;
}
}
}
catch
(
const
std
::
exception
&
e
)
{
...
...
@@ -164,8 +162,6 @@ namespace stan {
throw
;
}
recover_memory_nested
();
dz_dt
.
insert
(
dz_dt
.
end
(),
coupled_sys
.
begin
(),
coupled_sys
.
end
());
}
/**
...
...
@@ -320,11 +316,6 @@ namespace stan {
for
(
size_t
n
=
0
;
n
<
N_
;
n
++
)
y
[
n
]
+=
y0_dbl_
[
n
];
dz_dt
=
f_
(
t
,
y
,
theta_dbl_
,
x_
,
x_int_
,
msgs_
);
check_size_match
(
"coupled_ode_system"
,
"dz_dt"
,
dz_dt
.
size
(),
"states"
,
N_
);
std
::
vector
<
double
>
coupled_sys
(
N_
*
N_
);
std
::
vector
<
double
>
grad
(
N_
);
try
{
...
...
@@ -338,7 +329,11 @@ namespace stan {
vector
<
var
>
dy_dt_vars
=
f_
(
t
,
y_vars
,
theta_dbl_
,
x_
,
x_int_
,
msgs_
);
check_size_match
(
"coupled_ode_system"
,
"dz_dt"
,
dy_dt_vars
.
size
(),
"states"
,
N_
);
for
(
size_t
i
=
0
;
i
<
N_
;
i
++
)
{
dz_dt
[
i
]
=
dy_dt_vars
[
i
].
val
();
set_zero_all_adjoints_nested
();
dy_dt_vars
[
i
].
grad
(
z_vars
,
grad
);
...
...
@@ -350,7 +345,7 @@ namespace stan {
for
(
size_t
k
=
0
;
k
<
N_
;
k
++
)
temp_deriv
+=
z
[
N_
+
N_
*
j
+
k
]
*
grad
[
k
];
coupled_sys
[
i
+
j
*
N_
]
=
temp_deriv
;
dz_dt
[
N_
+
i
+
j
*
N_
]
=
temp_deriv
;
}
}
}
catch
(
const
std
::
exception
&
e
)
{
...
...
@@ -358,8 +353,6 @@ namespace stan {
throw
;
}
recover_memory_nested
();
dz_dt
.
insert
(
dz_dt
.
end
(),
coupled_sys
.
begin
(),
coupled_sys
.
end
());
}
/**
...
...
@@ -528,11 +521,6 @@ namespace stan {
for
(
size_t
n
=
0
;
n
<
N_
;
n
++
)
y
[
n
]
+=
y0_dbl_
[
n
];
dz_dt
=
f_
(
t
,
y
,
theta_dbl_
,
x_
,
x_int_
,
msgs_
);
check_size_match
(
"coupled_ode_system"
,
"dz_dt"
,
dz_dt
.
size
(),
"states"
,
N_
);
vector
<
double
>
coupled_sys
(
N_
*
(
N_
+
M_
));
vector
<
double
>
grad
(
N_
+
M_
);
try
{
...
...
@@ -549,7 +537,11 @@ namespace stan {
vector
<
var
>
dy_dt_vars
=
f_
(
t
,
y_vars
,
theta_vars
,
x_
,
x_int_
,
msgs_
);
check_size_match
(
"coupled_ode_system"
,
"dz_dt"
,
dy_dt_vars
.
size
(),
"states"
,
N_
);
for
(
size_t
i
=
0
;
i
<
N_
;
i
++
)
{
dz_dt
[
i
]
=
dy_dt_vars
[
i
].
val
();
set_zero_all_adjoints_nested
();
dy_dt_vars
[
i
].
grad
(
z_vars
,
grad
);
...
...
@@ -561,7 +553,7 @@ namespace stan {
for
(
size_t
k
=
0
;
k
<
N_
;
k
++
)
temp_deriv
+=
z
[
N_
+
N_
*
j
+
k
]
*
grad
[
k
];
coupled_sys
[
i
+
j
*
N_
]
=
temp_deriv
;
dz_dt
[
N_
+
i
+
j
*
N_
]
=
temp_deriv
;
}
}
}
catch
(
const
std
::
exception
&
e
)
{
...
...
@@ -569,8 +561,6 @@ namespace stan {
throw
;
}
recover_memory_nested
();
dz_dt
.
insert
(
dz_dt
.
end
(),
coupled_sys
.
begin
(),
coupled_sys
.
end
());
}
/**
...
...
This diff is collapsed.
Click to expand it.
test/unit/math/rev/arr/functor/coupled_ode_system_test.cpp
View file @
bc9e0181
...
...
@@ -25,7 +25,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_dv) {
std
::
vector
<
double
>
coupled_y0
;
std
::
vector
<
double
>
y0
;
double
t0
;
std
::
vector
<
double
>
dy_dt
;
std
::
vector
<
double
>
dy_dt
(
4
,
0
)
;
double
gamma
(
0.15
);
t0
=
0
;
...
...
@@ -146,7 +146,7 @@ TEST_F(StanAgradRevOde, memory_recovery_dv) {
coupled_system_dv
(
base_ode
,
y0_d
,
theta_v
,
x
,
x_int
,
&
msgs
);
std
::
vector
<
double
>
y
(
3
,
0
);
std
::
vector
<
double
>
dy_dt
(
3
,
0
);
std
::
vector
<
double
>
dy_dt
(
3
+
N
*
M
,
0
);
double
t
=
10
;
EXPECT_TRUE
(
stan
::
math
::
empty_nested
());
...
...
@@ -174,7 +174,7 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_dv) {
coupled_system_dv
(
throwing_ode
,
y0_d
,
theta_v
,
x
,
x_int
,
&
msgs
);
std
::
vector
<
double
>
y
(
3
,
0
);
std
::
vector
<
double
>
dy_dt
(
3
,
0
);
std
::
vector
<
double
>
dy_dt
(
3
+
N
*
M
,
0
);
double
t
=
10
;
EXPECT_TRUE
(
stan
::
math
::
empty_nested
());
...
...
@@ -197,7 +197,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vd) {
std
::
vector
<
stan
::
math
::
var
>
y0_var
;
std
::
vector
<
double
>
y0_adj
;
double
t0
;
std
::
vector
<
double
>
dy_dt
;
std
::
vector
<
double
>
dy_dt
(
6
,
0
)
;
double
gamma
(
0.15
);
t0
=
0
;
...
...
@@ -324,7 +324,7 @@ TEST_F(StanAgradRevOde, memory_recovery_vd) {
coupled_system_vd
(
base_ode
,
y0_v
,
theta_d
,
x
,
x_int
,
&
msgs
);
std
::
vector
<
double
>
y
(
3
,
0
);
std
::
vector
<
double
>
dy_dt
(
3
,
0
);
std
::
vector
<
double
>
dy_dt
(
3
+
N
*
N
,
0
);
double
t
=
10
;
EXPECT_TRUE
(
stan
::
math
::
empty_nested
());
...
...
@@ -352,7 +352,7 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_vd) {
coupled_system_vd
(
throwing_ode
,
y0_v
,
theta_d
,
x
,
x_int
,
&
msgs
);
std
::
vector
<
double
>
y
(
3
,
0
);
std
::
vector
<
double
>
dy_dt
(
3
,
0
);
std
::
vector
<
double
>
dy_dt
(
3
+
N
*
N
,
0
);
double
t
=
10
;
EXPECT_TRUE
(
stan
::
math
::
empty_nested
());
...
...
@@ -385,7 +385,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vv) {
double
t0
;
t0
=
0
;
std
::
vector
<
double
>
dy_dt
;
std
::
vector
<
double
>
dy_dt
(
2
+
2
*
2
+
2
*
1
)
;
system
(
coupled_y0
,
dy_dt
,
t0
);
std
::
vector
<
double
>
y0_double
(
2
);
...
...
@@ -503,7 +503,7 @@ TEST_F(StanAgradRevOde, memory_recovery_vv) {
coupled_system_vv
(
base_ode
,
y0_v
,
theta_v
,
x
,
x_int
,
&
msgs
);
std
::
vector
<
double
>
y
(
3
,
0
);
std
::
vector
<
double
>
dy_dt
(
3
,
0
);
std
::
vector
<
double
>
dy_dt
(
3
+
N
*
N
+
N
*
M
,
0
);
double
t
=
10
;
EXPECT_TRUE
(
stan
::
math
::
empty_nested
());
...
...
@@ -531,7 +531,7 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_vv) {
coupled_system_vv
(
throwing_ode
,
y0_v
,
theta_v
,
x
,
x_int
,
&
msgs
);
std
::
vector
<
double
>
y
(
3
,
0
);
std
::
vector
<
double
>
dy_dt
(
3
,
0
);
std
::
vector
<
double
>
dy_dt
(
3
+
N
*
N
+
N
*
M
,
0
);
double
t
=
10
;
EXPECT_TRUE
(
stan
::
math
::
empty_nested
());
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment