Skip to content

Commit b291ad6

Browse files
authored
[DeepSeek] Enable checkpoint load from HF (#908)
Enable loading weights from HF checkpoint. 1. Download - Added `download.py` to allow user download a HF checkpoint into local disk cache. - Usage: `python download.py {model_id}` 2. Load weights - Added `checkpoint.py` to load tensors from HF cache dir into a model. - The model can be a model chunk or full model. 3. Various code refactor - Moved `ModelArgs` into a separate file, adding DeepSeek config registry 4. Added support for DeepSeek-V2 - Greedy routing - Softmax score function 5. Added example run.py based on [DeepSeek-V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/tree/main), a 16B toy MoE. `torchrun --standalone --nproc-per-node 4 run.py`
1 parent bfbff6f commit b291ad6

File tree

5 files changed

+477
-266
lines changed

5 files changed

+477
-266
lines changed
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
import logging
9+
import os
10+
from typing import Dict, Optional, Set, Tuple
11+
12+
import torch
13+
from safetensors import safe_open
14+
15+
from transformers.utils import cached_file
16+
17+
18+
logger = logging.getLogger(__name__)
19+
20+
_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
21+
22+
23+
def read_weights_from_json(file_path: str) -> Optional[Dict[str, str]]:
24+
try:
25+
with open(file_path, "r") as file:
26+
data = json.load(file)
27+
28+
if "weight_map" in data and isinstance(data["weight_map"], dict):
29+
return data["weight_map"]
30+
else:
31+
logger.info("No 'weight_map' dictionary found in the JSON file.")
32+
return None
33+
except (json.JSONDecodeError, Exception) as e:
34+
logger.info(f"An error occurred while reading the JSON file: {str(e)}")
35+
return None
36+
37+
38+
def get_hf_weight_map_and_path(
39+
model_id: str,
40+
) -> Tuple[Dict[str, str], str]:
41+
"""Get the weight map for a given HF model id and also the cache path for loading the weights"""
42+
try:
43+
index_file = cached_file(model_id, _DEFAULT_SAFETENSOR_FILE_NAME)
44+
except Exception as e:
45+
logger.error(
46+
f"Model `{model_id}` not found in HF cache. "
47+
f"You can download the model using `python download.py {model_id}"
48+
)
49+
raise e
50+
51+
weight_map = read_weights_from_json(index_file)
52+
weight_path = os.path.dirname(index_file)
53+
logger.info(f"Loading weights from: {weight_path}")
54+
return weight_map, weight_path
55+
56+
57+
def get_needed_files(
58+
state_dict: Dict[str, torch.Tensor], weight_map: Dict[str, str]
59+
) -> Set[str]:
60+
needed_files = set()
61+
for param in state_dict.keys():
62+
file = weight_map.get(param)
63+
if file:
64+
needed_files.add(file)
65+
elif param.endswith("weight"):
66+
raise ValueError(
67+
f"Parameter {param} not found in weight map, please check..."
68+
)
69+
logger.info(f"Needed files: {needed_files}")
70+
return needed_files
71+
72+
73+
def load_safetensor_file(
74+
full_path: str, device: torch.device
75+
) -> Dict[str, torch.Tensor]:
76+
tensors = {}
77+
with safe_open(full_path, framework="pt", device=device) as f:
78+
for k in f.keys():
79+
tensors[k] = f.get_tensor(k)
80+
logger.info(f"Loaded {len(tensors)} tensors from {full_path}")
81+
return tensors
82+
83+
84+
def load_safetensor_weights(
85+
model: torch.nn.Module,
86+
weight_map: Dict[str, str],
87+
file_location: str,
88+
device: torch.device,
89+
):
90+
"""
91+
Load safetensor weights into a `nn.Module`.
92+
93+
Args:
94+
model (Module): The PyTorch module to load weights into. It may be a
95+
model chunk or a full model.
96+
weight_map (Dict[str, str]): Mapping of model parameters to file names.
97+
file_location (str): Directory containing the weight files.
98+
device (torch.device): The device to load tensors onto.
99+
"""
100+
model_state_dict = model.state_dict()
101+
needed_files = get_needed_files(model_state_dict, weight_map)
102+
updated_states: Set[str] = set()
103+
104+
for file in needed_files:
105+
full_path = os.path.join(file_location, file)
106+
try:
107+
checkpoint = load_safetensor_file(full_path, "cpu")
108+
except FileNotFoundError:
109+
logger.error(f"File not found: {full_path}")
110+
except Exception as e:
111+
logger.error(f"Error during checkpoint processing of {full_path}: {str(e)}")
112+
113+
matched_keys = set(checkpoint.keys()) & set(model_state_dict.keys())
114+
for key in matched_keys:
115+
# Check shape
116+
if model_state_dict[key].shape != checkpoint[key].shape:
117+
raise ValueError(
118+
f"Shape mismatch for {key}: "
119+
f"model needs {model_state_dict[key].shape}, but "
120+
f"checkpoint has {checkpoint[key].shape}"
121+
)
122+
model_state_dict[key] = checkpoint[key].to(device)
123+
124+
updated_states.update(matched_keys)
125+
126+
missing_keys = set(model_state_dict.keys()) - updated_states
127+
if missing_keys:
128+
raise RuntimeError(
129+
f"Partially updated state dict. Missing parameters: {missing_keys}"
130+
)
131+
132+
model.load_state_dict(model_state_dict, strict=False, assign=True)
133+
logger.info(f"Successfully loaded {len(updated_states)} weights into model")
134+
135+
136+
def load_weights_from_hf(
137+
model: torch.nn.Module,
138+
distribution: str,
139+
device: torch.device,
140+
):
141+
"""
142+
Load the weights from Hugging Face format (index file + multiple safetensor
143+
files), and fill into `model`. Model config is needed b/c we permute
144+
wq and wk weights based on attn heads.
145+
"""
146+
147+
weight_map, weight_path = get_hf_weight_map_and_path(distribution)
148+
149+
load_safetensor_weights(
150+
model,
151+
weight_map,
152+
weight_path,
153+
device,
154+
)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Usage:
8+
# python download.py {model_id}
9+
# Example:
10+
# python download.py deepseek-ai/DeepSeek-V2-Lite
11+
12+
import sys
13+
14+
from transformers import AutoModelForCausalLM
15+
16+
model_id = sys.argv[1]
17+
18+
model = AutoModelForCausalLM.from_pretrained(
19+
model_id,
20+
device_map="auto",
21+
)

0 commit comments

Comments
 (0)