Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lookup buffer implementation #4

Merged
merged 2 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions tests/kv_transfer/test_lookup_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from tqdm import tqdm
import time

# TODO: the test depends on a lot of fields in the current implementation. We should have standard interface instead direct field access

def test_run(my_rank, buffer):
def test_run(my_rank, buffer, device):

# buffer should be empty in the beginning
if my_rank == 0:
Expand All @@ -17,46 +18,49 @@ def test_run(my_rank, buffer):


# insert
tokens = torch.tensor([1,2,3]).to(buffer.pipe.device)
tokens = torch.tensor([1,2,3]).to(device)
roi = (tokens > 0)
if my_rank == 0:
key = 2.0 * torch.ones([5, 6]).to(buffer.pipe.device)
value = 3.0 * torch.ones([5, 6]).to(buffer.pipe.device)
key = 2.0 * torch.ones([5, 6]).to(device)
value = 3.0 * torch.ones([5, 6]).to(device)

placeholder = torch.tensor([1]).to(buffer.pipe.device)
placeholder = torch.tensor([1]).to(device)

buffer.insert(tokens, roi, key, value, placeholder)

#for i in range(2000):
# print("Here:", i)
# time.sleep(0.01)
torch.distributed.barrier()

# drop_select
if my_rank == 1:
tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi)
assert torch.allclose(tokens, tok)
assert torch.allclose(roi, roi_)
assert torch.allclose(key, 2.0 * torch.ones([5, 6]))
assert torch.allclose(value, 3.0 * torch.ones([5, 6]))
assert torch.allclose(key, 2.0 * torch.ones([5, 6], device = device))
assert torch.allclose(value, 3.0 * torch.ones([5, 6], device = device))
torch.distributed.barrier()

if my_rank == 0:
assert buffer.buffer_size == 0
assert len(buffer.buffer) == 0

print("Test run passed!")


def stress_test(my_rank, buf):
def stress_test(my_rank, buf, device):

torch.distributed.barrier()
torch.manual_seed(100)

device = buf.pipe.device

reqs = [
(
torch.rand(100).to(device), # tokens
torch.ones(100).bool().to(device), # roi
torch.rand(100).to(device), # key
torch.rand(100).to(device), # value
torch.rand(100).to(device), # hidden
) for i in range(200)]
) for i in tqdm(range(200))]

random.seed(my_rank)
random.shuffle(reqs)
Expand Down Expand Up @@ -86,7 +90,7 @@ def stress_test(my_rank, buf):
assert torch.allclose(k, k_)
assert torch.allclose(v, v_)
assert torch.allclose(h, h_)
print('Rand %d done' % my_rank)
print('Rank %d done' % my_rank)
torch.distributed.barrier()


Expand All @@ -101,13 +105,9 @@ def stress_test(my_rank, buf):
else:
torch.distributed.send(torch.tensor([n]), 0)

print("Passed stress test!")








if __name__ == "__main__":

Expand All @@ -123,10 +123,14 @@ def stress_test(my_rank, buf):


pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "nccl")
buffer = sklb.SimpleKVLookupBuffer(pipe, 170000)
cpu_pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "gloo")
buffer = sklb.SimpleKVLookupBuffer(cpu_pipe, pipe, 170000)

test_run(my_rank, buffer)
test_run(my_rank, buffer, pipe.device)

stress_test(my_rank, buffer)
stress_test(my_rank, buffer, pipe.device)

buffer.close()
pipe.close()
cpu_pipe.close()
print('Done')
22 changes: 19 additions & 3 deletions tests/kv_transfer/test_send_recv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,28 @@ def test_run(my_rank, pipe):
x = torch.tensor([1]).to(pipe.device)
y = torch.tensor([[2., 3., 4., 8.]]).to(pipe.device)
if my_rank == 0:

pipe.send_tensor(x)
print("sent tensor x")
pipe.send_tensor(y)
print("sent tensor y")
x2 = pipe.recv_tensor()
print("received x2 = ", x2)
y2 = pipe.recv_tensor()
print("received y2 = ", x2)

else:
assert torch.allclose(x, pipe.recv_tensor())
assert torch.allclose(y, pipe.recv_tensor())
x2 = pipe.recv_tensor()
print("received x2 = ", x2)
y2 = pipe.recv_tensor()
print("received y2 = ", x2)
pipe.send_tensor(x)
print("sent tensor x")
pipe.send_tensor(y)
print("sent tensor y")

assert torch.allclose(x, x2)
assert torch.allclose(y, y2)



def stress_test(my_rank, pipe):
Expand Down
8 changes: 7 additions & 1 deletion vllm/distributed/kv_transfer/kv_lookup_buffer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,10 @@ def insert(self,
@abstractmethod
def drop_select(self, input_tokens, roi) -> Optional[torch.Tensor]:
raise NotImplementedError


@abstractmethod
def close(self):
"""
Close the buffer, release resources.
"""
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

from vllm.distributed.kv_transfer.kv_lookup_buffer.base import \
KVLookupBufferBase
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from typing import Dict, Tuple, List, Optional
import threading
import torch
Expand All @@ -13,16 +14,24 @@

class SimpleKVLookupBuffer(KVLookupBufferBase):

def __init__(self, pipe, buffer_size_thresh):
def __init__(self, signal_pipe, data_pipe, buffer_size_thresh):
"""
signal_pipe: on CPU -- avoid recv() stops the python intepreter
data_pipe: on GPU
"""

self.buffer = deque()

self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh
self.buffer_lock = threading.Lock()
self.pipe = pipe
self.signal_pipe = signal_pipe
self.data_pipe = data_pipe
self.request_handling_thread = None

self.normal_signal = torch.tensor([0])
self.end_signal = None


def _matches(self, tokens_roi_sender, tokens_roi_recver):

Expand Down Expand Up @@ -57,9 +66,9 @@ def _matches(self, tokens_roi_sender, tokens_roi_recver):

def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor]) -> None:

assert tensor is not None, "Use self.pipe.send(None) instead"
assert tensor is not None, "Use self.data_pipe.send(None) instead"
self.buffer_size -= tensor.element_size() * tensor.numel()
self.pipe.send_tensor(tensor)
self.data_pipe.send_tensor(tensor)

def _get_element_size(self, data):

Expand Down Expand Up @@ -91,14 +100,22 @@ def _add_to_buffer(self, input_tokens, roi, key, value, hidden):
self.buffer_size += self._get_element_size(data)
self.buffer.append(buffer_item)

def _is_end_signal(self, signal):
return signal is None

def drop_select_handler(self):

try:

while True:
input_tokens = self.pipe.recv_tensor()
roi = self.pipe.recv_tensor()
signal = self.signal_pipe.recv_tensor()
if self._is_end_signal(signal):
logger.info("Received end signal!")
break

input_tokens = self.data_pipe.recv_tensor()

roi = self.data_pipe.recv_tensor()
tokens_roi_recver = [input_tokens, roi]

matched_length = 0
Expand All @@ -125,10 +142,13 @@ def drop_select_handler(self):
else:
# no match, just send None
for _ in range(5):
self.pipe.send_tensor(None)
self.data_pipe.send_tensor(None)

except RuntimeError as e:
if 'Connection closed by peer' not in str(e):
raise e

logger.debug("closing drop_select_handler")


def drop_select(self, input_tokens, roi):
Expand All @@ -142,14 +162,15 @@ def drop_select(self, input_tokens, roi):
if isinstance(roi, torch.Tensor):
roi = roi.clone()

self.pipe.send_tensor(input_tokens)
self.pipe.send_tensor(roi)
self.signal_pipe.send_tensor(self.normal_signal)
self.data_pipe.send_tensor(input_tokens)
self.data_pipe.send_tensor(roi)

input_tokens = self.pipe.recv_tensor()
roi = self.pipe.recv_tensor()
key = self.pipe.recv_tensor()
value = self.pipe.recv_tensor()
hidden = self.pipe.recv_tensor()
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
key = self.data_pipe.recv_tensor()
value = self.data_pipe.recv_tensor()
hidden = self.data_pipe.recv_tensor()

return [input_tokens, roi, key, value, hidden]

Expand All @@ -173,4 +194,12 @@ def insert(self, input_tokens, roi, key, value, hidden) -> None:
target=self.drop_select_handler)
self.request_handling_thread.start()



def close(self):

if hasattr(self, "request_handling_thread") and self.request_handling_thread is not None:
self.request_handling_thread.join()

else:
# TODO: have a explicit close signal and have a explicit way to check if it's requester
self.signal_pipe.send_tensor(self.end_signal)
6 changes: 5 additions & 1 deletion vllm/distributed/kv_transfer/kv_pipe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ def send_tensor(self, tensor):

@abstractmethod
def recv_tensor(self):
raise NotImplementedError
raise NotImplementedError

@abstractmethod
def close(self):
raise NotImplementedError
22 changes: 15 additions & 7 deletions vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class TorchDistributedPipe:
MAX_TENSOR_DIMENSIONS = 14
METADATA_DTYPE = torch.int64


def __init__(
self,
group_ranks: List[List[int]],
Expand All @@ -73,10 +74,7 @@ def __init__(
assert self.device_group is not None
assert self.rank_in_group <= 1

if torch.cuda.is_available():
self.device = torch.device(f"cuda:{local_rank}")
else:
self.device = torch.device("cpu")
self.device = self._select_device(torch_distributed_backend)

self.target_rank_for_send = self.ranks[
(self.rank_in_group + 1) % self.world_size
Expand All @@ -99,6 +97,12 @@ def __init__(
self.METADATA_LENGTH, dtype=self.METADATA_DTYPE, device=self.device
)

def _select_device(self, backend: Union[str, Backend]):
if torch.cuda.is_available() and backend == Backend.NCCL:
return torch.device(f"cuda:{self.local_rank}")
else:
return "cpu"

def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Create the metadata on based on the input tensor, and move it to GPU.
Expand Down Expand Up @@ -168,11 +172,12 @@ def _recv_metadata(self) -> torch.Tensor:
race conditions during sending/receiving. Therefore, the metadata
buffer can be reused
"""
torch.distributed.recv(
task = torch.distributed.recv(
self.rcv_metadata_buffer,
src=self.target_rank_for_recv,
group=self.device_group,
)

return self.rcv_metadata_buffer

def _send_impl(self, tensor):
Expand Down Expand Up @@ -256,15 +261,16 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
# print("Remaining size:", self.buffer_size)
self.buffer_size = self.buffer_size + tensor_size

# prepare the metadata before sending the tensor.

#self.send_tensor_wrapper(tensor)
self.transport_thread.submit(
self.send_tensor_wrapper,
tensor,
)


def recv_tensor(self) -> Optional[torch.Tensor]:
"""Receives a tensor from the src rank. Blocking."""

if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)

Expand All @@ -276,6 +282,8 @@ def recv_tensor(self) -> Optional[torch.Tensor]:
logger.error("Encountering exception in KV receiving thread")
logger.error("%s", e)

#tensor = self._recv_impl()

if tensor.numel() == 1 and tensor.item() == NONE_INT:
return None
else:
Expand Down
Loading