Skip to content

Commit

Permalink
mk-sql: improve registry detection logic
Browse files Browse the repository at this point in the history
Improved logic to determine whether connection is local.
Added new testing logic based on Windows registry dumps.
Tested on the vagrant machine.

CMK-17054

Change-Id: If5c3ae1549365776260545364e01222e2ccb4125
  • Loading branch information
s-kipnis committed Feb 12, 2025
1 parent a86d8cc commit 8a49b6e
Show file tree
Hide file tree
Showing 4 changed files with 328 additions and 116 deletions.
76 changes: 55 additions & 21 deletions packages/host/mk-sql/src/config/ms_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ impl Config {
}

pub fn get_registry_instance_info(name: &InstanceName) -> Option<InstanceInfo> {
let all = get_instances();
let all = get_instances(None);
let a = all.iter().find(|i| &i.name == name);
a.cloned()
}
Expand All @@ -260,12 +260,8 @@ fn get_additional_registry_instances(
auth: &Authentication,
conn: &Connection,
) -> Vec<CustomInstance> {
let work_host = calc_real_host(auth, conn).to_string().to_lowercase();
if work_host != "localhost" {
log::info!(
"skipping registry instances: the reason the host `{} `is not localhost",
work_host
);
if !is_local_endpoint(auth, conn) {
log::info!("skipping registry instances: the host is not enough localhost");
return vec![];
}

Expand All @@ -274,27 +270,34 @@ fn get_additional_registry_instances(
.map(|i| i.name().to_string().to_lowercase().clone())
.collect();
log::info!("localhost is defined, adding registry instances");
platform::registry::get_instances()
platform::registry::get_instances(None)
.into_iter()
.filter_map(|i| {
if names.contains(&i.name.to_string().to_lowercase()) {
.filter_map(|registry_instance_info| {
if names.contains(&registry_instance_info.name.to_string().to_lowercase()) {
log::info!(
"{} is ignored as already defined in custom instances",
i.name
registry_instance_info.name
);
return None;
}

Some(CustomInstance::from_registry(
&i.name,
&registry_instance_info.name,
auth,
conn,
&i.final_port(),
&registry_instance_info.final_host(),
&registry_instance_info.final_port(),
))
})
.collect::<Vec<CustomInstance>>()
}

pub fn is_local_endpoint(auth: &Authentication, conn: &Connection) -> bool {
auth.auth_type() == &AuthType::Integrated
|| conn.hostname() == HostName::from("localhost".to_owned())
|| conn.hostname() == HostName::from("127.0.0.1".to_owned())
}

#[derive(PartialEq, Debug, Clone)]
pub struct Authentication {
username: String,
Expand Down Expand Up @@ -696,12 +699,14 @@ impl CustomInstance {
name: &InstanceName,
main_auth: &Authentication,
main_conn: &Connection,
port: &Option<&Port>,
hostname: &Option<HostName>,
port: &Option<Port>,
) -> Self {
let (auth, conn) = CustomInstance::make_registry_auth_and_conn(
main_auth,
main_conn,
port.unwrap_or(&Port::from(0)),
hostname,
port.as_ref().unwrap_or(&Port::from(0)),
);
Self {
name: name.clone(),
Expand Down Expand Up @@ -746,9 +751,13 @@ impl CustomInstance {
fn make_registry_auth_and_conn(
main_auth: &Authentication,
main_conn: &Connection,
hostname: &Option<HostName>,
port: &Port,
) -> (Authentication, Connection) {
let conn = Connection {
hostname: hostname
.clone()
.unwrap_or_else(|| main_conn.hostname().clone()),
port: port.clone(),
..main_conn.clone()
};
Expand Down Expand Up @@ -780,19 +789,15 @@ impl CustomInstance {
}

pub fn calc_real_host(auth: &Authentication, conn: &Connection) -> HostName {
if is_local_host(auth, conn) {
if is_local_endpoint(auth, conn) {
"localhost".to_string().into()
} else {
conn.hostname().clone()
}
}

pub fn is_local_host(auth: &Authentication, _conn: &Connection) -> bool {
auth.auth_type() == &AuthType::Integrated
}

pub fn is_use_tcp(name: &InstanceName, auth: &Authentication, conn: &Connection) -> bool {
if is_local_host(auth, conn) {
if is_local_endpoint(auth, conn) {
get_registry_instance_info(name)
.map(|i| i.is_tcp())
.unwrap_or(true)
Expand Down Expand Up @@ -1709,4 +1714,33 @@ mssql:
let c = Connection::default();
assert!(is_use_tcp(&"MSSQLSERVER".to_string().into(), &a, &c));
}

#[test]
fn test_is_local_endpoint() {
let auth_integrated = Authentication {
auth_type: AuthType::Integrated,
..Default::default()
};
let auth_sql = Authentication {
auth_type: AuthType::SqlServer,
..Default::default()
};
let conn_non_local = Connection {
hostname: HostName::from("localhost.com".to_string()),
..Default::default()
};
let conn_local = Connection {
hostname: HostName::from("localhost".to_string()),
..Default::default()
};
let conn_127 = Connection {
hostname: HostName::from("127.0.0.1".to_string()),
..Default::default()
};
assert!(is_local_endpoint(&auth_integrated, &conn_local));
assert!(is_local_endpoint(&auth_integrated, &conn_non_local));
assert!(is_local_endpoint(&auth_sql, &conn_local));
assert!(is_local_endpoint(&auth_sql, &conn_127));
assert!(!is_local_endpoint(&auth_sql, &conn_non_local));
}
}
18 changes: 14 additions & 4 deletions packages/host/mk-sql/src/ms_sql/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use super::client::{self, UniClient};
use super::custom::get_sql_dir;
use super::section::{Section, SectionKind};
use crate::config::defines::defaults::MAX_CONNECTIONS;
use crate::config::ms_sql::{is_local_host, is_use_tcp, Discovery};
use crate::config::ms_sql::{is_local_endpoint, is_use_tcp, Discovery};
use crate::config::section;
use crate::config::{
self,
Expand Down Expand Up @@ -2164,7 +2164,13 @@ fn print_builders(title: &str, builders: &[SqlInstanceBuilder]) {
builders.len(),
builders
.iter()
.map(|i| format!("{}", i.get_name()))
.map(|i| format!(
"{} {}",
i.get_name(),
i.get_endpoint()
.map(|e| e.dump_compact())
.unwrap_or_else(|| "None".to_string())
))
.collect::<Vec<_>>()
.join(", ")
);
Expand All @@ -2177,7 +2183,7 @@ async fn get_custom_instance_builder(
let instance_name = &builder.get_name();
let auth = endpoint.auth();
let conn = endpoint.conn();
if is_local_host(auth, conn) && !is_use_tcp(instance_name, auth, conn) {
if is_local_endpoint(auth, conn) && !is_use_tcp(instance_name, auth, conn) {
if let Ok(mut client) = create_odbc_client(instance_name, None) {
log::debug!("Trying to connect to `{instance_name}` using ODBC");
let b = obtain_properties(&mut client, instance_name)
Expand Down Expand Up @@ -2342,7 +2348,11 @@ fn determine_reconnect(
Some(customization)
if Some(&customization.endpoint()) != instance_builder.get_endpoint() =>
{
log::info!("Instance {} to be reconnected", instance_builder.get_name(),);
log::info!(
"Instance {} to be reconnected with {}",
instance_builder.get_name(),
customization.endpoint().dump_compact()
);
(instance_builder, Some(customization.endpoint()))
}
_ => {
Expand Down
Loading

0 comments on commit 8a49b6e

Please sign in to comment.