Skip to content

Commit 1eceb58

Browse files
generatedunixname89002005307016facebook-github-bot
authored andcommitted
suppress errors in torchrec (#3016)
Summary: Pull Request resolved: #3016 Differential Revision: D75570790 fbshipit-source-id: c3200326178f2872a45dada80b7f4ec551321729
1 parent 6a28dca commit 1eceb58

33 files changed

+193
-3
lines changed

torchrec/datasets/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "Batch":
4343

4444
def record_stream(self, stream: torch.Stream) -> None:
4545
self.dense_features.record_stream(stream)
46+
# pyre-fixme[6]: For 1st argument expected `Stream` but got `Stream`.
4647
self.sparse_features.record_stream(stream)
4748
self.labels.record_stream(stream)
4849

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,8 @@ def _gen_named_parameters_by_table_ssd_pmt(
840840
for table_config, pmt in zip(config.embedding_tables, pmts):
841841
table_name = table_config.name
842842
emb_table = pmt
843+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
844+
# `Union[PartiallyMaterializedTensor, Tensor]`.
843845
weight: nn.Parameter = nn.Parameter(emb_table)
844846
# pyre-ignore
845847
weight._in_backward_optimizers = [EmptyFusedOptimizer()]
@@ -1229,6 +1231,10 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
12291231
Optional[List[torch.Tensor]],
12301232
Optional[List[torch.Tensor]],
12311233
]:
1234+
# pyre-fixme[7]: Expected `Tuple[List[PartiallyMaterializedTensor],
1235+
# Optional[List[Tensor]], Optional[List[Tensor]]]` but got
1236+
# `Tuple[Union[List[PartiallyMaterializedTensor], List[Tensor]],
1237+
# Optional[List[Tensor]], Optional[List[Tensor]]]`.
12321238
return self.emb_module.split_embedding_weights(no_snapshot)
12331239

12341240

@@ -2027,6 +2033,10 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
20272033
Optional[List[torch.Tensor]],
20282034
Optional[List[torch.Tensor]],
20292035
]:
2036+
# pyre-fixme[7]: Expected `Tuple[List[PartiallyMaterializedTensor],
2037+
# Optional[List[Tensor]], Optional[List[Tensor]]]` but got
2038+
# `Tuple[Union[List[PartiallyMaterializedTensor], List[Tensor]],
2039+
# Optional[List[Tensor]], Optional[List[Tensor]]]`.
20302040
return self.emb_module.split_embedding_weights(no_snapshot)
20312041

20322042

torchrec/distributed/composable/tests/test_fused_optim_nccl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def _test_sharded_fused_optimizer_state_dict(
4141
ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta"))
4242
apply_optimizer_in_backward(
4343
RowWiseAdagrad,
44+
# pyre-fixme[6]: For 2nd argument expected `Iterable[Parameter]` but
45+
# got `Iterable[Union[Tensor, Module]]`.
4446
[
4547
ebc.embedding_bags["table_0"].weight,
4648
ebc.embedding_bags["table_1"].weight,
@@ -49,6 +51,8 @@ def _test_sharded_fused_optimizer_state_dict(
4951
)
5052
apply_optimizer_in_backward(
5153
PartialRowWiseAdam,
54+
# pyre-fixme[6]: For 2nd argument expected `Iterable[Parameter]` but
55+
# got `Iterable[Union[Tensor, Module]]`.
5256
[
5357
ebc.embedding_bags["table_2"].weight,
5458
ebc.embedding_bags["table_3"].weight,

torchrec/distributed/embedding.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def record_stream(self, stream: torch.Stream) -> None:
315315
for ctx in self.sharding_contexts:
316316
ctx.record_stream(stream)
317317
for f in self.input_features:
318+
# pyre-fixme[6]: For 1st argument expected `Stream` but got `Stream`.
318319
f.record_stream(stream)
319320
for r in self.reverse_indices:
320321
r.record_stream(stream)
@@ -892,6 +893,10 @@ def _initialize_torch_state(self) -> None: # noqa
892893
_model_parallel_name_to_compute_kernel[table_name]
893894
!= EmbeddingComputeKernel.DENSE.value
894895
):
896+
# pyre-fixme[16]: `Module` has no attribute
897+
# `_in_backward_optimizers`.
898+
# pyre-fixme[16]: `Tensor` has no attribute
899+
# `_in_backward_optimizers`.
895900
self.embeddings[table_name].weight._in_backward_optimizers = [
896901
EmptyFusedOptimizer()
897902
]
@@ -1110,6 +1115,8 @@ def reset_parameters(self) -> None:
11101115
if sharding_type == ShardingType.DATA_PARALLEL.value:
11111116
pg = self._env.process_group
11121117
with torch.no_grad():
1118+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
1119+
# `TypeUnion[Module, Tensor]`.
11131120
dist.broadcast(param.data, src=0, group=pg)
11141121

11151122
def _generate_permute_indices_per_feature(

torchrec/distributed/embedding_lookup.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _load_state_dict(
134134
dst_param.detach().copy_(src_param)
135135
unexpected_keys.remove(key)
136136
else:
137+
# pyre-fixme[22]: The cast is redundant.
137138
missing_keys.append(cast(str, key))
138139
return missing_keys, unexpected_keys
139140

@@ -278,13 +279,17 @@ def prefetch(
278279
SSDTableBatchedEmbeddingBags,
279280
),
280281
)
282+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no
283+
# attribute `prefetch_pipeline`.
281284
and not emb_op.emb_module.prefetch_pipeline
282285
):
283286
logging.error(
284287
f"Invalid setting on {type(emb_op.emb_module)} modules. prefetch_pipeline must be set to True.\n"
285288
"If you don’t turn on prefetch_pipeline, cache locations might be wrong in backward and can cause wrong results.\n"
286289
)
287290
if hasattr(emb_op.emb_module, "prefetch"):
291+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no
292+
# attribute `prefetch`.
288293
emb_op.emb_module.prefetch(
289294
indices=features.values(),
290295
offsets=features.offsets(),
@@ -317,6 +322,7 @@ def state_dict(
317322
destination._metadata = OrderedDict()
318323

319324
for emb_module in self._emb_modules:
325+
# pyre-fixme[19]: Expected 0 positional arguments.
320326
emb_module.state_dict(destination, prefix, keep_vars)
321327

322328
return destination
@@ -363,6 +369,7 @@ def named_parameters_by_table(
363369
for (
364370
table_name,
365371
tbe_slice,
372+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
366373
) in embedding_kernel.named_parameters_by_table():
367374
yield (table_name, tbe_slice)
368375

@@ -389,10 +396,12 @@ def get_named_split_embedding_weights_snapshot(
389396

390397
def flush(self) -> None:
391398
for emb_module in self._emb_modules:
399+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
392400
emb_module.flush()
393401

394402
def purge(self) -> None:
395403
for emb_module in self._emb_modules:
404+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
396405
emb_module.purge()
397406

398407

@@ -521,6 +530,8 @@ def _need_prefetch(config: GroupedEmbeddingConfig) -> bool:
521530
self._feature_splits,
522531
)
523532
for emb_op, features in zip(self._emb_modules, features_by_group):
533+
# pyre-fixme[6]: For 1st argument expected `GroupedEmbeddingConfig`
534+
# but got `Union[Module, Tensor]`.
524535
if not _need_prefetch(emb_op.config):
525536
continue
526537
if (
@@ -531,13 +542,17 @@ def _need_prefetch(config: GroupedEmbeddingConfig) -> bool:
531542
SSDTableBatchedEmbeddingBags,
532543
),
533544
)
545+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no
546+
# attribute `prefetch_pipeline`.
534547
and not emb_op.emb_module.prefetch_pipeline
535548
):
536549
logging.error(
537550
f"Invalid setting on {type(emb_op.emb_module)} modules. prefetch_pipeline must be set to True.\n"
538551
"If you don't turn on prefetch_pipeline, cache locations might be wrong in backward and can cause wrong results.\n"
539552
)
540553
if hasattr(emb_op.emb_module, "prefetch"):
554+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no
555+
# attribute `prefetch`.
541556
emb_op.emb_module.prefetch(
542557
indices=features.values(),
543558
offsets=features.offsets(),
@@ -633,6 +648,7 @@ def state_dict(
633648
destination._metadata = OrderedDict()
634649

635650
for emb_module in self._emb_modules:
651+
# pyre-fixme[19]: Expected 0 positional arguments.
636652
emb_module.state_dict(destination, prefix, keep_vars)
637653

638654
return destination
@@ -679,6 +695,7 @@ def named_parameters_by_table(
679695
for (
680696
table_name,
681697
tbe_slice,
698+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
682699
) in embedding_kernel.named_parameters_by_table():
683700
yield (table_name, tbe_slice)
684701

@@ -703,10 +720,12 @@ def get_named_split_embedding_weights_snapshot(
703720

704721
def flush(self) -> None:
705722
for emb_module in self._emb_modules:
723+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
706724
emb_module.flush()
707725

708726
def purge(self) -> None:
709727
for emb_module in self._emb_modules:
728+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
710729
emb_module.purge()
711730

712731

@@ -813,6 +832,7 @@ def state_dict(
813832
destination._metadata = OrderedDict()
814833

815834
for emb_module in self._emb_modules:
835+
# pyre-fixme[19]: Expected 0 positional arguments.
816836
emb_module.state_dict(destination, prefix, keep_vars)
817837

818838
return destination
@@ -980,6 +1000,7 @@ def state_dict(
9801000
destination._metadata = OrderedDict()
9811001

9821002
for emb_module in self._emb_modules:
1003+
# pyre-fixme[19]: Expected 0 positional arguments.
9831004
emb_module.state_dict(destination, prefix, keep_vars)
9841005

9851006
return destination

torchrec/distributed/embedding_tower_sharding.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,8 @@ def __init__(
501501
# create mapping of logical towers to physical towers
502502
tables_per_lt: List[Set[str]] = []
503503
for tower in module.towers:
504+
# pyre-fixme[6]: For 1st argument expected `EmbeddingTower` but got
505+
# `Module`.
504506
lt_tables = set(tower_sharder.shardable_parameters(tower).keys())
505507
tables_per_lt.append(lt_tables)
506508
# check the tables in a logical tower are on same physical tower
@@ -546,6 +548,8 @@ def __init__(
546548
self._kjt_num_features_per_pt.append(len(kjt_names))
547549
self._wkjt_num_features_per_pt.append(len(wkjt_names))
548550

551+
# pyre-fixme[9]: local_towers has type `List[Tuple[str, EmbeddingTower]]`;
552+
# used as `List[Tuple[str, Module]]`.
549553
local_towers: List[Tuple[str, EmbeddingTower]] = [
550554
(str(i), tower)
551555
for i, tower in enumerate(module.towers)
@@ -603,6 +607,8 @@ def __init__(
603607
# Setup output dists for quantized comms
604608
output_dists = nn.ModuleList()
605609
for embedding in self.embeddings.values():
610+
# pyre-fixme[6]: For 1st argument expected `Iterable[Module]` but got
611+
# `Union[Module, Tensor]`.
606612
output_dists.extend(embedding._output_dists)
607613
self._output_dists: nn.ModuleList = output_dists
608614

@@ -780,10 +786,13 @@ def state_dict(
780786
# pyre-ignore [16]
781787
destination._metadata = OrderedDict()
782788
for i, embedding in self.embeddings.items():
789+
# pyre-fixme[19]: Expected 0 positional arguments.
783790
embedding.state_dict(
784791
destination, prefix + f"towers.{i}.embedding.", keep_vars
785792
)
786793
for i, interaction in self.interactions.items():
794+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
795+
# `state_dict`.
787796
interaction.module.state_dict(
788797
destination, prefix + f"towers.{i}.interaction.", keep_vars
789798
)
@@ -792,6 +801,9 @@ def state_dict(
792801
@property
793802
def fused_optimizer(self) -> KeyedOptimizer:
794803
return CombinedOptimizer(
804+
# pyre-fixme[6]: For 1st argument expected `List[Union[Tuple[str,
805+
# KeyedOptimizer], KeyedOptimizer]]` but got `List[Tuple[str,
806+
# Union[Module, Tensor]]]`.
795807
[
796808
(f"towers.{tower_index}.embedding", embedding.fused_optimizer)
797809
for tower_index, embedding in self.embeddings.items()
@@ -809,6 +821,8 @@ def named_parameters(
809821
)
810822
for i, interaction in self.interactions.items():
811823
yield from (
824+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no
825+
# attribute `named_parameters`.
812826
interaction.module.named_parameters(
813827
append_prefix(prefix, f"towers.{i}.interaction"), recurse
814828
)
@@ -825,6 +839,8 @@ def named_buffers(
825839
)
826840
for i, interaction in self.interactions.items():
827841
yield from (
842+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no
843+
# attribute `named_buffers`.
828844
interaction.module.named_buffers(
829845
append_prefix(prefix, f"towers.{i}.interaction"), recurse
830846
)
@@ -833,13 +849,16 @@ def named_buffers(
833849
def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
834850
for i, embedding in self.embeddings.items():
835851
yield from (
852+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
836853
embedding.sharded_parameter_names(
837854
append_prefix(prefix, f"towers.{i}.embedding")
838855
)
839856
)
840857
for i, interaction in self.interactions.items():
841858
yield from (
842859
key
860+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no
861+
# attribute `named_parameters`.
843862
for key, _ in interaction.module.named_parameters(
844863
append_prefix(prefix, f"towers.{i}.interaction")
845864
)
@@ -993,6 +1012,8 @@ def shardable_parameters(
9931012

9941013
named_parameters: Dict[str, nn.Parameter] = {}
9951014
for tower in module.towers:
1015+
# pyre-fixme[6]: For 1st argument expected `EmbeddingTower` but got
1016+
# `Module`.
9961017
named_parameters.update(self._tower_sharder.shardable_parameters(tower))
9971018
return named_parameters
9981019

torchrec/distributed/embedding_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def __iter__(self) -> Iterator[KeyedJaggedTensor]:
135135
return iter(self.features)
136136

137137
@torch.jit._drop
138+
# pyre-fixme[14]: `record_stream` overrides method defined in `Multistreamable`
139+
# inconsistently.
138140
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
139141
for feature in self.features:
140142
feature.record_stream(stream)
@@ -160,6 +162,7 @@ class InputDistOutputs(Multistreamable):
160162

161163
def record_stream(self, stream: torch.Stream) -> None:
162164
for feature in self.features:
165+
# pyre-fixme[6]: For 1st argument expected `Stream` but got `Stream`.
163166
feature.record_stream(stream)
164167
if self.unbucketize_permute_tensor is not None:
165168
self.unbucketize_permute_tensor.record_stream(stream)
@@ -189,6 +192,7 @@ def __iter__(self) -> Iterator[KJTList]:
189192
@torch.jit._drop
190193
def record_stream(self, stream: torch.Stream) -> None:
191194
for feature in self.features_list:
195+
# pyre-fixme[6]: For 1st argument expected `Stream` but got `Stream`.
192196
feature.record_stream(stream)
193197

194198
@torch.jit._drop

torchrec/distributed/embeddingbag.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,10 @@ def _initialize_torch_state(self, skip_registering: bool = False) -> None: # no
972972
_model_parallel_name_to_compute_kernel[table_name]
973973
!= EmbeddingComputeKernel.DENSE.value
974974
):
975+
# pyre-fixme[16]: `Module` has no attribute
976+
# `_in_backward_optimizers`.
977+
# pyre-fixme[16]: `Tensor` has no attribute
978+
# `_in_backward_optimizers`.
975979
self.embedding_bags[table_name].weight._in_backward_optimizers = [
976980
EmptyFusedOptimizer()
977981
]
@@ -1137,6 +1141,8 @@ def reset_parameters(self) -> None:
11371141
if sharding_type == ShardingType.DATA_PARALLEL.value:
11381142
pg = self._env.process_group
11391143
with torch.no_grad():
1144+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
1145+
# `Union[Module, Tensor]`.
11401146
dist.broadcast(param.data, src=0, group=pg)
11411147

11421148
def _create_input_dist(

torchrec/distributed/fp_embeddingbag.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def __init__(
8383
self._is_collection = True
8484
else:
8585
self._feature_processors = torch.nn.ModuleDict(
86+
# pyre-fixme[29]: `Union[(self: ModuleDict) -> ItemsView[str,
87+
# Module], Module, Tensor]` is not a function.
8688
module._feature_processors.items()
8789
)
8890
self._is_collection = False

torchrec/distributed/mc_embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def record_stream(self, stream: torch.Stream) -> None:
6060
continue
6161
value.record_stream(stream)
6262
if self.remapped_kjt is not None:
63+
# pyre-fixme[6]: For 1st argument expected `Stream` but got `Stream`.
6364
self.remapped_kjt.record_stream(stream)
6465

6566

torchrec/distributed/mc_embeddingbag.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def record_stream(self, stream: torch.Stream) -> None:
4545
continue
4646
value.record_stream(stream)
4747
if self.remapped_kjt is not None:
48+
# pyre-fixme[6]: For 1st argument expected `Stream` but got `Stream`.
4849
self.remapped_kjt.record_stream(stream)
4950

5051

0 commit comments

Comments
 (0)