Skip to content

Commit

Permalink
Remove old test
Browse files Browse the repository at this point in the history
  • Loading branch information
rystrauss committed Jan 12, 2025
1 parent 8e71846 commit c77f6be
Showing 1 changed file with 1 addition and 16 deletions.
17 changes: 1 addition & 16 deletions tests/test_rollouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp
import jax.random

from dopamax.environments import CartPole
from dopamax.environments import make_env
from dopamax.rollouts import create_minibatches, rollout_episode, SampleBatch


Expand All @@ -23,18 +23,3 @@ def test_create_minibatches():

for i in range(5):
assert jnp.all(minibatches["a"][i] * 10 == jnp.expand_dims(minibatches["b"][i], 1))


def test_rollout_episode_render():
key = jax.random.PRNGKey(0)
env = CartPole()

rollout_data = jax.jit(rollout_episode, static_argnums=(0, 1, 4))(
env,
lambda params, key, obs: (env.action_space.sample(key), {}),
{},
key,
render=True,
)

chex.assert_shape(rollout_data[SampleBatch.RENDER], (env.max_episode_length, *env.render_shape))

0 comments on commit c77f6be

Please sign in to comment.