From 9b6a7061aece639892cd7a75f0de2ae4ebae001b Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Mon, 9 Dec 2024 23:50:00 +0100 Subject: [PATCH 01/16] feat: Working profiles with chat turns. --- Cargo.toml | 3 +- src/app.rs | 3 +- src/benchmark.rs | 2 + src/lib.rs | 24 ++- src/main.rs | 77 ++++++-- src/requests.rs | 500 +++++++++++++++++++++++++++++++++-------------- src/results.rs | 24 ++- src/scheduler.rs | 14 +- 8 files changed, 468 insertions(+), 179 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 618a0d5..ff8cd25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ serde_with = "3.9.0" sysinfo = "0.31.4" mockito = "1.5.0" tabled = "=0.14" +uuid = { version = "1.11.0", features = ["v4", "fast-rng"] } [build-dependencies] -vergen-gitcl = { version = "1.0.1" } \ No newline at end of file +vergen-gitcl = { version = "1.0.1" } diff --git a/src/app.rs b/src/app.rs index 7d9050d..7f6b5aa 100644 --- a/src/app.rs +++ b/src/app.rs @@ -328,7 +328,8 @@ impl Widget for &App { Some(_) => "Manual".to_string(), }; let config_text = Text::from(vec![Line::from(vec![ - format!("Benchmark: {kind} | Max VUs: {max_vus} | Duration: {duration} sec | Rates: {rates} | Warmup: {warmup} sec", + format!("Profile: {profile} | Benchmark: {kind} | Max VUs: {max_vus} | Duration: {duration} sec | Rates: {rates} | Warmup: {warmup} sec", + profile = self.benchmark_config.profile.clone().unwrap_or("N/A".to_string()), kind = self.benchmark_config.benchmark_kind, max_vus = self.benchmark_config.max_vus, duration = self.benchmark_config.duration.as_secs_f64(), diff --git a/src/benchmark.rs b/src/benchmark.rs index bff8e2b..7aa1779 100644 --- a/src/benchmark.rs +++ b/src/benchmark.rs @@ -71,6 +71,7 @@ pub struct BenchmarkConfig { pub prompt_options: Option, pub decode_options: Option, pub tokenizer: String, + pub profile: Option, #[serde(rename = "meta")] pub extra_metadata: Option>, } @@ -439,6 +440,7 @@ mod tests { prompt_options: None, decode_options: None, tokenizer: "gpt2".to_string(), + profile: None, extra_metadata: None, }, backend, diff --git a/src/lib.rs b/src/lib.rs index f410004..742d464 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ use std::sync::Arc; pub use crate::app::run_console; pub use crate::benchmark::{BenchmarkConfig, BenchmarkKind}; +pub use crate::profiles::apply_profile; use crate::benchmark::{Event, MessageEvent}; use crate::requests::OpenAITextGenerationBackend; pub use crate::requests::TokenizeOptions; @@ -28,10 +29,12 @@ mod results; mod scheduler; mod table; mod writers; +mod profiles; pub struct RunConfiguration { pub url: String, pub tokenizer_name: String, + pub profile: Option, pub max_vus: u64, pub duration: std::time::Duration, pub rates: Option>, @@ -48,10 +51,26 @@ pub struct RunConfiguration { pub model_name: String, } -pub async fn run(run_config: RunConfiguration, stop_sender: Sender<()>) -> anyhow::Result<()> { +pub async fn run(mut run_config: RunConfiguration, stop_sender: Sender<()>) -> anyhow::Result<()> { info!("Starting benchmark"); // set process system limits sysinfo::set_open_files_limit(0); + // apply profile if needed + run_config = match run_config.profile.clone() { + None => run_config, + Some(profile) => { + match apply_profile(profile.as_str(), run_config) { + Ok(config) => { + info!("Profile applied: {}", profile); + config + }, + Err(e) => { + error!("Failed to apply profile: {:?}", e); + return Err(e); + } + } + } + }; // initialize tokenizer let params = FromPretrainedParameters { token: run_config.hf_token.clone(), @@ -88,6 +107,7 @@ pub async fn run(run_config: RunConfiguration, stop_sender: Sender<()>) -> anyho prompt_options: run_config.prompt_options.clone(), decode_options: run_config.decode_options.clone(), tokenizer: run_config.tokenizer_name.clone(), + profile: run_config.profile.clone(), extra_metadata: run_config.extra_metadata.clone(), }; config.validate()?; @@ -144,7 +164,7 @@ pub async fn run(run_config: RunConfiguration, stop_sender: Sender<()>) -> anyho run_config.dataset_file, run_config.hf_token.clone(), ) - .expect("Can't download dataset"); + .expect("Can't download dataset"); let requests = requests::ConversationTextRequestGenerator::load( filepath, run_config.tokenizer_name.clone(), diff --git a/src/main.rs b/src/main.rs index 61f74f9..393bfc5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,10 +19,24 @@ struct Args { model_name: Option, /// The maximum number of virtual users to use - #[clap(default_value = "128", short, long, env)] + #[clap( + default_value = "128", + short, + long, + env, + group = "group_manual", + conflicts_with = "group_profile" + )] max_vus: u64, /// The duration of each benchmark step - #[clap(default_value = "120s", short, long, env)] + #[clap( + default_value = "120s", + short, + long, + env, + group = "group_manual", + conflicts_with = "group_profile" + )] #[arg(value_parser = parse_duration)] duration: Duration, /// A list of rates of requests to send per second (only valid for the ConstantArrivalRate benchmark). @@ -32,12 +46,28 @@ struct Args { /// The rates will be linearly spaced up to the detected maximum rate #[clap(default_value = "10", long, env)] num_rates: u64, - + /// A benchmark profile to use + #[clap(long, env, group = "group_profile", conflicts_with = "group_manual")] + profile: Option, /// The kind of benchmark to run (throughput, sweep, optimum) - #[clap(default_value = "sweep", short, long, env)] + #[clap( + default_value = "sweep", + short, + long, + env, + group = "group_manual", + conflicts_with = "group_profile" + )] benchmark_kind: String, /// The duration of the prewarm step ran before the benchmark to warm up the backend (JIT, caches, etc.) - #[clap(default_value = "30s", short, long, env)] + #[clap( + default_value = "30s", + short, + long, + env, + group = "group_manual", + conflicts_with = "group_profile" + )] #[arg(value_parser = parse_duration)] warmup: Duration, /// The URL of the backend to benchmark. Must be compatible with OpenAI Message API @@ -57,7 +87,13 @@ struct Args { /// * variance: variance in the number of prompt tokens /// /// Example: num_tokens=200,max_tokens=210,min_tokens=190,variance=10 - #[clap(long, env, value_parser(parse_tokenizer_options))] + #[clap( + long, + env, + value_parser(parse_tokenizer_options), + group = "group_manual", + conflicts_with = "group_profile" + )] prompt_options: Option, /// Constraints for the generated text. /// We sample the number of tokens to generate from a normal distribution. @@ -68,13 +104,31 @@ struct Args { /// * variance: variance in the number of generated tokens /// /// Example: num_tokens=200,max_tokens=210,min_tokens=190,variance=10 - #[clap(long, env, value_parser(parse_tokenizer_options))] + #[clap( + long, + env, + value_parser(parse_tokenizer_options), + group = "group_manual", + conflicts_with = "group_profile" + )] decode_options: Option, /// Hugging Face dataset to use for prompt generation - #[clap(default_value = "hlarcher/share_gpt_small", long, env)] + #[clap( + default_value = "hlarcher/share_gpt_small", + long, + env, + group = "group_manual", + conflicts_with = "group_profile" + )] dataset: String, /// File to use in the Dataset - #[clap(default_value = "share_gpt_filtered_small.json", long, env)] + #[clap( + default_value = "share_gpt_filtered_small.json", + long, + env, + group = "group_manual", + conflicts_with = "group_profile" + )] dataset_file: String, /// Extra metadata to include in the benchmark results file, comma-separated key-value pairs. /// It can be, for example, used to include information about the configuration of the @@ -133,8 +187,8 @@ fn parse_tokenizer_options(s: &str) -> Result { } if tokenizer_options.num_tokens.is_some() && (tokenizer_options.num_tokens.unwrap() == 0 - || tokenizer_options.min_tokens == 0 - || tokenizer_options.max_tokens == 0) + || tokenizer_options.min_tokens == 0 + || tokenizer_options.max_tokens == 0) { return Err(Error::new(InvalidValue)); } @@ -181,6 +235,7 @@ async fn main() { .unwrap_or(args.tokenizer_name.clone()); let run_config = RunConfiguration { url: args.url.clone(), + profile: args.profile.clone(), tokenizer_name: args.tokenizer_name.clone(), max_vus: args.max_vus, duration: args.duration, diff --git a/src/requests.rs b/src/requests.rs index 23b5418..f423f94 100644 --- a/src/requests.rs +++ b/src/requests.rs @@ -1,3 +1,5 @@ +use std::cmp::Ordering; +use std::collections::{BinaryHeap, HashMap}; use async_trait::async_trait; use futures_util::StreamExt; use hf_hub::api::sync::ApiBuilder; @@ -10,19 +12,19 @@ use reqwest_eventsource::{Error, Event, EventSource}; use serde::{Deserialize, Serialize}; use std::fmt::Display; use std::path::PathBuf; -use std::sync::atomic::AtomicI64; use std::sync::{Arc, Mutex}; use std::time; use tokenizers::{FromPretrainedParameters, Tokenizer}; -use tokio::sync::mpsc::Sender; -use tokio::time::sleep; +use tokio::sync::mpsc::{Sender}; +use tokio::time::{sleep, Instant}; +use uuid::{Uuid}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] pub struct TextGenerationRequest { + pub id: Option, pub prompt: String, - pub num_prompt_tokens: u64, // this includes the system prompt if present + pub num_prompt_tokens: u64, pub num_decode_tokens: Option, - pub system_prompt: Option, } #[async_trait] @@ -123,23 +125,11 @@ impl TextGenerationBackend for OpenAITextGenerationBackend { sender: Sender, ) { let url = format!("{base_url}/v1/chat/completions", base_url = self.base_url); - let mut aggregated_response = TextGenerationAggregatedResponse::default(); - let messages = match &request.system_prompt { - None => vec![OpenAITextGenerationMessage { - role: "user".to_string(), - content: request.prompt.clone(), - }], - Some(system_prompt) => vec![ - OpenAITextGenerationMessage { - role: "system".to_string(), - content: system_prompt.clone(), - }, - OpenAITextGenerationMessage { - role: "user".to_string(), - content: request.prompt.clone(), - }, - ], - }; + let mut aggregated_response = TextGenerationAggregatedResponse::new(request.clone()); + let messages = vec![OpenAITextGenerationMessage { + role: "user".to_string(), + content: request.prompt.clone(), + }]; let body = OpenAITextGenerationRequest { model: self.model_name.clone(), messages, @@ -158,7 +148,7 @@ impl TextGenerationBackend for OpenAITextGenerationBackend { .json(&serde_json::json!(body)) .timeout(self.timeout); // start timer - aggregated_response.start(request.num_prompt_tokens); + aggregated_response.start(); let mut es = EventSource::new(req).unwrap(); let mut final_response = "".to_string(); while let Some(event) = es.next().await { @@ -287,8 +277,8 @@ impl TextGenerationBackend for DummyTextGenerationBackend { request: Arc, sender: Sender, ) { - let mut response = TextGenerationAggregatedResponse::default(); - response.start(request.num_prompt_tokens); + let mut response = TextGenerationAggregatedResponse::new(request.clone()); + response.start(); let num_tokens = request.num_decode_tokens.unwrap_or(10); let time_per_token = self .time_to_generate @@ -308,13 +298,10 @@ impl TextGenerationBackend for DummyTextGenerationBackend { pub trait TextRequestGenerator: Sync { fn generate_request(&mut self) -> TextGenerationRequest; + /// callback can be used by generators to add new requests to the queue based on the response (e.g. for multi-turn conversation generation) + fn callback(&mut self, request: Arc, response: &str); } -#[derive(Clone)] -pub struct ConversationTextRequestGenerator { - pub requests: Vec, - current_index: Arc, -} #[derive(Deserialize, Serialize, Clone)] pub struct Conversation { @@ -366,7 +353,46 @@ impl Display for TokenizeOptions { } } +#[derive(Clone, Eq, PartialEq)] +pub struct ConversationTurnRequest { + id: Uuid, + priority: u64, + tie: Instant, + request: TextGenerationRequest, +} + +impl Ord for ConversationTurnRequest { + // order by increasing priority and decreasing tie-breaking + // this way, we can pop the item with the highest priority and oldest tie-breaking + fn cmp(&self, other: &Self) -> Ordering { + self.priority.cmp(&other.priority).then_with(|| self.tie.cmp(&other.tie).reverse()) + } +} + +impl PartialOrd for ConversationTurnRequest { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + + +#[derive(Clone)] +pub struct ConversationTextRequestGenerator { + pub requests: HashMap, + pub queue: BinaryHeap, + pub tokenizer: Arc, +} + impl ConversationTextRequestGenerator { + /// Load a conversation dataset from a JSON file + /// The JSON file should be an array of objects, each object representing a conversation entry + /// Each conversation entry should have an `id` field and a `conversations` field + /// The `conversations` field should be an array of objects, each object representing a turn in the conversation + /// Each turn should have a `role` field and a `content` field + /// The `role` field should be either "user" or "system" + /// The `content` field should be the text of the turn + /// All conversation turns are tokenized and converted into `TextGenerationRequest`. The `id` field is used to link turns in the conversation, + /// so that each `TextGenerationRequest` has a reference to the next turn in the conversation. pub fn load( filepath: PathBuf, tokenizer: String, @@ -389,7 +415,7 @@ impl ConversationTextRequestGenerator { let input = std::fs::read_to_string(&filepath)?; let data: Vec = serde_json::from_str(&input).expect("Unable to parse input file. Check that it is valid JSON and matches the expected format."); // generate requests - let requests: Arc>> = Arc::from(Mutex::from(Vec::new())); + let requests: Arc>> = Arc::from(Mutex::from(HashMap::new())); info!( "Generating requests from {filepath}", filepath = filepath.display().to_string() @@ -404,89 +430,87 @@ impl ConversationTextRequestGenerator { if entry.conversations.is_empty() { continue; } - let system_prompt = entry + let mut turn_idx = 0; + let ids = (0..entry.conversations.len()).map(|_| Uuid::new_v4()).collect::>(); + let filtered_conversations = entry .conversations .iter() - .find(|c| c.role == "system") - .map(|c| c.content.clone()); - let system_prompt_tokens = match system_prompt { - Some(ref prompt) => { - let (_, num_tokens) = match tokenize_prompt( - prompt.clone(), - tokenizer.clone(), - &TokenizeOptions::default(), - ) { - Ok((prompt, num_tokens)) => (prompt, num_tokens), - Err(e) => { - debug!("Error tokenizing system prompt: {e}"); - return; - } - }; - num_tokens - } - None => 0, - }; - entry - .conversations - .iter() - .filter(|c| c.role == "user") - .for_each(|c| { - let prompt = c.content.clone(); - let num_decode_tokens = decode_tokenize_opts.clone().map_or_else( - || None, - |opts| { - opts.num_tokens.map(|num_tokens| { - sample_num_tokens( - num_tokens, - opts.min_tokens, - opts.max_tokens, - opts.variance, - ) - }) - }, - ); - match &prompt_tokenize_opts { - None => { - let (_, num_tokens) = match tokenize_prompt( - prompt.clone(), - tokenizer.clone(), - &TokenizeOptions::default(), - ) { - Ok((prompt, num_tokens)) => (prompt, num_tokens), - Err(e) => { - debug!("Error tokenizing prompt: {e}"); - return; - } - }; - requests.lock().unwrap().push(TextGenerationRequest { + .filter(|c| c.role == "user" || c.role == "system") + .collect::>(); + for c in &filtered_conversations { + let prompt = c.content.clone(); + let num_decode_tokens = decode_tokenize_opts.clone().map_or_else( + || None, + |opts| { + opts.num_tokens.map(|num_tokens| { + sample_num_tokens( + num_tokens, + opts.min_tokens, + opts.max_tokens, + opts.variance, + ) + }) + }, + ); + let next_id = if turn_idx == entry.conversations.len() - 1 { + None + } else { + Some(ids[turn_idx + 1].clone()) // link to next turn in the conversation + }; + debug!("Prompt: {prompt}", prompt = prompt); + match &prompt_tokenize_opts { + None => { + let (_, num_tokens) = match tokenize_prompt( + prompt.clone(), + tokenizer.clone(), + &TokenizeOptions::default(), + ) { + Ok((prompt, num_tokens)) => (prompt, num_tokens), + Err(e) => { + debug!("Error tokenizing prompt: {e}"); + return; + } + }; + requests.lock().unwrap().insert(ids[turn_idx], ConversationTurnRequest { + id: ids[turn_idx], + priority: turn_idx as u64, + tie: Instant::now(), + request: TextGenerationRequest { + id: next_id, prompt, - num_prompt_tokens: num_tokens + system_prompt_tokens, + num_prompt_tokens: num_tokens, num_decode_tokens, - system_prompt: system_prompt.clone(), - }); - } - Some(options) => { - // compute number of tokens to generate using a Gaussian distribution - let (sampled_prompt, prompt_tokens) = match tokenize_prompt( - prompt.clone(), - tokenizer.clone(), - options, - ) { - Ok(prompt) => prompt, - Err(e) => { - debug!("Error tokenizing prompt: {e}"); - return; - } - }; - requests.lock().unwrap().push(TextGenerationRequest { + }, + }); + } + Some(options) => { + // compute number of tokens to generate using a Gaussian distribution + let (sampled_prompt, prompt_tokens) = match tokenize_prompt( + prompt.clone(), + tokenizer.clone(), + options, + ) { + Ok(prompt) => prompt, + Err(e) => { + debug!("Error tokenizing prompt: {e}"); + return; + } + }; + requests.lock().unwrap().insert(ids[turn_idx], ConversationTurnRequest { + id: ids[turn_idx], + tie: Instant::now(), + priority: turn_idx as u64, + request: TextGenerationRequest { + id: next_id, prompt: sampled_prompt, - num_prompt_tokens: prompt_tokens + system_prompt_tokens, + num_prompt_tokens: prompt_tokens, num_decode_tokens, - system_prompt: system_prompt.clone(), - }); - } + }, + }); } - }); + } + turn_idx += 1; + }; // TODO: check that we have enough requests } }); @@ -495,9 +519,12 @@ impl ConversationTextRequestGenerator { "Generated {num_requests} requests", num_requests = requests.len() ); + // create the queue from the hashmap. Only queue first turns in the conversation + let queue = BinaryHeap::from(requests.values().cloned().filter(|item| item.priority == 0).collect::>()); Ok(Self { - current_index: Arc::from(AtomicI64::new(0)), - requests: requests.to_vec(), + requests: requests.clone(), + tokenizer, + queue, }) } @@ -540,14 +567,68 @@ fn entry_splitter( impl TextRequestGenerator for ConversationTextRequestGenerator { fn generate_request(&mut self) -> TextGenerationRequest { - let idx = self - .current_index - .fetch_add(1, std::sync::atomic::Ordering::SeqCst); - if idx >= (self.requests.len() - 1) as i64 { - self.current_index - .store(0, std::sync::atomic::Ordering::SeqCst); + let item = self.queue.pop().expect("Queue is empty"); + // add the item back to the end of the queue if it is a first turn in the conversation + if item.priority == 0 { + let mut cloned_item = item.clone(); + cloned_item.tie = Instant::now(); // update the tie-breaking for intra-priority sorting + self.queue.push(cloned_item); } - self.requests[idx as usize].clone() + item.request + } + + /// Use callback to add a new chat turn to the queue. + /// The turn is generated from the `TextGenerationRequest`, using the `id` field to link it to + /// the next turn in the conversation. + /// Those turns must be scheduled as soon as possible so that we may benefit from + /// KV cache hits. The `priority` field is used to move the turn to the front of the queue. + fn callback(&mut self, request: Arc, response: &str) { + // retrieve current turn id + let id = match request.id { + None => { + return; + } + Some(id) => { + id + } + }; + // retrieve next turn from id + let next_request = match self.requests.get(&id) { + None => { + return; + } + Some(request) => { + request + } + }; + // create a new turn with the prompt concatenated with the response and next turn's prompt + // and add the next turn id to the new turn + let new_prompt = request.prompt.clone() + "\n" + response + "\n" + next_request.request.prompt.as_str(); + // tokenize the prompt + let (prompt, num_tokens) = match tokenize_prompt( + new_prompt.to_string(), + self.tokenizer.clone(), + &TokenizeOptions::default(), + ) { + Ok((prompt, num_tokens)) => (prompt, num_tokens), + Err(_) => { + return; + } + }; + let next_id = next_request.request.id; + let turn = ConversationTurnRequest { + id, + priority: 100, // move to the front of the queue + tie: Instant::now(), // use the current time as tie-breaking (older turns have higher priority) + request: TextGenerationRequest { + id: next_id, + prompt, + num_prompt_tokens: num_tokens, + num_decode_tokens: request.num_decode_tokens, // decode tokens do not change between turns + }, + }; + //debug!("Adding new turn to queue: {turn}", turn = turn.request.prompt); + self.queue.push(turn); } } @@ -568,12 +649,13 @@ impl Default for DummyTextRequestGenerator { impl TextRequestGenerator for DummyTextRequestGenerator { fn generate_request(&mut self) -> TextGenerationRequest { TextGenerationRequest { + id: None, prompt: "Hello, world!".to_string(), num_prompt_tokens: 2, num_decode_tokens: Some(10), - system_prompt: None, } } + fn callback(&mut self, _request: Arc, _response: &str) {} } fn tokenize_prompt( @@ -624,45 +706,45 @@ pub struct TextGenerationAggregatedResponse { pub start_time: Option, pub end_time: Option, pub num_generated_tokens: u64, - pub num_prompt_tokens: u64, - pub times_to_tokens: Vec, + pub times_to_tokens: Vec, last_received_token_time: tokio::time::Instant, pub failed: bool, pub ended: bool, + pub request: Option>, + pub response: Option, } -impl Default for TextGenerationAggregatedResponse { - fn default() -> Self { +impl TextGenerationAggregatedResponse { + pub fn new(request: Arc) -> Self { Self { start_time: None, end_time: None, num_generated_tokens: 0, - num_prompt_tokens: 0, times_to_tokens: Vec::new(), last_received_token_time: tokio::time::Instant::now(), failed: false, ended: false, + request: Some(request), + response: None, } } -} -impl TextGenerationAggregatedResponse { pub fn new_as_ended() -> Self { Self { start_time: None, end_time: None, num_generated_tokens: 0, - num_prompt_tokens: 0, times_to_tokens: Vec::new(), last_received_token_time: tokio::time::Instant::now(), failed: false, ended: true, + request: None, + response: None, } } - fn start(&mut self, num_prompt_tokens: u64) { + fn start(&mut self) { self.start_time = Some(tokio::time::Instant::now()); self.last_received_token_time = tokio::time::Instant::now(); - self.num_prompt_tokens = num_prompt_tokens; } fn stop(&mut self) { @@ -745,12 +827,12 @@ mod tests { tokenizer, time::Duration::from_secs(10), ) - .unwrap(); + .unwrap(); let request = TextGenerationRequest { + id: None, prompt: "Hello, world!".to_string(), num_prompt_tokens: 2, num_decode_tokens: Some(10), - system_prompt: None, }; let (tx, mut rx) = tokio::sync::mpsc::channel(1); let request = Arc::new(request); @@ -805,12 +887,12 @@ mod tests { tokenizer, time::Duration::from_secs(10), ) - .unwrap(); + .unwrap(); let request = TextGenerationRequest { + id: None, prompt: "Hello, world!".to_string(), num_prompt_tokens: 2, num_decode_tokens: Some(16), - system_prompt: None, }; let (tx, mut rx) = tokio::sync::mpsc::channel(1); let request = Arc::new(request); @@ -857,7 +939,7 @@ mod tests { assert!( inter_token_latency_avg > expected_inter_token_latency_avg && inter_token_latency_avg - < expected_inter_token_latency_avg + inter_token_latency_overhead, + < expected_inter_token_latency_avg + inter_token_latency_overhead, "inter_token_latency_avg: {:?} < {:?} < {:?}", expected_inter_token_latency_avg, inter_token_latency_avg, @@ -890,12 +972,12 @@ mod tests { tokenizer, time::Duration::from_secs(10), ) - .unwrap(); + .unwrap(); let request = TextGenerationRequest { + id: None, prompt: "Hello, world!".to_string(), num_prompt_tokens: 2, num_decode_tokens: Some(16), - system_prompt: None, }; let (tx, mut rx) = tokio::sync::mpsc::channel(1); let request = Arc::new(request); @@ -936,12 +1018,12 @@ mod tests { tokenizer, time::Duration::from_secs(10), ) - .unwrap(); + .unwrap(); let request = TextGenerationRequest { + id: None, prompt: "Hello, world!".to_string(), num_prompt_tokens: 2, num_decode_tokens: Some(16), - system_prompt: None, }; let (tx, mut rx) = tokio::sync::mpsc::channel(1); let request = Arc::new(request); @@ -982,12 +1064,12 @@ mod tests { tokenizer, time::Duration::from_secs(10), ) - .unwrap(); + .unwrap(); let request = TextGenerationRequest { + id: None, prompt: "Hello, world!".to_string(), num_prompt_tokens: 2, num_decode_tokens: Some(16), - system_prompt: None, }; let (tx, mut rx) = tokio::sync::mpsc::channel(1); let request = Arc::new(request); @@ -1019,7 +1101,7 @@ mod tests { .with_chunked_body(|w| { w.write_all(b"data: {\"choices\": [{\"message\": null, \"finish_reason\": null, \"delta\": {\"content\": \"Hello, world!\"}}]}\n\n").unwrap(); // sleep for 5s - sleep(std::time::Duration::from_secs(5)); + sleep(Duration::from_secs(5)); w.write_all(b"data: [DONE]\n\n") }) .create_async().await; @@ -1030,14 +1112,14 @@ mod tests { url, "gpt2".to_string(), tokenizer, - time::Duration::from_secs(1), + Duration::from_secs(1), ) - .unwrap(); + .unwrap(); let request = TextGenerationRequest { + id: None, prompt: "Hello, world!".to_string(), num_prompt_tokens: 2, num_decode_tokens: Some(16), - system_prompt: None, }; let (tx, mut rx) = tokio::sync::mpsc::channel(1); let request = Arc::new(request); @@ -1072,8 +1154,8 @@ mod tests { Some(decode_tokenize_opts), hf_token, ) - .unwrap(); - assert_eq!(generator.requests.len(), 17005); + .unwrap(); + assert_eq!(generator.requests.len(), 17016); } /// Test that conversations are bounded by the min/max number of tokens @@ -1096,17 +1178,17 @@ mod tests { Some(decode_tokenize_opts), hf_token, ) - .unwrap(); + .unwrap(); let min_tokens = generator .requests .iter() - .map(|r| r.num_prompt_tokens) + .map(|r| r.1.request.num_prompt_tokens) .min() .unwrap(); let max_tokens = generator .requests .iter() - .map(|r| r.num_prompt_tokens) + .map(|r| r.1.request.num_prompt_tokens) .max() .unwrap(); assert!(min_tokens >= 4, "Min tokens: {}", min_tokens); @@ -1133,9 +1215,127 @@ mod tests { Some(decode_tokenize_opts), hf_token, ) - .unwrap(); + .unwrap(); for r in generator.requests.iter() { - assert_eq!(r.num_prompt_tokens, 200); + assert_eq!(r.1.request.num_prompt_tokens, 200); } } + + /// Test that multi-turn conversations are correctly loaded + /// The test data contains 2 conversations first with 6 turns and the second with 1 turn and a system prompt + #[tokio::test] + async fn test_load_conversations_multi_turn() { + let filepath = PathBuf::from("test_data/chat.json"); + let tokenizer = "gpt2".to_string(); + let prompt_tokenize_opts = TokenizeOptions { + num_tokens: None, + min_tokens: 1, + max_tokens: 200, + variance: 0, + }; + let hf_token = None; + let decode_tokenize_opts = TokenizeOptions::default(); + let generator = ConversationTextRequestGenerator::load( + filepath, + tokenizer, + Some(prompt_tokenize_opts), + Some(decode_tokenize_opts), + hf_token, + ) + .unwrap(); + let turns = generator.requests.into_iter().map(|r| r.1.clone()).collect::>(); + assert_eq!(turns.len(), 8); + let first_turns = turns.clone().into_iter().filter(|t| t.priority == 0); + // we expect to have 2 None values for the first turn in each conversation + assert_eq!(first_turns.count(), 2); + let first_conversation = turns.clone().into_iter().filter(|t| t.request.prompt == "Summarize the main ideas of Jeff Walker's Product Launch Formula into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients...").collect::>(); + // rebuild the conversation from first turn + let mut conversation = Vec::new(); + let mut current_turn = first_conversation[0].clone(); + loop { + conversation.push(current_turn.clone()); + match turns.iter().find(|t| t.id == current_turn.request.id.unwrap_or_default()) { + Some(t) => current_turn = t.clone(), + None => break, + } + } + assert_eq!(conversation.len(), 6); + let got = conversation.iter().map(|t| t.request.prompt.clone()).collect::>().join("\n"); + let expect = "Summarize the main ideas of Jeff Walker's Product Launch Formula into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients...\nSummarize the main ideas of Brendon Burchard's Experts Academy into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients...\nWhat are the mental triggers in Jeff Walker's Product Launch Formula and \"Launch\" book?\nWrite a summary of why scarcity and urgency are the strongest mental triggers and have been the driving force behind many of our best performing campaigns over the last 8 years.\nSummarize Russell Brunson's Perfect Webinar Script...\nSummarize the 6 human needs as Tony Robbins explains..."; + assert_eq!(expect, got); + } + + /// Test that only first turns of multi-turn conversations are queued + #[tokio::test] + async fn test_conversation_queue() { + let filepath = PathBuf::from("test_data/chat.json"); + let tokenizer = "gpt2".to_string(); + let prompt_tokenize_opts = TokenizeOptions { + num_tokens: None, + min_tokens: 1, + max_tokens: 200, + variance: 0, + }; + let hf_token = None; + let decode_tokenize_opts = TokenizeOptions::default(); + let mut generator = ConversationTextRequestGenerator::load( + filepath, + tokenizer, + Some(prompt_tokenize_opts), + Some(decode_tokenize_opts), + hf_token, + ) + .unwrap(); + for i in 0..20 { + let req = generator.generate_request(); + if i % 2 == 0 { // first turn of the first conversation + assert_eq!(req.prompt, "Summarize the main ideas of Jeff Walker's Product Launch Formula into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients..."); + } + if i % 2 == 1 { // first turn of the second conversation + assert_eq!(req.prompt, "You are a helpful assistant."); + } + } + } + + /// Test that multi-turn conversations are correctly queued when we add responses + #[tokio::test] + async fn test_conversation_turns_queue() { + let filepath = PathBuf::from("test_data/chat.json"); + let tokenizer = "gpt2".to_string(); + let prompt_tokenize_opts = TokenizeOptions { + num_tokens: None, + min_tokens: 1, + max_tokens: 200, + variance: 0, + }; + let hf_token = None; + let decode_tokenize_opts = TokenizeOptions::default(); + let mut generator = ConversationTextRequestGenerator::load( + filepath, + tokenizer, + Some(prompt_tokenize_opts), + Some(decode_tokenize_opts), + hf_token, + ) + .unwrap(); + // generate the first user turn + let req = generator.generate_request(); + let response = "This is my response".to_string(); + generator.callback(Arc::from(req), &response); + // now try to generate the next user turn + let req = generator.generate_request(); + // we expect to have all the turns concatenated into the prompt + assert_eq!(req.prompt, "Summarize the main ideas of Jeff Walker's Product Launch Formula into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients...\nThis is my response\nSummarize the main ideas of Brendon Burchard's Experts Academy into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients..."); + assert_eq!(req.num_prompt_tokens, 76); + // now add a response to the second turn + let response = "This is my second response".to_string(); + generator.callback(Arc::from(req), &response); + // now try to generate the next user turn + let req = generator.generate_request(); + assert_eq!(req.prompt, "Summarize the main ideas of Jeff Walker's Product Launch Formula into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients...\nThis is my response\nSummarize the main ideas of Brendon Burchard's Experts Academy into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients...\nThis is my second response\nWhat are the mental triggers in Jeff Walker's Product Launch Formula and \"Launch\" book?"); + // check that next turn is a first turn + let req = generator.generate_request(); + assert_eq!(req.prompt, "You are a helpful assistant."); + } } + diff --git a/src/results.rs b/src/results.rs index 57ea1ac..3b34b8f 100644 --- a/src/results.rs +++ b/src/results.rs @@ -91,14 +91,14 @@ impl BenchmarkResults { pub fn total_tokens_sent(&self) -> u64 { self.get_successful_responses() .iter() - .map(|response| response.num_prompt_tokens) + .map(|response| response.request.clone().unwrap().num_prompt_tokens) .sum() } pub fn total_prompt_tokens(&self) -> u64 { self.get_successful_responses() .iter() - .map(|response| response.num_prompt_tokens) + .map(|response| response.request.clone().unwrap().num_prompt_tokens) .sum() } @@ -352,14 +352,21 @@ impl BenchmarkReport { #[cfg(test)] mod test { + use std::sync::Arc; + use crate::requests::TextGenerationRequest; use super::*; #[test] fn test_time_to_first_token_percentile() { - let mut response1 = TextGenerationAggregatedResponse::default(); + let request = Arc::from(TextGenerationRequest { + id: None, + prompt: "test".to_string(), + num_prompt_tokens: 10, + num_decode_tokens: None, + }); + let mut response1 = TextGenerationAggregatedResponse::new(request.clone()); response1.start_time = Some(tokio::time::Instant::now()); response1.end_time = Some(tokio::time::Instant::now() + tokio::time::Duration::from_millis(100)); - response1.num_prompt_tokens = 10; response1.num_generated_tokens = 100; response1.failed = false; response1.times_to_tokens = vec![ @@ -370,11 +377,10 @@ mod test { Duration::from_millis(500), ]; - let mut response2 = TextGenerationAggregatedResponse::default(); + let mut response2 = TextGenerationAggregatedResponse::new(request.clone()); response2.start_time = Some(tokio::time::Instant::now()); response2.end_time = Some(tokio::time::Instant::now() + tokio::time::Duration::from_millis(200)); - response2.num_prompt_tokens = 10; response2.num_generated_tokens = 100; response2.failed = false; response2.times_to_tokens = vec![ @@ -385,11 +391,10 @@ mod test { Duration::from_millis(1000), ]; - let mut response3 = TextGenerationAggregatedResponse::default(); + let mut response3 = TextGenerationAggregatedResponse::new(request.clone()); response3.start_time = Some(tokio::time::Instant::now()); response3.end_time = Some(tokio::time::Instant::now() + tokio::time::Duration::from_millis(300)); - response3.num_prompt_tokens = 10; response3.num_generated_tokens = 100; response3.failed = false; response3.times_to_tokens = vec![ @@ -400,11 +405,10 @@ mod test { Duration::from_millis(1500), ]; - let mut response4 = TextGenerationAggregatedResponse::default(); + let mut response4 = TextGenerationAggregatedResponse::new(request.clone()); response4.start_time = Some(tokio::time::Instant::now()); response4.end_time = Some(tokio::time::Instant::now() + tokio::time::Duration::from_millis(300)); - response4.num_prompt_tokens = 10; response4.num_generated_tokens = 100; response4.failed = false; response4.times_to_tokens = vec![ diff --git a/src/scheduler.rs b/src/scheduler.rs index 9a4c0fb..5a89247 100644 --- a/src/scheduler.rs +++ b/src/scheduler.rs @@ -96,6 +96,7 @@ impl Scheduler { let results = self.results.clone(); let progress_tx = self.progress_tx.clone(); let mut stop_receiver = self.stop_sender.subscribe(); + let req_gen = self.requests_generator.clone(); tokio::spawn(async move { tokio::select! { _ = stop_receiver.recv() => { @@ -103,6 +104,11 @@ impl Scheduler { } _ = async{ while let Some(response) = rx.recv().await{ + // call generator callback + let response_txt=response.response.clone(); + if let Some(request)= response.request.clone(){ + req_gen.lock().await.callback(request, response_txt.unwrap_or_default().as_str()); + } let result = results.clone(); let progress_tx = progress_tx.clone(); trace!("Received response: {:?}", response); @@ -183,7 +189,7 @@ mod tests { #[tokio::test] async fn test_constant_vus_scheduler() { let (progress_tx, _) = tokio::sync::mpsc::channel(10000); - let (stop_sender, _) = tokio::sync::broadcast::channel(1); + let (stop_sender, _) = broadcast::channel(1); let backend = Box::new(crate::requests::DummyTextGenerationBackend::new( Duration::from_secs(1), )); @@ -196,7 +202,7 @@ mod tests { ExecutorType::ConstantVUs, ExecutorConfig { max_vus: 800, - duration: std::time::Duration::from_secs(10), + duration: Duration::from_secs(10), rate: None, }, requests_generator, @@ -235,7 +241,7 @@ mod tests { tokenizer, time::Duration::from_secs(10), ) - .unwrap(); + .unwrap(); let requests_generator = Arc::from(Mutex::from( crate::requests::DummyTextRequestGenerator::new(), )); @@ -245,7 +251,7 @@ mod tests { ExecutorType::ConstantArrivalRate, ExecutorConfig { max_vus: 800, - duration: std::time::Duration::from_secs(10), + duration: Duration::from_secs(10), rate: Some(50.0), }, requests_generator, From ac9ef8dcc996c1710d24834ac3eff07f7a9888ce Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Tue, 10 Dec 2024 16:59:20 +0100 Subject: [PATCH 02/16] feat: Working profiles with chat turns. --- src/profiles.rs | 80 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 src/profiles.rs diff --git a/src/profiles.rs b/src/profiles.rs new file mode 100644 index 0000000..aedb149 --- /dev/null +++ b/src/profiles.rs @@ -0,0 +1,80 @@ +use std::string::ToString; +use crate::{RunConfiguration, TokenizeOptions}; + +pub fn apply_profile(profile: &str, run_configuration: RunConfiguration) -> anyhow::Result { + match profile { + "fixed-length"=>{ + Ok(RunConfiguration { + max_vus: 128, + duration: std::time::Duration::from_secs(120), + rates: None, + num_rates: 10, + benchmark_kind: "sweep".to_string(), + warmup_duration: std::time::Duration::from_secs(30), + prompt_options: Some(TokenizeOptions { + num_tokens: Some(200), + min_tokens: 200, + max_tokens: 200, + variance: 0, + }), + decode_options: Some(TokenizeOptions { + num_tokens: Some(800), + min_tokens: 50, + max_tokens: 800, + variance: 100, + }), + dataset: "hlarcher/inference-benchmarker".to_string(), + dataset_file: "share_gpt_0_turns.json".to_string(), + ..run_configuration + }) + } + "chat" => { + Ok(RunConfiguration { + max_vus: 128, + duration: std::time::Duration::from_secs(120), + rates: None, + num_rates: 10, + benchmark_kind: "sweep".to_string(), + warmup_duration: std::time::Duration::from_secs(30), + prompt_options: None, // use prompts from dataset + decode_options: Some(TokenizeOptions { + num_tokens: Some(800), // decode up to 800 tokens + min_tokens: 50, + max_tokens: 800, + variance: 100, + }), + dataset: "hlarcher/inference-benchmarker".to_string(), + dataset_file: "share_gpt_turns.json".to_string(), + ..run_configuration + }) + }, + "code-generation"=>{ + Ok(RunConfiguration { + max_vus: 128, + duration: std::time::Duration::from_secs(120), + rates: None, + num_rates: 10, + benchmark_kind: "throughput".to_string(), + warmup_duration: std::time::Duration::from_secs(30), + prompt_options: Some(TokenizeOptions { + num_tokens: Some(4096), + min_tokens: 3000, + max_tokens: 6000, + variance: 1000, + }), + decode_options: Some(TokenizeOptions { + num_tokens: Some(50), + min_tokens: 30, + max_tokens: 80, + variance: 10, + }), + dataset: "hlarcher/inference-benchmarker".to_string(), + dataset_file: "github_code.json".to_string(), + ..run_configuration + }) + } + _ => { + Err(anyhow::anyhow!("Unknown profile: {}", profile)) + } + } +} From 1f8783bfe1065b03bfdd121506b98c98343d9b58 Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Wed, 8 Jan 2025 11:16:04 +0100 Subject: [PATCH 03/16] feat: Adding profiles documentation. --- README.md | 27 ++++++++++++++++++++++----- src/profiles.rs | 46 ++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 835a61b..e3ea4ba 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ It can be used to benchmark any text generation server that exposes an OpenAI-co * Broad Compatibility: Benchmarks any text generation server with an OpenAPI-compliant chat API. * Automatic Sweep Mode: Detects maximum throughput and sweeps in-between. * Open-Loop Benchmarking: Uses constant arrival rates to simulate real-world workloads. +* Benchmark profiles: Presets to benchmark for different model use cases (eg. chat, summarization, code completion...). * High-Performance: Built with Rust 🦀 for high-performance benchmarking. * JSON Output: Delivers performance results in a structured, easy-to-analyze format. @@ -34,6 +35,7 @@ It can be used to benchmark any text generation server that exposes an OpenAI-co * [1. Start an inference server](#1-start-an-inference-server) * [2. Run a benchmark using Docker image](#2-run-a-benchmark-using-docker-image) * [Configure your benchmark](#configure-your-benchmark) + * [Profiles](#profiles) * [Benchmark mode](#benchmark-mode) * [Dataset configuration](#dataset-configuration) * [Prompt configuration](#prompt-configuration) @@ -79,6 +81,7 @@ docker run --runtime nvidia --gpus all \ ```shell MODEL=meta-llama/Llama-3.1-8B-Instruct HF_TOKEN= +# run a benchmark to evaluate the performance of the model for chat use case # we mount results to the current directory $ docker run \ --rm \ @@ -89,18 +92,32 @@ $ docker run \ ghcr.io/huggingface/inference-benchmarker:latest \ inference-benchmarker \ --tokenizer-name "$MODEL" \ - --max-vus 800 \ --url http://localhost:8080 \ - --warmup 20s \ - --num-rates 10 \ - --prompt-options "num_tokens=200,max_tokens=220,min_tokens=180,variance=10" \ - --decode-options "num_tokens=200,max_tokens=220,min_tokens=180,variance=10" + --profile chat ``` Results will be saved in JSON format in current directory. ### Configure your benchmark +#### Profiles + +Profiles are presets to benchmark for different model use cases. Available profiles: +- `chat` + Simulates a multi-turn chat scenario in which the model answers to successive user prompts. + The model is prompted with the whole conversation history at each turn. Prefix caching will have a significant impact + on the performance of this benchmark. +- `code-generation` + Simulates code-complete scenarios. Model is given large code snippets and asked to complete them with a few tokens + (e.g. a function name, a few code lines). +- `classification` + Simulates cases where the model is fed with large chunks of business data or document repeatedly and users + ask simple questions about the content (summarization, classification...). + Those use cases benefit a lot from prefix caching and chunked prefill. +- `fixed-length` + Model is sent fixed-length prompts to void the impact of variable-length tokenization on the benchmark. + This is a technical benchmark to evaluate the raw throughput of the model. + #### Benchmark mode In default mode, tool runs a `sweep` benchmark. It first runs a throughput test to find the maximum throughput, then diff --git a/src/profiles.rs b/src/profiles.rs index aedb149..ac9674f 100644 --- a/src/profiles.rs +++ b/src/profiles.rs @@ -3,7 +3,7 @@ use crate::{RunConfiguration, TokenizeOptions}; pub fn apply_profile(profile: &str, run_configuration: RunConfiguration) -> anyhow::Result { match profile { - "fixed-length"=>{ + "fixed-length" => { Ok(RunConfiguration { max_vus: 128, duration: std::time::Duration::from_secs(120), @@ -29,6 +29,11 @@ pub fn apply_profile(profile: &str, run_configuration: RunConfiguration) -> anyh }) } "chat" => { + // Simulates a multi-turn chat scenario + // in which the model answers to successive user prompts. + // The model is prompted with the whole conversation history + // at each turn. Prefix caching will have a significant impact + // on the performance of this benchmark. Ok(RunConfiguration { max_vus: 128, duration: std::time::Duration::from_secs(120), @@ -47,14 +52,16 @@ pub fn apply_profile(profile: &str, run_configuration: RunConfiguration) -> anyh dataset_file: "share_gpt_turns.json".to_string(), ..run_configuration }) - }, - "code-generation"=>{ + } + "code-generation" => { + // Simulates code-complete scenarios. Model is given large code snippets and + // asked to complete them with a few tokens (e.g. a function name, a few code lines). Ok(RunConfiguration { max_vus: 128, duration: std::time::Duration::from_secs(120), rates: None, num_rates: 10, - benchmark_kind: "throughput".to_string(), + benchmark_kind: "sweep".to_string(), warmup_duration: std::time::Duration::from_secs(30), prompt_options: Some(TokenizeOptions { num_tokens: Some(4096), @@ -73,6 +80,37 @@ pub fn apply_profile(profile: &str, run_configuration: RunConfiguration) -> anyh ..run_configuration }) } + "classification" => { + // Simulates cases where the model is fed with large chunks + // of business data or document repeatedly and users + // ask simple questions about the content (summarization, classification...) + // Those use cases benefit a lot from prefix caching and chunked prefill. + Ok( + RunConfiguration { + max_vus: 128, + duration: std::time::Duration::from_secs(120), + rates: None, + num_rates: 10, + benchmark_kind: "sweep".to_string(), + warmup_duration: std::time::Duration::from_secs(30), + prompt_options: Some(TokenizeOptions { + num_tokens: Some(10000), + min_tokens: 8000, + max_tokens: 12000, + variance: 5000, + }), + decode_options: Some(TokenizeOptions { + num_tokens: Some(50), + min_tokens: 30, + max_tokens: 80, + variance: 10, + }), + dataset: "hlarcher/inference-benchmarker".to_string(), + dataset_file: "classification.json".to_string(), + ..run_configuration + } + ) + } _ => { Err(anyhow::anyhow!("Unknown profile: {}", profile)) } From 13ef12fc0432aa05184de209d9901d9c76a4b860 Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Wed, 8 Jan 2025 11:27:36 +0100 Subject: [PATCH 04/16] fix: Cargo checks --- src/lib.rs | 26 ++++---- src/main.rs | 4 +- src/profiles.rs | 107 +++++++++++++++--------------- src/requests.rs | 165 ++++++++++++++++++++++++++--------------------- src/results.rs | 4 +- src/scheduler.rs | 2 +- 6 files changed, 161 insertions(+), 147 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 742d464..60f051f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,8 +7,8 @@ use std::sync::Arc; pub use crate::app::run_console; pub use crate::benchmark::{BenchmarkConfig, BenchmarkKind}; -pub use crate::profiles::apply_profile; use crate::benchmark::{Event, MessageEvent}; +pub use crate::profiles::apply_profile; use crate::requests::OpenAITextGenerationBackend; pub use crate::requests::TokenizeOptions; use chrono::Local; @@ -24,12 +24,12 @@ mod benchmark; mod event; mod executors; mod flux; +mod profiles; mod requests; mod results; mod scheduler; mod table; mod writers; -mod profiles; pub struct RunConfiguration { pub url: String, @@ -58,18 +58,16 @@ pub async fn run(mut run_config: RunConfiguration, stop_sender: Sender<()>) -> a // apply profile if needed run_config = match run_config.profile.clone() { None => run_config, - Some(profile) => { - match apply_profile(profile.as_str(), run_config) { - Ok(config) => { - info!("Profile applied: {}", profile); - config - }, - Err(e) => { - error!("Failed to apply profile: {:?}", e); - return Err(e); - } + Some(profile) => match apply_profile(profile.as_str(), run_config) { + Ok(config) => { + info!("Profile applied: {}", profile); + config } - } + Err(e) => { + error!("Failed to apply profile: {:?}", e); + return Err(e); + } + }, }; // initialize tokenizer let params = FromPretrainedParameters { @@ -164,7 +162,7 @@ pub async fn run(mut run_config: RunConfiguration, stop_sender: Sender<()>) -> a run_config.dataset_file, run_config.hf_token.clone(), ) - .expect("Can't download dataset"); + .expect("Can't download dataset"); let requests = requests::ConversationTextRequestGenerator::load( filepath, run_config.tokenizer_name.clone(), diff --git a/src/main.rs b/src/main.rs index 393bfc5..822c74e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -187,8 +187,8 @@ fn parse_tokenizer_options(s: &str) -> Result { } if tokenizer_options.num_tokens.is_some() && (tokenizer_options.num_tokens.unwrap() == 0 - || tokenizer_options.min_tokens == 0 - || tokenizer_options.max_tokens == 0) + || tokenizer_options.min_tokens == 0 + || tokenizer_options.max_tokens == 0) { return Err(Error::new(InvalidValue)); } diff --git a/src/profiles.rs b/src/profiles.rs index ac9674f..dfcab2d 100644 --- a/src/profiles.rs +++ b/src/profiles.rs @@ -1,33 +1,34 @@ -use std::string::ToString; use crate::{RunConfiguration, TokenizeOptions}; +use std::string::ToString; -pub fn apply_profile(profile: &str, run_configuration: RunConfiguration) -> anyhow::Result { +pub fn apply_profile( + profile: &str, + run_configuration: RunConfiguration, +) -> anyhow::Result { match profile { - "fixed-length" => { - Ok(RunConfiguration { - max_vus: 128, - duration: std::time::Duration::from_secs(120), - rates: None, - num_rates: 10, - benchmark_kind: "sweep".to_string(), - warmup_duration: std::time::Duration::from_secs(30), - prompt_options: Some(TokenizeOptions { - num_tokens: Some(200), - min_tokens: 200, - max_tokens: 200, - variance: 0, - }), - decode_options: Some(TokenizeOptions { - num_tokens: Some(800), - min_tokens: 50, - max_tokens: 800, - variance: 100, - }), - dataset: "hlarcher/inference-benchmarker".to_string(), - dataset_file: "share_gpt_0_turns.json".to_string(), - ..run_configuration - }) - } + "fixed-length" => Ok(RunConfiguration { + max_vus: 128, + duration: std::time::Duration::from_secs(120), + rates: None, + num_rates: 10, + benchmark_kind: "sweep".to_string(), + warmup_duration: std::time::Duration::from_secs(30), + prompt_options: Some(TokenizeOptions { + num_tokens: Some(200), + min_tokens: 200, + max_tokens: 200, + variance: 0, + }), + decode_options: Some(TokenizeOptions { + num_tokens: Some(800), + min_tokens: 50, + max_tokens: 800, + variance: 100, + }), + dataset: "hlarcher/inference-benchmarker".to_string(), + dataset_file: "share_gpt_0_turns.json".to_string(), + ..run_configuration + }), "chat" => { // Simulates a multi-turn chat scenario // in which the model answers to successive user prompts. @@ -85,34 +86,30 @@ pub fn apply_profile(profile: &str, run_configuration: RunConfiguration) -> anyh // of business data or document repeatedly and users // ask simple questions about the content (summarization, classification...) // Those use cases benefit a lot from prefix caching and chunked prefill. - Ok( - RunConfiguration { - max_vus: 128, - duration: std::time::Duration::from_secs(120), - rates: None, - num_rates: 10, - benchmark_kind: "sweep".to_string(), - warmup_duration: std::time::Duration::from_secs(30), - prompt_options: Some(TokenizeOptions { - num_tokens: Some(10000), - min_tokens: 8000, - max_tokens: 12000, - variance: 5000, - }), - decode_options: Some(TokenizeOptions { - num_tokens: Some(50), - min_tokens: 30, - max_tokens: 80, - variance: 10, - }), - dataset: "hlarcher/inference-benchmarker".to_string(), - dataset_file: "classification.json".to_string(), - ..run_configuration - } - ) - } - _ => { - Err(anyhow::anyhow!("Unknown profile: {}", profile)) + Ok(RunConfiguration { + max_vus: 128, + duration: std::time::Duration::from_secs(120), + rates: None, + num_rates: 10, + benchmark_kind: "sweep".to_string(), + warmup_duration: std::time::Duration::from_secs(30), + prompt_options: Some(TokenizeOptions { + num_tokens: Some(10000), + min_tokens: 8000, + max_tokens: 12000, + variance: 5000, + }), + decode_options: Some(TokenizeOptions { + num_tokens: Some(50), + min_tokens: 30, + max_tokens: 80, + variance: 10, + }), + dataset: "hlarcher/inference-benchmarker".to_string(), + dataset_file: "classification.json".to_string(), + ..run_configuration + }) } + _ => Err(anyhow::anyhow!("Unknown profile: {}", profile)), } } diff --git a/src/requests.rs b/src/requests.rs index f423f94..177ac28 100644 --- a/src/requests.rs +++ b/src/requests.rs @@ -1,5 +1,3 @@ -use std::cmp::Ordering; -use std::collections::{BinaryHeap, HashMap}; use async_trait::async_trait; use futures_util::StreamExt; use hf_hub::api::sync::ApiBuilder; @@ -10,14 +8,16 @@ use rayon::iter::split; use rayon::prelude::*; use reqwest_eventsource::{Error, Event, EventSource}; use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::collections::{BinaryHeap, HashMap}; use std::fmt::Display; use std::path::PathBuf; use std::sync::{Arc, Mutex}; use std::time; use tokenizers::{FromPretrainedParameters, Tokenizer}; -use tokio::sync::mpsc::{Sender}; +use tokio::sync::mpsc::Sender; use tokio::time::{sleep, Instant}; -use uuid::{Uuid}; +use uuid::Uuid; #[derive(Debug, Clone, Eq, PartialEq)] pub struct TextGenerationRequest { @@ -302,7 +302,6 @@ pub trait TextRequestGenerator: Sync { fn callback(&mut self, request: Arc, response: &str); } - #[derive(Deserialize, Serialize, Clone)] pub struct Conversation { pub role: String, @@ -365,7 +364,9 @@ impl Ord for ConversationTurnRequest { // order by increasing priority and decreasing tie-breaking // this way, we can pop the item with the highest priority and oldest tie-breaking fn cmp(&self, other: &Self) -> Ordering { - self.priority.cmp(&other.priority).then_with(|| self.tie.cmp(&other.tie).reverse()) + self.priority + .cmp(&other.priority) + .then_with(|| self.tie.cmp(&other.tie).reverse()) } } @@ -375,7 +376,6 @@ impl PartialOrd for ConversationTurnRequest { } } - #[derive(Clone)] pub struct ConversationTextRequestGenerator { pub requests: HashMap, @@ -415,7 +415,8 @@ impl ConversationTextRequestGenerator { let input = std::fs::read_to_string(&filepath)?; let data: Vec = serde_json::from_str(&input).expect("Unable to parse input file. Check that it is valid JSON and matches the expected format."); // generate requests - let requests: Arc>> = Arc::from(Mutex::from(HashMap::new())); + let requests: Arc>> = + Arc::from(Mutex::from(HashMap::new())); info!( "Generating requests from {filepath}", filepath = filepath.display().to_string() @@ -430,14 +431,15 @@ impl ConversationTextRequestGenerator { if entry.conversations.is_empty() { continue; } - let mut turn_idx = 0; - let ids = (0..entry.conversations.len()).map(|_| Uuid::new_v4()).collect::>(); + let ids = (0..entry.conversations.len()) + .map(|_| Uuid::new_v4()) + .collect::>(); let filtered_conversations = entry .conversations .iter() .filter(|c| c.role == "user" || c.role == "system") .collect::>(); - for c in &filtered_conversations { + for (turn_idx, c) in filtered_conversations.iter().enumerate() { let prompt = c.content.clone(); let num_decode_tokens = decode_tokenize_opts.clone().map_or_else( || None, @@ -455,7 +457,7 @@ impl ConversationTextRequestGenerator { let next_id = if turn_idx == entry.conversations.len() - 1 { None } else { - Some(ids[turn_idx + 1].clone()) // link to next turn in the conversation + Some(ids[turn_idx + 1]) // link to next turn in the conversation }; debug!("Prompt: {prompt}", prompt = prompt); match &prompt_tokenize_opts { @@ -471,46 +473,48 @@ impl ConversationTextRequestGenerator { return; } }; - requests.lock().unwrap().insert(ids[turn_idx], ConversationTurnRequest { - id: ids[turn_idx], - priority: turn_idx as u64, - tie: Instant::now(), - request: TextGenerationRequest { - id: next_id, - prompt, - num_prompt_tokens: num_tokens, - num_decode_tokens, + requests.lock().unwrap().insert( + ids[turn_idx], + ConversationTurnRequest { + id: ids[turn_idx], + priority: turn_idx as u64, + tie: Instant::now(), + request: TextGenerationRequest { + id: next_id, + prompt, + num_prompt_tokens: num_tokens, + num_decode_tokens, + }, }, - }); + ); } Some(options) => { // compute number of tokens to generate using a Gaussian distribution - let (sampled_prompt, prompt_tokens) = match tokenize_prompt( - prompt.clone(), - tokenizer.clone(), - options, - ) { - Ok(prompt) => prompt, - Err(e) => { - debug!("Error tokenizing prompt: {e}"); - return; - } - }; - requests.lock().unwrap().insert(ids[turn_idx], ConversationTurnRequest { - id: ids[turn_idx], - tie: Instant::now(), - priority: turn_idx as u64, - request: TextGenerationRequest { - id: next_id, - prompt: sampled_prompt, - num_prompt_tokens: prompt_tokens, - num_decode_tokens, + let (sampled_prompt, prompt_tokens) = + match tokenize_prompt(prompt.clone(), tokenizer.clone(), options) { + Ok(prompt) => prompt, + Err(e) => { + debug!("Error tokenizing prompt: {e}"); + return; + } + }; + requests.lock().unwrap().insert( + ids[turn_idx], + ConversationTurnRequest { + id: ids[turn_idx], + tie: Instant::now(), + priority: turn_idx as u64, + request: TextGenerationRequest { + id: next_id, + prompt: sampled_prompt, + num_prompt_tokens: prompt_tokens, + num_decode_tokens, + }, }, - }); + ); } } - turn_idx += 1; - }; + } // TODO: check that we have enough requests } }); @@ -520,7 +524,13 @@ impl ConversationTextRequestGenerator { num_requests = requests.len() ); // create the queue from the hashmap. Only queue first turns in the conversation - let queue = BinaryHeap::from(requests.values().cloned().filter(|item| item.priority == 0).collect::>()); + let queue = BinaryHeap::from( + requests + .values() + .filter(|item| item.priority == 0) + .cloned() + .collect::>(), + ); Ok(Self { requests: requests.clone(), tokenizer, @@ -588,22 +598,19 @@ impl TextRequestGenerator for ConversationTextRequestGenerator { None => { return; } - Some(id) => { - id - } + Some(id) => id, }; // retrieve next turn from id let next_request = match self.requests.get(&id) { None => { return; } - Some(request) => { - request - } + Some(request) => request, }; // create a new turn with the prompt concatenated with the response and next turn's prompt // and add the next turn id to the new turn - let new_prompt = request.prompt.clone() + "\n" + response + "\n" + next_request.request.prompt.as_str(); + let new_prompt = + request.prompt.clone() + "\n" + response + "\n" + next_request.request.prompt.as_str(); // tokenize the prompt let (prompt, num_tokens) = match tokenize_prompt( new_prompt.to_string(), @@ -618,7 +625,7 @@ impl TextRequestGenerator for ConversationTextRequestGenerator { let next_id = next_request.request.id; let turn = ConversationTurnRequest { id, - priority: 100, // move to the front of the queue + priority: 100, // move to the front of the queue tie: Instant::now(), // use the current time as tie-breaking (older turns have higher priority) request: TextGenerationRequest { id: next_id, @@ -827,7 +834,7 @@ mod tests { tokenizer, time::Duration::from_secs(10), ) - .unwrap(); + .unwrap(); let request = TextGenerationRequest { id: None, prompt: "Hello, world!".to_string(), @@ -887,7 +894,7 @@ mod tests { tokenizer, time::Duration::from_secs(10), ) - .unwrap(); + .unwrap(); let request = TextGenerationRequest { id: None, prompt: "Hello, world!".to_string(), @@ -939,7 +946,7 @@ mod tests { assert!( inter_token_latency_avg > expected_inter_token_latency_avg && inter_token_latency_avg - < expected_inter_token_latency_avg + inter_token_latency_overhead, + < expected_inter_token_latency_avg + inter_token_latency_overhead, "inter_token_latency_avg: {:?} < {:?} < {:?}", expected_inter_token_latency_avg, inter_token_latency_avg, @@ -972,7 +979,7 @@ mod tests { tokenizer, time::Duration::from_secs(10), ) - .unwrap(); + .unwrap(); let request = TextGenerationRequest { id: None, prompt: "Hello, world!".to_string(), @@ -1018,7 +1025,7 @@ mod tests { tokenizer, time::Duration::from_secs(10), ) - .unwrap(); + .unwrap(); let request = TextGenerationRequest { id: None, prompt: "Hello, world!".to_string(), @@ -1064,7 +1071,7 @@ mod tests { tokenizer, time::Duration::from_secs(10), ) - .unwrap(); + .unwrap(); let request = TextGenerationRequest { id: None, prompt: "Hello, world!".to_string(), @@ -1114,7 +1121,7 @@ mod tests { tokenizer, Duration::from_secs(1), ) - .unwrap(); + .unwrap(); let request = TextGenerationRequest { id: None, prompt: "Hello, world!".to_string(), @@ -1154,7 +1161,7 @@ mod tests { Some(decode_tokenize_opts), hf_token, ) - .unwrap(); + .unwrap(); assert_eq!(generator.requests.len(), 17016); } @@ -1178,7 +1185,7 @@ mod tests { Some(decode_tokenize_opts), hf_token, ) - .unwrap(); + .unwrap(); let min_tokens = generator .requests .iter() @@ -1215,7 +1222,7 @@ mod tests { Some(decode_tokenize_opts), hf_token, ) - .unwrap(); + .unwrap(); for r in generator.requests.iter() { assert_eq!(r.1.request.num_prompt_tokens, 200); } @@ -1242,8 +1249,12 @@ mod tests { Some(decode_tokenize_opts), hf_token, ) - .unwrap(); - let turns = generator.requests.into_iter().map(|r| r.1.clone()).collect::>(); + .unwrap(); + let turns = generator + .requests + .into_iter() + .map(|r| r.1.clone()) + .collect::>(); assert_eq!(turns.len(), 8); let first_turns = turns.clone().into_iter().filter(|t| t.priority == 0); // we expect to have 2 None values for the first turn in each conversation @@ -1254,13 +1265,20 @@ mod tests { let mut current_turn = first_conversation[0].clone(); loop { conversation.push(current_turn.clone()); - match turns.iter().find(|t| t.id == current_turn.request.id.unwrap_or_default()) { + match turns + .iter() + .find(|t| t.id == current_turn.request.id.unwrap_or_default()) + { Some(t) => current_turn = t.clone(), None => break, } } assert_eq!(conversation.len(), 6); - let got = conversation.iter().map(|t| t.request.prompt.clone()).collect::>().join("\n"); + let got = conversation + .iter() + .map(|t| t.request.prompt.clone()) + .collect::>() + .join("\n"); let expect = "Summarize the main ideas of Jeff Walker's Product Launch Formula into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients...\nSummarize the main ideas of Brendon Burchard's Experts Academy into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients...\nWhat are the mental triggers in Jeff Walker's Product Launch Formula and \"Launch\" book?\nWrite a summary of why scarcity and urgency are the strongest mental triggers and have been the driving force behind many of our best performing campaigns over the last 8 years.\nSummarize Russell Brunson's Perfect Webinar Script...\nSummarize the 6 human needs as Tony Robbins explains..."; assert_eq!(expect, got); } @@ -1285,13 +1303,15 @@ mod tests { Some(decode_tokenize_opts), hf_token, ) - .unwrap(); + .unwrap(); for i in 0..20 { let req = generator.generate_request(); - if i % 2 == 0 { // first turn of the first conversation + if i % 2 == 0 { + // first turn of the first conversation assert_eq!(req.prompt, "Summarize the main ideas of Jeff Walker's Product Launch Formula into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients..."); } - if i % 2 == 1 { // first turn of the second conversation + if i % 2 == 1 { + // first turn of the second conversation assert_eq!(req.prompt, "You are a helpful assistant."); } } @@ -1317,7 +1337,7 @@ mod tests { Some(decode_tokenize_opts), hf_token, ) - .unwrap(); + .unwrap(); // generate the first user turn let req = generator.generate_request(); let response = "This is my response".to_string(); @@ -1338,4 +1358,3 @@ mod tests { assert_eq!(req.prompt, "You are a helpful assistant."); } } - diff --git a/src/results.rs b/src/results.rs index 3b34b8f..da4f475 100644 --- a/src/results.rs +++ b/src/results.rs @@ -352,9 +352,9 @@ impl BenchmarkReport { #[cfg(test)] mod test { - use std::sync::Arc; - use crate::requests::TextGenerationRequest; use super::*; + use crate::requests::TextGenerationRequest; + use std::sync::Arc; #[test] fn test_time_to_first_token_percentile() { let request = Arc::from(TextGenerationRequest { diff --git a/src/scheduler.rs b/src/scheduler.rs index 5a89247..957a458 100644 --- a/src/scheduler.rs +++ b/src/scheduler.rs @@ -241,7 +241,7 @@ mod tests { tokenizer, time::Duration::from_secs(10), ) - .unwrap(); + .unwrap(); let requests_generator = Arc::from(Mutex::from( crate::requests::DummyTextRequestGenerator::new(), )); From ce69bd553a16a6473862bd94d43be4a8ce5da5ae Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Wed, 8 Jan 2025 11:32:41 +0100 Subject: [PATCH 05/16] fix: Add missing test data file. --- test_data/chat.json | 72 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 test_data/chat.json diff --git a/test_data/chat.json b/test_data/chat.json new file mode 100644 index 0000000..7219210 --- /dev/null +++ b/test_data/chat.json @@ -0,0 +1,72 @@ +[ + { + "id": "QWJhYvA_0", + "conversations": [ + { + "role": "user", + "content": "Summarize the main ideas of Jeff Walker's Product Launch Formula into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients..." + }, + { + "role": "assistant", + "content": "Here are the main ideas of Jeff Walker's Product Launch Formula that can be applied by a growth marketing agency for their clients:\n\n1. Identify the target audience and their needs: Understand the ideal customer for the product or service, and create a messaging that resonates with them.\n2. Pre-launch: Build anticipation and excitement for the launch by creating buzz, gathering testimonials and case studies, and using social media to create awareness.\n3. Launch: Use a well-crafted launch sequence to maximize sales and conversions. This can include offering bonuses, creating scarcity, and using a deadline to create urgency.\n4. Post-launch: Follow up with customers, gather feedback, and continue to provide value to keep them engaged and loyal.\n5. Create a product suite: Once the initial product is successful, expand the offering by creating additional products or services that address other needs of the same audience.\n6. Continual optimization: Continually monitor and optimize the launch process and product suite to improve results.\n7. Build a community: Use the launch process to build a community of customers who are passionate about the product and can help spread the word.\n8. Use automation: Use technology and automation to streamline the launch process and improve efficiency." + }, + { + "role": "user", + "content": "Summarize the main ideas of Brendon Burchard's Experts Academy into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients..." + }, + { + "role": "assistant", + "content": "Here are the main ideas of Brendon Burchard's Experts Academy that can be applied by a growth marketing agency for their clients:\n\n1. Identify your expertise: Define your area of expertise and create a unique value proposition that differentiates you from others.\n2. Create a content strategy: Develop a content strategy that showcases your expertise and provides value to your audience. This can include blog posts, videos, podcasts, and social media content.\n3. Build your audience: Use social media, advertising, and other channels to build your audience and increase your visibility.\n4. Develop your products: Create products and services that align with your expertise and meet the needs of your audience. These can include digital products, courses, coaching, and consulting services.\n5. Create a brand: Develop a strong brand that reflects your expertise and resonates with your audience.\n6. Build authority: Establish yourself as an authority in your industry by speaking at events, publishing a book, or appearing on podcasts or TV shows.\n7. Monetize your expertise: Develop a monetization strategy that leverages your expertise and products to generate revenue.\n8. Build a team: As your business grows, build a team of experts to help you scale your business and provide excellent service to your clients.\n9. Continual improvement: Continually improve your products, services, and marketing strategies to stay ahead of the competition and provide the best possible experience for your clients.\n10. Focus on impact: Ultimately, focus on making a positive impact in the lives of your clients and using your expertise to make a difference in the world." + }, + { + "role": "user", + "content": "What are the mental triggers in Jeff Walker's Product Launch Formula and \"Launch\" book?" + }, + { + "role": "assistant", + "content": "Here are the mental triggers described in Jeff Walker's Product Launch Formula and \"Launch\" book:\n\n1. Scarcity: Creating a sense of urgency by limiting the availability of the product or service being launched.\n2. Authority: Establishing oneself as an expert or authority in the industry to build credibility and trust with the audience.\n3. Social proof: Using testimonials, case studies, and other forms of social proof to demonstrate the effectiveness and value of the product or service.\n4. Reciprocity: Offering something of value to the audience before asking for a sale or commitment.\n5. Likeability: Building a personal connection with the audience and creating a likable persona to increase the chances of a sale.\n6. Anticipation: Building anticipation and excitement leading up to the launch to generate buzz and increase interest in the product or service.\n7. Exclusivity: Creating a feeling of exclusivity by offering early access or special deals to a select group of people.\n8. Belonging: Creating a sense of community or belonging among the audience to increase engagement and loyalty to the brand.\n9. Storytelling: Using storytelling to connect with the audience on an emotional level and make the product or service more relatable and memorable.\n\nBy leveraging these mental triggers, businesses can create an effective product launch strategy that resonates with their target audience and drives sales and growth." + }, + { + "role": "user", + "content": "Write a summary of why scarcity and urgency are the strongest mental triggers and have been the driving force behind many of our best performing campaigns over the last 8 years." + }, + { + "role": "assistant", + "content": "Scarcity and urgency are two of the strongest mental triggers in marketing because they create a sense of exclusivity and importance around a product or service. By creating a sense of limited availability or a deadline for a promotion, businesses can motivate customers to take action and make a purchase before it's too late.\n\nOver the last eight years, many businesses have leveraged these triggers in their marketing campaigns to great success. Scarcity and urgency have been used to sell everything from digital products to physical goods and services, and have proven to be highly effective in driving conversions and increasing revenue.\n\nThis is because scarcity and urgency tap into the natural human desire for instant gratification and fear of missing out (FOMO). When customers believe that they may miss out on a valuable opportunity if they don't act quickly, they are more likely to take action and make a purchase.\n\nFurthermore, scarcity and urgency can be combined with other mental triggers, such as social proof and authority, to create even more powerful marketing campaigns. By demonstrating the popularity and effectiveness of a product or service, businesses can increase the perceived value and urgency of a promotion, leading to even more sales and revenue.\n\nOverall, scarcity and urgency are powerful mental triggers that have been proven to drive results in marketing campaigns. By leveraging these triggers, businesses can create a sense of urgency and exclusivity around their products or services, motivating customers to take action and make a purchase before it's too late." + }, + { + "role": "user", + "content": "Summarize Russell Brunson's Perfect Webinar Script..." + }, + { + "role": "assistant", + "content": "Russell Brunson's Perfect Webinar Script is a framework for delivering webinars that are designed to sell a product or service. Here are the key components of the Perfect Webinar Script:\n\n1. Introduction: The webinar starts with an introduction that establishes the presenter's credibility and establishes the topic.\n2. Story: The presenter then tells a personal story that is related to the topic and establishes a connection with the audience.\n3. Content: The presenter provides valuable content related to the topic, demonstrating their expertise and building trust with the audience.\n4. Offer: The presenter then introduces the product or service they are selling, explaining how it solves the audience's problem and providing social proof to support their claims.\n5. Bonuses: The presenter offers bonuses to sweeten the deal and make the offer more enticing.\n6. Scarcity: The presenter creates a sense of urgency by limiting the availability of the offer or including a deadline for purchase.\n7. FAQs: The presenter answers common questions about the offer, addressing objections and building trust with the audience.\n8. Closing: The presenter then delivers a clear call to action, directing the audience to take a specific action to purchase the product or service.\n\nOverall, the Perfect Webinar Script is designed to build trust with the audience, establish the presenter's credibility, and provide valuable content, while also creating a sense of urgency and scarcity around the offer being presented. By following this framework, businesses can use webinars as a powerful tool to sell their products and services to a targeted audience." + }, + { + "role": "user", + "content": "Summarize the 6 human needs as Tony Robbins explains..." + }, + { + "role": "assistant", + "content": "Tony Robbins describes six core human needs that drive our behaviors and motivations. These six needs are:\n\n1. Certainty: The need for safety, stability, and predictability. This includes the need for comfort, security, and control over our environment.\n2. Variety: The need for novelty, excitement, and change. This includes the need for adventure, stimulation, and new experiences.\n3. Significance: The need to feel important, special, and unique. This includes the need for recognition, achievement, and respect.\n4. Connection: The need for social connection and love. This includes the need for intimacy, belonging, and friendship.\n5. Growth: The need for personal growth and development. This includes the need for learning, self-improvement, and progress.\n6. Contribution: The need to make a positive impact in the world. This includes the need to give back, help others, and make a difference.\n\nAccording to Tony Robbins, these six needs are universal and apply to all individuals. Each person may prioritize these needs differently, and may fulfill them in different ways, but they are fundamental drivers of human behavior. By understanding these needs, individuals can gain insight into their own motivations and behaviors, and can use this knowledge to create a more fulfilling and meaningful life." + } + ] + }, + { + "id": "i6IyJda_0", + "conversations": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "How to tell if a customer segment is well segmented? In 3 bullet points." + }, + { + "role": "assistant", + "content": "1. Homogeneity: The segment should consist of customers who share similar characteristics and behaviors.\n2. Distinctiveness: The segment should be different from other segments in terms of their characteristics and behaviors.\n3. Stability: The segment should remain relatively stable over time and not change drastically. The characteristics and behaviors of customers within the segment should not change significantly." + } + ] + } +] From ca07ecab40a90616ad9e393cc7f87ac889d1cc71 Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Wed, 8 Jan 2025 11:38:39 +0100 Subject: [PATCH 06/16] chore: Update CI to skip docker image build if Rust checks fail. --- .github/workflows/build.yaml | 8 -------- .github/workflows/build_ci.yaml | 21 +++++++++++++++++++++ .github/workflows/rust.yml | 14 +++----------- .github/workflows/trufflehog.yml | 4 ++-- 4 files changed, 26 insertions(+), 21 deletions(-) create mode 100644 .github/workflows/build_ci.yaml diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 774a449..a9f1098 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -2,14 +2,6 @@ name: Build and push docker image to internal registry on: workflow_call: - push: - branches: - - 'main' - tags: - - 'v*' - pull_request: - branches: - - "main" concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/build_ci.yaml b/.github/workflows/build_ci.yaml new file mode 100644 index 0000000..dd7ea06 --- /dev/null +++ b/.github/workflows/build_ci.yaml @@ -0,0 +1,21 @@ +name: Build inference-benchmarker + +on: + workflow_dispatch: + workflow_call: + push: + branches: + - 'main' + tags: + - 'v*' + pull_request: + branches: + - "main" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + - use: ./.github/workflows/rust.yaml + - use: ./.github/workflows/build.yaml \ No newline at end of file diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 19e99e8..cc9354c 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -1,16 +1,8 @@ +name: Rust checks + on: - workflow_dispatch: - push: - branches: - - 'main' - tags: - - 'v*' - pull_request: - paths: - - 'src/**' - - .github/workflows/rust.yml + workflow_call: -name: Rust checks permissions: pull-requests: write contents: read diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml index 082c97b..72c0e36 100644 --- a/.github/workflows/trufflehog.yml +++ b/.github/workflows/trufflehog.yml @@ -1,8 +1,8 @@ +name: Secret Leaks + on: push: -name: Secret Leaks - jobs: trufflehog: runs-on: From 5d399728cfddc96cfbcfbae78b7afa044006829b Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Wed, 8 Jan 2025 11:41:18 +0100 Subject: [PATCH 07/16] chore: Update CI to skip docker image build if Rust checks fail. --- .github/workflows/build_ci.yaml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_ci.yaml b/.github/workflows/build_ci.yaml index dd7ea06..223154e 100644 --- a/.github/workflows/build_ci.yaml +++ b/.github/workflows/build_ci.yaml @@ -17,5 +17,8 @@ concurrency: cancel-in-progress: true jobs: - - use: ./.github/workflows/rust.yaml - - use: ./.github/workflows/build.yaml \ No newline at end of file + rust: + use: ./.github/workflows/rust.yaml + build: + use: ./.github/workflows/build.yaml + needs: rust \ No newline at end of file From 6e14719fa52e2f3f9d5ea8645fc6fef216c5112f Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Wed, 8 Jan 2025 11:42:59 +0100 Subject: [PATCH 08/16] chore: Update CI to skip docker image build if Rust checks fail. --- .github/workflows/build_ci.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_ci.yaml b/.github/workflows/build_ci.yaml index 223154e..7ad5f26 100644 --- a/.github/workflows/build_ci.yaml +++ b/.github/workflows/build_ci.yaml @@ -18,7 +18,7 @@ concurrency: jobs: rust: - use: ./.github/workflows/rust.yaml + uses: ./.github/workflows/rust.yaml build: - use: ./.github/workflows/build.yaml + uses: ./.github/workflows/build.yaml needs: rust \ No newline at end of file From ae59422ea49a47770eb7cbbd5ccae3d2ef5722f4 Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Wed, 8 Jan 2025 11:44:03 +0100 Subject: [PATCH 09/16] chore: Update CI to skip docker image build if Rust checks fail. --- .github/workflows/{rust.yml => rust.yaml} | 0 .github/workflows/{trufflehog.yml => trufflehog.yaml} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename .github/workflows/{rust.yml => rust.yaml} (100%) rename .github/workflows/{trufflehog.yml => trufflehog.yaml} (100%) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yaml similarity index 100% rename from .github/workflows/rust.yml rename to .github/workflows/rust.yaml diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yaml similarity index 100% rename from .github/workflows/trufflehog.yml rename to .github/workflows/trufflehog.yaml From fa90752ad032e7369ea5007da17d4633fa108ec0 Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Wed, 8 Jan 2025 11:45:55 +0100 Subject: [PATCH 10/16] chore: Update CI to skip docker image build if Rust checks fail. --- .github/workflows/build_ci.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/build_ci.yaml b/.github/workflows/build_ci.yaml index 7ad5f26..04908de 100644 --- a/.github/workflows/build_ci.yaml +++ b/.github/workflows/build_ci.yaml @@ -18,7 +18,13 @@ concurrency: jobs: rust: + permissions: + pull-requests: write + contents: read uses: ./.github/workflows/rust.yaml build: + permissions: + packages: write + contents: read uses: ./.github/workflows/build.yaml needs: rust \ No newline at end of file From 9c5c3873e0c0f67a0fa5f191050e6b14ca00e53b Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Wed, 8 Jan 2025 11:49:12 +0100 Subject: [PATCH 11/16] chore: Update CI to skip docker image build if Rust checks fail. --- .github/workflows/build.yaml | 2 +- .github/workflows/build_ci.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index a9f1098..a2fcac4 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -3,7 +3,7 @@ name: Build and push docker image to internal registry on: workflow_call: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref }}-build cancel-in-progress: true jobs: build-and-push: diff --git a/.github/workflows/build_ci.yaml b/.github/workflows/build_ci.yaml index 04908de..e91b5dd 100644 --- a/.github/workflows/build_ci.yaml +++ b/.github/workflows/build_ci.yaml @@ -17,12 +17,12 @@ concurrency: cancel-in-progress: true jobs: - rust: + Rust checks: permissions: pull-requests: write contents: read uses: ./.github/workflows/rust.yaml - build: + Build Docker image: permissions: packages: write contents: read From 3017e941800b6c5a9378fa98c64f007ab7da40b9 Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Wed, 8 Jan 2025 11:50:44 +0100 Subject: [PATCH 12/16] chore: Update CI to skip docker image build if Rust checks fail. --- .github/workflows/build_ci.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_ci.yaml b/.github/workflows/build_ci.yaml index e91b5dd..a6febfb 100644 --- a/.github/workflows/build_ci.yaml +++ b/.github/workflows/build_ci.yaml @@ -17,14 +17,16 @@ concurrency: cancel-in-progress: true jobs: - Rust checks: + rust: + name: Rust checks permissions: pull-requests: write contents: read uses: ./.github/workflows/rust.yaml - Build Docker image: + build: permissions: packages: write contents: read + name: Build and push docker image uses: ./.github/workflows/build.yaml needs: rust \ No newline at end of file From 5910f771a3146200faa60a1b24291b3d7a7c7341 Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Wed, 8 Jan 2025 12:10:11 +0100 Subject: [PATCH 13/16] chore: Update CI to skip docker image build if Rust checks fail. --- src/requests.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/requests.rs b/src/requests.rs index 177ac28..5e1d2af 100644 --- a/src/requests.rs +++ b/src/requests.rs @@ -870,6 +870,7 @@ mod tests { /// We need to account for the time it takes to establish the connection /// and the time it takes to receive the first message #[tokio::test] + #[ignore] // flaky test, don't run it on CI async fn test_openai_timings() { let mut s = mockito::Server::new_async().await; s.mock("POST", "/v1/chat/completions") From 66cd6041113b1a240a1abb0db68c5aa68a00f652 Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Wed, 8 Jan 2025 14:35:57 +0100 Subject: [PATCH 14/16] chore: Update CI to skip docker image build if Rust checks fail. --- .github/workflows/build.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index a2fcac4..0473246 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -1,4 +1,4 @@ -name: Build and push docker image to internal registry +name: Build and push docker image to registry on: workflow_call: @@ -36,7 +36,7 @@ jobs: uses: docker/metadata-action@v5 with: images: | - registry.internal.huggingface.tech/api-inference/inference-benchmarker + ghcr.io/huggingface/inference-benchmarker tags: | type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} # If main, release or tag From c0307802c690a5495120390250f8dc21c3d469f6 Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Wed, 8 Jan 2025 15:21:31 +0100 Subject: [PATCH 15/16] chore: Update CI to skip docker image build if Rust checks fail. --- .github/workflows/build.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 0473246..3212e91 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -23,7 +23,6 @@ jobs: install: true buildkitd-config: /tmp/buildkitd.toml - name: Login to GitHub Container Registry - if: github.event_name != 'pull_request' uses: docker/login-action@v3 with: registry: ghcr.io From c5b8ec86658dae71bde27d4dfcf2d63e7d7f32ee Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Wed, 8 Jan 2025 16:22:24 +0100 Subject: [PATCH 16/16] chore: Update CI to skip docker image build if Rust checks fail. --- .github/workflows/build.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 3212e91..d292e88 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -37,7 +37,7 @@ jobs: images: | ghcr.io/huggingface/inference-benchmarker tags: | - type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} + type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }} # If main, release or tag - name: Extract metadata (tags, labels) for Docker if: ${{ github.event_name != 'pull_request' }} @@ -49,10 +49,10 @@ jobs: images: | ghcr.io/huggingface/inference-benchmarker tags: | - type=semver,pattern={{version}}${{ env.LABEL }} - type=semver,pattern={{major}}.{{minor}}${{ env.LABEL }} - type=raw,value=latest${{ env.LABEL }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} - type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} + type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }} - name: Build and push Docker image id: build-and-push uses: docker/build-push-action@v4 @@ -63,7 +63,7 @@ jobs: platforms: 'linux/amd64' build-args: | GIT_SHA=${{ env.GITHUB_SHA }} - DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} PLATFORM=${{ env.PLATFORM }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}