@@ -332,6 +332,7 @@ def __init__(
332
332
offload_config : Optional [OffloadConfig ] = None ,
333
333
state_dict_on_rank_0_only : bool = False ,
334
334
gradient_predivide_factor : Optional [float ] = None ,
335
+ zero2_process_group : Optional [ProcessGroup ] = None ,
335
336
):
336
337
try :
337
338
import torch ._C
@@ -380,6 +381,9 @@ def __init__(
380
381
"parameter uses all the available ranks for the optimal performance."
381
382
)
382
383
self .reshard_after_forward = self ._orig_reshard_after_forward = reshard_after_forward
384
+
385
+ self .zero2_process_group = zero2_process_group
386
+
383
387
self .disable_reshard_on_root = disable_reshard_on_root
384
388
self .mixed_precision = mixed_precision
385
389
self .fp32_reduce_scatter = fp32_reduce_scatter
@@ -518,6 +522,9 @@ def __init__(
518
522
if isinstance (m , FullyShardedDataParallel ):
519
523
m ._free_ssd_offload ()
520
524
525
+ if self .zero2_process_group is not None :
526
+ assert not self .move_params_to_cpu
527
+
521
528
def _get_gradient_predivide_factor (self , world_size : int ) -> float :
522
529
factor : int = 1
523
530
while world_size % factor == 0 and world_size / factor > factor :
@@ -1419,7 +1426,10 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
1419
1426
outputs = self .module (* args , ** kwargs )
1420
1427
1421
1428
if self .reshard_after_forward :
1422
- self ._free_full_params ()
1429
+ if self .zero2_process_group is not None :
1430
+ self ._zero2_shard_to_smaller_group ()
1431
+ else :
1432
+ self ._free_full_params ()
1423
1433
if self .mixed_precision or self .move_params_to_cpu :
1424
1434
self ._free_fp16_param_shard ()
1425
1435
@@ -1499,7 +1509,10 @@ def _pre_backward_hook(*unused: Any) -> None:
1499
1509
# idempotent. So in case they are called unnecessarily, they don't incur much
1500
1510
# overhead.
1501
1511
if self .reshard_after_forward :
1502
- self ._rebuild_full_params ()
1512
+ if self .zero2_process_group is not None :
1513
+ self ._zero2_rebuild_full_params ()
1514
+ else :
1515
+ self ._rebuild_full_params ()
1503
1516
if (
1504
1517
self .reshard_after_forward
1505
1518
and self ._fsdp_forward_ordering is not None
@@ -2006,6 +2019,126 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
2006
2019
torch .cuda .current_stream ().wait_stream (self ._streams ["all_gather" ])
2007
2020
return output_tensors
2008
2021
2022
+
2023
+ @torch .no_grad ()
2024
+ def _zero2_rebuild_full_params (self , force_full_precision : bool = False , wait_for_all_gather = True ) -> Optional [List [Tuple [torch .Tensor , bool ]]]:
2025
+ """
2026
+ Gather all shards of params.
2027
+
2028
+ Note, this is idempotent if full params are already gathered. Callers
2029
+ assume the idempotency. So please keep it that way.
2030
+
2031
+ Args:
2032
+ force_full_precision (bool, Optional): by default params will be gathered
2033
+ in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
2034
+ ``True``, in which case they will be gathered in full precision
2035
+ (e.g., FP32), possibly in fresh storage. The parameter that's being
2036
+ rebuilt will end up in full precision as well.
2037
+
2038
+ Returns:
2039
+ A list of tuples, where the first element is the full-sized param
2040
+ and the second element is a bool indicating if it's safe for the
2041
+ caller to free the full-sized param. This will be ``None`` if
2042
+ ``force_full_precision=False`` and the full params are already gathered.
2043
+ """
2044
+ output_tensors : List [Tuple [torch .Tensor , bool ]] = []
2045
+
2046
+ def update_p_data (custom_output_tensor : Optional [torch .Tensor ] = None ) -> None :
2047
+ """
2048
+ Helper function to update p.data pointer.
2049
+
2050
+ Args:
2051
+ custom_output_tensor (torch.Tensor, Optional): if not None, this
2052
+ tensor contains the data we just gathered.
2053
+ """
2054
+ if custom_output_tensor is not None :
2055
+ assert p ._is_sharded
2056
+ p .data = custom_output_tensor
2057
+ output_tensors .append ((p .data , True ))
2058
+ elif not p ._is_sharded :
2059
+ if (self .mixed_precision or self .move_params_to_cpu ) and not force_full_precision :
2060
+ assert p ._fp16_shard is not None
2061
+ p .data = p ._fp16_shard
2062
+ output_tensors .append ((p .data , True ))
2063
+ else :
2064
+ # Here p.data == p._fp32_shard, so it's not safe to free.
2065
+ output_tensors .append ((p .data , False ))
2066
+ else :
2067
+ p .data = p ._full_param_padded
2068
+ output_tensors .append ((p .data , True ))
2069
+ # Trim any padding and reshape to match original size.
2070
+ p .data = p .data [: p ._orig_size .numel ()].view (p ._orig_size )
2071
+
2072
+ if self ._has_shared_params :
2073
+ # self.has_full_params flag can be out of sync if a shared param is
2074
+ # sharded by another FSDP instance. An example is that in eval case
2075
+ # with reshard_after_forward=False but the sharing instance has
2076
+ # reshard_after_forward=True. Then, on the second forward, the
2077
+ # other instance can shard the shared param and but this instance
2078
+ # can mistakenly think the full param is already gathered from the
2079
+ # has_full_params flag.
2080
+ #
2081
+ # Therefore, we update the flag accordingly here.
2082
+ self .has_full_params = not any (p ._full_param_padded .storage ().size () == 0 for p in self .params )
2083
+
2084
+ # Early exit if we already have full params and don't need full precision.
2085
+ if self .has_full_params and not force_full_precision :
2086
+ if wait_for_all_gather :
2087
+ torch .cuda .current_stream ().wait_stream (self ._streams ["all_gather" ])
2088
+ for p in self .params :
2089
+ update_p_data ()
2090
+ return output_tensors
2091
+
2092
+ self .has_full_params = True
2093
+
2094
+ with torch .cuda .stream (self ._streams ["all_gather" ]):
2095
+
2096
+ for p in self .params :
2097
+ if not p ._is_sharded : # e.g., when world_size == 1
2098
+ update_p_data ()
2099
+ else :
2100
+ # Skip if already built. Only shared param can be rebuilt multiple times.
2101
+ # A corner case is p._orig_size = (1,), which means the shape equality is
2102
+ # not a perfect check. But we assume we don't share a param with shape (1,).
2103
+ if p .data .shape == p ._orig_size and hasattr (p , "_is_shared" ) and p ._is_shared :
2104
+ continue
2105
+ # If self.move_params_to_cpu and force_full_precision, we need to cast
2106
+ # the FP32 CPU param to CUDA for the all-gather.
2107
+ p_data = p .data .to (p ._full_param_padded .device , non_blocking = True )
2108
+
2109
+ p_size = p ._full_param_padded .size ()
2110
+ assert p_size .numel () % self .world_size == 0
2111
+ if self .mixed_precision and force_full_precision :
2112
+ # Allocate fresh tensor in full precision since we are in
2113
+ # mixed precision and full precision rebuild is asked.
2114
+ output_tensor = p_data .new_zeros (p_size )
2115
+ else :
2116
+ if p ._full_param_padded .storage ().size () != p_size .numel ():
2117
+ # Allocate based on full size from all shards.
2118
+ alloc_storage_ (p ._full_param_padded , size = p_size )
2119
+ output_tensor = p ._full_param_padded
2120
+
2121
+ # Fill output_tensor with (p.data for each shard in self.world_size)
2122
+ if hasattr (dist , "_all_gather_base" ) and enable_nccl_base_collectives :
2123
+ # New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
2124
+ dist ._all_gather_base (output_tensor , p ._zero2_fp16_shard , group = self .zero2_process_group )
2125
+ else :
2126
+ chunks = list (output_tensor .chunk (self .world_size ))
2127
+ dist .all_gather (chunks , p ._zero2_fp16_shard , group = self .zero2_process_group )
2128
+
2129
+ # Set p.data = output_tensor (with padding trimmed)
2130
+ update_p_data (output_tensor )
2131
+
2132
+ if (self .mixed_precision or self .move_params_to_cpu ) and not force_full_precision :
2133
+ self ._free_zero2_param_shard ([p ])
2134
+
2135
+ if self .move_params_to_cpu and (self .params [0 ].dtype == self .compute_dtype ):
2136
+ self ._free_zero2_param_shard ([p ])
2137
+ if wait_for_all_gather :
2138
+ torch .cuda .current_stream ().wait_stream (self ._streams ["all_gather" ])
2139
+ return output_tensors
2140
+
2141
+
2009
2142
@torch .no_grad ()
2010
2143
def _use_full_params (self ) -> None :
2011
2144
"""
@@ -2074,6 +2207,38 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
2074
2207
free_storage_ (p ._full_param_padded )
2075
2208
torch .cuda .current_stream ().synchronize ()
2076
2209
2210
+
2211
+ def _zero2_shard_to_smaller_group (self , params : Optional [List [Parameter ]] = None ):
2212
+ if params is None :
2213
+ params = self .params
2214
+ self .has_full_params = False
2215
+ current_stream = torch .cuda .current_stream ()
2216
+ for p in params :
2217
+ if not p ._is_sharded : # e.g., world_size == 1
2218
+ if self .mixed_precision or self .move_params_to_cpu :
2219
+ self ._free_fp16_param_shard ([p ])
2220
+ continue
2221
+ # Cases for when zero2 world size > 1 but less than zero3 size
2222
+ zero2_world_size = dist .get_world_size (self .zero2_process_group )
2223
+ zero2_rank = dist .get_rank (self .zero2_process_group )
2224
+ chunks = p ._full_param_padded .chunk (zero2_world_size )
2225
+
2226
+ p ._zero2_fp16_shard = torch .empty_like (chunks [zero2_rank ])
2227
+ p ._zero2_fp16_shard .copy_ (chunks [zero2_rank ])
2228
+
2229
+ # Don't let PyTorch reuse this memory until all work in the current
2230
+ # stream is complete.
2231
+ p ._full_param_padded .record_stream (current_stream )
2232
+ # There may be external references to the Tensor Storage that we
2233
+ # can't modify, such as references that are created by
2234
+ # ctx.save_for_backward in the forward pass. Thus when we
2235
+ # unshard parameters, we should reuse the original Tensor
2236
+ # Storage object and unshard it in-place. For now, just resize
2237
+ # the Storage to 0 to save memory.
2238
+ free_storage_ (p ._full_param_padded )
2239
+ torch .cuda .current_stream ().synchronize ()
2240
+
2241
+
2077
2242
def local_metadata_dict (self ) -> Dict [str , Any ]:
2078
2243
"""
2079
2244
Get the information needed to reconstruct the model from shards offline.
@@ -2238,6 +2403,19 @@ def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> No
2238
2403
p ._fp16_shard .record_stream (current_stream )
2239
2404
free_storage_ (p ._fp16_shard )
2240
2405
2406
+ @torch .no_grad ()
2407
+ def _free_zero2_param_shard (self , params : Optional [List [Parameter ]] = None ) -> None :
2408
+ """Free storage for FP16 shards for a list of params."""
2409
+ if params is None :
2410
+ params = self .params
2411
+ current_stream = torch .cuda .current_stream ()
2412
+ for p in params :
2413
+ if p ._zero2_fp16_shard is not None :
2414
+ # _fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
2415
+ # free it until the work in the current stream completes.
2416
+ p ._zero2_fp16_shard .record_stream (current_stream )
2417
+ free_storage_ (p ._zero2_fp16_shard )
2418
+
2241
2419
def assert_state (self , state : Union [TrainingState , List [TrainingState ]]) -> None :
2242
2420
"""Assert we are in the given state."""
2243
2421
# Since assert can be turned off and this error checking
0 commit comments