Skip to content

Commit a8b079c

Browse files
authored
Remove FloatReward (#829)
* Remove FloatReward. Fixes #794 * Bump SB3 version to ensure we have the bug-fix that makes the FloatReward unneeded.
1 parent d74e903 commit a8b079c

File tree

2 files changed

+1
-12
lines changed

2 files changed

+1
-12
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
203203
"rich",
204204
"scikit-learn>=0.21.2",
205205
"seals~=0.2.1",
206-
"stable-baselines3~=2.0",
206+
"stable-baselines3~=2.2.1",
207207
"sacred>=0.8.4",
208208
"tensorboard>=1.14",
209209
"huggingface_sb3~=3.0",

tests/algorithms/conftest.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Fixtures common across algorithm tests."""
22
from typing import Sequence
33

4-
import gymnasium as gym
54
import pytest
65
from stable_baselines3.common import envs
76
from stable_baselines3.common.policies import BasePolicy
@@ -113,20 +112,10 @@ def pendulum_single_venv(rng) -> VecEnv:
113112
)
114113

115114

116-
# TODO(GH#794): Remove after https://github.com/DLR-RM/stable-baselines3/pull/1676
117-
# merged and released.
118-
class FloatReward(gym.RewardWrapper):
119-
"""Typecasts reward to a float."""
120-
121-
def reward(self, reward):
122-
return float(reward)
123-
124-
125115
@pytest.fixture
126116
def multi_obs_venv() -> VecEnv:
127117
def make_env():
128118
env = envs.SimpleMultiObsEnv(channel_last=False)
129-
env = FloatReward(env)
130119
return RolloutInfoWrapper(env)
131120

132121
return DummyVecEnv([make_env, make_env])

0 commit comments

Comments
 (0)