@@ -34,19 +34,6 @@ def _permuted_values(
34
34
return torch .cat (values , dim = dim )
35
35
36
36
37
- @torch .fx .wrap
38
- def _build_dict (
39
- keys : List [str ],
40
- values : Union [torch .Tensor , List [torch .Tensor ]],
41
- splits : List [int ],
42
- dim : int ,
43
- ) -> Dict [str , torch .Tensor ]:
44
- if isinstance (values , torch .Tensor ):
45
- return dict (zip (keys , torch .split (values , splits , dim = dim )))
46
- else :
47
- return dict (zip (keys , values ))
48
-
49
-
50
37
@torch .fx .wrap
51
38
def module_init (module : "KTRegroupAsDict" , keyed_tensors : List [KeyedTensor ]) -> None :
52
39
assert len (keyed_tensors ) > 0 , "Empty list provided"
@@ -115,6 +102,12 @@ def forward(self, values: List[torch.Tensor]) -> List[torch.Tensor]:
115
102
)
116
103
117
104
105
+ def _to_tensor_dict (
106
+ keys : List [str ], values : Union [List [torch .Tensor ], Tuple [torch .Tensor , ...]]
107
+ ) -> Dict [str , torch .Tensor ]:
108
+ return {key : values [i ] for i , key in enumerate (keys )}
109
+
110
+
118
111
class KTRegroupAsDict (torch .nn .Module , CacheMixin ):
119
112
"""
120
113
KTRegroupAsDict is a nn.Module that mirrors beahvior of static method KeyedTensor.regroup_as_dict()
@@ -204,11 +197,13 @@ def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]:
204
197
if self ._use_fbgemm_regroup :
205
198
values = _get_kts_values (keyed_tensors )
206
199
permuted_values = self ._permute_pooled_embs_impl (values )
200
+ return _to_tensor_dict (self ._keys , permuted_values )
207
201
else :
208
202
permuted_values = _permuted_values (
209
203
keyed_tensors , self ._idx_key_pairs , self ._dim
210
204
)
211
- return _build_dict (self ._keys , permuted_values , self ._splits , self ._dim )
205
+ splitted_values = torch .split (permuted_values , self ._splits , dim = self ._dim )
206
+ return _to_tensor_dict (self ._keys , splitted_values )
212
207
213
208
def clear_cache (self ) -> None :
214
209
self ._is_inited = False
0 commit comments