Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 80 additions & 53 deletions include/flexflow/attention_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,22 @@ class AttentionMetaData {
num_q_heads_ = 0;
num_kv_heads_ = 0;
head_dim_ = 0;
q_indptr = nullptr;
kv_indptr = nullptr;
kv_indices = nullptr;
kv_last_page_len = nullptr;
qk_indptr = nullptr;
custom_mask = nullptr;
workspace = nullptr;
q_indptr_dec = nullptr;
kv_indptr_dec = nullptr;
kv_indices_dec = nullptr;
kv_last_page_len_dec = nullptr;
workspace_dec = nullptr;
float_workspace_dec = nullptr;
int_workspace_dec = nullptr;
q_indptr_pref = nullptr;
kv_indptr_pref = nullptr;
kv_indices_pref = nullptr;
kv_last_page_len_pref = nullptr;
workspace_pref = nullptr;
float_workspace_pref = nullptr;
int_workspace_pref = nullptr;
workspace_size = 0;
float_workspace = nullptr;
float_workspace_size = 0;
int_workspace = nullptr;
int_workspace_size = 0;
mem_size_ = 0;
enabled_ = false;
Expand All @@ -80,17 +85,22 @@ class AttentionMetaData {
num_q_heads_ = rhs.num_q_heads_;
num_kv_heads_ = rhs.num_kv_heads_;
head_dim_ = rhs.head_dim_;
q_indptr = rhs.q_indptr;
kv_indptr = rhs.kv_indptr;
kv_indices = rhs.kv_indices;
kv_last_page_len = rhs.kv_last_page_len;
qk_indptr = rhs.qk_indptr;
custom_mask = rhs.custom_mask;
workspace = rhs.workspace;
q_indptr_dec = rhs.q_indptr_dec;
kv_indptr_dec= rhs.kv_indptr_dec;
kv_indices_dec = rhs.kv_indices_dec;
kv_last_page_len_dec = rhs.kv_last_page_len_dec;
workspace_dec = rhs.workspace_dec;
float_workspace_dec = rhs.float_workspace_dec;
int_workspace_dec = rhs.int_workspace_dec;
q_indptr_pref = rhs.q_indptr_pref;
kv_indptr_pref= rhs.kv_indptr_pref;
kv_indices_pref = rhs.kv_indices_pref;
kv_last_page_len_pref = rhs.kv_last_page_len_pref;
workspace_pref = rhs.workspace_pref;
float_workspace_pref = rhs.float_workspace_pref;
int_workspace_pref = rhs.int_workspace_pref;
workspace_size = rhs.workspace_size;
float_workspace = rhs.float_workspace;
float_workspace_size = rhs.float_workspace_size;
int_workspace = rhs.int_workspace;
int_workspace_size = rhs.int_workspace_size;
mem_size_ = rhs.mem_size_;
enabled_ = rhs.enabled_;
Expand All @@ -106,30 +116,33 @@ class AttentionMetaData {
size_t max_num_pages = round_up_pages(BatchConfig::max_sequence_length());
size_t indices_size = std::max(
(batch_size + 1) * 4 + max_num_pages * batch_size, 1ul * 1024 * 1024);
size_t custom_mask_size = 0;

float_workspace_size = 128 * 1024 * 1024; // 128 MB
int_workspace_size = 8 * 1024 * 1024; // 8 MB
workspace_size =
float_workspace_size + int_workspace_size; // float + int workspace

mem_size_ = alignTo(sizeof(int32_t) * indices_size +
sizeof(uint8_t) * custom_mask_size + workspace_size,
mem_size_ = alignTo(2*(sizeof(int32_t) * indices_size + workspace_size),
16);
return mem_size_;
}

void assign_address(void *ptr, int size) {
if (ptr == nullptr) {
q_indptr = nullptr;
kv_indptr = nullptr;
kv_indices = nullptr;
kv_last_page_len = nullptr;
qk_indptr = nullptr;
custom_mask = nullptr;
workspace = nullptr;
float_workspace = nullptr;
int_workspace = nullptr;
q_indptr_dec = nullptr;
kv_indptr_dec = nullptr;
kv_indices_dec = nullptr;
kv_last_page_len_dec = nullptr;
workspace_dec = nullptr;
float_workspace_dec = nullptr;
int_workspace_dec = nullptr;
q_indptr_pref = nullptr;
kv_indptr_pref = nullptr;
kv_indices_pref = nullptr;
kv_last_page_len_pref = nullptr;
workspace_pref = nullptr;
float_workspace_pref = nullptr;
int_workspace_pref = nullptr;
return;
}
assert(size >= mem_size() &&
Expand All @@ -138,19 +151,26 @@ class AttentionMetaData {
size_t max_num_pages = round_up_pages(BatchConfig::max_sequence_length());
size_t indices_size = std::max(
(batch_size + 1) * 4 + max_num_pages * batch_size, 1ul * 1024 * 1024);
size_t custom_mask_size = 0;

q_indptr = static_cast<int32_t *>(ptr);
kv_indptr = q_indptr + batch_size + 1;
kv_indices = kv_indptr + batch_size + 1;
kv_last_page_len = kv_indices + max_num_pages * batch_size;
qk_indptr = kv_last_page_len + batch_size + 1;
custom_mask = static_cast<uint8_t *>(ptr) + sizeof(int32_t) * indices_size;
workspace = static_cast<void *>(static_cast<uint8_t *>(ptr) +
sizeof(int32_t) * indices_size +
sizeof(uint8_t) * custom_mask_size);
float_workspace = workspace;
int_workspace = static_cast<void *>(static_cast<uint8_t *>(workspace) +

q_indptr_dec = static_cast<int32_t *>(ptr);
kv_indptr_dec = q_indptr_dec + batch_size + 1;
kv_indices_dec = kv_indptr_dec + batch_size + 1;
kv_last_page_len_dec = kv_indices_dec + max_num_pages * batch_size;
q_indptr_pref = static_cast<int32_t *>(ptr) + indices_size;
kv_indptr_pref = q_indptr_pref + batch_size + 1;
kv_indices_pref = kv_indptr_pref + batch_size + 1;
kv_last_page_len_pref = kv_indices_pref + max_num_pages * batch_size;

workspace_dec = static_cast<void *>(static_cast<uint8_t *>(ptr) +
sizeof(int32_t) * indices_size * 2);
float_workspace_dec = workspace_dec;
int_workspace_dec = static_cast<void *>(static_cast<uint8_t *>(workspace_dec) +
float_workspace_size);
workspace_pref = static_cast<void *>(static_cast<uint8_t *>(ptr) +
sizeof(int32_t) * indices_size * 2 +
workspace_size);
float_workspace_pref = workspace_pref;
int_workspace_pref = static_cast<void *>(static_cast<uint8_t *>(workspace_pref) +
float_workspace_size);
}

Expand Down Expand Up @@ -184,19 +204,26 @@ class AttentionMetaData {
uint32_t num_kv_heads_;
uint32_t head_dim_;

int32_t *q_indptr;
int32_t *kv_indptr;
int32_t *kv_indices;
int32_t *kv_last_page_len;
int32_t *qk_indptr;
uint8_t *custom_mask;
void *workspace;
int32_t *q_indptr_dec;
int32_t *kv_indptr_dec;
int32_t *kv_indices_dec;
int32_t *kv_last_page_len_dec;
uint8_t *custom_mask_dec;
void *workspace_dec;
void *float_workspace_dec;
void *int_workspace_dec;
int32_t *q_indptr_pref;
int32_t *kv_indptr_pref;
int32_t *kv_indices_pref;
int32_t *kv_last_page_len_pref;
uint8_t *custom_mask_pref;
void *workspace_pref;
void *float_workspace_pref;
void *int_workspace_pref;

size_t workspace_size;
void *float_workspace;
size_t float_workspace_size;
void *int_workspace;
size_t int_workspace_size;

size_t mem_size_;

// batchsize -> handler
Expand Down
7 changes: 5 additions & 2 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,12 @@ class BatchConfig {
// returns number of inference and finetuning FWD tokens
int num_active_tokens() const;

// returns number of inference-only tokens
// returns number of inference-only tokens (prefill + decode)
int num_inference_tokens() const;
// returns number of inference-only requests (prefill + decode)
int num_inference_requests() const;
int num_prefill_requests() const;
int num_decoding_requests() const;

// return the index where the finetuning request would be stored (i.e. last
// slot of the batch)
Expand Down Expand Up @@ -97,7 +100,7 @@ class BatchConfig {
// These maximum values are used for copying BatchConfig
// across workers
static int const MAX_NUM_REQUESTS = 260;
static int const MAX_NUM_TOKENS = 3000;
static int const MAX_NUM_TOKENS = 8192;
static int const MAX_SPEC_TREE_TOKEN_NUM = 64;
static int const MAX_PEFT_CONFIG_SIZE = 1024;

Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@ struct FFHandler {
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
cudnnHandle_t dnn, peft_dnn;
cublasHandle_t blas, peft_blas;
cudaStream_t extra_stream1, extra_stream2;
#else
miopenHandle_t dnn, peft_dnn;
hipblasHandle_t blas, peft_blas;
hipStream_t extra_stream1, extra_stream2;
#endif
void *workSpace;
size_t workSpaceSize;
Expand Down
5 changes: 4 additions & 1 deletion include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@ struct Request {
std::vector<float> finetuning_losses;
// bwd state
int last_processed_bwd_layer = INT_MAX;
int max_samples = -1; // -1: no limit
// how many gradient accumulation steps to do before updating the weights.
// if left as -1, it will be set to the number of entries in the dataset
int gradient_accumulation_steps = -1;
int num_logging_steps = 10;
// std::vector<int> finetuning_tokens_per_batch;
};
RequestType req_type = REQ_INFERENCE;
Expand All @@ -118,6 +120,7 @@ struct Request {
int benchmarking_tokens = -1;
bool add_special_tokens = true;
bool warmup = false;
bool ignore_eos = false;

Status status = PENDING;
long long arrival_time_us = 0;
Expand Down Expand Up @@ -401,7 +404,7 @@ class RequestManager {

// peft
std::unordered_map<PEFTModelID, LoraLinearConfig> peft_configs;
int max_lora_rank = 32;
int max_lora_rank = 16;
int max_concurrent_adapters = 0;
// peft benchmarking
bool enable_peft_finetuning = false;
Expand Down
47 changes: 39 additions & 8 deletions inference/flexllm/peft_train.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ void parse_input_args(char **argv,
int &max_tokens_per_batch,
int &max_sequence_length,
int &num_kv_cache_slots,
int &max_finetuning_samples,
int &max_training_epochs,
int &gradient_accumulation_steps,
int &num_logging_steps,
int &num_layers_per_finetuning_step,
bool &run_warmup) {
for (int i = 1; i < argc; i++) {
Expand Down Expand Up @@ -143,10 +146,22 @@ void parse_input_args(char **argv,
num_kv_cache_slots = std::stoi(argv[++i]);
continue;
}
if (!strcmp(argv[i], "--max-training-steps")) {
if (!strcmp(argv[i], "--max-finetuning-samples")) {
max_finetuning_samples = std::stoi(argv[++i]);
continue;
}
if (!strcmp(argv[i], "--max-training-epochs")) {
max_training_epochs = std::stoi(argv[++i]);
continue;
}
if (!strcmp(argv[i], "--gradient-accumulation-steps")) {
gradient_accumulation_steps = std::stoi(argv[++i]);
continue;
}
if (!strcmp(argv[i], "--num-logging-steps")) {
num_logging_steps = std::stoi(argv[++i]);
continue;
}
if (!strcmp(argv[i], "--num-layers-per-finetuning-step")) {
num_layers_per_finetuning_step = std::stoi(argv[++i]);
continue;
Expand Down Expand Up @@ -213,6 +228,7 @@ std::vector<Request> load_trace(nlohmann::ordered_json prompt_json,
int prompt_length = entry["prompt_length"];
int response_length = entry["response_length"];
std::string text = entry["prompt"];
double arrival_time_us = entry["arrival_time"].get<double>() * 1000000;

Request inference_req;
if (benchmarking) {
Expand All @@ -222,6 +238,8 @@ std::vector<Request> load_trace(nlohmann::ordered_json prompt_json,
inference_req.prompt = text;
}
inference_req.max_new_tokens = response_length;
inference_req.ignore_eos = true;
inference_req.arrival_time_us = arrival_time_us;
requests.push_back(inference_req);
}
return requests;
Expand Down Expand Up @@ -283,11 +301,15 @@ void FlexFlow::top_level_task(Task const *task,
int max_requests_per_batch = 1;
int max_tokens_per_batch = 128;
int max_sequence_length = 256;
int max_training_epochs = 2;
int max_finetuning_samples = -1; // -1: no limit
int max_training_epochs = 1;
int gradient_accumulation_steps = 8;
int num_logging_steps = 10;
bool enable_peft_finetuning = true;
int num_layers_per_finetuning_step = -1;
bool run_warmup = false;
int num_kv_cache_slots = -1;
int rank = 16;

InputArgs const &command_args = HighLevelRuntime::get_input_args();
char **argv = command_args.argv;
Expand All @@ -307,7 +329,10 @@ void FlexFlow::top_level_task(Task const *task,
max_tokens_per_batch,
max_sequence_length,
num_kv_cache_slots,
max_finetuning_samples,
max_training_epochs,
gradient_accumulation_steps,
num_logging_steps,
num_layers_per_finetuning_step,
run_warmup);
enable_peft_finetuning = file_paths.dataset_file_path.empty() ? false : true;
Expand Down Expand Up @@ -391,7 +416,6 @@ void FlexFlow::top_level_task(Task const *task,
"Invalid LLM model type passed (or no type was passed).");

// load PEFT config
int rank = 16;
LoraOptimizerConfig *optim_config = new LoraSGDOptimizerConfig(0.001f);
std::vector<std::string> target_modules = {"down_proj"};
LoraLinearConfig peft_config_finetuning(file_paths.cache_folder_path,
Expand All @@ -418,6 +442,7 @@ void FlexFlow::top_level_task(Task const *task,
model_type, bos_token_id, eos_token_ids, tokenizer_filepath);
rm->register_output_filepath(file_paths.output_file_path);
rm->set_enable_peft_finetuning(enable_peft_finetuning);
rm->set_max_lora_rank(rank);

FFModel model(ffconfig, ffconfig.cpu_offload);
model.set_num_kv_cache_pages(
Expand Down Expand Up @@ -482,9 +507,10 @@ void FlexFlow::top_level_task(Task const *task,

// Run workload
{
std::vector<Request> requests =
load_requests(file_paths.prompt_file_path, 128);

std::vector<Request> inference_requests;
if (!file_paths.prompt_file_path.empty()) {
inference_requests = load_requests(file_paths.prompt_file_path, 128);
}
// Add fine-tuning request
assert(!file_paths.dataset_file_path.empty() &&
"Dataset file path is required for fine-tuning.");
Expand All @@ -495,12 +521,17 @@ void FlexFlow::top_level_task(Task const *task,
fine_tuning_req.peft_model_id = *peft_model_id_finetuning;
fine_tuning_req.peft_finetuning_info.dataset_filepath =
file_paths.dataset_file_path;
fine_tuning_req.peft_finetuning_info.max_samples = max_finetuning_samples;
fine_tuning_req.peft_finetuning_info.max_training_epochs =
max_training_epochs;
requests.push_back(fine_tuning_req);
fine_tuning_req.peft_finetuning_info.gradient_accumulation_steps =
gradient_accumulation_steps;
fine_tuning_req.peft_finetuning_info.num_logging_steps = num_logging_steps;
std::vector<Request> finetuning_requests;
finetuning_requests.push_back(fine_tuning_req);

std::cout << "----------inference started--------------" << std::endl;
std::vector<GenerationResult> result = model.generate(requests);
std::vector<GenerationResult> result = model.generate_online(inference_requests, finetuning_requests);
std::cout << "----------inference finished--------------" << std::endl;
}

Expand Down
Loading
Loading