Skip to content

Commit c3fb81e

Browse files
update with viptr
1 parent 9f6bc7c commit c3fb81e

File tree

4 files changed

+25
-8
lines changed

4 files changed

+25
-8
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ repos:
1616
- id: no-commit-to-branch
1717
args: ['--branch', 'main']
1818
- repo: https://github.com/astral-sh/ruff-pre-commit
19-
rev: v0.11.1
19+
rev: v0.11.5
2020
hooks:
2121
- id: ruff
2222
args: [ --fix ]

doctr/models/classification/vip/layers/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __init__(
246246
*conv_sequence_pt(dim, dim, kernel_size=1, groups=dim, bias=False, bn=True, relu=False),
247247
)
248248
else:
249-
self.sr = nn.Identity() # type: ignore[assignment]
249+
self.sr = nn.Identity() # type: ignore[assignment]
250250

251251
self.local_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
252252

doctr/models/classification/vip/pytorch.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,15 @@ def _init_weights(self, m):
230230
nn.init.constant_(m.bias, 0)
231231
nn.init.constant_(m.weight, 1.0)
232232

233+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
234+
"""Load pretrained parameters onto the model
235+
236+
Args:
237+
path_or_url: the path or URL to the model parameters (checkpoint)
238+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
239+
"""
240+
load_pretrained_params(self, path_or_url, **kwargs)
241+
233242

234243
def vip_tiny(pretrained: bool = False, **kwargs: Any) -> VIPNet:
235244
"""
@@ -322,7 +331,7 @@ def _vip(
322331
# The number of classes is not the same as the number of classes in the pretrained model =>
323332
# remove the last layer weights
324333
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
325-
load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
334+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
326335
return model
327336

328337

doctr/models/recognition/viptr/pytorch.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
9191

9292

9393
class VIPTR(RecognitionModel, nn.Module):
94-
"""Implements a VIPTR architecture as described in `"A Vision Permutable Extractor for Fast and Efficient
94+
"""Implements a VIPTR architecture as described in `"A Vision Permutable Extractor for Fast and Efficient
9595
Scene Text Recognition" <https://arxiv.org/abs/2401.10110>`_.
96-
96+
9797
Args:
9898
feature_extractor: the backbone serving as feature extractor
9999
vocab: vocabulary used for encoding
@@ -110,7 +110,6 @@ def __init__(
110110
exportable: bool = False,
111111
cfg: dict[str, Any] | None = None,
112112
):
113-
114113
super().__init__()
115114
self.vocab = vocab
116115
self.exportable = exportable
@@ -134,6 +133,15 @@ def __init__(
134133
if m.bias is not None:
135134
nn.init.zeros_(m.bias)
136135

136+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
137+
"""Load pretrained parameters onto the model
138+
139+
Args:
140+
path_or_url: the path or URL to the model parameters (checkpoint)
141+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
142+
"""
143+
load_pretrained_params(self, path_or_url, **kwargs)
144+
137145
def forward(
138146
self,
139147
x: torch.Tensor,
@@ -230,7 +238,7 @@ def _viptr(
230238

231239
# Feature extractor
232240
feat_extractor = IntermediateLayerGetter(
233-
backbone_fn(pretrained_backbone, input_shape=_cfg["input_shape"]), # type: ignore[call-arg]
241+
backbone_fn(pretrained_backbone, input_shape=_cfg["input_shape"]), # type: ignore[call-arg]
234242
{layer: "features"},
235243
)
236244

@@ -244,7 +252,7 @@ def _viptr(
244252
# The number of classes is not the same as the number of classes in the pretrained model =>
245253
# remove the last layer weights
246254
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
247-
load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
255+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
248256

249257
return model
250258

0 commit comments

Comments
 (0)