|
37 | 37 |
|
38 | 38 | _triton_modules_available = False
|
39 | 39 | if is_triton_module_available():
|
40 |
| - from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU |
| 40 | + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU |
| 41 | + |
41 | 42 | _triton_modules_available = True
|
42 | 43 |
|
43 | 44 | _xformers_available = False
|
44 | 45 | if is_xformers_available():
|
45 |
| - import xformers.ops |
46 | 46 | _xformers_available = True
|
47 | 47 |
|
| 48 | + |
48 | 49 | class SanaMSBlock(nn.Module):
|
49 | 50 | """
|
50 | 51 | A Sana block with global shared adaptive layer norm zero (adaLN-Zero) conditioning.
|
@@ -84,7 +85,9 @@ def __init__(
|
84 | 85 | self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
|
85 | 86 | elif attn_type == "triton_linear":
|
86 | 87 | if not _triton_modules_available:
|
87 |
| - raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") |
| 88 | + raise ValueError( |
| 89 | + f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}." |
| 90 | + ) |
88 | 91 | # linear self attention with triton kernel fusion
|
89 | 92 | self_num_heads = hidden_size // linear_head_dim
|
90 | 93 | self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)
|
@@ -120,7 +123,9 @@ def __init__(
|
120 | 123 | )
|
121 | 124 | elif ffn_type == "triton_mbconvpreglu":
|
122 | 125 | if not _triton_modules_available:
|
123 |
| - raise ValueError(f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") |
| 126 | + raise ValueError( |
| 127 | + f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}." |
| 128 | + ) |
124 | 129 | self.mlp = TritonMBConvPreGLU(
|
125 | 130 | in_dim=hidden_size,
|
126 | 131 | out_dim=hidden_size,
|
@@ -316,7 +321,7 @@ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
|
316 | 321 | y_lens = [y.shape[2]] * y.shape[0]
|
317 | 322 | y = y.squeeze(1).view(1, -1, x.shape[-1])
|
318 | 323 | else:
|
319 |
| - raise ValueError(f"{attn_type} type is not available due to _xformers_available={_xformers_available}.") |
| 324 | + raise ValueError(f"Attention type is not available due to _xformers_available={_xformers_available}.") |
320 | 325 |
|
321 | 326 | for block in self.blocks:
|
322 | 327 | x = auto_grad_checkpoint(
|
|
0 commit comments