-
Notifications
You must be signed in to change notification settings - Fork 346
[Experimental Feature] Huggingface model training #919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hi @junjzhang! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
f111334
to
3689b58
Compare
3689b58
to
c412663
Compare
This PR is super interesting to me! Instead of supporting a ton of individual models, a quick way to enable TorchTitan is to use models loaded from Huggingface and apply the various optimization techniques. |
Yes! The weight loading function in this PR is general enough, the only cost of adapting a new model is to implementing parallelism applying function and maybe patch some forward function to employ pp. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you are almost creating a copy of entire torchtitan under this folder.
Instead, could you reuse existing files and functions (via import
) as much as possible?
E.g. I can tell that hf_weights_utils.py
, parallelize
/pipeline
fns, toml config, and test files cannot be directly reused, for other parts can we reuse, including train.py
?
Even for parallelize
/pipeline
fns, we can have standalone files, but we should depend on functions in torchtitan llama, e.g. https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/parallelize_llama.py, as much as possible.
Reusing codes is a thing that always kept in my mind, I've tried my best.
I think it would be more grace if titan could refactor these function and extract some common pattern. Then I could reuse codes as much as possible. |
@tianyu-l Hi, as stated before, I've pushed a new commit that tried to reuse Titan's codes in dataset.py and parallelize_llama.py. If further reusing is required, I suppose these Titan's functions should be refactored to take in more arguments or be decomposed into multiple reusable methods. And I also run training (FSDP2 PP2 TP2) to see the correctness of this commit: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think we can reuse more, e.g. pretty the entire files of dataset.py
and run_train.sh
, most of parallelize_llama.py
.
I agree we should adjust train.py
to support better reuse. This is actually a good execercise for us to improve it. @fegin
Also, please help point out if there are remaining places which are different from current torchtitan but I missed.
Can I ask a general question? What is the motivation for this PR? It definitely makes sense to use llama pretrained weights and load into torchtitan to train and use parallelisms. However, in addition to loading the huggingface weights this PR also uses the huggingface model definition. The huggingface model definition requires patching (to forward), and modifications to torchtitan parallelization code. I have a couple of specific questions:
Thanks! |
dataset_path = "../../../tests/assets/c4_test" | ||
|
||
[experimental] | ||
context_parallel_degree = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if you have tested if CP works out of the box? Although I'm not sure what attention ops HF Llama model uses.
train_state = TrainState() | ||
|
||
# load initial checkpoint | ||
checkpoint = CheckpointManager( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the torchtitan CheckpointManager
still compatible with the HF model?
@wconstab pp issue fixed and tested with accfa1f#diff-f3ae151a2c757861e79894d650697214e593188396134823440371457ed71ed3 . |
@tianyu-l CP tested with FSDP2 TP2 CP2. |
Is there a plan to deduplicate the code from the main TorciTitan? What's the motivation of duplicating |
As mentioned before, I need to revise some lines of code of main() to use AutoModelForCausalLM, make the model compilable, adapt PP input, load state dict, etc. Since main is a super huge function, I cannot reuse it. I guess a better way is decomposing the main function into separate common methods, like |
@junjzhang Sounds good. I'm refactoring |
if weight_map is None: | ||
partition_path = pretrained_model_path / "model.safetensors" | ||
else: | ||
partition_path = pretrained_model_path / weight_map[ckpt_state_dict_key] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A KeyError
will be raised if a checkpoint key is missing from weight_map
. For example, when tie_word_embeddings = True
, model.lm_head.weight
is not present in weight_map. So, how should we initialize or load the weight in this case? Currently, a KeyError occurs. If we ignore it, the loaded model will be incorrect—specifically, model.lm_head.weight
will be set to zero.
Llama-3.2-3B can be used as a model for verification.
train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) | ||
model.to_empty(device=init_device) | ||
with torch.no_grad(): | ||
load_sharded_state_dict_for_model_from_hf(job_config.model.flavor, model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also doesn’t properly initialize rotary_emb’s inv_freq, so I had to hack it for now:
inv_freq, attention_scaling = m.model.rotary_emb.rope_init_fn(m.model.rotary_emb.config, device=init_device)
m.model.rotary_emb.inv_freq, m.model.rotary_emb.attention_scaling = inv_freq, attention_scaling
is there any more elegant approach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the compare code:
import os.path
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM
def load_hf(name_or_path):
model = LlamaForCausalLM.from_pretrained(name_or_path, use_flash_attention_2=False, trust_remote_code=True)
return model
def load_titan_from_hf(name_or_path):
model_config = AutoConfig.from_pretrained(name_or_path, trust_remote_code=True)
with torch.device("meta"):
if hasattr(LlamaForCausalLM, "from_config"):
model = LlamaForCausalLM.from_config(
model_config, use_flash_attention_2=False, trust_remote_code=True
)
else:
model = LlamaForCausalLM(model_config)
model.apply(lambda m: setattr(m, "_is_hf_initialized", False))
model.to_empty(device="cpu") # will set all param to zero
with torch.no_grad():
load_sharded_state_dict_for_model_from_hf(name_or_path, model)
return model
if __name__ == "__main__":
path_model = ....
name_or_path = os.path.join(path_model, "Llama-3.1-8B")
tokenizer = AutoTokenizer.from_pretrained(name_or_path)
default_dtype = torch.bfloat16
torch.set_default_dtype(default_dtype)
model_hf = load_hf(name_or_path)
model_tt = load_titan_from_hf(name_or_path)
messages = [{"role": "user", "content": "Create a JavaScript fragment that includes computing the percentage equivalent of an integer value that's conferred as an input. As an example, consider the quantity of 50."}, {"role": "assistant", "content": "Here is a simple implementation in JavaScript using a function that takes an integer and the maximum possible value of that integer as inputs. Here, we are considering the maximum value to be 100.\n\n```javascript\nfunction toPercentage(value, max_value = 100) {\n return (value / max_value) * 100;\n}\n\nconsole.log(toPercentage(50)); //Output: 50\n```\n\nIn the example, the function `toPercentage()` takes `value` and `max_value` as input. It returns the percentage equivalent of the `value` input based on the `max_value` assumed to be 100.\n\nYou can use this function in your webpage, JavaScript app, or even in node.js. You can also easily alter it to handle different scenarios.\n\nJust replace the (50) in `console.log(toPercentage(50));` with any integer you would like to convert into percentage. The function will print the answer to the console.\n\nThis script assumes that the maximum value is 100. If you want to calculate the percentage with a different maximum value, you can use the function as `toPercentage(value, max_value)`. For example, toPercentage(50, 200) would return 25, because 50 is 25% of 200."}]
input_ids = torch.tensor(tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True), dtype=torch.long)
input_ids = pad_sequence([input_ids], batch_first=True, padding_value=-100).long()
attention_mask = input_ids.ne(-100).long()
with torch.no_grad():
inputs_embeds_hf = model_hf.model.embed_tokens(input_ids)
inputs_embeds_tt = model_tt.model.embed_tokens(input_ids)
print(f"embed same: {torch.allclose(inputs_embeds_hf, inputs_embeds_tt, atol=1e-5)}")
cache_position = torch.arange(0, 0 + inputs_embeds_hf.shape[1], device=inputs_embeds_hf.device)
position_ids = cache_position.unsqueeze(0)
print(f"same inv_freq: {torch.allclose(model_hf.model.rotary_emb.inv_freq, model_tt.model.rotary_emb.inv_freq)}")
position_embeddings_hf = model_hf.model.rotary_emb(inputs_embeds_hf, position_ids)
position_embeddings_tt = model_tt.model.rotary_emb(inputs_embeds_tt, position_ids)
print(f"pos embed-cos same: {torch.allclose(position_embeddings_hf[0], position_embeddings_tt[0], atol=1e-5)}")
print(f"pos embed-sin same: {torch.allclose(position_embeddings_hf[1], position_embeddings_tt[1], atol=1e-5)}")
model_output_hf = model_hf(input_ids=input_ids, attention_mask=attention_mask)
model_output_tt = model_tt(input_ids=input_ids, attention_mask=attention_mask)
print(f"logits same: {torch.allclose(model_output_hf.logits, model_output_tt.logits)}")
print(f"----- re-init rotary_emb inv_freq --------")
model_tt.model.rotary_emb.inv_freq = model_tt.model.rotary_emb.rope_init_fn(model_tt.model.rotary_emb.config, "cpu")[0]
print(f"same inv_freq: {torch.allclose(model_hf.model.rotary_emb.inv_freq, model_tt.model.rotary_emb.inv_freq)}")
position_embeddings_hf = model_hf.model.rotary_emb(inputs_embeds_hf, position_ids)
position_embeddings_tt = model_tt.model.rotary_emb(inputs_embeds_tt, position_ids)
print(f"pos embed-cos same: {torch.allclose(position_embeddings_hf[0], position_embeddings_tt[0], atol=1e-5)}")
print(f"pos embed-sin same: {torch.allclose(position_embeddings_hf[1], position_embeddings_tt[1], atol=1e-5)}")
model_output_hf = model_hf(input_ids=input_ids, attention_mask=attention_mask)
model_output_tt = model_tt(input_ids=input_ids, attention_mask=attention_mask)
print(f"logits same: {torch.allclose(model_output_hf.logits, model_output_tt.logits)}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
compare result:
embed same: True
same inv_freq: False
pos embed-cos same: False
pos embed-sin same: False
logits same: False
----- re-init rotary_emb inv_freq --------
same inv_freq: True
pos embed-cos same: True
pos embed-sin same: True
logits same: True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ericxsun , Could you explain why rotary_emb's inv_freq is not initialized correctly? I think inv_freq
got the correct initialization when creating the LlamaRotaryEmbedding
, and because persistent = False
, it won‘t be affected by load_state_dict
.
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, config: LlamaConfig, device=None):
super().__init__()
...
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) # init here
self.register_buffer("inv_freq", inv_freq, persistent=False) # registered as non-persistent buffer
self.original_inv_freq = self.inv_freq
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m.to_empty(device=init_device)
, before loading, caused inv_freq
to have a random value. It looks like this:
model.model.rotary_emb.inv_freq
Out[9]:
tensor([ 9.7134e-31, 0.0000e+00, 2.6889e-31, 0.0000e+00, 1.1210e-43,
0.0000e+00, 8.9683e-44, 0.0000e+00, 1.2224e-30, 0.0000e+00,
-5.7851e+13, 4.5783e-41, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 9.1084e-44, 0.0000e+00, 1.5715e-30,
0.0000e+00, 1.4957e-30, 0.0000e+00, 0.0000e+00, 0.0000e+00,
4.6243e-44, 0.0000e+00, 6.7608e-30, 0.0000e+00, -6.4284e+13,
4.5783e-41, 2.9147e-43, 0.0000e+00, 1.1210e-43, 0.0000e+00,
1.5482e-30, 0.0000e+00, 1.4013e-45, 0.0000e+00, 0.0000e+00,
0.0000e+00, 4.2039e-45, 0.0000e+00, 8.9129e+04, 1.1600e+00,
4.2039e-45, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 1.3815e-30, 0.0000e+00, 2.2561e-43, 0.0000e+00,
2.6974e-31, 0.0000e+00, 1.6658e-30, 0.0000e+00])
Hi, as discussed in #903. This PR includes features of training a llama model from HF directly using “AutoModelForCausalLM” and loading safetensors (hf weights) in an online sharding manner.
pytest test_loading_hf_weights.py
Here is my results:
LOG_RANK=7 bash run_train.sh
(FSDP 2 - PP 2 - TP 2)Here is my results: