Skip to content

Commit

Permalink
fix(nyz): fix multiple model wrappers reset bug (#846)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Dec 5, 2024
1 parent e93b5a6 commit 5615816
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
4 changes: 3 additions & 1 deletion ding/model/wrapper/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def reset(self, data_id: List[int] = None, **kwargs) -> None:
model wrappers often needs to maintain some stateful variables for each data trajectory, \
so we leave this ``data_id`` argument to reset the stateful variables of the indicated data.
"""
pass
# This is necessary when multiple model wrappers.
if hasattr(self._model, 'reset'):
return self._model.reset(data_id=data_id, **kwargs)

def forward(self, *args, **kwargs) -> Any:
"""
Expand Down
41 changes: 41 additions & 0 deletions ding/model/wrapper/test_model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,20 @@ def forward(self, data):
return {'output': output, 'next_state': next_state}


class TempLSTMActor(torch.nn.Module):

def __init__(self):
super(TempLSTMActor, self).__init__()
self.model = get_lstm(lstm_type='pytorch', input_size=36, hidden_size=32, num_layers=2, norm_type=None)

def forward(self, data, tmp=0):
output, next_state = self.model(data['f'], data['prev_state'], list_next_state=True)
ret = {'logit': output, 'tmp': tmp, 'action': output + torch.rand_like(output), 'next_state': next_state}
if 'mask' in data:
ret['action_mask'] = data['mask']
return ret


@pytest.fixture(scope='function')
def setup_model():
return torch.nn.Linear(3, 6)
Expand Down Expand Up @@ -576,3 +590,30 @@ def test_combination_multinomial_sample_wrapper(self):
output = model.forward(shot_number=shot_number, inputs=data)
assert output['action'].shape == (4, shot_number)
assert (output['action'] >= 0).all() and (output['action'] < 64).all()

def test_hidden_state_and_epsilon_greedy_wrapper(self):
model = model_wrap(TempLSTMActor(), wrapper_name='hidden_state', state_num=4, save_prev_state=True)
model = model_wrap(model, wrapper_name='eps_greedy_sample')
model.reset()
# Check that reset properly initializes all states to None
assert all([isinstance(s, type(None)) for s in model._state.values()])

data = {'f': torch.randn(2, 4, 36)}
output = model.forward(data, eps=0.8)
assert output['tmp'] == 0
assert 'logit' in output
assert output['logit'].shape == (2, 4, 32)
assert 'action' in output
assert output['action'].shape == (2, 4)
assert 'prev_state' in output
assert len(output['prev_state']) == 4
assert output['prev_state'][0]['h'].shape == (2, 1, 32)
assert output['prev_state'][0]['c'].shape == (2, 1, 32)

assert all([isinstance(s, dict) for s in model._state.values()])
# Check that reset with specific data_id works
model.reset(data_id=[0, 2])
assert isinstance(model._state[0], type(None))
assert isinstance(model._state[2], type(None))
assert isinstance(model._state[1], dict)
assert isinstance(model._state[3], dict)

0 comments on commit 5615816

Please sign in to comment.