Skip to content

Commit 4db6446

Browse files
authored
Merge pull request #5 from KuntaiDu/kuntai-disagg-refactor
Kuntai disagg refactor
2 parents 1377912 + c5b7232 commit 4db6446

File tree

6 files changed

+115
-48
lines changed

6 files changed

+115
-48
lines changed

tests/kv_transfer/test_lookup_buffer.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from tqdm import tqdm
88
import time
99

10+
# TODO: the test depends on a lot of fields in the current implementation. We should have standard interface instead direct field access
1011

11-
def test_run(my_rank, buffer):
12+
def test_run(my_rank, buffer, device):
1213

1314
# buffer should be empty in the beginning
1415
if my_rank == 0:
@@ -17,46 +18,49 @@ def test_run(my_rank, buffer):
1718

1819

1920
# insert
20-
tokens = torch.tensor([1,2,3]).to(buffer.pipe.device)
21+
tokens = torch.tensor([1,2,3]).to(device)
2122
roi = (tokens > 0)
2223
if my_rank == 0:
23-
key = 2.0 * torch.ones([5, 6]).to(buffer.pipe.device)
24-
value = 3.0 * torch.ones([5, 6]).to(buffer.pipe.device)
24+
key = 2.0 * torch.ones([5, 6]).to(device)
25+
value = 3.0 * torch.ones([5, 6]).to(device)
2526

26-
placeholder = torch.tensor([1]).to(buffer.pipe.device)
27+
placeholder = torch.tensor([1]).to(device)
2728

2829
buffer.insert(tokens, roi, key, value, placeholder)
30+
31+
#for i in range(2000):
32+
# print("Here:", i)
33+
# time.sleep(0.01)
2934
torch.distributed.barrier()
3035

3136
# drop_select
3237
if my_rank == 1:
3338
tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi)
3439
assert torch.allclose(tokens, tok)
3540
assert torch.allclose(roi, roi_)
36-
assert torch.allclose(key, 2.0 * torch.ones([5, 6]))
37-
assert torch.allclose(value, 3.0 * torch.ones([5, 6]))
41+
assert torch.allclose(key, 2.0 * torch.ones([5, 6], device = device))
42+
assert torch.allclose(value, 3.0 * torch.ones([5, 6], device = device))
3843
torch.distributed.barrier()
3944

4045
if my_rank == 0:
4146
assert buffer.buffer_size == 0
4247
assert len(buffer.buffer) == 0
48+
49+
print("Test run passed!")
4350

44-
45-
def stress_test(my_rank, buf):
51+
def stress_test(my_rank, buf, device):
4652

4753
torch.distributed.barrier()
4854
torch.manual_seed(100)
4955

50-
device = buf.pipe.device
51-
5256
reqs = [
5357
(
5458
torch.rand(100).to(device), # tokens
5559
torch.ones(100).bool().to(device), # roi
5660
torch.rand(100).to(device), # key
5761
torch.rand(100).to(device), # value
5862
torch.rand(100).to(device), # hidden
59-
) for i in range(200)]
63+
) for i in tqdm(range(200))]
6064

6165
random.seed(my_rank)
6266
random.shuffle(reqs)
@@ -86,7 +90,7 @@ def stress_test(my_rank, buf):
8690
assert torch.allclose(k, k_)
8791
assert torch.allclose(v, v_)
8892
assert torch.allclose(h, h_)
89-
print('Rand %d done' % my_rank)
93+
print('Rank %d done' % my_rank)
9094
torch.distributed.barrier()
9195

9296

@@ -101,13 +105,9 @@ def stress_test(my_rank, buf):
101105
else:
102106
torch.distributed.send(torch.tensor([n]), 0)
103107

108+
print("Passed stress test!")
104109

105110

106-
107-
108-
109-
110-
111111

112112
if __name__ == "__main__":
113113

@@ -123,10 +123,14 @@ def stress_test(my_rank, buf):
123123

124124

125125
pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "nccl")
126-
buffer = sklb.SimpleKVLookupBuffer(pipe, 170000)
126+
cpu_pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "gloo")
127+
buffer = sklb.SimpleKVLookupBuffer(cpu_pipe, pipe, 170000)
127128

128-
test_run(my_rank, buffer)
129+
test_run(my_rank, buffer, pipe.device)
129130

130-
stress_test(my_rank, buffer)
131+
stress_test(my_rank, buffer, pipe.device)
131132

133+
buffer.close()
134+
pipe.close()
135+
cpu_pipe.close()
132136
print('Done')

tests/kv_transfer/test_send_recv.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,28 @@ def test_run(my_rank, pipe):
1212
x = torch.tensor([1]).to(pipe.device)
1313
y = torch.tensor([[2., 3., 4., 8.]]).to(pipe.device)
1414
if my_rank == 0:
15-
1615
pipe.send_tensor(x)
16+
print("sent tensor x")
1717
pipe.send_tensor(y)
18+
print("sent tensor y")
19+
x2 = pipe.recv_tensor()
20+
print("received x2 = ", x2)
21+
y2 = pipe.recv_tensor()
22+
print("received y2 = ", x2)
23+
1824
else:
19-
assert torch.allclose(x, pipe.recv_tensor())
20-
assert torch.allclose(y, pipe.recv_tensor())
25+
x2 = pipe.recv_tensor()
26+
print("received x2 = ", x2)
27+
y2 = pipe.recv_tensor()
28+
print("received y2 = ", x2)
29+
pipe.send_tensor(x)
30+
print("sent tensor x")
31+
pipe.send_tensor(y)
32+
print("sent tensor y")
33+
34+
assert torch.allclose(x, x2)
35+
assert torch.allclose(y, y2)
36+
2137

2238

2339
def stress_test(my_rank, pipe):

vllm/distributed/kv_transfer/kv_lookup_buffer/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,10 @@ def insert(self,
1515
@abstractmethod
1616
def drop_select(self, input_tokens, roi) -> Optional[torch.Tensor]:
1717
raise NotImplementedError
18-
18+
19+
@abstractmethod
20+
def close(self):
21+
"""
22+
Close the buffer, release resources.
23+
"""
24+
raise NotImplementedError

vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import \
33
KVLookupBufferBase
4+
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
45
from typing import Dict, Tuple, List, Optional
56
import threading
67
import torch
@@ -13,16 +14,24 @@
1314

1415
class SimpleKVLookupBuffer(KVLookupBufferBase):
1516

16-
def __init__(self, pipe, buffer_size_thresh):
17+
def __init__(self, signal_pipe, data_pipe, buffer_size_thresh):
18+
"""
19+
signal_pipe: on CPU -- avoid recv() stops the python intepreter
20+
data_pipe: on GPU
21+
"""
1722

1823
self.buffer = deque()
1924

2025
self.buffer_size = 0
2126
self.buffer_size_threshold = buffer_size_thresh
2227
self.buffer_lock = threading.Lock()
23-
self.pipe = pipe
28+
self.signal_pipe = signal_pipe
29+
self.data_pipe = data_pipe
2430
self.request_handling_thread = None
2531

32+
self.normal_signal = torch.tensor([0])
33+
self.end_signal = None
34+
2635

2736
def _matches(self, tokens_roi_sender, tokens_roi_recver):
2837

@@ -57,9 +66,9 @@ def _matches(self, tokens_roi_sender, tokens_roi_recver):
5766

5867
def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor]) -> None:
5968

60-
assert tensor is not None, "Use self.pipe.send(None) instead"
69+
assert tensor is not None, "Use self.data_pipe.send(None) instead"
6170
self.buffer_size -= tensor.element_size() * tensor.numel()
62-
self.pipe.send_tensor(tensor)
71+
self.data_pipe.send_tensor(tensor)
6372

6473
def _get_element_size(self, data):
6574

@@ -91,14 +100,22 @@ def _add_to_buffer(self, input_tokens, roi, key, value, hidden):
91100
self.buffer_size += self._get_element_size(data)
92101
self.buffer.append(buffer_item)
93102

103+
def _is_end_signal(self, signal):
104+
return signal is None
94105

95106
def drop_select_handler(self):
96107

97108
try:
98109

99110
while True:
100-
input_tokens = self.pipe.recv_tensor()
101-
roi = self.pipe.recv_tensor()
111+
signal = self.signal_pipe.recv_tensor()
112+
if self._is_end_signal(signal):
113+
logger.info("Received end signal!")
114+
break
115+
116+
input_tokens = self.data_pipe.recv_tensor()
117+
118+
roi = self.data_pipe.recv_tensor()
102119
tokens_roi_recver = [input_tokens, roi]
103120

104121
matched_length = 0
@@ -125,10 +142,13 @@ def drop_select_handler(self):
125142
else:
126143
# no match, just send None
127144
for _ in range(5):
128-
self.pipe.send_tensor(None)
145+
self.data_pipe.send_tensor(None)
146+
129147
except RuntimeError as e:
130148
if 'Connection closed by peer' not in str(e):
131149
raise e
150+
151+
logger.debug("closing drop_select_handler")
132152

133153

134154
def drop_select(self, input_tokens, roi):
@@ -142,14 +162,15 @@ def drop_select(self, input_tokens, roi):
142162
if isinstance(roi, torch.Tensor):
143163
roi = roi.clone()
144164

145-
self.pipe.send_tensor(input_tokens)
146-
self.pipe.send_tensor(roi)
165+
self.signal_pipe.send_tensor(self.normal_signal)
166+
self.data_pipe.send_tensor(input_tokens)
167+
self.data_pipe.send_tensor(roi)
147168

148-
input_tokens = self.pipe.recv_tensor()
149-
roi = self.pipe.recv_tensor()
150-
key = self.pipe.recv_tensor()
151-
value = self.pipe.recv_tensor()
152-
hidden = self.pipe.recv_tensor()
169+
input_tokens = self.data_pipe.recv_tensor()
170+
roi = self.data_pipe.recv_tensor()
171+
key = self.data_pipe.recv_tensor()
172+
value = self.data_pipe.recv_tensor()
173+
hidden = self.data_pipe.recv_tensor()
153174

154175
return [input_tokens, roi, key, value, hidden]
155176

@@ -173,4 +194,12 @@ def insert(self, input_tokens, roi, key, value, hidden) -> None:
173194
target=self.drop_select_handler)
174195
self.request_handling_thread.start()
175196

176-
197+
198+
def close(self):
199+
200+
if hasattr(self, "request_handling_thread") and self.request_handling_thread is not None:
201+
self.request_handling_thread.join()
202+
203+
else:
204+
# TODO: have a explicit close signal and have a explicit way to check if it's requester
205+
self.signal_pipe.send_tensor(self.end_signal)

vllm/distributed/kv_transfer/kv_pipe/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,8 @@ def send_tensor(self, tensor):
1010

1111
@abstractmethod
1212
def recv_tensor(self):
13-
raise NotImplementedError
13+
raise NotImplementedError
14+
15+
@abstractmethod
16+
def close(self):
17+
raise NotImplementedError

vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class TorchDistributedPipe:
5050
MAX_TENSOR_DIMENSIONS = 14
5151
METADATA_DTYPE = torch.int64
5252

53+
5354
def __init__(
5455
self,
5556
group_ranks: List[List[int]],
@@ -73,10 +74,7 @@ def __init__(
7374
assert self.device_group is not None
7475
assert self.rank_in_group <= 1
7576

76-
if torch.cuda.is_available():
77-
self.device = torch.device(f"cuda:{local_rank}")
78-
else:
79-
self.device = torch.device("cpu")
77+
self.device = self._select_device(torch_distributed_backend)
8078

8179
self.target_rank_for_send = self.ranks[
8280
(self.rank_in_group + 1) % self.world_size
@@ -99,6 +97,12 @@ def __init__(
9997
self.METADATA_LENGTH, dtype=self.METADATA_DTYPE, device=self.device
10098
)
10199

100+
def _select_device(self, backend: Union[str, Backend]):
101+
if torch.cuda.is_available() and backend == Backend.NCCL:
102+
return torch.device(f"cuda:{self.local_rank}")
103+
else:
104+
return "cpu"
105+
102106
def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor:
103107
"""
104108
Create the metadata on based on the input tensor, and move it to GPU.
@@ -168,11 +172,12 @@ def _recv_metadata(self) -> torch.Tensor:
168172
race conditions during sending/receiving. Therefore, the metadata
169173
buffer can be reused
170174
"""
171-
torch.distributed.recv(
175+
task = torch.distributed.recv(
172176
self.rcv_metadata_buffer,
173177
src=self.target_rank_for_recv,
174178
group=self.device_group,
175179
)
180+
176181
return self.rcv_metadata_buffer
177182

178183
def _send_impl(self, tensor):
@@ -256,15 +261,16 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
256261
# print("Remaining size:", self.buffer_size)
257262
self.buffer_size = self.buffer_size + tensor_size
258263

259-
# prepare the metadata before sending the tensor.
264+
265+
#self.send_tensor_wrapper(tensor)
260266
self.transport_thread.submit(
261267
self.send_tensor_wrapper,
262268
tensor,
263269
)
264270

271+
265272
def recv_tensor(self) -> Optional[torch.Tensor]:
266273
"""Receives a tensor from the src rank. Blocking."""
267-
268274
if self.transport_thread is None:
269275
self.transport_thread = ThreadPoolExecutor(max_workers=1)
270276

@@ -276,6 +282,8 @@ def recv_tensor(self) -> Optional[torch.Tensor]:
276282
logger.error("Encountering exception in KV receiving thread")
277283
logger.error("%s", e)
278284

285+
#tensor = self._recv_impl()
286+
279287
if tensor.numel() == 1 and tensor.item() == NONE_INT:
280288
return None
281289
else:

0 commit comments

Comments
 (0)