Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Substitute credentials authentication by token authentication with auto-renewal #743

Draft
wants to merge 3 commits into
base: 3.0
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 8 additions & 2 deletions dependencies/typedb/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,17 @@ def typedb_dependencies():
)

def typedb_protocol():
# TODO: Temp, return typedb
git_repository(
name = "typedb_protocol",
remote = "https://github.com/typedb/typedb-protocol",
tag = "3.0.0", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_protocol
remote = "https://github.com/farost/typedb-protocol",
commit = "b4b7ee87b08c16831a21629f81974347f38cce5c", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_protocol
)
# git_repository(
# name = "typedb_protocol",
# remote = "https://github.com/typedb/typedb-protocol",
# tag = "3.0.0", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_protocol
# )

def typedb_behaviour():
git_repository(
Expand Down
2 changes: 1 addition & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@

[dependencies.typedb-protocol]
features = []
rev = "b4b7ee87b08c16831a21629f81974347f38cce5c"
git = "https://github.com/typedb/typedb-protocol"
tag = "3.0.0"
default-features = false

[dependencies.log]
Expand Down
19 changes: 16 additions & 3 deletions rust/src/common/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::{collections::HashSet, error::Error as StdError, fmt};

use itertools::Itertools;
use tonic::{Code, Status};
use tonic_types::StatusExt;
use tonic_types::{ErrorDetails, ErrorInfo, StatusExt};

use super::{address::Address, RequestID};

Expand Down Expand Up @@ -150,7 +150,7 @@ error_messages! { ConnectionError
15: "The replica is not the primary replica.",
ClusterAllNodesFailed { errors: String } =
16: "Attempted connecting to all TypeDB Cluster servers, but the following errors occurred: \n{errors}.",
ClusterTokenCredentialInvalid =
TokenCredentialInvalid =
17: "Invalid token credentials.",
EncryptionSettingsMismatch =
18: "Unable to connect to TypeDB: possible encryption settings mismatch.",
Expand Down Expand Up @@ -275,6 +275,16 @@ impl Error {
}
}

fn try_extracting_connection_error(status: &Status, code: &str) -> Option<ConnectionError> {
// TODO: We should probably catch more connection errors instead of wrapping them into
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really want to refactor error messaging here.

// ServerErrors. However, the most valuable information even for connection is inside
// stacktraces now.
match code {
"AUT3" => Some(ConnectionError::TokenCredentialInvalid {}),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouu ok make a comment in the server-side that we depend on those error messages client-side and not to change them randomly -- if we don't already have that warning!

Copy link
Member Author

@farost farost Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We did a similar thing (even for a bigger number of errors) in 2.x in other domains. But I will. And the BDDs with server running with shorter tokens will show if we don't actually renew them.

_ => None,
}
}

fn from_message(message: &str) -> Self {
// TODO: Consider converting some of the messages to connection errors
Self::Other(message.to_owned())
Expand Down Expand Up @@ -352,9 +362,13 @@ impl From<Status> for Error {
})
} else if let Some(error_info) = details.error_info() {
let code = error_info.reason.clone();
if let Some(connection_error) = Self::try_extracting_connection_error(&status, &code) {
return Self::Connection(connection_error);
}
let domain = error_info.domain.clone();
let stack_trace =
if let Some(debug_info) = details.debug_info() { debug_info.stack_entries.clone() } else { vec![] };

Self::Server(ServerError::new(code, domain, status.message().to_owned(), stack_trace))
} else {
Self::from_message(status.message())
Expand All @@ -364,7 +378,6 @@ impl From<Status> for Error {
Self::parse_unavailable(status.message())
} else if status.code() == Code::Unknown
|| is_rst_stream(&status)
|| status.code() == Code::InvalidArgument
|| status.code() == Code::FailedPrecondition
|| status.code() == Code::AlreadyExists
{
Expand Down
4 changes: 2 additions & 2 deletions rust/src/connection/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ use crate::{
error::ServerError,
info::UserInfo,
user::User,
Options, TransactionType,
Credentials, Options, TransactionType,
};

#[derive(Debug)]
pub(super) enum Request {
ConnectionOpen { driver_lang: String, driver_version: String },
ConnectionOpen { driver_lang: String, driver_version: String, credentials: Credentials },

ServersAll,

Expand Down
26 changes: 20 additions & 6 deletions rust/src/connection/network/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
* under the License.
*/

use std::sync::Arc;
use std::sync::{Arc, RwLock};

use tonic::{
body::BoxBody,
client::GrpcService,
metadata::MetadataValue,
service::{
interceptor::{InterceptedService, ResponseFuture as InterceptorResponseFuture},
Interceptor,
Expand Down Expand Up @@ -65,20 +66,33 @@ pub(super) fn open_callcred_channel(
#[derive(Debug)]
pub(super) struct CallCredentials {
credentials: Credentials,
token: RwLock<Option<String>>,
}

impl CallCredentials {
pub(super) fn new(credentials: Credentials) -> Self {
Self { credentials }
Self { credentials, token: RwLock::new(None) }
}

pub(super) fn username(&self) -> &str {
self.credentials.username()
pub(super) fn credentials(&self) -> &Credentials {
&self.credentials
}

pub(super) fn set_token(&self, token: String) {
*self.token.write().expect("Expected token write lock acquisition on set") = Some(token);
}

pub(super) fn reset_token(&self) {
*self.token.write().expect("Expected token write lock acquisition on reset") = None;
}

pub(super) fn inject(&self, mut request: Request<()>) -> Request<()> {
request.metadata_mut().insert("username", self.credentials.username().try_into().unwrap());
request.metadata_mut().insert("password", self.credentials.password().try_into().unwrap());
if let Some(token) = &*self.token.read().expect("Expected token read lock acquisition on inject") {
request.metadata_mut().insert(
"authorization",
format!("Bearer {}", token).try_into().expect("Expected authorization header formatting"),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more a question about the server, but I decided to follow the standard HTTP format of authorization: Bearer <TOKEN> metadata records. It requires manual parsing in tonic and is done better in axum (for http), but I haven't found any explicit recommendation to turn away from the HTTP standard in gRPC and drop the Bearer part. Although we used to just set token: ... in 2.x.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong knowledge/opinion on this, maybe @lolski has some?

);
}
request
}
}
Expand Down
32 changes: 25 additions & 7 deletions rust/src/connection/network/proto/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

use itertools::Itertools;
use typedb_protocol::{
connection, database, database_manager, query::initial_res::Res, server_manager, transaction, user, user_manager,
Version::Version,
authentication, connection, database, database_manager, query::initial_res::Res, server_manager, transaction, user,
user_manager, Version::Version,
};
use uuid::Uuid;

Expand All @@ -32,14 +32,18 @@ use crate::{
error::{ConnectionError, InternalError, ServerError},
info::UserInfo,
user::User,
Credentials,
};

impl TryIntoProto<connection::open::Req> for Request {
fn try_into_proto(self) -> Result<connection::open::Req> {
match self {
Self::ConnectionOpen { driver_lang, driver_version } => {
Ok(connection::open::Req { version: Version.into(), driver_lang, driver_version })
}
Self::ConnectionOpen { driver_lang, driver_version, credentials } => Ok(connection::open::Req {
version: Version.into(),
driver_lang,
driver_version,
authentication: Some(credentials.try_into_proto()?),
}),
other => Err(InternalError::UnexpectedRequestType { request_type: format!("{other:?}") }.into()),
}
}
Expand Down Expand Up @@ -225,14 +229,28 @@ impl TryIntoProto<user::delete::Req> for Request {
}
}

impl TryIntoProto<authentication::sign_in::Req> for Credentials {
fn try_into_proto(self) -> Result<authentication::sign_in::Req> {
Ok(authentication::sign_in::Req {
credentials: Some(authentication::sign_in::req::Credentials::Password(
authentication::sign_in::req::Password {
username: self.username().to_owned(),
password: self.password().to_owned(),
},
)),
})
}
}

impl TryFromProto<connection::open::Res> for Response {
fn try_from_proto(proto: connection::open::Res) -> Result<Self> {
let mut database_infos = Vec::new();
for database_info_proto in proto.databases_all.unwrap().databases {
for database_info_proto in proto.databases_all.expect("Expected databases data").databases {
database_infos.push(DatabaseInfo::try_from_proto(database_info_proto)?);
}
Ok(Self::ConnectionOpen {
connection_id: Uuid::from_slice(proto.connection_id.unwrap().id.as_slice()).unwrap(),
connection_id: Uuid::from_slice(proto.connection_id.expect("Expected connection id").id.as_slice())
.unwrap(),
server_duration_millis: proto.server_duration_millis,
databases: database_infos,
})
Expand Down
45 changes: 37 additions & 8 deletions rust/src/connection/network/stub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ use tokio::sync::mpsc::{unbounded_channel as unbounded_async, UnboundedSender};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::{Response, Status, Streaming};
use typedb_protocol::{
connection, database, database_manager, server_manager, transaction, type_db_client::TypeDbClient as GRPC, user,
user_manager,
authentication, connection, database, database_manager, server_manager, transaction,
type_db_client::TypeDbClient as GRPC, user, user_manager,
};

use super::channel::{CallCredentials, GRPCChannel};
use crate::common::{error::ConnectionError, Error, Result, StdResult};
use crate::{
common::{error::ConnectionError, Error, Result, StdResult},
connection::network::proto::TryIntoProto,
};

type TonicResult<T> = StdResult<Response<T>, Status>;

Expand All @@ -45,15 +48,41 @@ impl<Channel: GRPCChannel> RPCStub<Channel> {
Self { grpc: GRPC::new(channel), call_credentials }
}

async fn call<F, R>(&mut self, call: F) -> Result<R>
async fn call_with_auto_renew_token<F, R>(&mut self, call: F) -> Result<R>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is actually a copypaste from 2.x.

where
for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, Result<R>>,
{
call(self).await
match call(self).await {
Err(Error::Connection(ConnectionError::TokenCredentialInvalid)) => {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it failed specifically, renew the token and retry.

debug!("Request rejected because token credential was invalid. Renewing token and trying again...");
self.renew_token().await?;
call(self).await
}
res => res,
}
}

async fn renew_token(&mut self) -> Result {
if let Some(call_credentials) = &self.call_credentials {
trace!("Renewing token...");
call_credentials.reset_token();
let request = call_credentials.credentials().clone().try_into_proto()?;
let token = self.grpc.sign_in(request).await?.into_inner().token;
call_credentials.set_token(token);
trace!("Token renewed");
}
Ok(())
}

pub(super) async fn connection_open(&mut self, req: connection::open::Req) -> Result<connection::open::Res> {
self.single(|this| Box::pin(this.grpc.connection_open(req.clone()))).await
let result = self.single(|this| Box::pin(this.grpc.connection_open(req.clone()))).await;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sneak into the response to get the token on the network level without asking the more business-like logic to work with it for us.

if let Ok(response) = &result {
if let Some(call_credentials) = &self.call_credentials {
call_credentials
.set_token(response.authentication.as_ref().expect("Expected authentication token").token.clone());
}
}
result
}

pub(super) async fn servers_all(&mut self, req: server_manager::all::Req) -> Result<server_manager::all::Res> {
Expand Down Expand Up @@ -107,7 +136,7 @@ impl<Channel: GRPCChannel> RPCStub<Channel> {
&mut self,
open_req: transaction::Req,
) -> Result<(UnboundedSender<transaction::Client>, Streaming<transaction::Server>)> {
self.call(|this| {
self.call_with_auto_renew_token(|this| {
let transaction_req = transaction::Client { reqs: vec![open_req.clone()] };
Box::pin(async {
let (sender, receiver) = unbounded_async();
Expand Down Expand Up @@ -154,6 +183,6 @@ impl<Channel: GRPCChannel> RPCStub<Channel> {
for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, TonicResult<R>> + Send + Sync,
R: 'static,
{
self.call(|this| Box::pin(call(this).map(|r| Ok(r?.into_inner())))).await
self.call_with_auto_renew_token(|this| Box::pin(call(this).map(|r| Ok(r?.into_inner())))).await
}
}
12 changes: 8 additions & 4 deletions rust/src/connection/server_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ impl ServerConnection {
) -> crate::Result<(Self, Vec<DatabaseInfo>)> {
let username = credentials.username().to_string();
let request_transmitter =
Arc::new(RPCTransmitter::start(address, credentials, driver_options, &background_runtime)?);
Arc::new(RPCTransmitter::start(address, credentials.clone(), driver_options, &background_runtime)?);
let (connection_id, latency, database_info) =
Self::open_connection(&request_transmitter, driver_lang, driver_version).await?;
Self::open_connection(&request_transmitter, driver_lang, driver_version, credentials).await?;
let latency_tracker = LatencyTracker::new(latency);
let server_connection = Self {
background_runtime,
Expand All @@ -85,9 +85,13 @@ impl ServerConnection {
request_transmitter: &RPCTransmitter,
driver_lang: &str,
driver_version: &str,
credentials: Credentials,
) -> crate::Result<(Uuid, Duration, Vec<DatabaseInfo>)> {
let message =
Request::ConnectionOpen { driver_lang: driver_lang.to_owned(), driver_version: driver_version.to_owned() };
let message = Request::ConnectionOpen {
driver_lang: driver_lang.to_owned(),
driver_version: driver_version.to_owned(),
credentials,
};

let request_time = Instant::now();
match request_transmitter.request(message).await? {
Expand Down
1 change: 1 addition & 0 deletions tool/test/start-community-server.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ rm -rf typedb-all

bazel run //tool/test:typedb-extractor -- typedb-all
BAZEL_JAVA_HOME=$(bazel run //tool/test:echo-java-home)
# TODO: Can add `--server.authentication.token_ttl_seconds 15` to test token auto-renewal in BDDs!
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have the artifacts (I just tested locally), so this feature can be good. We have small integration tests, some small behaviour tests, and some big behaviour tests. Everything will be covered!

./typedb-all/typedb server --development-mode.enabled &

set +e
Expand Down