Skip to content

Commit 7816f0c

Browse files
Fix array deletion by ensuring strong references are created for JaxVariable during normal execution while handling NNX tracing contexts
1 parent b23cee7 commit 7816f0c

File tree

1 file changed

+8
-18
lines changed

1 file changed

+8
-18
lines changed

keras/src/backend/jax/core.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -202,34 +202,24 @@ def _maybe_create_strong_reference(self, value):
202202
# For sharded arrays, hold references to the shards' data.
203203
shard_data = [shard.data for shard in value.addressable_shards]
204204
if hasattr(self, "_trace_state"):
205-
# NNX variable - check for TraceContextError
205+
# NNX variable - be careful with mutations during tracing
206206
try:
207207
self._shard_references = [shard_data]
208-
except Exception as e:
209-
if config.is_nnx_enabled():
210-
from flax.errors import TraceContextError
211-
212-
if not isinstance(e, TraceContextError):
213-
raise
214-
else:
215-
raise
208+
except Exception:
209+
# During tracing, mutations might not be allowed
210+
pass
216211
else:
217212
# Regular JaxVariable - always create reference
218213
self._shard_references = [shard_data]
219214
else:
220215
# For non-sharded arrays, hold a ref to the array itself.
221216
if hasattr(self, "_trace_state"):
222-
# NNX variable - check for TraceContextError
217+
# NNX variable - be careful with mutations during tracing
223218
try:
224219
self._strong_reference = value
225-
except Exception as e:
226-
if config.is_nnx_enabled():
227-
from flax.errors import TraceContextError
228-
229-
if not isinstance(e, TraceContextError):
230-
raise
231-
else:
232-
raise
220+
except Exception:
221+
# During tracing, mutations might not be allowed
222+
pass
233223
else:
234224
# Regular JaxVariable - always create reference
235225
self._strong_reference = value

0 commit comments

Comments
 (0)