-
Notifications
You must be signed in to change notification settings - Fork 388
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(pu): add pong and cartpole ddp config of dqn and onppo
- Loading branch information
1 parent
de9ada0
commit da88e02
Showing
6 changed files
with
280 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from easydict import EasyDict | ||
|
||
pong_dqn_config = dict( | ||
exp_name='data_pong/pong_dqn_ddp_seed0', | ||
env=dict( | ||
collector_env_num=4, | ||
evaluator_env_num=4, | ||
n_evaluator_episode=8, | ||
stop_value=20, | ||
env_id='PongNoFrameskip-v4', | ||
#'ALE/Pong-v5' is available. But special setting is needed after gym make. | ||
frame_stack=4, | ||
), | ||
policy=dict( | ||
multi_gpu=True, | ||
cuda=True, | ||
priority=False, | ||
model=dict( | ||
obs_shape=[4, 84, 84], | ||
action_shape=6, | ||
encoder_hidden_size_list=[128, 128, 512], | ||
), | ||
nstep=3, | ||
discount_factor=0.99, | ||
learn=dict( | ||
update_per_collect=10, | ||
batch_size=32, | ||
learning_rate=0.0001, | ||
target_update_freq=500, | ||
), | ||
collect=dict(n_sample=96, ), | ||
eval=dict(evaluator=dict(eval_freq=4000, )), | ||
other=dict( | ||
eps=dict( | ||
type='exp', | ||
start=1., | ||
end=0.05, | ||
decay=250000, | ||
), | ||
replay_buffer=dict(replay_buffer_size=100000, ), | ||
), | ||
), | ||
) | ||
pong_dqn_config = EasyDict(pong_dqn_config) | ||
main_config = pong_dqn_config | ||
pong_dqn_create_config = dict( | ||
env=dict( | ||
type='atari', | ||
import_names=['dizoo.atari.envs.atari_env'], | ||
), | ||
env_manager=dict(type='subprocess'), | ||
policy=dict(type='dqn'), | ||
) | ||
pong_dqn_create_config = EasyDict(pong_dqn_create_config) | ||
create_config = pong_dqn_create_config | ||
|
||
if __name__ == '__main__': | ||
""" | ||
Overview: | ||
This script should be executed with <nproc_per_node> GPUs. | ||
Run the following command to launch the script: | ||
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./dizoo/atari/config/serial/pong/pong_dqn_ddp_config.py | ||
""" | ||
from ding.utils import DDPContext | ||
from ding.entry import serial_pipeline | ||
with DDPContext(): | ||
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(3e6)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from easydict import EasyDict | ||
|
||
pong_onppo_config = dict( | ||
exp_name='data_pong/pong_onppo_ddp_seed0', | ||
env=dict( | ||
collector_env_num=8, | ||
evaluator_env_num=8, | ||
n_evaluator_episode=8, | ||
stop_value=20, | ||
env_id='PongNoFrameskip-v4', | ||
#'ALE/Pong-v5' is available. But special setting is needed after gym make. | ||
frame_stack=4, | ||
), | ||
policy=dict( | ||
multi_gpu=True, | ||
cuda=True, | ||
recompute_adv=True, | ||
action_space='discrete', | ||
model=dict( | ||
obs_shape=[4, 84, 84], | ||
action_shape=6, | ||
action_space='discrete', | ||
encoder_hidden_size_list=[64, 64, 128], | ||
actor_head_hidden_size=128, | ||
critic_head_hidden_size=128, | ||
), | ||
learn=dict( | ||
epoch_per_collect=10, | ||
update_per_collect=1, | ||
batch_size=320, | ||
learning_rate=3e-4, | ||
value_weight=0.5, | ||
entropy_weight=0.001, | ||
clip_ratio=0.2, | ||
adv_norm=True, | ||
value_norm=True, | ||
# for onppo, when we recompute adv, we need the key done in data to split traj, so we must | ||
# use ignore_done=False here, | ||
# but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True | ||
# for halfcheetah, the length=1000 | ||
ignore_done=False, | ||
grad_clip_type='clip_norm', | ||
grad_clip_value=0.5, | ||
), | ||
collect=dict( | ||
n_sample=3200, | ||
unroll_len=1, | ||
discount_factor=0.99, | ||
gae_lambda=0.95, | ||
), | ||
eval=dict(evaluator=dict(eval_freq=1000, )), | ||
), | ||
) | ||
main_config = EasyDict(pong_onppo_config) | ||
|
||
pong_onppo_create_config = dict( | ||
env=dict( | ||
type='atari', | ||
import_names=['dizoo.atari.envs.atari_env'], | ||
), | ||
env_manager=dict(type='subprocess'), | ||
policy=dict(type='ppo'), | ||
) | ||
create_config = EasyDict(pong_onppo_create_config) | ||
|
||
if __name__ == "__main__": | ||
""" | ||
Overview: | ||
This script should be executed with <nproc_per_node> GPUs. | ||
Run the following command to launch the script: | ||
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./dizoo/atari/config/serial/pong/pong_onppo_ddp_config.py | ||
""" | ||
from ding.utils import DDPContext | ||
from ding.entry import serial_pipeline_onpolicy | ||
with DDPContext(): | ||
serial_pipeline_onpolicy((main_config, create_config), seed=0, max_env_step=int(3e6)) |
66 changes: 66 additions & 0 deletions
66
dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from easydict import EasyDict | ||
|
||
cartpole_dqn_config = dict( | ||
exp_name='cartpole_dqn_seed0', | ||
env=dict( | ||
collector_env_num=8, | ||
evaluator_env_num=5, | ||
n_evaluator_episode=5, | ||
stop_value=195, | ||
replay_path='cartpole_dqn_seed0/video', | ||
), | ||
policy=dict( | ||
multi_gpu=True, | ||
cuda=True, | ||
model=dict( | ||
obs_shape=4, | ||
action_shape=2, | ||
encoder_hidden_size_list=[128, 128, 64], | ||
dueling=True, | ||
# dropout=0.1, | ||
), | ||
nstep=1, | ||
discount_factor=0.97, | ||
learn=dict( | ||
update_per_collect=5, | ||
batch_size=64, | ||
learning_rate=0.001, | ||
), | ||
collect=dict(n_sample=8), | ||
eval=dict(evaluator=dict(eval_freq=40, )), | ||
other=dict( | ||
eps=dict( | ||
type='exp', | ||
start=0.95, | ||
end=0.1, | ||
decay=10000, | ||
), | ||
replay_buffer=dict(replay_buffer_size=20000, ), | ||
), | ||
), | ||
) | ||
cartpole_dqn_config = EasyDict(cartpole_dqn_config) | ||
main_config = cartpole_dqn_config | ||
cartpole_dqn_create_config = dict( | ||
env=dict( | ||
type='cartpole', | ||
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], | ||
), | ||
env_manager=dict(type='subprocess'), | ||
policy=dict(type='dqn'), | ||
) | ||
cartpole_dqn_create_config = EasyDict(cartpole_dqn_create_config) | ||
create_config = cartpole_dqn_create_config | ||
|
||
if __name__ == "__main__": | ||
""" | ||
Overview: | ||
This script should be executed with <nproc_per_node> GPUs. | ||
Run the following command to launch the script: | ||
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py | ||
""" | ||
from ding.utils import DDPContext | ||
from ding.entry import serial_pipeline | ||
with DDPContext(): | ||
serial_pipeline((main_config, create_config), seed=0) | ||
|
64 changes: 64 additions & 0 deletions
64
dizoo/classic_control/cartpole/config/cartpole_ppo_ddp_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from easydict import EasyDict | ||
|
||
cartpole_ppo_config = dict( | ||
exp_name='cartpole_ppo_seed0', | ||
env=dict( | ||
collector_env_num=8, | ||
evaluator_env_num=5, | ||
n_evaluator_episode=5, | ||
stop_value=195, | ||
), | ||
policy=dict( | ||
multi_gpu=True, | ||
cuda=True, | ||
action_space='discrete', | ||
model=dict( | ||
obs_shape=4, | ||
action_shape=2, | ||
action_space='discrete', | ||
encoder_hidden_size_list=[64, 64, 128], | ||
critic_head_hidden_size=128, | ||
actor_head_hidden_size=128, | ||
), | ||
learn=dict( | ||
epoch_per_collect=2, | ||
batch_size=64, | ||
learning_rate=0.001, | ||
value_weight=0.5, | ||
entropy_weight=0.01, | ||
clip_ratio=0.2, | ||
learner=dict(hook=dict(save_ckpt_after_iter=100)), | ||
), | ||
collect=dict( | ||
n_sample=256, | ||
unroll_len=1, | ||
discount_factor=0.9, | ||
gae_lambda=0.95, | ||
), | ||
eval=dict(evaluator=dict(eval_freq=100, ), ), | ||
), | ||
) | ||
cartpole_ppo_config = EasyDict(cartpole_ppo_config) | ||
main_config = cartpole_ppo_config | ||
cartpole_ppo_create_config = dict( | ||
env=dict( | ||
type='cartpole', | ||
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], | ||
), | ||
env_manager=dict(type='base'), | ||
policy=dict(type='ppo'), | ||
) | ||
cartpole_ppo_create_config = EasyDict(cartpole_ppo_create_config) | ||
create_config = cartpole_ppo_create_config | ||
|
||
if __name__ == "__main__": | ||
""" | ||
Overview: | ||
This script should be executed with <nproc_per_node> GPUs. | ||
Run the following command to launch the script: | ||
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./dizoo/classic_control/cartpole/config/cartpole_ppo_ddp_config.py | ||
""" | ||
from ding.utils import DDPContext | ||
from ding.entry import serial_pipeline_onpolicy | ||
with DDPContext(): | ||
serial_pipeline_onpolicy((main_config, create_config), seed=0) |