Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,17 @@ struct CombinedBatchConfigMetaStruct {

struct FFHandler {
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
cudnnHandle_t dnn, peft_dnn;
cublasHandle_t blas, peft_blas;
cudnnHandle_t dnn, peft_fwd_dnn, peft_bwd_dnn;
cublasHandle_t blas, peft_fwd_blas, peft_bwd_blas;
cudaStream_t peft_fwd_stream;
cudaEvent_t peft_fwd_can_start;
cudaEvent_t peft_fwd_done;
#else
miopenHandle_t dnn, peft_dnn;
hipblasHandle_t blas, peft_blas;
miopenHandle_t dnn, peft_fwd_dnn, peft_bwd_dnn;
hipblasHandle_t blas, peft_fwd_blas, peft_bwd_blas;
hipStream_t peft_fwd_stream;
hipEvent_t peft_fwd_can_start;
hipEvent_t peft_fwd_done;
#endif
void *workSpace;
size_t workSpaceSize;
Expand Down Expand Up @@ -178,7 +182,7 @@ class FFConfig {
size_t offload_reserve_space_size;
DataType quantization_type;
// PEFT related fields
bool enable_peft, enable_peft_finetuning;
PeftSupportMode peft_support_mode;
// Control parallelizable dimensions
bool only_data_parallel;
bool enable_sample_parallel;
Expand Down
18 changes: 18 additions & 0 deletions include/flexflow/ffconst.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,24 @@ enum RequestType {
REQ_FINETUNING = 4002,
};

enum PeftSupportMode {
PEFT_DISABLED = 5001,
// no finetuning supported
PEFT_INFERENCE_ONLY = 5002,
// finetuning fwd limited by max tokens per batch, bwd layers limited to 1
COSERVING = 5003,
// finetuning fwd/bwd unlimited, alternating inference and finetuning batches
TEMPORAL_SHARING = 5004,
// finetuning fwd/bwd unlimited, inference and finetuning work in the same batch (different kernels)
SPATIAL_SHARING = 5005,
// finetuning fwd limited by max tokens per batch, bwd layers limited to 1. Inference and finetuning work in the same batch (different kernels)
SPATIAL_SHARING_LIMITED = 5006,
// finetuning fwd limited by max tokens per batch, bwd layers limited to 1. Alternating inference and finetuning batches
TEMPORAL_SHARING_LIMITED = 5007,
// finetuning fwd/bwd unlimited, inference and finetuning work in separate Legion tasks
SPATIAL_SHARING_SEPARATE_TASKS = 5008,
};

// This is consistent with TASO's OpType
// https://github.com/jiazhihao/TASO/blob/master/include/taso/ops.h#L75-L138
enum OperatorType {
Expand Down
4 changes: 4 additions & 0 deletions include/flexflow/ffconst_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ size_t get_quantization_to_byte_size(DataType type,

std::ostream &operator<<(std::ostream &, OperatorType);

const char* peftSupportModeToString(const PeftSupportMode mode);
bool peft_finetuning_enabled(const PeftSupportMode peft_support_mode);
bool peft_enabled(const PeftSupportMode peft_support_mode);

}; // namespace FlexFlow

#endif // _FLEXFLOW_FFCONST_UTILS_H
11 changes: 5 additions & 6 deletions include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,10 @@ int flexflow_config_get_tensor_parallelism_degree(flexflow_config_t handle_);

int flexflow_config_get_pipeline_parallelism_degree(flexflow_config_t handle_);

bool flexflow_config_get_enable_peft(flexflow_config_t handle_);
void flexflow_config_set_peft_support_mode(flexflow_config_t handle_,
enum PeftSupportMode value);

bool flexflow_config_get_enable_peft_finetuning(flexflow_config_t handle_);
void flexflow_config_set_enable_peft_finetuning(flexflow_config_t handle_,
bool value);
enum PeftSupportMode flexflow_config_get_peft_support_mode(flexflow_config_t handle_);

void flexflow_config_set_data_parallelism_degree(flexflow_config_t handle_,
int value);
Expand Down Expand Up @@ -984,8 +983,8 @@ int flexflow_request_manager_get_max_sequence_length(
void flexflow_request_manager_set_max_concurrent_adapters(
flexflow_request_manager_t handle_, int max_concurrent_adapters);

void flexflow_request_manager_set_enable_peft_finetuning(
flexflow_request_manager_t handle_, bool enable_peft_finetuning_);
void flexflow_request_manager_set_peft_support_mode(
flexflow_request_manager_t handle_, enum PeftSupportMode peft_support_mode_);

void flexflow_request_manager_set_num_transformers_layers(
flexflow_request_manager_t handle_, int num_transformers_layers_);
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
*/

#pragma once
#include "flexflow/ffconst_utils.h"
#include "flexflow/batch_config.h"
#include <string>
#include <vector>
Expand Down
2 changes: 1 addition & 1 deletion include/flexflow/op_meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class OpMeta {
FFHandler handle;
bool profiling; // Measure the run time of the task
bool inference_debugging;
bool enable_peft_finetuning;
PeftSupportMode peft_support_mode;
int decoding_step;
int bwd_step;
char op_name[MAX_OPNAME];
Expand Down
4 changes: 2 additions & 2 deletions include/flexflow/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class Op {
Op(int guid,
bool profiling,
bool inference_debugging,
bool enable_peft_finetuning,
PeftSupportMode peft_support_mode,
OperatorType otype,
DataType dtype,
char const *name,
Expand Down Expand Up @@ -474,7 +474,7 @@ class Op {
int numInputs, numWeights, numOutputs;
bool profiling;
bool inference_debugging;
bool enable_peft_finetuning;
PeftSupportMode peft_support_mode;
bool add_bias_only_once;
#ifdef FF_USE_NCCL
ncclUniqueId ncclId;
Expand Down
8 changes: 3 additions & 5 deletions include/flexflow/ops/argmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,20 @@ class ArgMax : public Op {
MachineView const &pc,
CostMetrics &cost_metrics) const override;
template <typename DT>
static void forward_kernel(ArgMaxMeta const *m,
static void inference_kernel(ArgMaxMeta const *m,
BatchConfig const *bc,
DT const *input_ptr,
int *indices_ptr,
float *prob_ptr,
int *parent_ptr,
int length,
int batch_size,
int num_classes,
float *loss,
ffStream_t stream);
static void forward_kernel_wrapper(ArgMaxMeta const *m,
static void inference_kernel_wrapper(ArgMaxMeta const *m,
BatchConfig const *bc,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &indices,
GenericTensorAccessorW const &parent,
int batch_size,
float *loss);
Params get_params() const;

Expand Down
29 changes: 19 additions & 10 deletions include/flexflow/ops/kernels/embedding_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ class EmbeddingMeta : public OpMeta {

namespace Kernels {
namespace Embedding {
void forward_kernel_wrapper(EmbeddingMeta const *m,
void inference_kernel_wrapper(EmbeddingMeta const *m,
BatchConfig const *bc,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
GenericTensorAccessorR const &weight,
int in_dim,
int out_dim,
int batch_size);
int out_dim);
void backward_kernel_wrapper(EmbeddingMeta const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &output,
Expand All @@ -34,17 +34,26 @@ void backward_kernel_wrapper(EmbeddingMeta const *m,

namespace Internal {
template <typename TI, typename TD>
void forward_kernel(TI const *input_ptr,
void forward_kernel(EmbeddingMeta const *m,
BatchConfig const *bc,
TI const *input_ptr,
TD *output_ptr,
TD const *weight_ptr,
int in_dim,
int out_dim,
int batch_size,
AggrMode aggr,
int outputSize,
ffStream_t stream);

;
// int batch_size,
// AggrMode aggr,
// int outputSize,
cudaStream_t stream);
template <typename TI, typename TD>
void forward_kernel_spatial_sharing(EmbeddingMeta const *m,
BatchConfig const *bc,
TI const *input_ptr,
TD *output_ptr,
TD const *weight_ptr,
int in_dim,
int out_dim,
cudaStream_t main_stream);
} // namespace Internal
} // namespace Embedding
} // namespace Kernels
Expand Down
7 changes: 3 additions & 4 deletions include/flexflow/ops/kernels/linear_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ void inference_kernel_wrapper(LinearMeta *m,
void const *filter_ptr,
void const *bias_ptr,
int in_dim,
int out_dim,
int batch_size);
int out_dim);
void peft_bwd_kernel_wrapper(LinearMeta const *m,
BatchConfig const *bc,
void *input_grad_ptr,
Expand All @@ -83,13 +82,13 @@ bool use_activation(ActiMode mode);
namespace Internal {
template <typename DT>
void inference_kernel(LinearMeta const *m,
BatchConfig const *bc,
void const *input_ptr,
void *output_ptr,
void const *filter_ptr,
void const *weight_ptr,
void const *bias_ptr,
int in_dim,
int out_dim,
int batch_size,
ffStream_t stream);
template <typename DT>
void store_peft_activations(LinearMeta const *m,
Expand Down
4 changes: 2 additions & 2 deletions include/flexflow/ops/kernels/softmax_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ class SoftmaxMeta : public OpMeta {
MemoryAllocator &gpu_mem_allocator);
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
cudnnTensorDescriptor_t inputTensor;
cudnnTensorDescriptor_t outputTensor;
cudnnTensorDescriptor_t outputTensor, outputTensorPeftFwd;
#else
miopenTensorDescriptor_t inputTensor;
miopenTensorDescriptor_t outputTensor;
miopenTensorDescriptor_t outputTensor, outputTensorPeftFwd;
#endif
int dim;
// PEFT related fields
Expand Down
21 changes: 19 additions & 2 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ class RequestManager {
SERVING = 1002,
TERMINATED = 1003,
};
enum PeftTemporalSharingState {
INFERENCE = 0,
FINETUNING_FWD = 1,
FINETUNING_BWD = 2,
};
using TokenId = BatchConfig::TokenId;

RequestManager();
Expand All @@ -190,7 +195,8 @@ class RequestManager {
void set_max_sequence_length(int max_seq_length);

void push_spec_infer_tree_width(int tree_width);
void set_enable_peft_finetuning(bool enable_peft_finetuning_);
void set_peft_support_mode(PeftSupportMode peft_support_mode_);
void update_peft_temporal_sharing_state(void);
void set_inference_finished(bool finished = true);
int register_ssm_model(FFModel *model);
void register_tokenizer(ModelType model_type,
Expand All @@ -209,6 +215,9 @@ class RequestManager {
int get_num_transformer_layers();
void set_num_layers_per_finetuning_step(int num_layers_per_finetuning_step);
int get_num_layers_per_finetuning_step();
void set_temporal_sharing_frequency(int temporal_sharing_frequency);
int get_temporal_sharing_frequency();

void initBitMask(BatchConfig::BitMask &bitmask, int initLength);
void appendPendingRequest(BatchConfig::BitMask &bitmask, int initLength);
void appendBitMask(BatchConfig::BitMask &bitmask,
Expand Down Expand Up @@ -289,6 +298,9 @@ class RequestManager {
BatchConfig prepare_next_bwd_batch(BatchConfig &new_bc);
BatchConfig prepare_next_fwd_batch(BatchConfig const &old_bc,
InferenceResult const &result);
void add_inference_work_if_needed(BatchConfig &new_bc,
BatchConfig const &old_bc);
void check_new_bc(BatchConfig const &new_bc);
BatchConfigFuture
prepare_next_batch(BatchConfigFuture const &old_bc,
InferenceResultFuture const &result,
Expand Down Expand Up @@ -407,10 +419,15 @@ class RequestManager {
int max_lora_rank = 32;
int max_concurrent_adapters = 0;
// peft benchmarking
bool enable_peft_finetuning = false;
PeftSupportMode peft_support_mode = PEFT_DISABLED;
PeftTemporalSharingState peft_temporal_sharing_state =
PeftTemporalSharingState::INFERENCE;
int peft_temporal_sharing_inf_step = 0;
BatchConfig ts_saved_old_batch;
bool inference_finished = false;
int num_transformer_layers = 0;
int num_layers_per_finetuning_step = 0;
int temporal_sharing_frequency = 10;

// tree width in each speculative step, if not specified 1
std::vector<int> spec_infer_tree_width;
Expand Down
2 changes: 1 addition & 1 deletion inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ To run a PEFT model example in C++, call:
-llm-model JackFram/llama-160m \
-finetuning-dataset ../inference/prompt/peft_dataset.json \
-peft-model goliaro/llama-160m-lora \
-enable-peft \
--peft-support-mode COSERVING \
--use-full-precision \
--inference-debugging
```
73 changes: 73 additions & 0 deletions inference/flexllm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,76 @@ set(BIN_DEST "bin")
install(TARGETS ${project_target1} DESTINATION ${BIN_DEST})
install(PROGRAMS "${CMAKE_CURRENT_BINARY_DIR}/peft_train" DESTINATION ${BIN_DEST})
install(PROGRAMS "${CMAKE_CURRENT_BINARY_DIR}/gdb_peft_train" DESTINATION ${BIN_DEST})




# Overhead test
set(project_target2 overhead_test_unwrapped)
set(CPU_SRC2
${FLEXFLOW_CPP_DRV_SRC}
overhead_test.cc
../models/llama.cc
../models/opt.cc
../models/falcon.cc
../models/starcoder.cc
../models/mpt.cc)

if (FF_GPU_BACKEND STREQUAL "cuda" OR FF_GPU_BACKEND STREQUAL "hip_cuda")
cuda_add_executable(${project_target2} ${CPU_SRC2})
if (FF_GPU_BACKEND STREQUAL "hip_cuda")
target_compile_definitions(${project_target2} PRIVATE __HIP_PLATFORM_NVIDIA__)
endif()
elseif(FF_GPU_BACKEND STREQUAL "hip_rocm")
set_source_files_properties(${CPU_SRC2} PROPERTIES LANGUAGE HIP)
hip_add_executable(${project_target2} ${CPU_SRC2})
if (FF_HIP_ARCH STREQUAL "")
message(FATAL_ERROR "FF_HIP_ARCH is empty!")
endif()
set_property(TARGET ${project_target2} PROPERTY HIP_ARCHITECTURES "${FF_HIP_ARCH}")
target_compile_definitions(${project_target2} PRIVATE __HIP_PLATFORM_AMD__)
else()
message(FATAL_ERROR "Compilation of ${project_target2} for ${FF_GPU_BACKEND} backend not yet supported")
endif()

target_include_directories(${project_target2} PRIVATE ${FLEXFLOW_INCLUDE_DIRS} ${CMAKE_INSTALL_INCLUDEDIR})
target_include_directories(${project_target2} PRIVATE ${CMAKE_SOURCE_DIR}/inference)
target_link_libraries(${project_target2} -Wl,--whole-archive flexflow -Wl,--no-whole-archive ${FLEXFLOW_EXT_LIBRARIES})


set(TARGET_PATH "${project_target2}")

# Configure the normal execution wrapper.
# Here, LAUNCHER is simply "exec" so that it runs the executable normally.
set(LAUNCHER "exec")
configure_file(
"${CMAKE_CURRENT_SOURCE_DIR}/../inference_wrapper.in"
"${CMAKE_CURRENT_BINARY_DIR}/overhead_test"
@ONLY
)

file(CHMOD "${CMAKE_CURRENT_BINARY_DIR}/overhead_test"
PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE
GROUP_READ GROUP_EXECUTE
WORLD_READ WORLD_EXECUTE
)

# Configure the debugging launcher wrapper.
# Here, LAUNCHER is set to "gdb --args" so that it runs under gdb.
set(LAUNCHER "gdb -ex run --args")
configure_file(
"${CMAKE_CURRENT_SOURCE_DIR}/../inference_wrapper.in"
"${CMAKE_CURRENT_BINARY_DIR}/gdb_overhead_test"
@ONLY
)
file(CHMOD "${CMAKE_CURRENT_BINARY_DIR}/gdb_overhead_test"
PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE
GROUP_READ GROUP_EXECUTE
WORLD_READ WORLD_EXECUTE
)


set(BIN_DEST "bin")
install(TARGETS ${project_target2} DESTINATION ${BIN_DEST})
install(PROGRAMS "${CMAKE_CURRENT_BINARY_DIR}/overhead_test" DESTINATION ${BIN_DEST})
install(PROGRAMS "${CMAKE_CURRENT_BINARY_DIR}/gdb_overhead_test" DESTINATION ${BIN_DEST})
Loading
Loading