Skip to content

Commit ed89ef1

Browse files
committed
distributedconfig api
1 parent 46b2931 commit ed89ef1

4 files changed

Lines changed: 156 additions & 0 deletions

File tree

.claude/skills

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../.ai/skills

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@
219219
title: Accelerator selection
220220
- local: accelerate
221221
title: Accelerate
222+
- local: distributed_config
223+
title: DistributedConfig
222224
- local: fsdp
223225
title: FullyShardedDataParallel
224226
- local: deepspeed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
<!--Copyright 2026 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# DistributedConfig
18+
19+
[`DistributedConfig`] shards a model across GPUs directly through [`~PreTrainedModel.from_pretrained`]. It supports [FSDP2](./fsdp), [tensor parallelism](./tensor_parallelism), and [N-D parallelism](./perf_train_gpu_many).
20+
21+
Pass a [`DistributedConfig`] to [`~PreTrainedModel.from_pretrained`] and Transformers builds the device mesh and shards the supported layers for you.
22+
23+
The fields below control how the model is sharded.
24+
25+
| field | description |
26+
|---|---|
27+
| `tp_size` | Number of devices for tensor parallelism. Defaults to 1 when only `fsdp_size` is set. |
28+
| `tp_plan` | Tensor parallel sharding plan. Leave as `None` to use the model's default plan. |
29+
| `fsdp_size` | Number of devices for FSDP2. Defaults to 1 when only `tp_size` is set. |
30+
| `fsdp_cpu_offload` | Offload parameters and gradients to CPU to save GPU memory. Defaults to `False`. |
31+
| `fsdp_mixed_precision` | Compute in `bfloat16` and reduce gradients in `float32`. Defaults to `False`. |
32+
| `enable_expert_parallel` | Shard mixture-of-experts layers across devices. See [Expert parallelism](./expert_parallelism). |
33+
34+
The product of `tp_size` and `fsdp_size` must equal the number of devices you launch with.
35+
36+
## FSDP2
37+
38+
[FSDP2](./fsdp) shards parameters, gradients, and optimizer states across GPUs. Set `fsdp_size` to the number of devices to shard across.
39+
40+
```py
41+
import torch
42+
from transformers import AutoModelForCausalLM
43+
from transformers.distributed.configuration_utils import DistributedConfig
44+
45+
distributed_config = DistributedConfig(fsdp_size=4)
46+
47+
model = AutoModelForCausalLM.from_pretrained(
48+
"Qwen/Qwen3-0.6B",
49+
distributed_config=distributed_config,
50+
)
51+
```
52+
53+
Transformers wraps each layer according to the model's `base_model_fsdp_plan`. Check whether a model declares one before sharding.
54+
55+
```py
56+
from transformers import AutoConfig
57+
58+
config = AutoConfig.from_pretrained("Qwen/Qwen3-0.6B")
59+
print(config.base_model_fsdp_plan)
60+
```
61+
62+
The plan maps modules to a sharding strategy. `free_full_weight` reshards a module after the forward pass to save memory, and `keep_full_weight` keeps it gathered to avoid a second all-gather during the backward pass.
63+
64+
```py
65+
{
66+
"embed_tokens": "free_full_weight",
67+
"layers.*": "free_full_weight",
68+
"norm": "keep_full_weight",
69+
}
70+
```
71+
72+
Set `fsdp_mixed_precision=True` to compute in `bfloat16` while reducing gradients in `float32`, and set `fsdp_cpu_offload=True` to move parameters and gradients to CPU when they aren't in use.
73+
74+
```py
75+
distributed_config = DistributedConfig(
76+
fsdp_size=4,
77+
fsdp_mixed_precision=True,
78+
fsdp_cpu_offload=True,
79+
)
80+
```
81+
82+
## Tensor parallelism
83+
84+
[Tensor parallelism](./tensor_parallelism) splits weight matrices across GPUs. Set `tp_size` to shard the model's supported layers.
85+
86+
```py
87+
import torch
88+
from transformers import AutoModelForCausalLM
89+
from transformers.distributed.configuration_utils import DistributedConfig
90+
91+
distributed_config = DistributedConfig(tp_size=4)
92+
93+
model = AutoModelForCausalLM.from_pretrained(
94+
"Qwen/Qwen3-0.6B",
95+
distributed_config=distributed_config,
96+
)
97+
```
98+
99+
Transformers shards according to the model's `base_model_tp_plan`. Pass `tp_plan` to override the layout, for example `{"model.layers.*.self_attn.q_proj": "colwise"}`.
100+
101+
## N-D parallelism
102+
103+
Combine FSDP2 and tensor parallelism by setting both sizes. The example below runs on 4 GPUs, sharding each tensor-parallel group of 2 GPUs with FSDP2 across the remaining 2.
104+
105+
```py
106+
import torch
107+
from transformers import AutoModelForCausalLM
108+
from transformers.distributed.configuration_utils import DistributedConfig
109+
110+
distributed_config = DistributedConfig(tp_size=2, fsdp_size=2)
111+
112+
model = AutoModelForCausalLM.from_pretrained(
113+
"Qwen/Qwen3-0.6B",
114+
dtype=torch.bfloat16,
115+
distributed_config=distributed_config,
116+
)
117+
```
118+
119+
## Launch
120+
121+
Launch your script with [torchrun](https://pytorch.org/docs/stable/elastic/run.html) and set `--nproc-per-node` to the total number of devices, equal to `tp_size * fsdp_size`.
122+
123+
```shell
124+
torchrun --nproc-per-node 4 train.py
125+
```
126+
127+
## Next steps
128+
129+
- See [FSDP2](./fsdp) for sharded training.
130+
- See [Tensor parallelism](./tensor_parallelism) for more details on partitioning strategies and manual plans.
131+
- See [Expert parallelism](./expert_parallelism) for sharding mixture-of-experts models.
132+
- See [N-D parallelism](./perf_train_gpu_many) for combining parallelism strategies.
133+
- Read [The Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook) for a deeper look at how these strategies work.

utils/.checkers_cache.json

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"add_dates": "2c0ef6a3ec2eb2a3cce73c6a74c0564b219e698b6f5946270661116133c9a07c",
3+
"auto_mappings": "7e19518867242e07284c2f62b21f1195b762d47e1693cee727842b880493aceb",
4+
"config_attributes": "b050cab4ee3f179c4d19dc8856da37a1f09385fb2b29df5c58e045ee0ec29b40",
5+
"config_docstrings": "fbc006ec716f7421d51c9b245fe496184b1e7ce50627ce4e75a2ac0836d95346",
6+
"copies": "28b52d9a0d557147c611907c7c5177ca3219708aa6be6f4fdbb28fd2a93f0ea3",
7+
"deps_table": "dd2c3dd9c20aba4869ced10b5dbfa9dcc443b0981aace7b2a0fbf4b5e5cec2c1",
8+
"docstrings": "5ba8326a194c9606de1424f1f5c1e20077545c53a9ed6b768a6f8a2e1870f7e6",
9+
"doctest_list": "98897e42dabaed5c666734f12a5049f4327fc89bad4621819adf55ce3e9c2a66",
10+
"dummies": "8b9eb0f2047c2e692adba8e01f4207370ffe3b4de8b83482b98cea4630b3e2ef",
11+
"imports": "4e8c8768fc924f3f530debaf287bad4bb9d267e7c86728450cb63e9b7c201376",
12+
"init_isort": "1d049dc690b05fad7209f1e3ccb49ebce51db3fef94b63e140ce5b69c1ab24af",
13+
"inits": "13852b590793c350372c94fdedb7f16b0e081bf61ec4ed83fae13304b19e837f",
14+
"modular_conversion": "8e778ff2f66849bb611c594bcdcb2be8125b467e32a2537ebbf37467f1943422",
15+
"pipeline_typing": "3cb9d37a9d033222ad798914141cd056e264f5158754fb590580e6ac85128f72",
16+
"ruff_check": "0bacd4bcbd205e1611d816882ed10a719f77761c3950fd6d831899c267055a23",
17+
"ruff_format": "0bacd4bcbd205e1611d816882ed10a719f77761c3950fd6d831899c267055a23",
18+
"sort_auto_mappings": "3d98987835c97d17679c4732a38fce3bd46edd3dc5e9f09dc659d74cc4fca3c9",
19+
"update_metadata": "10a0fc570ecb47b9be79a682a831a2a67ab0cb7067cec849d1985493f969e371"
20+
}

0 commit comments

Comments
 (0)