Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,16 @@ class DatasetConfig:
drop_last: bool = field(
default=True, metadata={"help": "Drop the last incomplete batch"}
)
single_rank_load: bool = field(
default=False,
metadata={"help": "Use single rank rollout send/recive or not"},
)
balance_batch: bool = field(
default=False,
metadata={
"help": "Balance all rollouts across DP ranks by total tokens. Note: this only works when `single_rank_load` is set to True."
},
)
max_length: int | None = field(
default=None,
metadata={
Expand Down Expand Up @@ -1116,6 +1126,7 @@ class BaseExperimentConfig:
"For benchmarking purposes only. None indicates normal training."
},
)
weight_update_mode: str = field(default="disk")
total_train_n_seqs: int | None = field(
default=None,
metadata={
Expand Down
14 changes: 11 additions & 3 deletions areal/api/workflow_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,9 @@ def submit(
except queue.Full:
raise RuntimeError("Input queue full. Please increase queue_size.")

def wait(self, count: int, timeout: float | None = None) -> Dict[str, Any]:
def wait(
self, count: int, timeout: float | None = None, single_rank_load: bool = False
) -> Dict[str, Any] | List[Dict[str, Any]]:
"""Wait for workflow results.

See :meth:`~areal.api.engine_api.InferenceEngine.wait` for detailed documentation.
Expand Down Expand Up @@ -500,7 +502,10 @@ def wait(self, count: int, timeout: float | None = None) -> Dict[str, Any]:
self.result_cache[count:],
)
random.shuffle(results)
return concat_padded_tensors([r.data for r in results])
if single_rank_load:
return [r.data for r in results]
else:
return concat_padded_tensors([r.data for r in results])

def rollout_batch(
self,
Expand Down Expand Up @@ -528,6 +533,7 @@ def prepare_batch(
workflow: "RolloutWorkflow" | None = None,
workflow_builder: Callable | None = None,
should_accept: Callable | None = None,
single_rank_load: bool = False,
):
"""Prepare a batch with controlled staleness.

Expand All @@ -552,7 +558,9 @@ def prepare_batch(
should_accept=should_accept,
)
try:
return self.wait(dataloader.batch_size, timeout=1)
return self.wait(
dataloader.batch_size, timeout=1, single_rank_load=single_rank_load
)
except TimeoutError:
pass

Expand Down
21 changes: 21 additions & 0 deletions areal/engine/sglang_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,14 @@ def prepare_batch(
workflow: Optional[RolloutWorkflow] = None,
workflow_builder: Optional[Callable] = None,
should_accept: Callable | None = None,
single_rank_load: bool = False,
):
return self.workflow_executor.prepare_batch(
dataloader=dataloader,
workflow=workflow,
workflow_builder=workflow_builder,
should_accept=should_accept,
single_rank_load=single_rank_load,
)

def pause_generation(self):
Expand All @@ -438,6 +440,25 @@ def continue_generation(self):
res = requests.post(f"http://{addr}/continue_generation")
res.raise_for_status()

def pause_generation(self):
"""Pause the generation of inference engine.

Used during updating weights from distributed or disk.
"""
for addr in self.addresses:
res = requests.post(f"http://{addr}/pause_generation")
res.raise_for_status()

# The above http request may require some time to be scheduled and executed.
# The following line waits until all requests are indeed dropped.
time.sleep(self.config.pause_grace_period)

def continue_generation(self):
"""Continue the generation of inference engine."""
for addr in self.addresses:
res = requests.post(f"http://{addr}/continue_generation")
res.raise_for_status()

def pause(self):
"""Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
return self.workflow_executor.pause()
Expand Down
2 changes: 2 additions & 0 deletions areal/engine/vllm_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,14 @@ def prepare_batch(
workflow: Optional[RolloutWorkflow] = None,
workflow_builder: Optional[Callable] = None,
should_accept: Callable | None = None,
single_rank_load: bool = False,
):
return self.workflow_executor.prepare_batch(
dataloader=dataloader,
workflow=workflow,
workflow_builder=workflow_builder,
should_accept=should_accept,
single_rank_load = single_rank_load,
)

def pause_generation(self):
Expand Down
63 changes: 63 additions & 0 deletions areal/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from areal.api.cli_args import MicroBatchSpec
from areal.utils.data import (
balance_batch,
pack_tensor_dict,
pad_and_stack_tensors_along_first_dim,
pad_sequences_to_tensors,
Expand Down Expand Up @@ -69,3 +70,65 @@ def test_micro_batch_split(mock_padded_data, n_mbs, max_tokens_per_mb):
assert torch.allclose(x, packed_data[key])
y = pad_and_stack_tensors_along_first_dim(xs)
assert torch.allclose(mock_padded_data[key], y)


def mock_rollout_data(bs, max_prompt_len, max_answer_len):
prompt_lens = torch.randint(1, max_prompt_len, size=(bs,))
answer_lens = torch.randint(1, max_answer_len, size=(bs,))
all_data = []
for prompt_len, ans_len in zip(prompt_lens, answer_lens):
prompt_len = int(prompt_len)
ans_len = int(ans_len)
seq = dict(
input_ids=torch.randint(0, VOCAB_SIZE, size=(prompt_len + ans_len,)),
loss_mask=torch.tensor([0] * prompt_len + [1] * ans_len),
logprobs=torch.randn(prompt_len + ans_len),
position_ids=torch.arange(prompt_len + ans_len),
)
all_data.append(seq)
return pad_sequences_to_tensors(all_data)


def similar_by_cv(numbers):
mean = sum(numbers) / len(numbers)
std_dev = (sum((x - mean) ** 2 for x in numbers) / len(numbers)) ** 0.5
cv = std_dev / mean
return cv


@pytest.mark.parametrize(
"gbs,max_prompt_len,max_answer_len",
[(32, 512, 2048), (32, 512, 4096), (64, 512, 4096)],
)
@pytest.mark.parametrize("dp_num", [4, 8, 16])
def test_balance_batch(dp_num, gbs, max_prompt_len, max_answer_len):
mbs = gbs // dp_num
tensordict_list = []
for i in range(dp_num):
tensordict_list.append(mock_rollout_data(mbs, max_prompt_len, max_answer_len))
total_tokens_per_rank_original = []
for batch in tensordict_list:
batch_size = batch["attention_mask"].shape[0]
assert batch_size == mbs
total_tokens = sum(
batch["attention_mask"].view(batch_size, -1).sum(-1).tolist()
)
total_tokens_per_rank_original.append(total_tokens)
batches = balance_batch(tensordict_list, dp_num)
batch_per_rank = [batch for batch in batches]
total_tokens_per_rank = []
for batch in batch_per_rank:
batch_size = batch["attention_mask"].shape[0]
assert batch_size == mbs
total_tokens = sum(
batch["attention_mask"].view(batch_size, -1).sum(-1).tolist()
)
total_tokens_per_rank.append(total_tokens)
cv_original = similar_by_cv(total_tokens_per_rank_original)
cv_balance = similar_by_cv(total_tokens_per_rank)
print(
f"-----cv_original {cv_original} total_tokens_per_rank_original {total_tokens_per_rank_original}"
)
print(f"-----cv_balance {cv_balance} total_tokens_per_rank {total_tokens_per_rank}")
print("*****************************************************")
Comment on lines +129 to +133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The print statements in test_balance_batch appear to be for debugging. They should be removed from the final test code to keep the test output clean. If this information is valuable for debugging in CI, consider using the logging module instead.

assert cv_original >= cv_balance
Loading