| 
18 | 18 | 
 
  | 
19 | 19 | import imitation.testing.reward_nets as testing_reward_nets  | 
20 | 20 | from imitation.algorithms import preference_comparisons  | 
 | 21 | +from imitation.algorithms.preference_comparisons import (  | 
 | 22 | +    PebbleAgentTrainer,  | 
 | 23 | +    TrajectoryGenerator,  | 
 | 24 | +)  | 
21 | 25 | from imitation.data import types  | 
22 | 26 | from imitation.data.types import TrajectoryWithRew  | 
23 | 27 | from imitation.policies.replay_buffer_wrapper import ReplayBufferView  | 
24 | 28 | from imitation.regularization import regularizers, updaters  | 
25 | 29 | from imitation.rewards import reward_nets  | 
 | 30 | +from imitation.rewards.reward_function import RewardFn  | 
26 | 31 | from imitation.scripts.train_preference_comparisons import create_pebble_reward_fn  | 
27 | 32 | from imitation.util import networks, util  | 
28 | 33 | 
 
  | 
@@ -1120,3 +1125,26 @@ def test_that_trainer_improves(  | 
1120 | 1125 |     )  | 
1121 | 1126 | 
 
  | 
1122 | 1127 |     assert np.mean(trained_agent_rewards) > np.mean(novice_agent_rewards)  | 
 | 1128 | + | 
 | 1129 | + | 
 | 1130 | +def test_trajectory_generator_raises_on_pretrain_if_not_implemented():  | 
 | 1131 | +    class TrajectoryGeneratorTestImpl(TrajectoryGenerator):  | 
 | 1132 | +        def sample(self, steps: int) -> Sequence[TrajectoryWithRew]:  | 
 | 1133 | +            return []  | 
 | 1134 | + | 
 | 1135 | +    generator = TrajectoryGeneratorTestImpl()  | 
 | 1136 | +    assert generator.has_pretraining is False  | 
 | 1137 | +    with pytest.raises(ValueError, match="should not consume any timesteps"):  | 
 | 1138 | +        generator.unsupervised_pretrain(1)  | 
 | 1139 | + | 
 | 1140 | + | 
 | 1141 | +def test_pebble_agent_trainer_expects_pebble_reward(agent, venv, rng):  | 
 | 1142 | +    reward_fn: RewardFn = lambda state, action, next, done: state  | 
 | 1143 | + | 
 | 1144 | +    with pytest.raises(ValueError, match="PebbleStateEntropyReward"):  | 
 | 1145 | +        PebbleAgentTrainer(  | 
 | 1146 | +            algorithm=agent,  | 
 | 1147 | +            reward_fn=reward_fn,  # type:ignore[arg-type]  | 
 | 1148 | +            venv=venv,  | 
 | 1149 | +            rng=rng,  | 
 | 1150 | +        )  | 
0 commit comments