Skip to content

Commit

Permalink
fix(pu):fix atari_ppo_ddp.py
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Dec 10, 2024
1 parent da88e02 commit 6cc337d
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 11 deletions.
2 changes: 1 addition & 1 deletion ding/worker/collector/sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ding.envs import BaseEnvManager
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning, get_rank, get_world_size, \
broadcast_object_list, allreduce_data
allreduce_data
from ding.torch_utils import to_tensor, to_ndarray
from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict

pong_onppo_config = dict(
pong_ppo_config = dict(
env=dict(
collector_env_num=8,
evaluator_env_num=8,
Expand Down Expand Up @@ -49,19 +49,19 @@
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
main_config = EasyDict(pong_onppo_config)
main_config = EasyDict(pong_ppo_config)

pong_onppo_create_config = dict(
pong_ppo_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo'),
)
create_config = EasyDict(pong_onppo_create_config)
create_config = EasyDict(pong_ppo_create_config)

if __name__ == "__main__":
# or you can enter `ding -m serial_onpolicy -c pong_onppo_config.py -s 0`
# or you can enter `ding -m serial_onpolicy -c pong_ppo_config.py -s 0`
from ding.entry import serial_pipeline_onpolicy
serial_pipeline_onpolicy((main_config, create_config), seed=0)
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./dizoo/atari/config/serial/pong/pong_onppo_ddp_config.py
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./dizoo/atari/config/serial/pong/pong_ppo_ddp_config.py
"""
from ding.utils import DDPContext
from ding.entry import serial_pipeline_onpolicy
Expand Down
2 changes: 1 addition & 1 deletion dizoo/atari/example/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
gae_estimator, termination_checker
from ding.utils import set_pkg_seed
from dizoo.atari.envs.atari_env import AtariEnv
from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config
from dizoo.atari.config.serial.pong.pong_ppo_config import main_config, create_config


def main():
Expand Down
12 changes: 9 additions & 3 deletions dizoo/atari/example/atari_ppo_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
gae_estimator, ddp_termination_checker, online_logger
from ding.utils import set_pkg_seed, DistContext, get_rank, get_world_size
from ding.utils import set_pkg_seed, DDPContext, get_rank, get_world_size
from dizoo.atari.envs.atari_env import AtariEnv
from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config
from dizoo.atari.config.serial.pong.pong_ppo_config import main_config, create_config


def main():
logging.getLogger().setLevel(logging.INFO)
with DistContext():
with DDPContext():
rank, world_size = get_rank(), get_world_size()
main_config.example = 'pong_ppo_seed0_ddp_avgsplit'
main_config.policy.multi_gpu = True
Expand Down Expand Up @@ -53,4 +53,10 @@ def main():


if __name__ == "__main__":
"""
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./dizoo/atari/example/atari_ppo_ddp.py
"""
main()

0 comments on commit 6cc337d

Please sign in to comment.