From 9ed228690116747c17262216b0b1e5319a5af93a Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Wed, 16 Apr 2025 13:42:44 -0400 Subject: [PATCH 1/3] add multi-stage guide --- docs/user_guide/multi-stage.md | 80 ++++++++++++++++++++++++++++++++++ docs/user_guide/parallelism.md | 7 +++ mkdocs.yaml | 2 + 3 files changed, 89 insertions(+) create mode 100644 docs/user_guide/multi-stage.md create mode 100644 docs/user_guide/parallelism.md diff --git a/docs/user_guide/multi-stage.md b/docs/user_guide/multi-stage.md new file mode 100644 index 00000000..87447440 --- /dev/null +++ b/docs/user_guide/multi-stage.md @@ -0,0 +1,80 @@ +# Multi-Stage Training in Fast-LLM + +Fast-LLM trains large models by splitting them into *stages*, each running on a separate GPU or node. It reduces memory usage by distributing (or *sharding*) model state (weights, gradients, or optimizer states) across devices. + +This guide explains how to configure multi-stage training for both common and advanced use cases. + +## ZeRO-Stage Sharding + +Fast-LLM uses ZeRO-style sharding to partition model state efficiently across GPUs. This differs from pipeline parallelism, which partitions model computation into sequential pipeline stages. + +The primary setting for ZeRO sharding is `zero_stage` in your configuration: + +```yaml +multi_stage: + zero_stage: ... +``` + +The following table summarizes the behavior of `zero_stage`: + +| `zero_stage` | Weights | Gradients | Optimizer States | Communication overhead | +| ------------- | ---------- | ---------- | ---------------- | ------------------------------------------------------------ | +| `1` (default) | Replicated | Replicated | Sharded | Lowest, default choice | +| `2` | Replicated | Sharded | Sharded | Moderate, saves more memory at additional communication cost | +| `3` | Sharded | Sharded | Sharded | Highest, maximum memory saving with increased communication | + +Optimizer states are always sharded by default. ZeRO Stage 0 (full replication) is not supported. + +In general, start with the default (`zero_stage: 1`) and verify if your model trains without memory errors. If you encounter out-of-memory issues, try increasing `zero_stage`: + +```yaml +multi_stage: + zero_stage: 2 +``` + +Increased sharding reduces memory consumption but adds communication overhead between GPUs or nodes. Before increasing `zero_stage`, you might first try lowering the micro batch size or sequence length, since this usually incurs less overhead. + +You'll likely iterate between adjusting `zero_stage`, micro batch size, and sequence length to find the optimal balance of memory usage and training throughput. If these adjustments don't resolve your issue, or you're unsatisfied with tradeoffs like sequence length versus throughput, you may need to reconsider your broader parallelism strategy. This includes adjusting tensor parallelism, pipeline parallelism, or sequence data parallelism. That topic is covered in greater depth in the [Parallelism Guide](parallelism.md). + +## Expert Options + +Beyond `zero_stage`, Fast-LLM offers additional multi-stage settings for fine-tuning. These advanced options typically don't need manual adjustment. Change them only if you're certain about your goals and tradeoffs. + +### Buffers + +When gradients or weights are sharded, Fast-LLM accumulates partial results in shared *buffers* during forward and backward passes. These buffers reduce communication overhead by batching gradient or weight updates across GPUs or nodes. + +By default, Fast-LLM automatically determines buffer counts based on your `zero_stage` setting: + +- `num_grad_buffers`: + - `2` if `zero_stage >= 2` + - `1` otherwise +- `num_weight_buffers`: + - `2` if `zero_stage == 3` + - `1` otherwise + +If you want explicit control, you can override these values: + +```yaml +multi_stage: + num_grad_buffers: 3 + num_weight_buffers: 2 +``` + +For example, increasing `num_grad_buffers` to `3` or `4` will decrease inter-GPU communication frequency, potentially improving throughput—provided sufficient GPU memory is available. + +### Stage Layout Control + +You can adjust how layers and pipeline stages map onto GPUs or nodes: + +```yaml +multi_stage: + layers_per_stage: 1.0 + stages_per_pipeline_stage: 1 +``` + +Defaults work well in most cases: + +- **`layers_per_stage`**: Determines the number of layers per stage. Defaults to `1.0` (one layer per stage). Increase it to reduce inter-stage communication or decrease it for better load balancing. Fractional values are allowed. + +- **`stages_per_pipeline_stage`**: Specifies how many stages run per pipeline worker. This setting is relevant only when pipeline parallelism is active. Default is `1`. Increase to assign multiple stages to the same pipeline worker, potentially simplifying communication patterns at the cost of flexibility in load distribution. diff --git a/docs/user_guide/parallelism.md b/docs/user_guide/parallelism.md new file mode 100644 index 00000000..406908cd --- /dev/null +++ b/docs/user_guide/parallelism.md @@ -0,0 +1,7 @@ +--- +title: Parallelism +--- + +!!! warning + + Looking for the parallelism guide? It's on its way, come back soon! diff --git a/mkdocs.yaml b/mkdocs.yaml index 2439d55a..a080bc83 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -176,6 +176,8 @@ nav: - Reference: - User Guide: - Configuration: user_guide/configuration.md + - Multi-Stage: user_guide/multi-stage.md + - Parallelism: user_guide/parallelism.md - Developer Guide: - Configuration: developer_guide/configuration.md - Model: From 120df7364358d8f2780a2de704ffcfdecb816443 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Wed, 16 Apr 2025 20:44:11 -0400 Subject: [PATCH 2/3] address comments --- docs/user_guide/multi-stage.md | 52 +++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/docs/user_guide/multi-stage.md b/docs/user_guide/multi-stage.md index 87447440..b0abc1b0 100644 --- a/docs/user_guide/multi-stage.md +++ b/docs/user_guide/multi-stage.md @@ -2,11 +2,13 @@ Fast-LLM trains large models by splitting them into *stages*, each running on a separate GPU or node. It reduces memory usage by distributing (or *sharding*) model state (weights, gradients, or optimizer states) across devices. +A *stage* refers to a logical partition of a model, typically containing a subset of layers or computational steps. Each stage runs independently on its own GPU or node. + This guide explains how to configure multi-stage training for both common and advanced use cases. ## ZeRO-Stage Sharding -Fast-LLM uses ZeRO-style sharding to partition model state efficiently across GPUs. This differs from pipeline parallelism, which partitions model computation into sequential pipeline stages. +Fast-LLM uses ZeRO-style sharding to reduce memory usage by partitioning model state (such as weights, gradients, and optimizer states) across GPUs that would otherwise maintain full replicas in data parallelism. This is compatible with and complementary to model-parallel techniques like pipeline and tensor parallelism. The primary setting for ZeRO sharding is `zero_stage` in your configuration: @@ -17,14 +19,25 @@ multi_stage: The following table summarizes the behavior of `zero_stage`: -| `zero_stage` | Weights | Gradients | Optimizer States | Communication overhead | -| ------------- | ---------- | ---------- | ---------------- | ------------------------------------------------------------ | -| `1` (default) | Replicated | Replicated | Sharded | Lowest, default choice | -| `2` | Replicated | Sharded | Sharded | Moderate, saves more memory at additional communication cost | -| `3` | Sharded | Sharded | Sharded | Highest, maximum memory saving with increased communication | +| `zero_stage` | Weights | Gradients | Optimizer States | Communication overhead | +| ------------- | ---------- | ---------- | ---------------- | ------------------------- | +| `1` (default) | Replicated | Replicated | Sharded | Moderate, default choice | +| `2` | Replicated | Sharded | Sharded | Moderate[^1] | +| `3` | Sharded | Sharded | Sharded | High[^2] | + +[^1]: Communication overhead for ZeRO Stage 2 is similar to Stage 1, except during (depth-first) gradient accumulation when additional all-reduce operations occur. +[^2]: Communication overhead for ZeRO Stage 3 is higher than Stage 2, especially during (depth-first) gradient accumulation. Optimizer states are always sharded by default. ZeRO Stage 0 (full replication) is not supported. +While ZeRO Stage 3 introduces the most communication overhead, the practical difference between Stages 1 and 2 is minimal except during gradient accumulation. + +**Recommendation:** + +- **ZeRO Stage 1 (default)**: Ideal for most training scenarios. +- **ZeRO Stage 2**: Useful if gradients cause memory pressure. +- **ZeRO Stage 3**: Useful for very large models exceeding GPU memory. + In general, start with the default (`zero_stage: 1`) and verify if your model trains without memory errors. If you encounter out-of-memory issues, try increasing `zero_stage`: ```yaml @@ -32,9 +45,9 @@ multi_stage: zero_stage: 2 ``` -Increased sharding reduces memory consumption but adds communication overhead between GPUs or nodes. Before increasing `zero_stage`, you might first try lowering the micro batch size or sequence length, since this usually incurs less overhead. +Increasing ZeRO-style sharding reduces memory consumption but may add communication overhead between GPUs or nodes, potentially slowing down training. Before increasing `zero_stage`, first try lowering the micro batch size or sequence length, as this typically incurs less overhead. -You'll likely iterate between adjusting `zero_stage`, micro batch size, and sequence length to find the optimal balance of memory usage and training throughput. If these adjustments don't resolve your issue, or you're unsatisfied with tradeoffs like sequence length versus throughput, you may need to reconsider your broader parallelism strategy. This includes adjusting tensor parallelism, pipeline parallelism, or sequence data parallelism. That topic is covered in greater depth in the [Parallelism Guide](parallelism.md). +You'll likely iterate between adjusting `zero_stage`, micro batch size, and sequence length to find the optimal balance of memory usage and training throughput. If these adjustments don't resolve your issue, or you're unsatisfied with tradeoffs like sequence length versus throughput, reconsider your broader parallelism strategy. This includes adjusting tensor parallelism, pipeline parallelism, or sequence data parallelism, covered in greater depth in the [Parallelism Guide](parallelism.md). ## Expert Options @@ -42,16 +55,11 @@ Beyond `zero_stage`, Fast-LLM offers additional multi-stage settings for fine-tu ### Buffers -When gradients or weights are sharded, Fast-LLM accumulates partial results in shared *buffers* during forward and backward passes. These buffers reduce communication overhead by batching gradient or weight updates across GPUs or nodes. +When gradients or weights are sharded, Fast-LLM accumulates partial results in shared *buffers* during forward and backward passes, separately for gradients and weights. These buffers reduce communication overhead by batching gradient or weight updates across GPUs or nodes. The options `num_grad_buffers` and `num_weight_buffers` control the number of buffers used for gradients and weights, respectively. -By default, Fast-LLM automatically determines buffer counts based on your `zero_stage` setting: +By default, Fast-LLM assigns one gradient and weight buffer per stage, where the number of stages equals the total number of logical partitions (stages) of the model. This enables overlapping communication (e.g., data transfers between GPUs or nodes) with computation (actual processing done by each GPU or node). Lower values (e.g., 1) reduce this overlap, potentially increasing communication waiting times. -- `num_grad_buffers`: - - `2` if `zero_stage >= 2` - - `1` otherwise -- `num_weight_buffers`: - - `2` if `zero_stage == 3` - - `1` otherwise +Increasing `num_grad_buffers` or `num_weight_buffers` provides more room for overlapping communication with compute. This can help in some setups, especially when stages are imbalanced, but generally isn't necessary. Note that this does not reduce total communication; it just shifts when it happens. If you want explicit control, you can override these values: @@ -61,7 +69,7 @@ multi_stage: num_weight_buffers: 2 ``` -For example, increasing `num_grad_buffers` to `3` or `4` will decrease inter-GPU communication frequency, potentially improving throughput—provided sufficient GPU memory is available. +Increasing `num_grad_buffers` to `3` or `4` decreases inter-GPU communication frequency, potentially improving throughput—provided sufficient GPU memory is available. ### Stage Layout Control @@ -75,6 +83,12 @@ multi_stage: Defaults work well in most cases: -- **`layers_per_stage`**: Determines the number of layers per stage. Defaults to `1.0` (one layer per stage). Increase it to reduce inter-stage communication or decrease it for better load balancing. Fractional values are allowed. +- **`layers_per_stage`**: Determines the number of layers per stage. Defaults to `1.0` (one layer per stage). Increase to reduce inter-stage communication or decrease for better load balancing. Fractional values are allowed. + + !!! warning + This setting is supported but hasn't been tested in recent versions. Use with caution. + +- **`stages_per_pipeline_stage`**: Intended to specify how many stages run per pipeline worker when pipeline parallelism is active. -- **`stages_per_pipeline_stage`**: Specifies how many stages run per pipeline worker. This setting is relevant only when pipeline parallelism is active. Default is `1`. Increase to assign multiple stages to the same pipeline worker, potentially simplifying communication patterns at the cost of flexibility in load distribution. + !!! warning + This feature is currently **not implemented**. Changing this value has no effect. From d0cb7529bcba58b09697488f200aaffd6980f571 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Fri, 2 May 2025 13:19:42 -0400 Subject: [PATCH 3/3] address comments --- docs/user_guide/multi-stage.md | 40 ++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/docs/user_guide/multi-stage.md b/docs/user_guide/multi-stage.md index b0abc1b0..13b3abd2 100644 --- a/docs/user_guide/multi-stage.md +++ b/docs/user_guide/multi-stage.md @@ -25,7 +25,7 @@ The following table summarizes the behavior of `zero_stage`: | `2` | Replicated | Sharded | Sharded | Moderate[^1] | | `3` | Sharded | Sharded | Sharded | High[^2] | -[^1]: Communication overhead for ZeRO Stage 2 is similar to Stage 1, except during (depth-first) gradient accumulation when additional all-reduce operations occur. +[^1]: Communication overhead for ZeRO Stage 2 is similar to Stage 1, except during (depth-first) gradient accumulation when additional reduce-scatter operations occur. [^2]: Communication overhead for ZeRO Stage 3 is higher than Stage 2, especially during (depth-first) gradient accumulation. Optimizer states are always sharded by default. ZeRO Stage 0 (full replication) is not supported. @@ -55,21 +55,43 @@ Beyond `zero_stage`, Fast-LLM offers additional multi-stage settings for fine-tu ### Buffers -When gradients or weights are sharded, Fast-LLM accumulates partial results in shared *buffers* during forward and backward passes, separately for gradients and weights. These buffers reduce communication overhead by batching gradient or weight updates across GPUs or nodes. The options `num_grad_buffers` and `num_weight_buffers` control the number of buffers used for gradients and weights, respectively. +Fast-LLM streams sharded tensors through communication buffers, allowing network transfers to overlap with GPU computation. These buffers temporarily store gradient or weight shards during forward and backward passes, improving training throughput by hiding communication latency. -By default, Fast-LLM assigns one gradient and weight buffer per stage, where the number of stages equals the total number of logical partitions (stages) of the model. This enables overlapping communication (e.g., data transfers between GPUs or nodes) with computation (actual processing done by each GPU or node). Lower values (e.g., 1) reduce this overlap, potentially increasing communication waiting times. +Buffers are only relevant when gradients or parameters are actually sharded, depending on your ZeRO stage: -Increasing `num_grad_buffers` or `num_weight_buffers` provides more room for overlapping communication with compute. This can help in some setups, especially when stages are imbalanced, but generally isn't necessary. Note that this does not reduce total communication; it just shifts when it happens. +| Buffer type | Active when | Config key | Default | +| ---------------- | ----------------- | -------------------- | ------- | +| Gradient buffers | ZeRO stage 2 or 3 | `num_grad_buffers` | `1` | +| Weight buffers | ZeRO stage 3 only | `num_weight_buffers` | `1` | -If you want explicit control, you can override these values: +- **Gradient buffers (`num_grad_buffers`)**: + + - Applies when gradients are sharded (ZeRO stages 2 and 3). + - Default (`1`) means no overlap (gradients are communicated layer-by-layer). + - Setting to `2` enables *double-buffering* (second buffer lets gradients transfer asynchronously while the GPU computes the next layer). Values of `3` or more add additional buffers, further increasing overlap at the cost of extra GPU memory per additional buffer. + +- **Weight buffers (`num_weight_buffers`)**: + + - Applies only at ZeRO stage 3 when parameters (weights) are sharded. + - Default (`1`) means no overlap (parameters communicated without asynchronous transfer). + - Setting to `2` enables *double-buffering* for weights (second buffer lets parameter transfers overlap with GPU computation). Higher values add more overlap, consuming additional GPU memory per buffer. + +These buffer settings have no effect when their respective tensors aren't sharded: + +- At ZeRO stage **1**, gradients and parameters are fully replicated, so both `num_grad_buffers` and `num_weight_buffers` are ignored. +- At ZeRO stage **2**, parameters remain replicated; thus, only `num_grad_buffers` is relevant. + +Buffers do not reduce the total amount of communication, Rather, they shift when communication occurs, improving throughput if your training is network-bound and you have spare GPU memory. + +If you want explicit control, you can override these values in your configuration: ```yaml multi_stage: - num_grad_buffers: 3 - num_weight_buffers: 2 + num_grad_buffers: 3 # ZeRO 2 or 3 + num_weight_buffers: 2 # ZeRO 3 only ``` -Increasing `num_grad_buffers` to `3` or `4` decreases inter-GPU communication frequency, potentially improving throughput—provided sufficient GPU memory is available. +Adjust buffers only if you observe GPU utilization drops due to frequent waiting for network transfers, and have GPU memory to spare. Start with defaults (`1`) and tune upward cautiously. ### Stage Layout Control @@ -91,4 +113,4 @@ Defaults work well in most cases: - **`stages_per_pipeline_stage`**: Intended to specify how many stages run per pipeline worker when pipeline parallelism is active. !!! warning - This feature is currently **not implemented**. Changing this value has no effect. + This feature is currently **not implemented**. Changing this value will currently cause a validation error.