diff --git a/.gitignore b/.gitignore index 87df5a1..6dca3b2 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,10 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST + +# Local development environments and test caches +.venv/ +.pytest_cache/ + +# Generated experiment outputs +experiments/kondo/artifacts/ diff --git a/egg/actors/dreamer_bps.py b/egg/actors/dreamer_bps.py index f93d1ce..41365f9 100644 --- a/egg/actors/dreamer_bps.py +++ b/egg/actors/dreamer_bps.py @@ -113,7 +113,11 @@ def sample_batch( self.sampler_network.apply if self.sampler_network else state.apply_fn ) seqs, logps_full = self.sampler( - sampler_apply_fn, state.params, prompts_flat, k_sample + sampler_apply_fn, + state.params, + prompts_flat, + k_sample, + model=self.sampler_network, ) # Answer length. diff --git a/egg/actors/fixed_bpc.py b/egg/actors/fixed_bpc.py index c3ae5db..95dc4af 100644 --- a/egg/actors/fixed_bpc.py +++ b/egg/actors/fixed_bpc.py @@ -121,6 +121,7 @@ def _sample(prompts: jax.Array, k: jax.Array) -> ar_sample.SampleResult: params=state.params, prompts=prompts, key=k, + model=self.sampler_network, ) seqs, logps = jax.vmap(_sample)(active_prompts, sample_keys) diff --git a/egg/actors/fixed_bps.py b/egg/actors/fixed_bps.py index a2783d7..9c480fd 100644 --- a/egg/actors/fixed_bps.py +++ b/egg/actors/fixed_bps.py @@ -85,7 +85,11 @@ def sample_batch( sampler_apply_fn = state.apply_fn seqs, logps = self.sampler( - sampler_apply_fn, state.params, prompts_flat, k_sample + sampler_apply_fn, + state.params, + prompts_flat, + k_sample, + model=self.sampler_network, ) answers = seqs[:, prompt_len:] # (B, A_len) diff --git a/egg/lib/ar_sample.py b/egg/lib/ar_sample.py index 15fd5bc..64805b4 100644 --- a/egg/lib/ar_sample.py +++ b/egg/lib/ar_sample.py @@ -30,6 +30,41 @@ class SampleResult(tp.NamedTuple): log_probs: jax.Array # same shape as `tokens` +def _extract_logits( + outs: jax.Array | dict[str, jax.Array], +) -> jax.Array: + if isinstance(outs, dict): + return outs["policy_logits"] + return outs + + +class _StepKeys(tp.NamedTuple): + next_key: jax.Array + sample_key: jax.Array + explore_key: jax.Array + model_key: jax.Array | None + + +def _split_step_keys( + key: jax.Array, + *, + fixed_model_key: bool, + fixed_model_rng: jax.Array | None, +) -> _StepKeys: + """Splits one dense-sampler step worth of RNG state.""" + key, sample_key, explore_key = jax.random.split(key, 3) + if fixed_model_key: + model_key = fixed_model_rng + else: + key, model_key = jax.random.split(key) + return _StepKeys( + next_key=key, + sample_key=sample_key, + explore_key=explore_key, + model_key=model_key, + ) + + @dataclasses.dataclass(frozen=True) class ARSampler: """Autoregressively sample *a batch* of sequences in parallel.""" @@ -39,6 +74,7 @@ class ARSampler: fixed_model_key: bool = False # If True, model key is fixed per episode. epsilon: float = 0.0 # Probability of taking a random action. vocab_size: int | None = None # Must be provided if epsilon > 0 + use_decode_cache: bool = True def __post_init__(self): if self.epsilon > 0 and self.vocab_size is None: @@ -50,8 +86,26 @@ def __call__( params: tp.Any, prompts: jax.Array, # (B, T_p) or (T_p,) key: jax.Array, + model: tp.Any | None = None, ) -> SampleResult: """Autoregressively sample *a batch* of sequences in parallel.""" + if ( + self.use_decode_cache + and model is not None + and hasattr(model, "decode_step") + and hasattr(model, "init_decode_cache") + ): + return self._sample_cached(apply_fn, params, prompts, key, model) + return self._sample_dense(apply_fn, params, prompts, key) + + def _sample_dense( + self, + apply_fn: tp.Callable[..., jax.Array | dict[str, jax.Array]], + params: tp.Any, + prompts: jax.Array, + key: jax.Array, + ) -> SampleResult: + """Dense autoregressive sampling without decode cache.""" squeeze = prompts.ndim == 1 if squeeze: prompts = prompts[None, :] # (1, T_p) @@ -71,31 +125,31 @@ def __call__( init = _Carry(seq=seq, logp=logp, key=key) def step(t: int, carry: _Carry) -> _Carry: - seq, logp, k = carry - k, k_sample, k_explore = jax.random.split(k, 3) - - if self.fixed_model_key: - current_k_model = k_model - else: - k, current_k_model = jax.random.split(k) - - outs = apply_fn({"params": params}, seq, rngs={"noise": current_k_model}) - if isinstance(outs, dict): - logits = outs["policy_logits"] - else: - logits = outs + seq, logp, key = carry + step_keys = _split_step_keys( + key, + fixed_model_key=self.fixed_model_key, + fixed_model_rng=k_model, + ) + + outs = apply_fn({"params": params}, seq, rngs={"noise": step_keys.model_key}) + logits = _extract_logits(outs) logits_t = logits[:, t - 1] # (B, V) # Sample from model - model_tok = jax.random.categorical(k_sample, logits_t) # (B,) + model_tok = jax.random.categorical(step_keys.sample_key, logits_t) # (B,) if self.epsilon > 0: # Epsilon-greedy exploration explore_action = jax.random.randint( - k_explore, shape=(batch_size,), minval=0, maxval=self.vocab_size + step_keys.explore_key, + shape=(batch_size,), + minval=0, + maxval=self.vocab_size, ) explore_cond = ( - jax.random.uniform(k_explore, shape=(batch_size,)) < self.epsilon + jax.random.uniform(step_keys.explore_key, shape=(batch_size,)) + < self.epsilon ) tok = jnp.where(explore_cond, explore_action, model_tok) else: @@ -127,7 +181,7 @@ def step(t: int, carry: _Carry) -> _Carry: seq = seq.at[:, t].set(tok) logp = logp.at[:, t].set(lp_t) - return _Carry(seq, logp, k) + return _Carry(seq, logp, step_keys.next_key) final = jax.lax.fori_loop( lower=prompt_len, @@ -142,6 +196,205 @@ def step(t: int, carry: _Carry) -> _Carry: return SampleResult(tokens=tokens, log_probs=log_probs) + def _sample_cached( + self, + apply_fn: tp.Callable[..., jax.Array | dict[str, jax.Array]], + params: tp.Any, + prompts: jax.Array, + key: jax.Array, + model: tp.Any, + ) -> SampleResult: + """Autoregressive sampling with prompt prefill and cached decoding.""" + squeeze = prompts.ndim == 1 + if squeeze: + prompts = prompts[None, :] + batch_size, prompt_len = prompts.shape + if prompt_len == 0: + return self._sample_dense(apply_fn, params, prompts, key) + + seq = jnp.full( + (batch_size, self.sequence_length), self.pad_token, prompts.dtype + ) + seq = seq.at[:, :prompt_len].set(prompts) + logp = jnp.zeros((batch_size, self.sequence_length), jnp.float32) + + if self.fixed_model_key: + k_model, key = jax.random.split(key) + else: + k_model = None + + def decode_apply( + cache: dict[str, jax.Array] | None, + token: jax.Array, + model_key: jax.Array | None, + ) -> tuple[jax.Array, dict[str, jax.Array]]: + variables: dict[str, tp.Any] = {"params": params} + if cache is not None: + variables["cache"] = cache + if model_key is None: + outs, mutated = apply_fn( + variables, + token, + method=model.decode_step, + mutable=["cache"], + ) + else: + outs, mutated = apply_fn( + variables, + token, + rngs={"noise": model_key}, + method=model.decode_step, + mutable=["cache"], + ) + return _extract_logits(outs), mutated["cache"] + + def init_cache_apply(token: jax.Array) -> dict[str, jax.Array]: + _, mutated = apply_fn( + {"params": params}, + token, + method=model.init_decode_cache, + mutable=["cache"], + ) + return mutated["cache"] + + def prompt_token_at(i: int) -> jax.Array: + return jax.lax.dynamic_slice(prompts, (0, i), (batch_size, 1)) + + first_step_keys = _split_step_keys( + key, + fixed_model_key=self.fixed_model_key, + fixed_model_rng=k_model, + ) + cache = init_cache_apply(prompt_token_at(0)) + prefill_model_key = first_step_keys.model_key + logits_prev, cache = decode_apply( + cache, + prompt_token_at(0), + prefill_model_key, + ) + + def prefill_step( + i: int, + carry: tuple[dict[str, jax.Array], jax.Array], + ) -> tuple[dict[str, jax.Array], jax.Array]: + current_cache, current_logits = carry + current_logits, current_cache = decode_apply( + current_cache, + prompt_token_at(i), + prefill_model_key, + ) + return current_cache, current_logits + + cache, logits_prev = jax.lax.fori_loop( + 1, + prompt_len, + prefill_step, + (cache, logits_prev), + ) + + init = _DecodeCarry( + seq=seq, + logp=logp, + key=first_step_keys.next_key, + sample_key=first_step_keys.sample_key, + explore_key=first_step_keys.explore_key, + cache=cache, + logits_prev=logits_prev, + ) + + def step(t: int, carry: _DecodeCarry) -> _DecodeCarry: + seq, logp, key, sample_key, explore_key, cache, logits_prev = carry + + logits_t = logits_prev[:, 0] # (B, V) + model_tok = jax.random.categorical(sample_key, logits_t) + + if self.epsilon > 0: + explore_action = jax.random.randint( + explore_key, + shape=(batch_size,), + minval=0, + maxval=self.vocab_size, + ) + explore_cond = ( + jax.random.uniform(explore_key, shape=(batch_size,)) < self.epsilon + ) + tok = jnp.where(explore_cond, explore_action, model_tok) + else: + tok = model_tok + + log_probs_model = jax.nn.log_softmax(logits_t) + lp_t_model = log_probs_model[jnp.arange(batch_size), tok] + + if self.epsilon > 0: + prob_tok_model = jnp.exp(lp_t_model) + prob_tok_random = 1.0 / self.vocab_size + mixed_prob_tok = ( + 1.0 - self.epsilon + ) * prob_tok_model + self.epsilon * prob_tok_random + lp_t = jnp.log(mixed_prob_tok + 1e-9) + else: + lp_t = lp_t_model + + seq = seq.at[:, t].set(tok) + logp = logp.at[:, t].set(lp_t) + + def decode_next( + state: tuple[jax.Array, dict[str, jax.Array]], + ) -> tuple[jax.Array, jax.Array, jax.Array, dict[str, jax.Array], jax.Array]: + inner_key, inner_cache = state + next_step_keys = _split_step_keys( + inner_key, + fixed_model_key=self.fixed_model_key, + fixed_model_rng=k_model, + ) + next_logits, inner_cache = decode_apply( + inner_cache, + tok[:, None], + next_step_keys.model_key, + ) + return ( + next_step_keys.next_key, + next_step_keys.sample_key, + next_step_keys.explore_key, + inner_cache, + next_logits, + ) + + def keep_state( + state: tuple[jax.Array, dict[str, jax.Array]], + ) -> tuple[jax.Array, jax.Array, jax.Array, dict[str, jax.Array], jax.Array]: + inner_key, inner_cache = state + return inner_key, sample_key, explore_key, inner_cache, logits_prev + + key, sample_key, explore_key, cache, logits_prev = jax.lax.cond( + t + 1 < self.sequence_length, + decode_next, + keep_state, + (key, cache), + ) + return _DecodeCarry( + seq, + logp, + key, + sample_key, + explore_key, + cache, + logits_prev, + ) + + final = jax.lax.fori_loop( + lower=prompt_len, + upper=self.sequence_length, + body_fun=step, + init_val=init, + ) + + tokens, log_probs, _, _, _, _, _ = final + if squeeze: + tokens, log_probs = tokens[0], log_probs[0] + + return SampleResult(tokens=tokens, log_probs=log_probs) + def get_full_logprobs_b_l( apply_fn: tp.Callable[..., jax.Array], @@ -189,3 +442,13 @@ class _Carry(tp.NamedTuple): seq: jax.Array # (B, seq_length) logp: jax.Array # (B, seq_length) key: jax.Array + + +class _DecodeCarry(tp.NamedTuple): + seq: jax.Array # (B, seq_length) + logp: jax.Array # (B, seq_length) + key: jax.Array + sample_key: jax.Array + explore_key: jax.Array + cache: dict[str, jax.Array] + logits_prev: jax.Array # (B, 1, vocab_size) diff --git a/egg/losses/common.py b/egg/losses/common.py index 38cc02a..52910a2 100644 --- a/egg/losses/common.py +++ b/egg/losses/common.py @@ -34,6 +34,28 @@ class PolicyLogProbs: row_mask: jax.Array # (B,) — valid rows +@dataclasses.dataclass(frozen=True) +class DelightSignals: + """Learner-side screening signals used by DG and proper Kondo.""" + + fwd: PolicyLogProbs + sampler_lp_answer: jax.Array # (B, A) + advantages: jax.Array # (B,) + surprisal_tok: jax.Array # (B, A) + delight_tok: jax.Array # (B, A) + priority_tok: jax.Array # (B, A) + priority_row: jax.Array # (B,) + + +def answer_token_mask(batch: base.Batch) -> tuple[jax.Array, jax.Array]: + """Returns row and answer-token masks for a batch.""" + prompts, answers = batch.prompts, batch.answers + row_mask = batch.aux.get("row_mask", jnp.ones(prompts.shape[0], jnp.float32)) + row_mask = row_mask.astype(jnp.float32) + token_mask = (answers >= 0).astype(jnp.float32) * row_mask[:, None] + return row_mask, token_mask + + def forward_pass( params: base.Params, state: base.StateT, @@ -55,14 +77,12 @@ def forward_pass( start = prompt_len - 1 end = prompt_len + answer_len - 1 - ans_tok_mask = (answers >= 0).astype(jnp.float32) - row_mask = batch.aux.get("row_mask", jnp.ones(prompts.shape[0], jnp.float32)) - row_mask = row_mask.astype(jnp.float32) + row_mask, token_mask = answer_token_mask(batch) return PolicyLogProbs( lp_all=lp_all[:, start:end, :], lp_answer=lp_pol[:, start:end], - token_mask=ans_tok_mask * row_mask[:, None], + token_mask=token_mask, row_mask=row_mask, ) @@ -92,3 +112,180 @@ def grouped_advantages( cnt_r = jnp.zeros(num_groups, jnp.float32).at[group_ids].add(row_mask) baseline_per_group = sum_r / (cnt_r + eps) return rewards - jax.lax.stop_gradient(baseline_per_group[group_ids]) + + +def compute_priority( + priority: str, + advantage: jax.Array, + surprisal_tok: jax.Array, + alpha: float = 0.5, +) -> jax.Array: + """Per-token Kondo priority score.""" + if priority == "delight": + return advantage[:, None] * surprisal_tok + elif priority == "advantage": + return jnp.broadcast_to(advantage[:, None], surprisal_tok.shape) + elif priority == "abs_advantage": + return jnp.broadcast_to(jnp.abs(advantage[:, None]), surprisal_tok.shape) + elif priority == "surprisal": + return surprisal_tok + elif priority == "uniform": + return jnp.ones_like(surprisal_tok) + elif priority == "additive": + return alpha * advantage[:, None] + (1.0 - alpha) * surprisal_tok + else: + raise ValueError(f"Unknown priority: {priority}") + + +def topk_token_gate( + priority_tok: jax.Array, + token_mask: jax.Array, + pct_learn: float, +) -> tuple[jax.Array, jax.Array, jax.Array]: + """Binary top-k token gate matching the dense Kondo masking rule.""" + total_valid_tokens = jnp.sum(token_mask) + k_target = jnp.maximum( + 1, jnp.round(jnp.asarray(pct_learn, jnp.float32) * total_valid_tokens).astype(jnp.int32) + ) + vals_flat = jnp.where(token_mask > 0.0, priority_tok, -jnp.inf).reshape(-1) + sorted_vals = jnp.sort(vals_flat) + threshold = sorted_vals[vals_flat.size - k_target] - 1e-6 + gate = (priority_tok >= threshold).astype(jnp.float32) * token_mask + return gate, threshold, k_target + + +def delight_signals( + params: base.Params, + state: base.StateT, + batch: base.Batch, + key: jax.Array, + *, + use_grouped_baseline: bool, + num_groups: int | None, + priority: str = "delight", + alpha_additive: float = 0.5, +) -> DelightSignals: + """Computes learner-side delight and row-level screening scores.""" + fwd = forward_pass(params, state, batch, key) + sampler_lp_answer = sampler_answer_logprobs(batch) + rewards = batch.rewards + + if use_grouped_baseline: + group_ids = batch.aux.get("group_ids") + advantages = grouped_advantages( + rewards, + group_ids, + num_groups, + fwd.row_mask, + ) + else: + advantages = grouped_advantages(rewards, None, None, fwd.row_mask) + + surprisal_tok = -fwd.lp_answer + delight_tok = advantages[:, None] * surprisal_tok + priority_tok = compute_priority( + priority, + advantages, + surprisal_tok, + alpha=alpha_additive, + ) + + tok_count_per_row = jnp.sum(fwd.token_mask, axis=1) + safe_tok_count = tok_count_per_row + 1e-8 + priority_row = ( + jnp.sum(priority_tok * fwd.token_mask, axis=1) / safe_tok_count + ) * fwd.row_mask + + return DelightSignals( + fwd=fwd, + sampler_lp_answer=sampler_lp_answer, + advantages=advantages, + surprisal_tok=surprisal_tok, + delight_tok=delight_tok, + priority_tok=priority_tok, + priority_row=priority_row, + ) + + +def delight_signals_from_sample_logprobs( + batch: base.Batch, + *, + use_grouped_baseline: bool, + num_groups: int | None, + priority: str = "delight", + alpha_additive: float = 0.5, +) -> DelightSignals: + """Computes screening signals directly from stored sampler log-probs. + + This is only exact when the actor policy matches the learner policy. + """ + row_mask, token_mask = answer_token_mask(batch) + sampler_lp_answer = sampler_answer_logprobs(batch) + rewards = batch.rewards + + if use_grouped_baseline: + group_ids = batch.aux.get("group_ids") + advantages = grouped_advantages( + rewards, + group_ids, + num_groups, + row_mask, + ) + else: + advantages = grouped_advantages(rewards, None, None, row_mask) + + surprisal_tok = -sampler_lp_answer + delight_tok = advantages[:, None] * surprisal_tok + priority_tok = compute_priority( + priority, + advantages, + surprisal_tok, + alpha=alpha_additive, + ) + + tok_count_per_row = jnp.sum(token_mask, axis=1) + safe_tok_count = tok_count_per_row + 1e-8 + priority_row = ( + jnp.sum(priority_tok * token_mask, axis=1) / safe_tok_count + ) * row_mask + + dummy_lp_all = jnp.zeros(surprisal_tok.shape + (1,), dtype=sampler_lp_answer.dtype) + fwd = PolicyLogProbs( + lp_all=dummy_lp_all, + lp_answer=sampler_lp_answer, + token_mask=token_mask, + row_mask=row_mask, + ) + return DelightSignals( + fwd=fwd, + sampler_lp_answer=sampler_lp_answer, + advantages=advantages, + surprisal_tok=surprisal_tok, + delight_tok=delight_tok, + priority_tok=priority_tok, + priority_row=priority_row, + ) + + +def compact_batch_rows(batch: base.Batch, row_indices: jax.Array) -> base.Batch: + """Selects a fixed subset of rows from a batch and row-shaped aux fields.""" + prompts = jnp.take(batch.prompts, row_indices, axis=0) + answers = jnp.take(batch.answers, row_indices, axis=0) + rewards = jnp.take(batch.rewards, row_indices, axis=0) + sample_log_probs = jnp.take(batch.sample_log_probs, row_indices, axis=0) + + aux: dict[str, jax.Array] = {} + batch_rows = batch.prompts.shape[0] + for key, value in batch.aux.items(): + if hasattr(value, "shape") and value.shape and value.shape[0] == batch_rows: + aux[key] = jnp.take(value, row_indices, axis=0) + else: + aux[key] = value + + return base.Batch( + prompts=prompts, + answers=answers, + rewards=rewards, + sample_log_probs=sample_log_probs, + aux=aux, + ) diff --git a/egg/losses/kondo.py b/egg/losses/kondo.py index 06b2a88..e36f1ae 100644 --- a/egg/losses/kondo.py +++ b/egg/losses/kondo.py @@ -29,6 +29,7 @@ from absl import logging from egg import base from egg.lib import statistics +from egg.losses import common import jax import jax.numpy as jnp @@ -71,31 +72,6 @@ def make(self) -> 'KondoLoss': alpha_additive=float(self.alpha_additive), ) - -def _compute_priority( - priority: PriorityType, - advantage: jnp.ndarray, - surprisal_tok: jnp.ndarray, - token_mask: jnp.ndarray, # pylint: disable=unused-argument - alpha: float, -) -> jnp.ndarray: - """Per-token priority score, shape (B, A).""" - if priority == PriorityType.DELIGHT: - return advantage[:, None] * surprisal_tok - elif priority == PriorityType.ADVANTAGE: - return jnp.broadcast_to(advantage[:, None], surprisal_tok.shape) - elif priority == PriorityType.ABS_ADVANTAGE: - return jnp.broadcast_to(jnp.abs(advantage[:, None]), surprisal_tok.shape) - elif priority == PriorityType.SURPRISAL: - return surprisal_tok - elif priority == PriorityType.UNIFORM: - return jnp.ones_like(surprisal_tok) - elif priority == PriorityType.ADDITIVE: - return alpha * advantage[:, None] + (1.0 - alpha) * surprisal_tok - else: - raise ValueError(f'Unknown priority: {priority}') - - @dataclasses.dataclass(frozen=True) class KondoLoss(base.LossFn): """Token-level Kondo gate. @@ -175,8 +151,11 @@ def __call__( a_tok = (advantage[:, None]) * token_mask surprisal_tok = -lp_pol_answer - priority_score = _compute_priority( - self.priority, advantage, surprisal_tok, token_mask, self.alpha_additive + priority_score = common.compute_priority( + self.priority, + advantage, + surprisal_tok, + alpha=self.alpha_additive, ) k_target = jnp.maximum( diff --git a/egg/losses/screened_pg.py b/egg/losses/screened_pg.py new file mode 100644 index 0000000..7e03b5d --- /dev/null +++ b/egg/losses/screened_pg.py @@ -0,0 +1,94 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Policy-gradient loss for pre-screened compacted batches.""" + +from __future__ import annotations + +import dataclasses + +from egg import base +from egg.lib import statistics +from egg.losses import common +import jax +import jax.numpy as jnp + + +@dataclasses.dataclass +class LossConfig(base.MakeableConfig[base.LossFn]): + """Loss config for compacted Kondo batches with precomputed advantages.""" + + beta_kl: float = 0.0 + + def make(self) -> "ScreenedPolicyGradient": + return ScreenedPolicyGradient(beta_kl=float(self.beta_kl)) + + +@dataclasses.dataclass(frozen=True) +class ScreenedPolicyGradient(base.LossFn): + """Token-level PG on compacted rows with optional KL anchor.""" + + beta_kl: float = 0.0 + + def __call__( + self, + params: base.Params, + state: base.StateT, + batch: base.Batch, + key: jax.Array, + ) -> tuple[jax.Array, base.Metrics]: + fwd = common.forward_pass(params, state, batch, key) + lp_samp = common.sampler_answer_logprobs(batch) + advantages = batch.rewards + loss_token_mask = batch.aux.get("loss_token_mask_answer") + if loss_token_mask is None: + loss_token_mask = fwd.token_mask + loss_token_mask = loss_token_mask.astype(jnp.float32) * fwd.token_mask + loss_normalizer = batch.aux.get("loss_normalizer") + if loss_normalizer is None: + loss_normalizer = jnp.sum(loss_token_mask) + 1e-8 + else: + loss_normalizer = jnp.asarray(loss_normalizer, jnp.float32) + 1e-8 + + a_tok = advantages[:, None] * loss_token_mask + per_tok_pg = -a_tok * fwd.lp_answer + backward_tok_count = jnp.sum(fwd.token_mask) + 1e-8 + selected_tok_count = jnp.sum(loss_token_mask) + 1e-8 + loss_pg = jnp.sum(per_tok_pg) / loss_normalizer + + kl_tok = (lp_samp - fwd.lp_answer) * fwd.token_mask + loss_kl = jnp.sum(kl_tok) / backward_tok_count + loss = loss_pg + self.beta_kl * loss_kl + + valid_rows = jnp.sum(fwd.row_mask) + 1e-8 + metrics: base.Metrics = { + "loss": loss, + "loss_pg": loss_pg, + "loss_kl": loss_kl, + "beta_kl": jnp.asarray(self.beta_kl, jnp.float32), + "advantage_mean": jnp.sum(advantages * fwd.row_mask) / valid_rows, + "valid_row_count": valid_rows, + "valid_token_count": backward_tok_count, + "selected_token_count": selected_tok_count, + "loss_normalizer": loss_normalizer, + **statistics.scalar_stats( + statistics.entropy_from_logp(fwd.lp_all), "policy_entropy" + ), + **statistics.logp_stats( + learner_logp=fwd.lp_answer, + sampler_logp=lp_samp, + ), + } + return loss, metrics diff --git a/egg/networks/logit_noise.py b/egg/networks/logit_noise.py index 194c645..8b7c43a 100644 --- a/egg/networks/logit_noise.py +++ b/egg/networks/logit_noise.py @@ -65,7 +65,25 @@ def setup(self): def __call__(self, x: jax.Array, *args, **kwargs) -> jax.Array: """Calls the inner network and adds the logit noise.""" base_logits = self.inner_network(x, *args, **kwargs) + return self._add_noise(base_logits) + @nn.compact + def decode_step(self, x: jax.Array, *args, **kwargs) -> jax.Array: + """Single-token cached decode step, forwarded to the inner network.""" + if not hasattr(self.inner_network, "decode_step"): + raise ValueError("Inner network does not support cached decoding.") + base_logits = self.inner_network.decode_step(x, *args, **kwargs) + return self._add_noise(base_logits) + + @nn.compact + def init_decode_cache(self, x: jax.Array, *args, **kwargs) -> jax.Array: + """Initializes the inner network's decode cache without sampling noise.""" + if not hasattr(self.inner_network, "init_decode_cache"): + raise ValueError("Inner network does not support cached decoding.") + return self.inner_network.init_decode_cache(x, *args, **kwargs) + + def _add_noise(self, base_logits: jax.Array) -> jax.Array: + """Applies the configured logit noise.""" if self.sigma == 0.0: return base_logits diff --git a/egg/networks/transformers.py b/egg/networks/transformers.py index 9b5a559..4e0e9bd 100644 --- a/egg/networks/transformers.py +++ b/egg/networks/transformers.py @@ -69,11 +69,73 @@ def __call__(self, tokens: jax.Array, **kwargs) -> jax.Array: return output_layer(x) # (B, T, V) + @nn.compact + def decode_step(self, tokens: jax.Array, **kwargs) -> jax.Array: + """Single-token cached decode step for autoregressive sampling.""" + del kwargs # Unused. + + if tokens.ndim != 2 or tokens.shape[1] != 1: + raise ValueError( + "decode_step expects tokens with shape (batch_size, 1)." + ) + + cache_index = self.variable( + "cache", + "position", + lambda: jnp.array(0, dtype=jnp.int32), + ) + x = _embed_decode(tokens, self.config, cache_index.value) + cache_index.value = cache_index.value + tokens.shape[1] + + for i in range(self.config.num_layers): + x = TransformerBlock( + self.config, + decode=True, + name=f"layer_{i}", + )(x, mask=None) + + output_layer = nn.Dense( + self.config.vocab_size, use_bias=self.config.bias, name="output" + ) + return output_layer(x) # (B, 1, V) + + @nn.compact + def init_decode_cache(self, tokens: jax.Array, **kwargs) -> jax.Array: + """Initializes decode-cache state without consuming a real token.""" + del kwargs # Unused. + + if tokens.ndim != 2 or tokens.shape[1] != 1: + raise ValueError( + "init_decode_cache expects tokens with shape (batch_size, 1)." + ) + + batch_size = tokens.shape[0] + _ = self.variable( + "cache", + "position", + lambda: jnp.array(0, dtype=jnp.int32), + ) + dummy_tokens = jnp.zeros( + (batch_size, self.config.sequence_length), + dtype=tokens.dtype, + ) + x = _embed_decode(dummy_tokens, self.config, jnp.array(0, dtype=jnp.int32)) + + for i in range(self.config.num_layers): + x = TransformerBlock( + self.config, + decode=True, + name=f"layer_{i}", + )(x, mask=None) + + return jnp.zeros((batch_size, 1, self.config.vocab_size), dtype=x.dtype) + class TransformerBlock(nn.Module): """Single transformer block: pre-norm attention + feed-forward.""" config: NetworkConfig + decode: bool = False @nn.compact def __call__(self, x: jax.Array, mask: jax.Array) -> jax.Array: @@ -82,6 +144,7 @@ def __call__(self, x: jax.Array, mask: jax.Array) -> jax.Array: h = nn.MultiHeadDotProductAttention( num_heads=self.config.num_heads, use_bias=self.config.bias, + decode=self.decode, name="self_attn", )(inputs_q=h, inputs_k=h, inputs_v=h, mask=mask) x = x + h # Residual @@ -127,3 +190,21 @@ def _embed( x = (tok_emb + pos_emb) * pad_mask # zero-out PAD rows return x, pad_mask + + +def _embed_decode( + tokens: jax.Array, + cfg: NetworkConfig, + start_position: jax.Array, +) -> jax.Array: + """Embed one decode chunk at its absolute position.""" + pad_mask = (tokens != base.PAD_TOKEN)[..., None] # (B, 1, 1) + + safe_ids = jnp.where(tokens == base.PAD_TOKEN, 0, tokens) + tok_emb = nn.Embed(cfg.vocab_size, cfg.embed_dim, name="token_emb")(safe_ids) + + pos_idx = start_position + jnp.arange(tokens.shape[1], dtype=jnp.int32) + pos_emb = nn.Embed(cfg.sequence_length, cfg.embed_dim, name="pos_emb")(pos_idx) + pos_emb = jnp.expand_dims(pos_emb, 0) # (1, 1, D) + + return (tok_emb + pos_emb) * pad_mask diff --git a/egg/trainers/kondo_async.py b/egg/trainers/kondo_async.py new file mode 100644 index 0000000..63b1d44 --- /dev/null +++ b/egg/trainers/kondo_async.py @@ -0,0 +1,373 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Async trainer with proper Kondo screening and compacted backward passes.""" + +from __future__ import annotations + +import dataclasses +import json +import time +import typing as tp + +from egg import base +from egg.lib import logging +from egg.lib import quantization +from egg.losses import common +import jax +import jax.numpy as jnp +import pandas as pd + + +def _block_tree(tree): + return jax.tree_util.tree_map(jax.block_until_ready, tree) + + +@dataclasses.dataclass(frozen=True) +class TrainerConfig(base.MakeableConfig["KondoAsyncTrainer"]): + """Configuration for the proper Kondo async trainer.""" + + steps: int = 1000 + seed: int = 0 + log_freq: int | None = None + sampler_delay: int = 0 + uniform_delay: bool = False + sampler_bits: int = 32 + deterministic: bool = True + log_details: bool = False + log_learner_performance: bool = True + + pct_learn: float = 0.5 + priority: str = "delight" + alpha_additive: float = 0.5 + use_grouped_baseline: bool = True + num_groups: int | None = None + + def make(self) -> "KondoAsyncTrainer": + if not (0.0 < self.pct_learn <= 1.0): + raise ValueError("pct_learn must be in (0, 1].") + return KondoAsyncTrainer(config=self) + + +@dataclasses.dataclass(frozen=True) +class KondoAsyncTrainer(base.Trainer): + """Async trainer that screens a full batch and backprops only on kept rows.""" + + config: TrainerConfig + + def _should_log(self, step: int) -> bool: + if self.config.log_freq is None: + return logging.logarithmic_logging(step) + return step % self.config.log_freq == 0 + + def _can_reuse_sampler_logprobs_for_screen( + self, + actor_cfg: tp.Any, + ) -> bool: + """Returns True if sampler log-probs exactly match learner log-probs.""" + if self.config.sampler_delay != 0: + return False + if self.config.sampler_bits < 32 or not self.config.deterministic: + return False + if getattr(actor_cfg, "epsilon", 0.0) != 0.0: + return False + if getattr(actor_cfg, "bug_prob", 0.0) != 0.0: + return False + if getattr(actor_cfg, "correct_prob", 0.0) != 0.0: + return False + if getattr(actor_cfg, "random_prob", 0.0) != 0.0: + return False + if getattr(actor_cfg, "override_token_prob", None) is not None: + return False + sampler_net_cfg = getattr(actor_cfg, "sampler_network_config", None) + if sampler_net_cfg is None: + return True + return getattr(sampler_net_cfg, "sigma", 0.0) == 0.0 + + def __call__( + self, + actor: base.Actor[base.StateT], + learner: base.Learner[base.StateT], + ) -> pd.DataFrame: + actor_cfg = getattr(actor, "config", None) + if actor_cfg is None: + raise ValueError("KondoAsyncTrainer expects an actor with a config.") + batch_size = actor_cfg.prompts_per_batch * actor_cfg.samples_per_prompt + keep_count = max(1, int(round(self.config.pct_learn * batch_size))) + reuse_sampler_logprobs_for_screen = ( + self._can_reuse_sampler_logprobs_for_screen(actor_cfg) + ) + + quantizer = quantization.QuantizeConfig( + num_bits=self.config.sampler_bits, + deterministic=self.config.deterministic, + ).make() + + key = jax.random.PRNGKey(self.config.seed) + key, init_key = jax.random.split(key) + state = learner.init_state(init_key) + + hist_len = self.config.sampler_delay + 1 + params_history: tuple[tp.Any, ...] = tuple([state.params] * hist_len) + + @jax.jit + def sample_batch_step( + current_state: base.StateT, + actor_params: tp.Any, + key: jax.Array, + ) -> tuple[base.Batch, base.Metrics]: + actor_state = current_state.replace(params=actor_params) + batch, _, actor_metrics = actor.sample_batch(actor_state, key) + return batch, actor_metrics + + @jax.jit + def screen_batch_step( + current_state: base.StateT, + batch: base.Batch, + key: jax.Array, + ) -> tuple[jax.Array, jax.Array, jax.Array, base.Metrics]: + if reuse_sampler_logprobs_for_screen: + signals = common.delight_signals_from_sample_logprobs( + batch, + use_grouped_baseline=self.config.use_grouped_baseline, + num_groups=self.config.num_groups, + priority=self.config.priority, + alpha_additive=self.config.alpha_additive, + ) + else: + signals = common.delight_signals( + current_state.params, + current_state, + batch, + key, + use_grouped_baseline=self.config.use_grouped_baseline, + num_groups=self.config.num_groups, + priority=self.config.priority, + alpha_additive=self.config.alpha_additive, + ) + token_gate, gate_threshold, k_target = common.topk_token_gate( + signals.priority_tok, + signals.fwd.token_mask, + self.config.pct_learn, + ) + row_scores = jnp.sum(token_gate * signals.priority_tok, axis=1) + row_scores = jnp.where(signals.fwd.row_mask > 0.0, row_scores, -jnp.inf) + top_values, keep_idx = jax.lax.top_k(row_scores, keep_count) + valid_rows = jnp.sum(signals.fwd.row_mask) + 1e-8 + tok_count = jnp.sum(signals.fwd.token_mask) + 1e-8 + selected_tok_count = jnp.sum(token_gate) + 1e-8 + kept_selected_tok_count = jnp.sum(jnp.take(token_gate, keep_idx, axis=0)) + delight_row = ( + jnp.sum(signals.delight_tok * signals.fwd.token_mask, axis=1) + / (jnp.sum(signals.fwd.token_mask, axis=1) + 1e-8) + ) * signals.fwd.row_mask + + metrics: base.Metrics = { + "rows_total": valid_rows, + "rows_kept": jnp.asarray(keep_count, jnp.float32), + "keep_fraction": jnp.asarray(keep_count, jnp.float32) / valid_rows, + "token_keep_fraction": selected_tok_count / tok_count, + "gate_threshold": gate_threshold, + "row_gate_threshold": top_values[-1], + "token_k_target": jnp.asarray(k_target, jnp.float32), + "priority_row_mean": jnp.sum(row_scores * signals.fwd.row_mask) + / valid_rows, + "priority_row_kept_mean": jnp.mean(top_values), + "advantage_mean": jnp.sum(signals.advantages * signals.fwd.row_mask) + / valid_rows, + "surprisal_token_mean": jnp.sum( + signals.surprisal_tok * signals.fwd.token_mask + ) + / tok_count, + "delight_token_mean": jnp.sum( + signals.delight_tok * signals.fwd.token_mask + ) + / tok_count, + "delight_row_mean": jnp.sum(delight_row) / valid_rows, + "tokens_selected": selected_tok_count, + "tokens_selected_in_kept_rows": kept_selected_tok_count, + "selected_token_recall": kept_selected_tok_count / selected_tok_count, + "used_sampler_logprobs": jnp.asarray( + reuse_sampler_logprobs_for_screen, jnp.float32 + ), + } + return keep_idx, signals.advantages, token_gate, metrics + + @jax.jit + def compact_batch_step( + batch: base.Batch, + keep_idx: jax.Array, + advantages: jax.Array, + token_gate: jax.Array, + ) -> tuple[base.Batch, base.Metrics]: + kept_batch = common.compact_batch_rows(batch, keep_idx) + kept_rewards = jnp.take(advantages, keep_idx, axis=0) + kept_token_gate = jnp.take(token_gate, keep_idx, axis=0) + kept_aux = dict(kept_batch.aux) + kept_aux["row_mask"] = jnp.ones((keep_count,), dtype=jnp.float32) + kept_aux["parent_row_ids"] = keep_idx + kept_aux["loss_token_mask_answer"] = kept_token_gate + kept_aux["loss_normalizer"] = jnp.sum( + (batch.answers >= 0).astype(jnp.float32) + ) + kept_batch = kept_batch._replace(rewards=kept_rewards, aux=kept_aux) + + total_tokens = jnp.sum((batch.answers >= 0).astype(jnp.float32)) + kept_tokens = jnp.sum((kept_batch.answers >= 0).astype(jnp.float32)) + selected_tokens = jnp.sum(kept_token_gate) + metrics: base.Metrics = { + "tokens_total": total_tokens, + "tokens_kept_for_backward": kept_tokens, + "backward_token_fraction": kept_tokens / (total_tokens + 1e-8), + "tokens_selected_for_loss": selected_tokens, + "selected_token_fraction_in_kept_rows": selected_tokens + / (kept_tokens + 1e-8), + } + return kept_batch, metrics + + @jax.jit + def train_on_kept_step( + current_state: base.StateT, + kept_batch: base.Batch, + key: jax.Array, + ) -> tuple[base.StateT, base.Metrics]: + return learner.step(current_state, kept_batch, key) + + @jax.jit + def evaluate_learner_performance( + current_state: base.StateT, + key: jax.Array, + ) -> jax.Array: + batch, _, _ = actor.sample_batch(current_state, key) + return batch.rewards.mean() + + # Warm up compiled paths once so timing excludes first-compile cost. + key, warm_quant_key, warm_sample_key, warm_screen_key, warm_train_key = ( + jax.random.split(key, 5) + ) + warm_actor_params = quantizer(state.params, warm_quant_key) + warm_batch, _ = sample_batch_step(state, warm_actor_params, warm_sample_key) + warm_keep_idx, warm_advantages, warm_token_gate, _ = screen_batch_step( + state, warm_batch, warm_screen_key + ) + warm_kept_batch, _ = compact_batch_step( + warm_batch, warm_keep_idx, warm_advantages, warm_token_gate + ) + warm_state, _ = train_on_kept_step(state, warm_kept_batch, warm_train_key) + _block_tree(warm_batch.prompts) + _block_tree(warm_keep_idx) + _block_tree(warm_kept_batch.prompts) + _block_tree(warm_state.params) + + records: list[dict[str, tp.Any]] = [] + start = time.perf_counter() + cumulative_reward = 0.0 + logger = logging.RunningMeanLogger() + + for step in range(1, self.config.steps + 1): + key, quant_key, sample_key, screen_key, train_key, eval_key = ( + jax.random.split(key, 6) + ) + + if self.config.uniform_delay: + quant_key, draw_key = jax.random.split(quant_key) + idx = int( + jax.random.randint(draw_key, shape=(), minval=0, maxval=hist_len) + ) + chosen_params = params_history[idx] + else: + chosen_params = params_history[0] + actor_params = quantizer(chosen_params, quant_key) + + sample_t0 = time.perf_counter() + batch, actor_metrics = sample_batch_step(state, actor_params, sample_key) + _block_tree(batch.prompts) + sample_time_s = time.perf_counter() - sample_t0 + + screen_t0 = time.perf_counter() + keep_idx, advantages, token_gate, screen_metrics = screen_batch_step( + state, batch, screen_key + ) + _block_tree(keep_idx) + screen_time_s = time.perf_counter() - screen_t0 + + compact_t0 = time.perf_counter() + kept_batch, compact_metrics = compact_batch_step( + batch, keep_idx, advantages, token_gate + ) + _block_tree(kept_batch.prompts) + compact_time_s = time.perf_counter() - compact_t0 + + train_t0 = time.perf_counter() + state, learner_metrics = train_on_kept_step(state, kept_batch, train_key) + _block_tree(state.params) + train_time_s = time.perf_counter() - train_t0 + + params_history = params_history[1:] + (state.params,) + + metrics: base.Metrics = { + "reward": batch.rewards.mean(), + "sample_time_s": jnp.asarray(sample_time_s, jnp.float32), + "screen_time_s": jnp.asarray(screen_time_s, jnp.float32), + "compact_time_s": jnp.asarray(compact_time_s, jnp.float32), + "train_time_s": jnp.asarray(train_time_s, jnp.float32), + "total_step_time_s": jnp.asarray( + sample_time_s + screen_time_s + compact_time_s + train_time_s, + jnp.float32, + ), + } + metrics.update({f"screen/{k}": v for k, v in screen_metrics.items()}) + metrics.update({f"compact/{k}": v for k, v in compact_metrics.items()}) + if self.config.log_details: + metrics.update({ + **{f"actor/{k}": v for k, v in actor_metrics.items()}, + **{f"learner/{k}": v for k, v in learner_metrics.items()}, + }) + + logger.record(metrics) + cumulative_reward += float(metrics["reward"]) + + if self._should_log(step) or step == self.config.steps: + record = logger.write() + record["step"] = step + record["time"] = time.perf_counter() - start + record["cum_reward"] = cumulative_reward + + if self.config.log_learner_performance: + learner_reward = evaluate_learner_performance(state, eval_key) + record["learner_reward"] = float(learner_reward) + + records.append(record) + + summary: dict[str, tp.Any] = { + "step": step, + "time": _format_float(record["time"]), + "reward": _format_float(record.get("reward", 0.0)), + "cum_reward": _format_float(cumulative_reward), + "keep_fraction": _format_float( + record.get("screen/keep_fraction", 0.0) + ), + "train_time_s": _format_float(record.get("train_time_s", 0.0)), + } + if self.config.log_learner_performance: + summary["learner_reward"] = _format_float(record["learner_reward"]) + print(json.dumps(summary)) + + print("--- Training finished ---") + return pd.DataFrame(records) + + +def _format_float(f: float) -> float: + return float(format(f, ".3g")) diff --git a/egg/trainers/vanilla_async.py b/egg/trainers/vanilla_async.py index f7ea871..9bded72 100644 --- a/egg/trainers/vanilla_async.py +++ b/egg/trainers/vanilla_async.py @@ -32,6 +32,10 @@ import pandas as pd +def _block_tree(tree): + return jax.tree_util.tree_map(jax.block_until_ready, tree) + + @dataclasses.dataclass(frozen=True) class TrainerConfig(base.MakeableConfig["VanillaAsyncTrainer"]): """Configuration for the vanilla async trainer.""" @@ -83,30 +87,22 @@ def __call__( params_history: tuple[tp.Any, ...] = tuple([state.params] * hist_len) @jax.jit - def train_step( + def sample_batch_step( current_state: base.StateT, actor_params: tp.Any, key: jax.Array, - ) -> tuple[base.StateT, jax.Array, base.Metrics]: - # Actor: sample a batch using provided actor_params - key, sample_key = jax.random.split(key) + ) -> tuple[base.Batch, base.Metrics]: actor_state = current_state.replace(params=actor_params) - batch, _, actor_metrics = actor.sample_batch(actor_state, sample_key) + batch, _, actor_metrics = actor.sample_batch(actor_state, key) + return batch, actor_metrics - # Learner: compute grads + update state - key, grad_key = jax.random.split(key) - new_state, learner_metrics = learner.step(current_state, batch, grad_key) - - # Conditional metric logging based on log_details - metrics: base.Metrics = { - "reward": batch.rewards.mean(), - } - if self.config.log_details: - metrics.update({ - **{f"actor/{k}": v for k, v in actor_metrics.items()}, - **{f"learner/{k}": v for k, v in learner_metrics.items()}, - }) - return new_state, key, metrics + @jax.jit + def train_on_batch_step( + current_state: base.StateT, + batch: base.Batch, + key: jax.Array, + ) -> tuple[base.StateT, base.Metrics]: + return learner.step(current_state, batch, key) @jax.jit def evaluate_learner_performance( @@ -117,6 +113,15 @@ def evaluate_learner_performance( batch, _, _ = actor.sample_batch(current_state, key) return batch.rewards.mean() + key, warm_quant_key, warm_sample_key, warm_train_key = jax.random.split( + key, 4 + ) + warm_actor_params = quantizer(state.params, warm_quant_key) + warm_batch, _ = sample_batch_step(state, warm_actor_params, warm_sample_key) + warm_state, _ = train_on_batch_step(state, warm_batch, warm_train_key) + _block_tree(warm_batch.prompts) + _block_tree(warm_state.params) + records: list[dict[str, tp.Any]] = [] start = time.time() cumulative_reward = 0.0 @@ -124,7 +129,9 @@ def evaluate_learner_performance( for step in range(1, self.config.steps + 1): # RNG for: (quantization, training step, evaluation) - key, quant_key, train_key, eval_key = jax.random.split(key, 4) + key, quant_key, sample_key, train_key, eval_key = jax.random.split( + key, 5 + ) # Choose which snapshot the actor will see this step. # Fixed worst-case delay 0=oldest. Uniform delay sample in [0, hist_len]. @@ -141,8 +148,27 @@ def evaluate_learner_performance( # Quantize the chosen params for the actor. actor_params = quantizer(chosen_params, quant_key) - # Run one training step - state, key, metrics = train_step(state, actor_params, train_key) + sample_t0 = time.perf_counter() + batch, actor_metrics = sample_batch_step(state, actor_params, sample_key) + _block_tree(batch.prompts) + sample_time_s = time.perf_counter() - sample_t0 + + train_t0 = time.perf_counter() + state, learner_metrics = train_on_batch_step(state, batch, train_key) + _block_tree(state.params) + train_time_s = time.perf_counter() - train_t0 + + metrics: base.Metrics = { + "reward": batch.rewards.mean(), + "sample_time_s": sample_time_s, + "train_time_s": train_time_s, + "total_step_time_s": sample_time_s + train_time_s, + } + if self.config.log_details: + metrics.update({ + **{f"actor/{k}": v for k, v in actor_metrics.items()}, + **{f"learner/{k}": v for k, v in learner_metrics.items()}, + }) # Update history with the new parameters (drop oldest, append newest) params_history = params_history[1:] + (state.params,) diff --git a/experiments/kondo/run.py b/experiments/kondo/run.py index 0ab05cf..054f367 100644 --- a/experiments/kondo/run.py +++ b/experiments/kondo/run.py @@ -21,7 +21,8 @@ Losses: delightful: DG with sigmoid gate (no backward-pass savings). - kondo: DG-K with binary Kondo gate (skips backward passes). + kondo: base token-mask Kondo, or proper Kondo compute skipping when + `proper_kondo=True`. reinforce: PG baseline. ppo, pmpo: standard RL baselines. """ @@ -41,8 +42,10 @@ from egg.losses import catalog from egg.losses import dg from egg.losses import kondo +from egg.losses import screened_pg from egg.networks import logit_noise from egg.networks import transformers +from egg.trainers import kondo_async from egg.trainers import vanilla_async import fancyflags as ff import jax.numpy as jnp @@ -82,18 +85,23 @@ class SweepConfig: # 'delight', 'advantage', 'abs_advantage', 'surprisal', 'uniform', 'additive' priority: str = 'delight' alpha_additive: float = 0.5 # Only used if priority = additive + proper_kondo: bool = True # Use compacted-batch compute skipping. learning_rate: float = 3e-4 seed: int = 42 num_steps: int = 1000 + log_freq: int | None = None prompts_per_batch: int = 10 samples_per_prompt: int = 10 log_details: bool = False -SweepFlags = ff.DEFINE_from_instance( - 'sweep', SweepConfig(), 'Sweep configuration.' -) +if hasattr(ff, 'DEFINE_from_instance'): + SweepFlags = ff.DEFINE_from_instance( + 'sweep', SweepConfig(), 'Sweep configuration.' + ) +else: + SweepFlags = ff.DEFINE_auto('sweep', SweepConfig, 'Sweep configuration.') def run_experiment(sweep_config: SweepConfig) -> pd.DataFrame: @@ -161,14 +169,19 @@ def run_experiment(sweep_config: SweepConfig) -> pd.DataFrame: num_groups=sweep_config.prompts_per_batch, ) elif sweep_config.loss == 'kondo': - loss_cfg = kondo.LossConfig( - pct_learn=sweep_config.loss_param_one, - priority=sweep_config.priority, - alpha_additive=sweep_config.alpha_additive, - beta_kl=sweep_config.loss_param_two, - use_grouped_baseline=True, - num_groups=sweep_config.prompts_per_batch, - ) + if sweep_config.proper_kondo: + loss_cfg = screened_pg.LossConfig( + beta_kl=sweep_config.loss_param_two, + ) + else: + loss_cfg = kondo.LossConfig( + pct_learn=sweep_config.loss_param_one, + priority=sweep_config.priority, + alpha_additive=sweep_config.alpha_additive, + beta_kl=sweep_config.loss_param_two, + use_grouped_baseline=True, + num_groups=sweep_config.prompts_per_batch, + ) elif sweep_config.loss == 'ppo': loss_cfg = catalog.Loss.PPO.config( clip_epsilon=sweep_config.loss_param_one, @@ -200,11 +213,25 @@ def run_experiment(sweep_config: SweepConfig) -> pd.DataFrame: learning_rate=sweep_config.learning_rate, ) - trainer_cfg = vanilla_async.TrainerConfig( - steps=sweep_config.num_steps, - seed=sweep_config.seed, - log_details=sweep_config.log_details, - ) + if sweep_config.loss == 'kondo' and sweep_config.proper_kondo: + trainer_cfg = kondo_async.TrainerConfig( + steps=sweep_config.num_steps, + seed=sweep_config.seed, + log_freq=sweep_config.log_freq, + log_details=sweep_config.log_details, + pct_learn=sweep_config.loss_param_one, + priority=sweep_config.priority, + alpha_additive=sweep_config.alpha_additive, + use_grouped_baseline=True, + num_groups=sweep_config.prompts_per_batch, + ) + else: + trainer_cfg = vanilla_async.TrainerConfig( + steps=sweep_config.num_steps, + seed=sweep_config.seed, + log_freq=sweep_config.log_freq, + log_details=sweep_config.log_details, + ) actor = actor_cfg.make() learner = learner_cfg.make()