-
Notifications
You must be signed in to change notification settings - Fork 288
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug description
Description of what the bug is.
AssertionError: Tuple observations are not supported.
Steps to reproduce
Code or a description of how to reproduce the bug.
import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy
from imitation.algorithms import bc
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
rng = np.random.default_rng(0)
env = gym.make("Pendulum-v1")
env = RolloutInfoWrapper(env)
def train_expert():
print("Training a expert.")
expert = PPO(
policy=MlpPolicy,
env=env,
seed=0,
batch_size=64,
ent_coef=0.0,
learning_rate=0.0003,
n_epochs=10,
n_steps=64,
)
expert.learn(10) # Note: change this to 100_000 to train a decent expert.
return expert
def sample_expert_transitions():
expert = train_expert()
print("Sampling expert transitions.")
rollouts = rollout.rollout(
expert,
env,
rollout.make_sample_until(min_timesteps=None, min_episodes=50),
rng=rng,
)
return rollout.flatten_trajectories(rollouts)
Environment
- Operating system and version:
- Python version:
- Output of
pip freeze --all
:
degen2
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working