Skip to content

Commit f0eca51

Browse files
committed
two issues with sum_int:
1. double free issue in test_sum_int.rs 2. can't load the plugin: Error: dlsym(0x808c4010, sqlite3_sumint_init): symbol not found
1 parent 1345df6 commit f0eca51

File tree

11 files changed

+558
-37
lines changed

11 files changed

+558
-37
lines changed

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,7 @@ crate-type = ["cdylib"]
4242
[[example]]
4343
name = "load_permanent"
4444
crate-type = ["cdylib"]
45+
46+
[[example]]
47+
name = "sum_int"
48+
crate-type = ["cdylib"]

examples/sum_int.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//! cargo build --example sum_int
2+
//! sqlite3 :memory: '.read examples/test.sql'
3+
4+
use libsqlite3_sys::sqlite3_int64;
5+
use sqlite_loadable::prelude::*;
6+
use sqlite_loadable::window::{WindowFunctionCallbacks, define_window_function};
7+
use sqlite_loadable::{api, Result};
8+
9+
/// Example inspired by sqlite3's sumint
10+
/// https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions
11+
pub fn x_step(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> {
12+
assert!(values.len() == 1);
13+
let new_value = api::value_int64(values.get(0).expect("should be one"));
14+
let previous_value = api::get_aggregate_context_value::<sqlite3_int64>(context)?;
15+
api::set_aggregate_context_value::<sqlite3_int64>(context, previous_value + new_value)?;
16+
Ok(())
17+
}
18+
19+
20+
pub fn x_final(context: *mut sqlite3_context) -> Result<()> {
21+
let value = api::get_aggregate_context_value::<sqlite3_int64>(context)?;
22+
api::result_int64(context, value);
23+
Ok(())
24+
}
25+
26+
27+
pub fn x_value(context: *mut sqlite3_context) -> Result<()> {
28+
let value = api::get_aggregate_context_value::<sqlite3_int64>(context)?;
29+
api::result_int64(context, value);
30+
Ok(())
31+
}
32+
33+
pub fn x_inverse(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> {
34+
assert!(values.len() == 1);
35+
let new_value = api::value_int64(values.get(0).expect("should be one"));
36+
let previous_value = api::get_aggregate_context_value::<sqlite3_int64>(context)?;
37+
api::set_aggregate_context_value::<sqlite3_int64>(context, previous_value - new_value)?;
38+
Ok(())
39+
}
40+
41+
#[sqlite_entrypoint]
42+
pub fn sqlite3_sum_int_init(db: *mut sqlite3) -> Result<()> {
43+
let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC;
44+
define_window_function(db, "sum_int", -1, flags,
45+
WindowFunctionCallbacks::new(x_step, x_final, x_value, x_inverse))?;
46+
Ok(())
47+
}
48+

sqlite-loadable-macros/Cargo.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/api.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::ext::{
1919
use crate::Error;
2020
use sqlite3ext_sys::{
2121
sqlite3, sqlite3_context, sqlite3_mprintf, sqlite3_value, SQLITE_BLOB, SQLITE_FLOAT,
22-
SQLITE_INTEGER, SQLITE_NULL, SQLITE_TEXT,
22+
SQLITE_INTEGER, SQLITE_NULL, SQLITE_TEXT, sqlite3_aggregate_context,
2323
};
2424
use std::os::raw::c_int;
2525
use std::slice::from_raw_parts;
@@ -571,3 +571,41 @@ impl ExtendedColumnAffinity {
571571
ExtendedColumnAffinity::Numeric
572572
}
573573
}
574+
575+
// TODO write test
576+
pub fn get_aggregate_context_value<T>(context: *mut sqlite3_context) -> Result<T, String>
577+
where
578+
T: Copy,
579+
{
580+
let p_value: *mut T = unsafe {
581+
sqlite3_aggregate_context(context, std::mem::size_of::<T>() as i32) as *mut T
582+
};
583+
584+
if p_value.is_null() {
585+
return Err("sqlite3_aggregate_context returned a null pointer.".to_string());
586+
}
587+
588+
let value: T = unsafe { *p_value };
589+
590+
Ok(value)
591+
}
592+
593+
// TODO write test
594+
pub fn set_aggregate_context_value<T>(context: *mut sqlite3_context, value: T) -> Result<(), String>
595+
where
596+
T: Copy,
597+
{
598+
let p_value: *mut T = unsafe {
599+
sqlite3_aggregate_context(context, std::mem::size_of::<T>() as i32) as *mut T
600+
};
601+
602+
if p_value.is_null() {
603+
return Err("sqlite3_aggregate_context returned a null pointer.".to_string());
604+
}
605+
606+
unsafe {
607+
*p_value = value;
608+
}
609+
610+
Ok(())
611+
}

src/bit_flags.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
use bitflags::bitflags;
2+
3+
use sqlite3ext_sys::{
4+
SQLITE_DETERMINISTIC, SQLITE_DIRECTONLY, SQLITE_INNOCUOUS, SQLITE_SUBTYPE, SQLITE_UTF16,
5+
SQLITE_UTF16BE, SQLITE_UTF16LE, SQLITE_UTF8,
6+
};
7+
8+
bitflags! {
9+
/// Represents the possible flag values that can be passed into sqlite3_create_function_v2
10+
/// or sqlite3_create_window_function, as the 4th "eTextRep" parameter.
11+
/// Includes both the encoding options (utf8, utf16, etc.) and function-level parameters
12+
/// (deterministion, innocuous, etc.).
13+
pub struct FunctionFlags: i32 {
14+
const UTF8 = SQLITE_UTF8 as i32;
15+
const UTF16LE = SQLITE_UTF16LE as i32;
16+
const UTF16BE = SQLITE_UTF16BE as i32;
17+
const UTF16 = SQLITE_UTF16 as i32;
18+
19+
/// "... to signal that the function will always return the same result given the same
20+
/// inputs within a single SQL statement."
21+
/// <https://www.sqlite.org/c3ref/create_function.html#:~:text=ORed%20with%20SQLITE_DETERMINISTIC>
22+
const DETERMINISTIC = SQLITE_DETERMINISTIC as i32;
23+
const DIRECTONLY = SQLITE_DIRECTONLY as i32;
24+
const SUBTYPE = SQLITE_SUBTYPE as i32;
25+
const INNOCUOUS = SQLITE_INNOCUOUS as i32;
26+
}
27+
}

src/errors.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ impl Error {
4545
ErrorKind::CStringUtf8Error(_) => "utf8 err".to_owned(),
4646
ErrorKind::Message(msg) => msg,
4747
ErrorKind::TableFunction(_) => "table func error".to_owned(),
48+
ErrorKind::DefineWindowFunction(_) => "Error defining window function".to_owned(),
4849
}
4950
}
5051
}
@@ -53,6 +54,7 @@ impl Error {
5354
#[derive(Debug, PartialEq, Eq)]
5455
pub enum ErrorKind {
5556
DefineScalarFunction(c_int),
57+
DefineWindowFunction(c_int),
5658
CStringError(NulError),
5759
CStringUtf8Error(std::str::Utf8Error),
5860
TableFunction(c_int),

src/ext.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use sqlite3ext_sys::{
2626
sqlite3_result_pointer, sqlite3_result_text, sqlite3_set_auxdata, sqlite3_step, sqlite3_stmt,
2727
sqlite3_value, sqlite3_value_blob, sqlite3_value_bytes, sqlite3_value_double,
2828
sqlite3_value_int, sqlite3_value_int64, sqlite3_value_pointer, sqlite3_value_subtype,
29-
sqlite3_value_text, sqlite3_value_type,
29+
sqlite3_value_text, sqlite3_value_type, sqlite3_create_window_function,
3030
};
3131

3232
/// If creating a dynmically loadable extension, this MUST be redefined to point
@@ -316,6 +316,29 @@ pub unsafe fn sqlite3ext_get_auxdata(context: *mut sqlite3_context, n: c_int) ->
316316
((*SQLITE3_API).get_auxdata.expect(EXPECT_MESSAGE))(context, n)
317317
}
318318

319+
pub unsafe fn sqlite3ext_create_window_function(
320+
db: *mut sqlite3,
321+
s: *const c_char,
322+
argc: i32,
323+
text_rep: i32,
324+
p_app: *mut c_void,
325+
x_step: Option<unsafe extern "C" fn(*mut sqlite3_context, i32, *mut *mut sqlite3_value)>,
326+
x_final: Option<unsafe extern "C" fn(*mut sqlite3_context)>,
327+
x_value: Option<unsafe extern "C" fn(*mut sqlite3_context)>,
328+
x_inverse: Option<unsafe extern "C" fn(*mut sqlite3_context, i32, *mut *mut sqlite3_value)>,
329+
destroy: Option<unsafe extern "C" fn(*mut c_void)>
330+
) -> c_int {
331+
if SQLITE3_API.is_null() {
332+
sqlite3_create_window_function(
333+
db, s, argc, text_rep, p_app, x_step, x_final, x_value, x_inverse, destroy,
334+
)
335+
} else {
336+
((*SQLITE3_API).create_window_function.expect(EXPECT_MESSAGE))(
337+
db, s, argc, text_rep, p_app, x_step, x_final, x_value, x_inverse, destroy,
338+
)
339+
}
340+
}
341+
319342
pub unsafe fn sqlite3ext_create_function_v2(
320343
db: *mut sqlite3,
321344
s: *const c_char,

src/lib.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,20 @@ pub mod prelude;
1414
pub mod scalar;
1515
pub mod table;
1616
pub mod vtab_argparse;
17+
pub mod window;
18+
pub mod bit_flags;
1719

1820
#[doc(inline)]
1921
pub use errors::{Error, ErrorKind, Result};
2022

2123
#[doc(inline)]
22-
pub use scalar::{define_scalar_function, define_scalar_function_with_aux, FunctionFlags};
24+
pub use bit_flags::FunctionFlags;
25+
26+
#[doc(inline)]
27+
pub use scalar::{define_scalar_function, define_scalar_function_with_aux};
28+
29+
#[doc(inline)]
30+
pub use window::{WindowFunctionCallbacksWithAux, define_window_function_with_aux};
2331

2432
#[doc(inline)]
2533
pub use collation::define_collation;

src/scalar.rs

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,38 +11,10 @@ use crate::{
1111
api,
1212
constants::{SQLITE_INTERNAL, SQLITE_OKAY},
1313
errors::{Error, ErrorKind, Result},
14-
ext::sqlite3ext_create_function_v2,
14+
ext::sqlite3ext_create_function_v2, FunctionFlags,
1515
};
1616
use sqlite3ext_sys::{sqlite3, sqlite3_context, sqlite3_user_data, sqlite3_value};
1717

18-
use bitflags::bitflags;
19-
20-
use sqlite3ext_sys::{
21-
SQLITE_DETERMINISTIC, SQLITE_DIRECTONLY, SQLITE_INNOCUOUS, SQLITE_SUBTYPE, SQLITE_UTF16,
22-
SQLITE_UTF16BE, SQLITE_UTF16LE, SQLITE_UTF8,
23-
};
24-
25-
bitflags! {
26-
/// Represents the possible flag values that can be passed into sqlite3_create_function_v2
27-
/// or sqlite3_create_window_function, as the 4th "eTextRep" parameter.
28-
/// Includes both the encoding options (utf8, utf16, etc.) and function-level parameters
29-
/// (deterministion, innocuous, etc.).
30-
pub struct FunctionFlags: i32 {
31-
const UTF8 = SQLITE_UTF8 as i32;
32-
const UTF16LE = SQLITE_UTF16LE as i32;
33-
const UTF16BE = SQLITE_UTF16BE as i32;
34-
const UTF16 = SQLITE_UTF16 as i32;
35-
36-
/// "... to signal that the function will always return the same result given the same
37-
/// inputs within a single SQL statement."
38-
/// <https://www.sqlite.org/c3ref/create_function.html#:~:text=ORed%20with%20SQLITE_DETERMINISTIC>
39-
const DETERMINISTIC = SQLITE_DETERMINISTIC as i32;
40-
const DIRECTONLY = SQLITE_DIRECTONLY as i32;
41-
const SUBTYPE = SQLITE_SUBTYPE as i32;
42-
const INNOCUOUS = SQLITE_INNOCUOUS as i32;
43-
}
44-
}
45-
4618
fn create_function_v2(
4719
db: *mut sqlite3,
4820
name: &str,
@@ -53,14 +25,15 @@ fn create_function_v2(
5325
x_step: Option<unsafe extern "C" fn(*mut sqlite3_context, i32, *mut *mut sqlite3_value)>,
5426
x_final: Option<unsafe extern "C" fn(*mut sqlite3_context)>,
5527
destroy: Option<unsafe extern "C" fn(*mut c_void)>,
56-
) -> Result<()> {
28+
) -> Result<()>
29+
{
5730
let cname = CString::new(name)?;
5831
let result = unsafe {
5932
sqlite3ext_create_function_v2(
6033
db,
6134
cname.as_ptr(),
6235
num_args,
63-
func_flags.bits,
36+
func_flags.bits(),
6437
p_app,
6538
x_func,
6639
x_step,
@@ -240,3 +213,4 @@ where
240213

241214
x_func_wrapper::<F>
242215
}
216+

0 commit comments

Comments
 (0)