From 7ef8b89ee7f8dd5fa7fcdc8e84f8f5ded04d1fb6 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 28 Jan 2025 11:19:37 +0100 Subject: [PATCH 1/7] More logs in the allocator. --- backends/v3/src/radix.rs | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 532ec6ddcc8..223ac67be3f 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -57,13 +57,18 @@ impl RadixAllocator { // temporary, the trie needs to be able to report whether it can // allocate the requested amount. Just not implemented yet. tracing::debug!( - "Free blocks {} need {n_blocks_needed}", + "Free blocks {} need {n_blocks_needed}", self.free_blocks.len() ); - self.free_blocks.extend( - self.cache_blocks - .evict(n_blocks_needed - self.free_blocks.len()), + let free_blocks = self + .cache_blocks + .evict(n_blocks_needed - self.free_blocks.len()); + tracing::debug!( + "Freed {} blocks: Now having {} free blocks", + free_blocks.len(), + free_blocks.len() + self.free_blocks.len() ); + self.free_blocks.extend(free_blocks); } if self.free_blocks.len() >= n_blocks_needed { @@ -106,6 +111,9 @@ impl Allocator for RadixAllocator { let suffix_blocks = suffix_len.div_ceil(self.block_size); tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}"); + metrics::counter!("tgi_cache_hit", "allocator" => "radix") + .increment(prefix_len.try_into().expect("Can convert usize to u64")); + metrics::counter!("tgi_cache_total", "allocator" => "radix").increment(suffix_len.into()); match self.alloc_or_reclaim(suffix_blocks as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), From 6a88063cc29ef2ed97120156ecae1be11ee8b425 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 28 Jan 2025 19:48:17 +0100 Subject: [PATCH 2/7] Adding Dummy kvrouter. --- Cargo.lock | 250 ++++++++++++++++++++++++++++++++++------- Cargo.toml | 4 +- kvrouter/Cargo.toml | 17 +++ kvrouter/src/lib.rs | 146 ++++++++++++++++++++++++ kvrouter/src/main.rs | 30 +++++ kvrouter/src/trie.rs | 257 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 662 insertions(+), 42 deletions(-) create mode 100644 kvrouter/Cargo.toml create mode 100644 kvrouter/src/lib.rs create mode 100644 kvrouter/src/main.rs create mode 100644 kvrouter/src/trie.rs diff --git a/Cargo.lock b/Cargo.lock index af3e19027fc..ee52a2b429e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,11 +24,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom", + "getrandom 0.2.15", "once_cell", "serde", "version_check", - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -291,7 +291,7 @@ dependencies = [ "http-body 0.4.6", "hyper 0.14.31", "itoa", - "matchit", + "matchit 0.7.3", "memchr", "mime", "percent-encoding", @@ -321,10 +321,10 @@ dependencies = [ "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.5.1", + "hyper 1.5.2", "hyper-util", "itoa", - "matchit", + "matchit 0.7.3", "memchr", "mime", "percent-encoding", @@ -336,7 +336,42 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 1.0.2", "tokio", - "tower 0.5.1", + "tower 0.5.2", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" +dependencies = [ + "axum-core 0.5.0", + "axum-macros", + "bytes", + "form_urlencoded", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.5.2", + "hyper-util", + "itoa", + "matchit 0.8.4", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper 1.0.2", + "tokio", + "tower 0.5.2", "tower-layer", "tower-service", "tracing", @@ -380,6 +415,37 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-core" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" +dependencies = [ + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper 1.0.2", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "axum-tracing-opentelemetry" version = "0.16.0" @@ -1442,7 +1508,19 @@ checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets 0.52.6", ] [[package]] @@ -1596,7 +1674,7 @@ dependencies = [ "log", "native-tls", "num_cpus", - "rand", + "rand 0.8.5", "reqwest 0.11.27", "serde", "serde_json", @@ -1719,9 +1797,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.5.1" +version = "1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" +checksum = "256fb8d4bd6413123cc9d91832d78325c48ff41677595be797d90f42969beae0" dependencies = [ "bytes", "futures-channel", @@ -1746,7 +1824,7 @@ checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.5.1", + "hyper 1.5.2", "hyper-util", "log", "rustls 0.23.17", @@ -1793,7 +1871,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.1", - "hyper 1.5.1", + "hyper 1.5.2", "pin-project-lite", "socket2", "tokio", @@ -2166,6 +2244,21 @@ dependencies = [ "uuid-simd", ] +[[package]] +name = "kvrouter" +version = "3.0.2-dev0" +dependencies = [ + "async-stream", + "axum 0.8.1", + "futures", + "futures-util", + "hyper 1.5.2", + "hyper-util", + "rand 0.9.0", + "slotmap", + "tokio", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -2318,6 +2411,12 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "maybe-rayon" version = "0.1.1" @@ -2361,7 +2460,7 @@ checksum = "b4f0c8427b39666bf970460908b213ec09b3b350f20c0c2eabcbba51704a08e6" dependencies = [ "base64 0.22.1", "http-body-util", - "hyper 1.5.1", + "hyper 1.5.2", "hyper-rustls", "hyper-util", "indexmap 2.6.0", @@ -2450,7 +2549,7 @@ dependencies = [ "hermit-abi 0.3.9", "libc", "log", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] @@ -2499,7 +2598,7 @@ dependencies = [ "bytes", "futures", "pin-project", - "rand", + "rand 0.8.5", "thiserror", "tokio", "tokio-util", @@ -2931,7 +3030,7 @@ dependencies = [ "opentelemetry_api", "ordered-float 3.9.2", "percent-encoding", - "rand", + "rand 0.8.5", "regex", "serde_json", "thiserror", @@ -2955,7 +3054,7 @@ dependencies = [ "opentelemetry 0.21.0", "ordered-float 4.5.0", "percent-encoding", - "rand", + "rand 0.8.5", "thiserror", ] @@ -3159,7 +3258,7 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -3392,7 +3491,7 @@ dependencies = [ "libc", "once_cell", "raw-cpuid", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "web-sys", "winapi", ] @@ -3419,8 +3518,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.0", + "zerocopy 0.8.14", ] [[package]] @@ -3430,7 +3540,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.0", ] [[package]] @@ -3439,7 +3559,17 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", +] + +[[package]] +name = "rand_core" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b08f3c9802962f7e1b25113931d94f43ed9725bebc59db9d0c3e9a23b67e15ff" +dependencies = [ + "getrandom 0.3.1", + "zerocopy 0.8.14", ] [[package]] @@ -3489,8 +3619,8 @@ dependencies = [ "once_cell", "paste", "profiling", - "rand", - "rand_chacha", + "rand 0.8.5", + "rand_chacha 0.3.1", "simd_helpers", "system-deps", "thiserror", @@ -3568,7 +3698,7 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ - "getrandom", + "getrandom 0.2.15", "libredox", "thiserror", ] @@ -3704,7 +3834,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.5.1", + "hyper 1.5.2", "hyper-util", "ipnet", "js-sys", @@ -3755,7 +3885,7 @@ checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" dependencies = [ "cc", "cfg-if", - "getrandom", + "getrandom 0.2.15", "libc", "spin 0.9.8", "untrusted 0.9.0", @@ -4531,7 +4661,7 @@ dependencies = [ "opentelemetry-otlp", "outlines-core", "pyo3", - "rand", + "rand 0.8.5", "regex", "reqwest 0.11.27", "serde", @@ -4579,7 +4709,7 @@ dependencies = [ "opentelemetry-otlp", "prost 0.12.6", "prost-build", - "rand", + "rand 0.8.5", "regex", "reqwest 0.11.27", "serde", @@ -4630,7 +4760,7 @@ dependencies = [ "opentelemetry-otlp", "prost 0.12.6", "prost-build", - "rand", + "rand 0.8.5", "regex", "reqwest 0.11.27", "serde", @@ -4764,7 +4894,7 @@ dependencies = [ "aho-corasick", "derive_builder", "esaxx-rs", - "getrandom", + "getrandom 0.2.15", "hf-hub", "indicatif", "itertools 0.12.1", @@ -4774,7 +4904,7 @@ dependencies = [ "monostate", "onig", "paste", - "rand", + "rand 0.8.5", "rayon", "rayon-cond", "regex", @@ -4844,7 +4974,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f" dependencies = [ "pin-project", - "rand", + "rand 0.8.5", "tokio", ] @@ -4997,7 +5127,7 @@ dependencies = [ "indexmap 1.9.3", "pin-project", "pin-project-lite", - "rand", + "rand 0.8.5", "slab", "tokio", "tokio-util", @@ -5008,14 +5138,14 @@ dependencies = [ [[package]] name = "tower" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", "pin-project-lite", - "sync_wrapper 0.1.2", + "sync_wrapper 1.0.2", "tokio", "tower-layer", "tower-service", @@ -5370,8 +5500,8 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ - "getrandom", - "rand", + "getrandom 0.2.15", + "rand 0.8.5", "uuid-macro-internal", ] @@ -5479,6 +5609,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasm-bindgen" version = "0.2.95" @@ -5926,6 +6065,15 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags 2.6.0", +] + [[package]] name = "write16" version = "1.0.0" @@ -5975,7 +6123,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a367f292d93d4eab890745e75a778da40909cab4d6ff8173693812f79c4a2468" +dependencies = [ + "zerocopy-derive 0.8.14", ] [[package]] @@ -5989,6 +6146,17 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "zerocopy-derive" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3931cb58c62c13adec22e38686b559c86a30565e16ad6e8510a337cedc611e1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "zerofrom" version = "0.1.4" diff --git a/Cargo.toml b/Cargo.toml index 9f49c9abe3f..3b6c20a1676 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ members = [ "backends/trtllm", "launcher", "router" -] +, "kvrouter"] default-members = [ "benchmark", "backends/v2", @@ -15,6 +15,7 @@ default-members = [ "backends/grpc-metadata", # "backends/trtllm", "launcher", + "kvrouter", "router" ] resolver = "2" @@ -34,6 +35,7 @@ metrics-exporter-prometheus = { version = "0.15.1", features = [] } minijinja = { version = "2.2.0", features = ["json"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } pyo3 = { version = "0.22.2", features = ["auto-initialize"] } +axum = { version = "0.7", features = ["json"] } [profile.release] incremental = true diff --git a/kvrouter/Cargo.toml b/kvrouter/Cargo.toml new file mode 100644 index 00000000000..2e54f6e46f7 --- /dev/null +++ b/kvrouter/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "kvrouter" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +async-stream = "0.3.6" +axum = { version = "0.8.1", features = ["macros"] } +futures = "0.3.31" +futures-util = "0.3.31" +hyper = { version = "1.5.2", features = ["full"] } +hyper-util = { version = "0.1.10", features = ["full"] } +rand = "0.9.0" +slotmap = "1.0.7" +tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] } diff --git a/kvrouter/src/lib.rs b/kvrouter/src/lib.rs new file mode 100644 index 00000000000..31d1d81fab6 --- /dev/null +++ b/kvrouter/src/lib.rs @@ -0,0 +1,146 @@ +use axum::{ + body::Body, + extract::{Request, State}, + http::uri::Uri, + response::{IntoResponse, Response}, +}; +use futures_util::stream::StreamExt; +use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; +use rand::{rng, Rng}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; + +mod trie; + +use crate::trie::Trie; + +const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR"; +type Client = hyper_util::client::legacy::Client; + +#[derive(Clone)] +pub struct RoundRobin { + client: Client, + trie: Arc>, + backends: Arc>, + inqueue: Arc>, + inflight: Arc>, + factor: f32, +} + +impl RoundRobin { + pub fn new(backends: Vec) -> Self { + let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) + .build(HttpConnector::new()); + let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect()); + let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect()); + let trie = Arc::new(Mutex::new(Trie::new())); + let factor: f32 = std::env::var(FACTOR_KEY) + .unwrap_or("1.5".to_string()) + .parse() + .unwrap_or(1.5); + Self { + inflight, + inqueue, + trie, + client, + factor, + backends: Arc::new(backends), + } + } + + pub fn next(&mut self, key: &[u8]) -> usize { + let mut trie = self.trie.lock().unwrap(); + let (start, stop) = trie.insert(key); + let n = trie.count(); + eprintln!( + "Start {start} Stop {stop} N {n} : Key {}", + String::from_utf8_lossy(key) + ); + let mut rng = rng(); + let x: f32 = rng.random(); + println!("Random number is {x:.2}"); + let start = (start as f32) / (n as f32); + let stop = (stop as f32) / (n as f32); + let rescaled_x = x * (stop - start) + start; + assert!(rescaled_x >= start); + assert!(rescaled_x <= stop); + assert!(rescaled_x >= 0.0); + assert!(rescaled_x <= 1.0); + println!("Start {start:.2} stop {stop:.2}: rescaled {rescaled_x:.2}"); + let n: usize = (rescaled_x * (self.backends.len() as f32)) as usize; + n + } +} + +pub async fn handler(State(mut state): State, req: Request) -> Response { + // Get the next backend index + let limit = 2048usize; + let (parts, body) = req.into_parts(); + // TODO + let bytes = axum::body::to_bytes(body, limit).await.unwrap(); + let index = state.next(&bytes); + // Get the backend URL + let n = state.backends.len(); + let mut index = index % n; + let backend = &state.backends[index]; + + let mut inflight = state.inflight[index].load(Ordering::Relaxed); + let mut inqueue = state.inqueue[index].load(Ordering::Relaxed); + + for i in 0..n { + if (inqueue as f32) <= state.factor * inflight as f32 { + break; + } + if i == 0 { + eprintln!("Backend overloaded (queue: {inqueue} inflight {inflight}), jumping ahead"); + } + index += 1; + index %= state.backends.len(); + inflight = state.inflight[index].load(Ordering::Relaxed); + inqueue = state.inflight[index].load(Ordering::Relaxed); + } + state.inflight[index].fetch_add(1, Ordering::Relaxed); + state.inqueue[index].fetch_add(1, Ordering::Relaxed); + + let body: Body = bytes.into(); + let mut req = Request::from_parts(parts, body); + let path = req.uri().path(); + let path_query = req + .uri() + .path_and_query() + .map(|v| v.as_str()) + .unwrap_or(path); + + let uri = format!("{backend}{path_query}"); + eprintln!("Inflight {uri}"); + *req.uri_mut() = Uri::try_from(uri).unwrap(); + + let response = state + .client + .request(req) + .await + // TODO + .unwrap(); + //.map_err(|_| StatusCode::BAD_GATEWAY)?; + let response = response.into_response(); + let (parts, body) = response.into_parts(); + let response_stream = body.into_data_stream(); + let response_stream = async_stream::stream! { + let mut response_stream = Box::pin(response_stream); + let mut start = true; + while let Some(raw_event) = response_stream.next().await { + if start{ + eprintln!("Not inqueue"); + state.inqueue[index].fetch_sub(1, Ordering::Relaxed); + start = false; + } + yield raw_event; + } + eprintln!("Not inflight"); + state.inflight[index].fetch_sub(1, Ordering::Relaxed); + }; + + let body = Body::from_stream(response_stream); + + Response::from_parts(parts, body) +} diff --git a/kvrouter/src/main.rs b/kvrouter/src/main.rs new file mode 100644 index 00000000000..7a213e45d5e --- /dev/null +++ b/kvrouter/src/main.rs @@ -0,0 +1,30 @@ +use axum::{ + routing::Router, + routing::{get, post}, +}; +use kvrouter::{handler, RoundRobin}; + +#[tokio::main] +async fn main() { + // List of backend servers + let backends = vec![ + "http://localhost:8000".to_string(), + "http://localhost:8001".to_string(), + ]; + + // Create a new instance of the RoundRobinRouter + let router = RoundRobin::new(backends); + + // Create the Axum router + let app = Router::new() + .route("/{*key}", get(handler)) + .route("/{*key}", post(handler)) + .with_state(router); + + // run it + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + println!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); +} diff --git a/kvrouter/src/trie.rs b/kvrouter/src/trie.rs new file mode 100644 index 00000000000..002778bc316 --- /dev/null +++ b/kvrouter/src/trie.rs @@ -0,0 +1,257 @@ +use std::collections::BTreeMap; + +// TODO +#[allow(dead_code)] +#[cfg_attr(test, derive(Debug, PartialEq))] +pub enum Error { + MissingEntry, +} + +#[derive(Clone)] +pub struct Trie { + root: Node, +} + +#[derive(Clone)] +#[cfg_attr(test, derive(Debug, PartialEq))] +pub struct Node { + content: Vec, + nelements: usize, + children: BTreeMap, +} + +pub fn mismatch(xs: &[u8], ys: &[u8]) -> usize { + // SIMD + mismatch_chunks::<128>(xs, ys) +} + +fn mismatch_chunks(xs: &[u8], ys: &[u8]) -> usize { + let off = xs + .chunks_exact(N) + .zip(ys.chunks_exact(N)) + .take_while(|(x, y)| x == y) + .count() + * N; + off + xs[off..] + .iter() + .zip(&ys[off..]) + .take_while(|(x, y)| x == y) + .count() +} + +impl Node { + fn new() -> Self { + Self { + content: vec![], + nelements: 0, + children: BTreeMap::new(), + } + } + + fn insert(&mut self, data: &[u8], left: usize) -> (usize, usize) { + let (start, stop) = if self.nelements == 0 { + self.content = data.to_vec(); + (left, left + 1) + } else { + let mismatch = mismatch(data, &self.content); + if mismatch == self.content.len() { + // Full prefix match, just dive deeper + let (start, stop) = if let Some(c) = data.get(mismatch) { + let left: usize = self + .children + .iter() + .take_while(|(&d, _)| d < *c) + .map(|(_, n)| n.nelements) + .sum(); + let next_node = self.children.entry(*c).or_insert(Node::new()); + next_node.insert(&data[mismatch..], left) + } else { + (0, self.nelements + 1) + }; + (left + start, left + stop) + } else { + // Partial match, split node + let left = self.content[mismatch..].to_vec(); + let right = data[mismatch..].to_vec(); + + let children = std::mem::take(&mut self.children); + let mut children_content = vec![ + (left, children, self.nelements), + (right, BTreeMap::new(), 1), + ]; + children_content.sort_by(|a, b| a.0.cmp(&b.0)); + self.content.truncate(mismatch); + self.children.clear(); + for (child_content, children, nelements) in children_content { + if !child_content.is_empty() { + let c = child_content[0]; + let child = Node { + content: child_content, + nelements, + children, + }; + self.children.insert(c, child); + } + } + let c = data[mismatch]; + let left: usize = self + .children + .iter() + .take_while(|(&d, _)| d < c) + .map(|(_, n)| n.nelements) + .sum(); + (left, left + 1) + } + }; + self.nelements += 1; + (start, stop) + } + + // TODO + #[allow(dead_code)] + fn remove(&mut self, data: &[u8]) -> Result<(), Error> { + let mismatch = mismatch(data, &self.content); + if mismatch != self.content.len() { + Err(Error::MissingEntry) + } else { + if let Some(c) = data.get(mismatch) { + if let Some(node) = self.children.get_mut(c) { + node.remove(&data[mismatch..])?; + } + } + self.nelements -= 1; + Ok(()) + } + } +} + +impl Trie { + pub fn new() -> Self { + let root = Node::new(); + Self { root } + } + + pub fn insert(&mut self, data: &[u8]) -> (usize, usize) { + self.root.insert(data, 0) + } + + // TODO + #[allow(dead_code)] + pub fn remove(&mut self, data: &[u8]) -> Result<(), Error> { + self.root.remove(data) + } + + pub fn count(&self) -> usize { + self.root.nelements + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn simple() { + let mut trie = Trie::new(); + assert_eq!(trie.insert(b"toto"), (0, 1)); + assert_eq!(trie.insert(b"tata"), (0, 1)); + + assert_eq!(trie.root.nelements, 2); + assert_eq!(trie.root.content, b"t"); + assert_eq!(trie.root.children.len(), 2); + assert_eq!( + trie.root.children, + BTreeMap::from_iter([ + ( + b'a', + Node { + nelements: 1, + content: b"ata".to_vec(), + children: BTreeMap::new() + } + ), + ( + b'o', + Node { + nelements: 1, + content: b"oto".to_vec(), + children: BTreeMap::new() + } + ) + ]) + ); + assert_eq!(trie.insert(b"coco"), (0, 1)); + assert_eq!(trie.insert(b"zaza"), (3, 4)); + assert_eq!(trie.root.nelements, 4); + assert_eq!(trie.root.content, b""); + assert_eq!(trie.root.children.len(), 3); + assert_eq!( + trie.root.children, + BTreeMap::from_iter([ + ( + b'c', + Node { + nelements: 1, + content: b"coco".to_vec(), + children: BTreeMap::new() + } + ), + ( + b't', + Node { + nelements: 2, + content: b"t".to_vec(), + children: BTreeMap::from_iter([ + ( + b'a', + Node { + nelements: 1, + content: b"ata".to_vec(), + children: BTreeMap::new() + } + ), + ( + b'o', + Node { + nelements: 1, + content: b"oto".to_vec(), + children: BTreeMap::new() + } + ) + ]) + } + ), + ( + b'z', + Node { + nelements: 1, + content: b"zaza".to_vec(), + children: BTreeMap::new() + } + ), + ]) + ); + } + + #[test] + fn delete() { + let mut trie = Trie::new(); + trie.insert(b"toto"); + trie.insert(b"tata"); + + assert_eq!(trie.root.nelements, 2); + assert_eq!(trie.remove(b"coco"), Err(Error::MissingEntry)); + assert_eq!(trie.remove(b"toto"), Ok(())); + assert_eq!(trie.root.nelements, 1); + } + + #[test] + fn duplicate() { + let mut trie = Trie::new(); + assert_eq!(trie.insert(b"toto"), (0, 1)); + assert_eq!(trie.insert(b"toto"), (0, 2)); + assert_eq!(trie.root.nelements, 2); + assert_eq!(trie.remove(b"toto"), Ok(())); + assert_eq!(trie.root.nelements, 1); + } +} From 0a495ad118b2a969871278d2ebf3b6480829b93c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 29 Jan 2025 12:40:26 +0100 Subject: [PATCH 3/7] Updating the kvrouter to support roundrobin for comparison, still withthe overloading checks --- backends/v3/src/radix.rs | 7 +- kvrouter/src/lib.rs | 153 +++++++++++++++++++++++++++------------ kvrouter/src/main.rs | 47 ++++++++---- load_tests/common.js | 2 +- 4 files changed, 147 insertions(+), 62 deletions(-) diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 223ac67be3f..cc617c041ec 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -110,10 +110,15 @@ impl Allocator for RadixAllocator { let suffix_blocks = suffix_len.div_ceil(self.block_size); + let prefix_len_uncached = prefill_tokens.as_ref().map(|p| p.len()).unwrap_or_default(); tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}"); metrics::counter!("tgi_cache_hit", "allocator" => "radix") .increment(prefix_len.try_into().expect("Can convert usize to u64")); - metrics::counter!("tgi_cache_total", "allocator" => "radix").increment(suffix_len.into()); + metrics::counter!("tgi_cache_total", "allocator" => "radix").increment( + prefix_len_uncached + .try_into() + .expect("Can convert usize to u64"), + ); match self.alloc_or_reclaim(suffix_blocks as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), diff --git a/kvrouter/src/lib.rs b/kvrouter/src/lib.rs index 31d1d81fab6..ba0bf3c8999 100644 --- a/kvrouter/src/lib.rs +++ b/kvrouter/src/lib.rs @@ -18,37 +18,25 @@ const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR"; type Client = hyper_util::client::legacy::Client; #[derive(Clone)] -pub struct RoundRobin { - client: Client, +pub struct ContentAware { trie: Arc>, - backends: Arc>, - inqueue: Arc>, - inflight: Arc>, - factor: f32, } -impl RoundRobin { - pub fn new(backends: Vec) -> Self { - let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) - .build(HttpConnector::new()); - let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect()); - let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect()); +impl Default for ContentAware { + fn default() -> Self { + Self::new() + } +} + +impl ContentAware { + pub fn new() -> Self { let trie = Arc::new(Mutex::new(Trie::new())); - let factor: f32 = std::env::var(FACTOR_KEY) - .unwrap_or("1.5".to_string()) - .parse() - .unwrap_or(1.5); - Self { - inflight, - inqueue, - trie, - client, - factor, - backends: Arc::new(backends), - } + Self { trie } } +} - pub fn next(&mut self, key: &[u8]) -> usize { +impl LoadBalancer for ContentAware { + fn next(&mut self, key: &[u8], n_backends: usize) -> usize { let mut trie = self.trie.lock().unwrap(); let (start, stop) = trie.insert(key); let n = trie.count(); @@ -67,38 +55,108 @@ impl RoundRobin { assert!(rescaled_x >= 0.0); assert!(rescaled_x <= 1.0); println!("Start {start:.2} stop {stop:.2}: rescaled {rescaled_x:.2}"); - let n: usize = (rescaled_x * (self.backends.len() as f32)) as usize; + let n: usize = (rescaled_x * (n_backends as f32)) as usize; n } } -pub async fn handler(State(mut state): State, req: Request) -> Response { +#[derive(Clone)] +pub struct RoundRobin { + current: Arc, +} + +impl Default for RoundRobin { + fn default() -> Self { + Self::new() + } +} + +impl RoundRobin { + pub fn new() -> Self { + let current = Arc::new(AtomicUsize::new(0)); + Self { current } + } +} + +impl LoadBalancer for RoundRobin { + fn next(&mut self, _key: &[u8], _n_backends: usize) -> usize { + self.current.fetch_add(1, Ordering::Relaxed) + } +} + +#[derive(Clone)] +pub struct OverloadHandler { + client: Client, + load_balancer: T, + backends: Arc>, + inqueue: Arc>, + inflight: Arc>, + factor: f32, +} + +impl OverloadHandler { + pub fn new(load_balancer: T, backends: Vec) -> Self { + let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) + .build(HttpConnector::new()); + let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect()); + let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect()); + let factor: f32 = std::env::var(FACTOR_KEY) + .unwrap_or("1.5".to_string()) + .parse() + .unwrap_or(1.5); + let backends = Arc::new(backends); + Self { + load_balancer, + backends, + client, + factor, + inflight, + inqueue, + } + } + + fn next(&mut self, key: &[u8]) -> usize { + // Get the backend URL + let index = self.load_balancer.next(key, self.backends.len()); + let n = self.backends.len(); + let mut index = index % n; + + let mut inflight = self.inflight[index].load(Ordering::Relaxed); + let mut inqueue = self.inqueue[index].load(Ordering::Relaxed); + + for i in 0..n { + if (inqueue as f32) <= self.factor * inflight as f32 { + break; + } + if i == 0 { + eprintln!( + "Backend overloaded (queue: {inqueue} inflight {inflight}), jumping ahead" + ); + } + index += 1; + index %= self.backends.len(); + inflight = self.inflight[index].load(Ordering::Relaxed); + inqueue = self.inflight[index].load(Ordering::Relaxed); + } + index + } +} + +pub trait LoadBalancer { + fn next(&mut self, key: &[u8], n_backends: usize) -> usize; +} + +pub async fn handler( + State(mut state): State>, + req: Request, +) -> Response { // Get the next backend index - let limit = 2048usize; + let limit = 1024 * 1024; let (parts, body) = req.into_parts(); // TODO let bytes = axum::body::to_bytes(body, limit).await.unwrap(); let index = state.next(&bytes); - // Get the backend URL - let n = state.backends.len(); - let mut index = index % n; let backend = &state.backends[index]; - - let mut inflight = state.inflight[index].load(Ordering::Relaxed); - let mut inqueue = state.inqueue[index].load(Ordering::Relaxed); - - for i in 0..n { - if (inqueue as f32) <= state.factor * inflight as f32 { - break; - } - if i == 0 { - eprintln!("Backend overloaded (queue: {inqueue} inflight {inflight}), jumping ahead"); - } - index += 1; - index %= state.backends.len(); - inflight = state.inflight[index].load(Ordering::Relaxed); - inqueue = state.inflight[index].load(Ordering::Relaxed); - } state.inflight[index].fetch_add(1, Ordering::Relaxed); state.inqueue[index].fetch_add(1, Ordering::Relaxed); @@ -131,6 +189,7 @@ pub async fn handler(State(mut state): State, req: Request) -> Respo while let Some(raw_event) = response_stream.next().await { if start{ eprintln!("Not inqueue"); + state.inqueue[index].fetch_sub(1, Ordering::Relaxed); start = false; } diff --git a/kvrouter/src/main.rs b/kvrouter/src/main.rs index 7a213e45d5e..58aaefca69b 100644 --- a/kvrouter/src/main.rs +++ b/kvrouter/src/main.rs @@ -2,7 +2,7 @@ use axum::{ routing::Router, routing::{get, post}, }; -use kvrouter::{handler, RoundRobin}; +use kvrouter::{handler, ContentAware, OverloadHandler, RoundRobin}; #[tokio::main] async fn main() { @@ -10,21 +10,42 @@ async fn main() { let backends = vec![ "http://localhost:8000".to_string(), "http://localhost:8001".to_string(), + "http://localhost:8002".to_string(), + "http://localhost:8003".to_string(), ]; // Create a new instance of the RoundRobinRouter - let router = RoundRobin::new(backends); + if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" { + println!("Using round robin"); + let lb = RoundRobin::new(); + // Create the Axum router + let router = OverloadHandler::new(lb, backends); + let app = Router::new() + .route("/{*key}", get(handler)) + .route("/{*key}", post(handler)) + .with_state(router); - // Create the Axum router - let app = Router::new() - .route("/{*key}", get(handler)) - .route("/{*key}", post(handler)) - .with_state(router); + // run it + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + println!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); + } else { + println!("Using Content aware"); + let lb = ContentAware::new(); + // Create the Axum router + let router = OverloadHandler::new(lb, backends); + let app = Router::new() + .route("/{*key}", get(handler)) + .route("/{*key}", post(handler)) + .with_state(router); - // run it - let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") - .await - .unwrap(); - println!("listening on {}", listener.local_addr().unwrap()); - axum::serve(listener, app).await.unwrap(); + // run it + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + println!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); + }; } diff --git a/load_tests/common.js b/load_tests/common.js index d890bf6710d..4e3cee4af65 100644 --- a/load_tests/common.js +++ b/load_tests/common.js @@ -50,7 +50,7 @@ export function get_options() { throughput: { executor: 'shared-iterations', vus: 100, - iterations: 200, + iterations: 500, maxDuration: '40s', }, }, From 914b1637688b0d58de780c72ac587b8dc9675d6d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 29 Jan 2025 13:27:32 +0100 Subject: [PATCH 4/7] Remove kvrouter from default members. --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3b6c20a1676..ef91ab70dc2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,8 +6,9 @@ members = [ "backends/grpc-metadata", "backends/trtllm", "launcher", - "router" -, "kvrouter"] + "router", + "kvrouter" +] default-members = [ "benchmark", "backends/v2", @@ -15,7 +16,6 @@ default-members = [ "backends/grpc-metadata", # "backends/trtllm", "launcher", - "kvrouter", "router" ] resolver = "2" From 1932c5b9ed9ec46dbdf563c521c70f5e8937ec99 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 29 Jan 2025 13:48:06 +0100 Subject: [PATCH 5/7] Adding kvrouter to the workspace. --- Cargo.toml | 3 ++- Dockerfile | 1 + Dockerfile_amd | 1 + Dockerfile_intel | 1 + Dockerfile_trtllm | 1 + 5 files changed, 6 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index ef91ab70dc2..d526d94ae67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ members = [ "backends/trtllm", "launcher", "router", - "kvrouter" + "kvrouter", ] default-members = [ "benchmark", @@ -16,6 +16,7 @@ default-members = [ "backends/grpc-metadata", # "backends/trtllm", "launcher", + "kvrouter", "router" ] resolver = "2" diff --git a/Dockerfile b/Dockerfile index 7200533309c..21e1ab0c5d8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,6 +10,7 @@ COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark +COPY kvrouter kvrouter COPY router router COPY backends backends COPY launcher launcher diff --git a/Dockerfile_amd b/Dockerfile_amd index 8b7808bea51..5fa3d632dcf 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY kvrouter kvrouter COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json diff --git a/Dockerfile_intel b/Dockerfile_intel index 0f0d4383595..6031176d49e 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -12,6 +12,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY kvrouter kvrouter COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json diff --git a/Dockerfile_trtllm b/Dockerfile_trtllm index 999d63d742b..872b09108af 100644 --- a/Dockerfile_trtllm +++ b/Dockerfile_trtllm @@ -92,6 +92,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY router router COPY backends backends COPY benchmark benchmark +COPY kvrouter kvrouter COPY launcher launcher COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi From 57fa04adfd29a11ed16283fe2fda49eb4f788d2e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 31 Jan 2025 09:07:31 +0100 Subject: [PATCH 6/7] Cleaner version. --- Cargo.lock | 5 +- kvrouter/Cargo.toml | 1 + kvrouter/src/lib.rs | 161 ++++++++++++++++++++++++++++++++----------- kvrouter/src/main.rs | 66 +++++++++--------- 4 files changed, 156 insertions(+), 77 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ee52a2b429e..9b741b1f524 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2254,6 +2254,7 @@ dependencies = [ "futures-util", "hyper 1.5.2", "hyper-util", + "log", "rand 0.9.0", "slotmap", "tokio", @@ -2352,9 +2353,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "loop9" diff --git a/kvrouter/Cargo.toml b/kvrouter/Cargo.toml index 2e54f6e46f7..efb2aa2a019 100644 --- a/kvrouter/Cargo.toml +++ b/kvrouter/Cargo.toml @@ -12,6 +12,7 @@ futures = "0.3.31" futures-util = "0.3.31" hyper = { version = "1.5.2", features = ["full"] } hyper-util = { version = "0.1.10", features = ["full"] } +log = "0.4.25" rand = "0.9.0" slotmap = "1.0.7" tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] } diff --git a/kvrouter/src/lib.rs b/kvrouter/src/lib.rs index ba0bf3c8999..874ae192856 100644 --- a/kvrouter/src/lib.rs +++ b/kvrouter/src/lib.rs @@ -5,10 +5,11 @@ use axum::{ response::{IntoResponse, Response}, }; use futures_util::stream::StreamExt; +use hyper::StatusCode; use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use rand::{rng, Rng}; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Mutex}; +use tokio::sync::{mpsc, oneshot}; mod trie; @@ -17,9 +18,8 @@ use crate::trie::Trie; const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR"; type Client = hyper_util::client::legacy::Client; -#[derive(Clone)] pub struct ContentAware { - trie: Arc>, + trie: Trie, } impl Default for ContentAware { @@ -30,14 +30,14 @@ impl Default for ContentAware { impl ContentAware { pub fn new() -> Self { - let trie = Arc::new(Mutex::new(Trie::new())); + let trie = Trie::new(); Self { trie } } } impl LoadBalancer for ContentAware { fn next(&mut self, key: &[u8], n_backends: usize) -> usize { - let mut trie = self.trie.lock().unwrap(); + let trie = &mut self.trie; let (start, stop) = trie.insert(key); let n = trie.count(); eprintln!( @@ -60,9 +60,8 @@ impl LoadBalancer for ContentAware { } } -#[derive(Clone)] pub struct RoundRobin { - current: Arc, + current: AtomicUsize, } impl Default for RoundRobin { @@ -73,7 +72,7 @@ impl Default for RoundRobin { impl RoundRobin { pub fn new() -> Self { - let current = Arc::new(AtomicUsize::new(0)); + let current = AtomicUsize::new(0); Self { current } } } @@ -84,38 +83,34 @@ impl LoadBalancer for RoundRobin { } } -#[derive(Clone)] pub struct OverloadHandler { - client: Client, load_balancer: T, - backends: Arc>, - inqueue: Arc>, - inflight: Arc>, + backends: Vec, + inqueue: Vec, + inflight: Vec, factor: f32, + rx: Rcv, } impl OverloadHandler { - pub fn new(load_balancer: T, backends: Vec) -> Self { - let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) - .build(HttpConnector::new()); - let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect()); - let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect()); + pub fn new(load_balancer: T, backends: Vec, rx: Rcv) -> Self { + let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect(); + let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect(); let factor: f32 = std::env::var(FACTOR_KEY) .unwrap_or("1.5".to_string()) .parse() .unwrap_or(1.5); - let backends = Arc::new(backends); Self { load_balancer, backends, - client, factor, inflight, inqueue, + rx, } } - fn next(&mut self, key: &[u8]) -> usize { + fn next(&mut self, key: &[u8]) -> String { // Get the backend URL let index = self.load_balancer.next(key, self.backends.len()); let n = self.backends.len(); @@ -138,7 +133,45 @@ impl OverloadHandler { inflight = self.inflight[index].load(Ordering::Relaxed); inqueue = self.inflight[index].load(Ordering::Relaxed); } - index + let backend = &self.backends[index]; + self.inflight[index].fetch_add(1, Ordering::Relaxed); + self.inqueue[index].fetch_add(1, Ordering::Relaxed); + backend.to_string() + } + + pub async fn run(&mut self) { + while let Some(msg) = self.rx.recv().await { + eprintln!("Msg {msg:?}"); + match msg { + Msg::Next(key, sx) => { + let backend: String = self.next(&key); + eprintln!("Sending back backend {backend}"); + if let Err(err) = sx.send(backend) { + eprintln!("Cannot send back result: {err}"); + } + } + Msg::Dequeue(backend) => { + let index = self.backends.iter().position(|b| b == &backend); + if let Some(index) = index { + self.inqueue[index].fetch_sub(1, Ordering::Relaxed); + } + } + Msg::Deflight(backend) => { + let index = self.backends.iter().position(|b| b == &backend); + if let Some(index) = index { + self.inflight[index].fetch_sub(1, Ordering::Relaxed); + } + } + Msg::AddBackend(backend) => { + self.backends.push(backend); + self.backends.sort(); + } + Msg::RemoveBackend(backend) => { + self.backends.retain(|b| *b == backend); + self.backends.sort(); + } + } + } } } @@ -146,21 +179,71 @@ pub trait LoadBalancer { fn next(&mut self, key: &[u8], n_backends: usize) -> usize; } -pub async fn handler( - State(mut state): State>, +#[derive(Debug)] +pub enum Msg { + Next(Vec, oneshot::Sender), + Dequeue(String), + Deflight(String), + AddBackend(String), + RemoveBackend(String), +} + +type Snd = mpsc::Sender; +type Rcv = mpsc::Receiver; + +#[derive(Clone)] +pub struct Communicator { + sender: Snd, + client: Client, +} + +impl Communicator { + pub fn new(sender: Snd) -> Self { + let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) + .build(HttpConnector::new()); + Self { sender, client } + } + + async fn dequeue(&self, backend: String) -> Result<(), mpsc::error::SendError> { + self.sender.send(Msg::Dequeue(backend)).await + } + + async fn deflight(&self, backend: String) -> Result<(), mpsc::error::SendError> { + self.sender.send(Msg::Deflight(backend)).await + } + + async fn next(&self, key: Vec) -> Result> { + let (sx, rx) = oneshot::channel(); + self.sender.send(Msg::Next(key, sx)).await?; + let backend = rx.await.unwrap(); + Ok(backend) + } +} + +pub async fn handler( + State(state): State, req: Request, -) -> Response { +) -> Result, StatusCode> { // Get the next backend index - let limit = 1024 * 1024; let (parts, body) = req.into_parts(); - // TODO - let bytes = axum::body::to_bytes(body, limit).await.unwrap(); - let index = state.next(&bytes); - let backend = &state.backends[index]; - state.inflight[index].fetch_add(1, Ordering::Relaxed); - state.inqueue[index].fetch_add(1, Ordering::Relaxed); - - let body: Body = bytes.into(); + let mut response_stream = body.into_data_stream(); + let event = response_stream.next().await; + let key = if let Some(Ok(event)) = &event { + event.to_vec() + } else { + vec![] + }; + let backend = state.next(key).await.map_err(|_| StatusCode::BAD_GATEWAY)?; + let response_stream = async_stream::stream! { + let mut response_stream = Box::pin(response_stream); + if let Some(event) = event{ + yield event; + } + while let Some(raw_event) = response_stream.next().await { + yield raw_event; + } + }; + let body = Body::from_stream(response_stream); let mut req = Request::from_parts(parts, body); let path = req.uri().path(); let path_query = req @@ -177,9 +260,7 @@ pub async fn handler( .client .request(req) .await - // TODO - .unwrap(); - //.map_err(|_| StatusCode::BAD_GATEWAY)?; + .map_err(|_| StatusCode::BAD_GATEWAY)?; let response = response.into_response(); let (parts, body) = response.into_parts(); let response_stream = body.into_data_stream(); @@ -190,16 +271,16 @@ pub async fn handler( if start{ eprintln!("Not inqueue"); - state.inqueue[index].fetch_sub(1, Ordering::Relaxed); + state.dequeue(backend.to_string()).await.unwrap(); start = false; } yield raw_event; } eprintln!("Not inflight"); - state.inflight[index].fetch_sub(1, Ordering::Relaxed); + state.deflight(backend.to_string()).await.unwrap(); }; let body = Body::from_stream(response_stream); - Response::from_parts(parts, body) + Ok(Response::from_parts(parts, body)) } diff --git a/kvrouter/src/main.rs b/kvrouter/src/main.rs index 58aaefca69b..b50815456c9 100644 --- a/kvrouter/src/main.rs +++ b/kvrouter/src/main.rs @@ -2,50 +2,46 @@ use axum::{ routing::Router, routing::{get, post}, }; -use kvrouter::{handler, ContentAware, OverloadHandler, RoundRobin}; +use kvrouter::{handler, Communicator, ContentAware, OverloadHandler, RoundRobin}; #[tokio::main] async fn main() { // List of backend servers let backends = vec![ "http://localhost:8000".to_string(), - "http://localhost:8001".to_string(), - "http://localhost:8002".to_string(), - "http://localhost:8003".to_string(), + // "http://localhost:8001".to_string(), + // "http://localhost:8002".to_string(), + // "http://localhost:8003".to_string(), ]; // Create a new instance of the RoundRobinRouter - if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" { - println!("Using round robin"); - let lb = RoundRobin::new(); - // Create the Axum router - let router = OverloadHandler::new(lb, backends); - let app = Router::new() - .route("/{*key}", get(handler)) - .route("/{*key}", post(handler)) - .with_state(router); - // run it - let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") - .await - .unwrap(); - println!("listening on {}", listener.local_addr().unwrap()); - axum::serve(listener, app).await.unwrap(); - } else { - println!("Using Content aware"); - let lb = ContentAware::new(); - // Create the Axum router - let router = OverloadHandler::new(lb, backends); - let app = Router::new() - .route("/{*key}", get(handler)) - .route("/{*key}", post(handler)) - .with_state(router); + println!("Using Content aware"); + // Create the Axum router - // run it - let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") - .await - .unwrap(); - println!("listening on {}", listener.local_addr().unwrap()); - axum::serve(listener, app).await.unwrap(); - }; + let (sx, rx) = tokio::sync::mpsc::channel(100); + let communicator = Communicator::new(sx); + tokio::task::spawn(async move { + if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" { + println!("Using round robin"); + let lb = RoundRobin::new(); + let mut router = OverloadHandler::new(lb, backends, rx); + router.run().await; + } else { + let lb = ContentAware::new(); + let mut router = OverloadHandler::new(lb, backends, rx); + router.run().await; + }; + }); + let app = Router::new() + .route("/{*key}", get(handler)) + .route("/{*key}", post(handler)) + .with_state(communicator); + + // run it + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + println!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); } From 50c8ebdef0c98b3f2520a26e06ed4c5276db66cd Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 31 Jan 2025 13:16:29 +0100 Subject: [PATCH 7/7] CI must be green. --- .github/workflows/build.yaml | 10 +++++----- docs/source/backends/trtllm.md | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 2db466a84a6..c991ffcbe67 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -74,12 +74,12 @@ jobs: export runs_on="ubuntu-latest" export platform="" export extra_pytest="" - if [[ "${GITHUB_REF}" == "refs/tags/*" ]]; then - export build_type="release"; - export target=""; - else + if [[ "${GITHUB_REF}" == "refs/tags/*" ]]; then + export build_type="release"; + export target=""; + else export build_type="dev"; - export target="ci-runtime"; + export target="ci-runtime"; fi ;; rocm) diff --git a/docs/source/backends/trtllm.md b/docs/source/backends/trtllm.md index e89e8f5208b..10db4a5089d 100644 --- a/docs/source/backends/trtllm.md +++ b/docs/source/backends/trtllm.md @@ -179,4 +179,4 @@ ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt" ENV USE_LLD_LINKER=ON ENV CUDA_ARCH_LIST=${cuda_arch_list} -``` \ No newline at end of file +```