Skip to content

Commit

Permalink
#625 fix entropy shape
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Dec 1, 2022
1 parent ddd7b2f commit 15c682a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
5 changes: 3 additions & 2 deletions src/imitation/algorithms/pebble/entropy_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def _entropy_reward(self, state, action, next_state, done):
all_observations = all_observations.reshape((-1, *self.obs_shape))

if all_observations.shape[0] < self.nearest_neighbor_k:
# not enough observations to compare to, fall back to the learned function
# not enough observations to compare to, fall back to the learned function;
# (falling back to a constant may also be ok)
return self.learned_reward_fn(state, action, next_state, done)
else:
# TODO #625: deal with the conversion back and forth between np and torch
Expand All @@ -104,7 +105,7 @@ def _entropy_reward(self, state, action, next_state, done):
self.nearest_neighbor_k,
)
normalized_entropies = self.entropy_stats.forward(entropies)
return normalized_entropies.numpy()
return normalized_entropies.numpy()

def __getstate__(self):
state = self.__dict__.copy()
Expand Down
3 changes: 1 addition & 2 deletions src/imitation/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,5 +389,4 @@ def compute_state_entropy(
# a point is itself, which we want to skip.
assert distances_tensor.shape[-1] > k
knn_dists = th.kthvalue(distances_tensor, k=k + 1, dim=1).values
state_entropy = knn_dists
return state_entropy.unsqueeze(1)
return knn_dists
6 changes: 3 additions & 3 deletions tests/algorithms/pebble/test_entropy_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@


def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng):
all_observations = rng.random((BUFFER_SIZE, VENVS, *(OBS_SHAPE)))
all_observations = rng.random((BUFFER_SIZE, VENVS, *OBS_SHAPE))

reward_fn = PebbleStateEntropyReward(Mock(), K)
reward_fn.set_replay_buffer(
ReplayBufferView(all_observations, lambda: slice(None)), OBS_SHAPE
)

# Act
observations = th.rand((BATCH_SIZE, *(OBS_SHAPE)))
observations = th.rand((BATCH_SIZE, *OBS_SHAPE))
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)

# Assert
expected = util.compute_state_entropy(
observations, all_observations.reshape(-1, *(OBS_SHAPE)), K
observations, all_observations.reshape(-1, *OBS_SHAPE), K
)
expected_normalized = reward_fn.entropy_stats.normalize(
th.as_tensor(expected)
Expand Down

0 comments on commit 15c682a

Please sign in to comment.