diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 7f573a6349..653eee718b 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -21,7 +21,7 @@ jobs: runs-on: ubuntu-24.04 strategy: matrix: - runtime: [ async-std, tokio ] + runtime: [ async-global-executor, smol, tokio ] tls: [ native-tls, rustls, none ] steps: - uses: actions/checkout@v4 @@ -118,7 +118,7 @@ jobs: runs-on: ubuntu-24.04 strategy: matrix: - runtime: [ async-std, tokio ] + runtime: [ async-global-executor, smol, tokio ] linking: [ sqlite, sqlite-unbundled ] needs: check steps: @@ -187,7 +187,7 @@ jobs: strategy: matrix: postgres: [ 17, 13 ] - runtime: [ async-std, tokio ] + runtime: [ async-global-executor, smol, tokio ] tls: [ native-tls, rustls-aws-lc-rs, rustls-ring, none ] needs: check steps: @@ -288,7 +288,7 @@ jobs: strategy: matrix: mysql: [ 8 ] - runtime: [ async-std, tokio ] + runtime: [ async-global-executor, smol, tokio ] tls: [ native-tls, rustls-aws-lc-rs, rustls-ring, none ] needs: check steps: @@ -377,7 +377,7 @@ jobs: strategy: matrix: mariadb: [ verylatest, 11_4, 10_11, 10_4 ] - runtime: [ async-std, tokio ] + runtime: [ async-global-executor, smol, tokio ] tls: [ native-tls, rustls-aws-lc-rs, rustls-ring, none ] needs: check steps: diff --git a/Cargo.lock b/Cargo.lock index 07754e7c22..fe6b4cf9b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -198,6 +198,17 @@ dependencies = [ "slab", ] +[[package]] +name = "async-fs" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebcd09b382f40fcd159c2d695175b2ae620ffa5f3bd6f664131efff4e8b9e04a" +dependencies = [ + "async-lock 3.4.0", + "blocking", + "futures-lite 2.5.0", +] + [[package]] name = "async-global-executor" version = "2.4.1" @@ -213,6 +224,20 @@ dependencies = [ "once_cell", ] +[[package]] +name = "async-global-executor" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13f937e26114b93193065fd44f507aa2e9169ad0cdabbb996920b1fe1ddea7ba" +dependencies = [ + "async-channel 2.3.1", + "async-executor", + "async-io 2.4.0", + "async-lock 3.4.0", + "blocking", + "futures-lite 2.5.0", +] + [[package]] name = "async-io" version = "1.13.0" @@ -272,6 +297,54 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-net" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b948000fad4873c1c9339d60f2623323a0cfd3816e5181033c6a5cb68b2accf7" +dependencies = [ + "async-io 2.4.0", + "blocking", + "futures-lite 2.5.0", +] + +[[package]] +name = "async-process" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63255f1dc2381611000436537bbedfe83183faa303a5a0edaf191edef06526bb" +dependencies = [ + "async-channel 2.3.1", + "async-io 2.4.0", + "async-lock 3.4.0", + "async-signal", + "async-task", + "blocking", + "cfg-if", + "event-listener 5.4.0", + "futures-lite 2.5.0", + "rustix 0.38.43", + "tracing", +] + +[[package]] +name = "async-signal" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "637e00349800c0bdf8bfc21ebbc0b6524abea702b0da4168ac00d070d0c0b9f3" +dependencies = [ + "async-io 2.4.0", + "async-lock 3.4.0", + "atomic-waker", + "cfg-if", + "futures-core", + "futures-io", + "rustix 0.38.43", + "signal-hook-registry", + "slab", + "windows-sys 0.59.0", +] + [[package]] name = "async-std" version = "1.13.0" @@ -280,7 +353,7 @@ checksum = "c634475f29802fde2b8f0b505b1bd00dfe4df7d4a000f0b36f7671197d5c3615" dependencies = [ "async-attributes", "async-channel 1.9.0", - "async-global-executor", + "async-global-executor 2.4.1", "async-io 2.4.0", "async-lock 3.4.0", "crossbeam-utils", @@ -3327,6 +3400,23 @@ dependencies = [ "serde", ] +[[package]] +name = "smol" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33bd3e260892199c3ccfc487c88b2da2265080acb316cd920da72fdfd7c599f" +dependencies = [ + "async-channel 2.3.1", + "async-executor", + "async-fs", + "async-io 2.4.0", + "async-lock 3.4.0", + "async-net", + "async-process", + "blocking", + "futures-lite 2.5.0", +] + [[package]] name = "socket2" version = "0.4.10" @@ -3424,13 +3514,18 @@ dependencies = [ name = "sqlx-core" version = "0.8.3" dependencies = [ + "async-global-executor 3.1.0", "async-io 1.13.0", + "async-io 2.4.0", + "async-lock 3.4.0", + "async-net", "async-std", "base64 0.22.1", "bigdecimal", "bit-vec", "bstr", "bytes", + "cfg-if", "chrono", "crc", "crossbeam-queue", @@ -3458,6 +3553,7 @@ dependencies = [ "serde_json", "sha2", "smallvec", + "smol", "sqlx", "thiserror 2.0.11", "time", @@ -3610,7 +3706,9 @@ dependencies = [ name = "sqlx-macros-core" version = "0.8.3" dependencies = [ + "async-global-executor 3.1.0", "async-std", + "cfg-if", "dotenvy", "either", "heck 0.5.0", @@ -3621,6 +3719,7 @@ dependencies = [ "serde", "serde_json", "sha2", + "smol", "sqlx-core", "sqlx-mysql", "sqlx-postgres", diff --git a/Cargo.toml b/Cargo.toml index f31d715b26..48c60c6bfd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,8 +35,7 @@ authors = [ "Chloe Ross ", "Daniel Akhterov ", ] -# TODO: enable this for 0.9.0 -# rust-version = "1.80.0" +rust-version = "1.80.0" [package] name = "sqlx" @@ -76,7 +75,9 @@ _unstable-all-types = [ ] # Base runtime features without TLS +runtime-async-global-executor = ["_rt-async-global-executor", "sqlx-core/_rt-async-global-executor", "sqlx-macros?/_rt-async-global-executor"] runtime-async-std = ["_rt-async-std", "sqlx-core/_rt-async-std", "sqlx-macros?/_rt-async-std"] +runtime-smol = ["_rt-smol", "sqlx-core/_rt-smol", "sqlx-macros?/_rt-smol"] runtime-tokio = ["_rt-tokio", "sqlx-core/_rt-tokio", "sqlx-macros?/_rt-tokio"] # TLS features @@ -92,14 +93,22 @@ tls-none = [] # Legacy Runtime + TLS features +runtime-async-global-executor-native-tls = ["runtime-async-global-executor", "tls-native-tls"] +runtime-async-global-executor-rustls = ["runtime-async-global-executor", "tls-rustls-ring"] + runtime-async-std-native-tls = ["runtime-async-std", "tls-native-tls"] runtime-async-std-rustls = ["runtime-async-std", "tls-rustls-ring"] +runtime-smol-native-tls = ["runtime-smol", "tls-native-tls"] +runtime-smol-rustls = ["runtime-smol", "tls-rustls-ring"] + runtime-tokio-native-tls = ["runtime-tokio", "tls-native-tls"] runtime-tokio-rustls = ["runtime-tokio", "tls-rustls-ring"] # for conditional compilation +_rt-async-global-executor = [] _rt-async-std = [] +_rt-smol = [] _rt-tokio = [] _sqlite = [] @@ -151,12 +160,22 @@ time = { version = "0.3.36", features = ["formatting", "parsing", "macros"] } uuid = "1.1.2" # Common utility crates +cfg-if = "1.0.0" dotenvy = { version = "0.15.0", default-features = false } # Runtimes +[workspace.dependencies.async-global-executor] +version = "3.1" +default-features = false +features = ["async-io"] + [workspace.dependencies.async-std] version = "1.12" +[workspace.dependencies.smol] +version = "2.0" +default-features = false + [workspace.dependencies.tokio] version = "1" features = ["time", "net", "sync", "fs", "io-util", "rt"] diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 29f0b09695..16beae9c11 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,4 +1,4 @@ # Note: should NOT increase during a minor/patch release cycle [toolchain] -channel = "1.78" +channel = "1.80" profile = "minimal" diff --git a/sqlx-cli/src/database.rs b/sqlx-cli/src/database.rs index 7a9bc6bf2f..7dd87253af 100644 --- a/sqlx-cli/src/database.rs +++ b/sqlx-cli/src/database.rs @@ -13,7 +13,7 @@ pub async fn create(connect_opts: &ConnectOpts) -> anyhow::Result<()> { let exists = crate::retry_connect_errors(connect_opts, Any::database_exists).await?; if !exists { - #[cfg(feature = "_sqlite")] + #[cfg(any(feature = "sqlite", feature = "sqlite-unbundled"))] sqlx::sqlite::CREATE_DB_WAL.store( connect_opts.sqlite_create_db_wal, std::sync::atomic::Ordering::Release, diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index 07058aa147..0b771921b9 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -262,7 +262,7 @@ pub struct ConnectOpts { /// However, if your application sets a `journal_mode` on `SqliteConnectOptions` to something /// other than `Wal`, then it will have to take the database file out of WAL mode on connecting, /// which requires an exclusive lock and may return a `database is locked` (`SQLITE_BUSY`) error. - #[cfg(feature = "_sqlite")] + #[cfg(any(feature = "sqlite", feature = "sqlite-unbundled"))] #[clap(long, action = clap::ArgAction::Set, default_value = "true")] pub sqlite_create_db_wal: bool, } diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index dcd8083023..cf1da37052 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -19,7 +19,14 @@ any = [] json = ["serde", "serde_json"] # for conditional compilation -_rt-async-std = ["async-std", "async-io"] +_rt-async-global-executor = [ + "async-global-executor", + "async-io-global-executor", + "async-lock", + "async-net", +] +_rt-async-std = ["async-std", "async-io-std"] +_rt-smol = ["smol"] _rt-tokio = ["tokio", "tokio-stream"] _tls-native-tls = ["native-tls"] _tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"] @@ -33,7 +40,9 @@ offline = ["serde", "either/serde"] [dependencies] # Runtimes +async-global-executor = { workspace = true, optional = true } async-std = { workspace = true, optional = true } +smol = { workspace = true, optional = true } tokio = { workspace = true, optional = true } # TLS @@ -52,9 +61,13 @@ ipnetwork = { workspace = true, optional = true } mac_address = { workspace = true, optional = true } uuid = { workspace = true, optional = true } -async-io = { version = "1.9.0", optional = true } +async-io-global-executor = { package = "async-io", version = "2.2", optional = true } +async-io-std = { package = "async-io", version = "1.9.0", optional = true } +async-lock = { version = "3.4.0", optional = true } +async-net = { package = "async-net", version = "2.0.0", optional = true } base64 = { version = "0.22.0", default-features = false, features = ["std"] } bytes = "1.1.0" +cfg-if = { workspace = true } chrono = { version = "0.4.34", default-features = false, features = ["clock"], optional = true } crc = { version = "3", optional = true } crossbeam-queue = "0.3.2" diff --git a/sqlx-core/src/net/socket/mod.rs b/sqlx-core/src/net/socket/mod.rs index d11f15884e..8d3d71905d 100644 --- a/sqlx-core/src/net/socket/mod.rs +++ b/sqlx-core/src/net/socket/mod.rs @@ -5,6 +5,7 @@ use std::pin::Pin; use std::task::{ready, Context, Poll}; use bytes::BufMut; +use cfg_if::cfg_if; pub use buffered::{BufferedSocket, WriteBuffer}; @@ -202,43 +203,103 @@ pub async fn connect_tcp( return Ok(with_socket.with_socket(stream).await); } - #[cfg(feature = "_rt-async-std")] - { - use async_io::Async; - use async_std::net::ToSocketAddrs; - use std::net::TcpStream; - - let mut last_err = None; - - // Loop through all the Socket Addresses that the hostname resolves to - for socket_addr in (host, port).to_socket_addrs().await? { - let stream = Async::::connect(socket_addr) - .await - .and_then(|s| { - s.get_ref().set_nodelay(true)?; - Ok(s) - }); - match stream { - Ok(stream) => return Ok(with_socket.with_socket(stream).await), - Err(e) => last_err = Some(e), + cfg_if! { + if #[cfg(feature = "_rt-async-global-executor")] { + use async_io_global_executor::Async; + use async_net::resolve; + use std::net::TcpStream; + + let mut last_err = None; + + // Loop through all the Socket Addresses that the hostname resolves to + for socket_addr in resolve((host, port)).await? { + let stream = Async::::connect(socket_addr) + .await + .and_then(|s| { + s.get_ref().set_nodelay(true)?; + Ok(s) + }); + match stream { + Ok(stream) => return Ok(with_socket.with_socket(stream).await), + Err(e) => last_err = Some(e), + } } - } - // If we reach this point, it means we failed to connect to any of the addresses. - // Return the last error we encountered, or a custom error if the hostname didn't resolve to any address. - match last_err { - Some(err) => Err(err.into()), - None => Err(io::Error::new( - io::ErrorKind::AddrNotAvailable, - "Hostname did not resolve to any addresses", - ) - .into()), - } - } + // If we reach this point, it means we failed to connect to any of the addresses. + // Return the last error we encountered, or a custom error if the hostname didn't resolve to any address. + Err(match last_err { + Some(err) => err, + None => io::Error::new( + io::ErrorKind::AddrNotAvailable, + "Hostname did not resolve to any addresses", + ), + } + .into()) + } else if #[cfg(feature = "_rt-async-std")] { + use async_io_std::Async; + use async_std::net::ToSocketAddrs; + use std::net::TcpStream; + + let mut last_err = None; + + // Loop through all the Socket Addresses that the hostname resolves to + for socket_addr in (host, port).to_socket_addrs().await? { + let stream = Async::::connect(socket_addr) + .await + .and_then(|s| { + s.get_ref().set_nodelay(true)?; + Ok(s) + }); + match stream { + Ok(stream) => return Ok(with_socket.with_socket(stream).await), + Err(e) => last_err = Some(e), + } + } - #[cfg(not(feature = "_rt-async-std"))] - { - crate::rt::missing_rt((host, port, with_socket)) + // If we reach this point, it means we failed to connect to any of the addresses. + // Return the last error we encountered, or a custom error if the hostname didn't resolve to any address. + Err(match last_err { + Some(err) => err, + None => io::Error::new( + io::ErrorKind::AddrNotAvailable, + "Hostname did not resolve to any addresses", + ), + } + .into()) + } else if #[cfg(feature = "_rt-smol")] { + use smol::net::resolve; + use smol::Async; + use std::net::TcpStream; + + let mut last_err = None; + + // Loop through all the Socket Addresses that the hostname resolves to + for socket_addr in resolve((host, port)).await? { + let stream = Async::::connect(socket_addr) + .await + .and_then(|s| { + s.get_ref().set_nodelay(true)?; + Ok(s) + }); + match stream { + Ok(stream) => return Ok(with_socket.with_socket(stream).await), + Err(e) => last_err = Some(e), + } + } + + // If we reach this point, it means we failed to connect to any of the addresses. + // Return the last error we encountered, or a custom error if the hostname didn't resolve to any address. + Err(match last_err { + Some(err) => err, + None => io::Error::new( + io::ErrorKind::AddrNotAvailable, + "Hostname did not resolve to any addresses", + ), + } + .into()) + } else { + crate::rt::missing_rt((host, port, with_socket)) + } } } @@ -260,19 +321,31 @@ pub async fn connect_uds, Ws: WithSocket>( return Ok(with_socket.with_socket(stream).await); } - #[cfg(feature = "_rt-async-std")] - { - use async_io::Async; - use std::os::unix::net::UnixStream; + cfg_if! { + if #[cfg(feature = "_rt-async-global-executor")] { + use async_io_global_executor::Async; + use std::os::unix::net::UnixStream; - let stream = Async::::connect(path).await?; + let stream = Async::::connect(path).await?; - Ok(with_socket.with_socket(stream).await) - } + Ok(with_socket.with_socket(stream).await) + } else if #[cfg(feature = "_rt-async-std")] { + use async_io_std::Async; + use std::os::unix::net::UnixStream; + + let stream = Async::::connect(path).await?; + + Ok(with_socket.with_socket(stream).await) + } else if #[cfg(feature = "_rt-smol")] { + use smol::Async; + use std::os::unix::net::UnixStream; - #[cfg(not(feature = "_rt-async-std"))] - { - crate::rt::missing_rt((path, with_socket)) + let stream = Async::::connect(path).await?; + + Ok(with_socket.with_socket(stream).await) + } else { + crate::rt::missing_rt((path, with_socket)) + } } } diff --git a/sqlx-core/src/rt/mod.rs b/sqlx-core/src/rt/mod.rs index 43409073ab..0451092691 100644 --- a/sqlx-core/src/rt/mod.rs +++ b/sqlx-core/src/rt/mod.rs @@ -4,19 +4,35 @@ use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -#[cfg(feature = "_rt-async-std")] -pub mod rt_async_std; +use cfg_if::cfg_if; + +#[cfg(any( + feature = "_rt-async-global-executor", + feature = "_rt-async-std", + feature = "_rt-smol" +))] +pub mod rt_async_io; + +#[cfg(feature = "_rt-async-global-executor")] +pub mod rt_async_global_executor; + +#[cfg(feature = "_rt-smol")] +pub mod rt_smol; #[cfg(feature = "_rt-tokio")] pub mod rt_tokio; #[derive(Debug, thiserror::Error)] #[error("operation timed out")] -pub struct TimeoutError(()); +pub struct TimeoutError; pub enum JoinHandle { + #[cfg(feature = "_rt-async-global-executor")] + AsyncGlobalExecutor(rt_async_global_executor::JoinHandle), #[cfg(feature = "_rt-async-std")] AsyncStd(async_std::task::JoinHandle), + #[cfg(feature = "_rt-smol")] + Smol(rt_smol::JoinHandle), #[cfg(feature = "_rt-tokio")] Tokio(tokio::task::JoinHandle), // `PhantomData` requires `T: Unpin` @@ -28,18 +44,22 @@ pub async fn timeout(duration: Duration, f: F) -> Result(f: F) -> F::Output { - #[cfg(feature = "_rt-tokio")] - { - return tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("failed to start Tokio runtime") - .block_on(f); - } - - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] - { - async_std::task::block_on(f) - } - - #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] - { - missing_rt(f) + cfg_if! { + if #[cfg(feature = "_rt-async-global-executor")] { + async_io_global_executor::block_on(f) + } else if #[cfg(feature = "_rt-async-std")] { + async_std::task::block_on(f) + } else if #[cfg(feature = "_rt-smol")] { + smol::block_on(f) + } else if #[cfg(feature = "_rt-tokio")] { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("failed to start Tokio runtime") + .block_on(f) + } else { + missing_rt(f) + } } } @@ -140,7 +183,7 @@ pub fn missing_rt(_unused: T) -> ! { panic!("this functionality requires a Tokio context") } - panic!("either the `runtime-async-std` or `runtime-tokio` feature must be enabled") + panic!("one of the `runtime-async-global-executor`, `runtime-async-std`, `runtime-smol`, or `runtime-tokio` feature must be enabled") } impl Future for JoinHandle { @@ -149,8 +192,12 @@ impl Future for JoinHandle { #[track_caller] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match &mut *self { + #[cfg(feature = "_rt-async-global-executor")] + Self::AsyncGlobalExecutor(handle) => Pin::new(handle).poll(cx), #[cfg(feature = "_rt-async-std")] Self::AsyncStd(handle) => Pin::new(handle).poll(cx), + #[cfg(feature = "_rt-smol")] + Self::Smol(handle) => Pin::new(handle).poll(cx), #[cfg(feature = "_rt-tokio")] Self::Tokio(handle) => Pin::new(handle) .poll(cx) diff --git a/sqlx-core/src/rt/rt_async_global_executor/join_handle.rs b/sqlx-core/src/rt/rt_async_global_executor/join_handle.rs new file mode 100644 index 0000000000..580883e21f --- /dev/null +++ b/sqlx-core/src/rt/rt_async_global_executor/join_handle.rs @@ -0,0 +1,30 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use async_global_executor::Task; + +pub struct JoinHandle { + pub task: Option>, +} + +impl Drop for JoinHandle { + fn drop(&mut self) { + if let Some(task) = self.task.take() { + task.detach(); + } + } +} + +impl Future for JoinHandle { + type Output = T; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.task.as_mut() { + Some(task) => Future::poll(Pin::new(task), cx), + None => unreachable!("JoinHandle polled after dropping"), + } + } +} diff --git a/sqlx-core/src/rt/rt_async_global_executor/mod.rs b/sqlx-core/src/rt/rt_async_global_executor/mod.rs new file mode 100644 index 0000000000..282eb5dbdf --- /dev/null +++ b/sqlx-core/src/rt/rt_async_global_executor/mod.rs @@ -0,0 +1,8 @@ +mod join_handle; +pub use join_handle::*; + +mod timeout; +pub use timeout::*; + +pub mod yield_now; +pub use yield_now::*; diff --git a/sqlx-core/src/rt/rt_async_global_executor/timeout.rs b/sqlx-core/src/rt/rt_async_global_executor/timeout.rs new file mode 100644 index 0000000000..8188758d93 --- /dev/null +++ b/sqlx-core/src/rt/rt_async_global_executor/timeout.rs @@ -0,0 +1,20 @@ +use std::{future::Future, pin::pin, time::Duration}; + +use futures_util::future::{select, Either}; + +use crate::rt::TimeoutError; + +pub async fn sleep(duration: Duration) { + timeout_future(duration).await; +} + +pub async fn timeout(duration: Duration, future: F) -> Result { + match select(pin!(future), timeout_future(duration)).await { + Either::Left((result, _)) => Ok(result), + Either::Right(_) => Err(TimeoutError), + } +} + +fn timeout_future(duration: Duration) -> impl Future { + async_io_global_executor::Timer::after(duration) +} diff --git a/sqlx-core/src/rt/rt_async_global_executor/yield_now.rs b/sqlx-core/src/rt/rt_async_global_executor/yield_now.rs new file mode 100644 index 0000000000..1adb55e0f4 --- /dev/null +++ b/sqlx-core/src/rt/rt_async_global_executor/yield_now.rs @@ -0,0 +1,28 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +pub fn yield_now() -> impl Future { + YieldNow(false) +} + +struct YieldNow(bool); + +impl Future for YieldNow { + type Output = (); + + // The futures executor is implemented as a FIFO queue, so all this future + // does is re-schedule the future back to the end of the queue, giving room + // for other futures to progress. + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if !self.0 { + self.0 = true; + cx.waker().wake_by_ref(); + Poll::Pending + } else { + Poll::Ready(()) + } + } +} diff --git a/sqlx-core/src/rt/rt_async_std/mod.rs b/sqlx-core/src/rt/rt_async_io/mod.rs similarity index 100% rename from sqlx-core/src/rt/rt_async_std/mod.rs rename to sqlx-core/src/rt/rt_async_io/mod.rs diff --git a/sqlx-core/src/rt/rt_async_std/socket.rs b/sqlx-core/src/rt/rt_async_io/socket.rs similarity index 68% rename from sqlx-core/src/rt/rt_async_std/socket.rs rename to sqlx-core/src/rt/rt_async_io/socket.rs index 2d66d70c76..6c30f1a181 100644 --- a/sqlx-core/src/rt/rt_async_std/socket.rs +++ b/sqlx-core/src/rt/rt_async_io/socket.rs @@ -3,19 +3,29 @@ use crate::net::Socket; use std::io; use std::io::{Read, Write}; use std::net::{Shutdown, TcpStream}; - use std::task::{Context, Poll}; +use cfg_if::cfg_if; + use crate::io::ReadBuf; -use async_io::Async; + +cfg_if! { + if #[cfg(feature = "_rt-async-global-executor")] { + use async_io_global_executor::Async; + } else if #[cfg(feature = "_rt-async-std")] { + use async_io_std::Async; + } else if #[cfg(feature = "_rt-smol")] { + use smol::Async; + } +} impl Socket for Async { fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result { - self.get_mut().read(buf.init_mut()) + self.get_ref().read(buf.init_mut()) } fn try_write(&mut self, buf: &[u8]) -> io::Result { - self.get_mut().write(buf) + self.get_ref().write(buf) } fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { @@ -27,18 +37,18 @@ impl Socket for Async { } fn poll_shutdown(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(self.get_mut().shutdown(Shutdown::Both)) + Poll::Ready(self.get_ref().shutdown(Shutdown::Both)) } } #[cfg(unix)] impl Socket for Async { fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result { - self.get_mut().read(buf.init_mut()) + self.get_ref().read(buf.init_mut()) } fn try_write(&mut self, buf: &[u8]) -> io::Result { - self.get_mut().write(buf) + self.get_ref().write(buf) } fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { @@ -50,6 +60,6 @@ impl Socket for Async { } fn poll_shutdown(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(self.get_mut().shutdown(Shutdown::Both)) + Poll::Ready(self.get_ref().shutdown(Shutdown::Both)) } } diff --git a/sqlx-core/src/rt/rt_smol/join_handle.rs b/sqlx-core/src/rt/rt_smol/join_handle.rs new file mode 100644 index 0000000000..6702733c4a --- /dev/null +++ b/sqlx-core/src/rt/rt_smol/join_handle.rs @@ -0,0 +1,30 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use smol::Task; + +pub struct JoinHandle { + pub task: Option>, +} + +impl Drop for JoinHandle { + fn drop(&mut self) { + if let Some(task) = self.task.take() { + task.detach(); + } + } +} + +impl Future for JoinHandle { + type Output = T; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.task.as_mut() { + Some(task) => Future::poll(Pin::new(task), cx), + None => unreachable!("JoinHandle polled after dropping"), + } + } +} diff --git a/sqlx-core/src/rt/rt_smol/mod.rs b/sqlx-core/src/rt/rt_smol/mod.rs new file mode 100644 index 0000000000..86f2b8c964 --- /dev/null +++ b/sqlx-core/src/rt/rt_smol/mod.rs @@ -0,0 +1,5 @@ +mod join_handle; +pub use join_handle::*; + +mod timeout; +pub use timeout::*; diff --git a/sqlx-core/src/rt/rt_smol/timeout.rs b/sqlx-core/src/rt/rt_smol/timeout.rs new file mode 100644 index 0000000000..75afff58ba --- /dev/null +++ b/sqlx-core/src/rt/rt_smol/timeout.rs @@ -0,0 +1,20 @@ +use std::{future::Future, pin::pin, time::Duration}; + +use futures_util::future::{select, Either}; + +use crate::rt::TimeoutError; + +pub async fn sleep(duration: Duration) { + timeout_future(duration).await; +} + +pub async fn timeout(duration: Duration, future: F) -> Result { + match select(pin!(future), timeout_future(duration)).await { + Either::Left((result, _)) => Ok(result), + Either::Right(_) => Err(TimeoutError), + } +} + +fn timeout_future(duration: Duration) -> impl Future { + smol::Timer::after(duration) +} diff --git a/sqlx-core/src/sync.rs b/sqlx-core/src/sync.rs index 27ad29c33e..ed082f752c 100644 --- a/sqlx-core/src/sync.rs +++ b/sqlx-core/src/sync.rs @@ -1,17 +1,13 @@ +use cfg_if::cfg_if; + // For types with identical signatures that don't require runtime support, // we can just arbitrarily pick one to use based on what's enabled. // // We'll generally lean towards Tokio's types as those are more featureful // (including `tokio-console` support) and more widely deployed. -#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] -pub use async_std::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard}; - -#[cfg(feature = "_rt-tokio")] -pub use tokio::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard}; - pub struct AsyncSemaphore { - // We use the semaphore from futures-intrusive as the one from async-std + // We use the semaphore from futures-intrusive as the one from async-lock // is missing the ability to add arbitrary permits, and is not guaranteed to be fair: // * https://github.com/smol-rs/async-lock/issues/22 // * https://github.com/smol-rs/async-lock/issues/23 @@ -20,7 +16,14 @@ pub struct AsyncSemaphore { // and there are some soundness concerns (although it turns out any intrusive future is unsound // in MIRI due to the necessitated mutable aliasing): // https://github.com/launchbadge/sqlx/issues/1668 - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] + #[cfg(all( + any( + feature = "_rt-async-global-executor", + feature = "_rt-async-std", + feature = "_rt-smol" + ), + not(feature = "_rt-tokio") + ))] inner: futures_intrusive::sync::Semaphore, #[cfg(feature = "_rt-tokio")] @@ -30,12 +33,24 @@ pub struct AsyncSemaphore { impl AsyncSemaphore { #[track_caller] pub fn new(fair: bool, permits: usize) -> Self { - if cfg!(not(any(feature = "_rt-async-std", feature = "_rt-tokio"))) { + if cfg!(not(any( + feature = "_rt-async-global-executor", + feature = "_rt-async-std", + feature = "_rt-smol", + feature = "_rt-tokio" + ))) { crate::rt::missing_rt((fair, permits)); } AsyncSemaphore { - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] + #[cfg(all( + any( + feature = "_rt-async-global-executor", + feature = "_rt-async-std", + feature = "_rt-smol" + ), + not(feature = "_rt-tokio") + ))] inner: futures_intrusive::sync::Semaphore::new(fair, permits), #[cfg(feature = "_rt-tokio")] inner: { @@ -46,61 +61,93 @@ impl AsyncSemaphore { } pub fn permits(&self) -> usize { - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] - return self.inner.permits(); - - #[cfg(feature = "_rt-tokio")] - return self.inner.available_permits(); - - #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] - crate::rt::missing_rt(()) + cfg_if! { + if #[cfg(all( + any( + feature = "_rt-async-global-executor", + feature = "_rt-async-std", + feature = "_rt-smol" + ), + not(feature = "_rt-tokio") + ))] { + self.inner.permits() + } else if #[cfg(feature = "_rt-tokio")] { + self.inner.available_permits() + } else { + crate::rt::missing_rt(()) + } + } } pub async fn acquire(&self, permits: u32) -> AsyncSemaphoreReleaser<'_> { - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] - return AsyncSemaphoreReleaser { - inner: self.inner.acquire(permits as usize).await, - }; - - #[cfg(feature = "_rt-tokio")] - return AsyncSemaphoreReleaser { - inner: self - .inner - // Weird quirk: `tokio::sync::Semaphore` mostly uses `usize` for permit counts, - // but `u32` for this and `try_acquire_many()`. - .acquire_many(permits) - .await - .expect("BUG: we do not expose the `.close()` method"), - }; - - #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] - crate::rt::missing_rt(permits) + cfg_if! { + if #[cfg(all( + any( + feature = "_rt-async-global-executor", + feature = "_rt-async-std", + feature = "_rt-smol" + ), + not(feature = "_rt-tokio") + ))] { + AsyncSemaphoreReleaser { + inner: self.inner.acquire(permits as usize).await, + } + } else if #[cfg(feature = "_rt-tokio")] { + AsyncSemaphoreReleaser { + inner: self + .inner + // Weird quirk: `tokio::sync::Semaphore` mostly uses `usize` for permit counts, + // but `u32` for this and `try_acquire_many()`. + .acquire_many(permits) + .await + .expect("BUG: we do not expose the `.close()` method"), + } + } else { + crate::rt::missing_rt(permits) + } + } } pub fn try_acquire(&self, permits: u32) -> Option> { - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] - return Some(AsyncSemaphoreReleaser { - inner: self.inner.try_acquire(permits as usize)?, - }); - - #[cfg(feature = "_rt-tokio")] - return Some(AsyncSemaphoreReleaser { - inner: self.inner.try_acquire_many(permits).ok()?, - }); - - #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] - crate::rt::missing_rt(permits) + cfg_if! { + if #[cfg(all( + any( + feature = "_rt-async-global-executor", + feature = "_rt-async-std", + feature = "_rt-smol" + ), + not(feature = "_rt-tokio") + ))] { + Some(AsyncSemaphoreReleaser { + inner: self.inner.try_acquire(permits as usize)?, + }) + } else if #[cfg(feature = "_rt-tokio")] { + Some(AsyncSemaphoreReleaser { + inner: self.inner.try_acquire_many(permits).ok()?, + }) + } else { + crate::rt::missing_rt(permits) + } + } } pub fn release(&self, permits: usize) { - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] - return self.inner.release(permits); - - #[cfg(feature = "_rt-tokio")] - return self.inner.add_permits(permits); - - #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] - crate::rt::missing_rt(permits) + cfg_if! { + if #[cfg(all( + any( + feature = "_rt-async-global-executor", + feature = "_rt-async-std", + feature = "_rt-smol" + ), + not(feature = "_rt-tokio") + ))] { + self.inner.release(permits); + } else if #[cfg(feature = "_rt-tokio")] { + self.inner.add_permits(permits); + } else { + crate::rt::missing_rt(permits); + } + } } } @@ -114,30 +161,46 @@ pub struct AsyncSemaphoreReleaser<'a> { // and there are some soundness concerns (although it turns out any intrusive future is unsound // in MIRI due to the necessitated mutable aliasing): // https://github.com/launchbadge/sqlx/issues/1668 - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] + #[cfg(all( + any( + feature = "_rt-async-global-executor", + feature = "_rt-async-std", + feature = "_rt-smol" + ), + not(feature = "_rt-tokio") + ))] inner: futures_intrusive::sync::SemaphoreReleaser<'a>, #[cfg(feature = "_rt-tokio")] inner: tokio::sync::SemaphorePermit<'a>, - #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] + #[cfg(not(any( + feature = "_rt-async-global-executor", + feature = "_rt-async-std", + feature = "_rt-smol", + feature = "_rt-tokio" + )))] _phantom: std::marker::PhantomData<&'a ()>, } impl AsyncSemaphoreReleaser<'_> { pub fn disarm(self) { - #[cfg(feature = "_rt-tokio")] - { - self.inner.forget(); - } - - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] - { - let mut this = self; - this.inner.disarm(); + cfg_if! { + if #[cfg(all( + any( + feature = "_rt-async-global-executor", + feature = "_rt-async-std", + feature = "_rt-smol" + ), + not(feature = "_rt-tokio") + ))] { + let mut this = self; + this.inner.disarm(); + } else if #[cfg(feature = "_rt-tokio")] { + self.inner.forget(); + } else { + crate::rt::missing_rt(()); + } } - - #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] - crate::rt::missing_rt(()) } } diff --git a/sqlx-macros-core/Cargo.toml b/sqlx-macros-core/Cargo.toml index 46786b7d8d..273bc52187 100644 --- a/sqlx-macros-core/Cargo.toml +++ b/sqlx-macros-core/Cargo.toml @@ -11,7 +11,9 @@ repository.workspace = true default = [] # for conditional compilation +_rt-async-global-executor = ["async-global-executor", "sqlx-core/_rt-async-global-executor"] _rt-async-std = ["async-std", "sqlx-core/_rt-async-std"] +_rt-smol = ["smol", "sqlx-core/_rt-smol"] _rt-tokio = ["tokio", "sqlx-core/_rt-tokio"] _tls-native-tls = ["sqlx-core/_tls-native-tls"] @@ -50,9 +52,12 @@ sqlx-mysql = { workspace = true, features = ["offline", "migrate"], optional = t sqlx-postgres = { workspace = true, features = ["offline", "migrate"], optional = true } sqlx-sqlite = { workspace = true, features = ["offline", "migrate"], optional = true } +async-global-executor = { workspace = true, optional = true } async-std = { workspace = true, optional = true } +smol = { workspace = true, optional = true } tokio = { workspace = true, optional = true } +cfg-if = { workspace = true} dotenvy = { workspace = true } hex = { version = "0.4.3" } diff --git a/sqlx-macros-core/src/lib.rs b/sqlx-macros-core/src/lib.rs index e8804f57fe..bacb62f52e 100644 --- a/sqlx-macros-core/src/lib.rs +++ b/sqlx-macros-core/src/lib.rs @@ -19,6 +19,8 @@ feature(track_path) )] +use cfg_if::cfg_if; + #[cfg(feature = "macros")] use crate::query::QueryDriver; @@ -55,28 +57,29 @@ pub fn block_on(f: F) -> F::Output where F: std::future::Future, { - #[cfg(feature = "_rt-tokio")] - { - use once_cell::sync::Lazy; - use tokio::runtime::{self, Runtime}; - - // We need a single, persistent Tokio runtime since we're caching connections, - // otherwise we'll get "IO driver has terminated" errors. - static TOKIO_RT: Lazy = Lazy::new(|| { - runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("failed to start Tokio runtime") - }); + cfg_if! { + if #[cfg(feature = "_rt-async-global-executor")] { + sqlx_core::rt::test_block_on(f) + } else if #[cfg(feature = "_rt-async-std")] { + async_std::task::block_on(f) + } else if #[cfg(feature = "_rt-smol")] { + sqlx_core::rt::test_block_on(f) + } else if #[cfg(feature = "_rt-tokio")] { + use once_cell::sync::Lazy; + use tokio::runtime::{self, Runtime}; - TOKIO_RT.block_on(f) - } + // We need a single, persistent Tokio runtime since we're caching connections, + // otherwise we'll get "IO driver has terminated" errors. + static TOKIO_RT: Lazy = Lazy::new(|| { + runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("failed to start Tokio runtime") + }); - #[cfg(all(feature = "_rt-async-std", not(feature = "tokio")))] - { - async_std::task::block_on(f) + TOKIO_RT.block_on(f) + } else { + sqlx_core::rt::missing_rt(f) + } } - - #[cfg(not(any(feature = "_rt-async-std", feature = "tokio")))] - sqlx_core::rt::missing_rt(f) } diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 5617d3f251..c0fda68ec9 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -14,7 +14,9 @@ proc-macro = true default = [] # for conditional compilation +_rt-async-global-executor = ["sqlx-macros-core/_rt-async-global-executor"] _rt-async-std = ["sqlx-macros-core/_rt-async-std"] +_rt-smol = ["sqlx-macros-core/_rt-smol"] _rt-tokio = ["sqlx-macros-core/_rt-tokio"] _tls-native-tls = ["sqlx-macros-core/_tls-native-tls"]