diff --git a/jetstream/core/lora/adapter_tensorstore.py b/jetstream/core/lora/adapter_tensorstore.py index 1d303c1f..40dcf362 100644 --- a/jetstream/core/lora/adapter_tensorstore.py +++ b/jetstream/core/lora/adapter_tensorstore.py @@ -99,6 +99,7 @@ def __init__( adapters_dir_path: str, hbm_memory_budget: int, cpu_memory_budget: int, + total_slots: int, ): """Initializes the AdapterTensorStore.""" self.engine = engine # Possibly MaxEngine object @@ -119,6 +120,8 @@ def __init__( self.running_requests: int = ( 0 # Number of async tasks which are in "loading" state ) + self.decoding_adapters_cache: Dict[str, Any] = {} + self.total_slots = total_slots self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety # --- Unsafe Internal methods which assumes that lock is held --- @@ -207,6 +210,70 @@ def _unsafe_unload_adapter(self, adapter_id: str): metadata.size_hbm = 0 metadata.size_cpu = 0 + def _initialize_decoding_adapters_cache(self, adapter_weights): + """ + Create a new PyTree with zero tensors at the paths corresponding to non-None leaves + in the input PyTree. The zero tensors have an added dimension of size `self.totol_slots`. + Args: + adatper_weights: The input PyTree, whose structure will be mirrored. + Returns: + A new PyTree with zero Tensors or None values, mirroring the structure of the input PyTree. + """ + def create_zero_leaf(leaf): + if leaf is not None: + original_shape = leaf.shape + if not original_shape: # handle scalar case + zero_tensor_shape = (self.total_slots,) + else: + zero_tensor_shape = (self.total_slots,) + original_shape # Prepend a new dimension + + return jnp.zeros(zero_tensor_shape, dtype=leaf.dtype) + else: + return None # Maintain None structure for None leaves + + return jax.tree_util.tree_map(create_zero_leaf, adapter_weights) + + + def insert_adapter_in_cache(self, adapter_id: str, slot_id: int): + """ + Insert the specific adapter tensors into a slot in the serving_adapters_cache. + Args: + adapter_id: The id of the adapter, whose tensors will be inserted + slot_id: The id of slot, which represents the index in the serving_adapter_cache + where the adapter tensors will be inserted. + """ + + def insert_leaf(dest_leaf, source_leaf): + if dest_leaf is not None and source_leaf is not None: + return dest_leaf.at[slot_id].set(source_leaf) # Insert at the specific index + elif dest_leaf is not None: + return dest_leaf # If source_leaf is None, keep the zero_leaf as is + elif source_leaf is not None: # In this case the adapters have different target modules + original_shape = source_leaf.shape + if not original_shape: # Handle scalar case + zero_tensor_shape = (self.total_slots,) + else: + zero_tensor_shape = (self.total_slots,) + original_shape + new_dest_leaf = jnp.zeros(zero_tensor_shape, dtype=source_leaf.dtype) + return new_dest_leaf.at[slot_id].set(source_leaf) + else: + return None # If both are None, return None + + if adapter_id == "": + logging.info("Empty adapter id. So no LoRA tensors inserted into the cache in adapter_tensorStore.") + return + + asyncio.run(self.load_adapter(adapter_id, None, True)) + + adapter_weights = self.loaded_adapters_hbm[adapter_id] + + if not self.decoding_adapters_cache: + self.decoding_adapters_cache = self._initialize_decoding_adapters_cache(adapter_weights) + + self.decoding_adapters_cache = jax.tree_util.tree_map(insert_leaf, + self.decoding_adapters_cache, + adapter_weights) + # --- Public Methods (Acquire lock, then call unsafe methods) --- async def register_adapter( diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index ce5d7a4d..0be42b0d 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -1030,6 +1030,10 @@ def _insert_if_possible( # Check if there are any free my_slots. We don't want to block here since # we can still generate if we can't insert. We do this in a while loop to # insert as many sequences as possible. + adapter_tensorstore = None + if self._generate_adapterstore and idx < len(self._generate_adapterstore): + adapter_tensorstore = self._generate_adapterstore[idx] + while True: my_slots_size = my_slots.qsize() @@ -1098,8 +1102,12 @@ def _insert_if_possible( new_request.prefill_result, decode_state, slot=slot, - # request_id=new_request.request_id, ) + + if adapter_tensorstore: + adapter_tensorstore.insert_adapter_in_cache( + new_request.adapter_id, slot) + ThreadDebugLog( thread_name, f"Generate slice {idx} filled slot {slot} at step " @@ -1239,6 +1247,10 @@ def _generate_thread(self, idx: int): my_generate_backlog = self._generate_backlogs[idx] my_detokenize_backlog = self._detokenize_backlogs[idx] + adapter_tensorstore = None + if self._generate_adapterstore and idx < len(self._generate_adapterstore): + adapter_tensorstore = self._generate_adapterstore[idx] + # Keep track of what step tokens were generated at. generate_timestep = 0 # State to store things like running kv cache in. @@ -1304,6 +1316,12 @@ def _generate_thread(self, idx: int): my_slots.qsize() < max_concurrent_decodes ), "At this point we must have some requests inserted into the slots." + decoding_adapters_cache = None + + if adapter_tensorstore: + decoding_adapters_cache = adapter_tensorstore.decoding_adapters_cache + #decode_state["lora_adapter_cache"] = decoding_adapters_cache + # Now we actually take a generate step on requests in the slots. decode_state, sampled_tokens = generate_engine.generate( generate_params, decode_state diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 733fbba9..28e7c16a 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -174,6 +174,7 @@ def create_driver( shared_adapterstore = [] if lora_input_adapters_path: + # TODO: Make hbm_memory_budget and cpu_memory_budget configurable for pe in engines.prefill_engines: prefill_adapterstore.append( adapterstore.AdapterTensorStore( @@ -181,9 +182,10 @@ def create_driver( adapters_dir_path=lora_input_adapters_path, hbm_memory_budget=20 * (1024**3), # 20 GB HBM cpu_memory_budget=100 * (1024**3), # 100 GB RAM + total_slots=pe.max_concurrent_decodes, ) ) - # TODO: Make hbm_memory_budget and cpu_memory_budget configurable + for ge in engines.generate_engines: generate_adapterstore.append( adapterstore.AdapterTensorStore( @@ -191,6 +193,7 @@ def create_driver( adapters_dir_path=lora_input_adapters_path, hbm_memory_budget=20 * (1024**3), # 20 GB HBM cpu_memory_budget=100 * (1024**3), # 100 GB RAM + total_slots=ge.max_concurrent_decodes, ) ) @@ -201,6 +204,7 @@ def create_driver( adapters_dir_path=lora_input_adapters_path, hbm_memory_budget=20 * (1024**3), # 20 GB HBM cpu_memory_budget=100 * (1024**3), # 100 GB RAM + total_slots=ie.max_concurrent_decodes, ) ) @@ -315,6 +319,9 @@ def run( "Not starting Prometheus server: --prometheus_port flag not set" ) + if multi_sampling and lora_input_adapters_path: + raise ValueError("LoRA adapters is not enabled for multi_sampling mode.") + driver = create_driver( config, devices,