Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions blog/2025-12-10-rfork.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
---
title: "Let Tensors Fly — Accelerating Large Model Weight Loading with R-Fork"
author: "Ant Group DeepXPU Team, SGLang Team"
date: "December 10, 2025"
previewImg: /images/blog/rfork/preview.svg
---

## TL;DR

> We introduce **Tensor R-Fork** (stands for Tensor Remote Fork), a novel weight loading methodology that leverages **efficient inter-node GPU-to-GPU data transfer path** to load tensors from a running SGLang instance to a new instance with **zero-copy**.

Our approach provides three key advantages:

1. Significantly accelerates weight-loading performance;
2. Eliminates redundant model weight storage on local disk and/or DRAM;
3. Ensures non-disturbing operation for inference services.

For instance, when applied to Deepseek-R1 model, the loading time is reduced **from several minutes to mere seconds**, while local disk and/or DRAM storage usage is **reduced by ~600GB** and inference service quality maintains during model transfers.

## Background
As the scale LLM services and the size of model weights continue to expand, the cold-start time of SGLang instances has become a critical bottleneck for production efficiency. Among the cold-start phases, weight loading remains the most time-consuming task.

Taking Deepseek-R1 as an example, loading weights from local disk typically takes several minutes, while loading from remote storage systems can take up to tens of minutes. As model sizes continue to grow exponentially, the time required for initialization and data transfer will likely worsen.

How can we optimize weight loading performance? The most straightforward approach is to maximize the bottleneck bandwidth in the weight data flow. The data flow of commonly-used model loading approaches in the industry and their associated bottleneck bandwidths are as follows:

| Load weights from | Data Flow | Bottleneck |
|-----------------------|----------------------------------------------------------|------------|
| remote storage center | remote storage -> remote Ethernet NIC -> Ethernet -> local Ethernet NIC -> local DRAM -> local GPU memory | NVMe/Ethernet NIC |
| local disk | disk -> DRAM -> GPU memory | NVMe |
| local DRAM | DRAM -> GPU memory | PCIe |


Can we exploit higher-bandwidth data flows for transferring tensors? The answer is **YES** — InfiniBand offers hundreds of gigabytes per second of throughput. However, the critical question remains: How can we fully leverage InfiniBand's bandwidth for efficient weight loading in SGLang?

To address this challenge, we have developed **a novel weight-loading framework called Tensor R-Fork** (stands for Tensor Remote Fork), which reduces Deepseek-R1 model loading time to mere seconds and is already production-ready.

## Design

The core concept of <a href=https://github.com/sgl-project/sglang/blob/main/docs/advanced_features/rfork.md>Tensor R-Fork</a>[0] is to **leverage GPU-Direct RDMA for constructing a peer-to-peer (P2P) weight storage architecture.**

The performance of data transfer using traditional method is low, because there is always bottleneck in the entire path, whose bandwidth is much smaller than InfiniBand.
From the data flow analysis, we observe that weight tensors are stored on each GPU and can be transmitted directly between nodes via GPU-direct RDMA.

To maximize the utilization of InfiniBand NIC's bandwidth, we design a per GPU-pair data transfer strategy: a local GPU directly transfers data to/from its paired remote GPU. This design effectively bypasses the PCIe bottleneck between GPU and CPU, enabling high-throughput communication without relying on CPU or host memory.
The data flow of loading weights from remote SGLang instance is the following:

<img src="/images/blog/rfork/design.svg" style="display:block; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 100%; image-orientation: none;"></img>


## Implementation

To make every running instances act as the source of model weights for any new instance requiring the same model—while minimizing (or ideally eliminating) disruption to the inference services of running instances—we implemented the framework with two backend options: NCCL and TransferEngine. Consider a running instance A (referred to as the source instance) and a new instance B to be booted (destination instance). Below, we will explain the implementation of weight transfer mechanisms using these two backends in detail.

### NCCL backend

When using <a href=https://github.com/sgl-project/sglang/pull/8215>NCCL</a> as the backend[1], the process involves two stages:
1. Establishing communication groups between source and destination instances.
2. Transferring weights from the source instance to the destination instance via these groups.

During destination instance initialization, it sends an HTTP request to the designated source instance to initiate communication group creation. Each TPWorker of destination instance establishes a NCCL communication group with its corresponding TPWorker of source instance (i.e.source rank 0 pairs with destination rank 0, etc.). Each communication group consists of exactly two members: the source TPWorker and destination TPWorker.

Once communication groups are established, each source TPWorker broadcasts its weights tensor located on GPU memory through the group using NCCL broadcast. The destination TPWorker receives the weights directly into its GPU memory without any intermediate memory copies.

While NCCL serves as Tensor R-Fork backend by leveraging GPU-Direct RDMA, it does have a critical limitation: weight transfer disrupts the source instance's inference service, due to two key factors:
1. **Communication group establishment**: The source instance must actively participate in creating communication groups.
2. **CUDA kernel interference**: The NCCL broadcast mechanism triggers CUDA kernel execution, which competes for GPU resources and introduces latency spikes during generation tasks.

### TransferEngine backend

To achieve non-disturbing weight transfer, we introduce an alternative backend: <a href=https://github.com/sgl-project/sglang/pull/13125>TransferEngine</a>, which leverages GPU-Direct RDMA for efficient data movement[2]. TransferEngine (TE) is a lightweight RDMA-based transfer runtime that runs alongside each TPWorker on the source instance and exposes GPU-resident weight tensors to remote readers without invoking CUDA kernels on the source.

During source SGLang instance initialization:
1. Each TPWorker (tensor parallel worker) spawns a TransferEngine instance.
2. TransferEngine registers the GPU memory addresses of its weights with the RDMA channel.

When initializing the destination instance:
1. It sends an HTTP request to retrieve the source instance's TransferEngine metadata, including RDMA keys mapped to the corresponding GPU memory addresses.
2. Using these RDMA keys, the destination instance directly loads weights from the source's GPU memory without interrupting the source instance's ongoing services.

*Want to learn more about TransferEngine? You are more than welcome to check **TransferEngine** in Appendix session 🚀

### NCCL vs. TransferEngine

| | NCCL | TransferEngine |
|----------------------|--------------------------------|----------------|
| Deployment Complexity| ✅ No additional dependency. |❌ Additional library `mooncake` is needed. |
|Overhead of Transfer Setup | ✅ Building communication groups takes hundreds of milliseconds | ➖ Registering memory regions to RDMA channel may take several seconds, but can be overlapped with other initialization phases.|
|Non-disturbing to GPU workload | ❌ Tensor transfer will launch CUDA kernels. | ✅ No CUDA kernels launched for transferring weights. |

## How to Use

Detailed usage please refer to <a href=https://github.com/sgl-project/sglang/blob/main/docs/advanced_features/rfork.md>R-Fork document</a>

### Use NCCL as backend

```shell
python -m sglang.launch_server [args] \
--load-format remote_instance \
--remote-instance-weight-loader-seed-instance-ip [seed_instance_ip] \
--remote-instance-weight-loader-seed-instance-service-port [seed_instance_service_port] \
--remote-instance-weight-loader-send-weights-group-ports [send_weights_nccl_group_ports_list] \
--remote-instance-weight-loader-backend nccl # optional, default is "nccl"
```

### Use TransferEngine as backend

```shell
python -m sglang.launch_server [args] \
--load-format remote_instance \
--remote-instance-weight-loader-seed-instance-ip [seed_instance_ip] \
--remote-instance-weight-loader-seed-instance-service-port [seed_instance_service_port] \
--remote-instance-weight-loader-backend transfer_engine
```

## Performance

We evaluated the performance of launching a new SGLang instance equipped with eight NVIDIA H20 GPUs, while loading the DeepSeek-R1 model from different sources.

<img src="/images/blog/rfork/performance.svg" style="display:block; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 100%; image-orientation: none;"></img>

Registering the memory region can be overlapped with other initialization phases to further optimize total boot-up time.

## Industrial Practice

In the previous sections, we demonstrated how to manually configure seed instances for Tensor R-Fork within SGLang server arguments. However, this manual approach is impractical in real-world industrial deployment, where identifying available seed instances requires significant operational overhead.

To address this challenge, we propose **<a href=https://github.com/sgl-project/sglang/issues/12910>Tensor R-Fork Planner</a>**[4], a cluster scheduler designed to orchestrate source instance metadata. The Planner tracks critical information, including:
1. **Model compatibility**: Which model is currently running on the instance.
2. **Parallelism configuration**: The parallel strategy (e.g., tensor parallelism, pipeline parallelism) employed.
3. **Service health status**: Whether the instance is healthy and suitable as a seed instance.

Upon completion of its initialization, each instance registers itself with the Planner, providing its model metadata and parallelism configuration. When a new instance boots up, it first queries the Planner to identify an eligible seed instance that matches both its model and parallelism strategy. If a compatible seed instance is found, the new instance loads weights directly from the seed; otherwise, it falls back to the default load format.

## Future Work

The practice of R-Fork opens up more imaginative possibilities: the key concept of R-Fork is that it enables all SGLang instances to act as data storage centers for other instances. Starting from weight tensors, we will manage additional tensors through Tensor R-Fork mechanism in the future, allowing GPU clusters to function not only as computing centers but also as storage centers.

## Acknowledgements

**Ant Group DeepXPU Team**: Anqi Shen, Tianyu Zhou, Zehuan Li, Tiwei Bie, Mingliang Gong, Jianfeng Tan

**SGLang Team**: Chenyang Zhao, Liangsheng Yin, Lianmin Zheng

**TransferEngine Team**: Teng Ma, Feng Ren, Shangming Cai

## Reference

[0] Tensor R-Fork Documentation: <a href=https://github.com/sgl-project/sglang/blob/main/docs/advanced_features/rfork.md>Documentation</a>
[1] Tensor R-Fork with NCCL backend: <a href=https://github.com/sgl-project/sglang/pull/8215>PR#8215</a>
[2] Tensor R-Fork with TransferEngine backend: <a href=https://github.com/sgl-project/sglang/pull/13125>PR#13125</a>
[3] Concurrent weights loading from disk: <a href=https://github.com/sgl-project/sglang/pull/7943>PR#7943</a>
[4] Tensor R-Fork Planner SGLang RFC: <a href=https://github.com/sgl-project/sglang/issues/12910>Issue#12910</a>
[5] TransferEngine: <a href=https://kvcache-ai.github.io/Mooncake/design/transfer-engine.html>https://kvcache-ai.github.io/Mooncake/design/transfer-engine.html</a>,
TransferEngine APIs: <a href=https://kvcache-ai.github.io/Mooncake/python-api-reference/transfer-engine.html>https://kvcache-ai.github.io/Mooncake/python-api-reference/transfer-engine.html</a>

## Appendix

### TransferEngine

Key advantages provided by TransferEngine[5]:
* **Multi-backend support**: TE supports multiple backends, including RDMA (with GPUDirect), NVLink, GDS, and TCP. It can intelligently identify the best backend per request, so that the highest performance could be reached.
* **Direct RDMA reads**: using the published addresses and rkeys, the destination performs RDMA operations (typically RDMA READ) directly into its own pre-allocated GPU buffers, leveraging GPU-Direct RDMA so that no host-device or device-host intermediate copies are required.
* **Non-disturbing**: TE performs pure NIC-driven transfers that avoid launching CUDA kernels on the source GPU.
* **Lifecycle & housekeeping**: TE maintains the lifetime of registrations until tensors are evicted or the process exits.
* **Concurrency & flow control**: TE coordinates concurrent reads (from one or many destinations) and can apply throttling or rate limits to avoid saturating instance’s NIC or impacting inference latency.

Known limitation in the current TransferEngine implementation:
* **Memory registration (register_mr) is slow**: <u>This is due to the RDMA driver</u>. If you have any insights or solutions to this issue, we would be truly grateful to hear from you. We value diverse perspectives and are keen to explore innovative approaches together.


4 changes: 4 additions & 0 deletions public/images/blog/rfork/design.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions public/images/blog/rfork/performance.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions public/images/blog/rfork/preview.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.