Skip to content

Commit ea728f6

Browse files
committed
Improve several typing issues for flex vit, can (almost) work with jit if we bash h,w key into an int or str
1 parent 97341fe commit ea728f6

File tree

2 files changed

+43
-35
lines changed

2 files changed

+43
-35
lines changed

Diff for: timm/layers/patch_embed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def resample_patch_embed(
321321
verbose: bool = False,
322322
):
323323
""" Standalone function (computes matrix on each call). """
324-
assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_c, in_c, h, w)"
324+
assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_ch, in_ch, h, w)"
325325
assert len(new_size) == 2, "New shape should only be hw (height, width)"
326326

327327
old_size_tuple: Tuple[int, int] = tuple(patch_embed.shape[-2:])

Diff for: timm/models/vision_transformer_flex.py

+42-34
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def batch_patchify(
4242
pad: bool = True,
4343
) -> Tuple[torch.Tensor, Tuple[int, int]]:
4444
B, C, H, W = x.shape
45-
ph, pw = to_2tuple(patch_size)
45+
ph, pw = patch_size
4646

4747
# Ensure the image is divisible by patch size
4848
if pad and (H % ph != 0 or W % pw != 0):
@@ -202,21 +202,20 @@ def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
202202
else:
203203
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
204204

205-
def forward(self, x, patch_coord=None, patch_valid=None):
205+
def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None):
206206
"""Forward pass for combined embedding
207207
208208
Args:
209209
x: Input tensor [B, C, H, W] or pre-patchified [B, N, P*P*C]
210210
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex
211-
patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex
212211
213212
Returns:
214213
Embedded tensor with position encoding and class/register tokens applied
215214
If patch_type is provided, also returns attention mask
216215
"""
217216
# Apply patch embedding
218217
naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None
219-
grid_size: Optional[Tuple[int, int]] = None
218+
grid_size: Optional[List[int]] = None
220219

221220
B = x.shape[0]
222221
if self.is_linear:
@@ -227,7 +226,7 @@ def forward(self, x, patch_coord=None, patch_valid=None):
227226
# Calculate the appropriate grid size from coords
228227
max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1
229228
max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1
230-
naflex_grid_sizes = [(h.item(), w.item()) for h, w in zip(max_y, max_x)]
229+
naflex_grid_sizes = [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)]
231230
else:
232231
_assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4')
233232
x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad)
@@ -257,6 +256,7 @@ def forward(self, x, patch_coord=None, patch_valid=None):
257256
if naflex_grid_sizes is not None:
258257
self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes)
259258
else:
259+
assert grid_size is not None
260260
self._apply_learned_pos_embed(x, grid_size=grid_size)
261261
elif self.pos_embed_type == 'rope':
262262
assert False, "ROPE not yet implemented"
@@ -287,15 +287,19 @@ def _apply_learned_naflex_pos_embed(
287287
orig_h, orig_w = self.pos_embed.shape[1:3]
288288

289289
# Determine unique grid sizes
290-
size_to_indices = {}
290+
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
291291
for bi, (h, w) in enumerate(naflex_grid_sizes):
292-
if not (h, w) in size_to_indices:
293-
size_to_indices[(h, w)] = [bi]
292+
#k = h << 16 | w # FIXME can get jit compat with this
293+
k = (h, w)
294+
if not k in size_to_indices:
295+
size_to_indices[k] = [bi]
294296
else:
295-
size_to_indices[(h, w)].append(bi)
297+
size_to_indices[k].append(bi)
296298

297299
# Handle each batch element separately with its own grid size
298-
for (h, w), batch_indices in size_to_indices.items():
300+
for k, batch_indices in size_to_indices.items():
301+
h, w = k
302+
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
299303
# Interpolate only once for this (h, w)
300304
if (h == orig_h) and (w == orig_w):
301305
pos_embed_flat = self.pos_embed.reshape(orig_h * orig_w, -1)
@@ -315,7 +319,7 @@ def _apply_learned_naflex_pos_embed(
315319
def _apply_learned_pos_embed(
316320
self,
317321
x: torch.Tensor,
318-
grid_size: Tuple[int, int],
322+
grid_size: List[int],
319323
):
320324
orig_h, orig_w = self.pos_embed.shape[1:3]
321325
if grid_size[0] != orig_h or grid_size[1] != orig_w:
@@ -340,7 +344,7 @@ def _apply_learned_pos_embed(
340344

341345
@register_notrace_function
342346
def create_attention_mask(
343-
patch_valid: Optional[torch.Tensor],
347+
patch_valid: torch.Tensor,
344348
num_prefix_tokens: int = 0,
345349
dtype: torch.dtype = torch.float32,
346350
) -> torch.Tensor:
@@ -357,7 +361,7 @@ def create_attention_mask(
357361
Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
358362
or None if patch_type is None
359363
"""
360-
patch_valid = patch_valid.bool()
364+
patch_valid = patch_valid.to(torch.bool)
361365
B = patch_valid.shape[0]
362366

363367
if num_prefix_tokens > 0:
@@ -373,7 +377,7 @@ def create_attention_mask(
373377

374378
@register_notrace_function
375379
def create_attention_mask2(
376-
patch_valid: Optional[torch.Tensor],
380+
patch_valid: torch.Tensor,
377381
num_prefix_tokens: int = 0,
378382
q_len: Optional[int] = None,
379383
dtype: torch.dtype = torch.float32,
@@ -411,7 +415,7 @@ def create_attention_mask2(
411415

412416
@register_notrace_function
413417
def create_pool_mask(
414-
patch_valid: Optional[torch.Tensor],
418+
patch_valid:torch.Tensor,
415419
dtype: torch.dtype = torch.float32,
416420
) -> torch.Tensor:
417421
patch_valid = patch_valid.bool()
@@ -773,8 +777,16 @@ def forward_features(
773777
patch_valid: Optional[torch.Tensor] = None,
774778
attn_mask: Optional[torch.Tensor] = None,
775779
) -> torch.Tensor:
780+
781+
if attn_mask is None and patch_valid is not None:
782+
attn_mask = create_attention_mask(
783+
patch_valid,
784+
num_prefix_tokens=self.num_prefix_tokens,
785+
dtype=x.dtype
786+
)
787+
776788
# Pass through embedding module with patch coordinate/type support
777-
x = self.embeds(x, patch_coord=patch_coord, patch_valid=patch_valid)
789+
x = self.embeds(x, patch_coord=patch_coord)
778790

779791
# Apply transformer blocks with masked attention if mask provided
780792
if attn_mask is not None:
@@ -827,7 +839,7 @@ def _pool(
827839

828840
# For max pooling with mask
829841
masked_x = x.clone()
830-
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
842+
masked_x[~patch_valid] = -1e4 # torch.finfo(masked_x.dtype).min
831843
masked_max = masked_x.max(dim=1)[0]
832844

833845
# Combine average and max
@@ -864,27 +876,23 @@ def forward(
864876
Returns:
865877
Model output tensor
866878
"""
867-
# Handle dictionary input from NaFlex collator
868-
if isinstance(x, dict):
869-
assert patch_coord is None
870-
assert patch_valid is None
871-
# Extract the required components from the dictionary
879+
if isinstance(x, torch.Tensor):
880+
patches = x
881+
else:
882+
# Handle dictionary input from NaFlex collator
872883
patch_coord = x['patch_coord']
873884
patch_valid = x['patch_valid']
874885
patches = x['patches']
875886

876-
if False:
877-
# DEBUG, reconstruct patches
878-
for i in range(len(patches)):
879-
patch = patches[i][patch_valid[i]]
880-
h = (patch_coord[i, :, 0].max() + 1).item()
881-
w = (patch_coord[i, :, 1].max() + 1).item()
882-
patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3)
883-
patch = patch.reshape(3, h*16, w*16)
884-
from torchvision.utils import save_image
885-
save_image(patch, f'patch_{i}.jpg', normalize=True)
886-
else:
887-
patches = x
887+
# DEBUG, reconstruct patches
888+
# for i in range(len(patches)):
889+
# patch = patches[i][patch_valid[i]]
890+
# h = (patch_coord[i, :, 0].max() + 1).item()
891+
# w = (patch_coord[i, :, 1].max() + 1).item()
892+
# patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3)
893+
# patch = patch.reshape(3, h*16, w*16)
894+
# from torchvision.utils import save_image
895+
# save_image(patch, f'patch_{i}.jpg', normalize=True)
888896

889897
# Create attention mask if patch_type is provided
890898
if patch_valid is not None:

0 commit comments

Comments
 (0)