diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 2282534b0f..09dc235912 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -491,20 +491,50 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder, goto fail; } - res = ensure_backend(instance, encoding, wasi_nn_ctx); - if (res != success) - goto fail; + if (encoding == autodetect) { + for (graph_encoding e = openvino; e <= unknown_backend; e++) { + if (wasi_nn_ctx->is_backend_ctx_initialized) { + call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res, + wasi_nn_ctx->backend_ctx); + } + + res = ensure_backend(instance, e, wasi_nn_ctx); + if (res != success) { + NN_ERR_PRINTF("continue trying the next"); + continue; + } + + call_wasi_nn_func(wasi_nn_ctx->backend, load, res, + wasi_nn_ctx->backend_ctx, &builder_native, e, + target, g); + if (res != success) { + NN_ERR_PRINTF("continue trying the next"); + continue; + } + + break; + } + } + else { + res = ensure_backend(instance, encoding, wasi_nn_ctx); + if (res != success) + goto fail; - call_wasi_nn_func(wasi_nn_ctx->backend, load, res, wasi_nn_ctx->backend_ctx, - &builder_native, encoding, target, g); - if (res != success) - goto fail; + call_wasi_nn_func(wasi_nn_ctx->backend, load, res, + wasi_nn_ctx->backend_ctx, &builder_native, encoding, + target, g); + if (res != success) + goto fail; + } fail: // XXX: Free intermediate structure pointers - if (builder_native.buf) + if (builder_native.buf) { wasm_runtime_free(builder_native.buf); - unlock_ctx(wasi_nn_ctx); + } + if (wasi_nn_ctx) { + unlock_ctx(wasi_nn_ctx); + } return res; } @@ -565,17 +595,29 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, goto fail; } - res = ensure_backend(instance, autodetect, wasi_nn_ctx); - if (res != success) - goto fail; + for (graph_encoding e = openvino; e <= unknown_backend; e++) { + if (wasi_nn_ctx->is_backend_ctx_initialized) { + call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res, + wasi_nn_ctx->backend_ctx); + } - call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, - wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len, - g); - if (res != success) - goto fail; + res = ensure_backend(instance, e, wasi_nn_ctx); + if (res != success) { + NN_ERR_PRINTF("continue trying the next"); + continue; + } + + call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, + wasi_nn_ctx->backend_ctx, nul_terminated_name, + name_len, g); + if (res != success) { + NN_ERR_PRINTF("continue trying the next"); + continue; + } + + break; + } - res = success; fail: if (nul_terminated_name != NULL) { wasm_runtime_free(nul_terminated_name); @@ -627,18 +669,29 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name, goto fail; } - res = ensure_backend(instance, autodetect, wasi_nn_ctx); - if (res != success) - goto fail; - ; + for (graph_encoding e = openvino; e <= unknown_backend; e++) { + if (wasi_nn_ctx->is_backend_ctx_initialized) { + call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res, + wasi_nn_ctx->backend_ctx); + } - call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res, - wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len, - nul_terminated_config, config_len, g); - if (res != success) - goto fail; + res = ensure_backend(instance, e, wasi_nn_ctx); + if (res != success) { + NN_ERR_PRINTF("continue trying the next"); + continue; + } + + call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res, + wasi_nn_ctx->backend_ctx, nul_terminated_name, + name_len, nul_terminated_config, config_len, g); + if (res != success) { + NN_ERR_PRINTF("continue trying the next"); + continue; + } + + break; + } - res = success; fail: if (nul_terminated_name != NULL) { wasm_runtime_free(nul_terminated_name);