From 5649860a9480d8534e93461a68aa2715c3b15f46 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Fri, 15 Aug 2025 23:37:50 -0500 Subject: [PATCH 1/4] The Webbening --- Cargo.toml | 1 + backends/web/.vscode/settings.json | 3 + backends/web/Cargo.toml | 44 ++ backends/web/_loader.js | 144 +++++++ backends/web/api.rs | 648 +++++++++++++++++++++++++++++ backends/web/binding/mod.rs | 35 ++ backends/web/binding/session.rs | 224 ++++++++++ backends/web/binding/tensor.rs | 235 +++++++++++ backends/web/env.rs | 17 + backends/web/lib.rs | 66 +++ backends/web/memory.rs | 71 ++++ backends/web/private.rs | 17 + backends/web/session.rs | 74 ++++ backends/web/tensor.rs | 248 +++++++++++ backends/web/util.rs | 18 + 15 files changed, 1845 insertions(+) create mode 100644 backends/web/.vscode/settings.json create mode 100644 backends/web/Cargo.toml create mode 100644 backends/web/_loader.js create mode 100644 backends/web/api.rs create mode 100644 backends/web/binding/mod.rs create mode 100644 backends/web/binding/session.rs create mode 100644 backends/web/binding/tensor.rs create mode 100644 backends/web/env.rs create mode 100644 backends/web/lib.rs create mode 100644 backends/web/memory.rs create mode 100644 backends/web/private.rs create mode 100644 backends/web/session.rs create mode 100644 backends/web/tensor.rs create mode 100644 backends/web/util.rs diff --git a/Cargo.toml b/Cargo.toml index 12ef5fde..932317db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ default-members = [ '.' ] exclude = [ 'backends/candle', 'backends/tract', + 'backends/web', 'examples/async-gpt2-api', 'examples/cudarc', 'examples/custom-ops', diff --git a/backends/web/.vscode/settings.json b/backends/web/.vscode/settings.json new file mode 100644 index 00000000..b3d20a3f --- /dev/null +++ b/backends/web/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "rust-analyzer.cargo.target": "wasm32-unknown-unknown" +} \ No newline at end of file diff --git a/backends/web/Cargo.toml b/backends/web/Cargo.toml new file mode 100644 index 00000000..c71dc450 --- /dev/null +++ b/backends/web/Cargo.toml @@ -0,0 +1,44 @@ +[workspace] +resolver = "2" + +[package] +name = "ort-web" +description = "ONNX Runtime on the web 🌐 - An alternative backend for ort" +version = "0.1.0+1.22" +edition = "2024" +rust-version = "1.88" +license = "MIT OR Apache-2.0" +repository = "https://github.com/pykeio/ort" +homepage = "https://ort.pyke.io/backends/web" +keywords = [ "machine-learning", "ai", "ml", "web", "wasm" ] +categories = [ "algorithms", "mathematics", "science", "web-programming", "wasm" ] +authors = [ + "pyke.io " +] + +[lib] +name = "ort_web" +path = "lib.rs" + +[dependencies] +js-sys = "0.3" +ort = { path = "../../", version = "=2.0.0-rc.10", default-features = false, features = [ "alternative-backend" ] } +ort-sys = { path = "../../ort-sys", version = "=2.0.0-rc.10", default-features = false, features = [ "disable-linking" ] } +serde = { version = "1.0", features = [ "derive" ] } +serde-wasm-bindgen = "0.6" +wasm-bindgen = "0.2" +wasm-bindgen-futures = "0.4" + +[dependencies.web-sys] +version = "0.3" +features = [ + "console", + "ImageData", + "HtmlImageElement", + "ImageBitmap", + "WebGlTexture", + "GpuBuffer" +] + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = [ 'cfg(web_sys_unstable_apis)' ] } diff --git a/backends/web/_loader.js b/backends/web/_loader.js new file mode 100644 index 00000000..74a9f07e --- /dev/null +++ b/backends/web/_loader.js @@ -0,0 +1,144 @@ +const INIT_SYMBOL = Symbol('@ort-web.init'); + +const FEATURES_NONE = 0; +const FEATURES_WEBGL = 1 << 0; +const FEATURES_WEBGPU = 1 << 1; +const FEATURES_ALL = FEATURES_WEBGL | FEATURES_WEBGPU; + +/** + * @typedef {Object} Dist + * @property {string} baseUrl + * @property {string} scriptName + * @property {string} binaryName + * @property {string} [wrapperName] defaults to `binaryName` s/\.wasm$/.mjs + * @property {Record<'main' | 'wrapper' | 'binary', string>} integrities + */ + +const DIST_BASE = 'https://cdn.pyke.io/0/pyke:ort-rs/web@1.22.0/'; + +/** @type {Record} */ +const DIST = { + [FEATURES_NONE]: { + baseUrl: DIST_BASE, + scriptName: 'ort.wasm.min.js', + binaryName: 'ort-wasm-simd-threaded.wasm', + integrities: { + main: 'epp8GDQUoLKx5qHa6SoDHKv7fSHILsEYk4uEsHdPRBztXEIVWCe/lhhrQJUBMZcf', + wrapper: 'LXMGGJ76ujT3yGw+OWQZVB6vBmJ7lqTO957Fh6ov3385aw3EncleBNFfYFAl3vXW', + binary: 'Eu/XUdOA62yl+TueG792KtrQlAGAMW3g10sY4G3LBYyYZUtM126Z4Gr3ljTlXUGG' + } + }, + [FEATURES_WEBGL]: { + baseUrl: DIST_BASE, + scriptName: 'ort.webgl.min.js', + binaryName: 'ort-wasm-simd-threaded.wasm', + integrities: { + main: 'IbmlOTVtLFqdmXae30hOMw60GXx+uyALrXF1TomZTqfkz2eL2RL/Po/TzbsGe/yv', + wrapper: 'LXMGGJ76ujT3yGw+OWQZVB6vBmJ7lqTO957Fh6ov3385aw3EncleBNFfYFAl3vXW', + binary: 'Eu/XUdOA62yl+TueG792KtrQlAGAMW3g10sY4G3LBYyYZUtM126Z4Gr3ljTlXUGG' + } + }, + [FEATURES_WEBGPU]: { + baseUrl: DIST_BASE, + scriptName: 'ort.webgpu.min.js', + binaryName: 'ort-wasm-simd-threaded.jsep.wasm', + integrities: { + main: 'XM2cMlQFAUJFJ3s2424PSr/v9zkRT4aXfi1cUz2SunZatAOwTR5GfTcIKLJIf3Ns', + wrapper: 'fZi+E4spXPUbkMSScLJlEGqj5QdfSJK7VQ2AZC5HLLV8lZg1j+TZT0RK6aEakeeX', + binary: 'NNN1BawwGTHI+TPz2ivQSKo1AJHr/496DqG53T9IUQ6B9ruFNrov0DNJhqucIwZ1' + } + }, + [FEATURES_ALL]: { + baseUrl: DIST_BASE, + scriptName: 'ort.all.min.js', + binaryName: 'ort-wasm-simd-threaded.jsep.wasm', + integrities: { + main: 'YWRTN6ucI4mQ8JMXfTaXD+iM7ExBj4KSHo6k6W9UIgx1tG98UgXpekjyYvRQ6akx', + wrapper: 'fZi+E4spXPUbkMSScLJlEGqj5QdfSJK7VQ2AZC5HLLV8lZg1j+TZT0RK6aEakeeX', + binary: 'NNN1BawwGTHI+TPz2ivQSKo1AJHr/496DqG53T9IUQ6B9ruFNrov0DNJhqucIwZ1' + } + } +}; + +/** + * @param {string} url + * @param {'fetch' | 'script' | 'module'} as + * @param {string} [type] + * @param {string} [integrity] + */ +function preload(url, as, type, integrity) { + const el = document.createElement('link'); + el.href = url; + if (as !== 'module') { + el.rel = 'preload'; + el.setAttribute('as', as); + } else { + el.rel = 'modulepreload'; + } + if (type) { + el.setAttribute('type', type); + } + if (integrity) { + el.setAttribute('integrity', `sha384-${integrity}`); + } + document.head.appendChild(el); +} + +/** + * @param {number} features + * @returns {Promise} + */ +export function initRuntime(features) { + if ('ort' in window && /** @type {any} */(window).ort[INIT_SYMBOL]) { + return Promise.resolve(false); + } + + if (!(features in DIST)) { + return Promise.reject(new Error('Unsupported feature set')); + } + + const dist = DIST[features]; + /** @param {string} file */ + const relative = file => new URL(file, dist.baseUrl).toString(); + + return new Promise((resolve, reject) => { + // since the order is load main script -> imports wrapper script -> fetches wasm, now would be a good time to + // start fetching those + preload( + relative(dist.binaryName), + 'fetch', + 'application/wasm', + dist.integrities.binary + ); + preload( + relative(dist.wrapperName || dist.binaryName.replace(/\.wasm$/, '.mjs')), + 'module', + undefined, + dist.integrities.wrapper + ); + + const script = document.createElement('script'); + script.src = new URL(dist.binaryName, dist.baseUrl).toString(); + if (dist.integrities.main) { + script.setAttribute('integrity', `sha384-${dist.integrities.main}`); + } + script.addEventListener('load', () => { + if (!('ort' in window)) { + return reject(new Error('script loaded but ort not defined')); + } + + Object.defineProperty(window.ort, INIT_SYMBOL, { + value: true, + configurable: false, + enumerable: false, + writable: false + }); + + resolve(true); + }); + script.addEventListener('error', e => { + reject(e.error); + }); + document.head.appendChild(script); + }); +} diff --git a/backends/web/api.rs b/backends/web/api.rs new file mode 100644 index 00000000..b0263995 --- /dev/null +++ b/backends/web/api.rs @@ -0,0 +1,648 @@ +#![allow(non_snake_case)] + +use alloc::{ + boxed::Box, + ffi::CString, + format, + string::{String, ToString}, + vec::Vec +}; +use core::{ + ffi::{self, CStr}, + future::Future, + pin::Pin +}; +use std::collections::HashMap; + +use ort_sys::{stub::Error, *}; + +use crate::{ + binding, + env::Environment, + memory::{Allocator, MemoryInfo}, + session::{RunOptions, Session, SessionOptions}, + tensor::{SyncDirection, Tensor, TensorData, TypeInfo, create_buffer, onnx_to_dtype}, + util::value_to_string +}; + +unsafe extern "system" fn CreateEnv(_log_severity_level: OrtLoggingLevel, _logid: *const ffi::c_char, out: *mut *mut OrtEnv) -> OrtStatusPtr { + unsafe { out.write(Environment::new_sys()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn CreateEnvWithCustomLogger( + _logging_function: OrtLoggingFunction, + _logger_param: *mut ffi::c_void, + _log_severity_level: OrtLoggingLevel, + _logid: *const ffi::c_char, + out: *mut *mut OrtEnv +) -> OrtStatusPtr { + unsafe { out.write(Environment::new_sys()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn EnableTelemetryEvents(_env: *const OrtEnv) -> OrtStatusPtr { + OrtStatusPtr::default() +} + +unsafe extern "system" fn DisableTelemetryEvents(_env: *const OrtEnv) -> OrtStatusPtr { + OrtStatusPtr::default() +} + +unsafe fn CreateSession( + _env: *const OrtEnv, + model_path: &str, + options: *const OrtSessionOptions, + out: *mut *mut OrtSession +) -> Pin>> { + let options = unsafe { &*options.cast::() }; + + let fut = Box::pin(async move { + match Session::from_url(model_path, options).await { + Ok(session) => { + let ptr = (Box::leak(Box::new(session))) as *mut Session; + unsafe { out.write(ptr.cast()) }; + + OrtStatusPtr::default() + } + Err(e) => e.into_sys() + } + }) as Pin>>; + unsafe { core::mem::transmute(fut) } +} + +unsafe fn CreateSessionFromArray( + _env: *const OrtEnv, + model_data: &[u8], + options: *const OrtSessionOptions, + out: *mut *mut OrtSession +) -> Pin>> { + let options = unsafe { &*options.cast::() }; + + let fut = Box::pin(async move { + match Session::from_bytes(model_data, options).await { + Ok(session) => { + let ptr = (Box::leak(Box::new(session))) as *mut Session; + unsafe { out.write(ptr.cast()) }; + + OrtStatusPtr::default() + } + Err(e) => e.into_sys() + } + }) as Pin>>; + unsafe { core::mem::transmute(fut) } +} + +unsafe extern "system" fn Run( + _session: *mut OrtSession, + _run_options: *const OrtRunOptions, + _input_names: *const *const ::core::ffi::c_char, + _inputs: *const *const OrtValue, + _input_len: usize, + _output_names: *const *const ::core::ffi::c_char, + _output_names_len: usize, + _output_ptrs: *mut *mut OrtValue +) -> OrtStatusPtr { + Error::new_sys(OrtErrorCode::ORT_FAIL, "Synchronous `Session::run` is not supported in ort-web; use `run_async()`.") +} + +unsafe fn RunAsync( + session: *mut OrtSession, + _run_options: *const OrtRunOptions, + input_names: &[&str], + inputs: &[*const OrtValue], + output_names: &[&str], + output_ptrs: &mut [*mut OrtValue] +) -> Pin>> { + let session = unsafe { &*session.cast::() }; + + let fut = Box::pin(async move { + let inputs = input_names + .iter() + .zip(inputs) + .map(|(&name, &input)| (name, unsafe { &*input.cast::() })) + .collect::>(); + + match session.js.run(inputs.into_iter()).await { + Ok(outputs) => { + let output_names: Vec = output_names.iter().map(|&name| name.to_string()).collect(); + let output_view = unsafe { core::slice::from_raw_parts_mut(output_ptrs.as_mut_ptr().cast::<*mut Tensor>(), output_ptrs.len()) }; + + for (name, mut tensor) in outputs { + if let Some(index) = output_names + .iter() + .zip(output_view.iter_mut()) + .find_map(|(o_name, output)| if name == *o_name { Some(output) } else { None }) + { + if !session.disable_sync { + if let Err(e) = tensor.sync(SyncDirection::Rust).await { + return Error::new_sys(OrtErrorCode::ORT_FAIL, format!("Failed to synchronize output '{name}': {e}")); + } + } + + *index = Box::leak(Box::new(tensor)); + } + } + + OrtStatusPtr::default() + } + Err(e) => Error::new_sys(OrtErrorCode::ORT_FAIL, format!("Failed to run session: {}", value_to_string(&e))) + } + }) as Pin>>; + unsafe { core::mem::transmute(fut) } +} + +unsafe extern "system" fn CreateSessionOptions(options: *mut *mut OrtSessionOptions) -> OrtStatusPtr { + unsafe { options.write((Box::leak(Box::new(SessionOptions::new())) as *mut SessionOptions).cast()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn SessionOptionsAppendExecutionProvider( + options: *mut OrtSessionOptions, + provider_name: *const ::core::ffi::c_char, + provider_options_keys: *const *const ::core::ffi::c_char, + provider_options_values: *const *const ::core::ffi::c_char, + num_keys: usize +) -> OrtStatusPtr { + let options = unsafe { &mut *options.cast::() }; + let execution_providers = options.js.execution_providers.get_or_insert_default(); + + let Ok(options) = unsafe { core::slice::from_raw_parts(provider_options_keys, num_keys) } + .iter() + .zip(unsafe { core::slice::from_raw_parts(provider_options_values, num_keys) }.iter()) + .map(|(k, v)| Ok((unsafe { CStr::from_ptr(*k) }.to_str()?, unsafe { CStr::from_ptr(*v) }.to_str()?))) + .collect::, core::str::Utf8Error>>() + else { + return Error::new_sys(OrtErrorCode::ORT_FAIL, "EP options contains invalid UTF-8"); + }; + + let provider_name = unsafe { CStr::from_ptr(provider_name) }; + match provider_name.to_string_lossy().as_ref() { + "WASM" => { + execution_providers.push(binding::ExecutionProvider::WASM); + } + "WebGL" => { + execution_providers.push(binding::ExecutionProvider::WebGL); + } + "WebGPU" => { + execution_providers.push(binding::ExecutionProvider::WebGPU { + preferred_layout: match options.get("ep.webgpuexecutionprovider.preferredLayout") { + Some(&"NHWC") => Some(binding::WebGPUPreferredLayout::NHWC), + Some(&"NCHW") => Some(binding::WebGPUPreferredLayout::NCHW), + _ => None + } + }); + } + "WebNN" => { + execution_providers.push(binding::ExecutionProvider::WebNN { + power_preference: match options.get("powerPreference") { + Some(&"default") => Some(binding::WebNNPowerPreference::Default), + Some(&"high-performance") => Some(binding::WebNNPowerPreference::HighPerformance), + Some(&"low-power") => Some(binding::WebNNPowerPreference::LowPower), + _ => None + }, + device_type: match options.get("deviceType") { + Some(&"cpu") => Some(binding::WebNNDeviceType::CPU), + Some(&"npu") => Some(binding::WebNNDeviceType::NPU), + Some(&"gpu") => Some(binding::WebNNDeviceType::GPU), + _ => None + }, + num_threads: options.get("numThreads").and_then(|c| c.parse().ok()) + }); + } + x => return Error::new_sys(OrtErrorCode::ORT_NOT_IMPLEMENTED, format!("Provider '{x}' not supported")) + } + + OrtStatusPtr::default() +} + +unsafe extern "system" fn CloneSessionOptions(in_options: *const OrtSessionOptions, out_options: *mut *mut OrtSessionOptions) -> OrtStatusPtr { + let options = unsafe { &*in_options.cast::() }; + unsafe { out_options.write((Box::leak(Box::new(options.clone())) as *mut SessionOptions).cast()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn SessionGetInputCount(session: *const OrtSession, out: *mut usize) -> OrtStatusPtr { + let session = unsafe { &*session.cast::() }; + unsafe { out.write(session.js.input_len()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn SessionGetOutputCount(session: *const OrtSession, out: *mut usize) -> OrtStatusPtr { + let session = unsafe { &*session.cast::() }; + unsafe { out.write(session.js.output_len()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn SessionGetOverridableInitializerCount(_session: *const OrtSession, out: *mut usize) -> OrtStatusPtr { + unsafe { out.write(0) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn SessionGetInputTypeInfo(session: *const OrtSession, index: usize, type_info: *mut *mut OrtTypeInfo) -> OrtStatusPtr { + let session = unsafe { &*session.cast::() }; + let metadata = session.js.input_metadata().remove(index); + if !metadata.is_tensor { + return Error::new_sys(OrtErrorCode::ORT_FAIL, "non-tensor types are not currently supported"); + } + + unsafe { type_info.write(TypeInfo::new_sys_from_value_metadata(&metadata)) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn SessionGetOutputTypeInfo(session: *const OrtSession, index: usize, type_info: *mut *mut OrtTypeInfo) -> OrtStatusPtr { + let session = unsafe { &*session.cast::() }; + let metadata = session.js.output_metadata().remove(index); + if !metadata.is_tensor { + return Error::new_sys(OrtErrorCode::ORT_FAIL, "non-tensor types are not currently supported"); + } + + unsafe { type_info.write(TypeInfo::new_sys_from_value_metadata(&metadata)) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn SessionGetInputName( + session: *const OrtSession, + index: usize, + _allocator: *mut OrtAllocator, + value: *mut *mut ffi::c_char +) -> OrtStatusPtr { + let session = unsafe { &*session.cast::() }; + let name = CString::new(&*session.js.input_names().remove(index)).unwrap(); + unsafe { value.write(name.into_raw()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn SessionGetOutputName( + session: *const OrtSession, + index: usize, + _allocator: *mut OrtAllocator, + value: *mut *mut ffi::c_char +) -> OrtStatusPtr { + let session = unsafe { &*session.cast::() }; + let name = CString::new(&*session.js.output_names().remove(index)).unwrap(); + unsafe { value.write(name.into_raw()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn CreateRunOptions(out: *mut *mut OrtRunOptions) -> OrtStatusPtr { + unsafe { out.write((Box::leak(Box::new(RunOptions::new())) as *mut RunOptions).cast()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn CreateTensorAsOrtValue( + _allocator: *mut OrtAllocator, + shape: *const i64, + shape_len: usize, + type_: ONNXTensorElementDataType, + out: *mut *mut OrtValue +) -> OrtStatusPtr { + let shape = unsafe { core::slice::from_raw_parts(shape, shape_len) } + .iter() + .map(|c| *c as i32) + .collect::>(); + let Some(dtype) = onnx_to_dtype(type_) else { + return Error::new_sys(OrtErrorCode::ORT_FAIL, "unsupported dtype"); + }; + + match binding::Tensor::new_from_buffer(dtype, create_buffer(dtype, &shape), &shape) { + Ok(tensor) => { + unsafe { out.write((Box::leak(Box::new(Tensor::from_tensor(tensor))) as *mut Tensor).cast()) }; + OrtStatusPtr::default() + } + Err(e) => Error::new_sys(OrtErrorCode::ORT_FAIL, format!("Failed to create tensor: {}", value_to_string(&e))) + } +} + +unsafe extern "system" fn CreateTensorWithDataAsOrtValue( + _info: *const OrtMemoryInfo, + p_data: *mut ffi::c_void, + p_data_len: usize, + shape: *const i64, + shape_len: usize, + type_: ONNXTensorElementDataType, + out: *mut *mut OrtValue +) -> OrtStatusPtr { + let shape = unsafe { core::slice::from_raw_parts(shape, shape_len) } + .iter() + .map(|c| *c as i32) + .collect::>(); + let Some(dtype) = onnx_to_dtype(type_) else { + return Error::new_sys(OrtErrorCode::ORT_FAIL, "unsupported dtype"); + }; + + match unsafe { Tensor::from_ptr(dtype, p_data, p_data_len, &shape) } { + Ok(tensor) => { + unsafe { out.write((Box::leak(Box::new(tensor)) as *mut Tensor).cast()) }; + OrtStatusPtr::default() + } + Err(e) => Error::new_sys(OrtErrorCode::ORT_FAIL, format!("Failed to create tensor: {}", value_to_string(&e))) + } +} + +unsafe extern "system" fn IsTensor(_value: *const OrtValue, out: *mut ffi::c_int) -> OrtStatusPtr { + unsafe { out.write(1) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn GetTensorMutableData(value: *mut OrtValue, out: *mut *mut ffi::c_void) -> OrtStatusPtr { + let tensor = unsafe { &mut *value.cast::() }; + match &mut tensor.data { + TensorData::RustView { ptr, .. } => { + unsafe { out.write(*ptr) }; + OrtStatusPtr::default() + } + TensorData::External { buffer } => { + if let Some(buffer) = buffer { + unsafe { out.write(buffer.as_mut_ptr().cast()) }; + OrtStatusPtr::default() + } else { + Error::new_sys(OrtErrorCode::ORT_FAIL, "External data is not synchronized; you should call `ort_web::synchronize`.") + } + } + } +} + +unsafe extern "system" fn CastTypeInfoToTensorInfo(type_info: *const OrtTypeInfo, out: *mut *const OrtTensorTypeAndShapeInfo) -> OrtStatusPtr { + unsafe { out.write(type_info.cast()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn GetOnnxTypeFromTypeInfo(_type_info: *const OrtTypeInfo, out: *mut ONNXType) -> OrtStatusPtr { + unsafe { out.write(ONNXType::ONNX_TYPE_TENSOR) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn CreateTensorTypeAndShapeInfo(out: *mut *mut OrtTensorTypeAndShapeInfo) -> OrtStatusPtr { + unsafe { out.write(TypeInfo::new_sys(binding::DataType::Float32, Vec::new()).cast()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn SetTensorElementType(info: *mut OrtTensorTypeAndShapeInfo, type_: ONNXTensorElementDataType) -> OrtStatusPtr { + let info = unsafe { &mut *info.cast::() }; + match onnx_to_dtype(type_) { + Some(_) => { + info.dtype = type_; + OrtStatusPtr::default() + } + None => Error::new_sys(OrtErrorCode::ORT_FAIL, "Unsupported tensor data type") + } +} + +unsafe extern "system" fn SetDimensions(info: *mut OrtTensorTypeAndShapeInfo, dim_values: *const i64, dim_count: usize) -> OrtStatusPtr { + let info = unsafe { &mut *info.cast::() }; + info.shape = unsafe { core::slice::from_raw_parts(dim_values.cast(), dim_count) }.to_vec(); + OrtStatusPtr::default() +} + +unsafe extern "system" fn GetTensorElementType(info: *const OrtTensorTypeAndShapeInfo, out: *mut ONNXTensorElementDataType) -> OrtStatusPtr { + let info = unsafe { &*info.cast::() }; + unsafe { out.write(info.dtype) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn GetDimensionsCount(info: *const OrtTensorTypeAndShapeInfo, out: *mut usize) -> OrtStatusPtr { + let info = unsafe { &*info.cast::() }; + unsafe { out.write(info.shape.len()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn GetDimensions(info: *const OrtTensorTypeAndShapeInfo, dim_values: *mut i64, dim_values_length: usize) -> OrtStatusPtr { + let info = unsafe { &*info.cast::() }; + for (i, dim) in info.shape.iter().enumerate().take(dim_values_length) { + unsafe { dim_values.add(i).write(*dim as _) }; + } + OrtStatusPtr::default() +} + +unsafe extern "system" fn GetSymbolicDimensions( + _info: *const OrtTensorTypeAndShapeInfo, + dim_params: *mut *const ffi::c_char, + dim_params_length: usize +) -> OrtStatusPtr { + for i in 0..dim_params_length { + unsafe { dim_params.add(i).write(c"".as_ptr()) }; + } + OrtStatusPtr::default() +} + +unsafe extern "system" fn GetTensorShapeElementCount(info: *const OrtTensorTypeAndShapeInfo, out: *mut usize) -> OrtStatusPtr { + let info = unsafe { &*info.cast::() }; + let mut size = 1usize; + for dim in &info.shape { + size *= *dim as usize; + } + unsafe { out.write(size) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn GetTensorTypeAndShape(value: *const OrtValue, out: *mut *mut OrtTensorTypeAndShapeInfo) -> OrtStatusPtr { + let tensor = unsafe { &*value.cast::() }; + unsafe { out.write(TypeInfo::new_sys_from_tensor(tensor).cast()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn GetTypeInfo(value: *const OrtValue, out: *mut *mut OrtTypeInfo) -> OrtStatusPtr { + let tensor = unsafe { &*value.cast::() }; + unsafe { out.write(TypeInfo::new_sys_from_tensor(tensor)) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn GetValueType(_value: *const OrtValue, out: *mut ONNXType) -> OrtStatusPtr { + unsafe { out.write(ONNXType::ONNX_TYPE_TENSOR) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn CreateMemoryInfo( + name: *const ffi::c_char, + _type: OrtAllocatorType, + _id: ffi::c_int, + _mem_type: OrtMemType, + out: *mut *mut OrtMemoryInfo +) -> OrtStatusPtr { + let device_name = unsafe { CStr::from_ptr(name) }; + match MemoryInfo::from_location(&*device_name.to_string_lossy()) { + Some(inf) => { + unsafe { *out = (Box::leak(Box::new(inf)) as *mut MemoryInfo).cast() }; + OrtStatusPtr::default() + } + None => Error::new_sys( + OrtErrorCode::ORT_FAIL, + "Unsupported MemoryInfo type - only CPU tensors can be created this way. Tensors must be created from existing non-CPU buffers using `ort_web::TensorExt::from_*`." + ) + } +} + +unsafe extern "system" fn CreateCpuMemoryInfo(_type: OrtAllocatorType, _mem_type: OrtMemType, out: *mut *mut OrtMemoryInfo) -> OrtStatusPtr { + unsafe { *out = (Box::leak(Box::new(MemoryInfo { location: binding::DataLocation::Cpu })) as *mut MemoryInfo).cast() }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn CompareMemoryInfo(info1: *const OrtMemoryInfo, info2: *const OrtMemoryInfo, out: *mut ffi::c_int) -> OrtStatusPtr { + let info1 = unsafe { &*info1.cast::() }; + let info2 = unsafe { &*info2.cast::() }; + unsafe { out.write(if info1 == info2 { 0 } else { -1 }) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn MemoryInfoGetName(ptr: *const OrtMemoryInfo, out: *mut *const ffi::c_char) -> OrtStatusPtr { + let info = unsafe { &*ptr.cast::() }; + unsafe { out.write(info.location_exposed().unwrap_or(c"").as_ptr().cast()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn MemoryInfoGetId(_ptr: *const OrtMemoryInfo, out: *mut ffi::c_int) -> OrtStatusPtr { + unsafe { out.write(0) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn MemoryInfoGetMemType(_ptr: *const OrtMemoryInfo, out: *mut OrtMemType) -> OrtStatusPtr { + unsafe { out.write(OrtMemType::OrtMemTypeDefault) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn MemoryInfoGetType(_ptr: *const OrtMemoryInfo, out: *mut OrtAllocatorType) -> OrtStatusPtr { + unsafe { out.write(OrtAllocatorType::OrtDeviceAllocator) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn GetAllocatorWithDefaultOptions(out: *mut *mut OrtAllocator) -> OrtStatusPtr { + unsafe { out.write((&crate::memory::DEFAULT_CPU_ALLOCATOR as *const Allocator).cast_mut().cast()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn ReleaseEnv(input: *mut OrtEnv) { + drop(unsafe { Environment::consume_sys(input) }); +} + +unsafe extern "system" fn ReleaseStatus(input: *mut OrtStatus) { + drop(unsafe { Error::consume_sys(input) }); +} + +unsafe extern "system" fn ReleaseMemoryInfo(input: *mut OrtMemoryInfo) { + drop(unsafe { Box::::from_raw(input.cast()) }); +} + +unsafe extern "system" fn ReleaseSession(input: *mut OrtSession) { + drop(unsafe { Box::::from_raw(input.cast()) }); +} + +unsafe extern "system" fn ReleaseValue(input: *mut OrtValue) { + drop(unsafe { Box::::from_raw(input.cast()) }); +} + +unsafe extern "system" fn ReleaseRunOptions(input: *mut OrtRunOptions) { + drop(unsafe { Box::::from_raw(input.cast()) }); +} + +unsafe extern "system" fn ReleaseTypeInfo(input: *mut OrtTypeInfo) { + drop(unsafe { TypeInfo::consume_sys(input) }); +} + +unsafe extern "system" fn ReleaseTensorTypeAndShapeInfo(input: *mut OrtTensorTypeAndShapeInfo) { + drop(unsafe { TypeInfo::consume_sys(input.cast()) }); +} + +unsafe extern "system" fn ReleaseSessionOptions(input: *mut OrtSessionOptions) { + drop(unsafe { Box::from_raw(input.cast::()) }); +} + +unsafe extern "system" fn CreateAllocator(_session: *const OrtSession, mem_info: *const OrtMemoryInfo, out: *mut *mut OrtAllocator) -> OrtStatusPtr { + let mem_info = unsafe { &*mem_info.cast::() }; + if mem_info.location != binding::DataLocation::Cpu { + return Error::new_sys(OrtErrorCode::ORT_INVALID_ARGUMENT, "Only CPU allocators are supported."); + } + + unsafe { out.write((Box::leak(Box::new(Allocator::new())) as *mut Allocator).cast()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn ReleaseAllocator(input: *mut OrtAllocator) { + drop(unsafe { Box::from_raw(input.cast::()) }); +} + +unsafe extern "system" fn GetTensorMemoryInfo(value: *const OrtValue, mem_info: *mut *const OrtMemoryInfo) -> OrtStatusPtr { + let tensor = unsafe { &*value.cast::() }; + unsafe { mem_info.write((&tensor.memory_info as *const MemoryInfo).cast()) }; + OrtStatusPtr::default() +} + +unsafe extern "system" fn MemoryInfoGetDeviceType(ptr: *const OrtMemoryInfo, out: *mut OrtMemoryInfoDeviceType) { + let memory_info = unsafe { &*ptr.cast::() }; + unsafe { + out.write(match memory_info.location { + binding::DataLocation::Cpu | binding::DataLocation::CpuPinned => OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU, + _ => OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU + }) + }; +} + +unsafe extern "system" fn GetBuildInfoString() -> *const ffi::c_char { + concat!("ORT Build Info: backend=ort-web, version=", env!("CARGO_PKG_VERSION"), ", with <3\0") + .as_ptr() + .cast() +} + +pub const fn api() -> OrtApi { + OrtApi { + CreateEnv, + CreateEnvWithCustomLogger, + EnableTelemetryEvents, + DisableTelemetryEvents, + CreateSession, + CreateSessionFromArray, + Run, + RunAsync, + CreateSessionOptions, + CloneSessionOptions, + SessionGetInputCount, + SessionGetOutputCount, + SessionGetOverridableInitializerCount, + SessionGetInputTypeInfo, + SessionGetOutputTypeInfo, + SessionGetInputName, + SessionGetOutputName, + CreateTensorAsOrtValue, + CreateTensorWithDataAsOrtValue, + IsTensor, + GetTensorMutableData, + CastTypeInfoToTensorInfo, + GetOnnxTypeFromTypeInfo, + CreateTensorTypeAndShapeInfo, + SetTensorElementType, + SetDimensions, + GetTensorElementType, + GetDimensionsCount, + GetDimensions, + GetSymbolicDimensions, + GetTensorShapeElementCount, + GetTensorTypeAndShape, + GetTypeInfo, + GetValueType, + CreateMemoryInfo, + CreateCpuMemoryInfo, + CompareMemoryInfo, + MemoryInfoGetName, + MemoryInfoGetId, + MemoryInfoGetMemType, + MemoryInfoGetType, + GetAllocatorWithDefaultOptions, + ReleaseEnv, + ReleaseStatus, + ReleaseMemoryInfo, + ReleaseSession, + ReleaseValue, + ReleaseTypeInfo, + ReleaseTensorTypeAndShapeInfo, + ReleaseSessionOptions, + CreateAllocator, + ReleaseAllocator, + GetTensorMemoryInfo, + MemoryInfoGetDeviceType, + GetBuildInfoString, + CreateRunOptions, + ReleaseRunOptions, + SessionOptionsAppendExecutionProvider, + ..ort_sys::stub::api() + } +} diff --git a/backends/web/binding/mod.rs b/backends/web/binding/mod.rs new file mode 100644 index 00000000..73b645b5 --- /dev/null +++ b/backends/web/binding/mod.rs @@ -0,0 +1,35 @@ +use js_sys::Boolean; +use serde::{Deserialize, Serialize}; +use wasm_bindgen::prelude::*; + +mod session; +pub use self::session::*; +mod tensor; +pub use self::tensor::*; + +#[wasm_bindgen] +#[derive(Deserialize, Serialize, Debug, Clone, Copy)] +#[serde(rename_all = "lowercase")] +pub enum DataType { + Bool = "bool", + Float16 = "float16", + Float32 = "float32", + Float64 = "float64", + Int4 = "int4", + Int8 = "int8", + Int16 = "int16", + Int32 = "int32", + Int64 = "int64", + Uint4 = "uint4", + Uint8 = "uint8", + Uint16 = "uint16", + Uint32 = "uint32", + Uint64 = "uint64", + String = "string" +} + +#[wasm_bindgen(module = "/_loader.js")] +extern "C" { + #[wasm_bindgen(catch, js_name = "initRuntime")] + pub async fn init_runtime(features: u8) -> Result; +} diff --git a/backends/web/binding/session.rs b/backends/web/binding/session.rs new file mode 100644 index 00000000..eaf435be --- /dev/null +++ b/backends/web/binding/session.rs @@ -0,0 +1,224 @@ +use alloc::{string::String, vec::Vec}; +use std::collections::HashMap; + +use js_sys::{JsString, Object, Reflect, Uint8Array}; +use serde::{Deserialize, Serialize}; +use wasm_bindgen::prelude::*; + +use crate::{binding::DataType, tensor::Tensor}; + +#[derive(Serialize, Debug, Clone, Copy)] +#[serde(rename_all = "lowercase")] +pub enum ExecutionMode { + Sequential, + Parallel +} + +#[derive(Serialize, Debug, Clone, Copy)] +#[serde(rename_all = "lowercase")] +pub enum GraphOptimizationLevel { + Disabled, + Basic, + Layout, + Extended, + All +} + +#[derive(Serialize, Debug, Clone, Copy, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum WebNNDeviceType { + CPU, + GPU, + NPU +} + +#[derive(Serialize, Debug, Clone, Copy, PartialEq, Eq)] +#[serde(rename_all = "kebab-case")] +pub enum WebNNPowerPreference { + Default, + HighPerformance, + LowPower +} + +#[derive(Serialize, Debug, Clone, Copy, PartialEq, Eq)] +#[serde(rename_all = "UPPERCASE")] +pub enum WebGPUPreferredLayout { + NHWC, + NCHW +} + +#[derive(Serialize, Debug, Clone)] +#[serde(tag = "name", rename_all = "lowercase")] +pub enum ExecutionProvider { + WASM, + WebGL, + #[serde(rename_all = "camelCase")] + WebNN { + device_type: Option, + num_threads: Option, + power_preference: Option + }, + #[serde(rename_all = "camelCase")] + WebGPU { + preferred_layout: Option + } +} + +#[derive(Serialize, Default, Clone)] +#[serde(rename_all = "camelCase")] +pub struct SessionOptions { + pub enable_cpu_mem_arena: Option, + pub enable_graph_capture: Option, + pub enable_mem_pattern: Option, + pub enable_profiling: Option, + pub execution_mode: Option, + pub execution_providers: Option>, + pub extra: Option>, + pub free_dimension_override: Option>, + pub graph_optimization_level: Option, + pub inter_op_num_threads: Option, + pub intra_op_num_threads: Option, + pub log_id: Option, + pub log_severity_level: Option, + pub log_verbosity_level: Option +} + +impl SessionOptions { + pub(crate) fn to_value(&self) -> Result { + self.serialize(&serde_wasm_bindgen::Serializer::new().serialize_maps_as_objects(true)) + } +} + +#[wasm_bindgen] +extern "C" { + #[wasm_bindgen(js_namespace = ort)] + pub type InferenceSession; + + #[wasm_bindgen(catch, js_namespace = ort, static_method_of = InferenceSession, js_name = create)] + async fn create_from_uri_raw(uri: &str, options: JsValue) -> Result; + #[wasm_bindgen(catch, js_namespace = ort, static_method_of = InferenceSession, js_name = create)] + async fn create_from_bytes_raw(buffer: &Uint8Array, options: JsValue) -> Result; + + #[wasm_bindgen(catch, structural, method, js_name = startProfiling)] + pub fn start_profiling(this: &InferenceSession) -> Result<(), JsValue>; + #[wasm_bindgen(catch, structural, method, js_name = endProfiling)] + pub fn end_profiling(this: &InferenceSession) -> Result<(), JsValue>; + #[wasm_bindgen(catch, structural, method, js_name = release)] + pub async fn release(this: &InferenceSession) -> Result<(), JsValue>; + + #[wasm_bindgen(structural, method, getter, js_name = inputMetadata)] + fn input_metadata_raw(this: &InferenceSession) -> Vec; + #[wasm_bindgen(structural, method, getter, js_name = outputMetadata)] + fn output_metadata_raw(this: &InferenceSession) -> Vec; + #[wasm_bindgen(structural, method, getter, js_name = inputNames)] + fn input_names_raw(this: &InferenceSession) -> Vec; + #[wasm_bindgen(structural, method, getter, js_name = outputNames)] + fn output_names_raw(this: &InferenceSession) -> Vec; + + #[wasm_bindgen(catch, structural, method, js_name = run)] + async fn run_raw(this: &InferenceSession, feeds: JsValue) -> Result; + #[wasm_bindgen(catch, structural, method, js_name = run)] + async fn run_with_fetches_raw(this: &InferenceSession, feeds: JsValue, fetches: JsValue) -> Result; +} + +impl InferenceSession { + pub async fn create_from_uri(uri: &str, options: &SessionOptions) -> Result { + InferenceSession::create_from_uri_raw(uri, options.to_value()?).await + } + pub async fn create_from_bytes(buffer: &Uint8Array, options: &SessionOptions) -> Result { + InferenceSession::create_from_bytes_raw(buffer, options.to_value()?).await + } + + pub fn input_names(&self) -> Vec { + self.input_names_raw().into_iter().map(String::from).collect() + } + pub fn output_names(&self) -> Vec { + self.output_names_raw().into_iter().map(String::from).collect() + } + + pub fn input_len(&self) -> usize { + self.input_names_raw().len() + } + pub fn output_len(&self) -> usize { + self.output_names_raw().len() + } + + pub fn input_metadata(&self) -> Vec { + self.input_metadata_raw() + .into_iter() + .map(|x| serde_wasm_bindgen::from_value(x)) + .collect::, serde_wasm_bindgen::Error>>() + .unwrap() + } + + pub fn output_metadata(&self) -> Vec { + self.output_metadata_raw() + .into_iter() + .map(|x| serde_wasm_bindgen::from_value(x)) + .collect::, serde_wasm_bindgen::Error>>() + .unwrap() + } + + pub async fn run(&self, feeds: impl Iterator) -> Result, JsValue> { + let feeds_value = Object::new(); + for (name, tensor) in feeds { + Reflect::set(&feeds_value, &JsValue::from_str(name), &tensor.js)?; + } + Self::to_outputs(self.run_raw(feeds_value.into()).await?) + } + + pub async fn run_with_fetches( + &self, + feeds: impl Iterator, + fetches: impl Iterator)> + ) -> Result, JsValue> { + let feeds_value = Object::new(); + for (name, tensor) in feeds { + Reflect::set(&feeds_value, &JsValue::from_str(name), &tensor.js)?; + } + let fetches_value = Object::new(); + for (name, tensor) in fetches { + let null = JsValue::null(); + Reflect::set( + &fetches_value, + &JsValue::from_str(name), + match tensor { + Some(tensor) => &tensor.js, + None => &null + } + )?; + } + Self::to_outputs(self.run_with_fetches_raw(feeds_value.into(), fetches_value.into()).await?) + } + + fn to_outputs(value: JsValue) -> Result, JsValue> { + Ok(Reflect::own_keys(&value)? + .to_vec() + .into_iter() + .filter_map(|c| { + c.dyn_ref::().map(String::from).and_then(|k| { + Reflect::get(&value, &c) + .map(super::Tensor::unchecked_from_js) + .ok() + .map(|v| (k, Tensor::from_tensor(v))) + }) + }) + .collect()) + } +} + +#[derive(Deserialize, Debug)] +#[serde(untagged)] +pub enum ShapeElement { + Named(String), + Value(i32) +} + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct ValueMetadata { + pub is_tensor: bool, + pub name: String, + pub shape: Option>, + pub r#type: Option +} diff --git a/backends/web/binding/tensor.rs b/backends/web/binding/tensor.rs new file mode 100644 index 00000000..936b83bf --- /dev/null +++ b/backends/web/binding/tensor.rs @@ -0,0 +1,235 @@ +use alloc::{string::ToString, vec::Vec}; + +use js_sys::JsString; +use serde::{Deserialize, Serialize}; +use wasm_bindgen::prelude::*; +use web_sys::{HtmlImageElement, ImageBitmap, ImageData, WebGlTexture}; + +use crate::binding::DataType; + +#[derive(Serialize, Debug, Clone, Copy)] +#[serde(rename_all = "UPPERCASE")] +pub enum ImageFormat { + Rgb, + Rgba, + Bgr, + Rbg +} + +#[derive(Serialize, Debug, Clone, Copy)] +#[serde(rename_all = "UPPERCASE")] +pub enum ImageTensorLayout { + Nhwc, + Nchw +} + +#[derive(Serialize, Debug, Clone, Copy)] +#[serde(rename_all = "lowercase")] +pub enum ImageDataType { + Float32, + Uint8 +} + +impl Into for ImageDataType { + fn into(self) -> DataType { + match self { + Self::Float32 => DataType::Float32, + Self::Uint8 => DataType::Uint8 + } + } +} + +#[derive(Serialize)] +#[serde(untagged)] +pub enum ImageNormOption { + Splat(f32), + PerChannel([f32; 3]), + PerChannelWithAlpha([f32; 4]) +} + +#[derive(Serialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct ImageNorm { + pub bias: Option, + pub mean: Option +} + +impl ImageNorm { + pub const fn imagenet(format: ImageFormat) -> ImageNorm { + const RGB_MEAN: [f32; 3] = [0.485, 0.456, 0.406]; + const RGB_STD: [f32; 3] = [0.229, 0.224, 0.225]; + ImageNorm { + mean: Some(match format { + ImageFormat::Rgb => ImageNormOption::PerChannel(RGB_MEAN), + ImageFormat::Rgba => ImageNormOption::PerChannelWithAlpha([RGB_MEAN[0], RGB_MEAN[1], RGB_MEAN[2], 0.5]), + ImageFormat::Bgr => ImageNormOption::PerChannel([RGB_MEAN[2], RGB_MEAN[1], RGB_MEAN[0]]), + ImageFormat::Rbg => ImageNormOption::PerChannel([RGB_MEAN[0], RGB_MEAN[2], RGB_MEAN[1]]) + }), + bias: Some(match format { + ImageFormat::Rgb => ImageNormOption::PerChannel(RGB_STD), + ImageFormat::Rgba => ImageNormOption::PerChannelWithAlpha([RGB_STD[0], RGB_STD[1], RGB_STD[2], 0.5]), + ImageFormat::Bgr => ImageNormOption::PerChannel([RGB_STD[2], RGB_STD[1], RGB_STD[0]]), + ImageFormat::Rbg => ImageNormOption::PerChannel([RGB_STD[0], RGB_STD[2], RGB_STD[1]]) + }) + } + } +} + +#[derive(Serialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct TensorFromImageOptions { + pub data_type: Option, + pub norm: Option, + pub resized_height: Option, + pub resized_width: Option, + pub tensor_format: Option, + pub tensor_layout: Option +} + +impl TensorFromImageOptions { + pub(crate) fn to_value(&self) -> Result { + serde_wasm_bindgen::to_value(self) + } +} + +#[derive(Serialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct TensorFromUrlOptions { + #[serde(flatten)] + base: TensorFromImageOptions, + pub width: Option, + pub height: Option +} + +impl TensorFromUrlOptions { + pub(crate) fn to_value(&self) -> Result { + serde_wasm_bindgen::to_value(self) + } +} + +#[derive(Serialize)] +#[serde(transparent)] +pub struct DisposeFunction(#[serde(with = "serde_wasm_bindgen::preserve")] JsValue); + +impl From for DisposeFunction +where + T: FnOnce() + 'static +{ + fn from(value: T) -> Self { + DisposeFunction(Closure::once_into_js(value)) + } +} + +#[derive(Serialize)] +#[serde(transparent)] +pub struct DownloadFunction(#[serde(with = "serde_wasm_bindgen::preserve")] JsValue); + +impl From for DownloadFunction +where + T: FnOnce() -> F + 'static, + F: Future> + 'static, + E: core::error::Error +{ + fn from(value: T) -> Self { + DownloadFunction(Closure::once_into_js(move || { + wasm_bindgen_futures::future_to_promise(async move { + match value().await { + Ok(value) => Ok(value), + Err(e) => Err(JsString::from(e.to_string()).into()) + } + }) + })) + } +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct TensorFromTextureOptions { + pub width: u32, + pub height: u32, + pub format: Option, + pub dispose: Option, + pub download: Option +} + +#[wasm_bindgen] +#[derive(Deserialize, Serialize, Debug, Clone, Copy, PartialEq, Eq)] +#[serde(rename_all = "kebab-case")] +pub enum DataLocation { + None = "none", // indicates tensor is disposed + Cpu = "cpu", + CpuPinned = "cpu-pinned", // what is *pinned* in WASM? + Texture = "texture", + GpuBuffer = "gpu-buffer", + MlTensor = "ml-tensor" +} + +#[wasm_bindgen] +extern "C" { + #[wasm_bindgen(js_namespace = ort)] + pub type Tensor; + + #[wasm_bindgen(catch, js_namespace = ort, static_method_of = Tensor, js_name = fromImage)] + async fn from_image_data_raw(image_data: &ImageData, options: JsValue) -> Result; + #[wasm_bindgen(catch, js_namespace = ort, static_method_of = Tensor, js_name = fromImage)] + async fn from_image_element_raw(element: &HtmlImageElement, options: JsValue) -> Result; + #[wasm_bindgen(catch, js_namespace = ort, static_method_of = Tensor, js_name = fromImage)] + async fn from_image_bitmap_raw(bitmap: &ImageBitmap, options: JsValue) -> Result; + #[wasm_bindgen(catch, js_namespace = ort, static_method_of = Tensor, js_name = fromImage)] + async fn from_image_url_raw(url: &str, options: JsValue) -> Result; + #[wasm_bindgen(catch, js_namespace = ort, static_method_of = Tensor, js_name = fromTexture)] + fn from_texture(texture: &WebGlTexture, options: JsValue) -> Result; + #[cfg(web_sys_unstable_apis)] + #[wasm_bindgen(catch, js_namespace = ort, static_method_of = Tensor, js_name = fromGpuBuffer)] + fn from_gpu_buffer(buffer: &web_sys::GpuBuffer, options: JsValue) -> Result; + #[wasm_bindgen(catch, js_namespace = ort, static_method_of = Tensor, js_name = fromPinnedBuffer)] + fn from_pinned_buffer(dtype: DataType, buffer: JsValue, dims: JsValue) -> Result; + + #[wasm_bindgen(constructor, catch, js_namespace = ort, js_class = Tensor)] + fn new_from_buffer_raw(dtype: DataType, buffer: JsValue, dims: JsValue) -> Result; + + #[wasm_bindgen(structural, catch, method, getter, js_name = data)] + pub fn data(this: &Tensor) -> Result; + #[wasm_bindgen(structural, method, getter, js_name = location)] + pub fn location(this: &Tensor) -> DataLocation; + #[wasm_bindgen(structural, method, getter, js_name = type)] + pub fn dtype(this: &Tensor) -> DataType; + #[wasm_bindgen(structural, method, getter, js_name = size)] + pub fn size(this: &Tensor) -> usize; + #[wasm_bindgen(structural, method, getter, js_name = dims)] + pub fn dims(this: &Tensor) -> Vec; + + #[wasm_bindgen(structural, catch, method, js_name = getData)] + pub async fn get_data(this: &Tensor) -> Result; + + #[wasm_bindgen(structural, catch, method, js_name = dispose)] + pub fn dispose(this: &Tensor) -> Result<(), JsValue>; + #[wasm_bindgen(structural, catch, method, js_name = reshape)] + fn reshape(this: &Tensor, dims: JsValue) -> Result; +} + +impl Tensor { + pub async fn from_image_data(image_data: &ImageData, options: &TensorFromImageOptions) -> Result { + Self::from_image_data_raw(image_data, options.to_value()?).await + } + + pub async fn from_image_element(element: &HtmlImageElement, options: &TensorFromImageOptions) -> Result { + Self::from_image_element_raw(element, options.to_value()?).await + } + + pub async fn from_image_bitmap(bitmap: &ImageBitmap, options: &TensorFromImageOptions) -> Result { + Self::from_image_bitmap_raw(bitmap, options.to_value()?).await + } + + pub async fn from_image_url(url: &str, options: &TensorFromUrlOptions) -> Result { + Self::from_image_url_raw(url, options.to_value()?).await + } + + pub fn new_from_buffer(dtype: DataType, buffer: JsValue, dims: &[i32]) -> Result { + Self::new_from_buffer_raw(dtype, buffer, convert_dims(dims)) + } +} + +fn convert_dims(dims: &[i32]) -> JsValue { + dims.iter().map(|d| js_sys::Number::from(*d)).collect::().into() +} diff --git a/backends/web/env.rs b/backends/web/env.rs new file mode 100644 index 00000000..22aa7924 --- /dev/null +++ b/backends/web/env.rs @@ -0,0 +1,17 @@ +use alloc::boxed::Box; + +pub(crate) struct Environment {} + +impl Environment { + pub fn new_sys() -> *mut ort_sys::OrtEnv { + (Box::leak(Box::new(Self {})) as *mut Environment).cast() + } + + pub unsafe fn cast_from_sys<'e>(ptr: *const ort_sys::OrtEnv) -> &'e Environment { + unsafe { &*ptr.cast::() } + } + + pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtEnv) -> Box { + unsafe { Box::from_raw(ptr.cast::()) } + } +} diff --git a/backends/web/lib.rs b/backends/web/lib.rs new file mode 100644 index 00000000..08dcc090 --- /dev/null +++ b/backends/web/lib.rs @@ -0,0 +1,66 @@ +#![deny(clippy::panic, clippy::panicking_unwrap)] +#![warn(clippy::std_instead_of_alloc, clippy::std_instead_of_core)] + +extern crate alloc; +extern crate core; + +use alloc::string::String; +use core::fmt; + +use wasm_bindgen::prelude::*; + +use crate::util::value_to_string; + +mod api; +mod binding; +mod env; +mod memory; +mod session; +mod tensor; +mod util; +#[macro_use] +pub(crate) mod private; + +pub mod prelude { + pub use crate::{ + session::sync_outputs, + tensor::{SyncDirection, ValueExt} + }; +} + +pub type Result = core::result::Result; + +#[derive(Debug, Clone)] +pub struct Error { + msg: String +} + +impl Error { + pub(crate) fn new(msg: impl Into) -> Self { + Self { msg: msg.into() } + } +} + +impl From for Error { + fn from(value: JsValue) -> Self { + Self::new(value_to_string(&value)) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.msg.fmt(f) + } +} + +impl core::error::Error for Error {} + +pub const FEATURE_NONE: u8 = 0; +pub const FEATURE_WEBGL: u8 = 1 << 0; +pub const FEATURE_WEBGPU: u8 = 1 << 1; +pub const FEATURE_WEBNN: u8 = FEATURE_WEBGPU; + +pub async fn api(features: u8) -> Result { + binding::init_runtime(features).await?; + Ok(self::api::api()) +} diff --git a/backends/web/memory.rs b/backends/web/memory.rs new file mode 100644 index 00000000..686efbf1 --- /dev/null +++ b/backends/web/memory.rs @@ -0,0 +1,71 @@ +use alloc::ffi::CString; +use core::{ + ffi::{CStr, c_void}, + ptr +}; + +use crate::binding; + +#[repr(C)] +pub struct Allocator { + _sys_api: ort_sys::OrtAllocator +} + +impl Allocator { + pub const fn new() -> Self { + Self { + _sys_api: ort_sys::OrtAllocator { + version: ort_sys::ORT_API_VERSION, + Alloc: Some(sys_allocator_alloc), + Free: Some(sys_allocator_free), + Info: Some(sys_allocator_info), + Reserve: Some(sys_allocator_reserve) + } + } + } +} + +pub static DEFAULT_CPU_ALLOCATOR: Allocator = Allocator::new(); + +unsafe extern "system" fn sys_allocator_alloc(_this: *mut ort_sys::OrtAllocator, _size: usize) -> *mut c_void { + ptr::null_mut() +} + +unsafe extern "system" fn sys_allocator_free(_this: *mut ort_sys::OrtAllocator, p: *mut c_void) { + drop(unsafe { CString::from_raw(p.cast()) }); +} + +unsafe extern "system" fn sys_allocator_info(this_: *const ort_sys::OrtAllocator) -> *const ort_sys::OrtMemoryInfo { + let _allocator = unsafe { &*this_.cast::() }; + ptr::dangling() +} + +unsafe extern "system" fn sys_allocator_reserve(_this: *const ort_sys::OrtAllocator, _size: usize) -> *mut c_void { + ptr::null_mut() +} + +#[derive(Clone, PartialEq, Eq)] +pub struct MemoryInfo { + pub location: binding::DataLocation +} + +impl MemoryInfo { + pub fn location_exposed(&self) -> Option<&'static CStr> { + match self.location { + binding::DataLocation::Cpu | binding::DataLocation::CpuPinned => Some(c"Cpu"), + binding::DataLocation::Texture => Some(c"WebGL"), + binding::DataLocation::GpuBuffer => Some(c"WebGPU_Buffer"), + binding::DataLocation::MlTensor => Some(c"WebNN"), + _ => None + } + } + + pub fn from_location(location: &str) -> Option { + match location { + "Cpu" => Some(Self { + location: binding::DataLocation::CpuPinned + }), + _ => None + } + } +} diff --git a/backends/web/private.rs b/backends/web/private.rs new file mode 100644 index 00000000..9e023052 --- /dev/null +++ b/backends/web/private.rs @@ -0,0 +1,17 @@ +pub struct PrivateTraitMarker; + +#[macro_export] +macro_rules! private_trait { + () => { + #[doc(hidden)] + fn _private() -> crate::private::PrivateTraitMarker; + }; +} +#[macro_export] +macro_rules! private_impl { + () => { + fn _private() -> crate::private::PrivateTraitMarker { + crate::private::PrivateTraitMarker + } + }; +} diff --git a/backends/web/session.rs b/backends/web/session.rs new file mode 100644 index 00000000..cd1af3d2 --- /dev/null +++ b/backends/web/session.rs @@ -0,0 +1,74 @@ +use js_sys::Uint8Array; +use ort::session::SessionOutputs; +use ort_sys::{OrtErrorCode, stub::Error}; + +use crate::{ + binding, + tensor::{SyncDirection, ValueExt}, + util::value_to_string +}; + +pub const SESSION_SENTINEL: [u8; 4] = [0xFC, 0x86, 0xA5, 0x01]; + +#[repr(C)] +pub struct Session { + sentinel: [u8; 4], + pub js: binding::InferenceSession, + pub disable_sync: bool +} + +impl Session { + pub async fn from_url(uri: &str, options: &SessionOptions) -> Result { + Ok(Session { + sentinel: SESSION_SENTINEL, + js: binding::InferenceSession::create_from_uri(uri, &options.js) + .await + .map_err(|e| Error::new(OrtErrorCode::ORT_FAIL, value_to_string(&e)))?, + disable_sync: options.disable_sync + }) + } + + pub async fn from_bytes(bytes: &[u8], options: &SessionOptions) -> Result { + Ok(Session { + sentinel: SESSION_SENTINEL, + js: binding::InferenceSession::create_from_bytes( + // i'm fairly confident that the bytes are copied, at least when we're not using ONNX.js + &unsafe { Uint8Array::view(bytes) }, + &options.js + ) + .await + .map_err(|e| Error::new(OrtErrorCode::ORT_FAIL, value_to_string(&e)))?, + disable_sync: options.disable_sync + }) + } +} + +pub struct RunOptions {} + +impl RunOptions { + pub const fn new() -> Self { + RunOptions {} + } +} + +pub async fn sync_outputs(outputs: &mut SessionOutputs<'_>) -> crate::Result<()> { + for (_, mut value) in outputs.iter_mut() { + value.sync(SyncDirection::Rust).await?; + } + Ok(()) +} + +#[derive(Clone)] +pub struct SessionOptions { + pub js: binding::SessionOptions, + pub disable_sync: bool +} + +impl SessionOptions { + pub fn new() -> Self { + Self { + js: binding::SessionOptions::default(), + disable_sync: false + } + } +} diff --git a/backends/web/tensor.rs b/backends/web/tensor.rs new file mode 100644 index 00000000..ae85c7c2 --- /dev/null +++ b/backends/web/tensor.rs @@ -0,0 +1,248 @@ +use alloc::{boxed::Box, vec::Vec}; +use core::{ffi::c_void, slice}; + +use js_sys::Uint8Array; +use ort::{AsPointer, value::ValueTypeMarker}; +use wasm_bindgen::{JsCast, JsValue}; + +use crate::{ + Error, + binding::{self, DataType}, + memory::MemoryInfo, + util::num_elements +}; + +pub const TENSOR_SENTINEL: [u8; 4] = [0xFC, 0x86, 0xA5, 0x39]; + +pub enum TensorData { + /// Data is stored in WASM linear memory and can be immediately accessed. + RustView { ptr: *mut c_void, byte_len: usize }, + /// Data is stored outside of WASM linear memory (i.e. session output, or a tensor created from anything other than + /// a Rust slice) and would need to be retrieved if we try to extract this tensor. + External { buffer: Option> } +} + +#[repr(C)] +pub struct Tensor { + sentinel: [u8; 4], + pub js: binding::Tensor, + pub data: TensorData, + pub memory_info: MemoryInfo +} + +impl Tensor { + pub unsafe fn from_ptr(dtype: binding::DataType, ptr: *mut c_void, byte_len: usize, dims: &[i32]) -> Result { + let tensor = binding::Tensor::new_from_buffer(dtype, unsafe { buffer_from_ptr(dtype, ptr, byte_len) }, dims)?; + Ok(Self { + sentinel: TENSOR_SENTINEL, + memory_info: MemoryInfo { location: tensor.location() }, + js: tensor, + data: TensorData::RustView { ptr, byte_len } + }) + } + + pub fn from_tensor(tensor: binding::Tensor) -> Self { + Self { + sentinel: TENSOR_SENTINEL, + memory_info: MemoryInfo { location: tensor.location() }, + js: tensor, + data: TensorData::External { buffer: None } + } + } + + pub async fn sync(&mut self, direction: SyncDirection) -> crate::Result<()> { + match direction { + SyncDirection::Rust => { + let data = self.js.get_data().await?; + + // cast to some kind of typed array first, then convert to uint8array so we can properly copy + let generic_typed_array = Uint8Array::unchecked_from_js(data); + let bytes = Uint8Array::new_with_byte_offset_and_length( + &generic_typed_array.buffer(), + generic_typed_array.byte_offset(), + generic_typed_array.byte_length() + ); + match &mut self.data { + TensorData::RustView { ptr, byte_len } => { + bytes.copy_to(unsafe { core::slice::from_raw_parts_mut(ptr.cast(), *byte_len) }); + } + TensorData::External { buffer } => { + let buffer = match buffer { + Some(buffer) => buffer, + None => { + *buffer = Some(vec![0; generic_typed_array.byte_length() as usize].into_boxed_slice()); + unsafe { buffer.as_mut().unwrap_unchecked() } + } + }; + bytes.copy_to(buffer); + } + } + } + SyncDirection::Runtime => { + let Ok(generic_typed_array) = self.js.data().map(Uint8Array::unchecked_from_js) else { + // we have a download function, but no upload... + return Err(Error::new( + "Cannot synchronize Rust data to a runtime tensor that is not on the CPU; modify the WebGPU/WebGL buffer directly." + )); + }; + let bytes = Uint8Array::new_with_byte_offset_and_length( + &generic_typed_array.buffer(), + generic_typed_array.byte_offset(), + generic_typed_array.byte_length() + ); + bytes.copy_from(match &self.data { + TensorData::RustView { ptr, byte_len } => unsafe { core::slice::from_raw_parts(ptr.cast(), *byte_len) }, + TensorData::External { buffer } => { + let Some(buffer) = buffer else { + return Ok(()); + }; + &*buffer + } + }); + } + } + Ok(()) + } +} + +pub fn create_buffer(dtype: binding::DataType, shape: &[i32]) -> JsValue { + let numel = num_elements(shape) as u32; + match dtype { + binding::DataType::Bool | binding::DataType::Uint8 => js_sys::Uint8Array::new_with_length(numel).into(), + binding::DataType::Int8 => js_sys::Int8Array::new_with_length(numel).into(), + binding::DataType::Uint16 => js_sys::Uint16Array::new_with_length(numel).into(), + binding::DataType::Int16 => js_sys::Int16Array::new_with_length(numel).into(), + binding::DataType::Uint32 => js_sys::Uint32Array::new_with_length(numel).into(), + binding::DataType::Int32 => js_sys::Int32Array::new_with_length(numel).into(), + binding::DataType::Uint64 => js_sys::BigUint64Array::new_with_length(numel).into(), + binding::DataType::Int64 => js_sys::BigInt64Array::new_with_length(numel).into(), + binding::DataType::Float32 => js_sys::Float32Array::new_with_length(numel).into(), + binding::DataType::Float64 => js_sys::Float64Array::new_with_length(numel).into(), + binding::DataType::Int4 | binding::DataType::Uint4 | binding::DataType::Float16 | binding::DataType::String => unimplemented!(), + binding::DataType::__Invalid => unreachable!() + } +} + +pub unsafe fn buffer_from_ptr(dtype: binding::DataType, ptr: *mut c_void, byte_len: usize) -> JsValue { + match dtype { + binding::DataType::Bool | binding::DataType::Uint8 => unsafe { js_sys::Uint8Array::view(slice::from_raw_parts(ptr.cast(), byte_len)) }.into(), + binding::DataType::Int8 => unsafe { js_sys::Int8Array::view(slice::from_raw_parts(ptr.cast(), byte_len)) }.into(), + binding::DataType::Uint16 => unsafe { js_sys::Uint16Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 2)) }.into(), + binding::DataType::Int16 => unsafe { js_sys::Int16Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 2)) }.into(), + binding::DataType::Uint32 => unsafe { js_sys::Uint32Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 4)) }.into(), + binding::DataType::Int32 => unsafe { js_sys::Int32Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 4)) }.into(), + binding::DataType::Uint64 => unsafe { js_sys::BigUint64Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 8)) }.into(), + binding::DataType::Int64 => unsafe { js_sys::BigInt64Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 8)) }.into(), + binding::DataType::Float32 => unsafe { js_sys::Float32Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 4)) }.into(), + binding::DataType::Float64 => unsafe { js_sys::Float64Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 8)) }.into(), + binding::DataType::Int4 | binding::DataType::Uint4 | binding::DataType::Float16 | binding::DataType::String => unimplemented!(), + binding::DataType::__Invalid => unreachable!() + } +} + +pub fn dtype_to_onnx(dtype: binding::DataType) -> ort_sys::ONNXTensorElementDataType { + match dtype { + binding::DataType::String => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, + binding::DataType::Bool => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, + binding::DataType::Uint8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, + binding::DataType::Int8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, + binding::DataType::Uint16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, + binding::DataType::Int16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, + binding::DataType::Uint32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, + binding::DataType::Int32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, + binding::DataType::Uint64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, + binding::DataType::Int64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, + binding::DataType::Float16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, + binding::DataType::Float32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + binding::DataType::Float64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, + binding::DataType::Int4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, + binding::DataType::Uint4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, + binding::DataType::__Invalid => unreachable!() + } +} + +pub fn onnx_to_dtype(dtype: ort_sys::ONNXTensorElementDataType) -> Option { + match dtype { + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => Some(binding::DataType::String), + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => Some(binding::DataType::Bool), + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => Some(binding::DataType::Uint8), + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => Some(binding::DataType::Int8), + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 => Some(binding::DataType::Uint16), + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 => Some(binding::DataType::Int16), + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => Some(binding::DataType::Uint32), + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 => Some(binding::DataType::Int32), + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => Some(binding::DataType::Uint64), + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => Some(binding::DataType::Int64), + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => Some(binding::DataType::Float16), + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => Some(binding::DataType::Float32), + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => Some(binding::DataType::Float64), + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 => Some(binding::DataType::Int4), + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4 => Some(binding::DataType::Uint4), + _ => None + } +} + +pub struct TypeInfo { + pub dtype: ort_sys::ONNXTensorElementDataType, + pub shape: Vec +} + +impl TypeInfo { + pub fn new_sys_from_tensor(tensor: &Tensor) -> *mut ort_sys::OrtTypeInfo { + Self::new_sys(tensor.js.dtype(), tensor.js.dims()) + } + + pub fn new_sys_from_value_metadata(metadata: &binding::ValueMetadata) -> *mut ort_sys::OrtTypeInfo { + Self::new_sys( + metadata.r#type.unwrap(), + metadata + .shape + .as_ref() + .unwrap() + .iter() + .map(|el| match el { + binding::ShapeElement::Value(v) => *v as i32, + binding::ShapeElement::Named(_) => -1 + }) + .collect() + ) + } + + pub fn new_sys(dtype: DataType, shape: Vec) -> *mut ort_sys::OrtTypeInfo { + (Box::leak(Box::new(Self { dtype: dtype_to_onnx(dtype), shape })) as *mut TypeInfo).cast() + } + + pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtTypeInfo) -> Box { + unsafe { Box::from_raw(ptr.cast::()) } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SyncDirection { + Rust, + Runtime +} + +pub trait ValueExt { + crate::private_trait!(); + + #[allow(async_fn_in_trait)] + async fn sync(&mut self, direction: SyncDirection) -> crate::Result<()>; +} + +impl ValueExt for ort::value::Value { + crate::private_impl!(); + + async fn sync(&mut self, direction: SyncDirection) -> crate::Result<()> { + let ptr = self.ptr_mut(); + // definitely safe regardless of what backend is used since it's highly improbable that a backend's tensor would be + // smaller than 4 bytes (which is pointer size on wasm32) + let sentinel: [u8; 4] = unsafe { core::ptr::read(ptr.cast()) }; + if sentinel != TENSOR_SENTINEL { + return Err(Error::new("Cannot synchronize Value that was not created by ort-web")); + } + + let tensor: &mut Tensor = unsafe { &mut *ptr.cast() }; + tensor.sync(direction).await + } +} diff --git a/backends/web/util.rs b/backends/web/util.rs new file mode 100644 index 00000000..4a41bc16 --- /dev/null +++ b/backends/web/util.rs @@ -0,0 +1,18 @@ +use alloc::string::String; + +use wasm_bindgen::{JsCast, JsValue}; + +pub fn value_to_string(value: &JsValue) -> String { + js_sys::Object::unchecked_from_js_ref(value).to_string().into() +} + +pub fn num_elements(dims: &[i32]) -> usize { + let mut size = 1usize; + for dim in dims { + if *dim < 0 { + return 0; + } + size *= *dim as usize; + } + size +} From f3214298632d645525b6d223f0f79e84266f5542 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Tue, 19 Aug 2025 16:51:27 -0500 Subject: [PATCH 2/4] telemetry --- backends/web/_telemetry.js | 52 +++++++++++++++++++++++++++++++++++++ backends/web/api.rs | 24 +++++++++++++---- backends/web/binding/mod.rs | 6 +++++ backends/web/env.rs | 27 +++++++++++++++++-- 4 files changed, 102 insertions(+), 7 deletions(-) create mode 100644 backends/web/_telemetry.js diff --git a/backends/web/_telemetry.js b/backends/web/_telemetry.js new file mode 100644 index 00000000..4e36fb40 --- /dev/null +++ b/backends/web/_telemetry.js @@ -0,0 +1,52 @@ +const EVENT_URL = 'https://signal.pyke.io/beacon/9f5be487-d137-455a-9938-2fc7ecaa9de3/vVOv73JqP3iYRqXMBNm'; + +const IS_LOCALHOST = /^localhost$|^127(\.[0-9]+){0,2}\.[0-9]+$|^\[::1?\]$/; + +/** @param {Uint8Array} payload */ +function track(payload) { + if (IS_LOCALHOST.test(location.hostname) || location.protocol === 'file:') { + return false; + } + if (navigator.webdriver || 'Cypress' in window) { + return false; + } + + return navigator.sendBeacon(EVENT_URL, payload.buffer); +} + +/** @param {Uint8Array[]} chunks */ +function concat(...chunks) { + const concatenated = new Uint8Array(chunks.reduce((a, b) => a + b.byteLength, 0)); + let offset = 0; + for (const chunk of chunks) { + concatenated.set(chunk, offset); + offset += chunk.byteLength; + } + return concatenated; +} + +/** @param {number} x */ +function asUint32(x) { + const view = new DataView(new ArrayBuffer(4)); + view.setUint32(0, x, true); + return new Uint8Array(view.buffer); +} + +const encoder = new TextEncoder(); + +let hasInitializedSession = false; +export function trackSessionInit() { + if (hasInitializedSession) { + return true; + } + + hasInitializedSession = true; + + const hostname = location.hostname; + return track(concat( + new Uint8Array([ 0x01 ]), + new Uint8Array([ 0x90, 0x63, 0x8A, 0xE7 ]), + asUint32(hostname.length), + encoder.encode(hostname) + )); +} diff --git a/backends/web/api.rs b/backends/web/api.rs index b0263995..d70be75e 100644 --- a/backends/web/api.rs +++ b/backends/web/api.rs @@ -18,7 +18,7 @@ use ort_sys::{stub::Error, *}; use crate::{ binding, - env::Environment, + env::{Environment, TelemetryEvent}, memory::{Allocator, MemoryInfo}, session::{RunOptions, Session, SessionOptions}, tensor::{SyncDirection, Tensor, TensorData, TypeInfo, create_buffer, onnx_to_dtype}, @@ -41,16 +41,20 @@ unsafe extern "system" fn CreateEnvWithCustomLogger( OrtStatusPtr::default() } -unsafe extern "system" fn EnableTelemetryEvents(_env: *const OrtEnv) -> OrtStatusPtr { +unsafe extern "system" fn EnableTelemetryEvents(env: *const OrtEnv) -> OrtStatusPtr { + let env = unsafe { Environment::cast_from_sys_mut(env.cast_mut()) }; + env.with_telemetry = true; OrtStatusPtr::default() } -unsafe extern "system" fn DisableTelemetryEvents(_env: *const OrtEnv) -> OrtStatusPtr { +unsafe extern "system" fn DisableTelemetryEvents(env: *const OrtEnv) -> OrtStatusPtr { + let env = unsafe { Environment::cast_from_sys_mut(env.cast_mut()) }; + env.with_telemetry = false; OrtStatusPtr::default() } unsafe fn CreateSession( - _env: *const OrtEnv, + env: *const OrtEnv, model_path: &str, options: *const OrtSessionOptions, out: *mut *mut OrtSession @@ -63,6 +67,11 @@ unsafe fn CreateSession( let ptr = (Box::leak(Box::new(session))) as *mut Session; unsafe { out.write(ptr.cast()) }; + { + let env = unsafe { Environment::cast_from_sys(env) }; + env.send_telemetry_event(TelemetryEvent::SessionInit); + } + OrtStatusPtr::default() } Err(e) => e.into_sys() @@ -72,7 +81,7 @@ unsafe fn CreateSession( } unsafe fn CreateSessionFromArray( - _env: *const OrtEnv, + env: *const OrtEnv, model_data: &[u8], options: *const OrtSessionOptions, out: *mut *mut OrtSession @@ -85,6 +94,11 @@ unsafe fn CreateSessionFromArray( let ptr = (Box::leak(Box::new(session))) as *mut Session; unsafe { out.write(ptr.cast()) }; + { + let env = unsafe { Environment::cast_from_sys(env) }; + env.send_telemetry_event(TelemetryEvent::SessionInit); + } + OrtStatusPtr::default() } Err(e) => e.into_sys() diff --git a/backends/web/binding/mod.rs b/backends/web/binding/mod.rs index 73b645b5..45d631db 100644 --- a/backends/web/binding/mod.rs +++ b/backends/web/binding/mod.rs @@ -33,3 +33,9 @@ extern "C" { #[wasm_bindgen(catch, js_name = "initRuntime")] pub async fn init_runtime(features: u8) -> Result; } + +#[wasm_bindgen(module = "/_telemetry.js")] +extern "C" { + #[wasm_bindgen(catch, js_name = "trackSessionInit")] + pub fn track_session_init() -> Result; +} diff --git a/backends/web/env.rs b/backends/web/env.rs index 22aa7924..86d79f50 100644 --- a/backends/web/env.rs +++ b/backends/web/env.rs @@ -1,17 +1,40 @@ use alloc::boxed::Box; -pub(crate) struct Environment {} +use crate::binding; + +pub(crate) struct Environment { + pub with_telemetry: bool +} impl Environment { pub fn new_sys() -> *mut ort_sys::OrtEnv { - (Box::leak(Box::new(Self {})) as *mut Environment).cast() + (Box::leak(Box::new(Self { with_telemetry: true })) as *mut Environment).cast() } pub unsafe fn cast_from_sys<'e>(ptr: *const ort_sys::OrtEnv) -> &'e Environment { unsafe { &*ptr.cast::() } } + pub unsafe fn cast_from_sys_mut<'e>(ptr: *mut ort_sys::OrtEnv) -> &'e mut Environment { + unsafe { &mut *ptr.cast::() } + } + pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtEnv) -> Box { unsafe { Box::from_raw(ptr.cast::()) } } + + pub fn send_telemetry_event(&self, event: TelemetryEvent) { + if !self.with_telemetry { + return; + } + + let _ = match event { + TelemetryEvent::SessionInit => binding::track_session_init() + }; + } +} + +#[derive(Debug)] +pub enum TelemetryEvent { + SessionInit } From c16096eca865c18a7e0f3e49a443cfb3f2d13cbf Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Fri, 26 Sep 2025 02:02:07 -0500 Subject: [PATCH 3/4] onnxruntime 1.23.0 --- backends/web/_loader.js | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/backends/web/_loader.js b/backends/web/_loader.js index 74a9f07e..7a1092d8 100644 --- a/backends/web/_loader.js +++ b/backends/web/_loader.js @@ -23,9 +23,9 @@ const DIST = { scriptName: 'ort.wasm.min.js', binaryName: 'ort-wasm-simd-threaded.wasm', integrities: { - main: 'epp8GDQUoLKx5qHa6SoDHKv7fSHILsEYk4uEsHdPRBztXEIVWCe/lhhrQJUBMZcf', - wrapper: 'LXMGGJ76ujT3yGw+OWQZVB6vBmJ7lqTO957Fh6ov3385aw3EncleBNFfYFAl3vXW', - binary: 'Eu/XUdOA62yl+TueG792KtrQlAGAMW3g10sY4G3LBYyYZUtM126Z4Gr3ljTlXUGG' + main: 'Uvpo3KshAzID7bmsY+Pz2/tiNWwl6Y5XeDTPpktDx73e0o/1TdssZDScTVHxpLYv', + wrapper: 'Y/ZaWdP4FERyRvi+anEVDVDDhMJKldzf33TRb2MiCALo054swqCUe6aM/tD8XL6g', + binary: '9UMXJFWi2zyn9PbGgXmJjEYM4hu8T8zmqmgxX6zQ08ZmNBOso3IT0cTp3M3oU7DU' } }, [FEATURES_WEBGL]: { @@ -33,9 +33,9 @@ const DIST = { scriptName: 'ort.webgl.min.js', binaryName: 'ort-wasm-simd-threaded.wasm', integrities: { - main: 'IbmlOTVtLFqdmXae30hOMw60GXx+uyALrXF1TomZTqfkz2eL2RL/Po/TzbsGe/yv', - wrapper: 'LXMGGJ76ujT3yGw+OWQZVB6vBmJ7lqTO957Fh6ov3385aw3EncleBNFfYFAl3vXW', - binary: 'Eu/XUdOA62yl+TueG792KtrQlAGAMW3g10sY4G3LBYyYZUtM126Z4Gr3ljTlXUGG' + main: 'pD9jsAlDhP5yhHaVikKM6mXw/E4HPB+4kc/rf3lrMctGWwT0XpIxiTdH/XDHR7Pr', + wrapper: 'Y/ZaWdP4FERyRvi+anEVDVDDhMJKldzf33TRb2MiCALo054swqCUe6aM/tD8XL6g', + binary: '9UMXJFWi2zyn9PbGgXmJjEYM4hu8T8zmqmgxX6zQ08ZmNBOso3IT0cTp3M3oU7DU' } }, [FEATURES_WEBGPU]: { @@ -43,9 +43,9 @@ const DIST = { scriptName: 'ort.webgpu.min.js', binaryName: 'ort-wasm-simd-threaded.jsep.wasm', integrities: { - main: 'XM2cMlQFAUJFJ3s2424PSr/v9zkRT4aXfi1cUz2SunZatAOwTR5GfTcIKLJIf3Ns', - wrapper: 'fZi+E4spXPUbkMSScLJlEGqj5QdfSJK7VQ2AZC5HLLV8lZg1j+TZT0RK6aEakeeX', - binary: 'NNN1BawwGTHI+TPz2ivQSKo1AJHr/496DqG53T9IUQ6B9ruFNrov0DNJhqucIwZ1' + main: 'rY/SpyGuo298HuKPNCTIhlm3xc022++95XwJnuGVpKaW4yEzMTTDvgXoRQdiicvj', + wrapper: 'Liv6LVoHkWBuJEPAGGmpzPGesXdc9YN5Eu0UaA9a9qChwB0H21V86UFBLhnIBieb', + binary: 'jVPVL8reOtRz4+v3ZZAWg8bO5m7HGJr7tsMxmvNae28TztYbHZIk8JXHeZ/82yST' } }, [FEATURES_ALL]: { @@ -53,9 +53,9 @@ const DIST = { scriptName: 'ort.all.min.js', binaryName: 'ort-wasm-simd-threaded.jsep.wasm', integrities: { - main: 'YWRTN6ucI4mQ8JMXfTaXD+iM7ExBj4KSHo6k6W9UIgx1tG98UgXpekjyYvRQ6akx', - wrapper: 'fZi+E4spXPUbkMSScLJlEGqj5QdfSJK7VQ2AZC5HLLV8lZg1j+TZT0RK6aEakeeX', - binary: 'NNN1BawwGTHI+TPz2ivQSKo1AJHr/496DqG53T9IUQ6B9ruFNrov0DNJhqucIwZ1' + main: 'VVNyVdgdgHOM/8agRDy7rVx66N+/9T1vkYzwYtSS/u36YVzaln3cMtxt24ozySvr', + wrapper: 'Liv6LVoHkWBuJEPAGGmpzPGesXdc9YN5Eu0UaA9a9qChwB0H21V86UFBLhnIBieb', + binary: 'jVPVL8reOtRz4+v3ZZAWg8bO5m7HGJr7tsMxmvNae28TztYbHZIk8JXHeZ/82yST' } } }; From 85617cf668463f64c3795e9c1e977c271d3a6da4 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Fri, 26 Sep 2025 02:06:29 -0500 Subject: [PATCH 4/4] update base --- backends/web/_loader.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/web/_loader.js b/backends/web/_loader.js index 7a1092d8..39629cf4 100644 --- a/backends/web/_loader.js +++ b/backends/web/_loader.js @@ -14,7 +14,7 @@ const FEATURES_ALL = FEATURES_WEBGL | FEATURES_WEBGPU; * @property {Record<'main' | 'wrapper' | 'binary', string>} integrities */ -const DIST_BASE = 'https://cdn.pyke.io/0/pyke:ort-rs/web@1.22.0/'; +const DIST_BASE = 'https://cdn.pyke.io/0/pyke:ort-rs/web@1.23.0/'; /** @type {Record} */ const DIST = {