Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kuntai disagg refactor #6

Merged
merged 6 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/kv_transfer/test_lookup_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def stress_test(my_rank, buf, device):
n = 0

# the buffer size can only store 100 reqs
# so the sender will occasionally block.needs to wait for the receiver.
# so the sender will occasionally block to wait for the receiver.
for req in tqdm(reqs):
if my_rank == 0:
buf.insert(*req)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,14 @@ def _matches(self, tokens_roi_sender, tokens_roi_recver):
return True


# I am assuming that roi is a mask on tokens
# Assuming that roi is a mask on tokens
tokens_sender = tokens_sender[roi_sender]
tokens_recver = tokens_recver[roi_recver]


# simple common prefix matching
min_length = min(len(tokens_sender), len(tokens_recver))
if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]):
# drastically simplified
# common prefix matching
return min_length

return 0
Expand Down Expand Up @@ -148,7 +147,7 @@ def drop_select_handler(self):
if 'Connection closed by peer' not in str(e):
raise e

logger.debug("closing drop_select_handler")
logger.debug("Closing drop_select_handler")


def drop_select(self, input_tokens, roi):
Expand Down Expand Up @@ -182,7 +181,7 @@ def full_handler(self):
def insert(self, input_tokens, roi, key, value, hidden) -> None:

while self.buffer_size > self.buffer_size_threshold:
logger.debug("KV transfer buffer is full. Handling...")
# logger.debug("KV transfer buffer is full. Handling...")
self.full_handler()

self._add_to_buffer(input_tokens, roi, key, value, hidden)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
logger = init_logger(__name__)


# if the tensor is only one-element and only contains this number
# if the tensor is only one-element and only contains NONE_INT
# this means that the sended object is None.
NONE_INT = -150886311

# Mapping tensor dtype to a int, used for tensor metadata transmission
FLOAT16_INT = -543205003776624
INT64_INT = -375623078607432
BOOL_INT = -28035262008646
Expand Down Expand Up @@ -258,11 +260,9 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
self.block_if_full()

with self.buffer_size_lock:
# print("Remaining size:", self.buffer_size)
self.buffer_size = self.buffer_size + tensor_size


#self.send_tensor_wrapper(tensor)
self.transport_thread.submit(
self.send_tensor_wrapper,
tensor,
Expand Down
27 changes: 21 additions & 6 deletions vllm/distributed/kv_transfer/vllm_adapter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""vLLM distributed KV cache transfer API.
These APIs are used in `vllm/worker/model_runner.py`.

Currently supporting TP and PP.
Currently supporting TP and PP, but TP and PP must be the same.

Workflow:
- In prefill instance, KV cache sender *buffers* the KV cache send requests
- In prefill instance, vLLM `insert` that buffers the KV cache into lookup buffer.
- In decode instance
- KV cache receiver sends the hash of input tokens to sender
- KV cache sender executes send request
- KV cache receiver receives the KV cache
- vLLM first runs `drop_select` to send input tokens and a mask on input tokens to sender
- The prefill instance send back the matching KV caches
- vLLM then store the KV cache into paged memory.
"""
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from collections import defaultdict, deque
Expand Down Expand Up @@ -68,6 +68,19 @@ def __init__(
torch_distributed_backend: Union[str, Backend],
):


# init pipe
self.device_pipe = TorchDistributedPipe(
group_ranks,
local_rank,
torch_distributed_backend,
)
self.cpu_pipe = TorchDistributedPipe(
group_ranks,
local_rank,
"gloo"
)

# init two pipes: one or send and one for recv
if IS_KV_PREFILL_INSTANCE or IS_LMCACHE_INSTANCE:
self.send_pipe = TorchDistributedPipe(
Expand Down Expand Up @@ -95,8 +108,10 @@ def __init__(

# FIXME(Jiayi): buffer initializtion should be adapted accordingly
# Signal pipe needs to be initialized on both vllm and lmc side

# init lookup buffer
self.buffer = SimpleKVLookupBuffer(self.pipe, 1000**3 * 10)
# TODO: replace this 1e9 with a configurable parameter or a constant
self.buffer = SimpleKVLookupBuffer(self.cpu_pipe, self.device_pipe, 1e9 * 10)

def send_kv_caches_and_hidden_states(
self,
Expand Down