Skip to content

Commit

Permalink
Implement compare for contextual bandits
Browse files Browse the repository at this point in the history
Summary: Implements `compare` for contextual bandits

Reviewed By: yiwan-rl

Differential Revision: D67707014

fbshipit-source-id: ca1fe430c1cbde9bbbbab81eedd0446e1c099b16
  • Loading branch information
rodrigodesalvobraz authored and facebook-github-bot committed Dec 31, 2024
1 parent c29ed5f commit 043fdfa
Show file tree
Hide file tree
Showing 11 changed files with 612 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

from abc import abstractmethod
from typing import Any
from typing import Any, List

import torch
from pearl.action_representation_modules.action_representation_module import (
Expand Down Expand Up @@ -83,3 +83,30 @@ def get_scores(
Return scores trained by this contextual bandit algorithm
"""
pass

def compare(self, other: PolicyLearner) -> str:
"""
Compares two ContextualBanditBase instances for equality,
checking attributes and exploration module.
Args:
other: The other PolicyLearner to compare with.
Returns:
str: A string describing the differences, or an empty string if they are identical.
"""

differences: List[str] = []

differences.append(super().compare(other))

if not isinstance(other, ContextualBanditBase):
differences.append("other is not an instance of ContextualBanditBase")
else:
# Compare attributes
if self._feature_dim != other._feature_dim:
differences.append(
f"_feature_dim is different: {self._feature_dim} vs {other._feature_dim}"
)

return "\n".join(differences)
41 changes: 40 additions & 1 deletion pearl/policy_learners/contextual_bandits/disjoint_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Any
from typing import Any, List

import torch

Expand All @@ -28,6 +28,7 @@
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.learning.action_utils import (
concatenate_actions_to_state,
Expand Down Expand Up @@ -238,3 +239,41 @@ def set_history_summarization_module(
# to the optimizer of the bandit, but disjoint bandits do not use a pytorch optimizer.
# Instead, the optimization uses Pearl's own linear regression module.
self._history_summarization_module = value

def compare(self, other: PolicyLearner) -> str:
"""
Compares two DisjointBanditContainer instances for equality,
checking attributes, arm bandits, and exploration module.
Args:
other: The other DisjointBanditContainer to compare with.
Returns:
str: A string describing the differences, or an empty string if they are identical.
"""
differences: List[str] = []

differences.extend(super().compare(other))

if not isinstance(other, DisjointBanditContainer):
differences.append("other is not an instance of DisjointBanditContainer")
else:
# Compare attributes
if self._n_arms != other._n_arms:
differences.append(
f"_n_arms is different: {self._n_arms} vs {other._n_arms}"
)
if self._state_features_only != other._state_features_only:
differences.append(
f"_state_features_only is different: {self._state_features_only} "
+ f"vs {other._state_features_only}"
)

# Compare arm bandits
for i, (arm_bandit1, arm_bandit2) in enumerate(
zip(self._arm_bandits, other._arm_bandits)
):
if (reason := arm_bandit1.compare(arm_bandit2)) != "":
differences.append(f"Arm bandit {i} is different: {reason}")

return "\n".join(differences) # Join the differences with newlines
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Any
from typing import Any, List

import torch

Expand All @@ -24,6 +24,7 @@
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.learning.action_utils import (
concatenate_actions_to_state,
Expand Down Expand Up @@ -154,3 +155,39 @@ def set_history_summarization_module(
# to the optimizer of the bandit, but disjoint bandits do not use a pytorch optimizer.
# Instead, the optimization uses Pearl's own linear regression module.
self._history_summarization_module = value

def compare(self, other: PolicyLearner) -> str:
"""
Compares two DisjointLinearBandit instances for equality,
checking attributes, linear regressions, and exploration module.
Args:
other: The other ContextualBanditBase to compare with.
Returns:
str: A string describing the differences, or an empty string if they are identical.
"""
differences: List[str] = []

differences.extend(super().compare(other))

if not isinstance(other, DisjointLinearBandit):
differences.append("other is not an instance of DisjointLinearBandit")
else:
# Compare attributes
if self._state_features_only != other._state_features_only:
differences.append(
f"_state_features_only is different: {self._state_features_only} vs "
+ "{other._state_features_only}"
)

# Compare linear regressions
for i, (lr1, lr2) in enumerate(
zip(self._linear_regressions_list, other._linear_regressions_list)
):
assert isinstance(lr1, LinearRegression)
assert isinstance(lr2, LinearRegression)
if (reason := lr1.compare(lr2)) != "":
differences.append(f"Linear regression {i} is different: {reason}")

return "\n".join(differences)
43 changes: 42 additions & 1 deletion pearl/policy_learners/contextual_bandits/linear_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Any
from typing import Any, List

import torch
from pearl.action_representation_modules.action_representation_module import (
Expand All @@ -29,6 +29,7 @@
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.learning.action_utils import (
concatenate_actions_to_state,
Expand Down Expand Up @@ -223,3 +224,43 @@ def set_history_summarization_module(
# currently linear bandit algorithm does not update
# parameters of the history summarization module
self._history_summarization_module = value

def compare(self, other: PolicyLearner) -> str:
"""
Compares two LinearBandit instances for equality,
checking attributes, model, and exploration module.
Args:
other: The other ContextualBanditBase to compare with.
Returns:
str: A string describing the differences, or an empty string if they are identical.
"""
differences: List[str] = []

differences.extend(super().compare(other))

if not isinstance(other, LinearBandit):
differences.append("other is not an instance of LinearBandit")
else: # Type refinement with else block
# Compare attributes
if self.apply_discounting_interval != other.apply_discounting_interval:
differences.append(
f"apply_discounting_interval is different: {self.apply_discounting_interval} "
+ f"vs {other.apply_discounting_interval}"
)
if (
self.last_sum_weight_when_discounted
!= other.last_sum_weight_when_discounted
):
differences.append(
"last_sum_weight_when_discounted is different: "
+ f"{self.last_sum_weight_when_discounted} "
+ f"vs {other.last_sum_weight_when_discounted}"
)

# Compare models using their compare method
if (reason := self.model.compare(other.model)) != "":
differences.append(f"model is different: {reason}")

return "\n".join(differences)
51 changes: 50 additions & 1 deletion pearl/policy_learners/contextual_bandits/neural_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Any
from typing import Any, List

import torch
from pearl.action_representation_modules.action_representation_module import (
Expand All @@ -30,11 +30,16 @@
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.learning.action_utils import (
concatenate_actions_to_state,
)
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
from pearl.utils.module_utils import (
modules_have_similar_state_dict,
optimizers_have_similar_state_dict,
)
from torch import optim


Expand Down Expand Up @@ -181,3 +186,47 @@ def get_scores(
@property
def optimizer(self) -> torch.optim.Optimizer:
return self._optimizer

def compare(self, other: PolicyLearner) -> str:
"""
Compares two NeuralBandit instances for equality,
checking attributes, model, and exploration module.
Args:
other: The other ContextualBanditBase to compare with.
Returns:
str: A string describing the differences, or an empty string if they are identical.
"""
differences: List[str] = []

differences.extend(super().compare(other))

if not isinstance(other, NeuralBandit):
differences.append("other is not an instance of NeuralBandit")
else: # Type refinement with else block
# Compare attributes
if self._state_features_only != other._state_features_only:
differences.append(
f"_state_features_only is different: {self._state_features_only} "
+ f"vs {other._state_features_only}"
)
if self.loss_type != other.loss_type:
differences.append(
f"loss_type is different: {self.loss_type} vs {other.loss_type}"
)

# Compare models using modules_have_similar_state_dict
if (
reason := modules_have_similar_state_dict(self.model, other.model)
) != "":
differences.append(f"model is different: {reason}")

if (
reason := optimizers_have_similar_state_dict(
self._optimizer, other._optimizer
)
) != "":
differences.append(f"optimizer is different: {reason}")

return "\n".join(differences)
65 changes: 64 additions & 1 deletion pearl/policy_learners/contextual_bandits/neural_linear_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Any
from typing import Any, List

import torch
from pearl.action_representation_modules.action_representation_module import (
Expand Down Expand Up @@ -35,11 +35,13 @@
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.learning.action_utils import (
concatenate_actions_to_state,
)
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
from pearl.utils.module_utils import optimizers_have_similar_state_dict
from torch import optim


Expand Down Expand Up @@ -307,3 +309,64 @@ def get_scores(
representation=self.model._linear_regression_layer,
)
return scores.reshape(batch_size, -1).squeeze(-1)

def compare(self, other: PolicyLearner) -> str:
"""
Compares two NeuralLinearBandit instances for equality,
checking attributes, model, and exploration module.
Args:
other: The other ContextualBanditBase to compare with.
Returns:
str: A string describing the differences, or an empty string if they are identical.
"""
differences: List[str] = []

differences.extend(super().compare(other))

if not isinstance(other, NeuralLinearBandit):
differences.append("other is not an instance of NeuralLinearBandit")
else: # Type refinement with else block
# Compare attributes
if self._state_features_only != other._state_features_only:
differences.append(
f"_state_features_only is different: {self._state_features_only} "
+ f"vs {other._state_features_only}"
)
if self.loss_type != other.loss_type:
differences.append(
f"loss_type is different: {self.loss_type} vs {other.loss_type}"
)
if self.apply_discounting_interval != other.apply_discounting_interval:
differences.append(
f"apply_discounting_interval is different: {self.apply_discounting_interval} "
+ f"vs {other.apply_discounting_interval}"
)
if (
self.last_sum_weight_when_discounted
!= other.last_sum_weight_when_discounted
):
differences.append(
"last_sum_weight_when_discounted is different: "
+ f"{self.last_sum_weight_when_discounted} "
+ f"vs {other.last_sum_weight_when_discounted}"
)
if self.separate_uncertainty != other.separate_uncertainty:
differences.append(
f"separate_uncertainty is different: {self.separate_uncertainty} "
+ f"vs {other.separate_uncertainty}"
)

# Compare models using their compare method
if (reason := self.model.compare(other.model)) != "":
differences.append(f"model is different: {reason}")

if (
reason := optimizers_have_similar_state_dict(
self._optimizer, other._optimizer
)
) != "":
differences.append(f"optimizer is different: {reason}")

return "\n".join(differences)
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ def compare(self, other: ExplorationModule) -> str:
differences.append(
"other is not an instance of ThompsonSamplingExplorationLinearDisjoint"
)

# No additional attributes to compare
else:
if self._enable_efficient_sampling != other._enable_efficient_sampling:
differences.append(
f"_enable_efficient_sampling is different: {self._enable_efficient_sampling} "
+ f"vs {other._enable_efficient_sampling}"
)

return "\n".join(differences)
Loading

0 comments on commit 043fdfa

Please sign in to comment.