Skip to content

BraxAutoResetWrapper discards terminal observation, breaking correct value bootstrapping on truncation #305

@YannBerthelot

Description

@YannBerthelot

BraxAutoResetWrapper discards terminal observation, breaking correct value bootstrapping on truncation

Hello everyone, I have searched through the existing and closed issues but I have found nothing regarding incorrect value bootstrapping on truncation, so here it is:

Summary

BraxAutoResetWrapper (in mujoco_playground/_src/wrapper.py) overwrites state.obs with the reset observation on done, without preserving the pre-reset observation anywhere in state.info. While this is perfect for automated-reset in vectorized environments, this makes it impossible for downstream RL algorithms to implement correct value bootstrapping on truncation.

Why this matters for RL

Every actor-critic method needs to distinguish termination from truncation when computing TD targets / GAE:

  • Termination (e.g. agent fell): the MDP ends. Target is r_t (no bootstrap).
  • Truncation (e.g. time limit hit, task still alive): the episode is artificially cut. Target is r_t + γ·V(s_{t+1}), where s_{t+1} is the observation the agent would have seen next had the episode continued.

wrap_for_brax_training already exposes the termination/truncation distinction correctly via state.info['truncation'] (set by brax's EpisodeWrapper). But to actually compute the truncation bootstrap V(s_{t+1}), the learner needs access to s_{t+1} — the terminal observation of the truncated trajectory. After BraxAutoResetWrapper.step, that observation is gone: state.obs has been replaced by the reset state, and info only holds reset-related bookkeeping (_first_data, _first_obs, _rng, _done_count).

In practice this means every PPO/SAC implementation built on Playground is silently bootstrapping on V(s_reset) instead of V(s_{t+1}) at every truncation. The bias is small on short-horizon tasks with low γ but grows with episode length and γ, and is particularly pathological on locomotion tasks where s_reset is an upright pose very different from the late-episode state.

The same flaw exists in brax.envs.wrappers.training.AutoResetWrapper. Gymnasium solved this years ago by putting the terminal observation in info["final_observation"] on truncation, which is the convention most SB3/CleanRL/RLlib code expects.

Proposed fix

In BraxAutoResetWrapper.step, before the where_done mask replaces obs and data, stash the pre-reset observation in info:

# inside BraxAutoResetWrapper.step, after self.env.step(...) returns
next_info = state.info
next_info[f"{self._info_key}_final_obs"] = state.obs
# ...rest unchanged

This is does not break anything, and enables correct truncation handling for any downstream learner that reads it.

Reproduction

import jax, jax.numpy as jnp
from mujoco_playground import registry, wrapper

env = wrapper.wrap_for_brax_training(registry.load("HopperHop"), episode_length=10)
state = jax.jit(env.reset)(jax.random.split(jax.random.PRNGKey(0), 2))
for t in range(10):
    state = jax.jit(env.step)(state, jnp.zeros((2, env.action_size)))
print("done:", state.done, "truncation:", state.info["truncation"])
print("obs (overwritten by reset):", state.obs[0, :3])
# no info key holds the pre-reset observation
print("info keys:", [k for k in state.info if "obs" in k.lower()])

Expected: some info['...final_obs...'] key carrying the real terminal observation.
Actual: no such key exists.

Maybe I am missing something, please tell me if that's the case :)

Otherwise, I can take care of the pull-request if needed.

Yann

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions