From 3eeb165ae70f03e645dbbc642e2b8cf0b27a9336 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Tue, 23 Apr 2024 03:49:18 -0700 Subject: [PATCH 01/12] refactor: introduce `SqlSafeStr` API --- sqlx-core/src/executor.rs | 73 ------------- sqlx-core/src/lib.rs | 1 + sqlx-core/src/query.rs | 176 ++++++++++++------------------- sqlx-core/src/query_as.rs | 6 +- sqlx-core/src/query_builder.rs | 31 ++++-- sqlx-core/src/query_scalar.rs | 18 ++-- sqlx-core/src/sql_str.rs | 182 +++++++++++++++++++++++++++++++++ 7 files changed, 284 insertions(+), 203 deletions(-) create mode 100644 sqlx-core/src/sql_str.rs diff --git a/sqlx-core/src/executor.rs b/sqlx-core/src/executor.rs index 84b1a660d8..758cca7330 100644 --- a/sqlx-core/src/executor.rs +++ b/sqlx-core/src/executor.rs @@ -182,76 +182,3 @@ pub trait Executor<'c>: Send + Debug + Sized { where 'c: 'e; } - -/// A type that may be executed against a database connection. -/// -/// Implemented for the following: -/// -/// * [`&str`](std::str) -/// * [`Query`](super::query::Query) -/// -pub trait Execute<'q, DB: Database>: Send + Sized { - /// Gets the SQL that will be executed. - fn sql(&self) -> &'q str; - - /// Gets the previously cached statement, if available. - fn statement(&self) -> Option<&DB::Statement<'q>>; - - /// Returns the arguments to be bound against the query string. - /// - /// Returning `Ok(None)` for `Arguments` indicates to use a "simple" query protocol and to not - /// prepare the query. Returning `Ok(Some(Default::default()))` is an empty arguments object that - /// will be prepared (and cached) before execution. - /// - /// Returns `Err` if encoding any of the arguments failed. - fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError>; - - /// Returns `true` if the statement should be cached. - fn persistent(&self) -> bool; -} - -// NOTE: `Execute` is explicitly not implemented for String and &String to make it slightly more -// involved to write `conn.execute(format!("SELECT {val}"))` -impl<'q, DB: Database> Execute<'q, DB> for &'q str { - #[inline] - fn sql(&self) -> &'q str { - self - } - - #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { - None - } - - #[inline] - fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { - Ok(None) - } - - #[inline] - fn persistent(&self) -> bool { - true - } -} - -impl<'q, DB: Database> Execute<'q, DB> for (&'q str, Option<::Arguments<'q>>) { - #[inline] - fn sql(&self) -> &'q str { - self.0 - } - - #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { - None - } - - #[inline] - fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { - Ok(self.1.take()) - } - - #[inline] - fn persistent(&self) -> bool { - true - } -} diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index df4b2cc27d..4671c16708 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -72,6 +72,7 @@ pub mod net; pub mod query_as; pub mod query_builder; pub mod query_scalar; +pub mod sql_str; pub mod raw_sql; pub mod row; diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index 60f509c342..7bd424bf16 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -8,15 +8,16 @@ use crate::arguments::{Arguments, IntoArguments}; use crate::database::{Database, HasStatementCache}; use crate::encode::Encode; use crate::error::{BoxDynError, Error}; -use crate::executor::{Execute, Executor}; +use crate::executor::{Executor}; +use crate::sql_str::{SqlSafeStr, SqlStr}; use crate::statement::Statement; use crate::types::Type; /// A single SQL query as a prepared statement. Returned by [`query()`]. #[must_use = "query must be executed to affect database"] -pub struct Query<'q, DB: Database, A> { - pub(crate) statement: Either<&'q str, &'q DB::Statement<'q>>, - pub(crate) arguments: Option>, +pub struct Query<'q, 'a, DB: Database> { + pub(crate) statement: Either>, + pub(crate) arguments: Option, BoxDynError>>, pub(crate) database: PhantomData, pub(crate) persistent: bool, } @@ -33,46 +34,32 @@ pub struct Query<'q, DB: Database, A> { /// before `.try_map()`. This is also to prevent adding superfluous binds to the result of /// `query!()` et al. #[must_use = "query must be executed to affect database"] -pub struct Map<'q, DB: Database, F, A> { - inner: Query<'q, DB, A>, +pub struct Map<'q, 'a, DB: Database, F> { + inner: Query<'q, 'a, DB>, mapper: F, } -impl<'q, DB, A> Execute<'q, DB> for Query<'q, DB, A> +impl<'q, 'a, DB> Query<'q, 'a, DB> where - DB: Database, - A: Send + IntoArguments<'q, DB>, + DB: Database + HasStatementCache, { - #[inline] - fn sql(&self) -> &'q str { - match self.statement { - Either::Right(statement) => statement.sql(), - Either::Left(sql) => sql, - } - } - - fn statement(&self) -> Option<&DB::Statement<'q>> { - match self.statement { - Either::Right(statement) => Some(statement), - Either::Left(_) => None, - } - } - - #[inline] - fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { - self.arguments - .take() - .transpose() - .map(|option| option.map(IntoArguments::into_arguments)) - } - - #[inline] - fn persistent(&self) -> bool { - self.persistent + /// If `true`, the statement will get prepared once and cached to the + /// connection's statement cache. + /// + /// If queried once with the flag set to `true`, all subsequent queries + /// matching the one with the flag will use the cached statement until the + /// cache is cleared. + /// + /// If `false`, the prepared statement will be closed after execution. + /// + /// Default: `true`. + pub fn persistent(mut self, value: bool) -> Self { + self.persistent = value; + self } } -impl<'q, DB: Database> Query<'q, DB, ::Arguments<'q>> { +impl<'q, 'a, DB: Database> Query<'q, 'a, DB> { /// Bind a value for use with this SQL query. /// /// If the number of times this is called does not match the number of bind parameters that @@ -120,31 +107,10 @@ impl<'q, DB: Database> Query<'q, DB, ::Arguments<'q>> { } } -impl<'q, DB, A> Query<'q, DB, A> -where - DB: Database + HasStatementCache, -{ - /// If `true`, the statement will get prepared once and cached to the - /// connection's statement cache. - /// - /// If queried once with the flag set to `true`, all subsequent queries - /// matching the one with the flag will use the cached statement until the - /// cache is cleared. - /// - /// If `false`, the prepared statement will be closed after execution. - /// - /// Default: `true`. - pub fn persistent(mut self, value: bool) -> Self { - self.persistent = value; - self - } -} - -impl<'q, DB, A: Send> Query<'q, DB, A> -where - DB: Database, - A: 'q + IntoArguments<'q, DB>, -{ +impl<'q, 'a, DB> Query<'q, 'a, DB> + where + DB: Database, + { /// Map each row in the result to another type. /// /// See [`try_map`](Query::try_map) for a fallible version of this method. @@ -155,7 +121,7 @@ where pub fn map( self, mut f: F, - ) -> Map<'q, DB, impl FnMut(DB::Row) -> Result + Send, A> + ) -> Map<'q, 'a, DB, impl FnMut(DB::Row) -> Result + Send> where F: FnMut(DB::Row) -> O + Send, O: Unpin, @@ -168,7 +134,7 @@ where /// The [`query_as`](super::query_as::query_as) method will construct a mapped query using /// a [`FromRow`](super::from_row::FromRow) implementation. #[inline] - pub fn try_map(self, f: F) -> Map<'q, DB, F, A> + pub fn try_map(self, f: F) -> Map<'q, 'a, DB, F> where F: FnMut(DB::Row) -> Result + Send, O: Unpin, @@ -184,33 +150,17 @@ where pub async fn execute<'e, 'c: 'e, E>(self, executor: E) -> Result where 'q: 'e, - A: 'e, + 'a: 'e, E: Executor<'c, Database = DB>, { executor.execute(self).await } - /// Execute multiple queries and return the rows affected from each query, in a stream. - #[inline] - #[deprecated = "Only the SQLite driver supports multiple statements in one prepared statement and that behavior is deprecated. Use `sqlx::raw_sql()` instead. See https://github.com/launchbadge/sqlx/issues/3108 for discussion."] - pub async fn execute_many<'e, 'c: 'e, E>( - self, - executor: E, - ) -> BoxStream<'e, Result> - where - 'q: 'e, - A: 'e, - E: Executor<'c, Database = DB>, - { - executor.execute_many(self) - } - /// Execute the query and return the generated results as a stream. #[inline] pub fn fetch<'e, 'c: 'e, E>(self, executor: E) -> BoxStream<'e, Result> where 'q: 'e, - A: 'e, E: Executor<'c, Database = DB>, { executor.fetch(self) @@ -229,7 +179,6 @@ where ) -> BoxStream<'e, Result, Error>> where 'q: 'e, - A: 'e, E: Executor<'c, Database = DB>, { executor.fetch_many(self) @@ -246,7 +195,6 @@ where pub async fn fetch_all<'e, 'c: 'e, E>(self, executor: E) -> Result, Error> where 'q: 'e, - A: 'e, E: Executor<'c, Database = DB>, { executor.fetch_all(self).await @@ -268,7 +216,6 @@ where pub async fn fetch_one<'e, 'c: 'e, E>(self, executor: E) -> Result where 'q: 'e, - A: 'e, E: Executor<'c, Database = DB>, { executor.fetch_one(self).await @@ -290,45 +237,51 @@ where pub async fn fetch_optional<'e, 'c: 'e, E>(self, executor: E) -> Result, Error> where 'q: 'e, - A: 'e, E: Executor<'c, Database = DB>, { executor.fetch_optional(self).await } } -impl<'q, DB, F: Send, A: Send> Execute<'q, DB> for Map<'q, DB, F, A> +#[doc(hidden)] +impl<'q, 'a, DB> Query<'q, 'a, DB> where DB: Database, - A: IntoArguments<'q, DB>, { #[inline] fn sql(&self) -> &'q str { - self.inner.sql() + match &self.statement { + Either::Right(statement) => statement.sql(), + Either::Left(sql) => sql.as_str(), + } } - #[inline] fn statement(&self) -> Option<&DB::Statement<'q>> { - self.inner.statement() + match self.statement { + Either::Right(statement) => Some(statement), + Either::Left(_) => None, + } } #[inline] fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { - self.inner.take_arguments() + self.arguments + .take() + .transpose() + .map(|option| option.map(IntoArguments::into_arguments)) } #[inline] - fn persistent(&self) -> bool { - self.inner.arguments.is_some() + fn is_persistent(&self) -> bool { + self.persistent } } -impl<'q, DB, F, O, A> Map<'q, DB, F, A> +impl<'q, 'a, DB, F, O> Map<'q, 'a, DB, F> where DB: Database, F: FnMut(DB::Row) -> Result + Send, O: Send + Unpin, - A: 'q + Send + IntoArguments<'q, DB>, { /// Map each row in the result to another type. /// @@ -340,7 +293,7 @@ where pub fn map( self, mut g: G, - ) -> Map<'q, DB, impl FnMut(DB::Row) -> Result + Send, A> + ) -> Map<'q, 'a, DB, impl FnMut(DB::Row) -> Result + Send> where G: FnMut(O) -> P + Send, P: Unpin, @@ -356,7 +309,7 @@ where pub fn try_map( self, mut g: G, - ) -> Map<'q, DB, impl FnMut(DB::Row) -> Result + Send, A> + ) -> Map<'q, 'a, DB, impl FnMut(DB::Row) -> Result + Send> where G: FnMut(O) -> Result + Send, P: Unpin, @@ -497,9 +450,9 @@ where } /// Execute a single SQL query as a prepared statement (explicitly created). -pub fn query_statement<'q, DB>( +pub fn query_statement<'q, 'a, DB>( statement: &'q DB::Statement<'q>, -) -> Query<'q, DB, ::Arguments<'_>> +) -> Query<'q, 'a, DB> where DB: Database, { @@ -512,17 +465,17 @@ where } /// Execute a single SQL query as a prepared statement (explicitly created), with the given arguments. -pub fn query_statement_with<'q, DB, A>( +pub fn query_statement_with<'q, 'a, DB, A>( statement: &'q DB::Statement<'q>, arguments: A, -) -> Query<'q, DB, A> +) -> Query<'q, 'a, DB> where DB: Database, A: IntoArguments<'q, DB>, { Query { database: PhantomData, - arguments: Some(Ok(arguments)), + arguments: Some(Ok(arguments.into_arguments())), statement: Either::Right(statement), persistent: true, } @@ -652,14 +605,15 @@ where /// /// As an additional benefit, query parameters are usually sent in a compact binary encoding instead of a human-readable /// text encoding, which saves bandwidth. -pub fn query(sql: &str) -> Query<'_, DB, ::Arguments<'_>> +pub fn query<'a, DB, SQL>(sql: SQL) -> Query<'static, 'a, DB> where DB: Database, + SQL: SqlSafeStr, { Query { database: PhantomData, arguments: Some(Ok(Default::default())), - statement: Either::Left(sql), + statement: Either::Left(sql.into_sql_str()), persistent: true, } } @@ -667,27 +621,27 @@ where /// Execute a SQL query as a prepared statement (transparently cached), with the given arguments. /// /// See [`query()`][query] for details, such as supported syntax. -pub fn query_with<'q, DB, A>(sql: &'q str, arguments: A) -> Query<'q, DB, A> +pub fn query_with<'a, DB, SQL, A>(sql: SQL, arguments: A) -> Query<'static, 'a, DB> where DB: Database, - A: IntoArguments<'q, DB>, + A: IntoArguments<'a, DB>, { query_with_result(sql, Ok(arguments)) } /// Same as [`query_with`] but is initialized with a Result of arguments instead -pub fn query_with_result<'q, DB, A>( - sql: &'q str, - arguments: Result, -) -> Query<'q, DB, A> +pub fn query_with_result<'a, DB, SQL>( + sql: SQL, + arguments: Result, BoxDynError>, +) -> Query<'static, 'a, DB> where DB: Database, - A: IntoArguments<'q, DB>, + SQL: SqlSafeStr, { Query { database: PhantomData, arguments: Some(arguments), - statement: Either::Left(sql), + statement: Either::Left(sql.into_sql_str()), persistent: true, } } diff --git a/sqlx-core/src/query_as.rs b/sqlx-core/src/query_as.rs index 9f28fe41e9..257d28b75f 100644 --- a/sqlx-core/src/query_as.rs +++ b/sqlx-core/src/query_as.rs @@ -11,6 +11,7 @@ use crate::error::{BoxDynError, Error}; use crate::executor::{Execute, Executor}; use crate::from_row::FromRow; use crate::query::{query, query_statement, query_statement_with, query_with_result, Query}; +use crate::sql_str::SqlSafeStr; use crate::types::Type; /// A single SQL query as a prepared statement, mapping results using [`FromRow`]. @@ -339,7 +340,7 @@ where /// /// ``` #[inline] -pub fn query_as<'q, DB, O>(sql: &'q str) -> QueryAs<'q, DB, O, ::Arguments<'q>> +pub fn query_as<'q, DB, SQL, O>(sql: SQL) -> QueryAs<'q, DB, O, ::Arguments<'q>> where DB: Database, O: for<'r> FromRow<'r, DB::Row>, @@ -357,9 +358,10 @@ where /// /// For details about type mapping from [`FromRow`], see [`query_as()`]. #[inline] -pub fn query_as_with<'q, DB, O, A>(sql: &'q str, arguments: A) -> QueryAs<'q, DB, O, A> +pub fn query_as_with<'q, DB, SQL, O, A>(sql: SQL, arguments: A) -> QueryAs<'q, DB, O, A> where DB: Database, + SQL: SqlSafeStr<'q>, A: IntoArguments<'q, DB>, O: for<'r> FromRow<'r, DB::Row>, { diff --git a/sqlx-core/src/query_builder.rs b/sqlx-core/src/query_builder.rs index b242bf7b2a..d15d1a00cd 100644 --- a/sqlx-core/src/query_builder.rs +++ b/sqlx-core/src/query_builder.rs @@ -3,7 +3,7 @@ use std::fmt::Display; use std::fmt::Write; use std::marker::PhantomData; - +use std::sync::Arc; use crate::arguments::{Arguments, IntoArguments}; use crate::database::Database; use crate::encode::Encode; @@ -13,6 +13,7 @@ use crate::query_as::QueryAs; use crate::query_scalar::QueryScalar; use crate::types::Type; use crate::Either; +use crate::sql_str::AssertSqlSafe; /// A builder type for constructing queries at runtime. /// @@ -25,7 +26,9 @@ pub struct QueryBuilder<'args, DB> where DB: Database, { - query: String, + // Using `Arc` allows us to share the query string allocation with the database driver. + // It's only copied if the driver retains ownership after execution. + query: Arc, init_len: usize, arguments: Option<::Arguments<'args>>, } @@ -85,6 +88,16 @@ where "QueryBuilder must be reset before reuse after `.build()`" ); } + + fn query_mut(&mut self) -> &mut String { + assert!( + self.arguments.is_some(), + "QueryBuilder must be reset before reuse after `.build()`" + ); + + Arc::get_mut(&mut self.query) + .expect("BUG: query must not be shared at this point in time") + } /// Append a SQL fragment to the query. /// @@ -116,7 +129,7 @@ where pub fn push(&mut self, sql: impl Display) -> &mut Self { self.sanity_check(); - write!(self.query, "{sql}").expect("error formatting `sql`"); + write!(self.query_mut(), "{sql}").expect("error formatting `sql`"); self } @@ -158,7 +171,7 @@ where arguments.add(value).expect("Failed to add argument"); arguments - .format_placeholder(&mut self.query) + .format_placeholder(self.query_mut()) .expect("error in format_placeholder"); self @@ -453,12 +466,10 @@ where pub fn build(&mut self) -> Query<'_, DB, ::Arguments<'args>> { self.sanity_check(); - Query { - statement: Either::Left(&self.query), - arguments: self.arguments.take().map(Ok), - database: PhantomData, - persistent: true, - } + crate::query::query_with( + AssertSqlSafe(&self.query), + self.arguments.take().expect("BUG: just ran sanity_check") + ) } /// Produce an executable query from this builder. diff --git a/sqlx-core/src/query_scalar.rs b/sqlx-core/src/query_scalar.rs index c131adcca3..c097b9b18b 100644 --- a/sqlx-core/src/query_scalar.rs +++ b/sqlx-core/src/query_scalar.rs @@ -4,6 +4,7 @@ use futures_util::{StreamExt, TryFutureExt, TryStreamExt}; use crate::arguments::IntoArguments; use crate::database::{Database, HasStatementCache}; +use crate::decode::Decode; use crate::encode::Encode; use crate::error::{BoxDynError, Error}; use crate::executor::{Execute, Executor}; @@ -11,6 +12,7 @@ use crate::from_row::FromRow; use crate::query_as::{ query_as, query_as_with_result, query_statement_as, query_statement_as_with, QueryAs, }; +use crate::sql_str::SqlSafeStr; use crate::types::Type; /// A single SQL query as a prepared statement which extracts only the first column of each row. @@ -318,12 +320,13 @@ where /// # } /// ``` #[inline] -pub fn query_scalar<'q, DB, O>( - sql: &'q str, +pub fn query_scalar<'q, DB, SQL, O>( + sql: SQL, ) -> QueryScalar<'q, DB, O, ::Arguments<'q>> where DB: Database, - (O,): for<'r> FromRow<'r, DB::Row>, + SQL: SqlSafeStr<'q>, + O: Type + for<'r> Decode<'r, DB>, { QueryScalar { inner: query_as(sql), @@ -337,11 +340,12 @@ where /// /// For details about prepared statements and allowed SQL syntax, see [`query()`][crate::query::query]. #[inline] -pub fn query_scalar_with<'q, DB, O, A>(sql: &'q str, arguments: A) -> QueryScalar<'q, DB, O, A> +pub fn query_scalar_with<'q, DB, SQL, O, A>(sql: SQL, arguments: A) -> QueryScalar<'q, DB, O, A> where DB: Database, + SQL: SqlSafeStr<'q>, A: IntoArguments<'q, DB>, - (O,): for<'r> FromRow<'r, DB::Row>, + O: Type + for<'r> Decode<'r, DB>, { query_scalar_with_result(sql, Ok(arguments)) } @@ -368,7 +372,7 @@ pub fn query_statement_scalar<'q, DB, O>( ) -> QueryScalar<'q, DB, O, ::Arguments<'_>> where DB: Database, - (O,): for<'r> FromRow<'r, DB::Row>, + O: Type + for<'r> Decode<'r, DB>, { QueryScalar { inner: query_statement_as(statement), @@ -383,7 +387,7 @@ pub fn query_statement_scalar_with<'q, DB, O, A>( where DB: Database, A: IntoArguments<'q, DB>, - (O,): for<'r> FromRow<'r, DB::Row>, + O: Type + for<'r> Decode<'r, DB>, { QueryScalar { inner: query_statement_as_with(statement, arguments), diff --git a/sqlx-core/src/sql_str.rs b/sqlx-core/src/sql_str.rs new file mode 100644 index 0000000000..58ed23d6d3 --- /dev/null +++ b/sqlx-core/src/sql_str.rs @@ -0,0 +1,182 @@ +use std::borrow::Borrow; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +/// A SQL string that is safe to execute on a database connection. +/// +/// A "safe" SQL string is one that is unlikely to contain a [SQL injection vulnerability][injection]. +/// +/// In practice, this means a string type that is unlikely to contain dynamic data or user input. +/// +/// `&'static str` is the only string type that satisfies the requirements of this trait +/// (ignoring [`String::leak()`] which has niche use-cases) and so is the only string type that +/// natively implements this trait by default. +/// +/// For other string types, use [`AssertSqlSafe`] to assert this property. +/// This is the only intended way to pass an owned `String` to [`query()`] and its related functions +/// as well as [`raw_sql()`]. +/// +/// The maintainers of SQLx take no responsibility for any data leaks or loss resulting from misuse +/// of this API. +/// +/// ### Motivation +/// This is designed to act as a speed bump against naively using `format!()` to add dynamic data +/// or user input to a query, which is a classic vector for SQL injection as SQLx does not +/// provide any sort of escaping or sanitization (which would have to be specially implemented +/// for each database flavor/locale). +/// +/// The recommended way to incorporate dynamic data or user input in a query is to use +/// bind parameters, which requires the query to execute as a prepared statement. +/// See [`query()`] for details. +/// +/// This trait and [`AssertSqlSafe`] are intentionally analogous to +/// [`std::panic::UnwindSafe`] and [`std::panic::AssertUnwindSafe`], respectively. +/// +/// [injection]: https://en.wikipedia.org/wiki/SQL_injection +/// [`query()`]: crate::query::query +/// [`raw_sql()`]: crate::raw_sql::raw_sql +pub trait SqlSafeStr { + /// Convert `self` to a [`SqlStr`]. + fn into_sql_str(self) -> SqlStr; +} + +impl SqlSafeStr for &'static str { + #[inline] + + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Static(self)) + } +} + +/// Assert that a query string is safe to execute on a database connection. +/// +/// Using this API means that **you** have made sure that the string contents do not contain a +/// [SQL injection vulnerability][injection]. It means that, if the string was constructed +/// dynamically, and/or from user input, you have taken care to sanitize the input yourself. +/// SQLx does not provide any sort of sanitization; the design of SQLx prefers the use +/// of prepared statements for dynamic input. +/// +/// The maintainers of SQLx take no responsibility for any data leaks or loss resulting from misuse +/// of this API. **Use at your own risk.** +/// +/// Note that `&'static str` implements [`SqlSafeStr`] directly and so does not need to be wrapped +/// with this type. +/// +/// [injection]: https://en.wikipedia.org/wiki/SQL_injection +pub struct AssertSqlSafe(pub T); + +/// Note: copies the string. +/// +/// It is recommended to pass one of the supported owned string types instead. +impl<'a> SqlSafeStr for AssertSqlSafe<&'a str> { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Arced(self.0.into())) + } +} +impl SqlSafeStr for AssertSqlSafe { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Owned(self.0)) + } +} + +impl SqlSafeStr for AssertSqlSafe> { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Boxed(self.0)) + } +} + +// Note: this is not implemented for `Rc` because it would make `QueryString: !Send`. +impl SqlSafeStr for AssertSqlSafe> { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Arced(self.into())) + } +} + +/// A SQL string that is ready to execute on a database connection. +/// +/// This is essentially `Cow<'static, str>` but which can be constructed from additional types +/// without copying. +/// +/// See [`SqlSafeStr`] for details. +#[derive(Debug)] +pub struct SqlStr(Repr); + +#[derive(Debug)] +enum Repr { + /// We need a variant to memoize when we already have a static string, so we don't copy it. + Static(&'static str), + /// Thanks to the new niche in `String`, this doesn't increase the size beyond 3 words. + /// We essentially get all these variants for free. + Owned(String), + Boxed(Box), + Arced(Arc), + /// Allows for dynamic shared ownership with `query_builder`. + ArcString(Arc), +} + +impl Clone for SqlStr { + fn clone(&self) -> Self { + Self(match &self.0 { + Repr::Static(s) => Repr::Static(s), + Repr::Arced(s) => Repr::Arced(s.clone()), + _ => Repr::Arced(self.as_str().into()), + }) + } +} + +impl SqlSafeStr for SqlStr { + #[inline] + fn into_sql_str(self) -> SqlStr { + self + } +} + +impl SqlStr { + pub(crate) fn from_arc_string(arc: Arc) -> Self { + SqlStr(Repr::ArcString(arc)) + } + + /// Borrow the inner query string. + #[inline] + pub fn as_str(&self) -> &str { + match &self.0 { + Repr::Static(s) => s, + Repr::Owned(s) => s, + Repr::Boxed(s) => s, + Repr::Arced(s) => s, + Repr::ArcString(s) => s, + } + } +} + +impl AsRef for SqlStr { + #[inline] + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl Borrow for SqlStr { + #[inline] + fn borrow(&self) -> &str { + self.as_str() + } +} + +impl PartialEq for SqlStr where T: AsRef { + fn eq(&self, other: &T) -> bool { + self.as_str() == other.as_ref() + } +} + +impl Eq for SqlStr {} + +impl Hash for SqlStr { + fn hash(&self, state: &mut H) { + self.as_str().hash(state) + } +} From 35e368c4d72bf90276b15327a33bbfbf6665cdf5 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Mon, 17 Feb 2025 20:35:54 +0100 Subject: [PATCH 02/12] rebase main --- sqlx-core/src/executor.rs | 73 ++++++++++++++ sqlx-core/src/query.rs | 176 +++++++++++++++++++++------------ sqlx-core/src/query_as.rs | 6 +- sqlx-core/src/query_builder.rs | 31 ++---- sqlx-core/src/query_scalar.rs | 18 ++-- 5 files changed, 203 insertions(+), 101 deletions(-) diff --git a/sqlx-core/src/executor.rs b/sqlx-core/src/executor.rs index 758cca7330..84b1a660d8 100644 --- a/sqlx-core/src/executor.rs +++ b/sqlx-core/src/executor.rs @@ -182,3 +182,76 @@ pub trait Executor<'c>: Send + Debug + Sized { where 'c: 'e; } + +/// A type that may be executed against a database connection. +/// +/// Implemented for the following: +/// +/// * [`&str`](std::str) +/// * [`Query`](super::query::Query) +/// +pub trait Execute<'q, DB: Database>: Send + Sized { + /// Gets the SQL that will be executed. + fn sql(&self) -> &'q str; + + /// Gets the previously cached statement, if available. + fn statement(&self) -> Option<&DB::Statement<'q>>; + + /// Returns the arguments to be bound against the query string. + /// + /// Returning `Ok(None)` for `Arguments` indicates to use a "simple" query protocol and to not + /// prepare the query. Returning `Ok(Some(Default::default()))` is an empty arguments object that + /// will be prepared (and cached) before execution. + /// + /// Returns `Err` if encoding any of the arguments failed. + fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError>; + + /// Returns `true` if the statement should be cached. + fn persistent(&self) -> bool; +} + +// NOTE: `Execute` is explicitly not implemented for String and &String to make it slightly more +// involved to write `conn.execute(format!("SELECT {val}"))` +impl<'q, DB: Database> Execute<'q, DB> for &'q str { + #[inline] + fn sql(&self) -> &'q str { + self + } + + #[inline] + fn statement(&self) -> Option<&DB::Statement<'q>> { + None + } + + #[inline] + fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { + Ok(None) + } + + #[inline] + fn persistent(&self) -> bool { + true + } +} + +impl<'q, DB: Database> Execute<'q, DB> for (&'q str, Option<::Arguments<'q>>) { + #[inline] + fn sql(&self) -> &'q str { + self.0 + } + + #[inline] + fn statement(&self) -> Option<&DB::Statement<'q>> { + None + } + + #[inline] + fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { + Ok(self.1.take()) + } + + #[inline] + fn persistent(&self) -> bool { + true + } +} diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index 7bd424bf16..60f509c342 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -8,16 +8,15 @@ use crate::arguments::{Arguments, IntoArguments}; use crate::database::{Database, HasStatementCache}; use crate::encode::Encode; use crate::error::{BoxDynError, Error}; -use crate::executor::{Executor}; -use crate::sql_str::{SqlSafeStr, SqlStr}; +use crate::executor::{Execute, Executor}; use crate::statement::Statement; use crate::types::Type; /// A single SQL query as a prepared statement. Returned by [`query()`]. #[must_use = "query must be executed to affect database"] -pub struct Query<'q, 'a, DB: Database> { - pub(crate) statement: Either>, - pub(crate) arguments: Option, BoxDynError>>, +pub struct Query<'q, DB: Database, A> { + pub(crate) statement: Either<&'q str, &'q DB::Statement<'q>>, + pub(crate) arguments: Option>, pub(crate) database: PhantomData, pub(crate) persistent: bool, } @@ -34,32 +33,46 @@ pub struct Query<'q, 'a, DB: Database> { /// before `.try_map()`. This is also to prevent adding superfluous binds to the result of /// `query!()` et al. #[must_use = "query must be executed to affect database"] -pub struct Map<'q, 'a, DB: Database, F> { - inner: Query<'q, 'a, DB>, +pub struct Map<'q, DB: Database, F, A> { + inner: Query<'q, DB, A>, mapper: F, } -impl<'q, 'a, DB> Query<'q, 'a, DB> +impl<'q, DB, A> Execute<'q, DB> for Query<'q, DB, A> where - DB: Database + HasStatementCache, + DB: Database, + A: Send + IntoArguments<'q, DB>, { - /// If `true`, the statement will get prepared once and cached to the - /// connection's statement cache. - /// - /// If queried once with the flag set to `true`, all subsequent queries - /// matching the one with the flag will use the cached statement until the - /// cache is cleared. - /// - /// If `false`, the prepared statement will be closed after execution. - /// - /// Default: `true`. - pub fn persistent(mut self, value: bool) -> Self { - self.persistent = value; - self + #[inline] + fn sql(&self) -> &'q str { + match self.statement { + Either::Right(statement) => statement.sql(), + Either::Left(sql) => sql, + } + } + + fn statement(&self) -> Option<&DB::Statement<'q>> { + match self.statement { + Either::Right(statement) => Some(statement), + Either::Left(_) => None, + } + } + + #[inline] + fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { + self.arguments + .take() + .transpose() + .map(|option| option.map(IntoArguments::into_arguments)) + } + + #[inline] + fn persistent(&self) -> bool { + self.persistent } } -impl<'q, 'a, DB: Database> Query<'q, 'a, DB> { +impl<'q, DB: Database> Query<'q, DB, ::Arguments<'q>> { /// Bind a value for use with this SQL query. /// /// If the number of times this is called does not match the number of bind parameters that @@ -107,10 +120,31 @@ impl<'q, 'a, DB: Database> Query<'q, 'a, DB> { } } -impl<'q, 'a, DB> Query<'q, 'a, DB> - where - DB: Database, - { +impl<'q, DB, A> Query<'q, DB, A> +where + DB: Database + HasStatementCache, +{ + /// If `true`, the statement will get prepared once and cached to the + /// connection's statement cache. + /// + /// If queried once with the flag set to `true`, all subsequent queries + /// matching the one with the flag will use the cached statement until the + /// cache is cleared. + /// + /// If `false`, the prepared statement will be closed after execution. + /// + /// Default: `true`. + pub fn persistent(mut self, value: bool) -> Self { + self.persistent = value; + self + } +} + +impl<'q, DB, A: Send> Query<'q, DB, A> +where + DB: Database, + A: 'q + IntoArguments<'q, DB>, +{ /// Map each row in the result to another type. /// /// See [`try_map`](Query::try_map) for a fallible version of this method. @@ -121,7 +155,7 @@ impl<'q, 'a, DB> Query<'q, 'a, DB> pub fn map( self, mut f: F, - ) -> Map<'q, 'a, DB, impl FnMut(DB::Row) -> Result + Send> + ) -> Map<'q, DB, impl FnMut(DB::Row) -> Result + Send, A> where F: FnMut(DB::Row) -> O + Send, O: Unpin, @@ -134,7 +168,7 @@ impl<'q, 'a, DB> Query<'q, 'a, DB> /// The [`query_as`](super::query_as::query_as) method will construct a mapped query using /// a [`FromRow`](super::from_row::FromRow) implementation. #[inline] - pub fn try_map(self, f: F) -> Map<'q, 'a, DB, F> + pub fn try_map(self, f: F) -> Map<'q, DB, F, A> where F: FnMut(DB::Row) -> Result + Send, O: Unpin, @@ -150,17 +184,33 @@ impl<'q, 'a, DB> Query<'q, 'a, DB> pub async fn execute<'e, 'c: 'e, E>(self, executor: E) -> Result where 'q: 'e, - 'a: 'e, + A: 'e, E: Executor<'c, Database = DB>, { executor.execute(self).await } + /// Execute multiple queries and return the rows affected from each query, in a stream. + #[inline] + #[deprecated = "Only the SQLite driver supports multiple statements in one prepared statement and that behavior is deprecated. Use `sqlx::raw_sql()` instead. See https://github.com/launchbadge/sqlx/issues/3108 for discussion."] + pub async fn execute_many<'e, 'c: 'e, E>( + self, + executor: E, + ) -> BoxStream<'e, Result> + where + 'q: 'e, + A: 'e, + E: Executor<'c, Database = DB>, + { + executor.execute_many(self) + } + /// Execute the query and return the generated results as a stream. #[inline] pub fn fetch<'e, 'c: 'e, E>(self, executor: E) -> BoxStream<'e, Result> where 'q: 'e, + A: 'e, E: Executor<'c, Database = DB>, { executor.fetch(self) @@ -179,6 +229,7 @@ impl<'q, 'a, DB> Query<'q, 'a, DB> ) -> BoxStream<'e, Result, Error>> where 'q: 'e, + A: 'e, E: Executor<'c, Database = DB>, { executor.fetch_many(self) @@ -195,6 +246,7 @@ impl<'q, 'a, DB> Query<'q, 'a, DB> pub async fn fetch_all<'e, 'c: 'e, E>(self, executor: E) -> Result, Error> where 'q: 'e, + A: 'e, E: Executor<'c, Database = DB>, { executor.fetch_all(self).await @@ -216,6 +268,7 @@ impl<'q, 'a, DB> Query<'q, 'a, DB> pub async fn fetch_one<'e, 'c: 'e, E>(self, executor: E) -> Result where 'q: 'e, + A: 'e, E: Executor<'c, Database = DB>, { executor.fetch_one(self).await @@ -237,51 +290,45 @@ impl<'q, 'a, DB> Query<'q, 'a, DB> pub async fn fetch_optional<'e, 'c: 'e, E>(self, executor: E) -> Result, Error> where 'q: 'e, + A: 'e, E: Executor<'c, Database = DB>, { executor.fetch_optional(self).await } } -#[doc(hidden)] -impl<'q, 'a, DB> Query<'q, 'a, DB> +impl<'q, DB, F: Send, A: Send> Execute<'q, DB> for Map<'q, DB, F, A> where DB: Database, + A: IntoArguments<'q, DB>, { #[inline] fn sql(&self) -> &'q str { - match &self.statement { - Either::Right(statement) => statement.sql(), - Either::Left(sql) => sql.as_str(), - } + self.inner.sql() } + #[inline] fn statement(&self) -> Option<&DB::Statement<'q>> { - match self.statement { - Either::Right(statement) => Some(statement), - Either::Left(_) => None, - } + self.inner.statement() } #[inline] fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { - self.arguments - .take() - .transpose() - .map(|option| option.map(IntoArguments::into_arguments)) + self.inner.take_arguments() } #[inline] - fn is_persistent(&self) -> bool { - self.persistent + fn persistent(&self) -> bool { + self.inner.arguments.is_some() } } -impl<'q, 'a, DB, F, O> Map<'q, 'a, DB, F> +impl<'q, DB, F, O, A> Map<'q, DB, F, A> where DB: Database, F: FnMut(DB::Row) -> Result + Send, O: Send + Unpin, + A: 'q + Send + IntoArguments<'q, DB>, { /// Map each row in the result to another type. /// @@ -293,7 +340,7 @@ where pub fn map( self, mut g: G, - ) -> Map<'q, 'a, DB, impl FnMut(DB::Row) -> Result + Send> + ) -> Map<'q, DB, impl FnMut(DB::Row) -> Result + Send, A> where G: FnMut(O) -> P + Send, P: Unpin, @@ -309,7 +356,7 @@ where pub fn try_map( self, mut g: G, - ) -> Map<'q, 'a, DB, impl FnMut(DB::Row) -> Result + Send> + ) -> Map<'q, DB, impl FnMut(DB::Row) -> Result + Send, A> where G: FnMut(O) -> Result + Send, P: Unpin, @@ -450,9 +497,9 @@ where } /// Execute a single SQL query as a prepared statement (explicitly created). -pub fn query_statement<'q, 'a, DB>( +pub fn query_statement<'q, DB>( statement: &'q DB::Statement<'q>, -) -> Query<'q, 'a, DB> +) -> Query<'q, DB, ::Arguments<'_>> where DB: Database, { @@ -465,17 +512,17 @@ where } /// Execute a single SQL query as a prepared statement (explicitly created), with the given arguments. -pub fn query_statement_with<'q, 'a, DB, A>( +pub fn query_statement_with<'q, DB, A>( statement: &'q DB::Statement<'q>, arguments: A, -) -> Query<'q, 'a, DB> +) -> Query<'q, DB, A> where DB: Database, A: IntoArguments<'q, DB>, { Query { database: PhantomData, - arguments: Some(Ok(arguments.into_arguments())), + arguments: Some(Ok(arguments)), statement: Either::Right(statement), persistent: true, } @@ -605,15 +652,14 @@ where /// /// As an additional benefit, query parameters are usually sent in a compact binary encoding instead of a human-readable /// text encoding, which saves bandwidth. -pub fn query<'a, DB, SQL>(sql: SQL) -> Query<'static, 'a, DB> +pub fn query(sql: &str) -> Query<'_, DB, ::Arguments<'_>> where DB: Database, - SQL: SqlSafeStr, { Query { database: PhantomData, arguments: Some(Ok(Default::default())), - statement: Either::Left(sql.into_sql_str()), + statement: Either::Left(sql), persistent: true, } } @@ -621,27 +667,27 @@ where /// Execute a SQL query as a prepared statement (transparently cached), with the given arguments. /// /// See [`query()`][query] for details, such as supported syntax. -pub fn query_with<'a, DB, SQL, A>(sql: SQL, arguments: A) -> Query<'static, 'a, DB> +pub fn query_with<'q, DB, A>(sql: &'q str, arguments: A) -> Query<'q, DB, A> where DB: Database, - A: IntoArguments<'a, DB>, + A: IntoArguments<'q, DB>, { query_with_result(sql, Ok(arguments)) } /// Same as [`query_with`] but is initialized with a Result of arguments instead -pub fn query_with_result<'a, DB, SQL>( - sql: SQL, - arguments: Result, BoxDynError>, -) -> Query<'static, 'a, DB> +pub fn query_with_result<'q, DB, A>( + sql: &'q str, + arguments: Result, +) -> Query<'q, DB, A> where DB: Database, - SQL: SqlSafeStr, + A: IntoArguments<'q, DB>, { Query { database: PhantomData, arguments: Some(arguments), - statement: Either::Left(sql.into_sql_str()), + statement: Either::Left(sql), persistent: true, } } diff --git a/sqlx-core/src/query_as.rs b/sqlx-core/src/query_as.rs index 257d28b75f..9f28fe41e9 100644 --- a/sqlx-core/src/query_as.rs +++ b/sqlx-core/src/query_as.rs @@ -11,7 +11,6 @@ use crate::error::{BoxDynError, Error}; use crate::executor::{Execute, Executor}; use crate::from_row::FromRow; use crate::query::{query, query_statement, query_statement_with, query_with_result, Query}; -use crate::sql_str::SqlSafeStr; use crate::types::Type; /// A single SQL query as a prepared statement, mapping results using [`FromRow`]. @@ -340,7 +339,7 @@ where /// /// ``` #[inline] -pub fn query_as<'q, DB, SQL, O>(sql: SQL) -> QueryAs<'q, DB, O, ::Arguments<'q>> +pub fn query_as<'q, DB, O>(sql: &'q str) -> QueryAs<'q, DB, O, ::Arguments<'q>> where DB: Database, O: for<'r> FromRow<'r, DB::Row>, @@ -358,10 +357,9 @@ where /// /// For details about type mapping from [`FromRow`], see [`query_as()`]. #[inline] -pub fn query_as_with<'q, DB, SQL, O, A>(sql: SQL, arguments: A) -> QueryAs<'q, DB, O, A> +pub fn query_as_with<'q, DB, O, A>(sql: &'q str, arguments: A) -> QueryAs<'q, DB, O, A> where DB: Database, - SQL: SqlSafeStr<'q>, A: IntoArguments<'q, DB>, O: for<'r> FromRow<'r, DB::Row>, { diff --git a/sqlx-core/src/query_builder.rs b/sqlx-core/src/query_builder.rs index d15d1a00cd..b242bf7b2a 100644 --- a/sqlx-core/src/query_builder.rs +++ b/sqlx-core/src/query_builder.rs @@ -3,7 +3,7 @@ use std::fmt::Display; use std::fmt::Write; use std::marker::PhantomData; -use std::sync::Arc; + use crate::arguments::{Arguments, IntoArguments}; use crate::database::Database; use crate::encode::Encode; @@ -13,7 +13,6 @@ use crate::query_as::QueryAs; use crate::query_scalar::QueryScalar; use crate::types::Type; use crate::Either; -use crate::sql_str::AssertSqlSafe; /// A builder type for constructing queries at runtime. /// @@ -26,9 +25,7 @@ pub struct QueryBuilder<'args, DB> where DB: Database, { - // Using `Arc` allows us to share the query string allocation with the database driver. - // It's only copied if the driver retains ownership after execution. - query: Arc, + query: String, init_len: usize, arguments: Option<::Arguments<'args>>, } @@ -88,16 +85,6 @@ where "QueryBuilder must be reset before reuse after `.build()`" ); } - - fn query_mut(&mut self) -> &mut String { - assert!( - self.arguments.is_some(), - "QueryBuilder must be reset before reuse after `.build()`" - ); - - Arc::get_mut(&mut self.query) - .expect("BUG: query must not be shared at this point in time") - } /// Append a SQL fragment to the query. /// @@ -129,7 +116,7 @@ where pub fn push(&mut self, sql: impl Display) -> &mut Self { self.sanity_check(); - write!(self.query_mut(), "{sql}").expect("error formatting `sql`"); + write!(self.query, "{sql}").expect("error formatting `sql`"); self } @@ -171,7 +158,7 @@ where arguments.add(value).expect("Failed to add argument"); arguments - .format_placeholder(self.query_mut()) + .format_placeholder(&mut self.query) .expect("error in format_placeholder"); self @@ -466,10 +453,12 @@ where pub fn build(&mut self) -> Query<'_, DB, ::Arguments<'args>> { self.sanity_check(); - crate::query::query_with( - AssertSqlSafe(&self.query), - self.arguments.take().expect("BUG: just ran sanity_check") - ) + Query { + statement: Either::Left(&self.query), + arguments: self.arguments.take().map(Ok), + database: PhantomData, + persistent: true, + } } /// Produce an executable query from this builder. diff --git a/sqlx-core/src/query_scalar.rs b/sqlx-core/src/query_scalar.rs index c097b9b18b..c131adcca3 100644 --- a/sqlx-core/src/query_scalar.rs +++ b/sqlx-core/src/query_scalar.rs @@ -4,7 +4,6 @@ use futures_util::{StreamExt, TryFutureExt, TryStreamExt}; use crate::arguments::IntoArguments; use crate::database::{Database, HasStatementCache}; -use crate::decode::Decode; use crate::encode::Encode; use crate::error::{BoxDynError, Error}; use crate::executor::{Execute, Executor}; @@ -12,7 +11,6 @@ use crate::from_row::FromRow; use crate::query_as::{ query_as, query_as_with_result, query_statement_as, query_statement_as_with, QueryAs, }; -use crate::sql_str::SqlSafeStr; use crate::types::Type; /// A single SQL query as a prepared statement which extracts only the first column of each row. @@ -320,13 +318,12 @@ where /// # } /// ``` #[inline] -pub fn query_scalar<'q, DB, SQL, O>( - sql: SQL, +pub fn query_scalar<'q, DB, O>( + sql: &'q str, ) -> QueryScalar<'q, DB, O, ::Arguments<'q>> where DB: Database, - SQL: SqlSafeStr<'q>, - O: Type + for<'r> Decode<'r, DB>, + (O,): for<'r> FromRow<'r, DB::Row>, { QueryScalar { inner: query_as(sql), @@ -340,12 +337,11 @@ where /// /// For details about prepared statements and allowed SQL syntax, see [`query()`][crate::query::query]. #[inline] -pub fn query_scalar_with<'q, DB, SQL, O, A>(sql: SQL, arguments: A) -> QueryScalar<'q, DB, O, A> +pub fn query_scalar_with<'q, DB, O, A>(sql: &'q str, arguments: A) -> QueryScalar<'q, DB, O, A> where DB: Database, - SQL: SqlSafeStr<'q>, A: IntoArguments<'q, DB>, - O: Type + for<'r> Decode<'r, DB>, + (O,): for<'r> FromRow<'r, DB::Row>, { query_scalar_with_result(sql, Ok(arguments)) } @@ -372,7 +368,7 @@ pub fn query_statement_scalar<'q, DB, O>( ) -> QueryScalar<'q, DB, O, ::Arguments<'_>> where DB: Database, - O: Type + for<'r> Decode<'r, DB>, + (O,): for<'r> FromRow<'r, DB::Row>, { QueryScalar { inner: query_statement_as(statement), @@ -387,7 +383,7 @@ pub fn query_statement_scalar_with<'q, DB, O, A>( where DB: Database, A: IntoArguments<'q, DB>, - O: Type + for<'r> Decode<'r, DB>, + (O,): for<'r> FromRow<'r, DB::Row>, { QueryScalar { inner: query_statement_as_with(statement, arguments), From cb9b7d0078b84ff107f3f1e02eaf5b02737a8274 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Sat, 1 Feb 2025 12:04:26 +0100 Subject: [PATCH 03/12] Add SqlStr + remove Statement lifetime --- sqlx-core/src/any/connection/backend.rs | 2 +- sqlx-core/src/any/connection/executor.rs | 2 +- sqlx-core/src/any/database.rs | 2 +- sqlx-core/src/any/statement.rs | 26 ++++++++++++------------ sqlx-core/src/column.rs | 4 ++-- sqlx-core/src/database.rs | 2 +- sqlx-core/src/executor.rs | 10 ++++----- sqlx-core/src/pool/executor.rs | 2 +- sqlx-core/src/query.rs | 10 ++++----- sqlx-core/src/query_as.rs | 6 +++--- sqlx-core/src/query_scalar.rs | 6 +++--- sqlx-core/src/raw_sql.rs | 2 +- sqlx-core/src/sql_str.rs | 10 ++++----- sqlx-core/src/statement.rs | 4 ++-- sqlx-mysql/src/any.rs | 2 +- sqlx-mysql/src/connection/executor.rs | 7 ++++--- sqlx-mysql/src/database.rs | 2 +- sqlx-mysql/src/statement.rs | 20 +++++++++--------- sqlx-postgres/src/any.rs | 2 +- sqlx-postgres/src/connection/executor.rs | 7 ++++--- sqlx-postgres/src/database.rs | 2 +- sqlx-postgres/src/listener.rs | 2 +- sqlx-postgres/src/statement.rs | 20 +++++++++--------- sqlx-sqlite/src/any.rs | 2 +- sqlx-sqlite/src/connection/executor.rs | 5 +++-- sqlx-sqlite/src/connection/worker.rs | 10 ++++----- sqlx-sqlite/src/database.rs | 2 +- sqlx-sqlite/src/statement/mod.rs | 20 +++++++++--------- 28 files changed, 97 insertions(+), 94 deletions(-) diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index 6c84c1d8ce..643a32f8f7 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -112,7 +112,7 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { &'c mut self, sql: &'q str, parameters: &[AnyTypeInfo], - ) -> BoxFuture<'c, crate::Result>>; + ) -> BoxFuture<'c, crate::Result>; fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, crate::Result>>; } diff --git a/sqlx-core/src/any/connection/executor.rs b/sqlx-core/src/any/connection/executor.rs index ccf6dd7933..3a941194bb 100644 --- a/sqlx-core/src/any/connection/executor.rs +++ b/sqlx-core/src/any/connection/executor.rs @@ -47,7 +47,7 @@ impl<'c> Executor<'c> for &'c mut AnyConnection { self, sql: &'q str, parameters: &[AnyTypeInfo], - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result> where 'c: 'e, { diff --git a/sqlx-core/src/any/database.rs b/sqlx-core/src/any/database.rs index 9c3f15bb1f..6e8343e928 100644 --- a/sqlx-core/src/any/database.rs +++ b/sqlx-core/src/any/database.rs @@ -28,7 +28,7 @@ impl Database for Any { type Arguments<'q> = AnyArguments<'q>; type ArgumentBuffer<'q> = AnyArgumentBuffer<'q>; - type Statement<'q> = AnyStatement<'q>; + type Statement = AnyStatement; const NAME: &'static str = "Any"; diff --git a/sqlx-core/src/any/statement.rs b/sqlx-core/src/any/statement.rs index 1fbb11895c..f42d121e1b 100644 --- a/sqlx-core/src/any/statement.rs +++ b/sqlx-core/src/any/statement.rs @@ -3,15 +3,15 @@ use crate::column::ColumnIndex; use crate::database::Database; use crate::error::Error; use crate::ext::ustr::UStr; +use crate::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; use crate::statement::Statement; use crate::HashMap; use either::Either; -use std::borrow::Cow; use std::sync::Arc; -pub struct AnyStatement<'q> { +pub struct AnyStatement { #[doc(hidden)] - pub sql: Cow<'q, str>, + pub sql: SqlStr, #[doc(hidden)] pub parameters: Option, usize>>, #[doc(hidden)] @@ -20,12 +20,12 @@ pub struct AnyStatement<'q> { pub columns: Vec, } -impl<'q> Statement<'q> for AnyStatement<'q> { +impl Statement for AnyStatement { type Database = Any; - fn to_owned(&self) -> AnyStatement<'static> { - AnyStatement::<'static> { - sql: Cow::Owned(self.sql.clone().into_owned()), + fn to_owned(&self) -> AnyStatement { + AnyStatement { + sql: self.sql.clone(), column_names: self.column_names.clone(), parameters: self.parameters.clone(), columns: self.columns.clone(), @@ -33,7 +33,7 @@ impl<'q> Statement<'q> for AnyStatement<'q> { } fn sql(&self) -> &str { - &self.sql + &self.sql.as_str() } fn parameters(&self) -> Option> { @@ -51,8 +51,8 @@ impl<'q> Statement<'q> for AnyStatement<'q> { impl_statement_query!(AnyArguments<'_>); } -impl<'i> ColumnIndex> for &'i str { - fn index(&self, statement: &AnyStatement<'_>) -> Result { +impl<'i> ColumnIndex for &'i str { + fn index(&self, statement: &AnyStatement) -> Result { statement .column_names .get(*self) @@ -61,7 +61,7 @@ impl<'i> ColumnIndex> for &'i str { } } -impl<'q> AnyStatement<'q> { +impl<'q> AnyStatement { #[doc(hidden)] pub fn try_from_statement( query: &'q str, @@ -69,7 +69,7 @@ impl<'q> AnyStatement<'q> { column_names: Arc>, ) -> crate::Result where - S: Statement<'q>, + S: Statement, AnyTypeInfo: for<'a> TryFrom<&'a ::TypeInfo, Error = Error>, AnyColumn: for<'a> TryFrom<&'a ::Column, Error = Error>, { @@ -91,7 +91,7 @@ impl<'q> AnyStatement<'q> { .collect::, _>>()?; Ok(Self { - sql: query.into(), + sql: AssertSqlSafe(query).into_sql_str(), columns, column_names, parameters, diff --git a/sqlx-core/src/column.rs b/sqlx-core/src/column.rs index 9f45819ed6..ebcd2d6829 100644 --- a/sqlx-core/src/column.rs +++ b/sqlx-core/src/column.rs @@ -69,8 +69,8 @@ macro_rules! impl_column_index_for_row { #[macro_export] macro_rules! impl_column_index_for_statement { ($S:ident) => { - impl $crate::column::ColumnIndex<$S<'_>> for usize { - fn index(&self, statement: &$S<'_>) -> Result { + impl $crate::column::ColumnIndex<$S> for usize { + fn index(&self, statement: &$S) -> Result { let len = $crate::statement::Statement::columns(statement).len(); if *self >= len { diff --git a/sqlx-core/src/database.rs b/sqlx-core/src/database.rs index e44c3d88ac..02d7a1214e 100644 --- a/sqlx-core/src/database.rs +++ b/sqlx-core/src/database.rs @@ -101,7 +101,7 @@ pub trait Database: 'static + Sized + Send + Debug { type ArgumentBuffer<'q>; /// The concrete `Statement` implementation for this database. - type Statement<'q>: Statement<'q, Database = Self>; + type Statement: Statement; /// The display name for this database driver. const NAME: &'static str; diff --git a/sqlx-core/src/executor.rs b/sqlx-core/src/executor.rs index 84b1a660d8..dde7810932 100644 --- a/sqlx-core/src/executor.rs +++ b/sqlx-core/src/executor.rs @@ -149,7 +149,7 @@ pub trait Executor<'c>: Send + Debug + Sized { fn prepare<'e, 'q: 'e>( self, query: &'q str, - ) -> BoxFuture<'e, Result<::Statement<'q>, Error>> + ) -> BoxFuture<'e, Result<::Statement, Error>> where 'c: 'e, { @@ -165,7 +165,7 @@ pub trait Executor<'c>: Send + Debug + Sized { self, sql: &'q str, parameters: &'e [::TypeInfo], - ) -> BoxFuture<'e, Result<::Statement<'q>, Error>> + ) -> BoxFuture<'e, Result<::Statement, Error>> where 'c: 'e; @@ -195,7 +195,7 @@ pub trait Execute<'q, DB: Database>: Send + Sized { fn sql(&self) -> &'q str; /// Gets the previously cached statement, if available. - fn statement(&self) -> Option<&DB::Statement<'q>>; + fn statement(&self) -> Option<&DB::Statement>; /// Returns the arguments to be bound against the query string. /// @@ -219,7 +219,7 @@ impl<'q, DB: Database> Execute<'q, DB> for &'q str { } #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { None } @@ -241,7 +241,7 @@ impl<'q, DB: Database> Execute<'q, DB> for (&'q str, Option<::Ar } #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { None } diff --git a/sqlx-core/src/pool/executor.rs b/sqlx-core/src/pool/executor.rs index ba27b44316..56527f59cc 100644 --- a/sqlx-core/src/pool/executor.rs +++ b/sqlx-core/src/pool/executor.rs @@ -52,7 +52,7 @@ where self, sql: &'q str, parameters: &'e [::TypeInfo], - ) -> BoxFuture<'e, Result<::Statement<'q>, Error>> { + ) -> BoxFuture<'e, Result<::Statement, Error>> { let pool = self.clone(); Box::pin(async move { pool.acquire().await?.prepare_with(sql, parameters).await }) diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index 60f509c342..080e47e04b 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -15,7 +15,7 @@ use crate::types::Type; /// A single SQL query as a prepared statement. Returned by [`query()`]. #[must_use = "query must be executed to affect database"] pub struct Query<'q, DB: Database, A> { - pub(crate) statement: Either<&'q str, &'q DB::Statement<'q>>, + pub(crate) statement: Either<&'q str, &'q DB::Statement>, pub(crate) arguments: Option>, pub(crate) database: PhantomData, pub(crate) persistent: bool, @@ -51,7 +51,7 @@ where } } - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { match self.statement { Either::Right(statement) => Some(statement), Either::Left(_) => None, @@ -308,7 +308,7 @@ where } #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { self.inner.statement() } @@ -498,7 +498,7 @@ where /// Execute a single SQL query as a prepared statement (explicitly created). pub fn query_statement<'q, DB>( - statement: &'q DB::Statement<'q>, + statement: &'q DB::Statement, ) -> Query<'q, DB, ::Arguments<'_>> where DB: Database, @@ -513,7 +513,7 @@ where /// Execute a single SQL query as a prepared statement (explicitly created), with the given arguments. pub fn query_statement_with<'q, DB, A>( - statement: &'q DB::Statement<'q>, + statement: &'q DB::Statement, arguments: A, ) -> Query<'q, DB, A> where diff --git a/sqlx-core/src/query_as.rs b/sqlx-core/src/query_as.rs index 9f28fe41e9..afe56293bc 100644 --- a/sqlx-core/src/query_as.rs +++ b/sqlx-core/src/query_as.rs @@ -32,7 +32,7 @@ where } #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { self.inner.statement() } @@ -385,7 +385,7 @@ where // Make a SQL query from a statement, that is mapped to a concrete type. pub fn query_statement_as<'q, DB, O>( - statement: &'q DB::Statement<'q>, + statement: &'q DB::Statement, ) -> QueryAs<'q, DB, O, ::Arguments<'_>> where DB: Database, @@ -399,7 +399,7 @@ where // Make a SQL query from a statement, with the given arguments, that is mapped to a concrete type. pub fn query_statement_as_with<'q, DB, O, A>( - statement: &'q DB::Statement<'q>, + statement: &'q DB::Statement, arguments: A, ) -> QueryAs<'q, DB, O, A> where diff --git a/sqlx-core/src/query_scalar.rs b/sqlx-core/src/query_scalar.rs index c131adcca3..8179d18fee 100644 --- a/sqlx-core/src/query_scalar.rs +++ b/sqlx-core/src/query_scalar.rs @@ -29,7 +29,7 @@ where self.inner.sql() } - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { self.inner.statement() } @@ -364,7 +364,7 @@ where // Make a SQL query from a statement, that is mapped to a concrete value. pub fn query_statement_scalar<'q, DB, O>( - statement: &'q DB::Statement<'q>, + statement: &'q DB::Statement, ) -> QueryScalar<'q, DB, O, ::Arguments<'_>> where DB: Database, @@ -377,7 +377,7 @@ where // Make a SQL query from a statement, with the given arguments, that is mapped to a concrete value. pub fn query_statement_scalar_with<'q, DB, O, A>( - statement: &'q DB::Statement<'q>, + statement: &'q DB::Statement, arguments: A, ) -> QueryScalar<'q, DB, O, A> where diff --git a/sqlx-core/src/raw_sql.rs b/sqlx-core/src/raw_sql.rs index 37627d4453..6108e8742e 100644 --- a/sqlx-core/src/raw_sql.rs +++ b/sqlx-core/src/raw_sql.rs @@ -123,7 +123,7 @@ impl<'q, DB: Database> Execute<'q, DB> for RawSql<'q> { self.0 } - fn statement(&self) -> Option<&::Statement<'q>> { + fn statement(&self) -> Option<&::Statement> { None } diff --git a/sqlx-core/src/sql_str.rs b/sqlx-core/src/sql_str.rs index 58ed23d6d3..1308e6f049 100644 --- a/sqlx-core/src/sql_str.rs +++ b/sqlx-core/src/sql_str.rs @@ -28,8 +28,8 @@ use std::sync::Arc; /// The recommended way to incorporate dynamic data or user input in a query is to use /// bind parameters, which requires the query to execute as a prepared statement. /// See [`query()`] for details. -/// -/// This trait and [`AssertSqlSafe`] are intentionally analogous to +/// +/// This trait and [`AssertSqlSafe`] are intentionally analogous to /// [`std::panic::UnwindSafe`] and [`std::panic::AssertUnwindSafe`], respectively. /// /// [injection]: https://en.wikipedia.org/wiki/SQL_injection @@ -66,7 +66,7 @@ impl SqlSafeStr for &'static str { pub struct AssertSqlSafe(pub T); /// Note: copies the string. -/// +/// /// It is recommended to pass one of the supported owned string types instead. impl<'a> SqlSafeStr for AssertSqlSafe<&'a str> { #[inline] @@ -92,7 +92,7 @@ impl SqlSafeStr for AssertSqlSafe> { impl SqlSafeStr for AssertSqlSafe> { #[inline] fn into_sql_str(self) -> SqlStr { - SqlStr(Repr::Arced(self.into())) + SqlStr(Repr::Arced(self.0)) } } @@ -139,7 +139,7 @@ impl SqlStr { pub(crate) fn from_arc_string(arc: Arc) -> Self { SqlStr(Repr::ArcString(arc)) } - + /// Borrow the inner query string. #[inline] pub fn as_str(&self) -> &str { diff --git a/sqlx-core/src/statement.rs b/sqlx-core/src/statement.rs index 17dfd6428d..8940e3c6cd 100644 --- a/sqlx-core/src/statement.rs +++ b/sqlx-core/src/statement.rs @@ -16,12 +16,12 @@ use either::Either; /// /// Statements can be re-used with any connection and on first-use it will be re-prepared and /// cached within the connection. -pub trait Statement<'q>: Send + Sync { +pub trait Statement: Send + Sync { type Database: Database; /// Creates an owned statement from this statement reference. This copies /// the original SQL text. - fn to_owned(&self) -> ::Statement<'static>; + fn to_owned(&self) -> ::Statement; /// Get the original SQL text used to create this statement. fn sql(&self) -> &str; diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index 19b3a6f27c..ccb4d8c341 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -137,7 +137,7 @@ impl AnyConnectionBackend for MySqlConnection { &'c mut self, sql: &'q str, _parameters: &[AnyTypeInfo], - ) -> BoxFuture<'c, sqlx_core::Result>> { + ) -> BoxFuture<'c, sqlx_core::Result> { Box::pin(async move { let statement = Executor::prepare_with(self, sql, &[]).await?; AnyStatement::try_from_statement( diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index 4f5af4bf6d..3f44ab72d9 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -22,7 +22,8 @@ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::TryStreamExt; -use std::{borrow::Cow, pin::pin, sync::Arc}; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; +use std::{pin::pin, sync::Arc}; impl MySqlConnection { async fn prepare_statement<'c>( @@ -301,7 +302,7 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { self, sql: &'q str, _parameters: &'e [MySqlTypeInfo], - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result> where 'c: 'e, { @@ -322,7 +323,7 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { }; Ok(MySqlStatement { - sql: Cow::Borrowed(sql), + sql: AssertSqlSafe(sql).into_sql_str(), // metadata has internal Arcs for expensive data structures metadata: metadata.clone(), }) diff --git a/sqlx-mysql/src/database.rs b/sqlx-mysql/src/database.rs index d03a567284..0e3f51f532 100644 --- a/sqlx-mysql/src/database.rs +++ b/sqlx-mysql/src/database.rs @@ -28,7 +28,7 @@ impl Database for MySql { type Arguments<'q> = MySqlArguments; type ArgumentBuffer<'q> = Vec; - type Statement<'q> = MySqlStatement<'q>; + type Statement = MySqlStatement; const NAME: &'static str = "MySQL"; diff --git a/sqlx-mysql/src/statement.rs b/sqlx-mysql/src/statement.rs index e9578403e1..1f32299467 100644 --- a/sqlx-mysql/src/statement.rs +++ b/sqlx-mysql/src/statement.rs @@ -5,14 +5,14 @@ use crate::ext::ustr::UStr; use crate::HashMap; use crate::{MySql, MySqlArguments, MySqlTypeInfo}; use either::Either; -use std::borrow::Cow; +use sqlx_core::sql_str::SqlStr; use std::sync::Arc; pub(crate) use sqlx_core::statement::*; #[derive(Debug, Clone)] -pub struct MySqlStatement<'q> { - pub(crate) sql: Cow<'q, str>, +pub struct MySqlStatement { + pub(crate) sql: SqlStr, pub(crate) metadata: MySqlStatementMetadata, } @@ -23,18 +23,18 @@ pub(crate) struct MySqlStatementMetadata { pub(crate) parameters: usize, } -impl<'q> Statement<'q> for MySqlStatement<'q> { +impl Statement for MySqlStatement { type Database = MySql; - fn to_owned(&self) -> MySqlStatement<'static> { - MySqlStatement::<'static> { - sql: Cow::Owned(self.sql.clone().into_owned()), + fn to_owned(&self) -> MySqlStatement { + MySqlStatement { + sql: self.sql.clone(), metadata: self.metadata.clone(), } } fn sql(&self) -> &str { - &self.sql + &self.sql.as_str() } fn parameters(&self) -> Option> { @@ -48,8 +48,8 @@ impl<'q> Statement<'q> for MySqlStatement<'q> { impl_statement_query!(MySqlArguments); } -impl ColumnIndex> for &'_ str { - fn index(&self, statement: &MySqlStatement<'_>) -> Result { +impl ColumnIndex for &'_ str { + fn index(&self, statement: &MySqlStatement) -> Result { statement .metadata .column_names diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index 762f53e5df..b516d47d71 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -137,7 +137,7 @@ impl AnyConnectionBackend for PgConnection { &'c mut self, sql: &'q str, _parameters: &[AnyTypeInfo], - ) -> BoxFuture<'c, sqlx_core::Result>> { + ) -> BoxFuture<'c, sqlx_core::Result> { Box::pin(async move { let statement = Executor::prepare_with(self, sql, &[]).await?; AnyStatement::try_from_statement( diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index 076c4209f6..b51445afff 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -17,8 +17,9 @@ use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::TryStreamExt; use sqlx_core::arguments::Arguments; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; use sqlx_core::Either; -use std::{borrow::Cow, pin::pin, sync::Arc}; +use std::{pin::pin, sync::Arc}; async fn prepare( conn: &mut PgConnection, @@ -440,7 +441,7 @@ impl<'c> Executor<'c> for &'c mut PgConnection { self, sql: &'q str, parameters: &'e [PgTypeInfo], - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result> where 'c: 'e, { @@ -450,7 +451,7 @@ impl<'c> Executor<'c> for &'c mut PgConnection { let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?; Ok(PgStatement { - sql: Cow::Borrowed(sql), + sql: AssertSqlSafe(sql).into_sql_str(), metadata, }) }) diff --git a/sqlx-postgres/src/database.rs b/sqlx-postgres/src/database.rs index 876e295899..fbc762615b 100644 --- a/sqlx-postgres/src/database.rs +++ b/sqlx-postgres/src/database.rs @@ -30,7 +30,7 @@ impl Database for Postgres { type Arguments<'q> = PgArguments; type ArgumentBuffer<'q> = PgArgumentBuffer; - type Statement<'q> = PgStatement<'q>; + type Statement = PgStatement; const NAME: &'static str = "PostgreSQL"; diff --git a/sqlx-postgres/src/listener.rs b/sqlx-postgres/src/listener.rs index 17a46a916f..c475e5bb0d 100644 --- a/sqlx-postgres/src/listener.rs +++ b/sqlx-postgres/src/listener.rs @@ -421,7 +421,7 @@ impl<'c> Executor<'c> for &'c mut PgListener { self, query: &'q str, parameters: &'e [PgTypeInfo], - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result> where 'c: 'e, { diff --git a/sqlx-postgres/src/statement.rs b/sqlx-postgres/src/statement.rs index abd553af30..c671ae3dfa 100644 --- a/sqlx-postgres/src/statement.rs +++ b/sqlx-postgres/src/statement.rs @@ -3,15 +3,15 @@ use crate::column::ColumnIndex; use crate::error::Error; use crate::ext::ustr::UStr; use crate::{PgArguments, Postgres}; -use std::borrow::Cow; use std::sync::Arc; +use sqlx_core::sql_str::SqlStr; pub(crate) use sqlx_core::statement::Statement; use sqlx_core::{Either, HashMap}; #[derive(Debug, Clone)] -pub struct PgStatement<'q> { - pub(crate) sql: Cow<'q, str>, +pub struct PgStatement { + pub(crate) sql: SqlStr, pub(crate) metadata: Arc, } @@ -24,18 +24,18 @@ pub(crate) struct PgStatementMetadata { pub(crate) parameters: Vec, } -impl<'q> Statement<'q> for PgStatement<'q> { +impl Statement for PgStatement { type Database = Postgres; - fn to_owned(&self) -> PgStatement<'static> { - PgStatement::<'static> { - sql: Cow::Owned(self.sql.clone().into_owned()), + fn to_owned(&self) -> PgStatement { + PgStatement { + sql: self.sql.clone(), metadata: self.metadata.clone(), } } fn sql(&self) -> &str { - &self.sql + &self.sql.as_str() } fn parameters(&self) -> Option> { @@ -49,8 +49,8 @@ impl<'q> Statement<'q> for PgStatement<'q> { impl_statement_query!(PgArguments); } -impl ColumnIndex> for &'_ str { - fn index(&self, statement: &PgStatement<'_>) -> Result { +impl ColumnIndex for &'_ str { + fn index(&self, statement: &PgStatement) -> Result { statement .metadata .column_names diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index c72370d0ff..5954444105 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -134,7 +134,7 @@ impl AnyConnectionBackend for SqliteConnection { &'c mut self, sql: &'q str, _parameters: &[AnyTypeInfo], - ) -> BoxFuture<'c, sqlx_core::Result>> { + ) -> BoxFuture<'c, sqlx_core::Result> { Box::pin(async move { let statement = Executor::prepare_with(self, sql, &[]).await?; AnyStatement::try_from_statement(sql, &statement, statement.column_names.clone()) diff --git a/sqlx-sqlite/src/connection/executor.rs b/sqlx-sqlite/src/connection/executor.rs index 1f6ce7726f..c0151d228f 100644 --- a/sqlx-sqlite/src/connection/executor.rs +++ b/sqlx-sqlite/src/connection/executor.rs @@ -7,6 +7,7 @@ use futures_util::{stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use sqlx_core::describe::Describe; use sqlx_core::error::Error; use sqlx_core::executor::{Execute, Executor}; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; use sqlx_core::Either; use std::{future, pin::pin}; @@ -76,7 +77,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { self, sql: &'q str, _parameters: &[SqliteTypeInfo], - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result> where 'c: 'e, { @@ -84,7 +85,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { let statement = self.worker.prepare(sql).await?; Ok(SqliteStatement { - sql: sql.into(), + sql: AssertSqlSafe(sql).into_sql_str(), ..statement }) }) diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index 00a4c2999c..9f8b6614af 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -1,4 +1,3 @@ -use std::borrow::Cow; use std::future::Future; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -6,6 +5,7 @@ use std::thread; use futures_channel::oneshot; use futures_intrusive::sync::{Mutex, MutexGuard}; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; use tracing::span::Span; use sqlx_core::describe::Describe; @@ -54,7 +54,7 @@ impl WorkerSharedState { enum Command { Prepare { query: Box, - tx: oneshot::Sender, Error>>, + tx: oneshot::Sender>, }, Describe { query: Box, @@ -335,7 +335,7 @@ impl ConnectionWorker { establish_rx.await.map_err(|_| Error::WorkerCrashed)? } - pub(crate) async fn prepare(&mut self, query: &str) -> Result, Error> { + pub(crate) async fn prepare(&mut self, query: &str) -> Result { self.oneshot_cmd(|tx| Command::Prepare { query: query.into(), tx, @@ -495,7 +495,7 @@ impl ConnectionWorker { } } -fn prepare(conn: &mut ConnectionState, query: &str) -> Result, Error> { +fn prepare(conn: &mut ConnectionState, query: &str) -> Result { // prepare statement object (or checkout from cache) let statement = conn.statements.get(query, true)?; @@ -514,7 +514,7 @@ fn prepare(conn: &mut ConnectionState, query: &str) -> Result = SqliteArguments<'q>; type ArgumentBuffer<'q> = Vec>; - type Statement<'q> = SqliteStatement<'q>; + type Statement = SqliteStatement; const NAME: &'static str = "SQLite"; diff --git a/sqlx-sqlite/src/statement/mod.rs b/sqlx-sqlite/src/statement/mod.rs index 179b8eeaf7..515baf19ba 100644 --- a/sqlx-sqlite/src/statement/mod.rs +++ b/sqlx-sqlite/src/statement/mod.rs @@ -2,8 +2,8 @@ use crate::column::ColumnIndex; use crate::error::Error; use crate::ext::ustr::UStr; use crate::{Sqlite, SqliteArguments, SqliteColumn, SqliteTypeInfo}; +use sqlx_core::sql_str::SqlStr; use sqlx_core::{Either, HashMap}; -use std::borrow::Cow; use std::sync::Arc; pub(crate) use sqlx_core::statement::*; @@ -17,19 +17,19 @@ pub(crate) use r#virtual::VirtualStatement; #[derive(Debug, Clone)] #[allow(clippy::rc_buffer)] -pub struct SqliteStatement<'q> { - pub(crate) sql: Cow<'q, str>, +pub struct SqliteStatement { + pub(crate) sql: SqlStr, pub(crate) parameters: usize, pub(crate) columns: Arc>, pub(crate) column_names: Arc>, } -impl<'q> Statement<'q> for SqliteStatement<'q> { +impl Statement for SqliteStatement { type Database = Sqlite; - fn to_owned(&self) -> SqliteStatement<'static> { - SqliteStatement::<'static> { - sql: Cow::Owned(self.sql.clone().into_owned()), + fn to_owned(&self) -> SqliteStatement { + SqliteStatement { + sql: self.sql.clone(), parameters: self.parameters, columns: Arc::clone(&self.columns), column_names: Arc::clone(&self.column_names), @@ -37,7 +37,7 @@ impl<'q> Statement<'q> for SqliteStatement<'q> { } fn sql(&self) -> &str { - &self.sql + &self.sql.as_str() } fn parameters(&self) -> Option> { @@ -51,8 +51,8 @@ impl<'q> Statement<'q> for SqliteStatement<'q> { impl_statement_query!(SqliteArguments<'_>); } -impl ColumnIndex> for &'_ str { - fn index(&self, statement: &SqliteStatement<'_>) -> Result { +impl ColumnIndex for &'_ str { + fn index(&self, statement: &SqliteStatement) -> Result { statement .column_names .get(*self) From 6dd8fd6c75ddd9fd35a0f1cc463ffbd0edcec599 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Sat, 1 Feb 2025 16:11:57 +0100 Subject: [PATCH 04/12] Update the definition of Executor and AnyConnectionBackend + update Postgres driver --- sqlx-core/src/any/connection/backend.rs | 9 +++--- sqlx-core/src/any/connection/executor.rs | 20 ++++++------ sqlx-core/src/any/statement.rs | 10 +++--- sqlx-core/src/executor.rs | 39 ++++++++++++---------- sqlx-core/src/logger.rs | 28 ++++++++-------- sqlx-core/src/migrate/migration.rs | 11 +++++-- sqlx-core/src/migrate/source.rs | 3 +- sqlx-core/src/pool/executor.rs | 21 ++++++++---- sqlx-core/src/query.rs | 17 +++++----- sqlx-core/src/query_as.rs | 11 ++++--- sqlx-core/src/query_builder.rs | 24 ++++++++------ sqlx-core/src/query_scalar.rs | 12 ++++--- sqlx-core/src/raw_sql.rs | 13 ++++---- sqlx-core/src/sql_str.rs | 21 +++++++++--- sqlx-core/src/statement.rs | 3 +- sqlx-core/src/transaction.rs | 21 ++++++------ sqlx-macros-core/src/database/mod.rs | 5 ++- sqlx-postgres/src/any.rs | 14 +++++--- sqlx-postgres/src/connection/describe.rs | 3 +- sqlx-postgres/src/connection/executor.rs | 41 ++++++++++++------------ sqlx-postgres/src/listener.rs | 22 ++++++++----- sqlx-postgres/src/migrate.rs | 17 +++++----- sqlx-postgres/src/statement.rs | 4 +-- sqlx-postgres/src/testing/mod.rs | 11 ++++--- sqlx-postgres/src/transaction.rs | 15 +++++---- sqlx-test/src/lib.rs | 9 ++++-- tests/any/any.rs | 5 +-- tests/any/pool.rs | 3 +- tests/postgres/derives.rs | 3 +- tests/postgres/postgres.rs | 22 +++++++++---- tests/postgres/query_builder.rs | 29 ++++++++++------- tests/postgres/rustsec.rs | 3 +- 32 files changed, 282 insertions(+), 187 deletions(-) diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index 643a32f8f7..5ee6b1f8fa 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -1,5 +1,6 @@ use crate::any::{Any, AnyArguments, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo}; use crate::describe::Describe; +use crate::sql_str::SqlStr; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; @@ -96,23 +97,23 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { fn fetch_many<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxStream<'q, crate::Result>>; fn fetch_optional<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxFuture<'q, crate::Result>>; fn prepare_with<'c, 'q: 'c>( &'c mut self, - sql: &'q str, + sql: SqlStr, parameters: &[AnyTypeInfo], ) -> BoxFuture<'c, crate::Result>; - fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, crate::Result>>; + fn describe<'q>(&'q mut self, sql: SqlStr) -> BoxFuture<'q, crate::Result>>; } diff --git a/sqlx-core/src/any/connection/executor.rs b/sqlx-core/src/any/connection/executor.rs index 3a941194bb..2f147f65bc 100644 --- a/sqlx-core/src/any/connection/executor.rs +++ b/sqlx-core/src/any/connection/executor.rs @@ -2,6 +2,7 @@ use crate::any::{Any, AnyConnection, AnyQueryResult, AnyRow, AnyStatement, AnyTy use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; +use crate::sql_str::SqlSafeStr; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; @@ -23,8 +24,8 @@ impl<'c> Executor<'c> for &'c mut AnyConnection { Ok(arguments) => arguments, Err(error) => return stream::once(future::ready(Err(error))).boxed(), }; - self.backend - .fetch_many(query.sql(), query.persistent(), arguments) + let persistent = query.persistent(); + self.backend.fetch_many(query.sql(), persistent, arguments) } fn fetch_optional<'e, 'q: 'e, E>( @@ -39,28 +40,29 @@ impl<'c> Executor<'c> for &'c mut AnyConnection { Ok(arguments) => arguments, Err(error) => return future::ready(Err(error)).boxed(), }; + let persistent = query.persistent(); self.backend - .fetch_optional(query.sql(), query.persistent(), arguments) + .fetch_optional(query.sql(), persistent, arguments) } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, parameters: &[AnyTypeInfo], ) -> BoxFuture<'e, Result> where 'c: 'e, { - self.backend.prepare_with(sql, parameters) + self.backend.prepare_with(sql.into_sql_str(), parameters) } - fn describe<'e, 'q: 'e>( + fn describe<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { - self.backend.describe(sql) + self.backend.describe(sql.into_sql_str()) } } diff --git a/sqlx-core/src/any/statement.rs b/sqlx-core/src/any/statement.rs index f42d121e1b..d92980a286 100644 --- a/sqlx-core/src/any/statement.rs +++ b/sqlx-core/src/any/statement.rs @@ -3,7 +3,7 @@ use crate::column::ColumnIndex; use crate::database::Database; use crate::error::Error; use crate::ext::ustr::UStr; -use crate::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; +use crate::sql_str::SqlStr; use crate::statement::Statement; use crate::HashMap; use either::Either; @@ -32,8 +32,8 @@ impl Statement for AnyStatement { } } - fn sql(&self) -> &str { - &self.sql.as_str() + fn sql(&self) -> SqlStr { + self.sql.clone() } fn parameters(&self) -> Option> { @@ -64,7 +64,7 @@ impl<'i> ColumnIndex for &'i str { impl<'q> AnyStatement { #[doc(hidden)] pub fn try_from_statement( - query: &'q str, + query: SqlStr, statement: &S, column_names: Arc>, ) -> crate::Result @@ -91,7 +91,7 @@ impl<'q> AnyStatement { .collect::, _>>()?; Ok(Self { - sql: AssertSqlSafe(query).into_sql_str(), + sql: query, columns, column_names, parameters, diff --git a/sqlx-core/src/executor.rs b/sqlx-core/src/executor.rs index dde7810932..a66c127789 100644 --- a/sqlx-core/src/executor.rs +++ b/sqlx-core/src/executor.rs @@ -1,6 +1,7 @@ use crate::database::Database; use crate::describe::Describe; use crate::error::{BoxDynError, Error}; +use crate::sql_str::{SqlSafeStr, SqlStr}; use either::Either; use futures_core::future::BoxFuture; @@ -146,9 +147,9 @@ pub trait Executor<'c>: Send + Debug + Sized { /// This explicit API is provided to allow access to the statement metadata available after /// it prepared but before the first row is returned. #[inline] - fn prepare<'e, 'q: 'e>( + fn prepare<'e>( self, - query: &'q str, + query: impl SqlSafeStr, ) -> BoxFuture<'e, Result<::Statement, Error>> where 'c: 'e, @@ -161,9 +162,9 @@ pub trait Executor<'c>: Send + Debug + Sized { /// /// Only some database drivers (PostgreSQL, MSSQL) can take advantage of /// this extra information to influence parameter type inference. - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, parameters: &'e [::TypeInfo], ) -> BoxFuture<'e, Result<::Statement, Error>> where @@ -175,9 +176,9 @@ pub trait Executor<'c>: Send + Debug + Sized { /// This is used by compile-time verification in the query macros to /// power their type inference. #[doc(hidden)] - fn describe<'e, 'q: 'e>( + fn describe<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, ) -> BoxFuture<'e, Result, Error>> where 'c: 'e; @@ -192,7 +193,7 @@ pub trait Executor<'c>: Send + Debug + Sized { /// pub trait Execute<'q, DB: Database>: Send + Sized { /// Gets the SQL that will be executed. - fn sql(&self) -> &'q str; + fn sql(self) -> SqlStr; /// Gets the previously cached statement, if available. fn statement(&self) -> Option<&DB::Statement>; @@ -210,12 +211,13 @@ pub trait Execute<'q, DB: Database>: Send + Sized { fn persistent(&self) -> bool; } -// NOTE: `Execute` is explicitly not implemented for String and &String to make it slightly more -// involved to write `conn.execute(format!("SELECT {val}"))` -impl<'q, DB: Database> Execute<'q, DB> for &'q str { +impl<'q, DB: Database, T> Execute<'q, DB> for (T, Option<::Arguments<'q>>) +where + T: SqlSafeStr + Send, +{ #[inline] - fn sql(&self) -> &'q str { - self + fn sql(self) -> SqlStr { + self.0.into_sql_str() } #[inline] @@ -225,7 +227,7 @@ impl<'q, DB: Database> Execute<'q, DB> for &'q str { #[inline] fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { - Ok(None) + Ok(self.1.take()) } #[inline] @@ -234,10 +236,13 @@ impl<'q, DB: Database> Execute<'q, DB> for &'q str { } } -impl<'q, DB: Database> Execute<'q, DB> for (&'q str, Option<::Arguments<'q>>) { +impl<'q, DB: Database, T> Execute<'q, DB> for T +where + T: SqlSafeStr + Send, +{ #[inline] - fn sql(&self) -> &'q str { - self.0 + fn sql(self) -> SqlStr { + self.into_sql_str() } #[inline] @@ -247,7 +252,7 @@ impl<'q, DB: Database> Execute<'q, DB> for (&'q str, Option<::Ar #[inline] fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { - Ok(self.1.take()) + Ok(None) } #[inline] diff --git a/sqlx-core/src/logger.rs b/sqlx-core/src/logger.rs index cf6dd533bd..7cbbff0077 100644 --- a/sqlx-core/src/logger.rs +++ b/sqlx-core/src/logger.rs @@ -1,4 +1,4 @@ -use crate::connection::LogSettings; +use crate::{connection::LogSettings, sql_str::SqlStr}; use std::time::Instant; // Yes these look silly. `tracing` doesn't currently support dynamic levels @@ -60,16 +60,16 @@ pub(crate) fn private_level_filter_to_trace_level( private_level_filter_to_levels(filter).map(|(level, _)| level) } -pub struct QueryLogger<'q> { - sql: &'q str, +pub struct QueryLogger { + sql: SqlStr, rows_returned: u64, rows_affected: u64, start: Instant, settings: LogSettings, } -impl<'q> QueryLogger<'q> { - pub fn new(sql: &'q str, settings: LogSettings) -> Self { +impl QueryLogger { + pub fn new(sql: SqlStr, settings: LogSettings) -> Self { Self { sql, rows_returned: 0, @@ -104,18 +104,18 @@ impl<'q> QueryLogger<'q> { let log_is_enabled = log::log_enabled!(target: "sqlx::query", log_level) || private_tracing_dynamic_enabled!(target: "sqlx::query", tracing_level); if log_is_enabled { - let mut summary = parse_query_summary(self.sql); + let mut summary = parse_query_summary(self.sql.as_str()); - let sql = if summary != self.sql { + let sql = if summary != self.sql.as_str() { summary.push_str(" …"); format!( "\n\n{}\n", - self.sql /* - sqlformat::format( - self.sql, - &sqlformat::QueryParams::None, - sqlformat::FormatOptions::default() - )*/ + self.sql.as_str() /* + sqlformat::format( + self.sql, + &sqlformat::QueryParams::None, + sqlformat::FormatOptions::default() + )*/ ) } else { String::new() @@ -158,7 +158,7 @@ impl<'q> QueryLogger<'q> { } } -impl<'q> Drop for QueryLogger<'q> { +impl Drop for QueryLogger { fn drop(&mut self) { self.finish(); } diff --git a/sqlx-core/src/migrate/migration.rs b/sqlx-core/src/migrate/migration.rs index 9bd7f569d8..4c57769a2a 100644 --- a/sqlx-core/src/migrate/migration.rs +++ b/sqlx-core/src/migrate/migration.rs @@ -2,6 +2,8 @@ use std::borrow::Cow; use sha2::{Digest, Sha384}; +use crate::sql_str::{SqlSafeStr, SqlStr}; + use super::MigrationType; #[derive(Debug, Clone)] @@ -9,7 +11,7 @@ pub struct Migration { pub version: i64, pub description: Cow<'static, str>, pub migration_type: MigrationType, - pub sql: Cow<'static, str>, + pub sql: SqlStr, pub checksum: Cow<'static, [u8]>, pub no_tx: bool, } @@ -19,10 +21,13 @@ impl Migration { version: i64, description: Cow<'static, str>, migration_type: MigrationType, - sql: Cow<'static, str>, + sql: impl SqlSafeStr, no_tx: bool, ) -> Self { - let checksum = Cow::Owned(Vec::from(Sha384::digest(sql.as_bytes()).as_slice())); + let sql = sql.into_sql_str(); + let checksum = Cow::Owned(Vec::from( + Sha384::digest(sql.as_str().as_bytes()).as_slice(), + )); Migration { version, diff --git a/sqlx-core/src/migrate/source.rs b/sqlx-core/src/migrate/source.rs index d0c23b43cd..7e1ef142c8 100644 --- a/sqlx-core/src/migrate/source.rs +++ b/sqlx-core/src/migrate/source.rs @@ -1,5 +1,6 @@ use crate::error::BoxDynError; use crate::migrate::{Migration, MigrationType}; +use crate::sql_str::AssertSqlSafe; use futures_core::future::BoxFuture; use std::borrow::Cow; @@ -131,7 +132,7 @@ pub fn resolve_blocking(path: &Path) -> Result, Resolv version, Cow::Owned(description), migration_type, - Cow::Owned(sql), + AssertSqlSafe(sql), no_tx, ), entry_path, diff --git a/sqlx-core/src/pool/executor.rs b/sqlx-core/src/pool/executor.rs index 56527f59cc..f0c1a31601 100644 --- a/sqlx-core/src/pool/executor.rs +++ b/sqlx-core/src/pool/executor.rs @@ -8,6 +8,7 @@ use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::pool::Pool; +use crate::sql_str::SqlSafeStr; impl<'p, DB: Database> Executor<'p> for &'_ Pool where @@ -48,22 +49,30 @@ where Box::pin(async move { pool.acquire().await?.fetch_optional(query).await }) } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, parameters: &'e [::TypeInfo], - ) -> BoxFuture<'e, Result<::Statement, Error>> { + ) -> BoxFuture<'e, Result<::Statement, Error>> + where + 'p: 'e, + { let pool = self.clone(); + let sql = sql.into_sql_str(); Box::pin(async move { pool.acquire().await?.prepare_with(sql, parameters).await }) } #[doc(hidden)] - fn describe<'e, 'q: 'e>( + fn describe<'e>( self, - sql: &'q str, - ) -> BoxFuture<'e, Result, Error>> { + sql: impl SqlSafeStr, + ) -> BoxFuture<'e, Result, Error>> + where + 'p: 'e, + { let pool = self.clone(); + let sql = sql.into_sql_str(); Box::pin(async move { pool.acquire().await?.describe(sql).await }) } diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index 080e47e04b..6a13babdfe 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -9,13 +9,14 @@ use crate::database::{Database, HasStatementCache}; use crate::encode::Encode; use crate::error::{BoxDynError, Error}; use crate::executor::{Execute, Executor}; +use crate::sql_str::{SqlSafeStr, SqlStr}; use crate::statement::Statement; use crate::types::Type; /// A single SQL query as a prepared statement. Returned by [`query()`]. #[must_use = "query must be executed to affect database"] pub struct Query<'q, DB: Database, A> { - pub(crate) statement: Either<&'q str, &'q DB::Statement>, + pub(crate) statement: Either, pub(crate) arguments: Option>, pub(crate) database: PhantomData, pub(crate) persistent: bool, @@ -44,7 +45,7 @@ where A: Send + IntoArguments<'q, DB>, { #[inline] - fn sql(&self) -> &'q str { + fn sql(self) -> SqlStr { match self.statement { Either::Right(statement) => statement.sql(), Either::Left(sql) => sql, @@ -303,7 +304,7 @@ where A: IntoArguments<'q, DB>, { #[inline] - fn sql(&self) -> &'q str { + fn sql(self) -> SqlStr { self.inner.sql() } @@ -652,14 +653,14 @@ where /// /// As an additional benefit, query parameters are usually sent in a compact binary encoding instead of a human-readable /// text encoding, which saves bandwidth. -pub fn query(sql: &str) -> Query<'_, DB, ::Arguments<'_>> +pub fn query<'a, DB>(sql: impl SqlSafeStr) -> Query<'a, DB, ::Arguments<'a>> where DB: Database, { Query { database: PhantomData, arguments: Some(Ok(Default::default())), - statement: Either::Left(sql), + statement: Either::Left(sql.into_sql_str()), persistent: true, } } @@ -667,7 +668,7 @@ where /// Execute a SQL query as a prepared statement (transparently cached), with the given arguments. /// /// See [`query()`][query] for details, such as supported syntax. -pub fn query_with<'q, DB, A>(sql: &'q str, arguments: A) -> Query<'q, DB, A> +pub fn query_with<'q, DB, A>(sql: impl SqlSafeStr, arguments: A) -> Query<'q, DB, A> where DB: Database, A: IntoArguments<'q, DB>, @@ -677,7 +678,7 @@ where /// Same as [`query_with`] but is initialized with a Result of arguments instead pub fn query_with_result<'q, DB, A>( - sql: &'q str, + sql: impl SqlSafeStr, arguments: Result, ) -> Query<'q, DB, A> where @@ -687,7 +688,7 @@ where Query { database: PhantomData, arguments: Some(arguments), - statement: Either::Left(sql), + statement: Either::Left(sql.into_sql_str()), persistent: true, } } diff --git a/sqlx-core/src/query_as.rs b/sqlx-core/src/query_as.rs index afe56293bc..2465472f44 100644 --- a/sqlx-core/src/query_as.rs +++ b/sqlx-core/src/query_as.rs @@ -11,6 +11,7 @@ use crate::error::{BoxDynError, Error}; use crate::executor::{Execute, Executor}; use crate::from_row::FromRow; use crate::query::{query, query_statement, query_statement_with, query_with_result, Query}; +use crate::sql_str::{SqlSafeStr, SqlStr}; use crate::types::Type; /// A single SQL query as a prepared statement, mapping results using [`FromRow`]. @@ -27,7 +28,7 @@ where A: 'q + IntoArguments<'q, DB>, { #[inline] - fn sql(&self) -> &'q str { + fn sql(self) -> SqlStr { self.inner.sql() } @@ -339,7 +340,9 @@ where /// /// ``` #[inline] -pub fn query_as<'q, DB, O>(sql: &'q str) -> QueryAs<'q, DB, O, ::Arguments<'q>> +pub fn query_as<'q, DB, O>( + sql: impl SqlSafeStr, +) -> QueryAs<'q, DB, O, ::Arguments<'q>> where DB: Database, O: for<'r> FromRow<'r, DB::Row>, @@ -357,7 +360,7 @@ where /// /// For details about type mapping from [`FromRow`], see [`query_as()`]. #[inline] -pub fn query_as_with<'q, DB, O, A>(sql: &'q str, arguments: A) -> QueryAs<'q, DB, O, A> +pub fn query_as_with<'q, DB, O, A>(sql: impl SqlSafeStr, arguments: A) -> QueryAs<'q, DB, O, A> where DB: Database, A: IntoArguments<'q, DB>, @@ -369,7 +372,7 @@ where /// Same as [`query_as_with`] but takes arguments as a Result #[inline] pub fn query_as_with_result<'q, DB, O, A>( - sql: &'q str, + sql: impl SqlSafeStr, arguments: Result, ) -> QueryAs<'q, DB, O, A> where diff --git a/sqlx-core/src/query_builder.rs b/sqlx-core/src/query_builder.rs index b242bf7b2a..2ff79e9f2e 100644 --- a/sqlx-core/src/query_builder.rs +++ b/sqlx-core/src/query_builder.rs @@ -3,6 +3,7 @@ use std::fmt::Display; use std::fmt::Write; use std::marker::PhantomData; +use std::sync::Arc; use crate::arguments::{Arguments, IntoArguments}; use crate::database::Database; @@ -11,6 +12,8 @@ use crate::from_row::FromRow; use crate::query::Query; use crate::query_as::QueryAs; use crate::query_scalar::QueryScalar; +use crate::sql_str::AssertSqlSafe; +use crate::sql_str::SqlSafeStr; use crate::types::Type; use crate::Either; @@ -25,7 +28,7 @@ pub struct QueryBuilder<'args, DB> where DB: Database, { - query: String, + query: Arc, init_len: usize, arguments: Option<::Arguments<'args>>, } @@ -34,7 +37,7 @@ impl<'args, DB: Database> Default for QueryBuilder<'args, DB> { fn default() -> Self { QueryBuilder { init_len: 0, - query: String::default(), + query: String::default().into(), arguments: Some(Default::default()), } } @@ -55,7 +58,7 @@ where QueryBuilder { init_len: init.len(), - query: init, + query: init.into(), arguments: Some(Default::default()), } } @@ -73,7 +76,7 @@ where QueryBuilder { init_len: init.len(), - query: init, + query: init.into(), arguments: Some(arguments.into_arguments()), } } @@ -115,8 +118,9 @@ where /// e.g. check that strings aren't too long, numbers are within expected ranges, etc. pub fn push(&mut self, sql: impl Display) -> &mut Self { self.sanity_check(); + let query: &mut String = Arc::get_mut(&mut self.query).expect(""); - write!(self.query, "{sql}").expect("error formatting `sql`"); + write!(query, "{sql}").expect("error formatting `sql`"); self } @@ -157,8 +161,9 @@ where .expect("BUG: Arguments taken already"); arguments.add(value).expect("Failed to add argument"); + let query: &mut String = Arc::get_mut(&mut self.query).expect(""); arguments - .format_placeholder(&mut self.query) + .format_placeholder(query) .expect("error in format_placeholder"); self @@ -454,7 +459,7 @@ where self.sanity_check(); Query { - statement: Either::Left(&self.query), + statement: Either::Left(AssertSqlSafe(self.query.clone()).into_sql_str()), arguments: self.arguments.take().map(Ok), database: PhantomData, persistent: true, @@ -511,7 +516,8 @@ where /// The query is truncated to the initial fragment provided to [`new()`][Self::new] and /// the bind arguments are reset. pub fn reset(&mut self) -> &mut Self { - self.query.truncate(self.init_len); + let query: &mut String = Arc::get_mut(&mut self.query).expect(""); + query.truncate(self.init_len); self.arguments = Some(Default::default()); self @@ -524,7 +530,7 @@ where /// Deconstruct this `QueryBuilder`, returning the built SQL. May not be syntactically correct. pub fn into_sql(self) -> String { - self.query + Arc::into_inner(self.query).unwrap() } } diff --git a/sqlx-core/src/query_scalar.rs b/sqlx-core/src/query_scalar.rs index 8179d18fee..24b418cdea 100644 --- a/sqlx-core/src/query_scalar.rs +++ b/sqlx-core/src/query_scalar.rs @@ -11,6 +11,7 @@ use crate::from_row::FromRow; use crate::query_as::{ query_as, query_as_with_result, query_statement_as, query_statement_as_with, QueryAs, }; +use crate::sql_str::{SqlSafeStr, SqlStr}; use crate::types::Type; /// A single SQL query as a prepared statement which extracts only the first column of each row. @@ -25,7 +26,7 @@ where A: 'q + IntoArguments<'q, DB>, { #[inline] - fn sql(&self) -> &'q str { + fn sql(self) -> SqlStr { self.inner.sql() } @@ -319,7 +320,7 @@ where /// ``` #[inline] pub fn query_scalar<'q, DB, O>( - sql: &'q str, + sql: impl SqlSafeStr, ) -> QueryScalar<'q, DB, O, ::Arguments<'q>> where DB: Database, @@ -337,7 +338,10 @@ where /// /// For details about prepared statements and allowed SQL syntax, see [`query()`][crate::query::query]. #[inline] -pub fn query_scalar_with<'q, DB, O, A>(sql: &'q str, arguments: A) -> QueryScalar<'q, DB, O, A> +pub fn query_scalar_with<'q, DB, O, A>( + sql: impl SqlSafeStr, + arguments: A, +) -> QueryScalar<'q, DB, O, A> where DB: Database, A: IntoArguments<'q, DB>, @@ -349,7 +353,7 @@ where /// Same as [`query_scalar_with`] but takes arguments as Result #[inline] pub fn query_scalar_with_result<'q, DB, O, A>( - sql: &'q str, + sql: impl SqlSafeStr, arguments: Result, ) -> QueryScalar<'q, DB, O, A> where diff --git a/sqlx-core/src/raw_sql.rs b/sqlx-core/src/raw_sql.rs index 6108e8742e..5fc12d7848 100644 --- a/sqlx-core/src/raw_sql.rs +++ b/sqlx-core/src/raw_sql.rs @@ -4,6 +4,7 @@ use futures_core::stream::BoxStream; use crate::database::Database; use crate::error::BoxDynError; use crate::executor::{Execute, Executor}; +use crate::sql_str::{SqlSafeStr, SqlStr}; use crate::Error; // AUTHOR'S NOTE: I was just going to call this API `sql()` and `Sql`, respectively, @@ -15,7 +16,7 @@ use crate::Error; /// One or more raw SQL statements, separated by semicolons (`;`). /// /// See [`raw_sql()`] for details. -pub struct RawSql<'q>(&'q str); +pub struct RawSql(SqlStr); /// Execute one or more statements as raw SQL, separated by semicolons (`;`). /// @@ -114,12 +115,12 @@ pub struct RawSql<'q>(&'q str); /// /// See [MySQL manual, section 13.3.3: Statements That Cause an Implicit Commit](https://dev.mysql.com/doc/refman/8.0/en/implicit-commit.html) for details. /// See also: [MariaDB manual: SQL statements That Cause an Implicit Commit](https://mariadb.com/kb/en/sql-statements-that-cause-an-implicit-commit/). -pub fn raw_sql(sql: &str) -> RawSql<'_> { - RawSql(sql) +pub fn raw_sql(sql: impl SqlSafeStr) -> RawSql { + RawSql(sql.into_sql_str()) } -impl<'q, DB: Database> Execute<'q, DB> for RawSql<'q> { - fn sql(&self) -> &'q str { +impl<'q, DB: Database> Execute<'q, DB> for RawSql { + fn sql(self) -> SqlStr { self.0 } @@ -136,7 +137,7 @@ impl<'q, DB: Database> Execute<'q, DB> for RawSql<'q> { } } -impl<'q> RawSql<'q> { +impl<'q> RawSql { /// Execute the SQL string and return the total number of rows affected. #[inline] pub async fn execute<'e, E>( diff --git a/sqlx-core/src/sql_str.rs b/sqlx-core/src/sql_str.rs index 1308e6f049..9e53ede4d0 100644 --- a/sqlx-core/src/sql_str.rs +++ b/sqlx-core/src/sql_str.rs @@ -1,6 +1,6 @@ use std::borrow::Borrow; use std::hash::{Hash, Hasher}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; /// A SQL string that is safe to execute on a database connection. /// @@ -96,6 +96,13 @@ impl SqlSafeStr for AssertSqlSafe> { } } +impl SqlSafeStr for AssertSqlSafe> { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::ArcString(self.0)) + } +} + /// A SQL string that is ready to execute on a database connection. /// /// This is essentially `Cow<'static, str>` but which can be constructed from additional types @@ -118,8 +125,16 @@ enum Repr { ArcString(Arc), } +static COUNT_CLONES: Mutex = Mutex::new(0usize); + impl Clone for SqlStr { fn clone(&self) -> Self { + let mut lock = COUNT_CLONES.lock().unwrap(); + *lock += 1; + let clones: usize = *lock; + drop(lock); + + println!("------- Count clones: {clones} --------\n\n\n"); Self(match &self.0 { Repr::Static(s) => Repr::Static(s), Repr::Arced(s) => Repr::Arced(s.clone()), @@ -136,10 +151,6 @@ impl SqlSafeStr for SqlStr { } impl SqlStr { - pub(crate) fn from_arc_string(arc: Arc) -> Self { - SqlStr(Repr::ArcString(arc)) - } - /// Borrow the inner query string. #[inline] pub fn as_str(&self) -> &str { diff --git a/sqlx-core/src/statement.rs b/sqlx-core/src/statement.rs index 8940e3c6cd..6f8fd95962 100644 --- a/sqlx-core/src/statement.rs +++ b/sqlx-core/src/statement.rs @@ -6,6 +6,7 @@ use crate::from_row::FromRow; use crate::query::Query; use crate::query_as::QueryAs; use crate::query_scalar::QueryScalar; +use crate::sql_str::SqlStr; use either::Either; /// An explicitly prepared statement. @@ -24,7 +25,7 @@ pub trait Statement: Send + Sync { fn to_owned(&self) -> ::Statement; /// Get the original SQL text used to create this statement. - fn sql(&self) -> &str; + fn sql(&self) -> SqlStr; /// Get the expected parameters for this statement. /// diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 2a84ff6555..c846d410d1 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -1,4 +1,3 @@ -use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::ops::{Deref, DerefMut}; @@ -7,6 +6,7 @@ use futures_core::future::BoxFuture; use crate::database::Database; use crate::error::Error; use crate::pool::MaybePoolConnection; +use crate::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; /// Generic management of database transactions. /// @@ -274,29 +274,30 @@ where } } -pub fn begin_ansi_transaction_sql(depth: usize) -> Cow<'static, str> { +pub fn begin_ansi_transaction_sql(depth: usize) -> SqlStr { if depth == 0 { - Cow::Borrowed("BEGIN") + "BEGIN".into_sql_str() } else { - Cow::Owned(format!("SAVEPOINT _sqlx_savepoint_{depth}")) + AssertSqlSafe(format!("SAVEPOINT _sqlx_savepoint_{depth}")).into_sql_str() } } -pub fn commit_ansi_transaction_sql(depth: usize) -> Cow<'static, str> { +pub fn commit_ansi_transaction_sql(depth: usize) -> SqlStr { if depth == 1 { - Cow::Borrowed("COMMIT") + "COMMIT".into_sql_str() } else { - Cow::Owned(format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", depth - 1)) + AssertSqlSafe(format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", depth - 1)).into_sql_str() } } -pub fn rollback_ansi_transaction_sql(depth: usize) -> Cow<'static, str> { +pub fn rollback_ansi_transaction_sql(depth: usize) -> SqlStr { if depth == 1 { - Cow::Borrowed("ROLLBACK") + "ROLLBACK".into_sql_str() } else { - Cow::Owned(format!( + AssertSqlSafe(format!( "ROLLBACK TO SAVEPOINT _sqlx_savepoint_{}", depth - 1 )) + .into_sql_str() } } diff --git a/sqlx-macros-core/src/database/mod.rs b/sqlx-macros-core/src/database/mod.rs index a2d0a1fa0d..90c2048386 100644 --- a/sqlx-macros-core/src/database/mod.rs +++ b/sqlx-macros-core/src/database/mod.rs @@ -8,6 +8,8 @@ use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; +use sqlx_core::sql_str::AssertSqlSafe; +use sqlx_core::sql_str::SqlSafeStr; use sqlx_core::type_checking::TypeChecking; #[cfg(any(feature = "postgres", feature = "mysql", feature = "_sqlite"))] @@ -58,7 +60,8 @@ impl CachingDescribeBlocking { } }; - conn.describe(query).await + conn.describe(AssertSqlSafe(query.to_string()).into_sql_str()) + .await }) } } diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index b516d47d71..353cbe6c16 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -7,6 +7,7 @@ use futures_core::stream::BoxStream; use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt}; use std::borrow::Cow; use std::{future, pin::pin}; +use sqlx_core::sql_str::SqlStr; use sqlx_core::any::{ Any, AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow, @@ -84,7 +85,7 @@ impl AnyConnectionBackend for PgConnection { fn fetch_many<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxStream<'q, sqlx_core::Result>> { @@ -110,7 +111,7 @@ impl AnyConnectionBackend for PgConnection { fn fetch_optional<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxFuture<'q, sqlx_core::Result>> { @@ -135,11 +136,11 @@ impl AnyConnectionBackend for PgConnection { fn prepare_with<'c, 'q: 'c>( &'c mut self, - sql: &'q str, + sql: SqlStr, _parameters: &[AnyTypeInfo], ) -> BoxFuture<'c, sqlx_core::Result> { Box::pin(async move { - let statement = Executor::prepare_with(self, sql, &[]).await?; + let statement = Executor::prepare_with(self, sql.clone(), &[]).await?; AnyStatement::try_from_statement( sql, &statement, @@ -148,7 +149,10 @@ impl AnyConnectionBackend for PgConnection { }) } - fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, sqlx_core::Result>> { + fn describe<'c, 'q>( + &'q mut self, + sql: SqlStr, + ) -> BoxFuture<'q, sqlx_core::Result>> { Box::pin(async move { let describe = Executor::describe(self, sql).await?; diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index a27578c56c..56c11cee94 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -12,6 +12,7 @@ use crate::HashMap; use crate::{PgColumn, PgConnection, PgTypeInfo}; use smallvec::SmallVec; use sqlx_core::query_builder::QueryBuilder; +use sqlx_core::sql_str::AssertSqlSafe; use std::sync::Arc; /// Describes the type of the `pg_type.typtype` column @@ -543,7 +544,7 @@ WHERE rngtypid = $1 } let (Json(explains),): (Json>,) = - query_as(&explain).fetch_one(self).await?; + query_as(AssertSqlSafe(explain)).fetch_one(self).await?; let mut nullables = Vec::new(); diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index b51445afff..2b6e24cd34 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -15,9 +15,9 @@ use crate::{ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; -use futures_util::TryStreamExt; +use futures_util::{pin_mut, TryStreamExt}; use sqlx_core::arguments::Arguments; -use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; +use sqlx_core::sql_str::{SqlSafeStr, SqlStr}; use sqlx_core::Either; use std::{pin::pin, sync::Arc}; @@ -193,14 +193,12 @@ impl PgConnection { pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>( &'c mut self, - query: &'q str, + query: SqlStr, arguments: Option, limit: u8, persistent: bool, metadata_opt: Option>, ) -> Result, Error>> + 'e, Error> { - let mut logger = QueryLogger::new(query, self.inner.log_settings.clone()); - // before we continue, wait until we are "ready" to accept more queries self.wait_until_ready().await?; @@ -223,7 +221,7 @@ impl PgConnection { // prepare the statement if this our first time executing it // always return the statement ID here let (statement, metadata_) = self - .get_or_prepare(query, &arguments.types, persistent, metadata_opt) + .get_or_prepare(query.as_str(), &arguments.types, persistent, metadata_opt) .await?; metadata = metadata_; @@ -274,7 +272,7 @@ impl PgConnection { PgValueFormat::Binary } else { // Query will trigger a ReadyForQuery - self.inner.stream.write_msg(Query(query))?; + self.inner.stream.write_msg(Query(query.as_str()))?; self.inner.pending_ready_for_query_count += 1; // metadata starts out as "nothing" @@ -285,6 +283,7 @@ impl PgConnection { }; self.inner.stream.flush().await?; + let mut logger = QueryLogger::new(query, self.inner.log_settings.clone()); Ok(try_stream! { loop { @@ -385,7 +384,6 @@ impl<'c> Executor<'c> for &'c mut PgConnection { 'q: 'e, E: 'q, { - let sql = query.sql(); // False positive: https://github.com/rust-lang/rust-clippy/issues/12560 #[allow(clippy::map_clone)] let metadata = query.statement().map(|s| Arc::clone(&s.metadata)); @@ -394,7 +392,9 @@ impl<'c> Executor<'c> for &'c mut PgConnection { Box::pin(try_stream! { let arguments = arguments?; - let mut s = pin!(self.run(sql, arguments, 0, persistent, metadata).await?); + let sql = query.sql(); + let s = self.run(sql, arguments, 0, persistent, metadata).await?; + pin_mut!(s); while let Some(v) = s.try_next().await? { r#yield!(v); @@ -411,7 +411,6 @@ impl<'c> Executor<'c> for &'c mut PgConnection { 'q: 'e, E: 'q, { - let sql = query.sql(); // False positive: https://github.com/rust-lang/rust-clippy/issues/12560 #[allow(clippy::map_clone)] let metadata = query.statement().map(|s| Arc::clone(&s.metadata)); @@ -419,6 +418,7 @@ impl<'c> Executor<'c> for &'c mut PgConnection { let persistent = query.persistent(); Box::pin(async move { + let sql = query.sql(); let arguments = arguments?; let mut s = pin!(self.run(sql, arguments, 1, persistent, metadata).await?); @@ -437,37 +437,38 @@ impl<'c> Executor<'c> for &'c mut PgConnection { }) } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, parameters: &'e [PgTypeInfo], ) -> BoxFuture<'e, Result> where 'c: 'e, { + let sql = sql.into_sql_str(); Box::pin(async move { self.wait_until_ready().await?; - let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?; + let (_, metadata) = self + .get_or_prepare(sql.as_str(), parameters, true, None) + .await?; - Ok(PgStatement { - sql: AssertSqlSafe(sql).into_sql_str(), - metadata, - }) + Ok(PgStatement { sql, metadata }) }) } - fn describe<'e, 'q: 'e>( + fn describe<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { + let sql = sql.into_sql_str(); Box::pin(async move { self.wait_until_ready().await?; - let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?; + let (stmt_id, metadata) = self.get_or_prepare(sql.as_str(), &[], true, None).await?; let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?; diff --git a/sqlx-postgres/src/listener.rs b/sqlx-postgres/src/listener.rs index c475e5bb0d..0c088f9487 100644 --- a/sqlx-postgres/src/listener.rs +++ b/sqlx-postgres/src/listener.rs @@ -7,6 +7,7 @@ use futures_core::future::BoxFuture; use futures_core::stream::{BoxStream, Stream}; use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use sqlx_core::acquire::Acquire; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; use sqlx_core::transaction::Transaction; use sqlx_core::Either; use tracing::Instrument; @@ -116,7 +117,7 @@ impl PgListener { pub async fn listen(&mut self, channel: &str) -> Result<(), Error> { self.connection() .await? - .execute(&*format!(r#"LISTEN "{}""#, ident(channel))) + .execute(AssertSqlSafe(format!(r#"LISTEN "{}""#, ident(channel)))) .await?; self.channels.push(channel.to_owned()); @@ -133,7 +134,10 @@ impl PgListener { self.channels.extend(channels.into_iter().map(|s| s.into())); let query = build_listen_all_query(&self.channels[beg..]); - self.connection().await?.execute(&*query).await?; + self.connection() + .await? + .execute(AssertSqlSafe(query)) + .await?; Ok(()) } @@ -145,7 +149,7 @@ impl PgListener { // UNLISTEN (we've disconnected anyways) if let Some(connection) = self.connection.as_mut() { connection - .execute(&*format!(r#"UNLISTEN "{}""#, ident(channel))) + .execute(AssertSqlSafe(format!(r#"UNLISTEN "{}""#, ident(channel)))) .await?; } @@ -176,7 +180,7 @@ impl PgListener { connection.inner.stream.notifications = self.buffer_tx.take(); connection - .execute(&*build_listen_all_query(&self.channels)) + .execute(AssertSqlSafe(build_listen_all_query(&self.channels))) .await?; self.connection = Some(connection); @@ -417,14 +421,15 @@ impl<'c> Executor<'c> for &'c mut PgListener { async move { self.connection().await?.fetch_optional(query).await }.boxed() } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - query: &'q str, + query: impl SqlSafeStr, parameters: &'e [PgTypeInfo], ) -> BoxFuture<'e, Result> where 'c: 'e, { + let query = query.into_sql_str(); async move { self.connection() .await? @@ -435,13 +440,14 @@ impl<'c> Executor<'c> for &'c mut PgListener { } #[doc(hidden)] - fn describe<'e, 'q: 'e>( + fn describe<'e>( self, - query: &'q str, + query: impl SqlSafeStr, ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { + let query = query.into_sql_str(); async move { self.connection().await?.describe(query).await }.boxed() } } diff --git a/sqlx-postgres/src/migrate.rs b/sqlx-postgres/src/migrate.rs index c37e92f4d6..af87449e83 100644 --- a/sqlx-postgres/src/migrate.rs +++ b/sqlx-postgres/src/migrate.rs @@ -7,6 +7,7 @@ use futures_core::future::BoxFuture; pub(crate) use sqlx_core::migrate::MigrateError; pub(crate) use sqlx_core::migrate::{AppliedMigration, Migration}; pub(crate) use sqlx_core::migrate::{Migrate, MigrateDatabase}; +use sqlx_core::sql_str::AssertSqlSafe; use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; @@ -45,10 +46,10 @@ impl MigrateDatabase for Postgres { let mut conn = options.connect().await?; let _ = conn - .execute(&*format!( + .execute(AssertSqlSafe(format!( "CREATE DATABASE \"{}\"", database.replace('"', "\"\"") - )) + ))) .await?; Ok(()) @@ -76,10 +77,10 @@ impl MigrateDatabase for Postgres { let mut conn = options.connect().await?; let _ = conn - .execute(&*format!( + .execute(AssertSqlSafe(format!( "DROP DATABASE IF EXISTS \"{}\"", database.replace('"', "\"\"") - )) + ))) .await?; Ok(()) @@ -99,10 +100,10 @@ impl MigrateDatabase for Postgres { let pid_type = if version >= 90200 { "pid" } else { "procpid" }; - conn.execute(&*format!( + conn.execute(AssertSqlSafe(format!( "SELECT pg_terminate_backend(pg_stat_activity.{pid_type}) FROM pg_stat_activity \ WHERE pg_stat_activity.datname = '{database}' AND {pid_type} <> pg_backend_pid()" - )) + ))) .await?; Self::drop_database(url).await @@ -277,7 +278,7 @@ async fn execute_migration( migration: &Migration, ) -> Result<(), MigrateError> { let _ = conn - .execute(&*migration.sql) + .execute(migration.sql.clone()) .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; @@ -302,7 +303,7 @@ async fn revert_migration( migration: &Migration, ) -> Result<(), MigrateError> { let _ = conn - .execute(&*migration.sql) + .execute(migration.sql.clone()) .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; diff --git a/sqlx-postgres/src/statement.rs b/sqlx-postgres/src/statement.rs index c671ae3dfa..864574781a 100644 --- a/sqlx-postgres/src/statement.rs +++ b/sqlx-postgres/src/statement.rs @@ -34,8 +34,8 @@ impl Statement for PgStatement { } } - fn sql(&self) -> &str { - &self.sql.as_str() + fn sql(&self) -> SqlStr { + self.sql.clone() } fn parameters(&self) -> Option> { diff --git a/sqlx-postgres/src/testing/mod.rs b/sqlx-postgres/src/testing/mod.rs index af20fe87ea..e127812a29 100644 --- a/sqlx-postgres/src/testing/mod.rs +++ b/sqlx-postgres/src/testing/mod.rs @@ -1,6 +1,7 @@ use std::fmt::Write; use std::ops::Deref; use std::str::FromStr; +use std::sync::Arc; use std::time::Duration; use futures_core::future::BoxFuture; @@ -8,6 +9,7 @@ use futures_core::future::BoxFuture; use once_cell::sync::OnceCell; use sqlx_core::connection::Connection; use sqlx_core::query_scalar::query_scalar; +use sqlx_core::sql_str::AssertSqlSafe; use crate::error::Error; use crate::executor::Executor; @@ -55,12 +57,13 @@ impl TestSupport for Postgres { let mut deleted_db_names = Vec::with_capacity(delete_db_names.len()); - let mut command = String::new(); + let mut command_arced = Arc::new(String::new()); for db_name in &delete_db_names { + let command = Arc::get_mut(&mut command_arced).unwrap(); command.clear(); writeln!(command, "drop database if exists {db_name:?};").ok(); - match conn.execute(&*command).await { + match conn.execute(AssertSqlSafe(command_arced.clone())).await { Ok(_deleted) => { deleted_db_names.push(db_name); } @@ -169,7 +172,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { let create_command = format!("create database {db_name:?}"); debug_assert!(create_command.starts_with("create database \"")); - conn.execute(&(create_command)[..]).await?; + conn.execute(AssertSqlSafe(create_command)).await?; Ok(TestContext { pool_opts: PoolOptions::new() @@ -191,7 +194,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { async fn do_cleanup(conn: &mut PgConnection, db_name: &str) -> Result<(), Error> { let delete_db_command = format!("drop database if exists {db_name:?};"); - conn.execute(&*delete_db_command).await?; + conn.execute(AssertSqlSafe(delete_db_command)).await?; query("delete from _sqlx_test.databases where db_name = $1::text") .bind(db_name) .execute(&mut *conn) diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index 23352a8dcf..b3f5c20f9e 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -30,6 +30,7 @@ impl TransactionManager for PgTransactionManager { }; let rollback = Rollback::new(conn); + rollback.conn.queue_simple_query(&statement)?; rollback.conn.wait_until_ready().await?; if !rollback.conn.in_transaction() { @@ -45,7 +46,7 @@ impl TransactionManager for PgTransactionManager { fn commit(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { if conn.inner.transaction_depth > 0 { - conn.execute(&*commit_ansi_transaction_sql(conn.inner.transaction_depth)) + conn.execute(commit_ansi_transaction_sql(conn.inner.transaction_depth)) .await?; conn.inner.transaction_depth -= 1; @@ -58,10 +59,8 @@ impl TransactionManager for PgTransactionManager { fn rollback(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { if conn.inner.transaction_depth > 0 { - conn.execute(&*rollback_ansi_transaction_sql( - conn.inner.transaction_depth, - )) - .await?; + conn.execute(rollback_ansi_transaction_sql(conn.inner.transaction_depth)) + .await?; conn.inner.transaction_depth -= 1; } @@ -72,8 +71,10 @@ impl TransactionManager for PgTransactionManager { fn start_rollback(conn: &mut PgConnection) { if conn.inner.transaction_depth > 0 { - conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.inner.transaction_depth)) - .expect("BUG: Rollback query somehow too large for protocol"); + conn.queue_simple_query( + rollback_ansi_transaction_sql(conn.inner.transaction_depth).as_str(), + ) + .expect("BUG: Rollback query somehow too large for protocol"); conn.inner.transaction_depth -= 1; } diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index cc77f38dba..29852a6437 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -109,12 +109,13 @@ macro_rules! test_unprepared_type { async fn [< test_unprepared_type_ $name >] () -> anyhow::Result<()> { use sqlx::prelude::*; use futures::TryStreamExt; + use sqlx_core::sql_str::AssertSqlSafe; let mut conn = sqlx_test::new::<$db>().await?; $( let query = format!("SELECT {}", $text); - let mut s = conn.fetch(&*query); + let mut s = conn.fetch(AssertSqlSafe(query)); let row = s.try_next().await?.unwrap(); let rec = row.try_get::<$ty, _>(0)?; @@ -137,13 +138,14 @@ macro_rules! __test_prepared_decode_type { #[sqlx_macros::test] async fn [< test_prepared_decode_type_ $name >] () -> anyhow::Result<()> { use sqlx::Row; + use sqlx_core::sql_str::AssertSqlSafe; let mut conn = sqlx_test::new::<$db>().await?; $( let query = format!("SELECT {}", $text); - let row = sqlx::query(&query) + let row = sqlx::query(AssertSqlSafe(query)) .fetch_one(&mut conn) .await?; @@ -166,6 +168,7 @@ macro_rules! __test_prepared_type { #[sqlx_macros::test] async fn [< test_prepared_type_ $name >] () -> anyhow::Result<()> { use sqlx::Row; + use sqlx_core::sql_str::AssertSqlSafe; let mut conn = sqlx_test::new::<$db>().await?; @@ -173,7 +176,7 @@ macro_rules! __test_prepared_type { let query = format!($sql, $text); println!("{query}"); - let row = sqlx::query(&query) + let row = sqlx::query(AssertSqlSafe(query)) .bind($value) .bind($value) .fetch_one(&mut conn) diff --git a/tests/any/any.rs b/tests/any/any.rs index 2c59ca5339..c84f8c818b 100644 --- a/tests/any/any.rs +++ b/tests/any/any.rs @@ -1,5 +1,6 @@ use sqlx::any::AnyRow; use sqlx::{Any, Connection, Executor, Row}; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_test::new; #[sqlx_macros::test] @@ -106,7 +107,7 @@ async fn it_can_fail_and_recover() -> anyhow::Result<()> { // now try and use the connection let val: i32 = conn - .fetch_one(&*format!("SELECT {i}")) + .fetch_one(AssertSqlSafe(format!("SELECT {i}"))) .await? .get_unchecked(0); @@ -132,7 +133,7 @@ async fn it_can_fail_and_recover_with_pool() -> anyhow::Result<()> { // now try and use the connection let val: i32 = pool - .fetch_one(&*format!("SELECT {i}")) + .fetch_one(AssertSqlSafe(format!("SELECT {i}"))) .await? .get_unchecked(0); diff --git a/tests/any/pool.rs b/tests/any/pool.rs index 3130b4f1c6..a4849940b8 100644 --- a/tests/any/pool.rs +++ b/tests/any/pool.rs @@ -1,5 +1,6 @@ use sqlx::any::{AnyConnectOptions, AnyPoolOptions}; use sqlx::Executor; +use sqlx_core::sql_str::AssertSqlSafe; use std::sync::{ atomic::{AtomicI32, AtomicUsize, Ordering}, Arc, Mutex, @@ -111,7 +112,7 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { id ); - conn.execute(&statement[..]).await?; + conn.execute(AssertSqlSafe(statement)).await?; Ok(()) }) }) diff --git a/tests/postgres/derives.rs b/tests/postgres/derives.rs index 13f9bf1d5d..88f9d026f4 100644 --- a/tests/postgres/derives.rs +++ b/tests/postgres/derives.rs @@ -1,6 +1,7 @@ use futures::TryStreamExt; use sqlx::postgres::types::PgRange; use sqlx::{Connection, Executor, FromRow, Postgres}; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_postgres::PgHasArrayType; use sqlx_test::{new, test_type}; use std::fmt::Debug; @@ -259,7 +260,7 @@ SELECT id, mood FROM people WHERE id = $1 let stmt = format!("SELECT id, mood FROM people WHERE id = {people_id}"); dbg!(&stmt); - let mut cursor = conn.fetch(&*stmt); + let mut cursor = conn.fetch(AssertSqlSafe(stmt)); let row = cursor.try_next().await?.unwrap(); let rec = PeopleRow::from_row(&row)?; diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index fc7108bf4f..b0285f53bf 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -6,6 +6,7 @@ use sqlx::postgres::{ PgPoolOptions, PgRow, PgSeverity, Postgres, PG_COPY_MAX_DATA_LEN, }; use sqlx::{Column, Connection, Executor, Row, Statement, TypeInfo}; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_core::{bytes::Bytes, error::BoxDynError}; use sqlx_test::{new, pool, setup_if_needed}; use std::env; @@ -309,7 +310,10 @@ async fn it_can_fail_and_recover() -> anyhow::Result<()> { assert!(res.is_err()); // now try and use the connection - let val: i32 = conn.fetch_one(&*format!("SELECT {i}::int4")).await?.get(0); + let val: i32 = conn + .fetch_one(AssertSqlSafe(format!("SELECT {i}::int4"))) + .await? + .get(0); assert_eq!(val, i); } @@ -330,7 +334,10 @@ async fn it_can_fail_and_recover_with_pool() -> anyhow::Result<()> { assert!(res.is_err()); // now try and use the connection - let val: i32 = pool.fetch_one(&*format!("SELECT {i}::int4")).await?.get(0); + let val: i32 = pool + .fetch_one(AssertSqlSafe(format!("SELECT {i}::int4"))) + .await? + .get(0); assert_eq!(val, i); } @@ -803,7 +810,7 @@ async fn it_closes_statement_from_cache_issue_470() -> anyhow::Result<()> { let mut conn = PgConnection::connect_with(&options).await?; for i in 0..5 { - let row = sqlx::query(&*format!("SELECT {i}::int4 AS val")) + let row = sqlx::query(AssertSqlSafe(format!("SELECT {i}::int4 AS val"))) .fetch_one(&mut conn) .await?; @@ -1099,8 +1106,10 @@ async fn test_listener_try_recv_buffered() -> anyhow::Result<()> { { let mut txn = notify_conn.begin().await?; for i in 0..5 { - txn.execute(format!("NOTIFY test_channel2, 'payload {i}'").as_str()) - .await?; + txn.execute(AssertSqlSafe(format!( + "NOTIFY test_channel2, 'payload {i}'" + ))) + .await?; } txn.commit().await?; } @@ -1951,7 +1960,8 @@ async fn test_postgres_bytea_hex_deserialization_errors() -> anyhow::Result<()> conn.execute("SET bytea_output = 'escape';").await?; for value in ["", "DEADBEEF"] { let query = format!("SELECT '\\x{value}'::bytea"); - let res: sqlx::Result> = conn.fetch_one(query.as_str()).await?.try_get(0usize); + let res: sqlx::Result> = + conn.fetch_one(AssertSqlSafe(query)).await?.try_get(0usize); // Deserialization only supports hex format so this should error and definitely not panic. res.unwrap_err(); } diff --git a/tests/postgres/query_builder.rs b/tests/postgres/query_builder.rs index 08ed7d11a3..b1e0659eff 100644 --- a/tests/postgres/query_builder.rs +++ b/tests/postgres/query_builder.rs @@ -3,6 +3,7 @@ use sqlx::query_builder::QueryBuilder; use sqlx::Executor; use sqlx::Type; use sqlx::{Either, Execute}; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_test::new; #[test] @@ -54,18 +55,20 @@ fn test_build() { qb.push(" WHERE id = ").push_bind(42i32); let query = qb.build(); - assert_eq!(query.sql(), "SELECT * FROM users WHERE id = $1"); assert_eq!(Execute::persistent(&query), true); + assert_eq!(query.sql(), "SELECT * FROM users WHERE id = $1"); } #[test] fn test_reset() { let mut qb: QueryBuilder<'_, Postgres> = QueryBuilder::new(""); - let _query = qb - .push("SELECT * FROM users WHERE id = ") - .push_bind(42i32) - .build(); + { + let _query = qb + .push("SELECT * FROM users WHERE id = ") + .push_bind(42i32) + .build(); + } qb.reset(); @@ -76,10 +79,12 @@ fn test_reset() { fn test_query_builder_reuse() { let mut qb: QueryBuilder<'_, Postgres> = QueryBuilder::new(""); - let _query = qb - .push("SELECT * FROM users WHERE id = ") - .push_bind(42i32) - .build(); + { + let _query = qb + .push("SELECT * FROM users WHERE id = ") + .push_bind(42i32) + .build(); + } qb.reset(); @@ -97,8 +102,10 @@ fn test_query_builder_with_args() { .push_bind(42i32) .build(); + let args = query.take_arguments().unwrap().unwrap(); + let mut qb: QueryBuilder<'_, Postgres> = - QueryBuilder::with_arguments(query.sql(), query.take_arguments().unwrap().unwrap()); + QueryBuilder::with_arguments(query.sql().as_str(), args); let query = qb.push(" OR membership_level = ").push_bind(3i32).build(); assert_eq!( @@ -129,7 +136,7 @@ async fn test_max_number_of_binds() -> anyhow::Result<()> { let mut conn = new::().await?; // Indirectly ensures the macros support this many binds since this is what they use. - let describe = conn.describe(qb.sql()).await?; + let describe = conn.describe(AssertSqlSafe(qb.sql().to_string())).await?; match describe .parameters diff --git a/tests/postgres/rustsec.rs b/tests/postgres/rustsec.rs index 45fd76b9db..a0692be4c6 100644 --- a/tests/postgres/rustsec.rs +++ b/tests/postgres/rustsec.rs @@ -1,4 +1,5 @@ use sqlx::{Error, PgPool}; +use sqlx_core::sql_str::AssertSqlSafe; use std::{cmp, str}; @@ -114,7 +115,7 @@ async fn rustsec_2024_0363(pool: PgPool) -> anyhow::Result<()> { assert_eq!(wrapped_len, fake_payload_len); - let res = sqlx::raw_sql(&query) + let res = sqlx::raw_sql(AssertSqlSafe(query)) // Note: the connection may hang afterward // because `pending_ready_for_query_count` will underflow. .execute(&pool) From d836f65a9c9b48d0c060f07bba9c4e5e539b2315 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Sat, 1 Feb 2025 17:43:02 +0100 Subject: [PATCH 05/12] Update MySql driver --- sqlx-core/src/migrate/migration.rs | 12 ++++------- sqlx-core/src/migrate/source.rs | 3 +-- sqlx-mysql/src/any.rs | 11 +++++----- sqlx-mysql/src/connection/executor.rs | 31 ++++++++++++++------------- sqlx-mysql/src/migrate.rs | 12 +++++++---- sqlx-mysql/src/options/connect.rs | 3 ++- sqlx-mysql/src/statement.rs | 4 ++-- sqlx-mysql/src/testing/mod.rs | 11 ++++++---- sqlx-mysql/src/transaction.rs | 7 +++--- sqlx-postgres/src/migrate.rs | 4 ++-- src/lib.rs | 1 + 11 files changed, 53 insertions(+), 46 deletions(-) diff --git a/sqlx-core/src/migrate/migration.rs b/sqlx-core/src/migrate/migration.rs index 4c57769a2a..394a968fd4 100644 --- a/sqlx-core/src/migrate/migration.rs +++ b/sqlx-core/src/migrate/migration.rs @@ -2,8 +2,6 @@ use std::borrow::Cow; use sha2::{Digest, Sha384}; -use crate::sql_str::{SqlSafeStr, SqlStr}; - use super::MigrationType; #[derive(Debug, Clone)] @@ -11,7 +9,8 @@ pub struct Migration { pub version: i64, pub description: Cow<'static, str>, pub migration_type: MigrationType, - pub sql: SqlStr, + // We can't use `SqlStr` here because it can't be used in a const context + pub sql: Cow<'static, str>, pub checksum: Cow<'static, [u8]>, pub no_tx: bool, } @@ -21,13 +20,10 @@ impl Migration { version: i64, description: Cow<'static, str>, migration_type: MigrationType, - sql: impl SqlSafeStr, + sql: Cow<'static, str>, no_tx: bool, ) -> Self { - let sql = sql.into_sql_str(); - let checksum = Cow::Owned(Vec::from( - Sha384::digest(sql.as_str().as_bytes()).as_slice(), - )); + let checksum = Cow::Owned(Vec::from(Sha384::digest(sql.as_bytes()).as_slice())); Migration { version, diff --git a/sqlx-core/src/migrate/source.rs b/sqlx-core/src/migrate/source.rs index 7e1ef142c8..d0c23b43cd 100644 --- a/sqlx-core/src/migrate/source.rs +++ b/sqlx-core/src/migrate/source.rs @@ -1,6 +1,5 @@ use crate::error::BoxDynError; use crate::migrate::{Migration, MigrationType}; -use crate::sql_str::AssertSqlSafe; use futures_core::future::BoxFuture; use std::borrow::Cow; @@ -132,7 +131,7 @@ pub fn resolve_blocking(path: &Path) -> Result, Resolv version, Cow::Owned(description), migration_type, - AssertSqlSafe(sql), + Cow::Owned(sql), no_tx, ), entry_path, diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index ccb4d8c341..44f6092352 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -15,6 +15,7 @@ use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; +use sqlx_core::sql_str::SqlStr; use sqlx_core::transaction::TransactionManager; use std::borrow::Cow; use std::{future, pin::pin}; @@ -82,7 +83,7 @@ impl AnyConnectionBackend for MySqlConnection { fn fetch_many<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxStream<'q, sqlx_core::Result>> { @@ -108,7 +109,7 @@ impl AnyConnectionBackend for MySqlConnection { fn fetch_optional<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxFuture<'q, sqlx_core::Result>> { @@ -135,11 +136,11 @@ impl AnyConnectionBackend for MySqlConnection { fn prepare_with<'c, 'q: 'c>( &'c mut self, - sql: &'q str, + sql: SqlStr, _parameters: &[AnyTypeInfo], ) -> BoxFuture<'c, sqlx_core::Result> { Box::pin(async move { - let statement = Executor::prepare_with(self, sql, &[]).await?; + let statement = Executor::prepare_with(self, sql.clone(), &[]).await?; AnyStatement::try_from_statement( sql, &statement, @@ -148,7 +149,7 @@ impl AnyConnectionBackend for MySqlConnection { }) } - fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, sqlx_core::Result>> { + fn describe<'q>(&'q mut self, sql: SqlStr) -> BoxFuture<'q, sqlx_core::Result>> { Box::pin(async move { let describe = Executor::describe(self, sql).await?; describe.try_into_any() diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index 3f44ab72d9..f74ec1143a 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -22,7 +22,7 @@ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::TryStreamExt; -use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; +use sqlx_core::sql_str::{SqlSafeStr, SqlStr}; use std::{pin::pin, sync::Arc}; impl MySqlConnection { @@ -102,13 +102,11 @@ impl MySqlConnection { #[allow(clippy::needless_lifetimes)] pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>( &'c mut self, - sql: &'q str, + sql: SqlStr, arguments: Option, persistent: bool, ) -> Result, Error>> + 'e, Error> { - let mut logger = QueryLogger::new(sql, self.inner.log_settings.clone()); - self.inner.stream.wait_until_ready().await?; self.inner.stream.waiting.push_back(Waiting::Result); @@ -121,7 +119,7 @@ impl MySqlConnection { let (mut column_names, format, mut needs_metadata) = if let Some(arguments) = arguments { if persistent && self.inner.cache_statement.is_enabled() { let (id, metadata) = self - .get_or_prepare_statement(sql) + .get_or_prepare_statement(sql.as_str()) .await?; // https://dev.mysql.com/doc/internals/en/com-stmt-execute.html @@ -135,7 +133,7 @@ impl MySqlConnection { (metadata.column_names, MySqlValueFormat::Binary, false) } else { let (id, metadata) = self - .prepare_statement(sql) + .prepare_statement(sql.as_str()) .await?; // https://dev.mysql.com/doc/internals/en/com-stmt-execute.html @@ -152,10 +150,11 @@ impl MySqlConnection { } } else { // https://dev.mysql.com/doc/internals/en/com-query.html - self.inner.stream.send_packet(Query(sql)).await?; + self.inner.stream.send_packet(Query(sql.as_str())).await?; (Arc::default(), MySqlValueFormat::Text, true) }; + let mut logger = QueryLogger::new(sql, self.inner.log_settings.clone()); loop { // query response is a meta-packet which may be one of: @@ -262,11 +261,11 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { 'q: 'e, E: 'q, { - let sql = query.sql(); let arguments = query.take_arguments().map_err(Error::Encode); let persistent = query.persistent(); Box::pin(try_stream! { + let sql = query.sql(); let arguments = arguments?; let mut s = pin!(self.run(sql, arguments, persistent).await?); @@ -298,21 +297,22 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { }) } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, _parameters: &'e [MySqlTypeInfo], ) -> BoxFuture<'e, Result> where 'c: 'e, { + let sql = sql.into_sql_str(); Box::pin(async move { self.inner.stream.wait_until_ready().await?; let metadata = if self.inner.cache_statement.is_enabled() { - self.get_or_prepare_statement(sql).await?.1 + self.get_or_prepare_statement(sql.as_str()).await?.1 } else { - let (id, metadata) = self.prepare_statement(sql).await?; + let (id, metadata) = self.prepare_statement(sql.as_str()).await?; self.inner .stream @@ -323,7 +323,7 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { }; Ok(MySqlStatement { - sql: AssertSqlSafe(sql).into_sql_str(), + sql, // metadata has internal Arcs for expensive data structures metadata: metadata.clone(), }) @@ -331,14 +331,15 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { } #[doc(hidden)] - fn describe<'e, 'q: 'e>(self, sql: &'q str) -> BoxFuture<'e, Result, Error>> + fn describe<'e>(self, sql: impl SqlSafeStr) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { + let sql = sql.into_sql_str(); Box::pin(async move { self.inner.stream.wait_until_ready().await?; - let (id, metadata) = self.prepare_statement(sql).await?; + let (id, metadata) = self.prepare_statement(sql.as_str()).await?; self.inner .stream diff --git a/sqlx-mysql/src/migrate.rs b/sqlx-mysql/src/migrate.rs index 79b55ace3c..40c83c9fa2 100644 --- a/sqlx-mysql/src/migrate.rs +++ b/sqlx-mysql/src/migrate.rs @@ -4,6 +4,7 @@ use std::time::Instant; use futures_core::future::BoxFuture; pub(crate) use sqlx_core::migrate::*; +use sqlx_core::sql_str::AssertSqlSafe; use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; @@ -37,7 +38,7 @@ impl MigrateDatabase for MySql { let mut conn = options.connect().await?; let _ = conn - .execute(&*format!("CREATE DATABASE `{database}`")) + .execute(AssertSqlSafe(format!("CREATE DATABASE `{database}`"))) .await?; Ok(()) @@ -66,7 +67,9 @@ impl MigrateDatabase for MySql { let mut conn = options.connect().await?; let _ = conn - .execute(&*format!("DROP DATABASE IF EXISTS `{database}`")) + .execute(AssertSqlSafe(format!( + "DROP DATABASE IF EXISTS `{database}`" + ))) .await?; Ok(()) @@ -200,7 +203,8 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( .await?; let _ = tx - .execute(&*migration.sql) + // We can't use `SqlStr` in `Migration` because it can't be used in a const context + .execute(AssertSqlSafe(migration.sql.to_string())) .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; @@ -269,7 +273,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( .execute(&mut *tx) .await?; - tx.execute(&*migration.sql).await?; + tx.execute(AssertSqlSafe(migration.sql.to_string())).await?; // language=SQL let _ = query(r#"DELETE FROM _sqlx_migrations WHERE version = ?"#) diff --git a/sqlx-mysql/src/options/connect.rs b/sqlx-mysql/src/options/connect.rs index 116a49ccad..d3f14d64f5 100644 --- a/sqlx-mysql/src/options/connect.rs +++ b/sqlx-mysql/src/options/connect.rs @@ -4,6 +4,7 @@ use crate::executor::Executor; use crate::{MySqlConnectOptions, MySqlConnection}; use futures_core::future::BoxFuture; use log::LevelFilter; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_core::Url; use std::time::Duration; @@ -77,7 +78,7 @@ impl ConnectOptions for MySqlConnectOptions { } if !options.is_empty() { - conn.execute(&*format!(r#"SET {};"#, options.join(","))) + conn.execute(AssertSqlSafe(format!(r#"SET {};"#, options.join(",")))) .await?; } diff --git a/sqlx-mysql/src/statement.rs b/sqlx-mysql/src/statement.rs index 1f32299467..52044d2273 100644 --- a/sqlx-mysql/src/statement.rs +++ b/sqlx-mysql/src/statement.rs @@ -33,8 +33,8 @@ impl Statement for MySqlStatement { } } - fn sql(&self) -> &str { - &self.sql.as_str() + fn sql(&self) -> SqlStr { + self.sql.clone() } fn parameters(&self) -> Option> { diff --git a/sqlx-mysql/src/testing/mod.rs b/sqlx-mysql/src/testing/mod.rs index 1981cf73c5..d9a612a79a 100644 --- a/sqlx-mysql/src/testing/mod.rs +++ b/sqlx-mysql/src/testing/mod.rs @@ -1,5 +1,6 @@ use std::ops::Deref; use std::str::FromStr; +use std::sync::Arc; use std::time::Duration; use futures_core::future::BoxFuture; @@ -8,6 +9,7 @@ use once_cell::sync::OnceCell; use sqlx_core::connection::Connection; use sqlx_core::query_builder::QueryBuilder; use sqlx_core::query_scalar::query_scalar; +use sqlx_core::sql_str::AssertSqlSafe; use std::fmt::Write; use crate::error::Error; @@ -55,15 +57,16 @@ impl TestSupport for MySql { let mut deleted_db_names = Vec::with_capacity(delete_db_names.len()); - let mut command = String::new(); + let mut command_arced = Arc::new(String::new()); for db_name in &delete_db_names { + let command = Arc::get_mut(&mut command_arced).unwrap(); command.clear(); let db_name = format!("_sqlx_test_database_{db_name}"); writeln!(command, "drop database if exists {db_name:?};").ok(); - match conn.execute(&*command).await { + match conn.execute(AssertSqlSafe(command_arced.clone())).await { Ok(_deleted) => { deleted_db_names.push(db_name); } @@ -162,7 +165,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { .execute(&mut *conn) .await?; - conn.execute(&format!("create database {db_name:?}")[..]) + conn.execute(AssertSqlSafe(format!("create database {db_name:?}"))) .await?; eprintln!("created database {db_name}"); @@ -187,7 +190,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { async fn do_cleanup(conn: &mut MySqlConnection, db_name: &str) -> Result<(), Error> { let delete_db_command = format!("drop database if exists {db_name:?};"); - conn.execute(&*delete_db_command).await?; + conn.execute(AssertSqlSafe(delete_db_command)).await?; query("delete from _sqlx_test.databases where db_name = $1::text") .bind(db_name) .execute(&mut *conn) diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index 545cb5f4f2..85c7c56b48 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -22,6 +22,7 @@ impl TransactionManager for MySqlTransactionManager { ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; + let statement = match statement { // custom `BEGIN` statements are not allowed if we're already in a transaction // (we need to issue a `SAVEPOINT` instead) @@ -44,7 +45,7 @@ impl TransactionManager for MySqlTransactionManager { let depth = conn.inner.transaction_depth; if depth > 0 { - conn.execute(&*commit_ansi_transaction_sql(depth)).await?; + conn.execute(commit_ansi_transaction_sql(depth)).await?; conn.inner.transaction_depth = depth - 1; } @@ -57,7 +58,7 @@ impl TransactionManager for MySqlTransactionManager { let depth = conn.inner.transaction_depth; if depth > 0 { - conn.execute(&*rollback_ansi_transaction_sql(depth)).await?; + conn.execute(rollback_ansi_transaction_sql(depth)).await?; conn.inner.transaction_depth = depth - 1; } @@ -73,7 +74,7 @@ impl TransactionManager for MySqlTransactionManager { conn.inner.stream.sequence_id = 0; conn.inner .stream - .write_packet(Query(&rollback_ansi_transaction_sql(depth))) + .write_packet(Query(rollback_ansi_transaction_sql(depth).as_str())) .expect("BUG: unexpected error queueing ROLLBACK"); conn.inner.transaction_depth = depth - 1; diff --git a/sqlx-postgres/src/migrate.rs b/sqlx-postgres/src/migrate.rs index af87449e83..504229c2e2 100644 --- a/sqlx-postgres/src/migrate.rs +++ b/sqlx-postgres/src/migrate.rs @@ -278,7 +278,7 @@ async fn execute_migration( migration: &Migration, ) -> Result<(), MigrateError> { let _ = conn - .execute(migration.sql.clone()) + .execute(AssertSqlSafe(migration.sql.to_string())) .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; @@ -303,7 +303,7 @@ async fn revert_migration( migration: &Migration, ) -> Result<(), MigrateError> { let _ = conn - .execute(migration.sql.clone()) + .execute(AssertSqlSafe(migration.sql.to_string())) .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; diff --git a/src/lib.rs b/src/lib.rs index ed76c5f5ee..0437ed0327 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,6 +29,7 @@ pub use sqlx_core::query_scalar::query_scalar_with_result as __query_scalar_with pub use sqlx_core::query_scalar::{query_scalar, query_scalar_with}; pub use sqlx_core::raw_sql::{raw_sql, RawSql}; pub use sqlx_core::row::Row; +pub use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; pub use sqlx_core::statement::Statement; pub use sqlx_core::transaction::{Transaction, TransactionManager}; pub use sqlx_core::type_info::TypeInfo; From 88d9f8f3229f1809b91ff041e9a0cc6b690ecf0e Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Sat, 1 Feb 2025 18:12:17 +0100 Subject: [PATCH 06/12] Update Sqlite driver --- sqlx-sqlite/src/any.rs | 11 +++--- sqlx-sqlite/src/connection/describe.rs | 8 +++-- sqlx-sqlite/src/connection/execute.rs | 8 +++-- sqlx-sqlite/src/connection/executor.rs | 21 ++++++----- sqlx-sqlite/src/connection/explain.rs | 3 +- sqlx-sqlite/src/connection/mod.rs | 3 +- sqlx-sqlite/src/connection/worker.rs | 50 ++++++++++++-------------- sqlx-sqlite/src/lib.rs | 6 ++-- sqlx-sqlite/src/migrate.rs | 5 +-- sqlx-sqlite/src/options/connect.rs | 3 +- sqlx-sqlite/src/statement/mod.rs | 4 +-- tests/sqlite/describe.rs | 10 +++--- tests/sqlite/rustsec.rs | 4 +-- 13 files changed, 71 insertions(+), 65 deletions(-) diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index 5954444105..f309cad496 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -12,6 +12,7 @@ use sqlx_core::any::{ Any, AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo, AnyTypeInfoKind, AnyValueKind, }; +use sqlx_core::sql_str::SqlStr; use crate::type_info::DataType; use sqlx_core::connection::{ConnectOptions, Connection}; @@ -84,7 +85,7 @@ impl AnyConnectionBackend for SqliteConnection { fn fetch_many<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxStream<'q, sqlx_core::Result>> { @@ -107,7 +108,7 @@ impl AnyConnectionBackend for SqliteConnection { fn fetch_optional<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxFuture<'q, sqlx_core::Result>> { @@ -132,16 +133,16 @@ impl AnyConnectionBackend for SqliteConnection { fn prepare_with<'c, 'q: 'c>( &'c mut self, - sql: &'q str, + sql: SqlStr, _parameters: &[AnyTypeInfo], ) -> BoxFuture<'c, sqlx_core::Result> { Box::pin(async move { - let statement = Executor::prepare_with(self, sql, &[]).await?; + let statement = Executor::prepare_with(self, sql.clone(), &[]).await?; AnyStatement::try_from_statement(sql, &statement, statement.column_names.clone()) }) } - fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, sqlx_core::Result>> { + fn describe<'q>(&'q mut self, sql: SqlStr) -> BoxFuture<'q, sqlx_core::Result>> { Box::pin(async move { Executor::describe(self, sql).await?.try_into_any() }) } } diff --git a/sqlx-sqlite/src/connection/describe.rs b/sqlx-sqlite/src/connection/describe.rs index 0f4da33ccc..13a25e96aa 100644 --- a/sqlx-sqlite/src/connection/describe.rs +++ b/sqlx-sqlite/src/connection/describe.rs @@ -5,14 +5,18 @@ use crate::error::Error; use crate::statement::VirtualStatement; use crate::type_info::DataType; use crate::{Sqlite, SqliteColumn}; +use sqlx_core::sql_str::SqlStr; use sqlx_core::Either; use std::convert::identity; -pub(crate) fn describe(conn: &mut ConnectionState, query: &str) -> Result, Error> { +pub(crate) fn describe( + conn: &mut ConnectionState, + query: SqlStr, +) -> Result, Error> { // describing a statement from SQLite can be involved // each SQLx statement is comprised of multiple SQL statements - let mut statement = VirtualStatement::new(query, false)?; + let mut statement = VirtualStatement::new(query.as_str(), false)?; let mut columns = Vec::new(); let mut nullable = Vec::new(); diff --git a/sqlx-sqlite/src/connection/execute.rs b/sqlx-sqlite/src/connection/execute.rs index 8a76236977..7acbc91ff8 100644 --- a/sqlx-sqlite/src/connection/execute.rs +++ b/sqlx-sqlite/src/connection/execute.rs @@ -3,12 +3,13 @@ use crate::error::Error; use crate::logger::QueryLogger; use crate::statement::{StatementHandle, VirtualStatement}; use crate::{SqliteArguments, SqliteQueryResult, SqliteRow}; +use sqlx_core::sql_str::SqlSafeStr; use sqlx_core::Either; pub struct ExecuteIter<'a> { handle: &'a mut ConnectionHandle, statement: &'a mut VirtualStatement, - logger: QueryLogger<'a>, + logger: QueryLogger, args: Option>, /// since a `VirtualStatement` can encompass multiple actual statements, @@ -20,12 +21,13 @@ pub struct ExecuteIter<'a> { pub(crate) fn iter<'a>( conn: &'a mut ConnectionState, - query: &'a str, + query: impl SqlSafeStr, args: Option>, persistent: bool, ) -> Result, Error> { + let query = query.into_sql_str(); // fetch the cached statement or allocate a new one - let statement = conn.statements.get(query, persistent)?; + let statement = conn.statements.get(query.as_str(), persistent)?; let logger = QueryLogger::new(query, conn.log_settings.clone()); diff --git a/sqlx-sqlite/src/connection/executor.rs b/sqlx-sqlite/src/connection/executor.rs index c0151d228f..3ad67c6267 100644 --- a/sqlx-sqlite/src/connection/executor.rs +++ b/sqlx-sqlite/src/connection/executor.rs @@ -7,7 +7,7 @@ use futures_util::{stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use sqlx_core::describe::Describe; use sqlx_core::error::Error; use sqlx_core::executor::{Execute, Executor}; -use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; +use sqlx_core::sql_str::SqlSafeStr; use sqlx_core::Either; use std::{future, pin::pin}; @@ -24,12 +24,12 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { 'q: 'e, E: 'q, { - let sql = query.sql(); let arguments = match query.take_arguments().map_err(Error::Encode) { Ok(arguments) => arguments, Err(error) => return stream::once(future::ready(Err(error))).boxed(), }; let persistent = query.persistent() && arguments.is_some(); + let sql = query.sql(); Box::pin( self.worker @@ -49,7 +49,6 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { 'q: 'e, E: 'q, { - let sql = query.sql(); let arguments = match query.take_arguments().map_err(Error::Encode) { Ok(arguments) => arguments, Err(error) => return future::ready(Err(error)).boxed(), @@ -57,6 +56,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { let persistent = query.persistent() && arguments.is_some(); Box::pin(async move { + let sql = query.sql(); let mut stream = pin!(self .worker .execute(sql, arguments, self.row_channel_size, persistent, Some(1)) @@ -73,29 +73,28 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { }) } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, _parameters: &[SqliteTypeInfo], ) -> BoxFuture<'e, Result> where 'c: 'e, { + let sql = sql.into_sql_str(); Box::pin(async move { let statement = self.worker.prepare(sql).await?; - Ok(SqliteStatement { - sql: AssertSqlSafe(sql).into_sql_str(), - ..statement - }) + Ok(statement) }) } #[doc(hidden)] - fn describe<'e, 'q: 'e>(self, sql: &'q str) -> BoxFuture<'e, Result, Error>> + fn describe<'e>(self, sql: impl SqlSafeStr) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { - Box::pin(self.worker.describe(sql)) + let sql = sql.into_sql_str(); + Box::pin(async move { self.worker.describe(sql).await }) } } diff --git a/sqlx-sqlite/src/connection/explain.rs b/sqlx-sqlite/src/connection/explain.rs index bfa66aa12f..edd65ece49 100644 --- a/sqlx-sqlite/src/connection/explain.rs +++ b/sqlx-sqlite/src/connection/explain.rs @@ -12,6 +12,7 @@ use crate::from_row::FromRow; use crate::logger::{BranchParent, BranchResult, DebugDiff}; use crate::type_info::DataType; use crate::SqliteTypeInfo; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_core::{hash_map, HashMap}; use std::fmt::Debug; use std::str::from_utf8; @@ -567,7 +568,7 @@ pub(super) fn explain( ) -> Result<(Vec, Vec>), Error> { let root_block_cols = root_block_columns(conn)?; let program: Vec<(i64, String, i64, i64, i64, Vec)> = - execute::iter(conn, &format!("EXPLAIN {query}"), None, false)? + execute::iter(conn, AssertSqlSafe(format!("EXPLAIN {query}")), None, false)? .filter_map(|res| res.map(|either| either.right()).transpose()) .map(|row| FromRow::from_row(&row?)) .collect::, Error>>()?; diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index b94ad91c4d..5823de6ec8 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -23,6 +23,7 @@ use sqlx_core::common::StatementCache; pub(crate) use sqlx_core::connection::*; use sqlx_core::error::Error; use sqlx_core::executor::Executor; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_core::transaction::Transaction; use crate::connection::establish::EstablishParams; @@ -224,7 +225,7 @@ impl Connection for SqliteConnection { write!(pragma_string, "PRAGMA analysis_limit = {limit}; ").ok(); } pragma_string.push_str("PRAGMA optimize;"); - self.execute(&*pragma_string).await?; + self.execute(AssertSqlSafe(pragma_string)).await?; } let shutdown = self.worker.shutdown(); // Drop the statement worker, which should diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index 9f8b6614af..24b8de2cd0 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -5,7 +5,7 @@ use std::thread; use futures_channel::oneshot; use futures_intrusive::sync::{Mutex, MutexGuard}; -use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; +use sqlx_core::sql_str::SqlStr; use tracing::span::Span; use sqlx_core::describe::Describe; @@ -53,15 +53,15 @@ impl WorkerSharedState { enum Command { Prepare { - query: Box, + query: SqlStr, tx: oneshot::Sender>, }, Describe { - query: Box, + query: SqlStr, tx: oneshot::Sender, Error>>, }, Execute { - query: Box, + query: SqlStr, arguments: Option>, persistent: bool, tx: flume::Sender, Error>>, @@ -145,7 +145,7 @@ impl ConnectionWorker { let _guard = span.enter(); match cmd { Command::Prepare { query, tx } => { - tx.send(prepare(&mut conn, &query).map(|prepared| { + tx.send(prepare(&mut conn, query).map(|prepared| { update_cached_statements_size( &conn, &shared.cached_statements_size, @@ -155,7 +155,7 @@ impl ConnectionWorker { .ok(); } Command::Describe { query, tx } => { - tx.send(describe(&mut conn, &query)).ok(); + tx.send(describe(&mut conn, query)).ok(); } Command::Execute { query, @@ -164,7 +164,7 @@ impl ConnectionWorker { tx, limit } => { - let iter = match execute::iter(&mut conn, &query, arguments, persistent) + let iter = match execute::iter(&mut conn, query, arguments, persistent) { Ok(iter) => iter, Err(e) => { @@ -225,7 +225,7 @@ impl ConnectionWorker { }; let res = conn.handle - .exec(statement) + .exec(begin_ansi_transaction_sql(depth).as_str()) .map(|_| { shared.transaction_depth.fetch_add(1, Ordering::Release); }); @@ -238,7 +238,7 @@ impl ConnectionWorker { // immediately otherwise it would remain started forever. if let Err(error) = conn .handle - .exec(rollback_ansi_transaction_sql(depth + 1)) + .exec(rollback_ansi_transaction_sql(depth + 1).as_str()) .map(|_| { shared.transaction_depth.fetch_sub(1, Ordering::Release); }) @@ -256,7 +256,7 @@ impl ConnectionWorker { let res = if depth > 0 { conn.handle - .exec(commit_ansi_transaction_sql(depth)) + .exec(commit_ansi_transaction_sql(depth).as_str()) .map(|_| { shared.transaction_depth.fetch_sub(1, Ordering::Release); }) @@ -282,7 +282,7 @@ impl ConnectionWorker { let res = if depth > 0 { conn.handle - .exec(rollback_ansi_transaction_sql(depth)) + .exec(rollback_ansi_transaction_sql(depth).as_str()) .map(|_| { shared.transaction_depth.fetch_sub(1, Ordering::Release); }) @@ -335,25 +335,19 @@ impl ConnectionWorker { establish_rx.await.map_err(|_| Error::WorkerCrashed)? } - pub(crate) async fn prepare(&mut self, query: &str) -> Result { - self.oneshot_cmd(|tx| Command::Prepare { - query: query.into(), - tx, - }) - .await? + pub(crate) async fn prepare(&mut self, query: SqlStr) -> Result { + self.oneshot_cmd(|tx| Command::Prepare { query, tx }) + .await? } - pub(crate) async fn describe(&mut self, query: &str) -> Result, Error> { - self.oneshot_cmd(|tx| Command::Describe { - query: query.into(), - tx, - }) - .await? + pub(crate) async fn describe(&mut self, query: SqlStr) -> Result, Error> { + self.oneshot_cmd(|tx| Command::Describe { query, tx }) + .await? } pub(crate) async fn execute( &mut self, - query: &str, + query: SqlStr, args: Option>, chan_size: usize, persistent: bool, @@ -364,7 +358,7 @@ impl ConnectionWorker { self.command_tx .send_async(( Command::Execute { - query: query.into(), + query, arguments: args.map(SqliteArguments::into_static), persistent, tx, @@ -495,9 +489,9 @@ impl ConnectionWorker { } } -fn prepare(conn: &mut ConnectionState, query: &str) -> Result { +fn prepare(conn: &mut ConnectionState, query: SqlStr) -> Result { // prepare statement object (or checkout from cache) - let statement = conn.statements.get(query, true)?; + let statement = conn.statements.get(query.as_str(), true)?; let mut parameters = 0; let mut columns = None; @@ -514,7 +508,7 @@ fn prepare(conn: &mut ConnectionState, query: &str) -> Result Result &str { - &self.sql.as_str() + fn sql(&self) -> SqlStr { + self.sql.clone() } fn parameters(&self) -> Option> { diff --git a/tests/sqlite/describe.rs b/tests/sqlite/describe.rs index 5458eaaa82..82d3445530 100644 --- a/tests/sqlite/describe.rs +++ b/tests/sqlite/describe.rs @@ -437,7 +437,7 @@ async fn it_describes_ungrouped_aggregate() -> anyhow::Result<()> { async fn it_describes_literal_subquery() -> anyhow::Result<()> { async fn assert_literal_described( conn: &mut sqlx::SqliteConnection, - query: &str, + query: &'static str, ) -> anyhow::Result<()> { let info = conn.describe(query).await?; @@ -473,7 +473,7 @@ async fn it_describes_literal_subquery() -> anyhow::Result<()> { async fn assert_tweet_described( conn: &mut sqlx::SqliteConnection, - query: &str, + query: &'static str, ) -> anyhow::Result<()> { let info = conn.describe(query).await?; let columns = info.columns(); @@ -533,7 +533,7 @@ async fn it_describes_table_order_by() -> anyhow::Result<()> { async fn assert_literal_order_by_described( conn: &mut sqlx::SqliteConnection, - query: &str, + query: &'static str, ) -> anyhow::Result<()> { let info = conn.describe(query).await?; @@ -571,7 +571,7 @@ async fn it_describes_table_order_by() -> anyhow::Result<()> { async fn it_describes_union() -> anyhow::Result<()> { async fn assert_union_described( conn: &mut sqlx::SqliteConnection, - query: &str, + query: &'static str, ) -> anyhow::Result<()> { let info = conn.describe(query).await?; @@ -653,7 +653,7 @@ async fn it_describes_having_group_by() -> anyhow::Result<()> { async fn it_describes_strange_queries() -> anyhow::Result<()> { async fn assert_single_column_described( conn: &mut sqlx::SqliteConnection, - query: &str, + query: &'static str, typename: &str, nullable: bool, ) -> anyhow::Result<()> { diff --git a/tests/sqlite/rustsec.rs b/tests/sqlite/rustsec.rs index 3ff9c524fa..08f88a3ad9 100644 --- a/tests/sqlite/rustsec.rs +++ b/tests/sqlite/rustsec.rs @@ -1,4 +1,4 @@ -use sqlx::{Connection, Error, SqliteConnection}; +use sqlx::{AssertSqlSafe, Connection, Error, SqliteConnection}; // https://rustsec.org/advisories/RUSTSEC-2024-0363.html // @@ -50,7 +50,7 @@ async fn rustsec_2024_0363() -> anyhow::Result<()> { .execute(&mut conn) .await?; - let res = sqlx::raw_sql(&query).execute(&mut conn).await; + let res = sqlx::raw_sql(AssertSqlSafe(query)).execute(&mut conn).await; if let Err(e) = res { // Connection rejected the query; we're happy. From db51756a83171b93402787e9dad1a20ed66b31e1 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Sat, 1 Feb 2025 19:29:05 +0100 Subject: [PATCH 07/12] remove debug clone count --- sqlx-core/src/sql_str.rs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/sqlx-core/src/sql_str.rs b/sqlx-core/src/sql_str.rs index 9e53ede4d0..bb3a267262 100644 --- a/sqlx-core/src/sql_str.rs +++ b/sqlx-core/src/sql_str.rs @@ -1,6 +1,6 @@ use std::borrow::Borrow; use std::hash::{Hash, Hasher}; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; /// A SQL string that is safe to execute on a database connection. /// @@ -125,16 +125,8 @@ enum Repr { ArcString(Arc), } -static COUNT_CLONES: Mutex = Mutex::new(0usize); - impl Clone for SqlStr { fn clone(&self) -> Self { - let mut lock = COUNT_CLONES.lock().unwrap(); - *lock += 1; - let clones: usize = *lock; - drop(lock); - - println!("------- Count clones: {clones} --------\n\n\n"); Self(match &self.0 { Repr::Static(s) => Repr::Static(s), Repr::Arced(s) => Repr::Arced(s.clone()), From 694682de510c2c5920224a2bc2972aaf924413c9 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Sat, 1 Feb 2025 19:54:24 +0100 Subject: [PATCH 08/12] Reduce the amount of SqlStr clones --- sqlx-core/src/any/statement.rs | 11 +++++++---- sqlx-core/src/query.rs | 2 +- sqlx-core/src/statement.rs | 4 +++- sqlx-mysql/src/any.rs | 9 +++------ sqlx-mysql/src/statement.rs | 6 +++++- sqlx-postgres/src/any.rs | 9 +++------ sqlx-postgres/src/statement.rs | 6 +++++- sqlx-sqlite/src/any.rs | 5 +++-- sqlx-sqlite/src/statement/mod.rs | 6 +++++- 9 files changed, 35 insertions(+), 23 deletions(-) diff --git a/sqlx-core/src/any/statement.rs b/sqlx-core/src/any/statement.rs index d92980a286..782a3c0dcd 100644 --- a/sqlx-core/src/any/statement.rs +++ b/sqlx-core/src/any/statement.rs @@ -32,10 +32,14 @@ impl Statement for AnyStatement { } } - fn sql(&self) -> SqlStr { + fn sql_cloned(&self) -> SqlStr { self.sql.clone() } + fn into_sql(self) -> SqlStr { + self.sql + } + fn parameters(&self) -> Option> { match &self.parameters { Some(Either::Left(types)) => Some(Either::Left(types)), @@ -64,8 +68,7 @@ impl<'i> ColumnIndex for &'i str { impl<'q> AnyStatement { #[doc(hidden)] pub fn try_from_statement( - query: SqlStr, - statement: &S, + statement: S, column_names: Arc>, ) -> crate::Result where @@ -91,7 +94,7 @@ impl<'q> AnyStatement { .collect::, _>>()?; Ok(Self { - sql: query, + sql: statement.into_sql(), columns, column_names, parameters, diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index 6a13babdfe..882ed17630 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -47,7 +47,7 @@ where #[inline] fn sql(self) -> SqlStr { match self.statement { - Either::Right(statement) => statement.sql(), + Either::Right(statement) => statement.sql_cloned(), Either::Left(sql) => sql, } } diff --git a/sqlx-core/src/statement.rs b/sqlx-core/src/statement.rs index 6f8fd95962..5173d1b191 100644 --- a/sqlx-core/src/statement.rs +++ b/sqlx-core/src/statement.rs @@ -25,7 +25,9 @@ pub trait Statement: Send + Sync { fn to_owned(&self) -> ::Statement; /// Get the original SQL text used to create this statement. - fn sql(&self) -> SqlStr; + fn sql_cloned(&self) -> SqlStr; + + fn into_sql(self) -> SqlStr; /// Get the expected parameters for this statement. /// diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index 44f6092352..58dfc9b698 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -140,12 +140,9 @@ impl AnyConnectionBackend for MySqlConnection { _parameters: &[AnyTypeInfo], ) -> BoxFuture<'c, sqlx_core::Result> { Box::pin(async move { - let statement = Executor::prepare_with(self, sql.clone(), &[]).await?; - AnyStatement::try_from_statement( - sql, - &statement, - statement.metadata.column_names.clone(), - ) + let statement = Executor::prepare_with(self, sql, &[]).await?; + let column_names = statement.metadata.column_names.clone(); + AnyStatement::try_from_statement(statement, column_names) }) } diff --git a/sqlx-mysql/src/statement.rs b/sqlx-mysql/src/statement.rs index 52044d2273..e6b0065961 100644 --- a/sqlx-mysql/src/statement.rs +++ b/sqlx-mysql/src/statement.rs @@ -33,10 +33,14 @@ impl Statement for MySqlStatement { } } - fn sql(&self) -> SqlStr { + fn sql_cloned(&self) -> SqlStr { self.sql.clone() } + fn into_sql(self) -> SqlStr { + self.sql + } + fn parameters(&self) -> Option> { Some(Either::Right(self.metadata.parameters)) } diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index 353cbe6c16..cabeed75ef 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -140,12 +140,9 @@ impl AnyConnectionBackend for PgConnection { _parameters: &[AnyTypeInfo], ) -> BoxFuture<'c, sqlx_core::Result> { Box::pin(async move { - let statement = Executor::prepare_with(self, sql.clone(), &[]).await?; - AnyStatement::try_from_statement( - sql, - &statement, - statement.metadata.column_names.clone(), - ) + let statement = Executor::prepare_with(self, sql, &[]).await?; + let colunn_names = statement.metadata.column_names.clone(); + AnyStatement::try_from_statement(statement, colunn_names) }) } diff --git a/sqlx-postgres/src/statement.rs b/sqlx-postgres/src/statement.rs index 864574781a..b2f739d033 100644 --- a/sqlx-postgres/src/statement.rs +++ b/sqlx-postgres/src/statement.rs @@ -34,10 +34,14 @@ impl Statement for PgStatement { } } - fn sql(&self) -> SqlStr { + fn sql_cloned(&self) -> SqlStr { self.sql.clone() } + fn into_sql(self) -> SqlStr { + self.sql + } + fn parameters(&self) -> Option> { Some(Either::Left(&self.metadata.parameters)) } diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index f309cad496..2345451090 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -137,8 +137,9 @@ impl AnyConnectionBackend for SqliteConnection { _parameters: &[AnyTypeInfo], ) -> BoxFuture<'c, sqlx_core::Result> { Box::pin(async move { - let statement = Executor::prepare_with(self, sql.clone(), &[]).await?; - AnyStatement::try_from_statement(sql, &statement, statement.column_names.clone()) + let statement = Executor::prepare_with(self, sql, &[]).await?; + let column_names = statement.column_names.clone(); + AnyStatement::try_from_statement(statement, column_names) }) } diff --git a/sqlx-sqlite/src/statement/mod.rs b/sqlx-sqlite/src/statement/mod.rs index f67cc301e8..08f56d08eb 100644 --- a/sqlx-sqlite/src/statement/mod.rs +++ b/sqlx-sqlite/src/statement/mod.rs @@ -36,10 +36,14 @@ impl Statement for SqliteStatement { } } - fn sql(&self) -> SqlStr { + fn sql_cloned(&self) -> SqlStr { self.sql.clone() } + fn into_sql(self) -> SqlStr { + self.sql + } + fn parameters(&self) -> Option> { Some(Either::Right(self.parameters)) } From a6aa63d379a7cbf5a0f5342ab1f270f1bb309c16 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Sat, 1 Feb 2025 20:54:18 +0100 Subject: [PATCH 09/12] improve QueryBuilder error message --- sqlx-core/src/query_builder.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sqlx-core/src/query_builder.rs b/sqlx-core/src/query_builder.rs index 2ff79e9f2e..09a40f6ca6 100644 --- a/sqlx-core/src/query_builder.rs +++ b/sqlx-core/src/query_builder.rs @@ -43,6 +43,8 @@ impl<'args, DB: Database> Default for QueryBuilder<'args, DB> { } } +const ERROR: &str = "BUG: query must not be shared at this point in time"; + impl<'args, DB: Database> QueryBuilder<'args, DB> where DB: Database, @@ -118,7 +120,7 @@ where /// e.g. check that strings aren't too long, numbers are within expected ranges, etc. pub fn push(&mut self, sql: impl Display) -> &mut Self { self.sanity_check(); - let query: &mut String = Arc::get_mut(&mut self.query).expect(""); + let query: &mut String = Arc::get_mut(&mut self.query).expect(ERROR); write!(query, "{sql}").expect("error formatting `sql`"); @@ -161,7 +163,7 @@ where .expect("BUG: Arguments taken already"); arguments.add(value).expect("Failed to add argument"); - let query: &mut String = Arc::get_mut(&mut self.query).expect(""); + let query: &mut String = Arc::get_mut(&mut self.query).expect(ERROR); arguments .format_placeholder(query) .expect("error in format_placeholder"); @@ -516,7 +518,7 @@ where /// The query is truncated to the initial fragment provided to [`new()`][Self::new] and /// the bind arguments are reset. pub fn reset(&mut self) -> &mut Self { - let query: &mut String = Arc::get_mut(&mut self.query).expect(""); + let query: &mut String = Arc::get_mut(&mut self.query).expect(ERROR); query.truncate(self.init_len); self.arguments = Some(Default::default()); From 28bc3222a27b659d59207f1e13aa0e411e116a60 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Mon, 17 Feb 2025 20:51:32 +0100 Subject: [PATCH 10/12] cargo fmt --- sqlx-core/src/sql_str.rs | 5 ++++- sqlx-postgres/src/any.rs | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sqlx-core/src/sql_str.rs b/sqlx-core/src/sql_str.rs index bb3a267262..be9fbd92a6 100644 --- a/sqlx-core/src/sql_str.rs +++ b/sqlx-core/src/sql_str.rs @@ -170,7 +170,10 @@ impl Borrow for SqlStr { } } -impl PartialEq for SqlStr where T: AsRef { +impl PartialEq for SqlStr +where + T: AsRef, +{ fn eq(&self, other: &T) -> bool { self.as_str() == other.as_ref() } diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index cabeed75ef..73d4678aa6 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -8,6 +8,7 @@ use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt}; use std::borrow::Cow; use std::{future, pin::pin}; use sqlx_core::sql_str::SqlStr; +use std::{future, pin::pin}; use sqlx_core::any::{ Any, AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow, From a0f122b010fea43b900b4ee0c281a72640dbea72 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Mon, 17 Feb 2025 21:45:54 +0100 Subject: [PATCH 11/12] fix clippy warnings --- sqlx-cli/src/database.rs | 1 + sqlx-cli/src/opt.rs | 1 + sqlx-core/src/any/arguments.rs | 2 +- sqlx-core/src/any/connection/backend.rs | 2 +- sqlx-core/src/any/row.rs | 2 +- sqlx-core/src/any/statement.rs | 4 ++-- sqlx-core/src/ext/async_stream.rs | 2 +- sqlx-core/src/io/encode.rs | 2 +- sqlx-core/src/net/socket/mod.rs | 8 ++++---- sqlx-core/src/pool/inner.rs | 4 ++-- sqlx-core/src/pool/maybe.rs | 6 +++--- sqlx-core/src/query.rs | 8 ++++---- sqlx-core/src/query_as.rs | 8 ++++---- sqlx-core/src/query_builder.rs | 3 +-- sqlx-core/src/query_scalar.rs | 8 ++++---- sqlx-core/src/raw_sql.rs | 9 +-------- sqlx-core/src/rt/mod.rs | 2 ++ sqlx-core/src/sql_str.rs | 3 +-- sqlx-core/src/transaction.rs | 12 ++++++------ sqlx-core/src/type_checking.rs | 2 +- sqlx-mysql/src/any.rs | 2 +- sqlx-mysql/src/connection/establish.rs | 2 +- sqlx-mysql/src/connection/executor.rs | 4 ++-- sqlx-mysql/src/protocol/statement/execute.rs | 2 +- sqlx-mysql/src/types/text.rs | 2 +- sqlx-postgres/src/advisory_lock.rs | 14 ++++++-------- sqlx-postgres/src/any.rs | 5 +---- sqlx-postgres/src/connection/executor.rs | 2 +- sqlx-postgres/src/connection/tls.rs | 2 +- sqlx-postgres/src/message/response.rs | 2 +- sqlx-postgres/src/types/cube.rs | 2 +- sqlx-postgres/src/types/geometry/box.rs | 2 +- sqlx-postgres/src/types/geometry/line.rs | 2 +- sqlx-postgres/src/types/geometry/line_segment.rs | 2 +- sqlx-postgres/src/types/geometry/point.rs | 2 +- sqlx-postgres/src/types/json.rs | 2 +- sqlx-postgres/src/types/text.rs | 2 +- sqlx-sqlite/src/any.rs | 2 +- sqlx-sqlite/src/connection/intmap.rs | 2 +- sqlx-sqlite/src/connection/worker.rs | 3 +-- sqlx-sqlite/src/logger.rs | 2 +- sqlx-sqlite/src/value.rs | 8 ++++---- 42 files changed, 73 insertions(+), 84 deletions(-) diff --git a/sqlx-cli/src/database.rs b/sqlx-cli/src/database.rs index 7a9bc6bf2f..48dff16838 100644 --- a/sqlx-cli/src/database.rs +++ b/sqlx-cli/src/database.rs @@ -1,3 +1,4 @@ +#![allow(unexpected_cfgs)] use crate::migrate; use crate::opt::ConnectOpts; use console::{style, Term}; diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index 07058aa147..dbc57f890e 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -1,3 +1,4 @@ +#![allow(unexpected_cfgs)] use std::ops::{Deref, Not}; use clap::{Args, Parser}; diff --git a/sqlx-core/src/any/arguments.rs b/sqlx-core/src/any/arguments.rs index 59a0c0f765..2c05e3fd5b 100644 --- a/sqlx-core/src/any/arguments.rs +++ b/sqlx-core/src/any/arguments.rs @@ -32,7 +32,7 @@ impl<'q> Arguments<'q> for AnyArguments<'q> { pub struct AnyArgumentBuffer<'q>(#[doc(hidden)] pub Vec>); -impl<'q> Default for AnyArguments<'q> { +impl Default for AnyArguments<'_> { fn default() -> Self { AnyArguments { values: AnyArgumentBuffer(vec![]), diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index 5ee6b1f8fa..7575219d38 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -115,5 +115,5 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { parameters: &[AnyTypeInfo], ) -> BoxFuture<'c, crate::Result>; - fn describe<'q>(&'q mut self, sql: SqlStr) -> BoxFuture<'q, crate::Result>>; + fn describe(&mut self, sql: SqlStr) -> BoxFuture<'_, crate::Result>>; } diff --git a/sqlx-core/src/any/row.rs b/sqlx-core/src/any/row.rs index 310881da14..57b8590b5f 100644 --- a/sqlx-core/src/any/row.rs +++ b/sqlx-core/src/any/row.rs @@ -63,7 +63,7 @@ impl Row for AnyRow { } } -impl<'i> ColumnIndex for &'i str { +impl ColumnIndex for &'_ str { fn index(&self, row: &AnyRow) -> Result { row.column_names .get(*self) diff --git a/sqlx-core/src/any/statement.rs b/sqlx-core/src/any/statement.rs index 782a3c0dcd..6fa979743e 100644 --- a/sqlx-core/src/any/statement.rs +++ b/sqlx-core/src/any/statement.rs @@ -55,7 +55,7 @@ impl Statement for AnyStatement { impl_statement_query!(AnyArguments<'_>); } -impl<'i> ColumnIndex for &'i str { +impl ColumnIndex for &'_ str { fn index(&self, statement: &AnyStatement) -> Result { statement .column_names @@ -65,7 +65,7 @@ impl<'i> ColumnIndex for &'i str { } } -impl<'q> AnyStatement { +impl AnyStatement { #[doc(hidden)] pub fn try_from_statement( statement: S, diff --git a/sqlx-core/src/ext/async_stream.rs b/sqlx-core/src/ext/async_stream.rs index 56777ca4db..c41d940981 100644 --- a/sqlx-core/src/ext/async_stream.rs +++ b/sqlx-core/src/ext/async_stream.rs @@ -95,7 +95,7 @@ impl Yielder { } } -impl<'a, T> Stream for TryAsyncStream<'a, T> { +impl Stream for TryAsyncStream<'_, T> { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/sqlx-core/src/io/encode.rs b/sqlx-core/src/io/encode.rs index a603ea9325..ba032d294d 100644 --- a/sqlx-core/src/io/encode.rs +++ b/sqlx-core/src/io/encode.rs @@ -9,7 +9,7 @@ pub trait ProtocolEncode<'en, Context = ()> { fn encode_with(&self, buf: &mut Vec, context: Context) -> Result<(), crate::Error>; } -impl<'en, C> ProtocolEncode<'en, C> for &'_ [u8] { +impl ProtocolEncode<'_, C> for &'_ [u8] { fn encode_with(&self, buf: &mut Vec, _context: C) -> Result<(), crate::Error> { buf.extend_from_slice(self); Ok(()) diff --git a/sqlx-core/src/net/socket/mod.rs b/sqlx-core/src/net/socket/mod.rs index d11f15884e..1f24da8c40 100644 --- a/sqlx-core/src/net/socket/mod.rs +++ b/sqlx-core/src/net/socket/mod.rs @@ -62,7 +62,7 @@ pub struct Read<'a, S: ?Sized, B> { buf: &'a mut B, } -impl<'a, S: ?Sized, B> Future for Read<'a, S, B> +impl Future for Read<'_, S, B> where S: Socket, B: ReadBuf, @@ -90,7 +90,7 @@ pub struct Write<'a, S: ?Sized> { buf: &'a [u8], } -impl<'a, S: ?Sized> Future for Write<'a, S> +impl Future for Write<'_, S> where S: Socket, { @@ -116,7 +116,7 @@ pub struct Flush<'a, S: ?Sized> { socket: &'a mut S, } -impl<'a, S: Socket + ?Sized> Future for Flush<'a, S> { +impl Future for Flush<'_, S> { type Output = io::Result<()>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -128,7 +128,7 @@ pub struct Shutdown<'a, S: ?Sized> { socket: &'a mut S, } -impl<'a, S: ?Sized> Future for Shutdown<'a, S> +impl Future for Shutdown<'_, S> where S: Socket, { diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index 2066364a8e..51dd4b0055 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -452,14 +452,14 @@ pub(super) fn is_beyond_max_lifetime( ) -> bool { options .max_lifetime - .map_or(false, |max| live.created_at.elapsed() > max) + .is_some_and(|max| live.created_at.elapsed() > max) } /// Returns `true` if the connection has exceeded `options.idle_timeout` if set, `false` otherwise. fn is_beyond_idle_timeout(idle: &Idle, options: &PoolOptions) -> bool { options .idle_timeout - .map_or(false, |timeout| idle.idle_since.elapsed() > timeout) + .is_some_and(|timeout| idle.idle_since.elapsed() > timeout) } async fn check_idle_conn( diff --git a/sqlx-core/src/pool/maybe.rs b/sqlx-core/src/pool/maybe.rs index f9f16c41a5..71a48728a2 100644 --- a/sqlx-core/src/pool/maybe.rs +++ b/sqlx-core/src/pool/maybe.rs @@ -8,7 +8,7 @@ pub enum MaybePoolConnection<'c, DB: Database> { PoolConnection(PoolConnection), } -impl<'c, DB: Database> Deref for MaybePoolConnection<'c, DB> { +impl Deref for MaybePoolConnection<'_, DB> { type Target = DB::Connection; #[inline] @@ -20,7 +20,7 @@ impl<'c, DB: Database> Deref for MaybePoolConnection<'c, DB> { } } -impl<'c, DB: Database> DerefMut for MaybePoolConnection<'c, DB> { +impl DerefMut for MaybePoolConnection<'_, DB> { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { match self { @@ -30,7 +30,7 @@ impl<'c, DB: Database> DerefMut for MaybePoolConnection<'c, DB> { } } -impl<'c, DB: Database> From> for MaybePoolConnection<'c, DB> { +impl From> for MaybePoolConnection<'_, DB> { fn from(v: PoolConnection) -> Self { MaybePoolConnection::PoolConnection(v) } diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index 882ed17630..fba07417b1 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -121,7 +121,7 @@ impl<'q, DB: Database> Query<'q, DB, ::Arguments<'q>> { } } -impl<'q, DB, A> Query<'q, DB, A> +impl Query<'_, DB, A> where DB: Database + HasStatementCache, { @@ -498,9 +498,9 @@ where } /// Execute a single SQL query as a prepared statement (explicitly created). -pub fn query_statement<'q, DB>( - statement: &'q DB::Statement, -) -> Query<'q, DB, ::Arguments<'_>> +pub fn query_statement( + statement: &DB::Statement, +) -> Query<'_, DB, ::Arguments<'_>> where DB: Database, { diff --git a/sqlx-core/src/query_as.rs b/sqlx-core/src/query_as.rs index 2465472f44..a9a035a82c 100644 --- a/sqlx-core/src/query_as.rs +++ b/sqlx-core/src/query_as.rs @@ -58,7 +58,7 @@ impl<'q, DB: Database, O> QueryAs<'q, DB, O, ::Arguments<'q>> { } } -impl<'q, DB, O, A> QueryAs<'q, DB, O, A> +impl QueryAs<'_, DB, O, A> where DB: Database + HasStatementCache, { @@ -387,9 +387,9 @@ where } // Make a SQL query from a statement, that is mapped to a concrete type. -pub fn query_statement_as<'q, DB, O>( - statement: &'q DB::Statement, -) -> QueryAs<'q, DB, O, ::Arguments<'_>> +pub fn query_statement_as( + statement: &DB::Statement, +) -> QueryAs<'_, DB, O, ::Arguments<'_>> where DB: Database, O: for<'r> FromRow<'r, DB::Row>, diff --git a/sqlx-core/src/query_builder.rs b/sqlx-core/src/query_builder.rs index 09a40f6ca6..b4eed071f8 100644 --- a/sqlx-core/src/query_builder.rs +++ b/sqlx-core/src/query_builder.rs @@ -33,7 +33,7 @@ where arguments: Option<::Arguments<'args>>, } -impl<'args, DB: Database> Default for QueryBuilder<'args, DB> { +impl Default for QueryBuilder<'_, DB> { fn default() -> Self { QueryBuilder { init_len: 0, @@ -198,7 +198,6 @@ where /// assert!(sql.ends_with("in (?, ?) ")); /// # } /// ``` - pub fn separated<'qb, Sep>(&'qb mut self, separator: Sep) -> Separated<'qb, 'args, DB, Sep> where 'args: 'qb, diff --git a/sqlx-core/src/query_scalar.rs b/sqlx-core/src/query_scalar.rs index 24b418cdea..1059463874 100644 --- a/sqlx-core/src/query_scalar.rs +++ b/sqlx-core/src/query_scalar.rs @@ -55,7 +55,7 @@ impl<'q, DB: Database, O> QueryScalar<'q, DB, O, ::Arguments<'q> } } -impl<'q, DB, O, A> QueryScalar<'q, DB, O, A> +impl QueryScalar<'_, DB, O, A> where DB: Database + HasStatementCache, { @@ -367,9 +367,9 @@ where } // Make a SQL query from a statement, that is mapped to a concrete value. -pub fn query_statement_scalar<'q, DB, O>( - statement: &'q DB::Statement, -) -> QueryScalar<'q, DB, O, ::Arguments<'_>> +pub fn query_statement_scalar( + statement: &DB::Statement, +) -> QueryScalar<'_, DB, O, ::Arguments<'_>> where DB: Database, (O,): for<'r> FromRow<'r, DB::Row>, diff --git a/sqlx-core/src/raw_sql.rs b/sqlx-core/src/raw_sql.rs index 5fc12d7848..43a7ec920a 100644 --- a/sqlx-core/src/raw_sql.rs +++ b/sqlx-core/src/raw_sql.rs @@ -137,7 +137,7 @@ impl<'q, DB: Database> Execute<'q, DB> for RawSql { } } -impl<'q> RawSql { +impl RawSql { /// Execute the SQL string and return the total number of rows affected. #[inline] pub async fn execute<'e, E>( @@ -145,7 +145,6 @@ impl<'q> RawSql { executor: E, ) -> crate::Result<::QueryResult> where - 'q: 'e, E: Executor<'e>, { executor.execute(self).await @@ -158,7 +157,6 @@ impl<'q> RawSql { executor: E, ) -> BoxStream<'e, crate::Result<::QueryResult>> where - 'q: 'e, E: Executor<'e>, { executor.execute_many(self) @@ -173,7 +171,6 @@ impl<'q> RawSql { executor: E, ) -> BoxStream<'e, Result<::Row, Error>> where - 'q: 'e, E: Executor<'e>, { executor.fetch(self) @@ -195,7 +192,6 @@ impl<'q> RawSql { >, > where - 'q: 'e, E: Executor<'e>, { executor.fetch_many(self) @@ -214,7 +210,6 @@ impl<'q> RawSql { executor: E, ) -> crate::Result::Row>> where - 'q: 'e, E: Executor<'e>, { executor.fetch_all(self).await @@ -238,7 +233,6 @@ impl<'q> RawSql { executor: E, ) -> crate::Result<::Row> where - 'q: 'e, E: Executor<'e>, { executor.fetch_one(self).await @@ -262,7 +256,6 @@ impl<'q> RawSql { executor: E, ) -> crate::Result<::Row> where - 'q: 'e, E: Executor<'e>, { executor.fetch_one(self).await diff --git a/sqlx-core/src/rt/mod.rs b/sqlx-core/src/rt/mod.rs index 43409073ab..210b296233 100644 --- a/sqlx-core/src/rt/mod.rs +++ b/sqlx-core/src/rt/mod.rs @@ -26,6 +26,7 @@ pub enum JoinHandle { pub async fn timeout(duration: Duration, f: F) -> Result { #[cfg(feature = "_rt-tokio")] if rt_tokio::available() { + #[allow(clippy::needless_return)] return tokio::time::timeout(duration, f) .await .map_err(|_| TimeoutError(())); @@ -116,6 +117,7 @@ pub async fn yield_now() { pub fn test_block_on(f: F) -> F::Output { #[cfg(feature = "_rt-tokio")] { + #[allow(clippy::needless_return)] return tokio::runtime::Builder::new_current_thread() .enable_all() .build() diff --git a/sqlx-core/src/sql_str.rs b/sqlx-core/src/sql_str.rs index be9fbd92a6..c5a870712d 100644 --- a/sqlx-core/src/sql_str.rs +++ b/sqlx-core/src/sql_str.rs @@ -42,7 +42,6 @@ pub trait SqlSafeStr { impl SqlSafeStr for &'static str { #[inline] - fn into_sql_str(self) -> SqlStr { SqlStr(Repr::Static(self)) } @@ -68,7 +67,7 @@ pub struct AssertSqlSafe(pub T); /// Note: copies the string. /// /// It is recommended to pass one of the supported owned string types instead. -impl<'a> SqlSafeStr for AssertSqlSafe<&'a str> { +impl SqlSafeStr for AssertSqlSafe<&str> { #[inline] fn into_sql_str(self) -> SqlStr { SqlStr(Repr::Arced(self.0.into())) diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index c846d410d1..3ad0ada9f6 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -199,7 +199,7 @@ where // } // } -impl<'c, DB> Debug for Transaction<'c, DB> +impl Debug for Transaction<'_, DB> where DB: Database, { @@ -209,7 +209,7 @@ where } } -impl<'c, DB> Deref for Transaction<'c, DB> +impl Deref for Transaction<'_, DB> where DB: Database, { @@ -221,7 +221,7 @@ where } } -impl<'c, DB> DerefMut for Transaction<'c, DB> +impl DerefMut for Transaction<'_, DB> where DB: Database, { @@ -235,13 +235,13 @@ where // `PgAdvisoryLockGuard`. // // See: https://github.com/launchbadge/sqlx/issues/2520 -impl<'c, DB: Database> AsMut for Transaction<'c, DB> { +impl AsMut for Transaction<'_, DB> { fn as_mut(&mut self) -> &mut DB::Connection { &mut self.connection } } -impl<'c, 't, DB: Database> crate::acquire::Acquire<'t> for &'t mut Transaction<'c, DB> { +impl<'t, DB: Database> crate::acquire::Acquire<'t> for &'t mut Transaction<'_, DB> { type Database = DB; type Connection = &'t mut ::Connection; @@ -257,7 +257,7 @@ impl<'c, 't, DB: Database> crate::acquire::Acquire<'t> for &'t mut Transaction<' } } -impl<'c, DB> Drop for Transaction<'c, DB> +impl Drop for Transaction<'_, DB> where DB: Database, { diff --git a/sqlx-core/src/type_checking.rs b/sqlx-core/src/type_checking.rs index 5766124530..1da6b7ab3f 100644 --- a/sqlx-core/src/type_checking.rs +++ b/sqlx-core/src/type_checking.rs @@ -112,7 +112,7 @@ where } } -impl<'v, DB> Debug for FmtValue<'v, DB> +impl Debug for FmtValue<'_, DB> where DB: Database, { diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index 58dfc9b698..89afdce474 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -146,7 +146,7 @@ impl AnyConnectionBackend for MySqlConnection { }) } - fn describe<'q>(&'q mut self, sql: SqlStr) -> BoxFuture<'q, sqlx_core::Result>> { + fn describe(&mut self, sql: SqlStr) -> BoxFuture<'_, sqlx_core::Result>> { Box::pin(async move { let describe = Executor::describe(self, sql).await?; describe.try_into_any() diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index 85a9d84f96..ec7d8e4c2c 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -186,7 +186,7 @@ impl<'a> DoHandshake<'a> { } } -impl<'a> WithSocket for DoHandshake<'a> { +impl WithSocket for DoHandshake<'_> { type Output = Result; async fn with_socket(self, socket: S) -> Self::Output { diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index f74ec1143a..bba63a3e07 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -26,7 +26,7 @@ use sqlx_core::sql_str::{SqlSafeStr, SqlStr}; use std::{pin::pin, sync::Arc}; impl MySqlConnection { - async fn prepare_statement<'c>( + async fn prepare_statement( &mut self, sql: &str, ) -> Result<(u32, MySqlStatementMetadata), Error> { @@ -73,7 +73,7 @@ impl MySqlConnection { Ok((id, metadata)) } - async fn get_or_prepare_statement<'c>( + async fn get_or_prepare_statement( &mut self, sql: &str, ) -> Result<(u32, MySqlStatementMetadata), Error> { diff --git a/sqlx-mysql/src/protocol/statement/execute.rs b/sqlx-mysql/src/protocol/statement/execute.rs index 6e51e7b564..89010315bb 100644 --- a/sqlx-mysql/src/protocol/statement/execute.rs +++ b/sqlx-mysql/src/protocol/statement/execute.rs @@ -11,7 +11,7 @@ pub struct Execute<'q> { pub arguments: &'q MySqlArguments, } -impl<'q> ProtocolEncode<'_, Capabilities> for Execute<'q> { +impl ProtocolEncode<'_, Capabilities> for Execute<'_> { fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), crate::Error> { buf.push(0x17); // COM_STMT_EXECUTE buf.extend(&self.statement.to_le_bytes()); diff --git a/sqlx-mysql/src/types/text.rs b/sqlx-mysql/src/types/text.rs index ad61c1bee8..363ec02439 100644 --- a/sqlx-mysql/src/types/text.rs +++ b/sqlx-mysql/src/types/text.rs @@ -16,7 +16,7 @@ impl Type for Text { } } -impl<'q, T> Encode<'q, MySql> for Text +impl Encode<'_, MySql> for Text where T: Display, { diff --git a/sqlx-postgres/src/advisory_lock.rs b/sqlx-postgres/src/advisory_lock.rs index d1aef176fb..047ede6be6 100644 --- a/sqlx-postgres/src/advisory_lock.rs +++ b/sqlx-postgres/src/advisory_lock.rs @@ -362,7 +362,7 @@ impl<'lock, C: AsMut> PgAdvisoryLockGuard<'lock, C> { } } -impl<'lock, C: AsMut + AsRef> Deref for PgAdvisoryLockGuard<'lock, C> { +impl + AsRef> Deref for PgAdvisoryLockGuard<'_, C> { type Target = PgConnection; fn deref(&self) -> &Self::Target { @@ -376,16 +376,14 @@ impl<'lock, C: AsMut + AsRef> Deref for PgAdvisoryLo /// However, replacing the connection with a different one using, e.g. [`std::mem::replace()`] /// is a logic error and will cause a warning to be logged by the PostgreSQL server when this /// guard attempts to release the lock. -impl<'lock, C: AsMut + AsRef> DerefMut - for PgAdvisoryLockGuard<'lock, C> -{ +impl + AsRef> DerefMut for PgAdvisoryLockGuard<'_, C> { fn deref_mut(&mut self) -> &mut Self::Target { self.conn.as_mut().expect(NONE_ERR).as_mut() } } -impl<'lock, C: AsMut + AsRef> AsRef - for PgAdvisoryLockGuard<'lock, C> +impl + AsRef> AsRef + for PgAdvisoryLockGuard<'_, C> { fn as_ref(&self) -> &PgConnection { self.conn.as_ref().expect(NONE_ERR).as_ref() @@ -398,7 +396,7 @@ impl<'lock, C: AsMut + AsRef> AsRef /// However, replacing the connection with a different one using, e.g. [`std::mem::replace()`] /// is a logic error and will cause a warning to be logged by the PostgreSQL server when this /// guard attempts to release the lock. -impl<'lock, C: AsMut> AsMut for PgAdvisoryLockGuard<'lock, C> { +impl> AsMut for PgAdvisoryLockGuard<'_, C> { fn as_mut(&mut self) -> &mut PgConnection { self.conn.as_mut().expect(NONE_ERR).as_mut() } @@ -407,7 +405,7 @@ impl<'lock, C: AsMut> AsMut for PgAdvisoryLockGuard< /// Queues a `pg_advisory_unlock()` call on the wrapped connection which will be flushed /// to the server the next time it is used, or when it is returned to [`PgPool`][crate::PgPool] /// in the case of [`PoolConnection`][crate::pool::PoolConnection]. -impl<'lock, C: AsMut> Drop for PgAdvisoryLockGuard<'lock, C> { +impl> Drop for PgAdvisoryLockGuard<'_, C> { fn drop(&mut self) { if let Some(mut conn) = self.conn.take() { // Queue a simple query message to execute next time the connection is used. diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index 73d4678aa6..ab76f4b1d1 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -147,10 +147,7 @@ impl AnyConnectionBackend for PgConnection { }) } - fn describe<'c, 'q>( - &'q mut self, - sql: SqlStr, - ) -> BoxFuture<'q, sqlx_core::Result>> { + fn describe<'c>(&mut self, sql: SqlStr) -> BoxFuture<'_, sqlx_core::Result>> { Box::pin(async move { let describe = Executor::describe(self, sql).await?; diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index 2b6e24cd34..c1ee2a2307 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -160,7 +160,7 @@ impl PgConnection { self.inner.pending_ready_for_query_count += 1; } - async fn get_or_prepare<'a>( + async fn get_or_prepare( &mut self, sql: &str, parameters: &[PgTypeInfo], diff --git a/sqlx-postgres/src/connection/tls.rs b/sqlx-postgres/src/connection/tls.rs index 16b7333bf5..a49c9caa8c 100644 --- a/sqlx-postgres/src/connection/tls.rs +++ b/sqlx-postgres/src/connection/tls.rs @@ -7,7 +7,7 @@ use crate::{PgConnectOptions, PgSslMode}; pub struct MaybeUpgradeTls<'a>(pub &'a PgConnectOptions); -impl<'a> WithSocket for MaybeUpgradeTls<'a> { +impl WithSocket for MaybeUpgradeTls<'_> { type Output = crate::Result>; async fn with_socket(self, socket: S) -> Self::Output { diff --git a/sqlx-postgres/src/message/response.rs b/sqlx-postgres/src/message/response.rs index d6e43e0871..a7c09cfa34 100644 --- a/sqlx-postgres/src/message/response.rs +++ b/sqlx-postgres/src/message/response.rs @@ -195,7 +195,7 @@ struct Fields<'a> { offset: usize, } -impl<'a> Iterator for Fields<'a> { +impl Iterator for Fields<'_> { type Item = (u8, Range); fn next(&mut self) -> Option { diff --git a/sqlx-postgres/src/types/cube.rs b/sqlx-postgres/src/types/cube.rs index cc2a016090..d7ddbd1723 100644 --- a/sqlx-postgres/src/types/cube.rs +++ b/sqlx-postgres/src/types/cube.rs @@ -71,7 +71,7 @@ impl<'r> Decode<'r, Postgres> for PgCube { } } -impl<'q> Encode<'q, Postgres> for PgCube { +impl Encode<'_, Postgres> for PgCube { fn produces(&self) -> Option { Some(PgTypeInfo::with_name("cube")) } diff --git a/sqlx-postgres/src/types/geometry/box.rs b/sqlx-postgres/src/types/geometry/box.rs index 28016b2786..ad4fa39ef7 100644 --- a/sqlx-postgres/src/types/geometry/box.rs +++ b/sqlx-postgres/src/types/geometry/box.rs @@ -56,7 +56,7 @@ impl<'r> Decode<'r, Postgres> for PgBox { } } -impl<'q> Encode<'q, Postgres> for PgBox { +impl Encode<'_, Postgres> for PgBox { fn produces(&self) -> Option { Some(PgTypeInfo::with_name("box")) } diff --git a/sqlx-postgres/src/types/geometry/line.rs b/sqlx-postgres/src/types/geometry/line.rs index 8f08c949ef..6bc90676ed 100644 --- a/sqlx-postgres/src/types/geometry/line.rs +++ b/sqlx-postgres/src/types/geometry/line.rs @@ -47,7 +47,7 @@ impl<'r> Decode<'r, Postgres> for PgLine { } } -impl<'q> Encode<'q, Postgres> for PgLine { +impl Encode<'_, Postgres> for PgLine { fn produces(&self) -> Option { Some(PgTypeInfo::with_name("line")) } diff --git a/sqlx-postgres/src/types/geometry/line_segment.rs b/sqlx-postgres/src/types/geometry/line_segment.rs index cd08e4da4a..486d2ba07d 100644 --- a/sqlx-postgres/src/types/geometry/line_segment.rs +++ b/sqlx-postgres/src/types/geometry/line_segment.rs @@ -57,7 +57,7 @@ impl<'r> Decode<'r, Postgres> for PgLSeg { } } -impl<'q> Encode<'q, Postgres> for PgLSeg { +impl Encode<'_, Postgres> for PgLSeg { fn produces(&self) -> Option { Some(PgTypeInfo::with_name("lseg")) } diff --git a/sqlx-postgres/src/types/geometry/point.rs b/sqlx-postgres/src/types/geometry/point.rs index 83b7c24d0d..6f264e85ea 100644 --- a/sqlx-postgres/src/types/geometry/point.rs +++ b/sqlx-postgres/src/types/geometry/point.rs @@ -50,7 +50,7 @@ impl<'r> Decode<'r, Postgres> for PgPoint { } } -impl<'q> Encode<'q, Postgres> for PgPoint { +impl Encode<'_, Postgres> for PgPoint { fn produces(&self) -> Option { Some(PgTypeInfo::with_name("point")) } diff --git a/sqlx-postgres/src/types/json.rs b/sqlx-postgres/src/types/json.rs index 567e48015e..32f886c781 100644 --- a/sqlx-postgres/src/types/json.rs +++ b/sqlx-postgres/src/types/json.rs @@ -54,7 +54,7 @@ impl PgHasArrayType for JsonRawValue { } } -impl<'q, T> Encode<'q, Postgres> for Json +impl Encode<'_, Postgres> for Json where T: Serialize, { diff --git a/sqlx-postgres/src/types/text.rs b/sqlx-postgres/src/types/text.rs index b5b0a5ed7b..12d92d4b2a 100644 --- a/sqlx-postgres/src/types/text.rs +++ b/sqlx-postgres/src/types/text.rs @@ -18,7 +18,7 @@ impl Type for Text { } } -impl<'q, T> Encode<'q, Postgres> for Text +impl Encode<'_, Postgres> for Text where T: Display, { diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index 2345451090..0f3c64f4f5 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -143,7 +143,7 @@ impl AnyConnectionBackend for SqliteConnection { }) } - fn describe<'q>(&'q mut self, sql: SqlStr) -> BoxFuture<'q, sqlx_core::Result>> { + fn describe(&mut self, sql: SqlStr) -> BoxFuture<'_, sqlx_core::Result>> { Box::pin(async move { Executor::describe(self, sql).await?.try_into_any() }) } } diff --git a/sqlx-sqlite/src/connection/intmap.rs b/sqlx-sqlite/src/connection/intmap.rs index dc09162f64..105c0d2b0e 100644 --- a/sqlx-sqlite/src/connection/intmap.rs +++ b/sqlx-sqlite/src/connection/intmap.rs @@ -103,7 +103,7 @@ impl IntMap { *item = Some(V::default()); } - return self.0[idx].as_mut().unwrap(); + self.0[idx].as_mut().unwrap() } } diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index 24b8de2cd0..62c737ba03 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -145,12 +145,11 @@ impl ConnectionWorker { let _guard = span.enter(); match cmd { Command::Prepare { query, tx } => { - tx.send(prepare(&mut conn, query).map(|prepared| { + tx.send(prepare(&mut conn, query).inspect(|_| { update_cached_statements_size( &conn, &shared.cached_statements_size, ); - prepared })) .ok(); } diff --git a/sqlx-sqlite/src/logger.rs b/sqlx-sqlite/src/logger.rs index 3abed8cebc..1464a730c7 100644 --- a/sqlx-sqlite/src/logger.rs +++ b/sqlx-sqlite/src/logger.rs @@ -436,7 +436,7 @@ impl<'q, R: Debug, S: Debug + DebugDiff, P: Debug> QueryPlanLogger<'q, R, S, P> } } -impl<'q, R: Debug, S: Debug + DebugDiff, P: Debug> Drop for QueryPlanLogger<'q, R, S, P> { +impl Drop for QueryPlanLogger<'_, R, S, P> { fn drop(&mut self) { self.finish(); } diff --git a/sqlx-sqlite/src/value.rs b/sqlx-sqlite/src/value.rs index 469c4e70d5..dc40f29ccb 100644 --- a/sqlx-sqlite/src/value.rs +++ b/sqlx-sqlite/src/value.rs @@ -108,8 +108,8 @@ pub(crate) struct ValueHandle<'a> { } // SAFE: only protected value objects are stored in SqliteValue -unsafe impl<'a> Send for ValueHandle<'a> {} -unsafe impl<'a> Sync for ValueHandle<'a> {} +unsafe impl Send for ValueHandle<'_> {} +unsafe impl Sync for ValueHandle<'_> {} impl ValueHandle<'static> { fn new_owned(value: NonNull, type_info: SqliteTypeInfo) -> Self { @@ -122,7 +122,7 @@ impl ValueHandle<'static> { } } -impl<'a> ValueHandle<'a> { +impl ValueHandle<'_> { fn new_borrowed(value: NonNull, type_info: SqliteTypeInfo) -> Self { Self { value, @@ -185,7 +185,7 @@ impl<'a> ValueHandle<'a> { } } -impl<'a> Drop for ValueHandle<'a> { +impl Drop for ValueHandle<'_> { fn drop(&mut self) { if self.free_on_drop { unsafe { From a0b6739432ef9e430c163b6d2d3d8cb567b32ed4 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Mon, 17 Feb 2025 21:52:04 +0100 Subject: [PATCH 12/12] fix doc test --- sqlx-core/src/query.rs | 2 +- sqlx-core/src/sql_str.rs | 11 ++++++++++- sqlx-core/src/transaction.rs | 1 + sqlx-mysql/src/transaction.rs | 5 +++-- sqlx-postgres/src/any.rs | 3 +-- sqlx-postgres/src/transaction.rs | 5 +++-- sqlx-sqlite/src/connection/worker.rs | 7 ++++--- 7 files changed, 23 insertions(+), 11 deletions(-) diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index fba07417b1..fbdb09263f 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -558,7 +558,7 @@ where /// let query = format!("SELECT * FROM articles WHERE content LIKE '%{user_input}%'"); /// // where `conn` is `PgConnection` or `MySqlConnection` /// // or some other type that implements `Executor`. -/// let results = sqlx::query(&query).fetch_all(&mut conn).await?; +/// let results = sqlx::query(sqlx::AssertSqlSafe(query)).fetch_all(&mut conn).await?; /// # Ok(()) /// # } /// ``` diff --git a/sqlx-core/src/sql_str.rs b/sqlx-core/src/sql_str.rs index c5a870712d..fb43a07453 100644 --- a/sqlx-core/src/sql_str.rs +++ b/sqlx-core/src/sql_str.rs @@ -1,4 +1,4 @@ -use std::borrow::Borrow; +use std::borrow::{Borrow, Cow}; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -102,6 +102,15 @@ impl SqlSafeStr for AssertSqlSafe> { } } +impl SqlSafeStr for AssertSqlSafe> { + fn into_sql_str(self) -> SqlStr { + match self.0 { + Cow::Borrowed(str) => str.into_sql_str(), + Cow::Owned(str) => AssertSqlSafe(str).into_sql_str(), + } + } +} + /// A SQL string that is ready to execute on a database connection. /// /// This is essentially `Cow<'static, str>` but which can be constructed from additional types diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 3ad0ada9f6..7d45dd2b98 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::ops::{Deref, DerefMut}; diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index 85c7c56b48..3bed4c030d 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use futures_core::future::BoxFuture; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; use crate::connection::Waiting; use crate::error::Error; @@ -27,10 +28,10 @@ impl TransactionManager for MySqlTransactionManager { // custom `BEGIN` statements are not allowed if we're already in a transaction // (we need to issue a `SAVEPOINT` instead) Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), - Some(statement) => statement, + Some(statement) => AssertSqlSafe(statement).into_sql_str(), None => begin_ansi_transaction_sql(depth), }; - conn.execute(&*statement).await?; + conn.execute(statement).await?; if !conn.in_transaction() { return Err(Error::BeginFailed); } diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index ab76f4b1d1..aaeff02ba1 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -5,9 +5,8 @@ use crate::{ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt}; -use std::borrow::Cow; -use std::{future, pin::pin}; use sqlx_core::sql_str::SqlStr; +use std::borrow::Cow; use std::{future, pin::pin}; use sqlx_core::any::{ diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index b3f5c20f9e..69cc3c9a1a 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -1,5 +1,6 @@ use futures_core::future::BoxFuture; use sqlx_core::database::Database; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; use std::borrow::Cow; use crate::error::Error; @@ -25,13 +26,13 @@ impl TransactionManager for PgTransactionManager { // custom `BEGIN` statements are not allowed if we're already in // a transaction (we need to issue a `SAVEPOINT` instead) Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), - Some(statement) => statement, + Some(statement) => AssertSqlSafe(statement).into_sql_str(), None => begin_ansi_transaction_sql(depth), }; let rollback = Rollback::new(conn); - rollback.conn.queue_simple_query(&statement)?; + rollback.conn.queue_simple_query(statement.as_str())?; rollback.conn.wait_until_ready().await?; if !rollback.conn.in_transaction() { return Err(Error::BeginFailed); diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index 62c737ba03..1c534e8e4c 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::future::Future; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -5,7 +6,7 @@ use std::thread; use futures_channel::oneshot; use futures_intrusive::sync::{Mutex, MutexGuard}; -use sqlx_core::sql_str::SqlStr; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; use tracing::span::Span; use sqlx_core::describe::Describe; @@ -219,12 +220,12 @@ impl ConnectionWorker { } continue; }, - Some(statement) => statement, + Some(statement) => AssertSqlSafe(statement).into_sql_str(), None => begin_ansi_transaction_sql(depth), }; let res = conn.handle - .exec(begin_ansi_transaction_sql(depth).as_str()) + .exec(statement.as_str()) .map(|_| { shared.transaction_depth.fetch_add(1, Ordering::Release); });