Skip to content

Commit

Permalink
Refactor TestAgentWithPyTorch so its models can be reused
Browse files Browse the repository at this point in the history
Summary:
`TestAgentWithPyTorch` creates and trains small models to make sure their construction hasn't been broken.

Other tests can also benefit from creating and training the same small models. Instead of piling up all tests in the same class, we refactor `TestAgentWithPyTorch` to create and train them in separate methods. Then, we can subclass it to create new tests benefitting from the same models, without making the original too overloaded.

Reviewed By: yiwan-rl

Differential Revision: D67707013

fbshipit-source-id: 42e39a517a37cc826fb4974744620bf46c1e8758
  • Loading branch information
rodrigodesalvobraz authored and facebook-github-bot committed Dec 30, 2024
1 parent 6066854 commit 4b8e774
Showing 1 changed file with 127 additions and 28 deletions.
155 changes: 127 additions & 28 deletions test/unit/with_pytorch/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
# LICENSE file in the root directory of this source tree.
#


# pyre-strict


import unittest
from typing import Tuple

import torch
from pearl.action_representation_modules.one_hot_action_representation_module import (
Expand Down Expand Up @@ -55,13 +58,7 @@


class TestAgentWithPyTorch(unittest.TestCase):
"""
A collection of Agent tests using PyTorch (this saves around 100 secs in test loading).
For tests not involving PyTorch, use see test/without_pytorch.
"""

def test_deep_td_learning_online_rl_sanity_check(self) -> None:
# make sure E2E is fine
def get_new_dqn_agent_and_environment(self) -> Tuple[PearlAgent, GymEnvironment]:
env = GymEnvironment("CartPole-v1")

assert isinstance(env.action_space, DiscreteActionSpace)
Expand All @@ -79,12 +76,20 @@ def test_deep_td_learning_online_rl_sanity_check(self) -> None:
),
replay_buffer=BasicReplayBuffer(10000),
)
return agent, env

def get_trained_dqn_agent_and_environment(
self,
) -> Tuple[PearlAgent, GymEnvironment]:
agent, env = self.get_new_dqn_agent_and_environment()
online_learning_to_png_graph(
agent, env, number_of_episodes=10, learn_after_episode=True
)
return agent, env

def test_conservative_deep_td_learning_online_rl_sanity_check(self) -> None:
# make sure E2E is fine for cql loss
def get_new_conservative_dqn_agent_and_environment(
self,
) -> Tuple[PearlAgent, GymEnvironment]:
env = GymEnvironment("CartPole-v1")

assert isinstance(env.action_space, DiscreteActionSpace)
Expand All @@ -102,16 +107,20 @@ def test_conservative_deep_td_learning_online_rl_sanity_check(self) -> None:
),
replay_buffer=BasicReplayBuffer(10000),
)
return agent, env

def get_trained_conservative_dqn_agent_and_environment(
self,
) -> Tuple[PearlAgent, GymEnvironment]:
agent, env = self.get_new_conservative_dqn_agent_and_environment()
online_learning_to_png_graph(
agent, env, number_of_episodes=10, learn_after_episode=True
)
return agent, env

def test_deep_td_learning_online_rl_sanity_check_dueling(
def get_new_dqn_dueling_agent_and_environment(
self,
number_of_episodes: int = 10,
batch_size: int = 128,
) -> None:
# make sure E2E is fine
) -> Tuple[PearlAgent, GymEnvironment]:
env = GymEnvironment("CartPole-v1")
assert isinstance(env.action_space, DiscreteActionSpace)
num_actions = env.action_space.n
Expand All @@ -122,19 +131,27 @@ def test_deep_td_learning_online_rl_sanity_check_dueling(
action_space=env.action_space,
training_rounds=20,
network_type=DuelingQValueNetwork,
batch_size=batch_size,
batch_size=128,
action_representation_module=OneHotActionTensorRepresentationModule(
max_number_actions=num_actions
),
),
replay_buffer=BasicReplayBuffer(10000),
)
return agent, env

def get_trained_dqn_dueling_agent_and_environment(
self,
) -> Tuple[PearlAgent, GymEnvironment]:
agent, env = self.get_new_dqn_dueling_agent_and_environment()
online_learning_to_png_graph(
agent, env, number_of_episodes=number_of_episodes, learn_after_episode=True
agent, env, number_of_episodes=10, learn_after_episode=True
)
return agent, env

def test_deep_td_learning_online_rl_two_tower_network(self) -> None:
# make sure E2E is fine
def get_new_dqn_two_tower_agent_and_environment(
self,
) -> Tuple[PearlAgent, GymEnvironment]:
env = GymEnvironment("CartPole-v1")

assert isinstance(env.action_space, DiscreteActionSpace)
Expand All @@ -157,11 +174,20 @@ def test_deep_td_learning_online_rl_two_tower_network(self) -> None:
),
replay_buffer=BasicReplayBuffer(10000),
)
return agent, env

def get_trained_dqn_two_tower_agent_and_environment(
self,
) -> Tuple[PearlAgent, GymEnvironment]:
agent, env = self.get_new_dqn_two_tower_agent_and_environment()
online_learning_to_png_graph(
agent, env, number_of_episodes=10, learn_after_episode=True
)
return agent, env

def test_with_linear_contextual(self) -> None:
def get_new_linear_contextual_agent_and_environment(
self,
) -> Tuple[PearlAgent, ContextualBanditLinearSyntheticEnvironment]:
"""
This is an integration test for ContextualBandit with
ContextualBanditLinearSyntheticEnvironment.
Expand All @@ -184,6 +210,12 @@ def test_with_linear_contextual(self) -> None:
action_space=action_space,
observation_dim=observation_dim,
)
return agent, env

def get_trained_linear_contextual_agent_and_environment(
self,
) -> Tuple[PearlAgent, ContextualBanditLinearSyntheticEnvironment]:
agent, env = self.get_new_linear_contextual_agent_and_environment()

regrets = []
for _ in range(100):
Expand All @@ -201,14 +233,33 @@ def test_with_linear_contextual(self) -> None:
# that the regret is decreasing over learning steps
self.assertTrue(sum(regrets[10:]) >= sum(regrets[-10:]))

def test_online_rl(self) -> None:
return agent, env

def get_new_online_rl_agent_and_environment(
self,
) -> Tuple[PearlAgent, FixedNumberOfStepsEnvironment]:
env = FixedNumberOfStepsEnvironment(max_number_of_steps=100)
agent = PearlAgent(TabularQLearning())
return agent, env

def get_trained_online_rl_agent_and_environment(
self,
) -> Tuple[PearlAgent, FixedNumberOfStepsEnvironment]:
agent, env = self.get_new_online_rl_agent_and_environment()
online_learning(agent, env, number_of_episodes=1000)
return agent, env

def test_tabular_q_learning_online_rl(self) -> None:
def get_new_tabular_q_learning_agent_and_environment(
self,
) -> Tuple[PearlAgent, GymEnvironment]:
env = GymEnvironment("FrozenLake-v1", is_slippery=False)
agent = PearlAgent(policy_learner=TabularQLearning(exploration_rate=0.7))
return agent, env

def get_trained_tabular_q_learning_agent_and_environment(
self,
) -> Tuple[PearlAgent, GymEnvironment]:
agent, env = self.get_new_tabular_q_learning_agent_and_environment()
# We use a large exploration rate because the exploitation action
# is always the first one among those with the highest value
# (so that the agent is deterministic in the absence of exploration).
Expand All @@ -227,15 +278,12 @@ def test_tabular_q_learning_online_rl(self) -> None:

online_learning(agent, env, number_of_episodes=6000)

for _ in range(100): # Should always reach the goal
episode_info, total_steps = run_episode(
agent, env, learn=False, exploit=True
)
assert episode_info["return"] == 1.0
return agent, env

def test_contextual_bandit_with_tabular_q_learning_online_rl(self) -> None:
def get_new_contextual_bandit_with_tabular_q_learning_agent_and_environment(
self,
) -> Tuple[PearlAgent, RewardIsEqualToTenTimesActionMultiArmBanditEnvironment]:
num_actions = 5
max_action = num_actions - 1
env = RewardIsEqualToTenTimesActionMultiArmBanditEnvironment(
action_space=DiscreteActionSpace(
actions=list(torch.arange(num_actions).view(-1, 1))
Expand All @@ -249,8 +297,59 @@ def test_contextual_bandit_with_tabular_q_learning_online_rl(self) -> None:
agent = PearlAgent(
policy_learner=TabularQLearning(exploration_rate=0.1, learning_rate=0.1)
)
return agent, env

def get_trained_contextual_bandit_with_tabular_q_learning_agent_and_environment(
self,
) -> Tuple[PearlAgent, RewardIsEqualToTenTimesActionMultiArmBanditEnvironment]:
agent, env = (
self.get_new_contextual_bandit_with_tabular_q_learning_agent_and_environment() # noqa E501
)

online_learning(agent, env, number_of_episodes=10000)
return agent, env

"""
A collection of Agent tests using PyTorch (this saves around 100 secs in test loading).
For tests not involving PyTorch, use see test/without_pytorch.
"""

def test_deep_td_learning_online_rl_sanity_check(self) -> None:
self.get_trained_dqn_agent_and_environment()

def test_conservative_deep_td_learning_online_rl_sanity_check(self) -> None:
self.get_trained_conservative_dqn_agent_and_environment()

def test_deep_td_learning_online_rl_sanity_check_dueling(
self,
) -> None:
self.get_trained_dqn_dueling_agent_and_environment()

def test_deep_td_learning_online_rl_two_tower_network(self) -> None:
self.get_trained_dqn_two_tower_agent_and_environment()

def test_with_linear_contextual(self) -> None:
self.get_trained_linear_contextual_agent_and_environment()

def test_online_rl(self) -> None:
self.get_trained_online_rl_agent_and_environment()

def test_tabular_q_learning_online_rl(self) -> None:
agent, env = self.get_trained_tabular_q_learning_agent_and_environment()

for _ in range(100): # Should always reach the goal
episode_info, total_steps = run_episode(
agent, env, learn=False, exploit=True
)
assert episode_info["return"] == 1.0

def test_contextual_bandit_with_tabular_q_learning_online_rl(self) -> None:
agent, env = (
self.get_trained_contextual_bandit_with_tabular_q_learning_agent_and_environment()
)
assert isinstance(env.action_space, DiscreteActionSpace)
num_actions = env.action_space.n
max_action = num_actions - 1

# Should have learned to use action max_action with reward equal to max_action * 10
for _ in range(100):
Expand Down

0 comments on commit 4b8e774

Please sign in to comment.