-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
In convert_seq_to_patch_view, the return value of select_segments is unused, suggesting the scoring module might not be effective.
Samay/src/samay/models/lptm/model/masktrain.py
Lines 19 to 38 in 6a549ae
| @staticmethod | |
| def convert_seq_to_patch_view( | |
| mask: torch.Tensor, | |
| scores: torch.Tensor, | |
| patch_len: int = 8, | |
| stride: Optional[int] = None, | |
| ): | |
| """ | |
| Input: | |
| mask : torch.Tensor of shape [batch_size x seq_len] | |
| Output | |
| mask : torch.Tensor of shape [batch_size x n_patches] | |
| """ | |
| stride = patch_len if stride is None else stride | |
| # sm.forward(mask) | |
| if hasattr(scores, "shape"): | |
| select_segments(scores, patch_len, mask=mask) | |
| mask = mask.unfold(dimension=-1, size=patch_len, step=stride) | |
| # mask : [batch_size x n_patches x patch_len] | |
| return (mask.sum(dim=-1) == patch_len).long() |
Metadata
Metadata
Assignees
Labels
No labels