diff --git a/src/imitation/algorithms/pebble/entropy_reward.py b/src/imitation/algorithms/pebble/entropy_reward.py index 3d9d76b00..e0d94c171 100644 --- a/src/imitation/algorithms/pebble/entropy_reward.py +++ b/src/imitation/algorithms/pebble/entropy_reward.py @@ -94,7 +94,8 @@ def _entropy_reward(self, state, action, next_state, done): all_observations = all_observations.reshape((-1, *self.obs_shape)) if all_observations.shape[0] < self.nearest_neighbor_k: - # not enough observations to compare to, fall back to the learned function + # not enough observations to compare to, fall back to the learned function; + # (falling back to a constant may also be ok) return self.learned_reward_fn(state, action, next_state, done) else: # TODO #625: deal with the conversion back and forth between np and torch @@ -104,7 +105,7 @@ def _entropy_reward(self, state, action, next_state, done): self.nearest_neighbor_k, ) normalized_entropies = self.entropy_stats.forward(entropies) - return normalized_entropies.numpy() + return normalized_entropies.numpy() def __getstate__(self): state = self.__dict__.copy() diff --git a/src/imitation/util/util.py b/src/imitation/util/util.py index 9e5815e0c..9bf1c1a40 100644 --- a/src/imitation/util/util.py +++ b/src/imitation/util/util.py @@ -389,5 +389,4 @@ def compute_state_entropy( # a point is itself, which we want to skip. assert distances_tensor.shape[-1] > k knn_dists = th.kthvalue(distances_tensor, k=k + 1, dim=1).values - state_entropy = knn_dists - return state_entropy.unsqueeze(1) + return knn_dists diff --git a/tests/algorithms/pebble/test_entropy_reward.py b/tests/algorithms/pebble/test_entropy_reward.py index c4f127b09..918222382 100644 --- a/tests/algorithms/pebble/test_entropy_reward.py +++ b/tests/algorithms/pebble/test_entropy_reward.py @@ -21,7 +21,7 @@ def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng): - all_observations = rng.random((BUFFER_SIZE, VENVS, *(OBS_SHAPE))) + all_observations = rng.random((BUFFER_SIZE, VENVS, *OBS_SHAPE)) reward_fn = PebbleStateEntropyReward(Mock(), K) reward_fn.set_replay_buffer( @@ -29,12 +29,12 @@ def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng): ) # Act - observations = th.rand((BATCH_SIZE, *(OBS_SHAPE))) + observations = th.rand((BATCH_SIZE, *OBS_SHAPE)) reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER) # Assert expected = util.compute_state_entropy( - observations, all_observations.reshape(-1, *(OBS_SHAPE)), K + observations, all_observations.reshape(-1, *OBS_SHAPE), K ) expected_normalized = reward_fn.entropy_stats.normalize( th.as_tensor(expected)