Skip to content

Commit ab617f8

Browse files
committed
Support list of ShardingSpec in MpDeviceLoader
1 parent bd29e79 commit ab617f8

File tree

3 files changed

+28
-3
lines changed

3 files changed

+28
-3
lines changed

test/spmd/test_xla_sharding.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,23 @@ def test_send_cpu_data_to_device_with_sharding(self):
653653
torch_xla._XLAC._get_xla_sharding_spec(xt),
654654
torch_xla._XLAC._get_xla_sharding_spec(explicit_xt))
655655

656+
def test_send_cpu_data_to_device_with_multiple_sharding(self):
657+
tensors = [torch.randn(16), torch.randn(16, 16), torch.randn(16, 16, 16)]
658+
mesh = self._get_mesh((self.n_devices, 1))
659+
specs = [
660+
xs.ShardingSpec(mesh, spec) for spec in [(0, None), (0, None, None)]
661+
]
662+
xtensors = xm.send_cpu_data_to_device(tensors, xm.xla_device(), specs)
663+
str_specs = [torch_xla._XLAC._get_xla_sharding_spec(t) for t in xtensors]
664+
self.assertEqual(str_specs[0], '{replicated}')
665+
if self.n_devices > 1:
666+
dev_fmt = (self.n_devices, ','.join(map(str, range(self.n_devices))))
667+
self.assertEqual(str_specs[1], "{devices=[%d,1]%s}" % dev_fmt)
668+
self.assertEqual(str_specs[2], "{devices=[%d,1,1]%s}" % dev_fmt)
669+
else:
670+
self.assertEqual(str_specs[1], '{replicated}')
671+
self.assertEqual(str_specs[2], '{replicated}')
672+
656673
def test_multiple_operations(self):
657674
t1 = torch.randn(2, 2)
658675
t2 = torch.randn(2, 2)

torch_xla/core/xla_model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,15 @@ def convert_fn(tensors):
10121012
devices = [str(device)] * len(tensors)
10131013
shardings = None
10141014
if input_sharding:
1015-
shardings = [input_sharding.xla_spec(t) for t in tensors]
1015+
if isinstance(input_sharding, list):
1016+
shardings = [None] * len(tensors)
1017+
for i, tensor in enumerate(tensors):
1018+
for sharding in input_sharding:
1019+
if sharding.can_apply(tensor):
1020+
shardings[i] = sharding.xla_spec(tensor)
1021+
break
1022+
else:
1023+
shardings = [input_sharding.xla_spec(t) for t in tensors]
10161024
xtensors = torch_xla._XLAC._xla_tensors_from_aten(tensors, devices,
10171025
shardings)
10181026
return xtensors

torch_xla/distributed/parallel_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ class ParallelLoader(object):
7474
host_to_device_transfer_threads (int, optional): The number of threads that
7575
work in parallel to transfer data from loader queue to device queue.
7676
Default: 1
77-
input_sharding (ShardingSpec, optional): Sharding spec to apply to
78-
compatible input tensors after loading.
77+
input_sharding (Union[ShardingSpec, List[ShardingSpec]], optional): Sharding
78+
specs to apply to compatible input tensors when loading.
7979
Default: None
8080
"""
8181

0 commit comments

Comments
 (0)