Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 79 additions & 31 deletions areal/infra/controller/train_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,32 +123,63 @@ def _dispatch_tensors(


def _pad_eval_batch(
args: tuple[Any, ...], dp_size: int, group_size: int = 1
) -> tuple[Any, ...]:
"""Pad the first tensor-like arg to a multiple of ``dp_size * group_size``.

Called before dispatch for explicit evaluation controller paths so that
``balanced_greedy_partition`` always receives a divisible input.
Dummy items have zero attention/loss masks and contribute nothing
to metrics or loss.
args: tuple[Any, ...],
kwargs: dict[str, Any],
dp_size: int,
group_size: int = 1,
) -> tuple[tuple[Any, ...], dict[str, Any], int | None]:
"""Pad all top-level tensor-like inputs to a multiple of ``dp_size * group_size``.

Called before explicit evaluation dispatch so that tensor partitions stay
aligned across every top-level tensor-like argument/keyword argument.
Dummy items have zero attention/loss masks and contribute nothing to
metrics or loss.
"""
result = list(args)
pad_target = dp_size * group_size
for i, arg in enumerate(result):
result_args = list(args)
result_kwargs = dict(kwargs)
tensor_inputs: list[tuple[str, Any, list[Any]]] = []

for i, arg in enumerate(result_args):
if isinstance(arg, list) and arg and _is_tensor_like(arg):
n = len(arg)
pad_count = (-n) % pad_target
if pad_count > 0:
padded = list(arg)
template = arg[0]
padded.extend(make_dummy_eval_item(template) for _ in range(pad_count))
result[i] = padded
logger.info(
f"Eval dispatch: padded {pad_count} dummy items "
f"(total {len(padded)}) for dp_size={dp_size}"
)
break # only pad the first tensor-like arg
return tuple(result)
tensor_inputs.append(("arg", i, arg))
for key, value in result_kwargs.items():
if isinstance(value, list) and value and _is_tensor_like(value):
tensor_inputs.append(("kwarg", key, value))

if not tensor_inputs:
return tuple(result_args), result_kwargs, None

lengths = {len(items) for _, _, items in tensor_inputs}
if len(lengths) != 1:
raise ValueError(
"All tensor-like eval dispatch inputs must have the same length, "
f"got lengths {sorted(lengths)}."
)
orig_len = lengths.pop()
pad_target = dp_size * group_size
pad_count = (-orig_len) % pad_target
if pad_count == 0:
return tuple(result_args), result_kwargs, orig_len

for location, key, items in tensor_inputs:
padded = list(items)
template = items[0]
padded.extend(make_dummy_eval_item(template) for _ in range(pad_count))
if location == "arg":
result_args[key] = padded
else:
result_kwargs[key] = padded

logger.info(
"Eval dispatch: padded %s dummy items for %s tensor-like inputs "
"(total %s) for dp_size=%s group_size=%s",
pad_count,
len(tensor_inputs),
orig_len + pad_count,
dp_size,
group_size,
)
return tuple(result_args), result_kwargs, orig_len


def _merge_tensors(
Expand Down Expand Up @@ -452,11 +483,16 @@ def _custom_function_call(
**kwargs,
):
"""Dispatch method call to workers via the appropriate path."""
group_size = kwargs.get("group_size", 1)
args, kwargs, orig_len = self._pad_eval_dispatch_args(
args, kwargs, group_size=group_size
)
dp_args, dp_kwargs, group_indices = self._prepare_dispatch(*args, **kwargs)
results = run_async_task(
self._call_workers, method, dp_args, dp_kwargs, rpc_meta=rpc_meta
)
return self._collect_results(results, group_indices)
merged_results = self._collect_results(results, group_indices)
return self._trim_padded_eval_results(merged_results, orig_len)

async def _async_custom_function_call(
self,
Expand All @@ -466,25 +502,37 @@ async def _async_custom_function_call(
**kwargs,
):
"""Async version of _custom_function_call."""
group_size = kwargs.get("group_size", 1)
args, kwargs, orig_len = self._pad_eval_dispatch_args(
args, kwargs, group_size=group_size
)
dp_args, dp_kwargs, group_indices = self._prepare_dispatch(*args, **kwargs)
results = await self._call_workers(
method, dp_args, dp_kwargs, rpc_meta=rpc_meta
)
return self._collect_results(results, group_indices)
merged_results = self._collect_results(results, group_indices)
return self._trim_padded_eval_results(merged_results, orig_len)

def _pad_eval_dispatch_args(
self,
args: tuple[Any, ...],
kwargs: dict[str, Any],
*,
group_size: int,
) -> tuple[tuple[Any, ...], dict[str, Any]]:
) -> tuple[tuple[Any, ...], dict[str, Any], int | None]:
"""Pad eval batches for explicit algorithm-level evaluation dispatch."""
kwargs = dict(kwargs)
args = _pad_eval_batch(
args, self.parallel_strategy.dp_size, group_size=group_size
return _pad_eval_batch(
args,
kwargs,
self.parallel_strategy.dp_size,
group_size=group_size,
)
return args, kwargs

def _trim_padded_eval_results(self, results: Any, orig_len: int | None) -> Any:
"""Drop dummy eval outputs that were added only for even DP partitioning."""
if orig_len is None or not isinstance(results, list) or len(results) <= orig_len:
return results
return results[:orig_len]

def _prepare_dispatch(
self, *args, **kwargs
Expand Down
19 changes: 12 additions & 7 deletions areal/utils/seqpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,14 +553,15 @@ def balanced_greedy_partition(nums: list[int], K: int) -> list[list[int]]:
List of K lists, where each inner list contains the indices assigned to that group

Raises:
ValueError: If len(nums) is not divisible by K or if len(nums) < K
ValueError: If K <= 0
"""
if K <= 0:
raise ValueError(f"K must be positive, got {K}.")

n = len(nums)
if n < K:
raise ValueError(f"Number of items ({n}) must be >= K ({K}).")
if n % K != 0:
raise ValueError("The length of nums must be divisible by K.")
m = n // K
base = n // K
remainder = n % K
capacities = [base + (1 if i < remainder else 0) for i in range(K)]

# Sort indices by value in descending order
sorted_indices = sorted(range(n), key=lambda i: -nums[i])
Expand All @@ -575,10 +576,14 @@ def balanced_greedy_partition(nums: list[int], K: int) -> list[list[int]]:
chosen_group = -1
min_sum = float("inf")
for i in range(K):
if counts[i] < m and sums[i] < min_sum:
if counts[i] < capacities[i] and sums[i] < min_sum:
min_sum = sums[i]
chosen_group = i

if chosen_group == -1:
raise RuntimeError(
f"Cannot assign item idx={idx} with capacities={capacities}, counts={counts}"
)
groups[chosen_group].append(idx)
sums[chosen_group] += num
counts[chosen_group] += 1
Expand Down
134 changes: 134 additions & 0 deletions tests/test_async_rl_crashes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from __future__ import annotations

from types import SimpleNamespace
from unittest.mock import patch

import torch

from areal.infra.controller.train_controller import TrainController
from areal.utils.seqpack import balanced_greedy_partition


def test_balanced_greedy_partition_handles_remainders():
groups = balanced_greedy_partition([10, 8, 6, 4, 2], K=3)

assert len(groups) == 3
assert sorted(len(group) for group in groups) == [1, 2, 2]
assert sorted(idx for group in groups for idx in group) == [0, 1, 2, 3, 4]


def _make_controller() -> TrainController:
controller = TrainController.__new__(TrainController)
controller.train_alloc = SimpleNamespace(parallel=SimpleNamespace(dp_size=4))
return controller


def _sample_item() -> dict[str, torch.Tensor]:
return {"attention_mask": torch.ones(1, dtype=torch.long)}


def test_custom_function_call_trims_padded_eval_items():
controller = _make_controller()
sample = _sample_item()
captured: dict[str, int] = {}

def fake_prepare_dispatch(*args, **kwargs):
captured["padded_len"] = len(args[0])
return [], {}, None

controller._prepare_dispatch = fake_prepare_dispatch
controller._call_workers = lambda *args, **kwargs: []
controller._collect_results = (
lambda results, group_indices: list(range(captured["padded_len"]))
)

with patch(
"areal.infra.controller.train_controller.make_dummy_eval_item",
lambda template: {"attention_mask": torch.zeros_like(template["attention_mask"])},
), patch(
"areal.infra.controller.train_controller.run_async_task",
lambda fn, *args, **kwargs: [],
):
result = controller._custom_function_call("noop", [sample] * 5)

assert captured["padded_len"] == 8
assert result == [0, 1, 2, 3, 4]


def test_custom_function_call_pads_kwargs_and_parallel_tensor_lists():
controller = _make_controller()
sample = _sample_item()
captured: dict[str, list[int] | int] = {}

def fake_prepare_dispatch(*args, **kwargs):
captured["arg_lens"] = [len(args[0]), len(args[1])]
captured["kwarg_len"] = len(kwargs["batch_kw"])
return [], {}, None

controller._prepare_dispatch = fake_prepare_dispatch
controller._call_workers = lambda *args, **kwargs: []
controller._collect_results = (
lambda results, group_indices: list(range(captured["kwarg_len"]))
)

with patch(
"areal.infra.controller.train_controller.make_dummy_eval_item",
lambda template: {"attention_mask": torch.zeros_like(template["attention_mask"])},
), patch(
"areal.infra.controller.train_controller.run_async_task",
lambda fn, *args, **kwargs: [],
):
result = controller._custom_function_call(
"noop",
[sample] * 5,
[sample] * 5,
batch_kw=[sample] * 5,
)

assert captured["arg_lens"] == [8, 8]
assert captured["kwarg_len"] == 8
assert result == [0, 1, 2, 3, 4]


def test_custom_function_call_uses_kwargs_length_for_trimming():
controller = _make_controller()
sample = _sample_item()
captured: dict[str, int] = {}

def fake_prepare_dispatch(*args, **kwargs):
captured["padded_len"] = len(kwargs["batch_kw"])
return [], {}, None

controller._prepare_dispatch = fake_prepare_dispatch
controller._call_workers = lambda *args, **kwargs: []
controller._collect_results = (
lambda results, group_indices: list(range(captured["padded_len"]))
)

with patch(
"areal.infra.controller.train_controller.make_dummy_eval_item",
lambda template: {"attention_mask": torch.zeros_like(template["attention_mask"])},
), patch(
"areal.infra.controller.train_controller.run_async_task",
lambda fn, *args, **kwargs: [],
):
result = controller._custom_function_call("noop", batch_kw=[sample] * 5)

assert captured["padded_len"] == 8
assert result == [0, 1, 2, 3, 4]


def test_custom_function_call_rejects_inconsistent_tensor_lengths():
controller = _make_controller()
sample = _sample_item()

try:
controller._custom_function_call(
"noop",
[sample] * 5,
[sample] * 4,
)
except ValueError as exc:
assert "same length" in str(exc)
else:
raise AssertionError("expected inconsistent tensor-like inputs to fail")
Loading