@@ -441,39 +441,21 @@ def _get_parameter_shard_indices_in_full_weight(
441
441
where it is located in the shard if it exists, or -1 if it's not in the shard.
442
442
Used to determine the location of each entry in a different distributed configuration.
443
443
"""
444
-
445
- # Create an empty index for the global parameter.
446
- index = torch .full (
447
- parameter_meta .global_shape ,
448
- - 1 ,
449
- dtype = torch .int64 ,
450
- device = device ,
451
- )
452
444
# Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard
453
445
begin , end = self ._get_parameter_range_in_shard (parameter_name )
454
446
455
- buffer_index = parameter_meta .global_to_local (index , expand = True )
456
- # Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible.
457
- # In that case, we work with a separate tensor to be copied back into `buffer_index`.
458
- try :
459
- buffer_index_flat = buffer_index .view (- 1 )
460
- is_view = True
461
- except RuntimeError :
462
- buffer_index_flat = buffer_index .new_full ((buffer_index .numel (),), - 1 )
463
- is_view = False
464
-
465
- # Copy the shard indices at their respective positions in the flat buffer index.
466
- buffer_index_flat [
447
+ # Create an empty local index to hold the local shard indices.
448
+ buffer_index = torch .full_like (parameter_meta , - 1 , dtype = torch .int64 , device = device )
449
+
450
+ # Copy the shard indices at their respective positions in the buffer index.
451
+ buffer_index .flatten ()[
467
452
self ._index_buffer_to_param (
468
453
self ._fsdp_dim .rank * self ._shard_size , parameter_name
469
454
) : self ._index_buffer_to_param ((self ._fsdp_dim .rank + 1 ) * self ._shard_size , parameter_name )
470
455
].copy_ (torch .arange (begin , end , dtype = torch .int64 , device = device ))
471
456
472
- # If needed, copy the flat buffer index back into the index.
473
- if not is_view :
474
- buffer_index .copy_ (buffer_index_flat .view_as (buffer_index ))
475
-
476
- return index
457
+ # Create a global index from the local one.
458
+ return parameter_meta .local_to_global_partial (buffer_index , - 1 )
477
459
478
460
def copy_shard_overlaps (
479
461
self ,
0 commit comments