@@ -42,7 +42,7 @@ def batch_patchify(
42
42
pad : bool = True ,
43
43
) -> Tuple [torch .Tensor , Tuple [int , int ]]:
44
44
B , C , H , W = x .shape
45
- ph , pw = to_2tuple ( patch_size )
45
+ ph , pw = patch_size
46
46
47
47
# Ensure the image is divisible by patch size
48
48
if pad and (H % ph != 0 or W % pw != 0 ):
@@ -202,21 +202,20 @@ def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
202
202
else :
203
203
return img_size [0 ] // self .patch_size [0 ], img_size [1 ] // self .patch_size [1 ]
204
204
205
- def forward (self , x , patch_coord = None , patch_valid = None ):
205
+ def forward (self , x : torch . Tensor , patch_coord : Optional [ torch . Tensor ] = None ):
206
206
"""Forward pass for combined embedding
207
207
208
208
Args:
209
209
x: Input tensor [B, C, H, W] or pre-patchified [B, N, P*P*C]
210
210
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex
211
- patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex
212
211
213
212
Returns:
214
213
Embedded tensor with position encoding and class/register tokens applied
215
214
If patch_type is provided, also returns attention mask
216
215
"""
217
216
# Apply patch embedding
218
217
naflex_grid_sizes : Optional [List [Tuple [int , int ]]] = None
219
- grid_size : Optional [Tuple [ int , int ]] = None
218
+ grid_size : Optional [List [ int ]] = None
220
219
221
220
B = x .shape [0 ]
222
221
if self .is_linear :
@@ -227,7 +226,7 @@ def forward(self, x, patch_coord=None, patch_valid=None):
227
226
# Calculate the appropriate grid size from coords
228
227
max_y = patch_coord [:, :, 0 ].max (dim = 1 )[0 ] + 1
229
228
max_x = patch_coord [:, :, 1 ].max (dim = 1 )[0 ] + 1
230
- naflex_grid_sizes = [(h .item (), w .item ()) for h , w in zip (max_y , max_x )]
229
+ naflex_grid_sizes = [(int ( h .item ()), int ( w .item () )) for h , w in zip (max_y , max_x )]
231
230
else :
232
231
_assert (x .ndim == 4 , 'Expecting 2D image input with input ndim == 4' )
233
232
x , grid_size = batch_patchify (x , self .patch_size , pad = self .dynamic_img_pad )
@@ -257,6 +256,7 @@ def forward(self, x, patch_coord=None, patch_valid=None):
257
256
if naflex_grid_sizes is not None :
258
257
self ._apply_learned_naflex_pos_embed (x , naflex_grid_sizes = naflex_grid_sizes )
259
258
else :
259
+ assert grid_size is not None
260
260
self ._apply_learned_pos_embed (x , grid_size = grid_size )
261
261
elif self .pos_embed_type == 'rope' :
262
262
assert False , "ROPE not yet implemented"
@@ -287,15 +287,19 @@ def _apply_learned_naflex_pos_embed(
287
287
orig_h , orig_w = self .pos_embed .shape [1 :3 ]
288
288
289
289
# Determine unique grid sizes
290
- size_to_indices = {}
290
+ size_to_indices : Dict [ Tuple [ int , int ], List [ int ]] = {}
291
291
for bi , (h , w ) in enumerate (naflex_grid_sizes ):
292
- if not (h , w ) in size_to_indices :
293
- size_to_indices [(h , w )] = [bi ]
292
+ #k = h << 16 | w # FIXME can get jit compat with this
293
+ k = (h , w )
294
+ if not k in size_to_indices :
295
+ size_to_indices [k ] = [bi ]
294
296
else :
295
- size_to_indices [( h , w ) ].append (bi )
297
+ size_to_indices [k ].append (bi )
296
298
297
299
# Handle each batch element separately with its own grid size
298
- for (h , w ), batch_indices in size_to_indices .items ():
300
+ for k , batch_indices in size_to_indices .items ():
301
+ h , w = k
302
+ #h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
299
303
# Interpolate only once for this (h, w)
300
304
if (h == orig_h ) and (w == orig_w ):
301
305
pos_embed_flat = self .pos_embed .reshape (orig_h * orig_w , - 1 )
@@ -315,7 +319,7 @@ def _apply_learned_naflex_pos_embed(
315
319
def _apply_learned_pos_embed (
316
320
self ,
317
321
x : torch .Tensor ,
318
- grid_size : Tuple [ int , int ],
322
+ grid_size : List [ int ],
319
323
):
320
324
orig_h , orig_w = self .pos_embed .shape [1 :3 ]
321
325
if grid_size [0 ] != orig_h or grid_size [1 ] != orig_w :
@@ -340,7 +344,7 @@ def _apply_learned_pos_embed(
340
344
341
345
@register_notrace_function
342
346
def create_attention_mask (
343
- patch_valid : Optional [ torch .Tensor ] ,
347
+ patch_valid : torch .Tensor ,
344
348
num_prefix_tokens : int = 0 ,
345
349
dtype : torch .dtype = torch .float32 ,
346
350
) -> torch .Tensor :
@@ -357,7 +361,7 @@ def create_attention_mask(
357
361
Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
358
362
or None if patch_type is None
359
363
"""
360
- patch_valid = patch_valid .bool ( )
364
+ patch_valid = patch_valid .to ( torch . bool )
361
365
B = patch_valid .shape [0 ]
362
366
363
367
if num_prefix_tokens > 0 :
@@ -373,7 +377,7 @@ def create_attention_mask(
373
377
374
378
@register_notrace_function
375
379
def create_attention_mask2 (
376
- patch_valid : Optional [ torch .Tensor ] ,
380
+ patch_valid : torch .Tensor ,
377
381
num_prefix_tokens : int = 0 ,
378
382
q_len : Optional [int ] = None ,
379
383
dtype : torch .dtype = torch .float32 ,
@@ -411,7 +415,7 @@ def create_attention_mask2(
411
415
412
416
@register_notrace_function
413
417
def create_pool_mask (
414
- patch_valid : Optional [ torch .Tensor ] ,
418
+ patch_valid :torch .Tensor ,
415
419
dtype : torch .dtype = torch .float32 ,
416
420
) -> torch .Tensor :
417
421
patch_valid = patch_valid .bool ()
@@ -773,8 +777,16 @@ def forward_features(
773
777
patch_valid : Optional [torch .Tensor ] = None ,
774
778
attn_mask : Optional [torch .Tensor ] = None ,
775
779
) -> torch .Tensor :
780
+
781
+ if attn_mask is None and patch_valid is not None :
782
+ attn_mask = create_attention_mask (
783
+ patch_valid ,
784
+ num_prefix_tokens = self .num_prefix_tokens ,
785
+ dtype = x .dtype
786
+ )
787
+
776
788
# Pass through embedding module with patch coordinate/type support
777
- x = self .embeds (x , patch_coord = patch_coord , patch_valid = patch_valid )
789
+ x = self .embeds (x , patch_coord = patch_coord )
778
790
779
791
# Apply transformer blocks with masked attention if mask provided
780
792
if attn_mask is not None :
@@ -827,7 +839,7 @@ def _pool(
827
839
828
840
# For max pooling with mask
829
841
masked_x = x .clone ()
830
- masked_x [~ patch_valid ] = torch .finfo (masked_x .dtype ).min
842
+ masked_x [~ patch_valid ] = - 1e4 # torch.finfo(masked_x.dtype).min
831
843
masked_max = masked_x .max (dim = 1 )[0 ]
832
844
833
845
# Combine average and max
@@ -864,27 +876,23 @@ def forward(
864
876
Returns:
865
877
Model output tensor
866
878
"""
867
- # Handle dictionary input from NaFlex collator
868
- if isinstance (x , dict ):
869
- assert patch_coord is None
870
- assert patch_valid is None
871
- # Extract the required components from the dictionary
879
+ if isinstance (x , torch .Tensor ):
880
+ patches = x
881
+ else :
882
+ # Handle dictionary input from NaFlex collator
872
883
patch_coord = x ['patch_coord' ]
873
884
patch_valid = x ['patch_valid' ]
874
885
patches = x ['patches' ]
875
886
876
- if False :
877
- # DEBUG, reconstruct patches
878
- for i in range (len (patches )):
879
- patch = patches [i ][patch_valid [i ]]
880
- h = (patch_coord [i , :, 0 ].max () + 1 ).item ()
881
- w = (patch_coord [i , :, 1 ].max () + 1 ).item ()
882
- patch = patch .reshape (h , w , 16 , 16 , 3 ).permute (4 , 0 , 2 , 1 , 3 )
883
- patch = patch .reshape (3 , h * 16 , w * 16 )
884
- from torchvision .utils import save_image
885
- save_image (patch , f'patch_{ i } .jpg' , normalize = True )
886
- else :
887
- patches = x
887
+ # DEBUG, reconstruct patches
888
+ # for i in range(len(patches)):
889
+ # patch = patches[i][patch_valid[i]]
890
+ # h = (patch_coord[i, :, 0].max() + 1).item()
891
+ # w = (patch_coord[i, :, 1].max() + 1).item()
892
+ # patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3)
893
+ # patch = patch.reshape(3, h*16, w*16)
894
+ # from torchvision.utils import save_image
895
+ # save_image(patch, f'patch_{i}.jpg', normalize=True)
888
896
889
897
# Create attention mask if patch_type is provided
890
898
if patch_valid is not None :
0 commit comments