Skip to content

Commit c3ba533

Browse files
kmehantSunMarc
andauthored
feat: add support for tensor parallel training workflow with accelerate (#34194)
* feat: add support for tensor parallel flow using accelerate Signed-off-by: Mehant Kammakomati <[email protected]> * fix: add tp degree to env variable Signed-off-by: Mehant Kammakomati <[email protected]> * fix: add version check for accelerate to allow TP Signed-off-by: Mehant Kammakomati <[email protected]> * docs: tensor parallelism Signed-off-by: Mehant Kammakomati <[email protected]> * nit: rename plugin name Signed-off-by: Mehant Kammakomati <[email protected]> * fix: guard accelerate version before allow tp Signed-off-by: Mehant Kammakomati <[email protected]> * docs: add more docs and updates related to TP Signed-off-by: Mehant Kammakomati <[email protected]> --------- Signed-off-by: Mehant Kammakomati <[email protected]> Co-authored-by: Marc Sun <[email protected]>
1 parent e6cc410 commit c3ba533

File tree

8 files changed

+133
-6
lines changed

8 files changed

+133
-6
lines changed

docs/source/ar/trainer.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,29 @@ tpu_use_sudo: false
673673
use_cpu: false
674674
```
675675
676+
</hfoption>
677+
<hfoption id="Tensor Parallelism with PyTorch 2">
678+
679+
```yml
680+
compute_environment: LOCAL_MACHINE
681+
tp_config:
682+
tp_size: 4
683+
distributed_type: TP
684+
downcast_bf16: 'no'
685+
machine_rank: 0
686+
main_training_function: main
687+
mixed_precision: 'no'
688+
num_machines: 1
689+
num_processes: 4
690+
rdzv_backend: static
691+
same_network: true
692+
tpu_env: []
693+
tpu_use_cluster: false
694+
tpu_use_sudo: false
695+
use_cpu: false
696+
697+
```
698+
676699
</hfoption>
677700
</hfoptions>
678701
يُعد أمر [`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) هو الطريقة المُوصى بها لتشغيل نص البرمجى للتدريب على نظام موزع باستخدام Accelerate و [`Trainer`] مع المعلمات المحددة في `config_file.yaml`. يتم حفظ هذا الملف في مجلد ذاكرة التخزين المؤقت لـ Accelerate ويتم تحميله تلقائيًا عند تشغيل `accelerate_launch`.

docs/source/en/llm_tutorial_optimization.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ To give some examples of how much VRAM it roughly takes to load a model in bfloa
5555

5656
As of writing this document, the largest GPU chip on the market is the A100 & H100 offering 80GB of VRAM. Most of the models listed before require more than 80GB just to be loaded and therefore necessarily require [tensor parallelism](https://huggingface.co/docs/transformers/perf_train_gpu_many#tensor-parallelism) and/or [pipeline parallelism](https://huggingface.co/docs/transformers/perf_train_gpu_many#naive-model-parallelism-vertical-and-pipeline-parallelism).
5757

58-
🤗 Transformers does not support tensor parallelism out of the box as it requires the model architecture to be written in a specific way. If you're interested in writing models in a tensor-parallelism-friendly way, feel free to have a look at [the text-generation-inference library](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models/custom_modeling).
58+
🤗 Transformers now supports tensor parallelism for supported models having `base_tp_plan` in their respecitve config classes. Learn more about Tensor Parallelism [here](perf_train_gpu_many#tensor-parallelism). Furthermore, if you're interested in writing models in a tensor-parallelism-friendly way, feel free to have a look at [the text-generation-inference library](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models/custom_modeling).
5959

6060
Naive pipeline parallelism is supported out of the box. For this, simply load the model with `device="auto"` which will automatically place the different layers on the available GPUs as explained [here](https://huggingface.co/docs/accelerate/v0.22.0/en/concept_guides/big_model_inference).
6161
Note, however that while very effective, this naive pipeline parallelism does not tackle the issues of GPU idling. For this more advanced pipeline parallelism is required as explained [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many#naive-model-parallelism-vertical-and-pipeline-parallelism).

docs/source/en/perf_train_gpu_many.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,12 +450,13 @@ Implementations:
450450
- [parallelformers](https://github.com/tunib-ai/parallelformers) (only inference at the moment)
451451
- [SageMaker](https://arxiv.org/abs/2111.05972) - this is a proprietary solution that can only be used on AWS.
452452
- [OSLO](https://github.com/tunib-ai/oslo) has the tensor parallelism implementation based on the Transformers.
453+
- [`transformers` integration](main_classes/trainer) tensor parallelism is available through tp_size attribute for models having `base_tp_plan`. Further you can look at [example usage](perf_infer_gpu_multi)
453454

454455
SageMaker combines TP with DP for a more efficient processing.
455456

456457
🤗 Transformers status:
457-
- core: not yet implemented in the core
458-
- but if you want inference [parallelformers](https://github.com/tunib-ai/parallelformers) provides this support for most of our models. So until this is implemented in the core you can use theirs. And hopefully training mode will be supported too.
458+
- core: uses PyTorch 2 APIs to support tensor parallelism to models having base_tp_plan in their respective config classes.
459+
- Alternatively, you can as well try [parallelformers](https://github.com/tunib-ai/parallelformers) that provides this support for most of our models. Training mode with TP is as well supported natively in transformers.
459460
- Deepspeed-Inference also supports our BERT, GPT-2, and GPT-Neo models in their super-fast CUDA-kernel-based inference mode, see more [here](https://www.deepspeed.ai/tutorials/inference-tutorial/)
460461

461462
🤗 Accelerate integrates with [TP from Megatron-LM](https://huggingface.co/docs/accelerate/v0.23.0/en/usage_guides/megatron_lm).
@@ -535,7 +536,7 @@ Important papers:
535536
- [Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model](
536537
https://arxiv.org/abs/2201.11990)
537538

538-
🤗 Transformers status: not yet implemented, since we have no PP and TP.
539+
🤗 Transformers status: not yet implemented, since we have no PP.
539540

540541
## FlexFlow
541542

docs/source/en/trainer.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,29 @@ tpu_use_sudo: false
799799
use_cpu: false
800800
```
801801
802+
</hfoption>
803+
<hfoption id="Tensor Parallelism with PyTorch 2">
804+
805+
```yml
806+
compute_environment: LOCAL_MACHINE
807+
tp_config:
808+
tp_size: 4
809+
distributed_type: TP
810+
downcast_bf16: 'no'
811+
machine_rank: 0
812+
main_training_function: main
813+
mixed_precision: 'no'
814+
num_machines: 1
815+
num_processes: 4
816+
rdzv_backend: static
817+
same_network: true
818+
tpu_env: []
819+
tpu_use_cluster: false
820+
tpu_use_sudo: false
821+
use_cpu: false
822+
823+
```
824+
802825
</hfoption>
803826
</hfoptions>
804827

docs/source/es/trainer.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,30 @@ use_cpu: false
361361

362362
```
363363
364+
</hfoption>
365+
366+
<hfoption id="Tensor Parallelism with PyTorch 2">
367+
368+
```yml
369+
compute_environment: LOCAL_MACHINE
370+
tp_config:
371+
tp_size: 4
372+
distributed_type: TP
373+
downcast_bf16: 'no'
374+
machine_rank: 0
375+
main_training_function: main
376+
mixed_precision: 'no'
377+
num_machines: 1
378+
num_processes: 4
379+
rdzv_backend: static
380+
same_network: true
381+
tpu_env: []
382+
tpu_use_cluster: false
383+
tpu_use_sudo: false
384+
use_cpu: false
385+
386+
```
387+
364388
</hfoption>
365389
</hfoptions>
366390

docs/source/ko/trainer.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,29 @@ tpu_use_sudo: false
548548
use_cpu: false
549549
```
550550
551+
</hfoption>
552+
<hfoption id="Tensor Parallelism with PyTorch 2">
553+
554+
```yml
555+
compute_environment: LOCAL_MACHINE
556+
tp_config:
557+
tp_size: 4
558+
distributed_type: TP
559+
downcast_bf16: 'no'
560+
machine_rank: 0
561+
main_training_function: main
562+
mixed_precision: 'no'
563+
num_machines: 1
564+
num_processes: 4
565+
rdzv_backend: static
566+
same_network: true
567+
tpu_env: []
568+
tpu_use_cluster: false
569+
tpu_use_sudo: false
570+
use_cpu: false
571+
572+
```
573+
551574
</hfoption>
552575
</hfoptions>
553576

src/transformers/trainer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@
241241
)
242242

243243
DATA_SAMPLERS = [RandomSampler]
244+
if version.parse(accelerate_version) > version.parse("1.3.0"):
245+
from accelerate.utils import TorchTensorParallelPlugin
244246
if version.parse(accelerate_version) > version.parse("0.23.0"):
245247
from accelerate.data_loader import SeedableRandomSampler
246248

@@ -5094,6 +5096,14 @@ def create_accelerator_and_postprocess(self):
50945096
args["dataloader_config"] = dataloader_config
50955097
else:
50965098
args.update(accelerator_config)
5099+
# tp is initialized at Accelerator init phase so
5100+
# args should be prepared here
5101+
if self.args.tp_size > 1:
5102+
self.is_tp_enabled = True
5103+
if version.parse(accelerate_version) > version.parse("1.3.0"):
5104+
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size)
5105+
else:
5106+
raise ValueError("Requires accelerate>1.3.0 to use Tensor Parallelism.")
50975107

50985108
# create accelerator object
50995109
self.accelerator = Accelerator(**args)
@@ -5108,7 +5118,7 @@ def create_accelerator_and_postprocess(self):
51085118
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
51095119
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
51105120
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
5111-
5121+
self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None
51125122
# post accelerator creation setup
51135123
if self.is_fsdp_enabled:
51145124
fsdp_plugin = self.accelerator.state.fsdp_plugin

src/transformers/training_args.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,10 @@ class TrainingArguments:
569569
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
570570
used when the xla flag is set to true, and an auto wrapping policy is specified through
571571
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
572-
572+
tp_size (`int`, *optional*):
573+
Use tp_size to enable PyTorch tensor parallelism. Tensor parallelism support is only available to models having `base_tp_plan`
574+
in their respective config classes.
575+
Set a value greater than 1 to activate TP. The same is used to prepare device mesh internally. Requires accelerate>1.3.0.
573576
deepspeed (`str` or `dict`, *optional*):
574577
Use [Deepspeed](https://github.com/deepspeedai/DeepSpeed). This is an experimental feature and its API may
575578
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
@@ -1250,6 +1253,18 @@ class TrainingArguments:
12501253
)
12511254
},
12521255
)
1256+
tp_size: Optional[int] = field(
1257+
default=0,
1258+
metadata={
1259+
"help": (
1260+
"Use tp_size to enable pytorch tensor parallelism."
1261+
"Tensor parallelism support is only available to models having `base_tp_plan` in their respective config classes."
1262+
"Set a value greater than 1 to activate TP."
1263+
"The same is used to prepare device mesh internally."
1264+
"Requires accelerate>1.3.0."
1265+
)
1266+
},
1267+
)
12531268
fsdp_transformer_layer_cls_to_wrap: Optional[str] = field(
12541269
default=None,
12551270
metadata={
@@ -1975,6 +1990,14 @@ def __post_init__(self):
19751990
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
19761991
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
19771992

1993+
if self.tp_size > 1:
1994+
if not is_accelerate_available("1.3.1"):
1995+
raise NotImplementedError(
1996+
"TP using PyTorch requires Accelerate version `accelerate` >= 1.3.1. "
1997+
"This is not supported and we recommend you to update your version."
1998+
)
1999+
os.environ["ACCELERATE_USE_TP"] = "true"
2000+
os.environ["TP_SIZE"] = str(self.tp_size)
19782001
# accelerate integration for FSDP
19792002
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
19802003
os.environ["ACCELERATE_USE_FSDP"] = "true"

0 commit comments

Comments
 (0)