Skip to content

Commit

Permalink
feature(xyy):add HPT model to implement PolicyStem+DuelingHead
Browse files Browse the repository at this point in the history
  • Loading branch information
luodi-7 committed Nov 27, 2024
1 parent 9667211 commit f7f4d04
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 166 deletions.
2 changes: 1 addition & 1 deletion ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@
from .qgpo import QGPO
from .ebm import EBM, AutoregressiveEBM
from .havac import HAVAC
from .policy_stem import PolicyStem

143 changes: 135 additions & 8 deletions ding/model/template/hpt.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,152 @@
from typing import Union, Optional, Dict, Callable, List
from einops import rearrange, repeat
import torch
import torch.nn as nn

from ding.model.common.head import DuelingHead
from ding.utils.registry_factory import MODEL_REGISTRY
from ding.model.template.policy_stem import PolicyStem


INIT_CONST = 0.02

@MODEL_REGISTRY.register('hpt')
class HPT(nn.Module):
def __init__(self, state_dim, action_dim):
super(HPT, self).__init__()
# 初始化 Policy Stem
# Initialise Policy Stem
self.policy_stem = PolicyStem()
self.policy_stem.init_cross_attn()

# Dueling Head,输入为 16*128,输出为动作维度
# Dueling Head, input is 16*128, output is action dimension
self.head = DuelingHead(hidden_size=16*128, output_size=action_dim)
def forward(self, x):
# Policy Stem 输出 [B, 16, 128]
# Policy Stem Outputs [B, 16, 128]
tokens = self.policy_stem.compute_latent(x)
# Flatten 操作
# Flatten Operation
tokens_flattened = tokens.view(tokens.size(0), -1) # [B, 16*128]
# 输入到 Dueling Head
# Enter to Dueling Head
q_values = self.head(tokens_flattened)
return q_values
return q_values



class PolicyStem(nn.Module):
"""policy stem
Overview:
The reference uses PolicyStem from
<https://github.com/liruiw/HPT/blob/main/hpt/models/policy_stem.py>
"""
def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs):
super().__init__()
# Initialise the feature extraction module
self.feature_extractor = nn.Linear(feature_dim, token_dim)
# Initialise CrossAttention
self.init_cross_attn()

def init_cross_attn(self):
"""Initialize cross attention module and learnable tokens."""
token_num = 16
self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST)
self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1)

def compute_latent(self, x: torch.Tensor) -> torch.Tensor:
"""
Compute latent representations of input data using attention.
Args:
x (torch.Tensor): Input tensor with shape [B, T, ..., F].
Returns:
torch.Tensor: Latent tokens, shape [B, 16, 128].
"""
# Using the Feature Extractor
stem_feat = self.feature_extractor(x)
stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128)
# Calculating latent tokens using CrossAttention
stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128)
stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128)
return stem_tokens
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass to compute latent tokens.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Latent tokens tensor.
"""
return self.compute_latent(x)

def freeze(self):
for param in self.parameters():
param.requires_grad = False

def unfreeze(self):
for param in self.parameters():
param.requires_grad = True

def save(self, path : str):
torch.save(self.state_dict(), path)

@property
def device(self):
return next(self.parameters()).device

class CrossAttention(nn.Module):
"""
CrossAttention module used in the Perceiver IO model.
Args:
query_dim (int): The dimension of the query input.
heads (int, optional): The number of attention heads. Defaults to 8.
dim_head (int, optional): The dimension of each attention head. Defaults to 64.
dropout (float, optional): The dropout probability. Defaults to 0.0.
"""

def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = query_dim
self.scale = dim_head**-0.5
self.heads = heads

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, query_dim)

self.dropout = nn.Dropout(dropout)

def forward(self, x: torch.Tensor, context: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Forward pass of the CrossAttention module.
Args:
x (torch.Tensor): The query input tensor.
context (torch.Tensor): The context input tensor.
mask (torch.Tensor, optional): The attention mask tensor. Defaults to None.
Returns:
torch.Tensor: The output tensor.
"""
h = self.heads
q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale

if mask is not None:
# fill in the masks with negative values
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, max_neg_value)

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)

# dropout
attn = self.dropout(attn)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
133 changes: 0 additions & 133 deletions ding/model/template/policy_stem.py

This file was deleted.

12 changes: 4 additions & 8 deletions dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gym
import torch
from ditk import logging
from ding.data.model_loader import FileModelLoader
from ding.data.storage_loader import FileStorageLoader
Expand All @@ -15,7 +16,7 @@
from ding.utils import set_pkg_seed
from dizoo.box2d.lunarlander.config.lunarlander_dqn_config import main_config, create_config

import torch



def main():
Expand All @@ -35,19 +36,14 @@ def main():

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

# 迁移模型到 GPU
# # Migrating models to the GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DQN(**cfg.policy.model).to(device)

# 检查模型是否在 GPU
for param in model.parameters():
print("模型参数所在设备:", param.device)
break
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)

# 将模型传入 Policy
# Pass the model into Policy
policy = DQNPolicy(cfg.policy, model=model)
print("日志保存路径:", cfg.exp_name)


# Consider the case with multiple processes
Expand Down
21 changes: 5 additions & 16 deletions dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@

import gym
import torch
import torch.nn as nn
from ditk import logging
from ding.data.model_loader import FileModelLoader
from ding.data.storage_loader import FileStorageLoader
Expand All @@ -16,8 +18,7 @@
nstep_reward_enhancer
from ding.utils import set_pkg_seed
from dizoo.box2d.lunarlander.config.lunarlander_hpt_config import main_config, create_config
import torch
import torch.nn as nn




Expand All @@ -38,24 +39,13 @@ def main():

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

# 迁移模型到 GPU
# Migrating models to the GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = DQN(**cfg.policy.model).to(device)
model = HPT(cfg.policy.model.obs_shape,cfg.policy.model.action_shape).to(device)



# 检查模型是否在 GPU
for param in model.parameters():
print("模型参数所在设备:", param.device)
break
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)

# 将模型传入 Policy
# Pass the model into Policy
policy = DQNPolicy(cfg.policy, model=model)
print("日志保存路径:", cfg.exp_name)


# Consider the case with multiple processes
if task.router.is_active:
Expand All @@ -74,7 +64,6 @@ def main():


# Here is the part of single process pipeline.
# evaluator_env.enable_save_replay(replay_path='./video')
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(eps_greedy_handler(cfg))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
Expand Down

0 comments on commit f7f4d04

Please sign in to comment.