diff --git a/Cargo.lock b/Cargo.lock index c0081b7dd..ea1bdc495 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2753,7 +2753,7 @@ dependencies = [ [[package]] name = "typedb-protocol" version = "0.0.0" -source = "git+https://github.com/typedb/typedb-protocol?tag=3.0.0#111f1a9ed8aac3360c0b5d16e68c1ecebe823137" +source = "git+https://github.com/typedb/typedb-protocol?rev=b4b7ee87b08c16831a21629f81974347f38cce5c#b4b7ee87b08c16831a21629f81974347f38cce5c" dependencies = [ "prost", "tonic", diff --git a/dependencies/typedb/repositories.bzl b/dependencies/typedb/repositories.bzl index daf6f0958..7fe7aef96 100644 --- a/dependencies/typedb/repositories.bzl +++ b/dependencies/typedb/repositories.bzl @@ -25,11 +25,17 @@ def typedb_dependencies(): ) def typedb_protocol(): + # TODO: Temp, return typedb git_repository( name = "typedb_protocol", - remote = "https://github.com/typedb/typedb-protocol", - tag = "3.0.0", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_protocol + remote = "https://github.com/farost/typedb-protocol", + commit = "b4b7ee87b08c16831a21629f81974347f38cce5c", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_protocol ) +# git_repository( +# name = "typedb_protocol", +# remote = "https://github.com/typedb/typedb-protocol", +# tag = "3.0.0", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_protocol +# ) def typedb_behaviour(): git_repository( diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 108c688bf..3a183722f 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -60,8 +60,8 @@ [dependencies.typedb-protocol] features = [] + rev = "b4b7ee87b08c16831a21629f81974347f38cce5c" git = "https://github.com/typedb/typedb-protocol" - tag = "3.0.0" default-features = false [dependencies.log] diff --git a/rust/src/common/error.rs b/rust/src/common/error.rs index 6c8a284ca..c64a659d7 100644 --- a/rust/src/common/error.rs +++ b/rust/src/common/error.rs @@ -21,7 +21,7 @@ use std::{collections::HashSet, error::Error as StdError, fmt}; use itertools::Itertools; use tonic::{Code, Status}; -use tonic_types::StatusExt; +use tonic_types::{ErrorDetails, ErrorInfo, StatusExt}; use super::{address::Address, RequestID}; @@ -150,7 +150,7 @@ error_messages! { ConnectionError 15: "The replica is not the primary replica.", ClusterAllNodesFailed { errors: String } = 16: "Attempted connecting to all TypeDB Cluster servers, but the following errors occurred: \n{errors}.", - ClusterTokenCredentialInvalid = + TokenCredentialInvalid = 17: "Invalid token credentials.", EncryptionSettingsMismatch = 18: "Unable to connect to TypeDB: possible encryption settings mismatch.", @@ -275,6 +275,16 @@ impl Error { } } + fn try_extracting_connection_error(status: &Status, code: &str) -> Option { + // TODO: We should probably catch more connection errors instead of wrapping them into + // ServerErrors. However, the most valuable information even for connection is inside + // stacktraces now. + match code { + "AUT3" => Some(ConnectionError::TokenCredentialInvalid {}), + _ => None, + } + } + fn from_message(message: &str) -> Self { // TODO: Consider converting some of the messages to connection errors Self::Other(message.to_owned()) @@ -352,9 +362,13 @@ impl From for Error { }) } else if let Some(error_info) = details.error_info() { let code = error_info.reason.clone(); + if let Some(connection_error) = Self::try_extracting_connection_error(&status, &code) { + return Self::Connection(connection_error); + } let domain = error_info.domain.clone(); let stack_trace = if let Some(debug_info) = details.debug_info() { debug_info.stack_entries.clone() } else { vec![] }; + Self::Server(ServerError::new(code, domain, status.message().to_owned(), stack_trace)) } else { Self::from_message(status.message()) @@ -364,7 +378,6 @@ impl From for Error { Self::parse_unavailable(status.message()) } else if status.code() == Code::Unknown || is_rst_stream(&status) - || status.code() == Code::InvalidArgument || status.code() == Code::FailedPrecondition || status.code() == Code::AlreadyExists { diff --git a/rust/src/connection/message.rs b/rust/src/connection/message.rs index 73f860b0c..e9e080798 100644 --- a/rust/src/connection/message.rs +++ b/rust/src/connection/message.rs @@ -35,12 +35,12 @@ use crate::{ error::ServerError, info::UserInfo, user::User, - Options, TransactionType, + Credentials, Options, TransactionType, }; #[derive(Debug)] pub(super) enum Request { - ConnectionOpen { driver_lang: String, driver_version: String }, + ConnectionOpen { driver_lang: String, driver_version: String, credentials: Credentials }, ServersAll, diff --git a/rust/src/connection/network/channel.rs b/rust/src/connection/network/channel.rs index 1d972bfc9..a96ec9bdd 100644 --- a/rust/src/connection/network/channel.rs +++ b/rust/src/connection/network/channel.rs @@ -17,11 +17,12 @@ * under the License. */ -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use tonic::{ body::BoxBody, client::GrpcService, + metadata::MetadataValue, service::{ interceptor::{InterceptedService, ResponseFuture as InterceptorResponseFuture}, Interceptor, @@ -65,20 +66,33 @@ pub(super) fn open_callcred_channel( #[derive(Debug)] pub(super) struct CallCredentials { credentials: Credentials, + token: RwLock>, } impl CallCredentials { pub(super) fn new(credentials: Credentials) -> Self { - Self { credentials } + Self { credentials, token: RwLock::new(None) } } - pub(super) fn username(&self) -> &str { - self.credentials.username() + pub(super) fn credentials(&self) -> &Credentials { + &self.credentials + } + + pub(super) fn set_token(&self, token: String) { + *self.token.write().expect("Expected token write lock acquisition on set") = Some(token); + } + + pub(super) fn reset_token(&self) { + *self.token.write().expect("Expected token write lock acquisition on reset") = None; } pub(super) fn inject(&self, mut request: Request<()>) -> Request<()> { - request.metadata_mut().insert("username", self.credentials.username().try_into().unwrap()); - request.metadata_mut().insert("password", self.credentials.password().try_into().unwrap()); + if let Some(token) = &*self.token.read().expect("Expected token read lock acquisition on inject") { + request.metadata_mut().insert( + "authorization", + format!("Bearer {}", token).try_into().expect("Expected authorization header formatting"), + ); + } request } } diff --git a/rust/src/connection/network/proto/message.rs b/rust/src/connection/network/proto/message.rs index d03ff4023..900e96e82 100644 --- a/rust/src/connection/network/proto/message.rs +++ b/rust/src/connection/network/proto/message.rs @@ -19,8 +19,8 @@ use itertools::Itertools; use typedb_protocol::{ - connection, database, database_manager, query::initial_res::Res, server_manager, transaction, user, user_manager, - Version::Version, + authentication, connection, database, database_manager, query::initial_res::Res, server_manager, transaction, user, + user_manager, Version::Version, }; use uuid::Uuid; @@ -32,14 +32,18 @@ use crate::{ error::{ConnectionError, InternalError, ServerError}, info::UserInfo, user::User, + Credentials, }; impl TryIntoProto for Request { fn try_into_proto(self) -> Result { match self { - Self::ConnectionOpen { driver_lang, driver_version } => { - Ok(connection::open::Req { version: Version.into(), driver_lang, driver_version }) - } + Self::ConnectionOpen { driver_lang, driver_version, credentials } => Ok(connection::open::Req { + version: Version.into(), + driver_lang, + driver_version, + authentication: Some(credentials.try_into_proto()?), + }), other => Err(InternalError::UnexpectedRequestType { request_type: format!("{other:?}") }.into()), } } @@ -225,14 +229,28 @@ impl TryIntoProto for Request { } } +impl TryIntoProto for Credentials { + fn try_into_proto(self) -> Result { + Ok(authentication::sign_in::Req { + credentials: Some(authentication::sign_in::req::Credentials::Password( + authentication::sign_in::req::Password { + username: self.username().to_owned(), + password: self.password().to_owned(), + }, + )), + }) + } +} + impl TryFromProto for Response { fn try_from_proto(proto: connection::open::Res) -> Result { let mut database_infos = Vec::new(); - for database_info_proto in proto.databases_all.unwrap().databases { + for database_info_proto in proto.databases_all.expect("Expected databases data").databases { database_infos.push(DatabaseInfo::try_from_proto(database_info_proto)?); } Ok(Self::ConnectionOpen { - connection_id: Uuid::from_slice(proto.connection_id.unwrap().id.as_slice()).unwrap(), + connection_id: Uuid::from_slice(proto.connection_id.expect("Expected connection id").id.as_slice()) + .unwrap(), server_duration_millis: proto.server_duration_millis, databases: database_infos, }) diff --git a/rust/src/connection/network/stub.rs b/rust/src/connection/network/stub.rs index 38af93adf..0eb780470 100644 --- a/rust/src/connection/network/stub.rs +++ b/rust/src/connection/network/stub.rs @@ -25,12 +25,15 @@ use tokio::sync::mpsc::{unbounded_channel as unbounded_async, UnboundedSender}; use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::{Response, Status, Streaming}; use typedb_protocol::{ - connection, database, database_manager, server_manager, transaction, type_db_client::TypeDbClient as GRPC, user, - user_manager, + authentication, connection, database, database_manager, server_manager, transaction, + type_db_client::TypeDbClient as GRPC, user, user_manager, }; use super::channel::{CallCredentials, GRPCChannel}; -use crate::common::{error::ConnectionError, Error, Result, StdResult}; +use crate::{ + common::{error::ConnectionError, Error, Result, StdResult}, + connection::network::proto::TryIntoProto, +}; type TonicResult = StdResult, Status>; @@ -45,15 +48,41 @@ impl RPCStub { Self { grpc: GRPC::new(channel), call_credentials } } - async fn call(&mut self, call: F) -> Result + async fn call_with_auto_renew_token(&mut self, call: F) -> Result where for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, Result>, { - call(self).await + match call(self).await { + Err(Error::Connection(ConnectionError::TokenCredentialInvalid)) => { + debug!("Request rejected because token credential was invalid. Renewing token and trying again..."); + self.renew_token().await?; + call(self).await + } + res => res, + } + } + + async fn renew_token(&mut self) -> Result { + if let Some(call_credentials) = &self.call_credentials { + trace!("Renewing token..."); + call_credentials.reset_token(); + let request = call_credentials.credentials().clone().try_into_proto()?; + let token = self.grpc.sign_in(request).await?.into_inner().token; + call_credentials.set_token(token); + trace!("Token renewed"); + } + Ok(()) } pub(super) async fn connection_open(&mut self, req: connection::open::Req) -> Result { - self.single(|this| Box::pin(this.grpc.connection_open(req.clone()))).await + let result = self.single(|this| Box::pin(this.grpc.connection_open(req.clone()))).await; + if let Ok(response) = &result { + if let Some(call_credentials) = &self.call_credentials { + call_credentials + .set_token(response.authentication.as_ref().expect("Expected authentication token").token.clone()); + } + } + result } pub(super) async fn servers_all(&mut self, req: server_manager::all::Req) -> Result { @@ -107,7 +136,7 @@ impl RPCStub { &mut self, open_req: transaction::Req, ) -> Result<(UnboundedSender, Streaming)> { - self.call(|this| { + self.call_with_auto_renew_token(|this| { let transaction_req = transaction::Client { reqs: vec![open_req.clone()] }; Box::pin(async { let (sender, receiver) = unbounded_async(); @@ -154,6 +183,6 @@ impl RPCStub { for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, TonicResult> + Send + Sync, R: 'static, { - self.call(|this| Box::pin(call(this).map(|r| Ok(r?.into_inner())))).await + self.call_with_auto_renew_token(|this| Box::pin(call(this).map(|r| Ok(r?.into_inner())))).await } } diff --git a/rust/src/connection/server_connection.rs b/rust/src/connection/server_connection.rs index 58c0ed211..d92c25c3d 100644 --- a/rust/src/connection/server_connection.rs +++ b/rust/src/connection/server_connection.rs @@ -65,9 +65,9 @@ impl ServerConnection { ) -> crate::Result<(Self, Vec)> { let username = credentials.username().to_string(); let request_transmitter = - Arc::new(RPCTransmitter::start(address, credentials, driver_options, &background_runtime)?); + Arc::new(RPCTransmitter::start(address, credentials.clone(), driver_options, &background_runtime)?); let (connection_id, latency, database_info) = - Self::open_connection(&request_transmitter, driver_lang, driver_version).await?; + Self::open_connection(&request_transmitter, driver_lang, driver_version, credentials).await?; let latency_tracker = LatencyTracker::new(latency); let server_connection = Self { background_runtime, @@ -85,9 +85,13 @@ impl ServerConnection { request_transmitter: &RPCTransmitter, driver_lang: &str, driver_version: &str, + credentials: Credentials, ) -> crate::Result<(Uuid, Duration, Vec)> { - let message = - Request::ConnectionOpen { driver_lang: driver_lang.to_owned(), driver_version: driver_version.to_owned() }; + let message = Request::ConnectionOpen { + driver_lang: driver_lang.to_owned(), + driver_version: driver_version.to_owned(), + credentials, + }; let request_time = Instant::now(); match request_transmitter.request(message).await? { diff --git a/tool/test/start-community-server.sh b/tool/test/start-community-server.sh index 7695fb53d..fe52d1596 100755 --- a/tool/test/start-community-server.sh +++ b/tool/test/start-community-server.sh @@ -22,6 +22,7 @@ rm -rf typedb-all bazel run //tool/test:typedb-extractor -- typedb-all BAZEL_JAVA_HOME=$(bazel run //tool/test:echo-java-home) +# TODO: Can add `--server.authentication.token_ttl_seconds 15` to test token auto-renewal in BDDs! ./typedb-all/typedb server --development-mode.enabled & set +e