Skip to content

GAIL and AIRL don't work #680

@mertalbaba

Description

@mertalbaba

Bug description

Your adversarial model implementations, including GAIL and AIRL, does not work well in MuJoCo environments. Tested on Hopper, HalfCheetah and Humanoid, and both AIRL and GAIL failed to reach a meaningful score after 100k and 1 million steps of training.

Steps to reproduce

import numpy as np
import gym
from stable_baselines3 import SAC
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms.adversarial.gail import GAIL
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env

rng = np.random.default_rng(0)

env = gym.make("HalfCheetah-v3")
expert = SAC(policy=MlpPolicy, env=env, n_steps=64)
expert.learn(100000)

rollouts = rollout.rollout(
    expert,
    make_vec_env(
        "HalfCheetah-v3",
        n_envs=5,
        post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
        rng=rng,
    ),
    rollout.make_sample_until(min_timesteps=100000, min_episodes=60),
    rng=rng,
)

venv = make_vec_env("HalfCheetah-v3", n_envs=8, rng=rng)
learner = SAC(env=venv, policy=MlpPolicy)
reward_net = BasicRewardNet(
    venv.observation_space,
    venv.action_space,
    normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=2048,
    n_disc_updates_per_round=4,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net,
)

gail_trainer.train(100000)
rewards, _ = evaluate_policy(learner, venv, 100, return_episode_rewards=True)
print("Rewards:", rewards)

Environment

  • Operating system and version: Ubuntu 20.04.5 LTS
  • Python version: 3.8.10
  • Output of pip freeze --all:
absl-py==1.4.0
ale-py==0.7.5
astunparse==1.6.3
baselines==0.1.5
cachetools==5.3.0
certifi==2022.12.7
cffi==1.15.1
chai-sacred==0.8.3
charset-normalizer==3.0.1
click==8.1.3
cloudpickle==2.2.1
colorama==0.4.6
contourpy==1.0.7
cycler==0.11.0
Cython==0.29.33
dill==0.3.6
docopt==0.6.2
fasteners==0.18
filelock==3.9.0
flatbuffers==23.1.21
fonttools==4.38.0
gast==0.4.0
gitdb==4.0.10
GitPython==3.1.30
glfw==2.5.5
google-auth==2.16.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
grpcio==1.51.1
gym==0.21.0
h5py==3.8.0
huggingface-hub==0.12.0
huggingface-sb3==2.2.4
idna==3.4
imageio==2.25.0
imitation==0.3.2
importlib-metadata==4.13.0
importlib-resources==5.10.2
joblib==1.2.0
jsonpickle==3.0.1
keras==2.11.0
kiwisolver==1.4.4
libclang==15.0.6.1
lockfile==0.12.2
Markdown==3.4.1
MarkupSafe==2.1.2
matplotlib==3.6.3
mpi4py==3.1.4
mujoco-py==2.1.2.14
munch==2.5.0
numpy==1.24.1
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==23.0
pandas==1.5.3
Pillow==9.4.0
pip==20.0.2
pkg-resources==0.0.0
progressbar2==4.2.0
protobuf==3.19.6
py-cpuinfo==9.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.21
pyglet==1.5.27
pyparsing==3.0.9
python-dateutil==2.8.2
python-utils==3.4.5
pytz==2022.7.1
PyYAML==6.0
pyzmq==25.0.0
requests==2.28.2
requests-oauthlib==1.3.1
rsa==4.9
scikit-learn==1.2.1
scipy==1.10.0
seals==0.1.5
setuptools==44.0.0
six==1.16.0
smmap==5.0.0
stable-baselines3==1.7.0
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-estimator==2.11.0
tensorflow-io-gcs-filesystem==0.30.0
termcolor==2.2.0
threadpoolctl==3.1.0
torch==1.13.1+cu116
torch-tb-profiler==0.4.1
torchaudio==0.13.1+cu116
torchvision==0.14.1+cu116
tqdm==4.64.1
typing-extensions==4.4.0
urllib3==1.26.14
wasabi==1.1.1
Werkzeug==2.2.2
wheel==0.34.2
wrapt==1.14.1
zipp==3.11.0
zmq==0.0.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions