Skip to content

Commit 3a86f1e

Browse files
Fix ModelParallel sharded variable loading test
- Change Embedding input_dim from 100 to 96 to be divisible by 8 devices - Add skip condition for tests requiring < 2 devices (sharding not meaningful) - Prevents ValueError when sharding dimensions not divisible by device count - Test now properly skips on single-device systems and passes on multi-device CI
1 parent 496baa4 commit 3a86f1e

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

keras/src/distribution/distribution_lib_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,11 @@ def test_model_parallel_sharded_variable_loading(self):
381381

382382
# Ensure we have JAX devices
383383
jax_devices = jax.devices()
384+
if len(jax_devices) < 2:
385+
pytest.skip(
386+
"Test requires at least 2 devices for meaningful sharding"
387+
)
388+
384389
logging.debug(f"JAX devices available: {len(jax_devices)}")
385390
for i, device in enumerate(jax_devices):
386391
logging.debug(f" Device {i}: {device}")
@@ -432,7 +437,7 @@ def test_model_parallel_sharded_variable_loading(self):
432437
),
433438
# Embedding layer (modified in commit)
434439
layers.Embedding(
435-
input_dim=100, output_dim=32, name="embedding"
440+
input_dim=96, output_dim=32, name="embedding"
436441
),
437442
layers.Flatten(),
438443
# Convolutional layer (modified in commit)
@@ -517,7 +522,7 @@ def test_model_parallel_sharded_variable_loading(self):
517522
"ab,bc->ac", output_shape=32, name="einsum_dense"
518523
),
519524
layers.Embedding(
520-
input_dim=100, output_dim=32, name="embedding"
525+
input_dim=96, output_dim=32, name="embedding"
521526
),
522527
layers.Flatten(),
523528
layers.Reshape(

0 commit comments

Comments
 (0)