Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for tensor parallel training workflow with accelerate #34194

Merged
merged 8 commits into from
Feb 18, 2025

Conversation

kmehant
Copy link
Contributor

@kmehant kmehant commented Oct 16, 2024

What does this PR do?

1. Add apply_tensor_parallel API to apply TP plan to Llama and Granite models Already merged in to HF/transformers.
2. Introduce tp_size user facing argument to be further consumed by accelerate (see huggingface/accelerate#3173)
3. Allows for e2e TP training.

Please review in conjunction with huggingface/accelerate#3173

Fixes #32470

Results

See significant improvement in both memory and throughput compared against single gpu training, and FSDP across different settings (checkpointing on/off) and context lengths.

Note: Please be aware that the effective TPS for FSDP would be multiplicative of the parallel factor (number of GPUs/devices engaged in distributed training) whereas that is not the case with TP. Therefore, when effective throughput is considered we can find FSDP is better than TP in terms of throughput. However, that may be compensated by increasing the batch size utilizing the memory gains etc.

Done on two models

  1. ibm-granite/granite-8b-code-base-128k
  2. codellama/CodeLlama-7b-hf

Tables below show the max cuda memory and throughput for various configurations showing the potential of TP contributed in this PR. There is gains in both memory and throughput.

Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 8192 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 8192 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 8192 1 FALSE 52.4 7675.4
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 8192 1 TRUE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 8192 1 TRUE 29.975586 2256.896
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 8192 1 TRUE 26.5 5935.5
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 16384 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 16384 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 16384 1 FALSE OOM NA
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 16384 1 TRUE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 16384 1 TRUE 36.8 2084.864
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 16384 1 TRUE 33.5 5692.5
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 8192 1 FALSE OOM NA
codellama/CodeLlama-7b-hf FSDP 4 8192 1 FALSE 70.7 3560
codellama/CodeLlama-7b-hf TP (This PR) 4 8192 1 FALSE 42.8 9216
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 8192 1 TRUE 75.3 2849
codellama/CodeLlama-7b-hf FSDP 4 8192 1 TRUE 26.4 5957
codellama/CodeLlama-7b-hf TP (This PR) 4 8192 1 TRUE 21.4 7125
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 16384 1 FALSE OOM NA
codellama/CodeLlama-7b-hf FSDP 4 16384 1 FALSE OOM NA
codellama/CodeLlama-7b-hf TP (This PR) 4 16384 1 FALSE OOM NA
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 16384 1 TRUE 75.3 2599
codellama/CodeLlama-7b-hf FSDP 4 16384 1 TRUE 30.1 2433
codellama/CodeLlama-7b-hf TP (This PR) 4 16384 1 TRUE 26.6 6873

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

I have cycles to bring in more improvements over this PR to bring in Pytorch TP support to HF. Looking forward. Thank you

HF projects:

@kwen2501
Copy link
Contributor

Such timing! I have similar thought here. Shall we collaborate?

@kmehant
Copy link
Contributor Author

kmehant commented Oct 17, 2024

@kwen2501 Absolutely, please let me know, how you want to take this forward. Thank you.

@kmehant kmehant force-pushed the tp branch 3 times, most recently from cce95b6 to c3fe4bf Compare December 13, 2024 21:03
@kmehant kmehant changed the title feat: add support for tensor parallel using Pytorch 2.0 feat: add support for tensor parallel using Pytorch Dec 13, 2024
@kmehant kmehant force-pushed the tp branch 2 times, most recently from 0a4fe4e to e79f4d1 Compare January 23, 2025 08:55
@kmehant kmehant changed the title feat: add support for tensor parallel using Pytorch feat: add support for tensor parallel training workflow with accelerate Jan 29, 2025
@kmehant
Copy link
Contributor Author

kmehant commented Jan 29, 2025

@ArthurZucker @muellerzr since accelerate PR (huggingface/accelerate#3173) is merged. Requesting review and merge of this PR which would allow for complete e2e training workflow using tensor parallelism. Thank you.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM ! Could you also update the docs (Efficient Training on Multiple GPUs doc / trainer doc ) ?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kmehant
Copy link
Contributor Author

kmehant commented Feb 11, 2025

@SunMarc Have added some documentation. I would also add a page of documentation in accelerate and then reference it back here in a separate PR. For now I kept the documentation in the PR self contained for merge and use. Thanks.

cc: @muellerzr

@kmehant kmehant force-pushed the tp branch 2 times, most recently from 5c065ac to 7e8a2c2 Compare February 11, 2025 12:00
@kmehant
Copy link
Contributor Author

kmehant commented Feb 11, 2025

@SunMarc failing tests look unrelated to this PR and seem to happen after I rebased with main branch. Thank you.

@kmehant
Copy link
Contributor Author

kmehant commented Feb 11, 2025

@SunMarc @muellerzr I have a follow up PR as well - #36132 FYA. Thank you.

@kmehant
Copy link
Contributor Author

kmehant commented Feb 11, 2025

@SunMarc for some reason, even for the recently merged commits are failing for these two CI tests.

@kmehant kmehant force-pushed the tp branch 2 times, most recently from d073f05 to 177bd94 Compare February 12, 2025 12:20
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! One small comment I'd like addressed so users aren't confused by things breaking/we can detect early, and then good to go on my end

Comment on lines 1985 to 2000
if self.tp_size > 1:
os.environ["ACCELERATE_USE_TP"] = "true"
os.environ["TP_SIZE"] = str(self.tp_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During here we should also guard the accelerate version required to use tp, can you add this please? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have added this guard. Thank you for noticing it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we push for a merge? Thank you.

if self.args.tp_size > 1:
self.is_tp_enabled = True
if version.parse(accelerate_version) > version.parse("1.3.0"):
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with this API, so we need some documentation about what this uses under the hood!
Also we could check if model supports TP? or is it not even needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we could check if model supports TP? or is it not even needed?

Hi @ArthurZucker We check this in accelerate here - https://github.com/huggingface/accelerate/blob/526925b48c07d997cdd9bf5911f659ca45778915/src/accelerate/accelerator.py#L1511

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker

I am not familiar with this API, so we need some documentation about what this uses under the hood!

Do you want me to add that as a comment above this piece of code or please point me to a place in HF docs where you want me to add it. Thank you.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be added to the documentation for TP feature !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but interesting did not know accelerate already supported it!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker We added the support to accelerate at this PR - huggingface/accelerate#3173

@ArthurZucker
Copy link
Collaborator

Very much welcome otherwise!!!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM otherwise. Some doc in the TP documentation to say it's supported to train with it would be nice as well! Documenting features is key!


tp_size (`int`, *optional*):
Use tp_size to enable PyTorch tensor parallelism. Set a value greater than 1 to activate TP. The same is
used to prepare device mesh internally. Requires accelerate>1.3.0.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should explain that this uses the models' base_tp_plan and is not available for all models etc!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker added this fix thank you

@kmehant
Copy link
Contributor Author

kmehant commented Feb 13, 2025

@ArthurZucker Have added and updated docs wherever needed for TP, requesting your review. Thank you.

@kmehant kmehant force-pushed the tp branch 2 times, most recently from 33f5ad2 to ace502d Compare February 13, 2025 17:01
@kmehant
Copy link
Contributor Author

kmehant commented Feb 13, 2025

@ArthurZucker failing tests do not seem to be related to this PR

@kmehant
Copy link
Contributor Author

kmehant commented Feb 18, 2025

@ArthurZucker @muellerzr @SunMarc Can we go ahead with merge?

@SunMarc SunMarc merged commit c3ba533 into huggingface:main Feb 18, 2025
23 checks passed
@bursteratom
Copy link
Contributor

Hi @kmehant thank you very much for the much needed PR! Just wanna get clarified, aside from setting the tp_size which is passed into the transformers trainer, you also need to set tp_plan="auto" in the automodel.from_pretrained(...) right?

@kmehant
Copy link
Contributor Author

kmehant commented Feb 19, 2025

@bursteratom setting tp_plan="auto" is not needed since TP is applied to the model internally by accelerate when you are using through transformers with tp_size.

@bursteratom
Copy link
Contributor

bursteratom commented Feb 20, 2025

@kmehant Thank you for your response! I am asking cuz I ran into RuntimeError: aten._foreach_norm.Scalar: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTe nsor before calling distributed operators! when attempting to do TP with accelerate. I only set tp_size=2 in the trainer kwarg and did not set tp_plan="auto" at all. I am wondering if there are some additional steps that I missed? (testing with codellama/CodeLlama-7b-hf btw)

@kmehant
Copy link
Contributor Author

kmehant commented Feb 20, 2025

@bursteratom Right, apologies for not covering limitations in this PR. Essentially this happening due to gradient clipping which is not supported yet (PR is in place to fix this - #36132). So a workaround for this would be to set max_grad_norm to -1 until the above PR falls in place.

cc: @kwen2501 @muellerzr @SunMarc

@bursteratom
Copy link
Contributor

@kmehant Thank you so much for the tip and your amazing work! Looking forward to your fix getting merged!

@iMountTai iMountTai mentioned this pull request Feb 20, 2025
4 tasks
@bursteratom
Copy link
Contributor

@kmehant apologies for the repeated ping. I was able to train with tensor parallel successfully in accelerate, but ran into RuntimeError: Attempted to access the data pointer on an invalid python storage. when attempting to save the trained model. Any chance you know why this is happening?

image

zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Feb 21, 2025
…te (huggingface#34194)

* feat: add support for tensor parallel flow using accelerate

Signed-off-by: Mehant Kammakomati <[email protected]>

* fix: add tp degree to env variable

Signed-off-by: Mehant Kammakomati <[email protected]>

* fix: add version check for accelerate to allow TP

Signed-off-by: Mehant Kammakomati <[email protected]>

* docs: tensor parallelism

Signed-off-by: Mehant Kammakomati <[email protected]>

* nit: rename plugin name

Signed-off-by: Mehant Kammakomati <[email protected]>

* fix: guard accelerate version before allow tp

Signed-off-by: Mehant Kammakomati <[email protected]>

* docs: add more docs and updates related to TP

Signed-off-by: Mehant Kammakomati <[email protected]>

---------

Signed-off-by: Mehant Kammakomati <[email protected]>
Co-authored-by: Marc Sun <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enhancing Hugging Face Models with Tensor Parallelism for Large-Scale Model Support 🚀
7 participants