Skip to content

Add on-the-fly bfloat16->float16 conversion pass #740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: ovep-develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion onnxruntime/core/providers/openvino/backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,19 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node,
DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
return model_proto;
} else {
}
else if (session_context_.enable_bfloat16_optimizer) {
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP bfloat16->float16 optimization pass is enabled";
std::unique_ptr<onnxruntime::Model> model;
Status status = bfloat16_fix::Transform(subgraph, logger, model);
auto model_proto = model->ToProto();
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
print_model_proto_duration();
DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
return model_proto;
}
else {
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP QDQ optimization pass is disabled";
auto model = subgraph.CreateModel(logger);
auto model_proto = model->ToProto();
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/openvino/contexts.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ struct ProviderInfo {
bool disable_dynamic_shapes{false}; // [disable_dynamic_shapes]: Rewrite dynamic shaped models to
// static shape at runtime and execute.
bool enable_qdq_optimizer{false}; // Enables QDQ pruning for efficient inference latency with NPU
bool enable_bfloat16_optimizer{false}; // Enables on-the-fly bfloat16->float16 conversion
bool enable_causallm{false}; // Enables Causal LM Compilation for ORT GenAI OVEP Pass
bool so_context_enable{false}; // ORT session option
bool so_disable_cpu_ep_fallback{false}; // ORT session option
Expand All @@ -110,7 +111,7 @@ struct ProviderInfo {
const ConfigOptions* config_options{NULL};
const std::unordered_set<std::string> valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision",
"load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer",
"enable_causallm", "disable_dynamic_shapes", "reshape_input"};
"enable_bfloat16_optimizer", "enable_causallm", "disable_dynamic_shapes", "reshape_input"};

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is a separate provider option needed, is it possible to detect the model has bfloat16 datatype and intrinsically enable optimization ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfatimar, because at some point OpenVINO might enable the native execution of bfloat16 models. This is a workaround until this functionality is enabled. Let's discuss it with Mayuresh and act accordingly.
Regarding the graph optimizations link you've shared: strictly speaking, it's not the same kind of optimizations as are those in the list. It's more like the QDQ scales fix we implemented earlier.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot have external provider options as workaround because it impacts external users and apps and need to be given a deprecation notice 2 releases in advance. I would prefer it to be handled internally

};

// Holds context applicable to the entire EP instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer,
openvino_ep::GetCapability obj(ep_ctx_handle_,
graph_viewer,
session_context_.device_type,
session_context_.enable_qdq_optimizer);
session_context_.enable_qdq_optimizer,
session_context_.enable_bfloat16_optimizer);
result = obj.Execute();
session_context_.is_wholly_supported_graph = obj.IsWhollySupportedGraph();
session_context_.has_external_weights = obj.HasExternalWeights();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ static void ParseProviderInfo(const ProviderOptions& provider_options,

pi.enable_qdq_optimizer = ParseBooleanOption(provider_options, "enable_qdq_optimizer");

pi.enable_bfloat16_optimizer = ParseBooleanOption(provider_options, "enable_bfloat16_optimizer");

pi.enable_causallm = ParseBooleanOption(provider_options, "enable_causallm");

pi.disable_dynamic_shapes = ParseBooleanOption(provider_options, "disable_dynamic_shapes");
Expand Down
17 changes: 9 additions & 8 deletions onnxruntime/core/providers/openvino/ov_versions/capability.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ namespace openvino_ep {
GetCapability::GetCapability(const EPCtxHandler& ep_ctx_handler,
const GraphViewer& graph_viewer_param,
const std::string device_type_param,
const bool enable_qdq_optimizer) : ep_ctx_handler_(ep_ctx_handler),
graph_viewer_(graph_viewer_param),
device_type_(std::move(device_type_param)) {
const bool enable_qdq_optimizer,
bool enable_bfloat16_optimizer) : ep_ctx_handler_(ep_ctx_handler),
graph_viewer_(graph_viewer_param),
device_type_(std::move(device_type_param)) {
bool npu_qdq_optimizer_enabled = false;
if (device_type_.find("NPU") != std::string::npos) {
device_type_ = "CPU";
Expand All @@ -42,15 +43,15 @@ GetCapability::GetCapability(const EPCtxHandler& ep_ctx_handler,
}

#if OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 5
data_ops_ = new DataOps(graph_viewer_, V_2024_5, device_type_, npu_qdq_optimizer_enabled);
data_ops_ = new DataOps(graph_viewer_, V_2024_5, device_type_, npu_qdq_optimizer_enabled, enable_bfloat16_optimizer);
#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 6
data_ops_ = new DataOps(graph_viewer_, V_2024_6, device_type_, npu_qdq_optimizer_enabled);
data_ops_ = new DataOps(graph_viewer_, V_2024_6, device_type_, npu_qdq_optimizer_enabled, enable_bfloat16_optimizer);
#elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 0
data_ops_ = new DataOps(graph_viewer_, V_2025_0, device_type_, npu_qdq_optimizer_enabled);
data_ops_ = new DataOps(graph_viewer_, V_2025_0, device_type_, npu_qdq_optimizer_enabled, enable_bfloat16_optimizer);
#elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 1
data_ops_ = new DataOps(graph_viewer_, V_2025_1, device_type_, npu_qdq_optimizer_enabled);
data_ops_ = new DataOps(graph_viewer_, V_2025_1, device_type_, npu_qdq_optimizer_enabled, enable_bfloat16_optimizer);
#else
data_ops_ = new DataOps(graph_viewer_, V_2025_1, device_type_, npu_qdq_optimizer_enabled);
data_ops_ = new DataOps(graph_viewer_, V_2025_1, device_type_, npu_qdq_optimizer_enabled, enable_bfloat16_optimizer);
#endif
}

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/openvino/ov_versions/capability.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class GetCapability {
GetCapability(const EPCtxHandler& ep_ctx_handler,
const GraphViewer& graph_viewer_param,
const std::string device_type_param,
const bool enable_qdq_optimizer);
const bool enable_qdq_optimizer,
bool enable_bfloat16_optimizer);
virtual std::vector<std::unique_ptr<ComputeCapability>> Execute();
bool IsWhollySupportedGraph() {
return is_wholly_supported_graph_;
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/core/providers/openvino/ov_versions/data_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,11 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) {
return false;
}

auto dtype = type_proto->tensor_type().elem_type();
// Enable bfloat16 -> float16 on-the-fly conversion
if (bfloat16_optimizer_enabled_ && dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16)
return true;
if (is_initializer) {
auto dtype = type_proto->tensor_type().elem_type();
for (auto const& var : supported_types_initializer_) {
if ((var.first <= version_id_) &&
(var.second == dtype)) {
Expand All @@ -576,8 +579,6 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) {
#endif
return false;
} else {
auto dtype = type_proto->tensor_type().elem_type();

if (device_id_.find("HETERO") != std::string::npos ||
device_id_.find("MULTI") != std::string::npos || device_id_.find("AUTO") != std::string::npos) {
for (auto const& var : supported_types_npu_) {
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/core/providers/openvino/ov_versions/data_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class DataOps {
std::set<Pairs> supported_types_gpu_;
std::set<Pairs> supported_types_initializer_;
bool npu_qdq_optimizer_enabled_;
bool bfloat16_optimizer_enabled_;

protected:
void populate_op_mode_supported();
Expand All @@ -81,11 +82,13 @@ class DataOps {

public:
DataOps(const GraphViewer& graph_viewer_param, VersionNum ver,
const std::string dev_id, const bool npu_qdq_optimizer_enabled)
const std::string dev_id, const bool npu_qdq_optimizer_enabled,
bool bfloat16_optimizer_enabled)
: graph_viewer_(graph_viewer_param),
version_id_(ver),
device_id_(std::move(dev_id)),
npu_qdq_optimizer_enabled_(npu_qdq_optimizer_enabled) {
npu_qdq_optimizer_enabled_(npu_qdq_optimizer_enabled),
bfloat16_optimizer_enabled_(bfloat16_optimizer_enabled) {
populate_op_mode_supported();
populate_types_supported();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "qdq_scales_fix.h"
#include "core/providers/openvino/ov_protobuf_utils.h"
#include "core/framework/float16.h"

#include <fstream>
#include <list>
Expand Down Expand Up @@ -605,8 +606,7 @@ float get_initializer_value(const Graph& graph, const std::string& initializer_n
auto size = get_initializer_size(graph, initializer_name);
ORT_ENFORCE(size == 1, "Expected an initializer to be of size 1");
return raw_data[0];
}
else
} else
return get_float_initializer_data(p_initializer);
}

Expand Down Expand Up @@ -775,7 +775,6 @@ bool scale_graph(CustomGraph& gen_graph,
return needs_second_run;
}


Status copy_model(const GraphViewer& src_graph_viewer,
const logging::Logger& logger, std::unique_ptr<onnxruntime::Model>& model) {
model = src_graph_viewer.CreateModel(logger);
Expand Down Expand Up @@ -942,5 +941,55 @@ Status Transform(const GraphViewer& src_graph_viewer,
return status;
}
} // namespace qdq_scales_fix

namespace bfloat16_fix {
void replace_bf16_with_fp16(qdq_scales_fix::CustomGraph& gen_graph) {
for (auto& const_node : gen_graph.original_graph.Nodes()) {
auto node = const_cast<ONNX_NAMESPACE::Node*>(const_node);
if (node->OpType() == "Cast") {
for (auto& [name, const_attribute] : node->GetAttributes()) {
auto& attribute = const_cast<ONNX_NAMESPACE::AttributeProto&>(const_attribute);
if (name == "to" && attribute.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT)
if (attribute.i() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
attribute.set_i(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
}
}
for (auto& output : node->OutputDefs()) {
auto& output_proto = const_cast<ONNX_NAMESPACE::TypeProto&>(output->ToProto().type());
if (output_proto.mutable_tensor_type()->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
output_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
}
}

const auto& init_set = gen_graph.original_graph.GetAllInitializedTensors();
for (auto& [key, const_tensor_proto] : init_set) {
auto tensor_proto = const_cast<ONNX_NAMESPACE::TensorProto*>(const_tensor_proto);
auto dt = tensor_proto->data_type();
if (dt == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
auto raw_data = tensor_proto->has_raw_data() ? reinterpret_cast<std::uint16_t*>(tensor_proto->mutable_raw_data()->data()) : nullptr;
if (raw_data) {
tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
std::int64_t size = 1;
for (int i = 0; i < tensor_proto->dims_size(); ++i)
size *= tensor_proto->dims()[i];
for (std::int64_t i = 0; i < size; ++i) {
std::uint32_t tmp = static_cast<std::uint32_t>(raw_data[i]) << 16;
raw_data[i] = onnxruntime::MLFloat16(*reinterpret_cast<float*>(&tmp)).val;
}
}
}
}
}

Status Transform(const GraphViewer& src_graph_viewer,
const logging::Logger& logger,
/*out*/ std::unique_ptr<onnxruntime::Model>& model) {
auto status = qdq_scales_fix::copy_model(src_graph_viewer, logger, model);
auto g = qdq_scales_fix::generate_graph_from_onnx(model->MainGraph());

replace_bf16_with_fp16(g);
return status;
}
} // namespace bfloat16_fix
} // namespace openvino_ep
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,10 @@ Status Transform(const GraphViewer& src_graph,
const logging::Logger& logger,
/*out*/ std::unique_ptr<onnxruntime::Model>& model);
}
namespace bfloat16_fix {
Status Transform(const GraphViewer& src_graph,
const logging::Logger& logger,
/*out*/ std::unique_ptr<onnxruntime::Model>& model);
}
} // namespace openvino_ep
} // namespace onnxruntime
1 change: 1 addition & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2144,6 +2144,7 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O
ov_options_converted_map["load_config"] = "";
ov_options_converted_map["model_priority"] = "DEFAULT";
ov_options_converted_map["enable_qdq_optimizer"] = "false";
ov_options_converted_map["enable_bfloat16_optimizer"] = "false";
ov_options_converted_map["enable_causallm"] = "false";
return ov_options_converted_map;
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,8 @@ static std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory
#if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE)
ProviderOptions OV_provider_options_map;
const std::unordered_set<std::string> valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision",
"load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer",
"load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling",
"enable_qdq_optimizer", "enable_bfloat16_optimizer",
"enable_causallm", "disable_dynamic_shapes", "reshape_input"};
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,13 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
} else {
ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_qdq_optimizer' should be a boolean i.e. true or false. Default value is false.\n");
}
} else if (key == "enable_bfloat16_optimizer") {
if (value == "true" || value == "True" ||
value == "false" || value == "False") {
ov_options[key] = value;
} else {
ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_bfloat16_optimizer' should be a boolean i.e. true or false. Default value is false.\n");
}
} else if (key == "enable_causallm") {
if (value == "true" || value == "True" ||
value == "false" || value == "False") {
Expand Down