diff --git a/proto/generate.proto b/proto/generate.proto index 366a54180a5..6351e37f2c9 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -107,8 +107,6 @@ message Request { bool prefill_logprobs = 6; /// Return most likely n tokens uint32 top_n_tokens = 7; - /// LORA adapter index - optional uint32 adapter_index = 8; } message Batch { diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 01cc43fd320..c7a8013bd4a 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -134,6 +134,8 @@ message Request { repeated uint32 blocks = 9; /// Paged attention slots repeated uint32 slots = 10; + /// LORA adapter index + optional uint32 adapter_index = 11; } message Batch { diff --git a/router/client/src/v2/client.rs b/router/client/src/v2/client.rs index ff1a70eb59b..9a2e6ac79f9 100644 --- a/router/client/src/v2/client.rs +++ b/router/client/src/v2/client.rs @@ -154,7 +154,6 @@ impl Client { }), prefill_logprobs: true, top_n_tokens: 20, - adapter_index: None, }); n_tokens += max_input_length; diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 9a3892fb186..5ced4056e23 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -177,6 +177,7 @@ impl Client { }), prefill_logprobs: true, top_n_tokens: 20, + adapter_index: None, }); n_tokens += max_input_length; diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 94002f55064..300deccaf5c 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -244,6 +244,7 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + adapter_index: None, }; let batch = Batch { id: u64::MAX, diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index e284b251ce8..f02056973b9 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -290,7 +290,6 @@ impl State { entry.request.stopping_parameters.clone(), )), top_n_tokens: entry.request.top_n_tokens, - adapter_index: entry.request.adapter_index, }); // Set batch_time entry.batch_time = Some(Instant::now()); diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 0b66142abb1..fbfdf715125 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -351,6 +351,7 @@ impl State { top_n_tokens: entry.request.top_n_tokens, blocks, slots, + adapter_index: entry.request.adapter_index, }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -491,6 +492,7 @@ mod tests { stop_sequences: vec![], }, top_n_tokens: 0, + adapter_index: None, }, response_tx, span: info_span!("entry"), diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index f56e6f3d45d..704c8ef3b8b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -63,7 +63,7 @@ DOWN_PROJ = "down_proj" -def load_attention(config, prefix, weights): +def load_attention(config, prefix, weights, layer_id): # Only defined in granite. bias = getattr(config, "attention_bias", False) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 5aa2aecaad2..557656e786e 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -3,7 +3,6 @@ from pathlib import Path from typing import List, Dict, Optional, Set, Tuple, Union from safetensors import safe_open, SafetensorError -from safetensors.torch import load_file import torch from loguru import logger from huggingface_hub import hf_hub_download