11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- from typing import Optional
14
+ from typing import cast , Dict , Iterator , List , Optional , Tuple , Union
15
15
16
16
import torch
17
17
import torch .nn as nn
18
18
import torch .nn .functional as F
19
+ from torch import Tensor
20
+ from torch .optim import Optimizer
21
+ from torch .optim .lr_scheduler import _LRScheduler
19
22
from torch .utils .data import DataLoader , Dataset , IterableDataset , Subset
20
23
21
24
from pytorch_lightning import LightningDataModule , LightningModule
25
+ from pytorch_lightning .core .optimizer import LightningOptimizer
26
+ from pytorch_lightning .utilities .types import EPOCH_OUTPUT , STEP_OUTPUT
22
27
23
28
24
29
class RandomDictDataset (Dataset ):
25
30
def __init__ (self , size : int , length : int ):
26
31
self .len = length
27
32
self .data = torch .randn (length , size )
28
33
29
- def __getitem__ (self , index ) :
34
+ def __getitem__ (self , index : int ) -> Dict [ str , Tensor ] :
30
35
a = self .data [index ]
31
36
b = a + 2
32
37
return {"a" : a , "b" : b }
@@ -40,7 +45,7 @@ def __init__(self, size: int, length: int):
40
45
self .len = length
41
46
self .data = torch .randn (length , size )
42
47
43
- def __getitem__ (self , index ) :
48
+ def __getitem__ (self , index : int ) -> Tensor :
44
49
return self .data [index ]
45
50
46
51
def __len__ (self ) -> int :
@@ -52,7 +57,7 @@ def __init__(self, size: int, count: int):
52
57
self .count = count
53
58
self .size = size
54
59
55
- def __iter__ (self ):
60
+ def __iter__ (self ) -> Iterator [ Tensor ] :
56
61
for _ in range (self .count ):
57
62
yield torch .randn (self .size )
58
63
@@ -62,16 +67,16 @@ def __init__(self, size: int, count: int):
62
67
self .count = count
63
68
self .size = size
64
69
65
- def __iter__ (self ):
70
+ def __iter__ (self ) -> Iterator [ Tensor ] :
66
71
for _ in range (len (self )):
67
72
yield torch .randn (self .size )
68
73
69
- def __len__ (self ):
74
+ def __len__ (self ) -> int :
70
75
return self .count
71
76
72
77
73
78
class BoringModel (LightningModule ):
74
- def __init__ (self ):
79
+ def __init__ (self ) -> None :
75
80
"""Testing PL Module.
76
81
77
82
Use as follows:
@@ -90,60 +95,63 @@ def training_step(...):
90
95
super ().__init__ ()
91
96
self .layer = torch .nn .Linear (32 , 2 )
92
97
93
- def forward (self , x ):
98
+ def forward (self , x : Tensor ) -> Tensor : # type: ignore[override]
94
99
return self .layer (x )
95
100
96
- def loss (self , batch , preds ) :
101
+ def loss (self , batch : Tensor , preds : Tensor ) -> Tensor :
97
102
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
98
103
return torch .nn .functional .mse_loss (preds , torch .ones_like (preds ))
99
104
100
- def step (self , x ) :
105
+ def step (self , x : Tensor ) -> Tensor :
101
106
x = self (x )
102
107
out = torch .nn .functional .mse_loss (x , torch .ones_like (x ))
103
108
return out
104
109
105
- def training_step (self , batch , batch_idx ):
110
+ def training_step (self , batch : Tensor , batch_idx : int ) -> STEP_OUTPUT : # type: ignore[override]
106
111
output = self (batch )
107
112
loss = self .loss (batch , output )
108
113
return {"loss" : loss }
109
114
110
- def training_step_end (self , training_step_outputs ) :
115
+ def training_step_end (self , training_step_outputs : STEP_OUTPUT ) -> STEP_OUTPUT :
111
116
return training_step_outputs
112
117
113
- def training_epoch_end (self , outputs ) -> None :
118
+ def training_epoch_end (self , outputs : EPOCH_OUTPUT ) -> None :
119
+ outputs = cast (List [Dict [str , Tensor ]], outputs )
114
120
torch .stack ([x ["loss" ] for x in outputs ]).mean ()
115
121
116
- def validation_step (self , batch , batch_idx ):
122
+ def validation_step (self , batch : Tensor , batch_idx : int ) -> Optional [ STEP_OUTPUT ]: # type: ignore[override]
117
123
output = self (batch )
118
124
loss = self .loss (batch , output )
119
125
return {"x" : loss }
120
126
121
- def validation_epoch_end (self , outputs ) -> None :
127
+ def validation_epoch_end (self , outputs : Union [EPOCH_OUTPUT , List [EPOCH_OUTPUT ]]) -> None :
128
+ outputs = cast (List [Dict [str , Tensor ]], outputs )
122
129
torch .stack ([x ["x" ] for x in outputs ]).mean ()
123
130
124
- def test_step (self , batch , batch_idx ):
131
+ def test_step (self , batch : Tensor , batch_idx : int ) -> Optional [ STEP_OUTPUT ]: # type: ignore[override]
125
132
output = self (batch )
126
133
loss = self .loss (batch , output )
127
134
return {"y" : loss }
128
135
129
- def test_epoch_end (self , outputs ) -> None :
136
+ def test_epoch_end (self , outputs : Union [EPOCH_OUTPUT , List [EPOCH_OUTPUT ]]) -> None :
137
+ outputs = cast (List [Dict [str , Tensor ]], outputs )
130
138
torch .stack ([x ["y" ] for x in outputs ]).mean ()
131
139
132
- def configure_optimizers (self ):
140
+ def configure_optimizers (self ) -> Tuple [ List [ torch . optim . Optimizer ], List [ _LRScheduler ]] :
133
141
optimizer = torch .optim .SGD (self .layer .parameters (), lr = 0.1 )
134
142
lr_scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = 1 )
135
143
return [optimizer ], [lr_scheduler ]
136
144
137
- def train_dataloader (self ):
145
+ def train_dataloader (self ) -> DataLoader :
138
146
return DataLoader (RandomDataset (32 , 64 ))
139
147
140
- def val_dataloader (self ):
148
+ def val_dataloader (self ) -> DataLoader :
141
149
return DataLoader (RandomDataset (32 , 64 ))
142
150
143
- def test_dataloader (self ):
151
+ def test_dataloader (self ) -> DataLoader :
144
152
return DataLoader (RandomDataset (32 , 64 ))
145
153
146
- def predict_dataloader (self ):
154
+ def predict_dataloader (self ) -> DataLoader :
147
155
return DataLoader (RandomDataset (32 , 64 ))
148
156
149
157
@@ -155,7 +163,7 @@ def __init__(self, data_dir: str = "./"):
155
163
self .checkpoint_state : Optional [str ] = None
156
164
self .random_full = RandomDataset (32 , 64 * 4 )
157
165
158
- def setup (self , stage : Optional [str ] = None ):
166
+ def setup (self , stage : Optional [str ] = None ) -> None :
159
167
if stage == "fit" or stage is None :
160
168
self .random_train = Subset (self .random_full , indices = range (64 ))
161
169
@@ -168,26 +176,27 @@ def setup(self, stage: Optional[str] = None):
168
176
if stage == "predict" or stage is None :
169
177
self .random_predict = Subset (self .random_full , indices = range (64 * 3 , 64 * 4 ))
170
178
171
- def train_dataloader (self ):
179
+ def train_dataloader (self ) -> DataLoader :
172
180
return DataLoader (self .random_train )
173
181
174
- def val_dataloader (self ):
182
+ def val_dataloader (self ) -> DataLoader :
175
183
return DataLoader (self .random_val )
176
184
177
- def test_dataloader (self ):
185
+ def test_dataloader (self ) -> DataLoader :
178
186
return DataLoader (self .random_test )
179
187
180
- def predict_dataloader (self ):
188
+ def predict_dataloader (self ) -> DataLoader :
181
189
return DataLoader (self .random_predict )
182
190
183
191
184
192
class ManualOptimBoringModel (BoringModel ):
185
- def __init__ (self ):
193
+ def __init__ (self ) -> None :
186
194
super ().__init__ ()
187
195
self .automatic_optimization = False
188
196
189
- def training_step (self , batch , batch_idx ):
197
+ def training_step (self , batch : Tensor , batch_idx : int ) -> STEP_OUTPUT : # type: ignore[override]
190
198
opt = self .optimizers ()
199
+ assert isinstance (opt , (Optimizer , LightningOptimizer ))
191
200
output = self (batch )
192
201
loss = self .loss (batch , output )
193
202
opt .zero_grad ()
@@ -202,21 +211,21 @@ def __init__(self, out_dim: int = 10, learning_rate: float = 0.02):
202
211
self .l1 = torch .nn .Linear (32 , out_dim )
203
212
self .learning_rate = learning_rate
204
213
205
- def forward (self , x ):
214
+ def forward (self , x : Tensor ) -> Tensor : # type: ignore[override]
206
215
return torch .relu (self .l1 (x .view (x .size (0 ), - 1 )))
207
216
208
- def training_step (self , batch , batch_nb ):
217
+ def training_step (self , batch : Tensor , batch_nb : int ) -> STEP_OUTPUT : # type: ignore[override]
209
218
x = batch
210
219
x = self (x )
211
220
loss = x .sum ()
212
221
return loss
213
222
214
- def configure_optimizers (self ):
223
+ def configure_optimizers (self ) -> torch . optim . Optimizer :
215
224
return torch .optim .Adam (self .parameters (), lr = self .learning_rate )
216
225
217
226
218
227
class Net (nn .Module ):
219
- def __init__ (self ):
228
+ def __init__ (self ) -> None :
220
229
super ().__init__ ()
221
230
self .conv1 = nn .Conv2d (1 , 32 , 3 , 1 )
222
231
self .conv2 = nn .Conv2d (32 , 64 , 3 , 1 )
@@ -225,7 +234,7 @@ def __init__(self):
225
234
self .fc1 = nn .Linear (9216 , 128 )
226
235
self .fc2 = nn .Linear (128 , 10 )
227
236
228
- def forward (self , x ) :
237
+ def forward (self , x : Tensor ) -> Tensor :
229
238
x = self .conv1 (x )
230
239
x = F .relu (x )
231
240
x = self .conv2 (x )
0 commit comments