Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 79 additions & 14 deletions tunix/generate/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,35 @@
from flax.nnx import statelib
import jax
import jax.numpy as jnp
from jax.experimental import multihost_utils as mhu
from jax.lax import with_sharding_constraint as wsc
from jax.sharding import PartitionSpec as P
import jaxtyping
import numpy as np
from tunix.generate import base_sampler
from tunix.generate import utils
import tunix.generate.beam_search as beam_search_lib
import tunix.generate.tokenizer_adapter as tok_adapter
from tunix.sft import sharding_utils as shd_utils
from typing import Any as _TypingAny # for helper typing isolation

LayerCache = dict[str, jaxtyping.Array]
Cache = dict[str, LayerCache]


def _assert_tp_replicated(arr: _TypingAny) -> None:
"""Assert that a global array is replicated across tensor-parallel axis.

This checks that the sharding spec does not include the 'tp' axis.
If the array does not have a NamedSharding/spec, the check is skipped.
"""
sh = getattr(arr, 'sharding', None)
spec = getattr(sh, 'spec', None)
if spec is None:
return
assert 'tp' not in spec, f'Expected TP-replicated array; got spec={spec}'


@flax.struct.dataclass
class _SamplingState:
"""Internal sampling state."""
Expand Down Expand Up @@ -508,6 +526,8 @@ def _prefill_fn(
sampler_state.cache,
attention_mask,
)
# Ensure full-vocab logits are replicated across tp for sampling/beam.
logits = wsc(logits, P('fsdp', None, None))
token_buffer = sampler_state.token_buffer
done = sampler_state.done
positions = sampler_state.positions
Expand Down Expand Up @@ -608,6 +628,8 @@ def _sample_step(
sampler_state.cache,
attention_mask,
)
# Ensure full-vocab logits are replicated across tp for sampling.
logits = wsc(logits, P('fsdp', None, None))
updated_sampler_state = self._sample(
logits=logits,
cache=cache,
Expand Down Expand Up @@ -696,8 +718,11 @@ def __call__(

tokens = [self.tokenize(x) for x in input_strings]
max_tokens_length = max(len(x) for x in tokens)
if max_prompt_length is None or max_prompt_length < max_tokens_length:
max_prompt_length = utils.next_power_of_2(max_tokens_length)
# Compute a global max across processes and choose a global prompt length.
local_max = np.array([max_tokens_length], dtype=np.int32)
max_prompt_len_global = int(mhu.process_allgather(local_max).max())
if max_prompt_length is None or max_prompt_length < max_prompt_len_global:
max_prompt_length = utils.next_power_of_2(max_prompt_len_global)

all_input_ids = jnp.array([
utils.pad_to_length(
Expand All @@ -720,6 +745,8 @@ def __call__(
seed = jax.random.PRNGKey(0)
elif isinstance(seed, int):
seed = jax.random.PRNGKey(seed)
# Make per-process RNG unique across fsdp axis.
seed = jax.random.fold_in(seed, jax.process_index())
sampling_state = self.init_sample_state(
all_input_ids,
include_logits=return_logits,
Expand All @@ -731,6 +758,19 @@ def __call__(
seed=seed,
beam_size=beam_size,
)
# Convert batch-shaped fields and cache to fsdp-sharded global Arrays.
sampling_state = dataclasses.replace(
sampling_state,
token_buffer=shd_utils.shard_input(sampling_state.token_buffer, ("fsdp",)),
positions=shd_utils.shard_input(sampling_state.positions, ("fsdp",)),
done=shd_utils.shard_input(sampling_state.done, ("fsdp",)),
logits_buffer=(
None
if sampling_state.logits_buffer is None
else shd_utils.shard_input(sampling_state.logits_buffer, ("fsdp",))
),
cache=shd_utils.shard_input(sampling_state.cache, ("fsdp",)),
)
sampling_state = self._compiled_prefill_fn(
self._flattened_transformer_state, sampling_state
)
Expand All @@ -740,7 +780,6 @@ def __call__(
)
token_buffers = sampling_state.token_buffer
logits_buffers = sampling_state.logits_buffer

if sampling_state.sampling_mode == 'beam_search':
updated_args = beam_search_lib.finalize_beam_search_state(
sampling_state.beam_search_sampling_state,
Expand All @@ -765,32 +804,58 @@ def __call__(
max_prompt_length,
max_len,
)
out_tokens, lengths = jax.device_get(out_tokens), jax.device_get(lengths)
# Decode from local shards only (multi-controller safe)
if not out_tokens.is_fully_addressable:
_assert_tp_replicated(out_tokens)
_assert_tp_replicated(lengths)
tokens_host = out_tokens.addressable_shards[0].data
lengths_host = lengths.addressable_shards[0].data
else:
tokens_host = jax.device_get(out_tokens)
lengths_host = jax.device_get(lengths)

decoded_outputs = [
self.tokenizer.decode(tokens[:length].tolist())
for tokens, length in zip(out_tokens, lengths)
self.tokenizer.decode(tokens_host[i][:int(lengths_host[i])].tolist())
for i in range(len(tokens_host))
]
else:
# Gather a single local device shard for processing; avoid iterating global arrays.
if hasattr(token_buffers, 'is_fully_addressable') and not token_buffers.is_fully_addressable:
_assert_tp_replicated(token_buffers)
token_buffers_host = token_buffers.addressable_shards[0].data
if logits_buffers is None:
logits_buffers_host = None
else:
_assert_tp_replicated(logits_buffers)
logits_buffers_host = logits_buffers.addressable_shards[0].data
else:
token_buffers_host = jax.device_get(token_buffers)
logits_buffers_host = (
None if logits_buffers is None else jax.device_get(logits_buffers)
)
out_tokens = []
out_logits = []
lengths = []
for i, token_buffer in enumerate(token_buffers):
for i in range(token_buffers_host.shape[0]):
token_buffer_i = token_buffers_host[i]
start_idx = (
utils.find_first_non_pad_idx(token_buffer, self.tokenizer.pad_id())
utils.find_first_non_pad_idx(
jnp.array(token_buffer_i), self.tokenizer.pad_id()
)
if echo
else max_prompt_length
)
end_idx = (
utils.find_first_eos_idx(
token_buffer[max_prompt_length:], self.eos_ids
jnp.array(token_buffer_i[max_prompt_length:]), self.eos_ids
)
+ max_prompt_length
)
out_tokens.append(token_buffer[start_idx:end_idx])
if return_logits:
out_logits.append(logits_buffers[i][start_idx:end_idx])
out_tokens.append(token_buffer_i[start_idx:end_idx])
if return_logits and logits_buffers_host is not None:
out_logits.append(logits_buffers_host[i][start_idx:end_idx])
lengths.append(end_idx - start_idx)

lengths = np.array(lengths, dtype=np.int32)
decoded_outputs = [
self.tokenizer.decode(tokens.tolist()) for tokens in out_tokens
]
Expand All @@ -799,7 +864,7 @@ def __call__(
text=decoded_outputs,
logits=out_logits if return_logits else [],
tokens=out_tokens,
padded_prompt_tokens=all_input_ids,
padded_prompt_tokens=shd_utils.shard_input(all_input_ids, ("fsdp",)),
logprobs=None,
)
return result
50 changes: 31 additions & 19 deletions tunix/rl/grpo/grpo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import flax
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec
import numpy as np
from tunix.rl import algorithm_config as algo_config_lib
from tunix.rl import common
from tunix.rl import function_registry
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl import rl_learner
from tunix.sft import sharding_utils

TrainingInputT = rl_learner.TrainingInputT
RewardFn = rl_learner.RewardFn
Expand Down Expand Up @@ -206,14 +208,19 @@ def _generate_and_compute_advantage(
rollout_output = self.rl_cluster.generate(
prompts=training_input["prompts"],
mode=mode,
micro_batch_size=(
self._rollout_micro_batch_size * self.algo_config.num_generations
),
micro_batch_size=(self.rollout_micro_batch_size * self.algo_config.num_generations),
)
completion_ids = rollout_output.tokens
prompt_ids = rollout_output.left_padded_prompt_tokens
completion_text = rollout_output.text

# Ensure local completions match local prompts.
assert len(completion_text) == len(training_input["prompts"]), (
"Mismatch between local completions and prompts; expected per-process "
"batching to align counts."
)
local_training_input = training_input

# Assemble masks
prompt_mask = prompt_ids != pad_value
completion_padding_mask = jnp.not_equal(completion_ids, pad_value)
Expand All @@ -232,10 +239,7 @@ def _generate_and_compute_advantage(
completion_tokens=completion_ids,
pad_id=pad_value,
eos_id=eos_value,
micro_batch_size=(
self._compute_logps_micro_batch_size
* self.algo_config.num_generations
),
micro_batch_size=(self.compute_logps_micro_batch_size * self.algo_config.num_generations),
)
interval.device_end([ref_per_token_logps])
else:
Expand All @@ -248,23 +252,27 @@ def _generate_and_compute_advantage(
old_per_token_logps = self.rl_cluster.get_old_per_token_logps(
prompt_tokens=prompt_ids,
completion_tokens=completion_ids,
micro_batch_size=(
self._compute_logps_micro_batch_size
* self.algo_config.num_generations
),
micro_batch_size=(self.compute_logps_micro_batch_size * self.algo_config.num_generations),
)
interval.device_end([old_per_token_logps])
else:
old_per_token_logps = None

with self.rl_cluster.perf.span("advantage_computation"):
# Compute rewards and advantages
rewards = self._compute_rewards(
prompts=training_input["prompts"],
# Compute rewards locally on each process
rewards_local = self._compute_rewards(
prompts=local_training_input["prompts"],
completions=completion_text,
mode=mode,
**{k: v for k, v in training_input.items() if k != "prompts"},
**{k: v for k, v in local_training_input.items() if k != "prompts"},
)

# Create globally sharded array from process-local rewards
mesh = self.rl_cluster.r2m[rl_cluster_lib.Role.ACTOR]
pspec = PartitionSpec(*self.rl_cluster.cluster_config.training_config.data_sharding_axis)
reward_sharding = sharding_utils.get_sharding(rewards_local, mesh, pspec)
rewards = jax.make_array_from_process_local_data(reward_sharding, rewards_local)

advantage_estimator = function_registry.get_advantage_estimator(
self.algo_config.advantage_estimator
)
Expand Down Expand Up @@ -293,22 +301,26 @@ def _generate_and_compute_advantage(
)
for m_fn in self.metric_fns:
user_defined_metric = m_fn(
prompts=training_input["prompts"],
prompts=local_training_input["prompts"],
completions=completion_text,
advances=advantages,
advantages=advantages,
rewards=rewards,
**{k: v for k, v in training_input.items() if k != "prompts"},
**{k: v for k, v in local_training_input.items() if k != "prompts"},
)
self.rl_cluster.buffer_metrics(user_defined_metric, mode=mode)

return TrainExample(
# Shard TrainExample leaves along fsdp to align with data-parallel mental model.
return sharding_utils.shard_input(
TrainExample(
prompt_ids=prompt_ids,
prompt_mask=prompt_mask,
completion_ids=completion_ids,
completion_mask=completion_mask,
ref_per_token_logps=ref_per_token_logps,
advantages=advantages,
old_per_token_logps=old_per_token_logps,
),
self.rl_cluster.cluster_config.training_config.data_sharding_axis,
)

def _compute_trajectory_ids(
Expand Down
Loading
Loading