diff --git a/torchdiffeq/__init__.py b/torchdiffeq/__init__.py index 4eff75dbf..74712bacf 100644 --- a/torchdiffeq/__init__.py +++ b/torchdiffeq/__init__.py @@ -1,4 +1,5 @@ from ._impl import odeint from ._impl import odeint_adjoint from ._impl import odeint_event +from ._impl import RejectStepError __version__ = "0.2.3" diff --git a/torchdiffeq/_impl/__init__.py b/torchdiffeq/_impl/__init__.py index 05b671e9c..6cbcbace2 100644 --- a/torchdiffeq/_impl/__init__.py +++ b/torchdiffeq/_impl/__init__.py @@ -1,2 +1,3 @@ from .odeint import odeint, odeint_event from .adjoint import odeint_adjoint +from .rk_common import RejectStepError \ No newline at end of file diff --git a/torchdiffeq/_impl/rk_common.py b/torchdiffeq/_impl/rk_common.py index 4d85afb18..98d064a6c 100644 --- a/torchdiffeq/_impl/rk_common.py +++ b/torchdiffeq/_impl/rk_common.py @@ -38,6 +38,10 @@ def backward(ctx, grad_scratch): return grad_scratch, grad_scratch[ctx.index], None +class RejectStepError(Exception): + pass + + def _runge_kutta_step(func, y0, f0, t0, dt, t1, tableau): """Take an arbitrary Runge-Kutta step and estimate error. Args: @@ -262,28 +266,35 @@ def _adaptive_step(self, rk_state): # Must be arranged as doing all the step_t handling, then all the jump_t handling, in case we # trigger both. (i.e. interleaving them would be wrong.) - - y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, t1, tableau=self.tableau) - # dtypes: - # y1.dtype == self.y0.dtype - # f1.dtype == self.y0.dtype - # y1_error.dtype == self.dtype - # k.dtype == self.y0.dtype - - ######################################################## - # Error Ratio # - ######################################################## - error_ratio = _compute_error_ratio(y1_error, self.rtol, self.atol, y0, y1, self.norm) - accept_step = error_ratio <= 1 - - # Handle min max stepping - if dt > self.max_step: + try: + y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, t1, tableau=self.tableau) + # dtypes: + # y1.dtype == self.y0.dtype + # f1.dtype == self.y0.dtype + # y1_error.dtype == self.dtype + # k.dtype == self.y0.dtype + except RejectStepError as ex: + # self.func requested the step be rejected + # If already at minimum step size, stop integration as can't proceed + if dt <= self.min_step: + raise(ex) + error_ratio = torch.tensor(10.0, dtype=self.dtype, device=self.y0.device) accept_step = False - if dt <= self.min_step: - accept_step = True - - # dtypes: - # error_ratio.dtype == self.dtype + else: + ######################################################## + # Error Ratio # + ######################################################## + error_ratio = _compute_error_ratio(y1_error, self.rtol, self.atol, y0, y1, self.norm) + accept_step = error_ratio <= 1 + + # Handle min max stepping + if dt > self.max_step: + accept_step = False + if dt <= self.min_step: + accept_step = True + + # dtypes: + # error_ratio.dtype == self.dtype ######################################################## # Update RK State #