diff --git a/src/config/mod.rs b/src/config/mod.rs index 89cbf49..39a8cd1 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -79,11 +79,14 @@ pub struct ServerConfig { pub max_in_flight_requests_per_connection: usize, pub max_request_bytes: usize, pub max_statements_per_request: usize, + pub statement_timeout_ms: Option, pub max_memory_intensive_requests: usize, pub max_scan_rows: usize, pub max_sort_rows: usize, pub max_join_rows: usize, pub max_query_result_rows: usize, + pub max_query_result_bytes: usize, + pub max_concurrent_queries_per_identity: Option, } impl Default for ServerConfig { @@ -94,11 +97,14 @@ impl Default for ServerConfig { max_in_flight_requests_per_connection: defaults.max_in_flight_requests_per_connection, max_request_bytes: defaults.max_request_bytes, max_statements_per_request: defaults.max_statements_per_request, + statement_timeout_ms: defaults.statement_timeout_ms, max_memory_intensive_requests: defaults.max_memory_intensive_requests, max_scan_rows: defaults.max_scan_rows, max_sort_rows: defaults.max_sort_rows, max_join_rows: defaults.max_join_rows, max_query_result_rows: defaults.max_query_result_rows, + max_query_result_bytes: defaults.max_query_result_bytes, + max_concurrent_queries_per_identity: defaults.max_concurrent_queries_per_identity, } } } @@ -278,11 +284,14 @@ pub struct StartupDiagnostics { pub server_max_in_flight_requests_per_connection: usize, pub server_max_request_bytes: usize, pub server_max_statements_per_request: usize, + pub server_statement_timeout_ms: Option, pub server_max_memory_intensive_requests: usize, pub server_max_scan_rows: usize, pub server_max_sort_rows: usize, pub server_max_join_rows: usize, pub server_max_query_result_rows: usize, + pub server_max_query_result_bytes: usize, + pub server_max_concurrent_queries_per_identity: Option, pub wal_segment_size_bytes: u64, pub wal_sync_mode: SyncModeConfig, pub sstable_data_block_size_bytes: usize, @@ -309,6 +318,10 @@ impl StartupDiagnostics { ), format!("server.max_request_bytes={}", self.server_max_request_bytes), format!("server.max_statements_per_request={}", self.server_max_statements_per_request), + format!( + "server.statement_timeout_ms={}", + format_optional_u64(self.server_statement_timeout_ms) + ), format!( "server.max_memory_intensive_requests={}", self.server_max_memory_intensive_requests @@ -317,6 +330,11 @@ impl StartupDiagnostics { format!("server.max_sort_rows={}", self.server_max_sort_rows), format!("server.max_join_rows={}", self.server_max_join_rows), format!("server.max_query_result_rows={}", self.server_max_query_result_rows), + format!("server.max_query_result_bytes={}", self.server_max_query_result_bytes), + format!( + "server.max_concurrent_queries_per_identity={}", + format_optional_usize(self.server_max_concurrent_queries_per_identity) + ), format!("wal.segment_size_bytes={}", self.wal_segment_size_bytes), format!("wal.sync_mode={}", self.wal_sync_mode.as_str()), format!("sstable.data_block_size_bytes={}", self.sstable_data_block_size_bytes), @@ -411,6 +429,9 @@ impl LsmdbConfig { if self.server.max_statements_per_request == 0 { return Err(invalid("server.max_statements_per_request", "must be > 0")); } + if matches!(self.server.statement_timeout_ms, Some(0)) { + return Err(invalid("server.statement_timeout_ms", "must be > 0 when set")); + } if self.server.max_memory_intensive_requests == 0 { return Err(invalid("server.max_memory_intensive_requests", "must be > 0")); } @@ -426,6 +447,15 @@ impl LsmdbConfig { if self.server.max_query_result_rows == 0 { return Err(invalid("server.max_query_result_rows", "must be > 0")); } + if self.server.max_query_result_bytes == 0 { + return Err(invalid("server.max_query_result_bytes", "must be > 0")); + } + if matches!(self.server.max_concurrent_queries_per_identity, Some(0)) { + return Err(invalid( + "server.max_concurrent_queries_per_identity", + "must be > 0 when set", + )); + } if self.wal.segment_size_bytes < MIN_WAL_SEGMENT_SIZE_BYTES { return Err(invalid( "wal.segment_size_bytes", @@ -491,6 +521,7 @@ impl LsmdbConfig { .max_in_flight_requests_per_connection, server_max_request_bytes: runtime.server_limits.max_request_bytes, server_max_statements_per_request: runtime.server_limits.max_statements_per_request, + server_statement_timeout_ms: runtime.server_limits.statement_timeout_ms, server_max_memory_intensive_requests: runtime .server_limits .max_memory_intensive_requests, @@ -498,6 +529,10 @@ impl LsmdbConfig { server_max_sort_rows: runtime.server_limits.max_sort_rows, server_max_join_rows: runtime.server_limits.max_join_rows, server_max_query_result_rows: runtime.server_limits.max_query_result_rows, + server_max_query_result_bytes: runtime.server_limits.max_query_result_bytes, + server_max_concurrent_queries_per_identity: runtime + .server_limits + .max_concurrent_queries_per_identity, wal_segment_size_bytes: storage.wal_options.segment_size_bytes, wal_sync_mode: SyncModeConfig::from(storage.wal_options.sync_mode), sstable_data_block_size_bytes: storage.sstable_builder_options.data_block_size_bytes, @@ -579,11 +614,14 @@ impl LsmdbConfig { .max_in_flight_requests_per_connection, max_request_bytes: self.server.max_request_bytes, max_statements_per_request: self.server.max_statements_per_request, + statement_timeout_ms: self.server.statement_timeout_ms, max_memory_intensive_requests: self.server.max_memory_intensive_requests, max_scan_rows: self.server.max_scan_rows, max_sort_rows: self.server.max_sort_rows, max_join_rows: self.server.max_join_rows, max_query_result_rows: self.server.max_query_result_rows, + max_query_result_bytes: self.server.max_query_result_bytes, + max_concurrent_queries_per_identity: self.server.max_concurrent_queries_per_identity, } } } @@ -592,6 +630,14 @@ fn invalid(field: &'static str, message: impl Into) -> ConfigError { ConfigError::InvalidValue { field, message: message.into() } } +fn format_optional_u64(value: Option) -> String { + value.map(|value| value.to_string()).unwrap_or_else(|| "none".to_string()) +} + +fn format_optional_usize(value: Option) -> String { + value.map(|value| value.to_string()).unwrap_or_else(|| "none".to_string()) +} + fn bloom_params_for_fpr(fpr: f64) -> (usize, u8) { let ln2 = std::f64::consts::LN_2; let bits = (-(fpr.ln()) / (ln2 * ln2)).ceil().max(1.0) as usize; diff --git a/src/executor/delete.rs b/src/executor/delete.rs index acc28fc..55f3a41 100644 --- a/src/executor/delete.rs +++ b/src/executor/delete.rs @@ -2,31 +2,36 @@ use crate::catalog::Catalog; use crate::mvcc::Transaction; use crate::planner::DeleteNode; -use super::ExecutionError; use super::filter::evaluate_predicate; use super::scan::scan_table_rows; +use super::{ExecutionContext, ExecutionError, apply_staged_writes}; pub(crate) fn execute_delete( catalog: &Catalog, tx: &mut Transaction, node: &DeleteNode, + context: &ExecutionContext<'_>, ) -> Result { + context.checkpoint()?; let table = catalog .get_table(&node.table) .ok_or_else(|| ExecutionError::TableNotFound(node.table.clone()))?; - let (_, rows) = scan_table_rows(catalog, tx, &table.name, usize::MAX)?; + let (_, rows) = scan_table_rows(catalog, tx, &table.name, usize::MAX, context)?; let mut affected = 0_u64; + let mut staged = std::collections::BTreeMap::new(); for stored in rows { + context.checkpoint()?; if let Some(predicate) = &node.predicate { if !evaluate_predicate(predicate, &stored.values, Some(&table.name))? { continue; } } - tx.delete(&stored.key)?; + staged.insert(stored.key, None); affected = affected.saturating_add(1); } + apply_staged_writes(tx, staged)?; Ok(affected) } diff --git a/src/executor/filter.rs b/src/executor/filter.rs index 64d52e7..b5e4e90 100644 --- a/src/executor/filter.rs +++ b/src/executor/filter.rs @@ -1,18 +1,24 @@ use crate::sql::ast::{BinaryOp, Expr, UnaryOp}; -use super::{ExecutionError, Row, RowSet, ScalarValue, literal_to_scalar, scalar_type_name}; - -pub(crate) fn apply_filter(input: RowSet, predicate: &Expr) -> Result { +use super::{ + ExecutionContext, ExecutionError, Row, RowSet, ScalarValue, literal_to_scalar, scalar_type_name, +}; + +pub(crate) fn apply_filter( + input: RowSet, + predicate: &Expr, + context: &ExecutionContext<'_>, +) -> Result { let RowSet { columns, rows, table_name } = input; - let filtered_rows = rows - .into_iter() - .filter_map(|row| match evaluate_predicate(predicate, &row, table_name.as_deref()) { - Ok(true) => Some(Ok(row)), - Ok(false) => None, - Err(err) => Some(Err(err)), - }) - .collect::, _>>()?; + let mut filtered_rows = Vec::new(); + for row in rows { + context.checkpoint()?; + match evaluate_predicate(predicate, &row, table_name.as_deref())? { + true => filtered_rows.push(row), + false => {} + } + } Ok(RowSet { columns, rows: filtered_rows, table_name }) } @@ -312,8 +318,17 @@ fn as_numeric(value: &ScalarValue) -> Result { #[cfg(test)] mod tests { use super::*; + use crate::executor::ExecutionLimits; + use crate::executor::governance::ExecutionGovernance; use crate::sql::ast::LiteralValue; + fn test_context() -> ExecutionContext<'static> { + ExecutionContext { + limits: Box::leak(Box::new(ExecutionLimits::default())), + governance: Box::leak(Box::new(ExecutionGovernance::default())), + } + } + #[test] fn evaluates_arithmetic_expression() { let expr = Expr::Binary { @@ -353,7 +368,7 @@ mod tests { right: Box::new(Expr::Literal(LiteralValue::Integer(2))), }; - let filtered = apply_filter(row_set, &predicate).expect("filter"); + let filtered = apply_filter(row_set, &predicate, &test_context()).expect("filter"); assert_eq!(filtered.rows.len(), 1); } } diff --git a/src/executor/governance.rs b/src/executor/governance.rs new file mode 100644 index 0000000..5539ace --- /dev/null +++ b/src/executor/governance.rs @@ -0,0 +1,125 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::time::{Duration, Instant}; + +use super::ExecutionError; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CancellationReason { + UserRequested = 1, +} + +impl CancellationReason { + pub fn as_str(self) -> &'static str { + match self { + CancellationReason::UserRequested => "user requested cancellation", + } + } +} + +#[derive(Debug, Default)] +struct CancellationState { + reason: AtomicU8, +} + +#[derive(Debug, Clone)] +pub struct StatementCancellation { + state: Arc, +} + +impl Default for StatementCancellation { + fn default() -> Self { + Self::new() + } +} + +impl StatementCancellation { + pub fn new() -> Self { + Self { state: Arc::new(CancellationState::default()) } + } + + pub fn cancel(&self) -> bool { + self.state + .reason + .compare_exchange( + 0, + CancellationReason::UserRequested as u8, + Ordering::SeqCst, + Ordering::SeqCst, + ) + .is_ok() + } + + pub fn reason(&self) -> Option { + match self.state.reason.load(Ordering::SeqCst) { + 0 => None, + 1 => Some(CancellationReason::UserRequested), + other => { + debug_assert_eq!(other, CancellationReason::UserRequested as u8); + Some(CancellationReason::UserRequested) + } + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct StatementDeadline { + deadline: Instant, + timeout: Duration, +} + +impl StatementDeadline { + pub fn after(timeout: Duration) -> Self { + Self { deadline: Instant::now() + timeout, timeout } + } + + pub fn is_elapsed(self) -> bool { + Instant::now() >= self.deadline + } + + pub fn timeout_ms(self) -> u64 { + let millis = self.timeout.as_millis(); + u64::try_from(millis).unwrap_or(u64::MAX) + } +} + +#[derive(Debug, Clone, Default)] +pub struct ExecutionGovernance { + deadline: Option, + cancellation: Option, +} + +impl ExecutionGovernance { + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.deadline = Some(StatementDeadline::after(timeout)); + self + } + + pub fn with_deadline(mut self, deadline: StatementDeadline) -> Self { + self.deadline = Some(deadline); + self + } + + pub fn with_cancellation(mut self, cancellation: StatementCancellation) -> Self { + self.cancellation = Some(cancellation); + self + } + + pub fn checkpoint(&self) -> Result<(), ExecutionError> { + if let Some(cancellation) = &self.cancellation { + if let Some(reason) = cancellation.reason() { + return Err(ExecutionError::StatementCanceled { reason: reason.as_str() }); + } + } + + if let Some(deadline) = self.deadline { + if deadline.is_elapsed() { + return Err(ExecutionError::StatementTimedOut { + timeout_ms: deadline.timeout_ms(), + }); + } + } + + Ok(()) + } +} diff --git a/src/executor/insert.rs b/src/executor/insert.rs index af70907..047e99b 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -6,14 +6,17 @@ use crate::planner::InsertNode; use super::filter::evaluate_const_expr; use super::{ - ExecutionError, Row, build_row_key, coerce_row_for_table, coerce_scalar_for_column, encode_row, + ExecutionContext, ExecutionError, Row, apply_staged_writes, build_row_key, + coerce_row_for_table, coerce_scalar_for_column, encode_row, staged_value_for_key, }; pub(crate) fn execute_insert( catalog: &Catalog, tx: &mut Transaction, node: &InsertNode, + context: &ExecutionContext<'_>, ) -> Result { + context.checkpoint()?; let table = catalog .get_table(&node.table) .ok_or_else(|| ExecutionError::TableNotFound(node.table.clone()))?; @@ -29,6 +32,7 @@ pub(crate) fn execute_insert( let mut seen = HashSet::new(); for (column_name, expr) in node.columns.iter().zip(node.values.iter()) { + context.checkpoint()?; if !seen.insert(column_name.clone()) { return Err(ExecutionError::DuplicateColumn(column_name.clone())); } @@ -44,12 +48,14 @@ pub(crate) fn execute_insert( let normalized = coerce_row_for_table(&table, &row)?; let key = build_row_key(&table, &normalized)?; - if tx.get(&key)?.is_some() { + let mut staged = std::collections::BTreeMap::new(); + if staged_value_for_key(tx, &staged, &key)?.is_some() { return Err(ExecutionError::PrimaryKeyConflict { table: table.name.clone() }); } let payload = encode_row(&table, &normalized)?; - tx.put(&key, &payload)?; + staged.insert(key, Some(payload)); + apply_staged_writes(tx, staged)?; Ok(1) } @@ -59,9 +65,18 @@ mod tests { use crate::catalog::column::ColumnDescriptor; use crate::catalog::schema::{ColumnType, DefaultValue}; use crate::catalog::table::TableDescriptor; + use crate::executor::ExecutionLimits; + use crate::executor::governance::ExecutionGovernance; use crate::mvcc::MvccStore; use crate::sql::ast::{Expr, LiteralValue}; + fn test_context() -> ExecutionContext<'static> { + ExecutionContext { + limits: Box::leak(Box::new(ExecutionLimits::default())), + governance: Box::leak(Box::new(ExecutionGovernance::default())), + } + } + #[test] fn inserts_row_with_default_values() { let store = MvccStore::new(); @@ -94,7 +109,7 @@ mod tests { }; let mut tx = store.begin_transaction(); - let affected = execute_insert(&catalog, &mut tx, &node).expect("insert"); + let affected = execute_insert(&catalog, &mut tx, &node, &test_context()).expect("insert"); assert_eq!(affected, 1); } } diff --git a/src/executor/join.rs b/src/executor/join.rs index 01d8eb0..84e1c84 100644 --- a/src/executor/join.rs +++ b/src/executor/join.rs @@ -1,29 +1,31 @@ use crate::sql::ast::Expr; use super::filter::evaluate_predicate; -use super::{ExecutionError, ExecutionLimits, Row, RowSet}; +use super::{ExecutionContext, ExecutionError, Row, RowSet}; pub(crate) fn execute_join( left: RowSet, right: RowSet, predicate: &Expr, - limits: &ExecutionLimits, + context: &ExecutionContext<'_>, ) -> Result { let right_prefix = right.table_name.clone().unwrap_or_else(|| "right".to_string()); let output_columns = build_join_columns(&left.columns, &right.columns, &right_prefix); - limits.ensure_join_rows(left.rows.len())?; - limits.ensure_join_rows(right.rows.len())?; + context.limits.ensure_join_rows(left.rows.len())?; + context.limits.ensure_join_rows(right.rows.len())?; let candidate_pairs = left.rows.len().saturating_mul(right.rows.len()); - limits.ensure_join_rows(candidate_pairs)?; + context.limits.ensure_join_rows(candidate_pairs)?; let mut joined_rows = Vec::new(); for left_row in &left.rows { + context.checkpoint()?; for right_row in &right.rows { + context.checkpoint()?; let merged = merge_rows(left_row, right_row, &right_prefix); if evaluate_predicate(predicate, &merged, None)? { joined_rows.push(merged); - limits.ensure_join_rows(joined_rows.len())?; + context.limits.ensure_join_rows(joined_rows.len())?; } } } @@ -57,10 +59,21 @@ fn merge_rows(left: &Row, right: &Row, right_prefix: &str) -> Row { #[cfg(test)] mod tests { + use std::thread; + use std::time::Duration; + use super::*; - use crate::executor::ScalarValue; + use crate::executor::governance::{ExecutionGovernance, StatementCancellation}; + use crate::executor::{ExecutionLimits, ScalarValue}; use crate::sql::ast::{BinaryOp, LiteralValue}; + fn test_context(limits: ExecutionLimits) -> ExecutionContext<'static> { + ExecutionContext { + limits: Box::leak(Box::new(limits)), + governance: Box::leak(Box::new(ExecutionGovernance::default())), + } + } + #[test] fn joins_rows_with_predicate() { let mut left_row = Row::new(); @@ -89,7 +102,8 @@ mod tests { }; let joined = - execute_join(left, right, &predicate, &ExecutionLimits::default()).expect("join"); + execute_join(left, right, &predicate, &test_context(ExecutionLimits::default())) + .expect("join"); assert_eq!(joined.rows.len(), 1); assert!(joined.rows[0].contains_key("profiles.id")); } @@ -127,7 +141,7 @@ mod tests { left, right, &predicate, - &ExecutionLimits { max_join_rows: 3, ..ExecutionLimits::default() }, + &test_context(ExecutionLimits { max_join_rows: 3, ..ExecutionLimits::default() }), ) .expect_err("join limit should fail"); assert!(matches!( @@ -136,4 +150,47 @@ mod tests { if resource == "join rows" && limit == 3 )); } + + #[test] + fn cancels_long_running_join() { + let mut left_rows = Vec::new(); + let mut right_rows = Vec::new(); + for id in 0..500_i64 { + let mut row = Row::new(); + row.insert("id".to_string(), ScalarValue::BigInt(id)); + left_rows.push(row.clone()); + right_rows.push(row); + } + + let left = RowSet { + columns: vec!["id".to_string()], + rows: left_rows, + table_name: Some("users".to_string()), + }; + let right = RowSet { + columns: vec!["id".to_string()], + rows: right_rows, + table_name: Some("profiles".to_string()), + }; + + let cancellation = StatementCancellation::new(); + let worker_cancellation = cancellation.clone(); + let handle = thread::spawn(move || { + let limits = Box::leak(Box::new(ExecutionLimits::default())); + let governance = Box::leak(Box::new( + ExecutionGovernance::default().with_cancellation(worker_cancellation), + )); + execute_join( + left, + right, + &Expr::Literal(LiteralValue::Boolean(true)), + &ExecutionContext { limits, governance }, + ) + }); + + thread::sleep(Duration::from_millis(1)); + cancellation.cancel(); + let err = handle.join().expect("join thread").expect_err("join should be canceled"); + assert!(matches!(err, ExecutionError::StatementCanceled { .. })); + } } diff --git a/src/executor/mod.rs b/src/executor/mod.rs index 1cd06c5..cd83eba 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -1,5 +1,6 @@ pub mod delete; pub mod filter; +pub mod governance; pub mod insert; pub mod join; pub mod projection; @@ -18,6 +19,8 @@ use crate::mvcc::{MvccStore, Transaction, TransactionError}; use crate::planner::PhysicalPlan; use crate::sql::ast::LiteralValue; +use self::governance::ExecutionGovernance; + #[derive(Debug, Clone, PartialEq)] pub enum ScalarValue { Integer(i32), @@ -34,6 +37,19 @@ impl ScalarValue { pub fn is_null(&self) -> bool { matches!(self, ScalarValue::Null) } + + fn estimated_query_bytes(&self) -> usize { + match self { + ScalarValue::Integer(_) => std::mem::size_of::(), + ScalarValue::BigInt(_) => std::mem::size_of::(), + ScalarValue::Float(_) => std::mem::size_of::(), + ScalarValue::Text(value) => value.len(), + ScalarValue::Boolean(_) => 1, + ScalarValue::Blob(value) => value.len(), + ScalarValue::Timestamp(_) => std::mem::size_of::(), + ScalarValue::Null => 0, + } + } } pub type Row = BTreeMap; @@ -91,6 +107,10 @@ pub enum ExecutionError { DdlInTransactionUnsupported, #[error("resource limit exceeded for {resource}: actual={actual}, limit={limit}")] ResourceLimitExceeded { resource: &'static str, actual: usize, limit: usize }, + #[error("statement timed out after {timeout_ms} ms")] + StatementTimedOut { timeout_ms: u64 }, + #[error("statement canceled: {reason}")] + StatementCanceled { reason: &'static str }, #[error("unsupported plan: {0}")] UnsupportedPlan(&'static str), } @@ -101,6 +121,7 @@ pub struct ExecutionLimits { pub max_sort_rows: usize, pub max_join_rows: usize, pub max_query_result_rows: usize, + pub max_query_result_bytes: usize, } impl Default for ExecutionLimits { @@ -110,6 +131,7 @@ impl Default for ExecutionLimits { max_sort_rows: usize::MAX, max_join_rows: usize::MAX, max_query_result_rows: usize::MAX, + max_query_result_bytes: usize::MAX, } } } @@ -142,6 +164,10 @@ impl ExecutionLimits { fn ensure_query_result_rows(&self, actual: usize) -> Result<(), ExecutionError> { self.ensure_within("query result rows", actual, self.max_query_result_rows) } + + fn ensure_query_result_bytes(&self, actual: usize) -> Result<(), ExecutionError> { + self.ensure_within("query result bytes", actual, self.max_query_result_bytes) + } } #[derive(Debug, Clone, PartialEq)] @@ -152,18 +178,29 @@ pub(crate) struct RowSet { } impl RowSet { - pub(crate) fn into_query_result(self) -> QueryResult { + pub(crate) fn into_query_result( + self, + context: &ExecutionContext<'_>, + ) -> Result { let mut materialized_rows = Vec::with_capacity(self.rows.len()); + let mut materialized_bytes = 0_usize; for row in self.rows { + context.checkpoint()?; let values = self .columns .iter() .map(|column| row.get(column).cloned().unwrap_or(ScalarValue::Null)) .collect::>(); + for value in &values { + materialized_bytes = + materialized_bytes.saturating_add(value.estimated_query_bytes()); + context.limits.ensure_query_result_bytes(materialized_bytes)?; + } materialized_rows.push(values); + context.limits.ensure_query_result_rows(materialized_rows.len())?; } - QueryResult { columns: self.columns, rows: materialized_rows } + Ok(QueryResult { columns: self.columns, rows: materialized_rows }) } } @@ -198,6 +235,17 @@ impl<'a> ExecutionSession<'a> { } pub fn execute_plan(&mut self, plan: &PhysicalPlan) -> Result { + self.execute_plan_with_governance(plan, &ExecutionGovernance::default()) + } + + pub fn execute_plan_with_governance( + &mut self, + plan: &PhysicalPlan, + governance: &ExecutionGovernance, + ) -> Result { + let limits = self.limits; + let context = ExecutionContext { limits: &limits, governance }; + context.checkpoint()?; match plan { PhysicalPlan::Begin(_) => { if self.active_tx.is_some() { @@ -232,13 +280,13 @@ impl<'a> ExecutionSession<'a> { Ok(ExecutionResult::AffectedRows(0)) } PhysicalPlan::Insert(node) => self - .with_write_tx(|catalog, tx| insert::execute_insert(catalog, tx, node)) + .with_write_tx(|catalog, tx| insert::execute_insert(catalog, tx, node, &context)) .map(ExecutionResult::AffectedRows), PhysicalPlan::Update(node) => self - .with_write_tx(|catalog, tx| update::execute_update(catalog, tx, node)) + .with_write_tx(|catalog, tx| update::execute_update(catalog, tx, node, &context)) .map(ExecutionResult::AffectedRows), PhysicalPlan::Delete(node) => self - .with_write_tx(|catalog, tx| delete::execute_delete(catalog, tx, node)) + .with_write_tx(|catalog, tx| delete::execute_delete(catalog, tx, node, &context)) .map(ExecutionResult::AffectedRows), PhysicalPlan::SeqScan(_) | PhysicalPlan::PrimaryKeyScan(_) @@ -246,14 +294,18 @@ impl<'a> ExecutionSession<'a> { | PhysicalPlan::Project(_) | PhysicalPlan::Sort(_) | PhysicalPlan::Limit(_) - | PhysicalPlan::Join(_) => { - let limits = self.limits; - self.with_read_tx(|catalog, tx| execute_query_plan(catalog, tx, plan, &limits)) - .and_then(|rows| { - limits.ensure_query_result_rows(rows.rows.len())?; - Ok(ExecutionResult::Query(rows.into_query_result())) - }) - } + | PhysicalPlan::Join(_) => self + .with_read_tx(|catalog, tx| execute_query_plan(catalog, tx, plan, &context)) + .and_then(|rows| Ok(ExecutionResult::Query(rows.into_query_result(&context)?))), + } + } + + pub fn abort_active_transaction(&mut self) -> bool { + if let Some(mut tx) = self.active_tx.take() { + tx.rollback(); + true + } else { + false } } @@ -298,36 +350,75 @@ fn execute_query_plan( catalog: &Catalog, tx: &mut Transaction, plan: &PhysicalPlan, - limits: &ExecutionLimits, + context: &ExecutionContext<'_>, ) -> Result { + context.checkpoint()?; match plan { - PhysicalPlan::SeqScan(node) => scan::execute_seq_scan(catalog, tx, node, limits), - PhysicalPlan::PrimaryKeyScan(node) => scan::execute_primary_key_scan(catalog, tx, node), + PhysicalPlan::SeqScan(node) => scan::execute_seq_scan(catalog, tx, node, context), + PhysicalPlan::PrimaryKeyScan(node) => { + scan::execute_primary_key_scan(catalog, tx, node, context) + } PhysicalPlan::Filter(node) => { - let input = execute_query_plan(catalog, tx, &node.input, limits)?; - filter::apply_filter(input, &node.predicate) + let input = execute_query_plan(catalog, tx, &node.input, context)?; + filter::apply_filter(input, &node.predicate, context) } PhysicalPlan::Project(node) => { - let input = execute_query_plan(catalog, tx, &node.input, limits)?; - projection::apply_projection(input, &node.projection) + let input = execute_query_plan(catalog, tx, &node.input, context)?; + projection::apply_projection(input, &node.projection, context) } PhysicalPlan::Sort(node) => { - let input = execute_query_plan(catalog, tx, &node.input, limits)?; - projection::apply_sort(input, &node.order_by, limits) + let input = execute_query_plan(catalog, tx, &node.input, context)?; + projection::apply_sort(input, &node.order_by, context) } PhysicalPlan::Limit(node) => { - let input = execute_query_plan(catalog, tx, &node.input, limits)?; - projection::apply_limit(input, node.limit) + let input = execute_query_plan(catalog, tx, &node.input, context)?; + projection::apply_limit(input, node.limit, context) } PhysicalPlan::Join(node) => { - let left = execute_query_plan(catalog, tx, &node.left, limits)?; - let right = execute_query_plan(catalog, tx, &node.right, limits)?; - join::execute_join(left, right, &node.on, limits) + let left = execute_query_plan(catalog, tx, &node.left, context)?; + let right = execute_query_plan(catalog, tx, &node.right, context)?; + join::execute_join(left, right, &node.on, context) } _ => Err(ExecutionError::UnsupportedPlan("non-query node used in query execution path")), } } +#[derive(Clone, Copy)] +pub(crate) struct ExecutionContext<'a> { + pub(crate) limits: &'a ExecutionLimits, + pub(crate) governance: &'a ExecutionGovernance, +} + +impl<'a> ExecutionContext<'a> { + pub(crate) fn checkpoint(&self) -> Result<(), ExecutionError> { + self.governance.checkpoint() + } +} + +pub(crate) fn staged_value_for_key( + tx: &Transaction, + staged: &BTreeMap, Option>>, + key: &[u8], +) -> Result>, ExecutionError> { + if let Some(value) = staged.get(key) { + return Ok(value.clone()); + } + Ok(tx.get(key)?) +} + +pub(crate) fn apply_staged_writes( + tx: &mut Transaction, + staged: BTreeMap, Option>>, +) -> Result<(), ExecutionError> { + for (key, value) in staged { + match value { + Some(value) => tx.put(&key, &value)?, + None => tx.delete(&key)?, + } + } + Ok(()) +} + fn create_statement_to_descriptor( create: &crate::sql::ast::CreateTableStatement, ) -> Result { diff --git a/src/executor/projection.rs b/src/executor/projection.rs index 98fd8d4..57eb0eb 100644 --- a/src/executor/projection.rs +++ b/src/executor/projection.rs @@ -3,11 +3,12 @@ use std::cmp::Ordering; use crate::sql::ast::{Expr, OrderByExpr, SelectItem, SortDirection}; use super::filter::evaluate_expr; -use super::{ExecutionError, ExecutionLimits, RowSet, ScalarValue}; +use super::{ExecutionContext, ExecutionError, RowSet, ScalarValue}; pub(crate) fn apply_projection( input: RowSet, projection: &[SelectItem], + context: &ExecutionContext<'_>, ) -> Result { if projection.len() == 1 && matches!(projection[0], SelectItem::Wildcard) { return Ok(input); @@ -23,6 +24,7 @@ pub(crate) fn apply_projection( let mut projected_rows = Vec::with_capacity(rows.len()); for row in rows { + context.checkpoint()?; let source = row.clone(); let mut materialized = row; for (index, item) in projection.iter().enumerate() { @@ -42,18 +44,19 @@ pub(crate) fn apply_projection( pub(crate) fn apply_sort( input: RowSet, order_by: &[OrderByExpr], - limits: &ExecutionLimits, + context: &ExecutionContext<'_>, ) -> Result { if order_by.is_empty() { return Ok(input); } let RowSet { columns, rows, table_name } = input; - limits.ensure_sort_rows(rows.len())?; + context.limits.ensure_sort_rows(rows.len())?; let mut keyed = rows .into_iter() .map(|row| { + context.checkpoint()?; let sort_keys = order_by .iter() .map(|entry| evaluate_expr(&entry.expr, &row, table_name.as_deref())) @@ -70,7 +73,12 @@ pub(crate) fn apply_sort( Ok(RowSet { columns, rows, table_name }) } -pub(crate) fn apply_limit(mut input: RowSet, limit: u64) -> Result { +pub(crate) fn apply_limit( + mut input: RowSet, + limit: u64, + context: &ExecutionContext<'_>, +) -> Result { + context.checkpoint()?; let limit = usize::try_from(limit).unwrap_or(usize::MAX); input.rows.truncate(limit); Ok(input) @@ -148,9 +156,18 @@ fn scalar_sort_tag(value: &ScalarValue) -> &'static str { #[cfg(test)] mod tests { use super::*; + use crate::executor::ExecutionLimits; + use crate::executor::governance::ExecutionGovernance; use crate::executor::{Row, ScalarValue}; use crate::sql::ast::BinaryOp; + fn test_context() -> ExecutionContext<'static> { + ExecutionContext { + limits: Box::leak(Box::new(ExecutionLimits::default())), + governance: Box::leak(Box::new(ExecutionGovernance::default())), + } + } + #[test] fn projects_expression_columns() { let mut row = Row::new(); @@ -170,6 +187,7 @@ mod tests { op: BinaryOp::Add, right: Box::new(Expr::Identifier("b".to_string())), })], + &test_context(), ) .expect("projection"); @@ -196,7 +214,7 @@ mod tests { expr: Expr::Identifier("id".to_string()), direction: SortDirection::Desc, }], - &ExecutionLimits::default(), + &test_context(), ) .expect("sort"); diff --git a/src/executor/scan.rs b/src/executor/scan.rs index 9355d2f..095d65c 100644 --- a/src/executor/scan.rs +++ b/src/executor/scan.rs @@ -4,7 +4,7 @@ use crate::mvcc::Transaction; use crate::planner::{PrimaryKeyScanNode, SeqScanNode}; use super::{ - ExecutionError, ExecutionLimits, Row, RowSet, StoredRow, build_row_key, + ExecutionContext, ExecutionError, Row, RowSet, StoredRow, build_row_key, coerce_scalar_for_column, decode_row, literal_to_scalar, table_rows_prefix, }; @@ -12,11 +12,12 @@ pub(crate) fn execute_seq_scan( catalog: &Catalog, tx: &Transaction, node: &SeqScanNode, - limits: &ExecutionLimits, + context: &ExecutionContext<'_>, ) -> Result { - let (_, stored_rows) = scan_table_rows(catalog, tx, &node.table, limits.max_scan_rows)?; + let (_, stored_rows) = + scan_table_rows(catalog, tx, &node.table, context.limits.max_scan_rows, context)?; let rows = stored_rows.into_iter().map(|stored| stored.values).collect::>(); - limits.ensure_scan_rows(rows.len())?; + context.limits.ensure_scan_rows(rows.len())?; Ok(RowSet { columns: node.output_columns.clone(), rows, table_name: Some(node.table.clone()) }) } @@ -25,7 +26,9 @@ pub(crate) fn execute_primary_key_scan( catalog: &Catalog, tx: &Transaction, node: &PrimaryKeyScanNode, + context: &ExecutionContext<'_>, ) -> Result { + context.checkpoint()?; let table = get_table(catalog, &node.table)?; let mut key_row = Row::new(); @@ -56,14 +59,41 @@ pub(crate) fn scan_table_rows( tx: &Transaction, table_name: &str, max_rows: usize, + context: &ExecutionContext<'_>, ) -> Result<(TableDescriptor, Vec), ExecutionError> { let table = get_table(catalog, table_name)?; let prefix = table_rows_prefix(&table.name); + let mut governance_error = None; let rows = if max_rows == usize::MAX { - tx.scan_prefix(&prefix)? + tx.scan_prefix_with_observer(&prefix, |seen| { + if seen == 0 { + return true; + } + match context.checkpoint() { + Ok(()) => true, + Err(err) => { + governance_error = Some(err); + false + } + } + })? } else { - tx.scan_prefix_limited(&prefix, max_rows)? + tx.scan_prefix_limited_with_observer(&prefix, max_rows, |seen| { + if seen == 0 { + return true; + } + match context.checkpoint() { + Ok(()) => true, + Err(err) => { + governance_error = Some(err); + false + } + } + })? }; + if let Some(err) = governance_error { + return Err(err); + } if rows.len() > max_rows { return Err(ExecutionError::ResourceLimitExceeded { @@ -73,10 +103,12 @@ pub(crate) fn scan_table_rows( }); } - let decoded_rows = rows - .into_iter() - .map(|(key, payload)| decode_row(&table, &payload).map(|values| StoredRow { key, values })) - .collect::, _>>()?; + let mut decoded_rows = Vec::with_capacity(rows.len()); + for (key, payload) in rows { + context.checkpoint()?; + let values = decode_row(&table, &payload)?; + decoded_rows.push(StoredRow { key, values }); + } Ok((table, decoded_rows)) } diff --git a/src/executor/update.rs b/src/executor/update.rs index a088202..282f6b5 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -5,21 +5,26 @@ use crate::planner::UpdateNode; use super::filter::{evaluate_expr, evaluate_predicate}; use super::scan::scan_table_rows; use super::{ - ExecutionError, build_row_key, coerce_row_for_table, coerce_scalar_for_column, encode_row, + ExecutionContext, ExecutionError, apply_staged_writes, build_row_key, coerce_row_for_table, + coerce_scalar_for_column, encode_row, staged_value_for_key, }; pub(crate) fn execute_update( catalog: &Catalog, tx: &mut Transaction, node: &UpdateNode, + context: &ExecutionContext<'_>, ) -> Result { + context.checkpoint()?; let table = catalog .get_table(&node.table) .ok_or_else(|| ExecutionError::TableNotFound(node.table.clone()))?; - let (_, rows) = scan_table_rows(catalog, tx, &table.name, usize::MAX)?; + let (_, rows) = scan_table_rows(catalog, tx, &table.name, usize::MAX, context)?; let mut affected = 0_u64; + let mut staged = std::collections::BTreeMap::new(); for stored in rows { + context.checkpoint()?; if let Some(predicate) = &node.predicate { if !evaluate_predicate(predicate, &stored.values, Some(&table.name))? { continue; @@ -43,16 +48,17 @@ pub(crate) fn execute_update( let normalized = coerce_row_for_table(&table, &updated)?; let new_key = build_row_key(&table, &normalized)?; if new_key != stored.key { - if tx.get(&new_key)?.is_some() { + if staged_value_for_key(tx, &staged, &new_key)?.is_some() { return Err(ExecutionError::PrimaryKeyConflict { table: table.name.clone() }); } - tx.delete(&stored.key)?; + staged.insert(stored.key.clone(), None); } let payload = encode_row(&table, &normalized)?; - tx.put(&new_key, &payload)?; + staged.insert(new_key, Some(payload)); affected = affected.saturating_add(1); } + apply_staged_writes(tx, staged)?; Ok(affected) } diff --git a/src/mvcc/transaction.rs b/src/mvcc/transaction.rs index 560ae39..6780465 100644 --- a/src/mvcc/transaction.rs +++ b/src/mvcc/transaction.rs @@ -222,28 +222,7 @@ impl MvccStore { } pub fn scan_prefix_at(&self, prefix: &[u8], read_ts: u64) -> Vec<(Vec, Vec)> { - let data = self.inner.data.read(); - let mut rows = Vec::new(); - - for (key, versions) in &data.versions { - if !key.starts_with(prefix) { - continue; - } - - for version in versions.iter().rev() { - if version.commit_ts > read_ts { - continue; - } - - if let Some(value) = &version.value { - rows.push((key.clone(), value.clone())); - } - break; - } - } - - rows.sort_by(|a, b| a.0.cmp(&b.0)); - rows + self.scan_prefix_at_with_observer(prefix, read_ts, |_| true) } pub fn scan_prefix_at_limited( @@ -252,6 +231,31 @@ impl MvccStore { read_ts: u64, max_rows: usize, ) -> Vec<(Vec, Vec)> { + self.scan_prefix_at_limited_with_observer(prefix, read_ts, max_rows, |_| true) + } + + pub fn scan_prefix_at_with_observer( + &self, + prefix: &[u8], + read_ts: u64, + observer: F, + ) -> Vec<(Vec, Vec)> + where + F: FnMut(usize) -> bool, + { + self.scan_prefix_at_limited_with_observer(prefix, read_ts, usize::MAX, observer) + } + + pub fn scan_prefix_at_limited_with_observer( + &self, + prefix: &[u8], + read_ts: u64, + max_rows: usize, + mut observer: F, + ) -> Vec<(Vec, Vec)> + where + F: FnMut(usize) -> bool, + { let data = self.inner.data.read(); let mut rows = Vec::new(); @@ -267,6 +271,10 @@ impl MvccStore { if let Some(value) = &version.value { rows.push((key.clone(), value.clone())); + if !observer(rows.len()) { + rows.sort_by(|a, b| a.0.cmp(&b.0)); + return rows; + } if rows.len() >= max_rows { rows.sort_by(|a, b| a.0.cmp(&b.0)); return rows; @@ -490,11 +498,34 @@ impl Transaction { Ok(visible.into_iter().collect()) } + pub fn scan_prefix_with_observer( + &self, + prefix: &[u8], + observer: F, + ) -> Result, Vec)>, TransactionError> + where + F: FnMut(usize) -> bool, + { + self.scan_prefix_limited_with_observer(prefix, usize::MAX, observer) + } + pub fn scan_prefix_limited( &self, prefix: &[u8], max_rows: usize, ) -> Result, Vec)>, TransactionError> { + self.scan_prefix_limited_with_observer(prefix, max_rows, |_| true) + } + + pub fn scan_prefix_limited_with_observer( + &self, + prefix: &[u8], + max_rows: usize, + mut observer: F, + ) -> Result, Vec)>, TransactionError> + where + F: FnMut(usize) -> bool, + { if self.closed { return Err(TransactionError::Closed); } @@ -504,10 +535,13 @@ impl Transaction { let fetch_limit = max_rows.saturating_add(write_overlap).saturating_add(1); let mut visible = self .store - .scan_prefix_at_limited(prefix, read_ts, fetch_limit) + .scan_prefix_at_limited_with_observer(prefix, read_ts, fetch_limit, |seen| { + observer(seen) + }) .into_iter() .collect::>(); + let mut observed = visible.len(); for (key, value) in &self.writes { if !key.starts_with(prefix) { continue; @@ -521,6 +555,10 @@ impl Transaction { visible.remove(key); } } + observed = observed.saturating_add(1); + if !observer(observed) { + break; + } } Ok(visible.into_iter().collect()) diff --git a/src/server/mod.rs b/src/server/mod.rs index d377324..8217401 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -2,10 +2,10 @@ pub mod protocol; pub mod tcp; pub use protocol::{ - AdminStatusPayload, ErrorCode, ErrorPayload, HealthPayload, PROTOCOL_VERSION, ProtocolError, - QueryPayload, ReadinessPayload, RequestFrame, RequestType, ResponseFrame, ResponsePayload, - TransactionState, read_request, read_request_with_limit, read_response, write_request, - write_response, + ActiveStatementPayload, ActiveStatementsPayload, AdminStatusPayload, ErrorCode, ErrorPayload, + HealthPayload, PROTOCOL_VERSION, ProtocolError, QueryPayload, ReadinessPayload, RequestFrame, + RequestType, ResponseFrame, ResponsePayload, StatementCancellationPayload, TransactionState, + read_request, read_request_with_limit, read_response, write_request, write_response, }; pub use tcp::{ ServerError, ServerHandle, ServerLimits, ServerOptions, start_server, start_server_with_options, diff --git a/src/server/protocol.rs b/src/server/protocol.rs index db31aab..d4b6203 100644 --- a/src/server/protocol.rs +++ b/src/server/protocol.rs @@ -18,6 +18,8 @@ pub enum RequestType { Health = 6, Readiness = 7, AdminStatus = 8, + ActiveStatements = 9, + CancelStatement = 10, } impl TryFrom for RequestType { @@ -33,6 +35,8 @@ impl TryFrom for RequestType { 6 => Ok(RequestType::Health), 7 => Ok(RequestType::Readiness), 8 => Ok(RequestType::AdminStatus), + 9 => Ok(RequestType::ActiveStatements), + 10 => Ok(RequestType::CancelStatement), other => { Err(ProtocolError::InvalidFrame(format!("unknown request type byte: {other}"))) } @@ -62,6 +66,9 @@ pub enum ErrorCode { Execution = 5, Busy = 6, ResourceLimit = 7, + Timeout = 8, + Canceled = 9, + Quota = 10, } impl TryFrom for ErrorCode { @@ -76,6 +83,9 @@ impl TryFrom for ErrorCode { 5 => Ok(ErrorCode::Execution), 6 => Ok(ErrorCode::Busy), 7 => Ok(ErrorCode::ResourceLimit), + 8 => Ok(ErrorCode::Timeout), + 9 => Ok(ErrorCode::Canceled), + 10 => Ok(ErrorCode::Quota), other => Err(ProtocolError::InvalidFrame(format!("unknown error code byte: {other}"))), } } @@ -91,6 +101,9 @@ impl ErrorCode { ErrorCode::Execution => "EXECUTION", ErrorCode::Busy => "BUSY", ErrorCode::ResourceLimit => "RESOURCE_LIMIT", + ErrorCode::Timeout => "TIMEOUT", + ErrorCode::Canceled => "CANCELED", + ErrorCode::Quota => "QUOTA", } } } @@ -117,6 +130,8 @@ pub enum ResponsePayload { Health(HealthPayload), Readiness(ReadinessPayload), AdminStatus(AdminStatusPayload), + ActiveStatements(ActiveStatementsPayload), + StatementCancellation(StatementCancellationPayload), } #[derive(Debug, Clone, PartialEq)] @@ -148,6 +163,10 @@ pub struct AdminStatusPayload { pub rejected_connections: u64, pub busy_requests: u64, pub resource_limit_requests: u64, + pub quota_rejections: u64, + pub timed_out_requests: u64, + pub canceled_requests: u64, + pub active_statements: u64, pub active_memory_intensive_requests: u64, pub mvcc_started: u64, pub mvcc_committed: u64, @@ -156,6 +175,29 @@ pub struct AdminStatusPayload { pub mvcc_active_transactions: u64, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ActiveStatementsPayload { + pub statements: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ActiveStatementPayload { + pub statement_id: u64, + pub connection_id: u64, + pub identity: String, + pub request_type: String, + pub runtime_ms: u64, + pub cancel_requested: bool, + pub sql_preview: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StatementCancellationPayload { + pub statement_id: u64, + pub accepted: bool, + pub status: String, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum TransactionState { @@ -443,6 +485,10 @@ fn encode_payload(payload: &ResponsePayload, out: &mut Vec) -> Result<(), Pr out.extend(status.rejected_connections.to_be_bytes()); out.extend(status.busy_requests.to_be_bytes()); out.extend(status.resource_limit_requests.to_be_bytes()); + out.extend(status.quota_rejections.to_be_bytes()); + out.extend(status.timed_out_requests.to_be_bytes()); + out.extend(status.canceled_requests.to_be_bytes()); + out.extend(status.active_statements.to_be_bytes()); out.extend(status.active_memory_intensive_requests.to_be_bytes()); out.extend(status.mvcc_started.to_be_bytes()); out.extend(status.mvcc_committed.to_be_bytes()); @@ -450,6 +496,25 @@ fn encode_payload(payload: &ResponsePayload, out: &mut Vec) -> Result<(), Pr out.extend(status.mvcc_write_conflicts.to_be_bytes()); out.extend(status.mvcc_active_transactions.to_be_bytes()); } + ResponsePayload::ActiveStatements(payload) => { + out.push(8_u8); + write_u32(out, payload.statements.len())?; + for statement in &payload.statements { + out.extend(statement.statement_id.to_be_bytes()); + out.extend(statement.connection_id.to_be_bytes()); + write_len_prefixed_bytes(out, statement.identity.as_bytes())?; + write_len_prefixed_bytes(out, statement.request_type.as_bytes())?; + out.extend(statement.runtime_ms.to_be_bytes()); + out.push(u8::from(statement.cancel_requested)); + write_len_prefixed_bytes(out, statement.sql_preview.as_bytes())?; + } + } + ResponsePayload::StatementCancellation(payload) => { + out.push(9_u8); + out.extend(payload.statement_id.to_be_bytes()); + out.push(u8::from(payload.accepted)); + write_len_prefixed_bytes(out, payload.status.as_bytes())?; + } } Ok(()) } @@ -518,6 +583,10 @@ fn decode_payload(cursor: &mut Cursor<&[u8]>) -> Result) -> Result) -> Result { + let count = read_u32(cursor)? as usize; + let mut statements = Vec::with_capacity(count); + for _ in 0..count { + let statement_id = read_u64(cursor)?; + let connection_id = read_u64(cursor)?; + let identity = read_len_prefixed_string(cursor)?; + let request_type = read_len_prefixed_string(cursor)?; + let runtime_ms = read_u64(cursor)?; + let cancel_requested = read_bool(cursor)?; + let sql_preview = read_len_prefixed_string(cursor)?; + statements.push(ActiveStatementPayload { + statement_id, + connection_id, + identity, + request_type, + runtime_ms, + cancel_requested, + sql_preview, + }); + } + Ok(ResponsePayload::ActiveStatements(ActiveStatementsPayload { statements })) + } + 9 => { + let statement_id = read_u64(cursor)?; + let accepted = read_bool(cursor)?; + let status = read_len_prefixed_string(cursor)?; + Ok(ResponsePayload::StatementCancellation(StatementCancellationPayload { + statement_id, + accepted, + status, + })) + } other => { Err(ProtocolError::InvalidFrame(format!("unknown response payload type: {other}"))) } @@ -687,6 +793,10 @@ mod tests { rejected_connections: 2, busy_requests: 3, resource_limit_requests: 1, + quota_rejections: 4, + timed_out_requests: 5, + canceled_requests: 6, + active_statements: 7, active_memory_intensive_requests: 0, mvcc_started: 12, mvcc_committed: 9, @@ -700,6 +810,26 @@ mod tests { assert_eq!(decoded, response); } + #[tokio::test] + async fn active_statements_payload_round_trip() { + let response = + ResponseFrame::Ok(ResponsePayload::ActiveStatements(ActiveStatementsPayload { + statements: vec![ActiveStatementPayload { + statement_id: 11, + connection_id: 3, + identity: "127.0.0.1".to_string(), + request_type: "QUERY".to_string(), + runtime_ms: 27, + cancel_requested: false, + sql_preview: "SELECT * FROM users".to_string(), + }], + })); + let (mut client, mut server) = tokio::io::duplex(2048); + write_response(&mut client, &response).await.expect("write response"); + let decoded = read_response(&mut server).await.expect("read response").expect("response"); + assert_eq!(decoded, response); + } + #[tokio::test] async fn request_frame_limit_rejects_oversized_body() { let request = RequestFrame { request_type: RequestType::Query, sql: "SELECT 1".repeat(64) }; diff --git a/src/server/tcp.rs b/src/server/tcp.rs index 98bc824..664dd30 100644 --- a/src/server/tcp.rs +++ b/src/server/tcp.rs @@ -1,8 +1,10 @@ +use std::collections::{BTreeMap, HashMap}; use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::time::Instant; +use std::time::{Duration, Instant}; +use parking_lot::Mutex; use thiserror::Error; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{OwnedSemaphorePermit, Semaphore, oneshot}; @@ -10,6 +12,7 @@ use tokio::task::JoinHandle; use tracing::{Instrument, debug, error, info, info_span, warn}; use crate::catalog::Catalog; +use crate::executor::governance::{ExecutionGovernance, StatementCancellation}; use crate::executor::{ExecutionError, ExecutionLimits, ExecutionSession}; use crate::mvcc::MvccStore; use crate::planner::{PhysicalPlan, PlannerError, plan_statement}; @@ -17,9 +20,10 @@ use crate::sql::parser::{ParseError, parse_sql}; use crate::sql::validator::{ValidationError, validate_statement}; use super::protocol::{ - AdminStatusPayload, ErrorCode, ErrorPayload, HealthPayload, PROTOCOL_VERSION, ProtocolError, - ReadinessPayload, RequestFrame, RequestType, ResponseFrame, ResponsePayload, - payload_from_execution_result, read_request_with_limit, write_response, + ActiveStatementPayload, ActiveStatementsPayload, AdminStatusPayload, ErrorCode, ErrorPayload, + HealthPayload, PROTOCOL_VERSION, ProtocolError, ReadinessPayload, RequestFrame, RequestType, + ResponseFrame, ResponsePayload, StatementCancellationPayload, payload_from_execution_result, + read_request_with_limit, write_response, }; static NEXT_CONNECTION_ID: AtomicU64 = AtomicU64::new(1); @@ -30,11 +34,14 @@ pub struct ServerLimits { pub max_in_flight_requests_per_connection: usize, pub max_request_bytes: usize, pub max_statements_per_request: usize, + pub statement_timeout_ms: Option, pub max_memory_intensive_requests: usize, pub max_scan_rows: usize, pub max_sort_rows: usize, pub max_join_rows: usize, pub max_query_result_rows: usize, + pub max_query_result_bytes: usize, + pub max_concurrent_queries_per_identity: Option, } impl Default for ServerLimits { @@ -44,11 +51,14 @@ impl Default for ServerLimits { max_in_flight_requests_per_connection: 1, max_request_bytes: 256 * 1024, max_statements_per_request: 16, + statement_timeout_ms: None, max_memory_intensive_requests: 8, max_scan_rows: 10_000, max_sort_rows: 10_000, max_join_rows: 10_000, max_query_result_rows: 10_000, + max_query_result_bytes: 4 * 1024 * 1024, + max_concurrent_queries_per_identity: None, } } } @@ -65,6 +75,7 @@ impl ServerLimits { ("max_sort_rows", self.max_sort_rows), ("max_join_rows", self.max_join_rows), ("max_query_result_rows", self.max_query_result_rows), + ("max_query_result_bytes", self.max_query_result_bytes), ] { if value == 0 { return Err(ServerError::InvalidConfiguration(format!( @@ -73,6 +84,18 @@ impl ServerLimits { } } + if matches!(self.statement_timeout_ms, Some(0)) { + return Err(ServerError::InvalidConfiguration( + "server limit 'statement_timeout_ms' must be > 0 when set".to_string(), + )); + } + if matches!(self.max_concurrent_queries_per_identity, Some(0)) { + return Err(ServerError::InvalidConfiguration( + "server limit 'max_concurrent_queries_per_identity' must be > 0 when set" + .to_string(), + )); + } + Ok(()) } @@ -82,6 +105,7 @@ impl ServerLimits { max_sort_rows: self.max_sort_rows, max_join_rows: self.max_join_rows, max_query_result_rows: self.max_query_result_rows, + max_query_result_bytes: self.max_query_result_bytes, } } } @@ -100,7 +124,13 @@ struct ServerRuntimeState { rejected_connections: AtomicU64, busy_requests: AtomicU64, resource_limit_requests: AtomicU64, + quota_rejections: AtomicU64, + timed_out_requests: AtomicU64, + canceled_requests: AtomicU64, active_memory_intensive_requests: AtomicU64, + next_statement_id: AtomicU64, + active_statements: Mutex>, + identity_query_counts: Mutex>, limits: ServerLimits, connection_slots: Arc, memory_intensive_slots: Arc, @@ -116,7 +146,13 @@ impl ServerRuntimeState { rejected_connections: AtomicU64::new(0), busy_requests: AtomicU64::new(0), resource_limit_requests: AtomicU64::new(0), + quota_rejections: AtomicU64::new(0), + timed_out_requests: AtomicU64::new(0), + canceled_requests: AtomicU64::new(0), active_memory_intensive_requests: AtomicU64::new(0), + next_statement_id: AtomicU64::new(1), + active_statements: Mutex::new(BTreeMap::new()), + identity_query_counts: Mutex::new(HashMap::new()), limits: options.limits, connection_slots: Arc::new(Semaphore::new(options.limits.max_concurrent_connections)), memory_intensive_slots: Arc::new(Semaphore::new( @@ -144,6 +180,145 @@ impl ServerRuntimeState { fn record_resource_limit_request(&self) { self.resource_limit_requests.fetch_add(1, Ordering::Relaxed); } + + fn record_quota_rejection(&self) { + self.quota_rejections.fetch_add(1, Ordering::Relaxed); + } + + fn record_timed_out_request(&self) { + self.timed_out_requests.fetch_add(1, Ordering::Relaxed); + } + + fn record_canceled_request(&self) { + self.canceled_requests.fetch_add(1, Ordering::Relaxed); + } + + fn begin_statement( + self: &Arc, + connection_id: u64, + identity: &str, + request_type: RequestType, + sql: &str, + ) -> Result { + if let Some(limit) = self.limits.max_concurrent_queries_per_identity { + let mut counts = self.identity_query_counts.lock(); + let current = counts.get(identity).copied().unwrap_or(0); + if current >= limit { + drop(counts); + return Err(RequestError::Quota(format!( + "identity '{identity}' exceeded concurrent query quota ({limit})" + ))); + } + counts.insert(identity.to_string(), current + 1); + } + + let statement_id = self.next_statement_id.fetch_add(1, Ordering::Relaxed); + let cancellation = StatementCancellation::new(); + let entry = ActiveStatementEntry { + statement_id, + connection_id, + identity: identity.to_string(), + request_type, + sql_preview: sql_preview(sql), + started_at: Instant::now(), + cancellation: cancellation.clone(), + }; + self.active_statements.lock().insert(statement_id, entry); + + Ok(StatementExecutionGuard { + runtime_state: Arc::clone(self), + statement_id, + identity: identity.to_string(), + cancellation, + }) + } + + fn finish_statement(&self, statement_id: u64, identity: &str) { + self.active_statements.lock().remove(&statement_id); + if self.limits.max_concurrent_queries_per_identity.is_some() { + let mut counts = self.identity_query_counts.lock(); + if let Some(current) = counts.get_mut(identity) { + if *current <= 1 { + counts.remove(identity); + } else { + *current -= 1; + } + } + } + } + + fn active_statement_payloads(&self) -> Vec { + self.active_statements + .lock() + .values() + .map(|entry| ActiveStatementPayload { + statement_id: entry.statement_id, + connection_id: entry.connection_id, + identity: entry.identity.clone(), + request_type: request_type_name(entry.request_type).to_string(), + runtime_ms: duration_to_millis(entry.started_at.elapsed()), + cancel_requested: entry.cancellation.reason().is_some(), + sql_preview: entry.sql_preview.clone(), + }) + .collect() + } + + fn cancel_statement(&self, statement_id: u64) -> StatementCancellationPayload { + let Some(statement) = self.active_statements.lock().get(&statement_id).cloned() else { + return StatementCancellationPayload { + statement_id, + accepted: false, + status: "statement not found".to_string(), + }; + }; + + let accepted = statement.cancellation.cancel(); + StatementCancellationPayload { + statement_id, + accepted, + status: if accepted { + "cancellation signaled".to_string() + } else { + "cancellation was already requested".to_string() + }, + } + } +} + +#[derive(Debug, Clone)] +struct ActiveStatementEntry { + statement_id: u64, + connection_id: u64, + identity: String, + request_type: RequestType, + sql_preview: String, + started_at: Instant, + cancellation: StatementCancellation, +} + +#[derive(Debug)] +struct StatementExecutionGuard { + runtime_state: Arc, + statement_id: u64, + identity: String, + cancellation: StatementCancellation, +} + +impl StatementExecutionGuard { + fn governance(&self, statement_timeout_ms: Option) -> ExecutionGovernance { + let mut governance = + ExecutionGovernance::default().with_cancellation(self.cancellation.clone()); + if let Some(timeout_ms) = statement_timeout_ms { + governance = governance.with_timeout(Duration::from_millis(timeout_ms)); + } + governance + } +} + +impl Drop for StatementExecutionGuard { + fn drop(&mut self) { + self.runtime_state.finish_statement(self.statement_id, &self.identity); + } } #[derive(Debug)] @@ -192,11 +367,13 @@ impl Drop for MemoryIntensiveRequestGuard { #[derive(Debug)] struct ConnectionContext { request_slots: Arc, + connection_id: u64, + identity: String, } impl ConnectionContext { - fn new(limit: usize) -> Self { - Self { request_slots: Arc::new(Semaphore::new(limit)) } + fn new(limit: usize, connection_id: u64, identity: String) -> Self { + Self { request_slots: Arc::new(Semaphore::new(limit)), connection_id, identity } } fn try_acquire_request_slot(&self) -> Option { @@ -228,6 +405,8 @@ enum RequestError { Busy(String), #[error("resource limit exceeded: {0}")] ResourceLimit(String), + #[error("quota exceeded: {0}")] + Quota(String), #[error("parse error: {0}")] Parse(#[from] ParseError), #[error("validation error: {0}")] @@ -248,6 +427,7 @@ impl RequestError { RequestError::ResourceLimit(message) => { ErrorPayload::new(ErrorCode::ResourceLimit, message, false) } + RequestError::Quota(message) => ErrorPayload::new(ErrorCode::Quota, message, true), RequestError::Parse(error) => { ErrorPayload::new(ErrorCode::Parse, error.to_string(), false) } @@ -261,6 +441,12 @@ impl RequestError { ExecutionError::ResourceLimitExceeded { .. } => { ErrorPayload::new(ErrorCode::ResourceLimit, error.to_string(), false) } + ExecutionError::StatementTimedOut { .. } => { + ErrorPayload::new(ErrorCode::Timeout, error.to_string(), false) + } + ExecutionError::StatementCanceled { .. } => { + ErrorPayload::new(ErrorCode::Canceled, error.to_string(), false) + } _ => ErrorPayload::new(ErrorCode::Execution, error.to_string(), false), }, } @@ -340,6 +526,7 @@ async fn run_accept_loop( accept_result = listener.accept() => { let (mut stream, peer_addr) = accept_result?; let connection_id = NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed); + let identity = peer_addr.ip().to_string(); runtime_state.total_connections.fetch_add(1, Ordering::Relaxed); let Some(connection_permit) = runtime_state.try_acquire_connection() else { runtime_state.record_rejected_connection(); @@ -372,6 +559,8 @@ async fn run_accept_loop( store, runtime_state, connection_permit, + connection_id, + identity, ) .await { @@ -391,10 +580,15 @@ async fn handle_connection( store: Arc, runtime_state: Arc, connection_permit: OwnedSemaphorePermit, + connection_id: u64, + identity: String, ) -> Result<(), ServerError> { let _connection_guard = ConnectionGuard::new(Arc::clone(&runtime_state), connection_permit); - let connection_context = - ConnectionContext::new(runtime_state.limits.max_in_flight_requests_per_connection); + let connection_context = ConnectionContext::new( + runtime_state.limits.max_in_flight_requests_per_connection, + connection_id, + identity, + ); let mut session = ExecutionSession::with_limits( catalog.as_ref(), store.as_ref(), @@ -446,24 +640,39 @@ async fn handle_connection( let sql_len = request.sql.len(); debug!(request_type = ?request_type, sql_len, "received request frame"); - let response = - match execute_request(&mut session, &catalog, &store, &runtime_state, request) { - Ok(payload) => { - debug!(request_type = ?request_type, "request handled successfully"); - ResponseFrame::Ok(payload) + let response = match execute_request( + &mut session, + &catalog, + &store, + &runtime_state, + &connection_context, + request, + ) { + Ok(payload) => { + debug!(request_type = ?request_type, "request handled successfully"); + ResponseFrame::Ok(payload) + } + Err(err) => { + warn!(request_type = ?request_type, error = %err, "request failed"); + let payload = err.into_error_payload(); + if payload.code == ErrorCode::Busy { + runtime_state.record_busy_request(); } - Err(err) => { - warn!(request_type = ?request_type, error = %err, "request failed"); - let payload = err.into_error_payload(); - if payload.code == ErrorCode::Busy { - runtime_state.record_busy_request(); - } - if payload.code == ErrorCode::ResourceLimit { - runtime_state.record_resource_limit_request(); - } - ResponseFrame::Err(payload) + if payload.code == ErrorCode::ResourceLimit { + runtime_state.record_resource_limit_request(); } - }; + if payload.code == ErrorCode::Quota { + runtime_state.record_quota_rejection(); + } + if payload.code == ErrorCode::Timeout { + runtime_state.record_timed_out_request(); + } + if payload.code == ErrorCode::Canceled { + runtime_state.record_canceled_request(); + } + ResponseFrame::Err(payload) + } + }; if let Err(err) = write_response(&mut stream, &response).await { error!(error = %err, "failed to write response"); @@ -477,6 +686,7 @@ fn execute_request( catalog: &Catalog, store: &MvccStore, runtime_state: &Arc, + connection_context: &ConnectionContext, request: RequestFrame, ) -> Result { debug!(request_type = ?request.request_type, "executing request"); @@ -487,17 +697,60 @@ fn execute_request( "query request requires non-empty SQL payload".to_string(), )); } - execute_sql(session, catalog, runtime_state, &request.sql) + let statement = runtime_state.begin_statement( + connection_context.connection_id, + &connection_context.identity, + RequestType::Query, + &request.sql, + )?; + execute_sql( + session, + catalog, + runtime_state, + &request.sql, + &statement.governance(runtime_state.limits.statement_timeout_ms), + ) } - RequestType::Begin => { - execute_sql(session, catalog, runtime_state, "BEGIN ISOLATION LEVEL SNAPSHOT") + RequestType::Begin => execute_sql( + session, + catalog, + runtime_state, + "BEGIN ISOLATION LEVEL SNAPSHOT", + &ExecutionGovernance::default(), + ), + RequestType::Commit => { + execute_sql(session, catalog, runtime_state, "COMMIT", &ExecutionGovernance::default()) + } + RequestType::Rollback => execute_sql( + session, + catalog, + runtime_state, + "ROLLBACK", + &ExecutionGovernance::default(), + ), + RequestType::Explain => { + let statement = runtime_state.begin_statement( + connection_context.connection_id, + &connection_context.identity, + RequestType::Explain, + &request.sql, + )?; + explain_sql( + catalog, + runtime_state, + &request.sql, + &statement.governance(runtime_state.limits.statement_timeout_ms), + ) } - RequestType::Commit => execute_sql(session, catalog, runtime_state, "COMMIT"), - RequestType::Rollback => execute_sql(session, catalog, runtime_state, "ROLLBACK"), - RequestType::Explain => explain_sql(catalog, runtime_state, &request.sql), RequestType::Health => Ok(health_payload()), RequestType::Readiness => Ok(readiness_payload(runtime_state)), RequestType::AdminStatus => Ok(admin_status_payload(store, runtime_state.as_ref())), + RequestType::ActiveStatements => { + Ok(ResponsePayload::ActiveStatements(ActiveStatementsPayload { + statements: runtime_state.active_statement_payloads(), + })) + } + RequestType::CancelStatement => cancel_statement(runtime_state, &request.sql), } } @@ -506,7 +759,9 @@ fn execute_sql( catalog: &Catalog, runtime_state: &Arc, sql: &str, + governance: &ExecutionGovernance, ) -> Result { + governance.checkpoint()?; let statements = parse_sql(sql)?; if statements.len() > runtime_state.limits.max_statements_per_request { return Err(RequestError::ResourceLimit(format!( @@ -518,10 +773,24 @@ fn execute_sql( let mut last_result = None; for statement in statements { + governance.checkpoint()?; validate_statement(catalog, &statement)?; + governance.checkpoint()?; let plan = plan_statement(catalog, &statement)?; let _memory_guard = acquire_memory_intensive_guard(runtime_state, &plan)?; - let result = session.execute_plan(&plan)?; + let result = match session.execute_plan_with_governance(&plan, governance) { + Ok(result) => result, + Err(err) => { + if matches!( + err, + ExecutionError::StatementTimedOut { .. } + | ExecutionError::StatementCanceled { .. } + ) { + session.abort_active_transaction(); + } + return Err(RequestError::Execution(err)); + } + }; last_result = Some(result); } @@ -535,7 +804,9 @@ fn explain_sql( catalog: &Catalog, runtime_state: &Arc, sql: &str, + governance: &ExecutionGovernance, ) -> Result { + governance.checkpoint()?; if sql.trim().is_empty() { return Err(RequestError::InvalidRequest( "explain request requires non-empty SQL payload".to_string(), @@ -558,7 +829,9 @@ fn explain_sql( let mut rendered = Vec::new(); for (index, statement) in statements.into_iter().enumerate() { + governance.checkpoint()?; validate_statement(catalog, &statement)?; + governance.checkpoint()?; let plan = plan_statement(catalog, &statement)?; rendered.push(format!("Statement {}:\n{plan:#?}", index + 1)); } @@ -570,6 +843,18 @@ fn health_payload() -> ResponsePayload { ResponsePayload::Health(HealthPayload { ok: true, status: "ok".to_string() }) } +fn cancel_statement( + runtime_state: &Arc, + statement_id: &str, +) -> Result { + let statement_id = statement_id.trim().parse::().map_err(|_| { + RequestError::InvalidRequest( + "cancel request requires a numeric statement id in the SQL payload".to_string(), + ) + })?; + Ok(ResponsePayload::StatementCancellation(runtime_state.cancel_statement(statement_id))) +} + fn readiness_payload(runtime_state: &Arc) -> ResponsePayload { let ready = runtime_state.accepting_connections.load(Ordering::Relaxed); let status = if ready { "ready" } else { "draining" }; @@ -616,6 +901,11 @@ fn admin_status_payload(store: &MvccStore, runtime_state: &ServerRuntimeState) - rejected_connections: runtime_state.rejected_connections.load(Ordering::Relaxed), busy_requests: runtime_state.busy_requests.load(Ordering::Relaxed), resource_limit_requests: runtime_state.resource_limit_requests.load(Ordering::Relaxed), + quota_rejections: runtime_state.quota_rejections.load(Ordering::Relaxed), + timed_out_requests: runtime_state.timed_out_requests.load(Ordering::Relaxed), + canceled_requests: runtime_state.canceled_requests.load(Ordering::Relaxed), + active_statements: u64::try_from(runtime_state.active_statements.lock().len()) + .unwrap_or(u64::MAX), active_memory_intensive_requests: runtime_state .active_memory_intensive_requests .load(Ordering::Relaxed), @@ -626,3 +916,34 @@ fn admin_status_payload(store: &MvccStore, runtime_state: &ServerRuntimeState) - mvcc_active_transactions: u64::try_from(metrics.active_transactions).unwrap_or(u64::MAX), }) } + +fn request_type_name(request_type: RequestType) -> &'static str { + match request_type { + RequestType::Query => "QUERY", + RequestType::Begin => "BEGIN", + RequestType::Commit => "COMMIT", + RequestType::Rollback => "ROLLBACK", + RequestType::Explain => "EXPLAIN", + RequestType::Health => "HEALTH", + RequestType::Readiness => "READINESS", + RequestType::AdminStatus => "ADMIN_STATUS", + RequestType::ActiveStatements => "ACTIVE_STATEMENTS", + RequestType::CancelStatement => "CANCEL_STATEMENT", + } +} + +fn sql_preview(sql: &str) -> String { + const MAX_PREVIEW_CHARS: usize = 160; + + let mut preview = sql.split_whitespace().collect::>().join(" "); + if preview.chars().count() > MAX_PREVIEW_CHARS { + preview = preview.chars().take(MAX_PREVIEW_CHARS).collect::(); + preview.push_str("..."); + } + preview +} + +fn duration_to_millis(duration: Duration) -> u64 { + let millis = duration.as_millis(); + u64::try_from(millis).unwrap_or(u64::MAX) +} diff --git a/tests/integration/server.rs b/tests/integration/server.rs index 7b114ed..4875e11 100644 --- a/tests/integration/server.rs +++ b/tests/integration/server.rs @@ -1,16 +1,21 @@ use std::net::SocketAddr; use std::str::from_utf8; use std::sync::Arc; +use std::time::Duration; use lsmdb::catalog::Catalog; +use lsmdb::executor::{ExecutionResult, ExecutionSession}; use lsmdb::mvcc::MvccStore; +use lsmdb::planner::plan_statement; use lsmdb::server::{ - AdminStatusPayload, ErrorCode, ErrorPayload, HealthPayload, PROTOCOL_VERSION, QueryPayload, - ReadinessPayload, RequestFrame, RequestType, ResponseFrame, ResponsePayload, ServerLimits, - ServerOptions, TransactionState, read_response, start_server, start_server_with_options, - write_request, + ActiveStatementsPayload, AdminStatusPayload, ErrorCode, ErrorPayload, HealthPayload, + PROTOCOL_VERSION, QueryPayload, ReadinessPayload, RequestFrame, RequestType, ResponseFrame, + ResponsePayload, ServerLimits, ServerOptions, StatementCancellationPayload, TransactionState, + read_response, start_server, start_server_with_options, write_request, }; +use lsmdb::sql::{parse_statement, validate_statement}; use tokio::net::TcpStream; +use tokio::time::sleep; async fn send_request(stream: &mut TcpStream, request: RequestFrame) -> ResponseFrame { write_request(stream, &request).await.expect("write request"); @@ -59,6 +64,66 @@ fn response_to_error(response: ResponseFrame) -> ErrorPayload { } } +fn response_to_active_statements(response: ResponseFrame) -> ActiveStatementsPayload { + match response { + ResponseFrame::Ok(ResponsePayload::ActiveStatements(payload)) => payload, + other => panic!("expected active statements payload, got {other:?}"), + } +} + +fn response_to_statement_cancellation(response: ResponseFrame) -> StatementCancellationPayload { + match response { + ResponseFrame::Ok(ResponsePayload::StatementCancellation(payload)) => payload, + other => panic!("expected statement cancellation payload, got {other:?}"), + } +} + +fn execute_setup_sql(catalog: &Catalog, store: &MvccStore, sql: &str) -> ExecutionResult { + let statement = parse_statement(sql).expect("parse setup SQL"); + validate_statement(catalog, &statement).expect("validate setup SQL"); + let plan = plan_statement(catalog, &statement).expect("plan setup SQL"); + let mut session = ExecutionSession::new(catalog, store); + session.execute_plan(&plan).expect("execute setup SQL") +} + +fn populate_users(catalog: &Catalog, store: &MvccStore, rows: usize) { + execute_setup_sql( + catalog, + store, + "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, PRIMARY KEY (id))", + ); + for id in 1..=rows { + execute_setup_sql( + catalog, + store, + &format!( + "INSERT INTO users (id, email) VALUES ({id}, '{}')", + format!("user{id:05}@example.com") + ), + ); + } +} + +async fn wait_for_active_statement_id(stream: &mut TcpStream, request_type: &str) -> u64 { + for _ in 0..500 { + let payload = response_to_active_statements( + send_request( + stream, + RequestFrame { request_type: RequestType::ActiveStatements, sql: String::new() }, + ) + .await, + ); + if let Some(statement) = + payload.statements.into_iter().find(|statement| statement.request_type == request_type) + { + return statement.statement_id; + } + sleep(Duration::from_millis(1)).await; + } + + panic!("timed out waiting for active statement"); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn server_executes_query_requests_end_to_end() { let store = Arc::new(MvccStore::new()); @@ -286,6 +351,10 @@ async fn server_exposes_health_readiness_and_admin_status() { assert_eq!(admin.rejected_connections, 0); assert_eq!(admin.busy_requests, 0); assert_eq!(admin.resource_limit_requests, 0); + assert_eq!(admin.quota_rejections, 0); + assert_eq!(admin.timed_out_requests, 0); + assert_eq!(admin.canceled_requests, 0); + assert_eq!(admin.active_statements, 0); assert_eq!(admin.active_memory_intensive_requests, 0); assert_eq!(admin.mvcc_started, 0); assert_eq!(admin.mvcc_committed, 0); @@ -479,3 +548,219 @@ async fn server_enforces_statement_count_limit() { server.shutdown().await.expect("shutdown server"); } + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn server_times_out_long_running_scan_and_sort_queries() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + populate_users(&catalog, &store, 25_000); + + let options = ServerOptions { + limits: ServerLimits { statement_timeout_ms: Some(1), ..ServerLimits::default() }, + }; + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start server"); + let server_addr = server.local_addr(); + + let mut client = TcpStream::connect(server_addr).await.expect("connect client"); + + let scan_error = response_to_error( + send_request( + &mut client, + RequestFrame { + request_type: RequestType::Query, + sql: "SELECT id, email FROM users".to_string(), + }, + ) + .await, + ); + assert_eq!(scan_error.code, ErrorCode::Timeout); + assert!(scan_error.message.contains("timed out")); + + let sort_error = response_to_error( + send_request( + &mut client, + RequestFrame { + request_type: RequestType::Query, + sql: "SELECT id FROM users ORDER BY email DESC".to_string(), + }, + ) + .await, + ); + assert_eq!(sort_error.code, ErrorCode::Timeout); + assert!(sort_error.message.contains("timed out")); + + let admin = response_to_admin_status( + send_request( + &mut client, + RequestFrame { request_type: RequestType::AdminStatus, sql: String::new() }, + ) + .await, + ); + assert!(admin.timed_out_requests >= 2); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn server_rejects_queries_when_identity_quota_is_reached() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + populate_users(&catalog, &store, 60_000); + + let options = ServerOptions { + limits: ServerLimits { + max_concurrent_queries_per_identity: Some(1), + statement_timeout_ms: Some(5_000), + ..ServerLimits::default() + }, + }; + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start server"); + let server_addr = server.local_addr(); + + let mut query_client = TcpStream::connect(server_addr).await.expect("connect query client"); + let query_task = tokio::spawn(async move { + send_request( + &mut query_client, + RequestFrame { + request_type: RequestType::Query, + sql: "UPDATE users SET email = 'quota@example.com'".to_string(), + }, + ) + .await + }); + + let mut admin_client = TcpStream::connect(server_addr).await.expect("connect admin client"); + let statement_id = wait_for_active_statement_id(&mut admin_client, "QUERY").await; + + let mut second_client = TcpStream::connect(server_addr).await.expect("connect second client"); + let quota_error = response_to_error( + send_request( + &mut second_client, + RequestFrame { + request_type: RequestType::Query, + sql: "SELECT id FROM users ORDER BY email ASC".to_string(), + }, + ) + .await, + ); + assert_eq!(quota_error.code, ErrorCode::Quota); + assert!(quota_error.retryable); + + let cancel = response_to_statement_cancellation( + send_request( + &mut admin_client, + RequestFrame { + request_type: RequestType::CancelStatement, + sql: statement_id.to_string(), + }, + ) + .await, + ); + assert_eq!(cancel.statement_id, statement_id); + assert!(cancel.accepted); + + let canceled = response_to_error(query_task.await.expect("query task")); + assert_eq!(canceled.code, ErrorCode::Canceled); + + let admin = response_to_admin_status( + send_request( + &mut admin_client, + RequestFrame { request_type: RequestType::AdminStatus, sql: String::new() }, + ) + .await, + ); + assert!(admin.quota_rejections >= 1); + assert!(admin.canceled_requests >= 1); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn server_cancellation_rolls_back_active_transaction_state() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + populate_users(&catalog, &store, 50_000); + + let options = ServerOptions { + limits: ServerLimits { statement_timeout_ms: Some(5_000), ..ServerLimits::default() }, + }; + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start server"); + let server_addr = server.local_addr(); + + let mut txn_client = TcpStream::connect(server_addr).await.expect("connect txn client"); + let begin = send_request( + &mut txn_client, + RequestFrame { request_type: RequestType::Begin, sql: String::new() }, + ) + .await; + assert!(matches!( + begin, + ResponseFrame::Ok(ResponsePayload::TransactionState(TransactionState::Begun)) + )); + + let txn_task = tokio::spawn(async move { + let update = send_request( + &mut txn_client, + RequestFrame { + request_type: RequestType::Query, + sql: "UPDATE users SET email = 'blocked@example.com'".to_string(), + }, + ) + .await; + let commit = send_request( + &mut txn_client, + RequestFrame { request_type: RequestType::Commit, sql: String::new() }, + ) + .await; + (update, commit) + }); + + let mut admin_client = TcpStream::connect(server_addr).await.expect("connect admin client"); + let statement_id = wait_for_active_statement_id(&mut admin_client, "QUERY").await; + let cancel = response_to_statement_cancellation( + send_request( + &mut admin_client, + RequestFrame { + request_type: RequestType::CancelStatement, + sql: statement_id.to_string(), + }, + ) + .await, + ); + assert!(cancel.accepted); + + let (update, commit) = txn_task.await.expect("txn task"); + let update_error = response_to_error(update); + assert_eq!(update_error.code, ErrorCode::Canceled); + + let commit_error = response_to_error(commit); + assert_eq!(commit_error.code, ErrorCode::Execution); + assert!(commit_error.message.contains("no active transaction")); + + let mut verify_client = TcpStream::connect(server_addr).await.expect("connect verify client"); + let result = response_to_query( + send_request( + &mut verify_client, + RequestFrame { + request_type: RequestType::Query, + sql: "SELECT email FROM users WHERE id = 1".to_string(), + }, + ) + .await, + ); + assert_eq!(from_utf8(&result.rows[0][0]).expect("utf8 cell"), "user00001@example.com"); + + server.shutdown().await.expect("shutdown server"); +} diff --git a/tools/lsmdb-cli/main.rs b/tools/lsmdb-cli/main.rs index 572d53a..a632aff 100644 --- a/tools/lsmdb-cli/main.rs +++ b/tools/lsmdb-cli/main.rs @@ -21,7 +21,7 @@ async fn main() -> Result<(), Box> { println!("Connected to lsmdb server at {addr}"); println!( - "Type SQL to execute. Meta commands: \\help, \\q, \\timing, \\explain , \\health, \\ready, \\status" + "Type SQL to execute. Meta commands: \\help, \\q, \\timing, \\explain , \\health, \\ready, \\status, \\queries, \\cancel " ); let mut timing_enabled = false; @@ -181,6 +181,45 @@ async fn handle_meta_command( return Ok(ControlFlow::Continue); } + if input == "\\queries" { + let start = Instant::now(); + let response = send_request( + stream, + RequestFrame { request_type: RequestType::ActiveStatements, sql: String::new() }, + ) + .await?; + let elapsed = start.elapsed(); + render_response(response); + if *timing_enabled { + println!("Time: {:.3} ms", elapsed.as_secs_f64() * 1000.0); + } + return Ok(ControlFlow::Continue); + } + + if let Some(statement_id) = input.strip_prefix("\\cancel") { + let statement_id = statement_id.trim(); + if statement_id.is_empty() { + println!("Usage: \\cancel "); + return Ok(ControlFlow::Continue); + } + + let start = Instant::now(); + let response = send_request( + stream, + RequestFrame { + request_type: RequestType::CancelStatement, + sql: statement_id.to_string(), + }, + ) + .await?; + let elapsed = start.elapsed(); + render_response(response); + if *timing_enabled { + println!("Time: {:.3} ms", elapsed.as_secs_f64() * 1000.0); + } + return Ok(ControlFlow::Continue); + } + println!("Unknown command: {input}. Use \\help for available commands."); Ok(ControlFlow::Continue) } @@ -245,6 +284,10 @@ fn render_response(response: ResponseFrame) { println!("rejected_connections: {}", status.rejected_connections); println!("busy_requests: {}", status.busy_requests); println!("resource_limit_requests: {}", status.resource_limit_requests); + println!("quota_rejections: {}", status.quota_rejections); + println!("timed_out_requests: {}", status.timed_out_requests); + println!("canceled_requests: {}", status.canceled_requests); + println!("active_statements: {}", status.active_statements); println!( "active_memory_intensive_requests: {}", status.active_memory_intensive_requests @@ -255,6 +298,27 @@ fn render_response(response: ResponseFrame) { println!("mvcc_write_conflicts: {}", status.mvcc_write_conflicts); println!("mvcc_active_transactions: {}", status.mvcc_active_transactions); } + ResponsePayload::ActiveStatements(payload) => { + if payload.statements.is_empty() { + println!("No active statements"); + } else { + for statement in payload.statements { + println!("statement_id: {}", statement.statement_id); + println!("connection_id: {}", statement.connection_id); + println!("identity: {}", statement.identity); + println!("request_type: {}", statement.request_type); + println!("runtime_ms: {}", statement.runtime_ms); + println!("cancel_requested: {}", statement.cancel_requested); + println!("sql_preview: {}", statement.sql_preview); + println!(); + } + } + } + ResponsePayload::StatementCancellation(payload) => { + println!("statement_id: {}", payload.statement_id); + println!("accepted: {}", payload.accepted); + println!("status: {}", payload.status); + } }, ResponseFrame::Err(error) => { eprintln!( @@ -350,4 +414,6 @@ fn print_help() { println!(" \\health Request liveness status"); println!(" \\ready Request readiness status"); println!(" \\status Request admin runtime diagnostics"); + println!(" \\queries List active statements"); + println!(" \\cancel Signal cancellation for an active statement"); }