Skip to content

Commit

Permalink
fix(pu): fix qmix's mixer to support image obs
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Oct 28, 2024
1 parent 55dc254 commit f946e65
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 19 deletions.
5 changes: 4 additions & 1 deletion ding/model/template/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1))
agent_q_act = agent_q_act.squeeze(-1) # T, B, A
if self.mixer:
global_state_embedding = self._global_state_encoder(global_state)
if len(global_state.shape) == 5:
global_state_embedding = self._global_state_encoder(global_state.reshape(-1, *global_state.shape[-3:])).reshape(global_state.shape[0], global_state.shape[1], -1)
else:
global_state_embedding = self._global_state_encoder(global_state)
total_q = self._mixer(agent_q_act, global_state_embedding)
else:
total_q = agent_q_act.sum(-1)
Expand Down
9 changes: 3 additions & 6 deletions dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
stop_value=1e6,
manager=dict(
shared_memory=False,
reset_timeout=6000,
),
manager=dict(shared_memory=False,),
max_env_step=3e6,
),
policy=dict(
Expand All @@ -35,14 +32,14 @@
),
learn=dict(
update_per_collect=100,
batch_size=32,
batch_size=16,
learning_rate=0.0005,
target_update_theta=0.001,
discount_factor=0.99,
double_q=True,
),
collect=dict(
n_sample=600,
n_sample=32,
unroll_len=16,
env_num=collector_env_num,
),
Expand Down
40 changes: 28 additions & 12 deletions dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,21 +147,37 @@ def _process_obs(self, obs: Dict[str, np.ndarray]) -> np.ndarray:
"""
Processes the observations into the required format.
"""
if self._channel_first:
obs = np.array([np.transpose(obs[agent], (2, 0, 1)) for agent in self._agents]).astype(np.uint8)
else:
obs = np.array([obs[agent] for agent in self._agents]).astype(np.uint8)
# Process agent observations, transpose if channel_first is True
obs = np.array(
[np.transpose(obs[agent], (2, 0, 1)) if self._channel_first else obs[agent]
for agent in self._agents],
dtype=np.uint8
)

# Return only agent observations if configured to do so
if self._cfg.get('agent_obs_only', False):
return obs
ret = {}
ret['agent_state'] = obs

# Initialize return dictionary
ret = {
'agent_state': (obs / 255.0).astype(np.float32)
}

# Obtain global state, transpose if channel_first is True
global_state = self._env.state()
if self._channel_first:
ret['global_state'] = self._env.state().transpose(2, 0, 1)
else:
ret['global_state'] = self._env.state()
if self._agent_specific_global_state: # TODO: more elegant way to handle this
ret['global_state'] = np.repeat(np.expand_dims(ret['global_state'], axis=0), self._num_pistons, axis=0)
ret['action_mask'] = np.ones((self._num_pistons, *self._action_dim)).astype(np.float32)
global_state = global_state.transpose(2, 0, 1)
ret['global_state'] = (global_state / 255.0).astype(np.float32)

# Handle agent-specific global states by repeating the global state for each agent
if self._agent_specific_global_state:
ret['global_state'] = np.tile(
np.expand_dims(ret['global_state'], axis=0),
(self._num_pistons, 1, 1, 1)
)

# Set action mask for each agent
ret['action_mask'] = np.ones((self._num_pistons, *self._action_dim), dtype=np.float32)

return ret

Expand Down

0 comments on commit f946e65

Please sign in to comment.