-
Notifications
You must be signed in to change notification settings - Fork 504
fix: prevent async RL dispatch crashes on uneven batches #1225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -452,11 +452,30 @@ def _custom_function_call( | |
| **kwargs, | ||
| ): | ||
| """Dispatch method call to workers via the appropriate path.""" | ||
| group_size = kwargs.get("group_size", 1) | ||
| orig_len = None | ||
| for arg in args: | ||
| if isinstance(arg, list) and arg and _is_tensor_like(arg): | ||
| orig_len = len(arg) | ||
| break | ||
|
|
||
| args, kwargs = self._pad_eval_dispatch_args( | ||
| args, kwargs, group_size=group_size | ||
| ) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a potential |
||
| dp_args, dp_kwargs, group_indices = self._prepare_dispatch(*args, **kwargs) | ||
| results = run_async_task( | ||
| self._call_workers, method, dp_args, dp_kwargs, rpc_meta=rpc_meta | ||
| ) | ||
| return self._collect_results(results, group_indices) | ||
| merged_results = self._collect_results(results, group_indices) | ||
|
|
||
| if ( | ||
| orig_len is not None | ||
| and isinstance(merged_results, list) | ||
| and len(merged_results) > orig_len | ||
| ): | ||
| merged_results = merged_results[:orig_len] | ||
|
|
||
| return merged_results | ||
|
|
||
| async def _async_custom_function_call( | ||
| self, | ||
|
|
@@ -466,11 +485,30 @@ async def _async_custom_function_call( | |
| **kwargs, | ||
| ): | ||
| """Async version of _custom_function_call.""" | ||
| group_size = kwargs.get("group_size", 1) | ||
| orig_len = None | ||
| for arg in args: | ||
| if isinstance(arg, list) and arg and _is_tensor_like(arg): | ||
| orig_len = len(arg) | ||
| break | ||
|
|
||
| args, kwargs = self._pad_eval_dispatch_args( | ||
| args, kwargs, group_size=group_size | ||
| ) | ||
| dp_args, dp_kwargs, group_indices = self._prepare_dispatch(*args, **kwargs) | ||
| results = await self._call_workers( | ||
| method, dp_args, dp_kwargs, rpc_meta=rpc_meta | ||
| ) | ||
| return self._collect_results(results, group_indices) | ||
| merged_results = self._collect_results(results, group_indices) | ||
|
|
||
| if ( | ||
| orig_len is not None | ||
| and isinstance(merged_results, list) | ||
| and len(merged_results) > orig_len | ||
| ): | ||
| merged_results = merged_results[:orig_len] | ||
|
|
||
| return merged_results | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The padding and trimming logic in |
||
|
|
||
| def _pad_eval_dispatch_args( | ||
| self, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from types import SimpleNamespace | ||
| from unittest.mock import patch | ||
|
|
||
| import torch | ||
|
|
||
| from areal.infra.controller.train_controller import TrainController | ||
| from areal.utils.seqpack import balanced_greedy_partition | ||
|
|
||
|
|
||
| def test_balanced_greedy_partition_handles_remainders(): | ||
| groups = balanced_greedy_partition([10, 8, 6, 4, 2], K=3) | ||
|
|
||
| assert len(groups) == 3 | ||
| assert sorted(len(group) for group in groups) == [1, 2, 2] | ||
| assert sorted(idx for group in groups for idx in group) == [0, 1, 2, 3, 4] | ||
|
|
||
|
|
||
| def test_custom_function_call_trims_padded_eval_items(): | ||
| controller = TrainController.__new__(TrainController) | ||
| controller.train_alloc = SimpleNamespace(parallel=SimpleNamespace(dp_size=4)) | ||
|
|
||
| sample = {"attention_mask": torch.ones(1, dtype=torch.long)} | ||
| captured: dict[str, int] = {} | ||
|
|
||
| def fake_prepare_dispatch(*args, **kwargs): | ||
| captured["padded_len"] = len(args[0]) | ||
| return [], {}, None | ||
|
|
||
| controller._prepare_dispatch = fake_prepare_dispatch | ||
| controller._call_workers = lambda *args, **kwargs: [] | ||
| controller._collect_results = ( | ||
| lambda results, group_indices: list(range(captured["padded_len"])) | ||
| ) | ||
|
|
||
| with patch( | ||
| "areal.infra.controller.train_controller.make_dummy_eval_item", | ||
| lambda template: {"attention_mask": torch.zeros_like(template["attention_mask"])}, | ||
| ), patch( | ||
| "areal.infra.controller.train_controller.run_async_task", | ||
| lambda fn, *args, **kwargs: [], | ||
| ): | ||
| result = controller._custom_function_call("noop", [sample] * 5) | ||
|
|
||
| assert captured["padded_len"] == 8 | ||
| assert result == [0, 1, 2, 3, 4] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
orig_lendetection logic only inspectsargs, ignoringkwargs. If a tensor-like batch is passed viakwargs, it will not be padded (as_pad_eval_dispatch_argsalso only handlesargs), and the results will not be trimmed. This inconsistency could lead to crashes or incorrect behavior in engines that require even batches across data-parallel ranks. Consider extending this detection and the padding logic to includekwargs.