Skip to content

Commit 531fa06

Browse files
Migrate to pathlib (#537)
* Initial mypy configuration * Initial change to get the PR up * Initial review at replacing os.path * Bug fixes from tests * Fix types: test_envs.py * Fix types: conftest.py * Fix types: tests/util * Fix types: tests/scripts * Fix types: tests/rewards * Fix types: tests/policies * Incorrect decorator in update_stats method form networks.py::BaseNorm * Fix types: tests/algorithms (adersarial and bc) * Fix types: tests/algorithms (dagger and pc) * Fix types: tests/data * Linting * Linting * Fix types: algorithms/preference_comparisons.py * Fix types: algorithms/mce_irl.py * Formatting, fixed minor bug * Clarify why types are ignored * Started fixing types on algorithms/density.py * Linting * Linting (add back type ignore after reformatting) * Fixed types: imitation/data/types.py * Fixed types (started): imitation/data/ * Fixed types: imitation/data/buffer.py * Fixed bug in buffer.py * Fixed types: imitation/data/rollout.py * Fixed types: imitation/data/wrappers.py * Improve makefile to support automatic cache cleaning * Fixed types: imitation/testing/ * Linting, fixed wrong return type in rewards.predict_processed_all * Fixed types: imitation/policies/ * Formatting * Fixed types: imitation/rewards/ * Fixed types: imitation/rewards/ * Fixed types: imitation/scripts/ * Fixed types: imitation/util/ and formatting * Linting and formatting * Bug fixes for test errors * Linting and typing * Improve typing in algorithms * Formatting * Bug fix * Formatting * Fixes suggested by Adam. * Fix mypy version. * Fix bugs * Remove unused imports * Formatting * Added parse_path func and refactored code to use it * Fix typing, linting * Update TabularPolicy.predict to match base class * Fix not checking for dones * Change for loop to dict comprehension * Remove is_ensemble to clear up type checking errors * Reduce code duplication and general cleanup * Fix type annotation of step_dict * Change List to Sequence * Fix density.py::DensityAlgorithm._set_demo_from_batch * Fixed n_steps (OnPolicyAlgorithm) * Fix errors in tests * Include some suggestions into rollout.py and preference_comparisons.py * Formatting * Fix setter error as per python/mypy#5936 * add reason for assertion. * Fix style guide violation: https://google.github.io/styleguide/pyguide.html#22-imports * Update src/imitation/scripts/parallel.py Co-authored-by: Adam Gleave <[email protected]> * Move kwargs to the end. * Swap order of expert_policy_type and expert_policy_path validation check * Update src/imitation/util/util.py Co-authored-by: Adam Gleave <[email protected]> * Update tests/rewards/test_reward_fn.py Co-authored-by: Adam Gleave <[email protected]> * Explicit random state setting and fix corresponding tests (except notebooks, sacred config, scripts) * Fix notebooks; add script to clean notebooks * Fix all tests. * Formattting. * Additional fixes * Linting * Remove automatically generated `_api` docs files too on `make clean` * Fix docstrings. * Fix issue with next(iter(iterable)) * Formatting * Remove whitespace * Add TODO message to remove type ignore later * Remove unnecessary assertion. * Fixed types in density.py set_demonstrations * Added type ignore to pytype bug * Fix_get_first_iter_element and add tests * Bugfix in BC and tests -- masked as previously iterator ran out too early! * Remove makefile for now * Added link to SB3 issue for future reference. * Fix types of train_imitation Only return "expert_stats" if all trajectories have reward. * Modify assert in test_bc to reflect correct type * Add ci/clean_notebooks.py to CI checks * Improve clean_notebooks.py by allowing checking only mode. * Add ipynb notebook checks to CI * Add support for explicit files for notebook cleaning * Clean notebooks * Small improvements in util.py * Replace TransitionKind with TransitionsMinimal * Delete unused statement in test * Update src/imitation/util/util.py Co-authored-by: Adam Gleave <[email protected]> * Update src/imitation/util/util.py Co-authored-by: Adam Gleave <[email protected]> * Make type ignore specific to pytype * Linting * Migrate from RandomState (deprecated) to Generator * Add backticks to error message * Create "AnyNorm" alias * Small fix * Add additional checks to shapes in _set_demo_from_batch * Fix RolloutStatsComputer type * Improved logging/messages in clean_notebooks.py * Fix issues resulting from merge * Bug fix * Bug fix (wasn't really fixed before) * Fixed docs example of BC * Fix bugs resulting from merge * Fix docs (dagger.rst) caught by sphinx CI * Add mypy to CI * Continue fixing miscellaneous type errors * Linting * Fix issue with normalize_input_layer type * Add support for checking presence of generic type ignores * Allow subdirectories in notebook clean * Add full typing support for TransitionsMinimal as a sequence * Fix types for density.py * Misc fixes * Add support for prefix context manager in logger (from #529) * Added back accidentally removed code * Replaced preference comparisons prefix with ctx manager * Fixed errors * Bug fixes * Docstring fixes * Fix bug in serialize.py * Fixed codecheck by pointing notebook checks to docs * Add rng to mce_irl.rst (doctest) * Add rng to density.rst (doctest) * Fix remaining rst files * Increase sample size to reduce flakiness * Ignore files not passing mypy for now * Comment in wrong line * Comment in wrong line * Move excluded files to argument * Add quotes to mypy arg call * Fix CI mypy call * Fix CI yaml * Break ignored files up into one line each * Address PR comments * Point SB3 to master to include bug fix * Small bug fixes * Small bug fixes * Sort import * Linting * Do not follow imports for ignored files * Fix tests for context managers * Format / fix tests for context manager * Switch to sb3 1.6.1 * Formatting * Upgrade Python version in Windows CI * Remove unused import * Remove unused fixture * Add coveragerc file * Add utils test * Add tests and asserts * Add test to synthetic gatherer * Add trajectory unwrap tests * Formatting * Remove bracket typo * Fix .coveragerc instruction * Improve density algo coverage and bug fixes * Fix bug in test * Add pragma no cover updates * Minor coverage tweaks * Fix iterator test * Add test for parse_path * Updates on sacred util * Mark type ignore rule * Mark type ignore rule * Miscellaneous bug fixes and improvements * Reformat hanging line * Ignore parse path checks for windows * Add trailing comma * Minor changes * No newline end of file * Update src/imitation/data/types.py Co-authored-by: Adam Gleave <[email protected]> * Update src/imitation/data/types.py Co-authored-by: Adam Gleave <[email protected]> * Include suggestions from Adam Co-authored-by: Adam Gleave <[email protected]>
1 parent 70c8cee commit 531fa06

33 files changed

+316
-210
lines changed

.circleci/config.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ executors:
4141
(?x)(
4242
src/imitation/algorithms/preference_comparisons.py$
4343
| src/imitation/rewards/reward_nets.py$
44-
| src/imitation/util/sacred.py$
4544
| src/imitation/algorithms/base.py$
4645
| src/imitation/scripts/train_preference_comparisons.py$
4746
| src/imitation/rewards/serialize.py$

ci/code_checks.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ SRC_FILES=(src/ tests/ experiments/ examples/ docs/conf.py setup.py ci/)
55
EXCLUDE_MYPY="(?x)(
66
src/imitation/algorithms/preference_comparisons.py$
77
| src/imitation/rewards/reward_nets.py$
8-
| src/imitation/util/sacred.py$
98
| src/imitation/algorithms/base.py$
109
| src/imitation/scripts/train_preference_comparisons.py$
1110
| src/imitation/rewards/serialize.py$

docs/tutorials/1_train_bc.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,4 @@
200200
},
201201
"nbformat": 4,
202202
"nbformat_minor": 2
203-
}
203+
}

docs/tutorials/3_train_gail.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,4 +187,4 @@
187187
},
188188
"nbformat": 4,
189189
"nbformat_minor": 2
190-
}
190+
}

docs/tutorials/4_train_airl.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,4 +181,4 @@
181181
},
182182
"nbformat": 4,
183183
"nbformat_minor": 2
184-
}
184+
}

docs/tutorials/5_train_preference_comparisons.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,4 +203,4 @@
203203
},
204204
"nbformat": 4,
205205
"nbformat_minor": 2
206-
}
206+
}

docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,4 +236,4 @@
236236
},
237237
"nbformat": 4,
238238
"nbformat_minor": 5
239-
}
239+
}

docs/tutorials/7_train_density.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,4 @@
158158
},
159159
"nbformat": 4,
160160
"nbformat_minor": 4
161-
}
161+
}

src/imitation/algorithms/adversarial/common.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import collections
44
import dataclasses
55
import logging
6-
import os
76
from typing import (
87
Callable,
98
Iterable,
@@ -127,7 +126,7 @@ def __init__(
127126
gen_algo: base_class.BaseAlgorithm,
128127
reward_net: reward_nets.RewardNet,
129128
n_disc_updates_per_round: int = 2,
130-
log_dir: str = "output/",
129+
log_dir: types.AnyPath = "output/",
131130
disc_opt_cls: Type[th.optim.Optimizer] = th.optim.Adam,
132131
disc_opt_kwargs: Optional[Mapping] = None,
133132
gen_train_timesteps: Optional[int] = None,
@@ -202,7 +201,7 @@ def __init__(
202201
self.venv = venv
203202
self.gen_algo = gen_algo
204203
self._reward_net = reward_net.to(gen_algo.device)
205-
self._log_dir = log_dir
204+
self._log_dir = types.parse_path(log_dir)
206205

207206
# Create graph for optimising/recording stats on discriminator
208207
self._disc_opt_cls = disc_opt_cls
@@ -215,10 +214,10 @@ def __init__(
215214
)
216215

217216
if self._init_tensorboard:
218-
logging.info("building summary directory at " + self._log_dir)
219-
summary_dir = os.path.join(self._log_dir, "summary")
220-
os.makedirs(summary_dir, exist_ok=True)
221-
self._summary_writer = thboard.SummaryWriter(summary_dir)
217+
logging.info(f"building summary directory at {self._log_dir}")
218+
summary_dir = self._log_dir / "summary"
219+
summary_dir.mkdir(parents=True, exist_ok=True)
220+
self._summary_writer = thboard.SummaryWriter(str(summary_dir))
222221

223222
self.venv_buffering = wrappers.BufferingWrapper(self.venv)
224223

src/imitation/algorithms/bc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,4 +472,4 @@ def save_policy(self, policy_path: types.AnyPath) -> None:
472472
Args:
473473
policy_path: path to save policy to.
474474
"""
475-
th.save(self.policy, types.path_to_str(policy_path))
475+
th.save(self.policy, types.parse_path(policy_path))

src/imitation/algorithms/dagger.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,8 @@ def reconstruct_trainer(
9090
A deserialized `DAggerTrainer`.
9191
"""
9292
custom_logger = custom_logger or imit_logger.configure()
93-
checkpoint_path = pathlib.Path(
94-
types.path_to_str(scratch_dir),
95-
"checkpoint-latest.pt",
96-
)
93+
scratch_dir = types.parse_path(scratch_dir)
94+
checkpoint_path = scratch_dir / "checkpoint-latest.pt"
9795
trainer = th.load(checkpoint_path, map_location=utils.get_device(device))
9896
trainer.venv = venv
9997
trainer._logger = custom_logger
@@ -109,14 +107,14 @@ def _save_dagger_demo(
109107
# however that NPZ save here is likely more space efficient than
110108
# pickle from types.save(), and types.save only accepts
111109
# TrajectoryWithRew right now (subclass of Trajectory).
112-
save_dir_obj = pathlib.Path(types.path_to_str(save_dir))
110+
save_dir = types.parse_path(save_dir)
113111
assert isinstance(trajectory, types.Trajectory)
114112
actual_prefix = f"{prefix}-" if prefix else ""
115113
timestamp = util.make_unique_timestamp()
116114
filename = f"{actual_prefix}dagger-demo-{timestamp}.npz"
117115

118-
save_dir_obj.mkdir(parents=True, exist_ok=True)
119-
npz_path = save_dir_obj / filename
116+
save_dir.mkdir(parents=True, exist_ok=True)
117+
npz_path = save_dir / filename
120118
np.savez_compressed(npz_path, **dataclasses.asdict(trajectory))
121119
logging.info(f"Saved demo at '{npz_path}'")
122120

@@ -344,7 +342,7 @@ def __init__(
344342
if beta_schedule is None:
345343
beta_schedule = LinearBetaSchedule(15)
346344
self.beta_schedule = beta_schedule
347-
self.scratch_dir = pathlib.Path(types.path_to_str(scratch_dir))
345+
self.scratch_dir = types.parse_path(scratch_dir)
348346
self.venv = venv
349347
self.round_num = 0
350348
self._last_loaded_round = -1
@@ -397,11 +395,7 @@ def _load_all_demos(self):
397395
return demo_transitions, num_demos_by_round
398396

399397
def _get_demo_paths(self, round_dir):
400-
return [
401-
os.path.join(round_dir, p)
402-
for p in os.listdir(round_dir)
403-
if p.endswith(".npz")
404-
]
398+
return [round_dir / p for p in os.listdir(round_dir) if p.endswith(".npz")]
405399

406400
def _demo_dir_path_for_round(self, round_num: Optional[int] = None) -> pathlib.Path:
407401
if round_num is None:
@@ -411,7 +405,7 @@ def _demo_dir_path_for_round(self, round_num: Optional[int] = None) -> pathlib.P
411405
def _try_load_demos(self) -> None:
412406
"""Load the dataset for this round into self.bc_trainer as a DataLoader."""
413407
demo_dir = self._demo_dir_path_for_round()
414-
demo_paths = self._get_demo_paths(demo_dir) if os.path.isdir(demo_dir) else []
408+
demo_paths = self._get_demo_paths(demo_dir) if demo_dir.is_dir() else []
415409
if len(demo_paths) == 0:
416410
raise NeedsDemosException(
417411
f"No demos found for round {self.round_num} in dir '{demo_dir}'. "

src/imitation/algorithms/mce_irl.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,7 @@ def set_pi(self, pi: np.ndarray) -> None:
176176
self.pi = pi
177177

178178
def _predict(self, observation: th.Tensor, deterministic: bool = False):
179-
raise NotImplementedError(
180-
"Should never be called as predict overridden.",
181-
)
179+
raise NotImplementedError("Should never be called as predict overridden.")
182180

183181
def forward( # type: ignore[override]
184182
self,

src/imitation/data/types.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,87 @@ def dataclass_quick_asdict(obj) -> Dict[str, Any]:
4444
return d
4545

4646

47-
def path_to_str(path: AnyPath) -> str:
48-
if isinstance(path, bytes):
49-
return path.decode()
47+
def parse_path(
48+
path: AnyPath,
49+
allow_relative: bool = True,
50+
base_directory: Optional[pathlib.Path] = None,
51+
) -> pathlib.Path:
52+
"""Parse a path to a `pathlib.Path` object.
53+
54+
All resulting paths are resolved, absolute paths. If `allow_relative` is True,
55+
then relative paths are allowed as input, and are resolved relative to the
56+
current working directory, or relative to `base_directory` if it is
57+
specified.
58+
59+
Args:
60+
path: The path to parse. Can be a string, bytes, or `os.PathLike`.
61+
allow_relative: If True, then relative paths are allowed as input, and
62+
are resolved relative to the current working directory. If False,
63+
an error is raised if the path is not absolute.
64+
base_directory: If specified, then relative paths are resolved relative
65+
to this directory, instead of the current working directory.
66+
67+
Returns:
68+
A `pathlib.Path` object.
69+
70+
Raises:
71+
ValueError: If `allow_relative` is False and the path is not absolute.
72+
ValueError: If `base_directory` is specified and `allow_relative` is
73+
False.
74+
"""
75+
if base_directory is not None and not allow_relative:
76+
raise ValueError(
77+
"If `base_directory` is specified, then `allow_relative` must be True.",
78+
)
79+
80+
parsed_path: pathlib.Path
81+
if isinstance(path, pathlib.Path):
82+
parsed_path = path
83+
elif isinstance(path, str):
84+
parsed_path = pathlib.Path(path)
85+
elif isinstance(path, bytes):
86+
parsed_path = pathlib.Path(path.decode())
87+
else:
88+
parsed_path = pathlib.Path(str(path))
89+
90+
if parsed_path.is_absolute():
91+
return parsed_path
92+
else:
93+
if allow_relative:
94+
base_directory = base_directory or pathlib.Path.cwd()
95+
# relative to current working directory
96+
return base_directory / parsed_path
97+
else:
98+
raise ValueError(f"Path {str(parsed_path)} is not absolute")
99+
100+
101+
def parse_optional_path(
102+
path: Optional[AnyPath],
103+
allow_relative: bool = True,
104+
base_directory: Optional[pathlib.Path] = None,
105+
) -> Optional[pathlib.Path]:
106+
"""Parse an optional path to a `pathlib.Path` object.
107+
108+
All resulting paths are resolved, absolute paths. If `allow_relative` is True,
109+
then relative paths are allowed as input, and are resolved relative to the
110+
current working directory, or relative to `base_directory` if it is
111+
specified.
112+
113+
Args:
114+
path: The path to parse. Can be a string, bytes, or `os.PathLike`.
115+
allow_relative: If True, then relative paths are allowed as input, and
116+
are resolved relative to the current working directory. If False,
117+
an error is raised if the path is not absolute.
118+
base_directory: If specified, then relative paths are resolved relative
119+
to this directory, instead of the current working directory.
120+
121+
Returns:
122+
A `pathlib.Path` object, or None if `path` is None.
123+
"""
124+
if path is None:
125+
return None
50126
else:
51-
return str(path)
127+
return parse_path(path, allow_relative, base_directory)
52128

53129

54130
@dataclasses.dataclass(frozen=True)
@@ -417,10 +493,10 @@ def save(path: AnyPath, trajectories: Sequence[Trajectory]):
417493
trajectories: The trajectories to save.
418494
419495
Raises:
420-
ValueError: If the trajectories are not all of the same type, i.e. some are
496+
ValueError: If not all trajectories have the same type, i.e. some are
421497
`Trajectory` and others are `TrajectoryWithRew`.
422498
"""
423-
p = pathlib.Path(path_to_str(path))
499+
p = parse_path(path)
424500
p.parent.mkdir(parents=True, exist_ok=True)
425501
tmp_path = f"{p}.tmp"
426502

src/imitation/policies/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _choose_action(self, obs: np.ndarray) -> np.ndarray:
4040
def forward(self, *args):
4141
# technically BasePolicy is a Torch module, so this needs a forward()
4242
# method
43-
raise NotImplementedError()
43+
raise NotImplementedError() # pragma: no cover
4444

4545

4646
class RandomPolicy(HardCodedPolicy):

src/imitation/policies/serialize.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
# torch.load() and torch.save() calls
55

66
import logging
7-
import os
87
import pathlib
98
from typing import Callable, Type, TypeVar
109

1110
import huggingface_sb3 as hfsb3
1211
from stable_baselines3.common import base_class, callbacks, policies, vec_env
1312

13+
from imitation.data import types
1414
from imitation.policies import base
1515
from imitation.util import registry
1616

@@ -52,7 +52,7 @@ def load_stable_baselines_model(
5252
The deserialized RL algorithm.
5353
"""
5454
logging.info(f"Loading Stable Baselines policy for '{cls}' from '{path}'")
55-
path_obj = pathlib.Path(path)
55+
path_obj = types.parse_path(path)
5656

5757
if path_obj.is_dir():
5858
path_obj = path_obj / "model.zip"
@@ -181,7 +181,7 @@ def load_policy(
181181

182182

183183
def save_stable_model(
184-
output_dir: str,
184+
output_dir: pathlib.Path,
185185
model: base_class.BaseAlgorithm,
186186
filename: str = "model.zip",
187187
) -> None:
@@ -197,9 +197,9 @@ def save_stable_model(
197197
# Save each model in new directory in case we want to add metadata or other
198198
# information in future. (E.g. we used to save `VecNormalize` statistics here,
199199
# although that is no longer necessary.)
200-
os.makedirs(output_dir, exist_ok=True)
201-
model.save(os.path.join(output_dir, filename))
202-
logging.info("Saved policy to %s", output_dir)
200+
output_dir.mkdir(parents=True, exist_ok=True)
201+
model.save(output_dir / filename)
202+
logging.info(f"Saved policy to {output_dir}")
203203

204204

205205
class SavePolicyCallback(callbacks.EventCallback):
@@ -211,7 +211,7 @@ class SavePolicyCallback(callbacks.EventCallback):
211211

212212
def __init__(
213213
self,
214-
policy_dir: str,
214+
policy_dir: pathlib.Path,
215215
*args,
216216
**kwargs,
217217
):
@@ -227,6 +227,6 @@ def __init__(
227227

228228
def _on_step(self) -> bool:
229229
assert self.model is not None
230-
output_dir = os.path.join(self.policy_dir, f"{self.num_timesteps:012d}")
230+
output_dir = self.policy_dir / f"{self.num_timesteps:012d}"
231231
save_stable_model(output_dir, self.model)
232232
return True

0 commit comments

Comments
 (0)