7
7
from stable_baselines3 .common .buffers import ReplayBuffer
8
8
from stable_baselines3 .common .type_aliases import ReplayBufferSamples
9
9
10
- from imitation .rewards .reward_function import RewardFn
10
+ from imitation .rewards .reward_function import RewardFn , ReplayBufferAwareRewardFn
11
11
from imitation .util import util
12
12
13
13
@@ -37,13 +37,13 @@ def __init__(
37
37
observations_buffer : np .ndarray ,
38
38
buffer_slice_provider : Callable [[], slice ],
39
39
):
40
- self ._observations_buffer = observations_buffer .view ()
41
- self ._observations_buffer .flags .writeable = False
40
+ self ._observations_buffer_view = observations_buffer .view ()
41
+ self ._observations_buffer_view .flags .writeable = False
42
42
self ._buffer_slice_provider = buffer_slice_provider
43
43
44
44
@property
45
45
def observations (self ):
46
- return self ._observations_buffer [self ._buffer_slice_provider ()]
46
+ return self ._observations_buffer_view [self ._buffer_slice_provider ()]
47
47
48
48
49
49
class ReplayBufferRewardWrapper (ReplayBuffer ):
@@ -57,7 +57,6 @@ def __init__(
57
57
* ,
58
58
replay_buffer_class : Type [ReplayBuffer ],
59
59
reward_fn : RewardFn ,
60
- on_initialized_callback : Callable [["ReplayBufferRewardWrapper" ], None ] = None ,
61
60
** kwargs ,
62
61
):
63
62
"""Builds ReplayBufferRewardWrapper.
@@ -88,8 +87,8 @@ def __init__(
88
87
self .reward_fn = reward_fn
89
88
_base_kwargs = {k : v for k , v in kwargs .items () if k in ["device" , "n_envs" ]}
90
89
super ().__init__ (buffer_size , observation_space , action_space , ** _base_kwargs )
91
- if on_initialized_callback is not None :
92
- on_initialized_callback (self )
90
+ if isinstance ( reward_fn , ReplayBufferAwareRewardFn ) :
91
+ reward_fn . on_replay_buffer_initialized (self )
93
92
94
93
# TODO(juan) remove the type ignore once the merged PR
95
94
# https://github.com/python/mypy/pull/13475
0 commit comments