Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions src/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use api::v1::prometheus_gateway_client::PrometheusGatewayClient;
use api::v1::region::region_client::RegionClient as PbRegionClient;
use arrow_flight::flight_service_client::FlightServiceClient;
use common_grpc::channel_manager::{
ChannelConfig, ChannelManager, ClientTlsOption, load_tls_config,
ChannelConfig, ChannelManager, ClientTlsOption, load_client_tls_config,
};
use parking_lot::RwLock;
use snafu::{OptionExt, ResultExt};
Expand Down Expand Up @@ -95,9 +95,9 @@ impl Client {
U: AsRef<str>,
A: AsRef<[U]>,
{
let channel_config = ChannelConfig::default().client_tls_config(client_tls);
let tls_config = load_tls_config(channel_config.client_tls.as_ref())
.context(error::CreateTlsChannelSnafu)?;
let channel_config = ChannelConfig::default().client_tls_config(client_tls.clone());
let tls_config =
load_client_tls_config(Some(client_tls)).context(error::CreateTlsChannelSnafu)?;
let channel_manager = ChannelManager::with_config(channel_config, tls_config);
Ok(Self::with_manager_and_urls(channel_manager, urls))
}
Expand Down
2 changes: 2 additions & 0 deletions src/common/grpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ datatypes.workspace = true
flatbuffers = "25.2"
hyper.workspace = true
lazy_static.workspace = true
notify.workspace = true
prost.workspace = true
serde.workspace = true
serde_json.workspace = true
Expand All @@ -37,6 +38,7 @@ vec1 = "1.12"
criterion = "0.4"
hyper-util = { workspace = true, features = ["tokio"] }
rand.workspace = true
tempfile.workspace = true

[[bench]]
name = "bench_main"
Expand Down
103 changes: 90 additions & 13 deletions src/common/grpc/src/channel_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::time::Duration;
Expand All @@ -30,6 +31,7 @@ use tonic::transport::{
use tower::Service;

use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, Result};
use crate::reloadable_tls::{ReloadableTlsConfig, TlsConfigLoader, maybe_watch_tls_config};

const RECYCLE_CHANNEL_INTERVAL_SECS: u64 = 60;
pub const DEFAULT_GRPC_REQUEST_TIMEOUT_SECS: u64 = 10;
Expand All @@ -50,7 +52,7 @@ pub struct ChannelManager {
struct Inner {
id: u64,
config: ChannelConfig,
client_tls_config: Option<ClientTlsConfig>,
reloadable_client_tls_config: Option<Arc<ReloadableClientTlsConfig>>,
pool: Arc<Pool>,
channel_recycle_started: AtomicBool,
cancel: CancellationToken,
Expand Down Expand Up @@ -78,7 +80,7 @@ impl Inner {
Self {
id,
config,
client_tls_config: None,
reloadable_client_tls_config: None,
pool,
channel_recycle_started: AtomicBool::new(false),
cancel,
Expand All @@ -91,13 +93,17 @@ impl ChannelManager {
Default::default()
}

/// unified with config function that support tls config
/// use [`load_tls_config`] to load tls config from file system
pub fn with_config(config: ChannelConfig, tls_config: Option<ClientTlsConfig>) -> Self {
/// Create a ChannelManager with configuration and optional TLS config
///
/// Use [`load_client_tls_config`] to create TLS configuration from `ClientTlsOption`.
/// The TLS config supports both static (watch disabled) and dynamic reloading (watch enabled).
/// If you want to use dynamic reloading, please **manually** invoke [`maybe_watch_client_tls_config`] after this method.
pub fn with_config(
config: ChannelConfig,
reloadable_tls_config: Option<Arc<ReloadableClientTlsConfig>>,
) -> Self {
let mut inner = Inner::with_config(config.clone());
if let Some(tls_config) = tls_config {
inner.client_tls_config = Some(tls_config);
}
inner.reloadable_client_tls_config = reloadable_tls_config;
Self {
inner: Arc::new(inner),
}
Expand Down Expand Up @@ -172,8 +178,21 @@ impl ChannelManager {
self.pool().retain_channel(f);
}

/// Clear all channels to force reconnection.
/// This should be called when TLS configuration changes to ensure new connections use updated certificates.
pub fn clear_all_channels(&self) {
self.pool().retain_channel(|_, _| false);
}

fn build_endpoint(&self, addr: &str) -> Result<Endpoint> {
let http_prefix = if self.inner.client_tls_config.is_some() {
// Get the latest TLS config from reloadable config (which handles both static and dynamic cases)
let tls_config = self
.inner
.reloadable_client_tls_config
.as_ref()
.and_then(|c| c.get_config());

let http_prefix = if tls_config.is_some() {
"https"
} else {
"http"
Expand Down Expand Up @@ -212,9 +231,9 @@ impl ChannelManager {
if let Some(enabled) = self.config().http2_adaptive_window {
endpoint = endpoint.http2_adaptive_window(enabled);
}
if let Some(tls_config) = &self.inner.client_tls_config {
if let Some(tls_config) = tls_config {
endpoint = endpoint
.tls_config(tls_config.clone())
.tls_config(tls_config)
.context(CreateChannelSnafu { addr })?;
}

Expand Down Expand Up @@ -248,7 +267,7 @@ impl ChannelManager {
}
}

pub fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result<Option<ClientTlsConfig>> {
fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result<Option<ClientTlsConfig>> {
let path_config = match tls_option {
Some(path_config) if path_config.enabled => path_config,
_ => return Ok(None),
Expand Down Expand Up @@ -276,13 +295,69 @@ pub fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result<Option<Cl
Ok(Some(tls_config))
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
impl TlsConfigLoader<ClientTlsConfig> for ClientTlsOption {
type Error = crate::error::Error;

fn load(&self) -> Result<Option<ClientTlsConfig>> {
load_tls_config(Some(self))
}

fn watch_paths(&self) -> Vec<&Path> {
let mut paths = Vec::new();
if let Some(cert_path) = &self.client_cert_path {
paths.push(Path::new(cert_path.as_str()));
}
if let Some(key_path) = &self.client_key_path {
paths.push(Path::new(key_path.as_str()));
}
if let Some(ca_path) = &self.server_ca_cert_path {
paths.push(Path::new(ca_path.as_str()));
}
paths
}

fn watch_enabled(&self) -> bool {
self.enabled && self.watch
}
}

/// Type alias for client-side reloadable TLS config
pub type ReloadableClientTlsConfig = ReloadableTlsConfig<ClientTlsConfig, ClientTlsOption>;

/// Load client TLS configuration from `ClientTlsOption` and return a `ReloadableClientTlsConfig`.
/// This is the primary way to create TLS configuration for the ChannelManager.
pub fn load_client_tls_config(
tls_option: Option<ClientTlsOption>,
) -> Result<Option<Arc<ReloadableClientTlsConfig>>> {
match tls_option {
Some(option) if option.enabled => {
let reloadable = ReloadableClientTlsConfig::try_new(option)?;
Ok(Some(Arc::new(reloadable)))
}
_ => Ok(None),
}
}

pub fn maybe_watch_client_tls_config(
client_tls_config: Arc<ReloadableClientTlsConfig>,
channel_manager: ChannelManager,
) -> Result<()> {
maybe_watch_tls_config(client_tls_config, move || {
// Clear all existing channels to force reconnection with new certificates
channel_manager.clear_all_channels();
info!("Cleared all existing channels to use new TLS certificates.");
})
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct ClientTlsOption {
/// Whether to enable TLS for client.
pub enabled: bool,
pub server_ca_cert_path: Option<String>,
pub client_cert_path: Option<String>,
pub client_key_path: Option<String>,
#[serde(default)]
pub watch: bool,
}

#[derive(Clone, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -602,6 +677,7 @@ mod tests {
server_ca_cert_path: Some("some_server_path".to_string()),
client_cert_path: Some("some_cert_path".to_string()),
client_key_path: Some("some_key_path".to_string()),
watch: false,
});

assert_eq!(
Expand All @@ -623,6 +699,7 @@ mod tests {
server_ca_cert_path: Some("some_server_path".to_string()),
client_cert_path: Some("some_cert_path".to_string()),
client_key_path: Some("some_key_path".to_string()),
watch: false,
}),
max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
Expand Down
10 changes: 10 additions & 0 deletions src/common/grpc/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ pub enum Error {
location: Location,
},

#[snafu(display("Failed to watch config file path: {}", path))]
FileWatch {
path: String,
#[snafu(source)]
error: notify::Error,
#[snafu(implicit)]
location: Location,
},

#[snafu(display(
"Write type mismatch, column name: {}, expected: {}, actual: {}",
column_name,
Expand Down Expand Up @@ -108,6 +117,7 @@ impl ErrorExt for Error {
match self {
Error::InvalidTlsConfig { .. }
| Error::InvalidConfigFilePath { .. }
| Error::FileWatch { .. }
| Error::TypeMismatch { .. }
| Error::InvalidFlightData { .. }
| Error::NotSupported { .. } => StatusCode::InvalidArguments,
Expand Down
1 change: 1 addition & 0 deletions src/common/grpc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub mod channel_manager;
pub mod error;
pub mod flight;
pub mod precision;
pub mod reloadable_tls;
pub mod select;

pub use arrow_flight::FlightData;
Expand Down
Loading