Skip to content

Commit aeccbc5

Browse files
committed
Copy weights file to epctx output directory
1 parent a8527b9 commit aeccbc5

File tree

3 files changed

+45
-17
lines changed

3 files changed

+45
-17
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

+12-9
Original file line numberDiff line numberDiff line change
@@ -84,18 +84,21 @@ BackendManager::BackendManager(SessionContext& session_context,
8484
std::string device_type = session_context_.device_type;
8585

8686
auto& sw = shared_context_.shared_weights;
87+
if (sw.external_weight_filename.empty() && !sw.metadata.empty()) {
88+
// Reasonable assumption that all metadata entries have the same external file location
89+
sw.external_weight_filename = sw.metadata.begin()->second.location;
90+
}
91+
8792
if (session_context_.so_share_ep_contexts) {
88-
std::filesystem::path weight_filename = session_context_.onnx_model_path_name.parent_path();
89-
if (sw.external_weight_filename.empty() && !sw.metadata.empty()) {
90-
// Reasonable assumption that all metadata entries have the same external file location
91-
sw.external_weight_filename = sw.metadata.begin()->second.location;
93+
auto weight_path = session_context_.GetNewWeightsFilePath(sw.external_weight_filename);
94+
if (!std::filesystem::exists(weight_path)) {
95+
weight_path = session_context_.GetModelDirectory() / sw.external_weight_filename;
9296
}
93-
weight_filename /= sw.external_weight_filename;
94-
std::ifstream weight_file(weight_filename);
9597

98+
std::ifstream weight_file(weight_path);
9699
if (weight_file) {
97100
if (!sw.mapped_weights) {
98-
sw.mapped_weights = std::make_unique<SharedContext::SharedWeights::WeightsFile>(weight_filename);
101+
sw.mapped_weights = std::make_unique<SharedContext::SharedWeights::WeightsFile>(weight_path);
99102
}
100103
backend_utils::CreateOVTensors(session_context_.device_type, sw.metadata, *sw.mapped_weights);
101104
}
@@ -241,7 +244,7 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie
241244
std::ofstream blob_file(blob_filename,
242245
std::ios::out | std::ios::trunc | std::ios::binary);
243246
if (!blob_file) {
244-
ORT_THROW("Unable to open file for epctx model dump.");
247+
ORT_THROW("Unable to open file for epctx model dump." + blob_filename.string());
245248
}
246249
compiled_model.export_model(blob_file);
247250
model_blob_str = blob_filename.filename().string();
@@ -324,7 +327,7 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) {
324327
static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name,
325328
[[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto,
326329
[[maybe_unused]] const onnxruntime::Node& fused_node) {
327-
#ifdef NOT_RELEASE
330+
#ifdef NOT_RELEASE
328331
if (openvino_ep::backend_utils::IsDebugEnabled()) {
329332
auto model_name = onnx_model_path_name.empty() ? "unknown.onnx" : onnx_model_path_name.filename();
330333

onnxruntime/core/providers/openvino/contexts.h

+14
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,20 @@ struct SessionContext : ProviderInfo {
118118
mutable bool has_external_weights = false; // Value is set to mutable to modify from capability
119119
const std::vector<uint32_t> OpenVINO_Version = {OPENVINO_VERSION_MAJOR, OPENVINO_VERSION_MINOR};
120120
const std::string openvino_sdk_version = std::to_string(OPENVINO_VERSION_MAJOR) + "." + std::to_string(OPENVINO_VERSION_MINOR);
121+
122+
fs::path GetModelDirectory() const {
123+
return onnx_model_path_name.parent_path();
124+
}
125+
126+
fs::path GetEpContextOutputDirectory() const {
127+
return so_context_file_path.empty() ? GetModelDirectory() : so_context_file_path.parent_path();
128+
}
129+
130+
fs::path GetNewWeightsFilePath(fs::path external_weights_filename) const {
131+
ORT_ENFORCE(!external_weights_filename.empty(), "External weights filename should not be empty.");
132+
// Otherwise, use the provided external weights filename.
133+
return GetEpContextOutputDirectory() / fs::path(external_weights_filename.filename().string() + "_weights.bin");
134+
}
121135
};
122136

123137
// Holds context specific to subgraph.

onnxruntime/core/providers/openvino/openvino_execution_provider.cc

+19-8
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,13 @@ common::Status OpenVINOExecutionProvider::Compile(
102102
graph_body_viewer_0.DomainToVersionMap().at(kOnnxDomain);
103103
}
104104

105+
const auto metadata_path = session_context_.GetEpContextOutputDirectory() / "metadata.bin";
106+
105107
// Temporary code to read metadata before it moves to the .bin
106108
auto& metadata = shared_context_->shared_weights.metadata;
107109
if (session_context_.so_share_ep_contexts && metadata.empty()) {
108110
// Metadata is always read from model location, this could be a source or epctx model
109-
fs::path metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin";
110-
std::ifstream file(metadata_filename, std::ios::binary);
111+
std::ifstream file(metadata_path, std::ios::binary);
111112
if (file) {
112113
file >> metadata;
113114
}
@@ -174,20 +175,30 @@ common::Status OpenVINOExecutionProvider::Compile(
174175
}
175176

176177
if (session_context_.so_share_ep_contexts) {
177-
fs::path metadata_filename;
178-
if (session_context_.so_context_file_path.empty()) {
179-
metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin";
178+
const auto& sw_path_filename = shared_context_->shared_weights.external_weight_filename;
179+
fs::path new_weights_file_path = session_context_.GetNewWeightsFilePath(sw_path_filename);
180+
fs::path original_weights_path = session_context_.GetModelDirectory() / sw_path_filename;
181+
182+
if (!std::filesystem::exists(new_weights_file_path)) {
183+
try {
184+
std::filesystem::create_hard_link(original_weights_path, new_weights_file_path);
185+
} catch (const std::filesystem::filesystem_error& e) {
186+
LOGS_DEFAULT(WARNING) << "Failed to create hard link: " << e.what() << " Falling back to copy.";
187+
std::filesystem::copy_file(original_weights_path, new_weights_file_path);
188+
}
180189
} else {
181-
metadata_filename = session_context_.so_context_file_path.parent_path() / "metadata.bin";
190+
LOGS_DEFAULT(WARNING) << "Weights file already exists: " << new_weights_file_path.string() << " Link/Copy.";
182191
}
183192

184193
// Metadata is generated only for shared contexts
185-
// If saving metadata then save it to the provided path or ose the original model path
194+
// If saving metadata then save it to the provided path or use the original model path
186195
// Multiple calls to Compile() will update the metadata and for the last call
187196
// the resulting file will contain the aggregated content
188-
std::ofstream file(metadata_filename, std::ios::binary);
197+
std::ofstream file(metadata_path, std::ios::binary);
189198
if (file) {
190199
file << metadata;
200+
} else {
201+
LOGS_DEFAULT(WARNING) << "Failed to write metadata to file: " << metadata_path.string();
191202
}
192203
}
193204

0 commit comments

Comments
 (0)