Skip to content

Commit 36323d1

Browse files
lukebaumanncopybara-github
authored andcommitted
Treat additional error types as potential slice down issues.
This change separates the set of `JaxRuntimeError` types that are considered indicative of a slice being down into `DATA_LOSS` and additional types. `DEADLINE_EXCEEDED`, `NOT_FOUND`, and `INTERNAL` are now treated as "may or may not" be related to slice down but still returning true. PiperOrigin-RevId: 802202196
1 parent 315f578 commit 36323d1

File tree

1 file changed

+33
-9
lines changed

1 file changed

+33
-9
lines changed

pathwaysutils/elastic/manager.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ class Manager:
7070
_SIMPLE_EXECUTION_TEST_VALUE = 100
7171
_ELASTIC_DOWN_ERROR_TYPES = [
7272
"DATA_LOSS",
73+
]
74+
_ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES = [
75+
"DEADLINE_EXCEEDED",
7376
"NOT_FOUND",
7477
"INTERNAL",
7578
]
@@ -168,24 +171,45 @@ def is_error_due_to_slice_down(self, error: Exception) -> bool:
168171
The error types that are considered due to slice down are
169172
jax.errors.JaxRuntimeError with the following error kind in the message:
170173
- DATA_LOSS
174+
- DEADLINE_EXCEEDED
171175
- NOT_FOUND
172176
- INTERNAL
173177
174178
Args:
175179
error: The error to check.
176180
"""
177-
return_value = isinstance(error, jax.errors.JaxRuntimeError) and any(
178-
error_type in str(error)
179-
for error_type in self._ELASTIC_DOWN_ERROR_TYPES
180-
)
181-
if return_value:
182-
_logger.info("Caught an error due to slice down")
183-
else:
181+
error_due_to_slice_down = False
182+
traceback_logging_level = logging.DEBUG
183+
184+
if isinstance(error, jax.errors.JaxRuntimeError):
185+
if any(
186+
error_type in str(error)
187+
for error_type in self._ELASTIC_DOWN_ERROR_TYPES
188+
):
189+
_logger.info("Caught an error due to slice down")
190+
191+
error_due_to_slice_down = True
192+
193+
elif any(
194+
error_type in str(error)
195+
for error_type in self._ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES
196+
):
197+
_logger.warning(
198+
"Caught an error due that may or may not be due to slice down. This"
199+
" error will be treated as due to slice down."
200+
)
201+
traceback_logging_level = logging.WARNING
202+
203+
error_due_to_slice_down = True
204+
205+
if not error_due_to_slice_down:
184206
_logger.info("Caught an error not due to slice down")
185207

186-
_logger.debug("\n".join(traceback.format_exception(error)))
208+
_logger.log(
209+
traceback_logging_level, "\n".join(traceback.format_exception(error))
210+
)
187211

188-
return return_value
212+
return error_due_to_slice_down
189213

190214
def _simple_execution(self, devices: Sequence[jax.Device]) -> jax.Array:
191215
"""Simple execution to test if a slice is available.

0 commit comments

Comments
 (0)