BraxAutoResetWrapper discards terminal observation, breaking correct value bootstrapping on truncation
Hello everyone, I have searched through the existing and closed issues but I have found nothing regarding incorrect value bootstrapping on truncation, so here it is:
Summary
BraxAutoResetWrapper (in mujoco_playground/_src/wrapper.py) overwrites state.obs with the reset observation on done, without preserving the pre-reset observation anywhere in state.info. While this is perfect for automated-reset in vectorized environments, this makes it impossible for downstream RL algorithms to implement correct value bootstrapping on truncation.
Why this matters for RL
Every actor-critic method needs to distinguish termination from truncation when computing TD targets / GAE:
- Termination (e.g. agent fell): the MDP ends. Target is
r_t (no bootstrap).
- Truncation (e.g. time limit hit, task still alive): the episode is artificially cut. Target is
r_t + γ·V(s_{t+1}), where s_{t+1} is the observation the agent would have seen next had the episode continued.
wrap_for_brax_training already exposes the termination/truncation distinction correctly via state.info['truncation'] (set by brax's EpisodeWrapper). But to actually compute the truncation bootstrap V(s_{t+1}), the learner needs access to s_{t+1} — the terminal observation of the truncated trajectory. After BraxAutoResetWrapper.step, that observation is gone: state.obs has been replaced by the reset state, and info only holds reset-related bookkeeping (_first_data, _first_obs, _rng, _done_count).
In practice this means every PPO/SAC implementation built on Playground is silently bootstrapping on V(s_reset) instead of V(s_{t+1}) at every truncation. The bias is small on short-horizon tasks with low γ but grows with episode length and γ, and is particularly pathological on locomotion tasks where s_reset is an upright pose very different from the late-episode state.
The same flaw exists in brax.envs.wrappers.training.AutoResetWrapper. Gymnasium solved this years ago by putting the terminal observation in info["final_observation"] on truncation, which is the convention most SB3/CleanRL/RLlib code expects.
Proposed fix
In BraxAutoResetWrapper.step, before the where_done mask replaces obs and data, stash the pre-reset observation in info:
# inside BraxAutoResetWrapper.step, after self.env.step(...) returns
next_info = state.info
next_info[f"{self._info_key}_final_obs"] = state.obs
# ...rest unchanged
This is does not break anything, and enables correct truncation handling for any downstream learner that reads it.
Reproduction
import jax, jax.numpy as jnp
from mujoco_playground import registry, wrapper
env = wrapper.wrap_for_brax_training(registry.load("HopperHop"), episode_length=10)
state = jax.jit(env.reset)(jax.random.split(jax.random.PRNGKey(0), 2))
for t in range(10):
state = jax.jit(env.step)(state, jnp.zeros((2, env.action_size)))
print("done:", state.done, "truncation:", state.info["truncation"])
print("obs (overwritten by reset):", state.obs[0, :3])
# no info key holds the pre-reset observation
print("info keys:", [k for k in state.info if "obs" in k.lower()])
Expected: some info['...final_obs...'] key carrying the real terminal observation.
Actual: no such key exists.
Maybe I am missing something, please tell me if that's the case :)
Otherwise, I can take care of the pull-request if needed.
Yann
BraxAutoResetWrapperdiscards terminal observation, breaking correct value bootstrapping on truncationHello everyone, I have searched through the existing and closed issues but I have found nothing regarding incorrect value bootstrapping on truncation, so here it is:
Summary
BraxAutoResetWrapper(inmujoco_playground/_src/wrapper.py) overwritesstate.obswith the reset observation ondone, without preserving the pre-reset observation anywhere instate.info. While this is perfect for automated-reset in vectorized environments, this makes it impossible for downstream RL algorithms to implement correct value bootstrapping on truncation.Why this matters for RL
Every actor-critic method needs to distinguish termination from truncation when computing TD targets / GAE:
r_t(no bootstrap).r_t + γ·V(s_{t+1}), wheres_{t+1}is the observation the agent would have seen next had the episode continued.wrap_for_brax_trainingalready exposes the termination/truncation distinction correctly viastate.info['truncation'](set by brax'sEpisodeWrapper). But to actually compute the truncation bootstrapV(s_{t+1}), the learner needs access tos_{t+1}— the terminal observation of the truncated trajectory. AfterBraxAutoResetWrapper.step, that observation is gone:state.obshas been replaced by the reset state, andinfoonly holds reset-related bookkeeping (_first_data,_first_obs,_rng,_done_count).In practice this means every PPO/SAC implementation built on Playground is silently bootstrapping on
V(s_reset)instead ofV(s_{t+1})at every truncation. The bias is small on short-horizon tasks with low γ but grows with episode length and γ, and is particularly pathological on locomotion tasks wheres_resetis an upright pose very different from the late-episode state.The same flaw exists in
brax.envs.wrappers.training.AutoResetWrapper. Gymnasium solved this years ago by putting the terminal observation ininfo["final_observation"]on truncation, which is the convention most SB3/CleanRL/RLlib code expects.Proposed fix
In
BraxAutoResetWrapper.step, before thewhere_donemask replacesobsanddata, stash the pre-reset observation in info:This is does not break anything, and enables correct truncation handling for any downstream learner that reads it.
Reproduction
Expected: some
info['...final_obs...']key carrying the real terminal observation.Actual: no such key exists.
Maybe I am missing something, please tell me if that's the case :)
Otherwise, I can take care of the pull-request if needed.
Yann