diff --git a/Cargo.lock b/Cargo.lock index ee52a2b429e..9b741b1f524 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2254,6 +2254,7 @@ dependencies = [ "futures-util", "hyper 1.5.2", "hyper-util", + "log", "rand 0.9.0", "slotmap", "tokio", @@ -2352,9 +2353,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "loop9" diff --git a/kvrouter/Cargo.toml b/kvrouter/Cargo.toml index 2e54f6e46f7..efb2aa2a019 100644 --- a/kvrouter/Cargo.toml +++ b/kvrouter/Cargo.toml @@ -12,6 +12,7 @@ futures = "0.3.31" futures-util = "0.3.31" hyper = { version = "1.5.2", features = ["full"] } hyper-util = { version = "0.1.10", features = ["full"] } +log = "0.4.25" rand = "0.9.0" slotmap = "1.0.7" tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] } diff --git a/kvrouter/src/lib.rs b/kvrouter/src/lib.rs index ba0bf3c8999..874ae192856 100644 --- a/kvrouter/src/lib.rs +++ b/kvrouter/src/lib.rs @@ -5,10 +5,11 @@ use axum::{ response::{IntoResponse, Response}, }; use futures_util::stream::StreamExt; +use hyper::StatusCode; use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use rand::{rng, Rng}; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Mutex}; +use tokio::sync::{mpsc, oneshot}; mod trie; @@ -17,9 +18,8 @@ use crate::trie::Trie; const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR"; type Client = hyper_util::client::legacy::Client; -#[derive(Clone)] pub struct ContentAware { - trie: Arc>, + trie: Trie, } impl Default for ContentAware { @@ -30,14 +30,14 @@ impl Default for ContentAware { impl ContentAware { pub fn new() -> Self { - let trie = Arc::new(Mutex::new(Trie::new())); + let trie = Trie::new(); Self { trie } } } impl LoadBalancer for ContentAware { fn next(&mut self, key: &[u8], n_backends: usize) -> usize { - let mut trie = self.trie.lock().unwrap(); + let trie = &mut self.trie; let (start, stop) = trie.insert(key); let n = trie.count(); eprintln!( @@ -60,9 +60,8 @@ impl LoadBalancer for ContentAware { } } -#[derive(Clone)] pub struct RoundRobin { - current: Arc, + current: AtomicUsize, } impl Default for RoundRobin { @@ -73,7 +72,7 @@ impl Default for RoundRobin { impl RoundRobin { pub fn new() -> Self { - let current = Arc::new(AtomicUsize::new(0)); + let current = AtomicUsize::new(0); Self { current } } } @@ -84,38 +83,34 @@ impl LoadBalancer for RoundRobin { } } -#[derive(Clone)] pub struct OverloadHandler { - client: Client, load_balancer: T, - backends: Arc>, - inqueue: Arc>, - inflight: Arc>, + backends: Vec, + inqueue: Vec, + inflight: Vec, factor: f32, + rx: Rcv, } impl OverloadHandler { - pub fn new(load_balancer: T, backends: Vec) -> Self { - let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) - .build(HttpConnector::new()); - let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect()); - let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect()); + pub fn new(load_balancer: T, backends: Vec, rx: Rcv) -> Self { + let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect(); + let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect(); let factor: f32 = std::env::var(FACTOR_KEY) .unwrap_or("1.5".to_string()) .parse() .unwrap_or(1.5); - let backends = Arc::new(backends); Self { load_balancer, backends, - client, factor, inflight, inqueue, + rx, } } - fn next(&mut self, key: &[u8]) -> usize { + fn next(&mut self, key: &[u8]) -> String { // Get the backend URL let index = self.load_balancer.next(key, self.backends.len()); let n = self.backends.len(); @@ -138,7 +133,45 @@ impl OverloadHandler { inflight = self.inflight[index].load(Ordering::Relaxed); inqueue = self.inflight[index].load(Ordering::Relaxed); } - index + let backend = &self.backends[index]; + self.inflight[index].fetch_add(1, Ordering::Relaxed); + self.inqueue[index].fetch_add(1, Ordering::Relaxed); + backend.to_string() + } + + pub async fn run(&mut self) { + while let Some(msg) = self.rx.recv().await { + eprintln!("Msg {msg:?}"); + match msg { + Msg::Next(key, sx) => { + let backend: String = self.next(&key); + eprintln!("Sending back backend {backend}"); + if let Err(err) = sx.send(backend) { + eprintln!("Cannot send back result: {err}"); + } + } + Msg::Dequeue(backend) => { + let index = self.backends.iter().position(|b| b == &backend); + if let Some(index) = index { + self.inqueue[index].fetch_sub(1, Ordering::Relaxed); + } + } + Msg::Deflight(backend) => { + let index = self.backends.iter().position(|b| b == &backend); + if let Some(index) = index { + self.inflight[index].fetch_sub(1, Ordering::Relaxed); + } + } + Msg::AddBackend(backend) => { + self.backends.push(backend); + self.backends.sort(); + } + Msg::RemoveBackend(backend) => { + self.backends.retain(|b| *b == backend); + self.backends.sort(); + } + } + } } } @@ -146,21 +179,71 @@ pub trait LoadBalancer { fn next(&mut self, key: &[u8], n_backends: usize) -> usize; } -pub async fn handler( - State(mut state): State>, +#[derive(Debug)] +pub enum Msg { + Next(Vec, oneshot::Sender), + Dequeue(String), + Deflight(String), + AddBackend(String), + RemoveBackend(String), +} + +type Snd = mpsc::Sender; +type Rcv = mpsc::Receiver; + +#[derive(Clone)] +pub struct Communicator { + sender: Snd, + client: Client, +} + +impl Communicator { + pub fn new(sender: Snd) -> Self { + let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) + .build(HttpConnector::new()); + Self { sender, client } + } + + async fn dequeue(&self, backend: String) -> Result<(), mpsc::error::SendError> { + self.sender.send(Msg::Dequeue(backend)).await + } + + async fn deflight(&self, backend: String) -> Result<(), mpsc::error::SendError> { + self.sender.send(Msg::Deflight(backend)).await + } + + async fn next(&self, key: Vec) -> Result> { + let (sx, rx) = oneshot::channel(); + self.sender.send(Msg::Next(key, sx)).await?; + let backend = rx.await.unwrap(); + Ok(backend) + } +} + +pub async fn handler( + State(state): State, req: Request, -) -> Response { +) -> Result, StatusCode> { // Get the next backend index - let limit = 1024 * 1024; let (parts, body) = req.into_parts(); - // TODO - let bytes = axum::body::to_bytes(body, limit).await.unwrap(); - let index = state.next(&bytes); - let backend = &state.backends[index]; - state.inflight[index].fetch_add(1, Ordering::Relaxed); - state.inqueue[index].fetch_add(1, Ordering::Relaxed); - - let body: Body = bytes.into(); + let mut response_stream = body.into_data_stream(); + let event = response_stream.next().await; + let key = if let Some(Ok(event)) = &event { + event.to_vec() + } else { + vec![] + }; + let backend = state.next(key).await.map_err(|_| StatusCode::BAD_GATEWAY)?; + let response_stream = async_stream::stream! { + let mut response_stream = Box::pin(response_stream); + if let Some(event) = event{ + yield event; + } + while let Some(raw_event) = response_stream.next().await { + yield raw_event; + } + }; + let body = Body::from_stream(response_stream); let mut req = Request::from_parts(parts, body); let path = req.uri().path(); let path_query = req @@ -177,9 +260,7 @@ pub async fn handler( .client .request(req) .await - // TODO - .unwrap(); - //.map_err(|_| StatusCode::BAD_GATEWAY)?; + .map_err(|_| StatusCode::BAD_GATEWAY)?; let response = response.into_response(); let (parts, body) = response.into_parts(); let response_stream = body.into_data_stream(); @@ -190,16 +271,16 @@ pub async fn handler( if start{ eprintln!("Not inqueue"); - state.inqueue[index].fetch_sub(1, Ordering::Relaxed); + state.dequeue(backend.to_string()).await.unwrap(); start = false; } yield raw_event; } eprintln!("Not inflight"); - state.inflight[index].fetch_sub(1, Ordering::Relaxed); + state.deflight(backend.to_string()).await.unwrap(); }; let body = Body::from_stream(response_stream); - Response::from_parts(parts, body) + Ok(Response::from_parts(parts, body)) } diff --git a/kvrouter/src/main.rs b/kvrouter/src/main.rs index 58aaefca69b..b50815456c9 100644 --- a/kvrouter/src/main.rs +++ b/kvrouter/src/main.rs @@ -2,50 +2,46 @@ use axum::{ routing::Router, routing::{get, post}, }; -use kvrouter::{handler, ContentAware, OverloadHandler, RoundRobin}; +use kvrouter::{handler, Communicator, ContentAware, OverloadHandler, RoundRobin}; #[tokio::main] async fn main() { // List of backend servers let backends = vec![ "http://localhost:8000".to_string(), - "http://localhost:8001".to_string(), - "http://localhost:8002".to_string(), - "http://localhost:8003".to_string(), + // "http://localhost:8001".to_string(), + // "http://localhost:8002".to_string(), + // "http://localhost:8003".to_string(), ]; // Create a new instance of the RoundRobinRouter - if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" { - println!("Using round robin"); - let lb = RoundRobin::new(); - // Create the Axum router - let router = OverloadHandler::new(lb, backends); - let app = Router::new() - .route("/{*key}", get(handler)) - .route("/{*key}", post(handler)) - .with_state(router); - // run it - let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") - .await - .unwrap(); - println!("listening on {}", listener.local_addr().unwrap()); - axum::serve(listener, app).await.unwrap(); - } else { - println!("Using Content aware"); - let lb = ContentAware::new(); - // Create the Axum router - let router = OverloadHandler::new(lb, backends); - let app = Router::new() - .route("/{*key}", get(handler)) - .route("/{*key}", post(handler)) - .with_state(router); + println!("Using Content aware"); + // Create the Axum router - // run it - let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") - .await - .unwrap(); - println!("listening on {}", listener.local_addr().unwrap()); - axum::serve(listener, app).await.unwrap(); - }; + let (sx, rx) = tokio::sync::mpsc::channel(100); + let communicator = Communicator::new(sx); + tokio::task::spawn(async move { + if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" { + println!("Using round robin"); + let lb = RoundRobin::new(); + let mut router = OverloadHandler::new(lb, backends, rx); + router.run().await; + } else { + let lb = ContentAware::new(); + let mut router = OverloadHandler::new(lb, backends, rx); + router.run().await; + }; + }); + let app = Router::new() + .route("/{*key}", get(handler)) + .route("/{*key}", post(handler)) + .with_state(communicator); + + // run it + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + println!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); }