1
+ """Reward function for the PEBBLE training algorithm."""
2
+
1
3
from enum import Enum , auto
2
- from typing import Tuple
4
+ from typing import Dict , Optional , Tuple , Union
3
5
4
6
import numpy as np
5
7
import torch as th
6
8
7
9
from imitation .policies .replay_buffer_wrapper import (
8
- ReplayBufferView ,
9
10
ReplayBufferRewardWrapper ,
11
+ ReplayBufferView ,
10
12
)
11
13
from imitation .rewards .reward_function import ReplayBufferAwareRewardFn , RewardFn
12
14
from imitation .util import util
13
15
from imitation .util .networks import RunningNorm
14
16
15
17
16
18
class PebbleRewardPhase (Enum ):
17
- """States representing different behaviors for PebbleStateEntropyReward"""
19
+ """States representing different behaviors for PebbleStateEntropyReward. """
18
20
19
21
UNSUPERVISED_EXPLORATION = auto () # Entropy based reward
20
22
POLICY_AND_REWARD_LEARNING = auto () # Learned reward
21
23
22
24
23
25
class PebbleStateEntropyReward (ReplayBufferAwareRewardFn ):
24
- """
25
- Reward function for implementation of the PEBBLE learning algorithm
26
- ( https://arxiv.org/pdf/2106.05091.pdf) .
26
+ """Reward function for implementation of the PEBBLE learning algorithm.
27
+
28
+ See https://arxiv.org/pdf/2106.05091.pdf .
27
29
28
30
The rewards returned by this function go through the three phases:
29
31
1. Before enough samples are collected for entropy calculation, the
@@ -38,33 +40,38 @@ class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
38
40
supplied with set_replay_buffer() or on_replay_buffer_initialized().
39
41
To transition to the last phase, unsupervised_exploration_finish() needs
40
42
to be called.
41
-
42
- Args:
43
- learned_reward_fn: The learned reward function used after unsupervised
44
- exploration is finished
45
- nearest_neighbor_k: Parameter for entropy computation (see
46
- compute_state_entropy())
47
43
"""
48
44
49
- # TODO #625: parametrize nearest_neighbor_k
50
45
def __init__ (
51
46
self ,
52
47
learned_reward_fn : RewardFn ,
53
48
nearest_neighbor_k : int = 5 ,
54
49
):
50
+ """Builds this class.
51
+
52
+ Args:
53
+ learned_reward_fn: The learned reward function used after unsupervised
54
+ exploration is finished
55
+ nearest_neighbor_k: Parameter for entropy computation (see
56
+ compute_state_entropy())
57
+ """
55
58
self .learned_reward_fn = learned_reward_fn
56
59
self .nearest_neighbor_k = nearest_neighbor_k
57
60
self .entropy_stats = RunningNorm (1 )
58
61
self .state = PebbleRewardPhase .UNSUPERVISED_EXPLORATION
59
62
60
63
# These two need to be set with set_replay_buffer():
61
- self .replay_buffer_view = None
62
- self .obs_shape = None
64
+ self .replay_buffer_view : Optional [ ReplayBufferView ] = None
65
+ self .obs_shape : Union [ Tuple [ int , ...], Dict [ str , Tuple [ int , ...]], None ] = None
63
66
64
67
def on_replay_buffer_initialized (self , replay_buffer : ReplayBufferRewardWrapper ):
65
68
self .set_replay_buffer (replay_buffer .buffer_view , replay_buffer .obs_shape )
66
69
67
- def set_replay_buffer (self , replay_buffer : ReplayBufferView , obs_shape : Tuple ):
70
+ def set_replay_buffer (
71
+ self ,
72
+ replay_buffer : ReplayBufferView ,
73
+ obs_shape : Union [Tuple [int , ...], Dict [str , Tuple [int , ...]]],
74
+ ):
68
75
self .replay_buffer_view = replay_buffer
69
76
self .obs_shape = obs_shape
70
77
@@ -87,7 +94,7 @@ def __call__(
87
94
def _entropy_reward (self , state , action , next_state , done ):
88
95
if self .replay_buffer_view is None :
89
96
raise ValueError (
90
- "Replay buffer must be supplied before entropy reward can be used"
97
+ "Replay buffer must be supplied before entropy reward can be used" ,
91
98
)
92
99
all_observations = self .replay_buffer_view .observations
93
100
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
0 commit comments