From 6bd2dfe737b7ef5d0f5a152ff11a85b981f0b998 Mon Sep 17 00:00:00 2001 From: Junyu Chen Date: Sun, 24 Nov 2024 20:25:09 +0800 Subject: [PATCH] remove triton dependency of DC-AE and fix bugs; Signed-off-by: lawrence-cj --- .../efficientvit/models/efficientvit/dc_ae.py | 8 +++--- .../dc_ae/efficientvit/models/nn/norm.py | 26 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/diffusion/model/dc_ae/efficientvit/models/efficientvit/dc_ae.py b/diffusion/model/dc_ae/efficientvit/models/efficientvit/dc_ae.py index 64f162c..90c24a3 100644 --- a/diffusion/model/dc_ae/efficientvit/models/efficientvit/dc_ae.py +++ b/diffusion/model/dc_ae/efficientvit/models/efficientvit/dc_ae.py @@ -48,7 +48,7 @@ class EncoderConfig: width_list: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024) depth_list: tuple[int, ...] = (2, 2, 2, 2, 2, 2) block_type: Any = "ResBlock" - norm: str = "trms2d" + norm: str = "rms2d" act: str = "silu" downsample_block_type: str = "ConvPixelUnshuffle" downsample_match_channel: bool = True @@ -67,12 +67,12 @@ class DecoderConfig: width_list: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024) depth_list: tuple[int, ...] = (2, 2, 2, 2, 2, 2) block_type: Any = "ResBlock" - norm: Any = "trms2d" + norm: Any = "rms2d" act: Any = "silu" upsample_block_type: str = "ConvPixelShuffle" upsample_match_channel: bool = True upsample_shortcut: str = "duplicating" - out_norm: str = "trms2d" + out_norm: str = "rms2d" out_act: str = "relu" @@ -470,7 +470,7 @@ def dc_ae_f32c32(name: str, pretrained_path: str) -> DCAEConfig: "decoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] " "decoder.width_list=[128,256,512,512,1024,1024] decoder.depth_list=[3,3,3,3,3,3] " "decoder.upsample_block_type=InterpolateConv " - "decoder.norm=trms2d decoder.act=silu " + "decoder.norm=rms2d decoder.act=silu " "scaling_factor=0.41407" ) else: diff --git a/diffusion/model/dc_ae/efficientvit/models/nn/norm.py b/diffusion/model/dc_ae/efficientvit/models/nn/norm.py index 5e62beb..2b88672 100644 --- a/diffusion/model/dc_ae/efficientvit/models/nn/norm.py +++ b/diffusion/model/dc_ae/efficientvit/models/nn/norm.py @@ -40,12 +40,38 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return TritonRMSNorm2dFunc.apply(x, self.weight, self.bias, self.eps) +class RMSNorm2d(nn.Module): + def __init__( + self, num_features: int, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True + ) -> None: + super().__init__() + self.num_features = num_features + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = torch.nn.parameter.Parameter(torch.empty(self.num_features)) + if bias: + self.bias = torch.nn.parameter.Parameter(torch.empty(self.num_features)) + else: + self.register_parameter("bias", None) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = (x / torch.sqrt(torch.square(x.float()).mean(dim=1, keepdim=True) + self.eps)).to(x.dtype) + if self.elementwise_affine: + x = x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) + return x + + # register normalization function here REGISTERED_NORM_DICT: dict[str, type] = { "bn2d": nn.BatchNorm2d, "ln": nn.LayerNorm, "ln2d": LayerNorm2d, "trms2d": TritonRMSNorm2d, + "rms2d": RMSNorm2d, }