diff --git a/Cargo.toml b/Cargo.toml index 6431ecd..5d45190 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,3 +47,11 @@ crate-type = ["cdylib"] [[example]] name = "load_permanent" crate-type = ["cdylib"] + +[[example]] +name = "sum_int" +crate-type = ["cdylib"] + +[[example]] +name = "sum_int_aux" +crate-type = ["cdylib"] diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..b03a82e --- /dev/null +++ b/Dockerfile @@ -0,0 +1,8 @@ +FROM debian:bullseye-slim + +RUN apt-get update && apt-get install -y curl valgrind build-essential clang +# Install Rust +ENV RUST_VERSION=stable +RUN curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain=$RUST_VERSION +# Install cargo-valgrind +RUN /bin/bash -c "source /root/.cargo/env && cargo install cargo-valgrind" diff --git a/README.md b/README.md index f39c746..24e4b82 100644 --- a/README.md +++ b/README.md @@ -132,6 +132,14 @@ select * from xxx; Some real-world non-Rust examples of traditional virtual tables in SQLite include the [CSV virtual table](https://www.sqlite.org/csv.html), the full-text search [fts5 extension](https://www.sqlite.org/fts5.html#fts5_table_creation_and_initialization), and the [R-Tree extension](https://www.sqlite.org/rtree.html#creating_an_r_tree_index). + +### Window / Aggregate functions + +A window function can be defined using the `define_window_function`. The step and final function must be defined. See the [`sum_int.rs`](./examples/sum_int.rs) implementation and the [sqlite's own example](https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions) for a full solution. + +There is also a [`define_window_function_with_aux`](./src/window.rs), in case a mutable auxillary object is required in place of the context aggregate pointer provided by sqlite3. In this case, the object being passed is not required to implement the Copy trait. + + ## Examples The [`examples/`](./examples/) directory has a few bare-bones examples of extensions, which you can build with: @@ -241,7 +249,7 @@ A hello world extension in C is `17KB`, while one in Rust is `469k`. It's still - [ ] Stabilize scalar function interface - [ ] Stabilize virtual table interface -- [ ] Support [aggregate window functions](https://www.sqlite.org/windowfunctions.html#udfwinfunc) ([#1](https://github.com/asg017/sqlite-loadable-rs/issues/1)) +- [x] Support [aggregate window functions](https://www.sqlite.org/windowfunctions.html#udfwinfunc) ([#1](https://github.com/asg017/sqlite-loadable-rs/issues/1)) - [ ] Support [collating sequences](https://www.sqlite.org/c3ref/create_collation.html) ([#2](https://github.com/asg017/sqlite-loadable-rs/issues/2)) - [ ] Support [virtual file systems](sqlite.org/vfs.html) ([#3](https://github.com/asg017/sqlite-loadable-rs/issues/3)) diff --git a/examples/sum_int.rs b/examples/sum_int.rs new file mode 100644 index 0000000..e8183a2 --- /dev/null +++ b/examples/sum_int.rs @@ -0,0 +1,50 @@ +//! cargo build --example sum_int +//! sqlite3 :memory: '.read examples/test.sql' + +use libsqlite3_sys::sqlite3_int64; +use sqlite_loadable::prelude::*; +use sqlite_loadable::window::define_window_function; +use sqlite_loadable::{api, Result}; + +/// Example inspired by sqlite3's sumint +/// https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions +pub fn x_step(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + let previous_value = api::get_aggregate_context_value::(context)?; + api::set_aggregate_context_value::(context, previous_value + new_value)?; + Ok(()) +} + + +pub fn x_final(context: *mut sqlite3_context) -> Result<()> { + let value = api::get_aggregate_context_value::(context)?; + api::result_int64(context, value); + Ok(()) +} + + +pub fn x_value(context: *mut sqlite3_context) -> Result<()> { + let value = api::get_aggregate_context_value::(context)?; + api::result_int64(context, value); + Ok(()) +} + +pub fn x_inverse(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + let previous_value = api::get_aggregate_context_value::(context)?; + api::set_aggregate_context_value::(context, previous_value - new_value)?; + Ok(()) +} + +#[sqlite_entrypoint] +pub fn sqlite3_sumint_init(db: *mut sqlite3) -> Result<()> { + let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC; + define_window_function( + db, "sumint", -1, flags, + x_step, x_final, Some(x_value), Some(x_inverse), + )?; + Ok(()) +} + diff --git a/examples/sum_int_aux.rs b/examples/sum_int_aux.rs new file mode 100644 index 0000000..a1b417d --- /dev/null +++ b/examples/sum_int_aux.rs @@ -0,0 +1,44 @@ +//! cargo build --example sum_int_aux + +use sqlite_loadable::prelude::*; +use sqlite_loadable::window::define_window_function_with_aux; +use sqlite_loadable::{api, Result}; + +/// Example inspired by sqlite3's sumint +/// https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions +pub fn x_step(_context: *mut sqlite3_context, values: &[*mut sqlite3_value], i: &mut i64) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + *i = *i + new_value; + Ok(()) +} + + +pub fn x_final(context: *mut sqlite3_context, i: &mut i64) -> Result<()> { + api::result_int64(context, *i); + Ok(()) +} + + +pub fn x_value(context: *mut sqlite3_context, i: &mut i64) -> Result<()> { + api::result_int64(context, *i); + Ok(()) +} + +pub fn x_inverse(_context: *mut sqlite3_context, values: &[*mut sqlite3_value], i: &mut i64) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + *i = *i - new_value; + Ok(()) +} + +#[sqlite_entrypoint] +pub fn sqlite3_sumintaux_init(db: *mut sqlite3) -> Result<()> { + let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC; + define_window_function_with_aux::( + db, "sumint_aux", -1, flags, + x_step, x_final, Some(x_value), Some(x_inverse), + 0, + )?; + Ok(()) +} \ No newline at end of file diff --git a/run-docker.sh b/run-docker.sh new file mode 100644 index 0000000..9b80c37 --- /dev/null +++ b/run-docker.sh @@ -0,0 +1,8 @@ +#!/bin/sh +NAME="io_uring:1.0" +docker image inspect "$NAME" || docker build -t "$NAME" . +docker run -it -v $PWD:/tmp -w /tmp $NAME + +# see https://github.com/jfrimmel/cargo-valgrind/pull/58/commits/1c168f296e0b3daa50279c642dd37aecbd85c5ff#L59 +# scan for double frees and leaks +# VALGRINDFLAGS="--leak-check=yes --trace-children=yes" cargo valgrind test diff --git a/sqlite-loadable-macros/Cargo.lock b/sqlite-loadable-macros/Cargo.lock index 08347de..7dcbdfe 100644 --- a/sqlite-loadable-macros/Cargo.lock +++ b/sqlite-loadable-macros/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "proc-macro2" -version = "1.0.46" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b" +checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" dependencies = [ "unicode-ident", ] @@ -22,7 +22,7 @@ dependencies = [ [[package]] name = "sqlite-loadable-macros" -version = "0.0.1" +version = "0.0.3" dependencies = [ "proc-macro2", "quote", diff --git a/src/api.rs b/src/api.rs index 11e444f..c6e1e0f 100644 --- a/src/api.rs +++ b/src/api.rs @@ -18,6 +18,7 @@ use crate::ext::{ }; use crate::Error; use sqlite3ext_sys::{SQLITE_BLOB, SQLITE_FLOAT, SQLITE_INTEGER, SQLITE_NULL, SQLITE_TEXT}; +use sqlite3ext_sys::sqlite3_aggregate_context; use std::os::raw::c_int; use std::slice::from_raw_parts; use std::str::Utf8Error; @@ -568,3 +569,41 @@ impl ExtendedColumnAffinity { ExtendedColumnAffinity::Numeric } } + +// TODO write test +pub fn get_aggregate_context_value(context: *mut sqlite3_context) -> Result +where + T: Copy, +{ + let p_value: *mut T = unsafe { + sqlite3_aggregate_context(context, std::mem::size_of::() as i32) as *mut T + }; + + if p_value.is_null() { + return Err("sqlite3_aggregate_context returned a null pointer.".to_string()); + } + + let value: T = unsafe { *p_value }; + + Ok(value) +} + +// TODO write test +pub fn set_aggregate_context_value(context: *mut sqlite3_context, value: T) -> Result<(), String> +where + T: Copy, +{ + let p_value: *mut T = unsafe { + sqlite3_aggregate_context(context, std::mem::size_of::() as i32) as *mut T + }; + + if p_value.is_null() { + return Err("sqlite3_aggregate_context returned a null pointer.".to_string()); + } + + unsafe { + *p_value = value; + } + + Ok(()) +} \ No newline at end of file diff --git a/src/bit_flags.rs b/src/bit_flags.rs new file mode 100644 index 0000000..4542196 --- /dev/null +++ b/src/bit_flags.rs @@ -0,0 +1,27 @@ +use bitflags::bitflags; + +use sqlite3ext_sys::{ + SQLITE_DETERMINISTIC, SQLITE_DIRECTONLY, SQLITE_INNOCUOUS, SQLITE_SUBTYPE, SQLITE_UTF16, + SQLITE_UTF16BE, SQLITE_UTF16LE, SQLITE_UTF8, +}; + +bitflags! { + /// Represents the possible flag values that can be passed into sqlite3_create_function_v2 + /// or sqlite3_create_window_function, as the 4th "eTextRep" parameter. + /// Includes both the encoding options (utf8, utf16, etc.) and function-level parameters + /// (deterministion, innocuous, etc.). + pub struct FunctionFlags: i32 { + const UTF8 = SQLITE_UTF8 as i32; + const UTF16LE = SQLITE_UTF16LE as i32; + const UTF16BE = SQLITE_UTF16BE as i32; + const UTF16 = SQLITE_UTF16 as i32; + + /// "... to signal that the function will always return the same result given the same + /// inputs within a single SQL statement." + /// + const DETERMINISTIC = SQLITE_DETERMINISTIC as i32; + const DIRECTONLY = SQLITE_DIRECTONLY as i32; + const SUBTYPE = SQLITE_SUBTYPE as i32; + const INNOCUOUS = SQLITE_INNOCUOUS as i32; + } +} diff --git a/src/errors.rs b/src/errors.rs index 31135a3..c8a7be0 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -45,6 +45,7 @@ impl Error { ErrorKind::CStringUtf8Error(_) => "utf8 err".to_owned(), ErrorKind::Message(msg) => msg, ErrorKind::TableFunction(_) => "table func error".to_owned(), + ErrorKind::DefineWindowFunction(_) => "Error defining window function".to_owned(), } } } @@ -53,6 +54,7 @@ impl Error { #[derive(Debug, PartialEq, Eq)] pub enum ErrorKind { DefineScalarFunction(c_int), + DefineWindowFunction(c_int), CStringError(NulError), CStringUtf8Error(std::str::Utf8Error), TableFunction(c_int), diff --git a/src/ext.rs b/src/ext.rs index 7326cd5..9da12aa 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -29,7 +29,7 @@ pub use sqlite3ext_sys::{ sqlite3, sqlite3_api_routines, sqlite3_context, sqlite3_index_info, sqlite3_index_info_sqlite3_index_constraint, sqlite3_index_info_sqlite3_index_constraint_usage, sqlite3_index_info_sqlite3_index_orderby, sqlite3_module, sqlite3_stmt, sqlite3_value, - sqlite3_vtab, sqlite3_vtab_cursor, + sqlite3_vtab, sqlite3_vtab_cursor, sqlite3_create_window_function }; /// If creating a dynmically loadable extension, this MUST be redefined to point @@ -413,6 +413,42 @@ pub unsafe fn sqlite3ext_get_auxdata(context: *mut sqlite3_context, n: c_int) -> ((*SQLITE3_API).get_auxdata.expect(EXPECT_MESSAGE))(context, n) } +#[cfg(feature = "static")] +pub unsafe fn sqlite3ext_create_window_function( + db: *mut sqlite3, + s: *const c_char, + argc: i32, + text_rep: i32, + p_app: *mut c_void, + x_step: Option, + x_final: Option, + x_value: Option, + x_inverse: Option, + destroy: Option +) -> c_int { + sqlite3_create_window_function( + db, s, argc, text_rep, p_app, x_step, x_final, x_value, x_inverse, destroy, + ) +} + +#[cfg(not(feature = "static"))] +pub unsafe fn sqlite3ext_create_window_function( + db: *mut sqlite3, + s: *const c_char, + argc: i32, + text_rep: i32, + p_app: *mut c_void, + x_step: Option, + x_final: Option, + x_value: Option, + x_inverse: Option, + destroy: Option +) -> c_int { + ((*SQLITE3_API).create_window_function.expect(EXPECT_MESSAGE))( + db, s, argc, text_rep, p_app, x_step, x_final, x_value, x_inverse, destroy, + ) +} + #[cfg(feature = "static")] pub unsafe fn sqlite3ext_create_function_v2( db: *mut sqlite3, diff --git a/src/lib.rs b/src/lib.rs index 3df5c80..d405ff1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,12 +14,20 @@ pub mod prelude; pub mod scalar; pub mod table; pub mod vtab_argparse; +pub mod window; +pub mod bit_flags; #[doc(inline)] pub use errors::{Error, ErrorKind, Result}; #[doc(inline)] -pub use scalar::{define_scalar_function, define_scalar_function_with_aux, FunctionFlags}; +pub use bit_flags::FunctionFlags; + +#[doc(inline)] +pub use scalar::{define_scalar_function, define_scalar_function_with_aux}; + +#[doc(inline)] +pub use window::{define_window_function,define_window_function_with_aux}; #[doc(inline)] pub use collation::define_collation; diff --git a/src/scalar.rs b/src/scalar.rs index db2982e..530c71e 100644 --- a/src/scalar.rs +++ b/src/scalar.rs @@ -12,39 +12,12 @@ use crate::{ constants::{SQLITE_INTERNAL, SQLITE_OKAY}, errors::{Error, ErrorKind, Result}, ext::{ - sqlite3, sqlite3_context, sqlite3_value, sqlite3ext_create_function_v2, + sqlite3, sqlite3_context, sqlite3_value, sqlite3ext_user_data, - }, + sqlite3ext_create_function_v2, + }, FunctionFlags, }; -use bitflags::bitflags; - -use sqlite3ext_sys::{ - SQLITE_DETERMINISTIC, SQLITE_DIRECTONLY, SQLITE_INNOCUOUS, SQLITE_SUBTYPE, SQLITE_UTF16, - SQLITE_UTF16BE, SQLITE_UTF16LE, SQLITE_UTF8, -}; - -bitflags! { - /// Represents the possible flag values that can be passed into sqlite3_create_function_v2 - /// or sqlite3_create_window_function, as the 4th "eTextRep" parameter. - /// Includes both the encoding options (utf8, utf16, etc.) and function-level parameters - /// (deterministion, innocuous, etc.). - pub struct FunctionFlags: i32 { - const UTF8 = SQLITE_UTF8 as i32; - const UTF16LE = SQLITE_UTF16LE as i32; - const UTF16BE = SQLITE_UTF16BE as i32; - const UTF16 = SQLITE_UTF16 as i32; - - /// "... to signal that the function will always return the same result given the same - /// inputs within a single SQL statement." - /// - const DETERMINISTIC = SQLITE_DETERMINISTIC as i32; - const DIRECTONLY = SQLITE_DIRECTONLY as i32; - const SUBTYPE = SQLITE_SUBTYPE as i32; - const INNOCUOUS = SQLITE_INNOCUOUS as i32; - } -} - fn create_function_v2( db: *mut sqlite3, name: &str, @@ -55,14 +28,15 @@ fn create_function_v2( x_step: Option, x_final: Option, destroy: Option, -) -> Result<()> { +) -> Result<()> +{ let cname = CString::new(name)?; let result = unsafe { sqlite3ext_create_function_v2( db, cname.as_ptr(), num_args, - func_flags.bits, + func_flags.bits(), p_app, x_func, x_step, @@ -242,6 +216,7 @@ where x_func_wrapper:: } + pub fn scalar_function_raw_with_aux( x_func: F, aux: T, diff --git a/src/window.rs b/src/window.rs new file mode 100644 index 0000000..e7c9a0d --- /dev/null +++ b/src/window.rs @@ -0,0 +1,333 @@ +//! Define window functions on sqlite3 database connections. + +#![allow(clippy::not_unsafe_ptr_arg_deref)] +use std::{ + ffi::CString, + os::raw::{c_int, c_void}, + slice, +}; + +use crate::{ + api, + constants::{SQLITE_INTERNAL, SQLITE_OKAY}, + errors::{Error, ErrorKind, Result}, + ext::sqlite3ext_create_window_function, FunctionFlags, +}; +use sqlite3ext_sys::{sqlite3, sqlite3_context, sqlite3_user_data, sqlite3_value}; + +fn create_window_function( + db: *mut sqlite3, + name: &str, + num_args: c_int, + func_flags: FunctionFlags, + p_app: *mut c_void, + x_step: unsafe extern "C" fn(*mut sqlite3_context, i32, *mut *mut sqlite3_value), + x_final: unsafe extern "C" fn(*mut sqlite3_context), + x_value: Option, + x_inverse: Option, + destroy: unsafe extern "C" fn(*mut c_void), +) -> Result<()> { + + let cname = CString::new(name)?; + let result = unsafe { + sqlite3ext_create_window_function( + db, + cname.as_ptr(), + num_args, + func_flags.bits(), + p_app, + Some(x_step), + Some(x_final), + x_value, + x_inverse, + Some(destroy), + ) + }; + + if result != SQLITE_OKAY { + Err(Error::new(ErrorKind::DefineWindowFunction(result))) + } else { + Ok(()) + } +} + +type ValueCallback = fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()>; +type ContextCallback = fn(context: *mut sqlite3_context) -> Result<()>; + +struct WindowFunctionCallbacks +{ + x_step: ValueCallback, + x_final: ContextCallback, + x_value: Option, + x_inverse: Option, +} + +impl WindowFunctionCallbacks { + fn new( + x_step: ValueCallback, + x_final: ContextCallback, + x_value: Option, + x_inverse: Option + ) -> Self { + Self { + x_step, + x_final, + x_value, + x_inverse, + } + } +} + +pub fn define_window_function( + db: *mut sqlite3, + name: &str, + num_args: c_int, + func_flags: FunctionFlags, + x_step: ValueCallback, + x_final: ContextCallback, + x_value: Option, + x_inverse: Option +) -> Result<()> +{ + let app_pointer = Box::into_raw( + Box::new( + WindowFunctionCallbacks::new(x_step, x_final, x_value, x_inverse) + ) + ); + + unsafe extern "C" fn x_step_wrapper( + context: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, + ) + { + let x = sqlite3_user_data(context).cast::(); + let args = slice::from_raw_parts(argv, argc as usize); + match ((*x).x_step)(context, args) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + + unsafe extern "C" fn x_inverse_wrapper( + context: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, + ) + { + let x = sqlite3_user_data(context).cast::(); + if let Some(x_inverse) = (*x).x_inverse { + let args = slice::from_raw_parts(argv, argc as usize); + match x_inverse(context, args) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + } + + unsafe extern "C" fn x_final_wrapper( + context: *mut sqlite3_context, + ) + { + let x = sqlite3_user_data(context).cast::(); + match ((*x).x_final)(context) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + + unsafe extern "C" fn x_value_wrapper( + context: *mut sqlite3_context, + ) + { + let x = sqlite3_user_data(context).cast::(); + if let Some(x_value) = (*x).x_value { + match x_value(context) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + } + + unsafe extern "C" fn destroy( + p_app: *mut c_void, + ) + { + let callbacks = p_app.cast::(); + let _ = Box::from_raw(callbacks); // drop + } + + create_window_function( + db, + name, + num_args, + func_flags, + app_pointer.cast::(), + x_step_wrapper, + x_final_wrapper, + Some(x_value_wrapper), + Some(x_inverse_wrapper), + destroy, + ) + +} + +// Now with aux in case the aggregate type does not implement the Copy trait +// Implementing the Copy trait implies that the underlying bytes are copyable + +type ValueCallbackWithAux = fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value], aux: &mut T) -> Result<()>; +type ContextCallbackWithAux = fn(context: *mut sqlite3_context, aux: &mut T) -> Result<()>; + +struct WindowFunctionCallbacksWithAux +{ + x_step: ValueCallbackWithAux, + x_final: ContextCallbackWithAux, + x_value: Option>, + x_inverse: Option>, + aux: T, +} + +impl WindowFunctionCallbacksWithAux { + fn new( + x_step: ValueCallbackWithAux, + x_final: ContextCallbackWithAux, + x_value: Option>, + x_inverse: Option>, + aux: T + ) -> Self { + Self { + x_step, + x_final, + x_value, + x_inverse, + aux, + } + } +} + +pub fn define_window_function_with_aux( + db: *mut sqlite3, + name: &str, + num_args: c_int, + func_flags: FunctionFlags, + x_step: ValueCallbackWithAux, + x_final: ContextCallbackWithAux, + x_value: Option>, + x_inverse: Option>, + aux: T, +) -> Result<()> +{ + let app_pointer = Box::into_raw( + Box::new( + WindowFunctionCallbacksWithAux::new(x_step, x_final, x_value, x_inverse, aux) + ) + ); + + unsafe extern "C" fn x_step_wrapper( + context: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, + ) + { + let x = sqlite3_user_data(context).cast::>(); + let args = slice::from_raw_parts(argv, argc as usize); + match ((*x).x_step)(context, args, &mut (*x).aux) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + + unsafe extern "C" fn x_inverse_wrapper( + context: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, + ) + { + let x = sqlite3_user_data(context).cast::>(); + if let Some(x_inverse) = (*x).x_inverse { + let args = slice::from_raw_parts(argv, argc as usize); + match x_inverse(context, args, &mut (*x).aux) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + } + + unsafe extern "C" fn x_final_wrapper( + context: *mut sqlite3_context, + ) + { + let x = sqlite3_user_data(context).cast::>(); + match ((*x).x_final)(context, &mut (*x).aux) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + + unsafe extern "C" fn x_value_wrapper( + context: *mut sqlite3_context, + ) + { + let x = sqlite3_user_data(context).cast::>(); + if let Some(x_value) = (*x).x_value { + match x_value(context, &mut (*x).aux) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + } + + unsafe extern "C" fn destroy( + p_app: *mut c_void, + ) + { + let callbacks = p_app.cast::(); + let _ = Box::from_raw(callbacks); // drop + } + + create_window_function( + db, + name, + num_args, + func_flags, + app_pointer.cast::(), + x_step_wrapper::, + x_final_wrapper::, + Some(x_value_wrapper::), + Some(x_inverse_wrapper::), + destroy, + ) + +} \ No newline at end of file diff --git a/tests/test_sum_int.rs b/tests/test_sum_int.rs new file mode 100644 index 0000000..073eb53 --- /dev/null +++ b/tests/test_sum_int.rs @@ -0,0 +1,76 @@ +use libsqlite3_sys::sqlite3_int64; +use sqlite_loadable::prelude::*; +use sqlite_loadable::window::define_window_function; +use sqlite_loadable::{api, Result}; + +/// Example inspired by sqlite3's sumint +/// https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions +pub fn x_step(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + let previous_value = api::get_aggregate_context_value::(context)?; + api::set_aggregate_context_value::(context, previous_value + new_value)?; + Ok(()) +} + + +pub fn x_final(context: *mut sqlite3_context) -> Result<()> { + let value = api::get_aggregate_context_value::(context)?; + api::result_int64(context, value); + Ok(()) +} + + +pub fn x_value(context: *mut sqlite3_context) -> Result<()> { + let value = api::get_aggregate_context_value::(context)?; + api::result_int64(context, value); + Ok(()) +} + +pub fn x_inverse(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + let previous_value = api::get_aggregate_context_value::(context)?; + api::set_aggregate_context_value::(context, previous_value - new_value)?; + Ok(()) +} + +#[sqlite_entrypoint] +pub fn sqlite3_sumint_init(db: *mut sqlite3) -> Result<()> { + let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC; + define_window_function( + db, "sumint", -1, flags, + x_step, x_final, Some(x_value), Some(x_inverse), + )?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + use rusqlite::{ffi::sqlite3_auto_extension, Connection}; + + #[test] + fn test_rusqlite_auto_extension() { + unsafe { + sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_sumint_init as *const ()))); + } + + let conn = Connection::open_in_memory().unwrap(); + + let _ = conn + .execute("CREATE TABLE t3(x TEXT, y INTEGER)", ()); + + let _ = conn + .execute("INSERT INTO t3 VALUES ('a', 4), ('b', 5), ('c', 3), ('d', 8), ('e', 1)", ()); + + let result: sqlite3_int64 = conn.query_row("SELECT x, sumint(y) OVER ( + ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS sum_y + FROM t3 ORDER BY x", (), |x| x.get(1)).unwrap(); + + + assert_eq!(result, 9); + } +} diff --git a/tests/test_sum_int_aux.rs b/tests/test_sum_int_aux.rs new file mode 100644 index 0000000..b9ec23b --- /dev/null +++ b/tests/test_sum_int_aux.rs @@ -0,0 +1,73 @@ +use libsqlite3_sys::sqlite3_int64; +use sqlite_loadable::prelude::*; +use sqlite_loadable::window::define_window_function_with_aux; +use sqlite_loadable::{api, Result}; + +/// Example inspired by sqlite3's sumint +/// https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions +pub fn x_step(_context: *mut sqlite3_context, values: &[*mut sqlite3_value], i: &mut i64) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + *i = *i + new_value; + Ok(()) +} + + +pub fn x_final(context: *mut sqlite3_context, i: &mut i64) -> Result<()> { + api::result_int64(context, *i); + Ok(()) +} + + +pub fn x_value(context: *mut sqlite3_context, i: &mut i64) -> Result<()> { + api::result_int64(context, *i); + Ok(()) +} + +pub fn x_inverse(_context: *mut sqlite3_context, values: &[*mut sqlite3_value], i: &mut i64) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + *i = *i - new_value; + Ok(()) +} + +#[sqlite_entrypoint] +pub fn sqlite3_sumintaux_init(db: *mut sqlite3) -> Result<()> { + let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC; + define_window_function_with_aux::( + db, "sumint_aux", -1, flags, + x_step, x_final, Some(x_value), Some(x_inverse), + 0, + )?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + use rusqlite::{ffi::sqlite3_auto_extension, Connection}; + + #[test] + fn test_rusqlite_auto_extension() { + unsafe { + sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_sumintaux_init as *const ()))); + } + + let conn = Connection::open_in_memory().unwrap(); + + let _ = conn + .execute("CREATE TABLE t3(x TEXT, y INTEGER)", ()); + + let _ = conn + .execute("INSERT INTO t3 VALUES ('a', 4), ('b', 5), ('c', 3), ('d', 8), ('e', 1)", ()); + + let result: sqlite3_int64 = conn.query_row("SELECT x, sumint_aux(y) OVER ( + ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS sum_y + FROM t3 ORDER BY x", (), |x| x.get(1)).unwrap(); + + + assert_eq!(result, 9); + } +}