diff --git a/sqlx-cli/src/database.rs b/sqlx-cli/src/database.rs index 7a9bc6bf2f..48dff16838 100644 --- a/sqlx-cli/src/database.rs +++ b/sqlx-cli/src/database.rs @@ -1,3 +1,4 @@ +#![allow(unexpected_cfgs)] use crate::migrate; use crate::opt::ConnectOpts; use console::{style, Term}; diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index 07058aa147..dbc57f890e 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -1,3 +1,4 @@ +#![allow(unexpected_cfgs)] use std::ops::{Deref, Not}; use clap::{Args, Parser}; diff --git a/sqlx-core/src/any/arguments.rs b/sqlx-core/src/any/arguments.rs index 59a0c0f765..2c05e3fd5b 100644 --- a/sqlx-core/src/any/arguments.rs +++ b/sqlx-core/src/any/arguments.rs @@ -32,7 +32,7 @@ impl<'q> Arguments<'q> for AnyArguments<'q> { pub struct AnyArgumentBuffer<'q>(#[doc(hidden)] pub Vec>); -impl<'q> Default for AnyArguments<'q> { +impl Default for AnyArguments<'_> { fn default() -> Self { AnyArguments { values: AnyArgumentBuffer(vec![]), diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index 6c84c1d8ce..7575219d38 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -1,5 +1,6 @@ use crate::any::{Any, AnyArguments, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo}; use crate::describe::Describe; +use crate::sql_str::SqlStr; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; @@ -96,23 +97,23 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { fn fetch_many<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxStream<'q, crate::Result>>; fn fetch_optional<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxFuture<'q, crate::Result>>; fn prepare_with<'c, 'q: 'c>( &'c mut self, - sql: &'q str, + sql: SqlStr, parameters: &[AnyTypeInfo], - ) -> BoxFuture<'c, crate::Result>>; + ) -> BoxFuture<'c, crate::Result>; - fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, crate::Result>>; + fn describe(&mut self, sql: SqlStr) -> BoxFuture<'_, crate::Result>>; } diff --git a/sqlx-core/src/any/connection/executor.rs b/sqlx-core/src/any/connection/executor.rs index ccf6dd7933..2f147f65bc 100644 --- a/sqlx-core/src/any/connection/executor.rs +++ b/sqlx-core/src/any/connection/executor.rs @@ -2,6 +2,7 @@ use crate::any::{Any, AnyConnection, AnyQueryResult, AnyRow, AnyStatement, AnyTy use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; +use crate::sql_str::SqlSafeStr; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; @@ -23,8 +24,8 @@ impl<'c> Executor<'c> for &'c mut AnyConnection { Ok(arguments) => arguments, Err(error) => return stream::once(future::ready(Err(error))).boxed(), }; - self.backend - .fetch_many(query.sql(), query.persistent(), arguments) + let persistent = query.persistent(); + self.backend.fetch_many(query.sql(), persistent, arguments) } fn fetch_optional<'e, 'q: 'e, E>( @@ -39,28 +40,29 @@ impl<'c> Executor<'c> for &'c mut AnyConnection { Ok(arguments) => arguments, Err(error) => return future::ready(Err(error)).boxed(), }; + let persistent = query.persistent(); self.backend - .fetch_optional(query.sql(), query.persistent(), arguments) + .fetch_optional(query.sql(), persistent, arguments) } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, parameters: &[AnyTypeInfo], - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result> where 'c: 'e, { - self.backend.prepare_with(sql, parameters) + self.backend.prepare_with(sql.into_sql_str(), parameters) } - fn describe<'e, 'q: 'e>( + fn describe<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { - self.backend.describe(sql) + self.backend.describe(sql.into_sql_str()) } } diff --git a/sqlx-core/src/any/database.rs b/sqlx-core/src/any/database.rs index 9c3f15bb1f..6e8343e928 100644 --- a/sqlx-core/src/any/database.rs +++ b/sqlx-core/src/any/database.rs @@ -28,7 +28,7 @@ impl Database for Any { type Arguments<'q> = AnyArguments<'q>; type ArgumentBuffer<'q> = AnyArgumentBuffer<'q>; - type Statement<'q> = AnyStatement<'q>; + type Statement = AnyStatement; const NAME: &'static str = "Any"; diff --git a/sqlx-core/src/any/row.rs b/sqlx-core/src/any/row.rs index 310881da14..57b8590b5f 100644 --- a/sqlx-core/src/any/row.rs +++ b/sqlx-core/src/any/row.rs @@ -63,7 +63,7 @@ impl Row for AnyRow { } } -impl<'i> ColumnIndex for &'i str { +impl ColumnIndex for &'_ str { fn index(&self, row: &AnyRow) -> Result { row.column_names .get(*self) diff --git a/sqlx-core/src/any/statement.rs b/sqlx-core/src/any/statement.rs index 1fbb11895c..6fa979743e 100644 --- a/sqlx-core/src/any/statement.rs +++ b/sqlx-core/src/any/statement.rs @@ -3,15 +3,15 @@ use crate::column::ColumnIndex; use crate::database::Database; use crate::error::Error; use crate::ext::ustr::UStr; +use crate::sql_str::SqlStr; use crate::statement::Statement; use crate::HashMap; use either::Either; -use std::borrow::Cow; use std::sync::Arc; -pub struct AnyStatement<'q> { +pub struct AnyStatement { #[doc(hidden)] - pub sql: Cow<'q, str>, + pub sql: SqlStr, #[doc(hidden)] pub parameters: Option, usize>>, #[doc(hidden)] @@ -20,20 +20,24 @@ pub struct AnyStatement<'q> { pub columns: Vec, } -impl<'q> Statement<'q> for AnyStatement<'q> { +impl Statement for AnyStatement { type Database = Any; - fn to_owned(&self) -> AnyStatement<'static> { - AnyStatement::<'static> { - sql: Cow::Owned(self.sql.clone().into_owned()), + fn to_owned(&self) -> AnyStatement { + AnyStatement { + sql: self.sql.clone(), column_names: self.column_names.clone(), parameters: self.parameters.clone(), columns: self.columns.clone(), } } - fn sql(&self) -> &str { - &self.sql + fn sql_cloned(&self) -> SqlStr { + self.sql.clone() + } + + fn into_sql(self) -> SqlStr { + self.sql } fn parameters(&self) -> Option> { @@ -51,8 +55,8 @@ impl<'q> Statement<'q> for AnyStatement<'q> { impl_statement_query!(AnyArguments<'_>); } -impl<'i> ColumnIndex> for &'i str { - fn index(&self, statement: &AnyStatement<'_>) -> Result { +impl ColumnIndex for &'_ str { + fn index(&self, statement: &AnyStatement) -> Result { statement .column_names .get(*self) @@ -61,15 +65,14 @@ impl<'i> ColumnIndex> for &'i str { } } -impl<'q> AnyStatement<'q> { +impl AnyStatement { #[doc(hidden)] pub fn try_from_statement( - query: &'q str, - statement: &S, + statement: S, column_names: Arc>, ) -> crate::Result where - S: Statement<'q>, + S: Statement, AnyTypeInfo: for<'a> TryFrom<&'a ::TypeInfo, Error = Error>, AnyColumn: for<'a> TryFrom<&'a ::Column, Error = Error>, { @@ -91,7 +94,7 @@ impl<'q> AnyStatement<'q> { .collect::, _>>()?; Ok(Self { - sql: query.into(), + sql: statement.into_sql(), columns, column_names, parameters, diff --git a/sqlx-core/src/column.rs b/sqlx-core/src/column.rs index 9f45819ed6..ebcd2d6829 100644 --- a/sqlx-core/src/column.rs +++ b/sqlx-core/src/column.rs @@ -69,8 +69,8 @@ macro_rules! impl_column_index_for_row { #[macro_export] macro_rules! impl_column_index_for_statement { ($S:ident) => { - impl $crate::column::ColumnIndex<$S<'_>> for usize { - fn index(&self, statement: &$S<'_>) -> Result { + impl $crate::column::ColumnIndex<$S> for usize { + fn index(&self, statement: &$S) -> Result { let len = $crate::statement::Statement::columns(statement).len(); if *self >= len { diff --git a/sqlx-core/src/database.rs b/sqlx-core/src/database.rs index e44c3d88ac..02d7a1214e 100644 --- a/sqlx-core/src/database.rs +++ b/sqlx-core/src/database.rs @@ -101,7 +101,7 @@ pub trait Database: 'static + Sized + Send + Debug { type ArgumentBuffer<'q>; /// The concrete `Statement` implementation for this database. - type Statement<'q>: Statement<'q, Database = Self>; + type Statement: Statement; /// The display name for this database driver. const NAME: &'static str; diff --git a/sqlx-core/src/executor.rs b/sqlx-core/src/executor.rs index 84b1a660d8..a66c127789 100644 --- a/sqlx-core/src/executor.rs +++ b/sqlx-core/src/executor.rs @@ -1,6 +1,7 @@ use crate::database::Database; use crate::describe::Describe; use crate::error::{BoxDynError, Error}; +use crate::sql_str::{SqlSafeStr, SqlStr}; use either::Either; use futures_core::future::BoxFuture; @@ -146,10 +147,10 @@ pub trait Executor<'c>: Send + Debug + Sized { /// This explicit API is provided to allow access to the statement metadata available after /// it prepared but before the first row is returned. #[inline] - fn prepare<'e, 'q: 'e>( + fn prepare<'e>( self, - query: &'q str, - ) -> BoxFuture<'e, Result<::Statement<'q>, Error>> + query: impl SqlSafeStr, + ) -> BoxFuture<'e, Result<::Statement, Error>> where 'c: 'e, { @@ -161,11 +162,11 @@ pub trait Executor<'c>: Send + Debug + Sized { /// /// Only some database drivers (PostgreSQL, MSSQL) can take advantage of /// this extra information to influence parameter type inference. - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, parameters: &'e [::TypeInfo], - ) -> BoxFuture<'e, Result<::Statement<'q>, Error>> + ) -> BoxFuture<'e, Result<::Statement, Error>> where 'c: 'e; @@ -175,9 +176,9 @@ pub trait Executor<'c>: Send + Debug + Sized { /// This is used by compile-time verification in the query macros to /// power their type inference. #[doc(hidden)] - fn describe<'e, 'q: 'e>( + fn describe<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, ) -> BoxFuture<'e, Result, Error>> where 'c: 'e; @@ -192,10 +193,10 @@ pub trait Executor<'c>: Send + Debug + Sized { /// pub trait Execute<'q, DB: Database>: Send + Sized { /// Gets the SQL that will be executed. - fn sql(&self) -> &'q str; + fn sql(self) -> SqlStr; /// Gets the previously cached statement, if available. - fn statement(&self) -> Option<&DB::Statement<'q>>; + fn statement(&self) -> Option<&DB::Statement>; /// Returns the arguments to be bound against the query string. /// @@ -210,22 +211,23 @@ pub trait Execute<'q, DB: Database>: Send + Sized { fn persistent(&self) -> bool; } -// NOTE: `Execute` is explicitly not implemented for String and &String to make it slightly more -// involved to write `conn.execute(format!("SELECT {val}"))` -impl<'q, DB: Database> Execute<'q, DB> for &'q str { +impl<'q, DB: Database, T> Execute<'q, DB> for (T, Option<::Arguments<'q>>) +where + T: SqlSafeStr + Send, +{ #[inline] - fn sql(&self) -> &'q str { - self + fn sql(self) -> SqlStr { + self.0.into_sql_str() } #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { None } #[inline] fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { - Ok(None) + Ok(self.1.take()) } #[inline] @@ -234,20 +236,23 @@ impl<'q, DB: Database> Execute<'q, DB> for &'q str { } } -impl<'q, DB: Database> Execute<'q, DB> for (&'q str, Option<::Arguments<'q>>) { +impl<'q, DB: Database, T> Execute<'q, DB> for T +where + T: SqlSafeStr + Send, +{ #[inline] - fn sql(&self) -> &'q str { - self.0 + fn sql(self) -> SqlStr { + self.into_sql_str() } #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { None } #[inline] fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { - Ok(self.1.take()) + Ok(None) } #[inline] diff --git a/sqlx-core/src/ext/async_stream.rs b/sqlx-core/src/ext/async_stream.rs index 56777ca4db..c41d940981 100644 --- a/sqlx-core/src/ext/async_stream.rs +++ b/sqlx-core/src/ext/async_stream.rs @@ -95,7 +95,7 @@ impl Yielder { } } -impl<'a, T> Stream for TryAsyncStream<'a, T> { +impl Stream for TryAsyncStream<'_, T> { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/sqlx-core/src/io/encode.rs b/sqlx-core/src/io/encode.rs index a603ea9325..ba032d294d 100644 --- a/sqlx-core/src/io/encode.rs +++ b/sqlx-core/src/io/encode.rs @@ -9,7 +9,7 @@ pub trait ProtocolEncode<'en, Context = ()> { fn encode_with(&self, buf: &mut Vec, context: Context) -> Result<(), crate::Error>; } -impl<'en, C> ProtocolEncode<'en, C> for &'_ [u8] { +impl ProtocolEncode<'_, C> for &'_ [u8] { fn encode_with(&self, buf: &mut Vec, _context: C) -> Result<(), crate::Error> { buf.extend_from_slice(self); Ok(()) diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index df4b2cc27d..4671c16708 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -72,6 +72,7 @@ pub mod net; pub mod query_as; pub mod query_builder; pub mod query_scalar; +pub mod sql_str; pub mod raw_sql; pub mod row; diff --git a/sqlx-core/src/logger.rs b/sqlx-core/src/logger.rs index cf6dd533bd..7cbbff0077 100644 --- a/sqlx-core/src/logger.rs +++ b/sqlx-core/src/logger.rs @@ -1,4 +1,4 @@ -use crate::connection::LogSettings; +use crate::{connection::LogSettings, sql_str::SqlStr}; use std::time::Instant; // Yes these look silly. `tracing` doesn't currently support dynamic levels @@ -60,16 +60,16 @@ pub(crate) fn private_level_filter_to_trace_level( private_level_filter_to_levels(filter).map(|(level, _)| level) } -pub struct QueryLogger<'q> { - sql: &'q str, +pub struct QueryLogger { + sql: SqlStr, rows_returned: u64, rows_affected: u64, start: Instant, settings: LogSettings, } -impl<'q> QueryLogger<'q> { - pub fn new(sql: &'q str, settings: LogSettings) -> Self { +impl QueryLogger { + pub fn new(sql: SqlStr, settings: LogSettings) -> Self { Self { sql, rows_returned: 0, @@ -104,18 +104,18 @@ impl<'q> QueryLogger<'q> { let log_is_enabled = log::log_enabled!(target: "sqlx::query", log_level) || private_tracing_dynamic_enabled!(target: "sqlx::query", tracing_level); if log_is_enabled { - let mut summary = parse_query_summary(self.sql); + let mut summary = parse_query_summary(self.sql.as_str()); - let sql = if summary != self.sql { + let sql = if summary != self.sql.as_str() { summary.push_str(" …"); format!( "\n\n{}\n", - self.sql /* - sqlformat::format( - self.sql, - &sqlformat::QueryParams::None, - sqlformat::FormatOptions::default() - )*/ + self.sql.as_str() /* + sqlformat::format( + self.sql, + &sqlformat::QueryParams::None, + sqlformat::FormatOptions::default() + )*/ ) } else { String::new() @@ -158,7 +158,7 @@ impl<'q> QueryLogger<'q> { } } -impl<'q> Drop for QueryLogger<'q> { +impl Drop for QueryLogger { fn drop(&mut self) { self.finish(); } diff --git a/sqlx-core/src/migrate/migration.rs b/sqlx-core/src/migrate/migration.rs index 9bd7f569d8..394a968fd4 100644 --- a/sqlx-core/src/migrate/migration.rs +++ b/sqlx-core/src/migrate/migration.rs @@ -9,6 +9,7 @@ pub struct Migration { pub version: i64, pub description: Cow<'static, str>, pub migration_type: MigrationType, + // We can't use `SqlStr` here because it can't be used in a const context pub sql: Cow<'static, str>, pub checksum: Cow<'static, [u8]>, pub no_tx: bool, diff --git a/sqlx-core/src/net/socket/mod.rs b/sqlx-core/src/net/socket/mod.rs index d11f15884e..1f24da8c40 100644 --- a/sqlx-core/src/net/socket/mod.rs +++ b/sqlx-core/src/net/socket/mod.rs @@ -62,7 +62,7 @@ pub struct Read<'a, S: ?Sized, B> { buf: &'a mut B, } -impl<'a, S: ?Sized, B> Future for Read<'a, S, B> +impl Future for Read<'_, S, B> where S: Socket, B: ReadBuf, @@ -90,7 +90,7 @@ pub struct Write<'a, S: ?Sized> { buf: &'a [u8], } -impl<'a, S: ?Sized> Future for Write<'a, S> +impl Future for Write<'_, S> where S: Socket, { @@ -116,7 +116,7 @@ pub struct Flush<'a, S: ?Sized> { socket: &'a mut S, } -impl<'a, S: Socket + ?Sized> Future for Flush<'a, S> { +impl Future for Flush<'_, S> { type Output = io::Result<()>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -128,7 +128,7 @@ pub struct Shutdown<'a, S: ?Sized> { socket: &'a mut S, } -impl<'a, S: ?Sized> Future for Shutdown<'a, S> +impl Future for Shutdown<'_, S> where S: Socket, { diff --git a/sqlx-core/src/pool/executor.rs b/sqlx-core/src/pool/executor.rs index ba27b44316..f0c1a31601 100644 --- a/sqlx-core/src/pool/executor.rs +++ b/sqlx-core/src/pool/executor.rs @@ -8,6 +8,7 @@ use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::pool::Pool; +use crate::sql_str::SqlSafeStr; impl<'p, DB: Database> Executor<'p> for &'_ Pool where @@ -48,22 +49,30 @@ where Box::pin(async move { pool.acquire().await?.fetch_optional(query).await }) } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, parameters: &'e [::TypeInfo], - ) -> BoxFuture<'e, Result<::Statement<'q>, Error>> { + ) -> BoxFuture<'e, Result<::Statement, Error>> + where + 'p: 'e, + { let pool = self.clone(); + let sql = sql.into_sql_str(); Box::pin(async move { pool.acquire().await?.prepare_with(sql, parameters).await }) } #[doc(hidden)] - fn describe<'e, 'q: 'e>( + fn describe<'e>( self, - sql: &'q str, - ) -> BoxFuture<'e, Result, Error>> { + sql: impl SqlSafeStr, + ) -> BoxFuture<'e, Result, Error>> + where + 'p: 'e, + { let pool = self.clone(); + let sql = sql.into_sql_str(); Box::pin(async move { pool.acquire().await?.describe(sql).await }) } diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index 2066364a8e..51dd4b0055 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -452,14 +452,14 @@ pub(super) fn is_beyond_max_lifetime( ) -> bool { options .max_lifetime - .map_or(false, |max| live.created_at.elapsed() > max) + .is_some_and(|max| live.created_at.elapsed() > max) } /// Returns `true` if the connection has exceeded `options.idle_timeout` if set, `false` otherwise. fn is_beyond_idle_timeout(idle: &Idle, options: &PoolOptions) -> bool { options .idle_timeout - .map_or(false, |timeout| idle.idle_since.elapsed() > timeout) + .is_some_and(|timeout| idle.idle_since.elapsed() > timeout) } async fn check_idle_conn( diff --git a/sqlx-core/src/pool/maybe.rs b/sqlx-core/src/pool/maybe.rs index f9f16c41a5..71a48728a2 100644 --- a/sqlx-core/src/pool/maybe.rs +++ b/sqlx-core/src/pool/maybe.rs @@ -8,7 +8,7 @@ pub enum MaybePoolConnection<'c, DB: Database> { PoolConnection(PoolConnection), } -impl<'c, DB: Database> Deref for MaybePoolConnection<'c, DB> { +impl Deref for MaybePoolConnection<'_, DB> { type Target = DB::Connection; #[inline] @@ -20,7 +20,7 @@ impl<'c, DB: Database> Deref for MaybePoolConnection<'c, DB> { } } -impl<'c, DB: Database> DerefMut for MaybePoolConnection<'c, DB> { +impl DerefMut for MaybePoolConnection<'_, DB> { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { match self { @@ -30,7 +30,7 @@ impl<'c, DB: Database> DerefMut for MaybePoolConnection<'c, DB> { } } -impl<'c, DB: Database> From> for MaybePoolConnection<'c, DB> { +impl From> for MaybePoolConnection<'_, DB> { fn from(v: PoolConnection) -> Self { MaybePoolConnection::PoolConnection(v) } diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index 60f509c342..fbdb09263f 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -9,13 +9,14 @@ use crate::database::{Database, HasStatementCache}; use crate::encode::Encode; use crate::error::{BoxDynError, Error}; use crate::executor::{Execute, Executor}; +use crate::sql_str::{SqlSafeStr, SqlStr}; use crate::statement::Statement; use crate::types::Type; /// A single SQL query as a prepared statement. Returned by [`query()`]. #[must_use = "query must be executed to affect database"] pub struct Query<'q, DB: Database, A> { - pub(crate) statement: Either<&'q str, &'q DB::Statement<'q>>, + pub(crate) statement: Either, pub(crate) arguments: Option>, pub(crate) database: PhantomData, pub(crate) persistent: bool, @@ -44,14 +45,14 @@ where A: Send + IntoArguments<'q, DB>, { #[inline] - fn sql(&self) -> &'q str { + fn sql(self) -> SqlStr { match self.statement { - Either::Right(statement) => statement.sql(), + Either::Right(statement) => statement.sql_cloned(), Either::Left(sql) => sql, } } - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { match self.statement { Either::Right(statement) => Some(statement), Either::Left(_) => None, @@ -120,7 +121,7 @@ impl<'q, DB: Database> Query<'q, DB, ::Arguments<'q>> { } } -impl<'q, DB, A> Query<'q, DB, A> +impl Query<'_, DB, A> where DB: Database + HasStatementCache, { @@ -303,12 +304,12 @@ where A: IntoArguments<'q, DB>, { #[inline] - fn sql(&self) -> &'q str { + fn sql(self) -> SqlStr { self.inner.sql() } #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { self.inner.statement() } @@ -497,9 +498,9 @@ where } /// Execute a single SQL query as a prepared statement (explicitly created). -pub fn query_statement<'q, DB>( - statement: &'q DB::Statement<'q>, -) -> Query<'q, DB, ::Arguments<'_>> +pub fn query_statement( + statement: &DB::Statement, +) -> Query<'_, DB, ::Arguments<'_>> where DB: Database, { @@ -513,7 +514,7 @@ where /// Execute a single SQL query as a prepared statement (explicitly created), with the given arguments. pub fn query_statement_with<'q, DB, A>( - statement: &'q DB::Statement<'q>, + statement: &'q DB::Statement, arguments: A, ) -> Query<'q, DB, A> where @@ -557,7 +558,7 @@ where /// let query = format!("SELECT * FROM articles WHERE content LIKE '%{user_input}%'"); /// // where `conn` is `PgConnection` or `MySqlConnection` /// // or some other type that implements `Executor`. -/// let results = sqlx::query(&query).fetch_all(&mut conn).await?; +/// let results = sqlx::query(sqlx::AssertSqlSafe(query)).fetch_all(&mut conn).await?; /// # Ok(()) /// # } /// ``` @@ -652,14 +653,14 @@ where /// /// As an additional benefit, query parameters are usually sent in a compact binary encoding instead of a human-readable /// text encoding, which saves bandwidth. -pub fn query(sql: &str) -> Query<'_, DB, ::Arguments<'_>> +pub fn query<'a, DB>(sql: impl SqlSafeStr) -> Query<'a, DB, ::Arguments<'a>> where DB: Database, { Query { database: PhantomData, arguments: Some(Ok(Default::default())), - statement: Either::Left(sql), + statement: Either::Left(sql.into_sql_str()), persistent: true, } } @@ -667,7 +668,7 @@ where /// Execute a SQL query as a prepared statement (transparently cached), with the given arguments. /// /// See [`query()`][query] for details, such as supported syntax. -pub fn query_with<'q, DB, A>(sql: &'q str, arguments: A) -> Query<'q, DB, A> +pub fn query_with<'q, DB, A>(sql: impl SqlSafeStr, arguments: A) -> Query<'q, DB, A> where DB: Database, A: IntoArguments<'q, DB>, @@ -677,7 +678,7 @@ where /// Same as [`query_with`] but is initialized with a Result of arguments instead pub fn query_with_result<'q, DB, A>( - sql: &'q str, + sql: impl SqlSafeStr, arguments: Result, ) -> Query<'q, DB, A> where @@ -687,7 +688,7 @@ where Query { database: PhantomData, arguments: Some(arguments), - statement: Either::Left(sql), + statement: Either::Left(sql.into_sql_str()), persistent: true, } } diff --git a/sqlx-core/src/query_as.rs b/sqlx-core/src/query_as.rs index 9f28fe41e9..a9a035a82c 100644 --- a/sqlx-core/src/query_as.rs +++ b/sqlx-core/src/query_as.rs @@ -11,6 +11,7 @@ use crate::error::{BoxDynError, Error}; use crate::executor::{Execute, Executor}; use crate::from_row::FromRow; use crate::query::{query, query_statement, query_statement_with, query_with_result, Query}; +use crate::sql_str::{SqlSafeStr, SqlStr}; use crate::types::Type; /// A single SQL query as a prepared statement, mapping results using [`FromRow`]. @@ -27,12 +28,12 @@ where A: 'q + IntoArguments<'q, DB>, { #[inline] - fn sql(&self) -> &'q str { + fn sql(self) -> SqlStr { self.inner.sql() } #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { self.inner.statement() } @@ -57,7 +58,7 @@ impl<'q, DB: Database, O> QueryAs<'q, DB, O, ::Arguments<'q>> { } } -impl<'q, DB, O, A> QueryAs<'q, DB, O, A> +impl QueryAs<'_, DB, O, A> where DB: Database + HasStatementCache, { @@ -339,7 +340,9 @@ where /// /// ``` #[inline] -pub fn query_as<'q, DB, O>(sql: &'q str) -> QueryAs<'q, DB, O, ::Arguments<'q>> +pub fn query_as<'q, DB, O>( + sql: impl SqlSafeStr, +) -> QueryAs<'q, DB, O, ::Arguments<'q>> where DB: Database, O: for<'r> FromRow<'r, DB::Row>, @@ -357,7 +360,7 @@ where /// /// For details about type mapping from [`FromRow`], see [`query_as()`]. #[inline] -pub fn query_as_with<'q, DB, O, A>(sql: &'q str, arguments: A) -> QueryAs<'q, DB, O, A> +pub fn query_as_with<'q, DB, O, A>(sql: impl SqlSafeStr, arguments: A) -> QueryAs<'q, DB, O, A> where DB: Database, A: IntoArguments<'q, DB>, @@ -369,7 +372,7 @@ where /// Same as [`query_as_with`] but takes arguments as a Result #[inline] pub fn query_as_with_result<'q, DB, O, A>( - sql: &'q str, + sql: impl SqlSafeStr, arguments: Result, ) -> QueryAs<'q, DB, O, A> where @@ -384,9 +387,9 @@ where } // Make a SQL query from a statement, that is mapped to a concrete type. -pub fn query_statement_as<'q, DB, O>( - statement: &'q DB::Statement<'q>, -) -> QueryAs<'q, DB, O, ::Arguments<'_>> +pub fn query_statement_as( + statement: &DB::Statement, +) -> QueryAs<'_, DB, O, ::Arguments<'_>> where DB: Database, O: for<'r> FromRow<'r, DB::Row>, @@ -399,7 +402,7 @@ where // Make a SQL query from a statement, with the given arguments, that is mapped to a concrete type. pub fn query_statement_as_with<'q, DB, O, A>( - statement: &'q DB::Statement<'q>, + statement: &'q DB::Statement, arguments: A, ) -> QueryAs<'q, DB, O, A> where diff --git a/sqlx-core/src/query_builder.rs b/sqlx-core/src/query_builder.rs index b242bf7b2a..b4eed071f8 100644 --- a/sqlx-core/src/query_builder.rs +++ b/sqlx-core/src/query_builder.rs @@ -3,6 +3,7 @@ use std::fmt::Display; use std::fmt::Write; use std::marker::PhantomData; +use std::sync::Arc; use crate::arguments::{Arguments, IntoArguments}; use crate::database::Database; @@ -11,6 +12,8 @@ use crate::from_row::FromRow; use crate::query::Query; use crate::query_as::QueryAs; use crate::query_scalar::QueryScalar; +use crate::sql_str::AssertSqlSafe; +use crate::sql_str::SqlSafeStr; use crate::types::Type; use crate::Either; @@ -25,21 +28,23 @@ pub struct QueryBuilder<'args, DB> where DB: Database, { - query: String, + query: Arc, init_len: usize, arguments: Option<::Arguments<'args>>, } -impl<'args, DB: Database> Default for QueryBuilder<'args, DB> { +impl Default for QueryBuilder<'_, DB> { fn default() -> Self { QueryBuilder { init_len: 0, - query: String::default(), + query: String::default().into(), arguments: Some(Default::default()), } } } +const ERROR: &str = "BUG: query must not be shared at this point in time"; + impl<'args, DB: Database> QueryBuilder<'args, DB> where DB: Database, @@ -55,7 +60,7 @@ where QueryBuilder { init_len: init.len(), - query: init, + query: init.into(), arguments: Some(Default::default()), } } @@ -73,7 +78,7 @@ where QueryBuilder { init_len: init.len(), - query: init, + query: init.into(), arguments: Some(arguments.into_arguments()), } } @@ -115,8 +120,9 @@ where /// e.g. check that strings aren't too long, numbers are within expected ranges, etc. pub fn push(&mut self, sql: impl Display) -> &mut Self { self.sanity_check(); + let query: &mut String = Arc::get_mut(&mut self.query).expect(ERROR); - write!(self.query, "{sql}").expect("error formatting `sql`"); + write!(query, "{sql}").expect("error formatting `sql`"); self } @@ -157,8 +163,9 @@ where .expect("BUG: Arguments taken already"); arguments.add(value).expect("Failed to add argument"); + let query: &mut String = Arc::get_mut(&mut self.query).expect(ERROR); arguments - .format_placeholder(&mut self.query) + .format_placeholder(query) .expect("error in format_placeholder"); self @@ -191,7 +198,6 @@ where /// assert!(sql.ends_with("in (?, ?) ")); /// # } /// ``` - pub fn separated<'qb, Sep>(&'qb mut self, separator: Sep) -> Separated<'qb, 'args, DB, Sep> where 'args: 'qb, @@ -454,7 +460,7 @@ where self.sanity_check(); Query { - statement: Either::Left(&self.query), + statement: Either::Left(AssertSqlSafe(self.query.clone()).into_sql_str()), arguments: self.arguments.take().map(Ok), database: PhantomData, persistent: true, @@ -511,7 +517,8 @@ where /// The query is truncated to the initial fragment provided to [`new()`][Self::new] and /// the bind arguments are reset. pub fn reset(&mut self) -> &mut Self { - self.query.truncate(self.init_len); + let query: &mut String = Arc::get_mut(&mut self.query).expect(ERROR); + query.truncate(self.init_len); self.arguments = Some(Default::default()); self @@ -524,7 +531,7 @@ where /// Deconstruct this `QueryBuilder`, returning the built SQL. May not be syntactically correct. pub fn into_sql(self) -> String { - self.query + Arc::into_inner(self.query).unwrap() } } diff --git a/sqlx-core/src/query_scalar.rs b/sqlx-core/src/query_scalar.rs index c131adcca3..1059463874 100644 --- a/sqlx-core/src/query_scalar.rs +++ b/sqlx-core/src/query_scalar.rs @@ -11,6 +11,7 @@ use crate::from_row::FromRow; use crate::query_as::{ query_as, query_as_with_result, query_statement_as, query_statement_as_with, QueryAs, }; +use crate::sql_str::{SqlSafeStr, SqlStr}; use crate::types::Type; /// A single SQL query as a prepared statement which extracts only the first column of each row. @@ -25,11 +26,11 @@ where A: 'q + IntoArguments<'q, DB>, { #[inline] - fn sql(&self) -> &'q str { + fn sql(self) -> SqlStr { self.inner.sql() } - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { self.inner.statement() } @@ -54,7 +55,7 @@ impl<'q, DB: Database, O> QueryScalar<'q, DB, O, ::Arguments<'q> } } -impl<'q, DB, O, A> QueryScalar<'q, DB, O, A> +impl QueryScalar<'_, DB, O, A> where DB: Database + HasStatementCache, { @@ -319,7 +320,7 @@ where /// ``` #[inline] pub fn query_scalar<'q, DB, O>( - sql: &'q str, + sql: impl SqlSafeStr, ) -> QueryScalar<'q, DB, O, ::Arguments<'q>> where DB: Database, @@ -337,7 +338,10 @@ where /// /// For details about prepared statements and allowed SQL syntax, see [`query()`][crate::query::query]. #[inline] -pub fn query_scalar_with<'q, DB, O, A>(sql: &'q str, arguments: A) -> QueryScalar<'q, DB, O, A> +pub fn query_scalar_with<'q, DB, O, A>( + sql: impl SqlSafeStr, + arguments: A, +) -> QueryScalar<'q, DB, O, A> where DB: Database, A: IntoArguments<'q, DB>, @@ -349,7 +353,7 @@ where /// Same as [`query_scalar_with`] but takes arguments as Result #[inline] pub fn query_scalar_with_result<'q, DB, O, A>( - sql: &'q str, + sql: impl SqlSafeStr, arguments: Result, ) -> QueryScalar<'q, DB, O, A> where @@ -363,9 +367,9 @@ where } // Make a SQL query from a statement, that is mapped to a concrete value. -pub fn query_statement_scalar<'q, DB, O>( - statement: &'q DB::Statement<'q>, -) -> QueryScalar<'q, DB, O, ::Arguments<'_>> +pub fn query_statement_scalar( + statement: &DB::Statement, +) -> QueryScalar<'_, DB, O, ::Arguments<'_>> where DB: Database, (O,): for<'r> FromRow<'r, DB::Row>, @@ -377,7 +381,7 @@ where // Make a SQL query from a statement, with the given arguments, that is mapped to a concrete value. pub fn query_statement_scalar_with<'q, DB, O, A>( - statement: &'q DB::Statement<'q>, + statement: &'q DB::Statement, arguments: A, ) -> QueryScalar<'q, DB, O, A> where diff --git a/sqlx-core/src/raw_sql.rs b/sqlx-core/src/raw_sql.rs index 37627d4453..43a7ec920a 100644 --- a/sqlx-core/src/raw_sql.rs +++ b/sqlx-core/src/raw_sql.rs @@ -4,6 +4,7 @@ use futures_core::stream::BoxStream; use crate::database::Database; use crate::error::BoxDynError; use crate::executor::{Execute, Executor}; +use crate::sql_str::{SqlSafeStr, SqlStr}; use crate::Error; // AUTHOR'S NOTE: I was just going to call this API `sql()` and `Sql`, respectively, @@ -15,7 +16,7 @@ use crate::Error; /// One or more raw SQL statements, separated by semicolons (`;`). /// /// See [`raw_sql()`] for details. -pub struct RawSql<'q>(&'q str); +pub struct RawSql(SqlStr); /// Execute one or more statements as raw SQL, separated by semicolons (`;`). /// @@ -114,16 +115,16 @@ pub struct RawSql<'q>(&'q str); /// /// See [MySQL manual, section 13.3.3: Statements That Cause an Implicit Commit](https://dev.mysql.com/doc/refman/8.0/en/implicit-commit.html) for details. /// See also: [MariaDB manual: SQL statements That Cause an Implicit Commit](https://mariadb.com/kb/en/sql-statements-that-cause-an-implicit-commit/). -pub fn raw_sql(sql: &str) -> RawSql<'_> { - RawSql(sql) +pub fn raw_sql(sql: impl SqlSafeStr) -> RawSql { + RawSql(sql.into_sql_str()) } -impl<'q, DB: Database> Execute<'q, DB> for RawSql<'q> { - fn sql(&self) -> &'q str { +impl<'q, DB: Database> Execute<'q, DB> for RawSql { + fn sql(self) -> SqlStr { self.0 } - fn statement(&self) -> Option<&::Statement<'q>> { + fn statement(&self) -> Option<&::Statement> { None } @@ -136,7 +137,7 @@ impl<'q, DB: Database> Execute<'q, DB> for RawSql<'q> { } } -impl<'q> RawSql<'q> { +impl RawSql { /// Execute the SQL string and return the total number of rows affected. #[inline] pub async fn execute<'e, E>( @@ -144,7 +145,6 @@ impl<'q> RawSql<'q> { executor: E, ) -> crate::Result<::QueryResult> where - 'q: 'e, E: Executor<'e>, { executor.execute(self).await @@ -157,7 +157,6 @@ impl<'q> RawSql<'q> { executor: E, ) -> BoxStream<'e, crate::Result<::QueryResult>> where - 'q: 'e, E: Executor<'e>, { executor.execute_many(self) @@ -172,7 +171,6 @@ impl<'q> RawSql<'q> { executor: E, ) -> BoxStream<'e, Result<::Row, Error>> where - 'q: 'e, E: Executor<'e>, { executor.fetch(self) @@ -194,7 +192,6 @@ impl<'q> RawSql<'q> { >, > where - 'q: 'e, E: Executor<'e>, { executor.fetch_many(self) @@ -213,7 +210,6 @@ impl<'q> RawSql<'q> { executor: E, ) -> crate::Result::Row>> where - 'q: 'e, E: Executor<'e>, { executor.fetch_all(self).await @@ -237,7 +233,6 @@ impl<'q> RawSql<'q> { executor: E, ) -> crate::Result<::Row> where - 'q: 'e, E: Executor<'e>, { executor.fetch_one(self).await @@ -261,7 +256,6 @@ impl<'q> RawSql<'q> { executor: E, ) -> crate::Result<::Row> where - 'q: 'e, E: Executor<'e>, { executor.fetch_one(self).await diff --git a/sqlx-core/src/rt/mod.rs b/sqlx-core/src/rt/mod.rs index 43409073ab..210b296233 100644 --- a/sqlx-core/src/rt/mod.rs +++ b/sqlx-core/src/rt/mod.rs @@ -26,6 +26,7 @@ pub enum JoinHandle { pub async fn timeout(duration: Duration, f: F) -> Result { #[cfg(feature = "_rt-tokio")] if rt_tokio::available() { + #[allow(clippy::needless_return)] return tokio::time::timeout(duration, f) .await .map_err(|_| TimeoutError(())); @@ -116,6 +117,7 @@ pub async fn yield_now() { pub fn test_block_on(f: F) -> F::Output { #[cfg(feature = "_rt-tokio")] { + #[allow(clippy::needless_return)] return tokio::runtime::Builder::new_current_thread() .enable_all() .build() diff --git a/sqlx-core/src/sql_str.rs b/sqlx-core/src/sql_str.rs new file mode 100644 index 0000000000..fb43a07453 --- /dev/null +++ b/sqlx-core/src/sql_str.rs @@ -0,0 +1,196 @@ +use std::borrow::{Borrow, Cow}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +/// A SQL string that is safe to execute on a database connection. +/// +/// A "safe" SQL string is one that is unlikely to contain a [SQL injection vulnerability][injection]. +/// +/// In practice, this means a string type that is unlikely to contain dynamic data or user input. +/// +/// `&'static str` is the only string type that satisfies the requirements of this trait +/// (ignoring [`String::leak()`] which has niche use-cases) and so is the only string type that +/// natively implements this trait by default. +/// +/// For other string types, use [`AssertSqlSafe`] to assert this property. +/// This is the only intended way to pass an owned `String` to [`query()`] and its related functions +/// as well as [`raw_sql()`]. +/// +/// The maintainers of SQLx take no responsibility for any data leaks or loss resulting from misuse +/// of this API. +/// +/// ### Motivation +/// This is designed to act as a speed bump against naively using `format!()` to add dynamic data +/// or user input to a query, which is a classic vector for SQL injection as SQLx does not +/// provide any sort of escaping or sanitization (which would have to be specially implemented +/// for each database flavor/locale). +/// +/// The recommended way to incorporate dynamic data or user input in a query is to use +/// bind parameters, which requires the query to execute as a prepared statement. +/// See [`query()`] for details. +/// +/// This trait and [`AssertSqlSafe`] are intentionally analogous to +/// [`std::panic::UnwindSafe`] and [`std::panic::AssertUnwindSafe`], respectively. +/// +/// [injection]: https://en.wikipedia.org/wiki/SQL_injection +/// [`query()`]: crate::query::query +/// [`raw_sql()`]: crate::raw_sql::raw_sql +pub trait SqlSafeStr { + /// Convert `self` to a [`SqlStr`]. + fn into_sql_str(self) -> SqlStr; +} + +impl SqlSafeStr for &'static str { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Static(self)) + } +} + +/// Assert that a query string is safe to execute on a database connection. +/// +/// Using this API means that **you** have made sure that the string contents do not contain a +/// [SQL injection vulnerability][injection]. It means that, if the string was constructed +/// dynamically, and/or from user input, you have taken care to sanitize the input yourself. +/// SQLx does not provide any sort of sanitization; the design of SQLx prefers the use +/// of prepared statements for dynamic input. +/// +/// The maintainers of SQLx take no responsibility for any data leaks or loss resulting from misuse +/// of this API. **Use at your own risk.** +/// +/// Note that `&'static str` implements [`SqlSafeStr`] directly and so does not need to be wrapped +/// with this type. +/// +/// [injection]: https://en.wikipedia.org/wiki/SQL_injection +pub struct AssertSqlSafe(pub T); + +/// Note: copies the string. +/// +/// It is recommended to pass one of the supported owned string types instead. +impl SqlSafeStr for AssertSqlSafe<&str> { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Arced(self.0.into())) + } +} +impl SqlSafeStr for AssertSqlSafe { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Owned(self.0)) + } +} + +impl SqlSafeStr for AssertSqlSafe> { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Boxed(self.0)) + } +} + +// Note: this is not implemented for `Rc` because it would make `QueryString: !Send`. +impl SqlSafeStr for AssertSqlSafe> { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Arced(self.0)) + } +} + +impl SqlSafeStr for AssertSqlSafe> { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::ArcString(self.0)) + } +} + +impl SqlSafeStr for AssertSqlSafe> { + fn into_sql_str(self) -> SqlStr { + match self.0 { + Cow::Borrowed(str) => str.into_sql_str(), + Cow::Owned(str) => AssertSqlSafe(str).into_sql_str(), + } + } +} + +/// A SQL string that is ready to execute on a database connection. +/// +/// This is essentially `Cow<'static, str>` but which can be constructed from additional types +/// without copying. +/// +/// See [`SqlSafeStr`] for details. +#[derive(Debug)] +pub struct SqlStr(Repr); + +#[derive(Debug)] +enum Repr { + /// We need a variant to memoize when we already have a static string, so we don't copy it. + Static(&'static str), + /// Thanks to the new niche in `String`, this doesn't increase the size beyond 3 words. + /// We essentially get all these variants for free. + Owned(String), + Boxed(Box), + Arced(Arc), + /// Allows for dynamic shared ownership with `query_builder`. + ArcString(Arc), +} + +impl Clone for SqlStr { + fn clone(&self) -> Self { + Self(match &self.0 { + Repr::Static(s) => Repr::Static(s), + Repr::Arced(s) => Repr::Arced(s.clone()), + _ => Repr::Arced(self.as_str().into()), + }) + } +} + +impl SqlSafeStr for SqlStr { + #[inline] + fn into_sql_str(self) -> SqlStr { + self + } +} + +impl SqlStr { + /// Borrow the inner query string. + #[inline] + pub fn as_str(&self) -> &str { + match &self.0 { + Repr::Static(s) => s, + Repr::Owned(s) => s, + Repr::Boxed(s) => s, + Repr::Arced(s) => s, + Repr::ArcString(s) => s, + } + } +} + +impl AsRef for SqlStr { + #[inline] + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl Borrow for SqlStr { + #[inline] + fn borrow(&self) -> &str { + self.as_str() + } +} + +impl PartialEq for SqlStr +where + T: AsRef, +{ + fn eq(&self, other: &T) -> bool { + self.as_str() == other.as_ref() + } +} + +impl Eq for SqlStr {} + +impl Hash for SqlStr { + fn hash(&self, state: &mut H) { + self.as_str().hash(state) + } +} diff --git a/sqlx-core/src/statement.rs b/sqlx-core/src/statement.rs index 17dfd6428d..5173d1b191 100644 --- a/sqlx-core/src/statement.rs +++ b/sqlx-core/src/statement.rs @@ -6,6 +6,7 @@ use crate::from_row::FromRow; use crate::query::Query; use crate::query_as::QueryAs; use crate::query_scalar::QueryScalar; +use crate::sql_str::SqlStr; use either::Either; /// An explicitly prepared statement. @@ -16,15 +17,17 @@ use either::Either; /// /// Statements can be re-used with any connection and on first-use it will be re-prepared and /// cached within the connection. -pub trait Statement<'q>: Send + Sync { +pub trait Statement: Send + Sync { type Database: Database; /// Creates an owned statement from this statement reference. This copies /// the original SQL text. - fn to_owned(&self) -> ::Statement<'static>; + fn to_owned(&self) -> ::Statement; /// Get the original SQL text used to create this statement. - fn sql(&self) -> &str; + fn sql_cloned(&self) -> SqlStr; + + fn into_sql(self) -> SqlStr; /// Get the expected parameters for this statement. /// diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 2a84ff6555..7d45dd2b98 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -7,6 +7,7 @@ use futures_core::future::BoxFuture; use crate::database::Database; use crate::error::Error; use crate::pool::MaybePoolConnection; +use crate::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; /// Generic management of database transactions. /// @@ -199,7 +200,7 @@ where // } // } -impl<'c, DB> Debug for Transaction<'c, DB> +impl Debug for Transaction<'_, DB> where DB: Database, { @@ -209,7 +210,7 @@ where } } -impl<'c, DB> Deref for Transaction<'c, DB> +impl Deref for Transaction<'_, DB> where DB: Database, { @@ -221,7 +222,7 @@ where } } -impl<'c, DB> DerefMut for Transaction<'c, DB> +impl DerefMut for Transaction<'_, DB> where DB: Database, { @@ -235,13 +236,13 @@ where // `PgAdvisoryLockGuard`. // // See: https://github.com/launchbadge/sqlx/issues/2520 -impl<'c, DB: Database> AsMut for Transaction<'c, DB> { +impl AsMut for Transaction<'_, DB> { fn as_mut(&mut self) -> &mut DB::Connection { &mut self.connection } } -impl<'c, 't, DB: Database> crate::acquire::Acquire<'t> for &'t mut Transaction<'c, DB> { +impl<'t, DB: Database> crate::acquire::Acquire<'t> for &'t mut Transaction<'_, DB> { type Database = DB; type Connection = &'t mut ::Connection; @@ -257,7 +258,7 @@ impl<'c, 't, DB: Database> crate::acquire::Acquire<'t> for &'t mut Transaction<' } } -impl<'c, DB> Drop for Transaction<'c, DB> +impl Drop for Transaction<'_, DB> where DB: Database, { @@ -274,29 +275,30 @@ where } } -pub fn begin_ansi_transaction_sql(depth: usize) -> Cow<'static, str> { +pub fn begin_ansi_transaction_sql(depth: usize) -> SqlStr { if depth == 0 { - Cow::Borrowed("BEGIN") + "BEGIN".into_sql_str() } else { - Cow::Owned(format!("SAVEPOINT _sqlx_savepoint_{depth}")) + AssertSqlSafe(format!("SAVEPOINT _sqlx_savepoint_{depth}")).into_sql_str() } } -pub fn commit_ansi_transaction_sql(depth: usize) -> Cow<'static, str> { +pub fn commit_ansi_transaction_sql(depth: usize) -> SqlStr { if depth == 1 { - Cow::Borrowed("COMMIT") + "COMMIT".into_sql_str() } else { - Cow::Owned(format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", depth - 1)) + AssertSqlSafe(format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", depth - 1)).into_sql_str() } } -pub fn rollback_ansi_transaction_sql(depth: usize) -> Cow<'static, str> { +pub fn rollback_ansi_transaction_sql(depth: usize) -> SqlStr { if depth == 1 { - Cow::Borrowed("ROLLBACK") + "ROLLBACK".into_sql_str() } else { - Cow::Owned(format!( + AssertSqlSafe(format!( "ROLLBACK TO SAVEPOINT _sqlx_savepoint_{}", depth - 1 )) + .into_sql_str() } } diff --git a/sqlx-core/src/type_checking.rs b/sqlx-core/src/type_checking.rs index 5766124530..1da6b7ab3f 100644 --- a/sqlx-core/src/type_checking.rs +++ b/sqlx-core/src/type_checking.rs @@ -112,7 +112,7 @@ where } } -impl<'v, DB> Debug for FmtValue<'v, DB> +impl Debug for FmtValue<'_, DB> where DB: Database, { diff --git a/sqlx-macros-core/src/database/mod.rs b/sqlx-macros-core/src/database/mod.rs index a2d0a1fa0d..90c2048386 100644 --- a/sqlx-macros-core/src/database/mod.rs +++ b/sqlx-macros-core/src/database/mod.rs @@ -8,6 +8,8 @@ use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; +use sqlx_core::sql_str::AssertSqlSafe; +use sqlx_core::sql_str::SqlSafeStr; use sqlx_core::type_checking::TypeChecking; #[cfg(any(feature = "postgres", feature = "mysql", feature = "_sqlite"))] @@ -58,7 +60,8 @@ impl CachingDescribeBlocking { } }; - conn.describe(query).await + conn.describe(AssertSqlSafe(query.to_string()).into_sql_str()) + .await }) } } diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index 19b3a6f27c..89afdce474 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -15,6 +15,7 @@ use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; +use sqlx_core::sql_str::SqlStr; use sqlx_core::transaction::TransactionManager; use std::borrow::Cow; use std::{future, pin::pin}; @@ -82,7 +83,7 @@ impl AnyConnectionBackend for MySqlConnection { fn fetch_many<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxStream<'q, sqlx_core::Result>> { @@ -108,7 +109,7 @@ impl AnyConnectionBackend for MySqlConnection { fn fetch_optional<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxFuture<'q, sqlx_core::Result>> { @@ -135,20 +136,17 @@ impl AnyConnectionBackend for MySqlConnection { fn prepare_with<'c, 'q: 'c>( &'c mut self, - sql: &'q str, + sql: SqlStr, _parameters: &[AnyTypeInfo], - ) -> BoxFuture<'c, sqlx_core::Result>> { + ) -> BoxFuture<'c, sqlx_core::Result> { Box::pin(async move { let statement = Executor::prepare_with(self, sql, &[]).await?; - AnyStatement::try_from_statement( - sql, - &statement, - statement.metadata.column_names.clone(), - ) + let column_names = statement.metadata.column_names.clone(); + AnyStatement::try_from_statement(statement, column_names) }) } - fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, sqlx_core::Result>> { + fn describe(&mut self, sql: SqlStr) -> BoxFuture<'_, sqlx_core::Result>> { Box::pin(async move { let describe = Executor::describe(self, sql).await?; describe.try_into_any() diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index 85a9d84f96..ec7d8e4c2c 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -186,7 +186,7 @@ impl<'a> DoHandshake<'a> { } } -impl<'a> WithSocket for DoHandshake<'a> { +impl WithSocket for DoHandshake<'_> { type Output = Result; async fn with_socket(self, socket: S) -> Self::Output { diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index 4f5af4bf6d..bba63a3e07 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -22,10 +22,11 @@ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::TryStreamExt; -use std::{borrow::Cow, pin::pin, sync::Arc}; +use sqlx_core::sql_str::{SqlSafeStr, SqlStr}; +use std::{pin::pin, sync::Arc}; impl MySqlConnection { - async fn prepare_statement<'c>( + async fn prepare_statement( &mut self, sql: &str, ) -> Result<(u32, MySqlStatementMetadata), Error> { @@ -72,7 +73,7 @@ impl MySqlConnection { Ok((id, metadata)) } - async fn get_or_prepare_statement<'c>( + async fn get_or_prepare_statement( &mut self, sql: &str, ) -> Result<(u32, MySqlStatementMetadata), Error> { @@ -101,13 +102,11 @@ impl MySqlConnection { #[allow(clippy::needless_lifetimes)] pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>( &'c mut self, - sql: &'q str, + sql: SqlStr, arguments: Option, persistent: bool, ) -> Result, Error>> + 'e, Error> { - let mut logger = QueryLogger::new(sql, self.inner.log_settings.clone()); - self.inner.stream.wait_until_ready().await?; self.inner.stream.waiting.push_back(Waiting::Result); @@ -120,7 +119,7 @@ impl MySqlConnection { let (mut column_names, format, mut needs_metadata) = if let Some(arguments) = arguments { if persistent && self.inner.cache_statement.is_enabled() { let (id, metadata) = self - .get_or_prepare_statement(sql) + .get_or_prepare_statement(sql.as_str()) .await?; // https://dev.mysql.com/doc/internals/en/com-stmt-execute.html @@ -134,7 +133,7 @@ impl MySqlConnection { (metadata.column_names, MySqlValueFormat::Binary, false) } else { let (id, metadata) = self - .prepare_statement(sql) + .prepare_statement(sql.as_str()) .await?; // https://dev.mysql.com/doc/internals/en/com-stmt-execute.html @@ -151,10 +150,11 @@ impl MySqlConnection { } } else { // https://dev.mysql.com/doc/internals/en/com-query.html - self.inner.stream.send_packet(Query(sql)).await?; + self.inner.stream.send_packet(Query(sql.as_str())).await?; (Arc::default(), MySqlValueFormat::Text, true) }; + let mut logger = QueryLogger::new(sql, self.inner.log_settings.clone()); loop { // query response is a meta-packet which may be one of: @@ -261,11 +261,11 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { 'q: 'e, E: 'q, { - let sql = query.sql(); let arguments = query.take_arguments().map_err(Error::Encode); let persistent = query.persistent(); Box::pin(try_stream! { + let sql = query.sql(); let arguments = arguments?; let mut s = pin!(self.run(sql, arguments, persistent).await?); @@ -297,21 +297,22 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { }) } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, _parameters: &'e [MySqlTypeInfo], - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result> where 'c: 'e, { + let sql = sql.into_sql_str(); Box::pin(async move { self.inner.stream.wait_until_ready().await?; let metadata = if self.inner.cache_statement.is_enabled() { - self.get_or_prepare_statement(sql).await?.1 + self.get_or_prepare_statement(sql.as_str()).await?.1 } else { - let (id, metadata) = self.prepare_statement(sql).await?; + let (id, metadata) = self.prepare_statement(sql.as_str()).await?; self.inner .stream @@ -322,7 +323,7 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { }; Ok(MySqlStatement { - sql: Cow::Borrowed(sql), + sql, // metadata has internal Arcs for expensive data structures metadata: metadata.clone(), }) @@ -330,14 +331,15 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { } #[doc(hidden)] - fn describe<'e, 'q: 'e>(self, sql: &'q str) -> BoxFuture<'e, Result, Error>> + fn describe<'e>(self, sql: impl SqlSafeStr) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { + let sql = sql.into_sql_str(); Box::pin(async move { self.inner.stream.wait_until_ready().await?; - let (id, metadata) = self.prepare_statement(sql).await?; + let (id, metadata) = self.prepare_statement(sql.as_str()).await?; self.inner .stream diff --git a/sqlx-mysql/src/database.rs b/sqlx-mysql/src/database.rs index d03a567284..0e3f51f532 100644 --- a/sqlx-mysql/src/database.rs +++ b/sqlx-mysql/src/database.rs @@ -28,7 +28,7 @@ impl Database for MySql { type Arguments<'q> = MySqlArguments; type ArgumentBuffer<'q> = Vec; - type Statement<'q> = MySqlStatement<'q>; + type Statement = MySqlStatement; const NAME: &'static str = "MySQL"; diff --git a/sqlx-mysql/src/migrate.rs b/sqlx-mysql/src/migrate.rs index 79b55ace3c..40c83c9fa2 100644 --- a/sqlx-mysql/src/migrate.rs +++ b/sqlx-mysql/src/migrate.rs @@ -4,6 +4,7 @@ use std::time::Instant; use futures_core::future::BoxFuture; pub(crate) use sqlx_core::migrate::*; +use sqlx_core::sql_str::AssertSqlSafe; use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; @@ -37,7 +38,7 @@ impl MigrateDatabase for MySql { let mut conn = options.connect().await?; let _ = conn - .execute(&*format!("CREATE DATABASE `{database}`")) + .execute(AssertSqlSafe(format!("CREATE DATABASE `{database}`"))) .await?; Ok(()) @@ -66,7 +67,9 @@ impl MigrateDatabase for MySql { let mut conn = options.connect().await?; let _ = conn - .execute(&*format!("DROP DATABASE IF EXISTS `{database}`")) + .execute(AssertSqlSafe(format!( + "DROP DATABASE IF EXISTS `{database}`" + ))) .await?; Ok(()) @@ -200,7 +203,8 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( .await?; let _ = tx - .execute(&*migration.sql) + // We can't use `SqlStr` in `Migration` because it can't be used in a const context + .execute(AssertSqlSafe(migration.sql.to_string())) .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; @@ -269,7 +273,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( .execute(&mut *tx) .await?; - tx.execute(&*migration.sql).await?; + tx.execute(AssertSqlSafe(migration.sql.to_string())).await?; // language=SQL let _ = query(r#"DELETE FROM _sqlx_migrations WHERE version = ?"#) diff --git a/sqlx-mysql/src/options/connect.rs b/sqlx-mysql/src/options/connect.rs index 116a49ccad..d3f14d64f5 100644 --- a/sqlx-mysql/src/options/connect.rs +++ b/sqlx-mysql/src/options/connect.rs @@ -4,6 +4,7 @@ use crate::executor::Executor; use crate::{MySqlConnectOptions, MySqlConnection}; use futures_core::future::BoxFuture; use log::LevelFilter; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_core::Url; use std::time::Duration; @@ -77,7 +78,7 @@ impl ConnectOptions for MySqlConnectOptions { } if !options.is_empty() { - conn.execute(&*format!(r#"SET {};"#, options.join(","))) + conn.execute(AssertSqlSafe(format!(r#"SET {};"#, options.join(",")))) .await?; } diff --git a/sqlx-mysql/src/protocol/statement/execute.rs b/sqlx-mysql/src/protocol/statement/execute.rs index 6e51e7b564..89010315bb 100644 --- a/sqlx-mysql/src/protocol/statement/execute.rs +++ b/sqlx-mysql/src/protocol/statement/execute.rs @@ -11,7 +11,7 @@ pub struct Execute<'q> { pub arguments: &'q MySqlArguments, } -impl<'q> ProtocolEncode<'_, Capabilities> for Execute<'q> { +impl ProtocolEncode<'_, Capabilities> for Execute<'_> { fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), crate::Error> { buf.push(0x17); // COM_STMT_EXECUTE buf.extend(&self.statement.to_le_bytes()); diff --git a/sqlx-mysql/src/statement.rs b/sqlx-mysql/src/statement.rs index e9578403e1..e6b0065961 100644 --- a/sqlx-mysql/src/statement.rs +++ b/sqlx-mysql/src/statement.rs @@ -5,14 +5,14 @@ use crate::ext::ustr::UStr; use crate::HashMap; use crate::{MySql, MySqlArguments, MySqlTypeInfo}; use either::Either; -use std::borrow::Cow; +use sqlx_core::sql_str::SqlStr; use std::sync::Arc; pub(crate) use sqlx_core::statement::*; #[derive(Debug, Clone)] -pub struct MySqlStatement<'q> { - pub(crate) sql: Cow<'q, str>, +pub struct MySqlStatement { + pub(crate) sql: SqlStr, pub(crate) metadata: MySqlStatementMetadata, } @@ -23,18 +23,22 @@ pub(crate) struct MySqlStatementMetadata { pub(crate) parameters: usize, } -impl<'q> Statement<'q> for MySqlStatement<'q> { +impl Statement for MySqlStatement { type Database = MySql; - fn to_owned(&self) -> MySqlStatement<'static> { - MySqlStatement::<'static> { - sql: Cow::Owned(self.sql.clone().into_owned()), + fn to_owned(&self) -> MySqlStatement { + MySqlStatement { + sql: self.sql.clone(), metadata: self.metadata.clone(), } } - fn sql(&self) -> &str { - &self.sql + fn sql_cloned(&self) -> SqlStr { + self.sql.clone() + } + + fn into_sql(self) -> SqlStr { + self.sql } fn parameters(&self) -> Option> { @@ -48,8 +52,8 @@ impl<'q> Statement<'q> for MySqlStatement<'q> { impl_statement_query!(MySqlArguments); } -impl ColumnIndex> for &'_ str { - fn index(&self, statement: &MySqlStatement<'_>) -> Result { +impl ColumnIndex for &'_ str { + fn index(&self, statement: &MySqlStatement) -> Result { statement .metadata .column_names diff --git a/sqlx-mysql/src/testing/mod.rs b/sqlx-mysql/src/testing/mod.rs index 1981cf73c5..d9a612a79a 100644 --- a/sqlx-mysql/src/testing/mod.rs +++ b/sqlx-mysql/src/testing/mod.rs @@ -1,5 +1,6 @@ use std::ops::Deref; use std::str::FromStr; +use std::sync::Arc; use std::time::Duration; use futures_core::future::BoxFuture; @@ -8,6 +9,7 @@ use once_cell::sync::OnceCell; use sqlx_core::connection::Connection; use sqlx_core::query_builder::QueryBuilder; use sqlx_core::query_scalar::query_scalar; +use sqlx_core::sql_str::AssertSqlSafe; use std::fmt::Write; use crate::error::Error; @@ -55,15 +57,16 @@ impl TestSupport for MySql { let mut deleted_db_names = Vec::with_capacity(delete_db_names.len()); - let mut command = String::new(); + let mut command_arced = Arc::new(String::new()); for db_name in &delete_db_names { + let command = Arc::get_mut(&mut command_arced).unwrap(); command.clear(); let db_name = format!("_sqlx_test_database_{db_name}"); writeln!(command, "drop database if exists {db_name:?};").ok(); - match conn.execute(&*command).await { + match conn.execute(AssertSqlSafe(command_arced.clone())).await { Ok(_deleted) => { deleted_db_names.push(db_name); } @@ -162,7 +165,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { .execute(&mut *conn) .await?; - conn.execute(&format!("create database {db_name:?}")[..]) + conn.execute(AssertSqlSafe(format!("create database {db_name:?}"))) .await?; eprintln!("created database {db_name}"); @@ -187,7 +190,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { async fn do_cleanup(conn: &mut MySqlConnection, db_name: &str) -> Result<(), Error> { let delete_db_command = format!("drop database if exists {db_name:?};"); - conn.execute(&*delete_db_command).await?; + conn.execute(AssertSqlSafe(delete_db_command)).await?; query("delete from _sqlx_test.databases where db_name = $1::text") .bind(db_name) .execute(&mut *conn) diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index 545cb5f4f2..3bed4c030d 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use futures_core::future::BoxFuture; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; use crate::connection::Waiting; use crate::error::Error; @@ -22,14 +23,15 @@ impl TransactionManager for MySqlTransactionManager { ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; + let statement = match statement { // custom `BEGIN` statements are not allowed if we're already in a transaction // (we need to issue a `SAVEPOINT` instead) Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), - Some(statement) => statement, + Some(statement) => AssertSqlSafe(statement).into_sql_str(), None => begin_ansi_transaction_sql(depth), }; - conn.execute(&*statement).await?; + conn.execute(statement).await?; if !conn.in_transaction() { return Err(Error::BeginFailed); } @@ -44,7 +46,7 @@ impl TransactionManager for MySqlTransactionManager { let depth = conn.inner.transaction_depth; if depth > 0 { - conn.execute(&*commit_ansi_transaction_sql(depth)).await?; + conn.execute(commit_ansi_transaction_sql(depth)).await?; conn.inner.transaction_depth = depth - 1; } @@ -57,7 +59,7 @@ impl TransactionManager for MySqlTransactionManager { let depth = conn.inner.transaction_depth; if depth > 0 { - conn.execute(&*rollback_ansi_transaction_sql(depth)).await?; + conn.execute(rollback_ansi_transaction_sql(depth)).await?; conn.inner.transaction_depth = depth - 1; } @@ -73,7 +75,7 @@ impl TransactionManager for MySqlTransactionManager { conn.inner.stream.sequence_id = 0; conn.inner .stream - .write_packet(Query(&rollback_ansi_transaction_sql(depth))) + .write_packet(Query(rollback_ansi_transaction_sql(depth).as_str())) .expect("BUG: unexpected error queueing ROLLBACK"); conn.inner.transaction_depth = depth - 1; diff --git a/sqlx-mysql/src/types/text.rs b/sqlx-mysql/src/types/text.rs index ad61c1bee8..363ec02439 100644 --- a/sqlx-mysql/src/types/text.rs +++ b/sqlx-mysql/src/types/text.rs @@ -16,7 +16,7 @@ impl Type for Text { } } -impl<'q, T> Encode<'q, MySql> for Text +impl Encode<'_, MySql> for Text where T: Display, { diff --git a/sqlx-postgres/src/advisory_lock.rs b/sqlx-postgres/src/advisory_lock.rs index d1aef176fb..047ede6be6 100644 --- a/sqlx-postgres/src/advisory_lock.rs +++ b/sqlx-postgres/src/advisory_lock.rs @@ -362,7 +362,7 @@ impl<'lock, C: AsMut> PgAdvisoryLockGuard<'lock, C> { } } -impl<'lock, C: AsMut + AsRef> Deref for PgAdvisoryLockGuard<'lock, C> { +impl + AsRef> Deref for PgAdvisoryLockGuard<'_, C> { type Target = PgConnection; fn deref(&self) -> &Self::Target { @@ -376,16 +376,14 @@ impl<'lock, C: AsMut + AsRef> Deref for PgAdvisoryLo /// However, replacing the connection with a different one using, e.g. [`std::mem::replace()`] /// is a logic error and will cause a warning to be logged by the PostgreSQL server when this /// guard attempts to release the lock. -impl<'lock, C: AsMut + AsRef> DerefMut - for PgAdvisoryLockGuard<'lock, C> -{ +impl + AsRef> DerefMut for PgAdvisoryLockGuard<'_, C> { fn deref_mut(&mut self) -> &mut Self::Target { self.conn.as_mut().expect(NONE_ERR).as_mut() } } -impl<'lock, C: AsMut + AsRef> AsRef - for PgAdvisoryLockGuard<'lock, C> +impl + AsRef> AsRef + for PgAdvisoryLockGuard<'_, C> { fn as_ref(&self) -> &PgConnection { self.conn.as_ref().expect(NONE_ERR).as_ref() @@ -398,7 +396,7 @@ impl<'lock, C: AsMut + AsRef> AsRef /// However, replacing the connection with a different one using, e.g. [`std::mem::replace()`] /// is a logic error and will cause a warning to be logged by the PostgreSQL server when this /// guard attempts to release the lock. -impl<'lock, C: AsMut> AsMut for PgAdvisoryLockGuard<'lock, C> { +impl> AsMut for PgAdvisoryLockGuard<'_, C> { fn as_mut(&mut self) -> &mut PgConnection { self.conn.as_mut().expect(NONE_ERR).as_mut() } @@ -407,7 +405,7 @@ impl<'lock, C: AsMut> AsMut for PgAdvisoryLockGuard< /// Queues a `pg_advisory_unlock()` call on the wrapped connection which will be flushed /// to the server the next time it is used, or when it is returned to [`PgPool`][crate::PgPool] /// in the case of [`PoolConnection`][crate::pool::PoolConnection]. -impl<'lock, C: AsMut> Drop for PgAdvisoryLockGuard<'lock, C> { +impl> Drop for PgAdvisoryLockGuard<'_, C> { fn drop(&mut self) { if let Some(mut conn) = self.conn.take() { // Queue a simple query message to execute next time the connection is used. diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index 762f53e5df..aaeff02ba1 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -5,6 +5,7 @@ use crate::{ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt}; +use sqlx_core::sql_str::SqlStr; use std::borrow::Cow; use std::{future, pin::pin}; @@ -84,7 +85,7 @@ impl AnyConnectionBackend for PgConnection { fn fetch_many<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxStream<'q, sqlx_core::Result>> { @@ -110,7 +111,7 @@ impl AnyConnectionBackend for PgConnection { fn fetch_optional<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxFuture<'q, sqlx_core::Result>> { @@ -135,20 +136,17 @@ impl AnyConnectionBackend for PgConnection { fn prepare_with<'c, 'q: 'c>( &'c mut self, - sql: &'q str, + sql: SqlStr, _parameters: &[AnyTypeInfo], - ) -> BoxFuture<'c, sqlx_core::Result>> { + ) -> BoxFuture<'c, sqlx_core::Result> { Box::pin(async move { let statement = Executor::prepare_with(self, sql, &[]).await?; - AnyStatement::try_from_statement( - sql, - &statement, - statement.metadata.column_names.clone(), - ) + let colunn_names = statement.metadata.column_names.clone(); + AnyStatement::try_from_statement(statement, colunn_names) }) } - fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, sqlx_core::Result>> { + fn describe<'c>(&mut self, sql: SqlStr) -> BoxFuture<'_, sqlx_core::Result>> { Box::pin(async move { let describe = Executor::describe(self, sql).await?; diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index a27578c56c..56c11cee94 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -12,6 +12,7 @@ use crate::HashMap; use crate::{PgColumn, PgConnection, PgTypeInfo}; use smallvec::SmallVec; use sqlx_core::query_builder::QueryBuilder; +use sqlx_core::sql_str::AssertSqlSafe; use std::sync::Arc; /// Describes the type of the `pg_type.typtype` column @@ -543,7 +544,7 @@ WHERE rngtypid = $1 } let (Json(explains),): (Json>,) = - query_as(&explain).fetch_one(self).await?; + query_as(AssertSqlSafe(explain)).fetch_one(self).await?; let mut nullables = Vec::new(); diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index 076c4209f6..c1ee2a2307 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -15,10 +15,11 @@ use crate::{ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; -use futures_util::TryStreamExt; +use futures_util::{pin_mut, TryStreamExt}; use sqlx_core::arguments::Arguments; +use sqlx_core::sql_str::{SqlSafeStr, SqlStr}; use sqlx_core::Either; -use std::{borrow::Cow, pin::pin, sync::Arc}; +use std::{pin::pin, sync::Arc}; async fn prepare( conn: &mut PgConnection, @@ -159,7 +160,7 @@ impl PgConnection { self.inner.pending_ready_for_query_count += 1; } - async fn get_or_prepare<'a>( + async fn get_or_prepare( &mut self, sql: &str, parameters: &[PgTypeInfo], @@ -192,14 +193,12 @@ impl PgConnection { pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>( &'c mut self, - query: &'q str, + query: SqlStr, arguments: Option, limit: u8, persistent: bool, metadata_opt: Option>, ) -> Result, Error>> + 'e, Error> { - let mut logger = QueryLogger::new(query, self.inner.log_settings.clone()); - // before we continue, wait until we are "ready" to accept more queries self.wait_until_ready().await?; @@ -222,7 +221,7 @@ impl PgConnection { // prepare the statement if this our first time executing it // always return the statement ID here let (statement, metadata_) = self - .get_or_prepare(query, &arguments.types, persistent, metadata_opt) + .get_or_prepare(query.as_str(), &arguments.types, persistent, metadata_opt) .await?; metadata = metadata_; @@ -273,7 +272,7 @@ impl PgConnection { PgValueFormat::Binary } else { // Query will trigger a ReadyForQuery - self.inner.stream.write_msg(Query(query))?; + self.inner.stream.write_msg(Query(query.as_str()))?; self.inner.pending_ready_for_query_count += 1; // metadata starts out as "nothing" @@ -284,6 +283,7 @@ impl PgConnection { }; self.inner.stream.flush().await?; + let mut logger = QueryLogger::new(query, self.inner.log_settings.clone()); Ok(try_stream! { loop { @@ -384,7 +384,6 @@ impl<'c> Executor<'c> for &'c mut PgConnection { 'q: 'e, E: 'q, { - let sql = query.sql(); // False positive: https://github.com/rust-lang/rust-clippy/issues/12560 #[allow(clippy::map_clone)] let metadata = query.statement().map(|s| Arc::clone(&s.metadata)); @@ -393,7 +392,9 @@ impl<'c> Executor<'c> for &'c mut PgConnection { Box::pin(try_stream! { let arguments = arguments?; - let mut s = pin!(self.run(sql, arguments, 0, persistent, metadata).await?); + let sql = query.sql(); + let s = self.run(sql, arguments, 0, persistent, metadata).await?; + pin_mut!(s); while let Some(v) = s.try_next().await? { r#yield!(v); @@ -410,7 +411,6 @@ impl<'c> Executor<'c> for &'c mut PgConnection { 'q: 'e, E: 'q, { - let sql = query.sql(); // False positive: https://github.com/rust-lang/rust-clippy/issues/12560 #[allow(clippy::map_clone)] let metadata = query.statement().map(|s| Arc::clone(&s.metadata)); @@ -418,6 +418,7 @@ impl<'c> Executor<'c> for &'c mut PgConnection { let persistent = query.persistent(); Box::pin(async move { + let sql = query.sql(); let arguments = arguments?; let mut s = pin!(self.run(sql, arguments, 1, persistent, metadata).await?); @@ -436,37 +437,38 @@ impl<'c> Executor<'c> for &'c mut PgConnection { }) } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, parameters: &'e [PgTypeInfo], - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result> where 'c: 'e, { + let sql = sql.into_sql_str(); Box::pin(async move { self.wait_until_ready().await?; - let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?; + let (_, metadata) = self + .get_or_prepare(sql.as_str(), parameters, true, None) + .await?; - Ok(PgStatement { - sql: Cow::Borrowed(sql), - metadata, - }) + Ok(PgStatement { sql, metadata }) }) } - fn describe<'e, 'q: 'e>( + fn describe<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { + let sql = sql.into_sql_str(); Box::pin(async move { self.wait_until_ready().await?; - let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?; + let (stmt_id, metadata) = self.get_or_prepare(sql.as_str(), &[], true, None).await?; let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?; diff --git a/sqlx-postgres/src/connection/tls.rs b/sqlx-postgres/src/connection/tls.rs index 16b7333bf5..a49c9caa8c 100644 --- a/sqlx-postgres/src/connection/tls.rs +++ b/sqlx-postgres/src/connection/tls.rs @@ -7,7 +7,7 @@ use crate::{PgConnectOptions, PgSslMode}; pub struct MaybeUpgradeTls<'a>(pub &'a PgConnectOptions); -impl<'a> WithSocket for MaybeUpgradeTls<'a> { +impl WithSocket for MaybeUpgradeTls<'_> { type Output = crate::Result>; async fn with_socket(self, socket: S) -> Self::Output { diff --git a/sqlx-postgres/src/database.rs b/sqlx-postgres/src/database.rs index 876e295899..fbc762615b 100644 --- a/sqlx-postgres/src/database.rs +++ b/sqlx-postgres/src/database.rs @@ -30,7 +30,7 @@ impl Database for Postgres { type Arguments<'q> = PgArguments; type ArgumentBuffer<'q> = PgArgumentBuffer; - type Statement<'q> = PgStatement<'q>; + type Statement = PgStatement; const NAME: &'static str = "PostgreSQL"; diff --git a/sqlx-postgres/src/listener.rs b/sqlx-postgres/src/listener.rs index 17a46a916f..0c088f9487 100644 --- a/sqlx-postgres/src/listener.rs +++ b/sqlx-postgres/src/listener.rs @@ -7,6 +7,7 @@ use futures_core::future::BoxFuture; use futures_core::stream::{BoxStream, Stream}; use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use sqlx_core::acquire::Acquire; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; use sqlx_core::transaction::Transaction; use sqlx_core::Either; use tracing::Instrument; @@ -116,7 +117,7 @@ impl PgListener { pub async fn listen(&mut self, channel: &str) -> Result<(), Error> { self.connection() .await? - .execute(&*format!(r#"LISTEN "{}""#, ident(channel))) + .execute(AssertSqlSafe(format!(r#"LISTEN "{}""#, ident(channel)))) .await?; self.channels.push(channel.to_owned()); @@ -133,7 +134,10 @@ impl PgListener { self.channels.extend(channels.into_iter().map(|s| s.into())); let query = build_listen_all_query(&self.channels[beg..]); - self.connection().await?.execute(&*query).await?; + self.connection() + .await? + .execute(AssertSqlSafe(query)) + .await?; Ok(()) } @@ -145,7 +149,7 @@ impl PgListener { // UNLISTEN (we've disconnected anyways) if let Some(connection) = self.connection.as_mut() { connection - .execute(&*format!(r#"UNLISTEN "{}""#, ident(channel))) + .execute(AssertSqlSafe(format!(r#"UNLISTEN "{}""#, ident(channel)))) .await?; } @@ -176,7 +180,7 @@ impl PgListener { connection.inner.stream.notifications = self.buffer_tx.take(); connection - .execute(&*build_listen_all_query(&self.channels)) + .execute(AssertSqlSafe(build_listen_all_query(&self.channels))) .await?; self.connection = Some(connection); @@ -417,14 +421,15 @@ impl<'c> Executor<'c> for &'c mut PgListener { async move { self.connection().await?.fetch_optional(query).await }.boxed() } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - query: &'q str, + query: impl SqlSafeStr, parameters: &'e [PgTypeInfo], - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result> where 'c: 'e, { + let query = query.into_sql_str(); async move { self.connection() .await? @@ -435,13 +440,14 @@ impl<'c> Executor<'c> for &'c mut PgListener { } #[doc(hidden)] - fn describe<'e, 'q: 'e>( + fn describe<'e>( self, - query: &'q str, + query: impl SqlSafeStr, ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { + let query = query.into_sql_str(); async move { self.connection().await?.describe(query).await }.boxed() } } diff --git a/sqlx-postgres/src/message/response.rs b/sqlx-postgres/src/message/response.rs index d6e43e0871..a7c09cfa34 100644 --- a/sqlx-postgres/src/message/response.rs +++ b/sqlx-postgres/src/message/response.rs @@ -195,7 +195,7 @@ struct Fields<'a> { offset: usize, } -impl<'a> Iterator for Fields<'a> { +impl Iterator for Fields<'_> { type Item = (u8, Range); fn next(&mut self) -> Option { diff --git a/sqlx-postgres/src/migrate.rs b/sqlx-postgres/src/migrate.rs index c37e92f4d6..504229c2e2 100644 --- a/sqlx-postgres/src/migrate.rs +++ b/sqlx-postgres/src/migrate.rs @@ -7,6 +7,7 @@ use futures_core::future::BoxFuture; pub(crate) use sqlx_core::migrate::MigrateError; pub(crate) use sqlx_core::migrate::{AppliedMigration, Migration}; pub(crate) use sqlx_core::migrate::{Migrate, MigrateDatabase}; +use sqlx_core::sql_str::AssertSqlSafe; use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; @@ -45,10 +46,10 @@ impl MigrateDatabase for Postgres { let mut conn = options.connect().await?; let _ = conn - .execute(&*format!( + .execute(AssertSqlSafe(format!( "CREATE DATABASE \"{}\"", database.replace('"', "\"\"") - )) + ))) .await?; Ok(()) @@ -76,10 +77,10 @@ impl MigrateDatabase for Postgres { let mut conn = options.connect().await?; let _ = conn - .execute(&*format!( + .execute(AssertSqlSafe(format!( "DROP DATABASE IF EXISTS \"{}\"", database.replace('"', "\"\"") - )) + ))) .await?; Ok(()) @@ -99,10 +100,10 @@ impl MigrateDatabase for Postgres { let pid_type = if version >= 90200 { "pid" } else { "procpid" }; - conn.execute(&*format!( + conn.execute(AssertSqlSafe(format!( "SELECT pg_terminate_backend(pg_stat_activity.{pid_type}) FROM pg_stat_activity \ WHERE pg_stat_activity.datname = '{database}' AND {pid_type} <> pg_backend_pid()" - )) + ))) .await?; Self::drop_database(url).await @@ -277,7 +278,7 @@ async fn execute_migration( migration: &Migration, ) -> Result<(), MigrateError> { let _ = conn - .execute(&*migration.sql) + .execute(AssertSqlSafe(migration.sql.to_string())) .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; @@ -302,7 +303,7 @@ async fn revert_migration( migration: &Migration, ) -> Result<(), MigrateError> { let _ = conn - .execute(&*migration.sql) + .execute(AssertSqlSafe(migration.sql.to_string())) .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; diff --git a/sqlx-postgres/src/statement.rs b/sqlx-postgres/src/statement.rs index abd553af30..b2f739d033 100644 --- a/sqlx-postgres/src/statement.rs +++ b/sqlx-postgres/src/statement.rs @@ -3,15 +3,15 @@ use crate::column::ColumnIndex; use crate::error::Error; use crate::ext::ustr::UStr; use crate::{PgArguments, Postgres}; -use std::borrow::Cow; use std::sync::Arc; +use sqlx_core::sql_str::SqlStr; pub(crate) use sqlx_core::statement::Statement; use sqlx_core::{Either, HashMap}; #[derive(Debug, Clone)] -pub struct PgStatement<'q> { - pub(crate) sql: Cow<'q, str>, +pub struct PgStatement { + pub(crate) sql: SqlStr, pub(crate) metadata: Arc, } @@ -24,18 +24,22 @@ pub(crate) struct PgStatementMetadata { pub(crate) parameters: Vec, } -impl<'q> Statement<'q> for PgStatement<'q> { +impl Statement for PgStatement { type Database = Postgres; - fn to_owned(&self) -> PgStatement<'static> { - PgStatement::<'static> { - sql: Cow::Owned(self.sql.clone().into_owned()), + fn to_owned(&self) -> PgStatement { + PgStatement { + sql: self.sql.clone(), metadata: self.metadata.clone(), } } - fn sql(&self) -> &str { - &self.sql + fn sql_cloned(&self) -> SqlStr { + self.sql.clone() + } + + fn into_sql(self) -> SqlStr { + self.sql } fn parameters(&self) -> Option> { @@ -49,8 +53,8 @@ impl<'q> Statement<'q> for PgStatement<'q> { impl_statement_query!(PgArguments); } -impl ColumnIndex> for &'_ str { - fn index(&self, statement: &PgStatement<'_>) -> Result { +impl ColumnIndex for &'_ str { + fn index(&self, statement: &PgStatement) -> Result { statement .metadata .column_names diff --git a/sqlx-postgres/src/testing/mod.rs b/sqlx-postgres/src/testing/mod.rs index af20fe87ea..e127812a29 100644 --- a/sqlx-postgres/src/testing/mod.rs +++ b/sqlx-postgres/src/testing/mod.rs @@ -1,6 +1,7 @@ use std::fmt::Write; use std::ops::Deref; use std::str::FromStr; +use std::sync::Arc; use std::time::Duration; use futures_core::future::BoxFuture; @@ -8,6 +9,7 @@ use futures_core::future::BoxFuture; use once_cell::sync::OnceCell; use sqlx_core::connection::Connection; use sqlx_core::query_scalar::query_scalar; +use sqlx_core::sql_str::AssertSqlSafe; use crate::error::Error; use crate::executor::Executor; @@ -55,12 +57,13 @@ impl TestSupport for Postgres { let mut deleted_db_names = Vec::with_capacity(delete_db_names.len()); - let mut command = String::new(); + let mut command_arced = Arc::new(String::new()); for db_name in &delete_db_names { + let command = Arc::get_mut(&mut command_arced).unwrap(); command.clear(); writeln!(command, "drop database if exists {db_name:?};").ok(); - match conn.execute(&*command).await { + match conn.execute(AssertSqlSafe(command_arced.clone())).await { Ok(_deleted) => { deleted_db_names.push(db_name); } @@ -169,7 +172,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { let create_command = format!("create database {db_name:?}"); debug_assert!(create_command.starts_with("create database \"")); - conn.execute(&(create_command)[..]).await?; + conn.execute(AssertSqlSafe(create_command)).await?; Ok(TestContext { pool_opts: PoolOptions::new() @@ -191,7 +194,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { async fn do_cleanup(conn: &mut PgConnection, db_name: &str) -> Result<(), Error> { let delete_db_command = format!("drop database if exists {db_name:?};"); - conn.execute(&*delete_db_command).await?; + conn.execute(AssertSqlSafe(delete_db_command)).await?; query("delete from _sqlx_test.databases where db_name = $1::text") .bind(db_name) .execute(&mut *conn) diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index 23352a8dcf..69cc3c9a1a 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -1,5 +1,6 @@ use futures_core::future::BoxFuture; use sqlx_core::database::Database; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; use std::borrow::Cow; use crate::error::Error; @@ -25,12 +26,13 @@ impl TransactionManager for PgTransactionManager { // custom `BEGIN` statements are not allowed if we're already in // a transaction (we need to issue a `SAVEPOINT` instead) Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), - Some(statement) => statement, + Some(statement) => AssertSqlSafe(statement).into_sql_str(), None => begin_ansi_transaction_sql(depth), }; let rollback = Rollback::new(conn); - rollback.conn.queue_simple_query(&statement)?; + + rollback.conn.queue_simple_query(statement.as_str())?; rollback.conn.wait_until_ready().await?; if !rollback.conn.in_transaction() { return Err(Error::BeginFailed); @@ -45,7 +47,7 @@ impl TransactionManager for PgTransactionManager { fn commit(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { if conn.inner.transaction_depth > 0 { - conn.execute(&*commit_ansi_transaction_sql(conn.inner.transaction_depth)) + conn.execute(commit_ansi_transaction_sql(conn.inner.transaction_depth)) .await?; conn.inner.transaction_depth -= 1; @@ -58,10 +60,8 @@ impl TransactionManager for PgTransactionManager { fn rollback(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { if conn.inner.transaction_depth > 0 { - conn.execute(&*rollback_ansi_transaction_sql( - conn.inner.transaction_depth, - )) - .await?; + conn.execute(rollback_ansi_transaction_sql(conn.inner.transaction_depth)) + .await?; conn.inner.transaction_depth -= 1; } @@ -72,8 +72,10 @@ impl TransactionManager for PgTransactionManager { fn start_rollback(conn: &mut PgConnection) { if conn.inner.transaction_depth > 0 { - conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.inner.transaction_depth)) - .expect("BUG: Rollback query somehow too large for protocol"); + conn.queue_simple_query( + rollback_ansi_transaction_sql(conn.inner.transaction_depth).as_str(), + ) + .expect("BUG: Rollback query somehow too large for protocol"); conn.inner.transaction_depth -= 1; } diff --git a/sqlx-postgres/src/types/cube.rs b/sqlx-postgres/src/types/cube.rs index cc2a016090..d7ddbd1723 100644 --- a/sqlx-postgres/src/types/cube.rs +++ b/sqlx-postgres/src/types/cube.rs @@ -71,7 +71,7 @@ impl<'r> Decode<'r, Postgres> for PgCube { } } -impl<'q> Encode<'q, Postgres> for PgCube { +impl Encode<'_, Postgres> for PgCube { fn produces(&self) -> Option { Some(PgTypeInfo::with_name("cube")) } diff --git a/sqlx-postgres/src/types/geometry/box.rs b/sqlx-postgres/src/types/geometry/box.rs index 28016b2786..ad4fa39ef7 100644 --- a/sqlx-postgres/src/types/geometry/box.rs +++ b/sqlx-postgres/src/types/geometry/box.rs @@ -56,7 +56,7 @@ impl<'r> Decode<'r, Postgres> for PgBox { } } -impl<'q> Encode<'q, Postgres> for PgBox { +impl Encode<'_, Postgres> for PgBox { fn produces(&self) -> Option { Some(PgTypeInfo::with_name("box")) } diff --git a/sqlx-postgres/src/types/geometry/line.rs b/sqlx-postgres/src/types/geometry/line.rs index 8f08c949ef..6bc90676ed 100644 --- a/sqlx-postgres/src/types/geometry/line.rs +++ b/sqlx-postgres/src/types/geometry/line.rs @@ -47,7 +47,7 @@ impl<'r> Decode<'r, Postgres> for PgLine { } } -impl<'q> Encode<'q, Postgres> for PgLine { +impl Encode<'_, Postgres> for PgLine { fn produces(&self) -> Option { Some(PgTypeInfo::with_name("line")) } diff --git a/sqlx-postgres/src/types/geometry/line_segment.rs b/sqlx-postgres/src/types/geometry/line_segment.rs index cd08e4da4a..486d2ba07d 100644 --- a/sqlx-postgres/src/types/geometry/line_segment.rs +++ b/sqlx-postgres/src/types/geometry/line_segment.rs @@ -57,7 +57,7 @@ impl<'r> Decode<'r, Postgres> for PgLSeg { } } -impl<'q> Encode<'q, Postgres> for PgLSeg { +impl Encode<'_, Postgres> for PgLSeg { fn produces(&self) -> Option { Some(PgTypeInfo::with_name("lseg")) } diff --git a/sqlx-postgres/src/types/geometry/point.rs b/sqlx-postgres/src/types/geometry/point.rs index 83b7c24d0d..6f264e85ea 100644 --- a/sqlx-postgres/src/types/geometry/point.rs +++ b/sqlx-postgres/src/types/geometry/point.rs @@ -50,7 +50,7 @@ impl<'r> Decode<'r, Postgres> for PgPoint { } } -impl<'q> Encode<'q, Postgres> for PgPoint { +impl Encode<'_, Postgres> for PgPoint { fn produces(&self) -> Option { Some(PgTypeInfo::with_name("point")) } diff --git a/sqlx-postgres/src/types/json.rs b/sqlx-postgres/src/types/json.rs index 567e48015e..32f886c781 100644 --- a/sqlx-postgres/src/types/json.rs +++ b/sqlx-postgres/src/types/json.rs @@ -54,7 +54,7 @@ impl PgHasArrayType for JsonRawValue { } } -impl<'q, T> Encode<'q, Postgres> for Json +impl Encode<'_, Postgres> for Json where T: Serialize, { diff --git a/sqlx-postgres/src/types/text.rs b/sqlx-postgres/src/types/text.rs index b5b0a5ed7b..12d92d4b2a 100644 --- a/sqlx-postgres/src/types/text.rs +++ b/sqlx-postgres/src/types/text.rs @@ -18,7 +18,7 @@ impl Type for Text { } } -impl<'q, T> Encode<'q, Postgres> for Text +impl Encode<'_, Postgres> for Text where T: Display, { diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index c72370d0ff..0f3c64f4f5 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -12,6 +12,7 @@ use sqlx_core::any::{ Any, AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo, AnyTypeInfoKind, AnyValueKind, }; +use sqlx_core::sql_str::SqlStr; use crate::type_info::DataType; use sqlx_core::connection::{ConnectOptions, Connection}; @@ -84,7 +85,7 @@ impl AnyConnectionBackend for SqliteConnection { fn fetch_many<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxStream<'q, sqlx_core::Result>> { @@ -107,7 +108,7 @@ impl AnyConnectionBackend for SqliteConnection { fn fetch_optional<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxFuture<'q, sqlx_core::Result>> { @@ -132,16 +133,17 @@ impl AnyConnectionBackend for SqliteConnection { fn prepare_with<'c, 'q: 'c>( &'c mut self, - sql: &'q str, + sql: SqlStr, _parameters: &[AnyTypeInfo], - ) -> BoxFuture<'c, sqlx_core::Result>> { + ) -> BoxFuture<'c, sqlx_core::Result> { Box::pin(async move { let statement = Executor::prepare_with(self, sql, &[]).await?; - AnyStatement::try_from_statement(sql, &statement, statement.column_names.clone()) + let column_names = statement.column_names.clone(); + AnyStatement::try_from_statement(statement, column_names) }) } - fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, sqlx_core::Result>> { + fn describe(&mut self, sql: SqlStr) -> BoxFuture<'_, sqlx_core::Result>> { Box::pin(async move { Executor::describe(self, sql).await?.try_into_any() }) } } diff --git a/sqlx-sqlite/src/connection/describe.rs b/sqlx-sqlite/src/connection/describe.rs index 0f4da33ccc..13a25e96aa 100644 --- a/sqlx-sqlite/src/connection/describe.rs +++ b/sqlx-sqlite/src/connection/describe.rs @@ -5,14 +5,18 @@ use crate::error::Error; use crate::statement::VirtualStatement; use crate::type_info::DataType; use crate::{Sqlite, SqliteColumn}; +use sqlx_core::sql_str::SqlStr; use sqlx_core::Either; use std::convert::identity; -pub(crate) fn describe(conn: &mut ConnectionState, query: &str) -> Result, Error> { +pub(crate) fn describe( + conn: &mut ConnectionState, + query: SqlStr, +) -> Result, Error> { // describing a statement from SQLite can be involved // each SQLx statement is comprised of multiple SQL statements - let mut statement = VirtualStatement::new(query, false)?; + let mut statement = VirtualStatement::new(query.as_str(), false)?; let mut columns = Vec::new(); let mut nullable = Vec::new(); diff --git a/sqlx-sqlite/src/connection/execute.rs b/sqlx-sqlite/src/connection/execute.rs index 8a76236977..7acbc91ff8 100644 --- a/sqlx-sqlite/src/connection/execute.rs +++ b/sqlx-sqlite/src/connection/execute.rs @@ -3,12 +3,13 @@ use crate::error::Error; use crate::logger::QueryLogger; use crate::statement::{StatementHandle, VirtualStatement}; use crate::{SqliteArguments, SqliteQueryResult, SqliteRow}; +use sqlx_core::sql_str::SqlSafeStr; use sqlx_core::Either; pub struct ExecuteIter<'a> { handle: &'a mut ConnectionHandle, statement: &'a mut VirtualStatement, - logger: QueryLogger<'a>, + logger: QueryLogger, args: Option>, /// since a `VirtualStatement` can encompass multiple actual statements, @@ -20,12 +21,13 @@ pub struct ExecuteIter<'a> { pub(crate) fn iter<'a>( conn: &'a mut ConnectionState, - query: &'a str, + query: impl SqlSafeStr, args: Option>, persistent: bool, ) -> Result, Error> { + let query = query.into_sql_str(); // fetch the cached statement or allocate a new one - let statement = conn.statements.get(query, persistent)?; + let statement = conn.statements.get(query.as_str(), persistent)?; let logger = QueryLogger::new(query, conn.log_settings.clone()); diff --git a/sqlx-sqlite/src/connection/executor.rs b/sqlx-sqlite/src/connection/executor.rs index 1f6ce7726f..3ad67c6267 100644 --- a/sqlx-sqlite/src/connection/executor.rs +++ b/sqlx-sqlite/src/connection/executor.rs @@ -7,6 +7,7 @@ use futures_util::{stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use sqlx_core::describe::Describe; use sqlx_core::error::Error; use sqlx_core::executor::{Execute, Executor}; +use sqlx_core::sql_str::SqlSafeStr; use sqlx_core::Either; use std::{future, pin::pin}; @@ -23,12 +24,12 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { 'q: 'e, E: 'q, { - let sql = query.sql(); let arguments = match query.take_arguments().map_err(Error::Encode) { Ok(arguments) => arguments, Err(error) => return stream::once(future::ready(Err(error))).boxed(), }; let persistent = query.persistent() && arguments.is_some(); + let sql = query.sql(); Box::pin( self.worker @@ -48,7 +49,6 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { 'q: 'e, E: 'q, { - let sql = query.sql(); let arguments = match query.take_arguments().map_err(Error::Encode) { Ok(arguments) => arguments, Err(error) => return future::ready(Err(error)).boxed(), @@ -56,6 +56,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { let persistent = query.persistent() && arguments.is_some(); Box::pin(async move { + let sql = query.sql(); let mut stream = pin!(self .worker .execute(sql, arguments, self.row_channel_size, persistent, Some(1)) @@ -72,29 +73,28 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { }) } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: impl SqlSafeStr, _parameters: &[SqliteTypeInfo], - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result> where 'c: 'e, { + let sql = sql.into_sql_str(); Box::pin(async move { let statement = self.worker.prepare(sql).await?; - Ok(SqliteStatement { - sql: sql.into(), - ..statement - }) + Ok(statement) }) } #[doc(hidden)] - fn describe<'e, 'q: 'e>(self, sql: &'q str) -> BoxFuture<'e, Result, Error>> + fn describe<'e>(self, sql: impl SqlSafeStr) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { - Box::pin(self.worker.describe(sql)) + let sql = sql.into_sql_str(); + Box::pin(async move { self.worker.describe(sql).await }) } } diff --git a/sqlx-sqlite/src/connection/explain.rs b/sqlx-sqlite/src/connection/explain.rs index bfa66aa12f..edd65ece49 100644 --- a/sqlx-sqlite/src/connection/explain.rs +++ b/sqlx-sqlite/src/connection/explain.rs @@ -12,6 +12,7 @@ use crate::from_row::FromRow; use crate::logger::{BranchParent, BranchResult, DebugDiff}; use crate::type_info::DataType; use crate::SqliteTypeInfo; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_core::{hash_map, HashMap}; use std::fmt::Debug; use std::str::from_utf8; @@ -567,7 +568,7 @@ pub(super) fn explain( ) -> Result<(Vec, Vec>), Error> { let root_block_cols = root_block_columns(conn)?; let program: Vec<(i64, String, i64, i64, i64, Vec)> = - execute::iter(conn, &format!("EXPLAIN {query}"), None, false)? + execute::iter(conn, AssertSqlSafe(format!("EXPLAIN {query}")), None, false)? .filter_map(|res| res.map(|either| either.right()).transpose()) .map(|row| FromRow::from_row(&row?)) .collect::, Error>>()?; diff --git a/sqlx-sqlite/src/connection/intmap.rs b/sqlx-sqlite/src/connection/intmap.rs index dc09162f64..105c0d2b0e 100644 --- a/sqlx-sqlite/src/connection/intmap.rs +++ b/sqlx-sqlite/src/connection/intmap.rs @@ -103,7 +103,7 @@ impl IntMap { *item = Some(V::default()); } - return self.0[idx].as_mut().unwrap(); + self.0[idx].as_mut().unwrap() } } diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index b94ad91c4d..5823de6ec8 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -23,6 +23,7 @@ use sqlx_core::common::StatementCache; pub(crate) use sqlx_core::connection::*; use sqlx_core::error::Error; use sqlx_core::executor::Executor; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_core::transaction::Transaction; use crate::connection::establish::EstablishParams; @@ -224,7 +225,7 @@ impl Connection for SqliteConnection { write!(pragma_string, "PRAGMA analysis_limit = {limit}; ").ok(); } pragma_string.push_str("PRAGMA optimize;"); - self.execute(&*pragma_string).await?; + self.execute(AssertSqlSafe(pragma_string)).await?; } let shutdown = self.worker.shutdown(); // Drop the statement worker, which should diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index 00a4c2999c..1c534e8e4c 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -6,6 +6,7 @@ use std::thread; use futures_channel::oneshot; use futures_intrusive::sync::{Mutex, MutexGuard}; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; use tracing::span::Span; use sqlx_core::describe::Describe; @@ -53,15 +54,15 @@ impl WorkerSharedState { enum Command { Prepare { - query: Box, - tx: oneshot::Sender, Error>>, + query: SqlStr, + tx: oneshot::Sender>, }, Describe { - query: Box, + query: SqlStr, tx: oneshot::Sender, Error>>, }, Execute { - query: Box, + query: SqlStr, arguments: Option>, persistent: bool, tx: flume::Sender, Error>>, @@ -145,17 +146,16 @@ impl ConnectionWorker { let _guard = span.enter(); match cmd { Command::Prepare { query, tx } => { - tx.send(prepare(&mut conn, &query).map(|prepared| { + tx.send(prepare(&mut conn, query).inspect(|_| { update_cached_statements_size( &conn, &shared.cached_statements_size, ); - prepared })) .ok(); } Command::Describe { query, tx } => { - tx.send(describe(&mut conn, &query)).ok(); + tx.send(describe(&mut conn, query)).ok(); } Command::Execute { query, @@ -164,7 +164,7 @@ impl ConnectionWorker { tx, limit } => { - let iter = match execute::iter(&mut conn, &query, arguments, persistent) + let iter = match execute::iter(&mut conn, query, arguments, persistent) { Ok(iter) => iter, Err(e) => { @@ -220,12 +220,12 @@ impl ConnectionWorker { } continue; }, - Some(statement) => statement, + Some(statement) => AssertSqlSafe(statement).into_sql_str(), None => begin_ansi_transaction_sql(depth), }; let res = conn.handle - .exec(statement) + .exec(statement.as_str()) .map(|_| { shared.transaction_depth.fetch_add(1, Ordering::Release); }); @@ -238,7 +238,7 @@ impl ConnectionWorker { // immediately otherwise it would remain started forever. if let Err(error) = conn .handle - .exec(rollback_ansi_transaction_sql(depth + 1)) + .exec(rollback_ansi_transaction_sql(depth + 1).as_str()) .map(|_| { shared.transaction_depth.fetch_sub(1, Ordering::Release); }) @@ -256,7 +256,7 @@ impl ConnectionWorker { let res = if depth > 0 { conn.handle - .exec(commit_ansi_transaction_sql(depth)) + .exec(commit_ansi_transaction_sql(depth).as_str()) .map(|_| { shared.transaction_depth.fetch_sub(1, Ordering::Release); }) @@ -282,7 +282,7 @@ impl ConnectionWorker { let res = if depth > 0 { conn.handle - .exec(rollback_ansi_transaction_sql(depth)) + .exec(rollback_ansi_transaction_sql(depth).as_str()) .map(|_| { shared.transaction_depth.fetch_sub(1, Ordering::Release); }) @@ -335,25 +335,19 @@ impl ConnectionWorker { establish_rx.await.map_err(|_| Error::WorkerCrashed)? } - pub(crate) async fn prepare(&mut self, query: &str) -> Result, Error> { - self.oneshot_cmd(|tx| Command::Prepare { - query: query.into(), - tx, - }) - .await? + pub(crate) async fn prepare(&mut self, query: SqlStr) -> Result { + self.oneshot_cmd(|tx| Command::Prepare { query, tx }) + .await? } - pub(crate) async fn describe(&mut self, query: &str) -> Result, Error> { - self.oneshot_cmd(|tx| Command::Describe { - query: query.into(), - tx, - }) - .await? + pub(crate) async fn describe(&mut self, query: SqlStr) -> Result, Error> { + self.oneshot_cmd(|tx| Command::Describe { query, tx }) + .await? } pub(crate) async fn execute( &mut self, - query: &str, + query: SqlStr, args: Option>, chan_size: usize, persistent: bool, @@ -364,7 +358,7 @@ impl ConnectionWorker { self.command_tx .send_async(( Command::Execute { - query: query.into(), + query, arguments: args.map(SqliteArguments::into_static), persistent, tx, @@ -495,9 +489,9 @@ impl ConnectionWorker { } } -fn prepare(conn: &mut ConnectionState, query: &str) -> Result, Error> { +fn prepare(conn: &mut ConnectionState, query: SqlStr) -> Result { // prepare statement object (or checkout from cache) - let statement = conn.statements.get(query, true)?; + let statement = conn.statements.get(query.as_str(), true)?; let mut parameters = 0; let mut columns = None; @@ -514,7 +508,7 @@ fn prepare(conn: &mut ConnectionState, query: &str) -> Result = SqliteArguments<'q>; type ArgumentBuffer<'q> = Vec>; - type Statement<'q> = SqliteStatement<'q>; + type Statement = SqliteStatement; const NAME: &'static str = "SQLite"; diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index e4a122b6bd..8429468be2 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -57,6 +57,7 @@ pub use options::{ }; pub use query_result::SqliteQueryResult; pub use row::SqliteRow; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; pub use statement::SqliteStatement; pub use transaction::SqliteTransactionManager; pub use type_info::SqliteTypeInfo; @@ -132,9 +133,10 @@ pub fn describe_blocking(query: &str, database_url: &str) -> Result QueryPlanLogger<'q, R, S, P> } } -impl<'q, R: Debug, S: Debug + DebugDiff, P: Debug> Drop for QueryPlanLogger<'q, R, S, P> { +impl Drop for QueryPlanLogger<'_, R, S, P> { fn drop(&mut self) { self.finish(); } diff --git a/sqlx-sqlite/src/migrate.rs b/sqlx-sqlite/src/migrate.rs index b9ce22dccd..b184b2305a 100644 --- a/sqlx-sqlite/src/migrate.rs +++ b/sqlx-sqlite/src/migrate.rs @@ -9,6 +9,7 @@ use crate::query::query; use crate::query_as::query_as; use crate::{Sqlite, SqliteConnectOptions, SqliteConnection, SqliteJournalMode}; use futures_core::future::BoxFuture; +use sqlx_core::sql_str::AssertSqlSafe; use std::str::FromStr; use std::sync::atomic::Ordering; use std::time::Duration; @@ -142,7 +143,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( // data lineage and debugging reasons, so it is not super important if it is lost. So we initialize it to -1 // and update it once the actual transaction completed. let _ = tx - .execute(&*migration.sql) + .execute(AssertSqlSafe(migration.sql.to_string())) .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; @@ -195,7 +196,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( let mut tx = self.begin().await?; let start = Instant::now(); - let _ = tx.execute(&*migration.sql).await?; + let _ = tx.execute(AssertSqlSafe(migration.sql.to_string())).await?; // language=SQL let _ = query(r#"DELETE FROM _sqlx_migrations WHERE version = ?1"#) diff --git a/sqlx-sqlite/src/options/connect.rs b/sqlx-sqlite/src/options/connect.rs index 309f2430e0..5afb9a46f2 100644 --- a/sqlx-sqlite/src/options/connect.rs +++ b/sqlx-sqlite/src/options/connect.rs @@ -4,6 +4,7 @@ use log::LevelFilter; use sqlx_core::connection::ConnectOptions; use sqlx_core::error::Error; use sqlx_core::executor::Executor; +use sqlx_core::sql_str::AssertSqlSafe; use std::fmt::Write; use std::str::FromStr; use std::time::Duration; @@ -36,7 +37,7 @@ impl ConnectOptions for SqliteConnectOptions { let mut conn = SqliteConnection::establish(self).await?; // Execute PRAGMAs - conn.execute(&*self.pragma_string()).await?; + conn.execute(AssertSqlSafe(self.pragma_string())).await?; if !self.collations.is_empty() { let mut locked = conn.lock_handle().await?; diff --git a/sqlx-sqlite/src/statement/mod.rs b/sqlx-sqlite/src/statement/mod.rs index 179b8eeaf7..08f56d08eb 100644 --- a/sqlx-sqlite/src/statement/mod.rs +++ b/sqlx-sqlite/src/statement/mod.rs @@ -2,8 +2,8 @@ use crate::column::ColumnIndex; use crate::error::Error; use crate::ext::ustr::UStr; use crate::{Sqlite, SqliteArguments, SqliteColumn, SqliteTypeInfo}; +use sqlx_core::sql_str::SqlStr; use sqlx_core::{Either, HashMap}; -use std::borrow::Cow; use std::sync::Arc; pub(crate) use sqlx_core::statement::*; @@ -17,27 +17,31 @@ pub(crate) use r#virtual::VirtualStatement; #[derive(Debug, Clone)] #[allow(clippy::rc_buffer)] -pub struct SqliteStatement<'q> { - pub(crate) sql: Cow<'q, str>, +pub struct SqliteStatement { + pub(crate) sql: SqlStr, pub(crate) parameters: usize, pub(crate) columns: Arc>, pub(crate) column_names: Arc>, } -impl<'q> Statement<'q> for SqliteStatement<'q> { +impl Statement for SqliteStatement { type Database = Sqlite; - fn to_owned(&self) -> SqliteStatement<'static> { - SqliteStatement::<'static> { - sql: Cow::Owned(self.sql.clone().into_owned()), + fn to_owned(&self) -> SqliteStatement { + SqliteStatement { + sql: self.sql.clone(), parameters: self.parameters, columns: Arc::clone(&self.columns), column_names: Arc::clone(&self.column_names), } } - fn sql(&self) -> &str { - &self.sql + fn sql_cloned(&self) -> SqlStr { + self.sql.clone() + } + + fn into_sql(self) -> SqlStr { + self.sql } fn parameters(&self) -> Option> { @@ -51,8 +55,8 @@ impl<'q> Statement<'q> for SqliteStatement<'q> { impl_statement_query!(SqliteArguments<'_>); } -impl ColumnIndex> for &'_ str { - fn index(&self, statement: &SqliteStatement<'_>) -> Result { +impl ColumnIndex for &'_ str { + fn index(&self, statement: &SqliteStatement) -> Result { statement .column_names .get(*self) diff --git a/sqlx-sqlite/src/value.rs b/sqlx-sqlite/src/value.rs index 469c4e70d5..dc40f29ccb 100644 --- a/sqlx-sqlite/src/value.rs +++ b/sqlx-sqlite/src/value.rs @@ -108,8 +108,8 @@ pub(crate) struct ValueHandle<'a> { } // SAFE: only protected value objects are stored in SqliteValue -unsafe impl<'a> Send for ValueHandle<'a> {} -unsafe impl<'a> Sync for ValueHandle<'a> {} +unsafe impl Send for ValueHandle<'_> {} +unsafe impl Sync for ValueHandle<'_> {} impl ValueHandle<'static> { fn new_owned(value: NonNull, type_info: SqliteTypeInfo) -> Self { @@ -122,7 +122,7 @@ impl ValueHandle<'static> { } } -impl<'a> ValueHandle<'a> { +impl ValueHandle<'_> { fn new_borrowed(value: NonNull, type_info: SqliteTypeInfo) -> Self { Self { value, @@ -185,7 +185,7 @@ impl<'a> ValueHandle<'a> { } } -impl<'a> Drop for ValueHandle<'a> { +impl Drop for ValueHandle<'_> { fn drop(&mut self) { if self.free_on_drop { unsafe { diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index cc77f38dba..29852a6437 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -109,12 +109,13 @@ macro_rules! test_unprepared_type { async fn [< test_unprepared_type_ $name >] () -> anyhow::Result<()> { use sqlx::prelude::*; use futures::TryStreamExt; + use sqlx_core::sql_str::AssertSqlSafe; let mut conn = sqlx_test::new::<$db>().await?; $( let query = format!("SELECT {}", $text); - let mut s = conn.fetch(&*query); + let mut s = conn.fetch(AssertSqlSafe(query)); let row = s.try_next().await?.unwrap(); let rec = row.try_get::<$ty, _>(0)?; @@ -137,13 +138,14 @@ macro_rules! __test_prepared_decode_type { #[sqlx_macros::test] async fn [< test_prepared_decode_type_ $name >] () -> anyhow::Result<()> { use sqlx::Row; + use sqlx_core::sql_str::AssertSqlSafe; let mut conn = sqlx_test::new::<$db>().await?; $( let query = format!("SELECT {}", $text); - let row = sqlx::query(&query) + let row = sqlx::query(AssertSqlSafe(query)) .fetch_one(&mut conn) .await?; @@ -166,6 +168,7 @@ macro_rules! __test_prepared_type { #[sqlx_macros::test] async fn [< test_prepared_type_ $name >] () -> anyhow::Result<()> { use sqlx::Row; + use sqlx_core::sql_str::AssertSqlSafe; let mut conn = sqlx_test::new::<$db>().await?; @@ -173,7 +176,7 @@ macro_rules! __test_prepared_type { let query = format!($sql, $text); println!("{query}"); - let row = sqlx::query(&query) + let row = sqlx::query(AssertSqlSafe(query)) .bind($value) .bind($value) .fetch_one(&mut conn) diff --git a/src/lib.rs b/src/lib.rs index ed76c5f5ee..0437ed0327 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,6 +29,7 @@ pub use sqlx_core::query_scalar::query_scalar_with_result as __query_scalar_with pub use sqlx_core::query_scalar::{query_scalar, query_scalar_with}; pub use sqlx_core::raw_sql::{raw_sql, RawSql}; pub use sqlx_core::row::Row; +pub use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; pub use sqlx_core::statement::Statement; pub use sqlx_core::transaction::{Transaction, TransactionManager}; pub use sqlx_core::type_info::TypeInfo; diff --git a/tests/any/any.rs b/tests/any/any.rs index 2c59ca5339..c84f8c818b 100644 --- a/tests/any/any.rs +++ b/tests/any/any.rs @@ -1,5 +1,6 @@ use sqlx::any::AnyRow; use sqlx::{Any, Connection, Executor, Row}; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_test::new; #[sqlx_macros::test] @@ -106,7 +107,7 @@ async fn it_can_fail_and_recover() -> anyhow::Result<()> { // now try and use the connection let val: i32 = conn - .fetch_one(&*format!("SELECT {i}")) + .fetch_one(AssertSqlSafe(format!("SELECT {i}"))) .await? .get_unchecked(0); @@ -132,7 +133,7 @@ async fn it_can_fail_and_recover_with_pool() -> anyhow::Result<()> { // now try and use the connection let val: i32 = pool - .fetch_one(&*format!("SELECT {i}")) + .fetch_one(AssertSqlSafe(format!("SELECT {i}"))) .await? .get_unchecked(0); diff --git a/tests/any/pool.rs b/tests/any/pool.rs index 3130b4f1c6..a4849940b8 100644 --- a/tests/any/pool.rs +++ b/tests/any/pool.rs @@ -1,5 +1,6 @@ use sqlx::any::{AnyConnectOptions, AnyPoolOptions}; use sqlx::Executor; +use sqlx_core::sql_str::AssertSqlSafe; use std::sync::{ atomic::{AtomicI32, AtomicUsize, Ordering}, Arc, Mutex, @@ -111,7 +112,7 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { id ); - conn.execute(&statement[..]).await?; + conn.execute(AssertSqlSafe(statement)).await?; Ok(()) }) }) diff --git a/tests/postgres/derives.rs b/tests/postgres/derives.rs index 13f9bf1d5d..88f9d026f4 100644 --- a/tests/postgres/derives.rs +++ b/tests/postgres/derives.rs @@ -1,6 +1,7 @@ use futures::TryStreamExt; use sqlx::postgres::types::PgRange; use sqlx::{Connection, Executor, FromRow, Postgres}; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_postgres::PgHasArrayType; use sqlx_test::{new, test_type}; use std::fmt::Debug; @@ -259,7 +260,7 @@ SELECT id, mood FROM people WHERE id = $1 let stmt = format!("SELECT id, mood FROM people WHERE id = {people_id}"); dbg!(&stmt); - let mut cursor = conn.fetch(&*stmt); + let mut cursor = conn.fetch(AssertSqlSafe(stmt)); let row = cursor.try_next().await?.unwrap(); let rec = PeopleRow::from_row(&row)?; diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index fc7108bf4f..b0285f53bf 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -6,6 +6,7 @@ use sqlx::postgres::{ PgPoolOptions, PgRow, PgSeverity, Postgres, PG_COPY_MAX_DATA_LEN, }; use sqlx::{Column, Connection, Executor, Row, Statement, TypeInfo}; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_core::{bytes::Bytes, error::BoxDynError}; use sqlx_test::{new, pool, setup_if_needed}; use std::env; @@ -309,7 +310,10 @@ async fn it_can_fail_and_recover() -> anyhow::Result<()> { assert!(res.is_err()); // now try and use the connection - let val: i32 = conn.fetch_one(&*format!("SELECT {i}::int4")).await?.get(0); + let val: i32 = conn + .fetch_one(AssertSqlSafe(format!("SELECT {i}::int4"))) + .await? + .get(0); assert_eq!(val, i); } @@ -330,7 +334,10 @@ async fn it_can_fail_and_recover_with_pool() -> anyhow::Result<()> { assert!(res.is_err()); // now try and use the connection - let val: i32 = pool.fetch_one(&*format!("SELECT {i}::int4")).await?.get(0); + let val: i32 = pool + .fetch_one(AssertSqlSafe(format!("SELECT {i}::int4"))) + .await? + .get(0); assert_eq!(val, i); } @@ -803,7 +810,7 @@ async fn it_closes_statement_from_cache_issue_470() -> anyhow::Result<()> { let mut conn = PgConnection::connect_with(&options).await?; for i in 0..5 { - let row = sqlx::query(&*format!("SELECT {i}::int4 AS val")) + let row = sqlx::query(AssertSqlSafe(format!("SELECT {i}::int4 AS val"))) .fetch_one(&mut conn) .await?; @@ -1099,8 +1106,10 @@ async fn test_listener_try_recv_buffered() -> anyhow::Result<()> { { let mut txn = notify_conn.begin().await?; for i in 0..5 { - txn.execute(format!("NOTIFY test_channel2, 'payload {i}'").as_str()) - .await?; + txn.execute(AssertSqlSafe(format!( + "NOTIFY test_channel2, 'payload {i}'" + ))) + .await?; } txn.commit().await?; } @@ -1951,7 +1960,8 @@ async fn test_postgres_bytea_hex_deserialization_errors() -> anyhow::Result<()> conn.execute("SET bytea_output = 'escape';").await?; for value in ["", "DEADBEEF"] { let query = format!("SELECT '\\x{value}'::bytea"); - let res: sqlx::Result> = conn.fetch_one(query.as_str()).await?.try_get(0usize); + let res: sqlx::Result> = + conn.fetch_one(AssertSqlSafe(query)).await?.try_get(0usize); // Deserialization only supports hex format so this should error and definitely not panic. res.unwrap_err(); } diff --git a/tests/postgres/query_builder.rs b/tests/postgres/query_builder.rs index 08ed7d11a3..b1e0659eff 100644 --- a/tests/postgres/query_builder.rs +++ b/tests/postgres/query_builder.rs @@ -3,6 +3,7 @@ use sqlx::query_builder::QueryBuilder; use sqlx::Executor; use sqlx::Type; use sqlx::{Either, Execute}; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_test::new; #[test] @@ -54,18 +55,20 @@ fn test_build() { qb.push(" WHERE id = ").push_bind(42i32); let query = qb.build(); - assert_eq!(query.sql(), "SELECT * FROM users WHERE id = $1"); assert_eq!(Execute::persistent(&query), true); + assert_eq!(query.sql(), "SELECT * FROM users WHERE id = $1"); } #[test] fn test_reset() { let mut qb: QueryBuilder<'_, Postgres> = QueryBuilder::new(""); - let _query = qb - .push("SELECT * FROM users WHERE id = ") - .push_bind(42i32) - .build(); + { + let _query = qb + .push("SELECT * FROM users WHERE id = ") + .push_bind(42i32) + .build(); + } qb.reset(); @@ -76,10 +79,12 @@ fn test_reset() { fn test_query_builder_reuse() { let mut qb: QueryBuilder<'_, Postgres> = QueryBuilder::new(""); - let _query = qb - .push("SELECT * FROM users WHERE id = ") - .push_bind(42i32) - .build(); + { + let _query = qb + .push("SELECT * FROM users WHERE id = ") + .push_bind(42i32) + .build(); + } qb.reset(); @@ -97,8 +102,10 @@ fn test_query_builder_with_args() { .push_bind(42i32) .build(); + let args = query.take_arguments().unwrap().unwrap(); + let mut qb: QueryBuilder<'_, Postgres> = - QueryBuilder::with_arguments(query.sql(), query.take_arguments().unwrap().unwrap()); + QueryBuilder::with_arguments(query.sql().as_str(), args); let query = qb.push(" OR membership_level = ").push_bind(3i32).build(); assert_eq!( @@ -129,7 +136,7 @@ async fn test_max_number_of_binds() -> anyhow::Result<()> { let mut conn = new::().await?; // Indirectly ensures the macros support this many binds since this is what they use. - let describe = conn.describe(qb.sql()).await?; + let describe = conn.describe(AssertSqlSafe(qb.sql().to_string())).await?; match describe .parameters diff --git a/tests/postgres/rustsec.rs b/tests/postgres/rustsec.rs index 45fd76b9db..a0692be4c6 100644 --- a/tests/postgres/rustsec.rs +++ b/tests/postgres/rustsec.rs @@ -1,4 +1,5 @@ use sqlx::{Error, PgPool}; +use sqlx_core::sql_str::AssertSqlSafe; use std::{cmp, str}; @@ -114,7 +115,7 @@ async fn rustsec_2024_0363(pool: PgPool) -> anyhow::Result<()> { assert_eq!(wrapped_len, fake_payload_len); - let res = sqlx::raw_sql(&query) + let res = sqlx::raw_sql(AssertSqlSafe(query)) // Note: the connection may hang afterward // because `pending_ready_for_query_count` will underflow. .execute(&pool) diff --git a/tests/sqlite/describe.rs b/tests/sqlite/describe.rs index 5458eaaa82..82d3445530 100644 --- a/tests/sqlite/describe.rs +++ b/tests/sqlite/describe.rs @@ -437,7 +437,7 @@ async fn it_describes_ungrouped_aggregate() -> anyhow::Result<()> { async fn it_describes_literal_subquery() -> anyhow::Result<()> { async fn assert_literal_described( conn: &mut sqlx::SqliteConnection, - query: &str, + query: &'static str, ) -> anyhow::Result<()> { let info = conn.describe(query).await?; @@ -473,7 +473,7 @@ async fn it_describes_literal_subquery() -> anyhow::Result<()> { async fn assert_tweet_described( conn: &mut sqlx::SqliteConnection, - query: &str, + query: &'static str, ) -> anyhow::Result<()> { let info = conn.describe(query).await?; let columns = info.columns(); @@ -533,7 +533,7 @@ async fn it_describes_table_order_by() -> anyhow::Result<()> { async fn assert_literal_order_by_described( conn: &mut sqlx::SqliteConnection, - query: &str, + query: &'static str, ) -> anyhow::Result<()> { let info = conn.describe(query).await?; @@ -571,7 +571,7 @@ async fn it_describes_table_order_by() -> anyhow::Result<()> { async fn it_describes_union() -> anyhow::Result<()> { async fn assert_union_described( conn: &mut sqlx::SqliteConnection, - query: &str, + query: &'static str, ) -> anyhow::Result<()> { let info = conn.describe(query).await?; @@ -653,7 +653,7 @@ async fn it_describes_having_group_by() -> anyhow::Result<()> { async fn it_describes_strange_queries() -> anyhow::Result<()> { async fn assert_single_column_described( conn: &mut sqlx::SqliteConnection, - query: &str, + query: &'static str, typename: &str, nullable: bool, ) -> anyhow::Result<()> { diff --git a/tests/sqlite/rustsec.rs b/tests/sqlite/rustsec.rs index 3ff9c524fa..08f88a3ad9 100644 --- a/tests/sqlite/rustsec.rs +++ b/tests/sqlite/rustsec.rs @@ -1,4 +1,4 @@ -use sqlx::{Connection, Error, SqliteConnection}; +use sqlx::{AssertSqlSafe, Connection, Error, SqliteConnection}; // https://rustsec.org/advisories/RUSTSEC-2024-0363.html // @@ -50,7 +50,7 @@ async fn rustsec_2024_0363() -> anyhow::Result<()> { .execute(&mut conn) .await?; - let res = sqlx::raw_sql(&query).execute(&mut conn).await; + let res = sqlx::raw_sql(AssertSqlSafe(query)).execute(&mut conn).await; if let Err(e) = res { // Connection rejected the query; we're happy.