Skip to content

Commit cc7f1d0

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Add padding in dynamic sharding for tensors before all2all (#2944)
Summary: Pull Request resolved: #2944 Given we can't expect shards in an embedding module to have the same dimensions for both dim 0 and dim 1, we have to pad the tensors passed into `all_to_all_single` collective to ensure we only call the expensive collective once. This diff: 1. adds the logic for padding tensors in both dimensions 2. adds logic to remove the padding when updating the state dict after resharding 3. Removes the original implentation of concatenating input tensors by dim 1 (which assumes dim 0 can be variable but dim 1 is consistent across all shards) and transposing 1. This ensures that the existing CW unit test is leveraging the padding logic, as CW unit test was the previous one that failed due to inconsistent dimensions. Padding leverages `nn.Functional.pad`, and pads tensors with value 0 on the right and bottom: e.g. ``` t = [[1, 2] [3, 4]] max_dim_0 = 4 max_dim_1 = 3 t = pad_tensor_to_max_dims(t, max_dim_0, max_dim_1) print(t) >>> [[1, 2, 0, 0] [3, 4, 0, 0] [0, 0, 0, 0]] ``` Max dimensions for dim 0 and 1 are determined by going through all shard sizes that are being redistrbuted. This is because we need to ensure the `output_tensor` passing into a2a has large enough size. Reviewed By: iamzainhuda Differential Revision: D74150894 fbshipit-source-id: 0d6f4ee4814d53de3785ea50396e5a6467bf308d
1 parent bdb3606 commit cc7f1d0

File tree

2 files changed

+176
-37
lines changed

2 files changed

+176
-37
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
5959
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
6060
from torchrec.distributed.sharding.dynamic_sharding import (
61+
get_largest_dims_from_sharding_plan_updates,
6162
shards_all_to_all,
6263
update_module_sharding_plan,
6364
update_state_dict_post_resharding,
@@ -1545,13 +1546,20 @@ def update_shards(
15451546
# Deleting all lookups
15461547
self._lookups.clear()
15471548

1549+
# Get max dim size to enable padding for all_to_all
1550+
max_dim_0, max_dim_1 = get_largest_dims_from_sharding_plan_updates(
1551+
changed_sharding_params
1552+
)
1553+
15481554
local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all(
15491555
module=self,
15501556
state_dict=current_state,
15511557
device=device, # pyre-ignore
15521558
changed_sharding_params=changed_sharding_params,
15531559
env=env,
15541560
extend_shard_name=self.extend_shard_name,
1561+
max_dim_0=max_dim_0,
1562+
max_dim_1=max_dim_1,
15551563
)
15561564

15571565
current_state = update_state_dict_post_resharding(
@@ -1561,6 +1569,7 @@ def update_shards(
15611569
new_sharding_params=changed_sharding_params,
15621570
curr_rank=dist.get_rank(),
15631571
extend_shard_name=self.extend_shard_name,
1572+
max_dim_0=max_dim_0,
15641573
)
15651574

15661575
for name, param in changed_sharding_params.items():

torchrec/distributed/sharding/dynamic_sharding.py

Lines changed: 167 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
import torch.distributed as dist
14+
import torch.nn.functional as F
1415
from torch.distributed._shard.sharded_tensor import Shard
1516
from torchrec.distributed.types import (
1617
ParameterSharding,
@@ -19,18 +20,73 @@
1920
ShardingEnv,
2021
)
2122

23+
OrderedShardNamesWithSizes = List[Tuple[str, List[int]]]
24+
"""
25+
A type alias to represent an ordered shard name and the corresponding shard_size
26+
in dim 0 & 1 that were sent to the current rank.
27+
This is a flattened and pruned nested list, which orders the shards names and
28+
sizes in the following priority:
29+
1. Rank order
30+
2. Table order
31+
3. Shard order
32+
33+
<table_x, shard_y> in below examples represent the 2d tensor correlated to a
34+
certain table `x`, allocated to rank `z`. The `y` here denotes the order of shards
35+
in the module attributes such as state_dict, sharding_plan, etc..
36+
37+
`z` != `y` numerically, but the order of shards is based on the order of ranks allocated
38+
39+
Example 1 NOTE: the ordering by rank:
40+
Rank 0 sends table_0, shard_0 to Rank 1.
41+
Rank 2 sends table_1, shard_0 to Rank 1.
42+
Rank 2 sends table_1, shard_1 to Rank 0
43+
Rank 3 sends table_0, shard_1 to Rank 0
44+
45+
NOTE: table_1 comes first due to its source rank being 'first'
46+
On Rank 0:output_tensor = [
47+
<table_1, shard_0>, # from rank 2
48+
<table_0, shard_1> # from rank 3
49+
]
50+
51+
On Rank 1: output_tensor = [
52+
<table_0, shard_0>, # from rank 0
53+
<table_1, shard_0> # from rank 2
54+
]
55+
56+
Example 2: NOTE: ordered by table when ranks are the same
57+
Rank 0 sends table_1 to Rank 1
58+
Rank 0 sends table_2 to Rank 1
59+
60+
output_tensor = [
61+
<table_0, shard_y>,
62+
<table_1, shard_y>
63+
]
64+
65+
Example 3: NOTE: ordered by shard if table and rank are the same
66+
Rank 0 sends table_1, shard_0 to Rank 1
67+
Rank 0 sends table_1, shard_1 to Rank 1
68+
69+
Rank 1: output_tensor = [
70+
<table_0, shard_0>,
71+
<table_1, shard_1>
72+
]
73+
"""
74+
2275

2376
def shards_all_to_all(
2477
module: ShardedModule[Any, Any, Any, Any], # pyre-ignore
2578
state_dict: Dict[str, ShardedTensor],
2679
device: torch.device,
2780
changed_sharding_params: Dict[str, ParameterSharding],
2881
env: ShardingEnv,
82+
max_dim_0: int,
83+
max_dim_1: int,
2984
extend_shard_name: Callable[[str], str] = lambda x: x,
30-
) -> Tuple[List[Tuple[str, int]], torch.Tensor]:
85+
) -> Tuple[OrderedShardNamesWithSizes, torch.Tensor]:
3186
"""
3287
Performs an all-to-all communication to redistribute shards across ranks based on new sharding parameters.
33-
Assumes ranks are ordered in ParameterSharding.ranks.
88+
Assumes ranks are ordered in ParameterSharding.ranks. Implements padding for concatenating, sending and
89+
receiving tensors of different sizes in dim 0 or 1.
3490
3591
Args:
3692
module (ShardedModule[Any, Any, Any, Any]): The module containing sharded tensors to be redistributed.
@@ -46,10 +102,14 @@ def shards_all_to_all(
46102
47103
extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict.
48104
105+
max_dim_0 (int): The maximum dimension size of dim 0 across all tables in the module.
106+
107+
max_dim_1 (int): The maximum dimension size of dim 1 across all tables in the module.
108+
49109
Returns:
50-
Tuple[List[Tuple[str, int]], torch.Tensor]: A tuple containing:
51-
- A list of shard name and the corresponding shard_size in dim 1 that were sent to the current rank.
52-
This is a flattened and pruned nested list, which orders the shards names and sizes by rank, then shard order.
110+
Tuple[List[Tuple[str, List[int]]], torch.Tensor]: Two outputs containing:
111+
- A list of shard name and the corresponding shard_size in dim 0 & 1 that were sent to the current rank.
112+
This is a flattened and pruned nested list, which orders the shards names and sizes by source rank, then shard order.
53113
- The tensor containing all shards received by the current rank after the all-to-all operation.
54114
"""
55115
if env.output_dtensor:
@@ -64,8 +124,6 @@ def shards_all_to_all(
64124
input_splits_per_rank = [[0] * world_size for _ in range(world_size)]
65125
output_splits_per_rank = [[0] * world_size for _ in range(world_size)]
66126

67-
# 0 by default, as current rank may be recieving 0 shards
68-
num_embeddings_received = 0
69127
output_tensor_tensor_count = 0
70128
shard_names_to_lengths_by_src_rank = [[] for _ in range(world_size)]
71129
local_table_to_input_tensor_by_dst_rank = [[] for _ in range(world_size)]
@@ -86,29 +144,20 @@ def shards_all_to_all(
86144
src_rank = src_ranks[i]
87145

88146
shard_size = sharded_t.metadata().shards_metadata[i].shard_sizes
89-
shard_size_dim_1 = shard_size[1]
90-
input_splits_per_rank[src_rank][dst_rank] += shard_size_dim_1
91-
output_splits_per_rank[dst_rank][src_rank] += shard_size_dim_1
147+
input_splits_per_rank[src_rank][dst_rank] += max_dim_0
148+
output_splits_per_rank[dst_rank][src_rank] += max_dim_0
92149
if src_rank == rank:
93150
local_shards = sharded_t.local_shards()
94151
assert len(local_shards) == 1
95-
local_table_to_input_tensor_by_dst_rank[dst_rank].append(
96-
sharded_t.local_shards()[0].tensor
152+
cur_t = pad_tensor_to_max_dims(
153+
sharded_t.local_shards()[0].tensor, max_dim_0, max_dim_1
97154
)
155+
local_table_to_input_tensor_by_dst_rank[dst_rank].append(cur_t)
98156
if dst_rank == rank:
99157
shard_names_to_lengths_by_src_rank[src_rank].append(
100-
(shard_name, shard_size_dim_1)
158+
(shard_name, shard_size)
101159
)
102-
# NOTE: Only need to update num_embeddings_received to be the
103-
# num_embeddings of shards if this rank is actually recieving
104-
# any tensors
105-
if num_embeddings_received == 0:
106-
num_embeddings_received = shard_size[0]
107-
else:
108-
# TODO: for 2D and row-wise, shard_sizes in dim 0 may be variable
109-
# For now, assume that shard_sizes in dim 0 are all the same
110-
assert num_embeddings_received == shard_size[0]
111-
output_tensor_tensor_count += shard_size[1]
160+
output_tensor_tensor_count += max_dim_0
112161

113162
local_input_splits = input_splits_per_rank[rank]
114163
local_output_splits = output_splits_per_rank[rank]
@@ -121,16 +170,13 @@ def shards_all_to_all(
121170
local_input_tensor,
122171
shard_info,
123172
),
124-
dim=1,
173+
dim=0,
125174
)
126175

127-
# Transposing the Tensors - because we are concatenating them along dimension 1
128-
# This is because dim 0 size may be different for different shards
129-
# whereas dim 1 size is the same for all shards as dim 1 size = num_embeddings per table
176+
max_embedding_size = max_dim_1
130177
local_output_tensor = torch.empty(
131-
[output_tensor_tensor_count, num_embeddings_received], device=device
178+
[output_tensor_tensor_count, max_embedding_size], device=device
132179
)
133-
local_input_tensor = local_input_tensor.T.contiguous()
134180

135181
assert sum(local_output_splits) == len(local_output_tensor)
136182
assert sum(local_input_splits) == len(local_input_tensor)
@@ -153,22 +199,23 @@ def shards_all_to_all(
153199

154200
def update_state_dict_post_resharding(
155201
state_dict: Dict[str, ShardedTensor],
156-
ordered_shard_names_and_lengths: List[Tuple[str, int]],
202+
ordered_shard_names_and_lengths: OrderedShardNamesWithSizes,
157203
output_tensor: torch.Tensor,
158204
new_sharding_params: Dict[str, ParameterSharding],
159205
curr_rank: int,
206+
max_dim_0: int,
160207
extend_shard_name: Callable[[str], str] = lambda x: x,
161208
) -> Dict[str, ShardedTensor]:
162209
"""
163210
Updates and returns the given state_dict with new placements and
164211
local_shards based on the output tensor of the AllToAll collective.
212+
Removes padding from the output tensor in dim 0 and 1 if necessary.
165213
166214
Args:
167215
state_dict (Dict[str, Any]): The state dict to be updated with new shard placements and local shards.
168216
169-
shard_names_by_src_rank (List[Tuple[str, int]]): A list of shard name and the corresponding shard_size in dim 1
170-
that were sent to the current rank. This is a flattened and pruned nested list, which orders the shards names and
171-
sizes by rank, then shard order.
217+
ordered_shard_names_and_lengths (List[Tuple[str, List[int]]]): A list of shard name and the corresponding shard_size.
218+
This is a flattened and pruned nested list, which orders the shards names and sizes by rank, then shard order.
172219
173220
output_tensor (torch.Tensor): The tensor containing the output data from the AllToAll operation.
174221
@@ -177,6 +224,10 @@ def update_state_dict_post_resharding(
177224
178225
curr_rank (int): The current rank of the process in the distributed environment.
179226
227+
max_dim_0 (int): The maximum dimension size of dim 0 across all tables in the module. Only dim 0
228+
is needed here to slice the output tensor correctly, as removing the padding will only reference
229+
the original shard sizes stored in ordered_shard_names_and_lengths.
230+
180231
extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict.
181232
182233
Returns:
@@ -187,10 +238,12 @@ def update_state_dict_post_resharding(
187238
shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {}
188239

189240
for shard_name, shard_size in ordered_shard_names_and_lengths:
190-
end_slice_index = slice_index + shard_size
191-
shard_name_to_local_output_tensor[shard_name] = output_tensor[
192-
slice_index:end_slice_index
193-
].T
241+
end_slice_index = slice_index + max_dim_0
242+
cur_t = output_tensor[slice_index:end_slice_index]
243+
cur_t = pad_tensor_to_max_dims(
244+
cur_t, shard_size[0], shard_size[1], remove_padding=True
245+
)
246+
shard_name_to_local_output_tensor[shard_name] = cur_t
194247
slice_index = end_slice_index
195248

196249
for shard_name, param in new_sharding_params.items():
@@ -234,3 +287,80 @@ def update_module_sharding_plan(
234287
for table_name, param_sharding in changed_sharding_params.items():
235288
current_plan[table_name] = param_sharding
236289
return
290+
291+
292+
def get_largest_dims_from_state_dict(
293+
state_dict: Dict[str, ShardedTensor],
294+
) -> Tuple[int, int]:
295+
"""
296+
Returns the largest dimension size of dim 0 and 1 across all tables in a module.
297+
298+
Args:
299+
state_dict (Dict[str, ShardedTensor]): The state dict containing the sharded tensors.
300+
301+
Returns:
302+
List[int]: A list of the largest dimension size of each table in the state_dict.
303+
"""
304+
max_dim_0 = 0
305+
max_dim_1 = 0
306+
for sharded_t in state_dict.values():
307+
for shard in sharded_t.metadata().shards_metadata:
308+
max_dim_0 = max(max_dim_0, shard.shard_sizes[0])
309+
max_dim_1 = max(max_dim_1, shard.shard_sizes[1])
310+
311+
return max_dim_0, max_dim_1
312+
313+
314+
def get_largest_dims_from_sharding_plan_updates(
315+
sharding_plan_updates: Dict[str, ParameterSharding],
316+
) -> Tuple[int, int]:
317+
"""
318+
Returns the largest dimension size of dim 0 and 1 across all tables in a module.
319+
320+
Args:
321+
state_dict (Dict[str, ShardedTensor]): The state dict containing the sharded tensors.
322+
323+
Returns:
324+
List[int]: A list of the largest dimension size of each table in the state_dict.
325+
"""
326+
max_dim_0 = 0
327+
max_dim_1 = 0
328+
for _, param in sharding_plan_updates.items():
329+
assert hasattr(param.sharding_spec, "shards")
330+
for shard in param.sharding_spec.shards: # pyre-ignore
331+
max_dim_0 = max(max_dim_0, shard.shard_sizes[0])
332+
max_dim_1 = max(max_dim_1, shard.shard_sizes[1])
333+
334+
return max_dim_0, max_dim_1
335+
336+
337+
def pad_tensor_to_max_dims(
338+
t: torch.Tensor,
339+
expected_dim_0: int,
340+
expected_dim_1: int,
341+
remove_padding: bool = False,
342+
) -> torch.Tensor:
343+
"""
344+
Pads a tensor on the right and bottom with zeros.
345+
346+
Args:
347+
tensor (torch.Tensor): The tensor to be padded.
348+
pad_right (int): The number of zeros to pad on the right.
349+
pad_bottom (int): The number of zeros to pad on the bottom.
350+
351+
Returns:
352+
torch.Tensor: The padded tensor.
353+
"""
354+
pad_right = expected_dim_1 - t.size(1)
355+
pad_bottom = expected_dim_0 - t.size(0)
356+
return F.pad(
357+
input=t,
358+
pad=(
359+
0,
360+
pad_right,
361+
0,
362+
pad_bottom,
363+
), # right and bottom
364+
mode="constant",
365+
value=0,
366+
)

0 commit comments

Comments
 (0)