diff --git a/tests/test_rollouts.py b/tests/test_rollouts.py
index 4445485..d134925 100644
--- a/tests/test_rollouts.py
+++ b/tests/test_rollouts.py
@@ -2,7 +2,7 @@
 import jax.numpy as jnp
 import jax.random
 
-from dopamax.environments import CartPole
+from dopamax.environments import make_env
 from dopamax.rollouts import create_minibatches, rollout_episode, SampleBatch
 
 
@@ -23,18 +23,3 @@ def test_create_minibatches():
 
     for i in range(5):
         assert jnp.all(minibatches["a"][i] * 10 == jnp.expand_dims(minibatches["b"][i], 1))
-
-
-def test_rollout_episode_render():
-    key = jax.random.PRNGKey(0)
-    env = CartPole()
-
-    rollout_data = jax.jit(rollout_episode, static_argnums=(0, 1, 4))(
-        env,
-        lambda params, key, obs: (env.action_space.sample(key), {}),
-        {},
-        key,
-        render=True,
-    )
-
-    chex.assert_shape(rollout_data[SampleBatch.RENDER], (env.max_episode_length, *env.render_shape))