Skip to content

Commit fec6bf0

Browse files
phmbressanGui-FernandesBR
authored andcommitted
ENH: Allow for Alternative and Custom ODE Solvers.
TST: Add slow testing for different ode solvers. MNT: Move ode solver validation to separate method.
1 parent 39d47cf commit fec6bf0

File tree

4 files changed

+103
-13
lines changed

4 files changed

+103
-13
lines changed

.vscode/settings.json

+1
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@
245245
"pytest",
246246
"pytz",
247247
"quantile",
248+
"Radau",
248249
"Rdot",
249250
"referece",
250251
"relativetoground",

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Attention: The newest changes should be on top -->
3232

3333
### Added
3434

35-
-
35+
- ENH: Allow for Alternative and Custom ODE Solvers. [#748](https://github.com/RocketPy-Team/RocketPy/pull/748)
3636

3737
### Changed
3838

rocketpy/simulation/flight.py

+63-11
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99
import simplekml
10-
from scipy import integrate
10+
from scipy.integrate import BDF, DOP853, LSODA, RK23, RK45, OdeSolver, Radau
1111

1212
from ..mathutils.function import Function, funcify_method
1313
from ..mathutils.vector_matrix import Matrix, Vector
@@ -24,8 +24,19 @@
2424
quaternions_to_spin,
2525
)
2626

27+
ODE_SOLVER_MAP = {
28+
'RK23': RK23,
29+
'RK45': RK45,
30+
'DOP853': DOP853,
31+
'Radau': Radau,
32+
'BDF': BDF,
33+
'LSODA': LSODA,
34+
}
2735

28-
class Flight: # pylint: disable=too-many-public-methods
36+
37+
# pylint: disable=too-many-public-methods
38+
# pylint: disable=too-many-instance-attributes
39+
class Flight:
2940
"""Keeps all flight information and has a method to simulate flight.
3041
3142
Attributes
@@ -506,6 +517,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
506517
verbose=False,
507518
name="Flight",
508519
equations_of_motion="standard",
520+
ode_solver="LSODA",
509521
):
510522
"""Run a trajectory simulation.
511523
@@ -581,10 +593,23 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
581593
more restricted set of equations of motion that only works for
582594
solid propulsion rockets. Such equations were used in RocketPy v0
583595
and are kept here for backwards compatibility.
596+
ode_solver : str, ``scipy.integrate.OdeSolver``, optional
597+
Integration method to use to solve the equations of motion ODE.
598+
Available options are: 'RK23', 'RK45', 'DOP853', 'Radau', 'BDF',
599+
'LSODA' from ``scipy.integrate.solve_ivp``.
600+
Default is 'LSODA', which is recommended for most flights.
601+
A custom ``scipy.integrate.OdeSolver`` can be passed as well.
602+
For more information on the integration methods, see the scipy
603+
documentation [1]_.
604+
584605
585606
Returns
586607
-------
587608
None
609+
610+
References
611+
----------
612+
.. [1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html
588613
"""
589614
# Save arguments
590615
self.env = environment
@@ -605,6 +630,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
605630
self.terminate_on_apogee = terminate_on_apogee
606631
self.name = name
607632
self.equations_of_motion = equations_of_motion
633+
self.ode_solver = ode_solver
608634

609635
# Controller initialization
610636
self.__init_controllers()
@@ -651,15 +677,16 @@ def __simulate(self, verbose):
651677

652678
# Create solver for this flight phase # TODO: allow different integrators
653679
self.function_evaluations.append(0)
654-
phase.solver = integrate.LSODA(
680+
681+
phase.solver = self._solver(
655682
phase.derivative,
656683
t0=phase.t,
657684
y0=self.y_sol,
658685
t_bound=phase.time_bound,
659-
min_step=self.min_time_step,
660-
max_step=self.max_time_step,
661686
rtol=self.rtol,
662687
atol=self.atol,
688+
max_step=self.max_time_step,
689+
min_step=self.min_time_step,
663690
)
664691

665692
# Initialize phase time nodes
@@ -691,13 +718,14 @@ def __simulate(self, verbose):
691718
for node_index, node in self.time_iterator(phase.time_nodes):
692719
# Determine time bound for this time node
693720
node.time_bound = phase.time_nodes[node_index + 1].t
694-
# NOTE: Setting the time bound and status for the phase solver,
695-
# and updating its internal state for the next integration step.
696721
phase.solver.t_bound = node.time_bound
697-
phase.solver._lsoda_solver._integrator.rwork[0] = phase.solver.t_bound
698-
phase.solver._lsoda_solver._integrator.call_args[4] = (
699-
phase.solver._lsoda_solver._integrator.rwork
700-
)
722+
if self.__is_lsoda:
723+
phase.solver._lsoda_solver._integrator.rwork[0] = (
724+
phase.solver.t_bound
725+
)
726+
phase.solver._lsoda_solver._integrator.call_args[4] = (
727+
phase.solver._lsoda_solver._integrator.rwork
728+
)
701729
phase.solver.status = "running"
702730

703731
# Feed required parachute and discrete controller triggers
@@ -1185,6 +1213,8 @@ def __init_solver_monitors(self):
11851213
self.t = self.solution[-1][0]
11861214
self.y_sol = self.solution[-1][1:]
11871215

1216+
self.__set_ode_solver(self.ode_solver)
1217+
11881218
def __init_equations_of_motion(self):
11891219
"""Initialize equations of motion."""
11901220
if self.equations_of_motion == "solid_propulsion":
@@ -1222,6 +1252,28 @@ def __cache_sensor_data(self):
12221252
sensor_data[sensor] = sensor.measured_data[:]
12231253
self.sensor_data = sensor_data
12241254

1255+
def __set_ode_solver(self, solver):
1256+
"""Sets the ODE solver to be used in the simulation.
1257+
1258+
Parameters
1259+
----------
1260+
solver : str, ``scipy.integrate.OdeSolver``
1261+
Integration method to use to solve the equations of motion ODE,
1262+
or a custom ``scipy.integrate.OdeSolver``.
1263+
"""
1264+
if isinstance(solver, OdeSolver):
1265+
self._solver = solver
1266+
else:
1267+
try:
1268+
self._solver = ODE_SOLVER_MAP[solver]
1269+
except KeyError as e:
1270+
raise ValueError(
1271+
f"Invalid ``ode_solver`` input: {solver}. "
1272+
f"Available options are: {', '.join(ODE_SOLVER_MAP.keys())}"
1273+
) from e
1274+
1275+
self.__is_lsoda = hasattr(self._solver, "_lsoda_solver")
1276+
12251277
@cached_property
12261278
def effective_1rl(self):
12271279
"""Original rail length minus the distance measured from nozzle exit

tests/integration/test_flight.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212

1313
@patch("matplotlib.pyplot.show")
14-
def test_all_info(mock_show, flight_calisto_robust): # pylint: disable=unused-argument
14+
# pylint: disable=unused-argument
15+
def test_all_info(mock_show, flight_calisto_robust):
1516
"""Test that the flight class is working as intended. This basically calls
1617
the all_info() method and checks if it returns None. It is not testing if
1718
the values are correct, but whether the method is working without errors.
@@ -27,6 +28,42 @@ def test_all_info(mock_show, flight_calisto_robust): # pylint: disable=unused-a
2728
assert flight_calisto_robust.all_info() is None
2829

2930

31+
@pytest.mark.slow
32+
@patch("matplotlib.pyplot.show")
33+
@pytest.mark.parametrize("solver_method", ["RK45", "DOP853", "Radau", "BDF"])
34+
# RK23 is unstable and requires a very low tolerance to work
35+
# pylint: disable=unused-argument
36+
def test_all_info_different_solvers(
37+
mock_show, calisto_robust, example_spaceport_env, solver_method
38+
):
39+
"""Test that the flight class is working as intended with different solver
40+
methods. This basically calls the all_info() method and checks if it returns
41+
None. It is not testing if the values are correct, but whether the method is
42+
working without errors.
43+
44+
Parameters
45+
----------
46+
mock_show : unittest.mock.MagicMock
47+
Mock object to replace matplotlib.pyplot.show
48+
calisto_robust : rocketpy.Rocket
49+
Rocket to be simulated. See the conftest.py file for more info.
50+
example_spaceport_env : rocketpy.Environment
51+
Environment to be simulated. See the conftest.py file for more info.
52+
solver_method : str
53+
The solver method to be used in the simulation.
54+
"""
55+
test_flight = Flight(
56+
environment=example_spaceport_env,
57+
rocket=calisto_robust,
58+
rail_length=5.2,
59+
inclination=85,
60+
heading=0,
61+
terminate_on_apogee=False,
62+
ode_solver=solver_method,
63+
)
64+
assert test_flight.all_info() is None
65+
66+
3067
class TestExportData:
3168
"""Tests the export_data method of the Flight class."""
3269

0 commit comments

Comments
 (0)