Skip to content

Commit 1377912

Browse files
authored
Merge pull request #3 from KuntaiDu/yihua-kv-pipe
Optimize the KV transfer pipe implementation
2 parents 9f81f41 + bb86588 commit 1377912

File tree

1 file changed

+166
-81
lines changed

1 file changed

+166
-81
lines changed
Lines changed: 166 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,12 @@
1-
2-
from vllm.distributed.group_coordinator import GroupCoordinator
3-
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
4-
from torch.distributed import Backend, ProcessGroup
1+
from torch.distributed import Backend
52
import torch
6-
from typing import Any, Dict, List, Optional, Tuple, Union
3+
from typing import List, Optional, Union
74
import threading
85
from concurrent.futures import ThreadPoolExecutor
96
import time
10-
import threading
11-
from collections import namedtuple
12-
from typing import Dict, Any, Tuple, List
13-
import pickle
147

158
from vllm.logger import init_logger
169

17-
1810
logger = init_logger(__name__)
1911

2012

@@ -52,34 +44,32 @@ def __init__(self, message):
5244
self.message = message
5345
super().__init__(self.message)
5446

55-
class TorchDistributedPipe(KVPipeBase):
56-
47+
48+
class TorchDistributedPipe:
49+
METADATA_LENGTH = 16
50+
MAX_TENSOR_DIMENSIONS = 14
51+
METADATA_DTYPE = torch.int64
52+
5753
def __init__(
5854
self,
5955
group_ranks: List[List[int]],
6056
local_rank: int,
61-
torch_distributed_backend: Union[str, Backend]
57+
torch_distributed_backend: Union[str, Backend],
6258
):
63-
6459
self.rank = torch.distributed.get_rank()
6560
self.local_rank = local_rank
6661
self.device_group = None
67-
self.cpu_group = None
6862

6963
for ranks in group_ranks:
7064
device_group = torch.distributed.new_group(
71-
ranks, backend=torch_distributed_backend)
72-
# a group with `gloo` backend, to allow direct coordination between
73-
# processes through the CPU.
74-
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
65+
ranks, backend=torch_distributed_backend
66+
)
7567
if self.rank in ranks:
7668
self.ranks = ranks
7769
self.world_size = len(ranks)
7870
self.rank_in_group = ranks.index(self.rank)
7971
self.device_group = device_group
80-
self.cpu_group = cpu_group
8172

82-
assert self.cpu_group is not None
8373
assert self.device_group is not None
8474
assert self.rank_in_group <= 1
8575

@@ -88,120 +78,215 @@ def __init__(
8878
else:
8979
self.device = torch.device("cpu")
9080

91-
# if turned on, will use CPU-based communication to perform a series of sanity check.
92-
# but it adds ~5ms delay, so please turn it off in performance-demanding usecases (e.g. disaggregated prefill)
93-
self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) %
94-
self.world_size]
95-
self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) %
96-
self.world_size]
81+
self.target_rank_for_send = self.ranks[
82+
(self.rank_in_group + 1) % self.world_size
83+
]
84+
self.target_rank_for_recv = self.ranks[
85+
(self.rank_in_group - 1) % self.world_size
86+
]
87+
88+
# FIXME: why we need this?
9789
torch.set_default_device(self.device)
9890

99-
self.kv_sending_thread = None
91+
self.transport_thread = None
10092
self.buffer_size = 0
10193
self.buffer_size_lock = threading.Lock()
10294

103-
self.none_tensor = torch.tensor([NONE_INT]).to(self.device)
104-
self.broken = False
95+
self.none_tensor = torch.tensor([NONE_INT], device=self.device)
96+
97+
# On-device tensors to be reused for recv
98+
self.rcv_metadata_buffer = torch.zeros(
99+
self.METADATA_LENGTH, dtype=self.METADATA_DTYPE, device=self.device
100+
)
101+
102+
def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor:
103+
"""
104+
Create the metadata on based on the input tensor, and move it to GPU.
105+
The metadata's length is `TorchDistributedPipe.METADATA_LENGTH`.
105106
106-
107-
def quick_send(self, tensor):
107+
Currently, the metadata is a int64 tensor and it includes dtype, number
108+
of dimensions, and the shape information of the input tensor.
108109
109-
group = self.device_group
110110
111-
# NCCL is NOT fully duplex
112-
# so CPU communication is ALWAYS necessary
113-
torch.distributed.send_object_list(
114-
[tensor.dtype, tensor.shape, str(tensor.device)],
115-
dst=self.target_rank_for_send,
116-
group=self.cpu_group
111+
The information follows the layout below:
112+
- metadata[0] -- dtype
113+
- metadata[1] -- number of dimensions
114+
- metadata[2 : 2+ndims] -- the shape of the input tensor
115+
116+
Parameters:
117+
- tensor: the input tensor
118+
119+
Returns:
120+
- metadata: the metadata tensor, on self.device
121+
"""
122+
buffer = torch.empty(self.METADATA_LENGTH, dtype=self.METADATA_DTYPE)
123+
buffer[0] = DTYPE2INT[tensor.dtype]
124+
ndims = len(tensor.shape)
125+
buffer[1] = len(tensor.shape)
126+
buffer[2 : 2 + ndims] = torch.tensor(
127+
tensor.shape, dtype=self.METADATA_DTYPE
117128
)
129+
return buffer.to(self.device)
130+
131+
def _prepare_recv_buffer(
132+
self, d_metadata_buffer: torch.Tensor
133+
) -> torch.Tensor:
134+
"""
135+
Create a buffer to receive the tensor based on the metadata.
136+
137+
Parameters:
138+
- d_metadata_buffer: the metadata tensor on self.device
139+
140+
Returns:
141+
- buffer: the buffer tensor to receive the tensor, on self.device
142+
"""
143+
h_buffer = d_metadata_buffer.cpu().numpy()
144+
dtype = INT2DTYPE[h_buffer[0]]
145+
ndims = h_buffer[1]
146+
shape = tuple(h_buffer[2 : 2 + ndims])
147+
return torch.empty(shape, dtype=dtype, device=self.device)
118148

149+
def _send_metadata(self, d_metadata_buffer: torch.Tensor):
150+
"""
151+
Send the metadata buffer to the target rank.
152+
"""
119153
torch.distributed.send(
120-
tensor,
154+
d_metadata_buffer,
121155
dst=self.target_rank_for_send,
122-
group=self.device_group
156+
group=self.device_group,
123157
)
124158

159+
def _recv_metadata(self) -> torch.Tensor:
160+
"""
161+
Receive the metadata buffer from the target rank.
125162
126-
def quick_recv(self):
163+
Returns:
164+
- metadata_buffer: the metadata buffer tensor, on self.device
127165
128-
# NCCL is NOT fully duplex
129-
# so CPU communication is necessary
130-
metadata = [None, None, None]
131-
torch.distributed.recv_object_list(
132-
metadata,
166+
Note:
167+
The current implementation uses the assumption that there is no
168+
race conditions during sending/receiving. Therefore, the metadata
169+
buffer can be reused
170+
"""
171+
torch.distributed.recv(
172+
self.rcv_metadata_buffer,
133173
src=self.target_rank_for_recv,
134-
group=self.cpu_group
174+
group=self.device_group,
135175
)
136-
137-
dtype, shape, device = metadata
138-
if 'cuda' in device:
139-
device = self.device
140-
else:
141-
device = 'cpu'
142-
buffer = torch.zeros(shape, dtype=dtype).to(device)
143-
176+
return self.rcv_metadata_buffer
177+
178+
def _send_impl(self, tensor):
179+
"""
180+
The actual implementation of sending the tensor to the target rank.
181+
This function will first send the metadata, and then send the tensor.
182+
183+
Parameters:
184+
- tensor: the input tensor to be sent
185+
"""
186+
187+
metadata = self._make_metadata(tensor)
188+
self._send_metadata(metadata)
189+
190+
torch.distributed.send(
191+
tensor, dst=self.target_rank_for_send, group=self.device_group
192+
)
193+
194+
def _recv_impl(self) -> torch.Tensor:
195+
"""
196+
The actual implementation of receiving the tensor from the target rank.
197+
This function will first receive the metadata, then receive the tensor.
198+
199+
This function will block if there is no tensor to receive.
200+
201+
Returns:
202+
- buffer: the received tensor, on self.device
203+
"""
204+
d_metadata = self._recv_metadata()
205+
buffer = self._prepare_recv_buffer(d_metadata)
206+
144207
torch.distributed.recv(
145-
buffer,
146-
src=self.target_rank_for_recv,
147-
group=self.device_group
208+
buffer, src=self.target_rank_for_recv, group=self.device_group
148209
)
149-
return buffer
150-
151210

152-
153-
def send_tensor_wrapper(self, tensor) -> None:
211+
return buffer
154212

213+
def send_tensor_wrapper(self, tensor):
155214
try:
215+
"""Wrapper for send_tensor_dict"""
156216
tensor_size = tensor.element_size() * tensor.numel()
157-
self.quick_send(tensor)
158-
217+
self._send_impl(tensor)
218+
159219
with self.buffer_size_lock:
160220
self.buffer_size = self.buffer_size - tensor_size
161221
except Exception as e:
162222
logger.error("Encountering exception in KV sending thread")
163223
logger.error("%s", e)
164-
224+
165225
def block_if_full(self):
166-
226+
"""
227+
Block the current thread if the buffer size is larger than 1e9.
228+
"""
229+
# TODO: replace this 1e9 with a configurable parameter or a constant
167230
while self.buffer_size > 1e9:
168231
logger.debug("KV cache transfer pipe is full. Waiting...")
169232
time.sleep(0.05)
170233

171-
def send_tensor(self,
172-
tensor: Optional[torch.Tensor]) -> None:
234+
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
173235
"""
174236
Sends a tensor to the destination rank in a non-blocking way.
175237
Flow: send tensor dim -- send tensor shape -- send tensor data
176238
"""
177-
178-
if self.kv_sending_thread is None:
179-
self.kv_sending_thread = ThreadPoolExecutor(max_workers=1)
239+
240+
if self.transport_thread is None:
241+
self.transport_thread = ThreadPoolExecutor(max_workers=1)
180242

181243
if tensor is None:
182244
tensor = self.none_tensor
183245
tensor_size = 0
184246
else:
185247
tensor_size = tensor.element_size() * tensor.numel()
186248

249+
assert (
250+
0 < len(tensor.shape) < self.MAX_TENSOR_DIMENSIONS
251+
), f"Only support dimensions within 1-{self.MAX_TENSOR_DIMENSIONS}"
252+
187253
self.block_if_full()
188254

189255
with self.buffer_size_lock:
256+
# print("Remaining size:", self.buffer_size)
190257
self.buffer_size = self.buffer_size + tensor_size
191-
258+
192259
# prepare the metadata before sending the tensor.
193-
self.kv_sending_thread.submit(
194-
self.send_tensor_wrapper,
195-
tensor)
196-
260+
self.transport_thread.submit(
261+
self.send_tensor_wrapper,
262+
tensor,
263+
)
264+
197265
def recv_tensor(self) -> Optional[torch.Tensor]:
198266
"""Receives a tensor from the src rank. Blocking."""
199-
200-
tensor = self.quick_recv()
267+
268+
if self.transport_thread is None:
269+
self.transport_thread = ThreadPoolExecutor(max_workers=1)
270+
271+
future = self.transport_thread.submit(self._recv_impl)
272+
273+
try:
274+
tensor = future.result()
275+
except Exception as e:
276+
logger.error("Encountering exception in KV receiving thread")
277+
logger.error("%s", e)
278+
201279
if tensor.numel() == 1 and tensor.item() == NONE_INT:
202280
return None
203281
else:
204282
return tensor
205-
206283

207-
284+
def close(self):
285+
"""
286+
Close the pipe and release the resources.
287+
"""
288+
if (
289+
hasattr(self, "transport_thread")
290+
and self.transport_thread is not None
291+
):
292+
self.transport_thread.shutdown()

0 commit comments

Comments
 (0)