Skip to content

Commit

Permalink
OpenVINO EP Weights Sharing Feature (#23553)
Browse files Browse the repository at this point in the history
### Description
These changes are done to ensure that weight sharing happens between two model using session context option ep_weight_sharing.

Key changes introduced in this feature are:

Creating a shared context between two models Extracting external constant initializers and re labelling them back as
inputs to the model to allow weight loading in the direct blob. Creating EP Context Nodes when Subgraph partitioning is happening.

### Motivation and Context
This change was required to ensure that LLM with prefill and kvcache models can use the same share
The change was also required to ensure EP Context nodes can be formed even when model is being subgraph partitioned.

---------

Co-authored-by: jatinwadhwa921 <[email protected]>
Co-authored-by: jatinwadhwa921 <[email protected]>
Co-authored-by: saurabh <[email protected]>
Co-authored-by: TejalKhade28 <[email protected]>
Co-authored-by: sfatimar <[email protected]>
Co-authored-by: Javier E. Martinez <[email protected]>
Co-authored-by: Preetha Veeramalai <[email protected]>
Co-authored-by: Eric Crawford <[email protected]>
  • Loading branch information
9 people authored Feb 6, 2025
1 parent 2c2ff4a commit a6ea57b
Show file tree
Hide file tree
Showing 26 changed files with 1,420 additions and 992 deletions.
8 changes: 4 additions & 4 deletions cmake/onnxruntime_providers_openvino.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

# Header paths
find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX)
if(OpenVINO_VERSION VERSION_LESS 2024.4)
message(FATAL_ERROR "OpenVINO 2024.4 and newer are supported. Please, use latest OpenVINO release")
if(OpenVINO_VERSION VERSION_LESS 2024.5)
message(FATAL_ERROR "OpenVINO 2024.5 and newer are supported. Please, use latest OpenVINO release")
endif()

if(OpenVINO_VERSION VERSION_GREATER_EQUAL 2024.4)
Expand All @@ -30,7 +30,7 @@
endif()

list(APPEND OPENVINO_LIB_LIST openvino::frontend::onnx openvino::runtime ${PYTHON_LIBRARIES})
if ((DEFINED ENV{OPENCL_LIBS}) AND (DEFINED ENV{OPENCL_INCS}))
if ((DEFINED ENV{OPENCL_LIBS}) AND (DEFINED ENV{OPENCL_INCS}) AND onnxruntime_USE_OPENVINO_GPU)
add_definitions(-DIO_BUFFER_ENABLED=1)
list(APPEND OPENVINO_LIB_LIST $ENV{OPENCL_LIBS})
endif()
Expand Down Expand Up @@ -86,4 +86,4 @@
set_target_properties(onnxruntime_providers_openvino PROPERTIES
MAP_IMPORTED_CONFIG_RELEASE RelWithDebInfo
MAP_IMPORTED_CONFIG_DEBUG RelWithDebInfo
)
)
260 changes: 141 additions & 119 deletions onnxruntime/core/providers/openvino/backend_manager.cc

Large diffs are not rendered by default.

15 changes: 7 additions & 8 deletions onnxruntime/core/providers/openvino/backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,16 @@ namespace openvino_ep {
// Singleton class that manages all the backends
class BackendManager {
public:
BackendManager(const GlobalContext& global_context,
BackendManager(SessionContext& session_context,
SharedContext& shared_context,
const onnxruntime::Node& fused_node,
const onnxruntime::GraphViewer& subgraph,
const logging::Logger& logger,
EPCtxHandler& ctx_handle);
void Compute(OrtKernelContext* context);
void ShutdownBackendManager();
void SetGlobalCotext(const GlobalContext& global_context);
GlobalContext& GetGlobalContext();
Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph,
const logging::Logger& logger);
SessionContext& GetSessionContext();
Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph);
ov::CompiledModel& GetOVCompiledModel();

private:
Expand All @@ -52,9 +51,9 @@ class BackendManager {
std::shared_ptr<IBackend> concrete_backend_;
std::map<std::string, std::shared_ptr<IBackend>> backend_map_;
SubGraphContext subgraph_context_;
GlobalContext global_context_;
EPCtxHandler ep_ctx_handle_{};
std::string openvino_sdk_version_{};
EPCtxHandler& ep_ctx_handle_;
SessionContext& session_context_;
SharedContext& shared_context_;
};

} // namespace openvino_ep
Expand Down
213 changes: 191 additions & 22 deletions onnxruntime/core/providers/openvino/backend_utils.cc
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
// Copyright (C) Intel Corporation
// Licensed under the MIT License

#include <algorithm>
#include <sstream>
#include <fstream>
#include <utility>

#include <filesystem>
#include <stdexcept>

#include "openvino/pass/convert_fp32_to_fp16.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "openvino/runtime/intel_npu/level_zero/level_zero.hpp"
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/openvino/backend_utils.h"
#include "core/providers/openvino/ov_interface.h"
Expand All @@ -16,6 +19,105 @@ using Exception = ov::Exception;

namespace onnxruntime {
namespace openvino_ep {

SharedContext::SharedWeights::WeightsFile::WeightsFile(std::filesystem::path filename) : file_(filename, std::ios::in | std::ios::binary) {
try {
file_.exceptions(std::ifstream::failbit | std::ifstream::badbit);
weights_size_ = file_.seekg(0, std::ios::end).tellg();
} catch (std::ifstream::failure& e) {
ORT_THROW("Error: Failed to open weight file at ", filename.string(), " ", e.what());
}
}

void SharedContext::SharedWeights::WeightsFile::load_weights(size_t file_offset, void* data, size_t size) {
ORT_ENFORCE(file_offset < weights_size_ && size <= weights_size_ && (file_offset <= weights_size_ - size), "Error: File offset is out of bounds.");
file_.seekg(file_offset);
file_.read(reinterpret_cast<char*>(data), size);
}

std::ostream& operator<<(std::ostream& stream, const SharedContext::SharedWeights::Metadata::Map& metadata) {
try {
stream << metadata.size();

// Write each key-value pair
// Put elements in separate lines to facilitate reading
for (const auto& [key, value] : metadata) {
stream << std::endl
<< key.name;
stream << std::endl
<< value.location;
stream << std::endl
<< value.data_offset;
stream << std::endl
<< value.size;
stream << std::endl
<< value.dimensions.size();
for (const auto& dim : value.dimensions) {
stream << std::endl
<< dim;
}
stream << std::endl
<< value.element_type;
}
} catch (const Exception& e) {
ORT_THROW("Error: Failed to write map data.", e.what());
} catch (...) {
ORT_THROW("Error: Failed to write map data.");
}

ORT_ENFORCE(stream.good(), "Error: Failed to write map data.");
return stream;
}

std::istream& operator>>(std::istream& stream, SharedContext::SharedWeights::Metadata::Map& metadata) {
size_t map_size{0};
try {
stream >> map_size;

while (!stream.eof()) {
SharedContext::SharedWeights::Metadata::Key key;
SharedContext::SharedWeights::Metadata::Value value;
stream >> key.name;
stream >> value.location;
stream >> value.data_offset;
stream >> value.size;
size_t num_dimensions;
stream >> num_dimensions;

if (stream.fail()) {
ORT_THROW("Error: Failed to read num_dimensions from stream.");
}

constexpr size_t MAX_SAFE_DIMENSIONS = 1024;

size_t safe_num_dimensions = num_dimensions;

if (num_dimensions == 0 || safe_num_dimensions > MAX_SAFE_DIMENSIONS) {
ORT_THROW("Invalid number of dimensions provided.");
}
try {
value.dimensions.resize(safe_num_dimensions);
} catch (const std::bad_alloc&) {
ORT_THROW("Error: Memory allocation failed while resizing dimensions.");
}

for (auto& dim : value.dimensions) {
stream >> dim;
}
stream >> value.element_type;
metadata.emplace(key, value);
}
} catch (const Exception& e) {
ORT_THROW("Error: Failed to read map data.", e.what());
} catch (...) {
ORT_THROW("Error: Failed to read map data.");
}

ORT_ENFORCE(metadata.size() == map_size, "Error: Inconsistent map data.");

return stream;
}

namespace backend_utils {

bool IsDebugEnabled() {
Expand All @@ -34,23 +136,18 @@ bool IsCILogEnabled() {
return false;
}

struct static_cast_int64 {
template <typename T1> // T1 models type statically convertible to T
int64_t operator()(const T1& x) const { return static_cast<int64_t>(x); }
};

std::shared_ptr<const OVNetwork>
CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context,
CreateOVModel(const std::string model,
const SessionContext& session_context,
std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map) {
if (IsCILogEnabled()) {
std::cout << "CreateNgraphFunc" << std::endl;
}
const std::string model = model_proto.SerializeAsString();
try {
auto ov_model = global_context.ie_core.ReadModel(model, global_context.onnx_model_path_name);
auto ov_model = OVCore::ReadModel(model, session_context.onnx_model_path_name.string());

// Check for Constant Folding
if ((global_context.device_type != "NPU") && !global_context.is_wholly_supported_graph) {
if ((session_context.device_type != "NPU") && !session_context.is_wholly_supported_graph) {
ov::pass::ConstantFolding pass_const_obj;
pass_const_obj.run_on_model(ov_model);
auto& results = const_cast<ov::ResultVector&>(ov_model.get()->get_results());
Expand Down Expand Up @@ -82,7 +179,7 @@ Ort::UnownedValue
GetOutputTensor(Ort::KernelContext& context, size_t batch_size,
OVInferRequestPtr infer_request,
std::string output_name,
std::unordered_map<std::string, int> output_names) {
const SubGraphContext::string_index_map_t& output_names) {
auto graph_output_blob = infer_request->GetTensor(output_name);

auto graph_output_dims = graph_output_blob->get_shape();
Expand All @@ -107,7 +204,7 @@ GetOutputTensor(Ort::KernelContext& context, size_t batch_size,
Ort::UnownedValue
GetOutputTensor(Ort::KernelContext& context,
std::string output_name,
std::unordered_map<std::string, int> output_names,
const SubGraphContext::string_index_map_t& output_names,
std::shared_ptr<ov::Node> node) {
// Find position of '/' in the output_name
auto pos = output_name.find("/");
Expand All @@ -129,13 +226,13 @@ GetOutputTensor(Ort::KernelContext& context,
return context.GetOutput(index, output_shape.get(), num_dims);
}

int GetFirstAvailableDevice(GlobalContext& global_context) {
int GetFirstAvailableDevice(SessionContext& session_context) {
int i = 0;
// Get the first available VAD-M device and set the device to busy
while (i < 8) {
bool device = global_context.deviceAvailableList[i];
bool device = session_context.deviceAvailableList[i];
if (device) {
global_context.deviceAvailableList[i] = false;
session_context.deviceAvailableList[i] = false;
break;
}
i++;
Expand All @@ -144,9 +241,9 @@ int GetFirstAvailableDevice(GlobalContext& global_context) {
// make all remaining devices free
if (i == 8) {
i = 0;
global_context.deviceAvailableList[i] = false;
session_context.deviceAvailableList[i] = false;
for (int j = 1; j < 8; j++) {
global_context.deviceAvailableList[j] = true;
session_context.deviceAvailableList[j] = true;
}
}
return i;
Expand All @@ -155,23 +252,23 @@ int GetFirstAvailableDevice(GlobalContext& global_context) {
void FillOutputsWithConstantData(std::shared_ptr<ov::Node> node, Ort::UnownedValue& out_tensor) {
switch (node->get_element_type()) {
case ov::element::Type_t::f32: {
FillOutputHelper<float>(out_tensor, node);
FillOutputHelper<float>(out_tensor, std::move(node));
break;
}
case ov::element::Type_t::boolean: {
FillOutputHelper<char>(out_tensor, node);
FillOutputHelper<char>(out_tensor, std::move(node));
break;
}
case ov::element::Type_t::i32: {
FillOutputHelper<int32_t>(out_tensor, node);
FillOutputHelper<int32_t>(out_tensor, std::move(node));
break;
}
case ov::element::Type_t::i64: {
FillOutputHelper<int64_t>(out_tensor, node);
FillOutputHelper<int64_t>(out_tensor, std::move(node));
break;
}
case ov::element::Type_t::f16: {
FillOutputHelper<float>(out_tensor, node);
FillOutputHelper<float>(out_tensor, std::move(node));
break;
}
default:
Expand Down Expand Up @@ -267,6 +364,78 @@ void printPerformanceCounts(OVInferRequestPtr request, std::ostream& stream, std
printPerformanceCounts(performanceMap, stream, std::move(deviceName));
}

ov::element::Type GetOpenVINOElementType(ONNX_NAMESPACE::TensorProto_DataType dt) {
static std::unordered_map<ONNX_NAMESPACE::TensorProto_DataType, ov::element::Type> map{
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT, ov::element::f32},
{ONNX_NAMESPACE::TensorProto_DataType_UINT8, ov::element::u8},
{ONNX_NAMESPACE::TensorProto_DataType_INT8, ov::element::i8},
{ONNX_NAMESPACE::TensorProto_DataType_UINT16, ov::element::u16},
{ONNX_NAMESPACE::TensorProto_DataType_INT16, ov::element::i16},
{ONNX_NAMESPACE::TensorProto_DataType_INT32, ov::element::i32},
{ONNX_NAMESPACE::TensorProto_DataType_INT64, ov::element::i64},
{ONNX_NAMESPACE::TensorProto_DataType_STRING, ov::element::string},
{ONNX_NAMESPACE::TensorProto_DataType_BOOL, ov::element::boolean},
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, ov::element::f16},
{ONNX_NAMESPACE::TensorProto_DataType_DOUBLE, ov::element::f64},
{ONNX_NAMESPACE::TensorProto_DataType_UINT32, ov::element::u32},
{ONNX_NAMESPACE::TensorProto_DataType_UINT64, ov::element::u64},
//{ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64, ov::element::undefined},
//{ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128, ov::element::undefined},
{ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16, ov::element::bf16},
//{ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN, ov::element::undefined},
//{ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ, ov::element::undefined},
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2, ov::element::f8e5m2},
//{ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ, ov::element::undefined},
{ONNX_NAMESPACE::TensorProto_DataType_UINT4, ov::element::u4},
{ONNX_NAMESPACE::TensorProto_DataType_INT4, ov::element::i4},
};

if (auto result = map.find(dt); result != map.end()) {
return result->second;
} else {
throw std::runtime_error("Unsupported ONNX data type: " + std::to_string(dt));
}
}

// Function to handle tensor creation from external data
void CreateOVTensors(const std::string& device_name,
SharedContext::SharedWeights::Metadata::Map& metadata_map,
SharedContext::SharedWeights::WeightsFile& weights) {
for (auto& [key, value] : metadata_map) {
if (value.tensor) continue;

// Get element data type
auto onnx_element_type = (ONNX_NAMESPACE::TensorProto_DataType)value.element_type;

ov::element::Type ov_elementType = GetOpenVINOElementType(onnx_element_type); // Map to OpenVINO data type

// Create OpenVINO Tensor
if (device_name == "NPU") {
// Use remote tensors
auto npu_context = OVCore::Get().get_default_context("NPU").as<ov::intel_npu::level_zero::ZeroContext>();
auto&& remote_tensor = npu_context.create_l0_host_tensor(ov_elementType, value.dimensions, ov::intel_npu::TensorType::INPUT);

// Copy data to remote tensor
weights.load_weights(value.data_offset, remote_tensor.get(), value.size);
value.tensor = std::make_shared<ov::Tensor>(remote_tensor);
} else {
// Use vanilla tensors
value.tensor = std::make_shared<ov::Tensor>(ov_elementType, value.dimensions);
weights.load_weights(value.data_offset, value.tensor->data(), value.size);
}
ORT_ENFORCE(value.tensor->get_byte_size() == value.size, "Unexpected tensor size mismatch");
}
}

void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map) {
for (auto& [key, value] : metadata_map) {
if (value.tensor) {
value.tensor.reset();
}
}
metadata_map.clear();
}

} // namespace backend_utils
} // namespace openvino_ep
} // namespace onnxruntime
Loading

0 comments on commit a6ea57b

Please sign in to comment.