diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 24390204e..0e78317f9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,6 @@ jobs: strategy: matrix: python-version: [3.7, 3.8, 3.9] - steps: - uses: actions/checkout@v2 with: @@ -30,13 +29,13 @@ jobs: run: | python -m pip install --upgrade pip # cpu version of pytorch - faster to download - pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install pybullet==3.1.9 pip install -r requirements.txt # Use headless version pip install opencv-python-headless # install parking-env to test HER (pinned so it works with gym 0.21) - pip install git+https://github.com/eleurent/highway-env@1a04c6a98be64632cf9683625022023e70ff1ab1 + pip install highway-env==1.5.0 - name: Type check run: | make type diff --git a/.github/workflows/trained_agents.yml b/.github/workflows/trained_agents.yml index 39a970f85..a047e9861 100644 --- a/.github/workflows/trained_agents.yml +++ b/.github/workflows/trained_agents.yml @@ -29,13 +29,13 @@ jobs: run: | python -m pip install --upgrade pip # cpu version of pytorch - faster to download - pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install pybullet==3.1.9 pip install -r requirements.txt # Use headless version pip install opencv-python-headless # install parking-env to test HER (pinned so it works with gym 0.21) - pip install git+https://github.com/eleurent/highway-env@1a04c6a98be64632cf9683625022023e70ff1ab1 + pip install highway-env==1.5.0 # Add support for pickle5 protocol pip install pickle5 - name: Check trained agents diff --git a/.gitignore b/.gitignore index e44316a54..2032c6f01 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,7 @@ git_rewrite_commit_history.sh .vscode/ wandb runs +hub +*.mp4 +*.json +*.csv diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 6b6dfe1b3..512bbbd13 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -6,17 +6,25 @@ variables: type-check: script: + - pip install git+https://github.com/huggingface/huggingface_sb3 - make type pytest: script: # MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error + # tmp fix to have RecurrentPPO, will be fixed with new image + - pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib + - pip install git+https://github.com/huggingface/huggingface_sb3 - MKL_THREADING_LAYER=GNU make pytest + coverage: '/^TOTAL.+?(\d+\%)$/' check-trained-agents: script: # MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error - pip install pickle5 # Add support for pickle5 protocol + - pip install git+https://github.com/huggingface/huggingface_sb3 + # tmp fix to have RecurrentPPO, will be fixed with new image + - pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib - MKL_THREADING_LAYER=GNU make check-trained-agents lint: diff --git a/CHANGELOG.md b/CHANGELOG.md index 25c7dd5ab..4828362d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,18 +1,28 @@ -## Release 1.5.1a0 (WIP) +## Release 1.5.1a8 (WIP) ### Breaking Changes - Change default value for number of hyperparameter optimization trials from 10 to 500. (@ernestum) - Derive number of intermediate pruning evaluations from number of time steps (1 evaluation per 100k time steps.) (@ernestum) -- Updated default --eval-freq from 10k to 25k steps +- Updated default --eval-freq from 10k to 25k steps +- Update default horizon to 2 for the `HistoryWrapper` ### New Features - Support setting PyTorch's device with thye `--device` flag (@gregwar) +- Add `--max-total-trials` parameter to help with distributed optimization. (@ernestum) +- Added `vec_env_wrapper` support in the config (works the same as `env_wrapper`) +- Added Huggingface hub integration +- Added `RecurrentPPO` support (aka `ppo_lstm`) +- Added autodownload for "official" sb3 models from the hub ### Bug fixes +- Fix `Reacher-v3` name in PPO hyperparameter file +- Pinned ale-py==0.7.4 until new SB3 version is released +- Fix enjoy / record videos with LSTM policy ### Documentation ### Other +- When pruner is set to `"none"`, use `NopPruner` instead of diverted `MedianPruner` (@qgallouedec) ## Release 1.5.0 (2022-03-25) diff --git a/README.md b/README.md index 9b760fa3a..dd03e49e0 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ python scripts/plot_train.py -a her -e Fetch -y success -f rl-trained-agents/ -w Plot evaluation reward curve for TQC, SAC and TD3 on the HalfCheetah and Ant PyBullet environments: ``` -python scripts/all_plots.py -a sac td3 tqc --env HalfCheetah Ant -f rl-trained-agents/ +python3 scripts/all_plots.py -a sac td3 tqc --env HalfCheetahBullet AntBullet -f rl-trained-agents/ ``` ## Plot with the rliable library @@ -139,9 +139,46 @@ To load a checkpoint (here the checkpoint name is `rl_model_10000_steps.zip`): python enjoy.py --algo algo_name --env env_id -f logs/ --exp-id 1 --load-checkpoint 10000 ``` -To load the latest checkpoint: +## Offline Training with d3rlpy + +``` +python train.py --algo human --env donkey-generated-track-v0 --env-kwargs frame_skip:1 throttle_max:2.0 throttle_min:0.0 steering_min:-0.5 steering_max:0.5 level:6 max_cte:100000 test_mode:False --num-threads 2 --eval-freq -1 \ + -b logs/human/donkey-generated-track-v0_1/replay_buffer.pkl \ + --pretrain-params batch_size:256 n_eval_episodes:1 n_epochs:20 n_iterations:20 \ + --offline-algo bc +``` + +Params: + +```python + +n_iterations = args.pretrain_params.get("n_iterations", 10) +n_epochs = args.pretrain_params.get("n_epochs", 1) +q_func_type = args.pretrain_params.get("q_func_type") +batch_size = args.pretrain_params.get("batch_size", 512) +# n_action_samples = args.pretrain_params.get("n_action_samples", 1) +n_eval_episodes = args.pretrain_params.get("n_eval_episodes", 5) +add_to_buffer = args.pretrain_params.get("add_to_buffer", False) +deterministic = args.pretrain_params.get("deterministic", True) +``` + +## Human driving + +``` +python train.py --algo human --env donkey-generated-track-v0 --env-kwargs frame_skip:1 throttle_max:1.0 throttle_min:-1.0 steering_min:-1 steering_max:1 level:6 max_cte:100000 --num-threads 2 --eval-freq -1 +``` + +## Huggingface Hub Integration + +Upload model to hub (same syntax as for `enjoy.py`): +``` +python -m utils.push_to_hub --algo ppo --env CartPole-v1 -f logs/ -orga sb3 -m "Initial commit" ``` -python enjoy.py --algo algo_name --env env_id -f logs/ --exp-id 1 --load-last-checkpoint +you can choose custom `repo-name` (default: `{algo}-{env_id}`) by passing a `--repo-name` argument. + +Download model from hub: +``` +python -m utils.load_from_hub --algo ppo --env CartPole-v1 -f logs/ -orga sb3 ``` ## Hyperparameter yaml syntax @@ -243,6 +280,17 @@ env_wrapper: Note that you can easily specify parameters too. +## VecEnvWrapper + +You can specify which `VecEnvWrapper` to use in the config, the same way as for env wrappers (see above), using the `vec_env_wrapper` key: + +For instance: +```yaml +vec_env_wrapper: stable_baselines3.common.vec_env.VecMonitor +``` + +Note: `VecNormalize` is supported separately using `normalize` keyword, and `VecFrameStack` has a dedicated keyword `frame_stack`. + ## Callbacks Following the same syntax as env wrappers, you can also add custom callbacks to use during training. @@ -311,6 +359,8 @@ python -m utils.record_training --algo ppo --env CartPole-v1 -n 1000 -f logs --d Final performance of the trained agents can be found in [`benchmark.md`](./benchmark.md). To compute them, simply run `python -m utils.benchmark`. +List and videos of trained agents can be found on our Huggingface page: https://huggingface.co/sb3 + *NOTE: this is not a quantitative benchmark as it corresponds to only one run (cf [issue #38](https://github.com/araffin/rl-baselines-zoo/issues/38)). This benchmark is meant to check algorithm (maximal) performance, find potential bugs and also allow users to have access to pretrained agents.* ### Atari Games diff --git a/benchmark.md b/benchmark.md index d50f5c1d7..61c5239ac 100644 --- a/benchmark.md +++ b/benchmark.md @@ -8,6 +8,9 @@ during this evaluation. It uses the deterministic policy except for Atari games. +You can view each model card (it includes video and hyperparameters) +on our Huggingface page: https://huggingface.co/sb3 + *NOTE: this is not a quantitative benchmark as it corresponds to only one run (cf [issue #38](https://github.com/araffin/rl-baselines-zoo/issues/38)). This benchmark is meant to check algorithm (maximal) performance, find potential bugs @@ -15,181 +18,185 @@ and also allow users to have access to pretrained agents.* "M" stands for Million (1e6) -|algo | env_id |mean_reward|std_reward|n_timesteps|eval_timesteps|eval_episodes| -|-----|---------------------------|----------:|---------:|-----------|-------------:|------------:| -|a2c |Acrobot-v1 | -83.353| 17.213|500k | 149979| 1778| -|a2c |AntBulletEnv-v0 | 2497.147| 37.359|2M | 150000| 150| -|a2c |AsteroidsNoFrameskip-v4 | 1286.550| 423.750|10M | 614138| 258| -|a2c |BeamRiderNoFrameskip-v4 | 2890.298| 1379.137|10M | 591104| 47| -|a2c |BipedalWalker-v3 | 299.754| 23.459|5M | 149287| 208| -|a2c |BipedalWalkerHardcore-v3 | 96.171| 122.943|200M | 149704| 113| -|a2c |BreakoutNoFrameskip-v4 | 279.793| 122.177|10M | 604115| 82| -|a2c |CartPole-v1 | 500.000| 0.000|500k | 150000| 300| -|a2c |EnduroNoFrameskip-v4 | 0.000| 0.000|10M | 599040| 45| -|a2c |HalfCheetah-v3 | 3041.174| 157.265|1M | 150000| 150| -|a2c |HalfCheetahBulletEnv-v0 | 2107.384| 36.008|2M | 150000| 150| -|a2c |Hopper-v3 | 733.454| 376.574|1M | 149987| 580| -|a2c |HopperBulletEnv-v0 | 815.355| 313.798|2M | 149541| 254| -|a2c |LunarLander-v2 | 155.751| 80.419|200k | 149443| 297| -|a2c |LunarLanderContinuous-v2 | 84.225| 145.906|5M | 149305| 256| -|a2c |MountainCar-v0 | -111.263| 24.087|1M | 149982| 1348| -|a2c |MountainCarContinuous-v0 | 91.166| 0.255|100k | 149923| 1659| -|a2c |Pendulum-v1 | -162.965| 103.210|1M | 150000| 750| -|a2c |PongNoFrameskip-v4 | 17.292| 3.214|10M | 594910| 65| -|a2c |QbertNoFrameskip-v4 | 3882.345| 1223.327|10M | 610670| 194| -|a2c |ReacherBulletEnv-v0 | 14.968| 10.978|2M | 150000| 1000| -|a2c |RoadRunnerNoFrameskip-v4 | 31671.512| 6364.085|10M | 606710| 172| -|a2c |SeaquestNoFrameskip-v4 | 1721.493| 105.339|10M | 599691| 67| -|a2c |SpaceInvadersNoFrameskip-v4| 627.160| 201.974|10M | 604848| 162| -|a2c |Swimmer-v3 | 200.627| 2.544|1M | 150000| 150| -|a2c |Walker2DBulletEnv-v0 | 858.209| 333.116|2M | 149156| 173| -|ars |Acrobot-v1 | -82.884| 23.825|500k | 149985| 1788| -|ars |Ant-v3 | 2333.773| 20.597|75M | 150000| 150| -|ars |CartPole-v1 | 500.000| 0.000|50k | 150000| 300| -|ars |HalfCheetah-v3 | 4815.192| 1340.752|12M | 150000| 150| -|ars |Hopper-v3 | 3343.919| 5.730|7M | 150000| 150| -|ars |LunarLanderContinuous-v2 | 167.959| 147.071|2M | 149883| 562| -|ars |MountainCar-v0 | -122.000| 33.456|500k | 149938| 1229| -|ars |MountainCarContinuous-v0 | 96.672| 0.784|500k | 149990| 621| -|ars |Pendulum-v1 | -212.540| 160.444|2M | 150000| 750| -|ars |Swimmer-v3 | 355.267| 12.796|2M | 150000| 150| -|ars |Walker2d-v3 | 2993.582| 166.289|75M | 149821| 152| -|ddpg |AntBulletEnv-v0 | 2399.147| 75.410|1M | 150000| 150| -|ddpg |BipedalWalker-v3 | 197.486| 141.580|1M | 149237| 227| -|ddpg |HalfCheetahBulletEnv-v0 | 2078.325| 208.379|1M | 150000| 150| -|ddpg |HopperBulletEnv-v0 | 1157.065| 448.695|1M | 149565| 346| -|ddpg |LunarLanderContinuous-v2 | 230.217| 92.372|300k | 149862| 556| -|ddpg |MountainCarContinuous-v0 | 93.512| 0.048|300k | 149965| 2260| -|ddpg |Pendulum-v1 | -152.099| 94.282|20k | 150000| 750| -|ddpg |ReacherBulletEnv-v0 | 15.582| 9.606|300k | 150000| 1000| -|ddpg |Walker2DBulletEnv-v0 | 1387.591| 736.955|1M | 149051| 208| -|dqn |Acrobot-v1 | -76.639| 11.752|100k | 149998| 1932| -|dqn |AsteroidsNoFrameskip-v4 | 782.687| 259.247|10M | 607962| 134| -|dqn |BeamRiderNoFrameskip-v4 | 4295.946| 1790.458|10M | 600832| 37| -|dqn |BreakoutNoFrameskip-v4 | 358.327| 61.981|10M | 601461| 55| -|dqn |CartPole-v1 | 500.000| 0.000|50k | 150000| 300| -|dqn |EnduroNoFrameskip-v4 | 830.929| 194.544|10M | 599040| 14| -|dqn |LunarLander-v2 | 154.382| 79.241|100k | 149373| 200| -|dqn |MountainCar-v0 | -100.849| 9.925|120k | 149962| 1487| -|dqn |PongNoFrameskip-v4 | 20.602| 0.613|10M | 598998| 88| -|dqn |QbertNoFrameskip-v4 | 9496.774| 5399.633|10M | 605844| 124| -|dqn |RoadRunnerNoFrameskip-v4 | 40396.350| 7069.131|10M | 603257| 137| -|dqn |SeaquestNoFrameskip-v4 | 2000.290| 606.644|10M | 599505| 69| -|dqn |SpaceInvadersNoFrameskip-v4| 622.742| 201.564|10M | 604311| 155| -|ppo |Acrobot-v1 | -73.506| 18.201|1M | 149979| 2013| -|ppo |Ant-v3 | 1327.158| 451.577|1M | 149572| 175| -|ppo |AntBulletEnv-v0 | 2865.922| 56.468|2M | 150000| 150| -|ppo |AsteroidsNoFrameskip-v4 | 2156.174| 744.640|10M | 602092| 149| -|ppo |BeamRiderNoFrameskip-v4 | 3397.000| 1662.368|10M | 598926| 46| -|ppo |BipedalWalker-v3 | 213.299| 129.490|5M | 149826| 233| -|ppo |BipedalWalkerHardcore-v3 | 122.374| 117.605|100M | 148036| 105| -|ppo |BreakoutNoFrameskip-v4 | 398.033| 33.328|10M | 600418| 60| -|ppo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300| -|ppo |EnduroNoFrameskip-v4 | 996.364| 176.090|10M | 572416| 11| -|ppo |HalfCheetah-v3 | 5819.099| 663.530|1M | 150000| 150| -|ppo |HalfCheetahBulletEnv-v0 | 2924.721| 64.465|2M | 150000| 150| -|ppo |Hopper-v3 | 2410.435| 10.026|1M | 150000| 150| -|ppo |HopperBulletEnv-v0 | 2575.054| 223.301|2M | 149094| 152| -|ppo |LunarLander-v2 | 242.119| 31.823|1M | 149636| 369| -|ppo |LunarLanderContinuous-v2 | 270.863| 32.072|1M | 149956| 526| -|ppo |MountainCar-v0 | -110.423| 19.473|1M | 149954| 1358| -|ppo |MountainCarContinuous-v0 | 88.343| 2.572|20k | 149983| 633| -|ppo |Pendulum-v1 | -172.225| 104.159|100k | 150000| 750| -|ppo |PongNoFrameskip-v4 | 20.989| 0.105|10M | 599902| 90| -|ppo |QbertNoFrameskip-v4 | 15627.108| 3313.538|10M | 600248| 83| -|ppo |ReacherBulletEnv-v0 | 17.091| 11.048|1M | 150000| 1000| -|ppo |RoadRunnerNoFrameskip-v4 | 40680.645| 6675.058|10M | 605786| 155| -|ppo |SeaquestNoFrameskip-v4 | 1783.636| 34.096|10M | 598243| 66| -|ppo |SpaceInvadersNoFrameskip-v4| 960.331| 425.355|10M | 603771| 136| -|ppo |Swimmer-v3 | 281.561| 9.671|1M | 150000| 150| -|ppo |Walker2DBulletEnv-v0 | 2109.992| 13.899|2M | 150000| 150| -|ppo |Walker2d-v3 | 3478.798| 821.708|1M | 149343| 171| -|qrdqn|Acrobot-v1 | -69.135| 9.967|100k | 149949| 2138| -|qrdqn|AsteroidsNoFrameskip-v4 | 2185.303| 1097.172|10M | 599784| 66| -|qrdqn|BeamRiderNoFrameskip-v4 | 17122.941| 10769.997|10M | 596483| 17| -|qrdqn|BreakoutNoFrameskip-v4 | 393.600| 79.828|10M | 579711| 40| -|qrdqn|CartPole-v1 | 500.000| 0.000|50k | 150000| 300| -|qrdqn|EnduroNoFrameskip-v4 | 3231.200| 1311.801|10M | 585728| 5| -|qrdqn|LunarLander-v2 | 70.236| 225.491|100k | 149957| 522| -|qrdqn|MountainCar-v0 | -106.042| 15.536|120k | 149943| 1414| -|qrdqn|PongNoFrameskip-v4 | 20.492| 0.687|10M | 597443| 63| -|qrdqn|QbertNoFrameskip-v4 | 14799.728| 2917.629|10M | 600773| 92| -|qrdqn|RoadRunnerNoFrameskip-v4 | 42325.424| 8361.161|10M | 591016| 59| -|qrdqn|SeaquestNoFrameskip-v4 | 2557.576| 76.951|10M | 596275| 66| -|qrdqn|SpaceInvadersNoFrameskip-v4| 1899.928| 823.488|10M | 597218| 69| -|sac |Ant-v3 | 4615.791| 1354.111|1M | 149074| 165| -|sac |AntBulletEnv-v0 | 3073.114| 175.148|1M | 150000| 150| -|sac |BipedalWalker-v3 | 297.668| 33.060|500k | 149530| 136| -|sac |BipedalWalkerHardcore-v3 | 4.423| 103.910|10M | 149794| 88| -|sac |HalfCheetah-v3 | 9535.451| 100.470|1M | 150000| 150| -|sac |HalfCheetahBulletEnv-v0 | 2792.170| 12.088|1M | 150000| 150| -|sac |Hopper-v3 | 2325.547| 1129.676|1M | 149841| 236| -|sac |HopperBulletEnv-v0 | 2603.494| 164.322|1M | 149724| 151| -|sac |Humanoid-v3 | 6232.287| 279.885|2M | 149460| 150| -|sac |LunarLanderContinuous-v2 | 260.390| 65.467|500k | 149634| 672| -|sac |MountainCarContinuous-v0 | 94.679| 1.134|50k | 149966| 1443| -|sac |Pendulum-v1 | -156.995| 88.714|20k | 150000| 750| -|sac |ReacherBulletEnv-v0 | 18.062| 9.729|300k | 150000| 1000| -|sac |Swimmer-v3 | 345.568| 3.084|1M | 150000| 150| -|sac |Walker2DBulletEnv-v0 | 2292.266| 13.970|1M | 149983| 150| -|sac |Walker2d-v3 | 3863.203| 254.347|1M | 149309| 150| -|td3 |Ant-v3 | 5813.274| 589.773|1M | 149393| 151| -|td3 |AntBulletEnv-v0 | 3300.026| 54.640|1M | 150000| 150| -|td3 |BipedalWalker-v3 | 305.990| 56.886|1M | 149999| 224| -|td3 |BipedalWalkerHardcore-v3 | -98.116| 16.087|10M | 150000| 75| -|td3 |HalfCheetah-v3 | 9655.666| 969.916|1M | 150000| 150| -|td3 |HalfCheetahBulletEnv-v0 | 2821.641| 19.722|1M | 150000| 150| -|td3 |Hopper-v3 | 3606.390| 4.027|1M | 150000| 150| -|td3 |HopperBulletEnv-v0 | 2681.609| 27.806|1M | 149486| 150| -|td3 |Humanoid-v3 | 5566.687| 14.544|2M | 150000| 150| -|td3 |LunarLanderContinuous-v2 | 207.451| 67.562|300k | 149488| 337| -|td3 |MountainCarContinuous-v0 | 93.483| 0.075|300k | 149976| 2275| -|td3 |Pendulum-v1 | -151.855| 90.227|20k | 150000| 750| -|td3 |ReacherBulletEnv-v0 | 17.114| 9.750|300k | 150000| 1000| -|td3 |Swimmer-v3 | 359.127| 1.244|1M | 150000| 150| -|td3 |Walker2DBulletEnv-v0 | 2213.672| 230.558|1M | 149800| 152| -|td3 |Walker2d-v3 | 4717.823| 46.303|1M | 150000| 150| -|tqc |Ant-v3 | 3339.362| 1969.906|1M | 149583| 202| -|tqc |AntBulletEnv-v0 | 3456.717| 248.733|1M | 150000| 150| -|tqc |BipedalWalker-v3 | 329.808| 45.083|500k | 149682| 254| -|tqc |BipedalWalkerHardcore-v3 | 235.226| 110.569|2M | 149032| 131| -|tqc |FetchPickAndPlace-v1 | -9.331| 6.850|1M | 150000| 3000| -|tqc |FetchPush-v1 | -8.799| 5.438|1M | 150000| 3000| -|tqc |FetchReach-v1 | -1.659| 0.873|20k | 150000| 3000| -|tqc |FetchSlide-v1 | -29.210| 11.387|3M | 150000| 3000| -|tqc |HalfCheetah-v3 | 12089.939| 127.440|1M | 150000| 150| -|tqc |HalfCheetahBulletEnv-v0 | 3675.299| 17.681|1M | 150000| 150| -|tqc |Hopper-v3 | 3754.199| 8.276|1M | 150000| 150| -|tqc |HopperBulletEnv-v0 | 2662.373| 206.210|1M | 149881| 151| -|tqc |Humanoid-v3 | 7239.320| 1647.498|2M | 149508| 165| -|tqc |LunarLanderContinuous-v2 | 277.956| 25.466|500k | 149928| 706| -|tqc |MountainCarContinuous-v0 | 63.641| 45.259|50k | 149796| 186| -|tqc |PandaPickAndPlace-v1 | -8.024| 6.674|1M | 150000| 3000| -|tqc |PandaPush-v1 | -6.405| 6.400|1M | 150000| 3000| -|tqc |PandaReach-v1 | -1.768| 0.858|20k | 150000| 3000| -|tqc |PandaSlide-v1 | -27.497| 9.868|3M | 150000| 3000| -|tqc |PandaStack-v1 | -96.915| 17.240|1M | 150000| 1500| -|tqc |Pendulum-v1 | -151.340| 87.893|20k | 150000| 750| -|tqc |ReacherBulletEnv-v0 | 18.255| 9.543|300k | 150000| 1000| -|tqc |Swimmer-v3 | 339.423| 1.486|1M | 150000| 150| -|tqc |Walker2DBulletEnv-v0 | 2508.934| 614.624|1M | 149572| 159| -|tqc |Walker2d-v3 | 4380.720| 500.489|1M | 149606| 152| -|tqc |parking-v0 | -6.762| 2.690|100k | 149983| 7528| -|trpo |Acrobot-v1 | -83.114| 18.648|100k | 149976| 1783| -|trpo |Ant-v3 | 4982.301| 663.761|1M | 149909| 153| -|trpo |AntBulletEnv-v0 | 2560.621| 52.064|2M | 150000| 150| -|trpo |BipedalWalker-v3 | 182.339| 145.570|1M | 148440| 148| -|trpo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300| -|trpo |HalfCheetah-v3 | 1785.476| 68.672|1M | 150000| 150| -|trpo |HalfCheetahBulletEnv-v0 | 2758.752| 327.032|2M | 150000| 150| -|trpo |Hopper-v3 | 3618.386| 356.768|1M | 149575| 152| -|trpo |HopperBulletEnv-v0 | 2565.416| 410.298|1M | 149640| 154| -|trpo |LunarLander-v2 | 133.166| 112.173|200k | 149088| 230| -|trpo |LunarLanderContinuous-v2 | 262.387| 21.428|200k | 149925| 501| -|trpo |MountainCar-v0 | -107.278| 13.231|100k | 149974| 1398| -|trpo |MountainCarContinuous-v0 | 92.489| 0.355|50k | 149971| 1732| -|trpo |Pendulum-v1 | -174.631| 127.577|100k | 150000| 750| -|trpo |ReacherBulletEnv-v0 | 14.741| 11.559|300k | 150000| 1000| -|trpo |Swimmer-v3 | 365.663| 2.087|1M | 150000| 150| -|trpo |Walker2DBulletEnv-v0 | 1483.467| 823.468|2M | 149860| 197| -|trpo |Walker2d-v3 | 4933.148| 1452.538|1M | 149054| 163| +| algo | env_id |mean_reward|std_reward|n_timesteps|eval_timesteps|eval_episodes| +|--------|-----------------------------|----------:|---------:|-----------|-------------:|------------:| +|a2c |Acrobot-v1 | -83.353| 17.213|500k | 149979| 1778| +|a2c |AntBulletEnv-v0 | 2497.147| 37.359|2M | 150000| 150| +|a2c |AsteroidsNoFrameskip-v4 | 1286.550| 423.750|10M | 614138| 258| +|a2c |BeamRiderNoFrameskip-v4 | 2890.298| 1379.137|10M | 591104| 47| +|a2c |BipedalWalker-v3 | 299.754| 23.459|5M | 149287| 208| +|a2c |BipedalWalkerHardcore-v3 | 96.171| 122.943|200M | 149704| 113| +|a2c |BreakoutNoFrameskip-v4 | 279.793| 122.177|10M | 604115| 82| +|a2c |CartPole-v1 | 500.000| 0.000|500k | 150000| 300| +|a2c |EnduroNoFrameskip-v4 | 0.000| 0.000|10M | 599040| 45| +|a2c |HalfCheetah-v3 | 3041.174| 157.265|1M | 150000| 150| +|a2c |HalfCheetahBulletEnv-v0 | 2107.384| 36.008|2M | 150000| 150| +|a2c |Hopper-v3 | 733.454| 376.574|1M | 149987| 580| +|a2c |HopperBulletEnv-v0 | 815.355| 313.798|2M | 149541| 254| +|a2c |LunarLander-v2 | 155.751| 80.419|200k | 149443| 297| +|a2c |LunarLanderContinuous-v2 | 84.225| 145.906|5M | 149305| 256| +|a2c |MountainCar-v0 | -111.263| 24.087|1M | 149982| 1348| +|a2c |MountainCarContinuous-v0 | 91.166| 0.255|100k | 149923| 1659| +|a2c |Pendulum-v1 | -162.965| 103.210|1M | 150000| 750| +|a2c |PongNoFrameskip-v4 | 17.292| 3.214|10M | 594910| 65| +|a2c |QbertNoFrameskip-v4 | 3882.345| 1223.327|10M | 610670| 194| +|a2c |ReacherBulletEnv-v0 | 14.968| 10.978|2M | 150000| 1000| +|a2c |RoadRunnerNoFrameskip-v4 | 31671.512| 6364.085|10M | 606710| 172| +|a2c |SeaquestNoFrameskip-v4 | 1721.493| 105.339|10M | 599691| 67| +|a2c |SpaceInvadersNoFrameskip-v4 | 627.160| 201.974|10M | 604848| 162| +|a2c |Swimmer-v3 | 200.627| 2.544|1M | 150000| 150| +|a2c |Walker2DBulletEnv-v0 | 858.209| 333.116|2M | 149156| 173| +|ars |Acrobot-v1 | -82.884| 23.825|500k | 149985| 1788| +|ars |Ant-v3 | 2333.773| 20.597|75M | 150000| 150| +|ars |CartPole-v1 | 500.000| 0.000|50k | 150000| 300| +|ars |HalfCheetah-v3 | 4815.192| 1340.752|12M | 150000| 150| +|ars |Hopper-v3 | 3343.919| 5.730|7M | 150000| 150| +|ars |LunarLanderContinuous-v2 | 167.959| 147.071|2M | 149883| 562| +|ars |MountainCar-v0 | -122.000| 33.456|500k | 149938| 1229| +|ars |MountainCarContinuous-v0 | 96.672| 0.784|500k | 149990| 621| +|ars |Pendulum-v1 | -212.540| 160.444|2M | 150000| 750| +|ars |Swimmer-v3 | 355.267| 12.796|2M | 150000| 150| +|ars |Walker2d-v3 | 2993.582| 166.289|75M | 149821| 152| +|ddpg |AntBulletEnv-v0 | 2399.147| 75.410|1M | 150000| 150| +|ddpg |BipedalWalker-v3 | 197.486| 141.580|1M | 149237| 227| +|ddpg |HalfCheetahBulletEnv-v0 | 2078.325| 208.379|1M | 150000| 150| +|ddpg |HopperBulletEnv-v0 | 1157.065| 448.695|1M | 149565| 346| +|ddpg |LunarLanderContinuous-v2 | 230.217| 92.372|300k | 149862| 556| +|ddpg |MountainCarContinuous-v0 | 93.512| 0.048|300k | 149965| 2260| +|ddpg |Pendulum-v1 | -152.099| 94.282|20k | 150000| 750| +|ddpg |ReacherBulletEnv-v0 | 15.582| 9.606|300k | 150000| 1000| +|ddpg |Walker2DBulletEnv-v0 | 1387.591| 736.955|1M | 149051| 208| +|dqn |Acrobot-v1 | -76.639| 11.752|100k | 149998| 1932| +|dqn |AsteroidsNoFrameskip-v4 | 782.687| 259.247|10M | 607962| 134| +|dqn |BeamRiderNoFrameskip-v4 | 4295.946| 1790.458|10M | 600832| 37| +|dqn |BreakoutNoFrameskip-v4 | 358.327| 61.981|10M | 601461| 55| +|dqn |CartPole-v1 | 500.000| 0.000|50k | 150000| 300| +|dqn |EnduroNoFrameskip-v4 | 830.929| 194.544|10M | 599040| 14| +|dqn |LunarLander-v2 | 154.382| 79.241|100k | 149373| 200| +|dqn |MountainCar-v0 | -100.849| 9.925|120k | 149962| 1487| +|dqn |PongNoFrameskip-v4 | 20.602| 0.613|10M | 598998| 88| +|dqn |QbertNoFrameskip-v4 | 9496.774| 5399.633|10M | 605844| 124| +|dqn |RoadRunnerNoFrameskip-v4 | 40396.350| 7069.131|10M | 603257| 137| +|dqn |SeaquestNoFrameskip-v4 | 2000.290| 606.644|10M | 599505| 69| +|dqn |SpaceInvadersNoFrameskip-v4 | 622.742| 201.564|10M | 604311| 155| +|ppo |Acrobot-v1 | -73.506| 18.201|1M | 149979| 2013| +|ppo |Ant-v3 | 1327.158| 451.577|1M | 149572| 175| +|ppo |AntBulletEnv-v0 | 2865.922| 56.468|2M | 150000| 150| +|ppo |AsteroidsNoFrameskip-v4 | 2156.174| 744.640|10M | 602092| 149| +|ppo |BeamRiderNoFrameskip-v4 | 3397.000| 1662.368|10M | 598926| 46| +|ppo |BipedalWalker-v3 | 287.939| 2.448|5M | 149589| 123| +|ppo |BipedalWalkerHardcore-v3 | 122.374| 117.605|100M | 148036| 105| +|ppo |BreakoutNoFrameskip-v4 | 398.033| 33.328|10M | 600418| 60| +|ppo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300| +|ppo |EnduroNoFrameskip-v4 | 996.364| 176.090|10M | 572416| 11| +|ppo |HalfCheetah-v3 | 5819.099| 663.530|1M | 150000| 150| +|ppo |HalfCheetahBulletEnv-v0 | 2924.721| 64.465|2M | 150000| 150| +|ppo |Hopper-v3 | 2410.435| 10.026|1M | 150000| 150| +|ppo |HopperBulletEnv-v0 | 2575.054| 223.301|2M | 149094| 152| +|ppo |LunarLander-v2 | 242.119| 31.823|1M | 149636| 369| +|ppo |LunarLanderContinuous-v2 | 270.863| 32.072|1M | 149956| 526| +|ppo |MountainCar-v0 | -110.423| 19.473|1M | 149954| 1358| +|ppo |MountainCarContinuous-v0 | 88.343| 2.572|20k | 149983| 633| +|ppo |Pendulum-v1 | -172.225| 104.159|100k | 150000| 750| +|ppo |PongNoFrameskip-v4 | 20.989| 0.105|10M | 599902| 90| +|ppo |QbertNoFrameskip-v4 | 15627.108| 3313.538|10M | 600248| 83| +|ppo |ReacherBulletEnv-v0 | 17.091| 11.048|1M | 150000| 1000| +|ppo |RoadRunnerNoFrameskip-v4 | 40680.645| 6675.058|10M | 605786| 155| +|ppo |SeaquestNoFrameskip-v4 | 1783.636| 34.096|10M | 598243| 66| +|ppo |SpaceInvadersNoFrameskip-v4 | 960.331| 425.355|10M | 603771| 136| +|ppo |Swimmer-v3 | 281.561| 9.671|1M | 150000| 150| +|ppo |Walker2DBulletEnv-v0 | 2109.992| 13.899|2M | 150000| 150| +|ppo |Walker2d-v3 | 3478.798| 821.708|1M | 149343| 171| +|ppo_lstm|CarRacing-v0 | 862.549| 97.342|4M | 149588| 156| +|ppo_lstm|CartPoleNoVel-v1 | 500.000| 0.000|100k | 150000| 300| +|ppo_lstm|MountainCarContinuousNoVel-v0| 91.469| 1.776|300k | 149882| 1340| +|ppo_lstm|PendulumNoVel-v1 | -217.933| 140.094|100k | 150000| 750| +|qrdqn |Acrobot-v1 | -69.135| 9.967|100k | 149949| 2138| +|qrdqn |AsteroidsNoFrameskip-v4 | 2185.303| 1097.172|10M | 599784| 66| +|qrdqn |BeamRiderNoFrameskip-v4 | 17122.941| 10769.997|10M | 596483| 17| +|qrdqn |BreakoutNoFrameskip-v4 | 393.600| 79.828|10M | 579711| 40| +|qrdqn |CartPole-v1 | 500.000| 0.000|50k | 150000| 300| +|qrdqn |EnduroNoFrameskip-v4 | 3231.200| 1311.801|10M | 585728| 5| +|qrdqn |LunarLander-v2 | 70.236| 225.491|100k | 149957| 522| +|qrdqn |MountainCar-v0 | -106.042| 15.536|120k | 149943| 1414| +|qrdqn |PongNoFrameskip-v4 | 20.492| 0.687|10M | 597443| 63| +|qrdqn |QbertNoFrameskip-v4 | 14799.728| 2917.629|10M | 600773| 92| +|qrdqn |RoadRunnerNoFrameskip-v4 | 42325.424| 8361.161|10M | 591016| 59| +|qrdqn |SeaquestNoFrameskip-v4 | 2557.576| 76.951|10M | 596275| 66| +|qrdqn |SpaceInvadersNoFrameskip-v4 | 1899.928| 823.488|10M | 597218| 69| +|sac |Ant-v3 | 4615.791| 1354.111|1M | 149074| 165| +|sac |AntBulletEnv-v0 | 3073.114| 175.148|1M | 150000| 150| +|sac |BipedalWalker-v3 | 297.668| 33.060|500k | 149530| 136| +|sac |BipedalWalkerHardcore-v3 | 4.423| 103.910|10M | 149794| 88| +|sac |HalfCheetah-v3 | 9535.451| 100.470|1M | 150000| 150| +|sac |HalfCheetahBulletEnv-v0 | 2792.170| 12.088|1M | 150000| 150| +|sac |Hopper-v3 | 2325.547| 1129.676|1M | 149841| 236| +|sac |HopperBulletEnv-v0 | 2603.494| 164.322|1M | 149724| 151| +|sac |Humanoid-v3 | 6232.287| 279.885|2M | 149460| 150| +|sac |LunarLanderContinuous-v2 | 260.390| 65.467|500k | 149634| 672| +|sac |MountainCarContinuous-v0 | 94.679| 1.134|50k | 149966| 1443| +|sac |Pendulum-v1 | -156.995| 88.714|20k | 150000| 750| +|sac |ReacherBulletEnv-v0 | 18.062| 9.729|300k | 150000| 1000| +|sac |Swimmer-v3 | 345.568| 3.084|1M | 150000| 150| +|sac |Walker2DBulletEnv-v0 | 2292.266| 13.970|1M | 149983| 150| +|sac |Walker2d-v3 | 3863.203| 254.347|1M | 149309| 150| +|td3 |Ant-v3 | 5813.274| 589.773|1M | 149393| 151| +|td3 |AntBulletEnv-v0 | 3300.026| 54.640|1M | 150000| 150| +|td3 |BipedalWalker-v3 | 305.990| 56.886|1M | 149999| 224| +|td3 |BipedalWalkerHardcore-v3 | -98.116| 16.087|10M | 150000| 75| +|td3 |HalfCheetah-v3 | 9655.666| 969.916|1M | 150000| 150| +|td3 |HalfCheetahBulletEnv-v0 | 2821.641| 19.722|1M | 150000| 150| +|td3 |Hopper-v3 | 3606.390| 4.027|1M | 150000| 150| +|td3 |HopperBulletEnv-v0 | 2681.609| 27.806|1M | 149486| 150| +|td3 |Humanoid-v3 | 5566.687| 14.544|2M | 150000| 150| +|td3 |LunarLanderContinuous-v2 | 207.451| 67.562|300k | 149488| 337| +|td3 |MountainCarContinuous-v0 | 93.483| 0.075|300k | 149976| 2275| +|td3 |Pendulum-v1 | -151.855| 90.227|20k | 150000| 750| +|td3 |ReacherBulletEnv-v0 | 17.114| 9.750|300k | 150000| 1000| +|td3 |Swimmer-v3 | 359.127| 1.244|1M | 150000| 150| +|td3 |Walker2DBulletEnv-v0 | 2213.672| 230.558|1M | 149800| 152| +|td3 |Walker2d-v3 | 4717.823| 46.303|1M | 150000| 150| +|tqc |Ant-v3 | 3339.362| 1969.906|1M | 149583| 202| +|tqc |AntBulletEnv-v0 | 3456.717| 248.733|1M | 150000| 150| +|tqc |BipedalWalker-v3 | 329.808| 45.083|500k | 149682| 254| +|tqc |BipedalWalkerHardcore-v3 | 235.226| 110.569|2M | 149032| 131| +|tqc |FetchPickAndPlace-v1 | -9.331| 6.850|1M | 150000| 3000| +|tqc |FetchPush-v1 | -8.799| 5.438|1M | 150000| 3000| +|tqc |FetchReach-v1 | -1.659| 0.873|20k | 150000| 3000| +|tqc |FetchSlide-v1 | -29.210| 11.387|3M | 150000| 3000| +|tqc |HalfCheetah-v3 | 12089.939| 127.440|1M | 150000| 150| +|tqc |HalfCheetahBulletEnv-v0 | 3675.299| 17.681|1M | 150000| 150| +|tqc |Hopper-v3 | 3754.199| 8.276|1M | 150000| 150| +|tqc |HopperBulletEnv-v0 | 2662.373| 206.210|1M | 149881| 151| +|tqc |Humanoid-v3 | 7239.320| 1647.498|2M | 149508| 165| +|tqc |LunarLanderContinuous-v2 | 277.956| 25.466|500k | 149928| 706| +|tqc |MountainCarContinuous-v0 | 63.641| 45.259|50k | 149796| 186| +|tqc |PandaPickAndPlace-v1 | -8.024| 6.674|1M | 150000| 3000| +|tqc |PandaPush-v1 | -6.405| 6.400|1M | 150000| 3000| +|tqc |PandaReach-v1 | -1.768| 0.858|20k | 150000| 3000| +|tqc |PandaSlide-v1 | -27.497| 9.868|3M | 150000| 3000| +|tqc |PandaStack-v1 | -96.915| 17.240|1M | 150000| 1500| +|tqc |Pendulum-v1 | -151.340| 87.893|20k | 150000| 750| +|tqc |ReacherBulletEnv-v0 | 18.255| 9.543|300k | 150000| 1000| +|tqc |Swimmer-v3 | 339.423| 1.486|1M | 150000| 150| +|tqc |Walker2DBulletEnv-v0 | 2508.934| 614.624|1M | 149572| 159| +|tqc |Walker2d-v3 | 4380.720| 500.489|1M | 149606| 152| +|tqc |parking-v0 | -6.762| 2.690|100k | 149983| 7528| +|trpo |Acrobot-v1 | -83.114| 18.648|100k | 149976| 1783| +|trpo |Ant-v3 | 4982.301| 663.761|1M | 149909| 153| +|trpo |AntBulletEnv-v0 | 2560.621| 52.064|2M | 150000| 150| +|trpo |BipedalWalker-v3 | 182.339| 145.570|1M | 148440| 148| +|trpo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300| +|trpo |HalfCheetah-v3 | 1785.476| 68.672|1M | 150000| 150| +|trpo |HalfCheetahBulletEnv-v0 | 2758.752| 327.032|2M | 150000| 150| +|trpo |Hopper-v3 | 3618.386| 356.768|1M | 149575| 152| +|trpo |HopperBulletEnv-v0 | 2565.416| 410.298|1M | 149640| 154| +|trpo |LunarLander-v2 | 133.166| 112.173|200k | 149088| 230| +|trpo |LunarLanderContinuous-v2 | 262.387| 21.428|200k | 149925| 501| +|trpo |MountainCar-v0 | -107.278| 13.231|100k | 149974| 1398| +|trpo |MountainCarContinuous-v0 | 92.489| 0.355|50k | 149971| 1732| +|trpo |Pendulum-v1 | -174.631| 127.577|100k | 150000| 750| +|trpo |ReacherBulletEnv-v0 | 14.741| 11.559|300k | 150000| 1000| +|trpo |Swimmer-v3 | 365.663| 2.087|1M | 150000| 150| +|trpo |Walker2DBulletEnv-v0 | 1483.467| 823.468|2M | 149860| 197| +|trpo |Walker2d-v3 | 4933.148| 1452.538|1M | 149054| 163| diff --git a/docker/Dockerfile b/docker/Dockerfile index a99e9a7b1..2baa348c1 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -21,7 +21,7 @@ RUN \ mkdir -p ${CODE_DIR}/rl_zoo && \ pip uninstall -y stable-baselines3 && \ pip install -r /tmp/requirements.txt && \ - pip install git+https://github.com/eleurent/highway-env@1a04c6a98be64632cf9683625022023e70ff1ab1 && \ + pip install pip install highway-env==1.5.0 && \ rm -rf $HOME/.cache/pip ENV PATH=$VENV/bin:$PATH diff --git a/enjoy.py b/enjoy.py index c92304fb2..b2a1ecaee 100644 --- a/enjoy.py +++ b/enjoy.py @@ -1,4 +1,5 @@ import argparse +import asyncio import glob import importlib import os @@ -6,13 +7,34 @@ import numpy as np import torch as th +import websockets import yaml from stable_baselines3.common.utils import set_random_seed import utils.import_envs # noqa: F401 pylint: disable=unused-import -from utils import ALGOS, create_test_env, get_latest_run_id, get_saved_hyperparams +from utils import ALGOS, create_test_env, get_saved_hyperparams from utils.exp_manager import ExperimentManager -from utils.utils import StoreDict +from utils.load_from_hub import download_from_hub +from utils.utils import StoreDict, get_model_path + +EXIT = False +socket_port = int(os.environ.get("SOCKET_PORT", 8895)) + +# To unlock the start +# echo '{"angle":0,"throttle":0,"drive_mode":"local"}' | websocat "ws://127.0.0.1:8895/wsDrive" + + +async def handler(websocket): + _ = await websocket.recv() + global EXIT + EXIT = True + + +async def main_wait(): + async with websockets.serve(handler, "", socket_port): + while not EXIT: + await asyncio.sleep(0.05) + print("Exiting socket server") def main(): # noqa: C901 @@ -71,48 +93,43 @@ def main(): # noqa: C901 algo = args.algo folder = args.folder - if args.exp_id == 0: - args.exp_id = get_latest_run_id(os.path.join(folder, algo), env_id) - print(f"Loading latest experiment, id={args.exp_id}") - - # Sanity checks - if args.exp_id > 0: - log_path = os.path.join(folder, algo, f"{env_id}_{args.exp_id}") - else: - log_path = os.path.join(folder, algo) - - assert os.path.isdir(log_path), f"The {log_path} folder was not found" - - found = False - for ext in ["zip"]: - model_path = os.path.join(log_path, f"{env_id}.{ext}") - found = os.path.isfile(model_path) - if found: - break - - if args.load_best: - model_path = os.path.join(log_path, "best_model.zip") - found = os.path.isfile(model_path) - - if args.load_checkpoint is not None: - model_path = os.path.join(log_path, f"rl_model_{args.load_checkpoint}_steps.zip") - found = os.path.isfile(model_path) - - if args.load_last_checkpoint: - checkpoints = glob.glob(os.path.join(log_path, "rl_model_*_steps.zip")) - if len(checkpoints) == 0: - raise ValueError(f"No checkpoint found for {algo} on {env_id}, path: {log_path}") - - def step_count(checkpoint_path: str) -> int: - # path follow the pattern "rl_model_*_steps.zip", we count from the back to ignore any other _ in the path - return int(checkpoint_path.split("_")[-2]) - - checkpoints = sorted(checkpoints, key=step_count) - model_path = checkpoints[-1] - found = True - - if not found: - raise ValueError(f"No model found for {algo} on {env_id}, path: {model_path}") + try: + _, model_path, log_path = get_model_path( + args.exp_id, + folder, + algo, + env_id, + args.load_best, + args.load_checkpoint, + args.load_last_checkpoint, + ) + except (AssertionError, ValueError) as e: + # Special case for rl-trained agents + # auto-download from the hub + if "rl-trained-agents" not in folder: + raise e + else: + print("Pretrained model not found, trying to download it from sb3 Huggingface hub: https://huggingface.co/sb3") + # Auto-download + download_from_hub( + algo=algo, + env_id=env_id, + exp_id=args.exp_id, + folder=folder, + organization="sb3", + repo_name=None, + force=False, + ) + # Try again + _, model_path, log_path = get_model_path( + args.exp_id, + folder, + algo, + env_id, + args.load_best, + args.load_checkpoint, + args.load_last_checkpoint, + ) print(f"Loading {model_path}") @@ -138,7 +155,7 @@ def step_count(checkpoint_path: str) -> int: env_kwargs = {} args_path = os.path.join(log_path, env_id, "args.yml") if os.path.isfile(args_path): - with open(args_path, "r") as f: + with open(args_path) as f: loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr if loaded_args["env_kwargs"] is not None: env_kwargs = loaded_args["env_kwargs"] @@ -180,20 +197,34 @@ def step_count(checkpoint_path: str) -> int: obs = env.reset() + # Wait for message from websocket + if bool(int(os.environ.get("WAIT_FOR_START", False))): + print(f"Waiting for socket message on port {socket_port}") + asyncio.run(main_wait()) + # Deterministic by default except for atari games stochastic = args.stochastic or is_atari and not args.deterministic deterministic = not stochastic - state = None episode_reward = 0.0 episode_rewards, episode_lengths = [], [] ep_len = 0 # For HER, monitor success rate successes = [] + lstm_states = None + episode_start = np.ones((env.num_envs,), dtype=bool) try: for _ in range(args.n_timesteps): - action, state = model.predict(obs, state=state, deterministic=deterministic) + action, lstm_states = model.predict( + obs, + state=lstm_states, + episode_start=episode_start, + deterministic=deterministic, + ) obs, reward, done, infos = env.step(action) + + episode_start = done + if not args.no_render: env.render("human") @@ -218,7 +249,6 @@ def step_count(checkpoint_path: str) -> int: episode_lengths.append(ep_len) episode_reward = 0.0 ep_len = 0 - state = None # Reset also when the goal is achieved when using HER if done and infos[0].get("is_success") is not None: @@ -230,6 +260,7 @@ def step_count(checkpoint_path: str) -> int: episode_reward, ep_len = 0.0, 0 except KeyboardInterrupt: + print("Cancelled by the user...") pass if args.verbose > 0 and len(successes) > 0: diff --git a/hyperparams/a2c.yml b/hyperparams/a2c.yml index f51b0fb09..ba9416e85 100644 --- a/hyperparams/a2c.yml +++ b/hyperparams/a2c.yml @@ -1,6 +1,10 @@ atari: env_wrapper: - stable_baselines3.common.atari_wrappers.AtariWrapper + # Equivalent to + # vec_env_wrapper: + # - stable_baselines3.common.vec_env.VecFrameStack: + # n_stack: 4 frame_stack: 4 policy: 'CnnPolicy' n_envs: 16 diff --git a/hyperparams/human.yml b/hyperparams/human.yml new file mode 100644 index 000000000..6852b93b7 --- /dev/null +++ b/hyperparams/human.yml @@ -0,0 +1,13 @@ +donkey-generated-track-v0: + env_wrapper: + # - gym.wrappers.time_limit.TimeLimit: + # max_episode_steps: 1200 + - utils.wrappers.HistoryWrapper: + horizon: 2 + # - utils.wrappers.TimeFeatureWrapper: + # test_mode: True + # - stable_baselines3.common.monitor.Monitor: + # filename: None + n_timesteps: !!float 1e6 + policy: 'MlpPolicy' + buffer_size: 100000 diff --git a/hyperparams/ppo.yml b/hyperparams/ppo.yml index 2a88c30d8..1532953d5 100644 --- a/hyperparams/ppo.yml +++ b/hyperparams/ppo.yml @@ -85,17 +85,17 @@ Acrobot-v1: BipedalWalker-v3: normalize: true - n_envs: 16 + n_envs: 32 n_timesteps: !!float 5e6 policy: 'MlpPolicy' n_steps: 2048 batch_size: 64 gae_lambda: 0.95 - gamma: 0.99 + gamma: 0.999 n_epochs: 10 - ent_coef: 0.001 - learning_rate: !!float 2.5e-4 - clip_range: 0.2 + ent_coef: 0.0 + learning_rate: !!float 3e-4 + clip_range: 0.18 BipedalWalkerHardcore-v3: normalize: true @@ -319,28 +319,33 @@ MiniGrid-FourRooms-v0: CarRacing-v0: env_wrapper: + - utils.wrappers.FrameSkip: + skip: 2 - gym.wrappers.resize_observation.ResizeObservation: shape: 64 - gym.wrappers.gray_scale_observation.GrayScaleObservation: keep_dim: true - frame_stack: 4 + frame_stack: 2 + normalize: "{'norm_obs': False, 'norm_reward': True}" n_envs: 8 - n_timesteps: !!float 1e6 + n_timesteps: !!float 4e6 policy: 'CnnPolicy' batch_size: 128 n_steps: 512 gamma: 0.99 - gae_lambda: 0.9 - n_epochs: 20 + gae_lambda: 0.95 + n_epochs: 10 ent_coef: 0.0 sde_sample_freq: 4 max_grad_norm: 0.5 vf_coef: 0.5 - learning_rate: !!float 3e-5 + learning_rate: lin_1e-4 use_sde: True - clip_range: 0.4 + clip_range: 0.2 policy_kwargs: "dict(log_std_init=-2, ortho_init=False, + activation_fn=nn.GELU, + net_arch=[dict(pi=[256], vf=[256])], )" @@ -508,7 +513,7 @@ InvertedPendulum-v3: max_grad_norm: 0.3 vf_coef: 0.19816 -Reacher-v3: +Reacher-v2: normalize: true n_envs: 1 policy: 'MlpPolicy' @@ -524,28 +529,6 @@ Reacher-v3: max_grad_norm: 0.9 vf_coef: 0.950368 -# Swimmer-v3: -# normalize: true -# n_envs: 1 -# policy: 'MlpPolicy' -# n_timesteps: !!float 1e6 -# batch_size: 32 -# n_steps: 512 -# gamma: 0.9999 -# learning_rate: 5.49717e-05 -# ent_coef: 0.0554757 -# clip_range: 0.3 -# n_epochs: 10 -# gae_lambda: 0.95 -# max_grad_norm: 0.6 -# vf_coef: 0.38782 -# policy_kwargs: "dict( -# log_std_init=-2, -# ortho_init=False, -# activation_fn=nn.ReLU, -# net_arch=[dict(pi=[256, 256], vf=[256, 256])] -# )" - Walker2d-v3: normalize: true n_envs: 1 diff --git a/hyperparams/ppo_lstm.yml b/hyperparams/ppo_lstm.yml new file mode 100644 index 000000000..52fffe32f --- /dev/null +++ b/hyperparams/ppo_lstm.yml @@ -0,0 +1,534 @@ +atari: + env_wrapper: + - stable_baselines3.common.atari_wrappers.AtariWrapper: + terminal_on_life_loss: False + frame_stack: 4 + policy: 'CnnLstmPolicy' + n_envs: 8 + n_steps: 128 + n_epochs: 4 + batch_size: 256 + n_timesteps: !!float 1e7 + learning_rate: lin_2.5e-4 + clip_range: lin_0.1 + vf_coef: 0.5 + ent_coef: 0.01 + policy_kwargs: "dict(enable_critic_lstm=False, + lstm_hidden_size=128, + )" + +# Tuned +PendulumNoVel-v1: + normalize: True + n_envs: 4 + n_timesteps: !!float 1e5 + policy: 'MlpLstmPolicy' + n_steps: 1024 + gae_lambda: 0.95 + gamma: 0.9 + n_epochs: 10 + ent_coef: 0.0 + learning_rate: !!float 1e-3 + clip_range: 0.2 + use_sde: True + sde_sample_freq: 4 + policy_kwargs: "dict( + ortho_init=False, + activation_fn=nn.ReLU, + lstm_hidden_size=64, + enable_critic_lstm=True, + net_arch=[dict(pi=[64], vf=[64])] + )" + +# Tuned +CartPoleNoVel-v1: + normalize: True + n_envs: 8 + n_timesteps: !!float 1e5 + policy: 'MlpLstmPolicy' + n_steps: 32 + batch_size: 256 + gae_lambda: 0.8 + gamma: 0.98 + n_epochs: 20 + ent_coef: 0.0 + learning_rate: lin_0.001 + clip_range: lin_0.2 + policy_kwargs: "dict( + ortho_init=False, + activation_fn=nn.ReLU, + lstm_hidden_size=64, + enable_critic_lstm=True, + net_arch=[dict(pi=[64], vf=[64])] + )" + +# TO BE TUNED +MountainCarNoVel-v0: + normalize: true + n_envs: 16 + n_timesteps: !!float 1e6 + policy: 'MlpLstmPolicy' + n_steps: 16 + gae_lambda: 0.98 + gamma: 0.99 + n_epochs: 4 + ent_coef: 0.0 + +# Tuned +MountainCarContinuousNoVel-v0: + normalize: true + n_envs: 8 + n_timesteps: !!float 3e5 + policy: 'MlpLstmPolicy' + batch_size: 256 + n_steps: 1024 + gamma: 0.9999 + learning_rate: !!float 7.77e-05 + ent_coef: 0.00429 + clip_range: 0.1 + n_epochs: 10 + gae_lambda: 0.9 + max_grad_norm: 5 + vf_coef: 0.19 + use_sde: True + sde_sample_freq: 8 + policy_kwargs: "dict(log_std_init=0.0, ortho_init=False, + lstm_hidden_size=32, + enable_critic_lstm=True, + net_arch=[dict(pi=[64], vf=[64])])" + +Acrobot-v1: + normalize: true + n_envs: 16 + n_timesteps: !!float 1e6 + policy: 'MlpLstmPolicy' + n_steps: 256 + gae_lambda: 0.94 + gamma: 0.99 + n_epochs: 4 + ent_coef: 0.0 + +BipedalWalker-v3: + normalize: true + n_envs: 32 + n_timesteps: !!float 5e6 + policy: 'MlpLstmPolicy' + n_steps: 256 + batch_size: 256 + gae_lambda: 0.95 + gamma: 0.999 + n_epochs: 10 + ent_coef: 0.0 + learning_rate: !!float 3e-4 + clip_range: 0.18 + policy_kwargs: "dict( + ortho_init=False, + activation_fn=nn.ReLU, + lstm_hidden_size=64, + enable_critic_lstm=True, + net_arch=[dict(pi=[64], vf=[64])] + )" + +# TO BE TUNED +BipedalWalkerHardcore-v3: + # env_wrapper: + # - utils.wrappers.FrameSkip: + # skip: 2 + normalize: true + n_envs: 32 + n_timesteps: !!float 10e7 + policy: 'MlpLstmPolicy' + n_steps: 256 + batch_size: 256 + gae_lambda: 0.95 + gamma: 0.999 + n_epochs: 10 + ent_coef: 0.001 + learning_rate: lin_3e-4 + clip_range: lin_0.2 + policy_kwargs: "dict( + ortho_init=False, + activation_fn=nn.ReLU, + lstm_hidden_size=64, + enable_critic_lstm=True, + net_arch=[dict(pi=[64], vf=[64])] + )" + +# Tuned +LunarLanderNoVel-v2: &lunar-defaults + normalize: True + n_envs: 32 + n_timesteps: !!float 5e6 + policy: 'MlpLstmPolicy' + n_steps: 512 + batch_size: 128 + gae_lambda: 0.98 + gamma: 0.999 + n_epochs: 4 + ent_coef: 0.01 + policy_kwargs: "dict( + ortho_init=False, + activation_fn=nn.ReLU, + lstm_hidden_size=64, + enable_critic_lstm=True, + net_arch=[dict(pi=[64], vf=[64])] + )" + +LunarLanderContinuousNoVel-v2: + <<: *lunar-defaults + + +HalfCheetahBulletEnv-v0: &pybullet-defaults + normalize: true + n_envs: 16 + n_timesteps: !!float 2e6 + policy: 'MlpLstmPolicy' + batch_size: 128 + n_steps: 256 + gamma: 0.99 + gae_lambda: 0.9 + n_epochs: 10 + policy_kwargs: "dict(ortho_init=False, + activation_fn=nn.ReLU, + net_arch=[dict(pi=[], vf=[])], + enable_critic_lstm=True, + lstm_hidden_size=128, + )" + +AntBulletEnv-v0: + <<: *pybullet-defaults + +Walker2DBulletEnv-v0: + <<: *pybullet-defaults + clip_range: lin_0.4 + +HopperBulletEnv-v0: + <<: *pybullet-defaults + clip_range: lin_0.4 + + +ReacherBulletEnv-v0: + <<: *pybullet-defaults + clip_range: lin_0.4 + +MinitaurBulletEnv-v0: + normalize: true + n_envs: 8 + n_timesteps: !!float 2e6 + policy: 'MlpLstmPolicy' + n_steps: 2048 + batch_size: 64 + gae_lambda: 0.95 + gamma: 0.99 + n_epochs: 10 + ent_coef: 0.0 + learning_rate: 2.5e-4 + clip_range: 0.2 + +MinitaurBulletDuckEnv-v0: + normalize: true + n_envs: 8 + n_timesteps: !!float 2e6 + policy: 'MlpLstmPolicy' + n_steps: 2048 + batch_size: 64 + gae_lambda: 0.95 + gamma: 0.99 + n_epochs: 10 + ent_coef: 0.0 + learning_rate: 2.5e-4 + clip_range: 0.2 + +# To be tuned +HumanoidBulletEnv-v0: + normalize: true + n_envs: 8 + n_timesteps: !!float 1e7 + policy: 'MlpLstmPolicy' + n_steps: 2048 + batch_size: 64 + gae_lambda: 0.95 + gamma: 0.99 + n_epochs: 10 + ent_coef: 0.0 + learning_rate: 2.5e-4 + clip_range: 0.2 + +InvertedDoublePendulumBulletEnv-v0: + normalize: true + n_envs: 8 + n_timesteps: !!float 2e6 + policy: 'MlpLstmPolicy' + n_steps: 2048 + batch_size: 64 + gae_lambda: 0.95 + gamma: 0.99 + n_epochs: 10 + ent_coef: 0.0 + learning_rate: 2.5e-4 + clip_range: 0.2 + +InvertedPendulumSwingupBulletEnv-v0: + normalize: true + n_envs: 8 + n_timesteps: !!float 2e6 + policy: 'MlpLstmPolicy' + n_steps: 2048 + batch_size: 64 + gae_lambda: 0.95 + gamma: 0.99 + n_epochs: 10 + ent_coef: 0.0 + learning_rate: 2.5e-4 + clip_range: 0.2 + + +CarRacing-v0: + env_wrapper: + # - utils.wrappers.FrameSkip: + # skip: 2 + - gym.wrappers.resize_observation.ResizeObservation: + shape: 64 + - gym.wrappers.gray_scale_observation.GrayScaleObservation: + keep_dim: true + frame_stack: 2 + normalize: "{'norm_obs': False, 'norm_reward': True}" + n_envs: 8 + n_timesteps: !!float 4e6 + policy: 'CnnLstmPolicy' + batch_size: 128 + n_steps: 512 + gamma: 0.99 + gae_lambda: 0.95 + n_epochs: 10 + ent_coef: 0.0 + sde_sample_freq: 4 + max_grad_norm: 0.5 + vf_coef: 0.5 + learning_rate: lin_1e-4 + use_sde: True + clip_range: 0.2 + policy_kwargs: "dict(log_std_init=-2, + ortho_init=False, + enable_critic_lstm=False, + activation_fn=nn.GELU, + lstm_hidden_size=128, + )" + +# === Mujoco Envs === +# HalfCheetah-v3: &mujoco-defaults +# normalize: true +# n_timesteps: !!float 1e6 +# policy: 'MlpLstmPolicy' + +Ant-v3: &mujoco-defaults + normalize: true + n_timesteps: !!float 1e6 + policy: 'MlpLstmPolicy' + +# Hopper-v3: +# <<: *mujoco-defaults +# +# Walker2d-v3: +# <<: *mujoco-defaults +# +# Humanoid-v3: +# <<: *mujoco-defaults +# n_timesteps: !!float 2e6 +# +Swimmer-v3: + <<: *mujoco-defaults + gamma: 0.9999 + + +# 10 mujoco envs + +HalfCheetah-v3: + normalize: true + n_envs: 1 + policy: 'MlpLstmPolicy' + n_timesteps: !!float 1e6 + batch_size: 64 + n_steps: 512 + gamma: 0.98 + learning_rate: 2.0633e-05 + ent_coef: 0.000401762 + clip_range: 0.1 + n_epochs: 20 + gae_lambda: 0.92 + max_grad_norm: 0.8 + vf_coef: 0.58096 + policy_kwargs: "dict( + log_std_init=-2, + ortho_init=False, + activation_fn=nn.ReLU, + net_arch=[dict(pi=[256, 256], vf=[256, 256])] + )" + +# Ant-v3: +# normalize: true +# n_envs: 1 +# policy: 'MlpLstmPolicy' +# n_timesteps: !!float 1e7 +# batch_size: 32 +# n_steps: 512 +# gamma: 0.98 +# learning_rate: 1.90609e-05 +# ent_coef: 4.9646e-07 +# clip_range: 0.1 +# n_epochs: 10 +# gae_lambda: 0.8 +# max_grad_norm: 0.6 +# vf_coef: 0.677239 + +Hopper-v3: + normalize: true + n_envs: 1 + policy: 'MlpLstmPolicy' + n_timesteps: !!float 1e6 + batch_size: 32 + n_steps: 512 + gamma: 0.999 + learning_rate: 9.80828e-05 + ent_coef: 0.00229519 + clip_range: 0.2 + n_epochs: 5 + gae_lambda: 0.99 + max_grad_norm: 0.7 + vf_coef: 0.835671 + policy_kwargs: "dict( + log_std_init=-2, + ortho_init=False, + activation_fn=nn.ReLU, + net_arch=[dict(pi=[256, 256], vf=[256, 256])] + )" + +HumanoidStandup-v3: + normalize: true + n_envs: 1 + policy: 'MlpLstmPolicy' + n_timesteps: !!float 1e7 + batch_size: 32 + n_steps: 512 + gamma: 0.99 + learning_rate: 2.55673e-05 + ent_coef: 3.62109e-06 + clip_range: 0.3 + n_epochs: 20 + gae_lambda: 0.9 + max_grad_norm: 0.7 + vf_coef: 0.430793 + policy_kwargs: "dict( + log_std_init=-2, + ortho_init=False, + activation_fn=nn.ReLU, + net_arch=[dict(pi=[256, 256], vf=[256, 256])] + )" + +Humanoid-v3: + normalize: true + n_envs: 1 + policy: 'MlpLstmPolicy' + n_timesteps: !!float 1e7 + batch_size: 256 + n_steps: 512 + gamma: 0.95 + learning_rate: 3.56987e-05 + ent_coef: 0.00238306 + clip_range: 0.3 + n_epochs: 5 + gae_lambda: 0.9 + max_grad_norm: 2 + vf_coef: 0.431892 + policy_kwargs: "dict( + log_std_init=-2, + ortho_init=False, + activation_fn=nn.ReLU, + net_arch=[dict(pi=[256, 256], vf=[256, 256])] + )" + +InvertedDoublePendulum-v3: + normalize: true + n_envs: 1 + policy: 'MlpLstmPolicy' + n_timesteps: !!float 1e6 + batch_size: 512 + n_steps: 128 + gamma: 0.98 + learning_rate: 0.000155454 + ent_coef: 1.05057e-06 + clip_range: 0.4 + n_epochs: 10 + gae_lambda: 0.8 + max_grad_norm: 0.5 + vf_coef: 0.695929 + +InvertedPendulum-v3: + normalize: true + n_envs: 1 + policy: 'MlpLstmPolicy' + n_timesteps: !!float 1e6 + batch_size: 64 + n_steps: 32 + gamma: 0.999 + learning_rate: 0.000222425 + ent_coef: 1.37976e-07 + clip_range: 0.4 + n_epochs: 5 + gae_lambda: 0.9 + max_grad_norm: 0.3 + vf_coef: 0.19816 + +Reacher-v3: + normalize: true + n_envs: 1 + policy: 'MlpLstmPolicy' + n_timesteps: !!float 1e6 + batch_size: 32 + n_steps: 512 + gamma: 0.9 + learning_rate: 0.000104019 + ent_coef: 7.52585e-08 + clip_range: 0.3 + n_epochs: 5 + gae_lambda: 1.0 + max_grad_norm: 0.9 + vf_coef: 0.950368 + +# Swimmer-v3: +# normalize: true +# n_envs: 1 +# policy: 'MlpLstmPolicy' +# n_timesteps: !!float 1e6 +# batch_size: 32 +# n_steps: 512 +# gamma: 0.9999 +# learning_rate: 5.49717e-05 +# ent_coef: 0.0554757 +# clip_range: 0.3 +# n_epochs: 10 +# gae_lambda: 0.95 +# max_grad_norm: 0.6 +# vf_coef: 0.38782 +# policy_kwargs: "dict( +# log_std_init=-2, +# ortho_init=False, +# activation_fn=nn.ReLU, +# net_arch=[dict(pi=[256, 256], vf=[256, 256])] +# )" + +Walker2d-v3: + normalize: true + n_envs: 1 + policy: 'MlpLstmPolicy' + n_timesteps: !!float 1e6 + batch_size: 32 + n_steps: 512 + gamma: 0.99 + learning_rate: 5.05041e-05 + ent_coef: 0.000585045 + clip_range: 0.1 + n_epochs: 20 + gae_lambda: 0.95 + max_grad_norm: 1 + vf_coef: 0.871923 diff --git a/hyperparams/sac.yml b/hyperparams/sac.yml index 8aeaef616..d831b8791 100644 --- a/hyperparams/sac.yml +++ b/hyperparams/sac.yml @@ -1,281 +1,32 @@ -# Tuned -MountainCarContinuous-v0: - n_timesteps: !!float 50000 - policy: 'MlpPolicy' - learning_rate: !!float 3e-4 - buffer_size: 50000 - batch_size: 512 - ent_coef: 0.1 - train_freq: 32 - gradient_steps: 32 - gamma: 0.9999 - tau: 0.01 - learning_starts: 0 - use_sde: True - policy_kwargs: "dict(log_std_init=-3.67, net_arch=[64, 64])" - -Pendulum-v1: - # callback: - # - utils.callbacks.ParallelTrainCallback - n_timesteps: 20000 - policy: 'MlpPolicy' - learning_rate: !!float 1e-3 - - -LunarLanderContinuous-v2: - n_timesteps: !!float 5e5 - policy: 'MlpPolicy' - batch_size: 256 - learning_rate: lin_7.3e-4 - buffer_size: 1000000 - ent_coef: 'auto' - gamma: 0.99 - tau: 0.01 - train_freq: 1 - gradient_steps: 1 - learning_starts: 10000 - policy_kwargs: "dict(net_arch=[400, 300])" - - -# Tuned -BipedalWalker-v3: - n_timesteps: !!float 5e5 - policy: 'MlpPolicy' - learning_rate: !!float 7.3e-4 - buffer_size: 300000 - batch_size: 256 - ent_coef: 'auto' - gamma: 0.98 - tau: 0.02 - train_freq: 64 - gradient_steps: 64 - learning_starts: 10000 - use_sde: True - policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])" - -# Almost tuned -BipedalWalkerHardcore-v3: - n_timesteps: !!float 1e7 - policy: 'MlpPolicy' - learning_rate: lin_7.3e-4 - buffer_size: 1000000 - batch_size: 256 - ent_coef: 0.005 - gamma: 0.99 - tau: 0.01 - train_freq: 1 - gradient_steps: 1 - learning_starts: 10000 - policy_kwargs: "dict(net_arch=[400, 300])" - -# === Bullet envs === - -# Tuned -HalfCheetahBulletEnv-v0: &pybullet-defaults - # env_wrapper: - # - sb3_contrib.common.wrappers.TimeFeatureWrapper - # - utils.wrappers.DelayedRewardWrapper: - # delay: 10 - # - utils.wrappers.HistoryWrapper: - # horizon: 10 - n_timesteps: !!float 1e6 - policy: 'MlpPolicy' - learning_rate: !!float 7.3e-4 - buffer_size: 300000 - batch_size: 256 - ent_coef: 'auto' - gamma: 0.98 - tau: 0.02 - train_freq: 8 - gradient_steps: 8 - learning_starts: 10000 - # replay_buffer_kwargs: "dict(handle_timeout_termination=True)" - use_sde: True - policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])" - -# Tuned -AntBulletEnv-v0: - <<: *pybullet-defaults - -HopperBulletEnv-v0: - <<: *pybullet-defaults - learning_rate: lin_7.3e-4 - -Walker2DBulletEnv-v0: - <<: *pybullet-defaults - learning_rate: lin_7.3e-4 - - -# Tuned -ReacherBulletEnv-v0: - <<: *pybullet-defaults - n_timesteps: !!float 3e5 - -HumanoidBulletEnv-v0: +## Custom envs +donkey-mountain-track-v0: &defaults + # Normalize AE (+ the rest) normalize: "{'norm_obs': True, 'norm_reward': False}" - n_timesteps: !!float 2e7 - policy: 'MlpPolicy' - learning_rate: lin_3e-4 - buffer_size: 1000000 - batch_size: 64 - ent_coef: 'auto' - train_freq: 1 - gradient_steps: 1 - learning_starts: 1000 - -# Tuned -InvertedDoublePendulumBulletEnv-v0: - <<: *pybullet-defaults - n_timesteps: !!float 5e5 - - -# Tuned -InvertedPendulumSwingupBulletEnv-v0: - <<: *pybullet-defaults - n_timesteps: !!float 3e5 - -# To be tuned -MinitaurBulletEnv-v0: - # normalize: "{'norm_obs': True, 'norm_reward': False}" - n_timesteps: !!float 1e6 - policy: 'MlpPolicy' - learning_rate: !!float 3e-4 - buffer_size: 100000 - batch_size: 256 - ent_coef: 'auto' - train_freq: 1 - gradient_steps: 1 - learning_starts: 10000 - - -# To be tuned -MinitaurBulletDuckEnv-v0: - n_timesteps: !!float 1e6 - policy: 'MlpPolicy' - learning_rate: !!float 3e-4 - buffer_size: 1000000 - batch_size: 256 - ent_coef: 'auto' - train_freq: 1 - gradient_steps: 1 - learning_starts: 10000 - -# To be tuned -CarRacing-v0: - env_wrapper: - - gym.wrappers.resize_observation.ResizeObservation: - shape: 64 - - gym.wrappers.gray_scale_observation.GrayScaleObservation: - keep_dim: true - frame_stack: 4 - n_envs: 1 - n_timesteps: !!float 1e6 - # TODO: try with MlpPolicy and low res (MNIST: 28) pixels - # policy: 'MlpPolicy' - policy: 'CnnPolicy' - learning_rate: !!float 7.3e-4 - buffer_size: 50000 - batch_size: 256 - ent_coef: 'auto' - gamma: 0.98 - tau: 0.02 - train_freq: 64 - # train_freq: [1, "episode"] - gradient_steps: 64 - # sde_sample_freq: 64 - learning_starts: 1000 - use_sde: True - use_sde_at_warmup: True - policy_kwargs: "dict(log_std_init=-2, net_arch=[64, 64])" - -# === Mujoco Envs === - -HalfCheetah-v3: &mujoco-defaults - n_timesteps: !!float 1e6 - policy: 'MlpPolicy' - learning_starts: 10000 - -Ant-v3: - <<: *mujoco-defaults - -Hopper-v3: - <<: *mujoco-defaults - -Walker2d-v3: - <<: *mujoco-defaults - -Humanoid-v3: - <<: *mujoco-defaults - n_timesteps: !!float 2e6 - -Swimmer-v3: - <<: *mujoco-defaults - gamma: 0.9999 - -# === HER Robotics GoalEnvs === - -FetchReach-v1: - n_timesteps: !!float 20000 - policy: 'MultiInputPolicy' - buffer_size: 1000000 - ent_coef: 'auto' - batch_size: 256 - gamma: 0.95 - learning_rate: 0.001 - learning_starts: 1000 - normalize: True - replay_buffer_class: HerReplayBuffer - replay_buffer_kwargs: "dict( - online_sampling=True, - goal_selection_strategy='future', - n_sampled_goal=4 - )" - policy_kwargs: "dict(net_arch=[64, 64])" - - -# ==== Custom Envs === - -donkey-generated-track-v0: env_wrapper: - gym.wrappers.time_limit.TimeLimit: - max_episode_steps: 500 + max_episode_steps: 10000 + - ae.wrapper.AutoencoderWrapper - utils.wrappers.HistoryWrapper: - horizon: 5 - n_timesteps: !!float 1e6 + horizon: 2 + callback: + - utils.callbacks.ParallelTrainCallback: + gradient_steps: 200 + - utils.callbacks.LapTimeCallback + n_timesteps: !!float 2e6 policy: 'MlpPolicy' learning_rate: !!float 7.3e-4 - buffer_size: 300000 + buffer_size: 200000 batch_size: 256 ent_coef: 'auto' gamma: 0.99 tau: 0.02 - # train_freq: 64 - train_freq: [1, "episode"] - # gradient_steps: -1 - gradient_steps: 64 + train_freq: 200 + gradient_steps: 256 learning_starts: 500 use_sde_at_warmup: True use_sde: True - sde_sample_freq: 64 - policy_kwargs: "dict(log_std_init=-2, net_arch=[64, 64])" + sde_sample_freq: 16 + policy_kwargs: "dict(log_std_init=-3, net_arch=[256, 256], n_critics=2, use_expln=True)" -# === Real Robot envs -NeckEnvRelative-v2: - <<: *pybullet-defaults - env_wrapper: - - utils.wrappers.HistoryWrapper: - horizon: 2 - # - utils.wrappers.LowPassFilterWrapper: - # freq: 2.0 - # df: 25.0 - n_timesteps: !!float 1e6 - buffer_size: 100000 - gamma: 0.99 - train_freq: [1, "episode"] - gradient_steps: -1 - # 10 episodes of warm-up - learning_starts: 3000 - use_sde_at_warmup: True - use_sde: True - sde_sample_freq: 64 - policy_kwargs: "dict(log_std_init=-2, net_arch=[256, 256])" +donkey-circuit-launch-track-v0: + <<: *defaults diff --git a/hyperparams/teleop.yml b/hyperparams/teleop.yml new file mode 100644 index 000000000..c60e15ebc --- /dev/null +++ b/hyperparams/teleop.yml @@ -0,0 +1,13 @@ +RLRacingEnv-v0: + env_wrapper: + # - gym.wrappers.time_limit.TimeLimit: + # max_episode_steps: 1200 + - utils.wrappers.HistoryWrapper: + horizon: 2 + # - utils.wrappers.TimeFeatureWrapper: + # test_mode: True + # - stable_baselines3.common.monitor.Monitor: + # filename: None + n_timesteps: !!float 2e6 + policy: 'MlpPolicy' + buffer_size: 10000 diff --git a/hyperparams/tqc.yml b/hyperparams/tqc.yml index 54e2f3a4d..d831b8791 100644 --- a/hyperparams/tqc.yml +++ b/hyperparams/tqc.yml @@ -1,259 +1,32 @@ -# Tuned -MountainCarContinuous-v0: - n_timesteps: !!float 50000 - policy: 'MlpPolicy' - learning_rate: !!float 3e-4 - buffer_size: 50000 - batch_size: 512 - ent_coef: 0.1 - train_freq: 32 - gradient_steps: 32 - gamma: 0.9999 - tau: 0.01 - learning_starts: 0 - use_sde: True - policy_kwargs: "dict(log_std_init=-3.67, net_arch=[64, 64])" - -Pendulum-v1: - n_timesteps: 20000 - policy: 'MlpPolicy' - learning_rate: !!float 1e-3 - -LunarLanderContinuous-v2: - n_timesteps: !!float 5e5 - policy: 'MlpPolicy' - learning_rate: lin_7.3e-4 - buffer_size: 1000000 - batch_size: 256 - ent_coef: 'auto' - gamma: 0.99 - tau: 0.01 - train_freq: 1 - gradient_steps: 1 - learning_starts: 10000 - policy_kwargs: "dict(net_arch=[400, 300])" - -BipedalWalker-v3: - n_timesteps: !!float 5e5 - policy: 'MlpPolicy' - learning_rate: !!float 7.3e-4 - buffer_size: 300000 - batch_size: 256 - ent_coef: 'auto' - gamma: 0.98 - tau: 0.02 - train_freq: 64 - gradient_steps: 64 - learning_starts: 10000 - use_sde: True - policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])" - -# Almost tuned -# History wrapper of size 2 for better performances -BipedalWalkerHardcore-v3: +## Custom envs +donkey-mountain-track-v0: &defaults + # Normalize AE (+ the rest) + normalize: "{'norm_obs': True, 'norm_reward': False}" + env_wrapper: + - gym.wrappers.time_limit.TimeLimit: + max_episode_steps: 10000 + - ae.wrapper.AutoencoderWrapper + - utils.wrappers.HistoryWrapper: + horizon: 2 + callback: + - utils.callbacks.ParallelTrainCallback: + gradient_steps: 200 + - utils.callbacks.LapTimeCallback n_timesteps: !!float 2e6 policy: 'MlpPolicy' - learning_rate: lin_7.3e-4 - buffer_size: 1000000 - batch_size: 256 - ent_coef: 'auto' - gamma: 0.99 - tau: 0.01 - train_freq: 1 - gradient_steps: 1 - learning_starts: 10000 - policy_kwargs: "dict(net_arch=[400, 300])" - -# === Bullet envs === - -# Tuned -HalfCheetahBulletEnv-v0: &pybullet-defaults - n_timesteps: !!float 1e6 - policy: 'MlpPolicy' learning_rate: !!float 7.3e-4 - buffer_size: 300000 + buffer_size: 200000 batch_size: 256 ent_coef: 'auto' - gamma: 0.98 - tau: 0.02 - train_freq: 8 - gradient_steps: 8 - learning_starts: 10000 - use_sde: True - policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])" - -# Tuned -AntBulletEnv-v0: - <<: *pybullet-defaults - -# Tuned -HopperBulletEnv-v0: - <<: *pybullet-defaults - learning_rate: lin_7.3e-4 - top_quantiles_to_drop_per_net: 5 - -# Tuned -Walker2DBulletEnv-v0: - <<: *pybullet-defaults - learning_rate: lin_7.3e-4 - -ReacherBulletEnv-v0: - <<: *pybullet-defaults - n_timesteps: !!float 3e5 - -# Almost tuned -HumanoidBulletEnv-v0: - n_timesteps: !!float 1e7 - policy: 'MlpPolicy' - learning_rate: lin_7.3e-4 - buffer_size: 300000 - batch_size: 256 - ent_coef: 'auto' - gamma: 0.98 + gamma: 0.99 tau: 0.02 - train_freq: 8 - gradient_steps: 8 - learning_starts: 10000 - top_quantiles_to_drop_per_net: 5 + train_freq: 200 + gradient_steps: 256 + learning_starts: 500 + use_sde_at_warmup: True use_sde: True - policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])" - -InvertedDoublePendulumBulletEnv-v0: - <<: *pybullet-defaults - n_timesteps: !!float 5e5 - -InvertedPendulumSwingupBulletEnv-v0: - <<: *pybullet-defaults - n_timesteps: !!float 3e5 - -MinitaurBulletEnv-v0: - n_timesteps: !!float 1e6 - policy: 'MlpPolicy' - learning_rate: !!float 3e-4 - buffer_size: 100000 - batch_size: 256 - ent_coef: 'auto' - train_freq: 1 - gradient_steps: 1 - learning_starts: 10000 - # top_quantiles_to_drop_per_net: 5 - -# === Mujoco Envs === - -HalfCheetah-v3: &mujoco-defaults - n_timesteps: !!float 1e6 - policy: 'MlpPolicy' - learning_starts: 10000 - -Ant-v3: - <<: *mujoco-defaults - -Hopper-v3: - <<: *mujoco-defaults - top_quantiles_to_drop_per_net: 5 - -Walker2d-v3: - <<: *mujoco-defaults - -Humanoid-v3: - <<: *mujoco-defaults - n_timesteps: !!float 2e6 - -Swimmer-v3: - <<: *mujoco-defaults - gamma: 0.9999 - -# === HER Robotics GoalEnvs === -FetchReach-v1: - env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper - n_timesteps: !!float 20000 - policy: 'MultiInputPolicy' - buffer_size: 1000000 - ent_coef: 'auto' - batch_size: 256 - gamma: 0.95 - learning_rate: 0.001 - learning_starts: 1000 - normalize: True - replay_buffer_class: HerReplayBuffer - replay_buffer_kwargs: "dict( - online_sampling=True, - goal_selection_strategy='future', - n_sampled_goal=4 - )" - policy_kwargs: "dict(net_arch=[64, 64], n_critics=1)" - -FetchPush-v1: &her-defaults - env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper - n_timesteps: !!float 1e6 - policy: 'MultiInputPolicy' - buffer_size: 1000000 - batch_size: 2048 - gamma: 0.95 - learning_rate: !!float 1e-3 - tau: 0.05 - replay_buffer_class: HerReplayBuffer - replay_buffer_kwargs: "dict( - online_sampling=True, - goal_selection_strategy='future', - n_sampled_goal=4, - )" - policy_kwargs: "dict(net_arch=[512, 512, 512], n_critics=2)" - -FetchSlide-v1: - <<: *her-defaults - n_timesteps: !!float 3e6 - -FetchPickAndPlace-v1: - <<: *her-defaults - -PandaReach-v1: - env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper - n_timesteps: !!float 20000 - policy: 'MultiInputPolicy' - buffer_size: 1000000 - ent_coef: 'auto' - batch_size: 256 - gamma: 0.95 - learning_rate: 0.001 - learning_starts: 1000 - normalize: True - replay_buffer_class: HerReplayBuffer - replay_buffer_kwargs: "dict( - online_sampling=True, - goal_selection_strategy='future', - n_sampled_goal=4 - )" - policy_kwargs: "dict(net_arch=[64, 64], n_critics=1)" - -PandaPush-v1: - <<: *her-defaults - -PandaSlide-v1: - <<: *her-defaults - n_timesteps: !!float 3e6 - -PandaPickAndPlace-v1: - <<: *her-defaults - -PandaStack-v1: - <<: *her-defaults - -parking-v0: - <<: *her-defaults - n_timesteps: !!float 1e5 - buffer_size: 1000000 - batch_size: 512 - gamma: 0.98 - learning_rate: !!float 1.5e-3 - tau: 0.005 - replay_buffer_class: HerReplayBuffer - replay_buffer_kwargs: "dict( - online_sampling=False, - goal_selection_strategy='episode', - n_sampled_goal=4, - max_episode_length=100 - )" + sde_sample_freq: 16 + policy_kwargs: "dict(log_std_init=-3, net_arch=[256, 256], n_critics=2, use_expln=True)" -A1Walking-v0: - <<: *pybullet-defaults +donkey-circuit-launch-track-v0: + <<: *defaults diff --git a/logs/benchmark/benchmark.md b/logs/benchmark/benchmark.md index 0642269c3..61c5239ac 100644 --- a/logs/benchmark/benchmark.md +++ b/logs/benchmark/benchmark.md @@ -8,6 +8,9 @@ during this evaluation. It uses the deterministic policy except for Atari games. +You can view each model card (it includes video and hyperparameters) +on our Huggingface page: https://huggingface.co/sb3 + *NOTE: this is not a quantitative benchmark as it corresponds to only one run (cf [issue #38](https://github.com/araffin/rl-baselines-zoo/issues/38)). This benchmark is meant to check algorithm (maximal) performance, find potential bugs @@ -15,181 +18,185 @@ and also allow users to have access to pretrained agents.* "M" stands for Million (1e6) -|algo | env_id |mean_reward|std_reward|n_timesteps|eval_timesteps|eval_episodes| -|-----|---------------------------|----------:|---------:|-----------|-------------:|------------:| -|a2c |Acrobot-v1 | -83.353| 17.213|500k | 149979| 1778| -|a2c |AntBulletEnv-v0 | 2497.147| 37.359|2M | 150000| 150| -|a2c |AsteroidsNoFrameskip-v4 | 1286.550| 423.750|10M | 614138| 258| -|a2c |BeamRiderNoFrameskip-v4 | 2890.298| 1379.137|10M | 591104| 47| -|a2c |BipedalWalker-v3 | 299.754| 23.459|5M | 149287| 208| -|a2c |BipedalWalkerHardcore-v3 | 96.171| 122.943|200M | 149704| 113| -|a2c |BreakoutNoFrameskip-v4 | 279.793| 122.177|10M | 604115| 82| -|a2c |CartPole-v1 | 500.000| 0.000|500k | 150000| 300| -|a2c |EnduroNoFrameskip-v4 | 0.000| 0.000|10M | 599040| 45| -|a2c |HalfCheetah-v3 | 3041.174| 157.265|1M | 150000| 150| -|a2c |HalfCheetahBulletEnv-v0 | 2107.384| 36.008|2M | 150000| 150| -|a2c |Hopper-v3 | 733.454| 376.574|1M | 149987| 580| -|a2c |HopperBulletEnv-v0 | 815.355| 313.798|2M | 149541| 254| -|a2c |LunarLander-v2 | 155.751| 80.419|200k | 149443| 297| -|a2c |LunarLanderContinuous-v2 | 84.225| 145.906|5M | 149305| 256| -|a2c |MountainCar-v0 | -111.263| 24.087|1M | 149982| 1348| -|a2c |MountainCarContinuous-v0 | 91.166| 0.255|100k | 149923| 1659| -|a2c |Pendulum-v0 | -162.965| 103.210|1M | 150000| 750| -|a2c |PongNoFrameskip-v4 | 17.292| 3.214|10M | 594910| 65| -|a2c |QbertNoFrameskip-v4 | 3882.345| 1223.327|10M | 610670| 194| -|a2c |ReacherBulletEnv-v0 | 14.968| 10.978|2M | 150000| 1000| -|a2c |RoadRunnerNoFrameskip-v4 | 31671.512| 6364.085|10M | 606710| 172| -|a2c |SeaquestNoFrameskip-v4 | 1721.493| 105.339|10M | 599691| 67| -|a2c |SpaceInvadersNoFrameskip-v4| 627.160| 201.974|10M | 604848| 162| -|a2c |Swimmer-v3 | 200.627| 2.544|1M | 150000| 150| -|a2c |Walker2DBulletEnv-v0 | 858.209| 333.116|2M | 149156| 173| -|ars |Acrobot-v1 | -82.884| 23.825|500k | 149985| 1788| -|ars |Ant-v3 | 2333.773| 20.597|75M | 150000| 150| -|ars |CartPole-v1 | 500.000| 0.000|50k | 150000| 300| -|ars |HalfCheetah-v3 | 4815.192| 1340.752|12M | 150000| 150| -|ars |Hopper-v3 | 3343.919| 5.730|7M | 150000| 150| -|ars |LunarLanderContinuous-v2 | 167.959| 147.071|2M | 149883| 562| -|ars |MountainCar-v0 | -122.000| 33.456|500k | 149938| 1229| -|ars |MountainCarContinuous-v0 | 96.672| 0.784|500k | 149990| 621| -|ars |Pendulum-v0 | -212.540| 160.444|2M | 150000| 750| -|ars |Swimmer-v3 | 355.267| 12.796|2M | 150000| 150| -|ars |Walker2d-v3 | 2993.582| 166.289|75M | 149821| 152| -|ddpg |AntBulletEnv-v0 | 2399.147| 75.410|1M | 150000| 150| -|ddpg |BipedalWalker-v3 | 197.486| 141.580|1M | 149237| 227| -|ddpg |HalfCheetahBulletEnv-v0 | 2078.325| 208.379|1M | 150000| 150| -|ddpg |HopperBulletEnv-v0 | 1157.065| 448.695|1M | 149565| 346| -|ddpg |LunarLanderContinuous-v2 | 230.217| 92.372|300k | 149862| 556| -|ddpg |MountainCarContinuous-v0 | 93.512| 0.048|300k | 149965| 2260| -|ddpg |Pendulum-v0 | -152.099| 94.282|20k | 150000| 750| -|ddpg |ReacherBulletEnv-v0 | 15.582| 9.606|300k | 150000| 1000| -|ddpg |Walker2DBulletEnv-v0 | 1387.591| 736.955|1M | 149051| 208| -|dqn |Acrobot-v1 | -76.639| 11.752|100k | 149998| 1932| -|dqn |AsteroidsNoFrameskip-v4 | 782.687| 259.247|10M | 607962| 134| -|dqn |BeamRiderNoFrameskip-v4 | 4295.946| 1790.458|10M | 600832| 37| -|dqn |BreakoutNoFrameskip-v4 | 358.327| 61.981|10M | 601461| 55| -|dqn |CartPole-v1 | 500.000| 0.000|50k | 150000| 300| -|dqn |EnduroNoFrameskip-v4 | 830.929| 194.544|10M | 599040| 14| -|dqn |LunarLander-v2 | 154.382| 79.241|100k | 149373| 200| -|dqn |MountainCar-v0 | -100.849| 9.925|120k | 149962| 1487| -|dqn |PongNoFrameskip-v4 | 20.602| 0.613|10M | 598998| 88| -|dqn |QbertNoFrameskip-v4 | 9496.774| 5399.633|10M | 605844| 124| -|dqn |RoadRunnerNoFrameskip-v4 | 40396.350| 7069.131|10M | 603257| 137| -|dqn |SeaquestNoFrameskip-v4 | 2000.290| 606.644|10M | 599505| 69| -|dqn |SpaceInvadersNoFrameskip-v4| 622.742| 201.564|10M | 604311| 155| -|ppo |Acrobot-v1 | -73.506| 18.201|1M | 149979| 2013| -|ppo |Ant-v3 | 1327.158| 451.577|1M | 149572| 175| -|ppo |AntBulletEnv-v0 | 2865.922| 56.468|2M | 150000| 150| -|ppo |AsteroidsNoFrameskip-v4 | 2156.174| 744.640|10M | 602092| 149| -|ppo |BeamRiderNoFrameskip-v4 | 3397.000| 1662.368|10M | 598926| 46| -|ppo |BipedalWalker-v3 | 213.299| 129.490|5M | 149826| 233| -|ppo |BipedalWalkerHardcore-v3 | 122.374| 117.605|100M | 148036| 105| -|ppo |BreakoutNoFrameskip-v4 | 398.033| 33.328|10M | 600418| 60| -|ppo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300| -|ppo |EnduroNoFrameskip-v4 | 996.364| 176.090|10M | 572416| 11| -|ppo |HalfCheetah-v3 | 5819.099| 663.530|1M | 150000| 150| -|ppo |HalfCheetahBulletEnv-v0 | 2924.721| 64.465|2M | 150000| 150| -|ppo |Hopper-v3 | 2410.435| 10.026|1M | 150000| 150| -|ppo |HopperBulletEnv-v0 | 2575.054| 223.301|2M | 149094| 152| -|ppo |LunarLander-v2 | 242.119| 31.823|1M | 149636| 369| -|ppo |LunarLanderContinuous-v2 | 270.863| 32.072|1M | 149956| 526| -|ppo |MountainCar-v0 | -110.423| 19.473|1M | 149954| 1358| -|ppo |MountainCarContinuous-v0 | 88.343| 2.572|20k | 149983| 633| -|ppo |Pendulum-v0 | -172.225| 104.159|100k | 150000| 750| -|ppo |PongNoFrameskip-v4 | 20.989| 0.105|10M | 599902| 90| -|ppo |QbertNoFrameskip-v4 | 15627.108| 3313.538|10M | 600248| 83| -|ppo |ReacherBulletEnv-v0 | 17.091| 11.048|1M | 150000| 1000| -|ppo |RoadRunnerNoFrameskip-v4 | 40680.645| 6675.058|10M | 605786| 155| -|ppo |SeaquestNoFrameskip-v4 | 1783.636| 34.096|10M | 598243| 66| -|ppo |SpaceInvadersNoFrameskip-v4| 960.331| 425.355|10M | 603771| 136| -|ppo |Swimmer-v3 | 281.561| 9.671|1M | 150000| 150| -|ppo |Walker2DBulletEnv-v0 | 2109.992| 13.899|2M | 150000| 150| -|ppo |Walker2d-v3 | 3478.798| 821.708|1M | 149343| 171| -|qrdqn|Acrobot-v1 | -69.135| 9.967|100k | 149949| 2138| -|qrdqn|AsteroidsNoFrameskip-v4 | 2185.303| 1097.172|10M | 599784| 66| -|qrdqn|BeamRiderNoFrameskip-v4 | 17122.941| 10769.997|10M | 596483| 17| -|qrdqn|BreakoutNoFrameskip-v4 | 393.600| 79.828|10M | 579711| 40| -|qrdqn|CartPole-v1 | 500.000| 0.000|50k | 150000| 300| -|qrdqn|EnduroNoFrameskip-v4 | 3231.200| 1311.801|10M | 585728| 5| -|qrdqn|LunarLander-v2 | 70.236| 225.491|100k | 149957| 522| -|qrdqn|MountainCar-v0 | -106.042| 15.536|120k | 149943| 1414| -|qrdqn|PongNoFrameskip-v4 | 20.492| 0.687|10M | 597443| 63| -|qrdqn|QbertNoFrameskip-v4 | 14799.728| 2917.629|10M | 600773| 92| -|qrdqn|RoadRunnerNoFrameskip-v4 | 42325.424| 8361.161|10M | 591016| 59| -|qrdqn|SeaquestNoFrameskip-v4 | 2557.576| 76.951|10M | 596275| 66| -|qrdqn|SpaceInvadersNoFrameskip-v4| 1899.928| 823.488|10M | 597218| 69| -|sac |Ant-v3 | 4615.791| 1354.111|1M | 149074| 165| -|sac |AntBulletEnv-v0 | 3073.114| 175.148|1M | 150000| 150| -|sac |BipedalWalker-v3 | 297.668| 33.060|500k | 149530| 136| -|sac |BipedalWalkerHardcore-v3 | 4.423| 103.910|10M | 149794| 88| -|sac |HalfCheetah-v3 | 9535.451| 100.470|1M | 150000| 150| -|sac |HalfCheetahBulletEnv-v0 | 2792.170| 12.088|1M | 150000| 150| -|sac |Hopper-v3 | 2325.547| 1129.676|1M | 149841| 236| -|sac |HopperBulletEnv-v0 | 2603.494| 164.322|1M | 149724| 151| -|sac |Humanoid-v3 | 6232.287| 279.885|2M | 149460| 150| -|sac |LunarLanderContinuous-v2 | 260.390| 65.467|500k | 149634| 672| -|sac |MountainCarContinuous-v0 | 94.679| 1.134|50k | 149966| 1443| -|sac |Pendulum-v0 | -156.995| 88.714|20k | 150000| 750| -|sac |ReacherBulletEnv-v0 | 18.062| 9.729|300k | 150000| 1000| -|sac |Swimmer-v3 | 345.568| 3.084|1M | 150000| 150| -|sac |Walker2DBulletEnv-v0 | 2292.266| 13.970|1M | 149983| 150| -|sac |Walker2d-v3 | 3863.203| 254.347|1M | 149309| 150| -|td3 |Ant-v3 | 5813.274| 589.773|1M | 149393| 151| -|td3 |AntBulletEnv-v0 | 3300.026| 54.640|1M | 150000| 150| -|td3 |BipedalWalker-v3 | 305.990| 56.886|1M | 149999| 224| -|td3 |BipedalWalkerHardcore-v3 | -98.116| 16.087|10M | 150000| 75| -|td3 |HalfCheetah-v3 | 9655.666| 969.916|1M | 150000| 150| -|td3 |HalfCheetahBulletEnv-v0 | 2821.641| 19.722|1M | 150000| 150| -|td3 |Hopper-v3 | 3606.390| 4.027|1M | 150000| 150| -|td3 |HopperBulletEnv-v0 | 2681.609| 27.806|1M | 149486| 150| -|td3 |Humanoid-v3 | 5566.687| 14.544|2M | 150000| 150| -|td3 |LunarLanderContinuous-v2 | 207.451| 67.562|300k | 149488| 337| -|td3 |MountainCarContinuous-v0 | 93.483| 0.075|300k | 149976| 2275| -|td3 |Pendulum-v0 | -151.855| 90.227|20k | 150000| 750| -|td3 |ReacherBulletEnv-v0 | 17.114| 9.750|300k | 150000| 1000| -|td3 |Swimmer-v3 | 359.127| 1.244|1M | 150000| 150| -|td3 |Walker2DBulletEnv-v0 | 2213.672| 230.558|1M | 149800| 152| -|td3 |Walker2d-v3 | 4717.823| 46.303|1M | 150000| 150| -|tqc |Ant-v3 | 3339.362| 1969.906|1M | 149583| 202| -|tqc |AntBulletEnv-v0 | 3456.717| 248.733|1M | 150000| 150| -|tqc |BipedalWalker-v3 | 329.808| 45.083|500k | 149682| 254| -|tqc |BipedalWalkerHardcore-v3 | 235.226| 110.569|2M | 149032| 131| -|tqc |FetchPickAndPlace-v1 | -9.331| 6.850|1M | 150000| 3000| -|tqc |FetchPush-v1 | -8.799| 5.438|1M | 150000| 3000| -|tqc |FetchReach-v1 | -1.659| 0.873|20k | 150000| 3000| -|tqc |FetchSlide-v1 | -29.210| 11.387|3M | 150000| 3000| -|tqc |HalfCheetah-v3 | 12089.939| 127.440|1M | 150000| 150| -|tqc |HalfCheetahBulletEnv-v0 | 3675.299| 17.681|1M | 150000| 150| -|tqc |Hopper-v3 | 3754.199| 8.276|1M | 150000| 150| -|tqc |HopperBulletEnv-v0 | 2662.373| 206.210|1M | 149881| 151| -|tqc |Humanoid-v3 | 7239.320| 1647.498|2M | 149508| 165| -|tqc |LunarLanderContinuous-v2 | 277.956| 25.466|500k | 149928| 706| -|tqc |MountainCarContinuous-v0 | 63.641| 45.259|50k | 149796| 186| -|tqc |PandaPickAndPlace-v1 | -8.024| 6.674|1M | 150000| 3000| -|tqc |PandaPush-v1 | -6.405| 6.400|1M | 150000| 3000| -|tqc |PandaReach-v1 | -1.768| 0.858|20k | 150000| 3000| -|tqc |PandaSlide-v1 | -27.497| 9.868|3M | 150000| 3000| -|tqc |PandaStack-v1 | -96.915| 17.240|1M | 150000| 1500| -|tqc |Pendulum-v0 | -151.340| 87.893|20k | 150000| 750| -|tqc |ReacherBulletEnv-v0 | 18.255| 9.543|300k | 150000| 1000| -|tqc |Swimmer-v3 | 339.423| 1.486|1M | 150000| 150| -|tqc |Walker2DBulletEnv-v0 | 2508.934| 614.624|1M | 149572| 159| -|tqc |Walker2d-v3 | 4380.720| 500.489|1M | 149606| 152| -|tqc |parking-v0 | -6.762| 2.690|100k | 149983| 7528| -|trpo |Acrobot-v1 | -83.114| 18.648|100k | 149976| 1783| -|trpo |Ant-v3 | 4982.301| 663.761|1M | 149909| 153| -|trpo |AntBulletEnv-v0 | 2560.621| 52.064|2M | 150000| 150| -|trpo |BipedalWalker-v3 | 182.339| 145.570|1M | 148440| 148| -|trpo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300| -|trpo |HalfCheetah-v3 | 1785.476| 68.672|1M | 150000| 150| -|trpo |HalfCheetahBulletEnv-v0 | 2758.752| 327.032|2M | 150000| 150| -|trpo |Hopper-v3 | 3618.386| 356.768|1M | 149575| 152| -|trpo |HopperBulletEnv-v0 | 2565.416| 410.298|1M | 149640| 154| -|trpo |LunarLander-v2 | 133.166| 112.173|200k | 149088| 230| -|trpo |LunarLanderContinuous-v2 | 262.387| 21.428|200k | 149925| 501| -|trpo |MountainCar-v0 | -107.278| 13.231|100k | 149974| 1398| -|trpo |MountainCarContinuous-v0 | 92.489| 0.355|50k | 149971| 1732| -|trpo |Pendulum-v0 | -174.631| 127.577|100k | 150000| 750| -|trpo |ReacherBulletEnv-v0 | 14.741| 11.559|300k | 150000| 1000| -|trpo |Swimmer-v3 | 365.663| 2.087|1M | 150000| 150| -|trpo |Walker2DBulletEnv-v0 | 1483.467| 823.468|2M | 149860| 197| -|trpo |Walker2d-v3 | 4933.148| 1452.538|1M | 149054| 163| +| algo | env_id |mean_reward|std_reward|n_timesteps|eval_timesteps|eval_episodes| +|--------|-----------------------------|----------:|---------:|-----------|-------------:|------------:| +|a2c |Acrobot-v1 | -83.353| 17.213|500k | 149979| 1778| +|a2c |AntBulletEnv-v0 | 2497.147| 37.359|2M | 150000| 150| +|a2c |AsteroidsNoFrameskip-v4 | 1286.550| 423.750|10M | 614138| 258| +|a2c |BeamRiderNoFrameskip-v4 | 2890.298| 1379.137|10M | 591104| 47| +|a2c |BipedalWalker-v3 | 299.754| 23.459|5M | 149287| 208| +|a2c |BipedalWalkerHardcore-v3 | 96.171| 122.943|200M | 149704| 113| +|a2c |BreakoutNoFrameskip-v4 | 279.793| 122.177|10M | 604115| 82| +|a2c |CartPole-v1 | 500.000| 0.000|500k | 150000| 300| +|a2c |EnduroNoFrameskip-v4 | 0.000| 0.000|10M | 599040| 45| +|a2c |HalfCheetah-v3 | 3041.174| 157.265|1M | 150000| 150| +|a2c |HalfCheetahBulletEnv-v0 | 2107.384| 36.008|2M | 150000| 150| +|a2c |Hopper-v3 | 733.454| 376.574|1M | 149987| 580| +|a2c |HopperBulletEnv-v0 | 815.355| 313.798|2M | 149541| 254| +|a2c |LunarLander-v2 | 155.751| 80.419|200k | 149443| 297| +|a2c |LunarLanderContinuous-v2 | 84.225| 145.906|5M | 149305| 256| +|a2c |MountainCar-v0 | -111.263| 24.087|1M | 149982| 1348| +|a2c |MountainCarContinuous-v0 | 91.166| 0.255|100k | 149923| 1659| +|a2c |Pendulum-v1 | -162.965| 103.210|1M | 150000| 750| +|a2c |PongNoFrameskip-v4 | 17.292| 3.214|10M | 594910| 65| +|a2c |QbertNoFrameskip-v4 | 3882.345| 1223.327|10M | 610670| 194| +|a2c |ReacherBulletEnv-v0 | 14.968| 10.978|2M | 150000| 1000| +|a2c |RoadRunnerNoFrameskip-v4 | 31671.512| 6364.085|10M | 606710| 172| +|a2c |SeaquestNoFrameskip-v4 | 1721.493| 105.339|10M | 599691| 67| +|a2c |SpaceInvadersNoFrameskip-v4 | 627.160| 201.974|10M | 604848| 162| +|a2c |Swimmer-v3 | 200.627| 2.544|1M | 150000| 150| +|a2c |Walker2DBulletEnv-v0 | 858.209| 333.116|2M | 149156| 173| +|ars |Acrobot-v1 | -82.884| 23.825|500k | 149985| 1788| +|ars |Ant-v3 | 2333.773| 20.597|75M | 150000| 150| +|ars |CartPole-v1 | 500.000| 0.000|50k | 150000| 300| +|ars |HalfCheetah-v3 | 4815.192| 1340.752|12M | 150000| 150| +|ars |Hopper-v3 | 3343.919| 5.730|7M | 150000| 150| +|ars |LunarLanderContinuous-v2 | 167.959| 147.071|2M | 149883| 562| +|ars |MountainCar-v0 | -122.000| 33.456|500k | 149938| 1229| +|ars |MountainCarContinuous-v0 | 96.672| 0.784|500k | 149990| 621| +|ars |Pendulum-v1 | -212.540| 160.444|2M | 150000| 750| +|ars |Swimmer-v3 | 355.267| 12.796|2M | 150000| 150| +|ars |Walker2d-v3 | 2993.582| 166.289|75M | 149821| 152| +|ddpg |AntBulletEnv-v0 | 2399.147| 75.410|1M | 150000| 150| +|ddpg |BipedalWalker-v3 | 197.486| 141.580|1M | 149237| 227| +|ddpg |HalfCheetahBulletEnv-v0 | 2078.325| 208.379|1M | 150000| 150| +|ddpg |HopperBulletEnv-v0 | 1157.065| 448.695|1M | 149565| 346| +|ddpg |LunarLanderContinuous-v2 | 230.217| 92.372|300k | 149862| 556| +|ddpg |MountainCarContinuous-v0 | 93.512| 0.048|300k | 149965| 2260| +|ddpg |Pendulum-v1 | -152.099| 94.282|20k | 150000| 750| +|ddpg |ReacherBulletEnv-v0 | 15.582| 9.606|300k | 150000| 1000| +|ddpg |Walker2DBulletEnv-v0 | 1387.591| 736.955|1M | 149051| 208| +|dqn |Acrobot-v1 | -76.639| 11.752|100k | 149998| 1932| +|dqn |AsteroidsNoFrameskip-v4 | 782.687| 259.247|10M | 607962| 134| +|dqn |BeamRiderNoFrameskip-v4 | 4295.946| 1790.458|10M | 600832| 37| +|dqn |BreakoutNoFrameskip-v4 | 358.327| 61.981|10M | 601461| 55| +|dqn |CartPole-v1 | 500.000| 0.000|50k | 150000| 300| +|dqn |EnduroNoFrameskip-v4 | 830.929| 194.544|10M | 599040| 14| +|dqn |LunarLander-v2 | 154.382| 79.241|100k | 149373| 200| +|dqn |MountainCar-v0 | -100.849| 9.925|120k | 149962| 1487| +|dqn |PongNoFrameskip-v4 | 20.602| 0.613|10M | 598998| 88| +|dqn |QbertNoFrameskip-v4 | 9496.774| 5399.633|10M | 605844| 124| +|dqn |RoadRunnerNoFrameskip-v4 | 40396.350| 7069.131|10M | 603257| 137| +|dqn |SeaquestNoFrameskip-v4 | 2000.290| 606.644|10M | 599505| 69| +|dqn |SpaceInvadersNoFrameskip-v4 | 622.742| 201.564|10M | 604311| 155| +|ppo |Acrobot-v1 | -73.506| 18.201|1M | 149979| 2013| +|ppo |Ant-v3 | 1327.158| 451.577|1M | 149572| 175| +|ppo |AntBulletEnv-v0 | 2865.922| 56.468|2M | 150000| 150| +|ppo |AsteroidsNoFrameskip-v4 | 2156.174| 744.640|10M | 602092| 149| +|ppo |BeamRiderNoFrameskip-v4 | 3397.000| 1662.368|10M | 598926| 46| +|ppo |BipedalWalker-v3 | 287.939| 2.448|5M | 149589| 123| +|ppo |BipedalWalkerHardcore-v3 | 122.374| 117.605|100M | 148036| 105| +|ppo |BreakoutNoFrameskip-v4 | 398.033| 33.328|10M | 600418| 60| +|ppo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300| +|ppo |EnduroNoFrameskip-v4 | 996.364| 176.090|10M | 572416| 11| +|ppo |HalfCheetah-v3 | 5819.099| 663.530|1M | 150000| 150| +|ppo |HalfCheetahBulletEnv-v0 | 2924.721| 64.465|2M | 150000| 150| +|ppo |Hopper-v3 | 2410.435| 10.026|1M | 150000| 150| +|ppo |HopperBulletEnv-v0 | 2575.054| 223.301|2M | 149094| 152| +|ppo |LunarLander-v2 | 242.119| 31.823|1M | 149636| 369| +|ppo |LunarLanderContinuous-v2 | 270.863| 32.072|1M | 149956| 526| +|ppo |MountainCar-v0 | -110.423| 19.473|1M | 149954| 1358| +|ppo |MountainCarContinuous-v0 | 88.343| 2.572|20k | 149983| 633| +|ppo |Pendulum-v1 | -172.225| 104.159|100k | 150000| 750| +|ppo |PongNoFrameskip-v4 | 20.989| 0.105|10M | 599902| 90| +|ppo |QbertNoFrameskip-v4 | 15627.108| 3313.538|10M | 600248| 83| +|ppo |ReacherBulletEnv-v0 | 17.091| 11.048|1M | 150000| 1000| +|ppo |RoadRunnerNoFrameskip-v4 | 40680.645| 6675.058|10M | 605786| 155| +|ppo |SeaquestNoFrameskip-v4 | 1783.636| 34.096|10M | 598243| 66| +|ppo |SpaceInvadersNoFrameskip-v4 | 960.331| 425.355|10M | 603771| 136| +|ppo |Swimmer-v3 | 281.561| 9.671|1M | 150000| 150| +|ppo |Walker2DBulletEnv-v0 | 2109.992| 13.899|2M | 150000| 150| +|ppo |Walker2d-v3 | 3478.798| 821.708|1M | 149343| 171| +|ppo_lstm|CarRacing-v0 | 862.549| 97.342|4M | 149588| 156| +|ppo_lstm|CartPoleNoVel-v1 | 500.000| 0.000|100k | 150000| 300| +|ppo_lstm|MountainCarContinuousNoVel-v0| 91.469| 1.776|300k | 149882| 1340| +|ppo_lstm|PendulumNoVel-v1 | -217.933| 140.094|100k | 150000| 750| +|qrdqn |Acrobot-v1 | -69.135| 9.967|100k | 149949| 2138| +|qrdqn |AsteroidsNoFrameskip-v4 | 2185.303| 1097.172|10M | 599784| 66| +|qrdqn |BeamRiderNoFrameskip-v4 | 17122.941| 10769.997|10M | 596483| 17| +|qrdqn |BreakoutNoFrameskip-v4 | 393.600| 79.828|10M | 579711| 40| +|qrdqn |CartPole-v1 | 500.000| 0.000|50k | 150000| 300| +|qrdqn |EnduroNoFrameskip-v4 | 3231.200| 1311.801|10M | 585728| 5| +|qrdqn |LunarLander-v2 | 70.236| 225.491|100k | 149957| 522| +|qrdqn |MountainCar-v0 | -106.042| 15.536|120k | 149943| 1414| +|qrdqn |PongNoFrameskip-v4 | 20.492| 0.687|10M | 597443| 63| +|qrdqn |QbertNoFrameskip-v4 | 14799.728| 2917.629|10M | 600773| 92| +|qrdqn |RoadRunnerNoFrameskip-v4 | 42325.424| 8361.161|10M | 591016| 59| +|qrdqn |SeaquestNoFrameskip-v4 | 2557.576| 76.951|10M | 596275| 66| +|qrdqn |SpaceInvadersNoFrameskip-v4 | 1899.928| 823.488|10M | 597218| 69| +|sac |Ant-v3 | 4615.791| 1354.111|1M | 149074| 165| +|sac |AntBulletEnv-v0 | 3073.114| 175.148|1M | 150000| 150| +|sac |BipedalWalker-v3 | 297.668| 33.060|500k | 149530| 136| +|sac |BipedalWalkerHardcore-v3 | 4.423| 103.910|10M | 149794| 88| +|sac |HalfCheetah-v3 | 9535.451| 100.470|1M | 150000| 150| +|sac |HalfCheetahBulletEnv-v0 | 2792.170| 12.088|1M | 150000| 150| +|sac |Hopper-v3 | 2325.547| 1129.676|1M | 149841| 236| +|sac |HopperBulletEnv-v0 | 2603.494| 164.322|1M | 149724| 151| +|sac |Humanoid-v3 | 6232.287| 279.885|2M | 149460| 150| +|sac |LunarLanderContinuous-v2 | 260.390| 65.467|500k | 149634| 672| +|sac |MountainCarContinuous-v0 | 94.679| 1.134|50k | 149966| 1443| +|sac |Pendulum-v1 | -156.995| 88.714|20k | 150000| 750| +|sac |ReacherBulletEnv-v0 | 18.062| 9.729|300k | 150000| 1000| +|sac |Swimmer-v3 | 345.568| 3.084|1M | 150000| 150| +|sac |Walker2DBulletEnv-v0 | 2292.266| 13.970|1M | 149983| 150| +|sac |Walker2d-v3 | 3863.203| 254.347|1M | 149309| 150| +|td3 |Ant-v3 | 5813.274| 589.773|1M | 149393| 151| +|td3 |AntBulletEnv-v0 | 3300.026| 54.640|1M | 150000| 150| +|td3 |BipedalWalker-v3 | 305.990| 56.886|1M | 149999| 224| +|td3 |BipedalWalkerHardcore-v3 | -98.116| 16.087|10M | 150000| 75| +|td3 |HalfCheetah-v3 | 9655.666| 969.916|1M | 150000| 150| +|td3 |HalfCheetahBulletEnv-v0 | 2821.641| 19.722|1M | 150000| 150| +|td3 |Hopper-v3 | 3606.390| 4.027|1M | 150000| 150| +|td3 |HopperBulletEnv-v0 | 2681.609| 27.806|1M | 149486| 150| +|td3 |Humanoid-v3 | 5566.687| 14.544|2M | 150000| 150| +|td3 |LunarLanderContinuous-v2 | 207.451| 67.562|300k | 149488| 337| +|td3 |MountainCarContinuous-v0 | 93.483| 0.075|300k | 149976| 2275| +|td3 |Pendulum-v1 | -151.855| 90.227|20k | 150000| 750| +|td3 |ReacherBulletEnv-v0 | 17.114| 9.750|300k | 150000| 1000| +|td3 |Swimmer-v3 | 359.127| 1.244|1M | 150000| 150| +|td3 |Walker2DBulletEnv-v0 | 2213.672| 230.558|1M | 149800| 152| +|td3 |Walker2d-v3 | 4717.823| 46.303|1M | 150000| 150| +|tqc |Ant-v3 | 3339.362| 1969.906|1M | 149583| 202| +|tqc |AntBulletEnv-v0 | 3456.717| 248.733|1M | 150000| 150| +|tqc |BipedalWalker-v3 | 329.808| 45.083|500k | 149682| 254| +|tqc |BipedalWalkerHardcore-v3 | 235.226| 110.569|2M | 149032| 131| +|tqc |FetchPickAndPlace-v1 | -9.331| 6.850|1M | 150000| 3000| +|tqc |FetchPush-v1 | -8.799| 5.438|1M | 150000| 3000| +|tqc |FetchReach-v1 | -1.659| 0.873|20k | 150000| 3000| +|tqc |FetchSlide-v1 | -29.210| 11.387|3M | 150000| 3000| +|tqc |HalfCheetah-v3 | 12089.939| 127.440|1M | 150000| 150| +|tqc |HalfCheetahBulletEnv-v0 | 3675.299| 17.681|1M | 150000| 150| +|tqc |Hopper-v3 | 3754.199| 8.276|1M | 150000| 150| +|tqc |HopperBulletEnv-v0 | 2662.373| 206.210|1M | 149881| 151| +|tqc |Humanoid-v3 | 7239.320| 1647.498|2M | 149508| 165| +|tqc |LunarLanderContinuous-v2 | 277.956| 25.466|500k | 149928| 706| +|tqc |MountainCarContinuous-v0 | 63.641| 45.259|50k | 149796| 186| +|tqc |PandaPickAndPlace-v1 | -8.024| 6.674|1M | 150000| 3000| +|tqc |PandaPush-v1 | -6.405| 6.400|1M | 150000| 3000| +|tqc |PandaReach-v1 | -1.768| 0.858|20k | 150000| 3000| +|tqc |PandaSlide-v1 | -27.497| 9.868|3M | 150000| 3000| +|tqc |PandaStack-v1 | -96.915| 17.240|1M | 150000| 1500| +|tqc |Pendulum-v1 | -151.340| 87.893|20k | 150000| 750| +|tqc |ReacherBulletEnv-v0 | 18.255| 9.543|300k | 150000| 1000| +|tqc |Swimmer-v3 | 339.423| 1.486|1M | 150000| 150| +|tqc |Walker2DBulletEnv-v0 | 2508.934| 614.624|1M | 149572| 159| +|tqc |Walker2d-v3 | 4380.720| 500.489|1M | 149606| 152| +|tqc |parking-v0 | -6.762| 2.690|100k | 149983| 7528| +|trpo |Acrobot-v1 | -83.114| 18.648|100k | 149976| 1783| +|trpo |Ant-v3 | 4982.301| 663.761|1M | 149909| 153| +|trpo |AntBulletEnv-v0 | 2560.621| 52.064|2M | 150000| 150| +|trpo |BipedalWalker-v3 | 182.339| 145.570|1M | 148440| 148| +|trpo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300| +|trpo |HalfCheetah-v3 | 1785.476| 68.672|1M | 150000| 150| +|trpo |HalfCheetahBulletEnv-v0 | 2758.752| 327.032|2M | 150000| 150| +|trpo |Hopper-v3 | 3618.386| 356.768|1M | 149575| 152| +|trpo |HopperBulletEnv-v0 | 2565.416| 410.298|1M | 149640| 154| +|trpo |LunarLander-v2 | 133.166| 112.173|200k | 149088| 230| +|trpo |LunarLanderContinuous-v2 | 262.387| 21.428|200k | 149925| 501| +|trpo |MountainCar-v0 | -107.278| 13.231|100k | 149974| 1398| +|trpo |MountainCarContinuous-v0 | 92.489| 0.355|50k | 149971| 1732| +|trpo |Pendulum-v1 | -174.631| 127.577|100k | 150000| 750| +|trpo |ReacherBulletEnv-v0 | 14.741| 11.559|300k | 150000| 1000| +|trpo |Swimmer-v3 | 365.663| 2.087|1M | 150000| 150| +|trpo |Walker2DBulletEnv-v0 | 1483.467| 823.468|2M | 149860| 197| +|trpo |Walker2d-v3 | 4933.148| 1452.538|1M | 149054| 163| diff --git a/logs/benchmark/ppo-BipedalWalker-v3/0.monitor.csv b/logs/benchmark/ppo-BipedalWalker-v3/0.monitor.csv index 65b44f3af..6a2bc3eca 100644 --- a/logs/benchmark/ppo-BipedalWalker-v3/0.monitor.csv +++ b/logs/benchmark/ppo-BipedalWalker-v3/0.monitor.csv @@ -1,235 +1,125 @@ -#{"t_start": 1614771346.3718872, "env_id": "BipedalWalker-v3"} +#{"t_start": 1654205553.3965414, "env_id": "BipedalWalker-v3"} r,l,t -302.352527,761,3.219014 -304.643791,747,3.818765 -118.656863,577,4.303263 -186.806605,791,4.924971 -302.92318,727,5.496392 -95.043446,526,5.913008 -304.071331,736,6.488454 -304.961145,742,7.067342 -305.404706,747,7.653983 -43.647549,413,7.977797 -305.786318,733,8.550616 --65.333118,166,8.681444 -300.446239,878,9.366114 -302.302039,756,9.955367 --63.398006,161,10.085125 -303.612975,709,10.647742 -301.85852,756,11.242622 -302.07473,754,11.829957 -302.562625,784,12.443302 -302.523438,773,13.046803 -304.049687,740,13.627153 -304.22878,791,14.251293 -301.844681,798,14.872318 -304.990141,727,15.439051 -303.951432,747,16.025128 -301.880788,764,16.623663 -302.835429,744,17.20378 --63.166425,240,17.392857 --19.553167,255,17.598439 -71.883582,459,17.96465 -171.700376,662,18.479113 -303.257497,755,19.065179 --31.888369,226,19.241292 -301.43717,821,19.880159 -302.191556,772,20.484396 --45.01458,317,20.733858 -184.726106,752,21.320945 -302.56233,759,21.909227 -303.76186,792,22.526143 -301.663292,739,23.103 -303.368805,742,23.679868 -302.449036,748,24.260264 -303.924373,751,24.844364 -303.955462,754,25.4316 -305.325247,738,26.004931 -305.033799,736,26.575473 -182.856283,707,27.126774 -302.0432,770,27.727888 -304.139851,732,28.296654 -303.16995,764,28.890888 -301.428877,755,29.481789 -303.415154,745,30.066476 -306.591512,727,30.635611 -304.58744,762,31.230929 -303.367456,763,31.826358 -305.595314,753,32.413927 -4.373387,304,32.651849 -193.166831,746,33.231313 -125.378329,552,33.665584 -13.381573,340,33.931804 -96.788857,505,34.321924 -39.350058,450,34.66974 -305.335492,746,35.251793 --52.636007,202,35.409539 -304.796698,728,35.976646 -6.380472,301,36.210353 -304.459181,742,36.788108 -140.072415,754,37.373717 --12.509825,271,37.585901 -307.386125,708,38.136988 -303.447734,753,38.725262 -126.609499,627,39.212487 -302.364809,738,39.786229 -304.475175,725,40.354099 -302.940891,758,40.943596 -15.795655,325,41.197428 -305.404282,735,41.771184 -305.236132,723,42.330353 -305.94299,752,42.913521 -21.148425,353,43.189599 --52.330365,188,43.340062 -151.386882,659,43.851162 -90.424794,494,44.232565 -301.927897,770,44.826415 -304.757464,721,45.382777 -302.697478,760,45.96874 --22.39859,259,46.169842 -305.956405,723,46.73031 -305.664408,727,47.290006 -305.287772,717,47.841211 --22.082868,254,48.038846 -304.894636,746,48.616611 -303.331698,741,49.187315 -304.592178,747,49.760554 -305.563102,736,50.328383 -302.689926,746,50.902858 -300.186937,802,51.52006 --15.602549,276,51.733101 -94.133354,512,52.128198 -305.610114,726,52.68604 -298.369037,798,53.300444 -95.514182,510,53.69267 -301.644435,796,54.305062 -306.86572,713,54.854769 -304.306364,747,55.433634 -303.593145,771,56.037491 -39.926782,380,56.334806 --40.461412,217,56.50227 --71.978628,144,56.616035 -127.880317,642,57.109494 -96.525125,545,57.532723 -54.597902,435,57.869843 -57.3985,418,58.193319 -302.945615,740,58.763972 -73.387278,473,59.135621 -302.676462,753,59.71879 -304.187669,766,60.308922 -301.580132,810,60.937058 -300.510398,797,61.56038 -76.407432,559,62.000009 -304.560592,739,62.571082 -304.481442,771,63.166171 -303.622933,732,63.738984 -118.340882,581,64.186391 -304.182357,734,64.751831 -131.486683,644,65.245947 -302.13052,756,65.827673 -303.603819,753,66.405886 -304.842933,710,66.955831 -304.3627,772,67.550153 -106.366572,585,68.004118 -301.121862,819,68.633571 -303.438111,757,69.220238 -71.421583,675,69.735597 -303.017117,767,70.33188 -31.324711,356,70.612745 -301.624264,746,71.196275 -301.047988,801,71.811748 -303.461027,752,72.399489 -302.993697,748,72.975591 -304.899086,759,73.563025 -302.1779,776,74.158236 -305.593249,735,74.725231 -305.155352,741,75.293045 -302.828854,747,75.872387 --50.970516,283,76.090183 --39.507444,225,76.264404 -303.69021,749,76.843124 -300.996457,780,77.443655 -304.156677,746,78.018283 -303.679295,761,78.605159 --18.347728,291,78.831671 -156.79956,689,79.371623 -136.722458,660,79.888347 -64.885886,483,80.268309 -41.603007,411,80.592154 -304.288839,763,81.193667 -35.546111,373,81.487278 -304.284635,747,82.073154 --73.988576,149,82.193739 -142.886248,656,82.707909 -302.361418,750,83.298816 -304.1507,729,83.875597 -304.413431,767,84.478215 -305.101579,743,85.058137 -199.662759,723,85.623957 -305.104542,726,86.195029 -301.071559,796,86.818468 -301.539925,798,87.450715 -301.531898,815,88.096411 -302.795926,761,88.697836 -303.416066,744,89.286814 -30.285208,399,89.600553 --40.258154,207,89.764034 -51.430234,396,90.076918 -40.241742,405,90.39917 -305.114193,733,90.973302 -306.063059,728,91.546085 --47.185539,228,91.725956 -300.573573,790,92.342886 --54.868415,200,92.501091 -303.220558,760,93.100096 -304.546298,752,93.692549 -82.932796,473,94.063656 -47.441718,427,94.402995 -303.839423,725,94.966027 -305.277608,763,95.567472 -23.496617,367,95.856493 -128.678039,569,96.299928 --32.592752,223,96.476208 --43.753721,211,96.642032 -305.424057,751,97.228528 -113.063822,584,97.688616 -303.492005,756,98.289925 -304.841375,758,98.880774 -304.577313,758,99.476772 -301.156213,802,100.103982 -304.010025,783,100.716021 -302.667361,740,101.299324 -303.918448,740,101.882527 -301.425668,778,102.492409 -305.446357,715,103.055928 -92.767958,589,103.520237 -301.549759,799,104.138262 -303.163771,739,104.712989 -300.678757,794,105.329364 --41.254046,225,105.503753 -303.030253,793,106.118409 -162.468708,678,106.652562 -51.832027,451,107.011397 -298.00122,862,107.677496 -303.813658,724,108.236765 -303.409989,750,108.822069 -302.429025,762,109.423768 -160.56586,682,109.950236 -303.36268,747,110.535994 -304.766964,733,111.114457 -303.314112,736,111.684765 -303.582416,751,112.266338 -302.967288,736,112.846985 --70.653028,183,112.992295 -304.122437,744,113.57416 -164.183535,692,114.111993 -303.447712,736,114.683947 -302.214696,756,115.280121 -158.744251,733,115.843334 -137.7137,724,116.403516 -297.088702,858,117.077439 -303.803565,749,117.659656 -307.109074,717,118.218738 -170.724354,709,118.780417 -303.059612,752,119.374358 --75.075435,127,119.475801 +286.498863,1232,3.148969 +290.03395,1193,3.942982 +285.871843,1236,4.766715 +291.850978,1198,5.564058 +288.768871,1210,6.368981 +288.157038,1212,7.176008 +289.361869,1202,7.97352 +287.842606,1225,8.786903 +287.280252,1190,9.578171 +284.324796,1247,10.404098 +288.956237,1213,11.208744 +288.669449,1191,11.998618 +285.979899,1247,12.825689 +289.639171,1190,13.618726 +286.947732,1244,14.446376 +287.606982,1188,15.235259 +286.782178,1223,16.048814 +288.687124,1216,16.856583 +285.721313,1232,17.67472 +287.000542,1216,18.483637 +288.605325,1213,19.286785 +291.318049,1207,20.087602 +289.566242,1193,20.877623 +292.464293,1179,21.661781 +290.672828,1192,22.457295 +290.689305,1182,23.244957 +287.682502,1206,24.045446 +291.957954,1163,24.819003 +282.784247,1263,25.659568 +287.454218,1216,26.470194 +285.451189,1228,27.283888 +287.095007,1272,28.128968 +287.733682,1228,28.945501 +290.772435,1195,29.742402 +288.365004,1236,30.56567 +285.35095,1225,31.380461 +289.413562,1208,32.184891 +288.015387,1211,32.991751 +288.731107,1205,33.791121 +290.727256,1189,34.581942 +289.43828,1207,35.3862 +287.729774,1227,36.204048 +285.253831,1260,37.041651 +289.768125,1185,37.826303 +285.958823,1233,38.646922 +287.200568,1222,39.46325 +290.043164,1185,40.250997 +289.101907,1211,41.05368 +286.462372,1243,41.87645 +285.64948,1221,42.688791 +287.519667,1187,43.476118 +284.971831,1267,44.316454 +288.742228,1214,45.120938 +285.628617,1255,45.954174 +284.757907,1250,46.783232 +284.347205,1245,47.609643 +289.301662,1211,48.409139 +284.519905,1255,49.241136 +283.713437,1242,50.06516 +289.766618,1189,50.851352 +282.590464,1273,51.692035 +289.451637,1175,52.468679 +282.330197,1285,53.320945 +290.373129,1201,54.118149 +285.502483,1222,54.927324 +290.665951,1163,55.698771 +289.852728,1213,56.503432 +287.244561,1205,57.306046 +286.817512,1223,58.116854 +291.118836,1183,58.904925 +289.975692,1205,59.704293 +291.492401,1159,60.475756 +287.211862,1228,61.291739 +284.231949,1244,62.120625 +287.456086,1196,62.914544 +286.782568,1231,63.733352 +290.014788,1182,64.517536 +285.797936,1238,65.341084 +286.128281,1239,66.164153 +291.3156,1168,66.940698 +285.707421,1227,67.7551 +287.444993,1217,68.563281 +287.287142,1235,69.380708 +287.296313,1252,70.212887 +285.059782,1244,71.037529 +285.075845,1250,71.86698 +289.072542,1172,72.643633 +286.401297,1247,73.470394 +289.753727,1171,74.244725 +288.401924,1186,75.030774 +289.74776,1210,75.834194 +290.13916,1177,76.612051 +292.795935,1181,77.396483 +285.884796,1247,78.222464 +288.603669,1212,79.027307 +289.873579,1212,79.832399 +285.787775,1230,80.647568 +290.36039,1202,81.441948 +290.807051,1179,82.224544 +285.749472,1245,83.048675 +290.984798,1179,83.832279 +288.575334,1226,84.645881 +286.358977,1210,85.449784 +287.503688,1221,86.257498 +290.378601,1209,87.059999 +289.137484,1223,87.869361 +288.796009,1224,88.683386 +287.783809,1249,89.514697 +291.227006,1151,90.278997 +280.305489,1277,91.127263 +283.010979,1279,91.975935 +286.370462,1259,92.809959 +294.429234,1149,93.57178 +286.001682,1213,94.377335 +287.589849,1233,95.19614 +290.317775,1200,95.992744 +286.754871,1236,96.815568 +289.788937,1205,97.617576 +286.015599,1251,98.448357 +287.574138,1193,99.242593 +291.136559,1193,100.036958 +291.611669,1221,100.848697 +286.55367,1234,101.670848 diff --git a/logs/benchmark/ppo_lstm-CarRacing-v0/0.monitor.csv b/logs/benchmark/ppo_lstm-CarRacing-v0/0.monitor.csv new file mode 100644 index 000000000..147b6093c --- /dev/null +++ b/logs/benchmark/ppo_lstm-CarRacing-v0/0.monitor.csv @@ -0,0 +1,158 @@ +#{"t_start": 1654204830.0419693, "env_id": "CarRacing-v0"} +r,l,t +792.61745,1000,6.181207 +885.765125,1000,9.131669 +873.063973,1000,12.114203 +890.740741,1000,15.143978 +877.272727,1000,18.152218 +889.169675,1000,21.111168 +890.066225,1000,24.140795 +910.2,898,26.813054 +879.757085,1000,29.724176 +915.2,848,32.276649 +930.5,695,34.305526 +869.96997,1000,37.345273 +920.5,795,39.70204 +896.296296,1000,42.632593 +886.885246,1000,45.646431 +858.477509,1000,48.618872 +892.619926,1000,51.57379 +887.261146,1000,54.593082 +877.491961,1000,57.618382 +798.954704,1000,60.627211 +870.967742,1000,63.656299 +889.247312,1000,66.617474 +812.02346,1000,69.691475 +924.3,757,71.923429 +848.170732,1000,74.979877 +881.595092,1000,78.039151 +905.8,942,80.79825 +873.684211,1000,83.779894 +891.03139,1000,86.646081 +901.4,986,89.59265 +543.712575,1000,92.629697 +893.006993,1000,95.612338 +914.1,859,98.190614 +896.254682,1000,101.128834 +890.131579,1000,104.133495 +887.421384,1000,107.167422 +367.625899,1000,110.111629 +889.966555,1000,113.082222 +926.4,736,115.258553 +893.006993,1000,118.254707 +850.920245,1000,121.333169 +893.78882,1000,124.388279 +924.0,760,126.620384 +916.3,837,129.113128 +871.929825,1000,132.089347 +917.7,823,134.529402 +911.2,888,137.202349 +893.150685,1000,140.205283 +851.807229,1000,143.279442 +896.644295,1000,146.267079 +565.540541,1000,149.227145 +920.9,791,151.540307 +880.707395,1000,154.581006 +868.421053,1000,157.576324 +333.544304,1000,160.562149 +902.9,971,163.504401 +890.797546,1000,166.568459 +918.6,814,169.000762 +861.651917,1000,172.072073 +919.1,809,174.487601 +892.125984,1000,177.403417 +877.272727,1000,180.424361 +888.847584,1000,183.378517 +896.551724,1000,186.378563 +911.4,886,189.030784 +896.296296,1000,191.977422 +850.769231,1000,195.032133 +870.873786,1000,198.056901 +880.327869,1000,201.057348 +730.508475,1000,204.041407 +893.485342,1000,207.073689 +875.460123,1000,210.124961 +918.5,815,212.544901 +856.375839,1000,215.530396 +875.0,1000,218.497411 +821.875,1000,221.545459 +868.75,1000,224.616827 +893.485342,1000,227.639646 +876.510067,1000,230.643212 +750.793651,1000,233.652184 +874.683544,1000,236.677366 +919.6,804,239.046962 +668.421053,1000,242.005007 +922.7,773,244.307434 +857.928803,1000,247.306081 +916.1,839,249.813316 +873.684211,1000,252.75659 +886.013986,1000,255.756162 +886.15917,1000,258.746746 +887.220447,1000,261.774839 +914.9,851,264.3091 +873.59736,1000,267.316416 +845.619335,1000,270.370811 +928.3,717,272.490756 +896.855346,1000,275.528629 +701.775148,1000,278.577352 +908.0,920,281.343297 +925.1,749,283.526727 +874.169742,1000,286.490014 +896.563574,1000,289.479618 +907.6,924,292.206284 +862.962963,1000,295.190832 +893.569132,1000,298.233784 +883.050847,1000,301.260211 +876.744186,1000,304.196508 +878.417266,1000,307.165412 +794.409938,1000,310.230357 +885.130112,1000,313.173896 +773.239437,1000,316.211473 +880.645161,1000,319.220745 +893.127148,1000,322.228811 +872.972973,1000,325.207085 +890.196078,1000,328.239602 +862.616822,1000,331.313493 +884.126984,1000,334.355265 +910.9,891,337.007747 +896.466431,1000,339.998146 +886.30137,1000,342.973861 +908.3,917,345.713885 +879.166667,1000,348.697348 +914.3,857,351.258304 +921.4,786,353.599769 +913.8,862,356.203853 +851.890034,1000,359.209628 +886.885246,1000,362.247912 +921.4,786,364.595427 +909.0,910,367.31875 +408.928571,1000,370.375511 +884.496124,1000,373.32294 +889.169675,1000,376.328308 +736.956522,1000,379.304595 +880.263158,1000,382.357421 +863.69637,1000,385.383473 +870.149254,1000,388.316006 +887.755102,1000,391.249082 +874.110032,1000,394.259789 +885.964912,1000,397.251099 +886.254296,1000,400.253898 +890.595611,1000,403.304099 +917.8,822,405.757406 +496.923077,1000,408.775885 +882.269504,1000,411.756959 +892.509363,1000,414.711423 +889.51049,1000,417.672504 +933.8,662,419.605968 +901.9,981,422.587127 +871.061093,1000,425.625754 +509.204196,867,428.206288 +883.443709,1000,431.226022 +896.551724,1000,434.219728 +893.355482,1000,437.238052 +896.527778,1000,440.214532 +908.9,911,442.970268 +896.632997,1000,445.964749 +915.1,849,448.522999 +880.314961,1000,451.533052 diff --git a/logs/benchmark/ppo_lstm-CartPoleNoVel-v1/0.monitor.csv b/logs/benchmark/ppo_lstm-CartPoleNoVel-v1/0.monitor.csv new file mode 100644 index 000000000..34af222c7 --- /dev/null +++ b/logs/benchmark/ppo_lstm-CartPoleNoVel-v1/0.monitor.csv @@ -0,0 +1,302 @@ +#{"t_start": 1654204464.909158, "env_id": "CartPoleNoVel-v1"} +r,l,t +500.0,500,2.692042 +500.0,500,3.061221 +500.0,500,3.429941 +500.0,500,3.798015 +500.0,500,4.16562 +500.0,500,4.535252 +500.0,500,4.903308 +500.0,500,5.270751 +500.0,500,5.639004 +500.0,500,6.007921 +500.0,500,6.376441 +500.0,500,6.744478 +500.0,500,7.112341 +500.0,500,7.482308 +500.0,500,7.849999 +500.0,500,8.217563 +500.0,500,8.585177 +500.0,500,8.953705 +500.0,500,9.321366 +500.0,500,9.688998 +500.0,500,10.056689 +500.0,500,10.425152 +500.0,500,10.792864 +500.0,500,11.160644 +500.0,500,11.528224 +500.0,500,11.897002 +500.0,500,12.264578 +500.0,500,12.632639 +500.0,500,13.000528 +500.0,500,13.369246 +500.0,500,13.737418 +500.0,500,14.104923 +500.0,500,14.472503 +500.0,500,14.840163 +500.0,500,15.209556 +500.0,500,15.577311 +500.0,500,15.944919 +500.0,500,16.312535 +500.0,500,16.68169 +500.0,500,17.04989 +500.0,500,17.417994 +500.0,500,17.78551 +500.0,500,18.154947 +500.0,500,18.52279 +500.0,500,18.890535 +500.0,500,19.258383 +500.0,500,19.627402 +500.0,500,19.995366 +500.0,500,20.363095 +500.0,500,20.730902 +500.0,500,21.100272 +500.0,500,21.467908 +500.0,500,21.836105 +500.0,500,22.203988 +500.0,500,22.573255 +500.0,500,22.941118 +500.0,500,23.309029 +500.0,500,23.676801 +500.0,500,24.045863 +500.0,500,24.413585 +500.0,500,24.781154 +500.0,500,25.14879 +500.0,500,25.517311 +500.0,500,25.885126 +500.0,500,26.252733 +500.0,500,26.620603 +500.0,500,26.989465 +500.0,500,27.357436 +500.0,500,27.725199 +500.0,500,28.092894 +500.0,500,28.460764 +500.0,500,28.828977 +500.0,500,29.196474 +500.0,500,29.564007 +500.0,500,29.931772 +500.0,500,30.300713 +500.0,500,30.668463 +500.0,500,31.036354 +500.0,500,31.404369 +500.0,500,31.773325 +500.0,500,32.14113 +500.0,500,32.508817 +500.0,500,32.876434 +500.0,500,33.245192 +500.0,500,33.61313 +500.0,500,33.980896 +500.0,500,34.348483 +500.0,500,34.71754 +500.0,500,35.085228 +500.0,500,35.453714 +500.0,500,35.821541 +500.0,500,36.190614 +500.0,500,36.558201 +500.0,500,36.925709 +500.0,500,37.293244 +500.0,500,37.662188 +500.0,500,38.029895 +500.0,500,38.397722 +500.0,500,38.76532 +500.0,500,39.133835 +500.0,500,39.501667 +500.0,500,39.868861 +500.0,500,40.236549 +500.0,500,40.605014 +500.0,500,40.972768 +500.0,500,41.340191 +500.0,500,41.70803 +500.0,500,42.075676 +500.0,500,42.444287 +500.0,500,42.811715 +500.0,500,43.180297 +500.0,500,43.547704 +500.0,500,43.916625 +500.0,500,44.2841 +500.0,500,44.65163 +500.0,500,45.019142 +500.0,500,45.388199 +500.0,500,45.75553 +500.0,500,46.123094 +500.0,500,46.49043 +500.0,500,46.859583 +500.0,500,47.227015 +500.0,500,47.594541 +500.0,500,47.961887 +500.0,500,48.3308 +500.0,500,48.698397 +500.0,500,49.065922 +500.0,500,49.433188 +500.0,500,49.802063 +500.0,500,50.169557 +500.0,500,50.537384 +500.0,500,50.904823 +500.0,500,51.273586 +500.0,500,51.641172 +500.0,500,52.00872 +500.0,500,52.376147 +500.0,500,52.745009 +500.0,500,53.112735 +500.0,500,53.480421 +500.0,500,53.847858 +500.0,500,54.216092 +500.0,500,54.583584 +500.0,500,54.951037 +500.0,500,55.318777 +500.0,500,55.686959 +500.0,500,56.054983 +500.0,500,56.422587 +500.0,500,56.789983 +500.0,500,57.157527 +500.0,500,57.526113 +500.0,500,57.893817 +500.0,500,58.261189 +500.0,500,58.628896 +500.0,500,58.997658 +500.0,500,59.365289 +500.0,500,59.732727 +500.0,500,60.100073 +500.0,500,60.468939 +500.0,500,60.8367 +500.0,500,61.204306 +500.0,500,61.571862 +500.0,500,61.940647 +500.0,500,62.308141 +500.0,500,62.675627 +500.0,500,63.04311 +500.0,500,63.411911 +500.0,500,63.779265 +500.0,500,64.146678 +500.0,500,64.514226 +500.0,500,64.882792 +500.0,500,65.250464 +500.0,500,65.617751 +500.0,500,65.98526 +500.0,500,66.35398 +500.0,500,66.721609 +500.0,500,67.089066 +500.0,500,67.456589 +500.0,500,67.824868 +500.0,500,68.192576 +500.0,500,68.560127 +500.0,500,68.927803 +500.0,500,69.296072 +500.0,500,69.664061 +500.0,500,70.031497 +500.0,500,70.398919 +500.0,500,70.766376 +500.0,500,71.134893 +500.0,500,71.502915 +500.0,500,71.870472 +500.0,500,72.237988 +500.0,500,72.607162 +500.0,500,72.974757 +500.0,500,73.342227 +500.0,500,73.709692 +500.0,500,74.078555 +500.0,500,74.446039 +500.0,500,74.813464 +500.0,500,75.180903 +500.0,500,75.549648 +500.0,500,75.917047 +500.0,500,76.284521 +500.0,500,76.651951 +500.0,500,77.020492 +500.0,500,77.387824 +500.0,500,77.755272 +500.0,500,78.122754 +500.0,500,78.491089 +500.0,500,78.858404 +500.0,500,79.225901 +500.0,500,79.593222 +500.0,500,79.961842 +500.0,500,80.329139 +500.0,500,80.696746 +500.0,500,81.064445 +500.0,500,81.432933 +500.0,500,81.800267 +500.0,500,82.16798 +500.0,500,82.535597 +500.0,500,82.904233 +500.0,500,83.272056 +500.0,500,83.6396 +500.0,500,84.007157 +500.0,500,84.37485 +500.0,500,84.743217 +500.0,500,85.110842 +500.0,500,85.478438 +500.0,500,85.845806 +500.0,500,86.214685 +500.0,500,86.582283 +500.0,500,86.949661 +500.0,500,87.317286 +500.0,500,87.686204 +500.0,500,88.053945 +500.0,500,88.421412 +500.0,500,88.788829 +500.0,500,89.157843 +500.0,500,89.52521 +500.0,500,89.892562 +500.0,500,90.260231 +500.0,500,90.629248 +500.0,500,90.997033 +500.0,500,91.364556 +500.0,500,91.73211 +500.0,500,92.100667 +500.0,500,92.468477 +500.0,500,92.836055 +500.0,500,93.203776 +500.0,500,93.572638 +500.0,500,93.940437 +500.0,500,94.307746 +500.0,500,94.675457 +500.0,500,95.044019 +500.0,500,95.412237 +500.0,500,95.779851 +500.0,500,96.147491 +500.0,500,96.516354 +500.0,500,96.884081 +500.0,500,97.251848 +500.0,500,97.619372 +500.0,500,97.987247 +500.0,500,98.355906 +500.0,500,98.723397 +500.0,500,99.091107 +500.0,500,99.458724 +500.0,500,99.827759 +500.0,500,100.195648 +500.0,500,100.563196 +500.0,500,100.930649 +500.0,500,101.299022 +500.0,500,101.6664 +500.0,500,102.034079 +500.0,500,102.401322 +500.0,500,102.770082 +500.0,500,103.137965 +500.0,500,103.505491 +500.0,500,103.872831 +500.0,500,104.24165 +500.0,500,104.609235 +500.0,500,104.976797 +500.0,500,105.344488 +500.0,500,105.713245 +500.0,500,106.080787 +500.0,500,106.448571 +500.0,500,106.815917 +500.0,500,107.185104 +500.0,500,107.552569 +500.0,500,107.919791 +500.0,500,108.287235 +500.0,500,108.655843 +500.0,500,109.023948 +500.0,500,109.391358 +500.0,500,109.758862 +500.0,500,110.127638 +500.0,500,110.495283 +500.0,500,110.862811 +500.0,500,111.23066 +500.0,500,111.598368 +500.0,500,111.966573 +500.0,500,112.334113 +500.0,500,112.701541 diff --git a/logs/benchmark/ppo_lstm-MountainCarContinuousNoVel-v0/0.monitor.csv b/logs/benchmark/ppo_lstm-MountainCarContinuousNoVel-v0/0.monitor.csv new file mode 100644 index 000000000..894fd1bf5 --- /dev/null +++ b/logs/benchmark/ppo_lstm-MountainCarContinuousNoVel-v0/0.monitor.csv @@ -0,0 +1,1342 @@ +#{"t_start": 1654204713.1905868, "env_id": "MountainCarContinuousNoVel-v0"} +r,l,t +88.870471,129,2.384638 +93.200782,96,2.452994 +93.1153,103,2.524679 +89.783776,122,2.609519 +89.084276,127,2.697699 +93.17103,103,2.769043 +90.272421,119,2.85192 +93.165781,103,2.923424 +88.882704,129,3.012742 +92.67766,106,3.086247 +88.822368,129,3.17575 +93.166806,96,3.242452 +91.849978,111,3.319841 +93.109522,103,3.391466 +93.15163,97,3.458845 +93.196885,103,3.530293 +91.575946,113,3.608723 +93.184759,98,3.676693 +88.82424,129,3.766255 +93.123139,98,3.834311 +93.250074,102,3.905042 +88.706054,130,3.997004 +90.768148,117,4.078424 +93.191438,103,4.149975 +93.216239,99,4.219114 +93.197325,102,4.290391 +93.191193,104,4.362637 +89.107044,127,4.450638 +92.474766,108,4.525669 +89.620468,123,4.610974 +89.775248,122,4.695605 +89.922732,121,4.779866 +92.583188,107,4.854172 +93.18132,98,4.922286 +90.477477,118,5.004337 +90.292669,119,5.086857 +93.172218,98,5.154969 +89.948811,121,5.239173 +92.363752,108,5.314215 +89.611613,123,5.399477 +89.917513,121,5.484584 +93.125296,97,5.551962 +88.623945,131,5.642749 +89.649255,123,5.728098 +93.184295,96,5.794691 +89.616299,123,5.88021 +93.133919,98,5.948249 +93.207498,99,6.016968 +92.764831,106,6.090605 +93.144456,98,6.158695 +88.950705,128,6.247501 +89.123758,127,6.335794 +93.219048,98,6.40382 +89.781542,122,6.488462 +89.782177,122,6.573397 +89.777733,122,6.658286 +93.185252,103,6.73019 +93.182929,97,6.797729 +93.113943,96,6.864768 +89.533839,124,6.9512 +93.194117,99,7.020988 +89.177635,126,7.108326 +89.107466,127,7.196391 +93.085148,103,7.267827 +92.916511,105,7.340834 +88.98702,128,7.429554 +93.1925,102,7.50033 +89.899999,121,7.584371 +93.127303,98,7.652373 +93.208642,104,7.72451 +89.002738,128,7.813232 +93.189193,104,7.885385 +88.741387,130,7.97542 +89.93632,121,8.059277 +91.029743,115,8.138972 +92.21992,109,8.214536 +89.764298,122,8.299145 +93.113999,96,8.365746 +93.168898,103,8.437198 +93.169668,103,8.509602 +88.852565,129,8.599525 +93.19289,103,8.671314 +88.702688,130,8.761684 +90.096848,120,8.845084 +93.110618,103,8.916838 +93.025518,104,8.988879 +91.864208,111,9.065803 +93.194416,104,9.137887 +90.487003,118,9.219725 +92.309221,109,9.295581 +89.795107,122,9.380244 +89.095942,127,9.468391 +89.369037,125,9.55494 +90.125949,120,9.638249 +93.172738,103,9.709915 +93.145475,99,9.77876 +90.83815,116,9.859453 +93.21489,101,9.929499 +92.092596,110,10.007234 +89.644357,123,10.092857 +88.840014,129,10.182201 +89.791642,122,10.266649 +93.124155,103,10.33812 +93.176343,100,10.407616 +89.465783,124,10.493598 +89.934026,121,10.577511 +93.21387,104,10.649803 +89.529739,124,10.735824 +89.64373,123,10.821223 +93.175665,96,10.88802 +88.988982,128,10.976886 +92.601194,107,11.051019 +92.896729,105,11.123809 +93.18996,103,11.195364 +89.340725,125,11.282209 +93.219392,98,11.350248 +89.629466,123,11.435549 +92.396533,108,11.511843 +93.214856,102,11.582583 +93.228063,103,11.653995 +90.320025,119,11.736438 +90.09917,120,11.819752 +93.219587,102,11.890496 +89.350777,125,11.97716 +93.213126,103,12.048732 +93.169643,96,12.115289 +89.61667,123,12.200474 +93.107148,103,12.271923 +89.523964,124,12.357987 +88.620873,131,12.44872 +89.91964,121,12.532513 +88.721763,130,12.622581 +92.196586,109,12.698182 +88.636467,131,12.789118 +89.107866,127,12.877657 +93.201758,103,12.949092 +88.962481,128,13.03885 +89.350487,125,13.125727 +89.491968,124,13.211623 +92.608824,107,13.285851 +93.173034,103,13.357493 +93.159943,103,13.428843 +93.043704,104,13.501012 +89.452637,124,13.586959 +88.882335,129,13.676363 +88.848431,129,13.765942 +90.910626,116,13.846459 +93.225353,100,13.915782 +93.226799,101,13.985793 +93.180982,99,14.054471 +92.90149,105,14.127404 +89.75725,122,14.211995 +89.312771,125,14.298653 +89.477136,124,14.384778 +90.083796,120,14.468117 +89.922336,121,14.553396 +93.193835,104,14.625931 +92.894112,105,14.698986 +93.088068,103,14.770371 +89.601074,123,14.855972 +88.971711,128,14.944741 +93.201646,96,15.011277 +90.674297,117,15.092951 +93.153273,100,15.162291 +88.623838,131,15.253053 +93.190221,103,15.324453 +90.478143,118,15.406408 +93.207232,101,15.476441 +93.190314,103,15.547894 +89.612589,123,15.633731 +89.954368,121,15.717603 +89.129356,127,15.806335 +93.113853,96,15.872919 +93.207797,103,15.944339 +93.25996,103,16.015861 +92.066858,110,16.093505 +89.924772,121,16.177347 +93.158804,98,16.245262 +89.670423,123,16.330824 +89.930267,121,16.414818 +93.112147,103,16.486212 +93.162887,104,16.558339 +89.788241,122,16.643117 +89.057092,127,16.73131 +89.387116,125,16.817916 +90.110195,120,16.901234 +91.070569,115,16.980884 +89.111102,127,17.068882 +92.685835,106,17.142292 +93.19072,103,17.213887 +93.160701,97,17.2811 +89.642763,123,17.366416 +93.162671,103,17.438048 +89.523522,124,17.524099 +89.32233,125,17.611737 +92.276643,109,17.68733 +89.808876,122,17.772177 +88.634222,131,17.86334 +93.106995,97,17.930629 +93.204532,98,17.998586 +89.460042,124,18.084498 +93.206819,102,18.15522 +93.242963,101,18.225204 +92.985313,104,18.297466 +88.856539,129,18.387129 +93.11965,103,18.458685 +93.192493,96,18.525193 +92.689761,106,18.598905 +88.760828,130,18.689154 +92.129061,110,18.765424 +89.670486,123,18.850925 +93.198623,96,18.917644 +89.523987,124,19.003566 +93.151354,100,19.074666 +93.131613,99,19.143501 +93.008502,104,19.215617 +93.198418,103,19.287133 +93.225091,102,19.357944 +93.125307,97,19.425234 +92.75004,106,19.498718 +91.132038,115,19.578646 +89.92194,121,19.662881 +92.856527,105,19.735739 +93.194141,104,19.807772 +93.173787,99,19.87659 +93.202117,103,19.947971 +89.633954,123,20.033147 +92.538812,107,20.107291 +89.666701,123,20.192489 +88.841411,129,20.281841 +88.75991,130,20.372189 +88.748129,130,20.462423 +93.173209,99,20.531058 +89.058857,127,20.620422 +93.012135,104,20.692664 +92.86422,105,20.76547 +93.233717,101,20.835612 +92.290489,109,20.911277 +90.078759,120,20.994437 +92.964467,104,21.06681 +93.222286,99,21.135419 +89.402433,125,21.222019 +89.636444,123,21.307191 +92.278216,109,21.382867 +89.516076,124,21.468837 +93.16054,103,21.540472 +93.153414,98,21.608481 +89.114463,127,21.69649 +93.15217,103,21.768385 +92.064927,110,21.844795 +89.246397,126,21.932109 +93.09074,103,22.003629 +90.070595,120,22.088371 +93.253679,104,22.16074 +90.128886,120,22.244342 +92.918405,105,22.317087 +93.160509,103,22.388608 +93.125584,103,22.460013 +89.10934,127,22.547946 +88.755254,130,22.637953 +92.762401,106,22.711475 +88.991092,128,22.800161 +93.155405,100,22.869831 +89.802453,122,22.954444 +92.975223,104,23.026471 +93.168904,104,23.098666 +93.193803,104,23.17078 +93.162262,100,23.240085 +93.168439,97,23.307513 +89.916092,121,23.391437 +93.134174,97,23.458661 +89.360341,125,23.545413 +89.647433,123,23.632222 +89.619042,123,23.717508 +93.215812,97,23.784776 +93.253447,104,23.857316 +89.373429,125,23.944042 +93.223484,101,24.014015 +93.186859,99,24.082722 +92.272831,109,24.158974 +89.357262,125,24.245987 +92.85714,105,24.318738 +89.209741,126,24.406289 +89.629,123,24.491523 +93.121792,97,24.558794 +89.503047,124,24.644727 +89.747775,122,24.729302 +90.758879,117,24.810816 +89.786,122,24.89551 +88.975371,128,24.984213 +93.155645,103,25.055868 +93.206052,98,25.124781 +91.965353,111,25.201982 +89.350126,125,25.288644 +93.242769,102,25.359886 +89.063908,127,25.4481 +90.760517,117,25.529359 +93.198117,100,25.599029 +93.202919,96,25.665623 +93.199189,100,25.7349 +93.189083,96,25.801679 +92.917035,105,25.87483 +89.196708,126,25.962184 +93.167242,103,26.033517 +92.022104,110,26.10972 +93.168998,96,26.176264 +89.487654,124,26.262158 +88.748209,130,26.352381 +89.500553,124,26.438342 +93.110556,96,26.505068 +89.232873,126,26.592495 +93.221687,98,26.66114 +93.225378,103,26.732705 +93.154196,97,26.800051 +93.146876,103,26.872393 +89.263686,126,26.959697 +89.389665,125,27.046299 +90.113798,120,27.129411 +90.132993,120,27.212542 +88.884437,129,27.301854 +89.17863,126,27.389226 +88.935867,128,27.47791 +89.455394,124,27.5641 +93.223059,98,27.632195 +88.627508,131,27.722929 +93.186658,103,27.79449 +93.210087,103,27.866216 +92.467354,108,27.941114 +93.180657,103,28.012656 +89.910636,121,28.096602 +89.915205,121,28.181555 +93.13226,103,28.253097 +89.512654,124,28.339037 +93.152805,100,28.408373 +88.634751,131,28.499069 +89.786068,122,28.583945 +89.496703,124,28.669853 +93.129118,96,28.736562 +89.616638,123,28.821762 +93.167541,99,28.890456 +93.145435,96,28.957038 +93.208513,98,29.024989 +89.902364,121,29.109222 +93.124172,103,29.180632 +93.160502,103,29.252065 +89.198501,126,29.339431 +88.718295,130,29.429578 +93.144012,98,29.497785 +93.239379,101,29.567852 +90.29559,119,29.651159 +88.935028,128,29.740124 +89.369332,125,29.826837 +90.903542,116,29.907273 +92.867716,105,29.980012 +93.109769,97,30.047182 +89.916482,121,30.131198 +89.947361,121,30.215031 +92.384285,108,30.289849 +93.168302,100,30.359423 +93.195546,103,30.430832 +93.132245,96,30.497771 +93.11245,103,30.569238 +89.644505,123,30.654446 +93.212769,103,30.725943 +93.207232,101,30.795945 +92.592211,107,30.870552 +89.518665,124,30.956553 +93.208976,97,31.023806 +93.167488,98,31.091734 +93.15612,103,31.163891 +93.123346,103,31.235541 +90.96735,116,31.315909 +89.780526,122,31.400698 +92.107803,110,31.477023 +92.854429,105,31.549801 +88.828096,129,31.639183 +93.172403,100,31.708567 +93.213278,103,31.780102 +92.27694,109,31.855751 +89.632061,123,31.940945 +89.494009,124,32.026816 +89.252692,126,32.114596 +93.127196,97,32.181916 +93.165562,96,32.248488 +93.191529,104,32.320741 +93.118865,96,32.387803 +88.706759,130,32.477954 +93.163357,97,32.545473 +90.654345,117,32.626546 +92.824101,105,32.700063 +93.17086,100,32.769766 +89.926913,121,32.854058 +93.172318,101,32.930846 +89.248442,126,33.023205 +93.208802,97,33.091722 +93.163235,103,33.16315 +93.153762,99,33.232074 +89.499331,124,33.318104 +92.996035,104,33.390788 +89.925275,121,33.475083 +93.199561,97,33.542899 +89.773787,122,33.628055 +93.20984,103,33.699708 +93.171011,97,33.766988 +93.133729,103,33.838587 +88.755047,130,33.928791 +89.784374,122,34.013383 +88.964799,128,34.102243 +89.523777,124,34.189713 +93.16526,99,34.258537 +93.103153,103,34.330007 +90.11437,120,34.4134 +93.14627,103,34.484855 +88.993536,128,34.573855 +93.213609,103,34.645761 +93.027042,104,34.718464 +93.194417,104,34.790818 +93.207783,101,34.861243 +92.486259,108,34.936146 +93.168404,96,35.002705 +93.202997,103,35.074059 +89.3334,125,35.16073 +93.19208,104,35.232822 +93.1181,98,35.300769 +89.939804,121,35.384792 +88.937494,128,35.473511 +91.866324,111,35.550756 +93.166404,100,35.620043 +93.114378,96,35.687636 +93.127401,103,35.759539 +92.366841,108,35.834608 +88.994289,128,35.923328 +90.725429,117,36.004556 +93.140582,97,36.071855 +89.663881,123,36.157084 +89.74507,122,36.241658 +93.155232,97,36.308902 +92.30263,109,36.384618 +93.129231,98,36.452696 +88.968342,128,36.541399 +89.796895,122,36.626001 +89.761916,122,36.71057 +93.184929,103,36.781954 +89.233469,126,36.86993 +92.125202,110,36.946249 +93.215979,100,37.015587 +88.979871,128,37.104365 +92.689584,106,37.17804 +93.167541,99,37.247573 +92.750778,106,37.321238 +90.510861,118,37.403147 +93.209345,103,37.47462 +93.150284,96,37.541144 +89.927138,121,37.625348 +93.201393,102,37.696249 +89.25032,126,37.783566 +93.118024,97,37.850921 +90.676998,117,37.931979 +92.997761,104,38.004241 +88.967295,128,38.093081 +93.19607,101,38.163158 +93.143988,98,38.231305 +93.197709,100,38.300707 +89.929647,121,38.384711 +89.944135,121,38.468504 +93.142617,103,38.539954 +90.871517,116,38.62026 +93.052779,104,38.692498 +89.229239,126,38.780882 +93.203135,100,38.850405 +88.621118,131,38.941623 +93.168939,99,39.01034 +93.185535,103,39.081716 +93.231659,102,39.152423 +92.92264,105,39.226221 +93.225727,104,39.298454 +89.359009,125,39.385287 +92.902836,105,39.458488 +93.226793,101,39.528496 +93.144704,97,39.595832 +93.186947,103,39.667278 +89.253536,126,39.754577 +88.74445,130,39.844696 +93.197189,98,39.91293 +93.012592,104,39.985068 +90.112836,120,40.068262 +89.35922,125,40.154878 +93.234469,101,40.225098 +93.115151,96,40.292018 +90.75269,117,40.37327 +90.687021,117,40.454548 +93.140207,96,40.521316 +93.157896,97,40.588768 +89.941241,121,40.672827 +93.191973,96,40.739504 +93.194094,99,40.808182 +88.818852,129,40.897855 +93.176703,102,40.968657 +93.106451,103,41.040064 +89.777162,122,41.124847 +89.761014,122,41.209603 +93.013559,104,41.281681 +93.21945,99,41.350535 +92.753054,106,41.425454 +90.683749,117,41.507115 +89.488384,124,41.593059 +93.175812,97,41.660291 +89.188212,126,41.748793 +93.12555,103,41.82026 +90.268513,119,41.902782 +89.199857,126,41.990211 +88.993695,128,42.07894 +89.383569,125,42.166335 +89.785771,122,42.250844 +89.919912,121,42.33491 +92.95706,104,42.407031 +89.948384,121,42.491248 +88.629588,131,42.582094 +93.151478,100,42.65139 +93.054311,104,42.723538 +89.897591,121,42.807677 +88.621066,131,42.898912 +93.154143,97,42.966727 +93.178637,103,43.038454 +93.089263,103,43.109984 +90.718912,117,43.191035 +92.617755,107,43.2665 +89.606326,123,43.352129 +93.211167,103,43.423663 +89.763147,122,43.508456 +89.248391,126,43.595964 +89.629351,123,43.681384 +93.143575,99,43.750002 +93.085616,103,43.821568 +89.785497,122,43.906287 +89.947794,121,43.99036 +90.118474,120,44.073502 +89.92278,121,44.15729 +89.9201,121,44.241117 +93.183763,99,44.30986 +88.695989,130,44.400143 +89.628294,123,44.485366 +92.5964,107,44.559591 +88.842441,129,44.648939 +89.761056,122,44.733547 +93.165847,99,44.803622 +93.253447,104,44.875814 +93.170444,101,44.945834 +89.259561,126,45.033162 +93.127176,103,45.104685 +89.903958,121,45.188509 +92.918127,105,45.261828 +89.366676,125,45.348496 +89.772839,122,45.433375 +89.481283,124,45.519332 +93.175499,97,45.586553 +93.143193,96,45.653127 +92.631217,107,45.727495 +93.209786,103,45.798879 +92.975968,104,45.871031 +90.136063,120,45.954417 +88.846936,129,46.043762 +93.185796,104,46.115808 +88.981183,128,46.204436 +90.317074,119,46.288286 +88.998623,128,46.377084 +89.353752,125,46.463748 +89.658448,123,46.549034 +89.345372,125,46.635935 +93.11071,97,46.703308 +90.312375,119,46.786139 +93.222617,99,46.854967 +89.776182,122,46.939637 +88.843542,129,47.029286 +93.177785,96,47.095961 +92.852356,105,47.169558 +89.912827,121,47.254726 +93.186609,101,47.324791 +92.591294,107,47.399243 +93.121397,96,47.46606 +93.190033,103,47.537537 +93.199546,96,47.604137 +93.197168,98,47.672517 +93.135044,96,47.739124 +89.321169,125,47.828504 +88.827483,129,47.918107 +91.943017,111,47.995121 +92.76017,106,48.068854 +93.056614,104,48.140908 +89.601065,123,48.226178 +93.243733,101,48.298943 +93.152136,96,48.366921 +92.556485,107,48.442136 +89.520444,124,48.530309 +89.222777,126,48.618885 +93.125304,103,48.691462 +89.324305,125,48.780459 +92.280802,109,48.857474 +93.219473,102,48.92973 +93.153015,98,48.998859 +93.120789,103,49.071363 +93.138183,98,49.141435 +93.228597,98,49.210526 +93.19371,101,49.28223 +89.379544,125,49.371821 +89.497615,124,49.45918 +89.637992,123,49.545283 +93.223862,98,49.614473 +93.1656,98,49.68358 +88.979888,128,49.773801 +93.246854,103,49.846238 +89.102574,127,49.935801 +90.298708,119,50.019345 +88.85704,129,50.109801 +92.834106,105,50.183894 +89.319573,125,50.271451 +93.227107,103,50.344268 +89.472779,124,50.43175 +93.013849,104,50.505059 +89.951172,121,50.589867 +93.052085,104,50.662871 +93.141053,98,50.731702 +89.460239,124,50.819522 +93.205993,103,50.891895 +88.982632,128,50.982142 +91.564982,113,51.061148 +89.956091,121,51.146376 +89.315634,125,51.234005 +93.227301,99,51.303588 +92.289065,109,51.380502 +92.075522,110,51.458128 +92.763744,106,51.532413 +93.092027,103,51.60506 +93.148699,97,51.673021 +90.738986,117,51.755298 +89.786,122,51.841087 +90.916311,116,51.922384 +92.74814,106,51.996851 +93.122982,96,52.063997 +93.135399,97,52.132061 +89.928553,121,52.216866 +93.188208,97,52.285007 +89.340955,125,52.373165 +93.18967,103,52.445853 +89.913332,121,52.530457 +88.94619,128,52.620269 +89.224781,126,52.708595 +93.169615,101,52.779431 +89.600619,123,52.865946 +93.085288,103,52.937934 +93.044662,104,53.01129 +93.156381,98,53.079892 +88.639904,131,53.171839 +89.120099,127,53.260688 +92.127328,110,53.337501 +93.148752,98,53.406463 +90.940909,116,53.487679 +93.053845,104,53.560072 +92.680192,106,53.633967 +93.159634,96,53.700949 +89.636492,123,53.786931 +93.146944,103,53.859663 +93.124236,103,53.931946 +93.153876,101,54.002323 +89.249385,126,54.092658 +89.928525,121,54.177735 +89.610838,123,54.265309 +89.920205,121,54.349269 +93.162726,103,54.42082 +93.209508,103,54.492306 +93.149278,103,54.563904 +89.121806,127,54.652085 +93.238805,100,54.722352 +93.168054,96,54.788999 +89.621739,123,54.874501 +89.766015,122,54.959018 +89.3334,125,55.045774 +89.793834,122,55.130327 +93.205812,97,55.197545 +89.345921,125,55.284811 +89.107265,127,55.374524 +88.624113,131,55.465456 +90.302261,119,55.548009 +93.225227,101,55.618049 +93.129502,97,55.685413 +93.238881,103,55.756884 +93.174922,97,55.824229 +93.229078,102,55.895712 +93.223212,104,55.967892 +93.185609,101,56.038198 +90.465933,118,56.120048 +93.183188,100,56.189414 +93.138151,99,56.258034 +89.252814,126,56.345441 +93.117953,97,56.413848 +93.210636,99,56.483658 +89.759015,122,56.568497 +93.13718,103,56.639938 +93.089389,103,56.711398 +90.106928,120,56.794839 +92.278123,109,56.871873 +89.258735,126,56.959299 +93.184611,101,57.02932 +90.090583,120,57.113042 +89.518188,124,57.199065 +93.111143,96,57.265657 +93.24386,100,57.335033 +93.164215,103,57.406752 +93.153225,103,57.478143 +89.647018,123,57.563354 +89.212664,126,57.650669 +93.193619,104,57.723444 +88.630405,131,57.814351 +93.159632,99,57.883154 +92.969876,104,57.955209 +93.257775,104,58.027567 +93.204367,103,58.099015 +93.176598,103,58.170403 +88.980221,128,58.259057 +93.185682,103,58.331519 +93.153443,97,58.400393 +93.190813,100,58.469758 +93.141555,103,58.541677 +90.111326,120,58.624851 +92.891634,105,58.698087 +90.9731,116,58.778599 +93.178179,98,58.847127 +91.956352,111,58.925008 +92.581149,107,58.99934 +89.372052,125,59.086059 +93.041503,104,59.158349 +89.785216,122,59.243069 +89.488451,124,59.329066 +89.897591,121,59.41295 +92.828202,105,59.486279 +88.884087,129,59.575779 +90.849976,116,59.65658 +88.841235,129,59.746313 +88.873501,129,59.835773 +90.641497,117,59.91855 +93.165524,99,59.988173 +91.761282,112,60.065935 +89.351063,125,60.153204 +89.464234,124,60.239187 +88.697457,130,60.329204 +88.728205,130,60.419264 +89.927493,121,60.503114 +93.206606,97,60.570374 +93.185289,100,60.639809 +93.191592,103,60.712137 +89.20037,126,60.799514 +93.190872,101,60.869609 +89.903345,121,60.953606 +93.131978,97,61.021408 +93.118777,96,61.088077 +88.847335,129,61.177588 +93.033803,104,61.250071 +89.464365,124,61.336578 +93.180778,99,61.406906 +93.161971,98,61.474979 +89.109976,127,61.563004 +89.234333,126,61.650808 +93.172672,98,61.719018 +88.817255,129,61.808419 +89.255953,126,61.89651 +93.16292,98,61.964486 +89.926827,121,62.048369 +93.162855,103,62.119741 +89.928657,121,62.203618 +92.71856,106,62.277082 +92.894733,105,62.349883 +92.588499,107,62.424174 +89.505686,124,62.510524 +88.844126,129,62.600293 +88.616758,131,62.691345 +89.913655,121,62.775189 +89.324195,125,62.862217 +90.855537,116,62.94439 +89.898845,121,63.028467 +93.168941,103,63.100612 +93.170321,98,63.168791 +93.173278,99,63.237434 +92.277452,109,63.312944 +88.936667,128,63.401632 +93.153032,103,63.47322 +93.209119,103,63.544656 +89.224854,126,63.631994 +88.876408,129,63.722192 +89.620409,123,63.807993 +89.930247,121,63.891977 +93.220525,101,63.961983 +89.466729,124,64.048094 +93.208772,97,64.115474 +89.78016,122,64.201033 +89.124359,127,64.289404 +93.256853,104,64.362192 +93.240526,103,64.435313 +93.192691,96,64.50199 +93.14142,103,64.573438 +91.046669,115,64.653547 +93.026223,104,64.725852 +93.139495,99,64.794516 +90.093082,120,64.877918 +93.170354,104,64.950914 +93.174745,103,65.02263 +89.079809,127,65.110618 +92.302363,109,65.186149 +92.685547,106,65.259686 +88.753068,130,65.349779 +89.229446,126,65.437122 +93.201758,103,65.508638 +93.169835,103,65.580437 +93.21127,98,65.648366 +92.703945,106,65.722057 +92.052723,110,65.798472 +89.645036,123,65.884053 +93.242601,103,65.956467 +93.166618,96,66.023062 +93.174006,100,66.093184 +93.202306,99,66.161842 +89.796869,122,66.246539 +93.196911,97,66.313801 +93.167182,103,66.385406 +92.115416,110,66.461798 +89.31029,125,66.548394 +93.119352,103,66.619787 +93.216244,98,66.687828 +89.397224,125,66.774605 +93.119545,103,66.846382 +92.900316,105,66.919612 +89.624073,123,67.004971 +89.49597,124,67.090888 +89.936968,121,67.174706 +93.136101,97,67.241969 +93.145402,96,67.308531 +89.926837,121,67.392574 +89.323361,125,67.480177 +89.791346,122,67.564778 +89.003521,128,67.653583 +92.638118,107,67.727788 +93.165807,103,67.799167 +93.213154,103,67.870732 +93.185303,103,67.942335 +89.517466,124,68.028608 +93.03569,104,68.100771 +90.104021,120,68.184289 +93.220299,102,68.254989 +89.904513,121,68.338781 +89.250115,126,68.42628 +92.892057,105,68.499012 +92.310724,109,68.574544 +93.260065,104,68.646877 +93.1623,103,68.718292 +90.530494,118,68.800266 +89.784091,122,68.884984 +89.766282,122,68.970518 +90.974708,116,69.05096 +92.465063,108,69.125746 +93.088382,103,69.197949 +89.526902,124,69.284006 +93.119194,103,69.355394 +92.736793,106,69.428912 +93.175418,99,69.497591 +93.199603,103,69.569034 +92.729968,106,69.64257 +93.175916,103,69.714298 +89.06635,127,69.802424 +93.148365,100,69.87172 +89.63516,123,69.957006 +89.095529,127,70.044955 +93.153171,103,70.116961 +93.179808,103,70.188952 +93.170827,103,70.260363 +89.785335,122,70.344972 +93.19421,103,70.416641 +88.996297,128,70.507539 +93.173689,102,70.5786 +93.208868,103,70.650165 +88.934949,128,70.738872 +89.904938,121,70.822806 +89.386822,125,70.90975 +93.137042,97,70.976971 +90.32513,119,71.059464 +89.645893,123,71.145229 +89.192003,126,71.232438 +89.901771,121,71.316207 +89.783699,122,71.400936 +93.133994,96,71.467522 +89.217632,126,71.555593 +92.902842,105,71.628335 +89.622455,123,71.713611 +89.79816,122,71.798161 +93.196793,97,71.865431 +89.935996,121,71.949824 +89.671654,123,72.036268 +92.721186,106,72.109731 +89.111197,127,72.197746 +88.936655,128,72.286493 +89.525556,124,72.372501 +93.108443,97,72.440234 +93.168253,100,72.509655 +89.626806,123,72.595053 +90.316942,119,72.67776 +89.66335,123,72.763368 +93.043299,104,72.836236 +93.238793,100,72.905641 +91.669079,112,72.983637 +89.794558,122,73.068152 +89.799978,122,73.152812 +90.293333,119,73.235365 +89.765886,122,73.319812 +89.938376,121,73.403628 +93.108047,97,73.471665 +93.126032,103,73.543578 +89.097076,127,73.631557 +89.326869,125,73.718272 +90.527638,118,73.800032 +89.368039,125,73.886836 +91.521559,113,73.965156 +91.742946,112,74.04279 +89.372316,125,74.129733 +92.902882,105,74.202828 +93.037962,104,74.275032 +93.22387,100,74.344827 +92.818768,105,74.417674 +92.999126,104,74.489882 +89.325082,125,74.576533 +93.206483,103,74.64801 +93.201267,99,74.716963 +88.741867,130,74.807455 +93.125454,103,74.87907 +88.729311,130,74.971007 +90.30481,119,75.054453 +93.201946,96,75.121042 +92.625267,107,75.195196 +88.848425,129,75.28473 +93.156069,98,75.352745 +89.481972,124,75.438712 +93.153056,103,75.510172 +91.313341,114,75.589355 +93.125562,97,75.656726 +93.180688,103,75.728171 +89.357397,125,75.815309 +89.001509,128,75.904256 +92.850497,105,75.977141 +93.16963,100,76.046458 +93.15178,101,76.116592 +93.194917,96,76.183227 +93.167975,102,76.253978 +93.153141,97,76.321224 +91.768443,112,76.399175 +93.178956,103,76.470708 +88.714545,130,76.562491 +93.203786,103,76.63395 +90.084093,120,76.717079 +88.761874,130,76.80734 +93.181487,103,76.878703 +89.782177,122,76.96331 +88.981619,128,77.051973 +93.202422,102,77.122611 +93.181872,103,77.19395 +93.191335,97,77.261143 +90.438809,118,77.34304 +93.160755,103,77.414633 +93.125652,103,77.486283 +89.915219,121,77.570083 +89.059606,127,77.658573 +93.195899,103,77.730296 +88.756015,130,77.820449 +92.282541,109,77.896095 +89.670907,123,77.981625 +92.36919,108,78.057579 +90.130027,120,78.141079 +91.540973,113,78.21956 +93.171519,102,78.29026 +91.142865,115,78.369956 +93.148635,97,78.437427 +93.196236,103,78.509104 +93.207615,103,78.581518 +93.138783,96,78.64813 +89.534245,124,78.734275 +93.20022,97,78.801528 +88.874818,129,78.891303 +92.227428,109,78.967328 +90.515465,118,79.049043 +89.114129,127,79.136993 +89.652359,123,79.222402 +93.091442,103,79.293814 +93.157742,97,79.361088 +89.609755,123,79.446616 +89.608427,123,79.53272 +93.137938,96,79.599425 +90.117153,120,79.682692 +93.159527,97,79.749966 +90.870113,116,79.830463 +88.76285,130,79.92064 +93.20217,103,79.992041 +89.11846,127,80.080197 +92.110418,110,80.156546 +90.094376,120,80.239767 +93.154491,102,80.310447 +93.171697,97,80.377777 +93.174483,96,80.444614 +93.205037,102,80.515415 +93.178626,98,80.583456 +93.22334,103,80.655239 +93.17101,97,80.722522 +92.97526,104,80.794599 +93.131802,103,80.866057 +88.730404,130,80.956301 +93.193534,103,81.028696 +89.126083,127,81.11715 +90.120245,120,81.200412 +90.849409,116,81.280774 +89.90991,121,81.364822 +93.175579,98,81.432969 +93.113936,97,81.500236 +92.140818,110,81.576479 +90.265286,119,81.658954 +92.563074,107,81.733215 +90.439933,118,81.814929 +88.631962,131,81.905742 +93.054351,104,81.97786 +88.762079,130,82.068028 +88.879246,129,82.157396 +92.689315,106,82.230889 +93.19242,97,82.298214 +93.010691,104,82.370269 +90.865029,116,82.450739 +89.313815,125,82.538231 +93.114758,103,82.610257 +90.095302,120,82.693957 +91.091859,115,82.773816 +93.096671,103,82.845246 +90.669573,117,82.927001 +89.118716,127,83.014992 +89.186415,126,83.102641 +92.765807,106,83.176044 +93.161325,103,83.247456 +93.235664,101,83.317414 +93.12112,103,83.389264 +93.142453,96,83.455893 +93.252322,102,83.526634 +89.177865,126,83.613927 +92.557538,107,83.688504 +93.237011,101,83.758949 +89.126008,127,83.846957 +93.13948,99,83.915902 +92.86566,105,83.988813 +92.018233,110,84.066455 +89.595287,123,84.151766 +91.963171,111,84.228827 +89.638505,123,84.31399 +89.107044,127,84.402097 +93.18787,103,84.473541 +93.125346,97,84.540911 +93.202199,103,84.612561 +89.31941,125,84.699226 +93.203888,98,84.767159 +93.207763,97,84.83464 +88.708089,130,84.924937 +89.207923,126,85.012314 +93.213003,103,85.084081 +88.701055,130,85.17514 +93.141613,98,85.243126 +93.116478,96,85.309857 +89.124396,127,85.397951 +93.210305,103,85.469415 +89.807634,122,85.554104 +92.400846,108,85.630467 +88.636401,131,85.721369 +93.19395,99,85.790329 +93.12237,103,85.862094 +89.943097,121,85.946465 +90.118342,120,86.030022 +92.913675,105,86.102823 +90.654229,117,86.184106 +89.914921,121,86.267894 +91.098943,115,86.347583 +88.936027,128,86.436484 +93.187367,102,86.507133 +89.249832,126,86.594477 +89.108234,127,86.682587 +92.916876,105,86.75541 +93.172672,98,86.823661 +90.512931,118,86.906229 +93.139919,100,86.975729 +92.873565,105,87.048832 +93.176092,103,87.12138 +93.167602,98,87.189672 +89.799842,122,87.274525 +92.781443,106,87.348084 +93.16958,103,87.419699 +89.390136,125,87.506336 +93.19128,104,87.578401 +89.244997,126,87.6659 +89.631206,123,87.751139 +92.842291,105,87.824018 +91.768443,112,87.901861 +92.917788,105,87.974781 +93.190767,103,88.046171 +93.211766,97,88.113547 +88.702936,130,88.203895 +89.481083,124,88.290087 +90.636246,117,88.371137 +93.2357,99,88.439769 +93.225886,104,88.512164 +93.171707,104,88.585734 +93.153452,99,88.654665 +89.483705,124,88.740605 +89.329511,125,88.827997 +93.263651,104,88.9003 +89.916457,121,88.984288 +93.154281,97,89.051493 +88.998193,128,89.140115 +93.197809,97,89.207337 +88.695288,130,89.297436 +89.232276,126,89.384784 +92.684492,106,89.45827 +93.092551,103,89.529663 +93.158403,97,89.597071 +90.110284,120,89.680509 +92.417685,108,89.755529 +89.50296,124,89.841577 +93.2217,98,89.909513 +88.983626,128,89.998706 +93.153941,97,90.066056 +89.114313,127,90.155123 +89.601298,123,90.240512 +89.256132,126,90.327803 +92.143612,110,90.404044 +89.792557,122,90.48878 +92.889482,105,90.561783 +92.53248,107,90.635876 +89.902364,121,90.719736 +89.463908,124,90.805644 +88.991652,128,90.894686 +93.1882,96,90.961793 +89.939802,121,91.045545 +92.829138,105,91.11826 +88.755194,130,91.208295 +89.7717,122,91.292886 +93.148861,96,91.35976 +91.029147,115,91.439586 +92.038456,110,91.515853 +90.265573,119,91.598734 +93.178352,103,91.671412 +93.123294,103,91.742921 +90.68715,117,91.824028 +93.250189,101,91.894095 +89.355171,125,91.981058 +89.320319,125,92.068037 +89.947445,121,92.151978 +89.397556,125,92.238627 +93.137736,97,92.305911 +93.185176,103,92.377325 +90.246564,119,92.459783 +90.284355,119,92.542341 +89.79206,122,92.627175 +88.955518,128,92.715869 +88.975009,128,92.804979 +89.48926,124,92.891282 +88.951405,128,92.980399 +90.465366,118,93.062204 +93.109509,96,93.130365 +93.181941,102,93.201146 +89.186386,126,93.288865 +93.121383,98,93.356848 +93.190716,103,93.428388 +93.116138,97,93.495687 +92.271016,109,93.571246 +93.211711,103,93.643106 +89.932909,121,93.727079 +93.178976,103,93.798545 +92.376469,108,93.873444 +92.421163,108,93.948519 +88.831319,129,94.038138 +88.99594,128,94.126996 +91.142011,115,94.206826 +93.22592,99,94.275533 +93.149295,99,94.344221 +93.175543,97,94.41165 +89.350786,125,94.498443 +93.123879,103,94.569937 +91.970749,111,94.648223 +92.834187,105,94.721867 +93.10462,97,94.789438 +89.249036,126,94.877339 +93.20912,103,94.948895 +90.309984,119,95.031687 +89.509837,124,95.117581 +93.174356,100,95.18692 +93.086235,103,95.258347 +88.74182,130,95.348523 +90.263482,119,95.431169 +93.171946,99,95.500041 +93.199588,96,95.566647 +93.255415,102,95.637417 +93.136135,103,95.708988 +93.238499,103,95.780708 +92.597358,107,95.855417 +89.631791,123,95.941022 +93.172587,103,96.012683 +93.209297,103,96.084104 +88.760082,130,96.175583 +93.178501,98,96.243942 +89.101273,127,96.332009 +93.182366,99,96.400661 +90.515893,118,96.482551 +93.089371,103,96.553957 +91.89351,111,96.630932 +93.15571,100,96.700363 +93.153888,98,96.768378 +89.785497,122,96.853261 +92.99594,104,96.925722 +88.717487,130,97.016156 +89.939204,121,97.100046 +89.934537,121,97.183904 +89.191932,126,97.27123 +93.210774,102,97.341972 +88.967505,128,97.43071 +90.082126,120,97.513875 +92.894327,105,97.586647 +88.732035,130,97.678046 +93.091029,103,97.749545 +89.244402,126,97.837037 +93.131919,98,97.905 +93.228036,103,97.976631 +91.534719,113,98.0553 +90.331292,119,98.13774 +90.081655,120,98.220867 +93.222289,101,98.290964 +93.213757,103,98.36263 +93.189369,103,98.434448 +93.203435,103,98.506011 +92.37832,108,98.581209 +93.205048,97,98.648436 +93.188135,104,98.720777 +93.215186,99,98.789453 +93.089926,103,98.860995 +93.236442,100,98.930472 +88.82333,129,99.019885 +93.030509,104,99.092 +92.818484,105,99.166039 +90.076409,120,99.249606 +90.081701,120,99.332954 +90.087343,120,99.416191 +92.726764,106,99.489869 +89.113876,127,99.577914 +89.597351,123,99.663222 +93.206144,103,99.735183 +93.128215,103,99.806684 +92.292302,109,99.882263 +93.14296,96,99.948892 +89.114328,127,100.036944 +88.620912,131,100.127707 +90.690956,117,100.208924 +88.937332,128,100.297754 +93.181685,104,100.370342 +93.039621,104,100.4425 +92.678212,106,100.516052 +93.116304,103,100.587449 +89.348768,125,100.674873 +89.512717,124,100.761595 +90.946707,116,100.841982 +93.150607,99,100.910701 +92.603745,107,100.98494 +89.111125,127,101.073551 +93.00211,104,101.145632 +89.184911,126,101.23296 +89.910641,121,101.316875 +89.743,122,101.40173 +93.134007,103,101.473332 +93.008281,104,101.545467 +93.209158,103,101.616877 +93.206328,98,101.685006 +89.515352,124,101.771208 +93.150985,101,101.841605 +89.115515,127,101.929739 +93.168336,100,101.999038 +93.210424,103,102.070834 +93.171406,96,102.138115 +89.467108,124,102.225538 +92.981193,104,102.297666 +93.141328,97,102.365139 +93.22965,100,102.434683 +89.926522,121,102.518587 +90.497214,118,102.600381 +93.23257,100,102.669826 +92.396608,108,102.744774 +93.190648,103,102.816445 +93.210019,103,102.888306 +93.124251,103,102.959941 +93.209679,103,103.031445 +89.906826,121,103.115343 +92.779204,106,103.188808 +93.235075,102,103.259526 +93.150835,97,103.326885 +92.75709,106,103.400465 +90.110252,120,103.483732 +93.179731,99,103.552426 +92.865523,105,103.625242 +93.200697,96,103.69212 +93.178257,102,103.764354 +88.735652,130,103.854799 +93.230682,104,103.927037 +93.160335,103,103.999138 +93.202513,103,104.070577 +89.960259,121,104.154469 +90.134254,120,104.237638 +92.385617,108,104.312606 +93.150526,100,104.381945 +93.183856,99,104.450664 +88.759067,130,104.540955 +88.822921,129,104.630307 +89.05238,127,104.718322 +89.798559,122,104.802872 +89.782447,122,104.887698 +92.384231,108,104.962691 +89.069718,127,105.051152 +93.203272,98,105.119108 +92.207674,109,105.194847 +90.250796,119,105.278846 +93.201308,96,105.345448 +93.132576,97,105.412721 +93.190848,103,105.484147 +89.781804,122,105.56881 +93.116875,103,105.640399 +93.150437,100,105.710082 +89.369762,125,105.796787 +92.987398,104,105.868939 +92.580634,107,105.943214 +93.153979,102,106.014348 +89.894977,121,106.098143 +88.697905,130,106.188554 +93.243321,104,106.260686 +92.750388,106,106.334211 +92.404785,108,106.409182 +92.856459,105,106.482111 +92.989282,104,106.554266 diff --git a/logs/benchmark/ppo_lstm-PendulumNoVel-v1/0.monitor.csv b/logs/benchmark/ppo_lstm-PendulumNoVel-v1/0.monitor.csv new file mode 100644 index 000000000..55fff6aaa --- /dev/null +++ b/logs/benchmark/ppo_lstm-PendulumNoVel-v1/0.monitor.csv @@ -0,0 +1,752 @@ +#{"t_start": 1654204588.0577092, "env_id": "PendulumNoVel-v1"} +r,l,t +-402.519199,200,2.480222 +-130.254367,200,2.630337 +-253.835402,200,2.780386 +-132.850895,200,2.930846 +-389.727931,200,3.080838 +-536.415855,200,3.230758 +-0.923932,200,3.380872 +-418.149336,200,3.530867 +-1.562882,200,3.681829 +-402.953016,200,3.831828 +-132.458464,200,3.981687 +-134.372756,200,4.131569 +-386.913341,200,4.281405 +-138.890803,200,4.431183 +-0.650854,200,4.581138 +-129.375741,200,4.730886 +-130.248166,200,4.880838 +-131.582319,200,5.030658 +-267.739547,200,5.181282 +-0.580612,200,5.33136 +-136.334798,200,5.481234 +-495.272764,200,5.631139 +-484.102178,200,5.781667 +-390.524931,200,5.932026 +-130.325168,200,6.082158 +-249.204415,200,6.232241 +-264.347824,200,6.382366 +-129.584069,200,6.532424 +-126.615014,200,6.683844 +-492.988763,200,6.834109 +-393.692814,200,6.984055 +-251.805998,200,7.134065 +-2.07597,200,7.283871 +-126.728964,200,7.433803 +-392.746958,200,7.583708 +-257.422931,200,7.733629 +-507.003874,200,7.884361 +-133.24181,200,8.034442 +-263.233057,200,8.185913 +-124.297082,200,8.337559 +-398.661795,200,8.487754 +-504.372413,200,8.637957 +-128.990286,200,8.787918 +-1.080866,200,8.937824 +-132.306424,200,9.087769 +-129.747648,200,9.237722 +-256.345754,200,9.388183 +-127.788099,200,9.538201 +-135.844131,200,9.68924 +-1.512856,200,9.840115 +-383.408297,200,9.990105 +-128.013623,200,10.140147 +-257.021361,200,10.290052 +-132.973206,200,10.440223 +-136.038808,200,10.590082 +-255.61615,200,10.739967 +-130.378137,200,10.890256 +-262.911814,200,11.04023 +-129.472033,200,11.191393 +-123.746214,200,11.341515 +-131.366346,200,11.49152 +-128.73705,200,11.641451 +-121.811967,200,11.791374 +-131.371365,200,11.941349 +-256.614339,200,12.091842 +-258.178329,200,12.241704 +-2.546309,200,12.391757 +-253.503013,200,12.541677 +-249.908471,200,12.692321 +-259.233888,200,12.842389 +-129.391829,200,12.992356 +-131.588284,200,13.142358 +-377.182906,200,13.292416 +-2.808374,200,13.442636 +-124.696563,200,13.592626 +-2.283769,200,13.74265 +-252.46194,200,13.892748 +-260.613869,200,14.042629 +-127.715141,200,14.193771 +-130.87715,200,14.344025 +-266.165831,200,14.494216 +-136.758933,200,14.644304 +-521.274432,200,14.794543 +-134.91778,200,14.944999 +-132.464344,200,15.09512 +-259.136921,200,15.244979 +-523.922728,200,15.394954 +-126.623535,200,15.54527 +-134.4309,200,15.696621 +-128.162222,200,15.846931 +-128.474804,200,15.997117 +-128.321519,200,16.147037 +-256.767207,200,16.296895 +-131.03306,200,16.447109 +-1.755037,200,16.597155 +-412.540829,200,16.747057 +-131.401396,200,16.897024 +-256.84293,200,17.047659 +-128.618866,200,17.198305 +-549.44501,200,17.349118 +-262.431591,200,17.499175 +-125.635456,200,17.649138 +-393.707247,200,17.799148 +-506.713824,200,17.949326 +-516.686094,200,18.09973 +-135.451174,200,18.249698 +-136.502389,200,18.399829 +-260.013239,200,18.549896 +-127.30565,200,18.700019 +-455.226846,200,18.850885 +-132.617079,200,19.001096 +-130.89826,200,19.150959 +-258.652825,200,19.300925 +-129.663716,200,19.450789 +-125.848011,200,19.600976 +-502.032173,200,19.750824 +-394.359487,200,19.901413 +-129.928044,200,20.051476 +-126.740772,200,20.201448 +-133.688292,200,20.352506 +-393.770809,200,20.502646 +-129.374495,200,20.652508 +-262.399361,200,20.802411 +-271.81621,200,20.952535 +-129.929063,200,21.102612 +-255.784697,200,21.252395 +-134.037024,200,21.403085 +-132.72468,200,21.552927 +-129.344628,200,21.702864 +-242.856234,200,21.853705 +-1.196476,200,22.003769 +-390.246182,200,22.1539 +-129.311706,200,22.304043 +-123.947636,200,22.454089 +-409.102118,200,22.604069 +-424.501649,200,22.754085 +-128.488516,200,22.904248 +-394.249286,200,23.054295 +-262.855733,200,23.204241 +-389.65392,200,23.355616 +-258.634432,200,23.505915 +-247.377378,200,23.656505 +-437.781626,200,23.806754 +-129.533367,200,23.956941 +-127.797557,200,24.107264 +-127.703761,200,24.258819 +-134.206551,200,24.409025 +-130.898959,200,24.559296 +-136.946955,200,24.70922 +-540.497969,200,24.860151 +-510.421549,200,25.010261 +-252.524736,200,25.160258 +-132.670144,200,25.310229 +-253.607705,200,25.460401 +-266.376325,200,25.610475 +-258.740894,200,25.760541 +-121.987025,200,25.910583 +-131.414354,200,26.060641 +-256.732595,200,26.210585 +-259.629018,200,26.361656 +-243.632808,200,26.512006 +-263.768034,200,26.662218 +-520.935392,200,26.812245 +-130.199743,200,26.962329 +-131.462324,200,27.112274 +-261.533195,200,27.262167 +-264.952242,200,27.4122 +-499.004986,200,27.562604 +-265.635419,200,27.712666 +-134.417145,200,27.863875 +-406.186914,200,28.014071 +-261.730142,200,28.164416 +-131.56676,200,28.314623 +-530.232442,200,28.465021 +-127.203444,200,28.615008 +-380.986625,200,28.765011 +-133.808886,200,28.915337 +-468.679318,200,29.065327 +-260.520473,200,29.215312 +-385.850028,200,29.366877 +-386.112575,200,29.516918 +-133.899244,200,29.66715 +-127.112062,200,29.817209 +-266.173676,200,29.967375 +-392.787095,200,30.117482 +-130.038843,200,30.268342 +-133.178376,200,30.418429 +-1.518369,200,30.568381 +-385.404969,200,30.718568 +-0.661449,200,30.869841 +-255.147382,200,31.020269 +-590.537446,200,31.170284 +-123.028762,200,31.320235 +-532.179323,200,31.470438 +-132.791688,200,31.620412 +-390.66596,200,31.770427 +-129.960529,200,31.920456 +-396.281114,200,32.071321 +-395.102186,200,32.222039 +-131.642122,200,32.373405 +-404.060851,200,32.523429 +-124.060503,200,32.673403 +-131.717835,200,32.823598 +-127.718313,200,32.97415 +-137.818234,200,33.124208 +-129.95035,200,33.275639 +-131.477312,200,33.425819 +-127.797146,200,33.576052 +-124.371809,200,33.726285 +-130.104659,200,33.878643 +-133.817383,200,34.029474 +-123.440699,200,34.179405 +-257.553851,200,34.329363 +-1.697628,200,34.479425 +-390.571687,200,34.629377 +-0.897894,200,34.779369 +-2.533391,200,34.929563 +-128.027611,200,35.080093 +-414.369842,200,35.230119 +-274.964981,200,35.381662 +-128.183136,200,35.531727 +-132.552583,200,35.681757 +-4.302741,200,35.831682 +-257.129494,200,35.981714 +-260.231924,200,36.131918 +-134.285054,200,36.282022 +-362.860757,200,36.432488 +-128.915897,200,36.582807 +-456.543651,200,36.732877 +-3.358384,200,36.884256 +-255.450524,200,37.034254 +-261.3761,200,37.184193 +-260.793575,200,37.334163 +-133.796244,200,37.48479 +-257.517465,200,37.634975 +-531.879293,200,37.78501 +-128.615242,200,37.935118 +-130.197618,200,38.08517 +-258.885161,200,38.235106 +-123.943804,200,38.386478 +-132.935323,200,38.536453 +-530.041036,200,38.686333 +-130.818515,200,38.836363 +-262.447193,200,38.986387 +-505.777393,200,39.136355 +-2.975214,200,39.286256 +-412.703885,200,39.436566 +-538.712992,200,39.586547 +-528.167505,200,39.736509 +-130.346881,200,39.889076 +-131.897898,200,40.039305 +-388.192033,200,40.189297 +-131.470567,200,40.339184 +-410.195644,200,40.489591 +-130.107815,200,40.639575 +-244.712956,200,40.789485 +-255.326688,200,40.940122 +-261.541475,200,41.090018 +-258.160837,200,41.239906 +-387.741392,200,41.392228 +-257.322626,200,41.542432 +-415.767291,200,41.692436 +-129.723756,200,41.842434 +-0.794575,200,41.992673 +-128.213887,200,42.142858 +-251.497676,200,42.292704 +-259.17477,200,42.442664 +-128.269772,200,42.592744 +-131.744092,200,42.742721 +-262.842008,200,42.894565 +-498.548773,200,43.047081 +-0.774458,200,43.19861 +-262.596822,200,43.350205 +-124.026436,200,43.50212 +-257.435492,200,43.654454 +-260.68249,200,43.806756 +-255.574588,200,43.959209 +-264.727231,200,44.111727 +-377.587041,200,44.264087 +-128.587609,200,44.417627 +-129.70796,200,44.570029 +-269.060373,200,44.72219 +-131.165587,200,44.874727 +-254.387696,200,45.026945 +-263.9862,200,45.178953 +-134.287859,200,45.331221 +-125.80162,200,45.483443 +-508.039606,200,45.635566 +-263.145171,200,45.78762 +-130.273372,200,45.940681 +-125.9654,200,46.09278 +-257.22796,200,46.244693 +-250.627959,200,46.396623 +-0.826117,200,46.548539 +-131.003295,200,46.700485 +-122.997263,200,46.852642 +-0.734416,200,47.004569 +-253.738489,200,47.15634 +-256.666289,200,47.308181 +-122.696645,200,47.461555 +-127.847016,200,47.613417 +-256.762524,200,47.765217 +-254.441311,200,47.917135 +-133.397656,200,48.069028 +-271.438197,200,48.220624 +-134.179045,200,48.372397 +-262.332681,200,48.524182 +-1.337436,200,48.675769 +-130.565648,200,48.827398 +-131.142902,200,48.9806 +-127.975927,200,49.132221 +-262.25393,200,49.283701 +-254.449826,200,49.435412 +-258.731496,200,49.586887 +-389.092838,200,49.738442 +-0.721196,200,49.890626 +-127.779348,200,50.042051 +-131.439441,200,50.193412 +-133.441949,200,50.344803 +-537.846484,200,50.497638 +-135.778334,200,50.649079 +-261.050266,200,50.800883 +-258.139315,200,50.952236 +-256.596868,200,51.103744 +-131.703291,200,51.254972 +-128.114295,200,51.406292 +-491.009204,200,51.557491 +-131.508491,200,51.708709 +-133.760086,200,51.859934 +-130.221702,200,52.012615 +-133.389744,200,52.163825 +-128.091615,200,52.315926 +-252.439552,200,52.468925 +-128.490453,200,52.620389 +-128.985618,200,52.771552 +-131.770744,200,52.921945 +-253.997435,200,53.072034 +-127.733059,200,53.221994 +-254.225273,200,53.372034 +-405.849477,200,53.523106 +-242.3805,200,53.674289 +-134.21694,200,53.824225 +-489.829358,200,53.974366 +-403.942868,200,54.124361 +-123.547458,200,54.274329 +-135.275231,200,54.424801 +-396.79745,200,54.574818 +-563.813464,200,54.724805 +-262.983306,200,54.875445 +-130.550435,200,55.026812 +-0.924758,200,55.177573 +-128.991373,200,55.327608 +-268.824452,200,55.478076 +-257.067561,200,55.628124 +-130.902823,200,55.778219 +-254.152831,200,55.928341 +-531.076613,200,56.078641 +-129.705213,200,56.228625 +-127.101532,200,56.378672 +-260.557891,200,56.52994 +-417.469802,200,56.679893 +-130.238158,200,56.829859 +-2.974881,200,56.980036 +-265.916012,200,57.129957 +-0.747026,200,57.279986 +-130.729517,200,57.430612 +-257.108569,200,57.580575 +-1.379407,200,57.73075 +-131.131822,200,57.880719 +-387.538382,200,58.032043 +-262.503316,200,58.182208 +-253.155763,200,58.332651 +-490.879981,200,58.483264 +-263.13122,200,58.6331 +-386.007151,200,58.783066 +-254.606673,200,58.933168 +-264.764901,200,59.083106 +-483.63525,200,59.233039 +-131.832368,200,59.383191 +-254.521777,200,59.534242 +-244.609944,200,59.684812 +-271.196004,200,59.834779 +-262.006001,200,59.985074 +-261.02399,200,60.135115 +-260.254047,200,60.285059 +-132.289053,200,60.435072 +-254.908308,200,60.584996 +-497.978891,200,60.734942 +-255.888694,200,60.884999 +-261.993791,200,61.035861 +-388.388129,200,61.186234 +-1.24258,200,61.336648 +-128.899158,200,61.486665 +-246.394188,200,61.636727 +-128.094586,200,61.786773 +-128.658202,200,61.936914 +-402.595834,200,62.08699 +-253.379127,200,62.237787 +-121.864009,200,62.388124 +-130.144938,200,62.538343 +-128.310917,200,62.688994 +-133.84844,200,62.838951 +-132.475825,200,62.989049 +-1.026764,200,63.138992 +-502.208176,200,63.288834 +-128.871903,200,63.438949 +-261.875245,200,63.588862 +-1.759837,200,63.738794 +-121.79994,200,63.888904 +-383.090765,200,64.039152 +-263.741944,200,64.190158 +-125.44497,200,64.340029 +-257.646825,200,64.490144 +-267.011003,200,64.640196 +-129.206928,200,64.790463 +-127.065427,200,64.940374 +-263.431516,200,65.090271 +-364.139141,200,65.240197 +-254.247729,200,65.390134 +-253.04928,200,65.540108 +-128.762169,200,65.691237 +-132.1468,200,65.841159 +-130.959596,200,65.991195 +-126.657641,200,66.141119 +-255.893589,200,66.290903 +-132.205551,200,66.440856 +-130.498794,200,66.590754 +-127.877895,200,66.74071 +-133.294295,200,66.890872 +-131.329031,200,67.040868 +-129.291232,200,67.192048 +-258.77575,200,67.341976 +-1.553831,200,67.492198 +-257.745899,200,67.642157 +-126.255173,200,67.792123 +-128.979293,200,67.94228 +-131.461846,200,68.092414 +-133.893376,200,68.242526 +-124.267656,200,68.39249 +-385.636733,200,68.542359 +-264.672312,200,68.693232 +-131.000453,200,68.843172 +-247.950141,200,68.99308 +-263.850945,200,69.142999 +-252.46746,200,69.292894 +-127.940674,200,69.442874 +-403.64718,200,69.592826 +-130.133859,200,69.74283 +-243.464739,200,69.892962 +-138.400139,200,70.042819 +-263.079562,200,70.194396 +-130.953487,200,70.344238 +-128.385681,200,70.494611 +-133.014874,200,70.644564 +-259.780775,200,70.794422 +-129.153884,200,70.944657 +-521.664381,200,71.094606 +-245.090205,200,71.244525 +-130.89642,200,71.394762 +-133.35325,200,71.545349 +-258.936847,200,71.696949 +-128.330259,200,71.847137 +-1.801753,200,71.997452 +-259.93976,200,72.147381 +-396.024985,200,72.297277 +-131.870194,200,72.447377 +-136.297627,200,72.597364 +-0.823265,200,72.747237 +-270.715603,200,72.897243 +-130.113774,200,73.047105 +-479.652184,200,73.197917 +-251.999533,200,73.347957 +-130.148371,200,73.498791 +-253.696658,200,73.648846 +-123.047836,200,73.798869 +-1.156308,200,73.949221 +-520.705938,200,74.099172 +-259.495808,200,74.249089 +-131.728049,200,74.399434 +-261.29829,200,74.549264 +-122.16629,200,74.700153 +-133.926102,200,74.850294 +-130.137137,200,75.000444 +-254.939282,200,75.15031 +-513.074461,200,75.30023 +-130.096033,200,75.450384 +-261.31782,200,75.600297 +-126.510035,200,75.750345 +-421.608682,200,75.900493 +-124.13579,200,76.050721 +-136.36839,200,76.202114 +-389.560064,200,76.352335 +-131.083888,200,76.502524 +-130.690174,200,76.652491 +-413.572655,200,76.802619 +-135.291854,200,76.952681 +-134.195794,200,77.102591 +-131.057201,200,77.253225 +-134.03553,200,77.403305 +-0.820752,200,77.55325 +-123.267231,200,77.704308 +-398.19583,200,77.854314 +-390.690876,200,78.004445 +-407.120835,200,78.154418 +-131.894882,200,78.304872 +-133.474274,200,78.454864 +-132.188956,200,78.60489 +-263.746771,200,78.754839 +-475.524122,200,78.905009 +-132.041959,200,79.054985 +-129.917066,200,79.206104 +-133.146303,200,79.356185 +-390.49514,200,79.506301 +-399.885877,200,79.656353 +-131.290165,200,79.806514 +-130.199982,200,79.95666 +-133.556031,200,80.106842 +-130.931187,200,80.256788 +-250.647754,200,80.406693 +-128.200454,200,80.556664 +-127.721969,200,80.708273 +-130.520502,200,80.858291 +-130.858737,200,81.00835 +-255.487012,200,81.158605 +-2.253563,200,81.309322 +-130.487165,200,81.459565 +-126.719604,200,81.609811 +-126.431937,200,81.759861 +-133.653628,200,81.909883 +-256.780391,200,82.059874 +-495.09774,200,82.211046 +-136.353332,200,82.361443 +-387.457952,200,82.511901 +-253.279425,200,82.661841 +-259.324465,200,82.811814 +-530.177779,200,82.961904 +-130.197747,200,83.111833 +-263.103061,200,83.261714 +-131.738116,200,83.411888 +-131.201019,200,83.56198 +-128.945717,200,83.71337 +-251.079851,200,83.863603 +-267.302457,200,84.013948 +-388.325636,200,84.163871 +-122.342331,200,84.3138 +-259.707733,200,84.463821 +-131.781672,200,84.614248 +-135.124134,200,84.764165 +-130.321858,200,84.914449 +-1.313194,200,85.064358 +-126.503561,200,85.214566 +-496.524308,200,85.365945 +-6.160445,200,85.516003 +-124.383663,200,85.665968 +-394.493689,200,85.816112 +-260.649325,200,85.966306 +-135.382236,200,86.116416 +-260.879486,200,86.266237 +-492.730836,200,86.416513 +-1.308366,200,86.566504 +-412.233998,200,86.716506 +-0.567212,200,86.867225 +-398.629213,200,87.01755 +-415.362889,200,87.167497 +-126.896792,200,87.317436 +-0.846113,200,87.467558 +-133.211996,200,87.617578 +-259.124641,200,87.767718 +-262.516432,200,87.917925 +-259.693684,200,88.068021 +-507.08818,200,88.21801 +-370.794814,200,88.368912 +-1.087646,200,88.519423 +-129.548493,200,88.669913 +-135.406123,200,88.819872 +-252.677184,200,88.97002 +-258.317547,200,89.119923 +-450.705273,200,89.27018 +-133.641993,200,89.420418 +-126.235579,200,89.570389 +-250.197252,200,89.720362 +-385.796826,200,89.871599 +-527.451591,200,90.021584 +-255.994445,200,90.171479 +-125.820988,200,90.321563 +-3.355339,200,90.471581 +-261.828027,200,90.621686 +-1.082514,200,90.772223 +-377.37877,200,90.922379 +-1.683153,200,91.07271 +-254.005314,200,91.222569 +-259.746342,200,91.373994 +-0.76549,200,91.524039 +-456.105696,200,91.674047 +-132.898364,200,91.824019 +-264.393299,200,91.974616 +-133.124353,200,92.124937 +-134.04704,200,92.274887 +-532.695581,200,92.425009 +-128.479076,200,92.575025 +-0.791785,200,92.725048 +-123.897,200,92.876107 +-403.021303,200,93.026173 +-256.030124,200,93.176146 +-132.985599,200,93.326122 +-1.818002,200,93.476148 +-270.469002,200,93.626078 +-3.079538,200,93.776158 +-259.517448,200,93.92628 +-254.361579,200,94.076236 +-385.965485,200,94.226105 +-1.827868,200,94.377053 +-131.51911,200,94.527777 +-392.387264,200,94.677916 +-133.101115,200,94.827924 +-133.839194,200,94.977951 +-123.915708,200,95.127863 +-131.339995,200,95.278092 +-395.23312,200,95.428234 +-387.022453,200,95.578197 +-131.928348,200,95.728455 +-129.970639,200,95.87979 +-261.022001,200,96.03028 +-261.491149,200,96.180207 +-246.163852,200,96.330205 +-124.154232,200,96.480319 +-130.419012,200,96.630197 +-242.645725,200,96.78007 +-130.260847,200,96.930046 +-125.610515,200,97.080301 +-1.407112,200,97.230152 +-253.562121,200,97.381064 +-2.213094,200,97.531155 +-0.774488,200,97.681336 +-4.187768,200,97.831277 +-259.572981,200,97.982108 +-129.353383,200,98.13202 +-124.517025,200,98.282337 +-264.471257,200,98.432557 +-254.968328,200,98.582457 +-123.42731,200,98.732319 +-259.974884,200,98.883663 +-447.813526,200,99.033815 +-259.955894,200,99.183714 +-258.804813,200,99.333555 +-129.039698,200,99.483533 +-128.780577,200,99.63378 +-260.814353,200,99.784062 +-122.431627,200,99.934475 +-1.080245,200,100.084785 +-260.830759,200,100.235228 +-534.234211,200,100.386697 +-523.165229,200,100.53708 +-129.10214,200,100.687067 +-132.807356,200,100.836922 +-130.414431,200,100.986925 +-395.870585,200,101.136974 +-377.701219,200,101.286847 +-260.948133,200,101.437155 +-0.816817,200,101.58697 +-262.958324,200,101.736865 +-134.742928,200,101.887607 +-454.942996,200,102.037928 +-130.919269,200,102.187853 +-264.041871,200,102.337847 +-130.423312,200,102.487939 +-127.226415,200,102.637887 +-522.461754,200,102.78776 +-1.148708,200,102.937798 +-6.536455,200,103.087711 +-257.6031,200,103.237552 +-128.720644,200,103.388964 +-131.406704,200,103.539749 +-129.394132,200,103.689775 +-132.589514,200,103.83981 +-124.933573,200,103.989935 +-478.735802,200,104.140508 +-3.082373,200,104.290401 +-131.676206,200,104.440384 +-534.309578,200,104.590317 +-135.326852,200,104.740203 +-132.253694,200,104.891677 +-255.012311,200,105.041959 +-255.012722,200,105.191854 +-258.503162,200,105.341737 +-0.895727,200,105.491628 +-0.899892,200,105.641625 +-1.728374,200,105.791913 +-514.48447,200,105.942069 +-399.104322,200,106.092014 +-533.905116,200,106.24193 +-124.653509,200,106.392929 +-1.529994,200,106.543303 +-565.123823,200,106.69336 +-382.928626,200,106.843282 +-393.138859,200,106.993197 +-133.039055,200,107.143152 +-1.254905,200,107.293069 +-253.218977,200,107.443086 +-377.926736,200,107.59297 +-384.992998,200,107.742795 +-132.20426,200,107.893014 +-2.265609,200,108.044569 +-129.277822,200,108.194533 +-516.770055,200,108.344465 +-123.844188,200,108.494859 +-488.233926,200,108.644754 +-259.32199,200,108.794864 +-417.74258,200,108.944967 +-126.061511,200,109.094862 +-130.831749,200,109.244875 +-255.263712,200,109.395019 +-501.52857,200,109.546746 +-1.092731,200,109.697082 +-124.415082,200,109.848345 +-398.725763,200,109.99847 +-265.103,200,110.148456 +-262.173707,200,110.298429 +-1.412234,200,110.448377 +-384.796566,200,110.598411 +-122.613973,200,110.748757 +-131.847815,200,110.898898 +-259.808992,200,111.0496 +-260.018786,200,111.199631 +-476.825245,200,111.349615 +-265.113998,200,111.49967 +-125.187484,200,111.649592 +-131.658743,200,111.799565 +-560.090144,200,111.949756 +-249.87696,200,112.09988 +-513.023542,200,112.249848 +-261.979488,200,112.399794 +-253.954332,200,112.55137 +-131.624941,200,112.701358 +-3.65569,200,112.851359 +-129.742158,200,113.001439 +-127.666264,200,113.151829 +-129.3067,200,113.301748 +-130.320689,200,113.452212 +-128.921845,200,113.602148 +-263.061197,200,113.752045 +-129.571114,200,113.902041 +-391.835508,200,114.053375 +-257.21494,200,114.203687 +-258.361078,200,114.353614 +-125.928314,200,114.503546 +-128.023682,200,114.654358 +-126.836506,200,114.804574 +-264.384871,200,114.955343 +-129.339465,200,115.105539 diff --git a/requirements.txt b/requirements.txt index dbd890eb1..4318fc8e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,20 @@ gym==0.21 -stable-baselines3[extra,tests,docs]>=1.5.0 -sb3-contrib>=1.5.0 +stable-baselines3[extra,tests,docs]>=1.5.1a7 +sb3-contrib>=1.5.1a8 box2d-py==2.3.8 pybullet gym-minigrid scikit-optimize optuna -pytablewriter +pytablewriter~=0.64 seaborn pyyaml>=5.1 cloudpickle>=1.5.0 plotly +git+https://github.com/takuseno/d3rlpy panda-gym==1.1.1 # tmp fix: until compatibility with panda-gym v2 rliable>=1.0.5 wandb +ale-py==0.7.4 # tmp fix: until new SB3 version is released +# TODO: replace with release +git+https://github.com/huggingface/huggingface_sb3 diff --git a/scripts/all_plots.py b/scripts/all_plots.py index b456b9009..335b78829 100644 --- a/scripts/all_plots.py +++ b/scripts/all_plots.py @@ -205,7 +205,7 @@ # Markdown Table -writer = pytablewriter.MarkdownTableWriter() +writer = pytablewriter.MarkdownTableWriter(max_precision=3) writer.table_name = "results_table" headers = ["Environments"] diff --git a/scripts/migrate_to_hub.py b/scripts/migrate_to_hub.py new file mode 100644 index 000000000..d2622a43a --- /dev/null +++ b/scripts/migrate_to_hub.py @@ -0,0 +1,23 @@ +import subprocess + +from utils.utils import get_hf_trained_models, get_trained_models + +folder = "rl-trained-agents" +orga = "sb3" +trained_models_local = get_trained_models(folder) +trained_models_hub = get_hf_trained_models(orga) +remaining_models = set(trained_models_local.keys()) - set(trained_models_hub.keys()) + +for trained_model in list(remaining_models): + algo, env_id = trained_models_local[trained_model] + args = ["-orga", orga, "-f", folder, "--algo", algo, "--env", env_id] + + # Since SB3 >= 1.1.0, HER is no more an algorithm but a replay buffer class + if algo == "her": + continue + + # if model doesn't exist already + repo_name = f"{algo}-{env_id}" + repo_id = f"{orga}/{repo_name}" + + return_code = subprocess.call(["python", "-m", "utils.push_to_hub"] + args) diff --git a/scripts/plot_from_file.py b/scripts/plot_from_file.py index 4ec0043cf..e0f1c95e3 100644 --- a/scripts/plot_from_file.py +++ b/scripts/plot_from_file.py @@ -79,7 +79,7 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5 results = pickle.load(file_handler) # Plot table -writer = pytablewriter.MarkdownTableWriter() +writer = pytablewriter.MarkdownTableWriter(max_precision=3) writer.table_name = "results_table" writer.headers = results["results_table"]["headers"] writer.value_matrix = results["results_table"]["value_matrix"] @@ -193,7 +193,7 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5 warnings.warn(f"{env} not found for normalizing scores, you should update `env_key_to_env_id`") # Truncate to convert to matrix - min_runs = min([len(algo_score) for algo_score in algo_scores]) + min_runs = min(len(algo_score) for algo_score in algo_scores) if min_runs > 0: algo_scores = [algo_score[:min_runs] for algo_score in algo_scores] # shape: (n_envs, n_runs) -> (n_runs, n_envs) diff --git a/setup.cfg b/setup.cfg index c74b8320e..6841b0fcb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,6 +20,7 @@ per-file-ignores = ./scripts/all_plots.py:E501 ./scripts/plot_train.py:E501 ./scripts/plot_training_success.py:E501 + ./utils/teleop.py:F405 exclude = # No need to traverse our git directory diff --git a/tests/test_enjoy.py b/tests/test_enjoy.py index c55e702b2..194530556 100644 --- a/tests/test_enjoy.py +++ b/tests/test_enjoy.py @@ -3,7 +3,7 @@ import pytest -from utils import get_trained_models +from utils.utils import get_hf_trained_models, get_trained_models def _assert_eq(left, right): @@ -12,8 +12,10 @@ def _assert_eq(left, right): FOLDER = "rl-trained-agents/" N_STEPS = 100 - +# Use local models trained_models = get_trained_models(FOLDER) +# Use huggingface models too +trained_models.update(get_hf_trained_models()) @pytest.mark.parametrize("trained_model", trained_models.keys()) @@ -26,6 +28,10 @@ def test_trained_agents(trained_model): if algo == "her": return + # skip car racing + if "CarRacing" in env_id: + return + # Skip mujoco envs if "Fetch" in trained_model or "-v3" in trained_model: return @@ -38,7 +44,7 @@ def test_trained_agents(trained_model): def test_benchmark(tmp_path): - args = ["-n", str(N_STEPS), "--benchmark-dir", tmp_path, "--test-mode"] + args = ["-n", str(N_STEPS), "--benchmark-dir", tmp_path, "--test-mode", "--no-hub"] return_code = subprocess.call(["python", "-m", "utils.benchmark"] + args) _assert_eq(return_code, 0) diff --git a/tests/test_hyperparams_opt.py b/tests/test_hyperparams_opt.py index cd5444b70..89396a1f1 100644 --- a/tests/test_hyperparams_opt.py +++ b/tests/test_hyperparams_opt.py @@ -2,7 +2,9 @@ import os import subprocess +import optuna import pytest +from optuna.trial import TrialState def _assert_eq(left, right): @@ -120,3 +122,57 @@ def test_optimize_log_path(tmp_path): ] return_code = subprocess.call(["python", "scripts/parse_study.py"] + args) _assert_eq(return_code, 0) + + +def test_multiple_workers(tmp_path): + study_name = "test-study" + storage = f"sqlite:///{tmp_path}/optuna.db" + # n trials per worker + n_trials = 2 + # max total trials + max_trials = 3 + # 1st worker will do 2 trials + # 2nd worker will do 1 trial + # 3rd worker will do nothing + n_workers = 3 + args = [ + "-optimize", + "--no-optim-plots", + "--storage", + storage, + "--n-trials", + str(n_trials), + "--max-total-trials", + str(max_trials), + "--study-name", + study_name, + "--n-evaluations", + str(1), + "-n", + str(100), + "--algo", + "a2c", + "--env", + "Pendulum-v1", + "--log-folder", + tmp_path, + "-params", + "n_envs:1", + "--seed", + "12", + ] + + # Sequencial execution to avoid race conditions + workers = [] + for _ in range(n_workers): + worker = subprocess.Popen( + ["python", "train.py"] + args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True + ) + worker.wait() + workers.append(worker) + + study = optuna.load_study(study_name=study_name, storage=storage) + assert len(study.get_trials(states=(TrialState.COMPLETE, TrialState.PRUNED))) == max_trials + + for worker in workers: + assert worker.returncode == 0, "STDOUT:\n{}\nSTDERR:\n{}\n".format(*worker.communicate()) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 5ec00c0f4..bfe0d7a1b 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -1,7 +1,9 @@ import gym import pybullet_envs # noqa: F401 import pytest +from stable_baselines3 import A2C from stable_baselines3.common.env_checker import check_env +from stable_baselines3.common.env_util import DummyVecEnv from utils.utils import get_wrapper_class from utils.wrappers import ActionNoiseWrapper, DelayedRewardWrapper, HistoryWrapper, TimeFeatureWrapper @@ -31,3 +33,20 @@ def test_get_wrapper(env_wrapper): if env_wrapper is not None: env = wrapper_class(env) check_env(env) + + +@pytest.mark.parametrize( + "vec_env_wrapper", + [ + None, + {"stable_baselines3.common.vec_env.VecFrameStack": dict(n_stack=2)}, + [{"stable_baselines3.common.vec_env.VecFrameStack": dict(n_stack=3)}, "stable_baselines3.common.vec_env.VecMonitor"], + ], +) +def test_get_vec_env_wrapper(vec_env_wrapper): + env = DummyVecEnv([lambda: gym.make("AntBulletEnv-v0")]) + hyperparams = {"vec_env_wrapper": vec_env_wrapper} + wrapper_class = get_wrapper_class(hyperparams, "vec_env_wrapper") + if wrapper_class is not None: + env = wrapper_class(env) + A2C("MlpPolicy", env).learn(16) diff --git a/train.py b/train.py index bbf85795e..1577297da 100644 --- a/train.py +++ b/train.py @@ -11,10 +11,28 @@ import torch as th from stable_baselines3.common.utils import set_random_seed +try: + from d3rlpy.algos import AWAC, AWR, BC, BCQ, BEAR, CQL, CRR, TD3PlusBC + from d3rlpy.models.encoders import VectorEncoderFactory + from d3rlpy.wrappers.sb3 import SB3Wrapper, to_mdp_dataset + + offline_algos = dict( + awr=AWR, + awac=AWAC, + bc=BC, + bcq=BCQ, + bear=BEAR, + cql=CQL, + crr=CRR, + td3bc=TD3PlusBC, + ) +except ImportError: + offline_algos = {} + # Register custom envs import utils.import_envs # noqa: F401 pytype: disable=import-error from utils.exp_manager import ExperimentManager -from utils.utils import ALGOS, StoreDict +from utils.utils import ALGOS, StoreDict, evaluate_policy_add_to_buffer seaborn.set() @@ -64,6 +82,13 @@ type=int, default=500, ) + parser.add_argument( + "--max-total-trials", + help="Number of (potentially pruned) trials for optimizing hyperparameters. " + "This applies to the entire optimization process and takes precedence over --n-trials if set.", + type=int, + default=None, + ) parser.add_argument( "-optimize", "--optimize-hyperparameters", action="store_true", default=False, help="Run hyperparameters search" ) @@ -117,6 +142,17 @@ help="Overwrite hyperparameter (e.g. learning_rate:0.01 train_freq:10)", ) parser.add_argument("-uuid", "--uuid", action="store_true", default=False, help="Ensure that the run has a unique ID") + parser.add_argument( + "--offline-algo", help="Offline RL Algorithm", type=str, required=False, choices=list(offline_algos.keys()) + ) + parser.add_argument("-b", "--pretrain-buffer", help="Path to a saved replay buffer for pretraining", type=str) + parser.add_argument( + "--pretrain-params", + type=str, + nargs="+", + action=StoreDict, + help="Optional arguments for pretraining with replay buffer", + ) parser.add_argument( "--track", action="store_true", @@ -201,6 +237,7 @@ args.storage, args.study_name, args.n_trials, + args.max_total_trials, args.n_jobs, args.sampler, args.pruner, @@ -228,9 +265,79 @@ args.saved_hyperparams = saved_hyperparams run.config.setdefaults(vars(args)) - # Normal training - if model is not None: - exp_manager.learn(model) - exp_manager.save_trained_model(model) + if args.pretrain_buffer is not None and model is not None: + model.load_replay_buffer(args.pretrain_buffer) + print(f"Buffer size = {model.replay_buffer.buffer_size}") + # Artificially reduce buffer size + # model.replay_buffer.full = False + # model.replay_buffer.pos = 5000 + + print(f"{model.replay_buffer.size()} transitions in the replay buffer") + + n_iterations = args.pretrain_params.get("n_iterations", 10) + n_epochs = args.pretrain_params.get("n_epochs", 1) + q_func_factory = args.pretrain_params.get("q_func_factory") + batch_size = args.pretrain_params.get("batch_size", 512) + # n_action_samples = args.pretrain_params.get("n_action_samples", 1) + n_eval_episodes = args.pretrain_params.get("n_eval_episodes", 5) + add_to_buffer = args.pretrain_params.get("add_to_buffer", False) + deterministic = args.pretrain_params.get("deterministic", True) + net_arch = args.pretrain_params.get("net_arch", [256, 256]) + scaler = args.pretrain_params.get("scaler", "standard") + encoder_factory = VectorEncoderFactory(hidden_units=net_arch) + for arg_name in { + "n_iterations", + "n_epochs", + "q_func_factory", + "batch_size", + "n_eval_episodes", + "add_to_buffer", + "deterministic", + "net_arch", + "scaler", + }: + if arg_name in args.pretrain_params: + del args.pretrain_params[arg_name] + try: + assert args.offline_algo is not None and offline_algos is not None + kwargs_ = {} if q_func_factory is None else dict(q_func_factory=q_func_factory) + kwargs_.update(dict(encoder_factory=encoder_factory)) + kwargs_.update(args.pretrain_params) + + offline_model = offline_algos[args.offline_algo]( + batch_size=batch_size, + **kwargs_, + ) + offline_model = SB3Wrapper(offline_model) + offline_model.use_sde = False + # break the logger... + # offline_model.replay_buffer = model.replay_buffer + + for i in range(n_iterations): + dataset = to_mdp_dataset(model.replay_buffer) + offline_model.fit(dataset.episodes, n_epochs=n_epochs, save_metrics=False) + + mean_reward, std_reward = evaluate_policy_add_to_buffer( + offline_model, + model.get_env(), + n_eval_episodes=n_eval_episodes, + replay_buffer=model.replay_buffer if add_to_buffer else None, + deterministic=deterministic, + ) + print(f"Iteration {i + 1} training, mean_reward={mean_reward:.2f} +/- {std_reward:.2f}") + except KeyboardInterrupt: + pass + finally: + print(f"Saving offline model to {exp_manager.save_path}/policy.pt") + offline_model.save_policy(f"{exp_manager.save_path}/policy.pt") + # print("Starting training") + # TODO: convert d3rlpy weights to DB3 + model.env.close() + exit() + + # Normal training + if model is not None: + exp_manager.learn(model) + exp_manager.save_trained_model(model) else: exp_manager.hyperparameters_optimization() diff --git a/utils/benchmark.py b/utils/benchmark.py index 9c7e8d597..b53430fb2 100644 --- a/utils/benchmark.py +++ b/utils/benchmark.py @@ -9,7 +9,7 @@ import pytablewriter from stable_baselines3.common.results_plotter import load_results, ts2xy -from utils.utils import get_latest_run_id, get_saved_hyperparams, get_trained_models +from utils.utils import get_hf_trained_models, get_latest_run_id, get_saved_hyperparams, get_trained_models parser = argparse.ArgumentParser() parser.add_argument("--log-dir", help="Root log folder", default="rl-trained-agents/", type=str) @@ -20,10 +20,15 @@ parser.add_argument("--seed", help="Random generator seed", type=int, default=0) parser.add_argument("--test-mode", action="store_true", default=False, help="Do only one experiment (useful for testing)") parser.add_argument("--with-mujoco", action="store_true", default=False, help="Run also MuJoCo envs") +parser.add_argument("--no-hub", action="store_true", default=False, help="Do not download models from hub") parser.add_argument("--num-threads", help="Number of threads for PyTorch", default=2, type=int) args = parser.parse_args() trained_models = get_trained_models(args.log_dir) + +if not args.no_hub: + trained_models.update(get_hf_trained_models()) + n_experiments = len(trained_models) results = { "algo": [], @@ -133,8 +138,12 @@ results_df = pd.DataFrame(results) # Sort results results_df = results_df.sort_values(by=["algo", "env_id"]) +# Create links to Huggingface hub +# links = [f"[{env_id}](https://huggingface.co/sb3/{algo}-{env_id})" +# for algo, env_id in zip(results_df["algo"], results_df["env_id"])] +# results_df["env_id"] = links -writer = pytablewriter.MarkdownTableWriter() +writer = pytablewriter.MarkdownTableWriter(max_precision=3) writer.from_dataframe(results_df) header = """ @@ -147,6 +156,9 @@ It uses the deterministic policy except for Atari games. +You can view each model card (it includes video and hyperparameters) +on our Huggingface page: https://huggingface.co/sb3 + *NOTE: this is not a quantitative benchmark as it corresponds to only one run (cf [issue #38](https://github.com/araffin/rl-baselines-zoo/issues/38)). This benchmark is meant to check algorithm (maximal) performance, find potential bugs diff --git a/utils/callbacks.py b/utils/callbacks.py index 608c5a778..27d485104 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -1,4 +1,5 @@ import os +import pickle import tempfile import time from copy import deepcopy @@ -11,7 +12,7 @@ from stable_baselines3 import SAC from stable_baselines3.common.callbacks import BaseCallback, EvalCallback from stable_baselines3.common.logger import TensorBoardOutputFormat -from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common.vec_env import VecEnv, sync_envs_normalization class TrialEvalCallback(EvalCallback): @@ -31,7 +32,7 @@ def __init__( log_path: Optional[str] = None, ): - super(TrialEvalCallback, self).__init__( + super().__init__( eval_env=eval_env, n_eval_episodes=n_eval_episodes, eval_freq=eval_freq, @@ -46,7 +47,7 @@ def __init__( def _on_step(self) -> bool: if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: - super(TrialEvalCallback, self)._on_step() + super()._on_step() self.eval_idx += 1 # report best or report current ? # report num_timesteps or elasped time ? @@ -69,7 +70,7 @@ class SaveVecNormalizeCallback(BaseCallback): """ def __init__(self, save_freq: int, save_path: str, name_prefix: Optional[str] = None, verbose: int = 0): - super(SaveVecNormalizeCallback, self).__init__(verbose) + super().__init__(verbose) self.save_freq = save_freq self.save_path = save_path self.name_prefix = name_prefix @@ -111,7 +112,7 @@ class ParallelTrainCallback(BaseCallback): """ def __init__(self, gradient_steps: int = 100, verbose: int = 0, sleep_time: float = 0.0): - super(ParallelTrainCallback, self).__init__(verbose) + super().__init__(verbose) self.batch_size = 0 self._model_ready = True self._model = None @@ -130,6 +131,12 @@ def _init_callback(self) -> None: self.model.save(temp_file) + if self.model.get_vec_normalize_env() is not None: + temp_file_norm = os.path.join("logs", "vec_normalize.pkl") + + with open(temp_file_norm, "wb") as file_handler: + pickle.dump(self.model.get_vec_normalize_env(), file_handler) + # TODO: add support for other algorithms for model_class in [SAC, TQC]: if isinstance(self.model, model_class): @@ -139,6 +146,11 @@ def _init_callback(self) -> None: assert self.model_class is not None, f"{self.model} is not supported for parallel training" self._model = self.model_class.load(temp_file) + if self.model.get_vec_normalize_env() is not None: + with open(temp_file_norm, "rb") as file_handler: + self._model._vec_normalize_env = pickle.load(file_handler) + self._model._vec_normalize_env.training = False + self.batch_size = self._model.batch_size # Disable train method @@ -183,6 +195,10 @@ def _on_rollout_end(self) -> None: self._model.replay_buffer = deepcopy(self.model.replay_buffer) self.model.set_parameters(deepcopy(self._model.get_parameters())) self.model.actor = self.model.policy.actor + # Sync VecNormalize + if self.model.get_vec_normalize_env() is not None: + sync_envs_normalization(self.model.get_vec_normalize_env(), self._model._vec_normalize_env) + if self.num_timesteps >= self._model.learning_starts: self.train() # Do not wait for the training loop to finish @@ -202,7 +218,7 @@ class RawStatisticsCallback(BaseCallback): """ def __init__(self, verbose=0): - super(RawStatisticsCallback, self).__init__(verbose) + super().__init__(verbose) # Custom counter to reports stats # (and avoid reporting multiple values for the same step) self._timesteps_counter = 0 @@ -227,3 +243,25 @@ def _on_step(self) -> bool: self._tensorboard_writer.write(logger_dict, exclude_dict, self._timesteps_counter) return True + + +class LapTimeCallback(BaseCallback): + def _on_training_start(self): + self.n_laps = 0 + output_formats = self.logger.output_formats + # Save reference to tensorboard formatter object + # note: the failure case (not formatter found) is not handled here, should be done with try/except. + self.tb_formatter = next(formatter for formatter in output_formats if isinstance(formatter, TensorBoardOutputFormat)) + + def _on_step(self) -> bool: + lap_count = self.locals["infos"][0]["lap_count"] + lap_time = self.locals["infos"][0]["last_lap_time"] + + if lap_count != self.n_laps and lap_time > 0: + self.n_laps = lap_count + self.tb_formatter.writer.add_scalar("time/lap_time", lap_time, self.num_timesteps) + if lap_count == 1: + self.tb_formatter.writer.add_scalar("time/first_lap_time", lap_time, self.num_timesteps) + else: + self.tb_formatter.writer.add_scalar("time/second_lap_time", lap_time, self.num_timesteps) + self.tb_formatter.writer.flush() diff --git a/utils/exp_manager.py b/utils/exp_manager.py index 1b6725831..3970fe58b 100644 --- a/utils/exp_manager.py +++ b/utils/exp_manager.py @@ -12,9 +12,12 @@ import optuna import torch as th import yaml +import zmq from optuna.integration.skopt import SkoptSampler -from optuna.pruners import BasePruner, MedianPruner, SuccessiveHalvingPruner +from optuna.pruners import BasePruner, MedianPruner, NopPruner, SuccessiveHalvingPruner from optuna.samplers import BaseSampler, RandomSampler, TPESampler +from optuna.study import MaxTrialsCallback +from optuna.trial import TrialState from optuna.visualization import plot_optimization_history, plot_param_importances from sb3_contrib.common.vec_env import AsyncEval @@ -47,7 +50,7 @@ from utils.utils import ALGOS, get_callback_list, get_latest_run_id, get_wrapper_class, linear_schedule -class ExperimentManager(object): +class ExperimentManager: """ Experiment manager: read the hyperparameters, preprocess them, create the environment and the RL model. @@ -73,6 +76,7 @@ def __init__( storage: Optional[str] = None, study_name: Optional[str] = None, n_trials: int = 1, + max_total_trials: Optional[int] = None, n_jobs: int = 1, sampler: str = "tpe", pruner: str = "median", @@ -90,7 +94,7 @@ def __init__( no_optim_plots: bool = False, device: Union[th.device, str] = "auto", ): - super(ExperimentManager, self).__init__() + super().__init__() self.algo = algo self.env_id = env_id # Custom params @@ -103,8 +107,10 @@ def __init__( self.frame_stack = None self.seed = seed self.optimization_log_path = optimization_log_path + self.vec_env_wrapper = None self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type] + self.vec_env_wrapper = None self.vec_env_kwargs = {} # self.vec_env_kwargs = {} if vec_env_type == "dummy" else {"start_method": "fork"} @@ -133,6 +139,7 @@ def __init__( self.no_optim_plots = no_optim_plots # maximum number of trials for finding the best hyperparams self.n_trials = n_trials + self.max_total_trials = max_total_trials # number of parallel jobs when doing hyperparameter search self.n_jobs = n_jobs self.sampler = sampler @@ -164,7 +171,7 @@ def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]: :return: the initialized RL model """ hyperparams, saved_hyperparams = self.read_hyperparameters() - hyperparams, self.env_wrapper, self.callbacks = self._preprocess_hyperparams(hyperparams) + hyperparams, self.env_wrapper, self.callbacks, self.vec_env_wrapper = self._preprocess_hyperparams(hyperparams) self.create_log_folder() self.create_callbacks() @@ -212,7 +219,7 @@ def learn(self, model: BaseAlgorithm) -> None: try: model.learn(self.n_timesteps, **kwargs) - except KeyboardInterrupt: + except (KeyboardInterrupt, zmq.error.ZMQError): # this allows to save the model when interrupting training pass finally: @@ -260,7 +267,7 @@ def _save_config(self, saved_hyperparams: Dict[str, Any]) -> None: def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: # Load hyperparameters from yaml file - with open(f"hyperparams/{self.algo}.yml", "r") as f: + with open(f"hyperparams/{self.algo}.yml") as f: hyperparams_dict = yaml.safe_load(f) if self.env_id in list(hyperparams_dict.keys()): hyperparams = hyperparams_dict[self.env_id] @@ -322,7 +329,7 @@ def _preprocess_normalization(self, hyperparams: Dict[str, Any]) -> Dict[str, An def _preprocess_hyperparams( self, hyperparams: Dict[str, Any] - ) -> Tuple[Dict[str, Any], Optional[Callable], List[BaseCallback]]: + ) -> Tuple[Dict[str, Any], Optional[Callable], List[BaseCallback], Optional[Callable]]: self.n_envs = hyperparams.get("n_envs", 1) if self.verbose > 0: @@ -344,7 +351,7 @@ def _preprocess_hyperparams( # Derive n_evaluations from number of timesteps if needed if self.n_evaluations is None and self.optimize_hyperparameters: - self.n_evaluations = self.n_timesteps // int(1e5) + self.n_evaluations = max(1, self.n_timesteps // int(1e5)) print( f"Doing {self.n_evaluations} intermediate evaluations for pruning based on the number of timesteps." " (1 evaluation every 100k timesteps)" @@ -374,12 +381,17 @@ def _preprocess_hyperparams( if "env_wrapper" in hyperparams.keys(): del hyperparams["env_wrapper"] + # Same for VecEnvWrapper + vec_env_wrapper = get_wrapper_class(hyperparams, "vec_env_wrapper") + if "vec_env_wrapper" in hyperparams.keys(): + del hyperparams["vec_env_wrapper"] + callbacks = get_callback_list(hyperparams) if "callback" in hyperparams.keys(): self.specified_callbacks = hyperparams["callback"] del hyperparams["callback"] - return hyperparams, env_wrapper, callbacks + return hyperparams, env_wrapper, callbacks, vec_env_wrapper def _preprocess_action_noise( self, hyperparams: Dict[str, Any], saved_hyperparams: Dict[str, Any], env: VecEnv @@ -453,17 +465,17 @@ def create_callbacks(self): @staticmethod def is_atari(env_id: str) -> bool: - entry_point = gym.envs.registry.env_specs[env_id].entry_point + entry_point = gym.envs.registry.env_specs[env_id].entry_point # pytype: disable=module-attr return "AtariEnv" in str(entry_point) @staticmethod def is_bullet(env_id: str) -> bool: - entry_point = gym.envs.registry.env_specs[env_id].entry_point + entry_point = gym.envs.registry.env_specs[env_id].entry_point # pytype: disable=module-attr return "pybullet_envs" in str(entry_point) @staticmethod def is_robotics_env(env_id: str) -> bool: - entry_point = gym.envs.registry.env_specs[env_id].entry_point + entry_point = gym.envs.registry.env_specs[env_id].entry_point # pytype: disable=module-attr return "gym.envs.robotics" in str(entry_point) or "panda_gym.envs" in str(entry_point) def _maybe_normalize(self, env: VecEnv, eval_env: bool) -> VecEnv: @@ -537,6 +549,9 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False) monitor_kwargs=monitor_kwargs, ) + if self.vec_env_wrapper is not None: + env = self.vec_env_wrapper(env) + # Wrap the env into a VecNormalize wrapper if needed # and load saved statistics when present env = self._maybe_normalize(env, eval_env) @@ -619,7 +634,7 @@ def _create_pruner(self, pruner_method: str) -> BasePruner: pruner = MedianPruner(n_startup_trials=self.n_startup_trials, n_warmup_steps=self.n_evaluations // 3) elif pruner_method == "none": # Do not prune - pruner = MedianPruner(n_startup_trials=self.n_trials, n_warmup_steps=self.n_evaluations) + pruner = NopPruner() else: raise ValueError(f"Unknown pruner: {pruner_method}") return pruner @@ -747,7 +762,28 @@ def hyperparameters_optimization(self) -> None: ) try: - study.optimize(self.objective, n_trials=self.n_trials, n_jobs=self.n_jobs) + if self.max_total_trials is not None: + # Note: we count already running trials here otherwise we get + # (max_total_trials + number of workers) trials in total. + counted_states = [ + TrialState.COMPLETE, + TrialState.RUNNING, + TrialState.PRUNED, + ] + completed_trials = len(study.get_trials(states=counted_states)) + if completed_trials < self.max_total_trials: + study.optimize( + self.objective, + n_jobs=self.n_jobs, + callbacks=[ + MaxTrialsCallback( + self.max_total_trials, + states=counted_states, + ) + ], + ) + else: + study.optimize(self.objective, n_jobs=self.n_jobs, n_trials=self.n_trials) except KeyboardInterrupt: pass diff --git a/utils/import_envs.py b/utils/import_envs.py index fbe0370e3..1dddf2880 100644 --- a/utils/import_envs.py +++ b/utils/import_envs.py @@ -1,3 +1,8 @@ +import gym +from gym.envs.registration import register + +from utils.wrappers import MaskVelocityWrapper + try: import pybullet_envs # pytype: disable=import-error except ImportError: @@ -28,7 +33,35 @@ except ImportError: gym_donkeycar = None +try: + import rl_racing.envs # pytype: disable=import-error +except ImportError: + rl_racing = None + +try: + import gym_space_engineers # pytype: disable=import-error +except ImportError: + gym_space_engineers = None + try: import panda_gym # pytype: disable=import-error except ImportError: panda_gym = None + + +# Register no vel envs +def create_no_vel_env(env_id: str): + def make_env(): + env = gym.make(env_id) + env = MaskVelocityWrapper(env) + return env + + return make_env + + +for env_id in MaskVelocityWrapper.velocity_indices.keys(): + name, version = env_id.split("-v") + register( + id=f"{name}NoVel-v{version}", + entry_point=create_no_vel_env(env_id), + ) diff --git a/utils/load_from_hub.py b/utils/load_from_hub.py new file mode 100644 index 000000000..490ef1f8f --- /dev/null +++ b/utils/load_from_hub.py @@ -0,0 +1,123 @@ +import argparse +import os +import shutil +import zipfile +from pathlib import Path +from typing import Optional + +from huggingface_sb3 import load_from_hub +from requests.exceptions import HTTPError + +from utils import ALGOS, get_latest_run_id + + +def download_from_hub( + algo: str, + env_id: str, + exp_id: int, + folder: str, + organization: str, + repo_name: Optional[str] = None, + force: bool = False, +) -> None: + """ + Try to load a model from the Huggingface hub + and save it following the RL Zoo structure. + Default repo name is {organization}/{algo}-{env_id} + where repo_name = {algo}-{env_id} + + :param algo: Algorithm + :param env_id: Environment id + :param exp_id: Experiment id + :param folder: Log folder + :param organization: Huggingface organization + :param repo_name: Overwrite default repository name + :param force: Allow overwritting the folder + if it already exists. + """ + + if repo_name is None: + repo_name = f"{algo}-{env_id}" + + repo_id = f"{organization}/{repo_name}" + print(f"Downloading from https://huggingface.co/{repo_id}") + + model_name = f"{algo}-{env_id}" + + checkpoint = load_from_hub(repo_id, f"{model_name}.zip") + config_path = load_from_hub(repo_id, "config.yml") + + # If VecNormalize, download + try: + vec_normalize_stats = load_from_hub(repo_id, "vec_normalize.pkl") + except HTTPError: + print("No normalization file") + vec_normalize_stats = None + + saved_args = load_from_hub(repo_id, "args.yml") + env_kwargs = load_from_hub(repo_id, "env_kwargs.yml") + train_eval_metrics = load_from_hub(repo_id, "train_eval_metrics.zip") + + if exp_id == 0: + exp_id = get_latest_run_id(os.path.join(folder, algo), env_id) + 1 + # Sanity checks + if exp_id > 0: + log_path = os.path.join(folder, algo, f"{env_id}_{exp_id}") + else: + log_path = os.path.join(folder, algo) + + # Check that the folder does not exist + log_folder = Path(log_path) + if log_folder.is_dir(): + if force: + print(f"The folder {log_path} already exists, overwritting") + # Delete the current one to avoid errors + shutil.rmtree(log_path) + else: + raise ValueError( + f"The folder {log_path} already exists, use --force to overwrite it, " + "or choose '--exp-id 0' to create a new folder" + ) + + print(f"Saving to {log_path}") + # Create folder structure + os.makedirs(log_path, exist_ok=True) + config_folder = os.path.join(log_path, env_id) + os.makedirs(config_folder, exist_ok=True) + + # Copy config files and saved stats + shutil.copy(checkpoint, os.path.join(log_path, f"{env_id}.zip")) + shutil.copy(saved_args, os.path.join(config_folder, "args.yml")) + shutil.copy(config_path, os.path.join(config_folder, "config.yml")) + shutil.copy(env_kwargs, os.path.join(config_folder, "env_kwargs.yml")) + if vec_normalize_stats is not None: + shutil.copy(vec_normalize_stats, os.path.join(config_folder, "vecnormalize.pkl")) + + # Extract monitor file and evaluation file + with zipfile.ZipFile(train_eval_metrics, "r") as zip_ref: + zip_ref.extractall(log_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--env", help="environment ID", type=str, required=True) + parser.add_argument("-f", "--folder", help="Log folder", type=str, required=True) + parser.add_argument("-orga", "--organization", help="Huggingface hub organization", default="sb3") + parser.add_argument("-name", "--repo-name", help="Huggingface hub repository name, by default 'algo-env_id'", type=str) + parser.add_argument("--algo", help="RL Algorithm", type=str, required=True, choices=list(ALGOS.keys())) + parser.add_argument("--exp-id", help="Experiment ID (default: 0: latest, -1: no exp folder)", default=0, type=int) + parser.add_argument("--verbose", help="Verbose mode (0: no output, 1: INFO)", default=1, type=int) + parser.add_argument( + "--force", action="store_true", default=False, help="Allow overwritting exp folder if it already exist" + ) + args = parser.parse_args() + + download_from_hub( + algo=args.algo, + env_id=args.env, + exp_id=args.exp_id, + folder=args.folder, + organization=args.organization, + repo_name=args.repo_name, + force=args.force, + ) diff --git a/utils/push_to_hub.py b/utils/push_to_hub.py new file mode 100644 index 000000000..b2a9ca878 --- /dev/null +++ b/utils/push_to_hub.py @@ -0,0 +1,403 @@ +import argparse +import glob +import os +import shutil +import zipfile +from copy import deepcopy +from pathlib import Path +from pprint import pformat +from typing import Any, Dict, Optional, Tuple + +import torch as th +import yaml +from huggingface_hub import HfApi, Repository +from huggingface_hub.repocard import metadata_save +from huggingface_sb3.push_to_hub import _evaluate_agent, _generate_replay, generate_metadata +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.utils import set_random_seed +from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize +from wasabi import Printer + +import utils.import_envs # noqa: F401 pylint: disable=unused-import +from utils import ALGOS, create_test_env, get_saved_hyperparams +from utils.exp_manager import ExperimentManager +from utils.utils import StoreDict, get_model_path + +msg = Printer() + + +def save_model_card(repo_dir: Path, generated_model_card: str, metadata: Dict[str, Any]) -> None: + """Saves a model card for the repository. + + :param repo_dir: repository directory + :param generated_model_card: model card generated by _generate_model_card() + :param metadata: metadata + """ + readme_path = repo_dir / "README.md" + # Always overwrite README + with readme_path.open("w", encoding="utf-8") as f: + f.write(generated_model_card) + + # Save our metrics to Readme metadata + metadata_save(readme_path, metadata) + + +def generate_model_card( + algo_name: str, + algo_class_name: str, + organization: str, + env_id: str, + mean_reward: float, + std_reward: float, + hyperparams: Dict[str, Any], + env_kwargs: Dict[str, Any], +) -> Tuple[str, Dict[str, Any]]: + """ + Generate the model card for the Hub + + :param algo_class_name: name of the algorithm class + :param env_id: name of the environment + :param mean_reward: mean reward of the agent + :param std_reward: standard deviation of the mean reward of the agent + :return: Model card (readme) and metadata (performance, algo/env id, tags) + """ + # Step 1: Select the tags + metadata = generate_metadata(algo_class_name, env_id, mean_reward, std_reward) + + # Step 2: Generate the model card + model_card = f""" +# **{algo_class_name}** Agent playing **{env_id}** +This is a trained model of a **{algo_class_name}** agent playing **{env_id}** +using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3) +and the [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo). + +The RL Zoo is a training framework for Stable Baselines3 +reinforcement learning agents, +with hyperparameter optimization and pre-trained agents included. +""" + + model_card += f""" +## Usage (with SB3 RL Zoo) + +RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo
+SB3: https://github.com/DLR-RM/stable-baselines3
+SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib + +``` +# Download model and save it into the logs/ folder +python -m utils.load_from_hub --algo {algo_name} --env {env_id} -orga {organization} -f logs/ +python enjoy.py --algo {algo_name} --env {env_id} -f logs/ +``` + +## Training (with the RL Zoo) +``` +python train.py --algo {algo_name} --env {env_id} -f logs/ +# Upload the model and generate video (when possible) +python -m utils.push_to_hub --algo {algo_name} --env {env_id} -f logs/ -orga {organization} +``` + +## Hyperparameters +```python +{pformat(hyperparams)} +``` +""" + if len(env_kwargs) > 0: + model_card += f""" +# Environment Arguments +```python +{pformat(env_kwargs)} +``` +""" + + return model_card, metadata + + +def package_to_hub( + model: BaseAlgorithm, + model_name: str, + algo_name: str, + algo_class_name: str, + log_path: Path, + hyperparams: Dict[str, Any], + env_kwargs: Dict[str, Any], + env_id: str, + eval_env: VecEnv, + repo_id: str, + commit_message: str, + is_deterministic: bool = True, + n_eval_episodes=10, + token: Optional[str] = None, + local_repo_path="hub", + video_length=1000, + generate_video: bool = False, +): + """ + Evaluate, Generate a video and Upload a model to Hugging Face Hub. + This method does the complete pipeline: + - It evaluates the model + - It generates the model card + - It generates a replay video of the agent + - It pushes everything to the hub + + This is a work in progress function, if it does not work, + use `push_to_hub` method. + + :param model: trained model + :param model_name: name of the model zip file + :param algo_name: alias used in the zoo for the algorithm, + usually lower case of the class (a2c, ars, ppo, ppo_lstm) + :param algo_class_name: name of the architecture of your model + Name of the algorithm class. + (DQN, PPO, A2C, SAC, RecurrentPPO, ...) + :param log_path: Path to where the model is saved in the zoo. + :param hyperparams: Hyperparameters used for training, + includes wrappers. + :param env_kwargs: Additional keyword arguments that were passed + to the environment. + :param env_id: name of the environment + :param eval_env: environment used to evaluate the agent + :param repo_id: id of the model repository from the Hugging Face Hub + :param commit_message: commit message + :param is_deterministic: use deterministic or stochastic actions (by default: True) + :param n_eval_episodes: number of evaluation episodes (by default: 10) + :param local_repo_path: local repository path + :param video_length: length of the video (in timesteps) + """ + + msg.info( + "This function will save, evaluate, generate a video of your agent, " + "create a model card and push everything to the hub. " + "It might take up to some minutes if video generation is activated. " + "This is a work in progress: if you encounter a bug, please open an issue." + ) + + organization, repo_name = repo_id.split("/") + + # Step 1: Clone or create the repo + # Create the repo (or clone its content if it's nonempty) + api = HfApi() + + repo_url = api.create_repo( + name=repo_name, + token=token, + organization=organization, + private=False, + exist_ok=True, + ) + + # Git pull + repo_local_path = Path(local_repo_path) / repo_name + repo = Repository(repo_local_path, clone_from=repo_url, use_auth_token=True) + repo.git_pull(rebase=True) + + repo.lfs_track(["*.mp4"]) + + # Step 1: Save the model + model.save(repo_local_path / model_name) + + # Retrieve VecNormalize wrapper if it exists + # we need to save the statistics + maybe_vec_normalize = unwrap_vec_normalize(eval_env) + + # Save the normalization + if maybe_vec_normalize is not None: + maybe_vec_normalize.save(repo_local_path / "vec_normalize.pkl") + # Do not update the stats at test time + maybe_vec_normalize.training = False + # Reward normalization is not needed at test time + maybe_vec_normalize.norm_reward = False + + # Unzip the model + with zipfile.ZipFile(repo_local_path / f"{model_name}.zip", "r") as zip_ref: + zip_ref.extractall(repo_local_path / model_name) + + # Step 2: Copy config files + args_path = log_path / env_id / "args.yml" + config_path = log_path / env_id / "config.yml" + + shutil.copy(args_path, repo_local_path / "args.yml") + shutil.copy(config_path, repo_local_path / "config.yml") + with open(repo_local_path / "env_kwargs.yml", "w") as outfile: + yaml.dump(env_kwargs, outfile) + + # Copy train/eval metrics into zip + with zipfile.ZipFile(repo_local_path / "train_eval_metrics.zip", "w") as archive: + if os.path.isfile(log_path / "evaluations.npz"): + archive.write(log_path / "evaluations.npz", arcname="evaluations.npz") + for monitor_file in glob.glob(f"{log_path}/*.csv"): + archive.write(monitor_file, arcname=monitor_file.split(os.sep)[-1]) + + # Step 3: Evaluate the agent + mean_reward, std_reward = _evaluate_agent(model, eval_env, n_eval_episodes, is_deterministic, repo_local_path) + + # Step 4: Generate a video + if generate_video: + _generate_replay(model, eval_env, video_length, is_deterministic, repo_local_path) + # Cleanup files after generation + # TODO: upstream to huggingface sb3 + video_path = Path("test.mp4") + if video_path.is_file(): + video_path.unlink() + json_path = list(glob.glob("*.meta.json")) + if len(json_path) > 0: + Path(json_path[0]).unlink() + + # Step 5: Generate the model card + generated_model_card, metadata = generate_model_card( + algo_name, + algo_class_name, + organization, + env_id, + mean_reward, + std_reward, + hyperparams, + env_kwargs, + ) + + save_model_card(repo_local_path, generated_model_card, metadata) + + msg.info(f"Pushing repo {repo_name} to the Hugging Face Hub") + repo.push_to_hub(commit_message=commit_message) + + msg.info(f"Your model is pushed to the hub. You can view your model here: {repo_url}") + return repo_url + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--env", help="environment ID", type=str, required=True) + parser.add_argument("-f", "--folder", help="Log folder", type=str, required=True) + parser.add_argument("--algo", help="RL Algorithm", type=str, required=True, choices=list(ALGOS.keys())) + parser.add_argument("-n", "--n-timesteps", help="number of timesteps", default=1000, type=int) + parser.add_argument("--num-threads", help="Number of threads for PyTorch (-1 to use default)", default=-1, type=int) + parser.add_argument("--n-envs", help="number of environments", default=1, type=int) + parser.add_argument("--exp-id", help="Experiment ID (default: 0: latest, -1: no exp folder)", default=0, type=int) + parser.add_argument("--verbose", help="Verbose mode (0: no output, 1: INFO)", default=1, type=int) + parser.add_argument( + "--no-render", action="store_true", default=False, help="Do not render the environment (useful for tests)" + ) + parser.add_argument("--deterministic", action="store_true", default=False, help="Use deterministic actions") + parser.add_argument("--device", help="PyTorch device to be use (ex: cpu, cuda...)", default="auto", type=str) + parser.add_argument( + "--load-best", action="store_true", default=False, help="Load best model instead of last model if available" + ) + parser.add_argument( + "--load-checkpoint", + type=int, + help="Load checkpoint instead of last model if available, " + "you must pass the number of timesteps corresponding to it", + ) + parser.add_argument( + "--load-last-checkpoint", + action="store_true", + default=False, + help="Load last checkpoint instead of last model if available", + ) + parser.add_argument("--stochastic", action="store_true", default=False, help="Use stochastic actions") + parser.add_argument("--seed", help="Random generator seed", type=int, default=0) + parser.add_argument( + "--env-kwargs", type=str, nargs="+", action=StoreDict, help="Optional keyword argument to pass to the env constructor" + ) + parser.add_argument("-orga", "--organization", help="Huggingface hub organization", type=str, required=True) + parser.add_argument("-name", "--repo-name", help="Huggingface hub repository name, by default 'algo-env_id'", type=str) + parser.add_argument("-m", "--commit-message", help="Commit message", default="Initial commit", type=str) + + args = parser.parse_args() + env_id = args.env + algo = args.algo + + _, model_path, log_path = get_model_path( + args.exp_id, + args.folder, + args.algo, + args.env, + args.load_best, + args.load_checkpoint, + args.load_last_checkpoint, + ) + + print(f"Loading {model_path}") + + # Off-policy algorithm only support one env for now + off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"] + + if algo in off_policy_algos: + args.n_envs = 1 + + set_random_seed(args.seed) + + if args.num_threads > 0: + if args.verbose > 1: + print(f"Setting torch.num_threads to {args.num_threads}") + th.set_num_threads(args.num_threads) + + is_atari = ExperimentManager.is_atari(env_id) + + stats_path = os.path.join(log_path, env_id) + hyperparams, stats_path = get_saved_hyperparams(stats_path, test_mode=True) + + # load env_kwargs if existing + env_kwargs = {} + args_path = os.path.join(log_path, env_id, "args.yml") + if os.path.isfile(args_path): + with open(args_path) as f: + loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr + if loaded_args["env_kwargs"] is not None: + env_kwargs = loaded_args["env_kwargs"] + # overwrite with command line arguments + if args.env_kwargs is not None: + env_kwargs.update(args.env_kwargs) + + eval_env = create_test_env( + env_id, + n_envs=args.n_envs, + stats_path=stats_path, + seed=args.seed, + log_dir=None, + should_render=not args.no_render, + hyperparams=deepcopy(hyperparams), + env_kwargs=env_kwargs, + ) + + kwargs = dict(seed=args.seed) + if algo in off_policy_algos: + # Dummy buffer size as we don't need memory to enjoy the trained agent + kwargs.update(dict(buffer_size=1)) + + # Note: we assume that we push models using the same machine (same python version) + # that trained them, if not, we would need to pass custom object as in enjoy.py + custom_objects = {} + model = ALGOS[algo].load(model_path, env=eval_env, custom_objects=custom_objects, device=args.device, **kwargs) + + # Deterministic by default except for atari games + stochastic = args.stochastic or is_atari and not args.deterministic + deterministic = not stochastic + + # Default model name, the model will be saved under "{algo}-{env_id}.zip" + model_name = f"{algo}-{env_id}" + + if args.repo_name is None: + args.repo_name = model_name + + repo_id = f"{args.organization}/{args.repo_name}" + print(f"Uploading to {repo_id}, make sure to have the rights") + + package_to_hub( + model, + model_name, + algo, + ALGOS[algo].__name__, + Path(log_path), + hyperparams, + env_kwargs, + env_id, + eval_env, + repo_id=repo_id, + commit_message=args.commit_message, + is_deterministic=deterministic, + n_eval_episodes=10, + token=None, + local_repo_path="hub", + video_length=1000, + generate_video=not args.no_render, + ) diff --git a/utils/record_video.py b/utils/record_video.py index dd89c0220..0bb66b114 100644 --- a/utils/record_video.py +++ b/utils/record_video.py @@ -2,12 +2,13 @@ import os import sys +import numpy as np import yaml from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.vec_env import VecVideoRecorder from utils.exp_manager import ExperimentManager -from utils.utils import ALGOS, StoreDict, create_test_env, get_latest_run_id, get_saved_hyperparams +from utils.utils import ALGOS, StoreDict, create_test_env, get_model_path, get_saved_hyperparams if __name__ == "__main__": # noqa: C901 parser = argparse.ArgumentParser() @@ -32,6 +33,12 @@ type=int, help="Load checkpoint instead of last model if available, you must pass the number of timesteps corresponding to it", ) + parser.add_argument( + "--load-last-checkpoint", + action="store_true", + default=False, + help="Load last checkpoint instead of last model if available", + ) parser.add_argument( "--env-kwargs", type=str, nargs="+", action=StoreDict, help="Optional keyword argument to pass to the env constructor" ) @@ -44,35 +51,18 @@ seed = args.seed video_length = args.n_timesteps n_envs = args.n_envs - load_best = args.load_best - load_checkpoint = args.load_checkpoint - - if args.exp_id == 0: - args.exp_id = get_latest_run_id(os.path.join(folder, algo), env_id) - print(f"Loading latest experiment, id={args.exp_id}") - # Sanity checks - if args.exp_id > 0: - log_path = os.path.join(folder, algo, f"{env_id}_{args.exp_id}") - else: - log_path = os.path.join(folder, algo) - - assert os.path.isdir(log_path), f"The {log_path} folder was not found" - - if load_best: - model_path = os.path.join(log_path, "best_model.zip") - name_prefix = f"best-model-{algo}-{env_id}" - elif load_checkpoint is None: - # Default: load latest model - model_path = os.path.join(log_path, f"{env_id}.zip") - name_prefix = f"final-model-{algo}-{env_id}" - else: - model_path = os.path.join(log_path, f"rl_model_{args.load_checkpoint}_steps.zip") - name_prefix = f"checkpoint-{args.load_checkpoint}-{algo}-{env_id}" - - found = os.path.isfile(model_path) - if not found: - raise ValueError(f"No model found for {algo} on {env_id}, path: {model_path}") + name_prefix, model_path, log_path = get_model_path( + args.exp_id, + folder, + algo, + env_id, + args.load_best, + args.load_checkpoint, + args.load_last_checkpoint, + ) + + print(f"Loading {model_path}") off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"] set_random_seed(args.seed) @@ -86,7 +76,7 @@ env_kwargs = {} args_path = os.path.join(log_path, env_id, "args.yml") if os.path.isfile(args_path): - with open(args_path, "r") as f: + with open(args_path) as f: loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr if loaded_args["env_kwargs"] is not None: env_kwargs = loaded_args["env_kwargs"] @@ -126,8 +116,6 @@ model = ALGOS[algo].load(model_path, env=env, custom_objects=custom_objects, **kwargs) - obs = env.reset() - # Deterministic by default except for atari games stochastic = args.stochastic or is_atari and not args.deterministic deterministic = not stochastic @@ -144,11 +132,19 @@ name_prefix=name_prefix, ) - env.reset() + obs = env.reset() + lstm_states = None + episode_starts = np.ones((env.num_envs,), dtype=bool) try: for _ in range(video_length + 1): - action, _ = model.predict(obs, deterministic=deterministic) - obs, _, _, _ = env.step(action) + action, lstm_states = model.predict( + obs, + state=lstm_states, + episode_start=episode_starts, + deterministic=deterministic, + ) + obs, _, dones, _ = env.step(action) + episode_starts = dones if not args.no_render: env.render() except KeyboardInterrupt: diff --git a/utils/teleop.py b/utils/teleop.py new file mode 100644 index 000000000..fa20356a1 --- /dev/null +++ b/utils/teleop.py @@ -0,0 +1,380 @@ +import os +import time +from typing import List, Optional, Tuple + +import numpy as np +import pygame +from pygame.locals import * # noqa: F403 +from sb3_contrib import TQC +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl + +# TELEOP_RATE = 1 / 60 + +UP = (1, 0) +LEFT = (0, 1) +RIGHT = (0, -1) +DOWN = (-1, 0) +STOP = (0, 0) +KEY_CODE_SPACE = 32 + +MAX_TURN = 1 +# Smoothing constants +STEP_THROTTLE = 0.8 +STEP_TURN = 0.8 + +GREEN = (72, 205, 40) +RED = (205, 39, 46) +GREY = (187, 179, 179) +BLACK = (36, 36, 36) +WHITE = (230, 230, 230) +ORANGE = (200, 110, 0) + +# pytype: disable=name-error +moveBindingsGame = {K_UP: UP, K_LEFT: LEFT, K_RIGHT: RIGHT, K_DOWN: DOWN} +# pytype: enable=name-error +pygame.font.init() +FONT = pygame.font.SysFont("Open Sans", 25) +SMALL_FONT = pygame.font.SysFont("Open Sans", 20) +KEY_MIN_DELAY = 0.4 + + +def control(x, theta, control_throttle, control_steering): + """ + Smooth control. + + :param x: (float) + :param theta: (float) + :param control_throttle: (float) + :param control_steering: (float) + :return: (float, float) + """ + target_throttle = x + target_steering = MAX_TURN * theta + if target_throttle > control_throttle: + control_throttle = min(target_throttle, control_throttle + STEP_THROTTLE) + elif target_throttle < control_throttle: + control_throttle = max(target_throttle, control_throttle - STEP_THROTTLE) + else: + control_throttle = target_throttle + + if target_steering > control_steering: + control_steering = min(target_steering, control_steering + STEP_TURN) + elif target_steering < control_steering: + control_steering = max(target_steering, control_steering - STEP_TURN) + else: + control_steering = target_steering + return control_throttle, control_steering + + +class HumanTeleop(BaseAlgorithm): + def __init__( + self, + policy, + env, + buffer_size=50000, + tensorboard_log=None, + verbose=0, + seed=None, + device=None, + _init_setup_model=False, + scale_human=1.0, + scale_model=0.0, + model_path=os.environ.get("MODEL_PATH"), + deterministic=True, + ): + super(HumanTeleop, self).__init__( + policy=None, env=env, policy_base=None, learning_rate=0.0, verbose=verbose, seed=seed + ) + + # pytype: disable=name-error + # self.button_switch_mode = K_m + # self.button_toggle_train_mode = K_t + # pytype: enable=name-error + + # TODO: add to model buffer + allow training in separate thread + + # Used to prevent from multiple successive key press + self.last_time_pressed = {} + self.event_buttons = None + self.action = np.zeros((2,)) + self.exit_thread = False + self.process = None + self.window = None + self.buffer_size = buffer_size + self.replay_buffer = ReplayBuffer( + buffer_size, self.observation_space, self.action_space, self.device, optimize_memory_usage=False + ) + self.scale_human = scale_human + self.scale_model = scale_model + # self.start_process() + self.model = None + self.deterministic = deterministic + # Pretrained model + if model_path is not None: + # "logs/tqc/donkey-generated-track-v0_9/donkey-generated-track-v0.zip" + self.model = TQC.load(model_path) + + def _excluded_save_params(self) -> List[str]: + """ + Returns the names of the parameters that should be excluded by default + when saving the model. + + :return: (List[str]) List of parameters that should be excluded from save + """ + # Exclude aliases + return super()._excluded_save_params() + ["process", "window", "model", "exit_thread"] + + def _setup_model(self): + self.exit_thread = False + + def init_buttons(self): + """ + Initialize the last_time_pressed timers that prevent + successive key press. + """ + self.event_buttons = [ + # self.button_switch_mode, + # self.button_toggle_train_mode, + ] + for key in self.event_buttons: + self.last_time_pressed[key] = 0 + + # def start_process(self): + # """Start main loop process.""" + # # Reset last time pressed + # self.init_buttons() + # self.process = Thread(target=self.main_loop) + # # Make it a deamon, so it will be deleted at the same time + # # of the main process + # self.process.daemon = True + # self.process.start() + + def check_key(self, keys, key): + """ + Check if a key was pressed and update associated timer. + + :param keys: (dict) + :param key: (any hashable type) + :return: (bool) Returns true when a given key was pressed, False otherwise + """ + if key is None: + return False + if keys[key] and (time.time() - self.last_time_pressed[key]) > KEY_MIN_DELAY: + # avoid multiple key press + self.last_time_pressed[key] = time.time() + return True + return False + + def handle_keys_event(self, keys): + """ + Handle the events induced by key press: + e.g. change of mode, toggling recording, ... + """ + + # Switch from "MANUAL" to "AUTONOMOUS" mode + # if self.check_key(keys, self.button_switch_mode) or self.check_key(keys, self.button_pause): + # self.is_manual = not self.is_manual + + def _sample_action(self) -> Tuple[np.ndarray, np.ndarray]: + unscaled_action, _ = self.model.predict(self._last_obs, deterministic=self.deterministic) + + # Rescale the action from [low, high] to [-1, 1] + scaled_action = self.model.policy.scale_action(unscaled_action) + # We store the scaled action in the buffer + buffer_action = scaled_action + action = self.model.policy.unscale_action(scaled_action) + return action, buffer_action + + def main_loop(self, total_timesteps=-1): + """ + Pygame loop that listens to keyboard events. + """ + pygame.init() + # Create a pygame window + self.window = pygame.display.set_mode((800, 500), RESIZABLE) # pytype: disable=name-error + + # Init values and fill the screen + control_throttle, control_steering = 0, 0 + action = [control_steering, control_throttle] + + self.update_screen(action) + + n_steps = 0 + buffer_action = np.array([[0.0, 0.0]]) + + while not self.exit_thread: + x, theta = 0, 0 + # Record pressed keys + keys = pygame.key.get_pressed() + for keycode in moveBindingsGame.keys(): + if keys[keycode]: + x_tmp, th_tmp = moveBindingsGame[keycode] + x += x_tmp + theta += th_tmp + + self.handle_keys_event(keys) + + # Smooth control for teleoperation + control_throttle, control_steering = control(x, theta, control_throttle, control_steering) + scaled_action = np.array([[-control_steering, control_throttle]]).astype(np.float32) + + # Use trained RL model action + if self.model is not None: + _, buffer_action_model = self._sample_action() + else: + buffer_action_model = 0.0 + scaled_action = self.scale_human * scaled_action + self.scale_model * buffer_action_model + scaled_action = np.clip(scaled_action, -1.0, 1.0) + + # We store the scaled action in the buffer + buffer_action = scaled_action + + # No need to unnormalized normally as the bounds are [-1, 1] + if self.model is not None: + action = self.model.policy.unscale_action(scaled_action) + else: + action = scaled_action.copy() + + self.action = action.copy()[0] + + if self.model is not None: + if ( + self.model.use_sde + and self.model.sde_sample_freq > 0 + and n_steps % self.model.sde_sample_freq == 0 + and not self.deterministic + ): + # Sample a new noise matrix + self.model.actor.reset_noise() + + new_obs, reward, done, infos = self.env.step(action) + + next_obs = new_obs + if done and infos[0].get("terminal_observation") is not None: + next_obs = infos[0]["terminal_observation"] + + self.replay_buffer.add(self._last_obs, next_obs, buffer_action, reward, done, infos) + + self._last_obs = new_obs + + self.update_screen(self.action) + + n_steps += 1 + if total_timesteps > 0: + self.exit_thread = n_steps >= total_timesteps + + if done: + print(f"{n_steps} steps") + # reset values + control_throttle, control_steering = 0, 0 + + for event in pygame.event.get(): + if (event.type == QUIT or event.type == KEYDOWN) and event.key in [ # pytype: disable=name-error + K_ESCAPE, # pytype: disable=name-error + K_q, # pytype: disable=name-error + ]: + self.exit_thread = True + pygame.display.flip() + # Limit FPS + # pygame.time.Clock().tick(1 / TELEOP_RATE) + + def write_text(self, text, x, y, font, color=GREY): + """ + :param text: (str) + :param x: (int) + :param y: (int) + :param font: (str) + :param color: (tuple) + """ + text = str(text) + text = font.render(text, True, color) + self.window.blit(text, (x, y)) + + def clear(self): + self.window.fill((0, 0, 0)) + + def update_screen(self, action): + """ + Update pygame window. + + :param action: ([float]) + """ + self.clear() + steering, throttle = action + self.write_text("Throttle: {:.2f}, Steering: {:.2f}".format(throttle, steering), 20, 0, FONT, WHITE) + + def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + """ + Get the name of the torch variables that will be saved. + ``th.save`` and ``th.load`` will be used with the right device + instead of the default pickling strategy. + + :return: (Tuple[List[str], List[str]]) + name of the variables with state dicts to save, name of additional torch tensors, + """ + return [], [] + + def learn( + self, + total_timesteps, + callback=None, + log_interval=100, + tb_log_name="run", + eval_env=None, + eval_freq=-1, + n_eval_episodes=5, + eval_log_path=None, + reset_num_timesteps=True, + ) -> "HumanTeleop": + self._last_obs = self.env.reset() + # Wait for teleop process + # time.sleep(3) + self.main_loop(total_timesteps) + + # with threading: + # for _ in range(total_timesteps): + # print(np.array([self.action])) + # self.env.step(np.array([self.action])) + # if self.exit_thread: + # break + return self + + def predict( + self, + observation: np.ndarray, + state: Optional[np.ndarray] = None, + mask: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """ + Get the model's action(s) from an observation + + :param observation: (np.ndarray) the input observation + :param state: (Optional[np.ndarray]) The last states (can be None, used in recurrent policies) + :param mask: (Optional[np.ndarray]) The last masks (can be None, used in recurrent policies) + :param deterministic: (bool) Whether or not to return deterministic actions. + :return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state + (used in recurrent policies) + """ + return self.action, None + + def save_replay_buffer(self, path) -> None: + """ + Save the replay buffer as a pickle file. + + :param path: (Union[str,pathlib.Path, io.BufferedIOBase]) Path to the file where the replay buffer should be saved. + if path is a str or pathlib.Path, the path is automatically created if necessary. + """ + assert self.replay_buffer is not None, "The replay buffer is not defined" + save_to_pkl(path, self.replay_buffer, self.verbose) + + def load_replay_buffer(self, path, truncate_last_traj=False) -> None: + """ + Load a replay buffer from a pickle file. + + :param path: (Union[str, pathlib.Path, io.BufferedIOBase]) Path to the pickled replay buffer. + """ + self.replay_buffer = load_from_pkl(path, self.verbose) + assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class" diff --git a/utils/utils.py b/utils/utils.py index 6072cc7cd..a53d2cc3e 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -2,19 +2,32 @@ import glob import importlib import os +import re from typing import Any, Callable, Dict, List, Optional, Tuple, Union import gym +import numpy as np import stable_baselines3 as sb3 # noqa: F401 import torch as th # noqa: F401 import yaml -from sb3_contrib import ARS, QRDQN, TQC, TRPO +from huggingface_hub import HfApi +from sb3_contrib import ARS, QRDQN, TQC, TRPO, RecurrentPPO from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike # noqa: F401 from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv, VecFrameStack, VecNormalize +try: + from utils.teleop import HumanTeleop +except ImportError: + HumanTeleop = None + +try: + from rl_racing.utils.teleop import HumanTeleop as Teleop +except ImportError: + Teleop = None + # For custom activation fn from torch import nn as nn # noqa: F401 pylint: disable=unused-import @@ -25,11 +38,14 @@ "ppo": PPO, "sac": SAC, "td3": TD3, + "human": HumanTeleop, + "teleop": Teleop, # SB3 Contrib, "ars": ARS, "qrdqn": QRDQN, "tqc": TQC, "trpo": TRPO, + "ppo_lstm": RecurrentPPO, } @@ -42,10 +58,12 @@ def flatten_dict_observations(env: gym.Env) -> gym.Env: return gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys)) -def get_wrapper_class(hyperparams: Dict[str, Any]) -> Optional[Callable[[gym.Env], gym.Env]]: +def get_wrapper_class(hyperparams: Dict[str, Any], key: str = "env_wrapper") -> Optional[Callable[[gym.Env], gym.Env]]: """ Get one or more Gym environment wrapper class specified as a hyper parameter "env_wrapper". + Works also for VecEnvWrapper with the key "vec_env_wrapper". + e.g. env_wrapper: gym_minigrid.wrappers.FlatObsWrapper @@ -67,8 +85,8 @@ def get_module_name(wrapper_name): def get_class_name(wrapper_name): return wrapper_name.split(".")[-1] - if "env_wrapper" in hyperparams.keys(): - wrapper_name = hyperparams.get("env_wrapper") + if key in hyperparams.keys(): + wrapper_name = hyperparams.get(key) if wrapper_name is None: return None @@ -204,6 +222,11 @@ def create_test_env( if "env_wrapper" in hyperparams.keys(): del hyperparams["env_wrapper"] + # Ignore for now + # TODO: handle it properly + if "vec_env_wrapper" in hyperparams.keys(): + del hyperparams["vec_env_wrapper"] + vec_env_kwargs = {} vec_env_cls = DummyVecEnv if n_envs > 1 or (ExperimentManager.is_bullet(env_id) and should_render): @@ -223,6 +246,12 @@ def create_test_env( vec_env_kwargs=vec_env_kwargs, ) + if "vec_env_wrapper" in hyperparams.keys(): + + vec_env_wrapper = get_wrapper_class(hyperparams, "vec_env_wrapper") + env = vec_env_wrapper(env) + del hyperparams["vec_env_wrapper"] + # Load saved stats for normalizing input and rewards # And optionally stack frames if stats_path is not None: @@ -282,6 +311,29 @@ def get_trained_models(log_folder: str) -> Dict[str, Tuple[str, str]]: return trained_models +def get_hf_trained_models(organization: str = "sb3") -> Dict[str, Tuple[str, str]]: + """ + Get pretrained models, + available on the Hugginface hub for a given organization. + + :param organization: + :return: Dict representing the trained agents + """ + api = HfApi() + models = api.list_models(author=organization) + regex = re.compile(r"^(?P[a-z_0-9]+)-(?P[a-zA-Z0-9]+-v[0-9]+)$") + trained_models = {} + for model in models: + # Remove organization + repo_id = model.modelId.split(f"{organization}/")[1] + result = regex.match(repo_id) + # Skip demo repo that does not fit the pattern + if result is not None: + algo, env_id = result.group("algo"), result.group("env_id") + trained_models[f"{algo}-{env_id}"] = (algo, env_id) + return trained_models + + def get_latest_run_id(log_path: str, env_id: str) -> int: """ Returns the latest run number for the given log name and log path, @@ -318,7 +370,7 @@ def get_saved_hyperparams( config_file = os.path.join(stats_path, "config.yml") if os.path.isfile(config_file): # Load saved hyperparameters - with open(os.path.join(stats_path, "config.yml"), "r") as f: + with open(os.path.join(stats_path, "config.yml")) as f: hyperparams = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr hyperparams["normalize"] = hyperparams.get("normalize", False) else: @@ -347,7 +399,7 @@ class StoreDict(argparse.Action): def __init__(self, option_strings, dest, nargs=None, **kwargs): self._nargs = nargs - super(StoreDict, self).__init__(option_strings, dest, nargs=nargs, **kwargs) + super().__init__(option_strings, dest, nargs=nargs, **kwargs) def __call__(self, parser, namespace, values, option_string=None): arg_dict = {} @@ -357,3 +409,120 @@ def __call__(self, parser, namespace, values, option_string=None): # Evaluate the string as python code arg_dict[key] = eval(value) setattr(namespace, self.dest, arg_dict) + + +def evaluate_policy_add_to_buffer( + model, + env, + n_eval_episodes=10, + deterministic=True, + render=False, + callback=None, + reward_threshold=None, + return_episode_rewards=False, + replay_buffer=None, +): + """ + Runs policy for ``n_eval_episodes`` episodes and returns average reward. + This is made to work only with one env. + :param model: (BaseAlgorithm) The RL agent you want to evaluate. + :param env: (gym.Env or VecEnv) The gym environment. In the case of a ``VecEnv`` + this must contain only one environment. + :param n_eval_episodes: (int) Number of episode to evaluate the agent + :param deterministic: (bool) Whether to use deterministic or stochastic actions + :param render: (bool) Whether to render the environment or not + :param callback: (callable) callback function to do additional checks, + called after each step. + :param reward_threshold: (float) Minimum expected reward per episode, + this will raise an error if the performance is not met + :param return_episode_rewards: (bool) If True, a list of reward per episode + will be returned instead of the mean. + :return: (float, float) Mean reward per episode, std of reward per episode + returns ([float], [int]) when ``return_episode_rewards`` is True + """ + if isinstance(env, VecEnv): + assert env.num_envs == 1, "You must pass only one environment when using this function" + + episode_rewards, episode_lengths = [], [] + for _ in range(n_eval_episodes): + obs = env.reset() + done, state = False, None + episode_reward = 0.0 + episode_length = 0 + if model.use_sde: + model.actor.reset_noise() + + while not done: + action, state = model.predict(obs, state=state, deterministic=deterministic) + new_obs, reward, done, info = env.step(action) + episode_reward += reward + if callback is not None: + callback(locals(), globals()) + episode_length += 1 + if replay_buffer is not None: + # We assume actions are normalized but not observation/reward + buffer_action = action + replay_buffer.add(obs, new_obs, buffer_action, reward, done, info) + obs = new_obs + if render: + env.render() + episode_rewards.append(episode_reward) + episode_lengths.append(episode_length) + mean_reward = np.mean(episode_rewards) + std_reward = np.std(episode_rewards) + if reward_threshold is not None: + assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}" + if return_episode_rewards: + return episode_rewards, episode_lengths + return mean_reward, std_reward + + +def get_model_path( + exp_id: int, + folder: str, + algo: str, + env_id: str, + load_best: bool = False, + load_checkpoint: Optional[str] = None, + load_last_checkpoint: bool = False, +) -> Tuple[str, str, str]: + + if exp_id == 0: + exp_id = get_latest_run_id(os.path.join(folder, algo), env_id) + print(f"Loading latest experiment, id={exp_id}") + # Sanity checks + if exp_id > 0: + log_path = os.path.join(folder, algo, f"{env_id}_{exp_id}") + else: + log_path = os.path.join(folder, algo) + + assert os.path.isdir(log_path), f"The {log_path} folder was not found" + + if load_best: + model_path = os.path.join(log_path, "best_model.zip") + name_prefix = f"best-model-{algo}-{env_id}" + elif load_checkpoint is not None: + model_path = os.path.join(log_path, f"rl_model_{load_checkpoint}_steps.zip") + name_prefix = f"checkpoint-{load_checkpoint}-{algo}-{env_id}" + elif load_last_checkpoint: + checkpoints = glob.glob(os.path.join(log_path, "rl_model_*_steps.zip")) + if len(checkpoints) == 0: + raise ValueError(f"No checkpoint found for {algo} on {env_id}, path: {log_path}") + + def step_count(checkpoint_path: str) -> int: + # path follow the pattern "rl_model_*_steps.zip", we count from the back to ignore any other _ in the path + return int(checkpoint_path.split("_")[-2]) + + checkpoints = sorted(checkpoints, key=step_count) + model_path = checkpoints[-1] + name_prefix = f"checkpoint-{step_count(model_path)}-{algo}-{env_id}" + else: + # Default: load latest model + model_path = os.path.join(log_path, f"{env_id}.zip") + name_prefix = f"final-model-{algo}-{env_id}" + + found = os.path.isfile(model_path) + if not found: + raise ValueError(f"No model found for {algo} on {env_id}, path: {model_path}") + + return name_prefix, model_path, log_path diff --git a/utils/wrappers.py b/utils/wrappers.py index 9cdaf783f..9d177c7af 100644 --- a/utils/wrappers.py +++ b/utils/wrappers.py @@ -1,7 +1,52 @@ +import os +from copy import deepcopy +from typing import Optional + import gym import numpy as np +import torch as th from sb3_contrib.common.wrappers import TimeFeatureWrapper # noqa: F401 (backward compatibility) from scipy.signal import iirfilter, sosfilt, zpk2sos +from stable_baselines3 import SAC +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper + + +class VecForceResetWrapper(VecEnvWrapper): + """ + For all environments to reset at once, + and tell the agent the trajectory was truncated. + + :param venv: The vectorized environment + """ + + def __init__(self, venv: VecEnv): + super().__init__(venv=venv) + + def reset(self) -> VecEnvObs: + return self.venv.reset() + + def step_wait(self) -> VecEnvStepReturn: + for env_idx in range(self.num_envs): + obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( + self.actions[env_idx] + ) + self._save_obs(env_idx, obs) + + if self.buf_dones.any(): + for env_idx in range(self.num_envs): + self.buf_infos[env_idx]["terminal_observation"] = self.buf_obs[None][env_idx] + if not self.buf_dones[env_idx]: + self.buf_infos[env_idx]["TimeLimit.truncated"] = True + self.buf_dones[env_idx] = True + obs = self.envs[env_idx].reset() + self._save_obs(env_idx, obs) + + return ( + self._obs_from_buf(), + np.copy(self.buf_rews), + np.copy(self.buf_dones), + deepcopy(self.buf_infos), + ) class DoneOnSuccessWrapper(gym.Wrapper): @@ -11,7 +56,7 @@ class DoneOnSuccessWrapper(gym.Wrapper): """ def __init__(self, env: gym.Env, reward_offset: float = 0.0, n_successes: int = 1): - super(DoneOnSuccessWrapper, self).__init__(env) + super().__init__(env) self.reward_offset = reward_offset self.n_successes = n_successes self.current_successes = 0 @@ -41,12 +86,12 @@ class ActionNoiseWrapper(gym.Wrapper): Add gaussian noise to the action (without telling the agent), to test the robustness of the control. - :param env: (gym.Env) - :param noise_std: (float) Standard deviation of the noise + :param env: + :param noise_std: Standard deviation of the noise """ - def __init__(self, env, noise_std=0.1): - super(ActionNoiseWrapper, self).__init__(env) + def __init__(self, env: gym.Env, noise_std: float = 0.1): + super().__init__(env) self.noise_std = noise_std def step(self, action): @@ -95,13 +140,13 @@ class LowPassFilterWrapper(gym.Wrapper): """ Butterworth-Lowpass - :param env: (gym.Env) + :param env: :param freq: Filter corner frequency. :param df: Sampling rate in Hz. """ - def __init__(self, env, freq=5.0, df=25.0): - super(LowPassFilterWrapper, self).__init__(env) + def __init__(self, env: gym.Env, freq: float = 5.0, df: float = 25.0): + super().__init__(env) self.freq = freq self.df = df self.signal = [] @@ -123,12 +168,12 @@ class ActionSmoothingWrapper(gym.Wrapper): """ Smooth the action using exponential moving average. - :param env: (gym.Env) - :param smoothing_coef: (float) Smoothing coefficient (0 no smoothing, 1 very smooth) + :param env: + :param smoothing_coef: Smoothing coefficient (0 no smoothing, 1 very smooth) """ - def __init__(self, env, smoothing_coef: float = 0.0): - super(ActionSmoothingWrapper, self).__init__(env) + def __init__(self, env: gym.Env, smoothing_coef: float = 0.0): + super().__init__(env) self.smoothing_coef = smoothing_coef self.smoothed_action = None # from https://github.com/rail-berkeley/softlearning/issues/3 @@ -152,12 +197,12 @@ class DelayedRewardWrapper(gym.Wrapper): Delay the reward by `delay` steps, it makes the task harder but more realistic. The reward is accumulated during those steps. - :param env: (gym.Env) - :param delay: (int) Number of steps the reward should be delayed. + :param env: + :param delay: Number of steps the reward should be delayed. """ - def __init__(self, env, delay=10): - super(DelayedRewardWrapper, self).__init__(env) + def __init__(self, env: gym.Env, delay: int = 10): + super().__init__(env) self.delay = delay self.current_step = 0 self.accumulated_reward = 0.0 @@ -185,11 +230,11 @@ class HistoryWrapper(gym.Wrapper): """ Stack past observations and actions to give an history to the agent. - :param env: (gym.Env) - :param horizon: (int) Number of steps to keep in the history. + :param env: + :param horizon:Number of steps to keep in the history. """ - def __init__(self, env: gym.Env, horizon: int = 5): + def __init__(self, env: gym.Env, horizon: int = 2): assert isinstance(env.observation_space, gym.spaces.Box) wrapped_obs_space = env.observation_space @@ -208,7 +253,7 @@ def __init__(self, env: gym.Env, horizon: int = 5): # Overwrite the observation space env.observation_space = gym.spaces.Box(low=low, high=high, dtype=wrapped_obs_space.dtype) - super(HistoryWrapper, self).__init__(env) + super().__init__(env) self.horizon = horizon self.low_action, self.high_action = low_action, high_action @@ -244,11 +289,11 @@ class HistoryWrapperObsDict(gym.Wrapper): """ History Wrapper for dict observation. - :param env: (gym.Env) - :param horizon: (int) Number of steps to keep in the history. + :param env: + :param horizon: Number of steps to keep in the history. """ - def __init__(self, env, horizon=5): + def __init__(self, env: gym.Env, horizon: int = 2): assert isinstance(env.observation_space.spaces["observation"], gym.spaces.Box) wrapped_obs_space = env.observation_space.spaces["observation"] @@ -267,7 +312,7 @@ def __init__(self, env, horizon=5): # Overwrite the observation space env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=wrapped_obs_space.dtype) - super(HistoryWrapperObsDict, self).__init__(env) + super().__init__(env) self.horizon = horizon self.low_action, self.high_action = low_action, high_action @@ -305,3 +350,175 @@ def step(self, action): obs_dict["observation"] = self._create_obs_from_history() return obs_dict, reward, done, info + + +class ResidualExpertWrapper(gym.Wrapper): + """ + :param env: + :param model_path: + :param add_expert_to_obs: + :param residual_scale: + """ + + def __init__( + self, + env: gym.Env, + model_path: Optional[str] = os.environ.get("MODEL_PATH"), + add_expert_to_obs: bool = True, + residual_scale: float = 0.2, + expert_scale: float = 1.0, + d3rlpy_model: bool = False, + ): + assert isinstance(env.observation_space, gym.spaces.Box) + assert model_path is not None + + wrapped_obs_space = env.observation_space + + low = np.concatenate((wrapped_obs_space.low, np.finfo(np.float32).min * np.ones(2))) + high = np.concatenate((wrapped_obs_space.high, np.finfo(np.float32).max * np.ones(2))) + + # Overwrite the observation space + env.observation_space = gym.spaces.Box(low=low, high=high, dtype=wrapped_obs_space.dtype) + + super(ResidualExpertWrapper, self).__init__(env) + + print(f"Loading model from {model_path}") + if d3rlpy_model: + self.model = th.jit.load(model_path) + else: + self.model = SAC.load(model_path) + self.d3rlpy_model = d3rlpy_model + self._last_obs = None + self.residual_scale = residual_scale + self.expert_scale = expert_scale + self.add_expert_to_obs = add_expert_to_obs + + def _predict(self, obs): + # TODO: move to gpu when possible + if self.d3rlpy_model: + expert_action = self.model(th.tensor(obs).reshape(1, -1)).cpu().numpy()[0, :] + else: + expert_action, _ = self.model.predict(obs, deterministic=True) + if self.add_expert_to_obs: + obs = np.concatenate((obs, expert_action), axis=-1) + return obs, expert_action + + def reset(self): + obs = self.env.reset() + obs, self.expert_action = self._predict(obs) + return obs + + def step(self, action): + action = np.clip(self.expert_scale * self.expert_action + self.residual_scale * action, -1.0, 1.0) + obs, reward, done, info = self.env.step(action) + obs, self.expert_action = self._predict(obs) + + return obs, reward, done, info + + +class ContinuityCostWrapper(gym.Wrapper): + """ + Add continuity cost to the reward. + It assumes that the action space is normalized + and symmetric (actions in [-1, 1]). + :param env: + :param weight_continuity: + """ + + def __init__(self, env: gym.Env, weight_continuity: float = 0.0, condition: bool = False): + super(ContinuityCostWrapper, self).__init__(env) + self.last_action = None + self.weight_continuity = weight_continuity + self.condition = condition + if condition: + self.weight_continuity = 1.0 + + def reset(self): + self.last_action = None + return self.env.reset() + + def step(self, action): + obs, reward, done, info = self.env.step(action) + # Continuity cost + if self.last_action is not None: + max_delta = 2.0 # for the action space: high - low = 1 - (-1) = 2 + continuity_cost = np.mean((action - self.last_action) ** 2 / max_delta**2) + continuity_cost = self.weight_continuity * continuity_cost + else: + continuity_cost = 0.0 + self.last_action = action.copy() + + if self.condition: + reward = (1 - continuity_cost) * reward + else: + reward -= continuity_cost + return obs, reward, done, info + + +class FrameSkip(gym.Wrapper): + """ + Return only every ``skip``-th frame (frameskipping) + + :param env: the environment + :param skip: number of ``skip``-th frame + """ + + def __init__(self, env: gym.Env, skip: int = 4): + super().__init__(env) + self._skip = skip + + def step(self, action: np.ndarray): + """ + Step the environment with the given action + Repeat action, sum reward. + + :param action: the action + :return: observation, reward, done, information + """ + total_reward = 0.0 + done = None + for _ in range(self._skip): + obs, reward, done, info = self.env.step(action) + total_reward += reward + if done: + break + + return obs, total_reward, done, info + + def reset(self): + return self.env.reset() + + +class MaskVelocityWrapper(gym.ObservationWrapper): + """ + Gym environment observation wrapper used to mask velocity terms in + observations. The intention is the make the MDP partially observable. + Adapted from https://github.com/LiuWenlin595/FinalProject. + + :param env: Gym environment + """ + + # Supported envs + velocity_indices = { + "CartPole-v1": np.array([1, 3]), + "MountainCar-v0": np.array([1]), + "MountainCarContinuous-v0": np.array([1]), + "Pendulum-v1": np.array([2]), + "LunarLander-v2": np.array([2, 3, 5]), + "LunarLanderContinuous-v2": np.array([2, 3, 5]), + } + + def __init__(self, env: gym.Env): + super().__init__(env) + + env_id: str = env.unwrapped.spec.id + # By default no masking + self.mask = np.ones_like((env.observation_space.sample())) + try: + # Mask velocity + self.mask[self.velocity_indices[env_id]] = 0.0 + except KeyError: + raise NotImplementedError(f"Velocity masking not implemented for {env_id}") + + def observation(self, observation: np.ndarray) -> np.ndarray: + return observation * self.mask diff --git a/version.txt b/version.txt index 33271c4d0..511e75b2e 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -1.5.1a0 +1.5.1a8