Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support tensor parallel & Data loader #3173

Merged
merged 4 commits into from
Jan 29, 2025
Merged

Conversation

kmehant
Copy link
Contributor

@kmehant kmehant commented Oct 16, 2024

What does this PR do?

  1. Implements TorchTensorParallelPlugin to support TP with Pytorch 2.0. This work should be seen along with the PR feat: add support for tensor parallel training workflow with accelerate transformers#34194.
  2. Simplify Tensor Parallel implementation with PyTorch TP transformers#34184
  3. Modifies dataloader to support passing same samples across TP ranks

Please review in conjunction with huggingface/transformers#34194

Results

See significant improvement in both memory and throughput compared against single gpu training, and FSDP across different settings (checkpointing on/off) and context lengths.

Done on two models

  1. ibm-granite/granite-8b-code-base-128k
  2. codellama/CodeLlama-7b-hf

Tables below show the max cuda memory and throughput for various configurations showing the potential of TP contributed in this PR. There is gains in both memory and throughput.

Note: Please be aware that the effective TPS for FSDP would be multiplicative of the parallel factor (number of GPUs/devices engaged in distributed training) whereas that is not the case with TP. Therefore, when effective throughput is considered we can find FSDP is better than TP in terms of throughput. However, that may be compensated by increasing the batch size utilizing the memory gains etc.

Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 8192 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 8192 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 8192 1 FALSE 52.4 7675.4
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 8192 1 TRUE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 8192 1 TRUE 29.975586 2256.896
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 8192 1 TRUE 26.5 5935.5
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 16384 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 16384 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 16384 1 FALSE OOM NA
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 16384 1 TRUE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 16384 1 TRUE 36.8 2084.864
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 16384 1 TRUE 33.5 5692.5
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 8192 1 FALSE OOM NA
codellama/CodeLlama-7b-hf FSDP 4 8192 1 FALSE 70.7 3560
codellama/CodeLlama-7b-hf TP (This PR) 4 8192 1 FALSE 42.8 9216
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 8192 1 TRUE 75.3 2849
codellama/CodeLlama-7b-hf FSDP 4 8192 1 TRUE 26.4 5957
codellama/CodeLlama-7b-hf TP (This PR) 4 8192 1 TRUE 21.4 7125
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 16384 1 FALSE OOM NA
codellama/CodeLlama-7b-hf FSDP 4 16384 1 FALSE OOM NA
codellama/CodeLlama-7b-hf TP (This PR) 4 16384 1 FALSE OOM NA
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 16384 1 TRUE 75.3 2599
codellama/CodeLlama-7b-hf FSDP 4 16384 1 TRUE 30.1 2433
codellama/CodeLlama-7b-hf TP (This PR) 4 16384 1 TRUE 26.6 6873

Fixes # (issue)
huggingface/transformers#32470

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

I have cycles to bring in more improvements over this PR to bring in Pytorch TP support to HF. Looking forward. Thank you

@kmehant kmehant changed the title feat: support tensor parallel using Pytorch 2.0 feat: support tensor parallel using Pytorch 2.0 & Data loader Oct 24, 2024
Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks great to me. We do still need to update this to work with accelerate config however, whcih happens in commands/config and commands/launch. Would you like to do so?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@muellerzr
Copy link
Collaborator

@kmehant if you rebase from main this should fix the failures (tl;dr we had py 3.8 EOL)

@kmehant
Copy link
Contributor Author

kmehant commented Oct 29, 2024

@muellerzr Appreciate your response. I would like to bring to your notice the below two points.

  1. This dataloader written to work for the paradigm (call it paradigm 1) of master process fetching the data needed and distributing them to all the worker processes. The more general paradigm (call it paradigm 2) of all the processes fetching their own data sample in TP case it has to be the same batch across the processes is not covered in this PR.
  2. This PR has a soft dependency to apply TP plan over the model since this PR is more like of 2 parts - TP workflow through accelerate plugin + dataloader.
    1. First part of the PR applies TP parallelism to the model like shown here - https://github.com/huggingface/accelerate/pull/3173/files#diff-2d7515874eaecac2687c7fc1a9c720be53f802bf14b4c3dcebe14ad443d075dcR1467 creating a soft dependency over feat: add support for tensor parallel training workflow with accelerate transformers#34194 (Part of this would be superseded by Simplify Tensor Parallel implementation with PyTorch TP transformers#34184 that is carrying a different interface to apply TP plan to the model).
    2. second part is the dataloader

For point (1) I can keep this PR simple and allow only for the paradigm 1 and address the paradigm 2 in another PR.
For point (2) I can remove application of TP part from this PR, keeping this simple and independent. The part removed can be added in a separate PR as point (2)(i) is completed.

WDYT?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR, this looks nice. I have a few smaller comments, please take a look.

Also, please ensure that make quality passes.

@@ -1457,6 +1463,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
model.apply_tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply_tensor_parallel will be implemented in huggingface/transformers#34194 but only for select model architectures, right? Should we check this and if not present, raise an appropriate error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @BenjaminBossan

The tensor_parallel() interface will be implemented here - https://github.com/huggingface/transformers/pull/34184/files#diff-6b72b98c4c2dcfc6cc606843917733f5d858374fbc22a735ff483bbc0c1e63eaR5017

I have raised a comment on providing a way to know if tensor_parallel succeeded or not. Once that PR is ready, we can handle it here. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let's see what the final result will be. But we could also check hasattr(model, "apply_tensor_parallel") or would that not work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan
The function tensor_parallel is being added to the parent class PretrainedModel so all the model classes would have this function irrespective of it being available or not for a model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, in that case it is crucial to add a method or attribute to check the support for TP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan

has_tp_plan property is added, so updated the code here to fail when the model has no support thank you.

src/accelerate/data_loader.py Outdated Show resolved Hide resolved
src/accelerate/data_loader.py Outdated Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Show resolved Hide resolved
@kmehant kmehant force-pushed the tp branch 5 times, most recently from da67cba to c096d40 Compare November 4, 2024 09:57
@kmehant
Copy link
Contributor Author

kmehant commented Nov 4, 2024

@muellerzr can I work on this #3173 (review) in a separate PR?

I have fetched and rebased my PR and addressed all the review comments thank you.

@HoangCongDuc
Copy link

This feature is really useful, thank you @kmehant. I wonder if it is possible to combine tensor parallel with data parallel after this PR, say, TP for same-node parallelism and DP for multi-node parallelism.

@kmehant
Copy link
Contributor Author

kmehant commented Nov 17, 2024

This feature is really useful, thank you @kmehant. I wonder if it is possible to combine tensor parallel with data parallel after this PR, say, TP for same-node parallelism and DP for multi-node parallelism.

Hi @HoangCongDuc, support for that is in my TODOs but not covered in this PR, should be coming soon after discussing with HF. Thank you.

@@ -1461,6 +1467,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
if not model.has_tp_plan:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that the attribute was renamed to supports_tp_plan? Maybe let's wait until that other PR is merged so that this one does not need to be adapted constantly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan
Yes, it got modified. I have updated this PR again and also that PR to transformers is now merged :)

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Overall the code looks sound, what I'd appreciate however is if we could bring this the last 10% of the way through:

  1. Actually implementing this in the CLI and setting the env variable up properly
  2. Writing some tests (src/accelerate/test_utils/scripts/test_tensor_parallel.py IMO)

@kmehant kmehant changed the title feat: support tensor parallel using Pytorch 2.0 & Data loader feat: support tensor parallel & Data loader Dec 13, 2024
@kmehant kmehant force-pushed the tp branch 7 times, most recently from 74bf3e2 to 974edd9 Compare December 13, 2024 18:07
@kmehant kmehant force-pushed the tp branch 8 times, most recently from d5cc290 to 6f016f5 Compare December 13, 2024 21:08
@kmehant
Copy link
Contributor Author

kmehant commented Dec 14, 2024

@muellerzr @BenjaminBossan

  1. Have implemented the CLI part
    1. accelerate launch usage
    2. accelerate config usage
  2. Have added run test using existing scripts for TP through CLI

Let me know if I have missed out something. Thank you.

Copy link

github-actions bot commented Jan 7, 2025

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates and sorry for the delays. My points have been addressed, so I'm approving the PR. But it would be great if @muellerzr could re-review it with the latest changes before merging.

@github-actions github-actions bot closed this Jan 17, 2025
@kmehant
Copy link
Contributor Author

kmehant commented Jan 17, 2025

@BenjaminBossan @muellerzr can we open this PR back and further requesting sooner review and merge since github actions tagged it stale. Thank you.

@SunMarc SunMarc reopened this Jan 23, 2025
@SunMarc
Copy link
Member

SunMarc commented Jan 23, 2025

Can you fix the conflits and the CI ? We will merge the PR soon after zach re-review it !

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of comments. Otherwise LGTM !

@@ -1471,6 +1492,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
if not model.supports_tp_plan:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model don't necessarily have the supports_tp_plan attribute. So make sure to check if the attribute exist also.

Copy link
Contributor Author

@kmehant kmehant Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review, I have added a check for this.

@@ -1471,6 +1492,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
if not model.supports_tp_plan:
raise NotImplementedError("Provided model does not support tensor parallelism")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give the user some indication of how they can add the support ? Also make sure to check the version of transformers needed as it was added quite recently.

Copy link
Contributor Author

@kmehant kmehant Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give the user some indication of how they can add the support ?

I have extended the message including these details

Also make sure to check the version of transformers needed as it was added quite recently.

Added version based check as well.

Thank you.

Comment on lines 1495 to 1496
elif self.distributed_type == DistributedType.TP:
if not model.supports_tp_plan:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a first version, it's fine if we rely on transformers integration but it would be a nice follow-up to make it compatible with any pytorch model.

@SunMarc SunMarc requested a review from muellerzr January 27, 2025 14:43
@require_non_torch_xla
@require_tp
@require_multi_device
@slow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to add another decorator for transformers lib

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have added transformers lib check decorator.

Comment on lines 16 to 17
from transformers.trainer_utils import set_seed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure to protect the import, the tests are failing bc of that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have moved the import into the test case itself which only runs when transformers lib is available. Thank you.

Signed-off-by: Mehant Kammakomati <[email protected]>
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! small nits

Comment on lines 43 to 46
from transformers.trainer_utils import set_seed

set_seed(42)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use set_seed from accelerate

Suggested change
from transformers.trainer_utils import set_seed
set_seed(42)
from accelerate.utils import set_seed
set_seed(42)

@SunMarc
Copy link
Member

SunMarc commented Jan 28, 2025

There are some failing tests in the CI, can you check if this is related to this PR @kmehant

@kmehant
Copy link
Contributor Author

kmehant commented Jan 28, 2025

@SunMarc Have fixed the failing test and tested it locally, requesting CI/CD workflow run again on this PR as a finishing check. Thanks.

Signed-off-by: Mehant Kammakomati <[email protected]>
Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nicely done! This looks great to me, thanks so much for all your work! I'll keep an eye on main and will ping you if we find test_tp fails at some point

@muellerzr muellerzr merged commit 0315365 into huggingface:main Jan 29, 2025
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants