Skip to content

Commit

Permalink
Cleaner version.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jan 31, 2025
1 parent 1932c5b commit 57fa04a
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 77 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions kvrouter/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ futures = "0.3.31"
futures-util = "0.3.31"
hyper = { version = "1.5.2", features = ["full"] }
hyper-util = { version = "0.1.10", features = ["full"] }
log = "0.4.25"
rand = "0.9.0"
slotmap = "1.0.7"
tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] }
161 changes: 121 additions & 40 deletions kvrouter/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ use axum::{
response::{IntoResponse, Response},
};
use futures_util::stream::StreamExt;
use hyper::StatusCode;
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
use rand::{rng, Rng};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use tokio::sync::{mpsc, oneshot};

mod trie;

Expand All @@ -17,9 +18,8 @@ use crate::trie::Trie;
const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR";
type Client = hyper_util::client::legacy::Client<HttpConnector, Body>;

#[derive(Clone)]
pub struct ContentAware {
trie: Arc<Mutex<Trie>>,
trie: Trie,
}

impl Default for ContentAware {
Expand All @@ -30,14 +30,14 @@ impl Default for ContentAware {

impl ContentAware {
pub fn new() -> Self {
let trie = Arc::new(Mutex::new(Trie::new()));
let trie = Trie::new();
Self { trie }
}
}

impl LoadBalancer for ContentAware {
fn next(&mut self, key: &[u8], n_backends: usize) -> usize {
let mut trie = self.trie.lock().unwrap();
let trie = &mut self.trie;
let (start, stop) = trie.insert(key);
let n = trie.count();
eprintln!(
Expand All @@ -60,9 +60,8 @@ impl LoadBalancer for ContentAware {
}
}

#[derive(Clone)]
pub struct RoundRobin {
current: Arc<AtomicUsize>,
current: AtomicUsize,
}

impl Default for RoundRobin {
Expand All @@ -73,7 +72,7 @@ impl Default for RoundRobin {

impl RoundRobin {
pub fn new() -> Self {
let current = Arc::new(AtomicUsize::new(0));
let current = AtomicUsize::new(0);
Self { current }
}
}
Expand All @@ -84,38 +83,34 @@ impl LoadBalancer for RoundRobin {
}
}

#[derive(Clone)]
pub struct OverloadHandler<T: LoadBalancer> {
client: Client,
load_balancer: T,
backends: Arc<Vec<String>>,
inqueue: Arc<Vec<AtomicUsize>>,
inflight: Arc<Vec<AtomicUsize>>,
backends: Vec<String>,
inqueue: Vec<AtomicUsize>,
inflight: Vec<AtomicUsize>,
factor: f32,
rx: Rcv,
}

impl<T: LoadBalancer> OverloadHandler<T> {
pub fn new(load_balancer: T, backends: Vec<String>) -> Self {
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
.build(HttpConnector::new());
let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
pub fn new(load_balancer: T, backends: Vec<String>, rx: Rcv) -> Self {
let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect();
let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect();
let factor: f32 = std::env::var(FACTOR_KEY)
.unwrap_or("1.5".to_string())
.parse()
.unwrap_or(1.5);
let backends = Arc::new(backends);
Self {
load_balancer,
backends,
client,
factor,
inflight,
inqueue,
rx,
}
}

fn next(&mut self, key: &[u8]) -> usize {
fn next(&mut self, key: &[u8]) -> String {
// Get the backend URL
let index = self.load_balancer.next(key, self.backends.len());
let n = self.backends.len();
Expand All @@ -138,29 +133,117 @@ impl<T: LoadBalancer> OverloadHandler<T> {
inflight = self.inflight[index].load(Ordering::Relaxed);
inqueue = self.inflight[index].load(Ordering::Relaxed);
}
index
let backend = &self.backends[index];
self.inflight[index].fetch_add(1, Ordering::Relaxed);
self.inqueue[index].fetch_add(1, Ordering::Relaxed);
backend.to_string()
}

pub async fn run(&mut self) {
while let Some(msg) = self.rx.recv().await {
eprintln!("Msg {msg:?}");
match msg {
Msg::Next(key, sx) => {
let backend: String = self.next(&key);
eprintln!("Sending back backend {backend}");
if let Err(err) = sx.send(backend) {
eprintln!("Cannot send back result: {err}");
}
}
Msg::Dequeue(backend) => {
let index = self.backends.iter().position(|b| b == &backend);
if let Some(index) = index {
self.inqueue[index].fetch_sub(1, Ordering::Relaxed);
}
}
Msg::Deflight(backend) => {
let index = self.backends.iter().position(|b| b == &backend);
if let Some(index) = index {
self.inflight[index].fetch_sub(1, Ordering::Relaxed);
}
}
Msg::AddBackend(backend) => {
self.backends.push(backend);
self.backends.sort();
}
Msg::RemoveBackend(backend) => {
self.backends.retain(|b| *b == backend);
self.backends.sort();
}
}
}
}
}

pub trait LoadBalancer {
fn next(&mut self, key: &[u8], n_backends: usize) -> usize;
}

pub async fn handler<T: LoadBalancer>(
State(mut state): State<OverloadHandler<T>>,
#[derive(Debug)]
pub enum Msg {
Next(Vec<u8>, oneshot::Sender<String>),
Dequeue(String),
Deflight(String),
AddBackend(String),
RemoveBackend(String),
}

type Snd = mpsc::Sender<Msg>;
type Rcv = mpsc::Receiver<Msg>;

#[derive(Clone)]
pub struct Communicator {
sender: Snd,
client: Client,
}

impl Communicator {
pub fn new(sender: Snd) -> Self {
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
.build(HttpConnector::new());
Self { sender, client }
}

async fn dequeue(&self, backend: String) -> Result<(), mpsc::error::SendError<Msg>> {
self.sender.send(Msg::Dequeue(backend)).await
}

async fn deflight(&self, backend: String) -> Result<(), mpsc::error::SendError<Msg>> {
self.sender.send(Msg::Deflight(backend)).await
}

async fn next(&self, key: Vec<u8>) -> Result<String, mpsc::error::SendError<Msg>> {
let (sx, rx) = oneshot::channel();
self.sender.send(Msg::Next(key, sx)).await?;
let backend = rx.await.unwrap();
Ok(backend)
}
}

pub async fn handler(
State(state): State<Communicator>,
req: Request,
) -> Response<Body> {
) -> Result<Response<Body>, StatusCode> {
// Get the next backend index
let limit = 1024 * 1024;
let (parts, body) = req.into_parts();
// TODO
let bytes = axum::body::to_bytes(body, limit).await.unwrap();
let index = state.next(&bytes);
let backend = &state.backends[index];
state.inflight[index].fetch_add(1, Ordering::Relaxed);
state.inqueue[index].fetch_add(1, Ordering::Relaxed);

let body: Body = bytes.into();
let mut response_stream = body.into_data_stream();
let event = response_stream.next().await;
let key = if let Some(Ok(event)) = &event {
event.to_vec()
} else {
vec![]
};
let backend = state.next(key).await.map_err(|_| StatusCode::BAD_GATEWAY)?;
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
if let Some(event) = event{
yield event;
}
while let Some(raw_event) = response_stream.next().await {
yield raw_event;
}
};
let body = Body::from_stream(response_stream);
let mut req = Request::from_parts(parts, body);
let path = req.uri().path();
let path_query = req
Expand All @@ -177,9 +260,7 @@ pub async fn handler<T: LoadBalancer>(
.client
.request(req)
.await
// TODO
.unwrap();
//.map_err(|_| StatusCode::BAD_GATEWAY)?;
.map_err(|_| StatusCode::BAD_GATEWAY)?;
let response = response.into_response();
let (parts, body) = response.into_parts();
let response_stream = body.into_data_stream();
Expand All @@ -190,16 +271,16 @@ pub async fn handler<T: LoadBalancer>(
if start{
eprintln!("Not inqueue");

state.inqueue[index].fetch_sub(1, Ordering::Relaxed);
state.dequeue(backend.to_string()).await.unwrap();
start = false;
}
yield raw_event;
}
eprintln!("Not inflight");
state.inflight[index].fetch_sub(1, Ordering::Relaxed);
state.deflight(backend.to_string()).await.unwrap();
};

let body = Body::from_stream(response_stream);

Response::from_parts(parts, body)
Ok(Response::from_parts(parts, body))
}
66 changes: 31 additions & 35 deletions kvrouter/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,46 @@ use axum::{
routing::Router,
routing::{get, post},
};
use kvrouter::{handler, ContentAware, OverloadHandler, RoundRobin};
use kvrouter::{handler, Communicator, ContentAware, OverloadHandler, RoundRobin};

#[tokio::main]
async fn main() {
// List of backend servers
let backends = vec![
"http://localhost:8000".to_string(),
"http://localhost:8001".to_string(),
"http://localhost:8002".to_string(),
"http://localhost:8003".to_string(),
// "http://localhost:8001".to_string(),
// "http://localhost:8002".to_string(),
// "http://localhost:8003".to_string(),
];

// Create a new instance of the RoundRobinRouter
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
println!("Using round robin");
let lb = RoundRobin::new();
// Create the Axum router
let router = OverloadHandler::new(lb, backends);
let app = Router::new()
.route("/{*key}", get(handler))
.route("/{*key}", post(handler))
.with_state(router);

// run it
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
} else {
println!("Using Content aware");
let lb = ContentAware::new();
// Create the Axum router
let router = OverloadHandler::new(lb, backends);
let app = Router::new()
.route("/{*key}", get(handler))
.route("/{*key}", post(handler))
.with_state(router);
println!("Using Content aware");
// Create the Axum router

// run it
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
};
let (sx, rx) = tokio::sync::mpsc::channel(100);
let communicator = Communicator::new(sx);
tokio::task::spawn(async move {
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
println!("Using round robin");
let lb = RoundRobin::new();
let mut router = OverloadHandler::new(lb, backends, rx);
router.run().await;
} else {
let lb = ContentAware::new();
let mut router = OverloadHandler::new(lb, backends, rx);
router.run().await;
};
});
let app = Router::new()
.route("/{*key}", get(handler))
.route("/{*key}", post(handler))
.with_state(communicator);

// run it
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}

0 comments on commit 57fa04a

Please sign in to comment.