Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions deepmd/pd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,8 @@ def eval_descriptor(
model = (
self.dp.model["Default"] if isinstance(self.dp, ModelWrapper) else self.dp
)
if self.auto_batch_size is not None:
self.auto_batch_size.set_oom_retry_mode(True)
model.set_eval_descriptor_hook(True)
self.eval(
coords,
Expand All @@ -835,6 +837,8 @@ def eval_descriptor(
)
descriptor = model.eval_descriptor()
model.set_eval_descriptor_hook(False)
if self.auto_batch_size is not None:
self.auto_batch_size.set_oom_retry_mode(False)
return to_numpy_array(descriptor)

def eval_fitting_last_layer(
Expand Down Expand Up @@ -878,6 +882,8 @@ def eval_fitting_last_layer(
Fitting output before last layer.
"""
model = self.dp.model["Default"]
if self.auto_batch_size is not None:
self.auto_batch_size.set_oom_retry_mode(True)
model.set_eval_fitting_last_layer_hook(True)
self.eval(
coords,
Expand All @@ -890,4 +896,6 @@ def eval_fitting_last_layer(
)
fitting_net = model.eval_fitting_last_layer()
model.set_eval_fitting_last_layer_hook(False)
if self.auto_batch_size is not None:
self.auto_batch_size.set_oom_retry_mode(False)
return to_numpy_array(fitting_net)
8 changes: 8 additions & 0 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,8 @@ def eval_descriptor(
Descriptors.
"""
model = self.dp.model["Default"]
if self.auto_batch_size is not None:
self.auto_batch_size.set_oom_retry_mode(True)
model.set_eval_descriptor_hook(True)
self.eval(
coords,
Expand All @@ -805,6 +807,8 @@ def eval_descriptor(
)
descriptor = model.eval_descriptor()
model.set_eval_descriptor_hook(False)
if self.auto_batch_size is not None:
self.auto_batch_size.set_oom_retry_mode(False)
return to_numpy_array(descriptor)

def eval_fitting_last_layer(
Expand Down Expand Up @@ -848,6 +852,8 @@ def eval_fitting_last_layer(
Fitting output before last layer.
"""
model = self.dp.model["Default"]
if self.auto_batch_size is not None:
self.auto_batch_size.set_oom_retry_mode(True)
model.set_eval_fitting_last_layer_hook(True)
self.eval(
coords,
Expand All @@ -860,4 +866,6 @@ def eval_fitting_last_layer(
)
fitting_net = model.eval_fitting_last_layer()
model.set_eval_fitting_last_layer_hook(False)
if self.auto_batch_size is not None:
self.auto_batch_size.set_oom_retry_mode(False)
return to_numpy_array(fitting_net)
48 changes: 48 additions & 0 deletions deepmd/utils/batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,38 @@
log = logging.getLogger(__name__)


class RetrySignal(Exception):
"""Signal to retry execution after OOM error."""


# originally copied from dpdispatcher
# https://github.com/deepmodeling/dpdispatcher/blob/9a76542311a02e84c4ae62f15b7edcd30850a64e/dpdispatcher/utils/utils.py#L161-L213
# license: LGPL-3.0-or-later
def retry(func: Any) -> Callable:
"""Decorator to retry the function until it succeeds or fails for certain times.

Returns
-------
wrapper: Callable
The wrapper.

Examples
--------
>>> @retry
... def func():
... raise RetrySignal("Failed")
"""

def wrapper(*args: Any, **kwargs: Any) -> Any:
while True:
try:
return func(*args, **kwargs)
except RetrySignal:
log.info("Retry the entire method")

return wrapper


class AutoBatchSize(ABC):
"""This class allows DeePMD-kit to automatically decide the maximum
batch size that will not cause an OOM error.
Expand Down Expand Up @@ -75,6 +107,7 @@ def __init__(self, initial_batch_size: int = 1024, factor: float = 2.0) -> None:
)

self.factor = factor
self.oom_retry_mode = False

def execute(
self, callable: Callable, start_index: int, natoms: int
Expand Down Expand Up @@ -125,6 +158,8 @@ def execute(
) from e
# adjust the next batch size
self._adjust_batch_size(1.0 / self.factor)
if self.set_oom_retry_mode:
raise RetrySignal from e
return 0, None
else:
n_tot = n_batch * natoms
Expand All @@ -147,6 +182,7 @@ def _adjust_batch_size(self, factor: float) -> None:
f"Adjust batch size from {old_batch_size} to {self.current_batch_size}"
)

@retry
def execute_all(
self,
callable: Callable,
Expand Down Expand Up @@ -281,3 +317,15 @@ def is_oom_error(self, e: Exception) -> bool:
bool
True if the exception is an OOM error
"""

def set_oom_retry_mode(self, enable: bool) -> None:
"""Set OOM retry mode.

In OOM retry mode, all data will be re-executed.

Parameters
----------
enable : bool
True to enable OOM retry mode
"""
self.oom_retry_mode = enable
Loading