Skip to content

Commit fcca2ca

Browse files
BoarQingYueqing Zhang
authored andcommitted
[VitisAI] fix deinit vitisai ep (#23725)
### Description <!-- Describe your changes. --> Removed the schema when unloading. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> This fix the crash where the onnxruntime reloads vitisai ep. Co-authored-by: Yueqing Zhang <[email protected]>
1 parent cf65f3b commit fcca2ca

File tree

5 files changed

+36
-7
lines changed

5 files changed

+36
-7
lines changed

onnxruntime/core/providers/shared_library/provider_interfaces.h

+1
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,7 @@ struct ProviderHost {
604604
virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0;
605605

606606
virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op) = 0;
607+
virtual void DeregisterSchema(const std::string& domain, const std::string& op_type, int version) = 0;
607608
virtual const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& name, const int maxInclusiveVersion, const std::string& domain) = 0;
608609
virtual const std::string& OpSchema__inputs__GetName(const ONNX_NAMESPACE::OpSchema* p, const size_t i) = 0;
609610
virtual const std::string& OpSchema__inputs__GetTypeStr(const ONNX_NAMESPACE::OpSchema* p, const size_t i) = 0;

onnxruntime/core/providers/vitisai/imp/global_api.cc

+16-6
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ struct OrtVitisAIEpAPI {
6363
void (*profiler_collect)(
6464
std::vector<EventInfo>& api_events,
6565
std::vector<EventInfo>& kernel_events);
66+
void (*deinitialize_onnxruntime_vitisai_ep)();
6667
void Ensure() {
6768
if (handle_)
6869
return;
@@ -91,6 +92,7 @@ struct OrtVitisAIEpAPI {
9192
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "create_ep_context_nodes", (void**)&create_ep_context_nodes));
9293
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_on_run_start", (void**)&vitisai_ep_on_run_start));
9394
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_set_ep_dynamic_options", (void**)&vitisai_ep_set_ep_dynamic_options));
95+
std::ignore = env.GetSymbolFromLibrary(handle_, "deinitialize_onnxruntime_vitisai_ep", (void**)&deinitialize_onnxruntime_vitisai_ep);
9496
}
9597
void Clear() {
9698
if (handle_) {
@@ -192,7 +194,7 @@ struct MyCustomOpKernel : OpKernel {
192194
void* op_kernel_;
193195
};
194196

195-
void create_kernel_registry(std::vector<OrtCustomOpDomain*> domains) {
197+
void create_kernel_registry(const std::vector<OrtCustomOpDomain*>& domains) {
196198
s_kernel_registry_vitisaiep = KernelRegistry::Create();
197199
for (const auto& domain : domains) {
198200
for (const auto* op : domain->custom_ops_) {
@@ -245,6 +247,7 @@ void create_kernel_registry(std::vector<OrtCustomOpDomain*> domains) {
245247
}
246248
}
247249
}
250+
248251
void initialize_vitisai_ep() {
249252
s_library_vitisaiep.Ensure();
250253
s_domains_vitisaiep.reserve(100);
@@ -253,6 +256,18 @@ void initialize_vitisai_ep() {
253256
create_kernel_registry(s_domains_vitisaiep);
254257
}
255258

259+
void deinitialize_vitisai_ep() {
260+
if (s_library_vitisaiep.deinitialize_onnxruntime_vitisai_ep) {
261+
s_library_vitisaiep.deinitialize_onnxruntime_vitisai_ep();
262+
}
263+
vaip::deregister_xir_ops(s_domains_vitisaiep);
264+
// kernel registry would be repopulated, no need to delete kernel registry
265+
s_domains_vitisaiep.clear();
266+
267+
s_library_vitisaiep.Clear();
268+
s_kernel_registry_vitisaiep.reset();
269+
}
270+
256271
static void set_version_info(vaip_core::OrtApiForVaip& api) {
257272
const char* magic = "VAIP";
258273
std::memcpy(reinterpret_cast<char*>(&api.magic), magic, sizeof(api.magic));
@@ -510,8 +525,3 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
510525
return &the_global_api;
511526
}
512527
}
513-
514-
void deinitialize_vitisai_ep() {
515-
s_library_vitisaiep.Clear();
516-
s_kernel_registry_vitisaiep.reset();
517-
}

onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc

+12
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,23 @@ namespace vaip {
1212
void register_xir_ops(const std::vector<OrtCustomOpDomain*>& domains) {
1313
for (auto domain : domains) {
1414
for (auto op : domain->custom_ops_) {
15+
// skip dequant quant schema, but register kernel
1516
if (Provider_GetHost()->GetSchema(op->GetName(op), op->GetStartVersion(op), domain->domain_) == nullptr) {
1617
Provider_GetHost()->RegisterSchema(domain->domain_, op);
1718
}
1819
}
1920
}
2021
}
2122

23+
void deregister_xir_ops(const std::vector<OrtCustomOpDomain*>& domains) {
24+
for (auto domain : domains) {
25+
if (domain->domain_ != "com.xilinx") continue; // skip dequant quant schema
26+
for (auto op : domain->custom_ops_) {
27+
if (Provider_GetHost()->GetSchema(op->GetName(op), op->GetStartVersion(op), domain->domain_) != nullptr) {
28+
Provider_GetHost()->DeregisterSchema(domain->domain_, op->GetName(op), op->GetStartVersion(op));
29+
}
30+
}
31+
}
32+
}
33+
2234
} // namespace vaip

onnxruntime/core/providers/vitisai/imp/register_xir_ops.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77

88
namespace vaip {
99
void register_xir_ops(const std::vector<OrtCustomOpDomain*>& domains);
10-
}
10+
void deregister_xir_ops(const std::vector<OrtCustomOpDomain*>& domains);
11+
} // namespace vaip

onnxruntime/core/session/provider_bridge_ort.cc

+5
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,11 @@ struct ProviderHostImpl : ProviderHost {
761761
auto schema = CreateSchema(domain, {op});
762762
ONNX_NAMESPACE::RegisterSchema(schema, ORT_API_VERSION);
763763
}
764+
765+
void DeregisterSchema(const std::string& domain, const std::string& op_type, int version) override {
766+
ONNX_NAMESPACE::DeregisterSchema(op_type, version, domain);
767+
}
768+
764769
const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& name, const int maxInclusiveVersion, const std::string& domain) override {
765770
return ONNX_NAMESPACE::OpSchemaRegistry::Instance()->GetSchema(name, maxInclusiveVersion, domain);
766771
}

0 commit comments

Comments
 (0)