Skip to content

Commit

Permalink
Overlap KV cache update for WindowedKeyValueCache in DecoderOnlyPipel…
Browse files Browse the repository at this point in the history
…ineState (#1222)

# Description

Add support for overlapping the KV cache update for a pipelined model
part with the graph execution of other pipelined model parts. It only
applies to `DecoderOnlyPipelineState` with `WindowedKeyValueCache`.

For example, consider a model with two parts (graph[1] and graph[2])
that have KV caches.

This is the approach in this PR:
```
iter 1 graph[1] run | -
iter 1 graph[2] run | iter 1 graph[1] KV cache update
iter 2 graph[1] run | iter 1 graph[2] KV cache update
iter 2 graph[2] run | iter 2 graph[1] KV cache update
iter 3 graph[1] run | iter 2 graph[2] KV cache update
```

For comparison, this is the existing approach:
```
iter 1 graph[1] run
iter 1 graph[2] run
iter 1 graph[1] KV cache update
iter 1 graph[2] KV cache update
iter 2 graph[1] run
iter 2 graph[2] run
iter 2 graph[1] KV cache update
iter 2 graph[2] KV cache update
```

# Measurements

Token generation rate with QNN EP 3-part Llama3.2 3B:
Baseline: 15.5328 tokens/sec
Updated: 18.2981 tokens/sec

Prompt processing logic is unchanged.
  • Loading branch information
edgchen1 authored Feb 6, 2025
1 parent 4d702e2 commit 759333f
Show file tree
Hide file tree
Showing 20 changed files with 718 additions and 350 deletions.
9 changes: 6 additions & 3 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescript
parser.add_argument("--build", action="store_true", help="Build.")
parser.add_argument("--test", action="store_true", help="Run tests.")
parser.add_argument(
"--clean", action="store_true", help="Run 'cmake --build --target clean' for the selected config/s."
"--clean", action="store_true", help="Run 'cmake --build --target clean' for the selected config."
)

parser.add_argument("--skip_tests", action="store_true", help="Skip all tests. Overrides --test.")
Expand Down Expand Up @@ -320,7 +320,7 @@ def _validate_cmake_args(args: argparse.Namespace):

def _validate_args(args: argparse.Namespace):
# default to all 3 stages
if not args.update and not args.build and not args.test:
if not any((args.update, args.clean, args.build, args.test)):
args.update = True
args.build = True
args.test = True
Expand Down Expand Up @@ -639,7 +639,7 @@ def clean(args: argparse.Namespace, env: dict[str, str]):
Clean the build output.
"""
log.info("Cleaning targets")
cmd_args = [str(args.cmake), "--build", str(args.build_dir), "--config", args.config, "--target", "clean"]
cmd_args = [str(args.cmake_path), "--build", str(args.build_dir), "--config", args.config, "--target", "clean"]
util.run(cmd_args, env=env)


Expand All @@ -655,6 +655,9 @@ def clean(args: argparse.Namespace, env: dict[str, str]):
if arguments.update:
update(arguments, environment)

if arguments.clean:
clean(arguments, environment)

if arguments.build:
build(arguments, environment)

Expand Down
5 changes: 0 additions & 5 deletions cmake/check_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,6 @@ if(USE_CUDA AND CMAKE_CUDA_COMPILER)
"${GENERATORS_ROOT}/cuda/*.cuh"
)

file(GLOB test_cuda_srcs CONFIGURE_DEPENDS
"${TESTS_ROOT}/*.cu"
"${TESTS_ROOT}/*.cuh"
)
list(APPEND test_srcs ${test_cuda_srcs})
add_compile_definitions(USE_CUDA=1)
include_directories("${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}")
elseif(USE_CUDA)
Expand Down
1 change: 1 addition & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ using cudaStream_t = void*;
#endif

#include "leakcheck.h"
#include "make_string.h"
#include "smartptrs.h"
#include "models/onnxruntime_api.h"
#include "models/debugging.h"
Expand Down
18 changes: 18 additions & 0 deletions src/make_string.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <string>
#include <sstream>

namespace Generators {

template <typename... Args>
inline std::string MakeString(Args&&... args) {
std::ostringstream s;
(s << ... << std::forward<Args>(args));
return s.str();
}

} // namespace Generators
141 changes: 115 additions & 26 deletions src/models/decoder_only_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// Licensed under the MIT License.

#include "../generators.h"
#include "../logging.h"
#include "decoder_only_pipeline.h"
#include "windowed_kv_cache.h"

namespace Generators {

Expand Down Expand Up @@ -82,17 +84,42 @@ bool IntermediatePipelineState::SupportsPrimaryDevice() const {
DeviceSpan<float> IntermediatePipelineState::Run(int total_length, DeviceSpan<int32_t>& next_tokens,
DeviceSpan<int32_t> next_indices) {
State::Run(*model_.sessions_[id_], params_->BatchBeamSize());

return {};
}

using NameToLayerIdxMap = std::unordered_map<std::string, size_t>;

static NameToLayerIdxMap GeneratePastKeyNameToLayerIdxMap(const Config& config) {
const size_t num_layers = config.model.decoder.num_hidden_layers;
const std::string& past_key_name_template = config.model.decoder.inputs.past_key_names;
NameToLayerIdxMap m{};
for (size_t i = 0; i < num_layers; ++i) {
m.emplace(ComposeKeyValueName(past_key_name_template, static_cast<int>(i)), i);
}
return m;
}

static std::vector<size_t> DetectLayerIndicesFromPastKeyNameInputs(
const NameToLayerIdxMap& past_key_name_to_layer_idx, std::span<const std::string> inputs) {
std::vector<size_t> detected_layer_indices{};
for (const auto& input_name : inputs) {
const auto it = past_key_name_to_layer_idx.find(input_name);
if (it != past_key_name_to_layer_idx.end()) {
detected_layer_indices.push_back(it->second);
}
}
return detected_layer_indices;
}

DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineModel& model,
DeviceSpan<int32_t> sequence_lengths,
const GeneratorParams& params)
: State{params, model},
model_{model},
input_ids_{CreateInputIDs(*this)},
key_value_cache_{CreateKeyValueCache(*this)},
do_key_value_cache_partial_token_generation_update_{
key_value_cache_ && key_value_cache_->IsPartialTokenGenerationUpdateSupported()},
position_inputs_{CreatePositionInputs(*this, sequence_lengths)} {
input_ids_->Add();
position_inputs_->Add();
Expand All @@ -102,8 +129,41 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode
}
extra_inputs_.Add();

for ([[maybe_unused]] const auto& pipeline_model : model_.config_->model.decoder.pipeline) {
pipeline_states_.emplace_back(std::make_unique<IntermediatePipelineState>(model_, params, pipeline_states_.size()));
const auto past_key_name_to_layer_idx = [&]() -> std::optional<NameToLayerIdxMap> {
if (do_key_value_cache_partial_token_generation_update_) {
return GeneratePastKeyNameToLayerIdxMap(*model_.config_);
}
return std::nullopt;
}();

for (const auto& pipeline_model : model_.config_->model.decoder.pipeline) {
auto pipeline_model_state = std::make_unique<IntermediatePipelineState>(model_, params, pipeline_states_.size());

auto overlapped_kv_cache_update_record = [&]() -> std::optional<OverlappedKeyValueCacheUpdateRecord> {
if (do_key_value_cache_partial_token_generation_update_) {
const bool token_gen_only = !pipeline_model.run_on_prompt && pipeline_model.run_on_token_gen;
if (token_gen_only) {
auto layer_indices = DetectLayerIndicesFromPastKeyNameInputs(*past_key_name_to_layer_idx,
pipeline_model.inputs);
if (!layer_indices.empty()) {
// token generation model with KV cache tensors - we should overlap KV cache update
auto record = OverlappedKeyValueCacheUpdateRecord{};
record.layer_indices = std::move(layer_indices);
return record;
}
}
}
return std::nullopt;
}();

pipeline_states_.emplace_back(std::move(pipeline_model_state));
pipeline_overlapped_kv_cache_update_records_.emplace_back(std::move(overlapped_kv_cache_update_record));
}

if (std::any_of(pipeline_overlapped_kv_cache_update_records_.begin(),
pipeline_overlapped_kv_cache_update_records_.end(),
[](const auto& record) { return record.has_value(); })) {
key_value_cache_update_worker_thread_.emplace();
}
}

Expand Down Expand Up @@ -132,13 +192,11 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
for (const auto& input_name : input_names_) {
if (pipeline_state->HasInput(input_name)) {
if (!pipeline_state->SupportsPrimaryDevice()) {
std::ostringstream oss;
oss << "Managed input " << input_name << " resides on the primary device type ("
<< to_string(model_.device_type_) << "). "
<< "But the pipeline model "
<< model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id
<< " is expecting it to reside elsewhere.";
throw std::runtime_error(oss.str());
throw std::runtime_error(
MakeString("Managed input ", input_name, " resides on the primary device type (",
static_cast<int>(model_.device_type_), "). But the pipeline model ",
model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id,
" is expecting it to reside elsewhere."));
}
pipeline_state->input_names_.push_back(input_name);
pipeline_state->inputs_.push_back(State::GetInput(input_name));
Expand All @@ -157,13 +215,11 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
for (const auto& output_name : output_names_) {
if (pipeline_state->HasOutput(output_name)) {
if (!pipeline_state->SupportsPrimaryDevice()) {
std::ostringstream oss;
oss << "Managed output " << output_name << " resides on the primary device type ("
<< to_string(model_.device_type_) << "). "
<< "But the pipeline model "
<< model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id
<< " is expecting it to reside elsewhere.";
throw std::runtime_error(oss.str());
throw std::runtime_error(
MakeString("Managed output ", output_name, " resides on the primary device type (",
static_cast<int>(model_.device_type_), "). But the pipeline model ",
model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id,
" is expecting it to reside elsewhere."));
}
pipeline_state->output_names_.push_back(output_name);
pipeline_state->outputs_.push_back(State::GetOutput(output_name));
Expand All @@ -176,13 +232,11 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
for (const auto& input_name : input_names_) {
if (pipeline_state->HasOutput(input_name)) {
if (!pipeline_state->SupportsPrimaryDevice()) {
std::ostringstream oss;
oss << "Managed input " << input_name << " resides on the primary device type ("
<< to_string(model_.device_type_) << "). "
<< "But the pipeline model "
<< model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id
<< " is expecting it to reside elsewhere.";
throw std::runtime_error(oss.str());
throw std::runtime_error(
MakeString("Managed input ", input_name, " resides on the primary device type (",
static_cast<int>(model_.device_type_), "). But the pipeline model ",
model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id,
" is expecting it to reside elsewhere."));
}
pipeline_state->output_names_.push_back(input_name);
pipeline_state->outputs_.push_back(State::GetInput(input_name));
Expand All @@ -198,9 +252,28 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
}
}

auto& overlapped_kv_update_record = pipeline_overlapped_kv_cache_update_records_[pipeline_state->id_];
if (overlapped_kv_update_record.has_value()) {
// wait for any outstanding KV cache update to finish
if (overlapped_kv_update_record->outstanding_update.valid()) {
overlapped_kv_update_record->outstanding_update.get();
}
}

// Run the intermediate pipeline state
pipeline_state->Run(total_length, next_tokens, next_indices);

if (overlapped_kv_update_record.has_value()) {
assert(key_value_cache_update_worker_thread_.has_value());
// enqueue the next KV cache update
auto update_fn = [&key_value_cache = *key_value_cache_.get(),
layer_indices = overlapped_kv_update_record->layer_indices,
next_indices, total_length]() {
key_value_cache.PartialTokenGenerationUpdate(next_indices, total_length, layer_indices);
};
overlapped_kv_update_record->outstanding_update = key_value_cache_update_worker_thread_->Enqueue(update_fn);
}

// Transfer ownership of all the non-managed outputs from the current pipeline state to the ortvalue store.
// All non managed outputs are assumed to be on CPU
for (size_t i = 0; i < pipeline_state->output_names_.size(); ++i) {
Expand Down Expand Up @@ -235,7 +308,7 @@ DeviceSpan<float> DecoderOnlyPipelineState::Run(int total_length, DeviceSpan<int
if (model_.config_->model.decoder.sliding_window.has_value() && i < num_chunks - 1) {
// Sliding the window over the input_ids, key_cache, and value_cache, position_ids, and attention_mask
input_ids_->Update(next_tokens);
key_value_cache_->Update(next_indices, total_length);
if (key_value_cache_) key_value_cache_->Update(next_indices, total_length);
position_inputs_->Update(next_tokens, total_length, static_cast<int>(input_ids_->GetShape()[1]));
}
}
Expand Down Expand Up @@ -263,7 +336,23 @@ void DecoderOnlyPipelineState::UpdateInputsOutputs(DeviceSpan<int32_t>& next_tok
input_ids_->Update(next_tokens);
size_t new_length = input_ids_->GetShape()[1];
position_inputs_->Update(next_tokens, total_length, static_cast<int>(new_length));
if (key_value_cache_) key_value_cache_->Update(beam_indices, total_length);

if (key_value_cache_) {
const bool outstanding_key_value_cache_partial_token_generation_update =
do_key_value_cache_partial_token_generation_update_ &&
std::any_of(pipeline_overlapped_kv_cache_update_records_.rbegin(),
pipeline_overlapped_kv_cache_update_records_.rend(),
[](const std::optional<OverlappedKeyValueCacheUpdateRecord>& record) {
return record.has_value() && record->outstanding_update.valid();
});

if (outstanding_key_value_cache_partial_token_generation_update) {
// If there is any outstanding partial KV cache update, don't update the KV cache here.
} else {
key_value_cache_->Update(beam_indices, total_length);
}
}

logits_.Update(next_tokens, new_length);
}

Expand Down
16 changes: 16 additions & 0 deletions src/models/decoder_only_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@

#pragma once

#include <future>
#include <optional>

#include "../worker_thread.h"
#include "model.h"
#include "input_ids.h"
#include "logits.h"
#include "kv_cache.h"
#include "windowed_kv_cache.h"
#include "position_inputs.h"
#include "extra_inputs.h"

Expand Down Expand Up @@ -68,12 +73,23 @@ struct DecoderOnlyPipelineState : State {
const DecoderOnlyPipelineModel& model_;
std::vector<std::unique_ptr<IntermediatePipelineState>> pipeline_states_;

struct OverlappedKeyValueCacheUpdateRecord {
std::vector<size_t> layer_indices{}; // indicates which layers of the KV cache are to be updated
std::future<void> outstanding_update{}; // future for an outstanding update task
};

std::vector<std::optional<OverlappedKeyValueCacheUpdateRecord>> pipeline_overlapped_kv_cache_update_records_;

// Stores all the outputs from the previous pipeline state(s)
std::unordered_map<std::string, std::unique_ptr<OrtValue>> ortvalue_store_;

std::unique_ptr<InputIDs> input_ids_;
Logits logits_{*this};

std::unique_ptr<KeyValueCache> key_value_cache_;
const bool do_key_value_cache_partial_token_generation_update_;
std::optional<WorkerThread> key_value_cache_update_worker_thread_{};

std::unique_ptr<PositionInputs> position_inputs_;
ExtraInputs extra_inputs_{*this};
};
Expand Down
Loading

0 comments on commit 759333f

Please sign in to comment.