diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 24390204e..0e78317f9 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -17,7 +17,6 @@ jobs:
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
-
steps:
- uses: actions/checkout@v2
with:
@@ -30,13 +29,13 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch - faster to download
- pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install pybullet==3.1.9
pip install -r requirements.txt
# Use headless version
pip install opencv-python-headless
# install parking-env to test HER (pinned so it works with gym 0.21)
- pip install git+https://github.com/eleurent/highway-env@1a04c6a98be64632cf9683625022023e70ff1ab1
+ pip install highway-env==1.5.0
- name: Type check
run: |
make type
diff --git a/.github/workflows/trained_agents.yml b/.github/workflows/trained_agents.yml
index 39a970f85..a047e9861 100644
--- a/.github/workflows/trained_agents.yml
+++ b/.github/workflows/trained_agents.yml
@@ -29,13 +29,13 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch - faster to download
- pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install pybullet==3.1.9
pip install -r requirements.txt
# Use headless version
pip install opencv-python-headless
# install parking-env to test HER (pinned so it works with gym 0.21)
- pip install git+https://github.com/eleurent/highway-env@1a04c6a98be64632cf9683625022023e70ff1ab1
+ pip install highway-env==1.5.0
# Add support for pickle5 protocol
pip install pickle5
- name: Check trained agents
diff --git a/.gitignore b/.gitignore
index e44316a54..2032c6f01 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,3 +15,7 @@ git_rewrite_commit_history.sh
.vscode/
wandb
runs
+hub
+*.mp4
+*.json
+*.csv
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 6b6dfe1b3..512bbbd13 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -6,17 +6,25 @@ variables:
type-check:
script:
+ - pip install git+https://github.com/huggingface/huggingface_sb3
- make type
pytest:
script:
# MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error
+ # tmp fix to have RecurrentPPO, will be fixed with new image
+ - pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
+ - pip install git+https://github.com/huggingface/huggingface_sb3
- MKL_THREADING_LAYER=GNU make pytest
+ coverage: '/^TOTAL.+?(\d+\%)$/'
check-trained-agents:
script:
# MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error
- pip install pickle5 # Add support for pickle5 protocol
+ - pip install git+https://github.com/huggingface/huggingface_sb3
+ # tmp fix to have RecurrentPPO, will be fixed with new image
+ - pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
- MKL_THREADING_LAYER=GNU make check-trained-agents
lint:
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 25c7dd5ab..4828362d9 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,18 +1,28 @@
-## Release 1.5.1a0 (WIP)
+## Release 1.5.1a8 (WIP)
### Breaking Changes
- Change default value for number of hyperparameter optimization trials from 10 to 500. (@ernestum)
- Derive number of intermediate pruning evaluations from number of time steps (1 evaluation per 100k time steps.) (@ernestum)
-- Updated default --eval-freq from 10k to 25k steps
+- Updated default --eval-freq from 10k to 25k steps
+- Update default horizon to 2 for the `HistoryWrapper`
### New Features
- Support setting PyTorch's device with thye `--device` flag (@gregwar)
+- Add `--max-total-trials` parameter to help with distributed optimization. (@ernestum)
+- Added `vec_env_wrapper` support in the config (works the same as `env_wrapper`)
+- Added Huggingface hub integration
+- Added `RecurrentPPO` support (aka `ppo_lstm`)
+- Added autodownload for "official" sb3 models from the hub
### Bug fixes
+- Fix `Reacher-v3` name in PPO hyperparameter file
+- Pinned ale-py==0.7.4 until new SB3 version is released
+- Fix enjoy / record videos with LSTM policy
### Documentation
### Other
+- When pruner is set to `"none"`, use `NopPruner` instead of diverted `MedianPruner` (@qgallouedec)
## Release 1.5.0 (2022-03-25)
diff --git a/README.md b/README.md
index 9b760fa3a..dd03e49e0 100644
--- a/README.md
+++ b/README.md
@@ -76,7 +76,7 @@ python scripts/plot_train.py -a her -e Fetch -y success -f rl-trained-agents/ -w
Plot evaluation reward curve for TQC, SAC and TD3 on the HalfCheetah and Ant PyBullet environments:
```
-python scripts/all_plots.py -a sac td3 tqc --env HalfCheetah Ant -f rl-trained-agents/
+python3 scripts/all_plots.py -a sac td3 tqc --env HalfCheetahBullet AntBullet -f rl-trained-agents/
```
## Plot with the rliable library
@@ -139,9 +139,46 @@ To load a checkpoint (here the checkpoint name is `rl_model_10000_steps.zip`):
python enjoy.py --algo algo_name --env env_id -f logs/ --exp-id 1 --load-checkpoint 10000
```
-To load the latest checkpoint:
+## Offline Training with d3rlpy
+
+```
+python train.py --algo human --env donkey-generated-track-v0 --env-kwargs frame_skip:1 throttle_max:2.0 throttle_min:0.0 steering_min:-0.5 steering_max:0.5 level:6 max_cte:100000 test_mode:False --num-threads 2 --eval-freq -1 \
+ -b logs/human/donkey-generated-track-v0_1/replay_buffer.pkl \
+ --pretrain-params batch_size:256 n_eval_episodes:1 n_epochs:20 n_iterations:20 \
+ --offline-algo bc
+```
+
+Params:
+
+```python
+
+n_iterations = args.pretrain_params.get("n_iterations", 10)
+n_epochs = args.pretrain_params.get("n_epochs", 1)
+q_func_type = args.pretrain_params.get("q_func_type")
+batch_size = args.pretrain_params.get("batch_size", 512)
+# n_action_samples = args.pretrain_params.get("n_action_samples", 1)
+n_eval_episodes = args.pretrain_params.get("n_eval_episodes", 5)
+add_to_buffer = args.pretrain_params.get("add_to_buffer", False)
+deterministic = args.pretrain_params.get("deterministic", True)
+```
+
+## Human driving
+
+```
+python train.py --algo human --env donkey-generated-track-v0 --env-kwargs frame_skip:1 throttle_max:1.0 throttle_min:-1.0 steering_min:-1 steering_max:1 level:6 max_cte:100000 --num-threads 2 --eval-freq -1
+```
+
+## Huggingface Hub Integration
+
+Upload model to hub (same syntax as for `enjoy.py`):
+```
+python -m utils.push_to_hub --algo ppo --env CartPole-v1 -f logs/ -orga sb3 -m "Initial commit"
```
-python enjoy.py --algo algo_name --env env_id -f logs/ --exp-id 1 --load-last-checkpoint
+you can choose custom `repo-name` (default: `{algo}-{env_id}`) by passing a `--repo-name` argument.
+
+Download model from hub:
+```
+python -m utils.load_from_hub --algo ppo --env CartPole-v1 -f logs/ -orga sb3
```
## Hyperparameter yaml syntax
@@ -243,6 +280,17 @@ env_wrapper:
Note that you can easily specify parameters too.
+## VecEnvWrapper
+
+You can specify which `VecEnvWrapper` to use in the config, the same way as for env wrappers (see above), using the `vec_env_wrapper` key:
+
+For instance:
+```yaml
+vec_env_wrapper: stable_baselines3.common.vec_env.VecMonitor
+```
+
+Note: `VecNormalize` is supported separately using `normalize` keyword, and `VecFrameStack` has a dedicated keyword `frame_stack`.
+
## Callbacks
Following the same syntax as env wrappers, you can also add custom callbacks to use during training.
@@ -311,6 +359,8 @@ python -m utils.record_training --algo ppo --env CartPole-v1 -n 1000 -f logs --d
Final performance of the trained agents can be found in [`benchmark.md`](./benchmark.md). To compute them, simply run `python -m utils.benchmark`.
+List and videos of trained agents can be found on our Huggingface page: https://huggingface.co/sb3
+
*NOTE: this is not a quantitative benchmark as it corresponds to only one run (cf [issue #38](https://github.com/araffin/rl-baselines-zoo/issues/38)). This benchmark is meant to check algorithm (maximal) performance, find potential bugs and also allow users to have access to pretrained agents.*
### Atari Games
diff --git a/benchmark.md b/benchmark.md
index d50f5c1d7..61c5239ac 100644
--- a/benchmark.md
+++ b/benchmark.md
@@ -8,6 +8,9 @@ during this evaluation.
It uses the deterministic policy except for Atari games.
+You can view each model card (it includes video and hyperparameters)
+on our Huggingface page: https://huggingface.co/sb3
+
*NOTE: this is not a quantitative benchmark as it corresponds to only one run
(cf [issue #38](https://github.com/araffin/rl-baselines-zoo/issues/38)).
This benchmark is meant to check algorithm (maximal) performance, find potential bugs
@@ -15,181 +18,185 @@ and also allow users to have access to pretrained agents.*
"M" stands for Million (1e6)
-|algo | env_id |mean_reward|std_reward|n_timesteps|eval_timesteps|eval_episodes|
-|-----|---------------------------|----------:|---------:|-----------|-------------:|------------:|
-|a2c |Acrobot-v1 | -83.353| 17.213|500k | 149979| 1778|
-|a2c |AntBulletEnv-v0 | 2497.147| 37.359|2M | 150000| 150|
-|a2c |AsteroidsNoFrameskip-v4 | 1286.550| 423.750|10M | 614138| 258|
-|a2c |BeamRiderNoFrameskip-v4 | 2890.298| 1379.137|10M | 591104| 47|
-|a2c |BipedalWalker-v3 | 299.754| 23.459|5M | 149287| 208|
-|a2c |BipedalWalkerHardcore-v3 | 96.171| 122.943|200M | 149704| 113|
-|a2c |BreakoutNoFrameskip-v4 | 279.793| 122.177|10M | 604115| 82|
-|a2c |CartPole-v1 | 500.000| 0.000|500k | 150000| 300|
-|a2c |EnduroNoFrameskip-v4 | 0.000| 0.000|10M | 599040| 45|
-|a2c |HalfCheetah-v3 | 3041.174| 157.265|1M | 150000| 150|
-|a2c |HalfCheetahBulletEnv-v0 | 2107.384| 36.008|2M | 150000| 150|
-|a2c |Hopper-v3 | 733.454| 376.574|1M | 149987| 580|
-|a2c |HopperBulletEnv-v0 | 815.355| 313.798|2M | 149541| 254|
-|a2c |LunarLander-v2 | 155.751| 80.419|200k | 149443| 297|
-|a2c |LunarLanderContinuous-v2 | 84.225| 145.906|5M | 149305| 256|
-|a2c |MountainCar-v0 | -111.263| 24.087|1M | 149982| 1348|
-|a2c |MountainCarContinuous-v0 | 91.166| 0.255|100k | 149923| 1659|
-|a2c |Pendulum-v1 | -162.965| 103.210|1M | 150000| 750|
-|a2c |PongNoFrameskip-v4 | 17.292| 3.214|10M | 594910| 65|
-|a2c |QbertNoFrameskip-v4 | 3882.345| 1223.327|10M | 610670| 194|
-|a2c |ReacherBulletEnv-v0 | 14.968| 10.978|2M | 150000| 1000|
-|a2c |RoadRunnerNoFrameskip-v4 | 31671.512| 6364.085|10M | 606710| 172|
-|a2c |SeaquestNoFrameskip-v4 | 1721.493| 105.339|10M | 599691| 67|
-|a2c |SpaceInvadersNoFrameskip-v4| 627.160| 201.974|10M | 604848| 162|
-|a2c |Swimmer-v3 | 200.627| 2.544|1M | 150000| 150|
-|a2c |Walker2DBulletEnv-v0 | 858.209| 333.116|2M | 149156| 173|
-|ars |Acrobot-v1 | -82.884| 23.825|500k | 149985| 1788|
-|ars |Ant-v3 | 2333.773| 20.597|75M | 150000| 150|
-|ars |CartPole-v1 | 500.000| 0.000|50k | 150000| 300|
-|ars |HalfCheetah-v3 | 4815.192| 1340.752|12M | 150000| 150|
-|ars |Hopper-v3 | 3343.919| 5.730|7M | 150000| 150|
-|ars |LunarLanderContinuous-v2 | 167.959| 147.071|2M | 149883| 562|
-|ars |MountainCar-v0 | -122.000| 33.456|500k | 149938| 1229|
-|ars |MountainCarContinuous-v0 | 96.672| 0.784|500k | 149990| 621|
-|ars |Pendulum-v1 | -212.540| 160.444|2M | 150000| 750|
-|ars |Swimmer-v3 | 355.267| 12.796|2M | 150000| 150|
-|ars |Walker2d-v3 | 2993.582| 166.289|75M | 149821| 152|
-|ddpg |AntBulletEnv-v0 | 2399.147| 75.410|1M | 150000| 150|
-|ddpg |BipedalWalker-v3 | 197.486| 141.580|1M | 149237| 227|
-|ddpg |HalfCheetahBulletEnv-v0 | 2078.325| 208.379|1M | 150000| 150|
-|ddpg |HopperBulletEnv-v0 | 1157.065| 448.695|1M | 149565| 346|
-|ddpg |LunarLanderContinuous-v2 | 230.217| 92.372|300k | 149862| 556|
-|ddpg |MountainCarContinuous-v0 | 93.512| 0.048|300k | 149965| 2260|
-|ddpg |Pendulum-v1 | -152.099| 94.282|20k | 150000| 750|
-|ddpg |ReacherBulletEnv-v0 | 15.582| 9.606|300k | 150000| 1000|
-|ddpg |Walker2DBulletEnv-v0 | 1387.591| 736.955|1M | 149051| 208|
-|dqn |Acrobot-v1 | -76.639| 11.752|100k | 149998| 1932|
-|dqn |AsteroidsNoFrameskip-v4 | 782.687| 259.247|10M | 607962| 134|
-|dqn |BeamRiderNoFrameskip-v4 | 4295.946| 1790.458|10M | 600832| 37|
-|dqn |BreakoutNoFrameskip-v4 | 358.327| 61.981|10M | 601461| 55|
-|dqn |CartPole-v1 | 500.000| 0.000|50k | 150000| 300|
-|dqn |EnduroNoFrameskip-v4 | 830.929| 194.544|10M | 599040| 14|
-|dqn |LunarLander-v2 | 154.382| 79.241|100k | 149373| 200|
-|dqn |MountainCar-v0 | -100.849| 9.925|120k | 149962| 1487|
-|dqn |PongNoFrameskip-v4 | 20.602| 0.613|10M | 598998| 88|
-|dqn |QbertNoFrameskip-v4 | 9496.774| 5399.633|10M | 605844| 124|
-|dqn |RoadRunnerNoFrameskip-v4 | 40396.350| 7069.131|10M | 603257| 137|
-|dqn |SeaquestNoFrameskip-v4 | 2000.290| 606.644|10M | 599505| 69|
-|dqn |SpaceInvadersNoFrameskip-v4| 622.742| 201.564|10M | 604311| 155|
-|ppo |Acrobot-v1 | -73.506| 18.201|1M | 149979| 2013|
-|ppo |Ant-v3 | 1327.158| 451.577|1M | 149572| 175|
-|ppo |AntBulletEnv-v0 | 2865.922| 56.468|2M | 150000| 150|
-|ppo |AsteroidsNoFrameskip-v4 | 2156.174| 744.640|10M | 602092| 149|
-|ppo |BeamRiderNoFrameskip-v4 | 3397.000| 1662.368|10M | 598926| 46|
-|ppo |BipedalWalker-v3 | 213.299| 129.490|5M | 149826| 233|
-|ppo |BipedalWalkerHardcore-v3 | 122.374| 117.605|100M | 148036| 105|
-|ppo |BreakoutNoFrameskip-v4 | 398.033| 33.328|10M | 600418| 60|
-|ppo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300|
-|ppo |EnduroNoFrameskip-v4 | 996.364| 176.090|10M | 572416| 11|
-|ppo |HalfCheetah-v3 | 5819.099| 663.530|1M | 150000| 150|
-|ppo |HalfCheetahBulletEnv-v0 | 2924.721| 64.465|2M | 150000| 150|
-|ppo |Hopper-v3 | 2410.435| 10.026|1M | 150000| 150|
-|ppo |HopperBulletEnv-v0 | 2575.054| 223.301|2M | 149094| 152|
-|ppo |LunarLander-v2 | 242.119| 31.823|1M | 149636| 369|
-|ppo |LunarLanderContinuous-v2 | 270.863| 32.072|1M | 149956| 526|
-|ppo |MountainCar-v0 | -110.423| 19.473|1M | 149954| 1358|
-|ppo |MountainCarContinuous-v0 | 88.343| 2.572|20k | 149983| 633|
-|ppo |Pendulum-v1 | -172.225| 104.159|100k | 150000| 750|
-|ppo |PongNoFrameskip-v4 | 20.989| 0.105|10M | 599902| 90|
-|ppo |QbertNoFrameskip-v4 | 15627.108| 3313.538|10M | 600248| 83|
-|ppo |ReacherBulletEnv-v0 | 17.091| 11.048|1M | 150000| 1000|
-|ppo |RoadRunnerNoFrameskip-v4 | 40680.645| 6675.058|10M | 605786| 155|
-|ppo |SeaquestNoFrameskip-v4 | 1783.636| 34.096|10M | 598243| 66|
-|ppo |SpaceInvadersNoFrameskip-v4| 960.331| 425.355|10M | 603771| 136|
-|ppo |Swimmer-v3 | 281.561| 9.671|1M | 150000| 150|
-|ppo |Walker2DBulletEnv-v0 | 2109.992| 13.899|2M | 150000| 150|
-|ppo |Walker2d-v3 | 3478.798| 821.708|1M | 149343| 171|
-|qrdqn|Acrobot-v1 | -69.135| 9.967|100k | 149949| 2138|
-|qrdqn|AsteroidsNoFrameskip-v4 | 2185.303| 1097.172|10M | 599784| 66|
-|qrdqn|BeamRiderNoFrameskip-v4 | 17122.941| 10769.997|10M | 596483| 17|
-|qrdqn|BreakoutNoFrameskip-v4 | 393.600| 79.828|10M | 579711| 40|
-|qrdqn|CartPole-v1 | 500.000| 0.000|50k | 150000| 300|
-|qrdqn|EnduroNoFrameskip-v4 | 3231.200| 1311.801|10M | 585728| 5|
-|qrdqn|LunarLander-v2 | 70.236| 225.491|100k | 149957| 522|
-|qrdqn|MountainCar-v0 | -106.042| 15.536|120k | 149943| 1414|
-|qrdqn|PongNoFrameskip-v4 | 20.492| 0.687|10M | 597443| 63|
-|qrdqn|QbertNoFrameskip-v4 | 14799.728| 2917.629|10M | 600773| 92|
-|qrdqn|RoadRunnerNoFrameskip-v4 | 42325.424| 8361.161|10M | 591016| 59|
-|qrdqn|SeaquestNoFrameskip-v4 | 2557.576| 76.951|10M | 596275| 66|
-|qrdqn|SpaceInvadersNoFrameskip-v4| 1899.928| 823.488|10M | 597218| 69|
-|sac |Ant-v3 | 4615.791| 1354.111|1M | 149074| 165|
-|sac |AntBulletEnv-v0 | 3073.114| 175.148|1M | 150000| 150|
-|sac |BipedalWalker-v3 | 297.668| 33.060|500k | 149530| 136|
-|sac |BipedalWalkerHardcore-v3 | 4.423| 103.910|10M | 149794| 88|
-|sac |HalfCheetah-v3 | 9535.451| 100.470|1M | 150000| 150|
-|sac |HalfCheetahBulletEnv-v0 | 2792.170| 12.088|1M | 150000| 150|
-|sac |Hopper-v3 | 2325.547| 1129.676|1M | 149841| 236|
-|sac |HopperBulletEnv-v0 | 2603.494| 164.322|1M | 149724| 151|
-|sac |Humanoid-v3 | 6232.287| 279.885|2M | 149460| 150|
-|sac |LunarLanderContinuous-v2 | 260.390| 65.467|500k | 149634| 672|
-|sac |MountainCarContinuous-v0 | 94.679| 1.134|50k | 149966| 1443|
-|sac |Pendulum-v1 | -156.995| 88.714|20k | 150000| 750|
-|sac |ReacherBulletEnv-v0 | 18.062| 9.729|300k | 150000| 1000|
-|sac |Swimmer-v3 | 345.568| 3.084|1M | 150000| 150|
-|sac |Walker2DBulletEnv-v0 | 2292.266| 13.970|1M | 149983| 150|
-|sac |Walker2d-v3 | 3863.203| 254.347|1M | 149309| 150|
-|td3 |Ant-v3 | 5813.274| 589.773|1M | 149393| 151|
-|td3 |AntBulletEnv-v0 | 3300.026| 54.640|1M | 150000| 150|
-|td3 |BipedalWalker-v3 | 305.990| 56.886|1M | 149999| 224|
-|td3 |BipedalWalkerHardcore-v3 | -98.116| 16.087|10M | 150000| 75|
-|td3 |HalfCheetah-v3 | 9655.666| 969.916|1M | 150000| 150|
-|td3 |HalfCheetahBulletEnv-v0 | 2821.641| 19.722|1M | 150000| 150|
-|td3 |Hopper-v3 | 3606.390| 4.027|1M | 150000| 150|
-|td3 |HopperBulletEnv-v0 | 2681.609| 27.806|1M | 149486| 150|
-|td3 |Humanoid-v3 | 5566.687| 14.544|2M | 150000| 150|
-|td3 |LunarLanderContinuous-v2 | 207.451| 67.562|300k | 149488| 337|
-|td3 |MountainCarContinuous-v0 | 93.483| 0.075|300k | 149976| 2275|
-|td3 |Pendulum-v1 | -151.855| 90.227|20k | 150000| 750|
-|td3 |ReacherBulletEnv-v0 | 17.114| 9.750|300k | 150000| 1000|
-|td3 |Swimmer-v3 | 359.127| 1.244|1M | 150000| 150|
-|td3 |Walker2DBulletEnv-v0 | 2213.672| 230.558|1M | 149800| 152|
-|td3 |Walker2d-v3 | 4717.823| 46.303|1M | 150000| 150|
-|tqc |Ant-v3 | 3339.362| 1969.906|1M | 149583| 202|
-|tqc |AntBulletEnv-v0 | 3456.717| 248.733|1M | 150000| 150|
-|tqc |BipedalWalker-v3 | 329.808| 45.083|500k | 149682| 254|
-|tqc |BipedalWalkerHardcore-v3 | 235.226| 110.569|2M | 149032| 131|
-|tqc |FetchPickAndPlace-v1 | -9.331| 6.850|1M | 150000| 3000|
-|tqc |FetchPush-v1 | -8.799| 5.438|1M | 150000| 3000|
-|tqc |FetchReach-v1 | -1.659| 0.873|20k | 150000| 3000|
-|tqc |FetchSlide-v1 | -29.210| 11.387|3M | 150000| 3000|
-|tqc |HalfCheetah-v3 | 12089.939| 127.440|1M | 150000| 150|
-|tqc |HalfCheetahBulletEnv-v0 | 3675.299| 17.681|1M | 150000| 150|
-|tqc |Hopper-v3 | 3754.199| 8.276|1M | 150000| 150|
-|tqc |HopperBulletEnv-v0 | 2662.373| 206.210|1M | 149881| 151|
-|tqc |Humanoid-v3 | 7239.320| 1647.498|2M | 149508| 165|
-|tqc |LunarLanderContinuous-v2 | 277.956| 25.466|500k | 149928| 706|
-|tqc |MountainCarContinuous-v0 | 63.641| 45.259|50k | 149796| 186|
-|tqc |PandaPickAndPlace-v1 | -8.024| 6.674|1M | 150000| 3000|
-|tqc |PandaPush-v1 | -6.405| 6.400|1M | 150000| 3000|
-|tqc |PandaReach-v1 | -1.768| 0.858|20k | 150000| 3000|
-|tqc |PandaSlide-v1 | -27.497| 9.868|3M | 150000| 3000|
-|tqc |PandaStack-v1 | -96.915| 17.240|1M | 150000| 1500|
-|tqc |Pendulum-v1 | -151.340| 87.893|20k | 150000| 750|
-|tqc |ReacherBulletEnv-v0 | 18.255| 9.543|300k | 150000| 1000|
-|tqc |Swimmer-v3 | 339.423| 1.486|1M | 150000| 150|
-|tqc |Walker2DBulletEnv-v0 | 2508.934| 614.624|1M | 149572| 159|
-|tqc |Walker2d-v3 | 4380.720| 500.489|1M | 149606| 152|
-|tqc |parking-v0 | -6.762| 2.690|100k | 149983| 7528|
-|trpo |Acrobot-v1 | -83.114| 18.648|100k | 149976| 1783|
-|trpo |Ant-v3 | 4982.301| 663.761|1M | 149909| 153|
-|trpo |AntBulletEnv-v0 | 2560.621| 52.064|2M | 150000| 150|
-|trpo |BipedalWalker-v3 | 182.339| 145.570|1M | 148440| 148|
-|trpo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300|
-|trpo |HalfCheetah-v3 | 1785.476| 68.672|1M | 150000| 150|
-|trpo |HalfCheetahBulletEnv-v0 | 2758.752| 327.032|2M | 150000| 150|
-|trpo |Hopper-v3 | 3618.386| 356.768|1M | 149575| 152|
-|trpo |HopperBulletEnv-v0 | 2565.416| 410.298|1M | 149640| 154|
-|trpo |LunarLander-v2 | 133.166| 112.173|200k | 149088| 230|
-|trpo |LunarLanderContinuous-v2 | 262.387| 21.428|200k | 149925| 501|
-|trpo |MountainCar-v0 | -107.278| 13.231|100k | 149974| 1398|
-|trpo |MountainCarContinuous-v0 | 92.489| 0.355|50k | 149971| 1732|
-|trpo |Pendulum-v1 | -174.631| 127.577|100k | 150000| 750|
-|trpo |ReacherBulletEnv-v0 | 14.741| 11.559|300k | 150000| 1000|
-|trpo |Swimmer-v3 | 365.663| 2.087|1M | 150000| 150|
-|trpo |Walker2DBulletEnv-v0 | 1483.467| 823.468|2M | 149860| 197|
-|trpo |Walker2d-v3 | 4933.148| 1452.538|1M | 149054| 163|
+| algo | env_id |mean_reward|std_reward|n_timesteps|eval_timesteps|eval_episodes|
+|--------|-----------------------------|----------:|---------:|-----------|-------------:|------------:|
+|a2c |Acrobot-v1 | -83.353| 17.213|500k | 149979| 1778|
+|a2c |AntBulletEnv-v0 | 2497.147| 37.359|2M | 150000| 150|
+|a2c |AsteroidsNoFrameskip-v4 | 1286.550| 423.750|10M | 614138| 258|
+|a2c |BeamRiderNoFrameskip-v4 | 2890.298| 1379.137|10M | 591104| 47|
+|a2c |BipedalWalker-v3 | 299.754| 23.459|5M | 149287| 208|
+|a2c |BipedalWalkerHardcore-v3 | 96.171| 122.943|200M | 149704| 113|
+|a2c |BreakoutNoFrameskip-v4 | 279.793| 122.177|10M | 604115| 82|
+|a2c |CartPole-v1 | 500.000| 0.000|500k | 150000| 300|
+|a2c |EnduroNoFrameskip-v4 | 0.000| 0.000|10M | 599040| 45|
+|a2c |HalfCheetah-v3 | 3041.174| 157.265|1M | 150000| 150|
+|a2c |HalfCheetahBulletEnv-v0 | 2107.384| 36.008|2M | 150000| 150|
+|a2c |Hopper-v3 | 733.454| 376.574|1M | 149987| 580|
+|a2c |HopperBulletEnv-v0 | 815.355| 313.798|2M | 149541| 254|
+|a2c |LunarLander-v2 | 155.751| 80.419|200k | 149443| 297|
+|a2c |LunarLanderContinuous-v2 | 84.225| 145.906|5M | 149305| 256|
+|a2c |MountainCar-v0 | -111.263| 24.087|1M | 149982| 1348|
+|a2c |MountainCarContinuous-v0 | 91.166| 0.255|100k | 149923| 1659|
+|a2c |Pendulum-v1 | -162.965| 103.210|1M | 150000| 750|
+|a2c |PongNoFrameskip-v4 | 17.292| 3.214|10M | 594910| 65|
+|a2c |QbertNoFrameskip-v4 | 3882.345| 1223.327|10M | 610670| 194|
+|a2c |ReacherBulletEnv-v0 | 14.968| 10.978|2M | 150000| 1000|
+|a2c |RoadRunnerNoFrameskip-v4 | 31671.512| 6364.085|10M | 606710| 172|
+|a2c |SeaquestNoFrameskip-v4 | 1721.493| 105.339|10M | 599691| 67|
+|a2c |SpaceInvadersNoFrameskip-v4 | 627.160| 201.974|10M | 604848| 162|
+|a2c |Swimmer-v3 | 200.627| 2.544|1M | 150000| 150|
+|a2c |Walker2DBulletEnv-v0 | 858.209| 333.116|2M | 149156| 173|
+|ars |Acrobot-v1 | -82.884| 23.825|500k | 149985| 1788|
+|ars |Ant-v3 | 2333.773| 20.597|75M | 150000| 150|
+|ars |CartPole-v1 | 500.000| 0.000|50k | 150000| 300|
+|ars |HalfCheetah-v3 | 4815.192| 1340.752|12M | 150000| 150|
+|ars |Hopper-v3 | 3343.919| 5.730|7M | 150000| 150|
+|ars |LunarLanderContinuous-v2 | 167.959| 147.071|2M | 149883| 562|
+|ars |MountainCar-v0 | -122.000| 33.456|500k | 149938| 1229|
+|ars |MountainCarContinuous-v0 | 96.672| 0.784|500k | 149990| 621|
+|ars |Pendulum-v1 | -212.540| 160.444|2M | 150000| 750|
+|ars |Swimmer-v3 | 355.267| 12.796|2M | 150000| 150|
+|ars |Walker2d-v3 | 2993.582| 166.289|75M | 149821| 152|
+|ddpg |AntBulletEnv-v0 | 2399.147| 75.410|1M | 150000| 150|
+|ddpg |BipedalWalker-v3 | 197.486| 141.580|1M | 149237| 227|
+|ddpg |HalfCheetahBulletEnv-v0 | 2078.325| 208.379|1M | 150000| 150|
+|ddpg |HopperBulletEnv-v0 | 1157.065| 448.695|1M | 149565| 346|
+|ddpg |LunarLanderContinuous-v2 | 230.217| 92.372|300k | 149862| 556|
+|ddpg |MountainCarContinuous-v0 | 93.512| 0.048|300k | 149965| 2260|
+|ddpg |Pendulum-v1 | -152.099| 94.282|20k | 150000| 750|
+|ddpg |ReacherBulletEnv-v0 | 15.582| 9.606|300k | 150000| 1000|
+|ddpg |Walker2DBulletEnv-v0 | 1387.591| 736.955|1M | 149051| 208|
+|dqn |Acrobot-v1 | -76.639| 11.752|100k | 149998| 1932|
+|dqn |AsteroidsNoFrameskip-v4 | 782.687| 259.247|10M | 607962| 134|
+|dqn |BeamRiderNoFrameskip-v4 | 4295.946| 1790.458|10M | 600832| 37|
+|dqn |BreakoutNoFrameskip-v4 | 358.327| 61.981|10M | 601461| 55|
+|dqn |CartPole-v1 | 500.000| 0.000|50k | 150000| 300|
+|dqn |EnduroNoFrameskip-v4 | 830.929| 194.544|10M | 599040| 14|
+|dqn |LunarLander-v2 | 154.382| 79.241|100k | 149373| 200|
+|dqn |MountainCar-v0 | -100.849| 9.925|120k | 149962| 1487|
+|dqn |PongNoFrameskip-v4 | 20.602| 0.613|10M | 598998| 88|
+|dqn |QbertNoFrameskip-v4 | 9496.774| 5399.633|10M | 605844| 124|
+|dqn |RoadRunnerNoFrameskip-v4 | 40396.350| 7069.131|10M | 603257| 137|
+|dqn |SeaquestNoFrameskip-v4 | 2000.290| 606.644|10M | 599505| 69|
+|dqn |SpaceInvadersNoFrameskip-v4 | 622.742| 201.564|10M | 604311| 155|
+|ppo |Acrobot-v1 | -73.506| 18.201|1M | 149979| 2013|
+|ppo |Ant-v3 | 1327.158| 451.577|1M | 149572| 175|
+|ppo |AntBulletEnv-v0 | 2865.922| 56.468|2M | 150000| 150|
+|ppo |AsteroidsNoFrameskip-v4 | 2156.174| 744.640|10M | 602092| 149|
+|ppo |BeamRiderNoFrameskip-v4 | 3397.000| 1662.368|10M | 598926| 46|
+|ppo |BipedalWalker-v3 | 287.939| 2.448|5M | 149589| 123|
+|ppo |BipedalWalkerHardcore-v3 | 122.374| 117.605|100M | 148036| 105|
+|ppo |BreakoutNoFrameskip-v4 | 398.033| 33.328|10M | 600418| 60|
+|ppo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300|
+|ppo |EnduroNoFrameskip-v4 | 996.364| 176.090|10M | 572416| 11|
+|ppo |HalfCheetah-v3 | 5819.099| 663.530|1M | 150000| 150|
+|ppo |HalfCheetahBulletEnv-v0 | 2924.721| 64.465|2M | 150000| 150|
+|ppo |Hopper-v3 | 2410.435| 10.026|1M | 150000| 150|
+|ppo |HopperBulletEnv-v0 | 2575.054| 223.301|2M | 149094| 152|
+|ppo |LunarLander-v2 | 242.119| 31.823|1M | 149636| 369|
+|ppo |LunarLanderContinuous-v2 | 270.863| 32.072|1M | 149956| 526|
+|ppo |MountainCar-v0 | -110.423| 19.473|1M | 149954| 1358|
+|ppo |MountainCarContinuous-v0 | 88.343| 2.572|20k | 149983| 633|
+|ppo |Pendulum-v1 | -172.225| 104.159|100k | 150000| 750|
+|ppo |PongNoFrameskip-v4 | 20.989| 0.105|10M | 599902| 90|
+|ppo |QbertNoFrameskip-v4 | 15627.108| 3313.538|10M | 600248| 83|
+|ppo |ReacherBulletEnv-v0 | 17.091| 11.048|1M | 150000| 1000|
+|ppo |RoadRunnerNoFrameskip-v4 | 40680.645| 6675.058|10M | 605786| 155|
+|ppo |SeaquestNoFrameskip-v4 | 1783.636| 34.096|10M | 598243| 66|
+|ppo |SpaceInvadersNoFrameskip-v4 | 960.331| 425.355|10M | 603771| 136|
+|ppo |Swimmer-v3 | 281.561| 9.671|1M | 150000| 150|
+|ppo |Walker2DBulletEnv-v0 | 2109.992| 13.899|2M | 150000| 150|
+|ppo |Walker2d-v3 | 3478.798| 821.708|1M | 149343| 171|
+|ppo_lstm|CarRacing-v0 | 862.549| 97.342|4M | 149588| 156|
+|ppo_lstm|CartPoleNoVel-v1 | 500.000| 0.000|100k | 150000| 300|
+|ppo_lstm|MountainCarContinuousNoVel-v0| 91.469| 1.776|300k | 149882| 1340|
+|ppo_lstm|PendulumNoVel-v1 | -217.933| 140.094|100k | 150000| 750|
+|qrdqn |Acrobot-v1 | -69.135| 9.967|100k | 149949| 2138|
+|qrdqn |AsteroidsNoFrameskip-v4 | 2185.303| 1097.172|10M | 599784| 66|
+|qrdqn |BeamRiderNoFrameskip-v4 | 17122.941| 10769.997|10M | 596483| 17|
+|qrdqn |BreakoutNoFrameskip-v4 | 393.600| 79.828|10M | 579711| 40|
+|qrdqn |CartPole-v1 | 500.000| 0.000|50k | 150000| 300|
+|qrdqn |EnduroNoFrameskip-v4 | 3231.200| 1311.801|10M | 585728| 5|
+|qrdqn |LunarLander-v2 | 70.236| 225.491|100k | 149957| 522|
+|qrdqn |MountainCar-v0 | -106.042| 15.536|120k | 149943| 1414|
+|qrdqn |PongNoFrameskip-v4 | 20.492| 0.687|10M | 597443| 63|
+|qrdqn |QbertNoFrameskip-v4 | 14799.728| 2917.629|10M | 600773| 92|
+|qrdqn |RoadRunnerNoFrameskip-v4 | 42325.424| 8361.161|10M | 591016| 59|
+|qrdqn |SeaquestNoFrameskip-v4 | 2557.576| 76.951|10M | 596275| 66|
+|qrdqn |SpaceInvadersNoFrameskip-v4 | 1899.928| 823.488|10M | 597218| 69|
+|sac |Ant-v3 | 4615.791| 1354.111|1M | 149074| 165|
+|sac |AntBulletEnv-v0 | 3073.114| 175.148|1M | 150000| 150|
+|sac |BipedalWalker-v3 | 297.668| 33.060|500k | 149530| 136|
+|sac |BipedalWalkerHardcore-v3 | 4.423| 103.910|10M | 149794| 88|
+|sac |HalfCheetah-v3 | 9535.451| 100.470|1M | 150000| 150|
+|sac |HalfCheetahBulletEnv-v0 | 2792.170| 12.088|1M | 150000| 150|
+|sac |Hopper-v3 | 2325.547| 1129.676|1M | 149841| 236|
+|sac |HopperBulletEnv-v0 | 2603.494| 164.322|1M | 149724| 151|
+|sac |Humanoid-v3 | 6232.287| 279.885|2M | 149460| 150|
+|sac |LunarLanderContinuous-v2 | 260.390| 65.467|500k | 149634| 672|
+|sac |MountainCarContinuous-v0 | 94.679| 1.134|50k | 149966| 1443|
+|sac |Pendulum-v1 | -156.995| 88.714|20k | 150000| 750|
+|sac |ReacherBulletEnv-v0 | 18.062| 9.729|300k | 150000| 1000|
+|sac |Swimmer-v3 | 345.568| 3.084|1M | 150000| 150|
+|sac |Walker2DBulletEnv-v0 | 2292.266| 13.970|1M | 149983| 150|
+|sac |Walker2d-v3 | 3863.203| 254.347|1M | 149309| 150|
+|td3 |Ant-v3 | 5813.274| 589.773|1M | 149393| 151|
+|td3 |AntBulletEnv-v0 | 3300.026| 54.640|1M | 150000| 150|
+|td3 |BipedalWalker-v3 | 305.990| 56.886|1M | 149999| 224|
+|td3 |BipedalWalkerHardcore-v3 | -98.116| 16.087|10M | 150000| 75|
+|td3 |HalfCheetah-v3 | 9655.666| 969.916|1M | 150000| 150|
+|td3 |HalfCheetahBulletEnv-v0 | 2821.641| 19.722|1M | 150000| 150|
+|td3 |Hopper-v3 | 3606.390| 4.027|1M | 150000| 150|
+|td3 |HopperBulletEnv-v0 | 2681.609| 27.806|1M | 149486| 150|
+|td3 |Humanoid-v3 | 5566.687| 14.544|2M | 150000| 150|
+|td3 |LunarLanderContinuous-v2 | 207.451| 67.562|300k | 149488| 337|
+|td3 |MountainCarContinuous-v0 | 93.483| 0.075|300k | 149976| 2275|
+|td3 |Pendulum-v1 | -151.855| 90.227|20k | 150000| 750|
+|td3 |ReacherBulletEnv-v0 | 17.114| 9.750|300k | 150000| 1000|
+|td3 |Swimmer-v3 | 359.127| 1.244|1M | 150000| 150|
+|td3 |Walker2DBulletEnv-v0 | 2213.672| 230.558|1M | 149800| 152|
+|td3 |Walker2d-v3 | 4717.823| 46.303|1M | 150000| 150|
+|tqc |Ant-v3 | 3339.362| 1969.906|1M | 149583| 202|
+|tqc |AntBulletEnv-v0 | 3456.717| 248.733|1M | 150000| 150|
+|tqc |BipedalWalker-v3 | 329.808| 45.083|500k | 149682| 254|
+|tqc |BipedalWalkerHardcore-v3 | 235.226| 110.569|2M | 149032| 131|
+|tqc |FetchPickAndPlace-v1 | -9.331| 6.850|1M | 150000| 3000|
+|tqc |FetchPush-v1 | -8.799| 5.438|1M | 150000| 3000|
+|tqc |FetchReach-v1 | -1.659| 0.873|20k | 150000| 3000|
+|tqc |FetchSlide-v1 | -29.210| 11.387|3M | 150000| 3000|
+|tqc |HalfCheetah-v3 | 12089.939| 127.440|1M | 150000| 150|
+|tqc |HalfCheetahBulletEnv-v0 | 3675.299| 17.681|1M | 150000| 150|
+|tqc |Hopper-v3 | 3754.199| 8.276|1M | 150000| 150|
+|tqc |HopperBulletEnv-v0 | 2662.373| 206.210|1M | 149881| 151|
+|tqc |Humanoid-v3 | 7239.320| 1647.498|2M | 149508| 165|
+|tqc |LunarLanderContinuous-v2 | 277.956| 25.466|500k | 149928| 706|
+|tqc |MountainCarContinuous-v0 | 63.641| 45.259|50k | 149796| 186|
+|tqc |PandaPickAndPlace-v1 | -8.024| 6.674|1M | 150000| 3000|
+|tqc |PandaPush-v1 | -6.405| 6.400|1M | 150000| 3000|
+|tqc |PandaReach-v1 | -1.768| 0.858|20k | 150000| 3000|
+|tqc |PandaSlide-v1 | -27.497| 9.868|3M | 150000| 3000|
+|tqc |PandaStack-v1 | -96.915| 17.240|1M | 150000| 1500|
+|tqc |Pendulum-v1 | -151.340| 87.893|20k | 150000| 750|
+|tqc |ReacherBulletEnv-v0 | 18.255| 9.543|300k | 150000| 1000|
+|tqc |Swimmer-v3 | 339.423| 1.486|1M | 150000| 150|
+|tqc |Walker2DBulletEnv-v0 | 2508.934| 614.624|1M | 149572| 159|
+|tqc |Walker2d-v3 | 4380.720| 500.489|1M | 149606| 152|
+|tqc |parking-v0 | -6.762| 2.690|100k | 149983| 7528|
+|trpo |Acrobot-v1 | -83.114| 18.648|100k | 149976| 1783|
+|trpo |Ant-v3 | 4982.301| 663.761|1M | 149909| 153|
+|trpo |AntBulletEnv-v0 | 2560.621| 52.064|2M | 150000| 150|
+|trpo |BipedalWalker-v3 | 182.339| 145.570|1M | 148440| 148|
+|trpo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300|
+|trpo |HalfCheetah-v3 | 1785.476| 68.672|1M | 150000| 150|
+|trpo |HalfCheetahBulletEnv-v0 | 2758.752| 327.032|2M | 150000| 150|
+|trpo |Hopper-v3 | 3618.386| 356.768|1M | 149575| 152|
+|trpo |HopperBulletEnv-v0 | 2565.416| 410.298|1M | 149640| 154|
+|trpo |LunarLander-v2 | 133.166| 112.173|200k | 149088| 230|
+|trpo |LunarLanderContinuous-v2 | 262.387| 21.428|200k | 149925| 501|
+|trpo |MountainCar-v0 | -107.278| 13.231|100k | 149974| 1398|
+|trpo |MountainCarContinuous-v0 | 92.489| 0.355|50k | 149971| 1732|
+|trpo |Pendulum-v1 | -174.631| 127.577|100k | 150000| 750|
+|trpo |ReacherBulletEnv-v0 | 14.741| 11.559|300k | 150000| 1000|
+|trpo |Swimmer-v3 | 365.663| 2.087|1M | 150000| 150|
+|trpo |Walker2DBulletEnv-v0 | 1483.467| 823.468|2M | 149860| 197|
+|trpo |Walker2d-v3 | 4933.148| 1452.538|1M | 149054| 163|
diff --git a/docker/Dockerfile b/docker/Dockerfile
index a99e9a7b1..2baa348c1 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -21,7 +21,7 @@ RUN \
mkdir -p ${CODE_DIR}/rl_zoo && \
pip uninstall -y stable-baselines3 && \
pip install -r /tmp/requirements.txt && \
- pip install git+https://github.com/eleurent/highway-env@1a04c6a98be64632cf9683625022023e70ff1ab1 && \
+ pip install pip install highway-env==1.5.0 && \
rm -rf $HOME/.cache/pip
ENV PATH=$VENV/bin:$PATH
diff --git a/enjoy.py b/enjoy.py
index c92304fb2..b2a1ecaee 100644
--- a/enjoy.py
+++ b/enjoy.py
@@ -1,4 +1,5 @@
import argparse
+import asyncio
import glob
import importlib
import os
@@ -6,13 +7,34 @@
import numpy as np
import torch as th
+import websockets
import yaml
from stable_baselines3.common.utils import set_random_seed
import utils.import_envs # noqa: F401 pylint: disable=unused-import
-from utils import ALGOS, create_test_env, get_latest_run_id, get_saved_hyperparams
+from utils import ALGOS, create_test_env, get_saved_hyperparams
from utils.exp_manager import ExperimentManager
-from utils.utils import StoreDict
+from utils.load_from_hub import download_from_hub
+from utils.utils import StoreDict, get_model_path
+
+EXIT = False
+socket_port = int(os.environ.get("SOCKET_PORT", 8895))
+
+# To unlock the start
+# echo '{"angle":0,"throttle":0,"drive_mode":"local"}' | websocat "ws://127.0.0.1:8895/wsDrive"
+
+
+async def handler(websocket):
+ _ = await websocket.recv()
+ global EXIT
+ EXIT = True
+
+
+async def main_wait():
+ async with websockets.serve(handler, "", socket_port):
+ while not EXIT:
+ await asyncio.sleep(0.05)
+ print("Exiting socket server")
def main(): # noqa: C901
@@ -71,48 +93,43 @@ def main(): # noqa: C901
algo = args.algo
folder = args.folder
- if args.exp_id == 0:
- args.exp_id = get_latest_run_id(os.path.join(folder, algo), env_id)
- print(f"Loading latest experiment, id={args.exp_id}")
-
- # Sanity checks
- if args.exp_id > 0:
- log_path = os.path.join(folder, algo, f"{env_id}_{args.exp_id}")
- else:
- log_path = os.path.join(folder, algo)
-
- assert os.path.isdir(log_path), f"The {log_path} folder was not found"
-
- found = False
- for ext in ["zip"]:
- model_path = os.path.join(log_path, f"{env_id}.{ext}")
- found = os.path.isfile(model_path)
- if found:
- break
-
- if args.load_best:
- model_path = os.path.join(log_path, "best_model.zip")
- found = os.path.isfile(model_path)
-
- if args.load_checkpoint is not None:
- model_path = os.path.join(log_path, f"rl_model_{args.load_checkpoint}_steps.zip")
- found = os.path.isfile(model_path)
-
- if args.load_last_checkpoint:
- checkpoints = glob.glob(os.path.join(log_path, "rl_model_*_steps.zip"))
- if len(checkpoints) == 0:
- raise ValueError(f"No checkpoint found for {algo} on {env_id}, path: {log_path}")
-
- def step_count(checkpoint_path: str) -> int:
- # path follow the pattern "rl_model_*_steps.zip", we count from the back to ignore any other _ in the path
- return int(checkpoint_path.split("_")[-2])
-
- checkpoints = sorted(checkpoints, key=step_count)
- model_path = checkpoints[-1]
- found = True
-
- if not found:
- raise ValueError(f"No model found for {algo} on {env_id}, path: {model_path}")
+ try:
+ _, model_path, log_path = get_model_path(
+ args.exp_id,
+ folder,
+ algo,
+ env_id,
+ args.load_best,
+ args.load_checkpoint,
+ args.load_last_checkpoint,
+ )
+ except (AssertionError, ValueError) as e:
+ # Special case for rl-trained agents
+ # auto-download from the hub
+ if "rl-trained-agents" not in folder:
+ raise e
+ else:
+ print("Pretrained model not found, trying to download it from sb3 Huggingface hub: https://huggingface.co/sb3")
+ # Auto-download
+ download_from_hub(
+ algo=algo,
+ env_id=env_id,
+ exp_id=args.exp_id,
+ folder=folder,
+ organization="sb3",
+ repo_name=None,
+ force=False,
+ )
+ # Try again
+ _, model_path, log_path = get_model_path(
+ args.exp_id,
+ folder,
+ algo,
+ env_id,
+ args.load_best,
+ args.load_checkpoint,
+ args.load_last_checkpoint,
+ )
print(f"Loading {model_path}")
@@ -138,7 +155,7 @@ def step_count(checkpoint_path: str) -> int:
env_kwargs = {}
args_path = os.path.join(log_path, env_id, "args.yml")
if os.path.isfile(args_path):
- with open(args_path, "r") as f:
+ with open(args_path) as f:
loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr
if loaded_args["env_kwargs"] is not None:
env_kwargs = loaded_args["env_kwargs"]
@@ -180,20 +197,34 @@ def step_count(checkpoint_path: str) -> int:
obs = env.reset()
+ # Wait for message from websocket
+ if bool(int(os.environ.get("WAIT_FOR_START", False))):
+ print(f"Waiting for socket message on port {socket_port}")
+ asyncio.run(main_wait())
+
# Deterministic by default except for atari games
stochastic = args.stochastic or is_atari and not args.deterministic
deterministic = not stochastic
- state = None
episode_reward = 0.0
episode_rewards, episode_lengths = [], []
ep_len = 0
# For HER, monitor success rate
successes = []
+ lstm_states = None
+ episode_start = np.ones((env.num_envs,), dtype=bool)
try:
for _ in range(args.n_timesteps):
- action, state = model.predict(obs, state=state, deterministic=deterministic)
+ action, lstm_states = model.predict(
+ obs,
+ state=lstm_states,
+ episode_start=episode_start,
+ deterministic=deterministic,
+ )
obs, reward, done, infos = env.step(action)
+
+ episode_start = done
+
if not args.no_render:
env.render("human")
@@ -218,7 +249,6 @@ def step_count(checkpoint_path: str) -> int:
episode_lengths.append(ep_len)
episode_reward = 0.0
ep_len = 0
- state = None
# Reset also when the goal is achieved when using HER
if done and infos[0].get("is_success") is not None:
@@ -230,6 +260,7 @@ def step_count(checkpoint_path: str) -> int:
episode_reward, ep_len = 0.0, 0
except KeyboardInterrupt:
+ print("Cancelled by the user...")
pass
if args.verbose > 0 and len(successes) > 0:
diff --git a/hyperparams/a2c.yml b/hyperparams/a2c.yml
index f51b0fb09..ba9416e85 100644
--- a/hyperparams/a2c.yml
+++ b/hyperparams/a2c.yml
@@ -1,6 +1,10 @@
atari:
env_wrapper:
- stable_baselines3.common.atari_wrappers.AtariWrapper
+ # Equivalent to
+ # vec_env_wrapper:
+ # - stable_baselines3.common.vec_env.VecFrameStack:
+ # n_stack: 4
frame_stack: 4
policy: 'CnnPolicy'
n_envs: 16
diff --git a/hyperparams/human.yml b/hyperparams/human.yml
new file mode 100644
index 000000000..6852b93b7
--- /dev/null
+++ b/hyperparams/human.yml
@@ -0,0 +1,13 @@
+donkey-generated-track-v0:
+ env_wrapper:
+ # - gym.wrappers.time_limit.TimeLimit:
+ # max_episode_steps: 1200
+ - utils.wrappers.HistoryWrapper:
+ horizon: 2
+ # - utils.wrappers.TimeFeatureWrapper:
+ # test_mode: True
+ # - stable_baselines3.common.monitor.Monitor:
+ # filename: None
+ n_timesteps: !!float 1e6
+ policy: 'MlpPolicy'
+ buffer_size: 100000
diff --git a/hyperparams/ppo.yml b/hyperparams/ppo.yml
index 2a88c30d8..1532953d5 100644
--- a/hyperparams/ppo.yml
+++ b/hyperparams/ppo.yml
@@ -85,17 +85,17 @@ Acrobot-v1:
BipedalWalker-v3:
normalize: true
- n_envs: 16
+ n_envs: 32
n_timesteps: !!float 5e6
policy: 'MlpPolicy'
n_steps: 2048
batch_size: 64
gae_lambda: 0.95
- gamma: 0.99
+ gamma: 0.999
n_epochs: 10
- ent_coef: 0.001
- learning_rate: !!float 2.5e-4
- clip_range: 0.2
+ ent_coef: 0.0
+ learning_rate: !!float 3e-4
+ clip_range: 0.18
BipedalWalkerHardcore-v3:
normalize: true
@@ -319,28 +319,33 @@ MiniGrid-FourRooms-v0:
CarRacing-v0:
env_wrapper:
+ - utils.wrappers.FrameSkip:
+ skip: 2
- gym.wrappers.resize_observation.ResizeObservation:
shape: 64
- gym.wrappers.gray_scale_observation.GrayScaleObservation:
keep_dim: true
- frame_stack: 4
+ frame_stack: 2
+ normalize: "{'norm_obs': False, 'norm_reward': True}"
n_envs: 8
- n_timesteps: !!float 1e6
+ n_timesteps: !!float 4e6
policy: 'CnnPolicy'
batch_size: 128
n_steps: 512
gamma: 0.99
- gae_lambda: 0.9
- n_epochs: 20
+ gae_lambda: 0.95
+ n_epochs: 10
ent_coef: 0.0
sde_sample_freq: 4
max_grad_norm: 0.5
vf_coef: 0.5
- learning_rate: !!float 3e-5
+ learning_rate: lin_1e-4
use_sde: True
- clip_range: 0.4
+ clip_range: 0.2
policy_kwargs: "dict(log_std_init=-2,
ortho_init=False,
+ activation_fn=nn.GELU,
+ net_arch=[dict(pi=[256], vf=[256])],
)"
@@ -508,7 +513,7 @@ InvertedPendulum-v3:
max_grad_norm: 0.3
vf_coef: 0.19816
-Reacher-v3:
+Reacher-v2:
normalize: true
n_envs: 1
policy: 'MlpPolicy'
@@ -524,28 +529,6 @@ Reacher-v3:
max_grad_norm: 0.9
vf_coef: 0.950368
-# Swimmer-v3:
-# normalize: true
-# n_envs: 1
-# policy: 'MlpPolicy'
-# n_timesteps: !!float 1e6
-# batch_size: 32
-# n_steps: 512
-# gamma: 0.9999
-# learning_rate: 5.49717e-05
-# ent_coef: 0.0554757
-# clip_range: 0.3
-# n_epochs: 10
-# gae_lambda: 0.95
-# max_grad_norm: 0.6
-# vf_coef: 0.38782
-# policy_kwargs: "dict(
-# log_std_init=-2,
-# ortho_init=False,
-# activation_fn=nn.ReLU,
-# net_arch=[dict(pi=[256, 256], vf=[256, 256])]
-# )"
-
Walker2d-v3:
normalize: true
n_envs: 1
diff --git a/hyperparams/ppo_lstm.yml b/hyperparams/ppo_lstm.yml
new file mode 100644
index 000000000..52fffe32f
--- /dev/null
+++ b/hyperparams/ppo_lstm.yml
@@ -0,0 +1,534 @@
+atari:
+ env_wrapper:
+ - stable_baselines3.common.atari_wrappers.AtariWrapper:
+ terminal_on_life_loss: False
+ frame_stack: 4
+ policy: 'CnnLstmPolicy'
+ n_envs: 8
+ n_steps: 128
+ n_epochs: 4
+ batch_size: 256
+ n_timesteps: !!float 1e7
+ learning_rate: lin_2.5e-4
+ clip_range: lin_0.1
+ vf_coef: 0.5
+ ent_coef: 0.01
+ policy_kwargs: "dict(enable_critic_lstm=False,
+ lstm_hidden_size=128,
+ )"
+
+# Tuned
+PendulumNoVel-v1:
+ normalize: True
+ n_envs: 4
+ n_timesteps: !!float 1e5
+ policy: 'MlpLstmPolicy'
+ n_steps: 1024
+ gae_lambda: 0.95
+ gamma: 0.9
+ n_epochs: 10
+ ent_coef: 0.0
+ learning_rate: !!float 1e-3
+ clip_range: 0.2
+ use_sde: True
+ sde_sample_freq: 4
+ policy_kwargs: "dict(
+ ortho_init=False,
+ activation_fn=nn.ReLU,
+ lstm_hidden_size=64,
+ enable_critic_lstm=True,
+ net_arch=[dict(pi=[64], vf=[64])]
+ )"
+
+# Tuned
+CartPoleNoVel-v1:
+ normalize: True
+ n_envs: 8
+ n_timesteps: !!float 1e5
+ policy: 'MlpLstmPolicy'
+ n_steps: 32
+ batch_size: 256
+ gae_lambda: 0.8
+ gamma: 0.98
+ n_epochs: 20
+ ent_coef: 0.0
+ learning_rate: lin_0.001
+ clip_range: lin_0.2
+ policy_kwargs: "dict(
+ ortho_init=False,
+ activation_fn=nn.ReLU,
+ lstm_hidden_size=64,
+ enable_critic_lstm=True,
+ net_arch=[dict(pi=[64], vf=[64])]
+ )"
+
+# TO BE TUNED
+MountainCarNoVel-v0:
+ normalize: true
+ n_envs: 16
+ n_timesteps: !!float 1e6
+ policy: 'MlpLstmPolicy'
+ n_steps: 16
+ gae_lambda: 0.98
+ gamma: 0.99
+ n_epochs: 4
+ ent_coef: 0.0
+
+# Tuned
+MountainCarContinuousNoVel-v0:
+ normalize: true
+ n_envs: 8
+ n_timesteps: !!float 3e5
+ policy: 'MlpLstmPolicy'
+ batch_size: 256
+ n_steps: 1024
+ gamma: 0.9999
+ learning_rate: !!float 7.77e-05
+ ent_coef: 0.00429
+ clip_range: 0.1
+ n_epochs: 10
+ gae_lambda: 0.9
+ max_grad_norm: 5
+ vf_coef: 0.19
+ use_sde: True
+ sde_sample_freq: 8
+ policy_kwargs: "dict(log_std_init=0.0, ortho_init=False,
+ lstm_hidden_size=32,
+ enable_critic_lstm=True,
+ net_arch=[dict(pi=[64], vf=[64])])"
+
+Acrobot-v1:
+ normalize: true
+ n_envs: 16
+ n_timesteps: !!float 1e6
+ policy: 'MlpLstmPolicy'
+ n_steps: 256
+ gae_lambda: 0.94
+ gamma: 0.99
+ n_epochs: 4
+ ent_coef: 0.0
+
+BipedalWalker-v3:
+ normalize: true
+ n_envs: 32
+ n_timesteps: !!float 5e6
+ policy: 'MlpLstmPolicy'
+ n_steps: 256
+ batch_size: 256
+ gae_lambda: 0.95
+ gamma: 0.999
+ n_epochs: 10
+ ent_coef: 0.0
+ learning_rate: !!float 3e-4
+ clip_range: 0.18
+ policy_kwargs: "dict(
+ ortho_init=False,
+ activation_fn=nn.ReLU,
+ lstm_hidden_size=64,
+ enable_critic_lstm=True,
+ net_arch=[dict(pi=[64], vf=[64])]
+ )"
+
+# TO BE TUNED
+BipedalWalkerHardcore-v3:
+ # env_wrapper:
+ # - utils.wrappers.FrameSkip:
+ # skip: 2
+ normalize: true
+ n_envs: 32
+ n_timesteps: !!float 10e7
+ policy: 'MlpLstmPolicy'
+ n_steps: 256
+ batch_size: 256
+ gae_lambda: 0.95
+ gamma: 0.999
+ n_epochs: 10
+ ent_coef: 0.001
+ learning_rate: lin_3e-4
+ clip_range: lin_0.2
+ policy_kwargs: "dict(
+ ortho_init=False,
+ activation_fn=nn.ReLU,
+ lstm_hidden_size=64,
+ enable_critic_lstm=True,
+ net_arch=[dict(pi=[64], vf=[64])]
+ )"
+
+# Tuned
+LunarLanderNoVel-v2: &lunar-defaults
+ normalize: True
+ n_envs: 32
+ n_timesteps: !!float 5e6
+ policy: 'MlpLstmPolicy'
+ n_steps: 512
+ batch_size: 128
+ gae_lambda: 0.98
+ gamma: 0.999
+ n_epochs: 4
+ ent_coef: 0.01
+ policy_kwargs: "dict(
+ ortho_init=False,
+ activation_fn=nn.ReLU,
+ lstm_hidden_size=64,
+ enable_critic_lstm=True,
+ net_arch=[dict(pi=[64], vf=[64])]
+ )"
+
+LunarLanderContinuousNoVel-v2:
+ <<: *lunar-defaults
+
+
+HalfCheetahBulletEnv-v0: &pybullet-defaults
+ normalize: true
+ n_envs: 16
+ n_timesteps: !!float 2e6
+ policy: 'MlpLstmPolicy'
+ batch_size: 128
+ n_steps: 256
+ gamma: 0.99
+ gae_lambda: 0.9
+ n_epochs: 10
+ policy_kwargs: "dict(ortho_init=False,
+ activation_fn=nn.ReLU,
+ net_arch=[dict(pi=[], vf=[])],
+ enable_critic_lstm=True,
+ lstm_hidden_size=128,
+ )"
+
+AntBulletEnv-v0:
+ <<: *pybullet-defaults
+
+Walker2DBulletEnv-v0:
+ <<: *pybullet-defaults
+ clip_range: lin_0.4
+
+HopperBulletEnv-v0:
+ <<: *pybullet-defaults
+ clip_range: lin_0.4
+
+
+ReacherBulletEnv-v0:
+ <<: *pybullet-defaults
+ clip_range: lin_0.4
+
+MinitaurBulletEnv-v0:
+ normalize: true
+ n_envs: 8
+ n_timesteps: !!float 2e6
+ policy: 'MlpLstmPolicy'
+ n_steps: 2048
+ batch_size: 64
+ gae_lambda: 0.95
+ gamma: 0.99
+ n_epochs: 10
+ ent_coef: 0.0
+ learning_rate: 2.5e-4
+ clip_range: 0.2
+
+MinitaurBulletDuckEnv-v0:
+ normalize: true
+ n_envs: 8
+ n_timesteps: !!float 2e6
+ policy: 'MlpLstmPolicy'
+ n_steps: 2048
+ batch_size: 64
+ gae_lambda: 0.95
+ gamma: 0.99
+ n_epochs: 10
+ ent_coef: 0.0
+ learning_rate: 2.5e-4
+ clip_range: 0.2
+
+# To be tuned
+HumanoidBulletEnv-v0:
+ normalize: true
+ n_envs: 8
+ n_timesteps: !!float 1e7
+ policy: 'MlpLstmPolicy'
+ n_steps: 2048
+ batch_size: 64
+ gae_lambda: 0.95
+ gamma: 0.99
+ n_epochs: 10
+ ent_coef: 0.0
+ learning_rate: 2.5e-4
+ clip_range: 0.2
+
+InvertedDoublePendulumBulletEnv-v0:
+ normalize: true
+ n_envs: 8
+ n_timesteps: !!float 2e6
+ policy: 'MlpLstmPolicy'
+ n_steps: 2048
+ batch_size: 64
+ gae_lambda: 0.95
+ gamma: 0.99
+ n_epochs: 10
+ ent_coef: 0.0
+ learning_rate: 2.5e-4
+ clip_range: 0.2
+
+InvertedPendulumSwingupBulletEnv-v0:
+ normalize: true
+ n_envs: 8
+ n_timesteps: !!float 2e6
+ policy: 'MlpLstmPolicy'
+ n_steps: 2048
+ batch_size: 64
+ gae_lambda: 0.95
+ gamma: 0.99
+ n_epochs: 10
+ ent_coef: 0.0
+ learning_rate: 2.5e-4
+ clip_range: 0.2
+
+
+CarRacing-v0:
+ env_wrapper:
+ # - utils.wrappers.FrameSkip:
+ # skip: 2
+ - gym.wrappers.resize_observation.ResizeObservation:
+ shape: 64
+ - gym.wrappers.gray_scale_observation.GrayScaleObservation:
+ keep_dim: true
+ frame_stack: 2
+ normalize: "{'norm_obs': False, 'norm_reward': True}"
+ n_envs: 8
+ n_timesteps: !!float 4e6
+ policy: 'CnnLstmPolicy'
+ batch_size: 128
+ n_steps: 512
+ gamma: 0.99
+ gae_lambda: 0.95
+ n_epochs: 10
+ ent_coef: 0.0
+ sde_sample_freq: 4
+ max_grad_norm: 0.5
+ vf_coef: 0.5
+ learning_rate: lin_1e-4
+ use_sde: True
+ clip_range: 0.2
+ policy_kwargs: "dict(log_std_init=-2,
+ ortho_init=False,
+ enable_critic_lstm=False,
+ activation_fn=nn.GELU,
+ lstm_hidden_size=128,
+ )"
+
+# === Mujoco Envs ===
+# HalfCheetah-v3: &mujoco-defaults
+# normalize: true
+# n_timesteps: !!float 1e6
+# policy: 'MlpLstmPolicy'
+
+Ant-v3: &mujoco-defaults
+ normalize: true
+ n_timesteps: !!float 1e6
+ policy: 'MlpLstmPolicy'
+
+# Hopper-v3:
+# <<: *mujoco-defaults
+#
+# Walker2d-v3:
+# <<: *mujoco-defaults
+#
+# Humanoid-v3:
+# <<: *mujoco-defaults
+# n_timesteps: !!float 2e6
+#
+Swimmer-v3:
+ <<: *mujoco-defaults
+ gamma: 0.9999
+
+
+# 10 mujoco envs
+
+HalfCheetah-v3:
+ normalize: true
+ n_envs: 1
+ policy: 'MlpLstmPolicy'
+ n_timesteps: !!float 1e6
+ batch_size: 64
+ n_steps: 512
+ gamma: 0.98
+ learning_rate: 2.0633e-05
+ ent_coef: 0.000401762
+ clip_range: 0.1
+ n_epochs: 20
+ gae_lambda: 0.92
+ max_grad_norm: 0.8
+ vf_coef: 0.58096
+ policy_kwargs: "dict(
+ log_std_init=-2,
+ ortho_init=False,
+ activation_fn=nn.ReLU,
+ net_arch=[dict(pi=[256, 256], vf=[256, 256])]
+ )"
+
+# Ant-v3:
+# normalize: true
+# n_envs: 1
+# policy: 'MlpLstmPolicy'
+# n_timesteps: !!float 1e7
+# batch_size: 32
+# n_steps: 512
+# gamma: 0.98
+# learning_rate: 1.90609e-05
+# ent_coef: 4.9646e-07
+# clip_range: 0.1
+# n_epochs: 10
+# gae_lambda: 0.8
+# max_grad_norm: 0.6
+# vf_coef: 0.677239
+
+Hopper-v3:
+ normalize: true
+ n_envs: 1
+ policy: 'MlpLstmPolicy'
+ n_timesteps: !!float 1e6
+ batch_size: 32
+ n_steps: 512
+ gamma: 0.999
+ learning_rate: 9.80828e-05
+ ent_coef: 0.00229519
+ clip_range: 0.2
+ n_epochs: 5
+ gae_lambda: 0.99
+ max_grad_norm: 0.7
+ vf_coef: 0.835671
+ policy_kwargs: "dict(
+ log_std_init=-2,
+ ortho_init=False,
+ activation_fn=nn.ReLU,
+ net_arch=[dict(pi=[256, 256], vf=[256, 256])]
+ )"
+
+HumanoidStandup-v3:
+ normalize: true
+ n_envs: 1
+ policy: 'MlpLstmPolicy'
+ n_timesteps: !!float 1e7
+ batch_size: 32
+ n_steps: 512
+ gamma: 0.99
+ learning_rate: 2.55673e-05
+ ent_coef: 3.62109e-06
+ clip_range: 0.3
+ n_epochs: 20
+ gae_lambda: 0.9
+ max_grad_norm: 0.7
+ vf_coef: 0.430793
+ policy_kwargs: "dict(
+ log_std_init=-2,
+ ortho_init=False,
+ activation_fn=nn.ReLU,
+ net_arch=[dict(pi=[256, 256], vf=[256, 256])]
+ )"
+
+Humanoid-v3:
+ normalize: true
+ n_envs: 1
+ policy: 'MlpLstmPolicy'
+ n_timesteps: !!float 1e7
+ batch_size: 256
+ n_steps: 512
+ gamma: 0.95
+ learning_rate: 3.56987e-05
+ ent_coef: 0.00238306
+ clip_range: 0.3
+ n_epochs: 5
+ gae_lambda: 0.9
+ max_grad_norm: 2
+ vf_coef: 0.431892
+ policy_kwargs: "dict(
+ log_std_init=-2,
+ ortho_init=False,
+ activation_fn=nn.ReLU,
+ net_arch=[dict(pi=[256, 256], vf=[256, 256])]
+ )"
+
+InvertedDoublePendulum-v3:
+ normalize: true
+ n_envs: 1
+ policy: 'MlpLstmPolicy'
+ n_timesteps: !!float 1e6
+ batch_size: 512
+ n_steps: 128
+ gamma: 0.98
+ learning_rate: 0.000155454
+ ent_coef: 1.05057e-06
+ clip_range: 0.4
+ n_epochs: 10
+ gae_lambda: 0.8
+ max_grad_norm: 0.5
+ vf_coef: 0.695929
+
+InvertedPendulum-v3:
+ normalize: true
+ n_envs: 1
+ policy: 'MlpLstmPolicy'
+ n_timesteps: !!float 1e6
+ batch_size: 64
+ n_steps: 32
+ gamma: 0.999
+ learning_rate: 0.000222425
+ ent_coef: 1.37976e-07
+ clip_range: 0.4
+ n_epochs: 5
+ gae_lambda: 0.9
+ max_grad_norm: 0.3
+ vf_coef: 0.19816
+
+Reacher-v3:
+ normalize: true
+ n_envs: 1
+ policy: 'MlpLstmPolicy'
+ n_timesteps: !!float 1e6
+ batch_size: 32
+ n_steps: 512
+ gamma: 0.9
+ learning_rate: 0.000104019
+ ent_coef: 7.52585e-08
+ clip_range: 0.3
+ n_epochs: 5
+ gae_lambda: 1.0
+ max_grad_norm: 0.9
+ vf_coef: 0.950368
+
+# Swimmer-v3:
+# normalize: true
+# n_envs: 1
+# policy: 'MlpLstmPolicy'
+# n_timesteps: !!float 1e6
+# batch_size: 32
+# n_steps: 512
+# gamma: 0.9999
+# learning_rate: 5.49717e-05
+# ent_coef: 0.0554757
+# clip_range: 0.3
+# n_epochs: 10
+# gae_lambda: 0.95
+# max_grad_norm: 0.6
+# vf_coef: 0.38782
+# policy_kwargs: "dict(
+# log_std_init=-2,
+# ortho_init=False,
+# activation_fn=nn.ReLU,
+# net_arch=[dict(pi=[256, 256], vf=[256, 256])]
+# )"
+
+Walker2d-v3:
+ normalize: true
+ n_envs: 1
+ policy: 'MlpLstmPolicy'
+ n_timesteps: !!float 1e6
+ batch_size: 32
+ n_steps: 512
+ gamma: 0.99
+ learning_rate: 5.05041e-05
+ ent_coef: 0.000585045
+ clip_range: 0.1
+ n_epochs: 20
+ gae_lambda: 0.95
+ max_grad_norm: 1
+ vf_coef: 0.871923
diff --git a/hyperparams/sac.yml b/hyperparams/sac.yml
index 8aeaef616..d831b8791 100644
--- a/hyperparams/sac.yml
+++ b/hyperparams/sac.yml
@@ -1,281 +1,32 @@
-# Tuned
-MountainCarContinuous-v0:
- n_timesteps: !!float 50000
- policy: 'MlpPolicy'
- learning_rate: !!float 3e-4
- buffer_size: 50000
- batch_size: 512
- ent_coef: 0.1
- train_freq: 32
- gradient_steps: 32
- gamma: 0.9999
- tau: 0.01
- learning_starts: 0
- use_sde: True
- policy_kwargs: "dict(log_std_init=-3.67, net_arch=[64, 64])"
-
-Pendulum-v1:
- # callback:
- # - utils.callbacks.ParallelTrainCallback
- n_timesteps: 20000
- policy: 'MlpPolicy'
- learning_rate: !!float 1e-3
-
-
-LunarLanderContinuous-v2:
- n_timesteps: !!float 5e5
- policy: 'MlpPolicy'
- batch_size: 256
- learning_rate: lin_7.3e-4
- buffer_size: 1000000
- ent_coef: 'auto'
- gamma: 0.99
- tau: 0.01
- train_freq: 1
- gradient_steps: 1
- learning_starts: 10000
- policy_kwargs: "dict(net_arch=[400, 300])"
-
-
-# Tuned
-BipedalWalker-v3:
- n_timesteps: !!float 5e5
- policy: 'MlpPolicy'
- learning_rate: !!float 7.3e-4
- buffer_size: 300000
- batch_size: 256
- ent_coef: 'auto'
- gamma: 0.98
- tau: 0.02
- train_freq: 64
- gradient_steps: 64
- learning_starts: 10000
- use_sde: True
- policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"
-
-# Almost tuned
-BipedalWalkerHardcore-v3:
- n_timesteps: !!float 1e7
- policy: 'MlpPolicy'
- learning_rate: lin_7.3e-4
- buffer_size: 1000000
- batch_size: 256
- ent_coef: 0.005
- gamma: 0.99
- tau: 0.01
- train_freq: 1
- gradient_steps: 1
- learning_starts: 10000
- policy_kwargs: "dict(net_arch=[400, 300])"
-
-# === Bullet envs ===
-
-# Tuned
-HalfCheetahBulletEnv-v0: &pybullet-defaults
- # env_wrapper:
- # - sb3_contrib.common.wrappers.TimeFeatureWrapper
- # - utils.wrappers.DelayedRewardWrapper:
- # delay: 10
- # - utils.wrappers.HistoryWrapper:
- # horizon: 10
- n_timesteps: !!float 1e6
- policy: 'MlpPolicy'
- learning_rate: !!float 7.3e-4
- buffer_size: 300000
- batch_size: 256
- ent_coef: 'auto'
- gamma: 0.98
- tau: 0.02
- train_freq: 8
- gradient_steps: 8
- learning_starts: 10000
- # replay_buffer_kwargs: "dict(handle_timeout_termination=True)"
- use_sde: True
- policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"
-
-# Tuned
-AntBulletEnv-v0:
- <<: *pybullet-defaults
-
-HopperBulletEnv-v0:
- <<: *pybullet-defaults
- learning_rate: lin_7.3e-4
-
-Walker2DBulletEnv-v0:
- <<: *pybullet-defaults
- learning_rate: lin_7.3e-4
-
-
-# Tuned
-ReacherBulletEnv-v0:
- <<: *pybullet-defaults
- n_timesteps: !!float 3e5
-
-HumanoidBulletEnv-v0:
+## Custom envs
+donkey-mountain-track-v0: &defaults
+ # Normalize AE (+ the rest)
normalize: "{'norm_obs': True, 'norm_reward': False}"
- n_timesteps: !!float 2e7
- policy: 'MlpPolicy'
- learning_rate: lin_3e-4
- buffer_size: 1000000
- batch_size: 64
- ent_coef: 'auto'
- train_freq: 1
- gradient_steps: 1
- learning_starts: 1000
-
-# Tuned
-InvertedDoublePendulumBulletEnv-v0:
- <<: *pybullet-defaults
- n_timesteps: !!float 5e5
-
-
-# Tuned
-InvertedPendulumSwingupBulletEnv-v0:
- <<: *pybullet-defaults
- n_timesteps: !!float 3e5
-
-# To be tuned
-MinitaurBulletEnv-v0:
- # normalize: "{'norm_obs': True, 'norm_reward': False}"
- n_timesteps: !!float 1e6
- policy: 'MlpPolicy'
- learning_rate: !!float 3e-4
- buffer_size: 100000
- batch_size: 256
- ent_coef: 'auto'
- train_freq: 1
- gradient_steps: 1
- learning_starts: 10000
-
-
-# To be tuned
-MinitaurBulletDuckEnv-v0:
- n_timesteps: !!float 1e6
- policy: 'MlpPolicy'
- learning_rate: !!float 3e-4
- buffer_size: 1000000
- batch_size: 256
- ent_coef: 'auto'
- train_freq: 1
- gradient_steps: 1
- learning_starts: 10000
-
-# To be tuned
-CarRacing-v0:
- env_wrapper:
- - gym.wrappers.resize_observation.ResizeObservation:
- shape: 64
- - gym.wrappers.gray_scale_observation.GrayScaleObservation:
- keep_dim: true
- frame_stack: 4
- n_envs: 1
- n_timesteps: !!float 1e6
- # TODO: try with MlpPolicy and low res (MNIST: 28) pixels
- # policy: 'MlpPolicy'
- policy: 'CnnPolicy'
- learning_rate: !!float 7.3e-4
- buffer_size: 50000
- batch_size: 256
- ent_coef: 'auto'
- gamma: 0.98
- tau: 0.02
- train_freq: 64
- # train_freq: [1, "episode"]
- gradient_steps: 64
- # sde_sample_freq: 64
- learning_starts: 1000
- use_sde: True
- use_sde_at_warmup: True
- policy_kwargs: "dict(log_std_init=-2, net_arch=[64, 64])"
-
-# === Mujoco Envs ===
-
-HalfCheetah-v3: &mujoco-defaults
- n_timesteps: !!float 1e6
- policy: 'MlpPolicy'
- learning_starts: 10000
-
-Ant-v3:
- <<: *mujoco-defaults
-
-Hopper-v3:
- <<: *mujoco-defaults
-
-Walker2d-v3:
- <<: *mujoco-defaults
-
-Humanoid-v3:
- <<: *mujoco-defaults
- n_timesteps: !!float 2e6
-
-Swimmer-v3:
- <<: *mujoco-defaults
- gamma: 0.9999
-
-# === HER Robotics GoalEnvs ===
-
-FetchReach-v1:
- n_timesteps: !!float 20000
- policy: 'MultiInputPolicy'
- buffer_size: 1000000
- ent_coef: 'auto'
- batch_size: 256
- gamma: 0.95
- learning_rate: 0.001
- learning_starts: 1000
- normalize: True
- replay_buffer_class: HerReplayBuffer
- replay_buffer_kwargs: "dict(
- online_sampling=True,
- goal_selection_strategy='future',
- n_sampled_goal=4
- )"
- policy_kwargs: "dict(net_arch=[64, 64])"
-
-
-# ==== Custom Envs ===
-
-donkey-generated-track-v0:
env_wrapper:
- gym.wrappers.time_limit.TimeLimit:
- max_episode_steps: 500
+ max_episode_steps: 10000
+ - ae.wrapper.AutoencoderWrapper
- utils.wrappers.HistoryWrapper:
- horizon: 5
- n_timesteps: !!float 1e6
+ horizon: 2
+ callback:
+ - utils.callbacks.ParallelTrainCallback:
+ gradient_steps: 200
+ - utils.callbacks.LapTimeCallback
+ n_timesteps: !!float 2e6
policy: 'MlpPolicy'
learning_rate: !!float 7.3e-4
- buffer_size: 300000
+ buffer_size: 200000
batch_size: 256
ent_coef: 'auto'
gamma: 0.99
tau: 0.02
- # train_freq: 64
- train_freq: [1, "episode"]
- # gradient_steps: -1
- gradient_steps: 64
+ train_freq: 200
+ gradient_steps: 256
learning_starts: 500
use_sde_at_warmup: True
use_sde: True
- sde_sample_freq: 64
- policy_kwargs: "dict(log_std_init=-2, net_arch=[64, 64])"
+ sde_sample_freq: 16
+ policy_kwargs: "dict(log_std_init=-3, net_arch=[256, 256], n_critics=2, use_expln=True)"
-# === Real Robot envs
-NeckEnvRelative-v2:
- <<: *pybullet-defaults
- env_wrapper:
- - utils.wrappers.HistoryWrapper:
- horizon: 2
- # - utils.wrappers.LowPassFilterWrapper:
- # freq: 2.0
- # df: 25.0
- n_timesteps: !!float 1e6
- buffer_size: 100000
- gamma: 0.99
- train_freq: [1, "episode"]
- gradient_steps: -1
- # 10 episodes of warm-up
- learning_starts: 3000
- use_sde_at_warmup: True
- use_sde: True
- sde_sample_freq: 64
- policy_kwargs: "dict(log_std_init=-2, net_arch=[256, 256])"
+donkey-circuit-launch-track-v0:
+ <<: *defaults
diff --git a/hyperparams/teleop.yml b/hyperparams/teleop.yml
new file mode 100644
index 000000000..c60e15ebc
--- /dev/null
+++ b/hyperparams/teleop.yml
@@ -0,0 +1,13 @@
+RLRacingEnv-v0:
+ env_wrapper:
+ # - gym.wrappers.time_limit.TimeLimit:
+ # max_episode_steps: 1200
+ - utils.wrappers.HistoryWrapper:
+ horizon: 2
+ # - utils.wrappers.TimeFeatureWrapper:
+ # test_mode: True
+ # - stable_baselines3.common.monitor.Monitor:
+ # filename: None
+ n_timesteps: !!float 2e6
+ policy: 'MlpPolicy'
+ buffer_size: 10000
diff --git a/hyperparams/tqc.yml b/hyperparams/tqc.yml
index 54e2f3a4d..d831b8791 100644
--- a/hyperparams/tqc.yml
+++ b/hyperparams/tqc.yml
@@ -1,259 +1,32 @@
-# Tuned
-MountainCarContinuous-v0:
- n_timesteps: !!float 50000
- policy: 'MlpPolicy'
- learning_rate: !!float 3e-4
- buffer_size: 50000
- batch_size: 512
- ent_coef: 0.1
- train_freq: 32
- gradient_steps: 32
- gamma: 0.9999
- tau: 0.01
- learning_starts: 0
- use_sde: True
- policy_kwargs: "dict(log_std_init=-3.67, net_arch=[64, 64])"
-
-Pendulum-v1:
- n_timesteps: 20000
- policy: 'MlpPolicy'
- learning_rate: !!float 1e-3
-
-LunarLanderContinuous-v2:
- n_timesteps: !!float 5e5
- policy: 'MlpPolicy'
- learning_rate: lin_7.3e-4
- buffer_size: 1000000
- batch_size: 256
- ent_coef: 'auto'
- gamma: 0.99
- tau: 0.01
- train_freq: 1
- gradient_steps: 1
- learning_starts: 10000
- policy_kwargs: "dict(net_arch=[400, 300])"
-
-BipedalWalker-v3:
- n_timesteps: !!float 5e5
- policy: 'MlpPolicy'
- learning_rate: !!float 7.3e-4
- buffer_size: 300000
- batch_size: 256
- ent_coef: 'auto'
- gamma: 0.98
- tau: 0.02
- train_freq: 64
- gradient_steps: 64
- learning_starts: 10000
- use_sde: True
- policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"
-
-# Almost tuned
-# History wrapper of size 2 for better performances
-BipedalWalkerHardcore-v3:
+## Custom envs
+donkey-mountain-track-v0: &defaults
+ # Normalize AE (+ the rest)
+ normalize: "{'norm_obs': True, 'norm_reward': False}"
+ env_wrapper:
+ - gym.wrappers.time_limit.TimeLimit:
+ max_episode_steps: 10000
+ - ae.wrapper.AutoencoderWrapper
+ - utils.wrappers.HistoryWrapper:
+ horizon: 2
+ callback:
+ - utils.callbacks.ParallelTrainCallback:
+ gradient_steps: 200
+ - utils.callbacks.LapTimeCallback
n_timesteps: !!float 2e6
policy: 'MlpPolicy'
- learning_rate: lin_7.3e-4
- buffer_size: 1000000
- batch_size: 256
- ent_coef: 'auto'
- gamma: 0.99
- tau: 0.01
- train_freq: 1
- gradient_steps: 1
- learning_starts: 10000
- policy_kwargs: "dict(net_arch=[400, 300])"
-
-# === Bullet envs ===
-
-# Tuned
-HalfCheetahBulletEnv-v0: &pybullet-defaults
- n_timesteps: !!float 1e6
- policy: 'MlpPolicy'
learning_rate: !!float 7.3e-4
- buffer_size: 300000
+ buffer_size: 200000
batch_size: 256
ent_coef: 'auto'
- gamma: 0.98
- tau: 0.02
- train_freq: 8
- gradient_steps: 8
- learning_starts: 10000
- use_sde: True
- policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"
-
-# Tuned
-AntBulletEnv-v0:
- <<: *pybullet-defaults
-
-# Tuned
-HopperBulletEnv-v0:
- <<: *pybullet-defaults
- learning_rate: lin_7.3e-4
- top_quantiles_to_drop_per_net: 5
-
-# Tuned
-Walker2DBulletEnv-v0:
- <<: *pybullet-defaults
- learning_rate: lin_7.3e-4
-
-ReacherBulletEnv-v0:
- <<: *pybullet-defaults
- n_timesteps: !!float 3e5
-
-# Almost tuned
-HumanoidBulletEnv-v0:
- n_timesteps: !!float 1e7
- policy: 'MlpPolicy'
- learning_rate: lin_7.3e-4
- buffer_size: 300000
- batch_size: 256
- ent_coef: 'auto'
- gamma: 0.98
+ gamma: 0.99
tau: 0.02
- train_freq: 8
- gradient_steps: 8
- learning_starts: 10000
- top_quantiles_to_drop_per_net: 5
+ train_freq: 200
+ gradient_steps: 256
+ learning_starts: 500
+ use_sde_at_warmup: True
use_sde: True
- policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"
-
-InvertedDoublePendulumBulletEnv-v0:
- <<: *pybullet-defaults
- n_timesteps: !!float 5e5
-
-InvertedPendulumSwingupBulletEnv-v0:
- <<: *pybullet-defaults
- n_timesteps: !!float 3e5
-
-MinitaurBulletEnv-v0:
- n_timesteps: !!float 1e6
- policy: 'MlpPolicy'
- learning_rate: !!float 3e-4
- buffer_size: 100000
- batch_size: 256
- ent_coef: 'auto'
- train_freq: 1
- gradient_steps: 1
- learning_starts: 10000
- # top_quantiles_to_drop_per_net: 5
-
-# === Mujoco Envs ===
-
-HalfCheetah-v3: &mujoco-defaults
- n_timesteps: !!float 1e6
- policy: 'MlpPolicy'
- learning_starts: 10000
-
-Ant-v3:
- <<: *mujoco-defaults
-
-Hopper-v3:
- <<: *mujoco-defaults
- top_quantiles_to_drop_per_net: 5
-
-Walker2d-v3:
- <<: *mujoco-defaults
-
-Humanoid-v3:
- <<: *mujoco-defaults
- n_timesteps: !!float 2e6
-
-Swimmer-v3:
- <<: *mujoco-defaults
- gamma: 0.9999
-
-# === HER Robotics GoalEnvs ===
-FetchReach-v1:
- env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper
- n_timesteps: !!float 20000
- policy: 'MultiInputPolicy'
- buffer_size: 1000000
- ent_coef: 'auto'
- batch_size: 256
- gamma: 0.95
- learning_rate: 0.001
- learning_starts: 1000
- normalize: True
- replay_buffer_class: HerReplayBuffer
- replay_buffer_kwargs: "dict(
- online_sampling=True,
- goal_selection_strategy='future',
- n_sampled_goal=4
- )"
- policy_kwargs: "dict(net_arch=[64, 64], n_critics=1)"
-
-FetchPush-v1: &her-defaults
- env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper
- n_timesteps: !!float 1e6
- policy: 'MultiInputPolicy'
- buffer_size: 1000000
- batch_size: 2048
- gamma: 0.95
- learning_rate: !!float 1e-3
- tau: 0.05
- replay_buffer_class: HerReplayBuffer
- replay_buffer_kwargs: "dict(
- online_sampling=True,
- goal_selection_strategy='future',
- n_sampled_goal=4,
- )"
- policy_kwargs: "dict(net_arch=[512, 512, 512], n_critics=2)"
-
-FetchSlide-v1:
- <<: *her-defaults
- n_timesteps: !!float 3e6
-
-FetchPickAndPlace-v1:
- <<: *her-defaults
-
-PandaReach-v1:
- env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper
- n_timesteps: !!float 20000
- policy: 'MultiInputPolicy'
- buffer_size: 1000000
- ent_coef: 'auto'
- batch_size: 256
- gamma: 0.95
- learning_rate: 0.001
- learning_starts: 1000
- normalize: True
- replay_buffer_class: HerReplayBuffer
- replay_buffer_kwargs: "dict(
- online_sampling=True,
- goal_selection_strategy='future',
- n_sampled_goal=4
- )"
- policy_kwargs: "dict(net_arch=[64, 64], n_critics=1)"
-
-PandaPush-v1:
- <<: *her-defaults
-
-PandaSlide-v1:
- <<: *her-defaults
- n_timesteps: !!float 3e6
-
-PandaPickAndPlace-v1:
- <<: *her-defaults
-
-PandaStack-v1:
- <<: *her-defaults
-
-parking-v0:
- <<: *her-defaults
- n_timesteps: !!float 1e5
- buffer_size: 1000000
- batch_size: 512
- gamma: 0.98
- learning_rate: !!float 1.5e-3
- tau: 0.005
- replay_buffer_class: HerReplayBuffer
- replay_buffer_kwargs: "dict(
- online_sampling=False,
- goal_selection_strategy='episode',
- n_sampled_goal=4,
- max_episode_length=100
- )"
+ sde_sample_freq: 16
+ policy_kwargs: "dict(log_std_init=-3, net_arch=[256, 256], n_critics=2, use_expln=True)"
-A1Walking-v0:
- <<: *pybullet-defaults
+donkey-circuit-launch-track-v0:
+ <<: *defaults
diff --git a/logs/benchmark/benchmark.md b/logs/benchmark/benchmark.md
index 0642269c3..61c5239ac 100644
--- a/logs/benchmark/benchmark.md
+++ b/logs/benchmark/benchmark.md
@@ -8,6 +8,9 @@ during this evaluation.
It uses the deterministic policy except for Atari games.
+You can view each model card (it includes video and hyperparameters)
+on our Huggingface page: https://huggingface.co/sb3
+
*NOTE: this is not a quantitative benchmark as it corresponds to only one run
(cf [issue #38](https://github.com/araffin/rl-baselines-zoo/issues/38)).
This benchmark is meant to check algorithm (maximal) performance, find potential bugs
@@ -15,181 +18,185 @@ and also allow users to have access to pretrained agents.*
"M" stands for Million (1e6)
-|algo | env_id |mean_reward|std_reward|n_timesteps|eval_timesteps|eval_episodes|
-|-----|---------------------------|----------:|---------:|-----------|-------------:|------------:|
-|a2c |Acrobot-v1 | -83.353| 17.213|500k | 149979| 1778|
-|a2c |AntBulletEnv-v0 | 2497.147| 37.359|2M | 150000| 150|
-|a2c |AsteroidsNoFrameskip-v4 | 1286.550| 423.750|10M | 614138| 258|
-|a2c |BeamRiderNoFrameskip-v4 | 2890.298| 1379.137|10M | 591104| 47|
-|a2c |BipedalWalker-v3 | 299.754| 23.459|5M | 149287| 208|
-|a2c |BipedalWalkerHardcore-v3 | 96.171| 122.943|200M | 149704| 113|
-|a2c |BreakoutNoFrameskip-v4 | 279.793| 122.177|10M | 604115| 82|
-|a2c |CartPole-v1 | 500.000| 0.000|500k | 150000| 300|
-|a2c |EnduroNoFrameskip-v4 | 0.000| 0.000|10M | 599040| 45|
-|a2c |HalfCheetah-v3 | 3041.174| 157.265|1M | 150000| 150|
-|a2c |HalfCheetahBulletEnv-v0 | 2107.384| 36.008|2M | 150000| 150|
-|a2c |Hopper-v3 | 733.454| 376.574|1M | 149987| 580|
-|a2c |HopperBulletEnv-v0 | 815.355| 313.798|2M | 149541| 254|
-|a2c |LunarLander-v2 | 155.751| 80.419|200k | 149443| 297|
-|a2c |LunarLanderContinuous-v2 | 84.225| 145.906|5M | 149305| 256|
-|a2c |MountainCar-v0 | -111.263| 24.087|1M | 149982| 1348|
-|a2c |MountainCarContinuous-v0 | 91.166| 0.255|100k | 149923| 1659|
-|a2c |Pendulum-v0 | -162.965| 103.210|1M | 150000| 750|
-|a2c |PongNoFrameskip-v4 | 17.292| 3.214|10M | 594910| 65|
-|a2c |QbertNoFrameskip-v4 | 3882.345| 1223.327|10M | 610670| 194|
-|a2c |ReacherBulletEnv-v0 | 14.968| 10.978|2M | 150000| 1000|
-|a2c |RoadRunnerNoFrameskip-v4 | 31671.512| 6364.085|10M | 606710| 172|
-|a2c |SeaquestNoFrameskip-v4 | 1721.493| 105.339|10M | 599691| 67|
-|a2c |SpaceInvadersNoFrameskip-v4| 627.160| 201.974|10M | 604848| 162|
-|a2c |Swimmer-v3 | 200.627| 2.544|1M | 150000| 150|
-|a2c |Walker2DBulletEnv-v0 | 858.209| 333.116|2M | 149156| 173|
-|ars |Acrobot-v1 | -82.884| 23.825|500k | 149985| 1788|
-|ars |Ant-v3 | 2333.773| 20.597|75M | 150000| 150|
-|ars |CartPole-v1 | 500.000| 0.000|50k | 150000| 300|
-|ars |HalfCheetah-v3 | 4815.192| 1340.752|12M | 150000| 150|
-|ars |Hopper-v3 | 3343.919| 5.730|7M | 150000| 150|
-|ars |LunarLanderContinuous-v2 | 167.959| 147.071|2M | 149883| 562|
-|ars |MountainCar-v0 | -122.000| 33.456|500k | 149938| 1229|
-|ars |MountainCarContinuous-v0 | 96.672| 0.784|500k | 149990| 621|
-|ars |Pendulum-v0 | -212.540| 160.444|2M | 150000| 750|
-|ars |Swimmer-v3 | 355.267| 12.796|2M | 150000| 150|
-|ars |Walker2d-v3 | 2993.582| 166.289|75M | 149821| 152|
-|ddpg |AntBulletEnv-v0 | 2399.147| 75.410|1M | 150000| 150|
-|ddpg |BipedalWalker-v3 | 197.486| 141.580|1M | 149237| 227|
-|ddpg |HalfCheetahBulletEnv-v0 | 2078.325| 208.379|1M | 150000| 150|
-|ddpg |HopperBulletEnv-v0 | 1157.065| 448.695|1M | 149565| 346|
-|ddpg |LunarLanderContinuous-v2 | 230.217| 92.372|300k | 149862| 556|
-|ddpg |MountainCarContinuous-v0 | 93.512| 0.048|300k | 149965| 2260|
-|ddpg |Pendulum-v0 | -152.099| 94.282|20k | 150000| 750|
-|ddpg |ReacherBulletEnv-v0 | 15.582| 9.606|300k | 150000| 1000|
-|ddpg |Walker2DBulletEnv-v0 | 1387.591| 736.955|1M | 149051| 208|
-|dqn |Acrobot-v1 | -76.639| 11.752|100k | 149998| 1932|
-|dqn |AsteroidsNoFrameskip-v4 | 782.687| 259.247|10M | 607962| 134|
-|dqn |BeamRiderNoFrameskip-v4 | 4295.946| 1790.458|10M | 600832| 37|
-|dqn |BreakoutNoFrameskip-v4 | 358.327| 61.981|10M | 601461| 55|
-|dqn |CartPole-v1 | 500.000| 0.000|50k | 150000| 300|
-|dqn |EnduroNoFrameskip-v4 | 830.929| 194.544|10M | 599040| 14|
-|dqn |LunarLander-v2 | 154.382| 79.241|100k | 149373| 200|
-|dqn |MountainCar-v0 | -100.849| 9.925|120k | 149962| 1487|
-|dqn |PongNoFrameskip-v4 | 20.602| 0.613|10M | 598998| 88|
-|dqn |QbertNoFrameskip-v4 | 9496.774| 5399.633|10M | 605844| 124|
-|dqn |RoadRunnerNoFrameskip-v4 | 40396.350| 7069.131|10M | 603257| 137|
-|dqn |SeaquestNoFrameskip-v4 | 2000.290| 606.644|10M | 599505| 69|
-|dqn |SpaceInvadersNoFrameskip-v4| 622.742| 201.564|10M | 604311| 155|
-|ppo |Acrobot-v1 | -73.506| 18.201|1M | 149979| 2013|
-|ppo |Ant-v3 | 1327.158| 451.577|1M | 149572| 175|
-|ppo |AntBulletEnv-v0 | 2865.922| 56.468|2M | 150000| 150|
-|ppo |AsteroidsNoFrameskip-v4 | 2156.174| 744.640|10M | 602092| 149|
-|ppo |BeamRiderNoFrameskip-v4 | 3397.000| 1662.368|10M | 598926| 46|
-|ppo |BipedalWalker-v3 | 213.299| 129.490|5M | 149826| 233|
-|ppo |BipedalWalkerHardcore-v3 | 122.374| 117.605|100M | 148036| 105|
-|ppo |BreakoutNoFrameskip-v4 | 398.033| 33.328|10M | 600418| 60|
-|ppo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300|
-|ppo |EnduroNoFrameskip-v4 | 996.364| 176.090|10M | 572416| 11|
-|ppo |HalfCheetah-v3 | 5819.099| 663.530|1M | 150000| 150|
-|ppo |HalfCheetahBulletEnv-v0 | 2924.721| 64.465|2M | 150000| 150|
-|ppo |Hopper-v3 | 2410.435| 10.026|1M | 150000| 150|
-|ppo |HopperBulletEnv-v0 | 2575.054| 223.301|2M | 149094| 152|
-|ppo |LunarLander-v2 | 242.119| 31.823|1M | 149636| 369|
-|ppo |LunarLanderContinuous-v2 | 270.863| 32.072|1M | 149956| 526|
-|ppo |MountainCar-v0 | -110.423| 19.473|1M | 149954| 1358|
-|ppo |MountainCarContinuous-v0 | 88.343| 2.572|20k | 149983| 633|
-|ppo |Pendulum-v0 | -172.225| 104.159|100k | 150000| 750|
-|ppo |PongNoFrameskip-v4 | 20.989| 0.105|10M | 599902| 90|
-|ppo |QbertNoFrameskip-v4 | 15627.108| 3313.538|10M | 600248| 83|
-|ppo |ReacherBulletEnv-v0 | 17.091| 11.048|1M | 150000| 1000|
-|ppo |RoadRunnerNoFrameskip-v4 | 40680.645| 6675.058|10M | 605786| 155|
-|ppo |SeaquestNoFrameskip-v4 | 1783.636| 34.096|10M | 598243| 66|
-|ppo |SpaceInvadersNoFrameskip-v4| 960.331| 425.355|10M | 603771| 136|
-|ppo |Swimmer-v3 | 281.561| 9.671|1M | 150000| 150|
-|ppo |Walker2DBulletEnv-v0 | 2109.992| 13.899|2M | 150000| 150|
-|ppo |Walker2d-v3 | 3478.798| 821.708|1M | 149343| 171|
-|qrdqn|Acrobot-v1 | -69.135| 9.967|100k | 149949| 2138|
-|qrdqn|AsteroidsNoFrameskip-v4 | 2185.303| 1097.172|10M | 599784| 66|
-|qrdqn|BeamRiderNoFrameskip-v4 | 17122.941| 10769.997|10M | 596483| 17|
-|qrdqn|BreakoutNoFrameskip-v4 | 393.600| 79.828|10M | 579711| 40|
-|qrdqn|CartPole-v1 | 500.000| 0.000|50k | 150000| 300|
-|qrdqn|EnduroNoFrameskip-v4 | 3231.200| 1311.801|10M | 585728| 5|
-|qrdqn|LunarLander-v2 | 70.236| 225.491|100k | 149957| 522|
-|qrdqn|MountainCar-v0 | -106.042| 15.536|120k | 149943| 1414|
-|qrdqn|PongNoFrameskip-v4 | 20.492| 0.687|10M | 597443| 63|
-|qrdqn|QbertNoFrameskip-v4 | 14799.728| 2917.629|10M | 600773| 92|
-|qrdqn|RoadRunnerNoFrameskip-v4 | 42325.424| 8361.161|10M | 591016| 59|
-|qrdqn|SeaquestNoFrameskip-v4 | 2557.576| 76.951|10M | 596275| 66|
-|qrdqn|SpaceInvadersNoFrameskip-v4| 1899.928| 823.488|10M | 597218| 69|
-|sac |Ant-v3 | 4615.791| 1354.111|1M | 149074| 165|
-|sac |AntBulletEnv-v0 | 3073.114| 175.148|1M | 150000| 150|
-|sac |BipedalWalker-v3 | 297.668| 33.060|500k | 149530| 136|
-|sac |BipedalWalkerHardcore-v3 | 4.423| 103.910|10M | 149794| 88|
-|sac |HalfCheetah-v3 | 9535.451| 100.470|1M | 150000| 150|
-|sac |HalfCheetahBulletEnv-v0 | 2792.170| 12.088|1M | 150000| 150|
-|sac |Hopper-v3 | 2325.547| 1129.676|1M | 149841| 236|
-|sac |HopperBulletEnv-v0 | 2603.494| 164.322|1M | 149724| 151|
-|sac |Humanoid-v3 | 6232.287| 279.885|2M | 149460| 150|
-|sac |LunarLanderContinuous-v2 | 260.390| 65.467|500k | 149634| 672|
-|sac |MountainCarContinuous-v0 | 94.679| 1.134|50k | 149966| 1443|
-|sac |Pendulum-v0 | -156.995| 88.714|20k | 150000| 750|
-|sac |ReacherBulletEnv-v0 | 18.062| 9.729|300k | 150000| 1000|
-|sac |Swimmer-v3 | 345.568| 3.084|1M | 150000| 150|
-|sac |Walker2DBulletEnv-v0 | 2292.266| 13.970|1M | 149983| 150|
-|sac |Walker2d-v3 | 3863.203| 254.347|1M | 149309| 150|
-|td3 |Ant-v3 | 5813.274| 589.773|1M | 149393| 151|
-|td3 |AntBulletEnv-v0 | 3300.026| 54.640|1M | 150000| 150|
-|td3 |BipedalWalker-v3 | 305.990| 56.886|1M | 149999| 224|
-|td3 |BipedalWalkerHardcore-v3 | -98.116| 16.087|10M | 150000| 75|
-|td3 |HalfCheetah-v3 | 9655.666| 969.916|1M | 150000| 150|
-|td3 |HalfCheetahBulletEnv-v0 | 2821.641| 19.722|1M | 150000| 150|
-|td3 |Hopper-v3 | 3606.390| 4.027|1M | 150000| 150|
-|td3 |HopperBulletEnv-v0 | 2681.609| 27.806|1M | 149486| 150|
-|td3 |Humanoid-v3 | 5566.687| 14.544|2M | 150000| 150|
-|td3 |LunarLanderContinuous-v2 | 207.451| 67.562|300k | 149488| 337|
-|td3 |MountainCarContinuous-v0 | 93.483| 0.075|300k | 149976| 2275|
-|td3 |Pendulum-v0 | -151.855| 90.227|20k | 150000| 750|
-|td3 |ReacherBulletEnv-v0 | 17.114| 9.750|300k | 150000| 1000|
-|td3 |Swimmer-v3 | 359.127| 1.244|1M | 150000| 150|
-|td3 |Walker2DBulletEnv-v0 | 2213.672| 230.558|1M | 149800| 152|
-|td3 |Walker2d-v3 | 4717.823| 46.303|1M | 150000| 150|
-|tqc |Ant-v3 | 3339.362| 1969.906|1M | 149583| 202|
-|tqc |AntBulletEnv-v0 | 3456.717| 248.733|1M | 150000| 150|
-|tqc |BipedalWalker-v3 | 329.808| 45.083|500k | 149682| 254|
-|tqc |BipedalWalkerHardcore-v3 | 235.226| 110.569|2M | 149032| 131|
-|tqc |FetchPickAndPlace-v1 | -9.331| 6.850|1M | 150000| 3000|
-|tqc |FetchPush-v1 | -8.799| 5.438|1M | 150000| 3000|
-|tqc |FetchReach-v1 | -1.659| 0.873|20k | 150000| 3000|
-|tqc |FetchSlide-v1 | -29.210| 11.387|3M | 150000| 3000|
-|tqc |HalfCheetah-v3 | 12089.939| 127.440|1M | 150000| 150|
-|tqc |HalfCheetahBulletEnv-v0 | 3675.299| 17.681|1M | 150000| 150|
-|tqc |Hopper-v3 | 3754.199| 8.276|1M | 150000| 150|
-|tqc |HopperBulletEnv-v0 | 2662.373| 206.210|1M | 149881| 151|
-|tqc |Humanoid-v3 | 7239.320| 1647.498|2M | 149508| 165|
-|tqc |LunarLanderContinuous-v2 | 277.956| 25.466|500k | 149928| 706|
-|tqc |MountainCarContinuous-v0 | 63.641| 45.259|50k | 149796| 186|
-|tqc |PandaPickAndPlace-v1 | -8.024| 6.674|1M | 150000| 3000|
-|tqc |PandaPush-v1 | -6.405| 6.400|1M | 150000| 3000|
-|tqc |PandaReach-v1 | -1.768| 0.858|20k | 150000| 3000|
-|tqc |PandaSlide-v1 | -27.497| 9.868|3M | 150000| 3000|
-|tqc |PandaStack-v1 | -96.915| 17.240|1M | 150000| 1500|
-|tqc |Pendulum-v0 | -151.340| 87.893|20k | 150000| 750|
-|tqc |ReacherBulletEnv-v0 | 18.255| 9.543|300k | 150000| 1000|
-|tqc |Swimmer-v3 | 339.423| 1.486|1M | 150000| 150|
-|tqc |Walker2DBulletEnv-v0 | 2508.934| 614.624|1M | 149572| 159|
-|tqc |Walker2d-v3 | 4380.720| 500.489|1M | 149606| 152|
-|tqc |parking-v0 | -6.762| 2.690|100k | 149983| 7528|
-|trpo |Acrobot-v1 | -83.114| 18.648|100k | 149976| 1783|
-|trpo |Ant-v3 | 4982.301| 663.761|1M | 149909| 153|
-|trpo |AntBulletEnv-v0 | 2560.621| 52.064|2M | 150000| 150|
-|trpo |BipedalWalker-v3 | 182.339| 145.570|1M | 148440| 148|
-|trpo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300|
-|trpo |HalfCheetah-v3 | 1785.476| 68.672|1M | 150000| 150|
-|trpo |HalfCheetahBulletEnv-v0 | 2758.752| 327.032|2M | 150000| 150|
-|trpo |Hopper-v3 | 3618.386| 356.768|1M | 149575| 152|
-|trpo |HopperBulletEnv-v0 | 2565.416| 410.298|1M | 149640| 154|
-|trpo |LunarLander-v2 | 133.166| 112.173|200k | 149088| 230|
-|trpo |LunarLanderContinuous-v2 | 262.387| 21.428|200k | 149925| 501|
-|trpo |MountainCar-v0 | -107.278| 13.231|100k | 149974| 1398|
-|trpo |MountainCarContinuous-v0 | 92.489| 0.355|50k | 149971| 1732|
-|trpo |Pendulum-v0 | -174.631| 127.577|100k | 150000| 750|
-|trpo |ReacherBulletEnv-v0 | 14.741| 11.559|300k | 150000| 1000|
-|trpo |Swimmer-v3 | 365.663| 2.087|1M | 150000| 150|
-|trpo |Walker2DBulletEnv-v0 | 1483.467| 823.468|2M | 149860| 197|
-|trpo |Walker2d-v3 | 4933.148| 1452.538|1M | 149054| 163|
+| algo | env_id |mean_reward|std_reward|n_timesteps|eval_timesteps|eval_episodes|
+|--------|-----------------------------|----------:|---------:|-----------|-------------:|------------:|
+|a2c |Acrobot-v1 | -83.353| 17.213|500k | 149979| 1778|
+|a2c |AntBulletEnv-v0 | 2497.147| 37.359|2M | 150000| 150|
+|a2c |AsteroidsNoFrameskip-v4 | 1286.550| 423.750|10M | 614138| 258|
+|a2c |BeamRiderNoFrameskip-v4 | 2890.298| 1379.137|10M | 591104| 47|
+|a2c |BipedalWalker-v3 | 299.754| 23.459|5M | 149287| 208|
+|a2c |BipedalWalkerHardcore-v3 | 96.171| 122.943|200M | 149704| 113|
+|a2c |BreakoutNoFrameskip-v4 | 279.793| 122.177|10M | 604115| 82|
+|a2c |CartPole-v1 | 500.000| 0.000|500k | 150000| 300|
+|a2c |EnduroNoFrameskip-v4 | 0.000| 0.000|10M | 599040| 45|
+|a2c |HalfCheetah-v3 | 3041.174| 157.265|1M | 150000| 150|
+|a2c |HalfCheetahBulletEnv-v0 | 2107.384| 36.008|2M | 150000| 150|
+|a2c |Hopper-v3 | 733.454| 376.574|1M | 149987| 580|
+|a2c |HopperBulletEnv-v0 | 815.355| 313.798|2M | 149541| 254|
+|a2c |LunarLander-v2 | 155.751| 80.419|200k | 149443| 297|
+|a2c |LunarLanderContinuous-v2 | 84.225| 145.906|5M | 149305| 256|
+|a2c |MountainCar-v0 | -111.263| 24.087|1M | 149982| 1348|
+|a2c |MountainCarContinuous-v0 | 91.166| 0.255|100k | 149923| 1659|
+|a2c |Pendulum-v1 | -162.965| 103.210|1M | 150000| 750|
+|a2c |PongNoFrameskip-v4 | 17.292| 3.214|10M | 594910| 65|
+|a2c |QbertNoFrameskip-v4 | 3882.345| 1223.327|10M | 610670| 194|
+|a2c |ReacherBulletEnv-v0 | 14.968| 10.978|2M | 150000| 1000|
+|a2c |RoadRunnerNoFrameskip-v4 | 31671.512| 6364.085|10M | 606710| 172|
+|a2c |SeaquestNoFrameskip-v4 | 1721.493| 105.339|10M | 599691| 67|
+|a2c |SpaceInvadersNoFrameskip-v4 | 627.160| 201.974|10M | 604848| 162|
+|a2c |Swimmer-v3 | 200.627| 2.544|1M | 150000| 150|
+|a2c |Walker2DBulletEnv-v0 | 858.209| 333.116|2M | 149156| 173|
+|ars |Acrobot-v1 | -82.884| 23.825|500k | 149985| 1788|
+|ars |Ant-v3 | 2333.773| 20.597|75M | 150000| 150|
+|ars |CartPole-v1 | 500.000| 0.000|50k | 150000| 300|
+|ars |HalfCheetah-v3 | 4815.192| 1340.752|12M | 150000| 150|
+|ars |Hopper-v3 | 3343.919| 5.730|7M | 150000| 150|
+|ars |LunarLanderContinuous-v2 | 167.959| 147.071|2M | 149883| 562|
+|ars |MountainCar-v0 | -122.000| 33.456|500k | 149938| 1229|
+|ars |MountainCarContinuous-v0 | 96.672| 0.784|500k | 149990| 621|
+|ars |Pendulum-v1 | -212.540| 160.444|2M | 150000| 750|
+|ars |Swimmer-v3 | 355.267| 12.796|2M | 150000| 150|
+|ars |Walker2d-v3 | 2993.582| 166.289|75M | 149821| 152|
+|ddpg |AntBulletEnv-v0 | 2399.147| 75.410|1M | 150000| 150|
+|ddpg |BipedalWalker-v3 | 197.486| 141.580|1M | 149237| 227|
+|ddpg |HalfCheetahBulletEnv-v0 | 2078.325| 208.379|1M | 150000| 150|
+|ddpg |HopperBulletEnv-v0 | 1157.065| 448.695|1M | 149565| 346|
+|ddpg |LunarLanderContinuous-v2 | 230.217| 92.372|300k | 149862| 556|
+|ddpg |MountainCarContinuous-v0 | 93.512| 0.048|300k | 149965| 2260|
+|ddpg |Pendulum-v1 | -152.099| 94.282|20k | 150000| 750|
+|ddpg |ReacherBulletEnv-v0 | 15.582| 9.606|300k | 150000| 1000|
+|ddpg |Walker2DBulletEnv-v0 | 1387.591| 736.955|1M | 149051| 208|
+|dqn |Acrobot-v1 | -76.639| 11.752|100k | 149998| 1932|
+|dqn |AsteroidsNoFrameskip-v4 | 782.687| 259.247|10M | 607962| 134|
+|dqn |BeamRiderNoFrameskip-v4 | 4295.946| 1790.458|10M | 600832| 37|
+|dqn |BreakoutNoFrameskip-v4 | 358.327| 61.981|10M | 601461| 55|
+|dqn |CartPole-v1 | 500.000| 0.000|50k | 150000| 300|
+|dqn |EnduroNoFrameskip-v4 | 830.929| 194.544|10M | 599040| 14|
+|dqn |LunarLander-v2 | 154.382| 79.241|100k | 149373| 200|
+|dqn |MountainCar-v0 | -100.849| 9.925|120k | 149962| 1487|
+|dqn |PongNoFrameskip-v4 | 20.602| 0.613|10M | 598998| 88|
+|dqn |QbertNoFrameskip-v4 | 9496.774| 5399.633|10M | 605844| 124|
+|dqn |RoadRunnerNoFrameskip-v4 | 40396.350| 7069.131|10M | 603257| 137|
+|dqn |SeaquestNoFrameskip-v4 | 2000.290| 606.644|10M | 599505| 69|
+|dqn |SpaceInvadersNoFrameskip-v4 | 622.742| 201.564|10M | 604311| 155|
+|ppo |Acrobot-v1 | -73.506| 18.201|1M | 149979| 2013|
+|ppo |Ant-v3 | 1327.158| 451.577|1M | 149572| 175|
+|ppo |AntBulletEnv-v0 | 2865.922| 56.468|2M | 150000| 150|
+|ppo |AsteroidsNoFrameskip-v4 | 2156.174| 744.640|10M | 602092| 149|
+|ppo |BeamRiderNoFrameskip-v4 | 3397.000| 1662.368|10M | 598926| 46|
+|ppo |BipedalWalker-v3 | 287.939| 2.448|5M | 149589| 123|
+|ppo |BipedalWalkerHardcore-v3 | 122.374| 117.605|100M | 148036| 105|
+|ppo |BreakoutNoFrameskip-v4 | 398.033| 33.328|10M | 600418| 60|
+|ppo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300|
+|ppo |EnduroNoFrameskip-v4 | 996.364| 176.090|10M | 572416| 11|
+|ppo |HalfCheetah-v3 | 5819.099| 663.530|1M | 150000| 150|
+|ppo |HalfCheetahBulletEnv-v0 | 2924.721| 64.465|2M | 150000| 150|
+|ppo |Hopper-v3 | 2410.435| 10.026|1M | 150000| 150|
+|ppo |HopperBulletEnv-v0 | 2575.054| 223.301|2M | 149094| 152|
+|ppo |LunarLander-v2 | 242.119| 31.823|1M | 149636| 369|
+|ppo |LunarLanderContinuous-v2 | 270.863| 32.072|1M | 149956| 526|
+|ppo |MountainCar-v0 | -110.423| 19.473|1M | 149954| 1358|
+|ppo |MountainCarContinuous-v0 | 88.343| 2.572|20k | 149983| 633|
+|ppo |Pendulum-v1 | -172.225| 104.159|100k | 150000| 750|
+|ppo |PongNoFrameskip-v4 | 20.989| 0.105|10M | 599902| 90|
+|ppo |QbertNoFrameskip-v4 | 15627.108| 3313.538|10M | 600248| 83|
+|ppo |ReacherBulletEnv-v0 | 17.091| 11.048|1M | 150000| 1000|
+|ppo |RoadRunnerNoFrameskip-v4 | 40680.645| 6675.058|10M | 605786| 155|
+|ppo |SeaquestNoFrameskip-v4 | 1783.636| 34.096|10M | 598243| 66|
+|ppo |SpaceInvadersNoFrameskip-v4 | 960.331| 425.355|10M | 603771| 136|
+|ppo |Swimmer-v3 | 281.561| 9.671|1M | 150000| 150|
+|ppo |Walker2DBulletEnv-v0 | 2109.992| 13.899|2M | 150000| 150|
+|ppo |Walker2d-v3 | 3478.798| 821.708|1M | 149343| 171|
+|ppo_lstm|CarRacing-v0 | 862.549| 97.342|4M | 149588| 156|
+|ppo_lstm|CartPoleNoVel-v1 | 500.000| 0.000|100k | 150000| 300|
+|ppo_lstm|MountainCarContinuousNoVel-v0| 91.469| 1.776|300k | 149882| 1340|
+|ppo_lstm|PendulumNoVel-v1 | -217.933| 140.094|100k | 150000| 750|
+|qrdqn |Acrobot-v1 | -69.135| 9.967|100k | 149949| 2138|
+|qrdqn |AsteroidsNoFrameskip-v4 | 2185.303| 1097.172|10M | 599784| 66|
+|qrdqn |BeamRiderNoFrameskip-v4 | 17122.941| 10769.997|10M | 596483| 17|
+|qrdqn |BreakoutNoFrameskip-v4 | 393.600| 79.828|10M | 579711| 40|
+|qrdqn |CartPole-v1 | 500.000| 0.000|50k | 150000| 300|
+|qrdqn |EnduroNoFrameskip-v4 | 3231.200| 1311.801|10M | 585728| 5|
+|qrdqn |LunarLander-v2 | 70.236| 225.491|100k | 149957| 522|
+|qrdqn |MountainCar-v0 | -106.042| 15.536|120k | 149943| 1414|
+|qrdqn |PongNoFrameskip-v4 | 20.492| 0.687|10M | 597443| 63|
+|qrdqn |QbertNoFrameskip-v4 | 14799.728| 2917.629|10M | 600773| 92|
+|qrdqn |RoadRunnerNoFrameskip-v4 | 42325.424| 8361.161|10M | 591016| 59|
+|qrdqn |SeaquestNoFrameskip-v4 | 2557.576| 76.951|10M | 596275| 66|
+|qrdqn |SpaceInvadersNoFrameskip-v4 | 1899.928| 823.488|10M | 597218| 69|
+|sac |Ant-v3 | 4615.791| 1354.111|1M | 149074| 165|
+|sac |AntBulletEnv-v0 | 3073.114| 175.148|1M | 150000| 150|
+|sac |BipedalWalker-v3 | 297.668| 33.060|500k | 149530| 136|
+|sac |BipedalWalkerHardcore-v3 | 4.423| 103.910|10M | 149794| 88|
+|sac |HalfCheetah-v3 | 9535.451| 100.470|1M | 150000| 150|
+|sac |HalfCheetahBulletEnv-v0 | 2792.170| 12.088|1M | 150000| 150|
+|sac |Hopper-v3 | 2325.547| 1129.676|1M | 149841| 236|
+|sac |HopperBulletEnv-v0 | 2603.494| 164.322|1M | 149724| 151|
+|sac |Humanoid-v3 | 6232.287| 279.885|2M | 149460| 150|
+|sac |LunarLanderContinuous-v2 | 260.390| 65.467|500k | 149634| 672|
+|sac |MountainCarContinuous-v0 | 94.679| 1.134|50k | 149966| 1443|
+|sac |Pendulum-v1 | -156.995| 88.714|20k | 150000| 750|
+|sac |ReacherBulletEnv-v0 | 18.062| 9.729|300k | 150000| 1000|
+|sac |Swimmer-v3 | 345.568| 3.084|1M | 150000| 150|
+|sac |Walker2DBulletEnv-v0 | 2292.266| 13.970|1M | 149983| 150|
+|sac |Walker2d-v3 | 3863.203| 254.347|1M | 149309| 150|
+|td3 |Ant-v3 | 5813.274| 589.773|1M | 149393| 151|
+|td3 |AntBulletEnv-v0 | 3300.026| 54.640|1M | 150000| 150|
+|td3 |BipedalWalker-v3 | 305.990| 56.886|1M | 149999| 224|
+|td3 |BipedalWalkerHardcore-v3 | -98.116| 16.087|10M | 150000| 75|
+|td3 |HalfCheetah-v3 | 9655.666| 969.916|1M | 150000| 150|
+|td3 |HalfCheetahBulletEnv-v0 | 2821.641| 19.722|1M | 150000| 150|
+|td3 |Hopper-v3 | 3606.390| 4.027|1M | 150000| 150|
+|td3 |HopperBulletEnv-v0 | 2681.609| 27.806|1M | 149486| 150|
+|td3 |Humanoid-v3 | 5566.687| 14.544|2M | 150000| 150|
+|td3 |LunarLanderContinuous-v2 | 207.451| 67.562|300k | 149488| 337|
+|td3 |MountainCarContinuous-v0 | 93.483| 0.075|300k | 149976| 2275|
+|td3 |Pendulum-v1 | -151.855| 90.227|20k | 150000| 750|
+|td3 |ReacherBulletEnv-v0 | 17.114| 9.750|300k | 150000| 1000|
+|td3 |Swimmer-v3 | 359.127| 1.244|1M | 150000| 150|
+|td3 |Walker2DBulletEnv-v0 | 2213.672| 230.558|1M | 149800| 152|
+|td3 |Walker2d-v3 | 4717.823| 46.303|1M | 150000| 150|
+|tqc |Ant-v3 | 3339.362| 1969.906|1M | 149583| 202|
+|tqc |AntBulletEnv-v0 | 3456.717| 248.733|1M | 150000| 150|
+|tqc |BipedalWalker-v3 | 329.808| 45.083|500k | 149682| 254|
+|tqc |BipedalWalkerHardcore-v3 | 235.226| 110.569|2M | 149032| 131|
+|tqc |FetchPickAndPlace-v1 | -9.331| 6.850|1M | 150000| 3000|
+|tqc |FetchPush-v1 | -8.799| 5.438|1M | 150000| 3000|
+|tqc |FetchReach-v1 | -1.659| 0.873|20k | 150000| 3000|
+|tqc |FetchSlide-v1 | -29.210| 11.387|3M | 150000| 3000|
+|tqc |HalfCheetah-v3 | 12089.939| 127.440|1M | 150000| 150|
+|tqc |HalfCheetahBulletEnv-v0 | 3675.299| 17.681|1M | 150000| 150|
+|tqc |Hopper-v3 | 3754.199| 8.276|1M | 150000| 150|
+|tqc |HopperBulletEnv-v0 | 2662.373| 206.210|1M | 149881| 151|
+|tqc |Humanoid-v3 | 7239.320| 1647.498|2M | 149508| 165|
+|tqc |LunarLanderContinuous-v2 | 277.956| 25.466|500k | 149928| 706|
+|tqc |MountainCarContinuous-v0 | 63.641| 45.259|50k | 149796| 186|
+|tqc |PandaPickAndPlace-v1 | -8.024| 6.674|1M | 150000| 3000|
+|tqc |PandaPush-v1 | -6.405| 6.400|1M | 150000| 3000|
+|tqc |PandaReach-v1 | -1.768| 0.858|20k | 150000| 3000|
+|tqc |PandaSlide-v1 | -27.497| 9.868|3M | 150000| 3000|
+|tqc |PandaStack-v1 | -96.915| 17.240|1M | 150000| 1500|
+|tqc |Pendulum-v1 | -151.340| 87.893|20k | 150000| 750|
+|tqc |ReacherBulletEnv-v0 | 18.255| 9.543|300k | 150000| 1000|
+|tqc |Swimmer-v3 | 339.423| 1.486|1M | 150000| 150|
+|tqc |Walker2DBulletEnv-v0 | 2508.934| 614.624|1M | 149572| 159|
+|tqc |Walker2d-v3 | 4380.720| 500.489|1M | 149606| 152|
+|tqc |parking-v0 | -6.762| 2.690|100k | 149983| 7528|
+|trpo |Acrobot-v1 | -83.114| 18.648|100k | 149976| 1783|
+|trpo |Ant-v3 | 4982.301| 663.761|1M | 149909| 153|
+|trpo |AntBulletEnv-v0 | 2560.621| 52.064|2M | 150000| 150|
+|trpo |BipedalWalker-v3 | 182.339| 145.570|1M | 148440| 148|
+|trpo |CartPole-v1 | 500.000| 0.000|100k | 150000| 300|
+|trpo |HalfCheetah-v3 | 1785.476| 68.672|1M | 150000| 150|
+|trpo |HalfCheetahBulletEnv-v0 | 2758.752| 327.032|2M | 150000| 150|
+|trpo |Hopper-v3 | 3618.386| 356.768|1M | 149575| 152|
+|trpo |HopperBulletEnv-v0 | 2565.416| 410.298|1M | 149640| 154|
+|trpo |LunarLander-v2 | 133.166| 112.173|200k | 149088| 230|
+|trpo |LunarLanderContinuous-v2 | 262.387| 21.428|200k | 149925| 501|
+|trpo |MountainCar-v0 | -107.278| 13.231|100k | 149974| 1398|
+|trpo |MountainCarContinuous-v0 | 92.489| 0.355|50k | 149971| 1732|
+|trpo |Pendulum-v1 | -174.631| 127.577|100k | 150000| 750|
+|trpo |ReacherBulletEnv-v0 | 14.741| 11.559|300k | 150000| 1000|
+|trpo |Swimmer-v3 | 365.663| 2.087|1M | 150000| 150|
+|trpo |Walker2DBulletEnv-v0 | 1483.467| 823.468|2M | 149860| 197|
+|trpo |Walker2d-v3 | 4933.148| 1452.538|1M | 149054| 163|
diff --git a/logs/benchmark/ppo-BipedalWalker-v3/0.monitor.csv b/logs/benchmark/ppo-BipedalWalker-v3/0.monitor.csv
index 65b44f3af..6a2bc3eca 100644
--- a/logs/benchmark/ppo-BipedalWalker-v3/0.monitor.csv
+++ b/logs/benchmark/ppo-BipedalWalker-v3/0.monitor.csv
@@ -1,235 +1,125 @@
-#{"t_start": 1614771346.3718872, "env_id": "BipedalWalker-v3"}
+#{"t_start": 1654205553.3965414, "env_id": "BipedalWalker-v3"}
r,l,t
-302.352527,761,3.219014
-304.643791,747,3.818765
-118.656863,577,4.303263
-186.806605,791,4.924971
-302.92318,727,5.496392
-95.043446,526,5.913008
-304.071331,736,6.488454
-304.961145,742,7.067342
-305.404706,747,7.653983
-43.647549,413,7.977797
-305.786318,733,8.550616
--65.333118,166,8.681444
-300.446239,878,9.366114
-302.302039,756,9.955367
--63.398006,161,10.085125
-303.612975,709,10.647742
-301.85852,756,11.242622
-302.07473,754,11.829957
-302.562625,784,12.443302
-302.523438,773,13.046803
-304.049687,740,13.627153
-304.22878,791,14.251293
-301.844681,798,14.872318
-304.990141,727,15.439051
-303.951432,747,16.025128
-301.880788,764,16.623663
-302.835429,744,17.20378
--63.166425,240,17.392857
--19.553167,255,17.598439
-71.883582,459,17.96465
-171.700376,662,18.479113
-303.257497,755,19.065179
--31.888369,226,19.241292
-301.43717,821,19.880159
-302.191556,772,20.484396
--45.01458,317,20.733858
-184.726106,752,21.320945
-302.56233,759,21.909227
-303.76186,792,22.526143
-301.663292,739,23.103
-303.368805,742,23.679868
-302.449036,748,24.260264
-303.924373,751,24.844364
-303.955462,754,25.4316
-305.325247,738,26.004931
-305.033799,736,26.575473
-182.856283,707,27.126774
-302.0432,770,27.727888
-304.139851,732,28.296654
-303.16995,764,28.890888
-301.428877,755,29.481789
-303.415154,745,30.066476
-306.591512,727,30.635611
-304.58744,762,31.230929
-303.367456,763,31.826358
-305.595314,753,32.413927
-4.373387,304,32.651849
-193.166831,746,33.231313
-125.378329,552,33.665584
-13.381573,340,33.931804
-96.788857,505,34.321924
-39.350058,450,34.66974
-305.335492,746,35.251793
--52.636007,202,35.409539
-304.796698,728,35.976646
-6.380472,301,36.210353
-304.459181,742,36.788108
-140.072415,754,37.373717
--12.509825,271,37.585901
-307.386125,708,38.136988
-303.447734,753,38.725262
-126.609499,627,39.212487
-302.364809,738,39.786229
-304.475175,725,40.354099
-302.940891,758,40.943596
-15.795655,325,41.197428
-305.404282,735,41.771184
-305.236132,723,42.330353
-305.94299,752,42.913521
-21.148425,353,43.189599
--52.330365,188,43.340062
-151.386882,659,43.851162
-90.424794,494,44.232565
-301.927897,770,44.826415
-304.757464,721,45.382777
-302.697478,760,45.96874
--22.39859,259,46.169842
-305.956405,723,46.73031
-305.664408,727,47.290006
-305.287772,717,47.841211
--22.082868,254,48.038846
-304.894636,746,48.616611
-303.331698,741,49.187315
-304.592178,747,49.760554
-305.563102,736,50.328383
-302.689926,746,50.902858
-300.186937,802,51.52006
--15.602549,276,51.733101
-94.133354,512,52.128198
-305.610114,726,52.68604
-298.369037,798,53.300444
-95.514182,510,53.69267
-301.644435,796,54.305062
-306.86572,713,54.854769
-304.306364,747,55.433634
-303.593145,771,56.037491
-39.926782,380,56.334806
--40.461412,217,56.50227
--71.978628,144,56.616035
-127.880317,642,57.109494
-96.525125,545,57.532723
-54.597902,435,57.869843
-57.3985,418,58.193319
-302.945615,740,58.763972
-73.387278,473,59.135621
-302.676462,753,59.71879
-304.187669,766,60.308922
-301.580132,810,60.937058
-300.510398,797,61.56038
-76.407432,559,62.000009
-304.560592,739,62.571082
-304.481442,771,63.166171
-303.622933,732,63.738984
-118.340882,581,64.186391
-304.182357,734,64.751831
-131.486683,644,65.245947
-302.13052,756,65.827673
-303.603819,753,66.405886
-304.842933,710,66.955831
-304.3627,772,67.550153
-106.366572,585,68.004118
-301.121862,819,68.633571
-303.438111,757,69.220238
-71.421583,675,69.735597
-303.017117,767,70.33188
-31.324711,356,70.612745
-301.624264,746,71.196275
-301.047988,801,71.811748
-303.461027,752,72.399489
-302.993697,748,72.975591
-304.899086,759,73.563025
-302.1779,776,74.158236
-305.593249,735,74.725231
-305.155352,741,75.293045
-302.828854,747,75.872387
--50.970516,283,76.090183
--39.507444,225,76.264404
-303.69021,749,76.843124
-300.996457,780,77.443655
-304.156677,746,78.018283
-303.679295,761,78.605159
--18.347728,291,78.831671
-156.79956,689,79.371623
-136.722458,660,79.888347
-64.885886,483,80.268309
-41.603007,411,80.592154
-304.288839,763,81.193667
-35.546111,373,81.487278
-304.284635,747,82.073154
--73.988576,149,82.193739
-142.886248,656,82.707909
-302.361418,750,83.298816
-304.1507,729,83.875597
-304.413431,767,84.478215
-305.101579,743,85.058137
-199.662759,723,85.623957
-305.104542,726,86.195029
-301.071559,796,86.818468
-301.539925,798,87.450715
-301.531898,815,88.096411
-302.795926,761,88.697836
-303.416066,744,89.286814
-30.285208,399,89.600553
--40.258154,207,89.764034
-51.430234,396,90.076918
-40.241742,405,90.39917
-305.114193,733,90.973302
-306.063059,728,91.546085
--47.185539,228,91.725956
-300.573573,790,92.342886
--54.868415,200,92.501091
-303.220558,760,93.100096
-304.546298,752,93.692549
-82.932796,473,94.063656
-47.441718,427,94.402995
-303.839423,725,94.966027
-305.277608,763,95.567472
-23.496617,367,95.856493
-128.678039,569,96.299928
--32.592752,223,96.476208
--43.753721,211,96.642032
-305.424057,751,97.228528
-113.063822,584,97.688616
-303.492005,756,98.289925
-304.841375,758,98.880774
-304.577313,758,99.476772
-301.156213,802,100.103982
-304.010025,783,100.716021
-302.667361,740,101.299324
-303.918448,740,101.882527
-301.425668,778,102.492409
-305.446357,715,103.055928
-92.767958,589,103.520237
-301.549759,799,104.138262
-303.163771,739,104.712989
-300.678757,794,105.329364
--41.254046,225,105.503753
-303.030253,793,106.118409
-162.468708,678,106.652562
-51.832027,451,107.011397
-298.00122,862,107.677496
-303.813658,724,108.236765
-303.409989,750,108.822069
-302.429025,762,109.423768
-160.56586,682,109.950236
-303.36268,747,110.535994
-304.766964,733,111.114457
-303.314112,736,111.684765
-303.582416,751,112.266338
-302.967288,736,112.846985
--70.653028,183,112.992295
-304.122437,744,113.57416
-164.183535,692,114.111993
-303.447712,736,114.683947
-302.214696,756,115.280121
-158.744251,733,115.843334
-137.7137,724,116.403516
-297.088702,858,117.077439
-303.803565,749,117.659656
-307.109074,717,118.218738
-170.724354,709,118.780417
-303.059612,752,119.374358
--75.075435,127,119.475801
+286.498863,1232,3.148969
+290.03395,1193,3.942982
+285.871843,1236,4.766715
+291.850978,1198,5.564058
+288.768871,1210,6.368981
+288.157038,1212,7.176008
+289.361869,1202,7.97352
+287.842606,1225,8.786903
+287.280252,1190,9.578171
+284.324796,1247,10.404098
+288.956237,1213,11.208744
+288.669449,1191,11.998618
+285.979899,1247,12.825689
+289.639171,1190,13.618726
+286.947732,1244,14.446376
+287.606982,1188,15.235259
+286.782178,1223,16.048814
+288.687124,1216,16.856583
+285.721313,1232,17.67472
+287.000542,1216,18.483637
+288.605325,1213,19.286785
+291.318049,1207,20.087602
+289.566242,1193,20.877623
+292.464293,1179,21.661781
+290.672828,1192,22.457295
+290.689305,1182,23.244957
+287.682502,1206,24.045446
+291.957954,1163,24.819003
+282.784247,1263,25.659568
+287.454218,1216,26.470194
+285.451189,1228,27.283888
+287.095007,1272,28.128968
+287.733682,1228,28.945501
+290.772435,1195,29.742402
+288.365004,1236,30.56567
+285.35095,1225,31.380461
+289.413562,1208,32.184891
+288.015387,1211,32.991751
+288.731107,1205,33.791121
+290.727256,1189,34.581942
+289.43828,1207,35.3862
+287.729774,1227,36.204048
+285.253831,1260,37.041651
+289.768125,1185,37.826303
+285.958823,1233,38.646922
+287.200568,1222,39.46325
+290.043164,1185,40.250997
+289.101907,1211,41.05368
+286.462372,1243,41.87645
+285.64948,1221,42.688791
+287.519667,1187,43.476118
+284.971831,1267,44.316454
+288.742228,1214,45.120938
+285.628617,1255,45.954174
+284.757907,1250,46.783232
+284.347205,1245,47.609643
+289.301662,1211,48.409139
+284.519905,1255,49.241136
+283.713437,1242,50.06516
+289.766618,1189,50.851352
+282.590464,1273,51.692035
+289.451637,1175,52.468679
+282.330197,1285,53.320945
+290.373129,1201,54.118149
+285.502483,1222,54.927324
+290.665951,1163,55.698771
+289.852728,1213,56.503432
+287.244561,1205,57.306046
+286.817512,1223,58.116854
+291.118836,1183,58.904925
+289.975692,1205,59.704293
+291.492401,1159,60.475756
+287.211862,1228,61.291739
+284.231949,1244,62.120625
+287.456086,1196,62.914544
+286.782568,1231,63.733352
+290.014788,1182,64.517536
+285.797936,1238,65.341084
+286.128281,1239,66.164153
+291.3156,1168,66.940698
+285.707421,1227,67.7551
+287.444993,1217,68.563281
+287.287142,1235,69.380708
+287.296313,1252,70.212887
+285.059782,1244,71.037529
+285.075845,1250,71.86698
+289.072542,1172,72.643633
+286.401297,1247,73.470394
+289.753727,1171,74.244725
+288.401924,1186,75.030774
+289.74776,1210,75.834194
+290.13916,1177,76.612051
+292.795935,1181,77.396483
+285.884796,1247,78.222464
+288.603669,1212,79.027307
+289.873579,1212,79.832399
+285.787775,1230,80.647568
+290.36039,1202,81.441948
+290.807051,1179,82.224544
+285.749472,1245,83.048675
+290.984798,1179,83.832279
+288.575334,1226,84.645881
+286.358977,1210,85.449784
+287.503688,1221,86.257498
+290.378601,1209,87.059999
+289.137484,1223,87.869361
+288.796009,1224,88.683386
+287.783809,1249,89.514697
+291.227006,1151,90.278997
+280.305489,1277,91.127263
+283.010979,1279,91.975935
+286.370462,1259,92.809959
+294.429234,1149,93.57178
+286.001682,1213,94.377335
+287.589849,1233,95.19614
+290.317775,1200,95.992744
+286.754871,1236,96.815568
+289.788937,1205,97.617576
+286.015599,1251,98.448357
+287.574138,1193,99.242593
+291.136559,1193,100.036958
+291.611669,1221,100.848697
+286.55367,1234,101.670848
diff --git a/logs/benchmark/ppo_lstm-CarRacing-v0/0.monitor.csv b/logs/benchmark/ppo_lstm-CarRacing-v0/0.monitor.csv
new file mode 100644
index 000000000..147b6093c
--- /dev/null
+++ b/logs/benchmark/ppo_lstm-CarRacing-v0/0.monitor.csv
@@ -0,0 +1,158 @@
+#{"t_start": 1654204830.0419693, "env_id": "CarRacing-v0"}
+r,l,t
+792.61745,1000,6.181207
+885.765125,1000,9.131669
+873.063973,1000,12.114203
+890.740741,1000,15.143978
+877.272727,1000,18.152218
+889.169675,1000,21.111168
+890.066225,1000,24.140795
+910.2,898,26.813054
+879.757085,1000,29.724176
+915.2,848,32.276649
+930.5,695,34.305526
+869.96997,1000,37.345273
+920.5,795,39.70204
+896.296296,1000,42.632593
+886.885246,1000,45.646431
+858.477509,1000,48.618872
+892.619926,1000,51.57379
+887.261146,1000,54.593082
+877.491961,1000,57.618382
+798.954704,1000,60.627211
+870.967742,1000,63.656299
+889.247312,1000,66.617474
+812.02346,1000,69.691475
+924.3,757,71.923429
+848.170732,1000,74.979877
+881.595092,1000,78.039151
+905.8,942,80.79825
+873.684211,1000,83.779894
+891.03139,1000,86.646081
+901.4,986,89.59265
+543.712575,1000,92.629697
+893.006993,1000,95.612338
+914.1,859,98.190614
+896.254682,1000,101.128834
+890.131579,1000,104.133495
+887.421384,1000,107.167422
+367.625899,1000,110.111629
+889.966555,1000,113.082222
+926.4,736,115.258553
+893.006993,1000,118.254707
+850.920245,1000,121.333169
+893.78882,1000,124.388279
+924.0,760,126.620384
+916.3,837,129.113128
+871.929825,1000,132.089347
+917.7,823,134.529402
+911.2,888,137.202349
+893.150685,1000,140.205283
+851.807229,1000,143.279442
+896.644295,1000,146.267079
+565.540541,1000,149.227145
+920.9,791,151.540307
+880.707395,1000,154.581006
+868.421053,1000,157.576324
+333.544304,1000,160.562149
+902.9,971,163.504401
+890.797546,1000,166.568459
+918.6,814,169.000762
+861.651917,1000,172.072073
+919.1,809,174.487601
+892.125984,1000,177.403417
+877.272727,1000,180.424361
+888.847584,1000,183.378517
+896.551724,1000,186.378563
+911.4,886,189.030784
+896.296296,1000,191.977422
+850.769231,1000,195.032133
+870.873786,1000,198.056901
+880.327869,1000,201.057348
+730.508475,1000,204.041407
+893.485342,1000,207.073689
+875.460123,1000,210.124961
+918.5,815,212.544901
+856.375839,1000,215.530396
+875.0,1000,218.497411
+821.875,1000,221.545459
+868.75,1000,224.616827
+893.485342,1000,227.639646
+876.510067,1000,230.643212
+750.793651,1000,233.652184
+874.683544,1000,236.677366
+919.6,804,239.046962
+668.421053,1000,242.005007
+922.7,773,244.307434
+857.928803,1000,247.306081
+916.1,839,249.813316
+873.684211,1000,252.75659
+886.013986,1000,255.756162
+886.15917,1000,258.746746
+887.220447,1000,261.774839
+914.9,851,264.3091
+873.59736,1000,267.316416
+845.619335,1000,270.370811
+928.3,717,272.490756
+896.855346,1000,275.528629
+701.775148,1000,278.577352
+908.0,920,281.343297
+925.1,749,283.526727
+874.169742,1000,286.490014
+896.563574,1000,289.479618
+907.6,924,292.206284
+862.962963,1000,295.190832
+893.569132,1000,298.233784
+883.050847,1000,301.260211
+876.744186,1000,304.196508
+878.417266,1000,307.165412
+794.409938,1000,310.230357
+885.130112,1000,313.173896
+773.239437,1000,316.211473
+880.645161,1000,319.220745
+893.127148,1000,322.228811
+872.972973,1000,325.207085
+890.196078,1000,328.239602
+862.616822,1000,331.313493
+884.126984,1000,334.355265
+910.9,891,337.007747
+896.466431,1000,339.998146
+886.30137,1000,342.973861
+908.3,917,345.713885
+879.166667,1000,348.697348
+914.3,857,351.258304
+921.4,786,353.599769
+913.8,862,356.203853
+851.890034,1000,359.209628
+886.885246,1000,362.247912
+921.4,786,364.595427
+909.0,910,367.31875
+408.928571,1000,370.375511
+884.496124,1000,373.32294
+889.169675,1000,376.328308
+736.956522,1000,379.304595
+880.263158,1000,382.357421
+863.69637,1000,385.383473
+870.149254,1000,388.316006
+887.755102,1000,391.249082
+874.110032,1000,394.259789
+885.964912,1000,397.251099
+886.254296,1000,400.253898
+890.595611,1000,403.304099
+917.8,822,405.757406
+496.923077,1000,408.775885
+882.269504,1000,411.756959
+892.509363,1000,414.711423
+889.51049,1000,417.672504
+933.8,662,419.605968
+901.9,981,422.587127
+871.061093,1000,425.625754
+509.204196,867,428.206288
+883.443709,1000,431.226022
+896.551724,1000,434.219728
+893.355482,1000,437.238052
+896.527778,1000,440.214532
+908.9,911,442.970268
+896.632997,1000,445.964749
+915.1,849,448.522999
+880.314961,1000,451.533052
diff --git a/logs/benchmark/ppo_lstm-CartPoleNoVel-v1/0.monitor.csv b/logs/benchmark/ppo_lstm-CartPoleNoVel-v1/0.monitor.csv
new file mode 100644
index 000000000..34af222c7
--- /dev/null
+++ b/logs/benchmark/ppo_lstm-CartPoleNoVel-v1/0.monitor.csv
@@ -0,0 +1,302 @@
+#{"t_start": 1654204464.909158, "env_id": "CartPoleNoVel-v1"}
+r,l,t
+500.0,500,2.692042
+500.0,500,3.061221
+500.0,500,3.429941
+500.0,500,3.798015
+500.0,500,4.16562
+500.0,500,4.535252
+500.0,500,4.903308
+500.0,500,5.270751
+500.0,500,5.639004
+500.0,500,6.007921
+500.0,500,6.376441
+500.0,500,6.744478
+500.0,500,7.112341
+500.0,500,7.482308
+500.0,500,7.849999
+500.0,500,8.217563
+500.0,500,8.585177
+500.0,500,8.953705
+500.0,500,9.321366
+500.0,500,9.688998
+500.0,500,10.056689
+500.0,500,10.425152
+500.0,500,10.792864
+500.0,500,11.160644
+500.0,500,11.528224
+500.0,500,11.897002
+500.0,500,12.264578
+500.0,500,12.632639
+500.0,500,13.000528
+500.0,500,13.369246
+500.0,500,13.737418
+500.0,500,14.104923
+500.0,500,14.472503
+500.0,500,14.840163
+500.0,500,15.209556
+500.0,500,15.577311
+500.0,500,15.944919
+500.0,500,16.312535
+500.0,500,16.68169
+500.0,500,17.04989
+500.0,500,17.417994
+500.0,500,17.78551
+500.0,500,18.154947
+500.0,500,18.52279
+500.0,500,18.890535
+500.0,500,19.258383
+500.0,500,19.627402
+500.0,500,19.995366
+500.0,500,20.363095
+500.0,500,20.730902
+500.0,500,21.100272
+500.0,500,21.467908
+500.0,500,21.836105
+500.0,500,22.203988
+500.0,500,22.573255
+500.0,500,22.941118
+500.0,500,23.309029
+500.0,500,23.676801
+500.0,500,24.045863
+500.0,500,24.413585
+500.0,500,24.781154
+500.0,500,25.14879
+500.0,500,25.517311
+500.0,500,25.885126
+500.0,500,26.252733
+500.0,500,26.620603
+500.0,500,26.989465
+500.0,500,27.357436
+500.0,500,27.725199
+500.0,500,28.092894
+500.0,500,28.460764
+500.0,500,28.828977
+500.0,500,29.196474
+500.0,500,29.564007
+500.0,500,29.931772
+500.0,500,30.300713
+500.0,500,30.668463
+500.0,500,31.036354
+500.0,500,31.404369
+500.0,500,31.773325
+500.0,500,32.14113
+500.0,500,32.508817
+500.0,500,32.876434
+500.0,500,33.245192
+500.0,500,33.61313
+500.0,500,33.980896
+500.0,500,34.348483
+500.0,500,34.71754
+500.0,500,35.085228
+500.0,500,35.453714
+500.0,500,35.821541
+500.0,500,36.190614
+500.0,500,36.558201
+500.0,500,36.925709
+500.0,500,37.293244
+500.0,500,37.662188
+500.0,500,38.029895
+500.0,500,38.397722
+500.0,500,38.76532
+500.0,500,39.133835
+500.0,500,39.501667
+500.0,500,39.868861
+500.0,500,40.236549
+500.0,500,40.605014
+500.0,500,40.972768
+500.0,500,41.340191
+500.0,500,41.70803
+500.0,500,42.075676
+500.0,500,42.444287
+500.0,500,42.811715
+500.0,500,43.180297
+500.0,500,43.547704
+500.0,500,43.916625
+500.0,500,44.2841
+500.0,500,44.65163
+500.0,500,45.019142
+500.0,500,45.388199
+500.0,500,45.75553
+500.0,500,46.123094
+500.0,500,46.49043
+500.0,500,46.859583
+500.0,500,47.227015
+500.0,500,47.594541
+500.0,500,47.961887
+500.0,500,48.3308
+500.0,500,48.698397
+500.0,500,49.065922
+500.0,500,49.433188
+500.0,500,49.802063
+500.0,500,50.169557
+500.0,500,50.537384
+500.0,500,50.904823
+500.0,500,51.273586
+500.0,500,51.641172
+500.0,500,52.00872
+500.0,500,52.376147
+500.0,500,52.745009
+500.0,500,53.112735
+500.0,500,53.480421
+500.0,500,53.847858
+500.0,500,54.216092
+500.0,500,54.583584
+500.0,500,54.951037
+500.0,500,55.318777
+500.0,500,55.686959
+500.0,500,56.054983
+500.0,500,56.422587
+500.0,500,56.789983
+500.0,500,57.157527
+500.0,500,57.526113
+500.0,500,57.893817
+500.0,500,58.261189
+500.0,500,58.628896
+500.0,500,58.997658
+500.0,500,59.365289
+500.0,500,59.732727
+500.0,500,60.100073
+500.0,500,60.468939
+500.0,500,60.8367
+500.0,500,61.204306
+500.0,500,61.571862
+500.0,500,61.940647
+500.0,500,62.308141
+500.0,500,62.675627
+500.0,500,63.04311
+500.0,500,63.411911
+500.0,500,63.779265
+500.0,500,64.146678
+500.0,500,64.514226
+500.0,500,64.882792
+500.0,500,65.250464
+500.0,500,65.617751
+500.0,500,65.98526
+500.0,500,66.35398
+500.0,500,66.721609
+500.0,500,67.089066
+500.0,500,67.456589
+500.0,500,67.824868
+500.0,500,68.192576
+500.0,500,68.560127
+500.0,500,68.927803
+500.0,500,69.296072
+500.0,500,69.664061
+500.0,500,70.031497
+500.0,500,70.398919
+500.0,500,70.766376
+500.0,500,71.134893
+500.0,500,71.502915
+500.0,500,71.870472
+500.0,500,72.237988
+500.0,500,72.607162
+500.0,500,72.974757
+500.0,500,73.342227
+500.0,500,73.709692
+500.0,500,74.078555
+500.0,500,74.446039
+500.0,500,74.813464
+500.0,500,75.180903
+500.0,500,75.549648
+500.0,500,75.917047
+500.0,500,76.284521
+500.0,500,76.651951
+500.0,500,77.020492
+500.0,500,77.387824
+500.0,500,77.755272
+500.0,500,78.122754
+500.0,500,78.491089
+500.0,500,78.858404
+500.0,500,79.225901
+500.0,500,79.593222
+500.0,500,79.961842
+500.0,500,80.329139
+500.0,500,80.696746
+500.0,500,81.064445
+500.0,500,81.432933
+500.0,500,81.800267
+500.0,500,82.16798
+500.0,500,82.535597
+500.0,500,82.904233
+500.0,500,83.272056
+500.0,500,83.6396
+500.0,500,84.007157
+500.0,500,84.37485
+500.0,500,84.743217
+500.0,500,85.110842
+500.0,500,85.478438
+500.0,500,85.845806
+500.0,500,86.214685
+500.0,500,86.582283
+500.0,500,86.949661
+500.0,500,87.317286
+500.0,500,87.686204
+500.0,500,88.053945
+500.0,500,88.421412
+500.0,500,88.788829
+500.0,500,89.157843
+500.0,500,89.52521
+500.0,500,89.892562
+500.0,500,90.260231
+500.0,500,90.629248
+500.0,500,90.997033
+500.0,500,91.364556
+500.0,500,91.73211
+500.0,500,92.100667
+500.0,500,92.468477
+500.0,500,92.836055
+500.0,500,93.203776
+500.0,500,93.572638
+500.0,500,93.940437
+500.0,500,94.307746
+500.0,500,94.675457
+500.0,500,95.044019
+500.0,500,95.412237
+500.0,500,95.779851
+500.0,500,96.147491
+500.0,500,96.516354
+500.0,500,96.884081
+500.0,500,97.251848
+500.0,500,97.619372
+500.0,500,97.987247
+500.0,500,98.355906
+500.0,500,98.723397
+500.0,500,99.091107
+500.0,500,99.458724
+500.0,500,99.827759
+500.0,500,100.195648
+500.0,500,100.563196
+500.0,500,100.930649
+500.0,500,101.299022
+500.0,500,101.6664
+500.0,500,102.034079
+500.0,500,102.401322
+500.0,500,102.770082
+500.0,500,103.137965
+500.0,500,103.505491
+500.0,500,103.872831
+500.0,500,104.24165
+500.0,500,104.609235
+500.0,500,104.976797
+500.0,500,105.344488
+500.0,500,105.713245
+500.0,500,106.080787
+500.0,500,106.448571
+500.0,500,106.815917
+500.0,500,107.185104
+500.0,500,107.552569
+500.0,500,107.919791
+500.0,500,108.287235
+500.0,500,108.655843
+500.0,500,109.023948
+500.0,500,109.391358
+500.0,500,109.758862
+500.0,500,110.127638
+500.0,500,110.495283
+500.0,500,110.862811
+500.0,500,111.23066
+500.0,500,111.598368
+500.0,500,111.966573
+500.0,500,112.334113
+500.0,500,112.701541
diff --git a/logs/benchmark/ppo_lstm-MountainCarContinuousNoVel-v0/0.monitor.csv b/logs/benchmark/ppo_lstm-MountainCarContinuousNoVel-v0/0.monitor.csv
new file mode 100644
index 000000000..894fd1bf5
--- /dev/null
+++ b/logs/benchmark/ppo_lstm-MountainCarContinuousNoVel-v0/0.monitor.csv
@@ -0,0 +1,1342 @@
+#{"t_start": 1654204713.1905868, "env_id": "MountainCarContinuousNoVel-v0"}
+r,l,t
+88.870471,129,2.384638
+93.200782,96,2.452994
+93.1153,103,2.524679
+89.783776,122,2.609519
+89.084276,127,2.697699
+93.17103,103,2.769043
+90.272421,119,2.85192
+93.165781,103,2.923424
+88.882704,129,3.012742
+92.67766,106,3.086247
+88.822368,129,3.17575
+93.166806,96,3.242452
+91.849978,111,3.319841
+93.109522,103,3.391466
+93.15163,97,3.458845
+93.196885,103,3.530293
+91.575946,113,3.608723
+93.184759,98,3.676693
+88.82424,129,3.766255
+93.123139,98,3.834311
+93.250074,102,3.905042
+88.706054,130,3.997004
+90.768148,117,4.078424
+93.191438,103,4.149975
+93.216239,99,4.219114
+93.197325,102,4.290391
+93.191193,104,4.362637
+89.107044,127,4.450638
+92.474766,108,4.525669
+89.620468,123,4.610974
+89.775248,122,4.695605
+89.922732,121,4.779866
+92.583188,107,4.854172
+93.18132,98,4.922286
+90.477477,118,5.004337
+90.292669,119,5.086857
+93.172218,98,5.154969
+89.948811,121,5.239173
+92.363752,108,5.314215
+89.611613,123,5.399477
+89.917513,121,5.484584
+93.125296,97,5.551962
+88.623945,131,5.642749
+89.649255,123,5.728098
+93.184295,96,5.794691
+89.616299,123,5.88021
+93.133919,98,5.948249
+93.207498,99,6.016968
+92.764831,106,6.090605
+93.144456,98,6.158695
+88.950705,128,6.247501
+89.123758,127,6.335794
+93.219048,98,6.40382
+89.781542,122,6.488462
+89.782177,122,6.573397
+89.777733,122,6.658286
+93.185252,103,6.73019
+93.182929,97,6.797729
+93.113943,96,6.864768
+89.533839,124,6.9512
+93.194117,99,7.020988
+89.177635,126,7.108326
+89.107466,127,7.196391
+93.085148,103,7.267827
+92.916511,105,7.340834
+88.98702,128,7.429554
+93.1925,102,7.50033
+89.899999,121,7.584371
+93.127303,98,7.652373
+93.208642,104,7.72451
+89.002738,128,7.813232
+93.189193,104,7.885385
+88.741387,130,7.97542
+89.93632,121,8.059277
+91.029743,115,8.138972
+92.21992,109,8.214536
+89.764298,122,8.299145
+93.113999,96,8.365746
+93.168898,103,8.437198
+93.169668,103,8.509602
+88.852565,129,8.599525
+93.19289,103,8.671314
+88.702688,130,8.761684
+90.096848,120,8.845084
+93.110618,103,8.916838
+93.025518,104,8.988879
+91.864208,111,9.065803
+93.194416,104,9.137887
+90.487003,118,9.219725
+92.309221,109,9.295581
+89.795107,122,9.380244
+89.095942,127,9.468391
+89.369037,125,9.55494
+90.125949,120,9.638249
+93.172738,103,9.709915
+93.145475,99,9.77876
+90.83815,116,9.859453
+93.21489,101,9.929499
+92.092596,110,10.007234
+89.644357,123,10.092857
+88.840014,129,10.182201
+89.791642,122,10.266649
+93.124155,103,10.33812
+93.176343,100,10.407616
+89.465783,124,10.493598
+89.934026,121,10.577511
+93.21387,104,10.649803
+89.529739,124,10.735824
+89.64373,123,10.821223
+93.175665,96,10.88802
+88.988982,128,10.976886
+92.601194,107,11.051019
+92.896729,105,11.123809
+93.18996,103,11.195364
+89.340725,125,11.282209
+93.219392,98,11.350248
+89.629466,123,11.435549
+92.396533,108,11.511843
+93.214856,102,11.582583
+93.228063,103,11.653995
+90.320025,119,11.736438
+90.09917,120,11.819752
+93.219587,102,11.890496
+89.350777,125,11.97716
+93.213126,103,12.048732
+93.169643,96,12.115289
+89.61667,123,12.200474
+93.107148,103,12.271923
+89.523964,124,12.357987
+88.620873,131,12.44872
+89.91964,121,12.532513
+88.721763,130,12.622581
+92.196586,109,12.698182
+88.636467,131,12.789118
+89.107866,127,12.877657
+93.201758,103,12.949092
+88.962481,128,13.03885
+89.350487,125,13.125727
+89.491968,124,13.211623
+92.608824,107,13.285851
+93.173034,103,13.357493
+93.159943,103,13.428843
+93.043704,104,13.501012
+89.452637,124,13.586959
+88.882335,129,13.676363
+88.848431,129,13.765942
+90.910626,116,13.846459
+93.225353,100,13.915782
+93.226799,101,13.985793
+93.180982,99,14.054471
+92.90149,105,14.127404
+89.75725,122,14.211995
+89.312771,125,14.298653
+89.477136,124,14.384778
+90.083796,120,14.468117
+89.922336,121,14.553396
+93.193835,104,14.625931
+92.894112,105,14.698986
+93.088068,103,14.770371
+89.601074,123,14.855972
+88.971711,128,14.944741
+93.201646,96,15.011277
+90.674297,117,15.092951
+93.153273,100,15.162291
+88.623838,131,15.253053
+93.190221,103,15.324453
+90.478143,118,15.406408
+93.207232,101,15.476441
+93.190314,103,15.547894
+89.612589,123,15.633731
+89.954368,121,15.717603
+89.129356,127,15.806335
+93.113853,96,15.872919
+93.207797,103,15.944339
+93.25996,103,16.015861
+92.066858,110,16.093505
+89.924772,121,16.177347
+93.158804,98,16.245262
+89.670423,123,16.330824
+89.930267,121,16.414818
+93.112147,103,16.486212
+93.162887,104,16.558339
+89.788241,122,16.643117
+89.057092,127,16.73131
+89.387116,125,16.817916
+90.110195,120,16.901234
+91.070569,115,16.980884
+89.111102,127,17.068882
+92.685835,106,17.142292
+93.19072,103,17.213887
+93.160701,97,17.2811
+89.642763,123,17.366416
+93.162671,103,17.438048
+89.523522,124,17.524099
+89.32233,125,17.611737
+92.276643,109,17.68733
+89.808876,122,17.772177
+88.634222,131,17.86334
+93.106995,97,17.930629
+93.204532,98,17.998586
+89.460042,124,18.084498
+93.206819,102,18.15522
+93.242963,101,18.225204
+92.985313,104,18.297466
+88.856539,129,18.387129
+93.11965,103,18.458685
+93.192493,96,18.525193
+92.689761,106,18.598905
+88.760828,130,18.689154
+92.129061,110,18.765424
+89.670486,123,18.850925
+93.198623,96,18.917644
+89.523987,124,19.003566
+93.151354,100,19.074666
+93.131613,99,19.143501
+93.008502,104,19.215617
+93.198418,103,19.287133
+93.225091,102,19.357944
+93.125307,97,19.425234
+92.75004,106,19.498718
+91.132038,115,19.578646
+89.92194,121,19.662881
+92.856527,105,19.735739
+93.194141,104,19.807772
+93.173787,99,19.87659
+93.202117,103,19.947971
+89.633954,123,20.033147
+92.538812,107,20.107291
+89.666701,123,20.192489
+88.841411,129,20.281841
+88.75991,130,20.372189
+88.748129,130,20.462423
+93.173209,99,20.531058
+89.058857,127,20.620422
+93.012135,104,20.692664
+92.86422,105,20.76547
+93.233717,101,20.835612
+92.290489,109,20.911277
+90.078759,120,20.994437
+92.964467,104,21.06681
+93.222286,99,21.135419
+89.402433,125,21.222019
+89.636444,123,21.307191
+92.278216,109,21.382867
+89.516076,124,21.468837
+93.16054,103,21.540472
+93.153414,98,21.608481
+89.114463,127,21.69649
+93.15217,103,21.768385
+92.064927,110,21.844795
+89.246397,126,21.932109
+93.09074,103,22.003629
+90.070595,120,22.088371
+93.253679,104,22.16074
+90.128886,120,22.244342
+92.918405,105,22.317087
+93.160509,103,22.388608
+93.125584,103,22.460013
+89.10934,127,22.547946
+88.755254,130,22.637953
+92.762401,106,22.711475
+88.991092,128,22.800161
+93.155405,100,22.869831
+89.802453,122,22.954444
+92.975223,104,23.026471
+93.168904,104,23.098666
+93.193803,104,23.17078
+93.162262,100,23.240085
+93.168439,97,23.307513
+89.916092,121,23.391437
+93.134174,97,23.458661
+89.360341,125,23.545413
+89.647433,123,23.632222
+89.619042,123,23.717508
+93.215812,97,23.784776
+93.253447,104,23.857316
+89.373429,125,23.944042
+93.223484,101,24.014015
+93.186859,99,24.082722
+92.272831,109,24.158974
+89.357262,125,24.245987
+92.85714,105,24.318738
+89.209741,126,24.406289
+89.629,123,24.491523
+93.121792,97,24.558794
+89.503047,124,24.644727
+89.747775,122,24.729302
+90.758879,117,24.810816
+89.786,122,24.89551
+88.975371,128,24.984213
+93.155645,103,25.055868
+93.206052,98,25.124781
+91.965353,111,25.201982
+89.350126,125,25.288644
+93.242769,102,25.359886
+89.063908,127,25.4481
+90.760517,117,25.529359
+93.198117,100,25.599029
+93.202919,96,25.665623
+93.199189,100,25.7349
+93.189083,96,25.801679
+92.917035,105,25.87483
+89.196708,126,25.962184
+93.167242,103,26.033517
+92.022104,110,26.10972
+93.168998,96,26.176264
+89.487654,124,26.262158
+88.748209,130,26.352381
+89.500553,124,26.438342
+93.110556,96,26.505068
+89.232873,126,26.592495
+93.221687,98,26.66114
+93.225378,103,26.732705
+93.154196,97,26.800051
+93.146876,103,26.872393
+89.263686,126,26.959697
+89.389665,125,27.046299
+90.113798,120,27.129411
+90.132993,120,27.212542
+88.884437,129,27.301854
+89.17863,126,27.389226
+88.935867,128,27.47791
+89.455394,124,27.5641
+93.223059,98,27.632195
+88.627508,131,27.722929
+93.186658,103,27.79449
+93.210087,103,27.866216
+92.467354,108,27.941114
+93.180657,103,28.012656
+89.910636,121,28.096602
+89.915205,121,28.181555
+93.13226,103,28.253097
+89.512654,124,28.339037
+93.152805,100,28.408373
+88.634751,131,28.499069
+89.786068,122,28.583945
+89.496703,124,28.669853
+93.129118,96,28.736562
+89.616638,123,28.821762
+93.167541,99,28.890456
+93.145435,96,28.957038
+93.208513,98,29.024989
+89.902364,121,29.109222
+93.124172,103,29.180632
+93.160502,103,29.252065
+89.198501,126,29.339431
+88.718295,130,29.429578
+93.144012,98,29.497785
+93.239379,101,29.567852
+90.29559,119,29.651159
+88.935028,128,29.740124
+89.369332,125,29.826837
+90.903542,116,29.907273
+92.867716,105,29.980012
+93.109769,97,30.047182
+89.916482,121,30.131198
+89.947361,121,30.215031
+92.384285,108,30.289849
+93.168302,100,30.359423
+93.195546,103,30.430832
+93.132245,96,30.497771
+93.11245,103,30.569238
+89.644505,123,30.654446
+93.212769,103,30.725943
+93.207232,101,30.795945
+92.592211,107,30.870552
+89.518665,124,30.956553
+93.208976,97,31.023806
+93.167488,98,31.091734
+93.15612,103,31.163891
+93.123346,103,31.235541
+90.96735,116,31.315909
+89.780526,122,31.400698
+92.107803,110,31.477023
+92.854429,105,31.549801
+88.828096,129,31.639183
+93.172403,100,31.708567
+93.213278,103,31.780102
+92.27694,109,31.855751
+89.632061,123,31.940945
+89.494009,124,32.026816
+89.252692,126,32.114596
+93.127196,97,32.181916
+93.165562,96,32.248488
+93.191529,104,32.320741
+93.118865,96,32.387803
+88.706759,130,32.477954
+93.163357,97,32.545473
+90.654345,117,32.626546
+92.824101,105,32.700063
+93.17086,100,32.769766
+89.926913,121,32.854058
+93.172318,101,32.930846
+89.248442,126,33.023205
+93.208802,97,33.091722
+93.163235,103,33.16315
+93.153762,99,33.232074
+89.499331,124,33.318104
+92.996035,104,33.390788
+89.925275,121,33.475083
+93.199561,97,33.542899
+89.773787,122,33.628055
+93.20984,103,33.699708
+93.171011,97,33.766988
+93.133729,103,33.838587
+88.755047,130,33.928791
+89.784374,122,34.013383
+88.964799,128,34.102243
+89.523777,124,34.189713
+93.16526,99,34.258537
+93.103153,103,34.330007
+90.11437,120,34.4134
+93.14627,103,34.484855
+88.993536,128,34.573855
+93.213609,103,34.645761
+93.027042,104,34.718464
+93.194417,104,34.790818
+93.207783,101,34.861243
+92.486259,108,34.936146
+93.168404,96,35.002705
+93.202997,103,35.074059
+89.3334,125,35.16073
+93.19208,104,35.232822
+93.1181,98,35.300769
+89.939804,121,35.384792
+88.937494,128,35.473511
+91.866324,111,35.550756
+93.166404,100,35.620043
+93.114378,96,35.687636
+93.127401,103,35.759539
+92.366841,108,35.834608
+88.994289,128,35.923328
+90.725429,117,36.004556
+93.140582,97,36.071855
+89.663881,123,36.157084
+89.74507,122,36.241658
+93.155232,97,36.308902
+92.30263,109,36.384618
+93.129231,98,36.452696
+88.968342,128,36.541399
+89.796895,122,36.626001
+89.761916,122,36.71057
+93.184929,103,36.781954
+89.233469,126,36.86993
+92.125202,110,36.946249
+93.215979,100,37.015587
+88.979871,128,37.104365
+92.689584,106,37.17804
+93.167541,99,37.247573
+92.750778,106,37.321238
+90.510861,118,37.403147
+93.209345,103,37.47462
+93.150284,96,37.541144
+89.927138,121,37.625348
+93.201393,102,37.696249
+89.25032,126,37.783566
+93.118024,97,37.850921
+90.676998,117,37.931979
+92.997761,104,38.004241
+88.967295,128,38.093081
+93.19607,101,38.163158
+93.143988,98,38.231305
+93.197709,100,38.300707
+89.929647,121,38.384711
+89.944135,121,38.468504
+93.142617,103,38.539954
+90.871517,116,38.62026
+93.052779,104,38.692498
+89.229239,126,38.780882
+93.203135,100,38.850405
+88.621118,131,38.941623
+93.168939,99,39.01034
+93.185535,103,39.081716
+93.231659,102,39.152423
+92.92264,105,39.226221
+93.225727,104,39.298454
+89.359009,125,39.385287
+92.902836,105,39.458488
+93.226793,101,39.528496
+93.144704,97,39.595832
+93.186947,103,39.667278
+89.253536,126,39.754577
+88.74445,130,39.844696
+93.197189,98,39.91293
+93.012592,104,39.985068
+90.112836,120,40.068262
+89.35922,125,40.154878
+93.234469,101,40.225098
+93.115151,96,40.292018
+90.75269,117,40.37327
+90.687021,117,40.454548
+93.140207,96,40.521316
+93.157896,97,40.588768
+89.941241,121,40.672827
+93.191973,96,40.739504
+93.194094,99,40.808182
+88.818852,129,40.897855
+93.176703,102,40.968657
+93.106451,103,41.040064
+89.777162,122,41.124847
+89.761014,122,41.209603
+93.013559,104,41.281681
+93.21945,99,41.350535
+92.753054,106,41.425454
+90.683749,117,41.507115
+89.488384,124,41.593059
+93.175812,97,41.660291
+89.188212,126,41.748793
+93.12555,103,41.82026
+90.268513,119,41.902782
+89.199857,126,41.990211
+88.993695,128,42.07894
+89.383569,125,42.166335
+89.785771,122,42.250844
+89.919912,121,42.33491
+92.95706,104,42.407031
+89.948384,121,42.491248
+88.629588,131,42.582094
+93.151478,100,42.65139
+93.054311,104,42.723538
+89.897591,121,42.807677
+88.621066,131,42.898912
+93.154143,97,42.966727
+93.178637,103,43.038454
+93.089263,103,43.109984
+90.718912,117,43.191035
+92.617755,107,43.2665
+89.606326,123,43.352129
+93.211167,103,43.423663
+89.763147,122,43.508456
+89.248391,126,43.595964
+89.629351,123,43.681384
+93.143575,99,43.750002
+93.085616,103,43.821568
+89.785497,122,43.906287
+89.947794,121,43.99036
+90.118474,120,44.073502
+89.92278,121,44.15729
+89.9201,121,44.241117
+93.183763,99,44.30986
+88.695989,130,44.400143
+89.628294,123,44.485366
+92.5964,107,44.559591
+88.842441,129,44.648939
+89.761056,122,44.733547
+93.165847,99,44.803622
+93.253447,104,44.875814
+93.170444,101,44.945834
+89.259561,126,45.033162
+93.127176,103,45.104685
+89.903958,121,45.188509
+92.918127,105,45.261828
+89.366676,125,45.348496
+89.772839,122,45.433375
+89.481283,124,45.519332
+93.175499,97,45.586553
+93.143193,96,45.653127
+92.631217,107,45.727495
+93.209786,103,45.798879
+92.975968,104,45.871031
+90.136063,120,45.954417
+88.846936,129,46.043762
+93.185796,104,46.115808
+88.981183,128,46.204436
+90.317074,119,46.288286
+88.998623,128,46.377084
+89.353752,125,46.463748
+89.658448,123,46.549034
+89.345372,125,46.635935
+93.11071,97,46.703308
+90.312375,119,46.786139
+93.222617,99,46.854967
+89.776182,122,46.939637
+88.843542,129,47.029286
+93.177785,96,47.095961
+92.852356,105,47.169558
+89.912827,121,47.254726
+93.186609,101,47.324791
+92.591294,107,47.399243
+93.121397,96,47.46606
+93.190033,103,47.537537
+93.199546,96,47.604137
+93.197168,98,47.672517
+93.135044,96,47.739124
+89.321169,125,47.828504
+88.827483,129,47.918107
+91.943017,111,47.995121
+92.76017,106,48.068854
+93.056614,104,48.140908
+89.601065,123,48.226178
+93.243733,101,48.298943
+93.152136,96,48.366921
+92.556485,107,48.442136
+89.520444,124,48.530309
+89.222777,126,48.618885
+93.125304,103,48.691462
+89.324305,125,48.780459
+92.280802,109,48.857474
+93.219473,102,48.92973
+93.153015,98,48.998859
+93.120789,103,49.071363
+93.138183,98,49.141435
+93.228597,98,49.210526
+93.19371,101,49.28223
+89.379544,125,49.371821
+89.497615,124,49.45918
+89.637992,123,49.545283
+93.223862,98,49.614473
+93.1656,98,49.68358
+88.979888,128,49.773801
+93.246854,103,49.846238
+89.102574,127,49.935801
+90.298708,119,50.019345
+88.85704,129,50.109801
+92.834106,105,50.183894
+89.319573,125,50.271451
+93.227107,103,50.344268
+89.472779,124,50.43175
+93.013849,104,50.505059
+89.951172,121,50.589867
+93.052085,104,50.662871
+93.141053,98,50.731702
+89.460239,124,50.819522
+93.205993,103,50.891895
+88.982632,128,50.982142
+91.564982,113,51.061148
+89.956091,121,51.146376
+89.315634,125,51.234005
+93.227301,99,51.303588
+92.289065,109,51.380502
+92.075522,110,51.458128
+92.763744,106,51.532413
+93.092027,103,51.60506
+93.148699,97,51.673021
+90.738986,117,51.755298
+89.786,122,51.841087
+90.916311,116,51.922384
+92.74814,106,51.996851
+93.122982,96,52.063997
+93.135399,97,52.132061
+89.928553,121,52.216866
+93.188208,97,52.285007
+89.340955,125,52.373165
+93.18967,103,52.445853
+89.913332,121,52.530457
+88.94619,128,52.620269
+89.224781,126,52.708595
+93.169615,101,52.779431
+89.600619,123,52.865946
+93.085288,103,52.937934
+93.044662,104,53.01129
+93.156381,98,53.079892
+88.639904,131,53.171839
+89.120099,127,53.260688
+92.127328,110,53.337501
+93.148752,98,53.406463
+90.940909,116,53.487679
+93.053845,104,53.560072
+92.680192,106,53.633967
+93.159634,96,53.700949
+89.636492,123,53.786931
+93.146944,103,53.859663
+93.124236,103,53.931946
+93.153876,101,54.002323
+89.249385,126,54.092658
+89.928525,121,54.177735
+89.610838,123,54.265309
+89.920205,121,54.349269
+93.162726,103,54.42082
+93.209508,103,54.492306
+93.149278,103,54.563904
+89.121806,127,54.652085
+93.238805,100,54.722352
+93.168054,96,54.788999
+89.621739,123,54.874501
+89.766015,122,54.959018
+89.3334,125,55.045774
+89.793834,122,55.130327
+93.205812,97,55.197545
+89.345921,125,55.284811
+89.107265,127,55.374524
+88.624113,131,55.465456
+90.302261,119,55.548009
+93.225227,101,55.618049
+93.129502,97,55.685413
+93.238881,103,55.756884
+93.174922,97,55.824229
+93.229078,102,55.895712
+93.223212,104,55.967892
+93.185609,101,56.038198
+90.465933,118,56.120048
+93.183188,100,56.189414
+93.138151,99,56.258034
+89.252814,126,56.345441
+93.117953,97,56.413848
+93.210636,99,56.483658
+89.759015,122,56.568497
+93.13718,103,56.639938
+93.089389,103,56.711398
+90.106928,120,56.794839
+92.278123,109,56.871873
+89.258735,126,56.959299
+93.184611,101,57.02932
+90.090583,120,57.113042
+89.518188,124,57.199065
+93.111143,96,57.265657
+93.24386,100,57.335033
+93.164215,103,57.406752
+93.153225,103,57.478143
+89.647018,123,57.563354
+89.212664,126,57.650669
+93.193619,104,57.723444
+88.630405,131,57.814351
+93.159632,99,57.883154
+92.969876,104,57.955209
+93.257775,104,58.027567
+93.204367,103,58.099015
+93.176598,103,58.170403
+88.980221,128,58.259057
+93.185682,103,58.331519
+93.153443,97,58.400393
+93.190813,100,58.469758
+93.141555,103,58.541677
+90.111326,120,58.624851
+92.891634,105,58.698087
+90.9731,116,58.778599
+93.178179,98,58.847127
+91.956352,111,58.925008
+92.581149,107,58.99934
+89.372052,125,59.086059
+93.041503,104,59.158349
+89.785216,122,59.243069
+89.488451,124,59.329066
+89.897591,121,59.41295
+92.828202,105,59.486279
+88.884087,129,59.575779
+90.849976,116,59.65658
+88.841235,129,59.746313
+88.873501,129,59.835773
+90.641497,117,59.91855
+93.165524,99,59.988173
+91.761282,112,60.065935
+89.351063,125,60.153204
+89.464234,124,60.239187
+88.697457,130,60.329204
+88.728205,130,60.419264
+89.927493,121,60.503114
+93.206606,97,60.570374
+93.185289,100,60.639809
+93.191592,103,60.712137
+89.20037,126,60.799514
+93.190872,101,60.869609
+89.903345,121,60.953606
+93.131978,97,61.021408
+93.118777,96,61.088077
+88.847335,129,61.177588
+93.033803,104,61.250071
+89.464365,124,61.336578
+93.180778,99,61.406906
+93.161971,98,61.474979
+89.109976,127,61.563004
+89.234333,126,61.650808
+93.172672,98,61.719018
+88.817255,129,61.808419
+89.255953,126,61.89651
+93.16292,98,61.964486
+89.926827,121,62.048369
+93.162855,103,62.119741
+89.928657,121,62.203618
+92.71856,106,62.277082
+92.894733,105,62.349883
+92.588499,107,62.424174
+89.505686,124,62.510524
+88.844126,129,62.600293
+88.616758,131,62.691345
+89.913655,121,62.775189
+89.324195,125,62.862217
+90.855537,116,62.94439
+89.898845,121,63.028467
+93.168941,103,63.100612
+93.170321,98,63.168791
+93.173278,99,63.237434
+92.277452,109,63.312944
+88.936667,128,63.401632
+93.153032,103,63.47322
+93.209119,103,63.544656
+89.224854,126,63.631994
+88.876408,129,63.722192
+89.620409,123,63.807993
+89.930247,121,63.891977
+93.220525,101,63.961983
+89.466729,124,64.048094
+93.208772,97,64.115474
+89.78016,122,64.201033
+89.124359,127,64.289404
+93.256853,104,64.362192
+93.240526,103,64.435313
+93.192691,96,64.50199
+93.14142,103,64.573438
+91.046669,115,64.653547
+93.026223,104,64.725852
+93.139495,99,64.794516
+90.093082,120,64.877918
+93.170354,104,64.950914
+93.174745,103,65.02263
+89.079809,127,65.110618
+92.302363,109,65.186149
+92.685547,106,65.259686
+88.753068,130,65.349779
+89.229446,126,65.437122
+93.201758,103,65.508638
+93.169835,103,65.580437
+93.21127,98,65.648366
+92.703945,106,65.722057
+92.052723,110,65.798472
+89.645036,123,65.884053
+93.242601,103,65.956467
+93.166618,96,66.023062
+93.174006,100,66.093184
+93.202306,99,66.161842
+89.796869,122,66.246539
+93.196911,97,66.313801
+93.167182,103,66.385406
+92.115416,110,66.461798
+89.31029,125,66.548394
+93.119352,103,66.619787
+93.216244,98,66.687828
+89.397224,125,66.774605
+93.119545,103,66.846382
+92.900316,105,66.919612
+89.624073,123,67.004971
+89.49597,124,67.090888
+89.936968,121,67.174706
+93.136101,97,67.241969
+93.145402,96,67.308531
+89.926837,121,67.392574
+89.323361,125,67.480177
+89.791346,122,67.564778
+89.003521,128,67.653583
+92.638118,107,67.727788
+93.165807,103,67.799167
+93.213154,103,67.870732
+93.185303,103,67.942335
+89.517466,124,68.028608
+93.03569,104,68.100771
+90.104021,120,68.184289
+93.220299,102,68.254989
+89.904513,121,68.338781
+89.250115,126,68.42628
+92.892057,105,68.499012
+92.310724,109,68.574544
+93.260065,104,68.646877
+93.1623,103,68.718292
+90.530494,118,68.800266
+89.784091,122,68.884984
+89.766282,122,68.970518
+90.974708,116,69.05096
+92.465063,108,69.125746
+93.088382,103,69.197949
+89.526902,124,69.284006
+93.119194,103,69.355394
+92.736793,106,69.428912
+93.175418,99,69.497591
+93.199603,103,69.569034
+92.729968,106,69.64257
+93.175916,103,69.714298
+89.06635,127,69.802424
+93.148365,100,69.87172
+89.63516,123,69.957006
+89.095529,127,70.044955
+93.153171,103,70.116961
+93.179808,103,70.188952
+93.170827,103,70.260363
+89.785335,122,70.344972
+93.19421,103,70.416641
+88.996297,128,70.507539
+93.173689,102,70.5786
+93.208868,103,70.650165
+88.934949,128,70.738872
+89.904938,121,70.822806
+89.386822,125,70.90975
+93.137042,97,70.976971
+90.32513,119,71.059464
+89.645893,123,71.145229
+89.192003,126,71.232438
+89.901771,121,71.316207
+89.783699,122,71.400936
+93.133994,96,71.467522
+89.217632,126,71.555593
+92.902842,105,71.628335
+89.622455,123,71.713611
+89.79816,122,71.798161
+93.196793,97,71.865431
+89.935996,121,71.949824
+89.671654,123,72.036268
+92.721186,106,72.109731
+89.111197,127,72.197746
+88.936655,128,72.286493
+89.525556,124,72.372501
+93.108443,97,72.440234
+93.168253,100,72.509655
+89.626806,123,72.595053
+90.316942,119,72.67776
+89.66335,123,72.763368
+93.043299,104,72.836236
+93.238793,100,72.905641
+91.669079,112,72.983637
+89.794558,122,73.068152
+89.799978,122,73.152812
+90.293333,119,73.235365
+89.765886,122,73.319812
+89.938376,121,73.403628
+93.108047,97,73.471665
+93.126032,103,73.543578
+89.097076,127,73.631557
+89.326869,125,73.718272
+90.527638,118,73.800032
+89.368039,125,73.886836
+91.521559,113,73.965156
+91.742946,112,74.04279
+89.372316,125,74.129733
+92.902882,105,74.202828
+93.037962,104,74.275032
+93.22387,100,74.344827
+92.818768,105,74.417674
+92.999126,104,74.489882
+89.325082,125,74.576533
+93.206483,103,74.64801
+93.201267,99,74.716963
+88.741867,130,74.807455
+93.125454,103,74.87907
+88.729311,130,74.971007
+90.30481,119,75.054453
+93.201946,96,75.121042
+92.625267,107,75.195196
+88.848425,129,75.28473
+93.156069,98,75.352745
+89.481972,124,75.438712
+93.153056,103,75.510172
+91.313341,114,75.589355
+93.125562,97,75.656726
+93.180688,103,75.728171
+89.357397,125,75.815309
+89.001509,128,75.904256
+92.850497,105,75.977141
+93.16963,100,76.046458
+93.15178,101,76.116592
+93.194917,96,76.183227
+93.167975,102,76.253978
+93.153141,97,76.321224
+91.768443,112,76.399175
+93.178956,103,76.470708
+88.714545,130,76.562491
+93.203786,103,76.63395
+90.084093,120,76.717079
+88.761874,130,76.80734
+93.181487,103,76.878703
+89.782177,122,76.96331
+88.981619,128,77.051973
+93.202422,102,77.122611
+93.181872,103,77.19395
+93.191335,97,77.261143
+90.438809,118,77.34304
+93.160755,103,77.414633
+93.125652,103,77.486283
+89.915219,121,77.570083
+89.059606,127,77.658573
+93.195899,103,77.730296
+88.756015,130,77.820449
+92.282541,109,77.896095
+89.670907,123,77.981625
+92.36919,108,78.057579
+90.130027,120,78.141079
+91.540973,113,78.21956
+93.171519,102,78.29026
+91.142865,115,78.369956
+93.148635,97,78.437427
+93.196236,103,78.509104
+93.207615,103,78.581518
+93.138783,96,78.64813
+89.534245,124,78.734275
+93.20022,97,78.801528
+88.874818,129,78.891303
+92.227428,109,78.967328
+90.515465,118,79.049043
+89.114129,127,79.136993
+89.652359,123,79.222402
+93.091442,103,79.293814
+93.157742,97,79.361088
+89.609755,123,79.446616
+89.608427,123,79.53272
+93.137938,96,79.599425
+90.117153,120,79.682692
+93.159527,97,79.749966
+90.870113,116,79.830463
+88.76285,130,79.92064
+93.20217,103,79.992041
+89.11846,127,80.080197
+92.110418,110,80.156546
+90.094376,120,80.239767
+93.154491,102,80.310447
+93.171697,97,80.377777
+93.174483,96,80.444614
+93.205037,102,80.515415
+93.178626,98,80.583456
+93.22334,103,80.655239
+93.17101,97,80.722522
+92.97526,104,80.794599
+93.131802,103,80.866057
+88.730404,130,80.956301
+93.193534,103,81.028696
+89.126083,127,81.11715
+90.120245,120,81.200412
+90.849409,116,81.280774
+89.90991,121,81.364822
+93.175579,98,81.432969
+93.113936,97,81.500236
+92.140818,110,81.576479
+90.265286,119,81.658954
+92.563074,107,81.733215
+90.439933,118,81.814929
+88.631962,131,81.905742
+93.054351,104,81.97786
+88.762079,130,82.068028
+88.879246,129,82.157396
+92.689315,106,82.230889
+93.19242,97,82.298214
+93.010691,104,82.370269
+90.865029,116,82.450739
+89.313815,125,82.538231
+93.114758,103,82.610257
+90.095302,120,82.693957
+91.091859,115,82.773816
+93.096671,103,82.845246
+90.669573,117,82.927001
+89.118716,127,83.014992
+89.186415,126,83.102641
+92.765807,106,83.176044
+93.161325,103,83.247456
+93.235664,101,83.317414
+93.12112,103,83.389264
+93.142453,96,83.455893
+93.252322,102,83.526634
+89.177865,126,83.613927
+92.557538,107,83.688504
+93.237011,101,83.758949
+89.126008,127,83.846957
+93.13948,99,83.915902
+92.86566,105,83.988813
+92.018233,110,84.066455
+89.595287,123,84.151766
+91.963171,111,84.228827
+89.638505,123,84.31399
+89.107044,127,84.402097
+93.18787,103,84.473541
+93.125346,97,84.540911
+93.202199,103,84.612561
+89.31941,125,84.699226
+93.203888,98,84.767159
+93.207763,97,84.83464
+88.708089,130,84.924937
+89.207923,126,85.012314
+93.213003,103,85.084081
+88.701055,130,85.17514
+93.141613,98,85.243126
+93.116478,96,85.309857
+89.124396,127,85.397951
+93.210305,103,85.469415
+89.807634,122,85.554104
+92.400846,108,85.630467
+88.636401,131,85.721369
+93.19395,99,85.790329
+93.12237,103,85.862094
+89.943097,121,85.946465
+90.118342,120,86.030022
+92.913675,105,86.102823
+90.654229,117,86.184106
+89.914921,121,86.267894
+91.098943,115,86.347583
+88.936027,128,86.436484
+93.187367,102,86.507133
+89.249832,126,86.594477
+89.108234,127,86.682587
+92.916876,105,86.75541
+93.172672,98,86.823661
+90.512931,118,86.906229
+93.139919,100,86.975729
+92.873565,105,87.048832
+93.176092,103,87.12138
+93.167602,98,87.189672
+89.799842,122,87.274525
+92.781443,106,87.348084
+93.16958,103,87.419699
+89.390136,125,87.506336
+93.19128,104,87.578401
+89.244997,126,87.6659
+89.631206,123,87.751139
+92.842291,105,87.824018
+91.768443,112,87.901861
+92.917788,105,87.974781
+93.190767,103,88.046171
+93.211766,97,88.113547
+88.702936,130,88.203895
+89.481083,124,88.290087
+90.636246,117,88.371137
+93.2357,99,88.439769
+93.225886,104,88.512164
+93.171707,104,88.585734
+93.153452,99,88.654665
+89.483705,124,88.740605
+89.329511,125,88.827997
+93.263651,104,88.9003
+89.916457,121,88.984288
+93.154281,97,89.051493
+88.998193,128,89.140115
+93.197809,97,89.207337
+88.695288,130,89.297436
+89.232276,126,89.384784
+92.684492,106,89.45827
+93.092551,103,89.529663
+93.158403,97,89.597071
+90.110284,120,89.680509
+92.417685,108,89.755529
+89.50296,124,89.841577
+93.2217,98,89.909513
+88.983626,128,89.998706
+93.153941,97,90.066056
+89.114313,127,90.155123
+89.601298,123,90.240512
+89.256132,126,90.327803
+92.143612,110,90.404044
+89.792557,122,90.48878
+92.889482,105,90.561783
+92.53248,107,90.635876
+89.902364,121,90.719736
+89.463908,124,90.805644
+88.991652,128,90.894686
+93.1882,96,90.961793
+89.939802,121,91.045545
+92.829138,105,91.11826
+88.755194,130,91.208295
+89.7717,122,91.292886
+93.148861,96,91.35976
+91.029147,115,91.439586
+92.038456,110,91.515853
+90.265573,119,91.598734
+93.178352,103,91.671412
+93.123294,103,91.742921
+90.68715,117,91.824028
+93.250189,101,91.894095
+89.355171,125,91.981058
+89.320319,125,92.068037
+89.947445,121,92.151978
+89.397556,125,92.238627
+93.137736,97,92.305911
+93.185176,103,92.377325
+90.246564,119,92.459783
+90.284355,119,92.542341
+89.79206,122,92.627175
+88.955518,128,92.715869
+88.975009,128,92.804979
+89.48926,124,92.891282
+88.951405,128,92.980399
+90.465366,118,93.062204
+93.109509,96,93.130365
+93.181941,102,93.201146
+89.186386,126,93.288865
+93.121383,98,93.356848
+93.190716,103,93.428388
+93.116138,97,93.495687
+92.271016,109,93.571246
+93.211711,103,93.643106
+89.932909,121,93.727079
+93.178976,103,93.798545
+92.376469,108,93.873444
+92.421163,108,93.948519
+88.831319,129,94.038138
+88.99594,128,94.126996
+91.142011,115,94.206826
+93.22592,99,94.275533
+93.149295,99,94.344221
+93.175543,97,94.41165
+89.350786,125,94.498443
+93.123879,103,94.569937
+91.970749,111,94.648223
+92.834187,105,94.721867
+93.10462,97,94.789438
+89.249036,126,94.877339
+93.20912,103,94.948895
+90.309984,119,95.031687
+89.509837,124,95.117581
+93.174356,100,95.18692
+93.086235,103,95.258347
+88.74182,130,95.348523
+90.263482,119,95.431169
+93.171946,99,95.500041
+93.199588,96,95.566647
+93.255415,102,95.637417
+93.136135,103,95.708988
+93.238499,103,95.780708
+92.597358,107,95.855417
+89.631791,123,95.941022
+93.172587,103,96.012683
+93.209297,103,96.084104
+88.760082,130,96.175583
+93.178501,98,96.243942
+89.101273,127,96.332009
+93.182366,99,96.400661
+90.515893,118,96.482551
+93.089371,103,96.553957
+91.89351,111,96.630932
+93.15571,100,96.700363
+93.153888,98,96.768378
+89.785497,122,96.853261
+92.99594,104,96.925722
+88.717487,130,97.016156
+89.939204,121,97.100046
+89.934537,121,97.183904
+89.191932,126,97.27123
+93.210774,102,97.341972
+88.967505,128,97.43071
+90.082126,120,97.513875
+92.894327,105,97.586647
+88.732035,130,97.678046
+93.091029,103,97.749545
+89.244402,126,97.837037
+93.131919,98,97.905
+93.228036,103,97.976631
+91.534719,113,98.0553
+90.331292,119,98.13774
+90.081655,120,98.220867
+93.222289,101,98.290964
+93.213757,103,98.36263
+93.189369,103,98.434448
+93.203435,103,98.506011
+92.37832,108,98.581209
+93.205048,97,98.648436
+93.188135,104,98.720777
+93.215186,99,98.789453
+93.089926,103,98.860995
+93.236442,100,98.930472
+88.82333,129,99.019885
+93.030509,104,99.092
+92.818484,105,99.166039
+90.076409,120,99.249606
+90.081701,120,99.332954
+90.087343,120,99.416191
+92.726764,106,99.489869
+89.113876,127,99.577914
+89.597351,123,99.663222
+93.206144,103,99.735183
+93.128215,103,99.806684
+92.292302,109,99.882263
+93.14296,96,99.948892
+89.114328,127,100.036944
+88.620912,131,100.127707
+90.690956,117,100.208924
+88.937332,128,100.297754
+93.181685,104,100.370342
+93.039621,104,100.4425
+92.678212,106,100.516052
+93.116304,103,100.587449
+89.348768,125,100.674873
+89.512717,124,100.761595
+90.946707,116,100.841982
+93.150607,99,100.910701
+92.603745,107,100.98494
+89.111125,127,101.073551
+93.00211,104,101.145632
+89.184911,126,101.23296
+89.910641,121,101.316875
+89.743,122,101.40173
+93.134007,103,101.473332
+93.008281,104,101.545467
+93.209158,103,101.616877
+93.206328,98,101.685006
+89.515352,124,101.771208
+93.150985,101,101.841605
+89.115515,127,101.929739
+93.168336,100,101.999038
+93.210424,103,102.070834
+93.171406,96,102.138115
+89.467108,124,102.225538
+92.981193,104,102.297666
+93.141328,97,102.365139
+93.22965,100,102.434683
+89.926522,121,102.518587
+90.497214,118,102.600381
+93.23257,100,102.669826
+92.396608,108,102.744774
+93.190648,103,102.816445
+93.210019,103,102.888306
+93.124251,103,102.959941
+93.209679,103,103.031445
+89.906826,121,103.115343
+92.779204,106,103.188808
+93.235075,102,103.259526
+93.150835,97,103.326885
+92.75709,106,103.400465
+90.110252,120,103.483732
+93.179731,99,103.552426
+92.865523,105,103.625242
+93.200697,96,103.69212
+93.178257,102,103.764354
+88.735652,130,103.854799
+93.230682,104,103.927037
+93.160335,103,103.999138
+93.202513,103,104.070577
+89.960259,121,104.154469
+90.134254,120,104.237638
+92.385617,108,104.312606
+93.150526,100,104.381945
+93.183856,99,104.450664
+88.759067,130,104.540955
+88.822921,129,104.630307
+89.05238,127,104.718322
+89.798559,122,104.802872
+89.782447,122,104.887698
+92.384231,108,104.962691
+89.069718,127,105.051152
+93.203272,98,105.119108
+92.207674,109,105.194847
+90.250796,119,105.278846
+93.201308,96,105.345448
+93.132576,97,105.412721
+93.190848,103,105.484147
+89.781804,122,105.56881
+93.116875,103,105.640399
+93.150437,100,105.710082
+89.369762,125,105.796787
+92.987398,104,105.868939
+92.580634,107,105.943214
+93.153979,102,106.014348
+89.894977,121,106.098143
+88.697905,130,106.188554
+93.243321,104,106.260686
+92.750388,106,106.334211
+92.404785,108,106.409182
+92.856459,105,106.482111
+92.989282,104,106.554266
diff --git a/logs/benchmark/ppo_lstm-PendulumNoVel-v1/0.monitor.csv b/logs/benchmark/ppo_lstm-PendulumNoVel-v1/0.monitor.csv
new file mode 100644
index 000000000..55fff6aaa
--- /dev/null
+++ b/logs/benchmark/ppo_lstm-PendulumNoVel-v1/0.monitor.csv
@@ -0,0 +1,752 @@
+#{"t_start": 1654204588.0577092, "env_id": "PendulumNoVel-v1"}
+r,l,t
+-402.519199,200,2.480222
+-130.254367,200,2.630337
+-253.835402,200,2.780386
+-132.850895,200,2.930846
+-389.727931,200,3.080838
+-536.415855,200,3.230758
+-0.923932,200,3.380872
+-418.149336,200,3.530867
+-1.562882,200,3.681829
+-402.953016,200,3.831828
+-132.458464,200,3.981687
+-134.372756,200,4.131569
+-386.913341,200,4.281405
+-138.890803,200,4.431183
+-0.650854,200,4.581138
+-129.375741,200,4.730886
+-130.248166,200,4.880838
+-131.582319,200,5.030658
+-267.739547,200,5.181282
+-0.580612,200,5.33136
+-136.334798,200,5.481234
+-495.272764,200,5.631139
+-484.102178,200,5.781667
+-390.524931,200,5.932026
+-130.325168,200,6.082158
+-249.204415,200,6.232241
+-264.347824,200,6.382366
+-129.584069,200,6.532424
+-126.615014,200,6.683844
+-492.988763,200,6.834109
+-393.692814,200,6.984055
+-251.805998,200,7.134065
+-2.07597,200,7.283871
+-126.728964,200,7.433803
+-392.746958,200,7.583708
+-257.422931,200,7.733629
+-507.003874,200,7.884361
+-133.24181,200,8.034442
+-263.233057,200,8.185913
+-124.297082,200,8.337559
+-398.661795,200,8.487754
+-504.372413,200,8.637957
+-128.990286,200,8.787918
+-1.080866,200,8.937824
+-132.306424,200,9.087769
+-129.747648,200,9.237722
+-256.345754,200,9.388183
+-127.788099,200,9.538201
+-135.844131,200,9.68924
+-1.512856,200,9.840115
+-383.408297,200,9.990105
+-128.013623,200,10.140147
+-257.021361,200,10.290052
+-132.973206,200,10.440223
+-136.038808,200,10.590082
+-255.61615,200,10.739967
+-130.378137,200,10.890256
+-262.911814,200,11.04023
+-129.472033,200,11.191393
+-123.746214,200,11.341515
+-131.366346,200,11.49152
+-128.73705,200,11.641451
+-121.811967,200,11.791374
+-131.371365,200,11.941349
+-256.614339,200,12.091842
+-258.178329,200,12.241704
+-2.546309,200,12.391757
+-253.503013,200,12.541677
+-249.908471,200,12.692321
+-259.233888,200,12.842389
+-129.391829,200,12.992356
+-131.588284,200,13.142358
+-377.182906,200,13.292416
+-2.808374,200,13.442636
+-124.696563,200,13.592626
+-2.283769,200,13.74265
+-252.46194,200,13.892748
+-260.613869,200,14.042629
+-127.715141,200,14.193771
+-130.87715,200,14.344025
+-266.165831,200,14.494216
+-136.758933,200,14.644304
+-521.274432,200,14.794543
+-134.91778,200,14.944999
+-132.464344,200,15.09512
+-259.136921,200,15.244979
+-523.922728,200,15.394954
+-126.623535,200,15.54527
+-134.4309,200,15.696621
+-128.162222,200,15.846931
+-128.474804,200,15.997117
+-128.321519,200,16.147037
+-256.767207,200,16.296895
+-131.03306,200,16.447109
+-1.755037,200,16.597155
+-412.540829,200,16.747057
+-131.401396,200,16.897024
+-256.84293,200,17.047659
+-128.618866,200,17.198305
+-549.44501,200,17.349118
+-262.431591,200,17.499175
+-125.635456,200,17.649138
+-393.707247,200,17.799148
+-506.713824,200,17.949326
+-516.686094,200,18.09973
+-135.451174,200,18.249698
+-136.502389,200,18.399829
+-260.013239,200,18.549896
+-127.30565,200,18.700019
+-455.226846,200,18.850885
+-132.617079,200,19.001096
+-130.89826,200,19.150959
+-258.652825,200,19.300925
+-129.663716,200,19.450789
+-125.848011,200,19.600976
+-502.032173,200,19.750824
+-394.359487,200,19.901413
+-129.928044,200,20.051476
+-126.740772,200,20.201448
+-133.688292,200,20.352506
+-393.770809,200,20.502646
+-129.374495,200,20.652508
+-262.399361,200,20.802411
+-271.81621,200,20.952535
+-129.929063,200,21.102612
+-255.784697,200,21.252395
+-134.037024,200,21.403085
+-132.72468,200,21.552927
+-129.344628,200,21.702864
+-242.856234,200,21.853705
+-1.196476,200,22.003769
+-390.246182,200,22.1539
+-129.311706,200,22.304043
+-123.947636,200,22.454089
+-409.102118,200,22.604069
+-424.501649,200,22.754085
+-128.488516,200,22.904248
+-394.249286,200,23.054295
+-262.855733,200,23.204241
+-389.65392,200,23.355616
+-258.634432,200,23.505915
+-247.377378,200,23.656505
+-437.781626,200,23.806754
+-129.533367,200,23.956941
+-127.797557,200,24.107264
+-127.703761,200,24.258819
+-134.206551,200,24.409025
+-130.898959,200,24.559296
+-136.946955,200,24.70922
+-540.497969,200,24.860151
+-510.421549,200,25.010261
+-252.524736,200,25.160258
+-132.670144,200,25.310229
+-253.607705,200,25.460401
+-266.376325,200,25.610475
+-258.740894,200,25.760541
+-121.987025,200,25.910583
+-131.414354,200,26.060641
+-256.732595,200,26.210585
+-259.629018,200,26.361656
+-243.632808,200,26.512006
+-263.768034,200,26.662218
+-520.935392,200,26.812245
+-130.199743,200,26.962329
+-131.462324,200,27.112274
+-261.533195,200,27.262167
+-264.952242,200,27.4122
+-499.004986,200,27.562604
+-265.635419,200,27.712666
+-134.417145,200,27.863875
+-406.186914,200,28.014071
+-261.730142,200,28.164416
+-131.56676,200,28.314623
+-530.232442,200,28.465021
+-127.203444,200,28.615008
+-380.986625,200,28.765011
+-133.808886,200,28.915337
+-468.679318,200,29.065327
+-260.520473,200,29.215312
+-385.850028,200,29.366877
+-386.112575,200,29.516918
+-133.899244,200,29.66715
+-127.112062,200,29.817209
+-266.173676,200,29.967375
+-392.787095,200,30.117482
+-130.038843,200,30.268342
+-133.178376,200,30.418429
+-1.518369,200,30.568381
+-385.404969,200,30.718568
+-0.661449,200,30.869841
+-255.147382,200,31.020269
+-590.537446,200,31.170284
+-123.028762,200,31.320235
+-532.179323,200,31.470438
+-132.791688,200,31.620412
+-390.66596,200,31.770427
+-129.960529,200,31.920456
+-396.281114,200,32.071321
+-395.102186,200,32.222039
+-131.642122,200,32.373405
+-404.060851,200,32.523429
+-124.060503,200,32.673403
+-131.717835,200,32.823598
+-127.718313,200,32.97415
+-137.818234,200,33.124208
+-129.95035,200,33.275639
+-131.477312,200,33.425819
+-127.797146,200,33.576052
+-124.371809,200,33.726285
+-130.104659,200,33.878643
+-133.817383,200,34.029474
+-123.440699,200,34.179405
+-257.553851,200,34.329363
+-1.697628,200,34.479425
+-390.571687,200,34.629377
+-0.897894,200,34.779369
+-2.533391,200,34.929563
+-128.027611,200,35.080093
+-414.369842,200,35.230119
+-274.964981,200,35.381662
+-128.183136,200,35.531727
+-132.552583,200,35.681757
+-4.302741,200,35.831682
+-257.129494,200,35.981714
+-260.231924,200,36.131918
+-134.285054,200,36.282022
+-362.860757,200,36.432488
+-128.915897,200,36.582807
+-456.543651,200,36.732877
+-3.358384,200,36.884256
+-255.450524,200,37.034254
+-261.3761,200,37.184193
+-260.793575,200,37.334163
+-133.796244,200,37.48479
+-257.517465,200,37.634975
+-531.879293,200,37.78501
+-128.615242,200,37.935118
+-130.197618,200,38.08517
+-258.885161,200,38.235106
+-123.943804,200,38.386478
+-132.935323,200,38.536453
+-530.041036,200,38.686333
+-130.818515,200,38.836363
+-262.447193,200,38.986387
+-505.777393,200,39.136355
+-2.975214,200,39.286256
+-412.703885,200,39.436566
+-538.712992,200,39.586547
+-528.167505,200,39.736509
+-130.346881,200,39.889076
+-131.897898,200,40.039305
+-388.192033,200,40.189297
+-131.470567,200,40.339184
+-410.195644,200,40.489591
+-130.107815,200,40.639575
+-244.712956,200,40.789485
+-255.326688,200,40.940122
+-261.541475,200,41.090018
+-258.160837,200,41.239906
+-387.741392,200,41.392228
+-257.322626,200,41.542432
+-415.767291,200,41.692436
+-129.723756,200,41.842434
+-0.794575,200,41.992673
+-128.213887,200,42.142858
+-251.497676,200,42.292704
+-259.17477,200,42.442664
+-128.269772,200,42.592744
+-131.744092,200,42.742721
+-262.842008,200,42.894565
+-498.548773,200,43.047081
+-0.774458,200,43.19861
+-262.596822,200,43.350205
+-124.026436,200,43.50212
+-257.435492,200,43.654454
+-260.68249,200,43.806756
+-255.574588,200,43.959209
+-264.727231,200,44.111727
+-377.587041,200,44.264087
+-128.587609,200,44.417627
+-129.70796,200,44.570029
+-269.060373,200,44.72219
+-131.165587,200,44.874727
+-254.387696,200,45.026945
+-263.9862,200,45.178953
+-134.287859,200,45.331221
+-125.80162,200,45.483443
+-508.039606,200,45.635566
+-263.145171,200,45.78762
+-130.273372,200,45.940681
+-125.9654,200,46.09278
+-257.22796,200,46.244693
+-250.627959,200,46.396623
+-0.826117,200,46.548539
+-131.003295,200,46.700485
+-122.997263,200,46.852642
+-0.734416,200,47.004569
+-253.738489,200,47.15634
+-256.666289,200,47.308181
+-122.696645,200,47.461555
+-127.847016,200,47.613417
+-256.762524,200,47.765217
+-254.441311,200,47.917135
+-133.397656,200,48.069028
+-271.438197,200,48.220624
+-134.179045,200,48.372397
+-262.332681,200,48.524182
+-1.337436,200,48.675769
+-130.565648,200,48.827398
+-131.142902,200,48.9806
+-127.975927,200,49.132221
+-262.25393,200,49.283701
+-254.449826,200,49.435412
+-258.731496,200,49.586887
+-389.092838,200,49.738442
+-0.721196,200,49.890626
+-127.779348,200,50.042051
+-131.439441,200,50.193412
+-133.441949,200,50.344803
+-537.846484,200,50.497638
+-135.778334,200,50.649079
+-261.050266,200,50.800883
+-258.139315,200,50.952236
+-256.596868,200,51.103744
+-131.703291,200,51.254972
+-128.114295,200,51.406292
+-491.009204,200,51.557491
+-131.508491,200,51.708709
+-133.760086,200,51.859934
+-130.221702,200,52.012615
+-133.389744,200,52.163825
+-128.091615,200,52.315926
+-252.439552,200,52.468925
+-128.490453,200,52.620389
+-128.985618,200,52.771552
+-131.770744,200,52.921945
+-253.997435,200,53.072034
+-127.733059,200,53.221994
+-254.225273,200,53.372034
+-405.849477,200,53.523106
+-242.3805,200,53.674289
+-134.21694,200,53.824225
+-489.829358,200,53.974366
+-403.942868,200,54.124361
+-123.547458,200,54.274329
+-135.275231,200,54.424801
+-396.79745,200,54.574818
+-563.813464,200,54.724805
+-262.983306,200,54.875445
+-130.550435,200,55.026812
+-0.924758,200,55.177573
+-128.991373,200,55.327608
+-268.824452,200,55.478076
+-257.067561,200,55.628124
+-130.902823,200,55.778219
+-254.152831,200,55.928341
+-531.076613,200,56.078641
+-129.705213,200,56.228625
+-127.101532,200,56.378672
+-260.557891,200,56.52994
+-417.469802,200,56.679893
+-130.238158,200,56.829859
+-2.974881,200,56.980036
+-265.916012,200,57.129957
+-0.747026,200,57.279986
+-130.729517,200,57.430612
+-257.108569,200,57.580575
+-1.379407,200,57.73075
+-131.131822,200,57.880719
+-387.538382,200,58.032043
+-262.503316,200,58.182208
+-253.155763,200,58.332651
+-490.879981,200,58.483264
+-263.13122,200,58.6331
+-386.007151,200,58.783066
+-254.606673,200,58.933168
+-264.764901,200,59.083106
+-483.63525,200,59.233039
+-131.832368,200,59.383191
+-254.521777,200,59.534242
+-244.609944,200,59.684812
+-271.196004,200,59.834779
+-262.006001,200,59.985074
+-261.02399,200,60.135115
+-260.254047,200,60.285059
+-132.289053,200,60.435072
+-254.908308,200,60.584996
+-497.978891,200,60.734942
+-255.888694,200,60.884999
+-261.993791,200,61.035861
+-388.388129,200,61.186234
+-1.24258,200,61.336648
+-128.899158,200,61.486665
+-246.394188,200,61.636727
+-128.094586,200,61.786773
+-128.658202,200,61.936914
+-402.595834,200,62.08699
+-253.379127,200,62.237787
+-121.864009,200,62.388124
+-130.144938,200,62.538343
+-128.310917,200,62.688994
+-133.84844,200,62.838951
+-132.475825,200,62.989049
+-1.026764,200,63.138992
+-502.208176,200,63.288834
+-128.871903,200,63.438949
+-261.875245,200,63.588862
+-1.759837,200,63.738794
+-121.79994,200,63.888904
+-383.090765,200,64.039152
+-263.741944,200,64.190158
+-125.44497,200,64.340029
+-257.646825,200,64.490144
+-267.011003,200,64.640196
+-129.206928,200,64.790463
+-127.065427,200,64.940374
+-263.431516,200,65.090271
+-364.139141,200,65.240197
+-254.247729,200,65.390134
+-253.04928,200,65.540108
+-128.762169,200,65.691237
+-132.1468,200,65.841159
+-130.959596,200,65.991195
+-126.657641,200,66.141119
+-255.893589,200,66.290903
+-132.205551,200,66.440856
+-130.498794,200,66.590754
+-127.877895,200,66.74071
+-133.294295,200,66.890872
+-131.329031,200,67.040868
+-129.291232,200,67.192048
+-258.77575,200,67.341976
+-1.553831,200,67.492198
+-257.745899,200,67.642157
+-126.255173,200,67.792123
+-128.979293,200,67.94228
+-131.461846,200,68.092414
+-133.893376,200,68.242526
+-124.267656,200,68.39249
+-385.636733,200,68.542359
+-264.672312,200,68.693232
+-131.000453,200,68.843172
+-247.950141,200,68.99308
+-263.850945,200,69.142999
+-252.46746,200,69.292894
+-127.940674,200,69.442874
+-403.64718,200,69.592826
+-130.133859,200,69.74283
+-243.464739,200,69.892962
+-138.400139,200,70.042819
+-263.079562,200,70.194396
+-130.953487,200,70.344238
+-128.385681,200,70.494611
+-133.014874,200,70.644564
+-259.780775,200,70.794422
+-129.153884,200,70.944657
+-521.664381,200,71.094606
+-245.090205,200,71.244525
+-130.89642,200,71.394762
+-133.35325,200,71.545349
+-258.936847,200,71.696949
+-128.330259,200,71.847137
+-1.801753,200,71.997452
+-259.93976,200,72.147381
+-396.024985,200,72.297277
+-131.870194,200,72.447377
+-136.297627,200,72.597364
+-0.823265,200,72.747237
+-270.715603,200,72.897243
+-130.113774,200,73.047105
+-479.652184,200,73.197917
+-251.999533,200,73.347957
+-130.148371,200,73.498791
+-253.696658,200,73.648846
+-123.047836,200,73.798869
+-1.156308,200,73.949221
+-520.705938,200,74.099172
+-259.495808,200,74.249089
+-131.728049,200,74.399434
+-261.29829,200,74.549264
+-122.16629,200,74.700153
+-133.926102,200,74.850294
+-130.137137,200,75.000444
+-254.939282,200,75.15031
+-513.074461,200,75.30023
+-130.096033,200,75.450384
+-261.31782,200,75.600297
+-126.510035,200,75.750345
+-421.608682,200,75.900493
+-124.13579,200,76.050721
+-136.36839,200,76.202114
+-389.560064,200,76.352335
+-131.083888,200,76.502524
+-130.690174,200,76.652491
+-413.572655,200,76.802619
+-135.291854,200,76.952681
+-134.195794,200,77.102591
+-131.057201,200,77.253225
+-134.03553,200,77.403305
+-0.820752,200,77.55325
+-123.267231,200,77.704308
+-398.19583,200,77.854314
+-390.690876,200,78.004445
+-407.120835,200,78.154418
+-131.894882,200,78.304872
+-133.474274,200,78.454864
+-132.188956,200,78.60489
+-263.746771,200,78.754839
+-475.524122,200,78.905009
+-132.041959,200,79.054985
+-129.917066,200,79.206104
+-133.146303,200,79.356185
+-390.49514,200,79.506301
+-399.885877,200,79.656353
+-131.290165,200,79.806514
+-130.199982,200,79.95666
+-133.556031,200,80.106842
+-130.931187,200,80.256788
+-250.647754,200,80.406693
+-128.200454,200,80.556664
+-127.721969,200,80.708273
+-130.520502,200,80.858291
+-130.858737,200,81.00835
+-255.487012,200,81.158605
+-2.253563,200,81.309322
+-130.487165,200,81.459565
+-126.719604,200,81.609811
+-126.431937,200,81.759861
+-133.653628,200,81.909883
+-256.780391,200,82.059874
+-495.09774,200,82.211046
+-136.353332,200,82.361443
+-387.457952,200,82.511901
+-253.279425,200,82.661841
+-259.324465,200,82.811814
+-530.177779,200,82.961904
+-130.197747,200,83.111833
+-263.103061,200,83.261714
+-131.738116,200,83.411888
+-131.201019,200,83.56198
+-128.945717,200,83.71337
+-251.079851,200,83.863603
+-267.302457,200,84.013948
+-388.325636,200,84.163871
+-122.342331,200,84.3138
+-259.707733,200,84.463821
+-131.781672,200,84.614248
+-135.124134,200,84.764165
+-130.321858,200,84.914449
+-1.313194,200,85.064358
+-126.503561,200,85.214566
+-496.524308,200,85.365945
+-6.160445,200,85.516003
+-124.383663,200,85.665968
+-394.493689,200,85.816112
+-260.649325,200,85.966306
+-135.382236,200,86.116416
+-260.879486,200,86.266237
+-492.730836,200,86.416513
+-1.308366,200,86.566504
+-412.233998,200,86.716506
+-0.567212,200,86.867225
+-398.629213,200,87.01755
+-415.362889,200,87.167497
+-126.896792,200,87.317436
+-0.846113,200,87.467558
+-133.211996,200,87.617578
+-259.124641,200,87.767718
+-262.516432,200,87.917925
+-259.693684,200,88.068021
+-507.08818,200,88.21801
+-370.794814,200,88.368912
+-1.087646,200,88.519423
+-129.548493,200,88.669913
+-135.406123,200,88.819872
+-252.677184,200,88.97002
+-258.317547,200,89.119923
+-450.705273,200,89.27018
+-133.641993,200,89.420418
+-126.235579,200,89.570389
+-250.197252,200,89.720362
+-385.796826,200,89.871599
+-527.451591,200,90.021584
+-255.994445,200,90.171479
+-125.820988,200,90.321563
+-3.355339,200,90.471581
+-261.828027,200,90.621686
+-1.082514,200,90.772223
+-377.37877,200,90.922379
+-1.683153,200,91.07271
+-254.005314,200,91.222569
+-259.746342,200,91.373994
+-0.76549,200,91.524039
+-456.105696,200,91.674047
+-132.898364,200,91.824019
+-264.393299,200,91.974616
+-133.124353,200,92.124937
+-134.04704,200,92.274887
+-532.695581,200,92.425009
+-128.479076,200,92.575025
+-0.791785,200,92.725048
+-123.897,200,92.876107
+-403.021303,200,93.026173
+-256.030124,200,93.176146
+-132.985599,200,93.326122
+-1.818002,200,93.476148
+-270.469002,200,93.626078
+-3.079538,200,93.776158
+-259.517448,200,93.92628
+-254.361579,200,94.076236
+-385.965485,200,94.226105
+-1.827868,200,94.377053
+-131.51911,200,94.527777
+-392.387264,200,94.677916
+-133.101115,200,94.827924
+-133.839194,200,94.977951
+-123.915708,200,95.127863
+-131.339995,200,95.278092
+-395.23312,200,95.428234
+-387.022453,200,95.578197
+-131.928348,200,95.728455
+-129.970639,200,95.87979
+-261.022001,200,96.03028
+-261.491149,200,96.180207
+-246.163852,200,96.330205
+-124.154232,200,96.480319
+-130.419012,200,96.630197
+-242.645725,200,96.78007
+-130.260847,200,96.930046
+-125.610515,200,97.080301
+-1.407112,200,97.230152
+-253.562121,200,97.381064
+-2.213094,200,97.531155
+-0.774488,200,97.681336
+-4.187768,200,97.831277
+-259.572981,200,97.982108
+-129.353383,200,98.13202
+-124.517025,200,98.282337
+-264.471257,200,98.432557
+-254.968328,200,98.582457
+-123.42731,200,98.732319
+-259.974884,200,98.883663
+-447.813526,200,99.033815
+-259.955894,200,99.183714
+-258.804813,200,99.333555
+-129.039698,200,99.483533
+-128.780577,200,99.63378
+-260.814353,200,99.784062
+-122.431627,200,99.934475
+-1.080245,200,100.084785
+-260.830759,200,100.235228
+-534.234211,200,100.386697
+-523.165229,200,100.53708
+-129.10214,200,100.687067
+-132.807356,200,100.836922
+-130.414431,200,100.986925
+-395.870585,200,101.136974
+-377.701219,200,101.286847
+-260.948133,200,101.437155
+-0.816817,200,101.58697
+-262.958324,200,101.736865
+-134.742928,200,101.887607
+-454.942996,200,102.037928
+-130.919269,200,102.187853
+-264.041871,200,102.337847
+-130.423312,200,102.487939
+-127.226415,200,102.637887
+-522.461754,200,102.78776
+-1.148708,200,102.937798
+-6.536455,200,103.087711
+-257.6031,200,103.237552
+-128.720644,200,103.388964
+-131.406704,200,103.539749
+-129.394132,200,103.689775
+-132.589514,200,103.83981
+-124.933573,200,103.989935
+-478.735802,200,104.140508
+-3.082373,200,104.290401
+-131.676206,200,104.440384
+-534.309578,200,104.590317
+-135.326852,200,104.740203
+-132.253694,200,104.891677
+-255.012311,200,105.041959
+-255.012722,200,105.191854
+-258.503162,200,105.341737
+-0.895727,200,105.491628
+-0.899892,200,105.641625
+-1.728374,200,105.791913
+-514.48447,200,105.942069
+-399.104322,200,106.092014
+-533.905116,200,106.24193
+-124.653509,200,106.392929
+-1.529994,200,106.543303
+-565.123823,200,106.69336
+-382.928626,200,106.843282
+-393.138859,200,106.993197
+-133.039055,200,107.143152
+-1.254905,200,107.293069
+-253.218977,200,107.443086
+-377.926736,200,107.59297
+-384.992998,200,107.742795
+-132.20426,200,107.893014
+-2.265609,200,108.044569
+-129.277822,200,108.194533
+-516.770055,200,108.344465
+-123.844188,200,108.494859
+-488.233926,200,108.644754
+-259.32199,200,108.794864
+-417.74258,200,108.944967
+-126.061511,200,109.094862
+-130.831749,200,109.244875
+-255.263712,200,109.395019
+-501.52857,200,109.546746
+-1.092731,200,109.697082
+-124.415082,200,109.848345
+-398.725763,200,109.99847
+-265.103,200,110.148456
+-262.173707,200,110.298429
+-1.412234,200,110.448377
+-384.796566,200,110.598411
+-122.613973,200,110.748757
+-131.847815,200,110.898898
+-259.808992,200,111.0496
+-260.018786,200,111.199631
+-476.825245,200,111.349615
+-265.113998,200,111.49967
+-125.187484,200,111.649592
+-131.658743,200,111.799565
+-560.090144,200,111.949756
+-249.87696,200,112.09988
+-513.023542,200,112.249848
+-261.979488,200,112.399794
+-253.954332,200,112.55137
+-131.624941,200,112.701358
+-3.65569,200,112.851359
+-129.742158,200,113.001439
+-127.666264,200,113.151829
+-129.3067,200,113.301748
+-130.320689,200,113.452212
+-128.921845,200,113.602148
+-263.061197,200,113.752045
+-129.571114,200,113.902041
+-391.835508,200,114.053375
+-257.21494,200,114.203687
+-258.361078,200,114.353614
+-125.928314,200,114.503546
+-128.023682,200,114.654358
+-126.836506,200,114.804574
+-264.384871,200,114.955343
+-129.339465,200,115.105539
diff --git a/requirements.txt b/requirements.txt
index dbd890eb1..4318fc8e3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,16 +1,20 @@
gym==0.21
-stable-baselines3[extra,tests,docs]>=1.5.0
-sb3-contrib>=1.5.0
+stable-baselines3[extra,tests,docs]>=1.5.1a7
+sb3-contrib>=1.5.1a8
box2d-py==2.3.8
pybullet
gym-minigrid
scikit-optimize
optuna
-pytablewriter
+pytablewriter~=0.64
seaborn
pyyaml>=5.1
cloudpickle>=1.5.0
plotly
+git+https://github.com/takuseno/d3rlpy
panda-gym==1.1.1 # tmp fix: until compatibility with panda-gym v2
rliable>=1.0.5
wandb
+ale-py==0.7.4 # tmp fix: until new SB3 version is released
+# TODO: replace with release
+git+https://github.com/huggingface/huggingface_sb3
diff --git a/scripts/all_plots.py b/scripts/all_plots.py
index b456b9009..335b78829 100644
--- a/scripts/all_plots.py
+++ b/scripts/all_plots.py
@@ -205,7 +205,7 @@
# Markdown Table
-writer = pytablewriter.MarkdownTableWriter()
+writer = pytablewriter.MarkdownTableWriter(max_precision=3)
writer.table_name = "results_table"
headers = ["Environments"]
diff --git a/scripts/migrate_to_hub.py b/scripts/migrate_to_hub.py
new file mode 100644
index 000000000..d2622a43a
--- /dev/null
+++ b/scripts/migrate_to_hub.py
@@ -0,0 +1,23 @@
+import subprocess
+
+from utils.utils import get_hf_trained_models, get_trained_models
+
+folder = "rl-trained-agents"
+orga = "sb3"
+trained_models_local = get_trained_models(folder)
+trained_models_hub = get_hf_trained_models(orga)
+remaining_models = set(trained_models_local.keys()) - set(trained_models_hub.keys())
+
+for trained_model in list(remaining_models):
+ algo, env_id = trained_models_local[trained_model]
+ args = ["-orga", orga, "-f", folder, "--algo", algo, "--env", env_id]
+
+ # Since SB3 >= 1.1.0, HER is no more an algorithm but a replay buffer class
+ if algo == "her":
+ continue
+
+ # if model doesn't exist already
+ repo_name = f"{algo}-{env_id}"
+ repo_id = f"{orga}/{repo_name}"
+
+ return_code = subprocess.call(["python", "-m", "utils.push_to_hub"] + args)
diff --git a/scripts/plot_from_file.py b/scripts/plot_from_file.py
index 4ec0043cf..e0f1c95e3 100644
--- a/scripts/plot_from_file.py
+++ b/scripts/plot_from_file.py
@@ -79,7 +79,7 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
results = pickle.load(file_handler)
# Plot table
-writer = pytablewriter.MarkdownTableWriter()
+writer = pytablewriter.MarkdownTableWriter(max_precision=3)
writer.table_name = "results_table"
writer.headers = results["results_table"]["headers"]
writer.value_matrix = results["results_table"]["value_matrix"]
@@ -193,7 +193,7 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
warnings.warn(f"{env} not found for normalizing scores, you should update `env_key_to_env_id`")
# Truncate to convert to matrix
- min_runs = min([len(algo_score) for algo_score in algo_scores])
+ min_runs = min(len(algo_score) for algo_score in algo_scores)
if min_runs > 0:
algo_scores = [algo_score[:min_runs] for algo_score in algo_scores]
# shape: (n_envs, n_runs) -> (n_runs, n_envs)
diff --git a/setup.cfg b/setup.cfg
index c74b8320e..6841b0fcb 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -20,6 +20,7 @@ per-file-ignores =
./scripts/all_plots.py:E501
./scripts/plot_train.py:E501
./scripts/plot_training_success.py:E501
+ ./utils/teleop.py:F405
exclude =
# No need to traverse our git directory
diff --git a/tests/test_enjoy.py b/tests/test_enjoy.py
index c55e702b2..194530556 100644
--- a/tests/test_enjoy.py
+++ b/tests/test_enjoy.py
@@ -3,7 +3,7 @@
import pytest
-from utils import get_trained_models
+from utils.utils import get_hf_trained_models, get_trained_models
def _assert_eq(left, right):
@@ -12,8 +12,10 @@ def _assert_eq(left, right):
FOLDER = "rl-trained-agents/"
N_STEPS = 100
-
+# Use local models
trained_models = get_trained_models(FOLDER)
+# Use huggingface models too
+trained_models.update(get_hf_trained_models())
@pytest.mark.parametrize("trained_model", trained_models.keys())
@@ -26,6 +28,10 @@ def test_trained_agents(trained_model):
if algo == "her":
return
+ # skip car racing
+ if "CarRacing" in env_id:
+ return
+
# Skip mujoco envs
if "Fetch" in trained_model or "-v3" in trained_model:
return
@@ -38,7 +44,7 @@ def test_trained_agents(trained_model):
def test_benchmark(tmp_path):
- args = ["-n", str(N_STEPS), "--benchmark-dir", tmp_path, "--test-mode"]
+ args = ["-n", str(N_STEPS), "--benchmark-dir", tmp_path, "--test-mode", "--no-hub"]
return_code = subprocess.call(["python", "-m", "utils.benchmark"] + args)
_assert_eq(return_code, 0)
diff --git a/tests/test_hyperparams_opt.py b/tests/test_hyperparams_opt.py
index cd5444b70..89396a1f1 100644
--- a/tests/test_hyperparams_opt.py
+++ b/tests/test_hyperparams_opt.py
@@ -2,7 +2,9 @@
import os
import subprocess
+import optuna
import pytest
+from optuna.trial import TrialState
def _assert_eq(left, right):
@@ -120,3 +122,57 @@ def test_optimize_log_path(tmp_path):
]
return_code = subprocess.call(["python", "scripts/parse_study.py"] + args)
_assert_eq(return_code, 0)
+
+
+def test_multiple_workers(tmp_path):
+ study_name = "test-study"
+ storage = f"sqlite:///{tmp_path}/optuna.db"
+ # n trials per worker
+ n_trials = 2
+ # max total trials
+ max_trials = 3
+ # 1st worker will do 2 trials
+ # 2nd worker will do 1 trial
+ # 3rd worker will do nothing
+ n_workers = 3
+ args = [
+ "-optimize",
+ "--no-optim-plots",
+ "--storage",
+ storage,
+ "--n-trials",
+ str(n_trials),
+ "--max-total-trials",
+ str(max_trials),
+ "--study-name",
+ study_name,
+ "--n-evaluations",
+ str(1),
+ "-n",
+ str(100),
+ "--algo",
+ "a2c",
+ "--env",
+ "Pendulum-v1",
+ "--log-folder",
+ tmp_path,
+ "-params",
+ "n_envs:1",
+ "--seed",
+ "12",
+ ]
+
+ # Sequencial execution to avoid race conditions
+ workers = []
+ for _ in range(n_workers):
+ worker = subprocess.Popen(
+ ["python", "train.py"] + args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True
+ )
+ worker.wait()
+ workers.append(worker)
+
+ study = optuna.load_study(study_name=study_name, storage=storage)
+ assert len(study.get_trials(states=(TrialState.COMPLETE, TrialState.PRUNED))) == max_trials
+
+ for worker in workers:
+ assert worker.returncode == 0, "STDOUT:\n{}\nSTDERR:\n{}\n".format(*worker.communicate())
diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py
index 5ec00c0f4..bfe0d7a1b 100644
--- a/tests/test_wrappers.py
+++ b/tests/test_wrappers.py
@@ -1,7 +1,9 @@
import gym
import pybullet_envs # noqa: F401
import pytest
+from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env
+from stable_baselines3.common.env_util import DummyVecEnv
from utils.utils import get_wrapper_class
from utils.wrappers import ActionNoiseWrapper, DelayedRewardWrapper, HistoryWrapper, TimeFeatureWrapper
@@ -31,3 +33,20 @@ def test_get_wrapper(env_wrapper):
if env_wrapper is not None:
env = wrapper_class(env)
check_env(env)
+
+
+@pytest.mark.parametrize(
+ "vec_env_wrapper",
+ [
+ None,
+ {"stable_baselines3.common.vec_env.VecFrameStack": dict(n_stack=2)},
+ [{"stable_baselines3.common.vec_env.VecFrameStack": dict(n_stack=3)}, "stable_baselines3.common.vec_env.VecMonitor"],
+ ],
+)
+def test_get_vec_env_wrapper(vec_env_wrapper):
+ env = DummyVecEnv([lambda: gym.make("AntBulletEnv-v0")])
+ hyperparams = {"vec_env_wrapper": vec_env_wrapper}
+ wrapper_class = get_wrapper_class(hyperparams, "vec_env_wrapper")
+ if wrapper_class is not None:
+ env = wrapper_class(env)
+ A2C("MlpPolicy", env).learn(16)
diff --git a/train.py b/train.py
index bbf85795e..1577297da 100644
--- a/train.py
+++ b/train.py
@@ -11,10 +11,28 @@
import torch as th
from stable_baselines3.common.utils import set_random_seed
+try:
+ from d3rlpy.algos import AWAC, AWR, BC, BCQ, BEAR, CQL, CRR, TD3PlusBC
+ from d3rlpy.models.encoders import VectorEncoderFactory
+ from d3rlpy.wrappers.sb3 import SB3Wrapper, to_mdp_dataset
+
+ offline_algos = dict(
+ awr=AWR,
+ awac=AWAC,
+ bc=BC,
+ bcq=BCQ,
+ bear=BEAR,
+ cql=CQL,
+ crr=CRR,
+ td3bc=TD3PlusBC,
+ )
+except ImportError:
+ offline_algos = {}
+
# Register custom envs
import utils.import_envs # noqa: F401 pytype: disable=import-error
from utils.exp_manager import ExperimentManager
-from utils.utils import ALGOS, StoreDict
+from utils.utils import ALGOS, StoreDict, evaluate_policy_add_to_buffer
seaborn.set()
@@ -64,6 +82,13 @@
type=int,
default=500,
)
+ parser.add_argument(
+ "--max-total-trials",
+ help="Number of (potentially pruned) trials for optimizing hyperparameters. "
+ "This applies to the entire optimization process and takes precedence over --n-trials if set.",
+ type=int,
+ default=None,
+ )
parser.add_argument(
"-optimize", "--optimize-hyperparameters", action="store_true", default=False, help="Run hyperparameters search"
)
@@ -117,6 +142,17 @@
help="Overwrite hyperparameter (e.g. learning_rate:0.01 train_freq:10)",
)
parser.add_argument("-uuid", "--uuid", action="store_true", default=False, help="Ensure that the run has a unique ID")
+ parser.add_argument(
+ "--offline-algo", help="Offline RL Algorithm", type=str, required=False, choices=list(offline_algos.keys())
+ )
+ parser.add_argument("-b", "--pretrain-buffer", help="Path to a saved replay buffer for pretraining", type=str)
+ parser.add_argument(
+ "--pretrain-params",
+ type=str,
+ nargs="+",
+ action=StoreDict,
+ help="Optional arguments for pretraining with replay buffer",
+ )
parser.add_argument(
"--track",
action="store_true",
@@ -201,6 +237,7 @@
args.storage,
args.study_name,
args.n_trials,
+ args.max_total_trials,
args.n_jobs,
args.sampler,
args.pruner,
@@ -228,9 +265,79 @@
args.saved_hyperparams = saved_hyperparams
run.config.setdefaults(vars(args))
- # Normal training
- if model is not None:
- exp_manager.learn(model)
- exp_manager.save_trained_model(model)
+ if args.pretrain_buffer is not None and model is not None:
+ model.load_replay_buffer(args.pretrain_buffer)
+ print(f"Buffer size = {model.replay_buffer.buffer_size}")
+ # Artificially reduce buffer size
+ # model.replay_buffer.full = False
+ # model.replay_buffer.pos = 5000
+
+ print(f"{model.replay_buffer.size()} transitions in the replay buffer")
+
+ n_iterations = args.pretrain_params.get("n_iterations", 10)
+ n_epochs = args.pretrain_params.get("n_epochs", 1)
+ q_func_factory = args.pretrain_params.get("q_func_factory")
+ batch_size = args.pretrain_params.get("batch_size", 512)
+ # n_action_samples = args.pretrain_params.get("n_action_samples", 1)
+ n_eval_episodes = args.pretrain_params.get("n_eval_episodes", 5)
+ add_to_buffer = args.pretrain_params.get("add_to_buffer", False)
+ deterministic = args.pretrain_params.get("deterministic", True)
+ net_arch = args.pretrain_params.get("net_arch", [256, 256])
+ scaler = args.pretrain_params.get("scaler", "standard")
+ encoder_factory = VectorEncoderFactory(hidden_units=net_arch)
+ for arg_name in {
+ "n_iterations",
+ "n_epochs",
+ "q_func_factory",
+ "batch_size",
+ "n_eval_episodes",
+ "add_to_buffer",
+ "deterministic",
+ "net_arch",
+ "scaler",
+ }:
+ if arg_name in args.pretrain_params:
+ del args.pretrain_params[arg_name]
+ try:
+ assert args.offline_algo is not None and offline_algos is not None
+ kwargs_ = {} if q_func_factory is None else dict(q_func_factory=q_func_factory)
+ kwargs_.update(dict(encoder_factory=encoder_factory))
+ kwargs_.update(args.pretrain_params)
+
+ offline_model = offline_algos[args.offline_algo](
+ batch_size=batch_size,
+ **kwargs_,
+ )
+ offline_model = SB3Wrapper(offline_model)
+ offline_model.use_sde = False
+ # break the logger...
+ # offline_model.replay_buffer = model.replay_buffer
+
+ for i in range(n_iterations):
+ dataset = to_mdp_dataset(model.replay_buffer)
+ offline_model.fit(dataset.episodes, n_epochs=n_epochs, save_metrics=False)
+
+ mean_reward, std_reward = evaluate_policy_add_to_buffer(
+ offline_model,
+ model.get_env(),
+ n_eval_episodes=n_eval_episodes,
+ replay_buffer=model.replay_buffer if add_to_buffer else None,
+ deterministic=deterministic,
+ )
+ print(f"Iteration {i + 1} training, mean_reward={mean_reward:.2f} +/- {std_reward:.2f}")
+ except KeyboardInterrupt:
+ pass
+ finally:
+ print(f"Saving offline model to {exp_manager.save_path}/policy.pt")
+ offline_model.save_policy(f"{exp_manager.save_path}/policy.pt")
+ # print("Starting training")
+ # TODO: convert d3rlpy weights to DB3
+ model.env.close()
+ exit()
+
+ # Normal training
+ if model is not None:
+ exp_manager.learn(model)
+ exp_manager.save_trained_model(model)
else:
exp_manager.hyperparameters_optimization()
diff --git a/utils/benchmark.py b/utils/benchmark.py
index 9c7e8d597..b53430fb2 100644
--- a/utils/benchmark.py
+++ b/utils/benchmark.py
@@ -9,7 +9,7 @@
import pytablewriter
from stable_baselines3.common.results_plotter import load_results, ts2xy
-from utils.utils import get_latest_run_id, get_saved_hyperparams, get_trained_models
+from utils.utils import get_hf_trained_models, get_latest_run_id, get_saved_hyperparams, get_trained_models
parser = argparse.ArgumentParser()
parser.add_argument("--log-dir", help="Root log folder", default="rl-trained-agents/", type=str)
@@ -20,10 +20,15 @@
parser.add_argument("--seed", help="Random generator seed", type=int, default=0)
parser.add_argument("--test-mode", action="store_true", default=False, help="Do only one experiment (useful for testing)")
parser.add_argument("--with-mujoco", action="store_true", default=False, help="Run also MuJoCo envs")
+parser.add_argument("--no-hub", action="store_true", default=False, help="Do not download models from hub")
parser.add_argument("--num-threads", help="Number of threads for PyTorch", default=2, type=int)
args = parser.parse_args()
trained_models = get_trained_models(args.log_dir)
+
+if not args.no_hub:
+ trained_models.update(get_hf_trained_models())
+
n_experiments = len(trained_models)
results = {
"algo": [],
@@ -133,8 +138,12 @@
results_df = pd.DataFrame(results)
# Sort results
results_df = results_df.sort_values(by=["algo", "env_id"])
+# Create links to Huggingface hub
+# links = [f"[{env_id}](https://huggingface.co/sb3/{algo}-{env_id})"
+# for algo, env_id in zip(results_df["algo"], results_df["env_id"])]
+# results_df["env_id"] = links
-writer = pytablewriter.MarkdownTableWriter()
+writer = pytablewriter.MarkdownTableWriter(max_precision=3)
writer.from_dataframe(results_df)
header = """
@@ -147,6 +156,9 @@
It uses the deterministic policy except for Atari games.
+You can view each model card (it includes video and hyperparameters)
+on our Huggingface page: https://huggingface.co/sb3
+
*NOTE: this is not a quantitative benchmark as it corresponds to only one run
(cf [issue #38](https://github.com/araffin/rl-baselines-zoo/issues/38)).
This benchmark is meant to check algorithm (maximal) performance, find potential bugs
diff --git a/utils/callbacks.py b/utils/callbacks.py
index 608c5a778..27d485104 100644
--- a/utils/callbacks.py
+++ b/utils/callbacks.py
@@ -1,4 +1,5 @@
import os
+import pickle
import tempfile
import time
from copy import deepcopy
@@ -11,7 +12,7 @@
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
from stable_baselines3.common.logger import TensorBoardOutputFormat
-from stable_baselines3.common.vec_env import VecEnv
+from stable_baselines3.common.vec_env import VecEnv, sync_envs_normalization
class TrialEvalCallback(EvalCallback):
@@ -31,7 +32,7 @@ def __init__(
log_path: Optional[str] = None,
):
- super(TrialEvalCallback, self).__init__(
+ super().__init__(
eval_env=eval_env,
n_eval_episodes=n_eval_episodes,
eval_freq=eval_freq,
@@ -46,7 +47,7 @@ def __init__(
def _on_step(self) -> bool:
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
- super(TrialEvalCallback, self)._on_step()
+ super()._on_step()
self.eval_idx += 1
# report best or report current ?
# report num_timesteps or elasped time ?
@@ -69,7 +70,7 @@ class SaveVecNormalizeCallback(BaseCallback):
"""
def __init__(self, save_freq: int, save_path: str, name_prefix: Optional[str] = None, verbose: int = 0):
- super(SaveVecNormalizeCallback, self).__init__(verbose)
+ super().__init__(verbose)
self.save_freq = save_freq
self.save_path = save_path
self.name_prefix = name_prefix
@@ -111,7 +112,7 @@ class ParallelTrainCallback(BaseCallback):
"""
def __init__(self, gradient_steps: int = 100, verbose: int = 0, sleep_time: float = 0.0):
- super(ParallelTrainCallback, self).__init__(verbose)
+ super().__init__(verbose)
self.batch_size = 0
self._model_ready = True
self._model = None
@@ -130,6 +131,12 @@ def _init_callback(self) -> None:
self.model.save(temp_file)
+ if self.model.get_vec_normalize_env() is not None:
+ temp_file_norm = os.path.join("logs", "vec_normalize.pkl")
+
+ with open(temp_file_norm, "wb") as file_handler:
+ pickle.dump(self.model.get_vec_normalize_env(), file_handler)
+
# TODO: add support for other algorithms
for model_class in [SAC, TQC]:
if isinstance(self.model, model_class):
@@ -139,6 +146,11 @@ def _init_callback(self) -> None:
assert self.model_class is not None, f"{self.model} is not supported for parallel training"
self._model = self.model_class.load(temp_file)
+ if self.model.get_vec_normalize_env() is not None:
+ with open(temp_file_norm, "rb") as file_handler:
+ self._model._vec_normalize_env = pickle.load(file_handler)
+ self._model._vec_normalize_env.training = False
+
self.batch_size = self._model.batch_size
# Disable train method
@@ -183,6 +195,10 @@ def _on_rollout_end(self) -> None:
self._model.replay_buffer = deepcopy(self.model.replay_buffer)
self.model.set_parameters(deepcopy(self._model.get_parameters()))
self.model.actor = self.model.policy.actor
+ # Sync VecNormalize
+ if self.model.get_vec_normalize_env() is not None:
+ sync_envs_normalization(self.model.get_vec_normalize_env(), self._model._vec_normalize_env)
+
if self.num_timesteps >= self._model.learning_starts:
self.train()
# Do not wait for the training loop to finish
@@ -202,7 +218,7 @@ class RawStatisticsCallback(BaseCallback):
"""
def __init__(self, verbose=0):
- super(RawStatisticsCallback, self).__init__(verbose)
+ super().__init__(verbose)
# Custom counter to reports stats
# (and avoid reporting multiple values for the same step)
self._timesteps_counter = 0
@@ -227,3 +243,25 @@ def _on_step(self) -> bool:
self._tensorboard_writer.write(logger_dict, exclude_dict, self._timesteps_counter)
return True
+
+
+class LapTimeCallback(BaseCallback):
+ def _on_training_start(self):
+ self.n_laps = 0
+ output_formats = self.logger.output_formats
+ # Save reference to tensorboard formatter object
+ # note: the failure case (not formatter found) is not handled here, should be done with try/except.
+ self.tb_formatter = next(formatter for formatter in output_formats if isinstance(formatter, TensorBoardOutputFormat))
+
+ def _on_step(self) -> bool:
+ lap_count = self.locals["infos"][0]["lap_count"]
+ lap_time = self.locals["infos"][0]["last_lap_time"]
+
+ if lap_count != self.n_laps and lap_time > 0:
+ self.n_laps = lap_count
+ self.tb_formatter.writer.add_scalar("time/lap_time", lap_time, self.num_timesteps)
+ if lap_count == 1:
+ self.tb_formatter.writer.add_scalar("time/first_lap_time", lap_time, self.num_timesteps)
+ else:
+ self.tb_formatter.writer.add_scalar("time/second_lap_time", lap_time, self.num_timesteps)
+ self.tb_formatter.writer.flush()
diff --git a/utils/exp_manager.py b/utils/exp_manager.py
index 1b6725831..3970fe58b 100644
--- a/utils/exp_manager.py
+++ b/utils/exp_manager.py
@@ -12,9 +12,12 @@
import optuna
import torch as th
import yaml
+import zmq
from optuna.integration.skopt import SkoptSampler
-from optuna.pruners import BasePruner, MedianPruner, SuccessiveHalvingPruner
+from optuna.pruners import BasePruner, MedianPruner, NopPruner, SuccessiveHalvingPruner
from optuna.samplers import BaseSampler, RandomSampler, TPESampler
+from optuna.study import MaxTrialsCallback
+from optuna.trial import TrialState
from optuna.visualization import plot_optimization_history, plot_param_importances
from sb3_contrib.common.vec_env import AsyncEval
@@ -47,7 +50,7 @@
from utils.utils import ALGOS, get_callback_list, get_latest_run_id, get_wrapper_class, linear_schedule
-class ExperimentManager(object):
+class ExperimentManager:
"""
Experiment manager: read the hyperparameters,
preprocess them, create the environment and the RL model.
@@ -73,6 +76,7 @@ def __init__(
storage: Optional[str] = None,
study_name: Optional[str] = None,
n_trials: int = 1,
+ max_total_trials: Optional[int] = None,
n_jobs: int = 1,
sampler: str = "tpe",
pruner: str = "median",
@@ -90,7 +94,7 @@ def __init__(
no_optim_plots: bool = False,
device: Union[th.device, str] = "auto",
):
- super(ExperimentManager, self).__init__()
+ super().__init__()
self.algo = algo
self.env_id = env_id
# Custom params
@@ -103,8 +107,10 @@ def __init__(
self.frame_stack = None
self.seed = seed
self.optimization_log_path = optimization_log_path
+ self.vec_env_wrapper = None
self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type]
+ self.vec_env_wrapper = None
self.vec_env_kwargs = {}
# self.vec_env_kwargs = {} if vec_env_type == "dummy" else {"start_method": "fork"}
@@ -133,6 +139,7 @@ def __init__(
self.no_optim_plots = no_optim_plots
# maximum number of trials for finding the best hyperparams
self.n_trials = n_trials
+ self.max_total_trials = max_total_trials
# number of parallel jobs when doing hyperparameter search
self.n_jobs = n_jobs
self.sampler = sampler
@@ -164,7 +171,7 @@ def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]:
:return: the initialized RL model
"""
hyperparams, saved_hyperparams = self.read_hyperparameters()
- hyperparams, self.env_wrapper, self.callbacks = self._preprocess_hyperparams(hyperparams)
+ hyperparams, self.env_wrapper, self.callbacks, self.vec_env_wrapper = self._preprocess_hyperparams(hyperparams)
self.create_log_folder()
self.create_callbacks()
@@ -212,7 +219,7 @@ def learn(self, model: BaseAlgorithm) -> None:
try:
model.learn(self.n_timesteps, **kwargs)
- except KeyboardInterrupt:
+ except (KeyboardInterrupt, zmq.error.ZMQError):
# this allows to save the model when interrupting training
pass
finally:
@@ -260,7 +267,7 @@ def _save_config(self, saved_hyperparams: Dict[str, Any]) -> None:
def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Load hyperparameters from yaml file
- with open(f"hyperparams/{self.algo}.yml", "r") as f:
+ with open(f"hyperparams/{self.algo}.yml") as f:
hyperparams_dict = yaml.safe_load(f)
if self.env_id in list(hyperparams_dict.keys()):
hyperparams = hyperparams_dict[self.env_id]
@@ -322,7 +329,7 @@ def _preprocess_normalization(self, hyperparams: Dict[str, Any]) -> Dict[str, An
def _preprocess_hyperparams(
self, hyperparams: Dict[str, Any]
- ) -> Tuple[Dict[str, Any], Optional[Callable], List[BaseCallback]]:
+ ) -> Tuple[Dict[str, Any], Optional[Callable], List[BaseCallback], Optional[Callable]]:
self.n_envs = hyperparams.get("n_envs", 1)
if self.verbose > 0:
@@ -344,7 +351,7 @@ def _preprocess_hyperparams(
# Derive n_evaluations from number of timesteps if needed
if self.n_evaluations is None and self.optimize_hyperparameters:
- self.n_evaluations = self.n_timesteps // int(1e5)
+ self.n_evaluations = max(1, self.n_timesteps // int(1e5))
print(
f"Doing {self.n_evaluations} intermediate evaluations for pruning based on the number of timesteps."
" (1 evaluation every 100k timesteps)"
@@ -374,12 +381,17 @@ def _preprocess_hyperparams(
if "env_wrapper" in hyperparams.keys():
del hyperparams["env_wrapper"]
+ # Same for VecEnvWrapper
+ vec_env_wrapper = get_wrapper_class(hyperparams, "vec_env_wrapper")
+ if "vec_env_wrapper" in hyperparams.keys():
+ del hyperparams["vec_env_wrapper"]
+
callbacks = get_callback_list(hyperparams)
if "callback" in hyperparams.keys():
self.specified_callbacks = hyperparams["callback"]
del hyperparams["callback"]
- return hyperparams, env_wrapper, callbacks
+ return hyperparams, env_wrapper, callbacks, vec_env_wrapper
def _preprocess_action_noise(
self, hyperparams: Dict[str, Any], saved_hyperparams: Dict[str, Any], env: VecEnv
@@ -453,17 +465,17 @@ def create_callbacks(self):
@staticmethod
def is_atari(env_id: str) -> bool:
- entry_point = gym.envs.registry.env_specs[env_id].entry_point
+ entry_point = gym.envs.registry.env_specs[env_id].entry_point # pytype: disable=module-attr
return "AtariEnv" in str(entry_point)
@staticmethod
def is_bullet(env_id: str) -> bool:
- entry_point = gym.envs.registry.env_specs[env_id].entry_point
+ entry_point = gym.envs.registry.env_specs[env_id].entry_point # pytype: disable=module-attr
return "pybullet_envs" in str(entry_point)
@staticmethod
def is_robotics_env(env_id: str) -> bool:
- entry_point = gym.envs.registry.env_specs[env_id].entry_point
+ entry_point = gym.envs.registry.env_specs[env_id].entry_point # pytype: disable=module-attr
return "gym.envs.robotics" in str(entry_point) or "panda_gym.envs" in str(entry_point)
def _maybe_normalize(self, env: VecEnv, eval_env: bool) -> VecEnv:
@@ -537,6 +549,9 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)
monitor_kwargs=monitor_kwargs,
)
+ if self.vec_env_wrapper is not None:
+ env = self.vec_env_wrapper(env)
+
# Wrap the env into a VecNormalize wrapper if needed
# and load saved statistics when present
env = self._maybe_normalize(env, eval_env)
@@ -619,7 +634,7 @@ def _create_pruner(self, pruner_method: str) -> BasePruner:
pruner = MedianPruner(n_startup_trials=self.n_startup_trials, n_warmup_steps=self.n_evaluations // 3)
elif pruner_method == "none":
# Do not prune
- pruner = MedianPruner(n_startup_trials=self.n_trials, n_warmup_steps=self.n_evaluations)
+ pruner = NopPruner()
else:
raise ValueError(f"Unknown pruner: {pruner_method}")
return pruner
@@ -747,7 +762,28 @@ def hyperparameters_optimization(self) -> None:
)
try:
- study.optimize(self.objective, n_trials=self.n_trials, n_jobs=self.n_jobs)
+ if self.max_total_trials is not None:
+ # Note: we count already running trials here otherwise we get
+ # (max_total_trials + number of workers) trials in total.
+ counted_states = [
+ TrialState.COMPLETE,
+ TrialState.RUNNING,
+ TrialState.PRUNED,
+ ]
+ completed_trials = len(study.get_trials(states=counted_states))
+ if completed_trials < self.max_total_trials:
+ study.optimize(
+ self.objective,
+ n_jobs=self.n_jobs,
+ callbacks=[
+ MaxTrialsCallback(
+ self.max_total_trials,
+ states=counted_states,
+ )
+ ],
+ )
+ else:
+ study.optimize(self.objective, n_jobs=self.n_jobs, n_trials=self.n_trials)
except KeyboardInterrupt:
pass
diff --git a/utils/import_envs.py b/utils/import_envs.py
index fbe0370e3..1dddf2880 100644
--- a/utils/import_envs.py
+++ b/utils/import_envs.py
@@ -1,3 +1,8 @@
+import gym
+from gym.envs.registration import register
+
+from utils.wrappers import MaskVelocityWrapper
+
try:
import pybullet_envs # pytype: disable=import-error
except ImportError:
@@ -28,7 +33,35 @@
except ImportError:
gym_donkeycar = None
+try:
+ import rl_racing.envs # pytype: disable=import-error
+except ImportError:
+ rl_racing = None
+
+try:
+ import gym_space_engineers # pytype: disable=import-error
+except ImportError:
+ gym_space_engineers = None
+
try:
import panda_gym # pytype: disable=import-error
except ImportError:
panda_gym = None
+
+
+# Register no vel envs
+def create_no_vel_env(env_id: str):
+ def make_env():
+ env = gym.make(env_id)
+ env = MaskVelocityWrapper(env)
+ return env
+
+ return make_env
+
+
+for env_id in MaskVelocityWrapper.velocity_indices.keys():
+ name, version = env_id.split("-v")
+ register(
+ id=f"{name}NoVel-v{version}",
+ entry_point=create_no_vel_env(env_id),
+ )
diff --git a/utils/load_from_hub.py b/utils/load_from_hub.py
new file mode 100644
index 000000000..490ef1f8f
--- /dev/null
+++ b/utils/load_from_hub.py
@@ -0,0 +1,123 @@
+import argparse
+import os
+import shutil
+import zipfile
+from pathlib import Path
+from typing import Optional
+
+from huggingface_sb3 import load_from_hub
+from requests.exceptions import HTTPError
+
+from utils import ALGOS, get_latest_run_id
+
+
+def download_from_hub(
+ algo: str,
+ env_id: str,
+ exp_id: int,
+ folder: str,
+ organization: str,
+ repo_name: Optional[str] = None,
+ force: bool = False,
+) -> None:
+ """
+ Try to load a model from the Huggingface hub
+ and save it following the RL Zoo structure.
+ Default repo name is {organization}/{algo}-{env_id}
+ where repo_name = {algo}-{env_id}
+
+ :param algo: Algorithm
+ :param env_id: Environment id
+ :param exp_id: Experiment id
+ :param folder: Log folder
+ :param organization: Huggingface organization
+ :param repo_name: Overwrite default repository name
+ :param force: Allow overwritting the folder
+ if it already exists.
+ """
+
+ if repo_name is None:
+ repo_name = f"{algo}-{env_id}"
+
+ repo_id = f"{organization}/{repo_name}"
+ print(f"Downloading from https://huggingface.co/{repo_id}")
+
+ model_name = f"{algo}-{env_id}"
+
+ checkpoint = load_from_hub(repo_id, f"{model_name}.zip")
+ config_path = load_from_hub(repo_id, "config.yml")
+
+ # If VecNormalize, download
+ try:
+ vec_normalize_stats = load_from_hub(repo_id, "vec_normalize.pkl")
+ except HTTPError:
+ print("No normalization file")
+ vec_normalize_stats = None
+
+ saved_args = load_from_hub(repo_id, "args.yml")
+ env_kwargs = load_from_hub(repo_id, "env_kwargs.yml")
+ train_eval_metrics = load_from_hub(repo_id, "train_eval_metrics.zip")
+
+ if exp_id == 0:
+ exp_id = get_latest_run_id(os.path.join(folder, algo), env_id) + 1
+ # Sanity checks
+ if exp_id > 0:
+ log_path = os.path.join(folder, algo, f"{env_id}_{exp_id}")
+ else:
+ log_path = os.path.join(folder, algo)
+
+ # Check that the folder does not exist
+ log_folder = Path(log_path)
+ if log_folder.is_dir():
+ if force:
+ print(f"The folder {log_path} already exists, overwritting")
+ # Delete the current one to avoid errors
+ shutil.rmtree(log_path)
+ else:
+ raise ValueError(
+ f"The folder {log_path} already exists, use --force to overwrite it, "
+ "or choose '--exp-id 0' to create a new folder"
+ )
+
+ print(f"Saving to {log_path}")
+ # Create folder structure
+ os.makedirs(log_path, exist_ok=True)
+ config_folder = os.path.join(log_path, env_id)
+ os.makedirs(config_folder, exist_ok=True)
+
+ # Copy config files and saved stats
+ shutil.copy(checkpoint, os.path.join(log_path, f"{env_id}.zip"))
+ shutil.copy(saved_args, os.path.join(config_folder, "args.yml"))
+ shutil.copy(config_path, os.path.join(config_folder, "config.yml"))
+ shutil.copy(env_kwargs, os.path.join(config_folder, "env_kwargs.yml"))
+ if vec_normalize_stats is not None:
+ shutil.copy(vec_normalize_stats, os.path.join(config_folder, "vecnormalize.pkl"))
+
+ # Extract monitor file and evaluation file
+ with zipfile.ZipFile(train_eval_metrics, "r") as zip_ref:
+ zip_ref.extractall(log_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--env", help="environment ID", type=str, required=True)
+ parser.add_argument("-f", "--folder", help="Log folder", type=str, required=True)
+ parser.add_argument("-orga", "--organization", help="Huggingface hub organization", default="sb3")
+ parser.add_argument("-name", "--repo-name", help="Huggingface hub repository name, by default 'algo-env_id'", type=str)
+ parser.add_argument("--algo", help="RL Algorithm", type=str, required=True, choices=list(ALGOS.keys()))
+ parser.add_argument("--exp-id", help="Experiment ID (default: 0: latest, -1: no exp folder)", default=0, type=int)
+ parser.add_argument("--verbose", help="Verbose mode (0: no output, 1: INFO)", default=1, type=int)
+ parser.add_argument(
+ "--force", action="store_true", default=False, help="Allow overwritting exp folder if it already exist"
+ )
+ args = parser.parse_args()
+
+ download_from_hub(
+ algo=args.algo,
+ env_id=args.env,
+ exp_id=args.exp_id,
+ folder=args.folder,
+ organization=args.organization,
+ repo_name=args.repo_name,
+ force=args.force,
+ )
diff --git a/utils/push_to_hub.py b/utils/push_to_hub.py
new file mode 100644
index 000000000..b2a9ca878
--- /dev/null
+++ b/utils/push_to_hub.py
@@ -0,0 +1,403 @@
+import argparse
+import glob
+import os
+import shutil
+import zipfile
+from copy import deepcopy
+from pathlib import Path
+from pprint import pformat
+from typing import Any, Dict, Optional, Tuple
+
+import torch as th
+import yaml
+from huggingface_hub import HfApi, Repository
+from huggingface_hub.repocard import metadata_save
+from huggingface_sb3.push_to_hub import _evaluate_agent, _generate_replay, generate_metadata
+from stable_baselines3.common.base_class import BaseAlgorithm
+from stable_baselines3.common.utils import set_random_seed
+from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize
+from wasabi import Printer
+
+import utils.import_envs # noqa: F401 pylint: disable=unused-import
+from utils import ALGOS, create_test_env, get_saved_hyperparams
+from utils.exp_manager import ExperimentManager
+from utils.utils import StoreDict, get_model_path
+
+msg = Printer()
+
+
+def save_model_card(repo_dir: Path, generated_model_card: str, metadata: Dict[str, Any]) -> None:
+ """Saves a model card for the repository.
+
+ :param repo_dir: repository directory
+ :param generated_model_card: model card generated by _generate_model_card()
+ :param metadata: metadata
+ """
+ readme_path = repo_dir / "README.md"
+ # Always overwrite README
+ with readme_path.open("w", encoding="utf-8") as f:
+ f.write(generated_model_card)
+
+ # Save our metrics to Readme metadata
+ metadata_save(readme_path, metadata)
+
+
+def generate_model_card(
+ algo_name: str,
+ algo_class_name: str,
+ organization: str,
+ env_id: str,
+ mean_reward: float,
+ std_reward: float,
+ hyperparams: Dict[str, Any],
+ env_kwargs: Dict[str, Any],
+) -> Tuple[str, Dict[str, Any]]:
+ """
+ Generate the model card for the Hub
+
+ :param algo_class_name: name of the algorithm class
+ :param env_id: name of the environment
+ :param mean_reward: mean reward of the agent
+ :param std_reward: standard deviation of the mean reward of the agent
+ :return: Model card (readme) and metadata (performance, algo/env id, tags)
+ """
+ # Step 1: Select the tags
+ metadata = generate_metadata(algo_class_name, env_id, mean_reward, std_reward)
+
+ # Step 2: Generate the model card
+ model_card = f"""
+# **{algo_class_name}** Agent playing **{env_id}**
+This is a trained model of a **{algo_class_name}** agent playing **{env_id}**
+using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3)
+and the [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo).
+
+The RL Zoo is a training framework for Stable Baselines3
+reinforcement learning agents,
+with hyperparameter optimization and pre-trained agents included.
+"""
+
+ model_card += f"""
+## Usage (with SB3 RL Zoo)
+
+RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo
+SB3: https://github.com/DLR-RM/stable-baselines3
+SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
+
+```
+# Download model and save it into the logs/ folder
+python -m utils.load_from_hub --algo {algo_name} --env {env_id} -orga {organization} -f logs/
+python enjoy.py --algo {algo_name} --env {env_id} -f logs/
+```
+
+## Training (with the RL Zoo)
+```
+python train.py --algo {algo_name} --env {env_id} -f logs/
+# Upload the model and generate video (when possible)
+python -m utils.push_to_hub --algo {algo_name} --env {env_id} -f logs/ -orga {organization}
+```
+
+## Hyperparameters
+```python
+{pformat(hyperparams)}
+```
+"""
+ if len(env_kwargs) > 0:
+ model_card += f"""
+# Environment Arguments
+```python
+{pformat(env_kwargs)}
+```
+"""
+
+ return model_card, metadata
+
+
+def package_to_hub(
+ model: BaseAlgorithm,
+ model_name: str,
+ algo_name: str,
+ algo_class_name: str,
+ log_path: Path,
+ hyperparams: Dict[str, Any],
+ env_kwargs: Dict[str, Any],
+ env_id: str,
+ eval_env: VecEnv,
+ repo_id: str,
+ commit_message: str,
+ is_deterministic: bool = True,
+ n_eval_episodes=10,
+ token: Optional[str] = None,
+ local_repo_path="hub",
+ video_length=1000,
+ generate_video: bool = False,
+):
+ """
+ Evaluate, Generate a video and Upload a model to Hugging Face Hub.
+ This method does the complete pipeline:
+ - It evaluates the model
+ - It generates the model card
+ - It generates a replay video of the agent
+ - It pushes everything to the hub
+
+ This is a work in progress function, if it does not work,
+ use `push_to_hub` method.
+
+ :param model: trained model
+ :param model_name: name of the model zip file
+ :param algo_name: alias used in the zoo for the algorithm,
+ usually lower case of the class (a2c, ars, ppo, ppo_lstm)
+ :param algo_class_name: name of the architecture of your model
+ Name of the algorithm class.
+ (DQN, PPO, A2C, SAC, RecurrentPPO, ...)
+ :param log_path: Path to where the model is saved in the zoo.
+ :param hyperparams: Hyperparameters used for training,
+ includes wrappers.
+ :param env_kwargs: Additional keyword arguments that were passed
+ to the environment.
+ :param env_id: name of the environment
+ :param eval_env: environment used to evaluate the agent
+ :param repo_id: id of the model repository from the Hugging Face Hub
+ :param commit_message: commit message
+ :param is_deterministic: use deterministic or stochastic actions (by default: True)
+ :param n_eval_episodes: number of evaluation episodes (by default: 10)
+ :param local_repo_path: local repository path
+ :param video_length: length of the video (in timesteps)
+ """
+
+ msg.info(
+ "This function will save, evaluate, generate a video of your agent, "
+ "create a model card and push everything to the hub. "
+ "It might take up to some minutes if video generation is activated. "
+ "This is a work in progress: if you encounter a bug, please open an issue."
+ )
+
+ organization, repo_name = repo_id.split("/")
+
+ # Step 1: Clone or create the repo
+ # Create the repo (or clone its content if it's nonempty)
+ api = HfApi()
+
+ repo_url = api.create_repo(
+ name=repo_name,
+ token=token,
+ organization=organization,
+ private=False,
+ exist_ok=True,
+ )
+
+ # Git pull
+ repo_local_path = Path(local_repo_path) / repo_name
+ repo = Repository(repo_local_path, clone_from=repo_url, use_auth_token=True)
+ repo.git_pull(rebase=True)
+
+ repo.lfs_track(["*.mp4"])
+
+ # Step 1: Save the model
+ model.save(repo_local_path / model_name)
+
+ # Retrieve VecNormalize wrapper if it exists
+ # we need to save the statistics
+ maybe_vec_normalize = unwrap_vec_normalize(eval_env)
+
+ # Save the normalization
+ if maybe_vec_normalize is not None:
+ maybe_vec_normalize.save(repo_local_path / "vec_normalize.pkl")
+ # Do not update the stats at test time
+ maybe_vec_normalize.training = False
+ # Reward normalization is not needed at test time
+ maybe_vec_normalize.norm_reward = False
+
+ # Unzip the model
+ with zipfile.ZipFile(repo_local_path / f"{model_name}.zip", "r") as zip_ref:
+ zip_ref.extractall(repo_local_path / model_name)
+
+ # Step 2: Copy config files
+ args_path = log_path / env_id / "args.yml"
+ config_path = log_path / env_id / "config.yml"
+
+ shutil.copy(args_path, repo_local_path / "args.yml")
+ shutil.copy(config_path, repo_local_path / "config.yml")
+ with open(repo_local_path / "env_kwargs.yml", "w") as outfile:
+ yaml.dump(env_kwargs, outfile)
+
+ # Copy train/eval metrics into zip
+ with zipfile.ZipFile(repo_local_path / "train_eval_metrics.zip", "w") as archive:
+ if os.path.isfile(log_path / "evaluations.npz"):
+ archive.write(log_path / "evaluations.npz", arcname="evaluations.npz")
+ for monitor_file in glob.glob(f"{log_path}/*.csv"):
+ archive.write(monitor_file, arcname=monitor_file.split(os.sep)[-1])
+
+ # Step 3: Evaluate the agent
+ mean_reward, std_reward = _evaluate_agent(model, eval_env, n_eval_episodes, is_deterministic, repo_local_path)
+
+ # Step 4: Generate a video
+ if generate_video:
+ _generate_replay(model, eval_env, video_length, is_deterministic, repo_local_path)
+ # Cleanup files after generation
+ # TODO: upstream to huggingface sb3
+ video_path = Path("test.mp4")
+ if video_path.is_file():
+ video_path.unlink()
+ json_path = list(glob.glob("*.meta.json"))
+ if len(json_path) > 0:
+ Path(json_path[0]).unlink()
+
+ # Step 5: Generate the model card
+ generated_model_card, metadata = generate_model_card(
+ algo_name,
+ algo_class_name,
+ organization,
+ env_id,
+ mean_reward,
+ std_reward,
+ hyperparams,
+ env_kwargs,
+ )
+
+ save_model_card(repo_local_path, generated_model_card, metadata)
+
+ msg.info(f"Pushing repo {repo_name} to the Hugging Face Hub")
+ repo.push_to_hub(commit_message=commit_message)
+
+ msg.info(f"Your model is pushed to the hub. You can view your model here: {repo_url}")
+ return repo_url
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--env", help="environment ID", type=str, required=True)
+ parser.add_argument("-f", "--folder", help="Log folder", type=str, required=True)
+ parser.add_argument("--algo", help="RL Algorithm", type=str, required=True, choices=list(ALGOS.keys()))
+ parser.add_argument("-n", "--n-timesteps", help="number of timesteps", default=1000, type=int)
+ parser.add_argument("--num-threads", help="Number of threads for PyTorch (-1 to use default)", default=-1, type=int)
+ parser.add_argument("--n-envs", help="number of environments", default=1, type=int)
+ parser.add_argument("--exp-id", help="Experiment ID (default: 0: latest, -1: no exp folder)", default=0, type=int)
+ parser.add_argument("--verbose", help="Verbose mode (0: no output, 1: INFO)", default=1, type=int)
+ parser.add_argument(
+ "--no-render", action="store_true", default=False, help="Do not render the environment (useful for tests)"
+ )
+ parser.add_argument("--deterministic", action="store_true", default=False, help="Use deterministic actions")
+ parser.add_argument("--device", help="PyTorch device to be use (ex: cpu, cuda...)", default="auto", type=str)
+ parser.add_argument(
+ "--load-best", action="store_true", default=False, help="Load best model instead of last model if available"
+ )
+ parser.add_argument(
+ "--load-checkpoint",
+ type=int,
+ help="Load checkpoint instead of last model if available, "
+ "you must pass the number of timesteps corresponding to it",
+ )
+ parser.add_argument(
+ "--load-last-checkpoint",
+ action="store_true",
+ default=False,
+ help="Load last checkpoint instead of last model if available",
+ )
+ parser.add_argument("--stochastic", action="store_true", default=False, help="Use stochastic actions")
+ parser.add_argument("--seed", help="Random generator seed", type=int, default=0)
+ parser.add_argument(
+ "--env-kwargs", type=str, nargs="+", action=StoreDict, help="Optional keyword argument to pass to the env constructor"
+ )
+ parser.add_argument("-orga", "--organization", help="Huggingface hub organization", type=str, required=True)
+ parser.add_argument("-name", "--repo-name", help="Huggingface hub repository name, by default 'algo-env_id'", type=str)
+ parser.add_argument("-m", "--commit-message", help="Commit message", default="Initial commit", type=str)
+
+ args = parser.parse_args()
+ env_id = args.env
+ algo = args.algo
+
+ _, model_path, log_path = get_model_path(
+ args.exp_id,
+ args.folder,
+ args.algo,
+ args.env,
+ args.load_best,
+ args.load_checkpoint,
+ args.load_last_checkpoint,
+ )
+
+ print(f"Loading {model_path}")
+
+ # Off-policy algorithm only support one env for now
+ off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"]
+
+ if algo in off_policy_algos:
+ args.n_envs = 1
+
+ set_random_seed(args.seed)
+
+ if args.num_threads > 0:
+ if args.verbose > 1:
+ print(f"Setting torch.num_threads to {args.num_threads}")
+ th.set_num_threads(args.num_threads)
+
+ is_atari = ExperimentManager.is_atari(env_id)
+
+ stats_path = os.path.join(log_path, env_id)
+ hyperparams, stats_path = get_saved_hyperparams(stats_path, test_mode=True)
+
+ # load env_kwargs if existing
+ env_kwargs = {}
+ args_path = os.path.join(log_path, env_id, "args.yml")
+ if os.path.isfile(args_path):
+ with open(args_path) as f:
+ loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr
+ if loaded_args["env_kwargs"] is not None:
+ env_kwargs = loaded_args["env_kwargs"]
+ # overwrite with command line arguments
+ if args.env_kwargs is not None:
+ env_kwargs.update(args.env_kwargs)
+
+ eval_env = create_test_env(
+ env_id,
+ n_envs=args.n_envs,
+ stats_path=stats_path,
+ seed=args.seed,
+ log_dir=None,
+ should_render=not args.no_render,
+ hyperparams=deepcopy(hyperparams),
+ env_kwargs=env_kwargs,
+ )
+
+ kwargs = dict(seed=args.seed)
+ if algo in off_policy_algos:
+ # Dummy buffer size as we don't need memory to enjoy the trained agent
+ kwargs.update(dict(buffer_size=1))
+
+ # Note: we assume that we push models using the same machine (same python version)
+ # that trained them, if not, we would need to pass custom object as in enjoy.py
+ custom_objects = {}
+ model = ALGOS[algo].load(model_path, env=eval_env, custom_objects=custom_objects, device=args.device, **kwargs)
+
+ # Deterministic by default except for atari games
+ stochastic = args.stochastic or is_atari and not args.deterministic
+ deterministic = not stochastic
+
+ # Default model name, the model will be saved under "{algo}-{env_id}.zip"
+ model_name = f"{algo}-{env_id}"
+
+ if args.repo_name is None:
+ args.repo_name = model_name
+
+ repo_id = f"{args.organization}/{args.repo_name}"
+ print(f"Uploading to {repo_id}, make sure to have the rights")
+
+ package_to_hub(
+ model,
+ model_name,
+ algo,
+ ALGOS[algo].__name__,
+ Path(log_path),
+ hyperparams,
+ env_kwargs,
+ env_id,
+ eval_env,
+ repo_id=repo_id,
+ commit_message=args.commit_message,
+ is_deterministic=deterministic,
+ n_eval_episodes=10,
+ token=None,
+ local_repo_path="hub",
+ video_length=1000,
+ generate_video=not args.no_render,
+ )
diff --git a/utils/record_video.py b/utils/record_video.py
index dd89c0220..0bb66b114 100644
--- a/utils/record_video.py
+++ b/utils/record_video.py
@@ -2,12 +2,13 @@
import os
import sys
+import numpy as np
import yaml
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import VecVideoRecorder
from utils.exp_manager import ExperimentManager
-from utils.utils import ALGOS, StoreDict, create_test_env, get_latest_run_id, get_saved_hyperparams
+from utils.utils import ALGOS, StoreDict, create_test_env, get_model_path, get_saved_hyperparams
if __name__ == "__main__": # noqa: C901
parser = argparse.ArgumentParser()
@@ -32,6 +33,12 @@
type=int,
help="Load checkpoint instead of last model if available, you must pass the number of timesteps corresponding to it",
)
+ parser.add_argument(
+ "--load-last-checkpoint",
+ action="store_true",
+ default=False,
+ help="Load last checkpoint instead of last model if available",
+ )
parser.add_argument(
"--env-kwargs", type=str, nargs="+", action=StoreDict, help="Optional keyword argument to pass to the env constructor"
)
@@ -44,35 +51,18 @@
seed = args.seed
video_length = args.n_timesteps
n_envs = args.n_envs
- load_best = args.load_best
- load_checkpoint = args.load_checkpoint
-
- if args.exp_id == 0:
- args.exp_id = get_latest_run_id(os.path.join(folder, algo), env_id)
- print(f"Loading latest experiment, id={args.exp_id}")
- # Sanity checks
- if args.exp_id > 0:
- log_path = os.path.join(folder, algo, f"{env_id}_{args.exp_id}")
- else:
- log_path = os.path.join(folder, algo)
-
- assert os.path.isdir(log_path), f"The {log_path} folder was not found"
-
- if load_best:
- model_path = os.path.join(log_path, "best_model.zip")
- name_prefix = f"best-model-{algo}-{env_id}"
- elif load_checkpoint is None:
- # Default: load latest model
- model_path = os.path.join(log_path, f"{env_id}.zip")
- name_prefix = f"final-model-{algo}-{env_id}"
- else:
- model_path = os.path.join(log_path, f"rl_model_{args.load_checkpoint}_steps.zip")
- name_prefix = f"checkpoint-{args.load_checkpoint}-{algo}-{env_id}"
-
- found = os.path.isfile(model_path)
- if not found:
- raise ValueError(f"No model found for {algo} on {env_id}, path: {model_path}")
+ name_prefix, model_path, log_path = get_model_path(
+ args.exp_id,
+ folder,
+ algo,
+ env_id,
+ args.load_best,
+ args.load_checkpoint,
+ args.load_last_checkpoint,
+ )
+
+ print(f"Loading {model_path}")
off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"]
set_random_seed(args.seed)
@@ -86,7 +76,7 @@
env_kwargs = {}
args_path = os.path.join(log_path, env_id, "args.yml")
if os.path.isfile(args_path):
- with open(args_path, "r") as f:
+ with open(args_path) as f:
loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr
if loaded_args["env_kwargs"] is not None:
env_kwargs = loaded_args["env_kwargs"]
@@ -126,8 +116,6 @@
model = ALGOS[algo].load(model_path, env=env, custom_objects=custom_objects, **kwargs)
- obs = env.reset()
-
# Deterministic by default except for atari games
stochastic = args.stochastic or is_atari and not args.deterministic
deterministic = not stochastic
@@ -144,11 +132,19 @@
name_prefix=name_prefix,
)
- env.reset()
+ obs = env.reset()
+ lstm_states = None
+ episode_starts = np.ones((env.num_envs,), dtype=bool)
try:
for _ in range(video_length + 1):
- action, _ = model.predict(obs, deterministic=deterministic)
- obs, _, _, _ = env.step(action)
+ action, lstm_states = model.predict(
+ obs,
+ state=lstm_states,
+ episode_start=episode_starts,
+ deterministic=deterministic,
+ )
+ obs, _, dones, _ = env.step(action)
+ episode_starts = dones
if not args.no_render:
env.render()
except KeyboardInterrupt:
diff --git a/utils/teleop.py b/utils/teleop.py
new file mode 100644
index 000000000..fa20356a1
--- /dev/null
+++ b/utils/teleop.py
@@ -0,0 +1,380 @@
+import os
+import time
+from typing import List, Optional, Tuple
+
+import numpy as np
+import pygame
+from pygame.locals import * # noqa: F403
+from sb3_contrib import TQC
+from stable_baselines3.common.base_class import BaseAlgorithm
+from stable_baselines3.common.buffers import ReplayBuffer
+from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
+
+# TELEOP_RATE = 1 / 60
+
+UP = (1, 0)
+LEFT = (0, 1)
+RIGHT = (0, -1)
+DOWN = (-1, 0)
+STOP = (0, 0)
+KEY_CODE_SPACE = 32
+
+MAX_TURN = 1
+# Smoothing constants
+STEP_THROTTLE = 0.8
+STEP_TURN = 0.8
+
+GREEN = (72, 205, 40)
+RED = (205, 39, 46)
+GREY = (187, 179, 179)
+BLACK = (36, 36, 36)
+WHITE = (230, 230, 230)
+ORANGE = (200, 110, 0)
+
+# pytype: disable=name-error
+moveBindingsGame = {K_UP: UP, K_LEFT: LEFT, K_RIGHT: RIGHT, K_DOWN: DOWN}
+# pytype: enable=name-error
+pygame.font.init()
+FONT = pygame.font.SysFont("Open Sans", 25)
+SMALL_FONT = pygame.font.SysFont("Open Sans", 20)
+KEY_MIN_DELAY = 0.4
+
+
+def control(x, theta, control_throttle, control_steering):
+ """
+ Smooth control.
+
+ :param x: (float)
+ :param theta: (float)
+ :param control_throttle: (float)
+ :param control_steering: (float)
+ :return: (float, float)
+ """
+ target_throttle = x
+ target_steering = MAX_TURN * theta
+ if target_throttle > control_throttle:
+ control_throttle = min(target_throttle, control_throttle + STEP_THROTTLE)
+ elif target_throttle < control_throttle:
+ control_throttle = max(target_throttle, control_throttle - STEP_THROTTLE)
+ else:
+ control_throttle = target_throttle
+
+ if target_steering > control_steering:
+ control_steering = min(target_steering, control_steering + STEP_TURN)
+ elif target_steering < control_steering:
+ control_steering = max(target_steering, control_steering - STEP_TURN)
+ else:
+ control_steering = target_steering
+ return control_throttle, control_steering
+
+
+class HumanTeleop(BaseAlgorithm):
+ def __init__(
+ self,
+ policy,
+ env,
+ buffer_size=50000,
+ tensorboard_log=None,
+ verbose=0,
+ seed=None,
+ device=None,
+ _init_setup_model=False,
+ scale_human=1.0,
+ scale_model=0.0,
+ model_path=os.environ.get("MODEL_PATH"),
+ deterministic=True,
+ ):
+ super(HumanTeleop, self).__init__(
+ policy=None, env=env, policy_base=None, learning_rate=0.0, verbose=verbose, seed=seed
+ )
+
+ # pytype: disable=name-error
+ # self.button_switch_mode = K_m
+ # self.button_toggle_train_mode = K_t
+ # pytype: enable=name-error
+
+ # TODO: add to model buffer + allow training in separate thread
+
+ # Used to prevent from multiple successive key press
+ self.last_time_pressed = {}
+ self.event_buttons = None
+ self.action = np.zeros((2,))
+ self.exit_thread = False
+ self.process = None
+ self.window = None
+ self.buffer_size = buffer_size
+ self.replay_buffer = ReplayBuffer(
+ buffer_size, self.observation_space, self.action_space, self.device, optimize_memory_usage=False
+ )
+ self.scale_human = scale_human
+ self.scale_model = scale_model
+ # self.start_process()
+ self.model = None
+ self.deterministic = deterministic
+ # Pretrained model
+ if model_path is not None:
+ # "logs/tqc/donkey-generated-track-v0_9/donkey-generated-track-v0.zip"
+ self.model = TQC.load(model_path)
+
+ def _excluded_save_params(self) -> List[str]:
+ """
+ Returns the names of the parameters that should be excluded by default
+ when saving the model.
+
+ :return: (List[str]) List of parameters that should be excluded from save
+ """
+ # Exclude aliases
+ return super()._excluded_save_params() + ["process", "window", "model", "exit_thread"]
+
+ def _setup_model(self):
+ self.exit_thread = False
+
+ def init_buttons(self):
+ """
+ Initialize the last_time_pressed timers that prevent
+ successive key press.
+ """
+ self.event_buttons = [
+ # self.button_switch_mode,
+ # self.button_toggle_train_mode,
+ ]
+ for key in self.event_buttons:
+ self.last_time_pressed[key] = 0
+
+ # def start_process(self):
+ # """Start main loop process."""
+ # # Reset last time pressed
+ # self.init_buttons()
+ # self.process = Thread(target=self.main_loop)
+ # # Make it a deamon, so it will be deleted at the same time
+ # # of the main process
+ # self.process.daemon = True
+ # self.process.start()
+
+ def check_key(self, keys, key):
+ """
+ Check if a key was pressed and update associated timer.
+
+ :param keys: (dict)
+ :param key: (any hashable type)
+ :return: (bool) Returns true when a given key was pressed, False otherwise
+ """
+ if key is None:
+ return False
+ if keys[key] and (time.time() - self.last_time_pressed[key]) > KEY_MIN_DELAY:
+ # avoid multiple key press
+ self.last_time_pressed[key] = time.time()
+ return True
+ return False
+
+ def handle_keys_event(self, keys):
+ """
+ Handle the events induced by key press:
+ e.g. change of mode, toggling recording, ...
+ """
+
+ # Switch from "MANUAL" to "AUTONOMOUS" mode
+ # if self.check_key(keys, self.button_switch_mode) or self.check_key(keys, self.button_pause):
+ # self.is_manual = not self.is_manual
+
+ def _sample_action(self) -> Tuple[np.ndarray, np.ndarray]:
+ unscaled_action, _ = self.model.predict(self._last_obs, deterministic=self.deterministic)
+
+ # Rescale the action from [low, high] to [-1, 1]
+ scaled_action = self.model.policy.scale_action(unscaled_action)
+ # We store the scaled action in the buffer
+ buffer_action = scaled_action
+ action = self.model.policy.unscale_action(scaled_action)
+ return action, buffer_action
+
+ def main_loop(self, total_timesteps=-1):
+ """
+ Pygame loop that listens to keyboard events.
+ """
+ pygame.init()
+ # Create a pygame window
+ self.window = pygame.display.set_mode((800, 500), RESIZABLE) # pytype: disable=name-error
+
+ # Init values and fill the screen
+ control_throttle, control_steering = 0, 0
+ action = [control_steering, control_throttle]
+
+ self.update_screen(action)
+
+ n_steps = 0
+ buffer_action = np.array([[0.0, 0.0]])
+
+ while not self.exit_thread:
+ x, theta = 0, 0
+ # Record pressed keys
+ keys = pygame.key.get_pressed()
+ for keycode in moveBindingsGame.keys():
+ if keys[keycode]:
+ x_tmp, th_tmp = moveBindingsGame[keycode]
+ x += x_tmp
+ theta += th_tmp
+
+ self.handle_keys_event(keys)
+
+ # Smooth control for teleoperation
+ control_throttle, control_steering = control(x, theta, control_throttle, control_steering)
+ scaled_action = np.array([[-control_steering, control_throttle]]).astype(np.float32)
+
+ # Use trained RL model action
+ if self.model is not None:
+ _, buffer_action_model = self._sample_action()
+ else:
+ buffer_action_model = 0.0
+ scaled_action = self.scale_human * scaled_action + self.scale_model * buffer_action_model
+ scaled_action = np.clip(scaled_action, -1.0, 1.0)
+
+ # We store the scaled action in the buffer
+ buffer_action = scaled_action
+
+ # No need to unnormalized normally as the bounds are [-1, 1]
+ if self.model is not None:
+ action = self.model.policy.unscale_action(scaled_action)
+ else:
+ action = scaled_action.copy()
+
+ self.action = action.copy()[0]
+
+ if self.model is not None:
+ if (
+ self.model.use_sde
+ and self.model.sde_sample_freq > 0
+ and n_steps % self.model.sde_sample_freq == 0
+ and not self.deterministic
+ ):
+ # Sample a new noise matrix
+ self.model.actor.reset_noise()
+
+ new_obs, reward, done, infos = self.env.step(action)
+
+ next_obs = new_obs
+ if done and infos[0].get("terminal_observation") is not None:
+ next_obs = infos[0]["terminal_observation"]
+
+ self.replay_buffer.add(self._last_obs, next_obs, buffer_action, reward, done, infos)
+
+ self._last_obs = new_obs
+
+ self.update_screen(self.action)
+
+ n_steps += 1
+ if total_timesteps > 0:
+ self.exit_thread = n_steps >= total_timesteps
+
+ if done:
+ print(f"{n_steps} steps")
+ # reset values
+ control_throttle, control_steering = 0, 0
+
+ for event in pygame.event.get():
+ if (event.type == QUIT or event.type == KEYDOWN) and event.key in [ # pytype: disable=name-error
+ K_ESCAPE, # pytype: disable=name-error
+ K_q, # pytype: disable=name-error
+ ]:
+ self.exit_thread = True
+ pygame.display.flip()
+ # Limit FPS
+ # pygame.time.Clock().tick(1 / TELEOP_RATE)
+
+ def write_text(self, text, x, y, font, color=GREY):
+ """
+ :param text: (str)
+ :param x: (int)
+ :param y: (int)
+ :param font: (str)
+ :param color: (tuple)
+ """
+ text = str(text)
+ text = font.render(text, True, color)
+ self.window.blit(text, (x, y))
+
+ def clear(self):
+ self.window.fill((0, 0, 0))
+
+ def update_screen(self, action):
+ """
+ Update pygame window.
+
+ :param action: ([float])
+ """
+ self.clear()
+ steering, throttle = action
+ self.write_text("Throttle: {:.2f}, Steering: {:.2f}".format(throttle, steering), 20, 0, FONT, WHITE)
+
+ def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
+ """
+ Get the name of the torch variables that will be saved.
+ ``th.save`` and ``th.load`` will be used with the right device
+ instead of the default pickling strategy.
+
+ :return: (Tuple[List[str], List[str]])
+ name of the variables with state dicts to save, name of additional torch tensors,
+ """
+ return [], []
+
+ def learn(
+ self,
+ total_timesteps,
+ callback=None,
+ log_interval=100,
+ tb_log_name="run",
+ eval_env=None,
+ eval_freq=-1,
+ n_eval_episodes=5,
+ eval_log_path=None,
+ reset_num_timesteps=True,
+ ) -> "HumanTeleop":
+ self._last_obs = self.env.reset()
+ # Wait for teleop process
+ # time.sleep(3)
+ self.main_loop(total_timesteps)
+
+ # with threading:
+ # for _ in range(total_timesteps):
+ # print(np.array([self.action]))
+ # self.env.step(np.array([self.action]))
+ # if self.exit_thread:
+ # break
+ return self
+
+ def predict(
+ self,
+ observation: np.ndarray,
+ state: Optional[np.ndarray] = None,
+ mask: Optional[np.ndarray] = None,
+ deterministic: bool = False,
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
+ """
+ Get the model's action(s) from an observation
+
+ :param observation: (np.ndarray) the input observation
+ :param state: (Optional[np.ndarray]) The last states (can be None, used in recurrent policies)
+ :param mask: (Optional[np.ndarray]) The last masks (can be None, used in recurrent policies)
+ :param deterministic: (bool) Whether or not to return deterministic actions.
+ :return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state
+ (used in recurrent policies)
+ """
+ return self.action, None
+
+ def save_replay_buffer(self, path) -> None:
+ """
+ Save the replay buffer as a pickle file.
+
+ :param path: (Union[str,pathlib.Path, io.BufferedIOBase]) Path to the file where the replay buffer should be saved.
+ if path is a str or pathlib.Path, the path is automatically created if necessary.
+ """
+ assert self.replay_buffer is not None, "The replay buffer is not defined"
+ save_to_pkl(path, self.replay_buffer, self.verbose)
+
+ def load_replay_buffer(self, path, truncate_last_traj=False) -> None:
+ """
+ Load a replay buffer from a pickle file.
+
+ :param path: (Union[str, pathlib.Path, io.BufferedIOBase]) Path to the pickled replay buffer.
+ """
+ self.replay_buffer = load_from_pkl(path, self.verbose)
+ assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class"
diff --git a/utils/utils.py b/utils/utils.py
index 6072cc7cd..a53d2cc3e 100644
--- a/utils/utils.py
+++ b/utils/utils.py
@@ -2,19 +2,32 @@
import glob
import importlib
import os
+import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import gym
+import numpy as np
import stable_baselines3 as sb3 # noqa: F401
import torch as th # noqa: F401
import yaml
-from sb3_contrib import ARS, QRDQN, TQC, TRPO
+from huggingface_hub import HfApi
+from sb3_contrib import ARS, QRDQN, TQC, TRPO, RecurrentPPO
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike # noqa: F401
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv, VecFrameStack, VecNormalize
+try:
+ from utils.teleop import HumanTeleop
+except ImportError:
+ HumanTeleop = None
+
+try:
+ from rl_racing.utils.teleop import HumanTeleop as Teleop
+except ImportError:
+ Teleop = None
+
# For custom activation fn
from torch import nn as nn # noqa: F401 pylint: disable=unused-import
@@ -25,11 +38,14 @@
"ppo": PPO,
"sac": SAC,
"td3": TD3,
+ "human": HumanTeleop,
+ "teleop": Teleop,
# SB3 Contrib,
"ars": ARS,
"qrdqn": QRDQN,
"tqc": TQC,
"trpo": TRPO,
+ "ppo_lstm": RecurrentPPO,
}
@@ -42,10 +58,12 @@ def flatten_dict_observations(env: gym.Env) -> gym.Env:
return gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))
-def get_wrapper_class(hyperparams: Dict[str, Any]) -> Optional[Callable[[gym.Env], gym.Env]]:
+def get_wrapper_class(hyperparams: Dict[str, Any], key: str = "env_wrapper") -> Optional[Callable[[gym.Env], gym.Env]]:
"""
Get one or more Gym environment wrapper class specified as a hyper parameter
"env_wrapper".
+ Works also for VecEnvWrapper with the key "vec_env_wrapper".
+
e.g.
env_wrapper: gym_minigrid.wrappers.FlatObsWrapper
@@ -67,8 +85,8 @@ def get_module_name(wrapper_name):
def get_class_name(wrapper_name):
return wrapper_name.split(".")[-1]
- if "env_wrapper" in hyperparams.keys():
- wrapper_name = hyperparams.get("env_wrapper")
+ if key in hyperparams.keys():
+ wrapper_name = hyperparams.get(key)
if wrapper_name is None:
return None
@@ -204,6 +222,11 @@ def create_test_env(
if "env_wrapper" in hyperparams.keys():
del hyperparams["env_wrapper"]
+ # Ignore for now
+ # TODO: handle it properly
+ if "vec_env_wrapper" in hyperparams.keys():
+ del hyperparams["vec_env_wrapper"]
+
vec_env_kwargs = {}
vec_env_cls = DummyVecEnv
if n_envs > 1 or (ExperimentManager.is_bullet(env_id) and should_render):
@@ -223,6 +246,12 @@ def create_test_env(
vec_env_kwargs=vec_env_kwargs,
)
+ if "vec_env_wrapper" in hyperparams.keys():
+
+ vec_env_wrapper = get_wrapper_class(hyperparams, "vec_env_wrapper")
+ env = vec_env_wrapper(env)
+ del hyperparams["vec_env_wrapper"]
+
# Load saved stats for normalizing input and rewards
# And optionally stack frames
if stats_path is not None:
@@ -282,6 +311,29 @@ def get_trained_models(log_folder: str) -> Dict[str, Tuple[str, str]]:
return trained_models
+def get_hf_trained_models(organization: str = "sb3") -> Dict[str, Tuple[str, str]]:
+ """
+ Get pretrained models,
+ available on the Hugginface hub for a given organization.
+
+ :param organization:
+ :return: Dict representing the trained agents
+ """
+ api = HfApi()
+ models = api.list_models(author=organization)
+ regex = re.compile(r"^(?P[a-z_0-9]+)-(?P[a-zA-Z0-9]+-v[0-9]+)$")
+ trained_models = {}
+ for model in models:
+ # Remove organization
+ repo_id = model.modelId.split(f"{organization}/")[1]
+ result = regex.match(repo_id)
+ # Skip demo repo that does not fit the pattern
+ if result is not None:
+ algo, env_id = result.group("algo"), result.group("env_id")
+ trained_models[f"{algo}-{env_id}"] = (algo, env_id)
+ return trained_models
+
+
def get_latest_run_id(log_path: str, env_id: str) -> int:
"""
Returns the latest run number for the given log name and log path,
@@ -318,7 +370,7 @@ def get_saved_hyperparams(
config_file = os.path.join(stats_path, "config.yml")
if os.path.isfile(config_file):
# Load saved hyperparameters
- with open(os.path.join(stats_path, "config.yml"), "r") as f:
+ with open(os.path.join(stats_path, "config.yml")) as f:
hyperparams = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr
hyperparams["normalize"] = hyperparams.get("normalize", False)
else:
@@ -347,7 +399,7 @@ class StoreDict(argparse.Action):
def __init__(self, option_strings, dest, nargs=None, **kwargs):
self._nargs = nargs
- super(StoreDict, self).__init__(option_strings, dest, nargs=nargs, **kwargs)
+ super().__init__(option_strings, dest, nargs=nargs, **kwargs)
def __call__(self, parser, namespace, values, option_string=None):
arg_dict = {}
@@ -357,3 +409,120 @@ def __call__(self, parser, namespace, values, option_string=None):
# Evaluate the string as python code
arg_dict[key] = eval(value)
setattr(namespace, self.dest, arg_dict)
+
+
+def evaluate_policy_add_to_buffer(
+ model,
+ env,
+ n_eval_episodes=10,
+ deterministic=True,
+ render=False,
+ callback=None,
+ reward_threshold=None,
+ return_episode_rewards=False,
+ replay_buffer=None,
+):
+ """
+ Runs policy for ``n_eval_episodes`` episodes and returns average reward.
+ This is made to work only with one env.
+ :param model: (BaseAlgorithm) The RL agent you want to evaluate.
+ :param env: (gym.Env or VecEnv) The gym environment. In the case of a ``VecEnv``
+ this must contain only one environment.
+ :param n_eval_episodes: (int) Number of episode to evaluate the agent
+ :param deterministic: (bool) Whether to use deterministic or stochastic actions
+ :param render: (bool) Whether to render the environment or not
+ :param callback: (callable) callback function to do additional checks,
+ called after each step.
+ :param reward_threshold: (float) Minimum expected reward per episode,
+ this will raise an error if the performance is not met
+ :param return_episode_rewards: (bool) If True, a list of reward per episode
+ will be returned instead of the mean.
+ :return: (float, float) Mean reward per episode, std of reward per episode
+ returns ([float], [int]) when ``return_episode_rewards`` is True
+ """
+ if isinstance(env, VecEnv):
+ assert env.num_envs == 1, "You must pass only one environment when using this function"
+
+ episode_rewards, episode_lengths = [], []
+ for _ in range(n_eval_episodes):
+ obs = env.reset()
+ done, state = False, None
+ episode_reward = 0.0
+ episode_length = 0
+ if model.use_sde:
+ model.actor.reset_noise()
+
+ while not done:
+ action, state = model.predict(obs, state=state, deterministic=deterministic)
+ new_obs, reward, done, info = env.step(action)
+ episode_reward += reward
+ if callback is not None:
+ callback(locals(), globals())
+ episode_length += 1
+ if replay_buffer is not None:
+ # We assume actions are normalized but not observation/reward
+ buffer_action = action
+ replay_buffer.add(obs, new_obs, buffer_action, reward, done, info)
+ obs = new_obs
+ if render:
+ env.render()
+ episode_rewards.append(episode_reward)
+ episode_lengths.append(episode_length)
+ mean_reward = np.mean(episode_rewards)
+ std_reward = np.std(episode_rewards)
+ if reward_threshold is not None:
+ assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"
+ if return_episode_rewards:
+ return episode_rewards, episode_lengths
+ return mean_reward, std_reward
+
+
+def get_model_path(
+ exp_id: int,
+ folder: str,
+ algo: str,
+ env_id: str,
+ load_best: bool = False,
+ load_checkpoint: Optional[str] = None,
+ load_last_checkpoint: bool = False,
+) -> Tuple[str, str, str]:
+
+ if exp_id == 0:
+ exp_id = get_latest_run_id(os.path.join(folder, algo), env_id)
+ print(f"Loading latest experiment, id={exp_id}")
+ # Sanity checks
+ if exp_id > 0:
+ log_path = os.path.join(folder, algo, f"{env_id}_{exp_id}")
+ else:
+ log_path = os.path.join(folder, algo)
+
+ assert os.path.isdir(log_path), f"The {log_path} folder was not found"
+
+ if load_best:
+ model_path = os.path.join(log_path, "best_model.zip")
+ name_prefix = f"best-model-{algo}-{env_id}"
+ elif load_checkpoint is not None:
+ model_path = os.path.join(log_path, f"rl_model_{load_checkpoint}_steps.zip")
+ name_prefix = f"checkpoint-{load_checkpoint}-{algo}-{env_id}"
+ elif load_last_checkpoint:
+ checkpoints = glob.glob(os.path.join(log_path, "rl_model_*_steps.zip"))
+ if len(checkpoints) == 0:
+ raise ValueError(f"No checkpoint found for {algo} on {env_id}, path: {log_path}")
+
+ def step_count(checkpoint_path: str) -> int:
+ # path follow the pattern "rl_model_*_steps.zip", we count from the back to ignore any other _ in the path
+ return int(checkpoint_path.split("_")[-2])
+
+ checkpoints = sorted(checkpoints, key=step_count)
+ model_path = checkpoints[-1]
+ name_prefix = f"checkpoint-{step_count(model_path)}-{algo}-{env_id}"
+ else:
+ # Default: load latest model
+ model_path = os.path.join(log_path, f"{env_id}.zip")
+ name_prefix = f"final-model-{algo}-{env_id}"
+
+ found = os.path.isfile(model_path)
+ if not found:
+ raise ValueError(f"No model found for {algo} on {env_id}, path: {model_path}")
+
+ return name_prefix, model_path, log_path
diff --git a/utils/wrappers.py b/utils/wrappers.py
index 9cdaf783f..9d177c7af 100644
--- a/utils/wrappers.py
+++ b/utils/wrappers.py
@@ -1,7 +1,52 @@
+import os
+from copy import deepcopy
+from typing import Optional
+
import gym
import numpy as np
+import torch as th
from sb3_contrib.common.wrappers import TimeFeatureWrapper # noqa: F401 (backward compatibility)
from scipy.signal import iirfilter, sosfilt, zpk2sos
+from stable_baselines3 import SAC
+from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
+
+
+class VecForceResetWrapper(VecEnvWrapper):
+ """
+ For all environments to reset at once,
+ and tell the agent the trajectory was truncated.
+
+ :param venv: The vectorized environment
+ """
+
+ def __init__(self, venv: VecEnv):
+ super().__init__(venv=venv)
+
+ def reset(self) -> VecEnvObs:
+ return self.venv.reset()
+
+ def step_wait(self) -> VecEnvStepReturn:
+ for env_idx in range(self.num_envs):
+ obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step(
+ self.actions[env_idx]
+ )
+ self._save_obs(env_idx, obs)
+
+ if self.buf_dones.any():
+ for env_idx in range(self.num_envs):
+ self.buf_infos[env_idx]["terminal_observation"] = self.buf_obs[None][env_idx]
+ if not self.buf_dones[env_idx]:
+ self.buf_infos[env_idx]["TimeLimit.truncated"] = True
+ self.buf_dones[env_idx] = True
+ obs = self.envs[env_idx].reset()
+ self._save_obs(env_idx, obs)
+
+ return (
+ self._obs_from_buf(),
+ np.copy(self.buf_rews),
+ np.copy(self.buf_dones),
+ deepcopy(self.buf_infos),
+ )
class DoneOnSuccessWrapper(gym.Wrapper):
@@ -11,7 +56,7 @@ class DoneOnSuccessWrapper(gym.Wrapper):
"""
def __init__(self, env: gym.Env, reward_offset: float = 0.0, n_successes: int = 1):
- super(DoneOnSuccessWrapper, self).__init__(env)
+ super().__init__(env)
self.reward_offset = reward_offset
self.n_successes = n_successes
self.current_successes = 0
@@ -41,12 +86,12 @@ class ActionNoiseWrapper(gym.Wrapper):
Add gaussian noise to the action (without telling the agent),
to test the robustness of the control.
- :param env: (gym.Env)
- :param noise_std: (float) Standard deviation of the noise
+ :param env:
+ :param noise_std: Standard deviation of the noise
"""
- def __init__(self, env, noise_std=0.1):
- super(ActionNoiseWrapper, self).__init__(env)
+ def __init__(self, env: gym.Env, noise_std: float = 0.1):
+ super().__init__(env)
self.noise_std = noise_std
def step(self, action):
@@ -95,13 +140,13 @@ class LowPassFilterWrapper(gym.Wrapper):
"""
Butterworth-Lowpass
- :param env: (gym.Env)
+ :param env:
:param freq: Filter corner frequency.
:param df: Sampling rate in Hz.
"""
- def __init__(self, env, freq=5.0, df=25.0):
- super(LowPassFilterWrapper, self).__init__(env)
+ def __init__(self, env: gym.Env, freq: float = 5.0, df: float = 25.0):
+ super().__init__(env)
self.freq = freq
self.df = df
self.signal = []
@@ -123,12 +168,12 @@ class ActionSmoothingWrapper(gym.Wrapper):
"""
Smooth the action using exponential moving average.
- :param env: (gym.Env)
- :param smoothing_coef: (float) Smoothing coefficient (0 no smoothing, 1 very smooth)
+ :param env:
+ :param smoothing_coef: Smoothing coefficient (0 no smoothing, 1 very smooth)
"""
- def __init__(self, env, smoothing_coef: float = 0.0):
- super(ActionSmoothingWrapper, self).__init__(env)
+ def __init__(self, env: gym.Env, smoothing_coef: float = 0.0):
+ super().__init__(env)
self.smoothing_coef = smoothing_coef
self.smoothed_action = None
# from https://github.com/rail-berkeley/softlearning/issues/3
@@ -152,12 +197,12 @@ class DelayedRewardWrapper(gym.Wrapper):
Delay the reward by `delay` steps, it makes the task harder but more realistic.
The reward is accumulated during those steps.
- :param env: (gym.Env)
- :param delay: (int) Number of steps the reward should be delayed.
+ :param env:
+ :param delay: Number of steps the reward should be delayed.
"""
- def __init__(self, env, delay=10):
- super(DelayedRewardWrapper, self).__init__(env)
+ def __init__(self, env: gym.Env, delay: int = 10):
+ super().__init__(env)
self.delay = delay
self.current_step = 0
self.accumulated_reward = 0.0
@@ -185,11 +230,11 @@ class HistoryWrapper(gym.Wrapper):
"""
Stack past observations and actions to give an history to the agent.
- :param env: (gym.Env)
- :param horizon: (int) Number of steps to keep in the history.
+ :param env:
+ :param horizon:Number of steps to keep in the history.
"""
- def __init__(self, env: gym.Env, horizon: int = 5):
+ def __init__(self, env: gym.Env, horizon: int = 2):
assert isinstance(env.observation_space, gym.spaces.Box)
wrapped_obs_space = env.observation_space
@@ -208,7 +253,7 @@ def __init__(self, env: gym.Env, horizon: int = 5):
# Overwrite the observation space
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=wrapped_obs_space.dtype)
- super(HistoryWrapper, self).__init__(env)
+ super().__init__(env)
self.horizon = horizon
self.low_action, self.high_action = low_action, high_action
@@ -244,11 +289,11 @@ class HistoryWrapperObsDict(gym.Wrapper):
"""
History Wrapper for dict observation.
- :param env: (gym.Env)
- :param horizon: (int) Number of steps to keep in the history.
+ :param env:
+ :param horizon: Number of steps to keep in the history.
"""
- def __init__(self, env, horizon=5):
+ def __init__(self, env: gym.Env, horizon: int = 2):
assert isinstance(env.observation_space.spaces["observation"], gym.spaces.Box)
wrapped_obs_space = env.observation_space.spaces["observation"]
@@ -267,7 +312,7 @@ def __init__(self, env, horizon=5):
# Overwrite the observation space
env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=wrapped_obs_space.dtype)
- super(HistoryWrapperObsDict, self).__init__(env)
+ super().__init__(env)
self.horizon = horizon
self.low_action, self.high_action = low_action, high_action
@@ -305,3 +350,175 @@ def step(self, action):
obs_dict["observation"] = self._create_obs_from_history()
return obs_dict, reward, done, info
+
+
+class ResidualExpertWrapper(gym.Wrapper):
+ """
+ :param env:
+ :param model_path:
+ :param add_expert_to_obs:
+ :param residual_scale:
+ """
+
+ def __init__(
+ self,
+ env: gym.Env,
+ model_path: Optional[str] = os.environ.get("MODEL_PATH"),
+ add_expert_to_obs: bool = True,
+ residual_scale: float = 0.2,
+ expert_scale: float = 1.0,
+ d3rlpy_model: bool = False,
+ ):
+ assert isinstance(env.observation_space, gym.spaces.Box)
+ assert model_path is not None
+
+ wrapped_obs_space = env.observation_space
+
+ low = np.concatenate((wrapped_obs_space.low, np.finfo(np.float32).min * np.ones(2)))
+ high = np.concatenate((wrapped_obs_space.high, np.finfo(np.float32).max * np.ones(2)))
+
+ # Overwrite the observation space
+ env.observation_space = gym.spaces.Box(low=low, high=high, dtype=wrapped_obs_space.dtype)
+
+ super(ResidualExpertWrapper, self).__init__(env)
+
+ print(f"Loading model from {model_path}")
+ if d3rlpy_model:
+ self.model = th.jit.load(model_path)
+ else:
+ self.model = SAC.load(model_path)
+ self.d3rlpy_model = d3rlpy_model
+ self._last_obs = None
+ self.residual_scale = residual_scale
+ self.expert_scale = expert_scale
+ self.add_expert_to_obs = add_expert_to_obs
+
+ def _predict(self, obs):
+ # TODO: move to gpu when possible
+ if self.d3rlpy_model:
+ expert_action = self.model(th.tensor(obs).reshape(1, -1)).cpu().numpy()[0, :]
+ else:
+ expert_action, _ = self.model.predict(obs, deterministic=True)
+ if self.add_expert_to_obs:
+ obs = np.concatenate((obs, expert_action), axis=-1)
+ return obs, expert_action
+
+ def reset(self):
+ obs = self.env.reset()
+ obs, self.expert_action = self._predict(obs)
+ return obs
+
+ def step(self, action):
+ action = np.clip(self.expert_scale * self.expert_action + self.residual_scale * action, -1.0, 1.0)
+ obs, reward, done, info = self.env.step(action)
+ obs, self.expert_action = self._predict(obs)
+
+ return obs, reward, done, info
+
+
+class ContinuityCostWrapper(gym.Wrapper):
+ """
+ Add continuity cost to the reward.
+ It assumes that the action space is normalized
+ and symmetric (actions in [-1, 1]).
+ :param env:
+ :param weight_continuity:
+ """
+
+ def __init__(self, env: gym.Env, weight_continuity: float = 0.0, condition: bool = False):
+ super(ContinuityCostWrapper, self).__init__(env)
+ self.last_action = None
+ self.weight_continuity = weight_continuity
+ self.condition = condition
+ if condition:
+ self.weight_continuity = 1.0
+
+ def reset(self):
+ self.last_action = None
+ return self.env.reset()
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+ # Continuity cost
+ if self.last_action is not None:
+ max_delta = 2.0 # for the action space: high - low = 1 - (-1) = 2
+ continuity_cost = np.mean((action - self.last_action) ** 2 / max_delta**2)
+ continuity_cost = self.weight_continuity * continuity_cost
+ else:
+ continuity_cost = 0.0
+ self.last_action = action.copy()
+
+ if self.condition:
+ reward = (1 - continuity_cost) * reward
+ else:
+ reward -= continuity_cost
+ return obs, reward, done, info
+
+
+class FrameSkip(gym.Wrapper):
+ """
+ Return only every ``skip``-th frame (frameskipping)
+
+ :param env: the environment
+ :param skip: number of ``skip``-th frame
+ """
+
+ def __init__(self, env: gym.Env, skip: int = 4):
+ super().__init__(env)
+ self._skip = skip
+
+ def step(self, action: np.ndarray):
+ """
+ Step the environment with the given action
+ Repeat action, sum reward.
+
+ :param action: the action
+ :return: observation, reward, done, information
+ """
+ total_reward = 0.0
+ done = None
+ for _ in range(self._skip):
+ obs, reward, done, info = self.env.step(action)
+ total_reward += reward
+ if done:
+ break
+
+ return obs, total_reward, done, info
+
+ def reset(self):
+ return self.env.reset()
+
+
+class MaskVelocityWrapper(gym.ObservationWrapper):
+ """
+ Gym environment observation wrapper used to mask velocity terms in
+ observations. The intention is the make the MDP partially observable.
+ Adapted from https://github.com/LiuWenlin595/FinalProject.
+
+ :param env: Gym environment
+ """
+
+ # Supported envs
+ velocity_indices = {
+ "CartPole-v1": np.array([1, 3]),
+ "MountainCar-v0": np.array([1]),
+ "MountainCarContinuous-v0": np.array([1]),
+ "Pendulum-v1": np.array([2]),
+ "LunarLander-v2": np.array([2, 3, 5]),
+ "LunarLanderContinuous-v2": np.array([2, 3, 5]),
+ }
+
+ def __init__(self, env: gym.Env):
+ super().__init__(env)
+
+ env_id: str = env.unwrapped.spec.id
+ # By default no masking
+ self.mask = np.ones_like((env.observation_space.sample()))
+ try:
+ # Mask velocity
+ self.mask[self.velocity_indices[env_id]] = 0.0
+ except KeyError:
+ raise NotImplementedError(f"Velocity masking not implemented for {env_id}")
+
+ def observation(self, observation: np.ndarray) -> np.ndarray:
+ return observation * self.mask
diff --git a/version.txt b/version.txt
index 33271c4d0..511e75b2e 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-1.5.1a0
+1.5.1a8