Skip to content

Commit

Permalink
refacto
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Jun 26, 2024
1 parent 93e0a7d commit 230f2a4
Show file tree
Hide file tree
Showing 14 changed files with 429 additions and 423 deletions.
28 changes: 10 additions & 18 deletions backends/v3/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
/// Batching and inference logic
use crate::queue::{Entry, Queue};
use text_generation_router::infer::{
GeneratedText, InferError, InferStreamResponse, Backend,
};
use async_trait::async_trait;
use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, PrefillToken, Token};
use nohash_hasher::IntMap;
use std::sync::{
Arc,
};
use crate::client::{Batch, CachedBatch, Generation, ShardedClient, ClientError, Health};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span};
use async_trait::async_trait;

pub struct BackendV3 {
/// Request queue
Expand Down Expand Up @@ -94,9 +90,7 @@ impl Backend for BackendV3 {
self.batching_task_notifier.notify_one();

// Return stream
Ok(
UnboundedReceiverStream::new(response_rx),
)
Ok(UnboundedReceiverStream::new(response_rx))
}

async fn health(&self, current_health: bool) -> bool {
Expand Down Expand Up @@ -193,10 +187,9 @@ pub(crate) async fn batching_task(
});

// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
prefill(&mut client, new_batch, &mut new_entries)
.instrument(span)
.await;
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
Expand Down Expand Up @@ -480,8 +473,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {

impl From<crate::client::GeneratedText> for GeneratedText {
fn from(value: crate::client::GeneratedText) -> Self {
let v3_finish_reason =
crate::client::FinishReason::try_from(value.finish_reason).unwrap();
let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
let finish_reason = match v3_finish_reason {
crate::client::FinishReason::Length => FinishReason::Length,
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
Expand Down
2 changes: 1 addition & 1 deletion backends/v3/src/client/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
/// Single shard Client
use crate::client::{pb, Chunk};
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
use base64::engine::general_purpose::STANDARD;
Expand All @@ -20,6 +19,7 @@ pub struct Client {

impl Client {
/// Returns a client connected to the given url
#[allow(dead_code)]
pub async fn connect(uri: Uri) -> Result<Self> {
let channel = Channel::builder(uri).connect().await?;

Expand Down
1 change: 0 additions & 1 deletion backends/v3/src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! Text Generation gRPC client library

use async_trait::async_trait;
use thiserror::Error;
use tonic::transport;
Expand Down
13 changes: 7 additions & 6 deletions backends/v3/src/client/sharded_client.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use crate::client::{ClientError, Result};
/// Multi shard Client
use crate::client::{Health, ShardInfo};
use crate::client::{ClientError, Result};

use crate::client::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
use crate::client::client::{DecodeTimings, PrefillTimings};
use crate::client::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use crate::client::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;

#[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client
Expand All @@ -35,6 +35,7 @@ impl ShardedClient {
}

/// Returns a client connected to the given uri
#[allow(dead_code)]
pub async fn connect(uri: Uri) -> Result<Self> {
let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await
Expand Down
13 changes: 7 additions & 6 deletions backends/v3/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
mod block_allocator;
mod queue;
mod backend;
mod block_allocator;
mod client;
mod queue;

use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV3;
use serde::Serialize;
use thiserror::Error;
use utoipa::ToSchema;
pub(crate) use backend::BackendV3;
use crate::client::{ShardedClient, ClientError};

#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct BackendInfo {
Expand All @@ -31,7 +31,8 @@ pub struct BackendInfo {
}

pub async fn connect_backend(
max_input_tokens: usize, max_total_tokens: usize,
max_input_tokens: usize,
max_total_tokens: usize,
master_shard_uds_path: String,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
Expand Down Expand Up @@ -137,4 +138,4 @@ pub enum V3Error {
Warmup(ClientError),
#[error("Not enough memory to handle `max_total_tokens={0}`")]
NotEnoughMemory(usize),
}
}
Loading

0 comments on commit 230f2a4

Please sign in to comment.