Skip to content

Commit 2cc1abe

Browse files
committed
add & config mypy on common package
Signed-off-by: wiseaidev <[email protected]>
1 parent 78c3235 commit 2cc1abe

File tree

17 files changed

+218
-84
lines changed

17 files changed

+218
-84
lines changed

common/batch.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,40 @@
11
"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
22
"""
33
# flake8: noqa
4-
from __future__ import annotations
5-
from typing import Dict
4+
from __future__ import (
5+
annotations,
6+
)
7+
68
import abc
7-
from dataclasses import dataclass
89
import dataclasses
10+
from collections import (
11+
UserDict,
12+
)
13+
from dataclasses import (
14+
dataclass,
15+
)
16+
from typing import (
17+
Any,
18+
Dict,
19+
List,
20+
TypeVar,
21+
)
922

1023
import torch
11-
from torchrec.streamable import Pipelineable
24+
from torchrec.streamable import (
25+
Pipelineable,
26+
)
27+
28+
_KT = TypeVar("_KT") # key type
29+
_VT = TypeVar("_VT") # value type
1230

1331

1432
class BatchBase(Pipelineable, abc.ABC):
1533
@abc.abstractmethod
16-
def as_dict(self) -> Dict:
34+
def as_dict(self) -> Dict[str, Any]:
1735
raise NotImplementedError
1836

19-
def to(self, device: torch.device, non_blocking: bool = False):
37+
def to(self, device: torch.device, non_blocking: bool = False) -> BatchBase:
2038
args = {}
2139
for feature_name, feature_value in self.as_dict().items():
2240
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
@@ -26,14 +44,14 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
2644
for feature_value in self.as_dict().values():
2745
feature_value.record_stream(stream)
2846

29-
def pin_memory(self):
47+
def pin_memory(self) -> BatchBase:
3048
args = {}
3149
for feature_name, feature_value in self.as_dict().items():
3250
args[feature_name] = feature_value.pin_memory()
3351
return self.__class__(**args)
3452

3553
def __repr__(self) -> str:
36-
def obj2str(v):
54+
def obj2str(v: Any) -> str:
3755
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
3856

3957
return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])
@@ -52,18 +70,18 @@ def batch_size(self) -> int:
5270
@dataclass
5371
class DataclassBatch(BatchBase):
5472
@classmethod
55-
def feature_names(cls):
73+
def feature_names(cls) -> List[str]:
5674
return list(cls.__dataclass_fields__.keys())
5775

58-
def as_dict(self):
76+
def as_dict(self) -> Dict[str, Any]:
5977
return {
6078
feature_name: getattr(self, feature_name)
6179
for feature_name in self.feature_names()
6280
if hasattr(self, feature_name)
6381
}
6482

6583
@staticmethod
66-
def from_schema(name: str, schema):
84+
def from_schema(name: str, schema: Any) -> type:
6785
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
6886
return dataclasses.make_dataclass(
6987
cls_name=name,
@@ -72,14 +90,14 @@ def from_schema(name: str, schema):
7290
)
7391

7492
@staticmethod
75-
def from_fields(name: str, fields: dict):
93+
def from_fields(name: str, fields: Dict[str, Any]) -> type:
7694
return dataclasses.make_dataclass(
7795
cls_name=name,
7896
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
7997
bases=(DataclassBatch,),
8098
)
8199

82100

83-
class DictionaryBatch(BatchBase, dict):
84-
def as_dict(self) -> Dict:
101+
class DictionaryBatch(BatchBase, UserDict[_KT, _VT]):
102+
def as_dict(self) -> Dict[str, Any]:
85103
return self

common/checkpointing/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot
1+
from tml.common.checkpointing.snapshot import (
2+
Snapshot,
3+
get_checkpoint,
4+
)

common/checkpointing/snapshot.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
import os
22
import time
3-
from typing import Any, Dict, List, Optional
4-
5-
from tml.ml_logging.torch_logging import logging
6-
from tml.common.filesystem import infer_fs, is_gcs_fs
3+
from typing import (
4+
Any,
5+
Dict,
6+
Generator,
7+
List,
8+
Optional,
9+
)
710

811
import torchsnapshot
9-
12+
from tml.common.filesystem import (
13+
infer_fs,
14+
is_gcs_fs,
15+
)
16+
from tml.ml_logging.torch_logging import (
17+
logging,
18+
)
19+
from torch import (
20+
FloatTensor,
21+
)
1022

1123
DONE_EVAL_SUBDIR = "evaled_by"
1224
GCS_PREFIX = "gs://"
@@ -25,22 +37,22 @@ def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
2537
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
2638

2739
@property
28-
def step(self):
40+
def step(self) -> int:
2941
return self.state["extra_state"]["step"]
3042

3143
@step.setter
3244
def step(self, step: int) -> None:
3345
self.state["extra_state"]["step"] = step
3446

3547
@property
36-
def walltime(self):
48+
def walltime(self) -> float:
3749
return self.state["extra_state"]["walltime"]
3850

3951
@walltime.setter
4052
def walltime(self, walltime: float) -> None:
4153
self.state["extra_state"]["walltime"] = walltime
4254

43-
def save(self, global_step: int) -> "PendingSnapshot":
55+
def save(self, global_step: int) -> "PendingSnapshot": # type: ignore
4456
"""Saves checkpoint with given global_step."""
4557
path = os.path.join(self.save_dir, str(global_step))
4658
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
@@ -98,7 +110,7 @@ def load_snapshot_to_weight(
98110
cls,
99111
embedding_snapshot: torchsnapshot.Snapshot,
100112
snapshot_emb_name: str,
101-
weight_tensor,
113+
weight_tensor: FloatTensor,
102114
) -> None:
103115
"""Loads pretrained embedding from the snapshot to the model.
104116
Utilise partial lodaing meachanism from torchsnapshot.
@@ -128,19 +140,21 @@ def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str:
128140
return os.path.join(_eval_subdir(checkpoint_path), f"{eval_partition}_DONE")
129141

130142

131-
def is_done_eval(checkpoint_path: str, eval_partition: str):
132-
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition))
143+
def is_done_eval(checkpoint_path: str, eval_partition: str) -> bool:
144+
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition)) # type: ignore[attr-defined]
133145

134146

135-
def mark_done_eval(checkpoint_path: str, eval_partition: str):
147+
def mark_done_eval(checkpoint_path: str, eval_partition: str) -> Any:
136148
infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition))
137149

138150

139151
def step_from_checkpoint(checkpoint: str) -> int:
140152
return int(os.path.basename(checkpoint))
141153

142154

143-
def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800):
155+
def checkpoints_iterator(
156+
save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800
157+
) -> Generator[str, None, None]:
144158
"""Simplified equivalent of tf.train.checkpoints_iterator.
145159
146160
Args:
@@ -149,7 +163,7 @@ def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int
149163
150164
"""
151165

152-
def _poll(last_checkpoint: Optional[str] = None):
166+
def _poll(last_checkpoint: Optional[str] = None) -> Optional[str]:
153167
stop_time = time.time() + timeout
154168
while True:
155169
_checkpoint_path = get_checkpoint(save_dir, missing_ok=True)

common/device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.distributed as dist
55

66

7-
def maybe_setup_tensorflow():
7+
def maybe_setup_tensorflow() -> None:
88
try:
99
import tensorflow as tf
1010
except ImportError:

common/filesystem/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs
1+
from tml.common.filesystem.util import (
2+
infer_fs,
3+
is_gcs_fs,
4+
is_local_fs,
5+
)

common/filesystem/test_infer_fs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
33
Mostly a test that it returns an object
44
"""
5-
from tml.common.filesystem import infer_fs
5+
from tml.common.filesystem import (
6+
infer_fs,
7+
)
68

79

810
def test_infer_fs():

common/filesystem/util.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
"""Utilities for interacting with the file systems."""
2-
from fsspec.implementations.local import LocalFileSystem
3-
import gcsfs
2+
from typing import (
3+
Union,
4+
)
45

6+
import gcsfs
7+
from fsspec.implementations.local import (
8+
LocalFileSystem,
9+
)
510

611
GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1)
712
LOCAL_FS = LocalFileSystem()
813

914

10-
def infer_fs(path: str):
15+
def infer_fs(path: str) -> Union[LocalFileSystem, gcsfs.core.GCSFileSystem, NotImplementedError]:
1116
if path.startswith("gs://"):
1217
return GCS_FS
1318
elif path.startswith("hdfs://"):
@@ -17,9 +22,9 @@ def infer_fs(path: str):
1722
return LOCAL_FS
1823

1924

20-
def is_local_fs(fs):
25+
def is_local_fs(fs: LocalFileSystem) -> bool:
2126
return fs == LOCAL_FS
2227

2328

24-
def is_gcs_fs(fs):
29+
def is_gcs_fs(fs: gcsfs.core.GCSFileSystem) -> bool:
2530
return fs == GCS_FS

common/log_weights.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,28 @@
11
"""For logging model weights."""
22
import itertools
3-
from typing import Callable, Dict, List, Optional, Union
3+
from typing import (
4+
Any,
5+
Callable,
6+
Dict,
7+
List,
8+
Optional,
9+
Union,
10+
)
411

5-
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
612
import torch
713
import torch.distributed as dist
8-
from torchrec.distributed.model_parallel import DistributedModelParallel
14+
from tml.ml_logging.torch_logging import (
15+
logging,
16+
)
17+
from torchrec.distributed.model_parallel import (
18+
DistributedModelParallel,
19+
)
920

1021

1122
def weights_to_log(
1223
model: torch.nn.Module,
13-
how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None,
14-
):
24+
how_to_log: Optional[Union[Callable[[Any], Any], Dict[str, Callable[[Any], Any]]]] = None,
25+
) -> Optional[Dict[str, Any]]:
1526
"""Creates dict of reduced weights to log to give sense of training.
1627
1728
Args:
@@ -21,7 +32,7 @@ def weights_to_log(
2132
2233
"""
2334
if not how_to_log:
24-
return
35+
return None
2536

2637
to_log = dict()
2738
named_parameters = model.named_parameters()
@@ -38,14 +49,14 @@ def weights_to_log(
3849
how = how_to_log
3950
else:
4051
how = how_to_log.get(param_name) # type: ignore[assignment]
41-
if not how:
42-
continue # type: ignore
52+
if how is None:
53+
continue
4354
to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy()
4455
return to_log
4556

4657

4758
def log_ebc_norms(
48-
model_state_dict,
59+
model_state_dict: Dict[str, Any],
4960
ebc_keys: List[str],
5061
sample_size: int = 4_000_000,
5162
) -> Dict[str, torch.Tensor]:

common/modules/embedding/config.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from typing import List
21
from enum import Enum
3-
4-
import tml.core.config as base_config
5-
from tml.optimizers.config import OptimizerConfig
2+
from typing import (
3+
List,
4+
)
65

76
import pydantic
7+
import tml.core.config as base_config
8+
from tml.optimizers.config import (
9+
OptimizerConfig,
10+
)
811

912

1013
class DataType(str, Enum):

common/modules/embedding/embedding.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
1-
from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType
2-
from tml.ml_logging.torch_logging import logging
3-
1+
import numpy as np
42
import torch
5-
from torch import nn
63
import torchrec
7-
from torchrec.modules import embedding_configs
8-
from torchrec import EmbeddingBagConfig, EmbeddingBagCollection
9-
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
10-
import numpy as np
4+
from tml.common.modules.embedding.config import (
5+
DataType,
6+
LargeEmbeddingsConfig,
7+
)
8+
from tml.ml_logging.torch_logging import (
9+
logging,
10+
)
11+
from torch import nn
12+
from torchrec import (
13+
EmbeddingBagCollection,
14+
EmbeddingBagConfig,
15+
)
16+
from torchrec.modules import (
17+
embedding_configs,
18+
)
19+
from torchrec.sparse.jagged_tensor import (
20+
KeyedJaggedTensor,
21+
KeyedTensor,
22+
)
1123

1224

1325
class LargeEmbeddings(nn.Module):

0 commit comments

Comments
 (0)