From f0eca511a4bad77876affbe7f043e8d2c32c733d Mon Sep 17 00:00:00 2001 From: Jasm Sison Date: Wed, 30 Aug 2023 02:15:07 +0200 Subject: [PATCH 01/13] two issues with sum_int: 1. double free issue in test_sum_int.rs 2. can't load the plugin: Error: dlsym(0x808c4010, sqlite3_sumint_init): symbol not found --- Cargo.toml | 4 + examples/sum_int.rs | 48 +++++ sqlite-loadable-macros/Cargo.lock | 6 +- src/api.rs | 40 +++- src/bit_flags.rs | 27 +++ src/errors.rs | 2 + src/ext.rs | 25 ++- src/lib.rs | 10 +- src/scalar.rs | 36 +--- src/window.rs | 321 ++++++++++++++++++++++++++++++ tests/test_sum_int.rs | 76 +++++++ 11 files changed, 558 insertions(+), 37 deletions(-) create mode 100644 examples/sum_int.rs create mode 100644 src/bit_flags.rs create mode 100644 src/window.rs create mode 100644 tests/test_sum_int.rs diff --git a/Cargo.toml b/Cargo.toml index 5cc1cdf..280e483 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,3 +42,7 @@ crate-type = ["cdylib"] [[example]] name = "load_permanent" crate-type = ["cdylib"] + +[[example]] +name = "sum_int" +crate-type = ["cdylib"] diff --git a/examples/sum_int.rs b/examples/sum_int.rs new file mode 100644 index 0000000..fc6db48 --- /dev/null +++ b/examples/sum_int.rs @@ -0,0 +1,48 @@ +//! cargo build --example sum_int +//! sqlite3 :memory: '.read examples/test.sql' + +use libsqlite3_sys::sqlite3_int64; +use sqlite_loadable::prelude::*; +use sqlite_loadable::window::{WindowFunctionCallbacks, define_window_function}; +use sqlite_loadable::{api, Result}; + +/// Example inspired by sqlite3's sumint +/// https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions +pub fn x_step(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + let previous_value = api::get_aggregate_context_value::(context)?; + api::set_aggregate_context_value::(context, previous_value + new_value)?; + Ok(()) +} + + +pub fn x_final(context: *mut sqlite3_context) -> Result<()> { + let value = api::get_aggregate_context_value::(context)?; + api::result_int64(context, value); + Ok(()) +} + + +pub fn x_value(context: *mut sqlite3_context) -> Result<()> { + let value = api::get_aggregate_context_value::(context)?; + api::result_int64(context, value); + Ok(()) +} + +pub fn x_inverse(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + let previous_value = api::get_aggregate_context_value::(context)?; + api::set_aggregate_context_value::(context, previous_value - new_value)?; + Ok(()) +} + +#[sqlite_entrypoint] +pub fn sqlite3_sum_int_init(db: *mut sqlite3) -> Result<()> { + let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC; + define_window_function(db, "sum_int", -1, flags, + WindowFunctionCallbacks::new(x_step, x_final, x_value, x_inverse))?; + Ok(()) +} + diff --git a/sqlite-loadable-macros/Cargo.lock b/sqlite-loadable-macros/Cargo.lock index 08347de..7dcbdfe 100644 --- a/sqlite-loadable-macros/Cargo.lock +++ b/sqlite-loadable-macros/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "proc-macro2" -version = "1.0.46" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b" +checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" dependencies = [ "unicode-ident", ] @@ -22,7 +22,7 @@ dependencies = [ [[package]] name = "sqlite-loadable-macros" -version = "0.0.1" +version = "0.0.3" dependencies = [ "proc-macro2", "quote", diff --git a/src/api.rs b/src/api.rs index 64a2e87..8b6ceb4 100644 --- a/src/api.rs +++ b/src/api.rs @@ -19,7 +19,7 @@ use crate::ext::{ use crate::Error; use sqlite3ext_sys::{ sqlite3, sqlite3_context, sqlite3_mprintf, sqlite3_value, SQLITE_BLOB, SQLITE_FLOAT, - SQLITE_INTEGER, SQLITE_NULL, SQLITE_TEXT, + SQLITE_INTEGER, SQLITE_NULL, SQLITE_TEXT, sqlite3_aggregate_context, }; use std::os::raw::c_int; use std::slice::from_raw_parts; @@ -571,3 +571,41 @@ impl ExtendedColumnAffinity { ExtendedColumnAffinity::Numeric } } + +// TODO write test +pub fn get_aggregate_context_value(context: *mut sqlite3_context) -> Result +where + T: Copy, +{ + let p_value: *mut T = unsafe { + sqlite3_aggregate_context(context, std::mem::size_of::() as i32) as *mut T + }; + + if p_value.is_null() { + return Err("sqlite3_aggregate_context returned a null pointer.".to_string()); + } + + let value: T = unsafe { *p_value }; + + Ok(value) +} + +// TODO write test +pub fn set_aggregate_context_value(context: *mut sqlite3_context, value: T) -> Result<(), String> +where + T: Copy, +{ + let p_value: *mut T = unsafe { + sqlite3_aggregate_context(context, std::mem::size_of::() as i32) as *mut T + }; + + if p_value.is_null() { + return Err("sqlite3_aggregate_context returned a null pointer.".to_string()); + } + + unsafe { + *p_value = value; + } + + Ok(()) +} \ No newline at end of file diff --git a/src/bit_flags.rs b/src/bit_flags.rs new file mode 100644 index 0000000..4542196 --- /dev/null +++ b/src/bit_flags.rs @@ -0,0 +1,27 @@ +use bitflags::bitflags; + +use sqlite3ext_sys::{ + SQLITE_DETERMINISTIC, SQLITE_DIRECTONLY, SQLITE_INNOCUOUS, SQLITE_SUBTYPE, SQLITE_UTF16, + SQLITE_UTF16BE, SQLITE_UTF16LE, SQLITE_UTF8, +}; + +bitflags! { + /// Represents the possible flag values that can be passed into sqlite3_create_function_v2 + /// or sqlite3_create_window_function, as the 4th "eTextRep" parameter. + /// Includes both the encoding options (utf8, utf16, etc.) and function-level parameters + /// (deterministion, innocuous, etc.). + pub struct FunctionFlags: i32 { + const UTF8 = SQLITE_UTF8 as i32; + const UTF16LE = SQLITE_UTF16LE as i32; + const UTF16BE = SQLITE_UTF16BE as i32; + const UTF16 = SQLITE_UTF16 as i32; + + /// "... to signal that the function will always return the same result given the same + /// inputs within a single SQL statement." + /// + const DETERMINISTIC = SQLITE_DETERMINISTIC as i32; + const DIRECTONLY = SQLITE_DIRECTONLY as i32; + const SUBTYPE = SQLITE_SUBTYPE as i32; + const INNOCUOUS = SQLITE_INNOCUOUS as i32; + } +} diff --git a/src/errors.rs b/src/errors.rs index 31135a3..c8a7be0 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -45,6 +45,7 @@ impl Error { ErrorKind::CStringUtf8Error(_) => "utf8 err".to_owned(), ErrorKind::Message(msg) => msg, ErrorKind::TableFunction(_) => "table func error".to_owned(), + ErrorKind::DefineWindowFunction(_) => "Error defining window function".to_owned(), } } } @@ -53,6 +54,7 @@ impl Error { #[derive(Debug, PartialEq, Eq)] pub enum ErrorKind { DefineScalarFunction(c_int), + DefineWindowFunction(c_int), CStringError(NulError), CStringUtf8Error(std::str::Utf8Error), TableFunction(c_int), diff --git a/src/ext.rs b/src/ext.rs index 86458ed..f644a60 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -26,7 +26,7 @@ use sqlite3ext_sys::{ sqlite3_result_pointer, sqlite3_result_text, sqlite3_set_auxdata, sqlite3_step, sqlite3_stmt, sqlite3_value, sqlite3_value_blob, sqlite3_value_bytes, sqlite3_value_double, sqlite3_value_int, sqlite3_value_int64, sqlite3_value_pointer, sqlite3_value_subtype, - sqlite3_value_text, sqlite3_value_type, + sqlite3_value_text, sqlite3_value_type, sqlite3_create_window_function, }; /// If creating a dynmically loadable extension, this MUST be redefined to point @@ -316,6 +316,29 @@ pub unsafe fn sqlite3ext_get_auxdata(context: *mut sqlite3_context, n: c_int) -> ((*SQLITE3_API).get_auxdata.expect(EXPECT_MESSAGE))(context, n) } +pub unsafe fn sqlite3ext_create_window_function( + db: *mut sqlite3, + s: *const c_char, + argc: i32, + text_rep: i32, + p_app: *mut c_void, + x_step: Option, + x_final: Option, + x_value: Option, + x_inverse: Option, + destroy: Option +) -> c_int { + if SQLITE3_API.is_null() { + sqlite3_create_window_function( + db, s, argc, text_rep, p_app, x_step, x_final, x_value, x_inverse, destroy, + ) + } else { + ((*SQLITE3_API).create_window_function.expect(EXPECT_MESSAGE))( + db, s, argc, text_rep, p_app, x_step, x_final, x_value, x_inverse, destroy, + ) + } +} + pub unsafe fn sqlite3ext_create_function_v2( db: *mut sqlite3, s: *const c_char, diff --git a/src/lib.rs b/src/lib.rs index d1ce9ba..19517d0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,12 +14,20 @@ pub mod prelude; pub mod scalar; pub mod table; pub mod vtab_argparse; +pub mod window; +pub mod bit_flags; #[doc(inline)] pub use errors::{Error, ErrorKind, Result}; #[doc(inline)] -pub use scalar::{define_scalar_function, define_scalar_function_with_aux, FunctionFlags}; +pub use bit_flags::FunctionFlags; + +#[doc(inline)] +pub use scalar::{define_scalar_function, define_scalar_function_with_aux}; + +#[doc(inline)] +pub use window::{WindowFunctionCallbacksWithAux, define_window_function_with_aux}; #[doc(inline)] pub use collation::define_collation; diff --git a/src/scalar.rs b/src/scalar.rs index 9d3281d..771ae9d 100644 --- a/src/scalar.rs +++ b/src/scalar.rs @@ -11,38 +11,10 @@ use crate::{ api, constants::{SQLITE_INTERNAL, SQLITE_OKAY}, errors::{Error, ErrorKind, Result}, - ext::sqlite3ext_create_function_v2, + ext::sqlite3ext_create_function_v2, FunctionFlags, }; use sqlite3ext_sys::{sqlite3, sqlite3_context, sqlite3_user_data, sqlite3_value}; -use bitflags::bitflags; - -use sqlite3ext_sys::{ - SQLITE_DETERMINISTIC, SQLITE_DIRECTONLY, SQLITE_INNOCUOUS, SQLITE_SUBTYPE, SQLITE_UTF16, - SQLITE_UTF16BE, SQLITE_UTF16LE, SQLITE_UTF8, -}; - -bitflags! { - /// Represents the possible flag values that can be passed into sqlite3_create_function_v2 - /// or sqlite3_create_window_function, as the 4th "eTextRep" parameter. - /// Includes both the encoding options (utf8, utf16, etc.) and function-level parameters - /// (deterministion, innocuous, etc.). - pub struct FunctionFlags: i32 { - const UTF8 = SQLITE_UTF8 as i32; - const UTF16LE = SQLITE_UTF16LE as i32; - const UTF16BE = SQLITE_UTF16BE as i32; - const UTF16 = SQLITE_UTF16 as i32; - - /// "... to signal that the function will always return the same result given the same - /// inputs within a single SQL statement." - /// - const DETERMINISTIC = SQLITE_DETERMINISTIC as i32; - const DIRECTONLY = SQLITE_DIRECTONLY as i32; - const SUBTYPE = SQLITE_SUBTYPE as i32; - const INNOCUOUS = SQLITE_INNOCUOUS as i32; - } -} - fn create_function_v2( db: *mut sqlite3, name: &str, @@ -53,14 +25,15 @@ fn create_function_v2( x_step: Option, x_final: Option, destroy: Option, -) -> Result<()> { +) -> Result<()> +{ let cname = CString::new(name)?; let result = unsafe { sqlite3ext_create_function_v2( db, cname.as_ptr(), num_args, - func_flags.bits, + func_flags.bits(), p_app, x_func, x_step, @@ -240,3 +213,4 @@ where x_func_wrapper:: } + diff --git a/src/window.rs b/src/window.rs new file mode 100644 index 0000000..c40a371 --- /dev/null +++ b/src/window.rs @@ -0,0 +1,321 @@ +//! Define window functions on sqlite3 database connections. + +#![allow(clippy::not_unsafe_ptr_arg_deref)] +use std::{ + ffi::CString, + os::raw::{c_int, c_void}, + slice, +}; + +use crate::{ + api, + constants::{SQLITE_INTERNAL, SQLITE_OKAY}, + errors::{Error, ErrorKind, Result}, + ext::sqlite3ext_create_window_function, FunctionFlags, +}; +use sqlite3ext_sys::{sqlite3, sqlite3_context, sqlite3_user_data, sqlite3_value}; + +// TODO typedef repeating parameter types, across multiple files + +fn create_window_function( + db: *mut sqlite3, + name: &str, + num_args: c_int, + func_flags: FunctionFlags, + p_app: *mut c_void, + x_step: Option, + x_final: Option, + x_value: Option, + x_inverse: Option, + destroy: Option, +) -> Result<()> { + + let cname = CString::new(name)?; + let result = unsafe { + sqlite3ext_create_window_function( + db, + cname.as_ptr(), + num_args, + func_flags.bits(), + p_app, + x_step, + x_final, + x_value, + x_inverse, + destroy, + ) + }; + + if result != SQLITE_OKAY { + Err(Error::new(ErrorKind::DefineWindowFunction(result))) + } else { + Ok(()) + } +} + +pub struct WindowFunctionCallbacks +{ + x_step: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()>, + x_final: fn(context: *mut sqlite3_context) -> Result<()>, + x_value: fn(context: *mut sqlite3_context) -> Result<()>, + x_inverse: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()>, +} + +impl WindowFunctionCallbacks { + pub fn new( + x_step: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()>, + x_final: fn(context: *mut sqlite3_context) -> Result<()>, + x_value: fn(context: *mut sqlite3_context) -> Result<()>, + x_inverse: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> + ) -> Self { + WindowFunctionCallbacks { + x_step, + x_final, + x_value, + x_inverse, + } + } +} + +pub struct WindowFunctionCallbacksWithAux +{ + x_step: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value], aux: &T) -> Result<()>, + x_final: fn(context: *mut sqlite3_context, aux: &T) -> Result<()>, + x_value: fn(context: *mut sqlite3_context, aux: &T) -> Result<()>, + x_inverse: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value], aux: &T) -> Result<()>, +} + +impl WindowFunctionCallbacksWithAux { + pub fn new( + x_step: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value], aux: &T) -> Result<()>, + x_final: fn(context: *mut sqlite3_context, aux: &T) -> Result<()>, + x_value: fn(context: *mut sqlite3_context, aux: &T) -> Result<()>, + x_inverse: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value], aux: &T) -> Result<()> + ) -> Self { + WindowFunctionCallbacksWithAux { + x_step, + x_final, + x_value, + x_inverse, + } + } +} + +// TODO add documentation +// TODO add new test with aux object +// TODO parentheses matching +/// The aux parameter can be used to pass another context object altogether +pub fn define_window_function_with_aux( + db: *mut sqlite3, + name: &str, + num_args: c_int, + func_flags: FunctionFlags, + callbacks: WindowFunctionCallbacksWithAux, + aux: T, +) -> Result<()> +{ + let callbacks_pointer = Box::into_raw(Box::new(callbacks)); + let aux_pointer: *mut T = Box::into_raw(Box::new(aux)); + let app_pointer = Box::into_raw(Box::new((callbacks_pointer, aux_pointer))); + + unsafe extern "C" fn x_step_wrapper( + context: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, + ) + { + let x = sqlite3_user_data(context).cast::<(*mut WindowFunctionCallbacksWithAux, *mut T)>(); + let boxed_function = Box::from_raw((*x).0).as_ref().x_step; + let aux = (*x).1; + // .collect slows things waaaay down, so stick with slice for now + let args = slice::from_raw_parts(argv, argc as usize); + let b = Box::from_raw(aux); + match boxed_function(context, args, &*b) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + Box::into_raw(b); + } + + unsafe extern "C" fn x_inverse_wrapper( + context: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, + ) + { + let x = sqlite3_user_data(context).cast::<(*mut WindowFunctionCallbacksWithAux, *mut T)>(); + let boxed_function = Box::from_raw((*x).0).as_ref().x_inverse; + let aux = (*x).1; + // .collect slows things waaaay down, so stick with slice for now + let args = slice::from_raw_parts(argv, argc as usize); + let b = Box::from_raw(aux); + match boxed_function(context, args, &*b) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + Box::into_raw(b); + } + + unsafe extern "C" fn x_final_wrapper( + context: *mut sqlite3_context, + ) + { + let x = sqlite3_user_data(context).cast::<(*mut WindowFunctionCallbacksWithAux, *mut T)>(); + let boxed_function = Box::from_raw((*x).0).as_ref().x_final; + let aux = (*x).1; + let b = Box::from_raw(aux); + match boxed_function(context, &*b) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + + unsafe extern "C" fn x_value_wrapper( + context: *mut sqlite3_context, + ) + { + let x = sqlite3_user_data(context).cast::<(*mut WindowFunctionCallbacksWithAux, *mut T)>(); + let boxed_function = Box::from_raw((*x).0).as_ref().x_value; + let aux = (*x).1; + let b = Box::from_raw(aux); + match boxed_function(context, &*b) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + + create_window_function( + db, + name, + num_args, + func_flags, + // app_pointer, + app_pointer.cast::(), + Some(x_step_wrapper::), + Some(x_final_wrapper::), + Some(x_value_wrapper::), + Some(x_inverse_wrapper::), + None, // Note: release resources in x_final if necessary + ) + + +} + +// TODO add documentation +// TODO parentheses matching +/// The aux parameter can be used to pass another context object altogether +pub fn define_window_function( + db: *mut sqlite3, + name: &str, + num_args: c_int, + func_flags: FunctionFlags, + callbacks: WindowFunctionCallbacks, +) -> Result<()> +{ + let callbacks_pointer = Box::into_raw(Box::new(callbacks)); + let app_pointer = Box::into_raw(Box::new(callbacks_pointer)); + + unsafe extern "C" fn x_step_wrapper( + context: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, + ) + { + let x = sqlite3_user_data(context).cast::<*mut WindowFunctionCallbacks>(); + let boxed_function = Box::from_raw(*x).as_ref().x_step; + // .collect slows things waaaay down, so stick with slice for now + let args = slice::from_raw_parts(argv, argc as usize); + match boxed_function(context, args) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + + unsafe extern "C" fn x_inverse_wrapper( + context: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, + ) + { + let x = sqlite3_user_data(context).cast::<*mut WindowFunctionCallbacks>(); + let boxed_function = Box::from_raw(*x).as_ref().x_inverse; + // .collect slows things waaaay down, so stick with slice for now + let args = slice::from_raw_parts(argv, argc as usize); + match boxed_function(context, args) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + + unsafe extern "C" fn x_final_wrapper( + context: *mut sqlite3_context, + ) + { + let x = sqlite3_user_data(context).cast::<*mut WindowFunctionCallbacks>(); + let boxed_function = Box::from_raw(*x).as_ref().x_final; + match boxed_function(context) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + + unsafe extern "C" fn x_value_wrapper( + context: *mut sqlite3_context, + ) + { + let x = sqlite3_user_data(context).cast::<*mut WindowFunctionCallbacks>(); + let boxed_function = Box::from_raw(*x).as_ref().x_value; + match boxed_function(context) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + + create_window_function( + db, + name, + num_args, + func_flags, + // app_pointer, + app_pointer.cast::(), + Some(x_step_wrapper), + Some(x_final_wrapper), + Some(x_value_wrapper), + Some(x_inverse_wrapper), + None, // Note: release resources in x_final if necessary + ) + + +} \ No newline at end of file diff --git a/tests/test_sum_int.rs b/tests/test_sum_int.rs new file mode 100644 index 0000000..21a3d50 --- /dev/null +++ b/tests/test_sum_int.rs @@ -0,0 +1,76 @@ +use libsqlite3_sys::sqlite3_int64; +use sqlite_loadable::prelude::*; +use sqlite_loadable::window::{WindowFunctionCallbacks, define_window_function}; +use sqlite_loadable::{api, Result}; + +/// Example inspired by sqlite3's sumint +/// https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions +pub fn x_step(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + let previous_value = api::get_aggregate_context_value::(context)?; + api::set_aggregate_context_value::(context, previous_value + new_value)?; + Ok(()) +} + + +pub fn x_final(context: *mut sqlite3_context) -> Result<()> { + let value = api::get_aggregate_context_value::(context)?; + api::result_int64(context, value); + Ok(()) +} + + +pub fn x_value(context: *mut sqlite3_context) -> Result<()> { + let value = api::get_aggregate_context_value::(context)?; + api::result_int64(context, value); + Ok(()) +} + +pub fn x_inverse(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + let previous_value = api::get_aggregate_context_value::(context)?; + api::set_aggregate_context_value::(context, previous_value - new_value)?; + Ok(()) +} + +#[sqlite_entrypoint] +pub fn sqlite3_sum_int_init(db: *mut sqlite3) -> Result<()> { + let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC; + define_window_function(db, "sum_int", -1, flags, + WindowFunctionCallbacks::new(x_step, x_final, x_value, x_inverse))?; + Ok(()) +} + +// include!("../examples/sum_int.rs"); + +#[cfg(test)] +mod tests { + use super::*; + + use rusqlite::{ffi::sqlite3_auto_extension, Connection}; + + #[test] + fn test_rusqlite_auto_extension() { + unsafe { + sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_sum_int_init as *const ()))); + } + + let conn = Connection::open_in_memory().unwrap(); + + let _ = conn + .execute("CREATE TABLE t3(x TEXT, y INTEGER)", ()); + + let _ = conn + .execute("INSERT INTO t3 VALUES ('a', 4), ('b', 5), ('c', 3), ('d', 8), ('e', 1)", ()); + + let result: sqlite3_int64 = conn.query_row("SELECT x, sum_int(y) OVER ( + ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS sum_y + FROM t3 ORDER BY x", (), |x| x.get(1)).unwrap(); + + + assert_eq!(result, 9); + } +} From c06acf9adbb0b4a84c28304907079e13096fe628 Mon Sep 17 00:00:00 2001 From: Jasm Sison Date: Wed, 30 Aug 2023 16:44:48 +0200 Subject: [PATCH 02/13] 1. Resources managed by sqlite should be left alone with Box::into 2. sqlite3, has a naming convention: sqlite3_xxx_init, where xxx cannot contain underscores although the function itself may contain an underscore in sqlite3 itself --- examples/sum_int.rs | 4 ++-- src/window.rs | 4 +++- tests/test_sum_int.rs | 8 ++++---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/sum_int.rs b/examples/sum_int.rs index fc6db48..fab48ab 100644 --- a/examples/sum_int.rs +++ b/examples/sum_int.rs @@ -39,9 +39,9 @@ pub fn x_inverse(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) - } #[sqlite_entrypoint] -pub fn sqlite3_sum_int_init(db: *mut sqlite3) -> Result<()> { +pub fn sqlite3_sumint_init(db: *mut sqlite3) -> Result<()> { let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC; - define_window_function(db, "sum_int", -1, flags, + define_window_function(db, "sumint", -1, flags, WindowFunctionCallbacks::new(x_step, x_final, x_value, x_inverse))?; Ok(()) } diff --git a/src/window.rs b/src/window.rs index c40a371..7e42f17 100644 --- a/src/window.rs +++ b/src/window.rs @@ -129,7 +129,7 @@ pub fn define_window_function_with_aux( let aux = (*x).1; // .collect slows things waaaay down, so stick with slice for now let args = slice::from_raw_parts(argv, argc as usize); - let b = Box::from_raw(aux); + let b: Box = Box::from_raw(aux); match boxed_function(context, args, &*b) { Ok(()) => (), Err(e) => { @@ -180,6 +180,7 @@ pub fn define_window_function_with_aux( } } } + Box::into_raw(b); } unsafe extern "C" fn x_value_wrapper( @@ -198,6 +199,7 @@ pub fn define_window_function_with_aux( } } } + Box::into_raw(b); } create_window_function( diff --git a/tests/test_sum_int.rs b/tests/test_sum_int.rs index 21a3d50..1b06894 100644 --- a/tests/test_sum_int.rs +++ b/tests/test_sum_int.rs @@ -36,9 +36,9 @@ pub fn x_inverse(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) - } #[sqlite_entrypoint] -pub fn sqlite3_sum_int_init(db: *mut sqlite3) -> Result<()> { +pub fn sqlite3_sumint_init(db: *mut sqlite3) -> Result<()> { let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC; - define_window_function(db, "sum_int", -1, flags, + define_window_function(db, "sumint", -1, flags, WindowFunctionCallbacks::new(x_step, x_final, x_value, x_inverse))?; Ok(()) } @@ -54,7 +54,7 @@ mod tests { #[test] fn test_rusqlite_auto_extension() { unsafe { - sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_sum_int_init as *const ()))); + sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_sumint_init as *const ()))); } let conn = Connection::open_in_memory().unwrap(); @@ -65,7 +65,7 @@ mod tests { let _ = conn .execute("INSERT INTO t3 VALUES ('a', 4), ('b', 5), ('c', 3), ('d', 8), ('e', 1)", ()); - let result: sqlite3_int64 = conn.query_row("SELECT x, sum_int(y) OVER ( + let result: sqlite3_int64 = conn.query_row("SELECT x, sumint(y) OVER ( ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING ) AS sum_y FROM t3 ORDER BY x", (), |x| x.get(1)).unwrap(); From fa5b5e0b1f8cea2756fa404bfceb8bb332dd14b5 Mon Sep 17 00:00:00 2001 From: Jasm Sison Date: Mon, 4 Sep 2023 15:34:04 +0200 Subject: [PATCH 03/13] correct mistake when casting, pointer to pointer was unnecessary --- examples/sum_int.rs | 8 +- src/lib.rs | 2 +- src/window.rs | 253 ++++++++++-------------------------------- tests/test_sum_int.rs | 10 +- 4 files changed, 72 insertions(+), 201 deletions(-) diff --git a/examples/sum_int.rs b/examples/sum_int.rs index fab48ab..e8183a2 100644 --- a/examples/sum_int.rs +++ b/examples/sum_int.rs @@ -3,7 +3,7 @@ use libsqlite3_sys::sqlite3_int64; use sqlite_loadable::prelude::*; -use sqlite_loadable::window::{WindowFunctionCallbacks, define_window_function}; +use sqlite_loadable::window::define_window_function; use sqlite_loadable::{api, Result}; /// Example inspired by sqlite3's sumint @@ -41,8 +41,10 @@ pub fn x_inverse(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) - #[sqlite_entrypoint] pub fn sqlite3_sumint_init(db: *mut sqlite3) -> Result<()> { let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC; - define_window_function(db, "sumint", -1, flags, - WindowFunctionCallbacks::new(x_step, x_final, x_value, x_inverse))?; + define_window_function( + db, "sumint", -1, flags, + x_step, x_final, Some(x_value), Some(x_inverse), + )?; Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 19517d0..e92f1e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,7 @@ pub use bit_flags::FunctionFlags; pub use scalar::{define_scalar_function, define_scalar_function_with_aux}; #[doc(inline)] -pub use window::{WindowFunctionCallbacksWithAux, define_window_function_with_aux}; +pub use window::{WindowFunctionCallbacks, define_window_function}; #[doc(inline)] pub use collation::define_collation; diff --git a/src/window.rs b/src/window.rs index 7e42f17..faa954a 100644 --- a/src/window.rs +++ b/src/window.rs @@ -15,19 +15,17 @@ use crate::{ }; use sqlite3ext_sys::{sqlite3, sqlite3_context, sqlite3_user_data, sqlite3_value}; -// TODO typedef repeating parameter types, across multiple files - fn create_window_function( db: *mut sqlite3, name: &str, num_args: c_int, func_flags: FunctionFlags, p_app: *mut c_void, - x_step: Option, - x_final: Option, + x_step: unsafe extern "C" fn(*mut sqlite3_context, i32, *mut *mut sqlite3_value), + x_final: unsafe extern "C" fn(*mut sqlite3_context), x_value: Option, x_inverse: Option, - destroy: Option, + destroy: unsafe extern "C" fn(*mut c_void), ) -> Result<()> { let cname = CString::new(name)?; @@ -38,11 +36,11 @@ fn create_window_function( num_args, func_flags.bits(), p_app, - x_step, - x_final, + Some(x_step), + Some(x_final), x_value, x_inverse, - destroy, + Some(destroy), ) }; @@ -53,46 +51,25 @@ fn create_window_function( } } +type ValueCallback = fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()>; +type ContextCallback = fn(context: *mut sqlite3_context) -> Result<()>; + pub struct WindowFunctionCallbacks { - x_step: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()>, - x_final: fn(context: *mut sqlite3_context) -> Result<()>, - x_value: fn(context: *mut sqlite3_context) -> Result<()>, - x_inverse: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()>, + x_step: ValueCallback, + x_final: ContextCallback, + x_value: Option, + x_inverse: Option, } impl WindowFunctionCallbacks { pub fn new( - x_step: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()>, - x_final: fn(context: *mut sqlite3_context) -> Result<()>, - x_value: fn(context: *mut sqlite3_context) -> Result<()>, - x_inverse: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> + x_step: ValueCallback, + x_final: ContextCallback, + x_value: Option, + x_inverse: Option ) -> Self { - WindowFunctionCallbacks { - x_step, - x_final, - x_value, - x_inverse, - } - } -} - -pub struct WindowFunctionCallbacksWithAux -{ - x_step: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value], aux: &T) -> Result<()>, - x_final: fn(context: *mut sqlite3_context, aux: &T) -> Result<()>, - x_value: fn(context: *mut sqlite3_context, aux: &T) -> Result<()>, - x_inverse: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value], aux: &T) -> Result<()>, -} - -impl WindowFunctionCallbacksWithAux { - pub fn new( - x_step: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value], aux: &T) -> Result<()>, - x_final: fn(context: *mut sqlite3_context, aux: &T) -> Result<()>, - x_value: fn(context: *mut sqlite3_context, aux: &T) -> Result<()>, - x_inverse: fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value], aux: &T) -> Result<()> - ) -> Self { - WindowFunctionCallbacksWithAux { + Self { x_step, x_final, x_value, @@ -102,136 +79,23 @@ impl WindowFunctionCallbacksWithAux { } // TODO add documentation -// TODO add new test with aux object -// TODO parentheses matching -/// The aux parameter can be used to pass another context object altogether -pub fn define_window_function_with_aux( - db: *mut sqlite3, - name: &str, - num_args: c_int, - func_flags: FunctionFlags, - callbacks: WindowFunctionCallbacksWithAux, - aux: T, -) -> Result<()> -{ - let callbacks_pointer = Box::into_raw(Box::new(callbacks)); - let aux_pointer: *mut T = Box::into_raw(Box::new(aux)); - let app_pointer = Box::into_raw(Box::new((callbacks_pointer, aux_pointer))); - - unsafe extern "C" fn x_step_wrapper( - context: *mut sqlite3_context, - argc: c_int, - argv: *mut *mut sqlite3_value, - ) - { - let x = sqlite3_user_data(context).cast::<(*mut WindowFunctionCallbacksWithAux, *mut T)>(); - let boxed_function = Box::from_raw((*x).0).as_ref().x_step; - let aux = (*x).1; - // .collect slows things waaaay down, so stick with slice for now - let args = slice::from_raw_parts(argv, argc as usize); - let b: Box = Box::from_raw(aux); - match boxed_function(context, args, &*b) { - Ok(()) => (), - Err(e) => { - if api::result_error(context, &e.result_error_message()).is_err() { - api::result_error_code(context, SQLITE_INTERNAL); - } - } - } - Box::into_raw(b); - } - - unsafe extern "C" fn x_inverse_wrapper( - context: *mut sqlite3_context, - argc: c_int, - argv: *mut *mut sqlite3_value, - ) - { - let x = sqlite3_user_data(context).cast::<(*mut WindowFunctionCallbacksWithAux, *mut T)>(); - let boxed_function = Box::from_raw((*x).0).as_ref().x_inverse; - let aux = (*x).1; - // .collect slows things waaaay down, so stick with slice for now - let args = slice::from_raw_parts(argv, argc as usize); - let b = Box::from_raw(aux); - match boxed_function(context, args, &*b) { - Ok(()) => (), - Err(e) => { - if api::result_error(context, &e.result_error_message()).is_err() { - api::result_error_code(context, SQLITE_INTERNAL); - } - } - } - Box::into_raw(b); - } - - unsafe extern "C" fn x_final_wrapper( - context: *mut sqlite3_context, - ) - { - let x = sqlite3_user_data(context).cast::<(*mut WindowFunctionCallbacksWithAux, *mut T)>(); - let boxed_function = Box::from_raw((*x).0).as_ref().x_final; - let aux = (*x).1; - let b = Box::from_raw(aux); - match boxed_function(context, &*b) { - Ok(()) => (), - Err(e) => { - if api::result_error(context, &e.result_error_message()).is_err() { - api::result_error_code(context, SQLITE_INTERNAL); - } - } - } - Box::into_raw(b); - } - - unsafe extern "C" fn x_value_wrapper( - context: *mut sqlite3_context, - ) - { - let x = sqlite3_user_data(context).cast::<(*mut WindowFunctionCallbacksWithAux, *mut T)>(); - let boxed_function = Box::from_raw((*x).0).as_ref().x_value; - let aux = (*x).1; - let b = Box::from_raw(aux); - match boxed_function(context, &*b) { - Ok(()) => (), - Err(e) => { - if api::result_error(context, &e.result_error_message()).is_err() { - api::result_error_code(context, SQLITE_INTERNAL); - } - } - } - Box::into_raw(b); - } - - create_window_function( - db, - name, - num_args, - func_flags, - // app_pointer, - app_pointer.cast::(), - Some(x_step_wrapper::), - Some(x_final_wrapper::), - Some(x_value_wrapper::), - Some(x_inverse_wrapper::), - None, // Note: release resources in x_final if necessary - ) - - -} - -// TODO add documentation -// TODO parentheses matching /// The aux parameter can be used to pass another context object altogether pub fn define_window_function( db: *mut sqlite3, name: &str, num_args: c_int, func_flags: FunctionFlags, - callbacks: WindowFunctionCallbacks, + x_step: ValueCallback, + x_final: ContextCallback, + x_value: Option, + x_inverse: Option ) -> Result<()> { - let callbacks_pointer = Box::into_raw(Box::new(callbacks)); - let app_pointer = Box::into_raw(Box::new(callbacks_pointer)); + let app_pointer = Box::into_raw( + Box::new( + WindowFunctionCallbacks::new(x_step, x_final, x_value, x_inverse) + ) + ); unsafe extern "C" fn x_step_wrapper( context: *mut sqlite3_context, @@ -239,11 +103,9 @@ pub fn define_window_function( argv: *mut *mut sqlite3_value, ) { - let x = sqlite3_user_data(context).cast::<*mut WindowFunctionCallbacks>(); - let boxed_function = Box::from_raw(*x).as_ref().x_step; - // .collect slows things waaaay down, so stick with slice for now + let x = sqlite3_user_data(context).cast::(); let args = slice::from_raw_parts(argv, argc as usize); - match boxed_function(context, args) { + match ((*x).x_step)(context, args) { Ok(()) => (), Err(e) => { if api::result_error(context, &e.result_error_message()).is_err() { @@ -259,17 +121,17 @@ pub fn define_window_function( argv: *mut *mut sqlite3_value, ) { - let x = sqlite3_user_data(context).cast::<*mut WindowFunctionCallbacks>(); - let boxed_function = Box::from_raw(*x).as_ref().x_inverse; - // .collect slows things waaaay down, so stick with slice for now - let args = slice::from_raw_parts(argv, argc as usize); - match boxed_function(context, args) { - Ok(()) => (), - Err(e) => { - if api::result_error(context, &e.result_error_message()).is_err() { - api::result_error_code(context, SQLITE_INTERNAL); + let x = sqlite3_user_data(context).cast::(); + if let Some(x_inverse) = (*x).x_inverse { + let args = slice::from_raw_parts(argv, argc as usize); + match x_inverse(context, args) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } } - } + } } } @@ -277,9 +139,8 @@ pub fn define_window_function( context: *mut sqlite3_context, ) { - let x = sqlite3_user_data(context).cast::<*mut WindowFunctionCallbacks>(); - let boxed_function = Box::from_raw(*x).as_ref().x_final; - match boxed_function(context) { + let x = sqlite3_user_data(context).cast::(); + match ((*x).x_final)(context) { Ok(()) => (), Err(e) => { if api::result_error(context, &e.result_error_message()).is_err() { @@ -293,30 +154,38 @@ pub fn define_window_function( context: *mut sqlite3_context, ) { - let x = sqlite3_user_data(context).cast::<*mut WindowFunctionCallbacks>(); - let boxed_function = Box::from_raw(*x).as_ref().x_value; - match boxed_function(context) { - Ok(()) => (), - Err(e) => { - if api::result_error(context, &e.result_error_message()).is_err() { - api::result_error_code(context, SQLITE_INTERNAL); + let x = sqlite3_user_data(context).cast::(); + if let Some(x_value) = (*x).x_value { + match x_value(context) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } } - } + } } } + unsafe extern "C" fn destroy( + p_app: *mut c_void, + ) + { + let callbacks = p_app.cast::(); + let _ = Box::from(callbacks); // drop + } + create_window_function( db, name, num_args, func_flags, - // app_pointer, app_pointer.cast::(), - Some(x_step_wrapper), - Some(x_final_wrapper), + x_step_wrapper, + x_final_wrapper, Some(x_value_wrapper), Some(x_inverse_wrapper), - None, // Note: release resources in x_final if necessary + destroy, ) diff --git a/tests/test_sum_int.rs b/tests/test_sum_int.rs index 1b06894..073eb53 100644 --- a/tests/test_sum_int.rs +++ b/tests/test_sum_int.rs @@ -1,6 +1,6 @@ use libsqlite3_sys::sqlite3_int64; use sqlite_loadable::prelude::*; -use sqlite_loadable::window::{WindowFunctionCallbacks, define_window_function}; +use sqlite_loadable::window::define_window_function; use sqlite_loadable::{api, Result}; /// Example inspired by sqlite3's sumint @@ -38,13 +38,13 @@ pub fn x_inverse(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) - #[sqlite_entrypoint] pub fn sqlite3_sumint_init(db: *mut sqlite3) -> Result<()> { let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC; - define_window_function(db, "sumint", -1, flags, - WindowFunctionCallbacks::new(x_step, x_final, x_value, x_inverse))?; + define_window_function( + db, "sumint", -1, flags, + x_step, x_final, Some(x_value), Some(x_inverse), + )?; Ok(()) } -// include!("../examples/sum_int.rs"); - #[cfg(test)] mod tests { use super::*; From f272ccd3647fcf711cca5d109487e81c0ebb36b7 Mon Sep 17 00:00:00 2001 From: Jasm Sison Date: Thu, 31 Aug 2023 19:58:34 +0200 Subject: [PATCH 04/13] run valgrind on mac M1 etc. --- Dockerfile | 8 ++++++++ run-docker.sh | 8 ++++++++ 2 files changed, 16 insertions(+) create mode 100644 Dockerfile create mode 100644 run-docker.sh diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..b03a82e --- /dev/null +++ b/Dockerfile @@ -0,0 +1,8 @@ +FROM debian:bullseye-slim + +RUN apt-get update && apt-get install -y curl valgrind build-essential clang +# Install Rust +ENV RUST_VERSION=stable +RUN curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain=$RUST_VERSION +# Install cargo-valgrind +RUN /bin/bash -c "source /root/.cargo/env && cargo install cargo-valgrind" diff --git a/run-docker.sh b/run-docker.sh new file mode 100644 index 0000000..0591db1 --- /dev/null +++ b/run-docker.sh @@ -0,0 +1,8 @@ +#!/bin/sh +NAME="valgrind:1.0" +docker image inspect "$NAME" || docker build -t "$NAME" . +docker run -it -v $PWD:/tmp -w /tmp valgrind:1.0 + +# see https://github.com/jfrimmel/cargo-valgrind/pull/58/commits/1c168f296e0b3daa50279c642dd37aecbd85c5ff#L59 +# scan for double frees and leaks +# VALGRINDFLAGS="--leak-check=yes --trace-children=yes" cargo valgrind test From 5935b4ca5012bf730d5e11402d21e98f1b634786 Mon Sep 17 00:00:00 2001 From: Jasm Sison Date: Mon, 4 Sep 2023 15:41:43 +0200 Subject: [PATCH 05/13] fix wrong function called to drop callbacks, verified no leaks with valgrind aux implementation might not be necessary, due to aggregate context --- src/window.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/window.rs b/src/window.rs index faa954a..a5ce4ec 100644 --- a/src/window.rs +++ b/src/window.rs @@ -78,8 +78,6 @@ impl WindowFunctionCallbacks { } } -// TODO add documentation -/// The aux parameter can be used to pass another context object altogether pub fn define_window_function( db: *mut sqlite3, name: &str, @@ -172,7 +170,7 @@ pub fn define_window_function( ) { let callbacks = p_app.cast::(); - let _ = Box::from(callbacks); // drop + let _ = Box::from_raw(callbacks); // drop } create_window_function( From c1447f1f645dbe335d8f976922aa9d167d3351ac Mon Sep 17 00:00:00 2001 From: Jasm Sison Date: Mon, 4 Sep 2023 16:15:59 +0200 Subject: [PATCH 06/13] add doc --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 19ec091..96bdbd4 100644 --- a/README.md +++ b/README.md @@ -132,6 +132,13 @@ select * from xxx; Some real-world non-Rust examples of traditional virtual tables in SQLite include the [CSV virtual table](https://www.sqlite.org/csv.html), the full-text search [fts5 extension](https://www.sqlite.org/fts5.html#fts5_table_creation_and_initialization), and the [R-Tree extension](https://www.sqlite.org/rtree.html#creating_an_r_tree_index). + +### Window (aggregate) functions + +A window function can be defined using the `define_window_function`. The step and final function must be defined. See the [`sum_int.rs`](./examples/sum_int.rs) implementation and the [sqlite's own example](https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions) for a full solution. + + + ## Examples The [`examples/`](./examples/) directory has a few bare-bones examples of extensions, which you can build with: From 5efb44a579c3acaf80461d0682c3299dbc7b80b7 Mon Sep 17 00:00:00 2001 From: Jasm Sison Date: Mon, 4 Sep 2023 16:17:53 +0200 Subject: [PATCH 07/13] clean up --- src/lib.rs | 2 +- src/window.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e92f1e6..ef80605 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,7 @@ pub use bit_flags::FunctionFlags; pub use scalar::{define_scalar_function, define_scalar_function_with_aux}; #[doc(inline)] -pub use window::{WindowFunctionCallbacks, define_window_function}; +pub use window::define_window_function; #[doc(inline)] pub use collation::define_collation; diff --git a/src/window.rs b/src/window.rs index a5ce4ec..d8a244c 100644 --- a/src/window.rs +++ b/src/window.rs @@ -54,7 +54,7 @@ fn create_window_function( type ValueCallback = fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()>; type ContextCallback = fn(context: *mut sqlite3_context) -> Result<()>; -pub struct WindowFunctionCallbacks +struct WindowFunctionCallbacks { x_step: ValueCallback, x_final: ContextCallback, @@ -63,7 +63,7 @@ pub struct WindowFunctionCallbacks } impl WindowFunctionCallbacks { - pub fn new( + fn new( x_step: ValueCallback, x_final: ContextCallback, x_value: Option, From 98c1ce012853c1e2b8a9f4015014aeb1d93f9365 Mon Sep 17 00:00:00 2001 From: Jasm Sison Date: Mon, 4 Sep 2023 23:33:57 +0200 Subject: [PATCH 08/13] Add window function wih auxillary object --- Cargo.toml | 4 ++ README.md | 1 + examples/sum_int_aux.rs | 45 ++++++++++++ src/lib.rs | 2 +- src/window.rs | 143 ++++++++++++++++++++++++++++++++++++++ tests/test_sum_int_aux.rs | 73 +++++++++++++++++++ 6 files changed, 267 insertions(+), 1 deletion(-) create mode 100644 examples/sum_int_aux.rs create mode 100644 tests/test_sum_int_aux.rs diff --git a/Cargo.toml b/Cargo.toml index 280e483..1c20141 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,3 +46,7 @@ crate-type = ["cdylib"] [[example]] name = "sum_int" crate-type = ["cdylib"] + +[[example]] +name = "sum_int_aux" +crate-type = ["cdylib"] diff --git a/README.md b/README.md index 96bdbd4..d39a75b 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,7 @@ Some real-world non-Rust examples of traditional virtual tables in SQLite includ A window function can be defined using the `define_window_function`. The step and final function must be defined. See the [`sum_int.rs`](./examples/sum_int.rs) implementation and the [sqlite's own example](https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions) for a full solution. +There is also a [`define_window_function_with_aux`](./src/window.rs), in case a mutable auxillary object is required in place of the context aggregate pointer provided by sqlite3. ## Examples diff --git a/examples/sum_int_aux.rs b/examples/sum_int_aux.rs new file mode 100644 index 0000000..7c2f065 --- /dev/null +++ b/examples/sum_int_aux.rs @@ -0,0 +1,45 @@ +//! cargo build --example sum_int_aux + +use libsqlite3_sys::sqlite3_int64; +use sqlite_loadable::prelude::*; +use sqlite_loadable::window::define_window_function_with_aux; +use sqlite_loadable::{api, Result}; + +/// Example inspired by sqlite3's sumint +/// https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions +pub fn x_step(_context: *mut sqlite3_context, values: &[*mut sqlite3_value], i: &mut i64) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + *i = *i + new_value; + Ok(()) +} + + +pub fn x_final(context: *mut sqlite3_context, i: &mut i64) -> Result<()> { + api::result_int64(context, *i); + Ok(()) +} + + +pub fn x_value(context: *mut sqlite3_context, i: &mut i64) -> Result<()> { + api::result_int64(context, *i); + Ok(()) +} + +pub fn x_inverse(context: *mut sqlite3_context, values: &[*mut sqlite3_value], i: &mut i64) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + api::set_aggregate_context_value::(context, *i - new_value)?; + Ok(()) +} + +#[sqlite_entrypoint] +pub fn sqlite3_sumintaux_init(db: *mut sqlite3) -> Result<()> { + let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC; + define_window_function_with_aux::( + db, "sumint_aux", -1, flags, + x_step, x_final, Some(x_value), Some(x_inverse), + 0, + )?; + Ok(()) +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index ef80605..a9b110c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,7 @@ pub use bit_flags::FunctionFlags; pub use scalar::{define_scalar_function, define_scalar_function_with_aux}; #[doc(inline)] -pub use window::define_window_function; +pub use window::{define_window_function,define_window_function_with_aux}; #[doc(inline)] pub use collation::define_collation; diff --git a/src/window.rs b/src/window.rs index d8a244c..e7c9a0d 100644 --- a/src/window.rs +++ b/src/window.rs @@ -186,5 +186,148 @@ pub fn define_window_function( destroy, ) +} + +// Now with aux in case the aggregate type does not implement the Copy trait +// Implementing the Copy trait implies that the underlying bytes are copyable + +type ValueCallbackWithAux = fn(context: *mut sqlite3_context, values: &[*mut sqlite3_value], aux: &mut T) -> Result<()>; +type ContextCallbackWithAux = fn(context: *mut sqlite3_context, aux: &mut T) -> Result<()>; + +struct WindowFunctionCallbacksWithAux +{ + x_step: ValueCallbackWithAux, + x_final: ContextCallbackWithAux, + x_value: Option>, + x_inverse: Option>, + aux: T, +} + +impl WindowFunctionCallbacksWithAux { + fn new( + x_step: ValueCallbackWithAux, + x_final: ContextCallbackWithAux, + x_value: Option>, + x_inverse: Option>, + aux: T + ) -> Self { + Self { + x_step, + x_final, + x_value, + x_inverse, + aux, + } + } +} + +pub fn define_window_function_with_aux( + db: *mut sqlite3, + name: &str, + num_args: c_int, + func_flags: FunctionFlags, + x_step: ValueCallbackWithAux, + x_final: ContextCallbackWithAux, + x_value: Option>, + x_inverse: Option>, + aux: T, +) -> Result<()> +{ + let app_pointer = Box::into_raw( + Box::new( + WindowFunctionCallbacksWithAux::new(x_step, x_final, x_value, x_inverse, aux) + ) + ); + + unsafe extern "C" fn x_step_wrapper( + context: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, + ) + { + let x = sqlite3_user_data(context).cast::>(); + let args = slice::from_raw_parts(argv, argc as usize); + match ((*x).x_step)(context, args, &mut (*x).aux) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + + unsafe extern "C" fn x_inverse_wrapper( + context: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, + ) + { + let x = sqlite3_user_data(context).cast::>(); + if let Some(x_inverse) = (*x).x_inverse { + let args = slice::from_raw_parts(argv, argc as usize); + match x_inverse(context, args, &mut (*x).aux) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + } + + unsafe extern "C" fn x_final_wrapper( + context: *mut sqlite3_context, + ) + { + let x = sqlite3_user_data(context).cast::>(); + match ((*x).x_final)(context, &mut (*x).aux) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + + unsafe extern "C" fn x_value_wrapper( + context: *mut sqlite3_context, + ) + { + let x = sqlite3_user_data(context).cast::>(); + if let Some(x_value) = (*x).x_value { + match x_value(context, &mut (*x).aux) { + Ok(()) => (), + Err(e) => { + if api::result_error(context, &e.result_error_message()).is_err() { + api::result_error_code(context, SQLITE_INTERNAL); + } + } + } + } + } + + unsafe extern "C" fn destroy( + p_app: *mut c_void, + ) + { + let callbacks = p_app.cast::(); + let _ = Box::from_raw(callbacks); // drop + } + + create_window_function( + db, + name, + num_args, + func_flags, + app_pointer.cast::(), + x_step_wrapper::, + x_final_wrapper::, + Some(x_value_wrapper::), + Some(x_inverse_wrapper::), + destroy, + ) } \ No newline at end of file diff --git a/tests/test_sum_int_aux.rs b/tests/test_sum_int_aux.rs new file mode 100644 index 0000000..317fb07 --- /dev/null +++ b/tests/test_sum_int_aux.rs @@ -0,0 +1,73 @@ +use libsqlite3_sys::sqlite3_int64; +use sqlite_loadable::prelude::*; +use sqlite_loadable::window::define_window_function_with_aux; +use sqlite_loadable::{api, Result}; + +/// Example inspired by sqlite3's sumint +/// https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions +pub fn x_step(_context: *mut sqlite3_context, values: &[*mut sqlite3_value], i: &mut i64) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + *i = *i + new_value; + Ok(()) +} + + +pub fn x_final(context: *mut sqlite3_context, i: &mut i64) -> Result<()> { + api::result_int64(context, *i); + Ok(()) +} + + +pub fn x_value(context: *mut sqlite3_context, i: &mut i64) -> Result<()> { + api::result_int64(context, *i); + Ok(()) +} + +pub fn x_inverse(context: *mut sqlite3_context, values: &[*mut sqlite3_value], i: &mut i64) -> Result<()> { + assert!(values.len() == 1); + let new_value = api::value_int64(values.get(0).expect("should be one")); + api::set_aggregate_context_value::(context, *i - new_value)?; + Ok(()) +} + +#[sqlite_entrypoint] +pub fn sqlite3_sumintaux_init(db: *mut sqlite3) -> Result<()> { + let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC; + define_window_function_with_aux::( + db, "sumint_aux", -1, flags, + x_step, x_final, Some(x_value), Some(x_inverse), + 0, + )?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + use rusqlite::{ffi::sqlite3_auto_extension, Connection}; + + #[test] + fn test_rusqlite_auto_extension() { + unsafe { + sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_sumintaux_init as *const ()))); + } + + let conn = Connection::open_in_memory().unwrap(); + + let _ = conn + .execute("CREATE TABLE t3(x TEXT, y INTEGER)", ()); + + let _ = conn + .execute("INSERT INTO t3 VALUES ('a', 4), ('b', 5), ('c', 3), ('d', 8), ('e', 1)", ()); + + let result: sqlite3_int64 = conn.query_row("SELECT x, sumint_aux(y) OVER ( + ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS sum_y + FROM t3 ORDER BY x", (), |x| x.get(1)).unwrap(); + + + assert_eq!(result, 9); + } +} From 1ed01133acbe0f3a63d94bcfc982c53801dda7e4 Mon Sep 17 00:00:00 2001 From: Jasm Sison Date: Tue, 5 Sep 2023 05:24:06 +0200 Subject: [PATCH 09/13] whoops, forgot to remove one aggregate context call --- examples/sum_int_aux.rs | 5 ++--- tests/test_sum_int_aux.rs | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/sum_int_aux.rs b/examples/sum_int_aux.rs index 7c2f065..a1b417d 100644 --- a/examples/sum_int_aux.rs +++ b/examples/sum_int_aux.rs @@ -1,6 +1,5 @@ //! cargo build --example sum_int_aux -use libsqlite3_sys::sqlite3_int64; use sqlite_loadable::prelude::*; use sqlite_loadable::window::define_window_function_with_aux; use sqlite_loadable::{api, Result}; @@ -26,10 +25,10 @@ pub fn x_value(context: *mut sqlite3_context, i: &mut i64) -> Result<()> { Ok(()) } -pub fn x_inverse(context: *mut sqlite3_context, values: &[*mut sqlite3_value], i: &mut i64) -> Result<()> { +pub fn x_inverse(_context: *mut sqlite3_context, values: &[*mut sqlite3_value], i: &mut i64) -> Result<()> { assert!(values.len() == 1); let new_value = api::value_int64(values.get(0).expect("should be one")); - api::set_aggregate_context_value::(context, *i - new_value)?; + *i = *i - new_value; Ok(()) } diff --git a/tests/test_sum_int_aux.rs b/tests/test_sum_int_aux.rs index 317fb07..b9ec23b 100644 --- a/tests/test_sum_int_aux.rs +++ b/tests/test_sum_int_aux.rs @@ -24,10 +24,10 @@ pub fn x_value(context: *mut sqlite3_context, i: &mut i64) -> Result<()> { Ok(()) } -pub fn x_inverse(context: *mut sqlite3_context, values: &[*mut sqlite3_value], i: &mut i64) -> Result<()> { +pub fn x_inverse(_context: *mut sqlite3_context, values: &[*mut sqlite3_value], i: &mut i64) -> Result<()> { assert!(values.len() == 1); let new_value = api::value_int64(values.get(0).expect("should be one")); - api::set_aggregate_context_value::(context, *i - new_value)?; + *i = *i - new_value; Ok(()) } From f56fd766af5992f62dc04b0c5d82416b3fee6ddd Mon Sep 17 00:00:00 2001 From: Jasm Sison Date: Tue, 5 Sep 2023 20:09:39 +0200 Subject: [PATCH 10/13] update readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d39a75b..955ead3 100644 --- a/README.md +++ b/README.md @@ -133,7 +133,7 @@ select * from xxx; Some real-world non-Rust examples of traditional virtual tables in SQLite include the [CSV virtual table](https://www.sqlite.org/csv.html), the full-text search [fts5 extension](https://www.sqlite.org/fts5.html#fts5_table_creation_and_initialization), and the [R-Tree extension](https://www.sqlite.org/rtree.html#creating_an_r_tree_index). -### Window (aggregate) functions +### Window / Aggregate functions A window function can be defined using the `define_window_function`. The step and final function must be defined. See the [`sum_int.rs`](./examples/sum_int.rs) implementation and the [sqlite's own example](https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions) for a full solution. @@ -249,7 +249,7 @@ A hello world extension in C is `17KB`, while one in Rust is `469k`. It's still - [ ] Stabilize scalar function interface - [ ] Stabilize virtual table interface -- [ ] Support [aggregate window functions](https://www.sqlite.org/windowfunctions.html#udfwinfunc) ([#1](https://github.com/asg017/sqlite-loadable-rs/issues/1)) +- [x] Support [aggregate window functions](https://www.sqlite.org/windowfunctions.html#udfwinfunc) ([#1](https://github.com/asg017/sqlite-loadable-rs/issues/1)) - [ ] Support [collating sequences](https://www.sqlite.org/c3ref/create_collation.html) ([#2](https://github.com/asg017/sqlite-loadable-rs/issues/2)) - [ ] Support [virtual file systems](sqlite.org/vfs.html) ([#3](https://github.com/asg017/sqlite-loadable-rs/issues/3)) From 2888109ba014eae0cf41b663ff2f39e33eeee85e Mon Sep 17 00:00:00 2001 From: Jasm Sison Date: Tue, 5 Sep 2023 20:15:01 +0200 Subject: [PATCH 11/13] update read me some more --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 955ead3..b0161db 100644 --- a/README.md +++ b/README.md @@ -137,7 +137,7 @@ Some real-world non-Rust examples of traditional virtual tables in SQLite includ A window function can be defined using the `define_window_function`. The step and final function must be defined. See the [`sum_int.rs`](./examples/sum_int.rs) implementation and the [sqlite's own example](https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions) for a full solution. -There is also a [`define_window_function_with_aux`](./src/window.rs), in case a mutable auxillary object is required in place of the context aggregate pointer provided by sqlite3. +There is also a [`define_window_function_with_aux`](./src/window.rs), in case a mutable auxillary object is required in place of the context aggregate pointer provided by sqlite3. In this case, the object being passed is not required to implement the Copy trait. ## Examples From 4c1bdc7a96f9d1e0efe9ce8a3f5fba45543b973e Mon Sep 17 00:00:00 2001 From: Jasm Sison Date: Mon, 11 Sep 2023 16:09:54 +0200 Subject: [PATCH 12/13] update docker shell script --- run-docker.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/run-docker.sh b/run-docker.sh index 0591db1..9b80c37 100644 --- a/run-docker.sh +++ b/run-docker.sh @@ -1,7 +1,7 @@ #!/bin/sh -NAME="valgrind:1.0" +NAME="io_uring:1.0" docker image inspect "$NAME" || docker build -t "$NAME" . -docker run -it -v $PWD:/tmp -w /tmp valgrind:1.0 +docker run -it -v $PWD:/tmp -w /tmp $NAME # see https://github.com/jfrimmel/cargo-valgrind/pull/58/commits/1c168f296e0b3daa50279c642dd37aecbd85c5ff#L59 # scan for double frees and leaks From 1adedfc040f1a815503d40c286d236a47c812a05 Mon Sep 17 00:00:00 2001 From: Jasm Sison Date: Sat, 16 Sep 2023 00:46:36 +0200 Subject: [PATCH 13/13] update notes --- NOTES | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/NOTES b/NOTES index 0a16978..43788f1 100644 --- a/NOTES +++ b/NOTES @@ -2,10 +2,10 @@ TODO - aggregate functions - `scalar.rs` -> `functions.rs` - - `Aggregate` and `Window` structs? - - should it use `sqlite3_aggregate_context` or something else? - - `create_aggregate_function` - - `create_window_function` + + `Aggregate` and `Window` structs? -> same thing, just a matter of invocation, aggregate if no 'OVER' + + should it use `sqlite3_aggregate_context` or something else? -> No, not necessary + + `create_aggregate_function` -> current design is closer to the actual sqlite design + + `create_window_function` -> closer to the sqlite design - vtab_in - `IndexInfo.is_in_operator(i32 constraint_idx)` - `IndexInfo.process_in_operator(constraint_idx: i32, value: bool)`