Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support resume ZeRO1 in a new data parallelism size #263

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/nanotron/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@

CHECKPOINT_FILE_NAME = "checkpoint_metadata.json"
MODEL_CONFIG_FILE_NAME = "model_config.json"
OPTIMIZER_CONFIG_FILE_NAME = "optimizer_config.json"
13 changes: 12 additions & 1 deletion src/nanotron/optim/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def __init__(
grad_buckets_named_params: The parameters to accumulate gradients for. If None it defaults to `named_parameters`. In case of Zero 1, this should be all the parameters in the model.

Note: We use `grad_buckets_named_params` to keep grad buffers for all parameters even when Zero 1 is used. This is because we need to accumulate gradients for all parameters without having to reduce in every accumulation step.
Note: We make a fp32 copy of parameters during initialization. Therefore parameters need to be initialized or loaded from a checkpoint before constructing this gradient accumulator
Note: We make a fp32 copy of parameters during initialization. Therefore parameters need to be initialized or loaded from a checkpoint before constructing this gradient accumulator.

"self.parameters"
- .fp32: the pointer to the full precision weights
- .half: the pointer to the half precision weights
"""
if grad_buckets_named_params is None:
named_parameters = list(named_parameters)
Expand Down Expand Up @@ -108,6 +112,9 @@ def __init__(

# Check that fp32 weights have the same memory representation as half precision weights
assert fp32_param.stride() == half_param.stride()
assert (
fp32_param.numel() == half_param.numel()
), f"There is a size mismatch of {name}, fp32_param: {fp32_param.numel()}, half_param: {half_param.numel()}"

# Copy weights from half precision to full precision
fp32_param.copy_(half_param)
Expand Down Expand Up @@ -289,6 +296,10 @@ def state_dict(self) -> Dict[str, torch.Tensor]:
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
assert set(state_dict.keys()) == set(self.parameters.keys())

# NOTE: double check if the dp size in the checkpoint
# is differ from the current dp size, then we merge the states
# and reshard them again

with torch.inference_mode():
for name, elt in self.parameters.items():
elt["fp32"].copy_(state_dict[name])
Expand Down
27 changes: 15 additions & 12 deletions src/nanotron/optim/zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
# partition model's params across DP ranks.
# `self.param_name_to_dp_rank_offsets` sets mapping between each param inside self.named_params and its rank
# NOTE: some param_groups may have no params in the current rank. we still keep them in self.optimizer.param_groups
# TODO: maybe not shard layernorm params in zero-1, because it is small anyway
self.param_name_to_dp_rank_offsets = self._partition_parameters()

current_dp_rank = dist.get_rank(self.dp_pg)
Expand Down Expand Up @@ -171,6 +172,8 @@ def _partition_parameters(self) -> Dict[str, Dict[int, Tuple[int, int]]]:
for name, param in named_params:
# We assume parameter to be contiguous in order to have an easy way of sharding it.
assert param.is_contiguous(), f"Parameter {name} is not contiguous"
if name == "model.final_layer_norm.pp_block.weight":
assert 1 == 1

numel = param.numel()
padded_numel_per_dp = (numel - 1) // self.dp_pg.size() + 1
Expand Down Expand Up @@ -368,25 +371,25 @@ def extract_parallel_ranks_from_shard_path(
) -> Union[Tuple[int, int, int], Tuple[int, int]]:
"""Extract parallel ranks from shard path

For example, if the shard path is:
+ For ZeRO-1: /path/to/optimizer_pp-0-of-1_dp-0-of-2_tp-0-of-1.pt
then the function will return (0, 0, 0) (pp_rank, dp_rank, tp_rank)
For example:
- ZeRO-1: /path/to/optimizer_pp-0-of-1_dp-0-of-2_tp-0-of-2_exp-0-of-1.pt
Returns: (0, 0, 0) (pp_rank, dp_rank, tp_rank)

For ZeRO-0: /path/to/optimizer_pp-0-of-1_tp-0-of-1.pt
then the function will return (0, 0) (pp_rank, tp_rank)
- ZeRO-0: /path/to/optimizer_pp-0-of-1_tp-0-of-2_exp-0-of-1.pt
Returns: (0, 0) (pp_rank, tp_rank)
"""
if is_zero1 is True:
# TODO(xrsrke): use the same pattern as weight checkpoints
# in weight checkpoints, we do pp-rank-.... but here we only do pp-...
# TODO(xrsrke): don't hardcode this
pattern = r"optimizer_pp-(\d+)-of-\d+_dp-(\d+)-of-\d+_tp-(\d+)-of-\d+\.pt"
if is_zero1:
pattern = r"optimizer_pp-(\d+)-of-\d+_dp-(\d+)-of-\d+_tp-(\d+)-of-\d+_exp-\d+-of-\d+\.pt"
match = re.search(pattern, str(shard_path))
if not match:
raise ValueError(f"Invalid shard path format: {shard_path}")
pp_rank, dp_rank, tp_rank = match.groups()
return int(pp_rank), int(dp_rank), int(tp_rank)
else:
# NOTE: this is zero0 checkpoint
pattern = r"pp-(\d+)-of-\d+_tp-(\d+)-of-\d+"
pattern = r"optimizer_pp-(\d+)-of-\d+_tp-(\d+)-of-\d+_exp-\d+-of-\d+\.pt"
match = re.search(pattern, str(shard_path))
if not match:
raise ValueError(f"Invalid shard path format: {shard_path}")
pp_rank, tp_rank = match.groups()
return int(pp_rank), int(tp_rank)

Expand Down
166 changes: 148 additions & 18 deletions src/nanotron/serialize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from tqdm import tqdm

from nanotron import distributed as dist
from nanotron import optim
from nanotron import logging, optim
from nanotron.constants import OPTIMIZER_CONFIG_FILE_NAME
from nanotron.logging import log_rank
from nanotron.optim.zero import (
ZeroDistributedOptimizer,
extract_parallel_ranks_from_shard_path,
Expand All @@ -22,13 +24,34 @@
from nanotron.serialize.metadata import TensorMetadata
from nanotron.serialize.utils import ObjectType, merge_and_shard_tp_tensors

logger = logging.get_logger(__name__)


def get_optimizer_filename(
tp_topology: Tuple[int, int],
pp_topology: Tuple[int, int],
dp_topology: Optional[Tuple[int, int]] = None,
exp_topology: Optional[Tuple[int, int]] = None,
is_zero: Optional[bool] = None,
):
"""
tp_topology: Tuple[int, int] = (rank, size)
pp_topology: Tuple[int, int] = (rank, size)
dp_topology: Tuple[int, int] = (rank, size)

NOTE: sometimes we get the checkpoint from a different topology (not the current parallel_context)
"""
assert exp_topology is not None, "exp_topology is required"
assert is_zero is not None, "is_zero is required"
pp_rank, pp_size = pp_topology
tp_rank, tp_size = tp_topology
exp_rank, exp_size = exp_topology

# TODO(xrsrke): take rank instead of parallel_context
def optimizer_filename(parallel_context: ParallelContext, is_zero: bool):
if is_zero is True:
return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
dp_rank, dp_size = dp_topology
return f"{ObjectType.OPTIMIZER.value}_pp-{pp_rank}-of-{pp_size}_dp-{dp_rank}-of-{dp_size}_tp-{tp_rank}-of-{tp_size}_exp-{exp_rank}-of-{exp_size}.pt"
else:
return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
return f"{ObjectType.OPTIMIZER.value}_pp-{pp_rank}-of-{pp_size}_tp-{tp_rank}-of-{tp_size}_exp-{exp_rank}-of-{exp_size}.pt"


def lr_scheduler_filename(parallel_context: ParallelContext, is_zero: bool):
Expand Down Expand Up @@ -57,7 +80,7 @@ def save_optimizer(
root_folder.mkdir(exist_ok=True, parents=True)

if dist.get_rank(parallel_context.world_pg) == 0:
with open(root_folder / "optimizer_config.json", "w") as fo:
with open(root_folder / OPTIMIZER_CONFIG_FILE_NAME, "w") as fo:
tp_size = parallel_context.tp_pg.size()
pp_size = parallel_context.pp_pg.size()
dp_size = parallel_context.dp_pg.size()
Expand Down Expand Up @@ -102,7 +125,13 @@ def convert_to_string(input_item):
torch.save(
optimizer.state_dict(),
root_folder
/ optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)),
/ get_optimizer_filename(
tp_topology=(dist.get_rank(parallel_context.tp_pg), parallel_context.tp_pg.size()),
pp_topology=(dist.get_rank(parallel_context.pp_pg), parallel_context.pp_pg.size()),
dp_topology=(dist.get_rank(parallel_context.dp_pg), parallel_context.dp_pg.size()),
exp_topology=(dist.get_rank(parallel_context.expert_pg), parallel_context.expert_parallel_size),
is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer),
),
)


Expand Down Expand Up @@ -139,6 +168,9 @@ def state_dict_to_device(state_dict: Dict, device: str) -> Dict:
for name, tensor in optim_state.items():
optim_state[name] = tensor.to(device)

for name, tensor in state_dict["gradient_accumulator"].items():
state_dict["gradient_accumulator"][name] = tensor.to(device)

assert (
state_dict["state"][0]["exp_avg"].device.type == "cuda"
), "Optimizer states should be on GPU because model is on GPU"
Expand All @@ -155,7 +187,7 @@ def load_optimizer(
model: Optional[nn.Module] = None,
):
root_folder = root_folder / "optimizer"
ckp_optimizer_config_path = root_folder / "optimizer_config.json"
ckp_optimizer_config_path = root_folder / OPTIMIZER_CONFIG_FILE_NAME
with open(ckp_optimizer_config_path, "r") as file:
ckp_optimizer_config = json.load(file)

Expand All @@ -164,6 +196,7 @@ def load_optimizer(
ckp_dp_size = ckp_optimizer_config["parallelism"]["dp_size"]
ckpt_expert_parallel_size = ckp_optimizer_config["parallelism"]["expert_parallel_size"]

# NOTE: tensor parallel, and pipeline paralell's optimizer state-agnotic loading
if int(ckp_tp_size) != int(parallel_context.tp_pg.size()) or int(ckp_pp_size) != int(
parallel_context.pp_pg.size()
):
Expand All @@ -188,7 +221,7 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
# across data parallel dimension, before merging the shards across tensor parallel dimension
shard_paths = list(
root_folder.glob(
f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_dp-*-of-{ckp_dp_size}_tp-*-of-{ckp_tp_size}-exp-*-of-{ckpt_expert_parallel_size}.pt"
f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_dp-*-of-{ckp_dp_size}_tp-*-of-{ckp_tp_size}_exp-*-of-{ckpt_expert_parallel_size}.pt"
)
)
ckp_sharded_optim_states = merge_dp_shard_in_zero1_optimizer(
Expand Down Expand Up @@ -330,31 +363,128 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
new_optim_state_dict["names"] = new_optim_state_param_names
state_dict = new_optim_state_dict
else:
# NOTE: since here we only load the optimizer states,
# then we shard it according to the current data parallel dimension
# TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely
state_dict = torch.load(
root_folder
/ optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)),
/ get_optimizer_filename(
tp_topology=(dist.get_rank(parallel_context.tp_pg), parallel_context.tp_pg.size()),
pp_topology=(dist.get_rank(parallel_context.pp_pg), parallel_context.pp_pg.size()),
# NOTE(xrsrke): suppose we initially have dp world size of 4,
# then we change to dp world size of 8, then we need to load the optimizer states
# now we do a round-robin mapping of the optimizer states to the new dp world size
# dp=8's ranks: [0, 1, 2, 3, 4, 5, 6, 7]
# maps to: [0, 1, 2, 3, 0, 1, 2, 3]
dp_topology=(int(dist.get_rank(parallel_context.pp_pg)) // int(ckp_dp_size), ckp_dp_size),
exp_topology=(dist.get_rank(parallel_context.expert_pg), parallel_context.expert_parallel_size),
is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer),
),
map_location=map_location,
)

def create_merged_optim_states(param_shapes, map_location):
merged_states = {}
for name, p_shape in param_shapes.items():
p_shape = tuple(int(x) for x in p_shape)
merged_states[name] = {
"exp_avg": torch.zeros(p_shape).view(-1).to(map_location),
"exp_avg_sq": torch.zeros(p_shape).view(-1).to(map_location),
}
return merged_states

def create_merged_gradients(param_shapes, map_location):
merged_grads = {}
for name, p_shape in param_shapes.items():
p_shape = tuple(int(x) for x in p_shape)
merged_grads[name] = torch.zeros(p_shape).view(-1).to(map_location)
return merged_grads

def load_sharded_states(shard_paths, map_location, load_type="state"):
sharded_states = {}
for shard_path in shard_paths:
pp_rank, dp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=True)
checkpoint = torch.load(shard_path, map_location=map_location)
sharded_states[(tp_rank, dp_rank)] = checkpoint[load_type]
return sharded_states

def get_key_by_value(d, target_value):
return next((key for key, value in d.items() if value == target_value), None)

def apply_offsets(merged_tensor, sharded_states, param_name, offsets, tp_rank, state_keys=None):
if state_keys:
for key in state_keys:
p_idx = get_key_by_value(state_dict["names"], param_name)
merged_tensor[param_name][key][int(offsets[0]) : int(offsets[1])] = sharded_states[
(int(tp_rank), int(dp_rank))
][p_idx][key]
else:
merged_tensor[param_name][int(offsets[0]) : int(offsets[1])] = sharded_states[
(int(tp_rank), int(dp_rank))
][param_name]

if isinstance(optimizer, ZeroDistributedOptimizer):
# NOTE: only reshard after merging tp shards
# or we get a new dp_Size
if int(ckp_tp_size) != parallel_context.tp_pg.size() or int(ckp_dp_size) != parallel_context.dp_pg.size():
# NOTE: if the optimizer is ZeRO-1, now we shard the optimizer states across data parallel dimension
shard_paths = list(
root_folder.glob(
f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_dp-*-of-{ckp_dp_size}_tp-*-of-{ckp_tp_size}_exp-*-of-{ckpt_expert_parallel_size}.pt"
)
)

if int(ckp_dp_size) != parallel_context.dp_pg.size():
log_rank(
f"[Optimizer Loading] Detect new data parallelism topology in ZeRO-1, resharding optimizer states and gradient accumulator's states", # noqa
logger=logger,
level=logging.INFO,
rank=0,
)

current_dp_rank = dist.get_rank(parallel_context.dp_pg)
tp_rank = dist.get_rank(parallel_context.tp_pg)
OPTIMIZER_STATE_NAMES = state_dict["state"][0].keys() - ["step"]
param_shapes = ckp_optimizer_config["configs"]["orig_param_shapes"]

# Handle optimizer states
ckp_sharded_optim_states = load_sharded_states(shard_paths, map_location, "state")
merged_optim_states = create_merged_optim_states(param_shapes, map_location)

for p_name, offsets in ckp_optimizer_config["configs"]["param_name_to_dp_rank_offsets"].items():
for dp_rank, offset in offsets.items():
apply_offsets(
merged_optim_states, ckp_sharded_optim_states, p_name, offset, tp_rank, OPTIMIZER_STATE_NAMES
)

# Update state dict with new sliced tensors
for param_index in state_dict["state"]:
param_name = [name for idx, name in state_dict["names"].items() if idx == param_index][0]
for state_name in OPTIMIZER_STATE_NAMES:
current_offsets = optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank]
sliced_tensor = get_sliced_tensor(
param=state_dict["state"][param_index][state_name],
start_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][0],
end_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][1],
param=merged_optim_states[param_name][state_name],
start_offset=current_offsets[0],
end_offset=current_offsets[1],
)
assert sliced_tensor.numel() > 0
state_dict["state"][param_index][state_name] = sliced_tensor

optimizer.load_state_dict(state_dict, map_location=map_location)
# Handle gradient accumulator if DP size changed
assert int(ckp_tp_size) == parallel_context.tp_pg.size(), "Don't support changing TP size for ZeRO-1"

ckp_sharded_grad_accum = load_sharded_states(shard_paths, map_location, "gradient_accumulator")
merged_grad_accumulator = create_merged_gradients(param_shapes, map_location)

for p_name, offsets in ckp_optimizer_config["configs"]["param_name_to_dp_rank_offsets"].items():
for dp_rank, offset in offsets.items():
apply_offsets(merged_grad_accumulator, ckp_sharded_grad_accum, p_name, offset, tp_rank)

# Update gradient accumulator with new slices
for p_name in state_dict["gradient_accumulator"].keys():
new_offset = optimizer.param_name_to_dp_rank_offsets[p_name][int(dp_rank)]
assert state_dict["gradient_accumulator"][p_name].device == merged_grad_accumulator[p_name].device
state_dict["gradient_accumulator"][p_name] = merged_grad_accumulator[p_name][
int(new_offset[0]) : int(new_offset[1])
]

optimizer.load_state_dict(state_dict, map_location=map_location)


def load_lr_scheduler(
Expand Down
14 changes: 6 additions & 8 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def __init__(
parallel_config=self.config.parallelism, tp_pg=self.parallel_context.tp_pg
)
self.model = self.init_model() # Defines self.model

self.unwrapped_model: NanotronModel = (
self.model.module if isinstance(self.model, DistributedDataParallel) else self.model
)
Expand Down Expand Up @@ -443,6 +444,11 @@ def train(
# free memory
gc.collect()
torch.cuda.empty_cache()

# Move optimizer states back to GPU before optimizer step
if self.init_checkpoint_path is not None and self.config.checkpoints.load_optimizer:
state_dict_to_device(self.optimizer.state_dict(), "cuda")

with prof:
for self.iteration_step in range(self.initial_iter_step, self.last_iter_step + 1):
if isinstance(prof, torch.profiler.profile):
Expand Down Expand Up @@ -552,14 +558,6 @@ def training_step(
loss_avg = None
handle = None

# Move optimizer states back to GPU before optimizer step
if (
self.init_checkpoint_path is not None
and self.config.checkpoints.load_optimizer
and self.iteration_step == self.initial_iter_step
):
state_dict_to_device(self.optimizer.state_dict(), "cuda")

before_optim_step_sanity_checks(
self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer
)
Expand Down
Loading