diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index b0bcc4c75d..2b416edd23 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1241,6 +1241,7 @@ class RejectionSamplingConfig: For KL metrics, aggregation is arithmetic. upper: Upper bound for filtering. lower: Lower bound for filtering (optional). + token_action: Enables two-stage rejection sampling and importance sampling mode ('mask' or 'clamp')(optional). """ level: str = field( @@ -1309,6 +1310,27 @@ class RejectionSamplingConfig: "For 'kl_k1' metric: can be used to filter negative KL estimates." }, ) + token_action: str | None = field( + default=None, + metadata={ + "help": ( + "Enables two-stage Geo-RS + Token-MIS/TIS mode. " + "Only valid when level='sequence' and metric='ratio'. " + "Stage 1 (Geo-RS): sequences whose geometric-mean importance ratio " + "exceeds `upper` are fully rejected (loss_mask zeroed for all tokens). " + "Stage 2: on accepted sequences, apply per-token correction using " + "this action — 'mask' (Token-MIS: zero tokens where per-token " + "ratio > upper) or 'clamp' (Token-TIS: clamp per-token ratio to " + "[lower, upper]). " + "None disables Stage 2 (pure sequence-level Geo-RS only). " + "Experimental results (PR #1084) show that neither Geo-RS alone " + "nor Token-MIS/TIS alone is stable under severe off-policy drift; " + "the two-stage combination is necessary for both grad_norm and " + "approx_kl stability." + ), + "choices": ["mask", "clamp"], + }, + ) def __post_init__(self): """Validate configuration.""" @@ -1377,6 +1399,30 @@ def __post_init__(self): UserWarning, stacklevel=2, ) + # Validate two-stage (Geo-RS + Token-MIS/TIS) constraints. + if self.token_action is not None: + _VALID_TOKEN_ACTIONS = ("mask", "clamp") + if self.token_action not in _VALID_TOKEN_ACTIONS: + raise ValueError( + f"token_action must be one of {_VALID_TOKEN_ACTIONS} or None, " + f"got '{self.token_action}'" + ) + if self.level != "sequence": + raise ValueError( + "token_action (two-stage Geo-RS + Token-MIS/TIS) requires " + f"level='sequence'. Got level='{self.level}'." + ) + if self.metric != "ratio": + raise ValueError( + "token_action (two-stage Geo-RS + Token-MIS/TIS) requires " + f"metric='ratio'. Got metric='{self.metric}'." + ) + if self.action != "mask": + raise ValueError( + "token_action (two-stage mode) requires action='mask' for " + "the sequence-level stage (hard Geo-RS rejection). " + f"Got action='{self.action}'." + ) @dataclass diff --git a/areal/utils/functional/functional.py b/areal/utils/functional/functional.py index a79aab6024..a52f5fcc53 100644 --- a/areal/utils/functional/functional.py +++ b/areal/utils/functional/functional.py @@ -250,6 +250,7 @@ def apply_rejection_sampling( behave_imp_weight = torch.exp(log_ratio) # Save original weight before any clamping, to compute clamped fraction later. original_weight = behave_imp_weight + per_token_ratio = behave_imp_weight.clone() # Step 4: Aggregate and filter # @@ -342,6 +343,25 @@ def apply_rejection_sampling( ), behave_imp_weight, ) + # ── Stage 2: Token-MIS/TIS (1D packed) ────────────────────────────── + if config.token_action is not None: + # behave_imp_weight holds per-token ratios. + # Shape: [total_tokens] in 1D packed format. + token_ratio = per_token_ratio + if config.token_action == "mask": + token_oor = token_ratio > config.upper + if config.lower is not None: + token_oor = token_oor | (token_ratio < config.lower) + loss_mask = loss_mask * (~token_oor).to(loss_mask.dtype) + behave_imp_weight = token_ratio * (~token_oor).to( + behave_imp_weight.dtype + ) + elif config.token_action == "clamp": + clamp_lower = config.lower if config.lower is not None else 0.0 + behave_imp_weight = token_ratio.clamp( + min=clamp_lower, max=config.upper + ) + # ── End Stage 2 ────────────────────────────────────────────────────── else: # 2D padded format agg_values = log_ratio if _use_log_agg else metric @@ -386,6 +406,38 @@ def apply_rejection_sampling( ), behave_imp_weight, ) + + # ── Stage 2: Token-MIS/TIS on Geo-RS-accepted sequences ───────────── + # Runs only in two-stage mode (config.token_action is not None). + # At this point, loss_mask already has Stage-1-rejected sequences + # zeroed out. We now apply per-token filtering on surviving tokens. + if config.token_action is not None: + # behave_imp_weight holds per-token ratios π_prox / π_behave. + # Shape: [batch, seq_len] in 2D padded format. + token_ratio = per_token_ratio + + if config.token_action == "mask": + # Token-MIS: zero out tokens where the per-token ratio exceeds + # upper, and optionally where it falls below lower. + # This suppresses tokens where the current policy has drifted + # far from the behavior policy at the token level. + token_oor = token_ratio > config.upper + if config.lower is not None: + token_oor = token_oor | (token_ratio < config.lower) + loss_mask = loss_mask * (~token_oor).to(loss_mask.dtype) + behave_imp_weight = token_ratio * (~token_oor).to( + behave_imp_weight.dtype + ) + + elif config.token_action == "clamp": + # Token-TIS: clamp the per-token importance ratio to + # [lower, upper]. All tokens remain in the gradient but + # their contribution is bounded. + clamp_lower = config.lower if config.lower is not None else 0.0 + behave_imp_weight = token_ratio.clamp( + min=clamp_lower, max=config.upper + ) + # ── End Stage 2 ────────────────────────────────────────────────────── else: # Token level if config.action == "mask": diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index 778b04972a..97ca19ef5e 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -1034,16 +1034,18 @@ Attributes: For KL metrics, aggregation is arithmetic. upper: Upper bound for filtering. lower: Lower bound for filtering (optional). + token_action: Enables two-stage rejection sampling and importance sampling mode ('mask' or 'clamp')(optional). ``` -| Parameter | Type | Default | Description | -| --------- | ------------- | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `level` | string | `"token"` | Filtering granularity. 'token': per-token filtering (each token judged independently). 'sequence': per-sequence filtering (all tokens in a sequence share the same fate). When metric='ratio', both the filtering decision and the correction weight (behave_imp_weight) operate at sequence level using the geometric mean. **Choices:** `token`, `sequence` | -| `action` | string | `"mask"` | Action to take when metric exceeds threshold. 'mask': zero out loss_mask for filtered tokens/sequences (rejection, completely excludes from gradient computation). 'clamp': clamp importance weight to \[lower, upper\] bounds (truncation, tokens still participate in gradient but with bounded weight). **Choices:** `mask`, `clamp` | -| `metric` | string | `"ratio"` | Divergence metric for filtering. 'ratio': direct importance ratio π_proximal/π_behave. 'kl_k1': KL estimator k1 = log(r), forward KL unbiased estimator (can be negative). 'kl_k2': KL estimator k2 = 0.5 * (log r)^2, non-negative quadratic approximation. 'kl_k3': KL estimator k3 = r - log(r) - 1, non-negative exact forward KL estimator. **Choices:** `ratio`, `kl_k1`, `kl_k2`, `kl_k3` | -| `agg` | string | `"mean"` | Aggregation method for sequence-level filtering. Only used when level='sequence'. For 'ratio' metric, aggregation is in log space: 'sum' = exp(sum(log(r_i))), 'mean' = exp(mean(log(r_i))) = geometric mean (length-invariant, consistent with GSPO). For KL metrics, aggregation is arithmetic: 'sum' = sum(kl_i), 'mean' = mean(kl_i). 'max': max of per-token metric values (most conservative). **Choices:** `sum`, `mean`, `max` | -| `upper` | float | `5.0` | Upper bound for filtering. Tokens/sequences with metric > upper are filtered out (loss_mask zeroed). For 'ratio' metric: must be > 1.0, typical values are 2.0 or 5.0. For 'kl_k2'/'kl_k3' metrics: typical values are 0.5-2.0. | -| `lower` | float \| None | `None` | Lower bound for filtering (optional). None means no lower bound. For 'ratio' metric: typical value is 0.5 (filter out tokens where policy probability dropped significantly). Must be > 0. For 'kl_k1' metric: can be used to filter negative KL estimates. | +| Parameter | Type | Default | Description | +| -------------- | -------------- | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `level` | string | `"token"` | Filtering granularity. 'token': per-token filtering (each token judged independently). 'sequence': per-sequence filtering (all tokens in a sequence share the same fate). When metric='ratio', both the filtering decision and the correction weight (behave_imp_weight) operate at sequence level using the geometric mean. **Choices:** `token`, `sequence` | +| `action` | string | `"mask"` | Action to take when metric exceeds threshold. 'mask': zero out loss_mask for filtered tokens/sequences (rejection, completely excludes from gradient computation). 'clamp': clamp importance weight to \[lower, upper\] bounds (truncation, tokens still participate in gradient but with bounded weight). **Choices:** `mask`, `clamp` | +| `metric` | string | `"ratio"` | Divergence metric for filtering. 'ratio': direct importance ratio π_proximal/π_behave. 'kl_k1': KL estimator k1 = log(r), forward KL unbiased estimator (can be negative). 'kl_k2': KL estimator k2 = 0.5 * (log r)^2, non-negative quadratic approximation. 'kl_k3': KL estimator k3 = r - log(r) - 1, non-negative exact forward KL estimator. **Choices:** `ratio`, `kl_k1`, `kl_k2`, `kl_k3` | +| `agg` | string | `"mean"` | Aggregation method for sequence-level filtering. Only used when level='sequence'. For 'ratio' metric, aggregation is in log space: 'sum' = exp(sum(log(r_i))), 'mean' = exp(mean(log(r_i))) = geometric mean (length-invariant, consistent with GSPO). For KL metrics, aggregation is arithmetic: 'sum' = sum(kl_i), 'mean' = mean(kl_i). 'max': max of per-token metric values (most conservative). **Choices:** `sum`, `mean`, `max` | +| `upper` | float | `5.0` | Upper bound for filtering. Tokens/sequences with metric > upper are filtered out (loss_mask zeroed). For 'ratio' metric: must be > 1.0, typical values are 2.0 or 5.0. For 'kl_k2'/'kl_k3' metrics: typical values are 0.5-2.0. | +| `lower` | float \| None | `None` | Lower bound for filtering (optional). None means no lower bound. For 'ratio' metric: typical value is 0.5 (filter out tokens where policy probability dropped significantly). Must be > 0. For 'kl_k1' metric: can be used to filter negative KL estimates. | +| `token_action` | string \| None | `None` | Enables two-stage Geo-RS + Token-MIS/TIS mode. Only valid when level='sequence' and metric='ratio'. Stage 1 (Geo-RS): sequences whose geometric-mean importance ratio exceeds `upper` are fully rejected (loss_mask zeroed for all tokens). Stage 2: on accepted sequences, apply per-token correction using this action — 'mask' (Token-MIS: zero tokens where per-token ratio > upper) or 'clamp' (Token-TIS: clamp per-token ratio to \[lower, upper\]). None disables Stage 2 (pure sequence-level Geo-RS only). Experimental results (PR #1084) show that neither Geo-RS alone nor Token-MIS/TIS alone is stable under severe off-policy drift; the two-stage combination is necessary for both grad_norm and approx_kl stability. **Choices:** `mask`, `clamp` | (section-scheduler)= diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index 0a71712686..17439e7a43 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -1032,16 +1032,18 @@ Attributes: For KL metrics, aggregation is arithmetic. upper: Upper bound for filtering. lower: Lower bound for filtering (optional). + token_action: Enables two-stage rejection sampling and importance sampling mode ('mask' or 'clamp')(optional). ``` -| Parameter | Type | Default | Description | -| --------- | ------------- | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `level` | string | `"token"` | Filtering granularity. 'token': per-token filtering (each token judged independently). 'sequence': per-sequence filtering (all tokens in a sequence share the same fate). When metric='ratio', both the filtering decision and the correction weight (behave_imp_weight) operate at sequence level using the geometric mean. **Choices:** `token`, `sequence` | -| `action` | string | `"mask"` | Action to take when metric exceeds threshold. 'mask': zero out loss_mask for filtered tokens/sequences (rejection, completely excludes from gradient computation). 'clamp': clamp importance weight to \[lower, upper\] bounds (truncation, tokens still participate in gradient but with bounded weight). **Choices:** `mask`, `clamp` | -| `metric` | string | `"ratio"` | Divergence metric for filtering. 'ratio': direct importance ratio π_proximal/π_behave. 'kl_k1': KL estimator k1 = log(r), forward KL unbiased estimator (can be negative). 'kl_k2': KL estimator k2 = 0.5 * (log r)^2, non-negative quadratic approximation. 'kl_k3': KL estimator k3 = r - log(r) - 1, non-negative exact forward KL estimator. **Choices:** `ratio`, `kl_k1`, `kl_k2`, `kl_k3` | -| `agg` | string | `"mean"` | Aggregation method for sequence-level filtering. Only used when level='sequence'. For 'ratio' metric, aggregation is in log space: 'sum' = exp(sum(log(r_i))), 'mean' = exp(mean(log(r_i))) = geometric mean (length-invariant, consistent with GSPO). For KL metrics, aggregation is arithmetic: 'sum' = sum(kl_i), 'mean' = mean(kl_i). 'max': max of per-token metric values (most conservative). **Choices:** `sum`, `mean`, `max` | -| `upper` | float | `5.0` | Upper bound for filtering. Tokens/sequences with metric > upper are filtered out (loss_mask zeroed). For 'ratio' metric: must be > 1.0, typical values are 2.0 or 5.0. For 'kl_k2'/'kl_k3' metrics: typical values are 0.5-2.0. | -| `lower` | float \| None | `None` | Lower bound for filtering (optional). None means no lower bound. For 'ratio' metric: typical value is 0.5 (filter out tokens where policy probability dropped significantly). Must be > 0. For 'kl_k1' metric: can be used to filter negative KL estimates. | +| Parameter | Type | Default | Description | +| -------------- | -------------- | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `level` | string | `"token"` | Filtering granularity. 'token': per-token filtering (each token judged independently). 'sequence': per-sequence filtering (all tokens in a sequence share the same fate). When metric='ratio', both the filtering decision and the correction weight (behave_imp_weight) operate at sequence level using the geometric mean. **Choices:** `token`, `sequence` | +| `action` | string | `"mask"` | Action to take when metric exceeds threshold. 'mask': zero out loss_mask for filtered tokens/sequences (rejection, completely excludes from gradient computation). 'clamp': clamp importance weight to \[lower, upper\] bounds (truncation, tokens still participate in gradient but with bounded weight). **Choices:** `mask`, `clamp` | +| `metric` | string | `"ratio"` | Divergence metric for filtering. 'ratio': direct importance ratio π_proximal/π_behave. 'kl_k1': KL estimator k1 = log(r), forward KL unbiased estimator (can be negative). 'kl_k2': KL estimator k2 = 0.5 * (log r)^2, non-negative quadratic approximation. 'kl_k3': KL estimator k3 = r - log(r) - 1, non-negative exact forward KL estimator. **Choices:** `ratio`, `kl_k1`, `kl_k2`, `kl_k3` | +| `agg` | string | `"mean"` | Aggregation method for sequence-level filtering. Only used when level='sequence'. For 'ratio' metric, aggregation is in log space: 'sum' = exp(sum(log(r_i))), 'mean' = exp(mean(log(r_i))) = geometric mean (length-invariant, consistent with GSPO). For KL metrics, aggregation is arithmetic: 'sum' = sum(kl_i), 'mean' = mean(kl_i). 'max': max of per-token metric values (most conservative). **Choices:** `sum`, `mean`, `max` | +| `upper` | float | `5.0` | Upper bound for filtering. Tokens/sequences with metric > upper are filtered out (loss_mask zeroed). For 'ratio' metric: must be > 1.0, typical values are 2.0 or 5.0. For 'kl_k2'/'kl_k3' metrics: typical values are 0.5-2.0. | +| `lower` | float \| None | `None` | Lower bound for filtering (optional). None means no lower bound. For 'ratio' metric: typical value is 0.5 (filter out tokens where policy probability dropped significantly). Must be > 0. For 'kl_k1' metric: can be used to filter negative KL estimates. | +| `token_action` | string \| None | `None` | Enables two-stage Geo-RS + Token-MIS/TIS mode. Only valid when level='sequence' and metric='ratio'. Stage 1 (Geo-RS): sequences whose geometric-mean importance ratio exceeds `upper` are fully rejected (loss_mask zeroed for all tokens). Stage 2: on accepted sequences, apply per-token correction using this action — 'mask' (Token-MIS: zero tokens where per-token ratio > upper) or 'clamp' (Token-TIS: clamp per-token ratio to \[lower, upper\]). None disables Stage 2 (pure sequence-level Geo-RS only). Experimental results (PR #1084) show that neither Geo-RS alone nor Token-MIS/TIS alone is stable under severe off-policy drift; the two-stage combination is necessary for both grad_norm and approx_kl stability. **Choices:** `mask`, `clamp` | (section-scheduler)= diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index 4c1dbb8c40..9a081c6b07 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -833,6 +833,333 @@ def test_invalid_metric_raises(self): with pytest.raises(ValueError, match="metric must be one of"): RejectionSamplingConfig(metric="invalid") + +class TestTwoStageRejectionSampling: + """Tests for two-stage Geo-RS + Token-MIS/TIS mode (from closed PR #1084). + + The two-stage pipeline: + Stage 1 — Geo-RS: reject sequences whose geometric-mean ratio > upper. + Stage 2 — Token-MIS/TIS: on accepted sequences, filter/clamp per-token. + """ + + # ── Config validation ───────────────────────────────────────────────────── + + def test_token_action_requires_sequence_level(self): + """token_action must be combined with level='sequence'.""" + with pytest.raises(ValueError, match="level='sequence'"): + RejectionSamplingConfig( + level="token", + action="mask", + metric="ratio", + upper=2.0, + token_action="mask", + ) + + def test_token_action_requires_ratio_metric(self): + """token_action is only defined for metric='ratio'.""" + with pytest.raises(ValueError, match="metric='ratio'"): + RejectionSamplingConfig( + level="sequence", + action="mask", + metric="kl_k2", + upper=1.0, + token_action="mask", + ) + + def test_token_action_requires_action_mask_at_sequence_level(self): + """Sequence-level stage must use action='mask' (hard rejection only).""" + with pytest.raises(ValueError, match="action='mask'"): + RejectionSamplingConfig( + level="sequence", + action="clamp", # invalid for two-stage + metric="ratio", + upper=2.0, + token_action="mask", + ) + + def test_token_action_invalid_string(self): + """token_action must be 'mask', 'clamp', or None.""" + with pytest.raises(ValueError, match="token_action must be one of"): + RejectionSamplingConfig( + level="sequence", + action="mask", + metric="ratio", + upper=2.0, + token_action="truncate", # typo / invalid choice + ) + + def test_valid_two_stage_mis_config(self): + """Geo-RS + Token-MIS config constructs without error.""" + config = RejectionSamplingConfig( + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, + token_action="mask", + ) + assert config.token_action == "mask" + assert config.level == "sequence" + + def test_valid_two_stage_tis_config(self): + """Geo-RS + Token-TIS config constructs without error.""" + config = RejectionSamplingConfig( + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, + lower=0.5, + token_action="clamp", + ) + assert config.token_action == "clamp" + assert config.lower == 0.5 + + # ── Functional tests — 2D padded format ────────────────────────────────── + + @staticmethod + def _batch_inputs(): + """ + Return a 2D padded batch with three sequences of length 4. + + Sequence 0: per-token ratio = 1.5 → geo-mean = 1.5 (accepted, upper=2.0) + Sequence 1: per-token ratio = 3.0 → geo-mean = 3.0 (rejected, > upper) + Sequence 2: per-token ratio = 0.8 → geo-mean = 0.8 (accepted) + """ + ratios = torch.tensor( + [ + [1.5, 1.5, 1.5, 1.5], + [3.0, 3.0, 3.0, 3.0], + [0.8, 0.8, 0.8, 0.8], + ] + ) + loss_mask = torch.ones(3, 4) + proximal_logprobs = torch.log(ratios) + old_logprobs = torch.zeros_like(proximal_logprobs) + return loss_mask, ratios, proximal_logprobs, old_logprobs + + def test_stage1_rejects_divergent_sequence(self): + """Stage 1 (Geo-RS) must fully zero-out the rejected sequence.""" + config = RejectionSamplingConfig( + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, + token_action="mask", + ) + loss_mask, ratios, proximal_logprobs, old_logprobs = self._batch_inputs() + result = apply_rejection_sampling( + config=config, + loss_mask=loss_mask, + cu_seqlens=None, + # behave_imp_weight=ratios, + proximal_logprobs=proximal_logprobs, + old_logprobs=old_logprobs, + ) + new_mask = result.loss_mask + # Sequence 1 (geo-mean 3.0 > 2.0) must be fully masked. + assert new_mask[1].sum() == 0, "Rejected sequence must be fully zeroed" + # Sequences 0 and 2 are accepted and their token ratios ≤ upper → kept. + assert new_mask[0].sum() == 4 + assert new_mask[2].sum() == 4 + + def test_stage2_mis_filters_high_token_within_accepted_seq(self): + """ + Stage 2 (Token-MIS) filters individual high-ratio tokens inside + a sequence that was accepted by Geo-RS. + """ + config = RejectionSamplingConfig( + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, + token_action="mask", + ) + # Seq 0: geo-mean ≈ exp(mean([0, 0, log(2.5), 0])) ≈ 1.26 → accepted by Geo-RS + # but token[2] = 2.5 > upper → masked by Token-MIS + # Seq 1: all ratios = 1.0 → accepted, all tokens kept + ratios = torch.tensor( + [ + [1.0, 1.0, 2.5, 1.0], + [1.0, 1.0, 1.0, 1.0], + ] + ) + loss_mask = torch.ones(2, 4) + proximal_logprobs = torch.log(ratios) + old_logprobs = torch.zeros_like(proximal_logprobs) + + result = apply_rejection_sampling( + config=config, + loss_mask=loss_mask, + cu_seqlens=None, + # behave_imp_weight=ratios, + proximal_logprobs=proximal_logprobs, + old_logprobs=old_logprobs, + ) + new_mask = result.loss_mask + assert new_mask[0, 0] == 1 + assert new_mask[0, 1] == 1 + assert new_mask[0, 2] == 0, "Token-MIS must mask the 2.5-ratio token" + assert new_mask[0, 3] == 1 + assert new_mask[1].sum() == 4, "Clean sequence must be fully kept" + + def test_stage2_tis_clamps_token_weights_not_mask(self): + """ + Stage 2 (Token-TIS) clamps per-token weights but must NOT zero loss_mask. + All tokens continue to contribute to the gradient. + """ + config = RejectionSamplingConfig( + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, + lower=0.5, + token_action="clamp", + ) + # Both sequences accepted by Geo-RS (geo-means ≤ 2.0). + ratios = torch.tensor( + [ + [0.2, 1.0, 1.8, 3.5], # tokens 0 and 3 out of [0.5, 2.0] + [0.8, 1.2, 1.5, 0.9], # all in range + ] + ) + loss_mask = torch.ones(2, 4) + proximal_logprobs = torch.log(ratios.clamp(min=1e-6)) + old_logprobs = torch.zeros_like(proximal_logprobs) + + result = apply_rejection_sampling( + config=config, + loss_mask=loss_mask, + cu_seqlens=None, + # behave_imp_weight=ratios, + proximal_logprobs=proximal_logprobs, + old_logprobs=old_logprobs, + ) + new_mask = result.loss_mask + new_weight = result.behave_imp_weight + # loss_mask must be entirely unchanged — TIS never zeros tokens. + assert new_mask.sum() == 8, "Token-TIS must not zero any loss_mask tokens" + # Weights clamped to [0.5, 2.0]. + assert new_weight[0, 0] == pytest.approx(0.5), "0.2 clamped to lower=0.5" + assert new_weight[0, 1] == pytest.approx(1.0), "1.0 unchanged" + assert new_weight[0, 2] == pytest.approx(1.8), "1.8 unchanged" + assert new_weight[0, 3] == pytest.approx(2.0), "3.5 clamped to upper=2.0" + assert new_weight[1].allclose(ratios[1]), "Seq 1 weights unchanged" + + def test_stage1_dominates_even_if_stage2_would_pass(self): + """ + Tokens in a Stage-1-rejected sequence must stay masked even if their + individual token ratio would have passed the Token-MIS threshold. + """ + config = RejectionSamplingConfig( + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, + token_action="mask", + ) + loss_mask = torch.ones(1, 4) + # geo-mean = 4.0 > 2.0 → Stage 1 rejects this sequence entirely. + ratios = torch.full((1, 4), 4.0) + proximal_logprobs = torch.log(ratios) + old_logprobs = torch.zeros_like(proximal_logprobs) + + result = apply_rejection_sampling( + config=config, + loss_mask=loss_mask, + cu_seqlens=None, + # behave_imp_weight=ratios, + proximal_logprobs=proximal_logprobs, + old_logprobs=old_logprobs, + ) + new_mask = result.loss_mask + assert new_mask.sum() == 0, "Stage 1 rejection must dominate Stage 2" + + def test_none_token_action_identical_to_pure_sequence_geo_rs(self): + """ + token_action=None must produce results identical to the existing + level='sequence', action='mask' mode — no Stage 2 runs. + """ + ratios = torch.tensor( + [ + [1.5, 1.5, 1.5, 1.5], + [3.0, 3.0, 3.0, 3.0], + [0.8, 0.8, 0.8, 0.8], + ] + ) + loss_mask = torch.ones(3, 4) + proximal_logprobs = torch.log(ratios) + old_logprobs = torch.zeros_like(proximal_logprobs) + + cfg_two_stage_off = RejectionSamplingConfig( + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, + token_action=None, + ) + cfg_original = RejectionSamplingConfig( + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, + ) + + result = apply_rejection_sampling( + proximal_logprobs, old_logprobs, loss_mask.clone(), None, cfg_two_stage_off + ) + mask_off = result.loss_mask + w_off = result.behave_imp_weight + result = apply_rejection_sampling( + proximal_logprobs, old_logprobs, loss_mask.clone(), None, cfg_original + ) + mask_orig = result.loss_mask + w_orig = result.behave_imp_weight + + torch.testing.assert_close(mask_off, mask_orig) + torch.testing.assert_close(w_off, w_orig) + + def test_lower_bound_also_applied_in_token_mis(self): + """ + Token-MIS with a `lower` bound must also mask tokens whose ratio + falls below `lower` (policy has dropped sharply at that token). + """ + config = RejectionSamplingConfig( + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=3.0, + lower=0.5, + token_action="mask", + ) + loss_mask = torch.ones(1, 4) + # Seq geo-mean ≈ exp(mean(log([0.3, 1.0, 1.0, 1.0]))) ≈ 0.84 → accepted + # but token[0] = 0.3 < lower=0.5 → masked by Token-MIS + ratios = torch.tensor([[0.3, 1.0, 1.0, 1.0]]) + proximal_logprobs = torch.log(ratios) + old_logprobs = torch.zeros_like(proximal_logprobs) + + result = apply_rejection_sampling( + config=config, + loss_mask=loss_mask, + cu_seqlens=None, + # behave_imp_weight=ratios, + proximal_logprobs=proximal_logprobs, + old_logprobs=old_logprobs, + ) + new_mask = result.loss_mask + assert new_mask[0, 0] == 0, "Token below lower bound must be masked" + assert new_mask[0, 1] == 1 + assert new_mask[0, 2] == 1 + assert new_mask[0, 3] == 1 + def test_invalid_agg_raises(self): """Invalid agg should raise ValueError.""" with pytest.raises(ValueError, match="agg must be one of"):