diff --git a/bsuite/baselines/jax/actor_critic/agent.py b/bsuite/baselines/jax/actor_critic/agent.py index 7962b849..aded414f 100644 --- a/bsuite/baselines/jax/actor_critic/agent.py +++ b/bsuite/baselines/jax/actor_critic/agent.py @@ -56,7 +56,7 @@ def __init__( # Define loss function. def loss(trajectory: sequence.Trajectory) -> jnp.ndarray: """"Actor-critic loss.""" - logits, values = network(trajectory.observations) + logits, values = network(trajectory.observations) # pytype: disable=wrong-arg-types # jax-ndarray td_errors = rlax.td_lambda( v_tm1=values[:-1], r_t=trajectory.rewards, diff --git a/bsuite/baselines/utils/sequence_test.py b/bsuite/baselines/utils/sequence_test.py index 985a487b..a4e25b38 100644 --- a/bsuite/baselines/utils/sequence_test.py +++ b/bsuite/baselines/utils/sequence_test.py @@ -31,8 +31,8 @@ def test_buffer(self): max_sequence_length = 10 obs_shape = (3, 3) buffer = sequence.Buffer( - obs_spec=specs.Array(obs_shape, dtype=np.float), - action_spec=specs.Array((), dtype=np.int), + obs_spec=specs.Array(obs_shape, dtype=float), + action_spec=specs.Array((), dtype=int), max_sequence_length=max_sequence_length) dummy_step = dm_env.transition(observation=np.zeros(obs_shape), reward=0.) diff --git a/bsuite/environments/cartpole.py b/bsuite/environments/cartpole.py index ea6133db..f7be258b 100644 --- a/bsuite/environments/cartpole.py +++ b/bsuite/environments/cartpole.py @@ -159,10 +159,10 @@ def _reset(self) -> dm_env.TimeStep: raise NotImplementedError('This environment implements its own auto-reset.') def action_spec(self): - return specs.DiscreteArray(dtype=np.int, num_values=3, name='action') + return specs.DiscreteArray(dtype=int, num_values=3, name='action') def observation_spec(self): - return specs.Array(shape=(1, 6), dtype=np.float32, name='state') + return specs.Array(shape=(1, 6), dtype=np.float32, name='observation') @property def observation(self) -> np.ndarray: diff --git a/bsuite/environments/catch.py b/bsuite/environments/catch.py index 5f88f4f0..c72d8d24 100644 --- a/bsuite/environments/catch.py +++ b/bsuite/environments/catch.py @@ -99,12 +99,12 @@ def _step(self, action: int) -> dm_env.TimeStep: def observation_spec(self) -> specs.BoundedArray: """Returns the observation spec.""" return specs.BoundedArray(shape=self._board.shape, dtype=self._board.dtype, - name="board", minimum=0, maximum=1) + name="observation", minimum=0, maximum=1) def action_spec(self) -> specs.DiscreteArray: """Returns the action spec.""" return specs.DiscreteArray( - dtype=np.int, num_values=len(_ACTIONS), name="action") + dtype=int, num_values=len(_ACTIONS), name="action") def _observation(self) -> np.ndarray: self._board.fill(0.) diff --git a/bsuite/experiments/cartpole_swingup/cartpole_swingup.py b/bsuite/experiments/cartpole_swingup/cartpole_swingup.py index 13c3b9e9..31d690d5 100644 --- a/bsuite/experiments/cartpole_swingup/cartpole_swingup.py +++ b/bsuite/experiments/cartpole_swingup/cartpole_swingup.py @@ -129,7 +129,7 @@ def _reset(self) -> dm_env.TimeStep: raise NotImplementedError('This environment implements its own auto-reset.') def action_spec(self): - return specs.DiscreteArray(dtype=np.int, num_values=3, name='action') + return specs.DiscreteArray(dtype=int, num_values=3, name='action') def observation_spec(self): return specs.Array(shape=(1, 8), dtype=np.float32, name='state') diff --git a/bsuite/experiments/deep_sea/analysis.py b/bsuite/experiments/deep_sea/analysis.py index fc336dc5..3616cf8d 100644 --- a/bsuite/experiments/deep_sea/analysis.py +++ b/bsuite/experiments/deep_sea/analysis.py @@ -37,7 +37,7 @@ def _check_data(df: pd.DataFrame) -> None: def find_solution(df_in: pd.DataFrame, sweep_vars: Optional[Sequence[str]] = None, merge: bool = True, - thresh: float = 0.8, + thresh: float = 0.9, num_episodes: int = NUM_EPISODES) -> pd.DataFrame: """Find first episode that gets below thresh regret by sweep_vars.""" # Check data has the necessary columns for deep sea diff --git a/bsuite/logging/logging_utils.py b/bsuite/logging/logging_utils.py index 9c801d87..652e5c30 100644 --- a/bsuite/logging/logging_utils.py +++ b/bsuite/logging/logging_utils.py @@ -15,7 +15,7 @@ # ============================================================================ """Read functionality for local csv-based experiments.""" -import collections +from collections import abc import copy from typing import Any, Callable, List, Mapping, Sequence, Tuple, Union @@ -72,7 +72,7 @@ def load_multiple_runs( # Convert any inputs to dictionary format. if isinstance(path_collection, six.string_types): path_collection = {path_collection: path_collection} - if not isinstance(path_collection, collections.Mapping): + if not isinstance(path_collection, abc.Mapping): path_collection = {path: path for path in path_collection} # Loop through multiple bsuite runs, and apply single_load_fn to each. diff --git a/bsuite/utils/gym_wrapper.py b/bsuite/utils/gym_wrapper.py index 84dde88b..a0cb67ac 100644 --- a/bsuite/utils/gym_wrapper.py +++ b/bsuite/utils/gym_wrapper.py @@ -97,8 +97,9 @@ def reward_range(self) -> Tuple[float, float]: def __getattr__(self, attr): """Delegate attribute access to underlying environment.""" - return getattr(self._env, attr) - + if "_env" in self.__dict__: + return getattr(self._env, attr) + return super().__getattribute__(attr) def space2spec(space: gym.Space, name: Optional[str] = None): """Converts an OpenAI Gym space to a dm_env spec or nested structure of specs. diff --git a/bsuite/utils/wrappers.py b/bsuite/utils/wrappers.py index 7e064531..d25d3027 100644 --- a/bsuite/utils/wrappers.py +++ b/bsuite/utils/wrappers.py @@ -134,8 +134,9 @@ def raw_env(self): def __getattr__(self, attr): """Delegate attribute access to underlying environment.""" - return getattr(self._env, attr) - + if "_env" in self.__dict__: + return getattr(self._env, attr) + return super().__getattribute__(attr) def _logarithmic_logging(episode: int, ratios: Optional[Sequence[float]] = None) -> bool: @@ -173,8 +174,9 @@ def step(self, action): def __getattr__(self, attr): """Delegate attribute access to underlying environment.""" - return getattr(self._env, attr) - + if "_env" in self.__dict__: + return getattr(self._env, attr) + return super().__getattribute__(attr) def _small_state_to_image(shape: Sequence[int], observation: np.ndarray) -> np.ndarray: @@ -307,8 +309,9 @@ def bsuite_info(self) -> Dict[str, Any]: def __getattr__(self, attr): """Delegate attribute access to underlying environment.""" - return getattr(self._env, attr) - + if "_env" in self.__dict__: + return getattr(self._env, attr) + return super().__getattribute__(attr) class RewardScale(environments.Environment): """Reward Scale environment wrapper.""" @@ -370,4 +373,6 @@ def bsuite_info(self) -> Dict[str, Any]: def __getattr__(self, attr): """Delegate attribute access to underlying environment.""" - return getattr(self._env, attr) + if "_env" in self.__dict__: + return getattr(self._env, attr) + return super().__getattribute__(attr)