Skip to content

Commit

Permalink
Allow twin critic to take network instances and test SAC in atari games.
Browse files Browse the repository at this point in the history
Summary: Our current twin critic does not allow taking network instances. Add it. Also, add SAC atari experiments in benchmark_config.py.

Reviewed By: rodrigodesalvobraz

Differential Revision: D66281503

fbshipit-source-id: e6e6138179593048266d88e087957a499ea30660
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Dec 13, 2024
1 parent a92fdd9 commit 5b9138f
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 24 deletions.
58 changes: 34 additions & 24 deletions pearl/neural_networks/sequential_decision_making/twin_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

import inspect
from collections.abc import Callable, Iterable
from typing import Tuple, Type
from typing import Optional, Tuple, Type

import torch
import torch.nn as nn
from pearl.neural_networks.common.utils import init_weights
from pearl.neural_networks.sequential_decision_making.q_value_networks import (
QValueNetwork,
VanillaQValueNetwork,
Expand All @@ -30,35 +31,44 @@ class TwinCritic(torch.nn.Module):

def __init__(
self,
state_dim: int,
action_dim: int,
hidden_dims: Iterable[int],
init_fn: Callable[[nn.Module], None],
state_dim: Optional[int] = None,
action_dim: Optional[int] = None,
hidden_dims: Optional[Iterable[int]] = None,
init_fn: Callable[[nn.Module], None] = init_weights,
network_type: type[QValueNetwork] = VanillaQValueNetwork,
output_dim: int = 1,
network_instance_1: Optional[QValueNetwork] = None,
network_instance_2: Optional[QValueNetwork] = None,
) -> None:
super().__init__()
if network_instance_1 is not None and network_instance_2 is not None:
self._critic_1: QValueNetwork = network_instance_1
self._critic_2: QValueNetwork = network_instance_2
else:
assert state_dim is not None
assert action_dim is not None
assert hidden_dims is not None
assert network_type is not None
if inspect.isabstract(network_type):
raise ValueError("network_type must not be abstract")

if inspect.isabstract(network_type):
raise ValueError("network_type must not be abstract")
# pyre-ignore[45]:
# Pyre does not know that `network_type` is asserted to be concrete
self._critic_1: QValueNetwork = network_type(
state_dim=state_dim,
action_dim=action_dim,
hidden_dims=hidden_dims,
output_dim=output_dim,
)

# pyre-ignore[45]:
# Pyre does not know that `network_type` is asserted to be concrete
self._critic_1: QValueNetwork = network_type(
state_dim=state_dim,
action_dim=action_dim,
hidden_dims=hidden_dims,
output_dim=output_dim,
)

# pyre-ignore[45]:
# Pyre does not know that `network_type` is asserted to be concrete
self._critic_2: QValueNetwork = network_type(
state_dim=state_dim,
action_dim=action_dim,
hidden_dims=hidden_dims,
output_dim=output_dim,
)
# pyre-ignore[45]:
# Pyre does not know that `network_type` is asserted to be concrete
self._critic_2: QValueNetwork = network_type(
state_dim=state_dim,
action_dim=action_dim,
hidden_dims=hidden_dims,
output_dim=output_dim,
)

# nn.ModuleList helps manage the networks
# (initialization, parameter update etc.) efficiently
Expand Down
41 changes: 41 additions & 0 deletions pearl/utils/scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
CNNQValueMultiHeadNetwork,
CNNQValueNetwork,
)
from pearl.neural_networks.sequential_decision_making.twin_critic import TwinCritic
from pearl.pearl_agent import PearlAgent
from pearl.utils.functional_utils.experimentation.set_seed import set_seed

Expand Down Expand Up @@ -244,6 +245,46 @@ def evaluate_single(
),
**method["network_args"],
)
if "critic_network_module" in method and method["critic_network_module"] in [
CNNQValueNetwork,
CNNQValueMultiHeadNetwork,
]:
action_dim = policy_learner_args[
"action_representation_module"
].representation_dim
output_dim = (
1 if method["critic_network_module"] is CNNQValueNetwork else action_dim
)
if "use_twin_critic" in method and method["use_twin_critic"]:
policy_learner_args["critic_network_instance"] = TwinCritic(
network_instance_1=method["critic_network_module"](
input_width=env.observation_space.shape[2],
input_height=env.observation_space.shape[1],
input_channels_count=env.observation_space.shape[0],
action_dim=action_dim,
output_dim=output_dim,
**method["critic_network_args"],
),
network_instance_2=method["critic_network_module"](
input_width=env.observation_space.shape[2],
input_height=env.observation_space.shape[1],
input_channels_count=env.observation_space.shape[0],
action_dim=action_dim,
output_dim=output_dim,
**method["critic_network_args"],
),
)
else:
policy_learner_args["critic_network_instance"] = method[
"critic_network_module"
](
input_width=env.observation_space.shape[2],
input_height=env.observation_space.shape[1],
input_channels_count=env.observation_space.shape[0],
action_dim=action_dim,
output_dim=output_dim,
**method["critic_network_args"],
)

if (
"critic_network_module" in method
Expand Down
69 changes: 69 additions & 0 deletions pearl/utils/scripts/benchmark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
EnsembleQValueNetwork,
VanillaQValueNetwork,
)
from pearl.neural_networks.sequential_decision_making.twin_critic import TwinCritic
from pearl.policy_learners.exploration_modules.common.epsilon_greedy_exploration import ( # noqa E501
EGreedyExploration,
)
Expand Down Expand Up @@ -532,6 +533,72 @@
"action_representation_module": OneHotActionTensorRepresentationModule,
"action_representation_module_args": {},
}
SAC_Atari_method = {
"name": "SAC",
"policy_learner": SoftActorCritic,
"policy_learner_args": {
"actor_hidden_dims": [64, 64],
"critic_hidden_dims": [64, 64],
"training_rounds": 50,
"batch_size": 32,
"entropy_coef": 0.1,
},
"actor_network_module": CNNActorNetwork,
"actor_network_args": {
"hidden_dims_fully_connected": [512],
"kernel_sizes": [8, 4, 3],
"output_channels_list": [32, 64, 64],
"strides": [4, 2, 1],
"paddings": [0, 0, 0],
},
"use_twin_critic": True,
"critic_network_module": CNNQValueNetwork,
"critic_network_args": {
"hidden_dims_fully_connected": [512],
"kernel_sizes": [8, 4, 3],
"output_channels_list": [32, 64, 64],
"strides": [4, 2, 1],
"paddings": [0, 0, 0],
},
"replay_buffer": BasicReplayBuffer,
"replay_buffer_args": {"capacity": 50000},
"action_representation_module": OneHotActionTensorRepresentationModule,
"action_representation_module_args": {},
"learn_every_k_steps": 200,
}
SAC_multi_head_Atari_method = {
"name": "SAC",
"policy_learner": SoftActorCritic,
"policy_learner_args": {
"actor_hidden_dims": [64, 64],
"critic_hidden_dims": [64, 64],
"training_rounds": 50,
"batch_size": 32,
"entropy_coef": 0.1,
},
"actor_network_module": CNNActorNetwork,
"actor_network_args": {
"hidden_dims_fully_connected": [512],
"kernel_sizes": [8, 4, 3],
"output_channels_list": [32, 64, 64],
"strides": [4, 2, 1],
"paddings": [0, 0, 0],
},
"use_twin_critic": True,
"critic_network_module": CNNQValueMultiHeadNetwork,
"critic_network_args": {
"hidden_dims_fully_connected": [512],
"kernel_sizes": [8, 4, 3],
"output_channels_list": [32, 64, 64],
"strides": [4, 2, 1],
"paddings": [0, 0, 0],
},
"replay_buffer": BasicReplayBuffer,
"replay_buffer_args": {"capacity": 50000},
"action_representation_module": OneHotActionTensorRepresentationModule,
"action_representation_module_args": {},
"learn_every_k_steps": 200,
}
IQL_online_method = {
"name": "IQL",
"policy_learner": ImplicitQLearning,
Expand Down Expand Up @@ -1408,6 +1475,8 @@
DQN_Atari_method,
DQN_multi_head_Atari_method,
PPO_Atari_method,
SAC_Atari_method,
SAC_multi_head_Atari_method,
],
"device_id": 0,
}
Expand Down

0 comments on commit 5b9138f

Please sign in to comment.