Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,10 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST

# Local development environments and test caches
.venv/
.pytest_cache/

# Generated experiment outputs
experiments/kondo/artifacts/
6 changes: 5 additions & 1 deletion egg/actors/dreamer_bps.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ def sample_batch(
self.sampler_network.apply if self.sampler_network else state.apply_fn
)
seqs, logps_full = self.sampler(
sampler_apply_fn, state.params, prompts_flat, k_sample
sampler_apply_fn,
state.params,
prompts_flat,
k_sample,
model=self.sampler_network,
)

# Answer length.
Expand Down
1 change: 1 addition & 0 deletions egg/actors/fixed_bpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _sample(prompts: jax.Array, k: jax.Array) -> ar_sample.SampleResult:
params=state.params,
prompts=prompts,
key=k,
model=self.sampler_network,
)

seqs, logps = jax.vmap(_sample)(active_prompts, sample_keys)
Expand Down
6 changes: 5 additions & 1 deletion egg/actors/fixed_bps.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ def sample_batch(
sampler_apply_fn = state.apply_fn

seqs, logps = self.sampler(
sampler_apply_fn, state.params, prompts_flat, k_sample
sampler_apply_fn,
state.params,
prompts_flat,
k_sample,
model=self.sampler_network,
)
answers = seqs[:, prompt_len:] # (B, A_len)

Expand Down
297 changes: 280 additions & 17 deletions egg/lib/ar_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,41 @@ class SampleResult(tp.NamedTuple):
log_probs: jax.Array # same shape as `tokens`


def _extract_logits(
outs: jax.Array | dict[str, jax.Array],
) -> jax.Array:
if isinstance(outs, dict):
return outs["policy_logits"]
return outs


class _StepKeys(tp.NamedTuple):
next_key: jax.Array
sample_key: jax.Array
explore_key: jax.Array
model_key: jax.Array | None


def _split_step_keys(
key: jax.Array,
*,
fixed_model_key: bool,
fixed_model_rng: jax.Array | None,
) -> _StepKeys:
"""Splits one dense-sampler step worth of RNG state."""
key, sample_key, explore_key = jax.random.split(key, 3)
if fixed_model_key:
model_key = fixed_model_rng
else:
key, model_key = jax.random.split(key)
return _StepKeys(
next_key=key,
sample_key=sample_key,
explore_key=explore_key,
model_key=model_key,
)


@dataclasses.dataclass(frozen=True)
class ARSampler:
"""Autoregressively sample *a batch* of sequences in parallel."""
Expand All @@ -39,6 +74,7 @@ class ARSampler:
fixed_model_key: bool = False # If True, model key is fixed per episode.
epsilon: float = 0.0 # Probability of taking a random action.
vocab_size: int | None = None # Must be provided if epsilon > 0
use_decode_cache: bool = True

def __post_init__(self):
if self.epsilon > 0 and self.vocab_size is None:
Expand All @@ -50,8 +86,26 @@ def __call__(
params: tp.Any,
prompts: jax.Array, # (B, T_p) or (T_p,)
key: jax.Array,
model: tp.Any | None = None,
) -> SampleResult:
"""Autoregressively sample *a batch* of sequences in parallel."""
if (
self.use_decode_cache
and model is not None
and hasattr(model, "decode_step")
and hasattr(model, "init_decode_cache")
):
return self._sample_cached(apply_fn, params, prompts, key, model)
return self._sample_dense(apply_fn, params, prompts, key)

def _sample_dense(
self,
apply_fn: tp.Callable[..., jax.Array | dict[str, jax.Array]],
params: tp.Any,
prompts: jax.Array,
key: jax.Array,
) -> SampleResult:
"""Dense autoregressive sampling without decode cache."""
squeeze = prompts.ndim == 1
if squeeze:
prompts = prompts[None, :] # (1, T_p)
Expand All @@ -71,31 +125,31 @@ def __call__(
init = _Carry(seq=seq, logp=logp, key=key)

def step(t: int, carry: _Carry) -> _Carry:
seq, logp, k = carry
k, k_sample, k_explore = jax.random.split(k, 3)

if self.fixed_model_key:
current_k_model = k_model
else:
k, current_k_model = jax.random.split(k)

outs = apply_fn({"params": params}, seq, rngs={"noise": current_k_model})
if isinstance(outs, dict):
logits = outs["policy_logits"]
else:
logits = outs
seq, logp, key = carry
step_keys = _split_step_keys(
key,
fixed_model_key=self.fixed_model_key,
fixed_model_rng=k_model,
)

outs = apply_fn({"params": params}, seq, rngs={"noise": step_keys.model_key})
logits = _extract_logits(outs)
logits_t = logits[:, t - 1] # (B, V)

# Sample from model
model_tok = jax.random.categorical(k_sample, logits_t) # (B,)
model_tok = jax.random.categorical(step_keys.sample_key, logits_t) # (B,)

if self.epsilon > 0:
# Epsilon-greedy exploration
explore_action = jax.random.randint(
k_explore, shape=(batch_size,), minval=0, maxval=self.vocab_size
step_keys.explore_key,
shape=(batch_size,),
minval=0,
maxval=self.vocab_size,
)
explore_cond = (
jax.random.uniform(k_explore, shape=(batch_size,)) < self.epsilon
jax.random.uniform(step_keys.explore_key, shape=(batch_size,))
< self.epsilon
)
tok = jnp.where(explore_cond, explore_action, model_tok)
else:
Expand Down Expand Up @@ -127,7 +181,7 @@ def step(t: int, carry: _Carry) -> _Carry:

seq = seq.at[:, t].set(tok)
logp = logp.at[:, t].set(lp_t)
return _Carry(seq, logp, k)
return _Carry(seq, logp, step_keys.next_key)

final = jax.lax.fori_loop(
lower=prompt_len,
Expand All @@ -142,6 +196,205 @@ def step(t: int, carry: _Carry) -> _Carry:

return SampleResult(tokens=tokens, log_probs=log_probs)

def _sample_cached(
self,
apply_fn: tp.Callable[..., jax.Array | dict[str, jax.Array]],
params: tp.Any,
prompts: jax.Array,
key: jax.Array,
model: tp.Any,
) -> SampleResult:
"""Autoregressive sampling with prompt prefill and cached decoding."""
squeeze = prompts.ndim == 1
if squeeze:
prompts = prompts[None, :]
batch_size, prompt_len = prompts.shape
if prompt_len == 0:
return self._sample_dense(apply_fn, params, prompts, key)

seq = jnp.full(
(batch_size, self.sequence_length), self.pad_token, prompts.dtype
)
seq = seq.at[:, :prompt_len].set(prompts)
logp = jnp.zeros((batch_size, self.sequence_length), jnp.float32)

if self.fixed_model_key:
k_model, key = jax.random.split(key)
else:
k_model = None

def decode_apply(
cache: dict[str, jax.Array] | None,
token: jax.Array,
model_key: jax.Array | None,
) -> tuple[jax.Array, dict[str, jax.Array]]:
variables: dict[str, tp.Any] = {"params": params}
if cache is not None:
variables["cache"] = cache
if model_key is None:
outs, mutated = apply_fn(
variables,
token,
method=model.decode_step,
mutable=["cache"],
)
else:
outs, mutated = apply_fn(
variables,
token,
rngs={"noise": model_key},
method=model.decode_step,
mutable=["cache"],
)
return _extract_logits(outs), mutated["cache"]

def init_cache_apply(token: jax.Array) -> dict[str, jax.Array]:
_, mutated = apply_fn(
{"params": params},
token,
method=model.init_decode_cache,
mutable=["cache"],
)
return mutated["cache"]

def prompt_token_at(i: int) -> jax.Array:
return jax.lax.dynamic_slice(prompts, (0, i), (batch_size, 1))

first_step_keys = _split_step_keys(
key,
fixed_model_key=self.fixed_model_key,
fixed_model_rng=k_model,
)
cache = init_cache_apply(prompt_token_at(0))
prefill_model_key = first_step_keys.model_key
logits_prev, cache = decode_apply(
cache,
prompt_token_at(0),
prefill_model_key,
)

def prefill_step(
i: int,
carry: tuple[dict[str, jax.Array], jax.Array],
) -> tuple[dict[str, jax.Array], jax.Array]:
current_cache, current_logits = carry
current_logits, current_cache = decode_apply(
current_cache,
prompt_token_at(i),
prefill_model_key,
)
return current_cache, current_logits

cache, logits_prev = jax.lax.fori_loop(
1,
prompt_len,
prefill_step,
(cache, logits_prev),
)

init = _DecodeCarry(
seq=seq,
logp=logp,
key=first_step_keys.next_key,
sample_key=first_step_keys.sample_key,
explore_key=first_step_keys.explore_key,
cache=cache,
logits_prev=logits_prev,
)

def step(t: int, carry: _DecodeCarry) -> _DecodeCarry:
seq, logp, key, sample_key, explore_key, cache, logits_prev = carry

logits_t = logits_prev[:, 0] # (B, V)
model_tok = jax.random.categorical(sample_key, logits_t)

if self.epsilon > 0:
explore_action = jax.random.randint(
explore_key,
shape=(batch_size,),
minval=0,
maxval=self.vocab_size,
)
explore_cond = (
jax.random.uniform(explore_key, shape=(batch_size,)) < self.epsilon
)
tok = jnp.where(explore_cond, explore_action, model_tok)
else:
tok = model_tok

log_probs_model = jax.nn.log_softmax(logits_t)
lp_t_model = log_probs_model[jnp.arange(batch_size), tok]

if self.epsilon > 0:
prob_tok_model = jnp.exp(lp_t_model)
prob_tok_random = 1.0 / self.vocab_size
mixed_prob_tok = (
1.0 - self.epsilon
) * prob_tok_model + self.epsilon * prob_tok_random
lp_t = jnp.log(mixed_prob_tok + 1e-9)
else:
lp_t = lp_t_model

seq = seq.at[:, t].set(tok)
logp = logp.at[:, t].set(lp_t)

def decode_next(
state: tuple[jax.Array, dict[str, jax.Array]],
) -> tuple[jax.Array, jax.Array, jax.Array, dict[str, jax.Array], jax.Array]:
inner_key, inner_cache = state
next_step_keys = _split_step_keys(
inner_key,
fixed_model_key=self.fixed_model_key,
fixed_model_rng=k_model,
)
next_logits, inner_cache = decode_apply(
inner_cache,
tok[:, None],
next_step_keys.model_key,
)
return (
next_step_keys.next_key,
next_step_keys.sample_key,
next_step_keys.explore_key,
inner_cache,
next_logits,
)

def keep_state(
state: tuple[jax.Array, dict[str, jax.Array]],
) -> tuple[jax.Array, jax.Array, jax.Array, dict[str, jax.Array], jax.Array]:
inner_key, inner_cache = state
return inner_key, sample_key, explore_key, inner_cache, logits_prev

key, sample_key, explore_key, cache, logits_prev = jax.lax.cond(
t + 1 < self.sequence_length,
decode_next,
keep_state,
(key, cache),
)
return _DecodeCarry(
seq,
logp,
key,
sample_key,
explore_key,
cache,
logits_prev,
)

final = jax.lax.fori_loop(
lower=prompt_len,
upper=self.sequence_length,
body_fun=step,
init_val=init,
)

tokens, log_probs, _, _, _, _, _ = final
if squeeze:
tokens, log_probs = tokens[0], log_probs[0]

return SampleResult(tokens=tokens, log_probs=log_probs)


def get_full_logprobs_b_l(
apply_fn: tp.Callable[..., jax.Array],
Expand Down Expand Up @@ -189,3 +442,13 @@ class _Carry(tp.NamedTuple):
seq: jax.Array # (B, seq_length)
logp: jax.Array # (B, seq_length)
key: jax.Array


class _DecodeCarry(tp.NamedTuple):
seq: jax.Array # (B, seq_length)
logp: jax.Array # (B, seq_length)
key: jax.Array
sample_key: jax.Array
explore_key: jax.Array
cache: dict[str, jax.Array]
logits_prev: jax.Array # (B, 1, vocab_size)
Loading