diff --git a/include/flexflow/attention_config.h b/include/flexflow/attention_config.h index 98992ff9a..a69b41ec6 100644 --- a/include/flexflow/attention_config.h +++ b/include/flexflow/attention_config.h @@ -61,17 +61,22 @@ class AttentionMetaData { num_q_heads_ = 0; num_kv_heads_ = 0; head_dim_ = 0; - q_indptr = nullptr; - kv_indptr = nullptr; - kv_indices = nullptr; - kv_last_page_len = nullptr; - qk_indptr = nullptr; - custom_mask = nullptr; - workspace = nullptr; + q_indptr_dec = nullptr; + kv_indptr_dec = nullptr; + kv_indices_dec = nullptr; + kv_last_page_len_dec = nullptr; + workspace_dec = nullptr; + float_workspace_dec = nullptr; + int_workspace_dec = nullptr; + q_indptr_pref = nullptr; + kv_indptr_pref = nullptr; + kv_indices_pref = nullptr; + kv_last_page_len_pref = nullptr; + workspace_pref = nullptr; + float_workspace_pref = nullptr; + int_workspace_pref = nullptr; workspace_size = 0; - float_workspace = nullptr; float_workspace_size = 0; - int_workspace = nullptr; int_workspace_size = 0; mem_size_ = 0; enabled_ = false; @@ -80,17 +85,22 @@ class AttentionMetaData { num_q_heads_ = rhs.num_q_heads_; num_kv_heads_ = rhs.num_kv_heads_; head_dim_ = rhs.head_dim_; - q_indptr = rhs.q_indptr; - kv_indptr = rhs.kv_indptr; - kv_indices = rhs.kv_indices; - kv_last_page_len = rhs.kv_last_page_len; - qk_indptr = rhs.qk_indptr; - custom_mask = rhs.custom_mask; - workspace = rhs.workspace; + q_indptr_dec = rhs.q_indptr_dec; + kv_indptr_dec= rhs.kv_indptr_dec; + kv_indices_dec = rhs.kv_indices_dec; + kv_last_page_len_dec = rhs.kv_last_page_len_dec; + workspace_dec = rhs.workspace_dec; + float_workspace_dec = rhs.float_workspace_dec; + int_workspace_dec = rhs.int_workspace_dec; + q_indptr_pref = rhs.q_indptr_pref; + kv_indptr_pref= rhs.kv_indptr_pref; + kv_indices_pref = rhs.kv_indices_pref; + kv_last_page_len_pref = rhs.kv_last_page_len_pref; + workspace_pref = rhs.workspace_pref; + float_workspace_pref = rhs.float_workspace_pref; + int_workspace_pref = rhs.int_workspace_pref; workspace_size = rhs.workspace_size; - float_workspace = rhs.float_workspace; float_workspace_size = rhs.float_workspace_size; - int_workspace = rhs.int_workspace; int_workspace_size = rhs.int_workspace_size; mem_size_ = rhs.mem_size_; enabled_ = rhs.enabled_; @@ -106,30 +116,33 @@ class AttentionMetaData { size_t max_num_pages = round_up_pages(BatchConfig::max_sequence_length()); size_t indices_size = std::max( (batch_size + 1) * 4 + max_num_pages * batch_size, 1ul * 1024 * 1024); - size_t custom_mask_size = 0; float_workspace_size = 128 * 1024 * 1024; // 128 MB int_workspace_size = 8 * 1024 * 1024; // 8 MB workspace_size = float_workspace_size + int_workspace_size; // float + int workspace - mem_size_ = alignTo(sizeof(int32_t) * indices_size + - sizeof(uint8_t) * custom_mask_size + workspace_size, + mem_size_ = alignTo(2*(sizeof(int32_t) * indices_size + workspace_size), 16); return mem_size_; } void assign_address(void *ptr, int size) { if (ptr == nullptr) { - q_indptr = nullptr; - kv_indptr = nullptr; - kv_indices = nullptr; - kv_last_page_len = nullptr; - qk_indptr = nullptr; - custom_mask = nullptr; - workspace = nullptr; - float_workspace = nullptr; - int_workspace = nullptr; + q_indptr_dec = nullptr; + kv_indptr_dec = nullptr; + kv_indices_dec = nullptr; + kv_last_page_len_dec = nullptr; + workspace_dec = nullptr; + float_workspace_dec = nullptr; + int_workspace_dec = nullptr; + q_indptr_pref = nullptr; + kv_indptr_pref = nullptr; + kv_indices_pref = nullptr; + kv_last_page_len_pref = nullptr; + workspace_pref = nullptr; + float_workspace_pref = nullptr; + int_workspace_pref = nullptr; return; } assert(size >= mem_size() && @@ -138,19 +151,26 @@ class AttentionMetaData { size_t max_num_pages = round_up_pages(BatchConfig::max_sequence_length()); size_t indices_size = std::max( (batch_size + 1) * 4 + max_num_pages * batch_size, 1ul * 1024 * 1024); - size_t custom_mask_size = 0; - - q_indptr = static_cast(ptr); - kv_indptr = q_indptr + batch_size + 1; - kv_indices = kv_indptr + batch_size + 1; - kv_last_page_len = kv_indices + max_num_pages * batch_size; - qk_indptr = kv_last_page_len + batch_size + 1; - custom_mask = static_cast(ptr) + sizeof(int32_t) * indices_size; - workspace = static_cast(static_cast(ptr) + - sizeof(int32_t) * indices_size + - sizeof(uint8_t) * custom_mask_size); - float_workspace = workspace; - int_workspace = static_cast(static_cast(workspace) + + + q_indptr_dec = static_cast(ptr); + kv_indptr_dec = q_indptr_dec + batch_size + 1; + kv_indices_dec = kv_indptr_dec + batch_size + 1; + kv_last_page_len_dec = kv_indices_dec + max_num_pages * batch_size; + q_indptr_pref = static_cast(ptr) + indices_size; + kv_indptr_pref = q_indptr_pref + batch_size + 1; + kv_indices_pref = kv_indptr_pref + batch_size + 1; + kv_last_page_len_pref = kv_indices_pref + max_num_pages * batch_size; + + workspace_dec = static_cast(static_cast(ptr) + + sizeof(int32_t) * indices_size * 2); + float_workspace_dec = workspace_dec; + int_workspace_dec = static_cast(static_cast(workspace_dec) + + float_workspace_size); + workspace_pref = static_cast(static_cast(ptr) + + sizeof(int32_t) * indices_size * 2 + + workspace_size); + float_workspace_pref = workspace_pref; + int_workspace_pref = static_cast(static_cast(workspace_pref) + float_workspace_size); } @@ -184,19 +204,26 @@ class AttentionMetaData { uint32_t num_kv_heads_; uint32_t head_dim_; - int32_t *q_indptr; - int32_t *kv_indptr; - int32_t *kv_indices; - int32_t *kv_last_page_len; - int32_t *qk_indptr; - uint8_t *custom_mask; - void *workspace; + int32_t *q_indptr_dec; + int32_t *kv_indptr_dec; + int32_t *kv_indices_dec; + int32_t *kv_last_page_len_dec; + uint8_t *custom_mask_dec; + void *workspace_dec; + void *float_workspace_dec; + void *int_workspace_dec; + int32_t *q_indptr_pref; + int32_t *kv_indptr_pref; + int32_t *kv_indices_pref; + int32_t *kv_last_page_len_pref; + uint8_t *custom_mask_pref; + void *workspace_pref; + void *float_workspace_pref; + void *int_workspace_pref; + size_t workspace_size; - void *float_workspace; size_t float_workspace_size; - void *int_workspace; size_t int_workspace_size; - size_t mem_size_; // batchsize -> handler diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index c94632fe1..0c7e7795d 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -67,9 +67,12 @@ class BatchConfig { // returns number of inference and finetuning FWD tokens int num_active_tokens() const; - // returns number of inference-only tokens + // returns number of inference-only tokens (prefill + decode) int num_inference_tokens() const; + // returns number of inference-only requests (prefill + decode) int num_inference_requests() const; + int num_prefill_requests() const; + int num_decoding_requests() const; // return the index where the finetuning request would be stored (i.e. last // slot of the batch) @@ -97,7 +100,7 @@ class BatchConfig { // These maximum values are used for copying BatchConfig // across workers static int const MAX_NUM_REQUESTS = 260; - static int const MAX_NUM_TOKENS = 3000; + static int const MAX_NUM_TOKENS = 8192; static int const MAX_SPEC_TREE_TOKEN_NUM = 64; static int const MAX_PEFT_CONFIG_SIZE = 1024; diff --git a/include/flexflow/config.h b/include/flexflow/config.h index d9ba03ae2..a4aa1d35c 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -90,9 +90,11 @@ struct FFHandler { #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) cudnnHandle_t dnn, peft_dnn; cublasHandle_t blas, peft_blas; + cudaStream_t extra_stream1, extra_stream2; #else miopenHandle_t dnn, peft_dnn; hipblasHandle_t blas, peft_blas; + hipStream_t extra_stream1, extra_stream2; #endif void *workSpace; size_t workSpaceSize; diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index 5cfb8e485..fc454b49a 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -106,9 +106,11 @@ struct Request { std::vector finetuning_losses; // bwd state int last_processed_bwd_layer = INT_MAX; + int max_samples = -1; // -1: no limit // how many gradient accumulation steps to do before updating the weights. // if left as -1, it will be set to the number of entries in the dataset int gradient_accumulation_steps = -1; + int num_logging_steps = 10; // std::vector finetuning_tokens_per_batch; }; RequestType req_type = REQ_INFERENCE; @@ -118,6 +120,7 @@ struct Request { int benchmarking_tokens = -1; bool add_special_tokens = true; bool warmup = false; + bool ignore_eos = false; Status status = PENDING; long long arrival_time_us = 0; @@ -401,7 +404,7 @@ class RequestManager { // peft std::unordered_map peft_configs; - int max_lora_rank = 32; + int max_lora_rank = 16; int max_concurrent_adapters = 0; // peft benchmarking bool enable_peft_finetuning = false; diff --git a/inference/flexllm/peft_train.cc b/inference/flexllm/peft_train.cc index 88dd01cec..9648e4376 100644 --- a/inference/flexllm/peft_train.cc +++ b/inference/flexllm/peft_train.cc @@ -53,7 +53,10 @@ void parse_input_args(char **argv, int &max_tokens_per_batch, int &max_sequence_length, int &num_kv_cache_slots, + int &max_finetuning_samples, int &max_training_epochs, + int &gradient_accumulation_steps, + int &num_logging_steps, int &num_layers_per_finetuning_step, bool &run_warmup) { for (int i = 1; i < argc; i++) { @@ -143,10 +146,22 @@ void parse_input_args(char **argv, num_kv_cache_slots = std::stoi(argv[++i]); continue; } - if (!strcmp(argv[i], "--max-training-steps")) { + if (!strcmp(argv[i], "--max-finetuning-samples")) { + max_finetuning_samples = std::stoi(argv[++i]); + continue; + } + if (!strcmp(argv[i], "--max-training-epochs")) { max_training_epochs = std::stoi(argv[++i]); continue; } + if (!strcmp(argv[i], "--gradient-accumulation-steps")) { + gradient_accumulation_steps = std::stoi(argv[++i]); + continue; + } + if (!strcmp(argv[i], "--num-logging-steps")) { + num_logging_steps = std::stoi(argv[++i]); + continue; + } if (!strcmp(argv[i], "--num-layers-per-finetuning-step")) { num_layers_per_finetuning_step = std::stoi(argv[++i]); continue; @@ -213,6 +228,7 @@ std::vector load_trace(nlohmann::ordered_json prompt_json, int prompt_length = entry["prompt_length"]; int response_length = entry["response_length"]; std::string text = entry["prompt"]; + double arrival_time_us = entry["arrival_time"].get() * 1000000; Request inference_req; if (benchmarking) { @@ -222,6 +238,8 @@ std::vector load_trace(nlohmann::ordered_json prompt_json, inference_req.prompt = text; } inference_req.max_new_tokens = response_length; + inference_req.ignore_eos = true; + inference_req.arrival_time_us = arrival_time_us; requests.push_back(inference_req); } return requests; @@ -283,11 +301,15 @@ void FlexFlow::top_level_task(Task const *task, int max_requests_per_batch = 1; int max_tokens_per_batch = 128; int max_sequence_length = 256; - int max_training_epochs = 2; + int max_finetuning_samples = -1; // -1: no limit + int max_training_epochs = 1; + int gradient_accumulation_steps = 8; + int num_logging_steps = 10; bool enable_peft_finetuning = true; int num_layers_per_finetuning_step = -1; bool run_warmup = false; int num_kv_cache_slots = -1; + int rank = 16; InputArgs const &command_args = HighLevelRuntime::get_input_args(); char **argv = command_args.argv; @@ -307,7 +329,10 @@ void FlexFlow::top_level_task(Task const *task, max_tokens_per_batch, max_sequence_length, num_kv_cache_slots, + max_finetuning_samples, max_training_epochs, + gradient_accumulation_steps, + num_logging_steps, num_layers_per_finetuning_step, run_warmup); enable_peft_finetuning = file_paths.dataset_file_path.empty() ? false : true; @@ -391,7 +416,6 @@ void FlexFlow::top_level_task(Task const *task, "Invalid LLM model type passed (or no type was passed)."); // load PEFT config - int rank = 16; LoraOptimizerConfig *optim_config = new LoraSGDOptimizerConfig(0.001f); std::vector target_modules = {"down_proj"}; LoraLinearConfig peft_config_finetuning(file_paths.cache_folder_path, @@ -418,6 +442,7 @@ void FlexFlow::top_level_task(Task const *task, model_type, bos_token_id, eos_token_ids, tokenizer_filepath); rm->register_output_filepath(file_paths.output_file_path); rm->set_enable_peft_finetuning(enable_peft_finetuning); + rm->set_max_lora_rank(rank); FFModel model(ffconfig, ffconfig.cpu_offload); model.set_num_kv_cache_pages( @@ -482,9 +507,10 @@ void FlexFlow::top_level_task(Task const *task, // Run workload { - std::vector requests = - load_requests(file_paths.prompt_file_path, 128); - + std::vector inference_requests; + if (!file_paths.prompt_file_path.empty()) { + inference_requests = load_requests(file_paths.prompt_file_path, 128); + } // Add fine-tuning request assert(!file_paths.dataset_file_path.empty() && "Dataset file path is required for fine-tuning."); @@ -495,12 +521,17 @@ void FlexFlow::top_level_task(Task const *task, fine_tuning_req.peft_model_id = *peft_model_id_finetuning; fine_tuning_req.peft_finetuning_info.dataset_filepath = file_paths.dataset_file_path; + fine_tuning_req.peft_finetuning_info.max_samples = max_finetuning_samples; fine_tuning_req.peft_finetuning_info.max_training_epochs = max_training_epochs; - requests.push_back(fine_tuning_req); + fine_tuning_req.peft_finetuning_info.gradient_accumulation_steps = + gradient_accumulation_steps; + fine_tuning_req.peft_finetuning_info.num_logging_steps = num_logging_steps; + std::vector finetuning_requests; + finetuning_requests.push_back(fine_tuning_req); std::cout << "----------inference started--------------" << std::endl; - std::vector result = model.generate(requests); + std::vector result = model.generate_online(inference_requests, finetuning_requests); std::cout << "----------inference finished--------------" << std::endl; } diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index cd5c68253..09c79ce51 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -51,6 +51,7 @@ void parse_input_args(char **argv, int &max_tokens_per_batch, int &max_sequence_length, int &num_kv_cache_slots, + bool &ignore_eos, int &max_length, bool &run_warmup) { for (int i = 1; i < argc; i++) { @@ -131,6 +132,10 @@ void parse_input_args(char **argv, max_length = std::stoi(argv[++i]); continue; } + if (!strcmp(argv[i], "--ignore-eos")) { + ignore_eos = true; + continue; + } // whether to run warmup if (!strcmp(argv[i], "--warmup")) { run_warmup = true; @@ -273,6 +278,7 @@ void FlexFlow::top_level_task(Task const *task, int max_length = 128; int num_kv_cache_slots = -1; bool run_warmup = false; + bool ignore_eos = false; InputArgs const &command_args = HighLevelRuntime::get_input_args(); char **argv = command_args.argv; @@ -291,6 +297,7 @@ void FlexFlow::top_level_task(Task const *task, max_tokens_per_batch, max_sequence_length, num_kv_cache_slots, + ignore_eos, max_length, run_warmup); @@ -425,6 +432,11 @@ void FlexFlow::top_level_task(Task const *task, std::cout << "----------inference started--------------" << std::endl; std::vector requests = load_requests(file_paths.prompt_file_path, qps, max_length); + if (ignore_eos) { + for (auto &request : requests) { + request.ignore_eos = true; + } + } std::vector result = (qps > 0.0f) ? model.generate_online(requests, {}) diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 37d6f444b..25f08ae7e 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -40,7 +40,6 @@ void LLAMA::create_llama_model(FFModel &ff, << ff.config.enable_peft_finetuning << std::endl; assert(llama_config.hidden_size % llama_config.num_attention_heads == 0 && "Hidden size not divisible by number of attention heads"); - int head_dim = llama_config.hidden_size / llama_config.num_attention_heads; int tot_num_heads = llama_config.num_attention_heads + 2 * llama_config.num_key_value_heads; @@ -98,7 +97,7 @@ void LLAMA::create_llama_model(FFModel &ff, Tensor qkv_proj = ff.dense( att_norm, - head_dim * tot_num_heads, + llama_config.head_dim * tot_num_heads, AC_MODE_NONE, false, // seems like llama does not use bias DT_NONE, // what is this @@ -115,11 +114,11 @@ void LLAMA::create_llama_model(FFModel &ff, case BEAM_SEARCH_MODE: { mha = ff.spec_inc_multihead_self_attention( qkv_proj, - llama_config.hidden_size, + llama_config.head_dim*llama_config.num_attention_heads, llama_config.num_attention_heads, llama_config.num_key_value_heads, - head_dim, - head_dim, + llama_config.head_dim, + llama_config.head_dim, 0.0f, /*dropout*/ false, /*add_zero_attn*/ DT_NONE, /*data_type*/ @@ -137,11 +136,11 @@ void LLAMA::create_llama_model(FFModel &ff, case TREE_VERIFY_MODE: { mha = ff.inc_multihead_self_attention_verify( qkv_proj, - llama_config.hidden_size, + llama_config.head_dim*llama_config.num_attention_heads, llama_config.num_attention_heads, llama_config.num_key_value_heads, - head_dim, - head_dim, + llama_config.head_dim, + llama_config.head_dim, 0.0f, /*dropout*/ false, /*add_zero_attn*/ DT_NONE, /*data_type*/ @@ -159,11 +158,11 @@ void LLAMA::create_llama_model(FFModel &ff, case INC_DECODING_MODE: { mha = ff.inc_multihead_self_attention( qkv_proj, - llama_config.hidden_size, + llama_config.head_dim*llama_config.num_attention_heads, llama_config.num_attention_heads, llama_config.num_key_value_heads, - head_dim, - head_dim, + llama_config.head_dim, + llama_config.head_dim, 0.0f, /*dropout*/ false, /*add_zero_attn*/ DT_NONE, /*data_type*/ @@ -300,8 +299,7 @@ void LLAMA::create_llama_model(FFModel &ff, // If PEFT is enabled, add LoRA layers if (ff.config.enable_peft) { // todo: add attention projections - std::vector target_modules = { - "qkv_proj", "o_proj", "gate_proj", "down_proj", "up_proj"}; + std::vector target_modules = {"down_proj"}; ff.add_lora_layers(target_modules); } @@ -311,7 +309,7 @@ void LLAMA::create_llama_model(FFModel &ff, llama_config.num_attention_heads, llama_config.num_key_value_heads, llama_config.hidden_size, - head_dim, + llama_config.head_dim, ff.config.tensor_parallelism_degree, use_full_precision); diff --git a/inference/models/llama.h b/inference/models/llama.h index e74f8d52b..f6d400e72 100644 --- a/inference/models/llama.h +++ b/inference/models/llama.h @@ -44,6 +44,12 @@ class LLAMA { num_key_value_heads = num_attention_heads; } hidden_size = model_config["hidden_size"]; + if (model_config.find("head_dim") != model_config.end()) { + head_dim = model_config["head_dim"]; + } else { + assert(hidden_size % num_attention_heads == 0); + head_dim = hidden_size / num_attention_heads; + } rms_norm_eps = model_config["rms_norm_eps"]; intermediate_size = model_config["intermediate_size"]; rotary_embedding_meta.apply_rotary_embedding = true; @@ -89,6 +95,7 @@ class LLAMA { std::cout << "\tnum_key_value_heads: " << num_key_value_heads << std::endl; std::cout << "\thidden_size: " << hidden_size << std::endl; + std::cout << "\thead_dim: " << head_dim << std::endl; std::cout << "\trms_norm_eps: " << rms_norm_eps << std::endl; std::cout << "\tintermediate_size: " << intermediate_size << std::endl; std::cout << "\trotary_embedding_meta: " << rotary_embedding_meta @@ -99,7 +106,7 @@ class LLAMA { int max_beam_width, max_beam_depth; int num_hidden_layers, vocab_size, num_attention_heads, num_key_value_heads, - hidden_size, intermediate_size; + hidden_size, intermediate_size, head_dim; float rms_norm_eps; RotaryEmbeddingMeta rotary_embedding_meta; }; diff --git a/python/flexflow/serve/models/llama.py b/python/flexflow/serve/models/llama.py index 85df0deec..f9099e542 100644 --- a/python/flexflow/serve/models/llama.py +++ b/python/flexflow/serve/models/llama.py @@ -52,6 +52,7 @@ def __init__(self, hf_config): if hf_config.num_key_value_heads is None else hf_config.num_key_value_heads ) + self.head_dim = hf_config.head_dim if "head_dim" in hf_config.__dict__ else (self.hidden_size // self.num_attention_heads) class FlexFlowLLAMA(FlexFlowModel): @@ -92,9 +93,6 @@ def __init__( assert ( self.llama_config.hidden_size % self.llama_config.num_attention_heads == 0 ) - self.head_dim = ( - self.llama_config.hidden_size // self.llama_config.num_attention_heads - ) self.tot_num_heads = ( self.llama_config.num_attention_heads + 2 * self.llama_config.num_key_value_heads @@ -146,7 +144,7 @@ def build_model(self): qkv_proj = self.ffmodel.dense( attn_norm, - self.head_dim * self.tot_num_heads, + self.llama_config.head_dim * self.tot_num_heads, ActiMode.AC_MODE_NONE, False, name=f"layers.{i}.self_attn.qkv_proj", @@ -155,11 +153,11 @@ def build_model(self): if self.mode == InferenceMode.BEAM_SEARCH_MODE: mha = self.ffmodel.spec_inc_multihead_self_attention( qkv_proj, - self.llama_config.hidden_size, + self.llama_config.head_dim*self.llama_config.num_attention_heads, self.llama_config.num_attention_heads, self.llama_config.num_key_value_heads, - self.head_dim, - self.head_dim, + self.llama_config.head_dim, + self.llama_config.head_dim, 0.0, # dropout False, # add_zero_attn DataType.DT_NONE, # data_type @@ -170,11 +168,11 @@ def build_model(self): elif self.mode == InferenceMode.TREE_VERIFY_MODE: mha = self.ffmodel.inc_multihead_self_attention_verify( qkv_proj, - self.llama_config.hidden_size, + self.llama_config.head_dim*self.llama_config.num_attention_heads, self.llama_config.num_attention_heads, self.llama_config.num_key_value_heads, - self.head_dim, - self.head_dim, + self.llama_config.head_dim, + self.llama_config.head_dim, 0.0, # dropout False, # add_zero_attn DataType.DT_NONE, # data_type @@ -185,11 +183,11 @@ def build_model(self): elif self.mode == InferenceMode.INC_DECODING_MODE: mha = self.ffmodel.inc_multihead_self_attention( qkv_proj, - self.llama_config.hidden_size, + self.llama_config.head_dim*self.llama_config.num_attention_heads, self.llama_config.num_attention_heads, self.llama_config.num_key_value_heads, - self.head_dim, - self.head_dim, + self.llama_config.head_dim, + self.llama_config.head_dim, 0.0, # dropout False, # add_zero_attn DataType.DT_NONE, # data_type diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 343f6becd..b77b4d289 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -1545,10 +1545,11 @@ struct AttnGeneralParams { int num_kv_heads; int head_dim; int num_tokens; + int first_prefill_req_idx; int peft_req_idx; BatchConfig::PerTokenInfo const *tokenInfos; - int32_t *kv_indptr; - int32_t *kv_page_indices; + int32_t *kv_indptr_pref, *kv_indptr_dec; + int32_t *kv_page_indices_pref, *kv_page_indices_dec; bool const *request_completed; }; @@ -1568,22 +1569,16 @@ __global__ void attention_prep_kernel(AttnGeneralParams general_params, int global_head_idx = (i / half_proj) % tot_num_heads; int token_idx = i / (half_proj * tot_num_heads); bool is_q = global_head_idx < general_params.num_q_heads; - bool is_k = global_head_idx >= general_params.num_q_heads && - global_head_idx < - general_params.num_q_heads + general_params.num_kv_heads; - bool is_v = global_head_idx >= - general_params.num_q_heads + general_params.num_kv_heads; + bool is_k = global_head_idx >= general_params.num_q_heads && global_head_idx < general_params.num_q_heads + general_params.num_kv_heads; + bool is_v = global_head_idx >= general_params.num_q_heads + general_params.num_kv_heads; // one and only one of is_q, is_k, is_v is true - assert((is_q && !is_k && !is_v) || (!is_q && is_k && !is_v) || - (!is_q && !is_k && is_v)); + assert((is_q && !is_k && !is_v) || (!is_q && is_k && !is_v) || (!is_q && !is_k && is_v)); int head_idx = is_q ? global_head_idx : (is_k ? global_head_idx - general_params.num_q_heads : global_head_idx - general_params.num_q_heads - general_params.num_kv_heads); - bool is_peft_token = general_params.tokenInfos[token_idx].request_index == - general_params.peft_req_idx; - int token_abs_idx = - general_params.tokenInfos[token_idx].abs_depth_in_request; + bool is_peft_token = general_params.tokenInfos[token_idx].request_index == general_params.peft_req_idx; + int token_abs_idx = general_params.tokenInfos[token_idx].abs_depth_in_request; int src_a_offset = (token_idx * general_params.head_dim * tot_num_heads) + (global_head_idx * general_params.head_dim) + pair_idx; @@ -1648,19 +1643,34 @@ __global__ void attention_prep_kernel(AttnGeneralParams general_params, int const req_idx = general_params.tokenInfos[token_idx].request_index; assert(req_idx != general_params.peft_req_idx && "Attempting to use inference KV cache for PEFT tokens"); - int req_idx_compact = 0; - for (int j = 0; j < req_idx; j++) { - if (!general_params.request_completed[j]) { - req_idx_compact++; + bool is_decode_token = req_idx < general_params.first_prefill_req_idx; + int page_idx = 0; + if (is_decode_token) { + int req_idx_compact = 0; + for (int j = 0; j < req_idx; j++) { + assert(j < general_params.first_prefill_req_idx); + if (!general_params.request_completed[j]) { + req_idx_compact++; + } } + assert(req_idx_compact >= 0 && req_idx_compact <= req_idx && + "Invalid request index"); + int logical_page_idx = token_abs_idx / kPagesize; + page_idx = general_params.kv_page_indices_dec[general_params.kv_indptr_dec[req_idx_compact] + logical_page_idx]; + } else { + int req_idx_compact = 0; + for (int j = general_params.first_prefill_req_idx; j < req_idx; j++) { + assert(j >= 0); + if (!general_params.request_completed[j]) { + req_idx_compact++; + } + } + assert(req_idx_compact >= 0 && req_idx_compact <= req_idx && + "Invalid request index"); + int logical_page_idx = token_abs_idx / kPagesize; + page_idx = general_params.kv_page_indices_pref[general_params.kv_indptr_pref[req_idx_compact] + logical_page_idx]; } - assert(req_idx_compact >= 0 && req_idx_compact <= req_idx && - "Invalid request index"); - int logical_page_idx = token_abs_idx / kPagesize; - int page_idx = - general_params - .kv_page_indices[general_params.kv_indptr[req_idx_compact] + - logical_page_idx]; + int to_k_idx = get_k_entry_offset_verify(token_abs_idx, page_idx, general_params.num_kv_heads, @@ -1709,16 +1719,15 @@ void flashinfer_incr_attention(IncMultiHeadSelfAttentionMeta *m, BatchConfig const *bc, int shard_id, DT *output_ptr, - cudaStream_t stream) { + cudaStream_t main_stream, + cudaEvent_t prep_done) { // global constant parameters uint32_t const num_q_heads = m->num_q_heads; uint32_t const num_kv_heads = m->num_kv_heads; uint32_t const head_dim = m->qProjSize; - uint32_t const batch_size = bc->num_inference_requests(); float const sm_scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->qProjSize) : 1.0f; - assert(batch_size > 0); assert(num_q_heads > 0); assert(num_kv_heads > 0); assert(head_dim > 0); @@ -1729,15 +1738,6 @@ void flashinfer_incr_attention(IncMultiHeadSelfAttentionMeta *m, assert(q != nullptr && "q is null!"); assert(kv != nullptr && "kv is null!"); assert(o != nullptr && "o is null!"); - assert(m->handle.incr_attention_metadata->q_indptr != nullptr && - "q_indptr is null!"); - assert(m->handle.incr_attention_metadata->kv_indices != nullptr && - "kv_indices is null!"); - assert(m->handle.incr_attention_metadata->kv_indptr != nullptr && - "kv_indptr is null!"); - assert(m->handle.incr_attention_metadata->kv_last_page_len != nullptr && - "kv_last_page_len is null!"); - if (m->inference_debugging) { // qTmp: [qProjSize, num_q_heads, num_new_tokens] std::string fpath_q = get_fwd_dbg_folder(m, shard_id) + ".queryTmp.pt"; @@ -1752,90 +1752,204 @@ void flashinfer_incr_attention(IncMultiHeadSelfAttentionMeta *m, torch::save(tensor_kv, fpath_kv.c_str()); } - paged_kv_t paged_kv( + + int batch_size_pref = bc->num_prefill_requests(); + int batch_size_dec = bc->num_decoding_requests(); + assert(batch_size_pref + batch_size_dec == bc->num_inference_requests()); + // decoding + if (batch_size_dec > 0) { + assert(m->handle.incr_attention_metadata->q_indptr_dec != nullptr && + "q_indptr_dec is null!"); + assert(m->handle.incr_attention_metadata->kv_indices_dec != nullptr && + "kv_indices_dec is null!"); + assert(m->handle.incr_attention_metadata->kv_indptr_dec != nullptr && + "kv_indptr_dec is null!"); + assert(m->handle.incr_attention_metadata->kv_last_page_len_dec != nullptr && + "kv_last_page_len_dec is null!"); + paged_kv_t paged_kv( + num_kv_heads, + kPagesize, + head_dim, + batch_size_dec, + QKVLayout::kNHD, + kv, + m->handle.incr_attention_metadata->kv_indices_dec, + m->handle.incr_attention_metadata->kv_indptr_dec, + m->handle.incr_attention_metadata->kv_last_page_len_dec); + + if (m->inference_debugging && false) { + bc->save_to_file(get_fwd_dbg_folder(m, shard_id) + ".batch_config"); + std::string fpath = get_fwd_dbg_folder(m, shard_id) + ".q_indptr_dec"; + save_tensor( + static_cast(m->handle.incr_attention_metadata->q_indptr_dec), + batch_size_dec + 1, + fpath.c_str()); + fpath = get_fwd_dbg_folder(m, shard_id) + ".kv_indptr_dec"; + save_tensor( + static_cast(m->handle.incr_attention_metadata->kv_indptr_dec), + batch_size_dec + 1, + fpath.c_str()); + fpath = get_fwd_dbg_folder(m, shard_id) + ".kv_indices_dec"; + + int num_pages; + checkCUDA( + cudaMemcpy(&num_pages, + m->handle.incr_attention_metadata->kv_indptr_dec + batch_size_dec, + sizeof(int), + cudaMemcpyDeviceToHost)); + save_tensor( + static_cast(m->handle.incr_attention_metadata->kv_indices_dec), + num_pages, + fpath.c_str()); + fpath = get_fwd_dbg_folder(m, shard_id) + ".kv_last_page_len_dec"; + save_tensor(static_cast( + m->handle.incr_attention_metadata->kv_last_page_len_dec), + batch_size_dec, + fpath.c_str()); + } + + assert(m->handle.incr_attention_metadata->decode_handler_collections.count(batch_size_dec) != 0 && "Handler is not initialized"); + BatchDecodeHandler *handler = static_cast(m->handle.incr_attention_metadata->decode_handler_collections[batch_size_dec]); + handler->SetCUDAStream(main_stream); + // printf("obtained handler\n"); + assert(sizeof(DT) == 2 && "FlashInfer only supports half precision"); + // Note that num decoding tokens == num decoding requests + half *q_decode = q; + half *o_decode = o; + DISPATCH_HEADDIM(head_dim, HEAD_DIM, { + // printf("Launching BatchDecodeWithPagedKVCacheWrapperDispatched\n"); + cudaError_t result = + BatchDecodeWithPagedKVCacheWrapperDispatched( + handler, + q_decode, + /*q_offset=*/nullptr, + paged_kv, + o_decode, + /*lse=*/nullptr, + num_q_heads, + /*window_left=*/-1, + /*logits_soft_cap=*/0.f, + sm_scale, + /*rope_scale=*/1.f, + /*rope_theta=*/static_cast(1e4), + main_stream); + if (result != cudaSuccess) { + fprintf(stderr, "Failed to run BatchDecodeWithPagedKVCacheWrapperDispatched: %s\n", cudaGetErrorString(result)); + assert(false); + } + }); + } + + // prefilling + if (batch_size_pref > 0) { + cudaStreamWaitEvent(m->handle.extra_stream1, prep_done, 0); + assert(m->handle.incr_attention_metadata->q_indptr_pref != nullptr && + "q_indptr_pref is null!"); + assert(m->handle.incr_attention_metadata->kv_indices_pref != nullptr && + "kv_indices_pref is null!"); + assert(m->handle.incr_attention_metadata->kv_indptr_pref != nullptr && + "kv_indptr_pref is null!"); + assert(m->handle.incr_attention_metadata->kv_last_page_len_pref != nullptr && + "kv_last_page_len_pref is null!"); + paged_kv_t paged_kv( num_kv_heads, kPagesize, head_dim, - batch_size, + batch_size_pref, QKVLayout::kNHD, kv, - m->handle.incr_attention_metadata->kv_indices, - m->handle.incr_attention_metadata->kv_indptr, - m->handle.incr_attention_metadata->kv_last_page_len); - - if (m->inference_debugging && false) { - bc->save_to_file(get_fwd_dbg_folder(m, shard_id) + ".batch_config"); - std::string fpath = get_fwd_dbg_folder(m, shard_id) + ".q_indptr"; - save_tensor( - static_cast(m->handle.incr_attention_metadata->q_indptr), - batch_size + 1, - fpath.c_str()); - fpath = get_fwd_dbg_folder(m, shard_id) + ".kv_indptr"; - save_tensor( - static_cast(m->handle.incr_attention_metadata->kv_indptr), - batch_size + 1, - fpath.c_str()); - fpath = get_fwd_dbg_folder(m, shard_id) + ".kv_indices"; - - int num_pages; - checkCUDA( - cudaMemcpy(&num_pages, - m->handle.incr_attention_metadata->kv_indptr + batch_size, - sizeof(int), - cudaMemcpyDeviceToHost)); - save_tensor( - static_cast(m->handle.incr_attention_metadata->kv_indices), - num_pages, - fpath.c_str()); - fpath = get_fwd_dbg_folder(m, shard_id) + ".kv_last_page_len"; - save_tensor(static_cast( - m->handle.incr_attention_metadata->kv_last_page_len), - batch_size, - fpath.c_str()); - } - - assert(m->handle.incr_attention_metadata->prompt_handler_collections.count( - batch_size) != 0 && - "Handler is not initialized"); - void *handler = - m->handle.incr_attention_metadata->prompt_handler_collections[batch_size]; - // printf("obtained handler\n"); - assert(sizeof(DT) == 2 && "FlashInfer only supports half precision"); - DISPATCH_HEADDIM(head_dim, HEAD_DIM, { - // printf("Launching BatchPrefillWithPagedKVCacheWrapperDispatched\n"); - cudaError_t result = - BatchPrefillWithPagedKVCacheWrapperDispatched( - static_cast(handler), - q, - m->handle.incr_attention_metadata->q_indptr, - /*q_offset=*/nullptr, - paged_kv, - /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, - o, - /*lse=*/nullptr, - num_q_heads, - /*window_left=*/-1, - /*logits_soft_cap=*/0.f, - sm_scale, - /*rope_scale=*/1.f, - /*rope_theta=*/static_cast(1e4), - stream); - if (result != cudaSuccess) { - throw std::runtime_error("Failed to run " - "IncrementalDecodingAttentionForwardKernel: " + - std::string(cudaGetErrorString(result))); + m->handle.incr_attention_metadata->kv_indices_pref, + m->handle.incr_attention_metadata->kv_indptr_pref, + m->handle.incr_attention_metadata->kv_last_page_len_pref); + + if (m->inference_debugging && false) { + bc->save_to_file(get_fwd_dbg_folder(m, shard_id) + ".batch_config"); + std::string fpath = get_fwd_dbg_folder(m, shard_id) + ".q_indptr_pref"; + save_tensor( + static_cast(m->handle.incr_attention_metadata->q_indptr_pref), + batch_size_pref + 1, + fpath.c_str()); + fpath = get_fwd_dbg_folder(m, shard_id) + ".kv_indptr_pref"; + save_tensor( + static_cast(m->handle.incr_attention_metadata->kv_indptr_pref), + batch_size_pref + 1, + fpath.c_str()); + fpath = get_fwd_dbg_folder(m, shard_id) + ".kv_indices_pref"; + + int num_pages; + checkCUDA( + cudaMemcpy(&num_pages, + m->handle.incr_attention_metadata->kv_indptr_pref + batch_size_pref, + sizeof(int), + cudaMemcpyDeviceToHost)); + save_tensor( + static_cast(m->handle.incr_attention_metadata->kv_indices_pref), + num_pages, + fpath.c_str()); + fpath = get_fwd_dbg_folder(m, shard_id) + ".kv_last_page_len_pref"; + save_tensor(static_cast( + m->handle.incr_attention_metadata->kv_last_page_len_pref), + batch_size_pref, + fpath.c_str()); } - }); + + assert(m->handle.incr_attention_metadata->prompt_handler_collections.count(batch_size_pref) != 0 && "Handler is not initialized"); + BatchPrefillHandler *handler = static_cast(m->handle.incr_attention_metadata->prompt_handler_collections[batch_size_pref]); + handler->SetCUDAStream(m->handle.extra_stream1); + // printf("obtained handler\n"); + assert(sizeof(DT) == 2 && "FlashInfer only supports half precision"); + // Note that num decoding tokens == num decoding requests + half *q_prefill = q + head_dim * num_q_heads * batch_size_dec; + half *o_prefill = o + head_dim * num_q_heads * batch_size_dec; + DISPATCH_HEADDIM(head_dim, HEAD_DIM, { + // printf("Launching BatchPrefillWithPagedKVCacheWrapperDispatched\n"); + cudaError_t result = + BatchPrefillWithPagedKVCacheWrapperDispatched( + handler, + q_prefill, + m->handle.incr_attention_metadata->q_indptr_pref, + /*q_offset=*/nullptr, + paged_kv, + /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, + o_prefill, + /*lse=*/nullptr, + num_q_heads, + /*window_left=*/-1, + /*logits_soft_cap=*/0.f, + sm_scale, + /*rope_scale=*/1.f, + /*rope_theta=*/static_cast(1e4), + m->handle.extra_stream1); + if (result != cudaSuccess) { + fprintf(stderr, "Failed to run BatchPrefillWithPagedKVCacheWrapperDispatched: %s\n", cudaGetErrorString(result)); + assert(false); + } + }); + + // ensure the main stream waits until prefilling has finished + cudaEvent_t prefilling_done; + checkCUDA(cudaEventCreate(&prefilling_done)); + checkCUDA(cudaEventRecord(prefilling_done, m->handle.extra_stream1)); + checkCUDA(cudaStreamWaitEvent(main_stream, prefilling_done, 0)); + checkCUDA(cudaEventDestroy(prefilling_done)); + } } template @@ -1844,9 +1958,9 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, int shard_id, DT const *qkv_ptr, DT *output_ptr, - cudaStream_t inf_stream, - cudaStream_t peft_stream) { - + cudaStream_t stream) { + // Step 0: Preparation (shared by inference and finetuning) + // ========================================================================== // qkv_ptr: [qProjSize, tot_num_heads, num_new_tokens] assert(m->qProjSize == m->kProjSize && m->qProjSize == m->vProjSize); size_t tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads; @@ -1862,15 +1976,28 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, if (bc->num_finetuning_fwd_requests() > 0) { peft_req_idx = bc->finetuning_request_index(); } + int first_prefill_req_idx = bc->max_requests_per_batch(); + for (int req_idx = 0; req_idx < bc->max_requests_per_batch(); req_idx++) { + if (bc->request_completed[req_idx] || bc->requestsInfo[req_idx].finetuning_request) { + continue; + } + if (bc->requestsInfo[req_idx].prompt_phase) { + first_prefill_req_idx = req_idx; + break; + } + } AttnGeneralParams general_params = { .num_q_heads = m->num_q_heads, .num_kv_heads = m->num_kv_heads, .head_dim = m->qProjSize, .num_tokens = bc->num_active_tokens(), + .first_prefill_req_idx = first_prefill_req_idx, .peft_req_idx = peft_req_idx, .tokenInfos = m->token_infos, - .kv_indptr = m->handle.incr_attention_metadata->kv_indptr, - .kv_page_indices = m->handle.incr_attention_metadata->kv_indices, + .kv_indptr_pref = m->handle.incr_attention_metadata->kv_indptr_pref, + .kv_indptr_dec = m->handle.incr_attention_metadata->kv_indptr_dec, + .kv_page_indices_pref = m->handle.incr_attention_metadata->kv_indices_pref, + .kv_page_indices_dec = m->handle.incr_attention_metadata->kv_indices_dec, .request_completed = m->request_completed}; RopeParams rope_params = { .apply_rotary_embedding = @@ -1896,18 +2023,21 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, attention_prep_kernel<<>>( + stream>>>( general_params, data_pointers, rope_params); + cudaEvent_t prep_done, finetuning_done; + checkCUDA(cudaEventCreate(&prep_done)); + checkCUDA(cudaEventCreate(&finetuning_done)); + checkCUDA(cudaEventRecord(prep_done, stream)); + + // Step 1: Run finetuning FWD on a separate stream + // ========================================================================== if (bc->num_finetuning_fwd_tokens() > 0) { - // wait until preparation has finished - cudaEvent_t prep_done; - cudaEventCreate(&prep_done); - cudaEventRecord(prep_done, inf_stream); - cudaStreamWaitEvent(peft_stream, prep_done, 0); + cudaStreamWaitEvent(m->handle.extra_stream2, prep_done, 0); flash_compute_attention_kernel_peft
( - m, bc, output_ptr, shard_id, peft_stream); + m, bc, output_ptr, shard_id, m->handle.extra_stream2); assert(m->peft_token_infos != nullptr); assert(m->peft_token_infos_size == sizeof(BatchConfig::PerTokenInfo) * @@ -1921,14 +2051,22 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, m->peft_token_infos[prev_steps_tokens + j] = bc->tokensInfo[tokens_previous_requests + j]; } + checkCUDA(cudaEventRecord(finetuning_done, m->handle.extra_stream2)); } - // flashinfer sdpa + // Step 2: Run inference + // ========================================================================== assert(bc->num_finetuning_fwd_tokens() >= 0 && bc->num_finetuning_bwd_tokens() >= 0); if (bc->num_inference_tokens() > 0) { - flashinfer_incr_attention
(m, bc, shard_id, output_ptr, inf_stream); + flashinfer_incr_attention
(m, bc, shard_id, output_ptr, stream, prep_done); } + + if (bc->num_finetuning_fwd_tokens() > 0) { + checkCUDA(cudaStreamWaitEvent(stream, finetuning_done, 0)); + } + checkCUDA(cudaEventDestroy(prep_done)); + checkCUDA(cudaEventDestroy(finetuning_done)); } // todo(yingyi): replace with flash-attn @@ -2114,7 +2252,7 @@ void flash_peft_bwd_kernel(IncMultiHeadSelfAttentionMeta *m, 0, peft_stream>>>( input_grad_ptr, - m->complex_input, + m->complex_input_bwd, m->peft_token_infos_device, m->rotary_embedding_meta->rope_theta, (m->rotary_embedding_meta->rope_type == "llama3"), @@ -2152,10 +2290,8 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( int shard_id, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - cudaStream_t inf_stream; - checkCUDA(get_legion_stream(&inf_stream)); - cudaStream_t peft_stream; - checkCUDA(get_legion_stream(&peft_stream)); + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); // cudaEvent_t t_start, t_end; // if (m->profiling) { @@ -2172,8 +2308,7 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( shard_id, input.get_half_ptr(), output.get_half_ptr(), - inf_stream, - peft_stream); + stream); } // else if (input.data_type == DT_BFLOAT16) { // Kernels::IncMultiHeadAttention::inference_kernel(m, @@ -2181,8 +2316,7 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( // shard_id, // input.get_bfloat16_ptr(), // output.get_bfloat16_ptr(), - // inf_stream, - // peft_stream); + // stream); // } else { assert(false && "Unspported data type"); diff --git a/src/runtime/batch_config.cc b/src/runtime/batch_config.cc index aa1718aaf..99d64e907 100644 --- a/src/runtime/batch_config.cc +++ b/src/runtime/batch_config.cc @@ -127,6 +127,28 @@ int BatchConfig::num_inference_requests() const { num_finetuning_bwd_requests(); } +int BatchConfig::num_prefill_requests() const { + int num_prefill_reqs = 0; + for (int i=0; i 0); return max_requests_per_batch() - 1; diff --git a/src/runtime/file_loader.cc b/src/runtime/file_loader.cc index e87594b00..a45ff45a8 100644 --- a/src/runtime/file_loader.cc +++ b/src/runtime/file_loader.cc @@ -232,7 +232,7 @@ void load_attention_weights_to_dense_v2(DT *ptr, weight_filenames.push_back(o_file); } - assert(head_dim == hidden_dim / num_q_heads); + // assert(head_dim == hidden_dim / num_q_heads); int total_num_heads = num_q_heads + 2 * num_kv_heads; int total_heads_per_shard = total_num_heads / tensor_parallelism_degree; diff --git a/src/runtime/model.cu b/src/runtime/model.cu index 4ad47a362..ea9176cbd 100644 --- a/src/runtime/model.cu +++ b/src/runtime/model.cu @@ -92,6 +92,10 @@ FFHandler handle.quantization_type = info->quantization_type; handle.allowTensorOpMathConversion = info->allowTensorOpMathConversion; + // create additional streams for tasks that require more than 1 + checkCUDA(cudaStreamCreate(&handle.extra_stream1)); + checkCUDA(cudaStreamCreate(&handle.extra_stream2)); + // flashinfer handle.incr_attention_metadata = new AttentionMetaData(); assert(handle.incr_attention_metadata != nullptr && @@ -183,19 +187,12 @@ FFHandler handle.batch_config_metadata = nullptr; } - // std::cout << "handle.batch_config_metadata_size: " - // << handle.batch_config_metadata_size << std::endl; - // std::cout << "handle.incr_attention_metadata->mem_size(): " - // << handle.incr_attention_metadata->mem_size() << std::endl; - if (handle.batch_config_metadata_size + - handle.incr_attention_metadata->mem_size()) { + if (handle.incr_attention_metadata->mem_size() > 0) { // allocate memory for offload reserve space Memory gpu_mem = get_proc_mem(Machine::get_machine(), task->target_proc); Realm::Rect<1, coord_t> bounds( Realm::Point<1, coord_t>(0), - Realm::Point<1, coord_t>(handle.batch_config_metadata_size + - handle.incr_attention_metadata->mem_size() - - 1)); + Realm::Point<1, coord_t>(handle.incr_attention_metadata->mem_size() - 1)); std::vector field_sizes; field_sizes.push_back(sizeof(char)); Realm::RegionInstance workspaceInst; @@ -207,14 +204,8 @@ FFHandler Realm::ProfilingRequestSet()) .wait(); void *ptr = workspaceInst.pointer_untyped(0, sizeof(char)); - handle.batch_config_metadata = - static_cast(ptr); - handle.incr_attention_metadata->assign_address( - static_cast(static_cast(ptr) + - handle.batch_config_metadata_size), - handle.incr_attention_metadata->mem_size()); + handle.incr_attention_metadata->assign_address(ptr, handle.incr_attention_metadata->mem_size()); } else { - handle.batch_config_metadata = nullptr; handle.incr_attention_metadata->assign_address(nullptr, 0); } diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index e639653bb..d5936bbb1 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -182,6 +182,10 @@ bool RequestManager::load_request_token_ids(Request &request) { /*ignore_comments */ true); for (auto &prompt : dataset_json) { + if (request.peft_finetuning_info.max_samples > 0 && + request.dataset.size() >= request.peft_finetuning_info.max_samples) { + break; + } std::string text = prompt.get(); std::vector input_tokens; input_tokens = this->tokenizer_->Encode(text); @@ -202,7 +206,9 @@ bool RequestManager::load_request_token_ids(Request &request) { } std::cout << "Creating dataset from json file: " << request.peft_finetuning_info.dataset_filepath - << ". Size of dataset: " << request.dataset.size() << std::endl; + << ". Size of dataset: " << request.dataset.size() + << ". Max samples: " << request.peft_finetuning_info.max_samples + << std::endl; } if (request.peft_finetuning_info.gradient_accumulation_steps == -1) { request.peft_finetuning_info.gradient_accumulation_steps = @@ -250,10 +256,13 @@ std::ostream &operator<<(std::ostream &os, Request const &req) { os << " status: " << req.peft_finetuning_info.status << "\n"; os << " dataset_filepath: " << req.peft_finetuning_info.dataset_filepath << "\n"; + os << " max_samples: " << req.peft_finetuning_info.max_samples << "\n"; os << " max_training_epochs: " << req.peft_finetuning_info.max_training_epochs << "\n"; os << " completed_training_steps: " << req.peft_finetuning_info.completed_training_steps << "\n"; + os << " num_logging_steps: " + << req.peft_finetuning_info.num_logging_steps << "\n"; os << " dataset_entry_processed_tokens: " << req.peft_finetuning_info.dataset_entry_processed_tokens << "\n"; os << " finetuning_losses: " @@ -791,7 +800,7 @@ bool RequestManager::inf_req_completed(BatchConfig const &old_bc, int i) { // printf("model_type = %d\n", this->model_type); if (request.tokens.size() >= old_bc.requestsInfo[i].max_length) { request_completed = true; - } else if (is_eos_token(request.tokens.back())) { + } else if (is_eos_token(request.tokens.back()) && !request.ignore_eos) { // Encounter EOS token id request_completed = true; } @@ -976,7 +985,7 @@ void RequestManager::handle_completed_inf_req(BatchConfig const &old_bc, GenerationResult &gr = request_generation_results[request.guid]; std::vector output_tokens = std::vector( request.tokens.begin() + gr.input_tokens.size(), request.tokens.end()); - if (is_eos_token(output_tokens.back())) { + if (is_eos_token(output_tokens.back()) && !request.ignore_eos) { // remove the EOS token output_tokens.pop_back(); } @@ -1600,6 +1609,11 @@ void RequestManager::process_finetuning_req_bwd_progress( ((int)request.dataset.size()), request.peft_finetuning_info.max_training_epochs); } + else if (request.peft_finetuning_info.completed_training_steps % request.peft_finetuning_info.num_logging_steps == 0) { + log_req_mgr.print("Completed finetuning step %i/%i", + request.peft_finetuning_info.completed_training_steps, + tot_steps); + } if (request.peft_finetuning_info.completed_training_steps == tot_steps || inference_finished) { handle_completed_finetuning_req(old_bc); diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index 6644293fa..58305dc3b 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -90,6 +90,7 @@ void RequestManager::load_tokens_task( void prepare_inference_params_kernel_h( BatchConfig const *batch_config, + bool decoding, std::vector &q_indptr_h, std::vector &kv_indptr_h, std::vector &kv_page_indices_h, @@ -109,12 +110,16 @@ void prepare_inference_params_kernel_h( q_indptr_h.push_back(0); kv_indptr_h.push_back(0); + int batch_size = 0; + for (int req_idx = 0; req_idx < batch_config->max_requests_per_batch(); req_idx++) { if (batch_config->request_completed[req_idx] || - batch_config->requestsInfo[req_idx].finetuning_request) { + batch_config->requestsInfo[req_idx].finetuning_request || + batch_config->requestsInfo[req_idx].prompt_phase == decoding) { continue; } + batch_size++; // q_indptr: first token offset in batch, plus one token at the end // representing the total number of tokens in batch @@ -145,10 +150,17 @@ void prepare_inference_params_kernel_h( } // check sizes - int batch_size = batch_config->num_active_requests() - - batch_config->num_finetuning_fwd_requests() - - batch_config->num_finetuning_bwd_requests(); + // printf("num_prefill_reqs: %d\n", batch_config->num_prefill_requests()); + // printf("num_decoding_reqs: %d\n", batch_config->num_decoding_requests()); + // printf("batch_size: %d\n", batch_size); + // printf("decoding: %d\n", decoding); assert(batch_size > 0); + assert(batch_size <= batch_config->num_inference_requests()); + if (decoding) { + assert(batch_size == batch_config->num_decoding_requests()); + } else { + assert(batch_size == batch_config->num_prefill_requests()); + } // printf("q_indptr_h size: %lu\n", q_indptr_h.size()); // printf("kv_indptr_h size: %lu\n", kv_indptr_h.size()); // printf("kv_page_indices_h size: %lu\n", kv_page_indices_h.size()); @@ -241,103 +253,211 @@ void RequestManager::load_batch_config_task( } // load attention metadata - int batch_size = batch_config->num_active_requests() - - batch_config->num_finetuning_fwd_requests() - - batch_config->num_finetuning_bwd_requests(); - if (batch_config->get_mode() == INC_DECODING_MODE && batch_size > 0 && + if (batch_config->get_mode() == INC_DECODING_MODE && batch_config->num_inference_requests() > 0 && handle.incr_attention_metadata->enabled()) { - // assert(handle.incr_attention_metadata->enabled()); - // printf("Entering here, handler: %p\n", handle.incr_attention_metadata); - std::vector q_indptr_h; - std::vector kv_indptr_h; - std::vector kv_page_indices_h; - std::vector kv_last_page_len_h; - // calculate the attention meta data - prepare_inference_params_kernel_h(batch_config, - q_indptr_h, - kv_indptr_h, - kv_page_indices_h, - kv_last_page_len_h); - checkCUDA(cudaMemcpyAsync(handle.incr_attention_metadata->q_indptr, - q_indptr_h.data(), - sizeof(int32_t) * q_indptr_h.size(), - cudaMemcpyHostToDevice, - stream)); - checkCUDA(cudaMemcpyAsync(handle.incr_attention_metadata->kv_indptr, - kv_indptr_h.data(), - sizeof(int32_t) * kv_indptr_h.size(), - cudaMemcpyHostToDevice, - stream)); - checkCUDA(cudaMemcpyAsync(handle.incr_attention_metadata->kv_indices, - kv_page_indices_h.data(), - sizeof(int32_t) * kv_page_indices_h.size(), - cudaMemcpyHostToDevice, - stream)); - checkCUDA(cudaMemcpyAsync(handle.incr_attention_metadata->kv_last_page_len, - kv_last_page_len_h.data(), - sizeof(int32_t) * kv_last_page_len_h.size(), - cudaMemcpyHostToDevice, - stream)); - // prepare attention forward handler - if (handle.incr_attention_metadata->prompt_handler_collections.count( - batch_size) == 0) { - handle.incr_attention_metadata->prompt_handler_collections[batch_size] = - static_cast(new flashinfer::BatchPrefillHandler(true)); + + int num_prefill_reqs = batch_config->num_prefill_requests(); + int num_decoding_reqs = batch_config->num_decoding_requests(); + + // printf("num_prefill_reqs: %d\n", num_prefill_reqs); + // printf("num_decoding_reqs: %d\n", num_decoding_reqs); + + // 1. prepare the indptrs for decoding requests, which occupy the first section in the batch + // ================================================================ + if (num_decoding_reqs > 0) { + std::vector q_indptr_h; + std::vector kv_indptr_h; + std::vector kv_page_indices_h; + std::vector kv_last_page_len_h; + prepare_inference_params_kernel_h(batch_config, + /*decoding=*/true, + q_indptr_h, + kv_indptr_h, + kv_page_indices_h, + kv_last_page_len_h); + checkCUDA(cudaMemcpyAsync(handle.incr_attention_metadata->q_indptr_dec, + q_indptr_h.data(), + sizeof(int32_t) * q_indptr_h.size(), + cudaMemcpyHostToDevice, + stream)); + checkCUDA(cudaMemcpyAsync(handle.incr_attention_metadata->kv_indptr_dec, + kv_indptr_h.data(), + sizeof(int32_t) * kv_indptr_h.size(), + cudaMemcpyHostToDevice, + stream)); + checkCUDA(cudaMemcpyAsync(handle.incr_attention_metadata->kv_indices_dec, + kv_page_indices_h.data(), + sizeof(int32_t) * kv_page_indices_h.size(), + cudaMemcpyHostToDevice, + stream)); + checkCUDA(cudaMemcpyAsync(handle.incr_attention_metadata->kv_last_page_len_dec, + kv_last_page_len_h.data(), + sizeof(int32_t) * kv_last_page_len_h.size(), + cudaMemcpyHostToDevice, + stream)); + // prepare attention forward handler + if (handle.incr_attention_metadata->decode_handler_collections.count(num_decoding_reqs) == 0) { + handle.incr_attention_metadata->decode_handler_collections[num_decoding_reqs] = static_cast(new flashinfer::BatchDecodeHandler(true, num_decoding_reqs)); + // printf("BatchDecodeHandler %p\n", handle.incr_attention_metadata->decode_handler_collections[num_decoding_reqs]); + } + BatchDecodeHandler *handler = static_cast(handle.incr_attention_metadata->decode_handler_collections[num_decoding_reqs]); + assert(handler != nullptr && "BatchDecodeHandler is null"); + + handler->SetCUDAStream(stream); + // static int step=0; + PageManager *pm = PageManager::get_page_manager(); + // printf("BatchPrefillHandler %p\n", handler); + // std::cout << "STEP " << step << ": " << *pm << std::endl; + // step+=1; + // std::cout << "batch_config: " << *batch_config << std::endl; + // std::cout << "q_indptr_h: "; + // for (int i = 0; i < q_indptr_h.size(); i++) { + // std::cout << q_indptr_h[i] << " "; + // } + // std::cout << std::endl; + // std::cout << "kv_indptr_h: "; + // for (int i = 0; i < kv_indptr_h.size(); i++) { + // std::cout << kv_indptr_h[i] << " "; + // } + // std::cout << std::endl; + // std::cout << "kv_page_indices_h: "; + // for (int i = 0; i < kv_page_indices_h.size(); i++) { + // std::cout << kv_page_indices_h[i] << " "; + // } + // std::cout << std::endl; + // std::cout << "kv_last_page_len_h: "; + // for (int i = 0; i < kv_last_page_len_h.size(); i++) { + // std::cout << kv_last_page_len_h[i] << " "; + // } + // std::cout << std::endl; + // std::cout << "batch_size: " << batch_size << std::endl; + + // std::cout << "num_q_heads: " << + // handle.incr_attention_metadata->num_q_heads() << std::endl; std::cout << + // "num_kv_heads: " << handle.incr_attention_metadata->num_kv_heads() << + // std::endl; std::cout << "head_dim: " << + // handle.incr_attention_metadata->head_dim() << std::endl; std::cout << + // "tokens_per_page: " << pm->get_tokens_per_page() << std::endl; std::cout + // << "float_workspace_size: " << + // handle.incr_attention_metadata->float_workspace_size << std::endl; + // std::cout << "int_workspace_size: " << + // handle.incr_attention_metadata->int_workspace_size << std::endl; + DISPATCH_HEADDIM( handle.incr_attention_metadata->head_dim(), HEAD_DIM, { + handler->BeginForwardDispatched( + static_cast(handle.incr_attention_metadata->float_workspace_dec), + handle.incr_attention_metadata->float_workspace_size, + static_cast(handle.incr_attention_metadata->int_workspace_dec), + handle.incr_attention_metadata->int_workspace_size, + static_cast(kv_indptr_h.data()), + static_cast(kv_last_page_len_h.data()), + num_decoding_reqs, + handle.incr_attention_metadata->num_q_heads(), + handle.incr_attention_metadata->num_kv_heads(), + pm->get_tokens_per_page()); + }); + } + + // 2. prepare the indptrs for prefilling requests, which occupy the second section in the batch + // ================================================================ + if (num_prefill_reqs > 0) { + std::vector q_indptr_h; + std::vector kv_indptr_h; + std::vector kv_page_indices_h; + std::vector kv_last_page_len_h; + prepare_inference_params_kernel_h(batch_config, + /*decoding=*/false, + q_indptr_h, + kv_indptr_h, + kv_page_indices_h, + kv_last_page_len_h); + checkCUDA(cudaMemcpyAsync(handle.incr_attention_metadata->q_indptr_pref, + q_indptr_h.data(), + sizeof(int32_t) * q_indptr_h.size(), + cudaMemcpyHostToDevice, + stream)); + checkCUDA(cudaMemcpyAsync(handle.incr_attention_metadata->kv_indptr_pref, + kv_indptr_h.data(), + sizeof(int32_t) * kv_indptr_h.size(), + cudaMemcpyHostToDevice, + stream)); + checkCUDA(cudaMemcpyAsync(handle.incr_attention_metadata->kv_indices_pref, + kv_page_indices_h.data(), + sizeof(int32_t) * kv_page_indices_h.size(), + cudaMemcpyHostToDevice, + stream)); + checkCUDA(cudaMemcpyAsync(handle.incr_attention_metadata->kv_last_page_len_pref, + kv_last_page_len_h.data(), + sizeof(int32_t) * kv_last_page_len_h.size(), + cudaMemcpyHostToDevice, + stream)); + // prepare attention forward handler + if (handle.incr_attention_metadata->prompt_handler_collections.count(num_prefill_reqs) == 0) { + handle.incr_attention_metadata->prompt_handler_collections[num_prefill_reqs] = static_cast(new flashinfer::BatchPrefillHandler(true)); + // printf("BatchPrefillHandler %p\n", handle.incr_attention_metadata->prompt_handler_collections[num_prefill_reqs]); + } + BatchPrefillHandler *handler = static_cast(handle.incr_attention_metadata->prompt_handler_collections[num_prefill_reqs]); + assert(handler != nullptr && "BatchPrefillHandler is null"); + + handler->SetCUDAStream(stream); + // static int step=0; + PageManager *pm = PageManager::get_page_manager(); + // printf("BatchPrefillHandler %p\n", handler); + // std::cout << "STEP " << step << ": " << *pm << std::endl; + // step+=1; + // std::cout << "batch_config: " << *batch_config << std::endl; + // std::cout << "q_indptr_h: "; + // for (int i = 0; i < q_indptr_h.size(); i++) { + // std::cout << q_indptr_h[i] << " "; + // } + // std::cout << std::endl; + // std::cout << "kv_indptr_h: "; + // for (int i = 0; i < kv_indptr_h.size(); i++) { + // std::cout << kv_indptr_h[i] << " "; + // } + // std::cout << std::endl; + // std::cout << "kv_page_indices_h: "; + // for (int i = 0; i < kv_page_indices_h.size(); i++) { + // std::cout << kv_page_indices_h[i] << " "; + // } + // std::cout << std::endl; + // std::cout << "kv_last_page_len_h: "; + // for (int i = 0; i < kv_last_page_len_h.size(); i++) { + // std::cout << kv_last_page_len_h[i] << " "; + // } + // std::cout << std::endl; + // std::cout << "batch_size: " << batch_size << std::endl; + + // std::cout << "num_q_heads: " << + // handle.incr_attention_metadata->num_q_heads() << std::endl; std::cout << + // "num_kv_heads: " << handle.incr_attention_metadata->num_kv_heads() << + // std::endl; std::cout << "head_dim: " << + // handle.incr_attention_metadata->head_dim() << std::endl; std::cout << + // "tokens_per_page: " << pm->get_tokens_per_page() << std::endl; std::cout + // << "float_workspace_size: " << + // handle.incr_attention_metadata->float_workspace_size << std::endl; + // std::cout << "int_workspace_size: " << + // handle.incr_attention_metadata->int_workspace_size << std::endl; + + handler->BeginForward( + static_cast(handle.incr_attention_metadata->float_workspace_pref), + handle.incr_attention_metadata->float_workspace_size, + static_cast(handle.incr_attention_metadata->int_workspace_pref), + handle.incr_attention_metadata->int_workspace_size, + static_cast(q_indptr_h.data()), + static_cast(kv_indptr_h.data()), + num_prefill_reqs, + handle.incr_attention_metadata->num_q_heads(), + handle.incr_attention_metadata->num_kv_heads(), + handle.incr_attention_metadata->head_dim(), + pm->get_tokens_per_page()); } - BatchPrefillHandler *handler = static_cast( - handle.incr_attention_metadata->prompt_handler_collections[batch_size]); - handler->SetCUDAStream(stream); - // static int step=0; - PageManager *pm = PageManager::get_page_manager(); - // printf("BatchPrefillHandler %p\n", handler); - // std::cout << "STEP " << step << ": " << *pm << std::endl; - // step+=1; - // std::cout << "batch_config: " << *batch_config << std::endl; - // std::cout << "q_indptr_h: "; - // for (int i = 0; i < q_indptr_h.size(); i++) { - // std::cout << q_indptr_h[i] << " "; - // } - // std::cout << std::endl; - // std::cout << "kv_indptr_h: "; - // for (int i = 0; i < kv_indptr_h.size(); i++) { - // std::cout << kv_indptr_h[i] << " "; - // } - // std::cout << std::endl; - // std::cout << "kv_page_indices_h: "; - // for (int i = 0; i < kv_page_indices_h.size(); i++) { - // std::cout << kv_page_indices_h[i] << " "; - // } - // std::cout << std::endl; - // std::cout << "kv_last_page_len_h: "; - // for (int i = 0; i < kv_last_page_len_h.size(); i++) { - // std::cout << kv_last_page_len_h[i] << " "; - // } - // std::cout << std::endl; - // std::cout << "batch_size: " << batch_size << std::endl; - - // std::cout << "num_q_heads: " << - // handle.incr_attention_metadata->num_q_heads() << std::endl; std::cout << - // "num_kv_heads: " << handle.incr_attention_metadata->num_kv_heads() << - // std::endl; std::cout << "head_dim: " << - // handle.incr_attention_metadata->head_dim() << std::endl; std::cout << - // "tokens_per_page: " << pm->get_tokens_per_page() << std::endl; std::cout - // << "float_workspace_size: " << - // handle.incr_attention_metadata->float_workspace_size << std::endl; - // std::cout << "int_workspace_size: " << - // handle.incr_attention_metadata->int_workspace_size << std::endl; - - handler->BeginForward( - static_cast(handle.incr_attention_metadata->float_workspace), - handle.incr_attention_metadata->float_workspace_size, - static_cast(handle.incr_attention_metadata->int_workspace), - handle.incr_attention_metadata->int_workspace_size, - static_cast(q_indptr_h.data()), - static_cast(kv_indptr_h.data()), - batch_size, - handle.incr_attention_metadata->num_q_heads(), - handle.incr_attention_metadata->num_kv_heads(), - handle.incr_attention_metadata->head_dim(), - pm->get_tokens_per_page()); } }