diff --git a/Cargo.lock b/Cargo.lock index 6335848..fd018f4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -53,6 +53,28 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "aws-lc-rs" +version = "1.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.38.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "base64" version = "0.21.7" @@ -120,6 +142,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" dependencies = [ "find-msvc-tools", + "jobserver", + "libc", "shlex", ] @@ -181,6 +205,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" +[[package]] +name = "cmake" +version = "0.1.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" +dependencies = [ + "cc", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -288,6 +321,12 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "either" version = "1.15.0" @@ -307,7 +346,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -356,6 +395,23 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "getrandom" version = "0.3.4" @@ -477,7 +533,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" dependencies = [ "hermit-abi", "libc", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -495,6 +551,16 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + [[package]] name = "js-sys" version = "0.3.91" @@ -568,9 +634,11 @@ dependencies = [ "parking_lot", "proptest", "rusqlite", + "rustls-pemfile", "serde", "thiserror", "tokio", + "tokio-rustls", "toml", "tracing", "tracing-subscriber", @@ -615,7 +683,7 @@ checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", "wasi", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -634,7 +702,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -891,6 +959,20 @@ version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rusqlite" version = "0.31.0" @@ -915,7 +997,52 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls" +version = "0.23.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" +dependencies = [ + "aws-lc-rs", + "log", + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +dependencies = [ + "aws-lc-rs", + "ring", + "rustls-pki-types", + "untrusted", ] [[package]] @@ -1043,9 +1170,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.61.2", ] +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.117" @@ -1067,7 +1200,7 @@ dependencies = [ "getrandom 0.4.2", "once_cell", "rustix", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -1121,7 +1254,7 @@ dependencies = [ "pin-project-lite", "socket2", "tokio-macros", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -1135,6 +1268,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "toml" version = "0.8.23" @@ -1255,6 +1398,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "valuable" version = "0.1.1" @@ -1411,7 +1560,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -1420,6 +1569,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.61.2" @@ -1429,6 +1587,70 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "winnow" version = "0.7.15" @@ -1546,6 +1768,12 @@ dependencies = [ "syn", ] +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + [[package]] name = "zmij" version = "1.0.21" diff --git a/Cargo.toml b/Cargo.toml index 3d7bd65..842ef20 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -102,9 +102,11 @@ parking_lot = "0.12" serde = { version = "1.0", features = ["derive"] } thiserror = "2.0" tokio = { version = "1.40", features = ["rt-multi-thread", "macros", "net", "io-util", "sync", "time"] } +tokio-rustls = "0.26" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } toml = "0.8" +rustls-pemfile = "2.2" [dev-dependencies] criterion = "0.5" diff --git a/src/config/mod.rs b/src/config/mod.rs index 39a8cd1..5dc7dce 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -5,7 +5,10 @@ use std::time::Duration; use serde::Deserialize; use thiserror::Error; -use crate::server::ServerLimits; +use crate::server::{ + ServerAuthOptions, ServerLimits, ServerOptions, ServerRole, ServerSecurityOptions, + ServerTlsMode, ServerTlsOptions, StaticPasswordUser, StaticTokenPrincipal, +}; use crate::storage::compaction::{ CompactionStrategy, LeveledCompactionConfig, TieredCompactionConfig, }; @@ -29,28 +32,17 @@ pub enum ConfigError { InvalidValue { field: &'static str, message: String }, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Default)] #[serde(default, deny_unknown_fields)] pub struct LsmdbConfig { pub storage: StorageConfig, pub server: ServerConfig, + pub security: SecurityConfig, pub wal: WalConfig, pub sstable: SstableConfig, pub compaction: CompactionConfig, } -impl Default for LsmdbConfig { - fn default() -> Self { - Self { - storage: StorageConfig::default(), - server: ServerConfig::default(), - wal: WalConfig::default(), - sstable: SstableConfig::default(), - compaction: CompactionConfig::default(), - } - } -} - #[derive(Debug, Clone, Deserialize)] #[serde(default, deny_unknown_fields)] pub struct StorageConfig { @@ -109,6 +101,100 @@ impl Default for ServerConfig { } } +#[derive(Debug, Clone, Deserialize, Default)] +#[serde(default, deny_unknown_fields)] +pub struct SecurityConfig { + pub auth: SecurityAuthConfig, + pub tls: SecurityTlsConfig, +} + +#[derive(Debug, Clone, Deserialize, Default)] +#[serde(default, deny_unknown_fields)] +pub struct SecurityAuthConfig { + pub mode: AuthModeConfig, + pub users: Vec, + pub tokens: Vec, +} + +#[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "snake_case")] +pub enum AuthModeConfig { + #[default] + Disabled, + Password, + Token, +} + +impl AuthModeConfig { + pub fn as_str(self) -> &'static str { + match self { + AuthModeConfig::Disabled => "disabled", + AuthModeConfig::Password => "password", + AuthModeConfig::Token => "token", + } + } +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct SecurityUserConfig { + pub username: String, + pub password: String, + pub role: RoleConfig, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct SecurityTokenConfig { + pub label: String, + pub token: String, + pub role: RoleConfig, +} + +#[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "snake_case")] +pub enum RoleConfig { + #[default] + Reader, + Writer, + Admin, +} + +impl From for ServerRole { + fn from(value: RoleConfig) -> Self { + match value { + RoleConfig::Reader => ServerRole::Reader, + RoleConfig::Writer => ServerRole::Writer, + RoleConfig::Admin => ServerRole::Admin, + } + } +} + +#[derive(Debug, Clone, Deserialize, Default)] +#[serde(default, deny_unknown_fields)] +pub struct SecurityTlsConfig { + pub mode: TlsModeConfig, + pub cert_path: Option, + pub key_path: Option, +} + +#[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "snake_case")] +pub enum TlsModeConfig { + #[default] + Disabled, + Required, +} + +impl TlsModeConfig { + pub fn as_str(self) -> &'static str { + match self { + TlsModeConfig::Disabled => "disabled", + TlsModeConfig::Required => "required", + } + } +} + #[derive(Debug, Clone, Deserialize)] #[serde(default, deny_unknown_fields)] pub struct WalConfig { @@ -186,7 +272,7 @@ impl Default for SstableConfig { } } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Default)] #[serde(default, deny_unknown_fields)] pub struct CompactionConfig { pub strategy: CompactionMode, @@ -194,16 +280,6 @@ pub struct CompactionConfig { pub tiered: TieredConfig, } -impl Default for CompactionConfig { - fn default() -> Self { - Self { - strategy: CompactionMode::default(), - leveled: LeveledConfig::default(), - tiered: TieredConfig::default(), - } - } -} - #[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq, Default)] #[serde(rename_all = "snake_case")] pub enum CompactionMode { @@ -257,6 +333,7 @@ pub struct RuntimeConfig { pub storage_engine: StorageEngineOptions, pub compaction_strategy: CompactionStrategy, pub server_limits: ServerLimits, + pub server_security: ServerSecurityOptions, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -292,6 +369,12 @@ pub struct StartupDiagnostics { pub server_max_query_result_rows: usize, pub server_max_query_result_bytes: usize, pub server_max_concurrent_queries_per_identity: Option, + pub security_auth_mode: AuthModeConfig, + pub security_tls_mode: TlsModeConfig, + pub security_user_count: usize, + pub security_token_count: usize, + pub security_tls_cert_path: Option, + pub security_tls_key_path: Option, pub wal_segment_size_bytes: u64, pub wal_sync_mode: SyncModeConfig, pub sstable_data_block_size_bytes: usize, @@ -335,6 +418,18 @@ impl StartupDiagnostics { "server.max_concurrent_queries_per_identity={}", format_optional_usize(self.server_max_concurrent_queries_per_identity) ), + format!("security.auth.mode={}", self.security_auth_mode.as_str()), + format!("security.tls.mode={}", self.security_tls_mode.as_str()), + format!("security.auth.users={}", self.security_user_count), + format!("security.auth.tokens={}", self.security_token_count), + format!( + "security.tls.cert_path={}", + format_optional_string(self.security_tls_cert_path.as_deref()) + ), + format!( + "security.tls.key_path={}", + format_optional_string(self.security_tls_key_path.as_deref()) + ), format!("wal.segment_size_bytes={}", self.wal_segment_size_bytes), format!("wal.sync_mode={}", self.wal_sync_mode.as_str()), format!("sstable.data_block_size_bytes={}", self.sstable_data_block_size_bytes), @@ -456,6 +551,87 @@ impl LsmdbConfig { "must be > 0 when set", )); } + match self.security.auth.mode { + AuthModeConfig::Disabled => {} + AuthModeConfig::Password => { + if self.security.auth.users.is_empty() { + return Err(invalid( + "security.auth.users", + "must contain at least one user when auth mode is 'password'", + )); + } + let mut seen = std::collections::BTreeSet::new(); + for user in &self.security.auth.users { + if user.username.trim().is_empty() { + return Err(invalid("security.auth.users.username", "must not be empty")); + } + if user.password.is_empty() { + return Err(invalid("security.auth.users.password", "must not be empty")); + } + if !seen.insert(user.username.as_str()) { + return Err(invalid( + "security.auth.users.username", + format!("duplicate username '{}'", user.username), + )); + } + } + } + AuthModeConfig::Token => { + if self.security.auth.tokens.is_empty() { + return Err(invalid( + "security.auth.tokens", + "must contain at least one token when auth mode is 'token'", + )); + } + let mut labels = std::collections::BTreeSet::new(); + let mut tokens = std::collections::BTreeSet::new(); + for token in &self.security.auth.tokens { + if token.label.trim().is_empty() { + return Err(invalid("security.auth.tokens.label", "must not be empty")); + } + if token.token.is_empty() { + return Err(invalid("security.auth.tokens.token", "must not be empty")); + } + if !labels.insert(token.label.as_str()) { + return Err(invalid( + "security.auth.tokens.label", + format!("duplicate token label '{}'", token.label), + )); + } + if !tokens.insert(token.token.as_str()) { + return Err(invalid( + "security.auth.tokens.token", + format!("duplicate token value for '{}'", token.label), + )); + } + } + } + } + if self.security.auth.mode != AuthModeConfig::Disabled + && self.security.tls.mode != TlsModeConfig::Required + { + return Err(invalid( + "security.tls.mode", + "must be 'required' when authentication is enabled", + )); + } + match self.security.tls.mode { + TlsModeConfig::Disabled => {} + TlsModeConfig::Required => { + let cert_path = self.security.tls.cert_path.as_ref().ok_or_else(|| { + invalid("security.tls.cert_path", "must be set when tls mode is 'required'") + })?; + let key_path = self.security.tls.key_path.as_ref().ok_or_else(|| { + invalid("security.tls.key_path", "must be set when tls mode is 'required'") + })?; + if cert_path.as_os_str().is_empty() { + return Err(invalid("security.tls.cert_path", "must not be empty")); + } + if key_path.as_os_str().is_empty() { + return Err(invalid("security.tls.key_path", "must not be empty")); + } + } + } if self.wal.segment_size_bytes < MIN_WAL_SEGMENT_SIZE_BYTES { return Err(invalid( "wal.segment_size_bytes", @@ -533,6 +709,22 @@ impl LsmdbConfig { server_max_concurrent_queries_per_identity: runtime .server_limits .max_concurrent_queries_per_identity, + security_auth_mode: self.security.auth.mode, + security_tls_mode: self.security.tls.mode, + security_user_count: self.security.auth.users.len(), + security_token_count: self.security.auth.tokens.len(), + security_tls_cert_path: self + .security + .tls + .cert_path + .as_ref() + .map(|path| path.display().to_string()), + security_tls_key_path: self + .security + .tls + .key_path + .as_ref() + .map(|path| path.display().to_string()), wal_segment_size_bytes: storage.wal_options.segment_size_bytes, wal_sync_mode: SyncModeConfig::from(storage.wal_options.sync_mode), sstable_data_block_size_bytes: storage.sstable_builder_options.data_block_size_bytes, @@ -550,6 +742,7 @@ impl LsmdbConfig { storage_engine: self.to_storage_engine_options_unchecked(), compaction_strategy: self.to_compaction_strategy_unchecked(), server_limits: self.to_server_limits_unchecked(), + server_security: self.to_server_security_options_unchecked(), }) } @@ -568,6 +761,19 @@ impl LsmdbConfig { Ok(self.to_server_limits_unchecked()) } + pub fn to_server_security_options(&self) -> Result { + self.validate()?; + Ok(self.to_server_security_options_unchecked()) + } + + pub fn to_server_options(&self) -> Result { + self.validate()?; + Ok(ServerOptions { + limits: self.to_server_limits_unchecked(), + security: self.to_server_security_options_unchecked(), + }) + } + fn to_storage_engine_options_unchecked(&self) -> StorageEngineOptions { let (bits_per_key, hash_functions) = bloom_params_for_fpr(self.sstable.bloom_fpr); @@ -624,6 +830,48 @@ impl LsmdbConfig { max_concurrent_queries_per_identity: self.server.max_concurrent_queries_per_identity, } } + + fn to_server_security_options_unchecked(&self) -> ServerSecurityOptions { + let auth = match self.security.auth.mode { + AuthModeConfig::Disabled => ServerAuthOptions::Disabled, + AuthModeConfig::Password => ServerAuthOptions::StaticPassword { + users: self + .security + .auth + .users + .iter() + .map(|user| StaticPasswordUser { + username: user.username.clone(), + password: user.password.clone(), + role: user.role.into(), + }) + .collect(), + }, + AuthModeConfig::Token => ServerAuthOptions::StaticToken { + principals: self + .security + .auth + .tokens + .iter() + .map(|token| StaticTokenPrincipal { + label: token.label.clone(), + token: token.token.clone(), + role: token.role.into(), + }) + .collect(), + }, + }; + let tls = ServerTlsOptions { + mode: match self.security.tls.mode { + TlsModeConfig::Disabled => ServerTlsMode::Disabled, + TlsModeConfig::Required => ServerTlsMode::Required, + }, + cert_path: self.security.tls.cert_path.clone(), + key_path: self.security.tls.key_path.clone(), + }; + + ServerSecurityOptions { auth, tls, allow_anonymous_access: false } + } } fn invalid(field: &'static str, message: impl Into) -> ConfigError { @@ -638,6 +886,10 @@ fn format_optional_usize(value: Option) -> String { value.map(|value| value.to_string()).unwrap_or_else(|| "none".to_string()) } +fn format_optional_string(value: Option<&str>) -> String { + value.map(str::to_string).unwrap_or_else(|| "none".to_string()) +} + fn bloom_params_for_fpr(fpr: f64) -> (usize, u8) { let ln2 = std::f64::consts::LN_2; let bits = (-(fpr.ln()) / (ln2 * ln2)).ceil().max(1.0) as usize; @@ -658,6 +910,11 @@ mod tests { use super::*; + const TLS_CERT_PATH: &str = + concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/tls/server.crt"); + const TLS_KEY_PATH: &str = + concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/tls/server.key"); + fn temp_file_path(label: &str) -> PathBuf { let mut path = std::env::temp_dir(); let nanos = SystemTime::now() @@ -745,6 +1002,40 @@ mod tests { } } + #[test] + fn parses_and_maps_security_config() { + let raw = format!( + r#" + [security.auth] + mode = "password" + + [[security.auth.users]] + username = "admin" + password = "secret" + role = "admin" + + [security.tls] + mode = "required" + cert_path = "{TLS_CERT_PATH}" + key_path = "{TLS_KEY_PATH}" + "# + ); + + let config = LsmdbConfig::from_toml_str(&raw).expect("parse security config"); + let options = config.to_server_options().expect("server options"); + match options.security.auth { + ServerAuthOptions::StaticPassword { users } => { + assert_eq!(users.len(), 1); + assert_eq!(users[0].username, "admin"); + assert_eq!(users[0].role, ServerRole::Admin); + } + other => panic!("expected static password auth, got {other:?}"), + } + assert_eq!(options.security.tls.mode, ServerTlsMode::Required); + assert_eq!(options.security.tls.cert_path.as_deref(), Some(Path::new(TLS_CERT_PATH))); + assert_eq!(options.security.tls.key_path.as_deref(), Some(Path::new(TLS_KEY_PATH))); + } + #[test] fn rejects_invalid_bloom_fpr() { let raw = r#" @@ -820,8 +1111,54 @@ mod tests { } #[test] - fn emits_startup_diagnostics_for_runtime_config() { + fn rejects_password_auth_without_users() { let raw = r#" + [security.auth] + mode = "password" + "#; + + let err = LsmdbConfig::from_toml_str(raw).expect_err("missing users should fail"); + assert!( + matches!(err, ConfigError::InvalidValue { field, .. } if field == "security.auth.users") + ); + } + + #[test] + fn rejects_password_auth_without_required_tls() { + let raw = r#" + [security.auth] + mode = "password" + + [[security.auth.users]] + username = "admin" + password = "secret" + role = "admin" + "#; + + let err = LsmdbConfig::from_toml_str(raw).expect_err("password auth should require tls"); + assert!( + matches!(err, ConfigError::InvalidValue { field, .. } if field == "security.tls.mode") + ); + } + + #[test] + fn rejects_tls_required_without_key_path() { + let raw = r#" + [security.tls] + mode = "required" + cert_path = "./server.crt" + "#; + + let err = LsmdbConfig::from_toml_str(raw).expect_err("missing key should fail"); + assert!( + matches!(err, ConfigError::InvalidValue { field, .. } if field == "security.tls.key_path") + ); + } + + #[test] + fn emits_startup_diagnostics_for_runtime_config() { + let raw = format!( + r#" [storage] memtable_size_bytes = 8192 memtable_arena_block_size_bytes = 4096 @@ -843,11 +1180,25 @@ mod tests { segment_size_bytes = 4096 sync_mode = "on_commit" + [security.auth] + mode = "token" + + [[security.auth.tokens]] + label = "ops-bot" + token = "opaque" + role = "writer" + + [security.tls] + mode = "required" + cert_path = "{TLS_CERT_PATH}" + key_path = "{TLS_KEY_PATH}" + [compaction] strategy = "leveled" - "#; + "# + ); - let config = LsmdbConfig::from_toml_str(raw).expect("parse valid config"); + let config = LsmdbConfig::from_toml_str(&raw).expect("parse valid config"); let diagnostics = config.startup_diagnostics().expect("startup diagnostics"); assert_eq!(diagnostics.memtable_size_bytes, 8192); assert_eq!(diagnostics.memtable_arena_block_size_bytes, 4096); @@ -856,11 +1207,17 @@ mod tests { assert_eq!(diagnostics.server_max_concurrent_connections, 24); assert_eq!(diagnostics.server_max_request_bytes, 32_768); assert_eq!(diagnostics.server_max_query_result_rows, 16); + assert_eq!(diagnostics.security_auth_mode, AuthModeConfig::Token); + assert_eq!(diagnostics.security_tls_mode, TlsModeConfig::Required); + assert_eq!(diagnostics.security_user_count, 0); + assert_eq!(diagnostics.security_token_count, 1); assert_eq!(diagnostics.wal_segment_size_bytes, 4096); assert_eq!(diagnostics.wal_sync_mode, SyncModeConfig::OnCommit); let lines = diagnostics.as_key_value_lines(); assert!(lines.iter().any(|line| line == "server.max_concurrent_connections=24")); + assert!(lines.iter().any(|line| line == "security.auth.mode=token")); + assert!(lines.iter().any(|line| line == "security.tls.mode=required")); assert!(lines.iter().any(|line| line == "compaction.strategy=leveled")); assert!(lines.iter().any(|line| line.starts_with("sstable.bloom_bits_per_key="))); } diff --git a/src/server/mod.rs b/src/server/mod.rs index 8217401..38f2a42 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -2,11 +2,15 @@ pub mod protocol; pub mod tcp; pub use protocol::{ - ActiveStatementPayload, ActiveStatementsPayload, AdminStatusPayload, ErrorCode, ErrorPayload, - HealthPayload, PROTOCOL_VERSION, ProtocolError, QueryPayload, ReadinessPayload, RequestFrame, - RequestType, ResponseFrame, ResponsePayload, StatementCancellationPayload, TransactionState, - read_request, read_request_with_limit, read_response, write_request, write_response, + ActiveStatementPayload, ActiveStatementsPayload, AdminStatusPayload, AuthenticationPayload, + AuthenticationRequest, ErrorCode, ErrorPayload, HealthPayload, PROTOCOL_VERSION, ProtocolError, + QueryPayload, ReadinessPayload, RequestFrame, RequestType, ResponseFrame, ResponsePayload, + StatementCancellationPayload, TransactionState, authentication_request_with_password, + authentication_request_with_token, decode_authentication_request, read_request, + read_request_with_limit, read_response, write_request, write_response, }; pub use tcp::{ - ServerError, ServerHandle, ServerLimits, ServerOptions, start_server, start_server_with_options, + ServerAuthOptions, ServerError, ServerHandle, ServerLimits, ServerOptions, ServerRole, + ServerSecurityOptions, ServerTlsMode, ServerTlsOptions, StaticPasswordUser, + StaticTokenPrincipal, start_server, start_server_with_options, }; diff --git a/src/server/protocol.rs b/src/server/protocol.rs index d4b6203..1c09df3 100644 --- a/src/server/protocol.rs +++ b/src/server/protocol.rs @@ -20,6 +20,7 @@ pub enum RequestType { AdminStatus = 8, ActiveStatements = 9, CancelStatement = 10, + Authenticate = 11, } impl TryFrom for RequestType { @@ -37,6 +38,7 @@ impl TryFrom for RequestType { 8 => Ok(RequestType::AdminStatus), 9 => Ok(RequestType::ActiveStatements), 10 => Ok(RequestType::CancelStatement), + 11 => Ok(RequestType::Authenticate), other => { Err(ProtocolError::InvalidFrame(format!("unknown request type byte: {other}"))) } @@ -50,6 +52,12 @@ pub struct RequestFrame { pub sql: String, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AuthenticationRequest { + Password { username: String, password: String }, + Token { token: String }, +} + #[derive(Debug, Clone, PartialEq)] pub enum ResponseFrame { Ok(ResponsePayload), @@ -69,6 +77,8 @@ pub enum ErrorCode { Timeout = 8, Canceled = 9, Quota = 10, + Unauthenticated = 11, + PermissionDenied = 12, } impl TryFrom for ErrorCode { @@ -86,6 +96,8 @@ impl TryFrom for ErrorCode { 8 => Ok(ErrorCode::Timeout), 9 => Ok(ErrorCode::Canceled), 10 => Ok(ErrorCode::Quota), + 11 => Ok(ErrorCode::Unauthenticated), + 12 => Ok(ErrorCode::PermissionDenied), other => Err(ProtocolError::InvalidFrame(format!("unknown error code byte: {other}"))), } } @@ -104,6 +116,8 @@ impl ErrorCode { ErrorCode::Timeout => "TIMEOUT", ErrorCode::Canceled => "CANCELED", ErrorCode::Quota => "QUOTA", + ErrorCode::Unauthenticated => "UNAUTHENTICATED", + ErrorCode::PermissionDenied => "PERMISSION_DENIED", } } } @@ -132,6 +146,7 @@ pub enum ResponsePayload { AdminStatus(AdminStatusPayload), ActiveStatements(ActiveStatementsPayload), StatementCancellation(StatementCancellationPayload), + Authentication(AuthenticationPayload), } #[derive(Debug, Clone, PartialEq)] @@ -198,6 +213,13 @@ pub struct StatementCancellationPayload { pub status: String, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AuthenticationPayload { + pub identity: String, + pub role: String, + pub auth_scheme: String, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum TransactionState { @@ -218,7 +240,7 @@ pub enum ProtocolError { Utf8(#[from] std::string::FromUtf8Error), } -pub async fn read_request( +pub async fn read_request( reader: &mut R, ) -> Result, ProtocolError> { let Some(body) = read_frame(reader, None).await? else { @@ -227,7 +249,7 @@ pub async fn read_request( decode_request(&body).map(Some) } -pub async fn read_request_with_limit( +pub async fn read_request_with_limit( reader: &mut R, max_body_bytes: usize, ) -> Result, ProtocolError> { @@ -237,7 +259,7 @@ pub async fn read_request_with_limit( decode_request(&body).map(Some) } -pub async fn write_request( +pub async fn write_request( writer: &mut W, request: &RequestFrame, ) -> Result<(), ProtocolError> { @@ -245,7 +267,7 @@ pub async fn write_request( write_frame(writer, &body).await } -pub async fn read_response( +pub async fn read_response( reader: &mut R, ) -> Result, ProtocolError> { let Some(body) = read_frame(reader, None).await? else { @@ -254,7 +276,7 @@ pub async fn read_response( decode_response(&body).map(Some) } -pub async fn write_response( +pub async fn write_response( writer: &mut W, response: &ResponseFrame, ) -> Result<(), ProtocolError> { @@ -262,6 +284,67 @@ pub async fn write_response( write_frame(writer, &body).await } +pub fn authentication_request_with_password( + username: impl Into, + password: impl Into, +) -> RequestFrame { + RequestFrame { + request_type: RequestType::Authenticate, + sql: encode_authentication_request(AuthenticationRequest::Password { + username: username.into(), + password: password.into(), + }), + } +} + +pub fn authentication_request_with_token(token: impl Into) -> RequestFrame { + RequestFrame { + request_type: RequestType::Authenticate, + sql: encode_authentication_request(AuthenticationRequest::Token { token: token.into() }), + } +} + +pub fn decode_authentication_request(payload: &str) -> Result { + let mut parts = payload.splitn(3, '\0'); + let scheme = parts.next().unwrap_or_default(); + let identity = parts.next().unwrap_or_default(); + let secret = parts.next().unwrap_or_default(); + + match scheme { + "password" => { + if identity.is_empty() { + return Err("password authentication requires a username".to_string()); + } + if secret.is_empty() { + return Err("password authentication requires a password".to_string()); + } + Ok(AuthenticationRequest::Password { + username: identity.to_string(), + password: secret.to_string(), + }) + } + "token" => { + if !identity.is_empty() { + return Err("token authentication does not accept a username".to_string()); + } + if secret.is_empty() { + return Err("token authentication requires a token".to_string()); + } + Ok(AuthenticationRequest::Token { token: secret.to_string() }) + } + _ => Err("authentication payload must use the 'password' or 'token' scheme".to_string()), + } +} + +fn encode_authentication_request(request: AuthenticationRequest) -> String { + match request { + AuthenticationRequest::Password { username, password } => { + format!("password\0{username}\0{password}") + } + AuthenticationRequest::Token { token } => format!("token\0\0{token}"), + } +} + pub fn payload_from_execution_result(result: &ExecutionResult) -> ResponsePayload { match result { ExecutionResult::Query(query) => ResponsePayload::Query(query_to_payload(query)), @@ -323,7 +406,7 @@ fn hex_char(value: u8) -> char { } } -async fn read_frame( +async fn read_frame( reader: &mut R, max_body_bytes: Option, ) -> Result>, ProtocolError> { @@ -358,7 +441,7 @@ async fn read_frame( Ok(Some(body)) } -async fn write_frame( +async fn write_frame( writer: &mut W, body: &[u8], ) -> Result<(), ProtocolError> { @@ -515,6 +598,12 @@ fn encode_payload(payload: &ResponsePayload, out: &mut Vec) -> Result<(), Pr out.push(u8::from(payload.accepted)); write_len_prefixed_bytes(out, payload.status.as_bytes())?; } + ResponsePayload::Authentication(payload) => { + out.push(10_u8); + write_len_prefixed_bytes(out, payload.identity.as_bytes())?; + write_len_prefixed_bytes(out, payload.role.as_bytes())?; + write_len_prefixed_bytes(out, payload.auth_scheme.as_bytes())?; + } } Ok(()) } @@ -648,6 +737,16 @@ fn decode_payload(cursor: &mut Cursor<&[u8]>) -> Result { + let identity = read_len_prefixed_string(cursor)?; + let role = read_len_prefixed_string(cursor)?; + let auth_scheme = read_len_prefixed_string(cursor)?; + Ok(ResponsePayload::Authentication(AuthenticationPayload { + identity, + role, + auth_scheme, + })) + } other => { Err(ProtocolError::InvalidFrame(format!("unknown response payload type: {other}"))) } @@ -830,6 +929,39 @@ mod tests { assert_eq!(decoded, response); } + #[tokio::test] + async fn authentication_payload_round_trip() { + let response = ResponseFrame::Ok(ResponsePayload::Authentication(AuthenticationPayload { + identity: "alice".to_string(), + role: "writer".to_string(), + auth_scheme: "password".to_string(), + })); + let (mut client, mut server) = tokio::io::duplex(1024); + write_response(&mut client, &response).await.expect("write response"); + let decoded = read_response(&mut server).await.expect("read response").expect("response"); + assert_eq!(decoded, response); + } + + #[test] + fn decodes_password_authentication_request() { + let request = decode_authentication_request("password\0alice\0secret") + .expect("decode password auth request"); + assert_eq!( + request, + AuthenticationRequest::Password { + username: "alice".to_string(), + password: "secret".to_string(), + } + ); + } + + #[test] + fn decodes_token_authentication_request() { + let request = + decode_authentication_request("token\0\0opaque-token").expect("decode token request"); + assert_eq!(request, AuthenticationRequest::Token { token: "opaque-token".to_string() }); + } + #[tokio::test] async fn request_frame_limit_rejects_oversized_body() { let request = RequestFrame { request_type: RequestType::Query, sql: "SELECT 1".repeat(64) }; diff --git a/src/server/tcp.rs b/src/server/tcp.rs index 664dd30..2bac569 100644 --- a/src/server/tcp.rs +++ b/src/server/tcp.rs @@ -1,14 +1,20 @@ use std::collections::{BTreeMap, HashMap}; +use std::fs::File; +use std::io::BufReader; use std::net::SocketAddr; +use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::time::{Duration, Instant}; use parking_lot::Mutex; use thiserror::Error; -use tokio::net::{TcpListener, TcpStream}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpListener; use tokio::sync::{OwnedSemaphorePermit, Semaphore, oneshot}; use tokio::task::JoinHandle; +use tokio_rustls::TlsAcceptor; +use tokio_rustls::rustls::{ServerConfig as RustlsServerConfig, pki_types::PrivateKeyDer}; use tracing::{Instrument, debug, error, info, info_span, warn}; use crate::catalog::Catalog; @@ -16,18 +22,139 @@ use crate::executor::governance::{ExecutionGovernance, StatementCancellation}; use crate::executor::{ExecutionError, ExecutionLimits, ExecutionSession}; use crate::mvcc::MvccStore; use crate::planner::{PhysicalPlan, PlannerError, plan_statement}; +use crate::sql::ast::Statement; use crate::sql::parser::{ParseError, parse_sql}; use crate::sql::validator::{ValidationError, validate_statement}; use super::protocol::{ - ActiveStatementPayload, ActiveStatementsPayload, AdminStatusPayload, ErrorCode, ErrorPayload, - HealthPayload, PROTOCOL_VERSION, ProtocolError, ReadinessPayload, RequestFrame, RequestType, - ResponseFrame, ResponsePayload, StatementCancellationPayload, payload_from_execution_result, + ActiveStatementPayload, ActiveStatementsPayload, AdminStatusPayload, AuthenticationPayload, + AuthenticationRequest, ErrorCode, ErrorPayload, HealthPayload, PROTOCOL_VERSION, ProtocolError, + ReadinessPayload, RequestFrame, RequestType, ResponseFrame, ResponsePayload, + StatementCancellationPayload, decode_authentication_request, payload_from_execution_result, read_request_with_limit, write_response, }; static NEXT_CONNECTION_ID: AtomicU64 = AtomicU64::new(1); +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ServerRole { + Admin, + Writer, + Reader, +} + +impl ServerRole { + pub fn as_str(self) -> &'static str { + match self { + ServerRole::Admin => "admin", + ServerRole::Writer => "writer", + ServerRole::Reader => "reader", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StaticPasswordUser { + pub username: String, + pub password: String, + pub role: ServerRole, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StaticTokenPrincipal { + pub label: String, + pub token: String, + pub role: ServerRole, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub enum ServerAuthOptions { + #[default] + Disabled, + StaticPassword { + users: Vec, + }, + StaticToken { + principals: Vec, + }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ServerTlsMode { + #[default] + Disabled, + Required, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct ServerTlsOptions { + pub mode: ServerTlsMode, + pub cert_path: Option, + pub key_path: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct ServerSecurityOptions { + pub auth: ServerAuthOptions, + pub tls: ServerTlsOptions, + pub allow_anonymous_access: bool, +} + +impl ServerSecurityOptions { + pub fn anonymous_for_local_dev() -> Self { + Self { allow_anonymous_access: true, ..Self::default() } + } + + fn validate(&self) -> Result<(), ServerError> { + match &self.auth { + ServerAuthOptions::Disabled => {} + ServerAuthOptions::StaticPassword { users } => { + if users.is_empty() { + return Err(ServerError::InvalidConfiguration( + "password authentication requires at least one user".to_string(), + )); + } + } + ServerAuthOptions::StaticToken { principals } => { + if principals.is_empty() { + return Err(ServerError::InvalidConfiguration( + "token authentication requires at least one token principal".to_string(), + )); + } + } + } + + if !matches!(self.auth, ServerAuthOptions::Disabled) + && self.tls.mode != ServerTlsMode::Required + { + return Err(ServerError::InvalidConfiguration( + "authentication requires tls mode 'required'".to_string(), + )); + } + if matches!(self.auth, ServerAuthOptions::Disabled) && !self.allow_anonymous_access { + return Err(ServerError::InvalidConfiguration( + "authentication is disabled; set allow_anonymous_access only for local development or tests" + .to_string(), + )); + } + + if self.tls.mode == ServerTlsMode::Required { + if self.tls.cert_path.is_none() { + return Err(ServerError::InvalidConfiguration( + "tls mode 'required' needs a certificate path".to_string(), + )); + } + if self.tls.key_path.is_none() { + return Err(ServerError::InvalidConfiguration( + "tls mode 'required' needs a private key path".to_string(), + )); + } + } + + Ok(()) + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ServerLimits { pub max_concurrent_connections: usize, @@ -110,9 +237,16 @@ impl ServerLimits { } } -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct ServerOptions { pub limits: ServerLimits, + pub security: ServerSecurityOptions, +} + +impl ServerOptions { + pub fn insecure_for_local_dev() -> Self { + Self { security: ServerSecurityOptions::anonymous_for_local_dev(), ..Self::default() } + } } #[derive(Debug)] @@ -132,6 +266,7 @@ struct ServerRuntimeState { active_statements: Mutex>, identity_query_counts: Mutex>, limits: ServerLimits, + security: ServerSecurityOptions, connection_slots: Arc, memory_intensive_slots: Arc, } @@ -154,6 +289,7 @@ impl ServerRuntimeState { active_statements: Mutex::new(BTreeMap::new()), identity_query_counts: Mutex::new(HashMap::new()), limits: options.limits, + security: options.security, connection_slots: Arc::new(Semaphore::new(options.limits.max_concurrent_connections)), memory_intensive_slots: Arc::new(Semaphore::new( options.limits.max_memory_intensive_requests, @@ -200,18 +336,6 @@ impl ServerRuntimeState { request_type: RequestType, sql: &str, ) -> Result { - if let Some(limit) = self.limits.max_concurrent_queries_per_identity { - let mut counts = self.identity_query_counts.lock(); - let current = counts.get(identity).copied().unwrap_or(0); - if current >= limit { - drop(counts); - return Err(RequestError::Quota(format!( - "identity '{identity}' exceeded concurrent query quota ({limit})" - ))); - } - counts.insert(identity.to_string(), current + 1); - } - let statement_id = self.next_statement_id.fetch_add(1, Ordering::Relaxed); let cancellation = StatementCancellation::new(); let entry = ActiveStatementEntry { @@ -233,8 +357,30 @@ impl ServerRuntimeState { }) } - fn finish_statement(&self, statement_id: u64, identity: &str) { + fn begin_identity_query( + self: &Arc, + identity: &str, + ) -> Result { + if let Some(limit) = self.limits.max_concurrent_queries_per_identity { + let mut counts = self.identity_query_counts.lock(); + let current = counts.get(identity).copied().unwrap_or(0); + if current >= limit { + drop(counts); + return Err(RequestError::Quota(format!( + "identity '{identity}' exceeded concurrent query quota ({limit})" + ))); + } + counts.insert(identity.to_string(), current + 1); + } + + Ok(IdentityQueryGuard { runtime_state: Arc::clone(self), identity: identity.to_string() }) + } + + fn finish_statement(&self, statement_id: u64, _identity: &str) { self.active_statements.lock().remove(&statement_id); + } + + fn finish_identity_query(&self, identity: &str) { if self.limits.max_concurrent_queries_per_identity.is_some() { let mut counts = self.identity_query_counts.lock(); if let Some(current) = counts.get_mut(identity) { @@ -321,6 +467,18 @@ impl Drop for StatementExecutionGuard { } } +#[derive(Debug)] +struct IdentityQueryGuard { + runtime_state: Arc, + identity: String, +} + +impl Drop for IdentityQueryGuard { + fn drop(&mut self) { + self.runtime_state.finish_identity_query(&self.identity); + } +} + #[derive(Debug)] struct ConnectionGuard { runtime_state: Arc, @@ -368,12 +526,23 @@ impl Drop for MemoryIntensiveRequestGuard { struct ConnectionContext { request_slots: Arc, connection_id: u64, - identity: String, + principal: AuthenticatedPrincipal, + peer_identity: String, } impl ConnectionContext { - fn new(limit: usize, connection_id: u64, identity: String) -> Self { - Self { request_slots: Arc::new(Semaphore::new(limit)), connection_id, identity } + fn new( + limit: usize, + connection_id: u64, + peer_identity: String, + principal: AuthenticatedPrincipal, + ) -> Self { + Self { + request_slots: Arc::new(Semaphore::new(limit)), + connection_id, + principal, + peer_identity, + } } fn try_acquire_request_slot(&self) -> Option { @@ -383,6 +552,21 @@ impl ConnectionContext { .ok() .map(|permit| ConnectionRequestGuard { _permit: permit }) } + + fn identity(&self) -> &str { + &self.principal.identity + } + + fn role(&self) -> ServerRole { + self.principal.role + } +} + +#[derive(Debug, Clone)] +struct AuthenticatedPrincipal { + identity: String, + role: ServerRole, + auth_scheme: &'static str, } #[derive(Debug, Error)] @@ -391,6 +575,8 @@ pub enum ServerError { Io(#[from] std::io::Error), #[error("protocol error: {0}")] Protocol(#[from] ProtocolError), + #[error("TLS error: {0}")] + Tls(String), #[error("invalid server configuration: {0}")] InvalidConfiguration(String), #[error("accept loop task failed: {0}")] @@ -407,6 +593,10 @@ enum RequestError { ResourceLimit(String), #[error("quota exceeded: {0}")] Quota(String), + #[error("unauthenticated: {0}")] + Unauthenticated(String), + #[error("permission denied: {0}")] + PermissionDenied(String), #[error("parse error: {0}")] Parse(#[from] ParseError), #[error("validation error: {0}")] @@ -428,6 +618,12 @@ impl RequestError { ErrorPayload::new(ErrorCode::ResourceLimit, message, false) } RequestError::Quota(message) => ErrorPayload::new(ErrorCode::Quota, message, true), + RequestError::Unauthenticated(message) => { + ErrorPayload::new(ErrorCode::Unauthenticated, message, false) + } + RequestError::PermissionDenied(message) => { + ErrorPayload::new(ErrorCode::PermissionDenied, message, false) + } RequestError::Parse(error) => { ErrorPayload::new(ErrorCode::Parse, error.to_string(), false) } @@ -453,6 +649,219 @@ impl RequestError { } } +fn load_tls_acceptor(options: &ServerTlsOptions) -> Result, ServerError> { + if options.mode == ServerTlsMode::Disabled { + return Ok(None); + } + + let cert_path = options.cert_path.as_ref().ok_or_else(|| { + ServerError::InvalidConfiguration( + "tls mode 'required' needs a certificate path".to_string(), + ) + })?; + let key_path = options.key_path.as_ref().ok_or_else(|| { + ServerError::InvalidConfiguration( + "tls mode 'required' needs a private key path".to_string(), + ) + })?; + + let mut cert_reader = BufReader::new(File::open(cert_path).map_err(|err| { + ServerError::Tls(format!("failed to open TLS certificate '{}': {err}", cert_path.display())) + })?); + let certificates = + rustls_pemfile::certs(&mut cert_reader).collect::, _>>().map_err(|err| { + ServerError::Tls(format!( + "failed to parse TLS certificate '{}': {err}", + cert_path.display() + )) + })?; + if certificates.is_empty() { + return Err(ServerError::Tls(format!( + "TLS certificate '{}' did not contain any certificate entries", + cert_path.display() + ))); + } + + let mut key_reader = BufReader::new(File::open(key_path).map_err(|err| { + ServerError::Tls(format!("failed to open TLS private key '{}': {err}", key_path.display())) + })?); + let private_key = rustls_pemfile::private_key(&mut key_reader) + .map_err(|err| { + ServerError::Tls(format!( + "failed to parse TLS private key '{}': {err}", + key_path.display() + )) + })? + .ok_or_else(|| { + ServerError::Tls(format!( + "TLS private key '{}' did not contain a supported key", + key_path.display() + )) + })?; + + let config = build_tls_server_config(certificates, private_key)?; + Ok(Some(TlsAcceptor::from(Arc::new(config)))) +} + +fn build_tls_server_config( + certificates: Vec>, + private_key: PrivateKeyDer<'static>, +) -> Result { + RustlsServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certificates, private_key) + .map_err(|err| ServerError::Tls(format!("failed to build TLS server config: {err}"))) +} + +async fn perform_authentication_handshake( + stream: &mut S, + runtime_state: &Arc, + connection_id: u64, + peer_identity: &str, +) -> Result, ServerError> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + match &runtime_state.security.auth { + ServerAuthOptions::Disabled => Ok(Some(AuthenticatedPrincipal { + identity: peer_identity.to_string(), + role: ServerRole::Admin, + auth_scheme: "disabled", + })), + auth_options => { + let request = match read_request_with_limit( + stream, + runtime_state.limits.max_request_bytes, + ) + .await + { + Ok(Some(request)) => request, + Ok(None) => return Ok(None), + Err(ProtocolError::FrameTooLarge { length, max }) => { + runtime_state.record_resource_limit_request(); + let response = ResponseFrame::Err(ErrorPayload::new( + ErrorCode::ResourceLimit, + format!( + "authentication frame too large: {length} bytes exceeds limit {max} bytes" + ), + false, + )); + let _ = write_response(stream, &response).await; + return Ok(None); + } + Err(err) => return Err(ServerError::Protocol(err)), + }; + + if request.request_type != RequestType::Authenticate { + warn!( + connection_id, + peer_identity, + request_type = ?request.request_type, + "rejecting unauthenticated request before session authentication" + ); + let response = ResponseFrame::Err(ErrorPayload::new( + ErrorCode::Unauthenticated, + "secure mode requires authentication before any other request".to_string(), + false, + )); + let _ = write_response(stream, &response).await; + return Ok(None); + } + + let auth_request = match decode_authentication_request(&request.sql) { + Ok(request) => request, + Err(message) => { + warn!(connection_id, peer_identity, error = %message, "invalid authentication payload"); + let response = ResponseFrame::Err(ErrorPayload::new( + ErrorCode::InvalidRequest, + message, + false, + )); + let _ = write_response(stream, &response).await; + return Ok(None); + } + }; + + match authenticate_principal(auth_options, auth_request) { + Ok(principal) => { + info!( + connection_id, + peer_identity, + identity = %principal.identity, + role = principal.role.as_str(), + auth_scheme = principal.auth_scheme, + "authenticated connection" + ); + let response = + ResponseFrame::Ok(ResponsePayload::Authentication(AuthenticationPayload { + identity: principal.identity.clone(), + role: principal.role.as_str().to_string(), + auth_scheme: principal.auth_scheme.to_string(), + })); + write_response(stream, &response).await?; + Ok(Some(principal)) + } + Err(err) => { + warn!(connection_id, peer_identity, error = %err, "authentication failed"); + let response = ResponseFrame::Err(err.into_error_payload()); + let _ = write_response(stream, &response).await; + Ok(None) + } + } + } + } +} + +fn authenticate_principal( + auth_options: &ServerAuthOptions, + request: AuthenticationRequest, +) -> Result { + match (auth_options, request) { + ( + ServerAuthOptions::StaticPassword { users }, + AuthenticationRequest::Password { username, password }, + ) => { + let user = users + .iter() + .find(|user| user.username == username && user.password == password) + .ok_or_else(|| { + RequestError::Unauthenticated( + "invalid username or password for secure server".to_string(), + ) + })?; + Ok(AuthenticatedPrincipal { + identity: user.username.clone(), + role: user.role, + auth_scheme: "password", + }) + } + (ServerAuthOptions::StaticToken { principals }, AuthenticationRequest::Token { token }) => { + let principal = + principals.iter().find(|principal| principal.token == token).ok_or_else(|| { + RequestError::Unauthenticated("invalid token for secure server".to_string()) + })?; + Ok(AuthenticatedPrincipal { + identity: principal.label.clone(), + role: principal.role, + auth_scheme: "token", + }) + } + (ServerAuthOptions::StaticPassword { .. }, AuthenticationRequest::Token { .. }) => { + Err(RequestError::Unauthenticated( + "secure server expects password authentication".to_string(), + )) + } + (ServerAuthOptions::StaticToken { .. }, AuthenticationRequest::Password { .. }) => Err( + RequestError::Unauthenticated("secure server expects token authentication".to_string()), + ), + (ServerAuthOptions::Disabled, _) => Ok(AuthenticatedPrincipal { + identity: "anonymous".to_string(), + role: ServerRole::Admin, + auth_scheme: "disabled", + }), + } +} + pub struct ServerHandle { local_addr: SocketAddr, shutdown_tx: Option>, @@ -495,6 +904,8 @@ pub async fn start_server_with_options( options: ServerOptions, ) -> Result { options.limits.validate()?; + options.security.validate()?; + let tls_acceptor = load_tls_acceptor(&options.security.tls)?; info!(%bind_addr, "starting tcp server"); let listener = TcpListener::bind(bind_addr).await?; let local_addr = listener.local_addr()?; @@ -503,7 +914,7 @@ pub async fn start_server_with_options( let runtime_state = Arc::new(ServerRuntimeState::new(options)); let (shutdown_tx, shutdown_rx) = oneshot::channel(); let task = tokio::spawn(async move { - run_accept_loop(listener, catalog, store, runtime_state, shutdown_rx).await + run_accept_loop(listener, catalog, store, runtime_state, tls_acceptor, shutdown_rx).await }); Ok(ServerHandle { local_addr, shutdown_tx: Some(shutdown_tx), task: Some(task) }) @@ -514,6 +925,7 @@ async fn run_accept_loop( catalog: Arc, store: Arc, runtime_state: Arc, + tls_acceptor: Option, mut shutdown_rx: oneshot::Receiver<()>, ) -> Result<(), ServerError> { loop { @@ -526,7 +938,7 @@ async fn run_accept_loop( accept_result = listener.accept() => { let (mut stream, peer_addr) = accept_result?; let connection_id = NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed); - let identity = peer_addr.ip().to_string(); + let peer_identity = peer_addr.ip().to_string(); runtime_state.total_connections.fetch_add(1, Ordering::Relaxed); let Some(connection_permit) = runtime_state.try_acquire_connection() else { runtime_state.record_rejected_connection(); @@ -551,19 +963,39 @@ async fn run_accept_loop( let catalog = Arc::clone(&catalog); let store = Arc::clone(&store); let runtime_state = Arc::clone(&runtime_state); + let tls_acceptor = tls_acceptor.clone(); let span = info_span!("connection", connection_id, %peer_addr); tokio::spawn(async move { - if let Err(err) = handle_connection( - stream, - catalog, - store, - runtime_state, - connection_permit, - connection_id, - identity, - ) - .await - { + let result = if let Some(acceptor) = tls_acceptor { + match acceptor.accept(stream).await { + Ok(tls_stream) => { + handle_connection( + tls_stream, + catalog, + store, + runtime_state, + connection_permit, + connection_id, + peer_identity, + ) + .await + } + Err(err) => Err(ServerError::Tls(err.to_string())), + } + } else { + handle_connection( + stream, + catalog, + store, + runtime_state, + connection_permit, + connection_id, + peer_identity, + ) + .await + }; + + if let Err(err) = result { warn!(error = %err, "connection task failed"); } }.instrument(span)); @@ -574,20 +1006,34 @@ async fn run_accept_loop( Ok(()) } -async fn handle_connection( - mut stream: TcpStream, +async fn handle_connection( + mut stream: S, catalog: Arc, store: Arc, runtime_state: Arc, connection_permit: OwnedSemaphorePermit, connection_id: u64, - identity: String, -) -> Result<(), ServerError> { + peer_identity: String, +) -> Result<(), ServerError> +where + S: AsyncRead + AsyncWrite + Unpin, +{ let _connection_guard = ConnectionGuard::new(Arc::clone(&runtime_state), connection_permit); + let Some(principal) = perform_authentication_handshake( + &mut stream, + &runtime_state, + connection_id, + &peer_identity, + ) + .await? + else { + return Ok(()); + }; let connection_context = ConnectionContext::new( runtime_state.limits.max_in_flight_requests_per_connection, connection_id, - identity, + peer_identity, + principal, ); let mut session = ExecutionSession::with_limits( catalog.as_ref(), @@ -640,6 +1086,28 @@ async fn handle_connection( let sql_len = request.sql.len(); debug!(request_type = ?request_type, sql_len, "received request frame"); + let identity_query_guard = + if matches!(request_type, RequestType::Query | RequestType::Explain) { + match runtime_state.begin_identity_query(connection_context.identity()) { + Ok(guard) => Some(guard), + Err(err) => { + warn!(request_type = ?request_type, error = %err, "request failed"); + let payload = err.into_error_payload(); + if payload.code == ErrorCode::Quota { + runtime_state.record_quota_rejection(); + } + let response = ResponseFrame::Err(payload); + if let Err(err) = write_response(&mut stream, &response).await { + error!(error = %err, "failed to write response"); + return Err(ServerError::Protocol(err)); + } + continue; + } + } + } else { + None + }; + let response = match execute_request( &mut session, &catalog, @@ -678,6 +1146,7 @@ async fn handle_connection( error!(error = %err, "failed to write response"); return Err(ServerError::Protocol(err)); } + drop(identity_query_guard); } } @@ -699,7 +1168,7 @@ fn execute_request( } let statement = runtime_state.begin_statement( connection_context.connection_id, - &connection_context.identity, + connection_context.identity(), RequestType::Query, &request.sql, )?; @@ -707,6 +1176,7 @@ fn execute_request( session, catalog, runtime_state, + connection_context, &request.sql, &statement.governance(runtime_state.limits.statement_timeout_ms), ) @@ -715,42 +1185,60 @@ fn execute_request( session, catalog, runtime_state, + connection_context, "BEGIN ISOLATION LEVEL SNAPSHOT", &ExecutionGovernance::default(), ), - RequestType::Commit => { - execute_sql(session, catalog, runtime_state, "COMMIT", &ExecutionGovernance::default()) - } + RequestType::Commit => execute_sql( + session, + catalog, + runtime_state, + connection_context, + "COMMIT", + &ExecutionGovernance::default(), + ), RequestType::Rollback => execute_sql( session, catalog, runtime_state, + connection_context, "ROLLBACK", &ExecutionGovernance::default(), ), RequestType::Explain => { let statement = runtime_state.begin_statement( connection_context.connection_id, - &connection_context.identity, + connection_context.identity(), RequestType::Explain, &request.sql, )?; explain_sql( catalog, runtime_state, + connection_context, &request.sql, &statement.governance(runtime_state.limits.statement_timeout_ms), ) } RequestType::Health => Ok(health_payload()), RequestType::Readiness => Ok(readiness_payload(runtime_state)), - RequestType::AdminStatus => Ok(admin_status_payload(store, runtime_state.as_ref())), + RequestType::AdminStatus => { + authorize_admin_request(connection_context, "admin status")?; + Ok(admin_status_payload(store, runtime_state.as_ref())) + } RequestType::ActiveStatements => { + authorize_admin_request(connection_context, "active statements")?; Ok(ResponsePayload::ActiveStatements(ActiveStatementsPayload { statements: runtime_state.active_statement_payloads(), })) } - RequestType::CancelStatement => cancel_statement(runtime_state, &request.sql), + RequestType::CancelStatement => { + authorize_admin_request(connection_context, "statement cancellation")?; + cancel_statement(runtime_state, &request.sql) + } + RequestType::Authenticate => { + Err(RequestError::InvalidRequest("connection is already authenticated".to_string())) + } } } @@ -758,11 +1246,13 @@ fn execute_sql( session: &mut ExecutionSession<'_>, catalog: &Catalog, runtime_state: &Arc, + connection_context: &ConnectionContext, sql: &str, governance: &ExecutionGovernance, ) -> Result { governance.checkpoint()?; let statements = parse_sql(sql)?; + authorize_sql_statements(connection_context, &statements)?; if statements.len() > runtime_state.limits.max_statements_per_request { return Err(RequestError::ResourceLimit(format!( "request contains {} statements, limit is {}", @@ -803,6 +1293,7 @@ fn execute_sql( fn explain_sql( catalog: &Catalog, runtime_state: &Arc, + connection_context: &ConnectionContext, sql: &str, governance: &ExecutionGovernance, ) -> Result { @@ -819,6 +1310,7 @@ fn explain_sql( "SQL payload produced no executable statement".to_string(), )); } + authorize_sql_statements(connection_context, &statements)?; if statements.len() > runtime_state.limits.max_statements_per_request { return Err(RequestError::ResourceLimit(format!( "request contains {} statements, limit is {}", @@ -861,6 +1353,89 @@ fn readiness_payload(runtime_state: &Arc) -> ResponsePayload ResponsePayload::Readiness(ReadinessPayload { ready, status: status.to_string() }) } +fn authorize_admin_request( + connection_context: &ConnectionContext, + operation: &str, +) -> Result<(), RequestError> { + if connection_context.role() == ServerRole::Admin { + return Ok(()); + } + + warn!( + connection_id = connection_context.connection_id, + peer_identity = %connection_context.peer_identity, + identity = %connection_context.identity(), + role = connection_context.role().as_str(), + operation, + "authorization denied for admin-only request" + ); + Err(RequestError::PermissionDenied(format!( + "role '{}' cannot access {operation}", + connection_context.role().as_str() + ))) +} + +fn authorize_sql_statements( + connection_context: &ConnectionContext, + statements: &[Statement], +) -> Result<(), RequestError> { + for statement in statements { + if role_allows_statement(connection_context.role(), statement) { + continue; + } + + warn!( + connection_id = connection_context.connection_id, + peer_identity = %connection_context.peer_identity, + identity = %connection_context.identity(), + role = connection_context.role().as_str(), + statement_kind = statement_kind(statement), + "authorization denied for SQL statement" + ); + return Err(RequestError::PermissionDenied(format!( + "role '{}' cannot execute {} statements", + connection_context.role().as_str(), + statement_kind(statement) + ))); + } + + Ok(()) +} + +fn role_allows_statement(role: ServerRole, statement: &Statement) -> bool { + match role { + ServerRole::Admin => true, + ServerRole::Writer => matches!( + statement, + Statement::Insert(_) + | Statement::Select(_) + | Statement::Update(_) + | Statement::Delete(_) + | Statement::Begin(_) + | Statement::Commit + | Statement::Rollback + ), + ServerRole::Reader => matches!( + statement, + Statement::Select(_) | Statement::Begin(_) | Statement::Commit | Statement::Rollback + ), + } +} + +fn statement_kind(statement: &Statement) -> &'static str { + match statement { + Statement::CreateTable(_) => "CREATE TABLE", + Statement::DropTable(_) => "DROP TABLE", + Statement::Insert(_) => "INSERT", + Statement::Select(_) => "SELECT", + Statement::Update(_) => "UPDATE", + Statement::Delete(_) => "DELETE", + Statement::Begin(_) => "BEGIN", + Statement::Commit => "COMMIT", + Statement::Rollback => "ROLLBACK", + } +} + fn acquire_memory_intensive_guard( runtime_state: &Arc, plan: &PhysicalPlan, @@ -929,6 +1504,7 @@ fn request_type_name(request_type: RequestType) -> &'static str { RequestType::AdminStatus => "ADMIN_STATUS", RequestType::ActiveStatements => "ACTIVE_STATEMENTS", RequestType::CancelStatement => "CANCEL_STATEMENT", + RequestType::Authenticate => "AUTHENTICATE", } } diff --git a/tests/fixtures/tls/server.crt b/tests/fixtures/tls/server.crt new file mode 100644 index 0000000..c37cff9 --- /dev/null +++ b/tests/fixtures/tls/server.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDSTCCAjGgAwIBAgIUSiKuegO6chUYvvtO6Z8WlBZB7oswDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMxODE4NDAwMFoXDTM2MDMx +NTE4NDAwMFowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEAz7AITSXkTJY3bGqEn41sJIUR/RNyGZrUYIhB7yvB0Lsn +XDuDEsznNT3x7ndg4ISPX/MzDPEQsFR22jK6DT+6i1c9goqKx9SmAvVrKauDQG/J +LsISvzolaPr/YNHyuYGusux+tdIQyOoMjUrrEdRyBhWdqgRL8XTDag8YbQJ20ROO +GJcqhIQmmYn60GSgmODSJ3lqrQrmxwadNAdgaznAXapyFrasAE3x8bv3FORah7GZ +JaP0d5IbwlWE2n0b5++xI1C9zANSCA2v9P//Wt5neTBU2iRm1WJuDeXaTMUwkyOq +tK9ulUo4Y3bDS0iAWXSDP/zpUodOYO7PqhP9ln5NswIDAQABo4GSMIGPMB0GA1Ud +DgQWBBQrsXZq+UyEOwT+oGfUA/pRxESG9DAfBgNVHSMEGDAWgBQrsXZq+UyEOwT+ +oGfUA/pRxESG9DAaBgNVHREEEzARgglsb2NhbGhvc3SHBH8AAAEwDAYDVR0TAQH/ +BAIwADAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDQYJKoZI +hvcNAQELBQADggEBAASq+2vIrni0sHv+KveekbjDNjNRuAyVd4wCdP90LfxCI0iB +x7w87ZRYeg2tEtGPQ3G+FgeSMccJrvTnHT9ujIZkpqP+M1WiX0i9FzEue9KMMloR +qDtNz14rSNDJXq7Z0JbU2E40BPygZwLF2dBw6SIP/9YlyuO9uI6K8NC31fU8uSfu +eZGpxp2iTZmyttJYA+LP0tCB6cxpR4t9ANNAJlU2LKpJ7QJ8tee0WsNcbcNOCh4n +oAgWOL9Q04RRvcVjuOYECZxL7c6lBhQiB8Del8T38wr9UQ4bq8bFgSfY6qmAaxiH +HexGIoE1MnlEvNZetrrabPKVrTkirHt7mLcRDlM= +-----END CERTIFICATE----- diff --git a/tests/fixtures/tls/server.key b/tests/fixtures/tls/server.key new file mode 100644 index 0000000..791068d --- /dev/null +++ b/tests/fixtures/tls/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDPsAhNJeRMljds +aoSfjWwkhRH9E3IZmtRgiEHvK8HQuydcO4MSzOc1PfHud2DghI9f8zMM8RCwVHba +MroNP7qLVz2CiorH1KYC9Wspq4NAb8kuwhK/OiVo+v9g0fK5ga6y7H610hDI6gyN +SusR1HIGFZ2qBEvxdMNqDxhtAnbRE44YlyqEhCaZifrQZKCY4NIneWqtCubHBp00 +B2BrOcBdqnIWtqwATfHxu/cU5FqHsZklo/R3khvCVYTafRvn77EjUL3MA1IIDa/0 +//9a3md5MFTaJGbVYm4N5dpMxTCTI6q0r26VSjhjdsNLSIBZdIM//OlSh05g7s+q +E/2Wfk2zAgMBAAECggEACjfuiKEzJOOFMZPiF5mVNwzHEE0bIZBhH6jEmbhs7lCv +BJY3Aj9Lpu53z1RXU2SiS0XDfsEDobFeMakqR0mZ644szBX18xQO4PljPucd65c0 +blUFKBx7x7kFxKU/zInJZytEpryBr+j4GiGUBEoQHCWHHtzcQbKNhNPeT0q+PtYh +Q1DbMppI1CbPj4RDLBJk/0uZbuxV4klZrboMZAFRoGCcy5g3r4XDXeUCL8ppo1/L +TQuM7s2pSpV+Psc6aR/J6NUkvuyoqfqt2Xk9qWJT2ehbsUXVSDfq7cFhObPLTrQr +IzhBFWf4LdgOb/snDtOCMlh5xJPp/Ap5jRMHcpKk9QKBgQDwhLaRnl/KJloNoS1Z +LxC/Y/U+hwCp4vk9ugdPdSoIJUywQKulDahH//tzuzi19e5WLarn4MWqD7L16TcT +8UJZre9tQOcYUgyR0QOtHM0aMwo7/B87Q4r/d1lI6hA1ezsCjcCYwkI+QuOaWSIK +BAf6KPHz7G5b6WVURiJ++wTlfwKBgQDdDlSMZeKraRfkvizxZ7QfLGzkdKgVs3w/ +MpGpgkp/wbxCg6w17c7qtz/4gmQJahJw/Tl6R+o/lPWdukHExRS1IbJT/dbouFG2 +4QLcQXt4XYNf/gNhJ8H8Au9A/peGi75uzNTZlS4ylwJdsvJq5ySaugVujAhPPYEt +nzQGAKb5zQKBgQCW63+vwgPzUbtiIAfXlVvZ7Hv/vzCgaWbh37AkoK0+LUGAuyO5 +TueQPkTnKsx8CRSDiOZb18PQYUd3XN6Nqe5rXWQGVxprPVjbyp6W6qKcVPiQCTUD +t+8pPBePVCfVlzzA7neyovp0HP66ZEGirULgKv8fgvUAwWQuzE9rBFHfOwKBgCBE +HTc5D/LxLhmnYKwD9RivxV07YeV5A2O+H+DcMb+gKbiTu6lLgu5jvSSq86skHnj7 +nU4p/Rk2xvs02rC8C5+8wWjdHmdtsA+/nElGDZ2uGKUEUL33rar5Sq7z+m4bK7rE +jzULP2kG/cNrgVL1VjR3fp96NSRL1/UuzcsqgTTpAoGAYGbp08rAbVooJDaFXfKO +qp5/BfmqqS1rh9hQ78T3G9EfwLCL8Sd2mucSOW0kJK5zuWZwhxbnPmKj9VUfOCb1 +p6fO4ckSk6YoDcHHs9d99udk8nG2MfYtOMzG0ES8J4XEGt6jcz8TpmYttRyhKNFz +ZNfQVNHTjj4p5LEWJ05Tvn8= +-----END PRIVATE KEY----- diff --git a/tests/integration/engine_compaction.rs b/tests/integration/engine_compaction.rs index 5a6b033..b0ae4ac 100644 --- a/tests/integration/engine_compaction.rs +++ b/tests/integration/engine_compaction.rs @@ -1,6 +1,7 @@ use std::fs; use std::path::PathBuf; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use std::thread; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use lsmdb::storage::compaction::{CompactionStrategy, LeveledCompactionConfig}; use lsmdb::storage::engine::{StorageEngine, StorageEngineOptions}; @@ -18,6 +19,25 @@ fn temp_dir(label: &str) -> PathBuf { dir } +fn count_user_key_versions(engine: &StorageEngine, user_key: &[u8]) -> usize { + let mut versions = 0_usize; + for table in engine.sstable_metadata() { + let reader = SSTableReader::open(engine.sstable_dir().join(&table.file_name)) + .expect("open sstable reader"); + let rows = reader.scan_range(None, None).expect("scan table rows"); + + for (internal_key, _value) in rows { + if decode_internal_key(&internal_key) + .map(|decoded| decoded.user_key == user_key) + .unwrap_or(false) + { + versions += 1; + } + } + } + versions +} + #[test] fn engine_runs_background_compaction_and_preserves_reads() { let dir = temp_dir("background-compaction"); @@ -108,21 +128,14 @@ fn compaction_collapses_old_versions_for_same_user_key() { assert_eq!(engine.get(b"hot-key").expect("read hot key"), Some(b"hot-09".to_vec())); - let mut hot_versions = 0_usize; - for table in engine.sstable_metadata() { - let reader = SSTableReader::open(engine.sstable_dir().join(&table.file_name)) - .expect("open sstable reader"); - let rows = reader.scan_range(None, None).expect("scan table rows"); - - for (internal_key, _value) in rows { - if decode_internal_key(&internal_key) - .map(|decoded| decoded.user_key == b"hot-key") - .unwrap_or(false) - { - hot_versions += 1; - } + let deadline = Instant::now() + Duration::from_secs(5); + let hot_versions = loop { + let hot_versions = count_user_key_versions(&engine, b"hot-key"); + if hot_versions == 1 || Instant::now() >= deadline { + break hot_versions; } - } + thread::sleep(Duration::from_millis(25)); + }; assert_eq!( hot_versions, 1, diff --git a/tests/integration/server.rs b/tests/integration/server.rs index 4875e11..4b8081d 100644 --- a/tests/integration/server.rs +++ b/tests/integration/server.rs @@ -1,6 +1,8 @@ use std::net::SocketAddr; +use std::path::PathBuf; use std::str::from_utf8; use std::sync::Arc; +use std::sync::OnceLock; use std::time::Duration; use lsmdb::catalog::Catalog; @@ -8,16 +10,30 @@ use lsmdb::executor::{ExecutionResult, ExecutionSession}; use lsmdb::mvcc::MvccStore; use lsmdb::planner::plan_statement; use lsmdb::server::{ - ActiveStatementsPayload, AdminStatusPayload, ErrorCode, ErrorPayload, HealthPayload, - PROTOCOL_VERSION, QueryPayload, ReadinessPayload, RequestFrame, RequestType, ResponseFrame, - ResponsePayload, ServerLimits, ServerOptions, StatementCancellationPayload, TransactionState, - read_response, start_server, start_server_with_options, write_request, + ActiveStatementsPayload, AdminStatusPayload, AuthenticationPayload, ErrorCode, ErrorPayload, + HealthPayload, PROTOCOL_VERSION, QueryPayload, ReadinessPayload, RequestFrame, RequestType, + ResponseFrame, ResponsePayload, ServerAuthOptions, ServerError, ServerLimits, ServerOptions, + ServerRole, ServerSecurityOptions, ServerTlsMode, ServerTlsOptions, + StatementCancellationPayload, StaticPasswordUser, StaticTokenPrincipal, TransactionState, + authentication_request_with_password, authentication_request_with_token, read_response, + start_server, start_server_with_options, write_request, }; use lsmdb::sql::{parse_statement, validate_statement}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tokio::time::sleep; +use tokio_rustls::TlsConnector; +use tokio_rustls::rustls::{ClientConfig, RootCertStore, pki_types::ServerName}; -async fn send_request(stream: &mut TcpStream, request: RequestFrame) -> ResponseFrame { +const TLS_CERT_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/tls/server.crt"); +const TLS_KEY_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/tls/server.key"); +static HEAVY_SERVER_TEST_SEMAPHORE: OnceLock> = OnceLock::new(); + +async fn send_request(stream: &mut S, request: RequestFrame) -> ResponseFrame +where + S: AsyncRead + AsyncWrite + Unpin, +{ write_request(stream, &request).await.expect("write request"); read_response(stream).await.expect("read response").expect("response") } @@ -78,6 +94,13 @@ fn response_to_statement_cancellation(response: ResponseFrame) -> StatementCance } } +fn response_to_authentication(response: ResponseFrame) -> AuthenticationPayload { + match response { + ResponseFrame::Ok(ResponsePayload::Authentication(payload)) => payload, + other => panic!("expected authentication payload, got {other:?}"), + } +} + fn execute_setup_sql(catalog: &Catalog, store: &MvccStore, sql: &str) -> ExecutionResult { let statement = parse_statement(sql).expect("parse setup SQL"); validate_statement(catalog, &statement).expect("validate setup SQL"); @@ -104,8 +127,11 @@ fn populate_users(catalog: &Catalog, store: &MvccStore, rows: usize) { } } -async fn wait_for_active_statement_id(stream: &mut TcpStream, request_type: &str) -> u64 { - for _ in 0..500 { +async fn wait_for_active_statement_id(stream: &mut S, request_type: &str) -> u64 +where + S: AsyncRead + AsyncWrite + Unpin, +{ + for _ in 0..5_000 { let payload = response_to_active_statements( send_request( stream, @@ -124,15 +150,80 @@ async fn wait_for_active_statement_id(stream: &mut TcpStream, request_type: &str panic!("timed out waiting for active statement"); } +fn password_server_options(users: Vec) -> ServerOptions { + ServerOptions { + security: ServerSecurityOptions { + auth: ServerAuthOptions::StaticPassword { users }, + tls: fixture_tls_server_options(), + ..ServerSecurityOptions::default() + }, + ..ServerOptions::default() + } +} + +fn token_server_options(principals: Vec) -> ServerOptions { + ServerOptions { + security: ServerSecurityOptions { + auth: ServerAuthOptions::StaticToken { principals }, + tls: fixture_tls_server_options(), + ..ServerSecurityOptions::default() + }, + ..ServerOptions::default() + } +} + +fn insecure_server_options() -> ServerOptions { + ServerOptions::insecure_for_local_dev() +} + +fn fixture_tls_server_options() -> ServerTlsOptions { + ServerTlsOptions { + mode: ServerTlsMode::Required, + cert_path: Some(PathBuf::from(TLS_CERT_PATH)), + key_path: Some(PathBuf::from(TLS_KEY_PATH)), + } +} + +async fn connect_tls_client(addr: SocketAddr) -> tokio_rustls::client::TlsStream { + let certificate = std::fs::File::open(TLS_CERT_PATH).expect("open tls certificate fixture"); + let mut reader = std::io::BufReader::new(certificate); + let certificates = rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .expect("parse tls certificate fixture"); + let mut roots = RootCertStore::empty(); + for certificate in certificates { + roots.add(certificate).expect("add tls certificate to root store"); + } + let config = ClientConfig::builder().with_root_certificates(roots).with_no_client_auth(); + let connector = TlsConnector::from(Arc::new(config)); + let stream = TcpStream::connect(addr).await.expect("connect tcp stream for tls"); + let server_name = ServerName::try_from("localhost").expect("valid server name").to_owned(); + connector.connect(server_name, stream).await.expect("complete tls handshake") +} + +async fn acquire_heavy_server_test_permit() -> OwnedSemaphorePermit { + HEAVY_SERVER_TEST_SEMAPHORE + .get_or_init(|| Arc::new(Semaphore::new(1))) + .clone() + .acquire_owned() + .await + .expect("heavy server test semaphore should remain open") +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn server_executes_query_requests_end_to_end() { let store = Arc::new(MvccStore::new()); let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); - let server = start_server(bind_addr, Arc::clone(&catalog), Arc::clone(&store)) - .await - .expect("start server"); + let server = start_server_with_options( + bind_addr, + Arc::clone(&catalog), + Arc::clone(&store), + insecure_server_options(), + ) + .await + .expect("start server"); let server_addr = server.local_addr(); let mut client = TcpStream::connect(server_addr).await.expect("connect client"); @@ -181,9 +272,14 @@ async fn server_returns_explain_plan_without_executing_statement() { let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); - let server = start_server(bind_addr, Arc::clone(&catalog), Arc::clone(&store)) - .await - .expect("start server"); + let server = start_server_with_options( + bind_addr, + Arc::clone(&catalog), + Arc::clone(&store), + insecure_server_options(), + ) + .await + .expect("start server"); let server_addr = server.local_addr(); let mut client = TcpStream::connect(server_addr).await.expect("connect client"); @@ -227,9 +323,14 @@ async fn server_tracks_transaction_state_per_connection() { let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); - let server = start_server(bind_addr, Arc::clone(&catalog), Arc::clone(&store)) - .await - .expect("start server"); + let server = start_server_with_options( + bind_addr, + Arc::clone(&catalog), + Arc::clone(&store), + insecure_server_options(), + ) + .await + .expect("start server"); let server_addr = server.local_addr(); let mut client_a = TcpStream::connect(server_addr).await.expect("connect client_a"); @@ -309,9 +410,14 @@ async fn server_exposes_health_readiness_and_admin_status() { let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); - let server = start_server(bind_addr, Arc::clone(&catalog), Arc::clone(&store)) - .await - .expect("start server"); + let server = start_server_with_options( + bind_addr, + Arc::clone(&catalog), + Arc::clone(&store), + insecure_server_options(), + ) + .await + .expect("start server"); let server_addr = server.local_addr(); let mut client = TcpStream::connect(server_addr).await.expect("connect client"); @@ -365,12 +471,31 @@ async fn server_exposes_health_readiness_and_admin_status() { server.shutdown().await.expect("shutdown server"); } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn start_server_requires_explicit_security_configuration() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let err = match start_server(bind_addr, Arc::clone(&catalog), Arc::clone(&store)).await { + Ok(_) => panic!("default startup should be rejected"), + Err(err) => err, + }; + match err { + ServerError::InvalidConfiguration(message) => { + assert!(message.contains("allow_anonymous_access")); + } + other => panic!("expected invalid configuration, got {other:?}"), + } +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn server_rejects_connections_above_limit_and_keeps_existing_connection_responsive() { let store = Arc::new(MvccStore::new()); let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); let options = ServerOptions { limits: ServerLimits { max_concurrent_connections: 1, ..ServerLimits::default() }, + ..insecure_server_options() }; let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); @@ -415,8 +540,10 @@ async fn server_rejects_connections_above_limit_and_keeps_existing_connection_re async fn server_rejects_oversized_request_frames_with_resource_limit_error() { let store = Arc::new(MvccStore::new()); let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); - let options = - ServerOptions { limits: ServerLimits { max_request_bytes: 32, ..ServerLimits::default() } }; + let options = ServerOptions { + limits: ServerLimits { max_request_bytes: 32, ..ServerLimits::default() }, + ..insecure_server_options() + }; let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); let server = @@ -454,6 +581,7 @@ async fn server_enforces_scan_and_sort_limits() { max_query_result_rows: 2, ..ServerLimits::default() }, + ..insecure_server_options() }; let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); @@ -522,6 +650,7 @@ async fn server_enforces_statement_count_limit() { let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); let options = ServerOptions { limits: ServerLimits { max_statements_per_request: 1, ..ServerLimits::default() }, + ..insecure_server_options() }; let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); @@ -551,12 +680,14 @@ async fn server_enforces_statement_count_limit() { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn server_times_out_long_running_scan_and_sort_queries() { + let _heavy_test_permit = acquire_heavy_server_test_permit().await; let store = Arc::new(MvccStore::new()); let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); populate_users(&catalog, &store, 25_000); let options = ServerOptions { limits: ServerLimits { statement_timeout_ms: Some(1), ..ServerLimits::default() }, + ..insecure_server_options() }; let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); let server = @@ -607,16 +738,21 @@ async fn server_times_out_long_running_scan_and_sort_queries() { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn server_rejects_queries_when_identity_quota_is_reached() { + let _heavy_test_permit = acquire_heavy_server_test_permit().await; let store = Arc::new(MvccStore::new()); let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); - populate_users(&catalog, &store, 60_000); + populate_users(&catalog, &store, 250_000); let options = ServerOptions { limits: ServerLimits { max_concurrent_queries_per_identity: Some(1), statement_timeout_ms: Some(5_000), + max_scan_rows: 300_000, + max_sort_rows: 300_000, + max_query_result_rows: 300_000, ..ServerLimits::default() }, + ..insecure_server_options() }; let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); let server = @@ -626,19 +762,15 @@ async fn server_rejects_queries_when_identity_quota_is_reached() { let server_addr = server.local_addr(); let mut query_client = TcpStream::connect(server_addr).await.expect("connect query client"); - let query_task = tokio::spawn(async move { - send_request( - &mut query_client, - RequestFrame { - request_type: RequestType::Query, - sql: "UPDATE users SET email = 'quota@example.com'".to_string(), - }, - ) - .await - }); + write_request( + &mut query_client, + &RequestFrame { request_type: RequestType::Query, sql: "SELECT id FROM users".to_string() }, + ) + .await + .expect("send blocking query"); - let mut admin_client = TcpStream::connect(server_addr).await.expect("connect admin client"); - let statement_id = wait_for_active_statement_id(&mut admin_client, "QUERY").await; + tokio::task::yield_now().await; + sleep(Duration::from_millis(10)).await; let mut second_client = TcpStream::connect(server_addr).await.expect("connect second client"); let quota_error = response_to_error( @@ -646,7 +778,7 @@ async fn server_rejects_queries_when_identity_quota_is_reached() { &mut second_client, RequestFrame { request_type: RequestType::Query, - sql: "SELECT id FROM users ORDER BY email ASC".to_string(), + sql: "SELECT id FROM users LIMIT 1".to_string(), }, ) .await, @@ -654,22 +786,9 @@ async fn server_rejects_queries_when_identity_quota_is_reached() { assert_eq!(quota_error.code, ErrorCode::Quota); assert!(quota_error.retryable); - let cancel = response_to_statement_cancellation( - send_request( - &mut admin_client, - RequestFrame { - request_type: RequestType::CancelStatement, - sql: statement_id.to_string(), - }, - ) - .await, - ); - assert_eq!(cancel.statement_id, statement_id); - assert!(cancel.accepted); - - let canceled = response_to_error(query_task.await.expect("query task")); - assert_eq!(canceled.code, ErrorCode::Canceled); + drop(query_client); + let mut admin_client = TcpStream::connect(server_addr).await.expect("connect admin client"); let admin = response_to_admin_status( send_request( &mut admin_client, @@ -678,19 +797,26 @@ async fn server_rejects_queries_when_identity_quota_is_reached() { .await, ); assert!(admin.quota_rejections >= 1); - assert!(admin.canceled_requests >= 1); server.shutdown().await.expect("shutdown server"); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn server_cancellation_rolls_back_active_transaction_state() { + let _heavy_test_permit = acquire_heavy_server_test_permit().await; let store = Arc::new(MvccStore::new()); let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); populate_users(&catalog, &store, 50_000); let options = ServerOptions { - limits: ServerLimits { statement_timeout_ms: Some(5_000), ..ServerLimits::default() }, + limits: ServerLimits { + statement_timeout_ms: Some(5_000), + max_scan_rows: 100_000, + max_sort_rows: 100_000, + max_query_result_rows: 100_000, + ..ServerLimits::default() + }, + ..insecure_server_options() }; let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); let server = @@ -715,7 +841,7 @@ async fn server_cancellation_rolls_back_active_transaction_state() { &mut txn_client, RequestFrame { request_type: RequestType::Query, - sql: "UPDATE users SET email = 'blocked@example.com'".to_string(), + sql: "SELECT id FROM users ORDER BY email DESC".to_string(), }, ) .await; @@ -764,3 +890,301 @@ async fn server_cancellation_rolls_back_active_transaction_state() { server.shutdown().await.expect("shutdown server"); } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn secure_server_rejects_unauthenticated_requests_before_any_command() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = password_server_options(vec![StaticPasswordUser { + username: "admin".to_string(), + password: "secret".to_string(), + role: ServerRole::Admin, + }]); + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start secure server"); + let server_addr = server.local_addr(); + + let mut client = connect_tls_client(server_addr).await; + let error = response_to_error( + send_request( + &mut client, + RequestFrame { request_type: RequestType::Health, sql: String::new() }, + ) + .await, + ); + assert_eq!(error.code, ErrorCode::Unauthenticated); + assert!(error.message.contains("requires authentication")); + assert!( + read_response(&mut client).await.expect("read connection closure").is_none(), + "server should close the connection after an unauthenticated request" + ); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn secure_server_rejects_authenticated_mode_without_tls() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = ServerOptions { + security: ServerSecurityOptions { + auth: ServerAuthOptions::StaticPassword { + users: vec![StaticPasswordUser { + username: "admin".to_string(), + password: "secret".to_string(), + role: ServerRole::Admin, + }], + }, + ..ServerSecurityOptions::default() + }, + ..ServerOptions::default() + }; + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let err = match start_server_with_options( + bind_addr, + Arc::clone(&catalog), + Arc::clone(&store), + options, + ) + .await + { + Ok(_) => panic!("auth without tls should be rejected"), + Err(err) => err, + }; + match err { + ServerError::InvalidConfiguration(message) => { + assert!(message.contains("authentication requires tls")); + } + other => panic!("expected invalid configuration, got {other:?}"), + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn secure_server_rejects_invalid_password_credentials() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = password_server_options(vec![StaticPasswordUser { + username: "admin".to_string(), + password: "secret".to_string(), + role: ServerRole::Admin, + }]); + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start secure server"); + let server_addr = server.local_addr(); + + let mut client = connect_tls_client(server_addr).await; + let error = response_to_error( + send_request(&mut client, authentication_request_with_password("admin", "wrong")).await, + ); + assert_eq!(error.code, ErrorCode::Unauthenticated); + assert!(error.message.contains("invalid username or password")); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn secure_server_enforces_role_based_authorization() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + execute_setup_sql( + &catalog, + &store, + "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, PRIMARY KEY (id))", + ); + execute_setup_sql( + &catalog, + &store, + "INSERT INTO users (id, email) VALUES (1, 'alice@example.com')", + ); + let options = password_server_options(vec![ + StaticPasswordUser { + username: "reader".to_string(), + password: "reader-secret".to_string(), + role: ServerRole::Reader, + }, + StaticPasswordUser { + username: "admin".to_string(), + password: "admin-secret".to_string(), + role: ServerRole::Admin, + }, + ]); + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start secure server"); + let server_addr = server.local_addr(); + + let mut reader = connect_tls_client(server_addr).await; + let auth = response_to_authentication( + send_request(&mut reader, authentication_request_with_password("reader", "reader-secret")) + .await, + ); + assert_eq!(auth.identity, "reader"); + assert_eq!(auth.role, "reader"); + + let select = response_to_query( + send_request( + &mut reader, + RequestFrame { + request_type: RequestType::Query, + sql: "SELECT email FROM users WHERE id = 1".to_string(), + }, + ) + .await, + ); + assert_eq!(from_utf8(&select.rows[0][0]).expect("utf8 cell"), "alice@example.com"); + + let update_error = response_to_error( + send_request( + &mut reader, + RequestFrame { + request_type: RequestType::Query, + sql: "UPDATE users SET email = 'blocked@example.com' WHERE id = 1".to_string(), + }, + ) + .await, + ); + assert_eq!(update_error.code, ErrorCode::PermissionDenied); + assert!(update_error.message.contains("role 'reader'")); + + let admin_status_error = response_to_error( + send_request( + &mut reader, + RequestFrame { request_type: RequestType::AdminStatus, sql: String::new() }, + ) + .await, + ); + assert_eq!(admin_status_error.code, ErrorCode::PermissionDenied); + + let mut admin = connect_tls_client(server_addr).await; + let admin_auth = response_to_authentication( + send_request(&mut admin, authentication_request_with_password("admin", "admin-secret")) + .await, + ); + assert_eq!(admin_auth.role, "admin"); + let admin_status = response_to_admin_status( + send_request( + &mut admin, + RequestFrame { request_type: RequestType::AdminStatus, sql: String::new() }, + ) + .await, + ); + assert!(admin_status.accepting_connections); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn secure_server_supports_static_token_authentication() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = token_server_options(vec![StaticTokenPrincipal { + label: "ingest-bot".to_string(), + token: "opaque-token".to_string(), + role: ServerRole::Writer, + }]); + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start token-auth server"); + let server_addr = server.local_addr(); + + let mut client = connect_tls_client(server_addr).await; + let auth = response_to_authentication( + send_request(&mut client, authentication_request_with_token("opaque-token")).await, + ); + assert_eq!(auth.identity, "ingest-bot"); + assert_eq!(auth.role, "writer"); + assert_eq!(auth.auth_scheme, "token"); + + let health = response_to_health( + send_request( + &mut client, + RequestFrame { request_type: RequestType::Health, sql: String::new() }, + ) + .await, + ); + assert!(health.ok); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn secure_server_accepts_tls_connections_when_required() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = password_server_options(vec![StaticPasswordUser { + username: "admin".to_string(), + password: "secret".to_string(), + role: ServerRole::Admin, + }]); + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start tls server"); + let server_addr = server.local_addr(); + + let mut client = connect_tls_client(server_addr).await; + let auth = response_to_authentication( + send_request(&mut client, authentication_request_with_password("admin", "secret")).await, + ); + assert_eq!(auth.identity, "admin"); + assert_eq!(auth.role, "admin"); + + let readiness = response_to_readiness( + send_request( + &mut client, + RequestFrame { request_type: RequestType::Readiness, sql: String::new() }, + ) + .await, + ); + assert!(readiness.ready); + + server.shutdown().await.expect("shutdown tls server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn tls_required_mode_rejects_missing_certificate_files() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = ServerOptions { + security: ServerSecurityOptions { + tls: ServerTlsOptions { + mode: ServerTlsMode::Required, + cert_path: Some(PathBuf::from("tests/fixtures/tls/missing.crt")), + key_path: Some(PathBuf::from(TLS_KEY_PATH)), + }, + allow_anonymous_access: true, + ..ServerSecurityOptions::default() + }, + ..insecure_server_options() + }; + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let result = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await; + let err = match result { + Ok(_) => panic!("missing certificate should fail"), + Err(err) => err, + }; + match err { + ServerError::Tls(message) => assert!(message.contains("missing.crt")), + other => panic!("expected tls error, got {other:?}"), + } +} diff --git a/tools/lsmdb-cli/main.rs b/tools/lsmdb-cli/main.rs index a632aff..bfc7c38 100644 --- a/tools/lsmdb-cli/main.rs +++ b/tools/lsmdb-cli/main.rs @@ -1,14 +1,22 @@ use std::env; +use std::fs::File; +use std::io::BufReader; use std::io::{self, Write}; use std::net::SocketAddr; +use std::path::PathBuf; +use std::sync::Arc; use std::time::Instant; use lsmdb::observability::init_tracing_from_env; use lsmdb::server::{ QueryPayload, RequestFrame, RequestType, ResponseFrame, ResponsePayload, TransactionState, - read_response, write_request, + authentication_request_with_password, authentication_request_with_token, read_response, + write_request, }; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; +use tokio_rustls::TlsConnector; +use tokio_rustls::rustls::{ClientConfig, RootCertStore, pki_types::ServerName}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -16,10 +24,14 @@ async fn main() -> Result<(), Box> { eprintln!("warning: failed to initialize tracing: {err}"); } - let addr = parse_addr()?; - let mut stream = TcpStream::connect(addr).await?; + let options = parse_cli_options()?; + let mut stream = connect(&options).await?; - println!("Connected to lsmdb server at {addr}"); + if let Some(auth) = &options.auth { + authenticate(&mut *stream, auth).await?; + } + + println!("Connected to lsmdb server at {}", options.addr); println!( "Type SQL to execute. Meta commands: \\help, \\q, \\timing, \\explain , \\health, \\ready, \\status, \\queries, \\cancel " ); @@ -43,7 +55,7 @@ async fn main() -> Result<(), Box> { } if input.starts_with('\\') { - match handle_meta_command(input, &mut timing_enabled, &mut stream).await? { + match handle_meta_command(input, &mut timing_enabled, &mut *stream).await? { ControlFlow::Continue => continue, ControlFlow::Break => break, } @@ -51,7 +63,7 @@ async fn main() -> Result<(), Box> { let request = request_from_sql(input); let start = Instant::now(); - let response = send_request(&mut stream, request).await?; + let response = send_request(&mut *stream, request).await?; let elapsed = start.elapsed(); render_response(response); @@ -69,9 +81,28 @@ enum ControlFlow { Break, } -fn parse_addr() -> Result> { +#[derive(Debug, Clone, PartialEq, Eq)] +struct CliOptions { + addr: SocketAddr, + auth: Option, + tls_ca_cert: Option, + tls_server_name: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum ClientAuth { + Password { username: String, password: String }, + Token { token: String }, +} + +fn parse_cli_options() -> Result> { let mut args = env::args().skip(1); let mut addr = "127.0.0.1:7878".to_string(); + let mut username = None; + let mut password = None; + let mut token = None; + let mut tls_ca_cert = None; + let mut tls_server_name = None; while let Some(arg) = args.next() { match arg.as_str() { @@ -81,6 +112,36 @@ fn parse_addr() -> Result> { }; addr = value; } + "--user" => { + let Some(value) = args.next() else { + return Err("--user expects a value".into()); + }; + username = Some(value); + } + "--password" => { + let Some(value) = args.next() else { + return Err("--password expects a value".into()); + }; + password = Some(value); + } + "--token" => { + let Some(value) = args.next() else { + return Err("--token expects a value".into()); + }; + token = Some(value); + } + "--tls-ca-cert" => { + let Some(value) = args.next() else { + return Err("--tls-ca-cert expects a value".into()); + }; + tls_ca_cert = Some(PathBuf::from(value)); + } + "--tls-server-name" => { + let Some(value) = args.next() else { + return Err("--tls-server-name expects a value".into()); + }; + tls_server_name = Some(value); + } "--help" | "-h" => { print_help(); std::process::exit(0); @@ -91,13 +152,97 @@ fn parse_addr() -> Result> { } } - Ok(addr.parse()?) + let auth = if let Some(token) = token { + if username.is_some() || password.is_some() { + return Err("--token cannot be combined with --user or --password".into()); + } + Some(ClientAuth::Token { token }) + } else if let Some(username) = username { + let password = match password { + Some(password) => password, + None => env::var("LSMDB_PASSWORD") + .map_err(|_| "--password or LSMDB_PASSWORD is required when --user is set")?, + }; + Some(ClientAuth::Password { username, password }) + } else if password.is_some() { + return Err("--password requires --user".into()); + } else { + None + }; + + Ok(CliOptions { addr: addr.parse()?, auth, tls_ca_cert, tls_server_name }) +} + +trait ClientIo: AsyncRead + AsyncWrite + Unpin + Send {} + +impl ClientIo for T where T: AsyncRead + AsyncWrite + Unpin + Send {} + +async fn connect(options: &CliOptions) -> Result, Box> { + let tcp = TcpStream::connect(options.addr).await?; + let Some(ca_cert_path) = &options.tls_ca_cert else { + return Ok(Box::new(tcp)); + }; + + let mut root_store = RootCertStore::empty(); + let mut cert_reader = BufReader::new(File::open(ca_cert_path)?); + let certificates = rustls_pemfile::certs(&mut cert_reader).collect::, _>>()?; + if certificates.is_empty() { + return Err(format!( + "TLS CA bundle '{}' does not contain any certificates", + ca_cert_path.display() + ) + .into()); + } + for certificate in certificates { + root_store.add(certificate)?; + } + + let server_name = + options.tls_server_name.clone().unwrap_or_else(|| options.addr.ip().to_string()); + let server_name = ServerName::try_from(server_name.as_str())?.to_owned(); + let config = ClientConfig::builder().with_root_certificates(root_store).with_no_client_auth(); + let connector = TlsConnector::from(Arc::new(config)); + let tls = connector.connect(server_name, tcp).await?; + Ok(Box::new(tls)) +} + +async fn authenticate( + stream: &mut S, + auth: &ClientAuth, +) -> Result<(), Box> +where + S: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + let request = match auth { + ClientAuth::Password { username, password } => { + authentication_request_with_password(username.clone(), password.clone()) + } + ClientAuth::Token { token } => authentication_request_with_token(token.clone()), + }; + + match send_request(stream, request).await? { + ResponseFrame::Ok(ResponsePayload::Authentication(payload)) => { + println!( + "Authenticated as {} ({}) via {}", + payload.identity, payload.role, payload.auth_scheme + ); + Ok(()) + } + ResponseFrame::Err(error) => Err(format!( + "authentication failed [{}{}]: {}", + error.code.as_str(), + if error.retryable { ", retryable" } else { "" }, + error.message + ) + .into()), + other => Err(format!("unexpected authentication response: {other:?}").into()), + } } async fn handle_meta_command( input: &str, timing_enabled: &mut bool, - stream: &mut TcpStream, + stream: &mut (impl AsyncRead + AsyncWrite + Unpin + ?Sized), ) -> Result> { if input == "\\q" || input == "\\quit" { return Ok(ControlFlow::Break); @@ -239,10 +384,13 @@ fn request_from_sql(sql: &str) -> RequestFrame { } } -async fn send_request( - stream: &mut TcpStream, +async fn send_request( + stream: &mut S, request: RequestFrame, -) -> Result> { +) -> Result> +where + S: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ write_request(stream, &request).await?; let response = read_response(stream).await?.ok_or_else(|| "server closed connection".to_string())?; @@ -319,6 +467,11 @@ fn render_response(response: ResponseFrame) { println!("accepted: {}", payload.accepted); println!("status: {}", payload.status); } + ResponsePayload::Authentication(payload) => { + println!("authenticated_identity: {}", payload.identity); + println!("authenticated_role: {}", payload.role); + println!("auth_scheme: {}", payload.auth_scheme); + } }, ResponseFrame::Err(error) => { eprintln!( @@ -405,7 +558,7 @@ fn hex_char(value: u8) -> char { } fn print_help() { - println!("Usage: lsmdb-cli [--addr HOST:PORT]"); + println!("Usage: lsmdb-cli [--addr HOST:PORT] [--user NAME --password VALUE] [--token VALUE]"); println!("Meta commands:"); println!(" \\help Show this help"); println!(" \\q | \\quit Exit CLI"); @@ -416,4 +569,11 @@ fn print_help() { println!(" \\status Request admin runtime diagnostics"); println!(" \\queries List active statements"); println!(" \\cancel Signal cancellation for an active statement"); + println!("Auth options:"); + println!(" --user NAME Authenticate with static username/password"); + println!(" --password VALUE Password for --user (or set LSMDB_PASSWORD)"); + println!(" --token VALUE Authenticate with a static token"); + println!("TLS options:"); + println!(" --tls-ca-cert PATH Enable TLS and trust the PEM CA/cert at PATH"); + println!(" --tls-server-name N Override the TLS server name (default: addr IP)"); }