-
Notifications
You must be signed in to change notification settings - Fork 288
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug description
Your adversarial model implementations, including GAIL and AIRL, does not work well in MuJoCo environments. Tested on Hopper, HalfCheetah and Humanoid, and both AIRL and GAIL failed to reach a meaningful score after 100k and 1 million steps of training.
Steps to reproduce
import numpy as np
import gym
from stable_baselines3 import SAC
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy
from imitation.algorithms.adversarial.gail import GAIL
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env
rng = np.random.default_rng(0)
env = gym.make("HalfCheetah-v3")
expert = SAC(policy=MlpPolicy, env=env, n_steps=64)
expert.learn(100000)
rollouts = rollout.rollout(
expert,
make_vec_env(
"HalfCheetah-v3",
n_envs=5,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
rng=rng,
),
rollout.make_sample_until(min_timesteps=100000, min_episodes=60),
rng=rng,
)
venv = make_vec_env("HalfCheetah-v3", n_envs=8, rng=rng)
learner = SAC(env=venv, policy=MlpPolicy)
reward_net = BasicRewardNet(
venv.observation_space,
venv.action_space,
normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
demonstrations=rollouts,
demo_batch_size=1024,
gen_replay_buffer_capacity=2048,
n_disc_updates_per_round=4,
venv=venv,
gen_algo=learner,
reward_net=reward_net,
)
gail_trainer.train(100000)
rewards, _ = evaluate_policy(learner, venv, 100, return_episode_rewards=True)
print("Rewards:", rewards)
Environment
- Operating system and version: Ubuntu 20.04.5 LTS
- Python version: 3.8.10
- Output of
pip freeze --all
:
absl-py==1.4.0
ale-py==0.7.5
astunparse==1.6.3
baselines==0.1.5
cachetools==5.3.0
certifi==2022.12.7
cffi==1.15.1
chai-sacred==0.8.3
charset-normalizer==3.0.1
click==8.1.3
cloudpickle==2.2.1
colorama==0.4.6
contourpy==1.0.7
cycler==0.11.0
Cython==0.29.33
dill==0.3.6
docopt==0.6.2
fasteners==0.18
filelock==3.9.0
flatbuffers==23.1.21
fonttools==4.38.0
gast==0.4.0
gitdb==4.0.10
GitPython==3.1.30
glfw==2.5.5
google-auth==2.16.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
grpcio==1.51.1
gym==0.21.0
h5py==3.8.0
huggingface-hub==0.12.0
huggingface-sb3==2.2.4
idna==3.4
imageio==2.25.0
imitation==0.3.2
importlib-metadata==4.13.0
importlib-resources==5.10.2
joblib==1.2.0
jsonpickle==3.0.1
keras==2.11.0
kiwisolver==1.4.4
libclang==15.0.6.1
lockfile==0.12.2
Markdown==3.4.1
MarkupSafe==2.1.2
matplotlib==3.6.3
mpi4py==3.1.4
mujoco-py==2.1.2.14
munch==2.5.0
numpy==1.24.1
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==23.0
pandas==1.5.3
Pillow==9.4.0
pip==20.0.2
pkg-resources==0.0.0
progressbar2==4.2.0
protobuf==3.19.6
py-cpuinfo==9.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.21
pyglet==1.5.27
pyparsing==3.0.9
python-dateutil==2.8.2
python-utils==3.4.5
pytz==2022.7.1
PyYAML==6.0
pyzmq==25.0.0
requests==2.28.2
requests-oauthlib==1.3.1
rsa==4.9
scikit-learn==1.2.1
scipy==1.10.0
seals==0.1.5
setuptools==44.0.0
six==1.16.0
smmap==5.0.0
stable-baselines3==1.7.0
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-estimator==2.11.0
tensorflow-io-gcs-filesystem==0.30.0
termcolor==2.2.0
threadpoolctl==3.1.0
torch==1.13.1+cu116
torch-tb-profiler==0.4.1
torchaudio==0.13.1+cu116
torchvision==0.14.1+cu116
tqdm==4.64.1
typing-extensions==4.4.0
urllib3==1.26.14
wasabi==1.1.1
Werkzeug==2.2.2
wheel==0.34.2
wrapt==1.14.1
zipp==3.11.0
zmq==0.0.0
Rowing0914
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working