11
11
12
12
import torch
13
13
import torch .distributed as dist
14
+ import torch .nn .functional as F
14
15
from torch .distributed ._shard .sharded_tensor import Shard
15
16
from torchrec .distributed .types import (
16
17
ParameterSharding ,
19
20
ShardingEnv ,
20
21
)
21
22
23
+ OrderedShardNamesWithSizes = List [Tuple [str , List [int ]]]
24
+ """
25
+ A type alias to represent an ordered shard name and the corresponding shard_size
26
+ in dim 0 & 1 that were sent to the current rank.
27
+ This is a flattened and pruned nested list, which orders the shards names and
28
+ sizes in the following priority:
29
+ 1. Rank order
30
+ 2. Table order
31
+ 3. Shard order
32
+
33
+ <table_x, shard_y> in below examples represent the 2d tensor correlated to a
34
+ certain table `x`, allocated to rank `z`. The `y` here denotes the order of shards
35
+ in the module attributes such as state_dict, sharding_plan, etc..
36
+
37
+ `z` != `y` numerically, but the order of shards is based on the order of ranks allocated
38
+
39
+ Example 1 NOTE: the ordering by rank:
40
+ Rank 0 sends table_0, shard_0 to Rank 1.
41
+ Rank 2 sends table_1, shard_0 to Rank 1.
42
+ Rank 2 sends table_1, shard_1 to Rank 0
43
+ Rank 3 sends table_0, shard_1 to Rank 0
44
+
45
+ NOTE: table_1 comes first due to its source rank being 'first'
46
+ On Rank 0:output_tensor = [
47
+ <table_1, shard_0>, # from rank 2
48
+ <table_0, shard_1> # from rank 3
49
+ ]
50
+
51
+ On Rank 1: output_tensor = [
52
+ <table_0, shard_0>, # from rank 0
53
+ <table_1, shard_0> # from rank 2
54
+ ]
55
+
56
+ Example 2: NOTE: ordered by table when ranks are the same
57
+ Rank 0 sends table_1 to Rank 1
58
+ Rank 0 sends table_2 to Rank 1
59
+
60
+ output_tensor = [
61
+ <table_0, shard_y>,
62
+ <table_1, shard_y>
63
+ ]
64
+
65
+ Example 3: NOTE: ordered by shard if table and rank are the same
66
+ Rank 0 sends table_1, shard_0 to Rank 1
67
+ Rank 0 sends table_1, shard_1 to Rank 1
68
+
69
+ Rank 1: output_tensor = [
70
+ <table_0, shard_0>,
71
+ <table_1, shard_1>
72
+ ]
73
+ """
74
+
22
75
23
76
def shards_all_to_all (
24
77
module : ShardedModule [Any , Any , Any , Any ], # pyre-ignore
25
78
state_dict : Dict [str , ShardedTensor ],
26
79
device : torch .device ,
27
80
changed_sharding_params : Dict [str , ParameterSharding ],
28
81
env : ShardingEnv ,
82
+ max_dim_0 : int ,
83
+ max_dim_1 : int ,
29
84
extend_shard_name : Callable [[str ], str ] = lambda x : x ,
30
- ) -> Tuple [List [ Tuple [ str , int ]] , torch .Tensor ]:
85
+ ) -> Tuple [OrderedShardNamesWithSizes , torch .Tensor ]:
31
86
"""
32
87
Performs an all-to-all communication to redistribute shards across ranks based on new sharding parameters.
33
- Assumes ranks are ordered in ParameterSharding.ranks.
88
+ Assumes ranks are ordered in ParameterSharding.ranks. Implements padding for concatenating, sending and
89
+ receiving tensors of different sizes in dim 0 or 1.
34
90
35
91
Args:
36
92
module (ShardedModule[Any, Any, Any, Any]): The module containing sharded tensors to be redistributed.
@@ -46,10 +102,14 @@ def shards_all_to_all(
46
102
47
103
extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict.
48
104
105
+ max_dim_0 (int): The maximum dimension size of dim 0 across all tables in the module.
106
+
107
+ max_dim_1 (int): The maximum dimension size of dim 1 across all tables in the module.
108
+
49
109
Returns:
50
- Tuple[List[Tuple[str, int]], torch.Tensor]: A tuple containing:
51
- - A list of shard name and the corresponding shard_size in dim 1 that were sent to the current rank.
52
- This is a flattened and pruned nested list, which orders the shards names and sizes by rank, then shard order.
110
+ Tuple[List[Tuple[str, List[ int]]] , torch.Tensor]: Two outputs containing:
111
+ - A list of shard name and the corresponding shard_size in dim 0 & 1 that were sent to the current rank.
112
+ This is a flattened and pruned nested list, which orders the shards names and sizes by source rank, then shard order.
53
113
- The tensor containing all shards received by the current rank after the all-to-all operation.
54
114
"""
55
115
if env .output_dtensor :
@@ -64,8 +124,6 @@ def shards_all_to_all(
64
124
input_splits_per_rank = [[0 ] * world_size for _ in range (world_size )]
65
125
output_splits_per_rank = [[0 ] * world_size for _ in range (world_size )]
66
126
67
- # 0 by default, as current rank may be recieving 0 shards
68
- num_embeddings_received = 0
69
127
output_tensor_tensor_count = 0
70
128
shard_names_to_lengths_by_src_rank = [[] for _ in range (world_size )]
71
129
local_table_to_input_tensor_by_dst_rank = [[] for _ in range (world_size )]
@@ -86,29 +144,20 @@ def shards_all_to_all(
86
144
src_rank = src_ranks [i ]
87
145
88
146
shard_size = sharded_t .metadata ().shards_metadata [i ].shard_sizes
89
- shard_size_dim_1 = shard_size [1 ]
90
- input_splits_per_rank [src_rank ][dst_rank ] += shard_size_dim_1
91
- output_splits_per_rank [dst_rank ][src_rank ] += shard_size_dim_1
147
+ input_splits_per_rank [src_rank ][dst_rank ] += max_dim_0
148
+ output_splits_per_rank [dst_rank ][src_rank ] += max_dim_0
92
149
if src_rank == rank :
93
150
local_shards = sharded_t .local_shards ()
94
151
assert len (local_shards ) == 1
95
- local_table_to_input_tensor_by_dst_rank [ dst_rank ]. append (
96
- sharded_t .local_shards ()[0 ].tensor
152
+ cur_t = pad_tensor_to_max_dims (
153
+ sharded_t .local_shards ()[0 ].tensor , max_dim_0 , max_dim_1
97
154
)
155
+ local_table_to_input_tensor_by_dst_rank [dst_rank ].append (cur_t )
98
156
if dst_rank == rank :
99
157
shard_names_to_lengths_by_src_rank [src_rank ].append (
100
- (shard_name , shard_size_dim_1 )
158
+ (shard_name , shard_size )
101
159
)
102
- # NOTE: Only need to update num_embeddings_received to be the
103
- # num_embeddings of shards if this rank is actually recieving
104
- # any tensors
105
- if num_embeddings_received == 0 :
106
- num_embeddings_received = shard_size [0 ]
107
- else :
108
- # TODO: for 2D and row-wise, shard_sizes in dim 0 may be variable
109
- # For now, assume that shard_sizes in dim 0 are all the same
110
- assert num_embeddings_received == shard_size [0 ]
111
- output_tensor_tensor_count += shard_size [1 ]
160
+ output_tensor_tensor_count += max_dim_0
112
161
113
162
local_input_splits = input_splits_per_rank [rank ]
114
163
local_output_splits = output_splits_per_rank [rank ]
@@ -121,16 +170,13 @@ def shards_all_to_all(
121
170
local_input_tensor ,
122
171
shard_info ,
123
172
),
124
- dim = 1 ,
173
+ dim = 0 ,
125
174
)
126
175
127
- # Transposing the Tensors - because we are concatenating them along dimension 1
128
- # This is because dim 0 size may be different for different shards
129
- # whereas dim 1 size is the same for all shards as dim 1 size = num_embeddings per table
176
+ max_embedding_size = max_dim_1
130
177
local_output_tensor = torch .empty (
131
- [output_tensor_tensor_count , num_embeddings_received ], device = device
178
+ [output_tensor_tensor_count , max_embedding_size ], device = device
132
179
)
133
- local_input_tensor = local_input_tensor .T .contiguous ()
134
180
135
181
assert sum (local_output_splits ) == len (local_output_tensor )
136
182
assert sum (local_input_splits ) == len (local_input_tensor )
@@ -153,22 +199,23 @@ def shards_all_to_all(
153
199
154
200
def update_state_dict_post_resharding (
155
201
state_dict : Dict [str , ShardedTensor ],
156
- ordered_shard_names_and_lengths : List [ Tuple [ str , int ]] ,
202
+ ordered_shard_names_and_lengths : OrderedShardNamesWithSizes ,
157
203
output_tensor : torch .Tensor ,
158
204
new_sharding_params : Dict [str , ParameterSharding ],
159
205
curr_rank : int ,
206
+ max_dim_0 : int ,
160
207
extend_shard_name : Callable [[str ], str ] = lambda x : x ,
161
208
) -> Dict [str , ShardedTensor ]:
162
209
"""
163
210
Updates and returns the given state_dict with new placements and
164
211
local_shards based on the output tensor of the AllToAll collective.
212
+ Removes padding from the output tensor in dim 0 and 1 if necessary.
165
213
166
214
Args:
167
215
state_dict (Dict[str, Any]): The state dict to be updated with new shard placements and local shards.
168
216
169
- shard_names_by_src_rank (List[Tuple[str, int]]): A list of shard name and the corresponding shard_size in dim 1
170
- that were sent to the current rank. This is a flattened and pruned nested list, which orders the shards names and
171
- sizes by rank, then shard order.
217
+ ordered_shard_names_and_lengths (List[Tuple[str, List[int]]]): A list of shard name and the corresponding shard_size.
218
+ This is a flattened and pruned nested list, which orders the shards names and sizes by rank, then shard order.
172
219
173
220
output_tensor (torch.Tensor): The tensor containing the output data from the AllToAll operation.
174
221
@@ -177,6 +224,10 @@ def update_state_dict_post_resharding(
177
224
178
225
curr_rank (int): The current rank of the process in the distributed environment.
179
226
227
+ max_dim_0 (int): The maximum dimension size of dim 0 across all tables in the module. Only dim 0
228
+ is needed here to slice the output tensor correctly, as removing the padding will only reference
229
+ the original shard sizes stored in ordered_shard_names_and_lengths.
230
+
180
231
extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict.
181
232
182
233
Returns:
@@ -187,10 +238,12 @@ def update_state_dict_post_resharding(
187
238
shard_name_to_local_output_tensor : Dict [str , torch .Tensor ] = {}
188
239
189
240
for shard_name , shard_size in ordered_shard_names_and_lengths :
190
- end_slice_index = slice_index + shard_size
191
- shard_name_to_local_output_tensor [shard_name ] = output_tensor [
192
- slice_index :end_slice_index
193
- ].T
241
+ end_slice_index = slice_index + max_dim_0
242
+ cur_t = output_tensor [slice_index :end_slice_index ]
243
+ cur_t = pad_tensor_to_max_dims (
244
+ cur_t , shard_size [0 ], shard_size [1 ], remove_padding = True
245
+ )
246
+ shard_name_to_local_output_tensor [shard_name ] = cur_t
194
247
slice_index = end_slice_index
195
248
196
249
for shard_name , param in new_sharding_params .items ():
@@ -234,3 +287,80 @@ def update_module_sharding_plan(
234
287
for table_name , param_sharding in changed_sharding_params .items ():
235
288
current_plan [table_name ] = param_sharding
236
289
return
290
+
291
+
292
+ def get_largest_dims_from_state_dict (
293
+ state_dict : Dict [str , ShardedTensor ],
294
+ ) -> Tuple [int , int ]:
295
+ """
296
+ Returns the largest dimension size of dim 0 and 1 across all tables in a module.
297
+
298
+ Args:
299
+ state_dict (Dict[str, ShardedTensor]): The state dict containing the sharded tensors.
300
+
301
+ Returns:
302
+ List[int]: A list of the largest dimension size of each table in the state_dict.
303
+ """
304
+ max_dim_0 = 0
305
+ max_dim_1 = 0
306
+ for sharded_t in state_dict .values ():
307
+ for shard in sharded_t .metadata ().shards_metadata :
308
+ max_dim_0 = max (max_dim_0 , shard .shard_sizes [0 ])
309
+ max_dim_1 = max (max_dim_1 , shard .shard_sizes [1 ])
310
+
311
+ return max_dim_0 , max_dim_1
312
+
313
+
314
+ def get_largest_dims_from_sharding_plan_updates (
315
+ sharding_plan_updates : Dict [str , ParameterSharding ],
316
+ ) -> Tuple [int , int ]:
317
+ """
318
+ Returns the largest dimension size of dim 0 and 1 across all tables in a module.
319
+
320
+ Args:
321
+ state_dict (Dict[str, ShardedTensor]): The state dict containing the sharded tensors.
322
+
323
+ Returns:
324
+ List[int]: A list of the largest dimension size of each table in the state_dict.
325
+ """
326
+ max_dim_0 = 0
327
+ max_dim_1 = 0
328
+ for _ , param in sharding_plan_updates .items ():
329
+ assert hasattr (param .sharding_spec , "shards" )
330
+ for shard in param .sharding_spec .shards : # pyre-ignore
331
+ max_dim_0 = max (max_dim_0 , shard .shard_sizes [0 ])
332
+ max_dim_1 = max (max_dim_1 , shard .shard_sizes [1 ])
333
+
334
+ return max_dim_0 , max_dim_1
335
+
336
+
337
+ def pad_tensor_to_max_dims (
338
+ t : torch .Tensor ,
339
+ expected_dim_0 : int ,
340
+ expected_dim_1 : int ,
341
+ remove_padding : bool = False ,
342
+ ) -> torch .Tensor :
343
+ """
344
+ Pads a tensor on the right and bottom with zeros.
345
+
346
+ Args:
347
+ tensor (torch.Tensor): The tensor to be padded.
348
+ pad_right (int): The number of zeros to pad on the right.
349
+ pad_bottom (int): The number of zeros to pad on the bottom.
350
+
351
+ Returns:
352
+ torch.Tensor: The padded tensor.
353
+ """
354
+ pad_right = expected_dim_1 - t .size (1 )
355
+ pad_bottom = expected_dim_0 - t .size (0 )
356
+ return F .pad (
357
+ input = t ,
358
+ pad = (
359
+ 0 ,
360
+ pad_right ,
361
+ 0 ,
362
+ pad_bottom ,
363
+ ), # right and bottom
364
+ mode = "constant" ,
365
+ value = 0 ,
366
+ )
0 commit comments