Skip to content

Commit 9b52654

Browse files
annotate a few torch.nn.modules.* modules (pytorch#45772)
Summary: Fixes pytorch#45771 Pull Request resolved: pytorch#45772 Reviewed By: mruberry Differential Revision: D24682013 Pulled By: albanD fbshipit-source-id: e32bc4fe9c586c079f7070924a874c70f3d127fa
1 parent 7178790 commit 9b52654

File tree

7 files changed

+25
-26
lines changed

7 files changed

+25
-26
lines changed

mypy.ini

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,6 @@ ignore_errors = True
7777
[mypy-torch._tensor_str]
7878
ignore_errors = True
7979

80-
[mypy-torch.nn.modules.batchnorm]
81-
ignore_errors = True
82-
8380
[mypy-torch.nn.modules.container]
8481
ignore_errors = True
8582

@@ -89,12 +86,6 @@ ignore_errors = True
8986
[mypy-torch.nn.modules.fold]
9087
ignore_errors = True
9188

92-
[mypy-torch.nn.modules.instancenorm]
93-
ignore_errors = True
94-
95-
[mypy-torch.nn.modules.linear]
96-
ignore_errors = True
97-
9889
[mypy-torch.nn.modules.loss]
9990
ignore_errors = True
10091

@@ -113,9 +104,6 @@ ignore_errors = True
113104
[mypy-torch.nn.modules.rnn]
114105
ignore_errors = True
115106

116-
[mypy-torch.nn.modules.sparse]
117-
ignore_errors = True
118-
119107
[mypy-torch.nn.parallel._functions]
120108
ignore_errors = True
121109

torch/nn/functional.pyi.in

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ def embedding(input: Tensor, weight: Tensor, padding_idx: Optional[int] = ..., m
189189

190190
def embedding_bag(input: Tensor, weight: Tensor, offsets: Optional[Tensor] = ..., max_norm: Optional[float] = ...,
191191
norm_type: float = ..., scale_grad_by_freq: bool = ..., mode: str = ...,
192-
sparse: bool = ...) -> Tensor: ...
192+
sparse: bool = ..., per_sample_weights: Optional[Tensor] = ...,
193+
include_last_offset: bool = ...) -> Tensor: ...
193194

194195
def batch_norm(input: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor],
195196
weight: Optional[Tensor] = ..., bias: Optional[Tensor] = ..., training: bool = ...,

torch/nn/modules/batchnorm.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,11 @@ def __init__(
5454

5555
def reset_running_stats(self) -> None:
5656
if self.track_running_stats:
57-
self.running_mean.zero_()
58-
self.running_var.fill_(1)
59-
self.num_batches_tracked.zero_()
57+
# running_mean/running_var/num_batches... are registerd at runtime depending
58+
# if self.track_running_stats is on
59+
self.running_mean.zero_() # type: ignore[operator]
60+
self.running_var.fill_(1) # type: ignore[operator]
61+
self.num_batches_tracked.zero_() # type: ignore[operator]
6062

6163
def reset_parameters(self) -> None:
6264
self.reset_running_stats()
@@ -107,8 +109,8 @@ def forward(self, input: Tensor) -> Tensor:
107109

108110
if self.training and self.track_running_stats:
109111
# TODO: if statement only here to tell the jit to skip emitting this when it is None
110-
if self.num_batches_tracked is not None:
111-
self.num_batches_tracked = self.num_batches_tracked + 1
112+
if self.num_batches_tracked is not None: # type: ignore
113+
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore
112114
if self.momentum is None: # use cumulative moving average
113115
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
114116
else: # use exponential moving average
@@ -128,6 +130,8 @@ def forward(self, input: Tensor) -> Tensor:
128130
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
129131
used for normalization (i.e. in eval mode when buffers are not None).
130132
"""
133+
assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
134+
assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
131135
return F.batch_norm(
132136
input,
133137
# If buffers are not to be tracked, ensure that they won't be updated
@@ -487,6 +491,7 @@ def forward(self, input: Tensor) -> Tensor:
487491
exponential_average_factor = self.momentum
488492

489493
if self.training and self.track_running_stats:
494+
assert self.num_batches_tracked is not None
490495
self.num_batches_tracked = self.num_batches_tracked + 1
491496
if self.momentum is None: # use cumulative moving average
492497
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
@@ -508,6 +513,8 @@ def forward(self, input: Tensor) -> Tensor:
508513
used for normalization (i.e. in eval mode when buffers are not None).
509514
"""
510515
# If buffers are not to be tracked, ensure that they won't be updated
516+
assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
517+
assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
511518
running_mean = self.running_mean if not self.training or self.track_running_stats else None
512519
running_var = self.running_var if not self.training or self.track_running_stats else None
513520

torch/nn/modules/instancenorm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
5252
def forward(self, input: Tensor) -> Tensor:
5353
self._check_input_dim(input)
5454

55+
assert self.running_mean is None or isinstance(self.running_mean, Tensor)
56+
assert self.running_var is None or isinstance(self.running_var, Tensor)
5557
return F.instance_norm(
5658
input, self.running_mean, self.running_var, self.weight, self.bias,
5759
self.training or not self.track_running_stats, self.momentum, self.eps)

torch/nn/modules/linear.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,10 @@ def extra_repr(self) -> str:
102102
# This class exists solely for Transformer; it has an annotation stating
103103
# that bias is never None, which appeases TorchScript
104104
class _LinearWithBias(Linear):
105-
bias: Tensor
105+
bias: Tensor # type: ignore
106106

107107
def __init__(self, in_features: int, out_features: int) -> None:
108-
super().__init__(in_features, out_features, bias=True)
108+
super().__init__(in_features, out_features, bias=True) # type: ignore
109109

110110

111111
class Bilinear(Module):
@@ -208,7 +208,8 @@ class LazyLinear(LazyModuleMixin, Linear):
208208
209209
"""
210210

211-
cls_to_become = Linear
211+
cls_to_become = Linear # type: ignore[assignment]
212+
weight: UninitializedParameter
212213

213214
def __init__(self, out_features: int, bias: bool = True) -> None:
214215
super().__init__(0, out_features, bias)
@@ -218,7 +219,7 @@ def reset_parameters(self) -> None:
218219
if not self.has_uninitialized_params() and self.in_features != 0:
219220
super().reset_parameters()
220221

221-
def initialize_parameters(self, input) -> None:
222+
def initialize_parameters(self, input) -> None: # type: ignore
222223
if self.has_uninitialized_params():
223224
with torch.no_grad():
224225
self.in_features = input.shape[-1]

torch/nn/modules/sparse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ class Embedding(Module):
9999

100100
num_embeddings: int
101101
embedding_dim: int
102-
padding_idx: int
103-
max_norm: float
102+
padding_idx: Optional[int]
103+
max_norm: Optional[float]
104104
norm_type: float
105105
scale_grad_by_freq: bool
106106
weight: Tensor
@@ -284,7 +284,7 @@ class EmbeddingBag(Module):
284284

285285
num_embeddings: int
286286
embedding_dim: int
287-
max_norm: float
287+
max_norm: Optional[float]
288288
norm_type: float
289289
scale_grad_by_freq: bool
290290
weight: Tensor

torch/nn/parameter.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ class Parameter(Tensor):
1111
class UninitializedParameter(Tensor):
1212
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...
1313

14-
def materialize(self, shape: Tuple[int], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
14+
def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
1515
...

0 commit comments

Comments
 (0)