Skip to content

Commit d4ed967

Browse files
authored
1098 bind different integrators for python (#1103)
- Bind IntegratorCore - Expose 3 often used Types of IntegratorCores to python - Update Simulation and model specific simulate bindings
1 parent a7660d8 commit d4ed967

File tree

9 files changed

+211
-51
lines changed

9 files changed

+211
-51
lines changed

pycode/memilio-simulation/memilio/simulation/bindings/compartments/flow_simulation.h

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/*
22
* Copyright (C) 2020-2024 MEmilio
33
*
4-
* Authors: Henrik Zunker
4+
* Authors: Maximilian Betz, Henrik Zunker
55
*
66
* Contact: Martin J. Kuehn <[email protected]>
77
*
@@ -28,18 +28,27 @@
2828
namespace pymio
2929
{
3030

31-
template <class Model, EnablePickling F>
32-
void bind_Flow_Simulation(pybind11::module_& m)
31+
/*
32+
* @brief bind FlowSimulation for any model
33+
*/
34+
template <class FlowSimulation>
35+
void bind_Flow_Simulation(pybind11::module_& m, std::string const& name)
3336
{
34-
bind_class<mio::FlowSimulation<double, Model>, F>(m, "FlowSimulation")
35-
.def(pybind11::init<const Model&, double, double>(), pybind11::arg("model"), pybind11::arg("t0") = 0,
37+
bind_class<FlowSimulation, EnablePickling::IfAvailable>(m, name.c_str())
38+
.def(pybind11::init<const typename FlowSimulation::Model&, double, double>(), pybind11::arg("model"), pybind11::arg("t0") = 0,
3639
pybind11::arg("dt") = 0.1)
37-
.def_property_readonly(
38-
"result", pybind11::overload_cast<>(&mio::FlowSimulation<double, Model>::get_result, pybind11::const_),
39-
pybind11::return_value_policy::reference_internal)
40-
.def_property_readonly("flows", &mio::FlowSimulation<double, Model>::get_flows,
40+
.def_property_readonly("result", pybind11::overload_cast<>(&FlowSimulation::get_result, pybind11::const_),
4141
pybind11::return_value_policy::reference_internal)
42-
.def("advance", &mio::FlowSimulation<double, Model>::advance, pybind11::arg("tmax"));
42+
.def_property_readonly("flows", pybind11::overload_cast<>(&FlowSimulation::get_flows, pybind11::const_),
43+
pybind11::return_value_policy::reference_internal)
44+
.def_property_readonly("model", pybind11::overload_cast<>(&FlowSimulation::get_model, pybind11::const_),
45+
pybind11::return_value_policy::reference_internal)
46+
.def_property_readonly("dt", pybind11::overload_cast<>(&FlowSimulation::get_dt, pybind11::const_),
47+
pybind11::return_value_policy::reference_internal)
48+
.def_property("integrator", pybind11::overload_cast<>(&FlowSimulation::get_integrator, pybind11::const_),
49+
&FlowSimulation::set_integrator, pybind11::return_value_policy::reference_internal)
50+
.def("advance", &FlowSimulation::advance, pybind11::arg("tmax"))
51+
.doc() = "A class for the simulation of a flow model.";
4352
}
4453

4554
} // namespace pymio

pycode/memilio-simulation/memilio/simulation/bindings/compartments/simulation.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,20 @@ namespace pymio
3434
template <class Simulation>
3535
void bind_Simulation(pybind11::module_& m, std::string const& name)
3636
{
37+
3738
bind_class<Simulation, EnablePickling::IfAvailable>(m, name.c_str())
3839
.def(pybind11::init<const typename Simulation::Model&, double, double>(), pybind11::arg("model"),
3940
pybind11::arg("t0") = 0, pybind11::arg("dt") = 0.1)
4041
.def_property_readonly("result", pybind11::overload_cast<>(&Simulation::get_result, pybind11::const_),
4142
pybind11::return_value_policy::reference_internal)
4243
.def_property_readonly("model", pybind11::overload_cast<>(&Simulation::get_model, pybind11::const_),
4344
pybind11::return_value_policy::reference_internal)
44-
.def("advance", &Simulation::advance, pybind11::arg("tmax"));
45+
.def_property_readonly("dt", pybind11::overload_cast<>(&Simulation::get_dt, pybind11::const_),
46+
pybind11::return_value_policy::reference_internal)
47+
.def_property("integrator", pybind11::overload_cast<>(&Simulation::get_integrator, pybind11::const_),
48+
&Simulation::set_integrator, pybind11::return_value_policy::reference_internal)
49+
.def("advance", &Simulation::advance, pybind11::arg("tmax"))
50+
.doc() = "A class for the simulation of a compartment model.";
4551
}
4652

4753
} // namespace pymio
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright (C) 2020-2024 MEmilio
3+
*
4+
* Authors: Maximilian Betz
5+
*
6+
* Contact: Martin J. Kuehn <[email protected]>
7+
*
8+
* Licensed under the Apache License, Version 2.0 (the "License");
9+
* you may not use this file except in compliance with the License.
10+
* You may obtain a copy of the License at
11+
*
12+
* http://www.apache.org/licenses/LICENSE-2.0
13+
*
14+
* Unless required by applicable law or agreed to in writing, software
15+
* distributed under the License is distributed on an "AS IS" BASIS,
16+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
* See the License for the specific language governing permissions and
18+
* limitations under the License.
19+
*/
20+
#ifndef PYMIO_INTEGRATOR_H
21+
#define PYMIO_INTEGRATOR_H
22+
23+
#include "memilio/math/integrator.h"
24+
#include "memilio/math/euler.h"
25+
#include "memilio/math/adapt_rk.h"
26+
#include "memilio/math/stepper_wrapper.h"
27+
#include <Eigen/Dense>
28+
#include "pybind_util.h"
29+
30+
#include "pybind11/pybind11.h"
31+
#include <pybind11/eigen.h>
32+
33+
namespace pymio
34+
{
35+
36+
void bind_Integrator_Core(pybind11::module_& m)
37+
{
38+
pymio::bind_class<mio::IntegratorCore<double>, pymio::EnablePickling::Never, std::shared_ptr<mio::IntegratorCore<double>>>(m, "IntegratorCore")
39+
.def_property(
40+
"dt_max", pybind11::overload_cast<>(&mio::IntegratorCore<double>::get_dt_max, pybind11::const_),
41+
[](mio::IntegratorCore<double>& self, double dt_max) {
42+
self.get_dt_max() = dt_max;
43+
})
44+
.def_property(
45+
"dt_min", pybind11::overload_cast<>(&mio::IntegratorCore<double>::get_dt_min, pybind11::const_),
46+
[](mio::IntegratorCore<double>& self, double dt_min) {
47+
self.get_dt_min() = dt_min;
48+
});
49+
50+
pymio::bind_class<mio::EulerIntegratorCore<double>, pymio::EnablePickling::Never, mio::IntegratorCore<double>, std::shared_ptr<mio::EulerIntegratorCore<double>>>(m, "EulerIntegratorCore")
51+
.def(pybind11::init<>())
52+
.def("step",
53+
[](const mio::EulerIntegratorCore<double>& self, pybind11::function f, Eigen::Ref<const Eigen::VectorXd> yt, double t, double dt, Eigen::Ref<Eigen::VectorXd> ytp1) {
54+
bool result = self.step(
55+
[f](Eigen::Ref<const Eigen::VectorXd> y, double t, Eigen::Ref<Eigen::VectorXd> dydt) {
56+
f(y, t, dydt);
57+
},
58+
yt, t, dt, ytp1
59+
);
60+
return result;
61+
}, pybind11::arg("f"), pybind11::arg("yt"), pybind11::arg("t"), pybind11::arg("dt"), pybind11::arg("ytp1")
62+
);
63+
64+
using RungeKuttaCashKarp54Integrator = mio::ControlledStepperWrapper<double, boost::numeric::odeint::runge_kutta_cash_karp54>;
65+
pymio::bind_class<RungeKuttaCashKarp54Integrator, pymio::EnablePickling::Never, mio::IntegratorCore<double>, std::shared_ptr<RungeKuttaCashKarp54Integrator>>(m, "RungeKuttaCashKarp54IntegratorCore")
66+
.def(pybind11::init<>())
67+
.def(pybind11::init<const double, const double, const double, const double>(), pybind11::arg("abs_tol"), pybind11::arg("rel_tol"), pybind11::arg("dt_min"), pybind11::arg("dt_max"))
68+
.def("set_abs_tolerance", &RungeKuttaCashKarp54Integrator::set_abs_tolerance, pybind11::arg("tol"))
69+
.def("set_rel_tolerance", &RungeKuttaCashKarp54Integrator::set_rel_tolerance, pybind11::arg("tol"));
70+
71+
pymio::bind_class<mio::RKIntegratorCore<double>, pymio::EnablePickling::Never, mio::IntegratorCore<double>, std::shared_ptr<mio::RKIntegratorCore<double>>>(m, "RKIntegratorCore")
72+
.def(pybind11::init<>())
73+
.def(pybind11::init<double, double, double, double>(),
74+
pybind11::arg("abs_tol") = 1e-10,
75+
pybind11::arg("rel_tol") = 1e-5,
76+
pybind11::arg("dt_min") = std::numeric_limits<double>::min(),
77+
pybind11::arg("dt_max") = std::numeric_limits<double>::max())
78+
.def("set_abs_tolerance", &mio::RKIntegratorCore<double>::set_abs_tolerance, pybind11::arg("tol"))
79+
.def("set_rel_tolerance", &mio::RKIntegratorCore<double>::set_rel_tolerance, pybind11::arg("tol"));
80+
81+
}
82+
} // namespace pymio
83+
84+
#endif //PYMIO_INTEGRATOR_H

pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -213,22 +213,15 @@ PYBIND11_MODULE(_simulation_osecir, m)
213213
.def(py::init<int>(), py::arg("num_agegroups"));
214214

215215
pymio::bind_Simulation<mio::osecir::Simulation<>>(m, "Simulation");
216+
pymio::bind_Flow_Simulation<mio::osecir::Simulation<double, mio::FlowSimulation<double, mio::osecir::Model<double>>>>(m, "FlowSimulation");
216217

217218
m.def(
218-
"simulate",
219-
[](double t0, double tmax, double dt, const mio::osecir::Model<double>& model) {
220-
return mio::osecir::simulate(t0, tmax, dt, model);
221-
},
222-
"Simulates an ODE SECIHURD model from t0 to tmax.", py::arg("t0"), py::arg("tmax"), py::arg("dt"),
223-
py::arg("model"));
219+
"simulate", &mio::osecir::simulate<double>, "Simulates an ODE SECIHURD model from t0 to tmax.",
220+
py::arg("t0"), py::arg("tmax"), py::arg("dt"), py::arg("model"), py::arg("integrator") = py::none());
224221

225222
m.def(
226-
"simulate_flows",
227-
[](double t0, double tmax, double dt, const mio::osecir::Model<double>& model) {
228-
return mio::osecir::simulate_flows<double>(t0, tmax, dt, model);
229-
},
230-
"Simulates an ODE SECIHURD model with flows from t0 to tmax.", py::arg("t0"), py::arg("tmax"), py::arg("dt"),
231-
py::arg("model"));
223+
"simulate_flows", &mio::osecir::simulate_flows<double>, "Simulates an ODE SECIHURD model with flows from t0 to tmax.",
224+
py::arg("t0"), py::arg("tmax"), py::arg("dt"), py::arg("model"), py::arg("integrator") = py::none());
232225

233226
pymio::bind_ModelNode<mio::osecir::Model<double>>(m, "ModelNode");
234227
pymio::bind_SimulationNode<mio::osecir::Simulation<>>(m, "SimulationNode");

pycode/memilio-simulation/memilio/simulation/bindings/models/osecirvvs.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "pybind_util.h"
2323
#include "utils/parameter_set.h"
2424
#include "compartments/simulation.h"
25+
#include "compartments/flow_simulation.h"
2526
#include "compartments/compartmentalmodel.h"
2627
#include "mobility/graph_simulation.h"
2728
#include "mobility/metapopulation_mobility_instant.h"
@@ -262,22 +263,16 @@ PYBIND11_MODULE(_simulation_osecirvvs, m)
262263
.def(py::init<int>(), py::arg("num_agegroups"));
263264

264265
pymio::bind_Simulation<mio::osecirvvs::Simulation<>>(m, "Simulation");
266+
pymio::bind_Flow_Simulation<mio::osecirvvs::Simulation<double, mio::FlowSimulation<double, mio::osecirvvs::Model<double>>>>(m, "FlowSimulation");
265267

266268
m.def(
267-
"simulate",
268-
[](double t0, double tmax, double dt, const mio::osecirvvs::Model<double>& model) {
269-
return mio::osecirvvs::simulate(t0, tmax, dt, model);
270-
},
271-
"Simulates an ODE SECIRVVS model from t0 to tmax.", py::arg("t0"), py::arg("tmax"), py::arg("dt"),
272-
py::arg("model"));
269+
"simulate", &mio::osecirvvs::simulate<double>, "Simulates an ODE SECIRVVS model from t0 to tmax.",
270+
py::arg("t0"), py::arg("tmax"), py::arg("dt"), py::arg("model"), py::arg("integrator") = py::none());
273271

274272
m.def(
275-
"simulate_flows",
276-
[](double t0, double tmax, double dt, const mio::osecirvvs::Model<double>& model) {
277-
return mio::osecirvvs::simulate_flows(t0, tmax, dt, model);
278-
},
279-
"Simulates an ODE SECIRVVS model with flows from t0 to tmax.", py::arg("t0"), py::arg("tmax"), py::arg("dt"),
280-
py::arg("model"));
273+
"simulate_flows", &mio::osecirvvs::simulate_flows<double>, "Simulates an ODE SECIRVVS model with flows from t0 to tmax.",
274+
py::arg("t0"), py::arg("tmax"), py::arg("dt"), py::arg("model"), py::arg("integrator") = py::none());
275+
281276

282277
pymio::bind_ModelNode<mio::osecirvvs::Model<double>>(m, "ModelNode");
283278
pymio::bind_SimulationNode<mio::osecirvvs::Simulation<>>(m, "SimulationNode");

pycode/memilio-simulation/memilio/simulation/bindings/models/oseir.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,16 @@ PYBIND11_MODULE(_simulation_oseir, m)
8787
m, "Model")
8888
.def(py::init<int>(), py::arg("num_agegroups"));
8989

90+
pymio::bind_Simulation<mio::Simulation<double, mio::oseir::Model<double>>>(m, "Simulation");
91+
pymio::bind_Flow_Simulation<mio::FlowSimulation<double, mio::oseir::Model<double>>>(m, "FlowSimulation");
92+
9093
m.def(
91-
"simulate",
92-
[](double t0, double tmax, double dt, const mio::oseir::Model<double>& model) {
93-
return mio::simulate(t0, tmax, dt, model);
94-
},
95-
"Simulates an ODE SEIR from t0 to tmax.", py::arg("t0"), py::arg("tmax"), py::arg("dt"), py::arg("model"));
94+
"simulate", &mio::simulate<double, mio::oseir::Model<double>>, "Simulates an ODE SEIR from t0 to tmax.",
95+
py::arg("t0"), py::arg("tmax"), py::arg("dt"), py::arg("model"), py::arg("integrator") = py::none());
9696

9797
m.def(
98-
"simulate_flows",
99-
[](double t0, double tmax, double dt, const mio::oseir::Model<double>& model) {
100-
return mio::simulate_flows(t0, tmax, dt, model);
101-
},
102-
"Simulates an ODE SEIR with flows from t0 to tmax.", py::arg("t0"), py::arg("tmax"), py::arg("dt"),
103-
py::arg("model"));
98+
"simulate_flows", &mio::simulate_flows<double, mio::oseir::Model<double>>, "Simulates an ODE SEIR with flows from t0 to tmax.",
99+
py::arg("t0"), py::arg("tmax"), py::arg("dt"), py::arg("model"), py::arg("integrator") = py::none());
104100

105101
m.attr("__version__") = "dev";
106102
}

pycode/memilio-simulation/memilio/simulation/bindings/models/osir.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,11 @@ PYBIND11_MODULE(_simulation_osir, m)
8585
m, "Model")
8686
.def(py::init<int>(), py::arg("num_agegroups"));
8787

88+
pymio::bind_Simulation<mio::Simulation<double, mio::osir::Model<double>>>(m, "Simulation");
89+
8890
m.def(
89-
"simulate",
90-
[](double t0, double tmax, double dt, const mio::osir::Model<double>& model) {
91-
return mio::simulate(t0, tmax, dt, model);
92-
},
93-
"Simulates an ODE SIR model from t0 to tmax.", py::arg("t0"), py::arg("tmax"), py::arg("dt"), py::arg("model"));
91+
"simulate", &mio::simulate<double, mio::osir::Model<double>>, "Simulates an ODE SIR model from t0 to tmax.",
92+
py::arg("t0"), py::arg("tmax"), py::arg("dt"), py::arg("model"), py::arg("integrator") = py::none());
9493

9594
m.attr("__version__") = "dev";
9695
}

pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "epidemiology/uncertain_matrix.h"
2828
#include "epidemiology/dynamic_npis.h"
2929
#include "epidemiology/simulation_day.h"
30+
#include "math/integrator.h"
3031
#include "mobility/metapopulation_mobility_instant.h"
3132
#include "utils/date.h"
3233
#include "utils/logging.h"
@@ -74,6 +75,8 @@ PYBIND11_MODULE(_simulation, m)
7475

7576
pymio::bind_time_series(m, "TimeSeries");
7677

78+
pymio::bind_Integrator_Core(m);
79+
7780
auto contact_matrix_class =
7881
pymio::bind_class<mio::ContactMatrix, pymio::EnablePickling::Required>(m, "ContactMatrix");
7982
pymio::bind_damping_expression_members(contact_matrix_class);
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#############################################################################
2+
# Copyright (C) 2020-2024 MEmilio
3+
#
4+
# Authors: Maximilian Betz
5+
#
6+
# Contact: Martin J. Kuehn <[email protected]>
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
#############################################################################
20+
import unittest
21+
22+
import numpy as np
23+
24+
import memilio.simulation as mio
25+
from memilio.simulation.osir import Model, simulate, InfectionState
26+
27+
28+
class Test_Integrators(unittest.TestCase):
29+
def test_euler_step(self):
30+
31+
# Define a simple deriv function for step
32+
def deriv_function(y, t, dydt):
33+
dydt[:] = -y
34+
35+
integrator = mio.EulerIntegratorCore()
36+
yt = np.array([1.0, 2.0])
37+
t = 0.0
38+
dt = 0.1
39+
ytp1 = np.zeros_like(yt)
40+
41+
result = integrator.step(deriv_function, yt, t, dt, ytp1)
42+
self.assertTrue(result)
43+
self.assertTrue((ytp1 == [0.9, 1.8]).all())
44+
45+
def test_model_integration(self):
46+
47+
model = Model(1)
48+
A0 = mio.AgeGroup(0)
49+
50+
# Compartment transition duration
51+
model.parameters.TimeInfected[A0] = 6.
52+
53+
# Compartment transition propabilities
54+
model.parameters.TransmissionProbabilityOnContact[A0] = 1.
55+
56+
# Initial number of people in each compartment
57+
model.populations[A0, InfectionState.Infected] = 50
58+
model.populations[A0, InfectionState.Recovered] = 10
59+
model.populations.set_difference_from_total(
60+
(A0, InfectionState.Susceptible), 8000)
61+
62+
integrator = mio.RKIntegratorCore()
63+
result1 = simulate(0, 5, 1, model, integrator)
64+
65+
dt_max = 0.5
66+
integrator.dt_max = dt_max
67+
result2 = simulate(0, 5, 0.5, model, integrator)
68+
69+
self.assertTrue((dt_max >= np.diff(result2.as_ndarray()[0, :])).all())
70+
self.assertFalse((result1.get_last_value() ==
71+
result2.get_last_value()).all())
72+
73+
74+
if __name__ == '__main__':
75+
unittest.main()

0 commit comments

Comments
 (0)