Skip to content

Commit 04e1af9

Browse files
drbhdatavistics
andauthored
Enable multiple LoRa adapters (#2010)
* feat: first draft load multiple lora * feat: load weights within layer and refactor lora pass * fix: refactor and reduce lora math * feat: baseline impl single request multi lora support * feat: prefer lorax implementation and port loading logic * fix: prefer adapter_data and refactors * feat: perfer loraxs custom punica kernels and add mlp loras * fix: adjust batch for bgmv * fix: adjust adapter_segments logic when in batch * fix: refactor and move changes to v3 proto * fix: pass model_id for all flash causal lms * fix: pass model_id for all causal and seq2seq lms * fix: add model_id to model test * feat: add lora support to mistral and refactors * feat: prefer model id in request * fix: include rust code for adapter id * feat: bump launcher and add new lora docs * feat: support base model generation and refactors * fix: rename doc to retry ci build * feat: support if vlm models * fix: add adapter_data param and avoid missing layers * fix: add adapter_data param to phi and neox * fix: update all models forwards to include adapter_data * fix: add model_id to IdeficsCausalLM * Update lora.md Fixed a typo * Update lora.md Fixing spam image * fix: add lora kernel to dockerfile, support running without kernels and refactors * fix: avoid dockerfile conflict * fix: refactors and adjust flash llama lora logic * fix: skip llama test due to CI issue (temp) * fix: skip llama test CI (temp) 2 * fix: revert skips and prefer updated ci token for tests * fix: refactors and helpful comments * fix: add noop in TensorParallelAdapterRowLinear too * fix: refactor and move shard_lora_weights logic * fix: exit early if no adapter_data --------- Co-authored-by: Derek <[email protected]>
1 parent a2a97b0 commit 04e1af9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+2785
-76
lines changed

Dockerfile

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,13 @@ COPY server/marlin/ .
145145
# Build specific version of transformers
146146
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
147147

148+
# Build Lorax Punica kernels
149+
FROM kernel-builder as lorax-punica-builder
150+
WORKDIR /usr/src
151+
COPY server/Makefile-lorax-punica Makefile
152+
# Build specific version of transformers
153+
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
154+
148155
# Build Transformers CUDA kernels
149156
FROM kernel-builder as custom-kernels-builder
150157
WORKDIR /usr/src
@@ -215,6 +222,7 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86
215222
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
216223
# Copy build artifacts from marlin kernels builder
217224
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
225+
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
218226

219227
# Copy builds artifacts from vllm builder
220228
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
@@ -266,4 +274,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
266274
RUN chmod +x /tgi-entrypoint.sh
267275

268276
ENTRYPOINT ["/tgi-entrypoint.sh"]
269-
CMD ["--json-output"]
277+
# CMD ["--json-output"]

benchmark/src/generation.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ async fn prefill(
157157
top_n_tokens: top_n_tokens.unwrap_or(0),
158158
blocks: vec![],
159159
slots: vec![],
160+
adapter_id: None,
160161
})
161162
.collect();
162163

docs/source/_toctree.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@
6060
- local: conceptual/speculation
6161
title: Speculation (Medusa, ngram)
6262
- local: conceptual/guidance
63-
title: How Guidance Works (via outlines)
63+
title: How Guidance Works (via outlines
64+
- local: conceptual/lora
65+
title: LoRA (Low-Rank Adaptation)
66+
6467

6568
title: Conceptual Guides

docs/source/basic_tutorials/launcher.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,14 @@ Options:
416416
[env: MAX_CLIENT_BATCH_SIZE=]
417417
[default: 4]
418418
419+
```
420+
## LORA_ADAPTERS
421+
```shell
422+
--lora-adapters <LORA_ADAPTERS>
423+
Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request
424+
425+
[env: LORA_ADAPTERS=]
426+
419427
```
420428
## HELP
421429
```shell

docs/source/conceptual/lora.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# LoRA (Low-Rank Adaptation)
2+
3+
## What is LoRA?
4+
5+
LoRA is a technique that allows for efficent fine-tuning a model while only updating a small portion of the model's weights. This is useful when you have a large model that has been pre-trained on a large dataset, but you want to fine-tune it on a smaller dataset or for a specific task.
6+
7+
LoRA works by adding a small number of additional weights to the model, which are used to adapt the model to the new dataset or task. These additional weights are learned during the fine-tuning process, while the rest of the model's weights are kept fixed.
8+
9+
## How is it used?
10+
11+
LoRA can be used in many ways and the community is always finding new ways to use it. Here are some examples of how you can use LoRA:
12+
13+
Technically, LoRA can be used to fine-tune a large language model on a small dataset. However, these use cases can span a wide range of applications, such as:
14+
15+
- fine-tuning a language model on a small dataset
16+
- fine-tuning a language model on a domain-specific dataset
17+
- fine-tuning a language model on a dataset with limited labels
18+
19+
## Optimizing Inference with LoRA
20+
21+
LoRA's can be used during inference by mutliplying the adapter weights with the model weights at each specified layer. This process can be computationally expensive, but due to awesome work by [punica-ai](https://github.com/punica-ai/punica) and the [lorax](https://github.com/predibase/lorax) team, optimized kernels/and frameworks have been developed to make this process more efficient. TGI leverages these optimizations in order to provide fast and efficient inference with mulitple LoRA models.
22+
23+
## Serving multiple LoRA adapters with TGI
24+
25+
Once a LoRA model has been trained, it can be used to generate text or perform other tasks just like a regular language model. However, because the model has been fine-tuned on a specific dataset, it may perform better on that dataset than a model that has not been fine-tuned.
26+
27+
In practice its often useful to have multiple LoRA models, each fine-tuned on a different dataset or for a different task. This allows you to use the model that is best suited for a particular task or dataset.
28+
29+
Text Generation Inference (TGI) now supports loading multiple LoRA models at startup that can be used in generation requests. This feature is available starting from version `~2.0.6` and is compatible with LoRA models trained using the `peft` library.
30+
31+
### Specifying LoRA models
32+
33+
To use LoRA in TGI, when starting the server, you can specify the list of LoRA models to load using the `LORA_ADAPTERS` environment variable. For example:
34+
35+
```bash
36+
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
37+
```
38+
39+
In the server logs, you will see the following message:
40+
41+
```txt
42+
Loading adapter weights into model: predibase/customer_support
43+
Loading adapter weights into model: predibase/dbpedia
44+
```
45+
46+
## Generate text
47+
48+
You can then use these models in generation requests by specifying the `lora_model` parameter in the request payload. For example:
49+
50+
```json
51+
curl 127.0.0.1:3000/generate \
52+
-X POST \
53+
-H 'Content-Type: application/json' \
54+
-d '{
55+
"inputs": "Hello who are you?",
56+
"parameters": {
57+
"max_new_tokens": 40,
58+
"adapter_id": "predibase/customer_support"
59+
}
60+
}'
61+
```
62+
63+
> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.
64+
65+
An updated tutorial with detailed examples will be published soon. Stay tuned!

launcher/src/main.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,11 @@ struct Args {
452452
/// Control the maximum number of inputs that a client can send in a single request
453453
#[clap(default_value = "4", long, env)]
454454
max_client_batch_size: usize,
455+
456+
/// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during
457+
/// startup that will be available to callers via the `adapter_id` field in a request.
458+
#[clap(long, env)]
459+
lora_adapters: Option<String>,
455460
}
456461

457462
#[derive(Debug)]
@@ -485,6 +490,7 @@ fn shard_manager(
485490
max_total_tokens: usize,
486491
max_batch_size: Option<usize>,
487492
max_input_tokens: usize,
493+
lora_adapters: Option<String>,
488494
otlp_endpoint: Option<String>,
489495
otlp_service_name: String,
490496
log_level: LevelFilter,
@@ -620,6 +626,11 @@ fn shard_manager(
620626
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
621627
}
622628

629+
// Lora Adapters
630+
if let Some(lora_adapters) = lora_adapters {
631+
envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
632+
}
633+
623634
// If huggingface_hub_cache is some, pass it to the shard
624635
// Useful when running inside a docker container
625636
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
@@ -1060,6 +1071,7 @@ fn spawn_shards(
10601071
let rope_scaling = args.rope_scaling;
10611072
let rope_factor = args.rope_factor;
10621073
let max_batch_size = args.max_batch_size;
1074+
let lora_adapters = args.lora_adapters.clone();
10631075
thread::spawn(move || {
10641076
shard_manager(
10651077
model_id,
@@ -1085,6 +1097,7 @@ fn spawn_shards(
10851097
max_total_tokens,
10861098
max_batch_size,
10871099
max_input_tokens,
1100+
lora_adapters,
10881101
otlp_endpoint,
10891102
otlp_service_name,
10901103
max_log_level,

proto/v3/generate.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ message Request {
134134
repeated uint32 blocks = 9;
135135
/// Paged attention slots
136136
repeated uint32 slots = 10;
137+
/// LORA adapter index
138+
optional string adapter_id = 11;
137139
}
138140

139141
message Batch {

router/client/src/v3/client.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ impl Client {
177177
}),
178178
prefill_logprobs: true,
179179
top_n_tokens: 20,
180+
adapter_id: None,
180181
});
181182
n_tokens += max_input_length;
182183

router/client/src/v3/sharded_client.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ impl Health for ShardedClient {
244244
// Block 0 is reserved for health checks
245245
blocks: vec![0],
246246
slots: (0..16).collect(),
247+
adapter_id: None,
247248
};
248249
let batch = Batch {
249250
id: u64::MAX,

router/src/infer/v2/queue.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ mod tests {
429429
stop_sequences: vec![],
430430
},
431431
top_n_tokens: 0,
432+
adapter_id: None,
432433
},
433434
response_tx,
434435
span: info_span!("entry"),

0 commit comments

Comments
 (0)