Skip to content

Commit 73c3903

Browse files
FlashCausalLM implem
1 parent 6983ec9 commit 73c3903

File tree

8 files changed

+274
-129
lines changed

8 files changed

+274
-129
lines changed

proto/v3/generate.proto

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,18 @@ message FilterBatchRequest {
224224
repeated uint64 terminated_request_ids = 3;
225225
}
226226

227+
message TerminatedGeneration {
228+
// Request ID
229+
uint64 id = 1;
230+
// Generated text
231+
GeneratedText generated_text = 2;
232+
}
233+
227234
message FilterBatchResponse {
228235
/// Filtered Batch (cached)
229236
CachedBatch batch = 1;
230237
/// Terminated generations
231-
repeated GeneratedText terminated_generations = 2;
238+
repeated TerminatedGeneration terminated_generations = 2;
232239
}
233240

234241

router/client/src/v3/client.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,15 @@ impl Client {
9292
batch_id: u64,
9393
kept_requests: Vec<KeptRequest>,
9494
terminated_request_ids: Vec<u64>,
95-
) -> Result<Option<CachedBatch>> {
95+
) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> {
9696
let request = tonic::Request::new(FilterBatchRequest {
9797
batch_id,
9898
kept_requests,
9999
terminated_request_ids,
100100
})
101101
.inject_context();
102102
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
103-
Ok(filtered_batch.batch)
103+
Ok((filtered_batch.batch, filtered_batch.terminated_generations))
104104
}
105105

106106
/// Warmup on a max size batch

router/client/src/v3/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ pub use client::Client;
88
pub use pb::generate::v3::{
99
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
1010
HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest,
11-
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
11+
NextTokenChooserParameters, Request, StoppingCriteriaParameters, TerminatedGeneration, Tokens,
1212
};
1313
pub use sharded_client::ShardedClient;

router/client/src/v3/sharded_client.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
use crate::{v3, Health, ShardInfo};
33
use crate::{ClientError, Result};
44

5-
use crate::v3::{Chunk, InfoResponse, Input};
5+
use crate::v3::{Chunk, InfoResponse, Input, TerminatedGeneration};
66
use async_trait::async_trait;
77
use futures::future::join_all;
88
use tonic::transport::Uri;
@@ -86,7 +86,7 @@ impl ShardedClient {
8686
batch_id: u64,
8787
kept_requests: Vec<KeptRequest>,
8888
terminated_request_ids: Vec<u64>,
89-
) -> Result<Option<CachedBatch>> {
89+
) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> {
9090
let futures: Vec<_> = self
9191
.clients
9292
.iter_mut()

0 commit comments

Comments
 (0)