1-
2- from vllm .distributed .group_coordinator import GroupCoordinator
3- from vllm .distributed .kv_transfer .kv_pipe .base import KVPipeBase
4- from torch .distributed import Backend , ProcessGroup
1+ from torch .distributed import Backend
52import torch
6- from typing import Any , Dict , List , Optional , Tuple , Union
3+ from typing import List , Optional , Union
74import threading
85from concurrent .futures import ThreadPoolExecutor
96import time
10- import threading
11- from collections import namedtuple
12- from typing import Dict , Any , Tuple , List
13- import pickle
147
158from vllm .logger import init_logger
169
17-
1810logger = init_logger (__name__ )
1911
2012
@@ -52,34 +44,32 @@ def __init__(self, message):
5244 self .message = message
5345 super ().__init__ (self .message )
5446
55- class TorchDistributedPipe (KVPipeBase ):
56-
47+
48+ class TorchDistributedPipe :
49+ METADATA_LENGTH = 16
50+ MAX_TENSOR_DIMENSIONS = 14
51+ METADATA_DTYPE = torch .int64
52+
5753 def __init__ (
5854 self ,
5955 group_ranks : List [List [int ]],
6056 local_rank : int ,
61- torch_distributed_backend : Union [str , Backend ]
57+ torch_distributed_backend : Union [str , Backend ],
6258 ):
63-
6459 self .rank = torch .distributed .get_rank ()
6560 self .local_rank = local_rank
6661 self .device_group = None
67- self .cpu_group = None
6862
6963 for ranks in group_ranks :
7064 device_group = torch .distributed .new_group (
71- ranks , backend = torch_distributed_backend )
72- # a group with `gloo` backend, to allow direct coordination between
73- # processes through the CPU.
74- cpu_group = torch .distributed .new_group (ranks , backend = "gloo" )
65+ ranks , backend = torch_distributed_backend
66+ )
7567 if self .rank in ranks :
7668 self .ranks = ranks
7769 self .world_size = len (ranks )
7870 self .rank_in_group = ranks .index (self .rank )
7971 self .device_group = device_group
80- self .cpu_group = cpu_group
8172
82- assert self .cpu_group is not None
8373 assert self .device_group is not None
8474 assert self .rank_in_group <= 1
8575
@@ -88,120 +78,215 @@ def __init__(
8878 else :
8979 self .device = torch .device ("cpu" )
9080
91- # if turned on, will use CPU-based communication to perform a series of sanity check.
92- # but it adds ~5ms delay, so please turn it off in performance-demanding usecases (e.g. disaggregated prefill)
93- self .target_rank_for_send = self .ranks [(self .rank_in_group + 1 ) %
94- self .world_size ]
95- self .target_rank_for_recv = self .ranks [(self .rank_in_group - 1 ) %
96- self .world_size ]
81+ self .target_rank_for_send = self .ranks [
82+ (self .rank_in_group + 1 ) % self .world_size
83+ ]
84+ self .target_rank_for_recv = self .ranks [
85+ (self .rank_in_group - 1 ) % self .world_size
86+ ]
87+
88+ # FIXME: why we need this?
9789 torch .set_default_device (self .device )
9890
99- self .kv_sending_thread = None
91+ self .transport_thread = None
10092 self .buffer_size = 0
10193 self .buffer_size_lock = threading .Lock ()
10294
103- self .none_tensor = torch .tensor ([NONE_INT ]).to (self .device )
104- self .broken = False
95+ self .none_tensor = torch .tensor ([NONE_INT ], device = self .device )
96+
97+ # On-device tensors to be reused for recv
98+ self .rcv_metadata_buffer = torch .zeros (
99+ self .METADATA_LENGTH , dtype = self .METADATA_DTYPE , device = self .device
100+ )
101+
102+ def _make_metadata (self , tensor : torch .Tensor ) -> torch .Tensor :
103+ """
104+ Create the metadata on based on the input tensor, and move it to GPU.
105+ The metadata's length is `TorchDistributedPipe.METADATA_LENGTH`.
105106
106-
107- def quick_send ( self , tensor ):
107+ Currently, the metadata is a int64 tensor and it includes dtype, number
108+ of dimensions, and the shape information of the input tensor.
108109
109- group = self .device_group
110110
111- # NCCL is NOT fully duplex
112- # so CPU communication is ALWAYS necessary
113- torch .distributed .send_object_list (
114- [tensor .dtype , tensor .shape , str (tensor .device )],
115- dst = self .target_rank_for_send ,
116- group = self .cpu_group
111+ The information follows the layout below:
112+ - metadata[0] -- dtype
113+ - metadata[1] -- number of dimensions
114+ - metadata[2 : 2+ndims] -- the shape of the input tensor
115+
116+ Parameters:
117+ - tensor: the input tensor
118+
119+ Returns:
120+ - metadata: the metadata tensor, on self.device
121+ """
122+ buffer = torch .empty (self .METADATA_LENGTH , dtype = self .METADATA_DTYPE )
123+ buffer [0 ] = DTYPE2INT [tensor .dtype ]
124+ ndims = len (tensor .shape )
125+ buffer [1 ] = len (tensor .shape )
126+ buffer [2 : 2 + ndims ] = torch .tensor (
127+ tensor .shape , dtype = self .METADATA_DTYPE
117128 )
129+ return buffer .to (self .device )
130+
131+ def _prepare_recv_buffer (
132+ self , d_metadata_buffer : torch .Tensor
133+ ) -> torch .Tensor :
134+ """
135+ Create a buffer to receive the tensor based on the metadata.
136+
137+ Parameters:
138+ - d_metadata_buffer: the metadata tensor on self.device
139+
140+ Returns:
141+ - buffer: the buffer tensor to receive the tensor, on self.device
142+ """
143+ h_buffer = d_metadata_buffer .cpu ().numpy ()
144+ dtype = INT2DTYPE [h_buffer [0 ]]
145+ ndims = h_buffer [1 ]
146+ shape = tuple (h_buffer [2 : 2 + ndims ])
147+ return torch .empty (shape , dtype = dtype , device = self .device )
118148
149+ def _send_metadata (self , d_metadata_buffer : torch .Tensor ):
150+ """
151+ Send the metadata buffer to the target rank.
152+ """
119153 torch .distributed .send (
120- tensor ,
154+ d_metadata_buffer ,
121155 dst = self .target_rank_for_send ,
122- group = self .device_group
156+ group = self .device_group ,
123157 )
124158
159+ def _recv_metadata (self ) -> torch .Tensor :
160+ """
161+ Receive the metadata buffer from the target rank.
125162
126- def quick_recv (self ):
163+ Returns:
164+ - metadata_buffer: the metadata buffer tensor, on self.device
127165
128- # NCCL is NOT fully duplex
129- # so CPU communication is necessary
130- metadata = [None , None , None ]
131- torch .distributed .recv_object_list (
132- metadata ,
166+ Note:
167+ The current implementation uses the assumption that there is no
168+ race conditions during sending/receiving. Therefore, the metadata
169+ buffer can be reused
170+ """
171+ torch .distributed .recv (
172+ self .rcv_metadata_buffer ,
133173 src = self .target_rank_for_recv ,
134- group = self .cpu_group
174+ group = self .device_group ,
135175 )
136-
137- dtype , shape , device = metadata
138- if 'cuda' in device :
139- device = self .device
140- else :
141- device = 'cpu'
142- buffer = torch .zeros (shape , dtype = dtype ).to (device )
143-
176+ return self .rcv_metadata_buffer
177+
178+ def _send_impl (self , tensor ):
179+ """
180+ The actual implementation of sending the tensor to the target rank.
181+ This function will first send the metadata, and then send the tensor.
182+
183+ Parameters:
184+ - tensor: the input tensor to be sent
185+ """
186+
187+ metadata = self ._make_metadata (tensor )
188+ self ._send_metadata (metadata )
189+
190+ torch .distributed .send (
191+ tensor , dst = self .target_rank_for_send , group = self .device_group
192+ )
193+
194+ def _recv_impl (self ) -> torch .Tensor :
195+ """
196+ The actual implementation of receiving the tensor from the target rank.
197+ This function will first receive the metadata, then receive the tensor.
198+
199+ This function will block if there is no tensor to receive.
200+
201+ Returns:
202+ - buffer: the received tensor, on self.device
203+ """
204+ d_metadata = self ._recv_metadata ()
205+ buffer = self ._prepare_recv_buffer (d_metadata )
206+
144207 torch .distributed .recv (
145- buffer ,
146- src = self .target_rank_for_recv ,
147- group = self .device_group
208+ buffer , src = self .target_rank_for_recv , group = self .device_group
148209 )
149- return buffer
150-
151210
152-
153- def send_tensor_wrapper (self , tensor ) -> None :
211+ return buffer
154212
213+ def send_tensor_wrapper (self , tensor ):
155214 try :
215+ """Wrapper for send_tensor_dict"""
156216 tensor_size = tensor .element_size () * tensor .numel ()
157- self .quick_send (tensor )
158-
217+ self ._send_impl (tensor )
218+
159219 with self .buffer_size_lock :
160220 self .buffer_size = self .buffer_size - tensor_size
161221 except Exception as e :
162222 logger .error ("Encountering exception in KV sending thread" )
163223 logger .error ("%s" , e )
164-
224+
165225 def block_if_full (self ):
166-
226+ """
227+ Block the current thread if the buffer size is larger than 1e9.
228+ """
229+ # TODO: replace this 1e9 with a configurable parameter or a constant
167230 while self .buffer_size > 1e9 :
168231 logger .debug ("KV cache transfer pipe is full. Waiting..." )
169232 time .sleep (0.05 )
170233
171- def send_tensor (self ,
172- tensor : Optional [torch .Tensor ]) -> None :
234+ def send_tensor (self , tensor : Optional [torch .Tensor ]) -> None :
173235 """
174236 Sends a tensor to the destination rank in a non-blocking way.
175237 Flow: send tensor dim -- send tensor shape -- send tensor data
176238 """
177-
178- if self .kv_sending_thread is None :
179- self .kv_sending_thread = ThreadPoolExecutor (max_workers = 1 )
239+
240+ if self .transport_thread is None :
241+ self .transport_thread = ThreadPoolExecutor (max_workers = 1 )
180242
181243 if tensor is None :
182244 tensor = self .none_tensor
183245 tensor_size = 0
184246 else :
185247 tensor_size = tensor .element_size () * tensor .numel ()
186248
249+ assert (
250+ 0 < len (tensor .shape ) < self .MAX_TENSOR_DIMENSIONS
251+ ), f"Only support dimensions within 1-{ self .MAX_TENSOR_DIMENSIONS } "
252+
187253 self .block_if_full ()
188254
189255 with self .buffer_size_lock :
256+ # print("Remaining size:", self.buffer_size)
190257 self .buffer_size = self .buffer_size + tensor_size
191-
258+
192259 # prepare the metadata before sending the tensor.
193- self .kv_sending_thread .submit (
194- self .send_tensor_wrapper ,
195- tensor )
196-
260+ self .transport_thread .submit (
261+ self .send_tensor_wrapper ,
262+ tensor ,
263+ )
264+
197265 def recv_tensor (self ) -> Optional [torch .Tensor ]:
198266 """Receives a tensor from the src rank. Blocking."""
199-
200- tensor = self .quick_recv ()
267+
268+ if self .transport_thread is None :
269+ self .transport_thread = ThreadPoolExecutor (max_workers = 1 )
270+
271+ future = self .transport_thread .submit (self ._recv_impl )
272+
273+ try :
274+ tensor = future .result ()
275+ except Exception as e :
276+ logger .error ("Encountering exception in KV receiving thread" )
277+ logger .error ("%s" , e )
278+
201279 if tensor .numel () == 1 and tensor .item () == NONE_INT :
202280 return None
203281 else :
204282 return tensor
205-
206283
207-
284+ def close (self ):
285+ """
286+ Close the pipe and release the resources.
287+ """
288+ if (
289+ hasattr (self , "transport_thread" )
290+ and self .transport_thread is not None
291+ ):
292+ self .transport_thread .shutdown ()
0 commit comments