Skip to content

Commit

Permalink
feature(pu): add pong and cartpole ddp config of dqn and onppo
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Nov 28, 2024
1 parent de9ada0 commit da88e02
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 3 deletions.
4 changes: 2 additions & 2 deletions ding/entry/serial_entry_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.reward_model import create_reward_model
from ding.utils import set_pkg_seed
from ding.utils import set_pkg_seed, get_rank


def serial_pipeline_onpolicy(
Expand Down Expand Up @@ -68,7 +68,7 @@ def serial_pipeline_onpolicy(
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = create_serial_collector(
cfg.policy.collect.collector,
Expand Down
6 changes: 5 additions & 1 deletion ding/worker/collector/interaction_serial_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def eval(
'''
# evaluator only work on rank0
stop_flag = False
episode_info = None # Initialize to ensure it's defined in all ranks

if get_rank() == 0:
if n_episode is None:
n_episode = self._default_n_episode
Expand Down Expand Up @@ -317,5 +319,7 @@ def eval(
broadcast_object_list(objects, src=0)
stop_flag, episode_info = objects

episode_info = to_item(episode_info)
# Ensure episode_info is converted to the correct format
episode_info = to_item(episode_info) if episode_info is not None else {}

return stop_flag, episode_info
67 changes: 67 additions & 0 deletions dizoo/atari/config/serial/pong/pong_dqn_ddp_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from easydict import EasyDict

pong_dqn_config = dict(
exp_name='data_pong/pong_dqn_ddp_seed0',
env=dict(
collector_env_num=4,
evaluator_env_num=4,
n_evaluator_episode=8,
stop_value=20,
env_id='PongNoFrameskip-v4',
#'ALE/Pong-v5' is available. But special setting is needed after gym make.
frame_stack=4,
),
policy=dict(
multi_gpu=True,
cuda=True,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
),
collect=dict(n_sample=96, ),
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=250000,
),
replay_buffer=dict(replay_buffer_size=100000, ),
),
),
)
pong_dqn_config = EasyDict(pong_dqn_config)
main_config = pong_dqn_config
pong_dqn_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
)
pong_dqn_create_config = EasyDict(pong_dqn_create_config)
create_config = pong_dqn_create_config

if __name__ == '__main__':
"""
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./dizoo/atari/config/serial/pong/pong_dqn_ddp_config.py
"""
from ding.utils import DDPContext
from ding.entry import serial_pipeline
with DDPContext():
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(3e6))
76 changes: 76 additions & 0 deletions dizoo/atari/config/serial/pong/pong_onppo_ddp_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from easydict import EasyDict

pong_onppo_config = dict(
exp_name='data_pong/pong_onppo_ddp_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=20,
env_id='PongNoFrameskip-v4',
#'ALE/Pong-v5' is available. But special setting is needed after gym make.
frame_stack=4,
),
policy=dict(
multi_gpu=True,
cuda=True,
recompute_adv=True,
action_space='discrete',
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
action_space='discrete',
encoder_hidden_size_list=[64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
),
learn=dict(
epoch_per_collect=10,
update_per_collect=1,
batch_size=320,
learning_rate=3e-4,
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
adv_norm=True,
value_norm=True,
# for onppo, when we recompute adv, we need the key done in data to split traj, so we must
# use ignore_done=False here,
# but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True
# for halfcheetah, the length=1000
ignore_done=False,
grad_clip_type='clip_norm',
grad_clip_value=0.5,
),
collect=dict(
n_sample=3200,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
eval=dict(evaluator=dict(eval_freq=1000, )),
),
)
main_config = EasyDict(pong_onppo_config)

pong_onppo_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo'),
)
create_config = EasyDict(pong_onppo_create_config)

if __name__ == "__main__":
"""
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./dizoo/atari/config/serial/pong/pong_onppo_ddp_config.py
"""
from ding.utils import DDPContext
from ding.entry import serial_pipeline_onpolicy
with DDPContext():
serial_pipeline_onpolicy((main_config, create_config), seed=0, max_env_step=int(3e6))
66 changes: 66 additions & 0 deletions dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from easydict import EasyDict

cartpole_dqn_config = dict(
exp_name='cartpole_dqn_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=195,
replay_path='cartpole_dqn_seed0/video',
),
policy=dict(
multi_gpu=True,
cuda=True,
model=dict(
obs_shape=4,
action_shape=2,
encoder_hidden_size_list=[128, 128, 64],
dueling=True,
# dropout=0.1,
),
nstep=1,
discount_factor=0.97,
learn=dict(
update_per_collect=5,
batch_size=64,
learning_rate=0.001,
),
collect=dict(n_sample=8),
eval=dict(evaluator=dict(eval_freq=40, )),
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=20000, ),
),
),
)
cartpole_dqn_config = EasyDict(cartpole_dqn_config)
main_config = cartpole_dqn_config
cartpole_dqn_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
)
cartpole_dqn_create_config = EasyDict(cartpole_dqn_create_config)
create_config = cartpole_dqn_create_config

if __name__ == "__main__":
"""
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py
"""
from ding.utils import DDPContext
from ding.entry import serial_pipeline
with DDPContext():
serial_pipeline((main_config, create_config), seed=0)

64 changes: 64 additions & 0 deletions dizoo/classic_control/cartpole/config/cartpole_ppo_ddp_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from easydict import EasyDict

cartpole_ppo_config = dict(
exp_name='cartpole_ppo_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=195,
),
policy=dict(
multi_gpu=True,
cuda=True,
action_space='discrete',
model=dict(
obs_shape=4,
action_shape=2,
action_space='discrete',
encoder_hidden_size_list=[64, 64, 128],
critic_head_hidden_size=128,
actor_head_hidden_size=128,
),
learn=dict(
epoch_per_collect=2,
batch_size=64,
learning_rate=0.001,
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
learner=dict(hook=dict(save_ckpt_after_iter=100)),
),
collect=dict(
n_sample=256,
unroll_len=1,
discount_factor=0.9,
gae_lambda=0.95,
),
eval=dict(evaluator=dict(eval_freq=100, ), ),
),
)
cartpole_ppo_config = EasyDict(cartpole_ppo_config)
main_config = cartpole_ppo_config
cartpole_ppo_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo'),
)
cartpole_ppo_create_config = EasyDict(cartpole_ppo_create_config)
create_config = cartpole_ppo_create_config

if __name__ == "__main__":
"""
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./dizoo/classic_control/cartpole/config/cartpole_ppo_ddp_config.py
"""
from ding.utils import DDPContext
from ding.entry import serial_pipeline_onpolicy
with DDPContext():
serial_pipeline_onpolicy((main_config, create_config), seed=0)

0 comments on commit da88e02

Please sign in to comment.