7
7
8
8
import numpy as np
9
9
import simplekml
10
- from scipy import integrate
10
+ from scipy . integrate import BDF , DOP853 , LSODA , RK23 , RK45 , OdeSolver , Radau
11
11
12
12
from ..mathutils .function import Function , funcify_method
13
13
from ..mathutils .vector_matrix import Matrix , Vector
24
24
quaternions_to_spin ,
25
25
)
26
26
27
+ ODE_SOLVER_MAP = {
28
+ 'RK23' : RK23 ,
29
+ 'RK45' : RK45 ,
30
+ 'DOP853' : DOP853 ,
31
+ 'Radau' : Radau ,
32
+ 'BDF' : BDF ,
33
+ 'LSODA' : LSODA ,
34
+ }
27
35
28
- class Flight : # pylint: disable=too-many-public-methods
36
+
37
+ # pylint: disable=too-many-public-methods
38
+ # pylint: disable=too-many-instance-attributes
39
+ class Flight :
29
40
"""Keeps all flight information and has a method to simulate flight.
30
41
31
42
Attributes
@@ -506,6 +517,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
506
517
verbose = False ,
507
518
name = "Flight" ,
508
519
equations_of_motion = "standard" ,
520
+ ode_solver = "LSODA" ,
509
521
):
510
522
"""Run a trajectory simulation.
511
523
@@ -581,10 +593,23 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
581
593
more restricted set of equations of motion that only works for
582
594
solid propulsion rockets. Such equations were used in RocketPy v0
583
595
and are kept here for backwards compatibility.
596
+ ode_solver : str, ``scipy.integrate.OdeSolver``, optional
597
+ Integration method to use to solve the equations of motion ODE.
598
+ Available options are: 'RK23', 'RK45', 'DOP853', 'Radau', 'BDF',
599
+ 'LSODA' from ``scipy.integrate.solve_ivp``.
600
+ Default is 'LSODA', which is recommended for most flights.
601
+ A custom ``scipy.integrate.OdeSolver`` can be passed as well.
602
+ For more information on the integration methods, see the scipy
603
+ documentation [1]_.
604
+
584
605
585
606
Returns
586
607
-------
587
608
None
609
+
610
+ References
611
+ ----------
612
+ .. [1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html
588
613
"""
589
614
# Save arguments
590
615
self .env = environment
@@ -605,6 +630,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
605
630
self .terminate_on_apogee = terminate_on_apogee
606
631
self .name = name
607
632
self .equations_of_motion = equations_of_motion
633
+ self .ode_solver = ode_solver
608
634
609
635
# Controller initialization
610
636
self .__init_controllers ()
@@ -651,15 +677,16 @@ def __simulate(self, verbose):
651
677
652
678
# Create solver for this flight phase # TODO: allow different integrators
653
679
self .function_evaluations .append (0 )
654
- phase .solver = integrate .LSODA (
680
+
681
+ phase .solver = self ._solver (
655
682
phase .derivative ,
656
683
t0 = phase .t ,
657
684
y0 = self .y_sol ,
658
685
t_bound = phase .time_bound ,
659
- min_step = self .min_time_step ,
660
- max_step = self .max_time_step ,
661
686
rtol = self .rtol ,
662
687
atol = self .atol ,
688
+ max_step = self .max_time_step ,
689
+ min_step = self .min_time_step ,
663
690
)
664
691
665
692
# Initialize phase time nodes
@@ -691,13 +718,14 @@ def __simulate(self, verbose):
691
718
for node_index , node in self .time_iterator (phase .time_nodes ):
692
719
# Determine time bound for this time node
693
720
node .time_bound = phase .time_nodes [node_index + 1 ].t
694
- # NOTE: Setting the time bound and status for the phase solver,
695
- # and updating its internal state for the next integration step.
696
721
phase .solver .t_bound = node .time_bound
697
- phase .solver ._lsoda_solver ._integrator .rwork [0 ] = phase .solver .t_bound
698
- phase .solver ._lsoda_solver ._integrator .call_args [4 ] = (
699
- phase .solver ._lsoda_solver ._integrator .rwork
700
- )
722
+ if self .__is_lsoda :
723
+ phase .solver ._lsoda_solver ._integrator .rwork [0 ] = (
724
+ phase .solver .t_bound
725
+ )
726
+ phase .solver ._lsoda_solver ._integrator .call_args [4 ] = (
727
+ phase .solver ._lsoda_solver ._integrator .rwork
728
+ )
701
729
phase .solver .status = "running"
702
730
703
731
# Feed required parachute and discrete controller triggers
@@ -1185,6 +1213,8 @@ def __init_solver_monitors(self):
1185
1213
self .t = self .solution [- 1 ][0 ]
1186
1214
self .y_sol = self .solution [- 1 ][1 :]
1187
1215
1216
+ self .__set_ode_solver (self .ode_solver )
1217
+
1188
1218
def __init_equations_of_motion (self ):
1189
1219
"""Initialize equations of motion."""
1190
1220
if self .equations_of_motion == "solid_propulsion" :
@@ -1222,6 +1252,28 @@ def __cache_sensor_data(self):
1222
1252
sensor_data [sensor ] = sensor .measured_data [:]
1223
1253
self .sensor_data = sensor_data
1224
1254
1255
+ def __set_ode_solver (self , solver ):
1256
+ """Sets the ODE solver to be used in the simulation.
1257
+
1258
+ Parameters
1259
+ ----------
1260
+ solver : str, ``scipy.integrate.OdeSolver``
1261
+ Integration method to use to solve the equations of motion ODE,
1262
+ or a custom ``scipy.integrate.OdeSolver``.
1263
+ """
1264
+ if isinstance (solver , OdeSolver ):
1265
+ self ._solver = solver
1266
+ else :
1267
+ try :
1268
+ self ._solver = ODE_SOLVER_MAP [solver ]
1269
+ except KeyError as e :
1270
+ raise ValueError (
1271
+ f"Invalid ``ode_solver`` input: { solver } . "
1272
+ f"Available options are: { ', ' .join (ODE_SOLVER_MAP .keys ())} "
1273
+ ) from e
1274
+
1275
+ self .__is_lsoda = hasattr (self ._solver , "_lsoda_solver" )
1276
+
1225
1277
@cached_property
1226
1278
def effective_1rl (self ):
1227
1279
"""Original rail length minus the distance measured from nozzle exit
0 commit comments