From b9cb032c87156c0b33ea8c934cacb9b1bb4482ac Mon Sep 17 00:00:00 2001 From: Georgii Novoselov Date: Thu, 13 Mar 2025 10:39:31 +0000 Subject: [PATCH 1/4] Add new authentication --- Cargo.lock | 2 +- dependencies/typedb/repositories.bzl | 9 ++++-- rust/Cargo.toml | 2 +- rust/src/common/error.rs | 2 +- rust/src/connection/network/channel.rs | 24 +++++++++----- rust/src/connection/network/stub.rs | 43 +++++++++++++++++++++----- 6 files changed, 62 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c0081b7dd..df94993d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2753,7 +2753,7 @@ dependencies = [ [[package]] name = "typedb-protocol" version = "0.0.0" -source = "git+https://github.com/typedb/typedb-protocol?tag=3.0.0#111f1a9ed8aac3360c0b5d16e68c1ecebe823137" +source = "git+https://github.com/typedb/typedb-protocol?rev=1ce9fc93c005b19d63fef835c1ee3a536ecf4c20#1ce9fc93c005b19d63fef835c1ee3a536ecf4c20" dependencies = [ "prost", "tonic", diff --git a/dependencies/typedb/repositories.bzl b/dependencies/typedb/repositories.bzl index daf6f0958..d48c4e589 100644 --- a/dependencies/typedb/repositories.bzl +++ b/dependencies/typedb/repositories.bzl @@ -27,9 +27,14 @@ def typedb_dependencies(): def typedb_protocol(): git_repository( name = "typedb_protocol", - remote = "https://github.com/typedb/typedb-protocol", - tag = "3.0.0", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_protocol + remote = "https://github.com/farost/typedb-protocol", + commit = "1ce9fc93c005b19d63fef835c1ee3a536ecf4c20", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_protocol ) +# git_repository( +# name = "typedb_protocol", +# remote = "https://github.com/typedb/typedb-protocol", +# tag = "3.0.0", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_protocol +# ) def typedb_behaviour(): git_repository( diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 108c688bf..b92a1f71e 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -60,8 +60,8 @@ [dependencies.typedb-protocol] features = [] + rev = "1ce9fc93c005b19d63fef835c1ee3a536ecf4c20" git = "https://github.com/typedb/typedb-protocol" - tag = "3.0.0" default-features = false [dependencies.log] diff --git a/rust/src/common/error.rs b/rust/src/common/error.rs index 6c8a284ca..c18b21c46 100644 --- a/rust/src/common/error.rs +++ b/rust/src/common/error.rs @@ -150,7 +150,7 @@ error_messages! { ConnectionError 15: "The replica is not the primary replica.", ClusterAllNodesFailed { errors: String } = 16: "Attempted connecting to all TypeDB Cluster servers, but the following errors occurred: \n{errors}.", - ClusterTokenCredentialInvalid = + TokenCredentialInvalid = 17: "Invalid token credentials.", EncryptionSettingsMismatch = 18: "Unable to connect to TypeDB: possible encryption settings mismatch.", diff --git a/rust/src/connection/network/channel.rs b/rust/src/connection/network/channel.rs index 1d972bfc9..3260bc5cb 100644 --- a/rust/src/connection/network/channel.rs +++ b/rust/src/connection/network/channel.rs @@ -17,7 +17,7 @@ * under the License. */ -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use tonic::{ body::BoxBody, @@ -29,7 +29,7 @@ use tonic::{ transport::{channel::ResponseFuture as ChannelResponseFuture, Channel, Error as TonicError}, Request, Status, }; - +use tonic::metadata::MetadataValue; use crate::{ common::{address::Address, Result, StdResult}, Credentials, DriverOptions, @@ -65,20 +65,30 @@ pub(super) fn open_callcred_channel( #[derive(Debug)] pub(super) struct CallCredentials { credentials: Credentials, + token: RwLock>, } impl CallCredentials { pub(super) fn new(credentials: Credentials) -> Self { - Self { credentials } + Self { credentials, token: RwLock::new(None) } + } + + pub(super) fn credentials(&self) -> &Credentials { + &self.credentials + } + + pub(super) fn set_token(&self, token: String) { + *self.token.write().expect("Expected token write lock acquisition on set") = Some(token); } - pub(super) fn username(&self) -> &str { - self.credentials.username() + pub(super) fn reset_token(&self) { + *self.token.write().expect("Expected token write lock acquisition on reset") = None; } pub(super) fn inject(&self, mut request: Request<()>) -> Request<()> { - request.metadata_mut().insert("username", self.credentials.username().try_into().unwrap()); - request.metadata_mut().insert("password", self.credentials.password().try_into().unwrap()); + if let Some(token) = &*self.token.read().expect("Expected token read lock acquisition on inject") { + request.metadata_mut().insert("authorization", format!("Bearer {}", token).try_into().expect("Expected authorization header formatting")); + } request } } diff --git a/rust/src/connection/network/stub.rs b/rust/src/connection/network/stub.rs index 38af93adf..6abe922dc 100644 --- a/rust/src/connection/network/stub.rs +++ b/rust/src/connection/network/stub.rs @@ -25,7 +25,7 @@ use tokio::sync::mpsc::{unbounded_channel as unbounded_async, UnboundedSender}; use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::{Response, Status, Streaming}; use typedb_protocol::{ - connection, database, database_manager, server_manager, transaction, type_db_client::TypeDbClient as GRPC, user, + authentication, connection, database, database_manager, server_manager, transaction, type_db_client::TypeDbClient as GRPC, user, user_manager, }; @@ -42,18 +42,45 @@ pub(super) struct RPCStub { impl RPCStub { pub(super) async fn new(channel: Channel, call_credentials: Option>) -> Self { - Self { grpc: GRPC::new(channel), call_credentials } + let mut this = Self { grpc: GRPC::new(channel), call_credentials }; + if let Err(err) = this.renew_token().await { + warn!("{err:?}"); + } + this } - async fn call(&mut self, call: F) -> Result + async fn call_with_auto_renew_token(&mut self, call: F) -> Result where - for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, Result>, + for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, Result>, { - call(self).await + match call(self).await { + Err(Error::Connection(ConnectionError::TokenCredentialInvalid)) => { + debug!("Request rejected because token credential was invalid. Renewing token and trying again..."); + self.renew_token().await?; + call(self).await + } + res => res, + } + } + + async fn renew_token(&mut self) -> Result { + if let Some(call_credentials) = &self.call_credentials { + trace!("Renewing token..."); + call_credentials.reset_token(); + let req = authentication::sign_in::Req { credentials: Some(authentication::sign_in::req::Credentials::Password(authentication::sign_in::req::Password { username: call_credentials.credentials().username().to_owned(), password: call_credentials.credentials().password().to_owned() })) }; + trace!("Sending token request..."); + let token = self.grpc.sign_in(req).await?.into_inner().token; + call_credentials.set_token(token); + trace!("Token renewed"); + } + Ok(()) } pub(super) async fn connection_open(&mut self, req: connection::open::Req) -> Result { - self.single(|this| Box::pin(this.grpc.connection_open(req.clone()))).await + let res = self.single(|this| Box::pin(this.grpc.connection_open(req.clone()))).await?; + trace!("Connection opened"); + self.renew_token().await?; + Ok(res) } pub(super) async fn servers_all(&mut self, req: server_manager::all::Req) -> Result { @@ -107,7 +134,7 @@ impl RPCStub { &mut self, open_req: transaction::Req, ) -> Result<(UnboundedSender, Streaming)> { - self.call(|this| { + self.call_with_auto_renew_token(|this| { let transaction_req = transaction::Client { reqs: vec![open_req.clone()] }; Box::pin(async { let (sender, receiver) = unbounded_async(); @@ -154,6 +181,6 @@ impl RPCStub { for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, TonicResult> + Send + Sync, R: 'static, { - self.call(|this| Box::pin(call(this).map(|r| Ok(r?.into_inner())))).await + self.call_with_auto_renew_token(|this| Box::pin(call(this).map(|r| Ok(r?.into_inner())))).await } } From 29c2a24636d66313cbd8b06145b41c79e907d38b Mon Sep 17 00:00:00 2001 From: Georgii Novoselov Date: Thu, 13 Mar 2025 17:06:41 +0000 Subject: [PATCH 2/4] Add token integration --- Cargo.lock | 2 +- dependencies/typedb/repositories.bzl | 3 +- rust/Cargo.toml | 2 +- rust/src/common/error.rs | 17 ++++++++-- rust/src/connection/message.rs | 20 ++++-------- rust/src/connection/network/proto/message.rs | 34 +++++++++++--------- rust/src/connection/network/stub.rs | 27 ++++++++-------- rust/src/connection/server_connection.rs | 7 ++-- 8 files changed, 60 insertions(+), 52 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index df94993d4..ea1bdc495 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2753,7 +2753,7 @@ dependencies = [ [[package]] name = "typedb-protocol" version = "0.0.0" -source = "git+https://github.com/typedb/typedb-protocol?rev=1ce9fc93c005b19d63fef835c1ee3a536ecf4c20#1ce9fc93c005b19d63fef835c1ee3a536ecf4c20" +source = "git+https://github.com/typedb/typedb-protocol?rev=b4b7ee87b08c16831a21629f81974347f38cce5c#b4b7ee87b08c16831a21629f81974347f38cce5c" dependencies = [ "prost", "tonic", diff --git a/dependencies/typedb/repositories.bzl b/dependencies/typedb/repositories.bzl index d48c4e589..7fe7aef96 100644 --- a/dependencies/typedb/repositories.bzl +++ b/dependencies/typedb/repositories.bzl @@ -25,10 +25,11 @@ def typedb_dependencies(): ) def typedb_protocol(): + # TODO: Temp, return typedb git_repository( name = "typedb_protocol", remote = "https://github.com/farost/typedb-protocol", - commit = "1ce9fc93c005b19d63fef835c1ee3a536ecf4c20", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_protocol + commit = "b4b7ee87b08c16831a21629f81974347f38cce5c", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_protocol ) # git_repository( # name = "typedb_protocol", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index b92a1f71e..3a183722f 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -60,7 +60,7 @@ [dependencies.typedb-protocol] features = [] - rev = "1ce9fc93c005b19d63fef835c1ee3a536ecf4c20" + rev = "b4b7ee87b08c16831a21629f81974347f38cce5c" git = "https://github.com/typedb/typedb-protocol" default-features = false diff --git a/rust/src/common/error.rs b/rust/src/common/error.rs index c18b21c46..c64a659d7 100644 --- a/rust/src/common/error.rs +++ b/rust/src/common/error.rs @@ -21,7 +21,7 @@ use std::{collections::HashSet, error::Error as StdError, fmt}; use itertools::Itertools; use tonic::{Code, Status}; -use tonic_types::StatusExt; +use tonic_types::{ErrorDetails, ErrorInfo, StatusExt}; use super::{address::Address, RequestID}; @@ -275,6 +275,16 @@ impl Error { } } + fn try_extracting_connection_error(status: &Status, code: &str) -> Option { + // TODO: We should probably catch more connection errors instead of wrapping them into + // ServerErrors. However, the most valuable information even for connection is inside + // stacktraces now. + match code { + "AUT3" => Some(ConnectionError::TokenCredentialInvalid {}), + _ => None, + } + } + fn from_message(message: &str) -> Self { // TODO: Consider converting some of the messages to connection errors Self::Other(message.to_owned()) @@ -352,9 +362,13 @@ impl From for Error { }) } else if let Some(error_info) = details.error_info() { let code = error_info.reason.clone(); + if let Some(connection_error) = Self::try_extracting_connection_error(&status, &code) { + return Self::Connection(connection_error); + } let domain = error_info.domain.clone(); let stack_trace = if let Some(debug_info) = details.debug_info() { debug_info.stack_entries.clone() } else { vec![] }; + Self::Server(ServerError::new(code, domain, status.message().to_owned(), stack_trace)) } else { Self::from_message(status.message()) @@ -364,7 +378,6 @@ impl From for Error { Self::parse_unavailable(status.message()) } else if status.code() == Code::Unknown || is_rst_stream(&status) - || status.code() == Code::InvalidArgument || status.code() == Code::FailedPrecondition || status.code() == Code::AlreadyExists { diff --git a/rust/src/connection/message.rs b/rust/src/connection/message.rs index 73f860b0c..17db6fb24 100644 --- a/rust/src/connection/message.rs +++ b/rust/src/connection/message.rs @@ -24,23 +24,15 @@ use tonic::Streaming; use typedb_protocol::transaction; use uuid::Uuid; -use crate::{ - answer::{ - concept_document::{ConceptDocumentHeader, Node}, - concept_row::ConceptRowHeader, - QueryType, - }, - common::{address::Address, info::DatabaseInfo, RequestID}, - concept::Concept, - error::ServerError, - info::UserInfo, - user::User, - Options, TransactionType, -}; +use crate::{answer::{ + concept_document::{ConceptDocumentHeader, Node}, + concept_row::ConceptRowHeader, + QueryType, +}, common::{address::Address, info::DatabaseInfo, RequestID}, concept::Concept, error::ServerError, info::UserInfo, user::User, Credentials, Options, TransactionType}; #[derive(Debug)] pub(super) enum Request { - ConnectionOpen { driver_lang: String, driver_version: String }, + ConnectionOpen { driver_lang: String, driver_version: String, credentials: Credentials }, ServersAll, diff --git a/rust/src/connection/network/proto/message.rs b/rust/src/connection/network/proto/message.rs index d03ff4023..8886141d7 100644 --- a/rust/src/connection/network/proto/message.rs +++ b/rust/src/connection/network/proto/message.rs @@ -18,27 +18,18 @@ */ use itertools::Itertools; -use typedb_protocol::{ - connection, database, database_manager, query::initial_res::Res, server_manager, transaction, user, user_manager, - Version::Version, -}; +use typedb_protocol::{authentication, connection, database, database_manager, query::initial_res::Res, server_manager, transaction, user, user_manager, Version::Version}; use uuid::Uuid; use super::{FromProto, IntoProto, TryFromProto, TryIntoProto}; -use crate::{ - answer::{concept_document::ConceptDocumentHeader, concept_row::ConceptRowHeader, QueryType}, - common::{info::DatabaseInfo, RequestID, Result}, - connection::message::{QueryRequest, QueryResponse, Request, Response, TransactionRequest, TransactionResponse}, - error::{ConnectionError, InternalError, ServerError}, - info::UserInfo, - user::User, -}; +use crate::{answer::{concept_document::ConceptDocumentHeader, concept_row::ConceptRowHeader, QueryType}, common::{info::DatabaseInfo, RequestID, Result}, connection::message::{QueryRequest, QueryResponse, Request, Response, TransactionRequest, TransactionResponse}, error::{ConnectionError, InternalError, ServerError}, info::UserInfo, user::User, Credentials}; impl TryIntoProto for Request { fn try_into_proto(self) -> Result { match self { - Self::ConnectionOpen { driver_lang, driver_version } => { - Ok(connection::open::Req { version: Version.into(), driver_lang, driver_version }) + Self::ConnectionOpen { driver_lang, driver_version, credentials } => { + Ok(connection::open::Req { version: Version.into(), driver_lang, driver_version, + authentication: Some(credentials.try_into_proto()?) }) } other => Err(InternalError::UnexpectedRequestType { request_type: format!("{other:?}") }.into()), } @@ -225,14 +216,25 @@ impl TryIntoProto for Request { } } +impl TryIntoProto for Credentials { + fn try_into_proto(self) -> Result { + Ok(authentication::sign_in::Req { + credentials: Some(authentication::sign_in::req::Credentials::Password(authentication::sign_in::req::Password { + username: self.username().to_owned(), + password: self.password().to_owned(), + })), + }) + } +} + impl TryFromProto for Response { fn try_from_proto(proto: connection::open::Res) -> Result { let mut database_infos = Vec::new(); - for database_info_proto in proto.databases_all.unwrap().databases { + for database_info_proto in proto.databases_all.expect("Expected databases data").databases { database_infos.push(DatabaseInfo::try_from_proto(database_info_proto)?); } Ok(Self::ConnectionOpen { - connection_id: Uuid::from_slice(proto.connection_id.unwrap().id.as_slice()).unwrap(), + connection_id: Uuid::from_slice(proto.connection_id.expect("Expected connection id").id.as_slice()).unwrap(), server_duration_millis: proto.server_duration_millis, databases: database_infos, }) diff --git a/rust/src/connection/network/stub.rs b/rust/src/connection/network/stub.rs index 6abe922dc..259c37142 100644 --- a/rust/src/connection/network/stub.rs +++ b/rust/src/connection/network/stub.rs @@ -31,6 +31,7 @@ use typedb_protocol::{ use super::channel::{CallCredentials, GRPCChannel}; use crate::common::{error::ConnectionError, Error, Result, StdResult}; +use crate::connection::network::proto::TryIntoProto; type TonicResult = StdResult, Status>; @@ -42,11 +43,7 @@ pub(super) struct RPCStub { impl RPCStub { pub(super) async fn new(channel: Channel, call_credentials: Option>) -> Self { - let mut this = Self { grpc: GRPC::new(channel), call_credentials }; - if let Err(err) = this.renew_token().await { - warn!("{err:?}"); - } - this + Self { grpc: GRPC::new(channel), call_credentials } } async fn call_with_auto_renew_token(&mut self, call: F) -> Result @@ -67,9 +64,8 @@ impl RPCStub { if let Some(call_credentials) = &self.call_credentials { trace!("Renewing token..."); call_credentials.reset_token(); - let req = authentication::sign_in::Req { credentials: Some(authentication::sign_in::req::Credentials::Password(authentication::sign_in::req::Password { username: call_credentials.credentials().username().to_owned(), password: call_credentials.credentials().password().to_owned() })) }; - trace!("Sending token request..."); - let token = self.grpc.sign_in(req).await?.into_inner().token; + let request = call_credentials.credentials().clone().try_into_proto()?; + let token = self.grpc.sign_in(request).await?.into_inner().token; call_credentials.set_token(token); trace!("Token renewed"); } @@ -77,10 +73,13 @@ impl RPCStub { } pub(super) async fn connection_open(&mut self, req: connection::open::Req) -> Result { - let res = self.single(|this| Box::pin(this.grpc.connection_open(req.clone()))).await?; - trace!("Connection opened"); - self.renew_token().await?; - Ok(res) + let result = self.single(|this| Box::pin(this.grpc.connection_open(req.clone()))).await; + if let Ok(response) = &result { + if let Some(call_credentials) = &self.call_credentials { + call_credentials.set_token(response.authentication.as_ref().expect("Expected authentication token").token.clone()); + } + } + result } pub(super) async fn servers_all(&mut self, req: server_manager::all::Req) -> Result { @@ -178,8 +177,8 @@ impl RPCStub { async fn single(&mut self, call: F) -> Result where - for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, TonicResult> + Send + Sync, - R: 'static, + for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, TonicResult> + Send + Sync, + R: 'static, { self.call_with_auto_renew_token(|this| Box::pin(call(this).map(|r| Ok(r?.into_inner())))).await } diff --git a/rust/src/connection/server_connection.rs b/rust/src/connection/server_connection.rs index 58c0ed211..774108b18 100644 --- a/rust/src/connection/server_connection.rs +++ b/rust/src/connection/server_connection.rs @@ -65,9 +65,9 @@ impl ServerConnection { ) -> crate::Result<(Self, Vec)> { let username = credentials.username().to_string(); let request_transmitter = - Arc::new(RPCTransmitter::start(address, credentials, driver_options, &background_runtime)?); + Arc::new(RPCTransmitter::start(address, credentials.clone(), driver_options, &background_runtime)?); let (connection_id, latency, database_info) = - Self::open_connection(&request_transmitter, driver_lang, driver_version).await?; + Self::open_connection(&request_transmitter, driver_lang, driver_version, credentials).await?; let latency_tracker = LatencyTracker::new(latency); let server_connection = Self { background_runtime, @@ -85,9 +85,10 @@ impl ServerConnection { request_transmitter: &RPCTransmitter, driver_lang: &str, driver_version: &str, + credentials: Credentials, ) -> crate::Result<(Uuid, Duration, Vec)> { let message = - Request::ConnectionOpen { driver_lang: driver_lang.to_owned(), driver_version: driver_version.to_owned() }; + Request::ConnectionOpen { driver_lang: driver_lang.to_owned(), driver_version: driver_version.to_owned(), credentials }; let request_time = Instant::now(); match request_transmitter.request(message).await? { From 6637234c96fe01bc3bb419cf1d11e9ae81f55f51 Mon Sep 17 00:00:00 2001 From: Georgii Novoselov Date: Thu, 13 Mar 2025 17:24:40 +0000 Subject: [PATCH 3/4] Rustfmt + TODOs --- rust/src/connection/message.rs | 18 +++++++--- rust/src/connection/network/channel.rs | 8 +++-- rust/src/connection/network/proto/message.rs | 38 ++++++++++++++------ rust/src/connection/network/stub.rs | 19 +++++----- rust/src/connection/server_connection.rs | 7 ++-- tool/test/start-community-server.sh | 1 + 6 files changed, 63 insertions(+), 28 deletions(-) diff --git a/rust/src/connection/message.rs b/rust/src/connection/message.rs index 17db6fb24..e9e080798 100644 --- a/rust/src/connection/message.rs +++ b/rust/src/connection/message.rs @@ -24,11 +24,19 @@ use tonic::Streaming; use typedb_protocol::transaction; use uuid::Uuid; -use crate::{answer::{ - concept_document::{ConceptDocumentHeader, Node}, - concept_row::ConceptRowHeader, - QueryType, -}, common::{address::Address, info::DatabaseInfo, RequestID}, concept::Concept, error::ServerError, info::UserInfo, user::User, Credentials, Options, TransactionType}; +use crate::{ + answer::{ + concept_document::{ConceptDocumentHeader, Node}, + concept_row::ConceptRowHeader, + QueryType, + }, + common::{address::Address, info::DatabaseInfo, RequestID}, + concept::Concept, + error::ServerError, + info::UserInfo, + user::User, + Credentials, Options, TransactionType, +}; #[derive(Debug)] pub(super) enum Request { diff --git a/rust/src/connection/network/channel.rs b/rust/src/connection/network/channel.rs index 3260bc5cb..a96ec9bdd 100644 --- a/rust/src/connection/network/channel.rs +++ b/rust/src/connection/network/channel.rs @@ -22,6 +22,7 @@ use std::sync::{Arc, RwLock}; use tonic::{ body::BoxBody, client::GrpcService, + metadata::MetadataValue, service::{ interceptor::{InterceptedService, ResponseFuture as InterceptorResponseFuture}, Interceptor, @@ -29,7 +30,7 @@ use tonic::{ transport::{channel::ResponseFuture as ChannelResponseFuture, Channel, Error as TonicError}, Request, Status, }; -use tonic::metadata::MetadataValue; + use crate::{ common::{address::Address, Result, StdResult}, Credentials, DriverOptions, @@ -87,7 +88,10 @@ impl CallCredentials { pub(super) fn inject(&self, mut request: Request<()>) -> Request<()> { if let Some(token) = &*self.token.read().expect("Expected token read lock acquisition on inject") { - request.metadata_mut().insert("authorization", format!("Bearer {}", token).try_into().expect("Expected authorization header formatting")); + request.metadata_mut().insert( + "authorization", + format!("Bearer {}", token).try_into().expect("Expected authorization header formatting"), + ); } request } diff --git a/rust/src/connection/network/proto/message.rs b/rust/src/connection/network/proto/message.rs index 8886141d7..900e96e82 100644 --- a/rust/src/connection/network/proto/message.rs +++ b/rust/src/connection/network/proto/message.rs @@ -18,19 +18,32 @@ */ use itertools::Itertools; -use typedb_protocol::{authentication, connection, database, database_manager, query::initial_res::Res, server_manager, transaction, user, user_manager, Version::Version}; +use typedb_protocol::{ + authentication, connection, database, database_manager, query::initial_res::Res, server_manager, transaction, user, + user_manager, Version::Version, +}; use uuid::Uuid; use super::{FromProto, IntoProto, TryFromProto, TryIntoProto}; -use crate::{answer::{concept_document::ConceptDocumentHeader, concept_row::ConceptRowHeader, QueryType}, common::{info::DatabaseInfo, RequestID, Result}, connection::message::{QueryRequest, QueryResponse, Request, Response, TransactionRequest, TransactionResponse}, error::{ConnectionError, InternalError, ServerError}, info::UserInfo, user::User, Credentials}; +use crate::{ + answer::{concept_document::ConceptDocumentHeader, concept_row::ConceptRowHeader, QueryType}, + common::{info::DatabaseInfo, RequestID, Result}, + connection::message::{QueryRequest, QueryResponse, Request, Response, TransactionRequest, TransactionResponse}, + error::{ConnectionError, InternalError, ServerError}, + info::UserInfo, + user::User, + Credentials, +}; impl TryIntoProto for Request { fn try_into_proto(self) -> Result { match self { - Self::ConnectionOpen { driver_lang, driver_version, credentials } => { - Ok(connection::open::Req { version: Version.into(), driver_lang, driver_version, - authentication: Some(credentials.try_into_proto()?) }) - } + Self::ConnectionOpen { driver_lang, driver_version, credentials } => Ok(connection::open::Req { + version: Version.into(), + driver_lang, + driver_version, + authentication: Some(credentials.try_into_proto()?), + }), other => Err(InternalError::UnexpectedRequestType { request_type: format!("{other:?}") }.into()), } } @@ -219,10 +232,12 @@ impl TryIntoProto for Request { impl TryIntoProto for Credentials { fn try_into_proto(self) -> Result { Ok(authentication::sign_in::Req { - credentials: Some(authentication::sign_in::req::Credentials::Password(authentication::sign_in::req::Password { - username: self.username().to_owned(), - password: self.password().to_owned(), - })), + credentials: Some(authentication::sign_in::req::Credentials::Password( + authentication::sign_in::req::Password { + username: self.username().to_owned(), + password: self.password().to_owned(), + }, + )), }) } } @@ -234,7 +249,8 @@ impl TryFromProto for Response { database_infos.push(DatabaseInfo::try_from_proto(database_info_proto)?); } Ok(Self::ConnectionOpen { - connection_id: Uuid::from_slice(proto.connection_id.expect("Expected connection id").id.as_slice()).unwrap(), + connection_id: Uuid::from_slice(proto.connection_id.expect("Expected connection id").id.as_slice()) + .unwrap(), server_duration_millis: proto.server_duration_millis, databases: database_infos, }) diff --git a/rust/src/connection/network/stub.rs b/rust/src/connection/network/stub.rs index 259c37142..0eb780470 100644 --- a/rust/src/connection/network/stub.rs +++ b/rust/src/connection/network/stub.rs @@ -25,13 +25,15 @@ use tokio::sync::mpsc::{unbounded_channel as unbounded_async, UnboundedSender}; use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::{Response, Status, Streaming}; use typedb_protocol::{ - authentication, connection, database, database_manager, server_manager, transaction, type_db_client::TypeDbClient as GRPC, user, - user_manager, + authentication, connection, database, database_manager, server_manager, transaction, + type_db_client::TypeDbClient as GRPC, user, user_manager, }; use super::channel::{CallCredentials, GRPCChannel}; -use crate::common::{error::ConnectionError, Error, Result, StdResult}; -use crate::connection::network::proto::TryIntoProto; +use crate::{ + common::{error::ConnectionError, Error, Result, StdResult}, + connection::network::proto::TryIntoProto, +}; type TonicResult = StdResult, Status>; @@ -48,7 +50,7 @@ impl RPCStub { async fn call_with_auto_renew_token(&mut self, call: F) -> Result where - for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, Result>, + for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, Result>, { match call(self).await { Err(Error::Connection(ConnectionError::TokenCredentialInvalid)) => { @@ -76,7 +78,8 @@ impl RPCStub { let result = self.single(|this| Box::pin(this.grpc.connection_open(req.clone()))).await; if let Ok(response) = &result { if let Some(call_credentials) = &self.call_credentials { - call_credentials.set_token(response.authentication.as_ref().expect("Expected authentication token").token.clone()); + call_credentials + .set_token(response.authentication.as_ref().expect("Expected authentication token").token.clone()); } } result @@ -177,8 +180,8 @@ impl RPCStub { async fn single(&mut self, call: F) -> Result where - for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, TonicResult> + Send + Sync, - R: 'static, + for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, TonicResult> + Send + Sync, + R: 'static, { self.call_with_auto_renew_token(|this| Box::pin(call(this).map(|r| Ok(r?.into_inner())))).await } diff --git a/rust/src/connection/server_connection.rs b/rust/src/connection/server_connection.rs index 774108b18..d92c25c3d 100644 --- a/rust/src/connection/server_connection.rs +++ b/rust/src/connection/server_connection.rs @@ -87,8 +87,11 @@ impl ServerConnection { driver_version: &str, credentials: Credentials, ) -> crate::Result<(Uuid, Duration, Vec)> { - let message = - Request::ConnectionOpen { driver_lang: driver_lang.to_owned(), driver_version: driver_version.to_owned(), credentials }; + let message = Request::ConnectionOpen { + driver_lang: driver_lang.to_owned(), + driver_version: driver_version.to_owned(), + credentials, + }; let request_time = Instant::now(); match request_transmitter.request(message).await? { diff --git a/tool/test/start-community-server.sh b/tool/test/start-community-server.sh index 7695fb53d..fe52d1596 100755 --- a/tool/test/start-community-server.sh +++ b/tool/test/start-community-server.sh @@ -22,6 +22,7 @@ rm -rf typedb-all bazel run //tool/test:typedb-extractor -- typedb-all BAZEL_JAVA_HOME=$(bazel run //tool/test:echo-java-home) +# TODO: Can add `--server.authentication.token_ttl_seconds 15` to test token auto-renewal in BDDs! ./typedb-all/typedb server --development-mode.enabled & set +e From 017a699611841c3ededc2f78c62ab42d25f8d3d9 Mon Sep 17 00:00:00 2001 From: Georgii Novoselov Date: Fri, 21 Mar 2025 13:59:06 +0000 Subject: [PATCH 4/4] Remove unnecessary code? --- rust/tests/behaviour/steps/params.rs | 46 ---------------------------- 1 file changed, 46 deletions(-) diff --git a/rust/tests/behaviour/steps/params.rs b/rust/tests/behaviour/steps/params.rs index e92d52107..b4dfdcc75 100644 --- a/rust/tests/behaviour/steps/params.rs +++ b/rust/tests/behaviour/steps/params.rs @@ -27,52 +27,6 @@ use typedb_driver::{ TransactionType as TypeDBTransactionType, }; -#[derive(Debug, Parameter)] -#[param(name = "containment", regex = r"(?:do not )?contain")] -pub struct ContainmentParam(bool); - -impl ContainmentParam { - pub fn assert(&self, actuals: &[T], item: U) - where - T: Comparable + fmt::Debug, - U: PartialEq + fmt::Debug, - { - if self.0 { - assert!(actuals.iter().any(|actual| actual.equals(&item)), "{item:?} not found in {actuals:?}") - } else { - assert!(actuals.iter().all(|actual| !actual.equals(&item)), "{item:?} found in {actuals:?}") - } - } -} - -impl FromStr for ContainmentParam { - type Err = Infallible; - - fn from_str(s: &str) -> Result { - Ok(Self(s == "contain")) - } -} - -pub trait Comparable { - fn equals(&self, item: &U) -> bool; -} - -impl, U: PartialEq + ?Sized> Comparable<&U> for T { - fn equals(&self, item: &&U) -> bool { - self.borrow() == *item - } -} - -impl<'a, T1, T2, U1, U2> Comparable<(&'a U1, &'a U2)> for (T1, T2) -where - T1: Comparable<&'a U1>, - T2: Comparable<&'a U2>, -{ - fn equals(&self, (first, second): &(&'a U1, &'a U2)) -> bool { - self.0.equals(first) && self.1.equals(second) - } -} - #[derive(Debug, Default, Parameter, Clone)] #[param(name = "value", regex = ".*?")] pub(crate) struct Value {