diff --git a/crates/core/src/crud_vtab.rs b/crates/core/src/crud_vtab.rs index 2eb7225..12ade1f 100644 --- a/crates/core/src/crud_vtab.rs +++ b/crates/core/src/crud_vtab.rs @@ -2,8 +2,10 @@ extern crate alloc; use alloc::boxed::Box; use alloc::string::String; +use alloc::sync::Arc; use const_format::formatcp; use core::ffi::{c_char, c_int, c_void}; +use core::sync::atomic::Ordering; use sqlite::{Connection, ResultCode, Value}; use sqlite_nostd as sqlite; @@ -13,6 +15,7 @@ use sqlite_nostd::ResultCode::NULL; use crate::error::SQLiteError; use crate::ext::SafeManagedStmt; use crate::schema::TableInfoFlags; +use crate::state::DatabaseState; use crate::vtab_util::*; // Structure: @@ -31,11 +34,12 @@ struct VirtualTable { db: *mut sqlite::sqlite3, current_tx: Option, insert_statement: Option, + state: Arc, } extern "C" fn connect( db: *mut sqlite::sqlite3, - _aux: *mut c_void, + aux: *mut c_void, _argc: c_int, _argv: *const *const c_char, vtab: *mut *mut sqlite::vtab, @@ -58,6 +62,14 @@ extern "C" fn connect( db, current_tx: None, insert_statement: None, + state: { + // Increase refcount - we can't use from_raw alone because we don't own the aux + // data (connect could be called multiple times). + let state = Arc::from_raw(aux as *mut DatabaseState); + let clone = state.clone(); + core::mem::forget(state); + clone + }, })); *vtab = tab.cast::(); let _ = sqlite::vtab_config(db, 0); @@ -127,13 +139,20 @@ fn insert_operation( flags: TableInfoFlags, ) -> Result<(), SQLiteError> { let tab = unsafe { &mut *(vtab.cast::()) }; - if tab.current_tx.is_none() { + if tab.state.is_in_sync_local.load(Ordering::Relaxed) { return Err(SQLiteError( ResultCode::MISUSE, - Some(String::from("No tx_id")), + Some(String::from("Using ps_crud during sync operation")), )); } - let current_tx = tab.current_tx.unwrap(); + + let Some(current_tx) = tab.current_tx else { + return Err(SQLiteError( + ResultCode::MISUSE, + Some(String::from("No tx_id")), + )); + }; + // language=SQLite let statement = tab .insert_statement @@ -206,8 +225,13 @@ static MODULE: sqlite_nostd::module = sqlite_nostd::module { xIntegrity: None, }; -pub fn register(db: *mut sqlite::sqlite3) -> Result<(), ResultCode> { - db.create_module_v2("powersync_crud_", &MODULE, None, None)?; +pub fn register(db: *mut sqlite::sqlite3, state: Arc) -> Result<(), ResultCode> { + db.create_module_v2( + "powersync_crud_", + &MODULE, + Some(Arc::into_raw(state) as *mut c_void), + Some(DatabaseState::destroy_arc), + )?; Ok(()) } diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 76edd45..da0438c 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -9,9 +9,12 @@ extern crate alloc; use core::ffi::{c_char, c_int}; +use alloc::sync::Arc; use sqlite::ResultCode; use sqlite_nostd as sqlite; +use crate::state::DatabaseState; + mod bson; mod checkpoint; mod crud_vtab; @@ -26,6 +29,7 @@ mod migrations; mod operations; mod operations_vtab; mod schema; +mod state; mod sync; mod sync_local; mod util; @@ -53,6 +57,8 @@ pub extern "C" fn sqlite3_powersync_init( } fn init_extension(db: *mut sqlite::sqlite3) -> Result<(), ResultCode> { + let state = Arc::new(DatabaseState::new()); + crate::version::register(db)?; crate::views::register(db)?; crate::uuid::register(db)?; @@ -62,11 +68,12 @@ fn init_extension(db: *mut sqlite::sqlite3) -> Result<(), ResultCode> { crate::view_admin::register(db)?; crate::checkpoint::register(db)?; crate::kv::register(db)?; - sync::register(db)?; + crate::state::register(db, state.clone())?; + sync::register(db, state.clone())?; crate::schema::register(db)?; - crate::operations_vtab::register(db)?; - crate::crud_vtab::register(db)?; + crate::operations_vtab::register(db, state.clone())?; + crate::crud_vtab::register(db, state)?; Ok(()) } diff --git a/crates/core/src/operations_vtab.rs b/crates/core/src/operations_vtab.rs index 96b5506..bb60308 100644 --- a/crates/core/src/operations_vtab.rs +++ b/crates/core/src/operations_vtab.rs @@ -1,6 +1,7 @@ extern crate alloc; use alloc::boxed::Box; +use alloc::sync::Arc; use core::ffi::{c_char, c_int, c_void}; use sqlite::{Connection, ResultCode, Value}; @@ -9,6 +10,7 @@ use sqlite_nostd as sqlite; use crate::operations::{ clear_remove_ops, delete_bucket, delete_pending_buckets, insert_operation, }; +use crate::state::DatabaseState; use crate::sync_local::sync_local; use crate::vtab_util::*; @@ -16,6 +18,7 @@ use crate::vtab_util::*; struct VirtualTable { base: sqlite::vtab, db: *mut sqlite::sqlite3, + state: Arc, target_applied: bool, target_validated: bool, @@ -23,7 +26,7 @@ struct VirtualTable { extern "C" fn connect( db: *mut sqlite::sqlite3, - _aux: *mut c_void, + aux: *mut c_void, _argc: c_int, _argv: *const *const c_char, vtab: *mut *mut sqlite::vtab, @@ -43,6 +46,14 @@ extern "C" fn connect( zErrMsg: core::ptr::null_mut(), }, db, + state: { + // Increase refcount - we can't use from_raw alone because we don't own the aux + // data (connect could be called multiple times). + let state = Arc::from_raw(aux as *mut DatabaseState); + let clone = state.clone(); + core::mem::forget(state); + clone + }, target_validated: false, target_applied: false, })); @@ -83,7 +94,7 @@ extern "C" fn update( let result = insert_operation(db, args[3].text()); vtab_result(vtab, result) } else if op == "sync_local" { - let result = sync_local(db, &args[3]); + let result = sync_local(&tab.state, db, &args[3]); if let Ok(result_row) = result { unsafe { *p_row_id = result_row; @@ -139,8 +150,13 @@ static MODULE: sqlite_nostd::module = sqlite_nostd::module { xIntegrity: None, }; -pub fn register(db: *mut sqlite::sqlite3) -> Result<(), ResultCode> { - db.create_module_v2("powersync_operations", &MODULE, None, None)?; +pub fn register(db: *mut sqlite::sqlite3, state: Arc) -> Result<(), ResultCode> { + db.create_module_v2( + "powersync_operations", + &MODULE, + Some(Arc::into_raw(state) as *mut c_void), + Some(DatabaseState::destroy_arc), + )?; Ok(()) } diff --git a/crates/core/src/schema/mod.rs b/crates/core/src/schema/mod.rs index 96fb732..cab6c0b 100644 --- a/crates/core/src/schema/mod.rs +++ b/crates/core/src/schema/mod.rs @@ -5,11 +5,15 @@ use alloc::vec::Vec; use serde::Deserialize; use sqlite::ResultCode; use sqlite_nostd as sqlite; -pub use table_info::{DiffIncludeOld, Table, TableInfoFlags}; +pub use table_info::{ + DiffIncludeOld, PendingStatement, PendingStatementValue, RawTable, Table, TableInfoFlags, +}; -#[derive(Deserialize)] +#[derive(Deserialize, Default)] pub struct Schema { - tables: Vec, + pub tables: Vec, + #[serde(default)] + pub raw_tables: Vec, } pub fn register(db: *mut sqlite::sqlite3) -> Result<(), ResultCode> { diff --git a/crates/core/src/schema/table_info.rs b/crates/core/src/schema/table_info.rs index 4224221..a225686 100644 --- a/crates/core/src/schema/table_info.rs +++ b/crates/core/src/schema/table_info.rs @@ -19,6 +19,13 @@ pub struct Table { pub flags: TableInfoFlags, } +#[derive(Deserialize)] +pub struct RawTable { + pub name: String, + pub put: PendingStatement, + pub delete: PendingStatement, +} + impl Table { pub fn from_json(text: &str) -> Result { serde_json::from_str(text) @@ -225,3 +232,17 @@ impl<'de> Deserialize<'de> for TableInfoFlags { ) } } + +#[derive(Deserialize)] +pub struct PendingStatement { + pub sql: String, + /// This vec should contain an entry for each parameter in [sql]. + pub params: Vec, +} + +#[derive(Deserialize)] +pub enum PendingStatementValue { + Id, + Column(String), + // TODO: Stuff like a raw object of put data? +} diff --git a/crates/core/src/state.rs b/crates/core/src/state.rs new file mode 100644 index 0000000..79bbb4a --- /dev/null +++ b/crates/core/src/state.rs @@ -0,0 +1,74 @@ +use core::{ + ffi::{c_int, c_void}, + sync::atomic::{AtomicBool, Ordering}, +}; + +use alloc::sync::Arc; +use sqlite::{Connection, ResultCode}; +use sqlite_nostd::{self as sqlite, Context}; + +/// State that is shared for a SQLite database connection after the core extension has been +/// registered on it. +/// +/// `init_extension` allocates an instance of this in an `Arc` that is shared as user-data for +/// functions/vtabs that need access to it. +pub struct DatabaseState { + pub is_in_sync_local: AtomicBool, +} + +impl DatabaseState { + pub fn new() -> Self { + DatabaseState { + is_in_sync_local: AtomicBool::new(false), + } + } + + pub fn sync_local_guard<'a>(&'a self) -> impl Drop + use<'a> { + self.is_in_sync_local + .compare_exchange(false, true, Ordering::Acquire, Ordering::Acquire) + .expect("should not be syncing already"); + + struct ClearOnDrop<'a>(&'a DatabaseState); + + impl Drop for ClearOnDrop<'_> { + fn drop(&mut self) { + self.0.is_in_sync_local.store(false, Ordering::Release); + } + } + + ClearOnDrop(self) + } + + pub unsafe extern "C" fn destroy_arc(ptr: *mut c_void) { + drop(Arc::from_raw(ptr.cast::())); + } +} + +pub fn register(db: *mut sqlite::sqlite3, state: Arc) -> Result<(), ResultCode> { + unsafe extern "C" fn func( + ctx: *mut sqlite::context, + _argc: c_int, + _argv: *mut *mut sqlite::value, + ) { + let data = ctx.user_data().cast::(); + let data = unsafe { data.as_ref() }.unwrap(); + + ctx.result_int(if data.is_in_sync_local.load(Ordering::Relaxed) { + 1 + } else { + 0 + }); + } + + db.create_function_v2( + "powersync_in_sync_operation", + 0, + 0, + Some(Arc::into_raw(state) as *mut c_void), + Some(func), + None, + None, + Some(DatabaseState::destroy_arc), + )?; + Ok(()) +} diff --git a/crates/core/src/sync/interface.rs b/crates/core/src/sync/interface.rs index aca5eb9..9afc2d5 100644 --- a/crates/core/src/sync/interface.rs +++ b/crates/core/src/sync/interface.rs @@ -5,6 +5,7 @@ use alloc::borrow::Cow; use alloc::boxed::Box; use alloc::rc::Rc; use alloc::string::ToString; +use alloc::sync::Arc; use alloc::{string::String, vec::Vec}; use serde::{Deserialize, Serialize}; use sqlite::{ResultCode, Value}; @@ -12,6 +13,8 @@ use sqlite_nostd::{self as sqlite, ColumnType}; use sqlite_nostd::{Connection, Context}; use crate::error::SQLiteError; +use crate::schema::Schema; +use crate::state::DatabaseState; use super::streaming_sync::SyncClient; use super::sync_status::DownloadSyncStatus; @@ -22,6 +25,8 @@ pub struct StartSyncStream { /// Bucket parameters to include in the request when opening a sync stream. #[serde(default)] pub parameters: Option>, + #[serde(default)] + pub schema: Schema, } /// A request sent from a client SDK to the [SyncClient] with a `powersync_control` invocation. @@ -118,7 +123,7 @@ struct SqlController { client: SyncClient, } -pub fn register(db: *mut sqlite::sqlite3) -> Result<(), ResultCode> { +pub fn register(db: *mut sqlite::sqlite3, state: Arc) -> Result<(), ResultCode> { extern "C" fn control( ctx: *mut sqlite::context, argc: c_int, @@ -199,7 +204,7 @@ pub fn register(db: *mut sqlite::sqlite3) -> Result<(), ResultCode> { } let controller = Box::new(SqlController { - client: SyncClient::new(db), + client: SyncClient::new(db, state), }); db.create_function_v2( diff --git a/crates/core/src/sync/mod.rs b/crates/core/src/sync/mod.rs index fb4f02c..2a28044 100644 --- a/crates/core/src/sync/mod.rs +++ b/crates/core/src/sync/mod.rs @@ -1,3 +1,4 @@ +use alloc::sync::Arc; use sqlite_nostd::{self as sqlite, ResultCode}; mod bucket_priority; @@ -13,6 +14,8 @@ mod sync_status; pub use bucket_priority::BucketPriority; pub use checksum::Checksum; -pub fn register(db: *mut sqlite::sqlite3) -> Result<(), ResultCode> { - interface::register(db) +use crate::state::DatabaseState; + +pub fn register(db: *mut sqlite::sqlite3, state: Arc) -> Result<(), ResultCode> { + interface::register(db, state) } diff --git a/crates/core/src/sync/storage_adapter.rs b/crates/core/src/sync/storage_adapter.rs index ed71b79..2b91118 100644 --- a/crates/core/src/sync/storage_adapter.rs +++ b/crates/core/src/sync/storage_adapter.rs @@ -9,6 +9,8 @@ use crate::{ error::SQLiteError, ext::SafeManagedStmt, operations::delete_bucket, + schema::Schema, + state::DatabaseState, sync::checkpoint::{validate_checkpoint, ChecksumMismatch}, sync_local::{PartialSyncOperation, SyncOperation}, }; @@ -143,8 +145,10 @@ impl StorageAdapter { pub fn sync_local( &self, + state: &DatabaseState, checkpoint: &OwnedCheckpoint, priority: Option, + schema: &Schema, ) -> Result { let mismatched_checksums = validate_checkpoint(checkpoint.buckets.values(), priority, self.db)?; @@ -182,7 +186,11 @@ impl StorageAdapter { } let sync_result = match priority { - None => SyncOperation::new(self.db, None).apply(), + None => { + let mut sync = SyncOperation::new(state, self.db, None); + sync.use_schema(schema); + sync.apply() + } Some(priority) => { let args = PartialArgs { priority, @@ -201,14 +209,16 @@ impl StorageAdapter { // TODO: Avoid this serialization, it's currently used to bind JSON SQL parameters. let serialized_args = serde_json::to_string(&args)?; - SyncOperation::new( + let mut sync = SyncOperation::new( + state, self.db, Some(PartialSyncOperation { priority, args: &serialized_args, }), - ) - .apply() + ); + sync.use_schema(schema); + sync.apply() } }?; diff --git a/crates/core/src/sync/streaming_sync.rs b/crates/core/src/sync/streaming_sync.rs index d5f2f51..0fac873 100644 --- a/crates/core/src/sync/streaming_sync.rs +++ b/crates/core/src/sync/streaming_sync.rs @@ -10,11 +10,18 @@ use alloc::{ collections::{btree_map::BTreeMap, btree_set::BTreeSet}, format, string::{String, ToString}, + sync::Arc, vec::Vec, }; use futures_lite::FutureExt; -use crate::{bson, error::SQLiteError, kv::client_id, sync::checkpoint::OwnedBucketChecksum}; +use crate::{ + bson, + error::SQLiteError, + kv::client_id, + state::DatabaseState, + sync::{checkpoint::OwnedBucketChecksum, interface::StartSyncStream}, +}; use sqlite_nostd::{self as sqlite, ResultCode}; use super::{ @@ -32,14 +39,16 @@ use super::{ /// initialized. pub struct SyncClient { db: *mut sqlite::sqlite3, + db_state: Arc, /// The current [ClientState] (essentially an optional [StreamingSyncIteration]). state: ClientState, } impl SyncClient { - pub fn new(db: *mut sqlite::sqlite3) -> Self { + pub fn new(db: *mut sqlite::sqlite3, state: Arc) -> Self { Self { db, + db_state: state, state: ClientState::Idle, } } @@ -52,7 +61,7 @@ impl SyncClient { SyncControlRequest::StartSyncStream(options) => { self.state.tear_down()?; - let mut handle = SyncIterationHandle::new(self.db, options.parameters)?; + let mut handle = SyncIterationHandle::new(self.db, options, self.db_state.clone())?; let instructions = handle.initialize()?; self.state = ClientState::IterationActive(handle); @@ -122,11 +131,13 @@ impl SyncIterationHandle { /// [StorageAdapter] and setting up the initial downloading state for [StorageAdapter] . fn new( db: *mut sqlite::sqlite3, - parameters: Option>, + options: StartSyncStream, + state: Arc, ) -> Result { let runner = StreamingSyncIteration { db, - parameters, + options, + state, adapter: StorageAdapter::new(db)?, status: SyncStatusContainer::new(), }; @@ -190,8 +201,9 @@ impl<'a> ActiveEvent<'a> { struct StreamingSyncIteration { db: *mut sqlite::sqlite3, + state: Arc, adapter: StorageAdapter, - parameters: Option>, + options: StartSyncStream, status: SyncStatusContainer, } @@ -244,7 +256,12 @@ impl StreamingSyncIteration { SyncEvent::BinaryLine { data } => bson::from_bytes(data)?, SyncEvent::UploadFinished => { if let Some(checkpoint) = validated_but_not_applied.take() { - let result = self.adapter.sync_local(&checkpoint, None)?; + let result = self.adapter.sync_local( + &self.state, + &checkpoint, + None, + &self.options.schema, + )?; match result { SyncLocalResult::ChangesApplied => { @@ -320,7 +337,9 @@ impl StreamingSyncIteration { ), )); }; - let result = self.adapter.sync_local(target, None)?; + let result = + self.adapter + .sync_local(&self.state, target, None, &self.options.schema)?; match result { SyncLocalResult::ChecksumFailure(checkpoint_result) => { @@ -363,7 +382,12 @@ impl StreamingSyncIteration { ), )); }; - let result = self.adapter.sync_local(target, Some(priority))?; + let result = self.adapter.sync_local( + &self.state, + target, + Some(priority), + &self.options.schema, + )?; match result { SyncLocalResult::ChecksumFailure(checkpoint_result) => { @@ -459,7 +483,7 @@ impl StreamingSyncIteration { raw_data: true, binary_data: true, client_id: client_id(self.db)?, - parameters: self.parameters.take(), + parameters: self.options.parameters.take(), }; event diff --git a/crates/core/src/sync_local.rs b/crates/core/src/sync_local.rs index f884e88..3d91a68 100644 --- a/crates/core/src/sync_local.rs +++ b/crates/core/src/sync_local.rs @@ -1,19 +1,25 @@ -use alloc::collections::BTreeSet; +use alloc::collections::btree_map::BTreeMap; use alloc::format; -use alloc::string::String; +use alloc::string::{String, ToString}; use alloc::vec::Vec; use serde::Deserialize; use crate::error::{PSResult, SQLiteError}; +use crate::schema::{PendingStatement, PendingStatementValue, RawTable, Schema}; +use crate::state::DatabaseState; use crate::sync::BucketPriority; use sqlite_nostd::{self as sqlite, Destructor, ManagedStmt, Value}; use sqlite_nostd::{ColumnType, Connection, ResultCode}; use crate::ext::SafeManagedStmt; -use crate::util::{internal_table_name, quote_internal_name}; +use crate::util::quote_internal_name; -pub fn sync_local(db: *mut sqlite::sqlite3, data: &V) -> Result { - let mut operation = SyncOperation::from_args(db, data)?; +pub fn sync_local( + state: &DatabaseState, + db: *mut sqlite::sqlite3, + data: &V, +) -> Result { + let mut operation: SyncOperation<'_> = SyncOperation::from_args(state, db, data)?; operation.apply() } @@ -26,14 +32,20 @@ pub struct PartialSyncOperation<'a> { } pub struct SyncOperation<'a> { + state: &'a DatabaseState, db: *mut sqlite::sqlite3, - data_tables: BTreeSet, + schema: ParsedDatabaseSchema<'a>, partial: Option>, } impl<'a> SyncOperation<'a> { - fn from_args(db: *mut sqlite::sqlite3, data: &'a V) -> Result { + fn from_args( + state: &'a DatabaseState, + db: *mut sqlite::sqlite3, + data: &'a V, + ) -> Result { Ok(Self::new( + state, db, match data.value_type() { ColumnType::Text => { @@ -60,14 +72,23 @@ impl<'a> SyncOperation<'a> { )) } - pub fn new(db: *mut sqlite::sqlite3, partial: Option>) -> Self { + pub fn new( + state: &'a DatabaseState, + db: *mut sqlite::sqlite3, + partial: Option>, + ) -> Self { Self { + state, db, - data_tables: BTreeSet::new(), + schema: ParsedDatabaseSchema::new(), partial, } } + pub fn use_schema(&mut self, schema: &'a Schema) { + self.schema.add_from_schema(schema); + } + fn can_apply_sync_changes(&self) -> Result { // Don't publish downloaded data until the upload queue is empty (except for downloaded data // in priority 0, which is published earlier). @@ -104,6 +125,8 @@ impl<'a> SyncOperation<'a> { } pub fn apply(&mut self) -> Result { + let guard = self.state.sync_local_guard(); + if !self.can_apply_sync_changes()? { return Ok(0); } @@ -126,48 +149,62 @@ impl<'a> SyncOperation<'a> { let id = statement.column_text(1)?; let data = statement.column_text(2); - let table_name = internal_table_name(type_name); - - if self.data_tables.contains(&table_name) { - let quoted = quote_internal_name(type_name, false); - - // is_err() is essentially a NULL check here. - // NULL data means no PUT operations found, so we delete the row. - if data.is_err() { - // DELETE - if last_delete_table.as_deref() != Some("ed) { - // Prepare statement when the table changed - last_delete_statement = Some( - self.db - .prepare_v2(&format!("DELETE FROM {} WHERE id = ?", quoted)) - .into_db_result(self.db)?, - ); - last_delete_table = Some(quoted.clone()); + if let Some(known) = self.schema.tables.get_mut(type_name) { + if let Some(raw) = &mut known.raw { + match data { + Ok(data) => { + let stmt = raw.put_statement(self.db)?; + let parsed: serde_json::Value = serde_json::from_str(data)?; + stmt.bind_for_put(id, &parsed)?; + stmt.stmt.exec()?; + } + Err(_) => { + let stmt = raw.delete_statement(self.db)?; + stmt.bind_for_delete(id)?; + stmt.stmt.exec()?; + } } - let delete_statement = last_delete_statement.as_mut().unwrap(); - - delete_statement.reset()?; - delete_statement.bind_text(1, id, sqlite::Destructor::STATIC)?; - delete_statement.exec()?; } else { - // INSERT/UPDATE - if last_insert_table.as_deref() != Some("ed) { - // Prepare statement when the table changed - last_insert_statement = Some( - self.db - .prepare_v2(&format!( - "REPLACE INTO {}(id, data) VALUES(?, ?)", - quoted - )) - .into_db_result(self.db)?, - ); - last_insert_table = Some(quoted.clone()); + let quoted = quote_internal_name(type_name, false); + + // is_err() is essentially a NULL check here. + // NULL data means no PUT operations found, so we delete the row. + if data.is_err() { + // DELETE + if last_delete_table.as_deref() != Some("ed) { + // Prepare statement when the table changed + last_delete_statement = Some( + self.db + .prepare_v2(&format!("DELETE FROM {} WHERE id = ?", quoted)) + .into_db_result(self.db)?, + ); + last_delete_table = Some(quoted.clone()); + } + let delete_statement = last_delete_statement.as_mut().unwrap(); + + delete_statement.reset()?; + delete_statement.bind_text(1, id, sqlite::Destructor::STATIC)?; + delete_statement.exec()?; + } else { + // INSERT/UPDATE + if last_insert_table.as_deref() != Some("ed) { + // Prepare statement when the table changed + last_insert_statement = Some( + self.db + .prepare_v2(&format!( + "REPLACE INTO {}(id, data) VALUES(?, ?)", + quoted + )) + .into_db_result(self.db)?, + ); + last_insert_table = Some(quoted.clone()); + } + let insert_statement = last_insert_statement.as_mut().unwrap(); + insert_statement.reset()?; + insert_statement.bind_text(1, id, sqlite::Destructor::STATIC)?; + insert_statement.bind_text(2, data?, sqlite::Destructor::STATIC)?; + insert_statement.exec()?; } - let insert_statement = last_insert_statement.as_mut().unwrap(); - insert_statement.reset()?; - insert_statement.bind_text(1, id, sqlite::Destructor::STATIC)?; - insert_statement.bind_text(2, data?, sqlite::Destructor::STATIC)?; - insert_statement.exec()?; } } else { if data.is_err() { @@ -210,23 +247,12 @@ impl<'a> SyncOperation<'a> { self.set_last_applied_op()?; self.mark_completed()?; + drop(guard); Ok(1) } fn collect_tables(&mut self) -> Result<(), SQLiteError> { - // language=SQLite - let statement = self - .db - .prepare_v2( - "SELECT name FROM sqlite_master WHERE type='table' AND name GLOB 'ps_data_*'", - ) - .into_db_result(self.db)?; - - while statement.step()? == ResultCode::ROW { - let name = statement.column_text(0)?; - self.data_tables.insert(String::from(name)); - } - Ok(()) + self.schema.add_from_db(self.db) } fn collect_full_operations(&self) -> Result { @@ -372,3 +398,175 @@ SELECT Ok(()) } } + +struct ParsedDatabaseSchema<'a> { + tables: BTreeMap>, +} + +impl<'a> ParsedDatabaseSchema<'a> { + fn new() -> Self { + Self { + tables: BTreeMap::new(), + } + } + + fn add_from_schema(&mut self, schema: &'a Schema) { + for raw in &schema.raw_tables { + self.tables + .insert(raw.name.clone(), ParsedSchemaTable::raw(raw)); + } + } + + fn add_from_db(&mut self, db: *mut sqlite::sqlite3) -> Result<(), SQLiteError> { + // language=SQLite + let statement = db + .prepare_v2( + "SELECT name FROM sqlite_master WHERE type='table' AND name GLOB 'ps_data_*'", + ) + .into_db_result(db)?; + + while statement.step()? == ResultCode::ROW { + let name = statement.column_text(0)?; + // Strip the ps_data__ prefix so that we can lookup tables by their sync protocol name. + let visible_name = name.get(9..).unwrap_or(name); + + // Tables which haven't been passed explicitly are assumed to not be raw tables. + self.tables + .insert(String::from(visible_name), ParsedSchemaTable::json_table()); + } + Ok(()) + } +} + +struct ParsedSchemaTable<'a> { + raw: Option>, +} + +struct RawTableWithCachedStatements<'a> { + definition: &'a RawTable, + cached_put: Option>, + cached_delete: Option>, +} + +impl<'a> RawTableWithCachedStatements<'a> { + fn put_statement( + &mut self, + db: *mut sqlite::sqlite3, + ) -> Result<&PreparedPendingStatement, SQLiteError> { + let cache_slot = &mut self.cached_put; + if let None = cache_slot { + let stmt = PreparedPendingStatement::prepare(db, &self.definition.put)?; + *cache_slot = Some(stmt); + } + + return Ok(cache_slot.as_ref().unwrap()); + } + + fn delete_statement( + &mut self, + db: *mut sqlite::sqlite3, + ) -> Result<&PreparedPendingStatement, SQLiteError> { + let cache_slot = &mut self.cached_delete; + if let None = cache_slot { + let stmt = PreparedPendingStatement::prepare(db, &self.definition.delete)?; + *cache_slot = Some(stmt); + } + + return Ok(cache_slot.as_ref().unwrap()); + } +} + +impl<'a> ParsedSchemaTable<'a> { + pub const fn json_table() -> Self { + Self { raw: None } + } + + pub fn raw(definition: &'a RawTable) -> Self { + Self { + raw: Some(RawTableWithCachedStatements { + definition, + cached_put: None, + cached_delete: None, + }), + } + } +} + +struct PreparedPendingStatement<'a> { + stmt: ManagedStmt, + params: &'a [PendingStatementValue], +} + +impl<'a> PreparedPendingStatement<'a> { + pub fn prepare( + db: *mut sqlite::sqlite3, + pending: &'a PendingStatement, + ) -> Result { + let stmt = db.prepare_v2(&pending.sql)?; + // TODO: Compare number of variables / other validity checks? + + Ok(Self { + stmt, + params: &pending.params, + }) + } + + pub fn bind_for_put(&self, id: &str, json_data: &serde_json::Value) -> Result<(), SQLiteError> { + use serde_json::Value; + for (i, source) in self.params.iter().enumerate() { + let i = (i + 1) as i32; + + match source { + PendingStatementValue::Id => { + self.stmt.bind_text(i, id, Destructor::STATIC)?; + } + PendingStatementValue::Column(column) => { + let parsed = json_data.as_object().ok_or_else(|| { + SQLiteError( + ResultCode::CONSTRAINT_DATATYPE, + Some("expected oplog data to be an object".to_string()), + ) + })?; + + match parsed.get(column) { + Some(Value::Bool(value)) => { + self.stmt.bind_int(i, if *value { 1 } else { 0 }) + } + Some(Value::Number(value)) => { + if let Some(value) = value.as_f64() { + // ??? there's no bind_double??? + self.stmt.bind_int64(i, value as i64) + } else if let Some(value) = value.as_u64() { + self.stmt.bind_int64(i, value as i64) + } else { + self.stmt.bind_int64(i, value.as_i64().unwrap()) + } + } + Some(Value::String(source)) => { + self.stmt.bind_text(i, &source, Destructor::STATIC) + } + _ => self.stmt.bind_null(i), + }?; + } + } + } + + Ok(()) + } + + pub fn bind_for_delete(&self, id: &str) -> Result<(), SQLiteError> { + for (i, source) in self.params.iter().enumerate() { + if let PendingStatementValue::Id = source { + self.stmt + .bind_text((i + 1) as i32, id, Destructor::STATIC)?; + } else { + return Err(SQLiteError( + ResultCode::MISUSE, + Some("Raw delete statement parameters must only reference id".to_string()), + )); + } + } + + Ok(()) + } +} diff --git a/crates/core/src/util.rs b/crates/core/src/util.rs index a9e0842..5e77768 100644 --- a/crates/core/src/util.rs +++ b/crates/core/src/util.rs @@ -32,10 +32,6 @@ pub fn quote_internal_name(name: &str, local_only: bool) -> String { } } -pub fn internal_table_name(name: &str) -> String { - return format!("ps_data__{}", name); -} - pub fn quote_identifier_prefixed(prefix: &str, name: &str) -> String { return format!("\"{:}{:}\"", prefix, name.replace("\"", "\"\"")); } diff --git a/dart/test/schema_test.dart b/dart/test/schema_test.dart index 18a4515..8d1646b 100644 --- a/dart/test/schema_test.dart +++ b/dart/test/schema_test.dart @@ -120,6 +120,36 @@ void main() { ); }); }); + + test('raw tables', () { + db.execute('SELECT powersync_replace_schema(?)', [ + json.encode({ + 'raw_tables': [ + { + 'name': 'users', + 'put': { + 'sql': 'INSERT OR REPLACE INTO users (id, name) VALUES (?, ?);', + 'params': [ + 'Id', + {'Column': 'name'} + ], + }, + 'delete': { + 'sql': 'DELETE FROM users WHERE id = ?', + 'params': ['Id'], + }, + } + ], + 'tables': [], + }) + ]); + + expect( + db.select( + "SELECT * FROM sqlite_schema WHERE type = 'table' AND name LIKE 'ps_data%'"), + isEmpty, + ); + }); }); } diff --git a/dart/test/sync_test.dart b/dart/test/sync_test.dart index fe78666..bc0c6fb 100644 --- a/dart/test/sync_test.dart +++ b/dart/test/sync_test.dart @@ -92,7 +92,7 @@ void _syncTests({ List pushSyncData( String bucket, String opId, String rowId, Object op, Object? data, - {int checksum = 0}) { + {int checksum = 0, String objectType = 'items'}) { return syncLine({ 'data': { 'bucket': bucket, @@ -103,7 +103,7 @@ void _syncTests({ { 'op_id': opId, 'op': op, - 'object_type': 'items', + 'object_type': objectType, 'object_id': rowId, 'checksum': checksum, 'data': json.encode(data), @@ -676,6 +676,155 @@ void _syncTests({ expect(db.select('SELECT * FROM ps_buckets'), isEmpty); }); }); + + syncTest('sets powersync_in_sync_operation', (_) { + var [row] = db.select('SELECT powersync_in_sync_operation() as r'); + expect(row, {'r': 0}); + + var testInSyncInvocations = []; + + db.createFunction( + functionName: 'test_in_sync', + function: (args) { + testInSyncInvocations.add((args[0] as int) != 0); + return null; + }, + argumentCount: const AllowedArgumentCount(1), + directOnly: false, + ); + + db.execute(''' +CREATE TRIGGER foo AFTER INSERT ON ps_data__items BEGIN + SELECT test_in_sync(powersync_in_sync_operation()); +END; +'''); + + // Run an insert sync iteration to start the trigger + invokeControl('start', null); + pushCheckpoint(buckets: [bucketDescription('a')]); + pushSyncData( + 'a', + '1', + '1', + 'PUT', + {'col': 'foo'}, + objectType: 'items', + ); + pushCheckpointComplete(); + + expect(testInSyncInvocations, [true]); + + [row] = db.select('SELECT powersync_in_sync_operation() as r'); + expect(row, {'r': 0}); + }); + + group('raw tables', () { + syncTest('smoke test', (_) { + db.execute( + 'CREATE TABLE users (id TEXT NOT NULL PRIMARY KEY, name TEXT NOT NULL) STRICT;'); + + invokeControl( + 'start', + json.encode({ + 'schema': { + 'raw_tables': [ + { + 'name': 'users', + 'put': { + 'sql': + 'INSERT OR REPLACE INTO users (id, name) VALUES (?, ?);', + 'params': [ + 'Id', + {'Column': 'name'} + ], + }, + 'delete': { + 'sql': 'DELETE FROM users WHERE id = ?', + 'params': ['Id'], + }, + } + ], + 'tables': [], + }, + }), + ); + + // Insert + pushCheckpoint(buckets: [bucketDescription('a')]); + pushSyncData( + 'a', + '1', + 'my_user', + 'PUT', + {'name': 'First user'}, + objectType: 'users', + ); + pushCheckpointComplete(); + + final users = db.select('SELECT * FROM users;'); + expect(users, [ + {'id': 'my_user', 'name': 'First user'} + ]); + + // Delete + pushCheckpoint(buckets: [bucketDescription('a')]); + pushSyncData( + 'a', + '1', + 'my_user', + 'REMOVE', + null, + objectType: 'users', + ); + pushCheckpointComplete(); + + expect(db.select('SELECT * FROM users'), isEmpty); + }); + + test("can't use crud vtab during sync", () { + db.execute( + 'CREATE TABLE users (id TEXT NOT NULL PRIMARY KEY, name TEXT NOT NULL) STRICT;'); + + invokeControl( + 'start', + json.encode({ + 'schema': { + 'raw_tables': [ + { + 'name': 'users', + 'put': { + // Inserting into powersync_crud_ during a sync operation is + // forbidden, that vtab should only collect local writes. + 'sql': "INSERT INTO powersync_crud_(data) VALUES (?);", + 'params': [ + {'Column': 'name'} + ], + }, + 'delete': { + 'sql': 'DELETE FROM users WHERE id = ?', + 'params': ['Id'], + }, + } + ], + 'tables': [], + }, + }), + ); + + // Insert + pushCheckpoint(buckets: [bucketDescription('a')]); + pushSyncData( + 'a', + '1', + 'my_user', + 'PUT', + {'name': 'First user'}, + objectType: 'users', + ); + + expect(pushCheckpointComplete, throwsA(isA())); + }); + }); } const _schema = {