diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 1b128164a22b..5b74f05c35f7 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -653,6 +653,23 @@ def test_send_cpu_data_to_device_with_sharding(self): torch_xla._XLAC._get_xla_sharding_spec(xt), torch_xla._XLAC._get_xla_sharding_spec(explicit_xt)) + def test_send_cpu_data_to_device_with_multiple_sharding(self): + tensors = [torch.randn(16), torch.randn(16, 16), torch.randn(16, 16, 16)] + mesh = self._get_mesh((self.n_devices, 1)) + specs = [ + xs.ShardingSpec(mesh, spec) for spec in [(0, None), (0, None, None)] + ] + xtensors = xm.send_cpu_data_to_device(tensors, xm.xla_device(), specs) + str_specs = [torch_xla._XLAC._get_xla_sharding_spec(t) for t in xtensors] + self.assertEqual(str_specs[0], '{replicated}') + if self.n_devices > 1: + dev_fmt = (self.n_devices, ','.join(map(str, range(self.n_devices)))) + self.assertEqual(str_specs[1], "{devices=[%d,1]%s}" % dev_fmt) + self.assertEqual(str_specs[2], "{devices=[%d,1,1]%s}" % dev_fmt) + else: + self.assertEqual(str_specs[1], '{replicated}') + self.assertEqual(str_specs[2], '{replicated}') + def test_multiple_operations(self): t1 = torch.randn(2, 2) t2 = torch.randn(2, 2) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index e85db1d20a6a..fbeface58b69 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -1012,7 +1012,16 @@ def convert_fn(tensors): devices = [str(device)] * len(tensors) shardings = None if input_sharding: - shardings = [input_sharding.xla_spec(t) for t in tensors] + if isinstance(input_sharding, list): + shardings = [None] * len(tensors) + # Apply the first matching ShardingSpec to each tensor. + for i, tensor in enumerate(tensors): + for sharding in input_sharding: + if sharding.can_apply(tensor): + shardings[i] = sharding.xla_spec(tensor) + break + else: + shardings = [input_sharding.xla_spec(t) for t in tensors] xtensors = torch_xla._XLAC._xla_tensors_from_aten(tensors, devices, shardings) return xtensors diff --git a/torch_xla/distributed/parallel_loader.py b/torch_xla/distributed/parallel_loader.py index 8af7196e95ca..64b884a30057 100644 --- a/torch_xla/distributed/parallel_loader.py +++ b/torch_xla/distributed/parallel_loader.py @@ -74,8 +74,9 @@ class ParallelLoader(object): host_to_device_transfer_threads (int, optional): The number of threads that work in parallel to transfer data from loader queue to device queue. Default: 1 - input_sharding (ShardingSpec, optional): Sharding spec to apply to - compatible input tensors after loading. + input_sharding (Union[ShardingSpec, List[ShardingSpec]], optional): Sharding + specs to apply to compatible input tensors when loading. When a list is + provided, the first matching sharding spec will be applied. Default: None """