Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions areal/infra/controller/train_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,30 @@ def _custom_function_call(
**kwargs,
):
"""Dispatch method call to workers via the appropriate path."""
group_size = kwargs.get("group_size", 1)
orig_len = None
for arg in args:
if isinstance(arg, list) and arg and _is_tensor_like(arg):
orig_len = len(arg)
break
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The orig_len detection logic only inspects args, ignoring kwargs. If a tensor-like batch is passed via kwargs, it will not be padded (as _pad_eval_dispatch_args also only handles args), and the results will not be trimmed. This inconsistency could lead to crashes or incorrect behavior in engines that require even batches across data-parallel ranks. Consider extending this detection and the padding logic to include kwargs.


args, kwargs = self._pad_eval_dispatch_args(
args, kwargs, group_size=group_size
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential IndexError when multiple tensor-like lists are passed as arguments. The underlying _pad_eval_batch function (called via _pad_eval_dispatch_args) only pads the first tensor-like list it encounters. If multiple parallel lists (e.g., states and actions) are provided, they will end up with different lengths after padding, causing _partition_inputs to fail when it attempts to index the unpadded lists using indices derived from the padded one. All tensor-like lists should be padded consistently to the same target length.

dp_args, dp_kwargs, group_indices = self._prepare_dispatch(*args, **kwargs)
results = run_async_task(
self._call_workers, method, dp_args, dp_kwargs, rpc_meta=rpc_meta
)
return self._collect_results(results, group_indices)
merged_results = self._collect_results(results, group_indices)

if (
orig_len is not None
and isinstance(merged_results, list)
and len(merged_results) > orig_len
):
merged_results = merged_results[:orig_len]

return merged_results

async def _async_custom_function_call(
self,
Expand All @@ -466,11 +485,30 @@ async def _async_custom_function_call(
**kwargs,
):
"""Async version of _custom_function_call."""
group_size = kwargs.get("group_size", 1)
orig_len = None
for arg in args:
if isinstance(arg, list) and arg and _is_tensor_like(arg):
orig_len = len(arg)
break

args, kwargs = self._pad_eval_dispatch_args(
args, kwargs, group_size=group_size
)
dp_args, dp_kwargs, group_indices = self._prepare_dispatch(*args, **kwargs)
results = await self._call_workers(
method, dp_args, dp_kwargs, rpc_meta=rpc_meta
)
return self._collect_results(results, group_indices)
merged_results = self._collect_results(results, group_indices)

if (
orig_len is not None
and isinstance(merged_results, list)
and len(merged_results) > orig_len
):
merged_results = merged_results[:orig_len]

return merged_results
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The padding and trimming logic in _async_custom_function_call is identical to that in _custom_function_call. This duplication should be refactored into a shared helper method to improve maintainability and ensure that future fixes (such as handling kwargs or multiple lists) are applied consistently to both execution paths.


def _pad_eval_dispatch_args(
self,
Expand Down
19 changes: 12 additions & 7 deletions areal/utils/seqpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,14 +553,15 @@ def balanced_greedy_partition(nums: list[int], K: int) -> list[list[int]]:
List of K lists, where each inner list contains the indices assigned to that group

Raises:
ValueError: If len(nums) is not divisible by K or if len(nums) < K
ValueError: If K <= 0
"""
if K <= 0:
raise ValueError(f"K must be positive, got {K}.")

n = len(nums)
if n < K:
raise ValueError(f"Number of items ({n}) must be >= K ({K}).")
if n % K != 0:
raise ValueError("The length of nums must be divisible by K.")
m = n // K
base = n // K
remainder = n % K
capacities = [base + (1 if i < remainder else 0) for i in range(K)]

# Sort indices by value in descending order
sorted_indices = sorted(range(n), key=lambda i: -nums[i])
Expand All @@ -575,10 +576,14 @@ def balanced_greedy_partition(nums: list[int], K: int) -> list[list[int]]:
chosen_group = -1
min_sum = float("inf")
for i in range(K):
if counts[i] < m and sums[i] < min_sum:
if counts[i] < capacities[i] and sums[i] < min_sum:
min_sum = sums[i]
chosen_group = i

if chosen_group == -1:
raise RuntimeError(
f"Cannot assign item idx={idx} with capacities={capacities}, counts={counts}"
)
groups[chosen_group].append(idx)
sums[chosen_group] += num
counts[chosen_group] += 1
Expand Down
47 changes: 47 additions & 0 deletions tests/test_async_rl_crashes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

from types import SimpleNamespace
from unittest.mock import patch

import torch

from areal.infra.controller.train_controller import TrainController
from areal.utils.seqpack import balanced_greedy_partition


def test_balanced_greedy_partition_handles_remainders():
groups = balanced_greedy_partition([10, 8, 6, 4, 2], K=3)

assert len(groups) == 3
assert sorted(len(group) for group in groups) == [1, 2, 2]
assert sorted(idx for group in groups for idx in group) == [0, 1, 2, 3, 4]


def test_custom_function_call_trims_padded_eval_items():
controller = TrainController.__new__(TrainController)
controller.train_alloc = SimpleNamespace(parallel=SimpleNamespace(dp_size=4))

sample = {"attention_mask": torch.ones(1, dtype=torch.long)}
captured: dict[str, int] = {}

def fake_prepare_dispatch(*args, **kwargs):
captured["padded_len"] = len(args[0])
return [], {}, None

controller._prepare_dispatch = fake_prepare_dispatch
controller._call_workers = lambda *args, **kwargs: []
controller._collect_results = (
lambda results, group_indices: list(range(captured["padded_len"]))
)

with patch(
"areal.infra.controller.train_controller.make_dummy_eval_item",
lambda template: {"attention_mask": torch.zeros_like(template["attention_mask"])},
), patch(
"areal.infra.controller.train_controller.run_async_task",
lambda fn, *args, **kwargs: [],
):
result = controller._custom_function_call("noop", [sample] * 5)

assert captured["padded_len"] == 8
assert result == [0, 1, 2, 3, 4]
Loading