Skip to content

Commit 6a999f1

Browse files
krishnakalyan3justusschockotaj
authored
Fix mypy errors attributed to pytorch_lightning.demos.boring_classes (#14201)
Co-authored-by: Justus Schock <[email protected]> Co-authored-by: otaj <[email protected]>
1 parent a01e016 commit 6a999f1

File tree

2 files changed

+43
-35
lines changed

2 files changed

+43
-35
lines changed

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ warn_no_return = "False"
5151
module = [
5252
"pytorch_lightning.callbacks.progress.rich_progress",
5353
"pytorch_lightning.core.datamodule",
54-
"pytorch_lightning.demos.boring_classes",
5554
"pytorch_lightning.demos.mnist_datamodule",
5655
"pytorch_lightning.profilers.base",
5756
"pytorch_lightning.profilers.pytorch",

src/pytorch_lightning/demos/boring_classes.py

+43-34
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,27 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Optional
14+
from typing import cast, Dict, Iterator, List, Optional, Tuple, Union
1515

1616
import torch
1717
import torch.nn as nn
1818
import torch.nn.functional as F
19+
from torch import Tensor
20+
from torch.optim import Optimizer
21+
from torch.optim.lr_scheduler import _LRScheduler
1922
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset
2023

2124
from pytorch_lightning import LightningDataModule, LightningModule
25+
from pytorch_lightning.core.optimizer import LightningOptimizer
26+
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
2227

2328

2429
class RandomDictDataset(Dataset):
2530
def __init__(self, size: int, length: int):
2631
self.len = length
2732
self.data = torch.randn(length, size)
2833

29-
def __getitem__(self, index):
34+
def __getitem__(self, index: int) -> Dict[str, Tensor]:
3035
a = self.data[index]
3136
b = a + 2
3237
return {"a": a, "b": b}
@@ -40,7 +45,7 @@ def __init__(self, size: int, length: int):
4045
self.len = length
4146
self.data = torch.randn(length, size)
4247

43-
def __getitem__(self, index):
48+
def __getitem__(self, index: int) -> Tensor:
4449
return self.data[index]
4550

4651
def __len__(self) -> int:
@@ -52,7 +57,7 @@ def __init__(self, size: int, count: int):
5257
self.count = count
5358
self.size = size
5459

55-
def __iter__(self):
60+
def __iter__(self) -> Iterator[Tensor]:
5661
for _ in range(self.count):
5762
yield torch.randn(self.size)
5863

@@ -62,16 +67,16 @@ def __init__(self, size: int, count: int):
6267
self.count = count
6368
self.size = size
6469

65-
def __iter__(self):
70+
def __iter__(self) -> Iterator[Tensor]:
6671
for _ in range(len(self)):
6772
yield torch.randn(self.size)
6873

69-
def __len__(self):
74+
def __len__(self) -> int:
7075
return self.count
7176

7277

7378
class BoringModel(LightningModule):
74-
def __init__(self):
79+
def __init__(self) -> None:
7580
"""Testing PL Module.
7681
7782
Use as follows:
@@ -90,60 +95,63 @@ def training_step(...):
9095
super().__init__()
9196
self.layer = torch.nn.Linear(32, 2)
9297

93-
def forward(self, x):
98+
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
9499
return self.layer(x)
95100

96-
def loss(self, batch, preds):
101+
def loss(self, batch: Tensor, preds: Tensor) -> Tensor:
97102
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
98103
return torch.nn.functional.mse_loss(preds, torch.ones_like(preds))
99104

100-
def step(self, x):
105+
def step(self, x: Tensor) -> Tensor:
101106
x = self(x)
102107
out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
103108
return out
104109

105-
def training_step(self, batch, batch_idx):
110+
def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT: # type: ignore[override]
106111
output = self(batch)
107112
loss = self.loss(batch, output)
108113
return {"loss": loss}
109114

110-
def training_step_end(self, training_step_outputs):
115+
def training_step_end(self, training_step_outputs: STEP_OUTPUT) -> STEP_OUTPUT:
111116
return training_step_outputs
112117

113-
def training_epoch_end(self, outputs) -> None:
118+
def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
119+
outputs = cast(List[Dict[str, Tensor]], outputs)
114120
torch.stack([x["loss"] for x in outputs]).mean()
115121

116-
def validation_step(self, batch, batch_idx):
122+
def validation_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: # type: ignore[override]
117123
output = self(batch)
118124
loss = self.loss(batch, output)
119125
return {"x": loss}
120126

121-
def validation_epoch_end(self, outputs) -> None:
127+
def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
128+
outputs = cast(List[Dict[str, Tensor]], outputs)
122129
torch.stack([x["x"] for x in outputs]).mean()
123130

124-
def test_step(self, batch, batch_idx):
131+
def test_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: # type: ignore[override]
125132
output = self(batch)
126133
loss = self.loss(batch, output)
127134
return {"y": loss}
128135

129-
def test_epoch_end(self, outputs) -> None:
136+
def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
137+
outputs = cast(List[Dict[str, Tensor]], outputs)
130138
torch.stack([x["y"] for x in outputs]).mean()
131139

132-
def configure_optimizers(self):
140+
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_LRScheduler]]:
133141
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
134142
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
135143
return [optimizer], [lr_scheduler]
136144

137-
def train_dataloader(self):
145+
def train_dataloader(self) -> DataLoader:
138146
return DataLoader(RandomDataset(32, 64))
139147

140-
def val_dataloader(self):
148+
def val_dataloader(self) -> DataLoader:
141149
return DataLoader(RandomDataset(32, 64))
142150

143-
def test_dataloader(self):
151+
def test_dataloader(self) -> DataLoader:
144152
return DataLoader(RandomDataset(32, 64))
145153

146-
def predict_dataloader(self):
154+
def predict_dataloader(self) -> DataLoader:
147155
return DataLoader(RandomDataset(32, 64))
148156

149157

@@ -155,7 +163,7 @@ def __init__(self, data_dir: str = "./"):
155163
self.checkpoint_state: Optional[str] = None
156164
self.random_full = RandomDataset(32, 64 * 4)
157165

158-
def setup(self, stage: Optional[str] = None):
166+
def setup(self, stage: Optional[str] = None) -> None:
159167
if stage == "fit" or stage is None:
160168
self.random_train = Subset(self.random_full, indices=range(64))
161169

@@ -168,26 +176,27 @@ def setup(self, stage: Optional[str] = None):
168176
if stage == "predict" or stage is None:
169177
self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))
170178

171-
def train_dataloader(self):
179+
def train_dataloader(self) -> DataLoader:
172180
return DataLoader(self.random_train)
173181

174-
def val_dataloader(self):
182+
def val_dataloader(self) -> DataLoader:
175183
return DataLoader(self.random_val)
176184

177-
def test_dataloader(self):
185+
def test_dataloader(self) -> DataLoader:
178186
return DataLoader(self.random_test)
179187

180-
def predict_dataloader(self):
188+
def predict_dataloader(self) -> DataLoader:
181189
return DataLoader(self.random_predict)
182190

183191

184192
class ManualOptimBoringModel(BoringModel):
185-
def __init__(self):
193+
def __init__(self) -> None:
186194
super().__init__()
187195
self.automatic_optimization = False
188196

189-
def training_step(self, batch, batch_idx):
197+
def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT: # type: ignore[override]
190198
opt = self.optimizers()
199+
assert isinstance(opt, (Optimizer, LightningOptimizer))
191200
output = self(batch)
192201
loss = self.loss(batch, output)
193202
opt.zero_grad()
@@ -202,21 +211,21 @@ def __init__(self, out_dim: int = 10, learning_rate: float = 0.02):
202211
self.l1 = torch.nn.Linear(32, out_dim)
203212
self.learning_rate = learning_rate
204213

205-
def forward(self, x):
214+
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
206215
return torch.relu(self.l1(x.view(x.size(0), -1)))
207216

208-
def training_step(self, batch, batch_nb):
217+
def training_step(self, batch: Tensor, batch_nb: int) -> STEP_OUTPUT: # type: ignore[override]
209218
x = batch
210219
x = self(x)
211220
loss = x.sum()
212221
return loss
213222

214-
def configure_optimizers(self):
223+
def configure_optimizers(self) -> torch.optim.Optimizer:
215224
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
216225

217226

218227
class Net(nn.Module):
219-
def __init__(self):
228+
def __init__(self) -> None:
220229
super().__init__()
221230
self.conv1 = nn.Conv2d(1, 32, 3, 1)
222231
self.conv2 = nn.Conv2d(32, 64, 3, 1)
@@ -225,7 +234,7 @@ def __init__(self):
225234
self.fc1 = nn.Linear(9216, 128)
226235
self.fc2 = nn.Linear(128, 10)
227236

228-
def forward(self, x):
237+
def forward(self, x: Tensor) -> Tensor:
229238
x = self.conv1(x)
230239
x = F.relu(x)
231240
x = self.conv2(x)

0 commit comments

Comments
 (0)