From 9e63ef12b59a4e82ed90e6486788db5a3e8f33e8 Mon Sep 17 00:00:00 2001 From: Ruoyu Gao Date: Thu, 4 May 2023 21:30:09 -0400 Subject: [PATCH] fix style for drex unittest --- ding/entry/tests/test_serial_entry_drex.py | 46 ++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 ding/entry/tests/test_serial_entry_drex.py diff --git a/ding/entry/tests/test_serial_entry_drex.py b/ding/entry/tests/test_serial_entry_drex.py new file mode 100644 index 0000000000..bded23a97f --- /dev/null +++ b/ding/entry/tests/test_serial_entry_drex.py @@ -0,0 +1,46 @@ +import pytest +import os +from easydict import EasyDict +from copy import deepcopy + +from dizoo.classic_control.cartpole.config.cartpole_dqn_config \ +import cartpole_dqn_config, cartpole_dqn_create_config +from dizoo.classic_control.cartpole.config.cartpole_drex_dqn_config \ +import cartpole_drex_dqn_config, cartpole_drex_dqn_create_config +from ding.entry import serial_pipeline, serial_pipeline_reward_model_offpolicy +from ding.entry.application_entry_drex_collect_data import drex_collecting_data + + +@pytest.mark.unittest +def test_drex(): + exp_name = 'test_serial_pipeline_drex_expert' + config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)] + config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100 + config[0].exp_name = exp_name + expert_policy = serial_pipeline(config, seed=0) + + exp_name = 'test_serial_pipeline_drex_collect' + config = [deepcopy(cartpole_drex_dqn_config), deepcopy(cartpole_drex_dqn_create_config)] + config[0].exp_name = exp_name + config[0].reward_model.exp_name = exp_name + config[0].reward_model.expert_model_path = 'test_serial_pipeline_drex_expert/ckpt/ckpt_best.pth.tar' + config[0].reward_model.reward_model_path = 'test_serial_pipeline_drex_collect/cartpole.params' + config[0].reward_model.offline_data_path = 'test_serial_pipeline_drex_collect' + config[0].reward_model.checkpoint_max = 100 + config[0].reward_model.checkpoint_step = 100 + config[0].reward_model.num_snippets = 100 + + args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'}) + args.cfg[0].policy.collect.n_episode = 8 + del args.cfg[0].policy.collect.n_sample + args.cfg[0].bc_iteration = 1000 # for unittest + args.cfg[1].policy.type = 'bc' + drex_collecting_data(args=args) + try: + serial_pipeline_reward_model_offpolicy( + config, seed=0, max_train_iter=1, pretrain_reward=True, cooptrain_reward=False + ) + except Exception: + assert False, "pipeline fail" + finally: + os.popen('rm -rf test_serial_pipeline_drex*')