Skip to content

Commit

Permalink
fix rust and python unit-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Jun 11, 2024
1 parent 73c3903 commit 37266e2
Show file tree
Hide file tree
Showing 12 changed files with 288 additions and 112 deletions.
1 change: 0 additions & 1 deletion .github/workflows/trufflehog.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@ jobs:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main

52 changes: 43 additions & 9 deletions router/src/infer/v3/block_allocator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::sync::{Arc, Mutex};
use std::fmt::Formatter;
use std::sync::{Arc, Mutex, TryLockError};
use thiserror::Error;

#[derive(Debug, Clone)]
#[derive(Clone)]
pub(crate) struct BlockAllocation {
allocated_blocks: Vec<u32>,
allocated_slots: Vec<u32>,
Expand Down Expand Up @@ -53,7 +54,19 @@ impl Drop for BlockAllocation {
}
}

#[derive(Debug, Clone)]
impl std::fmt::Debug for BlockAllocation {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BlockAllocation")
.field("allocated_blocks", &self.allocated_blocks.len())
.field("allocated_slots", &self.allocated_slots.len())
.field("required_blocks", &self.required_blocks)
.field("required_slots", &self.required_slots)
.field("block_allocator", &self.block_allocator)
.finish()
}
}

#[derive(Clone)]
pub(crate) struct BlockAllocator {
free_blocks: Arc<Mutex<Vec<u32>>>,
block_size: u32,
Expand Down Expand Up @@ -129,8 +142,7 @@ impl BlockAllocator {
Err(AllocationError::NotEnoughPages)
} else {
let n_free_blocks = free_blocks.len();
let allocated_blocks =
free_blocks.split_off(n_free_blocks - clipped_required_blocks);
let allocated_blocks = free_blocks.split_off(n_free_blocks - clipped_required_blocks);

let allocated_blocks = if repeats != 1 {
let mut allocated_blocks = allocated_blocks.repeat(repeats);
Expand All @@ -140,9 +152,8 @@ impl BlockAllocator {
allocated_blocks
};

let mut allocated_slots = Vec::with_capacity(
allocated_blocks.len() * self.block_size as usize * repeats,
);
let mut allocated_slots =
Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats);

let required_slots = (prompt_tokens + decode_tokens) as usize;

Expand All @@ -166,7 +177,30 @@ impl BlockAllocator {
}

pub(crate) fn free(&self, blocks: Vec<u32>) {
self.free_blocks.lock().expect("Lock could not be acquired. This is a bug.").extend(blocks)
self.free_blocks
.lock()
.expect("Lock could not be acquired. This is a bug.")
.extend(blocks)
}
}

impl std::fmt::Debug for BlockAllocator {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut d = f.debug_struct("BlockAllocator");
d.field("block_size", &self.block_size)
.field("window_size", &self.window_size);
match self.free_blocks.try_lock() {
Ok(guard) => {
d.field("free_blocks", &(*guard).len());
}
Err(TryLockError::Poisoned(err)) => {
d.field("free_blocks", &(**err.get_ref()).len());
}
Err(TryLockError::WouldBlock) => {
d.field("free_blocks", &format_args!("<locked>"));
}
};
d.finish()
}
}

Expand Down
10 changes: 6 additions & 4 deletions router/src/infer/v3/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ impl State {
if prefill_tokens > prefill_token_budget {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
tracing::debug!(
"Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget}"
);
self.entries.push_front((id, entry));
break;
}
Expand Down Expand Up @@ -456,7 +458,7 @@ mod tests {
let entry = Entry {
request: ValidGenerateRequest {
inputs: vec![],
input_length: 0,
input_length: 1,
truncate: 0,
decoder_input_details: false,
parameters: ValidParameters {
Expand Down Expand Up @@ -567,7 +569,7 @@ mod tests {

#[tokio::test]
async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None, 0, 2);
let mut state = State::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
Expand Down Expand Up @@ -689,7 +691,7 @@ mod tests {

#[tokio::test]
async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2, 16);
let queue = Queue::new(true, 1, None, 2, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
Expand Down
35 changes: 15 additions & 20 deletions router/src/infer/v3/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,7 @@ async fn prefill(
.expect("ID not found in entries. This is a bug.");

// Send intermediate responses
if let Err(_) = send_stream_responses(stream_responses, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}) {
if send_stream_responses(stream_responses, entry).is_err() {
// Sending failed, remove entry
entries
.remove(&id)
Expand Down Expand Up @@ -405,7 +401,7 @@ async fn filter_batch(
.filter_batch(
id,
updated_requests,
terminated_entries.keys().map(|v| *v).collect(),
terminated_entries.keys().copied().collect(),
)
.await
.unwrap()
Expand Down Expand Up @@ -460,11 +456,14 @@ fn send_terminated_generations(
};

// Send responses
if let Err(_) = entry.response_tx.send(Ok(response)).map_err(|err| {
let send_result = entry.response_tx.send(Ok(response)).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}) {
});

if send_result.is_err() {
// The channel is dropped, skip the rest of the messages
continue 'terminated_generations;
}
}
Expand Down Expand Up @@ -504,11 +503,7 @@ fn filter_send_ended_generations(
// If the generation has ended for this request, we send the responses to the channel and
// remove the entry to drop it and free its blocks
if finished {
let _ = send_stream_responses(stream_responses, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
});
let _ = send_stream_responses(stream_responses, entry);
// Remove from entries and filter
entries.remove(&id).expect("ID not found in entries. This is a bug.");
return None;
Expand All @@ -525,7 +520,11 @@ fn send_stream_responses(
entry: &Entry,
) -> Result<(), Box<SendError<Result<InferStreamResponse, InferError>>>> {
for response in stream_responses {
entry.response_tx.send(Ok(response))?;
entry.response_tx.send(Ok(response)).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
})?;
}
Ok(())
}
Expand All @@ -541,7 +540,7 @@ fn filter_send_update_allocations(
) -> (bool, IntMap<u64, Entry>) {
let mut updated = false;

let ids: Vec<u64> = entries.keys().map(|v| *v).collect();
let ids: Vec<u64> = entries.keys().copied().collect();
let mut terminated_entries =
IntMap::with_capacity_and_hasher(entries.len(), BuildNoHashHasher::default());

Expand Down Expand Up @@ -581,11 +580,7 @@ fn filter_send_update_allocations(
.expect("ID not found in stream_responses. This is a bug.");

// Send intermediate responses
if let Err(_) = send_stream_responses(stream_response, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}) {
if send_stream_responses(stream_response, entry).is_err() {
// Sending failed, remove entry
entries
.remove(id)
Expand Down
26 changes: 14 additions & 12 deletions server/tests/models/test_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,10 @@ def test_causal_lm_generate_token_completion_multi(
# Copy stopping_criterias before filtering
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()

next_batch = next_batch.filter(
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])]
next_batch, _ = next_batch.filter(
default_bloom,
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
[],
)

for _ in range(
Expand Down Expand Up @@ -307,15 +309,13 @@ def test_batch_concatenate(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)

next_batch = next_batch.filter(
next_batch, _ = next_batch.filter(
default_bloom,
[
generate_pb2.UpdatedRequest(
id=next_batch.requests[0].id, blocks=[], slots=[]
),
generate_pb2.UpdatedRequest(
id=next_batch.requests[1].id, blocks=[], slots=[]
),
]
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
],
[],
)

for _ in range(
Expand All @@ -339,8 +339,10 @@ def test_batch_concatenate(
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)

next_batch = next_batch.filter(
[generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])]
next_batch, _ = next_batch.filter(
default_bloom,
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])],
[],
)

for _ in range(
Expand Down
33 changes: 15 additions & 18 deletions server/tests/models/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,10 @@ def test_causal_lm_generate_token_completion_multi(
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
)

next_batch = next_batch.filter(
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])]
next_batch, _ = next_batch.filter(
default_causal_lm,
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
[],
)

for _ in range(
Expand Down Expand Up @@ -307,15 +309,13 @@ def test_batch_concatenate(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)

next_batch = next_batch.filter(
next_batch, _ = next_batch.filter(
default_causal_lm,
[
generate_pb2.UpdatedRequest(
id=next_batch.requests[0].id, blocks=[], slots=[]
),
generate_pb2.UpdatedRequest(
id=next_batch.requests[1].id, blocks=[], slots=[]
),
]
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
],
[],
)

for _ in range(
Expand All @@ -337,15 +337,12 @@ def test_batch_concatenate(
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)

next_batch = next_batch.filter(
next_batch, _ = next_batch.filter(
default_causal_lm,
[
generate_pb2.UpdatedRequest(
id=next_batch.requests[0].id, blocks=[], slots=[]
),
generate_pb2.UpdatedRequest(
id=next_batch.requests[1].id, blocks=[], slots=[]
),
]
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
],
[],
)

for _ in range(
Expand Down
26 changes: 14 additions & 12 deletions server/tests/models/test_seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
)
assert generations[1].generated_text.generated_tokens == 5

next_batch = next_batch.filter(
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])]
next_batch, _ = next_batch.filter(
default_seq2seq_lm,
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
[],
)

generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
Expand Down Expand Up @@ -341,15 +343,13 @@ def test_batch_concatenate(
)
assert generations[2].generated_text.generated_tokens == 5

next_batch = next_batch.filter(
next_batch, _ = next_batch.filter(
default_seq2seq_lm,
[
generate_pb2.UpdatedRequest(
id=next_batch.requests[0].id, blocks=[], slots=[]
),
generate_pb2.UpdatedRequest(
id=next_batch.requests[1].id, blocks=[], slots=[]
),
]
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
],
[],
)

generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
Expand All @@ -360,8 +360,10 @@ def test_batch_concatenate(
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7

next_batch = next_batch.filter(
[generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])]
next_batch, _ = next_batch.filter(
default_seq2seq_lm,
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])],
[],
)

generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
Expand Down
Loading

0 comments on commit 37266e2

Please sign in to comment.