Skip to content

[Experimental Feature] Huggingface model training #919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

junjzhang
Copy link

@junjzhang junjzhang commented Mar 3, 2025

Hi, as discussed in #903. This PR includes features of training a llama model from HF directly using “AutoModelForCausalLM” and loading safetensors (hf weights) in an online sharding manner.

  • test loading safetensors:pytest test_loading_hf_weights.py
    Here is my results:
    Cursor 2025-03-03 21 30 40
  • Training: LOG_RANK=7 bash run_train.sh (FSDP 2 - PP 2 - TP 2)
    Here is my results:
    Cursor 2025-03-03 22 10 08

@facebook-github-bot
Copy link

Hi @junjzhang!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@junjzhang junjzhang force-pushed the feature_train_hf_models branch from f111334 to 3689b58 Compare March 3, 2025 14:19
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 3, 2025
@junjzhang junjzhang force-pushed the feature_train_hf_models branch from 3689b58 to c412663 Compare March 3, 2025 14:25
@casper-hansen
Copy link
Contributor

This PR is super interesting to me! Instead of supporting a ton of individual models, a quick way to enable TorchTitan is to use models loaded from Huggingface and apply the various optimization techniques.

@junjzhang
Copy link
Author

This PR is super interesting to me! Instead of supporting a ton of individual models, a quick way to enable TorchTitan is to use models loaded from Huggingface and apply the various optimization techniques.

Yes! The weight loading function in this PR is general enough, the only cost of adapting a new model is to implementing parallelism applying function and maybe patch some forward function to employ pp.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you are almost creating a copy of entire torchtitan under this folder.
Instead, could you reuse existing files and functions (via import) as much as possible?
E.g. I can tell that hf_weights_utils.py, parallelize/pipeline fns, toml config, and test files cannot be directly reused, for other parts can we reuse, including train.py?

Even for parallelize/pipeline fns, we can have standalone files, but we should depend on functions in torchtitan llama, e.g. https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/parallelize_llama.py, as much as possible.

@tianyu-l tianyu-l linked an issue Mar 4, 2025 that may be closed by this pull request
@junjzhang
Copy link
Author

junjzhang commented Mar 4, 2025

I see you are almost creating a copy of entire torchtitan under this folder. Instead, could you reuse existing files and functions (via import) as much as possible? E.g. I can tell that hf_weights_utils.py, parallelize/pipeline fns, toml config, and test files cannot be directly reused, for other parts can we reuse, including train.py?

Even for parallelize/pipeline fns, we can have standalone files, but we should depend on functions in torchtitan llama, e.g. https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/parallelize_llama.py, as much as possible.

Reusing codes is a thing that always kept in my mind, I've tried my best.

  • As for train.py, since the original torchtitan version write every thing in a main function, it's hard to reuse it since I change some line of codes in the main function.
  • As for parallelize/pipeline fns, it's also hard to reuse original codes due to the same reason, I have to change the parallel plan but it is hard coded in titan's function.
  • As for Dataset, for the same reason, I have to return extra position_ids for hf's llama. Maybe I could reuse more code for dataset by doing monkey patch.
  • As for loss, same reason.

I think it would be more grace if titan could refactor these function and extract some common pattern. Then I could reuse codes as much as possible.
Anyway, I'll take a look to see if I could reuse more codes here.

@junjzhang
Copy link
Author

junjzhang commented Mar 4, 2025

I see you are almost creating a copy of entire torchtitan under this folder. Instead, could you reuse existing files and functions (via import) as much as possible? E.g. I can tell that hf_weights_utils.py, parallelize/pipeline fns, toml config, and test files cannot be directly reused, for other parts can we reuse, including train.py?
Even for parallelize/pipeline fns, we can have standalone files, but we should depend on functions in torchtitan llama, e.g. https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/parallelize_llama.py, as much as possible.

Reusing codes is a thing that always kept in my mind, I've tried my best.

  • As for train.py, since the original torchtitan version write every thing in a main function, it's hard to reuse it since I change some line of codes in the main function.
  • As for parallelize/pipeline fns, it's also hard to reuse original codes due to the same reason, I have to change the parallel plan but it is hard coded in titan's function.
  • As for Dataset, for the same reason, I have to return extra position_ids for hf's llama. Maybe I could reuse more code for dataset by doing monkey patch.
  • As for loss, same reason.

I think it would be more grace if titan could refactor these function and extract some common pattern. Then I could reuse codes as much as possible. Anyway, I'll take a look to see if I could reuse more codes here.

@tianyu-l Hi, as stated before, I've pushed a new commit that tried to reuse Titan's codes in dataset.py and parallelize_llama.py. If further reusing is required, I suppose these Titan's functions should be refactored to take in more arguments or be decomposed into multiple reusable methods.

And I also run training (FSDP2 PP2 TP2) to see the correctness of this commit:
Cursor 2025-03-04 16 55 01

@junjzhang junjzhang requested a review from tianyu-l March 5, 2025 07:59
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think we can reuse more, e.g. pretty the entire files of dataset.py and run_train.sh, most of parallelize_llama.py.

I agree we should adjust train.py to support better reuse. This is actually a good execercise for us to improve it. @fegin

Also, please help point out if there are remaining places which are different from current torchtitan but I missed.

@junjzhang
Copy link
Author

@tianyu-l Hi, I've updated readme in 91838de and reply all your comments. Please have a look.

@junjzhang junjzhang requested a review from tianyu-l March 5, 2025 11:46
@tianyu-l tianyu-l requested a review from fegin March 6, 2025 01:44
@wconstab
Copy link
Contributor

wconstab commented Mar 6, 2025

Can I ask a general question? What is the motivation for this PR? It definitely makes sense to use llama pretrained weights and load into torchtitan to train and use parallelisms.

However, in addition to loading the huggingface weights this PR also uses the huggingface model definition. The huggingface model definition requires patching (to forward), and modifications to torchtitan parallelization code. I have a couple of specific questions:

  1. is there a reason you prefer to use the hf llama3 model definition code instead of the torchtitan llama3 code? (If we could load hf weights into torchtitan llama3 would that be just as good?)
  2. since you are using hf model code and hf model weights, why is there a need for customized save/load state dict features?

Thanks!

dataset_path = "../../../tests/assets/c4_test"

[experimental]
context_parallel_degree = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if you have tested if CP works out of the box? Although I'm not sure what attention ops HF Llama model uses.

train_state = TrainState()

# load initial checkpoint
checkpoint = CheckpointManager(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the torchtitan CheckpointManager still compatible with the HF model?

@junjzhang
Copy link
Author

@junjzhang junjzhang requested a review from wconstab March 6, 2025 05:53
@junjzhang
Copy link
Author

junjzhang commented Mar 6, 2025

@tianyu-l CP tested with FSDP2 TP2 CP2.
CleanShot 2025-03-06 at 14 39 14@2x

@junjzhang junjzhang requested a review from tianyu-l March 6, 2025 06:34
@fegin
Copy link
Contributor

fegin commented Mar 7, 2025

Is there a plan to deduplicate the code from the main TorciTitan? What's the motivation of duplicating main.py or train()? Is it because of state_dict loading? If so, we can discuss how to make checkpointer supports this feature.

@junjzhang
Copy link
Author

Is there a plan to deduplicate the code from the main TorciTitan? What's the motivation of duplicating main.py or train()? Is it because of state_dict loading? If so, we can discuss how to make checkpointer supports this feature.

As mentioned before, I need to revise some lines of code of main() to use AutoModelForCausalLM, make the model compilable, adapt PP input, load state dict, etc. Since main is a super huge function, I cannot reuse it. I guess a better way is decomposing the main function into separate common methods, like build_pp_models, run_forward_and_backward. Refactoring is needed here to make it more general to reuse. I'd like to discuss with you to see how to refactor titan's methods to add hf model training feature with minimal code added.

@fegin
Copy link
Contributor

fegin commented Mar 10, 2025

@junjzhang Sounds good. I'm refactoring train.py. Let's discuss how can we make it more general. I don't expect HF models or other models can just adopt the original train.py, even after refactor, but hope that we can at least reuse as much as we can. My next step is to land the metrics refactor PR, #945, and publish the next refactor PR on Mar 10.

if weight_map is None:
partition_path = pretrained_model_path / "model.safetensors"
else:
partition_path = pretrained_model_path / weight_map[ckpt_state_dict_key]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A KeyError will be raised if a checkpoint key is missing from weight_map. For example, when tie_word_embeddings = True, model.lm_head.weight is not present in weight_map. So, how should we initialize or load the weight in this case? Currently, a KeyError occurs. If we ignore it, the loaded model will be incorrect—specifically, model.lm_head.weight will be set to zero.

Llama-3.2-3B can be used as a model for verification.

train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
model.to_empty(device=init_device)
with torch.no_grad():
load_sharded_state_dict_for_model_from_hf(job_config.model.flavor, model)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also doesn’t properly initialize rotary_emb’s inv_freq, so I had to hack it for now:

inv_freq, attention_scaling = m.model.rotary_emb.rope_init_fn(m.model.rotary_emb.config, device=init_device)  
m.model.rotary_emb.inv_freq, m.model.rotary_emb.attention_scaling = inv_freq, attention_scaling  

is there any more elegant approach?

Copy link

@airlsyn airlsyn Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the compare code:

import os.path

import torch
from torch.nn.utils.rnn import pad_sequence

from transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM


def load_hf(name_or_path):
    model = LlamaForCausalLM.from_pretrained(name_or_path, use_flash_attention_2=False, trust_remote_code=True)

    return model


def load_titan_from_hf(name_or_path):
    model_config = AutoConfig.from_pretrained(name_or_path, trust_remote_code=True)

    with torch.device("meta"):
        if hasattr(LlamaForCausalLM, "from_config"):
            model = LlamaForCausalLM.from_config(
                model_config, use_flash_attention_2=False, trust_remote_code=True
            )
        else:
            model = LlamaForCausalLM(model_config)

        model.apply(lambda m: setattr(m, "_is_hf_initialized", False))

    model.to_empty(device="cpu")  # will set all param to zero
    with torch.no_grad():
        load_sharded_state_dict_for_model_from_hf(name_or_path,  model)

    return model


if __name__ == "__main__":
    path_model = ....

    name_or_path = os.path.join(path_model, "Llama-3.1-8B")

    tokenizer = AutoTokenizer.from_pretrained(name_or_path)

    default_dtype = torch.bfloat16
    torch.set_default_dtype(default_dtype)

    model_hf = load_hf(name_or_path)
    model_tt = load_titan_from_hf(name_or_path)

    messages = [{"role": "user", "content": "Create a JavaScript fragment that includes computing the percentage equivalent of an integer value that's conferred as an input. As an example, consider the quantity of 50."}, {"role": "assistant", "content": "Here is a simple implementation in JavaScript using a function that takes an integer and the maximum possible value of that integer as inputs. Here, we are considering the maximum value to be 100.\n\n```javascript\nfunction toPercentage(value, max_value = 100) {\n  return (value / max_value) * 100;\n}\n\nconsole.log(toPercentage(50)); //Output: 50\n```\n\nIn the example, the function `toPercentage()` takes `value` and `max_value` as input. It returns the percentage equivalent of the `value` input based on the `max_value` assumed to be 100.\n\nYou can use this function in your webpage, JavaScript app, or even in node.js. You can also easily alter it to handle different scenarios.\n\nJust replace the (50) in `console.log(toPercentage(50));` with any integer you would like to convert into percentage. The function will print the answer to the console.\n\nThis script assumes that the maximum value is 100. If you want to calculate the percentage with a different maximum value, you can use the function as `toPercentage(value, max_value)`. For example, toPercentage(50, 200) would return 25, because 50 is 25% of 200."}]

    input_ids = torch.tensor(tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True), dtype=torch.long)
    input_ids = pad_sequence([input_ids], batch_first=True, padding_value=-100).long()
    attention_mask = input_ids.ne(-100).long()

    with torch.no_grad():
        inputs_embeds_hf = model_hf.model.embed_tokens(input_ids)
        inputs_embeds_tt = model_tt.model.embed_tokens(input_ids)

        print(f"embed same: {torch.allclose(inputs_embeds_hf, inputs_embeds_tt, atol=1e-5)}")

        cache_position = torch.arange(0, 0 + inputs_embeds_hf.shape[1], device=inputs_embeds_hf.device)
        position_ids = cache_position.unsqueeze(0)

        print(f"same inv_freq: {torch.allclose(model_hf.model.rotary_emb.inv_freq, model_tt.model.rotary_emb.inv_freq)}")
        position_embeddings_hf = model_hf.model.rotary_emb(inputs_embeds_hf, position_ids)
        position_embeddings_tt = model_tt.model.rotary_emb(inputs_embeds_tt, position_ids)

        print(f"pos embed-cos same: {torch.allclose(position_embeddings_hf[0], position_embeddings_tt[0], atol=1e-5)}")
        print(f"pos embed-sin same: {torch.allclose(position_embeddings_hf[1], position_embeddings_tt[1], atol=1e-5)}")

        model_output_hf = model_hf(input_ids=input_ids, attention_mask=attention_mask)
        model_output_tt = model_tt(input_ids=input_ids, attention_mask=attention_mask)
        print(f"logits same: {torch.allclose(model_output_hf.logits, model_output_tt.logits)}")

        print(f"----- re-init rotary_emb inv_freq --------")
        model_tt.model.rotary_emb.inv_freq = model_tt.model.rotary_emb.rope_init_fn(model_tt.model.rotary_emb.config, "cpu")[0]
        print(f"same inv_freq: {torch.allclose(model_hf.model.rotary_emb.inv_freq, model_tt.model.rotary_emb.inv_freq)}")
        position_embeddings_hf = model_hf.model.rotary_emb(inputs_embeds_hf, position_ids)
        position_embeddings_tt = model_tt.model.rotary_emb(inputs_embeds_tt, position_ids)

        print(f"pos embed-cos same: {torch.allclose(position_embeddings_hf[0], position_embeddings_tt[0], atol=1e-5)}")
        print(f"pos embed-sin same: {torch.allclose(position_embeddings_hf[1], position_embeddings_tt[1], atol=1e-5)}")

        model_output_hf = model_hf(input_ids=input_ids, attention_mask=attention_mask)
        model_output_tt = model_tt(input_ids=input_ids, attention_mask=attention_mask)
        print(f"logits same: {torch.allclose(model_output_hf.logits, model_output_tt.logits)}")

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compare result:

embed same: True
same inv_freq: False
pos embed-cos same: False
pos embed-sin same: False
logits same: False
----- re-init rotary_emb inv_freq --------
same inv_freq: True
pos embed-cos same: True
pos embed-sin same: True
logits same: True

Copy link
Contributor

@xingchensong xingchensong Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ericxsun , Could you explain why rotary_emb's inv_freq is not initialized correctly? I think inv_freq got the correct initialization when creating the LlamaRotaryEmbedding, and because persistent = False, it won‘t be affected by load_state_dict.

class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, config: LlamaConfig, device=None):
        super().__init__()
        ...
        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)  # init here
        self.register_buffer("inv_freq", inv_freq, persistent=False)  # registered as non-persistent buffer
        self.original_inv_freq = self.inv_freq
        ...

Copy link

@airlsyn airlsyn Mar 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m.to_empty(device=init_device), before loading, caused inv_freq to have a random value. It looks like this:

model.model.rotary_emb.inv_freq
Out[9]: 
tensor([ 9.7134e-31,  0.0000e+00,  2.6889e-31,  0.0000e+00,  1.1210e-43,
         0.0000e+00,  8.9683e-44,  0.0000e+00,  1.2224e-30,  0.0000e+00,
        -5.7851e+13,  4.5783e-41,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  9.1084e-44,  0.0000e+00,  1.5715e-30,
         0.0000e+00,  1.4957e-30,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         4.6243e-44,  0.0000e+00,  6.7608e-30,  0.0000e+00, -6.4284e+13,
         4.5783e-41,  2.9147e-43,  0.0000e+00,  1.1210e-43,  0.0000e+00,
         1.5482e-30,  0.0000e+00,  1.4013e-45,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  4.2039e-45,  0.0000e+00,  8.9129e+04,  1.1600e+00,
         4.2039e-45,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  1.3815e-30,  0.0000e+00,  2.2561e-43,  0.0000e+00,
         2.6974e-31,  0.0000e+00,  1.6658e-30,  0.0000e+00])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Possible PR discuss] Will a PR of training HF model be welcomed?
9 participants