Skip to content

Commit

Permalink
polish(pu): polish comments and resume_training option
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Nov 1, 2024
1 parent b789608 commit 0968723
Show file tree
Hide file tree
Showing 12 changed files with 25 additions and 26 deletions.
5 changes: 2 additions & 3 deletions ding/entry/serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def serial_pipeline(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
resume_training = cfg.policy.learn.get('resume_training', False)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not resume_training)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.resume_training)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -87,7 +86,7 @@ def serial_pipeline(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if resume_training:
if cfg.policy.learn.resume_training:
collector.envstep = learner.collector_envstep

# Accumulate plenty of data at the beginning of training.
Expand Down
7 changes: 3 additions & 4 deletions ding/entry/serial_entry_mbrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def mbrl_entry_setup(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
resume_training = cfg.policy.learn.get('resume_training', False)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not resume_training)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.resume_training)

if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -133,7 +132,7 @@ def serial_pipeline_dyna(
img_buffer = create_img_buffer(cfg, input_cfg, world_model, tb_logger)

learner.call_hook('before_run')
if resume_training:
if cfg.policy.learn.resume_training:
collector.envstep = learner.collector_envstep

if cfg.policy.get('random_collect_size', 0) > 0:
Expand Down Expand Up @@ -206,7 +205,7 @@ def serial_pipeline_dream(
mbrl_entry_setup(input_cfg, seed, env_setting, model)

learner.call_hook('before_run')
if resume_training:
if cfg.policy.learn.resume_training:
collector.envstep = learner.collector_envstep

if cfg.policy.get('random_collect_size', 0) > 0:
Expand Down
5 changes: 2 additions & 3 deletions ding/entry/serial_entry_ngu.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def serial_pipeline_ngu(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
resume_training = cfg.policy.learn.get('resume_training', False)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not resume_training)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.resume_training)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -90,7 +89,7 @@ def serial_pipeline_ngu(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if resume_training:
if cfg.policy.learn.resume_training:
collector.envstep = learner.collector_envstep

# Accumulate plenty of data at the beginning of training.
Expand Down
5 changes: 2 additions & 3 deletions ding/entry/serial_entry_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def serial_pipeline_onpolicy(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
resume_training = cfg.policy.learn.get('resume_training', False)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not resume_training)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.resume_training)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -81,7 +80,7 @@ def serial_pipeline_onpolicy(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if resume_training:
if cfg.policy.learn.resume_training:
collector.envstep = learner.collector_envstep

while True:
Expand Down
5 changes: 2 additions & 3 deletions ding/entry/serial_entry_onpolicy_ppg.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def serial_pipeline_onpolicy_ppg(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
resume_training = cfg.policy.learn.get('resume_training', False)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not resume_training)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.resume_training)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -81,7 +80,7 @@ def serial_pipeline_onpolicy_ppg(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if resume_training:
if cfg.policy.learn.resume_training:
collector.envstep = learner.collector_envstep

while True:
Expand Down
4 changes: 4 additions & 0 deletions ding/policy/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def default_config(cls: type) -> EasyDict:
traj_len_inf=False,
# neural network model config
model=dict(),
# If resume_training is True, the environment step count (collector.envstep) and training iteration (train_iter)
# will be loaded from the pretrained checkpoint, allowing training to resume seamlessly
# from where the ckpt left off.
learn=dict(resume_training=False),
)

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions ding/worker/collector/battle_episode_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def envstep(self) -> int:
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
- envstep (:obj:`int`): The total envstep count.
"""
return self._total_envstep_count

Expand All @@ -172,7 +172,7 @@ def envstep(self, value: int) -> None:
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): the total envstep count
- value (:obj:`int`): The total envstep count.
"""
self._total_envstep_count = value

Expand Down
4 changes: 2 additions & 2 deletions ding/worker/collector/battle_sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def envstep(self) -> int:
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
- envstep (:obj:`int`): The total envstep count.
"""
return self._total_envstep_count

Expand All @@ -185,7 +185,7 @@ def envstep(self, value: int) -> None:
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): the total envstep count
- value (:obj:`int`): The total envstep count.
"""
self._total_envstep_count = value

Expand Down
4 changes: 2 additions & 2 deletions ding/worker/collector/episode_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def envstep(self) -> int:
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
- envstep (:obj:`int`): The total envstep count.
"""
return self._total_envstep_count

Expand All @@ -167,7 +167,7 @@ def envstep(self, value: int) -> None:
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): the total envstep count
- value (:obj:`int`): The total envstep count.
"""
self._total_envstep_count = value

Expand Down
4 changes: 2 additions & 2 deletions ding/worker/collector/sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def envstep(self) -> int:
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
- envstep (:obj:`int`): The total envstep count.
"""
return self._total_envstep_count

Expand All @@ -195,7 +195,7 @@ def envstep(self, value: int) -> None:
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): the total envstep count
- value (:obj:`int`): The total envstep count.
"""
self._total_envstep_count = value

Expand Down
2 changes: 1 addition & 1 deletion ding/worker/learner/learner_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __call__(self, engine: 'BaseLearner') -> None: # noqa
path = os.path.join(dirname, ckpt_name)
state_dict = engine.policy.state_dict()
state_dict.update({'last_iter': engine.last_iter.val})
state_dict.update({'last_step': engine._collector_envstep})
state_dict.update({'last_step': engine.collector_envstep})
save_file(path, state_dict)
engine.info('{} save ckpt in {}'.format(engine.instance_name, path))

Expand Down
2 changes: 1 addition & 1 deletion dizoo/league_demo/league_demo_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def envstep(self) -> int:
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
- envstep (:obj:`int`): The total envstep count.
"""
return self._total_envstep_count

Expand Down

0 comments on commit 0968723

Please sign in to comment.