Skip to content

Commit

Permalink
add autotune to discrete sac
Browse files Browse the repository at this point in the history
Summary: While our continuous sac support entropy autotune, discrete sac currently does not support it. Add it.

Reviewed By: rodrigodesalvobraz

Differential Revision: D66208558

fbshipit-source-id: 326f43d4cfefd99554fa7e0df26b854696e4771b
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Dec 13, 2024
1 parent cafb382 commit 5d9316b
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import List, Optional, Type, Union
from typing import Any, Dict, Optional

import torch
from pearl.action_representation_modules.action_representation_module import (
Expand Down Expand Up @@ -38,6 +38,8 @@
from pearl.utils.functional_utils.learning.critic_utils import (
twin_critic_action_value_loss,
)
from pearl.utils.instantiations.spaces.discrete import DiscreteSpace

from torch import nn, optim


Expand Down Expand Up @@ -68,12 +70,14 @@ def __init__(
training_rounds: int = 100,
batch_size: int = 128,
entropy_coef: float = 0.2,
entropy_autotune: bool = True,
action_representation_module: ActionRepresentationModule | None = None,
actor_network_instance: ActorNetwork | None = None,
critic_network_instance: QValueNetwork | nn.Module | None = None,
actor_optimizer: Optional[optim.Optimizer] = None,
critic_optimizer: Optional[optim.Optimizer] = None,
history_summarization_optimizer: Optional[optim.Optimizer] = None,
target_entropy_scale: float = 0.89,
) -> None:
super().__init__(
state_dim=state_dim,
Expand Down Expand Up @@ -115,14 +119,69 @@ def __init__(
)

# TODO: implement learnable entropy coefficient
self._entropy_coef = entropy_coef
self._entropy_autotune = entropy_autotune
if entropy_autotune:
# initialize the entropy coefficient to 0
self.register_parameter(
"_log_entropy",
torch.nn.Parameter(torch.zeros(1, requires_grad=True)),
)
self._entropy_optimizer: torch.optim.Optimizer = optim.Adam(
# pyre-fixme[6]: In call `optim.adam.Adam.__init__`, for 1st positional argument,
# expected `Union[Iterable[Dict[str, typing.Any]], Iterable[Tuple[str, Tensor]],
# Iterable[Tensor]]` but got `List[Union[Module, Tensor]]`.
[self._log_entropy],
lr=self._critic_learning_rate,
eps=1e-4,
)
# pyre-fixme[6]: In call `optim.adam.Adam.__init__`, for 1st positional argument,
# expected `Union[Iterable[Dict[str, typing.Any]], Iterable[Tuple[str, Tensor]],
# Iterable[Tensor]]` but got `List[Union[Module, Tensor]]`.
self.register_buffer("_entropy_coef", torch.exp(self._log_entropy).detach())
assert isinstance(action_space, DiscreteSpace)
self.register_buffer(
"_target_entropy",
-target_entropy_scale * torch.log(1.0 / torch.tensor(action_space.n)),
)
else:
self.register_buffer("_entropy_coef", torch.tensor(entropy_coef))

# sac uses a learning rate scheduler specifically
def reset(self, action_space: ActionSpace) -> None:
# pyre-fixme[16]: `SoftActorCritic` has no attribute `_action_space`.
self._action_space = action_space
self.scheduler.step()

def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
actor_critic_loss = super().learn_batch(batch)

if self._entropy_autotune:
entropy = (
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__mul__)
# [[Named(self, torch._C.TensorBase), Named(other, Union[bool, complex, float,
# int, torch._tensor.Tensor])], torch._tensor.Tensor], torch._tensor.Tensor],
# nn.modules.module.Module, torch._tensor.Tensor]` is not a function.
-(self._action_probs_cache * self._action_log_probs_cache).sum(1).mean()
)
entropy_optimizer_loss = (
# pyre-fixme[6]: In call `torch._C._VariableFunctions.exp`,
# for 1st positional argument, expected `Tensor` but got `Union[Module, Tensor]`.
torch.exp(self._log_entropy) * (entropy - self._target_entropy).detach()
)

self._entropy_optimizer.zero_grad()
entropy_optimizer_loss.backward()
self._entropy_optimizer.step()
# pyre-fixme[6]: In call `torch._C._VariableFunctions.exp`,
# for 1st positional argument, expected `Tensor` but got `Union[Module, Tensor]`.
self._entropy_coef = torch.exp(self._log_entropy).detach()
actor_critic_loss = {
**actor_critic_loss,
**{"entropy_coef": entropy_optimizer_loss},
}

return actor_critic_loss

def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:
reward_batch = batch.reward # (batch_size)
terminated_batch = batch.terminated # (batch_size)
Expand Down Expand Up @@ -216,16 +275,17 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
available_actions=available_actions,
unavailable_actions_mask=unavailable_actions_mask,
) # (batch_size x action_space_size)
# pyre-fixme[16]: `SoftActorCritic` has no attribute `_action_probs_cache`.
self._action_probs_cache = new_policy_dist
# pyre-fixme[16]: `SoftActorCritic` has no attribute `_action_log_probs_cache`.
self._action_log_probs_cache = torch.log(new_policy_dist + 1e-8)
if unavailable_actions_mask is not None:
q[unavailable_actions_mask] = 0.0

loss = (
(
new_policy_dist
* (self._entropy_coef * torch.log(new_policy_dist + 1e-8) - q)
)
.sum(dim=1)
.mean()
)
# pyre-fixmeUnsupported operand [58]: `*` is not supported for operand types
# `torch._tensor.Tensor` and `Union[nn.modules.module.Module, torch._tensor.Tensor]`.
new_policy_dist * (self._entropy_coef * self._action_log_probs_cache - q)
).mean()

return loss
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ def learn_batch(self, batch: TransitionBatch) -> dict[str, Any]:
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[Module, Tensor]`.
self._entropy_coef = torch.exp(self._log_entropy).detach()
{**actor_critic_loss, **{"entropy_coef": entropy_optimizer_loss}}
actor_critic_loss = {
**actor_critic_loss,
**{"entropy_coef": entropy_optimizer_loss},
}

return actor_critic_loss

Expand Down

0 comments on commit 5d9316b

Please sign in to comment.