Skip to content

Commit 4746a39

Browse files
yashk2810jax authors
authored and
jax authors
committed
Show the correct sharding in is_compatible_aval error in MeshPspecSharding when created via _from_parsed_pspec. Preserve the original PartitionSpec from ParsedPartitionSpec if it exists, else calculate it.
PiperOrigin-RevId: 473267905
1 parent 40c80d7 commit 4746a39

File tree

4 files changed

+60
-5
lines changed

4 files changed

+60
-5
lines changed

jax/experimental/pjit.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,15 @@ def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC):
637637
def user_spec(self):
638638
return self.unsynced_user_spec(SpecSync.IN_SYNC)
639639

640+
def get_partition_spec(self) -> PartitionSpec:
641+
if self.sync < SpecSync.IN_SYNC:
642+
return _get_single_pspec(self)
643+
else:
644+
if isinstance(self.unsafe_user_spec, PartitionSpec):
645+
return self.unsafe_user_spec
646+
else:
647+
return _get_single_pspec(self)
648+
640649
def unsynced_user_spec(self, min_sync):
641650
if self.sync < min_sync:
642651
raise AssertionError(f"Please open a bug report! ({self.sync} >= {min_sync})")
@@ -1283,7 +1292,7 @@ def with_sharding_constraint(x, axis_resources):
12831292
for s in sharding_flat
12841293
]
12851294
else:
1286-
sharding_flat = [MeshPspecSharding._from_parsed_pspec(mesh, a)
1295+
sharding_flat = [pxla._create_mesh_pspec_sharding(mesh, a.user_spec, a)
12871296
for a in axis_resources_flat]
12881297
# Calculate unconstrained_dims from MeshPspecSharding because that information
12891298
# is lost when converted to OpSharding. Bind unconstrained_dims to

jax/experimental/sharding.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,7 @@ def is_compatible_aval(self, aval_shape: Shape):
197197

198198
@classmethod
199199
def _from_parsed_pspec(cls, mesh, parsed_pspec):
200-
from jax.experimental import pjit
201-
return cls(mesh, pjit._get_single_pspec(parsed_pspec), parsed_pspec)
200+
return cls(mesh, parsed_pspec.get_partition_spec(), parsed_pspec)
202201

203202
@pxla.maybe_cached_property
204203
def device_set(self) -> Set[Device]:

tests/array_test.py

+14
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,20 @@ def test_pmap_sharding_hash_eq(self):
505505
self.assertGreater(cache_info2.hits, cache_info1.hits + 1)
506506
self.assertEqual(cache_info2.misses, cache_info1.misses)
507507

508+
def test_is_compatible_error(self):
509+
shape = (8, 2)
510+
mesh = jtu.create_global_mesh((1, 1, 2), ('replica', 'data', 'mdl'))
511+
mps = sharding.MeshPspecSharding(mesh, P(None, ('mdl',), None, None))
512+
new_mps = sharding.MeshPspecSharding._from_parsed_pspec(
513+
mps.mesh, mps._parsed_pspec)
514+
515+
with self.assertRaisesRegex(
516+
ValueError,
517+
r"Sharding MeshPspecSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
518+
r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
519+
"valid for values of rank at least 4, but was applied to a value of rank 2"):
520+
new_mps.is_compatible_aval(shape)
521+
508522

509523
if __name__ == '__main__':
510524
absltest.main(testLoader=jtu.JaxTestLoader())

tests/pjit_test.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,24 @@ def make_keys(seeds):
10521052
self.assertEqual(out.shape, input_shape)
10531053
out.unsafe_raw_array() # doesn't crash
10541054

1055+
def test_with_sharding_constraint_is_compatible_error(self):
1056+
mesh = jtu.create_global_mesh((1, 1, 2), ('replica', 'data', 'mdl'))
1057+
1058+
with mesh:
1059+
def f(x):
1060+
y = with_sharding_constraint(x, P(None, ('mdl',), None, None))
1061+
z = y + 2
1062+
return z
1063+
pjit_f = pjit(f, in_axis_resources=P(None), out_axis_resources=P(None))
1064+
1065+
with self.assertRaisesRegex(
1066+
ValueError,
1067+
r"One of with_sharding_constraint.*Sharding "
1068+
r"MeshPspecSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
1069+
r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
1070+
"valid for values of rank at least 4, but was applied to a value of rank 1"):
1071+
pjit_f(jnp.array([1, 2, 3]))
1072+
10551073

10561074
class GDAPjitTest(jtu.JaxTestCase):
10571075

@@ -2180,8 +2198,7 @@ def testRankTooLowConstraint(self):
21802198
x = jnp.arange(2)
21812199
spec = P('x', 'y')
21822200
error = re.compile(
2183-
r"One of with_sharding_constraint arguments" + r".*" + spec_regex(
2184-
pxla.array_mapping_to_axis_resources(pxla._get_array_mapping(spec))) +
2201+
r"One of with_sharding_constraint arguments" + r".*" + spec_regex(spec) +
21852202
r".*rank at least 2, but was applied to a value of rank 1", re.M | re.S)
21862203
with self.assertRaisesRegex(ValueError, error):
21872204
pjit(lambda x: with_sharding_constraint(x, spec),
@@ -2625,6 +2642,22 @@ def test_simulated_training_cache_in_pjit(self):
26252642
self.assertEqual(id(next_op_sharding_sharding._op_sharding),
26262643
id(op_sharding_sharding._op_sharding))
26272644

2645+
def test_get_partition_spec(self):
2646+
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
2647+
s = MeshPspecSharding(mesh, P('x', 'y', None))
2648+
2649+
self.assertEqual(s._parsed_pspec.get_partition_spec(), P('x', 'y', None))
2650+
2651+
recovered_parsed_pspec = pjit_lib.parse_flatten_op_sharding(
2652+
s._to_xla_op_sharding(3), mesh)
2653+
self.assertEqual(recovered_parsed_pspec[0].get_partition_spec(),
2654+
P(('x',), ('y',)))
2655+
2656+
out_of_sync_parsed_pspec = pjit_lib.ParsedPartitionSpec(
2657+
P('x', 'y'), ('x', 'y'), pjit_lib.SpecSync.OUT_OF_SYNC)
2658+
self.assertEqual(out_of_sync_parsed_pspec.get_partition_spec(),
2659+
P(('x',), ('y',)))
2660+
26282661

26292662
if __name__ == '__main__':
26302663
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)