diff --git a/.buildkite/features/KV_Cache_Host_Offloading.yml b/.buildkite/features/KV_Cache_Host_Offloading.yml deleted file mode 100644 index 020392441..000000000 --- a/.buildkite/features/KV_Cache_Host_Offloading.yml +++ /dev/null @@ -1,45 +0,0 @@ -# KV cache host offloading -# feature support matrix -steps: - - label: "Correctness tests for KV cache host offloading" - key: "KV_Cache_Host_Offloading_CorrectnessTest" - soft_fail: true - agents: - queue: tpu_v6e_queue - commands: - - | - buildkite-agent meta-data set "KV_Cache_Host_Offloading_CorrectnessTest" "to be added" - - label: "Record correctness test result for KV cache host offloading" - key: "record_KV_Cache_Host_Offloading_CorrectnessTest" - depends_on: "KV_Cache_Host_Offloading_CorrectnessTest" - env: - CI_TARGET: "KV cache host offloading" - CI_STAGE: "CorrectnessTest" - CI_CATEGORY: "feature support matrix" - agents: - queue: cpu - commands: - - | - .buildkite/scripts/record_step_result.sh KV_Cache_Host_Offloading_CorrectnessTest - - - label: "Performance tests for KV cache host offloading" - key: "KV_Cache_Host_Offloading_PerformanceTest" - depends_on: "record_KV_Cache_Host_Offloading_CorrectnessTest" - soft_fail: true - agents: - queue: tpu_v6e_queue - commands: - - | - buildkite-agent meta-data set "KV_Cache_Host_Offloading_PerformanceTest" "to be added" - - label: "Record performance test result for KV cache host offloading" - key: "record_KV_Cache_Host_Offloading_PerformanceTest" - depends_on: "KV_Cache_Host_Offloading_PerformanceTest" - env: - CI_TARGET: "KV cache host offloading" - CI_STAGE: "PerformanceTest" - CI_CATEGORY: "feature support matrix" - agents: - queue: cpu - commands: - - | - .buildkite/scripts/record_step_result.sh KV_Cache_Host_Offloading_PerformanceTest diff --git a/.buildkite/features/KV_Cache_Offload.yml b/.buildkite/features/KV_Cache_Offload.yml new file mode 100644 index 000000000..17a88786b --- /dev/null +++ b/.buildkite/features/KV_Cache_Offload.yml @@ -0,0 +1,49 @@ +# KV Cache Offload +# feature support matrix +steps: + - label: "Correctness tests for KV Cache Offload" + key: "KV_Cache_Offload_CorrectnessTest" + soft_fail: true + env: + USE_V6E8_QUEUE: "True" + VLLM_LOG_LEVEL: "INFO" + agents: + queue: tpu_v6e_8_queue + commands: + - | + .buildkite/scripts/run_in_docker.sh \ + python3 -m pytest -s -v /workspace/tpu_inference/tests/offload/tpu_offload_accuracy_test.py + - label: "Record correctness test result for KV Cache Offload" + key: "record_KV_Cache_Offload_CorrectnessTest" + depends_on: "KV_Cache_Offload_CorrectnessTest" + env: + CI_TARGET: "KV Cache Offload" + CI_STAGE: "CorrectnessTest" + CI_CATEGORY: "feature support matrix" + agents: + queue: cpu + commands: + - | + .buildkite/scripts/record_step_result.sh KV_Cache_Offload_CorrectnessTest + + - label: "Performance tests for KV Cache Offload" + key: "KV_Cache_Offload_PerformanceTest" + depends_on: "record_KV_Cache_Offload_CorrectnessTest" + soft_fail: true + agents: + queue: tpu_v6e_queue + commands: + - | + buildkite-agent meta-data set "KV_Cache_Offload_PerformanceTest" "to be added" + - label: "Record performance test result for KV Cache Offload" + key: "record_KV_Cache_Offload_PerformanceTest" + depends_on: "KV_Cache_Offload_PerformanceTest" + env: + CI_TARGET: "KV Cache Offload" + CI_STAGE: "PerformanceTest" + CI_CATEGORY: "feature support matrix" + agents: + queue: cpu + commands: + - | + .buildkite/scripts/record_step_result.sh KV_Cache_Offload_PerformanceTest diff --git a/.buildkite/pipeline_jax.yml b/.buildkite/pipeline_jax.yml index 19c232ae8..2041a6d2f 100644 --- a/.buildkite/pipeline_jax.yml +++ b/.buildkite/pipeline_jax.yml @@ -122,6 +122,7 @@ steps: --ignore=/workspace/tpu_inference/tests/e2e \ --ignore=/workspace/tpu_inference/tpu_inference/mock \ --ignore=/workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_moe.py \ + --ignore=/workspace/tpu_inference/tests/offload \ --cov-config=/workspace/tpu_inference/.coveragerc --cov tpu_inference --cov-report term-missing --cov-fail-under=69 - label: "JAX unit tests - kernels" @@ -137,6 +138,7 @@ steps: --ignore=/workspace/tpu_inference/tests/kernels/ragged_paged_attention_kernel_v2_test.py \ --ignore=/workspace/tpu_inference/tests/kernels/ragged_kv_cache_update_v2_test.py \ --ignore=/workspace/tpu_inference/tests/kernels/collectives \ + --ignore=/workspace/tpu_inference/tests/kernels/host_dma_test.py \ --ignore=/workspace/tpu_inference/tests/kernels/fused_moe_v1_test.py else echo "Skipping: no changes detected in kernels, tests/kernels, or requirements.txt" @@ -255,6 +257,21 @@ steps: echo "Skipping: NIGHTLY environment variable not set" exit 0 fi + + - label: "kv cache offload tests on multi chips" + key: test_17 + soft_fail: true + env: + USE_V6E8_QUEUE: "True" + VLLM_LOG_LEVEL: "INFO" + agents: + queue: tpu_v6e_8_queue + commands: + - | + .buildkite/scripts/run_in_docker.sh \ + python3 -m pytest -s -v -x /workspace/tpu_inference/tests/offload/ \ + /workspace/tpu_inference/tests/kernels/host_dma_test.py \ + --ignore=/workspace/tpu_inference/tests/offload/tpu_offload_accuracy_test.py # ----------------------------------------------------------------- # NOTIFICATION STEP # ----------------------------------------------------------------- @@ -277,9 +294,10 @@ steps: - test_13 - test_15 - test_16 + - test_17 agents: queue: cpu commands: - | .buildkite/scripts/check_results.sh \ - "TPU JAX Tests Failed" test_0 test_1 test_2 test_3 test_4 test_5 test_6 test_7 test_8 test_9 test_10 test_11 test_12 test_13 test_15 test_16 + "TPU JAX Tests Failed" test_0 test_1 test_2 test_3 test_4 test_5 test_6 test_7 test_8 test_9 test_10 test_11 test_12 test_13 test_15 test_16 test_17 diff --git a/examples/offload/gke/benchmarks/README.md b/examples/offload/gke/benchmarks/README.md new file mode 100644 index 000000000..9d1136637 --- /dev/null +++ b/examples/offload/gke/benchmarks/README.md @@ -0,0 +1,117 @@ +# Benchmarks using SGLang bench_serving tool + +This guide outlines the steps to deploy a vLLM serving instance on Google Kubernetes Engine (GKE) with TPUs, create a service to expose it, and then run the SGLang `bench_serving.py` benchmark against it. Two deployment options for vLLM are provided: a baseline without host offload and one with TPU host offload for KV cache. + +## Prerequisites + +* `kubectl` configured to connect to your GKE cluster. +* `gcloud` CLI installed and authenticated. +* A GKE cluster with TPU nodes (the below steps have been verified with `ct6e-standard-8t` GKE node) +* Access to Llama-3.3-70B model on Hugging Face + +## 1. Create Hugging Face Token Secret + +A Hugging Face token is required to pull the model. Create a Kubernetes secret with your token: + +```bash +kubectl create secret generic hf-token-secret --from-literal=token='' +``` + +Replace `` with your actual Hugging Face token. + +## 2. Deploy vLLM Pod (Choose One) + +Choose one of the following deployment options for your vLLM pod. Ensure the right container image is used in the pod spec + +### Option A: Baseline vLLM (No Host Offload) + +This deployment uses a standard vLLM setup without any specific TPU host offload connector. The KV cache will reside entirely on the TPU HBM. + +```bash +kubectl apply -f deploy-baseline.yaml +``` + +### Option B: vLLM with TPU Host Offload + +This deployment configures vLLM to use a `TPUOffloadConnector` for KV cache offload to the host CPU memory. This is specified by the `--kv-transfer-config` argument. + +```bash +kubectl apply -f deploy-cpu-offload.yaml +``` + +## 3. Deploy Service + +Deploy a LoadBalancer service to expose your vLLM deployment. This will provide an external IP address to send benchmark requests to. + +```bash +kubectl apply -f service.yaml +``` + +After deployment, get the external IP of the service: + +```bash +kubectl get service tpu-offline-inference -o jsonpath='{.status.loadBalancer.ingress[0].ip}' +``` + +This command will directly output the external IP address. It might take a few minutes for the IP to be provisioned. + +## 4. Run Benchmark + +Instead of installing SGLang locally, we can run the benchmark from within the Kubernetes cluster using a dedicated pod. This approach avoids local dependency management and ensures the benchmark runs in a consistent environment. + +### a. Configure the Benchmark Pod + +A sample pod specification is provided in `benchmark-pod.yaml`. Before deploying it, you need to configure the environment variables within the file, especially the `IP` of the vLLM service. + +Open `benchmark-pod.yaml` and replace `` with the actual external IP address of your `tpu-offline-inference` service obtained in step 3. + +You can also adjust the following benchmark parameters via environment variables in the `benchmark-pod.yaml` file: + +* `GSP_NUM_GROUPS`: The number of unique system prompts. +* `GSP_PROMPTS_PER_GROUP`: The number of questions per system prompt. +* `GSP_SYSTEM_PROMPT_LEN`: The token length of the system prompt. +* `GSP_QUESTION_LEN`: The token length of the question. +* `GSP_OUTPUT_LEN`: The desired output token length. +* `MODEL`: The model to benchmark. + +### b. Deploy the Benchmark Pod + +Once configured, deploy the benchmark pod: + +```bash +kubectl apply -f benchmark-pod.yaml +``` + +The pod will start, clone the SGLang repository, install dependencies, and run the benchmark. + +### c. Monitor the Benchmark + +You can monitor the progress of the benchmark by checking the logs of the pod: + +```bash +kubectl logs -f sglang-benchmark +``` + +The pod is configured with `restartPolicy: Never`, so it will run the benchmark once and then complete. + +## 5. Understanding `generated-shared-prefix` Dataset + +The `generated-shared-prefix` dataset is designed to benchmark serving performance for workloads where multiple requests share a common, long prefix. This is common in applications using system prompts or few-shot examples. + +**How it works:** + +1. **System Prompt Generation:** A specified number of unique "system prompts" are generated. Each is a long sequence of random tokens. +2. **Question Generation:** Shorter "questions" (random tokens) are generated. +3. **Prompt Combination:** Each system prompt is combined with multiple unique questions to form final prompts. This creates groups of prompts where each prompt in a group shares the exact same system prompt as a prefix. +4. **Request Creation:** Each final prompt is packaged with its desired output length. +5. **Shuffling:** The entire set of generated requests is randomly shuffled. This interleaves requests from different groups, simulating realistic traffic where shared prefixes are not necessarily processed sequentially. +6. **Caching:** The generated dataset is cached locally for faster subsequent runs with the same parameters. + +**Key Parameters for `generated-shared-prefix`:** + +* `--gsp-num-groups`: The number of unique system prompts to generate. Each system prompt forms a "group" of requests. +* `--gsp-prompts-per-group`: The number of unique questions that will be appended to each system prompt. This determines how many requests will share a given system prompt. +* `--gsp-system-prompt-len`: The length (in tokens) of each generated system prompt. +* `--gsp-question-len`: The length (in tokens) of each generated question. +* `--gsp-output-len`: The desired length (in tokens) of the generated output for each request. +* `--seed`: (Optional) An integer seed for random number generation, ensuring reproducible prompt generation and request shuffling across runs. diff --git a/examples/offload/gke/benchmarks/benchmark-pod.yaml b/examples/offload/gke/benchmarks/benchmark-pod.yaml new file mode 100644 index 000000000..05e2da502 --- /dev/null +++ b/examples/offload/gke/benchmarks/benchmark-pod.yaml @@ -0,0 +1,55 @@ +apiVersion: v1 +kind: Pod +metadata: + name: sglang-benchmark +spec: + containers: + - name: sglang-benchmark-container + image: python:3.9-slim + command: ["/bin/bash", "-c"] + args: + - | + set -ex + apt-get update && apt-get install -y git + git clone -b v0.5.2 https://github.com/sgl-project/sglang.git + cd sglang + pip install --upgrade pip + pip install protobuf aiohttp numpy requests tqdm transformers + python3 python/sglang/bench_serving.py \ + --host=$(IP) \ + --port=$(PORT) \ + --dataset-name='generated-shared-prefix' \ + --model=$(MODEL) \ + --tokenizer=$(MODEL) \ + --backend=vllm \ + --gsp-num-groups=$(GSP_NUM_GROUPS) \ + --gsp-prompts-per-group=$(GSP_PROMPTS_PER_GROUP) \ + --gsp-system-prompt-len=$(GSP_SYSTEM_PROMPT_LEN) \ + --gsp-question-len=$(GSP_QUESTION_LEN) \ + --gsp-output-len=$(GSP_OUTPUT_LEN) \ + --request-rate=800 \ + --max-concurrency=300 \ + --seed 42 + env: + - name: IP + value: "34.162.66.198" # Replace with the external IP of your deployed service + - name: PORT + value: "80" + - name: MODEL + value: "meta-llama/Llama-3.3-70B-Instruct" + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + - name: GSP_NUM_GROUPS + value: "2" + - name: GSP_PROMPTS_PER_GROUP + value: "16" + - name: GSP_SYSTEM_PROMPT_LEN + value: "2048" + - name: GSP_QUESTION_LEN + value: "256" + - name: GSP_OUTPUT_LEN + value: "512" + restartPolicy: Never diff --git a/examples/offload/gke/benchmarks/deploy-baseline.yaml b/examples/offload/gke/benchmarks/deploy-baseline.yaml new file mode 100644 index 000000000..fdaab3147 --- /dev/null +++ b/examples/offload/gke/benchmarks/deploy-baseline.yaml @@ -0,0 +1,39 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: tpu-offline-inference +spec: + replicas: 1 + selector: + matchLabels: + app: tpu-offline-inference + template: + metadata: + labels: + app: tpu-offline-inference + spec: + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice + cloud.google.com/gke-tpu-topology: 2x4 # Specify the physical topology for the TPU slice. + containers: + - name: tpu-job + image: + imagePullPolicy: Always + command: ["/bin/sh", "-c"] + args: + - "vllm serve meta-llama/Llama-3.3-70B-Instruct --port 8000 --enable-chunked-prefill --tensor-parallel-size 8 --seed 42 --enable_prefix_caching --gpu-memory-utilization 0.9" + env: + - name: HUGGING_FACE_HUB_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + - name: SKIP_JAX_PRECOMPILE + value: "0" + ports: + - containerPort: 8000 + resources: + requests: + google.com/tpu: 8 + limits: + google.com/tpu: 8 diff --git a/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml b/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml new file mode 100644 index 000000000..0996a4122 --- /dev/null +++ b/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml @@ -0,0 +1,70 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: tpu-offline-inference +spec: + replicas: 1 + selector: + matchLabels: + app: tpu-offline-inference + template: + metadata: + labels: + app: tpu-offline-inference + spec: + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice + cloud.google.com/gke-tpu-topology: 2x4 # Specify the physical topology for the TPU slice. + initContainers: + - name: tpu-node-setup + image: busybox + command: ["/bin/sh", "-c"] + args: + - | + # WARNING: This changes the HOST memory settings, not just the container. + # Required to prevent vLLM crashes due to memory mapping limits. + sysctl -w vm.max_map_count=8388608 + + # Check if the VFIO IOMMU module parameter exists, and if so, increase the + # limit on DMA mappings. This allows the TPU driver to pin and map a + # larger number of memory pages for direct hardware access. + if [ -f /sys/module/vfio_iommu_type1/parameters/dma_entry_limit ]; then + echo 2000000 > /sys/module/vfio_iommu_type1/parameters/dma_entry_limit + echo "Successfully increased dma_entry_limit to 2000000" + else + echo "Warning: vfio_iommu_type1 module parameter not found. Ensure the module is loaded." + fi + securityContext: + privileged: true + containers: + - name: tpu-job + image: + imagePullPolicy: Always + command: ["/bin/sh", "-c"] + args: + - "vllm serve meta-llama/Llama-3.3-70B-Instruct --kv-transfer-config '{\"kv_connector\":\"TPUOffloadConnector\",\"kv_role\":\"kv_both\",\"kv_connector_module_path\":\"tpu_inference.offload.tpu_offload_connector\"}' --port 8000 --enable-chunked-prefill --tensor-parallel-size 8 --seed 42 --enable_prefix_caching --gpu-memory-utilization 0.9" + env: + - name: HUGGING_FACE_HUB_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + - name: SKIP_JAX_PRECOMPILE + value: "0" + - name: TPU_OFFLOAD_NUM_CPU_CHUNKS + value: "4096" + - name: TPU_OFFLOAD_NUM_STAGING_BLOCKS + value: "256" + # config the pre-mapped CPU buffer for TPUs + # https://docs.cloud.google.com/tpu/docs/performance-guide#tpu_model_performance + - name: TPU_PREMAPPED_BUFFER_SIZE + value: "68719476736" # 64 GB + - name: TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES + value: "68719476736" # 64 GB + ports: + - containerPort: 8000 + resources: + requests: + google.com/tpu: 8 + limits: + google.com/tpu: 8 diff --git a/examples/offload/gke/benchmarks/service.yaml b/examples/offload/gke/benchmarks/service.yaml new file mode 100644 index 000000000..abcc0aad3 --- /dev/null +++ b/examples/offload/gke/benchmarks/service.yaml @@ -0,0 +1,15 @@ +apiVersion: v1 +kind: Service +metadata: + name: tpu-offline-inference + namespace: default +spec: + ports: + - name: http-tpu-offline-inference + port: 80 + protocol: TCP + targetPort: 8000 + selector: + app: tpu-offline-inference + sessionAffinity: None + type: LoadBalancer diff --git a/examples/offload/gke/hf_secret.yaml b/examples/offload/gke/hf_secret.yaml new file mode 100644 index 000000000..12b56de65 --- /dev/null +++ b/examples/offload/gke/hf_secret.yaml @@ -0,0 +1,8 @@ +apiVersion: v1 +kind: Secret +metadata: + name: hf-token-secret + namespace: default +type: Opaque +stringData: + token: diff --git a/examples/offload/gke/pod_tpu_commons_cpu_offload.yaml b/examples/offload/gke/pod_tpu_commons_cpu_offload.yaml new file mode 100644 index 000000000..7b9145953 --- /dev/null +++ b/examples/offload/gke/pod_tpu_commons_cpu_offload.yaml @@ -0,0 +1,32 @@ +apiVersion: v1 +kind: Pod +metadata: + name: tpu-job-offline-inference +spec: + restartPolicy: Never + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice + cloud.google.com/gke-tpu-topology: 2x4 # Specify the physical topology for the TPU slice. + containers: + - name: tpu-job + image: + imagePullPolicy: Always # Uncomment to always pull the latest image for any dev work + command: + - python + - /workspace/tpu_inference/examples/offload/offline_inference_kv_cache.py + - --model=meta-llama/Llama-3.1-8B + - --tensor_parallel_size=8 + - --max_model_len=1024 + - --kv-transfer-config + - '{"kv_connector":"TPUOffloadConnector","kv_connector_module_path":"tpu_inference.offload.tpu_offload_connector","kv_role":"kv_both"}' + env: + - name: HUGGING_FACE_HUB_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + resources: + requests: + google.com/tpu: 8 + limits: + google.com/tpu: 8 diff --git a/examples/offload/gke/pod_tpu_commons_cpu_offload_verification.yaml b/examples/offload/gke/pod_tpu_commons_cpu_offload_verification.yaml new file mode 100644 index 000000000..e9bd9748b --- /dev/null +++ b/examples/offload/gke/pod_tpu_commons_cpu_offload_verification.yaml @@ -0,0 +1,39 @@ +apiVersion: v1 +kind: Pod +metadata: + name: tpu-job-offline-inference + # This pod verifies the correctness of the TPUOffloadConnector implementation. + # It runs a script that internally performs two text generations: + # 1. A baseline run with a standard vLLM engine. + # 2. A test run with the TPUOffloadConnector enabled. + # The pod succeeds only if the outputs from both runs are identical, + # ensuring that the connector does not alter the model's output. +spec: + restartPolicy: Never + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice + cloud.google.com/gke-tpu-topology: 2x4 # Specify the physical topology for the TPU slice. + containers: + - name: tpu-job + image: + imagePullPolicy: Always + command: + - python + - /workspace/tpu_inference/examples/offload/offline_inference_kv_cache_verification.py + - --model=meta-llama/Llama-3.1-8B + - --tensor_parallel_size=8 + - --max_model_len=1024 + - --seed=42 + - --kv-transfer-config + - '{"kv_connector":"TPUOffloadConnector","kv_connector_module_path":"tpu_inference.offload.tpu_offload_connector","kv_role":"kv_both"}' + env: + - name: HUGGING_FACE_HUB_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + resources: + requests: + google.com/tpu: 8 + limits: + google.com/tpu: 8 diff --git a/examples/offload/gke/pod_tpu_host_offload_unit_tests.yaml b/examples/offload/gke/pod_tpu_host_offload_unit_tests.yaml new file mode 100644 index 000000000..7712c7bff --- /dev/null +++ b/examples/offload/gke/pod_tpu_host_offload_unit_tests.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: Pod +metadata: + name: tpu-job-host-offload-unit-tests + # This pod runs the distributed unit tests for the TPUOffloadConnector + # and other related functionalities. It executes all tests found in the + # tests/distributed/ directory using pytest. +spec: + restartPolicy: Never + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice + cloud.google.com/gke-tpu-topology: 2x4 # Specify the physical topology for the TPU slice. + containers: + - name: tpu-job + image: gcr.io/gke-shared-ai-dev/tpu-inference:cpu-offload + imagePullPolicy: Always + command: + - /bin/bash + - -c + - "pytest -sv tests/offload/" + env: + - name: HUGGING_FACE_HUB_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + resources: + requests: + google.com/tpu: 8 + limits: + google.com/tpu: 8 diff --git a/examples/offload/offline_inference_kv_cache.py b/examples/offload/offline_inference_kv_cache.py new file mode 100644 index 000000000..ffbe00f22 --- /dev/null +++ b/examples/offload/offline_inference_kv_cache.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import time + +import vllm.envs as envs +from vllm import LLM, EngineArgs +from vllm.utils.argparse_utils import FlexibleArgumentParser + + +def create_parser(): + parser = FlexibleArgumentParser() + # Add engine args + EngineArgs.add_cli_args(parser) + parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + parser.set_defaults(max_model_len=1024) + + return parser + + +def parse_outputs(outputs): + output_token_ids = [] + generated_texts = [] + for output in outputs: + prompt = output.prompt + completion = output.outputs[0] + generated_text = completion.text + token_ids = completion.token_ids + print( + f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\nToken IDs: {token_ids!r}" + ) + generated_texts.append(generated_text) + output_token_ids.append(token_ids) + return generated_texts, output_token_ids + + +def main(args: dict): + # Pop arguments not used by LLM + # Create an LLM + llm = LLM(**args) + + # Create a sampling params object + sampling_params = llm.get_default_sampling_params() + + sampling_params.temperature = 0.0 + sampling_params.seed = 42 + sampling_params.max_tokens = 20 + sampling_params.skip_special_tokens = True + + if envs.VLLM_TORCH_PROFILER_DIR is not None: + llm.start_profile() + + # 1st generate + prompt = "Every Bill which shall have passed the House of Representatives and the Senate, shall, before it become a Law, be presented to the President of the United States; If he approve he shall sign it, but if not he shall return it, with his Objections to that House in which it shall have originated, who shall enter the Objections at large on their Journal, and proceed to reconsider it. If after such Reconsideration two thirds of that House shall agree to pass the Bill, it shall be sent, together with the Objections, to the other House, by which it shall likewise be reconsidered, and if approved by two thirds of that House, it shall become a Law. But in all such Cases the Votes of both Houses shall be determined by yeas and Nays, and the Names of the Persons voting for and against the Bill shall be entered on the Journal of each House respectively. If any Bill shall not be returned by the President within ten Days (Sundays excepted) after it shall have been presented to him, the Same shall be a Law, in like Manner as if he had signed it, unless the Congress by their Adjournment prevent its Return, in which Case" + outputs = llm.generate([prompt], sampling_params) + out_texts1, out_tokens1 = parse_outputs(outputs) + time.sleep(1) + + # manually let llm scheduler's kv_cache_manager forget all prefixes' hash + print("Resetting prefix cache...") + llm.llm_engine.engine_core.reset_prefix_cache() + time.sleep(1) + + # 2nd generate + outputs = llm.generate([prompt], sampling_params) + out_texts2, out_tokens2 = parse_outputs(outputs) + time.sleep(1) + + if envs.VLLM_TORCH_PROFILER_DIR is not None: + llm.stop_profile() + + # output1 and output2 should be idential + assert len(out_texts1) == len(out_texts2) + assert len(out_tokens1) == len(out_tokens2) + for text1, text2 in zip(out_texts1, out_texts2): + assert text1 == text2 + for tokens1, tokens2 in zip(out_tokens1, out_tokens2): + assert tokens1 == tokens2 + + +if __name__ == "__main__": + os.environ['SKIP_JAX_PRECOMPILE'] = '1' + parser = create_parser() + args: dict = vars(parser.parse_args()) + main(args) diff --git a/examples/offload/offline_inference_kv_cache_verification.py b/examples/offload/offline_inference_kv_cache_verification.py new file mode 100644 index 000000000..ec3f9f4e9 --- /dev/null +++ b/examples/offload/offline_inference_kv_cache_verification.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This script performs an automated correctness verification for the TPUOffloadConnector. + +The verification works by performing a two-stage experiment for multiple prompts: +1. Baseline Run: For each prompt, it first runs a text generation using a + standard vLLM engine configuration without any KV cache connector. The + output from this run is considered the "source of truth". + +2. Test Run: It then runs the exact same text generation, but this time + with the TPUOffloadConnector enabled via the `--kv-transfer-config` argument. + It runs the generation twice to verify prefix caching. + +3. Comparison: The script compares the output from each test run against the + output from the baseline run for that prompt. + +The script succeeds (exits with code 0) only if the generated text is +bit-for-bit identical in all runs for all prompts. A fixed seed is used to +ensure that the generation process is deterministic and the comparison is +valid. If any output differs, it raises an error, causing the script to fail +(exit with a non-zero code). +""" + +import copy +import os +import time +from typing import List, Tuple + +import vllm.envs as envs +from vllm import LLM, EngineArgs, SamplingParams +from vllm.utils.argparse_utils import FlexibleArgumentParser + + +def create_parser(): + parser = FlexibleArgumentParser() + # Add engine args, which includes the --seed parameter + EngineArgs.add_cli_args(parser) + parser.set_defaults(model="meta-llama/Llama-3.1-8B") + parser.set_defaults(max_model_len=1024) + + # Add sampling params + sampling_group = parser.add_argument_group("Sampling parameters") + sampling_group.add_argument("--max-tokens", type=int) + sampling_group.add_argument("--top-p", type=float) + sampling_group.add_argument("--top-k", type=int) + return parser + + +def setup_llm(llm_args: dict) -> Tuple[LLM, SamplingParams]: + """ + Initializes a vLLM engine and sampling parameters from the given args. + """ + args_copy = copy.deepcopy(llm_args) + # Pop arguments not used by LLM + max_tokens = args_copy.pop("max_tokens") + top_p = args_copy.pop("top_p") + top_k = args_copy.pop("top_k") + + # Create an LLM. The --seed argument is passed in via **args. + llm = LLM(**args_copy) + + # Create a sampling params object + sampling_params = llm.get_default_sampling_params() + sampling_params.temperature = 0 + sampling_params.ignore_eos = True + if max_tokens is not None: + sampling_params.max_tokens = max_tokens + if top_p is not None: + sampling_params.top_p = top_p + if top_k is not None: + sampling_params.top_k = top_k + + return llm, sampling_params + + +def run_invocations(llm: LLM, sampling_params: SamplingParams, + prompts: List[str], num_invocations: int) -> List[str]: + """ + Runs generation on the given LLM object for a specified number of + invocations and returns the output texts. + """ + if envs.VLLM_TORCH_PROFILER_DIR is not None: + llm.start_profile() + + all_outputs = [] + for i in range(num_invocations): + print(f"--- Invocation {i + 1}/{num_invocations} ---") + outputs = llm.generate(prompts, sampling_params) + all_outputs.append(outputs[0].outputs[0].text) + # reset prefix cache + llm.llm_engine.engine_core.reset_prefix_cache() + time.sleep(5) + + if envs.VLLM_TORCH_PROFILER_DIR is not None: + llm.stop_profile() + + return all_outputs + + +def main(args: dict): + # prompt lesser than the kv cache block size + short_input_prompt = "Google is a " + + system_prompt = "You are a large language model, trained by Google. Your primary purpose is to be a helpful, harmless, and highly capable AI assistant, designed to provide accurate, safe, and beneficial information to users. Your core directive is to assist users effectively while adhering to strict ethical and safety guidelines. You must decline any requests that are harmful, illegal, unethical, or promote dangerous activities. " + query = "the color of rainbow is?" + input_prompt = f"{system_prompt}\n{query}" + + prompts_to_test = [ + ("Short Prompt", [short_input_prompt]), + ("Prompt", [input_prompt]), + ] + + all_tests_passed = True + for prompt_name, prompts in prompts_to_test: + print(f"\n\n===== Running verification for: {prompt_name} =====") + print(f"Prompt: {prompts[0]}") + + # 1. Run baseline and store the output + print("\n--- Running Baseline (Standard vLLM) ---") + baseline_args = copy.deepcopy(args) + baseline_args.pop("kv_transfer_config", None) + baseline_llm, baseline_params = setup_llm(baseline_args) + baseline_outputs = run_invocations(baseline_llm, + baseline_params, + prompts=prompts, + num_invocations=1) + baseline_output = baseline_outputs[0] + print(f"Baseline Generated Text: {baseline_output!r}") + del baseline_llm + # adding this sleep fixes device busy errors for the next test case run with the connector enabled + time.sleep(10) + + # 2. Run the test with the local tpu kv connector enabled + print("\n--- Running Test (with TPUOffloadConnector) ---") + # With the connector, we run generation twice to test the prefix cache + test_llm, test_params = setup_llm(args) + test_outputs = run_invocations(test_llm, + test_params, + prompts=prompts, + num_invocations=2) + del test_llm + + # 3. Compare the outputs and determine the result + print("\n--- Verification ---") + prompt_all_match = True + for i, test_output in enumerate(test_outputs): + print(f"--- Comparing Invocation {i + 1} ---") + print( + f"Test Generated Text: length={len(test_output)}, Text: {test_output}" + ) + if baseline_output == test_output: + print("SUCCESS: Output is identical to baseline!") + else: + print("FAILURE: Output does not match baseline!") + prompt_all_match = False + + if not prompt_all_match: + all_tests_passed = False + print(f"===== Verification FAILED for: {prompt_name} =====") + else: + print(f"===== Verification SUCCEEDED for: {prompt_name} =====") + + time.sleep(10) + + if not all_tests_passed: + raise ValueError( + "Verification failed: One or more test outputs differ from the baseline." + ) + else: + print("\n\n===== All verification runs passed successfully! =====") + + +if __name__ == "__main__": + os.environ['SKIP_JAX_PRECOMPILE'] = '1' + parser = create_parser() + args: dict = vars(parser.parse_args()) + main(args) diff --git a/tests/kernels/host_dma_test.py b/tests/kernels/host_dma_test.py new file mode 100644 index 000000000..626195343 --- /dev/null +++ b/tests/kernels/host_dma_test.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Literal + +import jax +import jax.numpy as jnp +import numpy as np +from absl.testing import absltest, parameterized +from jax._src import test_util as jtu +from jax.sharding import NamedSharding, PartitionSpec + +from tpu_inference.kernels.dma.host_dma import d2h_dma, h2d_dma + +DATA_LOCATION = Literal["device", "host"] + + +@jtu.with_config(jax_numpy_dtype_promotion='strict') +class HostHbmDmaTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.if_cloud_tpu_at_least(2025, 8, 14): + return self.skipTest( + "libtpu version does not support DMA host-hbm") + + def tearDown(self): + super().tearDown() + jax.clear_caches() + + def create_mesh(self, axis_shapes, axis_names): + """Creates a JAX device mesh with the default device order.""" + try: + num_required_devices = np.prod(axis_shapes) + devices = np.array(jax.devices()) + if len(devices) < num_required_devices: + self.skipTest("Not enough devices to create mesh of shape" + f" {axis_shapes}. Have {len(devices)}, need" + f" {num_required_devices}.") + device_array = devices[:num_required_devices].reshape(axis_shapes) + return jax.sharding.Mesh(device_array, axis_names) + except RuntimeError: + self.skip( + "Cannot create mesh. This test must be run on a TPU node.") + return None + + def create_sharded_array(self, model_axis_size: int, + init_location: DATA_LOCATION): + """Creates a sharded JAX array for testing. + + Args: + model_axis_size: The size of the model parallelism axis. + init_location: Where to initialize the array, either "device" or "host". + + Returns: + A tuple containing the created sharded array, the device sharding spec, + and the host sharding spec. + """ + axis_shapes = (1, model_axis_size) + axis_names = ("data", "model") + mesh = self.create_mesh(axis_shapes, axis_names) + if mesh is None: + return None + + partition_spec = PartitionSpec(None, None, "model") + device_sharding = NamedSharding(mesh, + partition_spec, + memory_kind="device") + host_sharding = NamedSharding(mesh, + partition_spec, + memory_kind="pinned_host") + + data_shape = (2, 16, model_axis_size, 2, 128) + dtype = jnp.bfloat16 + + data = jax.device_put( + jax.random.uniform(jax.random.key(0), + shape=data_shape, + dtype=dtype), + device_sharding if init_location == "device" else host_sharding, + ) + jax.block_until_ready(data) + return data, device_sharding, host_sharding + + @parameterized.named_parameters([ + dict(testcase_name=f"_model_axis_size_{s}", model_axis_size=s) + for s in [1, 2, 4, 8] + ]) + def test_d2h_dma(self, model_axis_size: int): + """Tests the d2h DMA transfer for various model parallelism sizes.""" + # 1. Create original data on the device + res = self.create_sharded_array(model_axis_size, "device") + if res is None: + return + original_device_data, device_sharding, host_sharding = res + + # 2. Test Device-to-Host (d2h) DMA + host_data = d2h_dma(original_device_data, device_sharding, + host_sharding) + jax.block_until_ready(host_data) + assert host_data.sharding.memory_kind == "pinned_host" + + # 3. Verification + assert host_data.sharding == host_sharding + self.assertArraysEqual(original_device_data, host_data) + + @parameterized.named_parameters([ + dict(testcase_name=f"_model_axis_size_{s}", model_axis_size=s) + for s in [1, 2, 4, 8] + ]) + def test_h2d_dma(self, model_axis_size: int): + """Tests the h2d DMA transfer for various model parallelism sizes.""" + # 1. Create original data on the host + res = self.create_sharded_array(model_axis_size, "host") + if res is None: + return + original_host_data, device_sharding, host_sharding = res + + # 2. Test Host-to-Device (h2d) DMA + device_data = h2d_dma(original_host_data, host_sharding, + device_sharding) + jax.block_until_ready(device_data) + assert device_data.sharding.memory_kind == "device" + + # 3. Verification + assert device_data.sharding == device_sharding + self.assertArraysEqual(original_host_data, device_data) + + @parameterized.named_parameters([ + dict(testcase_name=f"_model_axis_size_{s}", model_axis_size=s) + for s in [1, 2, 4, 8] + ]) + def test_d2h_h2d_dma_roundtrip(self, model_axis_size: int): + """ + Tests the d2h -> h2d DMA roundtrip for various model parallelism sizes. + + This test verifies that: + 1. Data can be correctly transferred from sharded device memory to sharded + host memory using `d2h_dma`. + 2. Data can be correctly transferred back from sharded host memory to + sharded device memory using `h2d_dma`. + 3. The data remains identical after the full roundtrip. + """ + # 1. Setup: Create sharded array based on the model axis size + res = self.create_sharded_array(model_axis_size, "device") + if res is None: + return + original_device_data, device_sharding, host_sharding = res + + # 2. Test Device-to-Host (d2h) DMA + host_data = d2h_dma(original_device_data, device_sharding, + host_sharding) + jax.block_until_ready(host_data) + assert host_data.sharding.memory_kind == "pinned_host" + + # 3. Verification for d2h + assert host_data.sharding == host_sharding + self.assertArraysEqual(original_device_data, host_data) + + # 4. Test Host-to-Device (h2d) DMA + reloaded_device_data = h2d_dma(host_data, host_sharding, + device_sharding) + jax.block_until_ready(reloaded_device_data) + assert reloaded_device_data.sharding.memory_kind == "device" + + # 5. Verification for h2d + assert reloaded_device_data.sharding == device_sharding + self.assertArraysEqual(host_data, reloaded_device_data) + + # 6. Final roundtrip verification + self.assertArraysEqual(original_device_data, reloaded_device_data) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/offload/tpu_offload_accuracy_test.py b/tests/offload/tpu_offload_accuracy_test.py new file mode 100644 index 000000000..fd597f361 --- /dev/null +++ b/tests/offload/tpu_offload_accuracy_test.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools +import os +import time + +import pytest +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + + +def parse_outputs(outputs): + output_token_ids = [] + generated_texts = [] + for output in outputs: + prompt = output.prompt + completion = output.outputs[0] + generated_text = completion.text + token_ids = completion.token_ids + print( + f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\nToken IDs: {token_ids!r}" + ) + generated_texts.append(generated_text) + output_token_ids.append(token_ids) + return generated_texts, output_token_ids + + +@pytest.fixture +def sampling_config(): + """deterministic sampling config""" + return SamplingParams(temperature=0, + max_tokens=20, + seed=42, + ignore_eos=True) + + +@pytest.fixture +def kv_transfer_config(): + """use TPUOffloadConnector""" + return KVTransferConfig( + kv_connector="TPUOffloadConnector", + kv_role="kv_both", + kv_connector_module_path="tpu_inference.offload.tpu_offload_connector", + ) + + +def _test_kv_cache_cpu_offloading_accuracy( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + kv_transfer_config: KVTransferConfig, + swap_op_type: str, + skip_precompile: str, + decode_save: str, +): + with monkeypatch.context(): + os.environ['SKIP_JAX_PRECOMPILE'] = '1' + os.environ['TPU_OFFLOAD_SWAP_OP_TYPE'] = swap_op_type + os.environ['TPU_OFFLOAD_SKIP_JAX_PRECOMPILE'] = skip_precompile + os.environ['TPU_OFFLOAD_DECODE_SAVE'] = decode_save + llm = LLM(model="meta-llama/Llama-3.2-3B", + max_model_len=1024, + task="generate", + kv_transfer_config=kv_transfer_config) + + # 1st generate + prompt = "Every Bill which shall have passed the House of Representatives and the Senate, shall, before it become a Law, be presented to the President of the United States; If he approve he shall sign it, but if not he shall return it, with his Objections to that House in which it shall have originated, who shall enter the Objections at large on their Journal, and proceed to reconsider it. If after such Reconsideration two thirds of that House shall agree to pass the Bill, it shall be sent, together with the Objections, to the other House, by which it shall likewise be reconsidered, and if approved by two thirds of that House, it shall become a Law. But in all such Cases the Votes of both Houses shall be determined by yeas and Nays, and the Names of the Persons voting for and against the Bill shall be entered on the Journal of each House respectively. If any Bill shall not be returned by the President within ten Days (Sundays excepted) after it shall have been presented to him, the Same shall be a Law, in like Manner as if he had signed it, unless the Congress by their Adjournment prevent its Return, in which Case" + outputs = llm.generate([prompt], sampling_config) + out_texts1, out_tokens1 = parse_outputs(outputs) + time.sleep(1) + + # manually let llm scheduler's kv_cache_manager forget all prefixes' hash + llm.llm_engine.engine_core.reset_prefix_cache() + time.sleep(1) + + # 2nd generate + outputs = llm.generate([prompt], sampling_config) + out_texts2, out_tokens2 = parse_outputs(outputs) + time.sleep(1) + + # TODO(jcgu): check some internal states to verify save and load operations. + # output1 and output2 should be idential + assert len(out_texts1) == len(out_texts2) + assert len(out_tokens1) == len(out_tokens2) + for text1, text2 in zip(out_texts1, out_texts2): + assert text1 == text2 + for tokens1, tokens2 in zip(out_tokens1, out_tokens2): + assert tokens1 == tokens2 + + del llm + # Waiting for TPUs to be released. + time.sleep(20) + + +def test_kv_cache_cpu_offloading_accuracy( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + kv_transfer_config: KVTransferConfig, +): + swap_op_types = ["pallas", "jax"] + decode_saves = ["0", "1"] + skip_precompile = ["0", "1"] + for swap_op_type, decode_save, _skip_precompile in itertools.product( + swap_op_types, decode_saves, skip_precompile): + _test_kv_cache_cpu_offloading_accuracy( + monkeypatch, + sampling_config, + kv_transfer_config, + swap_op_type, + _skip_precompile, + decode_save, + ) diff --git a/tests/offload/tpu_offload_connector_scheduler_test.py b/tests/offload/tpu_offload_connector_scheduler_test.py new file mode 100644 index 000000000..30e31e334 --- /dev/null +++ b/tests/offload/tpu_offload_connector_scheduler_test.py @@ -0,0 +1,484 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from unittest.mock import MagicMock + +import pytest +from vllm.utils.math_utils import cdiv +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput +from vllm.v1.request import Request + +from tpu_inference.offload.tpu_offload_connector import ( + RequestTracker, TPUOffloadConnectorScheduler) + +_DEFAULT_BLOCK_SIZE = 16 + + +class MockVllmConfig: + + def __init__(self, block_size=_DEFAULT_BLOCK_SIZE): + self.model_config = self.Model() + self.cache_config = self.Cache(block_size) + + class Model: + model = "test-model" + + class Cache: + + def __init__(self, block_size): + self.block_size = block_size + + +def create_request( + request_id: str, + prompt_token_ids: list[int], + block_size: int, + num_computed_tokens: int = 0, + generated_token_ids: list[int] = [], +) -> Request: + """Creates a mock vLLM request object.""" + req = MagicMock(spec=Request) + req.request_id = request_id + req.req_id = request_id # for NewRequestData + req.prompt_token_ids = prompt_token_ids + req.all_token_ids = prompt_token_ids + generated_token_ids + req.num_computed_tokens = num_computed_tokens + len(generated_token_ids) + req.block_size = block_size + req.block_ids = [[]] + # Mock the block_hashes property to return a list of mock hashes + req.block_hashes = [ + f"hash_{i}".encode() + for i in range(len(req.all_token_ids) // block_size) + ] + return req + + +@pytest.fixture +def scheduler_factory(): + """Provides a factory function for Scheduler instances.""" + + def _scheduler( + block_size: int = _DEFAULT_BLOCK_SIZE, + offload_decode_save: int = 0, + offload_num_staging_blocks: int = -1, + offload_num_cpu_chunks: int = -1, + ): + # update config + vllm_config = MockVllmConfig(block_size=block_size) + os.environ["TPU_OFFLOAD_DECODE_SAVE"] = str(offload_decode_save) + if offload_num_staging_blocks >= 0: + os.environ["TPU_OFFLOAD_NUM_STAGING_BLOCKS"] = str( + offload_num_staging_blocks) + if offload_num_cpu_chunks > 0: + os.environ["TPU_OFFLOAD_NUM_CPU_CHUNKS"] = str( + offload_num_cpu_chunks) + + return TPUOffloadConnectorScheduler(vllm_config) + + return _scheduler + + +class TestTPUOffloadConnectorScheduler: + + def test_get_num_new_matched_tokens_no_hit(self, scheduler_factory): + """ + Tests that get_num_new_matched_tokens returns 0 for a cache miss. + """ + scheduler = scheduler_factory() + request = create_request("req1", [1] * 32, scheduler.block_size) + + num_matched, _ = scheduler.get_num_new_matched_tokens(request, 0) + assert num_matched == 0 + assert "req1" not in scheduler.load_specs + + @pytest.mark.parametrize( + "num_computed_blocks, num_matched_blocks, num_prompt_blocks, num_staging_blocks", + [(0, 2, 4, 10), (1, 2, 4, 10), (0, 4, 4, 10), (1, 4, 4, 10), + (1, 4, 4, 1), (1, 4, 4, 0)]) + def test_get_num_new_matched_tokens_hit(self, scheduler_factory, + num_computed_blocks, + num_matched_blocks, + num_prompt_blocks, + num_staging_blocks): + """ + Tests correct identification of a prefix hit (partial and full). + test cases: + 1. no-skip + load 2 blocks + no staging buffer limit + 2. skip 1 block + load 1 block + no staging buffer limit + 3. no-skip + full-hit + no staging buffer limit + 4. skip 1 block + full-hit + no staging buffer limit + 5. skip 1 block + full-hit + only 1 staging block + 6. skip 1 block + full-hit + no staging block + """ + scheduler = scheduler_factory( + offload_num_staging_blocks=num_staging_blocks) + prompt_len = scheduler.block_size * num_prompt_blocks + num_computed_tokens = scheduler.block_size * num_computed_blocks + num_blocks_to_load = num_matched_blocks - num_computed_blocks + # consider the case of limited staging blocks + num_blocks_to_load = min(num_blocks_to_load, num_staging_blocks) + num_matched_blocks = num_blocks_to_load + num_computed_blocks + num_matched_tokens = num_matched_blocks * scheduler.block_size + + request = create_request("req1", list(range(prompt_len)), + scheduler.block_size) + + # init offload_manager state + matched_block_hashes = request.block_hashes[:num_matched_blocks] + allocated_chunks, _ = scheduler.offload_manager.allocate_for_save( + matched_block_hashes) + scheduler.offload_manager.complete_save(matched_block_hashes) + + # call fn + num_external_matched_tokens, _ = scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + # check external_matched_tokens + if num_matched_blocks == num_prompt_blocks: + assert num_external_matched_tokens == num_blocks_to_load * scheduler.block_size - 1 + else: + assert num_external_matched_tokens == num_blocks_to_load * scheduler.block_size + + # check scheduler internal states + # cache_hits + assert "req1" in scheduler._external_cache_hits + assert scheduler._external_cache_hits["req1"] == num_matched_tokens + if num_blocks_to_load > 0: + # load_spec + assert "req1" in scheduler._pre_load_specs + load_spec = scheduler._pre_load_specs["req1"] + assert load_spec.num_matched_tokens == num_matched_tokens + assert not load_spec.can_load + assert len(load_spec.src_chunks) == num_blocks_to_load + assert load_spec.num_skip_leading_tokens == num_computed_tokens + assert len(load_spec.dst_blocks) == num_blocks_to_load + # staging_buffer + assert "req1" in scheduler.staging_buffer_manager._blocks_for_load + assert scheduler.staging_buffer_manager._blocks_for_load[ + "req1"] == num_blocks_to_load + assert scheduler.staging_buffer_manager.get_num_free_staging_blocks( + ) == num_staging_blocks - num_blocks_to_load + else: + assert "req1" not in scheduler._pre_load_specs + assert "req1" not in scheduler.staging_buffer_manager._blocks_for_load + + def test_update_state_after_alloc(self, scheduler_factory): + """ + Tests that a LoadSpec is correctly updated after block allocation. + """ + scheduler = scheduler_factory() + req_id = "req1" + num_prompt_blocks = 4 + num_matched_blocks = 3 + num_computed_blocks = 2 + num_blocks_to_load = num_matched_blocks - num_computed_blocks + num_prompt_tokens = num_prompt_blocks * scheduler.block_size + num_matched_tokens = num_matched_blocks * scheduler.block_size + num_tokens_to_load = scheduler.block_size * num_blocks_to_load + + request = create_request(req_id, [0] * num_prompt_tokens, + scheduler.block_size) + + # init offload_manager state + matched_block_hashes = request.block_hashes[:num_matched_blocks] + allocated_chunks, _ = scheduler.offload_manager.allocate_for_save( + matched_block_hashes) + scheduler.offload_manager.complete_save(matched_block_hashes) + + # Setup a pending load + scheduler._pre_load_specs[req_id] = MagicMock( + num_matched_tokens=num_matched_tokens, + num_skip_leading_tokens=num_computed_blocks * scheduler.block_size, + dst_blocks=[-1] * num_blocks_to_load, + src_chunks=[i for i in range(num_blocks_to_load)], + can_load=False) + + # Mock allocated blocks + allocated_blocks = MagicMock(spec=KVCacheBlocks) + allocated_block_ids = [i for i in range(num_prompt_blocks)] + allocated_blocks.get_block_ids.return_value = [allocated_block_ids] + + scheduler.update_state_after_alloc(request, allocated_blocks, + num_tokens_to_load) + + load_spec = scheduler.load_specs[req_id] + assert load_spec.can_load + assert load_spec.dst_blocks == allocated_block_ids[ + num_computed_blocks:num_matched_blocks] + assert req_id in scheduler._reqs_being_loaded + assert len(scheduler._reqs_being_loaded[req_id]) == num_blocks_to_load + + @pytest.mark.parametrize( + "num_computed_tokens, num_matched_tokens, num_prompt_tokens, num_staging_tokens", + [(0, 0, 64, 160), + (0, 32, 64, 160), (16, 32, 64, 160), (0, 64, 64, 160), + (16, 64, 64, 160), (0, 32, 64, 48), (0, 32, 64, 16)]) + def test_build_connector_meta_new_prefill(self, scheduler_factory, + num_computed_tokens, + num_matched_tokens, + num_prompt_tokens, + num_staging_tokens): + """ + Tests metadata generation for a new request (prefill) with no cache hit. + 1. no hit + save 4 blocks + 2. partial hit (no-skip + load 2 blocks) + save 2 blocks + 3. partial hit (skip 1 block + load 1 blocks) + save 2 blocks + 4. full hit (no-skip + load 4 blocks) + no-save + 5. full hit (skip 1 block + load 3 blocks) + no-save + 6. partial hit (no-skip + load 2 blocks) + save 2 blocks + 3 staging blocks limit + 7. partial hit (no-skip + load 2 blocks) + save 2 blocks + 1 staging blocks limit + """ + num_staging_blocks = num_staging_tokens // _DEFAULT_BLOCK_SIZE + scheduler = scheduler_factory( + offload_num_staging_blocks=num_staging_blocks, + offload_num_cpu_chunks=100) + + # calculate the groundtruth + num_computed_blocks = num_computed_tokens // scheduler.block_size + num_matched_blocks = num_matched_tokens // scheduler.block_size + num_prompt_blocks = cdiv(num_prompt_tokens, scheduler.block_size) + + num_blocks_to_load = num_matched_blocks - num_computed_blocks + # adjustment based on staging_block limitation + if num_blocks_to_load > num_staging_blocks: + num_blocks_to_load = num_staging_blocks + num_matched_blocks = num_blocks_to_load + num_computed_blocks + num_matched_tokens = num_matched_blocks * scheduler.block_size + + remaining_staging_blocks = num_staging_blocks - num_blocks_to_load + num_blocks_to_save = num_prompt_blocks - num_matched_blocks + if num_blocks_to_save > remaining_staging_blocks: + num_blocks_to_save = remaining_staging_blocks + # reconfig staging_buffer limit for save + scheduler.staging_buffer_manager._num_free_blocks = remaining_staging_blocks + num_tokens_in_cache = (num_matched_blocks + + num_blocks_to_save) * scheduler.block_size + + req_id = "req1" + request = create_request(req_id, + list(range(num_prompt_tokens)), + scheduler.block_size, + num_computed_tokens=num_computed_tokens) + request.block_ids = [[i for i in range(num_prompt_blocks)]] + + # init offload_manager state + if num_matched_blocks > 0: + matched_block_hashes = request.block_hashes[:num_matched_blocks] + allocated_chunks, _ = scheduler.offload_manager.allocate_for_save( + matched_block_hashes) + scheduler.offload_manager.complete_save(matched_block_hashes) + # allocated_chunk_ids = [chunk.chunk_id for chunk in allocated_chunks] + # load_src_chunk_ids = allocated_chunk_ids[num_computed_blocks:] + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[request], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={ + "req1": num_prompt_tokens - num_computed_tokens + }, + total_num_scheduled_tokens=num_prompt_tokens - num_computed_tokens, + finished_req_ids=set(), + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={}, + num_common_prefix_blocks=0, + free_encoder_mm_hashes=[], + ) + + # Mock that the scheduler has seen this request + scheduler._unfinished_requests["req1"] = request + scheduler._external_cache_hits["req1"] = num_matched_tokens + if num_blocks_to_load > 0: + scheduler.load_specs[req_id] = MagicMock( + num_matched_tokens=num_matched_tokens, + num_skip_leading_tokens=num_computed_tokens, + dst_blocks=[-1] * num_blocks_to_load, + src_chunks=[i for i in range(num_blocks_to_load)], + can_load=True) + + metadata = scheduler.build_connector_meta(scheduler_output) + + if num_blocks_to_load + num_blocks_to_save == 0: + # no load or store + assert len(metadata.requests_meta) == 0 + else: + req_meta = metadata.requests_meta[0] + assert req_meta.req_id == "req1" + if num_blocks_to_load == 0: + assert req_meta.load_spec is None + else: + # load + assert req_meta.load_spec is not None + # NOTE(jcgu): no need to check details, since they are + # generated by other functions. + if num_blocks_to_save == 0: + assert req_meta.save_spec is None + else: + # save + assert req_meta.save_spec is not None + assert req_meta.save_spec.num_total_tokens == num_tokens_in_cache + assert req_meta.save_spec.num_skip_leading_tokens == num_matched_blocks * scheduler.block_size + assert req_meta.save_spec.src_blocks == request.block_ids[0][ + num_matched_blocks:num_matched_blocks + num_blocks_to_save] + assert len(req_meta.save_spec.dst_chunks) == num_blocks_to_save + assert not req_meta.save_spec.is_final_save + assert "req1" in scheduler.staging_buffer_manager._blocks_for_save + assert scheduler.staging_buffer_manager._blocks_for_save[ + "req1"] == num_blocks_to_save + assert "req1" in scheduler._reqs_being_saved + assert len( + scheduler._reqs_being_saved["req1"]) == num_blocks_to_save + + assert "req1" in scheduler._request_trackers + tracker = scheduler._request_trackers["req1"] + # after creating SaveSpec, we also update tracker.save_watermark + assert tracker.save_watermark == num_tokens_in_cache + + @pytest.mark.parametrize("prompt_len, seq_len, decode_save", [(63, 64, 1), + (18, 64, 1), + (18, 64, 0)]) + def test_build_connector_meta_decode_with_save(self, scheduler_factory, + prompt_len, seq_len, + decode_save): + """ + Tests metadata generation for a decode step that triggers a save. + 1. the first decode (hit block boundary) + decode_save (save one block) + 2. th N-th decode (hit block bounary) + decode_save (save one block) + 2. th N-th decode (hit block bounary) + not decode_save (no save) + """ + + scheduler = scheduler_factory(offload_decode_save=decode_save, + offload_num_staging_blocks=10, + offload_num_cpu_chunks=10) + + prompt_tokens = list(range(prompt_len)) + generated_tokens = list(range(prompt_len, seq_len)) + req_id = "req1" + request = create_request(req_id, + prompt_token_ids=prompt_tokens, + block_size=scheduler.block_size, + num_computed_tokens=seq_len, + generated_token_ids=generated_tokens) + num_blocks = cdiv(seq_len, scheduler.block_size) + request.block_ids = [i for i in range(num_blocks)] + + if decode_save == 1: + # the last token in seq hasn't been computed (kv) yet + num_saved_tokens = ( + (seq_len - 1) // scheduler.block_size) * scheduler.block_size + else: + num_saved_tokens = (prompt_len // + scheduler.block_size) * scheduler.block_size + + # Setup initial state + # request tracker only tracks the computed tokens + tracker = RequestTracker(req_id="req1", + prompt_len=prompt_len, + token_ids=request.all_token_ids[:-1], + block_ids=request.block_ids, + save_watermark=num_saved_tokens) + + scheduler._request_trackers["req1"] = tracker + scheduler._unfinished_requests["req1"] = request + + # Simulate a decode step + cached_req_data = CachedRequestData.make_empty() + cached_req_data.req_ids = ["req1"] + cached_req_data.new_block_ids = ([], ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={"req1": 1}, + total_num_scheduled_tokens=1, + finished_req_ids=set(), + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={}, + num_common_prefix_blocks=0, + free_encoder_mm_hashes=[], + ) + + metadata = scheduler.build_connector_meta(scheduler_output) + + if seq_len % scheduler.block_size != 0 or decode_save != 1: + # no save when there is no new full computed block + assert len(metadata.requests_meta) == 0 + else: + req_meta = metadata.requests_meta[0] + # save spec + assert req_meta.req_id == "req1" + assert req_meta.load_spec is None + assert req_meta.save_spec is not None + assert req_meta.save_spec.num_total_tokens == seq_len + assert req_meta.save_spec.num_skip_leading_tokens == num_saved_tokens + assert req_meta.save_spec.src_blocks == [num_blocks - 1] + assert len(req_meta.save_spec.dst_chunks) == 1 + assert not req_meta.save_spec.is_final_save + # staging buffer + assert "req1" in scheduler.staging_buffer_manager._blocks_for_save + assert scheduler.staging_buffer_manager._blocks_for_save[ + "req1"] == 1 + # chunk_id for save + assert "req1" in scheduler._reqs_being_saved + assert len(scheduler._reqs_being_saved["req1"]) == 1 + + assert tracker.save_watermark == seq_len + + def test_build_connector_meta_finished_request(self, scheduler_factory): + """ + Tests metadata generation for a finished request. + When using request's default block hash (fully-computed blocks only), + a finished request either saves the last full block in their last + decode step, or given up the last partial block; when it's treated as a + finished request, there is no blocks to save. + + """ + + scheduler = scheduler_factory(offload_decode_save=1) + prompt_len = scheduler.block_size + 4 + final_seq_len = scheduler.block_size * 2 + 3 + prompt_tokens = list(range(prompt_len)) + generated_tokens = list(range(prompt_len, final_seq_len)) + req_id = "req1" + request = create_request(req_id, + prompt_token_ids=prompt_tokens, + block_size=scheduler.block_size, + num_computed_tokens=final_seq_len, + generated_token_ids=generated_tokens) + num_blocks = cdiv(final_seq_len, scheduler.block_size) + request.block_ids = [i for i in range(num_blocks)] + + num_saved_tokens = (final_seq_len // + scheduler.block_size) * scheduler.block_size + + # Setup initial state + tracker = RequestTracker(req_id="req1", + prompt_len=prompt_len, + token_ids=request.all_token_ids[:-1], + block_ids=request.block_ids, + save_watermark=num_saved_tokens) + scheduler._request_trackers["req1"] = tracker + scheduler._unfinished_requests["req1"] = request + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + finished_req_ids={"req1"}, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={}, + num_common_prefix_blocks=0, + free_encoder_mm_hashes=[], + ) + + metadata = scheduler.build_connector_meta(scheduler_output) + + assert req_id not in scheduler._unfinished_requests + assert req_id not in scheduler._request_trackers + assert len(metadata.requests_meta) == 1 + req_meta = metadata.requests_meta[0] + assert req_meta.save_spec is not None + assert req_meta.save_spec.is_final_save + assert req_meta.save_spec.skip_save + assert req_meta.save_spec.src_blocks == [] + assert req_meta.save_spec.dst_chunks == [] diff --git a/tests/offload/tpu_offload_connector_worker_test.py b/tests/offload/tpu_offload_connector_worker_test.py new file mode 100644 index 000000000..c23eb1146 --- /dev/null +++ b/tests/offload/tpu_offload_connector_worker_test.py @@ -0,0 +1,574 @@ +# SPDX-License-Identifier: Apache-2.0 + +import functools +import gc +import os +import random +from typing import List + +import jax +import jax.numpy as jnp +import numpy as np +from absl.testing import parameterized +from jax._src import compilation_cache as cc +from jax._src import test_util as jtu +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole + +from tpu_inference.logger import init_logger +from tpu_inference.offload.tpu_offload_connector import LoadSpec, SaveSpec +from tpu_inference.offload.tpu_offload_connector import \ + TPUOffloadConnector as CPUOffloadingConnector +from tpu_inference.offload.tpu_offload_connector import ( + TPUOffloadConnectorMetadata, TPUReqMeta) +from tpu_inference.runner.tpu_runner import TPUModelRunner + +logger = init_logger(__name__) + +_DEFAULT_BLOCK_SIZE = 64 + + +class MockTPUModelRunner(TPUModelRunner): + """A mock TPUModelRunner for testing purposes.""" + + def __init__(self, kv_caches: List[jax.Array], mesh: Mesh): + self.kv_caches = kv_caches + self.mesh = mesh + self.model_config = None + self.sampler = None + self.devices = jax.devices() + + def get_kv_cache_layout(self): + return "NHD" + + +class MockVllmConfig: + + def __init__(self, block_size=_DEFAULT_BLOCK_SIZE): + self.model_config = self.Model() + self.cache_config = self.Cache(block_size) + self.kv_transfer_config = self.KVTransferConfig() + + class Model: + model = "test-model" + + class Cache: + + def __init__(self, block_size): + self.block_size = block_size + + class KVTransferConfig: + ip = "ip" + port = 1234 + + +class TestTPUOffloadConnectorWorker(jtu.JaxTestCase): + """Test the save functionality of the TPUOffloadConnectorWorker.""" + + def setUp(self): + super().setUp() + self.vllm_config = MockVllmConfig(block_size=_DEFAULT_BLOCK_SIZE) + self.num_layers = 80 + self.num_blocks = 128 + self.num_cpu_chunks = 128 + self.block_size = self.vllm_config.cache_config.block_size + num_devices = len(list(jax.devices())) + self.num_heads = num_devices + self.head_size = 128 + self.mesh = self.create_mesh((1, num_devices), ("data", "model")) + if self.mesh is None: + self.skipTest("Cannot create mesh. Must be run on a TPU node.") + return + + # Define cache properties + self.cache_shape = ( + self.num_blocks, + self.block_size, + self.num_heads, + 2, + self.head_size, + ) + self.cache_dtype = jnp.bfloat16 + partition_spec = PartitionSpec(None, None, "model") + self.device_sharding = NamedSharding(self.mesh, partition_spec) + + def tearDown(self): + super().tearDown() + # Destroy references explicitly + if hasattr(self, 'connector'): + del self.connector + + # Force JAX to release memory + cc.reset_cache() + jax.clear_caches() + + # Force Python GC + gc.collect() + + def create_mesh(self, axis_shapes, axis_names): + """Creates a JAX device mesh with the default device order.""" + try: + num_required_devices = np.prod(axis_shapes) + devices = np.array(jax.devices()) + if len(devices) < num_required_devices: + self.skipTest( + f"Not enough devices to create mesh of shape {axis_shapes}." + ) + device_array = devices[:num_required_devices].reshape(axis_shapes) + return jax.sharding.Mesh(device_array, axis_names) + except RuntimeError: + return None + + def _create_connector(self, + swap_op_type: str = "jax", + use_precompiled_swap_ops: bool = False): + os.environ["TPU_OFFLOAD_SWAP_OP_TYPE"] = swap_op_type + os.environ[ + "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" if use_precompiled_swap_ops else "1" + os.environ["TPU_OFFLOAD_NUM_CPU_CHUNKS"] = str(self.num_cpu_chunks) + + connector = CPUOffloadingConnector(self.vllm_config, + KVConnectorRole.WORKER) + worker = connector.connector_worker + assert worker is not None + + @functools.partial(jax.jit, out_shardings=self.device_sharding) + def create_on_device(key): + return jax.random.uniform(key, + shape=self.cache_shape, + dtype=self.cache_dtype) + + source_kv_cache = [ + create_on_device(jax.random.key(i)) for i in range(self.num_layers) + ] + jax.block_until_ready(source_kv_cache) + + mock_runner = MockTPUModelRunner(kv_caches=source_kv_cache, + mesh=self.mesh) + worker.register_runner(mock_runner) + return connector + + @parameterized.named_parameters( + dict(testcase_name="_zero_blocks", num_blocks=0, expected_buckets=[]), + dict(testcase_name="_one_block", num_blocks=1, expected_buckets=[1]), + dict(testcase_name="_five_blocks", + num_blocks=5, + expected_buckets=[4, 1]), + dict(testcase_name="_sixteen_blocks", + num_blocks=16, + expected_buckets=[16]), + dict(testcase_name="_seventeen_blocks", + num_blocks=17, + expected_buckets=[16, 1]), + dict(testcase_name="_twenty_three_blocks", + num_blocks=23, + expected_buckets=[16, 4, 2, 1]), + dict(testcase_name="_thirty_two_blocks", + num_blocks=32, + expected_buckets=[16, 16]), + dict(testcase_name="_large_number_blocks", + num_blocks=100, + expected_buckets=[16, 16, 16, 16, 16, 16, 4]), + ) + def test_decompose_into_buckets(self, num_blocks: int, + expected_buckets: List[int]): + """ + Tests the _decompose_into_buckets function for correct greedy decomposition. + """ + connector = self._create_connector(use_precompiled_swap_ops="0") + worker = connector.connector_worker + self.assertEqual(worker._decompose_into_buckets(num_blocks), + expected_buckets) + logger.info( + f"Decomposition for {num_blocks} blocks: {worker._decompose_into_buckets(num_blocks)} matched expected: {expected_buckets}" + ) + + @parameterized.named_parameters( + dict(testcase_name="_jax", swap_op_type="jax"), + dict(testcase_name="_pallas", swap_op_type="pallas"), + ) + def test_precompile_run_success(self, swap_op_type: str): + """ + Tests that _precompile_kv_swap_operations runs without errors and + modifies the cache content. + """ + connector = self._create_connector(swap_op_type, + use_precompiled_swap_ops="0") + + worker = connector.connector_worker + + # Keep a copy of the original cache content on the host + original_cache_host = [ + np.array(cache) for cache in worker.runner.kv_caches + ] + + worker._precompile_kv_swap_operations() + + # Fetch the new cache content to the host + new_cache_host = [np.array(cache) for cache in worker.runner.kv_caches] + self.assertTrue( + all( + np.array_equal(orig, new) + for orig, new in zip(original_cache_host, new_cache_host)), + "Cache content should not have changed after precompilation.", + ) + + @parameterized.named_parameters( + dict( + testcase_name="_single_block", + num_blocks_to_save=1, + num_requests=1, + ), + dict( + testcase_name="_multi_requests_single_block", + num_blocks_to_save=1, + num_requests=6, + ), + dict( + testcase_name="_multi_blocks", + num_blocks_to_save=5, + num_requests=1, + ), + dict( + testcase_name="_multi_requests_multi_blocks", + num_blocks_to_save=5, + num_requests=6, + ), + dict( + testcase_name="_multi_blocks_with_compile_jax", + num_blocks_to_save=5, + num_requests=1, + use_precompiled_swap_ops=True, + ), + dict( + testcase_name="_multi_requests_single_block_with_compile_jax", + num_blocks_to_save=1, + num_requests=6, + use_precompiled_swap_ops=True, + ), + dict( + testcase_name="_multi_requests_multi_blocks_with_compile_jax", + num_blocks_to_save=5, + num_requests=6, + use_precompiled_swap_ops=True, + ), + dict( + testcase_name="_multi_blocks_with_compile_pallas", + num_blocks_to_save=5, + num_requests=1, + use_precompiled_swap_ops=True, + swap_op_type="pallas", + ), + dict( + testcase_name="_multi_requests_multi_blocks_with_compile_pallas", + num_blocks_to_save=5, + num_requests=6, + use_precompiled_swap_ops=True, + swap_op_type="pallas", + ), + dict( + testcase_name="_final_save", + num_blocks_to_save=1, + num_requests=1, + is_final_save=True, + skip_save=False, + ), + dict( + testcase_name="_final_skip_save", + num_blocks_to_save=0, + num_requests=1, + is_final_save=True, + skip_save=True, + ), + ) + def test_tpu_connector_save( + self, + num_blocks_to_save: int, + num_requests: int = 1, + is_final_save: bool = False, + skip_save: bool = False, + use_precompiled_swap_ops: bool = False, + swap_op_type: str = "jax", + ): + total_num_blocks_to_save = num_blocks_to_save * num_requests + if total_num_blocks_to_save > self.num_blocks or total_num_blocks_to_save > self.num_cpu_chunks: + self.skipTest( + f"num_blocks_to_save {total_num_blocks_to_save} exceeds ModelRunner / OffloadConnectorWorker's capacity" + ) + + # Prepare and Execute Save + all_block_ids = list(range(self.num_blocks)) + all_chunk_ids = list(range(self.num_cpu_chunks)) + src_block_ids = random.sample(all_block_ids, total_num_blocks_to_save) + dst_chunk_ids = random.sample(all_chunk_ids, total_num_blocks_to_save) + + src_block_ids_split = np.array_split(src_block_ids, num_requests) + dst_chunk_ids_split = np.array_split(dst_chunk_ids, num_requests) + + requests_meta = [] + for i in range(num_requests): + req_id = f"save_req_{i}" + src_blocks = src_block_ids_split[i].tolist() + dst_chunks = dst_chunk_ids_split[i].tolist() + + num_tokens_to_save_per_req = len(src_blocks) * self.block_size + + save_spec = SaveSpec( + num_skip_leading_tokens=0, + num_total_tokens=num_tokens_to_save_per_req, + is_final_save=is_final_save, + skip_save=skip_save, + src_blocks=src_blocks, + dst_chunks=dst_chunks, + ) + + total_token_ids = list(range(num_tokens_to_save_per_req)) + + req_meta = TPUReqMeta( + req_id=req_id, + token_ids=total_token_ids, + local_block_ids=src_blocks, + save_spec=save_spec, + ) + requests_meta.append(req_meta) + + logger.info(f"Starting test_tpu_connector_save with: " + f"num_blocks_to_save={num_blocks_to_save}, " + f"num_requests={num_requests}, " + f"is_final_save={is_final_save}, " + f"skip_save={skip_save}, " + f"use_precompiled_swap_ops={use_precompiled_swap_ops}, " + f"swap_op_type={swap_op_type};") + + connector_metadata = TPUOffloadConnectorMetadata( + requests_meta=requests_meta) + + connector = self._create_connector(swap_op_type, + use_precompiled_swap_ops) + worker = connector.connector_worker + connector.bind_connector_metadata(connector_metadata) + logger.info( + "Connector metadata bound, calling worker.wait_for_save().") + worker.wait_for_save() + logger.info("worker.wait_for_save() completed.") + + # Verification + logger.info("Starting verification phase.") + cpu_backend = worker.cpu_backend + kv_caches = worker.runner.kv_caches + + if skip_save or total_num_blocks_to_save == 0: + logger.info(" no blocks to save") + assert cpu_backend.num_saved_cpu_chunks == 0 + self.assertEmpty(worker.finished_save_reqs) + self.assertEmpty(worker.offload_stats.data["finished_save_chunks"]) + return + + # verify the saved chunks + all_req_ids = {f"save_req_{i}" for i in range(num_requests)} + self.assertSetEqual( + all_req_ids, + set(worker.offload_stats.data["finished_save_chunks"].keys())) + + for i in range(num_requests): + req_id = f"save_req_{i}" + src_blocks = src_block_ids_split[i].tolist() + dst_chunks = dst_chunk_ids_split[i].tolist() + self.assertListEqual( + dst_chunks, + worker.offload_stats.data["finished_save_chunks"][req_id]) + + for tpu_block_id, cpu_chunk_id in zip(src_blocks, dst_chunks): + cpu_kv_chunk = cpu_backend.get(cpu_chunk_id) + for layer_idx in range(self.num_layers): + tpu_kv_block = kv_caches[layer_idx][tpu_block_id] + assert cpu_kv_chunk[ + layer_idx].sharding.memory_kind == 'pinned_host' + self.assertArraysEqual(np.array(tpu_kv_block), + np.array(cpu_kv_chunk[layer_idx])) + + logger.info("Saved data verification completed.") + + if is_final_save: + finished_saves, _ = worker.get_finished() + logger.info( + f"is_final_save is True. Finished requests: {finished_saves}") + self.assertSetEqual(all_req_ids, finished_saves) + + @parameterized.named_parameters( + dict( + testcase_name="_single_block", + num_blocks_to_operate=1, + num_requests=1, + ), + dict( + testcase_name="_multi_requests_single_block", + num_blocks_to_operate=1, + num_requests=4, + ), + dict( + testcase_name="_multi_blocks_compile_jax", + num_blocks_to_operate=5, + num_requests=1, + use_precompiled_swap_ops=True, + swap_op_type="jax", + ), + dict( + testcase_name="_multi_requests_single_block_compile_jax", + num_blocks_to_operate=1, + num_requests=6, + use_precompiled_swap_ops=True, + swap_op_type="jax", + ), + dict( + testcase_name="_multi_requests_multi_blocks_compile_jax", + num_blocks_to_operate=5, + num_requests=6, + use_precompiled_swap_ops=True, + swap_op_type="jax", + ), + dict( + testcase_name="_multi_requests_multi_blocks_compile_pallas", + num_blocks_to_operate=5, + num_requests=6, + use_precompiled_swap_ops=True, + swap_op_type="pallas", + ), + ) + def test_tpu_connector_load( + self, + num_blocks_to_operate: int, + num_requests: int = 1, + use_precompiled_swap_ops: bool = False, + swap_op_type: str = "jax", + ): + """ + This test simulates a scenario where some amount of blocks get + offloaded to cpu cache, and then get loaded into tpu kv cache. + Both swap-out and swap-in are tested. + + Steps: + 1. Setup: + 2. Simulate a save operation + 3. Load the data + 4. Verification + """ + total_num_blocks_to_operate = num_blocks_to_operate * num_requests + if total_num_blocks_to_operate > self.num_blocks or total_num_blocks_to_operate > self.num_cpu_chunks: + self.skipTest( + f"num_blocks_to_save {total_num_blocks_to_operate} exceeds ModelRunner / OffloadConnectorWorker's capacity" + ) + # 1. Setup + connector = self._create_connector(swap_op_type, + use_precompiled_swap_ops) + worker = connector.connector_worker + # Ground truth cache on TPU + src_kv_cache = worker.runner.kv_caches + # Destination cache on TPU, should be modified by the load operation + dst_kv_cache = [ + jax.device_put(jnp.zeros(self.cache_shape, dtype=self.cache_dtype), + self.device_sharding) + for _ in range(self.num_layers) + ] + jax.block_until_ready(dst_kv_cache) + + # 2. Simulate a save operation + all_block_ids = list(range(self.num_blocks)) + all_chunk_ids = list(range(self.num_cpu_chunks)) + src_block_ids = random.sample(all_block_ids, + total_num_blocks_to_operate) + dst_chunk_ids = random.sample(all_chunk_ids, + total_num_blocks_to_operate) + + src_block_ids_split = np.array_split(src_block_ids, num_requests) + dst_chunk_ids_split = np.array_split(dst_chunk_ids, num_requests) + + save_requests_meta = [] + for i in range(num_requests): + req_id = f"save_req_{i}" + src_blocks = src_block_ids_split[i].tolist() + dst_chunks = dst_chunk_ids_split[i].tolist() + num_tokens_to_save_per_req = len(src_blocks) * self.block_size + + save_spec = SaveSpec( + num_skip_leading_tokens=0, + num_total_tokens=num_tokens_to_save_per_req, + is_final_save=False, + skip_save=False, + src_blocks=src_blocks, + dst_chunks=dst_chunks, + ) + total_token_ids = list(range(num_tokens_to_save_per_req)) + req_meta = TPUReqMeta( + req_id=req_id, + token_ids=total_token_ids, + local_block_ids=src_blocks, + save_spec=save_spec, + ) + save_requests_meta.append(req_meta) + + connector_metadata = TPUOffloadConnectorMetadata( + requests_meta=save_requests_meta) + connector.bind_connector_metadata(connector_metadata) + logger.info( + "Connector metadata bound, calling worker.wait_for_save().") + worker.wait_for_save() + logger.info("worker.wait_for_save() completed.") + + # 3. Prepare and Execute Delta Load + worker.runner.kv_caches = dst_kv_cache + + load_requests_meta = [] + for i in range(num_requests): + req_id = f"load_req_{i}" + src_blocks = src_block_ids_split[i].tolist() + dst_chunks = dst_chunk_ids_split[i].tolist() + num_tokens_to_load_per_req = len(src_blocks) * self.block_size + + load_spec = LoadSpec( + num_matched_tokens=num_tokens_to_load_per_req, + dst_blocks=src_blocks, + src_chunks=dst_chunks, + can_load=True, + num_skip_leading_tokens=0, + ) + total_token_ids = list(range(num_tokens_to_load_per_req)) + req_meta = TPUReqMeta( + req_id=req_id, + token_ids=total_token_ids, + local_block_ids=src_blocks, + load_spec=load_spec, + ) + load_requests_meta.append(req_meta) + + connector_metadata = TPUOffloadConnectorMetadata( + requests_meta=load_requests_meta) + connector.bind_connector_metadata(connector_metadata) + logger.info("Connector metadata bound, calling start_load_kv.") + worker.start_load_kv(fwd_ctx=None) + jax.block_until_ready(worker.runner.kv_caches) + logger.info("start_load_kv completed and blocked until ready.") + + # 4. Verification + # verify the data + dst_kv_cache = worker.runner.kv_caches + for i in range(num_requests): + src_blocks = src_block_ids_split[i].tolist() + for src_block_id in src_blocks: + for layer_idx in range(self.num_layers): + self.assertArraysEqual( + np.array(src_kv_cache[layer_idx][src_block_id]), + np.array(dst_kv_cache[layer_idx][src_block_id])) + + # verify the loaded chunks + all_load_req_ids = {f"load_req_{i}" for i in range(num_requests)} + self.assertSetEqual( + all_load_req_ids, + set(worker.offload_stats.data["finished_load_chunks"].keys())) + + for i in range(num_requests): + req_id = f"load_req_{i}" + dst_chunks = dst_chunk_ids_split[i].tolist() + self.assertListEqual( + dst_chunks, + worker.offload_stats.data["finished_load_chunks"][req_id]) diff --git a/tests/offload/tpu_offload_cpu_backend_test.py b/tests/offload/tpu_offload_cpu_backend_test.py new file mode 100644 index 000000000..69094bf64 --- /dev/null +++ b/tests/offload/tpu_offload_cpu_backend_test.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock + +import pytest + +from tpu_inference.offload.cpu_backend import LocalCPUBackend +from tpu_inference.offload.utils import CpuChunkId + + +# Helper to create a mock jax array with a specific size in bytes +def create_mock_jax_array(size_in_bytes: int) -> MagicMock: + """Creates a mock object with an 'nbytes' attribute.""" + mock_value = MagicMock() + mock_value.nbytes = size_in_bytes + return mock_value + + +class TestLocalCPUBackend: + """Test suite for the LocalCPUBackend.""" + + def test_add_and_get(self): + """Verifies that a value can be added and then retrieved successfully.""" + backend = LocalCPUBackend(num_cpu_chunks=10) + key = CpuChunkId(0) + value = create_mock_jax_array(50) + + backend.add(key, value) + retrieved_value = backend.get(key) + + assert retrieved_value == value + assert backend.current_size_bytes == 50 + + # Test with a list of JAX arrays (mocked) + key_list = CpuChunkId(1) + value_list = [create_mock_jax_array(20), create_mock_jax_array(30)] + backend.add(key_list, value_list) + retrieved_list_value = backend.get(key_list) + + assert retrieved_list_value == value_list + assert backend.current_size_bytes == 50 + 20 + 30 + + assert backend.num_saved_cpu_chunks == 2 + + def test_add_invalid_chunk_id(self): + """Verifies that adding a value with an invalid chunk_id raises a ValueError.""" + backend = LocalCPUBackend(num_cpu_chunks=10) + value = create_mock_jax_array(50) + + with pytest.raises(ValueError): + backend.add(CpuChunkId(-1), value) + + assert backend.num_saved_cpu_chunks == 0 + + def test_reclaim_unoccupied_chunks(self): + """Tests that unoccupied chunks are reclaimed correctly.""" + backend = LocalCPUBackend(num_cpu_chunks=10) + key1 = CpuChunkId(0) + key2 = CpuChunkId(1) + key3 = CpuChunkId(2) + value = create_mock_jax_array(10) + + backend.add(key1, value) + backend.add(key2, value) + backend.add(key3, value) + + assert backend.current_size_bytes == 30 + assert len(backend.cache) == 3 + + # Reclaim one chunk + backend.reclaim_unoccupied_chunks(occupied_chunk_ids=[key1, key3]) + + assert backend.current_size_bytes == 20 + assert len(backend.cache) == 2 + assert key1 in backend.cache + assert key2 not in backend.cache + assert key3 in backend.cache + + # Reclaim all chunks + backend.reclaim_unoccupied_chunks(occupied_chunk_ids=[]) + + assert backend.current_size_bytes == 0 + assert len(backend.cache) == 0 diff --git a/tests/offload/tpu_offload_manager_test.py b/tests/offload/tpu_offload_manager_test.py new file mode 100644 index 000000000..5d1674eae --- /dev/null +++ b/tests/offload/tpu_offload_manager_test.py @@ -0,0 +1,343 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest + +from tpu_inference.logger import init_logger +from tpu_inference.offload.offload_manager import (CPUChunkPool, + LRUCacheManager, + StagingBufferManager) +from tpu_inference.offload.utils import ReqId + +logger = init_logger(__name__) + + +class TestStagingBufferManager: + + def test_initialization(self): + manager = StagingBufferManager(num_blocks=100) + assert manager.num_blocks == 100 + assert manager.get_num_free_staging_blocks() == 100 + assert manager.get_num_used_staging_blocks() == 0 + + def test_allocate_simple(self): + manager = StagingBufferManager(num_blocks=100) + req_id1: ReqId = "req1" + req_id2: ReqId = "req2" + + allocated1 = manager.allocate(req_id1, 10, "load") + assert allocated1 == 10 + assert manager.get_num_free_staging_blocks() == 90 + assert manager.get_num_used_staging_blocks() == 10 + assert manager._num_blocks_for_load == 10 + assert manager._num_blocks_for_save == 0 + + allocated2 = manager.allocate(req_id2, 20, "save") + assert allocated2 == 20 + assert manager.get_num_free_staging_blocks() == 70 + assert manager.get_num_used_staging_blocks() == 30 + assert manager._num_blocks_for_load == 10 + assert manager._num_blocks_for_save == 20 + + def test_allocate_insufficient_capacity(self): + manager = StagingBufferManager(num_blocks=10) + req_id: ReqId = "req1" + allocated = manager.allocate(req_id, 20, "load") + assert allocated == 0 + assert manager.get_num_free_staging_blocks() == 10 + assert manager.get_num_used_staging_blocks() == 0 + + def test_allocate_existing_load_request(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + manager.allocate(req_id, 10, "load") + with pytest.raises(ValueError): + # multiple concurrent loads from a single request is not allowed. + manager.allocate(req_id, 5, "load") + + def test_allocate_existing_save_request(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + manager.allocate(req_id, 10, "save") + assert manager._blocks_for_save[req_id] == 10 + manager.allocate(req_id, 5, "save") + assert manager._blocks_for_save[req_id] == 15 + assert manager.get_num_free_staging_blocks() == 85 + assert manager.get_num_used_staging_blocks() == 15 + + def test_allocate_negative_blocks(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + allocated = manager.allocate(req_id, -5, "load") + assert allocated == -5 + assert manager.get_num_free_staging_blocks() == 100 + + def test_free_full(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + manager.allocate(req_id, 10, "load") + freed = manager.free(req_id, "load") + assert freed == 10 + assert manager.get_num_free_staging_blocks() == 100 + assert manager.get_num_used_staging_blocks() == 0 + assert req_id not in manager._blocks_for_load + + def test_free_partial(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + manager.allocate(req_id, 10, "save") + freed = manager.free(req_id, "save", num_finished_blocks=4) + assert freed == 4 + assert manager.get_num_free_staging_blocks() == 94 + assert manager.get_num_used_staging_blocks() == 6 + assert manager._blocks_for_save[req_id] == 6 + + def test_free_more_than_allocated(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + manager.allocate(req_id, 10, "load") + manager.free(req_id, "load", num_finished_blocks=15) + assert req_id not in manager._blocks_for_load + + def test_free_non_existent_request(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + freed = manager.free(req_id, "load") + assert freed == 0 + + def test_complex_scenario(self): + manager = StagingBufferManager(num_blocks=50) + req1, req2, req3 = "req1", "req2", "req3" + + # req1 loads 10, req2 saves 15 + assert manager.allocate(req1, 10, "load") == 10 + assert manager.allocate(req2, 15, "save") == 15 + assert manager.get_num_free_staging_blocks() == 25 + assert manager.get_num_used_staging_blocks() == 25 + + # req3 tries to load 30, fails + assert manager.allocate(req3, 30, "load") == 0 + assert manager.get_num_free_staging_blocks() == 25 + + # req1 finishes loading + assert manager.free(req1, "load") == 10 + assert manager.get_num_free_staging_blocks() == 35 + + # req3 can now load 20 + assert manager.allocate(req3, 20, "load") == 20 + assert manager.get_num_free_staging_blocks() == 15 + assert manager.get_num_used_staging_blocks( + ) == 35 # 15 for save (req2) + 20 for load (req3) + + # req2 saves another 5 + assert manager.allocate(req2, 5, "save") == 5 + assert manager.get_num_free_staging_blocks() == 10 + assert manager._blocks_for_save[req2] == 20 + + # req2 frees 8 blocks + assert manager.free(req2, "save", 8) == 8 + assert manager.get_num_free_staging_blocks() == 18 + assert manager._blocks_for_save[req2] == 12 + + # req2 and req3 finish + assert manager.free(req2, "save") == 12 + assert manager.free(req3, "load") == 20 + assert manager.get_num_free_staging_blocks() == 50 + assert manager.get_num_used_staging_blocks() == 0 + + +class TestCPUChunkPool: + + def test_initialization(self): + pool = CPUChunkPool(num_chunks=10) + assert pool.num_chunks == 10 + assert pool.num_free_chunks == 10 + assert pool.num_allocated_chunks == 0 + assert len(pool.free_chunk_list) == 10 + + def test_allocate_chunks(self): + pool = CPUChunkPool(num_chunks=10) + chunk_hashes = [101, 102, 103] + chunks = pool.allocate_chunks(chunk_hashes) + + assert len(chunks) == 3 + assert pool.num_free_chunks == 7 + assert pool.num_allocated_chunks == 3 + for i, chunk in enumerate(chunks): + assert chunk.chunk_hash == chunk_hashes[i] + assert chunk.chunk_id in pool.allocated_id_to_hash_map + + def test_allocate_chunks_insufficient_space(self): + pool = CPUChunkPool(num_chunks=2) + chunk_hashes = [101, 102, 103] + with pytest.raises(ValueError): + pool.allocate_chunks(chunk_hashes) + + def test_release_chunks(self): + pool = CPUChunkPool(num_chunks=10) + chunk_hashes = [101, 102, 103] + chunks = pool.allocate_chunks(chunk_hashes) + for chunk in chunks: + chunk.touch() + + for chunk in chunks: + pool.release_chunk(chunk) + + assert pool.num_free_chunks == 10 + assert pool.num_allocated_chunks == 0 + assert len(pool.free_chunk_list) == 10 + for chunk in chunks: + assert chunk.chunk_id not in pool.allocated_id_to_hash_map + assert chunk.chunk_hash is None + assert chunk.ref_cnt == -1 + + def test_release_chunks_in_use(self): + pool = CPUChunkPool(num_chunks=10) + chunk_hashes = [101] + chunks = pool.allocate_chunks(chunk_hashes) + chunks[0].touch() # ref_cnt = 0: saved + chunks[0].touch() # ref_cnt = 1: loading + + assert not pool.release_chunk(chunks[0]) + + +class TestLRUCacheManager: + + def test_initialization(self): + manager = LRUCacheManager(num_cpu_chunks=20) + assert manager.num_chunks == 20 + assert isinstance(manager.chunk_pool, CPUChunkPool) + assert len(manager.cpu_cache) == 0 + + def test_lookup(self): + manager = LRUCacheManager(num_cpu_chunks=20) + chunk_hashes = [101, 102, 103] + + # 1. Cache miss + assert manager.lookup(chunk_hashes) == 0 + + # 2. Cache hit + # Manually add to cache for testing + chunks = manager.chunk_pool.allocate_chunks(chunk_hashes) + for chunk, h in zip(chunks, chunk_hashes): + chunk.touch() # Make it ready to load + manager.cpu_cache[h] = chunk + + assert manager.lookup(chunk_hashes) == 3 + + # 3. Partial hit + assert manager.lookup([101, 102, 104]) == 2 + + def test_touch(self): + manager = LRUCacheManager(num_cpu_chunks=3) + chunk_hashes = [101, 102, 103] + chunks = manager.chunk_pool.allocate_chunks(chunk_hashes) + for chunk, h in zip(chunks, chunk_hashes): + manager.cpu_cache[h] = chunk + + manager.touch([101]) + assert list(manager.cpu_cache.keys()) == [102, 103, 101] + + manager.touch([102, 103]) + assert list(manager.cpu_cache.keys()) == [101, 103, 102] + + def test_allocate_for_save_simple(self): + manager = LRUCacheManager(num_cpu_chunks=5) + chunk_hashes = [101, 102] + + new_chunks, new_chunk_idxs = manager.allocate_for_save(chunk_hashes) + + assert len(new_chunks) == 2 + assert new_chunk_idxs == [0, 1] + assert manager.chunk_pool.num_free_chunks == 3 + assert len(manager.cpu_cache) == 2 + + def test_allocate_for_save_no_new_chunks(self): + manager = LRUCacheManager(num_cpu_chunks=5) + chunk_hashes = [101, 102] + manager.allocate_for_save(chunk_hashes) + + result = manager.allocate_for_save(chunk_hashes) + assert result is None + + def test_allocate_for_save_with_eviction(self): + manager = LRUCacheManager(num_cpu_chunks=2) + # Fill the cache + manager.allocate_for_save([101, 102]) + # Mark as evictable + manager.cpu_cache[101].touch() + manager.cpu_cache[102].touch() + + manager.touch([101, 102]) + + # This should evict 102 + new_chunks, new_chunk_idxs = manager.allocate_for_save([103]) + + assert len(new_chunks) == 1 + assert new_chunk_idxs == [0] + assert 102 not in manager.cpu_cache + assert 101 in manager.cpu_cache + assert 103 in manager.cpu_cache + assert manager.chunk_pool.num_free_chunks == 0 + + def test_allocate_for_save_cannot_evict(self): + manager = LRUCacheManager(num_cpu_chunks=2) + manager.allocate_for_save([101, 102]) + # Mark as in use, not evictable + manager.cpu_cache[101].touch() + manager.cpu_cache[101].touch() + manager.cpu_cache[102].touch() + manager.cpu_cache[102].touch() + + result = manager.allocate_for_save([103]) + assert result is None + assert len(manager.cpu_cache) == 2 + + def test_prepare_load(self): + manager = LRUCacheManager(num_cpu_chunks=2) + chunk_hashes = [101] + manager.allocate_for_save(chunk_hashes) + manager.complete_save(chunk_hashes) # ref_cnt = 0 + + chunks = manager.prepare_load(chunk_hashes) + assert len(chunks) == 1 + assert chunks[0].is_in_use # ref_cnt = 1 + + def test_complete_save(self): + manager = LRUCacheManager(num_cpu_chunks=2) + chunk_hashes = [101] + manager.allocate_for_save(chunk_hashes) + + chunk = manager.cpu_cache[101] + assert not chunk.is_ready_to_load # ref_cnt = -1 + + manager.complete_save(chunk_hashes) + assert chunk.is_ready_to_load # ref_cnt = 0 + + def test_complete_load(self): + manager = LRUCacheManager(num_cpu_chunks=2) + chunk_hashes = [101] + manager.allocate_for_save(chunk_hashes) + manager.complete_save(chunk_hashes) + chunks = manager.prepare_load(chunk_hashes) + + assert chunks[0].is_in_use # ref_cnt = 1 + manager.complete_load(chunk_hashes) + assert not chunks[0].is_in_use # ref_cnt = 0 + + def test_mark_completion(self): + manager = LRUCacheManager(num_cpu_chunks=2) + chunk_hashes = [101] + new_chunks, _ = manager.allocate_for_save(chunk_hashes) + chunk_ids = [c.chunk_id for c in new_chunks] + + manager.mark_completion(chunk_ids, 'save') + assert manager.cpu_cache[101].is_ready_to_load + + manager.prepare_load(chunk_hashes) + assert manager.cpu_cache[101].is_in_use + manager.mark_completion(chunk_ids, 'load') + assert not manager.cpu_cache[101].is_in_use + + def test_mark_completion_unknown_id(self): + manager = LRUCacheManager(num_cpu_chunks=2) + with pytest.raises(ValueError): + manager.mark_completion([999], 'save') diff --git a/tests/offload/tpu_offload_utils_test.py b/tests/offload/tpu_offload_utils_test.py new file mode 100644 index 000000000..e5c28af27 --- /dev/null +++ b/tests/offload/tpu_offload_utils_test.py @@ -0,0 +1,158 @@ +import functools +import itertools +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import NamedSharding, PartitionSpec + +from tpu_inference.offload.utils import (get_kv_cache_swap_fn, + jitted_insert_kv_cache_slices) + + +class TestTPUOffloadUtilsFn(unittest.TestCase): + + def setUp(self): + """Set up common parameters for the tests.""" + self.num_layers = 2 + self.num_tokens = 256 + num_devices = len(list(jax.devices())) + self.num_kv_heads = num_devices + self.head_dim = 128 + self.block_size = 16 + self.num_blocks = self.num_tokens // self.block_size + self.cache_shape = ( + self.num_blocks, + self.block_size, + self.num_kv_heads, + 2, + self.head_dim, + ) + self.block_shape = ( + self.block_size, + self.num_kv_heads, + 2, + self.head_dim, + ) + + self.cache_dtype = jnp.bfloat16 + + self.mesh = self.create_mesh((1, num_devices), ("data", "model")) + partition_spec = PartitionSpec(None, None, "model") + self.device_sharding = NamedSharding(self.mesh, + partition_spec, + memory_kind="device") + self.host_sharding = NamedSharding(self.mesh, + partition_spec, + memory_kind="pinned_host") + flatten_partition_spec = PartitionSpec(None, "model") + self.flatten_device_sharding = NamedSharding(self.mesh, + flatten_partition_spec, + memory_kind="device") + + def create_mesh(self, axis_shapes, axis_names): + """Creates a JAX device mesh with the default device order.""" + try: + num_required_devices = np.prod(axis_shapes) + devices = np.array(jax.devices()) + if len(devices) < num_required_devices: + self.skipTest( + f"Not enough devices to create mesh of shape {axis_shapes}." + ) + device_array = devices[:num_required_devices].reshape(axis_shapes) + return jax.sharding.Mesh(device_array, axis_names) + except RuntimeError: + return None + + def test_jitted_insert_kv_cache_slices_equivalence(self): + """ + Verify inserting scattered kv slices / pages into the large kv cache. + """ + num_blocks_to_insert = 3 + dst_blocks = [3, 5, 7] + dst_blocks_array = jnp.array(dst_blocks) + + initial_kv_caches = [ + jax.device_put(jnp.zeros(self.cache_shape, dtype=self.cache_dtype), + self.device_sharding) + for _ in range(self.num_layers) + ] + + # The raw, chunked KV data (input for the new method) + # This is a list of lists: List[layer -> List[block]] + raw_chunked_kv = [] + for i in range(self.num_layers): + layer_chunks = [ + jax.device_put( + jax.random.normal(jax.random.key(i), + shape=self.block_shape, + dtype=self.cache_dtype), + self.flatten_device_sharding) + for _ in range(num_blocks_to_insert) + ] + raw_chunked_kv.append(layer_chunks) + + output = jitted_insert_kv_cache_slices(self.block_size, + initial_kv_caches, + raw_chunked_kv, + dst_blocks_array) + + # --- Verification --- + # Check that the selected pages for each layer equal to the original ones. + for i in range(self.num_layers): + for j in range(num_blocks_to_insert): + block_id = dst_blocks[j] + np.testing.assert_array_equal(np.array(output[i][block_id]), + raw_chunked_kv[i][j]) + print("\nTest passed: the inserted kv equals to the original one.") + + def test_swap_fn_correctness(self): + """ + Verify that swap-out and swap-in functions work correctly for different + swap_op_types and jitted options. + """ + swap_op_types = ["jax", "pallas"] + jitted_options = [True, False] + + # NOTE(jcgu): we are using the entire kv cache [n_b, bs, nh, 2, hd], + # actually, we will operate on concatenated blocks [nt, nh, 2, hd]; + @functools.partial(jax.jit, out_shardings=self.device_sharding) + def create_on_device(key): + return jax.random.uniform(key, + shape=self.cache_shape, + dtype=self.cache_dtype) + + initial_kv_caches = [ + create_on_device(jax.random.key(i)) for i in range(self.num_layers) + ] + jax.block_until_ready(initial_kv_caches) + + for swap_op_type, jitted in itertools.product(swap_op_types, + jitted_options): + with self.subTest(swap_op_type=swap_op_type, jitted=jitted): + swap_in_fn, swap_out_fn = get_kv_cache_swap_fn( + swap_op_type, self.host_sharding, self.device_sharding, + jitted) + + # Put initial data on device + device_kv_caches = jax.device_put(initial_kv_caches, + self.device_sharding) + jax.block_until_ready(device_kv_caches) + + # Swap out to host + host_kv_caches = swap_out_fn(device_kv_caches) + + # Swap back in to device + final_device_kv_caches = swap_in_fn(host_kv_caches) + jax.block_until_ready(final_device_kv_caches) + + # Verify correctness + for i in range(self.num_layers): + np.testing.assert_array_equal( + np.array(initial_kv_caches[i]), + np.array(final_device_kv_caches[i])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tpu_inference/envs.py b/tpu_inference/envs.py index 9201e1a11..546e81875 100644 --- a/tpu_inference/envs.py +++ b/tpu_inference/envs.py @@ -24,6 +24,12 @@ NUM_SLICES: int = 1 RAY_USAGE_STATS_ENABLED: str = "0" VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm" + TPU_OFFLOAD_SKIP_JAX_PRECOMPILE: bool = False + TPU_OFFLOAD_SWAP_OP_TYPE: str = "jax" + TPU_OFFLOAD_DECODE_SAVE: bool = False + TPU_OFFLOAD_NUM_CPU_CHUNKS: int = 1024 + TPU_OFFLOAD_NUM_STAGING_BLOCKS: int = 128 + TPU_OFFLOAD_SAVE_THREADS: int = 1 def env_with_choices( @@ -122,6 +128,24 @@ def _get_validated_env() -> str | None: # Ray compiled DAG channel type for TPU "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]), + # kv offload to dram: skip pre-compiling swap-related jax functions + "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE": + lambda: bool(int(os.getenv("TPU_OFFLOAD_SKIP_JAX_PRECOMPILE", "0"))), + # kv offload to dram: swap function type: jax, or pallas + "TPU_OFFLOAD_SWAP_OP_TYPE": + lambda: os.getenv("TPU_OFFLOAD_SWAP_OP_TYPE", "jax"), + # kv offload to dram: save kv in the decode phase + "TPU_OFFLOAD_DECODE_SAVE": + lambda: bool(int(os.getenv("TPU_OFFLOAD_DECODE_SAVE", "0"))), + # kv offload to dram: dram space size in # of chunks / blocks + "TPU_OFFLOAD_NUM_CPU_CHUNKS": + lambda: int(os.getenv("TPU_OFFLOAD_NUM_CPU_CHUNKS", "1024")), + # kv offload to dram: size of staging buffer (hbm) for swap + "TPU_OFFLOAD_NUM_STAGING_BLOCKS": + lambda: int(os.getenv("TPU_OFFLOAD_NUM_STAGING_BLOCKS", "128")), + # kv offload to dram: number of threads for asynchronous TPU -> CPU data transfer + "TPU_OFFLOAD_SAVE_THREADS": + lambda: int(os.getenv("TPU_OFFLOAD_SAVE_THREADS", "1")), } diff --git a/tpu_inference/kernels/dma/__init__.py b/tpu_inference/kernels/dma/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tpu_inference/kernels/dma/host_dma.py b/tpu_inference/kernels/dma/host_dma.py new file mode 100644 index 000000000..68a53f9d0 --- /dev/null +++ b/tpu_inference/kernels/dma/host_dma.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +""" Host <-> HBM DMA kernel""" +import jax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + + +def host_hbm_dma(x_ref, y_ref): + """ + DMA a jax array between host and hbm + Input jax array ref: x_ref + Output jax array ref: y_ref + """ + + def body(sem): + pltpu.async_copy(x_ref, y_ref, sem).wait() + + pl.run_scoped(body, pltpu.SemaphoreType.DMA) + + +# NOTE(jcgu): input / out arrays should have the same sharding, but different memory_kind +# NOTE(jcgu): only support NamedSharding, does not support SingleDeviceSharding +def d2h_dma( + input_array: jax.Array, + input_sharding: jax.sharding.NamedSharding, + out_sharding: jax.sharding.NamedSharding, +) -> jax.Array: + """ DMA a device jax array to host memory. + Args: + input_array: input jax array on device hbm + input_sharding: input's device sharding + out_sharding: output's host sharding + Returns: + jax array on host memory with the same sharding + """ + + @jax.jit + def _d2h_dma_call(x): + return pl.pallas_call( + host_hbm_dma, + in_specs=[ + pl.BlockSpec(memory_space=pl.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pl.HOST), + out_shape=pltpu.HOST(shape=x.shape, dtype=x.dtype), + name="d2h_dma_kernel", + )(x) + + d2h_dma_kernel = jax.jit( + jax.shard_map( + _d2h_dma_call, + mesh=input_sharding.mesh, + in_specs=input_sharding.spec, + out_specs=out_sharding.spec, + check_vma=False, + ), + out_shardings=out_sharding, + ) + + return d2h_dma_kernel(input_array) + + +# NOTE(jcgu): input / out arrays should have the same sharding, but different memory_kind +# NOTE(jcgu): only support NamedSharding, does not support SingleDeviceSharding +def h2d_dma( + input_array: jax.Array, + input_sharding: jax.sharding.NamedSharding, + out_sharding: jax.sharding.NamedSharding, +) -> jax.Array: + """ DMA a host jax array to device hbm. + Args: + input_array: input jax array on host memory + input_sharding: the host sharding for input + out_sharding: the device sharding for output + Returns: + jax array on device hbm with the assigned sharding + """ + + @jax.jit + def _h2d_dma_call(x): + return pl.pallas_call( + host_hbm_dma, + in_specs=[ + pl.BlockSpec(memory_space=pl.HOST), + ], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + out_shape=jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), + name="h2d_dma_kernel", + )(x) + + h2d_dma_kernel = jax.jit( + jax.shard_map( + _h2d_dma_call, + mesh=input_sharding.mesh, + in_specs=input_sharding.spec, + out_specs=out_sharding.spec, + check_vma=False, + ), + out_shardings=out_sharding, + ) + + return h2d_dma_kernel(input_array) diff --git a/tpu_inference/offload/__init__.py b/tpu_inference/offload/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tpu_inference/offload/cpu_backend.py b/tpu_inference/offload/cpu_backend.py new file mode 100644 index 000000000..e4613a4a7 --- /dev/null +++ b/tpu_inference/offload/cpu_backend.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 + +import sys +from collections import OrderedDict +from typing import Any, Optional + +from tpu_inference.logger import init_logger +from tpu_inference.offload.utils import CpuChunkId + +logger = init_logger(__name__) + + +class LocalCPUBackend: + """ + A singleton in-memory CPU backend for storing KV cache keys and values. + + This class uses the singleton pattern to ensure that the scheduler and the + worker, running in the same process, can share the same cache. + The scheduler reads from this to find cache hits, and the worker writes + to it after saving KV blocks from the TPU. + + It implements an LRU (Least Recently Used) eviction policy with a maximum + size limit and support for pinning cache entries to prevent eviction. + """ + + def __init__(self, num_cpu_chunks: int): + self.max_num_cpu_chunks = num_cpu_chunks + self.cache: OrderedDict[CpuChunkId, Any] = OrderedDict() + self.current_size_bytes = 0 + self._num_saved_cpu_chunks = 0 + logger.info( + "LocalCPUBackend initialized." + f"CPU cache capacity: {self.max_num_cpu_chunks} chunks / pages.") + + @property + def num_saved_cpu_chunks(self) -> int: + return self._num_saved_cpu_chunks + + def _get_value_size(self, value: Any) -> int: + """Calculates the size of a cache value in bytes.""" + size_in_bytes = 0 + if isinstance(value, list): + # The value is a list of JAX arrays (one per layer) + size_in_bytes = sum(v.nbytes for v in value + if hasattr(v, 'nbytes')) + elif hasattr(value, 'nbytes'): + size_in_bytes = value.nbytes + else: + size_in_bytes = sys.getsizeof(value) + return size_in_bytes + + def add(self, chunk_id: CpuChunkId, value: Any) -> bool: + """ + Adds a key-value pair to the cache. + + If the cache is full, it evicts the least recently used, unpinned + entries until there is enough space. + """ + if chunk_id < 0 or chunk_id >= self.max_num_cpu_chunks: + # TODO(jcgu): report failure when offload scheduler / worker + # can handle failed operations. + raise ValueError(f" get invalid chunk_id: {chunk_id}") + + # Add the new item. + if chunk_id in self.cache: + old_value = self.cache.pop(chunk_id) + self.current_size_bytes -= self._get_value_size(old_value) + del old_value + self._num_saved_cpu_chunks -= 1 + + self.cache[chunk_id] = value + self._num_saved_cpu_chunks += 1 + value_size = self._get_value_size(value) + self.current_size_bytes += value_size + logger.info( + f"Added chunk_id: {chunk_id} (size:{value_size}) to CPU backend.") + logger.info( + f"Cache: {self.current_size_bytes} bytes, {self._num_saved_cpu_chunks} occupied chunks." + ) + return True + + def get(self, chunk_id: CpuChunkId) -> Optional[Any]: + """ + Gets the value for a given chunk_id and marks it as recently used. + """ + if chunk_id in self.cache: + return self.cache[chunk_id] + return None + + def reclaim_unoccupied_chunks(self, occupied_chunk_ids: list[CpuChunkId]): + chunk_ids = list(self.cache.keys()) + unoccupied_chunk_ids = [ + chunk_id for chunk_id in chunk_ids + if chunk_id not in occupied_chunk_ids + ] + reclaimed_size_bytes = 0 + for chunk_id in unoccupied_chunk_ids: + dummy_value = self.cache.pop(chunk_id) + reclaimed_size_bytes += self._get_value_size(dummy_value) + del dummy_value + self.current_size_bytes -= reclaimed_size_bytes + + logger.info( + f" Reclaimed {len(unoccupied_chunk_ids)} unoccupied chunks, " + f"with {reclaimed_size_bytes} bytes.") diff --git a/tpu_inference/offload/offload_manager.py b/tpu_inference/offload/offload_manager.py new file mode 100644 index 000000000..fc792d4d2 --- /dev/null +++ b/tpu_inference/offload/offload_manager.py @@ -0,0 +1,377 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections import OrderedDict +from dataclasses import dataclass +from typing import Literal, Optional, Tuple + +from vllm.v1.core.kv_cache_utils import BlockHash + +from tpu_inference.logger import init_logger +from tpu_inference.offload.utils import CpuChunkId, ReqId + +logger = init_logger(__name__) + +ChunkHash = BlockHash + + +@dataclass +class CPUChunk: + """ + ref_cnt: + -1: init, not saved + 0: saved, ready_to_evict, ready_to_load + >=1: loadings, ready_to_load, in_use + """ + chunk_id: CpuChunkId + ref_cnt: int = -1 + _chunk_hash: ChunkHash | None = None + + @property + def is_ready_to_load(self): + return self.ref_cnt >= 0 + + @property + def is_ready_to_evict(self): + return self.ref_cnt == 0 + + @property + def is_in_use(self): + return self.ref_cnt >= 1 + + @property + def chunk_hash(self): + return self._chunk_hash + + def touch(self): + self.ref_cnt += 1 + + def untouch(self): + self.ref_cnt -= 1 + + def reset(self): + self._chunk_hash = None + self.ref_cnt = -1 + + +class CPUChunkPool: + + def __init__(self, num_chunks: int): + self.num_chunks: int = num_chunks + self._num_allocated_chunks: int = 0 + self.free_chunk_list: list[CPUChunk] = [ + CPUChunk(idx) for idx in range(num_chunks - 1, -1, -1) + ] + # {allocated_chunk_id: chunk_hash} + self.allocated_id_to_hash_map: dict[CpuChunkId, ChunkHash] = {} + + @property + def num_free_chunks(self): + return self.num_chunks - self._num_allocated_chunks + + @property + def num_allocated_chunks(self): + return self._num_allocated_chunks + + def allocate_chunks(self, chunk_hashes: list[ChunkHash]) -> list[CPUChunk]: + num_required_chunks = len(chunk_hashes) + if num_required_chunks > self.num_free_chunks: + raise ValueError( + f"Cannot get {num_required_chunks} free chunks from the pool") + + ret: list[CPUChunk] = [ + self.free_chunk_list.pop() for _ in range(num_required_chunks) + ] + self._num_allocated_chunks += num_required_chunks + for chunk, chunk_hash in zip(ret, chunk_hashes): + chunk._chunk_hash = chunk_hash + assert chunk.chunk_id not in self.allocated_id_to_hash_map + self.allocated_id_to_hash_map[chunk.chunk_id] = chunk_hash + + return ret + + def release_chunk(self, chunk: CPUChunk) -> bool: + if not chunk.is_ready_to_evict: + logger.warning(f" Chunk[{chunk.chunk_id}] is still in use.") + return False + assert chunk.chunk_id in self.allocated_id_to_hash_map + self.allocated_id_to_hash_map.pop(chunk.chunk_id) + chunk.reset() + self.free_chunk_list.append(chunk) + self._num_allocated_chunks -= 1 + return True + + +class LRUCacheManager: + + def __init__(self, num_cpu_chunks: int): + self.num_chunks = num_cpu_chunks + self.chunk_pool = CPUChunkPool(self.num_chunks) + + self.cpu_cache: OrderedDict[ChunkHash, CPUChunk] = OrderedDict() + + # The cache is an OrderedDict for LRU behavior. + def lookup(self, chunk_hashes: list[ChunkHash]) -> int: + """_summary_ + return the number of cache hit starting from the first chunk + """ + hit_count = 0 + for chunk_hash in chunk_hashes: + chunk = self.cpu_cache.get(chunk_hash) + if chunk is None or not chunk.is_ready_to_load: + break + hit_count += 1 + return hit_count + + def touch(self, chunk_hashes: list[ChunkHash]) -> int: + """ access chunks for both save / load; and move them to the end.""" + for chunk_hash in reversed(chunk_hashes): + if self.cpu_cache.get(chunk_hash): + self.cpu_cache.move_to_end(chunk_hash) + + def allocate_for_save( + self, chunk_hashes: list[ChunkHash] + ) -> Tuple[list[CPUChunk], list[int]] | None: + # filter out chunks that are already stored + num_chunks = len(chunk_hashes) + new_chunk_idxs = [ + i for i in range(num_chunks) + if chunk_hashes[i] not in self.cpu_cache + ] + + num_new_chunks = len(new_chunk_idxs) + if num_new_chunks == 0: + logger.info("No new chunks to allocate") + return None + num_chunks_to_evict = max( + 0, num_new_chunks - self.chunk_pool.num_free_chunks) + + # build list of chunks to evict / reuse + to_evict = [] + if num_chunks_to_evict > 0: + for chunk_hash, chunk in self.cpu_cache.items(): + if chunk.is_ready_to_evict: + to_evict.append(chunk_hash) + num_chunks_to_evict -= 1 + if num_chunks_to_evict == 0: + break + else: + # we could not evict enough chunks + return None + + # evict chunks + for evicting_chunk_hash in to_evict: + evicting_chunk = self.cpu_cache.pop(evicting_chunk_hash) + # always true, since all evicting chunks are ready to evict + self.chunk_pool.release_chunk(evicting_chunk) + + new_chunk_hashes = [chunk_hashes[i] for i in new_chunk_idxs] + # allocate + try: + new_chunks = self.chunk_pool.allocate_chunks(new_chunk_hashes) + assert len(new_chunks) == len(new_chunk_hashes) + except Exception as e: + logger.warning(f" Failed to allocate {len(new_chunk_hashes)}: {e}") + # NOTE(jcgu): should we return None or something else? + return None + for chunk_hash, chunk in zip(new_chunk_hashes, new_chunks): + self.cpu_cache[chunk_hash] = chunk + # newly-allocated chunks, chunk-idx in the given chunk_hashes list + return new_chunks, new_chunk_idxs + + def prepare_load(self, chunk_hashes: list[ChunkHash]) -> list[CPUChunk]: + chunks = [] + for chunk_hash in chunk_hashes: + chunk = self.cpu_cache[chunk_hash] + assert chunk.is_ready_to_load + chunk.touch() + chunks.append(chunk) + return chunks + + def complete_save(self, chunk_hashes: list[ChunkHash]) -> None: + """ After store completion, mark the chunk to be ready to load.""" + for chunk_hash in chunk_hashes: + chunk = self.cpu_cache[chunk_hash] + assert not chunk.is_ready_to_load + # mark ready to load + chunk.touch() + assert chunk.is_ready_to_load + + def complete_load(self, chunk_hashes: list[ChunkHash]) -> None: + for chunk_hash in chunk_hashes: + chunk = self.cpu_cache[chunk_hash] + assert chunk.is_in_use + chunk.untouch() + + def mark_completion(self, chunk_ids, operation: Literal['save', + 'load']) -> None: + try: + chunk_hashes = [ + self.chunk_pool.allocated_id_to_hash_map[chunk_id] + for chunk_id in chunk_ids + ] + except Exception as e: + raise ValueError(f' failed to retrieve chunk hashes: {e}') + + chunk_hashes = [] + unknown_chunk_ids = [] + for chunk_id in chunk_ids: + if chunk_id in self.chunk_pool.allocated_id_to_hash_map: + chunk_hashes.append( + self.chunk_pool.allocated_id_to_hash_map[chunk_id]) + else: + unknown_chunk_ids.append(chunk_id) + if unknown_chunk_ids: + logger.warning( + f" Chunks[{unknown_chunk_ids}] are not found as allocated chunks in the pool." + ) + + if operation == 'save': + self.complete_save(chunk_hashes) + elif operation == 'load': + self.complete_load(chunk_hashes) + else: + raise ValueError(f"Unknown operation: {operation}") + + +class StagingBufferManager(): + """ Bookkeeping the staging buffer inside the connector scheduler. + NOTE(jcgu): the operations (e.g., allocate, free, get) to staging buffer / blocks are NOT thread-safe. + But it's okay since there is only one connector scheduler instance. + """ + + def __init__(self, num_blocks: int): + self.num_blocks = num_blocks + # {req_id: list(num_occupied_staging_blocks)} + self._blocks_for_save: dict[ReqId, int] = {} + self._blocks_for_load: dict[ReqId, int] = {} + + self._num_free_blocks: int = self.num_blocks + # keep track of the total occupied staging blocks for save and load respectively + self._num_blocks_for_save: int = 0 + self._num_blocks_for_load: int = 0 + + def get_num_free_staging_blocks(self) -> int: + return self._num_free_blocks + + def get_num_used_staging_blocks(self) -> int: + return self._num_blocks_for_load + self._num_blocks_for_save + + def get_num_used_save_staging_blocks(self, req_id: ReqId) -> int: + return self._blocks_for_save.get(req_id, 0) + + def get_num_used_load_staging_blocks(self, req_id: ReqId) -> int: + return self._blocks_for_load.get(req_id, 0) + + def allocate(self, req_id: ReqId, num_blocks: int, + usage: Literal["load", "save"]) -> int: + if num_blocks < 0: + logger.warning( + f" get {num_blocks} staging blocks to allocate for Req:{req_id}." + ) + return num_blocks + if num_blocks > self._num_free_blocks: + # do not have enough capacity, return 0 + return 0 + + if usage == "load": + if req_id in self._blocks_for_load: + # NOTE(jcgu): before completing the previous load, new load + # should not be triggered for the same request (is this correct?) + raise ValueError( + f" Req({req_id}) already has {self._blocks_for_load[req_id]}, and should not have new loads." + ) + else: + self._blocks_for_load[req_id] = num_blocks + self._num_blocks_for_load += num_blocks + elif usage == "save": + if req_id in self._blocks_for_save: + self._blocks_for_save[req_id] += num_blocks + else: + self._blocks_for_save[req_id] = num_blocks + self._num_blocks_for_save += num_blocks + else: + raise ValueError( + f" Staging buffer manager should not get usage: {usage}") + self._num_free_blocks -= num_blocks + + logger.info( + f" allocate {num_blocks} staging blocks to Req:{req_id} for {usage}." + ) + return num_blocks + + def free(self, + req_id: ReqId, + usage: Literal["load", "save"], + num_finished_blocks: Optional[int] = None) -> int: + """ + when num_finished_blocks is not given, we will assume the request is finished and should be removed. + """ + num_freed_blocks = 0 + # NOTE(jcgu): assuming FIFO execution order for a single request's save and + # load operations respectively + if usage == "load": + if req_id not in self._blocks_for_load: + logger.warning( + f" there is no record of staging buffer (usage: {usage}) for Req:{req_id}" + ) + return 0 + if num_finished_blocks is None: + num_freed_blocks = self._blocks_for_load[req_id] + else: + num_freed_blocks = num_finished_blocks + if self._blocks_for_load[req_id] < num_freed_blocks: + logger.warning( + f" Req({req_id}) has {num_finished_blocks} load staging buffer to free, but only has {self._blocks_for_load[req_id]} on record." + ) + + self._blocks_for_load[req_id] -= num_freed_blocks + if self._blocks_for_load[req_id] <= 0: + del self._blocks_for_load[req_id] + self._num_blocks_for_load -= num_freed_blocks + elif usage == "save": + if req_id not in self._blocks_for_save: + logger.warning( + f" there is no record of staging buffer (usage: {usage}) for Req:{req_id}" + ) + return 0 + if num_finished_blocks is None: + num_freed_blocks = self._blocks_for_save[req_id] + else: + num_freed_blocks = num_finished_blocks + if self._blocks_for_save[req_id] < num_freed_blocks: + logger.warning( + f" Req({req_id}) has {num_finished_blocks} save staging buffer to free, but only has {self._blocks_for_save[req_id]} on record." + ) + + self._blocks_for_save[req_id] -= num_freed_blocks + if self._blocks_for_save[req_id] <= 0: + del self._blocks_for_save[req_id] + self._num_blocks_for_save -= num_freed_blocks + else: + raise ValueError( + f" Staging buffer manager should not get usage: {usage}") + self._num_free_blocks += num_freed_blocks + + logger.info( + f" free {num_freed_blocks} staging blocks (usage: {usage}) from Req:{req_id}" + ) + return num_freed_blocks + + def get_usage(self, with_details: bool = False): + usage_str = (f"Staging Buffer: total={self.num_blocks}, " + f"free={self._num_free_blocks}, " + f"used_for_load={self._num_blocks_for_load}, " + f"used_for_save={self._num_blocks_for_save};") + if with_details: + blocks_for_save_str = " save_details:{" + for req, bn in self._blocks_for_save.items(): + blocks_for_save_str += f"{req}:{bn}," + blocks_for_save_str += "} " + + blocks_for_load_str = " load_details:{" + for req, bn in self._blocks_for_load.items(): + blocks_for_load_str += f"{req}:{bn}," + blocks_for_load_str += "}." + usage_str += blocks_for_save_str + blocks_for_load_str + + return usage_str diff --git a/tpu_inference/offload/tpu_offload_connector.py b/tpu_inference/offload/tpu_offload_connector.py new file mode 100644 index 000000000..407d9eab6 --- /dev/null +++ b/tpu_inference/offload/tpu_offload_connector.py @@ -0,0 +1,1906 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Scheduler side execution: +TPUOffloadConnectorScheduler manages the state of KV cache loading and saving for +each request. It acts as a state machine, tracking the progress of requests +across multiple scheduling steps and generating work orders (TPUReqMeta) for +the TPUOffloadConnectorWorker. + +Core Components: +- RequestTracker: The primary state object for a request. It tracks the + cumulative tokens and blocks processed, and how many of those tokens have + been saved to the CPU cache. A tracker is created when a request is first + scheduled and lives until the request is finished. + +- LoadSpec: A temporary state object created when a new request has a prefix + that matches data in the CPU cache (`get_num_new_matched_tokens`). It + holds the number of matched tokens and a `can_load` flag, which is set + to True only after the vLLM scheduler allocates the necessary blocks for + the load (`update_state_after_alloc`). + +- SaveSpec: A part of the work order sent to the worker. It instructs the + worker to save a specific slice of the KV cache from TPU to CPU. It + contains `num_skip_leading_tokens` to indicate which part of the request's + KV cache is new and needs saving, and an `is_final_save` flag to signal + the last save operation for a request. + +- TPUReqMeta: The unified work order for a single request in a single step, + sent from the scheduler to the worker. It can contain a `load_spec` (to + load from CPU to TPU), a `save_spec` (to save from TPU to CPU), or both. + +State Machine Flow (from the perspective of a request): + +1. RECEIVED -> AWAITING_ALLOCATION + - A new request arrives. + - `get_num_new_matched_tokens` checks the CPU backend for a matching + token prefix. + - If a match is found (N > 0 tokens), a `LoadSpec(num_matched_tokens=N, can_load=False)` + is created. The request now waits for the vLLM scheduler to allocate + physical blocks for these N tokens. + +2. AWAITING_ALLOCATION -> SCHEDULED + - The vLLM scheduler allocates blocks for the request. + - `update_state_after_alloc` is called. If a `LoadSpec` exists, its + `can_load` flag is set to True, greenlighting the load operation. + The request is now considered scheduled for processing in this step. + +3. SCHEDULED -> IN_FLIGHT or COMPLETED + - This transition is handled by `build_connector_meta` which calls the + central decision-making function, `_prepare_req_meta`. + - LoadSpec Preparation: The `LoadSpec` (if it exists and `can_load` + is True) is passed directly into the `TPUReqMeta`. The worker will + use `num_matched_tokens` to slice the correct prefix from the request's + `token_ids` and fetch the corresponding data from the CPU cache. + - SaveSpec Preparation: `_prepare_req_meta` determines if a save is + needed by comparing the total tokens processed so far + (`len(tracker.token_ids)`) with the number of tokens already saved + (`tracker.num_saved_tokens`). + - If `len(token_ids) > num_saved_tokens`, a `SaveSpec` is created. + - `num_skip_leading_tokens` is set to `tracker.num_saved_tokens`. This + tells the worker to ignore the prefix that's already in the CPU + cache and only save the new data. + - The scheduler then *transactionally* updates `tracker.num_saved_tokens` + to the new total length, ensuring this slice of data is not saved + again. + - If the scheduler has not finished the request, it transitions to + IN_FLIGHT. Its tracker is updated for the next scheduling step. + - If the scheduler has finished the request, it transitions to + COMPLETED. The tracker is removed, and a final `SaveSpec` is + generated. + - is_final_save: This flag is set to `True` only when the + scheduler marks a request as finished. It is a signal + for the worker, indicating that after this save is complete, the + request's lifecycle is over and its resources + can be safely freed. + +Worker Side Execution: +- The TPUOffloadConnectorWorker receives the `TPUOffloadConnectorMetadata` containing the list of + `TPUReqMeta` objects. +- `start_load_kv`: Iterates through the metadata. If a `meta.load_spec` + exists, it reads the corresponding data from the CPU backend and copies it + into the allocated blocks on the TPU. This is a blocking operation. +- `wait_for_save`: Iterates through the metadata. If a `meta.save_spec` + exists, it submits an asynchronous task to copy the specified slice of + KV data from TPU to CPU and update the CPU backend. It then waits for all + submitted save tasks for the current step to complete. +""" +import copy +import os +import time +from collections import defaultdict +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal, Optional, get_args + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import \ + KVConnectorStats +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import KVConnectorOutput + +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + from vllm.forward_context import ForwardContext + +from tpu_inference import envs +from tpu_inference.logger import init_logger +from tpu_inference.offload.cpu_backend import LocalCPUBackend +from tpu_inference.offload.offload_manager import (LRUCacheManager, + StagingBufferManager) +from tpu_inference.offload.utils import (CPU_OFFLOADING_SWAP_OP_TYPE, + CpuChunkId, KVCacheSwapFn, ReqId, + get_kv_cache_swap_fn, + jitted_insert_kv_cache_slices) +from tpu_inference.runner.kv_cache_manager import KVCacheManager +from tpu_inference.runner.tpu_runner import TPUModelRunner + +logger = init_logger(__name__) + +# kv cache layout needed by cpu offloading mechanism +REQUIRED_KV_CACHE_LAYOUT = "NHD" + +BLOCK_SIZE_BUCKETS = [1, 2, 4, 8, 16] + +# we keep our operations at vllm's block granularity, +# and want to provide the following three preferences when handling +# the last partial block during save: +# 1. [supported] drop: drop the entire partial block +# 2. pad: pad to a full block +# 3. dynamic: keep the partial block as is. +PARTIAL_BLOCK_SAVE_BEHAVIOR = Literal["drop"] + + +@dataclass +class SaveSpec: + """A confirmed work order for the worker to save KV data.""" + num_skip_leading_tokens: int + # total processed tokens for matching / saving + num_total_tokens: int + src_blocks: list[int] + dst_chunks: list[int] + # final save for the (newly) finished request + is_final_save: bool = False + # A direct signal to the worker to skip the data transfer but still + # process the completion signal if is_final_save is True. + skip_save: bool = False + + +@dataclass +class LoadSpec: + """Internal scheduler state for a potential load operation.""" + num_matched_tokens: int + src_chunks: list[int] + dst_blocks: list[int] + can_load: bool = False + num_skip_leading_tokens: int = 0 + + +@dataclass +class TPUReqMeta: + """A unified work order for a single request in a single step.""" + # The unique identifier for the request. + req_id: str + # For a load operation, this contains the prefix of tokens to be loaded + # from the cache. For a save operation, this contains the new tokens + # that have just been computed. + token_ids: list[int] + # The full list of physical blocks corresponding to the `token_ids`. + local_block_ids: list[int] + # An optional `SaveSpec` object. If present, it instructs the worker to + # perform a save operation. + save_spec: Optional[SaveSpec] = None + # An optional `LoadSpec` object. If present, it instructs the worker to + # perform a load operation. + load_spec: Optional[LoadSpec] = None + + def __repr__(self) -> str: + load_info = f"load_spec_exists={self.load_spec is not None}" + if self.load_spec: + load_info += ( + f", num_matched_tokens={self.load_spec.num_matched_tokens}, " + f"can_load={self.load_spec.can_load}, " + f"num_skip_leading_tokens={self.load_spec.num_skip_leading_tokens}, " + f"src_chunks={self.load_spec.src_chunks}, " + f"dst_blocks={self.load_spec.dst_blocks}") + save_info = f"save_spec_exists={self.save_spec is not None}" + if self.save_spec: + save_info += ( + f", num_skip_leading_tokens={self.save_spec.num_skip_leading_tokens}, " + f"num_total_tokens={self.save_spec.num_total_tokens}, " + f"is_final_save={self.save_spec.is_final_save}, " + f"skip_save={self.save_spec.skip_save}, " + f"dst_chunks={self.save_spec.dst_chunks}, " + f"src_blocks={self.save_spec.src_blocks}") + + return (f"TPUReqMeta(req_id={self.req_id}, " + f"num_token_ids={len(self.token_ids)}, " + f"num_local_block_ids={len(self.local_block_ids)}, " + f"{load_info}, {save_info})") + + +@dataclass +class RequestTracker: + """Tracks the evolving state of a single request across multiple scheduling steps.""" + # The unique identifier for the request. + req_id: str + # The total number of tokens in the original prompt. + prompt_len: int + # The full, cumulative list of physical block numbers allocated to this + # request so far. + block_ids: list[int] + # The full, cumulative list of token IDs that have been processed for this + # request so far. This list only contains the + # tokens to be computed, not the prefix loaded from cache. + token_ids: list[int] + # The number of tokens that were a hit in the CPU cache at the beginning + # of the request. This is constant for the lifetime of the request. + num_external_hits: int = 0 + # A high-water mark indicating how many tokens from the start of the + # computed tokens (`token_ids`) have already been saved to the CPU cache. + save_watermark: int = 0 + # Whether the request is in the decoding phase (generating one token at a time). + is_decode_phase: bool = False + + def update(self, new_block_ids: list[int], new_token_ids: list[int]): + """Appends new block IDs and token IDs to the tracker.""" + if new_block_ids is None: + new_block_ids = [] + elif len(new_block_ids) == 0: + new_block_ids = [] + elif isinstance(new_block_ids, tuple): + new_block_ids = new_block_ids[0] + elif isinstance(new_block_ids, list): + pass + else: + raise ValueError( + f"Unsupported new_block_ids type {type(new_block_ids)}") + self.block_ids.extend(new_block_ids) + self.token_ids.extend(new_token_ids) + + # NOTE(jcgu): is it always true? will MTP affect this judgement? + # When a request is scheduled again, and the number of new tokens + # is 1 (excluding chunked prefill), the request is in decode phase. + if len(new_token_ids) == 1: + self.is_decode_phase = True + + def __repr__(self) -> str: + output_str = " - RequestTracker: " + \ + f"req_id={self.req_id}, " + \ + f"prompt_len={self.prompt_len}, " + \ + f"num_tokens={len(self.token_ids)}, " + \ + f"num_blocks={len(self.block_ids)}, " + \ + f"save_watermark={self.save_watermark}" + return output_str + + +@dataclass +class KVOffloadConnectorStats(KVConnectorStats): + """Container for transfer performance metrics""" + + def __post_init__(self): + if not self.data: + # Empty container init, no data is passed in. + self.reset() + + def reset(self): + # Must be serializable + self.data: dict[str, dict[str, list[int]]] = { + "finished_save_chunks": dict(), + "finished_load_chunks": dict(), + } + + def record_save(self, req: ReqId, saved_chunk_ids: list[int]): + if req not in self.data["finished_save_chunks"]: + self.data["finished_save_chunks"][req] = [] + self.data["finished_save_chunks"][req].extend( + copy.deepcopy(saved_chunk_ids)) + + def record_load(self, req: ReqId, loaded_chunk_ids: list[int]): + if req not in self.data["finished_load_chunks"]: + self.data["finished_load_chunks"][req] = [] + self.data["finished_load_chunks"][req].extend( + copy.deepcopy(loaded_chunk_ids)) + + def clone_and_reset(self) -> "KVOffloadConnectorStats": + old = copy.copy(self) + self.reset() + return old + + def is_empty(self) -> bool: + return self.num_finished_blocks == 0 + + def aggregate(self, other: KVConnectorStats) -> KVConnectorStats: + return self + + def reduce(self) -> dict[str, int | float]: + # Compute compact representative stats suitable for CLI logging + if self.is_empty(): + return { + "Num finished save blocks ": 0, + "Num finished load blocks ": 0, + } + + finished_save_chunks = sum( + len(chunk_list) + for chunk_list in self.data["finished_save_chunks"].values()) + finished_load_chunks = sum( + len(chunk_list) + for chunk_list in self.data["finished_load_chunks"].values()) + + return { + "Num finished save chunks ": finished_save_chunks, + "Num finished load chunks": finished_load_chunks, + } + + @property + def num_finished_blocks(self) -> int: + return len(self.data["finished_save_chunks"]) + len( + self.data["finished_load_chunks"]) + + +# The metadata used for communicating between scheduler and worker connectors. +@dataclass +class TPUOffloadConnectorMetadata(KVConnectorMetadata): + requests_meta: list[TPUReqMeta] = field(default_factory=list) + + +class TPUOffloadConnector(KVConnectorBase_V1): + + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: KVCacheConfig | None = None, + ): + super().__init__(vllm_config, role, kv_cache_config) + logger.info("TPUOffloadConnector: Entering __init__") + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = \ + TPUOffloadConnectorScheduler(vllm_config) + self.connector_worker = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + # The worker needs a reference to the base connector to access + # the metadata object set by the engine. + self.connector_worker = TPUOffloadConnectorWorker( + vllm_config, self) + + ############################################################ + # Class Methods + ############################################################ + @classmethod + def get_required_kvcache_layout(cls, vllm_config: VllmConfig): + if vllm_config.model_config is None: + logger.warning_once("Unable to detect current VLLM config. " + "Fallback to default kv cache layout.") + return None + + # TODO(jcgu): test mla + use_mla = vllm_config.model_config.use_mla + if use_mla: + # which fallback to the default behavior. + return None + + logger.info_once( + "TPUOffloadConnector currently only supports %s KV cache layout.", + REQUIRED_KV_CACHE_LAYOUT) + return REQUIRED_KV_CACHE_LAYOUT + + @classmethod + def build_kv_connector_stats( + cls, + data: dict[str, dict[str, int]] | None = None + ) -> KVConnectorStats | None: + return (KVOffloadConnectorStats( + data=data) if data is not None else KVOffloadConnectorStats()) + + ############################################################ + # Scheduler Side Methods + ############################################################ + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> TPUOffloadConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: list[jax.Array]): + logger.info("TPUOffloadConnector: Entering register_kv_caches") + """ + We don't register kv_caches in connector, we call `register_runner` and + use runner.kv_caches directly instead because the ref of runner.kv_caches + would be reassigned during model forward. + """ + pass + + def register_runner(self, runner: TPUModelRunner) -> None: + logger.info("TPUOffloadConnector: Entering register_runner") + assert self.connector_worker is not None + self.connector_worker.register_runner(runner) + + def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: + """Starts loading the KV cache for the given requests.""" + assert self.connector_worker is not None + self.connector_worker.start_load_kv(fwd_ctx) + + def wait_for_layer_load(self, layer_name: str) -> None: + logger.info("TPUOffloadConnector: Entering wait_for_layer_load") + """TPU connector doesn't support layer wise load.""" + pass + + def save_kv_layer(self, **kwargs) -> None: + logger.info("TPUOffloadConnector: Entering save_kv_layer") + """TPU connector doesn't support layer wise save.""" + pass + + def wait_for_save(self): + assert isinstance(self._connector_metadata, + TPUOffloadConnectorMetadata) + self.connector_worker.wait_for_save() + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def update_connector_output(self, connector_output: KVConnectorOutput): + assert self.connector_scheduler is not None + self.connector_scheduler.update_connector_output(connector_output) + + def get_kv_connector_stats(self) -> KVConnectorStats | None: + if self.connector_worker is None: + return None + return self.connector_worker.get_kv_connector_stats() + + +class TPUOffloadConnectorScheduler(): + + def __init__(self, vllm_config: "VllmConfig"): + logger.info("TPUOffloadConnectorScheduler: Entering __init__") + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + + # offloading manager + self.num_cpu_chunks = envs.TPU_OFFLOAD_NUM_CPU_CHUNKS + self.offload_manager = LRUCacheManager( + num_cpu_chunks=self.num_cpu_chunks) + + self._request_trackers: dict[ReqId, RequestTracker] = {} + # This dictionary holds the full vLLM Request object for all requests + # that are currently in a running state (i.e., have been scheduled but + # are not yet finished). It's used to access the complete prompt token + # list when processing incremental updates for cached/running requests, + # as the scheduler output for these requests is minimal. + self._unfinished_requests: dict[ReqId, "Request"] = {} + self.load_specs: dict[ReqId, LoadSpec] = {} + # requests with load ops that have been considered by vllm scheduler, + # not all of them will be scheduled, the scheduled ones will be + # moved to load_specs. + # it should be cleaned after ConnectorMetadata's creation + self._pre_load_specs: dict[ReqId, LoadSpec] = {} + + # {reqid: total_num_matched_tokens_in_cpu_backend} + self._external_cache_hits: dict[ReqId, int] = {} + + # request ID -> set(block hashes being saved/loaded) + self._reqs_being_saved = defaultdict[ReqId, set[CpuChunkId]](set) + self._reqs_being_loaded = defaultdict[ReqId, set[CpuChunkId]](set) + + model_name = self.vllm_config.model_config.model + + self.decode_save = envs.TPU_OFFLOAD_DECODE_SAVE + # NOTE(jcgu): currently, let's make chunk_size == block_size + # chunk_size == n * block_size lead to + # 1. multi-size chunks + # 2. complicated resize (split, concatenate) operations due to + # real-chunk-size in save and load + self.cpu_chunk_size = self.block_size + + self.partial_block_save_behavior: PARTIAL_BLOCK_SAVE_BEHAVIOR = "drop" + + # config staging buffer + # NOTE(jcgu): Need to find a way to grab page_size_bytes in scheduler + # otherwise, we can only use # of blocks as input, instead of buffer size in GB + self.num_staging_blocks = envs.TPU_OFFLOAD_NUM_STAGING_BLOCKS + self.staging_buffer_manager = StagingBufferManager( + num_blocks=self.num_staging_blocks) + + logger.info( + f"TPUOffloadConnectorScheduler initialized with: " + f"block_size={self.block_size}, " + f"cpu_chunk_size={self.cpu_chunk_size}, " + f"num_cpu_chunks={self.num_cpu_chunks}, " + f"model_name={model_name}, " + f"decode_save={self.decode_save}, " + f"partial_block_save_behavior={self.partial_block_save_behavior}, " + f"num_staging_blocks={self.num_staging_blocks}.") + + def _get_request_block_hashes(self, req: "Request") -> list[BlockHash]: + # request's original block_hashes do not include the last partial block + # TODO(jcgu): add an option to use local token_processor + return req.block_hashes + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Checks for external KV cache hit against the local CPU backend. + """ + assert num_computed_tokens % self.block_size == 0, f"{num_computed_tokens} % {self.block_size} != 0" + # get block_hash + block_hashes = self._get_request_block_hashes(request) + num_total_blocks = len(block_hashes) + prompt_token_ids = request.prompt_token_ids + logger.info(f"Request {request.request_id}: Checking for cache hit. " + f"Prompt length: {len(prompt_token_ids)}, " + f"Block_hashes ({num_total_blocks})," + f"Already computed tokens: {num_computed_tokens}. ") + + # look for blocks in the cache + num_hits = self.offload_manager.lookup(block_hashes) + matched_block_hashes = block_hashes[:num_hits] + self.offload_manager.touch(block_hashes) + num_matched_blocks = len(matched_block_hashes) + # num_matched_tokens = min(num_matched_blocks * self.block_size, + # len(prompt_token_ids)) + num_matched_tokens = num_matched_blocks * self.block_size + assert num_matched_tokens <= len(prompt_token_ids) + num_computed_blocks = num_computed_tokens // self.block_size + num_blocks_to_load = max(num_matched_blocks - num_computed_blocks, 0) + logger.info( + f"Request {request.request_id}: Found {num_matched_tokens} (out of {len(prompt_token_ids)} prompt tokens) matched tokens ({num_matched_blocks} blocks) in CPU backend (computed_blocks: {num_computed_blocks}, blocks_to_load: {num_blocks_to_load})." + ) + + if num_blocks_to_load > 0: + # planning staging blocks for load + # NOTE(jcgu): do not worry about the inconsistency of the staging buffer status; + # there is only one connector scheduler who is operating on it. + num_avail_staging_blocks = self.staging_buffer_manager.get_num_free_staging_blocks( + ) + if num_blocks_to_load > num_avail_staging_blocks: + # reduce blocks_to_load (and matched tokens) when there are insufficient staging blocks. + logger.info( + f" Req({request.request_id}) found {num_matched_blocks} blocks ({num_matched_tokens} tokens), but only {num_avail_staging_blocks} staging blocks available." + ) + num_blocks_to_load = num_avail_staging_blocks + num_matched_blocks = num_blocks_to_load + num_computed_blocks + num_matched_tokens = num_matched_blocks * self.block_size + + # still have something to load + if num_blocks_to_load > 0: + # NOTE(jcgu): put dummy chunk / block ids; + # fill real ids later when the requests gets scheduled + src_chunk_ids = [-1] * num_blocks_to_load + dummy_dst_blocks = [-1] * num_blocks_to_load + self._pre_load_specs[request.request_id] = LoadSpec( + num_matched_tokens=num_matched_tokens, + src_chunks=src_chunk_ids, + dst_blocks=dummy_dst_blocks, + num_skip_leading_tokens=num_computed_tokens, + ) + num_allocated_staging_blocks = self.staging_buffer_manager.allocate( + request.request_id, + num_blocks=num_blocks_to_load, + usage="load") + assert num_allocated_staging_blocks == num_blocks_to_load >= 0, f" failed to allocate {num_allocated_staging_blocks} (load) staging blocks for request {request.request_id}, expected {num_blocks_to_load}." + + # record the matched tokens in the cache, it will be needed in + # init save_spec + self._external_cache_hits[request.request_id] = num_matched_tokens + + is_full_prefix_hit = (num_matched_tokens > 0 + and num_matched_tokens == len(prompt_token_ids)) + num_matched_for_scheduler = num_matched_tokens + if is_full_prefix_hit: + # When the entire prompt is found in the CPU cache (a "full hit"), + # report N-1 matched tokens to the vLLM scheduler instead + # of the true N. If we report a 100% match (N + # matched tokens for a prompt of length N), the scheduler sees + # zero new tokens and may not schedule the request for a prefill + # step at all and hits + # https://github.com/vllm-project/vllm/blob/b8b302cde434df8c9289a2b465406b47ebab1c2d/vllm/v1/core/sched/scheduler.py#L438 assetion. + # By reporting N-1, we ensure the scheduler allocates resources + # for and schedules the computation of the "last" token of the + # prompt. The worker (`start_load_kv`) still load the KV of N + # matched tokens, but the final token'KV will not be used, but be + # "re-computed" in the following forward pass (the loaded data in + # the slot gets override.) And from there, the request can + # seamlessly transition to the decoding phase. + num_matched_for_scheduler = num_matched_tokens - 1 + logger.info( + f"Request {request.request_id}: Full prompt hit. Reporting {num_matched_for_scheduler} matched tokens. Actual hit from backend is {num_matched_tokens} tokens" + ) + + # Note on unpinning for the full prefix hit case: Although we report N-1 tokens + # to the scheduler, the RequestTracker (created later in + # `build_connector_meta`) stores the true, full N prompt tokens. + # The `get_finished` method on the worker side uses this complete + # token list to regenerate the keys, ensuring that all N keys + # originally pinned during this lookup are gracefully unpinned upon + # request completion. + # We don't need to load tokens that are already computed locally in vLLM + num_to_load = max(0, num_matched_for_scheduler - num_computed_tokens) + logger.info( + f"Request {request.request_id}: After accounting for {num_computed_tokens} computed tokens, reporting {num_to_load} tokens to load." + ) + + # external_computed_tokens, load_kv_async + return num_to_load, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + """ + This hook is not used for the save logic. + Update the dst_blocks in the load_spec + """ + logger.info( + f"TPUOffloadConnectorScheduler: Entering update_state_after_alloc Request {request.request_id}: Scheduler allocated " + f"{num_external_tokens} external tokens.") + self._unfinished_requests[request.request_id] = request + if num_external_tokens == 0: + return + + # retrieve the load_spec + load_spec = self._pre_load_specs.pop(request.request_id, None) + if load_spec: + assert load_spec.num_skip_leading_tokens % self.block_size == 0 + assert len(load_spec.src_chunks) == len(load_spec.dst_blocks) + skip_leading_blocks = load_spec.num_skip_leading_tokens // self.block_size + num_blocks_to_load = len(load_spec.src_chunks) + num_matched_blocks = num_blocks_to_load + skip_leading_blocks + assert num_matched_blocks == load_spec.num_matched_tokens // self.block_size, f"{num_matched_blocks} != {load_spec.num_matched_tokens} // {self.block_size}" + + block_hashes = self._get_request_block_hashes(request) + all_blocks = blocks.get_block_ids()[0] + logger.info( + f" Request: {request.request_id} has {len(all_blocks)} blocks / {len(block_hashes)} block hashes." + ) + + # get the src chunk ids to load + block_hashes_to_load = block_hashes[ + skip_leading_blocks:num_matched_blocks] + chunks_to_load = self.offload_manager.prepare_load( + block_hashes_to_load) + src_chunk_ids = [chunk.chunk_id for chunk in chunks_to_load] + + # get dst block ids + dst_blocks = all_blocks[skip_leading_blocks:num_matched_blocks] + + # update load spec + load_spec.src_chunks = src_chunk_ids + load_spec.dst_blocks = dst_blocks + load_spec.can_load = True + self.load_specs[request.request_id] = load_spec + self._reqs_being_loaded[request.request_id] |= set( + load_spec.src_chunks) + logger.info( + f"Request {request.request_id} ({len(dst_blocks)} dst_blocks) is ready to load." + ) + + def _prepare_req_meta( + self, + tracker: RequestTracker, + load_spec: Optional[LoadSpec], + is_finished: bool, + ) -> Optional[TPUReqMeta]: + """ + Central decision-making function. Determines if a save or load is + needed and prepares the metadata. Also performs the transactional + update of the tracker's save state. + """ + req_id = tracker.req_id + _request = self._unfinished_requests[req_id] + block_hashes = self._get_request_block_hashes(_request) + self.offload_manager.touch(block_hashes) + + # only consider the tokens covered by block_hashes; + # currently full blocks only + num_total_blocks = len(block_hashes) + num_total_tokens = min(num_total_blocks * self.block_size, + len(tracker.token_ids)) + num_full_blocks = num_total_tokens // self.block_size + num_full_block_tokens = num_full_blocks * self.block_size + adjusted_num_total_tokens = num_full_block_tokens + adjusted_num_total_blocks = num_full_blocks + assert adjusted_num_total_blocks <= len(tracker.block_ids) + + has_new_tokens = adjusted_num_total_tokens > tracker.save_watermark + should_save = False + # Determine if a save is needed for this step + # when there are new token KVs: + # 1. Prefill: always save + # 2. Decode (with save_decode=True) + # 2.1 regular decode (not finished): accumulate until getting a full block + # 2.2 request finished: save + if has_new_tokens: + if not tracker.is_decode_phase: + # Prefill: always save the new-computed blocks + should_save = True + elif self.decode_save: + if is_finished: + # After decode, if there are new final new tokens to save + should_save = True + else: + # During decode, we do not drop or pad, just accumulate tokens until the next block boundary + next_block_boundary = ( + tracker.save_watermark // self.block_size + + 1) * self.block_size + logger.info( + f"in decode phase, next_block_boundary: {next_block_boundary}, " + ) + if num_total_tokens == next_block_boundary: + # always save the full block for decode (not affected by saving_behavior) + assert num_total_tokens == adjusted_num_total_tokens, f" decode_save: {num_total_tokens} != (adjusted) {adjusted_num_total_tokens}" + should_save = True + + logger.info(f" - Preparing meta for req (save): {tracker.req_id}, " + f"is_finished={is_finished}, " + f"total_tokens={num_total_tokens}, " + f"adjusted_num_total_tokens={adjusted_num_total_tokens}, " + f"adjusted_num_total_blocks={adjusted_num_total_blocks}, " + f"saved_tokens={tracker.save_watermark}, " + f"has_new={has_new_tokens}, " + f"is_decode={tracker.is_decode_phase}, " + f"should_save={should_save}") + + # A SaveSpec is always prepared for a finished request to signal completion, + # even if we don't save the underlying KV data. This is to ensure the TPUOffloadConnectorWorker + # can correctly report finished request. + save_spec = None + if should_save: + # get src block_ids for save + # NOTE(jcgu): recompute skip_leading_blocks + # if tracker.save_watermark has partial tokens in the last block + # and we saved (i.e., pad) the entire block to cpu_backend, now we + # want to save the kv of the new tokens in that block; because of + # the new tokens in that block's token sequence, the block will + # have a new key (hash value) in cpu_backend, so we should treat + # the block as a new cache and save the entire block. + # Example: + # we have saved: + # blocks: [------b0------] [------b1------] + # tokens: [t0, t1, t2, t3] [t4, t5,] + # cpu-backend:{key0: b0, key1:b1(2 tokens, padded)} + # + # Now, we have 2 new tokens in the sequence + # blocks: [------b0------] [------b1------] + # tokens: [t0, t1, t2, t3] [t4, t5, t6, t7] + # cpu-backend:{key0: b0, key1:b1(2 tokens, padded), + # key1_2: b1_2(4 tokens)} + # In cpu-backend, since b1's token-sequence has been changed, it + # will have a new key. + # + # if we always drop the partial-filled block when saving, then there + # will no such an issue. + num_skip_leading_blocks = tracker.save_watermark // self.block_size + num_skip_leading_tokens = num_skip_leading_blocks * self.block_size + num_blocks_to_save = adjusted_num_total_blocks - num_skip_leading_blocks + + # planning staging blocks for save + num_avail_staging_blocks = self.staging_buffer_manager.get_num_free_staging_blocks( + ) + if num_blocks_to_save > num_avail_staging_blocks: + # reduce blocks_to_save due to limited free staging blocks + logger.info( + f" Req({tracker.req_id}) have {num_blocks_to_save} ({adjusted_num_total_blocks} - {num_skip_leading_blocks}) blocks to save, but only {num_avail_staging_blocks} staging blocks available." + ) + num_blocks_to_save = num_avail_staging_blocks + adjusted_num_total_blocks = num_skip_leading_blocks + num_blocks_to_save + adjusted_num_total_tokens = adjusted_num_total_blocks * self.block_size + + if num_blocks_to_save > 0: + block_hashes_to_save = block_hashes[ + num_skip_leading_blocks:adjusted_num_total_blocks] + allocate_output = self.offload_manager.allocate_for_save( + block_hashes_to_save) + if allocate_output is not None: + # there are enough chunks to save + chunks_for_save, chunk_idxs = allocate_output + assert num_blocks_to_save == len(chunks_for_save) + src_block_ids = tracker.block_ids[ + num_skip_leading_blocks:adjusted_num_total_blocks] + + dst_chunks = [chunk.chunk_id for chunk in chunks_for_save] + src_blocks = [src_block_ids[idx] for idx in chunk_idxs] + + # This is a real save operation. + save_spec = SaveSpec( + num_skip_leading_tokens=num_skip_leading_tokens, + num_total_tokens=adjusted_num_total_tokens, + is_final_save=is_finished, + skip_save=False, + src_blocks=src_blocks, + dst_chunks=dst_chunks, + ) + self._reqs_being_saved[req_id] |= set(dst_chunks) + num_allocated_blocks = self.staging_buffer_manager.allocate( + tracker.req_id, + num_blocks=num_blocks_to_save, + usage="save") + assert num_allocated_blocks == num_blocks_to_save >= 0, f" failed to allocate {num_allocated_blocks} (save) staging blocks for request {tracker.req_id}, expected {num_blocks_to_save}." + + if adjusted_num_total_tokens > tracker.save_watermark: + logger.info( + f" -> Old watermark {tracker.save_watermark}, new save_watermark count: {adjusted_num_total_tokens}" + ) + tracker.save_watermark = adjusted_num_total_tokens + + if is_finished and save_spec is None: + # For finished requests, there must be a no-op save to update the state in the worker side. + # This is a "completion-only" signal because should_save is False. + # NOTE(jcgu): num_total_tokens will be used to unpin tokens; + # apply the number of saved tokens; + # TODO(jcgu): rm the no-op save, since save status has been updated + # through kv_connector_output.kv_connector_stats + save_spec = SaveSpec( + num_skip_leading_tokens=tracker.save_watermark, + num_total_tokens=tracker.save_watermark, + src_blocks=[], + dst_chunks=[], + is_final_save=True, + skip_save=True, + ) + + # 2. Determine if a work order is needed. + if not save_spec and not (load_spec and load_spec.can_load): + return None + + # 3. Construct and return the final work order. + return TPUReqMeta( + req_id=tracker.req_id, + token_ids=tracker.token_ids, + local_block_ids=tracker.block_ids, + save_spec=save_spec, + load_spec=load_spec, + ) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput) -> TPUOffloadConnectorMetadata: + metadata = TPUOffloadConnectorMetadata() + + # Phase 1: Handle and clean up finished requests + # This block handles requests that have completed their generation. + # We pop their state from our tracking dictionaries and call _prepare_req_meta + # one last time. This ensures any final, unsaved tokens are captured and + # signals to the worker that this is the final save for the request. + logger.info( + f"Phase 1: Processing {len(scheduler_output.finished_req_ids)} finished requests." + ) + for finished_req_id in scheduler_output.finished_req_ids: + logger.info(f" - Processing finished req: {finished_req_id}") + tracker = self._request_trackers[finished_req_id] + + if not tracker: + logger.warning( + f" - No tracker found for finished req: {finished_req_id}. Skipping." + ) + continue + + # Prepare one final metadata object if there's a final save needed. + # `is_finished` is set to True to flag this as the last save operation. + req_meta = self._prepare_req_meta(tracker, + load_spec=None, + is_finished=True) + if req_meta: + logger.info( + f" - Creating final save metadata for req: {finished_req_id}" + ) + metadata.requests_meta.append(req_meta) + + # Pop tracker and other state first. + self._request_trackers.pop(finished_req_id, None) + self._unfinished_requests.pop(finished_req_id, None) + self.load_specs.pop(finished_req_id, None) + + # Phase 2: Process newly scheduled requests + # This block handles requests being scheduled for the very first time. + # It creates the initial RequestTracker and prepares the first work order. + logger.info( + f"Phase 2: Processing {len(scheduler_output.scheduled_new_reqs)} new requests." + ) + for request in scheduler_output.scheduled_new_reqs: + req_id = request.req_id + + _request = self._unfinished_requests[req_id] + logger.info( + f" - Processing new req: {req_id}, {len(_request.block_hashes)} block_hashes." + ) + num_new_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + + # Get the external cache hit count from our new, reliable source. + num_external_hits = self._external_cache_hits.pop(req_id, 0) + + # Determine the total length of tokens the tracker should hold. + # This is vLLM's already computed tokens + newly scheduled tokens. + num_total_tokens_for_tracker = request.num_computed_tokens + num_new_scheduled_tokens + tokens_for_tracker = request.prompt_token_ids[: + num_total_tokens_for_tracker] + logger.info( + f" - num_new_scheduled_tokens: {num_new_scheduled_tokens}, num_vllm_computed: {request.num_computed_tokens}, num_external_hits: {num_external_hits}" + ) + logger.info( + f" - Slicing prompt[:{num_total_tokens_for_tracker}] -> len(tokens_for_tracker): {len(tokens_for_tracker)}" + ) + + # Set the initial high-water mark for `save_watermark`. + # This is the maximum of what vLLM has computed and what's in our external cache. + initial_save_watermark = max(request.num_computed_tokens, + num_external_hits) + + # Create and store the tracker, which will maintain the request's + # state for its entire lifetime. + assert req_id not in self._request_trackers, f"Request {req_id} already has a tracker." + # TODO(jcgu): reduce duplicated info in request tracker + tracker = RequestTracker( + req_id=req_id, + prompt_len=len(request.prompt_token_ids), + block_ids=copy.deepcopy(request.block_ids[0]), + token_ids=tokens_for_tracker, + num_external_hits=num_external_hits, + # The high-water mark for saved tokens starts after the cached prefix. + save_watermark=initial_save_watermark, + ) + self._request_trackers[req_id] = tracker + logger.info( + f" - Created tracker for {req_id} with initial state: {tracker}" + ) + + # Immediately prepare metadata for this new request. + # This could include both a load operation (for the cached part) + # and a save operation (for the newly computed part). + load_spec = self.load_specs.pop(req_id, None) + req_meta = self._prepare_req_meta(tracker, + load_spec, + is_finished=False) + if req_meta: + logger.info(f" - Creating metadata for new req: {req_id} " + f"(has_load={req_meta.load_spec is not None}, " + f"has_save={req_meta.save_spec is not None})") + metadata.requests_meta.append(req_meta) + + # Phase 3: Process cached (running) requests + # This block handles requests that have already been pre-filled at least + # once and are now being processed again (e.g., for chunked prefill). + cached_reqs = scheduler_output.scheduled_cached_reqs + logger.info( + f"Phase 3: Processing {len(cached_reqs.req_ids)} cached requests.") + for i, req_id in enumerate(cached_reqs.req_ids): + tracker = self._request_trackers[req_id] + full_request = self._unfinished_requests.get(req_id) + _block_hashes = full_request.block_hashes + logger.info( + f" - Processing cached req: {req_id}, {len(_block_hashes)} block_hashes." + ) + + if full_request is None: + logger.warning( + f" - No full request found for cached req: {req_id}. Skipping." + ) + continue + + # num_new_tokens: The number of *additional* tokens the scheduler is + # processing in this step for this ongoing request. + num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] + + # current_token_count: This is the crucial calculation to find our + # place in the full prompt. It's the length of the token prefix + # already processed in previous steps. + current_token_count = len(tracker.token_ids) + + logger.info( + f" - len(full_request.all_token_ids): {len(full_request.all_token_ids)}" + ) + # new_token_ids: The slice of the full token sequence corresponding to the + # new work being done in this step. + new_token_ids = full_request.all_token_ids[ + current_token_count:current_token_count + num_new_tokens] + + # new_blocks: The new physical blocks allocated for the new_token_ids. + new_blocks = cached_reqs.new_block_ids[i] + if new_blocks is None: + new_blocks = [] + + logger.info( + f" - num_new_tokens: {num_new_tokens}, current_token_count: {current_token_count}" + ) + logger.info( + f" - Slicing prompt -> len(new_token_ids): {len(new_token_ids)}" + ) + logger.info(f" - New blocks allocated: {len(new_blocks)}") + + # Update the tracker with the incremental data. + tracker.update(new_blocks, new_token_ids) + logger.info(f" - Updated tracker for {req_id}: " + f"total_tokens={len(tracker.token_ids)}, " + f"total_blocks={len(tracker.block_ids)}") + + # for cached requests, whose kv pages get evicted, there will be + # load operations. + load_spec = self.load_specs.pop(req_id, None) + req_meta = self._prepare_req_meta(tracker, + load_spec=load_spec, + is_finished=False) + if req_meta: + logger.info( + f" - Creating metadata for cached req: {req_id} " + f"(has_save={req_meta.save_spec is not None})") + metadata.requests_meta.append(req_meta) + + if metadata.requests_meta: + logger.info( + f"Prepared {len(metadata.requests_meta)} requests for worker.") + + # after building connector_metadata, all load_specs should be consumed + assert len( + self.load_specs + ) == 0, f" load_specs still has {list(self.load_specs.keys())}" + + # clean up the temporary states of requests that are not scheduled + for req_id, _load_spec in self._pre_load_specs.items(): + logger.info(f"non-scheduled-reuqest:{req_id}") + _freed_num_staging_blocks = self.staging_buffer_manager.free( + req_id, "load") + assert _freed_num_staging_blocks == len( + _load_spec.src_chunks + ), f"{_freed_num_staging_blocks} != {len(_load_spec.src_chunks)}" + self._pre_load_specs.clear() + self._external_cache_hits.clear() + + return metadata + + def update_connector_output(self, connector_output: KVConnectorOutput): + """ + Update KVConnector state from worker-side connectors output. + + Args: + connector_output (KVConnectorOutput): the worker-side + connectors output. + """ + logger.info( + f"TPUOffloadConnectorScheduler: getting workers' output: finished_sending: {connector_output.finished_sending}, finished_recving: {connector_output.finished_recving}" + ) + + # per iteration, update the finished staging blocks + if connector_output.kv_connector_stats and connector_output.kv_connector_stats.data is not None: + assert isinstance(connector_output.kv_connector_stats, + KVOffloadConnectorStats) + assert "finished_save_chunks" in connector_output.kv_connector_stats.data + assert "finished_load_chunks" in connector_output.kv_connector_stats.data + for req_id, saved_chunk_ids in connector_output.kv_connector_stats.data[ + "finished_save_chunks"].items(): + num_saved_chunks = len(saved_chunk_ids) + logger.info( + f" finished_save_chunks for {req_id}: {saved_chunk_ids}") + # free staging blocks + self.staging_buffer_manager.free( + req_id, usage="save", num_finished_blocks=num_saved_chunks) + # update in-flight save + for saved_chunk_id in saved_chunk_ids: + assert saved_chunk_id in self._reqs_being_saved[req_id] + self._reqs_being_saved[req_id].remove(saved_chunk_id) + if len(self._reqs_being_saved[req_id]) == 0: + self._reqs_being_saved.pop(req_id, None) + else: + logger.info( + f" remaining_saving_blocks:{req_id}, { self._reqs_being_saved[req_id]}." + ) + + # update the status of occupied cpu chunks + self.offload_manager.mark_completion(saved_chunk_ids, "save") + + for req_id, loaded_chunk_ids in connector_output.kv_connector_stats.data[ + "finished_load_chunks"].items(): + num_loaded_chunks = len(loaded_chunk_ids) + logger.info( + f" finished_load_chunks for {req_id}: {num_loaded_chunks}" + ) + self.staging_buffer_manager.free( + req_id, + usage="load", + num_finished_blocks=num_loaded_chunks) + # update in-flight save + for loaded_chunk_id in loaded_chunk_ids: + assert loaded_chunk_id in self._reqs_being_loaded[req_id] + self._reqs_being_loaded[req_id].remove(loaded_chunk_id) + if len(self._reqs_being_loaded[req_id]) == 0: + self._reqs_being_loaded.pop(req_id, None) + # update the status of occupied cpu chunks + self.offload_manager.mark_completion(loaded_chunk_ids, "load") + + # clean up the status of the finished requests + # save + for req_id in connector_output.finished_sending or []: + if req_id in self._reqs_being_saved: + assert len(self._reqs_being_saved[req_id]) == 0 + self._reqs_being_saved.pop(req_id) + num_freed_blocks = self.staging_buffer_manager.free(req_id, + usage="save") + logger.info( + f" freed {num_freed_blocks} staging blocks (save) from {req_id}" + ) + + # load + for req_id in connector_output.finished_recving or []: + if req_id in self._reqs_being_loaded: + assert len(self._reqs_being_loaded[req_id]) == 0 + self._reqs_being_loaded.pop(req_id) + num_freed_blocks = self.staging_buffer_manager.free(req_id, + usage="load") + logger.info( + f" freed {num_freed_blocks} staging blocks (load) from {req_id}" + ) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + return: + delay_free_blocks, kv_xfer_params + """ + logger.info(" Entering request_finished") + # Return True to indicate the request is being saved asynchronously + # and its blocks should not be freed yet. + + req_id = request.request_id + if req_id in self._reqs_being_saved and len( + self._reqs_being_saved[req_id]) > 0: + logger.info( + f"not_free_with_save:{req_id}, {self._reqs_being_saved[req_id]}" + ) + return True, None + if req_id in self._reqs_being_loaded and len( + self._reqs_being_loaded[req_id]) > 0: + logger.info( + f"not_free_with_load:{req_id}, {self._reqs_being_loaded[req_id]}" + ) + return True, None + + logger.info(f" finished request: {req_id}") + self._reqs_being_saved.pop(req_id, None) + self._reqs_being_loaded.pop(req_id, None) + + return False, None + + +class TPUOffloadConnectorWorker: + + def __init__(self, vllm_config: VllmConfig, + connector: "TPUOffloadConnector"): + logger.info("TPUOffloadConnectorWorker: Entering __init__") + self.vllm_config = vllm_config + self.connector = connector + self.block_size = vllm_config.cache_config.block_size + + self.runner: Optional[TPUModelRunner] = None + self.mesh: Optional[Mesh] = None + self.swap_in_fn: KVCacheSwapFn = None + self.swap_out_fn: KVCacheSwapFn = None + self.swap_op_type = envs.TPU_OFFLOAD_SWAP_OP_TYPE + # TODO(jcgu): check libtpu compatibility for pallas dma kernel + assert self.swap_op_type in get_args(CPU_OFFLOADING_SWAP_OP_TYPE) + self.use_bucketed_swap_ops = not envs.TPU_OFFLOAD_SKIP_JAX_PRECOMPILE + logger.info(f" swap operation type is {self.swap_op_type}, " + f"use_bucketed_swap_ops={self.use_bucketed_swap_ops}.") + + # cpu cache + self.num_cpu_chunks = envs.TPU_OFFLOAD_NUM_CPU_CHUNKS + self.cpu_backend = LocalCPUBackend(num_cpu_chunks=self.num_cpu_chunks) + # The worker needs its own token processor to generate keys. + model_name = self.vllm_config.model_config.model + logger.info( + f"Model name is {model_name}, KV block_size={self.block_size}") + + self.cpu_chunk_size = self.block_size + # Thread pool for asynchronous TPU->CPU copies + self.num_save_threads = envs.TPU_OFFLOAD_SAVE_THREADS + self.save_executor = ThreadPoolExecutor( + max_workers=self.num_save_threads, thread_name_prefix="tpu_save_handler") + self.finished_save_reqs: set[ReqId] = set() + self.finished_load_reqs: set[ReqId] = set() + # Tracks if wait_for_save has been called for the current step's metadata. + self._processed_save_for_step = False + + # record finished save / load blocks (with req_ids) for each iteration + self.offload_stats = KVOffloadConnectorStats() + + def __del__(self): + logger.info("TPUOffloadConnectorWorker: Entering __del__") + self.save_executor.shutdown(wait=True) + + def register_runner(self, runner: TPUModelRunner): + logger.info("TPUOffloadConnectorWorker: Entering register_runner") + self.runner = runner + self.devices = runner.devices + self.mesh = runner.mesh + # Get the spec of the kv_caches + kv_caches = runner.kv_caches + if kv_caches: + self.kv_cache_layout = runner.get_kv_cache_layout() + kv_layer = kv_caches[0] + self.num_layers = len(kv_caches) + self.shape = list(kv_layer.shape) + self.dtype = kv_layer.dtype + self.device_sharding = kv_layer.sharding + + # NOTE(jcgu): needed when sliced-kv is [num_tokens, num_head, head_dim] + self.flatten_device_sharding = jax.sharding.NamedSharding( + mesh=self.device_sharding.mesh, + spec=jax.sharding.PartitionSpec(None, "model"), + memory_kind="device") + + self.flatten_host_sharding = jax.sharding.NamedSharding( + mesh=self.device_sharding.mesh, + spec=jax.sharding.PartitionSpec(None, "model"), + memory_kind="pinned_host") + + self.swap_in_fn, self.swap_out_fn = get_kv_cache_swap_fn( + self.swap_op_type, + host_sharding=self.flatten_host_sharding, + device_sharding=self.flatten_device_sharding) + + logger.info( + "KV Cache details registered in TPUOffloadConnectorWorker:") + logger.info(f" - Num layers: {self.num_layers}") + logger.info(f" - Shape per layer: {self.shape}") + logger.info(f" - DType: {self.dtype}") + logger.info(f" - Device sharding: {self.device_sharding}") + logger.info( + f" - Flatten Device sharding: {self.flatten_device_sharding}") + logger.info(f" - Layout: {self.kv_cache_layout}") + else: + raise ValueError( + "TPUOffloadConnectorWorker registered with no KV caches.") + + # Pre-compile the JIT functions for KV cache swapping. + if self.use_bucketed_swap_ops: + self._precompile_kv_swap_operations() + + def _decompose_into_buckets(self, num_blocks: int) -> list[int]: + """ + Decomposes a number into a sum of numbers from the BLOCK_SIZE_BUCKETS + list using a greedy approach. + """ + sorted_buckets = sorted(BLOCK_SIZE_BUCKETS, reverse=True) + chunks = [] + remaining = num_blocks + while remaining > 0: + for bucket_size in sorted_buckets: + if remaining >= bucket_size: + chunks.append(bucket_size) + remaining -= bucket_size + break + else: + # This should not happen if 1 is in the buckets + raise ValueError( + "Could not decompose number with the given buckets.") + return chunks + + def _precompile_kv_swap_operations(self): + """ + Pre-compiles the JIT-compiled functions used for KV cache swapping + with a variety of common block sizes to avoid runtime recompilation. + """ + if os.getenv("TPU_OFFLOAD_SKIP_JAX_PRECOMPILE", "0") == "1": + logger.info( + "Skipping KV swap pre-compilation due to environment variable." + ) + return + + logger.info("Starting pre-compilation of KV cache swap operations") + start_time = time.time() + paged_kv_for_compilation = self.runner.kv_caches + for num_blocks in BLOCK_SIZE_BUCKETS: + try: + logger.info(f" - Compiling for {num_blocks} blocks...") + dummy_block_ids = jnp.arange(num_blocks) + + # 1. Pre-compile gather (used in save) + flat_dummy_kv_caches_tpu = KVCacheManager._jitted_gather_kv_cache( + paged_kv_for_compilation, dummy_block_ids) + jax.block_until_ready(flat_dummy_kv_caches_tpu) + + # 2. Pre-compile TPU -> CPU transfer (used in save) + dummy_kv_cpu = self.swap_out_fn(flat_dummy_kv_caches_tpu) + jax.block_until_ready(dummy_kv_cpu) + + # 3. Pre-compile CPU -> TPU transfer (used in load) + split_size_list = [self.block_size] * num_blocks + chunked_dummy_kv_cpu = jax.tree.map( + lambda flat_layer_cache: jax.lax.split( + flat_layer_cache, split_size_list, axis=0), + dummy_kv_cpu) + + chunked_dummy_kv_tpu = self.swap_in_fn(chunked_dummy_kv_cpu) + jax.block_until_ready(chunked_dummy_kv_tpu) + + # 4. Pre-compile insert (used in load). + # The result is passed to the next iteration's gather to avoid + # using a "deleted" array. + logger.info( + f" - Calling jitted_insert_kv_cache_slices with paged_kv_for_compilation len: {len(paged_kv_for_compilation)}, first_element_shape: {paged_kv_for_compilation[0].shape}, " + f"chunked_dummy_kv_tpu len: {len(chunked_dummy_kv_tpu)}") + paged_kv_for_compilation = jitted_insert_kv_cache_slices( + self.block_size, paged_kv_for_compilation, + chunked_dummy_kv_tpu, dummy_block_ids) + jax.block_until_ready(paged_kv_for_compilation) + except Exception as e: + logger.warning( + f" - Failed to pre-compile for {num_blocks} blocks: {e}", + exc_info=True) + + self.runner.kv_caches = paged_kv_for_compilation + duration = time.time() - start_time + logger.info("KV cache swap pre-compilation finished in %.2f [secs].", + duration) + + def _bucketed_gather_kv_cache( + self, + kv_caches: list[jax.Array], + block_ids: jax.Array, + ) -> list[jax.Array]: + """ + Gathers KV cache data for the given block_ids by breaking the operation + into bucket-aligned chunks to leverage JIT compilation cache. + """ + num_blocks = len(block_ids) + if num_blocks == 0: + return [] + if num_blocks in BLOCK_SIZE_BUCKETS: + return KVCacheManager._jitted_gather_kv_cache(kv_caches, block_ids) + + decomposed_block_sizes = self._decompose_into_buckets(num_blocks) + logger.info( + f"Decomposing gather for {num_blocks} blocks into bucket sizes {decomposed_block_sizes}" + ) + gathered_chunks = [] + block_offset = 0 + for decomposed_block_size in decomposed_block_sizes: + block_slice = jax.lax.dynamic_slice_in_dim(block_ids, + block_offset, + decomposed_block_size, + axis=0) + gathered_chunk = KVCacheManager._jitted_gather_kv_cache( + kv_caches, block_slice) + gathered_chunks.append(gathered_chunk) + block_offset += decomposed_block_size + + # Reassemble the results from all chunks + return jax.tree.map(lambda *x: jnp.concatenate(x, axis=0), + *gathered_chunks) + + def _bucketed_swap_out_fn( + self, + flat_kv_caches_tpu: list[jax.Array]) -> list[list[jax.Array]]: + """ + Swaps out KV cache data from TPU to CPU in bucket-aligned chunks, + returning a list of block-sized chunks per layer. + """ + num_tokens = flat_kv_caches_tpu[0].shape[0] + num_blocks = num_tokens // self.block_size + if num_blocks == 0: + return [[] for _ in range(self.num_layers)] + + # Fast path: handle bucket-sized transfers + if num_blocks in BLOCK_SIZE_BUCKETS: + split_size_list = [self.block_size] * num_blocks + flat_kv_caches_cpu = self.swap_out_fn(flat_kv_caches_tpu) + jax.block_until_ready(flat_kv_caches_cpu) + return jax.tree.map( + lambda flat_layer_cache: jax.lax.split( + flat_layer_cache, split_size_list, axis=0), + flat_kv_caches_cpu) + + # Bucket decomposition path + decomposed_block_sizes = self._decompose_into_buckets(num_blocks) + logger.info( + f"Decomposing swap-out for {num_blocks} blocks into bucket sizes {decomposed_block_sizes}" + ) + # This will be a list of lists, where each inner list holds the chunks + # for a layer. + final_chunks_per_layer = [[] for _ in range(self.num_layers)] + token_offset = 0 + for decomposed_block_size in decomposed_block_sizes: + chunk_size_in_tokens = decomposed_block_size * self.block_size + + # Slice the TPU tensor for the current bucket + tpu_chunk = [ + jax.lax.dynamic_slice_in_dim(layer_cache, + token_offset, + chunk_size_in_tokens, + axis=0) + for layer_cache in flat_kv_caches_tpu + ] + + # Swap the bucket to CPU, result is a flat tensor for this bucket. We are doing the chunking inside this function to avoid returning any jnp.concatenate + # of kv cache for the the bucketed blocks + cpu_chunk_flat_per_layer = self.swap_out_fn(tpu_chunk) + jax.block_until_ready(cpu_chunk_flat_per_layer) + # Split the flat bucket tensor into block-sized chunks and append + split_size_list = [self.block_size] * decomposed_block_size + for i, layer_cache in enumerate(cpu_chunk_flat_per_layer): + chunks = jax.lax.split(layer_cache, split_size_list, axis=0) + final_chunks_per_layer[i].extend(chunks) + + token_offset += chunk_size_in_tokens + + return final_chunks_per_layer + + def _bucketed_swap_in_fn( + self, + assembled_kv_on_cpu: list[list[jax.Array]], + ) -> list[list[jax.Array]]: + """ + Swaps in KV cache data from CPU to TPU in bucket-aligned chunks, + assembling a complete staging buffer on the TPU. + """ + num_blocks = len(assembled_kv_on_cpu[0]) + if num_blocks == 0: + return [[] for _ in range(self.num_layers)] + if num_blocks in BLOCK_SIZE_BUCKETS: + return self.swap_in_fn(assembled_kv_on_cpu) + + decomposed_block_sizes = self._decompose_into_buckets(num_blocks) + logger.info( + f"Decomposing swap-in for {num_blocks} blocks into bucket sizes {decomposed_block_sizes}" + ) + + tpu_chunks_per_layer = [[] for _ in range(self.num_layers)] + block_offset = 0 + for decomposed_block_size in decomposed_block_sizes: + cpu_chunks_for_bucket = [ + layer_chunks[block_offset:block_offset + decomposed_block_size] + for layer_chunks in assembled_kv_on_cpu + ] + tpu_chunks_for_bucket = self.swap_in_fn(cpu_chunks_for_bucket) + for i in range(self.num_layers): + tpu_chunks_per_layer[i].extend(tpu_chunks_for_bucket[i]) + block_offset += decomposed_block_size + + return tpu_chunks_per_layer + + def _bucketed_jitted_insert_kv_cache_slices( + self, + kv_caches: list[jax.Array], + kv_cache_slices: list[list[jax.Array]], + dst_blocks: jax.Array, + ) -> list[jax.Array]: + """ + Inserts KV cache slices into the main cache in bucket-aligned chunks. + """ + num_blocks = len(dst_blocks) + if num_blocks == 0: + return kv_caches + if num_blocks in BLOCK_SIZE_BUCKETS: + return jitted_insert_kv_cache_slices(self.block_size, kv_caches, + kv_cache_slices, dst_blocks) + + decomposed_block_sizes = self._decompose_into_buckets(num_blocks) + logger.info( + f"Decomposing insert for {num_blocks} blocks into bucket sizes {decomposed_block_sizes}" + ) + + updated_kv_caches = kv_caches + block_offset = 0 + for decomposed_block_size in decomposed_block_sizes: + slices_for_bucket = [ + layer_slices[block_offset:block_offset + decomposed_block_size] + for layer_slices in kv_cache_slices + ] + dst_blocks_for_bucket = jax.lax.dynamic_slice_in_dim( + dst_blocks, block_offset, decomposed_block_size, axis=0) + + updated_kv_caches = jitted_insert_kv_cache_slices( + self.block_size, updated_kv_caches, slices_for_bucket, + dst_blocks_for_bucket) + + block_offset += decomposed_block_size + + return updated_kv_caches + + def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], + full_token_ids: list[int], + save_spec: SaveSpec) -> ReqId: + """ + Extracts KV cache blocks from TPU, copies them to CPU, and updates the + CPU backend with the new cache keys and their corresponding token data. + """ + if not self.runner or not self.runner.kv_caches: + logger.error(f"Cannot save blocks for request {req_id}: runner or " + "KV caches not registered.") + return req_id + + blocks_to_save = save_spec.src_blocks + dst_chunks = save_spec.dst_chunks + + num_total_tokens = save_spec.num_total_tokens + num_skip_leading_tokens = save_spec.num_skip_leading_tokens + num_blocks_to_save = len(blocks_to_save) + + assert num_total_tokens <= len( + full_token_ids), f"{num_total_tokens} > {len(full_token_ids)}" + + num_tokens_to_save = num_total_tokens - num_skip_leading_tokens + if num_tokens_to_save <= 0 and not save_spec.is_final_save: + logger.info(f"Request {req_id}: No new tokens to save.") + return req_id + + process_token_ids = full_token_ids[:num_total_tokens] + tokens_to_save = process_token_ids[num_skip_leading_tokens:] + + logger.info( + f"Request {req_id} save details: " + f"full_block_ids len={len(full_block_ids)}, " + f"num_skip_leading_tokens={num_skip_leading_tokens}, " + f"num_total_tokens={num_total_tokens}, " + f"num_tokens_to_save={num_tokens_to_save}, " + f"blocks_to_save({len(blocks_to_save)}: {blocks_to_save}), " + f"dst_chunks({len(dst_chunks)}: {dst_chunks}) ") + + if not blocks_to_save and tokens_to_save: + logger.warning( + f"Request {req_id}: Tokens to save but no corresponding blocks found." + ) + return req_id + + if not tokens_to_save: + logger.info( + f"Request {req_id}: No new tokens to save, but processing as final save." + ) + return req_id + + # Verify if blocks_to_save is a contiguous subarray of full_block_ids + first_src_block = blocks_to_save[0] + last_src_block = blocks_to_save[-1] + try: + first_block_idx_in_full = full_block_ids.index(first_src_block) + last_block_idx_in_full = full_block_ids.index(last_src_block) + if not (last_block_idx_in_full - first_block_idx_in_full + 1 + == len(blocks_to_save)): + raise ValueError( + f"Request({req_id}): blocks_to_save {blocks_to_save} does not exist in full_block_ids {full_block_ids}" + ) + except ValueError: + raise ValueError( + f"Request({req_id}): blocks_to_save {blocks_to_save} contains blocks not present in local_block_ids {full_block_ids}" + ) + + try: + start_time = time.time() + blocks_to_save = jnp.array(blocks_to_save) + if self.use_bucketed_swap_ops: + flat_kv_caches_tpu = self._bucketed_gather_kv_cache( + self.runner.kv_caches, blocks_to_save) + else: + flat_kv_caches_tpu = KVCacheManager._jitted_gather_kv_cache( + self.runner.kv_caches, blocks_to_save) + + jax.block_until_ready(flat_kv_caches_tpu) + logger.info( + f"extracted_blocks_tpu: {flat_kv_caches_tpu[0].shape}, {flat_kv_caches_tpu[0].sharding}" + ) + + chunks_on_cpu = None + if self.use_bucketed_swap_ops: + chunks_on_cpu = self._bucketed_swap_out_fn(flat_kv_caches_tpu) + else: + flat_kv_caches_cpu = self.swap_out_fn(flat_kv_caches_tpu) + if flat_kv_caches_cpu: + jax.block_until_ready(flat_kv_caches_cpu) + # NOTE(jcgu): we keep cpu_chunk_size == block_size + split_size_list = [self.cpu_chunk_size + ] * num_blocks_to_save + chunks_on_cpu = jax.tree.map( + lambda flat_layer_cache: jax.lax.split( + flat_layer_cache, split_size_list, axis=0), + flat_kv_caches_cpu) + + if chunks_on_cpu and chunks_on_cpu[0]: + jax.block_until_ready(chunks_on_cpu) + + duration = time.time() - start_time + logger.info(f"Successfully saved {len(blocks_to_save)} blocks for " + f"request {req_id} to CPU in {duration:.4f} seconds.") + + total_size_bytes = sum( + sum(chunk.nbytes for chunk in layer_chunks) + for layer_chunks in chunks_on_cpu) + logger.info( + f"Total size of chunks_on_cpu: {total_size_bytes / 1024**2:.2f} MB" + ) + + post_transfer_start_time = time.time() + + for i in range(num_blocks_to_save): + chunk_id = dst_chunks[i] + cur_chunk_cross_layers = [ + chunks_on_cpu[j][i] for j in range(self.num_layers) + ] + self.cpu_backend.add(chunk_id, cur_chunk_cross_layers) + logger.info(f"Request {req_id}: Saving to CPU chunk: " + f"chunk_id={chunk_id}, " + f" local_chunk_idx={i}") + + logger.info( + f"Request {req_id}: Added {num_blocks_to_save} chunks to CPU backend." + ) + + post_transfer_duration = time.time() - post_transfer_start_time + logger.info( + f"Request {req_id}: e2e host processing of {num_blocks_to_save} chunks took {post_transfer_duration:.4f} seconds." + ) + except Exception as e: + logger.error(f"Error saving blocks for request {req_id}: {e}", + exc_info=True) + + return req_id + + def wait_for_save(self): + """ + Initiates and waits for all pending asynchronous save operations for the + current step to complete. + """ + # This method is idempotent. If the save operations for the current + # step's metadata have already been processed, we can exit early. + if self._processed_save_for_step: + return + + # logger.info("TPUOffloadConnectorWorker: Entering wait_for_save") + metadata = self.connector._get_connector_metadata() + if not isinstance(metadata, TPUOffloadConnectorMetadata): + logger.info( + "wait_for_save:not an instances of TPUOffloadConnectorMetadata" + ) + self._processed_save_for_step = True + return + + if not metadata.requests_meta: + # logger.info("wait_for_save:no reqs to save") + self._processed_save_for_step = True + return + + pending_save_futures: list[tuple[Future, TPUReqMeta]] = [] + # Handle save requests + for meta in metadata.requests_meta: + if meta.save_spec: + if meta.save_spec.skip_save: + logger.info( + f"Request {meta.req_id}: Scheduler signaled to skip save." + ) + if meta.save_spec.is_final_save: + logger.info( + f"Request {meta.req_id}: Final save is a no-op. Marking as finished." + ) + # self.finished_save_reqs.add(meta.req_id) + continue + + # If there are tokens to save, submit the task to the thread pool. + logger.info(f"Submitting save task for request {meta.req_id}") + future = self.save_executor.submit(self._save_blocks_to_cpu, + meta.req_id, + meta.local_block_ids, + meta.token_ids, + meta.save_spec) + pending_save_futures.append((future, meta)) + + if not pending_save_futures: + self._processed_save_for_step = True + return + + logger.info(f"Waiting for {len(pending_save_futures)} save " + "operations to complete...") + start_time = time.time() + + for future, meta in pending_save_futures: + try: + # The result of _save_blocks_to_cpu is the request_id + finished_req_id = future.result() + logger.info( + f"Save operation completed for request {finished_req_id}") + + if len(meta.save_spec.src_blocks) > 0: + self.offload_stats.record_save( + req=finished_req_id, + saved_chunk_ids=meta.save_spec.dst_chunks) + + if meta.save_spec and meta.save_spec.is_final_save: + logger.info( + f"Request {finished_req_id}: Final save completed. Marking as finished." + ) + self.finished_save_reqs.add(finished_req_id) + + except Exception as e: + logger.error(f"A save operation failed: {e}", exc_info=True) + + duration = time.time() - start_time + logger.info(f"All {len(pending_save_futures)} save operations " + f"completed in {duration:.4f} seconds.") + self._processed_save_for_step = True + + def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: + """ + This function is the worker-side entry point for loading data from the + local CPU backend into the TPU's sharded KV cache. It is a blocking + operation that ensures the cache is fully updated before the model's + forward pass begins. + """ + # Reset the save processing flag at the start of a new step. + self._processed_save_for_step = False + metadata = self.connector._get_connector_metadata() + if not isinstance( + metadata, + TPUOffloadConnectorMetadata) or not metadata.requests_meta: + logger.info("No load operations scheduled for this step.") + return + + if not self.device_sharding: + raise RuntimeError( + "KV cache sharding info not available. Was register_runner called?" + ) + + assert self.runner is not None and self.runner.kv_caches is not None + + # Process each request that needs its KV cache loaded + load_times = [] + for meta in metadata.requests_meta: + if not (meta.load_spec and meta.load_spec.can_load): + continue + + request_load_start_time = time.time() + logger.info( + "TPUOffloadConnectorWorker: Starting KV cache load process.") + dst_blocks = meta.load_spec.dst_blocks + src_chunks = meta.load_spec.src_chunks + num_blocks_to_load = len(dst_blocks) + num_matched_tokens = meta.load_spec.num_matched_tokens + num_skip_leading_tokens = meta.load_spec.num_skip_leading_tokens + num_tokens_to_load_delta = num_matched_tokens - num_skip_leading_tokens + assert num_skip_leading_tokens % self.block_size == 0, f"{num_skip_leading_tokens} % {self.block_size} != 0" + + if num_tokens_to_load_delta <= 0: + logger.info( + f"Request {meta.req_id}: No new tokens to load. Skipping.") + continue + + assert num_blocks_to_load > 0, f"Request({meta.req_id}) has no dst blocks to load." + # Verify if dst_blocks is a contiguous subarray of meta.local_block_ids + first_dst_block = dst_blocks[0] + last_dst_block = dst_blocks[-1] + try: + first_block_idx_in_local = meta.local_block_ids.index( + first_dst_block) + last_block_idx_in_local = meta.local_block_ids.index( + last_dst_block) + if not (last_block_idx_in_local - first_block_idx_in_local + 1 + == len(dst_blocks)): + raise ValueError( + f"Request({meta.req_id}): dst_blocks {dst_blocks} does not exist in local_block_ids {meta.local_block_ids}" + ) + except ValueError: + raise ValueError( + f"Request({meta.req_id}): dst_blocks {dst_blocks} contains blocks not present in local_block_ids {meta.local_block_ids}" + ) + + logger.info( + f"Processing KV load for request {meta.req_id}: " + f"Total matched: {num_matched_tokens}, " + f"Already computed: {num_skip_leading_tokens}. " + f"Fetching delta of {num_tokens_to_load_delta} tokens from cache for " + f"{num_blocks_to_load} blocks.") + + # Assemble the per-layer data for the delta tokens on the CPU. + # We create a list of lists, where the outer list represents layers + # and the inner lists will hold the data chunks for that layer. + assembled_kv_on_cpu = [[] for _ in range(self.num_layers)] + # Fetch and chunks from the backend. + for i in range(num_blocks_to_load): + src_chunk_id = src_chunks[i] + cached_value = self.cpu_backend.get(src_chunk_id) + if cached_value: + for j in range(self.num_layers): + assembled_kv_on_cpu[j].append(cached_value[j]) + else: + logger.error( + f"Chunk[{src_chunk_id}] not found in CPU backend for request {meta.req_id}. Inconsistent state detected." + ) + return + + # swap-in + # output: [[cpu_chunk_size * num_chunks] * num_layer] + if self.use_bucketed_swap_ops: + # Use the bucketed wrappers for a uniform two-step process + raw_chunked_kv_on_tpu = self._bucketed_swap_in_fn( + assembled_kv_on_cpu) + else: + raw_chunked_kv_on_tpu = self.swap_in_fn(assembled_kv_on_cpu) + jax.block_until_ready(raw_chunked_kv_on_tpu) + + if self.use_bucketed_swap_ops: + self.runner.kv_caches = self._bucketed_jitted_insert_kv_cache_slices( + self.runner.kv_caches, + raw_chunked_kv_on_tpu, + jnp.array(dst_blocks), + ) + else: + self.runner.kv_caches = jitted_insert_kv_cache_slices( + self.block_size, + self.runner.kv_caches, + raw_chunked_kv_on_tpu, + jnp.array(dst_blocks), + ) + jax.block_until_ready(self.runner.kv_caches) + logger.info( + f"Request {meta.req_id}: Loaded {num_tokens_to_load_delta} tokens into " + f"{num_blocks_to_load} new blocks.") + + load_times.append(time.time() - request_load_start_time) + self.finished_load_reqs.add(meta.req_id) + if num_blocks_to_load > 0: + self.offload_stats.record_load(req=meta.req_id, + loaded_chunk_ids=src_chunks) + + if load_times: + aggregate_load_time = sum(load_times) + logger.info( + f"TPUOffloadConnectorWorker: Aggregate KV cache load time for {len(load_times)} requests: {aggregate_load_time:.4f} seconds" + ) + + def get_kv_connector_stats(self) -> KVConnectorStats | None: + """ + Get the KV transfer stats for the connector. + """ + # Clear stats for next iteration + if not self.offload_stats.is_empty(): + return self.offload_stats.clone_and_reset() + return None + + def get_finished(self) -> tuple[set[str], set[str]]: + """ + Returns the sets of request IDs for completed save and load operations. + """ + # Safeguard call to wait_for_save(). + # In the final step for a request, the vLLM engine may not call + # `worker.execute_model()` if there's no computation to be done. + # This skips the usual `wait_for_save()` call, preventing the final + # save operation (marked with `is_final_save=True`) from being + # processed. Calling it here ensures that any pending save operations + # for the current step's metadata are executed, and the finished + # request IDs are correctly identified and reported back to the engine + # for resource cleanup. The `wait_for_save` method is idempotent, + # so this call is a no-op in the normal execution path. + logger.info("TPUOffloadConnectorWorker: Entering get_finished") + self.wait_for_save() + + finished_saves = self.finished_save_reqs + self.finished_save_reqs = set() + finished_loads = self.finished_load_reqs + self.finished_load_reqs = set() + logger.info(f"Finished saves: {finished_saves}, " + f"Finished loads: {finished_loads}") + return finished_saves, finished_loads diff --git a/tpu_inference/offload/utils.py b/tpu_inference/offload/utils.py new file mode 100644 index 000000000..3a99a57e6 --- /dev/null +++ b/tpu_inference/offload/utils.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 + +import functools +import hashlib +from dataclasses import dataclass +from typing import Callable, Iterable, List, Literal, Optional, Tuple + +import jax +from vllm.config import get_current_vllm_config +from vllm.distributed.kv_transfer.kv_connector.factory import \ + KVConnectorFactory + +from tpu_inference.kernels.dma.host_dma import d2h_dma, h2d_dma +from tpu_inference.logger import init_logger + +ReqId = str + +CpuChunkId = int + +# Corresponds to the initial hash value +NONE_HASH = 0 + +logger = init_logger(__name__) + +CPU_OFFLOADING_SWAP_OP_TYPE = Literal["jax", "pallas"] + + +@dataclass(order=True) +class CacheKey: + """ + A key for the cache engine. + """ + model_name: str + chunk_hash: int + + def __hash__(self): + return hash(( + self.model_name, + self.chunk_hash, + )) + + def __eq__(self, other): + if type(self) is type(other): + return (self.model_name == other.model_name + and self.chunk_hash == other.chunk_hash) + return False + + +class TokenProcessor: + + def __init__(self, model_name: str, chunk_size: int = 16): + self.model_name = model_name + self.chunk_size = chunk_size + logger.info(f"TokenProcessor initialized with chunk_size={chunk_size}") + + def _hash_tokens( + self, + tokens: List[int], + prefix_hash: Optional[int] = None, + ) -> int: + hasher = hashlib.sha256() + hasher.update(str(prefix_hash).encode('utf-8')) + hasher.update(str(tuple(tokens)).encode('utf-8')) + return int(hasher.hexdigest(), 16) + + def process_tokens( + self, + tokens: Optional[List[int]] = None, + ) -> Iterable[Tuple[int, int, CacheKey]]: + """Process the tokens and return the corresponding cache keys.""" + if not tokens: + return + + total_len = len(tokens) + prefix_hash = NONE_HASH + + for i in range(0, total_len, self.chunk_size): + chunk = tokens[i:i + self.chunk_size] + prefix_hash = self._hash_tokens(chunk, prefix_hash) + start_idx = i + end_idx = min(start_idx + self.chunk_size, total_len) + logger.info( + f"Processing chunk: start={start_idx}, end={end_idx}, hash={prefix_hash}" + ) + yield ( + start_idx, + end_idx, + CacheKey(model_name=self.model_name, chunk_hash=prefix_hash), + ) + + +def get_kv_connector_cache_layout(): + """ + Retrieve the required kv cache layout for the configured kv connector + Return: None, when no kv_transfer_config is found; otherwise, the layout str + """ + vllm_config = get_current_vllm_config() + kv_config = vllm_config.kv_transfer_config + if kv_config is not None: + connector_cls = KVConnectorFactory.get_connector_class(kv_config) + required_kvcache_layout = \ + connector_cls.get_required_kvcache_layout(vllm_config) + if required_kvcache_layout is not None: + return required_kvcache_layout + logger.info_once( + "Connectors do not specify a kv cache layout, defaulting to NHD.") + return None + + +SwapFn = Callable[ + [ + List[jax.Array], # src_kv_caches + jax.sharding.NamedSharding, # src_sharding + jax.sharding.NamedSharding, # dst_sharding + Literal["h2d", "d2h"], # direction + ], + List[jax.Array], # return value +] + +KVCacheSwapFn = Callable[[List[jax.Array]], List[jax.Array]] + + +# NOTE(jcgu): keep the same interface as the pallas one +def jax_swap_kv_caches( + src_kv_caches: List[jax.Array], + src_sharding: jax.sharding.NamedSharding, + dst_sharding: jax.sharding.NamedSharding, + direction: Literal["h2d", "d2h"], +) -> List[jax.Array]: + """Swap in / out multi-layer kv_cache using jax device_put + + Args: + src_kv_caches: [kv_cache of each layer] + src_sharding: kv_caches' original sharding + dst_sharding: kv_caches' target sharding (different memory_kind) + direction: h2d -> swap_in, d2h -> swap_out + Returns: + a list of jax.Array objects with the dst_sharding + """ + + def _jax_device_put(input_array): + return jax.device_put(input_array, dst_sharding) + + return jax.tree.map(_jax_device_put, src_kv_caches) + + +def pallas_swap_kv_caches( + src_kv_caches: List[jax.Array], + src_sharding: jax.sharding.NamedSharding, + dst_sharding: jax.sharding.NamedSharding, + direction: Literal["h2d", "d2h"], +) -> List[jax.Array]: + """Swap in / out multi-layer kv_cache using pallas dma kernel + + Args: + src_kv_caches: [kv_cache of each layer] + src_sharding: kv_caches' original sharding + dst_sharding: kv_caches' target sharding (different memory_kind) + direction: h2d -> swap_in, d2h -> swap_out + Returns: + a list of jax.Array objects with the dst_sharding + """ + + def swap_in_fn(inputs, input_sharding, out_sharding): + + def _swap_in(host_sharded_array): + return h2d_dma(host_sharded_array, input_sharding, out_sharding) + + return jax.tree.map(_swap_in, inputs) + + def swap_out_fn(inputs, input_sharding, out_sharding): + + def _swap_out(hbm_sharded_array): + return d2h_dma(hbm_sharded_array, input_sharding, out_sharding) + + return jax.tree.map(_swap_out, inputs) + + if direction == "d2h": + return swap_out_fn(src_kv_caches, src_sharding, dst_sharding) + elif direction == "h2d": + return swap_in_fn(src_kv_caches, src_sharding, dst_sharding) + + +def get_kv_cache_swap_fn( + swap_op_type: CPU_OFFLOADING_SWAP_OP_TYPE, + host_sharding: jax.sharding.NamedSharding, + device_sharding: jax.sharding.NamedSharding, + jitted: bool = True, +) -> Tuple[KVCacheSwapFn, KVCacheSwapFn]: + """get the right swap_in and swap_out functions + + Args: + swap_op_type : (str) pallas or jax + host_sharding: + device_sharding: + + Returns: + A tuple containing the jitted swap-in and swap-out functions. + """ + _swap_fn: SwapFn = pallas_swap_kv_caches if swap_op_type == "pallas" else jax_swap_kv_caches + if jitted: + _swap_in_fn = jax.jit( + _swap_fn, + static_argnames=["src_sharding", "dst_sharding", "direction"], + out_shardings=device_sharding) + _swap_out_fn = jax.jit( + _swap_fn, + static_argnames=["src_sharding", "dst_sharding", "direction"], + out_shardings=host_sharding) + else: + _swap_in_fn = _swap_fn + _swap_out_fn = _swap_fn + + # swap_in (h2d) + swap_in_fn = functools.partial(_swap_in_fn, + src_sharding=host_sharding, + dst_sharding=device_sharding, + direction="h2d") + # swap_out (d2h) + swap_out_fn = functools.partial(_swap_out_fn, + src_sharding=device_sharding, + dst_sharding=host_sharding, + direction="d2h") + return swap_in_fn, swap_out_fn + + +@functools.partial( + jax.jit, + static_argnames=("block_size"), + donate_argnames=( + "kv_caches", + "kv_cache_slices", + ), +) +def jitted_insert_kv_cache_slices( + block_size, + kv_caches: List[jax.Array], + kv_cache_slices: List[List[jax.Array]], + block_numbers: jax.Array, +) -> List[jax.Array]: + """ + JIT-compiled function to insert KV cache slices into the physical + cache for all layers at once. This fuses reshape, and scatter + operations into a single efficient kernel. + """ + + def _update_layer(cache, slices): + """The function to apply to each layer's cache and slices.""" + # new_shape = (1, block_size, *slices[0].shape[1:]) + for (i, block_idx) in enumerate(block_numbers): + # reshaped_block = slices[i].reshape(new_shape) + reshaped_block = jax.lax.expand_dims(slices[i], dimensions=(0, )) + cache = jax.lax.dynamic_update_slice_in_dim(cache, + reshaped_block, + block_idx, + axis=0) + return cache + + return jax.tree.map(_update_layer, kv_caches, kv_cache_slices) diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index 9f2a78526..8502bcd8e 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -217,9 +217,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "Forcing --disable_chunked_mm_input.") scheduler_config.disable_chunked_mm_input = True - kv_transfer_config = vllm_config.kv_transfer_config - if kv_transfer_config is not None: - assert kv_transfer_config.kv_connector == "TPUConnector" # Late initialization to avoid circular import from tpu_inference.models.jax.utils.quantization.quantization_utils import \ update_vllm_config_for_qwix_quantization diff --git a/tpu_inference/runner/kv_cache_manager.py b/tpu_inference/runner/kv_cache_manager.py index 348521715..5e4d0af54 100644 --- a/tpu_inference/runner/kv_cache_manager.py +++ b/tpu_inference/runner/kv_cache_manager.py @@ -11,6 +11,8 @@ from vllm.attention.layer import Attention from vllm.config import get_layers_from_vllm_config from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backends.utils import (get_kv_cache_layout, + set_kv_cache_layout) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec, MLAAttentionSpec, SlidingWindowSpec) @@ -18,6 +20,7 @@ from tpu_inference import utils from tpu_inference import utils as common_utils from tpu_inference.logger import init_logger +from tpu_inference.offload.utils import get_kv_connector_cache_layout from tpu_inference.runner import utils as runner_utils from tpu_inference.runner.input_batch import CachedRequestState, InputBatch from tpu_inference.runner.kv_cache import create_kv_caches @@ -29,6 +32,10 @@ logger = init_logger(__name__) +# default layout (order) used by kv cache manager +# N=num_blocks, H=num_heads and D=head_size +DEFAULT_KV_CACHE_LAYOUT = "NHD" + class KVCacheManager: @@ -165,6 +172,10 @@ def get_kv_cache_spec(self): f"Unknown attention type: {attn_module.attn_type}") return kv_cache_spec + def get_kv_cache_layout(self): + # return the layout (mostly "NHD" or "HND") of kv cache + return get_kv_cache_layout() + def maybe_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: block_sizes = [ @@ -195,6 +206,19 @@ def maybe_reinitialize_input_batch(self, def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.maybe_reinitialize_input_batch(kv_cache_config) + # set the kv cache layout which is needed by kv connectors + # NOTE(jcgu): please update the default value when the order changes + set_kv_cache_layout(DEFAULT_KV_CACHE_LAYOUT) + + # verify kv cache layout is matched between the cache manager and + # the kv connector (if configured) + _required_kv_layout = get_kv_connector_cache_layout() + if (_required_kv_layout + and _required_kv_layout != DEFAULT_KV_CACHE_LAYOUT): + raise ValueError( + f"KV cache layout ({DEFAULT_KV_CACHE_LAYOUT}) does not match with the " + f"kv_connector's required layout ({_required_kv_layout})") + # uniform page size. representative_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec page_size_bytes = representative_spec.page_size_bytes diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index ac67eae30..893f7ac60 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -524,6 +524,9 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def get_kv_cache_spec(self): return self.kv_cache_manager.get_kv_cache_spec() + def get_kv_cache_layout(self): + return self.kv_cache_manager.get_kv_cache_layout() + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.kv_cache_config = kv_cache_config self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1 diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index 9b52d43e4..85c8ba083 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -289,6 +289,25 @@ def determine_available_memory(self) -> int: total_hbm_limit_gb = round(total_hbm_limit / utils.GBYTES, 2) total_hbm_limit_cap_gb = round(total_hbm_limit_cap / utils.GBYTES, 2) total_hbm_used_gb = round(total_hbm_used / utils.GBYTES, 2) + + if self.vllm_config.kv_transfer_config is not None: + kv_transfer_config = self.vllm_config.kv_transfer_config + if kv_transfer_config.kv_connector == "TPUOffloadConnector" and \ + kv_transfer_config.kv_connector_module_path == "tpu_inference.offload.tpu_offload_connector": + # If kv offloading is enabled, we need to account for the memory used by the KV transfer buffer. + staging_buffer_pages = envs.TPU_OFFLOAD_NUM_STAGING_BLOCKS + + kv_cache_specs = self.model_runner.get_kv_cache_spec() + num_layers = len(kv_cache_specs) + vllm_page_size_bytes = get_uniform_page_size( + list(kv_cache_specs.values())) + stage_buffer_size_bytes = staging_buffer_pages * num_layers * vllm_page_size_bytes + + total_hbm_avail = total_hbm_avail - stage_buffer_size_bytes + logger.info( + f" ALERT: KV offloading enabled. Deducting {stage_buffer_size_bytes} Bytes ({staging_buffer_pages} pages) from available HBM for staging buffer." + ) + total_hbm_avail_gb = round(total_hbm_avail / utils.GBYTES, 2) logger.info(f"Memory statistics | " @@ -432,6 +451,11 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return kv_cache_specs + def get_kv_connector_handshake_metadata(self) -> dict | None: + """Get KV connector metadata from this worker if available.""" + # NOTE: we are not using it right now. + return + def initialize_from_config( self, kv_cache_config: KVCacheConfig, @@ -465,8 +489,3 @@ def sync_weights( def shutdown(self) -> None: return - - # Ray executor do not need handshake metadata - # as we pass the kv_parameters through proxy server - def get_kv_connector_handshake_metadata(self) -> None: - pass