Skip to content

Commit

Permalink
fix(nyz): fix mock and config bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Dec 12, 2024
1 parent 765b8fb commit aa86aa7
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
def test_serial_pipeline_trex_onpolicy():
exp_name = 'trex_onpolicy_test_serial_pipeline_trex_onpolicy_expert'
config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
config[0].policy.learn.learner = dict()
config[0].policy.learn.learner.hook = dict()
config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
config[0].exp_name = exp_name
expert_policy = serial_pipeline_onpolicy(config, seed=0)
Expand Down
6 changes: 3 additions & 3 deletions ding/framework/middleware/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
"""
Overview:
An encapsulation of inference and rollout middleware. Stop when completing \
the target number of steps.
the target number of steps.
Input of ctx:
- env_step (:obj:`int`): The env steps which will increase during collection.
"""
Expand Down Expand Up @@ -143,7 +143,7 @@ class EpisodeCollector:
"""
Overview:
The class of the collector running by episodes, including model inference and transition \
process. Use the `__call__` method to execute the whole collection process.
process. Use the `__call__` method to execute the whole collection process.
"""

def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None:
Expand All @@ -168,7 +168,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
"""
Overview:
An encapsulation of inference and rollout middleware. Stop when completing the \
target number of episodes.
target number of episodes.
Input of ctx:
- env_episode (:obj:`int`): The env env_episode which will increase during collection.
"""
Expand Down
6 changes: 2 additions & 4 deletions ding/framework/middleware/tests/mock_for_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def process_transition(self, obs: Any, model_output: dict, timestep: namedtuple)
'logit': 1.0,
'value': 2.0,
'reward': 0.1,
'done': True,
'done': timestep.done,
}
return transition

Expand All @@ -75,7 +75,6 @@ def __init__(self) -> None:
self.env_num = env_num
self.obs_dim = obs_dim
self.closed = False
self._reward_grow_indicator = 1
self._steps = [0 for _ in range(self.env_num)]

@property
Expand Down Expand Up @@ -111,11 +110,10 @@ def step(self, actions: tnp.ndarray) -> List[tnp.ndarray]:
obs=torch.rand(self.obs_dim),
reward=1.0,
done=done,
info={'eval_episode_return': self._reward_grow_indicator * 1.0} if done else {},
info={'eval_episode_return': 10.0} if done else {},
env_id=i,
)
timesteps.append(tnp.array(timestep))
self._reward_grow_indicator += 1 # eval_episode_return will increase as step method is called
return timesteps


Expand Down
2 changes: 1 addition & 1 deletion ding/framework/middleware/tests/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ def test_interaction_evaluator():
# there are 2 env_num and 5 episodes in the test.
# so when interaction_evaluator runs the first time, reward is [[1, 2, 3], [2, 3]] and the avg = 2.2
# the second time, reward is [[4, 5, 6], [5, 6]] . . .
assert ctx.eval_value == 2.2 + i // 10 * 3.0
assert ctx.eval_value == 10.0

0 comments on commit aa86aa7

Please sign in to comment.