From f7f4d04aeec0858e1f9104cb8ef55bd5efcb0d58 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 27 Nov 2024 21:46:48 +0800 Subject: [PATCH] feature(xyy):add HPT model to implement PolicyStem+DuelingHead --- ding/model/template/__init__.py | 2 +- ding/model/template/hpt.py | 143 +++++++++++++++++- ding/model/template/policy_stem.py | 133 ---------------- .../entry/lunarlander_dqn_example.py | 12 +- .../entry/lunarlander_hpt_example.py | 21 +-- 5 files changed, 145 insertions(+), 166 deletions(-) delete mode 100644 ding/model/template/policy_stem.py diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index d121d7f30e..de506123ba 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -31,4 +31,4 @@ from .qgpo import QGPO from .ebm import EBM, AutoregressiveEBM from .havac import HAVAC -from .policy_stem import PolicyStem + diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py index ba550a8549..e7a2960b2b 100644 --- a/ding/model/template/hpt.py +++ b/ding/model/template/hpt.py @@ -1,25 +1,152 @@ from typing import Union, Optional, Dict, Callable, List +from einops import rearrange, repeat import torch import torch.nn as nn - from ding.model.common.head import DuelingHead from ding.utils.registry_factory import MODEL_REGISTRY -from ding.model.template.policy_stem import PolicyStem + + +INIT_CONST = 0.02 + @MODEL_REGISTRY.register('hpt') class HPT(nn.Module): def __init__(self, state_dim, action_dim): super(HPT, self).__init__() - # 初始化 Policy Stem + # Initialise Policy Stem self.policy_stem = PolicyStem() self.policy_stem.init_cross_attn() - # Dueling Head,输入为 16*128,输出为动作维度 + # Dueling Head, input is 16*128, output is action dimension self.head = DuelingHead(hidden_size=16*128, output_size=action_dim) def forward(self, x): - # Policy Stem 输出 [B, 16, 128] + # Policy Stem Outputs [B, 16, 128] tokens = self.policy_stem.compute_latent(x) - # Flatten 操作 + # Flatten Operation tokens_flattened = tokens.view(tokens.size(0), -1) # [B, 16*128] - # 输入到 Dueling Head + # Enter to Dueling Head q_values = self.head(tokens_flattened) - return q_values \ No newline at end of file + return q_values + + + +class PolicyStem(nn.Module): + """policy stem + Overview: + The reference uses PolicyStem from + + """ + def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs): + super().__init__() + # Initialise the feature extraction module + self.feature_extractor = nn.Linear(feature_dim, token_dim) + # Initialise CrossAttention + self.init_cross_attn() + + def init_cross_attn(self): + """Initialize cross attention module and learnable tokens.""" + token_num = 16 + self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST) + self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1) + + def compute_latent(self, x: torch.Tensor) -> torch.Tensor: + """ + Compute latent representations of input data using attention. + + Args: + x (torch.Tensor): Input tensor with shape [B, T, ..., F]. + + Returns: + torch.Tensor: Latent tokens, shape [B, 16, 128]. + """ + # Using the Feature Extractor + stem_feat = self.feature_extractor(x) + stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) + # Calculating latent tokens using CrossAttention + stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128) + stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128) + return stem_tokens + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass to compute latent tokens. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Latent tokens tensor. + """ + return self.compute_latent(x) + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + + def unfreeze(self): + for param in self.parameters(): + param.requires_grad = True + + def save(self, path : str): + torch.save(self.state_dict(), path) + + @property + def device(self): + return next(self.parameters()).device + +class CrossAttention(nn.Module): + """ + CrossAttention module used in the Perceiver IO model. + + Args: + query_dim (int): The dimension of the query input. + heads (int, optional): The number of attention heads. Defaults to 8. + dim_head (int, optional): The dimension of each attention head. Defaults to 64. + dropout (float, optional): The dropout probability. Defaults to 0.0. + """ + + def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = query_dim + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, query_dim) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor, context: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Forward pass of the CrossAttention module. + + Args: + x (torch.Tensor): The query input tensor. + context (torch.Tensor): The context input tensor. + mask (torch.Tensor, optional): The attention mask tensor. Defaults to None. + + Returns: + torch.Tensor: The output tensor. + """ + h = self.heads + q = self.to_q(x) + k, v = self.to_kv(context).chunk(2, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale + + if mask is not None: + # fill in the masks with negative values + mask = rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + # dropout + attn = self.dropout(attn) + out = torch.einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) \ No newline at end of file diff --git a/ding/model/template/policy_stem.py b/ding/model/template/policy_stem.py deleted file mode 100644 index 81963cdbc4..0000000000 --- a/ding/model/template/policy_stem.py +++ /dev/null @@ -1,133 +0,0 @@ -# -------------------------------------------------------- -# Licensed under The MIT License [see LICENSE for details] -# -------------------------------------------------------- - - -import torch -from torch import nn -from typing import List, Optional - - -import torch -import torch.nn as nn -INIT_CONST = 0.02 -from einops import rearrange, repeat, reduce -class CrossAttention(nn.Module): - """ - CrossAttention module used in the Perceiver IO model. - - Args: - query_dim (int): The dimension of the query input. - heads (int, optional): The number of attention heads. Defaults to 8. - dim_head (int, optional): The dimension of each attention head. Defaults to 64. - dropout (float, optional): The dropout probability. Defaults to 0.0. - """ - - def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0): - super().__init__() - inner_dim = dim_head * heads - context_dim = query_dim - self.scale = dim_head**-0.5 - self.heads = heads - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) - self.to_out = nn.Linear(inner_dim, query_dim) - - self.dropout = nn.Dropout(dropout) - - def forward(self, x: torch.Tensor, context: torch.Tensor, - mask: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - Forward pass of the CrossAttention module. - - Args: - x (torch.Tensor): The query input tensor. - context (torch.Tensor): The context input tensor. - mask (torch.Tensor, optional): The attention mask tensor. Defaults to None. - - Returns: - torch.Tensor: The output tensor. - """ - h = self.heads - q = self.to_q(x) - k, v = self.to_kv(context).chunk(2, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale - - if mask is not None: - # fill in the masks with negative values - mask = rearrange(mask, "b ... -> b (...)") - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, "b j -> (b h) () j", h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - # dropout - attn = self.dropout(attn) - out = torch.einsum("b i j, b j d -> b i d", attn, v) - out = rearrange(out, "(b h) n d -> b n (h d)", h=h) - return self.to_out(out) - -class PolicyStem(nn.Module): - """policy stem""" - - def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs): - super().__init__() - # 初始化特征提取模块 - self.feature_extractor = nn.Linear(feature_dim, token_dim) - # 初始化 CrossAttention - self.init_cross_attn() - - def init_cross_attn(self): - """Initialize cross attention module and learnable tokens.""" - token_num = 16 - self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST) - self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1) - - def compute_latent(self, x: torch.Tensor) -> torch.Tensor: - """ - Compute latent representations of input data using attention. - - Args: - x (torch.Tensor): Input tensor with shape [B, T, ..., F]. - - Returns: - torch.Tensor: Latent tokens, shape [B, 16, 128]. - """ - # 使用特征提取器而不是直接调用 self(x) - stem_feat = self.feature_extractor(x) - stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) - # 使用 CrossAttention 计算 latent tokens - stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128) - stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128) - return stem_tokens - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass to compute latent tokens. - - Args: - x (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Latent tokens tensor. - """ - return self.compute_latent(x) - - def freeze(self): - for param in self.parameters(): - param.requires_grad = False - - def unfreeze(self): - for param in self.parameters(): - param.requires_grad = True - - def save(self, path : str): - torch.save(self.state_dict(), path) - - @property - def device(self): - return next(self.parameters()).device - diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py index bfcc1ab1d9..2ca2c06361 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py @@ -1,4 +1,5 @@ import gym +import torch from ditk import logging from ding.data.model_loader import FileModelLoader from ding.data.storage_loader import FileStorageLoader @@ -15,7 +16,7 @@ from ding.utils import set_pkg_seed from dizoo.box2d.lunarlander.config.lunarlander_dqn_config import main_config, create_config -import torch + def main(): @@ -35,19 +36,14 @@ def main(): set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - # 迁移模型到 GPU + # # Migrating models to the GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = DQN(**cfg.policy.model).to(device) - # 检查模型是否在 GPU - for param in model.parameters(): - print("模型参数所在设备:", param.device) - break buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) - # 将模型传入 Policy + # Pass the model into Policy policy = DQNPolicy(cfg.policy, model=model) - print("日志保存路径:", cfg.exp_name) # Consider the case with multiple processes diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py index b9dd6d8681..1410648dd3 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py @@ -1,5 +1,7 @@ import gym +import torch +import torch.nn as nn from ditk import logging from ding.data.model_loader import FileModelLoader from ding.data.storage_loader import FileStorageLoader @@ -16,8 +18,7 @@ nstep_reward_enhancer from ding.utils import set_pkg_seed from dizoo.box2d.lunarlander.config.lunarlander_hpt_config import main_config, create_config -import torch -import torch.nn as nn + @@ -38,24 +39,13 @@ def main(): set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - # 迁移模型到 GPU + # Migrating models to the GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # model = DQN(**cfg.policy.model).to(device) model = HPT(cfg.policy.model.obs_shape,cfg.policy.model.action_shape).to(device) - - - - # 检查模型是否在 GPU - for param in model.parameters(): - print("模型参数所在设备:", param.device) - break buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) - # 将模型传入 Policy + # Pass the model into Policy policy = DQNPolicy(cfg.policy, model=model) - print("日志保存路径:", cfg.exp_name) - # Consider the case with multiple processes if task.router.is_active: @@ -74,7 +64,6 @@ def main(): # Here is the part of single process pipeline. - # evaluator_env.enable_save_replay(replay_path='./video') task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(eps_greedy_handler(cfg)) task.use(StepCollector(cfg, policy.collect_mode, collector_env))