Skip to content

Commit b6156aa

Browse files
Rewrite AuraFlowPatchEmbed.pe_selection_index_based_on_dim to be torch.compile compatible (#11297)
* Update pe_selection_index_based_on_dim * Make pe_selection_index_based_on_dim work with torh.compile * Fix AuraFlowTransformer2DModel's dpcstring default values --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 7ecfe29 commit b6156aa

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

src/diffusers/models/transformers/auraflow_transformer_2d.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,23 @@ def pe_selection_index_based_on_dim(self, h, w):
7474
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
7575
# because original input are in flattened format, we have to flatten this 2d grid as well.
7676
h_p, w_p = h // self.patch_size, w // self.patch_size
77-
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
7877
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
79-
original_pe_indexes = original_pe_indexes.view(h_max, w_max)
78+
79+
# Calculate the top-left corner indices for the centered patch grid
8080
starth = h_max // 2 - h_p // 2
81-
endh = starth + h_p
8281
startw = w_max // 2 - w_p // 2
83-
endw = startw + w_p
84-
original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
85-
return original_pe_indexes.flatten()
82+
83+
# Generate the row and column indices for the desired patch grid
84+
rows = torch.arange(starth, starth + h_p, device=self.pos_embed.device)
85+
cols = torch.arange(startw, startw + w_p, device=self.pos_embed.device)
86+
87+
# Create a 2D grid of indices
88+
row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij")
89+
90+
# Convert the 2D grid indices to flattened 1D indices
91+
selected_indices = (row_indices * w_max + col_indices).flatten()
92+
93+
return selected_indices
8694

8795
def forward(self, latent):
8896
batch_size, num_channels, height, width = latent.size()
@@ -275,17 +283,17 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
275283
sample_size (`int`): The width of the latent images. This is fixed during training since
276284
it is used to learn a number of position embeddings.
277285
patch_size (`int`): Patch size to turn the input data into small patches.
278-
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
286+
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
279287
num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
280-
num_single_dit_layers (`int`, *optional*, defaults to 4):
288+
num_single_dit_layers (`int`, *optional*, defaults to 32):
281289
The number of layers of Transformer blocks to use. These blocks use concatenated image and text
282290
representations.
283-
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
284-
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
291+
attention_head_dim (`int`, *optional*, defaults to 256): The number of channels in each head.
292+
num_attention_heads (`int`, *optional*, defaults to 12): The number of heads to use for multi-head attention.
285293
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
286294
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
287-
out_channels (`int`, defaults to 16): Number of output channels.
288-
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
295+
out_channels (`int`, defaults to 4): Number of output channels.
296+
pos_embed_max_size (`int`, defaults to 1024): Maximum positions to embed from the image latents.
289297
"""
290298

291299
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]

0 commit comments

Comments
 (0)