diff --git a/areal/infra/controller/train_controller.py b/areal/infra/controller/train_controller.py index a7c493d5a8..474808f5f6 100644 --- a/areal/infra/controller/train_controller.py +++ b/areal/infra/controller/train_controller.py @@ -123,32 +123,63 @@ def _dispatch_tensors( def _pad_eval_batch( - args: tuple[Any, ...], dp_size: int, group_size: int = 1 -) -> tuple[Any, ...]: - """Pad the first tensor-like arg to a multiple of ``dp_size * group_size``. - - Called before dispatch for explicit evaluation controller paths so that - ``balanced_greedy_partition`` always receives a divisible input. - Dummy items have zero attention/loss masks and contribute nothing - to metrics or loss. + args: tuple[Any, ...], + kwargs: dict[str, Any], + dp_size: int, + group_size: int = 1, +) -> tuple[tuple[Any, ...], dict[str, Any], int | None]: + """Pad all top-level tensor-like inputs to a multiple of ``dp_size * group_size``. + + Called before explicit evaluation dispatch so that tensor partitions stay + aligned across every top-level tensor-like argument/keyword argument. + Dummy items have zero attention/loss masks and contribute nothing to + metrics or loss. """ - result = list(args) - pad_target = dp_size * group_size - for i, arg in enumerate(result): + result_args = list(args) + result_kwargs = dict(kwargs) + tensor_inputs: list[tuple[str, Any, list[Any]]] = [] + + for i, arg in enumerate(result_args): if isinstance(arg, list) and arg and _is_tensor_like(arg): - n = len(arg) - pad_count = (-n) % pad_target - if pad_count > 0: - padded = list(arg) - template = arg[0] - padded.extend(make_dummy_eval_item(template) for _ in range(pad_count)) - result[i] = padded - logger.info( - f"Eval dispatch: padded {pad_count} dummy items " - f"(total {len(padded)}) for dp_size={dp_size}" - ) - break # only pad the first tensor-like arg - return tuple(result) + tensor_inputs.append(("arg", i, arg)) + for key, value in result_kwargs.items(): + if isinstance(value, list) and value and _is_tensor_like(value): + tensor_inputs.append(("kwarg", key, value)) + + if not tensor_inputs: + return tuple(result_args), result_kwargs, None + + lengths = {len(items) for _, _, items in tensor_inputs} + if len(lengths) != 1: + raise ValueError( + "All tensor-like eval dispatch inputs must have the same length, " + f"got lengths {sorted(lengths)}." + ) + orig_len = lengths.pop() + pad_target = dp_size * group_size + pad_count = (-orig_len) % pad_target + if pad_count == 0: + return tuple(result_args), result_kwargs, orig_len + + for location, key, items in tensor_inputs: + padded = list(items) + template = items[0] + padded.extend(make_dummy_eval_item(template) for _ in range(pad_count)) + if location == "arg": + result_args[key] = padded + else: + result_kwargs[key] = padded + + logger.info( + "Eval dispatch: padded %s dummy items for %s tensor-like inputs " + "(total %s) for dp_size=%s group_size=%s", + pad_count, + len(tensor_inputs), + orig_len + pad_count, + dp_size, + group_size, + ) + return tuple(result_args), result_kwargs, orig_len def _merge_tensors( @@ -452,11 +483,16 @@ def _custom_function_call( **kwargs, ): """Dispatch method call to workers via the appropriate path.""" + group_size = kwargs.get("group_size", 1) + args, kwargs, orig_len = self._pad_eval_dispatch_args( + args, kwargs, group_size=group_size + ) dp_args, dp_kwargs, group_indices = self._prepare_dispatch(*args, **kwargs) results = run_async_task( self._call_workers, method, dp_args, dp_kwargs, rpc_meta=rpc_meta ) - return self._collect_results(results, group_indices) + merged_results = self._collect_results(results, group_indices) + return self._trim_padded_eval_results(merged_results, orig_len) async def _async_custom_function_call( self, @@ -466,11 +502,16 @@ async def _async_custom_function_call( **kwargs, ): """Async version of _custom_function_call.""" + group_size = kwargs.get("group_size", 1) + args, kwargs, orig_len = self._pad_eval_dispatch_args( + args, kwargs, group_size=group_size + ) dp_args, dp_kwargs, group_indices = self._prepare_dispatch(*args, **kwargs) results = await self._call_workers( method, dp_args, dp_kwargs, rpc_meta=rpc_meta ) - return self._collect_results(results, group_indices) + merged_results = self._collect_results(results, group_indices) + return self._trim_padded_eval_results(merged_results, orig_len) def _pad_eval_dispatch_args( self, @@ -478,13 +519,20 @@ def _pad_eval_dispatch_args( kwargs: dict[str, Any], *, group_size: int, - ) -> tuple[tuple[Any, ...], dict[str, Any]]: + ) -> tuple[tuple[Any, ...], dict[str, Any], int | None]: """Pad eval batches for explicit algorithm-level evaluation dispatch.""" - kwargs = dict(kwargs) - args = _pad_eval_batch( - args, self.parallel_strategy.dp_size, group_size=group_size + return _pad_eval_batch( + args, + kwargs, + self.parallel_strategy.dp_size, + group_size=group_size, ) - return args, kwargs + + def _trim_padded_eval_results(self, results: Any, orig_len: int | None) -> Any: + """Drop dummy eval outputs that were added only for even DP partitioning.""" + if orig_len is None or not isinstance(results, list) or len(results) <= orig_len: + return results + return results[:orig_len] def _prepare_dispatch( self, *args, **kwargs diff --git a/areal/utils/seqpack.py b/areal/utils/seqpack.py index 175461771f..f1c19b09f0 100644 --- a/areal/utils/seqpack.py +++ b/areal/utils/seqpack.py @@ -553,14 +553,15 @@ def balanced_greedy_partition(nums: list[int], K: int) -> list[list[int]]: List of K lists, where each inner list contains the indices assigned to that group Raises: - ValueError: If len(nums) is not divisible by K or if len(nums) < K + ValueError: If K <= 0 """ + if K <= 0: + raise ValueError(f"K must be positive, got {K}.") + n = len(nums) - if n < K: - raise ValueError(f"Number of items ({n}) must be >= K ({K}).") - if n % K != 0: - raise ValueError("The length of nums must be divisible by K.") - m = n // K + base = n // K + remainder = n % K + capacities = [base + (1 if i < remainder else 0) for i in range(K)] # Sort indices by value in descending order sorted_indices = sorted(range(n), key=lambda i: -nums[i]) @@ -575,10 +576,14 @@ def balanced_greedy_partition(nums: list[int], K: int) -> list[list[int]]: chosen_group = -1 min_sum = float("inf") for i in range(K): - if counts[i] < m and sums[i] < min_sum: + if counts[i] < capacities[i] and sums[i] < min_sum: min_sum = sums[i] chosen_group = i + if chosen_group == -1: + raise RuntimeError( + f"Cannot assign item idx={idx} with capacities={capacities}, counts={counts}" + ) groups[chosen_group].append(idx) sums[chosen_group] += num counts[chosen_group] += 1 diff --git a/tests/test_async_rl_crashes.py b/tests/test_async_rl_crashes.py new file mode 100644 index 0000000000..185a13b757 --- /dev/null +++ b/tests/test_async_rl_crashes.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import patch + +import torch + +from areal.infra.controller.train_controller import TrainController +from areal.utils.seqpack import balanced_greedy_partition + + +def test_balanced_greedy_partition_handles_remainders(): + groups = balanced_greedy_partition([10, 8, 6, 4, 2], K=3) + + assert len(groups) == 3 + assert sorted(len(group) for group in groups) == [1, 2, 2] + assert sorted(idx for group in groups for idx in group) == [0, 1, 2, 3, 4] + + +def _make_controller() -> TrainController: + controller = TrainController.__new__(TrainController) + controller.train_alloc = SimpleNamespace(parallel=SimpleNamespace(dp_size=4)) + return controller + + +def _sample_item() -> dict[str, torch.Tensor]: + return {"attention_mask": torch.ones(1, dtype=torch.long)} + + +def test_custom_function_call_trims_padded_eval_items(): + controller = _make_controller() + sample = _sample_item() + captured: dict[str, int] = {} + + def fake_prepare_dispatch(*args, **kwargs): + captured["padded_len"] = len(args[0]) + return [], {}, None + + controller._prepare_dispatch = fake_prepare_dispatch + controller._call_workers = lambda *args, **kwargs: [] + controller._collect_results = ( + lambda results, group_indices: list(range(captured["padded_len"])) + ) + + with patch( + "areal.infra.controller.train_controller.make_dummy_eval_item", + lambda template: {"attention_mask": torch.zeros_like(template["attention_mask"])}, + ), patch( + "areal.infra.controller.train_controller.run_async_task", + lambda fn, *args, **kwargs: [], + ): + result = controller._custom_function_call("noop", [sample] * 5) + + assert captured["padded_len"] == 8 + assert result == [0, 1, 2, 3, 4] + + +def test_custom_function_call_pads_kwargs_and_parallel_tensor_lists(): + controller = _make_controller() + sample = _sample_item() + captured: dict[str, list[int] | int] = {} + + def fake_prepare_dispatch(*args, **kwargs): + captured["arg_lens"] = [len(args[0]), len(args[1])] + captured["kwarg_len"] = len(kwargs["batch_kw"]) + return [], {}, None + + controller._prepare_dispatch = fake_prepare_dispatch + controller._call_workers = lambda *args, **kwargs: [] + controller._collect_results = ( + lambda results, group_indices: list(range(captured["kwarg_len"])) + ) + + with patch( + "areal.infra.controller.train_controller.make_dummy_eval_item", + lambda template: {"attention_mask": torch.zeros_like(template["attention_mask"])}, + ), patch( + "areal.infra.controller.train_controller.run_async_task", + lambda fn, *args, **kwargs: [], + ): + result = controller._custom_function_call( + "noop", + [sample] * 5, + [sample] * 5, + batch_kw=[sample] * 5, + ) + + assert captured["arg_lens"] == [8, 8] + assert captured["kwarg_len"] == 8 + assert result == [0, 1, 2, 3, 4] + + +def test_custom_function_call_uses_kwargs_length_for_trimming(): + controller = _make_controller() + sample = _sample_item() + captured: dict[str, int] = {} + + def fake_prepare_dispatch(*args, **kwargs): + captured["padded_len"] = len(kwargs["batch_kw"]) + return [], {}, None + + controller._prepare_dispatch = fake_prepare_dispatch + controller._call_workers = lambda *args, **kwargs: [] + controller._collect_results = ( + lambda results, group_indices: list(range(captured["padded_len"])) + ) + + with patch( + "areal.infra.controller.train_controller.make_dummy_eval_item", + lambda template: {"attention_mask": torch.zeros_like(template["attention_mask"])}, + ), patch( + "areal.infra.controller.train_controller.run_async_task", + lambda fn, *args, **kwargs: [], + ): + result = controller._custom_function_call("noop", batch_kw=[sample] * 5) + + assert captured["padded_len"] == 8 + assert result == [0, 1, 2, 3, 4] + + +def test_custom_function_call_rejects_inconsistent_tensor_lengths(): + controller = _make_controller() + sample = _sample_item() + + try: + controller._custom_function_call( + "noop", + [sample] * 5, + [sample] * 4, + ) + except ValueError as exc: + assert "same length" in str(exc) + else: + raise AssertionError("expected inconsistent tensor-like inputs to fail")