Skip to content

Commit 20366b0

Browse files
taufeeque9ernestumZiyueWang25
authored
Add scripts and configs for hyperparameter tuning (#675)
* Merge py file changes from benchmark-algs * Clean parallel script * Undo the changes from #653 to the dagger benchmark config files. This change just made some error messages go away indicating the missing imitation.algorithms.dagger.ExponentialBetaSchedule but it did not fix the root cause. * Improve readability and interpretability of benchmarking tests. * Add pxponential beta scheduler for dagger * Ignore coverage for unknown algorithms. * Cleanup and extend tests for beta schedules in dagger. * Add optuna to dependencies * Fix test case * Clean up the scripts * Remove reporter(done) since mean_return is reported by the runs * Add beta_schedule parameter to dagger script * Update config policy kwargs * Changes from review * Fix errors with some configs * Updates based on review * Change metric everywhere * Separate tuning code from parallel.py * Fix docstring * Removing resume option as it is getting tricky to correctly implement * Minor fixes * Updates from review * fix lint error * Add documentation for using the tuning script * Fix lint error * Updates from the review * Fix file name test errors * Add tune_run_kwargs in parallel script * Fix test errors * Fix test * Fix lint * Updates from review * Simplify few lines of code * Updates from review * Fix test * Revert "Fix test" This reverts commit 8b55134. * Fix test * Convert Dict to Mapping in input argument * Ignore coverage in script configurations. * Pin huggingface_sb3 version. * Update to the newest seals environment versions. * Push gymnasium dependency to 0.29 to ensure mujoco envs work. * Incorporate review comments * Fix test errors * Move benchmarking/ to scripts/ and add named configs for tuned hyperparams * Bump cache version & remove unnecessary files * Include tuned hyperparam json files in package data * Update storage hash * Update search space of bc * update benchmark and hyper parameter tuning readme * Update README.md * Incorporate reviewer's comments in benchmarking readme * Update gymnasium version and render mode in eval policy * Fix error * Update commands.py hex strings --------- Co-authored-by: Maximilian Ernestus <[email protected]> Co-authored-by: ZiyueWang25 <[email protected]>
1 parent f099c33 commit 20366b0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1023
-264
lines changed

benchmarking/README.md

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,42 @@
11
# Benchmarking imitation
22

3-
This directory contains sacred configuration files for benchmarking imitation's algorithms. For v0.3.2, these correspond to the hyperparameters used in the paper [imitation: Clean Imitation Learning Implementations](https://www.rocamonde.com/publication/gleave-imitation-2022/).
3+
The `src/imitation/scripts/config/tuned_hps` directory provides the tuned hyperparameter configs for benchmarking imitation. For v0.4.0, these correspond to the hyperparameters used in the paper [imitation: Clean Imitation Learning Implementations](https://arxiv.org/abs/2211.11972).
44

5-
Configuration files can be loaded either from the CLI or from the Python API. The examples below assume that your current working directory is the root of the `imitation` repository. This is not necessarily the case and you should adjust your paths accordingly.
5+
Configuration files can be loaded either from the CLI or from the Python API.
66

77
## CLI
88

99
```bash
10-
python -m imitation.scripts.<train_script> <algo> with benchmarking/<config_name>.json
10+
python -m imitation.scripts.<train_script> <algo> with <algo>_<env>
1111
```
12-
`train_script` can be either 1) `train_imitation` with `algo` as `bc` or `dagger` or 2) `train_adversarial` with `algo` as `gail` or `airl`.
12+
`train_script` can be either 1) `train_imitation` with `algo` as `bc` or `dagger` or 2) `train_adversarial` with `algo` as `gail` or `airl`. The `env` can be either of `seals_ant`, `seals_half_cheetah`, `seals_hopper`, `seals_swimmer`, or `seals_walker`. The hyperparameters for other environments are not tuned yet. You may be able to get reasonable performance by using hyperparameters tuned for a similar environment; alternatively, you can tune the hyperparameters using the `tuning` script.
1313

1414
## Python
1515

1616
```python
1717
...
18-
ex.add_config('benchmarking/<config_name>.json')
18+
from imitation.scripts.<train_script> import <train_ex>
19+
<train_ex>.run(command_name="<algo>", named_configs=["<algo>_<env>"])
1920
```
21+
22+
# Tuning Hyperparameters
23+
24+
The hyperparameters of any algorithm in imitation can be tuned using `src/imitation/scripts/tuning.py`.
25+
The benchmarking hyperparameter configs were generated by tuning the hyperparameters using
26+
the search space defined in the `scripts/config/tuning.py`.
27+
28+
The tuning script proceeds in two phases:
29+
1. Tune the hyperparameters using the search space provided.
30+
2. Re-evaluate the best hyperparameter config found in the first phase based on the maximum mean return on a separate set of seeds. Report the mean and standard deviation of these trials.
31+
32+
To use it with the default search space:
33+
```bash
34+
python -m imitation.scripts.tuning with <algo> 'parallel_run_config.base_named_configs=["<env>"]'
35+
```
36+
37+
In this command:
38+
- `<algo>` provides the default search space and settings for the specific algorithm, which is defined in the `scripts/config/tuning.py`
39+
- `<env>` sets the environment to tune the algorithm in. They are defined in the algo-specifc `scripts/config/train_[adversarial|imitation|preference_comparisons|rl].py` files. For the already tuned environments, use the `<algo>_<env>` named configs here.
40+
41+
See the documentation of `scripts/tuning.py` and `scripts/parallel.py` for many other arguments that can be
42+
provided through the command line to change the tuning behavior.

benchmarking/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def clean_config_file(file: pathlib.Path, write_path: pathlib.Path, /) -> None:
7979

8080
remove_empty_dicts(config)
8181
# files are of the format
82-
# /path/to/file/example_<algo>_<env>_best_hp_eval/<other_info>/sacred/1/config.json
82+
# /path/to/file/<algo>_<env>_best_hp_eval/<other_info>/sacred/1/config.json
8383
# we want to write to /<write_path>/<algo>_<env>.json
8484
with open(write_path / f"{file.parents[3].name}.json", "w") as f:
8585
json.dump(config, f, indent=4)

experiments/commands.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,24 @@
1212
1313
For example, we can run:
1414
15+
TUNED_HPS_DIR=../src/imitation/scripts/config/tuned_hps
1516
python commands.py \
1617
--name=run0 \
17-
--cfg_pattern=../benchmarking/*ai*_seals_walker_*.json \
18+
--cfg_pattern=$TUNED_HPS_DIR/*ai*_seals_walker_*.json \
1819
--output_dir=output
1920
2021
And get the following commands printed out:
2122
2223
python -m imitation.scripts.train_adversarial airl \
2324
--capture=sys --name=run0 \
2425
--file_storage=output/sacred/$USER-cmd-run0-airl-0-a3531726 \
25-
with ../benchmarking/example_airl_seals_walker_best_hp_eval.json \
26+
with ../src/imitation/scripts/config/tuned_hps/airl_seals_walker_best_hp_eval.json \
2627
seed=0 logging.log_root=output
2728
2829
python -m imitation.scripts.train_adversarial gail \
2930
--capture=sys --name=run0 \
3031
--file_storage=output/sacred/$USER-cmd-run0-gail-0-a1ec171b \
31-
with ../benchmarking/example_gail_seals_walker_best_hp_eval.json \
32+
with $TUNED_HPS_DIR/gail_seals_walker_best_hp_eval.json \
3233
seed=0 logging.log_root=output
3334
3435
We can execute commands in parallel by piping them to GNU parallel:
@@ -40,9 +41,10 @@
4041
4142
For example, we can run:
4243
44+
TUNED_HPS_DIR=../src/imitation/scripts/config/tuned_hps
4345
python commands.py \
4446
--name=run0 \
45-
--cfg_pattern=../benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json \
47+
--cfg_pattern=$TUNED_HPS_DIR/bc_seals_half_cheetah_best_hp_eval.json \
4648
--output_dir=/data/output \
4749
--remote
4850
@@ -51,8 +53,9 @@
5153
ctl job run --name $USER-cmd-run0-bc-0-72cb1df3 \
5254
--command "python -m imitation.scripts.train_imitation bc \
5355
--capture=sys --name=run0 \
54-
--file_storage=/data/output/sacred/$USER-cmd-run0-bc-0-72cb1df3 \
55-
with /data/imitation/benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json \
56+
--file_storage=/data/output/sacred/$USER-cmd-run0-bc-0-72cb1df3 with \
57+
/data/imitation/src/imitation/scripts/config/tuned_hps/
58+
bc_seals_half_cheetah_best_hp_eval.json \
5659
seed=0 logging.log_root=/data/output" \
5760
--container hacobe/devbox:imitation \
5861
--login --force-pull --never-restart --gpu 0 --shared-host-dir-mount /data
@@ -85,7 +88,7 @@ def _get_algo_name(cfg_file: str) -> str:
8588
"""Get the algorithm name from the given config filename."""
8689
algo_names = set()
8790
for key in _ALGO_NAME_TO_SCRIPT_NAME:
88-
if cfg_file.find("_" + key + "_") != -1:
91+
if cfg_file.find(key + "_") != -1:
8992
algo_names.add(key)
9093

9194
if len(algo_names) == 0:
@@ -121,7 +124,7 @@ def main(args: argparse.Namespace) -> None:
121124
else:
122125
cfg_path = os.path.join(args.remote_cfg_dir, cfg_file)
123126

124-
cfg_id = _get_cfg_id(cfg_path)
127+
cfg_id = _get_cfg_id(cfg_file)
125128

126129
for seed in args.seeds:
127130
cmd_id = _CMD_ID_TEMPLATE.format(
@@ -177,19 +180,19 @@ def parse() -> argparse.Namespace:
177180
parser.add_argument(
178181
"--cfg_pattern",
179182
type=str,
180-
default="example_bc_seals_half_cheetah_best_hp_eval.json",
183+
default="bc_seals_half_cheetah_best_hp_eval.json",
181184
help="""Generate a command for every file that matches this glob pattern. \
182185
Each matching file should be a config file that has its algorithm name \
183186
(bc, dagger, airl or gail) bookended by underscores in the filename. \
184187
If the --remote flag is enabled, then generate a command for every file in the \
185188
--remote_cfg_dir directory that has the same filename as a file that matches this \
186189
glob pattern. E.g., suppose the current, local working directory is 'foo' and \
187-
the subdirectory 'foo/bar' contains the config files 'example_bc_best.json' and \
188-
'example_dagger_best.json'. If the pattern 'bar/*.json' is supplied, then globbing \
189-
will return ['bar/example_bc_best.json', 'bar/example_dagger_best.json']. \
190+
the subdirectory 'foo/bar' contains the config files 'bc_best.json' and \
191+
'dagger_best.json'. If the pattern 'bar/*.json' is supplied, then globbing \
192+
will return ['bar/bc_best.json', 'bar/dagger_best.json']. \
190193
If the --remote flag is enabled, 'bar' will be replaced with `remote_cfg_dir` and \
191194
commands will be created for the following configs: \
192-
[`remote_cfg_dir`/example_bc_best.json, `remote_cfg_dir`/example_dagger_best.json] \
195+
[`remote_cfg_dir`/bc_best.json, `remote_cfg_dir`/dagger_best.json] \
193196
Why not just supply the pattern '`remote_cfg_dir`/*.json' directly? \
194197
Because the `remote_cfg_dir` directory may not exist on the local machine.""",
195198
)
@@ -220,7 +223,7 @@ def parse() -> argparse.Namespace:
220223
parser.add_argument(
221224
"--remote_cfg_dir",
222225
type=str,
223-
default="/data/imitation/benchmarking",
226+
default="/data/imitation/src/imitation/scripts/config/tuned_hps",
224227
help="""Path to a directory storing config files \
225228
accessible from each container. """,
226229
)

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ source = imitation
4545
include=
4646
src/*
4747
tests/*
48+
omit =
49+
src/imitation/scripts/config/*
4850

4951
[coverage:report]
5052
exclude_lines =

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
182182
python_requires=">=3.8.0",
183183
packages=find_packages("src"),
184184
package_dir={"": "src"},
185-
package_data={"imitation": ["py.typed", "envs/examples/airl_envs/assets/*.xml"]},
185+
package_data={"imitation": ["py.typed", "scripts/config/tuned_hps/*.json"]},
186186
# Note: while we are strict with our test and doc requirement versions, we try to
187187
# impose as little restrictions on the install requirements as possible. Try to
188188
# encode only known incompatibilities here. This prevents nasty dependency issues
@@ -200,6 +200,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
200200
"sacred>=0.8.4",
201201
"tensorboard>=1.14",
202202
"huggingface_sb3~=3.0",
203+
"optuna>=3.0.1",
203204
"datasets>=2.8.0",
204205
],
205206
tests_require=TESTS_REQUIRE,

src/imitation/scripts/analyze.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -262,38 +262,47 @@ def analyze_imitation(
262262
csv_output_path: If provided, then save a CSV output file to this path.
263263
tex_output_path: If provided, then save a LaTeX-format table to this path.
264264
print_table: If True, then print the dataframe to stdout.
265-
table_verbosity: Increasing levels of verbosity, from 0 to 2, increase the
266-
number of columns in the table.
265+
table_verbosity: Increasing levels of verbosity, from 0 to 3, increase the
266+
number of columns in the table. Level 3 prints all of the columns available.
267267
268268
Returns:
269269
The DataFrame generated from the Sacred logs.
270270
"""
271-
table_entry_fns_subset = _get_table_entry_fns_subset(table_verbosity)
271+
# Get column names for which we have get value using make_entry_fn
272+
# These are same across Level 2 & 3. In Level 3, we additionally add remaining
273+
# config columns.
274+
table_entry_fns_subset = _get_table_entry_fns_subset(min(table_verbosity, 2))
272275

273-
rows = []
276+
output_table = pd.DataFrame()
274277
for sd in _gather_sacred_dicts():
275-
row = {}
278+
if table_verbosity == 3:
279+
# gets all config columns
280+
row = pd.json_normalize(sd.config)
281+
else:
282+
# create an empty dataframe with a single row
283+
row = pd.DataFrame(index=[0])
284+
276285
for col_name, make_entry_fn in table_entry_fns_subset.items():
277286
row[col_name] = make_entry_fn(sd)
278-
rows.append(row)
279287

280-
df = pd.DataFrame(rows)
281-
if len(df) > 0:
282-
df.sort_values(by=["algo", "env_name"], inplace=True)
288+
output_table = pd.concat([output_table, row])
289+
290+
if len(output_table) > 0:
291+
output_table.sort_values(by=["algo", "env_name"], inplace=True)
283292

284-
display_options = dict(index=False)
293+
display_options: Mapping[str, Any] = dict(index=False)
285294
if csv_output_path is not None:
286-
df.to_csv(csv_output_path, **display_options)
295+
output_table.to_csv(csv_output_path, **display_options)
287296
print(f"Wrote CSV file to {csv_output_path}")
288297
if tex_output_path is not None:
289-
s: str = df.to_latex(**display_options)
298+
s: str = output_table.to_latex(**display_options)
290299
with open(tex_output_path, "w") as f:
291300
f.write(s)
292301
print(f"Wrote TeX file to {tex_output_path}")
293302

294303
if print_table:
295-
print(df.to_string(**display_options))
296-
return df
304+
print(output_table.to_string(**display_options))
305+
return output_table
297306

298307

299308
def _make_return_summary(stats: dict, prefix="") -> str:

src/imitation/scripts/config/analyze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def config():
1818
tex_output_path = None # Write LaTex output to this path
1919
print_table = True # Set to True to print analysis to stdout
2020
split_str = "," # str used to split source_dir_str into multiple source dirs
21-
table_verbosity = 1 # Choose from 0, 1, or 2
21+
table_verbosity = 1 # Choose from 0, 1, 2 or 3
2222
source_dirs = None
2323

2424

src/imitation/scripts/config/parallel.py

Lines changed: 13 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
`@parallel_ex.named_config` to define a new parallel experiment.
66
77
Adding custom named configs is necessary because the CLI interface can't add
8-
search spaces to the config like `"seed": tune.grid_search([0, 1, 2, 3])`.
8+
search spaces to the config like `"seed": tune.choice([0, 1, 2, 3])`.
9+
10+
For tuning hyperparameters of an algorithm on a given environment,
11+
check out the imitation/scripts/tuning.py script.
912
"""
1013

1114
import numpy as np
@@ -31,19 +34,10 @@ def config():
3134
"config_updates": {},
3235
} # `config` argument to `ray.tune.run(trainable, config)`
3336

34-
local_dir = None # `local_dir` arg for `ray.tune.run`
35-
upload_dir = None # `upload_dir` arg for `ray.tune.run`
36-
n_seeds = 3 # Number of seeds to search over by default
37-
38-
39-
@parallel_ex.config
40-
def seeds(n_seeds):
41-
search_space = {"config_updates": {"seed": tune.grid_search(list(range(n_seeds)))}}
42-
43-
44-
@parallel_ex.named_config
45-
def s3():
46-
upload_dir = "s3://shwang-chai/private"
37+
num_samples = 1 # Number of samples per grid search configuration
38+
repeat = 1 # Number of times to repeat a sampled configuration
39+
experiment_checkpoint_path = "" # Path to checkpoint of experiment
40+
tune_run_kwargs = {} # Additional kwargs to pass to `tune.run`
4741

4842

4943
# Debug named configs
@@ -58,12 +52,12 @@ def generate_test_data():
5852
"""
5953
sacred_ex_name = "train_rl"
6054
run_name = "TEST"
61-
n_seeds = 1
55+
repeat = 1
6256
search_space = {
6357
"config_updates": {
6458
"rl": {
6559
"rl_kwargs": {
66-
"learning_rate": tune.grid_search(
60+
"learning_rate": tune.choice(
6761
[3e-4 * x for x in (1 / 3, 1 / 2)],
6862
),
6963
},
@@ -86,63 +80,16 @@ def generate_test_data():
8680
def example_cartpole_rl():
8781
sacred_ex_name = "train_rl"
8882
run_name = "example-cartpole"
89-
n_seeds = 2
83+
repeat = 2
9084
search_space = {
9185
"config_updates": {
9286
"rl": {
9387
"rl_kwargs": {
94-
"learning_rate": tune.grid_search(np.logspace(3e-6, 1e-1, num=3)),
95-
"nminibatches": tune.grid_search([16, 32, 64]),
88+
"learning_rate": tune.choice(np.logspace(3e-6, 1e-1, num=3)),
89+
"nminibatches": tune.choice([16, 32, 64]),
9690
},
9791
},
9892
},
9993
}
10094
base_named_configs = ["cartpole"]
10195
resources_per_trial = dict(cpu=4)
102-
103-
104-
EASY_ENVS = ["cartpole", "pendulum", "mountain_car"]
105-
106-
107-
@parallel_ex.named_config
108-
def example_rl_easy():
109-
sacred_ex_name = "train_rl"
110-
run_name = "example-rl-easy"
111-
n_seeds = 2
112-
search_space = {
113-
"named_configs": tune.grid_search([[env] for env in EASY_ENVS]),
114-
"config_updates": {
115-
"rl": {
116-
"rl_kwargs": {
117-
"learning_rate": tune.grid_search(np.logspace(3e-6, 1e-1, num=3)),
118-
"nminibatches": tune.grid_search([16, 32, 64]),
119-
},
120-
},
121-
},
122-
}
123-
resources_per_trial = dict(cpu=4)
124-
125-
126-
@parallel_ex.named_config
127-
def example_gail_easy():
128-
sacred_ex_name = "train_adversarial"
129-
run_name = "example-gail-easy"
130-
n_seeds = 1
131-
search_space = {
132-
"named_configs": tune.grid_search([[env] for env in EASY_ENVS]),
133-
"config_updates": {
134-
"init_trainer_kwargs": {
135-
"rl": {
136-
"rl_kwargs": {
137-
"learning_rate": tune.grid_search(
138-
np.logspace(3e-6, 1e-1, num=3),
139-
),
140-
"nminibatches": tune.grid_search([16, 32, 64]),
141-
},
142-
},
143-
},
144-
},
145-
}
146-
search_space = {
147-
"command_name": "gail",
148-
}

0 commit comments

Comments
 (0)