@@ -202,34 +202,24 @@ def _maybe_create_strong_reference(self, value):
202
202
# For sharded arrays, hold references to the shards' data.
203
203
shard_data = [shard .data for shard in value .addressable_shards ]
204
204
if hasattr (self , "_trace_state" ):
205
- # NNX variable - check for TraceContextError
205
+ # NNX variable - be careful with mutations during tracing
206
206
try :
207
207
self ._shard_references = [shard_data ]
208
- except Exception as e :
209
- if config .is_nnx_enabled ():
210
- from flax .errors import TraceContextError
211
-
212
- if not isinstance (e , TraceContextError ):
213
- raise
214
- else :
215
- raise
208
+ except Exception :
209
+ # During tracing, mutations might not be allowed
210
+ pass
216
211
else :
217
212
# Regular JaxVariable - always create reference
218
213
self ._shard_references = [shard_data ]
219
214
else :
220
215
# For non-sharded arrays, hold a ref to the array itself.
221
216
if hasattr (self , "_trace_state" ):
222
- # NNX variable - check for TraceContextError
217
+ # NNX variable - be careful with mutations during tracing
223
218
try :
224
219
self ._strong_reference = value
225
- except Exception as e :
226
- if config .is_nnx_enabled ():
227
- from flax .errors import TraceContextError
228
-
229
- if not isinstance (e , TraceContextError ):
230
- raise
231
- else :
232
- raise
220
+ except Exception :
221
+ # During tracing, mutations might not be allowed
222
+ pass
233
223
else :
234
224
# Regular JaxVariable - always create reference
235
225
self ._strong_reference = value
0 commit comments