@@ -74,15 +74,23 @@ def pe_selection_index_based_on_dim(self, h, w):
74
74
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
75
75
# because original input are in flattened format, we have to flatten this 2d grid as well.
76
76
h_p , w_p = h // self .patch_size , w // self .patch_size
77
- original_pe_indexes = torch .arange (self .pos_embed .shape [1 ])
78
77
h_max , w_max = int (self .pos_embed_max_size ** 0.5 ), int (self .pos_embed_max_size ** 0.5 )
79
- original_pe_indexes = original_pe_indexes .view (h_max , w_max )
78
+
79
+ # Calculate the top-left corner indices for the centered patch grid
80
80
starth = h_max // 2 - h_p // 2
81
- endh = starth + h_p
82
81
startw = w_max // 2 - w_p // 2
83
- endw = startw + w_p
84
- original_pe_indexes = original_pe_indexes [starth :endh , startw :endw ]
85
- return original_pe_indexes .flatten ()
82
+
83
+ # Generate the row and column indices for the desired patch grid
84
+ rows = torch .arange (starth , starth + h_p , device = self .pos_embed .device )
85
+ cols = torch .arange (startw , startw + w_p , device = self .pos_embed .device )
86
+
87
+ # Create a 2D grid of indices
88
+ row_indices , col_indices = torch .meshgrid (rows , cols , indexing = "ij" )
89
+
90
+ # Convert the 2D grid indices to flattened 1D indices
91
+ selected_indices = (row_indices * w_max + col_indices ).flatten ()
92
+
93
+ return selected_indices
86
94
87
95
def forward (self , latent ):
88
96
batch_size , num_channels , height , width = latent .size ()
@@ -275,17 +283,17 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
275
283
sample_size (`int`): The width of the latent images. This is fixed during training since
276
284
it is used to learn a number of position embeddings.
277
285
patch_size (`int`): Patch size to turn the input data into small patches.
278
- in_channels (`int`, *optional*, defaults to 16 ): The number of channels in the input.
286
+ in_channels (`int`, *optional*, defaults to 4 ): The number of channels in the input.
279
287
num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
280
- num_single_dit_layers (`int`, *optional*, defaults to 4 ):
288
+ num_single_dit_layers (`int`, *optional*, defaults to 32 ):
281
289
The number of layers of Transformer blocks to use. These blocks use concatenated image and text
282
290
representations.
283
- attention_head_dim (`int`, *optional*, defaults to 64 ): The number of channels in each head.
284
- num_attention_heads (`int`, *optional*, defaults to 18 ): The number of heads to use for multi-head attention.
291
+ attention_head_dim (`int`, *optional*, defaults to 256 ): The number of channels in each head.
292
+ num_attention_heads (`int`, *optional*, defaults to 12 ): The number of heads to use for multi-head attention.
285
293
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
286
294
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
287
- out_channels (`int`, defaults to 16 ): Number of output channels.
288
- pos_embed_max_size (`int`, defaults to 4096 ): Maximum positions to embed from the image latents.
295
+ out_channels (`int`, defaults to 4 ): Number of output channels.
296
+ pos_embed_max_size (`int`, defaults to 1024 ): Maximum positions to embed from the image latents.
289
297
"""
290
298
291
299
_no_split_modules = ["AuraFlowJointTransformerBlock" , "AuraFlowSingleTransformerBlock" , "AuraFlowPatchEmbed" ]
0 commit comments