|
18 | 18 |
|
19 | 19 | import imitation.testing.reward_nets as testing_reward_nets
|
20 | 20 | from imitation.algorithms import preference_comparisons
|
| 21 | +from imitation.algorithms.preference_comparisons import ( |
| 22 | + PebbleAgentTrainer, |
| 23 | + TrajectoryGenerator, |
| 24 | +) |
21 | 25 | from imitation.data import types
|
22 | 26 | from imitation.data.types import TrajectoryWithRew
|
23 | 27 | from imitation.policies.replay_buffer_wrapper import ReplayBufferView
|
24 | 28 | from imitation.regularization import regularizers, updaters
|
25 | 29 | from imitation.rewards import reward_nets
|
| 30 | +from imitation.rewards.reward_function import RewardFn |
26 | 31 | from imitation.scripts.train_preference_comparisons import create_pebble_reward_fn
|
27 | 32 | from imitation.util import networks, util
|
28 | 33 |
|
@@ -1120,3 +1125,28 @@ def test_that_trainer_improves(
|
1120 | 1125 | )
|
1121 | 1126 |
|
1122 | 1127 | assert np.mean(trained_agent_rewards) > np.mean(novice_agent_rewards)
|
| 1128 | + |
| 1129 | + |
| 1130 | +def test_trajectory_generator_raises_on_pretrain_if_not_implemented(): |
| 1131 | + class TrajectoryGeneratorTestImpl(TrajectoryGenerator): |
| 1132 | + def sample(self, steps: int) -> Sequence[TrajectoryWithRew]: |
| 1133 | + return [] |
| 1134 | + |
| 1135 | + generator = TrajectoryGeneratorTestImpl() |
| 1136 | + assert generator.has_pretraining is False |
| 1137 | + with pytest.raises(ValueError, match="should not consume any timesteps"): |
| 1138 | + generator.unsupervised_pretrain(1) |
| 1139 | + |
| 1140 | + generator.sample(1) # just to make coverage happy |
| 1141 | + |
| 1142 | + |
| 1143 | +def test_pebble_agent_trainer_expects_pebble_reward(agent, venv, rng): |
| 1144 | + reward_fn: RewardFn = lambda state, action, next, done: state |
| 1145 | + |
| 1146 | + with pytest.raises(ValueError, match="PebbleStateEntropyReward"): |
| 1147 | + PebbleAgentTrainer( |
| 1148 | + algorithm=agent, |
| 1149 | + reward_fn=reward_fn, # type: ignore[call-arg] |
| 1150 | + venv=venv, |
| 1151 | + rng=rng, |
| 1152 | + ) |
0 commit comments