Skip to content

Commit

Permalink
[MLA-1952] Add optional seed for gym action spaces (#5303) (#5315)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Elion authored Apr 23, 2021
1 parent b629b49 commit 9e03966
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
4 changes: 3 additions & 1 deletion com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to
[Semantic Versioning](http://semver.org/spec/v2.0.0.html).


## [2.0.0-exp.1] - 2021-04-22
### Major Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
Expand Down Expand Up @@ -77,6 +76,9 @@ or actuators on your system. (#5194)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Fixed a bug where --results-dir has no effect. (#5269)
- Fixed a bug where old `.pt` checkpoints were not deleted during training. (#5271)
- The `UnityToGymWrapper` initializer now accepts an optional `action_space_seed` seed. If this is specified, it will
be used to set the random seed on the resulting action space. (#5303)


## [1.9.1-preview] - 2021-04-13
### Major Changes
Expand Down
9 changes: 7 additions & 2 deletions gym-unity/gym_unity/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
import numpy as np
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import gym
from gym import error, spaces
Expand Down Expand Up @@ -35,6 +35,7 @@ def __init__(
uint8_visual: bool = False,
flatten_branched: bool = False,
allow_multiple_obs: bool = False,
action_space_seed: Optional[int] = None,
):
"""
Environment initialization
Expand All @@ -46,6 +47,7 @@ def __init__(
containing the visual observations and the last element containing the array of vector observations.
If False, returns a single np.ndarray containing either only a single visual observation or the array of
vector observations.
:param action_space_seed: If non-None, will be used to set the random seed on created gym.Space instances.
"""
self._env = unity_env

Expand Down Expand Up @@ -130,6 +132,9 @@ def __init__(
"and continuous actions."
)

if action_space_seed is not None:
self._action_space.seed(action_space_seed)

# Set observations space
list_spaces: List[gym.Space] = []
shapes = self._get_vis_obs_shape()
Expand Down Expand Up @@ -305,7 +310,7 @@ def reward_range(self) -> Tuple[float, float]:
return -float("inf"), float("inf")

@property
def action_space(self):
def action_space(self) -> gym.Space:
return self._action_space

@property
Expand Down
21 changes: 15 additions & 6 deletions gym-unity/gym_unity/tests/test_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def test_gym_wrapper():
mock_env, mock_spec, mock_decision_step, mock_terminal_step
)
env = UnityToGymWrapper(mock_env)
assert isinstance(env, UnityToGymWrapper)
assert isinstance(env.reset(), np.ndarray)
actions = env.action_space.sample()
assert actions.shape[0] == 2
Expand Down Expand Up @@ -78,6 +77,21 @@ def test_action_space():
assert env.action_space.n == 5


def test_action_space_seed():
mock_env = mock.MagicMock()
mock_spec = create_mock_group_spec()
mock_decision_step, mock_terminal_step = create_mock_vector_steps(mock_spec)
setup_mock_unityenvironment(
mock_env, mock_spec, mock_decision_step, mock_terminal_step
)
actions = []
for _ in range(0, 2):
env = UnityToGymWrapper(mock_env, action_space_seed=1337)
env.reset()
actions.append(env.action_space.sample())
assert (actions[0] == actions[1]).all()


@pytest.mark.parametrize("use_uint8", [True, False], ids=["float", "uint8"])
def test_gym_wrapper_visual(use_uint8):
mock_env = mock.MagicMock()
Expand All @@ -93,7 +107,6 @@ def test_gym_wrapper_visual(use_uint8):

env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8)
assert isinstance(env.observation_space, spaces.Box)
assert isinstance(env, UnityToGymWrapper)
assert isinstance(env.reset(), np.ndarray)
actions = env.action_space.sample()
assert actions.shape[0] == 2
Expand Down Expand Up @@ -121,7 +134,6 @@ def test_gym_wrapper_single_visual_and_vector(use_uint8):
)

env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=True)
assert isinstance(env, UnityToGymWrapper)
assert isinstance(env.observation_space, spaces.Tuple)
assert len(env.observation_space) == 2
reset_obs = env.reset()
Expand All @@ -143,7 +155,6 @@ def test_gym_wrapper_single_visual_and_vector(use_uint8):

# check behavior for allow_multiple_obs = False
env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=False)
assert isinstance(env, UnityToGymWrapper)
assert isinstance(env.observation_space, spaces.Box)
reset_obs = env.reset()
assert isinstance(reset_obs, np.ndarray)
Expand All @@ -170,7 +181,6 @@ def test_gym_wrapper_multi_visual_and_vector(use_uint8):
)

env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=True)
assert isinstance(env, UnityToGymWrapper)
assert isinstance(env.observation_space, spaces.Tuple)
assert len(env.observation_space) == 3
reset_obs = env.reset()
Expand All @@ -188,7 +198,6 @@ def test_gym_wrapper_multi_visual_and_vector(use_uint8):

# check behavior for allow_multiple_obs = False
env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=False)
assert isinstance(env, UnityToGymWrapper)
assert isinstance(env.observation_space, spaces.Box)
reset_obs = env.reset()
assert isinstance(reset_obs, np.ndarray)
Expand Down

0 comments on commit 9e03966

Please sign in to comment.