diff --git a/Cargo.lock b/Cargo.lock index b51238e9..e62ba96f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2209,6 +2209,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive-where" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "510c292c8cf384b1a340b816a9a6cf2599eb8f566a44949024af88418000c50b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "derive_builder" version = "0.13.1" @@ -8070,7 +8081,9 @@ name = "wcn_rpc" version = "0.1.0" dependencies = [ "backoff", + "bytes", "derivative", + "derive-where", "derive_more 1.0.0", "futures", "governor", @@ -8081,12 +8094,14 @@ dependencies = [ "mini-moka", "nix", "pin-project", + "postcard", "quinn", "quinn-proto", "rand 0.8.5", "serde", "serde_json", "socket2", + "strum", "tap", "thiserror 1.0.64", "tokio", @@ -8118,12 +8133,18 @@ dependencies = [ name = "wcn_storage_api2" version = "0.1.0" dependencies = [ - "arc-swap", + "const-hex", + "derive_more 1.0.0", "futures", + "rand 0.8.5", "serde", + "strum", + "tap", "thiserror 1.0.64", "time", + "tokio", "tracing", + "tracing-subscriber", "wc", "wcn_auth", "wcn_rpc", diff --git a/Cargo.toml b/Cargo.toml index ae2fcead..2687c4f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,8 @@ derive_more = { version = "1.0.0", features = [ "deref", ] } derivative = "2" +derive-where = "1.5" +tap = "1.0" core = { package = "wcn_core", path = "crates/core" } auth = { package = "wcn_auth", path = "crates/auth" } domain = { package = "wcn_domain", path = "crates/domain" } @@ -50,11 +52,13 @@ time = "0.3" libp2p = { version = "0.55", default-features = false, features = ["serde"] } tokio-serde = { git = "https://github.com/xDarksome/tokio-serde.git", rev = "6df9ff9" } tokio-serde-postcard = { git = "https://github.com/xDarksome/tokio-serde-postcard.git", rev = "5e1b77a" } +postcard = { version = "1.0", default-features = false } itertools = "0.12" futures = "0.3" backoff = { version = "0.4", features = ["tokio"] } tracing = "0.1" tokio-stream = "0.1" +strum = "0.27" [workspace.lints.clippy] all = { level = "deny", priority = -1 } diff --git a/VERSION b/VERSION index b166d569..443a1952 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -250620.0 +250702.0 diff --git a/crates/rpc/Cargo.toml b/crates/rpc/Cargo.toml index c0887893..26d7643d 100644 --- a/crates/rpc/Cargo.toml +++ b/crates/rpc/Cargo.toml @@ -12,7 +12,9 @@ client = [] server = [] [dependencies] -derive_more = { workspace = true } +derive_more = { workspace = true, features = ["display", "try_from"] } +derive-where = { workspace = true } +strum = { workspace = true , features = ["derive"] } derivative = "2.2" futures = "0.3" indexmap = "2" @@ -26,6 +28,8 @@ tokio-stream = "0.1" tokio-util = { version = "0.7", features = ["compat", "codec"] } tokio-serde = { workspace = true, features = ["json"] } tokio-serde-postcard = { workspace = true } +postcard = { workspace = true, features = ["alloc"] } +bytes = "1" serde_json = "1" rand = { version = "0.8", features = ["small_rng"] } libp2p = { workspace = true } diff --git a/crates/rpc/src/client2.rs b/crates/rpc/src/client2.rs new file mode 100644 index 00000000..208b8991 --- /dev/null +++ b/crates/rpc/src/client2.rs @@ -0,0 +1,570 @@ +use { + crate::{ + quic::{self}, + transport, + transport2::{self, BiDirectionalStream, RecvStream, SendStream}, + BorrowedRequest, + ConnectionStatusCode, + RpcV2, + UnaryRpc, + }, + derive_where::derive_where, + futures::{ + sink::SinkMapErr, + stream::MapErr, + FutureExt, + Sink, + SinkExt, + Stream, + StreamExt as _, + TryStreamExt, + }, + libp2p::{identity, PeerId}, + std::{ + future::Future, + io, + net::{SocketAddr, SocketAddrV4}, + sync::Arc, + time::Duration, + }, + strum::{EnumDiscriminants, IntoDiscriminant, IntoStaticStr}, + tap::TapFallible as _, + tokio::{ + io::{AsyncReadExt as _, AsyncWriteExt}, + sync::{watch, Mutex}, + }, + wc::{ + future::FutureExt as _, + metrics::{self, enum_ordinalize::Ordinalize, EnumLabel, StringLabel}, + }, +}; + +// TODO: metrics, timeouts + +/// Client-specific part of an RPC [Api][`super::Api`]. +pub trait Api: super::Api + Sized { + /// Outbound [`Connection`] parameters. + type ConnectionParameters: Clone + Send + Sync + 'static; + + /// Implementor of [`HandleConnection`] of this RPC [`Api`]. + type ConnectionHandler: HandleConnection; + + //// Implementor of [`HandleRpc`] for the RPCs of this RPC [`Api`]. + type RpcHandler: Send + Sync + 'static; +} + +/// Handler of newly established outbound [`Connection`]s. +pub trait HandleConnection: Clone + Send + Sync + 'static { + /// Creates a new instance of [`Api::RpcHandler`]. + /// + /// Each outbound [`Connection`] gets a separate RPC handler. + fn new_rpc_handler(&self) -> API::RpcHandler; + + /// Handles the provided outbound [`Connection`]. + fn handle_connection( + &self, + conn: &Connection, + params: &API::ConnectionParameters, + ) -> impl Future> + Send; +} + +/// Handler of [`Outbound`] RPCs. +pub trait HandleRpc: Send + Sync + 'static { + type Input<'a>: Send + Sync + 'a; + type Output; + + fn handle_rpc<'a>( + &'a self, + rpc: Outbound, + input: &'a Self::Input<'a>, + ) -> impl Future> + Send + 'a; +} + +type RpcHandlerInput<'a, H, RPC> = >::Input<'a>; +type RpcHandlerOutput = >::Output; + +/// RPC [`Client`] config. +#[derive(Clone, Debug)] +pub struct Config { + /// [`identity::Keypair`] of the client. + pub keypair: identity::Keypair, + + /// Timeout of establishing an outbound connection. + pub connection_timeout: Duration, + + /// Highest allowed frequency of connection retries. + pub reconnect_interval: Duration, + + /// Maximum number of concurrent RPCs. + pub max_concurrent_rpcs: u32, + + /// [`transport::Priority`] of the client. + pub priority: transport::Priority, +} + +/// RPC client responsible for establishing outbound [`Connection`]s to remote +/// peers. +#[derive_where(Clone)] +pub struct Client { + config: Arc, + endpoint: quinn::Endpoint, + + connection_handler: API::ConnectionHandler, +} + +impl Client { + /// Creates a new RPC [`Client`]. + pub fn new(cfg: Config, connection_handler: API::ConnectionHandler) -> Result { + let transport_config = quic::new_quinn_transport_config(cfg.max_concurrent_rpcs); + let socket_addr = SocketAddr::new(std::net::Ipv4Addr::new(0, 0, 0, 0).into(), 0); + let endpoint = quic::new_quinn_endpoint( + socket_addr, + &cfg.keypair, + transport_config, + None, + cfg.priority, + ) + .map_err(Error::new)?; + + Ok(Client { + config: Arc::new(cfg), + endpoint, + connection_handler, + }) + } + + /// Establishes a new outbound [`Connection`]. + pub async fn connect( + &self, + addr: SocketAddrV4, + peer_id: &PeerId, + params: API::ConnectionParameters, + ) -> Result> { + async { + // `libp2p_tls` uses this "l" placeholder as server_name. + let conn = self.endpoint.connect(addr.into(), "l")?.await?; + + let remote_peer_id = quic::connection_peer_id(&conn)?; + + if *peer_id != remote_peer_id { + tracing::warn!( + expected = ?peer_id, + got = ?&remote_peer_id, + addr = ?addr, + "Wrong PeerId" + ); + + return Err(ErrorInner::WrongPeerId(remote_peer_id)); + } + + // handshake + let (mut tx, mut rx) = conn.open_bi().await?; + tx.write_u32(super::PROTOCOL_VERSION).await?; + tx.write_all(&API::NAME.0).await?; + check_connection_status(rx.read_i32().await?)?; + + let conn = self.new_connection_inner(addr, peer_id, params, Some(conn)); + + // we just created the `Connection`, the lock can't be locked + // NOTE: by holding this guard here we are also making sure that + // `ConnectionHandler::handle_connection` won't get into infinite recursion by + // trying to reconnect + let guard = conn.inner.watch_tx.try_lock().unwrap(); + let params = &guard.1; + + self.connection_handler + .handle_connection(&conn, params) + .await + .map_err(|err| err.0)?; + + drop(guard); + + tracing::info!( + api = %API::NAME, + addr = %conn.remote_peer_addr(), + peer_id = %conn.remote_peer_id(), + "Connection established" + ); + + Ok(conn) + } + .with_timeout(self.config.connection_timeout) + .await + .map_err(|_| ErrorInner::Timeout)? + .map_err(Error::new) + } + + /// Creates a new outbound [`Connection`] without waiting for it to be + /// established. + pub fn new_connection( + &self, + addr: SocketAddrV4, + peer_id: &PeerId, + params: API::ConnectionParameters, + ) -> Connection { + let conn = self.new_connection_inner(addr, peer_id, params, None); + conn.reconnect(); + conn + } + + fn new_connection_inner( + &self, + addr: SocketAddrV4, + peer_id: &PeerId, + params: API::ConnectionParameters, + quic: Option, + ) -> Connection { + let (tx, rx) = watch::channel(quic); + + Connection { + inner: Arc::new(ConnectionInner { + client: self.clone(), + remote_addr: addr, + remote_peer_id: *peer_id, + watch_rx: rx, + watch_tx: Arc::new(tokio::sync::Mutex::new((tx, params))), + rpc_handler: self.connection_handler.new_rpc_handler(), + }), + } + } +} + +/// Default implementation of [`ConnectionHandler`]. +/// +/// No-op, doesn't do anything with the [`Connection`]. +#[derive(Clone, Copy, Debug, Default)] +pub struct ConnectionHandler; + +impl HandleConnection for ConnectionHandler +where + API: Api, +{ + fn new_rpc_handler(&self) -> ::RpcHandler { + RpcHandler + } + + async fn handle_connection(&self, _conn: &Connection, _params: &()) -> Result<()> { + Ok(()) + } +} + +/// Outbound RPC of a specific type. +pub struct Outbound { + #[allow(clippy::type_complexity)] + send: SinkMapErr, fn(transport2::Error) -> Error>, + + #[allow(clippy::type_complexity)] + recv: MapErr, fn(transport2::Error) -> Error>, +} + +impl Outbound { + /// Returns mutable references to the underlying request/response streams. + pub fn streams_mut( + &mut self, + ) -> ( + &mut (impl for<'a, 'b> Sink<&'a BorrowedRequest<'b, RPC>, Error = Error> + 'static), + &mut impl Stream>, + ) { + (&mut self.send, &mut self.recv) + } +} + +/// Outbound connection. +/// +/// Existence of an instance of this type doesn't guarantee that the actual +/// network connection is already established (or will ever be established). +#[derive_where(Clone)] +pub struct Connection { + inner: Arc>, +} + +type ConnectionMutex = Mutex<(watch::Sender>, Params)>; + +struct ConnectionInner { + client: Client, + + remote_addr: SocketAddrV4, + remote_peer_id: PeerId, + + watch_rx: watch::Receiver>, + watch_tx: Arc>, + + rpc_handler: API::RpcHandler, +} + +impl Connection { + /// Returns [`SocketAddrV4`] of the remote peer. + pub fn remote_peer_addr(&self) -> &SocketAddrV4 { + &self.inner.remote_addr + } + + /// Returns [`PeerId`] of the remote peer. + pub fn remote_peer_id(&self) -> &PeerId { + &self.inner.remote_peer_id + } + + /// Indicates whether this [`Connection`] is closed. + pub fn is_closed(&self) -> bool { + self.inner + .watch_rx + .borrow() + .as_ref() + .map(|conn| conn.close_reason().is_some()) + .unwrap_or(true) + } + + /// Waits for this [`Connection`] to become open. + /// + /// IMPORTANT: This future may never resolve! Make sure that you use a + /// timeout. + pub async fn wait_open(&self) { + let mut watch_rx = self.inner.watch_rx.clone(); + match watch_rx.borrow_and_update().as_ref() { + Some(conn) if conn.close_reason().is_none() => return, + _ => {} + } + + self.reconnect(); + + drop(watch_rx.changed().await) + } + + /// Sends the provided RPC over this [`Connection`]. + pub fn send<'a, RPC: RpcV2>( + &'a self, + input: &'a RpcHandlerInput<'a, API::RpcHandler, RPC>, + ) -> Result>> + Send + 'a> + where + API::RpcHandler: HandleRpc, + { + let rpc = self.new_outbound_rpc::().tap_err(|err| { + if err.requires_reconnect() { + self.reconnect(); + } + })?; + + Ok(async move { + self.inner + .rpc_handler + .handle_rpc(rpc, input) + .await + .tap_err(|err| { + if err.0.requires_reconnect() { + self.reconnect(); + } + }) + }) + } + + fn new_outbound_rpc(&self) -> Result, ErrorInner> { + let quic = self.inner.watch_rx.borrow(); + let Some(conn) = quic.as_ref() else { + return Err(ErrorInner::NotConnected); + }; + + // `open_bi` only blocks if there are too many outbound streams. + let (mut tx, rx) = match conn.open_bi().now_or_never() { + Some(Ok(stream)) => stream, + Some(Err(err)) => return Err(err.into()), + None => return Err(ErrorInner::TooManyConcurrentRpcs), + }; + + // This can only block if send buffer is full. + tx.write_u8(RPC::ID) + .now_or_never() + .ok_or_else(|| ErrorInner::SendBufferFull)??; + + let (recv, send) = BiDirectionalStream::new(tx, rx).upgrade::(); + + Ok(Outbound { + send: SinkExt::<&RPC::Request>::sink_map_err(send, |err: transport2::Error| { + Error::new(err) + }), + recv: recv.map_err(Error::new), + }) + } + + fn reconnect(&self) { + // If we can't acquire the lock then reconnection is already in progress. + let Ok(guard) = self.inner.watch_tx.clone().try_lock_owned() else { + return; + }; + + let this = self.inner.clone(); + + tokio::spawn(async move { + let mut interval = tokio::time::interval(this.client.config.reconnect_interval); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + loop { + interval.tick().await; + + let res = this + .client + .connect(this.remote_addr, &this.remote_peer_id, guard.1.clone()) + .await; + + match res { + Ok(conn) => { + let quic = conn.inner.watch_rx.borrow(); + + // should always be `Some`, as we just established the connection + let _ = guard.0.send(Some(quic.as_ref().unwrap().clone())); + return; + } + Err(err) => { + metrics::counter!( + "wcn_rpc_client_connection_errors", + StringLabel<"remote_addr", SocketAddrV4> => &this.remote_addr, + StringLabel<"remote_peer_id", PeerId> => &this.remote_peer_id, + EnumLabel<"kind", ErrorKind> => err.0.discriminant() + ) + .increment(1); + } + } + } + }); + } +} + +/// Default implementation of [`HandleRpc`]. +/// +/// Automatically implements [`HandleRpc`] for all [`UnaryRpc`]s. +/// +/// You'll need to provide a manual implementation of [`HandleRpc`] for your +/// custom RPCs. +#[derive(Clone, Copy, Debug, Default)] +pub struct RpcHandler; + +impl HandleRpc for RpcHandler { + type Input<'a> = BorrowedRequest<'a, RPC>; + type Output = RPC::Response; + + async fn handle_rpc<'a>( + &'a self, + mut rpc: Outbound, + req: &'a BorrowedRequest<'a, RPC>, + ) -> Result { + let (tx, rx) = rpc.streams_mut(); + tx.send(req).await?; + rx.next() + .await + .ok_or_else(|| Error::new(transport2::Error::StreamFinished))? + } +} + +/// RPC [`Client`] error. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct Error(ErrorInner); + +/// RPC [`Client`] result. +pub type Result = std::result::Result; + +impl Error { + fn new(err: impl Into) -> Self { + Self(err.into()) + } +} + +impl From for Error { + fn from(err: ErrorInner) -> Self { + Self::new(err) + } +} + +#[derive(Debug, thiserror::Error, EnumDiscriminants)] +#[strum_discriminants(name(ErrorKind))] +#[strum_discriminants(derive(Ordinalize, IntoStaticStr))] +enum ErrorInner { + #[error("Not connected")] + NotConnected, + + #[error("QUIC: {0}")] + Quic(#[from] quic::Error), + + #[error(transparent)] + ExtractPeerId(#[from] quic::ExtractPeerIdError), + + #[error("Wrong PeerId: {0}")] + WrongPeerId(PeerId), + + #[error("Connect: {0:?}")] + Connect(#[from] quinn::ConnectError), + + #[error("Connection: {0:?}")] + Connection(#[from] quinn::ConnectionError), + + #[error("IO: {0:?}")] + Io(#[from] io::Error), + + #[error("Write: {0:?}")] + Write(#[from] quinn::WriteError), + + #[error("Too many concurrent RPCs")] + TooManyConcurrentRpcs, + + #[error("Send buffer is full")] + SendBufferFull, + + #[error("Timeout")] + Timeout, + + #[error("Transport: {0}")] + Transport(#[from] transport2::Error), + + #[error("Unknown ConnectionStatusCode({0})")] + UnknownConnectionStatusCode(i32), + + #[error("Unsupported protocol")] + UnsupportedProtocol, + + #[error("Unknown API")] + UnknownApi, + + #[error("Unauthorized")] + Unauthorized, +} + +impl ErrorInner { + fn requires_reconnect(&self) -> bool { + match self { + ErrorInner::NotConnected + | ErrorInner::Quic(_) + | ErrorInner::ExtractPeerId(_) + | ErrorInner::WrongPeerId(_) + | ErrorInner::Connect(_) + | ErrorInner::Connection(_) + | ErrorInner::Io(_) + | ErrorInner::Write(_) + | ErrorInner::Timeout + | ErrorInner::Transport(_) + | ErrorInner::UnknownConnectionStatusCode(_) + | ErrorInner::UnsupportedProtocol + | ErrorInner::UnknownApi + | ErrorInner::Unauthorized => true, + + ErrorInner::TooManyConcurrentRpcs | ErrorInner::SendBufferFull => false, + } + } +} + +fn check_connection_status(code: i32) -> Result<(), ErrorInner> { + let code = ConnectionStatusCode::try_from(code) + .map_err(|err| ErrorInner::UnknownConnectionStatusCode(err.input))?; + + Err(match code { + ConnectionStatusCode::Ok => return Ok(()), + ConnectionStatusCode::UnsupportedProtocol => ErrorInner::UnsupportedProtocol, + ConnectionStatusCode::UnknownApi => ErrorInner::UnknownApi, + ConnectionStatusCode::Unauthorized => ErrorInner::Unauthorized, + }) +} + +impl metrics::Enum for ErrorKind { + fn as_str(&self) -> &'static str { + self.into() + } +} + +// TODO: Vec Load Balancer diff --git a/crates/rpc/src/lib.rs b/crates/rpc/src/lib.rs index 66b5e887..d09f5a32 100644 --- a/crates/rpc/src/lib.rs +++ b/crates/rpc/src/lib.rs @@ -1,32 +1,120 @@ #![allow(async_fn_in_trait)] #![allow(clippy::manual_async_fn)] -pub use libp2p::{identity, Multiaddr, PeerId}; use { - derive_more::Display, - serde::{Deserialize, Serialize}, + derive_more::{derive::TryFrom, Display}, + serde::{de::DeserializeOwned, Deserialize, Serialize}, std::{borrow::Cow, fmt::Debug, marker::PhantomData, net::SocketAddr, str::FromStr}, transport::Codec, + transport2::Codec as CodecV2, +}; +pub use { + libp2p::{identity, Multiaddr, PeerId}, + transport2::PostcardCodec, }; #[cfg(feature = "client")] pub mod client; #[cfg(feature = "client")] +pub mod client2; +#[cfg(feature = "client")] pub use client::Client; #[cfg(feature = "server")] pub mod server; #[cfg(feature = "server")] +pub mod server2; +#[cfg(feature = "server")] pub use server::{IntoServer, Server}; pub mod middleware; pub mod quic; pub mod transport; +mod transport2; #[cfg(test)] mod test; +const PROTOCOL_VERSION: u32 = 0; + +/// RPC API specification. +pub trait Api: Clone + Send + Sync + 'static { + /// [`ApiName`] of this [`Api`]. + const NAME: ApiName; + + /// `enum` representation of all RPC IDs of this [`Api`]. + type RpcId: Copy + Into + TryFrom + Send + Sync + 'static; +} + +/// [`Api`] name. +pub type ApiName = ServerName; + +/// Remote procedure call. +pub trait RpcV2: Sized + Send + Sync + 'static { + /// ID of this [`Rpc`]. + const ID: u8; + + /// Request type of this [`Rpc`]. + type Request: MessageV2; + + /// Response type of this [`Rpc`]. + type Response: MessageV2; + + /// Serialization codec of this [`Rpc`]. + type Codec: CodecV2 + CodecV2; +} + +/// [`RpcV2::Request`]. +pub type Request = ::Request; + +/// [`RpcV2::Response`]. +pub type Response = ::Response; + +/// [`MessageV2::Borrowed`] of [`RpcV2::Request`]. +pub type BorrowedRequest<'a, RPC> = <::Request as MessageV2>::Borrowed<'a>; + +/// [`MessageV2::Borrowed`] of [`RpcV2::Response`]. +pub type BorrowedResponse<'a, RPC> = <::Response as MessageV2>::Borrowed<'a>; + +/// Request-response RPC. +pub trait UnaryRpc: RpcV2 {} + +/// Default implementation of [`UnaryRpc`]. +pub struct UnaryV2 + CodecV2> { + _marker: PhantomData<(Req, Resp, C)>, +} + +impl RpcV2 for UnaryV2 +where + Req: MessageV2, + Resp: MessageV2, + C: CodecV2 + CodecV2, +{ + const ID: u8 = ID; + type Request = Req; + type Response = Resp; + type Codec = C; +} + +impl UnaryRpc for UnaryV2 +where + Req: MessageV2, + Resp: MessageV2, + C: CodecV2 + CodecV2, +{ +} + +/// RPC message. +pub trait MessageV2: DeserializeOwned + Serialize + Unpin + Sync + Send + 'static { + type Borrowed<'a>: BorrowedMessage; +} + +/// Borrowed [Message][`MessageV2`]. +pub trait BorrowedMessage: Serialize + Unpin + Sync + Send {} + +impl BorrowedMessage for T where T: Serialize + Unpin + Sync + Send {} + /// Error codes produced by this module. pub mod error_code { #[allow(unused_imports)] @@ -120,7 +208,8 @@ impl Name { } /// RPC server name. -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +#[derive(Debug, Display, Clone, Copy, Hash, PartialEq, Eq)] +#[display("{}", self.as_str())] pub struct ServerName([u8; 16]); impl ServerName { @@ -293,3 +382,37 @@ impl Debug for PeerAddr { Debug::fmt(&self.to_string(), f) } } + +#[derive(Clone, Copy, Debug, TryFrom)] +#[try_from(repr)] +#[repr(i32)] +enum ConnectionStatusCode { + Ok = 0, + + UnsupportedProtocol = -1, + UnknownApi = -2, + Unauthorized = -3, +} + +impl MessageV2 for Result +where + T: MessageV2, + E: MessageV2, +{ + type Borrowed<'a> = Result, E::Borrowed<'a>>; +} + +impl MessageV2 for Option +where + T: MessageV2, +{ + type Borrowed<'a> = Option>; +} + +impl MessageV2 for () { + type Borrowed<'a> = (); +} + +impl MessageV2 for u64 { + type Borrowed<'a> = Self; +} diff --git a/crates/rpc/src/quic/mod.rs b/crates/rpc/src/quic/mod.rs index 3fae7f25..40806aa4 100644 --- a/crates/rpc/src/quic/mod.rs +++ b/crates/rpc/src/quic/mod.rs @@ -29,16 +29,17 @@ mod metrics; const PROTOCOL_VERSION: u32 = 1; -#[derive(Default)] -struct ConnectionHeader { - server_name: Option, +pub(crate) struct ConnectionHeader { + pub server_name: Option, } #[derive(Clone, Debug, thiserror::Error, Eq, PartialEq)] #[error("{0}: invalid QUIC Multiaddr")] pub struct InvalidMultiaddrError(Multiaddr); -fn new_quinn_transport_config(max_concurrent_streams: u32) -> Arc { +pub(crate) fn new_quinn_transport_config( + max_concurrent_streams: u32, +) -> Arc { const STREAM_WINDOW: u32 = 4 * 1024 * 1024; // 4 MiB // Our tests are too slow and connections get dropped because of missing keep @@ -61,7 +62,7 @@ fn new_quinn_transport_config(max_concurrent_streams: u32) -> Arc, @@ -177,7 +178,7 @@ pub enum Error { InvalidConnectionRate, } -fn connection_peer_id(conn: &quinn::Connection) -> Result { +pub(crate) fn connection_peer_id(conn: &quinn::Connection) -> Result { use ExtractPeerIdError as Error; let identity = conn.peer_identity().ok_or(Error::MissingPeerIdentity)?; diff --git a/crates/rpc/src/quic/server.rs b/crates/rpc/src/quic/server.rs index 697c37dc..4b5c12a5 100644 --- a/crates/rpc/src/quic/server.rs +++ b/crates/rpc/src/quic/server.rs @@ -24,7 +24,7 @@ use { }, }; -mod filter; +pub(crate) mod filter; /// QUIC RPC server config. pub struct Config { @@ -66,7 +66,11 @@ where S: Send + Sync + 'static, Server: Multiplexer, { - let filter = Filter::new(&cfg)?; + let filter = Filter::new(&filter::Config { + max_connections: cfg.max_connections, + max_connections_per_ip: cfg.max_connections_per_ip, + max_connection_rate_per_ip: cfg.max_connection_rate_per_ip, + })?; let server = Server::new(rpc_servers, cfg)?; Ok(server.serve(filter)) } diff --git a/crates/rpc/src/quic/server/filter.rs b/crates/rpc/src/quic/server/filter.rs index 5d893c6d..82078cf1 100644 --- a/crates/rpc/src/quic/server/filter.rs +++ b/crates/rpc/src/quic/server/filter.rs @@ -21,8 +21,14 @@ pub struct Filter { max_connections_per_ip: u32, } +pub(crate) struct Config { + pub max_connections: u32, + pub max_connections_per_ip: u32, + pub max_connection_rate_per_ip: u32, +} + impl Filter { - pub fn new(cfg: &super::Config) -> Result { + pub fn new(cfg: &Config) -> Result { let max_connection_rate: NonZeroU32 = cfg .max_connection_rate_per_ip .try_into() diff --git a/crates/rpc/src/server2.rs b/crates/rpc/src/server2.rs new file mode 100644 index 00000000..c1732c4f --- /dev/null +++ b/crates/rpc/src/server2.rs @@ -0,0 +1,602 @@ +use { + crate::{ + self as rpc, + quic::{ + self, + server::filter::{self, Filter, RejectionReason}, + }, + transport, + transport2::{self, BiDirectionalStream, RecvStream, SendStream}, + Api, + ApiName, + BorrowedResponse, + ConnectionStatusCode, + RpcV2, + ServerName, + UnaryRpc, + }, + derive_where::derive_where, + futures::{ + sink::SinkMapErr, + stream::MapErr, + FutureExt, + Sink, + SinkExt, + Stream, + StreamExt, + TryFutureExt as _, + TryStreamExt as _, + }, + libp2p::{identity, PeerId}, + quinn::crypto::rustls::{self, QuicServerConfig}, + std::{future::Future, io, marker::PhantomData, sync::Arc, time::Duration}, + tap::Pipe as _, + tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::{OwnedSemaphorePermit, Semaphore}, + }, + wc::{ + future::{FutureExt as _, StaticFutureExt}, + metrics::{self, future_metrics, Enum as _, EnumLabel, FutureExt as _, StringLabel}, + }, +}; + +// TODO: Authorization, metrics, timeouts + +/// Handler of newly established inbound [`Connection`]s. +pub trait HandleConnection: Clone + Send + Sync + 'static { + /// RPC [`Api`] the connections of which are being handled. + type Api: Api; + + /// Handles the provided inbound [`Connection`]. + fn handle_connection( + &self, + conn: Connection<'_, Self::Api>, + ) -> impl Future> + Send; +} + +/// Handler of [`Inbound`] RPCs. +pub trait HandleRpc: Send + Sync { + /// Handles the provided [`Inbound`] RPC. + fn handle_rpc<'a>( + &'a self, + rpc: &'a mut Inbound, + ) -> impl Future> + Send + 'a; +} + +/// [`HandleRpc`] specialization for [`UnaryRpc`]s. +pub trait HandleRequest: Send + Sync { + /// Handles the provided RPC request. + fn handle_request( + &self, + request: RPC::Request, + ) -> impl Future + Send + '_; +} + +impl HandleRpc for H +where + RPC: UnaryRpc, + H: HandleRequest, +{ + fn handle_rpc<'a>( + &'a self, + rpc: &'a mut Inbound, + ) -> impl Future> + Send + 'a { + async { + let req = rpc + .recv + .next() + .await + .ok_or_else(|| Error::new(transport2::Error::StreamFinished))??; + + let resp = self.handle_request(req).await; + + rpc.send.send(&resp).await?; + + Ok(()) + } + } +} + +/// RPC server config. +pub struct Config { + /// Name of the server. For metrics purposes only. + pub name: &'static str, + + /// [`Multiaddr`] to bind the server to. + pub port: u16, + + /// [`identity::Keypair`] of the server. + pub keypair: identity::Keypair, + + /// Timeout of establishing an inbound connection. + pub connection_timeout: Duration, + + /// Maximum global number of concurrent connections. + pub max_connections: u32, + + /// Maximum number of concurrent connections per client IP address. + pub max_connections_per_ip: u32, + + /// Maximum number of connections accepted per client IP address per second. + pub max_connection_rate_per_ip: u32, + + /// Maximum number of concurrent RPCs. + pub max_concurrent_rpcs: u32, + + /// [`transport::Priority`] of the server. + pub priority: transport::Priority, +} + +/// Creates a new RPC [`Api`] server. +pub fn new(connection_handler: impl HandleConnection) -> impl Server { + ApiServer { connection_handler } +} + +/// RPC server. +pub trait Server: Sized + Send + Sync + 'static +where + Self: sealed::ConnectionRouter, +{ + /// Multiplexes `self` with another [`Server`]. + fn multiplex(self, api_server: impl Server) -> impl Server { + Multiplexer { + head: api_server, + tail: self, + } + } + + /// Runs this RPC [`Server`] + fn serve(self, cfg: Config) -> Result + Send> { + let transport_config = quic::new_quinn_transport_config(cfg.max_concurrent_rpcs); + let server_tls_config = libp2p_tls::make_server_config(&cfg.keypair).map_err(Error::new)?; + let server_tls_config = + QuicServerConfig::try_from(server_tls_config).map_err(Error::new)?; + let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_tls_config)); + server_config.transport = transport_config.clone(); + server_config.migration(false); + + let endpoint = quic::new_quinn_endpoint( + ([0, 0, 0, 0], cfg.port).into(), + &cfg.keypair, + transport_config, + Some(server_config), + cfg.priority, + ) + .map_err(Error::new)?; + + let connection_filter = Filter::new(&filter::Config { + max_connections: cfg.max_connections, + max_connections_per_ip: cfg.max_connections_per_ip, + max_connection_rate_per_ip: cfg.max_connection_rate_per_ip, + }) + .map_err(Error::new)?; + + let rpc_semaphore = Arc::new(Semaphore::new(cfg.max_concurrent_rpcs as usize)); + + tracing::info!(port = cfg.port, server_name = cfg.name, "Serving"); + + Ok(accept_connections( + cfg, + endpoint, + connection_filter, + rpc_semaphore, + self, + )) + } +} + +mod sealed { + use super::*; + + pub trait ConnectionRouter: Clone + Send + Sync + 'static { + fn contains_api(&self, api_name: &ApiName) -> bool; + + /// Routes [`Connection`] to the [`Api`] connection handler. + fn route_connection( + &self, + api_name: &ApiName, + conn: Connection<'_>, + ) -> impl Future> + Send; + } +} + +#[derive_where(Clone)] +struct ApiServer { + connection_handler: H, +} + +#[derive(Clone)] +struct Multiplexer { + head: A, + tail: B, +} + +impl sealed::ConnectionRouter for ApiServer { + fn contains_api(&self, api_name: &ApiName) -> bool { + api_name == &H::Api::NAME + } + + async fn route_connection(&self, api_name: &ApiName, conn: Connection<'_>) -> Result<()> { + if !self.contains_api(api_name) { + return Err(ErrorInner::UnknownApi(*api_name).into()); + } + + self.connection_handler + .handle_connection(conn.specify_api()) + .await + } +} + +impl sealed::ConnectionRouter for Multiplexer +where + A: sealed::ConnectionRouter, + B: sealed::ConnectionRouter, +{ + fn contains_api(&self, api_name: &ApiName) -> bool { + self.head.contains_api(api_name) || self.tail.contains_api(api_name) + } + + async fn route_connection(&self, api_name: &ApiName, conn: Connection<'_>) -> Result<()> { + if self.head.contains_api(api_name) { + self.head.route_connection(api_name, conn).await + } else { + self.tail.route_connection(api_name, conn).await + } + } +} + +impl Server for R {} + +async fn accept_connections( + config: Config, + endpoint: quinn::Endpoint, + connection_filter: Filter, + rpc_semaphore: Arc, + router: R, +) { + while let Some(incoming) = endpoint.accept().await { + match connection_filter.try_acquire_permit(&incoming) { + Ok(permit) => match incoming.accept() { + Ok(connecting) => accept_connection( + config.connection_timeout, + config.name, + connecting, + permit, + rpc_semaphore.clone(), + router.clone(), + ), + + Err(err) => tracing::warn!(?err, "failed to accept incoming connection"), + }, + + Err(err) => { + if err == filter::RejectionReason::AddressNotValidated { + // Signal the client to retry with validated address. + let _ = incoming.retry(); + } else { + tracing::debug!( + server_name = config.name, + reason = err.as_str(), + remote_addr = ?incoming.remote_address().ip(), + "inbound connection dropped" + ); + + metrics::counter!( + "wcn_rpc_quic_server_connections_dropped", + EnumLabel<"reason", RejectionReason> => err, + StringLabel<"server_name"> => config.name + ) + .increment(1); + + // Calling `ignore()` instead of dropping avoids sending a response. + incoming.ignore(); + } + } + }; + } +} + +fn accept_connection( + timeout: Duration, + server_name: &'static str, + connecting: quinn::Connecting, + permit: filter::Permit, + rpc_semaphore: Arc, + router: R, +) { + use ConnectionStatusCode as Status; + + async move { + let (api_name, conn, mut tx) = async move { + let conn = connecting.await?; + let remote_peer_id = quic::connection_peer_id(&conn)?; + + let (mut tx, mut rx) = conn.accept_bi().await?; + + let protocol_version = rx.read_u32().await?; + + let api_name = match protocol_version { + super::PROTOCOL_VERSION => { + let mut buf = [0; 16]; + rx.read_exact(&mut buf).await?; + ServerName(buf) + } + ver => { + tx.write_i32(Status::UnsupportedProtocol as i32).await?; + return Err(ErrorInner::UnsupportedProtocolVersion(ver)); + } + }; + + let conn = ConnectionInner { + server_name, + _permit: permit, + remote_peer_id, + rpc_semaphore, + quic: conn, + }; + + Ok((api_name, conn, tx)) + } + .with_timeout(timeout) + .map_err(|_| ErrorInner::ConnectionTimeout) + .await? + .map_err(Error::new)?; + + let conn = Connection { + inner: &conn, + _marker: PhantomData, + }; + + let status = if router.contains_api(&api_name) { + ConnectionStatusCode::Ok + } else { + ConnectionStatusCode::UnknownApi + }; + tx.write_i32(status as i32).await.map_err(Error::new)?; + + router.route_connection(&api_name, conn).await + } + .map_err(|err| tracing::debug!(?err, "Inbound connection handler failed")) + .with_metrics(future_metrics!("wcn_rpc_server_inbound_connection")) + .pipe(tokio::spawn); +} + +/// Inbound connection. +#[derive(Clone, Copy)] +pub struct Connection<'a, API = ()> { + inner: &'a ConnectionInner, + _marker: PhantomData, +} + +impl<'a> Connection<'a> { + fn specify_api(self) -> Connection<'a, API> { + Connection { + inner: self.inner, + _marker: PhantomData, + } + } +} + +struct ConnectionInner { + server_name: &'static str, + + remote_peer_id: PeerId, + + _permit: filter::Permit, + rpc_semaphore: Arc, + + quic: quinn::Connection, +} + +impl<'a, API: Api> From<&'a mut ConnectionInner> for Connection<'a, API> { + fn from(inner: &'a mut ConnectionInner) -> Self { + Self { + inner, + _marker: PhantomData, + } + } +} + +/// Inbound RPC with yet undefined type. +pub struct InboundRpc { + id: API::RpcId, + stream: BiDirectionalStream, + permit: OwnedSemaphorePermit, +} + +impl InboundRpc { + /// Handles this RPC using the provided handler. + pub fn handle( + self, + handler: &impl HandleRpc, + ) -> impl Future> + Send + '_ { + async move { handler.handle_rpc(&mut self.upgrade()).await } + .with_metrics(future_metrics!("wcn_rpc_server_rpc")) + } + + /// Upgrades this untyped [`InboundRpc`] into a typed one. + /// + /// Caller is expected to ensure that the [`InboundRpc::id()`] is correct. + fn upgrade(self) -> Inbound { + if cfg!(debug_assertions) { + let id: u8 = self.id.into(); + assert_eq!(id, RPC::ID); + } + + let (recv, send) = self.stream.upgrade(); + + Inbound { + send: SinkExt::<&RPC::Response>::sink_map_err(send, |err: transport2::Error| { + Error::new(err) + }), + recv: recv.map_err(Error::new), + _permit: self.permit, + } + } +} + +impl InboundRpc { + /// Returns ID of this [`InboundRpc`]. + pub fn id(&self) -> API::RpcId { + self.id + } +} + +/// Inbound RPC of a specific type. +pub struct Inbound { + #[allow(clippy::type_complexity)] + recv: MapErr, fn(transport2::Error) -> Error>, + + #[allow(clippy::type_complexity)] + send: SinkMapErr, fn(transport2::Error) -> Error>, + + _permit: OwnedSemaphorePermit, +} + +impl Inbound { + /// Returns mutable references to the underlying request/response streams. + pub fn streams_mut( + &mut self, + ) -> ( + &mut impl Stream>, + &mut (impl for<'a, 'b> Sink<&'a BorrowedResponse<'b, RPC>, Error = Error> + 'static), + ) { + (&mut self.recv, &mut self.send) + } +} + +impl Connection<'_, API> { + /// Returns [`PeerId`] of the remote peer. + pub fn remote_peer_id(&self) -> &PeerId { + &self.inner.remote_peer_id + } + + /// Handles this [`Connection`] by handling all [`InboundRpc`] using the + /// provided `handler_fn`. + pub async fn handle( + &self, + rpc_handler: &H, + handler_fn: fn(InboundRpc, H) -> Fut, + ) -> Result<()> + where + H: Clone + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + loop { + let rpc = self.accept_rpc().await?; + let handler = rpc_handler.clone(); + + async move { handler_fn(rpc, handler).await }.spawn(); + } + } + + /// Handles the next [`InboundRpc`] using the provided `handler_fn`. + pub async fn handle_rpc(&self, f: impl FnOnce(InboundRpc) -> F) -> Result<()> + where + F: Future> + Send, + { + let rpc = self.accept_rpc().await?; + f(rpc).await + } + + /// Accepts the next [`InboundRpc`]. + async fn accept_rpc(&self) -> Result> { + loop { + let (tx, mut rx) = self.inner.quic.accept_bi().await.map_err(Error::new)?; + + let Some(permit) = self.acquire_stream_permit() else { + metrics::counter!( + "wcn_rpc_server_rpcs_dropped", + StringLabel<"server_name"> => self.inner.server_name + ) + .increment(1); + continue; + }; + + // when we receive a stream there's always at least some data in it + let id = rx + .read_u8() + .now_or_never() + .ok_or_else(|| ErrorInner::ReadRpcId)? + .map_err(Error::new)?; + + let id = id.try_into().map_err(|_| ErrorInner::UnknownRpcId(id))?; + + return Ok(InboundRpc { + id, + stream: BiDirectionalStream::new(tx, rx), + permit, + }); + } + } + + fn acquire_stream_permit(&self) -> Option { + metrics::gauge!("wcn_rpc_server_available_rpc_permits", StringLabel<"server_name"> => self.inner.server_name) + .set(self.inner.rpc_semaphore.available_permits() as f64); + + self.inner.rpc_semaphore.clone().try_acquire_owned().ok() + } +} + +/// RPC [`Server`] error. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct Error(ErrorInner); + +/// RPC [`Server`] result. +pub type Result = std::result::Result; + +impl Error { + fn new(err: impl Into) -> Self { + Self(err.into()) + } +} + +impl From for Error { + fn from(err: ErrorInner) -> Self { + Self::new(err) + } +} + +#[derive(Debug, thiserror::Error)] +enum ErrorInner { + #[error("Failed to generate TLS certificate: {0:?}")] + GenCertificate(#[from] libp2p_tls::certificate::GenError), + + #[error("quinn::rustls: {0}")] + Rustls(#[from] rustls::NoInitialCipherSuite), + + #[error("QUIC: {0}")] + Quic(#[from] quic::Error), + + #[error(transparent)] + ExtractPeerId(#[from] quic::ExtractPeerIdError), + + #[error("Connection: {0:?}")] + Connection(#[from] quinn::ConnectionError), + + #[error("Timeout establishing inbound connection")] + ConnectionTimeout, + + #[error("IO: {0:?}")] + Io(#[from] io::Error), + + #[error("Failed to read ConnectionHeader: {0:?}")] + ReadHeader(#[from] quinn::ReadExactError), + + #[error("Unsupported protocol version: {0}")] + UnsupportedProtocolVersion(u32), + + #[error("Unknown API: {0}")] + UnknownApi(rpc::ApiName), + + #[error("Transport: {0}")] + Transport(#[from] transport2::Error), + + #[error("Failed to read RPC ID without blocking")] + ReadRpcId, + + #[error("Unknown RPC ID: {0}")] + UnknownRpcId(u8), +} diff --git a/crates/rpc/src/transport2.rs b/crates/rpc/src/transport2.rs new file mode 100644 index 00000000..2fa4723e --- /dev/null +++ b/crates/rpc/src/transport2.rs @@ -0,0 +1,205 @@ +use { + crate::MessageV2, + bytes::{BufMut as _, Bytes, BytesMut}, + futures::{stream::MapErr, Sink, TryStreamExt}, + pin_project::pin_project, + serde::{Deserialize, Serialize}, + std::{ + io, + pin::Pin, + task::{self, ready}, + }, + tokio_serde::Framed, + tokio_stream::Stream, + tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}, +}; + +/// Serialization codec. +pub trait Codec: + for<'a> Serializer> + Serializer + Deserializer +{ +} + +impl Codec for C +where + M: MessageV2, + C: for<'a> Serializer> + Serializer + Deserializer, +{ +} + +pub trait Serializer: + tokio_serde::Serializer> + Unpin + Default + Send + Sync + 'static +{ +} + +impl Serializer for S where + S: tokio_serde::Serializer> + Unpin + Default + Send + Sync + 'static +{ +} + +pub trait Deserializer: + tokio_serde::Deserializer> + Unpin + Default + Send + Sync + 'static +{ +} + +impl Deserializer for D where + D: tokio_serde::Deserializer> + Unpin + Default + Send + Sync + 'static +{ +} + +#[derive(Clone, Copy, Debug, Default)] +pub struct PostcardCodec; + +impl tokio_serde::Deserializer for PostcardCodec +where + for<'a> T: Deserialize<'a>, +{ + type Error = io::Error; + + fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result { + postcard::from_bytes(src).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) + } +} + +impl tokio_serde::Serializer for PostcardCodec +where + T: Serialize, +{ + type Error = io::Error; + + fn serialize(self: Pin<&mut Self>, data: &T) -> Result { + postcard::experimental::serialized_size(data) + .and_then(|size| postcard::to_io(data, BytesMut::with_capacity(size).writer())) + .map(|writer| writer.into_inner().freeze()) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) + } +} + +/// Untyped bi-directional stream. +pub struct BiDirectionalStream { + pub(crate) rx: RawRecvStream, + pub(crate) tx: RawSendStream, +} + +type RawSendStream = FramedWrite; +type RawRecvStream = FramedRead; + +impl BiDirectionalStream { + pub fn new(tx: quinn::SendStream, rx: quinn::RecvStream) -> Self { + Self { + tx: FramedWrite::new(tx, LengthDelimitedCodec::new()), + rx: FramedRead::new(rx, LengthDelimitedCodec::new()), + } + } + + pub(crate) fn upgrade>( + self, + ) -> (RecvStream, SendStream) { + ( + RecvStream(Framed::new(self.rx.map_err(Into::into), C::default())), + SendStream { + inner: self.tx, + codec: C::default(), + }, + ) + } +} + +/// [`Stream`] of outbound [Message][`MessageV2`]s. +#[pin_project(project = SendStreamProj)] +pub struct SendStream { + #[pin] + inner: RawSendStream, + #[pin] + codec: C, +} + +impl Sink<&T> for SendStream +where + C: Serializer>, +{ + type Error = Error; + + fn poll_ready( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + self.project().inner.poll_ready(cx).map_err(Into::into) + } + + fn start_send(mut self: Pin<&mut Self>, item: &T) -> Result<(), Self::Error> { + let bytes = tokio_serde::Serializer::serialize(self.as_mut().project().codec, item) + .map_err(Into::into)?; + + self.as_mut().project().inner.start_send(bytes)?; + + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + self.project().inner.poll_flush(cx).map_err(Into::into) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + ready!(self.as_mut().project().inner.poll_flush(cx))?; + self.project().inner.poll_close(cx).map_err(Into::into) + } +} + +/// [`Stream`] of inbound [Message][`MessageV2`]s. +#[pin_project] +pub struct RecvStream>( + #[allow(clippy::type_complexity)] + #[pin] + Framed Error>, T, T, C>, +); + +impl> Stream for RecvStream { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + self.project().0.poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} + +#[derive(Clone, Debug, thiserror::Error, Eq, PartialEq)] +pub enum Error { + #[error("IO: {0:?}")] + IO(io::ErrorKind), + + #[error("Stream unexpectedly finished")] + StreamFinished, + + #[error("Codec: {_0}")] + Codec(String), + + #[error("{_0}")] + Other(String), +} + +impl From for Error { + fn from(err: io::Error) -> Self { + Self::IO(err.kind()) + } +} + +impl From for Error { + fn from(err: quinn::ConnectionError) -> Self { + Self::Other(format!("Connection: {err:?}")) + } +} + +pub type Result = std::result::Result; diff --git a/crates/storage_api2/Cargo.toml b/crates/storage_api2/Cargo.toml index 4b687423..116a43bf 100644 --- a/crates/storage_api2/Cargo.toml +++ b/crates/storage_api2/Cargo.toml @@ -8,12 +8,16 @@ publish = false workspace = true [features] -client = ["wcn_rpc/client", "arc-swap"] -server = ["wcn_rpc/server"] +default = ["rpc_client", "rpc_server"] +rpc_client = ["wcn_rpc/client"] +rpc_server = ["wcn_rpc/server"] [dependencies] -wc = { workspace = true, features = ["future"] } +wc = { workspace = true, features = ["future", "metrics"] } auth = { workspace = true } +derive_more = { workspace = true, features = ["from", "try_into"] } +strum = { workspace = true , features = ["derive"] } +tap = { workspace = true } wcn_rpc = { workspace = true } serde = "1" @@ -21,6 +25,9 @@ thiserror = "1" tracing = "0.1" futures = "0.3" time = "0.3" +const-hex = "1.14" -# client -arc-swap = { version = "1.7", optional = true } +[dev-dependencies] +tokio = { version = "1", default-features = false } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +rand = "0.8" diff --git a/crates/storage_api2/src/client.rs b/crates/storage_api2/src/client.rs deleted file mode 100644 index 38f08b9f..00000000 --- a/crates/storage_api2/src/client.rs +++ /dev/null @@ -1,458 +0,0 @@ -use { - super::*, - arc_swap::ArcSwap, - futures::SinkExt, - std::{ - collections::HashSet, - future::Future, - result::Result as StdResult, - sync::Arc, - time::Duration, - }, - wcn_rpc::{ - client::middleware::{ - self, - MeteredExt, - Timeouts, - WithRetries, - WithRetriesExt, - WithTimeouts, - WithTimeoutsExt as _, - }, - identity::Keypair, - middleware::Metered, - transport::{self, PendingConnection}, - PeerAddr, - }, -}; - -/// Storage API client. -#[derive(Clone)] -pub struct Client { - rpc: RpcClient, -} - -type RpcClient = - WithRetries>>, RetryStrategy>; - -/// Storage API access token. -pub type AccessToken = Arc>; - -/// [`Client`] config. -#[derive(Clone, Debug)] -pub struct Config { - /// [`Keypair`] of the [`Client`]. - pub keypair: Keypair, - - /// Timeout of establishing a network connection. - pub connection_timeout: Duration, - - /// Timeout of a [`Client`] operation. - pub operation_timeout: Duration, - - /// Storage API access token. - pub access_token: AccessToken, - - /// Maximum number of attempts to try before failing an operation. - pub max_attempts: usize, - - /// Additional label to be used for all metrics of the [`Server`]. - pub metrics_tag: &'static str, -} - -impl Config { - pub fn new(access_token: AccessToken) -> Self { - Self { - keypair: Keypair::generate_ed25519(), - connection_timeout: Duration::from_secs(5), - operation_timeout: Duration::from_secs(10), - access_token, - max_attempts: 3, - metrics_tag: "default", - } - } - - /// Overwrites [`Config::keypair`]. - pub fn with_keypair(mut self, keypair: Keypair) -> Self { - self.keypair = keypair; - self - } - - /// Overwrites [`Config::connection_timeout`]. - pub fn with_connection_timeout(mut self, timeout: Duration) -> Self { - self.connection_timeout = timeout; - self - } - - /// Overwrites [`Config::operation_timeout`]. - pub fn with_operation_timeout(mut self, timeout: Duration) -> Self { - self.operation_timeout = timeout; - self - } - - pub fn with_max_attempts(mut self, max_attempts: usize) -> Self { - self.max_attempts = max_attempts; - self - } - - /// Overwrites [`Config::metrics_tag`]. - pub fn with_metrics_tag(mut self, tag: &'static str) -> Self { - self.metrics_tag = tag; - self - } -} - -impl Client { - /// Creates a new [`Client`]. - pub fn new(config: Config) -> StdResult { - let handshake = Handshake { - access_token: config.access_token, - }; - - let rpc_client_config = wcn_rpc::client::Config { - keypair: config.keypair, - known_peers: HashSet::new(), - handshake, - connection_timeout: config.connection_timeout, - server_name: crate::RPC_SERVER_NAME, - priority: transport::Priority::High, - }; - - let timeouts = Timeouts::new().with_default(config.operation_timeout); - - let rpc_client = wcn_rpc::quic::Client::new(rpc_client_config) - .map_err(|err| CreationError(err.to_string()))? - .with_timeouts(timeouts) - .metered_with_tag(config.metrics_tag) - .with_retries(RetryStrategy::new(config.max_attempts)); - - Ok(Self { rpc: rpc_client }) - } - - pub fn remote_storage<'a>(&'a self, server_addr: &'a PeerAddr) -> RemoteStorage<'a> { - RemoteStorage { - client: self, - server_addr, - expected_keyspace_version: None, - } - } -} - -#[derive(Clone, Debug)] -struct RetryStrategy { - max_attempts: usize, -} - -impl RetryStrategy { - fn new(max_attempts: usize) -> Self { - Self { max_attempts } - } -} - -impl middleware::RetryStrategy for RetryStrategy { - fn requires_retry( - &self, - _rpc_id: wcn_rpc::Id, - error: &wcn_rpc::client::Error, - attempt: usize, - ) -> Option { - use crate::error_code; - - if attempt >= self.max_attempts { - return None; - } - - let rpc_error = match error { - wcn_rpc::client::Error::Transport(_) => return Some(Duration::from_millis(50)), - wcn_rpc::client::Error::Rpc { error, .. } => error, - }; - - Some(match rpc_error.code.as_ref() { - // These errors are non-retryable - wcn_rpc::error_code::THROTTLED - | error_code::INVALID_KEY - | error_code::KEYSPACE_VERSION_MISMATCH - | error_code::UNAUTHORIZED => return None, - - // On the first attempt retry immediately. - _ if attempt == 1 => Duration::ZERO, - - _ => Duration::from_millis(100), - }) - } -} - -/// Handle to a remote Storage API (Server). -#[derive(Clone, Copy)] -pub struct RemoteStorage<'a> { - client: &'a Client, - server_addr: &'a PeerAddr, - expected_keyspace_version: Option, -} - -impl RemoteStorage<'_> { - fn extended_key(&self, key: Key) -> ExtendedKey { - ExtendedKey { - inner: key.0, - keyspace_version: self.expected_keyspace_version, - } - } - - fn rpc_client(&self) -> &RpcClient { - &self.client.rpc - } - - /// Specifies the expected version of the keyspace of the [`RemoteStorage`]. - pub fn expecting_keyspace_version(mut self, version: u64) -> Self { - self.expected_keyspace_version = Some(version); - self - } - - /// Gets a [`Record`] by the provided [`Key`]. - pub async fn get(self, key: Key) -> Result> { - Get::send(self.rpc_client(), self.server_addr, &GetRequest { - key: self.extended_key(key), - }) - .await - .map(|opt| opt.map(|resp| Record::new(resp.value, resp.expiration, resp.version))) - .map_err(Into::into) - } - - /// Sets the provided [`Entry`] only if the version of the existing - /// [`Entry`] is < than the new one. - pub async fn set(self, entry: Entry) -> Result<()> { - Set::send(self.rpc_client(), self.server_addr, &SetRequest { - key: self.extended_key(entry.key), - value: entry.value, - expiration: entry.expiration.timestamp(), - version: entry.version.timestamp(), - }) - .await - .map_err(Into::into) - } - - /// Deletes an [`Entry`] by the provided [`Key`] only if the version of the - /// [`Entry`] is < than the provided `version`. - pub async fn del(self, key: Key, version: EntryVersion) -> Result<()> { - Del::send(self.rpc_client(), self.server_addr, &DelRequest { - key: self.extended_key(key), - version: version.timestamp(), - }) - .await - .map_err(Into::into) - } - - /// Gets an [`EntryExpiration`] by the provided [`Key`]. - pub async fn get_exp(self, key: Key) -> Result> { - GetExp::send(self.rpc_client(), self.server_addr, &GetExpRequest { - key: self.extended_key(key), - }) - .await - .map(|opt| opt.map(|resp| EntryExpiration::from(resp.expiration))) - .map_err(Into::into) - } - - /// Sets [`Expiration`] on the [`Entry`] with the provided [`Key`] only if - /// the version of the [`Entry`] is < than the provided `version`. - pub async fn set_exp( - self, - key: Key, - expiration: impl Into, - version: EntryVersion, - ) -> Result<()> { - SetExp::send(self.rpc_client(), self.server_addr, &SetExpRequest { - key: self.extended_key(key), - expiration: expiration.into().timestamp(), - version: version.timestamp(), - }) - .await - .map_err(Into::into) - } - - /// Gets a map [`Record`] by the provided [`Key`] and [`Field`]. - pub async fn hget(self, key: Key, field: Field) -> Result> { - HGet::send(self.rpc_client(), self.server_addr, &HGetRequest { - key: self.extended_key(key), - field, - }) - .await - .map(|opt| opt.map(|resp| Record::new(resp.value, resp.expiration, resp.version))) - .map_err(Into::into) - } - - /// Sets the provided [`MapEntry`] only if the version of the existing - /// [`MapEntry`] is < than the new one. - pub async fn hset(self, entry: MapEntry) -> Result<()> { - HSet::send(self.rpc_client(), self.server_addr, &HSetRequest { - key: self.extended_key(entry.key), - field: entry.field, - value: entry.value, - expiration: entry.expiration.timestamp(), - version: entry.version.timestamp(), - }) - .await - .map_err(Into::into) - } - - /// Deletes a [`MapEntry`] by the provided [`Key`] only if the version of - /// the [`MapEntry`] is < than the provided `version`. - pub async fn hdel(self, key: Key, field: Field, version: EntryVersion) -> Result<()> { - HDel::send(self.rpc_client(), self.server_addr, &HDelRequest { - key: self.extended_key(key), - field, - version: version.timestamp(), - }) - .await - .map_err(Into::into) - } - - /// Gets an [`EntryExpiration`] by the provided [`Key`] and [`Field`]. - pub async fn hget_exp(self, key: Key, field: Field) -> Result> { - HGetExp::send(self.rpc_client(), self.server_addr, &HGetExpRequest { - key: self.extended_key(key), - field, - }) - .await - .map(|opt| opt.map(|resp| EntryExpiration::from(resp.expiration))) - .map_err(Into::into) - } - - /// Sets [`Expiration`] on the [`MapEntry`] with the provided [`Key`] and - /// [`Field`] only if the version of the [`MapEntry`] is < than the - /// provided `version`. - pub async fn hset_exp( - self, - key: Key, - field: Field, - expiration: impl Into, - version: EntryVersion, - ) -> Result<()> { - HSetExp::send(self.rpc_client(), self.server_addr, &HSetExpRequest { - key: self.extended_key(key), - field, - expiration: expiration.into().timestamp(), - version: version.timestamp(), - }) - .await - .map_err(Into::into) - } - - /// Returns cardinality of the map with the provided [`Key`]. - pub async fn hcard(self, key: Key) -> Result { - HCard::send(self.rpc_client(), self.server_addr, &HCardRequest { - key: self.extended_key(key), - }) - .await - .map(|resp| resp.cardinality) - .map_err(Into::into) - } - - /// Returns a [`MapPage`] by iterating over the [`Field`]s of the map with - /// the provided [`Key`]. - pub async fn hscan(self, key: Key, count: u32, cursor: Option) -> Result { - let resp = HScan::send(self.rpc_client(), self.server_addr, &HScanRequest { - key: self.extended_key(key), - count, - cursor, - }) - .await - .map_err(Error::from)?; - - Ok(MapPage { - has_next: resp.records.len() >= count as usize, - records: resp - .records - .into_iter() - .map(|record| MapRecord { - field: record.field, - value: record.value, - expiration: EntryExpiration::from(record.expiration), - version: EntryVersion::from(record.version), - }) - .collect(), - }) - } -} - -/// Error of [`Client::new`]. -#[derive(Clone, Debug, thiserror::Error)] -#[error("{_0}")] -pub struct CreationError(String); - -/// Error of a [`Client`] operation. -#[derive(Clone, Debug, PartialEq, Eq, Hash, thiserror::Error)] -pub enum Error { - /// Transport errort. - #[error("Transport: {_0}")] - Transport(String), - - /// Operation timed out. - #[error("Timeout")] - Timeout, - - /// Server is throttling. - #[error("Throttled")] - Throttled, - - /// Client is not authorized to perform the operation. - #[error("Unauthorized")] - Unauthorized, - - /// Keyspace versions of client and server don't match. - #[error("Keyspace version mismatch")] - KeyspaceVersionMismatch, - - /// Other client/server error. - #[error("{_0}")] - Other(String), -} - -impl From for Error { - fn from(err: wcn_rpc::client::Error) -> Self { - let rpc_err = match err { - wcn_rpc::client::Error::Transport(err) => return Self::Transport(err.to_string()), - wcn_rpc::client::Error::Rpc { error, .. } => error, - }; - - match rpc_err.code.as_ref() { - wcn_rpc::error_code::TIMEOUT => Self::Timeout, - crate::error_code::KEYSPACE_VERSION_MISMATCH => Self::KeyspaceVersionMismatch, - crate::error_code::UNAUTHORIZED => Self::Unauthorized, - _ => Self::Other(format!("{rpc_err:?}")), - } - } -} - -/// [`Client`] operation [`Result`]. -pub type Result = std::result::Result; - -/// Client part of the [`network::Handshake`]. -#[derive(Clone)] -struct Handshake { - access_token: AccessToken, -} - -impl transport::Handshake for Handshake { - type Ok = (); - type Err = HandshakeError; - - fn handle( - &self, - _peer_id: PeerId, - conn: PendingConnection, - ) -> impl Future> + Send { - async move { - let (mut rx, mut tx) = conn - .initiate_handshake::() - .await?; - - let req = HandshakeRequest { - access_token: self.access_token.load().as_ref().to_owned(), - }; - - tx.send(req).await.map_err(HandshakeError::Transport)?; - - rx.recv_message().await?.map_err(Into::into) - } - } -} diff --git a/crates/storage_api2/src/lib.rs b/crates/storage_api2/src/lib.rs index 76d34778..8bad3741 100644 --- a/crates/storage_api2/src/lib.rs +++ b/crates/storage_api2/src/lib.rs @@ -5,227 +5,247 @@ pub use { wcn_rpc::{identity, Multiaddr, PeerAddr, PeerId}, }; use { + futures::FutureExt as _, serde::{Deserialize, Serialize}, - std::{io, time::Duration}, + std::{borrow::Cow, future::Future, str::FromStr, time::Duration}, time::OffsetDateTime as DateTime, - wcn_rpc::{self as rpc, transport}, }; -#[cfg(feature = "client")] -pub mod client; -#[cfg(feature = "client")] -pub use client::Client; -#[cfg(feature = "server")] -pub mod server; -#[cfg(feature = "server")] -pub use server::Server; - -const RPC_SERVER_NAME: rpc::ServerName = rpc::ServerName::new("storage_api"); - -/// RPC error codes produced by this module. -mod error_code { - /// Client is not authorized to perform the operation. - pub const UNAUTHORIZED: &str = "unauthorized"; - - /// Keyspace versions of the client and the server don't match. - pub const KEYSPACE_VERSION_MISMATCH: &str = "keyspace_version_mismatch"; - - /// Provided key was invalid. - pub const INVALID_KEY: &str = "invalid_key"; -} - -/// Key in a KV storage. -#[derive(Clone, Debug)] -pub struct Key(Vec); - -impl Key { - /// Length of a [`Key`] namespace (prefix). - pub const NAMESPACE_LEN: usize = auth::PUBLIC_KEY_LEN; - - const KIND_SHARED: u8 = 0; - const KIND_PRIVATE: u8 = 1; - - /// Creates a new shared [`Key`] using the global namespace. - pub fn shared(bytes: impl AsRef<[u8]>) -> Self { - Self::new(bytes, None) +pub mod operation; +pub use operation::{Operation, OperationRef}; + +#[cfg(any(feature = "rpc_client", feature = "rpc_server"))] +pub mod rpc; + +/// Namespace within a WCN cluster. +/// +/// Namespaces are isolated and every [`StorageApi`] [`Operation`] gets executed +/// on a specific [`Namespace`]. +#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct Namespace { + /// ID of the node operator to which this namespace belongs. + /// + /// Currentry an Ethereum address. + node_operator_id: [u8; 20], + + /// ID of this [`Namespace`] within the node operator scope. + id: u8, +} + +/// Version of the keyspace of a WCN cluster. +/// +/// Keyspace changes after data rebalancing within a WCN Cluster. +/// For data consistency reasons WCN Coordinators and WCN Replicas need to +/// operate using the same [`KeyspaceVersion`]. So for Coordinator->Replica +/// [`StorageApi`] calls [`Operation`]s need to include the expected +/// [`KeyspaceVersion`]. +/// +/// For Client->Coordinator and Replica->Database calls the [`KeyspaceVersion`] +/// validation is not required. +pub type KeyspaceVersion = u64; + +/// WCN Storage API. +/// +/// Lingua franka of the WCN network: +/// - Clients use it to execute storage operations on the network via sending +/// them to Replication Coordinators (WCN nodes hosting coordinator RPC +/// servers). +/// - Replication Coordinators use it to replicate storage operations across the +/// network via sending them to Replicas (WCN nodes hosting replica RPC +/// servers). +/// - Replicas use it to finally execute the operations on their local WCN +/// Database instances. +pub trait StorageApi: Clone + Send + Sync + 'static { + /// Executes the provided [`operation::Get`]. + fn get<'a>( + &'a self, + get: &'a operation::Get<'a>, + ) -> impl Future>>> + Send + 'a { + self.execute_ref(get) + .map(operation::Output::downcast_result) } - /// Creates a new private [`Key`] using the provided `namespace`. - pub fn private(namespace: &auth::PublicKey, bytes: impl AsRef<[u8]>) -> Self { - Self::new(bytes, Some(namespace)) + /// Executes the provided [`operation::Set`]. + fn set<'a>( + &'a self, + set: &'a operation::Set<'a>, + ) -> impl Future> + Send + 'a { + self.execute_ref(set) + .map(operation::Output::downcast_result) } - /// Returns namespace of this [`Key`]. - pub fn namespace(&self) -> Option<&[u8; Self::NAMESPACE_LEN]> { - match *self.0.first()? { - Self::KIND_PRIVATE => Some(self.0[1..][..Self::NAMESPACE_LEN].try_into().ok()?), - _ => None, - } + /// Executes the provided [`operation::Del`]. + fn del<'a>( + &'a self, + del: &'a operation::Del<'a>, + ) -> impl Future> + Send + 'a { + self.execute_ref(del) + .map(operation::Output::downcast_result) } - /// Returns the full byte representation of this [`Key`] (including - /// namespace). - pub fn as_bytes(&self) -> &[u8] { - &self.0 + /// Executes the provided [`operation::GetExp`]. + fn get_exp<'a>( + &'a self, + get_exp: &'a operation::GetExp<'a>, + ) -> impl Future>> + Send + 'a { + self.execute_ref(get_exp) + .map(operation::Output::downcast_result) } - /// Converts this [`Key`] into bytes. - pub fn into_bytes(self) -> Vec { - self.0 + /// Executes the provided [`operation::SetExp`]. + fn set_exp<'a>( + &'a self, + set_exp: &'a operation::SetExp<'a>, + ) -> impl Future> + Send + 'a { + self.execute_ref(set_exp) + .map(operation::Output::downcast_result) } - #[cfg(feature = "server")] - fn from_raw_bytes(bytes: Vec) -> Option { - match *bytes.first()? { - Self::KIND_SHARED => Some(Self(bytes)), - Self::KIND_PRIVATE if bytes.len() > Self::NAMESPACE_LEN + 1 => Some(Self(bytes)), - _ => None, - } + /// Executes the provided [`operation::HGet`]. + fn hget<'a>( + &'a self, + hget: &'a operation::HGet<'a>, + ) -> impl Future>>> + Send + 'a { + self.execute_ref(hget) + .map(operation::Output::downcast_result) } - fn new(bytes: impl AsRef<[u8]>, namespace: Option<&auth::PublicKey>) -> Self { - let bytes = bytes.as_ref(); - - let prefix_len = if namespace.is_some() { - Self::NAMESPACE_LEN - } else { - 0 - }; + /// Executes the provided [`operation::HSet`]. + fn hset<'a>( + &'a self, + hset: &'a operation::HSet<'a>, + ) -> impl Future> + Send + 'a { + self.execute_ref(hset) + .map(operation::Output::downcast_result) + } - let mut data = Vec::with_capacity(1 + prefix_len + bytes.len()); + /// Executes the provided [`operation::HDel`]. + fn hdel<'a>( + &'a self, + hdel: &'a operation::HDel<'a>, + ) -> impl Future> + Send + 'a { + self.execute_ref(hdel) + .map(operation::Output::downcast_result) + } - if let Some(namespace) = namespace { - data.push(Self::KIND_PRIVATE); - data.extend_from_slice(namespace.as_ref()); - } else { - data.push(Self::KIND_SHARED); - }; + /// Executes the provided [`operation::HGetExp`]. + fn hget_exp<'a>( + &'a self, + hget_exp: &'a operation::HGetExp<'a>, + ) -> impl Future>> + Send + 'a { + self.execute_ref(hget_exp) + .map(operation::Output::downcast_result) + } - data.extend_from_slice(bytes); - Self(data) + /// Executes the provided [`operation::HSetExp`]. + fn hset_exp<'a>( + &'a self, + hset_exp: &'a operation::HSetExp<'a>, + ) -> impl Future> + Send + 'a { + self.execute_ref(hset_exp) + .map(operation::Output::downcast_result) } -} -/// Value in a KV storage. -pub type Value = Vec; + /// Executes the provided [`operation::HCard`]. + fn hcard<'a>( + &'a self, + hcard: &'a operation::HCard<'a>, + ) -> impl Future> + Send + 'a { + self.execute_ref(hcard) + .map(operation::Output::downcast_result) + } -/// Subkey of a [`MapEntry`]. -pub type Field = Vec; + /// Executes the provided [`operation::HScan`]. + fn hscan<'a>( + &'a self, + hscan: &'a operation::HScan<'a>, + ) -> impl Future>> + Send + 'a { + self.execute_ref(hscan) + .map(operation::Output::downcast_result) + } -/// Basic KV storage entry. -#[derive(Clone, Debug)] -pub struct Entry { - /// [`Key`] of this [`Entry`]. - pub key: Key, + /// Executes the provided [`StorageApi`] [`OperationRef`]. + fn execute_ref<'a>( + &'a self, + operation: impl Into> + Send + 'a, + ) -> impl Future>> + Send + 'a { + self.execute(operation.into().to_owned()) + } - /// [`Value`] of this [`Entry`]. - pub value: Value, + /// Executes the provided [`StorageApi`] [`Operation`]. + fn execute<'a>( + &'a self, + operation: Operation<'a>, + ) -> impl Future>> + Send + 'a; +} + +/// Raw bytes. +pub type Bytes<'a> = Cow<'a, [u8]>; + +/// Raw [`Bytes`] value with metadata related to this value. +/// +/// [`Record`]s are being deleted from WCN Database after specified +/// [`Record::expiration`] time. +/// +/// A [`Record`] can not be overwritten by another [`Record`] with a lesser +/// [version][`Record::version`]. +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub struct Record<'a> { + /// Value of this [`Record`]. + pub value: Bytes<'a>, - /// Expiration time of this [`Entry`]. - pub expiration: EntryExpiration, + /// Expiration time of [`Record`]. + pub expiration: RecordExpiration, - /// Version of this [`Entry`]. - pub version: EntryVersion, + /// Version of this [`Record`]. + pub version: RecordVersion, } -impl Entry { - /// Creates a new [`Entry`]. - pub fn new( - key: impl Into, - value: impl Into, - expiration: impl Into, - ) -> Self { - Self { - key: key.into(), - value: value.into(), - expiration: expiration.into(), - version: EntryVersion::new(), +impl Record<'_> { + /// Converts `Self` into 'static. + pub fn into_static(self) -> Record<'static> { + Record { + value: Cow::Owned(self.value.into_owned()), + expiration: self.expiration, + version: self.version, } } } -/// Map entry in which each [`Value`] is associated with both [`Key`] and subkey -/// ([`Field`]). -#[derive(Clone, Debug)] -pub struct MapEntry { - /// [`Key`] of this [`Entry`]. - pub key: Key, - - /// [`Field`] of this [`Entry`]. - pub field: Field, - - /// [`Value`] of this [`Entry`]. - pub value: Value, - - /// Expiration time of this [`Entry`]. - pub expiration: EntryExpiration, - - /// Version of this [`Entry`]. - pub version: EntryVersion, -} - -impl MapEntry { - /// Creates a new [`MapEntry`]. - pub fn new( - key: impl Into, - field: impl Into, - value: impl Into, - expiration: impl Into, - ) -> Self { - Self { - key: key.into(), - field: field.into(), - value: value.into(), - expiration: expiration.into(), - version: EntryVersion::new(), +/// Entry within a Map. +/// +/// Maps are a separate data type of WCN Database, similar to Redis Hashes. +/// They differ from regular KV pairs by having a subkey (AKA +/// [field][MapEntry::field]). +/// +/// Each Map key contains an ordered set of entries (ascending order by +/// [`MapEntry::field`]). +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct MapEntry<'a> { + /// Subkey of this [`MapEntry`]. + pub field: Bytes<'a>, + + /// [`Record`] of this [`MapEntry`]. + pub record: Record<'a>, +} + +impl MapEntry<'_> { + /// Converts `Self` into 'static. + pub fn into_static(self) -> MapEntry<'static> { + MapEntry { + field: Cow::Owned(self.field.into_owned()), + record: self.record.into_static(), } } } -/// [`Entry`]/[`MapEntry`] without the associated [`Key`]/[`Field`]. -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct Record { - /// Value of this [`Record`]. - pub value: Value, - - /// Expiration time of the associated [`Entry`]/[`MapEntry`]. - pub expiration: EntryExpiration, - - /// Version of the associated [`Entry`]/[`MapEntry`]. - pub version: EntryVersion, -} - -/// [`MapEntry`] without the associated [`Key`]. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct MapRecord { - /// Field of this [`MapRecord`]. - pub field: Field, - - /// Value of this [`MapRecord`]. - pub value: Value, - - /// Expiration time of the associated [`MapEntry`]. - pub expiration: EntryExpiration, - - /// Version of the associated [`MapEntry`]. - pub version: EntryVersion, -} - -/// [`Entry`]/[`MapEntry`] expiration time. -#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] -pub struct EntryExpiration { +/// Expiration time of a [`Record`]. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)] +pub struct RecordExpiration { unix_timestamp_secs: u64, } -impl From for EntryExpiration { - fn from(timestamp: UnixTimestampSecs) -> Self { - Self { - unix_timestamp_secs: timestamp.0, - } - } -} - -impl From for EntryExpiration { +impl From for RecordExpiration { fn from(dur: Duration) -> Self { Self { unix_timestamp_secs: (DateTime::now_utc() + dur).unix_timestamp() as u64, @@ -233,7 +253,7 @@ impl From for EntryExpiration { } } -impl From for EntryExpiration { +impl From for RecordExpiration { fn from(dt: DateTime) -> Self { Self { unix_timestamp_secs: dt.unix_timestamp() as u64, @@ -241,48 +261,29 @@ impl From for EntryExpiration { } } -impl EntryExpiration { - pub fn from_unix_timestamp_secs(timestamp: u64) -> Self { - Self { - unix_timestamp_secs: timestamp, - } - } - - pub fn unix_timestamp_secs(&self) -> u64 { - self.unix_timestamp_secs - } - - pub fn to_duration(&self) -> Duration { - let expiry = DateTime::from_unix_timestamp(self.unix_timestamp_secs as i64) +impl From for Duration { + fn from(exp: RecordExpiration) -> Self { + let expiry = DateTime::from_unix_timestamp(exp.unix_timestamp_secs as i64) .unwrap_or(DateTime::UNIX_EPOCH); (expiry - DateTime::now_utc()) .try_into() .unwrap_or_default() } - - fn timestamp(&self) -> UnixTimestampSecs { - UnixTimestampSecs(self.unix_timestamp_secs) - } -} - -/// [`Entry`]/[`MapEntry`] version. -#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] -pub struct EntryVersion { - unix_timestamp_micros: u64, } -impl From for EntryVersion { - fn from(timestamp: UnixTimestampMicros) -> Self { - Self { - unix_timestamp_micros: timestamp.0, - } - } +/// Version of a [`Record`]. +/// +/// [`RecordVersion`] is a local client-side generated timestamp. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)] +pub struct RecordVersion { + /// UNIX timestamp (microseconds) representation of this [`RecordVersion`]. + pub unix_timestamp_micros: u64, } -impl EntryVersion { - #[allow(clippy::new_without_default)] - pub fn new() -> EntryVersion { +impl RecordVersion { + /// Generates a new [`RecordVersion`] using the current timestamp. + pub fn now() -> RecordVersion { Self { unix_timestamp_micros: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -290,241 +291,135 @@ impl EntryVersion { .as_micros() as u64, } } - - pub fn from_unix_timestamp_micros(timestamp: u64) -> Self { - Self { - unix_timestamp_micros: timestamp, - } - } - - pub fn unix_timestamp_micros(&self) -> u64 { - self.unix_timestamp_micros - } - - fn timestamp(&self) -> UnixTimestampMicros { - UnixTimestampMicros(self.unix_timestamp_micros) - } } -/// Page of [`MapRecord`]s. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct MapPage { - /// [`MapRecords`] of this [`Page`]. - pub records: Vec, +/// Page of [map entries][MapEntry]. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct MapPage<'a> { + /// [`MapRecord`]s of this [`Page`]. + pub entries: Vec>, /// Indicator of whether there's a next [`Page`] or not. pub has_next: bool, } -impl MapPage { +impl<'a> MapPage<'a> { /// Returns cursor pointing to the next [`Page`] if there is one. - pub fn next_page_cursor(&self) -> Option<&Field> { + pub fn next_page_cursor(&self) -> Option<&Bytes<'a>> { self.has_next - .then(|| self.records.last().map(|entry| &entry.field)) + .then(|| self.entries.last().map(|entry| &entry.field)) .flatten() } -} -#[cfg(feature = "client")] -impl Record { - fn new(value: Value, expiration: UnixTimestampSecs, version: UnixTimestampMicros) -> Self { - Self { - value, - expiration: expiration.into(), - version: version.into(), + /// Converts `Self` into 'static. + pub fn into_static(self) -> MapPage<'static> { + MapPage { + entries: self + .entries + .into_iter() + .map(MapEntry::into_static) + .collect(), + has_next: self.has_next, } } } -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct UnixTimestampSecs(u64); - -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct UnixTimestampMicros(u64); - -#[derive(Clone, Debug, Serialize, Deserialize)] -struct ExtendedKey { - inner: Vec, - keyspace_version: Option, -} - -type Get = rpc::Unary<{ rpc::id(b"get") }, GetRequest, Option>; - -#[derive(Clone, Debug, Serialize, Deserialize)] -struct GetRequest { - key: ExtendedKey, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -struct GetResponse { - value: Value, - expiration: UnixTimestampSecs, - version: UnixTimestampMicros, -} - -type Set = rpc::Unary<{ rpc::id(b"set") }, SetRequest, ()>; - -#[derive(Clone, Debug, Serialize, Deserialize)] -struct SetRequest { - key: ExtendedKey, - value: Value, - expiration: UnixTimestampSecs, - version: UnixTimestampMicros, -} - -type Del = rpc::Unary<{ rpc::id(b"del") }, DelRequest, ()>; - -#[derive(Clone, Debug, Serialize, Deserialize)] -struct DelRequest { - key: ExtendedKey, - version: UnixTimestampMicros, -} - -type GetExp = rpc::Unary<{ rpc::id(b"get_exp") }, GetExpRequest, Option>; - -#[derive(Clone, Debug, Serialize, Deserialize)] -struct GetExpRequest { - key: ExtendedKey, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -struct GetExpResponse { - expiration: UnixTimestampSecs, -} - -type SetExp = rpc::Unary<{ rpc::id(b"set_exp") }, SetExpRequest, ()>; - -#[derive(Clone, Debug, Serialize, Deserialize)] -struct SetExpRequest { - key: ExtendedKey, - expiration: UnixTimestampSecs, - version: UnixTimestampMicros, -} - -type HGet = rpc::Unary<{ rpc::id(b"hget") }, HGetRequest, Option>; - -#[derive(Clone, Debug, Serialize, Deserialize)] -struct HGetRequest { - key: ExtendedKey, - field: Field, -} +impl FromStr for Namespace { + type Err = InvalidNamespaceError; -#[derive(Clone, Debug, Serialize, Deserialize)] -struct HGetResponse { - value: Value, - expiration: UnixTimestampSecs, - version: UnixTimestampMicros, -} - -type HSet = rpc::Unary<{ rpc::id(b"hset") }, HSetRequest, ()>; + fn from_str(s: &str) -> Result { + use InvalidNamespaceError as Error; -#[derive(Clone, Debug, Serialize, Deserialize)] -struct HSetRequest { - key: ExtendedKey, - field: Field, - value: Value, - expiration: UnixTimestampSecs, - version: UnixTimestampMicros, -} + let mut parts = s.split('/'); -type HDel = rpc::Unary<{ rpc::id(b"hdel") }, HDelRequest, ()>; + let (operator_id, id) = (|| Some((parts.next()?, parts.next()?)))() + .ok_or(Error("Not enough components".into()))?; -#[derive(Clone, Debug, Serialize, Deserialize)] -struct HDelRequest { - key: ExtendedKey, - field: Field, - version: UnixTimestampMicros, -} - -type HGetExp = rpc::Unary<{ rpc::id(b"hget_exp") }, HGetExpRequest, Option>; - -#[derive(Clone, Debug, Serialize, Deserialize)] -struct HGetExpRequest { - key: ExtendedKey, - field: Field, -} + if parts.next().is_some() { + return Err(Error("Too many components".into())); + } -#[derive(Clone, Debug, Serialize, Deserialize)] -struct HGetExpResponse { - expiration: UnixTimestampSecs, + Ok(Self { + node_operator_id: const_hex::decode_to_array(operator_id) + .map_err(|err| Error(format!("Invalid node operator id: {err}").into()))?, + id: id + .parse() + .map_err(|err| Error(format!("Invalid id: {err}").into()))?, + }) + } } -type HSetExp = rpc::Unary<{ rpc::id(b"hset_exp") }, HSetExpRequest, ()>; - -#[derive(Clone, Debug, Serialize, Deserialize)] -struct HSetExpRequest { - key: ExtendedKey, - field: Field, - expiration: UnixTimestampSecs, - version: UnixTimestampMicros, -} +/// Error of parsing [`Namespace`] from a string. +#[derive(Debug, thiserror::Error)] +#[error("Invalid namespace: {_0}")] +pub struct InvalidNamespaceError(Cow<'static, str>); -type HCard = rpc::Unary<{ rpc::id(b"hcard") }, HCardRequest, HCardResponse>; +/// [`StorageApi`] result. +pub type Result = std::result::Result; -#[derive(Clone, Debug, Serialize, Deserialize)] -struct HCardRequest { - key: ExtendedKey, +/// [`StorageApi`] error. +#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] +#[error("{kind:?}({details:?})")] +pub struct Error { + kind: ErrorKind, + details: Option, } -#[derive(Clone, Debug, Serialize, Deserialize)] -struct HCardResponse { - cardinality: u64, +impl Error { + /// Returns [`ErrorKind`] of this [`Error`]. + pub fn kind(&self) -> ErrorKind { + self.kind + } } -type HScan = rpc::Unary<{ rpc::id(b"hscan") }, HScanRequest, HScanResponse>; +/// [`Error`] kind. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ErrorKind { + /// Client is not authorized to perfrom an [`Operation`]. + Unauthorized, -#[derive(Clone, Debug, Serialize, Deserialize)] -struct HScanRequest { - key: ExtendedKey, - count: u32, - cursor: Option, -} + /// [`KeyspaceVersion`] mismatch. + KeyspaceVersionMismatch, -#[derive(Clone, Debug, Serialize, Deserialize)] -struct HScanResponse { - records: Vec, - has_more: bool, -} + /// [`Operation`] timeout. + Timeout, -#[derive(Clone, Debug, Serialize, Deserialize)] -struct HScanResponseRecord { - field: Field, - value: Value, - expiration: UnixTimestampSecs, - version: UnixTimestampMicros, -} + /// Internal error. + Internal, -#[derive(Debug, Serialize, Deserialize)] -struct HandshakeRequest { - access_token: auth::Token, -} + /// Transport error. + Transport, -#[derive(Clone, Debug, Serialize, Deserialize)] -enum HandshakeErrorResponse { - InvalidToken(String), + /// Unable to determine [`ErrorKind`] of an [`Error`]. + Unknown, } -type HandshakeResponse = Result<(), HandshakeErrorResponse>; - -#[derive(Clone, Debug, thiserror::Error)] -pub enum HandshakeError { - #[error(transparent)] - Transport(#[from] transport::Error), - - #[error("Invalid token: {_0}")] - InvalidToken(String), +impl Error { + /// Creates a new [`Error`]. + fn new(kind: ErrorKind, details: Option) -> Self { + Self { kind, details } + } } -impl From for HandshakeError { - fn from(err: HandshakeErrorResponse) -> Self { - match err { - HandshakeErrorResponse::InvalidToken(err) => Self::InvalidToken(err), +impl From for Error { + fn from(kind: ErrorKind) -> Self { + Self { + kind, + details: None, } } } -impl From for HandshakeError { - fn from(err: io::Error) -> Self { - Self::Transport(err.into()) +#[cfg(test)] +#[test] +fn test_namespace_from_str() { + fn ns(s: &str) -> Result { + s.parse() } + + assert!(ns("0x14Cb1e6fb683A83455cA283e10f4959740A49ed7/0").is_ok()); + assert!(ns("14Cb1e6fb683A83455cA283e10f4959740A49ed7/0").is_ok()); + assert!(ns("14Cb1e6fb683A83455cA283e10f4959740A49ed7/255").is_ok()); + assert!(ns("14Cb1e6fb683A83455cA283e10f4959740A49ed7/256").is_err()); + assert!(ns("4Cb1e6fb683A83455cA283e10f4959740A49ed7/1").is_err()); } diff --git a/crates/storage_api2/src/operation.rs b/crates/storage_api2/src/operation.rs new file mode 100644 index 00000000..3abf898b --- /dev/null +++ b/crates/storage_api2/src/operation.rs @@ -0,0 +1,303 @@ +//! Storage API operations. + +use { + crate::{ + Bytes, + Error, + ErrorKind, + KeyspaceVersion, + MapEntry, + MapPage, + Namespace, + Record, + RecordExpiration, + RecordVersion, + Result, + }, + derive_more::derive::{From, TryInto}, + serde::{Deserialize, Serialize}, + strum::{EnumDiscriminants, IntoDiscriminant}, + tap::TapFallible as _, + wc::metrics::{self, enum_ordinalize::Ordinalize}, +}; + +/// Sum type of all Storage API operations. +#[derive(Clone, Debug, From, PartialEq, Eq)] +pub enum Operation<'a> { + Get(Get<'a>), + Set(Set<'a>), + Del(Del<'a>), + GetExp(GetExp<'a>), + SetExp(SetExp<'a>), + + HGet(HGet<'a>), + HSet(HSet<'a>), + HDel(HDel<'a>), + HGetExp(HGetExp<'a>), + HSetExp(HSetExp<'a>), + HCard(HCard<'a>), + HScan(HScan<'a>), +} + +impl<'a> Operation<'a> { + /// Converts &self to [`OperationRef`]. + /// + /// Reference to reference conversion, does not re-allocate. + pub fn to_ref(&'a self) -> OperationRef<'a> { + match self { + Self::Get(op) => OperationRef::Get(op), + Self::Set(op) => OperationRef::Set(op), + Self::Del(op) => OperationRef::Del(op), + Self::GetExp(op) => OperationRef::GetExp(op), + Self::SetExp(op) => OperationRef::SetExp(op), + Self::HGet(op) => OperationRef::HGet(op), + Self::HSet(op) => OperationRef::HSet(op), + Self::HDel(op) => OperationRef::HDel(op), + Self::HGetExp(op) => OperationRef::HGetExp(op), + Self::HSetExp(op) => OperationRef::HSetExp(op), + Self::HCard(op) => OperationRef::HCard(op), + Self::HScan(op) => OperationRef::HScan(op), + } + } +} + +/// Sum type of references to all Storage API operations. +#[derive(Clone, Debug, From, EnumDiscriminants, PartialEq, Eq)] +#[strum_discriminants(name(Name))] +#[strum_discriminants(derive(Ordinalize))] +pub enum OperationRef<'a> { + Get(&'a Get<'a>), + Set(&'a Set<'a>), + Del(&'a Del<'a>), + GetExp(&'a GetExp<'a>), + SetExp(&'a SetExp<'a>), + + HGet(&'a HGet<'a>), + HSet(&'a HSet<'a>), + HDel(&'a HDel<'a>), + HGetExp(&'a HGetExp<'a>), + HSetExp(&'a HSetExp<'a>), + HCard(&'a HCard<'a>), + HScan(&'a HScan<'a>), +} + +impl<'a> OperationRef<'a> { + /// Converts `self` into owned [`Operation`]. + /// + /// Re-allocates the underying heap-allocated data. + pub fn to_owned(self) -> Operation<'a> { + match self { + Self::Get(op) => Operation::Get(op.clone()), + Self::Set(op) => Operation::Set(op.clone()), + Self::Del(op) => Operation::Del(op.clone()), + Self::GetExp(op) => Operation::GetExp(op.clone()), + Self::SetExp(op) => Operation::SetExp(op.clone()), + Self::HGet(op) => Operation::HGet(op.clone()), + Self::HSet(op) => Operation::HSet(op.clone()), + Self::HDel(op) => Operation::HDel(op.clone()), + Self::HGetExp(op) => Operation::HGetExp(op.clone()), + Self::HSetExp(op) => Operation::HSetExp(op.clone()), + Self::HCard(op) => Operation::HCard(op.clone()), + Self::HScan(op) => Operation::HScan(op.clone()), + } + } +} + +impl metrics::Enum for Name { + fn as_str(&self) -> &'static str { + match self { + Self::Get => "get", + Self::Set => "set", + Self::Del => "del", + Self::GetExp => "get_exp", + Self::SetExp => "set_exp", + Self::HGet => "hget", + Self::HSet => "hset", + Self::HDel => "hdel", + Self::HGetExp => "hget_exp", + Self::HSetExp => "hset_exp", + Self::HCard => "hcard", + Self::HScan => "hscan", + } + } +} + +impl<'a> OperationRef<'a> { + /// Returns [`Name`] of this [`Operation`]. + pub fn name(&self) -> Name { + self.discriminant() + } + + /// Returns key of this [`Operation`]. + pub fn key(&self) -> &Bytes<'a> { + match self { + Self::Get(get) => &get.key, + Self::Set(set) => &set.key, + Self::Del(del) => &del.key, + Self::GetExp(get_exp) => &get_exp.key, + Self::SetExp(set_exp) => &set_exp.key, + Self::HGet(hget) => &hget.key, + Self::HSet(hset) => &hset.key, + Self::HDel(hdel) => &hdel.key, + Self::HGetExp(hget_exp) => &hget_exp.key, + Self::HSetExp(hset_exp) => &hset_exp.key, + Self::HCard(hcard) => &hcard.key, + Self::HScan(hscan) => &hscan.key, + } + } +} + +/// Gets a [`Record`] by the provided key. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct Get<'a> { + pub namespace: Namespace, + pub key: Bytes<'a>, + pub keyspace_version: Option, +} + +/// Sets a new [`Record`] under the provided key. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct Set<'a> { + pub namespace: Namespace, + pub key: Bytes<'a>, + pub record: Record<'a>, + pub keyspace_version: Option, +} + +/// Deletes a [`Record`] by the provided key. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct Del<'a> { + pub namespace: Namespace, + pub key: Bytes<'a>, + pub version: RecordVersion, + pub keyspace_version: Option, +} + +/// Gets a [`RecordExpiration`] by the provided key. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct GetExp<'a> { + pub namespace: Namespace, + pub key: Bytes<'a>, + pub keyspace_version: Option, +} + +/// Sets [`RecordExpiration`] on the [`Record`] with the provided key. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct SetExp<'a> { + pub namespace: Namespace, + pub key: Bytes<'a>, + pub expiration: RecordExpiration, + pub version: RecordVersion, + pub keyspace_version: Option, +} + +/// Gets a Map [`Record`] by the provided key and field. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct HGet<'a> { + pub namespace: Namespace, + pub key: Bytes<'a>, + pub field: Bytes<'a>, + pub keyspace_version: Option, +} + +/// Sets a new [`MapEntry`]. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct HSet<'a> { + pub namespace: Namespace, + pub key: Bytes<'a>, + pub entry: MapEntry<'a>, + pub keyspace_version: Option, +} + +/// Deletes a [`MapEntry`] by the provided key and field. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct HDel<'a> { + pub namespace: Namespace, + pub key: Bytes<'a>, + pub field: Bytes<'a>, + pub version: RecordVersion, + pub keyspace_version: Option, +} + +/// Gets a [`RecordExpiration`] by the provided key and field. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct HGetExp<'a> { + pub namespace: Namespace, + pub key: Bytes<'a>, + pub field: Bytes<'a>, + pub keyspace_version: Option, +} + +/// Sets [`RecordExpiration`] on the [`MapEntry`] with the provided key and +/// field. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct HSetExp<'a> { + pub namespace: Namespace, + pub key: Bytes<'a>, + pub field: Bytes<'a>, + pub expiration: RecordExpiration, + pub version: RecordVersion, + pub keyspace_version: Option, +} + +/// Returns cardinality of the Map with the provided key. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct HCard<'a> { + pub namespace: Namespace, + pub key: Bytes<'a>, + pub keyspace_version: Option, +} + +/// Returns a [`MapPage`] by iterating over the fields of the Map with +/// the provided key. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct HScan<'a> { + pub namespace: Namespace, + pub key: Bytes<'a>, + pub count: u32, + pub cursor: Option>, + pub keyspace_version: Option, +} + +/// [`Operation`] output. +#[derive(Clone, Debug, From, PartialEq, Eq, EnumDiscriminants, TryInto)] +#[strum_discriminants(name(OutputName))] +#[strum_discriminants(derive(strum::Display))] +pub enum Output<'a> { + Record(Option>), + Expiration(Option), + MapPage(MapPage<'a>), + Cardinality(u64), + None, +} + +impl From<()> for Output<'_> { + fn from(_: ()) -> Self { + Self::None + } +} + +impl Output<'_> { + /// Tries to downcast an [`Output`] within a [`Result`] into a concrete + /// output type. + pub fn downcast_result(operation_result: Result) -> Result + where + Self: TryInto>, + { + operation_result? + .try_into() + .tap_err(|err| tracing::error!(?err, "Failed to downcast output")) + .map_err(|err| Error::new(ErrorKind::Internal, Some(err.to_string()))) + } + + /// Converts `Self` into 'static. + pub fn into_static(self) -> Output<'static> { + match self { + Self::Record(opt) => Output::Record(opt.map(Record::into_static)), + Self::Expiration(record_expiration) => Output::Expiration(record_expiration), + Self::MapPage(map_page) => Output::MapPage(map_page.into_static()), + Self::Cardinality(card) => Output::Cardinality(card), + Self::None => Output::None, + } + } +} diff --git a/crates/storage_api2/src/rpc/client.rs b/crates/storage_api2/src/rpc/client.rs new file mode 100644 index 00000000..071032a5 --- /dev/null +++ b/crates/storage_api2/src/rpc/client.rs @@ -0,0 +1,154 @@ +pub use wcn_rpc::client2::Config; +use { + super::*, + crate::{operation, MapPage, Operation, OperationRef, Record, Result, StorageApi}, + wcn_rpc::client2::{Client, Connection, ConnectionHandler, RpcHandler}, +}; + +/// RPC [`Client`] of [`CoordinatorApi`]. +pub type Coordinator = Client; + +/// Outbound [`Connection`] to [`CoordinatorApi`]. +pub type CoordinatorConnection = Connection; + +/// RPC [`Client`] of [`ReplicaApi`]. +pub type Replica = Client; + +/// Outbound [`Connection`] to [`ReplicaApi`]. +pub type ReplicaConnection = Connection; + +/// RPC [`Client`] of [`DatabaseApi`]. +pub type Database = Client; + +/// Outbound [`Connection`] to [`DatabaseApi`]. +pub type DatabaseConnection = Connection; + +/// Creates a new [`Coordinator`] RPC client. +pub fn coordinator(config: Config) -> wcn_rpc::client2::Result { + Client::new(config, ConnectionHandler) +} + +/// Creates a new [`ReplicaApi`] RPC client. +pub fn replica(config: Config) -> wcn_rpc::client2::Result { + Client::new(config, ConnectionHandler) +} + +/// Creates a new [`DatabaseApi`] RPC client. +pub fn database(config: Config) -> wcn_rpc::client2::Result { + Client::new(config, ConnectionHandler) +} + +impl wcn_rpc::client2::Api for Api +where + Self: wcn_rpc::Api, +{ + type ConnectionParameters = (); + type ConnectionHandler = ConnectionHandler; + type RpcHandler = RpcHandler; +} + +impl StorageApi for Connection> +where + Api: wcn_rpc::client2::Api< + ConnectionParameters = (), + ConnectionHandler = ConnectionHandler, + RpcHandler = RpcHandler, + >, +{ + async fn get(&self, op: &operation::Get<'_>) -> Result>> { + self.send::(op)?.await?.map_err(Into::into) + } + + async fn set(&self, op: &operation::Set<'_>) -> Result<()> { + self.send::(op)?.await?.map_err(Into::into) + } + + async fn del(&self, op: &operation::Del<'_>) -> Result<()> { + self.send::(op)?.await?.map_err(Into::into) + } + + async fn get_exp(&self, op: &operation::GetExp<'_>) -> Result> { + self.send::(op)?.await?.map_err(Into::into) + } + + async fn set_exp(&self, op: &operation::SetExp<'_>) -> Result<()> { + self.send::(op)?.await?.map_err(Into::into) + } + + async fn hget(&self, op: &operation::HGet<'_>) -> Result>> { + self.send::(op)?.await?.map_err(Into::into) + } + + async fn hset(&self, op: &operation::HSet<'_>) -> Result<()> { + self.send::(op)?.await?.map_err(Into::into) + } + + async fn hdel(&self, op: &operation::HDel<'_>) -> Result<()> { + self.send::(op)?.await?.map_err(Into::into) + } + + async fn hget_exp(&self, op: &operation::HGetExp<'_>) -> Result> { + self.send::(op)?.await?.map_err(Into::into) + } + + async fn hset_exp(&self, op: &operation::HSetExp<'_>) -> Result<()> { + self.send::(op)?.await?.map_err(Into::into) + } + + async fn hcard(&self, op: &operation::HCard<'_>) -> Result { + self.send::(op)?.await?.map_err(Into::into) + } + + async fn hscan(&self, op: &operation::HScan<'_>) -> Result> { + self.send::(op)?.await?.map_err(Into::into) + } + + async fn execute_ref<'a>( + &'a self, + operation: impl Into> + Send + 'a, + ) -> Result> { + match operation.into() { + OperationRef::Get(get) => self.get(get).await.map(Into::into), + OperationRef::Set(set) => self.set(set).await.map(Into::into), + OperationRef::Del(del) => self.del(del).await.map(Into::into), + OperationRef::GetExp(get_exp) => self.get_exp(get_exp).await.map(Into::into), + OperationRef::SetExp(set_exp) => self.set_exp(set_exp).await.map(Into::into), + OperationRef::HGet(hget) => self.hget(hget).await.map(Into::into), + OperationRef::HSet(hset) => self.hset(hset).await.map(Into::into), + OperationRef::HDel(hdel) => self.hdel(hdel).await.map(Into::into), + OperationRef::HGetExp(hget_exp) => self.hget_exp(hget_exp).await.map(Into::into), + OperationRef::HSetExp(hset_exp) => self.hset_exp(hset_exp).await.map(Into::into), + OperationRef::HCard(hcard) => self.hcard(hcard).await.map(Into::into), + OperationRef::HScan(hscan) => self.hscan(hscan).await.map(Into::into), + } + } + + async fn execute<'a>( + &'a self, + operation: crate::Operation<'a>, + ) -> Result> { + Ok(match operation { + Operation::Get(op) => self.send::(&op)?.await??.into(), + Operation::Set(op) => self.send::(&op)?.await??.into(), + Operation::Del(op) => self.send::(&op)?.await??.into(), + Operation::GetExp(op) => self.send::(&op)?.await??.into(), + Operation::SetExp(op) => self.send::(&op)?.await??.into(), + Operation::HGet(op) => self.send::(&op)?.await??.into(), + Operation::HSet(op) => self.send::(&op)?.await??.into(), + Operation::HDel(op) => self.send::(&op)?.await??.into(), + Operation::HGetExp(op) => self.send::(&op)?.await??.into(), + Operation::HSetExp(op) => self.send::(&op)?.await??.into(), + Operation::HCard(op) => self.send::(&op)?.await??.into(), + Operation::HScan(op) => self.send::(&op)?.await??.into(), + }) + } +} + +impl From for crate::Error { + fn from(err: wcn_rpc::client2::Error) -> Self { + Self::new( + crate::ErrorKind::Transport, + Some(format!("wcn_rpc::client::Error: {err}")), + ) + } +} diff --git a/crates/storage_api2/src/rpc/mod.rs b/crates/storage_api2/src/rpc/mod.rs new file mode 100644 index 00000000..04713b35 --- /dev/null +++ b/crates/storage_api2/src/rpc/mod.rs @@ -0,0 +1,229 @@ +use { + crate::{operation, MapPage, Record, RecordExpiration}, + derive_more::derive::TryFrom, + serde::{Deserialize, Serialize}, + std::marker::PhantomData, + wcn_rpc::{ApiName, MessageV2 as Message, PostcardCodec}, +}; + +#[cfg(feature = "rpc_client")] +pub mod client; +#[cfg(feature = "rpc_server")] +pub mod server; + +#[derive(Clone, Copy, Debug, TryFrom)] +#[try_from(repr)] +#[repr(u8)] +pub enum Id { + Get = 0, + Set = 1, + Del = 2, + SetExp = 3, + GetExp = 4, + + HGet = 5, + HSet = 6, + HDel = 7, + HSetExp = 8, + HGetExp = 9, + HCard = 10, + HScan = 11, +} + +impl From for u8 { + fn from(id: Id) -> Self { + id as u8 + } +} + +/// `wcn_rpc` implementation of [`StorageApi`](super::StorageApi). +#[derive(Clone, Copy, Debug)] +pub struct Api(PhantomData); + +mod api_kind { + #[derive(Clone, Copy, Debug)] + pub struct Coordinator; + + #[derive(Clone, Copy, Debug)] + pub struct Replica; + + #[derive(Clone, Copy, Debug)] + pub struct Database; +} + +pub type CoordinatorApi = Api; + +impl wcn_rpc::Api for CoordinatorApi { + const NAME: ApiName = ApiName::new("Coordinator"); + type RpcId = Id; +} + +pub type ReplicaApi = Api; + +impl wcn_rpc::Api for ReplicaApi { + const NAME: ApiName = ApiName::new("Replica"); + type RpcId = Id; +} + +pub type DatabaseApi = Api; + +impl wcn_rpc::Api for DatabaseApi { + const NAME: ApiName = ApiName::new("Database"); + type RpcId = Id; +} + +type UnaryRpc = wcn_rpc::UnaryV2; + +type Get = UnaryRpc<{ Id::Get as u8 }, operation::Get<'static>, Result>>>; +type Set = UnaryRpc<{ Id::Set as u8 }, operation::Set<'static>, Result<()>>; +type Del = UnaryRpc<{ Id::Del as u8 }, operation::Del<'static>, Result<()>>; + +type GetExp = + UnaryRpc<{ Id::GetExp as u8 }, operation::GetExp<'static>, Result>>; +type SetExp = UnaryRpc<{ Id::SetExp as u8 }, operation::SetExp<'static>, Result<()>>; + +type HGet = UnaryRpc<{ Id::HGet as u8 }, operation::HGet<'static>, Result>>>; +type HSet = UnaryRpc<{ Id::HSet as u8 }, operation::HSet<'static>, Result<()>>; +type HDel = UnaryRpc<{ Id::HDel as u8 }, operation::HDel<'static>, Result<()>>; + +type HGetExp = + UnaryRpc<{ Id::HGetExp as u8 }, operation::HGetExp<'static>, Result>>; +type HSetExp = UnaryRpc<{ Id::HSetExp as u8 }, operation::HSetExp<'static>, Result<()>>; + +type HCard = UnaryRpc<{ Id::HCard as u8 }, operation::HCard<'static>, Result>; +type HScan = UnaryRpc<{ Id::HScan as u8 }, operation::HScan<'static>, Result>>; + +impl Message for operation::Get<'static> { + type Borrowed<'a> = operation::Get<'a>; +} + +impl Message for operation::Set<'static> { + type Borrowed<'a> = operation::Set<'a>; +} + +impl Message for operation::Del<'static> { + type Borrowed<'a> = operation::Del<'a>; +} + +impl Message for operation::GetExp<'static> { + type Borrowed<'a> = operation::GetExp<'a>; +} + +impl Message for operation::SetExp<'static> { + type Borrowed<'a> = operation::SetExp<'a>; +} + +impl Message for operation::HSet<'static> { + type Borrowed<'a> = operation::HSet<'a>; +} + +impl Message for operation::HGet<'static> { + type Borrowed<'a> = operation::HGet<'a>; +} + +impl Message for operation::HDel<'static> { + type Borrowed<'a> = operation::HDel<'a>; +} + +impl Message for operation::HGetExp<'static> { + type Borrowed<'a> = operation::HGetExp<'a>; +} + +impl Message for operation::HSetExp<'static> { + type Borrowed<'a> = operation::HSetExp<'a>; +} + +impl Message for operation::HCard<'static> { + type Borrowed<'a> = operation::HCard<'a>; +} + +impl Message for operation::HScan<'static> { + type Borrowed<'a> = operation::HScan<'a>; +} + +impl Message for Record<'static> { + type Borrowed<'a> = Record<'a>; +} + +impl Message for RecordExpiration { + type Borrowed<'a> = Self; +} + +impl Message for MapPage<'static> { + type Borrowed<'a> = MapPage<'a>; +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +struct Error { + code: u8, + details: Option, +} + +impl Error { + fn new(code: ErrorCode, details: Option) -> Self { + Self { + code: code as u8, + details, + } + } +} + +impl From for Error { + fn from(code: ErrorCode) -> Self { + Self::new(code, None) + } +} + +#[derive(TryFrom)] +#[try_from(repr)] +#[repr(u8)] +enum ErrorCode { + Internal = 0, + Unauthorized = 1, + KeyspaceVersionMismatch = 2, +} + +type Result = std::result::Result; + +impl From for crate::Error { + fn from(err: Error) -> Self { + use crate::ErrorKind; + + let Ok(code) = ErrorCode::try_from(err.code) else { + return Self::new( + crate::ErrorKind::Unknown, + Some(format!("Unexpected error code: {}", err.code)), + ); + }; + + let kind = match code { + ErrorCode::Unauthorized => ErrorKind::Unauthorized, + ErrorCode::KeyspaceVersionMismatch => ErrorKind::KeyspaceVersionMismatch, + ErrorCode::Internal => ErrorKind::Internal, + }; + + Self::new(kind, err.details) + } +} + +impl From for Error { + fn from(err: crate::Error) -> Self { + use crate::ErrorKind; + + let code = match err.kind { + ErrorKind::Unauthorized => ErrorCode::Unauthorized, + ErrorKind::KeyspaceVersionMismatch => ErrorCode::KeyspaceVersionMismatch, + + ErrorKind::Internal + | ErrorKind::Timeout + | ErrorKind::Transport + | ErrorKind::Unknown => ErrorCode::Internal, + }; + + Error::new(code, err.details) + } +} + +impl Message for Error { + type Borrowed<'a> = Self; +} diff --git a/crates/storage_api2/src/rpc/server.rs b/crates/storage_api2/src/rpc/server.rs new file mode 100644 index 00000000..fedf1a4b --- /dev/null +++ b/crates/storage_api2/src/rpc/server.rs @@ -0,0 +1,199 @@ +use { + super::*, + crate::{rpc::Id as RpcId, Operation, StorageApi}, + futures::{FutureExt as _, TryFutureExt as _}, + wcn_rpc::{ + server2::{Connection, HandleConnection, HandleRequest, Result}, + Request, + Response, + }, +}; + +/// Creates a new [`CoordinatorApi`] RPC server. +pub fn coordinator(storage_api: impl StorageApi) -> impl wcn_rpc::server2::Server { + new::(storage_api) +} + +/// Creates a new [`ReplicaApi`] RPC server. +pub fn replica(storage_api: impl StorageApi) -> impl wcn_rpc::server2::Server { + new::(storage_api) +} + +/// Creates a new [`DatabaseApi`] RPC server. +pub fn database(storage_api: impl StorageApi) -> impl wcn_rpc::server2::Server { + new::(storage_api) +} + +fn new(storage_api: impl StorageApi) -> impl wcn_rpc::server2::Server +where + Kind: Clone + Send + Sync + 'static, + Api: wcn_rpc::Api, +{ + wcn_rpc::server2::new(ConnectionHandler { + rpc_handler: RpcHandler { storage_api }, + _marker: PhantomData, + }) +} + +#[derive(Clone)] +struct ConnectionHandler { + rpc_handler: RpcHandler, + _marker: PhantomData, +} + +impl HandleConnection for ConnectionHandler +where + S: StorageApi, + Kind: Clone + Send + Sync + 'static, + Api: wcn_rpc::Api, +{ + type Api = super::Api; + + async fn handle_connection(&self, conn: Connection<'_, Self::Api>) -> Result<()> { + conn.handle(&self.rpc_handler, |rpc, handler| async move { + match rpc.id() { + RpcId::Get => rpc.handle::(&handler).await, + RpcId::Set => rpc.handle::(&handler).await, + RpcId::Del => rpc.handle::(&handler).await, + RpcId::SetExp => rpc.handle::(&handler).await, + RpcId::GetExp => rpc.handle::(&handler).await, + RpcId::HGet => rpc.handle::(&handler).await, + RpcId::HSet => rpc.handle::(&handler).await, + RpcId::HDel => rpc.handle::(&handler).await, + RpcId::HSetExp => rpc.handle::(&handler).await, + RpcId::HGetExp => rpc.handle::(&handler).await, + RpcId::HCard => rpc.handle::(&handler).await, + RpcId::HScan => rpc.handle::(&handler).await, + } + }) + .await + } +} + +#[derive(Clone)] +struct RpcHandler { + storage_api: S, +} + +impl HandleRequest for RpcHandler { + async fn handle_request(&self, req: Request) -> Response { + self.storage_api + .execute(Operation::Get(req)) + .map_ok(operation::Output::into_static) + .map(operation::Output::downcast_result) + .await + .map_err(Into::into) + } +} + +impl HandleRequest for RpcHandler { + async fn handle_request(&self, req: Request) -> Response { + self.storage_api + .execute(Operation::Set(req)) + .map(operation::Output::downcast_result) + .await + .map_err(Into::into) + } +} + +impl HandleRequest for RpcHandler { + async fn handle_request(&self, req: Request) -> Response { + self.storage_api + .execute(Operation::Del(req)) + .map(operation::Output::downcast_result) + .await + .map_err(Into::into) + } +} + +impl HandleRequest for RpcHandler { + async fn handle_request(&self, req: Request) -> Response { + self.storage_api + .execute(Operation::GetExp(req)) + .map(operation::Output::downcast_result) + .await + .map_err(Into::into) + } +} + +impl HandleRequest for RpcHandler { + async fn handle_request(&self, req: Request) -> Response { + self.storage_api + .execute(Operation::SetExp(req)) + .map(operation::Output::downcast_result) + .await + .map_err(Into::into) + } +} + +impl HandleRequest for RpcHandler { + async fn handle_request(&self, req: Request) -> Response { + self.storage_api + .execute(Operation::HGet(req)) + .map_ok(operation::Output::into_static) + .map(operation::Output::downcast_result) + .await + .map_err(Into::into) + } +} + +impl HandleRequest for RpcHandler { + async fn handle_request(&self, req: Request) -> Response { + self.storage_api + .execute(Operation::HSet(req)) + .map(operation::Output::downcast_result) + .await + .map_err(Into::into) + } +} + +impl HandleRequest for RpcHandler { + async fn handle_request(&self, req: Request) -> Response { + self.storage_api + .execute(Operation::HDel(req)) + .map(operation::Output::downcast_result) + .await + .map_err(Into::into) + } +} + +impl HandleRequest for RpcHandler { + async fn handle_request(&self, req: Request) -> Response { + self.storage_api + .execute(Operation::HGetExp(req)) + .map(operation::Output::downcast_result) + .await + .map_err(Into::into) + } +} + +impl HandleRequest for RpcHandler { + async fn handle_request(&self, req: Request) -> Response { + self.storage_api + .execute(Operation::HSetExp(req)) + .map(operation::Output::downcast_result) + .await + .map_err(Into::into) + } +} + +impl HandleRequest for RpcHandler { + async fn handle_request(&self, req: Request) -> Response { + self.storage_api + .execute(Operation::HCard(req)) + .map(operation::Output::downcast_result) + .await + .map_err(Into::into) + } +} + +impl HandleRequest for RpcHandler { + async fn handle_request(&self, req: Request) -> Response { + self.storage_api + .execute(Operation::HScan(req)) + .map_ok(operation::Output::into_static) + .map(operation::Output::downcast_result) + .await + .map_err(Into::into) + } +} diff --git a/crates/storage_api2/src/server.rs b/crates/storage_api2/src/server.rs deleted file mode 100644 index a48bfddd..00000000 --- a/crates/storage_api2/src/server.rs +++ /dev/null @@ -1,470 +0,0 @@ -use { - super::*, - futures::SinkExt as _, - std::{collections::HashSet, future::Future, sync::Arc}, - wcn_rpc::{ - middleware::Timeouts, - server::{ - middleware::{MeteredExt, WithTimeoutsExt}, - ClientConnectionInfo, - ConnectionInfo, - }, - transport::{self, BiDirectionalStream, PendingConnection, PostcardCodec}, - }, -}; - -/// Storage namespace. -pub type Namespace = Vec; - -/// Storage API [`Server`] config. -pub struct Config { - /// Timeout of a [`Server`] operation. - pub operation_timeout: Duration, - - /// Inbound connection [`Authenticator`]. - pub authenticator: A, -} - -/// Storage API server. -pub trait Server: Clone + Send + Sync + 'static { - /// Returns the current keyspace version of this [`Server`]. - fn keyspace_version(&self) -> u64; - - /// Gets a [`Record`] by the provided [`Key`]. - fn get(&self, key: Key) -> impl Future>> + Send; - - /// Sets the provided [`Entry`] only if the version of the existing - /// [`Entry`] is < than the new one. - fn set(&self, entry: Entry) -> impl Future> + Send; - - /// Deletes an [`Entry`] by the provided [`Key`] only if the version of the - /// [`Entry`] is < than the provided `version`. - fn del(&self, key: Key, version: EntryVersion) -> impl Future> + Send; - - /// Gets an [`EntryExpiration`] by the provided [`Key`]. - fn get_exp(&self, key: Key) -> impl Future>> + Send; - - /// Sets [`Expiration`] on the [`Entry`] with the provided [`Key`] only if - /// the version of the [`Entry`] is < than the provided `version`. - fn set_exp( - &self, - key: Key, - expiration: impl Into, - version: EntryVersion, - ) -> impl Future> + Send; - - /// Gets a map [`Record`] by the provided [`Key`] and [`Field`]. - fn hget(&self, key: Key, field: Field) -> impl Future>> + Send; - - /// Sets the provided [`MapEntry`] only if the version of the existing - /// [`MapEntry`] is < than the new one. - fn hset(&self, entry: MapEntry) -> impl Future> + Send; - - /// Deletes a [`MapEntry`] by the provided [`Key`] only if the version of - /// the [`MapEntry`] is < than the provided `version`. - fn hdel( - &self, - key: Key, - field: Field, - version: EntryVersion, - ) -> impl Future> + Send; - - /// Gets an [`EntryExpiration`] by the provided [`Key`] and [`Field`]. - fn hget_exp( - &self, - key: Key, - field: Field, - ) -> impl Future>> + Send; - - /// Sets [`Expiration`] on the [`MapEntry`] with the provided [`Key`] and - /// [`Field`] only if the version of the [`MapEntry`] is < than the - /// provided `version`. - fn hset_exp( - &self, - key: Key, - field: Field, - expiration: impl Into, - version: EntryVersion, - ) -> impl Future> + Send; - - /// Returns cardinality of the map with the provided [`Key`]. - fn hcard(&self, key: Key) -> impl Future> + Send; - - /// Returns a [`MapPage`] by iterating over the [`Field`]s of the map with - /// the provided [`Key`]. - fn hscan( - &self, - key: Key, - count: u32, - cursor: Option, - ) -> impl Future> + Send; - - /// Converts this Storage API [`Server`] into an [`rpc::Server`]. - fn into_rpc_server(self, cfg: Config) -> impl rpc::Server { - let timeouts = Timeouts::new().with_default(cfg.operation_timeout); - - let rpc_server_config = wcn_rpc::server::Config { - name: crate::RPC_SERVER_NAME, - handshake: Handshake { - authenticator: cfg.authenticator, - }, - }; - - RpcServer { - api_server: self, - config: rpc_server_config, - } - .with_timeouts(timeouts) - .metered() - } -} - -struct RpcHandler<'a, S> { - api_server: &'a S, - conn_info: &'a ConnectionInfo, -} - -impl RpcHandler<'_, S> { - fn prepare_key(&self, key: ExtendedKey) -> wcn_rpc::Result { - if let Some(keyspace_version) = key.keyspace_version { - if keyspace_version != self.api_server.keyspace_version() { - return Err(wcn_rpc::Error::new(error_code::KEYSPACE_VERSION_MISMATCH)); - } - } - - let key = Key::from_raw_bytes(key.inner) - .ok_or_else(|| wcn_rpc::Error::new(error_code::INVALID_KEY))?; - - if let Some(namespace) = key.namespace() { - if !self - .conn_info - .handshake_data - .namespaces - .contains(namespace.as_slice()) - { - return Err(wcn_rpc::Error::new(error_code::UNAUTHORIZED)); - } - } - - Ok(key) - } - - async fn get(&self, req: GetRequest) -> wcn_rpc::Result> { - let record = self - .api_server - .get(self.prepare_key(req.key)?) - .await - .map_err(Error::into_rpc_error)?; - - Ok(record.map(|rec| GetResponse { - value: rec.value, - expiration: rec.expiration.timestamp(), - version: rec.version.timestamp(), - })) - } - - async fn set(&self, req: SetRequest) -> wcn_rpc::Result<()> { - let entry = Entry { - key: self.prepare_key(req.key)?, - value: req.value, - expiration: EntryExpiration::from(req.expiration), - version: EntryVersion::from(req.version), - }; - - self.api_server - .set(entry) - .await - .map_err(Error::into_rpc_error) - } - - async fn del(&self, req: DelRequest) -> wcn_rpc::Result<()> { - self.api_server - .del(self.prepare_key(req.key)?, EntryVersion::from(req.version)) - .await - .map_err(Error::into_rpc_error) - } - - async fn get_exp(&self, req: GetExpRequest) -> wcn_rpc::Result> { - let expiration = self - .api_server - .get_exp(self.prepare_key(req.key)?) - .await - .map_err(Error::into_rpc_error)?; - - Ok(expiration.map(|exp| GetExpResponse { - expiration: exp.timestamp(), - })) - } - - async fn set_exp(&self, req: SetExpRequest) -> wcn_rpc::Result<()> { - self.api_server - .set_exp( - self.prepare_key(req.key)?, - EntryExpiration::from(req.expiration), - EntryVersion::from(req.version), - ) - .await - .map_err(Error::into_rpc_error) - } - - async fn hget(&self, req: HGetRequest) -> wcn_rpc::Result> { - let record = self - .api_server - .hget(self.prepare_key(req.key)?, req.field) - .await - .map_err(Error::into_rpc_error)?; - - Ok(record.map(|rec| HGetResponse { - value: rec.value, - expiration: rec.expiration.timestamp(), - version: rec.version.timestamp(), - })) - } - - async fn hset(&self, req: HSetRequest) -> wcn_rpc::Result<()> { - let entry = MapEntry { - key: self.prepare_key(req.key)?, - field: req.field, - value: req.value, - expiration: EntryExpiration::from(req.expiration), - version: EntryVersion::from(req.version), - }; - - self.api_server - .hset(entry) - .await - .map_err(Error::into_rpc_error) - } - - async fn hdel(&self, req: HDelRequest) -> wcn_rpc::Result<()> { - self.api_server - .hdel( - self.prepare_key(req.key)?, - req.field, - EntryVersion::from(req.version), - ) - .await - .map_err(Error::into_rpc_error) - } - - async fn hget_exp(&self, req: HGetExpRequest) -> wcn_rpc::Result> { - let expiration = self - .api_server - .hget_exp(self.prepare_key(req.key)?, req.field) - .await - .map_err(Error::into_rpc_error)?; - - Ok(expiration.map(|exp| HGetExpResponse { - expiration: exp.timestamp(), - })) - } - - async fn hset_exp(&self, req: HSetExpRequest) -> wcn_rpc::Result<()> { - self.api_server - .hset_exp( - self.prepare_key(req.key)?, - req.field, - EntryExpiration::from(req.expiration), - EntryVersion::from(req.version), - ) - .await - .map_err(Error::into_rpc_error) - } - - async fn hcard(&self, req: HCardRequest) -> wcn_rpc::Result { - self.api_server - .hcard(self.prepare_key(req.key)?) - .await - .map(|cardinality| HCardResponse { cardinality }) - .map_err(Error::into_rpc_error) - } - - async fn hscan(&self, req: HScanRequest) -> wcn_rpc::Result { - let page = self - .api_server - .hscan(self.prepare_key(req.key)?, req.count, req.cursor) - .await - .map_err(Error::into_rpc_error)?; - - Ok(HScanResponse { - records: page - .records - .into_iter() - .map(|rec| HScanResponseRecord { - field: rec.field, - value: rec.value, - expiration: rec.expiration.timestamp(), - version: rec.version.timestamp(), - }) - .collect(), - has_more: page.has_next, - }) - } -} - -#[derive(Clone, Debug)] -struct RpcServer { - api_server: S, - config: rpc::server::Config>, -} - -impl rpc::Server for RpcServer -where - S: Server, - V: Authenticator, -{ - type Handshake = Handshake; - type ConnectionData = (); - type Codec = PostcardCodec; - - fn config(&self) -> &wcn_rpc::server::Config { - &self.config - } - - fn handle_rpc<'a>( - &'a self, - id: rpc::Id, - stream: BiDirectionalStream, - conn_info: &'a ClientConnectionInfo, - ) -> impl Future + Send + 'a { - async move { - let handler = RpcHandler { - api_server: &self.api_server, - conn_info, - }; - - let _ = match id { - Get::ID => Get::handle(stream, |req| handler.get(req)).await, - Set::ID => Set::handle(stream, |req| handler.set(req)).await, - Del::ID => Del::handle(stream, |req| handler.del(req)).await, - GetExp::ID => GetExp::handle(stream, |req| handler.get_exp(req)).await, - SetExp::ID => SetExp::handle(stream, |req| handler.set_exp(req)).await, - - HGet::ID => HGet::handle(stream, |req| handler.hget(req)).await, - HSet::ID => HSet::handle(stream, |req| handler.hset(req)).await, - HDel::ID => HDel::handle(stream, |req| handler.hdel(req)).await, - HGetExp::ID => HGetExp::handle(stream, |req| handler.hget_exp(req)).await, - HSetExp::ID => HSetExp::handle(stream, |req| handler.hset_exp(req)).await, - HCard::ID => HCard::handle(stream, |req| handler.hcard(req)).await, - HScan::ID => HScan::handle(stream, |req| handler.hscan(req)).await, - - id => return tracing::warn!("Unexpected RPC: {}", rpc::Name::new(id)), - } - .map_err( - |err| tracing::debug!(name = %rpc::Name::new(id), ?err, "Failed to handle RPC"), - ); - } - } -} - -/// Error of a [`Server`] operation. -#[derive(Clone, Debug)] -pub struct Error(String); - -impl Error { - pub fn new(err: E) -> Self { - Self(format!("{err}")) - } - - fn into_rpc_error(self) -> wcn_rpc::Error { - wcn_rpc::Error { - code: "internal".into(), - description: Some(self.0.into()), - } - } -} - -/// [`Server`] operation [`Result`]. -pub type Result = std::result::Result; - -/// Server part of the [`network::Handshake`]. -#[derive(Clone, Debug)] -pub struct Handshake { - authenticator: V, -} - -#[derive(Clone, Debug)] -pub struct HandshakeData { - pub namespaces: Arc>, -} - -impl transport::Handshake for Handshake { - type Ok = HandshakeData; - type Err = HandshakeError; - - fn handle( - &self, - peer_id: PeerId, - conn: PendingConnection, - ) -> impl Future> + Send { - async move { - let (mut rx, mut tx) = conn - .accept_handshake::() - .await?; - - let req = rx.recv_message().await?; - - let err_resp = match self - .authenticator - .validate_access_token(&req.access_token, peer_id) - { - Ok(data) => { - tx.send(Ok(())).await?; - return Ok(HandshakeData { - namespaces: Arc::new( - data.namespaces() - .into_iter() - .map(|ns| ns.as_bytes().to_vec()) - .collect(), - ), - }); - } - Err(err) => HandshakeErrorResponse::InvalidToken(err), - }; - - tx.send(Err(err_resp.clone())).await?; - Err(err_resp.into()) - } - } -} - -/// Inbound connection authenticator. -pub trait Authenticator: Clone + Send + Sync + 'static { - /// Indicates whether the specified peer is an authorized access token - /// issuer. - fn is_authorized_token_issuer(&self, peer_id: PeerId) -> bool; - - /// Network id of the local Storage API server. - fn network_id(&self) -> &str; - - /// Validates the provided access token. - fn validate_access_token( - &self, - token: &auth::Token, - client_peer_id: PeerId, - ) -> Result { - let claims = token.decode().map_err(|err| err.to_string())?; - - if claims.is_expired() { - return Err("Token expired".to_string()); - } - - match claims.purpose() { - auth::token::Purpose::Storage => {} - }; - - if self.network_id() != claims.network_id() { - return Err("Wrong network".to_string()); - } - - if !self.is_authorized_token_issuer(claims.issuer_peer_id()) { - return Err("Unauthorized token issuer".to_string()); - } - - if claims.client_peer_id() != client_peer_id { - return Err("Wrong PeerId".to_string()); - } - - Ok(claims) - } -} diff --git a/crates/storage_api2/tests/integration.rs b/crates/storage_api2/tests/integration.rs new file mode 100644 index 00000000..18773007 --- /dev/null +++ b/crates/storage_api2/tests/integration.rs @@ -0,0 +1,384 @@ +use { + rand::{random, Rng}, + std::{ + net::{Ipv4Addr, SocketAddrV4}, + sync::{Arc, Mutex}, + time::Duration, + }, + tracing_subscriber::EnvFilter, + wc::future::StaticFutureExt, + wcn_rpc::{ + client2::{Api, Client, Connection}, + identity::Keypair, + transport, + }, + wcn_storage_api2::{ + operation, + Bytes, + KeyspaceVersion, + MapEntry, + MapPage, + Namespace, + Operation, + Record, + RecordExpiration, + RecordVersion, + Result, + StorageApi, + }, +}; + +#[tokio::test] +async fn test_rpc() { + let subscriber = tracing_subscriber::FmtSubscriber::builder() + .with_max_level(tracing::Level::INFO) + .with_env_filter(EnvFilter::from_default_env()) + .finish(); + tracing::subscriber::set_global_default(subscriber).unwrap(); + + for _ in 0..10 { + test_rpc_api( + wcn_storage_api2::rpc::server::coordinator, + wcn_storage_api2::rpc::client::coordinator, + ) + .await; + + test_rpc_api( + wcn_storage_api2::rpc::server::replica, + wcn_storage_api2::rpc::client::replica, + ) + .await; + + test_rpc_api( + wcn_storage_api2::rpc::server::database, + wcn_storage_api2::rpc::client::database, + ) + .await; + } +} + +async fn test_rpc_api( + server: impl FnOnce(TestStorage) -> S, + client: impl FnOnce(wcn_rpc::client2::Config) -> wcn_rpc::client2::Result>, +) where + API: Api, + S: wcn_rpc::server2::Server, + Connection: StorageApi, +{ + let storage = TestStorage::default(); + + let server_port = find_available_port(); + let server_keypair = Keypair::generate_ed25519(); + let server_peer_id = server_keypair.public().to_peer_id(); + let server_cfg = wcn_rpc::server2::Config { + name: "test", + port: server_port, + keypair: server_keypair, + connection_timeout: Duration::from_secs(10), + max_connections: 1, + max_connections_per_ip: 1, + max_connection_rate_per_ip: 1, + max_concurrent_rpcs: 10, + priority: transport::Priority::High, + }; + + let server_handle = server(storage.clone()).serve(server_cfg).unwrap().spawn(); + + let client_config = wcn_rpc::client2::Config { + keypair: Keypair::generate_ed25519(), + connection_timeout: Duration::from_secs(10), + reconnect_interval: Duration::from_secs(1), + max_concurrent_rpcs: 10, + priority: transport::Priority::High, + }; + + let client = client(client_config).unwrap(); + + let server_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, server_port); + let client_conn = client + .connect(server_addr, &server_peer_id, ()) + .await + .unwrap(); + + let ctx = &TestContext { + storage, + client_conn, + }; + + ctx.test_operations().await; + + server_handle.abort(); +} + +struct TestContext { + storage: TestStorage, + client_conn: Connection, +} + +impl TestContext +where + Connection: StorageApi, +{ + async fn test_operations(&self) { + self.test_get().await; + self.test_set().await; + self.test_del().await; + self.test_get_exp().await; + self.test_set_exp().await; + + self.test_hget().await; + self.test_hset().await; + self.test_hdel().await; + self.test_hget_exp().await; + self.test_hset_exp().await; + + self.test_hcard().await; + self.test_hscan().await; + } + + async fn test_operation( + &self, + operation: impl Into>, + output: impl Into>, + ) { + let operation = operation.into(); + let result = Ok(output.into()); + + let expect = (operation.clone(), result.clone()); + let _ = self.storage.expect.lock().unwrap().insert(expect); + assert_eq!(self.client_conn.execute(operation).await, result); + } + + async fn test_get(&self) { + let op = operation::Get { + namespace: namespace(), + key: bytes(), + keyspace_version: keyspace_version(), + }; + + self.test_operation(op, opt(record)).await; + } + + async fn test_set(&self) { + let op = operation::Set { + namespace: namespace(), + key: bytes(), + record: record(), + keyspace_version: keyspace_version(), + }; + + self.test_operation(op, ()).await; + } + + async fn test_del(&self) { + let op = operation::Del { + namespace: namespace(), + key: bytes(), + version: RecordVersion::now(), + keyspace_version: keyspace_version(), + }; + + self.test_operation(op, ()).await; + } + + async fn test_get_exp(&self) { + let op = operation::GetExp { + namespace: namespace(), + key: bytes(), + keyspace_version: keyspace_version(), + }; + + self.test_operation(op, opt(record_expiration)).await; + } + + async fn test_set_exp(&self) { + let op = operation::SetExp { + namespace: namespace(), + key: bytes(), + expiration: record_expiration(), + version: RecordVersion::now(), + keyspace_version: keyspace_version(), + }; + + self.test_operation(op, ()).await; + } + + async fn test_hget(&self) { + let op = operation::HGet { + namespace: namespace(), + key: bytes(), + field: bytes(), + keyspace_version: keyspace_version(), + }; + + self.test_operation(op, opt(record)).await; + } + + async fn test_hset(&self) { + let op = operation::HSet { + namespace: namespace(), + key: bytes(), + entry: map_entry(), + keyspace_version: keyspace_version(), + }; + + self.test_operation(op, ()).await; + } + + async fn test_hdel(&self) { + let op = operation::HDel { + namespace: namespace(), + key: bytes(), + field: bytes(), + version: RecordVersion::now(), + keyspace_version: keyspace_version(), + }; + + self.test_operation(op, ()).await; + } + + async fn test_hget_exp(&self) { + let op = operation::HGetExp { + namespace: namespace(), + key: bytes(), + field: bytes(), + keyspace_version: keyspace_version(), + }; + + self.test_operation(op, opt(record_expiration)).await; + } + + async fn test_hset_exp(&self) { + let op = operation::HSetExp { + namespace: namespace(), + key: bytes(), + field: bytes(), + expiration: record_expiration(), + version: RecordVersion::now(), + keyspace_version: keyspace_version(), + }; + + self.test_operation(op, ()).await; + } + + async fn test_hcard(&self) { + let op = operation::HCard { + namespace: namespace(), + key: bytes(), + keyspace_version: keyspace_version(), + }; + + self.test_operation(op, random::()).await; + } + + async fn test_hscan(&self) { + let op = operation::HScan { + namespace: namespace(), + key: bytes(), + count: random(), + cursor: opt(bytes), + keyspace_version: keyspace_version(), + }; + + self.test_operation(op, map_page()).await; + } +} + +fn namespace() -> Namespace { + let operator_id = rand::random::<[u8; 20]>(); + let id = rand::random::(); + + let operator_id = const_hex::encode(operator_id); + format!("{operator_id}/{id}").parse().unwrap() +} + +fn keyspace_version() -> Option { + if random() { + Some(random()) + } else { + None + } +} + +fn record() -> Record<'static> { + Record { + value: bytes(), + expiration: record_expiration(), + version: RecordVersion::now(), + } +} + +fn map_entry() -> MapEntry<'static> { + MapEntry { + field: bytes(), + record: record(), + } +} + +fn map_page() -> MapPage<'static> { + let mut rng = rand::thread_rng(); + + let len = rng.gen_range(1..=1000); + let mut buf = Vec::with_capacity(len); + for _ in 0..len { + buf.push(map_entry()); + } + + MapPage { + entries: buf, + has_next: random(), + } +} + +fn opt(f: fn() -> T) -> Option { + if random() { + Some(f()) + } else { + None + } +} + +fn bytes() -> Bytes<'static> { + let mut rng = rand::thread_rng(); + + let mut buf = vec![0u8; rng.gen_range(1..=4096)]; + rng.fill(&mut buf[..]); + buf.into() +} + +fn record_expiration() -> RecordExpiration { + let secs = rand::thread_rng().gen_range(30..=60 * 60 * 24 * 30); + Duration::from_secs(secs).into() +} + +fn find_available_port() -> u16 { + use std::{ + net::UdpSocket, + sync::atomic::{AtomicU16, Ordering}, + }; + + static NEXT_PORT: AtomicU16 = AtomicU16::new(48100); + + loop { + let port = NEXT_PORT.fetch_add(1, Ordering::Relaxed); + assert!(port != u16::MAX, "failed to find a free port"); + + if UdpSocket::bind((Ipv4Addr::LOCALHOST, port)).is_ok() { + return port; + } + } +} + +#[derive(Clone, Default)] +struct TestStorage { + #[allow(clippy::type_complexity)] + expect: Arc, Result>)>>>, +} + +impl StorageApi for TestStorage { + async fn execute<'a>(&'a self, operation: Operation<'a>) -> Result> { + let expected = self.expect.lock().unwrap().take().unwrap(); + assert_eq!(operation, expected.0); + expected.1 + } +} diff --git a/flake.nix b/flake.nix index f530c2d3..be43f747 100644 --- a/flake.nix +++ b/flake.nix @@ -30,6 +30,7 @@ pkg-config openssl clang + gcc13 # jemalloc fails to build on gcc14 (in debug builds) ]; rustc = { stable = fenixPackages.stable.rustc;