From 42cb2e68ddc428d90ad77ac67f09f5d7e47754f2 Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Fri, 26 Sep 2025 12:40:51 -0700 Subject: [PATCH] Unwrap variables before passing to `jax.jit`. Passing `__jax_array__`-implementing objects directly to `jax.jit` will no longer be supported. --- keras_rs/src/layers/embedding/distributed_embedding_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_rs/src/layers/embedding/distributed_embedding_test.py b/keras_rs/src/layers/embedding/distributed_embedding_test.py index 94967c1..c2149e2 100644 --- a/keras_rs/src/layers/embedding/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/distributed_embedding_test.py @@ -598,8 +598,8 @@ def test_correctness( non_trainable_layouts, ), )( - layer.trainable_variables, - layer.non_trainable_variables, + [v.value for v in layer.trainable_variables], + [v.value for v in layer.non_trainable_variables], preprocessed, ) else: