From f0f295ae221f8959a809625fd05aa4c1e5c036a1 Mon Sep 17 00:00:00 2001 From: aman2930 Date: Mon, 6 Jan 2025 16:33:32 +0000 Subject: [PATCH 01/22] Extra logging for understanding the workflow --- jetstream/core/orchestrator.py | 7 +++++++ jetstream/core/server_lib.py | 2 ++ jetstream/tools/maxtext/model_ckpt_conversion.sh | 5 ++++- jetstream/tools/requester.py | 3 ++- 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 15fc36dd..19e94f98 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -581,6 +581,7 @@ def _prefill_thread(self, idx: int): def _jax_transfer_prefill_result( self, new_request: ActiveRequest, target_idx: int ): + logging.info("AMANGU: In _jax_transfer_prefill_result") new_request.prefill_result = jax.device_put( new_request.prefill_result, self._generate_engines[target_idx].get_prefix_destination_sharding(), @@ -596,6 +597,7 @@ def _ray_transfer_prefill_result( def _transfer_prefill_result( self, new_request: ActiveRequest, target_idx: int ): + logging.info("AMANGU: In _transfer_prefill_result") if self._is_ray_backend: self._ray_transfer_prefill_result(new_request, target_idx) else: @@ -605,6 +607,7 @@ def _transfer_thread(self, idx: int): """Transfers the kv cache on an active request to the least full generate backlog.""" transfer_backlog = self._transfer_backlogs[idx] + logging.info("AMANGU: In _transfer_thread") while self.live: # The transfer thread can just sleep until it has work to do. @@ -641,6 +644,7 @@ def _transfer_thread(self, idx: int): def _generate_thread(self, idx: int): """Step token generation and insert prefills from backlog.""" logging.info("---------Spinning up generate thread %d.---------", idx) + logging.info("AMANGU: In _generate_thread") generate_engine = self._generate_engines[idx] my_slots = self._generate_slots[idx] my_generate_backlog = self._generate_backlogs[idx] @@ -780,6 +784,7 @@ def _detokenize_thread(self, is_prefill: bool, idx: int): # For all filled my_slots, pop the sampled token onto the relevant # requests return channel. If it done, place it back onto free slots. + logging.info("AMANGU: In _detokenize_thread") if is_prefill: my_detokenize_backlog = self._prefill_detokenize_backlogs[idx] else: @@ -908,6 +913,7 @@ def __init__(self, driver: Driver): def _get_prefill_content( self, request: jetstream_pb2.DecodeRequest ) -> Tuple[str | list[int], bool]: + logging.info("AMANGU: In LLMOrchestrator::_get_prefill_content") which_content = request.WhichOneof("content") content = getattr(request, which_content) if which_content == "text_content": @@ -979,6 +985,7 @@ async def Decode( # pylint: disable=invalid-overridden-method request: jetstream_pb2.DecodeRequest, context: Optional[grpc.aio.ServicerContext] = None, ) -> AsyncIterator[jetstream_pb2.DecodeResponse]: + logging.info("AMANGU: In LLMOrchestrator::Decode") """Decode.""" if context is None: logging.warning( diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index b323286a..f9b95aef 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -154,6 +154,8 @@ def create_driver( traceback.print_exc() os.kill(os.getpid(), signal.SIGKILL) + logging.info("AMANGU: Going to create the drivers.") + return orchestrator.Driver( prefill_engines=prefill_engines, generate_engines=generate_engines, diff --git a/jetstream/tools/maxtext/model_ckpt_conversion.sh b/jetstream/tools/maxtext/model_ckpt_conversion.sh index 0340dbfe..ff272dbe 100644 --- a/jetstream/tools/maxtext/model_ckpt_conversion.sh +++ b/jetstream/tools/maxtext/model_ckpt_conversion.sh @@ -38,6 +38,8 @@ export MODEL_BUCKET=$4 # Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run, specifically the unscanned checkpoint. export BASE_OUTPUT_DIRECTORY=$5 +export LORA_LOCAL_PATH=$6 + export BUCKET_LOCATION=US # Create three GCS buckets for the demo. @@ -63,7 +65,8 @@ else JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \ --base-model-path ${tmp_ckpt_path}${directory_substring} \ --maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \ - --model-size ${MODEL_NAME} + --model-size ${MODEL_NAME} \ + --lora-path ${LORA_LOCAL_PATH} fi echo "Written MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}" diff --git a/jetstream/tools/requester.py b/jetstream/tools/requester.py index 7ac0d55a..54bbb78a 100644 --- a/jetstream/tools/requester.py +++ b/jetstream/tools/requester.py @@ -26,7 +26,8 @@ _SERVER = flags.DEFINE_string("server", "0.0.0.0", "server address") _PORT = flags.DEFINE_string("port", "9000", "port to ping") -_TEXT = flags.DEFINE_string("text", "My dog is cute", "The message") +#_TEXT = flags.DEFINE_string("text", "My dog is cute", "The message") +_TEXT = flags.DEFINE_string("text", "22 year old", "The message") _MAX_TOKENS = flags.DEFINE_integer( "max_tokens", 3, "Maximum number of output/decode tokens of a sequence" ) From 610fcea5c654e26c4bff3b086936ee5f6ecdc14f Mon Sep 17 00:00:00 2001 From: aman2930 Date: Wed, 8 Jan 2025 18:12:17 +0000 Subject: [PATCH 02/22] Updating checkpoint conversion script to support LoRA weights conversion to Orbax format. --- .../tools/maxtext/model_ckpt_conversion.sh | 46 ++++++++++++++----- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/jetstream/tools/maxtext/model_ckpt_conversion.sh b/jetstream/tools/maxtext/model_ckpt_conversion.sh index ff272dbe..f50aece3 100644 --- a/jetstream/tools/maxtext/model_ckpt_conversion.sh +++ b/jetstream/tools/maxtext/model_ckpt_conversion.sh @@ -58,7 +58,7 @@ else pip install torch --index-url https://download.pytorch.org/whl/cpu # llama_or_mistral_ckpt.py requires local path, so we need to copy the checkpoint from CHKPT_BUCKET to local. tmp_ckpt_path="/tmp/" - gcloud storage cp -r ${CHKPT_BUCKET} ${tmp_ckpt_path} + #gcloud storage cp -r ${CHKPT_BUCKET} ${tmp_ckpt_path} path_parts=(${CHKPT_BUCKET//\// }) directory_substring=${path_parts[-1]} CONVERT_CKPT_SCRIPT="llama_or_mistral_ckpt.py" @@ -66,25 +66,49 @@ else --base-model-path ${tmp_ckpt_path}${directory_substring} \ --maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \ --model-size ${MODEL_NAME} \ - --lora-path ${LORA_LOCAL_PATH} + --lora-config-path ${LORA_LOCAL_PATH}/adapter_config.json \ + --lora-model-path ${LORA_LOCAL_PATH}/adapter_model.bin fi echo "Written MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}" # We define `SCANNED_CKPT_PATH` to refer to the checkpoint subdirectory. -export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items +# export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items +export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} # Convert MaxText compatible checkpoints to unscanned checkpoints. # Note that the `SCANNED_CKPT_PATH` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx} -JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \ -MaxText/configs/base.yml \ -base_output_directory=${BASE_OUTPUT_DIRECTORY} \ -load_parameters_path=${SCANNED_CKPT_PATH} \ -run_name=${RUN_NAME} \ -model_name=${MODEL_NAME} \ -force_unroll=true -echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints" +#JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \ +#MaxText/configs/base.yml \ +#base_output_directory=${BASE_OUTPUT_DIRECTORY} \ +#load_parameters_path=${SCANNED_CKPT_PATH}/base_weights/0/items \ +#run_name=${RUN_NAME} \ +#model_name=${MODEL_NAME} \ +#force_unroll=true +#echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints" + +if [[ -x "${LORA_LOCAL_PATH}" ]]; then + JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \ + MaxText/configs/base.yml \ + base_output_directory=${BASE_OUTPUT_DIRECTORY} \ + load_parameters_path=${SCANNED_CKPT_PATH}/lora_A/0/items \ + run_name=${RUN_NAME}/lora_A \ + model_name=${MODEL_NAME} \ + force_unroll=true + echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/lora_A/checkpoints" + + JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \ + MaxText/configs/base.yml \ + base_output_directory=${BASE_OUTPUT_DIRECTORY} \ + load_parameters_path=${SCANNED_CKPT_PATH}/lora_B/0/items \ + run_name=${RUN_NAME}/lora_B \ + model_name=${MODEL_NAME} \ + force_unroll=true + echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/lora_B/checkpoints" +fi + + # We will use the unscanned checkpoints by passing `UNSCANNED_CKPT_PATH` into `LOAD_PARAMETERS_PATH` in the following sections. export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items From 50deb3e38f07b57d3aa397067762bd3923f235a7 Mon Sep 17 00:00:00 2001 From: aman2930 Date: Wed, 22 Jan 2025 18:28:13 +0000 Subject: [PATCH 03/22] Cleaning up of loggings and some refactoring to make the script work e2e with LoRA paths. --- .../tools/maxtext/model_ckpt_conversion.sh | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/jetstream/tools/maxtext/model_ckpt_conversion.sh b/jetstream/tools/maxtext/model_ckpt_conversion.sh index f50aece3..a18691fa 100644 --- a/jetstream/tools/maxtext/model_ckpt_conversion.sh +++ b/jetstream/tools/maxtext/model_ckpt_conversion.sh @@ -58,16 +58,23 @@ else pip install torch --index-url https://download.pytorch.org/whl/cpu # llama_or_mistral_ckpt.py requires local path, so we need to copy the checkpoint from CHKPT_BUCKET to local. tmp_ckpt_path="/tmp/" - #gcloud storage cp -r ${CHKPT_BUCKET} ${tmp_ckpt_path} + gcloud storage cp -r ${CHKPT_BUCKET} ${tmp_ckpt_path} path_parts=(${CHKPT_BUCKET//\// }) directory_substring=${path_parts[-1]} CONVERT_CKPT_SCRIPT="llama_or_mistral_ckpt.py" - JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \ - --base-model-path ${tmp_ckpt_path}${directory_substring} \ - --maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \ - --model-size ${MODEL_NAME} \ - --lora-config-path ${LORA_LOCAL_PATH}/adapter_config.json \ - --lora-model-path ${LORA_LOCAL_PATH}/adapter_model.bin + if [[ -x "${LORA_LOCAL_PATH}" ]]; then + JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \ + --base-model-path ${tmp_ckpt_path}${directory_substring} \ + --maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \ + --model-size ${MODEL_NAME} \ + --lora-config-path ${LORA_LOCAL_PATH}/adapter_config.json \ + --lora-model-path ${LORA_LOCAL_PATH}/adapter_model.bin + else + JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \ + --base-model-path ${tmp_ckpt_path}${directory_substring} \ + --maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \ + --model-size ${MODEL_NAME} + fi fi echo "Written MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}" @@ -79,33 +86,26 @@ export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} # Note that the `SCANNED_CKPT_PATH` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx} -#JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \ -#MaxText/configs/base.yml \ -#base_output_directory=${BASE_OUTPUT_DIRECTORY} \ -#load_parameters_path=${SCANNED_CKPT_PATH}/base_weights/0/items \ -#run_name=${RUN_NAME} \ -#model_name=${MODEL_NAME} \ -#force_unroll=true -#echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints" - if [[ -x "${LORA_LOCAL_PATH}" ]]; then JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \ MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ - load_parameters_path=${SCANNED_CKPT_PATH}/lora_A/0/items \ - run_name=${RUN_NAME}/lora_A \ + load_parameters_path=${SCANNED_CKPT_PATH}/base_weights/0/items \ + lora_parameters_base_path=${SCANNED_CKPT_PATH}/lora_weights/0/items \ + lora_config_path=${LORA_LOCAL_PATH}/adapter_config.json \ + run_name=${RUN_NAME} \ model_name=${MODEL_NAME} \ force_unroll=true - echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/lora_A/checkpoints" - + echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints" +else JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \ MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ - load_parameters_path=${SCANNED_CKPT_PATH}/lora_B/0/items \ - run_name=${RUN_NAME}/lora_B \ + load_parameters_path=${SCANNED_CKPT_PATH}/base_weights/0/items \ + run_name=${RUN_NAME} \ model_name=${MODEL_NAME} \ force_unroll=true - echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/lora_B/checkpoints" + echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints" fi From 7426ea70065eb5afd0a3d6d22e9f0eda79dfb479 Mon Sep 17 00:00:00 2001 From: aman2930 Date: Mon, 27 Jan 2025 18:20:02 +0000 Subject: [PATCH 04/22] 1) Added MultiAdapterManager service proto along with the methods ListAdapters, LoadAdapter and UnloadAdapter. 2) Driver which is holding list of all loaded base-parameters is now storing the list of lora updated paramters for loaded lora. Implemented methods for loading, unloading and listing LoRA adapters into the Driver object. Original base model params are intact and saved into the params dictionary with key . 3) Created a proxy-client to make MultiAdapterManager service requests to JetStream server. --- jetstream/core/adapter_manager.py | 117 ++++++++++++++++ jetstream/core/orchestrator.py | 99 ++++++++++++- jetstream/core/proto/jetstream.proto | 52 ++++++- jetstream/core/proto/jetstream_pb2.py | 79 ++++++----- jetstream/core/proto/jetstream_pb2_grpc.py | 72 ++++++++++ jetstream/core/server_lib.py | 12 +- jetstream/tools/llm_gateway_proxy_client.py | 145 ++++++++++++++++++++ 7 files changed, 535 insertions(+), 41 deletions(-) create mode 100644 jetstream/core/adapter_manager.py create mode 100644 jetstream/tools/llm_gateway_proxy_client.py diff --git a/jetstream/core/adapter_manager.py b/jetstream/core/adapter_manager.py new file mode 100644 index 00000000..358fc0e2 --- /dev/null +++ b/jetstream/core/adapter_manager.py @@ -0,0 +1,117 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Manages the list of fine-tuned adapters loaded on top of the base model for serving. +""" + +import logging +import grpc + +from typing import Optional + +from jetstream.core.proto import jetstream_pb2_grpc +from jetstream.core.proto import jetstream_pb2 +from jetstream.core import orchestrator + + +def calculate_loading_cost(adapter_path: str): + return 1 + + +class MultiLoraManager(jetstream_pb2_grpc.MultiAdapterManagerServicer): + """Manages the parameters of multiple lora requests and their lifelines.""" + + _driver: orchestrator.Driver + + def __init__(self, driver: orchestrator.Driver): + self._driver = driver + self.loaded_adapters = {} # Dictionary to track loaded adapters + + def ListAdapters( + self, + request: jetstream_pb2.ListAdaptersRequest, + context: Optional[grpc.aio.ServicerContext] = None, + ) -> jetstream_pb2.ListAdaptersResponse: + """ListAdapters all loaded LoRA adapters.""" + + try: + logging.info("AMANGU LOG: Before making call to mayBeListLoadedAdapters.") + self._driver.mayBeListLoadedAdapters() + logging.info("AMANGU LOG: After making call to mayBeListLoadedAdapters.") + + adapter_infos = [] + for adapter_id, adapter_data in self.loaded_adapters.items(): + adapter_info = jetstream_pb2.AdapterInfo( + adapter_id=adapter_id, + loading_cost=adapter_data["loading_cost"] + ) + adapter_infos.append(adapter_info) + + # logging.info("AMANGU Log (adapter_manager.py): ListAdapters is still under implementation") + logging.info("AMANGU LOG: List adapters --> Before returning success.") + logging.info(f"AMANGU LOG: List of adapters --> {adapter_infos}.") + + return jetstream_pb2.ListAdaptersResponse(success=True, adapter_infos=adapter_infos) + except Exception as e: + logging.info("AMANGU LOG: List adapters --> Before returning failure.") + return jetstream_pb2.ListAdaptersResponse(success=False, error_message=str(e)) + + + def LoadAdapter( + self, + request: jetstream_pb2.LoadAdapterRequest, + context: Optional[grpc.aio.ServicerContext] = None, + ) -> jetstream_pb2.LoadAdapterResponse: + """Load a LoRA adapter as mentioned in the request.""" + + try: + # Load the adapter using MaxEngine in the Driver + # Implmentation to load adatper using MaxEnbine and request.adapter_path + + # Store adapter info (e.g. loading cost + self._driver.loadAndApplyAdapter(request.adapter_id, + request.adapter_config_path, + request.adapter_weights_path) + + self.loaded_adapters[request.adapter_id] = { + "adapter_path": request.adapter_weights_path, + "loading_cost": calculate_loading_cost(request.adapter_weights_path) + } + + return jetstream_pb2.LoadAdapterResponse(success=True) + except Exception as e: + return jetstream_pb2.LoadAdapterResponse(success=False, error_message=str(e)) + + + def UnloadAdapter( + self, + request: jetstream_pb2.UnloadAdapterRequest, + context: Optional[grpc.aio.ServicerContext] = None, + ) -> jetstream_pb2.UnloadAdapterResponse: + """Unload a LoRA adapter as mentioned in the request.""" + + # logging.info("AMANGU Log (adapter_manager.py): UnloadAdapter is still under implementation") + try: + # Unload the adapter + # Implementation to unload adapter from MaxEngine + self._driver.unloadAdapter(request.adapter_id) + + del self.loaded_adapters[request.adapter_id] + return jetstream_pb2.UnloadAdapterResponse(success=True) + except Exception as e: + logging.info(f"AMANGU Log(adapter_manager.py): UnloadAdapter failed with error {str(e)}") + return jetstream_pb2.UnloadAdapterResponse(success=False, error_message=str(e)) + + + diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 19e94f98..425861ef 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -74,6 +74,7 @@ to debug hangs due to bugs in threads (it is easier to debug with live logs). """ +import copy import dataclasses import functools import itertools @@ -149,6 +150,8 @@ class ActiveRequest: is_client_side_tokenization: Optional[bool] = False ################## Information relevant for metrics ################### metadata: ActiveRequestMetadata = ActiveRequestMetadata() + ################## Id of the adapter ################### + adapter_id: str = "" def enqueue_samples(self, generated_samples: list[ReturnSample]): """Adds the generated sample(s) to return channel for current step. @@ -512,7 +515,6 @@ def _prefill_thread(self, idx: int): """Thread which runs in the background performing prefills.""" logging.info("---------Spinning up prefill thread %d.---------", idx) prefill_engine = self._prefill_engines[idx] - prefill_params = self._prefill_params[idx] metadata = prefill_engine.get_tokenizer() tokenizer = prefill_engine.build_tokenizer(metadata) logging.info("---------Prefill params %d loaded.---------", idx) @@ -524,6 +526,20 @@ def _prefill_thread(self, idx: int): if request is None: break + + start_time = time.perf_counter() + prefill_params = self._prefill_params[idx] + end_time = time.perf_counter() + + elapsed_time = (end_time - start_time) * 1e6 + + logging.info(f"AMANGU Log (orchestrator.py): Time taken to set prefill_params=self._prefill_params is {elapsed_time} Micro-seconds.") + + adapter_id = request.adapter_id + if adapter_id != "" and adapter_id not in prefill_params: + logging.info(f"The adapter is not loaded into prefill_params, so bypassing the processing of the request.") + continue + request.metadata.prefill_dequeue_time = time.perf_counter() is_bos = True logging.info( @@ -540,7 +556,7 @@ def _prefill_thread(self, idx: int): # Compute new kv cache for the prefill_content. prefill_result, first_token = prefill_engine.prefill( - params=prefill_params, + params=prefill_params[adapter_id], padded_tokens=padded_tokens, true_length=true_length, ) @@ -655,7 +671,7 @@ def _generate_thread(self, idx: int): # State to store things like running kv cache in. decode_state = generate_engine.init_decode_state() - generate_params = self._generate_params[idx] + # generate_params = self._generate_params[idx] logging.info("---------Generate params %d loaded.---------", idx) time_of_last_generate = time.time() time_of_last_print = time.time() @@ -704,8 +720,10 @@ def _generate_thread(self, idx: int): block |= not self._transfer_backlogs[idx].empty() try: new_request = my_generate_backlog.get(block=block, timeout=1.0) + if new_request is None: break + new_request.metadata.generate_dequeue_time = time.perf_counter() if ( self._metrics_collector @@ -760,9 +778,24 @@ def _generate_thread(self, idx: int): my_slots.qsize() < max_concurrent_decodes ), "At this point we must have some requests inserted into the slots." + start_time = time.perf_counter() + generate_params = self._generate_params[idx] + end_time = time.perf_counter() + + elapsed_time = (end_time - start_time) * 1e6 + + logging.info(f"AMANGU Log (orchestrator.py): Time taken to set generate_params=self._generate_params is {elapsed_time} Micro-seconds.") + + adapter_id = "base_params" + if new_request != None: + adapter_id = new_request.adapter_id + if adapter_id != "" and adapter_id not in generate_params: + logging.info(f"The adapter is not loaded into generate_params, so bypassing the processing of the request.") + continue + # Now we actually take a generate step on requests in the slots. decode_state, sampled_tokens = generate_engine.generate( - generate_params, decode_state + generate_params[adapter_id], decode_state ) sampled_tokens.copy_to_host_async() # Respond to detokenization backpressure. @@ -901,6 +934,63 @@ def _detokenize_thread(self, is_prefill: bool, idx: int): slot, active_request = data my_live_requests[slot] = active_request + def loadAndApplyAdapter( + self, + adapter_id, + adapter_config_path, + adapter_weights_path): + logging.info(f"Loading and applying fine-tuning adapter to base weights") + + for index, params in enumerate(self._prefill_params): + if adapter_id not in params: + params[adapter_id] = copy.deepcopy(params["base_params"]) + self._prefill_engines[index].load_and_apply_adapter(params[adapter_id], + adapter_config_path, + adapter_weights_path) + else: + logging.info(f"Adapter={adapter_id} is already present in the prefill_params.") + + for index, params in enumerate(self._generate_params): + if adapter_id not in params: + params[adapter_id] = copy.deepcopy(params["base_params"]) + self._generate_engines[index].load_and_apply_adapter(params[adapter_id], + adapter_config_path, + adapter_weights_path) + else: + logging.info(f"Adapter={adapter_id} is already present in the generate_params.") + + def unloadAdapter( + self, + adapter_id): + logging.info(f"Unloading the adapter with adapter_id={adapter_id}") + + for params in self._prefill_params: + if adapter_id in params: + del params[adapter_id] + logging.info(f"Successfully deleted Adapter={adapter_id} from the prefill_params.") + else: + logging.info(f"Adapter={adapter_id} is not there in the prefill_params.") + + for params in self._generate_params: + if adapter_id in params: + del params[adapter_id] + logging.info(f"Successfully deleted Adapter={adapter_id} from the generate_params.") + else: + logging.info(f"Adapter={adapter_id} is not there in the generate_params.") + + def mayBeListLoadedAdapters(self): + logging.info(f"Listing loaded adapters:") + + loaded_adapters_in_prefill = [] + for params in self._prefill_params: + loaded_adapters_in_prefill.extend(list(params.keys())) + logging.info(f"In prefill_params: {loaded_adapters_in_prefill}") + + loaded_adapters_in_generate = [] + for params in self._generate_params: + loaded_adapters_in_generate.extend(list(params.keys())) + logging.info(f"In generate_params: {loaded_adapters_in_generate}") + class LLMOrchestrator(jetstream_pb2_grpc.OrchestratorServicer): """Coordinates a set of prefill and generate slices for LLM decoding.""" @@ -1005,6 +1095,7 @@ async def Decode( # pylint: disable=invalid-overridden-method prefill_content=prefill_content, is_client_side_tokenization=is_client_side_tokenization, return_channel=return_channel, + adapter_id=request.adapter_id, metadata=ActiveRequestMetadata( start_time=request.metadata.start_time, prefill_enqueue_time=time.perf_counter(), diff --git a/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto index f06d89d5..70f501df 100644 --- a/jetstream/core/proto/jetstream.proto +++ b/jetstream/core/proto/jetstream.proto @@ -59,8 +59,10 @@ message DecodeRequest { Metadata metadata = 7; } + string adapter_id = 8; + reserved 1, 2, 3; - // Next ID: 8 + // Next ID: 9 } message DecodeResponse { @@ -91,4 +93,50 @@ message HealthCheckRequest {} message HealthCheckResponse { // Denotes whether the model server is live bool is_live = 1; -} \ No newline at end of file +} + +service MultiAdapterManager { + // Lists all the currently loaded LoRA adapters + rpc ListAdapters (ListAdaptersRequest) returns (ListAdaptersResponse) {} + + // Loads a new LoRA adapter. + rpc LoadAdapter (LoadAdapterRequest) returns (LoadAdapterResponse) {} + + // Unloads a LoRA adapter + rpc UnloadAdapter (UnloadAdapterRequest) returns (UnloadAdapterResponse) {} +} + +message ListAdaptersRequest {} + +message ListAdaptersResponse { + bool success = 1; // True if successful, False otherwise + string error_message = 2; // Error message if listing the adapters + repeated AdapterInfo adapter_infos = 3; // List of information about loaded adapters. +} + +// Information about a single loaded LoRA adapter +message AdapterInfo { + string adapter_id = 1; + int64 loading_cost = 2; +} + +message LoadAdapterRequest { + string adapter_id = 1; // Unique ID for the adapter + string adapter_config_path = 2; // Path to the LoRA adapter config + string adapter_weights_path = 3; // Path to the LoRA adapter weights (a directory with weights & config) +} + +message LoadAdapterResponse { + bool success = 1; // True if successful, false otherwise + string error_message = 2; // Error message if loading failed +} + +message UnloadAdapterRequest { + string adapter_id = 1; // ID of the adapter to unload +} + +message UnloadAdapterResponse { + bool success = 1; // True if successful, false otherwise + string error_message = 2; // Error message if unloading failed +} + diff --git a/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py index 0b146032..da19b97f 100644 --- a/jetstream/core/proto/jetstream_pb2.py +++ b/jetstream/core/proto/jetstream_pb2.py @@ -13,8 +13,9 @@ # limitations under the License. # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# source: jetstream/core/proto/jetstream.proto -# Protobuf Python Version: 4.25.1 +# NO CHECKED-IN PROTOBUF GENCODE +# source: jetstream.proto +# Protobuf Python Version: 5.29.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -25,37 +26,51 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xfc\x02\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3' -) + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fjetstream.proto\x12\x0fjetstream_proto\"\x90\x03\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x12\x12\n\nadapter_id\x18\x08 \x01(\t\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04\"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\"\x14\n\x12HealthCheckRequest\"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\"\x15\n\x13ListAdaptersRequest\"s\n\x14ListAdaptersResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\x12\x33\n\radapter_infos\x18\x03 \x03(\x0b\x32\x1c.jetstream_proto.AdapterInfo\"7\n\x0b\x41\x64\x61pterInfo\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0cloading_cost\x18\x02 \x01(\x03\"c\n\x12LoadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x1b\n\x13\x61\x64\x61pter_config_path\x18\x02 \x01(\t\x12\x1c\n\x14\x61\x64\x61pter_weights_path\x18\x03 \x01(\t\"=\n\x13LoadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\"*\n\x14UnloadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"?\n\x15UnloadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t2\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse\"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse\"\x00\x32\xb2\x02\n\x13MultiAdapterManager\x12]\n\x0cListAdapters\x12$.jetstream_proto.ListAdaptersRequest\x1a%.jetstream_proto.ListAdaptersResponse\"\x00\x12Z\n\x0bLoadAdapter\x12#.jetstream_proto.LoadAdapterRequest\x1a$.jetstream_proto.LoadAdapterResponse\"\x00\x12`\n\rUnloadAdapter\x12%.jetstream_proto.UnloadAdapterRequest\x1a&.jetstream_proto.UnloadAdapterResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages( - DESCRIPTOR, "jetstream.core.proto.jetstream_pb2", _globals -) -if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals["_DECODEREQUEST"]._serialized_start = 58 - _globals["_DECODEREQUEST"]._serialized_end = 438 - _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 294 - _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 321 - _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 323 - _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 356 - _globals["_DECODEREQUEST_METADATA"]._serialized_start = 358 - _globals["_DECODEREQUEST_METADATA"]._serialized_end = 388 - _globals["_DECODERESPONSE"]._serialized_start = 441 - _globals["_DECODERESPONSE"]._serialized_end = 772 - _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 607 - _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 623 - _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 626 - _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 755 - _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 714 - _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 755 - _globals["_HEALTHCHECKREQUEST"]._serialized_start = 774 - _globals["_HEALTHCHECKREQUEST"]._serialized_end = 794 - _globals["_HEALTHCHECKRESPONSE"]._serialized_start = 796 - _globals["_HEALTHCHECKRESPONSE"]._serialized_end = 834 - _globals["_ORCHESTRATOR"]._serialized_start = 837 - _globals["_ORCHESTRATOR"]._serialized_end = 1022 +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'jetstream_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_DECODEREQUEST']._serialized_start=37 + _globals['_DECODEREQUEST']._serialized_end=437 + _globals['_DECODEREQUEST_TEXTCONTENT']._serialized_start=293 + _globals['_DECODEREQUEST_TEXTCONTENT']._serialized_end=320 + _globals['_DECODEREQUEST_TOKENCONTENT']._serialized_start=322 + _globals['_DECODEREQUEST_TOKENCONTENT']._serialized_end=355 + _globals['_DECODEREQUEST_METADATA']._serialized_start=357 + _globals['_DECODEREQUEST_METADATA']._serialized_end=387 + _globals['_DECODERESPONSE']._serialized_start=440 + _globals['_DECODERESPONSE']._serialized_end=771 + _globals['_DECODERESPONSE_INITIALCONTENT']._serialized_start=606 + _globals['_DECODERESPONSE_INITIALCONTENT']._serialized_end=622 + _globals['_DECODERESPONSE_STREAMCONTENT']._serialized_start=625 + _globals['_DECODERESPONSE_STREAMCONTENT']._serialized_end=754 + _globals['_DECODERESPONSE_STREAMCONTENT_SAMPLE']._serialized_start=713 + _globals['_DECODERESPONSE_STREAMCONTENT_SAMPLE']._serialized_end=754 + _globals['_HEALTHCHECKREQUEST']._serialized_start=773 + _globals['_HEALTHCHECKREQUEST']._serialized_end=793 + _globals['_HEALTHCHECKRESPONSE']._serialized_start=795 + _globals['_HEALTHCHECKRESPONSE']._serialized_end=833 + _globals['_LISTADAPTERSREQUEST']._serialized_start=835 + _globals['_LISTADAPTERSREQUEST']._serialized_end=856 + _globals['_LISTADAPTERSRESPONSE']._serialized_start=858 + _globals['_LISTADAPTERSRESPONSE']._serialized_end=973 + _globals['_ADAPTERINFO']._serialized_start=975 + _globals['_ADAPTERINFO']._serialized_end=1030 + _globals['_LOADADAPTERREQUEST']._serialized_start=1032 + _globals['_LOADADAPTERREQUEST']._serialized_end=1131 + _globals['_LOADADAPTERRESPONSE']._serialized_start=1133 + _globals['_LOADADAPTERRESPONSE']._serialized_end=1194 + _globals['_UNLOADADAPTERREQUEST']._serialized_start=1196 + _globals['_UNLOADADAPTERREQUEST']._serialized_end=1238 + _globals['_UNLOADADAPTERRESPONSE']._serialized_start=1240 + _globals['_UNLOADADAPTERRESPONSE']._serialized_end=1303 + _globals['_ORCHESTRATOR']._serialized_start=1306 + _globals['_ORCHESTRATOR']._serialized_end=1491 + _globals['_MULTIADAPTERMANAGER']._serialized_start=1494 + _globals['_MULTIADAPTERMANAGER']._serialized_end=1800 # @@protoc_insertion_point(module_scope) diff --git a/jetstream/core/proto/jetstream_pb2_grpc.py b/jetstream/core/proto/jetstream_pb2_grpc.py index d571ade8..5de13a1e 100644 --- a/jetstream/core/proto/jetstream_pb2_grpc.py +++ b/jetstream/core/proto/jetstream_pb2_grpc.py @@ -135,3 +135,75 @@ def HealthCheck( timeout, metadata, ) + + +class MultiAdapterManagerStub(object): + """MultiAdapterManagerStub.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.ListAdapters = channel.unary_unary( + '/jetstream_proto.MultiAdapterManager/ListAdapters', + request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersRequest.SerializeToString, + response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersResponse.FromString, + _registered_method=True) + self.LoadAdapter = channel.unary_unary( + '/jetstream_proto.MultiAdapterManager/LoadAdapter', + request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterRequest.SerializeToString, + response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterResponse.FromString, + _registered_method=True) + self.UnloadAdapter = channel.unary_unary( + '/jetstream_proto.MultiAdapterManager/UnloadAdapter', + request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterRequest.SerializeToString, + response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterResponse.FromString, + _registered_method=True) + + +class MultiAdapterManagerServicer(object): + """TODO: Merge this with main JetStream core once we settle on an API.""" + + def ListAdapters(self, request, context): + """Lists all the currently loaded LoRA adapters.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def LoadAdapter(self, request, context): + """Check the feasibility and load the new LoRA adapter.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def UnloadAdapter(self, request, context): + """Unload a LoRA adapter.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_MultiAdapterManagerServicer_to_server(servicer, server): + rpc_method_handlers = { + "ListAdapters": grpc.unary_unary_rpc_method_handler( + servicer.ListAdapters, + request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersRequest.FromString, + response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersResponse.SerializeToString, + ), + "LoadAdapter": grpc.unary_unary_rpc_method_handler( + servicer.LoadAdapter, + request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterRequest.FromString, + response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterResponse.SerializeToString, + ), + "UnloadAdapter": grpc.unary_unary_rpc_method_handler( + servicer.UnloadAdapter, + request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterRequest.FromString, + response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + "jetstream_proto.MultiAdapterManager", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index f9b95aef..6196bc83 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -32,6 +32,7 @@ import jax from jetstream.core import config_lib from jetstream.core import orchestrator +from jetstream.core import adapter_manager from jetstream.core.metrics.prometheus import JetstreamMetricsCollector from jetstream.core.proto import jetstream_pb2_grpc from jetstream.engine import warmup_utils, engine_api @@ -64,6 +65,11 @@ async def do_init(): jetstream_pb2_grpc.add_OrchestratorServicer_to_server( orchestrator.LLMOrchestrator(driver=self._driver), self._grpc_server ) + + jetstream_pb2_grpc.add_MultiAdapterManagerServicer_to_server( + adapter_manager.MultiLoraManager(driver=self._driver), self._grpc_server + ) + self._grpc_server.add_secure_port(f"{_HOST}:{port}", credentials) async def _async_start(self) -> None: @@ -113,9 +119,9 @@ def create_driver( An orchestrator driver. """ engines = config_lib.get_engines(config, devices=devices) - prefill_params = [pe.load_params() for pe in engines.prefill_engines] - generate_params = [ge.load_params() for ge in engines.generate_engines] - shared_params = [ie.load_params() for ie in engines.interleaved_engines] + prefill_params = [{"base_params": pe.load_params()} for pe in engines.prefill_engines] + generate_params = [{"base_params": ge.load_params()} for ge in engines.generate_engines] + shared_params = [{"base_params": ie.load_params()} for ie in engines.interleaved_engines] logging.info("Loaded all weights.") interleaved_mode = ( len(config.prefill_slices) + len(config.generate_slices) == 0 diff --git a/jetstream/tools/llm_gateway_proxy_client.py b/jetstream/tools/llm_gateway_proxy_client.py new file mode 100644 index 00000000..33a6e379 --- /dev/null +++ b/jetstream/tools/llm_gateway_proxy_client.py @@ -0,0 +1,145 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A test request.""" + +from typing import Sequence + +from absl import app +from absl import flags +import grpc +from jetstream.core.proto import jetstream_pb2 +from jetstream.core.proto import jetstream_pb2_grpc +from jetstream.engine.token_utils import load_vocab + + +_SERVER = flags.DEFINE_string("server", "0.0.0.0", "server address") +_PORT = flags.DEFINE_string("port", "9000", "port to ping") +#_TEXT = flags.DEFINE_string("text", "My dog is cute", "The message") +_TEXT = flags.DEFINE_string("text", "22 year old", "The message") +_MAX_TOKENS = flags.DEFINE_integer( + "max_tokens", 3, "Maximum number of output/decode tokens of a sequence" +) + +_ADAPTER_ID = flags.DEFINE_string( + "adapter_id", + None, + "Id of the fine-tuned adapter to be loaded on top of the base model.", + required=False, +) + +_ADAPTER_CONFIG_PATH = flags.DEFINE_string( + "adapter_config_path", + None, + "Path of the fine-tuned adapter to be loaded from.", + required=False, +) + +_ADAPTER_WEIGHTS_PATH = flags.DEFINE_string( + "adapter_weights_path", + None, + "Path of the fine-tuned adapter to be loaded from.", + required=False, +) + +_TEST_API_NAME = flags.DEFINE_string( + "test_api_name", + None, + "Name of the JetStream API to call.", + required=True, +) + + +def main(argv: Sequence[str]) -> None: + del argv + # Note: Uses insecure_channel only for local testing. Please add grpc + # credentials for Production. + address = f"{_SERVER.value}:{_PORT.value}" + with grpc.insecure_channel(address) as channel: + grpc.channel_ready_future(channel).result() + stub = jetstream_pb2_grpc.MultiAdapterManagerStub(channel) + print(f"Sending request to: {address}") + + if _TEST_API_NAME.value == "load_adapter": + print(f"Calling the JetStream/MultiAdapterManager/LoadAdapter.") + + adapter_id=_ADAPTER_ID.value + adapter_config_path=_ADAPTER_CONFIG_PATH.value + adapter_weights_path=_ADAPTER_WEIGHTS_PATH.value + + if adapter_id == None or adapter_weights_path == None or adapter_config_path == None: + print(f"For `load_adapter` API call, `adapter_id`, `adapter_config_path` and `adapter_weights_path` must be passed.") + return + + request = jetstream_pb2.LoadAdapterRequest( + adapter_id=adapter_id, + adapter_config_path=adapter_config_path, + adapter_weights_path=adapter_weights_path + ) + + response = stub.LoadAdapter(request) + + if response.success is True: + print(f"Adapter={adapter_id} is loaded successfully.") + else: + print(f"Adapter={adapter_id} loading failed with error={response.error_message}") + + elif _TEST_API_NAME.value == "unload_adapter": + print(f"Calling the JetStream/MultiAdapterManager/UnloadAdapter.") + + adapter_id=_ADAPTER_ID.value + + if adapter_id == None: + print(f"For `unload_adapter` API call, `adapter_id` must be passed.") + return + + request = jetstream_pb2.UnloadAdapterRequest( + adapter_id=adapter_id, + ) + + response = stub.UnloadAdapter(request) + + if response.success is True: + print(f"Adapter={adapter_id} is unloaded successfully.") + else: + print(f"Adapter={adapter_id} unloading failed with error={response.error_message}") + + elif _TEST_API_NAME.value == "list_adapters": + print(f"Calling the JetStream/MultiAdapterManager/ListAdapters.") + + request = jetstream_pb2.ListAdaptersRequest() + + response = stub.ListAdapters(request) + + if response.success is True: + print(f"`ListAdapter` call responded successfully. Here is the list of adapters loaded on server:") + for adapter_info in response.adapter_infos: + print(f"adapter_id={adapter_info.adapter_id}, loading_cost={adapter_info.loading_cost}.") + else: + print(f"`ListAdapter` call failed with error={error_message}") + + elif _TEST_API_NAME.value == None: + print(f"`test_api_name` flag is not set. So exiting.") + return + + else: + print(f"API={_TEST_API_NAME.value} is not implemented yet. So exiting.") + return + + + print(f"API calls ended.") + + +if __name__ == "__main__": + app.run(main) From fb88eca09152e35026bd981e4e725d51a256fdad Mon Sep 17 00:00:00 2001 From: aman2930 Date: Tue, 18 Feb 2025 18:15:22 +0000 Subject: [PATCH 05/22] 1) Implemented adapter_tensorstore module to store and manage the adapters. Its functionality includes loading, unloading of adapters between CPU RAM and HBM. It also follows LRU policy to evict the adapter if a new load_adapter request comes up. Currently it is only storing the adapter as separate tensors (lora_a and lora_b). Calculation of lora_b x lora_a is being done in prefill() and generate() during decode request. Adapter_tensorstore can be configured with a max_limit on HBM and RAM. 2) Functionality to load from a catalog file at the start of the server is added. If no file is given, it will just load the base params. Loading from the catalog file is done on CPU RAM. After that based on incoming requests, those params are moved/evicted to/from HBM. 3) Some proto updates to get only single path for each adapter, and that path is expected to have an adapter_config.json and Orbax format weights in 0/items folder. --- jetstream/core/adapter_manager.py | 59 +-- jetstream/core/adapter_tensorstore.py | 429 ++++++++++++++++++ jetstream/core/orchestrator.py | 200 ++++++-- jetstream/core/proto/jetstream.proto | 7 +- jetstream/core/proto/jetstream_pb2.py | 31 +- .../core/proto/jetstream_pb2_grpc_original.py | 209 +++++++++ jetstream/core/server_lib.py | 3 + jetstream/tools/decode_multi_requester.py | 337 ++++++++++++++ jetstream/tools/llm_gateway_proxy_client.py | 25 +- jetstream/tools/requester.py | 6 + 10 files changed, 1208 insertions(+), 98 deletions(-) create mode 100644 jetstream/core/adapter_tensorstore.py create mode 100644 jetstream/core/proto/jetstream_pb2_grpc_original.py create mode 100644 jetstream/tools/decode_multi_requester.py diff --git a/jetstream/core/adapter_manager.py b/jetstream/core/adapter_manager.py index 358fc0e2..0ca76dfe 100644 --- a/jetstream/core/adapter_manager.py +++ b/jetstream/core/adapter_manager.py @@ -19,10 +19,10 @@ import grpc from typing import Optional - +from jetstream.core import adapter_tensorstore +from jetstream.core import orchestrator from jetstream.core.proto import jetstream_pb2_grpc from jetstream.core.proto import jetstream_pb2 -from jetstream.core import orchestrator def calculate_loading_cost(adapter_path: str): @@ -36,7 +36,6 @@ class MultiLoraManager(jetstream_pb2_grpc.MultiAdapterManagerServicer): def __init__(self, driver: orchestrator.Driver): self._driver = driver - self.loaded_adapters = {} # Dictionary to track loaded adapters def ListAdapters( self, @@ -46,25 +45,32 @@ def ListAdapters( """ListAdapters all loaded LoRA adapters.""" try: - logging.info("AMANGU LOG: Before making call to mayBeListLoadedAdapters.") - self._driver.mayBeListLoadedAdapters() - logging.info("AMANGU LOG: After making call to mayBeListLoadedAdapters.") + adapters = self._driver.listAdaptersFromTensorstore() adapter_infos = [] - for adapter_id, adapter_data in self.loaded_adapters.items(): + for adapter_id, adapter_data in adapters.items(): + if adapter_data.status == "loaded_hbm": + loading_cost = 0 + elif adapter_data.status == "loaded_cpu": + loading_cost = 1 + elif adapter_data.status == "unloaded": + loading_cost = 2 + else: + loading_cost = -1 + adapter_info = jetstream_pb2.AdapterInfo( - adapter_id=adapter_id, - loading_cost=adapter_data["loading_cost"] - ) - adapter_infos.append(adapter_info) + adapter_id=adapter_id, + loading_cost=loading_cost, + size_hbm=adapter_data.size_hbm, + size_cpu=adapter_data.size_cpu, + last_accessed=adapter_data.last_accessed, + status=adapter_data.status) - # logging.info("AMANGU Log (adapter_manager.py): ListAdapters is still under implementation") - logging.info("AMANGU LOG: List adapters --> Before returning success.") - logging.info(f"AMANGU LOG: List of adapters --> {adapter_infos}.") + adapter_infos.append(adapter_info) return jetstream_pb2.ListAdaptersResponse(success=True, adapter_infos=adapter_infos) except Exception as e: - logging.info("AMANGU LOG: List adapters --> Before returning failure.") + logging.info(f"Listing of adapters failed with error: {str(e)}") return jetstream_pb2.ListAdaptersResponse(success=False, error_message=str(e)) @@ -76,21 +82,11 @@ def LoadAdapter( """Load a LoRA adapter as mentioned in the request.""" try: - # Load the adapter using MaxEngine in the Driver - # Implmentation to load adatper using MaxEnbine and request.adapter_path - - # Store adapter info (e.g. loading cost - self._driver.loadAndApplyAdapter(request.adapter_id, - request.adapter_config_path, - request.adapter_weights_path) - - self.loaded_adapters[request.adapter_id] = { - "adapter_path": request.adapter_weights_path, - "loading_cost": calculate_loading_cost(request.adapter_weights_path) - } + self._driver.loadAdapterToTensorstore(request.adapter_id, request.adapter_path) return jetstream_pb2.LoadAdapterResponse(success=True) except Exception as e: + logging.info(f"Loading of adapter_id={request.adapter_id} failed with error: {str(e)}") return jetstream_pb2.LoadAdapterResponse(success=False, error_message=str(e)) @@ -101,16 +97,11 @@ def UnloadAdapter( ) -> jetstream_pb2.UnloadAdapterResponse: """Unload a LoRA adapter as mentioned in the request.""" - # logging.info("AMANGU Log (adapter_manager.py): UnloadAdapter is still under implementation") try: - # Unload the adapter - # Implementation to unload adapter from MaxEngine - self._driver.unloadAdapter(request.adapter_id) - - del self.loaded_adapters[request.adapter_id] + self._driver.unloadAdapterFromTensorstore(request.adapter_id) return jetstream_pb2.UnloadAdapterResponse(success=True) except Exception as e: - logging.info(f"AMANGU Log(adapter_manager.py): UnloadAdapter failed with error {str(e)}") + logging.info(f"Loading of adapter_id={request.adapter_id} failed with error: {str(e)}") return jetstream_pb2.UnloadAdapterResponse(success=False, error_message=str(e)) diff --git a/jetstream/core/adapter_tensorstore.py b/jetstream/core/adapter_tensorstore.py new file mode 100644 index 00000000..169a02b7 --- /dev/null +++ b/jetstream/core/adapter_tensorstore.py @@ -0,0 +1,429 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Manages the list of fine-tuned adapters loaded on top of the base model for serving. +""" + +import logging +import dataclasses + +import jax +import jax.numpy as jnp +from flax import struct +import time +import asyncio +import functools +from typing import Dict, Optional, Any +import numpy as np + + +@dataclasses.dataclass +class AdapterMetadata: + adapter_id: str + adapter_path: str + status: str = "unloaded" # "loaded_hbm", "loaded_cpu", "loading", "unloading" + size_hbm: int = 0 # Size in HBM (bytes) + size_cpu: int = 0 # Size in CPU RAM (bytes) + last_accessed: float = 0.0 # timestamp + # rank: int = 8 + config: Dict[str, Any] = None + + +class AdapterTensorStore: + def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int): + self.hbm_memory_budget = hbm_memory_budget + self.cpu_memory_budget = cpu_memory_budget + self.adapter_registry: Dict[str, AdapterMetadata] = {} # All known adapters + self.loaded_adapters_hbm: Dict[str, jnp.ndarray] = {} # adapter_id -> Unified LoRA params (in HBM) + self.loaded_adapters_cpu: Dict[str, np.ndarray] = {} # adapter_id -> Unified LoRA params (in CPU RAM) + self.current_hbm_usage: int = 0 + self.current_cpu_usage: int = 0 + self.running_requests: int = 0 # Number of async tasks which are in "loading" state + self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety + + + def register_adapter(self, adapter_id: str, adapter_path: str, config: Dict[str, Any]): + """Registers a new LoRA adatper.""" + if adapter_id in self.adapter_registry: + raise ValueError(f"Adapter with ID '{adapter_id}' already registered.") + self.adapter_registry[adapter_id] = AdapterMetadata( + adapter_id=adapter_id, + adapter_path=adapter_path, + config=config) + + + def _get_size(self, arr: jnp.ndarray | np.ndarray) -> int: + """Calculates the size of a JAX or NumPy array in bytes.""" + # Use asarray to handle both JAX and NumPy arrays consistently + return np.asarray(arr).nbytes + + async def _transfer_to_hbm(self, adapter_id: str): + """Transfers an adapter from CPU RAM to HBM.""" + if adapter_id not in self.loaded_adapters_cpu: + raise ValueError(f"Adapter '{adapter_id}' not loaded in CPU RAM.") + + async with self.lock: #Acquire lock + metadata = self.adapter_registry[adapter_id] + + # Check if we have enough space in HBM; evict if necessary + while (self.current_hbm_usage + metadata.size_hbm) > self.hbm_memory_budget: + if not self._evict(from_hbm=True): + raise RuntimeError("Not enough HBM to transfer adapter, and eviction failed.") + + # Move from CPU to HBM + self.loaded_adapters_hbm[adapter_id] = jnp.array(self.loaded_adapters_cpu[adapter_id]) # Convert to JAX array + del self.loaded_adapters_cpu[adapter_id] + + self.current_cpu_usage -= metadata.size_cpu + self.current_hbm_usage += metadata.size_hbm + + metadata.status = "loaded_hbm" + metadata.last_accessed = time.time() + + + async def _transfer_to_cpu(self, adapter_id: str): + """Transfers an adapter from HBM to CPU RAM.""" + + if adapter_id not in self.loaded_adapters_hbm: + raise ValueError(f"Adapter '{adapter_id}' not loaded in HBM.") + + async with self.lock: + metadata = self. adapter_registry[adapter_id] + + # Check if we have enough space in CPU; evict if necessary. + while (self.current_cpu_usage + metadata.size_cpu) > self.cpu_memory_budget: + if not self._evict(from_hbm=False): + raise RuntimeError("Not enough CPU RAM to transfer adapter, and eviction failed.") + + # Move from HBM to CPU + self.loaded_adapters_cpu[adapter_id] = np.array(self.loaded_adapters_hbm[adapter_id]) + del self.loaded_adapters_hbm[adapter_id] + + self.current_hbm_usage -= metadata.size_hbm + self.current_cpu_usage += metadata.size_cpu + + metadata.status = "loaded_cpu" + metadata.last_accessed = time.time() + + + def _get_size_of_pytree(self, params): + params_bytes = jax.tree_util.tree_map(lambda x: x.nbytes, params) + total_bytes = jax.tree_util.tree_reduce(lambda x, y: x + y, params_bytes) + return total_bytes + + + def _as_np_array(self, params): + + def convert_if_jnp(leaf): + return np.array(leaf) + + return jax.tree_util.tree_map(convert_if_jnp, params) + + + def _as_jnp_array(self, params): + + def convert_if_np(leaf): + return jnp.array(leaf) + + return jax.tree_util.tree_map(convert_if_np, params) + + + async def load_adapter(self, adapter_id: str, adapter_weights = None, to_hbm: bool = True): + """Loads a LoRA adapter's weights, managing HBM and CPU memory.""" + if adapter_id not in self.adapter_registry: + raise ValueError(f"Adapter with ID '{adapter_id}' not registered.") + + metadata = self.adapter_registry[adapter_id] + + async with self.lock: # Acquire lock for thread safety + #logging.info(f"AMANGU Logs: Lock aquired by loading section of coroutine {asyncio.current_task().get_name()}.") + if metadata.status in ("loaded_hbm", "loaded_cpu"): + metadata.last_accessed = time.time() + + # if already loaded in HBM and we want HBM, or + # already loaded in CPU and we want CPU, we're done. + if ((to_hbm and metadata.status == "loaded_hbm") or + not to_hbm and metadata.status == "loaded_cpu"): + return + elif to_hbm and metadata.status == "loaded_cpu": + # Transfer from cpu to hbm + self._transfer_to_hbm(adapter_id) + return + elif not to_hbm and metadata.status == "loaded_hbm": + # Transfer from hbm to cpu + self._transfer_to_cpu(adapter_id) + return + + if metadata.status == "loading": + # Wait untill loading is done. + while metadata.status == "loading": + await asyncio.sleep(0.1) # Short sleep to avoid busy-waiting + + # Make recursive call to load_adapter to copy to device + await self.load_adapter(adapter_id, adapter_weights, to_hbm) + return + + metadata.status = "loading" + self.running_requests += 1 + #logging.info(f"AMANGU Logs: Lock released by loading section of coroutine {asyncio.current_task().get_name()}.") + + # Load the adapter (asynchronous) + loop = asyncio.get_running_loop() + try: + + # TODO(amangu): Placeholder for the loading logic. Replace with code to load + # the LoRA weights from the specific path. + + # --- ASYNCHRONOUS LOADING (CRITICAL!) --- + # Use asyncio.to_thread or similar to avoid blocking + + # TODO(amangu): Assumed that load_lora_weights is defined elsewhere + # which returns a dictionary: {"lora_A": ..., "lora_B": ...}. Adapt this part + # based on the actual structure of the loaded LoRA weights. + + if adapter_weights is None: + adapter_weights = await loop.run_in_executor( + None, + functools.partial(load_lora_weights, metadata.adapter_path)) + + async with self.lock: # Critical section for memory management + # Combine lora_a and lora_b to form a unified parameter. + # TODO(amangu): Check if combining and storing is having any optimization. + # unified_lora_params = self._combine_lora_params(lora_weights, metadata.rank) + #logging.info(f"AMANGU Logs: Lock aquired by saving section of coroutine {asyncio.current_task().get_name()}.") + + unified_lora_params = adapter_weights + unified_lora_params_as_jnp_array = self._as_jnp_array(unified_lora_params) + unified_lora_params_as_np_array = self._as_np_array(unified_lora_params) + del unified_lora_params + + # Get size of unified_lora_params when they are saved in HBM as JAX array + adapter_size_hbm = self._get_size_of_pytree(unified_lora_params_as_jnp_array) + + # Get size of unified_lora_params when they are saved in CPU RAM as NumPy array + adapter_size_cpu = self._get_size_of_pytree(unified_lora_params_as_np_array) + + metadata.size_hbm = adapter_size_hbm + metadata.size_cpu = adapter_size_cpu + + # --- EVICTION (if needed) --- + # Evict if necessary *before* loading into the target memory + if to_hbm: + while (self.current_hbm_usage + adapter_size_hbm) > self.hbm_memory_budget: + if not self._evict(from_hbm=True): + raise RuntimeError("Not enough HBM to load adapter, and eviction failed.") + else: #to_cpu + while (self.current_cpu_usage + adapter_size_cpu) > self.cpu_memory_budget: + if not self._evict(from_hbm=False): + raise RuntimeError("Not enough CPU RAM to load adapter, and eviction failed.") + + # Now that we have space (potentially), do the actual loading + if to_hbm: + self.loaded_adapters_hbm[adapter_id] = unified_lora_params_as_jnp_array # Convert the PyTree to Jax Array + self.current_hbm_usage += adapter_size_hbm + metadata.status = "loaded_hbm" + + else: #to cpu + self.loaded_adapters_cpu[adapter_id] = unified_lora_params_as_np_array # Convert the PyTree to NumPy Array + self.current_cpu_usage += adapter_size_cpu + metadata.status = "loaded_cpu" + + metadata.last_accessed = time.time() + #logging.info(f"AMANGU Logs: Lock released by saving section of coroutine {asyncio.current_task().get_name()}.") + + except Exception as e: + async with self.lock: + metadata.status = "unloaded" # Mark as unloaded on error + raise e # Re-Raise the exception + finally: + async with self.lock: + self.running_requests -= 1 + + + def _combine_lora_params(self, lora_weights, rank): + # Create a list to hold the combined LoRA parameters + combined_lora_params = [] + + for i in range(0, len(lora_weights), 2): + lora_a = lora_weights[i] + lora_b = lora_weights[i+1] + + # Reshape and concatenate lora_a and lora_b + # Assuming 'br,rnd->bnd' einsum configuration, where 'b' is batch, + # 'r' is rank, 'n' is num_heads, and 'd' is head_dim + num_heads = lora_a.shape[1] # Get number of heads from lora_a + head_dim = lora_a.shape[2] # Get head dimension from lora_a + + lora_a = jnp.transpose(lora_a, (1, 2, 0)) # (r, n, d) -> (n, d, r) + lora_b_reshaped = jnp.reshape(lora_b, (num_heads, head_dim, rank)) # (n * d, r) -> (n, d, r) + + combined_lora_param = jnp.einsum("ndr,ndr->ndr", lora_a, lora_b_reshaped) + combined_lora_params.append(combined_lora_param) + + # Concatenate the parameters for all layers to form a single unified parameter + unified_lora_params = jnp.stack(combined_lora_params, axis=0) + return unified_lora_params + + + def get_stacked_lora_weights(self, lora_ids: jnp.ndarray, to_hbm: bool = True): + """Retrieves the unified LoRA parameters for the given adapter IDs. + Handles HBM/CPU placement. + """ + + # The logic here is crucial. We have `lora_ids`, an array of shape + #(batch_size,), where each element is the ID of the LoRA adapter + # to use for that request in the batch. You need to use this to + # select the appropriate slices from the *unified* LoRA paramters. + + # 1. Get the unified LoRA paramters for the requested IDs. This + # might involve waiting if some adapters are still loading. + + required_adapters = set(lora_ids.tolist()) # Get unique adapter IDs + for adapter_id in required_adapters: + metadata = self.adapter_registry.get(adapter_id) + + if metadata is None: + raise ValueError(f"Adapter with ID '{adapter_id}' not registered.") + + if metadata.status != "loaded_hbm" and metadata.status != "loaded_cpu": + asyncio.run(self.load_adapter(adapter_id, to_hbm)) # Start loading (async) + elif to_hbm and metadata.status == "loaded_cpu": + self._transfer_to_hbm(adapter_id) + elif not to_hbm and metadata.status == "loaded_hbm": + self._transfer_to_cpu(adapter_id) + + # Wait till all the running requests are completed + while self.running_requests > 0: + time.sleep(0.1) + + # Now all required adapters should be loaded in correct memory (HBM or CPU), get them + if to_hbm: + required_adapters_params = [self.loaded_adapters_hbm[adapter_id] for adapter_id in required_adapters] + else: + required_adapters_params = [self.loaded_adapters_cpu[adapter_id] for adapter_id in required_adapters] + + # Stack the parameters for the required adapters + stacked_params = jax.tree_util.tree_map(lambda *arrs: jnp.stack(arrs), *required_adapters_params) + + # Extract paramters using jnp.take() function for the lora_ids. + retrieved_lora_params = jax.tree_util.tree_map( + lambda arr: jnp.take(arr, lora_ids, axis=0, fill_value=0), + stacked_params) + + return retrieved_lora_params + + + def get_lora_config(self, adapter_id): + metadata = self.adapter_registry.get(adapter_id) + return metadata.config + + + def get_lora_weights(self, adapter_id, to_hbm: bool = True): + """Retrieves the unified LoRA parameters for the given adapter IDs. + Handles HBM/CPU placement. + """ + + # The logic here is crucial. We have `lora_ids`, an array of shape + #(batch_size,), where each element is the ID of the LoRA adapter + # to use for that request in the batch. You need to use this to + # select the appropriate slices from the *unified* LoRA paramters. + + # 1. Get the unified LoRA paramters for the requested IDs. This + # might involve waiting if some adapters are still loading. + + metadata = self.adapter_registry.get(adapter_id) + + if metadata is None: + raise ValueError(f"Adapter with ID '{adapter_id}' not registered.") + + if metadata.status != "loaded_hbm" and metadata.status != "loaded_cpu": + asyncio.run(self.load_adapter(adapter_id, None, to_hbm)) # Start loading (async) + elif to_hbm and metadata.status == "loaded_cpu": + self._transfer_to_hbm(adapter_id) + elif not to_hbm and metadata.status == "loaded_hbm": + self._transfer_to_cpu(adapter_id) + + # Wait till all the running requests are completed + while self.running_requests > 0: + time.sleep(0.1) + + # Now all required adapters should be loaded in correct memory (HBM or CPU), get them + adapter_params = None + if to_hbm: + adapter_params = self.loaded_adapters_hbm[adapter_id] + else: + adapter_params = self.loaded_adapters_cpu[adapter_id] + + return adapter_params + + + async def unload_adapter(self, adapter_id: str): + """Unloads a LoRA adapter's weights and removes it from the TensorStore.""" + if adapter_id not in self.adapter_registry: + raise ValueError(f"Adatper with ID '{adapter_id}' not found.") + + metadata = self.adapter_registry[adapter_id] + + async with self.lock: + if metadata.status == "unloaded": + return # Already unloaded + if metadata.status == "loading": + # Wait for the loading to get complete. + while metadata.status == "loading": + await asyncio.sleep(0.1) + if metadata.status == "loaded_hbm": + del self.loaded_adapters_hbm[adapter_id] + self.current_hbm_usage -= metadata.size_hbm + metadata.status = "unloaded" + elif metadata.status == "loaded_cpu": + del self.loaded_adapters_cpu[adapter_id] + self.current_cpu_usage -= metadata.size_cpu + metadata.status = "unloaded" + + metadata.last_accessed = 0.0 # Reset last accessed time + metadata.size_hbm = 0 + metadata.size_cpu = 0 + + + def list_adapters(self) -> Dict[str, AdapterMetadata]: + """Lists all registered adatpers and their metadata.""" + return self.adapter_registry + + + def _evict(self, from_hbm: bool = True) -> bool: + """Evicts the least recently used adapter from memory (HBM or CPU).""" + + # Find the least recently used adapter that is currently loaded. + lru_adapter_id = None + lru_time = float('inf') + + if from_hbm: + adapters_dict = self.loaded_adapters_hbm + else: + adapters_dict = self.loaded_adapters_cpu + + for adapter_id, metadata in self.adapter_registry.items(): + if metadata.status == "loaded_hbm" if from_hbm else metadata.status == "loaded_cpu": + if metadata.last_accessed < lru_time: + lru_time = metadata.last_accessed + lru_adapter_id = adapter_id + + # If no adapter found to evict, return False + if lru_adapter_id is None: + return False + + # Unload the LRU adapter + self.unload_adapter(lru_adapter_id) # This is not synchronous, but ONLY within the lock + return True + diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 425861ef..5e70885b 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -86,10 +86,13 @@ import threading import time import traceback +import asyncio from typing import Any, AsyncIterator, Optional, Tuple, cast import grpc import jax +import jax.numpy as jnp +from jetstream.core import adapter_tensorstore from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc from jetstream.core.utils import async_multifuture @@ -228,6 +231,8 @@ class Driver: # All metrics we want to monitor should be collected with this _metrics_collector: JetstreamMetricsCollector | None = None + _adapter_tensorstore: adapter_tensorstore.AdapterTensorStore | None = None + def __init__( self, prefill_engines: Optional[list[engine_api.Engine]] = None, @@ -248,6 +253,10 @@ def __init__( if generate_params is None: generate_params = [] + self._adapter_tensorstore = adapter_tensorstore.AdapterTensorStore( + hbm_memory_budget=(20 * (1024 ** 3)), + cpu_memory_budget=(100 * (1024 ** 3))) + logging.info( "Initialising driver with %d prefill engines and %d generate engines.", len(prefill_engines), @@ -423,6 +432,7 @@ def __init__( ) self.live = True self._is_ray_backend = is_ray_backend + # Start all threads for t in self._all_threads: t.start() @@ -442,6 +452,7 @@ def stop(self): ) ) + while any(t.is_alive() for t in self._all_threads): # Empty all backlogs and mark any remaining requests as cancelled. for q in all_backlogs: @@ -527,18 +538,7 @@ def _prefill_thread(self, idx: int): if request is None: break - start_time = time.perf_counter() prefill_params = self._prefill_params[idx] - end_time = time.perf_counter() - - elapsed_time = (end_time - start_time) * 1e6 - - logging.info(f"AMANGU Log (orchestrator.py): Time taken to set prefill_params=self._prefill_params is {elapsed_time} Micro-seconds.") - - adapter_id = request.adapter_id - if adapter_id != "" and adapter_id not in prefill_params: - logging.info(f"The adapter is not loaded into prefill_params, so bypassing the processing of the request.") - continue request.metadata.prefill_dequeue_time = time.perf_counter() is_bos = True @@ -549,17 +549,45 @@ def _prefill_thread(self, idx: int): self._prefill_backlog.qsize(), is_bos, ) + # Tokenize and padding the text or token input. padded_tokens, true_length = self._process_prefill_content( request, tokenizer, is_bos, prefill_engine.max_prefill_length ) + start_time = time.perf_counter() + logging.info(f"AMANGU Log (orchestrator.py): Starting timer for Driver._prefill_thread -> prefill_engine.prefill().") + + adapter_id = request.adapter_id + + if adapter_id == "": + adapter_id = "base_params" + + final_params = None + if adapter_id == "base_params": + final_params = prefill_params[adapter_id] + else: + final_params = copy.deepcopy(prefill_params["base_params"]) + lora_params = self._adapter_tensorstore.get_lora_weights(adapter_id) + lora_config = self._adapter_tensorstore.get_lora_config(adapter_id) + self._prefill_engines[idx].apply_adapter( + final_params, + lora_config, + lora_params) + # Compute new kv cache for the prefill_content. prefill_result, first_token = prefill_engine.prefill( - params=prefill_params[adapter_id], + params=final_params, padded_tokens=padded_tokens, true_length=true_length, ) + + del final_params + end_time = time.perf_counter() + elapsed_time = (end_time - start_time) * 1e6 + + logging.info(f"AMANGU Log (orchestrator.py): Time taken for Driver._prefill_thread -> prefill_engine.prefill() is {elapsed_time} Micro-seconds.") + request.prefill_result = prefill_result # put first token to detokenize queue @@ -571,6 +599,9 @@ def _prefill_thread(self, idx: int): block=True, ) + elapsed_time = request.metadata.transfer_enqueue_time - request.metadata.prefill_dequeue_time + logging.info(f"AMANGU Log (orchestrator.py): Time taken in whole prefill_thread is {elapsed_time} Seconds.") + # Once prefill is complete, place it on the generation queue and block if # full. my_transfer_backlog.put(request, block=True) @@ -597,7 +628,6 @@ def _prefill_thread(self, idx: int): def _jax_transfer_prefill_result( self, new_request: ActiveRequest, target_idx: int ): - logging.info("AMANGU: In _jax_transfer_prefill_result") new_request.prefill_result = jax.device_put( new_request.prefill_result, self._generate_engines[target_idx].get_prefix_destination_sharding(), @@ -613,7 +643,6 @@ def _ray_transfer_prefill_result( def _transfer_prefill_result( self, new_request: ActiveRequest, target_idx: int ): - logging.info("AMANGU: In _transfer_prefill_result") if self._is_ray_backend: self._ray_transfer_prefill_result(new_request, target_idx) else: @@ -623,7 +652,6 @@ def _transfer_thread(self, idx: int): """Transfers the kv cache on an active request to the least full generate backlog.""" transfer_backlog = self._transfer_backlogs[idx] - logging.info("AMANGU: In _transfer_thread") while self.live: # The transfer thread can just sleep until it has work to do. @@ -648,6 +676,8 @@ def _transfer_thread(self, idx: int): # Place the request on the correct generate backlog and block if full. new_request.metadata.generate_enqueue_time = time.perf_counter() self._generate_backlogs[target_idx].put(new_request, block=True) + + elapsed_time = (new_request.metadata.generate_enqueue_time - new_request.metadata.transfer_dequeue_time) * 1e6 logging.info( "Successfully transferred prefill " "from prefill engine %d to generate engine %d " @@ -660,18 +690,22 @@ def _transfer_thread(self, idx: int): def _generate_thread(self, idx: int): """Step token generation and insert prefills from backlog.""" logging.info("---------Spinning up generate thread %d.---------", idx) - logging.info("AMANGU: In _generate_thread") generate_engine = self._generate_engines[idx] my_slots = self._generate_slots[idx] + logging.info(f"AMANGU: In _generate_thread: my_slots size = {my_slots.qsize()}") + logging.info(f"AMANGU: In _generate_thread: max_concurrent_decodes = {generate_engine.max_concurrent_decodes}") my_generate_backlog = self._generate_backlogs[idx] my_detokenize_backlog = self._generate_detokenize_backlogs[idx] # Keep track of what step tokens were generated at. generate_timestep = 0 + generate_engine.print_stats("Pre-start Generate Thread: Before init_decode_state") # State to store things like running kv cache in. decode_state = generate_engine.init_decode_state() + generate_engine.print_stats("Pre-start Generate Thread: After init_decode_state") # generate_params = self._generate_params[idx] + logging.info("---------Generate params %d loaded.---------", idx) time_of_last_generate = time.time() time_of_last_print = time.time() @@ -765,6 +799,7 @@ def _generate_thread(self, idx: int): decode_state = generate_engine.insert( new_request.prefill_result, decode_state, slot=slot ) + del new_request.prefill_result new_request.generate_timestep_added = generate_timestep new_request.complete = np.zeros( @@ -778,26 +813,24 @@ def _generate_thread(self, idx: int): my_slots.qsize() < max_concurrent_decodes ), "At this point we must have some requests inserted into the slots." - start_time = time.perf_counter() generate_params = self._generate_params[idx] - end_time = time.perf_counter() - - elapsed_time = (end_time - start_time) * 1e6 - - logging.info(f"AMANGU Log (orchestrator.py): Time taken to set generate_params=self._generate_params is {elapsed_time} Micro-seconds.") adapter_id = "base_params" - if new_request != None: - adapter_id = new_request.adapter_id - if adapter_id != "" and adapter_id not in generate_params: - logging.info(f"The adapter is not loaded into generate_params, so bypassing the processing of the request.") - continue + + start_time = time.perf_counter() # Now we actually take a generate step on requests in the slots. decode_state, sampled_tokens = generate_engine.generate( generate_params[adapter_id], decode_state ) sampled_tokens.copy_to_host_async() + + end_time = time.perf_counter() + + elapsed_time = (end_time - start_time) * 1e6 + + logging.info(f"AMANGU Log (orchestrator.py): Time taken to execute Decode.generate_thread -> generate_engine.generate is {elapsed_time} Micro-seconds.") + # Respond to detokenization backpressure. my_detokenize_backlog.put((generate_timestep, sampled_tokens), block=True) generate_timestep += 1 @@ -817,7 +850,6 @@ def _detokenize_thread(self, is_prefill: bool, idx: int): # For all filled my_slots, pop the sampled token onto the relevant # requests return channel. If it done, place it back onto free slots. - logging.info("AMANGU: In _detokenize_thread") if is_prefill: my_detokenize_backlog = self._prefill_detokenize_backlogs[idx] else: @@ -934,6 +966,113 @@ def _detokenize_thread(self, is_prefill: bool, idx: int): slot, active_request = data my_live_requests[slot] = active_request + + async def loadAdaptersFromCatalogToTensorStore(self): + logging.info(f"Loading adapters from the catalog file at the start of the server.") + + if not self._prefill_engines and not self._generate_engines: + logging.info(f"There is no MaxEngine object defined. So could not load any adapter.") + + engine = None + + if self._prefill_engines: + engine = self._prefill_engines[0] + else: + engine = self._generate_engines[0] + + adapter_params_and_config = engine.load_adapters_from_catalog_file() + + if not adapter_params_and_config: + logging.info("There is no adapter loaded from the catelog file.") + + tasks = [] + for key, value in adapter_params_and_config.items(): + adapter_id = key + adapter_config = value["config"] + adapter_params_pytree =value["params"] + + try: + self._adapter_tensorstore.register_adapter( + adapter_id, + adapter_config["adapter_path"], + adapter_config) + + except ValueError as e: + logging.info(f"Registration failed with error: {str(e)}") + + task = asyncio.create_task(self._adapter_tensorstore.load_adapter(adapter_id, adapter_params_pytree, False)) + task.set_name(f"Task:loading-adapter-{adapter_id}") + tasks.append(task) + + await asyncio.gather(*tasks) + + logging.info(f"All adapters from catalog file loaded successfully.") + + engine.print_stats("After loading all adapters from catelog.") + + + def loadAdapterToTensorstore( + self, + adapter_id, + adapter_path): + logging.info(f"Loading adapter_id={adapter_id} from adapter_path={adapter_path}.") + + if not self._prefill_engines and not self._generate_engines: + logging.info(f"There is no MaxEngine object defined. So could not load any adapter.") + + engine = None + + if self._prefill_engines: + engine = self._prefill_engines[0] + else: + engine = self._generate_engines[0] + + adapter_params, adapter_config = engine.load_single_adapter(adapter_path) + + if not adapter_params or not adapter_config: + logging.info("Either params or adapter config is not loaded successfully.") + + try: + self._adapter_tensorstore.register_adapter( + adapter_id, + adapter_path, + adapter_config) + except ValueError as e: + logging.info(f"Registration failed with error: {e}") + + asyncio.run(self._adapter_tensorstore.load_adapter(adapter_id, adapter_params, True)) + + logging.info(f"Successfully loaded adapter_id={adapter_id}.") + engine.print_stats("After loading adapter_id={adapter_id}") + + + def unloadAdapterFromTensorstore( + self, + adapter_id): + logging.info(f"Unoading adapter_id={adapter_id}.") + + try: + asyncio.run(self._adapter_tensorstore.unload_adapter(adapter_id)) + except ValueError as e: + logging.info(f"Registration failed with error: {e}") + + engine = None + + if self._prefill_engines: + engine = self._prefill_engines[0] + else: + engine = self._generate_engines[0] + + logging.info(f"Successfully unloaded adapter_id={adapter_id}.") + engine.print_stats("After unloading adapter_id={adapter_id}") + + + def listAdaptersFromTensorstore(self): + logging.info(f"Listing loaded adapters.") + + return self._adapter_tensorstore.adapter_registry + + def loadAndApplyAdapter( self, adapter_id, @@ -1003,7 +1142,6 @@ def __init__(self, driver: Driver): def _get_prefill_content( self, request: jetstream_pb2.DecodeRequest ) -> Tuple[str | list[int], bool]: - logging.info("AMANGU: In LLMOrchestrator::_get_prefill_content") which_content = request.WhichOneof("content") content = getattr(request, which_content) if which_content == "text_content": @@ -1075,7 +1213,7 @@ async def Decode( # pylint: disable=invalid-overridden-method request: jetstream_pb2.DecodeRequest, context: Optional[grpc.aio.ServicerContext] = None, ) -> AsyncIterator[jetstream_pb2.DecodeResponse]: - logging.info("AMANGU: In LLMOrchestrator::Decode") + """Decode.""" if context is None: logging.warning( @@ -1086,9 +1224,11 @@ async def Decode( # pylint: disable=invalid-overridden-method return_channel = async_multifuture.AsyncMultifuture() if context: context.add_done_callback(return_channel.cancel) + prefill_content, is_client_side_tokenization = self._get_prefill_content( request ) + # Wrap request as an ActiveRequest. active_request = ActiveRequest( max_tokens=request.max_tokens, diff --git a/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto index 70f501df..d3427be6 100644 --- a/jetstream/core/proto/jetstream.proto +++ b/jetstream/core/proto/jetstream.proto @@ -118,12 +118,15 @@ message ListAdaptersResponse { message AdapterInfo { string adapter_id = 1; int64 loading_cost = 2; + int64 size_hbm = 3; + int64 size_cpu = 4; + float last_accessed = 5; + string status = 6; } message LoadAdapterRequest { string adapter_id = 1; // Unique ID for the adapter - string adapter_config_path = 2; // Path to the LoRA adapter config - string adapter_weights_path = 3; // Path to the LoRA adapter weights (a directory with weights & config) + string adapter_path = 2; // Path to the LoRA adapter (config & weights) } message LoadAdapterResponse { diff --git a/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py index da19b97f..71d4af40 100644 --- a/jetstream/core/proto/jetstream_pb2.py +++ b/jetstream/core/proto/jetstream_pb2.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE @@ -28,7 +29,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fjetstream.proto\x12\x0fjetstream_proto\"\x90\x03\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x12\x12\n\nadapter_id\x18\x08 \x01(\t\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04\"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\"\x14\n\x12HealthCheckRequest\"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\"\x15\n\x13ListAdaptersRequest\"s\n\x14ListAdaptersResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\x12\x33\n\radapter_infos\x18\x03 \x03(\x0b\x32\x1c.jetstream_proto.AdapterInfo\"7\n\x0b\x41\x64\x61pterInfo\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0cloading_cost\x18\x02 \x01(\x03\"c\n\x12LoadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x1b\n\x13\x61\x64\x61pter_config_path\x18\x02 \x01(\t\x12\x1c\n\x14\x61\x64\x61pter_weights_path\x18\x03 \x01(\t\"=\n\x13LoadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\"*\n\x14UnloadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"?\n\x15UnloadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t2\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse\"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse\"\x00\x32\xb2\x02\n\x13MultiAdapterManager\x12]\n\x0cListAdapters\x12$.jetstream_proto.ListAdaptersRequest\x1a%.jetstream_proto.ListAdaptersResponse\"\x00\x12Z\n\x0bLoadAdapter\x12#.jetstream_proto.LoadAdapterRequest\x1a$.jetstream_proto.LoadAdapterResponse\"\x00\x12`\n\rUnloadAdapter\x12%.jetstream_proto.UnloadAdapterRequest\x1a&.jetstream_proto.UnloadAdapterResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fjetstream.proto\x12\x0fjetstream_proto\"\x90\x03\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x12\x12\n\nadapter_id\x18\x08 \x01(\t\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04\"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\"\x14\n\x12HealthCheckRequest\"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\"\x15\n\x13ListAdaptersRequest\"s\n\x14ListAdaptersResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\x12\x33\n\radapter_infos\x18\x03 \x03(\x0b\x32\x1c.jetstream_proto.AdapterInfo\"\x82\x01\n\x0b\x41\x64\x61pterInfo\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0cloading_cost\x18\x02 \x01(\x03\x12\x10\n\x08size_hbm\x18\x03 \x01(\x03\x12\x10\n\x08size_cpu\x18\x04 \x01(\x03\x12\x15\n\rlast_accessed\x18\x05 \x01(\x02\x12\x0e\n\x06status\x18\x06 \x01(\t\">\n\x12LoadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\"=\n\x13LoadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\"*\n\x14UnloadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"?\n\x15UnloadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t2\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse\"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse\"\x00\x32\xb2\x02\n\x13MultiAdapterManager\x12]\n\x0cListAdapters\x12$.jetstream_proto.ListAdaptersRequest\x1a%.jetstream_proto.ListAdaptersResponse\"\x00\x12Z\n\x0bLoadAdapter\x12#.jetstream_proto.LoadAdapterRequest\x1a$.jetstream_proto.LoadAdapterResponse\"\x00\x12`\n\rUnloadAdapter\x12%.jetstream_proto.UnloadAdapterRequest\x1a&.jetstream_proto.UnloadAdapterResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -59,18 +60,18 @@ _globals['_LISTADAPTERSREQUEST']._serialized_end=856 _globals['_LISTADAPTERSRESPONSE']._serialized_start=858 _globals['_LISTADAPTERSRESPONSE']._serialized_end=973 - _globals['_ADAPTERINFO']._serialized_start=975 - _globals['_ADAPTERINFO']._serialized_end=1030 - _globals['_LOADADAPTERREQUEST']._serialized_start=1032 - _globals['_LOADADAPTERREQUEST']._serialized_end=1131 - _globals['_LOADADAPTERRESPONSE']._serialized_start=1133 - _globals['_LOADADAPTERRESPONSE']._serialized_end=1194 - _globals['_UNLOADADAPTERREQUEST']._serialized_start=1196 - _globals['_UNLOADADAPTERREQUEST']._serialized_end=1238 - _globals['_UNLOADADAPTERRESPONSE']._serialized_start=1240 - _globals['_UNLOADADAPTERRESPONSE']._serialized_end=1303 - _globals['_ORCHESTRATOR']._serialized_start=1306 - _globals['_ORCHESTRATOR']._serialized_end=1491 - _globals['_MULTIADAPTERMANAGER']._serialized_start=1494 - _globals['_MULTIADAPTERMANAGER']._serialized_end=1800 + _globals['_ADAPTERINFO']._serialized_start=976 + _globals['_ADAPTERINFO']._serialized_end=1106 + _globals['_LOADADAPTERREQUEST']._serialized_start=1108 + _globals['_LOADADAPTERREQUEST']._serialized_end=1170 + _globals['_LOADADAPTERRESPONSE']._serialized_start=1172 + _globals['_LOADADAPTERRESPONSE']._serialized_end=1233 + _globals['_UNLOADADAPTERREQUEST']._serialized_start=1235 + _globals['_UNLOADADAPTERREQUEST']._serialized_end=1277 + _globals['_UNLOADADAPTERRESPONSE']._serialized_start=1279 + _globals['_UNLOADADAPTERRESPONSE']._serialized_end=1342 + _globals['_ORCHESTRATOR']._serialized_start=1345 + _globals['_ORCHESTRATOR']._serialized_end=1530 + _globals['_MULTIADAPTERMANAGER']._serialized_start=1533 + _globals['_MULTIADAPTERMANAGER']._serialized_end=1839 # @@protoc_insertion_point(module_scope) diff --git a/jetstream/core/proto/jetstream_pb2_grpc_original.py b/jetstream/core/proto/jetstream_pb2_grpc_original.py new file mode 100644 index 00000000..5de13a1e --- /dev/null +++ b/jetstream/core/proto/jetstream_pb2_grpc_original.py @@ -0,0 +1,209 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from jetstream.core.proto import jetstream_pb2 as jetstream_dot_core_dot_proto_dot_jetstream__pb2 + + +class OrchestratorStub(object): + """TODO: Merge this with main JetStream core once we settle on an API.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Decode = channel.unary_stream( + "/jetstream_proto.Orchestrator/Decode", + request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString, + response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString, + ) + self.HealthCheck = channel.unary_unary( + "/jetstream_proto.Orchestrator/HealthCheck", + request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString, + response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString, + ) + + +class OrchestratorServicer(object): + """TODO: Merge this with main JetStream core once we settle on an API.""" + + def Decode(self, request, context): + """Query LLM to generate text or tokens.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def HealthCheck(self, request, context): + """Checks if the model server is live.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_OrchestratorServicer_to_server(servicer, server): + rpc_method_handlers = { + "Decode": grpc.unary_stream_rpc_method_handler( + servicer.Decode, + request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.FromString, + response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.SerializeToString, + ), + "HealthCheck": grpc.unary_unary_rpc_method_handler( + servicer.HealthCheck, + request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.FromString, + response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + "jetstream_proto.Orchestrator", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) + + +# This class is part of an EXPERIMENTAL API. +class Orchestrator(object): + """TODO: Merge this with main JetStream core once we settle on an API.""" + + @staticmethod + def Decode( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_stream( + request, + target, + "/jetstream_proto.Orchestrator/Decode", + jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString, + jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def HealthCheck( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/jetstream_proto.Orchestrator/HealthCheck", + jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString, + jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + +class MultiAdapterManagerStub(object): + """MultiAdapterManagerStub.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.ListAdapters = channel.unary_unary( + '/jetstream_proto.MultiAdapterManager/ListAdapters', + request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersRequest.SerializeToString, + response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersResponse.FromString, + _registered_method=True) + self.LoadAdapter = channel.unary_unary( + '/jetstream_proto.MultiAdapterManager/LoadAdapter', + request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterRequest.SerializeToString, + response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterResponse.FromString, + _registered_method=True) + self.UnloadAdapter = channel.unary_unary( + '/jetstream_proto.MultiAdapterManager/UnloadAdapter', + request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterRequest.SerializeToString, + response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterResponse.FromString, + _registered_method=True) + + +class MultiAdapterManagerServicer(object): + """TODO: Merge this with main JetStream core once we settle on an API.""" + + def ListAdapters(self, request, context): + """Lists all the currently loaded LoRA adapters.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def LoadAdapter(self, request, context): + """Check the feasibility and load the new LoRA adapter.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def UnloadAdapter(self, request, context): + """Unload a LoRA adapter.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_MultiAdapterManagerServicer_to_server(servicer, server): + rpc_method_handlers = { + "ListAdapters": grpc.unary_unary_rpc_method_handler( + servicer.ListAdapters, + request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersRequest.FromString, + response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersResponse.SerializeToString, + ), + "LoadAdapter": grpc.unary_unary_rpc_method_handler( + servicer.LoadAdapter, + request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterRequest.FromString, + response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterResponse.SerializeToString, + ), + "UnloadAdapter": grpc.unary_unary_rpc_method_handler( + servicer.UnloadAdapter, + request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterRequest.FromString, + response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + "jetstream_proto.MultiAdapterManager", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 6196bc83..4485e8eb 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -66,6 +66,8 @@ async def do_init(): orchestrator.LLMOrchestrator(driver=self._driver), self._grpc_server ) + asyncio.run(self._driver.loadAdaptersFromCatalogToTensorStore()) + jetstream_pb2_grpc.add_MultiAdapterManagerServicer_to_server( adapter_manager.MultiLoraManager(driver=self._driver), self._grpc_server ) @@ -123,6 +125,7 @@ def create_driver( generate_params = [{"base_params": ge.load_params()} for ge in engines.generate_engines] shared_params = [{"base_params": ie.load_params()} for ie in engines.interleaved_engines] logging.info("Loaded all weights.") + interleaved_mode = ( len(config.prefill_slices) + len(config.generate_slices) == 0 ) diff --git a/jetstream/tools/decode_multi_requester.py b/jetstream/tools/decode_multi_requester.py new file mode 100644 index 00000000..3337c87a --- /dev/null +++ b/jetstream/tools/decode_multi_requester.py @@ -0,0 +1,337 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark JetStream online serving. + +On the server side, run one of the following commands: + * For real server, you need to pass correct server config (include the + model config that being passed into your engine impl) to the command + below. Refer to config_lib.py and implementations/mock/config.py for + config impl detail. + + (run with real server) + python -m jetstream.core.implementations..server \ + --config + + (run with mock server) + python -m jetstream.core.implementations.mock.server + +On the client side, run: + * For real server and shareGPT dataset, you need to pass the tokenizer, + server config, and dataset flags to the command below, and make some + changes to the tokenizer logic in the benchmark script (get_tokenizer + and sample_requests func) to use your tokenizer correctly. + * Add `--save-result` flag to save the benchmark result to a json file in + current folder. + * You can also add `--run_eval true` if you want to calculate ROUGE score + on the predicted outputs. + + (run with real model and engines) + python -m benchmarks.benchmark_serving \ + --tokenizer \ + --dataset \ + --dataset-path \ + --request-rate + + (run with mock) + python -m benchmarks.benchmark_serving \ + --request-rate 1 + +e2e example: +python3 benchmark_serving.py \ + --tokenizer /home/{username}/maxtext/assets/tokenizer \ + --num-prompts 100 \ + --dataset sharegpt \ + --dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json + +""" + + +import argparse +import asyncio +from dataclasses import dataclass, field +from datetime import datetime +import json +import random +import time +from typing import Any, AsyncGenerator, Optional +import os + + +import grpc +from jetstream.core.proto import jetstream_pb2 +from jetstream.core.proto import jetstream_pb2_grpc +from jetstream.engine.token_utils import load_vocab +from jetstream.external_tokenizers.llama3 import llama3_tokenizer +import numpy as np + + +def str2bool(v: str) -> bool: + """Convert a string of truth to True or False. + + Args: + - v (str): + - True values are 'y', 'yes', 't', 'true', and '1'; + - False values are 'n', 'no', 'f', 'false', and '0'. + + Returns: + bool: True or False + + Raises: + ValueError if v is anything else. + """ + v = v.lower() + true_values = ["y", "yes", "t", "true", "1"] + false_values = ["n", "no", "f", "false", "0"] + if v in true_values: + return True + elif v in false_values: + return False + else: + raise ValueError(f"Invalid value '{v}'!") + + +@dataclass +class BenchmarkMetrics: + """Data class to store benchmark metrics.""" + + completed: int + total_input: int + total_output: int + request_throughput: float + input_throughput: float + output_throughput: float + mean_ttft_ms: float + median_ttft_ms: float + p99_ttft_ms: float + mean_tpot_ms: float + median_tpot_ms: float + p99_tpot_ms: float + + +@dataclass +class InputRequest: + prompt: str = "" + output: str = "" + output_len: int = 0 + sample_idx: int = -1 + + +@dataclass +class RequestFuncOutput: + input_request: Optional[InputRequest] = None + generated_token_list: list[str] = field(default_factory=list) + generated_text: str = "" + success: bool = False + latency: float = 0 + ttft: float = 0 + + # Flatten the structure and return only the necessary results + def to_dict(self): + return { + "prompt": self.input_request.prompt, + "original_output": self.input_request.output, + "generated_text": self.generated_text, + "success": self.success, + "latency": self.latency, + "sample_idx": self.input_request.sample_idx, + } + + +def get_tokenizer( + model_id: str, + tokenizer_name: str, +) -> Any: + """Return a tokenizer or a tokenizer placholder.""" + if tokenizer_name == "test": + print("Using test tokenizer") + return "test" + elif model_id == "llama-3": + # Llama 3 uses a tiktoken tokenizer. + print(f"Using llama-3 tokenizer: {tokenizer_name}") + return llama3_tokenizer.Tokenizer(tokenizer_name) + else: + # Use JetStream tokenizer util. It's using the sentencepiece wrapper in + # seqio library. + print(f"Using tokenizer: {tokenizer_name}") + vocab = load_vocab(tokenizer_name) + return vocab.tokenizer + + +async def grpc_async_request( + api_url: str, request: Any +) -> tuple[list[str], float, float]: + """Send grpc synchronous request since the current grpc server is sync.""" + options = [("grpc.keepalive_timeout_ms", 10000)] + async with grpc.aio.insecure_channel(api_url, options=options) as channel: + stub = jetstream_pb2_grpc.OrchestratorStub(channel) + print("Making request") + ttft = 0 + token_list = [] + request_start_time = time.perf_counter() + response = stub.Decode(request) + async for resp in response: + if ttft == 0: + ttft = time.perf_counter() - request_start_time + token_list.extend(resp.stream_content.samples[0].token_ids) + latency = time.perf_counter() - request_start_time + return token_list, ttft, latency + + +async def send_request( + api_url: str, + tokenizer: Any, + input_request: InputRequest, +) -> RequestFuncOutput: + """Send the request to JetStream server.""" + # Tokenization on client side following MLPerf standard. + token_ids = tokenizer.encode(input_request.prompt) + request = jetstream_pb2.DecodeRequest( + token_content=jetstream_pb2.DecodeRequest.TokenContent( + token_ids=token_ids + ), + max_tokens=input_request.output_len, + adapter_id=input_request.adapter_id, + ) + output = RequestFuncOutput() + output.input_request = input_request + generated_token_list, ttft, latency = await grpc_async_request( + api_url, request + ) + output.ttft = ttft + output.latency = latency + output.generated_token_list = generated_token_list + # generated_token_list is a list of token ids, decode it to generated_text. + output.generated_text = tokenizer.decode(generated_token_list) + output.success = True + return output + + +async def get_request( + input_requests: list[InputRequest], +) -> AsyncGenerator[InputRequest, None]: + input_requests = iter(input_requests) + + for request in input_requests: + yield request + + +async def send_multi_request( + api_url: str, + tokenizer: Any, + input_requests: list[InputRequest], +): + """Send multiple LoRA adapter requests.""" + tasks = [] + async for request in get_request(input_requests): + tasks.append( + asyncio.create_task( + send_request( + api_url=api_url, + tokenizer=tokenizer, + input_request=request, + ) + ) + ) + outputs = await asyncio.gather(*tasks) + + return outputs + + +def mock_adapter_requests(total_mock_requests: int): + """Generates a list of mock requests containing mock data.""" + data = [] + for index in range(total_mock_requests): + request = InputRequest() + request.prompt = f"22 year old" + if index == 0: + request.adapter_id = "" + else: + request.adapter_id = f"test_lora_{index}" + request.output_len = 3 + data.append(request) + return data + + +def main(args: argparse.Namespace): + print(args) + + model_id = args.model + tokenizer_id = args.tokenizer + + api_url = f"{args.server}:{args.port}" + + tokenizer = get_tokenizer(model_id, tokenizer_id) + input_requests = mock_adapter_requests( + args.total_mock_requests + ) # e.g. [("AB", 2, "AB", 3)] + + request_outputs = asyncio.run( + send_multi_request( + api_url=api_url, + tokenizer=tokenizer, + input_requests=input_requests, + ) + ) + + output = [output.to_dict() for output in request_outputs] + + # Process output + for index, output in enumerate(output): + print(f"Prompt: {input_requests[index].prompt}") + print(f"AdapterId: {input_requests[index].adapter_id}") + print(f"Output: {output}") + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="Sending multiple serving requests to JetStream Server" + ) + parser.add_argument( + "--server", + type=str, + default="0.0.0.0", + help="Server address.", + ) + parser.add_argument("--port", type=str, default=9000) + parser.add_argument( + "--model", + type=str, + default="no_model", + help=( + "Name of the model like llama-2, llama-3, gemma. (it's just used to" + " label the benchmark, pick the tokenizer, the model config is" + " defined in config_lib, and passed as the server config flag when" + " we run the JetStream server)" + ), + ) + parser.add_argument( + "--total-mock-requests", + type=int, + default=3, + help="The maximum number of mock requests to send for benchmark testing.", + ) + parser.add_argument( + "--tokenizer", + type=str, + default="test", + help=( + "Name or path of the tokenizer. (For mock model testing, use the" + " default value)" + ), + ) + + parsed_args = parser.parse_args() + main(parsed_args) diff --git a/jetstream/tools/llm_gateway_proxy_client.py b/jetstream/tools/llm_gateway_proxy_client.py index 33a6e379..d607c544 100644 --- a/jetstream/tools/llm_gateway_proxy_client.py +++ b/jetstream/tools/llm_gateway_proxy_client.py @@ -39,15 +39,8 @@ required=False, ) -_ADAPTER_CONFIG_PATH = flags.DEFINE_string( - "adapter_config_path", - None, - "Path of the fine-tuned adapter to be loaded from.", - required=False, -) - -_ADAPTER_WEIGHTS_PATH = flags.DEFINE_string( - "adapter_weights_path", +_ADAPTER_PATH = flags.DEFINE_string( + "adapter_path", None, "Path of the fine-tuned adapter to be loaded from.", required=False, @@ -75,17 +68,15 @@ def main(argv: Sequence[str]) -> None: print(f"Calling the JetStream/MultiAdapterManager/LoadAdapter.") adapter_id=_ADAPTER_ID.value - adapter_config_path=_ADAPTER_CONFIG_PATH.value - adapter_weights_path=_ADAPTER_WEIGHTS_PATH.value + adapter_path=_ADAPTER_PATH.value - if adapter_id == None or adapter_weights_path == None or adapter_config_path == None: - print(f"For `load_adapter` API call, `adapter_id`, `adapter_config_path` and `adapter_weights_path` must be passed.") + if adapter_id == None or adapter_path == None: + print(f"For `load_adapter` API call, `adapter_id` and `adapter_path` must be passed.") return request = jetstream_pb2.LoadAdapterRequest( adapter_id=adapter_id, - adapter_config_path=adapter_config_path, - adapter_weights_path=adapter_weights_path + adapter_path=adapter_path ) response = stub.LoadAdapter(request) @@ -125,9 +116,9 @@ def main(argv: Sequence[str]) -> None: if response.success is True: print(f"`ListAdapter` call responded successfully. Here is the list of adapters loaded on server:") for adapter_info in response.adapter_infos: - print(f"adapter_id={adapter_info.adapter_id}, loading_cost={adapter_info.loading_cost}.") + print(f"adapter_id={adapter_info.adapter_id}, loading_cost={adapter_info.loading_cost}, size_hbm={adapter_info.size_hbm} bytes, size_cpu={adapter_info.size_cpu} Bytes, last_accessed={adapter_info.last_accessed}, status={adapter_info.status}") else: - print(f"`ListAdapter` call failed with error={error_message}") + print(f"`ListAdapter` call failed with error={response.error_message}") elif _TEST_API_NAME.value == None: print(f"`test_api_name` flag is not set. So exiting.") diff --git a/jetstream/tools/requester.py b/jetstream/tools/requester.py index 54bbb78a..d81cdfa9 100644 --- a/jetstream/tools/requester.py +++ b/jetstream/tools/requester.py @@ -42,6 +42,11 @@ False, "Enable client side tokenization with tokenizer.", ) +_ADAPTER_ID = flags.DEFINE_string( + "adapter_id", + "", + "ID of the adapter for this decode request.", + required=False) def _GetResponseAsync( @@ -90,6 +95,7 @@ def main(argv: Sequence[str]) -> None: text=_TEXT.value ), max_tokens=_MAX_TOKENS.value, + adapter_id=_ADAPTER_ID.value, ) return _GetResponseAsync(stub, request) From 3c6fcbd035e83a62c562ff348166c15e26c34c4a Mon Sep 17 00:00:00 2001 From: aman2930 Date: Mon, 24 Feb 2025 20:15:27 +0000 Subject: [PATCH 06/22] 1) Implemented a new Service API proto to align with OpenAI completion API (https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/docs/proposals/003-model-server-protocol/README.md#inference-api-protocol), & . 2) Added a flag to explicitly run the JetStream server with these APIs when . Else only expose older Decode() & HealthCheck() APIs of the JetStream Server. 3) Fixed a bug in the adapter_tensorstore while converting jnp_array and np_array. 4) Added a which made requests to the new APIs (v1/load_lora_adapter, v1/unload_lora_adapter, v1/models, v1/completions) --- jetstream/core/adapter_tensorstore.py | 12 +- jetstream/core/llm_inference_pool_api.py | 257 ++++++++++++++++++ .../core/proto/multi_lora_decoding.proto | 131 +++++++++ .../core/proto/multi_lora_decoding_pb2.py | 57 ++++ .../proto/multi_lora_decoding_pb2_grpc.py | 230 ++++++++++++++++ jetstream/core/server_lib.py | 38 ++- .../tools/llm_gateway_proxy_client_v2.py | 161 +++++++++++ 7 files changed, 869 insertions(+), 17 deletions(-) create mode 100644 jetstream/core/llm_inference_pool_api.py create mode 100644 jetstream/core/proto/multi_lora_decoding.proto create mode 100644 jetstream/core/proto/multi_lora_decoding_pb2.py create mode 100644 jetstream/core/proto/multi_lora_decoding_pb2_grpc.py create mode 100644 jetstream/tools/llm_gateway_proxy_client_v2.py diff --git a/jetstream/core/adapter_tensorstore.py b/jetstream/core/adapter_tensorstore.py index 169a02b7..fd8e6e39 100644 --- a/jetstream/core/adapter_tensorstore.py +++ b/jetstream/core/adapter_tensorstore.py @@ -82,7 +82,7 @@ async def _transfer_to_hbm(self, adapter_id: str): raise RuntimeError("Not enough HBM to transfer adapter, and eviction failed.") # Move from CPU to HBM - self.loaded_adapters_hbm[adapter_id] = jnp.array(self.loaded_adapters_cpu[adapter_id]) # Convert to JAX array + self.loaded_adapters_hbm[adapter_id] = self._as_jnp_array(self.loaded_adapters_cpu[adapter_id]) # Convert to JAX array del self.loaded_adapters_cpu[adapter_id] self.current_cpu_usage -= metadata.size_cpu @@ -107,7 +107,7 @@ async def _transfer_to_cpu(self, adapter_id: str): raise RuntimeError("Not enough CPU RAM to transfer adapter, and eviction failed.") # Move from HBM to CPU - self.loaded_adapters_cpu[adapter_id] = np.array(self.loaded_adapters_hbm[adapter_id]) + self.loaded_adapters_cpu[adapter_id] = self._as_np_array(self.loaded_adapters_hbm[adapter_id]) del self.loaded_adapters_hbm[adapter_id] self.current_hbm_usage -= metadata.size_hbm @@ -299,9 +299,9 @@ def get_stacked_lora_weights(self, lora_ids: jnp.ndarray, to_hbm: bool = True): if metadata.status != "loaded_hbm" and metadata.status != "loaded_cpu": asyncio.run(self.load_adapter(adapter_id, to_hbm)) # Start loading (async) elif to_hbm and metadata.status == "loaded_cpu": - self._transfer_to_hbm(adapter_id) + asyncio.run(self._transfer_to_hbm(adapter_id)) elif not to_hbm and metadata.status == "loaded_hbm": - self._transfer_to_cpu(adapter_id) + asyncio.run(self._transfer_to_cpu(adapter_id)) # Wait till all the running requests are completed while self.running_requests > 0: @@ -350,9 +350,9 @@ def get_lora_weights(self, adapter_id, to_hbm: bool = True): if metadata.status != "loaded_hbm" and metadata.status != "loaded_cpu": asyncio.run(self.load_adapter(adapter_id, None, to_hbm)) # Start loading (async) elif to_hbm and metadata.status == "loaded_cpu": - self._transfer_to_hbm(adapter_id) + asyncio.run(self._transfer_to_hbm(adapter_id)) elif not to_hbm and metadata.status == "loaded_hbm": - self._transfer_to_cpu(adapter_id) + asyncio.run(self._transfer_to_cpu(adapter_id)) # Wait till all the running requests are completed while self.running_requests > 0: diff --git a/jetstream/core/llm_inference_pool_api.py b/jetstream/core/llm_inference_pool_api.py new file mode 100644 index 00000000..f88ac5c6 --- /dev/null +++ b/jetstream/core/llm_inference_pool_api.py @@ -0,0 +1,257 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Manages the list of fine-tuned adapters loaded on top of the base model for serving. +""" + +import logging +import grpc +import time + +from typing import Any, AsyncIterator, Optional, Tuple, cast +from jetstream.core import adapter_tensorstore +from jetstream.core import orchestrator +from jetstream.core.proto import multi_lora_decoding_pb2_grpc +from jetstream.core.proto import multi_lora_decoding_pb2 +from jetstream.core.utils import async_multifuture +from jetstream.core.utils.return_sample import ReturnSample +from jetstream.engine import engine_api, tokenizer_api, token_utils + + +class MultiLoraManager(multi_lora_decoding_pb2_grpc.v1Servicer): + """Manages the parameters of multiple lora requests and their lifelines.""" + + _driver: orchestrator.Driver + + def __init__(self, driver: orchestrator.Driver): + self._driver = driver + + def models( + self, + request: multi_lora_decoding_pb2.ListAdaptersRequest, + context: Optional[grpc.aio.ServicerContext] = None, + ) -> multi_lora_decoding_pb2.ListAdaptersResponse: + """ListAdapters all loaded LoRA adapters.""" + + try: + adapters = self._driver.listAdaptersFromTensorstore() + + adapter_infos = [] + for adapter_id, adapter_data in adapters.items(): + if adapter_data.status == "loaded_hbm": + loading_cost = 0 + elif adapter_data.status == "loaded_cpu": + loading_cost = 1 + elif adapter_data.status == "unloaded": + loading_cost = 2 + else: + loading_cost = -1 + + adapter_info = multi_lora_decoding_pb2.AdapterInfo( + adapter_id=adapter_id, + loading_cost=loading_cost, + size_hbm=adapter_data.size_hbm, + size_cpu=adapter_data.size_cpu, + last_accessed=adapter_data.last_accessed, + status=adapter_data.status) + + adapter_infos.append(adapter_info) + + return multi_lora_decoding_pb2.ListAdaptersResponse(success=True, adapter_infos=adapter_infos) + except Exception as e: + logging.info(f"Listing of adapters failed with error: {str(e)}") + return multi_lora_decoding_pb2.ListAdaptersResponse(success=False, error_message=str(e)) + + + def load_lora_adapter( + self, + request: multi_lora_decoding_pb2.LoadAdapterRequest, + context: Optional[grpc.aio.ServicerContext] = None, + ) -> multi_lora_decoding_pb2.LoadAdapterResponse: + """Load a LoRA adapter as mentioned in the request.""" + + try: + self._driver.loadAdapterToTensorstore(request.adapter_id, request.adapter_path) + + return multi_lora_decoding_pb2.LoadAdapterResponse(success=True) + except Exception as e: + logging.info(f"Loading of adapter_id={request.adapter_id} failed with error: {str(e)}") + return multi_lora_decoding_pb2.LoadAdapterResponse(success=False, error_message=str(e)) + + + def unload_lora_adapter( + self, + request: multi_lora_decoding_pb2.UnloadAdapterRequest, + context: Optional[grpc.aio.ServicerContext] = None, + ) -> multi_lora_decoding_pb2.UnloadAdapterResponse: + """Unload a LoRA adapter as mentioned in the request.""" + + try: + self._driver.unloadAdapterFromTensorstore(request.adapter_id) + return multi_lora_decoding_pb2.UnloadAdapterResponse(success=True) + except Exception as e: + logging.info(f"Loading of adapter_id={request.adapter_id} failed with error: {str(e)}") + return multi_lora_decoding_pb2.UnloadAdapterResponse(success=False, error_message=str(e)) + + + def _get_prefill_content( + self, request: multi_lora_decoding_pb2.CompletionRequest + ) -> Tuple[str | list[int], bool]: + which_content = request.WhichOneof("content") + content = getattr(request, which_content) + if which_content == "text_content": + return cast(multi_lora_decoding_pb2.CompletionRequest.TextContent, content).text, False + else: + return ( + list( + cast(multi_lora_decoding_pb2.CompletionRequest.TokenContent, content).token_ids + ), + True, + ) + + def process_client_side_tokenization_response(self, response: Any): + samples = [] + for sample in response: + samples.append( + multi_lora_decoding_pb2.CompletionResponse.StreamContent.Sample( + token_ids=sample.token_ids, + ) + ) + return multi_lora_decoding_pb2.CompletionResponse( + stream_content=multi_lora_decoding_pb2.CompletionResponse.StreamContent( + samples=samples + ) + ) + + def should_buffer_response(self, response: Any) -> bool: + for item in response: + if item.text and token_utils.is_byte_token(item.text[-1]): + # If any sample ends in bytes, this means we might still need to + # decode more bytes to compose the string. + return True + + def process_server_side_tokenization_response( + self, response: Any, buffered_response_list + ): + # Flush the buffered responses to each sample of current response. + current_response_with_flushed_buffer = list( + zip(*buffered_response_list, response) + ) + # Empty buffer: [[s0_cur], [s1_cur], ...] + # Has buffer: + # [[s0_b0, s0_b1, ..., s0_cur], [s1_b0, s1_b1, ..., s1_cur], ...] + current_response_with_flushed_buffer = cast( + list[list[ReturnSample]], current_response_with_flushed_buffer + ) + # Form correct sample(s) and return as StreamContent for this iteration. + samples = [] + for sample in current_response_with_flushed_buffer: + text = [] + token_ids = [] + for resp in sample: + text.extend(resp.text) + token_ids.extend(resp.token_ids) + samples.append( + multi_lora_decoding_pb2.CompletionResponse.StreamContent.Sample( + text=token_utils.text_tokens_to_str(text), + token_ids=token_ids, + ) + ) + return multi_lora_decoding_pb2.CompletionResponse( + stream_content=multi_lora_decoding_pb2.CompletionResponse.StreamContent( + samples=samples + ) + ) + + async def completions( # pylint: disable=invalid-overridden-method + self, + request: multi_lora_decoding_pb2.CompletionRequest, + context: Optional[grpc.aio.ServicerContext] = None, + ) -> AsyncIterator[multi_lora_decoding_pb2.CompletionResponse]: + + """Decode.""" + if context is None: + logging.warning( + "LLM orchestrator is being used in offline test mode, and will not" + " respond to gRPC queries - only direct function calls." + ) + is_client_side_tokenization = False + return_channel = async_multifuture.AsyncMultifuture() + if context: + context.add_done_callback(return_channel.cancel) + + prefill_content, is_client_side_tokenization = self._get_prefill_content( + request + ) + + # Wrap request as an ActiveRequest. + active_request = orchestrator.ActiveRequest( + max_tokens=request.max_tokens, + prefill_content=prefill_content, + is_client_side_tokenization=is_client_side_tokenization, + return_channel=return_channel, + adapter_id=request.adapter_id, + metadata=orchestrator.ActiveRequestMetadata( + start_time=request.metadata.start_time, + prefill_enqueue_time=time.perf_counter(), + ), + ) + # The first stage is being prefilled, all other stages are handled + # inside the driver (transfer, generate*N, detokenize). + try: + self._driver.place_request_on_prefill_queue(active_request) + except queue.Full: + # Safely abort the gRPC server thread with a retriable error. + await _abort_or_raise( + context=context, + code=grpc.StatusCode.RESOURCE_EXHAUSTED, + details=( + "The driver prefill queue is full and more requests cannot be" + " handled. You may retry this request." + ), + ) + logging.info( + "Placed request on the prefill queue.", + ) + # When an active request is created a queue is instantiated. New tokens + # are placed there during the decoding loop, we pop from that queue by + # using the .next method on the active request. + # Yielding allows for the response to be a streaming grpc call - which + # can be called via iterating over a for loop on the client side. + # The DecodeResponse stream should consume all generated tokens in + # return_channel when complete signal is received (AsyncMultifuture + # promises this). + buffered_response_list = [] + async for response in active_request.return_channel: + response = cast(list[ReturnSample], response) + if is_client_side_tokenization: + # If is_client_side_tokenization, the client should request with token + # ids, and the JetStream server will return token ids as response. + # The client should take care of tokenization and detokenization. + yield self.process_client_side_tokenization_response(response) + else: + # Buffer response mechanism is used to handle streaming + # detokenization with special character (For some edge cases with + # SentencePiece tokenizer, it requires to decode a complete sequence + # instead of a single token). + if self.should_buffer_response(response): + buffered_response_list.append(response) + continue + yield self.process_server_side_tokenization_response( + response, buffered_response_list + ) + # Reset buffer after flushed. + buffered_response_list = [] + + diff --git a/jetstream/core/proto/multi_lora_decoding.proto b/jetstream/core/proto/multi_lora_decoding.proto new file mode 100644 index 00000000..77f190a6 --- /dev/null +++ b/jetstream/core/proto/multi_lora_decoding.proto @@ -0,0 +1,131 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// NOTICE: run `make generate-protos` if making changes to this file + +syntax = "proto3"; + +service v1 { + // Generate text based on a prompt. Supports streaming responses. + rpc completions (CompletionRequest) returns (stream CompletionResponse) {} + + // Lists all the currently loaded LoRA adapters + rpc models (ListAdaptersRequest) returns (ListAdaptersResponse) {} + + // Loads a new LoRA adapter. + rpc load_lora_adapter (LoadAdapterRequest) returns (LoadAdapterResponse) {} + + // Unloads a LoRA adapter + rpc unload_lora_adapter (UnloadAdapterRequest) returns (UnloadAdapterResponse) {} +} + + +message CompletionRequest { + // The maximum output length of a sequence. It's used in JetStream to control + // the output/decode length of a sequence. It would not be used in the engine. + // We should always set max_tokens <= (max_target_length - + // max_prefill_predict_length). max_target_length is the maximum length of a + // sequence; max_prefill_predict_length is the maximum length of the + // input/prefill of a sequence. + int32 max_tokens = 4; + + message TextContent { + string text = 1; + } + message TokenContent { + repeated int32 token_ids = 1; + } + + // The client can pass the inputs either as a string, in which case the server will + // tokenize it, or as tokens, in which case it's the client's responsibility to + // ensure they tokenize its input strings with the correct tokenizer. + oneof content { + TextContent text_content = 5; + TokenContent token_content = 6; + } + + message Metadata { + float start_time = 1; + } + + oneof metadata_optional { + Metadata metadata = 7; + } + + string adapter_id = 8; + + reserved 1, 2, 3; + // Next ID: 9 +} + +message CompletionResponse { + // InitialContent supports returning initial one-off response data from the + // stream. It's a placeholder for future features such as history cache. + message InitialContent {} + message StreamContent { + message Sample { + // The text string decoded from token id(s). + string text = 1; + // List of token ids, one list per sample. When speculative decoding is disabled, the list size should be 1; When speculative decoding is enabled, the list size should be >= 1. + repeated int32 token_ids = 2; + } + // Supports multiple samples in the StreamContent. The Sample list size depends on text generation strategy the engine used. + repeated Sample samples = 1; + } + + oneof content { + InitialContent initial_content = 2; + StreamContent stream_content = 3; + } + reserved 1; + // Next ID: 4 +} + +message ListAdaptersRequest {} + +message ListAdaptersResponse { + bool success = 1; // True if successful, False otherwise + string error_message = 2; // Error message if listing the adapters + repeated AdapterInfo adapter_infos = 3; // List of information about loaded adapters. +} + +// Information about a single loaded LoRA adapter +message AdapterInfo { + string adapter_id = 1; + int64 loading_cost = 2; + int64 size_hbm = 3; + int64 size_cpu = 4; + float last_accessed = 5; + string status = 6; +} + +message LoadAdapterRequest { + string adapter_id = 1; // Unique ID/name for the adapter + string adapter_path = 2; // Path to the LoRA adapter (config & weights) +} + +message LoadAdapterResponse { + bool success = 1; // True if successful, false otherwise + string error_message = 2; // Error message if loading failed +} + +message UnloadAdapterRequest { + string adapter_id = 1; // ID/Name of the adapter to unload +} + +message UnloadAdapterResponse { + bool success = 1; // True if successful, false otherwise + string error_message = 2; // Error message if unloading failed +} + diff --git a/jetstream/core/proto/multi_lora_decoding_pb2.py b/jetstream/core/proto/multi_lora_decoding_pb2.py new file mode 100644 index 00000000..d53e10aa --- /dev/null +++ b/jetstream/core/proto/multi_lora_decoding_pb2.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: multi_lora_decoding.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19multi_lora_decoding.proto\"\xf0\x02\n\x11\x43ompletionRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x36\n\x0ctext_content\x18\x05 \x01(\x0b\x32\x1e.CompletionRequest.TextContentH\x00\x12\x38\n\rtoken_content\x18\x06 \x01(\x0b\x32\x1f.CompletionRequest.TokenContentH\x00\x12/\n\x08metadata\x18\x07 \x01(\x0b\x32\x1b.CompletionRequest.MetadataH\x01\x12\x12\n\nadapter_id\x18\x08 \x01(\t\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04\"\xaa\x02\n\x12\x43ompletionResponse\x12=\n\x0finitial_content\x18\x02 \x01(\x0b\x32\".CompletionResponse.InitialContentH\x00\x12;\n\x0estream_content\x18\x03 \x01(\x0b\x32!.CompletionResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1au\n\rStreamContent\x12\x39\n\x07samples\x18\x01 \x03(\x0b\x32(.CompletionResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\"\x15\n\x13ListAdaptersRequest\"c\n\x14ListAdaptersResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\x12#\n\radapter_infos\x18\x03 \x03(\x0b\x32\x0c.AdapterInfo\"\x82\x01\n\x0b\x41\x64\x61pterInfo\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0cloading_cost\x18\x02 \x01(\x03\x12\x10\n\x08size_hbm\x18\x03 \x01(\x03\x12\x10\n\x08size_cpu\x18\x04 \x01(\x03\x12\x15\n\rlast_accessed\x18\x05 \x01(\x02\x12\x0e\n\x06status\x18\x06 \x01(\t\">\n\x12LoadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\"=\n\x13LoadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\"*\n\x14UnloadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"?\n\x15UnloadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t2\x83\x02\n\x02v1\x12:\n\x0b\x63ompletions\x12\x12.CompletionRequest\x1a\x13.CompletionResponse\"\x00\x30\x01\x12\x37\n\x06models\x12\x14.ListAdaptersRequest\x1a\x15.ListAdaptersResponse\"\x00\x12@\n\x11load_lora_adapter\x12\x13.LoadAdapterRequest\x1a\x14.LoadAdapterResponse\"\x00\x12\x46\n\x13unload_lora_adapter\x12\x15.UnloadAdapterRequest\x1a\x16.UnloadAdapterResponse\"\x00\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'multi_lora_decoding_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_COMPLETIONREQUEST']._serialized_start=30 + _globals['_COMPLETIONREQUEST']._serialized_end=398 + _globals['_COMPLETIONREQUEST_TEXTCONTENT']._serialized_start=254 + _globals['_COMPLETIONREQUEST_TEXTCONTENT']._serialized_end=281 + _globals['_COMPLETIONREQUEST_TOKENCONTENT']._serialized_start=283 + _globals['_COMPLETIONREQUEST_TOKENCONTENT']._serialized_end=316 + _globals['_COMPLETIONREQUEST_METADATA']._serialized_start=318 + _globals['_COMPLETIONREQUEST_METADATA']._serialized_end=348 + _globals['_COMPLETIONRESPONSE']._serialized_start=401 + _globals['_COMPLETIONRESPONSE']._serialized_end=699 + _globals['_COMPLETIONRESPONSE_INITIALCONTENT']._serialized_start=547 + _globals['_COMPLETIONRESPONSE_INITIALCONTENT']._serialized_end=563 + _globals['_COMPLETIONRESPONSE_STREAMCONTENT']._serialized_start=565 + _globals['_COMPLETIONRESPONSE_STREAMCONTENT']._serialized_end=682 + _globals['_COMPLETIONRESPONSE_STREAMCONTENT_SAMPLE']._serialized_start=641 + _globals['_COMPLETIONRESPONSE_STREAMCONTENT_SAMPLE']._serialized_end=682 + _globals['_LISTADAPTERSREQUEST']._serialized_start=701 + _globals['_LISTADAPTERSREQUEST']._serialized_end=722 + _globals['_LISTADAPTERSRESPONSE']._serialized_start=724 + _globals['_LISTADAPTERSRESPONSE']._serialized_end=823 + _globals['_ADAPTERINFO']._serialized_start=826 + _globals['_ADAPTERINFO']._serialized_end=956 + _globals['_LOADADAPTERREQUEST']._serialized_start=958 + _globals['_LOADADAPTERREQUEST']._serialized_end=1020 + _globals['_LOADADAPTERRESPONSE']._serialized_start=1022 + _globals['_LOADADAPTERRESPONSE']._serialized_end=1083 + _globals['_UNLOADADAPTERREQUEST']._serialized_start=1085 + _globals['_UNLOADADAPTERREQUEST']._serialized_end=1127 + _globals['_UNLOADADAPTERRESPONSE']._serialized_start=1129 + _globals['_UNLOADADAPTERRESPONSE']._serialized_end=1192 + _globals['_V1']._serialized_start=1195 + _globals['_V1']._serialized_end=1454 +# @@protoc_insertion_point(module_scope) diff --git a/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py b/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py new file mode 100644 index 00000000..d172d151 --- /dev/null +++ b/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py @@ -0,0 +1,230 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +from jetstream.core.proto import multi_lora_decoding_pb2 as multi__lora__decoding__pb2 + +GRPC_GENERATED_VERSION = '1.70.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in multi_lora_decoding_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class v1Stub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.completions = channel.unary_stream( + '/v1/completions', + request_serializer=multi__lora__decoding__pb2.CompletionRequest.SerializeToString, + response_deserializer=multi__lora__decoding__pb2.CompletionResponse.FromString, + _registered_method=True) + self.models = channel.unary_unary( + '/v1/models', + request_serializer=multi__lora__decoding__pb2.ListAdaptersRequest.SerializeToString, + response_deserializer=multi__lora__decoding__pb2.ListAdaptersResponse.FromString, + _registered_method=True) + self.load_lora_adapter = channel.unary_unary( + '/v1/load_lora_adapter', + request_serializer=multi__lora__decoding__pb2.LoadAdapterRequest.SerializeToString, + response_deserializer=multi__lora__decoding__pb2.LoadAdapterResponse.FromString, + _registered_method=True) + self.unload_lora_adapter = channel.unary_unary( + '/v1/unload_lora_adapter', + request_serializer=multi__lora__decoding__pb2.UnloadAdapterRequest.SerializeToString, + response_deserializer=multi__lora__decoding__pb2.UnloadAdapterResponse.FromString, + _registered_method=True) + + +class v1Servicer(object): + """Missing associated documentation comment in .proto file.""" + + def completions(self, request, context): + """Generate text based on a prompt. Supports streaming responses. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def models(self, request, context): + """Lists all the currently loaded LoRA adapters + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def load_lora_adapter(self, request, context): + """Loads a new LoRA adapter. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def unload_lora_adapter(self, request, context): + """Unloads a LoRA adapter + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_v1Servicer_to_server(servicer, server): + rpc_method_handlers = { + 'completions': grpc.unary_stream_rpc_method_handler( + servicer.completions, + request_deserializer=multi__lora__decoding__pb2.CompletionRequest.FromString, + response_serializer=multi__lora__decoding__pb2.CompletionResponse.SerializeToString, + ), + 'models': grpc.unary_unary_rpc_method_handler( + servicer.models, + request_deserializer=multi__lora__decoding__pb2.ListAdaptersRequest.FromString, + response_serializer=multi__lora__decoding__pb2.ListAdaptersResponse.SerializeToString, + ), + 'load_lora_adapter': grpc.unary_unary_rpc_method_handler( + servicer.load_lora_adapter, + request_deserializer=multi__lora__decoding__pb2.LoadAdapterRequest.FromString, + response_serializer=multi__lora__decoding__pb2.LoadAdapterResponse.SerializeToString, + ), + 'unload_lora_adapter': grpc.unary_unary_rpc_method_handler( + servicer.unload_lora_adapter, + request_deserializer=multi__lora__decoding__pb2.UnloadAdapterRequest.FromString, + response_serializer=multi__lora__decoding__pb2.UnloadAdapterResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'v1', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('v1', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class v1(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def completions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/v1/completions', + multi__lora__decoding__pb2.CompletionRequest.SerializeToString, + multi__lora__decoding__pb2.CompletionResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def models(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/v1/models', + multi__lora__decoding__pb2.ListAdaptersRequest.SerializeToString, + multi__lora__decoding__pb2.ListAdaptersResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def load_lora_adapter(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/v1/load_lora_adapter', + multi__lora__decoding__pb2.LoadAdapterRequest.SerializeToString, + multi__lora__decoding__pb2.LoadAdapterResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def unload_lora_adapter(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/v1/unload_lora_adapter', + multi__lora__decoding__pb2.UnloadAdapterRequest.SerializeToString, + multi__lora__decoding__pb2.UnloadAdapterResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 4485e8eb..18dae982 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -25,6 +25,7 @@ import threading import time import traceback +import importlib from typing import Any, Type @@ -32,9 +33,7 @@ import jax from jetstream.core import config_lib from jetstream.core import orchestrator -from jetstream.core import adapter_manager from jetstream.core.metrics.prometheus import JetstreamMetricsCollector -from jetstream.core.proto import jetstream_pb2_grpc from jetstream.engine import warmup_utils, engine_api from prometheus_client import start_http_server @@ -46,7 +45,12 @@ class JetStreamServer: """JetStream grpc server.""" def __init__( - self, driver: orchestrator.Driver, threads: int, port, credentials + self, + driver: orchestrator.Driver, + threads: int, + port, + credentials, + enable_llm_inference_pool = False ): self._executor = futures.ThreadPoolExecutor(max_workers=threads) @@ -62,15 +66,26 @@ async def do_init(): asyncio.run_coroutine_threadsafe(do_init(), loop=self._loop).result() self._driver = driver - jetstream_pb2_grpc.add_OrchestratorServicer_to_server( - orchestrator.LLMOrchestrator(driver=self._driver), self._grpc_server - ) - asyncio.run(self._driver.loadAdaptersFromCatalogToTensorStore()) + if enable_llm_inference_pool: + module_name = "jetstream.core.llm_inference_pool_api" + llm_inference_pool = importlib.import_module(module_name) - jetstream_pb2_grpc.add_MultiAdapterManagerServicer_to_server( - adapter_manager.MultiLoraManager(driver=self._driver), self._grpc_server - ) + module_name = "jetstream.core.proto.multi_lora_decoding_pb2_grpc" + multi_lora_decoding_pb2_grpc = importlib.import_module(module_name) + + asyncio.run(self._driver.loadAdaptersFromCatalogToTensorStore()) + + multi_lora_decoding_pb2_grpc.add_v1Servicer_to_server( + llm_inference_pool.MultiLoraManager(driver=self._driver), self._grpc_server + ) + else: + module_name = "jetstream.core.proto.jetstream_pb2_grpc" + jetstream_pb2_grpc = importlib.import_module(module_name) + + jetstream_pb2_grpc.add_OrchestratorServicer_to_server( + orchestrator.LLMOrchestrator(driver=self._driver), self._grpc_server + ) self._grpc_server.add_secure_port(f"{_HOST}:{port}", credentials) @@ -188,6 +203,7 @@ def run( enable_jax_profiler: bool = False, jax_profiler_port: int = 9999, enable_model_warmup: bool = False, + enable_llm_inference_pool: bool = False, ) -> JetStreamServer: """Runs a server with a specified config. @@ -228,7 +244,7 @@ def run( # We default threads to the total number of concurrent allowed decodes, # to make sure we can fully saturate the model. Set default minimum to 64. threads = threads or max(driver.get_total_concurrent_requests(), 64) - jetstream_server = JetStreamServer(driver, threads, port, credentials) + jetstream_server = JetStreamServer(driver, threads, port, credentials, enable_llm_inference_pool) logging.info("Starting server on port %d with %d threads", port, threads) jetstream_server.start() diff --git a/jetstream/tools/llm_gateway_proxy_client_v2.py b/jetstream/tools/llm_gateway_proxy_client_v2.py new file mode 100644 index 00000000..d14617c9 --- /dev/null +++ b/jetstream/tools/llm_gateway_proxy_client_v2.py @@ -0,0 +1,161 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A test request.""" + +from typing import Sequence + +from absl import app +from absl import flags +import grpc +from jetstream.core.proto import multi_lora_decoding_pb2 +from jetstream.core.proto import multi_lora_decoding_pb2_grpc +from jetstream.engine.token_utils import load_vocab + + +_SERVER = flags.DEFINE_string("server", "0.0.0.0", "server address") +_PORT = flags.DEFINE_string("port", "9000", "port to ping") +#_TEXT = flags.DEFINE_string("text", "My dog is cute", "The message") +_TEXT = flags.DEFINE_string("text", "22 year old", "The message") +_MAX_TOKENS = flags.DEFINE_integer( + "max_tokens", 3, "Maximum number of output/decode tokens of a sequence" +) + +_ADAPTER_ID = flags.DEFINE_string( + "adapter_id", + None, + "Id of the fine-tuned adapter to be loaded on top of the base model.", + required=False, +) + +_ADAPTER_PATH = flags.DEFINE_string( + "adapter_path", + None, + "Path of the fine-tuned adapter to be loaded from.", + required=False, +) + +_TEST_API_NAME = flags.DEFINE_string( + "test_api_name", + None, + "Name of the JetStream API to call.", + required=True, +) + + +def main(argv: Sequence[str]) -> None: + del argv + # Note: Uses insecure_channel only for local testing. Please add grpc + # credentials for Production. + address = f"{_SERVER.value}:{_PORT.value}" + with grpc.insecure_channel(address) as channel: + grpc.channel_ready_future(channel).result() + stub = multi_lora_decoding_pb2_grpc.v1Stub(channel) + print(f"Sending request to: {address}") + + if _TEST_API_NAME.value == "load_lora_adapter": + print(f"Calling the /v1/load_lora_adapter.") + + adapter_id=_ADAPTER_ID.value + adapter_path=_ADAPTER_PATH.value + + if adapter_id == None or adapter_path == None: + print(f"For `load_lora_adapter` API call, `adapter_id` and `adapter_path` must be passed.") + return + + request = multi_lora_decoding_pb2.LoadAdapterRequest( + adapter_id=adapter_id, + adapter_path=adapter_path + ) + + response = stub.load_lora_adapter(request) + + if response.success is True: + print(f"Adapter={adapter_id} is loaded successfully.") + else: + print(f"Adapter={adapter_id} loading failed with error={response.error_message}") + + elif _TEST_API_NAME.value == "unload_lora_adapter": + print(f"Calling the /v1/unload_lora_adapter.") + + adapter_id=_ADAPTER_ID.value + + if adapter_id == None: + print(f"For `unload_lora_adapter` API call, `adapter_id` must be passed.") + return + + request = multi_lora_decoding_pb2.UnloadAdapterRequest( + adapter_id=adapter_id, + ) + + response = stub.unload_lora_adapter(request) + + if response.success is True: + print(f"Adapter={adapter_id} is unloaded successfully.") + else: + print(f"Adapter={adapter_id} unloading failed with error={response.error_message}") + + elif _TEST_API_NAME.value == "models": + print(f"Calling the /v1/models.") + + request = multi_lora_decoding_pb2.ListAdaptersRequest() + + response = stub.models(request) + + if response.success is True: + print(f"`models` call responded successfully.") + if response.adapter_infos: + print(f"Here is the list of adapters loaded on server:") + else: + print(f"No adapters are loaded on the server.") + + for adapter_info in response.adapter_infos: + print(f"adapter_id={adapter_info.adapter_id}, loading_cost={adapter_info.loading_cost}, size_hbm={adapter_info.size_hbm} bytes, size_cpu={adapter_info.size_cpu} Bytes, last_accessed={adapter_info.last_accessed}, status={adapter_info.status}") + else: + print(f"`models` call failed with error={response.error_message}") + + elif _TEST_API_NAME.value == "completions": + print(f"Calling the /v1/completions.") + + request = multi_lora_decoding_pb2.CompletionRequest( + text_content=multi_lora_decoding_pb2.CompletionRequest.TextContent( + text=_TEXT.value, + ), + max_tokens=_MAX_TOKENS.value, + adapter_id=_ADAPTER_ID.value, + ) + + response = stub.completions(request) + + output = [] + for resp in response: + output.extend(resp.stream_content.samples[0].text) + + text_output = "".join(output) + + print(f"Prompt: {_TEXT.value}") + print(f"Response: {text_output}") + + + elif _TEST_API_NAME.value == None: + print(f"`test_api_name` flag is not set. So exiting.") + return + + else: + print(f"API={_TEST_API_NAME.value} is not implemented yet. So exiting.") + return + + +if __name__ == "__main__": + app.run(main) From 316c4905dcfe818c2f110b1f11d5a92e5a1be75a Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 26 Feb 2025 16:37:14 +0000 Subject: [PATCH 07/22] Adding following metrics into JetStream server: 1) kv_cache_utilization: This refers to percentage of memory in the allocated kv-cache on TPU HBM, that is actually used during decode. It is based on the percentage of slots used. 2) num_requests_waiting: Total number of requests which are waiting to be decoded. 3) lora_requests_info: List of LoRA adapters that are loaded into the TPU HBM for serving the requests. --- jetstream/core/adapter_tensorstore.py | 11 + jetstream/core/metrics/prometheus.py | 34 +++ jetstream/core/orchestrator.py | 37 +++ jetstream/tools/decode_multi_requester_v2.py | 295 +++++++++++++++++++ 4 files changed, 377 insertions(+) create mode 100644 jetstream/tools/decode_multi_requester_v2.py diff --git a/jetstream/core/adapter_tensorstore.py b/jetstream/core/adapter_tensorstore.py index fd8e6e39..5cc602a5 100644 --- a/jetstream/core/adapter_tensorstore.py +++ b/jetstream/core/adapter_tensorstore.py @@ -139,6 +139,17 @@ def convert_if_np(leaf): return jax.tree_util.tree_map(convert_if_np, params) + async def get_hbm_loaded_adapters(self): + hbm_loaded_adapters = [] + + async with self.lock: + for adapter_id, metadata in self.adapter_registry.items(): + if metadata.status == "loaded_hbm": + hbm_loaded_adapters.append(adapter_id) + + return ", ".join(hbm_loaded_adapters) + + async def load_adapter(self, adapter_id: str, adapter_weights = None, to_hbm: bool = True): """Loads a LoRA adapter's weights, managing HBM and CPU memory.""" if adapter_id not in self.adapter_registry: diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index dc8a00e9..ae68559b 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -214,6 +214,31 @@ def __new__(cls): ], ) + _num_requests_waiting = Gauge( + name="num_requests_waiting", + documentation="Number of requests waiting to be processed for inference.", + labelnames=["id"], + multiprocess_mode="sum", + ) + + _kv_cache_utilization = Gauge( + name="kv_cache_utilization_perc", + documentation="Percentage of kv-cache utilized by the requests under processing.", + labelnames=["id"], + multiprocess_mode="sum", + ) + + _lora_request_info = Gauge( + name="lora_request_info", + documentation="Information about LoRA adapters loaded into TPU Memory for serving current requests.", + labelnames=[ + "id", + "max_lora", + "running_lora_adapters", + ], + multiprocess_mode="livemostrecent", + ) + def get_prefill_backlog_metric(self): return self._prefill_backlog.labels(id=self._id) @@ -255,3 +280,12 @@ def get_request_output_length(self): def get_request_success_count_metric(self): return self._request_success_count.labels(id=self._id) + + def get_num_requests_waiting_metric(self): + return self._num_requests_waiting.labels(id=self._id) + + def get_kv_cache_utilization_metric(self): + return self._kv_cache_utilization.labels(id=self._id) + + def get_lora_request_info_metric(self, max_lora: int, loaded_adapters: str): + return self._lora_request_info.labels(id=self._id, max_lora=max_lora, running_lora_adapters=loaded_adapters) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 5e70885b..130467d9 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -314,6 +314,7 @@ def __init__( self._metrics_collector.get_generate_backlog_metric(idx).set_function( functools.partial(float, backlog.qsize()) ) + # Stage 4 # After prefill and generation, ActiveRequests are placed on the # detokenization backlog for tokens to be sent into each ActiveRequest's @@ -433,6 +434,12 @@ def __init__( self.live = True self._is_ray_backend = is_ray_backend + if self._metrics_collector: + self._metrics_collector.get_num_requests_waiting_metric().set_function( + self._get_total_requests_waiting_decode) + self._metrics_collector.get_kv_cache_utilization_metric().set_function( + self._get_kv_cache_utilization) + # Start all threads for t in self._all_threads: t.start() @@ -481,6 +488,28 @@ def stop(self): for t in self._all_threads: t.join() + def _get_kv_cache_utilization(self): + """Calculated the kv_cache utilization in percentage based on requests being decoded.""" + total_slots = 0 + empty_slots = 0 + for idx, engine in enumerate(self._generate_engines): + total_slots += engine.max_concurrent_decodes + empty_slots += self._generate_slots[idx].qsize() + + return ((total_slots - empty_slots) * 100 / total_slots) + + def _get_total_requests_waiting_decode(self): + """Calculate the total size of all relevant queues.""" + total_size = self._prefill_backlog.qsize() + + for transfer_queue in self._transfer_backlogs: + total_size += transfer_queue.qsize() + + for gen_queue in self._generate_backlogs.values(): + total_size += gen_queue.qsize() + + return float(total_size) + def get_total_concurrent_requests(self) -> int: """Gets the total number of concurrent requests the driver can handle.""" # We don't support filling all backlogs at once because it can cause GIL @@ -819,6 +848,14 @@ def _generate_thread(self, idx: int): start_time = time.perf_counter() + if self._metrics_collector: + adapters_list_str = asyncio.run(self._adapter_tensorstore.get_hbm_loaded_adapters()) + + max_loras = max_concurrent_decodes + + self._metrics_collector.get_lora_request_info_metric(max_loras, + adapters_list_str).set_to_current_time() + # Now we actually take a generate step on requests in the slots. decode_state, sampled_tokens = generate_engine.generate( generate_params[adapter_id], decode_state diff --git a/jetstream/tools/decode_multi_requester_v2.py b/jetstream/tools/decode_multi_requester_v2.py new file mode 100644 index 00000000..1479807c --- /dev/null +++ b/jetstream/tools/decode_multi_requester_v2.py @@ -0,0 +1,295 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark JetStream online serving. + +On the server side, run one of the following commands: + * For real server, you need to pass correct server config (include the + model config that being passed into your engine impl) to the command + below. Refer to config_lib.py and implementations/mock/config.py for + config impl detail. + + (run with real server) + python -m jetstream.core.implementations..server \ + --config + + (run with mock server) + python -m jetstream.core.implementations.mock.server + +On the client side, run: + * For real server and shareGPT dataset, you need to pass the tokenizer, + server config, and dataset flags to the command below, and make some + changes to the tokenizer logic in the benchmark script (get_tokenizer + and sample_requests func) to use your tokenizer correctly. + * Add `--save-result` flag to save the benchmark result to a json file in + current folder. + * You can also add `--run_eval true` if you want to calculate ROUGE score + on the predicted outputs. + + (run with real model and engines) + python -m benchmarks.benchmark_serving \ + --tokenizer \ + --dataset \ + --dataset-path \ + --request-rate + + (run with mock) + python -m benchmarks.benchmark_serving \ + --request-rate 1 + +e2e example: +python3 benchmark_serving.py \ + --tokenizer /home/{username}/maxtext/assets/tokenizer \ + --num-prompts 100 \ + --dataset sharegpt \ + --dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json + +""" + + +import argparse +import asyncio +from dataclasses import dataclass, field +from datetime import datetime +import json +import random +import time +from typing import Any, AsyncGenerator, Optional +import os + + +import grpc +from jetstream.core.proto import multi_lora_decoding_pb2 +from jetstream.core.proto import multi_lora_decoding_pb2_grpc +from jetstream.engine.token_utils import load_vocab +from jetstream.external_tokenizers.llama3 import llama3_tokenizer +import numpy as np + + +@dataclass +class InputRequest: + prompt: str = "" + output: str = "" + output_len: int = 0 + sample_idx: int = -1 + + +@dataclass +class RequestFuncOutput: + input_request: Optional[InputRequest] = None + generated_token_list: list[str] = field(default_factory=list) + generated_text: str = "" + success: bool = False + latency: float = 0 + ttft: float = 0 + + # Flatten the structure and return only the necessary results + def to_dict(self): + return { + "prompt": self.input_request.prompt, + "original_output": self.input_request.output, + "generated_text": self.generated_text, + "success": self.success, + "latency": self.latency, + "sample_idx": self.input_request.sample_idx, + } + + +def get_tokenizer( + model_id: str, + tokenizer_name: str, +) -> Any: + """Return a tokenizer or a tokenizer placholder.""" + if tokenizer_name == "test": + print("Using test tokenizer") + return "test" + elif model_id == "llama-3": + # Llama 3 uses a tiktoken tokenizer. + print(f"Using llama-3 tokenizer: {tokenizer_name}") + return llama3_tokenizer.Tokenizer(tokenizer_name) + else: + # Use JetStream tokenizer util. It's using the sentencepiece wrapper in + # seqio library. + print(f"Using tokenizer: {tokenizer_name}") + vocab = load_vocab(tokenizer_name) + return vocab.tokenizer + + +async def grpc_async_request( + api_url: str, request: Any +) -> tuple[list[str], float, float]: + """Send grpc synchronous request since the current grpc server is sync.""" + options = [("grpc.keepalive_timeout_ms", 10000)] + async with grpc.aio.insecure_channel(api_url, options=options) as channel: + stub = multi_lora_decoding_pb2_grpc.v1Stub(channel) + print("Making request") + ttft = 0 + token_list = [] + request_start_time = time.perf_counter() + response = stub.completions(request) + async for resp in response: + if ttft == 0: + ttft = time.perf_counter() - request_start_time + token_list.extend(resp.stream_content.samples[0].token_ids) + latency = time.perf_counter() - request_start_time + return token_list, ttft, latency + + +async def send_request( + api_url: str, + tokenizer: Any, + input_request: InputRequest, +) -> RequestFuncOutput: + """Send the request to JetStream server.""" + # Tokenization on client side following MLPerf standard. + token_ids = tokenizer.encode(input_request.prompt) + request = multi_lora_decoding_pb2.CompletionRequest( + token_content=multi_lora_decoding_pb2.CompletionRequest.TokenContent( + token_ids=token_ids + ), + max_tokens=input_request.output_len, + adapter_id=input_request.adapter_id, + ) + output = RequestFuncOutput() + output.input_request = input_request + generated_token_list, ttft, latency = await grpc_async_request( + api_url, request + ) + output.ttft = ttft + output.latency = latency + output.generated_token_list = generated_token_list + # generated_token_list is a list of token ids, decode it to generated_text. + output.generated_text = tokenizer.decode(generated_token_list) + output.success = True + return output + + +async def get_request( + input_requests: list[InputRequest], +) -> AsyncGenerator[InputRequest, None]: + input_requests = iter(input_requests) + + for request in input_requests: + yield request + + +async def send_multi_request( + api_url: str, + tokenizer: Any, + input_requests: list[InputRequest], +): + """Send multiple LoRA adapter requests.""" + tasks = [] + async for request in get_request(input_requests): + tasks.append( + asyncio.create_task( + send_request( + api_url=api_url, + tokenizer=tokenizer, + input_request=request, + ) + ) + ) + outputs = await asyncio.gather(*tasks) + + return outputs + + +def mock_adapter_requests(total_mock_requests: int): + """Generates a list of mock requests containing mock data.""" + data = [] + for index in range(total_mock_requests): + request = InputRequest() + request.prompt = f"22 year old" + if index == 0: + request.adapter_id = "" + else: + i = (index % 10) +1 + request.adapter_id = f"test_lora_{i}" + request.output_len = 200 + data.append(request) + return data + + +def main(args: argparse.Namespace): + print(args) + + model_id = args.model + tokenizer_id = args.tokenizer + + api_url = f"{args.server}:{args.port}" + + tokenizer = get_tokenizer(model_id, tokenizer_id) + input_requests = mock_adapter_requests( + args.total_mock_requests + ) # e.g. [("AB", 2, "AB", 3)] + + request_outputs = asyncio.run( + send_multi_request( + api_url=api_url, + tokenizer=tokenizer, + input_requests=input_requests, + ) + ) + + output = [output.to_dict() for output in request_outputs] + + # Process output + for index, output in enumerate(output): + print(f"Prompt: {input_requests[index].prompt}") + print(f"AdapterId: {input_requests[index].adapter_id}") + print(f"Output: {output}") + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="Sending multiple serving requests to JetStream Server" + ) + parser.add_argument( + "--server", + type=str, + default="0.0.0.0", + help="Server address.", + ) + parser.add_argument("--port", type=str, default=9000) + parser.add_argument( + "--model", + type=str, + default="no_model", + help=( + "Name of the model like llama-2, llama-3, gemma. (it's just used to" + " label the benchmark, pick the tokenizer, the model config is" + " defined in config_lib, and passed as the server config flag when" + " we run the JetStream server)" + ), + ) + parser.add_argument( + "--total-mock-requests", + type=int, + default=3, + help="The maximum number of mock requests to send for benchmark testing.", + ) + parser.add_argument( + "--tokenizer", + type=str, + default="test", + help=( + "Name or path of the tokenizer. (For mock model testing, use the" + " default value)" + ), + ) + + parsed_args = parser.parse_args() + main(parsed_args) From e4d875aea4ef686600ea8eaa0532a6fd45e9d409 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 6 Mar 2025 04:55:05 +0000 Subject: [PATCH 08/22] Refactoring and cleaning of the JetStream server code. --- jetstream/core/adapter_manager.py | 108 ------ .../core/{ => lora}/adapter_tensorstore.py | 231 ++++-------- .../multi_lora_inference_api.py} | 6 +- jetstream/core/orchestrator.py | 177 ++------- jetstream/core/proto/jetstream.proto | 48 --- .../core/proto/jetstream_pb2_grpc_original.py | 209 ----------- jetstream/core/server_lib.py | 27 +- jetstream/tools/decode_multi_requester.py | 337 ------------------ jetstream/tools/llm_gateway_proxy_client.py | 136 ------- ..._v2.py => multi_adapter_service_client.py} | 0 ...r_v2.py => multi_lora_decode_requester.py} | 44 +-- jetstream/tools/requester.py | 4 +- 12 files changed, 114 insertions(+), 1213 deletions(-) delete mode 100644 jetstream/core/adapter_manager.py rename jetstream/core/{ => lora}/adapter_tensorstore.py (57%) rename jetstream/core/{llm_inference_pool_api.py => lora/multi_lora_inference_api.py} (98%) delete mode 100644 jetstream/core/proto/jetstream_pb2_grpc_original.py delete mode 100644 jetstream/tools/decode_multi_requester.py delete mode 100644 jetstream/tools/llm_gateway_proxy_client.py rename jetstream/tools/{llm_gateway_proxy_client_v2.py => multi_adapter_service_client.py} (100%) rename jetstream/tools/{decode_multi_requester_v2.py => multi_lora_decode_requester.py} (81%) diff --git a/jetstream/core/adapter_manager.py b/jetstream/core/adapter_manager.py deleted file mode 100644 index 0ca76dfe..00000000 --- a/jetstream/core/adapter_manager.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Manages the list of fine-tuned adapters loaded on top of the base model for serving. -""" - -import logging -import grpc - -from typing import Optional -from jetstream.core import adapter_tensorstore -from jetstream.core import orchestrator -from jetstream.core.proto import jetstream_pb2_grpc -from jetstream.core.proto import jetstream_pb2 - - -def calculate_loading_cost(adapter_path: str): - return 1 - - -class MultiLoraManager(jetstream_pb2_grpc.MultiAdapterManagerServicer): - """Manages the parameters of multiple lora requests and their lifelines.""" - - _driver: orchestrator.Driver - - def __init__(self, driver: orchestrator.Driver): - self._driver = driver - - def ListAdapters( - self, - request: jetstream_pb2.ListAdaptersRequest, - context: Optional[grpc.aio.ServicerContext] = None, - ) -> jetstream_pb2.ListAdaptersResponse: - """ListAdapters all loaded LoRA adapters.""" - - try: - adapters = self._driver.listAdaptersFromTensorstore() - - adapter_infos = [] - for adapter_id, adapter_data in adapters.items(): - if adapter_data.status == "loaded_hbm": - loading_cost = 0 - elif adapter_data.status == "loaded_cpu": - loading_cost = 1 - elif adapter_data.status == "unloaded": - loading_cost = 2 - else: - loading_cost = -1 - - adapter_info = jetstream_pb2.AdapterInfo( - adapter_id=adapter_id, - loading_cost=loading_cost, - size_hbm=adapter_data.size_hbm, - size_cpu=adapter_data.size_cpu, - last_accessed=adapter_data.last_accessed, - status=adapter_data.status) - - adapter_infos.append(adapter_info) - - return jetstream_pb2.ListAdaptersResponse(success=True, adapter_infos=adapter_infos) - except Exception as e: - logging.info(f"Listing of adapters failed with error: {str(e)}") - return jetstream_pb2.ListAdaptersResponse(success=False, error_message=str(e)) - - - def LoadAdapter( - self, - request: jetstream_pb2.LoadAdapterRequest, - context: Optional[grpc.aio.ServicerContext] = None, - ) -> jetstream_pb2.LoadAdapterResponse: - """Load a LoRA adapter as mentioned in the request.""" - - try: - self._driver.loadAdapterToTensorstore(request.adapter_id, request.adapter_path) - - return jetstream_pb2.LoadAdapterResponse(success=True) - except Exception as e: - logging.info(f"Loading of adapter_id={request.adapter_id} failed with error: {str(e)}") - return jetstream_pb2.LoadAdapterResponse(success=False, error_message=str(e)) - - - def UnloadAdapter( - self, - request: jetstream_pb2.UnloadAdapterRequest, - context: Optional[grpc.aio.ServicerContext] = None, - ) -> jetstream_pb2.UnloadAdapterResponse: - """Unload a LoRA adapter as mentioned in the request.""" - - try: - self._driver.unloadAdapterFromTensorstore(request.adapter_id) - return jetstream_pb2.UnloadAdapterResponse(success=True) - except Exception as e: - logging.info(f"Loading of adapter_id={request.adapter_id} failed with error: {str(e)}") - return jetstream_pb2.UnloadAdapterResponse(success=False, error_message=str(e)) - - - diff --git a/jetstream/core/adapter_tensorstore.py b/jetstream/core/lora/adapter_tensorstore.py similarity index 57% rename from jetstream/core/adapter_tensorstore.py rename to jetstream/core/lora/adapter_tensorstore.py index 5cc602a5..306d13d2 100644 --- a/jetstream/core/adapter_tensorstore.py +++ b/jetstream/core/lora/adapter_tensorstore.py @@ -28,15 +28,40 @@ import numpy as np +def _get_size_of_pytree(params): + """Get the size of the PyTree.""" + + params_bytes = jax.tree_util.tree_map(lambda x: x.nbytes, params) + total_bytes = jax.tree_util.tree_reduce(lambda x, y: x + y, params_bytes) + return total_bytes + + +def _as_np_array(params): + """Create a new PyTree with Tensors as np.array.""" + + def convert_if_jnp(leaf): + return np.array(leaf) + + return jax.tree_util.tree_map(convert_if_jnp, params) + + +def _as_jnp_array(params): + """Create a new PyTree with Tensors as jnp.array.""" + + def convert_if_np(leaf): + return jnp.array(leaf) + + return jax.tree_util.tree_map(convert_if_np, params) + + @dataclasses.dataclass class AdapterMetadata: adapter_id: str adapter_path: str status: str = "unloaded" # "loaded_hbm", "loaded_cpu", "loading", "unloading" - size_hbm: int = 0 # Size in HBM (bytes) - size_cpu: int = 0 # Size in CPU RAM (bytes) + size_hbm: int = 0 # Size in HBM (bytes) + size_cpu: int = 0 # Size in CPU RAM (bytes) last_accessed: float = 0.0 # timestamp - # rank: int = 8 config: Dict[str, Any] = None @@ -44,9 +69,9 @@ class AdapterTensorStore: def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int): self.hbm_memory_budget = hbm_memory_budget self.cpu_memory_budget = cpu_memory_budget - self.adapter_registry: Dict[str, AdapterMetadata] = {} # All known adapters + self.adapter_registry: Dict[str, AdapterMetadata] = {} # All known adapters self.loaded_adapters_hbm: Dict[str, jnp.ndarray] = {} # adapter_id -> Unified LoRA params (in HBM) - self.loaded_adapters_cpu: Dict[str, np.ndarray] = {} # adapter_id -> Unified LoRA params (in CPU RAM) + self.loaded_adapters_cpu: Dict[str, np.ndarray] = {} # adapter_id -> Unified LoRA params (in CPU RAM) self.current_hbm_usage: int = 0 self.current_cpu_usage: int = 0 self.running_requests: int = 0 # Number of async tasks which are in "loading" state @@ -63,11 +88,6 @@ def register_adapter(self, adapter_id: str, adapter_path: str, config: Dict[str, config=config) - def _get_size(self, arr: jnp.ndarray | np.ndarray) -> int: - """Calculates the size of a JAX or NumPy array in bytes.""" - # Use asarray to handle both JAX and NumPy arrays consistently - return np.asarray(arr).nbytes - async def _transfer_to_hbm(self, adapter_id: str): """Transfers an adapter from CPU RAM to HBM.""" if adapter_id not in self.loaded_adapters_cpu: @@ -76,13 +96,18 @@ async def _transfer_to_hbm(self, adapter_id: str): async with self.lock: #Acquire lock metadata = self.adapter_registry[adapter_id] + if metadata.status == "loaded_hbm": + return + # Check if we have enough space in HBM; evict if necessary while (self.current_hbm_usage + metadata.size_hbm) > self.hbm_memory_budget: if not self._evict(from_hbm=True): raise RuntimeError("Not enough HBM to transfer adapter, and eviction failed.") - # Move from CPU to HBM - self.loaded_adapters_hbm[adapter_id] = self._as_jnp_array(self.loaded_adapters_cpu[adapter_id]) # Convert to JAX array + # Move from CPU RAM to HBM + self.loaded_adapters_hbm[adapter_id] = _as_jnp_array(self.loaded_adapters_cpu[adapter_id]) # Convert to JAX array + + # TODO(amangu): We can avoid deleting cpu_loaded adapters if RAM is not a concern del self.loaded_adapters_cpu[adapter_id] self.current_cpu_usage -= metadata.size_cpu @@ -101,13 +126,16 @@ async def _transfer_to_cpu(self, adapter_id: str): async with self.lock: metadata = self. adapter_registry[adapter_id] + if metadata.status == "loaded_cpu": + return + # Check if we have enough space in CPU; evict if necessary. while (self.current_cpu_usage + metadata.size_cpu) > self.cpu_memory_budget: if not self._evict(from_hbm=False): raise RuntimeError("Not enough CPU RAM to transfer adapter, and eviction failed.") - # Move from HBM to CPU - self.loaded_adapters_cpu[adapter_id] = self._as_np_array(self.loaded_adapters_hbm[adapter_id]) + # Move from HBM to CPU RAM + self.loaded_adapters_cpu[adapter_id] = _as_np_array(self.loaded_adapters_hbm[adapter_id]) del self.loaded_adapters_hbm[adapter_id] self.current_hbm_usage -= metadata.size_hbm @@ -117,29 +145,9 @@ async def _transfer_to_cpu(self, adapter_id: str): metadata.last_accessed = time.time() - def _get_size_of_pytree(self, params): - params_bytes = jax.tree_util.tree_map(lambda x: x.nbytes, params) - total_bytes = jax.tree_util.tree_reduce(lambda x, y: x + y, params_bytes) - return total_bytes - - - def _as_np_array(self, params): - - def convert_if_jnp(leaf): - return np.array(leaf) - - return jax.tree_util.tree_map(convert_if_jnp, params) - - - def _as_jnp_array(self, params): - - def convert_if_np(leaf): - return jnp.array(leaf) - - return jax.tree_util.tree_map(convert_if_np, params) - - async def get_hbm_loaded_adapters(self): + """Returns a comma separated list of adapters loaded into HBM.""" + hbm_loaded_adapters = [] async with self.lock: @@ -150,16 +158,21 @@ async def get_hbm_loaded_adapters(self): return ", ".join(hbm_loaded_adapters) - async def load_adapter(self, adapter_id: str, adapter_weights = None, to_hbm: bool = True): + async def load_adapter( + self, + adapter_id: str, + adapter_weights = None, + to_hbm: bool = True, + force_load: bool = False): """Loads a LoRA adapter's weights, managing HBM and CPU memory.""" + if adapter_id not in self.adapter_registry: raise ValueError(f"Adapter with ID '{adapter_id}' not registered.") metadata = self.adapter_registry[adapter_id] async with self.lock: # Acquire lock for thread safety - #logging.info(f"AMANGU Logs: Lock aquired by loading section of coroutine {asyncio.current_task().get_name()}.") - if metadata.status in ("loaded_hbm", "loaded_cpu"): + if not force_load and metadata.status in ("loaded_hbm", "loaded_cpu"): metadata.last_accessed = time.time() # if already loaded in HBM and we want HBM, or @@ -182,48 +195,29 @@ async def load_adapter(self, adapter_id: str, adapter_weights = None, to_hbm: bo await asyncio.sleep(0.1) # Short sleep to avoid busy-waiting # Make recursive call to load_adapter to copy to device - await self.load_adapter(adapter_id, adapter_weights, to_hbm) + await self.load_adapter(adapter_id, adapter_weights, to_hbm, force_load) return metadata.status = "loading" self.running_requests += 1 - #logging.info(f"AMANGU Logs: Lock released by loading section of coroutine {asyncio.current_task().get_name()}.") # Load the adapter (asynchronous) loop = asyncio.get_running_loop() - try: - - # TODO(amangu): Placeholder for the loading logic. Replace with code to load - # the LoRA weights from the specific path. - - # --- ASYNCHRONOUS LOADING (CRITICAL!) --- - # Use asyncio.to_thread or similar to avoid blocking - - # TODO(amangu): Assumed that load_lora_weights is defined elsewhere - # which returns a dictionary: {"lora_A": ..., "lora_B": ...}. Adapt this part - # based on the actual structure of the loaded LoRA weights. + try: if adapter_weights is None: - adapter_weights = await loop.run_in_executor( - None, - functools.partial(load_lora_weights, metadata.adapter_path)) + raise ValueError("Adapter weights for adapter_id={adapter_id} is None.") async with self.lock: # Critical section for memory management - # Combine lora_a and lora_b to form a unified parameter. - # TODO(amangu): Check if combining and storing is having any optimization. - # unified_lora_params = self._combine_lora_params(lora_weights, metadata.rank) - #logging.info(f"AMANGU Logs: Lock aquired by saving section of coroutine {asyncio.current_task().get_name()}.") - - unified_lora_params = adapter_weights - unified_lora_params_as_jnp_array = self._as_jnp_array(unified_lora_params) - unified_lora_params_as_np_array = self._as_np_array(unified_lora_params) - del unified_lora_params + adapter_weights_as_jnp_array = _as_jnp_array(adapter_weights) + adapter_weights_as_np_array = _as_np_array(adapter_weights) + del adapter_weights # Get size of unified_lora_params when they are saved in HBM as JAX array - adapter_size_hbm = self._get_size_of_pytree(unified_lora_params_as_jnp_array) + adapter_size_hbm = _get_size_of_pytree(adapter_weights_as_jnp_array) # Get size of unified_lora_params when they are saved in CPU RAM as NumPy array - adapter_size_cpu = self._get_size_of_pytree(unified_lora_params_as_np_array) + adapter_size_cpu = _get_size_of_pytree(adapter_weights_as_np_array) metadata.size_hbm = adapter_size_hbm metadata.size_cpu = adapter_size_cpu @@ -241,17 +235,16 @@ async def load_adapter(self, adapter_id: str, adapter_weights = None, to_hbm: bo # Now that we have space (potentially), do the actual loading if to_hbm: - self.loaded_adapters_hbm[adapter_id] = unified_lora_params_as_jnp_array # Convert the PyTree to Jax Array + self.loaded_adapters_hbm[adapter_id] = adapter_weights_as_jnp_array # Convert the PyTree to Jax Array self.current_hbm_usage += adapter_size_hbm metadata.status = "loaded_hbm" else: #to cpu - self.loaded_adapters_cpu[adapter_id] = unified_lora_params_as_np_array # Convert the PyTree to NumPy Array + self.loaded_adapters_cpu[adapter_id] = adapter_weights_as_np_array # Convert the PyTree to NumPy Array self.current_cpu_usage += adapter_size_cpu metadata.status = "loaded_cpu" metadata.last_accessed = time.time() - #logging.info(f"AMANGU Logs: Lock released by saving section of coroutine {asyncio.current_task().get_name()}.") except Exception as e: async with self.lock: @@ -262,80 +255,8 @@ async def load_adapter(self, adapter_id: str, adapter_weights = None, to_hbm: bo self.running_requests -= 1 - def _combine_lora_params(self, lora_weights, rank): - # Create a list to hold the combined LoRA parameters - combined_lora_params = [] - - for i in range(0, len(lora_weights), 2): - lora_a = lora_weights[i] - lora_b = lora_weights[i+1] - - # Reshape and concatenate lora_a and lora_b - # Assuming 'br,rnd->bnd' einsum configuration, where 'b' is batch, - # 'r' is rank, 'n' is num_heads, and 'd' is head_dim - num_heads = lora_a.shape[1] # Get number of heads from lora_a - head_dim = lora_a.shape[2] # Get head dimension from lora_a - - lora_a = jnp.transpose(lora_a, (1, 2, 0)) # (r, n, d) -> (n, d, r) - lora_b_reshaped = jnp.reshape(lora_b, (num_heads, head_dim, rank)) # (n * d, r) -> (n, d, r) - - combined_lora_param = jnp.einsum("ndr,ndr->ndr", lora_a, lora_b_reshaped) - combined_lora_params.append(combined_lora_param) - - # Concatenate the parameters for all layers to form a single unified parameter - unified_lora_params = jnp.stack(combined_lora_params, axis=0) - return unified_lora_params - - - def get_stacked_lora_weights(self, lora_ids: jnp.ndarray, to_hbm: bool = True): - """Retrieves the unified LoRA parameters for the given adapter IDs. - Handles HBM/CPU placement. - """ - - # The logic here is crucial. We have `lora_ids`, an array of shape - #(batch_size,), where each element is the ID of the LoRA adapter - # to use for that request in the batch. You need to use this to - # select the appropriate slices from the *unified* LoRA paramters. - - # 1. Get the unified LoRA paramters for the requested IDs. This - # might involve waiting if some adapters are still loading. - - required_adapters = set(lora_ids.tolist()) # Get unique adapter IDs - for adapter_id in required_adapters: - metadata = self.adapter_registry.get(adapter_id) - - if metadata is None: - raise ValueError(f"Adapter with ID '{adapter_id}' not registered.") - - if metadata.status != "loaded_hbm" and metadata.status != "loaded_cpu": - asyncio.run(self.load_adapter(adapter_id, to_hbm)) # Start loading (async) - elif to_hbm and metadata.status == "loaded_cpu": - asyncio.run(self._transfer_to_hbm(adapter_id)) - elif not to_hbm and metadata.status == "loaded_hbm": - asyncio.run(self._transfer_to_cpu(adapter_id)) - - # Wait till all the running requests are completed - while self.running_requests > 0: - time.sleep(0.1) - - # Now all required adapters should be loaded in correct memory (HBM or CPU), get them - if to_hbm: - required_adapters_params = [self.loaded_adapters_hbm[adapter_id] for adapter_id in required_adapters] - else: - required_adapters_params = [self.loaded_adapters_cpu[adapter_id] for adapter_id in required_adapters] - - # Stack the parameters for the required adapters - stacked_params = jax.tree_util.tree_map(lambda *arrs: jnp.stack(arrs), *required_adapters_params) - - # Extract paramters using jnp.take() function for the lora_ids. - retrieved_lora_params = jax.tree_util.tree_map( - lambda arr: jnp.take(arr, lora_ids, axis=0, fill_value=0), - stacked_params) - - return retrieved_lora_params - - def get_lora_config(self, adapter_id): + """Getter for the LoRA adapter config.""" metadata = self.adapter_registry.get(adapter_id) return metadata.config @@ -345,14 +266,6 @@ def get_lora_weights(self, adapter_id, to_hbm: bool = True): Handles HBM/CPU placement. """ - # The logic here is crucial. We have `lora_ids`, an array of shape - #(batch_size,), where each element is the ID of the LoRA adapter - # to use for that request in the batch. You need to use this to - # select the appropriate slices from the *unified* LoRA paramters. - - # 1. Get the unified LoRA paramters for the requested IDs. This - # might involve waiting if some adapters are still loading. - metadata = self.adapter_registry.get(adapter_id) if metadata is None: @@ -393,6 +306,7 @@ async def unload_adapter(self, adapter_id: str): # Wait for the loading to get complete. while metadata.status == "loading": await asyncio.sleep(0.1) + if metadata.status == "loaded_hbm": del self.loaded_adapters_hbm[adapter_id] self.current_hbm_usage -= metadata.size_hbm @@ -402,7 +316,7 @@ async def unload_adapter(self, adapter_id: str): self.current_cpu_usage -= metadata.size_cpu metadata.status = "unloaded" - metadata.last_accessed = 0.0 # Reset last accessed time + metadata.last_accessed = time.time() # Unload time metadata.size_hbm = 0 metadata.size_cpu = 0 @@ -419,11 +333,6 @@ def _evict(self, from_hbm: bool = True) -> bool: lru_adapter_id = None lru_time = float('inf') - if from_hbm: - adapters_dict = self.loaded_adapters_hbm - else: - adapters_dict = self.loaded_adapters_cpu - for adapter_id, metadata in self.adapter_registry.items(): if metadata.status == "loaded_hbm" if from_hbm else metadata.status == "loaded_cpu": if metadata.last_accessed < lru_time: @@ -434,7 +343,13 @@ def _evict(self, from_hbm: bool = True) -> bool: if lru_adapter_id is None: return False - # Unload the LRU adapter - self.unload_adapter(lru_adapter_id) # This is not synchronous, but ONLY within the lock + if from_hbm: + # Instead of completely unloading it, kept it in CPU RAM. + # It can be loaded to HBM if any request demanded it, or + # it will be evicted from CPU when cpu memory budget reached. + self._transfer_to_cpu(lru_adapter_id) + else: + # Unload the LRU adapter + self.unload_adapter(lru_adapter_id) # This is not synchronous, but ONLY within the lock return True diff --git a/jetstream/core/llm_inference_pool_api.py b/jetstream/core/lora/multi_lora_inference_api.py similarity index 98% rename from jetstream/core/llm_inference_pool_api.py rename to jetstream/core/lora/multi_lora_inference_api.py index f88ac5c6..998617f2 100644 --- a/jetstream/core/llm_inference_pool_api.py +++ b/jetstream/core/lora/multi_lora_inference_api.py @@ -20,8 +20,8 @@ import time from typing import Any, AsyncIterator, Optional, Tuple, cast -from jetstream.core import adapter_tensorstore from jetstream.core import orchestrator +from jetstream.core.lora import adapter_tensorstore from jetstream.core.proto import multi_lora_decoding_pb2_grpc from jetstream.core.proto import multi_lora_decoding_pb2 from jetstream.core.utils import async_multifuture @@ -39,8 +39,8 @@ def __init__(self, driver: orchestrator.Driver): def models( self, - request: multi_lora_decoding_pb2.ListAdaptersRequest, - context: Optional[grpc.aio.ServicerContext] = None, + request: multi_lora_decoding_pb2.ListAdaptersRequest, + context: Optional[grpc.aio.ServicerContext] = None, ) -> multi_lora_decoding_pb2.ListAdaptersResponse: """ListAdapters all loaded LoRA adapters.""" diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 130467d9..76adf399 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -92,7 +92,7 @@ import grpc import jax import jax.numpy as jnp -from jetstream.core import adapter_tensorstore +from jetstream.core.lora import adapter_tensorstore from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc from jetstream.core.utils import async_multifuture @@ -231,6 +231,7 @@ class Driver: # All metrics we want to monitor should be collected with this _metrics_collector: JetstreamMetricsCollector | None = None + # An object to store and manage the adapters _adapter_tensorstore: adapter_tensorstore.AdapterTensorStore | None = None def __init__( @@ -254,8 +255,8 @@ def __init__( generate_params = [] self._adapter_tensorstore = adapter_tensorstore.AdapterTensorStore( - hbm_memory_budget=(20 * (1024 ** 3)), - cpu_memory_budget=(100 * (1024 ** 3))) + hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM + cpu_memory_budget=(100 * (1024 ** 3))) # 100 GB RAM logging.info( "Initialising driver with %d prefill engines and %d generate engines.", @@ -314,7 +315,6 @@ def __init__( self._metrics_collector.get_generate_backlog_metric(idx).set_function( functools.partial(float, backlog.qsize()) ) - # Stage 4 # After prefill and generation, ActiveRequests are placed on the # detokenization backlog for tokens to be sent into each ActiveRequest's @@ -459,7 +459,6 @@ def stop(self): ) ) - while any(t.is_alive() for t in self._all_threads): # Empty all backlogs and mark any remaining requests as cancelled. for q in all_backlogs: @@ -510,6 +509,19 @@ def _get_total_requests_waiting_decode(self): return float(total_size) + def _export_lora_request_info(self): + """Export the metric named `lora_request_info`.""" + + adapters_list_str = "" + max_loras = 0 + if self._metrics_collector: + for idx, engine in enumerate(self._generate_engines): + adapters_list_str += asyncio.run(self._adapter_tensorstore.get_hbm_loaded_adapters()) + max_loras += engine.max_concurrent_decodes + + self._metrics_collector.get_lora_request_info_metric(max_loras, + adapters_list_str).set_to_current_time() + def get_total_concurrent_requests(self) -> int: """Gets the total number of concurrent requests the driver can handle.""" # We don't support filling all backlogs at once because it can cause GIL @@ -555,6 +567,7 @@ def _prefill_thread(self, idx: int): """Thread which runs in the background performing prefills.""" logging.info("---------Spinning up prefill thread %d.---------", idx) prefill_engine = self._prefill_engines[idx] + prefill_params = self._prefill_params[idx] metadata = prefill_engine.get_tokenizer() tokenizer = prefill_engine.build_tokenizer(metadata) logging.info("---------Prefill params %d loaded.---------", idx) @@ -567,7 +580,6 @@ def _prefill_thread(self, idx: int): if request is None: break - prefill_params = self._prefill_params[idx] request.metadata.prefill_dequeue_time = time.perf_counter() is_bos = True @@ -584,19 +596,13 @@ def _prefill_thread(self, idx: int): request, tokenizer, is_bos, prefill_engine.max_prefill_length ) - start_time = time.perf_counter() - logging.info(f"AMANGU Log (orchestrator.py): Starting timer for Driver._prefill_thread -> prefill_engine.prefill().") - adapter_id = request.adapter_id - if adapter_id == "": - adapter_id = "base_params" - final_params = None - if adapter_id == "base_params": - final_params = prefill_params[adapter_id] + if adapter_id == "": + final_params = prefill_params else: - final_params = copy.deepcopy(prefill_params["base_params"]) + final_params = copy.deepcopy(prefill_params) lora_params = self._adapter_tensorstore.get_lora_weights(adapter_id) lora_config = self._adapter_tensorstore.get_lora_config(adapter_id) self._prefill_engines[idx].apply_adapter( @@ -612,10 +618,6 @@ def _prefill_thread(self, idx: int): ) del final_params - end_time = time.perf_counter() - elapsed_time = (end_time - start_time) * 1e6 - - logging.info(f"AMANGU Log (orchestrator.py): Time taken for Driver._prefill_thread -> prefill_engine.prefill() is {elapsed_time} Micro-seconds.") request.prefill_result = prefill_result @@ -628,9 +630,6 @@ def _prefill_thread(self, idx: int): block=True, ) - elapsed_time = request.metadata.transfer_enqueue_time - request.metadata.prefill_dequeue_time - logging.info(f"AMANGU Log (orchestrator.py): Time taken in whole prefill_thread is {elapsed_time} Seconds.") - # Once prefill is complete, place it on the generation queue and block if # full. my_transfer_backlog.put(request, block=True) @@ -721,19 +720,15 @@ def _generate_thread(self, idx: int): logging.info("---------Spinning up generate thread %d.---------", idx) generate_engine = self._generate_engines[idx] my_slots = self._generate_slots[idx] - logging.info(f"AMANGU: In _generate_thread: my_slots size = {my_slots.qsize()}") - logging.info(f"AMANGU: In _generate_thread: max_concurrent_decodes = {generate_engine.max_concurrent_decodes}") my_generate_backlog = self._generate_backlogs[idx] my_detokenize_backlog = self._generate_detokenize_backlogs[idx] # Keep track of what step tokens were generated at. generate_timestep = 0 - generate_engine.print_stats("Pre-start Generate Thread: Before init_decode_state") # State to store things like running kv cache in. decode_state = generate_engine.init_decode_state() - generate_engine.print_stats("Pre-start Generate Thread: After init_decode_state") - # generate_params = self._generate_params[idx] + generate_params = self._generate_params[idx] logging.info("---------Generate params %d loaded.---------", idx) time_of_last_generate = time.time() @@ -783,10 +778,8 @@ def _generate_thread(self, idx: int): block |= not self._transfer_backlogs[idx].empty() try: new_request = my_generate_backlog.get(block=block, timeout=1.0) - if new_request is None: break - new_request.metadata.generate_dequeue_time = time.perf_counter() if ( self._metrics_collector @@ -828,6 +821,8 @@ def _generate_thread(self, idx: int): decode_state = generate_engine.insert( new_request.prefill_result, decode_state, slot=slot ) + + self._export_lora_request_info() del new_request.prefill_result new_request.generate_timestep_added = generate_timestep @@ -842,32 +837,12 @@ def _generate_thread(self, idx: int): my_slots.qsize() < max_concurrent_decodes ), "At this point we must have some requests inserted into the slots." - generate_params = self._generate_params[idx] - - adapter_id = "base_params" - - start_time = time.perf_counter() - - if self._metrics_collector: - adapters_list_str = asyncio.run(self._adapter_tensorstore.get_hbm_loaded_adapters()) - - max_loras = max_concurrent_decodes - - self._metrics_collector.get_lora_request_info_metric(max_loras, - adapters_list_str).set_to_current_time() - # Now we actually take a generate step on requests in the slots. decode_state, sampled_tokens = generate_engine.generate( - generate_params[adapter_id], decode_state + generate_params, decode_state ) sampled_tokens.copy_to_host_async() - end_time = time.perf_counter() - - elapsed_time = (end_time - start_time) * 1e6 - - logging.info(f"AMANGU Log (orchestrator.py): Time taken to execute Decode.generate_thread -> generate_engine.generate is {elapsed_time} Micro-seconds.") - # Respond to detokenization backpressure. my_detokenize_backlog.put((generate_timestep, sampled_tokens), block=True) generate_timestep += 1 @@ -1004,50 +979,6 @@ def _detokenize_thread(self, is_prefill: bool, idx: int): my_live_requests[slot] = active_request - async def loadAdaptersFromCatalogToTensorStore(self): - logging.info(f"Loading adapters from the catalog file at the start of the server.") - - if not self._prefill_engines and not self._generate_engines: - logging.info(f"There is no MaxEngine object defined. So could not load any adapter.") - - engine = None - - if self._prefill_engines: - engine = self._prefill_engines[0] - else: - engine = self._generate_engines[0] - - adapter_params_and_config = engine.load_adapters_from_catalog_file() - - if not adapter_params_and_config: - logging.info("There is no adapter loaded from the catelog file.") - - tasks = [] - for key, value in adapter_params_and_config.items(): - adapter_id = key - adapter_config = value["config"] - adapter_params_pytree =value["params"] - - try: - self._adapter_tensorstore.register_adapter( - adapter_id, - adapter_config["adapter_path"], - adapter_config) - - except ValueError as e: - logging.info(f"Registration failed with error: {str(e)}") - - task = asyncio.create_task(self._adapter_tensorstore.load_adapter(adapter_id, adapter_params_pytree, False)) - task.set_name(f"Task:loading-adapter-{adapter_id}") - tasks.append(task) - - await asyncio.gather(*tasks) - - logging.info(f"All adapters from catalog file loaded successfully.") - - engine.print_stats("After loading all adapters from catelog.") - - def loadAdapterToTensorstore( self, adapter_id, @@ -1110,64 +1041,6 @@ def listAdaptersFromTensorstore(self): return self._adapter_tensorstore.adapter_registry - def loadAndApplyAdapter( - self, - adapter_id, - adapter_config_path, - adapter_weights_path): - logging.info(f"Loading and applying fine-tuning adapter to base weights") - - for index, params in enumerate(self._prefill_params): - if adapter_id not in params: - params[adapter_id] = copy.deepcopy(params["base_params"]) - self._prefill_engines[index].load_and_apply_adapter(params[adapter_id], - adapter_config_path, - adapter_weights_path) - else: - logging.info(f"Adapter={adapter_id} is already present in the prefill_params.") - - for index, params in enumerate(self._generate_params): - if adapter_id not in params: - params[adapter_id] = copy.deepcopy(params["base_params"]) - self._generate_engines[index].load_and_apply_adapter(params[adapter_id], - adapter_config_path, - adapter_weights_path) - else: - logging.info(f"Adapter={adapter_id} is already present in the generate_params.") - - def unloadAdapter( - self, - adapter_id): - logging.info(f"Unloading the adapter with adapter_id={adapter_id}") - - for params in self._prefill_params: - if adapter_id in params: - del params[adapter_id] - logging.info(f"Successfully deleted Adapter={adapter_id} from the prefill_params.") - else: - logging.info(f"Adapter={adapter_id} is not there in the prefill_params.") - - for params in self._generate_params: - if adapter_id in params: - del params[adapter_id] - logging.info(f"Successfully deleted Adapter={adapter_id} from the generate_params.") - else: - logging.info(f"Adapter={adapter_id} is not there in the generate_params.") - - def mayBeListLoadedAdapters(self): - logging.info(f"Listing loaded adapters:") - - loaded_adapters_in_prefill = [] - for params in self._prefill_params: - loaded_adapters_in_prefill.extend(list(params.keys())) - logging.info(f"In prefill_params: {loaded_adapters_in_prefill}") - - loaded_adapters_in_generate = [] - for params in self._generate_params: - loaded_adapters_in_generate.extend(list(params.keys())) - logging.info(f"In generate_params: {loaded_adapters_in_generate}") - - class LLMOrchestrator(jetstream_pb2_grpc.OrchestratorServicer): """Coordinates a set of prefill and generate slices for LLM decoding.""" diff --git a/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto index d3427be6..9516b1c3 100644 --- a/jetstream/core/proto/jetstream.proto +++ b/jetstream/core/proto/jetstream.proto @@ -93,53 +93,5 @@ message HealthCheckRequest {} message HealthCheckResponse { // Denotes whether the model server is live bool is_live = 1; -} - -service MultiAdapterManager { - // Lists all the currently loaded LoRA adapters - rpc ListAdapters (ListAdaptersRequest) returns (ListAdaptersResponse) {} - - // Loads a new LoRA adapter. - rpc LoadAdapter (LoadAdapterRequest) returns (LoadAdapterResponse) {} - // Unloads a LoRA adapter - rpc UnloadAdapter (UnloadAdapterRequest) returns (UnloadAdapterResponse) {} } - -message ListAdaptersRequest {} - -message ListAdaptersResponse { - bool success = 1; // True if successful, False otherwise - string error_message = 2; // Error message if listing the adapters - repeated AdapterInfo adapter_infos = 3; // List of information about loaded adapters. -} - -// Information about a single loaded LoRA adapter -message AdapterInfo { - string adapter_id = 1; - int64 loading_cost = 2; - int64 size_hbm = 3; - int64 size_cpu = 4; - float last_accessed = 5; - string status = 6; -} - -message LoadAdapterRequest { - string adapter_id = 1; // Unique ID for the adapter - string adapter_path = 2; // Path to the LoRA adapter (config & weights) -} - -message LoadAdapterResponse { - bool success = 1; // True if successful, false otherwise - string error_message = 2; // Error message if loading failed -} - -message UnloadAdapterRequest { - string adapter_id = 1; // ID of the adapter to unload -} - -message UnloadAdapterResponse { - bool success = 1; // True if successful, false otherwise - string error_message = 2; // Error message if unloading failed -} - diff --git a/jetstream/core/proto/jetstream_pb2_grpc_original.py b/jetstream/core/proto/jetstream_pb2_grpc_original.py deleted file mode 100644 index 5de13a1e..00000000 --- a/jetstream/core/proto/jetstream_pb2_grpc_original.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - -from jetstream.core.proto import jetstream_pb2 as jetstream_dot_core_dot_proto_dot_jetstream__pb2 - - -class OrchestratorStub(object): - """TODO: Merge this with main JetStream core once we settle on an API.""" - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Decode = channel.unary_stream( - "/jetstream_proto.Orchestrator/Decode", - request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString, - response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString, - ) - self.HealthCheck = channel.unary_unary( - "/jetstream_proto.Orchestrator/HealthCheck", - request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString, - response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString, - ) - - -class OrchestratorServicer(object): - """TODO: Merge this with main JetStream core once we settle on an API.""" - - def Decode(self, request, context): - """Query LLM to generate text or tokens.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - def HealthCheck(self, request, context): - """Checks if the model server is live.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - -def add_OrchestratorServicer_to_server(servicer, server): - rpc_method_handlers = { - "Decode": grpc.unary_stream_rpc_method_handler( - servicer.Decode, - request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.FromString, - response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.SerializeToString, - ), - "HealthCheck": grpc.unary_unary_rpc_method_handler( - servicer.HealthCheck, - request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.FromString, - response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - "jetstream_proto.Orchestrator", rpc_method_handlers - ) - server.add_generic_rpc_handlers((generic_handler,)) - - -# This class is part of an EXPERIMENTAL API. -class Orchestrator(object): - """TODO: Merge this with main JetStream core once we settle on an API.""" - - @staticmethod - def Decode( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_stream( - request, - target, - "/jetstream_proto.Orchestrator/Decode", - jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString, - jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) - - @staticmethod - def HealthCheck( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, - target, - "/jetstream_proto.Orchestrator/HealthCheck", - jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString, - jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) - - -class MultiAdapterManagerStub(object): - """MultiAdapterManagerStub.""" - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.ListAdapters = channel.unary_unary( - '/jetstream_proto.MultiAdapterManager/ListAdapters', - request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersRequest.SerializeToString, - response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersResponse.FromString, - _registered_method=True) - self.LoadAdapter = channel.unary_unary( - '/jetstream_proto.MultiAdapterManager/LoadAdapter', - request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterRequest.SerializeToString, - response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterResponse.FromString, - _registered_method=True) - self.UnloadAdapter = channel.unary_unary( - '/jetstream_proto.MultiAdapterManager/UnloadAdapter', - request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterRequest.SerializeToString, - response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterResponse.FromString, - _registered_method=True) - - -class MultiAdapterManagerServicer(object): - """TODO: Merge this with main JetStream core once we settle on an API.""" - - def ListAdapters(self, request, context): - """Lists all the currently loaded LoRA adapters.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - def LoadAdapter(self, request, context): - """Check the feasibility and load the new LoRA adapter.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - def UnloadAdapter(self, request, context): - """Unload a LoRA adapter.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - -def add_MultiAdapterManagerServicer_to_server(servicer, server): - rpc_method_handlers = { - "ListAdapters": grpc.unary_unary_rpc_method_handler( - servicer.ListAdapters, - request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersRequest.FromString, - response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersResponse.SerializeToString, - ), - "LoadAdapter": grpc.unary_unary_rpc_method_handler( - servicer.LoadAdapter, - request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterRequest.FromString, - response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterResponse.SerializeToString, - ), - "UnloadAdapter": grpc.unary_unary_rpc_method_handler( - servicer.UnloadAdapter, - request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterRequest.FromString, - response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - "jetstream_proto.MultiAdapterManager", rpc_method_handlers - ) - server.add_generic_rpc_handlers((generic_handler,)) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 18dae982..1cf91ea2 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -34,6 +34,7 @@ from jetstream.core import config_lib from jetstream.core import orchestrator from jetstream.core.metrics.prometheus import JetstreamMetricsCollector +from jetstream.core.proto import jetstream_pb2_grpc from jetstream.engine import warmup_utils, engine_api from prometheus_client import start_http_server @@ -66,25 +67,19 @@ async def do_init(): asyncio.run_coroutine_threadsafe(do_init(), loop=self._loop).result() self._driver = driver + jetstream_pb2_grpc.add_OrchestratorServicer_to_server( + orchestrator.LLMOrchestrator(driver=self._driver), self._grpc_server + ) if enable_llm_inference_pool: - module_name = "jetstream.core.llm_inference_pool_api" - llm_inference_pool = importlib.import_module(module_name) + module_name = "jetstream.core.lora.multi_lora_inference_api" + multi_lora_inference = importlib.import_module(module_name) module_name = "jetstream.core.proto.multi_lora_decoding_pb2_grpc" multi_lora_decoding_pb2_grpc = importlib.import_module(module_name) - asyncio.run(self._driver.loadAdaptersFromCatalogToTensorStore()) - multi_lora_decoding_pb2_grpc.add_v1Servicer_to_server( - llm_inference_pool.MultiLoraManager(driver=self._driver), self._grpc_server - ) - else: - module_name = "jetstream.core.proto.jetstream_pb2_grpc" - jetstream_pb2_grpc = importlib.import_module(module_name) - - jetstream_pb2_grpc.add_OrchestratorServicer_to_server( - orchestrator.LLMOrchestrator(driver=self._driver), self._grpc_server + multi_lora_inference.MultiLoraManager(driver=self._driver), self._grpc_server ) self._grpc_server.add_secure_port(f"{_HOST}:{port}", credentials) @@ -136,9 +131,9 @@ def create_driver( An orchestrator driver. """ engines = config_lib.get_engines(config, devices=devices) - prefill_params = [{"base_params": pe.load_params()} for pe in engines.prefill_engines] - generate_params = [{"base_params": ge.load_params()} for ge in engines.generate_engines] - shared_params = [{"base_params": ie.load_params()} for ie in engines.interleaved_engines] + prefill_params = [pe.load_params() for pe in engines.prefill_engines] + generate_params = [ge.load_params() for ge in engines.generate_engines] + shared_params = [ie.load_params() for ie in engines.interleaved_engines] logging.info("Loaded all weights.") interleaved_mode = ( @@ -178,8 +173,6 @@ def create_driver( traceback.print_exc() os.kill(os.getpid(), signal.SIGKILL) - logging.info("AMANGU: Going to create the drivers.") - return orchestrator.Driver( prefill_engines=prefill_engines, generate_engines=generate_engines, diff --git a/jetstream/tools/decode_multi_requester.py b/jetstream/tools/decode_multi_requester.py deleted file mode 100644 index 3337c87a..00000000 --- a/jetstream/tools/decode_multi_requester.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Benchmark JetStream online serving. - -On the server side, run one of the following commands: - * For real server, you need to pass correct server config (include the - model config that being passed into your engine impl) to the command - below. Refer to config_lib.py and implementations/mock/config.py for - config impl detail. - - (run with real server) - python -m jetstream.core.implementations..server \ - --config - - (run with mock server) - python -m jetstream.core.implementations.mock.server - -On the client side, run: - * For real server and shareGPT dataset, you need to pass the tokenizer, - server config, and dataset flags to the command below, and make some - changes to the tokenizer logic in the benchmark script (get_tokenizer - and sample_requests func) to use your tokenizer correctly. - * Add `--save-result` flag to save the benchmark result to a json file in - current folder. - * You can also add `--run_eval true` if you want to calculate ROUGE score - on the predicted outputs. - - (run with real model and engines) - python -m benchmarks.benchmark_serving \ - --tokenizer \ - --dataset \ - --dataset-path \ - --request-rate - - (run with mock) - python -m benchmarks.benchmark_serving \ - --request-rate 1 - -e2e example: -python3 benchmark_serving.py \ - --tokenizer /home/{username}/maxtext/assets/tokenizer \ - --num-prompts 100 \ - --dataset sharegpt \ - --dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json - -""" - - -import argparse -import asyncio -from dataclasses import dataclass, field -from datetime import datetime -import json -import random -import time -from typing import Any, AsyncGenerator, Optional -import os - - -import grpc -from jetstream.core.proto import jetstream_pb2 -from jetstream.core.proto import jetstream_pb2_grpc -from jetstream.engine.token_utils import load_vocab -from jetstream.external_tokenizers.llama3 import llama3_tokenizer -import numpy as np - - -def str2bool(v: str) -> bool: - """Convert a string of truth to True or False. - - Args: - - v (str): - - True values are 'y', 'yes', 't', 'true', and '1'; - - False values are 'n', 'no', 'f', 'false', and '0'. - - Returns: - bool: True or False - - Raises: - ValueError if v is anything else. - """ - v = v.lower() - true_values = ["y", "yes", "t", "true", "1"] - false_values = ["n", "no", "f", "false", "0"] - if v in true_values: - return True - elif v in false_values: - return False - else: - raise ValueError(f"Invalid value '{v}'!") - - -@dataclass -class BenchmarkMetrics: - """Data class to store benchmark metrics.""" - - completed: int - total_input: int - total_output: int - request_throughput: float - input_throughput: float - output_throughput: float - mean_ttft_ms: float - median_ttft_ms: float - p99_ttft_ms: float - mean_tpot_ms: float - median_tpot_ms: float - p99_tpot_ms: float - - -@dataclass -class InputRequest: - prompt: str = "" - output: str = "" - output_len: int = 0 - sample_idx: int = -1 - - -@dataclass -class RequestFuncOutput: - input_request: Optional[InputRequest] = None - generated_token_list: list[str] = field(default_factory=list) - generated_text: str = "" - success: bool = False - latency: float = 0 - ttft: float = 0 - - # Flatten the structure and return only the necessary results - def to_dict(self): - return { - "prompt": self.input_request.prompt, - "original_output": self.input_request.output, - "generated_text": self.generated_text, - "success": self.success, - "latency": self.latency, - "sample_idx": self.input_request.sample_idx, - } - - -def get_tokenizer( - model_id: str, - tokenizer_name: str, -) -> Any: - """Return a tokenizer or a tokenizer placholder.""" - if tokenizer_name == "test": - print("Using test tokenizer") - return "test" - elif model_id == "llama-3": - # Llama 3 uses a tiktoken tokenizer. - print(f"Using llama-3 tokenizer: {tokenizer_name}") - return llama3_tokenizer.Tokenizer(tokenizer_name) - else: - # Use JetStream tokenizer util. It's using the sentencepiece wrapper in - # seqio library. - print(f"Using tokenizer: {tokenizer_name}") - vocab = load_vocab(tokenizer_name) - return vocab.tokenizer - - -async def grpc_async_request( - api_url: str, request: Any -) -> tuple[list[str], float, float]: - """Send grpc synchronous request since the current grpc server is sync.""" - options = [("grpc.keepalive_timeout_ms", 10000)] - async with grpc.aio.insecure_channel(api_url, options=options) as channel: - stub = jetstream_pb2_grpc.OrchestratorStub(channel) - print("Making request") - ttft = 0 - token_list = [] - request_start_time = time.perf_counter() - response = stub.Decode(request) - async for resp in response: - if ttft == 0: - ttft = time.perf_counter() - request_start_time - token_list.extend(resp.stream_content.samples[0].token_ids) - latency = time.perf_counter() - request_start_time - return token_list, ttft, latency - - -async def send_request( - api_url: str, - tokenizer: Any, - input_request: InputRequest, -) -> RequestFuncOutput: - """Send the request to JetStream server.""" - # Tokenization on client side following MLPerf standard. - token_ids = tokenizer.encode(input_request.prompt) - request = jetstream_pb2.DecodeRequest( - token_content=jetstream_pb2.DecodeRequest.TokenContent( - token_ids=token_ids - ), - max_tokens=input_request.output_len, - adapter_id=input_request.adapter_id, - ) - output = RequestFuncOutput() - output.input_request = input_request - generated_token_list, ttft, latency = await grpc_async_request( - api_url, request - ) - output.ttft = ttft - output.latency = latency - output.generated_token_list = generated_token_list - # generated_token_list is a list of token ids, decode it to generated_text. - output.generated_text = tokenizer.decode(generated_token_list) - output.success = True - return output - - -async def get_request( - input_requests: list[InputRequest], -) -> AsyncGenerator[InputRequest, None]: - input_requests = iter(input_requests) - - for request in input_requests: - yield request - - -async def send_multi_request( - api_url: str, - tokenizer: Any, - input_requests: list[InputRequest], -): - """Send multiple LoRA adapter requests.""" - tasks = [] - async for request in get_request(input_requests): - tasks.append( - asyncio.create_task( - send_request( - api_url=api_url, - tokenizer=tokenizer, - input_request=request, - ) - ) - ) - outputs = await asyncio.gather(*tasks) - - return outputs - - -def mock_adapter_requests(total_mock_requests: int): - """Generates a list of mock requests containing mock data.""" - data = [] - for index in range(total_mock_requests): - request = InputRequest() - request.prompt = f"22 year old" - if index == 0: - request.adapter_id = "" - else: - request.adapter_id = f"test_lora_{index}" - request.output_len = 3 - data.append(request) - return data - - -def main(args: argparse.Namespace): - print(args) - - model_id = args.model - tokenizer_id = args.tokenizer - - api_url = f"{args.server}:{args.port}" - - tokenizer = get_tokenizer(model_id, tokenizer_id) - input_requests = mock_adapter_requests( - args.total_mock_requests - ) # e.g. [("AB", 2, "AB", 3)] - - request_outputs = asyncio.run( - send_multi_request( - api_url=api_url, - tokenizer=tokenizer, - input_requests=input_requests, - ) - ) - - output = [output.to_dict() for output in request_outputs] - - # Process output - for index, output in enumerate(output): - print(f"Prompt: {input_requests[index].prompt}") - print(f"AdapterId: {input_requests[index].adapter_id}") - print(f"Output: {output}") - - -if __name__ == "__main__": - - parser = argparse.ArgumentParser( - description="Sending multiple serving requests to JetStream Server" - ) - parser.add_argument( - "--server", - type=str, - default="0.0.0.0", - help="Server address.", - ) - parser.add_argument("--port", type=str, default=9000) - parser.add_argument( - "--model", - type=str, - default="no_model", - help=( - "Name of the model like llama-2, llama-3, gemma. (it's just used to" - " label the benchmark, pick the tokenizer, the model config is" - " defined in config_lib, and passed as the server config flag when" - " we run the JetStream server)" - ), - ) - parser.add_argument( - "--total-mock-requests", - type=int, - default=3, - help="The maximum number of mock requests to send for benchmark testing.", - ) - parser.add_argument( - "--tokenizer", - type=str, - default="test", - help=( - "Name or path of the tokenizer. (For mock model testing, use the" - " default value)" - ), - ) - - parsed_args = parser.parse_args() - main(parsed_args) diff --git a/jetstream/tools/llm_gateway_proxy_client.py b/jetstream/tools/llm_gateway_proxy_client.py deleted file mode 100644 index d607c544..00000000 --- a/jetstream/tools/llm_gateway_proxy_client.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A test request.""" - -from typing import Sequence - -from absl import app -from absl import flags -import grpc -from jetstream.core.proto import jetstream_pb2 -from jetstream.core.proto import jetstream_pb2_grpc -from jetstream.engine.token_utils import load_vocab - - -_SERVER = flags.DEFINE_string("server", "0.0.0.0", "server address") -_PORT = flags.DEFINE_string("port", "9000", "port to ping") -#_TEXT = flags.DEFINE_string("text", "My dog is cute", "The message") -_TEXT = flags.DEFINE_string("text", "22 year old", "The message") -_MAX_TOKENS = flags.DEFINE_integer( - "max_tokens", 3, "Maximum number of output/decode tokens of a sequence" -) - -_ADAPTER_ID = flags.DEFINE_string( - "adapter_id", - None, - "Id of the fine-tuned adapter to be loaded on top of the base model.", - required=False, -) - -_ADAPTER_PATH = flags.DEFINE_string( - "adapter_path", - None, - "Path of the fine-tuned adapter to be loaded from.", - required=False, -) - -_TEST_API_NAME = flags.DEFINE_string( - "test_api_name", - None, - "Name of the JetStream API to call.", - required=True, -) - - -def main(argv: Sequence[str]) -> None: - del argv - # Note: Uses insecure_channel only for local testing. Please add grpc - # credentials for Production. - address = f"{_SERVER.value}:{_PORT.value}" - with grpc.insecure_channel(address) as channel: - grpc.channel_ready_future(channel).result() - stub = jetstream_pb2_grpc.MultiAdapterManagerStub(channel) - print(f"Sending request to: {address}") - - if _TEST_API_NAME.value == "load_adapter": - print(f"Calling the JetStream/MultiAdapterManager/LoadAdapter.") - - adapter_id=_ADAPTER_ID.value - adapter_path=_ADAPTER_PATH.value - - if adapter_id == None or adapter_path == None: - print(f"For `load_adapter` API call, `adapter_id` and `adapter_path` must be passed.") - return - - request = jetstream_pb2.LoadAdapterRequest( - adapter_id=adapter_id, - adapter_path=adapter_path - ) - - response = stub.LoadAdapter(request) - - if response.success is True: - print(f"Adapter={adapter_id} is loaded successfully.") - else: - print(f"Adapter={adapter_id} loading failed with error={response.error_message}") - - elif _TEST_API_NAME.value == "unload_adapter": - print(f"Calling the JetStream/MultiAdapterManager/UnloadAdapter.") - - adapter_id=_ADAPTER_ID.value - - if adapter_id == None: - print(f"For `unload_adapter` API call, `adapter_id` must be passed.") - return - - request = jetstream_pb2.UnloadAdapterRequest( - adapter_id=adapter_id, - ) - - response = stub.UnloadAdapter(request) - - if response.success is True: - print(f"Adapter={adapter_id} is unloaded successfully.") - else: - print(f"Adapter={adapter_id} unloading failed with error={response.error_message}") - - elif _TEST_API_NAME.value == "list_adapters": - print(f"Calling the JetStream/MultiAdapterManager/ListAdapters.") - - request = jetstream_pb2.ListAdaptersRequest() - - response = stub.ListAdapters(request) - - if response.success is True: - print(f"`ListAdapter` call responded successfully. Here is the list of adapters loaded on server:") - for adapter_info in response.adapter_infos: - print(f"adapter_id={adapter_info.adapter_id}, loading_cost={adapter_info.loading_cost}, size_hbm={adapter_info.size_hbm} bytes, size_cpu={adapter_info.size_cpu} Bytes, last_accessed={adapter_info.last_accessed}, status={adapter_info.status}") - else: - print(f"`ListAdapter` call failed with error={response.error_message}") - - elif _TEST_API_NAME.value == None: - print(f"`test_api_name` flag is not set. So exiting.") - return - - else: - print(f"API={_TEST_API_NAME.value} is not implemented yet. So exiting.") - return - - - print(f"API calls ended.") - - -if __name__ == "__main__": - app.run(main) diff --git a/jetstream/tools/llm_gateway_proxy_client_v2.py b/jetstream/tools/multi_adapter_service_client.py similarity index 100% rename from jetstream/tools/llm_gateway_proxy_client_v2.py rename to jetstream/tools/multi_adapter_service_client.py diff --git a/jetstream/tools/decode_multi_requester_v2.py b/jetstream/tools/multi_lora_decode_requester.py similarity index 81% rename from jetstream/tools/decode_multi_requester_v2.py rename to jetstream/tools/multi_lora_decode_requester.py index 1479807c..6e859204 100644 --- a/jetstream/tools/decode_multi_requester_v2.py +++ b/jetstream/tools/multi_lora_decode_requester.py @@ -12,49 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Benchmark JetStream online serving. - -On the server side, run one of the following commands: - * For real server, you need to pass correct server config (include the - model config that being passed into your engine impl) to the command - below. Refer to config_lib.py and implementations/mock/config.py for - config impl detail. - - (run with real server) - python -m jetstream.core.implementations..server \ - --config - - (run with mock server) - python -m jetstream.core.implementations.mock.server - -On the client side, run: - * For real server and shareGPT dataset, you need to pass the tokenizer, - server config, and dataset flags to the command below, and make some - changes to the tokenizer logic in the benchmark script (get_tokenizer - and sample_requests func) to use your tokenizer correctly. - * Add `--save-result` flag to save the benchmark result to a json file in - current folder. - * You can also add `--run_eval true` if you want to calculate ROUGE score - on the predicted outputs. - - (run with real model and engines) - python -m benchmarks.benchmark_serving \ - --tokenizer \ - --dataset \ - --dataset-path \ - --request-rate - - (run with mock) - python -m benchmarks.benchmark_serving \ - --request-rate 1 - -e2e example: -python3 benchmark_serving.py \ - --tokenizer /home/{username}/maxtext/assets/tokenizer \ - --num-prompts 100 \ - --dataset sharegpt \ - --dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json - +"""Decoding multiple LoRA requests via JetStream online serving. """ diff --git a/jetstream/tools/requester.py b/jetstream/tools/requester.py index d81cdfa9..e4263f38 100644 --- a/jetstream/tools/requester.py +++ b/jetstream/tools/requester.py @@ -26,8 +26,7 @@ _SERVER = flags.DEFINE_string("server", "0.0.0.0", "server address") _PORT = flags.DEFINE_string("port", "9000", "port to ping") -#_TEXT = flags.DEFINE_string("text", "My dog is cute", "The message") -_TEXT = flags.DEFINE_string("text", "22 year old", "The message") +_TEXT = flags.DEFINE_string("text", "My dog is cute", "The message") _MAX_TOKENS = flags.DEFINE_integer( "max_tokens", 3, "Maximum number of output/decode tokens of a sequence" ) @@ -88,6 +87,7 @@ def main(argv: Sequence[str]) -> None: token_ids=token_ids ), max_tokens=_MAX_TOKENS.value, + adapter_id=_ADAPTER_ID.value, ) else: request = jetstream_pb2.DecodeRequest( From eb74d860c4ff09b4bf8116380c68d6fa692f52c2 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 6 Mar 2025 06:00:51 +0000 Subject: [PATCH 09/22] Refactoring part-2. --- jetstream/core/orchestrator.py | 11 ++- jetstream/core/proto/jetstream.proto | 1 - jetstream/core/proto/jetstream_pb2.py | 22 +----- jetstream/core/proto/jetstream_pb2_grpc.py | 71 ------------------- .../proto/multi_lora_decoding_pb2_grpc.py | 2 +- jetstream/core/server_lib.py | 1 - 6 files changed, 8 insertions(+), 100 deletions(-) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 76adf399..ed513d53 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -516,9 +516,10 @@ def _export_lora_request_info(self): max_loras = 0 if self._metrics_collector: for idx, engine in enumerate(self._generate_engines): - adapters_list_str += asyncio.run(self._adapter_tensorstore.get_hbm_loaded_adapters()) max_loras += engine.max_concurrent_decodes + adapters_list_str += asyncio.run(self._adapter_tensorstore.get_hbm_loaded_adapters()) + self._metrics_collector.get_lora_request_info_metric(max_loras, adapters_list_str).set_to_current_time() @@ -580,7 +581,6 @@ def _prefill_thread(self, idx: int): if request is None: break - request.metadata.prefill_dequeue_time = time.perf_counter() is_bos = True logging.info( @@ -616,7 +616,6 @@ def _prefill_thread(self, idx: int): padded_tokens=padded_tokens, true_length=true_length, ) - del final_params request.prefill_result = prefill_result @@ -705,7 +704,6 @@ def _transfer_thread(self, idx: int): new_request.metadata.generate_enqueue_time = time.perf_counter() self._generate_backlogs[target_idx].put(new_request, block=True) - elapsed_time = (new_request.metadata.generate_enqueue_time - new_request.metadata.transfer_dequeue_time) * 1e6 logging.info( "Successfully transferred prefill " "from prefill engine %d to generate engine %d " @@ -821,7 +819,8 @@ def _generate_thread(self, idx: int): decode_state = generate_engine.insert( new_request.prefill_result, decode_state, slot=slot ) - + + # Export the lora_request_info metric self._export_lora_request_info() del new_request.prefill_result @@ -1123,7 +1122,6 @@ async def Decode( # pylint: disable=invalid-overridden-method request: jetstream_pb2.DecodeRequest, context: Optional[grpc.aio.ServicerContext] = None, ) -> AsyncIterator[jetstream_pb2.DecodeResponse]: - """Decode.""" if context is None: logging.warning( @@ -1134,7 +1132,6 @@ async def Decode( # pylint: disable=invalid-overridden-method return_channel = async_multifuture.AsyncMultifuture() if context: context.add_done_callback(return_channel.cancel) - prefill_content, is_client_side_tokenization = self._get_prefill_content( request ) diff --git a/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto index 9516b1c3..1f85e8e6 100644 --- a/jetstream/core/proto/jetstream.proto +++ b/jetstream/core/proto/jetstream.proto @@ -93,5 +93,4 @@ message HealthCheckRequest {} message HealthCheckResponse { // Denotes whether the model server is live bool is_live = 1; - } diff --git a/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py index 71d4af40..4fdb3dd6 100644 --- a/jetstream/core/proto/jetstream_pb2.py +++ b/jetstream/core/proto/jetstream_pb2.py @@ -29,7 +29,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fjetstream.proto\x12\x0fjetstream_proto\"\x90\x03\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x12\x12\n\nadapter_id\x18\x08 \x01(\t\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04\"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\"\x14\n\x12HealthCheckRequest\"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\"\x15\n\x13ListAdaptersRequest\"s\n\x14ListAdaptersResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\x12\x33\n\radapter_infos\x18\x03 \x03(\x0b\x32\x1c.jetstream_proto.AdapterInfo\"\x82\x01\n\x0b\x41\x64\x61pterInfo\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0cloading_cost\x18\x02 \x01(\x03\x12\x10\n\x08size_hbm\x18\x03 \x01(\x03\x12\x10\n\x08size_cpu\x18\x04 \x01(\x03\x12\x15\n\rlast_accessed\x18\x05 \x01(\x02\x12\x0e\n\x06status\x18\x06 \x01(\t\">\n\x12LoadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\"=\n\x13LoadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\"*\n\x14UnloadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"?\n\x15UnloadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t2\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse\"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse\"\x00\x32\xb2\x02\n\x13MultiAdapterManager\x12]\n\x0cListAdapters\x12$.jetstream_proto.ListAdaptersRequest\x1a%.jetstream_proto.ListAdaptersResponse\"\x00\x12Z\n\x0bLoadAdapter\x12#.jetstream_proto.LoadAdapterRequest\x1a$.jetstream_proto.LoadAdapterResponse\"\x00\x12`\n\rUnloadAdapter\x12%.jetstream_proto.UnloadAdapterRequest\x1a&.jetstream_proto.UnloadAdapterResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fjetstream.proto\x12\x0fjetstream_proto\"\x90\x03\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x12\x12\n\nadapter_id\x18\x08 \x01(\t\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04\"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\"\x14\n\x12HealthCheckRequest\"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse\"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -56,22 +56,6 @@ _globals['_HEALTHCHECKREQUEST']._serialized_end=793 _globals['_HEALTHCHECKRESPONSE']._serialized_start=795 _globals['_HEALTHCHECKRESPONSE']._serialized_end=833 - _globals['_LISTADAPTERSREQUEST']._serialized_start=835 - _globals['_LISTADAPTERSREQUEST']._serialized_end=856 - _globals['_LISTADAPTERSRESPONSE']._serialized_start=858 - _globals['_LISTADAPTERSRESPONSE']._serialized_end=973 - _globals['_ADAPTERINFO']._serialized_start=976 - _globals['_ADAPTERINFO']._serialized_end=1106 - _globals['_LOADADAPTERREQUEST']._serialized_start=1108 - _globals['_LOADADAPTERREQUEST']._serialized_end=1170 - _globals['_LOADADAPTERRESPONSE']._serialized_start=1172 - _globals['_LOADADAPTERRESPONSE']._serialized_end=1233 - _globals['_UNLOADADAPTERREQUEST']._serialized_start=1235 - _globals['_UNLOADADAPTERREQUEST']._serialized_end=1277 - _globals['_UNLOADADAPTERRESPONSE']._serialized_start=1279 - _globals['_UNLOADADAPTERRESPONSE']._serialized_end=1342 - _globals['_ORCHESTRATOR']._serialized_start=1345 - _globals['_ORCHESTRATOR']._serialized_end=1530 - _globals['_MULTIADAPTERMANAGER']._serialized_start=1533 - _globals['_MULTIADAPTERMANAGER']._serialized_end=1839 + _globals['_ORCHESTRATOR']._serialized_start=836 + _globals['_ORCHESTRATOR']._serialized_end=1021 # @@protoc_insertion_point(module_scope) diff --git a/jetstream/core/proto/jetstream_pb2_grpc.py b/jetstream/core/proto/jetstream_pb2_grpc.py index 5de13a1e..9d98a982 100644 --- a/jetstream/core/proto/jetstream_pb2_grpc.py +++ b/jetstream/core/proto/jetstream_pb2_grpc.py @@ -136,74 +136,3 @@ def HealthCheck( metadata, ) - -class MultiAdapterManagerStub(object): - """MultiAdapterManagerStub.""" - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.ListAdapters = channel.unary_unary( - '/jetstream_proto.MultiAdapterManager/ListAdapters', - request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersRequest.SerializeToString, - response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersResponse.FromString, - _registered_method=True) - self.LoadAdapter = channel.unary_unary( - '/jetstream_proto.MultiAdapterManager/LoadAdapter', - request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterRequest.SerializeToString, - response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterResponse.FromString, - _registered_method=True) - self.UnloadAdapter = channel.unary_unary( - '/jetstream_proto.MultiAdapterManager/UnloadAdapter', - request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterRequest.SerializeToString, - response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterResponse.FromString, - _registered_method=True) - - -class MultiAdapterManagerServicer(object): - """TODO: Merge this with main JetStream core once we settle on an API.""" - - def ListAdapters(self, request, context): - """Lists all the currently loaded LoRA adapters.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - def LoadAdapter(self, request, context): - """Check the feasibility and load the new LoRA adapter.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - def UnloadAdapter(self, request, context): - """Unload a LoRA adapter.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - -def add_MultiAdapterManagerServicer_to_server(servicer, server): - rpc_method_handlers = { - "ListAdapters": grpc.unary_unary_rpc_method_handler( - servicer.ListAdapters, - request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersRequest.FromString, - response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersResponse.SerializeToString, - ), - "LoadAdapter": grpc.unary_unary_rpc_method_handler( - servicer.LoadAdapter, - request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterRequest.FromString, - response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterResponse.SerializeToString, - ), - "UnloadAdapter": grpc.unary_unary_rpc_method_handler( - servicer.UnloadAdapter, - request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterRequest.FromString, - response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - "jetstream_proto.MultiAdapterManager", rpc_method_handlers - ) - server.add_generic_rpc_handlers((generic_handler,)) diff --git a/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py b/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py index d172d151..c6071fbf 100644 --- a/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py +++ b/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py @@ -3,7 +3,7 @@ import grpc import warnings -from jetstream.core.proto import multi_lora_decoding_pb2 as multi__lora__decoding__pb2 +import multi_lora_decoding_pb2 as multi__lora__decoding__pb2 GRPC_GENERATED_VERSION = '1.70.0' GRPC_VERSION = grpc.__version__ diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 1cf91ea2..6fa364af 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -135,7 +135,6 @@ def create_driver( generate_params = [ge.load_params() for ge in engines.generate_engines] shared_params = [ie.load_params() for ie in engines.interleaved_engines] logging.info("Loaded all weights.") - interleaved_mode = ( len(config.prefill_slices) + len(config.generate_slices) == 0 ) From a41e4cd8ab36112aa9eb7d9280b5458883bec1f6 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 6 Mar 2025 07:34:39 +0000 Subject: [PATCH 10/22] Refactor part-3. --- jetstream/core/orchestrator.py | 6 - jetstream/core/proto/jetstream_pb2.py | 8 +- jetstream/core/proto/jetstream_pb2_grpc.py | 194 ++++++++---------- .../core/proto/multi_lora_decoding_pb2.py | 7 +- .../proto/multi_lora_decoding_pb2_grpc.py | 95 ++------- 5 files changed, 111 insertions(+), 199 deletions(-) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index ed513d53..0de15320 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -580,7 +580,6 @@ def _prefill_thread(self, idx: int): if request is None: break - request.metadata.prefill_dequeue_time = time.perf_counter() is_bos = True logging.info( @@ -590,7 +589,6 @@ def _prefill_thread(self, idx: int): self._prefill_backlog.qsize(), is_bos, ) - # Tokenize and padding the text or token input. padded_tokens, true_length = self._process_prefill_content( request, tokenizer, is_bos, prefill_engine.max_prefill_length @@ -703,7 +701,6 @@ def _transfer_thread(self, idx: int): # Place the request on the correct generate backlog and block if full. new_request.metadata.generate_enqueue_time = time.perf_counter() self._generate_backlogs[target_idx].put(new_request, block=True) - logging.info( "Successfully transferred prefill " "from prefill engine %d to generate engine %d " @@ -727,7 +724,6 @@ def _generate_thread(self, idx: int): decode_state = generate_engine.init_decode_state() generate_params = self._generate_params[idx] - logging.info("---------Generate params %d loaded.---------", idx) time_of_last_generate = time.time() time_of_last_print = time.time() @@ -841,7 +837,6 @@ def _generate_thread(self, idx: int): generate_params, decode_state ) sampled_tokens.copy_to_host_async() - # Respond to detokenization backpressure. my_detokenize_backlog.put((generate_timestep, sampled_tokens), block=True) generate_timestep += 1 @@ -1135,7 +1130,6 @@ async def Decode( # pylint: disable=invalid-overridden-method prefill_content, is_client_side_tokenization = self._get_prefill_content( request ) - # Wrap request as an ActiveRequest. active_request = ActiveRequest( max_tokens=request.max_tokens, diff --git a/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py index 4fdb3dd6..a26eb2b9 100644 --- a/jetstream/core/proto/jetstream_pb2.py +++ b/jetstream/core/proto/jetstream_pb2.py @@ -11,12 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE # source: jetstream.proto -# Protobuf Python Version: 5.29.0 +# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -34,8 +32,8 @@ _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'jetstream_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None _globals['_DECODEREQUEST']._serialized_start=37 _globals['_DECODEREQUEST']._serialized_end=437 _globals['_DECODEREQUEST_TEXTCONTENT']._serialized_start=293 diff --git a/jetstream/core/proto/jetstream_pb2_grpc.py b/jetstream/core/proto/jetstream_pb2_grpc.py index 9d98a982..9e158fdb 100644 --- a/jetstream/core/proto/jetstream_pb2_grpc.py +++ b/jetstream/core/proto/jetstream_pb2_grpc.py @@ -15,124 +15,106 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -from jetstream.core.proto import jetstream_pb2 as jetstream_dot_core_dot_proto_dot_jetstream__pb2 +from jetstream.core.proto import jetstream_pb2 as jetstream__pb2 class OrchestratorStub(object): - """TODO: Merge this with main JetStream core once we settle on an API.""" + """TODO: Merge this with main JetStream core once we settle on an API. - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. """ - self.Decode = channel.unary_stream( - "/jetstream_proto.Orchestrator/Decode", - request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString, - response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString, - ) - self.HealthCheck = channel.unary_unary( - "/jetstream_proto.Orchestrator/HealthCheck", - request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString, - response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString, - ) + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Decode = channel.unary_stream( + '/jetstream_proto.Orchestrator/Decode', + request_serializer=jetstream__pb2.DecodeRequest.SerializeToString, + response_deserializer=jetstream__pb2.DecodeResponse.FromString, + ) + self.HealthCheck = channel.unary_unary( + '/jetstream_proto.Orchestrator/HealthCheck', + request_serializer=jetstream__pb2.HealthCheckRequest.SerializeToString, + response_deserializer=jetstream__pb2.HealthCheckResponse.FromString, + ) class OrchestratorServicer(object): - """TODO: Merge this with main JetStream core once we settle on an API.""" + """TODO: Merge this with main JetStream core once we settle on an API. - def Decode(self, request, context): - """Query LLM to generate text or tokens.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + """ - def HealthCheck(self, request, context): - """Checks if the model server is live.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + def Decode(self, request, context): + """Query LLM to generate text or tokens. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def HealthCheck(self, request, context): + """Checks if the model server is live. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def add_OrchestratorServicer_to_server(servicer, server): - rpc_method_handlers = { - "Decode": grpc.unary_stream_rpc_method_handler( - servicer.Decode, - request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.FromString, - response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.SerializeToString, - ), - "HealthCheck": grpc.unary_unary_rpc_method_handler( - servicer.HealthCheck, - request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.FromString, - response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - "jetstream_proto.Orchestrator", rpc_method_handlers - ) - server.add_generic_rpc_handlers((generic_handler,)) - - -# This class is part of an EXPERIMENTAL API. + rpc_method_handlers = { + 'Decode': grpc.unary_stream_rpc_method_handler( + servicer.Decode, + request_deserializer=jetstream__pb2.DecodeRequest.FromString, + response_serializer=jetstream__pb2.DecodeResponse.SerializeToString, + ), + 'HealthCheck': grpc.unary_unary_rpc_method_handler( + servicer.HealthCheck, + request_deserializer=jetstream__pb2.HealthCheckRequest.FromString, + response_serializer=jetstream__pb2.HealthCheckResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'jetstream_proto.Orchestrator', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. class Orchestrator(object): - """TODO: Merge this with main JetStream core once we settle on an API.""" - - @staticmethod - def Decode( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_stream( - request, - target, - "/jetstream_proto.Orchestrator/Decode", - jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString, - jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) - - @staticmethod - def HealthCheck( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, - target, - "/jetstream_proto.Orchestrator/HealthCheck", - jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString, - jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) + """TODO: Merge this with main JetStream core once we settle on an API. + + """ + @staticmethod + def Decode(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/jetstream_proto.Orchestrator/Decode', + jetstream__pb2.DecodeRequest.SerializeToString, + jetstream__pb2.DecodeResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def HealthCheck(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/jetstream_proto.Orchestrator/HealthCheck', + jetstream__pb2.HealthCheckRequest.SerializeToString, + jetstream__pb2.HealthCheckResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/jetstream/core/proto/multi_lora_decoding_pb2.py b/jetstream/core/proto/multi_lora_decoding_pb2.py index d53e10aa..4ad06c83 100644 --- a/jetstream/core/proto/multi_lora_decoding_pb2.py +++ b/jetstream/core/proto/multi_lora_decoding_pb2.py @@ -1,8 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE # source: multi_lora_decoding.proto -# Protobuf Python Version: 5.29.0 +# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -20,8 +19,8 @@ _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'multi_lora_decoding_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None _globals['_COMPLETIONREQUEST']._serialized_start=30 _globals['_COMPLETIONREQUEST']._serialized_end=398 _globals['_COMPLETIONREQUEST_TEXTCONTENT']._serialized_start=254 diff --git a/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py b/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py index c6071fbf..495714de 100644 --- a/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py +++ b/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py @@ -1,28 +1,8 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc -import warnings -import multi_lora_decoding_pb2 as multi__lora__decoding__pb2 - -GRPC_GENERATED_VERSION = '1.70.0' -GRPC_VERSION = grpc.__version__ -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - raise RuntimeError( - f'The grpc package installed is at version {GRPC_VERSION},' - + f' but the generated code in multi_lora_decoding_pb2_grpc.py depends on' - + f' grpcio>={GRPC_GENERATED_VERSION}.' - + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' - + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' - ) +from jetstream.core.proto import multi_lora_decoding_pb2 as multi__lora__decoding__pb2 class v1Stub(object): @@ -38,22 +18,22 @@ def __init__(self, channel): '/v1/completions', request_serializer=multi__lora__decoding__pb2.CompletionRequest.SerializeToString, response_deserializer=multi__lora__decoding__pb2.CompletionResponse.FromString, - _registered_method=True) + ) self.models = channel.unary_unary( '/v1/models', request_serializer=multi__lora__decoding__pb2.ListAdaptersRequest.SerializeToString, response_deserializer=multi__lora__decoding__pb2.ListAdaptersResponse.FromString, - _registered_method=True) + ) self.load_lora_adapter = channel.unary_unary( '/v1/load_lora_adapter', request_serializer=multi__lora__decoding__pb2.LoadAdapterRequest.SerializeToString, response_deserializer=multi__lora__decoding__pb2.LoadAdapterResponse.FromString, - _registered_method=True) + ) self.unload_lora_adapter = channel.unary_unary( '/v1/unload_lora_adapter', request_serializer=multi__lora__decoding__pb2.UnloadAdapterRequest.SerializeToString, response_deserializer=multi__lora__decoding__pb2.UnloadAdapterResponse.FromString, - _registered_method=True) + ) class v1Servicer(object): @@ -114,7 +94,6 @@ def add_v1Servicer_to_server(servicer, server): generic_handler = grpc.method_handlers_generic_handler( 'v1', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers('v1', rpc_method_handlers) # This class is part of an EXPERIMENTAL API. @@ -132,21 +111,11 @@ def completions(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_stream( - request, - target, - '/v1/completions', + return grpc.experimental.unary_stream(request, target, '/v1/completions', multi__lora__decoding__pb2.CompletionRequest.SerializeToString, multi__lora__decoding__pb2.CompletionResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def models(request, @@ -159,21 +128,11 @@ def models(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/v1/models', + return grpc.experimental.unary_unary(request, target, '/v1/models', multi__lora__decoding__pb2.ListAdaptersRequest.SerializeToString, multi__lora__decoding__pb2.ListAdaptersResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def load_lora_adapter(request, @@ -186,21 +145,11 @@ def load_lora_adapter(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/v1/load_lora_adapter', + return grpc.experimental.unary_unary(request, target, '/v1/load_lora_adapter', multi__lora__decoding__pb2.LoadAdapterRequest.SerializeToString, multi__lora__decoding__pb2.LoadAdapterResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def unload_lora_adapter(request, @@ -213,18 +162,8 @@ def unload_lora_adapter(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/v1/unload_lora_adapter', + return grpc.experimental.unary_unary(request, target, '/v1/unload_lora_adapter', multi__lora__decoding__pb2.UnloadAdapterRequest.SerializeToString, multi__lora__decoding__pb2.UnloadAdapterResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) From febaed17a56acd9deae35f159ee0226179f692b7 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 6 Mar 2025 22:15:05 +0000 Subject: [PATCH 11/22] 1) Adding more comments at applying LoRA on Prefill params path. 2) Fixing model_ckpt_conversion.sh after refactoring and merging from main. --- jetstream/core/orchestrator.py | 12 +++++ .../tools/maxtext/model_ckpt_conversion.sh | 45 ++++++++++++++----- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 36190f69..dad41aa5 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -666,6 +666,18 @@ def _prefill_thread(self, idx: int): adapter_id = request.adapter_id + # As prefill is happening one prompt at a time, for each prefill, we are + # applying the LoRA params on base to create a copy of params (equivalent + # to the size of base params) and use that for generating kv-cache. This + # copy is called the final_prefill_params, which is deleted soon after the + # generation of kv-cache. + # We can have memory-optimizations by updating the original copy of + # base params at the cost of extra computations to revert it back to original + # base params after kv-cache computation of each prompt, so that it can + # be used by the next prompt. But this optimization could also be tricky + # because as of now same params are being shared by prefill and generate, + # where generate always expect the base_params. So some race conditions need + # to be avoided. final_prefill_params = None if adapter_id == "": final_prefill_params = prefill_params diff --git a/jetstream/tools/maxtext/model_ckpt_conversion.sh b/jetstream/tools/maxtext/model_ckpt_conversion.sh index a18691fa..adc7c00c 100644 --- a/jetstream/tools/maxtext/model_ckpt_conversion.sh +++ b/jetstream/tools/maxtext/model_ckpt_conversion.sh @@ -38,10 +38,17 @@ export MODEL_BUCKET=$4 # Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run, specifically the unscanned checkpoint. export BASE_OUTPUT_DIRECTORY=$5 -export LORA_LOCAL_PATH=$6 +export HUGGING_FACE_CHECKPOINT=$6 + +export LORA_INPUT_ADAPTERS_PATH=$7 export BUCKET_LOCATION=US +if [[ -z "HUGGING_FACE_CHECKPOINT" ]]; then + echo "HUGGING_FACE_CHECKPOINT is required." + exit 1 +fi + # Create three GCS buckets for the demo. gcloud storage buckets create ${MODEL_BUCKET} --location=${BUCKET_LOCATION} || true gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} --location=${BUCKET_LOCATION} || true @@ -59,40 +66,56 @@ else # llama_or_mistral_ckpt.py requires local path, so we need to copy the checkpoint from CHKPT_BUCKET to local. tmp_ckpt_path="/tmp/" gcloud storage cp -r ${CHKPT_BUCKET} ${tmp_ckpt_path} + path_parts=(${CHKPT_BUCKET//\// }) directory_substring=${path_parts[-1]} CONVERT_CKPT_SCRIPT="llama_or_mistral_ckpt.py" - if [[ -x "${LORA_LOCAL_PATH}" ]]; then + + if [[ ! -z "${LORA_INPUT_ADAPTERS_PATH}" ]]; then + lora_local_path="/tmp/" + + if [[ "${LORA_INPUT_ADAPTERS_PATH}" =~ ^gs:// ]]; then + path_parts=(${LORA_INPUT_ADAPTERS_PATH//\// }) + lora_dir_substring=${path_parts[-1]} + + lora_local_path="${tmp_ckpt_path}${lora_dir_substring}" + if [[ ! -d ${lora_local_path} ]]; then + mkdir ${lora_local_path} + fi + gcloud storage cp -r ${LORA_INPUT_ADAPTERS_PATH} ${tmp_ckpt_path} + else + lora_local_path=${LORA_INPUT_ADAPTERS_PATH} + fi + JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \ --base-model-path ${tmp_ckpt_path}${directory_substring} \ --maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \ --model-size ${MODEL_NAME} \ - --lora-config-path ${LORA_LOCAL_PATH}/adapter_config.json \ - --lora-model-path ${LORA_LOCAL_PATH}/adapter_model.bin + --lora-input-adapters-path ${lora_local_path} \ + --huggingface-checkpoint ${HUGGING_FACE_CHECKPOINT} else JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \ --base-model-path ${tmp_ckpt_path}${directory_substring} \ --maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \ - --model-size ${MODEL_NAME} + --model-size ${MODEL_NAME} \ + --huggingface-checkpoint ${HUGGING_FACE_CHECKPOINT} fi fi echo "Written MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}" # We define `SCANNED_CKPT_PATH` to refer to the checkpoint subdirectory. -# export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} # Convert MaxText compatible checkpoints to unscanned checkpoints. # Note that the `SCANNED_CKPT_PATH` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx} -if [[ -x "${LORA_LOCAL_PATH}" ]]; then +if [[ ! -z "${LORA_INPUT_ADAPTERS_PATH}" ]]; then JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \ MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ - load_parameters_path=${SCANNED_CKPT_PATH}/base_weights/0/items \ - lora_parameters_base_path=${SCANNED_CKPT_PATH}/lora_weights/0/items \ - lora_config_path=${LORA_LOCAL_PATH}/adapter_config.json \ + load_parameters_path=${SCANNED_CKPT_PATH}/base/0/items \ + lora_input_adapters_path=${SCANNED_CKPT_PATH}/LoRAs \ run_name=${RUN_NAME} \ model_name=${MODEL_NAME} \ force_unroll=true @@ -101,7 +124,7 @@ else JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \ MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ - load_parameters_path=${SCANNED_CKPT_PATH}/base_weights/0/items \ + load_parameters_path=${SCANNED_CKPT_PATH}/0/items \ run_name=${RUN_NAME} \ model_name=${MODEL_NAME} \ force_unroll=true From ed66fdf06cb2064d118d0d8ebb76bf700616b3c2 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 6 Mar 2025 22:34:10 +0000 Subject: [PATCH 12/22] Fixing TypeCheck errors. --- jetstream/core/lora/adapter_tensorstore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream/core/lora/adapter_tensorstore.py b/jetstream/core/lora/adapter_tensorstore.py index 306d13d2..445fe8c9 100644 --- a/jetstream/core/lora/adapter_tensorstore.py +++ b/jetstream/core/lora/adapter_tensorstore.py @@ -62,7 +62,7 @@ class AdapterMetadata: size_hbm: int = 0 # Size in HBM (bytes) size_cpu: int = 0 # Size in CPU RAM (bytes) last_accessed: float = 0.0 # timestamp - config: Dict[str, Any] = None + config: Dict[str, Any] = {} class AdapterTensorStore: From a6a5cd1f02bdaf0040f6cecf393c46e504dabbf2 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 7 Mar 2025 06:15:39 +0000 Subject: [PATCH 13/22] Fixing linting error. --- jetstream/core/lora/adapter_tensorstore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream/core/lora/adapter_tensorstore.py b/jetstream/core/lora/adapter_tensorstore.py index 445fe8c9..ecbbeb93 100644 --- a/jetstream/core/lora/adapter_tensorstore.py +++ b/jetstream/core/lora/adapter_tensorstore.py @@ -62,7 +62,7 @@ class AdapterMetadata: size_hbm: int = 0 # Size in HBM (bytes) size_cpu: int = 0 # Size in CPU RAM (bytes) last_accessed: float = 0.0 # timestamp - config: Dict[str, Any] = {} + config: Dict[str, Any] = dataclasses.field(default_factory=dict) class AdapterTensorStore: From e4d22bf4193e4b9ae1dcfce7222cea34535f7f97 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 7 Mar 2025 16:46:20 +0000 Subject: [PATCH 14/22] JetStream changes for Jax based implementation of unified_lora_params for decoding batch of multiple different lora adapters. --- jetstream/core/lora/adapter_tensorstore.py | 74 +++++++++++++++++++++- jetstream/core/orchestrator.py | 18 ++++-- jetstream/engine/engine_api.py | 1 + 3 files changed, 87 insertions(+), 6 deletions(-) diff --git a/jetstream/core/lora/adapter_tensorstore.py b/jetstream/core/lora/adapter_tensorstore.py index ecbbeb93..75930aea 100644 --- a/jetstream/core/lora/adapter_tensorstore.py +++ b/jetstream/core/lora/adapter_tensorstore.py @@ -66,7 +66,7 @@ class AdapterMetadata: class AdapterTensorStore: - def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int): + def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int, total_slots: int): self.hbm_memory_budget = hbm_memory_budget self.cpu_memory_budget = cpu_memory_budget self.adapter_registry: Dict[str, AdapterMetadata] = {} # All known adapters @@ -75,6 +75,8 @@ def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int): self.current_hbm_usage: int = 0 self.current_cpu_usage: int = 0 self.running_requests: int = 0 # Number of async tasks which are in "loading" state + self.decoding_adapters_cache: Dict[str, Any] = {} + self.total_slots = total_slots self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety @@ -145,6 +147,76 @@ async def _transfer_to_cpu(self, adapter_id: str): metadata.last_accessed = time.time() + def _initialize_decoding_adapters_cache(self, adapter_weights): + """ + Create a new PyTree with zero tensors at the paths corresponding to non-None leaves + in the input PyTree. The zero tensors have an added dimension of size `self.totol_slots`. + + Args: + adatper_weights: The input PyTree, whose structure will be mirrored. + + Returns: + A new PyTree with zero Tensors or None values, mirroring the structure of the input PyTree. + """ + def create_zero_leaf(leaf): + if leaf is not None: + original_shape = leaf.shape + if not original_shape: # handle scalar case + zero_tensor_shape = (self.total_slots,) + else: + zero_tensor_shape = (self.total_slots,) + original_shape # Prepend a new dimension + + return jnp.zeros(zero_tensor_shape, dtype=leaf.dtype) + else: + return None # Maintain None structure for None leaves + + return jax.tree_util.tree_map(create_zero_leaf, adapter_weights) + + + def insert_adapter_in_cache(self, adapter_id: str, slot_id: int): + """ + Insert the specific adapter tensors into a slot in the serving_adapters_cache. + + Args: + adapter_id: The id of the adapter, whose tensors will be inserted + slot_id: The id of slot, which represents the index in the serving_adapter_cache + where the adapter tensors will be inserted. + """ + + def insert_leaf(dest_leaf, source_leaf): + if dest_leaf is not None and source_leaf is not None: + return dest_leaf.at[slot_id].set(source_leaf) # Insert at the specific index + elif dest_leaf is not None: + return dest_leaf # If source_leaf is None, keep the zero_leaf as is + elif source_leaf is not None: # In this case the adapters have different target modules + original_shape = source_leaf.shape + if not original_shape: # Handle scalar case + zero_tensor_shape = (self.total_slots,) + else: + zero_tensor_shape = (self.total_slots,) + original_shape + new_dest_leaf = jnp.zeros(zero_tensor_shape, dtype=source_leaf.dtype) + return new_dest_leaf.at[slot_id].set(source_leaf) + else: + return None # If both are None, return None + + if adapter_id == "": + logging.info("Empty adapter id. So no LoRA tensors inserted into the cache in adapter_tensorStore.") + return + + metadata = self.adapter_registry[adapter_id] + + asyncio.run(self.load_adapter(adapter_id, True)) + + adapter_weights = self.loaded_adapters_hbm[adapter_id] + + if not self.decoding_adapters_cache: + self.decoding_adapters_cache = self._initialize_decoding_adapters_cache(adapter_weights) + + self.decoding_adapters_cache = jax.tree_util.tree_map(insert_leaf, + self.decoding_adapters_cache, + adapter_weights) + + async def get_hbm_loaded_adapters(self): """Returns a comma separated list of adapters loaded into HBM.""" diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index dad41aa5..fa5d9460 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -281,10 +281,6 @@ def __init__( if generate_params is None: raise ValueError("No generate parameter provided.") - self._adapter_tensorstore = adapter_tensorstore.AdapterTensorStore( - hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM - cpu_memory_budget=(100 * (1024 ** 3))) # 100 GB RAM - logger.info( "Initializing the driver with %d prefill engines and %d " "generate engines in %s mode", @@ -301,6 +297,15 @@ def __init__( self._metrics_collector = metrics_collector self._multi_sampling = multi_sampling + total_slots = 0 + for engine in self._generate_engines: + total_slots += engine.max_concurrent_decodes + + self._adapter_tensorstore = adapter_tensorstore.AdapterTensorStore( + hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM + cpu_memory_budget=(100 * (1024 ** 3)), # 100 GB RAM + total_slots=total_slots) + # Stages 1-4 represent the life cycle of a request. # Stage 1 # At first, a request is placed here in order to get prefilled. @@ -930,6 +935,9 @@ def _insert_if_possible( slot=slot, #request_id=new_request.request_id, ) + + self._adapter_tensorstore.insert_adapter_in_cache(new_request.adapter_id, slot) + ThreadDebugLog( thread_name, f"Generate slice {idx} filled slot {slot} at step " @@ -1136,7 +1144,7 @@ def _generate_thread(self, idx: int): # Now we actually take a generate step on requests in the slots. decode_state, sampled_tokens = generate_engine.generate( - generate_params, decode_state + generate_params, decode_state, self._adapter_tensorstore.decoding_adapters_cache, ) sampled_tokens.copy_to_host_async() # Respond to detokenization backpressure. diff --git a/jetstream/engine/engine_api.py b/jetstream/engine/engine_api.py index 42d1e172..ea64a99d 100644 --- a/jetstream/engine/engine_api.py +++ b/jetstream/engine/engine_api.py @@ -197,6 +197,7 @@ def generate( params: Params, decode_state: DecodeState, sampler: Optional[Callable[[Any], Any]] = None, + lora_params: Params = None, ) -> Tuple[DecodeState, ResultTokens]: """Generates tokens for each sequence being decoded in parallel. From bd67171a559c646cbad2c7ac87523ba22b711447 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 7 Mar 2025 23:18:02 +0000 Subject: [PATCH 15/22] Adding documentations. --- jetstream/core/lora/adapter_tensorstore.py | 62 ++++++++++++++++--- .../core/lora/multi_lora_inference_api.py | 2 +- .../tools/multi_adapter_service_client.py | 25 +++++++- 3 files changed, 79 insertions(+), 10 deletions(-) diff --git a/jetstream/core/lora/adapter_tensorstore.py b/jetstream/core/lora/adapter_tensorstore.py index ecbbeb93..571bf0e4 100644 --- a/jetstream/core/lora/adapter_tensorstore.py +++ b/jetstream/core/lora/adapter_tensorstore.py @@ -66,12 +66,29 @@ class AdapterMetadata: class AdapterTensorStore: + """ + Manages the storage and retrieval of LoRA adapter weights, handling + placement in either HBM (High Bandwidth Memory, on the TPU/GPU) or CPU RAM. + + This class implements an LRU (Least Recently Used) eviction policy + to manage memory usage. It supports asynchronous loading and unloading + of adapters to avoid blocking the main inference thread. + + Args: + hbm_memory_budget (int): The maximum amount of HBM (in bytes) to use for + storing LoRA adapter weights. + cpu_memory_budget (int): The maximum amount of CPU RAM (in bytes) to use + for storing LoRA adapter weights. + """ + + def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int): + """Initializes the AdapterTensorStore.""" self.hbm_memory_budget = hbm_memory_budget self.cpu_memory_budget = cpu_memory_budget self.adapter_registry: Dict[str, AdapterMetadata] = {} # All known adapters - self.loaded_adapters_hbm: Dict[str, jnp.ndarray] = {} # adapter_id -> Unified LoRA params (in HBM) - self.loaded_adapters_cpu: Dict[str, np.ndarray] = {} # adapter_id -> Unified LoRA params (in CPU RAM) + self.loaded_adapters_hbm: Dict[str, jnp.ndarray] = {} # adapter_id -> LoRA params (in HBM) + self.loaded_adapters_cpu: Dict[str, np.ndarray] = {} # adapter_id -> LoRA params (in CPU RAM) self.current_hbm_usage: int = 0 self.current_cpu_usage: int = 0 self.running_requests: int = 0 # Number of async tasks which are in "loading" state @@ -80,6 +97,18 @@ def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int): def register_adapter(self, adapter_id: str, adapter_path: str, config: Dict[str, Any]): """Registers a new LoRA adatper.""" + """ + Registers a LoRA adapter with the TensorStore. This does *not* load + the adapter; it simply adds metadata about the adapter to the registry. + + Args: + adapter_id (str): A unique identifier for the adapter. + adapter_path (str): The path to the adapter weights (file or directory). + config (dict): Config of the loRA adapter. + + Raises: + ValueError: If an adapter with the same ID is already registered. + """ if adapter_id in self.adapter_registry: raise ValueError(f"Adapter with ID '{adapter_id}' already registered.") self.adapter_registry[adapter_id] = AdapterMetadata( @@ -162,9 +191,28 @@ async def load_adapter( self, adapter_id: str, adapter_weights = None, - to_hbm: bool = True, - force_load: bool = False): - """Loads a LoRA adapter's weights, managing HBM and CPU memory.""" + to_hbm: bool = True): + """ + Loads a LoRA adapter's weights into memory (either HBM or CPU RAM). + + This method is asynchronous to avoid blocking the main thread during + potentially slow I/O operations. It handles: + - Checking if the adapter is already loaded. + - Checking if there's enough memory (and evicting if necessary). + - Loading the weights (in a separate thread). + - Updating the adapter's status and metadata. + + Args: + adapter_id (str): The ID of the adapter to load. + adapter_weights: In the form of a PyTree. + to_hbm (bool): Whether to load the adapter into HBM (True) or + CPU RAM (False). Defaults to True (HBM). + + Raises: + ValueError: If the adapter ID is not registered. + RuntimeError: If there is not enough memory to load the adapter, + and eviction fails to free up enough space. + """ if adapter_id not in self.adapter_registry: raise ValueError(f"Adapter with ID '{adapter_id}' not registered.") @@ -172,7 +220,7 @@ async def load_adapter( metadata = self.adapter_registry[adapter_id] async with self.lock: # Acquire lock for thread safety - if not force_load and metadata.status in ("loaded_hbm", "loaded_cpu"): + if metadata.status in ("loaded_hbm", "loaded_cpu"): metadata.last_accessed = time.time() # if already loaded in HBM and we want HBM, or @@ -195,7 +243,7 @@ async def load_adapter( await asyncio.sleep(0.1) # Short sleep to avoid busy-waiting # Make recursive call to load_adapter to copy to device - await self.load_adapter(adapter_id, adapter_weights, to_hbm, force_load) + await self.load_adapter(adapter_id, adapter_weights, to_hbm) return metadata.status = "loading" diff --git a/jetstream/core/lora/multi_lora_inference_api.py b/jetstream/core/lora/multi_lora_inference_api.py index 9edc1ad1..60dfa74a 100644 --- a/jetstream/core/lora/multi_lora_inference_api.py +++ b/jetstream/core/lora/multi_lora_inference_api.py @@ -31,7 +31,7 @@ class MultiLoraManager(multi_lora_decoding_pb2_grpc.v1Servicer): - """Manages the parameters of multiple lora requests and their lifelines.""" + """Manages the parameters of multiple lora requests and their status/lifetimes.""" _driver: orchestrator.Driver diff --git a/jetstream/tools/multi_adapter_service_client.py b/jetstream/tools/multi_adapter_service_client.py index d14617c9..547078ee 100644 --- a/jetstream/tools/multi_adapter_service_client.py +++ b/jetstream/tools/multi_adapter_service_client.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""A test request.""" +"""A gRPC client to interact with JetStream Server.""" from typing import Sequence @@ -55,7 +55,28 @@ def main(argv: Sequence[str]) -> None: - del argv + """ + Main function for a gRPC client that interacts with a JetStream server. + + This client can: + - Load a LoRA adapter. + - Unload a LoRA adapter. + - List loaded adapters and their metadata. + - Generate text completions (using LoRA adapters if specified). + + The client uses command-line flags to specify the server address, port, + text input, maximum number of tokens, adapter ID, adapter path, and the + API to call. It uses insecure gRPC channels (suitable for local testing). + + Args: + argv: Command-line arguments (not used directly, flags are used instead). + + Raises: + ValueError: For invalid configurations, like missing required parameters + for specific API calls. + """ + + del argv # Unused # Note: Uses insecure_channel only for local testing. Please add grpc # credentials for Production. address = f"{_SERVER.value}:{_PORT.value}" From 5f679a9e41558cfdff3a9f3297ce4ea37d2a091a Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 13 Mar 2025 18:53:20 +0000 Subject: [PATCH 16/22] - Created separate adapter_tensorstore for each engine. - Implemented unapply lora from base_params - Fixed some comments from the PR --- jetstream/core/lora/adapter_tensorstore.py | 70 +++++- .../core/lora/multi_lora_inference_api.py | 160 +------------- jetstream/core/orchestrator.py | 207 ++++++++++++------ jetstream/core/proto/jetstream.proto | 2 +- jetstream/core/proto/jetstream_pb2.py | 44 ++-- .../core/proto/multi_lora_decoding.proto | 64 ------ .../core/proto/multi_lora_decoding_pb2.py | 50 ++--- .../proto/multi_lora_decoding_pb2_grpc.py | 34 --- jetstream/core/server_lib.py | 42 +++- .../tools/maxtext/model_ckpt_conversion.sh | 4 +- jetstream/tools/requester.py | 8 +- 11 files changed, 292 insertions(+), 393 deletions(-) diff --git a/jetstream/core/lora/adapter_tensorstore.py b/jetstream/core/lora/adapter_tensorstore.py index 571bf0e4..0f86237f 100644 --- a/jetstream/core/lora/adapter_tensorstore.py +++ b/jetstream/core/lora/adapter_tensorstore.py @@ -26,6 +26,7 @@ import functools from typing import Dict, Optional, Any import numpy as np +from jetstream.engine import engine_api def _get_size_of_pytree(params): @@ -82,8 +83,14 @@ class AdapterTensorStore: """ - def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int): + def __init__(self, + engine: engine_api.Engine, + adapters_dir_path: str, + hbm_memory_budget: int, + cpu_memory_budget: int): """Initializes the AdapterTensorStore.""" + self.engine = engine # Possibly MaxEngine object + self.adapters_dir_path = adapters_dir_path.rstrip("/") # All Adapters path without trailing `/` self.hbm_memory_budget = hbm_memory_budget self.cpu_memory_budget = cpu_memory_budget self.adapter_registry: Dict[str, AdapterMetadata] = {} # All known adapters @@ -95,26 +102,49 @@ def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int): self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety - def register_adapter(self, adapter_id: str, adapter_path: str, config: Dict[str, Any]): + def register_adapter(self, + adapter_id: str, + adapter_path: str = None, + adapter_config: Dict[str, Any] = None): """Registers a new LoRA adatper.""" """ - Registers a LoRA adapter with the TensorStore. This does *not* load - the adapter; it simply adds metadata about the adapter to the registry. + Registers a LoRA adapter with the TensorStore. This also loads the adapter; + IF called without adapter_config. Because in this case, it needs + to get adapter_config from the engine's load_single_adapter() call, which + also provides the adapter_params. So in that case it is beneficial to load + the adapter to HBM. This call path is expected only from the direct inference + request. + OTHERWISE, it simply adds metadata about the adapter to the registry. Args: adapter_id (str): A unique identifier for the adapter. adapter_path (str): The path to the adapter weights (file or directory). - config (dict): Config of the loRA adapter. + adapter_config (dict): Config of the loRA adapter. Raises: ValueError: If an adapter with the same ID is already registered. """ if adapter_id in self.adapter_registry: - raise ValueError(f"Adapter with ID '{adapter_id}' already registered.") + logging.warning(f"Adapter with ID '{adapter_id}' already registered.") + return + + if adapter_path is None: + adapter_path = f"{self.adapters_dir_path}/{adapter_id}" + + adapter_params = None + if adapter_config is None: + adapter_params, adapter_config = self.engine.load_single_adapter(adapter_path) + + if adapter_config is None: + raise ValueError(f"Failed to read adapter_config from {adapter_path}") + self.adapter_registry[adapter_id] = AdapterMetadata( adapter_id=adapter_id, adapter_path=adapter_path, - config=config) + config=adapter_config) + + if adapter_params is not None: + asyncio.run(self.load_adapter(adapter_id, adapter_params, True)) async def _transfer_to_hbm(self, adapter_id: str): @@ -254,7 +284,10 @@ async def load_adapter( try: if adapter_weights is None: - raise ValueError("Adapter weights for adapter_id={adapter_id} is None.") + adapter_weights, adapter_config = self.engine.load_single_adapter(adapter_path) + + if adapter_weights is None: + raise ValueError("Failed to load adapter_weights from {adapter_path}.") async with self.lock: # Critical section for memory management adapter_weights_as_jnp_array = _as_jnp_array(adapter_weights) @@ -303,21 +336,36 @@ async def load_adapter( self.running_requests -= 1 - def get_lora_config(self, adapter_id): + def get_lora_config(self, adapter_id: str, load_if_not_loaded: bool = False): """Getter for the LoRA adapter config.""" metadata = self.adapter_registry.get(adapter_id) + + if load_if_not_loaded and metadata is None: + self.register_adapter(adapter_id) + metadata = self.adapter_registry.get(adapter_id) + + if metadata is None: + raise ValueError(f"LoRA adapter with id={adapter_id} is not loaded.") + return metadata.config - def get_lora_weights(self, adapter_id, to_hbm: bool = True): + def get_lora_weights(self, + adapter_id, + to_hbm: bool = True, + load_if_not_loaded: bool = False): """Retrieves the unified LoRA parameters for the given adapter IDs. Handles HBM/CPU placement. """ metadata = self.adapter_registry.get(adapter_id) + if load_if_not_loaded and metadata is None: + self.register_adapter(adapter_id) + metadata = self.adapter_registry.get(adapter_id) + if metadata is None: - raise ValueError(f"Adapter with ID '{adapter_id}' not registered.") + raise ValueError(f"LoRA adapter with id={adapter_id} is not loaded.") if metadata.status != "loaded_hbm" and metadata.status != "loaded_cpu": asyncio.run(self.load_adapter(adapter_id, None, to_hbm)) # Start loading (async) diff --git a/jetstream/core/lora/multi_lora_inference_api.py b/jetstream/core/lora/multi_lora_inference_api.py index 60dfa74a..f483c12a 100644 --- a/jetstream/core/lora/multi_lora_inference_api.py +++ b/jetstream/core/lora/multi_lora_inference_api.py @@ -17,17 +17,11 @@ import logging import grpc -import time -import uuid -from typing import Any, AsyncIterator, Optional, Tuple, cast +from typing import Optional from jetstream.core import orchestrator -from jetstream.core.lora import adapter_tensorstore from jetstream.core.proto import multi_lora_decoding_pb2_grpc from jetstream.core.proto import multi_lora_decoding_pb2 -from jetstream.core.utils import async_multifuture -from jetstream.core.utils.return_sample import ReturnSample -from jetstream.engine import engine_api, tokenizer_api, token_utils class MultiLoraManager(multi_lora_decoding_pb2_grpc.v1Servicer): @@ -105,155 +99,3 @@ def unload_lora_adapter( logging.info(f"Loading of adapter_id={request.adapter_id} failed with error: {str(e)}") return multi_lora_decoding_pb2.UnloadAdapterResponse(success=False, error_message=str(e)) - - def _get_prefill_content( - self, request: multi_lora_decoding_pb2.CompletionRequest - ) -> Tuple[str | list[int], bool]: - which_content = request.WhichOneof("content") - content = getattr(request, which_content) - if which_content == "text_content": - return cast(multi_lora_decoding_pb2.CompletionRequest.TextContent, content).text, False - else: - return ( - list( - cast(multi_lora_decoding_pb2.CompletionRequest.TokenContent, content).token_ids - ), - True, - ) - - def process_client_side_tokenization_response(self, response: Any): - samples = [] - for sample in response: - samples.append( - multi_lora_decoding_pb2.CompletionResponse.StreamContent.Sample( - token_ids=sample.token_ids, - ) - ) - return multi_lora_decoding_pb2.CompletionResponse( - stream_content=multi_lora_decoding_pb2.CompletionResponse.StreamContent( - samples=samples - ) - ) - - def should_buffer_response(self, response: Any) -> bool: - for item in response: - if item.text and token_utils.is_byte_token(item.text[-1]): - # If any sample ends in bytes, this means we might still need to - # decode more bytes to compose the string. - return True - - def process_server_side_tokenization_response( - self, response: Any, buffered_response_list - ): - # Flush the buffered responses to each sample of current response. - current_response_with_flushed_buffer = list( - zip(*buffered_response_list, response) - ) - # Empty buffer: [[s0_cur], [s1_cur], ...] - # Has buffer: - # [[s0_b0, s0_b1, ..., s0_cur], [s1_b0, s1_b1, ..., s1_cur], ...] - current_response_with_flushed_buffer = cast( - list[list[ReturnSample]], current_response_with_flushed_buffer - ) - # Form correct sample(s) and return as StreamContent for this iteration. - samples = [] - for sample in current_response_with_flushed_buffer: - text = [] - token_ids = [] - for resp in sample: - text.extend(resp.text) - token_ids.extend(resp.token_ids) - samples.append( - multi_lora_decoding_pb2.CompletionResponse.StreamContent.Sample( - text=token_utils.text_tokens_to_str(text), - token_ids=token_ids, - ) - ) - return multi_lora_decoding_pb2.CompletionResponse( - stream_content=multi_lora_decoding_pb2.CompletionResponse.StreamContent( - samples=samples - ) - ) - - async def completions( # pylint: disable=invalid-overridden-method - self, - request: multi_lora_decoding_pb2.CompletionRequest, - context: Optional[grpc.aio.ServicerContext] = None, - ) -> AsyncIterator[multi_lora_decoding_pb2.CompletionResponse]: - - """Decode.""" - if context is None: - logging.warning( - "LLM orchestrator is being used in offline test mode, and will not" - " respond to gRPC queries - only direct function calls." - ) - is_client_side_tokenization = False - return_channel = async_multifuture.AsyncMultifuture() - if context: - context.add_done_callback(return_channel.cancel) - - prefill_content, is_client_side_tokenization = self._get_prefill_content( - request - ) - - # Wrap request as an ActiveRequest. - active_request = orchestrator.ActiveRequest( - request_id=uuid.uuid4(), - max_tokens=request.max_tokens, - prefill_content=prefill_content, - is_client_side_tokenization=is_client_side_tokenization, - return_channel=return_channel, - adapter_id=request.adapter_id, - metadata=orchestrator.ActiveRequestMetadata( - start_time=request.metadata.start_time, - prefill_enqueue_time=time.perf_counter(), - ), - ) - # The first stage is being prefilled, all other stages are handled - # inside the driver (transfer, generate*N, detokenize). - try: - self._driver.place_request_on_prefill_queue(active_request) - except queue.Full: - # Safely abort the gRPC server thread with a retriable error. - await _abort_or_raise( - context=context, - code=grpc.StatusCode.RESOURCE_EXHAUSTED, - details=( - "The driver prefill queue is full and more requests cannot be" - " handled. You may retry this request." - ), - ) - logging.info( - "Placed request on the prefill queue.", - ) - # When an active request is created a queue is instantiated. New tokens - # are placed there during the decoding loop, we pop from that queue by - # using the .next method on the active request. - # Yielding allows for the response to be a streaming grpc call - which - # can be called via iterating over a for loop on the client side. - # The DecodeResponse stream should consume all generated tokens in - # return_channel when complete signal is received (AsyncMultifuture - # promises this). - buffered_response_list = [] - async for response in active_request.return_channel: - response = cast(list[ReturnSample], response) - if is_client_side_tokenization: - # If is_client_side_tokenization, the client should request with token - # ids, and the JetStream server will return token ids as response. - # The client should take care of tokenization and detokenization. - yield self.process_client_side_tokenization_response(response) - else: - # Buffer response mechanism is used to handle streaming - # detokenization with special character (For some edge cases with - # SentencePiece tokenizer, it requires to decode a complete sequence - # instead of a single token). - if self.should_buffer_response(response): - buffered_response_list.append(response) - continue - yield self.process_server_side_tokenization_response( - response, buffered_response_list - ) - # Reset buffer after flushed. - buffered_response_list = [] - - diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index dad41aa5..29b656be 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -257,8 +257,9 @@ class Driver: # All metrics we want to monitor should be collected with this _metrics_collector: JetstreamMetricsCollector | None = None - # An object to store and manage the adapters - _adapter_tensorstore: adapter_tensorstore.AdapterTensorStore | None = None + # An object to store and manage the adapters for each prefill and generate Engine + _prefill_adapter_tensorstore: list[adapter_tensorstore.AdapterTensorStore] + _generate_adapter_tensorstore: list[adapter_tensorstore.AdapterTensorStore] def __init__( self, @@ -266,6 +267,8 @@ def __init__( generate_engines: Optional[list[engine_api.Engine]] = None, prefill_params: Optional[list[Any]] = None, generate_params: Optional[list[Any]] = None, + prefill_adapter_tensorstore: Optional[list[adapter_tensorstore.AdapterTensorStore]] = None, + generate_adapter_tensorstore: Optional[list[adapter_tensorstore.AdapterTensorStore]] = None, interleaved_mode: bool = False, jax_padding: bool = True, metrics_collector: JetstreamMetricsCollector | None = None, @@ -281,9 +284,13 @@ def __init__( if generate_params is None: raise ValueError("No generate parameter provided.") - self._adapter_tensorstore = adapter_tensorstore.AdapterTensorStore( - hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM - cpu_memory_budget=(100 * (1024 ** 3))) # 100 GB RAM + self._prefill_adapter_tensorstore = prefill_adapter_tensorstore + self._generate_adapter_tensorstore = generate_adapter_tensorstore + +# # TODO: Make `hbm_memory_budget` & `cpu_memory_budget` configurable. +# self._adapter_tensorstore = adapter_tensorstore.AdapterTensorStore( +# hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM +# cpu_memory_budget=(100 * (1024 ** 3))) # 100 GB RAM logger.info( "Initializing the driver with %d prefill engines and %d " @@ -547,8 +554,9 @@ def _export_lora_request_info(self): if self._metrics_collector: for idx, engine in enumerate(self._generate_engines): max_loras += engine.max_concurrent_decodes - - adapters_list_str += asyncio.run(self._adapter_tensorstore.get_hbm_loaded_adapters()) + if idx < len(self._generate_adapter_tensorstore): + adapters_list_str += asyncio.run( + self._generate_adapter_tensorstore[idx].get_hbm_loaded_adapters()) self._metrics_collector.get_lora_request_info_metric(max_loras, adapters_list_str).set_to_current_time() @@ -635,6 +643,9 @@ def _prefill_thread(self, idx: int): logger.info("Spinning up prefill thread %d.", idx) prefill_engine = self._prefill_engines[idx] prefill_params = self._prefill_params[idx] + _adapter_tensorstore = None + if idx < len(self._prefill_adapter_tensorstore): + _adapter_tensorstore = self._prefill_adapter_tensorstore[idx] metadata = prefill_engine.get_tokenizer() tokenizer = prefill_engine.build_tokenizer(metadata) thread_name = f"Prefill thread {idx}" @@ -678,17 +689,32 @@ def _prefill_thread(self, idx: int): # because as of now same params are being shared by prefill and generate, # where generate always expect the base_params. So some race conditions need # to be avoided. - final_prefill_params = None - if adapter_id == "": - final_prefill_params = prefill_params - else: - final_prefill_params = copy.deepcopy(prefill_params) - lora_params = self._adapter_tensorstore.get_lora_weights(adapter_id) - lora_config = self._adapter_tensorstore.get_lora_config(adapter_id) - self._prefill_engines[idx].apply_adapter( - final_prefill_params, - lora_config, - lora_params) + final_prefill_params = prefill_params + if adapter_id: + #final_prefill_params = copy.deepcopy(prefill_params) + try: + if _adapter_tensorstore is None: + raise ValueError( + f"_adapter_tensorstore is not initialized for prefill_engine_id={idx}") + + lora_params = _adapter_tensorstore.get_lora_weights( + adapter_id=adapter_id, load_if_not_loaded=True) + lora_config = _adapter_tensorstore.get_lora_config( + adapter_id=adapter_id, load_if_not_loaded=True) + prefill_engine.apply_adapter( + final_prefill_params, + lora_config, + lora_params) + except Exception as e: + request.num_samples = 1 + request.complete = np.zeros((request.num_samples,), np.bool_) + error_message = f"An error occurred: {type(e).__name__} - {str(e)}" + err_message_token_list = error_message.split() + error_result = ReturnSample(text=[error_message], + token_ids=[]) + request.enqueue_samples([error_result]) + request.return_channel.close() + continue # Compute new kv cache for the prefill_content. if self._multi_sampling: @@ -743,6 +769,29 @@ def _prefill_thread(self, idx: int): padded_tokens=padded_tokens, true_length=true_length, ) + + if adapter_id: + try: + if _adapter_tensorstore is None: + raise ValueError( + f"_adapter_tensorstore is not initialized for prefill_engine_id={idx}") + + lora_params = _adapter_tensorstore.get_lora_weights(adapter_id) + lora_config = _adapter_tensorstore.get_lora_config(adapter_id) + prefill_engine.unapply_adapter( + final_prefill_params, + lora_config, + lora_params) + except Exception as e: + request.num_samples = 1 + request.complete = np.zeros((request.num_samples,), np.bool_) + error_message = f"An error occurred: {type(e).__name__} - {str(e)}" + err_message_token_list = error_message.split() + error_result = ReturnSample(text=[error_message], + token_ids=[]) + request.enqueue_samples([error_result]) + request.return_channel.close() + continue del final_prefill_params @@ -1333,62 +1382,96 @@ def loadAdapterToTensorstore( self, adapter_id, adapter_path): - logging.info(f"Loading adapter_id={adapter_id} from adapter_path={adapter_path}.") - - if not self._prefill_engines and not self._generate_engines: - logging.info(f"There is no MaxEngine object defined. So could not load any adapter.") - - engine = None - - if self._prefill_engines: - engine = self._prefill_engines[0] - else: - engine = self._generate_engines[0] - - adapter_params, adapter_config = engine.load_single_adapter(adapter_path) - - if not adapter_params or not adapter_config: - logging.info("Either params or adapter config is not loaded successfully.") - - try: - self._adapter_tensorstore.register_adapter( - adapter_id, - adapter_path, - adapter_config) - except ValueError as e: - logging.info(f"Registration failed with error: {e}") - - asyncio.run(self._adapter_tensorstore.load_adapter(adapter_id, adapter_params, True)) - - logging.info(f"Successfully loaded adapter_id={adapter_id}.") - engine.print_stats("After loading adapter_id={adapter_id}") + """Load the adapter to adapter_tensorstore for each engine.""" + logger.info(f"Loading adapter_id={adapter_id} from adapter_path={adapter_path}.") + for idx, tensorstore in enumerate(self._prefill_adapter_tensorstore): + try: + engine = self._prefill_engines[idx] + adapter_params, adapter_config = engine.load_single_adapter(adapter_path) + + if not adapter_params or not adapter_config: + raise ValueError( + f"Failed to load adapter with id={adapter_id} from path={adapter_path}.") + + tensorstore.register_adapter( + adapter_id, + adapter_path, + adapter_config) + + asyncio.run(tensorstore.load_adapter(adapter_id, adapter_params, True)) + + logger.info(f"Successfully loaded adapter_id={adapter_id} in engine_{idx}.") + engine.print_stats("After loading adapter_id={adapter_id} in engine_{idx}") + + except Exception as e: + logger.info("Adapter loading failed with error: {str(e)}") + raise e + + for idx, tensorstore in enumerate(self._generate_adapter_tensorstore): + try: + engine = self._generate_engines[idx] + adapter_params, adapter_config = engine.load_single_adapter(adapter_path) + + if not adapter_params or not adapter_config: + raise ValueError( + f"Failed to load adapter with id={adapter_id} from path={adapter_path}.") + + tensorstore.register_adapter( + adapter_id, + adapter_path, + adapter_config) + + asyncio.run(tensorstore.load_adapter(adapter_id, adapter_params, True)) + + logger.info(f"Successfully loaded adapter_id={adapter_id} in engine_{idx}.") + engine.print_stats("After loading adapter_id={adapter_id} in engine_{idx}") + + except Exception as e: + logger.info("Adapter loading failed with error: {str(e)}") + raise e + def unloadAdapterFromTensorstore( self, adapter_id): - logging.info(f"Unoading adapter_id={adapter_id}.") + """Unload the adapter from adapter_tensorstore of each engine.""" + logger.info(f"Unloading adapter_id={adapter_id}.") - try: - asyncio.run(self._adapter_tensorstore.unload_adapter(adapter_id)) - except ValueError as e: - logging.info(f"Registration failed with error: {e}") + for idx, tensorstore in enumerate(self._prefill_adapter_tensorstore): + try: + engine = self._prefill_engines[idx] + asyncio.run(tensorstore.unload_adapter(adapter_id)) - engine = None + logger.info(f"Successfully unloaded adapter_id={adapter_id} from engine_{idx}.") + engine.print_stats("After loading adapter_id={adapter_id} from engine_{idx}") - if self._prefill_engines: - engine = self._prefill_engines[0] - else: - engine = self._generate_engines[0] + except Exception as e: + logger.info("Adapter unloading failed with error: {str(e)}") + raise e - logging.info(f"Successfully unloaded adapter_id={adapter_id}.") - engine.print_stats("After unloading adapter_id={adapter_id}") + for idx, tensorstore in enumerate(self._generate_adapter_tensorstore): + try: + engine = self._generate_engines[idx] + asyncio.run(tensorstore.unload_adapter(adapter_id)) + + logger.info(f"Successfully unloaded adapter_id={adapter_id} from engine_{idx}.") + engine.print_stats("After loading adapter_id={adapter_id} from engine_{idx}") + + except Exception as e: + logger.info("Adapter unloading failed with error: {str(e)}") + raise e def listAdaptersFromTensorstore(self): - logging.info(f"Listing loaded adapters.") + """List all the adapters from the adapter_tensorstore of each engine.""" + logger.info(f"Listing loaded adapters.") + + listed_adapters = {} + for idx, tensorstore in enumerate(self._generate_adapter_tensorstore): + listed_adapters.update(tensorstore.adapter_registry) - return self._adapter_tensorstore.adapter_registry + return listed_adapters class LLMOrchestrator(jetstream_pb2_grpc.OrchestratorServicer): @@ -1497,7 +1580,7 @@ async def Decode( # pylint: disable=invalid-overridden-method prefill_content=prefill_content, is_client_side_tokenization=is_client_side_tokenization, return_channel=return_channel, - adapter_id=request.adapter_id, + adapter_id=request.lora_adapter_id, metadata=ActiveRequestMetadata( start_time=request.metadata.start_time, prefill_enqueue_time=time.perf_counter(), diff --git a/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto index c6734858..624313c7 100644 --- a/jetstream/core/proto/jetstream.proto +++ b/jetstream/core/proto/jetstream.proto @@ -61,7 +61,7 @@ message DecodeRequest { int32 num_samples = 8; - string adapter_id = 9; + string lora_adapter_id = 9; reserved 1, 2, 3; // Next ID: 10 diff --git a/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py index 81a3759a..c0ae8169 100644 --- a/jetstream/core/proto/jetstream_pb2.py +++ b/jetstream/core/proto/jetstream_pb2.py @@ -27,7 +27,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fjetstream.proto\x12\x0fjetstream_proto\"\xa5\x03\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x12\x13\n\x0bnum_samples\x18\x08 \x01(\x05\x12\x12\n\nadapter_id\x18\t \x01(\t\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04\"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\"\x14\n\x12HealthCheckRequest\"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse\"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fjetstream.proto\x12\x0fjetstream_proto\"\xaa\x03\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x12\x13\n\x0bnum_samples\x18\x08 \x01(\x05\x12\x17\n\x0flora_adapter_id\x18\t \x01(\t\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04\"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\"\x14\n\x12HealthCheckRequest\"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse\"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -35,25 +35,25 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None _globals['_DECODEREQUEST']._serialized_start=37 - _globals['_DECODEREQUEST']._serialized_end=458 - _globals['_DECODEREQUEST_TEXTCONTENT']._serialized_start=314 - _globals['_DECODEREQUEST_TEXTCONTENT']._serialized_end=341 - _globals['_DECODEREQUEST_TOKENCONTENT']._serialized_start=343 - _globals['_DECODEREQUEST_TOKENCONTENT']._serialized_end=376 - _globals['_DECODEREQUEST_METADATA']._serialized_start=378 - _globals['_DECODEREQUEST_METADATA']._serialized_end=408 - _globals['_DECODERESPONSE']._serialized_start=461 - _globals['_DECODERESPONSE']._serialized_end=792 - _globals['_DECODERESPONSE_INITIALCONTENT']._serialized_start=627 - _globals['_DECODERESPONSE_INITIALCONTENT']._serialized_end=643 - _globals['_DECODERESPONSE_STREAMCONTENT']._serialized_start=646 - _globals['_DECODERESPONSE_STREAMCONTENT']._serialized_end=775 - _globals['_DECODERESPONSE_STREAMCONTENT_SAMPLE']._serialized_start=734 - _globals['_DECODERESPONSE_STREAMCONTENT_SAMPLE']._serialized_end=775 - _globals['_HEALTHCHECKREQUEST']._serialized_start=794 - _globals['_HEALTHCHECKREQUEST']._serialized_end=814 - _globals['_HEALTHCHECKRESPONSE']._serialized_start=816 - _globals['_HEALTHCHECKRESPONSE']._serialized_end=854 - _globals['_ORCHESTRATOR']._serialized_start=857 - _globals['_ORCHESTRATOR']._serialized_end=1042 + _globals['_DECODEREQUEST']._serialized_end=463 + _globals['_DECODEREQUEST_TEXTCONTENT']._serialized_start=319 + _globals['_DECODEREQUEST_TEXTCONTENT']._serialized_end=346 + _globals['_DECODEREQUEST_TOKENCONTENT']._serialized_start=348 + _globals['_DECODEREQUEST_TOKENCONTENT']._serialized_end=381 + _globals['_DECODEREQUEST_METADATA']._serialized_start=383 + _globals['_DECODEREQUEST_METADATA']._serialized_end=413 + _globals['_DECODERESPONSE']._serialized_start=466 + _globals['_DECODERESPONSE']._serialized_end=797 + _globals['_DECODERESPONSE_INITIALCONTENT']._serialized_start=632 + _globals['_DECODERESPONSE_INITIALCONTENT']._serialized_end=648 + _globals['_DECODERESPONSE_STREAMCONTENT']._serialized_start=651 + _globals['_DECODERESPONSE_STREAMCONTENT']._serialized_end=780 + _globals['_DECODERESPONSE_STREAMCONTENT_SAMPLE']._serialized_start=739 + _globals['_DECODERESPONSE_STREAMCONTENT_SAMPLE']._serialized_end=780 + _globals['_HEALTHCHECKREQUEST']._serialized_start=799 + _globals['_HEALTHCHECKREQUEST']._serialized_end=819 + _globals['_HEALTHCHECKRESPONSE']._serialized_start=821 + _globals['_HEALTHCHECKRESPONSE']._serialized_end=859 + _globals['_ORCHESTRATOR']._serialized_start=862 + _globals['_ORCHESTRATOR']._serialized_end=1047 # @@protoc_insertion_point(module_scope) diff --git a/jetstream/core/proto/multi_lora_decoding.proto b/jetstream/core/proto/multi_lora_decoding.proto index 77f190a6..30df270e 100644 --- a/jetstream/core/proto/multi_lora_decoding.proto +++ b/jetstream/core/proto/multi_lora_decoding.proto @@ -17,9 +17,6 @@ syntax = "proto3"; service v1 { - // Generate text based on a prompt. Supports streaming responses. - rpc completions (CompletionRequest) returns (stream CompletionResponse) {} - // Lists all the currently loaded LoRA adapters rpc models (ListAdaptersRequest) returns (ListAdaptersResponse) {} @@ -31,67 +28,6 @@ service v1 { } -message CompletionRequest { - // The maximum output length of a sequence. It's used in JetStream to control - // the output/decode length of a sequence. It would not be used in the engine. - // We should always set max_tokens <= (max_target_length - - // max_prefill_predict_length). max_target_length is the maximum length of a - // sequence; max_prefill_predict_length is the maximum length of the - // input/prefill of a sequence. - int32 max_tokens = 4; - - message TextContent { - string text = 1; - } - message TokenContent { - repeated int32 token_ids = 1; - } - - // The client can pass the inputs either as a string, in which case the server will - // tokenize it, or as tokens, in which case it's the client's responsibility to - // ensure they tokenize its input strings with the correct tokenizer. - oneof content { - TextContent text_content = 5; - TokenContent token_content = 6; - } - - message Metadata { - float start_time = 1; - } - - oneof metadata_optional { - Metadata metadata = 7; - } - - string adapter_id = 8; - - reserved 1, 2, 3; - // Next ID: 9 -} - -message CompletionResponse { - // InitialContent supports returning initial one-off response data from the - // stream. It's a placeholder for future features such as history cache. - message InitialContent {} - message StreamContent { - message Sample { - // The text string decoded from token id(s). - string text = 1; - // List of token ids, one list per sample. When speculative decoding is disabled, the list size should be 1; When speculative decoding is enabled, the list size should be >= 1. - repeated int32 token_ids = 2; - } - // Supports multiple samples in the StreamContent. The Sample list size depends on text generation strategy the engine used. - repeated Sample samples = 1; - } - - oneof content { - InitialContent initial_content = 2; - StreamContent stream_content = 3; - } - reserved 1; - // Next ID: 4 -} - message ListAdaptersRequest {} message ListAdaptersResponse { diff --git a/jetstream/core/proto/multi_lora_decoding_pb2.py b/jetstream/core/proto/multi_lora_decoding_pb2.py index 4ad06c83..e7fefbae 100644 --- a/jetstream/core/proto/multi_lora_decoding_pb2.py +++ b/jetstream/core/proto/multi_lora_decoding_pb2.py @@ -14,43 +14,27 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19multi_lora_decoding.proto\"\xf0\x02\n\x11\x43ompletionRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x36\n\x0ctext_content\x18\x05 \x01(\x0b\x32\x1e.CompletionRequest.TextContentH\x00\x12\x38\n\rtoken_content\x18\x06 \x01(\x0b\x32\x1f.CompletionRequest.TokenContentH\x00\x12/\n\x08metadata\x18\x07 \x01(\x0b\x32\x1b.CompletionRequest.MetadataH\x01\x12\x12\n\nadapter_id\x18\x08 \x01(\t\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04\"\xaa\x02\n\x12\x43ompletionResponse\x12=\n\x0finitial_content\x18\x02 \x01(\x0b\x32\".CompletionResponse.InitialContentH\x00\x12;\n\x0estream_content\x18\x03 \x01(\x0b\x32!.CompletionResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1au\n\rStreamContent\x12\x39\n\x07samples\x18\x01 \x03(\x0b\x32(.CompletionResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\"\x15\n\x13ListAdaptersRequest\"c\n\x14ListAdaptersResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\x12#\n\radapter_infos\x18\x03 \x03(\x0b\x32\x0c.AdapterInfo\"\x82\x01\n\x0b\x41\x64\x61pterInfo\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0cloading_cost\x18\x02 \x01(\x03\x12\x10\n\x08size_hbm\x18\x03 \x01(\x03\x12\x10\n\x08size_cpu\x18\x04 \x01(\x03\x12\x15\n\rlast_accessed\x18\x05 \x01(\x02\x12\x0e\n\x06status\x18\x06 \x01(\t\">\n\x12LoadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\"=\n\x13LoadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\"*\n\x14UnloadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"?\n\x15UnloadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t2\x83\x02\n\x02v1\x12:\n\x0b\x63ompletions\x12\x12.CompletionRequest\x1a\x13.CompletionResponse\"\x00\x30\x01\x12\x37\n\x06models\x12\x14.ListAdaptersRequest\x1a\x15.ListAdaptersResponse\"\x00\x12@\n\x11load_lora_adapter\x12\x13.LoadAdapterRequest\x1a\x14.LoadAdapterResponse\"\x00\x12\x46\n\x13unload_lora_adapter\x12\x15.UnloadAdapterRequest\x1a\x16.UnloadAdapterResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19multi_lora_decoding.proto\"\x15\n\x13ListAdaptersRequest\"c\n\x14ListAdaptersResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\x12#\n\radapter_infos\x18\x03 \x03(\x0b\x32\x0c.AdapterInfo\"\x82\x01\n\x0b\x41\x64\x61pterInfo\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0cloading_cost\x18\x02 \x01(\x03\x12\x10\n\x08size_hbm\x18\x03 \x01(\x03\x12\x10\n\x08size_cpu\x18\x04 \x01(\x03\x12\x15\n\rlast_accessed\x18\x05 \x01(\x02\x12\x0e\n\x06status\x18\x06 \x01(\t\">\n\x12LoadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\"=\n\x13LoadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\"*\n\x14UnloadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"?\n\x15UnloadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t2\xc7\x01\n\x02v1\x12\x37\n\x06models\x12\x14.ListAdaptersRequest\x1a\x15.ListAdaptersResponse\"\x00\x12@\n\x11load_lora_adapter\x12\x13.LoadAdapterRequest\x1a\x14.LoadAdapterResponse\"\x00\x12\x46\n\x13unload_lora_adapter\x12\x15.UnloadAdapterRequest\x1a\x16.UnloadAdapterResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'multi_lora_decoding_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_COMPLETIONREQUEST']._serialized_start=30 - _globals['_COMPLETIONREQUEST']._serialized_end=398 - _globals['_COMPLETIONREQUEST_TEXTCONTENT']._serialized_start=254 - _globals['_COMPLETIONREQUEST_TEXTCONTENT']._serialized_end=281 - _globals['_COMPLETIONREQUEST_TOKENCONTENT']._serialized_start=283 - _globals['_COMPLETIONREQUEST_TOKENCONTENT']._serialized_end=316 - _globals['_COMPLETIONREQUEST_METADATA']._serialized_start=318 - _globals['_COMPLETIONREQUEST_METADATA']._serialized_end=348 - _globals['_COMPLETIONRESPONSE']._serialized_start=401 - _globals['_COMPLETIONRESPONSE']._serialized_end=699 - _globals['_COMPLETIONRESPONSE_INITIALCONTENT']._serialized_start=547 - _globals['_COMPLETIONRESPONSE_INITIALCONTENT']._serialized_end=563 - _globals['_COMPLETIONRESPONSE_STREAMCONTENT']._serialized_start=565 - _globals['_COMPLETIONRESPONSE_STREAMCONTENT']._serialized_end=682 - _globals['_COMPLETIONRESPONSE_STREAMCONTENT_SAMPLE']._serialized_start=641 - _globals['_COMPLETIONRESPONSE_STREAMCONTENT_SAMPLE']._serialized_end=682 - _globals['_LISTADAPTERSREQUEST']._serialized_start=701 - _globals['_LISTADAPTERSREQUEST']._serialized_end=722 - _globals['_LISTADAPTERSRESPONSE']._serialized_start=724 - _globals['_LISTADAPTERSRESPONSE']._serialized_end=823 - _globals['_ADAPTERINFO']._serialized_start=826 - _globals['_ADAPTERINFO']._serialized_end=956 - _globals['_LOADADAPTERREQUEST']._serialized_start=958 - _globals['_LOADADAPTERREQUEST']._serialized_end=1020 - _globals['_LOADADAPTERRESPONSE']._serialized_start=1022 - _globals['_LOADADAPTERRESPONSE']._serialized_end=1083 - _globals['_UNLOADADAPTERREQUEST']._serialized_start=1085 - _globals['_UNLOADADAPTERREQUEST']._serialized_end=1127 - _globals['_UNLOADADAPTERRESPONSE']._serialized_start=1129 - _globals['_UNLOADADAPTERRESPONSE']._serialized_end=1192 - _globals['_V1']._serialized_start=1195 - _globals['_V1']._serialized_end=1454 + _globals['_LISTADAPTERSREQUEST']._serialized_start=29 + _globals['_LISTADAPTERSREQUEST']._serialized_end=50 + _globals['_LISTADAPTERSRESPONSE']._serialized_start=52 + _globals['_LISTADAPTERSRESPONSE']._serialized_end=151 + _globals['_ADAPTERINFO']._serialized_start=154 + _globals['_ADAPTERINFO']._serialized_end=284 + _globals['_LOADADAPTERREQUEST']._serialized_start=286 + _globals['_LOADADAPTERREQUEST']._serialized_end=348 + _globals['_LOADADAPTERRESPONSE']._serialized_start=350 + _globals['_LOADADAPTERRESPONSE']._serialized_end=411 + _globals['_UNLOADADAPTERREQUEST']._serialized_start=413 + _globals['_UNLOADADAPTERREQUEST']._serialized_end=455 + _globals['_UNLOADADAPTERRESPONSE']._serialized_start=457 + _globals['_UNLOADADAPTERRESPONSE']._serialized_end=520 + _globals['_V1']._serialized_start=523 + _globals['_V1']._serialized_end=722 # @@protoc_insertion_point(module_scope) diff --git a/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py b/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py index 495714de..274ceb33 100644 --- a/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py +++ b/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py @@ -14,11 +14,6 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ - self.completions = channel.unary_stream( - '/v1/completions', - request_serializer=multi__lora__decoding__pb2.CompletionRequest.SerializeToString, - response_deserializer=multi__lora__decoding__pb2.CompletionResponse.FromString, - ) self.models = channel.unary_unary( '/v1/models', request_serializer=multi__lora__decoding__pb2.ListAdaptersRequest.SerializeToString, @@ -39,13 +34,6 @@ def __init__(self, channel): class v1Servicer(object): """Missing associated documentation comment in .proto file.""" - def completions(self, request, context): - """Generate text based on a prompt. Supports streaming responses. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - def models(self, request, context): """Lists all the currently loaded LoRA adapters """ @@ -70,11 +58,6 @@ def unload_lora_adapter(self, request, context): def add_v1Servicer_to_server(servicer, server): rpc_method_handlers = { - 'completions': grpc.unary_stream_rpc_method_handler( - servicer.completions, - request_deserializer=multi__lora__decoding__pb2.CompletionRequest.FromString, - response_serializer=multi__lora__decoding__pb2.CompletionResponse.SerializeToString, - ), 'models': grpc.unary_unary_rpc_method_handler( servicer.models, request_deserializer=multi__lora__decoding__pb2.ListAdaptersRequest.FromString, @@ -100,23 +83,6 @@ def add_v1Servicer_to_server(servicer, server): class v1(object): """Missing associated documentation comment in .proto file.""" - @staticmethod - def completions(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_stream(request, target, '/v1/completions', - multi__lora__decoding__pb2.CompletionRequest.SerializeToString, - multi__lora__decoding__pb2.CompletionResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - @staticmethod def models(request, target, diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 10312ba3..b4fbf29a 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -34,6 +34,7 @@ import jax from jetstream.core import config_lib from jetstream.core import orchestrator +from jetstream.core.lora import adapter_tensorstore from jetstream.core.metrics.prometheus import JetstreamMetricsCollector from jetstream.core.proto import jetstream_pb2_grpc from jetstream.engine import warmup_utils, engine_api @@ -119,6 +120,7 @@ def create_driver( metrics_collector: JetstreamMetricsCollector | None = None, enable_model_warmup: bool = False, multi_sampling: bool = False, + lora_input_adapters_path: str = None ): """Creates a driver with a specified config. @@ -147,10 +149,41 @@ def create_driver( len(config.prefill_slices) + len(config.generate_slices) == 0 ) + prefill_adapter_tensorstore = [] + generate_adapter_tensorstore = [] + shared_adapter_tensorstore = [] + + if lora_input_adapters_path: + for pe in engines.prefill_engines: + prefill_adapter_tensorstore.append(adapter_tensorstore.AdapterTensorStore( + engine=pe, + adapters_dir_path=lora_input_adapters_path, + hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM + cpu_memory_budget=(100 * (1024 ** 3)) # 100 GB RAM + )) + + for ge in engines.generate_engines: + generate_adapter_tensorstore.append(adapter_tensorstore.AdapterTensorStore( + engine=ge, + adapters_dir_path=lora_input_adapters_path, + hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM + cpu_memory_budget=(100 * (1024 ** 3)) # 100 GB RAM + )) + + for ie in engines.interleaved_engines: + shared_adapter_tensorstore.append(adapter_tensorstore.AdapterTensorStore( + engine=ie, + adapters_dir_path=lora_input_adapters_path, + hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM + cpu_memory_budget=(100 * (1024 ** 3)) # 100 GB RAM + )) + prefill_engines = engines.prefill_engines + engines.interleaved_engines generate_engines = engines.generate_engines + engines.interleaved_engines prefill_params = prefill_params + shared_params generate_params = generate_params + shared_params + prefill_adapter_tensorstore += shared_adapter_tensorstore + generate_adapter_tensorstore += shared_adapter_tensorstore if prefill_engines is None: prefill_engines = [] @@ -185,6 +218,8 @@ def create_driver( generate_engines=generate_engines, prefill_params=prefill_params, generate_params=generate_params, + prefill_adapter_tensorstore=prefill_adapter_tensorstore, + generate_adapter_tensorstore=generate_adapter_tensorstore, interleaved_mode=interleaved_mode, jax_padding=jax_padding, metrics_collector=metrics_collector, @@ -205,7 +240,7 @@ def run( jax_profiler_port: int = 9999, enable_model_warmup: bool = False, multi_sampling: bool = False, - enable_llm_inference_pool: bool = False, + lora_input_adapters_path: str = None, ) -> JetStreamServer: """Runs a server with a specified config. @@ -222,6 +257,7 @@ def run( jax_profiler_port: The port JAX profiler server (default to 9999). enable_model_warmup: The flag to enable model server warmup. multi_sampling: The flag to enable multi-sampling. + lora_input_adapters_path: Input path for all lora adapters. Returns: JetStreamServer that wraps the grpc server and orchestrator driver. @@ -250,10 +286,14 @@ def run( metrics_collector, enable_model_warmup, multi_sampling, + lora_input_adapters_path ) # We default threads to the total number of concurrent allowed decodes, # to make sure we can fully saturate the model. Set default minimum to 64. threads = threads or max(driver.get_total_concurrent_requests(), 64) + enable_llm_inference_pool = False + if lora_input_adapters_path: + enable_llm_inference_pool = True jetstream_server = JetStreamServer(driver, threads, port, credentials, enable_llm_inference_pool) logging.info("Starting server on port %d with %d threads", port, threads) diff --git a/jetstream/tools/maxtext/model_ckpt_conversion.sh b/jetstream/tools/maxtext/model_ckpt_conversion.sh index adc7c00c..79187dfc 100644 --- a/jetstream/tools/maxtext/model_ckpt_conversion.sh +++ b/jetstream/tools/maxtext/model_ckpt_conversion.sh @@ -65,7 +65,7 @@ else pip install torch --index-url https://download.pytorch.org/whl/cpu # llama_or_mistral_ckpt.py requires local path, so we need to copy the checkpoint from CHKPT_BUCKET to local. tmp_ckpt_path="/tmp/" - gcloud storage cp -r ${CHKPT_BUCKET} ${tmp_ckpt_path} + #gcloud storage cp -r ${CHKPT_BUCKET} ${tmp_ckpt_path} path_parts=(${CHKPT_BUCKET//\// }) directory_substring=${path_parts[-1]} @@ -115,7 +115,7 @@ if [[ ! -z "${LORA_INPUT_ADAPTERS_PATH}" ]]; then MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ load_parameters_path=${SCANNED_CKPT_PATH}/base/0/items \ - lora_input_adapters_path=${SCANNED_CKPT_PATH}/LoRAs \ + lora_input_adapters_path=${SCANNED_CKPT_PATH}/loras \ run_name=${RUN_NAME} \ model_name=${MODEL_NAME} \ force_unroll=true diff --git a/jetstream/tools/requester.py b/jetstream/tools/requester.py index e8fe9026..5a60fe44 100644 --- a/jetstream/tools/requester.py +++ b/jetstream/tools/requester.py @@ -44,8 +44,8 @@ False, "Enable client side tokenization with tokenizer.", ) -_ADAPTER_ID = flags.DEFINE_string( - "adapter_id", +_LORA_ADAPTER_ID = flags.DEFINE_string( + "lora_adapter_id", "", "ID of the adapter for this decode request.", required=False) @@ -94,7 +94,7 @@ def main(argv: Sequence[str]) -> None: ), max_tokens=_MAX_TOKENS.value, num_samples=_NUM_SAMPLES.value, - adapter_id=_ADAPTER_ID.value, + lora_adapter_id=_LORA_ADAPTER_ID.value, ) else: request = jetstream_pb2.DecodeRequest( @@ -103,7 +103,7 @@ def main(argv: Sequence[str]) -> None: ), max_tokens=_MAX_TOKENS.value, num_samples=_NUM_SAMPLES.value, - adapter_id=_ADAPTER_ID.value, + lora_adapter_id=_LORA_ADAPTER_ID.value, ) return _GetResponseAsync(stub, request) From 5bda29b5326f2c01ade6b5796069ea890c66bc4b Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 13 Mar 2025 22:22:35 +0000 Subject: [PATCH 17/22] Refactoring and fixing lint errors. --- jetstream/core/lora/adapter_tensorstore.py | 5 +- .../core/lora/multi_lora_inference_api.py | 6 +- jetstream/core/metrics/prometheus.py | 7 +- jetstream/core/orchestrator.py | 169 +++++++++--------- jetstream/core/server_lib.py | 65 +++---- .../tools/multi_adapter_service_client.py | 25 +-- .../tools/multi_lora_decode_requester.py | 15 +- 7 files changed, 150 insertions(+), 142 deletions(-) diff --git a/jetstream/core/lora/adapter_tensorstore.py b/jetstream/core/lora/adapter_tensorstore.py index 0f86237f..13702ea5 100644 --- a/jetstream/core/lora/adapter_tensorstore.py +++ b/jetstream/core/lora/adapter_tensorstore.py @@ -104,8 +104,8 @@ def __init__(self, def register_adapter(self, adapter_id: str, - adapter_path: str = None, - adapter_config: Dict[str, Any] = None): + adapter_path: str | None = None, + adapter_config: Dict[str, Any] | None = None): """Registers a new LoRA adatper.""" """ Registers a LoRA adapter with the TensorStore. This also loads the adapter; @@ -284,6 +284,7 @@ async def load_adapter( try: if adapter_weights is None: + adapter_path = f"{self.adapters_dir_path}/{adapter_id}" adapter_weights, adapter_config = self.engine.load_single_adapter(adapter_path) if adapter_weights is None: diff --git a/jetstream/core/lora/multi_lora_inference_api.py b/jetstream/core/lora/multi_lora_inference_api.py index f483c12a..e3a97243 100644 --- a/jetstream/core/lora/multi_lora_inference_api.py +++ b/jetstream/core/lora/multi_lora_inference_api.py @@ -40,7 +40,7 @@ def models( """ListAdapters all loaded LoRA adapters.""" try: - adapters = self._driver.listAdaptersFromTensorstore() + adapters = self._driver.list_adapters_from_tensorstore() adapter_infos = [] for adapter_id, adapter_data in adapters.items(): @@ -77,7 +77,7 @@ def load_lora_adapter( """Load a LoRA adapter as mentioned in the request.""" try: - self._driver.loadAdapterToTensorstore(request.adapter_id, request.adapter_path) + self._driver.load_adapter_to_tensorstore(request.adapter_id, request.adapter_path) return multi_lora_decoding_pb2.LoadAdapterResponse(success=True) except Exception as e: @@ -93,7 +93,7 @@ def unload_lora_adapter( """Unload a LoRA adapter as mentioned in the request.""" try: - self._driver.unloadAdapterFromTensorstore(request.adapter_id) + self._driver.unload_adapter_from_tensorstore(request.adapter_id) return multi_lora_decoding_pb2.UnloadAdapterResponse(success=True) except Exception as e: logging.info(f"Loading of adapter_id={request.adapter_id} failed with error: {str(e)}") diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index 6d851f81..f4e50ce0 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -254,14 +254,14 @@ def __init__(self, model_name: Optional[str] = None): _kv_cache_utilization = Gauge( name="kv_cache_utilization_perc", - documentation="Percentage of kv-cache utilized by the requests under processing.", + documentation="kv-cache utilization % by the requests under processing.", labelnames=["id"], multiprocess_mode="sum", ) _lora_request_info = Gauge( name="lora_request_info", - documentation="Information about LoRA adapters loaded into TPU Memory for serving current requests.", + documentation="LoRA adapters loaded into HBM for processing requests.", labelnames=[ "id", "max_lora", @@ -322,4 +322,5 @@ def get_kv_cache_utilization_metric(self): return self._kv_cache_utilization.labels(**self.universal_labels) def get_lora_request_info_metric(self, max_lora: int, loaded_adapters: str): - return self._lora_request_info.labels(**self.universal_labels, max_lora=max_lora, running_lora_adapters=loaded_adapters) + return self._lora_request_info.labels(**self.universal_labels, + max_lora=max_lora, running_lora_adapters=loaded_adapters) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 29b656be..614d74e6 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -74,7 +74,6 @@ to debug hangs due to bugs in threads (it is easier to debug with live logs). """ -import copy from datetime import datetime import dataclasses import functools @@ -94,7 +93,7 @@ import grpc import jax import jax.numpy as jnp -from jetstream.core.lora import adapter_tensorstore +from jetstream.core.lora import adapter_tensorstore as adapterstore from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc @@ -257,9 +256,9 @@ class Driver: # All metrics we want to monitor should be collected with this _metrics_collector: JetstreamMetricsCollector | None = None - # An object to store and manage the adapters for each prefill and generate Engine - _prefill_adapter_tensorstore: list[adapter_tensorstore.AdapterTensorStore] - _generate_adapter_tensorstore: list[adapter_tensorstore.AdapterTensorStore] + # Store and manage the adapters for each prefill & generate Engine + _prefill_adapterstore: list[adapterstore.AdapterTensorStore] | None = None + _generate_adapterstore: list[adapterstore.AdapterTensorStore] | None = None def __init__( self, @@ -267,8 +266,10 @@ def __init__( generate_engines: Optional[list[engine_api.Engine]] = None, prefill_params: Optional[list[Any]] = None, generate_params: Optional[list[Any]] = None, - prefill_adapter_tensorstore: Optional[list[adapter_tensorstore.AdapterTensorStore]] = None, - generate_adapter_tensorstore: Optional[list[adapter_tensorstore.AdapterTensorStore]] = None, + prefill_adapterstore: Optional[ + list[adapterstore.AdapterTensorStore]] = None, + generate_adapterstore: Optional[ + list[adapterstore.AdapterTensorStore]] = None, interleaved_mode: bool = False, jax_padding: bool = True, metrics_collector: JetstreamMetricsCollector | None = None, @@ -284,13 +285,8 @@ def __init__( if generate_params is None: raise ValueError("No generate parameter provided.") - self._prefill_adapter_tensorstore = prefill_adapter_tensorstore - self._generate_adapter_tensorstore = generate_adapter_tensorstore - -# # TODO: Make `hbm_memory_budget` & `cpu_memory_budget` configurable. -# self._adapter_tensorstore = adapter_tensorstore.AdapterTensorStore( -# hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM -# cpu_memory_budget=(100 * (1024 ** 3))) # 100 GB RAM + self._prefill_adapterstore = prefill_adapterstore + self._generate_adapterstore = generate_adapterstore logger.info( "Initializing the driver with %d prefill engines and %d " @@ -525,14 +521,17 @@ def stop(self): logger.info("Driver stopped.") def _get_kv_cache_utilization(self): - """Calculated the kv_cache utilization in percentage based on requests being decoded.""" + """ + Calculated the kv_cache utilization in percentage based on requests + being decoded. + """ total_slots = 0 empty_slots = 0 for idx, engine in enumerate(self._generate_engines): total_slots += engine.max_concurrent_decodes empty_slots += self._generate_slots[idx].qsize() - return ((total_slots - empty_slots) * 100 / total_slots) + return (total_slots - empty_slots) * 100 / total_slots def _get_total_requests_waiting_decode(self): """Calculate the total size of all relevant queues.""" @@ -554,9 +553,9 @@ def _export_lora_request_info(self): if self._metrics_collector: for idx, engine in enumerate(self._generate_engines): max_loras += engine.max_concurrent_decodes - if idx < len(self._generate_adapter_tensorstore): + if idx < len(self._generate_adapterstore): adapters_list_str += asyncio.run( - self._generate_adapter_tensorstore[idx].get_hbm_loaded_adapters()) + self._generate_adapterstore[idx].get_hbm_loaded_adapters()) self._metrics_collector.get_lora_request_info_metric(max_loras, adapters_list_str).set_to_current_time() @@ -643,9 +642,9 @@ def _prefill_thread(self, idx: int): logger.info("Spinning up prefill thread %d.", idx) prefill_engine = self._prefill_engines[idx] prefill_params = self._prefill_params[idx] - _adapter_tensorstore = None - if idx < len(self._prefill_adapter_tensorstore): - _adapter_tensorstore = self._prefill_adapter_tensorstore[idx] + adapter_tensorstore = None + if idx < len(self._prefill_adapterstore): + adapter_tensorstore = self._prefill_adapterstore[idx] metadata = prefill_engine.get_tokenizer() tokenizer = prefill_engine.build_tokenizer(metadata) thread_name = f"Prefill thread {idx}" @@ -677,39 +676,31 @@ def _prefill_thread(self, idx: int): adapter_id = request.adapter_id - # As prefill is happening one prompt at a time, for each prefill, we are - # applying the LoRA params on base to create a copy of params (equivalent - # to the size of base params) and use that for generating kv-cache. This - # copy is called the final_prefill_params, which is deleted soon after the - # generation of kv-cache. - # We can have memory-optimizations by updating the original copy of - # base params at the cost of extra computations to revert it back to original - # base params after kv-cache computation of each prompt, so that it can - # be used by the next prompt. But this optimization could also be tricky - # because as of now same params are being shared by prefill and generate, - # where generate always expect the base_params. So some race conditions need - # to be avoided. + # Here we are applying the LoRA adapter params to the base params and + # them. In the interleaved mode, the prefill and generate shares the + # same params. But as long as prefill and decode happens sequentially, + # there is no issues. Issue will arrise if prefill and decode is running + # in parallel and sharing the same params. Issue arrise because prefill + # uses pre-merged weights and generate uses only base weights. final_prefill_params = prefill_params if adapter_id: - #final_prefill_params = copy.deepcopy(prefill_params) try: - if _adapter_tensorstore is None: - raise ValueError( - f"_adapter_tensorstore is not initialized for prefill_engine_id={idx}") - - lora_params = _adapter_tensorstore.get_lora_weights( + if adapter_tensorstore is None: + raise ValueError( + f"adapter_tensorstore is None for prefill_engine_id={idx}") + + lora_params = adapter_tensorstore.get_lora_weights( adapter_id=adapter_id, load_if_not_loaded=True) - lora_config = _adapter_tensorstore.get_lora_config( + lora_config = adapter_tensorstore.get_lora_config( adapter_id=adapter_id, load_if_not_loaded=True) - prefill_engine.apply_adapter( + prefill_engine.apply_adapter( final_prefill_params, lora_config, lora_params) - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught request.num_samples = 1 request.complete = np.zeros((request.num_samples,), np.bool_) error_message = f"An error occurred: {type(e).__name__} - {str(e)}" - err_message_token_list = error_message.split() error_result = ReturnSample(text=[error_message], token_ids=[]) request.enqueue_samples([error_result]) @@ -769,24 +760,23 @@ def _prefill_thread(self, idx: int): padded_tokens=padded_tokens, true_length=true_length, ) - + if adapter_id: try: - if _adapter_tensorstore is None: - raise ValueError( - f"_adapter_tensorstore is not initialized for prefill_engine_id={idx}") - - lora_params = _adapter_tensorstore.get_lora_weights(adapter_id) - lora_config = _adapter_tensorstore.get_lora_config(adapter_id) - prefill_engine.unapply_adapter( + if adapter_tensorstore is None: + raise ValueError( + f"adapter_tensorstore is None for prefill_engine_id={idx}") + + lora_params = adapter_tensorstore.get_lora_weights(adapter_id) + lora_config = adapter_tensorstore.get_lora_config(adapter_id) + prefill_engine.unapply_adapter( final_prefill_params, lora_config, lora_params) - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught request.num_samples = 1 request.complete = np.zeros((request.num_samples,), np.bool_) error_message = f"An error occurred: {type(e).__name__} - {str(e)}" - err_message_token_list = error_message.split() error_result = ReturnSample(text=[error_message], token_ids=[]) request.enqueue_samples([error_result]) @@ -1174,7 +1164,7 @@ def _generate_thread(self, idx: int): if decode_state is None: break - + # Export the lora_request_info metric self._export_lora_request_info() @@ -1378,97 +1368,104 @@ def _detokenize_thread(self, idx: int): logger.info("Detokenize thread %d stopped.", idx) - def loadAdapterToTensorstore( + def load_adapter_to_tensorstore( self, adapter_id, adapter_path): """Load the adapter to adapter_tensorstore for each engine.""" - logger.info(f"Loading adapter_id={adapter_id} from adapter_path={adapter_path}.") + logger.info("Loading adapter_id=%s from %s.", + adapter_id, adapter_path) - for idx, tensorstore in enumerate(self._prefill_adapter_tensorstore): + for idx, tensorstore in enumerate(self._prefill_adapterstore): try: engine = self._prefill_engines[idx] - adapter_params, adapter_config = engine.load_single_adapter(adapter_path) + adapter_params, adapter_config = engine.load_single_adapter( + adapter_path) if not adapter_params or not adapter_config: raise ValueError( - f"Failed to load adapter with id={adapter_id} from path={adapter_path}.") - + f"Failed to load adapter={adapter_id} from {adapter_path}.") + tensorstore.register_adapter( adapter_id, adapter_path, adapter_config) - + asyncio.run(tensorstore.load_adapter(adapter_id, adapter_params, True)) - logger.info(f"Successfully loaded adapter_id={adapter_id} in engine_{idx}.") - engine.print_stats("After loading adapter_id={adapter_id} in engine_{idx}") + logger.info("Successfully loaded '%s' in engine_%d.", + adapter_id, idx) + engine.print_stats(f"After loading '{adapter_id}' in engine_{idx}") except Exception as e: - logger.info("Adapter loading failed with error: {str(e)}") + logger.info("Adapter loading failed with error: %s", str(e)) raise e - for idx, tensorstore in enumerate(self._generate_adapter_tensorstore): + for idx, tensorstore in enumerate(self._generate_adapterstore): try: engine = self._generate_engines[idx] - adapter_params, adapter_config = engine.load_single_adapter(adapter_path) + adapter_params, adapter_config = engine.load_single_adapter( + adapter_path) if not adapter_params or not adapter_config: raise ValueError( - f"Failed to load adapter with id={adapter_id} from path={adapter_path}.") - + f"Failed to load adapter={adapter_id} from {adapter_path}.") + tensorstore.register_adapter( adapter_id, adapter_path, adapter_config) - + asyncio.run(tensorstore.load_adapter(adapter_id, adapter_params, True)) - logger.info(f"Successfully loaded adapter_id={adapter_id} in engine_{idx}.") - engine.print_stats("After loading adapter_id={adapter_id} in engine_{idx}") + logger.info("Successfully loaded '%s' in engine_%d.", + adapter_id, idx) + engine.print_stats(f"After loading '{adapter_id}' in engine_{idx}") except Exception as e: - logger.info("Adapter loading failed with error: {str(e)}") + logger.info("Adapter loading failed with error: %s", str(e)) raise e - - def unloadAdapterFromTensorstore( + + def unload_adapter_from_tensorstore( self, adapter_id): """Unload the adapter from adapter_tensorstore of each engine.""" - logger.info(f"Unloading adapter_id={adapter_id}.") + logger.info("Unloading adapter_id=%s", adapter_id) - for idx, tensorstore in enumerate(self._prefill_adapter_tensorstore): + for idx, tensorstore in enumerate(self._prefill_adapterstore): try: engine = self._prefill_engines[idx] asyncio.run(tensorstore.unload_adapter(adapter_id)) - logger.info(f"Successfully unloaded adapter_id={adapter_id} from engine_{idx}.") - engine.print_stats("After loading adapter_id={adapter_id} from engine_{idx}") + logger.info("Successfully unloaded '%s' in engine_%d.", + adapter_id, idx) + engine.print_stats(f"After unloading '{adapter_id}' in engine_{idx}") except Exception as e: - logger.info("Adapter unloading failed with error: {str(e)}") + logger.info("Adapter unloading failed with error: %s", str(e)) raise e - for idx, tensorstore in enumerate(self._generate_adapter_tensorstore): + for idx, tensorstore in enumerate(self._generate_adapterstore): try: engine = self._generate_engines[idx] asyncio.run(tensorstore.unload_adapter(adapter_id)) - logger.info(f"Successfully unloaded adapter_id={adapter_id} from engine_{idx}.") - engine.print_stats("After loading adapter_id={adapter_id} from engine_{idx}") + logger.info("Successfully unloaded '%s' in engine_%d.", + adapter_id, idx) + engine.print_stats(f"After unloading '{adapter_id}' in engine_{idx}") except Exception as e: - logger.info("Adapter unloading failed with error: {str(e)}") + logger.info("Adapter unloading failed with error: %s", str(e)) raise e - def listAdaptersFromTensorstore(self): + def list_adapters_from_tensorstore(self): """List all the adapters from the adapter_tensorstore of each engine.""" - logger.info(f"Listing loaded adapters.") + logger.info("Listing loaded adapters.") listed_adapters = {} - for idx, tensorstore in enumerate(self._generate_adapter_tensorstore): + for tensorstore in self._generate_adapterstore: listed_adapters.update(tensorstore.adapter_registry) return listed_adapters diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index b4fbf29a..2024a8b2 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -34,7 +34,7 @@ import jax from jetstream.core import config_lib from jetstream.core import orchestrator -from jetstream.core.lora import adapter_tensorstore +from jetstream.core.lora import adapter_tensorstore as adapterstore from jetstream.core.metrics.prometheus import JetstreamMetricsCollector from jetstream.core.proto import jetstream_pb2_grpc from jetstream.engine import warmup_utils, engine_api @@ -81,7 +81,8 @@ async def do_init(): multi_lora_decoding_pb2_grpc = importlib.import_module(module_name) multi_lora_decoding_pb2_grpc.add_v1Servicer_to_server( - multi_lora_inference.MultiLoraManager(driver=self._driver), self._grpc_server + multi_lora_inference.MultiLoraManager(driver=self._driver), + self._grpc_server ) self._grpc_server.add_secure_port(f"{_HOST}:{port}", credentials) @@ -120,7 +121,7 @@ def create_driver( metrics_collector: JetstreamMetricsCollector | None = None, enable_model_warmup: bool = False, multi_sampling: bool = False, - lora_input_adapters_path: str = None + lora_input_adapters_path: str | None = None ): """Creates a driver with a specified config. @@ -149,41 +150,44 @@ def create_driver( len(config.prefill_slices) + len(config.generate_slices) == 0 ) - prefill_adapter_tensorstore = [] - generate_adapter_tensorstore = [] - shared_adapter_tensorstore = [] + prefill_adapterstore = [] + generate_adapterstore = [] + shared_adapterstore = [] if lora_input_adapters_path: for pe in engines.prefill_engines: - prefill_adapter_tensorstore.append(adapter_tensorstore.AdapterTensorStore( - engine=pe, - adapters_dir_path=lora_input_adapters_path, - hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM - cpu_memory_budget=(100 * (1024 ** 3)) # 100 GB RAM - )) + prefill_adapterstore.append( + adapterstore.AdapterTensorStore( + engine=pe, + adapters_dir_path=lora_input_adapters_path, + hbm_memory_budget=20 * (1024 ** 3), # 20 GB HBM + cpu_memory_budget=100 * (1024 ** 3) # 100 GB RAM + )) for ge in engines.generate_engines: - generate_adapter_tensorstore.append(adapter_tensorstore.AdapterTensorStore( - engine=ge, - adapters_dir_path=lora_input_adapters_path, - hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM - cpu_memory_budget=(100 * (1024 ** 3)) # 100 GB RAM - )) + generate_adapterstore.append( + adapterstore.AdapterTensorStore( + engine=ge, + adapters_dir_path=lora_input_adapters_path, + hbm_memory_budget=20 * (1024 ** 3), # 20 GB HBM + cpu_memory_budget=100 * (1024 ** 3) # 100 GB RAM + )) for ie in engines.interleaved_engines: - shared_adapter_tensorstore.append(adapter_tensorstore.AdapterTensorStore( - engine=ie, - adapters_dir_path=lora_input_adapters_path, - hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM - cpu_memory_budget=(100 * (1024 ** 3)) # 100 GB RAM - )) + shared_adapterstore.append( + adapterstore.AdapterTensorStore( + engine=ie, + adapters_dir_path=lora_input_adapters_path, + hbm_memory_budget=20 * (1024 ** 3), # 20 GB HBM + cpu_memory_budget=100 * (1024 ** 3) # 100 GB RAM + )) prefill_engines = engines.prefill_engines + engines.interleaved_engines generate_engines = engines.generate_engines + engines.interleaved_engines prefill_params = prefill_params + shared_params generate_params = generate_params + shared_params - prefill_adapter_tensorstore += shared_adapter_tensorstore - generate_adapter_tensorstore += shared_adapter_tensorstore + prefill_adapterstore += shared_adapterstore + generate_adapterstore += shared_adapterstore if prefill_engines is None: prefill_engines = [] @@ -218,8 +222,8 @@ def create_driver( generate_engines=generate_engines, prefill_params=prefill_params, generate_params=generate_params, - prefill_adapter_tensorstore=prefill_adapter_tensorstore, - generate_adapter_tensorstore=generate_adapter_tensorstore, + prefill_adapterstore=prefill_adapterstore, + generate_adapterstore=generate_adapterstore, interleaved_mode=interleaved_mode, jax_padding=jax_padding, metrics_collector=metrics_collector, @@ -240,7 +244,7 @@ def run( jax_profiler_port: int = 9999, enable_model_warmup: bool = False, multi_sampling: bool = False, - lora_input_adapters_path: str = None, + lora_input_adapters_path: str | None = None, ) -> JetStreamServer: """Runs a server with a specified config. @@ -294,7 +298,8 @@ def run( enable_llm_inference_pool = False if lora_input_adapters_path: enable_llm_inference_pool = True - jetstream_server = JetStreamServer(driver, threads, port, credentials, enable_llm_inference_pool) + jetstream_server = JetStreamServer(driver, threads, port, + credentials, enable_llm_inference_pool) logging.info("Starting server on port %d with %d threads", port, threads) # Tweak gc config. diff --git a/jetstream/tools/multi_adapter_service_client.py b/jetstream/tools/multi_adapter_service_client.py index 547078ee..2e65ea97 100644 --- a/jetstream/tools/multi_adapter_service_client.py +++ b/jetstream/tools/multi_adapter_service_client.py @@ -19,6 +19,8 @@ from absl import app from absl import flags import grpc +from jetstream.core.proto import jetstream_pb2 +from jetstream.core.proto import jetstream_pb2_grpc from jetstream.core.proto import multi_lora_decoding_pb2 from jetstream.core.proto import multi_lora_decoding_pb2_grpc from jetstream.engine.token_utils import load_vocab @@ -32,15 +34,15 @@ "max_tokens", 3, "Maximum number of output/decode tokens of a sequence" ) -_ADAPTER_ID = flags.DEFINE_string( - "adapter_id", +_LORA_ADAPTER_ID = flags.DEFINE_string( + "lora_adapter_id", None, "Id of the fine-tuned adapter to be loaded on top of the base model.", required=False, ) -_ADAPTER_PATH = flags.DEFINE_string( - "adapter_path", +_LORA_ADAPTER_PATH = flags.DEFINE_string( + "lora_adapter_path", None, "Path of the fine-tuned adapter to be loaded from.", required=False, @@ -88,8 +90,8 @@ def main(argv: Sequence[str]) -> None: if _TEST_API_NAME.value == "load_lora_adapter": print(f"Calling the /v1/load_lora_adapter.") - adapter_id=_ADAPTER_ID.value - adapter_path=_ADAPTER_PATH.value + adapter_id=_LORA_ADAPTER_ID.value + adapter_path=_LORA_ADAPTER_PATH.value if adapter_id == None or adapter_path == None: print(f"For `load_lora_adapter` API call, `adapter_id` and `adapter_path` must be passed.") @@ -110,7 +112,7 @@ def main(argv: Sequence[str]) -> None: elif _TEST_API_NAME.value == "unload_lora_adapter": print(f"Calling the /v1/unload_lora_adapter.") - adapter_id=_ADAPTER_ID.value + adapter_id=_LORA_ADAPTER_ID.value if adapter_id == None: print(f"For `unload_lora_adapter` API call, `adapter_id` must be passed.") @@ -149,15 +151,16 @@ def main(argv: Sequence[str]) -> None: elif _TEST_API_NAME.value == "completions": print(f"Calling the /v1/completions.") - request = multi_lora_decoding_pb2.CompletionRequest( - text_content=multi_lora_decoding_pb2.CompletionRequest.TextContent( + request = jetstream_pb2.DecodeRequest( + text_content=jetstream_pb2.DecodeRequest.TextContent( text=_TEXT.value, ), max_tokens=_MAX_TOKENS.value, - adapter_id=_ADAPTER_ID.value, + lora_adapter_id=_LORA_ADAPTER_ID.value, ) + stub = jetstream_pb2_grpc.OrchestratorStub(channel) - response = stub.completions(request) + response = stub.Decode(request) output = [] for resp in response: diff --git a/jetstream/tools/multi_lora_decode_requester.py b/jetstream/tools/multi_lora_decode_requester.py index 6e859204..f69b2003 100644 --- a/jetstream/tools/multi_lora_decode_requester.py +++ b/jetstream/tools/multi_lora_decode_requester.py @@ -28,8 +28,8 @@ import grpc -from jetstream.core.proto import multi_lora_decoding_pb2 -from jetstream.core.proto import multi_lora_decoding_pb2_grpc +from jetstream.core.proto import jetstream_pb2 +from jetstream.core.proto import jetstream_pb2_grpc from jetstream.engine.token_utils import load_vocab from jetstream.external_tokenizers.llama3 import llama3_tokenizer import numpy as np @@ -41,6 +41,7 @@ class InputRequest: output: str = "" output_len: int = 0 sample_idx: int = -1 + adapter_id: str = "" @dataclass @@ -90,12 +91,12 @@ async def grpc_async_request( """Send grpc synchronous request since the current grpc server is sync.""" options = [("grpc.keepalive_timeout_ms", 10000)] async with grpc.aio.insecure_channel(api_url, options=options) as channel: - stub = multi_lora_decoding_pb2_grpc.v1Stub(channel) + stub = jetstream_pb2_grpc.OrchestratorStub(channel) print("Making request") ttft = 0 token_list = [] request_start_time = time.perf_counter() - response = stub.completions(request) + response = stub.Decode(request) async for resp in response: if ttft == 0: ttft = time.perf_counter() - request_start_time @@ -112,12 +113,12 @@ async def send_request( """Send the request to JetStream server.""" # Tokenization on client side following MLPerf standard. token_ids = tokenizer.encode(input_request.prompt) - request = multi_lora_decoding_pb2.CompletionRequest( - token_content=multi_lora_decoding_pb2.CompletionRequest.TokenContent( + request = jetstream_pb2.DecodeRequest( + token_content=jetstream_pb2.DecodeRequest.TokenContent( token_ids=token_ids ), max_tokens=input_request.output_len, - adapter_id=input_request.adapter_id, + lora_adapter_id=input_request.adapter_id, ) output = RequestFuncOutput() output.input_request = input_request From fe1511c79eb0d49ec62aba22a7f39a86b633329c Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 13 Mar 2025 23:27:37 +0000 Subject: [PATCH 18/22] Changes to resolve comments on the PR. --- jetstream/core/lora/adapter_tensorstore.py | 63 ++++++++++++---------- jetstream/core/orchestrator.py | 6 +-- jetstream/core/server_lib.py | 5 -- 3 files changed, 38 insertions(+), 36 deletions(-) diff --git a/jetstream/core/lora/adapter_tensorstore.py b/jetstream/core/lora/adapter_tensorstore.py index 13702ea5..224ff0e3 100644 --- a/jetstream/core/lora/adapter_tensorstore.py +++ b/jetstream/core/lora/adapter_tensorstore.py @@ -27,6 +27,7 @@ from typing import Dict, Optional, Any import numpy as np from jetstream.engine import engine_api +from enum import Enum def _get_size_of_pytree(params): @@ -54,12 +55,18 @@ def convert_if_np(leaf): return jax.tree_util.tree_map(convert_if_np, params) +class AdapterStatus(str, Enum): + UNLOADED = "unloaded" + LOADING = "loading" + LOADED_HBM = "loaded_hbm" + LOADED_CPU = "loaded_cpu" + @dataclasses.dataclass class AdapterMetadata: adapter_id: str adapter_path: str - status: str = "unloaded" # "loaded_hbm", "loaded_cpu", "loading", "unloading" + status: AdapterStatus = AdapterStatus.UNLOADED size_hbm: int = 0 # Size in HBM (bytes) size_cpu: int = 0 # Size in CPU RAM (bytes) last_accessed: float = 0.0 # timestamp @@ -155,7 +162,7 @@ async def _transfer_to_hbm(self, adapter_id: str): async with self.lock: #Acquire lock metadata = self.adapter_registry[adapter_id] - if metadata.status == "loaded_hbm": + if metadata.status == AdapterStatus.LOADED_HBM: return # Check if we have enough space in HBM; evict if necessary @@ -172,7 +179,7 @@ async def _transfer_to_hbm(self, adapter_id: str): self.current_cpu_usage -= metadata.size_cpu self.current_hbm_usage += metadata.size_hbm - metadata.status = "loaded_hbm" + metadata.status = AdapterStatus.LOADED_HBM metadata.last_accessed = time.time() @@ -185,7 +192,7 @@ async def _transfer_to_cpu(self, adapter_id: str): async with self.lock: metadata = self. adapter_registry[adapter_id] - if metadata.status == "loaded_cpu": + if metadata.status == AdapterStatus.LOADED_CPU: return # Check if we have enough space in CPU; evict if necessary. @@ -200,7 +207,7 @@ async def _transfer_to_cpu(self, adapter_id: str): self.current_hbm_usage -= metadata.size_hbm self.current_cpu_usage += metadata.size_cpu - metadata.status = "loaded_cpu" + metadata.status = AdapterStatus.LOADED_CPU metadata.last_accessed = time.time() @@ -211,7 +218,7 @@ async def get_hbm_loaded_adapters(self): async with self.lock: for adapter_id, metadata in self.adapter_registry.items(): - if metadata.status == "loaded_hbm": + if metadata.status == AdapterStatus.LOADED_HBM: hbm_loaded_adapters.append(adapter_id) return ", ".join(hbm_loaded_adapters) @@ -250,33 +257,33 @@ async def load_adapter( metadata = self.adapter_registry[adapter_id] async with self.lock: # Acquire lock for thread safety - if metadata.status in ("loaded_hbm", "loaded_cpu"): + if metadata.status in (AdapterStatus.LOADED_HBM, AdapterStatus.LOADED_CPU): metadata.last_accessed = time.time() # if already loaded in HBM and we want HBM, or # already loaded in CPU and we want CPU, we're done. - if ((to_hbm and metadata.status == "loaded_hbm") or - not to_hbm and metadata.status == "loaded_cpu"): + if ((to_hbm and metadata.status == AdapterStatus.LOADED_HBM) or + not to_hbm and metadata.status == AdapterStatus.LOADED_CPU): return - elif to_hbm and metadata.status == "loaded_cpu": + elif to_hbm and metadata.status == AdapterStatus.LOADED_CPU: # Transfer from cpu to hbm self._transfer_to_hbm(adapter_id) return - elif not to_hbm and metadata.status == "loaded_hbm": + elif not to_hbm and metadata.status == AdapterStatus.LOADED_HBM: # Transfer from hbm to cpu self._transfer_to_cpu(adapter_id) return - if metadata.status == "loading": + if metadata.status == AdapterStatus.LOADING: # Wait untill loading is done. - while metadata.status == "loading": + while metadata.status == AdapterStatus.LOADING: await asyncio.sleep(0.1) # Short sleep to avoid busy-waiting # Make recursive call to load_adapter to copy to device await self.load_adapter(adapter_id, adapter_weights, to_hbm) return - metadata.status = "loading" + metadata.status = AdapterStatus.LOADING self.running_requests += 1 # Load the adapter (asynchronous) @@ -319,18 +326,18 @@ async def load_adapter( if to_hbm: self.loaded_adapters_hbm[adapter_id] = adapter_weights_as_jnp_array # Convert the PyTree to Jax Array self.current_hbm_usage += adapter_size_hbm - metadata.status = "loaded_hbm" + metadata.status = AdapterStatus.LOADED_HBM else: #to cpu self.loaded_adapters_cpu[adapter_id] = adapter_weights_as_np_array # Convert the PyTree to NumPy Array self.current_cpu_usage += adapter_size_cpu - metadata.status = "loaded_cpu" + metadata.status = AdapterStatus.LOADED_CPU metadata.last_accessed = time.time() except Exception as e: async with self.lock: - metadata.status = "unloaded" # Mark as unloaded on error + metadata.status = AdapterStatus.UNLOADED # Mark as unloaded on error raise e # Re-Raise the exception finally: async with self.lock: @@ -368,11 +375,11 @@ def get_lora_weights(self, if metadata is None: raise ValueError(f"LoRA adapter with id={adapter_id} is not loaded.") - if metadata.status != "loaded_hbm" and metadata.status != "loaded_cpu": + if metadata.status != AdapterStatus.LOADED_HBM and metadata.status != AdapterStatus.LOADED_CPU: asyncio.run(self.load_adapter(adapter_id, None, to_hbm)) # Start loading (async) - elif to_hbm and metadata.status == "loaded_cpu": + elif to_hbm and metadata.status == AdapterStatus.LOADED_CPU: asyncio.run(self._transfer_to_hbm(adapter_id)) - elif not to_hbm and metadata.status == "loaded_hbm": + elif not to_hbm and metadata.status == AdapterStatus.LOADED_HBM: asyncio.run(self._transfer_to_cpu(adapter_id)) # Wait till all the running requests are completed @@ -397,21 +404,21 @@ async def unload_adapter(self, adapter_id: str): metadata = self.adapter_registry[adapter_id] async with self.lock: - if metadata.status == "unloaded": + if metadata.status == AdapterStatus.UNLOADED: return # Already unloaded - if metadata.status == "loading": + if metadata.status == AdapterStatus.LOADING: # Wait for the loading to get complete. - while metadata.status == "loading": + while metadata.status == AdapterStatus.LOADING: await asyncio.sleep(0.1) - if metadata.status == "loaded_hbm": + if metadata.status == AdapterStatus.LOADED_HBM: del self.loaded_adapters_hbm[adapter_id] self.current_hbm_usage -= metadata.size_hbm - metadata.status = "unloaded" - elif metadata.status == "loaded_cpu": + metadata.status = AdapterStatus.UNLOADED + elif metadata.status == AdapterStatus.LOADED_CPU: del self.loaded_adapters_cpu[adapter_id] self.current_cpu_usage -= metadata.size_cpu - metadata.status = "unloaded" + metadata.status = AdapterStatus.UNLOADED metadata.last_accessed = time.time() # Unload time metadata.size_hbm = 0 @@ -431,7 +438,7 @@ def _evict(self, from_hbm: bool = True) -> bool: lru_time = float('inf') for adapter_id, metadata in self.adapter_registry.items(): - if metadata.status == "loaded_hbm" if from_hbm else metadata.status == "loaded_cpu": + if metadata.status == AdapterStatus.LOADED_HBM if from_hbm else metadata.status == AdapterStatus.LOADED_CPU: if metadata.last_accessed < lru_time: lru_time = metadata.last_accessed lru_adapter_id = adapter_id diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 614d74e6..bed6c23e 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -1370,8 +1370,8 @@ def _detokenize_thread(self, idx: int): def load_adapter_to_tensorstore( self, - adapter_id, - adapter_path): + adapter_id: str, + adapter_path: str): """Load the adapter to adapter_tensorstore for each engine.""" logger.info("Loading adapter_id=%s from %s.", adapter_id, adapter_path) @@ -1429,7 +1429,7 @@ def load_adapter_to_tensorstore( def unload_adapter_from_tensorstore( self, - adapter_id): + adapter_id: str): """Unload the adapter from adapter_tensorstore of each engine.""" logger.info("Unloading adapter_id=%s", adapter_id) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 4b0ca6c9..2024a8b2 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -266,11 +266,6 @@ def run( Returns: JetStreamServer that wraps the grpc server and orchestrator driver. """ - # TODO: Deleting the lora_input_adapters_path for now. - # Planning to use it in next big PR. Currently accomodating it - # to fix the params mismatch between maxText and JetStream - del lora_input_adapters_path - server_start_time = time.time() logging.info("Kicking off gRPC server.") # Setup Prometheus server From f0da2b97c621158762e5a733e41295e2b24b2fbc Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 13 Mar 2025 23:39:12 +0000 Subject: [PATCH 19/22] Fixing Unit test errors. --- jetstream/core/orchestrator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index bed6c23e..7873d475 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -553,7 +553,7 @@ def _export_lora_request_info(self): if self._metrics_collector: for idx, engine in enumerate(self._generate_engines): max_loras += engine.max_concurrent_decodes - if idx < len(self._generate_adapterstore): + if self._generate_adapterstore and idx < len(self._generate_adapterstore): adapters_list_str += asyncio.run( self._generate_adapterstore[idx].get_hbm_loaded_adapters()) @@ -643,7 +643,7 @@ def _prefill_thread(self, idx: int): prefill_engine = self._prefill_engines[idx] prefill_params = self._prefill_params[idx] adapter_tensorstore = None - if idx < len(self._prefill_adapterstore): + if self._prefill_adapterstore and idx < len(self._prefill_adapterstore): adapter_tensorstore = self._prefill_adapterstore[idx] metadata = prefill_engine.get_tokenizer() tokenizer = prefill_engine.build_tokenizer(metadata) From 453db816e20fa2592b302a0a79cdac7b7bb2d760 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 15 Mar 2025 22:20:53 +0000 Subject: [PATCH 20/22] Fixing some failures due to merge conflicts. --- jetstream/core/lora/adapter_tensorstore.py | 5 ++-- jetstream/core/orchestrator.py | 31 +++++++++++++--------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/jetstream/core/lora/adapter_tensorstore.py b/jetstream/core/lora/adapter_tensorstore.py index 276ca2e5..4cd1f82d 100644 --- a/jetstream/core/lora/adapter_tensorstore.py +++ b/jetstream/core/lora/adapter_tensorstore.py @@ -87,11 +87,12 @@ class AdapterTensorStore: model to server multiple different LoRA adapters in a single batch. Args: + engine: Engine corresponding to the adapter tensorstore + adapters_dir_path: GCS path storing all the adapters hbm_memory_budget (int): The maximum amount of HBM (in bytes) to use for storing LoRA adapter weights. cpu_memory_budget (int): The maximum amount of CPU RAM (in bytes) to use for storing LoRA adapter weights. - total_slots: Number of generate slots. This is also equals to max_concurrent_decodes. """ def __init__(self, @@ -111,7 +112,7 @@ def __init__(self, self.current_cpu_usage: int = 0 self.running_requests: int = 0 # Number of async tasks which are in "loading" state self.decoding_adapters_cache: Dict[str, Any] = {} - self.total_slots = total_slots + self.total_slots = engine.max_concurrent_decodes self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 61bbfc34..58c4741d 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -304,15 +304,6 @@ def __init__( self._metrics_collector = metrics_collector self._multi_sampling = multi_sampling - total_slots = 0 - for engine in self._generate_engines: - total_slots += engine.max_concurrent_decodes - - self._adapter_tensorstore = adapter_tensorstore.AdapterTensorStore( - hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM - cpu_memory_budget=(100 * (1024 ** 3)), # 100 GB RAM - total_slots=total_slots) - # Stages 1-4 represent the life cycle of a request. # Stage 1 # At first, a request is placed here in order to get prefilled. @@ -562,7 +553,8 @@ def _export_lora_request_info(self): if self._metrics_collector: for idx, engine in enumerate(self._generate_engines): max_loras += engine.max_concurrent_decodes - if self._generate_adapterstore and idx < len(self._generate_adapterstore): + if (self._generate_adapterstore and + idx < len(self._generate_adapterstore)): adapters_list_str += asyncio.run( self._generate_adapterstore[idx].get_hbm_loaded_adapters()) @@ -908,6 +900,10 @@ def _insert_if_possible( # Check if there are any free my_slots. We don't want to block here since # we can still generate if we can't insert. We do this in a while loop to # insert as many sequences as possible. + adapter_tensorstore = None + if self._generate_adapterstore: + adapter_tensorstore = self._generate_adapterstore[idx] + while True: my_slots_size = my_slots.qsize() @@ -979,7 +975,9 @@ def _insert_if_possible( #request_id=new_request.request_id, ) - self._adapter_tensorstore.insert_adapter_in_cache(new_request.adapter_id, slot) + if adapter_tensorstore: + adapter_tensorstore.insert_adapter_in_cache( + new_request.adapter_id, slot) ThreadDebugLog( thread_name, @@ -1120,6 +1118,10 @@ def _generate_thread(self, idx: int): my_generate_backlog = self._generate_backlogs[idx] my_detokenize_backlog = self._detokenize_backlogs[idx] + adapter_tensorstore = None + if self._generate_adapterstore and idx < len(self._generate_adapterstore): + adapter_tensorstore = self._generate_adapterstore[idx] + # Keep track of what step tokens were generated at. generate_timestep = 0 # State to store things like running kv cache in. @@ -1185,9 +1187,14 @@ def _generate_thread(self, idx: int): my_slots.qsize() < max_concurrent_decodes ), "At this point we must have some requests inserted into the slots." + decoding_adapters_cache = None + + if adapter_tensorstore: + decoding_adapters_cache = adapter_tensorstore.decoding_adapters_cache + # Now we actually take a generate step on requests in the slots. decode_state, sampled_tokens = generate_engine.generate( - generate_params, decode_state, self._adapter_tensorstore.decoding_adapters_cache, + generate_params, decode_state, decoding_adapters_cache ) sampled_tokens.copy_to_host_async() # Respond to detokenization backpressure. From a38b686a151f11c19257b5ef799f75060754eddc Mon Sep 17 00:00:00 2001 From: aman2930 Date: Sat, 3 May 2025 00:39:21 +0000 Subject: [PATCH 21/22] Fixing some missed changes in last merge with main. --- jetstream/core/orchestrator.py | 103 --------------------- jetstream/core/proto/jetstream_pb2_grpc.py | 37 ++++---- 2 files changed, 19 insertions(+), 121 deletions(-) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 46c8f803..0611e54f 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -1611,109 +1611,6 @@ def list_adapters_from_tensorstore(self): return listed_adapters - def load_adapter_to_tensorstore( - self, - adapter_id: str, - adapter_path: str): - """Load the adapter to adapter_tensorstore for each engine.""" - logger.info("Loading adapter_id=%s from %s.", - adapter_id, adapter_path) - - for idx, tensorstore in enumerate(self._prefill_adapterstore): - try: - engine = self._prefill_engines[idx] - adapter_params, adapter_config = engine.load_single_adapter( - adapter_path) - - if not adapter_params or not adapter_config: - raise ValueError( - f"Failed to load adapter={adapter_id} from {adapter_path}.") - - tensorstore.register_adapter( - adapter_id, - adapter_path, - adapter_config) - - asyncio.run(tensorstore.load_adapter(adapter_id, adapter_params, True)) - - logger.info("Successfully loaded '%s' in engine_%d.", - adapter_id, idx) - engine.print_stats(f"After loading '{adapter_id}' in engine_{idx}") - - except Exception as e: - logger.info("Adapter loading failed with error: %s", str(e)) - raise e - - for idx, tensorstore in enumerate(self._generate_adapterstore): - try: - engine = self._generate_engines[idx] - adapter_params, adapter_config = engine.load_single_adapter( - adapter_path) - - if not adapter_params or not adapter_config: - raise ValueError( - f"Failed to load adapter={adapter_id} from {adapter_path}.") - - tensorstore.register_adapter( - adapter_id, - adapter_path, - adapter_config) - - asyncio.run(tensorstore.load_adapter(adapter_id, adapter_params, True)) - - logger.info("Successfully loaded '%s' in engine_%d.", - adapter_id, idx) - engine.print_stats(f"After loading '{adapter_id}' in engine_{idx}") - - except Exception as e: - logger.info("Adapter loading failed with error: %s", str(e)) - raise e - - - def unload_adapter_from_tensorstore( - self, - adapter_id: str): - """Unload the adapter from adapter_tensorstore of each engine.""" - logger.info("Unloading adapter_id=%s", adapter_id) - - for idx, tensorstore in enumerate(self._prefill_adapterstore): - try: - engine = self._prefill_engines[idx] - asyncio.run(tensorstore.unload_adapter(adapter_id)) - - logger.info("Successfully unloaded '%s' in engine_%d.", - adapter_id, idx) - engine.print_stats(f"After unloading '{adapter_id}' in engine_{idx}") - - except Exception as e: - logger.info("Adapter unloading failed with error: %s", str(e)) - raise e - - for idx, tensorstore in enumerate(self._generate_adapterstore): - try: - engine = self._generate_engines[idx] - asyncio.run(tensorstore.unload_adapter(adapter_id)) - - logger.info("Successfully unloaded '%s' in engine_%d.", - adapter_id, idx) - engine.print_stats(f"After unloading '{adapter_id}' in engine_{idx}") - - except Exception as e: - logger.info("Adapter unloading failed with error: %s", str(e)) - raise e - - - def list_adapters_from_tensorstore(self): - """List all the adapters from the adapter_tensorstore of each engine.""" - logger.info("Listing loaded adapters.") - - listed_adapters = {} - for tensorstore in self._generate_adapterstore: - listed_adapters.update(tensorstore.adapter_registry) - - return listed_adapters - - class LLMOrchestrator(jetstream_pb2_grpc.OrchestratorServicer): """Coordinates a set of prefill and generate slices for LLM decoding.""" diff --git a/jetstream/core/proto/jetstream_pb2_grpc.py b/jetstream/core/proto/jetstream_pb2_grpc.py index 8b0456d6..3302ab76 100644 --- a/jetstream/core/proto/jetstream_pb2_grpc.py +++ b/jetstream/core/proto/jetstream_pb2_grpc.py @@ -19,8 +19,13 @@ class OrchestratorStub(object): - """TODO: Merge this with main JetStream core once we settle on an API. + """TODO: Merge this with main JetStream core once we settle on an API.""" + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. """ self.Decode = channel.unary_stream( "/jetstream_proto.Orchestrator/Decode", @@ -35,23 +40,19 @@ class OrchestratorStub(object): class OrchestratorServicer(object): - """TODO: Merge this with main JetStream core once we settle on an API. - - """ + """TODO: Merge this with main JetStream core once we settle on an API.""" - def Decode(self, request, context): - """Query LLM to generate text or tokens. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + def Decode(self, request, context): + """Query LLM to generate text or tokens.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") - def HealthCheck(self, request, context): - """Checks if the model server is live. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + def HealthCheck(self, request, context): + """Checks if the model server is live.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def add_OrchestratorServicer_to_server(servicer, server): @@ -73,9 +74,9 @@ def add_OrchestratorServicer_to_server(servicer, server): server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class Orchestrator(object): - """TODO: Merge this with main JetStream core once we settle on an API. + """TODO: Merge this with main JetStream core once we settle on an API.""" @staticmethod def Decode( From 3ff6383227f33a518ee2e6136c366b7508d7e67b Mon Sep 17 00:00:00 2001 From: aman2930 Date: Fri, 9 May 2025 22:16:02 +0000 Subject: [PATCH 22/22] After merging with main, some of the code gets overridden. So adding back the code to support multi-LoRA adapters into same batch. --- jetstream/core/lora/adapter_tensorstore.py | 67 ++++++++++++++++++++++ jetstream/core/orchestrator.py | 6 +- jetstream/core/server_lib.py | 9 ++- jetstream/engine/engine_api.py | 1 - 4 files changed, 78 insertions(+), 5 deletions(-) diff --git a/jetstream/core/lora/adapter_tensorstore.py b/jetstream/core/lora/adapter_tensorstore.py index 1d303c1f..40dcf362 100644 --- a/jetstream/core/lora/adapter_tensorstore.py +++ b/jetstream/core/lora/adapter_tensorstore.py @@ -99,6 +99,7 @@ def __init__( adapters_dir_path: str, hbm_memory_budget: int, cpu_memory_budget: int, + total_slots: int, ): """Initializes the AdapterTensorStore.""" self.engine = engine # Possibly MaxEngine object @@ -119,6 +120,8 @@ def __init__( self.running_requests: int = ( 0 # Number of async tasks which are in "loading" state ) + self.decoding_adapters_cache: Dict[str, Any] = {} + self.total_slots = total_slots self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety # --- Unsafe Internal methods which assumes that lock is held --- @@ -207,6 +210,70 @@ def _unsafe_unload_adapter(self, adapter_id: str): metadata.size_hbm = 0 metadata.size_cpu = 0 + def _initialize_decoding_adapters_cache(self, adapter_weights): + """ + Create a new PyTree with zero tensors at the paths corresponding to non-None leaves + in the input PyTree. The zero tensors have an added dimension of size `self.totol_slots`. + Args: + adatper_weights: The input PyTree, whose structure will be mirrored. + Returns: + A new PyTree with zero Tensors or None values, mirroring the structure of the input PyTree. + """ + def create_zero_leaf(leaf): + if leaf is not None: + original_shape = leaf.shape + if not original_shape: # handle scalar case + zero_tensor_shape = (self.total_slots,) + else: + zero_tensor_shape = (self.total_slots,) + original_shape # Prepend a new dimension + + return jnp.zeros(zero_tensor_shape, dtype=leaf.dtype) + else: + return None # Maintain None structure for None leaves + + return jax.tree_util.tree_map(create_zero_leaf, adapter_weights) + + + def insert_adapter_in_cache(self, adapter_id: str, slot_id: int): + """ + Insert the specific adapter tensors into a slot in the serving_adapters_cache. + Args: + adapter_id: The id of the adapter, whose tensors will be inserted + slot_id: The id of slot, which represents the index in the serving_adapter_cache + where the adapter tensors will be inserted. + """ + + def insert_leaf(dest_leaf, source_leaf): + if dest_leaf is not None and source_leaf is not None: + return dest_leaf.at[slot_id].set(source_leaf) # Insert at the specific index + elif dest_leaf is not None: + return dest_leaf # If source_leaf is None, keep the zero_leaf as is + elif source_leaf is not None: # In this case the adapters have different target modules + original_shape = source_leaf.shape + if not original_shape: # Handle scalar case + zero_tensor_shape = (self.total_slots,) + else: + zero_tensor_shape = (self.total_slots,) + original_shape + new_dest_leaf = jnp.zeros(zero_tensor_shape, dtype=source_leaf.dtype) + return new_dest_leaf.at[slot_id].set(source_leaf) + else: + return None # If both are None, return None + + if adapter_id == "": + logging.info("Empty adapter id. So no LoRA tensors inserted into the cache in adapter_tensorStore.") + return + + asyncio.run(self.load_adapter(adapter_id, None, True)) + + adapter_weights = self.loaded_adapters_hbm[adapter_id] + + if not self.decoding_adapters_cache: + self.decoding_adapters_cache = self._initialize_decoding_adapters_cache(adapter_weights) + + self.decoding_adapters_cache = jax.tree_util.tree_map(insert_leaf, + self.decoding_adapters_cache, + adapter_weights) + # --- Public Methods (Acquire lock, then call unsafe methods) --- async def register_adapter( diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 0611e54f..0be42b0d 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -1031,7 +1031,7 @@ def _insert_if_possible( # we can still generate if we can't insert. We do this in a while loop to # insert as many sequences as possible. adapter_tensorstore = None - if self._generate_adapterstore: + if self._generate_adapterstore and idx < len(self._generate_adapterstore): adapter_tensorstore = self._generate_adapterstore[idx] while True: @@ -1102,7 +1102,6 @@ def _insert_if_possible( new_request.prefill_result, decode_state, slot=slot, - # request_id=new_request.request_id, ) if adapter_tensorstore: @@ -1321,10 +1320,11 @@ def _generate_thread(self, idx: int): if adapter_tensorstore: decoding_adapters_cache = adapter_tensorstore.decoding_adapters_cache + #decode_state["lora_adapter_cache"] = decoding_adapters_cache # Now we actually take a generate step on requests in the slots. decode_state, sampled_tokens = generate_engine.generate( - generate_params, decode_state, decoding_adapters_cache + generate_params, decode_state ) sampled_tokens.copy_to_host_async() # Respond to detokenization backpressure. diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 733fbba9..28e7c16a 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -174,6 +174,7 @@ def create_driver( shared_adapterstore = [] if lora_input_adapters_path: + # TODO: Make hbm_memory_budget and cpu_memory_budget configurable for pe in engines.prefill_engines: prefill_adapterstore.append( adapterstore.AdapterTensorStore( @@ -181,9 +182,10 @@ def create_driver( adapters_dir_path=lora_input_adapters_path, hbm_memory_budget=20 * (1024**3), # 20 GB HBM cpu_memory_budget=100 * (1024**3), # 100 GB RAM + total_slots=pe.max_concurrent_decodes, ) ) - # TODO: Make hbm_memory_budget and cpu_memory_budget configurable + for ge in engines.generate_engines: generate_adapterstore.append( adapterstore.AdapterTensorStore( @@ -191,6 +193,7 @@ def create_driver( adapters_dir_path=lora_input_adapters_path, hbm_memory_budget=20 * (1024**3), # 20 GB HBM cpu_memory_budget=100 * (1024**3), # 100 GB RAM + total_slots=ge.max_concurrent_decodes, ) ) @@ -201,6 +204,7 @@ def create_driver( adapters_dir_path=lora_input_adapters_path, hbm_memory_budget=20 * (1024**3), # 20 GB HBM cpu_memory_budget=100 * (1024**3), # 100 GB RAM + total_slots=ie.max_concurrent_decodes, ) ) @@ -315,6 +319,9 @@ def run( "Not starting Prometheus server: --prometheus_port flag not set" ) + if multi_sampling and lora_input_adapters_path: + raise ValueError("LoRA adapters is not enabled for multi_sampling mode.") + driver = create_driver( config, devices, diff --git a/jetstream/engine/engine_api.py b/jetstream/engine/engine_api.py index 22d8f137..9279e216 100644 --- a/jetstream/engine/engine_api.py +++ b/jetstream/engine/engine_api.py @@ -211,7 +211,6 @@ def generate( params: Params, decode_state: DecodeState, sampler: Optional[Callable[[Any], Any]] = None, - lora_params: Params = None, ) -> Tuple[DecodeState, ResultTokens]: """Generates tokens for each sequence being decoded in parallel.