diff --git a/.github/ISSUE_TEMPLATE/release-signoff.md b/.github/ISSUE_TEMPLATE/release-signoff.md new file mode 100644 index 0000000..c16c9f8 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/release-signoff.md @@ -0,0 +1,46 @@ +--- +name: Release Sign-Off +about: Production release readiness checklist and approvals +title: "release(signoff): vX.Y.Z readiness" +labels: ["area/release", "area/ops"] +assignees: [] +--- + +## Release Metadata + +- target version: +- target date: +- release owner: +- rollback owner: + +## Automated Gate Evidence + +- [ ] `Release Gate` workflow passed on target commit +- [ ] `cargo check --locked` passed +- [ ] `cargo test --lib --locked` passed +- [ ] `./tools/ci/run_integration_tests.sh` passed +- [ ] `./tools/release/check_critical_blockers.sh` passed + +## Engineering Sign-Off + +- [ ] Schema/storage compatibility risk reviewed +- [ ] Known high-severity defects triaged and dispositioned +- [ ] Upgrade notes completed +- [ ] Rollback procedure validated +- approver: +- approval date: + +## Operations Sign-Off + +- [ ] Runbook updated for this release +- [ ] Monitoring/alerts reviewed +- [ ] Capacity/performance risk reviewed +- [ ] Backup/restore posture reviewed +- approver: +- approval date: + +## Decision + +- [ ] APPROVED FOR RELEASE +- [ ] BLOCKED +- blocking notes: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 31cee04..12c6415 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,19 +2,46 @@ name: CI on: push: - branches: [main] + branches: [main, staging] pull_request: +env: + CARGO_TERM_COLOR: always + jobs: - lint-test: + quality: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Setup Rust uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt, clippy + - name: Cache cargo artifacts + uses: Swatinem/rust-cache@v2 - name: Format check run: cargo fmt --all --check - name: Clippy - run: cargo clippy --all-targets --all-features -- -D warnings - - name: Tests - run: cargo test --all-targets --all-features + run: cargo clippy --all-targets --all-features -- -W clippy::all + + unit-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup Rust + uses: dtolnay/rust-toolchain@stable + - name: Cache cargo artifacts + uses: Swatinem/rust-cache@v2 + - name: Library tests + run: cargo test --lib --locked + + integration-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup Rust + uses: dtolnay/rust-toolchain@stable + - name: Cache cargo artifacts + uses: Swatinem/rust-cache@v2 + - name: Integration and test-target coverage gate + run: ./tools/ci/run_integration_tests.sh diff --git a/.github/workflows/release-gate.yml b/.github/workflows/release-gate.yml new file mode 100644 index 0000000..c86df3c --- /dev/null +++ b/.github/workflows/release-gate.yml @@ -0,0 +1,35 @@ +name: Release Gate + +on: + workflow_dispatch: + push: + tags: + - "v*" + +permissions: + contents: read + issues: read + +env: + CARGO_TERM_COLOR: always + +jobs: + release-gate: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup Rust + uses: dtolnay/rust-toolchain@stable + - name: Cache cargo artifacts + uses: Swatinem/rust-cache@v2 + - name: Cargo check + run: cargo check --locked + - name: Library tests + run: cargo test --lib --locked + - name: Integration tests + run: ./tools/ci/run_integration_tests.sh + - name: Critical blocker gate + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITHUB_REPOSITORY: ${{ github.repository }} + run: ./tools/release/check_critical_blockers.sh diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..fd018f4 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,1781 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "aws-lc-rs" +version = "1.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.38.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "cc" +version = "1.2.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +dependencies = [ + "find-msvc-tools", + "jobserver", + "libc", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "clap" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" + +[[package]] +name = "cmake" +version = "0.1.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" +dependencies = [ + "cc", +] + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools", +] + +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "flate2" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi 5.3.0", + "wasip2", +] + +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] + +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if", + "crunchy", + "zerocopy", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + +[[package]] +name = "hdrhistogram" +version = "7.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" +dependencies = [ + "base64", + "byteorder", + "crossbeam-channel", + "flate2", + "nom", + "num-traits", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "is-terminal" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "libc" +version = "0.2.182" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" + +[[package]] +name = "libsqlite3-sys" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c10584274047cb335c23d3e61bcef8e323adae7c5c8c760540f73610177fc3f" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "lsmdb" +version = "0.1.0" +dependencies = [ + "bincode", + "bytes", + "crc32fast", + "criterion", + "crossbeam", + "hdrhistogram", + "parking_lot", + "proptest", + "rusqlite", + "rustls-pemfile", + "serde", + "thiserror", + "tokio", + "tokio-rustls", + "toml", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + +[[package]] +name = "mio" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "proptest" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37566cb3fdacef14c0737f9546df7cfeadbfbc9fef10991038bf5015d0c80532" +dependencies = [ + "bit-set", + "bit-vec", + "bitflags", + "num-traits", + "rand", + "rand_chacha", + "rand_xorshift", + "regex-syntax", + "rusty-fork", + "tempfile", + "unarray", +] + +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + +[[package]] +name = "rand_xorshift" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a" +dependencies = [ + "rand_core", +] + +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rusqlite" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b838eba278d213a8beaf485bd313fd580ca4505a00d5871caeb1457c55322cae" +dependencies = [ + "bitflags", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls" +version = "0.23.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" +dependencies = [ + "aws-lc-rs", + "log", + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +dependencies = [ + "aws-lc-rs", + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "rusty-fork" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc6bf79ff24e648f6da1f8d1f011e9cac26491b619e6b9280f2b47f1774e6ee2" +dependencies = [ + "fnv", + "quick-error", + "tempfile", + "wait-timeout", +] + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serde_spanned" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" +dependencies = [ + "serde", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "simd-adler32" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tempfile" +version = "3.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82a72c767771b47409d2345987fda8628641887d5466101319899796367354a0" +dependencies = [ + "fastrand", + "getrandom 0.4.2", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "tokio" +version = "1.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +dependencies = [ + "bytes", + "libc", + "mio", + "pin-project-lite", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "toml" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "toml_write", + "winnow", +] + +[[package]] +name = "toml_write" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wait-timeout" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ac3b126d3914f9849036f826e054cbabdc8519970b8998ddaf3b5bd3c65f11" +dependencies = [ + "libc", +] + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "web-sys" +version = "0.3.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "854ba17bb104abfb26ba36da9729addc7ce7f06f5c0f90f3c391f8461cca21f9" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "winnow" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" +dependencies = [ + "memchr", +] + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "zerocopy" +version = "0.8.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a789c6e490b576db9f7e6b6d661bcc9799f7c0ac8352f56ea20193b2681532e5" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f65c489a7071a749c849713807783f70672b28094011623e200cb86dcb835953" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/Cargo.toml b/Cargo.toml index 13867e7..842ef20 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,50 @@ license = "MIT" name = "lsmdb-cli" path = "tools/lsmdb-cli/main.rs" +[[bin]] +name = "lsmdb-admin" +path = "tools/lsmdb-admin/main.rs" + +[[test]] +name = "integration_catalog" +path = "tests/integration/catalog.rs" + +[[test]] +name = "integration_compaction" +path = "tests/integration/compaction.rs" + +[[test]] +name = "integration_engine_compaction" +path = "tests/integration/engine_compaction.rs" + +[[test]] +name = "integration_engine_recovery" +path = "tests/integration/engine_recovery.rs" + +[[test]] +name = "integration_executor" +path = "tests/integration/executor.rs" + +[[test]] +name = "integration_mvcc_isolation" +path = "tests/integration/mvcc_isolation.rs" + +[[test]] +name = "integration_planner" +path = "tests/integration/planner.rs" + +[[test]] +name = "integration_server" +path = "tests/integration/server.rs" + +[[test]] +name = "integration_sql_e2e" +path = "tests/integration/sql_e2e.rs" + +[[test]] +name = "integration_wal_recovery" +path = "tests/integration/wal_recovery.rs" + [[bench]] name = "write_throughput" harness = false @@ -58,9 +102,11 @@ parking_lot = "0.12" serde = { version = "1.0", features = ["derive"] } thiserror = "2.0" tokio = { version = "1.40", features = ["rt-multi-thread", "macros", "net", "io-util", "sync", "time"] } +tokio-rustls = "0.26" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } toml = "0.8" +rustls-pemfile = "2.2" [dev-dependencies] criterion = "0.5" diff --git a/README.md b/README.md index c7f1e96..cc383b7 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,19 @@ An LSM-tree based relational database built in Rust. - `tools/lsmdb-cli/` CLI client - `docs/` architecture and component docs +## Testing + +- library tests: `cargo test --lib --locked` +- integration suite gate: `./tools/ci/run_integration_tests.sh` +- full test run: `cargo test --locked` +- details: `docs/testing.md` + +## Release Gate + +- release readiness criteria: `docs/release_gate.md` +- local critical blocker check: `./tools/release/check_critical_blockers.sh /` +- CI release gate workflow: `.github/workflows/release-gate.yml` + ## Collaborate You can collaborate on this repository to help build a production-capable database. diff --git a/docs/config_preflight.md b/docs/config_preflight.md new file mode 100644 index 0000000..a5b1439 --- /dev/null +++ b/docs/config_preflight.md @@ -0,0 +1,37 @@ +# Config Preflight + +Use `lsmdb-admin` to validate config files before startup and print resolved runtime diagnostics. + +## Command + +```bash +cargo run --bin lsmdb-admin -- config check --config ./lsmdb.toml +``` + +## Output format + +The command prints stable `key=value` lines to stdout: + +```text +config.check=ok +config.path=./lsmdb.toml +storage.memtable_size_bytes=... +wal.segment_size_bytes=... +compaction.strategy=... +``` + +This output is intended to be parsed by CI pipelines and deployment scripts. + +## Exit codes + +- `0`: validation passed. +- `2`: config load/parse/validation failed. +- `64`: command usage error. + +## Validation highlights + +- Unknown fields are rejected (`deny_unknown_fields`). +- Numeric bounds are enforced (memtable sizing, WAL segment size, bloom FPR). +- Cross-field constraints are enforced: + - `storage.memtable_arena_block_size_bytes <= storage.memtable_size_bytes` + - `storage.flush_timeout_ms >= storage.flush_poll_interval_ms` diff --git a/docs/diagnostics_bundle.md b/docs/diagnostics_bundle.md new file mode 100644 index 0000000..32e762d --- /dev/null +++ b/docs/diagnostics_bundle.md @@ -0,0 +1,43 @@ +# Diagnostics Bundle + +Use `lsmdb-admin` to generate a support bundle for incident triage. + +## Command + +```bash +cargo run --bin lsmdb-admin -- diagnostics bundle \ + --config ./lsmdb.toml \ + --engine-root ./data \ + --output-dir ./diagnostics \ + --log-dir ./logs +``` + +## Bundle contents + +A timestamped bundle directory is created under `--output-dir` with: + +- `build_info.kv`: build and environment metadata. +- `config.redacted.toml`: redacted config snapshot. +- `startup_diagnostics.kv`: resolved runtime config diagnostics (if config loads). +- `storage_snapshot.kv`: WAL/SST/manifest file counts and bytes. +- `logs/`: redacted log captures (optional, size-limited). +- `bundle_manifest.kv`: manifest metadata with version and warning list. +- `checksums.crc32`: integrity checksums for bundle files. + +## Redaction policy + +Lines with keys matching these tokens are redacted: + +- `password` +- `secret` +- `token` +- `credential` +- `api_key` +- `private_key` +- `access_key` + +## Exit codes + +- `0`: bundle generated. +- `2`: command failed. +- `64`: usage error. diff --git a/docs/mvcc.md b/docs/mvcc.md index 7580d60..e3b8b0c 100644 --- a/docs/mvcc.md +++ b/docs/mvcc.md @@ -21,7 +21,7 @@ Transactions read at a fixed `read_ts` and stage writes locally until commit. - tracks pinned read timestamps - reports oldest active snapshot and active snapshot count - `transaction.rs` (`MvccStore`, `Transaction`) - - in-memory committed version store + - MVCC committed version store with optional durable backing - transaction API + conflict checks + metrics - `gc.rs` - prune obsolete versions behind snapshot watermark @@ -44,12 +44,33 @@ Transactions read at a fixed `read_ts` and stage writes locally until commit. - if write set empty, returns `read_ts` - detects write-write conflicts: - for each written key, if latest committed `commit_ts > read_ts`, abort with conflict error -- otherwise allocates `commit_ts` and appends versions +- otherwise allocates `commit_ts`, persists durable state (durable mode), then acknowledges commit 5. Rollback / drop - discard write buffer - release snapshot pin +## Commit durability contract + +Durable mode defines two explicit points: + +- Durability point: + - committed version state is written to `StorageEngine`. + - after this point, committed data must survive restart. +- Visibility / acknowledgment point: + - `commit()` returns success to caller. + - in-process metrics (`committed`) are incremented. + +Crash behavior: + +- Crash before durability point: + - transaction may be retried; data is not guaranteed persisted. +- Crash after durability point but before acknowledgment: + - data is recovered and visible after restart. + - client may treat commit outcome as unknown and handle idempotently. +- Rollback or uncommitted transaction: + - staged writes are never persisted and are absent after restart. + ## Conflict detection Current policy: @@ -78,11 +99,17 @@ Pruning rule: - rolled_back - write_conflicts - active_transactions +- recovered_keys +- recovered_versions -## Current limitation +## Durability modes -`MvccStore` is currently an in-memory `HashMap`-backed version store. +`MvccStore` supports two modes: -Implication: -- SQL/server transaction behavior is correct for in-process semantics -- committed SQL data is not yet durable across process restart via the LSM storage engine +- In-memory mode (`MvccStore::new()`) + - Uses only in-process `HashMap` state. + - Best for unit tests and fast local execution. +- Durable mode (`MvccStore::open_persistent*()` / `MvccStore::with_storage_engine()`) + - Persists committed MVCC state into `StorageEngine` under an internal key. + - Reloads committed versions and oracle position on restart. + - Allows SQL/catalog state to survive process restarts while preserving MVCC semantics. diff --git a/docs/release_gate.md b/docs/release_gate.md new file mode 100644 index 0000000..99d0e22 --- /dev/null +++ b/docs/release_gate.md @@ -0,0 +1,51 @@ +# Release Gate + +This document defines the production-readiness gate for `lsmdb`. A release is blocked if any +critical gate check fails. + +## Automated Gate + +The release gate workflow (`.github/workflows/release-gate.yml`) enforces the following checks: + +| Category | Check | Pass Condition | +| --- | --- | --- | +| Build | `cargo check --locked` | Succeeds | +| Reliability | `cargo test --lib --locked` | Succeeds | +| Reliability | `./tools/ci/run_integration_tests.sh` | Succeeds | +| Product Risk | `./tools/release/check_critical_blockers.sh` | No open critical `priority/high` issues | + +Critical blocking areas are currently: + +- `area/security` +- `area/recovery` +- `area/release` +- `area/ops` +- `area/server` +- `area/storage` +- `area/performance` + +Any open issue with `priority/high` and one of these area labels blocks release. + +## Manual Gate + +Manual sign-off is required in addition to automated checks. + +- Engineering sign-off is required. +- Operations sign-off is required. +- Release notes and upgrade notes must be updated. +- Rollback plan must be documented. + +Use `.github/ISSUE_TEMPLATE/release-signoff.md` for release sign-off records. + +## Readiness Status + +Quick local check: + +```bash +./tools/release/check_critical_blockers.sh ganeshwhere/lsmdb +``` + +CI check: + +- Run the `Release Gate` workflow manually (`workflow_dispatch`) before creating a release tag. +- Any failure in that workflow means release readiness is `NOT READY`. diff --git a/docs/sql_conformance.md b/docs/sql_conformance.md new file mode 100644 index 0000000..e0ea188 --- /dev/null +++ b/docs/sql_conformance.md @@ -0,0 +1,41 @@ +# SQL Conformance Suite + +The SQL conformance suite verifies parser, validator, planner, and executor behavior against the supported subset documented in `docs/sql_subset.md`. + +## Run + +```bash +cargo test --test sql_conformance +``` + +## Fixtures + +Fixture files live in: + +- `tests/conformance/sql/01-core-supported.toml` +- `tests/conformance/sql/02-errors-and-boundaries.toml` + +Each case includes: + +- `id`: stable case identifier. +- `category`: compatibility category. +- `setup_sql`: optional setup statements executed before the case statement. +- `sql`: statement under test. +- `expect`: expected outcome (`affected_rows`, `query`, `transaction_state`, `error_contains`). + +## Compatibility report artifact + +The test emits a machine-readable report at: + +- `target/sql-conformance/report.toml` + +Report fields include: + +- `schema_version` +- `generated_unix_seconds` +- `total_suites`, `total_cases`, `passed_cases`, `failed_cases` +- `covered_categories` +- `suite_summaries` +- `failures` (detailed per-case diagnostics) + +This artifact is intended for release checks and compatibility tracking. diff --git a/docs/sql_subset.md b/docs/sql_subset.md index c408e4d..d3174a2 100644 --- a/docs/sql_subset.md +++ b/docs/sql_subset.md @@ -90,3 +90,11 @@ Supported expression forms: - secondary indexes - prepared statements - SQL-backed persistence through `StorageEngine` (executor currently uses in-memory MVCC store) + +## Conformance coverage + +- Fixture-driven SQL conformance tests live under `tests/conformance/sql/`. +- Run with: `cargo test --test sql_conformance`. +- Machine-readable compatibility report artifact is emitted to: + - `target/sql-conformance/report.toml` +- Additional suite details and fixture schema are documented in `docs/sql_conformance.md`. diff --git a/docs/testing.md b/docs/testing.md new file mode 100644 index 0000000..a0b2f36 --- /dev/null +++ b/docs/testing.md @@ -0,0 +1,33 @@ +# Testing + +This project splits test coverage into: + +- library/unit tests in `src/**` +- integration tests in `tests/integration/*.rs` +- root-level conformance and persistence tests in `tests/*.rs` + +## Commands + +Run library tests: + +```bash +cargo test --lib --locked +``` + +Run integration suite with target-registration checks: + +```bash +./tools/ci/run_integration_tests.sh +``` + +Run full local test pass: + +```bash +cargo test --locked +``` + +## Integration target policy + +Files under `tests/integration/` are intentionally registered explicitly in `Cargo.toml` as `[[test]]` targets (`integration_*` naming). + +`./tools/ci/run_integration_tests.sh` enforces that every integration file is registered, so new integration tests cannot be silently skipped by CI. diff --git a/src/catalog/mod.rs b/src/catalog/mod.rs index 660caba..420a427 100644 --- a/src/catalog/mod.rs +++ b/src/catalog/mod.rs @@ -27,8 +27,8 @@ pub enum CatalogError { Transaction(#[from] TransactionError), #[error("failed to serialize table descriptor: {0}")] Serialization(String), - #[error("failed to deserialize table descriptor at key '{key}': {source}")] - Deserialization { key: String, source: String }, + #[error("failed to deserialize table descriptor at key '{key}': {message}")] + Deserialization { key: String, message: String }, } #[derive(Debug)] @@ -55,7 +55,7 @@ impl Catalog { let table_name = normalize_table_name(&descriptor.name)?.to_string(); descriptor.name = table_name.clone(); descriptor.validate()?; - let key = table_storage_key(table_name); + let key = table_storage_key(&table_name); let payload = bincode::serialize(&descriptor) .map_err(|err| CatalogError::Serialization(err.to_string()))?; @@ -72,7 +72,7 @@ impl Catalog { } Err(err @ TransactionError::WriteWriteConflict { .. }) => { self.refresh()?; - if self.tables.read().contains_key(table_name) { + if self.tables.read().contains_key(&table_name) { return Err(CatalogError::TableAlreadyExists(table_name.to_string())); } Err(CatalogError::Transaction(err)) @@ -127,19 +127,19 @@ impl Catalog { let table_name = table_name_from_key(&key) .ok_or_else(|| CatalogError::Deserialization { key: key_string.clone(), - source: "table key does not use catalog table prefix".to_string(), + message: "table key does not use catalog table prefix".to_string(), })? .to_string(); let descriptor: TableDescriptor = bincode::deserialize(&payload).map_err(|err| { - CatalogError::Deserialization { key: key_string.clone(), source: err.to_string() } + CatalogError::Deserialization { key: key_string.clone(), message: err.to_string() } })?; descriptor.validate()?; if descriptor.name != table_name { return Err(CatalogError::Deserialization { key: key_string, - source: format!( + message: format!( "descriptor name '{}' does not match key table '{}'", descriptor.name, table_name ), diff --git a/src/config/mod.rs b/src/config/mod.rs index f0c05d1..5dc7dce 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -5,6 +5,10 @@ use std::time::Duration; use serde::Deserialize; use thiserror::Error; +use crate::server::{ + ServerAuthOptions, ServerLimits, ServerOptions, ServerRole, ServerSecurityOptions, + ServerTlsMode, ServerTlsOptions, StaticPasswordUser, StaticTokenPrincipal, +}; use crate::storage::compaction::{ CompactionStrategy, LeveledCompactionConfig, TieredCompactionConfig, }; @@ -28,26 +32,17 @@ pub enum ConfigError { InvalidValue { field: &'static str, message: String }, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Default)] #[serde(default, deny_unknown_fields)] pub struct LsmdbConfig { pub storage: StorageConfig, + pub server: ServerConfig, + pub security: SecurityConfig, pub wal: WalConfig, pub sstable: SstableConfig, pub compaction: CompactionConfig, } -impl Default for LsmdbConfig { - fn default() -> Self { - Self { - storage: StorageConfig::default(), - wal: WalConfig::default(), - sstable: SstableConfig::default(), - compaction: CompactionConfig::default(), - } - } -} - #[derive(Debug, Clone, Deserialize)] #[serde(default, deny_unknown_fields)] pub struct StorageConfig { @@ -69,6 +64,137 @@ impl Default for StorageConfig { } } +#[derive(Debug, Clone, Deserialize)] +#[serde(default, deny_unknown_fields)] +pub struct ServerConfig { + pub max_concurrent_connections: usize, + pub max_in_flight_requests_per_connection: usize, + pub max_request_bytes: usize, + pub max_statements_per_request: usize, + pub statement_timeout_ms: Option, + pub max_memory_intensive_requests: usize, + pub max_scan_rows: usize, + pub max_sort_rows: usize, + pub max_join_rows: usize, + pub max_query_result_rows: usize, + pub max_query_result_bytes: usize, + pub max_concurrent_queries_per_identity: Option, +} + +impl Default for ServerConfig { + fn default() -> Self { + let defaults = ServerLimits::default(); + Self { + max_concurrent_connections: defaults.max_concurrent_connections, + max_in_flight_requests_per_connection: defaults.max_in_flight_requests_per_connection, + max_request_bytes: defaults.max_request_bytes, + max_statements_per_request: defaults.max_statements_per_request, + statement_timeout_ms: defaults.statement_timeout_ms, + max_memory_intensive_requests: defaults.max_memory_intensive_requests, + max_scan_rows: defaults.max_scan_rows, + max_sort_rows: defaults.max_sort_rows, + max_join_rows: defaults.max_join_rows, + max_query_result_rows: defaults.max_query_result_rows, + max_query_result_bytes: defaults.max_query_result_bytes, + max_concurrent_queries_per_identity: defaults.max_concurrent_queries_per_identity, + } + } +} + +#[derive(Debug, Clone, Deserialize, Default)] +#[serde(default, deny_unknown_fields)] +pub struct SecurityConfig { + pub auth: SecurityAuthConfig, + pub tls: SecurityTlsConfig, +} + +#[derive(Debug, Clone, Deserialize, Default)] +#[serde(default, deny_unknown_fields)] +pub struct SecurityAuthConfig { + pub mode: AuthModeConfig, + pub users: Vec, + pub tokens: Vec, +} + +#[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "snake_case")] +pub enum AuthModeConfig { + #[default] + Disabled, + Password, + Token, +} + +impl AuthModeConfig { + pub fn as_str(self) -> &'static str { + match self { + AuthModeConfig::Disabled => "disabled", + AuthModeConfig::Password => "password", + AuthModeConfig::Token => "token", + } + } +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct SecurityUserConfig { + pub username: String, + pub password: String, + pub role: RoleConfig, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct SecurityTokenConfig { + pub label: String, + pub token: String, + pub role: RoleConfig, +} + +#[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "snake_case")] +pub enum RoleConfig { + #[default] + Reader, + Writer, + Admin, +} + +impl From for ServerRole { + fn from(value: RoleConfig) -> Self { + match value { + RoleConfig::Reader => ServerRole::Reader, + RoleConfig::Writer => ServerRole::Writer, + RoleConfig::Admin => ServerRole::Admin, + } + } +} + +#[derive(Debug, Clone, Deserialize, Default)] +#[serde(default, deny_unknown_fields)] +pub struct SecurityTlsConfig { + pub mode: TlsModeConfig, + pub cert_path: Option, + pub key_path: Option, +} + +#[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "snake_case")] +pub enum TlsModeConfig { + #[default] + Disabled, + Required, +} + +impl TlsModeConfig { + pub fn as_str(self) -> &'static str { + match self { + TlsModeConfig::Disabled => "disabled", + TlsModeConfig::Required => "required", + } + } +} + #[derive(Debug, Clone, Deserialize)] #[serde(default, deny_unknown_fields)] pub struct WalConfig { @@ -114,6 +240,16 @@ impl From for SyncModeConfig { } } +impl SyncModeConfig { + pub fn as_str(self) -> &'static str { + match self { + SyncModeConfig::Never => "never", + SyncModeConfig::OnCommit => "on_commit", + SyncModeConfig::Always => "always", + } + } +} + #[derive(Debug, Clone, Deserialize)] #[serde(default, deny_unknown_fields)] pub struct SstableConfig { @@ -136,7 +272,7 @@ impl Default for SstableConfig { } } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Default)] #[serde(default, deny_unknown_fields)] pub struct CompactionConfig { pub strategy: CompactionMode, @@ -144,16 +280,6 @@ pub struct CompactionConfig { pub tiered: TieredConfig, } -impl Default for CompactionConfig { - fn default() -> Self { - Self { - strategy: CompactionMode::default(), - leveled: LeveledConfig::default(), - tiered: TieredConfig::default(), - } - } -} - #[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq, Default)] #[serde(rename_all = "snake_case")] pub enum CompactionMode { @@ -206,6 +332,145 @@ impl Default for TieredConfig { pub struct RuntimeConfig { pub storage_engine: StorageEngineOptions, pub compaction_strategy: CompactionStrategy, + pub server_limits: ServerLimits, + pub server_security: ServerSecurityOptions, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CompactionDiagnostics { + Leveled { + level0_file_limit: usize, + level_size_base_bytes: u64, + level_size_multiplier: u64, + max_levels: u32, + }, + Tiered { + max_components_per_tier: usize, + min_tier_size_bytes: u64, + output_level: u32, + }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StartupDiagnostics { + pub memtable_size_bytes: usize, + pub memtable_arena_block_size_bytes: usize, + pub flush_poll_interval_ms: u64, + pub flush_timeout_ms: u64, + pub server_max_concurrent_connections: usize, + pub server_max_in_flight_requests_per_connection: usize, + pub server_max_request_bytes: usize, + pub server_max_statements_per_request: usize, + pub server_statement_timeout_ms: Option, + pub server_max_memory_intensive_requests: usize, + pub server_max_scan_rows: usize, + pub server_max_sort_rows: usize, + pub server_max_join_rows: usize, + pub server_max_query_result_rows: usize, + pub server_max_query_result_bytes: usize, + pub server_max_concurrent_queries_per_identity: Option, + pub security_auth_mode: AuthModeConfig, + pub security_tls_mode: TlsModeConfig, + pub security_user_count: usize, + pub security_token_count: usize, + pub security_tls_cert_path: Option, + pub security_tls_key_path: Option, + pub wal_segment_size_bytes: u64, + pub wal_sync_mode: SyncModeConfig, + pub sstable_data_block_size_bytes: usize, + pub sstable_restart_interval: usize, + pub sstable_bloom_bits_per_key: usize, + pub sstable_bloom_hash_functions: u8, + pub compaction: CompactionDiagnostics, +} + +impl StartupDiagnostics { + pub fn as_key_value_lines(&self) -> Vec { + let mut lines = vec![ + format!("storage.memtable_size_bytes={}", self.memtable_size_bytes), + format!( + "storage.memtable_arena_block_size_bytes={}", + self.memtable_arena_block_size_bytes + ), + format!("storage.flush_poll_interval_ms={}", self.flush_poll_interval_ms), + format!("storage.flush_timeout_ms={}", self.flush_timeout_ms), + format!("server.max_concurrent_connections={}", self.server_max_concurrent_connections), + format!( + "server.max_in_flight_requests_per_connection={}", + self.server_max_in_flight_requests_per_connection + ), + format!("server.max_request_bytes={}", self.server_max_request_bytes), + format!("server.max_statements_per_request={}", self.server_max_statements_per_request), + format!( + "server.statement_timeout_ms={}", + format_optional_u64(self.server_statement_timeout_ms) + ), + format!( + "server.max_memory_intensive_requests={}", + self.server_max_memory_intensive_requests + ), + format!("server.max_scan_rows={}", self.server_max_scan_rows), + format!("server.max_sort_rows={}", self.server_max_sort_rows), + format!("server.max_join_rows={}", self.server_max_join_rows), + format!("server.max_query_result_rows={}", self.server_max_query_result_rows), + format!("server.max_query_result_bytes={}", self.server_max_query_result_bytes), + format!( + "server.max_concurrent_queries_per_identity={}", + format_optional_usize(self.server_max_concurrent_queries_per_identity) + ), + format!("security.auth.mode={}", self.security_auth_mode.as_str()), + format!("security.tls.mode={}", self.security_tls_mode.as_str()), + format!("security.auth.users={}", self.security_user_count), + format!("security.auth.tokens={}", self.security_token_count), + format!( + "security.tls.cert_path={}", + format_optional_string(self.security_tls_cert_path.as_deref()) + ), + format!( + "security.tls.key_path={}", + format_optional_string(self.security_tls_key_path.as_deref()) + ), + format!("wal.segment_size_bytes={}", self.wal_segment_size_bytes), + format!("wal.sync_mode={}", self.wal_sync_mode.as_str()), + format!("sstable.data_block_size_bytes={}", self.sstable_data_block_size_bytes), + format!("sstable.restart_interval={}", self.sstable_restart_interval), + format!("sstable.bloom_bits_per_key={}", self.sstable_bloom_bits_per_key), + format!("sstable.bloom_hash_functions={}", self.sstable_bloom_hash_functions), + ]; + + match self.compaction { + CompactionDiagnostics::Leveled { + level0_file_limit, + level_size_base_bytes, + level_size_multiplier, + max_levels, + } => { + lines.push("compaction.strategy=leveled".to_string()); + lines.push(format!("compaction.leveled.level0_file_limit={level0_file_limit}")); + lines.push(format!( + "compaction.leveled.level_size_base_bytes={level_size_base_bytes}" + )); + lines.push(format!( + "compaction.leveled.level_size_multiplier={level_size_multiplier}" + )); + lines.push(format!("compaction.leveled.max_levels={max_levels}")); + } + CompactionDiagnostics::Tiered { + max_components_per_tier, + min_tier_size_bytes, + output_level, + } => { + lines.push("compaction.strategy=tiered".to_string()); + lines.push(format!( + "compaction.tiered.max_components_per_tier={max_components_per_tier}" + )); + lines.push(format!("compaction.tiered.min_tier_size_bytes={min_tier_size_bytes}")); + lines.push(format!("compaction.tiered.output_level={output_level}")); + } + } + + lines + } } impl LsmdbConfig { @@ -229,12 +494,144 @@ impl LsmdbConfig { if self.storage.memtable_arena_block_size_bytes == 0 { return Err(invalid("storage.memtable_arena_block_size_bytes", "must be > 0")); } + if self.storage.memtable_arena_block_size_bytes > self.storage.memtable_size_bytes { + return Err(invalid( + "storage.memtable_arena_block_size_bytes", + "must be <= storage.memtable_size_bytes", + )); + } if self.storage.flush_poll_interval_ms == 0 { return Err(invalid("storage.flush_poll_interval_ms", "must be > 0")); } if self.storage.flush_timeout_ms == 0 { return Err(invalid("storage.flush_timeout_ms", "must be > 0")); } + if self.storage.flush_timeout_ms < self.storage.flush_poll_interval_ms { + return Err(invalid( + "storage.flush_timeout_ms", + "must be >= storage.flush_poll_interval_ms", + )); + } + if self.server.max_concurrent_connections == 0 { + return Err(invalid("server.max_concurrent_connections", "must be > 0")); + } + if self.server.max_in_flight_requests_per_connection == 0 { + return Err(invalid("server.max_in_flight_requests_per_connection", "must be > 0")); + } + if self.server.max_request_bytes == 0 { + return Err(invalid("server.max_request_bytes", "must be > 0")); + } + if self.server.max_statements_per_request == 0 { + return Err(invalid("server.max_statements_per_request", "must be > 0")); + } + if matches!(self.server.statement_timeout_ms, Some(0)) { + return Err(invalid("server.statement_timeout_ms", "must be > 0 when set")); + } + if self.server.max_memory_intensive_requests == 0 { + return Err(invalid("server.max_memory_intensive_requests", "must be > 0")); + } + if self.server.max_scan_rows == 0 { + return Err(invalid("server.max_scan_rows", "must be > 0")); + } + if self.server.max_sort_rows == 0 { + return Err(invalid("server.max_sort_rows", "must be > 0")); + } + if self.server.max_join_rows == 0 { + return Err(invalid("server.max_join_rows", "must be > 0")); + } + if self.server.max_query_result_rows == 0 { + return Err(invalid("server.max_query_result_rows", "must be > 0")); + } + if self.server.max_query_result_bytes == 0 { + return Err(invalid("server.max_query_result_bytes", "must be > 0")); + } + if matches!(self.server.max_concurrent_queries_per_identity, Some(0)) { + return Err(invalid( + "server.max_concurrent_queries_per_identity", + "must be > 0 when set", + )); + } + match self.security.auth.mode { + AuthModeConfig::Disabled => {} + AuthModeConfig::Password => { + if self.security.auth.users.is_empty() { + return Err(invalid( + "security.auth.users", + "must contain at least one user when auth mode is 'password'", + )); + } + let mut seen = std::collections::BTreeSet::new(); + for user in &self.security.auth.users { + if user.username.trim().is_empty() { + return Err(invalid("security.auth.users.username", "must not be empty")); + } + if user.password.is_empty() { + return Err(invalid("security.auth.users.password", "must not be empty")); + } + if !seen.insert(user.username.as_str()) { + return Err(invalid( + "security.auth.users.username", + format!("duplicate username '{}'", user.username), + )); + } + } + } + AuthModeConfig::Token => { + if self.security.auth.tokens.is_empty() { + return Err(invalid( + "security.auth.tokens", + "must contain at least one token when auth mode is 'token'", + )); + } + let mut labels = std::collections::BTreeSet::new(); + let mut tokens = std::collections::BTreeSet::new(); + for token in &self.security.auth.tokens { + if token.label.trim().is_empty() { + return Err(invalid("security.auth.tokens.label", "must not be empty")); + } + if token.token.is_empty() { + return Err(invalid("security.auth.tokens.token", "must not be empty")); + } + if !labels.insert(token.label.as_str()) { + return Err(invalid( + "security.auth.tokens.label", + format!("duplicate token label '{}'", token.label), + )); + } + if !tokens.insert(token.token.as_str()) { + return Err(invalid( + "security.auth.tokens.token", + format!("duplicate token value for '{}'", token.label), + )); + } + } + } + } + if self.security.auth.mode != AuthModeConfig::Disabled + && self.security.tls.mode != TlsModeConfig::Required + { + return Err(invalid( + "security.tls.mode", + "must be 'required' when authentication is enabled", + )); + } + match self.security.tls.mode { + TlsModeConfig::Disabled => {} + TlsModeConfig::Required => { + let cert_path = self.security.tls.cert_path.as_ref().ok_or_else(|| { + invalid("security.tls.cert_path", "must be set when tls mode is 'required'") + })?; + let key_path = self.security.tls.key_path.as_ref().ok_or_else(|| { + invalid("security.tls.key_path", "must be set when tls mode is 'required'") + })?; + if cert_path.as_os_str().is_empty() { + return Err(invalid("security.tls.cert_path", "must not be empty")); + } + if key_path.as_os_str().is_empty() { + return Err(invalid("security.tls.key_path", "must not be empty")); + } + } + } if self.wal.segment_size_bytes < MIN_WAL_SEGMENT_SIZE_BYTES { return Err(invalid( "wal.segment_size_bytes", @@ -272,12 +669,80 @@ impl LsmdbConfig { Ok(()) } + pub fn startup_diagnostics(&self) -> Result { + let runtime = self.to_runtime_config()?; + let storage = runtime.storage_engine; + let compaction = match runtime.compaction_strategy { + CompactionStrategy::Leveled(config) => CompactionDiagnostics::Leveled { + level0_file_limit: config.level0_file_limit, + level_size_base_bytes: config.level_size_base_bytes, + level_size_multiplier: config.level_size_multiplier, + max_levels: config.max_levels, + }, + CompactionStrategy::Tiered(config) => CompactionDiagnostics::Tiered { + max_components_per_tier: config.max_components_per_tier, + min_tier_size_bytes: config.min_tier_size_bytes, + output_level: config.output_level, + }, + }; + + Ok(StartupDiagnostics { + memtable_size_bytes: storage.memtable_size_bytes, + memtable_arena_block_size_bytes: storage.memtable_arena_block_size_bytes, + flush_poll_interval_ms: storage.flush_poll_interval.as_millis() as u64, + flush_timeout_ms: storage.flush_timeout.as_millis() as u64, + server_max_concurrent_connections: runtime.server_limits.max_concurrent_connections, + server_max_in_flight_requests_per_connection: runtime + .server_limits + .max_in_flight_requests_per_connection, + server_max_request_bytes: runtime.server_limits.max_request_bytes, + server_max_statements_per_request: runtime.server_limits.max_statements_per_request, + server_statement_timeout_ms: runtime.server_limits.statement_timeout_ms, + server_max_memory_intensive_requests: runtime + .server_limits + .max_memory_intensive_requests, + server_max_scan_rows: runtime.server_limits.max_scan_rows, + server_max_sort_rows: runtime.server_limits.max_sort_rows, + server_max_join_rows: runtime.server_limits.max_join_rows, + server_max_query_result_rows: runtime.server_limits.max_query_result_rows, + server_max_query_result_bytes: runtime.server_limits.max_query_result_bytes, + server_max_concurrent_queries_per_identity: runtime + .server_limits + .max_concurrent_queries_per_identity, + security_auth_mode: self.security.auth.mode, + security_tls_mode: self.security.tls.mode, + security_user_count: self.security.auth.users.len(), + security_token_count: self.security.auth.tokens.len(), + security_tls_cert_path: self + .security + .tls + .cert_path + .as_ref() + .map(|path| path.display().to_string()), + security_tls_key_path: self + .security + .tls + .key_path + .as_ref() + .map(|path| path.display().to_string()), + wal_segment_size_bytes: storage.wal_options.segment_size_bytes, + wal_sync_mode: SyncModeConfig::from(storage.wal_options.sync_mode), + sstable_data_block_size_bytes: storage.sstable_builder_options.data_block_size_bytes, + sstable_restart_interval: storage.sstable_builder_options.restart_interval, + sstable_bloom_bits_per_key: storage.sstable_builder_options.bloom_bits_per_key, + sstable_bloom_hash_functions: storage.sstable_builder_options.bloom_hash_functions, + compaction, + }) + } + pub fn to_runtime_config(&self) -> Result { self.validate()?; Ok(RuntimeConfig { storage_engine: self.to_storage_engine_options_unchecked(), compaction_strategy: self.to_compaction_strategy_unchecked(), + server_limits: self.to_server_limits_unchecked(), + server_security: self.to_server_security_options_unchecked(), }) } @@ -291,6 +756,24 @@ impl LsmdbConfig { Ok(self.to_compaction_strategy_unchecked()) } + pub fn to_server_limits(&self) -> Result { + self.validate()?; + Ok(self.to_server_limits_unchecked()) + } + + pub fn to_server_security_options(&self) -> Result { + self.validate()?; + Ok(self.to_server_security_options_unchecked()) + } + + pub fn to_server_options(&self) -> Result { + self.validate()?; + Ok(ServerOptions { + limits: self.to_server_limits_unchecked(), + security: self.to_server_security_options_unchecked(), + }) + } + fn to_storage_engine_options_unchecked(&self) -> StorageEngineOptions { let (bits_per_key, hash_functions) = bloom_params_for_fpr(self.sstable.bloom_fpr); @@ -328,12 +811,85 @@ impl LsmdbConfig { }), } } + + fn to_server_limits_unchecked(&self) -> ServerLimits { + ServerLimits { + max_concurrent_connections: self.server.max_concurrent_connections, + max_in_flight_requests_per_connection: self + .server + .max_in_flight_requests_per_connection, + max_request_bytes: self.server.max_request_bytes, + max_statements_per_request: self.server.max_statements_per_request, + statement_timeout_ms: self.server.statement_timeout_ms, + max_memory_intensive_requests: self.server.max_memory_intensive_requests, + max_scan_rows: self.server.max_scan_rows, + max_sort_rows: self.server.max_sort_rows, + max_join_rows: self.server.max_join_rows, + max_query_result_rows: self.server.max_query_result_rows, + max_query_result_bytes: self.server.max_query_result_bytes, + max_concurrent_queries_per_identity: self.server.max_concurrent_queries_per_identity, + } + } + + fn to_server_security_options_unchecked(&self) -> ServerSecurityOptions { + let auth = match self.security.auth.mode { + AuthModeConfig::Disabled => ServerAuthOptions::Disabled, + AuthModeConfig::Password => ServerAuthOptions::StaticPassword { + users: self + .security + .auth + .users + .iter() + .map(|user| StaticPasswordUser { + username: user.username.clone(), + password: user.password.clone(), + role: user.role.into(), + }) + .collect(), + }, + AuthModeConfig::Token => ServerAuthOptions::StaticToken { + principals: self + .security + .auth + .tokens + .iter() + .map(|token| StaticTokenPrincipal { + label: token.label.clone(), + token: token.token.clone(), + role: token.role.into(), + }) + .collect(), + }, + }; + let tls = ServerTlsOptions { + mode: match self.security.tls.mode { + TlsModeConfig::Disabled => ServerTlsMode::Disabled, + TlsModeConfig::Required => ServerTlsMode::Required, + }, + cert_path: self.security.tls.cert_path.clone(), + key_path: self.security.tls.key_path.clone(), + }; + + ServerSecurityOptions { auth, tls, allow_anonymous_access: false } + } } fn invalid(field: &'static str, message: impl Into) -> ConfigError { ConfigError::InvalidValue { field, message: message.into() } } +fn format_optional_u64(value: Option) -> String { + value.map(|value| value.to_string()).unwrap_or_else(|| "none".to_string()) +} + +fn format_optional_usize(value: Option) -> String { + value.map(|value| value.to_string()).unwrap_or_else(|| "none".to_string()) +} + +fn format_optional_string(value: Option<&str>) -> String { + value.map(str::to_string).unwrap_or_else(|| "none".to_string()) +} + fn bloom_params_for_fpr(fpr: f64) -> (usize, u8) { let ln2 = std::f64::consts::LN_2; let bits = (-(fpr.ln()) / (ln2 * ln2)).ceil().max(1.0) as usize; @@ -354,6 +910,11 @@ mod tests { use super::*; + const TLS_CERT_PATH: &str = + concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/tls/server.crt"); + const TLS_KEY_PATH: &str = + concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/tls/server.key"); + fn temp_file_path(label: &str) -> PathBuf { let mut path = std::env::temp_dir(); let nanos = SystemTime::now() @@ -371,6 +932,7 @@ mod tests { let runtime = config.to_runtime_config().expect("runtime config"); assert!(runtime.storage_engine.memtable_size_bytes > 0); + assert!(runtime.server_limits.max_concurrent_connections > 0); match runtime.compaction_strategy { CompactionStrategy::Leveled(_) => {} CompactionStrategy::Tiered(_) => panic!("expected leveled strategy by default"), @@ -386,6 +948,17 @@ mod tests { flush_poll_interval_ms = 15 flush_timeout_ms = 2000 + [server] + max_concurrent_connections = 32 + max_in_flight_requests_per_connection = 1 + max_request_bytes = 65536 + max_statements_per_request = 8 + max_memory_intensive_requests = 4 + max_scan_rows = 2048 + max_sort_rows = 1024 + max_join_rows = 512 + max_query_result_rows = 256 + [wal] segment_size_bytes = 16777216 sync_mode = "always" @@ -406,7 +979,11 @@ mod tests { let config = LsmdbConfig::from_toml_str(raw).expect("parse custom config"); let options = config.to_storage_engine_options().expect("storage options"); + let server_limits = config.to_server_limits().expect("server limits"); assert_eq!(options.memtable_size_bytes, 1_048_576); + assert_eq!(server_limits.max_concurrent_connections, 32); + assert_eq!(server_limits.max_request_bytes, 65_536); + assert_eq!(server_limits.max_join_rows, 512); assert_eq!(options.wal_options.segment_size_bytes, 16_777_216); assert_eq!(options.wal_options.sync_mode, SyncMode::Always); assert_eq!(options.sstable_builder_options.data_block_size_bytes, 8192); @@ -425,6 +1002,40 @@ mod tests { } } + #[test] + fn parses_and_maps_security_config() { + let raw = format!( + r#" + [security.auth] + mode = "password" + + [[security.auth.users]] + username = "admin" + password = "secret" + role = "admin" + + [security.tls] + mode = "required" + cert_path = "{TLS_CERT_PATH}" + key_path = "{TLS_KEY_PATH}" + "# + ); + + let config = LsmdbConfig::from_toml_str(&raw).expect("parse security config"); + let options = config.to_server_options().expect("server options"); + match options.security.auth { + ServerAuthOptions::StaticPassword { users } => { + assert_eq!(users.len(), 1); + assert_eq!(users[0].username, "admin"); + assert_eq!(users[0].role, ServerRole::Admin); + } + other => panic!("expected static password auth, got {other:?}"), + } + assert_eq!(options.security.tls.mode, ServerTlsMode::Required); + assert_eq!(options.security.tls.cert_path.as_deref(), Some(Path::new(TLS_CERT_PATH))); + assert_eq!(options.security.tls.key_path.as_deref(), Some(Path::new(TLS_KEY_PATH))); + } + #[test] fn rejects_invalid_bloom_fpr() { let raw = r#" @@ -455,4 +1066,159 @@ mod tests { fs::remove_file(path).expect("cleanup temp config"); } + + #[test] + fn rejects_arena_block_larger_than_memtable() { + let raw = r#" + [storage] + memtable_size_bytes = 4096 + memtable_arena_block_size_bytes = 8192 + "#; + + let err = LsmdbConfig::from_toml_str(raw).expect_err("invalid arena block size"); + assert!(matches!( + err, + ConfigError::InvalidValue { field, .. } + if field == "storage.memtable_arena_block_size_bytes" + )); + } + + #[test] + fn rejects_flush_timeout_smaller_than_poll_interval() { + let raw = r#" + [storage] + flush_poll_interval_ms = 100 + flush_timeout_ms = 50 + "#; + + let err = LsmdbConfig::from_toml_str(raw).expect_err("invalid flush timing"); + assert!( + matches!(err, ConfigError::InvalidValue { field, .. } if field == "storage.flush_timeout_ms") + ); + } + + #[test] + fn rejects_zero_server_connection_limit() { + let raw = r#" + [server] + max_concurrent_connections = 0 + "#; + + let err = LsmdbConfig::from_toml_str(raw).expect_err("invalid server limit"); + assert!( + matches!(err, ConfigError::InvalidValue { field, .. } if field == "server.max_concurrent_connections") + ); + } + + #[test] + fn rejects_password_auth_without_users() { + let raw = r#" + [security.auth] + mode = "password" + "#; + + let err = LsmdbConfig::from_toml_str(raw).expect_err("missing users should fail"); + assert!( + matches!(err, ConfigError::InvalidValue { field, .. } if field == "security.auth.users") + ); + } + + #[test] + fn rejects_password_auth_without_required_tls() { + let raw = r#" + [security.auth] + mode = "password" + + [[security.auth.users]] + username = "admin" + password = "secret" + role = "admin" + "#; + + let err = LsmdbConfig::from_toml_str(raw).expect_err("password auth should require tls"); + assert!( + matches!(err, ConfigError::InvalidValue { field, .. } if field == "security.tls.mode") + ); + } + + #[test] + fn rejects_tls_required_without_key_path() { + let raw = r#" + [security.tls] + mode = "required" + cert_path = "./server.crt" + "#; + + let err = LsmdbConfig::from_toml_str(raw).expect_err("missing key should fail"); + assert!( + matches!(err, ConfigError::InvalidValue { field, .. } if field == "security.tls.key_path") + ); + } + + #[test] + fn emits_startup_diagnostics_for_runtime_config() { + let raw = format!( + r#" + [storage] + memtable_size_bytes = 8192 + memtable_arena_block_size_bytes = 4096 + flush_poll_interval_ms = 25 + flush_timeout_ms = 100 + + [server] + max_concurrent_connections = 24 + max_in_flight_requests_per_connection = 1 + max_request_bytes = 32768 + max_statements_per_request = 4 + max_memory_intensive_requests = 2 + max_scan_rows = 128 + max_sort_rows = 64 + max_join_rows = 32 + max_query_result_rows = 16 + + [wal] + segment_size_bytes = 4096 + sync_mode = "on_commit" + + [security.auth] + mode = "token" + + [[security.auth.tokens]] + label = "ops-bot" + token = "opaque" + role = "writer" + + [security.tls] + mode = "required" + cert_path = "{TLS_CERT_PATH}" + key_path = "{TLS_KEY_PATH}" + + [compaction] + strategy = "leveled" + "# + ); + + let config = LsmdbConfig::from_toml_str(&raw).expect("parse valid config"); + let diagnostics = config.startup_diagnostics().expect("startup diagnostics"); + assert_eq!(diagnostics.memtable_size_bytes, 8192); + assert_eq!(diagnostics.memtable_arena_block_size_bytes, 4096); + assert_eq!(diagnostics.flush_poll_interval_ms, 25); + assert_eq!(diagnostics.flush_timeout_ms, 100); + assert_eq!(diagnostics.server_max_concurrent_connections, 24); + assert_eq!(diagnostics.server_max_request_bytes, 32_768); + assert_eq!(diagnostics.server_max_query_result_rows, 16); + assert_eq!(diagnostics.security_auth_mode, AuthModeConfig::Token); + assert_eq!(diagnostics.security_tls_mode, TlsModeConfig::Required); + assert_eq!(diagnostics.security_user_count, 0); + assert_eq!(diagnostics.security_token_count, 1); + assert_eq!(diagnostics.wal_segment_size_bytes, 4096); + assert_eq!(diagnostics.wal_sync_mode, SyncModeConfig::OnCommit); + + let lines = diagnostics.as_key_value_lines(); + assert!(lines.iter().any(|line| line == "server.max_concurrent_connections=24")); + assert!(lines.iter().any(|line| line == "security.auth.mode=token")); + assert!(lines.iter().any(|line| line == "security.tls.mode=required")); + assert!(lines.iter().any(|line| line == "compaction.strategy=leveled")); + assert!(lines.iter().any(|line| line.starts_with("sstable.bloom_bits_per_key="))); + } } diff --git a/src/executor/delete.rs b/src/executor/delete.rs index 60204fd..55f3a41 100644 --- a/src/executor/delete.rs +++ b/src/executor/delete.rs @@ -2,31 +2,36 @@ use crate::catalog::Catalog; use crate::mvcc::Transaction; use crate::planner::DeleteNode; -use super::ExecutionError; use super::filter::evaluate_predicate; use super::scan::scan_table_rows; +use super::{ExecutionContext, ExecutionError, apply_staged_writes}; pub(crate) fn execute_delete( catalog: &Catalog, tx: &mut Transaction, node: &DeleteNode, + context: &ExecutionContext<'_>, ) -> Result { + context.checkpoint()?; let table = catalog .get_table(&node.table) .ok_or_else(|| ExecutionError::TableNotFound(node.table.clone()))?; - let (_, rows) = scan_table_rows(catalog, tx, &table.name)?; + let (_, rows) = scan_table_rows(catalog, tx, &table.name, usize::MAX, context)?; let mut affected = 0_u64; + let mut staged = std::collections::BTreeMap::new(); for stored in rows { + context.checkpoint()?; if let Some(predicate) = &node.predicate { if !evaluate_predicate(predicate, &stored.values, Some(&table.name))? { continue; } } - tx.delete(&stored.key)?; + staged.insert(stored.key, None); affected = affected.saturating_add(1); } + apply_staged_writes(tx, staged)?; Ok(affected) } diff --git a/src/executor/filter.rs b/src/executor/filter.rs index 534d5f4..b5e4e90 100644 --- a/src/executor/filter.rs +++ b/src/executor/filter.rs @@ -1,18 +1,24 @@ use crate::sql::ast::{BinaryOp, Expr, UnaryOp}; -use super::{literal_to_scalar, scalar_type_name, ExecutionError, Row, RowSet, ScalarValue}; - -pub(crate) fn apply_filter(input: RowSet, predicate: &Expr) -> Result { +use super::{ + ExecutionContext, ExecutionError, Row, RowSet, ScalarValue, literal_to_scalar, scalar_type_name, +}; + +pub(crate) fn apply_filter( + input: RowSet, + predicate: &Expr, + context: &ExecutionContext<'_>, +) -> Result { let RowSet { columns, rows, table_name } = input; - let filtered_rows = rows - .into_iter() - .filter_map(|row| match evaluate_predicate(predicate, &row, table_name.as_deref()) { - Ok(true) => Some(Ok(row)), - Ok(false) => None, - Err(err) => Some(Err(err)), - }) - .collect::, _>>()?; + let mut filtered_rows = Vec::new(); + for row in rows { + context.checkpoint()?; + match evaluate_predicate(predicate, &row, table_name.as_deref())? { + true => filtered_rows.push(row), + false => {} + } + } Ok(RowSet { columns, rows: filtered_rows, table_name }) } @@ -312,8 +318,17 @@ fn as_numeric(value: &ScalarValue) -> Result { #[cfg(test)] mod tests { use super::*; + use crate::executor::ExecutionLimits; + use crate::executor::governance::ExecutionGovernance; use crate::sql::ast::LiteralValue; + fn test_context() -> ExecutionContext<'static> { + ExecutionContext { + limits: Box::leak(Box::new(ExecutionLimits::default())), + governance: Box::leak(Box::new(ExecutionGovernance::default())), + } + } + #[test] fn evaluates_arithmetic_expression() { let expr = Expr::Binary { @@ -353,7 +368,7 @@ mod tests { right: Box::new(Expr::Literal(LiteralValue::Integer(2))), }; - let filtered = apply_filter(row_set, &predicate).expect("filter"); + let filtered = apply_filter(row_set, &predicate, &test_context()).expect("filter"); assert_eq!(filtered.rows.len(), 1); } } diff --git a/src/executor/governance.rs b/src/executor/governance.rs new file mode 100644 index 0000000..5539ace --- /dev/null +++ b/src/executor/governance.rs @@ -0,0 +1,125 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::time::{Duration, Instant}; + +use super::ExecutionError; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CancellationReason { + UserRequested = 1, +} + +impl CancellationReason { + pub fn as_str(self) -> &'static str { + match self { + CancellationReason::UserRequested => "user requested cancellation", + } + } +} + +#[derive(Debug, Default)] +struct CancellationState { + reason: AtomicU8, +} + +#[derive(Debug, Clone)] +pub struct StatementCancellation { + state: Arc, +} + +impl Default for StatementCancellation { + fn default() -> Self { + Self::new() + } +} + +impl StatementCancellation { + pub fn new() -> Self { + Self { state: Arc::new(CancellationState::default()) } + } + + pub fn cancel(&self) -> bool { + self.state + .reason + .compare_exchange( + 0, + CancellationReason::UserRequested as u8, + Ordering::SeqCst, + Ordering::SeqCst, + ) + .is_ok() + } + + pub fn reason(&self) -> Option { + match self.state.reason.load(Ordering::SeqCst) { + 0 => None, + 1 => Some(CancellationReason::UserRequested), + other => { + debug_assert_eq!(other, CancellationReason::UserRequested as u8); + Some(CancellationReason::UserRequested) + } + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct StatementDeadline { + deadline: Instant, + timeout: Duration, +} + +impl StatementDeadline { + pub fn after(timeout: Duration) -> Self { + Self { deadline: Instant::now() + timeout, timeout } + } + + pub fn is_elapsed(self) -> bool { + Instant::now() >= self.deadline + } + + pub fn timeout_ms(self) -> u64 { + let millis = self.timeout.as_millis(); + u64::try_from(millis).unwrap_or(u64::MAX) + } +} + +#[derive(Debug, Clone, Default)] +pub struct ExecutionGovernance { + deadline: Option, + cancellation: Option, +} + +impl ExecutionGovernance { + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.deadline = Some(StatementDeadline::after(timeout)); + self + } + + pub fn with_deadline(mut self, deadline: StatementDeadline) -> Self { + self.deadline = Some(deadline); + self + } + + pub fn with_cancellation(mut self, cancellation: StatementCancellation) -> Self { + self.cancellation = Some(cancellation); + self + } + + pub fn checkpoint(&self) -> Result<(), ExecutionError> { + if let Some(cancellation) = &self.cancellation { + if let Some(reason) = cancellation.reason() { + return Err(ExecutionError::StatementCanceled { reason: reason.as_str() }); + } + } + + if let Some(deadline) = self.deadline { + if deadline.is_elapsed() { + return Err(ExecutionError::StatementTimedOut { + timeout_ms: deadline.timeout_ms(), + }); + } + } + + Ok(()) + } +} diff --git a/src/executor/insert.rs b/src/executor/insert.rs index af70907..047e99b 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -6,14 +6,17 @@ use crate::planner::InsertNode; use super::filter::evaluate_const_expr; use super::{ - ExecutionError, Row, build_row_key, coerce_row_for_table, coerce_scalar_for_column, encode_row, + ExecutionContext, ExecutionError, Row, apply_staged_writes, build_row_key, + coerce_row_for_table, coerce_scalar_for_column, encode_row, staged_value_for_key, }; pub(crate) fn execute_insert( catalog: &Catalog, tx: &mut Transaction, node: &InsertNode, + context: &ExecutionContext<'_>, ) -> Result { + context.checkpoint()?; let table = catalog .get_table(&node.table) .ok_or_else(|| ExecutionError::TableNotFound(node.table.clone()))?; @@ -29,6 +32,7 @@ pub(crate) fn execute_insert( let mut seen = HashSet::new(); for (column_name, expr) in node.columns.iter().zip(node.values.iter()) { + context.checkpoint()?; if !seen.insert(column_name.clone()) { return Err(ExecutionError::DuplicateColumn(column_name.clone())); } @@ -44,12 +48,14 @@ pub(crate) fn execute_insert( let normalized = coerce_row_for_table(&table, &row)?; let key = build_row_key(&table, &normalized)?; - if tx.get(&key)?.is_some() { + let mut staged = std::collections::BTreeMap::new(); + if staged_value_for_key(tx, &staged, &key)?.is_some() { return Err(ExecutionError::PrimaryKeyConflict { table: table.name.clone() }); } let payload = encode_row(&table, &normalized)?; - tx.put(&key, &payload)?; + staged.insert(key, Some(payload)); + apply_staged_writes(tx, staged)?; Ok(1) } @@ -59,9 +65,18 @@ mod tests { use crate::catalog::column::ColumnDescriptor; use crate::catalog::schema::{ColumnType, DefaultValue}; use crate::catalog::table::TableDescriptor; + use crate::executor::ExecutionLimits; + use crate::executor::governance::ExecutionGovernance; use crate::mvcc::MvccStore; use crate::sql::ast::{Expr, LiteralValue}; + fn test_context() -> ExecutionContext<'static> { + ExecutionContext { + limits: Box::leak(Box::new(ExecutionLimits::default())), + governance: Box::leak(Box::new(ExecutionGovernance::default())), + } + } + #[test] fn inserts_row_with_default_values() { let store = MvccStore::new(); @@ -94,7 +109,7 @@ mod tests { }; let mut tx = store.begin_transaction(); - let affected = execute_insert(&catalog, &mut tx, &node).expect("insert"); + let affected = execute_insert(&catalog, &mut tx, &node, &test_context()).expect("insert"); assert_eq!(affected, 1); } } diff --git a/src/executor/join.rs b/src/executor/join.rs index ee56337..84e1c84 100644 --- a/src/executor/join.rs +++ b/src/executor/join.rs @@ -1,22 +1,31 @@ use crate::sql::ast::Expr; use super::filter::evaluate_predicate; -use super::{ExecutionError, Row, RowSet}; +use super::{ExecutionContext, ExecutionError, Row, RowSet}; pub(crate) fn execute_join( left: RowSet, right: RowSet, predicate: &Expr, + context: &ExecutionContext<'_>, ) -> Result { let right_prefix = right.table_name.clone().unwrap_or_else(|| "right".to_string()); let output_columns = build_join_columns(&left.columns, &right.columns, &right_prefix); + context.limits.ensure_join_rows(left.rows.len())?; + context.limits.ensure_join_rows(right.rows.len())?; + + let candidate_pairs = left.rows.len().saturating_mul(right.rows.len()); + context.limits.ensure_join_rows(candidate_pairs)?; let mut joined_rows = Vec::new(); for left_row in &left.rows { + context.checkpoint()?; for right_row in &right.rows { + context.checkpoint()?; let merged = merge_rows(left_row, right_row, &right_prefix); if evaluate_predicate(predicate, &merged, None)? { joined_rows.push(merged); + context.limits.ensure_join_rows(joined_rows.len())?; } } } @@ -50,10 +59,21 @@ fn merge_rows(left: &Row, right: &Row, right_prefix: &str) -> Row { #[cfg(test)] mod tests { + use std::thread; + use std::time::Duration; + use super::*; - use crate::executor::ScalarValue; + use crate::executor::governance::{ExecutionGovernance, StatementCancellation}; + use crate::executor::{ExecutionLimits, ScalarValue}; use crate::sql::ast::{BinaryOp, LiteralValue}; + fn test_context(limits: ExecutionLimits) -> ExecutionContext<'static> { + ExecutionContext { + limits: Box::leak(Box::new(limits)), + governance: Box::leak(Box::new(ExecutionGovernance::default())), + } + } + #[test] fn joins_rows_with_predicate() { let mut left_row = Row::new(); @@ -81,8 +101,96 @@ mod tests { right: Box::new(Expr::Literal(LiteralValue::Integer(1))), }; - let joined = execute_join(left, right, &predicate).expect("join"); + let joined = + execute_join(left, right, &predicate, &test_context(ExecutionLimits::default())) + .expect("join"); assert_eq!(joined.rows.len(), 1); assert!(joined.rows[0].contains_key("profiles.id")); } + + #[test] + fn rejects_join_that_exceeds_row_limit() { + let mut left_row_a = Row::new(); + left_row_a.insert("id".to_string(), ScalarValue::BigInt(1)); + let mut left_row_b = Row::new(); + left_row_b.insert("id".to_string(), ScalarValue::BigInt(2)); + + let mut right_row_a = Row::new(); + right_row_a.insert("id".to_string(), ScalarValue::BigInt(1)); + let mut right_row_b = Row::new(); + right_row_b.insert("id".to_string(), ScalarValue::BigInt(2)); + + let left = RowSet { + columns: vec!["id".to_string()], + rows: vec![left_row_a, left_row_b], + table_name: Some("users".to_string()), + }; + let right = RowSet { + columns: vec!["id".to_string()], + rows: vec![right_row_a, right_row_b], + table_name: Some("profiles".to_string()), + }; + + let predicate = Expr::Binary { + left: Box::new(Expr::Identifier("id".to_string())), + op: BinaryOp::Equal, + right: Box::new(Expr::Literal(LiteralValue::Integer(1))), + }; + + let err = execute_join( + left, + right, + &predicate, + &test_context(ExecutionLimits { max_join_rows: 3, ..ExecutionLimits::default() }), + ) + .expect_err("join limit should fail"); + assert!(matches!( + err, + ExecutionError::ResourceLimitExceeded { resource, limit, .. } + if resource == "join rows" && limit == 3 + )); + } + + #[test] + fn cancels_long_running_join() { + let mut left_rows = Vec::new(); + let mut right_rows = Vec::new(); + for id in 0..500_i64 { + let mut row = Row::new(); + row.insert("id".to_string(), ScalarValue::BigInt(id)); + left_rows.push(row.clone()); + right_rows.push(row); + } + + let left = RowSet { + columns: vec!["id".to_string()], + rows: left_rows, + table_name: Some("users".to_string()), + }; + let right = RowSet { + columns: vec!["id".to_string()], + rows: right_rows, + table_name: Some("profiles".to_string()), + }; + + let cancellation = StatementCancellation::new(); + let worker_cancellation = cancellation.clone(); + let handle = thread::spawn(move || { + let limits = Box::leak(Box::new(ExecutionLimits::default())); + let governance = Box::leak(Box::new( + ExecutionGovernance::default().with_cancellation(worker_cancellation), + )); + execute_join( + left, + right, + &Expr::Literal(LiteralValue::Boolean(true)), + &ExecutionContext { limits, governance }, + ) + }); + + thread::sleep(Duration::from_millis(1)); + cancellation.cancel(); + let err = handle.join().expect("join thread").expect_err("join should be canceled"); + assert!(matches!(err, ExecutionError::StatementCanceled { .. })); + } } diff --git a/src/executor/mod.rs b/src/executor/mod.rs index 54f14a3..cd83eba 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -1,5 +1,6 @@ pub mod delete; pub mod filter; +pub mod governance; pub mod insert; pub mod join; pub mod projection; @@ -18,6 +19,8 @@ use crate::mvcc::{MvccStore, Transaction, TransactionError}; use crate::planner::PhysicalPlan; use crate::sql::ast::LiteralValue; +use self::governance::ExecutionGovernance; + #[derive(Debug, Clone, PartialEq)] pub enum ScalarValue { Integer(i32), @@ -34,6 +37,19 @@ impl ScalarValue { pub fn is_null(&self) -> bool { matches!(self, ScalarValue::Null) } + + fn estimated_query_bytes(&self) -> usize { + match self { + ScalarValue::Integer(_) => std::mem::size_of::(), + ScalarValue::BigInt(_) => std::mem::size_of::(), + ScalarValue::Float(_) => std::mem::size_of::(), + ScalarValue::Text(value) => value.len(), + ScalarValue::Boolean(_) => 1, + ScalarValue::Blob(value) => value.len(), + ScalarValue::Timestamp(_) => std::mem::size_of::(), + ScalarValue::Null => 0, + } + } } pub type Row = BTreeMap; @@ -89,10 +105,71 @@ pub enum ExecutionError { TransactionAlreadyActive, #[error("DDL in explicit transaction is not supported yet")] DdlInTransactionUnsupported, + #[error("resource limit exceeded for {resource}: actual={actual}, limit={limit}")] + ResourceLimitExceeded { resource: &'static str, actual: usize, limit: usize }, + #[error("statement timed out after {timeout_ms} ms")] + StatementTimedOut { timeout_ms: u64 }, + #[error("statement canceled: {reason}")] + StatementCanceled { reason: &'static str }, #[error("unsupported plan: {0}")] UnsupportedPlan(&'static str), } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ExecutionLimits { + pub max_scan_rows: usize, + pub max_sort_rows: usize, + pub max_join_rows: usize, + pub max_query_result_rows: usize, + pub max_query_result_bytes: usize, +} + +impl Default for ExecutionLimits { + fn default() -> Self { + Self { + max_scan_rows: usize::MAX, + max_sort_rows: usize::MAX, + max_join_rows: usize::MAX, + max_query_result_rows: usize::MAX, + max_query_result_bytes: usize::MAX, + } + } +} + +impl ExecutionLimits { + fn ensure_within( + &self, + resource: &'static str, + actual: usize, + limit: usize, + ) -> Result<(), ExecutionError> { + if actual > limit { + return Err(ExecutionError::ResourceLimitExceeded { resource, actual, limit }); + } + Ok(()) + } + + fn ensure_scan_rows(&self, actual: usize) -> Result<(), ExecutionError> { + self.ensure_within("scan rows", actual, self.max_scan_rows) + } + + fn ensure_sort_rows(&self, actual: usize) -> Result<(), ExecutionError> { + self.ensure_within("sort rows", actual, self.max_sort_rows) + } + + fn ensure_join_rows(&self, actual: usize) -> Result<(), ExecutionError> { + self.ensure_within("join rows", actual, self.max_join_rows) + } + + fn ensure_query_result_rows(&self, actual: usize) -> Result<(), ExecutionError> { + self.ensure_within("query result rows", actual, self.max_query_result_rows) + } + + fn ensure_query_result_bytes(&self, actual: usize) -> Result<(), ExecutionError> { + self.ensure_within("query result bytes", actual, self.max_query_result_bytes) + } +} + #[derive(Debug, Clone, PartialEq)] pub(crate) struct RowSet { pub columns: Vec, @@ -101,18 +178,29 @@ pub(crate) struct RowSet { } impl RowSet { - pub(crate) fn into_query_result(self) -> QueryResult { + pub(crate) fn into_query_result( + self, + context: &ExecutionContext<'_>, + ) -> Result { let mut materialized_rows = Vec::with_capacity(self.rows.len()); + let mut materialized_bytes = 0_usize; for row in self.rows { + context.checkpoint()?; let values = self .columns .iter() .map(|column| row.get(column).cloned().unwrap_or(ScalarValue::Null)) .collect::>(); + for value in &values { + materialized_bytes = + materialized_bytes.saturating_add(value.estimated_query_bytes()); + context.limits.ensure_query_result_bytes(materialized_bytes)?; + } materialized_rows.push(values); + context.limits.ensure_query_result_rows(materialized_rows.len())?; } - QueryResult { columns: self.columns, rows: materialized_rows } + Ok(QueryResult { columns: self.columns, rows: materialized_rows }) } } @@ -125,12 +213,21 @@ pub(crate) struct StoredRow { pub struct ExecutionSession<'a> { catalog: &'a Catalog, store: &'a MvccStore, + limits: ExecutionLimits, active_tx: Option, } impl<'a> ExecutionSession<'a> { pub fn new(catalog: &'a Catalog, store: &'a MvccStore) -> Self { - Self { catalog, store, active_tx: None } + Self::with_limits(catalog, store, ExecutionLimits::default()) + } + + pub fn with_limits( + catalog: &'a Catalog, + store: &'a MvccStore, + limits: ExecutionLimits, + ) -> Self { + Self { catalog, store, limits, active_tx: None } } pub fn has_active_transaction(&self) -> bool { @@ -138,6 +235,17 @@ impl<'a> ExecutionSession<'a> { } pub fn execute_plan(&mut self, plan: &PhysicalPlan) -> Result { + self.execute_plan_with_governance(plan, &ExecutionGovernance::default()) + } + + pub fn execute_plan_with_governance( + &mut self, + plan: &PhysicalPlan, + governance: &ExecutionGovernance, + ) -> Result { + let limits = self.limits; + let context = ExecutionContext { limits: &limits, governance }; + context.checkpoint()?; match plan { PhysicalPlan::Begin(_) => { if self.active_tx.is_some() { @@ -172,13 +280,13 @@ impl<'a> ExecutionSession<'a> { Ok(ExecutionResult::AffectedRows(0)) } PhysicalPlan::Insert(node) => self - .with_write_tx(|catalog, tx| insert::execute_insert(catalog, tx, node)) + .with_write_tx(|catalog, tx| insert::execute_insert(catalog, tx, node, &context)) .map(ExecutionResult::AffectedRows), PhysicalPlan::Update(node) => self - .with_write_tx(|catalog, tx| update::execute_update(catalog, tx, node)) + .with_write_tx(|catalog, tx| update::execute_update(catalog, tx, node, &context)) .map(ExecutionResult::AffectedRows), PhysicalPlan::Delete(node) => self - .with_write_tx(|catalog, tx| delete::execute_delete(catalog, tx, node)) + .with_write_tx(|catalog, tx| delete::execute_delete(catalog, tx, node, &context)) .map(ExecutionResult::AffectedRows), PhysicalPlan::SeqScan(_) | PhysicalPlan::PrimaryKeyScan(_) @@ -187,8 +295,17 @@ impl<'a> ExecutionSession<'a> { | PhysicalPlan::Sort(_) | PhysicalPlan::Limit(_) | PhysicalPlan::Join(_) => self - .with_read_tx(|catalog, tx| execute_query_plan(catalog, tx, plan)) - .map(|rows| ExecutionResult::Query(rows.into_query_result())), + .with_read_tx(|catalog, tx| execute_query_plan(catalog, tx, plan, &context)) + .and_then(|rows| Ok(ExecutionResult::Query(rows.into_query_result(&context)?))), + } + } + + pub fn abort_active_transaction(&mut self) -> bool { + if let Some(mut tx) = self.active_tx.take() { + tx.rollback(); + true + } else { + false } } @@ -233,35 +350,75 @@ fn execute_query_plan( catalog: &Catalog, tx: &mut Transaction, plan: &PhysicalPlan, + context: &ExecutionContext<'_>, ) -> Result { + context.checkpoint()?; match plan { - PhysicalPlan::SeqScan(node) => scan::execute_seq_scan(catalog, tx, node), - PhysicalPlan::PrimaryKeyScan(node) => scan::execute_primary_key_scan(catalog, tx, node), + PhysicalPlan::SeqScan(node) => scan::execute_seq_scan(catalog, tx, node, context), + PhysicalPlan::PrimaryKeyScan(node) => { + scan::execute_primary_key_scan(catalog, tx, node, context) + } PhysicalPlan::Filter(node) => { - let input = execute_query_plan(catalog, tx, &node.input)?; - filter::apply_filter(input, &node.predicate) + let input = execute_query_plan(catalog, tx, &node.input, context)?; + filter::apply_filter(input, &node.predicate, context) } PhysicalPlan::Project(node) => { - let input = execute_query_plan(catalog, tx, &node.input)?; - projection::apply_projection(input, &node.projection) + let input = execute_query_plan(catalog, tx, &node.input, context)?; + projection::apply_projection(input, &node.projection, context) } PhysicalPlan::Sort(node) => { - let input = execute_query_plan(catalog, tx, &node.input)?; - projection::apply_sort(input, &node.order_by) + let input = execute_query_plan(catalog, tx, &node.input, context)?; + projection::apply_sort(input, &node.order_by, context) } PhysicalPlan::Limit(node) => { - let input = execute_query_plan(catalog, tx, &node.input)?; - projection::apply_limit(input, node.limit) + let input = execute_query_plan(catalog, tx, &node.input, context)?; + projection::apply_limit(input, node.limit, context) } PhysicalPlan::Join(node) => { - let left = execute_query_plan(catalog, tx, &node.left)?; - let right = execute_query_plan(catalog, tx, &node.right)?; - join::execute_join(left, right, &node.on) + let left = execute_query_plan(catalog, tx, &node.left, context)?; + let right = execute_query_plan(catalog, tx, &node.right, context)?; + join::execute_join(left, right, &node.on, context) } _ => Err(ExecutionError::UnsupportedPlan("non-query node used in query execution path")), } } +#[derive(Clone, Copy)] +pub(crate) struct ExecutionContext<'a> { + pub(crate) limits: &'a ExecutionLimits, + pub(crate) governance: &'a ExecutionGovernance, +} + +impl<'a> ExecutionContext<'a> { + pub(crate) fn checkpoint(&self) -> Result<(), ExecutionError> { + self.governance.checkpoint() + } +} + +pub(crate) fn staged_value_for_key( + tx: &Transaction, + staged: &BTreeMap, Option>>, + key: &[u8], +) -> Result>, ExecutionError> { + if let Some(value) = staged.get(key) { + return Ok(value.clone()); + } + Ok(tx.get(key)?) +} + +pub(crate) fn apply_staged_writes( + tx: &mut Transaction, + staged: BTreeMap, Option>>, +) -> Result<(), ExecutionError> { + for (key, value) in staged { + match value { + Some(value) => tx.put(&key, &value)?, + None => tx.delete(&key)?, + } + } + Ok(()) +} + fn create_statement_to_descriptor( create: &crate::sql::ast::CreateTableStatement, ) -> Result { diff --git a/src/executor/projection.rs b/src/executor/projection.rs index e40cf20..57eb0eb 100644 --- a/src/executor/projection.rs +++ b/src/executor/projection.rs @@ -3,11 +3,12 @@ use std::cmp::Ordering; use crate::sql::ast::{Expr, OrderByExpr, SelectItem, SortDirection}; use super::filter::evaluate_expr; -use super::{ExecutionError, RowSet, ScalarValue}; +use super::{ExecutionContext, ExecutionError, RowSet, ScalarValue}; pub(crate) fn apply_projection( input: RowSet, projection: &[SelectItem], + context: &ExecutionContext<'_>, ) -> Result { if projection.len() == 1 && matches!(projection[0], SelectItem::Wildcard) { return Ok(input); @@ -23,6 +24,7 @@ pub(crate) fn apply_projection( let mut projected_rows = Vec::with_capacity(rows.len()); for row in rows { + context.checkpoint()?; let source = row.clone(); let mut materialized = row; for (index, item) in projection.iter().enumerate() { @@ -42,16 +44,19 @@ pub(crate) fn apply_projection( pub(crate) fn apply_sort( input: RowSet, order_by: &[OrderByExpr], + context: &ExecutionContext<'_>, ) -> Result { if order_by.is_empty() { return Ok(input); } let RowSet { columns, rows, table_name } = input; + context.limits.ensure_sort_rows(rows.len())?; let mut keyed = rows .into_iter() .map(|row| { + context.checkpoint()?; let sort_keys = order_by .iter() .map(|entry| evaluate_expr(&entry.expr, &row, table_name.as_deref())) @@ -68,7 +73,12 @@ pub(crate) fn apply_sort( Ok(RowSet { columns, rows, table_name }) } -pub(crate) fn apply_limit(mut input: RowSet, limit: u64) -> Result { +pub(crate) fn apply_limit( + mut input: RowSet, + limit: u64, + context: &ExecutionContext<'_>, +) -> Result { + context.checkpoint()?; let limit = usize::try_from(limit).unwrap_or(usize::MAX); input.rows.truncate(limit); Ok(input) @@ -102,9 +112,9 @@ fn compare_sort_keys( fn compare_scalar_for_sort(left: &ScalarValue, right: &ScalarValue) -> Ordering { match (left, right) { - (ScalarValue::Null, ScalarValue::Null) => Ordering::Equal, - (ScalarValue::Null, _) => Ordering::Greater, - (_, ScalarValue::Null) => Ordering::Less, + (ScalarValue::Null, ScalarValue::Null) => return Ordering::Equal, + (ScalarValue::Null, _) => return Ordering::Greater, + (_, ScalarValue::Null) => return Ordering::Less, _ => {} } @@ -146,9 +156,18 @@ fn scalar_sort_tag(value: &ScalarValue) -> &'static str { #[cfg(test)] mod tests { use super::*; + use crate::executor::ExecutionLimits; + use crate::executor::governance::ExecutionGovernance; use crate::executor::{Row, ScalarValue}; use crate::sql::ast::BinaryOp; + fn test_context() -> ExecutionContext<'static> { + ExecutionContext { + limits: Box::leak(Box::new(ExecutionLimits::default())), + governance: Box::leak(Box::new(ExecutionGovernance::default())), + } + } + #[test] fn projects_expression_columns() { let mut row = Row::new(); @@ -168,6 +187,7 @@ mod tests { op: BinaryOp::Add, right: Box::new(Expr::Identifier("b".to_string())), })], + &test_context(), ) .expect("projection"); @@ -194,6 +214,7 @@ mod tests { expr: Expr::Identifier("id".to_string()), direction: SortDirection::Desc, }], + &test_context(), ) .expect("sort"); diff --git a/src/executor/scan.rs b/src/executor/scan.rs index 57d3e77..095d65c 100644 --- a/src/executor/scan.rs +++ b/src/executor/scan.rs @@ -4,17 +4,20 @@ use crate::mvcc::Transaction; use crate::planner::{PrimaryKeyScanNode, SeqScanNode}; use super::{ - ExecutionError, Row, RowSet, StoredRow, build_row_key, coerce_scalar_for_column, decode_row, - literal_to_scalar, table_rows_prefix, + ExecutionContext, ExecutionError, Row, RowSet, StoredRow, build_row_key, + coerce_scalar_for_column, decode_row, literal_to_scalar, table_rows_prefix, }; pub(crate) fn execute_seq_scan( catalog: &Catalog, tx: &Transaction, node: &SeqScanNode, + context: &ExecutionContext<'_>, ) -> Result { - let (_, stored_rows) = scan_table_rows(catalog, tx, &node.table)?; + let (_, stored_rows) = + scan_table_rows(catalog, tx, &node.table, context.limits.max_scan_rows, context)?; let rows = stored_rows.into_iter().map(|stored| stored.values).collect::>(); + context.limits.ensure_scan_rows(rows.len())?; Ok(RowSet { columns: node.output_columns.clone(), rows, table_name: Some(node.table.clone()) }) } @@ -23,7 +26,9 @@ pub(crate) fn execute_primary_key_scan( catalog: &Catalog, tx: &Transaction, node: &PrimaryKeyScanNode, + context: &ExecutionContext<'_>, ) -> Result { + context.checkpoint()?; let table = get_table(catalog, &node.table)?; let mut key_row = Row::new(); @@ -53,15 +58,57 @@ pub(crate) fn scan_table_rows( catalog: &Catalog, tx: &Transaction, table_name: &str, + max_rows: usize, + context: &ExecutionContext<'_>, ) -> Result<(TableDescriptor, Vec), ExecutionError> { let table = get_table(catalog, table_name)?; let prefix = table_rows_prefix(&table.name); - let rows = tx.scan_prefix(&prefix)?; + let mut governance_error = None; + let rows = if max_rows == usize::MAX { + tx.scan_prefix_with_observer(&prefix, |seen| { + if seen == 0 { + return true; + } + match context.checkpoint() { + Ok(()) => true, + Err(err) => { + governance_error = Some(err); + false + } + } + })? + } else { + tx.scan_prefix_limited_with_observer(&prefix, max_rows, |seen| { + if seen == 0 { + return true; + } + match context.checkpoint() { + Ok(()) => true, + Err(err) => { + governance_error = Some(err); + false + } + } + })? + }; + if let Some(err) = governance_error { + return Err(err); + } + + if rows.len() > max_rows { + return Err(ExecutionError::ResourceLimitExceeded { + resource: "scan rows", + actual: rows.len(), + limit: max_rows, + }); + } - let decoded_rows = rows - .into_iter() - .map(|(key, payload)| decode_row(&table, &payload).map(|values| StoredRow { key, values })) - .collect::, _>>()?; + let mut decoded_rows = Vec::with_capacity(rows.len()); + for (key, payload) in rows { + context.checkpoint()?; + let values = decode_row(&table, &payload)?; + decoded_rows.push(StoredRow { key, values }); + } Ok((table, decoded_rows)) } diff --git a/src/executor/update.rs b/src/executor/update.rs index 434057b..282f6b5 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -5,21 +5,26 @@ use crate::planner::UpdateNode; use super::filter::{evaluate_expr, evaluate_predicate}; use super::scan::scan_table_rows; use super::{ - ExecutionError, build_row_key, coerce_row_for_table, coerce_scalar_for_column, encode_row, + ExecutionContext, ExecutionError, apply_staged_writes, build_row_key, coerce_row_for_table, + coerce_scalar_for_column, encode_row, staged_value_for_key, }; pub(crate) fn execute_update( catalog: &Catalog, tx: &mut Transaction, node: &UpdateNode, + context: &ExecutionContext<'_>, ) -> Result { + context.checkpoint()?; let table = catalog .get_table(&node.table) .ok_or_else(|| ExecutionError::TableNotFound(node.table.clone()))?; - let (_, rows) = scan_table_rows(catalog, tx, &table.name)?; + let (_, rows) = scan_table_rows(catalog, tx, &table.name, usize::MAX, context)?; let mut affected = 0_u64; + let mut staged = std::collections::BTreeMap::new(); for stored in rows { + context.checkpoint()?; if let Some(predicate) = &node.predicate { if !evaluate_predicate(predicate, &stored.values, Some(&table.name))? { continue; @@ -43,16 +48,17 @@ pub(crate) fn execute_update( let normalized = coerce_row_for_table(&table, &updated)?; let new_key = build_row_key(&table, &normalized)?; if new_key != stored.key { - if tx.get(&new_key)?.is_some() { + if staged_value_for_key(tx, &staged, &new_key)?.is_some() { return Err(ExecutionError::PrimaryKeyConflict { table: table.name.clone() }); } - tx.delete(&stored.key)?; + staged.insert(stored.key.clone(), None); } let payload = encode_row(&table, &normalized)?; - tx.put(&new_key, &payload)?; + staged.insert(new_key, Some(payload)); affected = affected.saturating_add(1); } + apply_staged_writes(tx, staged)?; Ok(affected) } diff --git a/src/mvcc/gc.rs b/src/mvcc/gc.rs index 2cb86d3..5c9abe7 100644 --- a/src/mvcc/gc.rs +++ b/src/mvcc/gc.rs @@ -1,6 +1,6 @@ +use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::mpsc::{self, RecvTimeoutError, Sender}; -use std::sync::Arc; use std::thread::{self, JoinHandle}; use std::time::Duration; @@ -55,14 +55,16 @@ impl GcWorker { let handle = thread::Builder::new() .name("lsmdb-mvcc-gc".to_string()) - .spawn(move || loop { - match shutdown_rx.recv_timeout(config.interval) { - Ok(_) => break, - Err(RecvTimeoutError::Disconnected) => break, - Err(RecvTimeoutError::Timeout) => { - let stats = run_gc_once(&store); - last_removed_versions_worker - .store(stats.removed_versions, Ordering::Release); + .spawn(move || { + loop { + match shutdown_rx.recv_timeout(config.interval) { + Ok(_) => break, + Err(RecvTimeoutError::Disconnected) => break, + Err(RecvTimeoutError::Timeout) => { + let stats = run_gc_once(&store); + last_removed_versions_worker + .store(stats.removed_versions, Ordering::Release); + } } } }) diff --git a/src/mvcc/timestamp.rs b/src/mvcc/timestamp.rs index abd3852..67c64bf 100644 --- a/src/mvcc/timestamp.rs +++ b/src/mvcc/timestamp.rs @@ -1,5 +1,5 @@ -use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; #[derive(Debug, Clone)] pub struct TimestampOracle { diff --git a/src/mvcc/transaction.rs b/src/mvcc/transaction.rs index 1653da7..6780465 100644 --- a/src/mvcc/transaction.rs +++ b/src/mvcc/transaction.rs @@ -1,15 +1,23 @@ use std::collections::{BTreeMap, HashMap}; +use std::path::Path; use std::sync::Arc; +#[cfg(test)] +use std::sync::atomic::AtomicBool; use std::sync::atomic::{AtomicU64, Ordering}; use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; use thiserror::Error; use tracing::{debug, trace, warn}; +use crate::storage::engine::{StorageEngine, StorageEngineOptions}; + use super::snapshot::{Snapshot, SnapshotRegistry}; use super::timestamp::TimestampOracle; -#[derive(Debug, Clone, PartialEq, Eq)] +const MVCC_DURABLE_STATE_KEY: &[u8] = b"__mvcc__/state"; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct CommittedVersion { pub commit_ts: u64, pub value: Option>, @@ -23,6 +31,8 @@ pub enum TransactionError { WriteWriteConflict { key: String, read_ts: u64, conflicting_commit_ts: u64 }, #[error("transaction is no longer active")] Closed, + #[error("durable MVCC persistence error: {0}")] + Persistence(String), } #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] @@ -32,6 +42,8 @@ pub struct TransactionMetrics { pub rolled_back: u64, pub write_conflicts: u64, pub active_transactions: usize, + pub recovered_keys: u64, + pub recovered_versions: u64, } #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] @@ -50,6 +62,7 @@ struct MvccStoreInner { oracle: TimestampOracle, snapshots: SnapshotRegistry, data: RwLock, + durable_engine: Option>, metrics: TransactionMetricsState, } @@ -59,6 +72,10 @@ struct TransactionMetricsState { committed: AtomicU64, rolled_back: AtomicU64, write_conflicts: AtomicU64, + recovered_keys: AtomicU64, + recovered_versions: AtomicU64, + #[cfg(test)] + crash_after_durable_commit: AtomicBool, } #[derive(Debug, Clone)] @@ -74,12 +91,60 @@ impl Default for MvccStore { impl MvccStore { pub fn new() -> Self { + Self::from_parts(MvccStoreData::default(), None, 0) + } + + pub fn with_storage_engine(engine: Arc) -> Result { + let data = load_durable_state(engine.as_ref())?; + let max_commit_ts = max_commit_ts_for_data(&data); + Ok(Self::from_parts(data, Some(engine), max_commit_ts)) + } + + pub fn open_persistent>(root_dir: P) -> Result { + Self::open_persistent_with_options(root_dir, StorageEngineOptions::default()) + } + + pub fn open_persistent_with_options>( + root_dir: P, + options: StorageEngineOptions, + ) -> Result { + let engine = StorageEngine::open_with_options(root_dir, options).map_err(|err| { + TransactionError::Persistence(format!("open storage engine failed: {err}")) + })?; + Self::with_storage_engine(Arc::new(engine)) + } + + pub fn is_durable(&self) -> bool { + self.inner.durable_engine.is_some() + } + + #[cfg(test)] + fn set_crash_after_durable_commit_for_test(&self, enabled: bool) { + self.inner.metrics.crash_after_durable_commit.store(enabled, Ordering::Relaxed); + } + + fn from_parts( + data: MvccStoreData, + durable_engine: Option>, + initial_timestamp: u64, + ) -> Self { + let recovered_keys = u64::try_from(data.versions.len()).unwrap_or(u64::MAX); + let recovered_versions = data + .versions + .values() + .map(|versions| u64::try_from(versions.len()).unwrap_or(u64::MAX)) + .fold(0_u64, u64::saturating_add); + let metrics = TransactionMetricsState::default(); + metrics.recovered_keys.store(recovered_keys, Ordering::Relaxed); + metrics.recovered_versions.store(recovered_versions, Ordering::Relaxed); + Self { inner: Arc::new(MvccStoreInner { - oracle: TimestampOracle::default(), + oracle: TimestampOracle::new(initial_timestamp), snapshots: SnapshotRegistry::default(), - data: RwLock::new(MvccStoreData::default()), - metrics: TransactionMetricsState::default(), + data: RwLock::new(data), + durable_engine, + metrics, }), } } @@ -116,6 +181,8 @@ impl MvccStore { rolled_back: self.inner.metrics.rolled_back.load(Ordering::Relaxed), write_conflicts: self.inner.metrics.write_conflicts.load(Ordering::Relaxed), active_transactions: self.active_snapshot_count(), + recovered_keys: self.inner.metrics.recovered_keys.load(Ordering::Relaxed), + recovered_versions: self.inner.metrics.recovered_versions.load(Ordering::Relaxed), } } @@ -155,6 +222,40 @@ impl MvccStore { } pub fn scan_prefix_at(&self, prefix: &[u8], read_ts: u64) -> Vec<(Vec, Vec)> { + self.scan_prefix_at_with_observer(prefix, read_ts, |_| true) + } + + pub fn scan_prefix_at_limited( + &self, + prefix: &[u8], + read_ts: u64, + max_rows: usize, + ) -> Vec<(Vec, Vec)> { + self.scan_prefix_at_limited_with_observer(prefix, read_ts, max_rows, |_| true) + } + + pub fn scan_prefix_at_with_observer( + &self, + prefix: &[u8], + read_ts: u64, + observer: F, + ) -> Vec<(Vec, Vec)> + where + F: FnMut(usize) -> bool, + { + self.scan_prefix_at_limited_with_observer(prefix, read_ts, usize::MAX, observer) + } + + pub fn scan_prefix_at_limited_with_observer( + &self, + prefix: &[u8], + read_ts: u64, + max_rows: usize, + mut observer: F, + ) -> Vec<(Vec, Vec)> + where + F: FnMut(usize) -> bool, + { let data = self.inner.data.read(); let mut rows = Vec::new(); @@ -170,6 +271,14 @@ impl MvccStore { if let Some(value) = &version.value { rows.push((key.clone(), value.clone())); + if !observer(rows.len()) { + rows.sort_by(|a, b| a.0.cmp(&b.0)); + return rows; + } + if rows.len() >= max_rows { + rows.sort_by(|a, b| a.0.cmp(&b.0)); + return rows; + } } break; } @@ -218,12 +327,37 @@ impl MvccStore { } let commit_ts = self.inner.oracle.next_timestamp(); + let mut previous_versions = Vec::with_capacity(writes.len()); for (key, value) in writes { + previous_versions.push((key.clone(), data.versions.get(key).cloned())); let entry = data.versions.entry(key.clone()).or_default(); entry.push(CommittedVersion { commit_ts, value: value.clone() }); } + if let Some(engine) = self.inner.durable_engine.as_ref() { + if let Err(err) = persist_durable_state(engine.as_ref(), &data.versions) { + for (key, previous) in previous_versions { + match previous { + Some(versions) => { + data.versions.insert(key, versions); + } + None => { + data.versions.remove(&key); + } + } + } + return Err(err); + } + } + + #[cfg(test)] + if self.inner.durable_engine.is_some() + && self.inner.metrics.crash_after_durable_commit.load(Ordering::Relaxed) + { + panic!("simulated crash after durable commit"); + } + trace!(commit_ts, write_count = writes.len(), "commit applied"); Ok(commit_ts) } @@ -231,8 +365,9 @@ impl MvccStore { pub(crate) fn prune_versions_older_than(&self, watermark_ts: u64) -> PruneStats { let mut data = self.inner.data.write(); let mut stats = PruneStats::default(); + let mut previous_versions = Vec::new(); - for versions in data.versions.values_mut() { + for (key, versions) in data.versions.iter_mut() { stats.scanned_keys = stats.scanned_keys.saturating_add(1); if versions.len() <= 1 { @@ -241,22 +376,74 @@ impl MvccStore { let split = versions .iter() - .position(|version| version.commit_ts >= watermark_ts) + .position(|version| version.commit_ts > watermark_ts) .unwrap_or(versions.len()); if split <= 1 { continue; } + previous_versions.push((key.clone(), versions.clone())); let remove_count = split - 1; versions.drain(0..remove_count); stats.removed_versions = stats.removed_versions.saturating_add(remove_count as u64); } + if stats.removed_versions > 0 { + if let Some(engine) = self.inner.durable_engine.as_ref() { + if let Err(err) = persist_durable_state(engine.as_ref(), &data.versions) { + warn!(error = %err, "failed to persist GC-pruned MVCC state; reverting prune"); + for (key, versions) in previous_versions { + data.versions.insert(key, versions); + } + stats.removed_versions = 0; + } + } + } + stats } } +fn load_durable_state(engine: &StorageEngine) -> Result { + let Some(raw) = engine.get(MVCC_DURABLE_STATE_KEY).map_err(|err| { + TransactionError::Persistence(format!("load durable MVCC state failed: {err}")) + })? + else { + return Ok(MvccStoreData::default()); + }; + + let versions = + bincode::deserialize::, Vec>>(&raw).map_err(|err| { + TransactionError::Persistence(format!("decode durable MVCC state failed: {err}")) + })?; + + Ok(MvccStoreData { versions }) +} + +fn persist_durable_state( + engine: &StorageEngine, + versions: &HashMap, Vec>, +) -> Result<(), TransactionError> { + let payload = bincode::serialize(versions).map_err(|err| { + TransactionError::Persistence(format!("encode durable MVCC state failed: {err}")) + })?; + + engine.put(MVCC_DURABLE_STATE_KEY, &payload).map_err(|err| { + TransactionError::Persistence(format!("persist durable MVCC state failed: {err}")) + })?; + + Ok(()) +} + +fn max_commit_ts_for_data(data: &MvccStoreData) -> u64 { + data.versions + .values() + .filter_map(|versions| versions.last().map(|version| version.commit_ts)) + .max() + .unwrap_or(0) +} + #[derive(Debug)] pub struct Transaction { store: MvccStore, @@ -311,6 +498,72 @@ impl Transaction { Ok(visible.into_iter().collect()) } + pub fn scan_prefix_with_observer( + &self, + prefix: &[u8], + observer: F, + ) -> Result, Vec)>, TransactionError> + where + F: FnMut(usize) -> bool, + { + self.scan_prefix_limited_with_observer(prefix, usize::MAX, observer) + } + + pub fn scan_prefix_limited( + &self, + prefix: &[u8], + max_rows: usize, + ) -> Result, Vec)>, TransactionError> { + self.scan_prefix_limited_with_observer(prefix, max_rows, |_| true) + } + + pub fn scan_prefix_limited_with_observer( + &self, + prefix: &[u8], + max_rows: usize, + mut observer: F, + ) -> Result, Vec)>, TransactionError> + where + F: FnMut(usize) -> bool, + { + if self.closed { + return Err(TransactionError::Closed); + } + + let read_ts = self.read_ts()?; + let write_overlap = self.writes.keys().filter(|key| key.starts_with(prefix)).count(); + let fetch_limit = max_rows.saturating_add(write_overlap).saturating_add(1); + let mut visible = self + .store + .scan_prefix_at_limited_with_observer(prefix, read_ts, fetch_limit, |seen| { + observer(seen) + }) + .into_iter() + .collect::>(); + + let mut observed = visible.len(); + for (key, value) in &self.writes { + if !key.starts_with(prefix) { + continue; + } + + match value { + Some(value) => { + visible.insert(key.clone(), value.clone()); + } + None => { + visible.remove(key); + } + } + observed = observed.saturating_add(1); + if !observer(observed) { + break; + } + } + + Ok(visible.into_iter().collect()) + } + pub fn put(&mut self, key: &[u8], value: &[u8]) -> Result<(), TransactionError> { if self.closed { return Err(TransactionError::Closed); @@ -374,10 +627,25 @@ impl Drop for Transaction { #[cfg(test)] mod tests { + use std::fs; + use std::panic::{AssertUnwindSafe, catch_unwind}; + use std::path::PathBuf; use std::thread; + use std::time::{SystemTime, UNIX_EPOCH}; use super::*; + fn test_dir(label: &str) -> PathBuf { + let mut path = std::env::temp_dir(); + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time should be after epoch") + .as_nanos(); + path.push(format!("lsmdb-mvcc-{label}-{}-{nanos}", std::process::id())); + fs::create_dir_all(&path).expect("create temp dir"); + path + } + #[test] fn transaction_commit_and_visibility() { let store = MvccStore::new(); @@ -500,4 +768,92 @@ mod tests { vec![(b"k/b".to_vec(), b"v1".to_vec()), (b"k/c".to_vec(), b"v2".to_vec())] ); } + + #[test] + fn durable_store_recovers_versions_and_advances_timestamp_after_restart() { + let dir = test_dir("durable-recovery"); + + { + let store = MvccStore::open_persistent(&dir).expect("open durable store"); + assert!(store.is_durable()); + + let mut tx = store.begin_transaction(); + tx.put(b"k", b"v1").expect("write v1"); + assert_eq!(tx.commit().expect("commit v1"), 1); + + let mut tx = store.begin_transaction(); + tx.put(b"k", b"v2").expect("write v2"); + assert_eq!(tx.commit().expect("commit v2"), 2); + assert_eq!(store.version_count_for_key(b"k"), 2); + } + + { + let store = MvccStore::open_persistent(&dir).expect("reopen durable store"); + assert_eq!(store.version_count_for_key(b"k"), 2); + let reader = store.begin_transaction(); + assert_eq!(reader.get(b"k").expect("read latest"), Some(b"v2".to_vec())); + + let mut tx = store.begin_transaction(); + tx.put(b"k", b"v3").expect("write v3"); + let commit_ts = tx.commit().expect("commit v3"); + assert!(commit_ts >= 3); + } + + fs::remove_dir_all(dir).expect("cleanup temp dir"); + } + + #[test] + fn durable_store_survives_crash_before_commit_acknowledgment() { + let dir = test_dir("crash-before-ack"); + + { + let store = MvccStore::open_persistent(&dir).expect("open durable store"); + store.set_crash_after_durable_commit_for_test(true); + + let mut tx = store.begin_transaction(); + tx.put(b"crash-key", b"persisted").expect("write"); + let crashed = catch_unwind(AssertUnwindSafe(|| { + let _ = tx.commit(); + })); + assert!(crashed.is_err()); + } + + { + let store = MvccStore::open_persistent(&dir).expect("reopen durable store"); + let reader = store.begin_transaction(); + assert_eq!( + reader.get(b"crash-key").expect("read after simulated crash"), + Some(b"persisted".to_vec()) + ); + } + + fs::remove_dir_all(dir).expect("cleanup temp dir"); + } + + #[test] + fn durable_store_reports_recovered_metrics_after_restart() { + let dir = test_dir("recovered-metrics"); + + { + let store = MvccStore::open_persistent(&dir).expect("open durable store"); + + let mut tx = store.begin_transaction(); + tx.put(b"m/a", b"1").expect("put a"); + tx.commit().expect("commit a"); + + let mut tx = store.begin_transaction(); + tx.put(b"m/a", b"2").expect("update a"); + tx.put(b"m/b", b"3").expect("put b"); + tx.commit().expect("commit second batch"); + } + + { + let store = MvccStore::open_persistent(&dir).expect("reopen durable store"); + let metrics = store.metrics(); + assert_eq!(metrics.recovered_keys, 2); + assert_eq!(metrics.recovered_versions, 3); + } + + fs::remove_dir_all(dir).expect("cleanup temp dir"); + } } diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 0b06fa6..b948aa0 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -83,8 +83,23 @@ mod tests { let PhysicalPlan::Sort(sort) = *limit.input else { panic!("expected sort"); }; - let PhysicalPlan::PrimaryKeyScan(_) = *sort.input else { - panic!("expected primary key scan under sort"); - }; + assert!( + contains_primary_key_scan(&sort.input), + "expected primary key scan in sort subtree" + ); + } + + fn contains_primary_key_scan(plan: &PhysicalPlan) -> bool { + match plan { + PhysicalPlan::PrimaryKeyScan(_) => true, + PhysicalPlan::Filter(filter) => contains_primary_key_scan(&filter.input), + PhysicalPlan::Project(project) => contains_primary_key_scan(&project.input), + PhysicalPlan::Sort(sort) => contains_primary_key_scan(&sort.input), + PhysicalPlan::Limit(limit) => contains_primary_key_scan(&limit.input), + PhysicalPlan::Join(join) => { + contains_primary_key_scan(&join.left) || contains_primary_key_scan(&join.right) + } + _ => false, + } } } diff --git a/src/planner/physical.rs b/src/planner/physical.rs index 1b25e45..50183c3 100644 --- a/src/planner/physical.rs +++ b/src/planner/physical.rs @@ -5,7 +5,7 @@ use crate::sql::ast::{ LiteralValue, OrderByExpr, SelectItem, }; -use super::logical::{LogicalFilter, LogicalLimit, LogicalPlan, LogicalScan, LogicalSort}; +use super::logical::{LogicalFilter, LogicalPlan, LogicalScan}; #[derive(Debug, Clone, PartialEq)] pub enum PhysicalPlan { @@ -252,6 +252,7 @@ fn collect_conjuncts<'a>(expr: &'a Expr, out: &mut Vec<&'a Expr>) { #[cfg(test)] mod tests { use super::*; + use crate::planner::{LogicalLimit, LogicalSort}; use crate::sql::ast::SortDirection; fn scan_node() -> LogicalScan { diff --git a/src/server/mod.rs b/src/server/mod.rs index a22e7f3..38f2a42 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -2,7 +2,15 @@ pub mod protocol; pub mod tcp; pub use protocol::{ - ProtocolError, QueryPayload, RequestFrame, RequestType, ResponseFrame, ResponsePayload, - TransactionState, read_request, read_response, write_request, write_response, + ActiveStatementPayload, ActiveStatementsPayload, AdminStatusPayload, AuthenticationPayload, + AuthenticationRequest, ErrorCode, ErrorPayload, HealthPayload, PROTOCOL_VERSION, ProtocolError, + QueryPayload, ReadinessPayload, RequestFrame, RequestType, ResponseFrame, ResponsePayload, + StatementCancellationPayload, TransactionState, authentication_request_with_password, + authentication_request_with_token, decode_authentication_request, read_request, + read_request_with_limit, read_response, write_request, write_response, +}; +pub use tcp::{ + ServerAuthOptions, ServerError, ServerHandle, ServerLimits, ServerOptions, ServerRole, + ServerSecurityOptions, ServerTlsMode, ServerTlsOptions, StaticPasswordUser, + StaticTokenPrincipal, start_server, start_server_with_options, }; -pub use tcp::{ServerError, ServerHandle, start_server}; diff --git a/src/server/protocol.rs b/src/server/protocol.rs index 5c99474..1c09df3 100644 --- a/src/server/protocol.rs +++ b/src/server/protocol.rs @@ -5,6 +5,8 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use crate::executor::{ExecutionResult, QueryResult, ScalarValue}; +pub const PROTOCOL_VERSION: u16 = 1; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum RequestType { @@ -13,6 +15,12 @@ pub enum RequestType { Commit = 3, Rollback = 4, Explain = 5, + Health = 6, + Readiness = 7, + AdminStatus = 8, + ActiveStatements = 9, + CancelStatement = 10, + Authenticate = 11, } impl TryFrom for RequestType { @@ -25,6 +33,12 @@ impl TryFrom for RequestType { 3 => Ok(RequestType::Commit), 4 => Ok(RequestType::Rollback), 5 => Ok(RequestType::Explain), + 6 => Ok(RequestType::Health), + 7 => Ok(RequestType::Readiness), + 8 => Ok(RequestType::AdminStatus), + 9 => Ok(RequestType::ActiveStatements), + 10 => Ok(RequestType::CancelStatement), + 11 => Ok(RequestType::Authenticate), other => { Err(ProtocolError::InvalidFrame(format!("unknown request type byte: {other}"))) } @@ -38,10 +52,87 @@ pub struct RequestFrame { pub sql: String, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AuthenticationRequest { + Password { username: String, password: String }, + Token { token: String }, +} + #[derive(Debug, Clone, PartialEq)] pub enum ResponseFrame { Ok(ResponsePayload), - Err(String), + Err(ErrorPayload), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum ErrorCode { + InvalidRequest = 1, + Parse = 2, + Validation = 3, + Planner = 4, + Execution = 5, + Busy = 6, + ResourceLimit = 7, + Timeout = 8, + Canceled = 9, + Quota = 10, + Unauthenticated = 11, + PermissionDenied = 12, +} + +impl TryFrom for ErrorCode { + type Error = ProtocolError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(ErrorCode::InvalidRequest), + 2 => Ok(ErrorCode::Parse), + 3 => Ok(ErrorCode::Validation), + 4 => Ok(ErrorCode::Planner), + 5 => Ok(ErrorCode::Execution), + 6 => Ok(ErrorCode::Busy), + 7 => Ok(ErrorCode::ResourceLimit), + 8 => Ok(ErrorCode::Timeout), + 9 => Ok(ErrorCode::Canceled), + 10 => Ok(ErrorCode::Quota), + 11 => Ok(ErrorCode::Unauthenticated), + 12 => Ok(ErrorCode::PermissionDenied), + other => Err(ProtocolError::InvalidFrame(format!("unknown error code byte: {other}"))), + } + } +} + +impl ErrorCode { + pub fn as_str(self) -> &'static str { + match self { + ErrorCode::InvalidRequest => "INVALID_REQUEST", + ErrorCode::Parse => "PARSE", + ErrorCode::Validation => "VALIDATION", + ErrorCode::Planner => "PLANNER", + ErrorCode::Execution => "EXECUTION", + ErrorCode::Busy => "BUSY", + ErrorCode::ResourceLimit => "RESOURCE_LIMIT", + ErrorCode::Timeout => "TIMEOUT", + ErrorCode::Canceled => "CANCELED", + ErrorCode::Quota => "QUOTA", + ErrorCode::Unauthenticated => "UNAUTHENTICATED", + ErrorCode::PermissionDenied => "PERMISSION_DENIED", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ErrorPayload { + pub code: ErrorCode, + pub message: String, + pub retryable: bool, +} + +impl ErrorPayload { + pub fn new(code: ErrorCode, message: impl Into, retryable: bool) -> Self { + Self { code, message: message.into(), retryable } + } } #[derive(Debug, Clone, PartialEq)] @@ -50,6 +141,12 @@ pub enum ResponsePayload { AffectedRows(u64), TransactionState(TransactionState), ExplainPlan(String), + Health(HealthPayload), + Readiness(ReadinessPayload), + AdminStatus(AdminStatusPayload), + ActiveStatements(ActiveStatementsPayload), + StatementCancellation(StatementCancellationPayload), + Authentication(AuthenticationPayload), } #[derive(Debug, Clone, PartialEq)] @@ -58,6 +155,71 @@ pub struct QueryPayload { pub rows: Vec>>, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HealthPayload { + pub ok: bool, + pub status: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ReadinessPayload { + pub ready: bool, + pub status: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AdminStatusPayload { + pub server_version: String, + pub protocol_version: u16, + pub uptime_seconds: u64, + pub accepting_connections: bool, + pub active_connections: u64, + pub total_connections: u64, + pub rejected_connections: u64, + pub busy_requests: u64, + pub resource_limit_requests: u64, + pub quota_rejections: u64, + pub timed_out_requests: u64, + pub canceled_requests: u64, + pub active_statements: u64, + pub active_memory_intensive_requests: u64, + pub mvcc_started: u64, + pub mvcc_committed: u64, + pub mvcc_rolled_back: u64, + pub mvcc_write_conflicts: u64, + pub mvcc_active_transactions: u64, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ActiveStatementsPayload { + pub statements: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ActiveStatementPayload { + pub statement_id: u64, + pub connection_id: u64, + pub identity: String, + pub request_type: String, + pub runtime_ms: u64, + pub cancel_requested: bool, + pub sql_preview: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StatementCancellationPayload { + pub statement_id: u64, + pub accepted: bool, + pub status: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AuthenticationPayload { + pub identity: String, + pub role: String, + pub auth_scheme: String, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum TransactionState { @@ -72,20 +234,32 @@ pub enum ProtocolError { Io(#[from] std::io::Error), #[error("invalid frame: {0}")] InvalidFrame(String), + #[error("frame too large: length={length}, max={max}")] + FrameTooLarge { length: usize, max: usize }, #[error("utf8 decode error: {0}")] Utf8(#[from] std::string::FromUtf8Error), } -pub async fn read_request( +pub async fn read_request( reader: &mut R, ) -> Result, ProtocolError> { - let Some(body) = read_frame(reader).await? else { + let Some(body) = read_frame(reader, None).await? else { return Ok(None); }; decode_request(&body).map(Some) } -pub async fn write_request( +pub async fn read_request_with_limit( + reader: &mut R, + max_body_bytes: usize, +) -> Result, ProtocolError> { + let Some(body) = read_frame(reader, Some(max_body_bytes)).await? else { + return Ok(None); + }; + decode_request(&body).map(Some) +} + +pub async fn write_request( writer: &mut W, request: &RequestFrame, ) -> Result<(), ProtocolError> { @@ -93,16 +267,16 @@ pub async fn write_request( write_frame(writer, &body).await } -pub async fn read_response( +pub async fn read_response( reader: &mut R, ) -> Result, ProtocolError> { - let Some(body) = read_frame(reader).await? else { + let Some(body) = read_frame(reader, None).await? else { return Ok(None); }; decode_response(&body).map(Some) } -pub async fn write_response( +pub async fn write_response( writer: &mut W, response: &ResponseFrame, ) -> Result<(), ProtocolError> { @@ -110,6 +284,67 @@ pub async fn write_response( write_frame(writer, &body).await } +pub fn authentication_request_with_password( + username: impl Into, + password: impl Into, +) -> RequestFrame { + RequestFrame { + request_type: RequestType::Authenticate, + sql: encode_authentication_request(AuthenticationRequest::Password { + username: username.into(), + password: password.into(), + }), + } +} + +pub fn authentication_request_with_token(token: impl Into) -> RequestFrame { + RequestFrame { + request_type: RequestType::Authenticate, + sql: encode_authentication_request(AuthenticationRequest::Token { token: token.into() }), + } +} + +pub fn decode_authentication_request(payload: &str) -> Result { + let mut parts = payload.splitn(3, '\0'); + let scheme = parts.next().unwrap_or_default(); + let identity = parts.next().unwrap_or_default(); + let secret = parts.next().unwrap_or_default(); + + match scheme { + "password" => { + if identity.is_empty() { + return Err("password authentication requires a username".to_string()); + } + if secret.is_empty() { + return Err("password authentication requires a password".to_string()); + } + Ok(AuthenticationRequest::Password { + username: identity.to_string(), + password: secret.to_string(), + }) + } + "token" => { + if !identity.is_empty() { + return Err("token authentication does not accept a username".to_string()); + } + if secret.is_empty() { + return Err("token authentication requires a token".to_string()); + } + Ok(AuthenticationRequest::Token { token: secret.to_string() }) + } + _ => Err("authentication payload must use the 'password' or 'token' scheme".to_string()), + } +} + +fn encode_authentication_request(request: AuthenticationRequest) -> String { + match request { + AuthenticationRequest::Password { username, password } => { + format!("password\0{username}\0{password}") + } + AuthenticationRequest::Token { token } => format!("token\0\0{token}"), + } +} + pub fn payload_from_execution_result(result: &ExecutionResult) -> ResponsePayload { match result { ExecutionResult::Query(query) => ResponsePayload::Query(query_to_payload(query)), @@ -171,8 +406,9 @@ fn hex_char(value: u8) -> char { } } -async fn read_frame( +async fn read_frame( reader: &mut R, + max_body_bytes: Option, ) -> Result>, ProtocolError> { let mut len_buf = [0_u8; 4]; match reader.read_exact(&mut len_buf).await { @@ -187,13 +423,25 @@ async fn read_frame( "frame length must be greater than zero".to_string(), )); } + if let Some(max_body_bytes) = max_body_bytes { + if length > max_body_bytes { + let mut remaining = length; + let mut discard_buf = [0_u8; 4096]; + while remaining > 0 { + let chunk_len = remaining.min(discard_buf.len()); + reader.read_exact(&mut discard_buf[..chunk_len]).await?; + remaining -= chunk_len; + } + return Err(ProtocolError::FrameTooLarge { length, max: max_body_bytes }); + } + } let mut body = vec![0_u8; length]; reader.read_exact(&mut body).await?; Ok(Some(body)) } -async fn write_frame( +async fn write_frame( writer: &mut W, body: &[u8], ) -> Result<(), ProtocolError> { @@ -229,9 +477,11 @@ fn encode_response(response: &ResponseFrame) -> Result, ProtocolError> { body.push(0_u8); encode_payload(payload, &mut body)?; } - ResponseFrame::Err(message) => { + ResponseFrame::Err(error) => { body.push(1_u8); - write_len_prefixed_bytes(&mut body, message.as_bytes())?; + body.push(error.code as u8); + body.push(u8::from(error.retryable)); + write_len_prefixed_bytes(&mut body, error.message.as_bytes())?; } } Ok(body) @@ -255,13 +505,15 @@ fn decode_response(body: &[u8]) -> Result { } 1 => { let mut cursor = Cursor::new(payload); + let code = ErrorCode::try_from(read_u8(&mut cursor)?)?; + let retryable = read_bool(&mut cursor)?; let message = read_len_prefixed_string(&mut cursor)?; if (cursor.position() as usize) != payload.len() { return Err(ProtocolError::InvalidFrame( "error payload has trailing bytes".to_string(), )); } - Ok(ResponseFrame::Err(message)) + Ok(ResponseFrame::Err(ErrorPayload { code, message, retryable })) } other => Err(ProtocolError::InvalidFrame(format!("unknown response status byte: {other}"))), } @@ -295,6 +547,63 @@ fn encode_payload(payload: &ResponsePayload, out: &mut Vec) -> Result<(), Pr out.push(4_u8); write_len_prefixed_bytes(out, plan.as_bytes())?; } + ResponsePayload::Health(health) => { + out.push(5_u8); + out.push(u8::from(health.ok)); + write_len_prefixed_bytes(out, health.status.as_bytes())?; + } + ResponsePayload::Readiness(readiness) => { + out.push(6_u8); + out.push(u8::from(readiness.ready)); + write_len_prefixed_bytes(out, readiness.status.as_bytes())?; + } + ResponsePayload::AdminStatus(status) => { + out.push(7_u8); + write_len_prefixed_bytes(out, status.server_version.as_bytes())?; + out.extend(status.protocol_version.to_be_bytes()); + out.extend(status.uptime_seconds.to_be_bytes()); + out.push(u8::from(status.accepting_connections)); + out.extend(status.active_connections.to_be_bytes()); + out.extend(status.total_connections.to_be_bytes()); + out.extend(status.rejected_connections.to_be_bytes()); + out.extend(status.busy_requests.to_be_bytes()); + out.extend(status.resource_limit_requests.to_be_bytes()); + out.extend(status.quota_rejections.to_be_bytes()); + out.extend(status.timed_out_requests.to_be_bytes()); + out.extend(status.canceled_requests.to_be_bytes()); + out.extend(status.active_statements.to_be_bytes()); + out.extend(status.active_memory_intensive_requests.to_be_bytes()); + out.extend(status.mvcc_started.to_be_bytes()); + out.extend(status.mvcc_committed.to_be_bytes()); + out.extend(status.mvcc_rolled_back.to_be_bytes()); + out.extend(status.mvcc_write_conflicts.to_be_bytes()); + out.extend(status.mvcc_active_transactions.to_be_bytes()); + } + ResponsePayload::ActiveStatements(payload) => { + out.push(8_u8); + write_u32(out, payload.statements.len())?; + for statement in &payload.statements { + out.extend(statement.statement_id.to_be_bytes()); + out.extend(statement.connection_id.to_be_bytes()); + write_len_prefixed_bytes(out, statement.identity.as_bytes())?; + write_len_prefixed_bytes(out, statement.request_type.as_bytes())?; + out.extend(statement.runtime_ms.to_be_bytes()); + out.push(u8::from(statement.cancel_requested)); + write_len_prefixed_bytes(out, statement.sql_preview.as_bytes())?; + } + } + ResponsePayload::StatementCancellation(payload) => { + out.push(9_u8); + out.extend(payload.statement_id.to_be_bytes()); + out.push(u8::from(payload.accepted)); + write_len_prefixed_bytes(out, payload.status.as_bytes())?; + } + ResponsePayload::Authentication(payload) => { + out.push(10_u8); + write_len_prefixed_bytes(out, payload.identity.as_bytes())?; + write_len_prefixed_bytes(out, payload.role.as_bytes())?; + write_len_prefixed_bytes(out, payload.auth_scheme.as_bytes())?; + } } Ok(()) } @@ -343,6 +652,101 @@ fn decode_payload(cursor: &mut Cursor<&[u8]>) -> Result { + let ok = read_bool(cursor)?; + let status = read_len_prefixed_string(cursor)?; + Ok(ResponsePayload::Health(HealthPayload { ok, status })) + } + 6 => { + let ready = read_bool(cursor)?; + let status = read_len_prefixed_string(cursor)?; + Ok(ResponsePayload::Readiness(ReadinessPayload { ready, status })) + } + 7 => { + let server_version = read_len_prefixed_string(cursor)?; + let protocol_version = read_u16(cursor)?; + let uptime_seconds = read_u64(cursor)?; + let accepting_connections = read_bool(cursor)?; + let active_connections = read_u64(cursor)?; + let total_connections = read_u64(cursor)?; + let rejected_connections = read_u64(cursor)?; + let busy_requests = read_u64(cursor)?; + let resource_limit_requests = read_u64(cursor)?; + let quota_rejections = read_u64(cursor)?; + let timed_out_requests = read_u64(cursor)?; + let canceled_requests = read_u64(cursor)?; + let active_statements = read_u64(cursor)?; + let active_memory_intensive_requests = read_u64(cursor)?; + let mvcc_started = read_u64(cursor)?; + let mvcc_committed = read_u64(cursor)?; + let mvcc_rolled_back = read_u64(cursor)?; + let mvcc_write_conflicts = read_u64(cursor)?; + let mvcc_active_transactions = read_u64(cursor)?; + Ok(ResponsePayload::AdminStatus(AdminStatusPayload { + server_version, + protocol_version, + uptime_seconds, + accepting_connections, + active_connections, + total_connections, + rejected_connections, + busy_requests, + resource_limit_requests, + quota_rejections, + timed_out_requests, + canceled_requests, + active_statements, + active_memory_intensive_requests, + mvcc_started, + mvcc_committed, + mvcc_rolled_back, + mvcc_write_conflicts, + mvcc_active_transactions, + })) + } + 8 => { + let count = read_u32(cursor)? as usize; + let mut statements = Vec::with_capacity(count); + for _ in 0..count { + let statement_id = read_u64(cursor)?; + let connection_id = read_u64(cursor)?; + let identity = read_len_prefixed_string(cursor)?; + let request_type = read_len_prefixed_string(cursor)?; + let runtime_ms = read_u64(cursor)?; + let cancel_requested = read_bool(cursor)?; + let sql_preview = read_len_prefixed_string(cursor)?; + statements.push(ActiveStatementPayload { + statement_id, + connection_id, + identity, + request_type, + runtime_ms, + cancel_requested, + sql_preview, + }); + } + Ok(ResponsePayload::ActiveStatements(ActiveStatementsPayload { statements })) + } + 9 => { + let statement_id = read_u64(cursor)?; + let accepted = read_bool(cursor)?; + let status = read_len_prefixed_string(cursor)?; + Ok(ResponsePayload::StatementCancellation(StatementCancellationPayload { + statement_id, + accepted, + status, + })) + } + 10 => { + let identity = read_len_prefixed_string(cursor)?; + let role = read_len_prefixed_string(cursor)?; + let auth_scheme = read_len_prefixed_string(cursor)?; + Ok(ResponsePayload::Authentication(AuthenticationPayload { + identity, + role, + auth_scheme, + })) + } other => { Err(ProtocolError::InvalidFrame(format!("unknown response payload type: {other}"))) } @@ -381,6 +785,20 @@ fn read_u16(cursor: &mut Cursor<&[u8]>) -> Result { Ok(u16::from_be_bytes(raw)) } +fn read_u64(cursor: &mut Cursor<&[u8]>) -> Result { + let mut raw = [0_u8; 8]; + read_exact(cursor, &mut raw)?; + Ok(u64::from_be_bytes(raw)) +} + +fn read_bool(cursor: &mut Cursor<&[u8]>) -> Result { + match read_u8(cursor)? { + 0 => Ok(false), + 1 => Ok(true), + other => Err(ProtocolError::InvalidFrame(format!("invalid bool byte: {other}"))), + } +} + fn read_u32(cursor: &mut Cursor<&[u8]>) -> Result { let mut raw = [0_u8; 4]; read_exact(cursor, &mut raw)?; @@ -400,7 +818,7 @@ fn read_len_prefixed_bytes(cursor: &mut Cursor<&[u8]>) -> Result, Protoc } fn read_exact(cursor: &mut Cursor<&[u8]>, out: &mut [u8]) -> Result<(), ProtocolError> { - cursor.read_exact(out).map_err(|err| ProtocolError::InvalidFrame(err.to_string())) + Read::read_exact(cursor, out).map_err(|err| ProtocolError::InvalidFrame(err.to_string())) } #[cfg(test)] @@ -427,6 +845,19 @@ mod tests { assert_eq!(decoded, response); } + #[tokio::test] + async fn error_response_round_trip() { + let response = ResponseFrame::Err(ErrorPayload { + code: ErrorCode::Busy, + message: "server busy: retry later".to_string(), + retryable: true, + }); + let (mut client, mut server) = tokio::io::duplex(1024); + write_response(&mut client, &response).await.expect("write response"); + let decoded = read_response(&mut server).await.expect("read response").expect("response"); + assert_eq!(decoded, response); + } + #[tokio::test] async fn explain_payload_round_trip() { let response = @@ -436,4 +867,108 @@ mod tests { let decoded = read_response(&mut server).await.expect("read response").expect("response"); assert_eq!(decoded, response); } + + #[tokio::test] + async fn health_payload_round_trip() { + let response = ResponseFrame::Ok(ResponsePayload::Health(HealthPayload { + ok: true, + status: "ok".to_string(), + })); + let (mut client, mut server) = tokio::io::duplex(1024); + write_response(&mut client, &response).await.expect("write response"); + let decoded = read_response(&mut server).await.expect("read response").expect("response"); + assert_eq!(decoded, response); + } + + #[tokio::test] + async fn admin_status_payload_round_trip() { + let response = ResponseFrame::Ok(ResponsePayload::AdminStatus(AdminStatusPayload { + server_version: "0.1.0".to_string(), + protocol_version: PROTOCOL_VERSION, + uptime_seconds: 42, + accepting_connections: true, + active_connections: 1, + total_connections: 4, + rejected_connections: 2, + busy_requests: 3, + resource_limit_requests: 1, + quota_rejections: 4, + timed_out_requests: 5, + canceled_requests: 6, + active_statements: 7, + active_memory_intensive_requests: 0, + mvcc_started: 12, + mvcc_committed: 9, + mvcc_rolled_back: 2, + mvcc_write_conflicts: 1, + mvcc_active_transactions: 0, + })); + let (mut client, mut server) = tokio::io::duplex(2048); + write_response(&mut client, &response).await.expect("write response"); + let decoded = read_response(&mut server).await.expect("read response").expect("response"); + assert_eq!(decoded, response); + } + + #[tokio::test] + async fn active_statements_payload_round_trip() { + let response = + ResponseFrame::Ok(ResponsePayload::ActiveStatements(ActiveStatementsPayload { + statements: vec![ActiveStatementPayload { + statement_id: 11, + connection_id: 3, + identity: "127.0.0.1".to_string(), + request_type: "QUERY".to_string(), + runtime_ms: 27, + cancel_requested: false, + sql_preview: "SELECT * FROM users".to_string(), + }], + })); + let (mut client, mut server) = tokio::io::duplex(2048); + write_response(&mut client, &response).await.expect("write response"); + let decoded = read_response(&mut server).await.expect("read response").expect("response"); + assert_eq!(decoded, response); + } + + #[tokio::test] + async fn authentication_payload_round_trip() { + let response = ResponseFrame::Ok(ResponsePayload::Authentication(AuthenticationPayload { + identity: "alice".to_string(), + role: "writer".to_string(), + auth_scheme: "password".to_string(), + })); + let (mut client, mut server) = tokio::io::duplex(1024); + write_response(&mut client, &response).await.expect("write response"); + let decoded = read_response(&mut server).await.expect("read response").expect("response"); + assert_eq!(decoded, response); + } + + #[test] + fn decodes_password_authentication_request() { + let request = decode_authentication_request("password\0alice\0secret") + .expect("decode password auth request"); + assert_eq!( + request, + AuthenticationRequest::Password { + username: "alice".to_string(), + password: "secret".to_string(), + } + ); + } + + #[test] + fn decodes_token_authentication_request() { + let request = + decode_authentication_request("token\0\0opaque-token").expect("decode token request"); + assert_eq!(request, AuthenticationRequest::Token { token: "opaque-token".to_string() }); + } + + #[tokio::test] + async fn request_frame_limit_rejects_oversized_body() { + let request = RequestFrame { request_type: RequestType::Query, sql: "SELECT 1".repeat(64) }; + + let (mut client, mut server) = tokio::io::duplex(4096); + write_request(&mut client, &request).await.expect("write request"); + let err = read_request_with_limit(&mut server, 16).await.expect_err("frame too large"); + assert!(matches!(err, ProtocolError::FrameTooLarge { .. })); + } } diff --git a/src/server/tcp.rs b/src/server/tcp.rs index bb8c6d5..2bac569 100644 --- a/src/server/tcp.rs +++ b/src/server/tcp.rs @@ -1,33 +1,584 @@ +use std::collections::{BTreeMap, HashMap}; +use std::fs::File; +use std::io::BufReader; use std::net::SocketAddr; +use std::path::PathBuf; use std::sync::Arc; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; +use parking_lot::Mutex; use thiserror::Error; -use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::oneshot; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpListener; +use tokio::sync::{OwnedSemaphorePermit, Semaphore, oneshot}; use tokio::task::JoinHandle; +use tokio_rustls::TlsAcceptor; +use tokio_rustls::rustls::{ServerConfig as RustlsServerConfig, pki_types::PrivateKeyDer}; use tracing::{Instrument, debug, error, info, info_span, warn}; use crate::catalog::Catalog; -use crate::executor::{ExecutionError, ExecutionSession}; +use crate::executor::governance::{ExecutionGovernance, StatementCancellation}; +use crate::executor::{ExecutionError, ExecutionLimits, ExecutionSession}; use crate::mvcc::MvccStore; -use crate::planner::{PlannerError, plan_statement}; +use crate::planner::{PhysicalPlan, PlannerError, plan_statement}; +use crate::sql::ast::Statement; use crate::sql::parser::{ParseError, parse_sql}; use crate::sql::validator::{ValidationError, validate_statement}; use super::protocol::{ - ProtocolError, RequestFrame, RequestType, ResponseFrame, ResponsePayload, - payload_from_execution_result, read_request, write_response, + ActiveStatementPayload, ActiveStatementsPayload, AdminStatusPayload, AuthenticationPayload, + AuthenticationRequest, ErrorCode, ErrorPayload, HealthPayload, PROTOCOL_VERSION, ProtocolError, + ReadinessPayload, RequestFrame, RequestType, ResponseFrame, ResponsePayload, + StatementCancellationPayload, decode_authentication_request, payload_from_execution_result, + read_request_with_limit, write_response, }; static NEXT_CONNECTION_ID: AtomicU64 = AtomicU64::new(1); +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ServerRole { + Admin, + Writer, + Reader, +} + +impl ServerRole { + pub fn as_str(self) -> &'static str { + match self { + ServerRole::Admin => "admin", + ServerRole::Writer => "writer", + ServerRole::Reader => "reader", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StaticPasswordUser { + pub username: String, + pub password: String, + pub role: ServerRole, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StaticTokenPrincipal { + pub label: String, + pub token: String, + pub role: ServerRole, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub enum ServerAuthOptions { + #[default] + Disabled, + StaticPassword { + users: Vec, + }, + StaticToken { + principals: Vec, + }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ServerTlsMode { + #[default] + Disabled, + Required, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct ServerTlsOptions { + pub mode: ServerTlsMode, + pub cert_path: Option, + pub key_path: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct ServerSecurityOptions { + pub auth: ServerAuthOptions, + pub tls: ServerTlsOptions, + pub allow_anonymous_access: bool, +} + +impl ServerSecurityOptions { + pub fn anonymous_for_local_dev() -> Self { + Self { allow_anonymous_access: true, ..Self::default() } + } + + fn validate(&self) -> Result<(), ServerError> { + match &self.auth { + ServerAuthOptions::Disabled => {} + ServerAuthOptions::StaticPassword { users } => { + if users.is_empty() { + return Err(ServerError::InvalidConfiguration( + "password authentication requires at least one user".to_string(), + )); + } + } + ServerAuthOptions::StaticToken { principals } => { + if principals.is_empty() { + return Err(ServerError::InvalidConfiguration( + "token authentication requires at least one token principal".to_string(), + )); + } + } + } + + if !matches!(self.auth, ServerAuthOptions::Disabled) + && self.tls.mode != ServerTlsMode::Required + { + return Err(ServerError::InvalidConfiguration( + "authentication requires tls mode 'required'".to_string(), + )); + } + if matches!(self.auth, ServerAuthOptions::Disabled) && !self.allow_anonymous_access { + return Err(ServerError::InvalidConfiguration( + "authentication is disabled; set allow_anonymous_access only for local development or tests" + .to_string(), + )); + } + + if self.tls.mode == ServerTlsMode::Required { + if self.tls.cert_path.is_none() { + return Err(ServerError::InvalidConfiguration( + "tls mode 'required' needs a certificate path".to_string(), + )); + } + if self.tls.key_path.is_none() { + return Err(ServerError::InvalidConfiguration( + "tls mode 'required' needs a private key path".to_string(), + )); + } + } + + Ok(()) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ServerLimits { + pub max_concurrent_connections: usize, + pub max_in_flight_requests_per_connection: usize, + pub max_request_bytes: usize, + pub max_statements_per_request: usize, + pub statement_timeout_ms: Option, + pub max_memory_intensive_requests: usize, + pub max_scan_rows: usize, + pub max_sort_rows: usize, + pub max_join_rows: usize, + pub max_query_result_rows: usize, + pub max_query_result_bytes: usize, + pub max_concurrent_queries_per_identity: Option, +} + +impl Default for ServerLimits { + fn default() -> Self { + Self { + max_concurrent_connections: 128, + max_in_flight_requests_per_connection: 1, + max_request_bytes: 256 * 1024, + max_statements_per_request: 16, + statement_timeout_ms: None, + max_memory_intensive_requests: 8, + max_scan_rows: 10_000, + max_sort_rows: 10_000, + max_join_rows: 10_000, + max_query_result_rows: 10_000, + max_query_result_bytes: 4 * 1024 * 1024, + max_concurrent_queries_per_identity: None, + } + } +} + +impl ServerLimits { + fn validate(self) -> Result<(), ServerError> { + for (field, value) in [ + ("max_concurrent_connections", self.max_concurrent_connections), + ("max_in_flight_requests_per_connection", self.max_in_flight_requests_per_connection), + ("max_request_bytes", self.max_request_bytes), + ("max_statements_per_request", self.max_statements_per_request), + ("max_memory_intensive_requests", self.max_memory_intensive_requests), + ("max_scan_rows", self.max_scan_rows), + ("max_sort_rows", self.max_sort_rows), + ("max_join_rows", self.max_join_rows), + ("max_query_result_rows", self.max_query_result_rows), + ("max_query_result_bytes", self.max_query_result_bytes), + ] { + if value == 0 { + return Err(ServerError::InvalidConfiguration(format!( + "server limit '{field}' must be > 0" + ))); + } + } + + if matches!(self.statement_timeout_ms, Some(0)) { + return Err(ServerError::InvalidConfiguration( + "server limit 'statement_timeout_ms' must be > 0 when set".to_string(), + )); + } + if matches!(self.max_concurrent_queries_per_identity, Some(0)) { + return Err(ServerError::InvalidConfiguration( + "server limit 'max_concurrent_queries_per_identity' must be > 0 when set" + .to_string(), + )); + } + + Ok(()) + } + + fn execution_limits(self) -> ExecutionLimits { + ExecutionLimits { + max_scan_rows: self.max_scan_rows, + max_sort_rows: self.max_sort_rows, + max_join_rows: self.max_join_rows, + max_query_result_rows: self.max_query_result_rows, + max_query_result_bytes: self.max_query_result_bytes, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct ServerOptions { + pub limits: ServerLimits, + pub security: ServerSecurityOptions, +} + +impl ServerOptions { + pub fn insecure_for_local_dev() -> Self { + Self { security: ServerSecurityOptions::anonymous_for_local_dev(), ..Self::default() } + } +} + +#[derive(Debug)] +struct ServerRuntimeState { + started_at: Instant, + accepting_connections: AtomicBool, + active_connections: AtomicU64, + total_connections: AtomicU64, + rejected_connections: AtomicU64, + busy_requests: AtomicU64, + resource_limit_requests: AtomicU64, + quota_rejections: AtomicU64, + timed_out_requests: AtomicU64, + canceled_requests: AtomicU64, + active_memory_intensive_requests: AtomicU64, + next_statement_id: AtomicU64, + active_statements: Mutex>, + identity_query_counts: Mutex>, + limits: ServerLimits, + security: ServerSecurityOptions, + connection_slots: Arc, + memory_intensive_slots: Arc, +} + +impl ServerRuntimeState { + fn new(options: ServerOptions) -> Self { + Self { + started_at: Instant::now(), + accepting_connections: AtomicBool::new(true), + active_connections: AtomicU64::new(0), + total_connections: AtomicU64::new(0), + rejected_connections: AtomicU64::new(0), + busy_requests: AtomicU64::new(0), + resource_limit_requests: AtomicU64::new(0), + quota_rejections: AtomicU64::new(0), + timed_out_requests: AtomicU64::new(0), + canceled_requests: AtomicU64::new(0), + active_memory_intensive_requests: AtomicU64::new(0), + next_statement_id: AtomicU64::new(1), + active_statements: Mutex::new(BTreeMap::new()), + identity_query_counts: Mutex::new(HashMap::new()), + limits: options.limits, + security: options.security, + connection_slots: Arc::new(Semaphore::new(options.limits.max_concurrent_connections)), + memory_intensive_slots: Arc::new(Semaphore::new( + options.limits.max_memory_intensive_requests, + )), + } + } + + fn try_acquire_connection(self: &Arc) -> Option { + self.connection_slots.clone().try_acquire_owned().ok() + } + + fn try_acquire_memory_intensive_slot(self: &Arc) -> Option { + self.memory_intensive_slots.clone().try_acquire_owned().ok() + } + + fn record_rejected_connection(&self) { + self.rejected_connections.fetch_add(1, Ordering::Relaxed); + } + + fn record_busy_request(&self) { + self.busy_requests.fetch_add(1, Ordering::Relaxed); + } + + fn record_resource_limit_request(&self) { + self.resource_limit_requests.fetch_add(1, Ordering::Relaxed); + } + + fn record_quota_rejection(&self) { + self.quota_rejections.fetch_add(1, Ordering::Relaxed); + } + + fn record_timed_out_request(&self) { + self.timed_out_requests.fetch_add(1, Ordering::Relaxed); + } + + fn record_canceled_request(&self) { + self.canceled_requests.fetch_add(1, Ordering::Relaxed); + } + + fn begin_statement( + self: &Arc, + connection_id: u64, + identity: &str, + request_type: RequestType, + sql: &str, + ) -> Result { + let statement_id = self.next_statement_id.fetch_add(1, Ordering::Relaxed); + let cancellation = StatementCancellation::new(); + let entry = ActiveStatementEntry { + statement_id, + connection_id, + identity: identity.to_string(), + request_type, + sql_preview: sql_preview(sql), + started_at: Instant::now(), + cancellation: cancellation.clone(), + }; + self.active_statements.lock().insert(statement_id, entry); + + Ok(StatementExecutionGuard { + runtime_state: Arc::clone(self), + statement_id, + identity: identity.to_string(), + cancellation, + }) + } + + fn begin_identity_query( + self: &Arc, + identity: &str, + ) -> Result { + if let Some(limit) = self.limits.max_concurrent_queries_per_identity { + let mut counts = self.identity_query_counts.lock(); + let current = counts.get(identity).copied().unwrap_or(0); + if current >= limit { + drop(counts); + return Err(RequestError::Quota(format!( + "identity '{identity}' exceeded concurrent query quota ({limit})" + ))); + } + counts.insert(identity.to_string(), current + 1); + } + + Ok(IdentityQueryGuard { runtime_state: Arc::clone(self), identity: identity.to_string() }) + } + + fn finish_statement(&self, statement_id: u64, _identity: &str) { + self.active_statements.lock().remove(&statement_id); + } + + fn finish_identity_query(&self, identity: &str) { + if self.limits.max_concurrent_queries_per_identity.is_some() { + let mut counts = self.identity_query_counts.lock(); + if let Some(current) = counts.get_mut(identity) { + if *current <= 1 { + counts.remove(identity); + } else { + *current -= 1; + } + } + } + } + + fn active_statement_payloads(&self) -> Vec { + self.active_statements + .lock() + .values() + .map(|entry| ActiveStatementPayload { + statement_id: entry.statement_id, + connection_id: entry.connection_id, + identity: entry.identity.clone(), + request_type: request_type_name(entry.request_type).to_string(), + runtime_ms: duration_to_millis(entry.started_at.elapsed()), + cancel_requested: entry.cancellation.reason().is_some(), + sql_preview: entry.sql_preview.clone(), + }) + .collect() + } + + fn cancel_statement(&self, statement_id: u64) -> StatementCancellationPayload { + let Some(statement) = self.active_statements.lock().get(&statement_id).cloned() else { + return StatementCancellationPayload { + statement_id, + accepted: false, + status: "statement not found".to_string(), + }; + }; + + let accepted = statement.cancellation.cancel(); + StatementCancellationPayload { + statement_id, + accepted, + status: if accepted { + "cancellation signaled".to_string() + } else { + "cancellation was already requested".to_string() + }, + } + } +} + +#[derive(Debug, Clone)] +struct ActiveStatementEntry { + statement_id: u64, + connection_id: u64, + identity: String, + request_type: RequestType, + sql_preview: String, + started_at: Instant, + cancellation: StatementCancellation, +} + +#[derive(Debug)] +struct StatementExecutionGuard { + runtime_state: Arc, + statement_id: u64, + identity: String, + cancellation: StatementCancellation, +} + +impl StatementExecutionGuard { + fn governance(&self, statement_timeout_ms: Option) -> ExecutionGovernance { + let mut governance = + ExecutionGovernance::default().with_cancellation(self.cancellation.clone()); + if let Some(timeout_ms) = statement_timeout_ms { + governance = governance.with_timeout(Duration::from_millis(timeout_ms)); + } + governance + } +} + +impl Drop for StatementExecutionGuard { + fn drop(&mut self) { + self.runtime_state.finish_statement(self.statement_id, &self.identity); + } +} + +#[derive(Debug)] +struct IdentityQueryGuard { + runtime_state: Arc, + identity: String, +} + +impl Drop for IdentityQueryGuard { + fn drop(&mut self) { + self.runtime_state.finish_identity_query(&self.identity); + } +} + +#[derive(Debug)] +struct ConnectionGuard { + runtime_state: Arc, + _permit: OwnedSemaphorePermit, +} + +impl ConnectionGuard { + fn new(runtime_state: Arc, permit: OwnedSemaphorePermit) -> Self { + runtime_state.active_connections.fetch_add(1, Ordering::Relaxed); + Self { runtime_state, _permit: permit } + } +} + +impl Drop for ConnectionGuard { + fn drop(&mut self) { + self.runtime_state.active_connections.fetch_sub(1, Ordering::Relaxed); + } +} + +#[derive(Debug)] +struct ConnectionRequestGuard { + _permit: OwnedSemaphorePermit, +} + +#[derive(Debug)] +struct MemoryIntensiveRequestGuard { + runtime_state: Arc, + _permit: OwnedSemaphorePermit, +} + +impl MemoryIntensiveRequestGuard { + fn new(runtime_state: Arc, permit: OwnedSemaphorePermit) -> Self { + runtime_state.active_memory_intensive_requests.fetch_add(1, Ordering::Relaxed); + Self { runtime_state, _permit: permit } + } +} + +impl Drop for MemoryIntensiveRequestGuard { + fn drop(&mut self) { + self.runtime_state.active_memory_intensive_requests.fetch_sub(1, Ordering::Relaxed); + } +} + +#[derive(Debug)] +struct ConnectionContext { + request_slots: Arc, + connection_id: u64, + principal: AuthenticatedPrincipal, + peer_identity: String, +} + +impl ConnectionContext { + fn new( + limit: usize, + connection_id: u64, + peer_identity: String, + principal: AuthenticatedPrincipal, + ) -> Self { + Self { + request_slots: Arc::new(Semaphore::new(limit)), + connection_id, + principal, + peer_identity, + } + } + + fn try_acquire_request_slot(&self) -> Option { + self.request_slots + .clone() + .try_acquire_owned() + .ok() + .map(|permit| ConnectionRequestGuard { _permit: permit }) + } + + fn identity(&self) -> &str { + &self.principal.identity + } + + fn role(&self) -> ServerRole { + self.principal.role + } +} + +#[derive(Debug, Clone)] +struct AuthenticatedPrincipal { + identity: String, + role: ServerRole, + auth_scheme: &'static str, +} + #[derive(Debug, Error)] pub enum ServerError { #[error("I/O error: {0}")] Io(#[from] std::io::Error), #[error("protocol error: {0}")] Protocol(#[from] ProtocolError), + #[error("TLS error: {0}")] + Tls(String), + #[error("invalid server configuration: {0}")] + InvalidConfiguration(String), #[error("accept loop task failed: {0}")] Join(String), } @@ -36,6 +587,16 @@ pub enum ServerError { enum RequestError { #[error("invalid request: {0}")] InvalidRequest(String), + #[error("server busy: {0}")] + Busy(String), + #[error("resource limit exceeded: {0}")] + ResourceLimit(String), + #[error("quota exceeded: {0}")] + Quota(String), + #[error("unauthenticated: {0}")] + Unauthenticated(String), + #[error("permission denied: {0}")] + PermissionDenied(String), #[error("parse error: {0}")] Parse(#[from] ParseError), #[error("validation error: {0}")] @@ -46,6 +607,261 @@ enum RequestError { Execution(#[from] ExecutionError), } +impl RequestError { + fn into_error_payload(self) -> ErrorPayload { + match self { + RequestError::InvalidRequest(message) => { + ErrorPayload::new(ErrorCode::InvalidRequest, message, false) + } + RequestError::Busy(message) => ErrorPayload::new(ErrorCode::Busy, message, true), + RequestError::ResourceLimit(message) => { + ErrorPayload::new(ErrorCode::ResourceLimit, message, false) + } + RequestError::Quota(message) => ErrorPayload::new(ErrorCode::Quota, message, true), + RequestError::Unauthenticated(message) => { + ErrorPayload::new(ErrorCode::Unauthenticated, message, false) + } + RequestError::PermissionDenied(message) => { + ErrorPayload::new(ErrorCode::PermissionDenied, message, false) + } + RequestError::Parse(error) => { + ErrorPayload::new(ErrorCode::Parse, error.to_string(), false) + } + RequestError::Validation(error) => { + ErrorPayload::new(ErrorCode::Validation, error.to_string(), false) + } + RequestError::Planner(error) => { + ErrorPayload::new(ErrorCode::Planner, error.to_string(), false) + } + RequestError::Execution(error) => match error { + ExecutionError::ResourceLimitExceeded { .. } => { + ErrorPayload::new(ErrorCode::ResourceLimit, error.to_string(), false) + } + ExecutionError::StatementTimedOut { .. } => { + ErrorPayload::new(ErrorCode::Timeout, error.to_string(), false) + } + ExecutionError::StatementCanceled { .. } => { + ErrorPayload::new(ErrorCode::Canceled, error.to_string(), false) + } + _ => ErrorPayload::new(ErrorCode::Execution, error.to_string(), false), + }, + } + } +} + +fn load_tls_acceptor(options: &ServerTlsOptions) -> Result, ServerError> { + if options.mode == ServerTlsMode::Disabled { + return Ok(None); + } + + let cert_path = options.cert_path.as_ref().ok_or_else(|| { + ServerError::InvalidConfiguration( + "tls mode 'required' needs a certificate path".to_string(), + ) + })?; + let key_path = options.key_path.as_ref().ok_or_else(|| { + ServerError::InvalidConfiguration( + "tls mode 'required' needs a private key path".to_string(), + ) + })?; + + let mut cert_reader = BufReader::new(File::open(cert_path).map_err(|err| { + ServerError::Tls(format!("failed to open TLS certificate '{}': {err}", cert_path.display())) + })?); + let certificates = + rustls_pemfile::certs(&mut cert_reader).collect::, _>>().map_err(|err| { + ServerError::Tls(format!( + "failed to parse TLS certificate '{}': {err}", + cert_path.display() + )) + })?; + if certificates.is_empty() { + return Err(ServerError::Tls(format!( + "TLS certificate '{}' did not contain any certificate entries", + cert_path.display() + ))); + } + + let mut key_reader = BufReader::new(File::open(key_path).map_err(|err| { + ServerError::Tls(format!("failed to open TLS private key '{}': {err}", key_path.display())) + })?); + let private_key = rustls_pemfile::private_key(&mut key_reader) + .map_err(|err| { + ServerError::Tls(format!( + "failed to parse TLS private key '{}': {err}", + key_path.display() + )) + })? + .ok_or_else(|| { + ServerError::Tls(format!( + "TLS private key '{}' did not contain a supported key", + key_path.display() + )) + })?; + + let config = build_tls_server_config(certificates, private_key)?; + Ok(Some(TlsAcceptor::from(Arc::new(config)))) +} + +fn build_tls_server_config( + certificates: Vec>, + private_key: PrivateKeyDer<'static>, +) -> Result { + RustlsServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certificates, private_key) + .map_err(|err| ServerError::Tls(format!("failed to build TLS server config: {err}"))) +} + +async fn perform_authentication_handshake( + stream: &mut S, + runtime_state: &Arc, + connection_id: u64, + peer_identity: &str, +) -> Result, ServerError> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + match &runtime_state.security.auth { + ServerAuthOptions::Disabled => Ok(Some(AuthenticatedPrincipal { + identity: peer_identity.to_string(), + role: ServerRole::Admin, + auth_scheme: "disabled", + })), + auth_options => { + let request = match read_request_with_limit( + stream, + runtime_state.limits.max_request_bytes, + ) + .await + { + Ok(Some(request)) => request, + Ok(None) => return Ok(None), + Err(ProtocolError::FrameTooLarge { length, max }) => { + runtime_state.record_resource_limit_request(); + let response = ResponseFrame::Err(ErrorPayload::new( + ErrorCode::ResourceLimit, + format!( + "authentication frame too large: {length} bytes exceeds limit {max} bytes" + ), + false, + )); + let _ = write_response(stream, &response).await; + return Ok(None); + } + Err(err) => return Err(ServerError::Protocol(err)), + }; + + if request.request_type != RequestType::Authenticate { + warn!( + connection_id, + peer_identity, + request_type = ?request.request_type, + "rejecting unauthenticated request before session authentication" + ); + let response = ResponseFrame::Err(ErrorPayload::new( + ErrorCode::Unauthenticated, + "secure mode requires authentication before any other request".to_string(), + false, + )); + let _ = write_response(stream, &response).await; + return Ok(None); + } + + let auth_request = match decode_authentication_request(&request.sql) { + Ok(request) => request, + Err(message) => { + warn!(connection_id, peer_identity, error = %message, "invalid authentication payload"); + let response = ResponseFrame::Err(ErrorPayload::new( + ErrorCode::InvalidRequest, + message, + false, + )); + let _ = write_response(stream, &response).await; + return Ok(None); + } + }; + + match authenticate_principal(auth_options, auth_request) { + Ok(principal) => { + info!( + connection_id, + peer_identity, + identity = %principal.identity, + role = principal.role.as_str(), + auth_scheme = principal.auth_scheme, + "authenticated connection" + ); + let response = + ResponseFrame::Ok(ResponsePayload::Authentication(AuthenticationPayload { + identity: principal.identity.clone(), + role: principal.role.as_str().to_string(), + auth_scheme: principal.auth_scheme.to_string(), + })); + write_response(stream, &response).await?; + Ok(Some(principal)) + } + Err(err) => { + warn!(connection_id, peer_identity, error = %err, "authentication failed"); + let response = ResponseFrame::Err(err.into_error_payload()); + let _ = write_response(stream, &response).await; + Ok(None) + } + } + } + } +} + +fn authenticate_principal( + auth_options: &ServerAuthOptions, + request: AuthenticationRequest, +) -> Result { + match (auth_options, request) { + ( + ServerAuthOptions::StaticPassword { users }, + AuthenticationRequest::Password { username, password }, + ) => { + let user = users + .iter() + .find(|user| user.username == username && user.password == password) + .ok_or_else(|| { + RequestError::Unauthenticated( + "invalid username or password for secure server".to_string(), + ) + })?; + Ok(AuthenticatedPrincipal { + identity: user.username.clone(), + role: user.role, + auth_scheme: "password", + }) + } + (ServerAuthOptions::StaticToken { principals }, AuthenticationRequest::Token { token }) => { + let principal = + principals.iter().find(|principal| principal.token == token).ok_or_else(|| { + RequestError::Unauthenticated("invalid token for secure server".to_string()) + })?; + Ok(AuthenticatedPrincipal { + identity: principal.label.clone(), + role: principal.role, + auth_scheme: "token", + }) + } + (ServerAuthOptions::StaticPassword { .. }, AuthenticationRequest::Token { .. }) => { + Err(RequestError::Unauthenticated( + "secure server expects password authentication".to_string(), + )) + } + (ServerAuthOptions::StaticToken { .. }, AuthenticationRequest::Password { .. }) => Err( + RequestError::Unauthenticated("secure server expects token authentication".to_string()), + ), + (ServerAuthOptions::Disabled, _) => Ok(AuthenticatedPrincipal { + identity: "anonymous".to_string(), + role: ServerRole::Admin, + auth_scheme: "disabled", + }), + } +} + pub struct ServerHandle { local_addr: SocketAddr, shutdown_tx: Option>, @@ -78,14 +894,28 @@ pub async fn start_server( catalog: Arc, store: Arc, ) -> Result { + start_server_with_options(bind_addr, catalog, store, ServerOptions::default()).await +} + +pub async fn start_server_with_options( + bind_addr: SocketAddr, + catalog: Arc, + store: Arc, + options: ServerOptions, +) -> Result { + options.limits.validate()?; + options.security.validate()?; + let tls_acceptor = load_tls_acceptor(&options.security.tls)?; info!(%bind_addr, "starting tcp server"); let listener = TcpListener::bind(bind_addr).await?; let local_addr = listener.local_addr()?; info!(%local_addr, "tcp server bound"); + let runtime_state = Arc::new(ServerRuntimeState::new(options)); let (shutdown_tx, shutdown_rx) = oneshot::channel(); - let task = - tokio::spawn(async move { run_accept_loop(listener, catalog, store, shutdown_rx).await }); + let task = tokio::spawn(async move { + run_accept_loop(listener, catalog, store, runtime_state, tls_acceptor, shutdown_rx).await + }); Ok(ServerHandle { local_addr, shutdown_tx: Some(shutdown_tx), task: Some(task) }) } @@ -94,23 +924,78 @@ async fn run_accept_loop( listener: TcpListener, catalog: Arc, store: Arc, + runtime_state: Arc, + tls_acceptor: Option, mut shutdown_rx: oneshot::Receiver<()>, ) -> Result<(), ServerError> { loop { tokio::select! { _ = &mut shutdown_rx => { + runtime_state.accepting_connections.store(false, Ordering::Relaxed); info!("tcp accept loop received shutdown signal"); break; } accept_result = listener.accept() => { - let (stream, peer_addr) = accept_result?; + let (mut stream, peer_addr) = accept_result?; let connection_id = NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed); + let peer_identity = peer_addr.ip().to_string(); + runtime_state.total_connections.fetch_add(1, Ordering::Relaxed); + let Some(connection_permit) = runtime_state.try_acquire_connection() else { + runtime_state.record_rejected_connection(); + warn!( + connection_id, + %peer_addr, + max_concurrent_connections = runtime_state.limits.max_concurrent_connections, + "rejecting connection because the server is at capacity" + ); + let response = ResponseFrame::Err(ErrorPayload::new( + ErrorCode::Busy, + format!( + "server busy: max concurrent connections ({}) reached; retry later", + runtime_state.limits.max_concurrent_connections + ), + true, + )); + let _ = write_response(&mut stream, &response).await; + continue; + }; info!(connection_id, %peer_addr, "accepted tcp connection"); let catalog = Arc::clone(&catalog); let store = Arc::clone(&store); + let runtime_state = Arc::clone(&runtime_state); + let tls_acceptor = tls_acceptor.clone(); let span = info_span!("connection", connection_id, %peer_addr); tokio::spawn(async move { - if let Err(err) = handle_connection(stream, catalog, store).await { + let result = if let Some(acceptor) = tls_acceptor { + match acceptor.accept(stream).await { + Ok(tls_stream) => { + handle_connection( + tls_stream, + catalog, + store, + runtime_state, + connection_permit, + connection_id, + peer_identity, + ) + .await + } + Err(err) => Err(ServerError::Tls(err.to_string())), + } + } else { + handle_connection( + stream, + catalog, + store, + runtime_state, + connection_permit, + connection_id, + peer_identity, + ) + .await + }; + + if let Err(err) = result { warn!(error = %err, "connection task failed"); } }.instrument(span)); @@ -121,32 +1006,139 @@ async fn run_accept_loop( Ok(()) } -async fn handle_connection( - mut stream: TcpStream, +async fn handle_connection( + mut stream: S, catalog: Arc, store: Arc, -) -> Result<(), ServerError> { - let mut session = ExecutionSession::new(catalog.as_ref(), store.as_ref()); + runtime_state: Arc, + connection_permit: OwnedSemaphorePermit, + connection_id: u64, + peer_identity: String, +) -> Result<(), ServerError> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + let _connection_guard = ConnectionGuard::new(Arc::clone(&runtime_state), connection_permit); + let Some(principal) = perform_authentication_handshake( + &mut stream, + &runtime_state, + connection_id, + &peer_identity, + ) + .await? + else { + return Ok(()); + }; + let connection_context = ConnectionContext::new( + runtime_state.limits.max_in_flight_requests_per_connection, + connection_id, + peer_identity, + principal, + ); + let mut session = ExecutionSession::with_limits( + catalog.as_ref(), + store.as_ref(), + runtime_state.limits.execution_limits(), + ); debug!("connection session created"); loop { - let Some(request) = read_request(&mut stream).await? else { - debug!("client closed connection"); - return Ok(()); + let request = match read_request_with_limit( + &mut stream, + runtime_state.limits.max_request_bytes, + ) + .await + { + Ok(Some(request)) => request, + Ok(None) => { + debug!("client closed connection"); + return Ok(()); + } + Err(ProtocolError::FrameTooLarge { length, max }) => { + runtime_state.record_resource_limit_request(); + warn!(length, max, "rejecting oversized request frame"); + let response = ResponseFrame::Err(ErrorPayload::new( + ErrorCode::ResourceLimit, + format!("request frame too large: {length} bytes exceeds limit {max} bytes"), + false, + )); + let _ = write_response(&mut stream, &response).await; + return Ok(()); + } + Err(err) => return Err(ServerError::Protocol(err)), + }; + + let Some(_request_guard) = connection_context.try_acquire_request_slot() else { + runtime_state.record_busy_request(); + let response = ResponseFrame::Err(ErrorPayload::new( + ErrorCode::Busy, + format!( + "server busy: max in-flight requests per connection ({}) exceeded; retry later", + runtime_state.limits.max_in_flight_requests_per_connection + ), + true, + )); + write_response(&mut stream, &response).await?; + continue; }; let request_type = request.request_type; let sql_len = request.sql.len(); debug!(request_type = ?request_type, sql_len, "received request frame"); - let response = match execute_request(&mut session, &catalog, request) { + let identity_query_guard = + if matches!(request_type, RequestType::Query | RequestType::Explain) { + match runtime_state.begin_identity_query(connection_context.identity()) { + Ok(guard) => Some(guard), + Err(err) => { + warn!(request_type = ?request_type, error = %err, "request failed"); + let payload = err.into_error_payload(); + if payload.code == ErrorCode::Quota { + runtime_state.record_quota_rejection(); + } + let response = ResponseFrame::Err(payload); + if let Err(err) = write_response(&mut stream, &response).await { + error!(error = %err, "failed to write response"); + return Err(ServerError::Protocol(err)); + } + continue; + } + } + } else { + None + }; + + let response = match execute_request( + &mut session, + &catalog, + &store, + &runtime_state, + &connection_context, + request, + ) { Ok(payload) => { debug!(request_type = ?request_type, "request handled successfully"); ResponseFrame::Ok(payload) } Err(err) => { warn!(request_type = ?request_type, error = %err, "request failed"); - ResponseFrame::Err(err.to_string()) + let payload = err.into_error_payload(); + if payload.code == ErrorCode::Busy { + runtime_state.record_busy_request(); + } + if payload.code == ErrorCode::ResourceLimit { + runtime_state.record_resource_limit_request(); + } + if payload.code == ErrorCode::Quota { + runtime_state.record_quota_rejection(); + } + if payload.code == ErrorCode::Timeout { + runtime_state.record_timed_out_request(); + } + if payload.code == ErrorCode::Canceled { + runtime_state.record_canceled_request(); + } + ResponseFrame::Err(payload) } }; @@ -154,12 +1146,16 @@ async fn handle_connection( error!(error = %err, "failed to write response"); return Err(ServerError::Protocol(err)); } + drop(identity_query_guard); } } fn execute_request( session: &mut ExecutionSession<'_>, catalog: &Catalog, + store: &MvccStore, + runtime_state: &Arc, + connection_context: &ConnectionContext, request: RequestFrame, ) -> Result { debug!(request_type = ?request.request_type, "executing request"); @@ -170,27 +1166,121 @@ fn execute_request( "query request requires non-empty SQL payload".to_string(), )); } - execute_sql(session, catalog, &request.sql) + let statement = runtime_state.begin_statement( + connection_context.connection_id, + connection_context.identity(), + RequestType::Query, + &request.sql, + )?; + execute_sql( + session, + catalog, + runtime_state, + connection_context, + &request.sql, + &statement.governance(runtime_state.limits.statement_timeout_ms), + ) + } + RequestType::Begin => execute_sql( + session, + catalog, + runtime_state, + connection_context, + "BEGIN ISOLATION LEVEL SNAPSHOT", + &ExecutionGovernance::default(), + ), + RequestType::Commit => execute_sql( + session, + catalog, + runtime_state, + connection_context, + "COMMIT", + &ExecutionGovernance::default(), + ), + RequestType::Rollback => execute_sql( + session, + catalog, + runtime_state, + connection_context, + "ROLLBACK", + &ExecutionGovernance::default(), + ), + RequestType::Explain => { + let statement = runtime_state.begin_statement( + connection_context.connection_id, + connection_context.identity(), + RequestType::Explain, + &request.sql, + )?; + explain_sql( + catalog, + runtime_state, + connection_context, + &request.sql, + &statement.governance(runtime_state.limits.statement_timeout_ms), + ) + } + RequestType::Health => Ok(health_payload()), + RequestType::Readiness => Ok(readiness_payload(runtime_state)), + RequestType::AdminStatus => { + authorize_admin_request(connection_context, "admin status")?; + Ok(admin_status_payload(store, runtime_state.as_ref())) + } + RequestType::ActiveStatements => { + authorize_admin_request(connection_context, "active statements")?; + Ok(ResponsePayload::ActiveStatements(ActiveStatementsPayload { + statements: runtime_state.active_statement_payloads(), + })) + } + RequestType::CancelStatement => { + authorize_admin_request(connection_context, "statement cancellation")?; + cancel_statement(runtime_state, &request.sql) + } + RequestType::Authenticate => { + Err(RequestError::InvalidRequest("connection is already authenticated".to_string())) } - RequestType::Begin => execute_sql(session, catalog, "BEGIN ISOLATION LEVEL SNAPSHOT"), - RequestType::Commit => execute_sql(session, catalog, "COMMIT"), - RequestType::Rollback => execute_sql(session, catalog, "ROLLBACK"), - RequestType::Explain => explain_sql(catalog, &request.sql), } } fn execute_sql( session: &mut ExecutionSession<'_>, catalog: &Catalog, + runtime_state: &Arc, + connection_context: &ConnectionContext, sql: &str, + governance: &ExecutionGovernance, ) -> Result { + governance.checkpoint()?; let statements = parse_sql(sql)?; + authorize_sql_statements(connection_context, &statements)?; + if statements.len() > runtime_state.limits.max_statements_per_request { + return Err(RequestError::ResourceLimit(format!( + "request contains {} statements, limit is {}", + statements.len(), + runtime_state.limits.max_statements_per_request + ))); + } let mut last_result = None; for statement in statements { + governance.checkpoint()?; validate_statement(catalog, &statement)?; + governance.checkpoint()?; let plan = plan_statement(catalog, &statement)?; - let result = session.execute_plan(&plan)?; + let _memory_guard = acquire_memory_intensive_guard(runtime_state, &plan)?; + let result = match session.execute_plan_with_governance(&plan, governance) { + Ok(result) => result, + Err(err) => { + if matches!( + err, + ExecutionError::StatementTimedOut { .. } + | ExecutionError::StatementCanceled { .. } + ) { + session.abort_active_transaction(); + } + return Err(RequestError::Execution(err)); + } + }; last_result = Some(result); } @@ -200,7 +1290,14 @@ fn execute_sql( Ok(payload_from_execution_result(&result)) } -fn explain_sql(catalog: &Catalog, sql: &str) -> Result { +fn explain_sql( + catalog: &Catalog, + runtime_state: &Arc, + connection_context: &ConnectionContext, + sql: &str, + governance: &ExecutionGovernance, +) -> Result { + governance.checkpoint()?; if sql.trim().is_empty() { return Err(RequestError::InvalidRequest( "explain request requires non-empty SQL payload".to_string(), @@ -213,13 +1310,216 @@ fn explain_sql(catalog: &Catalog, sql: &str) -> Result runtime_state.limits.max_statements_per_request { + return Err(RequestError::ResourceLimit(format!( + "request contains {} statements, limit is {}", + statements.len(), + runtime_state.limits.max_statements_per_request + ))); + } let mut rendered = Vec::new(); for (index, statement) in statements.into_iter().enumerate() { + governance.checkpoint()?; validate_statement(catalog, &statement)?; + governance.checkpoint()?; let plan = plan_statement(catalog, &statement)?; rendered.push(format!("Statement {}:\n{plan:#?}", index + 1)); } Ok(ResponsePayload::ExplainPlan(rendered.join("\n\n"))) } + +fn health_payload() -> ResponsePayload { + ResponsePayload::Health(HealthPayload { ok: true, status: "ok".to_string() }) +} + +fn cancel_statement( + runtime_state: &Arc, + statement_id: &str, +) -> Result { + let statement_id = statement_id.trim().parse::().map_err(|_| { + RequestError::InvalidRequest( + "cancel request requires a numeric statement id in the SQL payload".to_string(), + ) + })?; + Ok(ResponsePayload::StatementCancellation(runtime_state.cancel_statement(statement_id))) +} + +fn readiness_payload(runtime_state: &Arc) -> ResponsePayload { + let ready = runtime_state.accepting_connections.load(Ordering::Relaxed); + let status = if ready { "ready" } else { "draining" }; + ResponsePayload::Readiness(ReadinessPayload { ready, status: status.to_string() }) +} + +fn authorize_admin_request( + connection_context: &ConnectionContext, + operation: &str, +) -> Result<(), RequestError> { + if connection_context.role() == ServerRole::Admin { + return Ok(()); + } + + warn!( + connection_id = connection_context.connection_id, + peer_identity = %connection_context.peer_identity, + identity = %connection_context.identity(), + role = connection_context.role().as_str(), + operation, + "authorization denied for admin-only request" + ); + Err(RequestError::PermissionDenied(format!( + "role '{}' cannot access {operation}", + connection_context.role().as_str() + ))) +} + +fn authorize_sql_statements( + connection_context: &ConnectionContext, + statements: &[Statement], +) -> Result<(), RequestError> { + for statement in statements { + if role_allows_statement(connection_context.role(), statement) { + continue; + } + + warn!( + connection_id = connection_context.connection_id, + peer_identity = %connection_context.peer_identity, + identity = %connection_context.identity(), + role = connection_context.role().as_str(), + statement_kind = statement_kind(statement), + "authorization denied for SQL statement" + ); + return Err(RequestError::PermissionDenied(format!( + "role '{}' cannot execute {} statements", + connection_context.role().as_str(), + statement_kind(statement) + ))); + } + + Ok(()) +} + +fn role_allows_statement(role: ServerRole, statement: &Statement) -> bool { + match role { + ServerRole::Admin => true, + ServerRole::Writer => matches!( + statement, + Statement::Insert(_) + | Statement::Select(_) + | Statement::Update(_) + | Statement::Delete(_) + | Statement::Begin(_) + | Statement::Commit + | Statement::Rollback + ), + ServerRole::Reader => matches!( + statement, + Statement::Select(_) | Statement::Begin(_) | Statement::Commit | Statement::Rollback + ), + } +} + +fn statement_kind(statement: &Statement) -> &'static str { + match statement { + Statement::CreateTable(_) => "CREATE TABLE", + Statement::DropTable(_) => "DROP TABLE", + Statement::Insert(_) => "INSERT", + Statement::Select(_) => "SELECT", + Statement::Update(_) => "UPDATE", + Statement::Delete(_) => "DELETE", + Statement::Begin(_) => "BEGIN", + Statement::Commit => "COMMIT", + Statement::Rollback => "ROLLBACK", + } +} + +fn acquire_memory_intensive_guard( + runtime_state: &Arc, + plan: &PhysicalPlan, +) -> Result, RequestError> { + if !plan_is_memory_intensive(plan) { + return Ok(None); + } + + let Some(permit) = runtime_state.try_acquire_memory_intensive_slot() else { + return Err(RequestError::Busy(format!( + "max memory-intensive requests ({}) reached; retry later", + runtime_state.limits.max_memory_intensive_requests + ))); + }; + + Ok(Some(MemoryIntensiveRequestGuard::new(Arc::clone(runtime_state), permit))) +} + +fn plan_is_memory_intensive(plan: &PhysicalPlan) -> bool { + match plan { + PhysicalPlan::SeqScan(_) | PhysicalPlan::Sort(_) | PhysicalPlan::Join(_) => true, + PhysicalPlan::Filter(node) => plan_is_memory_intensive(&node.input), + PhysicalPlan::Project(node) => plan_is_memory_intensive(&node.input), + PhysicalPlan::Limit(node) => plan_is_memory_intensive(&node.input), + _ => false, + } +} + +fn admin_status_payload(store: &MvccStore, runtime_state: &ServerRuntimeState) -> ResponsePayload { + let metrics = store.metrics(); + ResponsePayload::AdminStatus(AdminStatusPayload { + server_version: env!("CARGO_PKG_VERSION").to_string(), + protocol_version: PROTOCOL_VERSION, + uptime_seconds: runtime_state.started_at.elapsed().as_secs(), + accepting_connections: runtime_state.accepting_connections.load(Ordering::Relaxed), + active_connections: runtime_state.active_connections.load(Ordering::Relaxed), + total_connections: runtime_state.total_connections.load(Ordering::Relaxed), + rejected_connections: runtime_state.rejected_connections.load(Ordering::Relaxed), + busy_requests: runtime_state.busy_requests.load(Ordering::Relaxed), + resource_limit_requests: runtime_state.resource_limit_requests.load(Ordering::Relaxed), + quota_rejections: runtime_state.quota_rejections.load(Ordering::Relaxed), + timed_out_requests: runtime_state.timed_out_requests.load(Ordering::Relaxed), + canceled_requests: runtime_state.canceled_requests.load(Ordering::Relaxed), + active_statements: u64::try_from(runtime_state.active_statements.lock().len()) + .unwrap_or(u64::MAX), + active_memory_intensive_requests: runtime_state + .active_memory_intensive_requests + .load(Ordering::Relaxed), + mvcc_started: metrics.started, + mvcc_committed: metrics.committed, + mvcc_rolled_back: metrics.rolled_back, + mvcc_write_conflicts: metrics.write_conflicts, + mvcc_active_transactions: u64::try_from(metrics.active_transactions).unwrap_or(u64::MAX), + }) +} + +fn request_type_name(request_type: RequestType) -> &'static str { + match request_type { + RequestType::Query => "QUERY", + RequestType::Begin => "BEGIN", + RequestType::Commit => "COMMIT", + RequestType::Rollback => "ROLLBACK", + RequestType::Explain => "EXPLAIN", + RequestType::Health => "HEALTH", + RequestType::Readiness => "READINESS", + RequestType::AdminStatus => "ADMIN_STATUS", + RequestType::ActiveStatements => "ACTIVE_STATEMENTS", + RequestType::CancelStatement => "CANCEL_STATEMENT", + RequestType::Authenticate => "AUTHENTICATE", + } +} + +fn sql_preview(sql: &str) -> String { + const MAX_PREVIEW_CHARS: usize = 160; + + let mut preview = sql.split_whitespace().collect::>().join(" "); + if preview.chars().count() > MAX_PREVIEW_CHARS { + preview = preview.chars().take(MAX_PREVIEW_CHARS).collect::(); + preview.push_str("..."); + } + preview +} + +fn duration_to_millis(duration: Duration) -> u64 { + let millis = duration.as_millis(); + u64::try_from(millis).unwrap_or(u64::MAX) +} diff --git a/src/storage/compaction/leveled.rs b/src/storage/compaction/leveled.rs index 6e9ec61..97da5e7 100644 --- a/src/storage/compaction/leveled.rs +++ b/src/storage/compaction/leveled.rs @@ -68,15 +68,15 @@ pub fn pick_compaction( let mut best: Option = None; let level0 = version_set.level_tables(0); - if level0.len() >= config.level0_file_limit { + if !level0.is_empty() { let target_level = 1; let target_tables = version_set.level_tables(target_level); let source = pick_smallest_overlap_source(level0, target_tables)?; let overlaps = overlapping_tables(source, target_tables); - let score = (level0.len() as f64 / config.level0_file_limit as f64) - + ((overlaps.len() as f64) * 0.01) - + 10.0; + let l0_pressure = level0.len() as f64 / config.level0_file_limit.max(1) as f64; + let overflow_bonus = if level0.len() >= config.level0_file_limit { 10.0 } else { 1.0 }; + let score = l0_pressure + ((overlaps.len() as f64) * 0.01) + overflow_bonus; best = Some(LeveledCompactionPlan { trigger: LeveledTrigger::Level0Overflow, diff --git a/src/storage/compaction/mod.rs b/src/storage/compaction/mod.rs index 2a8a8ee..15eecac 100644 --- a/src/storage/compaction/mod.rs +++ b/src/storage/compaction/mod.rs @@ -3,14 +3,14 @@ pub mod scheduler; pub mod tiered; pub use leveled::{ + LeveledCompactionConfig, LeveledCompactionPlan, LeveledTrigger, pick_compaction as pick_leveled_compaction, ranges_overlap as leveled_ranges_overlap, - target_size_bytes as leveled_target_size_bytes, LeveledCompactionConfig, LeveledCompactionPlan, - LeveledTrigger, + target_size_bytes as leveled_target_size_bytes, }; pub use scheduler::{ CompactionMetrics, CompactionPlan, CompactionScheduler, CompactionStrategy, ScheduledCompaction, }; pub use tiered::{ - group_tables_into_tiers, pick_compaction as pick_tiered_compaction, tier_id_for_size, - TieredCompactionConfig, TieredCompactionPlan, + TieredCompactionConfig, TieredCompactionPlan, group_tables_into_tiers, + pick_compaction as pick_tiered_compaction, tier_id_for_size, }; diff --git a/src/storage/compaction/tiered.rs b/src/storage/compaction/tiered.rs index a9d4c5e..0c614b9 100644 --- a/src/storage/compaction/tiered.rs +++ b/src/storage/compaction/tiered.rs @@ -39,13 +39,13 @@ pub fn pick_compaction( version_set: &VersionSet, config: &TieredCompactionConfig, ) -> Option { - let mut tiers = group_tables_into_tiers(version_set, config); + let tiers = group_tables_into_tiers(version_set, config); if tiers.is_empty() { return None; } let mut best: Option = None; - for (tier_id, mut tables) in tiers.drain() { + for (tier_id, mut tables) in tiers.into_iter() { if tables.len() < config.max_components_per_tier { continue; } diff --git a/src/storage/engine.rs b/src/storage/engine.rs index b498403..c7e1899 100644 --- a/src/storage/engine.rs +++ b/src/storage/engine.rs @@ -1236,9 +1236,7 @@ fn compact_one_task(task: CompactionTask) -> CompactionResponse { } } -fn merge_compaction_rows( - mut rows: Vec<(Vec, Vec, u64)>, -) -> Vec<(Vec, Vec)> { +fn merge_compaction_rows(mut rows: Vec<(Vec, Vec, u64)>) -> Vec<(Vec, Vec)> { rows.sort_by(|left, right| left.0.cmp(&right.0).then(left.2.cmp(&right.2))); // First collapse identical internal keys, preferring the newest table id. @@ -1293,9 +1291,7 @@ fn merge_compaction_rows( collapsed } -fn estimate_live_user_bytes_for_tables( - tables: &[SSTableRuntime], -) -> Result { +fn estimate_live_user_bytes_for_tables(tables: &[SSTableRuntime]) -> Result { let mut latest = HashMap::, (u64, ValueType, u64)>::new(); let mut undecodable_live_bytes = 0_u64; @@ -1313,13 +1309,11 @@ fn estimate_live_user_bytes_for_tables( let decoded_live_bytes = latest .values() - .map(|(_, value_type, logical_bytes)| { - if *value_type == ValueType::Put { - *logical_bytes - } else { - 0 - } - }) + .map( + |(_, value_type, logical_bytes)| { + if *value_type == ValueType::Put { *logical_bytes } else { 0 } + }, + ) .sum::(); Ok(decoded_live_bytes.saturating_add(undecodable_live_bytes)) @@ -1343,7 +1337,10 @@ fn accumulate_live_user_bytes( .map(|(sequence, _, _)| decoded.sequence > *sequence) .unwrap_or(true); if should_replace { - latest.insert(decoded.user_key.to_vec(), (decoded.sequence, decoded.value_type, logical_bytes)); + latest.insert( + decoded.user_key.to_vec(), + (decoded.sequence, decoded.value_type, logical_bytes), + ); } } @@ -1515,7 +1512,9 @@ mod tests { let merged = merge_compaction_rows(rows); let dup_row = merged .iter() - .find(|(key, _)| decode_internal_key(key).map(|k| k.user_key == b"dup").unwrap_or(false)) + .find(|(key, _)| { + decode_internal_key(key).map(|k| k.user_key == b"dup").unwrap_or(false) + }) .expect("duplicate key row"); assert_eq!(dup_row.1, b"new".to_vec()); } @@ -1546,13 +1545,11 @@ mod tests { let live = latest .values() - .map(|(_, value_type, logical_bytes)| { - if *value_type == ValueType::Put { - *logical_bytes - } else { - 0 - } - }) + .map( + |(_, value_type, logical_bytes)| { + if *value_type == ValueType::Put { *logical_bytes } else { 0 } + }, + ) .sum::() .saturating_add(undecodable); diff --git a/src/storage/memtable/mod.rs b/src/storage/memtable/mod.rs index 77125c6..7a16f12 100644 --- a/src/storage/memtable/mod.rs +++ b/src/storage/memtable/mod.rs @@ -1,8 +1,8 @@ pub mod arena; pub mod skiplist; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use self::skiplist::SkipList; @@ -291,7 +291,7 @@ mod tests { #[test] fn manager_promotes_mutable_when_size_limit_is_reached() { - let mut manager = MemTableManager::new(20, 32); + let mut manager = MemTableManager::new(40, 32); let promoted = manager.put(b"alpha", 1, b"1234567890"); assert!(!promoted); diff --git a/src/storage/wal/mod.rs b/src/storage/wal/mod.rs index a583b57..56e8699 100644 --- a/src/storage/wal/mod.rs +++ b/src/storage/wal/mod.rs @@ -4,8 +4,8 @@ pub mod writer; pub use reader::{WalReadError, WalReader, WalReplay}; pub use record::{ - decode_physical, encode_physical, parse_segment_id, segment_file_name, PhysicalRecord, - RecordDecodeError, RecordEncodeError, RecordType, BLOCK_SIZE_BYTES, DEFAULT_SEGMENT_SIZE_BYTES, - HEADER_LEN, + BLOCK_SIZE_BYTES, DEFAULT_SEGMENT_SIZE_BYTES, HEADER_LEN, PhysicalRecord, RecordDecodeError, + RecordEncodeError, RecordType, decode_physical, encode_physical, parse_segment_id, + segment_file_name, }; pub use writer::{SyncMode, WalWriteError, WalWriter, WalWriterOptions}; diff --git a/src/storage/wal/reader.rs b/src/storage/wal/reader.rs index 839298e..82bab0e 100644 --- a/src/storage/wal/reader.rs +++ b/src/storage/wal/reader.rs @@ -4,7 +4,7 @@ use std::path::{Path, PathBuf}; use thiserror::Error; -use super::record::{checksum, parse_segment_id, RecordType, BLOCK_SIZE_BYTES, HEADER_LEN}; +use super::record::{BLOCK_SIZE_BYTES, HEADER_LEN, RecordType, checksum, parse_segment_id}; #[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct WalReplay { diff --git a/src/storage/wal/writer.rs b/src/storage/wal/writer.rs index 6a2122c..68c1b3a 100644 --- a/src/storage/wal/writer.rs +++ b/src/storage/wal/writer.rs @@ -78,7 +78,10 @@ impl WalWriter { } pub fn append(&mut self, payload: &[u8]) -> Result<(), WalWriteError> { - if self.segment_len >= self.options.segment_size_bytes { + let estimated_len = estimate_logical_record_len(payload.len(), self.block_offset) as u64; + if self.segment_len > 0 + && self.segment_len.saturating_add(estimated_len) > self.options.segment_size_bytes + { self.rotate_segment()?; } @@ -217,6 +220,40 @@ impl WalWriter { } } +fn estimate_logical_record_len(payload_len: usize, mut block_offset: usize) -> usize { + if payload_len == 0 { + let space_left = BLOCK_SIZE_BYTES - block_offset; + if space_left < HEADER_LEN { + return space_left + HEADER_LEN; + } + return HEADER_LEN; + } + + let mut remaining = payload_len; + let mut written = 0usize; + + while remaining > 0 { + let space_left = BLOCK_SIZE_BYTES - block_offset; + if space_left <= HEADER_LEN { + written += space_left; + block_offset = 0; + continue; + } + + let fragment = remaining.min(space_left - HEADER_LEN); + let encoded_len = HEADER_LEN + fragment; + written += encoded_len; + remaining -= fragment; + + block_offset += encoded_len; + if block_offset == BLOCK_SIZE_BYTES { + block_offset = 0; + } + } + + written +} + fn next_segment_id(dir: &Path) -> io::Result { let mut max_id = None; diff --git a/tests/bench/mixed_read_write.rs b/tests/bench/mixed_read_write.rs index 416a15f..d1eca35 100644 --- a/tests/bench/mixed_read_write.rs +++ b/tests/bench/mixed_read_write.rs @@ -325,11 +325,7 @@ fn summarize_window_throughput(values: &[f64]) -> (f64, f64, f64, f64) { } fn percentile_us(histogram: &Histogram, quantile: f64) -> f64 { - if histogram.len() == 0 { - 0.0 - } else { - histogram.value_at_quantile(quantile) as f64 / 1_000.0 - } + if histogram.len() == 0 { 0.0 } else { histogram.value_at_quantile(quantile) as f64 / 1_000.0 } } fn build_keys(key_space: usize) -> Vec> { diff --git a/tests/bench/read_latency.rs b/tests/bench/read_latency.rs index 155444f..b8b0bad 100644 --- a/tests/bench/read_latency.rs +++ b/tests/bench/read_latency.rs @@ -4,7 +4,7 @@ use std::fs; use std::path::{Path, PathBuf}; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; use hdrhistogram::Histogram; use lsmdb::storage::engine::{StorageEngine, StorageEngineOptions}; use lsmdb::storage::wal::SyncMode; diff --git a/tests/bench/wal_recovery.rs b/tests/bench/wal_recovery.rs index dc7ce5b..c9fc5b7 100644 --- a/tests/bench/wal_recovery.rs +++ b/tests/bench/wal_recovery.rs @@ -5,8 +5,8 @@ use std::path::{Path, PathBuf}; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use lsmdb::storage::wal::{ - SyncMode, WalReader, WalWriter, WalWriterOptions, - DEFAULT_SEGMENT_SIZE_BYTES as WAL_DEFAULT_SEGMENT_SIZE_BYTES, + DEFAULT_SEGMENT_SIZE_BYTES as WAL_DEFAULT_SEGMENT_SIZE_BYTES, SyncMode, WalReader, WalWriter, + WalWriterOptions, }; const DEFAULT_RECORDS: usize = 250_000; diff --git a/tests/bench/write_throughput.rs b/tests/bench/write_throughput.rs index c2d3261..90e67c6 100644 --- a/tests/bench/write_throughput.rs +++ b/tests/bench/write_throughput.rs @@ -3,7 +3,7 @@ use std::fs; use std::path::{Path, PathBuf}; use std::time::{SystemTime, UNIX_EPOCH}; -use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; +use criterion::{BatchSize, BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; use lsmdb::storage::engine::{StorageEngine, StorageEngineOptions}; use lsmdb::storage::wal::SyncMode; diff --git a/tests/conformance/sql/01-core-supported.toml b/tests/conformance/sql/01-core-supported.toml new file mode 100644 index 0000000..4839512 --- /dev/null +++ b/tests/conformance/sql/01-core-supported.toml @@ -0,0 +1,116 @@ +suite_id = "core_supported" +description = "Positive conformance coverage for currently supported SQL statements." + +[[cases]] +id = "ddl_create_table" +category = "ddl.create_table" +sql = "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, active BOOLEAN DEFAULT true, PRIMARY KEY (id))" + +[cases.expect] +kind = "affected_rows" +count = 0 + +[[cases]] +id = "dml_insert" +category = "dml.insert" +setup_sql = [ + "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, active BOOLEAN DEFAULT true, PRIMARY KEY (id))" +] +sql = "INSERT INTO users (id, email, active) VALUES (1, 'a@x.com', true)" + +[cases.expect] +kind = "affected_rows" +count = 1 + +[[cases]] +id = "query_select_projection_order_limit" +category = "query.select" +setup_sql = [ + "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, active BOOLEAN DEFAULT true, PRIMARY KEY (id))", + "INSERT INTO users (id, email, active) VALUES (1, 'a@x.com', true)", + "INSERT INTO users (id, email, active) VALUES (2, 'b@x.com', true)" +] +sql = "SELECT id, email FROM users WHERE active = true ORDER BY id DESC LIMIT 1" + +[cases.expect] +kind = "query" +columns = ["id", "email"] +rows = [["2", "b@x.com"]] + +[[cases]] +id = "dml_update" +category = "dml.update" +setup_sql = [ + "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, active BOOLEAN DEFAULT true, PRIMARY KEY (id))", + "INSERT INTO users (id, email, active) VALUES (1, 'a@x.com', true)" +] +sql = "UPDATE users SET email = 'updated@x.com' WHERE id = 1" + +[cases.expect] +kind = "affected_rows" +count = 1 + +[[cases]] +id = "dml_delete" +category = "dml.delete" +setup_sql = [ + "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, active BOOLEAN DEFAULT true, PRIMARY KEY (id))", + "INSERT INTO users (id, email, active) VALUES (1, 'a@x.com', true)" +] +sql = "DELETE FROM users WHERE id = 1" + +[cases.expect] +kind = "affected_rows" +count = 1 + +[[cases]] +id = "ddl_drop_table" +category = "ddl.drop_table" +setup_sql = [ + "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, active BOOLEAN DEFAULT true, PRIMARY KEY (id))" +] +sql = "DROP TABLE users" + +[cases.expect] +kind = "affected_rows" +count = 0 + +[[cases]] +id = "txn_begin" +category = "txn.begin" +setup_sql = [ + "CREATE TABLE accounts (id BIGINT NOT NULL, email TEXT NOT NULL, PRIMARY KEY (id))" +] +sql = "BEGIN ISOLATION LEVEL SNAPSHOT" + +[cases.expect] +kind = "transaction_state" +state = "begun" + +[[cases]] +id = "txn_commit" +category = "txn.commit" +setup_sql = [ + "CREATE TABLE accounts (id BIGINT NOT NULL, email TEXT NOT NULL, PRIMARY KEY (id))", + "BEGIN ISOLATION LEVEL SNAPSHOT", + "INSERT INTO accounts (id, email) VALUES (1, 'pending@x.com')" +] +sql = "COMMIT" + +[cases.expect] +kind = "transaction_state" +state = "committed" + +[[cases]] +id = "txn_rollback" +category = "txn.rollback" +setup_sql = [ + "CREATE TABLE accounts (id BIGINT NOT NULL, email TEXT NOT NULL, PRIMARY KEY (id))", + "BEGIN ISOLATION LEVEL SNAPSHOT", + "INSERT INTO accounts (id, email) VALUES (2, 'rollback@x.com')" +] +sql = "ROLLBACK" + +[cases.expect] +kind = "transaction_state" +state = "rolled_back" diff --git a/tests/conformance/sql/02-errors-and-boundaries.toml b/tests/conformance/sql/02-errors-and-boundaries.toml new file mode 100644 index 0000000..8cbab8d --- /dev/null +++ b/tests/conformance/sql/02-errors-and-boundaries.toml @@ -0,0 +1,73 @@ +suite_id = "errors_and_boundaries" +description = "Negative conformance coverage for unsupported syntax and enforced SQL validation rules." + +[[cases]] +id = "error_invalid_where_type" +category = "errors.invalid_where_type" +setup_sql = [ + "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, PRIMARY KEY (id))" +] +sql = "SELECT id FROM users WHERE id + 1" + +[cases.expect] +kind = "error_contains" +message = "WHERE clause must evaluate to BOOLEAN" + +[[cases]] +id = "error_unsupported_join_syntax" +category = "errors.unsupported_join" +setup_sql = [ + "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, PRIMARY KEY (id))", + "CREATE TABLE accounts (id BIGINT NOT NULL, email TEXT NOT NULL, PRIMARY KEY (id))" +] +sql = "SELECT users.id FROM users JOIN accounts ON users.id = accounts.id" + +[cases.expect] +kind = "error_contains" +message = "expected a SQL statement" + +[[cases]] +id = "error_ddl_in_explicit_transaction" +category = "errors.ddl_in_explicit_txn" +setup_sql = [ + "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, PRIMARY KEY (id))", + "BEGIN ISOLATION LEVEL SNAPSHOT" +] +sql = "CREATE TABLE audit (id BIGINT NOT NULL, PRIMARY KEY (id))" + +[cases.expect] +kind = "error_contains" +message = "DDL in explicit transaction is not supported yet" + +[[cases]] +id = "error_unknown_table" +category = "errors.unknown_table" +sql = "SELECT id FROM missing_users" + +[cases.expect] +kind = "error_contains" +message = "table 'missing_users' not found" + +[[cases]] +id = "error_duplicate_insert_column" +category = "errors.duplicate_column" +setup_sql = [ + "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, PRIMARY KEY (id))" +] +sql = "INSERT INTO users (id, id) VALUES (1, 2)" + +[cases.expect] +kind = "error_contains" +message = "duplicated in statement" + +[[cases]] +id = "error_non_nullable_null" +category = "errors.non_nullable_null" +setup_sql = [ + "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, PRIMARY KEY (id))" +] +sql = "INSERT INTO users (id, email) VALUES (1, NULL)" + +[cases.expect] +kind = "error_contains" +message = "NULL is not allowed" diff --git a/tests/fixtures/tls/server.crt b/tests/fixtures/tls/server.crt new file mode 100644 index 0000000..c37cff9 --- /dev/null +++ b/tests/fixtures/tls/server.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDSTCCAjGgAwIBAgIUSiKuegO6chUYvvtO6Z8WlBZB7oswDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMxODE4NDAwMFoXDTM2MDMx +NTE4NDAwMFowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEAz7AITSXkTJY3bGqEn41sJIUR/RNyGZrUYIhB7yvB0Lsn +XDuDEsznNT3x7ndg4ISPX/MzDPEQsFR22jK6DT+6i1c9goqKx9SmAvVrKauDQG/J +LsISvzolaPr/YNHyuYGusux+tdIQyOoMjUrrEdRyBhWdqgRL8XTDag8YbQJ20ROO +GJcqhIQmmYn60GSgmODSJ3lqrQrmxwadNAdgaznAXapyFrasAE3x8bv3FORah7GZ +JaP0d5IbwlWE2n0b5++xI1C9zANSCA2v9P//Wt5neTBU2iRm1WJuDeXaTMUwkyOq +tK9ulUo4Y3bDS0iAWXSDP/zpUodOYO7PqhP9ln5NswIDAQABo4GSMIGPMB0GA1Ud +DgQWBBQrsXZq+UyEOwT+oGfUA/pRxESG9DAfBgNVHSMEGDAWgBQrsXZq+UyEOwT+ +oGfUA/pRxESG9DAaBgNVHREEEzARgglsb2NhbGhvc3SHBH8AAAEwDAYDVR0TAQH/ +BAIwADAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDQYJKoZI +hvcNAQELBQADggEBAASq+2vIrni0sHv+KveekbjDNjNRuAyVd4wCdP90LfxCI0iB +x7w87ZRYeg2tEtGPQ3G+FgeSMccJrvTnHT9ujIZkpqP+M1WiX0i9FzEue9KMMloR +qDtNz14rSNDJXq7Z0JbU2E40BPygZwLF2dBw6SIP/9YlyuO9uI6K8NC31fU8uSfu +eZGpxp2iTZmyttJYA+LP0tCB6cxpR4t9ANNAJlU2LKpJ7QJ8tee0WsNcbcNOCh4n +oAgWOL9Q04RRvcVjuOYECZxL7c6lBhQiB8Del8T38wr9UQ4bq8bFgSfY6qmAaxiH +HexGIoE1MnlEvNZetrrabPKVrTkirHt7mLcRDlM= +-----END CERTIFICATE----- diff --git a/tests/fixtures/tls/server.key b/tests/fixtures/tls/server.key new file mode 100644 index 0000000..791068d --- /dev/null +++ b/tests/fixtures/tls/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDPsAhNJeRMljds +aoSfjWwkhRH9E3IZmtRgiEHvK8HQuydcO4MSzOc1PfHud2DghI9f8zMM8RCwVHba +MroNP7qLVz2CiorH1KYC9Wspq4NAb8kuwhK/OiVo+v9g0fK5ga6y7H610hDI6gyN +SusR1HIGFZ2qBEvxdMNqDxhtAnbRE44YlyqEhCaZifrQZKCY4NIneWqtCubHBp00 +B2BrOcBdqnIWtqwATfHxu/cU5FqHsZklo/R3khvCVYTafRvn77EjUL3MA1IIDa/0 +//9a3md5MFTaJGbVYm4N5dpMxTCTI6q0r26VSjhjdsNLSIBZdIM//OlSh05g7s+q +E/2Wfk2zAgMBAAECggEACjfuiKEzJOOFMZPiF5mVNwzHEE0bIZBhH6jEmbhs7lCv +BJY3Aj9Lpu53z1RXU2SiS0XDfsEDobFeMakqR0mZ644szBX18xQO4PljPucd65c0 +blUFKBx7x7kFxKU/zInJZytEpryBr+j4GiGUBEoQHCWHHtzcQbKNhNPeT0q+PtYh +Q1DbMppI1CbPj4RDLBJk/0uZbuxV4klZrboMZAFRoGCcy5g3r4XDXeUCL8ppo1/L +TQuM7s2pSpV+Psc6aR/J6NUkvuyoqfqt2Xk9qWJT2ehbsUXVSDfq7cFhObPLTrQr +IzhBFWf4LdgOb/snDtOCMlh5xJPp/Ap5jRMHcpKk9QKBgQDwhLaRnl/KJloNoS1Z +LxC/Y/U+hwCp4vk9ugdPdSoIJUywQKulDahH//tzuzi19e5WLarn4MWqD7L16TcT +8UJZre9tQOcYUgyR0QOtHM0aMwo7/B87Q4r/d1lI6hA1ezsCjcCYwkI+QuOaWSIK +BAf6KPHz7G5b6WVURiJ++wTlfwKBgQDdDlSMZeKraRfkvizxZ7QfLGzkdKgVs3w/ +MpGpgkp/wbxCg6w17c7qtz/4gmQJahJw/Tl6R+o/lPWdukHExRS1IbJT/dbouFG2 +4QLcQXt4XYNf/gNhJ8H8Au9A/peGi75uzNTZlS4ylwJdsvJq5ySaugVujAhPPYEt +nzQGAKb5zQKBgQCW63+vwgPzUbtiIAfXlVvZ7Hv/vzCgaWbh37AkoK0+LUGAuyO5 +TueQPkTnKsx8CRSDiOZb18PQYUd3XN6Nqe5rXWQGVxprPVjbyp6W6qKcVPiQCTUD +t+8pPBePVCfVlzzA7neyovp0HP66ZEGirULgKv8fgvUAwWQuzE9rBFHfOwKBgCBE +HTc5D/LxLhmnYKwD9RivxV07YeV5A2O+H+DcMb+gKbiTu6lLgu5jvSSq86skHnj7 +nU4p/Rk2xvs02rC8C5+8wWjdHmdtsA+/nElGDZ2uGKUEUL33rar5Sq7z+m4bK7rE +jzULP2kG/cNrgVL1VjR3fp96NSRL1/UuzcsqgTTpAoGAYGbp08rAbVooJDaFXfKO +qp5/BfmqqS1rh9hQ78T3G9EfwLCL8Sd2mucSOW0kJK5zuWZwhxbnPmKj9VUfOCb1 +p6fO4ckSk6YoDcHHs9d99udk8nG2MfYtOMzG0ES8J4XEGt6jcz8TpmYttRyhKNFz +ZNfQVNHTjj4p5LEWJ05Tvn8= +-----END PRIVATE KEY----- diff --git a/tests/integration/compaction.rs b/tests/integration/compaction.rs index 16f2de4..21dd288 100644 --- a/tests/integration/compaction.rs +++ b/tests/integration/compaction.rs @@ -1,6 +1,6 @@ use lsmdb::storage::compaction::{ - pick_leveled_compaction, pick_tiered_compaction, CompactionScheduler, CompactionStrategy, - LeveledCompactionConfig, LeveledTrigger, TieredCompactionConfig, + CompactionScheduler, CompactionStrategy, LeveledCompactionConfig, LeveledTrigger, + TieredCompactionConfig, pick_leveled_compaction, pick_tiered_compaction, }; use lsmdb::storage::manifest::version::{SSTableMetadata, VersionSet}; diff --git a/tests/integration/engine_compaction.rs b/tests/integration/engine_compaction.rs index 320d548..b0ae4ac 100644 --- a/tests/integration/engine_compaction.rs +++ b/tests/integration/engine_compaction.rs @@ -1,6 +1,7 @@ use std::fs; use std::path::PathBuf; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use std::thread; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use lsmdb::storage::compaction::{CompactionStrategy, LeveledCompactionConfig}; use lsmdb::storage::engine::{StorageEngine, StorageEngineOptions}; @@ -18,6 +19,25 @@ fn temp_dir(label: &str) -> PathBuf { dir } +fn count_user_key_versions(engine: &StorageEngine, user_key: &[u8]) -> usize { + let mut versions = 0_usize; + for table in engine.sstable_metadata() { + let reader = SSTableReader::open(engine.sstable_dir().join(&table.file_name)) + .expect("open sstable reader"); + let rows = reader.scan_range(None, None).expect("scan table rows"); + + for (internal_key, _value) in rows { + if decode_internal_key(&internal_key) + .map(|decoded| decoded.user_key == user_key) + .unwrap_or(false) + { + versions += 1; + } + } + } + versions +} + #[test] fn engine_runs_background_compaction_and_preserves_reads() { let dir = temp_dir("background-compaction"); @@ -106,26 +126,16 @@ fn compaction_collapses_old_versions_for_same_user_key() { .wait_for_background_compaction(Duration::from_secs(5)) .expect("background compaction should complete"); - assert_eq!( - engine.get(b"hot-key").expect("read hot key"), - Some(b"hot-09".to_vec()) - ); + assert_eq!(engine.get(b"hot-key").expect("read hot key"), Some(b"hot-09".to_vec())); - let mut hot_versions = 0_usize; - for table in engine.sstable_metadata() { - let reader = SSTableReader::open(engine.sstable_dir().join(&table.file_name)) - .expect("open sstable reader"); - let rows = reader.scan_range(None, None).expect("scan table rows"); - - for (internal_key, _value) in rows { - if decode_internal_key(&internal_key) - .map(|decoded| decoded.user_key == b"hot-key") - .unwrap_or(false) - { - hot_versions += 1; - } + let deadline = Instant::now() + Duration::from_secs(5); + let hot_versions = loop { + let hot_versions = count_user_key_versions(&engine, b"hot-key"); + if hot_versions == 1 || Instant::now() >= deadline { + break hot_versions; } - } + thread::sleep(Duration::from_millis(25)); + }; assert_eq!( hot_versions, 1, diff --git a/tests/integration/planner.rs b/tests/integration/planner.rs index 196586e..09d97b3 100644 --- a/tests/integration/planner.rs +++ b/tests/integration/planner.rs @@ -38,14 +38,10 @@ fn planner_uses_primary_key_scan_for_pk_equality_filter() { let statement = parse_statement("SELECT id, email FROM users WHERE id = 7").expect("parse"); let plan = plan_statement(&catalog, &statement).expect("plan"); - match plan { - PhysicalPlan::PrimaryKeyScan(scan) => { - assert_eq!(scan.table, "users"); - assert_eq!(scan.key_values.len(), 1); - assert_eq!(scan.key_values[0].0, "id"); - } - other => panic!("expected primary key scan, got {other:?}"), - } + assert!( + contains_primary_key_scan(&plan), + "expected primary key scan in plan subtree, got {plan:?}" + ); } #[test] @@ -55,8 +51,38 @@ fn planner_keeps_seq_scan_for_non_pk_filter() { parse_statement("SELECT id, email FROM users WHERE email = 'a@b.com'").expect("parse"); let plan = plan_statement(&catalog, &statement).expect("plan"); - let PhysicalPlan::Filter(filter) = plan else { - panic!("expected filter"); - }; - assert!(matches!(*filter.input, PhysicalPlan::SeqScan(_))); + assert!( + contains_filter_over_seq_scan(&plan), + "expected filter over seq scan in plan subtree, got {plan:?}" + ); +} + +fn contains_primary_key_scan(plan: &PhysicalPlan) -> bool { + match plan { + PhysicalPlan::PrimaryKeyScan(_) => true, + PhysicalPlan::Filter(filter) => contains_primary_key_scan(&filter.input), + PhysicalPlan::Project(project) => contains_primary_key_scan(&project.input), + PhysicalPlan::Sort(sort) => contains_primary_key_scan(&sort.input), + PhysicalPlan::Limit(limit) => contains_primary_key_scan(&limit.input), + PhysicalPlan::Join(join) => { + contains_primary_key_scan(&join.left) || contains_primary_key_scan(&join.right) + } + _ => false, + } +} + +fn contains_filter_over_seq_scan(plan: &PhysicalPlan) -> bool { + match plan { + PhysicalPlan::Filter(filter) => { + matches!(*filter.input, PhysicalPlan::SeqScan(_)) + || contains_filter_over_seq_scan(&filter.input) + } + PhysicalPlan::Project(project) => contains_filter_over_seq_scan(&project.input), + PhysicalPlan::Sort(sort) => contains_filter_over_seq_scan(&sort.input), + PhysicalPlan::Limit(limit) => contains_filter_over_seq_scan(&limit.input), + PhysicalPlan::Join(join) => { + contains_filter_over_seq_scan(&join.left) || contains_filter_over_seq_scan(&join.right) + } + _ => false, + } } diff --git a/tests/integration/server.rs b/tests/integration/server.rs index 25a364e..4b8081d 100644 --- a/tests/integration/server.rs +++ b/tests/integration/server.rs @@ -1,16 +1,39 @@ use std::net::SocketAddr; +use std::path::PathBuf; use std::str::from_utf8; use std::sync::Arc; +use std::sync::OnceLock; +use std::time::Duration; use lsmdb::catalog::Catalog; +use lsmdb::executor::{ExecutionResult, ExecutionSession}; use lsmdb::mvcc::MvccStore; +use lsmdb::planner::plan_statement; use lsmdb::server::{ - QueryPayload, RequestFrame, RequestType, ResponseFrame, ResponsePayload, TransactionState, - read_response, start_server, write_request, + ActiveStatementsPayload, AdminStatusPayload, AuthenticationPayload, ErrorCode, ErrorPayload, + HealthPayload, PROTOCOL_VERSION, QueryPayload, ReadinessPayload, RequestFrame, RequestType, + ResponseFrame, ResponsePayload, ServerAuthOptions, ServerError, ServerLimits, ServerOptions, + ServerRole, ServerSecurityOptions, ServerTlsMode, ServerTlsOptions, + StatementCancellationPayload, StaticPasswordUser, StaticTokenPrincipal, TransactionState, + authentication_request_with_password, authentication_request_with_token, read_response, + start_server, start_server_with_options, write_request, }; +use lsmdb::sql::{parse_statement, validate_statement}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio::time::sleep; +use tokio_rustls::TlsConnector; +use tokio_rustls::rustls::{ClientConfig, RootCertStore, pki_types::ServerName}; -async fn send_request(stream: &mut TcpStream, request: RequestFrame) -> ResponseFrame { +const TLS_CERT_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/tls/server.crt"); +const TLS_KEY_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/tls/server.key"); +static HEAVY_SERVER_TEST_SEMAPHORE: OnceLock> = OnceLock::new(); + +async fn send_request(stream: &mut S, request: RequestFrame) -> ResponseFrame +where + S: AsyncRead + AsyncWrite + Unpin, +{ write_request(stream, &request).await.expect("write request"); read_response(stream).await.expect("read response").expect("response") } @@ -29,15 +52,178 @@ fn response_to_explain(response: ResponseFrame) -> String { } } +fn response_to_health(response: ResponseFrame) -> HealthPayload { + match response { + ResponseFrame::Ok(ResponsePayload::Health(payload)) => payload, + other => panic!("expected health payload, got {other:?}"), + } +} + +fn response_to_readiness(response: ResponseFrame) -> ReadinessPayload { + match response { + ResponseFrame::Ok(ResponsePayload::Readiness(payload)) => payload, + other => panic!("expected readiness payload, got {other:?}"), + } +} + +fn response_to_admin_status(response: ResponseFrame) -> AdminStatusPayload { + match response { + ResponseFrame::Ok(ResponsePayload::AdminStatus(payload)) => payload, + other => panic!("expected admin status payload, got {other:?}"), + } +} + +fn response_to_error(response: ResponseFrame) -> ErrorPayload { + match response { + ResponseFrame::Err(payload) => payload, + other => panic!("expected error payload, got {other:?}"), + } +} + +fn response_to_active_statements(response: ResponseFrame) -> ActiveStatementsPayload { + match response { + ResponseFrame::Ok(ResponsePayload::ActiveStatements(payload)) => payload, + other => panic!("expected active statements payload, got {other:?}"), + } +} + +fn response_to_statement_cancellation(response: ResponseFrame) -> StatementCancellationPayload { + match response { + ResponseFrame::Ok(ResponsePayload::StatementCancellation(payload)) => payload, + other => panic!("expected statement cancellation payload, got {other:?}"), + } +} + +fn response_to_authentication(response: ResponseFrame) -> AuthenticationPayload { + match response { + ResponseFrame::Ok(ResponsePayload::Authentication(payload)) => payload, + other => panic!("expected authentication payload, got {other:?}"), + } +} + +fn execute_setup_sql(catalog: &Catalog, store: &MvccStore, sql: &str) -> ExecutionResult { + let statement = parse_statement(sql).expect("parse setup SQL"); + validate_statement(catalog, &statement).expect("validate setup SQL"); + let plan = plan_statement(catalog, &statement).expect("plan setup SQL"); + let mut session = ExecutionSession::new(catalog, store); + session.execute_plan(&plan).expect("execute setup SQL") +} + +fn populate_users(catalog: &Catalog, store: &MvccStore, rows: usize) { + execute_setup_sql( + catalog, + store, + "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, PRIMARY KEY (id))", + ); + for id in 1..=rows { + execute_setup_sql( + catalog, + store, + &format!( + "INSERT INTO users (id, email) VALUES ({id}, '{}')", + format!("user{id:05}@example.com") + ), + ); + } +} + +async fn wait_for_active_statement_id(stream: &mut S, request_type: &str) -> u64 +where + S: AsyncRead + AsyncWrite + Unpin, +{ + for _ in 0..5_000 { + let payload = response_to_active_statements( + send_request( + stream, + RequestFrame { request_type: RequestType::ActiveStatements, sql: String::new() }, + ) + .await, + ); + if let Some(statement) = + payload.statements.into_iter().find(|statement| statement.request_type == request_type) + { + return statement.statement_id; + } + sleep(Duration::from_millis(1)).await; + } + + panic!("timed out waiting for active statement"); +} + +fn password_server_options(users: Vec) -> ServerOptions { + ServerOptions { + security: ServerSecurityOptions { + auth: ServerAuthOptions::StaticPassword { users }, + tls: fixture_tls_server_options(), + ..ServerSecurityOptions::default() + }, + ..ServerOptions::default() + } +} + +fn token_server_options(principals: Vec) -> ServerOptions { + ServerOptions { + security: ServerSecurityOptions { + auth: ServerAuthOptions::StaticToken { principals }, + tls: fixture_tls_server_options(), + ..ServerSecurityOptions::default() + }, + ..ServerOptions::default() + } +} + +fn insecure_server_options() -> ServerOptions { + ServerOptions::insecure_for_local_dev() +} + +fn fixture_tls_server_options() -> ServerTlsOptions { + ServerTlsOptions { + mode: ServerTlsMode::Required, + cert_path: Some(PathBuf::from(TLS_CERT_PATH)), + key_path: Some(PathBuf::from(TLS_KEY_PATH)), + } +} + +async fn connect_tls_client(addr: SocketAddr) -> tokio_rustls::client::TlsStream { + let certificate = std::fs::File::open(TLS_CERT_PATH).expect("open tls certificate fixture"); + let mut reader = std::io::BufReader::new(certificate); + let certificates = rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .expect("parse tls certificate fixture"); + let mut roots = RootCertStore::empty(); + for certificate in certificates { + roots.add(certificate).expect("add tls certificate to root store"); + } + let config = ClientConfig::builder().with_root_certificates(roots).with_no_client_auth(); + let connector = TlsConnector::from(Arc::new(config)); + let stream = TcpStream::connect(addr).await.expect("connect tcp stream for tls"); + let server_name = ServerName::try_from("localhost").expect("valid server name").to_owned(); + connector.connect(server_name, stream).await.expect("complete tls handshake") +} + +async fn acquire_heavy_server_test_permit() -> OwnedSemaphorePermit { + HEAVY_SERVER_TEST_SEMAPHORE + .get_or_init(|| Arc::new(Semaphore::new(1))) + .clone() + .acquire_owned() + .await + .expect("heavy server test semaphore should remain open") +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn server_executes_query_requests_end_to_end() { let store = Arc::new(MvccStore::new()); let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); - let server = start_server(bind_addr, Arc::clone(&catalog), Arc::clone(&store)) - .await - .expect("start server"); + let server = start_server_with_options( + bind_addr, + Arc::clone(&catalog), + Arc::clone(&store), + insecure_server_options(), + ) + .await + .expect("start server"); let server_addr = server.local_addr(); let mut client = TcpStream::connect(server_addr).await.expect("connect client"); @@ -86,9 +272,14 @@ async fn server_returns_explain_plan_without_executing_statement() { let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); - let server = start_server(bind_addr, Arc::clone(&catalog), Arc::clone(&store)) - .await - .expect("start server"); + let server = start_server_with_options( + bind_addr, + Arc::clone(&catalog), + Arc::clone(&store), + insecure_server_options(), + ) + .await + .expect("start server"); let server_addr = server.local_addr(); let mut client = TcpStream::connect(server_addr).await.expect("connect client"); @@ -132,9 +323,14 @@ async fn server_tracks_transaction_state_per_connection() { let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); - let server = start_server(bind_addr, Arc::clone(&catalog), Arc::clone(&store)) - .await - .expect("start server"); + let server = start_server_with_options( + bind_addr, + Arc::clone(&catalog), + Arc::clone(&store), + insecure_server_options(), + ) + .await + .expect("start server"); let server_addr = server.local_addr(); let mut client_a = TcpStream::connect(server_addr).await.expect("connect client_a"); @@ -207,3 +403,788 @@ async fn server_tracks_transaction_state_per_connection() { server.shutdown().await.expect("shutdown server"); } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_exposes_health_readiness_and_admin_status() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = start_server_with_options( + bind_addr, + Arc::clone(&catalog), + Arc::clone(&store), + insecure_server_options(), + ) + .await + .expect("start server"); + let server_addr = server.local_addr(); + + let mut client = TcpStream::connect(server_addr).await.expect("connect client"); + + let health = response_to_health( + send_request( + &mut client, + RequestFrame { request_type: RequestType::Health, sql: String::new() }, + ) + .await, + ); + assert!(health.ok); + assert_eq!(health.status, "ok"); + + let readiness = response_to_readiness( + send_request( + &mut client, + RequestFrame { request_type: RequestType::Readiness, sql: String::new() }, + ) + .await, + ); + assert!(readiness.ready); + assert_eq!(readiness.status, "ready"); + + let admin = response_to_admin_status( + send_request( + &mut client, + RequestFrame { request_type: RequestType::AdminStatus, sql: String::new() }, + ) + .await, + ); + assert_eq!(admin.protocol_version, PROTOCOL_VERSION); + assert_eq!(admin.server_version, env!("CARGO_PKG_VERSION")); + assert!(admin.accepting_connections); + assert!(admin.total_connections >= 1); + assert!(admin.active_connections >= 1); + assert_eq!(admin.rejected_connections, 0); + assert_eq!(admin.busy_requests, 0); + assert_eq!(admin.resource_limit_requests, 0); + assert_eq!(admin.quota_rejections, 0); + assert_eq!(admin.timed_out_requests, 0); + assert_eq!(admin.canceled_requests, 0); + assert_eq!(admin.active_statements, 0); + assert_eq!(admin.active_memory_intensive_requests, 0); + assert_eq!(admin.mvcc_started, 0); + assert_eq!(admin.mvcc_committed, 0); + assert_eq!(admin.mvcc_rolled_back, 0); + assert_eq!(admin.mvcc_write_conflicts, 0); + assert_eq!(admin.mvcc_active_transactions, 0); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn start_server_requires_explicit_security_configuration() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let err = match start_server(bind_addr, Arc::clone(&catalog), Arc::clone(&store)).await { + Ok(_) => panic!("default startup should be rejected"), + Err(err) => err, + }; + match err { + ServerError::InvalidConfiguration(message) => { + assert!(message.contains("allow_anonymous_access")); + } + other => panic!("expected invalid configuration, got {other:?}"), + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_rejects_connections_above_limit_and_keeps_existing_connection_responsive() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = ServerOptions { + limits: ServerLimits { max_concurrent_connections: 1, ..ServerLimits::default() }, + ..insecure_server_options() + }; + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start server"); + let server_addr = server.local_addr(); + + let mut first = TcpStream::connect(server_addr).await.expect("connect first client"); + let mut second = TcpStream::connect(server_addr).await.expect("connect second client"); + + let rejected = + read_response(&mut second).await.expect("read busy response").expect("busy response"); + let error = response_to_error(rejected); + assert_eq!(error.code, ErrorCode::Busy); + assert!(error.retryable); + assert!(error.message.contains("max concurrent connections")); + + let health = response_to_health( + send_request( + &mut first, + RequestFrame { request_type: RequestType::Health, sql: String::new() }, + ) + .await, + ); + assert!(health.ok); + + let admin = response_to_admin_status( + send_request( + &mut first, + RequestFrame { request_type: RequestType::AdminStatus, sql: String::new() }, + ) + .await, + ); + assert_eq!(admin.rejected_connections, 1); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_rejects_oversized_request_frames_with_resource_limit_error() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = ServerOptions { + limits: ServerLimits { max_request_bytes: 32, ..ServerLimits::default() }, + ..insecure_server_options() + }; + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start server"); + let server_addr = server.local_addr(); + + let mut client = TcpStream::connect(server_addr).await.expect("connect client"); + let error = response_to_error( + send_request( + &mut client, + RequestFrame { + request_type: RequestType::Query, + sql: "SELECT 1 FROM some_really_long_table_name_that_exceeds_the_limit".to_string(), + }, + ) + .await, + ); + assert_eq!(error.code, ErrorCode::ResourceLimit); + assert!(!error.retryable); + assert!(error.message.contains("request frame too large")); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_enforces_scan_and_sort_limits() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = ServerOptions { + limits: ServerLimits { + max_scan_rows: 2, + max_sort_rows: 2, + max_query_result_rows: 2, + ..ServerLimits::default() + }, + ..insecure_server_options() + }; + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start server"); + let server_addr = server.local_addr(); + + let mut client = TcpStream::connect(server_addr).await.expect("connect client"); + let create_response = send_request( + &mut client, + RequestFrame { + request_type: RequestType::Query, + sql: "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, PRIMARY KEY (id))" + .to_string(), + }, + ) + .await; + assert!(matches!(create_response, ResponseFrame::Ok(ResponsePayload::AffectedRows(0)))); + + for id in 1..=3 { + let insert = send_request( + &mut client, + RequestFrame { + request_type: RequestType::Query, + sql: format!("INSERT INTO users (id, email) VALUES ({id}, 'user{id}@x.com')"), + }, + ) + .await; + assert!(matches!(insert, ResponseFrame::Ok(ResponsePayload::AffectedRows(1)))); + } + + let scan_error = response_to_error( + send_request( + &mut client, + RequestFrame { + request_type: RequestType::Query, + sql: "SELECT id, email FROM users".to_string(), + }, + ) + .await, + ); + assert_eq!(scan_error.code, ErrorCode::ResourceLimit); + assert!(scan_error.message.contains("scan rows")); + + let sort_error = response_to_error( + send_request( + &mut client, + RequestFrame { + request_type: RequestType::Query, + sql: "SELECT id FROM users ORDER BY id DESC".to_string(), + }, + ) + .await, + ); + assert_eq!(sort_error.code, ErrorCode::ResourceLimit); + assert!(sort_error.message.contains("scan rows") || sort_error.message.contains("sort rows")); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_enforces_statement_count_limit() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = ServerOptions { + limits: ServerLimits { max_statements_per_request: 1, ..ServerLimits::default() }, + ..insecure_server_options() + }; + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start server"); + let server_addr = server.local_addr(); + + let mut client = TcpStream::connect(server_addr).await.expect("connect client"); + + let error = response_to_error( + send_request( + &mut client, + RequestFrame { + request_type: RequestType::Query, + sql: "BEGIN ISOLATION LEVEL SNAPSHOT; COMMIT".to_string(), + }, + ) + .await, + ); + assert_eq!(error.code, ErrorCode::ResourceLimit); + assert!(error.message.contains("statements")); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn server_times_out_long_running_scan_and_sort_queries() { + let _heavy_test_permit = acquire_heavy_server_test_permit().await; + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + populate_users(&catalog, &store, 25_000); + + let options = ServerOptions { + limits: ServerLimits { statement_timeout_ms: Some(1), ..ServerLimits::default() }, + ..insecure_server_options() + }; + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start server"); + let server_addr = server.local_addr(); + + let mut client = TcpStream::connect(server_addr).await.expect("connect client"); + + let scan_error = response_to_error( + send_request( + &mut client, + RequestFrame { + request_type: RequestType::Query, + sql: "SELECT id, email FROM users".to_string(), + }, + ) + .await, + ); + assert_eq!(scan_error.code, ErrorCode::Timeout); + assert!(scan_error.message.contains("timed out")); + + let sort_error = response_to_error( + send_request( + &mut client, + RequestFrame { + request_type: RequestType::Query, + sql: "SELECT id FROM users ORDER BY email DESC".to_string(), + }, + ) + .await, + ); + assert_eq!(sort_error.code, ErrorCode::Timeout); + assert!(sort_error.message.contains("timed out")); + + let admin = response_to_admin_status( + send_request( + &mut client, + RequestFrame { request_type: RequestType::AdminStatus, sql: String::new() }, + ) + .await, + ); + assert!(admin.timed_out_requests >= 2); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn server_rejects_queries_when_identity_quota_is_reached() { + let _heavy_test_permit = acquire_heavy_server_test_permit().await; + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + populate_users(&catalog, &store, 250_000); + + let options = ServerOptions { + limits: ServerLimits { + max_concurrent_queries_per_identity: Some(1), + statement_timeout_ms: Some(5_000), + max_scan_rows: 300_000, + max_sort_rows: 300_000, + max_query_result_rows: 300_000, + ..ServerLimits::default() + }, + ..insecure_server_options() + }; + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start server"); + let server_addr = server.local_addr(); + + let mut query_client = TcpStream::connect(server_addr).await.expect("connect query client"); + write_request( + &mut query_client, + &RequestFrame { request_type: RequestType::Query, sql: "SELECT id FROM users".to_string() }, + ) + .await + .expect("send blocking query"); + + tokio::task::yield_now().await; + sleep(Duration::from_millis(10)).await; + + let mut second_client = TcpStream::connect(server_addr).await.expect("connect second client"); + let quota_error = response_to_error( + send_request( + &mut second_client, + RequestFrame { + request_type: RequestType::Query, + sql: "SELECT id FROM users LIMIT 1".to_string(), + }, + ) + .await, + ); + assert_eq!(quota_error.code, ErrorCode::Quota); + assert!(quota_error.retryable); + + drop(query_client); + + let mut admin_client = TcpStream::connect(server_addr).await.expect("connect admin client"); + let admin = response_to_admin_status( + send_request( + &mut admin_client, + RequestFrame { request_type: RequestType::AdminStatus, sql: String::new() }, + ) + .await, + ); + assert!(admin.quota_rejections >= 1); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn server_cancellation_rolls_back_active_transaction_state() { + let _heavy_test_permit = acquire_heavy_server_test_permit().await; + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + populate_users(&catalog, &store, 50_000); + + let options = ServerOptions { + limits: ServerLimits { + statement_timeout_ms: Some(5_000), + max_scan_rows: 100_000, + max_sort_rows: 100_000, + max_query_result_rows: 100_000, + ..ServerLimits::default() + }, + ..insecure_server_options() + }; + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start server"); + let server_addr = server.local_addr(); + + let mut txn_client = TcpStream::connect(server_addr).await.expect("connect txn client"); + let begin = send_request( + &mut txn_client, + RequestFrame { request_type: RequestType::Begin, sql: String::new() }, + ) + .await; + assert!(matches!( + begin, + ResponseFrame::Ok(ResponsePayload::TransactionState(TransactionState::Begun)) + )); + + let txn_task = tokio::spawn(async move { + let update = send_request( + &mut txn_client, + RequestFrame { + request_type: RequestType::Query, + sql: "SELECT id FROM users ORDER BY email DESC".to_string(), + }, + ) + .await; + let commit = send_request( + &mut txn_client, + RequestFrame { request_type: RequestType::Commit, sql: String::new() }, + ) + .await; + (update, commit) + }); + + let mut admin_client = TcpStream::connect(server_addr).await.expect("connect admin client"); + let statement_id = wait_for_active_statement_id(&mut admin_client, "QUERY").await; + let cancel = response_to_statement_cancellation( + send_request( + &mut admin_client, + RequestFrame { + request_type: RequestType::CancelStatement, + sql: statement_id.to_string(), + }, + ) + .await, + ); + assert!(cancel.accepted); + + let (update, commit) = txn_task.await.expect("txn task"); + let update_error = response_to_error(update); + assert_eq!(update_error.code, ErrorCode::Canceled); + + let commit_error = response_to_error(commit); + assert_eq!(commit_error.code, ErrorCode::Execution); + assert!(commit_error.message.contains("no active transaction")); + + let mut verify_client = TcpStream::connect(server_addr).await.expect("connect verify client"); + let result = response_to_query( + send_request( + &mut verify_client, + RequestFrame { + request_type: RequestType::Query, + sql: "SELECT email FROM users WHERE id = 1".to_string(), + }, + ) + .await, + ); + assert_eq!(from_utf8(&result.rows[0][0]).expect("utf8 cell"), "user00001@example.com"); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn secure_server_rejects_unauthenticated_requests_before_any_command() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = password_server_options(vec![StaticPasswordUser { + username: "admin".to_string(), + password: "secret".to_string(), + role: ServerRole::Admin, + }]); + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start secure server"); + let server_addr = server.local_addr(); + + let mut client = connect_tls_client(server_addr).await; + let error = response_to_error( + send_request( + &mut client, + RequestFrame { request_type: RequestType::Health, sql: String::new() }, + ) + .await, + ); + assert_eq!(error.code, ErrorCode::Unauthenticated); + assert!(error.message.contains("requires authentication")); + assert!( + read_response(&mut client).await.expect("read connection closure").is_none(), + "server should close the connection after an unauthenticated request" + ); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn secure_server_rejects_authenticated_mode_without_tls() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = ServerOptions { + security: ServerSecurityOptions { + auth: ServerAuthOptions::StaticPassword { + users: vec![StaticPasswordUser { + username: "admin".to_string(), + password: "secret".to_string(), + role: ServerRole::Admin, + }], + }, + ..ServerSecurityOptions::default() + }, + ..ServerOptions::default() + }; + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let err = match start_server_with_options( + bind_addr, + Arc::clone(&catalog), + Arc::clone(&store), + options, + ) + .await + { + Ok(_) => panic!("auth without tls should be rejected"), + Err(err) => err, + }; + match err { + ServerError::InvalidConfiguration(message) => { + assert!(message.contains("authentication requires tls")); + } + other => panic!("expected invalid configuration, got {other:?}"), + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn secure_server_rejects_invalid_password_credentials() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = password_server_options(vec![StaticPasswordUser { + username: "admin".to_string(), + password: "secret".to_string(), + role: ServerRole::Admin, + }]); + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start secure server"); + let server_addr = server.local_addr(); + + let mut client = connect_tls_client(server_addr).await; + let error = response_to_error( + send_request(&mut client, authentication_request_with_password("admin", "wrong")).await, + ); + assert_eq!(error.code, ErrorCode::Unauthenticated); + assert!(error.message.contains("invalid username or password")); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn secure_server_enforces_role_based_authorization() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + execute_setup_sql( + &catalog, + &store, + "CREATE TABLE users (id BIGINT NOT NULL, email TEXT NOT NULL, PRIMARY KEY (id))", + ); + execute_setup_sql( + &catalog, + &store, + "INSERT INTO users (id, email) VALUES (1, 'alice@example.com')", + ); + let options = password_server_options(vec![ + StaticPasswordUser { + username: "reader".to_string(), + password: "reader-secret".to_string(), + role: ServerRole::Reader, + }, + StaticPasswordUser { + username: "admin".to_string(), + password: "admin-secret".to_string(), + role: ServerRole::Admin, + }, + ]); + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start secure server"); + let server_addr = server.local_addr(); + + let mut reader = connect_tls_client(server_addr).await; + let auth = response_to_authentication( + send_request(&mut reader, authentication_request_with_password("reader", "reader-secret")) + .await, + ); + assert_eq!(auth.identity, "reader"); + assert_eq!(auth.role, "reader"); + + let select = response_to_query( + send_request( + &mut reader, + RequestFrame { + request_type: RequestType::Query, + sql: "SELECT email FROM users WHERE id = 1".to_string(), + }, + ) + .await, + ); + assert_eq!(from_utf8(&select.rows[0][0]).expect("utf8 cell"), "alice@example.com"); + + let update_error = response_to_error( + send_request( + &mut reader, + RequestFrame { + request_type: RequestType::Query, + sql: "UPDATE users SET email = 'blocked@example.com' WHERE id = 1".to_string(), + }, + ) + .await, + ); + assert_eq!(update_error.code, ErrorCode::PermissionDenied); + assert!(update_error.message.contains("role 'reader'")); + + let admin_status_error = response_to_error( + send_request( + &mut reader, + RequestFrame { request_type: RequestType::AdminStatus, sql: String::new() }, + ) + .await, + ); + assert_eq!(admin_status_error.code, ErrorCode::PermissionDenied); + + let mut admin = connect_tls_client(server_addr).await; + let admin_auth = response_to_authentication( + send_request(&mut admin, authentication_request_with_password("admin", "admin-secret")) + .await, + ); + assert_eq!(admin_auth.role, "admin"); + let admin_status = response_to_admin_status( + send_request( + &mut admin, + RequestFrame { request_type: RequestType::AdminStatus, sql: String::new() }, + ) + .await, + ); + assert!(admin_status.accepting_connections); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn secure_server_supports_static_token_authentication() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = token_server_options(vec![StaticTokenPrincipal { + label: "ingest-bot".to_string(), + token: "opaque-token".to_string(), + role: ServerRole::Writer, + }]); + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start token-auth server"); + let server_addr = server.local_addr(); + + let mut client = connect_tls_client(server_addr).await; + let auth = response_to_authentication( + send_request(&mut client, authentication_request_with_token("opaque-token")).await, + ); + assert_eq!(auth.identity, "ingest-bot"); + assert_eq!(auth.role, "writer"); + assert_eq!(auth.auth_scheme, "token"); + + let health = response_to_health( + send_request( + &mut client, + RequestFrame { request_type: RequestType::Health, sql: String::new() }, + ) + .await, + ); + assert!(health.ok); + + server.shutdown().await.expect("shutdown server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn secure_server_accepts_tls_connections_when_required() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = password_server_options(vec![StaticPasswordUser { + username: "admin".to_string(), + password: "secret".to_string(), + role: ServerRole::Admin, + }]); + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let server = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await + .expect("start tls server"); + let server_addr = server.local_addr(); + + let mut client = connect_tls_client(server_addr).await; + let auth = response_to_authentication( + send_request(&mut client, authentication_request_with_password("admin", "secret")).await, + ); + assert_eq!(auth.identity, "admin"); + assert_eq!(auth.role, "admin"); + + let readiness = response_to_readiness( + send_request( + &mut client, + RequestFrame { request_type: RequestType::Readiness, sql: String::new() }, + ) + .await, + ); + assert!(readiness.ready); + + server.shutdown().await.expect("shutdown tls server"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn tls_required_mode_rejects_missing_certificate_files() { + let store = Arc::new(MvccStore::new()); + let catalog = Arc::new(Catalog::open((*store).clone()).expect("open catalog")); + let options = ServerOptions { + security: ServerSecurityOptions { + tls: ServerTlsOptions { + mode: ServerTlsMode::Required, + cert_path: Some(PathBuf::from("tests/fixtures/tls/missing.crt")), + key_path: Some(PathBuf::from(TLS_KEY_PATH)), + }, + allow_anonymous_access: true, + ..ServerSecurityOptions::default() + }, + ..insecure_server_options() + }; + + let bind_addr: SocketAddr = "127.0.0.1:0".parse().expect("parse socket addr"); + let result = + start_server_with_options(bind_addr, Arc::clone(&catalog), Arc::clone(&store), options) + .await; + let err = match result { + Ok(_) => panic!("missing certificate should fail"), + Err(err) => err, + }; + match err { + ServerError::Tls(message) => assert!(message.contains("missing.crt")), + other => panic!("expected tls error, got {other:?}"), + } +} diff --git a/tests/integration/wal_recovery.rs b/tests/integration/wal_recovery.rs index 2804a22..731cb9c 100644 --- a/tests/integration/wal_recovery.rs +++ b/tests/integration/wal_recovery.rs @@ -4,7 +4,7 @@ use std::path::PathBuf; use std::time::{SystemTime, UNIX_EPOCH}; use lsmdb::storage::wal::{ - SyncMode, WalReader, WalWriter, WalWriterOptions, BLOCK_SIZE_BYTES, DEFAULT_SEGMENT_SIZE_BYTES, + BLOCK_SIZE_BYTES, DEFAULT_SEGMENT_SIZE_BYTES, SyncMode, WalReader, WalWriter, WalWriterOptions, }; fn test_dir(label: &str) -> PathBuf { diff --git a/tests/sql_conformance.rs b/tests/sql_conformance.rs new file mode 100644 index 0000000..5122bd1 --- /dev/null +++ b/tests/sql_conformance.rs @@ -0,0 +1,420 @@ +use std::collections::{BTreeMap, BTreeSet}; +use std::fs; +use std::path::{Path, PathBuf}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use lsmdb::catalog::Catalog; +use lsmdb::executor::{ExecutionResult, ExecutionSession, ScalarValue}; +use lsmdb::mvcc::MvccStore; +use lsmdb::planner::plan_statement; +use lsmdb::sql::{parse_statement, validate_statement}; +use serde::{Deserialize, Serialize}; + +const FIXTURE_DIR: &str = "tests/conformance/sql"; +const REPORT_RELATIVE_PATH: &str = "sql-conformance/report.toml"; +const REPORT_SCHEMA_VERSION: u32 = 1; + +const REQUIRED_CATEGORIES: &[&str] = &[ + "ddl.create_table", + "ddl.drop_table", + "dml.insert", + "query.select", + "dml.update", + "dml.delete", + "txn.begin", + "txn.commit", + "txn.rollback", + "errors.invalid_where_type", + "errors.unsupported_join", + "errors.ddl_in_explicit_txn", +]; + +#[derive(Debug, Clone, Deserialize)] +#[serde(deny_unknown_fields)] +struct ConformanceSuite { + suite_id: String, + description: String, + cases: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(deny_unknown_fields)] +struct ConformanceCase { + id: String, + category: String, + #[serde(default)] + tags: Vec, + #[serde(default)] + setup_sql: Vec, + sql: String, + expect: Expectation, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case", deny_unknown_fields)] +enum Expectation { + AffectedRows { count: u64 }, + Query { columns: Vec, rows: Vec> }, + TransactionState { state: String }, + ErrorContains { message: String }, +} + +#[derive(Debug, Clone)] +struct CaseOutcome { + passed: bool, + details: String, +} + +#[derive(Debug, Clone, Default)] +struct SuiteStats { + description: String, + total_cases: usize, + passed_cases: usize, + failed_cases: usize, +} + +#[derive(Debug, Clone, Serialize)] +struct ConformanceReport { + schema_version: u32, + generated_unix_seconds: u64, + source_fixture_dir: String, + total_suites: usize, + total_cases: usize, + passed_cases: usize, + failed_cases: usize, + covered_categories: Vec, + suite_summaries: Vec, + failures: Vec, +} + +#[derive(Debug, Clone, Serialize)] +struct SuiteSummary { + suite_id: String, + description: String, + total_cases: usize, + passed_cases: usize, + failed_cases: usize, +} + +#[derive(Debug, Clone, Serialize)] +struct CaseFailure { + suite_id: String, + case_id: String, + category: String, + tags: Vec, + sql: String, + details: String, +} + +#[test] +fn sql_conformance_suite_matches_documented_subset() { + let suites = load_suites(Path::new(FIXTURE_DIR)).expect("load SQL conformance fixtures"); + assert!(!suites.is_empty(), "expected at least one SQL conformance fixture suite"); + + let mut covered_categories = BTreeSet::new(); + let mut suite_stats = BTreeMap::::new(); + let mut failures = Vec::::new(); + + let mut total_cases = 0_usize; + let mut passed_cases = 0_usize; + + for suite in &suites { + let stats = suite_stats.entry(suite.suite_id.clone()).or_insert_with(|| SuiteStats { + description: suite.description.clone(), + ..SuiteStats::default() + }); + + for case in &suite.cases { + covered_categories.insert(case.category.clone()); + + total_cases = total_cases.saturating_add(1); + stats.total_cases = stats.total_cases.saturating_add(1); + + let outcome = run_case(case); + if outcome.passed { + passed_cases = passed_cases.saturating_add(1); + stats.passed_cases = stats.passed_cases.saturating_add(1); + } else { + stats.failed_cases = stats.failed_cases.saturating_add(1); + failures.push(CaseFailure { + suite_id: suite.suite_id.clone(), + case_id: case.id.clone(), + category: case.category.clone(), + tags: case.tags.clone(), + sql: case.sql.clone(), + details: outcome.details, + }); + } + } + } + + let missing_categories = REQUIRED_CATEGORIES + .iter() + .filter(|required| !covered_categories.contains(**required)) + .map(|value| (*value).to_string()) + .collect::>(); + + let report = ConformanceReport { + schema_version: REPORT_SCHEMA_VERSION, + generated_unix_seconds: now_unix_seconds(), + source_fixture_dir: FIXTURE_DIR.to_string(), + total_suites: suites.len(), + total_cases, + passed_cases, + failed_cases: total_cases.saturating_sub(passed_cases), + covered_categories: covered_categories.into_iter().collect(), + suite_summaries: suite_stats + .into_iter() + .map(|(suite_id, stats)| SuiteSummary { + suite_id, + description: stats.description, + total_cases: stats.total_cases, + passed_cases: stats.passed_cases, + failed_cases: stats.failed_cases, + }) + .collect(), + failures, + }; + + write_report(&report).expect("write SQL conformance report artifact"); + + assert!( + missing_categories.is_empty(), + "SQL conformance fixture coverage is missing required categories: {}", + missing_categories.join(", ") + ); + + assert!( + report.failures.is_empty(), + "SQL conformance failures (see {}):\n{}", + report_path().display(), + report + .failures + .iter() + .map(|failure| { + format!( + "- {}/{} [{}]: {}", + failure.suite_id, failure.case_id, failure.category, failure.details + ) + }) + .collect::>() + .join("\n") + ); +} + +fn run_case(case: &ConformanceCase) -> CaseOutcome { + let store = MvccStore::new(); + let catalog = Catalog::open(store.clone()).expect("open catalog for conformance case"); + let mut session = ExecutionSession::new(&catalog, &store); + + for setup_sql in &case.setup_sql { + if let Err(err) = execute_sql(&mut session, &catalog, setup_sql) { + return CaseOutcome { + passed: false, + details: format!("setup statement failed (`{setup_sql}`): {err}"), + }; + } + } + + let execution = execute_sql(&mut session, &catalog, &case.sql); + evaluate_expectation(execution, &case.expect) +} + +fn execute_sql( + session: &mut ExecutionSession<'_>, + catalog: &Catalog, + sql: &str, +) -> Result { + let statement = parse_statement(sql).map_err(|err| format!("parse error: {err}"))?; + validate_statement(catalog, &statement).map_err(|err| format!("validation error: {err}"))?; + let plan = + plan_statement(catalog, &statement).map_err(|err| format!("planner error: {err}"))?; + session.execute_plan(&plan).map_err(|err| format!("execution error: {err}")) +} + +fn evaluate_expectation( + execution: Result, + expect: &Expectation, +) -> CaseOutcome { + match expect { + Expectation::AffectedRows { count } => match execution { + Ok(ExecutionResult::AffectedRows(actual)) if &actual == count => { + CaseOutcome { passed: true, details: "ok".to_string() } + } + Ok(other) => CaseOutcome { + passed: false, + details: format!("expected affected_rows={count}, got {other:?}"), + }, + Err(err) => CaseOutcome { + passed: false, + details: format!("expected affected_rows={count}, got error: {err}"), + }, + }, + Expectation::Query { columns, rows } => match execution { + Ok(ExecutionResult::Query(query)) => { + let actual_rows = query + .rows + .iter() + .map(|row| row.iter().map(scalar_to_string).collect::>()) + .collect::>(); + + if &query.columns == columns && &actual_rows == rows { + CaseOutcome { passed: true, details: "ok".to_string() } + } else { + CaseOutcome { + passed: false, + details: format!( + "query mismatch: expected columns={columns:?} rows={rows:?}, got columns={:?} rows={actual_rows:?}", + query.columns + ), + } + } + } + Ok(other) => CaseOutcome { + passed: false, + details: format!("expected query result, got {other:?}"), + }, + Err(err) => CaseOutcome { + passed: false, + details: format!("expected query result, got error: {err}"), + }, + }, + Expectation::TransactionState { state } => match execution { + Ok(result) => { + if let Some(actual) = transaction_state_name(&result) { + if state.eq_ignore_ascii_case(actual) { + CaseOutcome { passed: true, details: "ok".to_string() } + } else { + CaseOutcome { + passed: false, + details: format!("expected transaction_state={state}, got {actual}"), + } + } + } else { + CaseOutcome { + passed: false, + details: format!("expected transaction state, got {result:?}"), + } + } + } + Err(err) => CaseOutcome { + passed: false, + details: format!("expected transaction state, got error: {err}"), + }, + }, + Expectation::ErrorContains { message } => match execution { + Ok(result) => CaseOutcome { + passed: false, + details: format!( + "expected error containing '{message}', got successful result {result:?}" + ), + }, + Err(err) => { + if err.to_ascii_lowercase().contains(&message.to_ascii_lowercase()) { + CaseOutcome { passed: true, details: "ok".to_string() } + } else { + CaseOutcome { + passed: false, + details: format!("expected error containing '{message}', got '{err}'"), + } + } + } + }, + } +} + +fn transaction_state_name(result: &ExecutionResult) -> Option<&'static str> { + match result { + ExecutionResult::TransactionBegun => Some("begun"), + ExecutionResult::TransactionCommitted => Some("committed"), + ExecutionResult::TransactionRolledBack => Some("rolled_back"), + _ => None, + } +} + +fn scalar_to_string(value: &ScalarValue) -> String { + match value { + ScalarValue::Integer(value) => value.to_string(), + ScalarValue::BigInt(value) => value.to_string(), + ScalarValue::Float(value) => value.to_string(), + ScalarValue::Text(value) => value.clone(), + ScalarValue::Boolean(value) => value.to_string(), + ScalarValue::Blob(bytes) => bytes_to_hex(bytes), + ScalarValue::Timestamp(value) => value.to_string(), + ScalarValue::Null => "NULL".to_string(), + } +} + +fn bytes_to_hex(bytes: &[u8]) -> String { + let mut out = String::with_capacity(bytes.len() * 2 + 2); + out.push_str("0x"); + for byte in bytes { + out.push(hex_char(byte >> 4)); + out.push(hex_char(byte & 0x0F)); + } + out +} + +fn hex_char(value: u8) -> char { + match value { + 0..=9 => (b'0' + value) as char, + 10..=15 => (b'a' + (value - 10)) as char, + _ => unreachable!(), + } +} + +fn load_suites(path: &Path) -> Result, String> { + let mut files = fs::read_dir(path) + .map_err(|err| format!("failed to read fixture directory '{}': {err}", path.display()))? + .filter_map(Result::ok) + .filter(|entry| entry.path().extension().and_then(|ext| ext.to_str()) == Some("toml")) + .map(|entry| entry.path()) + .collect::>(); + files.sort(); + + if files.is_empty() { + return Err(format!("no TOML fixture files found in {}", path.display())); + } + + let mut suites = Vec::new(); + for file in files { + let raw = fs::read_to_string(&file) + .map_err(|err| format!("failed reading fixture '{}': {err}", file.display()))?; + let suite: ConformanceSuite = toml::from_str(&raw) + .map_err(|err| format!("failed parsing fixture '{}': {err}", file.display()))?; + if suite.cases.is_empty() { + return Err(format!("fixture '{}' must include at least one case", file.display())); + } + suites.push(suite); + } + + Ok(suites) +} + +fn write_report(report: &ConformanceReport) -> Result<(), String> { + let report_path = report_path(); + let parent = report_path.parent().ok_or_else(|| { + format!("failed to resolve parent directory for report path {}", report_path.display()) + })?; + + fs::create_dir_all(parent).map_err(|err| { + format!("failed to create report directory '{}': {err}", parent.display()) + })?; + + let serialized = toml::to_string_pretty(report) + .map_err(|err| format!("failed to serialize conformance report: {err}"))?; + + fs::write(&report_path, serialized) + .map_err(|err| format!("failed to write report '{}': {err}", report_path.display())) +} + +fn report_path() -> PathBuf { + let target_dir = std::env::var_os("CARGO_TARGET_DIR") + .map(PathBuf::from) + .unwrap_or_else(|| PathBuf::from("target")); + target_dir.join(REPORT_RELATIVE_PATH) +} + +fn now_unix_seconds() -> u64 { + SystemTime::now().duration_since(UNIX_EPOCH).map(|d| d.as_secs()).unwrap_or(0) +} diff --git a/tests/sql_persistence.rs b/tests/sql_persistence.rs new file mode 100644 index 0000000..8d38bdd --- /dev/null +++ b/tests/sql_persistence.rs @@ -0,0 +1,135 @@ +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +use lsmdb::catalog::Catalog; +use lsmdb::executor::{ExecutionResult, ExecutionSession, ScalarValue}; +use lsmdb::mvcc::MvccStore; +use lsmdb::planner::plan_statement; +use lsmdb::sql::{parse_statement, validate_statement}; + +fn execute_sql( + session: &mut ExecutionSession<'_>, + catalog: &Catalog, + sql: &str, +) -> ExecutionResult { + let statement = parse_statement(sql).expect("parse SQL"); + validate_statement(catalog, &statement).expect("validate SQL"); + let plan = plan_statement(catalog, &statement).expect("plan SQL"); + session.execute_plan(&plan).expect("execute SQL") +} + +fn test_dir(label: &str) -> PathBuf { + let mut path = std::env::temp_dir(); + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time should be after epoch") + .as_nanos(); + path.push(format!("lsmdb-sql-persistence-{label}-{}-{nanos}", std::process::id())); + fs::create_dir_all(&path).expect("create temp dir"); + path +} + +#[test] +fn sql_data_and_catalog_survive_restart_with_durable_mvcc_store() { + let dir = test_dir("restart"); + + { + let store = MvccStore::open_persistent(&dir).expect("open durable store"); + let catalog = Catalog::open(store.clone()).expect("open catalog"); + let mut session = ExecutionSession::new(&catalog, &store); + + execute_sql( + &mut session, + &catalog, + "CREATE TABLE users ( + id BIGINT NOT NULL, + email TEXT NOT NULL, + active BOOLEAN DEFAULT true, + PRIMARY KEY (id) + )", + ); + + assert!(matches!( + execute_sql( + &mut session, + &catalog, + "INSERT INTO users (id, email, active) VALUES (1, 'persisted@x.com', true)", + ), + ExecutionResult::AffectedRows(1) + )); + } + + { + let store = MvccStore::open_persistent(&dir).expect("reopen durable store"); + let catalog = Catalog::open(store.clone()).expect("reopen catalog"); + let mut session = ExecutionSession::new(&catalog, &store); + + let result = + execute_sql(&mut session, &catalog, "SELECT id, email FROM users WHERE id = 1"); + let ExecutionResult::Query(query) = result else { + panic!("expected query result"); + }; + assert_eq!( + query.rows, + vec![vec![ScalarValue::BigInt(1), ScalarValue::Text("persisted@x.com".to_string())]] + ); + } + + fs::remove_dir_all(dir).expect("cleanup temp dir"); +} + +#[test] +fn rolled_back_and_uncommitted_sql_writes_are_not_visible_after_restart() { + let dir = test_dir("rollback"); + + { + let store = MvccStore::open_persistent(&dir).expect("open durable store"); + let catalog = Catalog::open(store.clone()).expect("open catalog"); + let mut session = ExecutionSession::new(&catalog, &store); + + execute_sql( + &mut session, + &catalog, + "CREATE TABLE accounts ( + id BIGINT NOT NULL, + email TEXT NOT NULL, + PRIMARY KEY (id) + )", + ); + + execute_sql(&mut session, &catalog, "BEGIN ISOLATION LEVEL SNAPSHOT"); + execute_sql( + &mut session, + &catalog, + "INSERT INTO accounts (id, email) VALUES (10, 'rollback@x.com')", + ); + execute_sql(&mut session, &catalog, "ROLLBACK"); + + execute_sql(&mut session, &catalog, "BEGIN ISOLATION LEVEL SNAPSHOT"); + execute_sql( + &mut session, + &catalog, + "INSERT INTO accounts (id, email) VALUES (11, 'uncommitted@x.com')", + ); + // Intentionally drop the session without COMMIT to simulate crash before ack. + } + + { + let store = MvccStore::open_persistent(&dir).expect("reopen durable store"); + let catalog = Catalog::open(store.clone()).expect("reopen catalog"); + let mut session = ExecutionSession::new(&catalog, &store); + + let result = execute_sql( + &mut session, + &catalog, + "SELECT id FROM accounts WHERE id = 10 OR id = 11 ORDER BY id ASC", + ); + let ExecutionResult::Query(query) = result else { + panic!("expected query result"); + }; + assert!(query.rows.is_empty()); + } + + fs::remove_dir_all(dir).expect("cleanup temp dir"); +} diff --git a/tools/ci/run_integration_tests.sh b/tools/ci/run_integration_tests.sh new file mode 100755 index 0000000..32fdf76 --- /dev/null +++ b/tools/ci/run_integration_tests.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +set -euo pipefail + +repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +cd "$repo_root" + +missing=0 +for file in tests/integration/*.rs; do + rel_path="${file#./}" + stem="$(basename "${file%.rs}")" + target="integration_${stem}" + + if ! grep -Fq "name = \"$target\"" Cargo.toml; then + echo "missing integration test target registration: $target ($rel_path)" + missing=1 + fi + + if ! grep -Fq "path = \"$rel_path\"" Cargo.toml; then + echo "missing integration test path registration: $rel_path" + missing=1 + fi +done + +if [[ "$missing" -ne 0 ]]; then + exit 1 +fi + +cargo test --tests --locked diff --git a/tools/lsmdb-admin/main.rs b/tools/lsmdb-admin/main.rs new file mode 100644 index 0000000..4b796b8 --- /dev/null +++ b/tools/lsmdb-admin/main.rs @@ -0,0 +1,733 @@ +use std::env; +use std::fs; +use std::io::{Read, Write}; +use std::path::{Path, PathBuf}; +use std::process::ExitCode; +use std::time::{SystemTime, UNIX_EPOCH}; + +use crc32fast::Hasher; +use lsmdb::config::LsmdbConfig; + +const DEFAULT_CONFIG_PATH: &str = "lsmdb.toml"; +const DEFAULT_ENGINE_ROOT: &str = "data"; +const DEFAULT_OUTPUT_DIR: &str = "diagnostics"; +const DEFAULT_MAX_LOG_BYTES: u64 = 10 * 1024 * 1024; +const BUNDLE_MANIFEST_VERSION: u32 = 1; + +const USAGE_EXIT_CODE: u8 = 64; +const COMMAND_FAILED_EXIT_CODE: u8 = 2; + +fn main() -> ExitCode { + let args = env::args().skip(1).collect::>(); + match run(args) { + Ok(()) => ExitCode::SUCCESS, + Err(RunError::Usage(message)) => { + eprintln!("{message}"); + ExitCode::from(USAGE_EXIT_CODE) + } + Err(RunError::CommandFailed { command, message }) => { + eprintln!("{command}.result=failed"); + eprintln!("error={message}"); + ExitCode::from(COMMAND_FAILED_EXIT_CODE) + } + } +} + +#[derive(Debug)] +enum RunError { + Usage(String), + CommandFailed { command: &'static str, message: String }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum Command { + Help, + ConfigCheck { config_path: PathBuf }, + DiagnosticsBundle(DiagnosticsBundleArgs), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct DiagnosticsBundleArgs { + config_path: PathBuf, + engine_root: PathBuf, + output_dir: PathBuf, + log_dir: Option, + max_log_bytes: u64, +} + +#[derive(Debug, Clone, Copy, Default)] +struct DirStats { + file_count: u64, + total_bytes: u64, +} + +#[derive(Debug, Clone, Copy, Default)] +struct StorageSnapshot { + wal: DirStats, + sst: DirStats, + manifest: DirStats, + total_files: u64, + total_bytes: u64, +} + +#[derive(Debug, Clone, Copy, Default)] +struct LogCaptureSummary { + copied_files: u64, + copied_bytes: u64, + truncated_files: u64, + skipped_files: u64, +} + +fn run(args: Vec) -> Result<(), RunError> { + match parse_command(&args)? { + Command::Help => { + print_help(); + Ok(()) + } + Command::ConfigCheck { config_path } => run_config_check(&config_path), + Command::DiagnosticsBundle(args) => run_diagnostics_bundle(&args), + } +} + +fn parse_command(args: &[String]) -> Result { + if args.is_empty() { + return Err(RunError::Usage(help_with_error("missing command"))); + } + + if args.len() == 1 && matches!(args[0].as_str(), "--help" | "-h" | "help") { + return Ok(Command::Help); + } + + if args.len() >= 2 && args[0] == "config" && args[1] == "check" { + return parse_config_check_args(&args[2..]); + } + + if args.len() >= 2 && args[0] == "diagnostics" && args[1] == "bundle" { + return parse_diagnostics_bundle_args(&args[2..]); + } + + Err(RunError::Usage(help_with_error(&format!("unknown command: {}", args.join(" "))))) +} + +fn parse_config_check_args(args: &[String]) -> Result { + let mut config_path = PathBuf::from(DEFAULT_CONFIG_PATH); + let mut index = 0; + + while index < args.len() { + match args[index].as_str() { + "--config" => { + index += 1; + if index >= args.len() { + return Err(RunError::Usage(help_with_error("missing value for --config"))); + } + config_path = PathBuf::from(&args[index]); + } + "--help" | "-h" => { + return Ok(Command::Help); + } + unknown => { + return Err(RunError::Usage(help_with_error(&format!( + "unknown option for config check: {unknown}" + )))); + } + } + index += 1; + } + + Ok(Command::ConfigCheck { config_path }) +} + +fn parse_diagnostics_bundle_args(args: &[String]) -> Result { + let mut parsed = DiagnosticsBundleArgs { + config_path: PathBuf::from(DEFAULT_CONFIG_PATH), + engine_root: PathBuf::from(DEFAULT_ENGINE_ROOT), + output_dir: PathBuf::from(DEFAULT_OUTPUT_DIR), + log_dir: None, + max_log_bytes: DEFAULT_MAX_LOG_BYTES, + }; + + let mut index = 0; + while index < args.len() { + match args[index].as_str() { + "--config" => { + index += 1; + if index >= args.len() { + return Err(RunError::Usage(help_with_error("missing value for --config"))); + } + parsed.config_path = PathBuf::from(&args[index]); + } + "--engine-root" => { + index += 1; + if index >= args.len() { + return Err(RunError::Usage(help_with_error( + "missing value for --engine-root", + ))); + } + parsed.engine_root = PathBuf::from(&args[index]); + } + "--output-dir" => { + index += 1; + if index >= args.len() { + return Err(RunError::Usage(help_with_error("missing value for --output-dir"))); + } + parsed.output_dir = PathBuf::from(&args[index]); + } + "--log-dir" => { + index += 1; + if index >= args.len() { + return Err(RunError::Usage(help_with_error("missing value for --log-dir"))); + } + parsed.log_dir = Some(PathBuf::from(&args[index])); + } + "--max-log-bytes" => { + index += 1; + if index >= args.len() { + return Err(RunError::Usage(help_with_error( + "missing value for --max-log-bytes", + ))); + } + parsed.max_log_bytes = args[index].parse::().map_err(|err| { + RunError::Usage(help_with_error(&format!( + "invalid value for --max-log-bytes: {err}" + ))) + })?; + } + "--help" | "-h" => { + return Ok(Command::Help); + } + unknown => { + return Err(RunError::Usage(help_with_error(&format!( + "unknown option for diagnostics bundle: {unknown}" + )))); + } + } + index += 1; + } + + Ok(Command::DiagnosticsBundle(parsed)) +} + +fn run_config_check(config_path: &Path) -> Result<(), RunError> { + let config = + LsmdbConfig::load_from_path(config_path).map_err(|err| RunError::CommandFailed { + command: "config.check", + message: format!("{err} (path={})", config_path.display()), + })?; + + let diagnostics = config.startup_diagnostics().map_err(|err| RunError::CommandFailed { + command: "config.check", + message: format!("{err} (path={})", config_path.display()), + })?; + + println!("config.check=ok"); + println!("config.path={}", config_path.display()); + for line in diagnostics.as_key_value_lines() { + println!("{line}"); + } + + Ok(()) +} + +fn run_diagnostics_bundle(args: &DiagnosticsBundleArgs) -> Result<(), RunError> { + let generated_at = now_unix_seconds(); + let bundle_id = format!("bundle-{generated_at}-{}", std::process::id()); + let bundle_dir = args.output_dir.join(&bundle_id); + + fs::create_dir_all(&bundle_dir).map_err(|err| RunError::CommandFailed { + command: "diagnostics.bundle", + message: format!("failed to create bundle directory '{}': {err}", bundle_dir.display()), + })?; + + let mut warnings = Vec::new(); + + write_lines( + &bundle_dir.join("build_info.kv"), + &[ + format!("bundle.generated_unix_seconds={generated_at}"), + format!("build.version={}", env!("CARGO_PKG_VERSION")), + format!("build.name={}", env!("CARGO_PKG_NAME")), + format!("build.target_os={}", env::consts::OS), + format!("build.target_arch={}", env::consts::ARCH), + ], + ) + .map_err(|err| RunError::CommandFailed { + command: "diagnostics.bundle", + message: format!("failed to write build info: {err}"), + })?; + + match fs::read_to_string(&args.config_path) { + Ok(raw) => { + let redacted = redact_sensitive_kv_lines(&raw); + write_string(&bundle_dir.join("config.redacted.toml"), &redacted).map_err(|err| { + RunError::CommandFailed { + command: "diagnostics.bundle", + message: format!("failed to write redacted config: {err}"), + } + })?; + } + Err(err) => warnings.push(format!( + "config.read_error=failed to read '{}': {err}", + args.config_path.display() + )), + } + + match LsmdbConfig::load_from_path(&args.config_path).and_then(|cfg| cfg.startup_diagnostics()) { + Ok(diag) => { + write_lines(&bundle_dir.join("startup_diagnostics.kv"), &diag.as_key_value_lines()) + .map_err(|err| RunError::CommandFailed { + command: "diagnostics.bundle", + message: format!("failed to write startup diagnostics: {err}"), + })?; + } + Err(err) => warnings.push(format!( + "config.diagnostics_error={} (path={})", + err, + args.config_path.display() + )), + } + + let snapshot = collect_storage_snapshot(&args.engine_root, &mut warnings); + write_lines( + &bundle_dir.join("storage_snapshot.kv"), + &[ + format!("engine_root.path={}", args.engine_root.display()), + format!("storage.wal.file_count={}", snapshot.wal.file_count), + format!("storage.wal.total_bytes={}", snapshot.wal.total_bytes), + format!("storage.sst.file_count={}", snapshot.sst.file_count), + format!("storage.sst.total_bytes={}", snapshot.sst.total_bytes), + format!("storage.manifest.file_count={}", snapshot.manifest.file_count), + format!("storage.manifest.total_bytes={}", snapshot.manifest.total_bytes), + format!("storage.total.file_count={}", snapshot.total_files), + format!("storage.total.total_bytes={}", snapshot.total_bytes), + ], + ) + .map_err(|err| RunError::CommandFailed { + command: "diagnostics.bundle", + message: format!("failed to write storage snapshot: {err}"), + })?; + + let log_summary = match &args.log_dir { + Some(log_dir) => { + capture_logs(log_dir, &bundle_dir.join("logs"), args.max_log_bytes, &mut warnings) + .map_err(|err| RunError::CommandFailed { + command: "diagnostics.bundle", + message: format!("failed while capturing logs: {err}"), + })? + } + None => LogCaptureSummary::default(), + }; + + write_lines( + &bundle_dir.join("bundle_manifest.kv"), + &build_manifest_lines(generated_at, &bundle_id, args, snapshot, log_summary, &warnings), + ) + .map_err(|err| RunError::CommandFailed { + command: "diagnostics.bundle", + message: format!("failed to write bundle manifest: {err}"), + })?; + + let checksum_path = bundle_dir.join("checksums.crc32"); + write_checksums(&bundle_dir, &checksum_path).map_err(|err| RunError::CommandFailed { + command: "diagnostics.bundle", + message: format!("failed to write checksums: {err}"), + })?; + + println!("diagnostics.bundle=ok"); + println!("diagnostics.bundle_id={bundle_id}"); + println!("diagnostics.output_dir={}", bundle_dir.display()); + println!("diagnostics.warning_count={}", warnings.len()); + + Ok(()) +} + +fn collect_storage_snapshot(engine_root: &Path, warnings: &mut Vec) -> StorageSnapshot { + let wal = collect_dir_stats(&engine_root.join("wal"), "wal", warnings); + let sst = collect_dir_stats(&engine_root.join("sst"), "sst", warnings); + let manifest = collect_dir_stats(&engine_root.join("manifest"), "manifest", warnings); + + StorageSnapshot { + wal, + sst, + manifest, + total_files: wal.file_count + sst.file_count + manifest.file_count, + total_bytes: wal.total_bytes + sst.total_bytes + manifest.total_bytes, + } +} + +fn collect_dir_stats(path: &Path, label: &str, warnings: &mut Vec) -> DirStats { + if !path.exists() { + warnings.push(format!("storage.{label}.warning=directory not found: {}", path.display())); + return DirStats::default(); + } + + let files = match list_regular_files(path) { + Ok(files) => files, + Err(err) => { + warnings.push(format!( + "storage.{label}.warning=failed to list files at '{}': {err}", + path.display() + )); + return DirStats::default(); + } + }; + + let mut stats = DirStats::default(); + for file in files { + match fs::metadata(&file) { + Ok(metadata) => { + stats.file_count = stats.file_count.saturating_add(1); + stats.total_bytes = stats.total_bytes.saturating_add(metadata.len()); + } + Err(err) => warnings.push(format!( + "storage.{label}.warning=failed to stat '{}': {err}", + file.display() + )), + } + } + + stats +} + +fn capture_logs( + log_dir: &Path, + output_logs_dir: &Path, + max_log_bytes: u64, + warnings: &mut Vec, +) -> std::io::Result { + if !log_dir.exists() { + warnings.push(format!("logs.warning=directory not found: {}", log_dir.display())); + return Ok(LogCaptureSummary::default()); + } + + let mut files = list_regular_files(log_dir)?; + files.sort_by(|a, b| { + let a_time = fs::metadata(a).and_then(|m| m.modified()).ok(); + let b_time = fs::metadata(b).and_then(|m| m.modified()).ok(); + b_time.cmp(&a_time) + }); + + let mut summary = LogCaptureSummary::default(); + let mut remaining = max_log_bytes; + + for source_path in files { + if remaining == 0 { + break; + } + + let raw = match fs::read(&source_path) { + Ok(raw) => raw, + Err(_) => { + summary.skipped_files = summary.skipped_files.saturating_add(1); + continue; + } + }; + + let to_copy = remaining.min(raw.len() as u64) as usize; + if to_copy == 0 { + break; + } + + let clipped = &raw[..to_copy]; + let redacted = redact_sensitive_kv_lines(&String::from_utf8_lossy(clipped)); + + let relative = + source_path.strip_prefix(log_dir).unwrap_or(source_path.as_path()).to_path_buf(); + let destination = output_logs_dir.join(relative); + if let Some(parent) = destination.parent() { + fs::create_dir_all(parent)?; + } + fs::write(destination, redacted.as_bytes())?; + + summary.copied_files = summary.copied_files.saturating_add(1); + summary.copied_bytes = summary.copied_bytes.saturating_add(to_copy as u64); + if raw.len() > to_copy { + summary.truncated_files = summary.truncated_files.saturating_add(1); + } + + remaining = remaining.saturating_sub(to_copy as u64); + } + + Ok(summary) +} + +fn build_manifest_lines( + generated_at: u64, + bundle_id: &str, + args: &DiagnosticsBundleArgs, + snapshot: StorageSnapshot, + log_summary: LogCaptureSummary, + warnings: &[String], +) -> Vec { + let mut lines = vec![ + format!("manifest.version={BUNDLE_MANIFEST_VERSION}"), + format!("bundle.id={bundle_id}"), + format!("bundle.generated_unix_seconds={generated_at}"), + format!("bundle.tool=lsmdb-admin"), + format!("bundle.tool_version={}", env!("CARGO_PKG_VERSION")), + format!("bundle.config_path={}", args.config_path.display()), + format!("bundle.engine_root={}", args.engine_root.display()), + format!("bundle.output_dir={}", args.output_dir.display()), + format!("bundle.max_log_bytes={}", args.max_log_bytes), + format!( + "bundle.log_dir={}", + args.log_dir + .as_ref() + .map(|path| path.display().to_string()) + .unwrap_or_else(|| "".to_string()) + ), + format!("storage.total_files={}", snapshot.total_files), + format!("storage.total_bytes={}", snapshot.total_bytes), + format!("logs.copied_files={}", log_summary.copied_files), + format!("logs.copied_bytes={}", log_summary.copied_bytes), + format!("logs.truncated_files={}", log_summary.truncated_files), + format!("logs.skipped_files={}", log_summary.skipped_files), + format!("warnings.count={}", warnings.len()), + ]; + + for (index, warning) in warnings.iter().enumerate() { + lines.push(format!("warnings.{index}={warning}")); + } + + lines +} + +fn write_checksums(bundle_dir: &Path, checksum_path: &Path) -> std::io::Result<()> { + let mut files = list_regular_files(bundle_dir)?; + files.sort(); + + let mut output = String::new(); + for path in files { + if path == checksum_path { + continue; + } + let checksum = crc32_for_file(&path)?; + let relative = path.strip_prefix(bundle_dir).unwrap_or(path.as_path()); + output.push_str(&format!("{checksum:08x} {}\n", relative.display())); + } + + write_string(checksum_path, &output) +} + +fn crc32_for_file(path: &Path) -> std::io::Result { + let mut file = fs::File::open(path)?; + let mut hasher = Hasher::new(); + let mut buffer = [0_u8; 8192]; + + loop { + let read = file.read(&mut buffer)?; + if read == 0 { + break; + } + hasher.update(&buffer[..read]); + } + + Ok(hasher.finalize()) +} + +fn list_regular_files(root: &Path) -> std::io::Result> { + let mut files = Vec::new(); + let mut stack = vec![root.to_path_buf()]; + + while let Some(path) = stack.pop() { + let metadata = fs::metadata(&path)?; + if metadata.is_file() { + files.push(path); + continue; + } + + for entry in fs::read_dir(&path)? { + let entry = entry?; + stack.push(entry.path()); + } + } + + Ok(files) +} + +fn write_lines(path: &Path, lines: &[String]) -> std::io::Result<()> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + + let mut file = fs::File::create(path)?; + for line in lines { + writeln!(file, "{line}")?; + } + + Ok(()) +} + +fn write_string(path: &Path, content: &str) -> std::io::Result<()> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + fs::write(path, content) +} + +fn now_unix_seconds() -> u64 { + SystemTime::now().duration_since(UNIX_EPOCH).map(|d| d.as_secs()).unwrap_or(0) +} + +fn redact_sensitive_kv_lines(raw: &str) -> String { + let mut output = String::new(); + for (index, line) in raw.lines().enumerate() { + if index > 0 { + output.push('\n'); + } + output.push_str(&redact_line(line)); + } + if raw.ends_with('\n') { + output.push('\n'); + } + output +} + +fn redact_line(line: &str) -> String { + if line.trim_start().starts_with('#') { + return line.to_string(); + } + + if let Some(index) = line.find('=') { + let key = line[..index].trim().to_ascii_lowercase(); + if is_sensitive_key(&key) { + return format!("{} ", &line[..=index]); + } + } + + if let Some(index) = line.find(':') { + let key = line[..index].trim().to_ascii_lowercase(); + if is_sensitive_key(&key) { + return format!("{} ", &line[..=index]); + } + } + + line.to_string() +} + +fn is_sensitive_key(key: &str) -> bool { + const KEYWORDS: [&str; 7] = + ["password", "secret", "token", "credential", "api_key", "private_key", "access_key"]; + + KEYWORDS.iter().any(|needle| key.contains(needle)) +} + +fn print_help() { + println!("lsmdb-admin: operational commands for lsmdb"); + println!(); + println!("Usage:"); + println!(" lsmdb-admin config check [--config PATH]"); + println!( + " lsmdb-admin diagnostics bundle [--config PATH] [--engine-root PATH] [--output-dir PATH] [--log-dir PATH] [--max-log-bytes BYTES]" + ); + println!(); + println!("Commands:"); + println!(" config check Validate config and print startup diagnostics"); + println!(" diagnostics bundle Build support bundle with redacted config and checksums"); + println!(); + println!("Options:"); + println!(" --config PATH Config file path (default: {DEFAULT_CONFIG_PATH})"); + println!(" --engine-root PATH Engine data root (default: {DEFAULT_ENGINE_ROOT})"); + println!(" --output-dir PATH Bundle output root (default: {DEFAULT_OUTPUT_DIR})"); + println!(" --log-dir PATH Optional log directory to include"); + println!( + " --max-log-bytes BYTES Max bytes copied from logs (default: {DEFAULT_MAX_LOG_BYTES})" + ); + println!(" --help, -h Show help"); + println!(); + println!("Exit codes:"); + println!(" 0 success"); + println!(" 2 command failed"); + println!(" 64 usage error"); +} + +fn help_with_error(message: &str) -> String { + format!( + "error: {message}\n\nUsage:\n lsmdb-admin config check [--config PATH]\n lsmdb-admin diagnostics bundle [--config PATH] [--engine-root PATH] [--output-dir PATH] [--log-dir PATH] [--max-log-bytes BYTES]\n lsmdb-admin --help" + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn parse(args: &[&str]) -> Result { + let values = args.iter().map(|value| (*value).to_string()).collect::>(); + parse_command(&values) + } + + #[test] + fn parses_config_check_with_default_path() { + let command = parse(&["config", "check"]).expect("parse default config command"); + assert_eq!( + command, + Command::ConfigCheck { config_path: PathBuf::from(DEFAULT_CONFIG_PATH) } + ); + } + + #[test] + fn parses_diagnostics_bundle_with_defaults() { + let command = parse(&["diagnostics", "bundle"]).expect("parse diagnostics bundle command"); + assert_eq!( + command, + Command::DiagnosticsBundle(DiagnosticsBundleArgs { + config_path: PathBuf::from(DEFAULT_CONFIG_PATH), + engine_root: PathBuf::from(DEFAULT_ENGINE_ROOT), + output_dir: PathBuf::from(DEFAULT_OUTPUT_DIR), + log_dir: None, + max_log_bytes: DEFAULT_MAX_LOG_BYTES, + }) + ); + } + + #[test] + fn parses_diagnostics_bundle_with_custom_args() { + let command = parse(&[ + "diagnostics", + "bundle", + "--config", + "./dev.toml", + "--engine-root", + "./data-dev", + "--output-dir", + "./diag", + "--log-dir", + "./logs", + "--max-log-bytes", + "2048", + ]) + .expect("parse diagnostics bundle with options"); + + assert_eq!( + command, + Command::DiagnosticsBundle(DiagnosticsBundleArgs { + config_path: PathBuf::from("./dev.toml"), + engine_root: PathBuf::from("./data-dev"), + output_dir: PathBuf::from("./diag"), + log_dir: Some(PathBuf::from("./logs")), + max_log_bytes: 2048, + }) + ); + } + + #[test] + fn rejects_unknown_option_for_diagnostics_bundle() { + let err = + parse(&["diagnostics", "bundle", "--bogus"]).expect_err("unknown option should fail"); + match err { + RunError::Usage(message) => { + assert!(message.contains("unknown option")); + } + other => panic!("expected usage error, got {other:?}"), + } + } + + #[test] + fn redacts_sensitive_key_value_lines() { + let input = "db_password = \"abc\"\nwal.segment_size_bytes = 4096\napi_key: secret\n"; + let output = redact_sensitive_kv_lines(input); + + assert!(output.contains("db_password = ")); + assert!(output.contains("api_key: ")); + assert!(output.contains("wal.segment_size_bytes = 4096")); + } +} diff --git a/tools/lsmdb-cli/main.rs b/tools/lsmdb-cli/main.rs index cc38b80..bfc7c38 100644 --- a/tools/lsmdb-cli/main.rs +++ b/tools/lsmdb-cli/main.rs @@ -1,14 +1,22 @@ use std::env; +use std::fs::File; +use std::io::BufReader; use std::io::{self, Write}; use std::net::SocketAddr; +use std::path::PathBuf; +use std::sync::Arc; use std::time::Instant; use lsmdb::observability::init_tracing_from_env; use lsmdb::server::{ QueryPayload, RequestFrame, RequestType, ResponseFrame, ResponsePayload, TransactionState, - read_response, write_request, + authentication_request_with_password, authentication_request_with_token, read_response, + write_request, }; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; +use tokio_rustls::TlsConnector; +use tokio_rustls::rustls::{ClientConfig, RootCertStore, pki_types::ServerName}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -16,11 +24,17 @@ async fn main() -> Result<(), Box> { eprintln!("warning: failed to initialize tracing: {err}"); } - let addr = parse_addr()?; - let mut stream = TcpStream::connect(addr).await?; + let options = parse_cli_options()?; + let mut stream = connect(&options).await?; - println!("Connected to lsmdb server at {addr}"); - println!("Type SQL to execute. Meta commands: \\help, \\q, \\timing, \\explain "); + if let Some(auth) = &options.auth { + authenticate(&mut *stream, auth).await?; + } + + println!("Connected to lsmdb server at {}", options.addr); + println!( + "Type SQL to execute. Meta commands: \\help, \\q, \\timing, \\explain , \\health, \\ready, \\status, \\queries, \\cancel " + ); let mut timing_enabled = false; let stdin = io::stdin(); @@ -41,16 +55,15 @@ async fn main() -> Result<(), Box> { } if input.starts_with('\\') { - match handle_meta_command(input, &mut timing_enabled, &mut stream).await? { + match handle_meta_command(input, &mut timing_enabled, &mut *stream).await? { ControlFlow::Continue => continue, ControlFlow::Break => break, } - continue; } let request = request_from_sql(input); let start = Instant::now(); - let response = send_request(&mut stream, request).await?; + let response = send_request(&mut *stream, request).await?; let elapsed = start.elapsed(); render_response(response); @@ -68,9 +81,28 @@ enum ControlFlow { Break, } -fn parse_addr() -> Result> { +#[derive(Debug, Clone, PartialEq, Eq)] +struct CliOptions { + addr: SocketAddr, + auth: Option, + tls_ca_cert: Option, + tls_server_name: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum ClientAuth { + Password { username: String, password: String }, + Token { token: String }, +} + +fn parse_cli_options() -> Result> { let mut args = env::args().skip(1); let mut addr = "127.0.0.1:7878".to_string(); + let mut username = None; + let mut password = None; + let mut token = None; + let mut tls_ca_cert = None; + let mut tls_server_name = None; while let Some(arg) = args.next() { match arg.as_str() { @@ -80,6 +112,36 @@ fn parse_addr() -> Result> { }; addr = value; } + "--user" => { + let Some(value) = args.next() else { + return Err("--user expects a value".into()); + }; + username = Some(value); + } + "--password" => { + let Some(value) = args.next() else { + return Err("--password expects a value".into()); + }; + password = Some(value); + } + "--token" => { + let Some(value) = args.next() else { + return Err("--token expects a value".into()); + }; + token = Some(value); + } + "--tls-ca-cert" => { + let Some(value) = args.next() else { + return Err("--tls-ca-cert expects a value".into()); + }; + tls_ca_cert = Some(PathBuf::from(value)); + } + "--tls-server-name" => { + let Some(value) = args.next() else { + return Err("--tls-server-name expects a value".into()); + }; + tls_server_name = Some(value); + } "--help" | "-h" => { print_help(); std::process::exit(0); @@ -90,13 +152,97 @@ fn parse_addr() -> Result> { } } - Ok(addr.parse()?) + let auth = if let Some(token) = token { + if username.is_some() || password.is_some() { + return Err("--token cannot be combined with --user or --password".into()); + } + Some(ClientAuth::Token { token }) + } else if let Some(username) = username { + let password = match password { + Some(password) => password, + None => env::var("LSMDB_PASSWORD") + .map_err(|_| "--password or LSMDB_PASSWORD is required when --user is set")?, + }; + Some(ClientAuth::Password { username, password }) + } else if password.is_some() { + return Err("--password requires --user".into()); + } else { + None + }; + + Ok(CliOptions { addr: addr.parse()?, auth, tls_ca_cert, tls_server_name }) +} + +trait ClientIo: AsyncRead + AsyncWrite + Unpin + Send {} + +impl ClientIo for T where T: AsyncRead + AsyncWrite + Unpin + Send {} + +async fn connect(options: &CliOptions) -> Result, Box> { + let tcp = TcpStream::connect(options.addr).await?; + let Some(ca_cert_path) = &options.tls_ca_cert else { + return Ok(Box::new(tcp)); + }; + + let mut root_store = RootCertStore::empty(); + let mut cert_reader = BufReader::new(File::open(ca_cert_path)?); + let certificates = rustls_pemfile::certs(&mut cert_reader).collect::, _>>()?; + if certificates.is_empty() { + return Err(format!( + "TLS CA bundle '{}' does not contain any certificates", + ca_cert_path.display() + ) + .into()); + } + for certificate in certificates { + root_store.add(certificate)?; + } + + let server_name = + options.tls_server_name.clone().unwrap_or_else(|| options.addr.ip().to_string()); + let server_name = ServerName::try_from(server_name.as_str())?.to_owned(); + let config = ClientConfig::builder().with_root_certificates(root_store).with_no_client_auth(); + let connector = TlsConnector::from(Arc::new(config)); + let tls = connector.connect(server_name, tcp).await?; + Ok(Box::new(tls)) +} + +async fn authenticate( + stream: &mut S, + auth: &ClientAuth, +) -> Result<(), Box> +where + S: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + let request = match auth { + ClientAuth::Password { username, password } => { + authentication_request_with_password(username.clone(), password.clone()) + } + ClientAuth::Token { token } => authentication_request_with_token(token.clone()), + }; + + match send_request(stream, request).await? { + ResponseFrame::Ok(ResponsePayload::Authentication(payload)) => { + println!( + "Authenticated as {} ({}) via {}", + payload.identity, payload.role, payload.auth_scheme + ); + Ok(()) + } + ResponseFrame::Err(error) => Err(format!( + "authentication failed [{}{}]: {}", + error.code.as_str(), + if error.retryable { ", retryable" } else { "" }, + error.message + ) + .into()), + other => Err(format!("unexpected authentication response: {other:?}").into()), + } } async fn handle_meta_command( input: &str, timing_enabled: &mut bool, - stream: &mut TcpStream, + stream: &mut (impl AsyncRead + AsyncWrite + Unpin + ?Sized), ) -> Result> { if input == "\\q" || input == "\\quit" { return Ok(ControlFlow::Break); @@ -135,6 +281,90 @@ async fn handle_meta_command( return Ok(ControlFlow::Continue); } + if input == "\\health" { + let start = Instant::now(); + let response = send_request( + stream, + RequestFrame { request_type: RequestType::Health, sql: String::new() }, + ) + .await?; + let elapsed = start.elapsed(); + render_response(response); + if *timing_enabled { + println!("Time: {:.3} ms", elapsed.as_secs_f64() * 1000.0); + } + return Ok(ControlFlow::Continue); + } + + if input == "\\ready" { + let start = Instant::now(); + let response = send_request( + stream, + RequestFrame { request_type: RequestType::Readiness, sql: String::new() }, + ) + .await?; + let elapsed = start.elapsed(); + render_response(response); + if *timing_enabled { + println!("Time: {:.3} ms", elapsed.as_secs_f64() * 1000.0); + } + return Ok(ControlFlow::Continue); + } + + if input == "\\status" { + let start = Instant::now(); + let response = send_request( + stream, + RequestFrame { request_type: RequestType::AdminStatus, sql: String::new() }, + ) + .await?; + let elapsed = start.elapsed(); + render_response(response); + if *timing_enabled { + println!("Time: {:.3} ms", elapsed.as_secs_f64() * 1000.0); + } + return Ok(ControlFlow::Continue); + } + + if input == "\\queries" { + let start = Instant::now(); + let response = send_request( + stream, + RequestFrame { request_type: RequestType::ActiveStatements, sql: String::new() }, + ) + .await?; + let elapsed = start.elapsed(); + render_response(response); + if *timing_enabled { + println!("Time: {:.3} ms", elapsed.as_secs_f64() * 1000.0); + } + return Ok(ControlFlow::Continue); + } + + if let Some(statement_id) = input.strip_prefix("\\cancel") { + let statement_id = statement_id.trim(); + if statement_id.is_empty() { + println!("Usage: \\cancel "); + return Ok(ControlFlow::Continue); + } + + let start = Instant::now(); + let response = send_request( + stream, + RequestFrame { + request_type: RequestType::CancelStatement, + sql: statement_id.to_string(), + }, + ) + .await?; + let elapsed = start.elapsed(); + render_response(response); + if *timing_enabled { + println!("Time: {:.3} ms", elapsed.as_secs_f64() * 1000.0); + } + return Ok(ControlFlow::Continue); + } + println!("Unknown command: {input}. Use \\help for available commands."); Ok(ControlFlow::Continue) } @@ -154,10 +384,13 @@ fn request_from_sql(sql: &str) -> RequestFrame { } } -async fn send_request( - stream: &mut TcpStream, +async fn send_request( + stream: &mut S, request: RequestFrame, -) -> Result> { +) -> Result> +where + S: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ write_request(stream, &request).await?; let response = read_response(stream).await?.ok_or_else(|| "server closed connection".to_string())?; @@ -175,9 +408,78 @@ fn render_response(response: ResponseFrame) { TransactionState::RolledBack => println!("Transaction rolled back"), }, ResponsePayload::ExplainPlan(plan) => println!("{plan}"), + ResponsePayload::Health(health) => { + println!( + "health: {} ({})", + if health.ok { "ok" } else { "unhealthy" }, + health.status + ) + } + ResponsePayload::Readiness(readiness) => { + println!( + "readiness: {} ({})", + if readiness.ready { "ready" } else { "not-ready" }, + readiness.status + ) + } + ResponsePayload::AdminStatus(status) => { + println!("server_version: {}", status.server_version); + println!("protocol_version: {}", status.protocol_version); + println!("uptime_seconds: {}", status.uptime_seconds); + println!("accepting_connections: {}", status.accepting_connections); + println!("active_connections: {}", status.active_connections); + println!("total_connections: {}", status.total_connections); + println!("rejected_connections: {}", status.rejected_connections); + println!("busy_requests: {}", status.busy_requests); + println!("resource_limit_requests: {}", status.resource_limit_requests); + println!("quota_rejections: {}", status.quota_rejections); + println!("timed_out_requests: {}", status.timed_out_requests); + println!("canceled_requests: {}", status.canceled_requests); + println!("active_statements: {}", status.active_statements); + println!( + "active_memory_intensive_requests: {}", + status.active_memory_intensive_requests + ); + println!("mvcc_started: {}", status.mvcc_started); + println!("mvcc_committed: {}", status.mvcc_committed); + println!("mvcc_rolled_back: {}", status.mvcc_rolled_back); + println!("mvcc_write_conflicts: {}", status.mvcc_write_conflicts); + println!("mvcc_active_transactions: {}", status.mvcc_active_transactions); + } + ResponsePayload::ActiveStatements(payload) => { + if payload.statements.is_empty() { + println!("No active statements"); + } else { + for statement in payload.statements { + println!("statement_id: {}", statement.statement_id); + println!("connection_id: {}", statement.connection_id); + println!("identity: {}", statement.identity); + println!("request_type: {}", statement.request_type); + println!("runtime_ms: {}", statement.runtime_ms); + println!("cancel_requested: {}", statement.cancel_requested); + println!("sql_preview: {}", statement.sql_preview); + println!(); + } + } + } + ResponsePayload::StatementCancellation(payload) => { + println!("statement_id: {}", payload.statement_id); + println!("accepted: {}", payload.accepted); + println!("status: {}", payload.status); + } + ResponsePayload::Authentication(payload) => { + println!("authenticated_identity: {}", payload.identity); + println!("authenticated_role: {}", payload.role); + println!("auth_scheme: {}", payload.auth_scheme); + } }, - ResponseFrame::Err(message) => { - eprintln!("Error: {message}"); + ResponseFrame::Err(error) => { + eprintln!( + "Error [{}{}]: {}", + error.code.as_str(), + if error.retryable { ", retryable" } else { "" }, + error.message + ); } } } @@ -256,10 +558,22 @@ fn hex_char(value: u8) -> char { } fn print_help() { - println!("Usage: lsmdb-cli [--addr HOST:PORT]"); + println!("Usage: lsmdb-cli [--addr HOST:PORT] [--user NAME --password VALUE] [--token VALUE]"); println!("Meta commands:"); println!(" \\help Show this help"); println!(" \\q | \\quit Exit CLI"); println!(" \\timing Toggle query timing display"); println!(" \\explain Print physical plan without executing SQL"); + println!(" \\health Request liveness status"); + println!(" \\ready Request readiness status"); + println!(" \\status Request admin runtime diagnostics"); + println!(" \\queries List active statements"); + println!(" \\cancel Signal cancellation for an active statement"); + println!("Auth options:"); + println!(" --user NAME Authenticate with static username/password"); + println!(" --password VALUE Password for --user (or set LSMDB_PASSWORD)"); + println!(" --token VALUE Authenticate with a static token"); + println!("TLS options:"); + println!(" --tls-ca-cert PATH Enable TLS and trust the PEM CA/cert at PATH"); + println!(" --tls-server-name N Override the TLS server name (default: addr IP)"); } diff --git a/tools/release/check_critical_blockers.sh b/tools/release/check_critical_blockers.sh new file mode 100755 index 0000000..794ad9e --- /dev/null +++ b/tools/release/check_critical_blockers.sh @@ -0,0 +1,93 @@ +#!/usr/bin/env bash +set -euo pipefail + +repo="${1:-${GITHUB_REPOSITORY:-}}" +if [[ -z "$repo" ]]; then + echo "error: repository not provided. Pass '/' or set GITHUB_REPOSITORY." + exit 2 +fi + +token="${GITHUB_TOKEN:-${GH_TOKEN:-}}" +if [[ -z "$token" ]] && command -v gh >/dev/null 2>&1; then + token="$(gh auth token 2>/dev/null || true)" +fi +if [[ -z "$token" ]]; then + echo "error: missing token. Set GITHUB_TOKEN/GH_TOKEN or authenticate with gh." + exit 2 +fi + +fetch_open_high_priority_issues() { + local page=1 + local all="[]" + while :; do + local url="https://api.github.com/repos/${repo}/issues?state=open&labels=priority/high&per_page=100&page=${page}" + local response + response="$( + curl -fsSL \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${token}" \ + "$url" + )" + local count + count="$(echo "$response" | jq 'length')" + if [[ "$count" -eq 0 ]]; then + break + fi + + all="$(jq -s 'add' <(echo "$all") <(echo "$response"))" + page=$((page + 1)) + done + echo "$all" +} + +all_open_high_priority="$(fetch_open_high_priority_issues)" + +critical_blockers="$( + echo "$all_open_high_priority" | jq '[ + .[] + | select(has("pull_request") | not) + | select( + ([.labels[].name] | index("area/security")) != null + or ([.labels[].name] | index("area/recovery")) != null + or ([.labels[].name] | index("area/release")) != null + or ([.labels[].name] | index("area/ops")) != null + or ([.labels[].name] | index("area/server")) != null + or ([.labels[].name] | index("area/storage")) != null + or ([.labels[].name] | index("area/performance")) != null + ) + | { + number: .number, + title: .title, + url: .html_url, + labels: [.labels[].name] + } + ]' +)" + +blocker_count="$(echo "$critical_blockers" | jq 'length')" + +if [[ -n "${GITHUB_STEP_SUMMARY:-}" ]]; then + { + echo "### Release Gate: Critical Open Issues" + echo + if [[ "$blocker_count" -eq 0 ]]; then + echo "- status: PASS" + echo "- critical open issues: 0" + else + echo "- status: FAIL" + echo "- critical open issues: $blocker_count" + echo + echo "| Issue | Title | Labels |" + echo "| --- | --- | --- |" + echo "$critical_blockers" | jq -r '.[] | "| [#\(.number)](\(.url)) | \(.title) | \(.labels | join(", ")) |"' + fi + } >>"$GITHUB_STEP_SUMMARY" +fi + +if [[ "$blocker_count" -ne 0 ]]; then + echo "release gate failed: found $blocker_count critical high-priority open issue(s)." + echo "$critical_blockers" | jq -r '.[] | "- #\(.number): \(.title) (\(.url))"' + exit 1 +fi + +echo "release gate passed: no critical high-priority open issues."