diff --git a/torchrec/modules/regroup.py b/torchrec/modules/regroup.py index 3e91ce981..c704e9f90 100644 --- a/torchrec/modules/regroup.py +++ b/torchrec/modules/regroup.py @@ -34,19 +34,6 @@ def _permuted_values( return torch.cat(values, dim=dim) -@torch.fx.wrap -def _build_dict( - keys: List[str], - values: Union[torch.Tensor, List[torch.Tensor]], - splits: List[int], - dim: int, -) -> Dict[str, torch.Tensor]: - if isinstance(values, torch.Tensor): - return dict(zip(keys, torch.split(values, splits, dim=dim))) - else: - return dict(zip(keys, values)) - - @torch.fx.wrap def module_init(module: "KTRegroupAsDict", keyed_tensors: List[KeyedTensor]) -> None: assert len(keyed_tensors) > 0, "Empty list provided" @@ -115,6 +102,12 @@ def forward(self, values: List[torch.Tensor]) -> List[torch.Tensor]: ) +def _to_tensor_dict( + keys: List[str], values: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]] +) -> Dict[str, torch.Tensor]: + return {key: values[i] for i, key in enumerate(keys)} + + class KTRegroupAsDict(torch.nn.Module, CacheMixin): """ KTRegroupAsDict is a nn.Module that mirrors beahvior of static method KeyedTensor.regroup_as_dict() @@ -204,11 +197,13 @@ def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]: if self._use_fbgemm_regroup: values = _get_kts_values(keyed_tensors) permuted_values = self._permute_pooled_embs_impl(values) + return _to_tensor_dict(self._keys, permuted_values) else: permuted_values = _permuted_values( keyed_tensors, self._idx_key_pairs, self._dim ) - return _build_dict(self._keys, permuted_values, self._splits, self._dim) + splitted_values = torch.split(permuted_values, self._splits, dim=self._dim) + return _to_tensor_dict(self._keys, splitted_values) def clear_cache(self) -> None: self._is_inited = False