@@ -1052,6 +1052,24 @@ def make_keys(seeds):
1052
1052
self .assertEqual (out .shape , input_shape )
1053
1053
out .unsafe_raw_array () # doesn't crash
1054
1054
1055
+ def test_with_sharding_constraint_is_compatible_error (self ):
1056
+ mesh = jtu .create_global_mesh ((1 , 1 , 2 ), ('replica' , 'data' , 'mdl' ))
1057
+
1058
+ with mesh :
1059
+ def f (x ):
1060
+ y = with_sharding_constraint (x , P (None , ('mdl' ,), None , None ))
1061
+ z = y + 2
1062
+ return z
1063
+ pjit_f = pjit (f , in_axis_resources = P (None ), out_axis_resources = P (None ))
1064
+
1065
+ with self .assertRaisesRegex (
1066
+ ValueError ,
1067
+ r"One of with_sharding_constraint.*Sharding "
1068
+ r"MeshPspecSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
1069
+ r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
1070
+ "valid for values of rank at least 4, but was applied to a value of rank 1" ):
1071
+ pjit_f (jnp .array ([1 , 2 , 3 ]))
1072
+
1055
1073
1056
1074
class GDAPjitTest (jtu .JaxTestCase ):
1057
1075
@@ -2180,8 +2198,7 @@ def testRankTooLowConstraint(self):
2180
2198
x = jnp .arange (2 )
2181
2199
spec = P ('x' , 'y' )
2182
2200
error = re .compile (
2183
- r"One of with_sharding_constraint arguments" + r".*" + spec_regex (
2184
- pxla .array_mapping_to_axis_resources (pxla ._get_array_mapping (spec ))) +
2201
+ r"One of with_sharding_constraint arguments" + r".*" + spec_regex (spec ) +
2185
2202
r".*rank at least 2, but was applied to a value of rank 1" , re .M | re .S )
2186
2203
with self .assertRaisesRegex (ValueError , error ):
2187
2204
pjit (lambda x : with_sharding_constraint (x , spec ),
@@ -2625,6 +2642,22 @@ def test_simulated_training_cache_in_pjit(self):
2625
2642
self .assertEqual (id (next_op_sharding_sharding ._op_sharding ),
2626
2643
id (op_sharding_sharding ._op_sharding ))
2627
2644
2645
+ def test_get_partition_spec (self ):
2646
+ mesh = jtu .create_global_mesh ((4 , 2 ), ('x' , 'y' ))
2647
+ s = MeshPspecSharding (mesh , P ('x' , 'y' , None ))
2648
+
2649
+ self .assertEqual (s ._parsed_pspec .get_partition_spec (), P ('x' , 'y' , None ))
2650
+
2651
+ recovered_parsed_pspec = pjit_lib .parse_flatten_op_sharding (
2652
+ s ._to_xla_op_sharding (3 ), mesh )
2653
+ self .assertEqual (recovered_parsed_pspec [0 ].get_partition_spec (),
2654
+ P (('x' ,), ('y' ,)))
2655
+
2656
+ out_of_sync_parsed_pspec = pjit_lib .ParsedPartitionSpec (
2657
+ P ('x' , 'y' ), ('x' , 'y' ), pjit_lib .SpecSync .OUT_OF_SYNC )
2658
+ self .assertEqual (out_of_sync_parsed_pspec .get_partition_spec (),
2659
+ P (('x' ,), ('y' ,)))
2660
+
2628
2661
2629
2662
if __name__ == '__main__' :
2630
2663
absltest .main (testLoader = jtu .JaxTestLoader ())
0 commit comments