Skip to content

Commit 2a50a7f

Browse files
lawrence-cjchenjy2003
andauthoredNov 24, 2024··
remove triton dependency of DC-AE and fix bugs; (#38)
Signed-off-by: lawrence-cj <cjs1020440147@icloud.com> Co-authored-by: Junyu Chen <chenjydl2003@gmail.com>
1 parent fa267d5 commit 2a50a7f

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed
 

‎diffusion/model/dc_ae/efficientvit/models/efficientvit/dc_ae.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class EncoderConfig:
4848
width_list: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024)
4949
depth_list: tuple[int, ...] = (2, 2, 2, 2, 2, 2)
5050
block_type: Any = "ResBlock"
51-
norm: str = "trms2d"
51+
norm: str = "rms2d"
5252
act: str = "silu"
5353
downsample_block_type: str = "ConvPixelUnshuffle"
5454
downsample_match_channel: bool = True
@@ -67,12 +67,12 @@ class DecoderConfig:
6767
width_list: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024)
6868
depth_list: tuple[int, ...] = (2, 2, 2, 2, 2, 2)
6969
block_type: Any = "ResBlock"
70-
norm: Any = "trms2d"
70+
norm: Any = "rms2d"
7171
act: Any = "silu"
7272
upsample_block_type: str = "ConvPixelShuffle"
7373
upsample_match_channel: bool = True
7474
upsample_shortcut: str = "duplicating"
75-
out_norm: str = "trms2d"
75+
out_norm: str = "rms2d"
7676
out_act: str = "relu"
7777

7878

@@ -470,7 +470,7 @@ def dc_ae_f32c32(name: str, pretrained_path: str) -> DCAEConfig:
470470
"decoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
471471
"decoder.width_list=[128,256,512,512,1024,1024] decoder.depth_list=[3,3,3,3,3,3] "
472472
"decoder.upsample_block_type=InterpolateConv "
473-
"decoder.norm=trms2d decoder.act=silu "
473+
"decoder.norm=rms2d decoder.act=silu "
474474
"scaling_factor=0.41407"
475475
)
476476
else:

‎diffusion/model/dc_ae/efficientvit/models/nn/norm.py

+26
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,38 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4040
return TritonRMSNorm2dFunc.apply(x, self.weight, self.bias, self.eps)
4141

4242

43+
class RMSNorm2d(nn.Module):
44+
def __init__(
45+
self, num_features: int, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True
46+
) -> None:
47+
super().__init__()
48+
self.num_features = num_features
49+
self.eps = eps
50+
self.elementwise_affine = elementwise_affine
51+
if self.elementwise_affine:
52+
self.weight = torch.nn.parameter.Parameter(torch.empty(self.num_features))
53+
if bias:
54+
self.bias = torch.nn.parameter.Parameter(torch.empty(self.num_features))
55+
else:
56+
self.register_parameter("bias", None)
57+
else:
58+
self.register_parameter("weight", None)
59+
self.register_parameter("bias", None)
60+
61+
def forward(self, x: torch.Tensor) -> torch.Tensor:
62+
x = (x / torch.sqrt(torch.square(x.float()).mean(dim=1, keepdim=True) + self.eps)).to(x.dtype)
63+
if self.elementwise_affine:
64+
x = x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
65+
return x
66+
67+
4368
# register normalization function here
4469
REGISTERED_NORM_DICT: dict[str, type] = {
4570
"bn2d": nn.BatchNorm2d,
4671
"ln": nn.LayerNorm,
4772
"ln2d": LayerNorm2d,
4873
"trms2d": TritonRMSNorm2d,
74+
"rms2d": RMSNorm2d,
4975
}
5076

5177

0 commit comments

Comments
 (0)
Please sign in to comment.