Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
KEY_MODEL = "model"
KEY_OPTIMIZER = "optimizer"

ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"

# Below are rewrite of HF FSDP model saving functions to be able to handle
# that the parameters are now a mixture of regular and Dtensors.
# - these functions are found in accelerate.utils.fsdp_utils.py
Expand Down Expand Up @@ -110,16 +112,30 @@ def save_fsdp_optimizer(
# get the state dicts for model and optimize
(model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer)

# filter out lora state dict
lora_state_dict = {
k: v for k, v in model_state_dict.items() if "lora_A" in k or "lora_B" in k
}

# - save model
ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
os.makedirs(ckpt_model, exist_ok=True)
logger.info(f"Saving model to {ckpt_model}")
dcp.save(
state_dict={KEY_MODEL: model_state_dict},
storage_writer=dcp.FileSystemWriter(ckpt_model),
planner=DefaultSavePlanner(),
)
logger.info(f"Model saved to {ckpt_model}")
if lora_state_dict:
ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
os.makedirs(ckpt_model, exist_ok=True)
logger.info(f"Saving lora model to {ckpt_model}")
dcp.save(
state_dict={KEY_MODEL: lora_state_dict},
storage_writer=dcp.FileSystemWriter(ckpt_model),
planner=DefaultSavePlanner(),
)
else:
ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
os.makedirs(ckpt_model, exist_ok=True)
logger.info(f"Saving ft model to {ckpt_model}")
dcp.save(
state_dict={KEY_MODEL: model_state_dict},
storage_writer=dcp.FileSystemWriter(ckpt_model),
planner=DefaultSavePlanner(),
)

# - save optimizer
ckpt_opt = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
Expand Down Expand Up @@ -467,30 +483,54 @@ def save_sharded_safetensors(
save_directory: str,
metadata: Dict,
max_shard_size: Union[int, str] = "5GB",
lora: bool = False,
):
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
input_state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size,
)
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the index
with open(
os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8"
) as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
if not lora:
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
input_state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size,
)

filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors}
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the index
with open(
os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8"
) as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)

filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {
tensor: input_state_dict[tensor].contiguous() for tensor in tensors
}
save_file(
shard, os.path.join(save_directory, shard_file), metadata=metadata
)
else:
filename_pattern = ADAPTER_SAFE_WEIGHTS_NAME.replace(
".bin", "{suffix}.bin"
).replace(".safetensors", "{suffix}.safetensors")
state_dict_split = split_torch_state_dict_into_shards(
input_state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size,
)
filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {
tensor: input_state_dict[tensor].contiguous() for tensor in tensors
}
save_file(
shard, os.path.join(save_directory, shard_file), metadata=metadata
)


# --------------------------- SCRIPT -------------------------
Expand Down Expand Up @@ -540,14 +580,28 @@ def recover_safetensors_from_dcp(
# get the state_dict
state_dict = loader(checkpoint_dir)

new_state_dict = {}
lora = False
for name, param in state_dict.items():
if "lora_A" in name or "lora_B" in name:
lora = True
if "base_model.model." in name:
name = name.replace("base_model.model.", "", 1)
if "default." in name:
name = name.replace("default.", "", 1)
new_state_dict[name] = param

# recover the original state dict
state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path)
state_dict = recover_original_state_dict_from_checkpoint(
new_state_dict, _name_or_path
)

# save it as a safetensors file
save_sharded_safetensors(
{k: v.contiguous() for k, v in state_dict.items()},
output_dir,
metadata={"format": "pt"},
lora=lora,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

# Third Party
from peft import LoraConfig
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
from torch.distributed._tensor import DTensor

# pylint: disable=import-error
Expand Down Expand Up @@ -237,10 +236,6 @@ def __init__(
assert (
lora_config.bias == "none"
), "ScatterMoE currently unable to handle bias in the lora adapters"
assert (
lora_config.target_modules == INCLUDE_LINEAR_LAYERS_SHORTHAND
or INCLUDE_LINEAR_LAYERS_SHORTHAND in lora_config.target_modules
), "ScatterMoe currently only handles lora adapters on all linears."

assert lora_config.init_lora_weights in {
True,
Expand Down Expand Up @@ -278,28 +273,8 @@ def __init__(
# - w1: the up_projection.
# - w2: the down_projection.
# - w3 (optional): the gate projection.
self.w1 = ScatteredExperts(
in_features=self.hidden_size,
out_features=self.intermediate_size,
num_experts=self.num_experts,
fan_out=self.top_k if not self.all_to_all else 1,
grouped_out=True,
dtype=dtype,
device=device,
lora_config=lora_config,
)
self.w2 = ScatteredExperts(
in_features=self.intermediate_size,
out_features=self.hidden_size,
num_experts=self.num_experts,
fan_out=1,
grouped_in=True,
dtype=dtype,
device=device,
lora_config=lora_config,
)
if mlp_arch == SCATTERMOE_SPEC_HAS_GATE:
self.w3 = ScatteredExperts(
if not lora_config:
self.w1 = ScatteredExperts(
in_features=self.hidden_size,
out_features=self.intermediate_size,
num_experts=self.num_experts,
Expand All @@ -309,6 +284,29 @@ def __init__(
device=device,
lora_config=lora_config,
)
if not lora_config:
self.w2 = ScatteredExperts(
in_features=self.intermediate_size,
out_features=self.hidden_size,
num_experts=self.num_experts,
fan_out=1,
grouped_in=True,
dtype=dtype,
device=device,
lora_config=lora_config,
)
if not lora_config:
if mlp_arch == SCATTERMOE_SPEC_HAS_GATE:
self.w3 = ScatteredExperts(
in_features=self.hidden_size,
out_features=self.intermediate_size,
num_experts=self.num_experts,
fan_out=self.top_k if not self.all_to_all else 1,
grouped_out=True,
dtype=dtype,
device=device,
lora_config=lora_config,
)

# referenced from dolomite-engine
def _compute_routing_weights(self, hidden_states: torch.Tensor):
Expand Down Expand Up @@ -457,36 +455,39 @@ def forward(self, hidden_states: torch.Tensor):
)

# compute the up projection
out = self.w1(
hidden_states,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs,
expert_offsets,
)
out = self.activation(out)

# - if the arch has a seperate gate projection
if self.w3:
out *= self.w3(
if hasattr(self, "w1"):
out = self.w1(
hidden_states,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs,
expert_offsets,
)
out = self.activation(out)

# - if the arch has a seperate gate projection
if hasattr(self, "w3"):
if self.w3:
out *= self.w3(
hidden_states,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs,
expert_offsets,
)

# compute the down projection
# - if no all-to-all processing, then depend on
# scattermoe kernel to perform the final scattering
hidden_states = self.w2(
out,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs,
expert_offsets,
gates=(None if self.all_to_all else routing_weights),
)
if hasattr(self, "w2"):
hidden_states = self.w2(
out,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs,
expert_offsets,
gates=(None if self.all_to_all else routing_weights),
)

# maybe scatter
hidden_states = self._maybe_scatter(
Expand Down