From 47253c4d02e6d4a757347c1ca37e5c6a4975bc9e Mon Sep 17 00:00:00 2001 From: Sidney Tsang Date: Thu, 13 Mar 2025 12:45:56 -0700 Subject: [PATCH] Fix symbolic tracing error during IEN lowering Summary: Symbolic trace passes in Proxy object which is incompatible with existing logic for caching kjt feature order in D68991644. Skipping this caching in these cases to avoid errors. Differential Revision: D71139922 --- torchrec/quant/embedding_modules.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 8ca76ac37..5c098c600 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -501,7 +501,9 @@ def forward( embeddings = [] kjt_keys = _get_kjt_keys(features) # Cache the features order since the features will always have the same order of keys in inference. - if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False): + if isinstance(features, KeyedJaggedTensor) and getattr( + self, MODULE_ATTR_CACHE_FEATURES_ORDER, False + ): if self._features_order == []: for k in self._feature_names: self._features_order.append(kjt_keys.index(k)) @@ -880,7 +882,9 @@ def forward( feature_embeddings: Dict[str, JaggedTensor] = {} kjt_keys = _get_kjt_keys(features) # Cache the features order since the features will always have the same order of keys in inference. - if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False): + if isinstance(features, KeyedJaggedTensor) and getattr( + self, MODULE_ATTR_CACHE_FEATURES_ORDER, False + ): if self._features_order == []: for k in self._feature_names: self._features_order.append(kjt_keys.index(k))