Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Dec 26, 2023
2 parents 177f44a + c1c5ffd commit 5cddcbe
Show file tree
Hide file tree
Showing 24 changed files with 761 additions and 250 deletions.
9 changes: 6 additions & 3 deletions ding/framework/middleware/functional/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,10 @@ def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable:
def producer(queue, dataset, batch_size, device):
torch.set_num_threads(4)
nonlocal stream
idx_iter = iter(range(len(dataset)))
idx_iter = iter(range(len(dataset) - batch_size))

if len(dataset) < batch_size:
logging.warning('batch_size is too large!!!!')
with torch.cuda.stream(stream):
while True:
if queue.full():
Expand All @@ -203,7 +206,7 @@ def producer(queue, dataset, batch_size, device):
start_idx = next(idx_iter)
except StopIteration:
del idx_iter
idx_iter = iter(range(len(dataset)))
idx_iter = iter(range(len(dataset) - batch_size))
start_idx = next(idx_iter)
data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)]
data = [[i[j] for i in data] for j in range(len(data[0]))]
Expand All @@ -213,7 +216,7 @@ def producer(queue, dataset, batch_size, device):
queue = Queue(maxsize=50)
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
producer_thread = Thread(
target=producer, args=(queue, dataset, cfg.policy.batch_size, device), name='cuda_fetcher_producer'
target=producer, args=(queue, dataset, cfg.policy.learn.batch_size, device), name='cuda_fetcher_producer'
)

def _fetch(ctx: "OfflineRLContext"):
Expand Down
109 changes: 100 additions & 9 deletions ding/league/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@


class EloCalculator(object):
"""
Overview:
A class that calculates Elo ratings for players based on game results.
Attributes:
- score (:obj:`dict`): A dictionary that maps game results to scores.
Interfaces:
``__init__``, ``get_new_rating``, ``get_new_rating_array``.
"""

score = {
1: 1.0, # win
0: 0.5, # draw
Expand All @@ -18,6 +29,20 @@ def get_new_rating(cls,
result: int,
k_factor: int = 32,
beta: int = 200) -> Tuple[int, int]:
"""
Overview:
Calculates the new ratings for two players based on their current ratings and game result.
Arguments:
- rating_a (:obj:`int`): The current rating of player A.
- rating_b (:obj:`int`): The current rating of player B.
- result (:obj:`int`): The result of the game: 1 for player A win, 0 for draw, -1 for player B win.
- k_factor (:obj:`int`): The K-factor used in the Elo rating system. Defaults to 32.
- beta (:obj:`int`): The beta value used in the Elo rating system. Defaults to 200.
Returns:
-ret (:obj:`Tuple[int, int]`): The new ratings for player A and player B, respectively.
"""
assert result in [1, 0, -1]
expect_a = 1. / (1. + math.pow(10, (rating_b - rating_a) / (2. * beta)))
expect_b = 1. / (1. + math.pow(10, (rating_a - rating_b) / (2. * beta)))
Expand All @@ -35,10 +60,25 @@ def get_new_rating_array(
beta: int = 200
) -> np.ndarray:
"""
Overview:
Calculates the new ratings for multiple players based on their current ratings, game results, \
and game counts.
Arguments:
- rating (obj:`np.ndarray`): An array of current ratings for each player.
- result (obj:`np.ndarray`): An array of game results, where 1 represents a win, 0 represents a draw, \
and -1 represents a loss.
- game_count (obj:`np.ndarray`): An array of game counts for each player.
- k_factor (obj:`int`): The K-factor used in the Elo rating system. Defaults to 32.
- beta (obj:`int`): The beta value used in the Elo rating system. Defaults to 200.
Returns:
-ret(obj:`np.ndarray`): An array of new ratings for each player.
Shapes:
rating: :math:`(N, )`, N is the number of player
result: :math:`(N, N)`
game_count: :math:`(N, N)`
- rating (obj:`np.ndarray`): :math:`(N, )`, N is the number of player
- result (obj:`np.ndarray`): :math:`(N, N)`
- game_count (obj:`np.ndarray`): :math:`(N, N)`
"""
rating_diff = np.expand_dims(rating, 0) - np.expand_dims(rating, 1)
expect = 1. / (1. + np.power(10, rating_diff / (2. * beta))) * game_count
Expand All @@ -48,6 +88,13 @@ def get_new_rating_array(


class PlayerRating(Rating):
"""
Overview:
Represents the rating of a player.
Interfaces:
``__init__``, ``__repr__``.
"""

def __init__(self, mu: float = None, sigma: float = None, elo_init: int = None) -> None:
super(PlayerRating, self).__init__(mu, sigma)
Expand All @@ -62,14 +109,33 @@ def __repr__(self) -> str:
class LeagueMetricEnv(TrueSkill):
"""
Overview:
TrueSkill rating system among game players, for more details pleas refer to ``https://trueskill.org/``
A class that represents a TrueSkill rating system for game players. Inherits from the TrueSkill class. \
For more details, please refer to https://trueskill.org/.
Interfaces:
``__init__``, ``create_rating``, ``rate_1vs1``, ``rate_1vsC``.
"""

def __init__(self, *args, elo_init: int = 1200, **kwargs) -> None:
super(LeagueMetricEnv, self).__init__(*args, **kwargs)
self.elo_init = elo_init

def create_rating(self, mu: float = None, sigma: float = None, elo_init: int = None) -> PlayerRating:
"""
Overview:
Creates a new player rating object with the specified mean, standard deviation, and Elo rating.
Arguments:
- mu (:obj:`float`): The mean value of the player's skill rating. If not provided, the default \
TrueSkill mean is used.
- sigma (:obj:`float`): The standard deviation of the player's skill rating. If not provided, \
the default TrueSkill sigma is used.
- elo_init (:obj:int`): The initial Elo rating value for the player. If not provided, the default \
elo_init value of the LeagueMetricEnv class is used.
Returns:
- PlayerRating: A player rating object with the specified mean, standard deviation, and Elo rating.
"""
if mu is None:
mu = self.mu
if sigma is None:
Expand All @@ -91,11 +157,23 @@ def _rate_1vs1(t1, t2, **kwargs):
t2 = PlayerRating(t2.mu, t2.sigma, t2_elo)
return t1, t2

def rate_1vs1(self,
team1: PlayerRating,
team2: PlayerRating,
result: List[str] = None,
**kwargs) -> Tuple[PlayerRating, PlayerRating]:
def rate_1vs1(self, team1: PlayerRating, team2: PlayerRating, result: List[str] = None, **kwargs) \
-> Tuple[PlayerRating, PlayerRating]:
"""
Overview:
Rates two teams of players against each other in a 1 vs 1 match and returns the updated ratings \
for both teams.
Arguments:
- team1 (:obj:`PlayerRating`): The rating object representing the first team of players.
- team2 (:obj:`PlayerRating`): The rating object representing the second team of players.
- result (:obj:`List[str]`): The result of the match. Can be 'wins', 'draws', or 'losses'. If \
not provided, the default behavior is to rate the match as a win for team1.
Returns:
- ret (:obj:`Tuple[PlayerRating, PlayerRating]`): A tuple containing the updated ratings for team1 \
and team2.
"""
if result is None:
return self._rate_1vs1(team1, team2, **kwargs)
else:
Expand All @@ -111,6 +189,19 @@ def rate_1vs1(self,
return team1, team2

def rate_1vsC(self, team1: PlayerRating, team2: PlayerRating, result: List[str]) -> PlayerRating:
"""
Overview:
Rates a team of players against a single player in a 1 vs C match and returns the updated rating \
for the team.
Arguments:
- team1 (:obj:`PlayerRating`): The rating object representing the team of players.
- team2 (:obj:`PlayerRating`): The rating object representing the single player.
- result (:obj:`List[str]`): The result of the match. Can be 'wins', 'draws', or 'losses'.
Returns:
- PlayerRating: The updated rating for the team of players.
"""
for r in result:
if r == 'wins':
team1, _ = self._rate_1vs1(team1, team2)
Expand Down
3 changes: 2 additions & 1 deletion ding/policy/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def default_preprocess_learn(
the following model forward and loss computation.
"""
# data preprocess
if data[0]['action'].dtype in [np.int64, torch.int64]:
elem = data[0]
if isinstance(elem['action'], (np.ndarray, torch.Tensor)) and elem['action'].dtype in [np.int64, torch.int64]:
data = default_collate(data, cat_1dim=True) # for discrete action
else:
data = default_collate(data, cat_1dim=False) # for continuous action
Expand Down
6 changes: 4 additions & 2 deletions ding/policy/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]:
if self._basic_discrete_env:
actions = actions.to(torch.long)
actions = actions.squeeze(-1)
action_target = torch.clone(actions).detach().to(self._device)
action_target = torch.clone(actions).detach().to(self._device)

if self._atari_env:
state_preds, action_preds, return_preds = self._learn_model.forward(
Expand Down Expand Up @@ -291,7 +291,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
self.states[i, self.t[i]] = data[i]['obs'].to(self._device)
else:
self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std
self.running_rtg[i] = self.running_rtg[i] - data[i]['reward'].to(self._device)
self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'] / self.rtg_scale).to(self._device)
self.rewards_to_go[i, self.t[i]] = self.running_rtg[i]

if self.t[i] <= self.context_len:
Expand Down Expand Up @@ -328,6 +328,8 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
act[i] = torch.multinomial(probs[i], num_samples=1)
else:
act = torch.argmax(logits, axis=1).unsqueeze(1)
else:
act = logits
for i in data_id:
self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t
self.t[i] += 1
Expand Down
4 changes: 4 additions & 0 deletions ding/policy/tests/test_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@

def get_action(shape, dtype, class_type):
if class_type == "numpy":
if dtype == "int64":
dtype = np.int64
elif dtype == "float32":
dtype = np.float32
return np.random.randn(*shape).astype(dtype)
else:
if dtype == "int64":
Expand Down
55 changes: 38 additions & 17 deletions ding/torch_utils/loss/contrastive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

class ContrastiveLoss(nn.Module):
"""
The class for contrastive learning losses.
Only InfoNCE loss supported currently.
Code Reference: https://github.com/rdevon/DIM.
paper: https://arxiv.org/abs/1808.06670.
Overview:
The class for contrastive learning losses. Only InfoNCE loss is supported currently. \
Code Reference: https://github.com/rdevon/DIM. Paper Reference: https://arxiv.org/abs/1808.06670.
Interfaces:
__init__, forward.
"""

def __init__(
Expand All @@ -24,13 +25,18 @@ def __init__(
temperature: float = 1.0,
) -> None:
"""
Args:
x_size: input shape for x, both the obs shape and the encoding shape are supported.
y_size: input shape for y, both the obs shape and the encoding shape are supported.
heads: a list of 2 int elems, heads[0] for x and head[1] for y.
Overview:
Initialize the ContrastiveLoss object using the given arguments.
Arguments:
- x_size (:obj:`Union[int, SequenceType]`): input shape for x, both the obs shape and the encoding shape \
are supported.
- y_size (:obj:`Union[int, SequenceType]`): Input shape for y, both the obs shape and the encoding shape \
are supported.
- heads (:obj:`SequenceType`): A list of 2 int elems, ``heads[0]`` for x and ``head[1]`` for y. \
Used in multi-head, global-local, local-local MI maximization process.
loss_type: only the InfoNCE loss is available now.
temperature: the parameter to adjust the log_softmax.
- encoder_shape (:obj:`Union[int, SequenceType]`): The dimension of encoder hidden state.
- loss_type: Only the InfoNCE loss is available now.
- temperature: The parameter to adjust the ``log_softmax``.
"""
super(ContrastiveLoss, self).__init__()
assert len(heads) == 2, "Expected length of 2, but got: {}".format(len(heads))
Expand All @@ -43,7 +49,7 @@ def __init__(
self._y_encoder = self._get_encoder(y_size, heads[1])
self._temperature = temperature

def _get_encoder(self, obs: Union[int, SequenceType], heads: int):
def _get_encoder(self, obs: Union[int, SequenceType], heads: int) -> nn.Module:
from ding.model import ConvEncoder, FCEncoder

if isinstance(obs, int):
Expand All @@ -61,14 +67,29 @@ def _get_encoder(self, obs: Union[int, SequenceType], heads: int):
encoder = ConvEncoder(obs, hidden_size_list, kernel_size=[4, 3, 2], stride=[2, 1, 1])
return encoder

def forward(self, x: torch.Tensor, y: torch.Tensor):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Computes the noise contrastive estimation-based loss, a.k.a. infoNCE.
Args:
x: the input x, both raw obs and encoding are supported.
y: the input y, both raw obs and encoding are supported.
Overview:
Computes the noise contrastive estimation-based loss, a.k.a. infoNCE.
Arguments:
- x (:obj:`torch.Tensor`): The input x, both raw obs and encoding are supported.
- y (:obj:`torch.Tensor`): The input y, both raw obs and encoding are supported.
Returns:
torch.Tensor: loss value.
loss (:obj:`torch.Tensor`): The calculated loss value.
Examples:
>>> x_dim = [3, 16]
>>> encode_shape = 16
>>> x = np.random.normal(0, 1, size=x_dim)
>>> y = x ** 2 + 0.01 * np.random.normal(0, 1, size=x_dim)
>>> estimator = ContrastiveLoss(dims, dims, encode_shape=encode_shape)
>>> loss = estimator.forward(x, y)
Examples:
>>> x_dim = [3, 1, 16, 16]
>>> encode_shape = 16
>>> x = np.random.normal(0, 1, size=x_dim)
>>> y = x ** 2 + 0.01 * np.random.normal(0, 1, size=x_dim)
>>> estimator = ContrastiveLoss(dims, dims, encode_shape=encode_shape)
>>> loss = estimator.forward(x, y)
"""

N = x.size(0)
Expand Down
Loading

0 comments on commit 5cddcbe

Please sign in to comment.