diff --git a/tests/test_rollouts.py b/tests/test_rollouts.py index 4445485..d134925 100644 --- a/tests/test_rollouts.py +++ b/tests/test_rollouts.py @@ -2,7 +2,7 @@ import jax.numpy as jnp import jax.random -from dopamax.environments import CartPole +from dopamax.environments import make_env from dopamax.rollouts import create_minibatches, rollout_episode, SampleBatch @@ -23,18 +23,3 @@ def test_create_minibatches(): for i in range(5): assert jnp.all(minibatches["a"][i] * 10 == jnp.expand_dims(minibatches["b"][i], 1)) - - -def test_rollout_episode_render(): - key = jax.random.PRNGKey(0) - env = CartPole() - - rollout_data = jax.jit(rollout_episode, static_argnums=(0, 1, 4))( - env, - lambda params, key, obs: (env.action_space.sample(key), {}), - {}, - key, - render=True, - ) - - chex.assert_shape(rollout_data[SampleBatch.RENDER], (env.max_episode_length, *env.render_shape))