diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 0dfbcbdf26..d1b8ff4634 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -27,7 +27,7 @@ jobs: --bin sqlx --release --no-default-features - --features mysql,postgres,sqlite + --features mysql,postgres,sqlite,sqlx-toml - uses: actions/upload-artifact@v4 with: @@ -175,6 +175,49 @@ jobs: DATABASE_URL: postgres://postgres:password@localhost:5432/mockable-todos run: cargo run -p sqlx-example-postgres-mockable-todos + - name: Multi-Database (Setup) + working-directory: examples/postgres/multi-database + env: + DATABASE_URL: postgres://postgres:password@localhost:5432/multi-database + ACCOUNTS_DATABASE_URL: postgres://postgres:password@localhost:5432/multi-database-accounts + PAYMENTS_DATABASE_URL: postgres://postgres:password@localhost:5432/multi-database-payments + run: | + (cd accounts && sqlx db setup) + (cd payments && sqlx db setup) + sqlx db setup + + - name: Multi-Database (Run) + env: + DATABASE_URL: postgres://postgres:password@localhost:5432/multi-database + ACCOUNTS_DATABASE_URL: postgres://postgres:password@localhost:5432/multi-database-accounts + PAYMENTS_DATABASE_URL: postgres://postgres:password@localhost:5432/multi-database-payments + run: cargo run -p sqlx-example-postgres-multi-database + + - name: Multi-Tenant (Setup) + working-directory: examples/postgres/multi-tenant + env: + DATABASE_URL: postgres://postgres:password@localhost:5432/multi-tenant + run: | + (cd accounts && sqlx db setup) + (cd payments && sqlx migrate run) + sqlx migrate run + + - name: Multi-Tenant (Run) + env: + DATABASE_URL: postgres://postgres:password@localhost:5432/multi-tenant + run: cargo run -p sqlx-example-postgres-multi-tenant + + - name: Preferred-Crates (Setup) + working-directory: examples/postgres/preferred-crates + env: + DATABASE_URL: postgres://postgres:password@localhost:5432/preferred-crates + run: sqlx db setup + + - name: Multi-Tenant (Run) + env: + DATABASE_URL: postgres://postgres:password@localhost:5432/preferred-crates + run: cargo run -p sqlx-example-postgres-preferred-crates + - name: TODOs (Setup) working-directory: examples/postgres/todos env: diff --git a/CHANGELOG.md b/CHANGELOG.md index 65ac125096..9036a38d09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,16 +13,42 @@ This section will be replaced in subsequent alpha releases. See the Git history ### Breaking -* [[#3821]] Groundwork for 0.9.0-alpha.1 - * Increased MSRV to 1.86 and set rust-version [@abonander] +* [[#3821]]: Groundwork for 0.9.0-alpha.1 [[@abonander]] + * Increased MSRV to 1.86 and set rust-version * Deleted deprecated combination runtime+TLS features (e.g. `runtime-tokio-native-tls`) * Deleted re-export of unstable `TransactionManager` trait in `sqlx`. * Not technically a breaking change because it's `#[doc(hidden)]`, but [it _will_ break SeaORM][seaorm-2600] if not proactively fixed. +* [[#3383]]: feat: create `sqlx.toml` format [[@abonander]] + * SQLx and `sqlx-cli` now support per-crate configuration files (`sqlx.toml`) + * New functionality includes, but is not limited to: + * Rename `DATABASE_URL` for a crate (for multi-database workspaces) + * Set global type overrides for the macros (supporting custom types) + * Rename or relocate the `_sqlx_migrations` table (for multiple crates using the same database) + * Set characters to ignore when hashing migrations (e.g. ignore whitespace) + * More to be implemented in future releases. + * Enable feature `sqlx-toml` to use. + * `sqlx-cli` has it enabled by default, but `sqlx` does **not**. + * Default features of library crates can be hard to completely turn off because of [feature unification], + so it's better to keep the default feature set as limited as possible. + [This is something we learned the hard way.][preferred-crates] + * Guide: see `sqlx::_config` module in documentation. + * Reference: [[Link](sqlx-core/src/config/reference.toml)] + * Examples (written for Postgres but can be adapted to other databases; PRs welcome!): + * Multiple databases using `DATABASE_URL` renaming and global type overrides: [[Link](examples/postgres/multi-database)] + * Multi-tenant database using `_sqlx_migrations` renaming and multiple schemas: [[Link](examples/postgres/multi-tenant)] + * Force use of `chrono` when `time` is enabled (e.g. when using `tower-sessions-sqlx-store`): [[Link][preferred-crates]] + * Forcing `bigdecimal` when `rust_decimal` is enabled is also shown, but problems with `chrono`/`time` are more common. + * **Breaking changes**: + * Significant changes to the `Migrate` trait + * `sqlx::migrate::resolve_blocking()` is now `#[doc(hidden)]` and thus SemVer-exempt. [seaorm-2600]: https://github.com/SeaQL/sea-orm/issues/2600 +[feature unification]: https://doc.rust-lang.org/cargo/reference/features.html#feature-unification +[preferred-crates]: examples/postgres/preferred-crates -[#3821]: https://github.com/launchbadge/sqlx/pull/3830 +[#3821]: https://github.com/launchbadge/sqlx/pull/3821 +[#3383]: https://github.com/launchbadge/sqlx/pull/3383 ## 0.8.6 - 2025-05-19 diff --git a/Cargo.lock b/Cargo.lock index bb4bf14198..71116dd15d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,18 +4,18 @@ version = 4 [[package]] name = "addr2line" -version = "0.24.2" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" dependencies = [ "gimli", ] [[package]] -name = "adler2" -version = "2.0.0" +name = "adler" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" @@ -127,7 +127,19 @@ checksum = "db4ce4441f99dbd377ca8a8f57b698c44d0d6e712d8329b5040da5a64aa1ce73" dependencies = [ "base64ct", "blake2", - "password-hash", + "password-hash 0.4.2", +] + +[[package]] +name = "argon2" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072" +dependencies = [ + "base64ct", + "blake2", + "cpufeatures", + "password-hash 0.5.0", ] [[package]] @@ -438,17 +450,17 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.74" +version = "0.3.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" dependencies = [ "addr2line", + "cc", "cfg-if", "libc", "miniz_oxide", "object", "rustc-demangle", - "windows-targets 0.52.6", ] [[package]] @@ -742,8 +754,10 @@ checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" dependencies = [ "android-tzdata", "iana-time-zone", + "js-sys", "num-traits", "serde", + "wasm-bindgen", "windows-targets 0.52.6", ] @@ -844,6 +858,33 @@ dependencies = [ "cc", ] +[[package]] +name = "color-eyre" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55146f5e46f237f7423d74111267d4597b59b0dad0ffaf7303bce9945d843ad5" +dependencies = [ + "backtrace", + "color-spantrace", + "eyre", + "indenter", + "once_cell", + "owo-colors", + "tracing-error", +] + +[[package]] +name = "color-spantrace" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd6be1b2a7e382e2b98b43b2adcca6bb0e465af0bdd38123873ae61eb17a72c2" +dependencies = [ + "once_cell", + "owo-colors", + "tracing-core", + "tracing-error", +] + [[package]] name = "colorchoice" version = "1.0.3" @@ -1276,6 +1317,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eyre" +version = "0.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec" +dependencies = [ + "indenter", + "once_cell", +] + [[package]] name = "fastrand" version = "1.9.0" @@ -1528,9 +1579,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.31.1" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "glob" @@ -1899,6 +1950,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb56e1aa765b4b4f3aadfab769793b7087bb03a4ea4920644a6d238e2df5b9ed" +[[package]] +name = "indenter" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" + [[package]] name = "indexmap" version = "1.9.3" @@ -2200,11 +2257,11 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.2" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" +checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" dependencies = [ - "adler2", + "adler", ] [[package]] @@ -2303,6 +2360,16 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -2368,9 +2435,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.7" +version = "0.32.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" dependencies = [ "memchr", ] @@ -2441,6 +2508,18 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + +[[package]] +name = "owo-colors" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" + [[package]] name = "parking" version = "2.2.1" @@ -2481,6 +2560,17 @@ dependencies = [ "subtle", ] +[[package]] +name = "password-hash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + [[package]] name = "paste" version = "1.0.15" @@ -3167,18 +3257,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" +checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" +checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" dependencies = [ "proc-macro2", "quote", @@ -3277,6 +3367,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shell-words" version = "1.1.0" @@ -3490,6 +3589,7 @@ dependencies = [ "time", "tokio", "tokio-stream", + "toml", "tracing", "url", "uuid", @@ -3512,7 +3612,7 @@ name = "sqlx-example-postgres-axum-social" version = "0.1.0" dependencies = [ "anyhow", - "argon2", + "argon2 0.4.1", "axum", "dotenvy", "once_cell", @@ -3590,6 +3690,123 @@ dependencies = [ "tokio", ] +[[package]] +name = "sqlx-example-postgres-multi-database" +version = "0.9.0-alpha.1" +dependencies = [ + "color-eyre", + "dotenvy", + "rand", + "rust_decimal", + "sqlx", + "sqlx-example-postgres-multi-database-accounts", + "sqlx-example-postgres-multi-database-payments", + "tokio", + "tracing-subscriber", +] + +[[package]] +name = "sqlx-example-postgres-multi-database-accounts" +version = "0.1.0" +dependencies = [ + "argon2 0.5.3", + "password-hash 0.5.0", + "rand", + "serde", + "sqlx", + "thiserror 1.0.69", + "time", + "tokio", + "uuid", +] + +[[package]] +name = "sqlx-example-postgres-multi-database-payments" +version = "0.1.0" +dependencies = [ + "rust_decimal", + "sqlx", + "sqlx-example-postgres-multi-database-accounts", + "time", + "uuid", +] + +[[package]] +name = "sqlx-example-postgres-multi-tenant" +version = "0.9.0-alpha.1" +dependencies = [ + "color-eyre", + "dotenvy", + "rand", + "rust_decimal", + "sqlx", + "sqlx-example-postgres-multi-tenant-accounts", + "sqlx-example-postgres-multi-tenant-payments", + "tokio", + "tracing-subscriber", +] + +[[package]] +name = "sqlx-example-postgres-multi-tenant-accounts" +version = "0.1.0" +dependencies = [ + "argon2 0.5.3", + "password-hash 0.5.0", + "rand", + "serde", + "sqlx", + "thiserror 1.0.69", + "time", + "tokio", + "uuid", +] + +[[package]] +name = "sqlx-example-postgres-multi-tenant-payments" +version = "0.1.0" +dependencies = [ + "rust_decimal", + "sqlx", + "sqlx-example-postgres-multi-tenant-accounts", + "time", + "uuid", +] + +[[package]] +name = "sqlx-example-postgres-preferred-crates" +version = "0.9.0-alpha.1" +dependencies = [ + "anyhow", + "chrono", + "dotenvy", + "serde", + "sqlx", + "sqlx-example-postgres-preferred-crates-uses-rust-decimal", + "sqlx-example-postgres-preferred-crates-uses-time", + "tokio", + "uuid", +] + +[[package]] +name = "sqlx-example-postgres-preferred-crates-uses-rust-decimal" +version = "0.9.0-alpha.1" +dependencies = [ + "chrono", + "rust_decimal", + "sqlx", + "uuid", +] + +[[package]] +name = "sqlx-example-postgres-preferred-crates-uses-time" +version = "0.9.0-alpha.1" +dependencies = [ + "serde", + "sqlx", + "time", + "uuid", +] + [[package]] name = "sqlx-example-postgres-todos" version = "0.1.0" @@ -4061,6 +4278,16 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + [[package]] name = "time" version = "0.3.37" @@ -4278,6 +4505,42 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-error" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b1581020d7a273442f5b45074a6a57d5757ad0a47dac0e9f0bd57b81936f3db" +dependencies = [ + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +dependencies = [ + "nu-ansi-term", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", ] [[package]] @@ -4406,9 +4669,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.11.1" +version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b913a3b5fe84142e269d63cc62b64319ccaf89b748fc31fe025177f767a756c4" +checksum = "e0f540e3240398cce6128b64ba83fdbdd86129c16a3aa1a3a252efd66eb3d587" dependencies = [ "serde", ] @@ -4455,6 +4718,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "value-bag" version = "1.10.0" diff --git a/Cargo.toml b/Cargo.toml index 111ee86f9c..8f8db9141c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,8 +16,11 @@ members = [ "examples/postgres/files", "examples/postgres/json", "examples/postgres/listen", - "examples/postgres/todos", "examples/postgres/mockable-todos", + "examples/postgres/multi-database", + "examples/postgres/multi-tenant", + "examples/postgres/preferred-crates", + "examples/postgres/todos", "examples/postgres/transaction", "examples/sqlite/todos", ] @@ -51,7 +54,7 @@ repository.workspace = true rust-version.workspace = true [package.metadata.docs.rs] -features = ["all-databases", "_unstable-all-types", "sqlite-preupdate-hook"] +features = ["all-databases", "_unstable-all-types", "_unstable-doc", "sqlite-preupdate-hook"] rustdoc-args = ["--cfg", "docsrs"] [features] @@ -61,6 +64,9 @@ derive = ["sqlx-macros/derive"] macros = ["derive", "sqlx-macros/macros"] migrate = ["sqlx-core/migrate", "sqlx-macros?/migrate", "sqlx-mysql?/migrate", "sqlx-postgres?/migrate", "sqlx-sqlite?/migrate"] +# Enable parsing of `sqlx.toml` for configuring macros and migrations. +sqlx-toml = ["sqlx-core/sqlx-toml", "sqlx-macros?/sqlx-toml"] + # intended mainly for CI and docs all-databases = ["mysql", "sqlite", "postgres", "any"] _unstable-all-types = [ @@ -76,6 +82,8 @@ _unstable-all-types = [ "bit-vec", "bstr" ] +# Render documentation that wouldn't otherwise be shown (e.g. `sqlx_core::config`). +_unstable-doc = [] # Base runtime features without TLS runtime-async-std = ["_rt-async-std", "sqlx-core/_rt-async-std", "sqlx-macros?/_rt-async-std"] @@ -132,7 +140,7 @@ sqlx-postgres = { version = "=0.9.0-alpha.1", path = "sqlx-postgres" } sqlx-sqlite = { version = "=0.9.0-alpha.1", path = "sqlx-sqlite" } # Facade crate (for reference from sqlx-cli) -sqlx = { version = "=0.9.0-alpha.1", path = ".", default-features = false } +sqlx = { version = "=0.9.0-alpha.1", path = "." } # Common type integrations shared by multiple driver crates. # These are optional unless enabled in a workspace crate. diff --git a/examples/postgres/multi-database/Cargo.toml b/examples/postgres/multi-database/Cargo.toml new file mode 100644 index 0000000000..c5e01621b8 --- /dev/null +++ b/examples/postgres/multi-database/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "sqlx-example-postgres-multi-database" +version.workspace = true +license.workspace = true +edition.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +authors.workspace = true + +[dependencies] +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } + +color-eyre = "0.6.3" +dotenvy = "0.15.7" +tracing-subscriber = "0.3.19" + +rust_decimal = "1.36.0" + +rand = "0.8.5" + +[dependencies.sqlx] +# version = "0.9.0" +workspace = true +features = ["runtime-tokio", "postgres", "migrate", "sqlx-toml"] + +[dependencies.accounts] +path = "accounts" +package = "sqlx-example-postgres-multi-database-accounts" + +[dependencies.payments] +path = "payments" +package = "sqlx-example-postgres-multi-database-payments" + +[lints] +workspace = true diff --git a/examples/postgres/multi-database/README.md b/examples/postgres/multi-database/README.md new file mode 100644 index 0000000000..c7804f90d1 --- /dev/null +++ b/examples/postgres/multi-database/README.md @@ -0,0 +1,62 @@ +# Using Multiple Databases with `sqlx.toml` + +This example project involves three crates, each owning a different schema in one database, +with their own set of migrations. + +* The main crate, a simple binary simulating the action of a REST API. + * Owns the `public` schema (tables are referenced unqualified). + * Migrations are moved to `src/migrations` using config key `migrate.migrations-dir` + to visually separate them from the subcrate folders. +* `accounts`: a subcrate simulating a reusable account-management crate. + * Owns schema `accounts`. +* `payments`: a subcrate simulating a wrapper for a payments API. + * Owns schema `payments`. + +## Note: Schema-Qualified Names + +This example uses schema-qualified names everywhere for clarity. + +It can be tempting to change the `search_path` of the connection (MySQL, Postgres) to eliminate the need for schema +prefixes, but this can cause some really confusing issues when names conflict. + +This example will generate a `_sqlx_migrations` table in three different schemas; if `search_path` is set +to `public,accounts,payments` and the migrator for the main application attempts to reference the table unqualified, +it would throw an error. + +# Setup + +This example requires running three different sets of migrations. + +Ensure `sqlx-cli` is installed with Postgres and `sqlx.toml` support: + +``` +cargo install sqlx-cli --features postgres,sqlx-toml +``` + +Start a Postgres server (shown here using Docker, `run` command also works with `podman`): + +``` +docker run -d -e POSTGRES_PASSWORD=password -p 5432:5432 --name postgres postgres:latest +``` + +Create `.env` with the various database URLs or set them in your shell environment; + +``` +DATABASE_URL=postgres://postgres:password@localhost/example-multi-database +ACCOUNTS_DATABASE_URL=postgres://postgres:password@localhost/example-multi-database-accounts +PAYMENTS_DATABASE_URL=postgres://postgres:password@localhost/example-multi-database-payments +``` + +Run the following commands: + +``` +(cd accounts && sqlx db setup) +(cd payments && sqlx db setup) +sqlx db setup +``` + +It is an open question how to make this more convenient; `sqlx-cli` could gain a `--recursive` flag that checks +subdirectories for `sqlx.toml` files, but that would only work for crates within the same workspace. If the `accounts` +and `payments` crates were instead crates.io dependencies, we would need Cargo's help to resolve that information. + +An issue has been opened for discussion: diff --git a/examples/postgres/multi-database/accounts/Cargo.toml b/examples/postgres/multi-database/accounts/Cargo.toml new file mode 100644 index 0000000000..f7c04ca8b4 --- /dev/null +++ b/examples/postgres/multi-database/accounts/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "sqlx-example-postgres-multi-database-accounts" +version = "0.1.0" +edition = "2021" + +[dependencies] +sqlx = { workspace = true, features = ["postgres", "time", "uuid", "macros", "sqlx-toml"] } +tokio = { version = "1", features = ["rt", "sync"] } + +argon2 = { version = "0.5.3", features = ["password-hash"] } +password-hash = { version = "0.5", features = ["std"] } + +uuid = { version = "1", features = ["serde"] } +thiserror = "1" +rand = "0.8" + +time = { version = "0.3.37", features = ["serde"] } + +serde = { version = "1.0.218", features = ["derive"] } + +[dev-dependencies] +sqlx = { workspace = true, features = ["runtime-tokio"] } diff --git a/examples/postgres/multi-database/accounts/migrations/01_setup.sql b/examples/postgres/multi-database/accounts/migrations/01_setup.sql new file mode 100644 index 0000000000..0f275f7e89 --- /dev/null +++ b/examples/postgres/multi-database/accounts/migrations/01_setup.sql @@ -0,0 +1,30 @@ +-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging +-- and auditing. +-- +-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which +-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do +-- +-- select trigger_updated_at(''); +-- +-- after a `CREATE TABLE`. +create or replace function set_updated_at() + returns trigger as +$$ +begin + NEW.updated_at = now(); + return NEW; +end; +$$ language plpgsql; + +create or replace function trigger_updated_at(tablename regclass) + returns void as +$$ +begin + execute format('CREATE TRIGGER set_updated_at + BEFORE UPDATE + ON %s + FOR EACH ROW + WHEN (OLD is distinct from NEW) + EXECUTE FUNCTION set_updated_at();', tablename); +end; +$$ language plpgsql; diff --git a/examples/postgres/multi-database/accounts/migrations/02_account.sql b/examples/postgres/multi-database/accounts/migrations/02_account.sql new file mode 100644 index 0000000000..519eddb10b --- /dev/null +++ b/examples/postgres/multi-database/accounts/migrations/02_account.sql @@ -0,0 +1,10 @@ +create table account +( + account_id uuid primary key default gen_random_uuid(), + email text unique not null, + password_hash text not null, + created_at timestamptz not null default now(), + updated_at timestamptz +); + +select trigger_updated_at('account'); diff --git a/examples/postgres/multi-database/accounts/migrations/03_session.sql b/examples/postgres/multi-database/accounts/migrations/03_session.sql new file mode 100644 index 0000000000..0a45de26b2 --- /dev/null +++ b/examples/postgres/multi-database/accounts/migrations/03_session.sql @@ -0,0 +1,6 @@ +create table session +( + session_token text primary key, -- random alphanumeric string + account_id uuid not null references account (account_id), + created_at timestamptz not null default now() +); diff --git a/examples/postgres/multi-database/accounts/sqlx.toml b/examples/postgres/multi-database/accounts/sqlx.toml new file mode 100644 index 0000000000..0620c4686f --- /dev/null +++ b/examples/postgres/multi-database/accounts/sqlx.toml @@ -0,0 +1,10 @@ +[common] +database-url-var = "ACCOUNTS_DATABASE_URL" + +[macros.table-overrides.'account'] +'account_id' = "crate::AccountId" +'password_hash' = "sqlx::types::Text" + +[macros.table-overrides.'session'] +'session_token' = "crate::SessionToken" +'account_id' = "crate::AccountId" diff --git a/examples/postgres/multi-database/accounts/src/lib.rs b/examples/postgres/multi-database/accounts/src/lib.rs new file mode 100644 index 0000000000..a543d2fd45 --- /dev/null +++ b/examples/postgres/multi-database/accounts/src/lib.rs @@ -0,0 +1,293 @@ +use argon2::{password_hash, Argon2, PasswordHasher, PasswordVerifier}; +use password_hash::PasswordHashString; +use rand::distributions::{Alphanumeric, DistString}; +use sqlx::PgPool; +use std::sync::Arc; +use uuid::Uuid; + +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tokio::sync::Semaphore; + +#[derive(sqlx::Type, Copy, Clone, Debug, serde::Deserialize, serde::Serialize)] +#[sqlx(transparent)] +pub struct AccountId(pub Uuid); + +#[derive(sqlx::Type, Clone, Debug, serde::Deserialize, serde::Serialize)] +#[sqlx(transparent)] +pub struct SessionToken(pub String); + +pub struct Session { + pub account_id: AccountId, + pub session_token: SessionToken, +} + +#[derive(Clone)] +pub struct AccountsManager { + /// To prevent confusion, each crate manages its own database connection pool. + pool: PgPool, + + /// Controls how many blocking tasks are allowed to run concurrently for Argon2 hashing. + /// + /// ### Motivation + /// Tokio blocking tasks are generally not designed for CPU-bound work. + /// + /// If no threads are idle, Tokio will automatically spawn new ones to handle + /// new blocking tasks up to a very high limit--512 by default. + /// + /// This is because blocking tasks are expected to spend their time *blocked*, e.g. on + /// blocking I/O, and thus not consume CPU resources or require a lot of context switching. + /// + /// This strategy is not the most efficient way to use threads for CPU-bound work, which + /// should schedule work to a fixed number of threads to minimize context switching + /// and memory usage (each new thread needs significant space allocated for its stack). + /// + /// We can work around this by using a purpose-designed thread-pool, like Rayon, + /// but we still have the problem that those APIs usually are not designed to support `async`, + /// so we end up needing blocking tasks anyway, or implementing our own work queue using + /// channels. Rayon also does not shut down idle worker threads. + /// + /// `block_in_place` is not a silver bullet, either, as it simply uses `spawn_blocking` + /// internally to take over from the current thread while it is executing blocking work. + /// This also prevents futures from being polled concurrently in the current task. + /// + /// We can lower the limit for blocking threads when creating the runtime, but this risks + /// starving other blocking tasks that are being created by the application or the Tokio + /// runtime itself + /// (which are used for `tokio::fs`, stdio, resolving of hostnames by `ToSocketAddrs`, etc.). + /// + /// Instead, we can just use a Semaphore to limit how many blocking tasks are spawned at once, + /// emulating the behavior of a thread pool like Rayon without needing any additional crates. + hashing_semaphore: Arc, +} + +#[derive(Debug, thiserror::Error)] +pub enum CreateAccountError { + #[error("error creating account: email in-use")] + EmailInUse, + #[error("error creating account")] + General( + #[source] + #[from] + GeneralError, + ), +} + +#[derive(Debug, thiserror::Error)] +pub enum CreateSessionError { + #[error("unknown email")] + UnknownEmail, + #[error("invalid password")] + InvalidPassword, + #[error("authentication error")] + General( + #[source] + #[from] + GeneralError, + ), +} + +#[derive(Debug, thiserror::Error)] +pub enum GeneralError { + #[error("database error")] + Sqlx( + #[source] + #[from] + sqlx::Error, + ), + #[error("error hashing password")] + PasswordHash( + #[source] + #[from] + password_hash::Error, + ), + #[error("task panicked")] + Task( + #[source] + #[from] + tokio::task::JoinError, + ), +} + +impl AccountsManager { + pub async fn setup( + opts: PgConnectOptions, + max_hashing_threads: usize, + ) -> Result { + // This should be configurable by the caller, but for simplicity, it's not. + let pool = PgPoolOptions::new() + .max_connections(5) + .connect_with(opts) + .await?; + + sqlx::migrate!() + .run(&pool) + .await + .map_err(sqlx::Error::from)?; + + Ok(AccountsManager { + pool, + hashing_semaphore: Semaphore::new(max_hashing_threads).into(), + }) + } + + async fn hash_password(&self, password: String) -> Result { + let guard = self + .hashing_semaphore + .clone() + .acquire_owned() + .await + .expect("BUG: this semaphore should not be closed"); + + // We transfer ownership to the blocking task and back to ensure Tokio doesn't spawn + // excess threads. + let (_guard, res) = tokio::task::spawn_blocking(move || { + let salt = password_hash::SaltString::generate(rand::thread_rng()); + ( + guard, + Argon2::default() + .hash_password(password.as_bytes(), &salt) + .map(|hash| hash.serialize()), + ) + }) + .await?; + + Ok(res?) + } + + async fn verify_password( + &self, + password: String, + hash: PasswordHashString, + ) -> Result<(), CreateSessionError> { + let guard = self + .hashing_semaphore + .clone() + .acquire_owned() + .await + .expect("BUG: this semaphore should not be closed"); + + let (_guard, res) = tokio::task::spawn_blocking(move || { + ( + guard, + Argon2::default().verify_password(password.as_bytes(), &hash.password_hash()), + ) + }) + .await + .map_err(GeneralError::from)?; + + if let Err(password_hash::Error::Password) = res { + return Err(CreateSessionError::InvalidPassword); + } + + res.map_err(GeneralError::from)?; + + Ok(()) + } + + pub async fn create( + &self, + email: &str, + password: String, + ) -> Result { + // Hash password whether the account exists or not to make it harder + // to tell the difference in the timing. + let hash = self.hash_password(password).await?; + + // Thanks to `sqlx.toml`, `account_id` maps to `AccountId` + sqlx::query_scalar!( + // language=PostgreSQL + "insert into account(email, password_hash) \ + values ($1, $2) \ + returning account_id", + email, + hash.as_str(), + ) + .fetch_one(&self.pool) + .await + .map_err(|e| { + if e.as_database_error().and_then(|dbe| dbe.constraint()) + == Some("account_account_id_key") + { + CreateAccountError::EmailInUse + } else { + GeneralError::from(e).into() + } + }) + } + + pub async fn create_session( + &self, + email: &str, + password: String, + ) -> Result { + let mut txn = self.pool.begin().await.map_err(GeneralError::from)?; + + // To save a round-trip to the database, we'll speculatively insert the session token + // at the same time as we're looking up the password hash. + // + // This does nothing until the transaction is actually committed. + let session_token = SessionToken::generate(); + + // Thanks to `sqlx.toml`: + // * `account_id` maps to `AccountId` + // * `password_hash` maps to `Text` + // * `session_token` maps to `SessionToken` + let maybe_account = sqlx::query!( + // language=PostgreSQL + "with account as ( + select account_id, password_hash \ + from account \ + where email = $1 + ), session as ( + insert into session(session_token, account_id) + select $2, account_id + from account + ) + select account.account_id, account.password_hash from account", + email, + session_token.0 + ) + .fetch_optional(&mut *txn) + .await + .map_err(GeneralError::from)?; + + let Some(account) = maybe_account else { + // Hash the password whether the account exists or not to hide the difference in timing. + self.hash_password(password) + .await + .map_err(GeneralError::from)?; + return Err(CreateSessionError::UnknownEmail); + }; + + self.verify_password(password, account.password_hash.into_inner()) + .await?; + + txn.commit().await.map_err(GeneralError::from)?; + + Ok(Session { + account_id: account.account_id, + session_token, + }) + } + + pub async fn auth_session( + &self, + session_token: &str, + ) -> Result, GeneralError> { + sqlx::query_scalar!( + "select account_id from session where session_token = $1", + session_token + ) + .fetch_optional(&self.pool) + .await + .map_err(GeneralError::from) + } +} + +impl SessionToken { + const LEN: usize = 32; + + fn generate() -> Self { + SessionToken(Alphanumeric.sample_string(&mut rand::thread_rng(), Self::LEN)) + } +} diff --git a/examples/postgres/multi-database/payments/Cargo.toml b/examples/postgres/multi-database/payments/Cargo.toml new file mode 100644 index 0000000000..853b32f624 --- /dev/null +++ b/examples/postgres/multi-database/payments/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "sqlx-example-postgres-multi-database-payments" +version = "0.1.0" +edition = "2021" + +[dependencies] + +sqlx = { workspace = true, features = ["postgres", "time", "uuid", "rust_decimal", "sqlx-toml"] } + +rust_decimal = "1.36.0" + +time = "0.3.37" +uuid = "1.12.1" + +[dependencies.accounts] +path = "../accounts" +package = "sqlx-example-postgres-multi-database-accounts" + +[dev-dependencies] +sqlx = { workspace = true, features = ["runtime-tokio"] } diff --git a/examples/postgres/multi-database/payments/migrations/01_setup.sql b/examples/postgres/multi-database/payments/migrations/01_setup.sql new file mode 100644 index 0000000000..5feb67d0a3 --- /dev/null +++ b/examples/postgres/multi-database/payments/migrations/01_setup.sql @@ -0,0 +1,30 @@ +-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging +-- and auditing. +-- +-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which +-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do +-- +-- select trigger_updated_at('
'); +-- +-- after a `CREATE TABLE`. +create or replace function set_updated_at() + returns trigger as +$$ +begin + NEW.updated_at = now(); +return NEW; +end; +$$ language plpgsql; + +create or replace function trigger_updated_at(tablename regclass) + returns void as +$$ +begin +execute format('CREATE TRIGGER set_updated_at + BEFORE UPDATE + ON %s + FOR EACH ROW + WHEN (OLD is distinct from NEW) + EXECUTE FUNCTION set_updated_at();', tablename); +end; +$$ language plpgsql; diff --git a/examples/postgres/multi-database/payments/migrations/02_payment.sql b/examples/postgres/multi-database/payments/migrations/02_payment.sql new file mode 100644 index 0000000000..7175a4b807 --- /dev/null +++ b/examples/postgres/multi-database/payments/migrations/02_payment.sql @@ -0,0 +1,59 @@ +-- `payments::PaymentStatus` +-- +-- Historically at LaunchBadge we preferred not to define enums on the database side because it can be annoying +-- and error-prone to keep them in-sync with the application. +-- Instead, we let the application define the enum and just have the database store a compact representation of it. +-- This is mostly a matter of taste, however. +-- +-- For the purposes of this example, we're using an in-database enum because this is a common use-case +-- for needing type overrides. +create type payment_status as enum ( + 'pending', + 'created', + 'success', + 'failed' + ); + +create table payment +( + payment_id uuid primary key default gen_random_uuid(), + -- Since `account` is in a separate database, we can't foreign-key to it. + account_id uuid not null, + + status payment_status not null, + + -- ISO 4217 currency code (https://en.wikipedia.org/wiki/ISO_4217#List_of_ISO_4217_currency_codes) + -- + -- This *could* be an ENUM of currency codes, but constraining this to a set of known values in the database + -- would be annoying to keep up to date as support for more currencies is added. + -- + -- Consider also if support for cryptocurrencies is desired; those are not covered by ISO 4217. + -- + -- Though ISO 4217 is a three-character code, `TEXT`, `VARCHAR` and `CHAR(N)` + -- all use the same storage format in Postgres. Any constraint against the length of this field + -- would purely be a sanity check. + currency text not null, + -- There's an endless debate about what type should be used to represent currency amounts. + -- + -- Postgres has the `MONEY` type, but the fractional precision depends on a C locale setting and the type is mostly + -- optimized for storing USD, or other currencies with a minimum fraction of 1 cent. + -- + -- NEVER use `FLOAT` or `DOUBLE`. IEEE-754 rounding point has round-off and precision errors that make it wholly + -- unsuitable for representing real money amounts. + -- + -- `NUMERIC`, being an arbitrary-precision decimal format, is a safe default choice that can support any currency, + -- and so is what we've chosen here. + amount NUMERIC not null, + + -- Payments almost always take place through a third-party vendor (e.g. PayPal, Stripe, etc.), + -- so imagine this is an identifier string for this payment in such a vendor's systems. + -- + -- For privacy and security reasons, payment and personally-identifying information + -- (e.g. credit card numbers, bank account numbers, billing addresses) should only be stored with the vendor + -- unless there is a good reason otherwise. + external_payment_id text, + created_at timestamptz not null default now(), + updated_at timestamptz +); + +select trigger_updated_at('payment'); diff --git a/examples/postgres/multi-database/payments/sqlx.toml b/examples/postgres/multi-database/payments/sqlx.toml new file mode 100644 index 0000000000..9196cb2e14 --- /dev/null +++ b/examples/postgres/multi-database/payments/sqlx.toml @@ -0,0 +1,9 @@ +[common] +database-url-var = "PAYMENTS_DATABASE_URL" + +[macros.table-overrides.'payment'] +'payment_id' = "crate::PaymentId" +'account_id' = "accounts::AccountId" + +[macros.type-overrides] +'payment_status' = "crate::PaymentStatus" diff --git a/examples/postgres/multi-database/payments/src/lib.rs b/examples/postgres/multi-database/payments/src/lib.rs new file mode 100644 index 0000000000..356d173a5f --- /dev/null +++ b/examples/postgres/multi-database/payments/src/lib.rs @@ -0,0 +1,127 @@ +use accounts::{AccountId, AccountsManager}; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use sqlx::{Acquire, PgConnection, PgPool, Postgres}; +use time::OffsetDateTime; +use uuid::Uuid; + +#[derive(sqlx::Type, Copy, Clone, Debug)] +#[sqlx(transparent)] +pub struct PaymentId(pub Uuid); + +#[derive(sqlx::Type, Copy, Clone, Debug)] +#[sqlx(type_name = "payment_status")] +#[sqlx(rename_all = "snake_case")] +pub enum PaymentStatus { + Pending, + Created, + Success, + Failed, +} + +// Users often assume that they need `#[derive(FromRow)]` to use `query_as!()`, +// then are surprised when the derive's control attributes have no effect. +// The macros currently do *not* use the `FromRow` trait at all. +// Support for `FromRow` is planned, but would require significant changes to the macros. +// See https://github.com/launchbadge/sqlx/issues/514 for details. +#[derive(Clone, Debug)] +pub struct Payment { + pub payment_id: PaymentId, + pub account_id: AccountId, + pub status: PaymentStatus, + pub currency: String, + // `rust_decimal::Decimal` has more than enough precision for any real-world amount of money. + pub amount: rust_decimal::Decimal, + pub external_payment_id: Option, + pub created_at: OffsetDateTime, + pub updated_at: Option, +} + +pub struct PaymentsManager { + pool: PgPool, +} + +impl PaymentsManager { + pub async fn setup(opts: PgConnectOptions) -> sqlx::Result { + let pool = PgPoolOptions::new() + .max_connections(5) + .connect_with(opts) + .await?; + + sqlx::migrate!().run(&pool).await?; + + Ok(Self { pool }) + } + + /// # Note + /// For simplicity, this does not ensure that `account_id` actually exists. + pub async fn create( + &self, + account_id: AccountId, + currency: &str, + amount: rust_decimal::Decimal, + ) -> sqlx::Result { + // Check-out a connection to avoid paying the overhead of acquiring one for each call. + let mut conn = self.pool.acquire().await?; + + // Imagine this method does more than just create a record in the database; + // maybe it actually initiates the payment with a third-party vendor, like Stripe. + // + // We need to ensure that we can link the payment in the vendor's systems back to a record + // in ours, even if any of the following happens: + // * The application dies before storing the external payment ID in the database + // * We lose the connection to the database while trying to commit a transaction + // * The database server dies while committing the transaction + // + // Thus, we create the payment in three atomic phases: + // * We create the payment record in our system and commit it. + // * We create the payment in the vendor's system with our payment ID attached. + // * We update our payment record with the vendor's payment ID. + let payment_id = sqlx::query_scalar!( + "insert into payment(account_id, status, currency, amount) \ + values ($1, $2, $3, $4) \ + returning payment_id", + // The database doesn't give us enough information to correctly typecheck `AccountId` here. + // We have to insert the UUID directly. + account_id.0, + PaymentStatus::Pending, + currency, + amount, + ) + .fetch_one(&mut *conn) + .await?; + + // We then create the record with the payment vendor... + let external_payment_id = "foobar1234"; + + // Then we store the external payment ID and update the payment status. + // + // NOTE: use caution with `select *` or `returning *`; + // the order of columns gets baked into the binary, so if it changes between compile time and + // run-time, you may run into errors. + let payment = sqlx::query_as!( + Payment, + "update payment \ + set status = $1, external_payment_id = $2 \ + where payment_id = $3 \ + returning *", + PaymentStatus::Created, + external_payment_id, + payment_id.0, + ) + .fetch_one(&mut *conn) + .await?; + + Ok(payment) + } + + pub async fn get(&self, payment_id: PaymentId) -> sqlx::Result> { + sqlx::query_as!( + Payment, + // see note above about `select *` + "select * from payment where payment_id = $1", + payment_id.0 + ) + .fetch_optional(&self.pool) + .await + } +} diff --git a/examples/postgres/multi-database/sqlx.toml b/examples/postgres/multi-database/sqlx.toml new file mode 100644 index 0000000000..7a557cf4ba --- /dev/null +++ b/examples/postgres/multi-database/sqlx.toml @@ -0,0 +1,3 @@ +[migrate] +# Move `migrations/` to under `src/` to separate it from subcrates. +migrations-dir = "src/migrations" \ No newline at end of file diff --git a/examples/postgres/multi-database/src/main.rs b/examples/postgres/multi-database/src/main.rs new file mode 100644 index 0000000000..263eff8e50 --- /dev/null +++ b/examples/postgres/multi-database/src/main.rs @@ -0,0 +1,120 @@ +use accounts::AccountsManager; +use color_eyre::eyre; +use color_eyre::eyre::{Context, OptionExt}; +use payments::PaymentsManager; +use rand::distributions::{Alphanumeric, DistString}; +use sqlx::Connection; + +#[tokio::main] +async fn main() -> eyre::Result<()> { + color_eyre::install()?; + let _ = dotenvy::dotenv(); + tracing_subscriber::fmt::init(); + + let mut conn = sqlx::PgConnection::connect( + // `env::var()` doesn't include the variable name in the error. + &dotenvy::var("DATABASE_URL").wrap_err("DATABASE_URL must be set")?, + ) + .await + .wrap_err("could not connect to database")?; + + let accounts = AccountsManager::setup( + dotenvy::var("ACCOUNTS_DATABASE_URL") + .wrap_err("ACCOUNTS_DATABASE_URL must be set")? + .parse() + .wrap_err("error parsing ACCOUNTS_DATABASE_URL")?, + 1, + ) + .await + .wrap_err("error initializing AccountsManager")?; + + let payments = PaymentsManager::setup( + dotenvy::var("PAYMENTS_DATABASE_URL") + .wrap_err("PAYMENTS_DATABASE_URL must be set")? + .parse() + .wrap_err("error parsing PAYMENTS_DATABASE_URL")?, + ) + .await + .wrap_err("error initializing PaymentsManager")?; + + // For simplicity's sake, imagine each of these might be invoked by different request routes + // in a web application. + + // POST /account + let user_email = format!("user{}@example.com", rand::random::()); + let user_password = Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + // Requires an externally managed transaction in case any application-specific records + // should be created after the actual account record. + let mut txn = conn.begin().await?; + + let account_id = accounts + // Takes ownership of the password string because it's sent to another thread for hashing. + .create(&user_email, user_password.clone()) + .await + .wrap_err("error creating account")?; + + txn.commit().await?; + + println!( + "created account ID: {}, email: {user_email:?}, password: {user_password:?}", + account_id.0 + ); + + // POST /session + // Log the user in. + let session = accounts + .create_session(&user_email, user_password.clone()) + .await + .wrap_err("error creating session")?; + + // After this, session.session_token should then be returned to the client, + // either in the response body or a `Set-Cookie` header. + println!("created session token: {}", session.session_token.0); + + // POST /purchase + // The client would then pass the session token to authenticated routes. + // In this route, they're making some kind of purchase. + + // First, we need to ensure the session is valid. + // `session.session_token` would be passed by the client in whatever way is appropriate. + // + // For a pure REST API, consider an `Authorization: Bearer` header instead of the request body. + // With Axum, you can create a reusable extractor that reads the header and validates the session + // by implementing `FromRequestParts`. + // + // For APIs where the browser is intended to be the primary client, using a session cookie + // may be easier for the frontend. By setting the cookie with `HttpOnly: true`, + // it's impossible for malicious Javascript on the client to access and steal the session token. + let account_id = accounts + .auth_session(&session.session_token.0) + .await + .wrap_err("error authenticating session")? + .ok_or_eyre("session does not exist")?; + + let purchase_amount: rust_decimal::Decimal = "12.34".parse().unwrap(); + + // Then, because the user is making a purchase, we record a payment. + let payment = payments + .create(account_id, "USD", purchase_amount) + .await + .wrap_err("error creating payment")?; + + println!("created payment: {payment:?}"); + + let purchase_id = sqlx::query_scalar!( + "insert into purchase(account_id, payment_id, amount) values ($1, $2, $3) returning purchase_id", + account_id.0, + payment.payment_id.0, + purchase_amount + ) + .fetch_one(&mut conn) + .await + .wrap_err("error creating purchase")?; + + println!("created purchase: {purchase_id}"); + + conn.close().await?; + + Ok(()) +} diff --git a/examples/postgres/multi-database/src/migrations/01_setup.sql b/examples/postgres/multi-database/src/migrations/01_setup.sql new file mode 100644 index 0000000000..0f275f7e89 --- /dev/null +++ b/examples/postgres/multi-database/src/migrations/01_setup.sql @@ -0,0 +1,30 @@ +-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging +-- and auditing. +-- +-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which +-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do +-- +-- select trigger_updated_at('
'); +-- +-- after a `CREATE TABLE`. +create or replace function set_updated_at() + returns trigger as +$$ +begin + NEW.updated_at = now(); + return NEW; +end; +$$ language plpgsql; + +create or replace function trigger_updated_at(tablename regclass) + returns void as +$$ +begin + execute format('CREATE TRIGGER set_updated_at + BEFORE UPDATE + ON %s + FOR EACH ROW + WHEN (OLD is distinct from NEW) + EXECUTE FUNCTION set_updated_at();', tablename); +end; +$$ language plpgsql; diff --git a/examples/postgres/multi-database/src/migrations/02_purchase.sql b/examples/postgres/multi-database/src/migrations/02_purchase.sql new file mode 100644 index 0000000000..dbd83fbf9a --- /dev/null +++ b/examples/postgres/multi-database/src/migrations/02_purchase.sql @@ -0,0 +1,11 @@ +create table purchase +( + purchase_id uuid primary key default gen_random_uuid(), + account_id uuid not null, + payment_id uuid not null, + amount numeric not null, + created_at timestamptz not null default now(), + updated_at timestamptz +); + +select trigger_updated_at('purchase'); diff --git a/examples/postgres/multi-tenant/Cargo.toml b/examples/postgres/multi-tenant/Cargo.toml new file mode 100644 index 0000000000..a219cce2b8 --- /dev/null +++ b/examples/postgres/multi-tenant/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "sqlx-example-postgres-multi-tenant" +version.workspace = true +license.workspace = true +edition.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +authors.workspace = true + +[dependencies] +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } + +color-eyre = "0.6.3" +dotenvy = "0.15.7" +tracing-subscriber = "0.3.19" + +rust_decimal = "1.36.0" + +rand = "0.8.5" + +[dependencies.sqlx] +# version = "0.9.0" +workspace = true +features = ["runtime-tokio", "postgres", "migrate", "sqlx-toml"] + +[dependencies.accounts] +path = "accounts" +package = "sqlx-example-postgres-multi-tenant-accounts" + +[dependencies.payments] +path = "payments" +package = "sqlx-example-postgres-multi-tenant-payments" + +[lints] +workspace = true diff --git a/examples/postgres/multi-tenant/README.md b/examples/postgres/multi-tenant/README.md new file mode 100644 index 0000000000..01848a3f83 --- /dev/null +++ b/examples/postgres/multi-tenant/README.md @@ -0,0 +1,60 @@ +# Multi-tenant Databases with `sqlx.toml` + +This example project involves three crates, each owning a different schema in one database, +with their own set of migrations. + +* The main crate, a simple binary simulating the action of a REST API. + * Owns the `public` schema (tables are referenced unqualified). + * Migrations are moved to `src/migrations` using config key `migrate.migrations-dir` + to visually separate them from the subcrate folders. +* `accounts`: a subcrate simulating a reusable account-management crate. + * Owns schema `accounts`. +* `payments`: a subcrate simulating a wrapper for a payments API. + * Owns schema `payments`. + +## Note: Schema-Qualified Names + +This example uses schema-qualified names everywhere for clarity. + +It can be tempting to change the `search_path` of the connection (MySQL, Postgres) to eliminate the need for schema +prefixes, but this can cause some really confusing issues when names conflict. + +This example will generate a `_sqlx_migrations` table in three different schemas; if `search_path` is set +to `public,accounts,payments` and the migrator for the main application attempts to reference the table unqualified, +it would throw an error. + +# Setup + +This example requires running three different sets of migrations. + +Ensure `sqlx-cli` is installed with Postgres and `sqlx.toml` support: + +``` +cargo install sqlx-cli --features postgres,sqlx-toml +``` + +Start a Postgres server (shown here using Docker, `run` command also works with `podman`): + +``` +docker run -d -e POSTGRES_PASSWORD=password -p 5432:5432 --name postgres postgres:latest +``` + +Create `.env` with `DATABASE_URL` or set the variable in your shell environment; + +``` +DATABASE_URL=postgres://postgres:password@localhost/example-multi-tenant +``` + +Run the following commands: + +``` +(cd accounts && sqlx db setup) +(cd payments && sqlx migrate run) +sqlx migrate run +``` + +It is an open question how to make this more convenient; `sqlx-cli` could gain a `--recursive` flag that checks +subdirectories for `sqlx.toml` files, but that would only work for crates within the same workspace. If the `accounts` +and `payments` crates were instead crates.io dependencies, we would need Cargo's help to resolve that information. + +An issue has been opened for discussion: diff --git a/examples/postgres/multi-tenant/accounts/Cargo.toml b/examples/postgres/multi-tenant/accounts/Cargo.toml new file mode 100644 index 0000000000..40c365c607 --- /dev/null +++ b/examples/postgres/multi-tenant/accounts/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "sqlx-example-postgres-multi-tenant-accounts" +version = "0.1.0" +edition = "2021" + +[dependencies] +tokio = { version = "1", features = ["rt", "sync"] } + +argon2 = { version = "0.5.3", features = ["password-hash"] } +password-hash = { version = "0.5", features = ["std"] } + +uuid = { version = "1", features = ["serde"] } +thiserror = "1" +rand = "0.8" + +time = { version = "0.3.37", features = ["serde"] } + +serde = { version = "1.0.218", features = ["derive"] } + +[dependencies.sqlx] +# version = "0.9.0" +workspace = true +features = ["postgres", "time", "uuid", "macros", "sqlx-toml", "migrate"] + +[dev-dependencies] +sqlx = { workspace = true, features = ["runtime-tokio"] } diff --git a/examples/postgres/multi-tenant/accounts/migrations/01_setup.sql b/examples/postgres/multi-tenant/accounts/migrations/01_setup.sql new file mode 100644 index 0000000000..007e202ec9 --- /dev/null +++ b/examples/postgres/multi-tenant/accounts/migrations/01_setup.sql @@ -0,0 +1,30 @@ +-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging +-- and auditing. +-- +-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which +-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do +-- +-- select accounts.trigger_updated_at('
'); +-- +-- after a `CREATE TABLE`. +create or replace function accounts.set_updated_at() + returns trigger as +$$ +begin + NEW.updated_at = now(); + return NEW; +end; +$$ language plpgsql; + +create or replace function accounts.trigger_updated_at(tablename regclass) + returns void as +$$ +begin + execute format('CREATE TRIGGER set_updated_at + BEFORE UPDATE + ON %s + FOR EACH ROW + WHEN (OLD is distinct from NEW) + EXECUTE FUNCTION accounts.set_updated_at();', tablename); +end; +$$ language plpgsql; diff --git a/examples/postgres/multi-tenant/accounts/migrations/02_account.sql b/examples/postgres/multi-tenant/accounts/migrations/02_account.sql new file mode 100644 index 0000000000..a75814bd09 --- /dev/null +++ b/examples/postgres/multi-tenant/accounts/migrations/02_account.sql @@ -0,0 +1,10 @@ +create table accounts.account +( + account_id uuid primary key default gen_random_uuid(), + email text unique not null, + password_hash text not null, + created_at timestamptz not null default now(), + updated_at timestamptz +); + +select accounts.trigger_updated_at('accounts.account'); diff --git a/examples/postgres/multi-tenant/accounts/migrations/03_session.sql b/examples/postgres/multi-tenant/accounts/migrations/03_session.sql new file mode 100644 index 0000000000..585f425874 --- /dev/null +++ b/examples/postgres/multi-tenant/accounts/migrations/03_session.sql @@ -0,0 +1,6 @@ +create table accounts.session +( + session_token text primary key, -- random alphanumeric string + account_id uuid not null references accounts.account (account_id), + created_at timestamptz not null default now() +); diff --git a/examples/postgres/multi-tenant/accounts/sqlx.toml b/examples/postgres/multi-tenant/accounts/sqlx.toml new file mode 100644 index 0000000000..024f6395e5 --- /dev/null +++ b/examples/postgres/multi-tenant/accounts/sqlx.toml @@ -0,0 +1,11 @@ +[migrate] +create-schemas = ["accounts"] +table-name = "accounts._sqlx_migrations" + +[macros.table-overrides.'accounts.account'] +'account_id' = "crate::AccountId" +'password_hash' = "sqlx::types::Text" + +[macros.table-overrides.'accounts.session'] +'session_token' = "crate::SessionToken" +'account_id' = "crate::AccountId" diff --git a/examples/postgres/multi-tenant/accounts/src/lib.rs b/examples/postgres/multi-tenant/accounts/src/lib.rs new file mode 100644 index 0000000000..ad33735165 --- /dev/null +++ b/examples/postgres/multi-tenant/accounts/src/lib.rs @@ -0,0 +1,284 @@ +use argon2::{password_hash, Argon2, PasswordHasher, PasswordVerifier}; +use password_hash::PasswordHashString; +use rand::distributions::{Alphanumeric, DistString}; +use sqlx::{Acquire, Executor, PgTransaction, Postgres}; +use std::sync::Arc; +use uuid::Uuid; + +use tokio::sync::Semaphore; + +#[derive(sqlx::Type, Copy, Clone, Debug, serde::Deserialize, serde::Serialize)] +#[sqlx(transparent)] +pub struct AccountId(pub Uuid); + +#[derive(sqlx::Type, Clone, Debug, serde::Deserialize, serde::Serialize)] +#[sqlx(transparent)] +pub struct SessionToken(pub String); + +pub struct Session { + pub account_id: AccountId, + pub session_token: SessionToken, +} + +pub struct AccountsManager { + /// Controls how many blocking tasks are allowed to run concurrently for Argon2 hashing. + /// + /// ### Motivation + /// Tokio blocking tasks are generally not designed for CPU-bound work. + /// + /// If no threads are idle, Tokio will automatically spawn new ones to handle + /// new blocking tasks up to a very high limit--512 by default. + /// + /// This is because blocking tasks are expected to spend their time *blocked*, e.g. on + /// blocking I/O, and thus not consume CPU resources or require a lot of context switching. + /// + /// This strategy is not the most efficient way to use threads for CPU-bound work, which + /// should schedule work to a fixed number of threads to minimize context switching + /// and memory usage (each new thread needs significant space allocated for its stack). + /// + /// We can work around this by using a purpose-designed thread-pool, like Rayon, + /// but we still have the problem that those APIs usually are not designed to support `async`, + /// so we end up needing blocking tasks anyway, or implementing our own work queue using + /// channels. Rayon also does not shut down idle worker threads. + /// + /// `block_in_place` is not a silver bullet, either, as it simply uses `spawn_blocking` + /// internally to take over from the current thread while it is executing blocking work. + /// This also prevents futures from being polled concurrently in the current task. + /// + /// We can lower the limit for blocking threads when creating the runtime, but this risks + /// starving other blocking tasks that are being created by the application or the Tokio + /// runtime itself + /// (which are used for `tokio::fs`, stdio, resolving of hostnames by `ToSocketAddrs`, etc.). + /// + /// Instead, we can just use a Semaphore to limit how many blocking tasks are spawned at once, + /// emulating the behavior of a thread pool like Rayon without needing any additional crates. + hashing_semaphore: Arc, +} + +#[derive(Debug, thiserror::Error)] +pub enum CreateAccountError { + #[error("error creating account: email in-use")] + EmailInUse, + #[error("error creating account")] + General( + #[source] + #[from] + GeneralError, + ), +} + +#[derive(Debug, thiserror::Error)] +pub enum CreateSessionError { + #[error("unknown email")] + UnknownEmail, + #[error("invalid password")] + InvalidPassword, + #[error("authentication error")] + General( + #[source] + #[from] + GeneralError, + ), +} + +#[derive(Debug, thiserror::Error)] +pub enum GeneralError { + #[error("database error")] + Sqlx( + #[source] + #[from] + sqlx::Error, + ), + #[error("error hashing password")] + PasswordHash( + #[source] + #[from] + password_hash::Error, + ), + #[error("task panicked")] + Task( + #[source] + #[from] + tokio::task::JoinError, + ), +} + +impl AccountsManager { + pub async fn setup( + pool: impl Acquire<'_, Database = Postgres>, + max_hashing_threads: usize, + ) -> Result { + sqlx::migrate!() + .run(pool) + .await + .map_err(sqlx::Error::from)?; + + Ok(AccountsManager { + hashing_semaphore: Semaphore::new(max_hashing_threads).into(), + }) + } + + async fn hash_password(&self, password: String) -> Result { + let guard = self + .hashing_semaphore + .clone() + .acquire_owned() + .await + .expect("BUG: this semaphore should not be closed"); + + // We transfer ownership to the blocking task and back to ensure Tokio doesn't spawn + // excess threads. + let (_guard, res) = tokio::task::spawn_blocking(move || { + let salt = password_hash::SaltString::generate(rand::thread_rng()); + ( + guard, + Argon2::default() + .hash_password(password.as_bytes(), &salt) + .map(|hash| hash.serialize()), + ) + }) + .await?; + + Ok(res?) + } + + async fn verify_password( + &self, + password: String, + hash: PasswordHashString, + ) -> Result<(), CreateSessionError> { + let guard = self + .hashing_semaphore + .clone() + .acquire_owned() + .await + .expect("BUG: this semaphore should not be closed"); + + let (_guard, res) = tokio::task::spawn_blocking(move || { + ( + guard, + Argon2::default().verify_password(password.as_bytes(), &hash.password_hash()), + ) + }) + .await + .map_err(GeneralError::from)?; + + if let Err(password_hash::Error::Password) = res { + return Err(CreateSessionError::InvalidPassword); + } + + res.map_err(GeneralError::from)?; + + Ok(()) + } + + pub async fn create( + &self, + txn: &mut PgTransaction<'_>, + email: &str, + password: String, + ) -> Result { + // Hash password whether the account exists or not to make it harder + // to tell the difference in the timing. + let hash = self.hash_password(password).await?; + + // Thanks to `sqlx.toml`, `account_id` maps to `AccountId` + sqlx::query_scalar!( + // language=PostgreSQL + "insert into accounts.account(email, password_hash) \ + values ($1, $2) \ + returning account_id", + email, + hash.as_str(), + ) + .fetch_one(&mut **txn) + .await + .map_err(|e| { + if e.as_database_error().and_then(|dbe| dbe.constraint()) + == Some("account_account_id_key") + { + CreateAccountError::EmailInUse + } else { + GeneralError::from(e).into() + } + }) + } + + pub async fn create_session( + &self, + db: impl Acquire<'_, Database = Postgres>, + email: &str, + password: String, + ) -> Result { + let mut txn = db.begin().await.map_err(GeneralError::from)?; + + // To save a round-trip to the database, we'll speculatively insert the session token + // at the same time as we're looking up the password hash. + // + // This does nothing until the transaction is actually committed. + let session_token = SessionToken::generate(); + + // Thanks to `sqlx.toml`: + // * `account_id` maps to `AccountId` + // * `password_hash` maps to `Text` + // * `session_token` maps to `SessionToken` + let maybe_account = sqlx::query!( + // language=PostgreSQL + "with account as ( + select account_id, password_hash \ + from accounts.account \ + where email = $1 + ), session as ( + insert into accounts.session(session_token, account_id) + select $2, account_id + from account + ) + select account.account_id, account.password_hash from account", + email, + session_token.0 + ) + .fetch_optional(&mut *txn) + .await + .map_err(GeneralError::from)?; + + let Some(account) = maybe_account else { + // Hash the password whether the account exists or not to hide the difference in timing. + self.hash_password(password) + .await + .map_err(GeneralError::from)?; + return Err(CreateSessionError::UnknownEmail); + }; + + self.verify_password(password, account.password_hash.into_inner()) + .await?; + + txn.commit().await.map_err(GeneralError::from)?; + + Ok(Session { + account_id: account.account_id, + session_token, + }) + } + + pub async fn auth_session( + &self, + db: impl Executor<'_, Database = Postgres>, + session_token: &str, + ) -> Result, GeneralError> { + sqlx::query_scalar!( + "select account_id from accounts.session where session_token = $1", + session_token + ) + .fetch_optional(db) + .await + .map_err(GeneralError::from) + } +} + +impl SessionToken { + const LEN: usize = 32; + + fn generate() -> Self { + SessionToken(Alphanumeric.sample_string(&mut rand::thread_rng(), Self::LEN)) + } +} diff --git a/examples/postgres/multi-tenant/payments/Cargo.toml b/examples/postgres/multi-tenant/payments/Cargo.toml new file mode 100644 index 0000000000..de15b21828 --- /dev/null +++ b/examples/postgres/multi-tenant/payments/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "sqlx-example-postgres-multi-tenant-payments" +version = "0.1.0" +edition = "2021" + +[dependencies] + +rust_decimal = "1.36.0" + +time = "0.3.37" +uuid = "1.12.1" + +[dependencies.sqlx] +# version = "0.9.0" +workspace = true +features = ["postgres", "time", "uuid", "rust_decimal", "sqlx-toml", "migrate"] + +[dependencies.accounts] +path = "../accounts" +package = "sqlx-example-postgres-multi-tenant-accounts" + +[dev-dependencies] +sqlx = { workspace = true, features = ["runtime-tokio"] } diff --git a/examples/postgres/multi-tenant/payments/migrations/01_setup.sql b/examples/postgres/multi-tenant/payments/migrations/01_setup.sql new file mode 100644 index 0000000000..4935a63705 --- /dev/null +++ b/examples/postgres/multi-tenant/payments/migrations/01_setup.sql @@ -0,0 +1,30 @@ +-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging +-- and auditing. +-- +-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which +-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do +-- +-- select payments.trigger_updated_at('
'); +-- +-- after a `CREATE TABLE`. +create or replace function payments.set_updated_at() + returns trigger as +$$ +begin + NEW.updated_at = now(); +return NEW; +end; +$$ language plpgsql; + +create or replace function payments.trigger_updated_at(tablename regclass) + returns void as +$$ +begin +execute format('CREATE TRIGGER set_updated_at + BEFORE UPDATE + ON %s + FOR EACH ROW + WHEN (OLD is distinct from NEW) + EXECUTE FUNCTION payments.set_updated_at();', tablename); +end; +$$ language plpgsql; diff --git a/examples/postgres/multi-tenant/payments/migrations/02_payment.sql b/examples/postgres/multi-tenant/payments/migrations/02_payment.sql new file mode 100644 index 0000000000..ee88fa18c0 --- /dev/null +++ b/examples/postgres/multi-tenant/payments/migrations/02_payment.sql @@ -0,0 +1,59 @@ +-- `payments::PaymentStatus` +-- +-- Historically at LaunchBadge we preferred not to define enums on the database side because it can be annoying +-- and error-prone to keep them in-sync with the application. +-- Instead, we let the application define the enum and just have the database store a compact representation of it. +-- This is mostly a matter of taste, however. +-- +-- For the purposes of this example, we're using an in-database enum because this is a common use-case +-- for needing type overrides. +create type payments.payment_status as enum ( + 'pending', + 'created', + 'success', + 'failed' + ); + +create table payments.payment +( + payment_id uuid primary key default gen_random_uuid(), + -- This cross-schema reference means migrations for the `accounts` crate should be run first. + account_id uuid not null references accounts.account (account_id), + + status payments.payment_status not null, + + -- ISO 4217 currency code (https://en.wikipedia.org/wiki/ISO_4217#List_of_ISO_4217_currency_codes) + -- + -- This *could* be an ENUM of currency codes, but constraining this to a set of known values in the database + -- would be annoying to keep up to date as support for more currencies is added. + -- + -- Consider also if support for cryptocurrencies is desired; those are not covered by ISO 4217. + -- + -- Though ISO 4217 is a three-character code, `TEXT`, `VARCHAR` and `CHAR(N)` + -- all use the same storage format in Postgres. Any constraint against the length of this field + -- would purely be a sanity check. + currency text not null, + -- There's an endless debate about what type should be used to represent currency amounts. + -- + -- Postgres has the `MONEY` type, but the fractional precision depends on a C locale setting and the type is mostly + -- optimized for storing USD, or other currencies with a minimum fraction of 1 cent. + -- + -- NEVER use `FLOAT` or `DOUBLE`. IEEE-754 rounding point has round-off and precision errors that make it wholly + -- unsuitable for representing real money amounts. + -- + -- `NUMERIC`, being an arbitrary-precision decimal format, is a safe default choice that can support any currency, + -- and so is what we've chosen here. + amount NUMERIC not null, + + -- Payments almost always take place through a third-party vendor (e.g. PayPal, Stripe, etc.), + -- so imagine this is an identifier string for this payment in such a vendor's systems. + -- + -- For privacy and security reasons, payment and personally-identifying information + -- (e.g. credit card numbers, bank account numbers, billing addresses) should only be stored with the vendor + -- unless there is a good reason otherwise. + external_payment_id text, + created_at timestamptz not null default now(), + updated_at timestamptz +); + +select payments.trigger_updated_at('payments.payment'); diff --git a/examples/postgres/multi-tenant/payments/sqlx.toml b/examples/postgres/multi-tenant/payments/sqlx.toml new file mode 100644 index 0000000000..1a4a27dc6a --- /dev/null +++ b/examples/postgres/multi-tenant/payments/sqlx.toml @@ -0,0 +1,10 @@ +[migrate] +create-schemas = ["payments"] +table-name = "payments._sqlx_migrations" + +[macros.table-overrides.'payments.payment'] +'payment_id' = "crate::PaymentId" +'account_id' = "accounts::AccountId" + +[macros.type-overrides] +'payments.payment_status' = "crate::PaymentStatus" diff --git a/examples/postgres/multi-tenant/payments/src/lib.rs b/examples/postgres/multi-tenant/payments/src/lib.rs new file mode 100644 index 0000000000..6a1efe05ee --- /dev/null +++ b/examples/postgres/multi-tenant/payments/src/lib.rs @@ -0,0 +1,110 @@ +use accounts::AccountId; +use sqlx::{Acquire, PgConnection, Postgres}; +use time::OffsetDateTime; +use uuid::Uuid; + +#[derive(sqlx::Type, Copy, Clone, Debug)] +#[sqlx(transparent)] +pub struct PaymentId(pub Uuid); + +#[derive(sqlx::Type, Copy, Clone, Debug)] +#[sqlx(type_name = "payments.payment_status")] +#[sqlx(rename_all = "snake_case")] +pub enum PaymentStatus { + Pending, + Created, + Success, + Failed, +} + +// Users often assume that they need `#[derive(FromRow)]` to use `query_as!()`, +// then are surprised when the derive's control attributes have no effect. +// The macros currently do *not* use the `FromRow` trait at all. +// Support for `FromRow` is planned, but would require significant changes to the macros. +// See https://github.com/launchbadge/sqlx/issues/514 for details. +#[derive(Clone, Debug)] +pub struct Payment { + pub payment_id: PaymentId, + pub account_id: AccountId, + pub status: PaymentStatus, + pub currency: String, + // `rust_decimal::Decimal` has more than enough precision for any real-world amount of money. + pub amount: rust_decimal::Decimal, + pub external_payment_id: Option, + pub created_at: OffsetDateTime, + pub updated_at: Option, +} + +// Accepting `impl Acquire` allows this function to be generic over `Pool`, `Connection` and `Transaction`. +pub async fn migrate(db: impl Acquire<'_, Database = Postgres>) -> sqlx::Result<()> { + sqlx::migrate!().run(db).await?; + Ok(()) +} + +pub async fn create( + conn: &mut PgConnection, + account_id: AccountId, + currency: &str, + amount: rust_decimal::Decimal, +) -> sqlx::Result { + // Imagine this method does more than just create a record in the database; + // maybe it actually initiates the payment with a third-party vendor, like Stripe. + // + // We need to ensure that we can link the payment in the vendor's systems back to a record + // in ours, even if any of the following happens: + // * The application dies before storing the external payment ID in the database + // * We lose the connection to the database while trying to commit a transaction + // * The database server dies while committing the transaction + // + // Thus, we create the payment in three atomic phases: + // * We create the payment record in our system and commit it. + // * We create the payment in the vendor's system with our payment ID attached. + // * We update our payment record with the vendor's payment ID. + let payment_id = sqlx::query_scalar!( + "insert into payments.payment(account_id, status, currency, amount) \ + values ($1, $2, $3, $4) \ + returning payment_id", + // The database doesn't give us enough information to correctly typecheck `AccountId` here. + // We have to insert the UUID directly. + account_id.0, + PaymentStatus::Pending, + currency, + amount, + ) + .fetch_one(&mut *conn) + .await?; + + // We then create the record with the payment vendor... + let external_payment_id = "foobar1234"; + + // Then we store the external payment ID and update the payment status. + // + // NOTE: use caution with `select *` or `returning *`; + // the order of columns gets baked into the binary, so if it changes between compile time and + // run-time, you may run into errors. + let payment = sqlx::query_as!( + Payment, + "update payments.payment \ + set status = $1, external_payment_id = $2 \ + where payment_id = $3 \ + returning *", + PaymentStatus::Created, + external_payment_id, + payment_id.0, + ) + .fetch_one(&mut *conn) + .await?; + + Ok(payment) +} + +pub async fn get(db: &mut PgConnection, payment_id: PaymentId) -> sqlx::Result> { + sqlx::query_as!( + Payment, + // see note above about `select *` + "select * from payments.payment where payment_id = $1", + payment_id.0 + ) + .fetch_optional(db) + .await +} diff --git a/examples/postgres/multi-tenant/sqlx.toml b/examples/postgres/multi-tenant/sqlx.toml new file mode 100644 index 0000000000..7a557cf4ba --- /dev/null +++ b/examples/postgres/multi-tenant/sqlx.toml @@ -0,0 +1,3 @@ +[migrate] +# Move `migrations/` to under `src/` to separate it from subcrates. +migrations-dir = "src/migrations" \ No newline at end of file diff --git a/examples/postgres/multi-tenant/src/main.rs b/examples/postgres/multi-tenant/src/main.rs new file mode 100644 index 0000000000..94a96fcf2b --- /dev/null +++ b/examples/postgres/multi-tenant/src/main.rs @@ -0,0 +1,108 @@ +use accounts::AccountsManager; +use color_eyre::eyre; +use color_eyre::eyre::{Context, OptionExt}; +use rand::distributions::{Alphanumeric, DistString}; +use sqlx::Connection; + +#[tokio::main] +async fn main() -> eyre::Result<()> { + color_eyre::install()?; + let _ = dotenvy::dotenv(); + tracing_subscriber::fmt::init(); + + let mut conn = sqlx::PgConnection::connect( + // `env::var()` doesn't include the variable name in the error. + &dotenvy::var("DATABASE_URL").wrap_err("DATABASE_URL must be set")?, + ) + .await + .wrap_err("could not connect to database")?; + + // Runs migration for `accounts` internally. + let accounts = AccountsManager::setup(&mut conn, 1) + .await + .wrap_err("error initializing AccountsManager")?; + + payments::migrate(&mut conn) + .await + .wrap_err("error running payments migrations")?; + + // For simplicity's sake, imagine each of these might be invoked by different request routes + // in a web application. + + // POST /account + let user_email = format!("user{}@example.com", rand::random::()); + let user_password = Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + // Requires an externally managed transaction in case any application-specific records + // should be created after the actual account record. + let mut txn = conn.begin().await?; + + let account_id = accounts + // Takes ownership of the password string because it's sent to another thread for hashing. + .create(&mut txn, &user_email, user_password.clone()) + .await + .wrap_err("error creating account")?; + + txn.commit().await?; + + println!( + "created account ID: {}, email: {user_email:?}, password: {user_password:?}", + account_id.0 + ); + + // POST /session + // Log the user in. + let session = accounts + .create_session(&mut conn, &user_email, user_password.clone()) + .await + .wrap_err("error creating session")?; + + // After this, session.session_token should then be returned to the client, + // either in the response body or a `Set-Cookie` header. + println!("created session token: {}", session.session_token.0); + + // POST /purchase + // The client would then pass the session token to authenticated routes. + // In this route, they're making some kind of purchase. + + // First, we need to ensure the session is valid. + // `session.session_token` would be passed by the client in whatever way is appropriate. + // + // For a pure REST API, consider an `Authorization: Bearer` header instead of the request body. + // With Axum, you can create a reusable extractor that reads the header and validates the session + // by implementing `FromRequestParts`. + // + // For APIs where the browser is intended to be the primary client, using a session cookie + // may be easier for the frontend. By setting the cookie with `HttpOnly: true`, + // it's impossible for malicious Javascript on the client to access and steal the session token. + let account_id = accounts + .auth_session(&mut conn, &session.session_token.0) + .await + .wrap_err("error authenticating session")? + .ok_or_eyre("session does not exist")?; + + let purchase_amount: rust_decimal::Decimal = "12.34".parse().unwrap(); + + // Then, because the user is making a purchase, we record a payment. + let payment = payments::create(&mut conn, account_id, "USD", purchase_amount) + .await + .wrap_err("error creating payment")?; + + println!("created payment: {payment:?}"); + + let purchase_id = sqlx::query_scalar!( + "insert into purchase(account_id, payment_id, amount) values ($1, $2, $3) returning purchase_id", + account_id.0, + payment.payment_id.0, + purchase_amount + ) + .fetch_one(&mut conn) + .await + .wrap_err("error creating purchase")?; + + println!("created purchase: {purchase_id}"); + + conn.close().await?; + + Ok(()) +} diff --git a/examples/postgres/multi-tenant/src/migrations/01_setup.sql b/examples/postgres/multi-tenant/src/migrations/01_setup.sql new file mode 100644 index 0000000000..0f275f7e89 --- /dev/null +++ b/examples/postgres/multi-tenant/src/migrations/01_setup.sql @@ -0,0 +1,30 @@ +-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging +-- and auditing. +-- +-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which +-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do +-- +-- select trigger_updated_at('
'); +-- +-- after a `CREATE TABLE`. +create or replace function set_updated_at() + returns trigger as +$$ +begin + NEW.updated_at = now(); + return NEW; +end; +$$ language plpgsql; + +create or replace function trigger_updated_at(tablename regclass) + returns void as +$$ +begin + execute format('CREATE TRIGGER set_updated_at + BEFORE UPDATE + ON %s + FOR EACH ROW + WHEN (OLD is distinct from NEW) + EXECUTE FUNCTION set_updated_at();', tablename); +end; +$$ language plpgsql; diff --git a/examples/postgres/multi-tenant/src/migrations/02_purchase.sql b/examples/postgres/multi-tenant/src/migrations/02_purchase.sql new file mode 100644 index 0000000000..3eebd64eb0 --- /dev/null +++ b/examples/postgres/multi-tenant/src/migrations/02_purchase.sql @@ -0,0 +1,11 @@ +create table purchase +( + purchase_id uuid primary key default gen_random_uuid(), + account_id uuid not null references accounts.account (account_id), + payment_id uuid not null references payments.payment (payment_id), + amount numeric not null, + created_at timestamptz not null default now(), + updated_at timestamptz +); + +select trigger_updated_at('purchase'); diff --git a/examples/postgres/preferred-crates/Cargo.toml b/examples/postgres/preferred-crates/Cargo.toml new file mode 100644 index 0000000000..cf6b0aca1d --- /dev/null +++ b/examples/postgres/preferred-crates/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "sqlx-example-postgres-preferred-crates" +version.workspace = true +license.workspace = true +edition.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +authors.workspace = true + +[dependencies] +dotenvy.workspace = true + +anyhow = "1" +chrono = "0.4" +serde = { version = "1", features = ["derive"] } +uuid = { version = "1", features = ["serde"] } + +[dependencies.tokio] +workspace = true +features = ["rt-multi-thread", "macros"] + +[dependencies.sqlx] +# version = "0.9.0" +workspace = true +features = ["runtime-tokio", "postgres", "bigdecimal", "chrono", "derive", "migrate", "sqlx-toml"] + +[dependencies.uses-rust-decimal] +path = "uses-rust-decimal" +package = "sqlx-example-postgres-preferred-crates-uses-rust-decimal" + +[dependencies.uses-time] +path = "uses-time" +package = "sqlx-example-postgres-preferred-crates-uses-time" + +[lints] +workspace = true diff --git a/examples/postgres/preferred-crates/README.md b/examples/postgres/preferred-crates/README.md new file mode 100644 index 0000000000..83f6ae6a5d --- /dev/null +++ b/examples/postgres/preferred-crates/README.md @@ -0,0 +1,55 @@ +# Usage of `macros.preferred-crates` in `sqlx.toml` + +## The Problem + +SQLx has many optional features that enable integrations for external crates to map from/to SQL types. + +In some cases, more than one optional feature applies to the same set of types: + +* The `chrono` and `time` features enable mapping SQL date/time types to those in these crates. +* Similarly, `bigdecimal` and `rust_decimal` enable mapping for the SQL `NUMERIC` type. + +Throughout its existence, the `query!()` family of macros has inferred which crate to use based on which optional +feature was enabled. If multiple features are enabled, one takes precedent over the other: `time` over `chrono`, +`rust_decimal` over `bigdecimal`, etc. The ordering is purely the result of historical happenstance and +does not indicate any specific preference for one crate over another. They each have their tradeoffs. + +This works fine when only one crate in the dependency graph depends on SQLx, but can break down if another crate +in the dependency graph also depends on SQLx. Because of Cargo's [feature unification], any features enabled +by this other crate are also forced on for all other crates that depend on the same version of SQLx in the same project. + +This is intentional design on Cargo's part; features are meant to be purely additive, so it can build each transitive +dependency just once no matter how many crates depend on it. Otherwise, this could result in combinatorial explosion. + +Unfortunately for us, this means that if your project depends on SQLx and enables the `chrono` feature, but also depends +on another crate that enables the `time` feature, the `query!()` macros will end up thinking that _you_ want to use +the `time` crate, because they don't know any better. + +Fixing this has historically required patching the dependency, which is annoying to maintain long-term. + +[feature unification]: https://doc.rust-lang.org/cargo/reference/features.html#feature-unification + +## The Solution + +However, as of 0.9.0, SQLx has gained the ability to configure the macros through the use of a `sqlx.toml` file. + +This includes the ability to tell the macros which crate you prefer, overriding the inference. + +See the [`sqlx.toml`](./sqlx.toml) file in this directory for details. + +A full reference `sqlx.toml` is also available as `sqlx-core/src/config/reference.toml`. + +## This Example + +This example exists both to showcase the macro configuration and also serve as a test for the functionality. + +It consists of three crates: + +* The root crate, which depends on SQLx and enables the `chrono` and `bigdecimal` features, +* `uses-rust-decimal`, a dependency which also depends on SQLx and enables the `rust_decimal` feature, +* and `uses-time`, a dependency which also depends on SQLx and enables the `time` feature. + * This serves as a stand-in for `tower-sessions-sqlx-store`, which is [one of the culprits for this issue](https://github.com/launchbadge/sqlx/issues/3412#issuecomment-2277377597). + +Given that both dependencies enable features with higher precedence, they would historically have interfered +with the usage in the root crate. (Pretend that they're published to crates.io and cannot be easily changed.) +However, because the root crate uses a `sqlx.toml`, the macros know exactly which crates it wants to use and everyone's happy. diff --git a/examples/postgres/preferred-crates/sqlx.toml b/examples/postgres/preferred-crates/sqlx.toml new file mode 100644 index 0000000000..c4d6394a9c --- /dev/null +++ b/examples/postgres/preferred-crates/sqlx.toml @@ -0,0 +1,9 @@ +[migrate] +# Move `migrations/` to under `src/` to separate it from subcrates. +migrations-dir = "src/migrations" + +[macros.preferred-crates] +# Keeps `time` from taking precedent even though it's enabled by a dependency. +date-time = "chrono" +# Same thing with `rust_decimal` +numeric = "bigdecimal" diff --git a/examples/postgres/preferred-crates/src/main.rs b/examples/postgres/preferred-crates/src/main.rs new file mode 100644 index 0000000000..5d6e4dc9b8 --- /dev/null +++ b/examples/postgres/preferred-crates/src/main.rs @@ -0,0 +1,70 @@ +use anyhow::Context; +use chrono::{DateTime, Utc}; +use sqlx::{Connection, PgConnection}; +use std::time::Duration; +use uuid::Uuid; + +#[derive(serde::Serialize, serde::Deserialize, PartialEq, Eq, Debug)] +struct SessionData { + user_id: Uuid, +} + +#[derive(sqlx::FromRow, Debug)] +struct User { + id: Uuid, + username: String, + password_hash: String, + // Because `time` is enabled by a transitive dependency, we previously would have needed + // a type override in the query to get types from `chrono`. + created_at: DateTime, + updated_at: Option>, +} + +const SESSION_DURATION: Duration = Duration::from_secs(60 * 60); // 1 hour + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let mut conn = + PgConnection::connect(&dotenvy::var("DATABASE_URL").context("DATABASE_URL must be set")?) + .await + .context("failed to connect to DATABASE_URL")?; + + sqlx::migrate!("./src/migrations").run(&mut conn).await?; + + uses_rust_decimal::create_table(&mut conn).await?; + uses_time::create_table(&mut conn).await?; + + let user_id = sqlx::query_scalar!( + "insert into users(username, password_hash) values($1, $2) returning id", + "user_foo", + "", + ) + .fetch_one(&mut conn) + .await?; + + let user = sqlx::query_as!(User, "select * from users where id = $1", user_id) + .fetch_one(&mut conn) + .await?; + + println!("Created user: {user:?}"); + + let session = + uses_time::create_session(&mut conn, SessionData { user_id }, SESSION_DURATION).await?; + + let session_from_id = uses_time::get_session::(&mut conn, session.id) + .await? + .expect("expected session"); + + assert_eq!(session, session_from_id); + + let purchase_id = + uses_rust_decimal::create_purchase(&mut conn, user_id, 1234u32.into(), "Rent").await?; + + let purchase = uses_rust_decimal::get_purchase(&mut conn, purchase_id) + .await? + .expect("expected purchase"); + + println!("Created purchase: {purchase:?}"); + + Ok(()) +} diff --git a/examples/postgres/preferred-crates/src/migrations/01_setup.sql b/examples/postgres/preferred-crates/src/migrations/01_setup.sql new file mode 100644 index 0000000000..0f275f7e89 --- /dev/null +++ b/examples/postgres/preferred-crates/src/migrations/01_setup.sql @@ -0,0 +1,30 @@ +-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging +-- and auditing. +-- +-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which +-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do +-- +-- select trigger_updated_at('
'); +-- +-- after a `CREATE TABLE`. +create or replace function set_updated_at() + returns trigger as +$$ +begin + NEW.updated_at = now(); + return NEW; +end; +$$ language plpgsql; + +create or replace function trigger_updated_at(tablename regclass) + returns void as +$$ +begin + execute format('CREATE TRIGGER set_updated_at + BEFORE UPDATE + ON %s + FOR EACH ROW + WHEN (OLD is distinct from NEW) + EXECUTE FUNCTION set_updated_at();', tablename); +end; +$$ language plpgsql; diff --git a/examples/postgres/preferred-crates/src/migrations/02_users.sql b/examples/postgres/preferred-crates/src/migrations/02_users.sql new file mode 100644 index 0000000000..6ef4f25dfc --- /dev/null +++ b/examples/postgres/preferred-crates/src/migrations/02_users.sql @@ -0,0 +1,11 @@ +create table users( + id uuid primary key default gen_random_uuid(), + username text not null, + password_hash text not null, + created_at timestamptz not null default now(), + updated_at timestamptz +); + +create unique index users_username_unique on users(lower(username)); + +select trigger_updated_at('users'); diff --git a/examples/postgres/preferred-crates/uses-rust-decimal/Cargo.toml b/examples/postgres/preferred-crates/uses-rust-decimal/Cargo.toml new file mode 100644 index 0000000000..13c409ac84 --- /dev/null +++ b/examples/postgres/preferred-crates/uses-rust-decimal/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "sqlx-example-postgres-preferred-crates-uses-rust-decimal" +version.workspace = true +license.workspace = true +edition.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +authors.workspace = true + +[dependencies] +chrono = "0.4" +rust_decimal = "1" +uuid = "1" + +[dependencies.sqlx] +workspace = true +features = ["runtime-tokio", "postgres", "rust_decimal", "chrono", "uuid"] + +[lints] +workspace = true diff --git a/examples/postgres/preferred-crates/uses-rust-decimal/src/lib.rs b/examples/postgres/preferred-crates/uses-rust-decimal/src/lib.rs new file mode 100644 index 0000000000..f955b737d1 --- /dev/null +++ b/examples/postgres/preferred-crates/uses-rust-decimal/src/lib.rs @@ -0,0 +1,55 @@ +use chrono::{DateTime, Utc}; +use sqlx::PgExecutor; + +#[derive(sqlx::FromRow, Debug)] +pub struct Purchase { + pub id: Uuid, + pub user_id: Uuid, + pub amount: Decimal, + pub description: String, + pub created_at: DateTime, +} + +pub use rust_decimal::Decimal; +use uuid::Uuid; + +pub async fn create_table(e: impl PgExecutor<'_>) -> sqlx::Result<()> { + sqlx::raw_sql( + // language=PostgreSQL + "create table if not exists purchases( \ + id uuid primary key default gen_random_uuid(), \ + user_id uuid not null, \ + amount numeric not null check(amount > 0), \ + description text not null, \ + created_at timestamptz not null default now() \ + ); + ", + ) + .execute(e) + .await?; + + Ok(()) +} + +pub async fn create_purchase( + e: impl PgExecutor<'_>, + user_id: Uuid, + amount: Decimal, + description: &str, +) -> sqlx::Result { + sqlx::query_scalar( + "insert into purchases(user_id, amount, description) values ($1, $2, $3) returning id", + ) + .bind(user_id) + .bind(amount) + .bind(description) + .fetch_one(e) + .await +} + +pub async fn get_purchase(e: impl PgExecutor<'_>, id: Uuid) -> sqlx::Result> { + sqlx::query_as("select * from purchases where id = $1") + .bind(id) + .fetch_optional(e) + .await +} diff --git a/examples/postgres/preferred-crates/uses-time/Cargo.toml b/examples/postgres/preferred-crates/uses-time/Cargo.toml new file mode 100644 index 0000000000..1dfb1dab7f --- /dev/null +++ b/examples/postgres/preferred-crates/uses-time/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "sqlx-example-postgres-preferred-crates-uses-time" +version.workspace = true +license.workspace = true +edition.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +authors.workspace = true + +[dependencies] +serde = "1" +time = "0.3" +uuid = "1" + +[dependencies.sqlx] +workspace = true +features = ["runtime-tokio", "postgres", "time", "json", "uuid"] + +[lints] +workspace = true diff --git a/examples/postgres/preferred-crates/uses-time/src/lib.rs b/examples/postgres/preferred-crates/uses-time/src/lib.rs new file mode 100644 index 0000000000..4fb3377880 --- /dev/null +++ b/examples/postgres/preferred-crates/uses-time/src/lib.rs @@ -0,0 +1,75 @@ +use serde::de::DeserializeOwned; +use serde::Serialize; +use sqlx::PgExecutor; +use std::time::Duration; +use time::OffsetDateTime; + +use sqlx::types::Json; +use uuid::Uuid; + +#[derive(sqlx::FromRow, PartialEq, Eq, Debug)] +pub struct Session { + pub id: Uuid, + #[sqlx(json)] + pub data: D, + pub created_at: OffsetDateTime, + pub expires_at: OffsetDateTime, +} + +pub async fn create_table(e: impl PgExecutor<'_>) -> sqlx::Result<()> { + sqlx::raw_sql( + // language=PostgreSQL + "create table if not exists sessions( \ + id uuid primary key default gen_random_uuid(), \ + data jsonb not null, + created_at timestamptz not null default now(), + expires_at timestamptz not null + )", + ) + .execute(e) + .await?; + + Ok(()) +} + +pub async fn create_session( + e: impl PgExecutor<'_>, + data: D, + valid_duration: Duration, +) -> sqlx::Result> { + // Round down to the nearest second because + // Postgres doesn't support precision higher than 1 microsecond anyway. + let created_at = OffsetDateTime::now_utc() + .replace_nanosecond(0) + .expect("0 nanoseconds should be in range"); + + let expires_at = created_at + valid_duration; + + let id: Uuid = sqlx::query_scalar( + "insert into sessions(data, created_at, expires_at) \ + values ($1, $2, $3) \ + returning id", + ) + .bind(Json(&data)) + .bind(created_at) + .bind(expires_at) + .fetch_one(e) + .await?; + + Ok(Session { + id, + data, + created_at, + expires_at, + }) +} + +pub async fn get_session( + e: impl PgExecutor<'_>, + id: Uuid, +) -> sqlx::Result>> { + sqlx::query_as("select id, data, created_at, expires_at from sessions where id = $1") + .bind(id) + .fetch_optional(e) + .await +} diff --git a/sqlx-cli/Cargo.toml b/sqlx-cli/Cargo.toml index f8c821a8f8..1de4e9cbc9 100644 --- a/sqlx-cli/Cargo.toml +++ b/sqlx-cli/Cargo.toml @@ -28,11 +28,6 @@ path = "src/bin/cargo-sqlx.rs" [dependencies] dotenvy = "0.15.0" tokio = { version = "1.15.0", features = ["macros", "rt", "rt-multi-thread", "signal"] } -sqlx = { workspace = true, default-features = false, features = [ - "runtime-tokio", - "migrate", - "any", -] } futures = "0.3.19" clap = { version = "4.3.10", features = ["derive", "env", "wrap_help"] } clap_complete = { version = "4.3.1", optional = true } @@ -48,8 +43,18 @@ filetime = "0.2" backoff = { version = "0.4.0", features = ["futures", "tokio"] } +[dependencies.sqlx] +workspace = true +default-features = false +features = [ + "runtime-tokio", + "migrate", + "any", +] + [features] -default = ["postgres", "sqlite", "mysql", "native-tls", "completions"] +default = ["postgres", "sqlite", "mysql", "native-tls", "completions", "sqlx-toml"] + rustls = ["sqlx/tls-rustls"] native-tls = ["sqlx/tls-native-tls"] @@ -64,6 +69,8 @@ openssl-vendored = ["openssl/vendored"] completions = ["dep:clap_complete"] +sqlx-toml = ["sqlx/sqlx-toml"] + # Conditional compilation only _sqlite = [] diff --git a/sqlx-cli/src/database.rs b/sqlx-cli/src/database.rs index 7a9bc6bf2f..eaba46eed9 100644 --- a/sqlx-cli/src/database.rs +++ b/sqlx-cli/src/database.rs @@ -1,5 +1,5 @@ -use crate::migrate; -use crate::opt::ConnectOpts; +use crate::opt::{ConnectOpts, MigrationSourceOpt}; +use crate::{migrate, Config}; use console::{style, Term}; use dialoguer::Confirm; use sqlx::any::Any; @@ -19,14 +19,14 @@ pub async fn create(connect_opts: &ConnectOpts) -> anyhow::Result<()> { std::sync::atomic::Ordering::Release, ); - Any::create_database(connect_opts.required_db_url()?).await?; + Any::create_database(connect_opts.expect_db_url()?).await?; } Ok(()) } pub async fn drop(connect_opts: &ConnectOpts, confirm: bool, force: bool) -> anyhow::Result<()> { - if confirm && !ask_to_continue_drop(connect_opts.required_db_url()?.to_owned()).await { + if confirm && !ask_to_continue_drop(connect_opts.expect_db_url()?.to_owned()).await { return Ok(()); } @@ -36,9 +36,9 @@ pub async fn drop(connect_opts: &ConnectOpts, confirm: bool, force: bool) -> any if exists { if force { - Any::force_drop_database(connect_opts.required_db_url()?).await?; + Any::force_drop_database(connect_opts.expect_db_url()?).await?; } else { - Any::drop_database(connect_opts.required_db_url()?).await?; + Any::drop_database(connect_opts.expect_db_url()?).await?; } } @@ -46,18 +46,23 @@ pub async fn drop(connect_opts: &ConnectOpts, confirm: bool, force: bool) -> any } pub async fn reset( - migration_source: &str, + config: &Config, + migration_source: &MigrationSourceOpt, connect_opts: &ConnectOpts, confirm: bool, force: bool, ) -> anyhow::Result<()> { drop(connect_opts, confirm, force).await?; - setup(migration_source, connect_opts).await + setup(config, migration_source, connect_opts).await } -pub async fn setup(migration_source: &str, connect_opts: &ConnectOpts) -> anyhow::Result<()> { +pub async fn setup( + config: &Config, + migration_source: &MigrationSourceOpt, + connect_opts: &ConnectOpts, +) -> anyhow::Result<()> { create(connect_opts).await?; - migrate::run(migration_source, connect_opts, false, false, None).await + migrate::run(config, migration_source, connect_opts, false, false, None).await } async fn ask_to_continue_drop(db_url: String) -> bool { diff --git a/sqlx-cli/src/lib.rs b/sqlx-cli/src/lib.rs index cb31205b4f..bb9f46ccc4 100644 --- a/sqlx-cli/src/lib.rs +++ b/sqlx-cli/src/lib.rs @@ -1,7 +1,6 @@ use std::io; use std::time::Duration; -use anyhow::Result; use futures::{Future, TryFutureExt}; use sqlx::{AnyConnection, Connection}; @@ -21,6 +20,8 @@ mod prepare; pub use crate::opt::Opt; +pub use sqlx::_unstable::config::{self, Config}; + /// Check arguments for `--no-dotenv` _before_ Clap parsing, and apply `.env` if not set. pub fn maybe_apply_dotenv() { if std::env::args().any(|arg| arg == "--no-dotenv") { @@ -30,7 +31,7 @@ pub fn maybe_apply_dotenv() { dotenvy::dotenv().ok(); } -pub async fn run(opt: Opt) -> Result<()> { +pub async fn run(opt: Opt) -> anyhow::Result<()> { // This `select!` is here so that when the process receives a `SIGINT` (CTRL + C), // the futures currently running on this task get dropped before the program exits. // This is currently necessary for the consumers of the `dialoguer` crate to restore @@ -50,24 +51,24 @@ pub async fn run(opt: Opt) -> Result<()> { } } -async fn do_run(opt: Opt) -> Result<()> { +async fn do_run(opt: Opt) -> anyhow::Result<()> { match opt.command { Command::Migrate(migrate) => match migrate.command { - MigrateCommand::Add { - source, - description, - reversible, - sequential, - timestamp, - } => migrate::add(&source, &description, reversible, sequential, timestamp).await?, + MigrateCommand::Add(opts) => migrate::add(opts).await?, MigrateCommand::Run { source, + config, dry_run, ignore_missing, - connect_opts, + mut connect_opts, target_version, } => { + let config = config.load_config().await?; + + connect_opts.populate_db_url(config)?; + migrate::run( + config, &source, &connect_opts, dry_run, @@ -78,12 +79,18 @@ async fn do_run(opt: Opt) -> Result<()> { } MigrateCommand::Revert { source, + config, dry_run, ignore_missing, - connect_opts, + mut connect_opts, target_version, } => { + let config = config.load_config().await?; + + connect_opts.populate_db_url(config)?; + migrate::revert( + config, &source, &connect_opts, dry_run, @@ -94,37 +101,83 @@ async fn do_run(opt: Opt) -> Result<()> { } MigrateCommand::Info { source, - connect_opts, - } => migrate::info(&source, &connect_opts).await?, - MigrateCommand::BuildScript { source, force } => migrate::build_script(&source, force)?, + config, + mut connect_opts, + } => { + let config = config.load_config().await?; + + connect_opts.populate_db_url(config)?; + + migrate::info(config, &source, &connect_opts).await? + } + MigrateCommand::BuildScript { + source, + config, + force, + } => { + let config = config.load_config().await?; + + migrate::build_script(config, &source, force)? + } }, Command::Database(database) => match database.command { - DatabaseCommand::Create { connect_opts } => database::create(&connect_opts).await?, + DatabaseCommand::Create { + config, + mut connect_opts, + } => { + let config = config.load_config().await?; + + connect_opts.populate_db_url(config)?; + database::create(&connect_opts).await? + } DatabaseCommand::Drop { confirmation, - connect_opts, + config, + mut connect_opts, force, - } => database::drop(&connect_opts, !confirmation.yes, force).await?, + } => { + let config = config.load_config().await?; + + connect_opts.populate_db_url(config)?; + database::drop(&connect_opts, !confirmation.yes, force).await? + } DatabaseCommand::Reset { confirmation, source, - connect_opts, + config, + mut connect_opts, force, - } => database::reset(&source, &connect_opts, !confirmation.yes, force).await?, + } => { + let config = config.load_config().await?; + + connect_opts.populate_db_url(config)?; + database::reset(config, &source, &connect_opts, !confirmation.yes, force).await? + } DatabaseCommand::Setup { source, - connect_opts, - } => database::setup(&source, &connect_opts).await?, + config, + mut connect_opts, + } => { + let config = config.load_config().await?; + + connect_opts.populate_db_url(config)?; + database::setup(config, &source, &connect_opts).await? + } }, Command::Prepare { check, all, workspace, - connect_opts, + mut connect_opts, args, - } => prepare::run(check, all, workspace, connect_opts, args).await?, + config, + } => { + let config = config.load_config().await?; + connect_opts.populate_db_url(config)?; + prepare::run(check, all, workspace, connect_opts, args).await? + } #[cfg(feature = "completions")] Command::Completions { shell } => completions::run(shell), @@ -152,7 +205,7 @@ where { sqlx::any::install_default_drivers(); - let db_url = opts.required_db_url()?; + let db_url = opts.expect_db_url()?; backoff::future::retry( backoff::ExponentialBackoffBuilder::new() diff --git a/sqlx-cli/src/migrate.rs b/sqlx-cli/src/migrate.rs index e00f6de651..45a38b202a 100644 --- a/sqlx-cli/src/migrate.rs +++ b/sqlx-cli/src/migrate.rs @@ -1,6 +1,6 @@ -use crate::opt::ConnectOpts; +use crate::config::Config; +use crate::opt::{AddMigrationOpts, ConnectOpts, MigrationSourceOpt}; use anyhow::{bail, Context}; -use chrono::Utc; use console::style; use sqlx::migrate::{AppliedMigration, Migrate, MigrateError, MigrationType, Migrator}; use sqlx::Connection; @@ -11,142 +11,47 @@ use std::fs::{self, File}; use std::path::Path; use std::time::Duration; -fn create_file( - migration_source: &str, - file_prefix: &str, - description: &str, - migration_type: MigrationType, -) -> anyhow::Result<()> { - use std::path::PathBuf; - - let mut file_name = file_prefix.to_string(); - file_name.push('_'); - file_name.push_str(&description.replace(' ', "_")); - file_name.push_str(migration_type.suffix()); - - let mut path = PathBuf::new(); - path.push(migration_source); - path.push(&file_name); - - println!("Creating {}", style(path.display()).cyan()); - - let mut file = File::create(&path).context("Failed to create migration file")?; - - std::io::Write::write_all(&mut file, migration_type.file_content().as_bytes())?; - - Ok(()) -} +pub async fn add(opts: AddMigrationOpts) -> anyhow::Result<()> { + let config = opts.config.load_config().await?; -enum MigrationOrdering { - Timestamp(String), - Sequential(String), -} + let source = opts.source.resolve_path(config); -impl MigrationOrdering { - fn timestamp() -> MigrationOrdering { - Self::Timestamp(Utc::now().format("%Y%m%d%H%M%S").to_string()) - } - - fn sequential(version: i64) -> MigrationOrdering { - Self::Sequential(format!("{version:04}")) - } - - fn file_prefix(&self) -> &str { - match self { - MigrationOrdering::Timestamp(prefix) => prefix, - MigrationOrdering::Sequential(prefix) => prefix, - } - } - - fn infer(sequential: bool, timestamp: bool, migrator: &Migrator) -> Self { - match (timestamp, sequential) { - (true, true) => panic!("Impossible to specify both timestamp and sequential mode"), - (true, false) => MigrationOrdering::timestamp(), - (false, true) => MigrationOrdering::sequential( - migrator - .iter() - .last() - .map_or(1, |last_migration| last_migration.version + 1), - ), - (false, false) => { - // inferring the naming scheme - let migrations = migrator - .iter() - .filter(|migration| migration.migration_type.is_up_migration()) - .rev() - .take(2) - .collect::>(); - if let [last, pre_last] = &migrations[..] { - // there are at least two migrations, compare the last twothere's only one existing migration - if last.version - pre_last.version == 1 { - // their version numbers differ by 1, infer sequential - MigrationOrdering::sequential(last.version + 1) - } else { - MigrationOrdering::timestamp() - } - } else if let [last] = &migrations[..] { - // there is only one existing migration - if last.version == 0 || last.version == 1 { - // infer sequential if the version number is 0 or 1 - MigrationOrdering::sequential(last.version + 1) - } else { - MigrationOrdering::timestamp() - } - } else { - MigrationOrdering::timestamp() - } - } - } - } -} - -pub async fn add( - migration_source: &str, - description: &str, - reversible: bool, - sequential: bool, - timestamp: bool, -) -> anyhow::Result<()> { - fs::create_dir_all(migration_source).context("Unable to create migrations directory")?; + fs::create_dir_all(source).context("Unable to create migrations directory")?; - let migrator = Migrator::new(Path::new(migration_source)).await?; - // Type of newly created migration will be the same as the first one - // or reversible flag if this is the first migration - let migration_type = MigrationType::infer(&migrator, reversible); + let migrator = opts.source.resolve(config).await?; - let ordering = MigrationOrdering::infer(sequential, timestamp, &migrator); - let file_prefix = ordering.file_prefix(); + let version_prefix = opts.version_prefix(config, &migrator); - if migration_type.is_reversible() { + if opts.reversible(config, &migrator) { create_file( - migration_source, - file_prefix, - description, + source, + &version_prefix, + &opts.description, MigrationType::ReversibleUp, )?; create_file( - migration_source, - file_prefix, - description, + source, + &version_prefix, + &opts.description, MigrationType::ReversibleDown, )?; } else { create_file( - migration_source, - file_prefix, - description, + source, + &version_prefix, + &opts.description, MigrationType::Simple, )?; } // if the migrations directory is empty - let has_existing_migrations = fs::read_dir(migration_source) + let has_existing_migrations = fs::read_dir(source) .map(|mut dir| dir.next().is_some()) .unwrap_or(false); if !has_existing_migrations { - let quoted_source = if migration_source != "migrations" { - format!("{migration_source:?}") + let quoted_source = if opts.source.source.is_some() { + format!("{source:?}") } else { "".to_string() }; @@ -184,6 +89,32 @@ See: https://docs.rs/sqlx/{version}/sqlx/macro.migrate.html Ok(()) } +fn create_file( + migration_source: &str, + file_prefix: &str, + description: &str, + migration_type: MigrationType, +) -> anyhow::Result<()> { + use std::path::PathBuf; + + let mut file_name = file_prefix.to_string(); + file_name.push('_'); + file_name.push_str(&description.replace(' ', "_")); + file_name.push_str(migration_type.suffix()); + + let mut path = PathBuf::new(); + path.push(migration_source); + path.push(&file_name); + + println!("Creating {}", style(path.display()).cyan()); + + let mut file = File::create(&path).context("Failed to create migration file")?; + + std::io::Write::write_all(&mut file, migration_type.file_content().as_bytes())?; + + Ok(()) +} + fn short_checksum(checksum: &[u8]) -> String { let mut s = String::with_capacity(checksum.len() * 2); for b in checksum { @@ -192,14 +123,25 @@ fn short_checksum(checksum: &[u8]) -> String { s } -pub async fn info(migration_source: &str, connect_opts: &ConnectOpts) -> anyhow::Result<()> { - let migrator = Migrator::new(Path::new(migration_source)).await?; +pub async fn info( + config: &Config, + migration_source: &MigrationSourceOpt, + connect_opts: &ConnectOpts, +) -> anyhow::Result<()> { + let migrator = migration_source.resolve(config).await?; + let mut conn = crate::connect(connect_opts).await?; - conn.ensure_migrations_table().await?; + // FIXME: we shouldn't actually be creating anything here + for schema_name in &config.migrate.create_schemas { + conn.create_schema_if_not_exists(schema_name).await?; + } + + conn.ensure_migrations_table(config.migrate.table_name()) + .await?; let applied_migrations: HashMap<_, _> = conn - .list_applied_migrations() + .list_applied_migrations(config.migrate.table_name()) .await? .into_iter() .map(|m| (m.version, m)) @@ -272,13 +214,15 @@ fn validate_applied_migrations( } pub async fn run( - migration_source: &str, + config: &Config, + migration_source: &MigrationSourceOpt, connect_opts: &ConnectOpts, dry_run: bool, ignore_missing: bool, target_version: Option, ) -> anyhow::Result<()> { - let migrator = Migrator::new(Path::new(migration_source)).await?; + let migrator = migration_source.resolve(config).await?; + if let Some(target_version) = target_version { if !migrator.version_exists(target_version) { bail!(MigrateError::VersionNotPresent(target_version)); @@ -287,14 +231,21 @@ pub async fn run( let mut conn = crate::connect(connect_opts).await?; - conn.ensure_migrations_table().await?; + for schema_name in &config.migrate.create_schemas { + conn.create_schema_if_not_exists(schema_name).await?; + } - let version = conn.dirty_version().await?; + conn.ensure_migrations_table(config.migrate.table_name()) + .await?; + + let version = conn.dirty_version(config.migrate.table_name()).await?; if let Some(version) = version { bail!(MigrateError::Dirty(version)); } - let applied_migrations = conn.list_applied_migrations().await?; + let applied_migrations = conn + .list_applied_migrations(config.migrate.table_name()) + .await?; validate_applied_migrations(&applied_migrations, &migrator, ignore_missing)?; let latest_version = applied_migrations @@ -332,7 +283,7 @@ pub async fn run( let elapsed = if dry_run || skip { Duration::new(0, 0) } else { - conn.apply(migration).await? + conn.apply(config.migrate.table_name(), migration).await? }; let text = if skip { "Skipped" @@ -365,13 +316,15 @@ pub async fn run( } pub async fn revert( - migration_source: &str, + config: &Config, + migration_source: &MigrationSourceOpt, connect_opts: &ConnectOpts, dry_run: bool, ignore_missing: bool, target_version: Option, ) -> anyhow::Result<()> { - let migrator = Migrator::new(Path::new(migration_source)).await?; + let migrator = migration_source.resolve(config).await?; + if let Some(target_version) = target_version { if target_version != 0 && !migrator.version_exists(target_version) { bail!(MigrateError::VersionNotPresent(target_version)); @@ -380,14 +333,22 @@ pub async fn revert( let mut conn = crate::connect(connect_opts).await?; - conn.ensure_migrations_table().await?; + // FIXME: we should not be creating anything here if it doesn't exist + for schema_name in &config.migrate.create_schemas { + conn.create_schema_if_not_exists(schema_name).await?; + } + + conn.ensure_migrations_table(config.migrate.table_name()) + .await?; - let version = conn.dirty_version().await?; + let version = conn.dirty_version(config.migrate.table_name()).await?; if let Some(version) = version { bail!(MigrateError::Dirty(version)); } - let applied_migrations = conn.list_applied_migrations().await?; + let applied_migrations = conn + .list_applied_migrations(config.migrate.table_name()) + .await?; validate_applied_migrations(&applied_migrations, &migrator, ignore_missing)?; let latest_version = applied_migrations @@ -421,7 +382,7 @@ pub async fn revert( let elapsed = if dry_run || skip { Duration::new(0, 0) } else { - conn.revert(migration).await? + conn.revert(config.migrate.table_name(), migration).await? }; let text = if skip { "Skipped" @@ -458,7 +419,13 @@ pub async fn revert( Ok(()) } -pub fn build_script(migration_source: &str, force: bool) -> anyhow::Result<()> { +pub fn build_script( + config: &Config, + migration_source: &MigrationSourceOpt, + force: bool, +) -> anyhow::Result<()> { + let source = migration_source.resolve_path(config); + anyhow::ensure!( Path::new("Cargo.toml").exists(), "must be run in a Cargo project root" @@ -473,7 +440,7 @@ pub fn build_script(migration_source: &str, force: bool) -> anyhow::Result<()> { r#"// generated by `sqlx migrate build-script` fn main() {{ // trigger recompilation when a new migration is added - println!("cargo:rerun-if-changed={migration_source}"); + println!("cargo:rerun-if-changed={source}"); }} "#, ); diff --git a/sqlx-cli/src/migration.rs b/sqlx-cli/src/migration.rs deleted file mode 100644 index 2ed8f94495..0000000000 --- a/sqlx-cli/src/migration.rs +++ /dev/null @@ -1,187 +0,0 @@ -use anyhow::{bail, Context}; -use console::style; -use std::fs::{self, File}; -use std::io::{Read, Write}; - -const MIGRATION_FOLDER: &str = "migrations"; - -pub struct Migration { - pub name: String, - pub sql: String, -} - -pub fn add_file(name: &str) -> anyhow::Result<()> { - use chrono::prelude::*; - use std::path::PathBuf; - - fs::create_dir_all(MIGRATION_FOLDER).context("Unable to create migrations directory")?; - - let dt = Utc::now(); - let mut file_name = dt.format("%Y-%m-%d_%H-%M-%S").to_string(); - file_name.push_str("_"); - file_name.push_str(name); - file_name.push_str(".sql"); - - let mut path = PathBuf::new(); - path.push(MIGRATION_FOLDER); - path.push(&file_name); - - let mut file = File::create(path).context("Failed to create file")?; - file.write_all(b"-- Add migration script here") - .context("Could not write to file")?; - - println!("Created migration: '{file_name}'"); - Ok(()) -} - -pub async fn run() -> anyhow::Result<()> { - let migrator = crate::migrator::get()?; - - if !migrator.can_migrate_database() { - bail!( - "Database migrations not supported for {}", - migrator.database_type() - ); - } - - migrator.create_migration_table().await?; - - let migrations = load_migrations()?; - - for mig in migrations.iter() { - let mut tx = migrator.begin_migration().await?; - - if tx.check_if_applied(&mig.name).await? { - println!("Already applied migration: '{}'", mig.name); - continue; - } - println!("Applying migration: '{}'", mig.name); - - tx.execute_migration(&mig.sql) - .await - .with_context(|| format!("Failed to run migration {:?}", &mig.name))?; - - tx.save_applied_migration(&mig.name) - .await - .context("Failed to insert migration")?; - - tx.commit().await.context("Failed")?; - } - - Ok(()) -} - -pub async fn list() -> anyhow::Result<()> { - let migrator = crate::migrator::get()?; - - if !migrator.can_migrate_database() { - bail!( - "Database migrations not supported for {}", - migrator.database_type() - ); - } - - let file_migrations = load_migrations()?; - - if migrator - .check_if_database_exists(&migrator.get_database_name()?) - .await? - { - let applied_migrations = migrator.get_migrations().await.unwrap_or_else(|_| { - println!("Could not retrieve data from migration table"); - Vec::new() - }); - - let mut width = 0; - for mig in file_migrations.iter() { - width = std::cmp::max(width, mig.name.len()); - } - for mig in file_migrations.iter() { - let status = if applied_migrations - .iter() - .find(|&m| mig.name == *m) - .is_some() - { - style("Applied").green() - } else { - style("Not Applied").yellow() - }; - - println!("{:width$}\t{}", mig.name, status, width = width); - } - - let orphans = check_for_orphans(file_migrations, applied_migrations); - - if let Some(orphans) = orphans { - println!("\nFound migrations applied in the database that does not have a corresponding migration file:"); - for name in orphans { - println!("{:width$}\t{}", name, style("Orphan").red(), width = width); - } - } - } else { - println!("No database found, listing migrations"); - - for mig in file_migrations { - println!("{}", mig.name); - } - } - - Ok(()) -} - -fn load_migrations() -> anyhow::Result> { - let entries = fs::read_dir(&MIGRATION_FOLDER).context("Could not find 'migrations' dir")?; - - let mut migrations = Vec::new(); - - for e in entries { - if let Ok(e) = e { - if let Ok(meta) = e.metadata() { - if !meta.is_file() { - continue; - } - - if let Some(ext) = e.path().extension() { - if ext != "sql" { - println!("Wrong ext: {ext:?}"); - continue; - } - } else { - continue; - } - - let mut file = File::open(e.path()) - .with_context(|| format!("Failed to open: '{:?}'", e.file_name()))?; - let mut contents = String::new(); - file.read_to_string(&mut contents) - .with_context(|| format!("Failed to read: '{:?}'", e.file_name()))?; - - migrations.push(Migration { - name: e.file_name().to_str().unwrap().to_string(), - sql: contents, - }); - } - } - } - - migrations.sort_by(|a, b| a.name.partial_cmp(&b.name).unwrap()); - - Ok(migrations) -} - -fn check_for_orphans( - file_migrations: Vec, - applied_migrations: Vec, -) -> Option> { - let orphans: Vec = applied_migrations - .iter() - .filter(|m| !file_migrations.iter().any(|fm| fm.name == **m)) - .cloned() - .collect(); - - if orphans.len() > 0 { - Some(orphans) - } else { - None - } -} diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index 133ba084f2..1642e11c47 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -1,11 +1,17 @@ -use std::ops::{Deref, Not}; - +use crate::config::migrate::{DefaultMigrationType, DefaultVersioning}; +use crate::config::Config; +use anyhow::Context; +use chrono::Utc; use clap::{ builder::{styling::AnsiColor, Styles}, Args, Parser, }; #[cfg(feature = "completions")] use clap_complete::Shell; +use sqlx::migrate::{MigrateError, Migrator, ResolveWith}; +use std::env; +use std::ops::{Deref, Not}; +use std::path::PathBuf; const HELP_STYLES: Styles = Styles::styled() .header(AnsiColor::Blue.on_default().bold()) @@ -62,6 +68,9 @@ pub enum Command { #[clap(flatten)] connect_opts: ConnectOpts, + + #[clap(flatten)] + config: ConfigOpt, }, #[clap(alias = "mig")] @@ -85,6 +94,9 @@ pub enum DatabaseCommand { Create { #[clap(flatten)] connect_opts: ConnectOpts, + + #[clap(flatten)] + config: ConfigOpt, }, /// Drops the database specified in your DATABASE_URL. @@ -92,6 +104,9 @@ pub enum DatabaseCommand { #[clap(flatten)] confirmation: Confirmation, + #[clap(flatten)] + config: ConfigOpt, + #[clap(flatten)] connect_opts: ConnectOpts, @@ -106,7 +121,10 @@ pub enum DatabaseCommand { confirmation: Confirmation, #[clap(flatten)] - source: Source, + source: MigrationSourceOpt, + + #[clap(flatten)] + config: ConfigOpt, #[clap(flatten)] connect_opts: ConnectOpts, @@ -119,7 +137,10 @@ pub enum DatabaseCommand { /// Creates the database specified in your DATABASE_URL and runs any pending migrations. Setup { #[clap(flatten)] - source: Source, + source: MigrationSourceOpt, + + #[clap(flatten)] + config: ConfigOpt, #[clap(flatten)] connect_opts: ConnectOpts, @@ -137,8 +158,55 @@ pub struct MigrateOpt { pub enum MigrateCommand { /// Create a new migration with the given description. /// + /// -------------------------------- + /// + /// Migrations may either be simple, or reversible. + /// + /// Reversible migrations can be reverted with `sqlx migrate revert`, simple migrations cannot. + /// + /// Reversible migrations are created as a pair of two files with the same filename but + /// extensions `.up.sql` and `.down.sql` for the up-migration and down-migration, respectively. + /// + /// The up-migration should contain the commands to be used when applying the migration, + /// while the down-migration should contain the commands to reverse the changes made by the + /// up-migration. + /// + /// When writing down-migrations, care should be taken to ensure that they + /// do not leave the database in an inconsistent state. + /// + /// Simple migrations have just `.sql` for their extension and represent an up-migration only. + /// + /// Note that reverting a migration is **destructive** and will likely result in data loss. + /// Reverting a migration will not restore any data discarded by commands in the up-migration. + /// + /// It is recommended to always back up the database before running migrations. + /// + /// -------------------------------- + /// + /// For convenience, this command attempts to detect if reversible migrations are in-use. + /// + /// If the latest existing migration is reversible, the new migration will also be reversible. + /// + /// Otherwise, a simple migration is created. + /// + /// This behavior can be overridden by `--simple` or `--reversible`, respectively. + /// + /// The default type to use can also be set in `sqlx.toml`. + /// + /// -------------------------------- + /// /// A version number will be automatically assigned to the migration. /// + /// Migrations are applied in ascending order by version number. + /// Version numbers do not need to be strictly consecutive. + /// + /// The migration process will abort if SQLx encounters a migration with a version number + /// less than _any_ previously applied migration. + /// + /// Migrations should only be created with increasing version number. + /// + /// -------------------------------- + /// /// For convenience, this command will attempt to detect if sequential versioning is in use, /// and if so, continue the sequence. /// @@ -148,33 +216,20 @@ pub enum MigrateCommand { /// /// * only one migration exists and its version number is either 0 or 1. /// - /// Otherwise timestamp versioning is assumed. + /// Otherwise, timestamp versioning (`YYYYMMDDHHMMSS`) is assumed. /// - /// This behavior can overridden by `--sequential` or `--timestamp`, respectively. - Add { - description: String, - - #[clap(flatten)] - source: Source, - - /// If true, creates a pair of up and down migration files with same version - /// else creates a single sql file - #[clap(short)] - reversible: bool, - - /// If set, use timestamp versioning for the new migration. Conflicts with `--sequential`. - #[clap(short, long)] - timestamp: bool, - - /// If set, use sequential versioning for the new migration. Conflicts with `--timestamp`. - #[clap(short, long, conflicts_with = "timestamp")] - sequential: bool, - }, + /// This behavior can be overridden by `--timestamp` or `--sequential`, respectively. + /// + /// The default versioning to use can also be set in `sqlx.toml`. + Add(AddMigrationOpts), /// Run all pending migrations. Run { #[clap(flatten)] - source: Source, + source: MigrationSourceOpt, + + #[clap(flatten)] + config: ConfigOpt, /// List all the migrations to be run without applying #[clap(long)] @@ -195,7 +250,10 @@ pub enum MigrateCommand { /// Revert the latest migration with a down file. Revert { #[clap(flatten)] - source: Source, + source: MigrationSourceOpt, + + #[clap(flatten)] + config: ConfigOpt, /// List the migration to be reverted without applying #[clap(long)] @@ -217,7 +275,10 @@ pub enum MigrateCommand { /// List all available migrations. Info { #[clap(flatten)] - source: Source, + source: MigrationSourceOpt, + + #[clap(flatten)] + config: ConfigOpt, #[clap(flatten)] connect_opts: ConnectOpts, @@ -228,7 +289,10 @@ pub enum MigrateCommand { /// Must be run in a Cargo project root. BuildScript { #[clap(flatten)] - source: Source, + source: MigrationSourceOpt, + + #[clap(flatten)] + config: ConfigOpt, /// Overwrite the build script if it already exists. #[clap(long)] @@ -236,19 +300,62 @@ pub enum MigrateCommand { }, } +#[derive(Args, Debug)] +pub struct AddMigrationOpts { + pub description: String, + + #[clap(flatten)] + pub source: MigrationSourceOpt, + + #[clap(flatten)] + pub config: ConfigOpt, + + /// If set, create an up-migration only. Conflicts with `--reversible`. + #[clap(long, conflicts_with = "reversible")] + simple: bool, + + /// If set, create a pair of up and down migration files with same version. + /// + /// Conflicts with `--simple`. + #[clap(short, long, conflicts_with = "simple")] + reversible: bool, + + /// If set, use timestamp versioning for the new migration. Conflicts with `--sequential`. + /// + /// Timestamp format: `YYYYMMDDHHMMSS` + #[clap(short, long, conflicts_with = "sequential")] + timestamp: bool, + + /// If set, use sequential versioning for the new migration. Conflicts with `--timestamp`. + #[clap(short, long, conflicts_with = "timestamp")] + sequential: bool, +} + /// Argument for the migration scripts source. #[derive(Args, Debug)] -pub struct Source { +pub struct MigrationSourceOpt { /// Path to folder containing migrations. - #[clap(long, default_value = "migrations")] - source: String, + /// + /// Defaults to `migrations/` if not specified, but a different default may be set by `sqlx.toml`. + #[clap(long)] + pub source: Option, } -impl Deref for Source { - type Target = String; +impl MigrationSourceOpt { + pub fn resolve_path<'a>(&'a self, config: &'a Config) -> &'a str { + if let Some(source) = &self.source { + return source; + } - fn deref(&self) -> &Self::Target { - &self.source + config.migrate.migrations_dir() + } + + pub async fn resolve(&self, config: &Config) -> Result { + Migrator::new(ResolveWith( + self.resolve_path(config), + config.migrate.to_resolve_config(), + )) + .await } } @@ -259,7 +366,7 @@ pub struct ConnectOpts { pub no_dotenv: NoDotenvOpt, /// Location of the DB, by default will be read from the DATABASE_URL env var or `.env` files. - #[clap(long, short = 'D', env)] + #[clap(long, short = 'D')] pub database_url: Option, /// The maximum time, in seconds, to try connecting to the database server before @@ -290,15 +397,85 @@ pub struct NoDotenvOpt { pub no_dotenv: bool, } +#[derive(Args, Debug)] +pub struct ConfigOpt { + /// Override the path to the config file. + /// + /// Defaults to `sqlx.toml` in the current directory, if it exists. + /// + /// Configuration file loading may be bypassed with `--config=/dev/null` on Linux, + /// or `--config=NUL` on Windows. + /// + /// Config file loading is enabled by the `sqlx-toml` feature. + #[clap(long)] + pub config: Option, +} + impl ConnectOpts { /// Require a database URL to be provided, otherwise /// return an error. - pub fn required_db_url(&self) -> anyhow::Result<&str> { - self.database_url.as_deref().ok_or_else( - || anyhow::anyhow!( - "the `--database-url` option or the `DATABASE_URL` environment variable must be provided" - ) - ) + pub fn expect_db_url(&self) -> anyhow::Result<&str> { + self.database_url + .as_deref() + .context("BUG: database_url not populated") + } + + /// Populate `database_url` from the environment, if not set. + pub fn populate_db_url(&mut self, config: &Config) -> anyhow::Result<()> { + if self.database_url.is_some() { + return Ok(()); + } + + let var = config.common.database_url_var(); + + let context = if var != "DATABASE_URL" { + " (`common.database-url-var` in `sqlx.toml`)" + } else { + "" + }; + + match env::var(var) { + Ok(url) => { + if !context.is_empty() { + eprintln!("Read database url from `{var}`{context}"); + } + + self.database_url = Some(url) + } + Err(env::VarError::NotPresent) => { + anyhow::bail!("`--database-url` or `{var}`{context} must be set") + } + Err(env::VarError::NotUnicode(_)) => { + anyhow::bail!("`{var}`{context} is not valid UTF-8"); + } + } + + Ok(()) + } +} + +impl ConfigOpt { + pub async fn load_config(&self) -> anyhow::Result<&'static Config> { + let path = self.config.clone(); + + // Tokio does file I/O on a background task anyway + tokio::task::spawn_blocking(|| { + if let Some(path) = path { + let err_str = format!("error reading config from {path:?}"); + Config::try_from_path(path).context(err_str) + } else { + let path = PathBuf::from("sqlx.toml"); + + if path.exists() { + eprintln!("Found `sqlx.toml` in current directory; reading..."); + Ok(Config::try_from_path(path)?) + } else { + Ok(Config::get_or_default()) + } + } + }) + .await + .context("unexpected error loading config")? } } @@ -334,3 +511,67 @@ impl Not for IgnoreMissing { !self.ignore_missing } } + +impl AddMigrationOpts { + pub fn reversible(&self, config: &Config, migrator: &Migrator) -> bool { + if self.reversible { + return true; + } + if self.simple { + return false; + } + + match config.migrate.defaults.migration_type { + DefaultMigrationType::Inferred => migrator + .iter() + .last() + .is_some_and(|m| m.migration_type.is_reversible()), + DefaultMigrationType::Simple => false, + DefaultMigrationType::Reversible => true, + } + } + + pub fn version_prefix(&self, config: &Config, migrator: &Migrator) -> String { + let default_versioning = &config.migrate.defaults.migration_versioning; + + match (self.timestamp, self.sequential, default_versioning) { + (true, false, _) | (false, false, DefaultVersioning::Timestamp) => next_timestamp(), + (false, true, _) | (false, false, DefaultVersioning::Sequential) => fmt_sequential( + migrator + .migrations + .last() + .map_or(1, |migration| migration.version + 1), + ), + (false, false, DefaultVersioning::Inferred) => { + migrator + .migrations + .rchunks(2) + .next() + .and_then(|migrations| { + match migrations { + [previous, latest] => { + // If the latest two versions differ by 1, infer sequential. + (latest.version - previous.version == 1) + .then_some(latest.version + 1) + } + [latest] => { + // If only one migration exists and its version is 0 or 1, infer sequential + matches!(latest.version, 0 | 1).then_some(latest.version + 1) + } + _ => unreachable!(), + } + }) + .map_or_else(next_timestamp, fmt_sequential) + } + (true, true, _) => unreachable!("BUG: Clap should have rejected this case"), + } + } +} + +fn next_timestamp() -> String { + Utc::now().format("%Y%m%d%H%M%S").to_string() +} + +fn fmt_sequential(version: i64) -> String { + format!("{version:04}") +} diff --git a/sqlx-cli/tests/add.rs b/sqlx-cli/tests/add.rs index 1d5ed7c7dd..cebbb51d53 100644 --- a/sqlx-cli/tests/add.rs +++ b/sqlx-cli/tests/add.rs @@ -1,20 +1,11 @@ +use anyhow::Context; use assert_cmd::Command; use std::cmp::Ordering; use std::fs::read_dir; +use std::ops::Index; use std::path::{Path, PathBuf}; use tempfile::TempDir; -#[test] -fn add_migration_ambiguous() -> anyhow::Result<()> { - for reversible in [true, false] { - let files = AddMigrations::new()? - .run("hello world", reversible, true, true, false)? - .fs_output()?; - assert_eq!(files.0, Vec::::new()); - } - Ok(()) -} - #[derive(Debug, PartialEq, Eq)] struct FileName { id: u64, @@ -34,11 +25,6 @@ impl PartialOrd for FileName { impl FileName { fn assert_is_timestamp(&self) { - //if the library is still used in 2050, this will need bumping ^^ - assert!( - self.id < 20500101000000, - "{self:?} is too high for a timestamp" - ); assert!( self.id > 20200101000000, "{self:?} is too low for a timestamp" @@ -59,6 +45,154 @@ impl From for FileName { } } } + +struct AddMigrationsResult(Vec); +impl AddMigrationsResult { + fn len(&self) -> usize { + self.0.len() + } + fn assert_is_reversible(&self) { + let mut up_cnt = 0; + let mut down_cnt = 0; + for file in self.0.iter() { + if file.suffix == "down.sql" { + down_cnt += 1; + } else if file.suffix == "up.sql" { + up_cnt += 1; + } else { + panic!("unknown suffix for {file:?}"); + } + assert!(file.description.starts_with("hello_world")); + } + assert_eq!(up_cnt, down_cnt); + } + fn assert_is_not_reversible(&self) { + for file in self.0.iter() { + assert_eq!(file.suffix, "sql"); + assert!(file.description.starts_with("hello_world")); + } + } +} + +impl Index for AddMigrationsResult { + type Output = FileName; + + fn index(&self, index: usize) -> &Self::Output { + &self.0[index] + } +} + +struct AddMigrations { + tempdir: TempDir, + config_arg: Option, +} + +impl AddMigrations { + fn new() -> anyhow::Result { + anyhow::Ok(Self { + tempdir: TempDir::new()?, + config_arg: None, + }) + } + + fn with_config(mut self, filename: &str) -> anyhow::Result { + let path = format!("./tests/assets/{filename}"); + + let path = std::fs::canonicalize(&path) + .with_context(|| format!("error canonicalizing path {path:?}"))?; + + let path = path + .to_str() + .with_context(|| format!("canonicalized version of path {path:?} is not UTF-8"))?; + + self.config_arg = Some(format!("--config={path}")); + Ok(self) + } + + fn run( + &self, + description: &str, + revesible: bool, + timestamp: bool, + sequential: bool, + expect_success: bool, + ) -> anyhow::Result<&'_ Self> { + let cmd_result = Command::cargo_bin("cargo-sqlx")? + .current_dir(&self.tempdir) + .args( + [ + vec!["sqlx", "migrate", "add", description], + self.config_arg.as_deref().map_or(vec![], |arg| vec![arg]), + match revesible { + true => vec!["-r"], + false => vec![], + }, + match timestamp { + true => vec!["--timestamp"], + false => vec![], + }, + match sequential { + true => vec!["--sequential"], + false => vec![], + }, + ] + .concat(), + ) + .env("RUST_BACKTRACE", "1") + .assert(); + if expect_success { + cmd_result.success(); + } else { + cmd_result.failure(); + } + anyhow::Ok(self) + } + fn fs_output(&self) -> anyhow::Result { + let files = recurse_files(&self.tempdir)?; + let mut fs_paths = Vec::with_capacity(files.len()); + for path in files { + let relative_path = path.strip_prefix(self.tempdir.path())?.to_path_buf(); + fs_paths.push(FileName::from(relative_path)); + } + Ok(AddMigrationsResult(fs_paths)) + } +} + +fn recurse_files(path: impl AsRef) -> anyhow::Result> { + let mut buf = vec![]; + let entries = read_dir(path)?; + + for entry in entries { + let entry = entry?; + let meta = entry.metadata()?; + + if meta.is_dir() { + let mut subdir = recurse_files(entry.path())?; + buf.append(&mut subdir); + } + + if meta.is_file() { + buf.push(entry.path()); + } + } + buf.sort(); + Ok(buf) +} + +#[test] +fn add_migration_error_ambiguous() -> anyhow::Result<()> { + for reversible in [true, false] { + let files = AddMigrations::new()? + // Passing both `--timestamp` and `--reversible` should result in an error. + .run("hello world", reversible, true, true, false)? + .fs_output()?; + + // Assert that no files are created + assert_eq!(files.0, []); + } + Ok(()) +} + #[test] fn add_migration_sequential() -> anyhow::Result<()> { { @@ -74,10 +208,12 @@ fn add_migration_sequential() -> anyhow::Result<()> { .run("hello world1", false, false, true, true)? .run("hello world2", true, false, true, true)? .fs_output()?; - assert_eq!(files.len(), 2); - files.assert_is_not_reversible(); + assert_eq!(files.len(), 3); assert_eq!(files.0[0].id, 1); assert_eq!(files.0[1].id, 2); + assert_eq!(files.0[1].suffix, "down.sql"); + assert_eq!(files.0[2].id, 2); + assert_eq!(files.0[2].suffix, "up.sql"); } Ok(()) } @@ -126,146 +262,145 @@ fn add_migration_timestamp() -> anyhow::Result<()> { .run("hello world1", false, true, false, true)? .run("hello world2", true, false, true, true)? .fs_output()?; - assert_eq!(files.len(), 2); - files.assert_is_not_reversible(); + assert_eq!(files.len(), 3); files.0[0].assert_is_timestamp(); // sequential -> timestamp is one way files.0[1].assert_is_timestamp(); + files.0[2].assert_is_timestamp(); } Ok(()) } + #[test] fn add_migration_timestamp_reversible() -> anyhow::Result<()> { { let files = AddMigrations::new()? .run("hello world", true, false, false, true)? .fs_output()?; + assert_eq!(files.len(), 2); files.assert_is_reversible(); - files.0[0].assert_is_timestamp(); - files.0[1].assert_is_timestamp(); + + // .up.sql and .down.sql + files[0].assert_is_timestamp(); + assert_eq!(files[1].id, files[0].id); } { let files = AddMigrations::new()? .run("hello world", true, true, false, true)? .fs_output()?; + assert_eq!(files.len(), 2); files.assert_is_reversible(); - files.0[0].assert_is_timestamp(); - files.0[1].assert_is_timestamp(); + + // .up.sql and .down.sql + files[0].assert_is_timestamp(); + assert_eq!(files[1].id, files[0].id); } { let files = AddMigrations::new()? .run("hello world1", true, true, false, true)? - .run("hello world2", true, false, true, true)? + // Reversible should be inferred, but sequential should be forced + .run("hello world2", false, false, true, true)? .fs_output()?; + assert_eq!(files.len(), 4); files.assert_is_reversible(); - files.0[0].assert_is_timestamp(); - files.0[1].assert_is_timestamp(); - files.0[2].assert_is_timestamp(); - files.0[3].assert_is_timestamp(); + + // First pair: .up.sql and .down.sql + files[0].assert_is_timestamp(); + assert_eq!(files[1].id, files[0].id); + + // Second pair; we set `--sequential` so this version should be one higher + assert_eq!(files[2].id, files[1].id + 1); + assert_eq!(files[3].id, files[1].id + 1); } Ok(()) } -struct AddMigrationsResult(Vec); -impl AddMigrationsResult { - fn len(&self) -> usize { - self.0.len() - } - fn assert_is_reversible(&self) { - let mut up_cnt = 0; - let mut down_cnt = 0; - for file in self.0.iter() { - if file.suffix == "down.sql" { - down_cnt += 1; - } else if file.suffix == "up.sql" { - up_cnt += 1; - } else { - panic!("unknown suffix for {file:?}"); - } - assert!(file.description.starts_with("hello_world")); - } - assert_eq!(up_cnt, down_cnt); - } - fn assert_is_not_reversible(&self) { - for file in self.0.iter() { - assert_eq!(file.suffix, "sql"); - assert!(file.description.starts_with("hello_world")); - } - } +#[test] +fn add_migration_config_default_type_reversible() -> anyhow::Result<()> { + let files = AddMigrations::new()? + .with_config("config_default_type_reversible.toml")? + // Type should default to reversible without any flags + .run("hello world", false, false, false, true)? + .run("hello world2", false, false, false, true)? + .run("hello world3", false, false, false, true)? + .fs_output()?; + + assert_eq!(files.len(), 6); + files.assert_is_reversible(); + + files[0].assert_is_timestamp(); + assert_eq!(files[1].id, files[0].id); + + files[2].assert_is_timestamp(); + assert_eq!(files[3].id, files[2].id); + + files[4].assert_is_timestamp(); + assert_eq!(files[5].id, files[4].id); + + Ok(()) } -struct AddMigrations(TempDir); -impl AddMigrations { - fn new() -> anyhow::Result { - anyhow::Ok(Self(TempDir::new()?)) - } - fn run( - self, - description: &str, - revesible: bool, - timestamp: bool, - sequential: bool, - expect_success: bool, - ) -> anyhow::Result { - let cmd_result = Command::cargo_bin("cargo-sqlx")? - .current_dir(&self.0) - .args( - [ - vec!["sqlx", "migrate", "add", description], - match revesible { - true => vec!["-r"], - false => vec![], - }, - match timestamp { - true => vec!["--timestamp"], - false => vec![], - }, - match sequential { - true => vec!["--sequential"], - false => vec![], - }, - ] - .concat(), - ) - .assert(); - if expect_success { - cmd_result.success(); - } else { - cmd_result.failure(); - } - anyhow::Ok(self) - } - fn fs_output(&self) -> anyhow::Result { - let files = recurse_files(&self.0)?; - let mut fs_paths = Vec::with_capacity(files.len()); - for path in files { - let relative_path = path.strip_prefix(self.0.path())?.to_path_buf(); - fs_paths.push(FileName::from(relative_path)); - } - Ok(AddMigrationsResult(fs_paths)) - } +#[test] +fn add_migration_config_default_versioning_sequential() -> anyhow::Result<()> { + let files = AddMigrations::new()? + .with_config("config_default_versioning_sequential.toml")? + // Versioning should default to timestamp without any flags + .run("hello world", false, false, false, true)? + .run("hello world2", false, false, false, true)? + .run("hello world3", false, false, false, true)? + .fs_output()?; + + assert_eq!(files.len(), 3); + files.assert_is_not_reversible(); + + assert_eq!(files[0].id, 1); + assert_eq!(files[1].id, 2); + assert_eq!(files[2].id, 3); + + Ok(()) } -fn recurse_files(path: impl AsRef) -> anyhow::Result> { - let mut buf = vec![]; - let entries = read_dir(path)?; +#[test] +fn add_migration_config_default_versioning_timestamp() -> anyhow::Result<()> { + let migrations = AddMigrations::new()?; - for entry in entries { - let entry = entry?; - let meta = entry.metadata()?; + migrations + .run("hello world", false, false, true, true)? + // Default config should infer sequential even without passing `--sequential` + .run("hello world2", false, false, false, true)? + .run("hello world3", false, false, false, true)?; - if meta.is_dir() { - let mut subdir = recurse_files(entry.path())?; - buf.append(&mut subdir); - } + let files = migrations.fs_output()?; - if meta.is_file() { - buf.push(entry.path()); - } - } - buf.sort(); - Ok(buf) + assert_eq!(files.len(), 3); + files.assert_is_not_reversible(); + + assert_eq!(files[0].id, 1); + assert_eq!(files[1].id, 2); + assert_eq!(files[2].id, 3); + + // Now set a config that uses `default-versioning = "timestamp"` + let migrations = migrations.with_config("config_default_versioning_timestamp.toml")?; + + // Now the default should be a timestamp + migrations + .run("hello world4", false, false, false, true)? + .run("hello world5", false, false, false, true)?; + + let files = migrations.fs_output()?; + + assert_eq!(files.len(), 5); + files.assert_is_not_reversible(); + + assert_eq!(files[0].id, 1); + assert_eq!(files[1].id, 2); + assert_eq!(files[2].id, 3); + + files[3].assert_is_timestamp(); + files[4].assert_is_timestamp(); + + Ok(()) } diff --git a/sqlx-cli/tests/assets/config_default_type_reversible.toml b/sqlx-cli/tests/assets/config_default_type_reversible.toml new file mode 100644 index 0000000000..79d7de0b65 --- /dev/null +++ b/sqlx-cli/tests/assets/config_default_type_reversible.toml @@ -0,0 +1,2 @@ +[migrate.defaults] +migration-type = "reversible" \ No newline at end of file diff --git a/sqlx-cli/tests/assets/config_default_versioning_sequential.toml b/sqlx-cli/tests/assets/config_default_versioning_sequential.toml new file mode 100644 index 0000000000..8cf275c2e1 --- /dev/null +++ b/sqlx-cli/tests/assets/config_default_versioning_sequential.toml @@ -0,0 +1,2 @@ +[migrate.defaults] +migration-versioning = "sequential" \ No newline at end of file diff --git a/sqlx-cli/tests/assets/config_default_versioning_timestamp.toml b/sqlx-cli/tests/assets/config_default_versioning_timestamp.toml new file mode 100644 index 0000000000..15892dc1ca --- /dev/null +++ b/sqlx-cli/tests/assets/config_default_versioning_timestamp.toml @@ -0,0 +1,2 @@ +[migrate.defaults] +migration-versioning = "timestamp" \ No newline at end of file diff --git a/sqlx-cli/tests/common/mod.rs b/sqlx-cli/tests/common/mod.rs index 43c0dbc1e1..66e7924859 100644 --- a/sqlx-cli/tests/common/mod.rs +++ b/sqlx-cli/tests/common/mod.rs @@ -1,25 +1,41 @@ use assert_cmd::{assert::Assert, Command}; +use sqlx::_unstable::config::Config; use sqlx::{migrate::Migrate, Connection, SqliteConnection}; use std::{ - env::temp_dir, - fs::remove_file, + env, fs, path::{Path, PathBuf}, }; pub struct TestDatabase { file_path: PathBuf, - migrations: String, + migrations_path: PathBuf, + pub config_path: Option, } impl TestDatabase { pub fn new(name: &str, migrations: &str) -> Self { - let migrations_path = Path::new("tests").join(migrations); - let file_path = Path::new(&temp_dir()).join(format!("test-{}.db", name)); - let ret = Self { + // Note: only set when _building_ + let temp_dir = option_env!("CARGO_TARGET_TMPDIR").map_or_else(env::temp_dir, PathBuf::from); + + let test_dir = temp_dir.join("migrate"); + + fs::create_dir_all(&test_dir) + .unwrap_or_else(|e| panic!("error creating directory: {test_dir:?}: {e}")); + + let file_path = test_dir.join(format!("test-{name}.db")); + + if file_path.exists() { + fs::remove_file(&file_path) + .unwrap_or_else(|e| panic!("error deleting test database {file_path:?}: {e}")); + } + + let this = Self { file_path, - migrations: String::from(migrations_path.to_str().unwrap()), + migrations_path: Path::new("tests").join(migrations), + config_path: None, }; + Command::cargo_bin("cargo-sqlx") .unwrap() .args([ @@ -27,11 +43,15 @@ impl TestDatabase { "database", "create", "--database-url", - &ret.connection_string(), + &this.connection_string(), ]) .assert() .success(); - ret + this + } + + pub fn set_migrations(&mut self, migrations: &str) { + self.migrations_path = Path::new("tests").join(migrations); } pub fn connection_string(&self) -> String { @@ -39,55 +59,77 @@ impl TestDatabase { } pub fn run_migration(&self, revert: bool, version: Option, dry_run: bool) -> Assert { - let ver = match version { - Some(v) => v.to_string(), - None => String::from(""), - }; - Command::cargo_bin("cargo-sqlx") - .unwrap() - .args( - [ - vec![ - "sqlx", - "migrate", - match revert { - true => "revert", - false => "run", - }, - "--database-url", - &self.connection_string(), - "--source", - &self.migrations, - ], - match version { - Some(_) => vec!["--target-version", &ver], - None => vec![], - }, - match dry_run { - true => vec!["--dry-run"], - false => vec![], - }, - ] - .concat(), - ) - .assert() + let mut command = Command::cargo_bin("sqlx").unwrap(); + command + .args([ + "migrate", + match revert { + true => "revert", + false => "run", + }, + "--database-url", + &self.connection_string(), + "--source", + ]) + .arg(&self.migrations_path); + + if let Some(config_path) = &self.config_path { + command.arg("--config").arg(config_path); + } + + if let Some(version) = version { + command.arg("--target-version").arg(version.to_string()); + } + + if dry_run { + command.arg("--dry-run"); + } + + command.assert() } pub async fn applied_migrations(&self) -> Vec { let mut conn = SqliteConnection::connect(&self.connection_string()) .await .unwrap(); - conn.list_applied_migrations() + + let config = Config::default(); + + conn.list_applied_migrations(config.migrate.table_name()) .await .unwrap() .iter() .map(|m| m.version) .collect() } + + pub fn migrate_info(&self) -> Assert { + let mut command = Command::cargo_bin("sqlx").unwrap(); + command + .args([ + "migrate", + "info", + "--database-url", + &self.connection_string(), + "--source", + ]) + .arg(&self.migrations_path); + + if let Some(config_path) = &self.config_path { + command.arg("--config").arg(config_path); + } + + command.assert() + } } impl Drop for TestDatabase { fn drop(&mut self) { - remove_file(&self.file_path).unwrap(); + // Only remove the database if there isn't a failure. + if !std::thread::panicking() { + fs::remove_file(&self.file_path).unwrap_or_else(|e| { + panic!("error deleting test database {:?}: {e}", self.file_path) + }); + } } } diff --git a/sqlx-cli/tests/ignored-chars/BOM/.gitattributes b/sqlx-cli/tests/ignored-chars/BOM/.gitattributes new file mode 100644 index 0000000000..cc2d335b83 --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/BOM/.gitattributes @@ -0,0 +1 @@ +*.sql text eol=lf diff --git a/sqlx-cli/tests/ignored-chars/BOM/1_user.sql b/sqlx-cli/tests/ignored-chars/BOM/1_user.sql new file mode 100644 index 0000000000..166ce39ca8 --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/BOM/1_user.sql @@ -0,0 +1,6 @@ +create table user +( + -- integer primary keys are the most efficient in SQLite + user_id integer primary key, + username text unique not null +); diff --git a/sqlx-cli/tests/ignored-chars/BOM/2_post.sql b/sqlx-cli/tests/ignored-chars/BOM/2_post.sql new file mode 100644 index 0000000000..a65420a57d --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/BOM/2_post.sql @@ -0,0 +1,10 @@ +create table post +( + post_id integer primary key, + user_id integer not null references user (user_id), + content text not null, + -- Defaults have to be wrapped in parenthesis + created_at datetime default (datetime('now')) +); + +create index post_created_at on post (created_at desc); diff --git a/sqlx-cli/tests/ignored-chars/BOM/3_comment.sql b/sqlx-cli/tests/ignored-chars/BOM/3_comment.sql new file mode 100644 index 0000000000..cc02ae3f0b --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/BOM/3_comment.sql @@ -0,0 +1,10 @@ +create table comment +( + comment_id integer primary key, + post_id integer not null references post (post_id), + user_id integer not null references "user" (user_id), + content text not null, + created_at datetime default (datetime('now')) +); + +create index comment_created_at on comment (created_at desc); diff --git a/sqlx-cli/tests/ignored-chars/CRLF/.gitattributes b/sqlx-cli/tests/ignored-chars/CRLF/.gitattributes new file mode 100644 index 0000000000..5645bd9e1a --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/CRLF/.gitattributes @@ -0,0 +1 @@ +*.sql text eol=crlf diff --git a/sqlx-cli/tests/ignored-chars/CRLF/1_user.sql b/sqlx-cli/tests/ignored-chars/CRLF/1_user.sql new file mode 100644 index 0000000000..100b750f19 --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/CRLF/1_user.sql @@ -0,0 +1,6 @@ +create table user +( + -- integer primary keys are the most efficient in SQLite + user_id integer primary key, + username text unique not null +); diff --git a/sqlx-cli/tests/ignored-chars/CRLF/2_post.sql b/sqlx-cli/tests/ignored-chars/CRLF/2_post.sql new file mode 100644 index 0000000000..74d2460596 --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/CRLF/2_post.sql @@ -0,0 +1,10 @@ +create table post +( + post_id integer primary key, + user_id integer not null references user (user_id), + content text not null, + -- Defaults have to be wrapped in parenthesis + created_at datetime default (datetime('now')) +); + +create index post_created_at on post (created_at desc); diff --git a/sqlx-cli/tests/ignored-chars/CRLF/3_comment.sql b/sqlx-cli/tests/ignored-chars/CRLF/3_comment.sql new file mode 100644 index 0000000000..a98b2628fc --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/CRLF/3_comment.sql @@ -0,0 +1,10 @@ +create table comment +( + comment_id integer primary key, + post_id integer not null references post (post_id), + user_id integer not null references "user" (user_id), + content text not null, + created_at datetime default (datetime('now')) +); + +create index comment_created_at on comment (created_at desc); diff --git a/sqlx-cli/tests/ignored-chars/LF/.gitattributes b/sqlx-cli/tests/ignored-chars/LF/.gitattributes new file mode 100644 index 0000000000..cc2d335b83 --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/LF/.gitattributes @@ -0,0 +1 @@ +*.sql text eol=lf diff --git a/sqlx-cli/tests/ignored-chars/LF/1_user.sql b/sqlx-cli/tests/ignored-chars/LF/1_user.sql new file mode 100644 index 0000000000..100b750f19 --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/LF/1_user.sql @@ -0,0 +1,6 @@ +create table user +( + -- integer primary keys are the most efficient in SQLite + user_id integer primary key, + username text unique not null +); diff --git a/sqlx-cli/tests/ignored-chars/LF/2_post.sql b/sqlx-cli/tests/ignored-chars/LF/2_post.sql new file mode 100644 index 0000000000..74d2460596 --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/LF/2_post.sql @@ -0,0 +1,10 @@ +create table post +( + post_id integer primary key, + user_id integer not null references user (user_id), + content text not null, + -- Defaults have to be wrapped in parenthesis + created_at datetime default (datetime('now')) +); + +create index post_created_at on post (created_at desc); diff --git a/sqlx-cli/tests/ignored-chars/LF/3_comment.sql b/sqlx-cli/tests/ignored-chars/LF/3_comment.sql new file mode 100644 index 0000000000..a98b2628fc --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/LF/3_comment.sql @@ -0,0 +1,10 @@ +create table comment +( + comment_id integer primary key, + post_id integer not null references post (post_id), + user_id integer not null references "user" (user_id), + content text not null, + created_at datetime default (datetime('now')) +); + +create index comment_created_at on comment (created_at desc); diff --git a/sqlx-cli/tests/ignored-chars/oops-all-tabs/.gitattributes b/sqlx-cli/tests/ignored-chars/oops-all-tabs/.gitattributes new file mode 100644 index 0000000000..cc2d335b83 --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/oops-all-tabs/.gitattributes @@ -0,0 +1 @@ +*.sql text eol=lf diff --git a/sqlx-cli/tests/ignored-chars/oops-all-tabs/1_user.sql b/sqlx-cli/tests/ignored-chars/oops-all-tabs/1_user.sql new file mode 100644 index 0000000000..0120c304e6 --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/oops-all-tabs/1_user.sql @@ -0,0 +1,6 @@ +create table user +( + -- integer primary keys are the most efficient in SQLite + user_id integer primary key, + username text unique not null +); diff --git a/sqlx-cli/tests/ignored-chars/oops-all-tabs/2_post.sql b/sqlx-cli/tests/ignored-chars/oops-all-tabs/2_post.sql new file mode 100644 index 0000000000..436028bbc0 --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/oops-all-tabs/2_post.sql @@ -0,0 +1,10 @@ +create table post +( + post_id integer primary key, + user_id integer not null references user (user_id), + content text not null, + -- Defaults have to be wrapped in parenthesis + created_at datetime default (datetime('now')) +); + +create index post_created_at on post (created_at desc); diff --git a/sqlx-cli/tests/ignored-chars/oops-all-tabs/3_comment.sql b/sqlx-cli/tests/ignored-chars/oops-all-tabs/3_comment.sql new file mode 100644 index 0000000000..2cdf347472 --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/oops-all-tabs/3_comment.sql @@ -0,0 +1,10 @@ +create table comment +( + comment_id integer primary key, + post_id integer not null references post (post_id), + user_id integer not null references "user" (user_id), + content text not null, + created_at datetime default (datetime('now')) +); + +create index comment_created_at on comment (created_at desc); diff --git a/sqlx-cli/tests/ignored-chars/sqlx.toml b/sqlx-cli/tests/ignored-chars/sqlx.toml new file mode 100644 index 0000000000..e5278d283f --- /dev/null +++ b/sqlx-cli/tests/ignored-chars/sqlx.toml @@ -0,0 +1,7 @@ +[migrate] +# Ignore common whitespace characters (beware syntatically significant whitespace!) +# Space, tab, CR, LF, zero-width non-breaking space (U+FEFF) +# +# U+FEFF is added by some editors as a magic number at the beginning of a text file indicating it is UTF-8 encoded, +# where it is known as a byte-order mark (BOM): https://en.wikipedia.org/wiki/Byte_order_mark +ignored-chars = [" ", "\t", "\r", "\n", "\uFEFF"] diff --git a/sqlx-cli/tests/migrate.rs b/sqlx-cli/tests/migrate.rs index 0ea9d4620d..f33ee5eb0e 100644 --- a/sqlx-cli/tests/migrate.rs +++ b/sqlx-cli/tests/migrate.rs @@ -13,16 +13,13 @@ async fn run_reversible_migrations() { ]; // Without --target-version specified.k { - let db = TestDatabase::new("migrate_run_reversible_latest", "migrations_reversible"); + let db = TestDatabase::new("run_reversible_latest", "migrations_reversible"); db.run_migration(false, None, false).success(); assert_eq!(db.applied_migrations().await, all_migrations); } // With --target-version specified. { - let db = TestDatabase::new( - "migrate_run_reversible_latest_explicit", - "migrations_reversible", - ); + let db = TestDatabase::new("run_reversible_latest_explicit", "migrations_reversible"); // Move to latest, explicitly specified. db.run_migration(false, Some(20230501000000), false) @@ -41,10 +38,7 @@ async fn run_reversible_migrations() { } // With --target-version, incrementally upgrade. { - let db = TestDatabase::new( - "migrate_run_reversible_incremental", - "migrations_reversible", - ); + let db = TestDatabase::new("run_reversible_incremental", "migrations_reversible"); // First version db.run_migration(false, Some(20230101000000), false) @@ -92,7 +86,7 @@ async fn revert_migrations() { // Without --target-version { - let db = TestDatabase::new("migrate_revert_incremental", "migrations_reversible"); + let db = TestDatabase::new("revert_incremental", "migrations_reversible"); db.run_migration(false, None, false).success(); // Dry-run @@ -109,7 +103,7 @@ async fn revert_migrations() { } // With --target-version { - let db = TestDatabase::new("migrate_revert_incremental", "migrations_reversible"); + let db = TestDatabase::new("revert_incremental", "migrations_reversible"); db.run_migration(false, None, false).success(); // Dry-run downgrade to version 3. @@ -142,6 +136,32 @@ async fn revert_migrations() { // Downgrade to zero. db.run_migration(true, Some(0), false).success(); - assert_eq!(db.applied_migrations().await, vec![] as Vec); + assert_eq!(db.applied_migrations().await, Vec::::new()); } } + +#[tokio::test] +async fn ignored_chars() { + let mut db = TestDatabase::new("ignored-chars", "ignored-chars/LF"); + db.config_path = Some("tests/ignored-chars/sqlx.toml".into()); + + db.run_migration(false, None, false).success(); + + db.set_migrations("ignored-chars/CRLF"); + + let expected_info = "1/installed user\n2/installed post\n3/installed comment\n"; + + // `ignored-chars` should produce the same migration checksum here + db.migrate_info().success().stdout(expected_info); + + // Running migration should be a no-op + db.run_migration(false, None, false).success().stdout(""); + + db.set_migrations("ignored-chars/BOM"); + db.migrate_info().success().stdout(expected_info); + db.run_migration(false, None, false).success().stdout(""); + + db.set_migrations("ignored-chars/oops-all-tabs"); + db.migrate_info().success().stdout(expected_info); + db.run_migration(false, None, false).success().stdout(""); +} diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 51b82fa68e..d64764ca0a 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -32,6 +32,14 @@ _tls-none = [] # support offline/decoupled building (enables serialization of `Describe`) offline = ["serde", "either/serde"] +# Enable parsing of `sqlx.toml`. +# For simplicity, the `config` module is always enabled, +# but disabling this disables the `serde` derives and the `toml` crate, +# which is a good bit less code to compile if the feature isn't being used. +sqlx-toml = ["serde", "toml/parse"] + +_unstable-doc = ["sqlx-toml"] + [dependencies] # Runtimes async-std = { workspace = true, optional = true } @@ -72,6 +80,7 @@ percent-encoding = "2.1.0" regex = { version = "1.5.5", optional = true } serde = { version = "1.0.132", features = ["derive", "rc"], optional = true } serde_json = { version = "1.0.73", features = ["raw_value"], optional = true } +toml = { version = "0.8.16", optional = true } sha2 = { version = "0.10.0", default-features = false, optional = true } #sqlformat = "0.2.0" thiserror = "2.0.0" diff --git a/sqlx-core/src/any/migrate.rs b/sqlx-core/src/any/migrate.rs index cb4f72c340..69b5bf6ab6 100644 --- a/sqlx-core/src/any/migrate.rs +++ b/sqlx-core/src/any/migrate.rs @@ -44,18 +44,44 @@ impl MigrateDatabase for Any { } impl Migrate for AnyConnection { - fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { - Box::pin(async { self.get_migrate()?.ensure_migrations_table().await }) + fn create_schema_if_not_exists<'e>( + &'e mut self, + schema_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async { + self.get_migrate()? + .create_schema_if_not_exists(schema_name) + .await + }) } - fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { - Box::pin(async { self.get_migrate()?.dirty_version().await }) + fn ensure_migrations_table<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async { + self.get_migrate()? + .ensure_migrations_table(table_name) + .await + }) } - fn list_applied_migrations( - &mut self, - ) -> BoxFuture<'_, Result, MigrateError>> { - Box::pin(async { self.get_migrate()?.list_applied_migrations().await }) + fn dirty_version<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { + Box::pin(async { self.get_migrate()?.dirty_version(table_name).await }) + } + + fn list_applied_migrations<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { + Box::pin(async { + self.get_migrate()? + .list_applied_migrations(table_name) + .await + }) } fn lock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { @@ -66,17 +92,19 @@ impl Migrate for AnyConnection { Box::pin(async { self.get_migrate()?.unlock().await }) } - fn apply<'e: 'm, 'm>( + fn apply<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { - Box::pin(async { self.get_migrate()?.apply(migration).await }) + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { + Box::pin(async { self.get_migrate()?.apply(table_name, migration).await }) } - fn revert<'e: 'm, 'm>( + fn revert<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { - Box::pin(async { self.get_migrate()?.revert(migration).await }) + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { + Box::pin(async { self.get_migrate()?.revert(table_name, migration).await }) } } diff --git a/sqlx-core/src/column.rs b/sqlx-core/src/column.rs index 9f45819ed6..fddc048c4b 100644 --- a/sqlx-core/src/column.rs +++ b/sqlx-core/src/column.rs @@ -2,6 +2,7 @@ use crate::database::Database; use crate::error::Error; use std::fmt::Debug; +use std::sync::Arc; pub trait Column: 'static + Send + Sync + Debug { type Database: Database; @@ -20,6 +21,61 @@ pub trait Column: 'static + Send + Sync + Debug { /// Gets the type information for the column. fn type_info(&self) -> &::TypeInfo; + + /// If this column comes from a table, return the table and original column name. + /// + /// Returns [`ColumnOrigin::Expression`] if the column is the result of an expression + /// or else the source table could not be determined. + /// + /// Returns [`ColumnOrigin::Unknown`] if the database driver does not have that information, + /// or has not overridden this method. + // This method returns an owned value instead of a reference, + // to give the implementor more flexibility. + fn origin(&self) -> ColumnOrigin { + ColumnOrigin::Unknown + } +} + +/// A [`Column`] that originates from a table. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct TableColumn { + /// The name of the table (optionally schema-qualified) that the column comes from. + pub table: Arc, + /// The original name of the column. + pub name: Arc, +} + +/// The possible statuses for our knowledge of the origin of a [`Column`]. +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub enum ColumnOrigin { + /// The column is known to originate from a table. + /// + /// Included is the table name and original column name. + Table(TableColumn), + /// The column originates from an expression, or else its origin could not be determined. + Expression, + /// The database driver does not know the column origin at this time. + /// + /// This may happen if: + /// * The connection is in the middle of executing a query, + /// and cannot query the catalog to fetch this information. + /// * The connection does not have access to the database catalog. + /// * The implementation of [`Column`] did not override [`Column::origin()`]. + #[default] + Unknown, +} + +impl ColumnOrigin { + /// Returns the true column origin, if known. + pub fn table_column(&self) -> Option<&TableColumn> { + if let Self::Table(table_column) = self { + Some(table_column) + } else { + None + } + } } /// A type that can be used to index into a [`Row`] or [`Statement`]. diff --git a/sqlx-core/src/config/common.rs b/sqlx-core/src/config/common.rs new file mode 100644 index 0000000000..2d5342d5b8 --- /dev/null +++ b/sqlx-core/src/config/common.rs @@ -0,0 +1,49 @@ +/// Configuration shared by multiple components. +#[derive(Debug, Default)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(default, rename_all = "kebab-case", deny_unknown_fields) +)] +pub struct Config { + /// Override the database URL environment variable. + /// + /// This is used by both the macros and `sqlx-cli`. + /// + /// Case-sensitive. Defaults to `DATABASE_URL`. + /// + /// Example: Multi-Database Project + /// ------- + /// You can use multiple databases in the same project by breaking it up into multiple crates, + /// then using a different environment variable for each. + /// + /// For example, with two crates in the workspace named `foo` and `bar`: + /// + /// #### `foo/sqlx.toml` + /// ```toml + /// [common] + /// database-url-var = "FOO_DATABASE_URL" + /// ``` + /// + /// #### `bar/sqlx.toml` + /// ```toml + /// [common] + /// database-url-var = "BAR_DATABASE_URL" + /// ``` + /// + /// #### `.env` + /// ```text + /// FOO_DATABASE_URL=postgres://postgres@localhost:5432/foo + /// BAR_DATABASE_URL=postgres://postgres@localhost:5432/bar + /// ``` + /// + /// The query macros used in `foo` will use `FOO_DATABASE_URL`, + /// and the ones used in `bar` will use `BAR_DATABASE_URL`. + pub database_url_var: Option, +} + +impl Config { + pub fn database_url_var(&self) -> &str { + self.database_url_var.as_deref().unwrap_or("DATABASE_URL") + } +} diff --git a/sqlx-core/src/config/macros.rs b/sqlx-core/src/config/macros.rs new file mode 100644 index 0000000000..6d08aa3ec2 --- /dev/null +++ b/sqlx-core/src/config/macros.rs @@ -0,0 +1,418 @@ +use std::collections::BTreeMap; + +/// Configuration for the `query!()` family of macros. +/// +/// See also [`common::Config`][crate::config::common::Config] for renaming `DATABASE_URL`. +#[derive(Debug, Default)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(default, rename_all = "kebab-case", deny_unknown_fields) +)] +pub struct Config { + /// Specify which crates' types to use when types from multiple crates apply. + /// + /// See [`PreferredCrates`] for details. + pub preferred_crates: PreferredCrates, + + /// Specify global overrides for mapping SQL type names to Rust type names. + /// + /// Default type mappings are defined by the database driver. + /// Refer to the `sqlx::types` module for details. + /// + /// ## Note: Case-Sensitive + /// Currently, the case of the type name MUST match the name SQLx knows it by. + /// Built-in types are spelled in all-uppercase to match SQL convention. + /// + /// However, user-created types in Postgres are all-lowercase unless quoted. + /// + /// ## Note: Orthogonal to Nullability + /// These overrides do not affect whether `query!()` decides to wrap a column in `Option<_>` + /// or not. They only override the inner type used. + /// + /// ## Note: Schema Qualification (Postgres) + /// Type names may be schema-qualified in Postgres. If so, the schema should be part + /// of the type string, e.g. `'foo.bar'` to reference type `bar` in schema `foo`. + /// + /// The schema and/or type name may additionally be quoted in the string + /// for a quoted identifier (see next section). + /// + /// Schema qualification should not be used for types in the search path. + /// + /// ## Note: Quoted Identifiers (Postgres) + /// Type names using [quoted identifiers in Postgres] must also be specified with quotes here. + /// + /// Note, however, that the TOML format parses way the outer pair of quotes, + /// so for quoted names in Postgres, double-quoting is necessary, + /// e.g. `'"Foo"'` for SQL type `"Foo"`. + /// + /// To reference a schema-qualified type with a quoted name, use double-quotes after the + /// dot, e.g. `'foo."Bar"'` to reference type `"Bar"` of schema `foo`, and vice versa for + /// quoted schema names. + /// + /// We recommend wrapping all type names in single quotes, as shown below, + /// to avoid confusion. + /// + /// MySQL/MariaDB and SQLite do not support custom types, so quoting type names should + /// never be necessary. + /// + /// [quoted identifiers in Postgres]: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS + // Note: we wanted to be able to handle this intelligently, + // but the `toml` crate authors weren't interested: https://github.com/toml-rs/toml/issues/761 + // + // We decided to just encourage always quoting type names instead. + /// Example: Custom Wrapper Types + /// ------- + /// Does SQLx not support a type that you need? Do you want additional semantics not + /// implemented on the built-in types? You can create a custom wrapper, + /// or use an external crate. + /// + /// #### `sqlx.toml` + /// ```toml + /// [macros.type-overrides] + /// # Override a built-in type + /// 'UUID' = "crate::types::MyUuid" + /// + /// # Support an external or custom wrapper type (e.g. from the `isn` Postgres extension) + /// # (NOTE: FOR DOCUMENTATION PURPOSES ONLY; THIS CRATE/TYPE DOES NOT EXIST AS OF WRITING) + /// 'isbn13' = "isn_rs::sqlx::ISBN13" + /// ``` + /// + /// Example: Custom Types in Postgres + /// ------- + /// If you have a custom type in Postgres that you want to map without needing to use + /// the type override syntax in `sqlx::query!()` every time, you can specify a global + /// override here. + /// + /// For example, a custom enum type `foo`: + /// + /// #### Migration or Setup SQL (e.g. `migrations/0_setup.sql`) + /// ```sql + /// CREATE TYPE foo AS ENUM ('Bar', 'Baz'); + /// ``` + /// + /// #### `src/types.rs` + /// ```rust,no_run + /// #[derive(sqlx::Type)] + /// pub enum Foo { + /// Bar, + /// Baz + /// } + /// ``` + /// + /// If you're not using `PascalCase` in your enum variants then you'll want to use + /// `#[sqlx(rename_all = "")]` on your enum. + /// See [`Type`][crate::type::Type] for details. + /// + /// #### `sqlx.toml` + /// ```toml + /// [macros.type-overrides] + /// # Map SQL type `foo` to `crate::types::Foo` + /// 'foo' = "crate::types::Foo" + /// ``` + /// + /// Example: Schema-Qualified Types + /// ------- + /// (See `Note` section above for details.) + /// + /// ```toml + /// [macros.type-overrides] + /// # Map SQL type `foo.foo` to `crate::types::Foo` + /// 'foo.foo' = "crate::types::Foo" + /// ``` + /// + /// Example: Quoted Identifiers + /// ------- + /// If a type or schema uses quoted identifiers, + /// it must be wrapped in quotes _twice_ for SQLx to know the difference: + /// + /// ```toml + /// [macros.type-overrides] + /// # `"Foo"` in SQLx + /// '"Foo"' = "crate::types::Foo" + /// # **NOT** `"Foo"` in SQLx (parses as just `Foo`) + /// "Foo" = "crate::types::Foo" + /// + /// # Schema-qualified + /// '"foo".foo' = "crate::types::Foo" + /// 'foo."Foo"' = "crate::types::Foo" + /// '"foo"."Foo"' = "crate::types::Foo" + /// ``` + /// + /// (See `Note` section above for details.) + // TODO: allow specifying different types for input vs output + // e.g. to accept `&[T]` on input but output `Vec` + pub type_overrides: BTreeMap, + + /// Specify per-table and per-column overrides for mapping SQL types to Rust types. + /// + /// Default type mappings are defined by the database driver. + /// Refer to the `sqlx::types` module for details. + /// + /// The supported syntax is similar to [`type_overrides`][Self::type_overrides], + /// (with the same caveat for quoted names!) but column names must be qualified + /// by a separately quoted table name, which may optionally be schema-qualified. + /// + /// Multiple columns for the same SQL table may be written in the same table in TOML + /// (see examples below). + /// + /// ## Note: Orthogonal to Nullability + /// These overrides do not affect whether `query!()` decides to wrap a column in `Option<_>` + /// or not. They only override the inner type used. + /// + /// ## Note: Schema Qualification + /// Table names may be schema-qualified. If so, the schema should be part + /// of the table name string, e.g. `'foo.bar'` to reference table `bar` in schema `foo`. + /// + /// The schema and/or type name may additionally be quoted in the string + /// for a quoted identifier (see next section). + /// + /// Postgres users: schema qualification should not be used for tables in the search path. + /// + /// ## Note: Quoted Identifiers + /// Schema, table, or column names using quoted identifiers ([MySQL], [Postgres], [SQLite]) + /// in SQL must also be specified with quotes here. + /// + /// Postgres and SQLite use double-quotes (`"Foo"`) while MySQL uses backticks (`\`Foo\`). + /// + /// Note, however, that the TOML format parses way the outer pair of quotes, + /// so for quoted names in Postgres, double-quoting is necessary, + /// e.g. `'"Foo"'` for SQL name `"Foo"`. + /// + /// To reference a schema-qualified table with a quoted name, use the appropriate quotation + /// characters after the dot, e.g. `'foo."Bar"'` to reference table `"Bar"` of schema `foo`, + /// and vice versa for quoted schema names. + /// + /// We recommend wrapping all table and column names in single quotes, as shown below, + /// to avoid confusion. + /// + /// [MySQL]: https://dev.mysql.com/doc/refman/8.4/en/identifiers.html + /// [Postgres]: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS + /// [SQLite]: https://sqlite.org/lang_keywords.html + // Note: we wanted to be able to handle this intelligently, + // but the `toml` crate authors weren't interested: https://github.com/toml-rs/toml/issues/761 + // + // We decided to just encourage always quoting type names instead. + /// + /// Example + /// ------- + /// + /// #### `sqlx.toml` + /// ```toml + /// [macros.table-overrides.'foo'] + /// # Map column `bar` of table `foo` to Rust type `crate::types::Foo`: + /// 'bar' = "crate::types::Bar" + /// + /// # Quoted column name + /// # Note: same quoting requirements as `macros.type_overrides` + /// '"Bar"' = "crate::types::Bar" + /// + /// # Note: will NOT work (parses as `Bar`) + /// # "Bar" = "crate::types::Bar" + /// + /// # Table name may be quoted (note the wrapping single-quotes) + /// [macros.table-overrides.'"Foo"'] + /// 'bar' = "crate::types::Bar" + /// '"Bar"' = "crate::types::Bar" + /// + /// # Table name may also be schema-qualified. + /// # Note how the dot is inside the quotes. + /// [macros.table-overrides.'my_schema.my_table'] + /// 'my_column' = "crate::types::MyType" + /// + /// # Quoted schema, table, and column names + /// [macros.table-overrides.'"My Schema"."My Table"'] + /// '"My Column"' = "crate::types::MyType" + /// ``` + pub table_overrides: BTreeMap>, +} + +#[derive(Debug, Default)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(default, rename_all = "kebab-case") +)] +pub struct PreferredCrates { + /// Specify the crate to use for mapping date/time types to Rust. + /// + /// The default behavior is to use whatever crate is enabled, + /// [`chrono`] or [`time`] (the latter takes precedent). + /// + /// [`chrono`]: crate::types::chrono + /// [`time`]: crate::types::time + /// + /// Example: Always Use Chrono + /// ------- + /// Thanks to Cargo's [feature unification], a crate in the dependency graph may enable + /// the `time` feature of SQLx which will force it on for all crates using SQLx, + /// which will result in problems if your crate wants to use types from [`chrono`]. + /// + /// You can use the type override syntax (see `sqlx::query!` for details), + /// or you can force an override globally by setting this option. + /// + /// #### `sqlx.toml` + /// ```toml + /// [macros.preferred-crates] + /// date-time = "chrono" + /// ``` + /// + /// [feature unification]: https://doc.rust-lang.org/cargo/reference/features.html#feature-unification + pub date_time: DateTimeCrate, + + /// Specify the crate to use for mapping `NUMERIC` types to Rust. + /// + /// The default behavior is to use whatever crate is enabled, + /// [`bigdecimal`] or [`rust_decimal`] (the latter takes precedent). + /// + /// [`bigdecimal`]: crate::types::bigdecimal + /// [`rust_decimal`]: crate::types::rust_decimal + /// + /// Example: Always Use `bigdecimal` + /// ------- + /// Thanks to Cargo's [feature unification], a crate in the dependency graph may enable + /// the `rust_decimal` feature of SQLx which will force it on for all crates using SQLx, + /// which will result in problems if your crate wants to use types from [`bigdecimal`]. + /// + /// You can use the type override syntax (see `sqlx::query!` for details), + /// or you can force an override globally by setting this option. + /// + /// #### `sqlx.toml` + /// ```toml + /// [macros.preferred-crates] + /// numeric = "bigdecimal" + /// ``` + /// + /// [feature unification]: https://doc.rust-lang.org/cargo/reference/features.html#feature-unification + pub numeric: NumericCrate, +} + +/// The preferred crate to use for mapping date/time types to Rust. +#[derive(Debug, Default, PartialEq, Eq)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(rename_all = "snake_case") +)] +pub enum DateTimeCrate { + /// Use whichever crate is enabled (`time` then `chrono`). + #[default] + Inferred, + + /// Always use types from [`chrono`][crate::types::chrono]. + /// + /// ```toml + /// [macros.preferred-crates] + /// date-time = "chrono" + /// ``` + Chrono, + + /// Always use types from [`time`][crate::types::time]. + /// + /// ```toml + /// [macros.preferred-crates] + /// date-time = "time" + /// ``` + Time, +} + +/// The preferred crate to use for mapping `NUMERIC` types to Rust. +#[derive(Debug, Default, PartialEq, Eq)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(rename_all = "snake_case") +)] +pub enum NumericCrate { + /// Use whichever crate is enabled (`rust_decimal` then `bigdecimal`). + #[default] + Inferred, + + /// Always use types from [`bigdecimal`][crate::types::bigdecimal]. + /// + /// ```toml + /// [macros.preferred-crates] + /// numeric = "bigdecimal" + /// ``` + #[cfg_attr(feature = "sqlx-toml", serde(rename = "bigdecimal"))] + BigDecimal, + + /// Always use types from [`rust_decimal`][crate::types::rust_decimal]. + /// + /// ```toml + /// [macros.preferred-crates] + /// numeric = "rust_decimal" + /// ``` + RustDecimal, +} + +/// A SQL type name; may optionally be schema-qualified. +/// +/// See [`macros.type-overrides`][Config::type_overrides] for usages. +pub type SqlType = Box; + +/// A SQL table name; may optionally be schema-qualified. +/// +/// See [`macros.table-overrides`][Config::table_overrides] for usages. +pub type TableName = Box; + +/// A column in a SQL table. +/// +/// See [`macros.table-overrides`][Config::table_overrides] for usages. +pub type ColumnName = Box; + +/// A Rust type name or path. +/// +/// Should be a global path (not relative). +pub type RustType = Box; + +/// Internal getter methods. +impl Config { + /// Get the override for a given type name (optionally schema-qualified). + pub fn type_override(&self, type_name: &str) -> Option<&str> { + // TODO: make this case-insensitive + self.type_overrides.get(type_name).map(|s| &**s) + } + + /// Get the override for a given column and table name (optionally schema-qualified). + pub fn column_override(&self, table: &str, column: &str) -> Option<&str> { + self.table_overrides + .get(table) + .and_then(|by_column| by_column.get(column)) + .map(|s| &**s) + } +} + +impl DateTimeCrate { + /// Returns `self == Self::Inferred` + #[inline(always)] + pub fn is_inferred(&self) -> bool { + *self == Self::Inferred + } + + #[inline(always)] + pub fn crate_name(&self) -> Option<&str> { + match self { + Self::Inferred => None, + Self::Chrono => Some("chrono"), + Self::Time => Some("time"), + } + } +} + +impl NumericCrate { + /// Returns `self == Self::Inferred` + #[inline(always)] + pub fn is_inferred(&self) -> bool { + *self == Self::Inferred + } + + #[inline(always)] + pub fn crate_name(&self) -> Option<&str> { + match self { + Self::Inferred => None, + Self::BigDecimal => Some("bigdecimal"), + Self::RustDecimal => Some("rust_decimal"), + } + } +} diff --git a/sqlx-core/src/config/migrate.rs b/sqlx-core/src/config/migrate.rs new file mode 100644 index 0000000000..0dd6cc2257 --- /dev/null +++ b/sqlx-core/src/config/migrate.rs @@ -0,0 +1,212 @@ +use std::collections::BTreeSet; + +/// Configuration for migrations when executed using `sqlx::migrate!()` or through `sqlx-cli`. +/// +/// ### Note +/// A manually constructed [`Migrator`][crate::migrate::Migrator] will not be aware of these +/// configuration options. We recommend using `sqlx::migrate!()` instead. +/// +/// ### Warning: Potential Data Loss or Corruption! +/// Many of these options, if changed after migrations are set up, +/// can result in data loss or corruption of a production database +/// if the proper precautions are not taken. +/// +/// Be sure you know what you are doing and that you read all relevant documentation _thoroughly_. +#[derive(Debug, Default)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(default, rename_all = "kebab-case", deny_unknown_fields) +)] +pub struct Config { + /// Specify the names of schemas to create if they don't already exist. + /// + /// This is done before checking the existence of the migrations table + /// (`_sqlx_migrations` or overridden `table_name` below) so that it may be placed in + /// one of these schemas. + /// + /// ### Example + /// `sqlx.toml`: + /// ```toml + /// [migrate] + /// create-schemas = ["foo"] + /// ``` + pub create_schemas: BTreeSet>, + + /// Override the name of the table used to track executed migrations. + /// + /// May be schema-qualified and/or contain quotes. Defaults to `_sqlx_migrations`. + /// + /// Potentially useful for multi-tenant databases. + /// + /// ### Warning: Potential Data Loss or Corruption! + /// Changing this option for a production database will likely result in data loss or corruption + /// as the migration machinery will no longer be aware of what migrations have been applied + /// and will attempt to re-run them. + /// + /// You should create the new table as a copy of the existing migrations table (with contents!), + /// and be sure all instances of your application have been migrated to the new + /// table before deleting the old one. + /// + /// ### Example + /// `sqlx.toml`: + /// ```toml + /// [migrate] + /// # Put `_sqlx_migrations` in schema `foo` + /// table-name = "foo._sqlx_migrations" + /// ``` + pub table_name: Option>, + + /// Override the directory used for migrations files. + /// + /// Relative to the crate root for `sqlx::migrate!()`, or the current directory for `sqlx-cli`. + pub migrations_dir: Option>, + + /// Specify characters that should be ignored when hashing migrations. + /// + /// Any characters contained in the given array will be dropped when a migration is hashed. + /// + /// ### Warning: May Change Hashes for Existing Migrations + /// Changing the characters considered in hashing migrations will likely + /// change the output of the hash. + /// + /// This may require manual rectification for deployed databases. + /// + /// ### Example: Ignore Carriage Return (`` | `\r`) + /// Line ending differences between platforms can result in migrations having non-repeatable + /// hashes. The most common culprit is the carriage return (`` | `\r`), which Windows + /// uses in its line endings alongside line feed (`` | `\n`), often written `CRLF` or `\r\n`, + /// whereas Linux and macOS use only line feeds. + /// + /// `sqlx.toml`: + /// ```toml + /// [migrate] + /// ignored-chars = ["\r"] + /// ``` + /// + /// For projects using Git, this can also be addressed using [`.gitattributes`]: + /// + /// ```text + /// # Force newlines in migrations to be line feeds on all platforms + /// migrations/*.sql text eol=lf + /// ``` + /// + /// This may require resetting or re-checking out the migrations files to take effect. + /// + /// [`.gitattributes`]: https://git-scm.com/docs/gitattributes + /// + /// ### Example: Ignore all Whitespace Characters + /// To make your migrations amenable to reformatting, you may wish to tell SQLx to ignore + /// _all_ whitespace characters in migrations. + /// + /// ##### Warning: Beware Syntactically Significant Whitespace! + /// If your migrations use string literals or quoted identifiers which contain whitespace, + /// this configuration will cause the migration machinery to ignore some changes to these. + /// This may result in a mismatch between the development and production versions of + /// your database. + /// + /// `sqlx.toml`: + /// ```toml + /// [migrate] + /// # Ignore common whitespace characters when hashing + /// ignored-chars = [" ", "\t", "\r", "\n"] # Space, tab, CR, LF + /// ``` + // Likely lower overhead for small sets than `HashSet`. + pub ignored_chars: BTreeSet, + + /// Specify default options for new migrations created with `sqlx migrate add`. + pub defaults: MigrationDefaults, +} + +#[derive(Debug, Default)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(default, rename_all = "kebab-case") +)] +pub struct MigrationDefaults { + /// Specify the default type of migration that `sqlx migrate add` should create by default. + /// + /// ### Example: Use Reversible Migrations by Default + /// `sqlx.toml`: + /// ```toml + /// [migrate.defaults] + /// migration-type = "reversible" + /// ``` + pub migration_type: DefaultMigrationType, + + /// Specify the default scheme that `sqlx migrate add` should use for version integers. + /// + /// ### Example: Use Sequential Versioning by Default + /// `sqlx.toml`: + /// ```toml + /// [migrate.defaults] + /// migration-versioning = "sequential" + /// ``` + pub migration_versioning: DefaultVersioning, +} + +/// The default type of migration that `sqlx migrate add` should create by default. +#[derive(Debug, Default, PartialEq, Eq)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(rename_all = "snake_case") +)] +pub enum DefaultMigrationType { + /// Create the same migration type as that of the latest existing migration, + /// or `Simple` otherwise. + #[default] + Inferred, + + /// Create non-reversible migrations (`_.sql`) by default. + Simple, + + /// Create reversible migrations (`_.up.sql` and `[...].down.sql`) by default. + Reversible, +} + +/// The default scheme that `sqlx migrate add` should use for version integers. +#[derive(Debug, Default, PartialEq, Eq)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(rename_all = "snake_case") +)] +pub enum DefaultVersioning { + /// Infer the versioning scheme from existing migrations: + /// + /// * If the versions of the last two migrations differ by `1`, infer `Sequential`. + /// * If only one migration exists and has version `1`, infer `Sequential`. + /// * Otherwise, infer `Timestamp`. + #[default] + Inferred, + + /// Use UTC timestamps for migration versions. + /// + /// This is the recommended versioning format as it's less likely to collide when multiple + /// developers are creating migrations on different branches. + /// + /// The exact timestamp format is unspecified. + Timestamp, + + /// Use sequential integers for migration versions. + Sequential, +} + +#[cfg(feature = "migrate")] +impl Config { + pub fn migrations_dir(&self) -> &str { + self.migrations_dir.as_deref().unwrap_or("migrations") + } + + pub fn table_name(&self) -> &str { + self.table_name.as_deref().unwrap_or("_sqlx_migrations") + } + + pub fn to_resolve_config(&self) -> crate::migrate::ResolveConfig { + let mut config = crate::migrate::ResolveConfig::new(); + config.ignore_chars(self.ignored_chars.iter().copied()); + config + } +} diff --git a/sqlx-core/src/config/mod.rs b/sqlx-core/src/config/mod.rs new file mode 100644 index 0000000000..6c828e909c --- /dev/null +++ b/sqlx-core/src/config/mod.rs @@ -0,0 +1,230 @@ +//! (Exported for documentation only) Guide and reference for `sqlx.toml` files. +//! +//! To use, create a `sqlx.toml` file in your crate root (the same directory as your `Cargo.toml`). +//! The configuration in a `sqlx.toml` configures SQLx *only* for the current crate. +//! +//! Requires the `sqlx-toml` feature (not enabled by default). +//! +//! `sqlx-cli` will also read `sqlx.toml` when running migrations. +//! +//! See the [`Config`] type and its fields for individual configuration options. +//! +//! See the [reference][`_reference`] for the full `sqlx.toml` file. + +use std::error::Error; +use std::fmt::Debug; +use std::io; +use std::path::{Path, PathBuf}; + +// `std::sync::OnceLock` doesn't have a stable `.get_or_try_init()` +// because it's blocked on a stable `Try` trait. +use once_cell::sync::OnceCell; + +/// Configuration shared by multiple components. +/// +/// See [`common::Config`] for details. +pub mod common; + +/// Configuration for the `query!()` family of macros. +/// +/// See [`macros::Config`] for details. +pub mod macros; + +/// Configuration for migrations when executed using `sqlx::migrate!()` or through `sqlx-cli`. +/// +/// See [`migrate::Config`] for details. +pub mod migrate; + +/// Reference for `sqlx.toml` files +/// +/// Source: `sqlx-core/src/config/reference.toml` +/// +/// ```toml +#[doc = include_str!("reference.toml")] +/// ``` +pub mod _reference {} + +#[cfg(all(test, feature = "sqlx-toml"))] +mod tests; + +/// The parsed structure of a `sqlx.toml` file. +#[derive(Debug, Default)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(default, rename_all = "kebab-case", deny_unknown_fields) +)] +pub struct Config { + /// Configuration shared by multiple components. + /// + /// See [`common::Config`] for details. + pub common: common::Config, + + /// Configuration for the `query!()` family of macros. + /// + /// See [`macros::Config`] for details. + pub macros: macros::Config, + + /// Configuration for migrations when executed using `sqlx::migrate!()` or through `sqlx-cli`. + /// + /// See [`migrate::Config`] for details. + pub migrate: migrate::Config, +} + +/// Error returned from various methods of [`Config`]. +#[derive(thiserror::Error, Debug)] +pub enum ConfigError { + /// The loading method expected `CARGO_MANIFEST_DIR` to be set and it wasn't. + /// + /// This is necessary to locate the root of the crate currently being compiled. + /// + /// See [the "Environment Variables" page of the Cargo Book][cargo-env] for details. + /// + /// [cargo-env]: https://doc.rust-lang.org/cargo/reference/environment-variables.html#environment-variables-cargo-sets-for-crates + #[error("environment variable `CARGO_MANIFEST_DIR` must be set and valid")] + Env( + #[from] + #[source] + std::env::VarError, + ), + + /// No configuration file was found. Not necessarily fatal. + #[error("config file {path:?} not found")] + NotFound { path: PathBuf }, + + /// An I/O error occurred while attempting to read the config file at `path`. + /// + /// If the error is [`io::ErrorKind::NotFound`], [`Self::NotFound`] is returned instead. + #[error("error reading config file {path:?}")] + Io { + path: PathBuf, + #[source] + error: io::Error, + }, + + /// An error in the TOML was encountered while parsing the config file at `path`. + /// + /// The error gives line numbers and context when printed with `Display`/`ToString`. + /// + /// Only returned if the `sqlx-toml` feature is enabled. + #[error("error parsing config file {path:?}")] + Parse { + path: PathBuf, + /// Type-erased [`toml::de::Error`]. + #[source] + error: Box, + }, + + /// A `sqlx.toml` file was found or specified, but the `sqlx-toml` feature is not enabled. + #[error("SQLx found config file at {path:?} but the `sqlx-toml` feature was not enabled")] + ParseDisabled { path: PathBuf }, +} + +impl ConfigError { + /// Create a [`ConfigError`] from a [`std::io::Error`]. + /// + /// Maps to either `NotFound` or `Io`. + pub fn from_io(path: PathBuf, error: io::Error) -> Self { + if error.kind() == io::ErrorKind::NotFound { + Self::NotFound { path } + } else { + Self::Io { path, error } + } + } + + /// If this error means the file was not found, return the path that was attempted. + pub fn not_found_path(&self) -> Option<&Path> { + if let Self::NotFound { path } = self { + Some(path) + } else { + None + } + } +} + +static CACHE: OnceCell = OnceCell::new(); + +/// Internal methods for loading a `Config`. +#[allow(clippy::result_large_err)] +impl Config { + /// Get the cached config, or to read `$CARGO_MANIFEST_DIR/sqlx.toml`. + /// + /// On success, the config is cached in a `static` and returned by future calls. + /// + /// Errors if `CARGO_MANIFEST_DIR` is not set, or if the config file could not be read. + pub fn try_from_crate() -> Result<&'static Self, ConfigError> { + Self::try_read_with(get_crate_path) + } + + /// Get the cached config, or attempt to read `sqlx.toml` from the current working directory. + /// + /// On success, the config is cached in a `static` and returned by future calls. + /// + /// Errors if the config file does not exist, or could not be read. + pub fn try_from_current_dir() -> Result<&'static Self, ConfigError> { + Self::try_read_with(|| Ok("sqlx.toml".into())) + } + + /// Get the cached config, or attempt to read it from the path given. + /// + /// On success, the config is cached in a `static` and returned by future calls. + /// + /// Errors if the config file does not exist, or could not be read. + pub fn try_from_path(path: impl Into) -> Result<&'static Self, ConfigError> { + Self::try_read_with(|| Ok(path.into())) + } + + /// Get the cached config, or return the default. + pub fn get_or_default() -> &'static Self { + CACHE.get_or_init(Config::default) + } + + /// Get the cached config, or attempt to read it from the path returned by the closure. + /// + /// On success, the config is cached in a `static` and returned by future calls. + /// + /// Errors if the config file does not exist, or could not be read. + fn try_read_with( + make_path: impl FnOnce() -> Result, + ) -> Result<&'static Self, ConfigError> { + CACHE.get_or_try_init(|| { + let path = make_path()?; + Self::read_from(path) + }) + } + + #[cfg(feature = "sqlx-toml")] + fn read_from(path: PathBuf) -> Result { + // The `toml` crate doesn't provide an incremental reader. + let toml_s = match std::fs::read_to_string(&path) { + Ok(toml) => toml, + Err(error) => { + return Err(ConfigError::from_io(path, error)); + } + }; + + // TODO: parse and lint TOML structure before deserializing + // Motivation: https://github.com/toml-rs/toml/issues/761 + tracing::debug!("read config TOML from {path:?}:\n{toml_s}"); + + toml::from_str(&toml_s).map_err(|error| ConfigError::Parse { + path, + error: Box::new(error), + }) + } + + #[cfg(not(feature = "sqlx-toml"))] + fn read_from(path: PathBuf) -> Result { + match path.try_exists() { + Ok(true) => Err(ConfigError::ParseDisabled { path }), + Ok(false) => Err(ConfigError::NotFound { path }), + Err(e) => Err(ConfigError::from_io(path, e)), + } + } +} + +fn get_crate_path() -> Result { + let mut path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR")?); + path.push("sqlx.toml"); + Ok(path) +} diff --git a/sqlx-core/src/config/reference.toml b/sqlx-core/src/config/reference.toml new file mode 100644 index 0000000000..77833fb5a8 --- /dev/null +++ b/sqlx-core/src/config/reference.toml @@ -0,0 +1,194 @@ +# `sqlx.toml` reference. +# +# Note: shown values are *not* defaults. +# They are explicitly set to non-default values to test parsing. +# Refer to the comment for a given option for its default value. + +############################################################################################### + +# Configuration shared by multiple components. +[common] +# Change the environment variable to get the database URL. +# +# This is used by both the macros and `sqlx-cli`. +# +# If not specified, defaults to `DATABASE_URL` +database-url-var = "FOO_DATABASE_URL" + +############################################################################################### + +# Configuration for the `query!()` family of macros. +[macros] + +[macros.preferred-crates] +# Force the macros to use the `chrono` crate for date/time types, even if `time` is enabled. +# +# Defaults to "inferred": use whichever crate is enabled (`time` takes precedence over `chrono`). +date-time = "chrono" + +# Or, ensure the macros always prefer `time` +# in case new date/time crates are added in the future: +# date-time = "time" + +# Force the macros to use the `rust_decimal` crate for `NUMERIC`, even if `bigdecimal` is enabled. +# +# Defaults to "inferred": use whichever crate is enabled (`bigdecimal` takes precedence over `rust_decimal`). +numeric = "rust_decimal" + +# Or, ensure the macros always prefer `bigdecimal` +# in case new decimal crates are added in the future: +# numeric = "bigdecimal" + +# Set global overrides for mapping SQL types to Rust types. +# +# Default type mappings are defined by the database driver. +# Refer to the `sqlx::types` module for details. +# +# Postgres users: schema qualification should not be used for types in the search path. +# +# ### Note: Orthogonal to Nullability +# These overrides do not affect whether `query!()` decides to wrap a column in `Option<_>` +# or not. They only override the inner type used. +[macros.type-overrides] +# Override a built-in type (map all `UUID` columns to `crate::types::MyUuid`) +# Note: currently, the case of the type name MUST match. +# Built-in types are spelled in all-uppercase to match SQL convention. +'UUID' = "crate::types::MyUuid" + +# Support an external or custom wrapper type (e.g. from the `isn` Postgres extension) +# (NOTE: FOR DOCUMENTATION PURPOSES ONLY; THIS CRATE/TYPE DOES NOT EXIST AS OF WRITING) +'isbn13' = "isn_rs::isbn::ISBN13" + +# SQL type `foo` to Rust type `crate::types::Foo`: +'foo' = "crate::types::Foo" + +# SQL type `"Bar"` to Rust type `crate::types::Bar`; notice the extra pair of quotes: +'"Bar"' = "crate::types::Bar" + +# Will NOT work (the first pair of quotes are parsed by TOML) +# "Bar" = "crate::types::Bar" + +# Schema qualified +'foo.bar' = "crate::types::Bar" + +# Schema qualified and quoted +'foo."Bar"' = "crate::schema::foo::Bar" + +# Quoted schema name +'"Foo".bar' = "crate::schema::foo::Bar" + +# Quoted schema and type name +'"Foo"."Bar"' = "crate::schema::foo::Bar" + +# Set per-table and per-column overrides for mapping SQL types to Rust types. +# +# Note: table name is required in the header. +# +# Postgres users: schema qualification should not be used for types in the search path. +# +# ### Note: Orthogonal to Nullability +# These overrides do not affect whether `query!()` decides to wrap a column in `Option<_>` +# or not. They only override the inner type used. +[macros.table-overrides.'foo'] +# Map column `bar` of table `foo` to Rust type `crate::types::Foo`: +'bar' = "crate::types::Bar" + +# Quoted column name +# Note: same quoting requirements as `macros.type_overrides` +'"Bar"' = "crate::types::Bar" + +# Note: will NOT work (parses as `Bar`) +# "Bar" = "crate::types::Bar" + +# Table name may be quoted (note the wrapping single-quotes) +[macros.table-overrides.'"Foo"'] +'bar' = "crate::types::Bar" +'"Bar"' = "crate::types::Bar" + +# Table name may also be schema-qualified. +# Note how the dot is inside the quotes. +[macros.table-overrides.'my_schema.my_table'] +'my_column' = "crate::types::MyType" + +# Quoted schema, table, and column names +[macros.table-overrides.'"My Schema"."My Table"'] +'"My Column"' = "crate::types::MyType" + +############################################################################################### + +# Configuration for migrations when executed using `sqlx::migrate!()` or through `sqlx-cli`. +# +# ### Note +# A manually constructed [`Migrator`][crate::migrate::Migrator] will not be aware of these +# configuration options. We recommend using `sqlx::migrate!()` instead. +# +# ### Warning: Potential Data Loss or Corruption! +# Many of these options, if changed after migrations are set up, +# can result in data loss or corruption of a production database +# if the proper precautions are not taken. +# +# Be sure you know what you are doing and that you read all relevant documentation _thoroughly_. +[migrate] +# Override the name of the table used to track executed migrations. +# +# May be schema-qualified and/or contain quotes. Defaults to `_sqlx_migrations`. +# +# Potentially useful for multi-tenant databases. +# +# ### Warning: Potential Data Loss or Corruption! +# Changing this option for a production database will likely result in data loss or corruption +# as the migration machinery will no longer be aware of what migrations have been applied +# and will attempt to re-run them. +# +# You should create the new table as a copy of the existing migrations table (with contents!), +# and be sure all instances of your application have been migrated to the new +# table before deleting the old one. +table-name = "foo._sqlx_migrations" + +# Override the directory used for migrations files. +# +# Relative to the crate root for `sqlx::migrate!()`, or the current directory for `sqlx-cli`. +migrations-dir = "foo/migrations" + +# Specify characters that should be ignored when hashing migrations. +# +# Any characters contained in the given set will be dropped when a migration is hashed. +# +# Defaults to an empty array (don't drop any characters). +# +# ### Warning: May Change Hashes for Existing Migrations +# Changing the characters considered in hashing migrations will likely +# change the output of the hash. +# +# This may require manual rectification for deployed databases. +# ignored-chars = [] + +# Ignore Carriage Returns (`` | `\r`) +# Note that the TOML format requires double-quoted strings to process escapes. +# ignored-chars = ["\r"] + +# Ignore common whitespace characters (beware syntatically significant whitespace!) +# Space, tab, CR, LF, zero-width non-breaking space (U+FEFF) +# +# U+FEFF is added by some editors as a magic number at the beginning of a text file indicating it is UTF-8 encoded, +# where it is known as a byte-order mark (BOM): https://en.wikipedia.org/wiki/Byte_order_mark +ignored-chars = [" ", "\t", "\r", "\n", "\uFEFF"] + +# Set default options for new migrations. +[migrate.defaults] +# Specify reversible migrations by default (for `sqlx migrate create`). +# +# Defaults to "inferred": uses the type of the last migration, or "simple" otherwise. +migration-type = "reversible" + +# Specify simple (non-reversible) migrations by default. +# migration-type = "simple" + +# Specify sequential versioning by default (for `sqlx migrate create`). +# +# Defaults to "inferred": guesses the versioning scheme from the latest migrations, +# or "timestamp" otherwise. +migration-versioning = "sequential" + +# Specify timestamp versioning by default. +# migration-versioning = "timestamp" diff --git a/sqlx-core/src/config/tests.rs b/sqlx-core/src/config/tests.rs new file mode 100644 index 0000000000..0b0b590919 --- /dev/null +++ b/sqlx-core/src/config/tests.rs @@ -0,0 +1,93 @@ +use crate::config::{self, Config}; +use std::collections::BTreeSet; + +#[test] +fn reference_parses_as_config() { + let config: Config = toml::from_str(include_str!("reference.toml")) + // The `Display` impl of `toml::Error` is *actually* more useful than `Debug` + .unwrap_or_else(|e| panic!("expected reference.toml to parse as Config: {e}")); + + assert_common_config(&config.common); + assert_macros_config(&config.macros); + assert_migrate_config(&config.migrate); +} + +fn assert_common_config(config: &config::common::Config) { + assert_eq!(config.database_url_var.as_deref(), Some("FOO_DATABASE_URL")); +} + +fn assert_macros_config(config: &config::macros::Config) { + use config::macros::*; + + assert_eq!(config.preferred_crates.date_time, DateTimeCrate::Chrono); + assert_eq!(config.preferred_crates.numeric, NumericCrate::RustDecimal); + + // Type overrides + // Don't need to cover everything, just some important canaries. + assert_eq!(config.type_override("UUID"), Some("crate::types::MyUuid")); + + assert_eq!(config.type_override("foo"), Some("crate::types::Foo")); + + assert_eq!(config.type_override(r#""Bar""#), Some("crate::types::Bar"),); + + assert_eq!( + config.type_override(r#""Foo".bar"#), + Some("crate::schema::foo::Bar"), + ); + + assert_eq!( + config.type_override(r#""Foo"."Bar""#), + Some("crate::schema::foo::Bar"), + ); + + // Column overrides + assert_eq!( + config.column_override("foo", "bar"), + Some("crate::types::Bar"), + ); + + assert_eq!( + config.column_override("foo", r#""Bar""#), + Some("crate::types::Bar"), + ); + + assert_eq!( + config.column_override(r#""Foo""#, "bar"), + Some("crate::types::Bar"), + ); + + assert_eq!( + config.column_override(r#""Foo""#, r#""Bar""#), + Some("crate::types::Bar"), + ); + + assert_eq!( + config.column_override("my_schema.my_table", "my_column"), + Some("crate::types::MyType"), + ); + + assert_eq!( + config.column_override(r#""My Schema"."My Table""#, r#""My Column""#), + Some("crate::types::MyType"), + ); +} + +fn assert_migrate_config(config: &config::migrate::Config) { + use config::migrate::*; + + assert_eq!(config.table_name.as_deref(), Some("foo._sqlx_migrations")); + assert_eq!(config.migrations_dir.as_deref(), Some("foo/migrations")); + + let ignored_chars = BTreeSet::from([' ', '\t', '\r', '\n', '\u{FEFF}']); + + assert_eq!(config.ignored_chars, ignored_chars); + + assert_eq!( + config.defaults.migration_type, + DefaultMigrationType::Reversible + ); + assert_eq!( + config.defaults.migration_versioning, + DefaultVersioning::Sequential + ); +} diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index df4b2cc27d..09f2900ba8 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -91,6 +91,8 @@ pub mod any; #[cfg(feature = "migrate")] pub mod testing; +pub mod config; + pub use error::{Error, Result}; pub use either::Either; diff --git a/sqlx-core/src/migrate/error.rs b/sqlx-core/src/migrate/error.rs index 608d55b18d..a04243963a 100644 --- a/sqlx-core/src/migrate/error.rs +++ b/sqlx-core/src/migrate/error.rs @@ -39,4 +39,7 @@ pub enum MigrateError { "migration {0} is partially applied; fix and remove row from `_sqlx_migrations` table" )] Dirty(i64), + + #[error("database driver does not support creation of schemas at migrate time: {0}")] + CreateSchemasNotSupported(String), } diff --git a/sqlx-core/src/migrate/migrate.rs b/sqlx-core/src/migrate/migrate.rs index 0e4448a9bd..841f775966 100644 --- a/sqlx-core/src/migrate/migrate.rs +++ b/sqlx-core/src/migrate/migrate.rs @@ -25,18 +25,31 @@ pub trait MigrateDatabase { // 'e = Executor pub trait Migrate { + /// Create a database schema with the given name if it does not already exist. + fn create_schema_if_not_exists<'e>( + &'e mut self, + schema_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>>; + // ensure migrations table exists // will create or migrate it if needed - fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>>; + fn ensure_migrations_table<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>>; // Return the version on which the database is dirty or None otherwise. // "dirty" means there is a partially applied migration that failed. - fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>>; + fn dirty_version<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>>; // Return the ordered list of applied migrations - fn list_applied_migrations( - &mut self, - ) -> BoxFuture<'_, Result, MigrateError>>; + fn list_applied_migrations<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>>; // Should acquire a database lock so that only one migration process // can run at a time. [`Migrate`] will call this function before applying @@ -50,16 +63,18 @@ pub trait Migrate { // run SQL from migration in a DDL transaction // insert new row to [_migrations] table on completion (success or failure) // returns the time taking to run the migration SQL - fn apply<'e: 'm, 'm>( + fn apply<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result>; + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result>; // run a revert SQL from migration in a DDL transaction // deletes the row in [_migrations] table with specified migration version on completion (success or failure) // returns the time taking to run the migration SQL - fn revert<'e: 'm, 'm>( + fn revert<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result>; + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result>; } diff --git a/sqlx-core/src/migrate/migration.rs b/sqlx-core/src/migrate/migration.rs index 9bd7f569d8..1f1175ce58 100644 --- a/sqlx-core/src/migrate/migration.rs +++ b/sqlx-core/src/migrate/migration.rs @@ -1,6 +1,5 @@ -use std::borrow::Cow; - use sha2::{Digest, Sha384}; +use std::borrow::Cow; use super::MigrationType; @@ -22,8 +21,26 @@ impl Migration { sql: Cow<'static, str>, no_tx: bool, ) -> Self { - let checksum = Cow::Owned(Vec::from(Sha384::digest(sql.as_bytes()).as_slice())); + let checksum = checksum(&sql); + + Self::with_checksum( + version, + description, + migration_type, + sql, + checksum.into(), + no_tx, + ) + } + pub(crate) fn with_checksum( + version: i64, + description: Cow<'static, str>, + migration_type: MigrationType, + sql: Cow<'static, str>, + checksum: Cow<'static, [u8]>, + no_tx: bool, + ) -> Self { Migration { version, description, @@ -40,3 +57,39 @@ pub struct AppliedMigration { pub version: i64, pub checksum: Cow<'static, [u8]>, } + +pub fn checksum(sql: &str) -> Vec { + Vec::from(Sha384::digest(sql).as_slice()) +} + +pub fn checksum_fragments<'a>(fragments: impl Iterator) -> Vec { + let mut digest = Sha384::new(); + + for fragment in fragments { + digest.update(fragment); + } + + digest.finalize().to_vec() +} + +#[test] +fn fragments_checksum_equals_full_checksum() { + // Copied from `examples/postgres/axum-social-with-tests/migrations/3_comment.sql` + let sql = "\ + \u{FEFF}create table comment (\r\n\ + \tcomment_id uuid primary key default gen_random_uuid(),\r\n\ + \tpost_id uuid not null references post(post_id),\r\n\ + \tuser_id uuid not null references \"user\"(user_id),\r\n\ + \tcontent text not null,\r\n\ + \tcreated_at timestamptz not null default now()\r\n\ + );\r\n\ + \r\n\ + create index on comment(post_id, created_at);\r\n\ + "; + + // Should yield a string for each character + let fragments_checksum = checksum_fragments(sql.split("")); + let full_checksum = checksum(sql); + + assert_eq!(fragments_checksum, full_checksum); +} diff --git a/sqlx-core/src/migrate/migration_type.rs b/sqlx-core/src/migrate/migration_type.rs index de2b019307..350ddb3f27 100644 --- a/sqlx-core/src/migrate/migration_type.rs +++ b/sqlx-core/src/migrate/migration_type.rs @@ -74,8 +74,9 @@ impl MigrationType { } } + #[deprecated = "unused"] pub fn infer(migrator: &Migrator, reversible: bool) -> MigrationType { - match migrator.iter().next() { + match migrator.iter().last() { Some(first_migration) => first_migration.migration_type, None => { if reversible { diff --git a/sqlx-core/src/migrate/migrator.rs b/sqlx-core/src/migrate/migrator.rs index 3209ba6e45..0f5cfb3fd7 100644 --- a/sqlx-core/src/migrate/migrator.rs +++ b/sqlx-core/src/migrate/migrator.rs @@ -23,25 +23,11 @@ pub struct Migrator { pub locking: bool, #[doc(hidden)] pub no_tx: bool, -} - -fn validate_applied_migrations( - applied_migrations: &[AppliedMigration], - migrator: &Migrator, -) -> Result<(), MigrateError> { - if migrator.ignore_missing { - return Ok(()); - } - - let migrations: HashSet<_> = migrator.iter().map(|m| m.version).collect(); - - for applied_migration in applied_migrations { - if !migrations.contains(&applied_migration.version) { - return Err(MigrateError::VersionMissing(applied_migration.version)); - } - } + #[doc(hidden)] + pub table_name: Cow<'static, str>, - Ok(()) + #[doc(hidden)] + pub create_schemas: Cow<'static, [Cow<'static, str>]>, } impl Migrator { @@ -51,6 +37,8 @@ impl Migrator { ignore_missing: false, no_tx: false, locking: true, + table_name: Cow::Borrowed("_sqlx_migrations"), + create_schemas: Cow::Borrowed(&[]), }; /// Creates a new instance with the given source. @@ -81,6 +69,38 @@ impl Migrator { }) } + /// Override the name of the table used to track executed migrations. + /// + /// May be schema-qualified and/or contain quotes. Defaults to `_sqlx_migrations`. + /// + /// Potentially useful for multi-tenant databases. + /// + /// ### Warning: Potential Data Loss or Corruption! + /// Changing this option for a production database will likely result in data loss or corruption + /// as the migration machinery will no longer be aware of what migrations have been applied + /// and will attempt to re-run them. + /// + /// You should create the new table as a copy of the existing migrations table (with contents!), + /// and be sure all instances of your application have been migrated to the new + /// table before deleting the old one. + pub fn dangerous_set_table_name(&mut self, table_name: impl Into>) -> &Self { + self.table_name = table_name.into(); + self + } + + /// Add a schema name to be created if it does not already exist. + /// + /// May be used with [`Self::dangerous_set_table_name()`] to place the migrations table + /// in a new schema without requiring it to exist first. + /// + /// ### Note: Support Depends on Database + /// SQLite cannot create new schemas without attaching them to a database file, + /// the path of which must be specified separately in an [`ATTACH DATABASE`](https://www.sqlite.org/lang_attach.html) command. + pub fn create_schema(&mut self, schema_name: impl Into>) -> &Self { + self.create_schemas.to_mut().push(schema_name.into()); + self + } + /// Specify whether applied migrations that are missing from the resolved migrations should be ignored. pub fn set_ignore_missing(&mut self, ignore_missing: bool) -> &Self { self.ignore_missing = ignore_missing; @@ -134,12 +154,21 @@ impl Migrator { ::Target: Migrate, { let mut conn = migrator.acquire().await?; - self.run_direct(&mut *conn).await + self.run_direct(None, &mut *conn).await + } + + pub async fn run_to<'a, A>(&self, target: i64, migrator: A) -> Result<(), MigrateError> + where + A: Acquire<'a>, + ::Target: Migrate, + { + let mut conn = migrator.acquire().await?; + self.run_direct(Some(target), &mut *conn).await } // Getting around the annoying "implementation of `Acquire` is not general enough" error #[doc(hidden)] - pub async fn run_direct(&self, conn: &mut C) -> Result<(), MigrateError> + pub async fn run_direct(&self, target: Option, conn: &mut C) -> Result<(), MigrateError> where C: Migrate, { @@ -148,16 +177,20 @@ impl Migrator { conn.lock().await?; } + for schema_name in self.create_schemas.iter() { + conn.create_schema_if_not_exists(schema_name).await?; + } + // creates [_migrations] table only if needed // eventually this will likely migrate previous versions of the table - conn.ensure_migrations_table().await?; + conn.ensure_migrations_table(&self.table_name).await?; - let version = conn.dirty_version().await?; + let version = conn.dirty_version(&self.table_name).await?; if let Some(version) = version { return Err(MigrateError::Dirty(version)); } - let applied_migrations = conn.list_applied_migrations().await?; + let applied_migrations = conn.list_applied_migrations(&self.table_name).await?; validate_applied_migrations(&applied_migrations, self)?; let applied_migrations: HashMap<_, _> = applied_migrations @@ -166,6 +199,11 @@ impl Migrator { .collect(); for migration in self.iter() { + if target.is_some_and(|target| target < migration.version) { + // Target version reached + break; + } + if migration.migration_type.is_down_migration() { continue; } @@ -177,7 +215,7 @@ impl Migrator { } } None => { - conn.apply(migration).await?; + conn.apply(&self.table_name, migration).await?; } } } @@ -222,14 +260,14 @@ impl Migrator { // creates [_migrations] table only if needed // eventually this will likely migrate previous versions of the table - conn.ensure_migrations_table().await?; + conn.ensure_migrations_table(&self.table_name).await?; - let version = conn.dirty_version().await?; + let version = conn.dirty_version(&self.table_name).await?; if let Some(version) = version { return Err(MigrateError::Dirty(version)); } - let applied_migrations = conn.list_applied_migrations().await?; + let applied_migrations = conn.list_applied_migrations(&self.table_name).await?; validate_applied_migrations(&applied_migrations, self)?; let applied_migrations: HashMap<_, _> = applied_migrations @@ -244,7 +282,7 @@ impl Migrator { .filter(|m| applied_migrations.contains_key(&m.version)) .filter(|m| m.version > target) { - conn.revert(migration).await?; + conn.revert(&self.table_name, migration).await?; } // unlock the migrator to allow other migrators to run @@ -256,3 +294,22 @@ impl Migrator { Ok(()) } } + +fn validate_applied_migrations( + applied_migrations: &[AppliedMigration], + migrator: &Migrator, +) -> Result<(), MigrateError> { + if migrator.ignore_missing { + return Ok(()); + } + + let migrations: HashSet<_> = migrator.iter().map(|m| m.version).collect(); + + for applied_migration in applied_migrations { + if !migrations.contains(&applied_migration.version) { + return Err(MigrateError::VersionMissing(applied_migration.version)); + } + } + + Ok(()) +} diff --git a/sqlx-core/src/migrate/mod.rs b/sqlx-core/src/migrate/mod.rs index f035b8d3c1..39347cf421 100644 --- a/sqlx-core/src/migrate/mod.rs +++ b/sqlx-core/src/migrate/mod.rs @@ -11,7 +11,7 @@ pub use migrate::{Migrate, MigrateDatabase}; pub use migration::{AppliedMigration, Migration}; pub use migration_type::MigrationType; pub use migrator::Migrator; -pub use source::MigrationSource; +pub use source::{MigrationSource, ResolveConfig, ResolveWith}; #[doc(hidden)] -pub use source::resolve_blocking; +pub use source::{resolve_blocking, resolve_blocking_with_config}; diff --git a/sqlx-core/src/migrate/source.rs b/sqlx-core/src/migrate/source.rs index d0c23b43cd..9c2ef7719b 100644 --- a/sqlx-core/src/migrate/source.rs +++ b/sqlx-core/src/migrate/source.rs @@ -1,8 +1,9 @@ use crate::error::BoxDynError; -use crate::migrate::{Migration, MigrationType}; +use crate::migrate::{migration, Migration, MigrationType}; use futures_core::future::BoxFuture; use std::borrow::Cow; +use std::collections::BTreeSet; use std::fmt::Debug; use std::fs; use std::io; @@ -28,19 +29,48 @@ pub trait MigrationSource<'s>: Debug { impl<'s> MigrationSource<'s> for &'s Path { fn resolve(self) -> BoxFuture<'s, Result, BoxDynError>> { + // Behavior changed from previous because `canonicalize()` is potentially blocking + // since it might require going to disk to fetch filesystem data. + self.to_owned().resolve() + } +} + +impl MigrationSource<'static> for PathBuf { + fn resolve(self) -> BoxFuture<'static, Result, BoxDynError>> { + // Technically this could just be `Box::pin(spawn_blocking(...))` + // but that would actually be a breaking behavior change because it would call + // `spawn_blocking()` on the current thread Box::pin(async move { - let canonical = self.canonicalize()?; - let migrations_with_paths = - crate::rt::spawn_blocking(move || resolve_blocking(&canonical)).await?; + crate::rt::spawn_blocking(move || { + let migrations_with_paths = resolve_blocking(&self)?; - Ok(migrations_with_paths.into_iter().map(|(m, _p)| m).collect()) + Ok(migrations_with_paths.into_iter().map(|(m, _p)| m).collect()) + }) + .await }) } } -impl MigrationSource<'static> for PathBuf { - fn resolve(self) -> BoxFuture<'static, Result, BoxDynError>> { - Box::pin(async move { self.as_path().resolve().await }) +/// A [`MigrationSource`] implementation with configurable resolution. +/// +/// `S` may be `PathBuf`, `&Path` or any type that implements `Into`. +/// +/// See [`ResolveConfig`] for details. +#[derive(Debug)] +pub struct ResolveWith(pub S, pub ResolveConfig); + +impl<'s, S: Debug + Into + Send + 's> MigrationSource<'s> for ResolveWith { + fn resolve(self) -> BoxFuture<'s, Result, BoxDynError>> { + Box::pin(async move { + let path = self.0.into(); + let config = self.1; + + let migrations_with_paths = + crate::rt::spawn_blocking(move || resolve_blocking_with_config(&path, &config)) + .await?; + + Ok(migrations_with_paths.into_iter().map(|(m, _p)| m).collect()) + }) } } @@ -52,11 +82,87 @@ pub struct ResolveError { source: Option, } +/// Configuration for migration resolution using [`ResolveWith`]. +#[derive(Debug, Default)] +pub struct ResolveConfig { + ignored_chars: BTreeSet, +} + +impl ResolveConfig { + /// Return a default, empty configuration. + pub fn new() -> Self { + ResolveConfig { + ignored_chars: BTreeSet::new(), + } + } + + /// Ignore a character when hashing migrations. + /// + /// The migration SQL string itself will still contain the character, + /// but it will not be included when calculating the checksum. + /// + /// This can be used to ignore whitespace characters so changing formatting + /// does not change the checksum. + /// + /// Adding the same `char` more than once is a no-op. + /// + /// ### Note: Changes Migration Checksum + /// This will change the checksum of resolved migrations, + /// which may cause problems with existing deployments. + /// + /// **Use at your own risk.** + pub fn ignore_char(&mut self, c: char) -> &mut Self { + self.ignored_chars.insert(c); + self + } + + /// Ignore one or more characters when hashing migrations. + /// + /// The migration SQL string itself will still contain these characters, + /// but they will not be included when calculating the checksum. + /// + /// This can be used to ignore whitespace characters so changing formatting + /// does not change the checksum. + /// + /// Adding the same `char` more than once is a no-op. + /// + /// ### Note: Changes Migration Checksum + /// This will change the checksum of resolved migrations, + /// which may cause problems with existing deployments. + /// + /// **Use at your own risk.** + pub fn ignore_chars(&mut self, chars: impl IntoIterator) -> &mut Self { + self.ignored_chars.extend(chars); + self + } + + /// Iterate over the set of ignored characters. + /// + /// Duplicate `char`s are not included. + pub fn ignored_chars(&self) -> impl Iterator + '_ { + self.ignored_chars.iter().copied() + } +} + // FIXME: paths should just be part of `Migration` but we can't add a field backwards compatibly // since it's `#[non_exhaustive]`. +#[doc(hidden)] pub fn resolve_blocking(path: &Path) -> Result, ResolveError> { - let s = fs::read_dir(path).map_err(|e| ResolveError { - message: format!("error reading migration directory {}: {e}", path.display()), + resolve_blocking_with_config(path, &ResolveConfig::new()) +} + +#[doc(hidden)] +pub fn resolve_blocking_with_config( + path: &Path, + config: &ResolveConfig, +) -> Result, ResolveError> { + let path = path.canonicalize().map_err(|e| ResolveError { + message: format!("error canonicalizing path {}", path.display()), + source: Some(e), + })?; + + let s = fs::read_dir(&path).map_err(|e| ResolveError { + message: format!("error reading migration directory {}", path.display()), source: Some(e), })?; @@ -65,7 +171,7 @@ pub fn resolve_blocking(path: &Path) -> Result, Resolv for res in s { let entry = res.map_err(|e| ResolveError { message: format!( - "error reading contents of migration directory {}: {e}", + "error reading contents of migration directory {}", path.display() ), source: Some(e), @@ -126,12 +232,15 @@ pub fn resolve_blocking(path: &Path) -> Result, Resolv // opt-out of migration transaction let no_tx = sql.starts_with("-- no-transaction"); + let checksum = checksum_with(&sql, &config.ignored_chars); + migrations.push(( - Migration::new( + Migration::with_checksum( version, Cow::Owned(description), migration_type, Cow::Owned(sql), + checksum.into(), no_tx, ), entry_path, @@ -143,3 +252,47 @@ pub fn resolve_blocking(path: &Path) -> Result, Resolv Ok(migrations) } + +fn checksum_with(sql: &str, ignored_chars: &BTreeSet) -> Vec { + if ignored_chars.is_empty() { + // This is going to be much faster because it doesn't have to UTF-8 decode `sql`. + return migration::checksum(sql); + } + + migration::checksum_fragments(sql.split(|c| ignored_chars.contains(&c))) +} + +#[test] +fn checksum_with_ignored_chars() { + // Ensure that `checksum_with` returns the same digest for a given set of ignored chars + // as the equivalent string with the characters removed. + let ignored_chars = [ + ' ', '\t', '\r', '\n', + // Zero-width non-breaking space (ZWNBSP), often added as a magic-number at the beginning + // of UTF-8 encoded files as a byte-order mark (BOM): + // https://en.wikipedia.org/wiki/Byte_order_mark + '\u{FEFF}', + ]; + + // Copied from `examples/postgres/axum-social-with-tests/migrations/3_comment.sql` + let sql = "\ + \u{FEFF}create table comment (\r\n\ + \tcomment_id uuid primary key default gen_random_uuid(),\r\n\ + \tpost_id uuid not null references post(post_id),\r\n\ + \tuser_id uuid not null references \"user\"(user_id),\r\n\ + \tcontent text not null,\r\n\ + \tcreated_at timestamptz not null default now()\r\n\ + );\r\n\ + \r\n\ + create index on comment(post_id, created_at);\r\n\ + "; + + let stripped_sql = sql.replace(&ignored_chars[..], ""); + + let ignored_chars = BTreeSet::from(ignored_chars); + + let digest_ignored = checksum_with(sql, &ignored_chars); + let digest_stripped = migration::checksum(&stripped_sql); + + assert_eq!(digest_ignored, digest_stripped); +} diff --git a/sqlx-core/src/testing/mod.rs b/sqlx-core/src/testing/mod.rs index 051353383b..d683fdf874 100644 --- a/sqlx-core/src/testing/mod.rs +++ b/sqlx-core/src/testing/mod.rs @@ -256,7 +256,7 @@ async fn setup_test_db( if let Some(migrator) = args.migrator { migrator - .run_direct(&mut conn) + .run_direct(None, &mut conn) .await .expect("failed to apply migrations"); } diff --git a/sqlx-core/src/type_checking.rs b/sqlx-core/src/type_checking.rs index 1da6b7ab3f..a3ded72abb 100644 --- a/sqlx-core/src/type_checking.rs +++ b/sqlx-core/src/type_checking.rs @@ -1,3 +1,4 @@ +use crate::config::macros::PreferredCrates; use crate::database::Database; use crate::decode::Decode; use crate::type_info::TypeInfo; @@ -26,12 +27,18 @@ pub trait TypeChecking: Database { /// /// If the type has a borrowed equivalent suitable for query parameters, /// this is that borrowed type. - fn param_type_for_id(id: &Self::TypeInfo) -> Option<&'static str>; + fn param_type_for_id( + id: &Self::TypeInfo, + preferred_crates: &PreferredCrates, + ) -> Result<&'static str, Error>; /// Get the full path of the Rust type that corresponds to the given `TypeInfo`, if applicable. /// /// Always returns the owned version of the type, suitable for decoding from `Row`. - fn return_type_for_id(id: &Self::TypeInfo) -> Option<&'static str>; + fn return_type_for_id( + id: &Self::TypeInfo, + preferred_crates: &PreferredCrates, + ) -> Result<&'static str, Error>; /// Get the name of the Cargo feature gate that must be enabled to process the given `TypeInfo`, /// if applicable. @@ -43,6 +50,22 @@ pub trait TypeChecking: Database { fn fmt_value_debug(value: &::Value) -> FmtValue<'_, Self>; } +pub type Result = std::result::Result; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("no built-in mapping found for SQL type; a type override may be required")] + NoMappingFound, + #[error("Cargo feature for configured `macros.preferred-crates.date-time` not enabled")] + DateTimeCrateFeatureNotEnabled, + #[error("Cargo feature for configured `macros.preferred-crates.numeric` not enabled")] + NumericCrateFeatureNotEnabled, + #[error("multiple date-time types are possible; falling back to `{fallback}`")] + AmbiguousDateTimeType { fallback: &'static str }, + #[error("multiple numeric types are possible; falling back to `{fallback}`")] + AmbiguousNumericType { fallback: &'static str }, +} + /// An adapter for [`Value`] which attempts to decode the value and format it when printed using [`Debug`]. pub struct FmtValue<'v, DB> where @@ -140,36 +163,304 @@ macro_rules! impl_type_checking { }, ParamChecking::$param_checking:ident, feature-types: $ty_info:ident => $get_gate:expr, + datetime-types: { + chrono: { + $($chrono_ty:ty $(| $chrono_input:ty)?),*$(,)? + }, + time: { + $($time_ty:ty $(| $time_input:ty)?),*$(,)? + }, + }, + numeric-types: { + bigdecimal: { + $($bigdecimal_ty:ty $(| $bigdecimal_input:ty)?),*$(,)? + }, + rust_decimal: { + $($rust_decimal_ty:ty $(| $rust_decimal_input:ty)?),*$(,)? + }, + }, ) => { impl $crate::type_checking::TypeChecking for $database { const PARAM_CHECKING: $crate::type_checking::ParamChecking = $crate::type_checking::ParamChecking::$param_checking; - fn param_type_for_id(info: &Self::TypeInfo) -> Option<&'static str> { - match () { + fn param_type_for_id( + info: &Self::TypeInfo, + preferred_crates: &$crate::config::macros::PreferredCrates, + ) -> Result<&'static str, $crate::type_checking::Error> { + use $crate::config::macros::{DateTimeCrate, NumericCrate}; + use $crate::type_checking::Error; + + // Check non-special types + // --------------------- + $( + $(#[$meta])? + if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok($crate::select_input_type!($ty $(, $input)?)); + } + )* + + $( + $(#[$meta])? + if <$ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok($crate::select_input_type!($ty $(, $input)?)); + } + )* + + // Check `macros.preferred-crates.date-time` + // + // Due to legacy reasons, `time` takes precedent over `chrono` if both are enabled. + // Any crates added later should be _lower_ priority than `chrono` to avoid breakages. + // ---------------------------------------- + #[cfg(feature = "time")] + if matches!(preferred_crates.date_time, DateTimeCrate::Time | DateTimeCrate::Inferred) { + $( + if <$time_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + if cfg!(feature = "chrono") { + return Err($crate::type_checking::Error::AmbiguousDateTimeType { + fallback: $crate::select_input_type!($time_ty $(, $time_input)?), + }); + } + + return Ok($crate::select_input_type!($time_ty $(, $time_input)?)); + } + )* + + $( + if <$time_ty as sqlx_core::types::Type<$database>>::compatible(info) { + if cfg!(feature = "chrono") { + return Err($crate::type_checking::Error::AmbiguousDateTimeType { + fallback: $crate::select_input_type!($time_ty $(, $time_input)?), + }); + } + + return Ok($crate::select_input_type!($time_ty $(, $time_input)?)); + } + )* + } + + #[cfg(not(feature = "time"))] + if preferred_crates.date_time == DateTimeCrate::Time { + return Err(Error::DateTimeCrateFeatureNotEnabled); + } + + #[cfg(feature = "chrono")] + if matches!(preferred_crates.date_time, DateTimeCrate::Chrono | DateTimeCrate::Inferred) { + $( + if <$chrono_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok($crate::select_input_type!($chrono_ty $(, $chrono_input)?)); + } + )* + $( - $(#[$meta])? - _ if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info => Some($crate::select_input_type!($ty $(, $input)?)), + if <$chrono_ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok($crate::select_input_type!($chrono_ty $(, $chrono_input)?)); + } + )* + } + + #[cfg(not(feature = "chrono"))] + if preferred_crates.date_time == DateTimeCrate::Chrono { + return Err(Error::DateTimeCrateFeatureNotEnabled); + } + + // Check `macros.preferred-crates.numeric` + // + // Due to legacy reasons, `bigdecimal` takes precedent over `rust_decimal` if + // both are enabled. + // ---------------------------------------- + #[cfg(feature = "bigdecimal")] + if matches!(preferred_crates.numeric, NumericCrate::BigDecimal | NumericCrate::Inferred) { + $( + if <$bigdecimal_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + if cfg!(feature = "rust_decimal") { + return Err($crate::type_checking::Error::AmbiguousNumericType { + fallback: $crate::select_input_type!($bigdecimal_ty $(, $bigdecimal_input)?), + }); + } + + return Ok($crate::select_input_type!($bigdecimal_ty $(, $bigdecimal_input)?)); + } )* + $( - $(#[$meta])? - _ if <$ty as sqlx_core::types::Type<$database>>::compatible(info) => Some($crate::select_input_type!($ty $(, $input)?)), + if <$bigdecimal_ty as sqlx_core::types::Type<$database>>::compatible(info) { + if cfg!(feature = "rust_decimal") { + return Err($crate::type_checking::Error::AmbiguousNumericType { + fallback: $crate::select_input_type!($bigdecimal_ty $(, $bigdecimal_input)?), + }); + } + + return Ok($crate::select_input_type!($bigdecimal_ty $(, $bigdecimal_input)?)); + } )* - _ => None } + + #[cfg(not(feature = "bigdecimal"))] + if preferred_crates.numeric == NumericCrate::BigDecimal { + return Err(Error::NumericCrateFeatureNotEnabled); + } + + #[cfg(feature = "rust_decimal")] + if matches!(preferred_crates.numeric, NumericCrate::RustDecimal | NumericCrate::Inferred) { + $( + if <$rust_decimal_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok($crate::select_input_type!($rust_decimal_ty $(, $rust_decimal_input)?)); + } + )* + + $( + if <$rust_decimal_ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok($crate::select_input_type!($rust_decimal_ty $(, $rust_decimal_input)?)); + } + )* + } + + #[cfg(not(feature = "rust_decimal"))] + if preferred_crates.numeric == NumericCrate::RustDecimal { + return Err(Error::NumericCrateFeatureNotEnabled); + } + + Err(Error::NoMappingFound) } - fn return_type_for_id(info: &Self::TypeInfo) -> Option<&'static str> { - match () { + fn return_type_for_id( + info: &Self::TypeInfo, + preferred_crates: &$crate::config::macros::PreferredCrates, + ) -> Result<&'static str, $crate::type_checking::Error> { + use $crate::config::macros::{DateTimeCrate, NumericCrate}; + use $crate::type_checking::Error; + + // Check non-special types + // --------------------- + $( + $(#[$meta])? + if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok(stringify!($ty)); + } + )* + + $( + $(#[$meta])? + if <$ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok(stringify!($ty)); + } + )* + + // Check `macros.preferred-crates.date-time` + // + // Due to legacy reasons, `time` takes precedent over `chrono` if both are enabled. + // Any crates added later should be _lower_ priority than `chrono` to avoid breakages. + // ---------------------------------------- + #[cfg(feature = "time")] + if matches!(preferred_crates.date_time, DateTimeCrate::Time | DateTimeCrate::Inferred) { $( - $(#[$meta])? - _ if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info => Some(stringify!($ty)), + if <$time_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + if cfg!(feature = "chrono") { + return Err($crate::type_checking::Error::AmbiguousDateTimeType { + fallback: stringify!($time_ty), + }); + } + + return Ok(stringify!($time_ty)); + } + )* + + $( + if <$time_ty as sqlx_core::types::Type<$database>>::compatible(info) { + if cfg!(feature = "chrono") { + return Err($crate::type_checking::Error::AmbiguousDateTimeType { + fallback: stringify!($time_ty), + }); + } + + return Ok(stringify!($time_ty)); + } + )* + } + + #[cfg(not(feature = "time"))] + if preferred_crates.date_time == DateTimeCrate::Time { + return Err(Error::DateTimeCrateFeatureNotEnabled); + } + + #[cfg(feature = "chrono")] + if matches!(preferred_crates.date_time, DateTimeCrate::Chrono | DateTimeCrate::Inferred) { + $( + if <$chrono_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok(stringify!($chrono_ty)); + } + )* + + $( + if <$chrono_ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok(stringify!($chrono_ty)); + } + )* + } + + #[cfg(not(feature = "chrono"))] + if preferred_crates.date_time == DateTimeCrate::Chrono { + return Err(Error::DateTimeCrateFeatureNotEnabled); + } + + // Check `macros.preferred-crates.numeric` + // + // Due to legacy reasons, `bigdecimal` takes precedent over `rust_decimal` if + // both are enabled. + // ---------------------------------------- + #[cfg(feature = "bigdecimal")] + if matches!(preferred_crates.numeric, NumericCrate::BigDecimal | NumericCrate::Inferred) { + $( + if <$bigdecimal_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + if cfg!(feature = "rust_decimal") { + return Err($crate::type_checking::Error::AmbiguousNumericType { + fallback: stringify!($bigdecimal_ty), + }); + } + + return Ok(stringify!($bigdecimal_ty)); + } + )* + + $( + if <$bigdecimal_ty as sqlx_core::types::Type<$database>>::compatible(info) { + if cfg!(feature = "rust_decimal") { + return Err($crate::type_checking::Error::AmbiguousNumericType { + fallback: stringify!($bigdecimal_ty), + }); + } + + return Ok(stringify!($bigdecimal_ty)); + } )* + } + + #[cfg(not(feature = "bigdecimal"))] + if preferred_crates.numeric == NumericCrate::BigDecimal { + return Err(Error::NumericCrateFeatureNotEnabled); + } + + #[cfg(feature = "rust_decimal")] + if matches!(preferred_crates.numeric, NumericCrate::RustDecimal | NumericCrate::Inferred) { + $( + if <$rust_decimal_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok($crate::select_input_type!($rust_decimal_ty $(, $rust_decimal_input)?)); + } + )* + $( - $(#[$meta])? - _ if <$ty as sqlx_core::types::Type<$database>>::compatible(info) => Some(stringify!($ty)), + if <$rust_decimal_ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok($crate::select_input_type!($rust_decimal_ty $(, $rust_decimal_input)?)); + } )* - _ => None } + + #[cfg(not(feature = "rust_decimal"))] + if preferred_crates.numeric == NumericCrate::RustDecimal { + return Err(Error::NumericCrateFeatureNotEnabled); + } + + Err(Error::NoMappingFound) } fn get_feature_gate($ty_info: &Self::TypeInfo) -> Option<&'static str> { @@ -181,13 +472,50 @@ macro_rules! impl_type_checking { let info = value.type_info(); - match () { + #[cfg(feature = "time")] + { $( - $(#[$meta])? - _ if <$ty as sqlx_core::types::Type<$database>>::compatible(&info) => $crate::type_checking::FmtValue::debug::<$ty>(value), + if <$time_ty as sqlx_core::types::Type<$database>>::compatible(&info) { + return $crate::type_checking::FmtValue::debug::<$time_ty>(value); + } )* - _ => $crate::type_checking::FmtValue::unknown(value), } + + #[cfg(feature = "chrono")] + { + $( + if <$chrono_ty as sqlx_core::types::Type<$database>>::compatible(&info) { + return $crate::type_checking::FmtValue::debug::<$chrono_ty>(value); + } + )* + } + + #[cfg(feature = "bigdecimal")] + { + $( + if <$bigdecimal_ty as sqlx_core::types::Type<$database>>::compatible(&info) { + return $crate::type_checking::FmtValue::debug::<$bigdecimal_ty>(value); + } + )* + } + + #[cfg(feature = "rust_decimal")] + { + $( + if <$rust_decimal_ty as sqlx_core::types::Type<$database>>::compatible(&info) { + return $crate::type_checking::FmtValue::debug::<$rust_decimal_ty>(value); + } + )* + } + + $( + $(#[$meta])? + if <$ty as sqlx_core::types::Type<$database>>::compatible(&info) { + return $crate::type_checking::FmtValue::debug::<$ty>(value); + } + )* + + $crate::type_checking::FmtValue::unknown(value) } } }; diff --git a/sqlx-macros-core/Cargo.toml b/sqlx-macros-core/Cargo.toml index d78cbe3d63..8263f762d8 100644 --- a/sqlx-macros-core/Cargo.toml +++ b/sqlx-macros-core/Cargo.toml @@ -27,6 +27,8 @@ derive = [] macros = [] migrate = ["sqlx-core/migrate"] +sqlx-toml = ["sqlx-core/sqlx-toml"] + # database mysql = ["sqlx-mysql"] postgres = ["sqlx-postgres"] diff --git a/sqlx-macros-core/src/migrate.rs b/sqlx-macros-core/src/migrate.rs index c9cf5b8eb1..aa0d56fb8c 100644 --- a/sqlx-macros-core/src/migrate.rs +++ b/sqlx-macros-core/src/migrate.rs @@ -3,11 +3,13 @@ extern crate proc_macro; use std::path::{Path, PathBuf}; -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens, TokenStreamExt}; +use sqlx_core::config::Config; +use sqlx_core::migrate::{Migration, MigrationType}; use syn::LitStr; -use sqlx_core::migrate::{Migration, MigrationType}; +pub const DEFAULT_PATH: &str = "./migrations"; pub struct QuoteMigrationType(MigrationType); @@ -81,20 +83,26 @@ impl ToTokens for QuoteMigration { } } -pub fn expand_migrator_from_lit_dir(dir: LitStr) -> crate::Result { - expand_migrator_from_dir(&dir.value(), dir.span()) +pub fn default_path(config: &Config) -> &str { + config + .migrate + .migrations_dir + .as_deref() + .unwrap_or(DEFAULT_PATH) } -pub(crate) fn expand_migrator_from_dir( - dir: &str, - err_span: proc_macro2::Span, -) -> crate::Result { - let path = crate::common::resolve_path(dir, err_span)?; +pub fn expand(path_arg: Option) -> crate::Result { + let config = Config::try_from_crate()?; + + let path = match path_arg { + Some(path_arg) => crate::common::resolve_path(path_arg.value(), path_arg.span())?, + None => { crate::common::resolve_path(default_path(config), Span::call_site()) }?, + }; - expand_migrator(&path) + expand_with_path(config, &path) } -pub(crate) fn expand_migrator(path: &Path) -> crate::Result { +pub fn expand_with_path(config: &Config, path: &Path) -> crate::Result { let path = path.canonicalize().map_err(|e| { format!( "error canonicalizing migration directory {}: {e}", @@ -102,11 +110,19 @@ pub(crate) fn expand_migrator(path: &Path) -> crate::Result { ) })?; + let resolve_config = config.migrate.to_resolve_config(); + // Use the same code path to resolve migrations at compile time and runtime. - let migrations = sqlx_core::migrate::resolve_blocking(&path)? + let migrations = sqlx_core::migrate::resolve_blocking_with_config(&path, &resolve_config)? .into_iter() .map(|(migration, path)| QuoteMigration { migration, path }); + let table_name = config.migrate.table_name(); + + let create_schemas = config.migrate.create_schemas.iter().map(|schema_name| { + quote! { ::std::borrow::Cow::Borrowed(#schema_name) } + }); + #[cfg(any(sqlx_macros_unstable, procmacro2_semver_exempt))] { let path = path.to_str().ok_or_else(|| { @@ -124,6 +140,8 @@ pub(crate) fn expand_migrator(path: &Path) -> crate::Result { migrations: ::std::borrow::Cow::Borrowed(&[ #(#migrations),* ]), + create_schemas: ::std::borrow::Cow::Borrowed(&[#(#create_schemas),*]), + table_name: ::std::borrow::Cow::Borrowed(#table_name), ..::sqlx::migrate::Migrator::DEFAULT } }) diff --git a/sqlx-macros-core/src/query/args.rs b/sqlx-macros-core/src/query/args.rs index 788a9aadc5..1b338efa3e 100644 --- a/sqlx-macros-core/src/query/args.rs +++ b/sqlx-macros-core/src/query/args.rs @@ -1,9 +1,12 @@ use crate::database::DatabaseExt; -use crate::query::QueryMacroInput; +use crate::query::{QueryMacroInput, Warnings}; use either::Either; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; +use sqlx_core::config::Config; use sqlx_core::describe::Describe; +use sqlx_core::type_checking; +use sqlx_core::type_info::TypeInfo; use syn::spanned::Spanned; use syn::{Expr, ExprCast, ExprGroup, Type}; @@ -11,6 +14,8 @@ use syn::{Expr, ExprCast, ExprGroup, Type}; /// and binds them to `DB::Arguments` with the ident `query_args`. pub fn quote_args( input: &QueryMacroInput, + config: &Config, + warnings: &mut Warnings, info: &Describe, ) -> crate::Result { let db_path = DB::db_path(); @@ -55,27 +60,7 @@ pub fn quote_args( return Ok(quote!()); } - let param_ty = - DB::param_type_for_id(param_ty) - .ok_or_else(|| { - if let Some(feature_gate) = DB::get_feature_gate(param_ty) { - format!( - "optional sqlx feature `{}` required for type {} of param #{}", - feature_gate, - param_ty, - i + 1, - ) - } else { - format!( - "no built in mapping found for type {} for param #{}; \ - a type override may be required, see documentation for details", - param_ty, - i + 1 - ) - } - })? - .parse::() - .map_err(|_| format!("Rust type mapping for {param_ty} not parsable"))?; + let param_ty = get_param_type::(param_ty, config, warnings, i)?; Ok(quote_spanned!(expr.span() => // this shouldn't actually run @@ -120,6 +105,77 @@ pub fn quote_args( }) } +fn get_param_type( + param_ty: &DB::TypeInfo, + config: &Config, + warnings: &mut Warnings, + i: usize, +) -> crate::Result { + if let Some(type_override) = config.macros.type_override(param_ty.name()) { + return Ok(type_override.parse()?); + } + + let err = match DB::param_type_for_id(param_ty, &config.macros.preferred_crates) { + Ok(t) => return Ok(t.parse()?), + Err(e) => e, + }; + + let param_num = i + 1; + + let message = match err { + type_checking::Error::NoMappingFound => { + if let Some(feature_gate) = DB::get_feature_gate(param_ty) { + format!( + "optional sqlx feature `{feature_gate}` required for type {param_ty} of param #{param_num}", + ) + } else { + format!( + "no built-in mapping for type {param_ty} of param #{param_num}; \ + a type override may be required, see documentation for details" + ) + } + } + type_checking::Error::DateTimeCrateFeatureNotEnabled => { + let feature_gate = config + .macros + .preferred_crates + .date_time + .crate_name() + .expect("BUG: got feature-not-enabled error for DateTimeCrate::Inferred"); + + format!( + "SQLx feature `{feature_gate}` required for type {param_ty} of param #{param_num} \ + (configured by `macros.preferred-crates.date-time` in sqlx.toml)", + ) + } + type_checking::Error::NumericCrateFeatureNotEnabled => { + let feature_gate = config + .macros + .preferred_crates + .numeric + .crate_name() + .expect("BUG: got feature-not-enabled error for NumericCrate::Inferred"); + + format!( + "SQLx feature `{feature_gate}` required for type {param_ty} of param #{param_num} \ + (configured by `macros.preferred-crates.numeric` in sqlx.toml)", + ) + } + + type_checking::Error::AmbiguousDateTimeType { fallback } => { + warnings.ambiguous_datetime = true; + return Ok(fallback.parse()?); + } + + type_checking::Error::AmbiguousNumericType { fallback } => { + warnings.ambiguous_numeric = true; + return Ok(fallback.parse()?); + } + }; + + Err(message.into()) +} + fn get_type_override(expr: &Expr) -> Option<&Type> { match expr { Expr::Group(group) => get_type_override(&group.expr), diff --git a/sqlx-macros-core/src/query/mod.rs b/sqlx-macros-core/src/query/mod.rs index a51137413e..0cd8771781 100644 --- a/sqlx-macros-core/src/query/mod.rs +++ b/sqlx-macros-core/src/query/mod.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{hash_map, HashMap}; use std::path::{Path, PathBuf}; use std::sync::{Arc, Mutex}; use std::{fs, io}; @@ -16,6 +16,7 @@ use crate::database::DatabaseExt; use crate::query::data::{hash_string, DynQueryData, QueryData}; use crate::query::input::RecordType; use either::Either; +use sqlx_core::config::Config; use url::Url; mod args; @@ -112,7 +113,7 @@ static METADATA: Lazy>> = Lazy::new(Default::def // If we are in a workspace, lookup `workspace_root` since `CARGO_MANIFEST_DIR` won't // reflect the workspace dir: https://github.com/rust-lang/cargo/issues/3946 -fn init_metadata(manifest_dir: &String) -> Metadata { +fn init_metadata(manifest_dir: &String) -> crate::Result { let manifest_dir: PathBuf = manifest_dir.into(); let (database_url, offline, offline_dir) = load_dot_env(&manifest_dir); @@ -123,15 +124,17 @@ fn init_metadata(manifest_dir: &String) -> Metadata { .map(|s| s.eq_ignore_ascii_case("true") || s == "1") .unwrap_or(false); - let database_url = env("DATABASE_URL").ok().or(database_url); + let var_name = Config::try_from_crate()?.common.database_url_var(); - Metadata { + let database_url = env(var_name).ok().or(database_url); + + Ok(Metadata { manifest_dir, offline, database_url, offline_dir, workspace_root: Arc::new(Mutex::new(None)), - } + }) } pub fn expand_input<'a>( @@ -149,9 +152,13 @@ pub fn expand_input<'a>( guard }); - let metadata = metadata_lock - .entry(manifest_dir) - .or_insert_with_key(init_metadata); + let metadata = match metadata_lock.entry(manifest_dir) { + hash_map::Entry::Occupied(occupied) => occupied.into_mut(), + hash_map::Entry::Vacant(vacant) => { + let metadata = init_metadata(vacant.key())?; + vacant.insert(metadata) + } + }; let data_source = match &metadata { Metadata { @@ -236,6 +243,12 @@ impl DescribeExt for Describe where { } +#[derive(Default)] +struct Warnings { + ambiguous_datetime: bool, + ambiguous_numeric: bool, +} + fn expand_with_data( input: QueryMacroInput, data: QueryData, @@ -244,6 +257,8 @@ fn expand_with_data( where Describe: DescribeExt, { + let config = Config::try_from_crate()?; + // validate at the minimum that our args match the query's input parameters let num_parameters = match data.describe.parameters() { Some(Either::Left(params)) => Some(params.len()), @@ -260,7 +275,9 @@ where } } - let args_tokens = args::quote_args(&input, &data.describe)?; + let mut warnings = Warnings::default(); + + let args_tokens = args::quote_args(&input, config, &mut warnings, &data.describe)?; let query_args = format_ident!("query_args"); @@ -279,7 +296,7 @@ where } else { match input.record_type { RecordType::Generated => { - let columns = output::columns_to_rust::(&data.describe)?; + let columns = output::columns_to_rust::(&data.describe, config, &mut warnings)?; let record_name: Type = syn::parse_str("Record").unwrap(); @@ -315,22 +332,44 @@ where record_tokens } RecordType::Given(ref out_ty) => { - let columns = output::columns_to_rust::(&data.describe)?; + let columns = output::columns_to_rust::(&data.describe, config, &mut warnings)?; output::quote_query_as::(&input, out_ty, &query_args, &columns) } - RecordType::Scalar => { - output::quote_query_scalar::(&input, &query_args, &data.describe)? - } + RecordType::Scalar => output::quote_query_scalar::( + &input, + config, + &mut warnings, + &query_args, + &data.describe, + )?, } }; + let mut warnings_out = TokenStream::new(); + + if warnings.ambiguous_datetime { + // Warns if the date-time crate is inferred but both `chrono` and `time` are enabled + warnings_out.extend(quote! { + ::sqlx::warn_on_ambiguous_inferred_date_time_crate(); + }); + } + + if warnings.ambiguous_numeric { + // Warns if the numeric crate is inferred but both `bigdecimal` and `rust_decimal` are enabled + warnings_out.extend(quote! { + ::sqlx::warn_on_ambiguous_inferred_numeric_crate(); + }); + } + let ret_tokens = quote! { { #[allow(clippy::all)] { use ::sqlx::Arguments as _; + #warnings_out + #args_tokens #output diff --git a/sqlx-macros-core/src/query/output.rs b/sqlx-macros-core/src/query/output.rs index 3641e55db5..987dcaa3cb 100644 --- a/sqlx-macros-core/src/query/output.rs +++ b/sqlx-macros-core/src/query/output.rs @@ -2,13 +2,16 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::Type; -use sqlx_core::column::Column; +use sqlx_core::column::{Column, ColumnOrigin}; use sqlx_core::describe::Describe; use crate::database::DatabaseExt; -use crate::query::QueryMacroInput; +use crate::query::{QueryMacroInput, Warnings}; +use sqlx_core::config::Config; +use sqlx_core::type_checking; use sqlx_core::type_checking::TypeChecking; +use sqlx_core::type_info::TypeInfo; use std::fmt::{self, Display, Formatter}; use syn::parse::{Parse, ParseStream}; use syn::Token; @@ -76,13 +79,22 @@ impl Display for DisplayColumn<'_> { } } -pub fn columns_to_rust(describe: &Describe) -> crate::Result> { +pub fn columns_to_rust( + describe: &Describe, + config: &Config, + warnings: &mut Warnings, +) -> crate::Result> { (0..describe.columns().len()) - .map(|i| column_to_rust(describe, i)) + .map(|i| column_to_rust(describe, config, warnings, i)) .collect::>>() } -fn column_to_rust(describe: &Describe, i: usize) -> crate::Result { +fn column_to_rust( + describe: &Describe, + config: &Config, + warnings: &mut Warnings, + i: usize, +) -> crate::Result { let column = &describe.columns()[i]; // add raw prefix to all identifiers @@ -106,7 +118,7 @@ fn column_to_rust(describe: &Describe, i: usize) -> crate:: (ColumnTypeOverride::Wildcard, true) => ColumnType::OptWildcard, (ColumnTypeOverride::None, _) => { - let type_ = get_column_type::(i, column); + let type_ = get_column_type::(config, warnings, i, column); if !nullable { ColumnType::Exact(type_) } else { @@ -193,6 +205,8 @@ pub fn quote_query_as( pub fn quote_query_scalar( input: &QueryMacroInput, + config: &Config, + warnings: &mut Warnings, bind_args: &Ident, describe: &Describe, ) -> crate::Result { @@ -207,10 +221,10 @@ pub fn quote_query_scalar( } // attempt to parse a column override, otherwise fall back to the inferred type of the column - let ty = if let Ok(rust_col) = column_to_rust(describe, 0) { + let ty = if let Ok(rust_col) = column_to_rust(describe, config, warnings, 0) { rust_col.type_.to_token_stream() } else if input.checked { - let ty = get_column_type::(0, &columns[0]); + let ty = get_column_type::(config, warnings, 0, &columns[0]); if describe.nullable(0).unwrap_or(true) { quote! { ::std::option::Option<#ty> } } else { @@ -228,37 +242,105 @@ pub fn quote_query_scalar( }) } -fn get_column_type(i: usize, column: &DB::Column) -> TokenStream { +fn get_column_type( + config: &Config, + warnings: &mut Warnings, + i: usize, + column: &DB::Column, +) -> TokenStream { + if let ColumnOrigin::Table(origin) = column.origin() { + if let Some(column_override) = config.macros.column_override(&origin.table, &origin.name) { + return column_override.parse().unwrap(); + } + } + let type_info = column.type_info(); - ::return_type_for_id(type_info).map_or_else( - || { - let message = - if let Some(feature_gate) = ::get_feature_gate(type_info) { - format!( - "SQLx feature `{feat}` required for type {ty} of {col}", - ty = &type_info, - feat = feature_gate, - col = DisplayColumn { - idx: i, - name: column.name() - } - ) - } else { - format!( - "no built in mapping found for type {ty} of {col}; \ - a type override may be required, see documentation for details", - ty = type_info, - col = DisplayColumn { - idx: i, - name: column.name() - } - ) - }; - syn::Error::new(Span::call_site(), message).to_compile_error() - }, - |t| t.parse().unwrap(), - ) + if let Some(type_override) = config.macros.type_override(type_info.name()) { + return type_override.parse().unwrap(); + } + + let err = match ::return_type_for_id( + type_info, + &config.macros.preferred_crates, + ) { + Ok(t) => return t.parse().unwrap(), + Err(e) => e, + }; + + let message = match err { + type_checking::Error::NoMappingFound => { + if let Some(feature_gate) = ::get_feature_gate(type_info) { + format!( + "SQLx feature `{feat}` required for type {ty} of {col}", + ty = &type_info, + feat = feature_gate, + col = DisplayColumn { + idx: i, + name: column.name() + } + ) + } else { + format!( + "no built-in mapping found for type {ty} of {col}; \ + a type override may be required, see documentation for details", + ty = type_info, + col = DisplayColumn { + idx: i, + name: column.name() + } + ) + } + } + type_checking::Error::DateTimeCrateFeatureNotEnabled => { + let feature_gate = config + .macros + .preferred_crates + .date_time + .crate_name() + .expect("BUG: got feature-not-enabled error for DateTimeCrate::Inferred"); + + format!( + "SQLx feature `{feat}` required for type {ty} of {col} \ + (configured by `macros.preferred-crates.date-time` in sqlx.toml)", + ty = &type_info, + feat = feature_gate, + col = DisplayColumn { + idx: i, + name: column.name() + } + ) + } + type_checking::Error::NumericCrateFeatureNotEnabled => { + let feature_gate = config + .macros + .preferred_crates + .numeric + .crate_name() + .expect("BUG: got feature-not-enabled error for NumericCrate::Inferred"); + + format!( + "SQLx feature `{feat}` required for type {ty} of {col} \ + (configured by `macros.preferred-crates.numeric` in sqlx.toml)", + ty = &type_info, + feat = feature_gate, + col = DisplayColumn { + idx: i, + name: column.name() + } + ) + } + type_checking::Error::AmbiguousDateTimeType { fallback } => { + warnings.ambiguous_datetime = true; + return fallback.parse().unwrap(); + } + type_checking::Error::AmbiguousNumericType { fallback } => { + warnings.ambiguous_numeric = true; + return fallback.parse().unwrap(); + } + }; + + syn::Error::new(Span::call_site(), message).to_compile_error() } impl ColumnDecl { diff --git a/sqlx-macros-core/src/test_attr.rs b/sqlx-macros-core/src/test_attr.rs index 3104a0e743..9565b5f2f5 100644 --- a/sqlx-macros-core/src/test_attr.rs +++ b/sqlx-macros-core/src/test_attr.rs @@ -77,6 +77,8 @@ fn expand_simple(input: syn::ItemFn) -> TokenStream { #[cfg(feature = "migrate")] fn expand_advanced(args: AttributeArgs, input: syn::ItemFn) -> crate::Result { + let config = sqlx_core::config::Config::try_from_crate()?; + let ret = &input.sig.output; let name = &input.sig.ident; let inputs = &input.sig.inputs; @@ -143,15 +145,16 @@ fn expand_advanced(args: AttributeArgs, input: syn::ItemFn) -> crate::Result { - let migrator = crate::migrate::expand_migrator_from_lit_dir(path)?; + let migrator = crate::migrate::expand(Some(path))?; quote! { args.migrator(&#migrator); } } MigrationsOpt::InferredPath if !inputs.is_empty() => { - let migrations_path = - crate::common::resolve_path("./migrations", proc_macro2::Span::call_site())?; + let path = crate::migrate::default_path(config); + + let resolved_path = crate::common::resolve_path(path, proc_macro2::Span::call_site())?; - if migrations_path.is_dir() { - let migrator = crate::migrate::expand_migrator(&migrations_path)?; + if resolved_path.is_dir() { + let migrator = crate::migrate::expand_with_path(config, &resolved_path)?; quote! { args.migrator(&#migrator); } } else { quote! {} diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 032a190dd1..23079a3810 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -28,6 +28,8 @@ derive = ["sqlx-macros-core/derive"] macros = ["sqlx-macros-core/macros"] migrate = ["sqlx-macros-core/migrate"] +sqlx-toml = ["sqlx-macros-core/sqlx-toml"] + # database mysql = ["sqlx-macros-core/mysql"] postgres = ["sqlx-macros-core/postgres"] diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index 987794acbc..ccffc9bd2a 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -68,8 +68,8 @@ pub fn derive_from_row(input: TokenStream) -> TokenStream { pub fn migrate(input: TokenStream) -> TokenStream { use syn::LitStr; - let input = syn::parse_macro_input!(input as LitStr); - match migrate::expand_migrator_from_lit_dir(input) { + let input = syn::parse_macro_input!(input as Option); + match migrate::expand(input) { Ok(ts) => ts.into(), Err(e) => { if let Some(parse_err) = e.downcast_ref::() { diff --git a/sqlx-mysql/src/column.rs b/sqlx-mysql/src/column.rs index 1bb841b9a1..457cf991d3 100644 --- a/sqlx-mysql/src/column.rs +++ b/sqlx-mysql/src/column.rs @@ -10,6 +10,9 @@ pub struct MySqlColumn { pub(crate) name: UStr, pub(crate) type_info: MySqlTypeInfo, + #[cfg_attr(feature = "offline", serde(default))] + pub(crate) origin: ColumnOrigin, + #[cfg_attr(feature = "offline", serde(skip))] pub(crate) flags: Option, } @@ -28,4 +31,8 @@ impl Column for MySqlColumn { fn type_info(&self) -> &MySqlTypeInfo { &self.type_info } + + fn origin(&self) -> ColumnOrigin { + self.origin.clone() + } } diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index 44cb523f56..0b6a234f4e 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -22,6 +22,7 @@ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::TryStreamExt; +use sqlx_core::column::{ColumnOrigin, TableColumn}; use std::{borrow::Cow, pin::pin, sync::Arc}; impl MySqlConnection { @@ -385,11 +386,30 @@ async fn recv_result_columns( fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result { // if the alias is empty, use the alias // only then use the name + let column_name = def.name()?; + let name = match (def.name()?, def.alias()?) { (_, alias) if !alias.is_empty() => UStr::new(alias), (name, _) => UStr::new(name), }; + let table = def.table()?; + + let origin = if table.is_empty() { + ColumnOrigin::Expression + } else { + let schema = def.schema()?; + + ColumnOrigin::Table(TableColumn { + table: if !schema.is_empty() { + format!("{schema}.{table}").into() + } else { + table.into() + }, + name: column_name.into(), + }) + }; + let type_info = MySqlTypeInfo::from_column(def); Ok(MySqlColumn { @@ -397,6 +417,7 @@ fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result Result<(MySqlConnectOptions, String), Error> { let mut options = MySqlConnectOptions::from_str(url)?; @@ -75,12 +74,28 @@ impl MigrateDatabase for MySql { } impl Migrate for MySqlConnection { - fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { + fn create_schema_if_not_exists<'e>( + &'e mut self, + schema_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async move { + // language=SQL + self.execute(&*format!(r#"CREATE SCHEMA IF NOT EXISTS {schema_name};"#)) + .await?; + + Ok(()) + }) + } + + fn ensure_migrations_table<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { // language=MySQL - self.execute( + self.execute(&*format!( r#" -CREATE TABLE IF NOT EXISTS _sqlx_migrations ( +CREATE TABLE IF NOT EXISTS {table_name} ( version BIGINT PRIMARY KEY, description TEXT NOT NULL, installed_on TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, @@ -88,20 +103,23 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( checksum BLOB NOT NULL, execution_time BIGINT NOT NULL ); - "#, - ) + "# + )) .await?; Ok(()) }) } - fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { + fn dirty_version<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL - let row: Option<(i64,)> = query_as( - "SELECT version FROM _sqlx_migrations WHERE success = false ORDER BY version LIMIT 1", - ) + let row: Option<(i64,)> = query_as(&format!( + "SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1" + )) .fetch_optional(self) .await?; @@ -109,15 +127,17 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn list_applied_migrations( - &mut self, - ) -> BoxFuture<'_, Result, MigrateError>> { + fn list_applied_migrations<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL - let rows: Vec<(i64, Vec)> = - query_as("SELECT version, checksum FROM _sqlx_migrations ORDER BY version") - .fetch_all(self) - .await?; + let rows: Vec<(i64, Vec)> = query_as(&format!( + "SELECT version, checksum FROM {table_name} ORDER BY version" + )) + .fetch_all(self) + .await?; let migrations = rows .into_iter() @@ -167,10 +187,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn apply<'e: 'm, 'm>( + fn apply<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { Box::pin(async move { // Use a single transaction for the actual migration script and the essential bookeeping so we never // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. @@ -187,12 +208,12 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( // `success=FALSE` and later modify the flag. // // language=MySQL - let _ = query( + let _ = query(&format!( r#" - INSERT INTO _sqlx_migrations ( version, description, success, checksum, execution_time ) + INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) VALUES ( ?, ?, FALSE, ?, -1 ) - "#, - ) + "# + )) .bind(migration.version) .bind(&*migration.description) .bind(&*migration.checksum) @@ -205,13 +226,13 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; // language=MySQL - let _ = query( + let _ = query(&format!( r#" - UPDATE _sqlx_migrations + UPDATE {table_name} SET success = TRUE WHERE version = ? - "#, - ) + "# + )) .bind(migration.version) .execute(&mut *tx) .await?; @@ -225,13 +246,13 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( let elapsed = start.elapsed(); #[allow(clippy::cast_possible_truncation)] - let _ = query( + let _ = query(&format!( r#" - UPDATE _sqlx_migrations + UPDATE {table_name} SET execution_time = ? WHERE version = ? - "#, - ) + "# + )) .bind(elapsed.as_nanos() as i64) .bind(migration.version) .execute(self) @@ -241,10 +262,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn revert<'e: 'm, 'm>( + fn revert<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { Box::pin(async move { // Use a single transaction for the actual migration script and the essential bookeeping so we never // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. @@ -258,13 +280,13 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( // `success=FALSE` and later remove the migration altogether. // // language=MySQL - let _ = query( + let _ = query(&format!( r#" - UPDATE _sqlx_migrations + UPDATE {table_name} SET success = FALSE WHERE version = ? - "#, - ) + "# + )) .bind(migration.version) .execute(&mut *tx) .await?; @@ -272,7 +294,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( tx.execute(&*migration.sql).await?; // language=SQL - let _ = query(r#"DELETE FROM _sqlx_migrations WHERE version = ?"#) + let _ = query(&format!(r#"DELETE FROM {table_name} WHERE version = ?"#)) .bind(migration.version) .execute(&mut *tx) .await?; diff --git a/sqlx-mysql/src/protocol/text/column.rs b/sqlx-mysql/src/protocol/text/column.rs index 425a5cdc47..a7d95f7166 100644 --- a/sqlx-mysql/src/protocol/text/column.rs +++ b/sqlx-mysql/src/protocol/text/column.rs @@ -1,4 +1,4 @@ -use std::str::from_utf8; +use std::str; use bitflags::bitflags; use bytes::{Buf, Bytes}; @@ -104,11 +104,9 @@ pub enum ColumnType { pub(crate) struct ColumnDefinition { #[allow(unused)] catalog: Bytes, - #[allow(unused)] schema: Bytes, #[allow(unused)] table_alias: Bytes, - #[allow(unused)] table: Bytes, alias: Bytes, name: Bytes, @@ -125,12 +123,20 @@ impl ColumnDefinition { // NOTE: strings in-protocol are transmitted according to the client character set // as this is UTF-8, all these strings should be UTF-8 + pub(crate) fn schema(&self) -> Result<&str, Error> { + str::from_utf8(&self.schema).map_err(Error::protocol) + } + + pub(crate) fn table(&self) -> Result<&str, Error> { + str::from_utf8(&self.table).map_err(Error::protocol) + } + pub(crate) fn name(&self) -> Result<&str, Error> { - from_utf8(&self.name).map_err(Error::protocol) + str::from_utf8(&self.name).map_err(Error::protocol) } pub(crate) fn alias(&self) -> Result<&str, Error> { - from_utf8(&self.alias).map_err(Error::protocol) + str::from_utf8(&self.alias).map_err(Error::protocol) } } diff --git a/sqlx-mysql/src/type_checking.rs b/sqlx-mysql/src/type_checking.rs index 3f3ce5833e..0bdc84d8c9 100644 --- a/sqlx-mysql/src/type_checking.rs +++ b/sqlx-mysql/src/type_checking.rs @@ -25,41 +25,39 @@ impl_type_checking!( // BINARY, VAR_BINARY, BLOB Vec, - // Types from third-party crates need to be referenced at a known path - // for the macros to work, but we don't want to require the user to add extra dependencies. - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveTime, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveDate, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveDateTime, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::DateTime, - - #[cfg(feature = "time")] - sqlx::types::time::Time, + #[cfg(feature = "json")] + sqlx::types::JsonValue, + }, + ParamChecking::Weak, + feature-types: info => info.__type_feature_gate(), + // The expansion of the macro automatically applies the correct feature name + // and checks `[macros.preferred-crates]` + datetime-types: { + chrono: { + sqlx::types::chrono::NaiveTime, - #[cfg(feature = "time")] - sqlx::types::time::Date, + sqlx::types::chrono::NaiveDate, - #[cfg(feature = "time")] - sqlx::types::time::PrimitiveDateTime, + sqlx::types::chrono::NaiveDateTime, - #[cfg(feature = "time")] - sqlx::types::time::OffsetDateTime, + sqlx::types::chrono::DateTime, + }, + time: { + sqlx::types::time::Time, - #[cfg(feature = "bigdecimal")] - sqlx::types::BigDecimal, + sqlx::types::time::Date, - #[cfg(feature = "rust_decimal")] - sqlx::types::Decimal, + sqlx::types::time::PrimitiveDateTime, - #[cfg(feature = "json")] - sqlx::types::JsonValue, + sqlx::types::time::OffsetDateTime, + }, + }, + numeric-types: { + bigdecimal: { + sqlx::types::BigDecimal, + }, + rust_decimal: { + sqlx::types::Decimal, + }, }, - ParamChecking::Weak, - feature-types: info => info.__type_feature_gate(), ); diff --git a/sqlx-postgres/src/column.rs b/sqlx-postgres/src/column.rs index a838c27b75..4dd3a1cbd2 100644 --- a/sqlx-postgres/src/column.rs +++ b/sqlx-postgres/src/column.rs @@ -1,6 +1,7 @@ use crate::ext::ustr::UStr; use crate::{PgTypeInfo, Postgres}; +use sqlx_core::column::ColumnOrigin; pub(crate) use sqlx_core::column::{Column, ColumnIndex}; #[derive(Debug, Clone)] @@ -9,6 +10,10 @@ pub struct PgColumn { pub(crate) ordinal: usize, pub(crate) name: UStr, pub(crate) type_info: PgTypeInfo, + + #[cfg_attr(feature = "offline", serde(default))] + pub(crate) origin: ColumnOrigin, + #[cfg_attr(feature = "offline", serde(skip))] pub(crate) relation_id: Option, #[cfg_attr(feature = "offline", serde(skip))] @@ -51,4 +56,8 @@ impl Column for PgColumn { fn type_info(&self) -> &PgTypeInfo { &self.type_info } + + fn origin(&self) -> ColumnOrigin { + self.origin.clone() + } } diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index a27578c56c..0334357a6c 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -1,3 +1,4 @@ +use crate::connection::TableColumns; use crate::error::Error; use crate::ext::ustr::UStr; use crate::io::StatementId; @@ -11,6 +12,7 @@ use crate::types::Oid; use crate::HashMap; use crate::{PgColumn, PgConnection, PgTypeInfo}; use smallvec::SmallVec; +use sqlx_core::column::{ColumnOrigin, TableColumn}; use sqlx_core::query_builder::QueryBuilder; use std::sync::Arc; @@ -100,7 +102,8 @@ impl PgConnection { pub(super) async fn handle_row_description( &mut self, desc: Option, - should_fetch: bool, + fetch_type_info: bool, + fetch_column_description: bool, ) -> Result<(Vec, HashMap), Error> { let mut columns = Vec::new(); let mut column_names = HashMap::new(); @@ -119,15 +122,25 @@ impl PgConnection { let name = UStr::from(field.name); let type_info = self - .maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch) + .maybe_fetch_type_info_by_oid(field.data_type_id, fetch_type_info) .await?; + let origin = if let (Some(relation_oid), Some(attribute_no)) = + (field.relation_id, field.relation_attribute_no) + { + self.maybe_fetch_column_origin(relation_oid, attribute_no, fetch_column_description) + .await? + } else { + ColumnOrigin::Expression + }; + let column = PgColumn { ordinal: index, name: name.clone(), type_info, relation_id: field.relation_id, relation_attribute_no: field.relation_attribute_no, + origin, }; columns.push(column); @@ -190,6 +203,69 @@ impl PgConnection { } } + async fn maybe_fetch_column_origin( + &mut self, + relation_id: Oid, + attribute_no: i16, + should_fetch: bool, + ) -> Result { + if let Some(origin) = self + .inner + .cache_table_to_column_names + .get(&relation_id) + .and_then(|table_columns| { + let column_name = table_columns.columns.get(&attribute_no).cloned()?; + + Some(ColumnOrigin::Table(TableColumn { + table: table_columns.table_name.clone(), + name: column_name, + })) + }) + { + return Ok(origin); + } + + if !should_fetch { + return Ok(ColumnOrigin::Unknown); + } + + // Looking up the table name _may_ end up being redundant, + // but the round-trip to the server is by far the most expensive part anyway. + let Some((table_name, column_name)): Option<(String, String)> = query_as( + // language=PostgreSQL + "SELECT $1::oid::regclass::text, attname \ + FROM pg_catalog.pg_attribute \ + WHERE attrelid = $1 AND attnum = $2", + ) + .bind(relation_id) + .bind(attribute_no) + .fetch_optional(&mut *self) + .await? + else { + // The column/table doesn't exist anymore for whatever reason. + return Ok(ColumnOrigin::Unknown); + }; + + let table_columns = self + .inner + .cache_table_to_column_names + .entry(relation_id) + .or_insert_with(|| TableColumns { + table_name: table_name.into(), + columns: Default::default(), + }); + + let column_name = table_columns + .columns + .entry(attribute_no) + .or_insert(column_name.into()); + + Ok(ColumnOrigin::Table(TableColumn { + table: table_columns.table_name.clone(), + name: Arc::clone(column_name), + })) + } + async fn fetch_type_by_oid(&mut self, oid: Oid) -> Result { let (name, typ_type, category, relation_id, element, base_type): ( String, diff --git a/sqlx-postgres/src/connection/establish.rs b/sqlx-postgres/src/connection/establish.rs index 1bc4172fbd..634b71de4b 100644 --- a/sqlx-postgres/src/connection/establish.rs +++ b/sqlx-postgres/src/connection/establish.rs @@ -148,6 +148,7 @@ impl PgConnection { cache_type_oid: HashMap::new(), cache_type_info: HashMap::new(), cache_elem_type_to_array: HashMap::new(), + cache_table_to_column_names: HashMap::new(), log_settings: options.log_settings.clone(), }), }) diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index d0596aacee..93cf4ec6bc 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -26,6 +26,7 @@ async fn prepare( parameters: &[PgTypeInfo], metadata: Option>, persistent: bool, + fetch_column_origin: bool, ) -> Result<(StatementId, Arc), Error> { let id = if persistent { let id = conn.inner.next_statement_id; @@ -85,7 +86,9 @@ async fn prepare( let parameters = conn.handle_parameter_description(parameters).await?; - let (columns, column_names) = conn.handle_row_description(rows, true).await?; + let (columns, column_names) = conn + .handle_row_description(rows, true, fetch_column_origin) + .await?; // ensure that if we did fetch custom data, we wait until we are fully ready before // continuing @@ -173,12 +176,21 @@ impl PgConnection { // optional metadata that was provided by the user, this means they are reusing // a statement object metadata: Option>, + fetch_column_origin: bool, ) -> Result<(StatementId, Arc), Error> { if let Some(statement) = self.inner.cache_statement.get_mut(sql) { return Ok((*statement).clone()); } - let statement = prepare(self, sql, parameters, metadata, persistent).await?; + let statement = prepare( + self, + sql, + parameters, + metadata, + persistent, + fetch_column_origin, + ) + .await?; if persistent && self.inner.cache_statement.is_enabled() { if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) { @@ -226,7 +238,7 @@ impl PgConnection { // prepare the statement if this our first time executing it // always return the statement ID here let (statement, metadata_) = self - .get_or_prepare(query, &arguments.types, persistent, metadata_opt) + .get_or_prepare(query, &arguments.types, persistent, metadata_opt, false) .await?; metadata = metadata_; @@ -333,7 +345,7 @@ impl PgConnection { BackendMessageFormat::RowDescription => { // indicates that a *new* set of rows are about to be returned let (columns, column_names) = self - .handle_row_description(Some(message.decode()?), false) + .handle_row_description(Some(message.decode()?), false, false) .await?; metadata = Arc::new(PgStatementMetadata { @@ -453,7 +465,9 @@ impl<'c> Executor<'c> for &'c mut PgConnection { Box::pin(async move { self.wait_until_ready().await?; - let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?; + let (_, metadata) = self + .get_or_prepare(sql, parameters, true, None, true) + .await?; Ok(PgStatement { sql: Cow::Borrowed(sql), @@ -472,7 +486,7 @@ impl<'c> Executor<'c> for &'c mut PgConnection { Box::pin(async move { self.wait_until_ready().await?; - let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?; + let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None, true).await?; let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?; diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index ce499ed744..74398d6a8b 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; +use std::collections::BTreeMap; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -64,6 +65,7 @@ pub struct PgConnectionInner { cache_type_info: HashMap, cache_type_oid: HashMap, cache_elem_type_to_array: HashMap, + cache_table_to_column_names: HashMap, // number of ReadyForQuery messages that we are currently expecting pub(crate) pending_ready_for_query_count: usize, @@ -75,6 +77,12 @@ pub struct PgConnectionInner { log_settings: LogSettings, } +pub(crate) struct TableColumns { + table_name: Arc, + /// Attribute number -> name. + columns: BTreeMap>, +} + impl PgConnection { /// the version number of the server in `libpq` format pub fn server_version_num(&self) -> Option { diff --git a/sqlx-postgres/src/migrate.rs b/sqlx-postgres/src/migrate.rs index c37e92f4d6..8275bda188 100644 --- a/sqlx-postgres/src/migrate.rs +++ b/sqlx-postgres/src/migrate.rs @@ -111,12 +111,28 @@ impl MigrateDatabase for Postgres { } impl Migrate for PgConnection { - fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { + fn create_schema_if_not_exists<'e>( + &'e mut self, + schema_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async move { + // language=SQL + self.execute(&*format!(r#"CREATE SCHEMA IF NOT EXISTS {schema_name};"#)) + .await?; + + Ok(()) + }) + } + + fn ensure_migrations_table<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { // language=SQL - self.execute( + self.execute(&*format!( r#" -CREATE TABLE IF NOT EXISTS _sqlx_migrations ( +CREATE TABLE IF NOT EXISTS {table_name} ( version BIGINT PRIMARY KEY, description TEXT NOT NULL, installed_on TIMESTAMPTZ NOT NULL DEFAULT now(), @@ -124,20 +140,23 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( checksum BYTEA NOT NULL, execution_time BIGINT NOT NULL ); - "#, - ) + "# + )) .await?; Ok(()) }) } - fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { + fn dirty_version<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL - let row: Option<(i64,)> = query_as( - "SELECT version FROM _sqlx_migrations WHERE success = false ORDER BY version LIMIT 1", - ) + let row: Option<(i64,)> = query_as(&format!( + "SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1" + )) .fetch_optional(self) .await?; @@ -145,15 +164,17 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn list_applied_migrations( - &mut self, - ) -> BoxFuture<'_, Result, MigrateError>> { + fn list_applied_migrations<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL - let rows: Vec<(i64, Vec)> = - query_as("SELECT version, checksum FROM _sqlx_migrations ORDER BY version") - .fetch_all(self) - .await?; + let rows: Vec<(i64, Vec)> = query_as(&format!( + "SELECT version, checksum FROM {table_name} ORDER BY version" + )) + .fetch_all(self) + .await?; let migrations = rows .into_iter() @@ -203,16 +224,17 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn apply<'e: 'm, 'm>( + fn apply<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { Box::pin(async move { let start = Instant::now(); // execute migration queries if migration.no_tx { - execute_migration(self, migration).await?; + execute_migration(self, table_name, migration).await?; } else { // Use a single transaction for the actual migration script and the essential bookeeping so we never // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. @@ -220,7 +242,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( // data lineage and debugging reasons, so it is not super important if it is lost. So we initialize it to -1 // and update it once the actual transaction completed. let mut tx = self.begin().await?; - execute_migration(&mut tx, migration).await?; + execute_migration(&mut tx, table_name, migration).await?; tx.commit().await?; } @@ -231,13 +253,13 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( // language=SQL #[allow(clippy::cast_possible_truncation)] - let _ = query( + let _ = query(&format!( r#" - UPDATE _sqlx_migrations + UPDATE {table_name} SET execution_time = $1 WHERE version = $2 - "#, - ) + "# + )) .bind(elapsed.as_nanos() as i64) .bind(migration.version) .execute(self) @@ -247,21 +269,22 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn revert<'e: 'm, 'm>( + fn revert<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { Box::pin(async move { let start = Instant::now(); // execute migration queries if migration.no_tx { - revert_migration(self, migration).await?; + revert_migration(self, table_name, migration).await?; } else { // Use a single transaction for the actual migration script and the essential bookeeping so we never // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. let mut tx = self.begin().await?; - revert_migration(&mut tx, migration).await?; + revert_migration(&mut tx, table_name, migration).await?; tx.commit().await?; } @@ -274,6 +297,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( async fn execute_migration( conn: &mut PgConnection, + table_name: &str, migration: &Migration, ) -> Result<(), MigrateError> { let _ = conn @@ -282,12 +306,12 @@ async fn execute_migration( .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; // language=SQL - let _ = query( + let _ = query(&format!( r#" - INSERT INTO _sqlx_migrations ( version, description, success, checksum, execution_time ) + INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) VALUES ( $1, $2, TRUE, $3, -1 ) - "#, - ) + "# + )) .bind(migration.version) .bind(&*migration.description) .bind(&*migration.checksum) @@ -299,6 +323,7 @@ async fn execute_migration( async fn revert_migration( conn: &mut PgConnection, + table_name: &str, migration: &Migration, ) -> Result<(), MigrateError> { let _ = conn @@ -307,7 +332,7 @@ async fn revert_migration( .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; // language=SQL - let _ = query(r#"DELETE FROM _sqlx_migrations WHERE version = $1"#) + let _ = query(&format!(r#"DELETE FROM {table_name} WHERE version = $1"#)) .bind(migration.version) .execute(conn) .await?; diff --git a/sqlx-postgres/src/type_checking.rs b/sqlx-postgres/src/type_checking.rs index 672d9f73e6..8f63cf97fa 100644 --- a/sqlx-postgres/src/type_checking.rs +++ b/sqlx-postgres/src/type_checking.rs @@ -49,42 +49,6 @@ impl_type_checking!( #[cfg(feature = "uuid")] sqlx::types::Uuid, - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveTime, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveDate, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveDateTime, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::DateTime | sqlx::types::chrono::DateTime<_>, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::postgres::types::PgTimeTz, - - #[cfg(feature = "time")] - sqlx::types::time::Time, - - #[cfg(feature = "time")] - sqlx::types::time::Date, - - #[cfg(feature = "time")] - sqlx::types::time::PrimitiveDateTime, - - #[cfg(feature = "time")] - sqlx::types::time::OffsetDateTime, - - #[cfg(feature = "time")] - sqlx::postgres::types::PgTimeTz, - - #[cfg(feature = "bigdecimal")] - sqlx::types::BigDecimal, - - #[cfg(feature = "rust_decimal")] - sqlx::types::Decimal, - #[cfg(feature = "ipnetwork")] sqlx::types::ipnetwork::IpNetwork, @@ -119,36 +83,6 @@ impl_type_checking!( #[cfg(feature = "uuid")] Vec | &[sqlx::types::Uuid], - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec | &[sqlx::types::chrono::NaiveTime], - - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec | &[sqlx::types::chrono::NaiveDate], - - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec | &[sqlx::types::chrono::NaiveDateTime], - - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec> | &[sqlx::types::chrono::DateTime<_>], - - #[cfg(feature = "time")] - Vec | &[sqlx::types::time::Time], - - #[cfg(feature = "time")] - Vec | &[sqlx::types::time::Date], - - #[cfg(feature = "time")] - Vec | &[sqlx::types::time::PrimitiveDateTime], - - #[cfg(feature = "time")] - Vec | &[sqlx::types::time::OffsetDateTime], - - #[cfg(feature = "bigdecimal")] - Vec | &[sqlx::types::BigDecimal], - - #[cfg(feature = "rust_decimal")] - Vec | &[sqlx::types::Decimal], - #[cfg(feature = "ipnetwork")] Vec | &[sqlx::types::ipnetwork::IpNetwork], @@ -168,72 +102,114 @@ impl_type_checking!( sqlx::postgres::types::PgRange, sqlx::postgres::types::PgRange, - #[cfg(feature = "bigdecimal")] - sqlx::postgres::types::PgRange, + // Range arrays - #[cfg(feature = "rust_decimal")] - sqlx::postgres::types::PgRange, + Vec> | &[sqlx::postgres::types::PgRange], + Vec> | &[sqlx::postgres::types::PgRange], + }, + ParamChecking::Strong, + feature-types: info => info.__type_feature_gate(), + // The expansion of the macro automatically applies the correct feature name + // and checks `[macros.preferred-crates]` + datetime-types: { + chrono: { + // Scalar types + sqlx::types::chrono::NaiveTime, - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::postgres::types::PgRange, + sqlx::types::chrono::NaiveDate, - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::postgres::types::PgRange, + sqlx::types::chrono::NaiveDateTime, - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::postgres::types::PgRange> | - sqlx::postgres::types::PgRange>, + sqlx::types::chrono::DateTime | sqlx::types::chrono::DateTime<_>, - #[cfg(feature = "time")] - sqlx::postgres::types::PgRange, + sqlx::postgres::types::PgTimeTz, - #[cfg(feature = "time")] - sqlx::postgres::types::PgRange, + // Array types + Vec | &[sqlx::types::chrono::NaiveTime], - #[cfg(feature = "time")] - sqlx::postgres::types::PgRange, + Vec | &[sqlx::types::chrono::NaiveDate], - // Range arrays + Vec | &[sqlx::types::chrono::NaiveDateTime], - Vec> | &[sqlx::postgres::types::PgRange], - Vec> | &[sqlx::postgres::types::PgRange], + Vec> | &[sqlx::types::chrono::DateTime<_>], + + // Range types + sqlx::postgres::types::PgRange, + + sqlx::postgres::types::PgRange, + + sqlx::postgres::types::PgRange> | + sqlx::postgres::types::PgRange>, + + // Arrays of ranges + Vec> | + &[sqlx::postgres::types::PgRange], + + Vec> | + &[sqlx::postgres::types::PgRange], + + Vec>> | + &[sqlx::postgres::types::PgRange>], + }, + time: { + // Scalar types + sqlx::types::time::Time, + + sqlx::types::time::Date, + + sqlx::types::time::PrimitiveDateTime, - #[cfg(feature = "bigdecimal")] - Vec> | - &[sqlx::postgres::types::PgRange], + sqlx::types::time::OffsetDateTime, - #[cfg(feature = "rust_decimal")] - Vec> | - &[sqlx::postgres::types::PgRange], + sqlx::postgres::types::PgTimeTz, - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec> | - &[sqlx::postgres::types::PgRange], + // Array types + Vec | &[sqlx::types::time::Time], - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec> | - &[sqlx::postgres::types::PgRange], + Vec | &[sqlx::types::time::Date], - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec>> | - &[sqlx::postgres::types::PgRange>], + Vec | &[sqlx::types::time::PrimitiveDateTime], - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec>> | - &[sqlx::postgres::types::PgRange>], + Vec | &[sqlx::types::time::OffsetDateTime], - #[cfg(feature = "time")] - Vec> | - &[sqlx::postgres::types::PgRange], + // Range types + sqlx::postgres::types::PgRange, - #[cfg(feature = "time")] - Vec> | - &[sqlx::postgres::types::PgRange], + sqlx::postgres::types::PgRange, - #[cfg(feature = "time")] - Vec> | - &[sqlx::postgres::types::PgRange], + sqlx::postgres::types::PgRange, + + // Arrays of ranges + Vec> | + &[sqlx::postgres::types::PgRange], + + Vec> | + &[sqlx::postgres::types::PgRange], + + Vec> | + &[sqlx::postgres::types::PgRange], + }, + }, + numeric-types: { + bigdecimal: { + sqlx::types::BigDecimal, + + Vec | &[sqlx::types::BigDecimal], + + sqlx::postgres::types::PgRange, + + Vec> | + &[sqlx::postgres::types::PgRange], + }, + rust_decimal: { + sqlx::types::Decimal, + + Vec | &[sqlx::types::Decimal], + + sqlx::postgres::types::PgRange, + + Vec> | + &[sqlx::postgres::types::PgRange], + }, }, - ParamChecking::Strong, - feature-types: info => info.__type_feature_gate(), ); diff --git a/sqlx-sqlite/Cargo.toml b/sqlx-sqlite/Cargo.toml index 151283deda..db7fb63cb8 100644 --- a/sqlx-sqlite/Cargo.toml +++ b/sqlx-sqlite/Cargo.toml @@ -27,6 +27,10 @@ preupdate-hook = ["libsqlite3-sys/preupdate_hook"] bundled = ["libsqlite3-sys/bundled"] unbundled = ["libsqlite3-sys/buildtime_bindgen"] +# Note: currently unused, only to satisfy "unexpected `cfg` condition" lint +bigdecimal = [] +rust_decimal = [] + [dependencies] futures-core = { version = "0.3.19", default-features = false } futures-channel = { version = "0.3.19", default-features = false, features = ["sink", "alloc", "std"] } @@ -73,4 +77,4 @@ sqlx = { workspace = true, default-features = false, features = ["macros", "runt workspace = true [package.metadata.docs.rs] -features = ["bundled", "any", "json", "chrono", "time", "uuid"] \ No newline at end of file +features = ["bundled", "any", "json", "chrono", "time", "uuid"] diff --git a/sqlx-sqlite/src/column.rs b/sqlx-sqlite/src/column.rs index 00b3bc360c..d319bd46a8 100644 --- a/sqlx-sqlite/src/column.rs +++ b/sqlx-sqlite/src/column.rs @@ -9,6 +9,9 @@ pub struct SqliteColumn { pub(crate) name: UStr, pub(crate) ordinal: usize, pub(crate) type_info: SqliteTypeInfo, + + #[cfg_attr(feature = "offline", serde(default))] + pub(crate) origin: ColumnOrigin, } impl Column for SqliteColumn { @@ -25,4 +28,8 @@ impl Column for SqliteColumn { fn type_info(&self) -> &SqliteTypeInfo { &self.type_info } + + fn origin(&self) -> ColumnOrigin { + self.origin.clone() + } } diff --git a/sqlx-sqlite/src/connection/describe.rs b/sqlx-sqlite/src/connection/describe.rs index 0f4da33ccc..6db81374aa 100644 --- a/sqlx-sqlite/src/connection/describe.rs +++ b/sqlx-sqlite/src/connection/describe.rs @@ -50,6 +50,8 @@ pub(crate) fn describe(conn: &mut ConnectionState, query: &str) -> Result Result BoxFuture<'_, Result<(), Error>> { @@ -64,12 +65,36 @@ impl MigrateDatabase for Sqlite { } impl Migrate for SqliteConnection { - fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { + fn create_schema_if_not_exists<'e>( + &'e mut self, + schema_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async move { + // Check if the schema already exists; if so, don't error. + let schema_version: Option = + query_scalar(&format!("PRAGMA {schema_name}.schema_version")) + .fetch_optional(&mut *self) + .await?; + + if schema_version.is_some() { + return Ok(()); + } + + Err(MigrateError::CreateSchemasNotSupported( + format!("cannot create new schema {schema_name}; creation of additional schemas in SQLite requires attaching extra database files"), + )) + }) + } + + fn ensure_migrations_table<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { // language=SQLite - self.execute( + self.execute(&*format!( r#" -CREATE TABLE IF NOT EXISTS _sqlx_migrations ( +CREATE TABLE IF NOT EXISTS {table_name} ( version BIGINT PRIMARY KEY, description TEXT NOT NULL, installed_on TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, @@ -77,20 +102,23 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( checksum BLOB NOT NULL, execution_time BIGINT NOT NULL ); - "#, - ) + "# + )) .await?; Ok(()) }) } - fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { + fn dirty_version<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQLite - let row: Option<(i64,)> = query_as( - "SELECT version FROM _sqlx_migrations WHERE success = false ORDER BY version LIMIT 1", - ) + let row: Option<(i64,)> = query_as(&format!( + "SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1" + )) .fetch_optional(self) .await?; @@ -98,15 +126,17 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn list_applied_migrations( - &mut self, - ) -> BoxFuture<'_, Result, MigrateError>> { + fn list_applied_migrations<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQLite - let rows: Vec<(i64, Vec)> = - query_as("SELECT version, checksum FROM _sqlx_migrations ORDER BY version") - .fetch_all(self) - .await?; + let rows: Vec<(i64, Vec)> = query_as(&format!( + "SELECT version, checksum FROM {table_name} ORDER BY version" + )) + .fetch_all(self) + .await?; let migrations = rows .into_iter() @@ -128,10 +158,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( Box::pin(async move { Ok(()) }) } - fn apply<'e: 'm, 'm>( + fn apply<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { Box::pin(async move { let mut tx = self.begin().await?; let start = Instant::now(); @@ -147,12 +178,12 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; // language=SQL - let _ = query( + let _ = query(&format!( r#" - INSERT INTO _sqlx_migrations ( version, description, success, checksum, execution_time ) + INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) VALUES ( ?1, ?2, TRUE, ?3, -1 ) - "#, - ) + "# + )) .bind(migration.version) .bind(&*migration.description) .bind(&*migration.checksum) @@ -169,13 +200,13 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( // language=SQL #[allow(clippy::cast_possible_truncation)] - let _ = query( + let _ = query(&format!( r#" - UPDATE _sqlx_migrations + UPDATE {table_name} SET execution_time = ?1 WHERE version = ?2 - "#, - ) + "# + )) .bind(elapsed.as_nanos() as i64) .bind(migration.version) .execute(self) @@ -185,10 +216,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn revert<'e: 'm, 'm>( + fn revert<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { Box::pin(async move { // Use a single transaction for the actual migration script and the essential bookeeping so we never // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. @@ -197,8 +229,8 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( let _ = tx.execute(&*migration.sql).await?; - // language=SQL - let _ = query(r#"DELETE FROM _sqlx_migrations WHERE version = ?1"#) + // language=SQLite + let _ = query(&format!(r#"DELETE FROM {table_name} WHERE version = ?1"#)) .bind(migration.version) .execute(&mut *tx) .await?; diff --git a/sqlx-sqlite/src/statement/handle.rs b/sqlx-sqlite/src/statement/handle.rs index ccc299fcd2..e3a757868b 100644 --- a/sqlx-sqlite/src/statement/handle.rs +++ b/sqlx-sqlite/src/statement/handle.rs @@ -1,12 +1,9 @@ use std::ffi::c_void; use std::ffi::CStr; -use std::os::raw::{c_char, c_int}; -use std::ptr; -use std::ptr::NonNull; -use std::slice::from_raw_parts; -use std::str::{from_utf8, from_utf8_unchecked}; - +use crate::error::{BoxDynError, Error}; +use crate::type_info::DataType; +use crate::{SqliteError, SqliteTypeInfo}; use libsqlite3_sys::{ sqlite3, sqlite3_bind_blob64, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64, sqlite3_bind_null, sqlite3_bind_parameter_count, sqlite3_bind_parameter_name, @@ -19,10 +16,13 @@ use libsqlite3_sys::{ sqlite3_value, SQLITE_DONE, SQLITE_LOCKED_SHAREDCACHE, SQLITE_MISUSE, SQLITE_OK, SQLITE_ROW, SQLITE_TRANSIENT, SQLITE_UTF8, }; - -use crate::error::{BoxDynError, Error}; -use crate::type_info::DataType; -use crate::{SqliteError, SqliteTypeInfo}; +use sqlx_core::column::{ColumnOrigin, TableColumn}; +use std::os::raw::{c_char, c_int}; +use std::ptr; +use std::ptr::NonNull; +use std::slice::from_raw_parts; +use std::str::{from_utf8, from_utf8_unchecked}; +use std::sync::Arc; use super::unlock_notify; @@ -34,6 +34,9 @@ pub(crate) struct StatementHandle(NonNull); unsafe impl Send for StatementHandle {} +// Most of the getters below allocate internally, and unsynchronized access is undefined. +// unsafe impl !Sync for StatementHandle {} + macro_rules! expect_ret_valid { ($fn_name:ident($($args:tt)*)) => {{ let val = $fn_name($($args)*); @@ -110,6 +113,65 @@ impl StatementHandle { } } + pub(crate) fn column_origin(&self, index: usize) -> ColumnOrigin { + if let Some((table, name)) = self + .column_table_name(index) + .zip(self.column_origin_name(index)) + { + let table: Arc = self + .column_db_name(index) + .filter(|&db| db != "main") + .map_or_else( + || table.into(), + // TODO: check that SQLite returns the names properly quoted if necessary + |db| format!("{db}.{table}").into(), + ); + + ColumnOrigin::Table(TableColumn { + table, + name: name.into(), + }) + } else { + ColumnOrigin::Expression + } + } + + fn column_db_name(&self, index: usize) -> Option<&str> { + unsafe { + let db_name = sqlite3_column_database_name(self.0.as_ptr(), check_col_idx!(index)); + + if !db_name.is_null() { + Some(from_utf8_unchecked(CStr::from_ptr(db_name).to_bytes())) + } else { + None + } + } + } + + fn column_table_name(&self, index: usize) -> Option<&str> { + unsafe { + let table_name = sqlite3_column_table_name(self.0.as_ptr(), check_col_idx!(index)); + + if !table_name.is_null() { + Some(from_utf8_unchecked(CStr::from_ptr(table_name).to_bytes())) + } else { + None + } + } + } + + fn column_origin_name(&self, index: usize) -> Option<&str> { + unsafe { + let origin_name = sqlite3_column_origin_name(self.0.as_ptr(), check_col_idx!(index)); + + if !origin_name.is_null() { + Some(from_utf8_unchecked(CStr::from_ptr(origin_name).to_bytes())) + } else { + None + } + } + } + pub(crate) fn column_type_info(&self, index: usize) -> SqliteTypeInfo { SqliteTypeInfo(DataType::from_code(self.column_type(index))) } diff --git a/sqlx-sqlite/src/statement/virtual.rs b/sqlx-sqlite/src/statement/virtual.rs index 2817146bc3..b25aa69e47 100644 --- a/sqlx-sqlite/src/statement/virtual.rs +++ b/sqlx-sqlite/src/statement/virtual.rs @@ -104,6 +104,7 @@ impl VirtualStatement { ordinal: i, name: name.clone(), type_info, + origin: statement.column_origin(i), }); column_names.insert(name, i); diff --git a/sqlx-sqlite/src/type_checking.rs b/sqlx-sqlite/src/type_checking.rs index e1ac3bc753..97af601c86 100644 --- a/sqlx-sqlite/src/type_checking.rs +++ b/sqlx-sqlite/src/type_checking.rs @@ -1,8 +1,7 @@ +use crate::Sqlite; #[allow(unused_imports)] use sqlx_core as sqlx; -use crate::Sqlite; - // f32 is not included below as REAL represents a floating point value // stored as an 8-byte IEEE floating point number (i.e. an f64) // For more info see: https://www.sqlite.org/datatype3.html#storage_classes_and_datatypes @@ -20,24 +19,6 @@ impl_type_checking!( String, Vec, - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveDate, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveDateTime, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::DateTime | sqlx::types::chrono::DateTime<_>, - - #[cfg(feature = "time")] - sqlx::types::time::OffsetDateTime, - - #[cfg(feature = "time")] - sqlx::types::time::PrimitiveDateTime, - - #[cfg(feature = "time")] - sqlx::types::time::Date, - #[cfg(feature = "uuid")] sqlx::types::Uuid, }, @@ -48,4 +29,28 @@ impl_type_checking!( // The type integrations simply allow the user to skip some intermediate representation, // which is usually TEXT. feature-types: _info => None, + + // The expansion of the macro automatically applies the correct feature name + // and checks `[macros.preferred-crates]` + datetime-types: { + chrono: { + sqlx::types::chrono::NaiveDate, + + sqlx::types::chrono::NaiveDateTime, + + sqlx::types::chrono::DateTime + | sqlx::types::chrono::DateTime<_>, + }, + time: { + sqlx::types::time::OffsetDateTime, + + sqlx::types::time::PrimitiveDateTime, + + sqlx::types::time::Date, + }, + }, + numeric-types: { + bigdecimal: { }, + rust_decimal: { }, + }, ); diff --git a/src/lib.rs b/src/lib.rs index e55dc26e36..c608e02aea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ pub use sqlx_core::acquire::Acquire; pub use sqlx_core::arguments::{Arguments, IntoArguments}; pub use sqlx_core::column::Column; pub use sqlx_core::column::ColumnIndex; +pub use sqlx_core::column::ColumnOrigin; pub use sqlx_core::connection::{ConnectOptions, Connection}; pub use sqlx_core::database::{self, Database}; pub use sqlx_core::describe::Describe; @@ -172,3 +173,37 @@ pub mod prelude { pub use super::Statement; pub use super::Type; } + +#[cfg(feature = "_unstable-doc")] +#[cfg_attr(docsrs, doc(cfg(feature = "_unstable-doc")))] +pub use sqlx_core::config as _config; + +// NOTE: APIs exported in this module are SemVer-exempt. +#[doc(hidden)] +pub mod _unstable { + pub use sqlx_core::config; +} + +#[doc(hidden)] +#[cfg_attr( + all(feature = "chrono", feature = "time"), + deprecated = "SQLx has both `chrono` and `time` features enabled, \ + which presents an ambiguity when the `query!()` macros are mapping date/time types. \ + The `query!()` macros prefer types from `time` by default, \ + but this behavior should not be relied upon; \ + to resolve the ambiguity, we recommend specifying the preferred crate in a `sqlx.toml` file: \ + https://docs.rs/sqlx/latest/sqlx/config/macros/PreferredCrates.html#field.date_time" +)] +pub fn warn_on_ambiguous_inferred_date_time_crate() {} + +#[doc(hidden)] +#[cfg_attr( + all(feature = "bigdecimal", feature = "rust_decimal"), + deprecated = "SQLx has both `bigdecimal` and `rust_decimal` features enabled, \ + which presents an ambiguity when the `query!()` macros are mapping `NUMERIC`. \ + The `query!()` macros prefer `bigdecimal::BigDecimal` by default, \ + but this behavior should not be relied upon; \ + to resolve the ambiguity, we recommend specifying the preferred crate in a `sqlx.toml` file: \ + https://docs.rs/sqlx/latest/sqlx/config/macros/PreferredCrates.html#field.numeric" +)] +pub fn warn_on_ambiguous_inferred_numeric_crate() {} diff --git a/src/macros/mod.rs b/src/macros/mod.rs index 9e81935876..0db6f0c2e7 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -74,6 +74,25 @@ /// /// [dotenv]: https://crates.io/crates/dotenv /// [dotenvy]: https://crates.io/crates/dotenvy +/// +/// ## Configuration with `sqlx.toml` +/// Multiple crate-wide configuration options are now available, including: +/// +/// * change the name of the `DATABASE_URL` variable for using multiple databases in the same workspace +/// * In the initial implementation, a separate crate must be created for each database. +/// Using multiple databases in the same crate may become possible in the future. +/// * global type overrides (useful for custom types!) +/// * per-column type overrides +/// * force use of a specific crate (e.g. `chrono` when both it and `time` are enabled) +/// +/// See the [configuration guide] and [reference `sqlx.toml`] for details. +/// +/// See also `examples/postgres/multi-database` and `examples/postgres/preferred-crates` +/// for example usage. +/// +/// [configuration guide]: crate::_config::macros::Config +/// [reference `sqlx.toml`]: crate::_config::_reference +/// /// ## Query Arguments /// Like `println!()` and the other formatting macros, you can add bind parameters to your SQL /// and this macro will typecheck passed arguments and error on missing ones: @@ -728,6 +747,7 @@ macro_rules! query_file_scalar_unchecked ( /// Embeds migrations into the binary by expanding to a static instance of [Migrator][crate::migrate::Migrator]. /// /// ```rust,ignore +/// // Consider instead setting /// sqlx::migrate!("db/migrations") /// .run(&pool) /// .await?; @@ -745,6 +765,38 @@ macro_rules! query_file_scalar_unchecked ( /// /// See [MigrationSource][crate::migrate::MigrationSource] for details on structure of the ./migrations directory. /// +/// ## Note: Platform-specific Line Endings +/// Different platforms use different bytes for line endings by default: +/// * Linux and MacOS use Line Feeds (LF:`\n`) +/// * Windows uses Carriage Returns _and_ Line Feeds (CRLF:'\r\n') +/// +/// This may result in un-reproducible hashes across platforms unless taken into account. +/// +/// One solution is to use a [`.gitattributes` file](https://git-scm.com/docs/gitattributes) +/// and force `.sql` files to be checked out with Line Feeds: +/// +/// ```gitattributes +/// *.sql text eol=lf +/// ``` +/// +/// Another option is to configure migrations to ignore whitespace. +/// See the next section for details. +/// +/// ## Configuration with `sqlx.toml` +/// Multiple crate-wide configuration options are now available, including: +/// +/// * creating schemas on database setup +/// * renaming the `_sqlx_migrations` table or placing it into a new schema +/// * relocating the migrations directory +/// * ignoring characters for hashing (such as whitespace and newlines) +/// +/// See the [configuration guide] and [reference `sqlx.toml`] for details. +/// +/// `sqlx-cli` can also read these options and use them when setting up or migrating databases. +/// +/// [configuration guide]: crate::_config::migrate::Config +/// [reference `sqlx.toml`]: crate::_config::_reference +/// /// ## Triggering Recompilation on Migration Changes /// In some cases when making changes to embedded migrations, such as adding a new migration without /// changing any Rust source files, you might find that `cargo build` doesn't actually do anything, @@ -814,6 +866,6 @@ macro_rules! migrate { }}; () => {{ - $crate::sqlx_macros::migrate!("./migrations") + $crate::sqlx_macros::migrate!() }}; } diff --git a/src/macros/test.md b/src/macros/test.md index 30de8070f6..ec3cee90b0 100644 --- a/src/macros/test.md +++ b/src/macros/test.md @@ -1,6 +1,6 @@ Mark an `async fn` as a test with SQLx support. -The test will automatically be executed in the async runtime according to the chosen +The test will automatically be executed in the async runtime according to the chosen `runtime-{async-std, tokio}` feature. If more than one runtime feature is enabled, `runtime-tokio` is preferred. By default, this behaves identically to `#[tokio::test]`1 or `#[async_std::test]`: @@ -31,25 +31,24 @@ but are isolated from each other. This feature is activated by changing the signature of your test function. The following signatures are supported: * `async fn(Pool) -> Ret` - * the `Pool`s used by all running tests share a single connection limit to avoid exceeding the server's limit. + * the `Pool`s used by all running tests share a single connection limit to avoid exceeding the server's limit. * `async fn(PoolConnection) -> Ret` - * `PoolConnection`, etc. + * `PoolConnection`, etc. * `async fn(PoolOptions, impl ConnectOptions) -> Ret` * Where `impl ConnectOptions` is, e.g, `PgConnectOptions`, `MySqlConnectOptions`, etc. - * If your test wants to create its own `Pool` (for example, to set pool callbacks or to modify `ConnectOptions`), + * If your test wants to create its own `Pool` (for example, to set pool callbacks or to modify `ConnectOptions`), you can use this signature. Where `DB` is a supported `Database` type and `Ret` is `()` or `Result<_, _>`. ##### Supported Databases -Most of these will require you to set `DATABASE_URL` as an environment variable +Most of these will require you to set `DATABASE_URL` as an environment variable or in a `.env` file like `sqlx::query!()` _et al_, to give the test driver a superuser connection with which to manage test databases. - | Database | Requires `DATABASE_URL` | -| --- | --- | +|----------|-------------------------| | Postgres | Yes | | MySQL | Yes | | SQLite | No2 | @@ -58,7 +57,7 @@ Test databases are automatically cleaned up as tests succeed, but failed tests w to facilitate debugging. Note that to simplify the implementation, panics are _always_ considered to be failures, even for `#[should_panic]` tests. -To limit disk space usage, any previously created test databases will be deleted the next time a test binary using +To limit disk space usage, any previously created test databases will be deleted the next time a test binary using `#[sqlx::test]` is run. ```rust,no_run @@ -86,8 +85,8 @@ converted to a filesystem path (`::` replaced with `/`). ### Automatic Migrations (requires `migrate` feature) -To ensure a straightforward test implementation against a fresh test database, migrations are automatically applied if a -`migrations` folder is found in the same directory as `CARGO_MANIFEST_DIR` (the directory where the current crate's +To ensure a straightforward test implementation against a fresh test database, migrations are automatically applied if a +`migrations` folder is found in the same directory as `CARGO_MANIFEST_DIR` (the directory where the current crate's `Cargo.toml` resides). You can override the resolved path relative to `CARGO_MANIFEST_DIR` in the attribute (global overrides are not currently @@ -116,11 +115,13 @@ async fn basic_test(pool: PgPool) -> sqlx::Result<()> { Or if you're already embedding migrations in your main crate, you can reference them directly: `foo_crate/lib.rs` + ```rust,ignore pub static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("foo_migrations"); ``` `foo_crate/tests/foo_test.rs` + ```rust,no_run # #[cfg(all(feature = "migrate", feature = "postgres"))] # mod example { @@ -129,12 +130,7 @@ use sqlx::{PgPool, Row}; # // This is standing in for the main crate since doc examples don't support multiple crates. # mod foo_crate { # use std::borrow::Cow; -# static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate::Migrator { -# migrations: Cow::Borrowed(&[]), -# ignore_missing: false, -# locking: true, -# no_tx: false -# }; +# static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate::Migrator::DEFAULT; # } // You could also do `use foo_crate::MIGRATOR` and just refer to it as `MIGRATOR` here. @@ -188,9 +184,12 @@ the database to already have users and posts in it so the comments tests don't h You can either pass a list of fixture to the attribute `fixtures` in three different operating modes: -1) Pass a list of references files in `./fixtures` (resolved as `./fixtures/{name}.sql`, `.sql` added only if extension is missing); -2) Pass a list of file paths (including associated extension), in which case they can either be absolute, or relative to the current file; -3) Pass a `path = ` parameter and a `scripts(, , ...)` parameter that are relative to the provided path (resolved as `{path}/{filename_x}.sql`, `.sql` added only if extension is missing). +1) Pass a list of references files in `./fixtures` (resolved as `./fixtures/{name}.sql`, `.sql` added only if extension + is missing); +2) Pass a list of file paths (including associated extension), in which case they can either be absolute, or relative to + the current file; +3) Pass a `path = ` parameter and a `scripts(, , ...)` parameter that are + relative to the provided path (resolved as `{path}/{filename_x}.sql`, `.sql` added only if extension is missing). In any case they will be applied in the given order3: @@ -225,6 +224,6 @@ async fn test_create_comment(pool: PgPool) -> sqlx::Result<()> { Multiple `fixtures` attributes can be used to combine different operating modes. 3Ordering for test fixtures is entirely up to the application, and each test may choose which fixtures to -apply and which to omit. However, since each fixture is applied separately (sent as a single command string, so wrapped -in an implicit `BEGIN` and `COMMIT`), you will want to make sure to order the fixtures such that foreign key +apply and which to omit. However, since each fixture is applied separately (sent as a single command string, so wrapped +in an implicit `BEGIN` and `COMMIT`), you will want to make sure to order the fixtures such that foreign key requirements are always satisfied, or else you might get errors. diff --git a/tests/postgres/macros.rs b/tests/postgres/macros.rs index 07ae962018..04c1fe9d1e 100644 --- a/tests/postgres/macros.rs +++ b/tests/postgres/macros.rs @@ -295,6 +295,7 @@ async fn query_by_bigdecimal() -> anyhow::Result<()> { let decimal = "1234".parse::()?; let ref tuple = ("51245.121232".parse::()?,); + #[cfg_attr(feature = "rust_decimal", allow(deprecated))] // TODO: upgrade to `expect` let result = sqlx::query!( "SELECT * from (VALUES(1234.0)) decimals(decimal)\ where decimal in ($1, $2, $3, $4, $5, $6, $7)",