From e91684197cfcd6732427150277f74de95cbf3cd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Tue, 22 Oct 2024 17:15:05 +0800 Subject: [PATCH] polish(pu): pistonball reuse PTZRecordVideo --- .../envs/petting_zoo_pistonball_env.py | 53 +++---------------- .../envs/petting_zoo_simple_spread_env.py | 10 ++-- 2 files changed, 9 insertions(+), 54 deletions(-) diff --git a/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py b/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py index 4e456db710..2238b85c53 100644 --- a/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py +++ b/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py @@ -1,57 +1,16 @@ -from typing import Any, List, Union, Optional, Dict -import gymnasium as gym -import numpy as np from functools import reduce +from typing import List, Optional, Dict -from ding.envs import BaseEnv, BaseEnvTimestep, FrameStackWrapper -from ding.torch_utils import to_ndarray, to_list +import gymnasium as gym +import numpy as np +from ding.envs import BaseEnv, BaseEnvTimestep from ding.envs.common.common_function import affine_transform +from ding.torch_utils import to_ndarray from ding.utils import ENV_REGISTRY -from pettingzoo.utils.conversions import parallel_wrapper_fn +from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PTZRecordVideo from pettingzoo.butterfly import pistonball_v6 -# Custom wrapper for recording videos in PettingZoo environments -class PTZRecordVideo(gym.wrappers.RecordVideo): - def step(self, action): - """ - Custom step function for handling PettingZoo environments - with gymnasium's RecordVideo wrapper. - """ - observations, rewards, terminateds, truncateds, infos = self.env.step(action) - - # Check if any agent has terminated or truncated - if not (self.terminated is True or self.truncated is True): - self.step_id += 1 - if not self.is_vector_env: - if terminateds or truncateds: - self.episode_id += 1 - self.terminated = terminateds - self.truncated = truncateds - elif terminateds[0] or truncateds[0]: - self.episode_id += 1 - self.terminated = terminateds[0] - self.truncated = truncateds[0] - - # Capture the video frame if recording - if self.recording: - assert self.video_recorder is not None - self.video_recorder.capture_frame() - self.recorded_frames += 1 - if self.video_length > 0 and self.recorded_frames > self.video_length: - self.close_video_recorder() - elif not self.is_vector_env: - if terminateds is True or truncateds is True: - self.close_video_recorder() - elif terminateds[0] or truncateds[0]: - self.close_video_recorder() - - elif self._video_enabled(): - self.start_video_recorder() - - return observations, rewards, terminateds, truncateds, infos - - @ENV_REGISTRY.register('petting_zoo_pistonball') class PettingZooPistonballEnv(BaseEnv): """ diff --git a/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index 10c642026d..186f41e9e2 100644 --- a/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -13,17 +13,12 @@ from pettingzoo.mpe.simple_spread.simple_spread import Scenario +# Custom wrapper for recording videos in PettingZoo environments class PTZRecordVideo(gym.wrappers.RecordVideo): def step(self, action): """Steps through the environment using action, recording observations if :attr:`self.recording`.""" # gymnasium==0.27.1 - ( - observations, - rewards, - terminateds, - truncateds, - infos, - ) = self.env.step(action) + observations, rewards, terminateds, truncateds, infos = self.env.step(action) # Because pettingzoo returns a dict of terminated and truncated, we need to check if any of the values are True if not (self.terminated is True or self.truncated is True): # the first location for modifications @@ -39,6 +34,7 @@ def step(self, action): self.terminated = terminateds[0] self.truncated = truncateds[0] + # Capture the video frame if recording if self.recording: assert self.video_recorder is not None self.video_recorder.capture_frame()