diff --git a/deepmd/pd/infer/deep_eval.py b/deepmd/pd/infer/deep_eval.py index 6c0ffed7ec..2465a2b0de 100644 --- a/deepmd/pd/infer/deep_eval.py +++ b/deepmd/pd/infer/deep_eval.py @@ -64,6 +64,9 @@ to_numpy_array, to_paddle_tensor, ) +from deepmd.utils.batch_size import ( + RetrySignal, +) from deepmd.utils.econf_embd import ( sort_element_type, ) @@ -823,18 +826,33 @@ def eval_descriptor( model = ( self.dp.model["Default"] if isinstance(self.dp, ModelWrapper) else self.dp ) + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(True) model.set_eval_descriptor_hook(True) - self.eval( - coords, - cells, - atom_types, - atomic=False, - fparam=fparam, - aparam=aparam, - **kwargs, - ) - descriptor = model.eval_descriptor() - model.set_eval_descriptor_hook(False) + try: + self.eval( + coords, + cells, + atom_types, + atomic=False, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + except RetrySignal: + return self.eval_descriptor( + coords, + cells, + atom_types, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + finally: + descriptor = model.eval_descriptor() + model.set_eval_descriptor_hook(False) + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(False) return to_numpy_array(descriptor) def eval_fitting_last_layer( @@ -878,16 +896,31 @@ def eval_fitting_last_layer( Fitting output before last layer. """ model = self.dp.model["Default"] + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(True) model.set_eval_fitting_last_layer_hook(True) - self.eval( - coords, - cells, - atom_types, - atomic=False, - fparam=fparam, - aparam=aparam, - **kwargs, - ) - fitting_net = model.eval_fitting_last_layer() - model.set_eval_fitting_last_layer_hook(False) + try: + self.eval( + coords, + cells, + atom_types, + atomic=False, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + except RetrySignal: + return self.eval_descriptor( + coords, + cells, + atom_types, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + finally: + fitting_net = model.eval_fitting_last_layer() + model.set_eval_fitting_last_layer_hook(False) + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(False) return to_numpy_array(fitting_net) diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 6e63ecb2fc..be90fcea78 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -67,6 +67,9 @@ to_numpy_array, to_torch_tensor, ) +from deepmd.utils.batch_size import ( + RetrySignal, +) from deepmd.utils.econf_embd import ( sort_element_type, ) @@ -793,18 +796,33 @@ def eval_descriptor( Descriptors. """ model = self.dp.model["Default"] + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(True) model.set_eval_descriptor_hook(True) - self.eval( - coords, - cells, - atom_types, - atomic=False, - fparam=fparam, - aparam=aparam, - **kwargs, - ) - descriptor = model.eval_descriptor() - model.set_eval_descriptor_hook(False) + try: + self.eval( + coords, + cells, + atom_types, + atomic=False, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + except RetrySignal: + return self.eval_descriptor( + coords, + cells, + atom_types, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + finally: + descriptor = model.eval_descriptor() + model.set_eval_descriptor_hook(False) + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(False) return to_numpy_array(descriptor) def eval_fitting_last_layer( @@ -848,16 +866,31 @@ def eval_fitting_last_layer( Fitting output before last layer. """ model = self.dp.model["Default"] + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(True) model.set_eval_fitting_last_layer_hook(True) - self.eval( - coords, - cells, - atom_types, - atomic=False, - fparam=fparam, - aparam=aparam, - **kwargs, - ) - fitting_net = model.eval_fitting_last_layer() - model.set_eval_fitting_last_layer_hook(False) + try: + self.eval( + coords, + cells, + atom_types, + atomic=False, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + except RetrySignal: + return self.eval_descriptor( + coords, + cells, + atom_types, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + finally: + fitting_net = model.eval_fitting_last_layer() + model.set_eval_fitting_last_layer_hook(False) + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(False) return to_numpy_array(fitting_net) diff --git a/deepmd/utils/batch_size.py b/deepmd/utils/batch_size.py index e701e82ec6..82de03695c 100644 --- a/deepmd/utils/batch_size.py +++ b/deepmd/utils/batch_size.py @@ -22,6 +22,10 @@ log = logging.getLogger(__name__) +class RetrySignal(Exception): + """Signal to retry execution after OOM error.""" + + class AutoBatchSize(ABC): """This class allows DeePMD-kit to automatically decide the maximum batch size that will not cause an OOM error. @@ -75,6 +79,7 @@ def __init__(self, initial_batch_size: int = 1024, factor: float = 2.0) -> None: ) self.factor = factor + self.oom_retry_mode = False def execute( self, callable: Callable, start_index: int, natoms: int @@ -125,6 +130,8 @@ def execute( ) from e # adjust the next batch size self._adjust_batch_size(1.0 / self.factor) + if self.oom_retry_mode: + raise RetrySignal from e return 0, None else: n_tot = n_batch * natoms @@ -281,3 +288,15 @@ def is_oom_error(self, e: Exception) -> bool: bool True if the exception is an OOM error """ + + def set_oom_retry_mode(self, enable: bool) -> None: + """Set OOM retry mode. + + In OOM retry mode, all data will be re-executed. + + Parameters + ---------- + enable : bool + True to enable OOM retry mode + """ + self.oom_retry_mode = enable