diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 8ca76ac37..5c098c600 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -501,7 +501,9 @@ def forward( embeddings = [] kjt_keys = _get_kjt_keys(features) # Cache the features order since the features will always have the same order of keys in inference. - if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False): + if isinstance(features, KeyedJaggedTensor) and getattr( + self, MODULE_ATTR_CACHE_FEATURES_ORDER, False + ): if self._features_order == []: for k in self._feature_names: self._features_order.append(kjt_keys.index(k)) @@ -880,7 +882,9 @@ def forward( feature_embeddings: Dict[str, JaggedTensor] = {} kjt_keys = _get_kjt_keys(features) # Cache the features order since the features will always have the same order of keys in inference. - if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False): + if isinstance(features, KeyedJaggedTensor) and getattr( + self, MODULE_ATTR_CACHE_FEATURES_ORDER, False + ): if self._features_order == []: for k in self._feature_names: self._features_order.append(kjt_keys.index(k))