Skip to content

Commit 90d3727

Browse files
committed
pre-commit
Signed-off-by: lawrence-cj <[email protected]>
1 parent cc5991b commit 90d3727

9 files changed

+62
-49
lines changed

diffusion/model/nets/__init__.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,43 @@
11
from .sana import (
2-
Sana,
3-
SanaBlock,
4-
get_2d_sincos_pos_embed,
5-
get_2d_sincos_pos_embed_from_grid,
2+
Sana,
3+
SanaBlock,
64
get_1d_sincos_pos_embed_from_grid,
5+
get_2d_sincos_pos_embed,
6+
get_2d_sincos_pos_embed_from_grid,
77
)
88
from .sana_multi_scale import (
9-
SanaMSBlock,
10-
SanaMS,
11-
SanaMS_600M_P1_D28,
9+
SanaMS,
10+
SanaMS_600M_P1_D28,
1211
SanaMS_600M_P2_D28,
1312
SanaMS_600M_P4_D28,
1413
SanaMS_1600M_P1_D20,
1514
SanaMS_1600M_P2_D20,
15+
SanaMSBlock,
1616
)
1717
from .sana_multi_scale_adaln import (
18-
SanaMSAdaLNBlock,
1918
SanaMSAdaLN,
2019
SanaMSAdaLN_600M_P1_D28,
2120
SanaMSAdaLN_600M_P2_D28,
2221
SanaMSAdaLN_600M_P4_D28,
2322
SanaMSAdaLN_1600M_P1_D20,
2423
SanaMSAdaLN_1600M_P2_D20,
24+
SanaMSAdaLNBlock,
2525
)
2626
from .sana_U_shape import (
27-
SanaUBlock,
2827
SanaU,
2928
SanaU_600M_P1_D28,
3029
SanaU_600M_P2_D28,
3130
SanaU_600M_P4_D28,
3231
SanaU_1600M_P1_D20,
3332
SanaU_1600M_P2_D20,
33+
SanaUBlock,
3434
)
3535
from .sana_U_shape_multi_scale import (
36-
SanaUMSBlock,
3736
SanaUMS,
3837
SanaUMS_600M_P1_D28,
3938
SanaUMS_600M_P2_D28,
4039
SanaUMS_600M_P4_D28,
4140
SanaUMS_1600M_P1_D20,
4241
SanaUMS_1600M_P2_D20,
42+
SanaUMSBlock,
4343
)

diffusion/model/nets/sana.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343

4444
_triton_modules_available = False
4545
if is_triton_module_available():
46-
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
46+
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
47+
4748
_triton_modules_available = True
4849

4950

@@ -84,7 +85,9 @@ def __init__(
8485
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
8586
elif attn_type == "triton_linear":
8687
if not _triton_modules_available:
87-
raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.")
88+
raise ValueError(
89+
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
90+
)
8891
# linear self attention with triton kernel fusion
8992
# TODO: Here the num_heads set to 36 for tmp used
9093
self_num_heads = hidden_size // linear_head_dim
@@ -131,7 +134,9 @@ def __init__(
131134
)
132135
elif ffn_type == "triton_mbconvpreglu":
133136
if not _triton_modules_available:
134-
raise ValueError(f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}.")
137+
raise ValueError(
138+
f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
139+
)
135140
self.mlp = TritonMBConvPreGLU(
136141
in_dim=hidden_size,
137142
out_dim=hidden_size,

diffusion/model/nets/sana_U_shape.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,6 @@
2323

2424
from diffusion.model.builder import MODELS
2525
from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
26-
try:
27-
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
28-
except ImportError:
29-
import warnings
30-
warnings.warn("TritonLiteMLA with `triton` is not available on your platform.")
3126
from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed
3227
from diffusion.model.nets.sana_blocks import (
3328
Attention,
@@ -41,13 +36,14 @@
4136
t2i_modulate,
4237
)
4338
from diffusion.model.norms import RMSNorm
44-
from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
45-
from diffusion.utils.logger import get_root_logger
39+
from diffusion.model.utils import auto_grad_checkpoint
4640
from diffusion.utils.import_utils import is_triton_module_available
41+
from diffusion.utils.logger import get_root_logger
4742

4843
_triton_modules_available = False
4944
if is_triton_module_available():
50-
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
45+
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
46+
5147
_triton_modules_available = True
5248

5349

@@ -88,7 +84,9 @@ def __init__(
8884
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
8985
elif attn_type == "triton_linear":
9086
if not _triton_modules_available:
91-
raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.")
87+
raise ValueError(
88+
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
89+
)
9290
# linear self attention with triton kernel fusion
9391
# TODO: Here the num_heads set to 36 for tmp used
9492
self_num_heads = hidden_size // 32

diffusion/model/nets/sana_U_shape_multi_scale.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,19 @@
2929
LiteLA,
3030
MultiHeadCrossAttention,
3131
PatchEmbedMS,
32-
SizeEmbedder,
3332
T2IFinalLayer,
3433
t2i_modulate,
3534
)
36-
from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
35+
from diffusion.model.utils import auto_grad_checkpoint
3736
from diffusion.utils.import_utils import is_triton_module_available
3837

3938
_triton_modules_available = False
4039
if is_triton_module_available():
41-
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
40+
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
41+
4242
_triton_modules_available = True
4343

44+
4445
class SanaUMSBlock(nn.Module):
4546
"""
4647
A SanaU block with global shared adaptive layer norm (adaLN-single) conditioning and U-shaped model.
@@ -79,7 +80,9 @@ def __init__(
7980
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
8081
elif attn_type == "triton_linear":
8182
if not _triton_modules_available:
82-
raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.")
83+
raise ValueError(
84+
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
85+
)
8386
# linear self attention with triton kernel fusion
8487
self_num_heads = hidden_size // 32
8588
self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)

diffusion/model/nets/sana_blocks.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,19 @@
2222
import torch
2323
import torch.nn as nn
2424
import torch.nn.functional as F
25-
from diffusion.utils.import_utils import is_xformers_available
2625
from einops import rearrange
2726
from timm.models.vision_transformer import Attention as Attention_
2827
from timm.models.vision_transformer import Mlp
2928
from transformers import AutoModelForCausalLM
3029

3130
from diffusion.model.norms import RMSNorm
3231
from diffusion.model.utils import get_same_padding, to_2tuple
33-
32+
from diffusion.utils.import_utils import is_xformers_available
3433

3534
_xformers_available = False
3635
if is_xformers_available():
3736
import xformers.ops
37+
3838
_xformers_available = True
3939

4040

diffusion/model/nets/sana_multi_scale.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,15 @@
3737

3838
_triton_modules_available = False
3939
if is_triton_module_available():
40-
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
40+
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
41+
4142
_triton_modules_available = True
4243

4344
_xformers_available = False
4445
if is_xformers_available():
45-
import xformers.ops
4646
_xformers_available = True
4747

48+
4849
class SanaMSBlock(nn.Module):
4950
"""
5051
A Sana block with global shared adaptive layer norm zero (adaLN-Zero) conditioning.
@@ -84,7 +85,9 @@ def __init__(
8485
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
8586
elif attn_type == "triton_linear":
8687
if not _triton_modules_available:
87-
raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.")
88+
raise ValueError(
89+
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
90+
)
8891
# linear self attention with triton kernel fusion
8992
self_num_heads = hidden_size // linear_head_dim
9093
self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)
@@ -120,7 +123,9 @@ def __init__(
120123
)
121124
elif ffn_type == "triton_mbconvpreglu":
122125
if not _triton_modules_available:
123-
raise ValueError(f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}.")
126+
raise ValueError(
127+
f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
128+
)
124129
self.mlp = TritonMBConvPreGLU(
125130
in_dim=hidden_size,
126131
out_dim=hidden_size,
@@ -316,7 +321,7 @@ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
316321
y_lens = [y.shape[2]] * y.shape[0]
317322
y = y.squeeze(1).view(1, -1, x.shape[-1])
318323
else:
319-
raise ValueError(f"{attn_type} type is not available due to _xformers_available={_xformers_available}.")
324+
raise ValueError(f"Attention type is not available due to _xformers_available={_xformers_available}.")
320325

321326
for block in self.blocks:
322327
x = auto_grad_checkpoint(

diffusion/model/nets/sana_multi_scale_adaln.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,6 @@
2121

2222
from diffusion.model.builder import MODELS
2323
from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
24-
try:
25-
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
26-
except ImportError:
27-
import warnings
28-
warnings.warn("TritonLiteMLA with `triton` is not available on your platform.")
2924
from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed
3025
from diffusion.model.nets.sana_blocks import (
3126
Attention,
@@ -38,12 +33,13 @@
3833
T2IFinalLayer,
3934
modulate,
4035
)
41-
from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
36+
from diffusion.model.utils import auto_grad_checkpoint
4237
from diffusion.utils.import_utils import is_triton_module_available
4338

4439
_triton_modules_available = False
4540
if is_triton_module_available():
46-
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
41+
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
42+
4743
_triton_modules_available = True
4844

4945

@@ -84,7 +80,9 @@ def __init__(
8480
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
8581
elif attn_type == "triton_linear":
8682
if not _triton_modules_available:
87-
raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.")
83+
raise ValueError(
84+
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
85+
)
8886
# linear self attention with triton kernel fusion
8987
self_num_heads = hidden_size // 32
9088
self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)

diffusion/model/nets/sana_others.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020
from timm.models.layers import DropPath
2121

2222
from diffusion.model.nets.basic_modules import DWMlp, MBConvPreGLU, Mlp
23-
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
24-
try:
25-
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
26-
except ImportError:
27-
import warnings
28-
warnings.warn("TritonLiteMLA with `triton` is not available on your platform.")
2923
from diffusion.model.nets.sana_blocks import Attention, FlashAttention, MultiHeadCrossAttention, t2i_modulate
24+
from diffusion.utils.import_utils import is_triton_module_available
25+
26+
_triton_modules_available = False
27+
if is_triton_module_available():
28+
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
29+
30+
_triton_modules_available = True
3031

3132

3233
class SanaMSPABlock(nn.Module):

diffusion/utils/import_utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import importlib.util
2-
import importlib_metadata
3-
from packaging import version
42
import logging
53
import warnings
64

5+
import importlib_metadata
6+
from packaging import version
7+
78
logger = logging.getLogger(__name__)
89

910
_xformers_available = importlib.util.find_spec("xformers") is not None
@@ -28,8 +29,10 @@
2829
_triton_modules_available = False
2930
warnings.warn("TritonLiteMLA and TritonMBConvPreGLU with `triton` is not available on your platform.")
3031

32+
3133
def is_xformers_available():
3234
return _xformers_available
3335

36+
3437
def is_triton_module_available():
3538
return _triton_modules_available

0 commit comments

Comments
 (0)