From 4437a2ac0b46b700a65fb8ad946c97b8f212c52e Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Fri, 8 Aug 2025 12:58:53 +0530 Subject: [PATCH 01/25] Refactor codebase to use a unified http client Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/auth.py | 4 +- src/databricks/sql/auth/authenticators.py | 2 + src/databricks/sql/auth/common.py | 61 +++-- src/databricks/sql/auth/oauth.py | 28 ++- src/databricks/sql/backend/sea/queue.py | 4 + src/databricks/sql/backend/sea/result_set.py | 1 + src/databricks/sql/client.py | 38 ++- .../sql/cloudfetch/download_manager.py | 3 + src/databricks/sql/cloudfetch/downloader.py | 79 +++--- src/databricks/sql/common/feature_flag.py | 16 +- src/databricks/sql/common/http.py | 112 --------- .../sql/common/unified_http_client.py | 226 ++++++++++++++++++ src/databricks/sql/result_set.py | 1 + src/databricks/sql/session.py | 39 ++- .../sql/telemetry/telemetry_client.py | 22 +- src/databricks/sql/utils.py | 15 +- 16 files changed, 440 insertions(+), 211 deletions(-) create mode 100644 src/databricks/sql/common/unified_http_client.py diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 3792d6d05..a8d0671b0 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -35,6 +35,7 @@ def get_auth_provider(cfg: ClientContext): cfg.oauth_client_id, cfg.oauth_scopes, cfg.auth_type, + http_client=http_client, ) elif cfg.access_token is not None: return AccessTokenAuthProvider(cfg.access_token) @@ -53,6 +54,7 @@ def get_auth_provider(cfg: ClientContext): cfg.oauth_redirect_port_range, cfg.oauth_client_id, cfg.oauth_scopes, + http_client=http_client, ) else: raise RuntimeError("No valid authentication settings!") @@ -79,7 +81,7 @@ def get_client_id_and_redirect_port(use_azure_auth: bool): ) -def get_python_sql_connector_auth_provider(hostname: str, **kwargs): +def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs): # TODO : unify all the auth mechanisms with the Python SDK auth_type = kwargs.get("auth_type") diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 26c1f3708..80f44812c 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -63,6 +63,7 @@ def __init__( redirect_port_range: List[int], client_id: str, scopes: List[str], + http_client, auth_type: str = "databricks-oauth", ): try: @@ -79,6 +80,7 @@ def __init__( port_range=redirect_port_range, client_id=client_id, idp_endpoint=idp_endpoint, + http_client=http_client, ) self._hostname = hostname self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(cloud_scopes) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 5cfbc37c0..262166a52 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -2,7 +2,6 @@ import logging from typing import Optional, List from urllib.parse import urlparse -from databricks.sql.common.http import DatabricksHttpClient, HttpMethod logger = logging.getLogger(__name__) @@ -36,6 +35,21 @@ def __init__( tls_client_cert_file: Optional[str] = None, oauth_persistence=None, credentials_provider=None, + # HTTP client configuration parameters + ssl_options=None, # SSLOptions type + socket_timeout: Optional[float] = None, + retry_stop_after_attempts_count: Optional[int] = None, + retry_delay_min: Optional[float] = None, + retry_delay_max: Optional[float] = None, + retry_stop_after_attempts_duration: Optional[float] = None, + retry_delay_default: Optional[float] = None, + retry_dangerous_codes: Optional[List[int]] = None, + http_proxy: Optional[str] = None, + proxy_username: Optional[str] = None, + proxy_password: Optional[str] = None, + pool_connections: Optional[int] = None, + pool_maxsize: Optional[int] = None, + user_agent: Optional[str] = None, ): self.hostname = hostname self.access_token = access_token @@ -51,6 +65,22 @@ def __init__( self.tls_client_cert_file = tls_client_cert_file self.oauth_persistence = oauth_persistence self.credentials_provider = credentials_provider + + # HTTP client configuration + self.ssl_options = ssl_options + self.socket_timeout = socket_timeout + self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 30 + self.retry_delay_min = retry_delay_min or 1.0 + self.retry_delay_max = retry_delay_max or 60.0 + self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration or 900.0 + self.retry_delay_default = retry_delay_default or 5.0 + self.retry_dangerous_codes = retry_dangerous_codes or [] + self.http_proxy = http_proxy + self.proxy_username = proxy_username + self.proxy_password = proxy_password + self.pool_connections = pool_connections or 10 + self.pool_maxsize = pool_maxsize or 20 + self.user_agent = user_agent def get_effective_azure_login_app_id(hostname) -> str: @@ -69,7 +99,7 @@ def get_effective_azure_login_app_id(hostname) -> str: return AzureAppId.PROD.value[1] -def get_azure_tenant_id_from_host(host: str, http_client=None) -> str: +def get_azure_tenant_id_from_host(host: str, http_client) -> str: """ Load the Azure tenant ID from the Azure Databricks login page. @@ -78,23 +108,22 @@ def get_azure_tenant_id_from_host(host: str, http_client=None) -> str: the Azure login page, and the tenant ID is extracted from the redirect URL. """ - if http_client is None: - http_client = DatabricksHttpClient.get_instance() - login_url = f"{host}/aad/auth" logger.debug("Loading tenant ID from %s", login_url) - with http_client.execute(HttpMethod.GET, login_url, allow_redirects=False) as resp: - if resp.status_code // 100 != 3: + + with http_client.request_context('GET', login_url, allow_redirects=False) as resp: + if resp.status // 100 != 3: raise ValueError( - f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}" + f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}" ) - entra_id_endpoint = resp.headers.get("Location") + entra_id_endpoint = dict(resp.headers).get("Location") if entra_id_endpoint is None: raise ValueError(f"No Location header in response from {login_url}") - # The Location header has the following form: https://login.microsoftonline.com//oauth2/authorize?... - # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud). - url = urlparse(entra_id_endpoint) - path_segments = url.path.split("/") - if len(path_segments) < 2: - raise ValueError(f"Invalid path in Location header: {url.path}") - return path_segments[1] + + # The Location header has the following form: https://login.microsoftonline.com//oauth2/authorize?... + # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud). + url = urlparse(entra_id_endpoint) + path_segments = url.path.split("/") + if len(path_segments) < 2: + raise ValueError(f"Invalid path in Location header: {url.path}") + return path_segments[1] diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index aa3184d88..0d67929a3 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -9,10 +9,8 @@ from typing import List, Optional import oauthlib.oauth2 -import requests from oauthlib.oauth2.rfc6749.errors import OAuth2Error -from requests.exceptions import RequestException -from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader +from databricks.sql.common.http import HttpMethod, HttpHeader from databricks.sql.common.http import OAuthResponse from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler from databricks.sql.auth.endpoint import OAuthEndpointCollection @@ -85,11 +83,13 @@ def __init__( port_range: List[int], client_id: str, idp_endpoint: OAuthEndpointCollection, + http_client, ): self.port_range = port_range self.client_id = client_id self.redirect_port = None self.idp_endpoint = idp_endpoint + self.http_client = http_client @staticmethod def __token_urlsafe(nbytes=32): @@ -103,8 +103,12 @@ def __fetch_well_known_config(self, hostname: str): known_config_url = self.idp_endpoint.get_openid_config_url(hostname) try: - response = requests.get(url=known_config_url, auth=IgnoreNetrcAuth()) - except RequestException as e: + from databricks.sql.common.unified_http_client import IgnoreNetrcAuth + response = self.http_client.request('GET', url=known_config_url) + # Convert urllib3 response to requests-like response for compatibility + response.status_code = response.status + response.json = lambda: json.loads(response.data.decode()) + except Exception as e: logger.error( f"Unable to fetch OAuth configuration from {known_config_url}.\n" "Verify it is a valid workspace URL and that OAuth is " @@ -122,7 +126,7 @@ def __fetch_well_known_config(self, hostname: str): raise RuntimeError(msg) try: return response.json() - except requests.exceptions.JSONDecodeError as e: + except Exception as e: logger.error( f"Unable to decode OAuth configuration from {known_config_url}.\n" "Verify it is a valid workspace URL and that OAuth is " @@ -209,10 +213,13 @@ def __send_token_request(token_request_url, data): "Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded", } - response = requests.post( - url=token_request_url, data=data, headers=headers, auth=IgnoreNetrcAuth() + # Use unified HTTP client + from databricks.sql.common.unified_http_client import IgnoreNetrcAuth + response = self.http_client.request( + 'POST', url=token_request_url, body=data, headers=headers ) - return response.json() + # Convert urllib3 response to dict for compatibility + return json.loads(response.data.decode()) def __send_refresh_token_request(self, hostname, refresh_token): oauth_config = self.__fetch_well_known_config(hostname) @@ -320,6 +327,7 @@ def __init__( token_url, client_id, client_secret, + http_client, extra_params: dict = {}, ): self.client_id = client_id @@ -327,7 +335,7 @@ def __init__( self.token_url = token_url self.extra_params = extra_params self.token: Optional[Token] = None - self._http_client = DatabricksHttpClient.get_instance() + self._http_client = http_client def get_token(self) -> Token: if self.token is None or self.token.is_expired(): diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 130f0c5bf..4a319c442 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -50,6 +50,7 @@ def build_queue( max_download_threads: int, sea_client: SeaDatabricksClient, lz4_compressed: bool, + http_client, ) -> ResultSetQueue: """ Factory method to build a result set queue for SEA backend. @@ -94,6 +95,7 @@ def build_queue( total_chunk_count=manifest.total_chunk_count, lz4_compressed=lz4_compressed, description=description, + http_client=http_client, ) raise ProgrammingError("Invalid result format") @@ -309,6 +311,7 @@ def __init__( sea_client: SeaDatabricksClient, statement_id: str, total_chunk_count: int, + http_client, lz4_compressed: bool = False, description: List[Tuple] = [], ): @@ -337,6 +340,7 @@ def __init__( # TODO: fix these arguments when telemetry is implemented in SEA session_id_hex=None, chunk_id=0, + http_client=http_client, ) logger.debug( diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index afa70bc89..17838ed81 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -64,6 +64,7 @@ def __init__( max_download_threads=sea_client.max_download_threads, sea_client=sea_client, lz4_compressed=execute_response.lz4_compressed, + http_client=connection.session.http_client, ) # Call parent constructor with common attributes diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 73ee0e03c..295be29dc 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -6,7 +6,6 @@ import pyarrow except ImportError: pyarrow = None -import requests import json import os import decimal @@ -292,6 +291,7 @@ def read(self) -> Optional[OAuthToken]: auth_provider=self.session.auth_provider, host_url=self.session.host, batch_size=self.telemetry_batch_size, + http_client=self.session.http_client, ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( @@ -744,16 +744,20 @@ def _handle_staging_put( ) with open(local_file, "rb") as fh: - r = requests.put(url=presigned_url, data=fh, headers=headers) + r = self.connection.session.http_client.request('PUT', presigned_url, body=fh.read(), headers=headers) + # Add compatibility attributes for urllib3 response + r.status_code = r.status + if hasattr(r, 'data'): + r.content = r.data + r.ok = r.status < 400 + r.text = r.data.decode() if r.data else "" # fmt: off - # Design borrowed from: https://stackoverflow.com/a/2342589/5093960 - - OK = requests.codes.ok # 200 - CREATED = requests.codes.created # 201 - ACCEPTED = requests.codes.accepted # 202 - NO_CONTENT = requests.codes.no_content # 204 - + # HTTP status codes + OK = 200 + CREATED = 201 + ACCEPTED = 202 + NO_CONTENT = 204 # fmt: on if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: @@ -783,7 +787,13 @@ def _handle_staging_get( session_id_hex=self.connection.get_session_id_hex(), ) - r = requests.get(url=presigned_url, headers=headers) + r = self.connection.session.http_client.request('GET', presigned_url, headers=headers) + # Add compatibility attributes for urllib3 response + r.status_code = r.status + if hasattr(r, 'data'): + r.content = r.data + r.ok = r.status < 400 + r.text = r.data.decode() if r.data else "" # response.ok verifies the status code is not between 400-600. # Any 2xx or 3xx will evaluate r.ok == True @@ -802,7 +812,13 @@ def _handle_staging_remove( ): """Make an HTTP DELETE request to the presigned_url""" - r = requests.delete(url=presigned_url, headers=headers) + r = self.connection.session.http_client.request('DELETE', presigned_url, headers=headers) + # Add compatibility attributes for urllib3 response + r.status_code = r.status + if hasattr(r, 'data'): + r.content = r.data + r.ok = r.status < 400 + r.text = r.data.decode() if r.data else "" if not r.ok: raise OperationalError( diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 32b698bed..27265720f 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -25,6 +25,7 @@ def __init__( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, ): self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = [] self.chunk_id = chunk_id @@ -47,6 +48,7 @@ def __init__( self._ssl_options = ssl_options self.session_id_hex = session_id_hex self.statement_id = statement_id + self._http_client = http_client def get_next_downloaded_file( self, next_row_offset: int @@ -109,6 +111,7 @@ def _schedule_downloads(self): chunk_id=chunk_id, session_id_hex=self.session_id_hex, statement_id=self.statement_id, + http_client=self._http_client, ) task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 1331fa203..ea375fbbb 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -2,10 +2,9 @@ from dataclasses import dataclass from typing import Optional -from requests.adapters import Retry import lz4.frame import time -from databricks.sql.common.http import DatabricksHttpClient, HttpMethod +from databricks.sql.common.http import HttpMethod from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink from databricks.sql.exc import Error from databricks.sql.types import SSLOptions @@ -16,16 +15,6 @@ # TODO: Ideally, we should use a common retry policy (DatabricksRetryPolicy) for all the requests across the library. # But DatabricksRetryPolicy should be updated first - currently it can work only with Thrift requests -retryPolicy = Retry( - total=5, # max retry attempts - backoff_factor=1, # min delay, 1 second - # TODO: `backoff_max` is supported since `urllib3` v2.0.0, but we allow >= 1.26. - # The default value (120 seconds) used since v1.26 looks reasonable enough - # backoff_max=60, # max delay, 60 seconds - # retry all status codes below 100, 429 (Too Many Requests), and all codes above 500, - # excluding 501 Not implemented - status_forcelist=[*range(0, 101), 429, 500, *range(502, 1000)], -) @dataclass @@ -73,11 +62,12 @@ def __init__( chunk_id: int, session_id_hex: Optional[str], statement_id: str, + http_client, ): self.settings = settings self.link = link self._ssl_options = ssl_options - self._http_client = DatabricksHttpClient.get_instance() + self._http_client = http_client self.chunk_id = chunk_id self.session_id_hex = session_id_hex self.statement_id = statement_id @@ -104,50 +94,47 @@ def run(self) -> DownloadedFile: start_time = time.time() - with self._http_client.execute( - method=HttpMethod.GET, + with self._http_client.request_context( + method='GET', url=self.link.fileLink, timeout=self.settings.download_timeout, - verify=self._ssl_options.tls_verify, headers=self.link.httpHeaders - # TODO: Pass cert from `self._ssl_options` ) as response: - response.raise_for_status() - - # Save (and decompress if needed) the downloaded file - compressed_data = response.content - - # Log download metrics - download_duration = time.time() - start_time - self._log_download_metrics( - self.link.fileLink, len(compressed_data), download_duration - ) - - decompressed_data = ( - ResultSetDownloadHandler._decompress_data(compressed_data) - if self.settings.is_lz4_compressed - else compressed_data - ) + if response.status >= 400: + raise Exception(f"HTTP {response.status}: {response.data.decode()}") + compressed_data = response.data + + # Log download metrics + download_duration = time.time() - start_time + self._log_download_metrics( + self.link.fileLink, len(compressed_data), download_duration + ) - # The size of the downloaded file should match the size specified from TSparkArrowResultLink - if len(decompressed_data) != self.link.bytesNum: - logger.debug( - "ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format( - len(decompressed_data), self.link.bytesNum - ) - ) + decompressed_data = ( + ResultSetDownloadHandler._decompress_data(compressed_data) + if self.settings.is_lz4_compressed + else compressed_data + ) + # The size of the downloaded file should match the size specified from TSparkArrowResultLink + if len(decompressed_data) != self.link.bytesNum: logger.debug( - "ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format( - self.link.startRowOffset, self.link.rowCount + "ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format( + len(decompressed_data), self.link.bytesNum ) ) - return DownloadedFile( - decompressed_data, - self.link.startRowOffset, - self.link.rowCount, + logger.debug( + "ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format( + self.link.startRowOffset, self.link.rowCount ) + ) + + return DownloadedFile( + decompressed_data, + self.link.startRowOffset, + self.link.rowCount, + ) def _log_download_metrics( self, url: str, bytes_downloaded: int, duration_seconds: float diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 53add9253..8e7029805 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -1,6 +1,6 @@ +import json import threading import time -import requests from dataclasses import dataclass, field from concurrent.futures import ThreadPoolExecutor from typing import Dict, Optional, List, Any, TYPE_CHECKING @@ -49,7 +49,7 @@ class FeatureFlagsContext: in the background, returning stale data until the refresh completes. """ - def __init__(self, connection: "Connection", executor: ThreadPoolExecutor): + def __init__(self, connection: "Connection", executor: ThreadPoolExecutor, http_client): from databricks.sql import __version__ self._connection = connection @@ -65,6 +65,9 @@ def __init__(self, connection: "Connection", executor: ThreadPoolExecutor): self._feature_flag_endpoint = ( f"https://{self._connection.session.host}{endpoint_suffix}" ) + + # Use the provided HTTP client + self._http_client = http_client def _is_refresh_needed(self) -> bool: """Checks if the cache is due for a proactive background refresh.""" @@ -105,9 +108,12 @@ def _refresh_flags(self): self._connection.session.auth_provider.add_headers(headers) headers["User-Agent"] = self._connection.session.useragent_header - response = requests.get( - self._feature_flag_endpoint, headers=headers, timeout=30 + response = self._http_client.request( + 'GET', self._feature_flag_endpoint, headers=headers, timeout=30 ) + # Add compatibility attributes for urllib3 response + response.status_code = response.status + response.json = lambda: json.loads(response.data.decode()) if response.status_code == 200: ff_response = FeatureFlagsResponse.from_dict(response.json()) @@ -159,7 +165,7 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: # Use the unique session ID as the key key = connection.get_session_id_hex() if key not in cls._context_map: - cls._context_map[key] = FeatureFlagsContext(connection, cls._executor) + cls._context_map[key] = FeatureFlagsContext(connection, cls._executor, connection.session.http_client) return cls._context_map[key] @classmethod diff --git a/src/databricks/sql/common/http.py b/src/databricks/sql/common/http.py index 0cd2919c0..cf76a5fba 100644 --- a/src/databricks/sql/common/http.py +++ b/src/databricks/sql/common/http.py @@ -38,115 +38,3 @@ class OAuthResponse: resource: str = "" access_token: str = "" refresh_token: str = "" - - -# Singleton class for common Http Client -class DatabricksHttpClient: - ## TODO: Unify all the http clients in the PySQL Connector - - _instance = None - _lock = threading.Lock() - - def __init__(self): - self.session = requests.Session() - adapter = HTTPAdapter( - pool_connections=5, - pool_maxsize=10, - max_retries=Retry(total=10, backoff_factor=0.1), - ) - self.session.mount("https://", adapter) - self.session.mount("http://", adapter) - - @classmethod - def get_instance(cls) -> "DatabricksHttpClient": - if cls._instance is None: - with cls._lock: - if cls._instance is None: - cls._instance = DatabricksHttpClient() - return cls._instance - - @contextmanager - def execute( - self, method: HttpMethod, url: str, **kwargs - ) -> Generator[requests.Response, None, None]: - logger.info("Executing HTTP request: %s with url: %s", method.value, url) - response = None - try: - response = self.session.request(method.value, url, **kwargs) - yield response - except Exception as e: - logger.error("Error executing HTTP request in DatabricksHttpClient: %s", e) - raise e - finally: - if response is not None: - response.close() - - def close(self): - self.session.close() - - -class TelemetryHTTPAdapter(HTTPAdapter): - """ - Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request. - This ensures the retry timer is started and the command type is set correctly, - allowing the policy to manage its state for the duration of the request retries. - """ - - def send(self, request, **kwargs): - self.max_retries.command_type = CommandType.OTHER - self.max_retries.start_retry_timer() - return super().send(request, **kwargs) - - -class TelemetryHttpClient: # TODO: Unify all the http clients in the PySQL Connector - """Singleton HTTP client for sending telemetry data.""" - - _instance: Optional["TelemetryHttpClient"] = None - _lock = threading.Lock() - - TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3 - TELEMETRY_RETRY_DELAY_MIN = 1.0 - TELEMETRY_RETRY_DELAY_MAX = 10.0 - TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0 - - def __init__(self): - """Initializes the session and mounts the custom retry adapter.""" - retry_policy = DatabricksRetryPolicy( - delay_min=self.TELEMETRY_RETRY_DELAY_MIN, - delay_max=self.TELEMETRY_RETRY_DELAY_MAX, - stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT, - stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION, - delay_default=1.0, - force_dangerous_codes=[], - ) - adapter = TelemetryHTTPAdapter(max_retries=retry_policy) - self.session = requests.Session() - self.session.mount("https://", adapter) - self.session.mount("http://", adapter) - - @classmethod - def get_instance(cls) -> "TelemetryHttpClient": - """Get the singleton instance of the TelemetryHttpClient.""" - if cls._instance is None: - with cls._lock: - if cls._instance is None: - logger.debug("Initializing singleton TelemetryHttpClient") - cls._instance = TelemetryHttpClient() - return cls._instance - - def post(self, url: str, **kwargs) -> requests.Response: - """ - Executes a POST request using the configured session. - - This is a blocking call intended to be run in a background thread. - """ - logger.debug("Executing telemetry POST request to: %s", url) - return self.session.post(url, **kwargs) - - def close(self): - """Closes the underlying requests.Session.""" - logger.debug("Closing TelemetryHttpClient session.") - self.session.close() - # Clear the instance to allow for re-initialization if needed - with TelemetryHttpClient._lock: - TelemetryHttpClient._instance = None diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py new file mode 100644 index 000000000..8c3be2bfd --- /dev/null +++ b/src/databricks/sql/common/unified_http_client.py @@ -0,0 +1,226 @@ +import logging +import ssl +import urllib.parse +from contextlib import contextmanager +from typing import Dict, Any, Optional, Generator, Union + +import urllib3 +from urllib3 import PoolManager, ProxyManager +from urllib3.util import make_headers +from urllib3.exceptions import MaxRetryError + +from databricks.sql.auth.retry import DatabricksRetryPolicy +from databricks.sql.exc import RequestError + +logger = logging.getLogger(__name__) + + +class UnifiedHttpClient: + """ + Unified HTTP client for all Databricks SQL connector HTTP operations. + + This client uses urllib3 for robust HTTP communication with retry policies, + connection pooling, SSL support, and proxy support. It replaces the various + singleton HTTP clients and direct requests usage throughout the codebase. + """ + + def __init__(self, client_context): + """ + Initialize the unified HTTP client. + + Args: + client_context: ClientContext instance containing HTTP configuration + """ + self.config = client_context + self._pool_manager = None + self._setup_pool_manager() + + def _setup_pool_manager(self): + """Set up the urllib3 PoolManager with configuration from ClientContext.""" + + # SSL context setup + ssl_context = None + if self.config.ssl_options: + ssl_context = ssl.create_default_context() + + # Configure SSL verification + if not self.config.ssl_options.tls_verify: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + elif not self.config.ssl_options.tls_verify_hostname: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_REQUIRED + + # Load custom CA file if specified + if self.config.ssl_options.tls_trusted_ca_file: + ssl_context.load_verify_locations(self.config.ssl_options.tls_trusted_ca_file) + + # Load client certificate if specified + if (self.config.ssl_options.tls_client_cert_file and + self.config.ssl_options.tls_client_cert_key_file): + ssl_context.load_cert_chain( + self.config.ssl_options.tls_client_cert_file, + self.config.ssl_options.tls_client_cert_key_file, + self.config.ssl_options.tls_client_cert_key_password + ) + + # Create retry policy + retry_policy = DatabricksRetryPolicy( + delay_min=self.config.retry_delay_min, + delay_max=self.config.retry_delay_max, + stop_after_attempts_count=self.config.retry_stop_after_attempts_count, + stop_after_attempts_duration=self.config.retry_stop_after_attempts_duration, + delay_default=self.config.retry_delay_default, + force_dangerous_codes=self.config.retry_dangerous_codes, + ) + + # Common pool manager kwargs + pool_kwargs = { + 'num_pools': self.config.pool_connections, + 'maxsize': self.config.pool_maxsize, + 'retries': retry_policy, + 'timeout': urllib3.Timeout( + connect=self.config.socket_timeout, + read=self.config.socket_timeout + ) if self.config.socket_timeout else None, + 'ssl_context': ssl_context, + } + + # Create proxy or regular pool manager + if self.config.http_proxy: + proxy_headers = None + if self.config.proxy_username and self.config.proxy_password: + proxy_headers = make_headers( + proxy_basic_auth=f"{self.config.proxy_username}:{self.config.proxy_password}" + ) + + self._pool_manager = ProxyManager( + self.config.http_proxy, + proxy_headers=proxy_headers, + **pool_kwargs + ) + else: + self._pool_manager = PoolManager(**pool_kwargs) + + def _prepare_headers(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: + """Prepare headers for the request, including User-Agent.""" + request_headers = {} + + if self.config.user_agent: + request_headers['User-Agent'] = self.config.user_agent + + if headers: + request_headers.update(headers) + + return request_headers + + @contextmanager + def request_context( + self, + method: str, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs + ) -> Generator[urllib3.HTTPResponse, None, None]: + """ + Context manager for making HTTP requests with proper resource cleanup. + + Args: + method: HTTP method (GET, POST, PUT, DELETE) + url: URL to request + headers: Optional headers dict + **kwargs: Additional arguments passed to urllib3 request + + Yields: + urllib3.HTTPResponse: The HTTP response object + """ + logger.debug("Making %s request to %s", method, url) + + request_headers = self._prepare_headers(headers) + response = None + + try: + response = self._pool_manager.request( + method=method, + url=url, + headers=request_headers, + **kwargs + ) + yield response + except MaxRetryError as e: + logger.error("HTTP request failed after retries: %s", e) + raise RequestError(f"HTTP request failed: {e}") + except Exception as e: + logger.error("HTTP request error: %s", e) + raise RequestError(f"HTTP request error: {e}") + finally: + if response: + response.close() + + def request(self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs) -> urllib3.HTTPResponse: + """ + Make an HTTP request. + + Args: + method: HTTP method (GET, POST, PUT, DELETE, etc.) + url: URL to request + headers: Optional headers dict + **kwargs: Additional arguments passed to urllib3 request + + Returns: + urllib3.HTTPResponse: The HTTP response object with data pre-loaded + """ + with self.request_context(method, url, headers=headers, **kwargs) as response: + # Read the response data to ensure it's available after context exit + response._body = response.data + return response + + def upload_file(self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None) -> urllib3.HTTPResponse: + """ + Upload a file using PUT method. + + Args: + url: URL to upload to + file_path: Path to the file to upload + headers: Optional headers + + Returns: + urllib3.HTTPResponse: The response from the server + """ + with open(file_path, 'rb') as file_obj: + return self.request('PUT', url, body=file_obj.read(), headers=headers) + + def download_file(self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None) -> None: + """ + Download a file using GET method. + + Args: + url: URL to download from + file_path: Path where to save the downloaded file + headers: Optional headers + """ + response = self.request('GET', url, headers=headers) + with open(file_path, 'wb') as file_obj: + file_obj.write(response.data) + + def close(self): + """Close the underlying connection pools.""" + if self._pool_manager: + self._pool_manager.clear() + self._pool_manager = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +# Compatibility class to maintain requests-like interface for OAuth +class IgnoreNetrcAuth: + """ + Compatibility class for OAuth code that expects requests.auth.AuthBase interface. + This is a no-op auth handler since OAuth handles auth differently. + """ + def __call__(self, request): + return request \ No newline at end of file diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 9feb6e924..77673db9a 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -244,6 +244,7 @@ def __init__( session_id_hex=connection.get_session_id_hex(), statement_id=execute_response.command_id.to_hex_guid(), chunk_id=self.num_chunks, + http_client=connection.session.http_client, ) if t_row_set.resultLinks: self.num_chunks += len(t_row_set.resultLinks) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index f1bc35bee..d0c94b6ba 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -4,6 +4,7 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider +from databricks.sql.auth.common import ClientContext from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME @@ -11,6 +12,7 @@ from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.common.unified_http_client import UnifiedHttpClient logger = logging.getLogger(__name__) @@ -42,10 +44,6 @@ def __init__( self.schema = schema self.http_path = http_path - self.auth_provider = get_python_sql_connector_auth_provider( - server_hostname, **kwargs - ) - user_agent_entry = kwargs.get("user_agent_entry") if user_agent_entry is None: user_agent_entry = kwargs.get("_user_agent_entry") @@ -77,6 +75,15 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) + # Create HTTP client configuration and unified HTTP client + self.client_context = self._build_client_context(server_hostname, **kwargs) + self.http_client = UnifiedHttpClient(self.client_context) + + # Create auth provider with HTTP client context + self.auth_provider = get_python_sql_connector_auth_provider( + server_hostname, http_client=self.http_client, **kwargs + ) + self.backend = self._create_backend( server_hostname, http_path, @@ -88,6 +95,26 @@ def __init__( self.protocol_version = None + def _build_client_context(self, server_hostname: str, **kwargs) -> ClientContext: + """Build ClientContext with HTTP configuration from kwargs.""" + return ClientContext( + hostname=server_hostname, + ssl_options=self.ssl_options, + socket_timeout=kwargs.get("_socket_timeout"), + retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count"), + retry_delay_min=kwargs.get("_retry_delay_min"), + retry_delay_max=kwargs.get("_retry_delay_max"), + retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration"), + retry_delay_default=kwargs.get("_retry_delay_default"), + retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), + http_proxy=kwargs.get("http_proxy"), + proxy_username=kwargs.get("proxy_username"), + proxy_password=kwargs.get("proxy_password"), + pool_connections=kwargs.get("pool_connections"), + pool_maxsize=kwargs.get("pool_maxsize"), + user_agent=self.useragent_header, + ) + def _create_backend( self, server_hostname: str, @@ -185,3 +212,7 @@ def close(self) -> None: logger.error("Attempt to close session raised a local exception: %s", e) self.is_open = False + + # Close HTTP client if it exists + if hasattr(self, 'http_client') and self.http_client: + self.http_client.close() diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 55f06c8df..93cef3600 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -168,6 +168,7 @@ def __init__( host_url, executor, batch_size, + http_client, ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled @@ -180,7 +181,7 @@ def __init__( self._driver_connection_params = None self._host_url = host_url self._executor = executor - self._http_client = TelemetryHttpClient.get_instance() + self._http_client = http_client def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" @@ -228,19 +229,34 @@ def _send_telemetry(self, events): try: logger.debug("Submitting telemetry request to thread pool") + + # Use unified HTTP client future = self._executor.submit( - self._http_client.post, + self._send_with_unified_client, url, data=request.to_json(), headers=headers, timeout=900, ) + future.add_done_callback( lambda fut: self._telemetry_request_callback(fut, sent_count=sent_count) ) except Exception as e: logger.debug("Failed to submit telemetry request: %s", e) + def _send_with_unified_client(self, url, data, headers): + """Helper method to send telemetry using the unified HTTP client.""" + try: + response = self._http_client.request('POST', url, body=data, headers=headers, timeout=900) + # Convert urllib3 response to requests-like response for compatibility + response.status_code = response.status + response.json = lambda: json.loads(response.data.decode()) if response.data else {} + return response + except Exception as e: + logger.error("Failed to send telemetry with unified client: %s", e) + raise + def _telemetry_request_callback(self, future, sent_count: int): """Callback function to handle telemetry request completion""" try: @@ -431,6 +447,7 @@ def initialize_telemetry_client( auth_provider, host_url, batch_size, + http_client, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: @@ -453,6 +470,7 @@ def initialize_telemetry_client( host_url=host_url, executor=TelemetryClientFactory._executor, batch_size=batch_size, + http_client=http_client, ) else: TelemetryClientFactory._clients[ diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index c1d89ca5c..ff48e0e91 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -64,6 +64,7 @@ def build_queue( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, lz4_compressed: bool = True, description: List[Tuple] = [], ) -> ResultSetQueue: @@ -104,15 +105,16 @@ def build_queue( elif row_set_type == TSparkRowSetType.URL_BASED_SET: return ThriftCloudFetchQueue( schema_bytes=arrow_schema_bytes, - start_row_offset=t_row_set.startRowOffset, - result_links=t_row_set.resultLinks, - lz4_compressed=lz4_compressed, - description=description, max_download_threads=max_download_threads, ssl_options=ssl_options, session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, + http_client=http_client, + start_row_offset=t_row_set.startRowOffset, + result_links=t_row_set.resultLinks, + lz4_compressed=lz4_compressed, + description=description, ) else: raise AssertionError("Row set type is not valid") @@ -224,6 +226,7 @@ def __init__( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, schema_bytes: Optional[bytes] = None, lz4_compressed: bool = True, description: List[Tuple] = [], @@ -247,6 +250,7 @@ def __init__( self.session_id_hex = session_id_hex self.statement_id = statement_id self.chunk_id = chunk_id + self._http_client = http_client # Table state self.table = None @@ -261,6 +265,7 @@ def __init__( session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, + http_client=http_client, ) def next_n_rows(self, num_rows: int) -> "pyarrow.Table": @@ -370,6 +375,7 @@ def __init__( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, @@ -396,6 +402,7 @@ def __init__( session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, + http_client=http_client, ) self.start_row_index = start_row_offset From 30c04a66c7abd88f455b57d78dd2ae230ff4b0cc Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Fri, 8 Aug 2025 19:04:13 +0530 Subject: [PATCH 02/25] Some more fixes and aligned tests Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/auth.py | 4 +- src/databricks/sql/auth/oauth.py | 18 -- src/databricks/sql/backend/thrift_backend.py | 10 +- src/databricks/sql/client.py | 48 +++++ src/databricks/sql/session.py | 27 +-- .../sql/telemetry/telemetry_client.py | 6 +- tests/unit/test_auth.py | 58 ++++-- tests/unit/test_cloud_fetch_queue.py | 183 ++++-------------- tests/unit/test_download_manager.py | 2 + tests/unit/test_downloader.py | 162 +++++++++------- tests/unit/test_telemetry.py | 73 +++++-- tests/unit/test_telemetry_retry.py | 88 ++++----- 12 files changed, 336 insertions(+), 343 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index a8d0671b0..cc421e69e 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -10,7 +10,7 @@ from databricks.sql.auth.common import AuthType, ClientContext -def get_auth_provider(cfg: ClientContext): +def get_auth_provider(cfg: ClientContext, http_client): if cfg.credentials_provider: return ExternalAuthProvider(cfg.credentials_provider) elif cfg.auth_type == AuthType.AZURE_SP_M2M.value: @@ -113,4 +113,4 @@ def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs) oauth_persistence=kwargs.get("experimental_oauth_persistence"), credentials_provider=kwargs.get("credentials_provider"), ) - return get_auth_provider(cfg) + return get_auth_provider(cfg, http_client) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 0d67929a3..270287953 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -61,22 +61,6 @@ def refresh(self) -> Token: pass -class IgnoreNetrcAuth(requests.auth.AuthBase): - """This auth method is a no-op. - - We use it to force requestslib to not use .netrc to write auth headers - when making .post() requests to the oauth token endpoints, since these - don't require authentication. - - In cases where .netrc is outdated or corrupt, these requests will fail. - - See issue #121 - """ - - def __call__(self, r): - return r - - class OAuthManager: def __init__( self, @@ -103,7 +87,6 @@ def __fetch_well_known_config(self, hostname: str): known_config_url = self.idp_endpoint.get_openid_config_url(hostname) try: - from databricks.sql.common.unified_http_client import IgnoreNetrcAuth response = self.http_client.request('GET', url=known_config_url) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status @@ -214,7 +197,6 @@ def __send_token_request(token_request_url, data): "Content-Type": "application/x-www-form-urlencoded", } # Use unified HTTP client - from databricks.sql.common.unified_http_client import IgnoreNetrcAuth response = self.http_client.request( 'POST', url=token_request_url, body=data, headers=headers ) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index b404b1669..801632a41 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -105,6 +105,7 @@ def __init__( http_headers, auth_provider: AuthProvider, ssl_options: SSLOptions, + http_client=None, **kwargs, ): # Internal arguments in **kwargs: @@ -145,10 +146,8 @@ def __init__( # Number of threads for handling cloud fetch downloads. Defaults to 10 logger.debug( - "ThriftBackend.__init__(server_hostname=%s, port=%s, http_path=%s)", - server_hostname, - port, - http_path, + "ThriftBackend.__init__(server_hostname=%s, port=%s, http_path=%s)" + % (server_hostname, port, http_path) ) port = port or 443 @@ -177,8 +176,8 @@ def __init__( self._max_download_threads = kwargs.get("max_download_threads", 10) self._ssl_options = ssl_options - self._auth_provider = auth_provider + self._http_client = http_client # Connector version 3 retry approach self.enable_v3_retries = kwargs.get("_enable_v3_retries", True) @@ -1292,6 +1291,7 @@ def fetch_results( session_id_hex=self._session_id_hex, statement_id=command_id.to_hex_guid(), chunk_id=chunk_id, + http_client=self._http_client, ) return ( diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 295be29dc..50f252dbc 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -50,6 +50,9 @@ from databricks.sql.session import Session from databricks.sql.backend.types import CommandId, BackendType, CommandState, SessionId +from databricks.sql.auth.common import ClientContext +from databricks.sql.common.unified_http_client import UnifiedHttpClient + from databricks.sql.thrift_api.TCLIService.ttypes import ( TOpenSessionResp, TSparkParameter, @@ -251,10 +254,14 @@ def read(self) -> Optional[OAuthToken]: "telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE ) + client_context = self._build_client_context(server_hostname, **kwargs) + http_client = UnifiedHttpClient(client_context) + try: self.session = Session( server_hostname, http_path, + http_client, http_headers, session_configuration, catalog, @@ -270,6 +277,7 @@ def read(self) -> Optional[OAuthToken]: host_url=server_hostname, http_path=http_path, port=kwargs.get("_port", 443), + http_client=http_client, user_agent=self.session.useragent_header if hasattr(self, "session") else None, @@ -342,6 +350,46 @@ def _set_use_inline_params_with_warning(self, value: Union[bool, str]): return value + def _build_client_context(self, server_hostname: str, **kwargs): + """Build ClientContext for HTTP client configuration.""" + from databricks.sql.auth.common import ClientContext + from databricks.sql.types import SSLOptions + + # Extract SSL options + ssl_options = SSLOptions( + tls_verify=not kwargs.get("_tls_no_verify", False), + tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), + tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), + tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), + ) + + # Build user agent + user_agent_entry = kwargs.get("user_agent_entry", "") + if user_agent_entry: + user_agent = f"PyDatabricksSqlConnector/{__version__} ({user_agent_entry})" + else: + user_agent = f"PyDatabricksSqlConnector/{__version__}" + + return ClientContext( + hostname=server_hostname, + ssl_options=ssl_options, + socket_timeout=kwargs.get("_socket_timeout"), + retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count", 30), + retry_delay_min=kwargs.get("_retry_delay_min", 1.0), + retry_delay_max=kwargs.get("_retry_delay_max", 60.0), + retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration", 900.0), + retry_delay_default=kwargs.get("_retry_delay_default", 1.0), + retry_dangerous_codes=kwargs.get("_retry_dangerous_codes", []), + http_proxy=kwargs.get("_http_proxy"), + proxy_username=kwargs.get("_proxy_username"), + proxy_password=kwargs.get("_proxy_password"), + pool_connections=kwargs.get("_pool_connections", 1), + pool_maxsize=kwargs.get("_pool_maxsize", 1), + user_agent=user_agent, + ) + # The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently. def __enter__(self) -> "Connection": return self diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index d0c94b6ba..c9b4f939a 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -22,6 +22,7 @@ def __init__( self, server_hostname: str, http_path: str, + http_client: UnifiedHttpClient, http_headers: Optional[List[Tuple[str, str]]] = None, session_configuration: Optional[Dict[str, Any]] = None, catalog: Optional[str] = None, @@ -75,9 +76,8 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) - # Create HTTP client configuration and unified HTTP client - self.client_context = self._build_client_context(server_hostname, **kwargs) - self.http_client = UnifiedHttpClient(self.client_context) + # Use the provided HTTP client (created in Connection) + self.http_client = http_client # Create auth provider with HTTP client context self.auth_provider = get_python_sql_connector_auth_provider( @@ -95,26 +95,6 @@ def __init__( self.protocol_version = None - def _build_client_context(self, server_hostname: str, **kwargs) -> ClientContext: - """Build ClientContext with HTTP configuration from kwargs.""" - return ClientContext( - hostname=server_hostname, - ssl_options=self.ssl_options, - socket_timeout=kwargs.get("_socket_timeout"), - retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count"), - retry_delay_min=kwargs.get("_retry_delay_min"), - retry_delay_max=kwargs.get("_retry_delay_max"), - retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration"), - retry_delay_default=kwargs.get("_retry_delay_default"), - retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), - http_proxy=kwargs.get("http_proxy"), - proxy_username=kwargs.get("proxy_username"), - proxy_password=kwargs.get("proxy_password"), - pool_connections=kwargs.get("pool_connections"), - pool_maxsize=kwargs.get("pool_maxsize"), - user_agent=self.useragent_header, - ) - def _create_backend( self, server_hostname: str, @@ -142,6 +122,7 @@ def _create_backend( "http_headers": all_headers, "auth_provider": auth_provider, "ssl_options": self.ssl_options, + "http_client": self.http_client, "_use_arrow_native_complex_types": _use_arrow_native_complex_types, **kwargs, } diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 93cef3600..13c15486d 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -3,7 +3,6 @@ import logging from concurrent.futures import ThreadPoolExecutor from typing import Dict, Optional, TYPE_CHECKING -from databricks.sql.common.http import TelemetryHttpClient from databricks.sql.telemetry.models.event import ( TelemetryEvent, DriverSystemConfiguration, @@ -38,6 +37,8 @@ from databricks.sql.telemetry.utils import BaseTelemetryClient from databricks.sql.common.feature_flag import FeatureFlagsContextFactory +from src.databricks.sql.common.unified_http_client import UnifiedHttpClient + if TYPE_CHECKING: from databricks.sql.client import Connection @@ -511,7 +512,6 @@ def close(session_id_hex): try: TelemetryClientFactory._stop_flush_thread() TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryHttpClient.close() except Exception as e: logger.debug("Failed to shutdown thread pool executor: %s", e) TelemetryClientFactory._executor = None @@ -524,6 +524,7 @@ def connection_failure_log( host_url: str, http_path: str, port: int, + http_client: UnifiedHttpClient, user_agent: Optional[str] = None, ): """Send error telemetry when connection creation fails, without requiring a session""" @@ -536,6 +537,7 @@ def connection_failure_log( auth_provider=None, host_url=host_url, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=http_client, ) telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 8bf914708..2e210a9e0 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -24,8 +24,8 @@ AzureOAuthEndpointCollection, ) from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory -from databricks.sql.common.http import DatabricksHttpClient from databricks.sql.experimental.oauth_persistence import OAuthPersistenceCache +import json class Auth(unittest.TestCase): @@ -98,12 +98,14 @@ def test_oauth_auth_provider(self, mock_get_tokens, mock_check_and_refresh): ) in params: with self.subTest(cloud_type.value): oauth_persistence = OAuthPersistenceCache() + mock_http_client = MagicMock() auth_provider = DatabricksOAuthProvider( hostname=host, oauth_persistence=oauth_persistence, redirect_port_range=[8020], client_id=client_id, scopes=scopes, + http_client=mock_http_client, auth_type=AuthType.AZURE_OAUTH.value if use_azure_auth else AuthType.DATABRICKS_OAUTH.value, @@ -142,7 +144,8 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: def test_get_python_sql_connector_auth_provider_access_token(self): hostname = "moderakh-test.cloud.databricks.com" kwargs = {"access_token": "dpi123"} - auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider") headers = {} @@ -159,7 +162,8 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: hostname = "moderakh-test.cloud.databricks.com" kwargs = {"credentials_provider": MyProvider()} - auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider") headers = {} @@ -174,7 +178,8 @@ def test_get_python_sql_connector_auth_provider_noop(self): "_tls_client_cert_file": tls_client_cert_file, "_use_cert_as_auth": use_cert_as_auth, } - auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "CredentialProvider") def test_get_python_sql_connector_basic_auth(self): @@ -182,8 +187,9 @@ def test_get_python_sql_connector_basic_auth(self): "username": "username", "password": "password", } + mock_http_client = MagicMock() with self.assertRaises(ValueError) as e: - get_python_sql_connector_auth_provider("foo.cloud.databricks.com", **kwargs) + get_python_sql_connector_auth_provider("foo.cloud.databricks.com", mock_http_client, **kwargs) self.assertIn( "Username/password authentication is no longer supported", str(e.exception) ) @@ -191,7 +197,8 @@ def test_get_python_sql_connector_basic_auth(self): @patch.object(DatabricksOAuthProvider, "_initial_get_token") def test_get_python_sql_connector_default_auth(self, mock__initial_get_token): hostname = "foo.cloud.databricks.com" - auth_provider = get_python_sql_connector_auth_provider(hostname) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client) self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider") self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID) @@ -223,10 +230,12 @@ def status_response(response_status_code): @pytest.fixture def token_source(self): + mock_http_client = MagicMock() return ClientCredentialsTokenSource( token_url="https://token_url.com", client_id="client_id", client_secret="client_secret", + http_client=mock_http_client, ) def test_no_token_refresh__when_token_is_not_expired( @@ -249,10 +258,21 @@ def test_no_token_refresh__when_token_is_not_expired( assert mock_get_token.call_count == 1 def test_get_token_success(self, token_source, http_response): - databricks_http_client = DatabricksHttpClient.get_instance() - with patch.object( - databricks_http_client.session, "request", return_value=http_response(200) - ) as mock_request: + mock_http_client = MagicMock() + + with patch.object(token_source, "_http_client", mock_http_client): + # Create a mock response with the expected format + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "abc123", + "token_type": "Bearer", + "refresh_token": None, + } + # Mock the context manager (execute returns context manager) + mock_http_client.execute.return_value.__enter__.return_value = mock_response + mock_http_client.execute.return_value.__exit__.return_value = None + token = token_source.get_token() # Assert @@ -262,11 +282,19 @@ def test_get_token_success(self, token_source, http_response): assert token.refresh_token is None def test_get_token_failure(self, token_source, http_response): - databricks_http_client = DatabricksHttpClient.get_instance() - with patch.object( - databricks_http_client.session, "request", return_value=http_response(400) - ) as mock_request: - with pytest.raises(Exception) as e: + mock_http_client = MagicMock() + + with patch.object(token_source, "_http_client", mock_http_client): + # Create a mock response with error + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.text = "Bad Request" + mock_response.json.return_value = {"error": "invalid_client"} + # Mock the context manager (execute returns context manager) + mock_http_client.execute.return_value.__enter__.return_value = mock_response + mock_http_client.execute.return_value.__exit__.return_value = None + + with pytest.raises(Exception): token_source.get_token() assert "Failed to get token: 400" in str(e.value) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index faa8e2f99..0c3fc7103 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -13,6 +13,31 @@ @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") class CloudFetchQueueSuite(unittest.TestCase): + def create_queue(self, schema_bytes=None, result_links=None, description=None, **kwargs): + """Helper method to create ThriftCloudFetchQueue with sensible defaults""" + # Set up defaults for commonly used parameters + defaults = { + 'max_download_threads': 10, + 'ssl_options': SSLOptions(), + 'session_id_hex': Mock(), + 'statement_id': Mock(), + 'chunk_id': 0, + 'start_row_offset': 0, + 'lz4_compressed': True, + } + + # Override defaults with any provided kwargs + defaults.update(kwargs) + + mock_http_client = MagicMock() + return utils.ThriftCloudFetchQueue( + schema_bytes=schema_bytes or MagicMock(), + result_links=result_links or [], + description=description or [], + http_client=mock_http_client, + **defaults + ) + def create_result_link( self, file_link: str = "fileLink", @@ -58,15 +83,7 @@ def get_schema_bytes(): def test_initializer_adds_links(self, mock_create_next_table): schema_bytes = MagicMock() result_links = self.create_result_links(10) - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=result_links, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, result_links=result_links) assert len(queue.download_manager._pending_links) == 10 assert len(queue.download_manager._download_tasks) == 0 @@ -74,16 +91,7 @@ def test_initializer_adds_links(self, mock_create_next_table): def test_initializer_no_links_to_add(self): schema_bytes = MagicMock() - result_links = [] - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=result_links, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, result_links=[]) assert len(queue.download_manager._pending_links) == 0 assert len(queue.download_manager._download_tasks) == 0 @@ -94,15 +102,7 @@ def test_initializer_no_links_to_add(self): return_value=None, ) def test_create_next_table_no_download(self, mock_get_next_downloaded_file): - queue = utils.ThriftCloudFetchQueue( - MagicMock(), - result_links=[], - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=MagicMock(), result_links=[]) assert queue._create_next_table() is None mock_get_next_downloaded_file.assert_called_with(0) @@ -117,16 +117,7 @@ def test_initializer_create_next_table_success( ): mock_create_arrow_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) expected_result = self.make_arrow_table() mock_get_next_downloaded_file.assert_called_with(0) @@ -145,16 +136,7 @@ def test_initializer_create_next_table_success( def test_next_n_rows_0_rows(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -167,16 +149,7 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): def test_next_n_rows_partial_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -190,16 +163,7 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): def test_next_n_rows_more_than_one_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -218,16 +182,7 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -242,17 +197,9 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): ) def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() - description = MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + # Create description that matches the 4-column schema + description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None result = queue.next_n_rows(100) @@ -263,16 +210,7 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None, 0] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 4 @@ -285,16 +223,7 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 2 @@ -307,16 +236,7 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -335,16 +255,7 @@ def test_remaining_rows_multiple_tables_fully_returned( None, ] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 3 @@ -365,17 +276,9 @@ def test_remaining_rows_multiple_tables_fully_returned( ) def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() - description = MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + # Create description that matches the 4-column schema + description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None result = queue.remaining_rows() diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index 6eb17a05a..1c77226a9 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -14,6 +14,7 @@ class DownloadManagerTests(unittest.TestCase): def create_download_manager( self, links, max_download_threads=10, lz4_compressed=True ): + mock_http_client = MagicMock() return download_manager.ResultFileDownloadManager( links, max_download_threads, @@ -22,6 +23,7 @@ def create_download_manager( session_id_hex=Mock(), statement_id=Mock(), chunk_id=0, + http_client=mock_http_client, ) def create_result_link( diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index c514980ee..00b1b849a 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -1,21 +1,19 @@ -from contextlib import contextmanager import unittest -from unittest.mock import Mock, patch, MagicMock - +from unittest.mock import patch, MagicMock, Mock import requests import databricks.sql.cloudfetch.downloader as downloader -from databricks.sql.common.http import DatabricksHttpClient from databricks.sql.exc import Error from databricks.sql.types import SSLOptions -def create_response(**kwargs) -> requests.Response: - result = requests.Response() +def create_mock_response(**kwargs): + """Create a mock response object for testing""" + mock_response = MagicMock() for k, v in kwargs.items(): - setattr(result, k, v) - result.close = Mock() - return result + setattr(mock_response, k, v) + mock_response.close = Mock() + return mock_response class DownloaderTests(unittest.TestCase): @@ -23,6 +21,17 @@ class DownloaderTests(unittest.TestCase): Unit tests for checking downloader logic. """ + def _setup_mock_http_response(self, mock_http_client, status=200, data=b""): + """Helper method to setup mock HTTP client with response context manager.""" + mock_response = MagicMock() + mock_response.status = status + mock_response.data = data + mock_context_manager = MagicMock() + mock_context_manager.__enter__.return_value = mock_response + mock_context_manager.__exit__.return_value = None + mock_http_client.request_context.return_value = mock_context_manager + return mock_response + def _setup_time_mock_for_download(self, mock_time, end_time): """Helper to setup time mock that handles logging system calls.""" call_count = [0] @@ -38,6 +47,7 @@ def time_side_effect(): @patch("time.time", return_value=1000) def test_run_link_expired(self, mock_time): + mock_http_client = MagicMock() settings = Mock() result_link = Mock() # Already expired @@ -49,6 +59,7 @@ def test_run_link_expired(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) with self.assertRaises(Error) as context: @@ -59,6 +70,7 @@ def test_run_link_expired(self, mock_time): @patch("time.time", return_value=1000) def test_run_link_past_expiry_buffer(self, mock_time): + mock_http_client = MagicMock() settings = Mock(link_expiry_buffer_secs=5) result_link = Mock() # Within the expiry buffer time @@ -70,6 +82,7 @@ def test_run_link_past_expiry_buffer(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) with self.assertRaises(Error) as context: @@ -80,46 +93,45 @@ def test_run_link_past_expiry_buffer(self, mock_time): @patch("time.time", return_value=1000) def test_run_get_response_not_ok(self, mock_time): - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() settings = Mock(link_expiry_buffer_secs=0, download_timeout=0) settings.download_timeout = 0 settings.use_proxy = False result_link = Mock(expiryTime=1001) - with patch.object( - http_client, - "execute", - return_value=create_response(status_code=404, _content=b"1234"), - ): - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) - with self.assertRaises(requests.exceptions.HTTPError) as context: - d.run() - self.assertTrue("404" in str(context.exception)) + # Setup mock HTTP response using helper method + self._setup_mock_http_response(mock_http_client, status=404, data=b"1234") + + d = downloader.ResultSetDownloadHandler( + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), + http_client=mock_http_client, + ) + with self.assertRaises(Exception) as context: + d.run() + self.assertTrue("404" in str(context.exception)) @patch("time.time") def test_run_uncompressed_successful(self, mock_time): self._setup_time_mock_for_download(mock_time, 1000.5) - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() file_bytes = b"1234567890" * 10 settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = False settings.min_cloudfetch_download_speed = 1.0 - result_link = Mock(bytesNum=100, expiryTime=1001) - result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=abc123" + result_link = Mock(expiryTime=1001, bytesNum=len(file_bytes)) + result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789" + + # Setup mock HTTP response using helper method + self._setup_mock_http_response(mock_http_client, status=200, data=file_bytes) - with patch.object( - http_client, - "execute", - return_value=create_response(status_code=200, _content=file_bytes), - ): + # Patch the log metrics method to avoid division by zero + with patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): d = downloader.ResultSetDownloadHandler( settings, result_link, @@ -127,29 +139,32 @@ def test_run_uncompressed_successful(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) file = d.run() - - assert file.file_bytes == b"1234567890" * 10 + self.assertEqual(file.file_bytes, file_bytes) + self.assertEqual(file.start_row_offset, result_link.startRowOffset) + self.assertEqual(file.row_count, result_link.rowCount) @patch("time.time") def test_run_compressed_successful(self, mock_time): self._setup_time_mock_for_download(mock_time, 1000.2) - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() file_bytes = b"1234567890" * 10 compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = True settings.min_cloudfetch_download_speed = 1.0 - result_link = Mock(bytesNum=100, expiryTime=1001) + result_link = Mock(expiryTime=1001, bytesNum=len(file_bytes)) result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789" - with patch.object( - http_client, - "execute", - return_value=create_response(status_code=200, _content=compressed_bytes), - ): + + # Setup mock HTTP response using helper method + self._setup_mock_http_response(mock_http_client, status=200, data=compressed_bytes) + + # Mock the decompression method and log metrics to avoid issues + with patch.object(downloader.ResultSetDownloadHandler, '_decompress_data', return_value=file_bytes), \ + patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): d = downloader.ResultSetDownloadHandler( settings, result_link, @@ -157,48 +172,53 @@ def test_run_compressed_successful(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) file = d.run() - - assert file.file_bytes == b"1234567890" * 10 + self.assertEqual(file.file_bytes, file_bytes) + self.assertEqual(file.start_row_offset, result_link.startRowOffset) + self.assertEqual(file.row_count, result_link.rowCount) @patch("time.time", return_value=1000) def test_download_connection_error(self, mock_time): - - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() settings = Mock( link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True ) result_link = Mock(bytesNum=100, expiryTime=1001) - with patch.object(http_client, "execute", side_effect=ConnectionError("foo")): - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) - with self.assertRaises(ConnectionError): - d.run() + mock_http_client.request_context.side_effect = ConnectionError("foo") + + d = downloader.ResultSetDownloadHandler( + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), + http_client=mock_http_client, + ) + with self.assertRaises(ConnectionError): + d.run() @patch("time.time", return_value=1000) def test_download_timeout(self, mock_time): - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() settings = Mock( link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True ) result_link = Mock(bytesNum=100, expiryTime=1001) - with patch.object(http_client, "execute", side_effect=TimeoutError("foo")): - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) - with self.assertRaises(TimeoutError): - d.run() + mock_http_client.request_context.side_effect = TimeoutError("foo") + + d = downloader.ResultSetDownloadHandler( + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), + http_client=mock_http_client, + ) + with self.assertRaises(TimeoutError): + d.run() diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index d85e41719..989b2351c 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -1,6 +1,7 @@ import uuid import pytest from unittest.mock import patch, MagicMock +import json from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, @@ -23,6 +24,7 @@ def mock_telemetry_client(): session_id = str(uuid.uuid4()) auth_provider = AccessTokenAuthProvider("test-token") executor = MagicMock() + mock_http_client = MagicMock() return TelemetryClient( telemetry_enabled=True, @@ -31,6 +33,7 @@ def mock_telemetry_client(): host_url="test-host.com", executor=executor, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) @@ -72,10 +75,15 @@ def test_event_batching_and_flushing_flow(self, mock_telemetry_client): mock_send.assert_called_once() assert len(client._events_batch) == 0 # Batch cleared after flush - @patch("requests.post") - def test_network_request_flow(self, mock_post, mock_telemetry_client): + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") + def test_network_request_flow(self, mock_http_request, mock_telemetry_client): """Test the complete network request flow with authentication.""" - mock_post.return_value.status_code = 200 + # Mock response for unified HTTP client + mock_response = MagicMock() + mock_response.status = 200 + mock_response.status_code = 200 + mock_http_request.return_value = mock_response + client = mock_telemetry_client # Create mock events @@ -91,7 +99,7 @@ def test_network_request_flow(self, mock_post, mock_telemetry_client): args, kwargs = client._executor.submit.call_args # Verify correct function and URL - assert args[0] == client._http_client.post + assert args[0] == client._send_with_unified_client assert args[1] == "https://test-host.com/telemetry-ext" assert kwargs["headers"]["Authorization"] == "Bearer test-token" @@ -208,6 +216,7 @@ def test_client_lifecycle_flow(self): """Test complete client lifecycle: initialize -> use -> close.""" session_id_hex = "test-session" auth_provider = AccessTokenAuthProvider("token") + mock_http_client = MagicMock() # Initialize enabled client TelemetryClientFactory.initialize_telemetry_client( @@ -216,6 +225,7 @@ def test_client_lifecycle_flow(self): auth_provider=auth_provider, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -234,6 +244,7 @@ def test_client_lifecycle_flow(self): def test_disabled_telemetry_flow(self): """Test that disabled telemetry uses NoopTelemetryClient.""" session_id_hex = "test-session" + mock_http_client = MagicMock() TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, @@ -241,6 +252,7 @@ def test_disabled_telemetry_flow(self): auth_provider=None, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -249,6 +261,7 @@ def test_disabled_telemetry_flow(self): def test_factory_error_handling(self): """Test that factory errors fall back to NoopTelemetryClient.""" session_id = "test-session" + mock_http_client = MagicMock() # Simulate initialization error with patch( @@ -261,6 +274,7 @@ def test_factory_error_handling(self): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) # Should fall back to NoopTelemetryClient @@ -271,6 +285,7 @@ def test_factory_shutdown_flow(self): """Test factory shutdown when last client is removed.""" session1 = "session-1" session2 = "session-2" + mock_http_client = MagicMock() # Initialize multiple clients for session in [session1, session2]: @@ -280,6 +295,7 @@ def test_factory_shutdown_flow(self): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) # Factory should be initialized @@ -325,10 +341,11 @@ def test_connection_failure_sends_correct_telemetry_payload( class TestTelemetryFeatureFlag: """Tests the interaction between the telemetry feature flag and connection parameters.""" - def _mock_ff_response(self, mock_requests_get, enabled: bool): - """Helper to configure the mock response for the feature flag endpoint.""" + def _mock_ff_response(self, mock_http_request, enabled: bool): + """Helper method to mock feature flag response for unified HTTP client.""" mock_response = MagicMock() - mock_response.status_code = 200 + mock_response.status = 200 + mock_response.status_code = 200 # Compatibility attribute payload = { "flags": [ { @@ -339,15 +356,21 @@ def _mock_ff_response(self, mock_requests_get, enabled: bool): "ttl_seconds": 3600, } mock_response.json.return_value = payload - mock_requests_get.return_value = mock_response + mock_response.data = json.dumps(payload).encode() + mock_http_request.return_value = mock_response - @patch("databricks.sql.common.feature_flag.requests.get") - def test_telemetry_enabled_when_flag_is_true(self, mock_requests_get, MockSession): + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") + def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSession): """Telemetry should be ON when enable_telemetry=True and server flag is 'true'.""" - self._mock_ff_response(mock_requests_get, enabled=True) + self._mock_ff_response(mock_http_request, enabled=True) mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + + # Set up mock HTTP client on the session + mock_http_client = MagicMock() + mock_http_client.request = mock_http_request + mock_session_instance.http_client = mock_http_client conn = sql.client.Connection( server_hostname="test", @@ -357,19 +380,24 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_requests_get, MockSessio ) assert conn.telemetry_enabled is True - mock_requests_get.assert_called_once() + mock_http_request.assert_called_once() client = TelemetryClientFactory.get_telemetry_client("test-session-ff-true") assert isinstance(client, TelemetryClient) - @patch("databricks.sql.common.feature_flag.requests.get") + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") def test_telemetry_disabled_when_flag_is_false( - self, mock_requests_get, MockSession + self, mock_http_request, MockSession ): """Telemetry should be OFF when enable_telemetry=True but server flag is 'false'.""" - self._mock_ff_response(mock_requests_get, enabled=False) + self._mock_ff_response(mock_http_request, enabled=False) mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + + # Set up mock HTTP client on the session + mock_http_client = MagicMock() + mock_http_client.request = mock_http_request + mock_session_instance.http_client = mock_http_client conn = sql.client.Connection( server_hostname="test", @@ -379,19 +407,24 @@ def test_telemetry_disabled_when_flag_is_false( ) assert conn.telemetry_enabled is False - mock_requests_get.assert_called_once() + mock_http_request.assert_called_once() client = TelemetryClientFactory.get_telemetry_client("test-session-ff-false") assert isinstance(client, NoopTelemetryClient) - @patch("databricks.sql.common.feature_flag.requests.get") + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") def test_telemetry_disabled_when_flag_request_fails( - self, mock_requests_get, MockSession + self, mock_http_request, MockSession ): """Telemetry should default to OFF if the feature flag network request fails.""" - mock_requests_get.side_effect = Exception("Network is down") + mock_http_request.side_effect = Exception("Network is down") mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + + # Set up mock HTTP client on the session + mock_http_client = MagicMock() + mock_http_client.request = mock_http_request + mock_session_instance.http_client = mock_http_client conn = sql.client.Connection( server_hostname="test", @@ -401,6 +434,6 @@ def test_telemetry_disabled_when_flag_request_fails( ) assert conn.telemetry_enabled is False - mock_requests_get.assert_called_once() + mock_http_request.assert_called_once() client = TelemetryClientFactory.get_telemetry_client("test-session-ff-fail") assert isinstance(client, NoopTelemetryClient) diff --git a/tests/unit/test_telemetry_retry.py b/tests/unit/test_telemetry_retry.py index d5287deb9..f0bdddd60 100644 --- a/tests/unit/test_telemetry_retry.py +++ b/tests/unit/test_telemetry_retry.py @@ -6,27 +6,23 @@ from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory from databricks.sql.auth.retry import DatabricksRetryPolicy -PATCH_TARGET = "urllib3.connectionpool.HTTPSConnectionPool._get_conn" +PATCH_TARGET = "databricks.sql.common.unified_http_client.UnifiedHttpClient.request" -def create_mock_conn(responses): - """Creates a mock connection object whose getresponse() method yields a series of responses.""" - mock_conn = MagicMock() - mock_http_responses = [] +def create_mock_response(responses): + """Creates mock urllib3 HTTPResponse objects for the given response specifications.""" + mock_responses = [] for resp in responses: - mock_http_response = MagicMock() - mock_http_response.status = resp.get("status") - mock_http_response.headers = resp.get("headers", {}) - body = resp.get("body", b"{}") - mock_http_response.fp = io.BytesIO(body) - - def release(): - mock_http_response.fp.close() - - mock_http_response.release_conn = release - mock_http_responses.append(mock_http_response) - mock_conn.getresponse.side_effect = mock_http_responses - return mock_conn + mock_response = MagicMock() + mock_response.status = resp.get("status") + mock_response.status_code = resp.get("status") # Add status_code for compatibility + mock_response.headers = resp.get("headers", {}) + mock_response.data = resp.get("body", b"{}") + mock_response.ok = resp.get("status", 200) < 400 + mock_response.text = resp.get("body", b"{}").decode() if isinstance(resp.get("body", b"{}"), bytes) else str(resp.get("body", "{}")) + mock_response.json = lambda: {} # Simple json mock + mock_responses.append(mock_response) + return mock_responses class TestTelemetryClientRetries: @@ -43,30 +39,16 @@ def setup_and_teardown(self): TelemetryClientFactory._executor = None def get_client(self, session_id, num_retries=3): - """ - Configures a client with a specific number of retries. - """ + mock_http_client = MagicMock() TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id, auth_provider=None, host_url="test.databricks.com", - batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - ) - client = TelemetryClientFactory.get_telemetry_client(session_id) - - retry_policy = DatabricksRetryPolicy( - delay_min=0.01, - delay_max=0.02, - stop_after_attempts_duration=2.0, - stop_after_attempts_count=num_retries, - delay_default=0.1, - force_dangerous_codes=[], - urllib3_kwargs={"total": num_retries}, + batch_size=1, # Use batch size of 1 to trigger immediate HTTP requests + http_client=mock_http_client, ) - adapter = client._http_client.session.adapters.get("https://") - adapter.max_retries = retry_policy - return client + return TelemetryClientFactory.get_telemetry_client(session_id) @pytest.mark.parametrize( "status_code, description", @@ -85,13 +67,19 @@ def test_non_retryable_status_codes_are_not_retried(self, status_code, descripti client = self.get_client(f"session-{status_code}") mock_responses = [{"status": status_code}] - with patch( - PATCH_TARGET, return_value=create_mock_conn(mock_responses) - ) as mock_get_conn: + mock_response = create_mock_response(mock_responses)[0] + with patch(PATCH_TARGET, return_value=mock_response) as mock_request: client.export_failure_log("TestError", "Test message") + + # Wait a moment for async operations to complete + time.sleep(0.1) + TelemetryClientFactory.close(client._session_id_hex) + + # Wait a bit more for any final operations + time.sleep(0.1) - mock_get_conn.return_value.getresponse.assert_called_once() + mock_request.assert_called_once() def test_exceeds_retry_count_limit(self): """ @@ -103,22 +91,28 @@ def test_exceeds_retry_count_limit(self): retry_after = 1 client = self.get_client("session-exceed-limit", num_retries=num_retries) mock_responses = [ - {"status": 503, "headers": {"Retry-After": str(retry_after)}}, - {"status": 429}, + {"status": 429, "headers": {"Retry-After": str(retry_after)}}, {"status": 502}, {"status": 503}, + {"status": 200}, ] - with patch( - PATCH_TARGET, return_value=create_mock_conn(mock_responses) - ) as mock_get_conn: + mock_response_objects = create_mock_response(mock_responses) + with patch(PATCH_TARGET, side_effect=mock_response_objects) as mock_request: start_time = time.time() client.export_failure_log("TestError", "Test message") + + # Wait for async operations to complete + time.sleep(0.2) + TelemetryClientFactory.close(client._session_id_hex) + + # Wait for any final operations + time.sleep(0.2) + end_time = time.time() assert ( - mock_get_conn.return_value.getresponse.call_count + mock_request.call_count == expected_total_calls ) - assert end_time - start_time > retry_after From 429460082749de360c9e86e55772f093deeca05e Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Fri, 8 Aug 2025 19:25:27 +0530 Subject: [PATCH 03/25] Fix all tests Signed-off-by: Vikrant Puppala --- src/databricks/sql/client.py | 2 +- tests/unit/test_auth.py | 2 +- tests/unit/test_sea_queue.py | 23 +++++- tests/unit/test_session.py | 3 +- tests/unit/test_telemetry_retry.py | 118 ----------------------------- 5 files changed, 23 insertions(+), 125 deletions(-) delete mode 100644 tests/unit/test_telemetry_retry.py diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 50f252dbc..7323b939a 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -443,7 +443,7 @@ def get_protocol_version(openSessionResp: TOpenSessionResp): @property def open(self) -> bool: """Return whether the connection is open by checking if the session is open.""" - return self.session.is_open + return hasattr(self, 'session') and self.session.is_open def cursor( self, diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 2e210a9e0..333782fd8 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -294,7 +294,7 @@ def test_get_token_failure(self, token_source, http_response): mock_http_client.execute.return_value.__enter__.return_value = mock_response mock_http_client.execute.return_value.__exit__.return_value = None - with pytest.raises(Exception): + with pytest.raises(Exception) as e: token_source.get_token() assert "Failed to get token: 400" in str(e.value) diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index cbeae098b..6471cb4fd 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -7,7 +7,7 @@ """ import pytest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock from databricks.sql.backend.sea.queue import ( JsonQueue, @@ -184,6 +184,7 @@ def description(self): def test_build_queue_json_array(self, json_manifest, sample_data): """Test building a JSON array queue.""" result_data = ResultData(data=sample_data) + mock_http_client = MagicMock() queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, @@ -194,6 +195,7 @@ def test_build_queue_json_array(self, json_manifest, sample_data): max_download_threads=10, sea_client=Mock(), lz4_compressed=False, + http_client=mock_http_client, ) assert isinstance(queue, JsonQueue) @@ -217,6 +219,8 @@ def test_build_queue_arrow_stream( ] result_data = ResultData(data=None, external_links=external_links) + mock_http_client = MagicMock() + with patch( "databricks.sql.backend.sea.queue.ResultFileDownloadManager" ), patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None): @@ -229,6 +233,7 @@ def test_build_queue_arrow_stream( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=False, + http_client=mock_http_client, ) assert isinstance(queue, SeaCloudFetchQueue) @@ -236,6 +241,7 @@ def test_build_queue_arrow_stream( def test_build_queue_invalid_format(self, invalid_manifest): """Test building a queue with invalid format.""" result_data = ResultData(data=[]) + mock_http_client = MagicMock() with pytest.raises(ProgrammingError, match="Invalid result format"): SeaResultSetQueueFactory.build_queue( @@ -247,6 +253,7 @@ def test_build_queue_invalid_format(self, invalid_manifest): max_download_threads=10, sea_client=Mock(), lz4_compressed=False, + http_client=mock_http_client, ) @@ -339,6 +346,7 @@ def test_init_with_valid_initial_link( ): """Test initialization with valid initial link.""" # Create a queue with valid initial link + mock_http_client = MagicMock() with patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None): queue = SeaCloudFetchQueue( result_data=ResultData(external_links=[sample_external_link]), @@ -349,6 +357,7 @@ def test_init_with_valid_initial_link( total_chunk_count=1, lz4_compressed=False, description=description, + http_client=mock_http_client, ) # Verify attributes @@ -367,6 +376,7 @@ def test_init_no_initial_links( ): """Test initialization with no initial links.""" # Create a queue with empty initial links + mock_http_client = MagicMock() queue = SeaCloudFetchQueue( result_data=ResultData(external_links=[]), max_download_threads=5, @@ -376,6 +386,7 @@ def test_init_no_initial_links( total_chunk_count=0, lz4_compressed=False, description=description, + http_client=mock_http_client, ) assert queue.table is None @@ -462,7 +473,7 @@ def test_hybrid_disposition_with_attachment( # Create result data with attachment attachment_data = b"mock_arrow_data" result_data = ResultData(attachment=attachment_data) - + mock_http_client = MagicMock() # Build queue queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, @@ -473,6 +484,7 @@ def test_hybrid_disposition_with_attachment( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=False, + http_client=mock_http_client, ) # Verify ArrowQueue was created @@ -508,7 +520,8 @@ def test_hybrid_disposition_with_external_links( # Create result data with external links but no attachment result_data = ResultData(external_links=external_links, attachment=None) - # Build queue + # Build queue + mock_http_client = MagicMock() queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, manifest=arrow_manifest, @@ -518,6 +531,7 @@ def test_hybrid_disposition_with_external_links( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=False, + http_client=mock_http_client, ) # Verify SeaCloudFetchQueue was created @@ -548,7 +562,7 @@ def test_hybrid_disposition_with_compressed_attachment( # Create result data with attachment result_data = ResultData(attachment=compressed_data) - + mock_http_client = MagicMock() # Build queue with lz4_compressed=True queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, @@ -559,6 +573,7 @@ def test_hybrid_disposition_with_compressed_attachment( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=True, + http_client=mock_http_client, ) # Verify ArrowQueue was created with decompressed data diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6823b1b33..e019e05a2 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -75,8 +75,9 @@ def test_http_header_passthrough(self, mock_client_class): call_kwargs = mock_client_class.call_args[1] assert ("foo", "bar") in call_kwargs["http_headers"] + @patch("%s.client.UnifiedHttpClient" % PACKAGE_NAME) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) - def test_tls_arg_passthrough(self, mock_client_class): + def test_tls_arg_passthrough(self, mock_client_class, mock_http_client): databricks.sql.connect( **self.DUMMY_CONNECTION_ARGS, _tls_verify_hostname="hostname", diff --git a/tests/unit/test_telemetry_retry.py b/tests/unit/test_telemetry_retry.py deleted file mode 100644 index f0bdddd60..000000000 --- a/tests/unit/test_telemetry_retry.py +++ /dev/null @@ -1,118 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -import io -import time - -from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory -from databricks.sql.auth.retry import DatabricksRetryPolicy - -PATCH_TARGET = "databricks.sql.common.unified_http_client.UnifiedHttpClient.request" - - -def create_mock_response(responses): - """Creates mock urllib3 HTTPResponse objects for the given response specifications.""" - mock_responses = [] - for resp in responses: - mock_response = MagicMock() - mock_response.status = resp.get("status") - mock_response.status_code = resp.get("status") # Add status_code for compatibility - mock_response.headers = resp.get("headers", {}) - mock_response.data = resp.get("body", b"{}") - mock_response.ok = resp.get("status", 200) < 400 - mock_response.text = resp.get("body", b"{}").decode() if isinstance(resp.get("body", b"{}"), bytes) else str(resp.get("body", "{}")) - mock_response.json = lambda: {} # Simple json mock - mock_responses.append(mock_response) - return mock_responses - - -class TestTelemetryClientRetries: - @pytest.fixture(autouse=True) - def setup_and_teardown(self): - TelemetryClientFactory._initialized = False - TelemetryClientFactory._clients = {} - TelemetryClientFactory._executor = None - yield - if TelemetryClientFactory._executor: - TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryClientFactory._initialized = False - TelemetryClientFactory._clients = {} - TelemetryClientFactory._executor = None - - def get_client(self, session_id, num_retries=3): - mock_http_client = MagicMock() - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=True, - session_id_hex=session_id, - auth_provider=None, - host_url="test.databricks.com", - batch_size=1, # Use batch size of 1 to trigger immediate HTTP requests - http_client=mock_http_client, - ) - return TelemetryClientFactory.get_telemetry_client(session_id) - - @pytest.mark.parametrize( - "status_code, description", - [ - (401, "Unauthorized"), - (403, "Forbidden"), - (501, "Not Implemented"), - (200, "Success"), - ], - ) - def test_non_retryable_status_codes_are_not_retried(self, status_code, description): - """ - Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried. - """ - # Use the status code in the session ID for easier debugging if it fails - client = self.get_client(f"session-{status_code}") - mock_responses = [{"status": status_code}] - - mock_response = create_mock_response(mock_responses)[0] - with patch(PATCH_TARGET, return_value=mock_response) as mock_request: - client.export_failure_log("TestError", "Test message") - - # Wait a moment for async operations to complete - time.sleep(0.1) - - TelemetryClientFactory.close(client._session_id_hex) - - # Wait a bit more for any final operations - time.sleep(0.1) - - mock_request.assert_called_once() - - def test_exceeds_retry_count_limit(self): - """ - Verifies that the client retries up to the specified number of times before giving up. - Verifies that the client respects the Retry-After header and retries on 429, 502, 503. - """ - num_retries = 3 - expected_total_calls = num_retries + 1 - retry_after = 1 - client = self.get_client("session-exceed-limit", num_retries=num_retries) - mock_responses = [ - {"status": 429, "headers": {"Retry-After": str(retry_after)}}, - {"status": 502}, - {"status": 503}, - {"status": 200}, - ] - - mock_response_objects = create_mock_response(mock_responses) - with patch(PATCH_TARGET, side_effect=mock_response_objects) as mock_request: - start_time = time.time() - client.export_failure_log("TestError", "Test message") - - # Wait for async operations to complete - time.sleep(0.2) - - TelemetryClientFactory.close(client._session_id_hex) - - # Wait for any final operations - time.sleep(0.2) - - end_time = time.time() - - assert ( - mock_request.call_count - == expected_total_calls - ) From 31552117d01160d59980a201a5c47d7135eb4040 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Fri, 8 Aug 2025 19:27:20 +0530 Subject: [PATCH 04/25] fmt Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/common.py | 12 +- src/databricks/sql/auth/oauth.py | 4 +- src/databricks/sql/client.py | 34 ++++-- src/databricks/sql/cloudfetch/downloader.py | 4 +- src/databricks/sql/common/feature_flag.py | 12 +- .../sql/common/unified_http_client.py | 109 +++++++++--------- src/databricks/sql/session.py | 4 +- .../sql/telemetry/telemetry_client.py | 12 +- 8 files changed, 108 insertions(+), 83 deletions(-) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 262166a52..61b07cb91 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -65,14 +65,16 @@ def __init__( self.tls_client_cert_file = tls_client_cert_file self.oauth_persistence = oauth_persistence self.credentials_provider = credentials_provider - + # HTTP client configuration self.ssl_options = ssl_options self.socket_timeout = socket_timeout self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 30 self.retry_delay_min = retry_delay_min or 1.0 self.retry_delay_max = retry_delay_max or 60.0 - self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration or 900.0 + self.retry_stop_after_attempts_duration = ( + retry_stop_after_attempts_duration or 900.0 + ) self.retry_delay_default = retry_delay_default or 5.0 self.retry_dangerous_codes = retry_dangerous_codes or [] self.http_proxy = http_proxy @@ -110,8 +112,8 @@ def get_azure_tenant_id_from_host(host: str, http_client) -> str: login_url = f"{host}/aad/auth" logger.debug("Loading tenant ID from %s", login_url) - - with http_client.request_context('GET', login_url, allow_redirects=False) as resp: + + with http_client.request_context("GET", login_url, allow_redirects=False) as resp: if resp.status // 100 != 3: raise ValueError( f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}" @@ -119,7 +121,7 @@ def get_azure_tenant_id_from_host(host: str, http_client) -> str: entra_id_endpoint = dict(resp.headers).get("Location") if entra_id_endpoint is None: raise ValueError(f"No Location header in response from {login_url}") - + # The Location header has the following form: https://login.microsoftonline.com//oauth2/authorize?... # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud). url = urlparse(entra_id_endpoint) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 270287953..7f96a2303 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -87,7 +87,7 @@ def __fetch_well_known_config(self, hostname: str): known_config_url = self.idp_endpoint.get_openid_config_url(hostname) try: - response = self.http_client.request('GET', url=known_config_url) + response = self.http_client.request("GET", url=known_config_url) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status response.json = lambda: json.loads(response.data.decode()) @@ -198,7 +198,7 @@ def __send_token_request(token_request_url, data): } # Use unified HTTP client response = self.http_client.request( - 'POST', url=token_request_url, body=data, headers=headers + "POST", url=token_request_url, body=data, headers=headers ) # Convert urllib3 response to dict for compatibility return json.loads(response.data.decode()) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 7323b939a..1a35f97da 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -354,7 +354,7 @@ def _build_client_context(self, server_hostname: str, **kwargs): """Build ClientContext for HTTP client configuration.""" from databricks.sql.auth.common import ClientContext from databricks.sql.types import SSLOptions - + # Extract SSL options ssl_options = SSLOptions( tls_verify=not kwargs.get("_tls_no_verify", False), @@ -364,22 +364,26 @@ def _build_client_context(self, server_hostname: str, **kwargs): tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) - + # Build user agent user_agent_entry = kwargs.get("user_agent_entry", "") if user_agent_entry: user_agent = f"PyDatabricksSqlConnector/{__version__} ({user_agent_entry})" else: user_agent = f"PyDatabricksSqlConnector/{__version__}" - + return ClientContext( hostname=server_hostname, ssl_options=ssl_options, socket_timeout=kwargs.get("_socket_timeout"), - retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count", 30), + retry_stop_after_attempts_count=kwargs.get( + "_retry_stop_after_attempts_count", 30 + ), retry_delay_min=kwargs.get("_retry_delay_min", 1.0), retry_delay_max=kwargs.get("_retry_delay_max", 60.0), - retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration", 900.0), + retry_stop_after_attempts_duration=kwargs.get( + "_retry_stop_after_attempts_duration", 900.0 + ), retry_delay_default=kwargs.get("_retry_delay_default", 1.0), retry_dangerous_codes=kwargs.get("_retry_dangerous_codes", []), http_proxy=kwargs.get("_http_proxy"), @@ -443,7 +447,7 @@ def get_protocol_version(openSessionResp: TOpenSessionResp): @property def open(self) -> bool: """Return whether the connection is open by checking if the session is open.""" - return hasattr(self, 'session') and self.session.is_open + return hasattr(self, "session") and self.session.is_open def cursor( self, @@ -792,10 +796,12 @@ def _handle_staging_put( ) with open(local_file, "rb") as fh: - r = self.connection.session.http_client.request('PUT', presigned_url, body=fh.read(), headers=headers) + r = self.connection.session.http_client.request( + "PUT", presigned_url, body=fh.read(), headers=headers + ) # Add compatibility attributes for urllib3 response r.status_code = r.status - if hasattr(r, 'data'): + if hasattr(r, "data"): r.content = r.data r.ok = r.status < 400 r.text = r.data.decode() if r.data else "" @@ -835,10 +841,12 @@ def _handle_staging_get( session_id_hex=self.connection.get_session_id_hex(), ) - r = self.connection.session.http_client.request('GET', presigned_url, headers=headers) + r = self.connection.session.http_client.request( + "GET", presigned_url, headers=headers + ) # Add compatibility attributes for urllib3 response r.status_code = r.status - if hasattr(r, 'data'): + if hasattr(r, "data"): r.content = r.data r.ok = r.status < 400 r.text = r.data.decode() if r.data else "" @@ -860,10 +868,12 @@ def _handle_staging_remove( ): """Make an HTTP DELETE request to the presigned_url""" - r = self.connection.session.http_client.request('DELETE', presigned_url, headers=headers) + r = self.connection.session.http_client.request( + "DELETE", presigned_url, headers=headers + ) # Add compatibility attributes for urllib3 response r.status_code = r.status - if hasattr(r, 'data'): + if hasattr(r, "data"): r.content = r.data r.ok = r.status < 400 r.text = r.data.decode() if r.data else "" diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index ea375fbbb..cef4ca274 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -95,10 +95,10 @@ def run(self) -> DownloadedFile: start_time = time.time() with self._http_client.request_context( - method='GET', + method="GET", url=self.link.fileLink, timeout=self.settings.download_timeout, - headers=self.link.httpHeaders + headers=self.link.httpHeaders, ) as response: if response.status >= 400: raise Exception(f"HTTP {response.status}: {response.data.decode()}") diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 8e7029805..1b920b008 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -49,7 +49,9 @@ class FeatureFlagsContext: in the background, returning stale data until the refresh completes. """ - def __init__(self, connection: "Connection", executor: ThreadPoolExecutor, http_client): + def __init__( + self, connection: "Connection", executor: ThreadPoolExecutor, http_client + ): from databricks.sql import __version__ self._connection = connection @@ -65,7 +67,7 @@ def __init__(self, connection: "Connection", executor: ThreadPoolExecutor, http_ self._feature_flag_endpoint = ( f"https://{self._connection.session.host}{endpoint_suffix}" ) - + # Use the provided HTTP client self._http_client = http_client @@ -109,7 +111,7 @@ def _refresh_flags(self): headers["User-Agent"] = self._connection.session.useragent_header response = self._http_client.request( - 'GET', self._feature_flag_endpoint, headers=headers, timeout=30 + "GET", self._feature_flag_endpoint, headers=headers, timeout=30 ) # Add compatibility attributes for urllib3 response response.status_code = response.status @@ -165,7 +167,9 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: # Use the unique session ID as the key key = connection.get_session_id_hex() if key not in cls._context_map: - cls._context_map[key] = FeatureFlagsContext(connection, cls._executor, connection.session.http_client) + cls._context_map[key] = FeatureFlagsContext( + connection, cls._executor, connection.session.http_client + ) return cls._context_map[key] @classmethod diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index 8c3be2bfd..a296704b4 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -18,7 +18,7 @@ class UnifiedHttpClient: """ Unified HTTP client for all Databricks SQL connector HTTP operations. - + This client uses urllib3 for robust HTTP communication with retry policies, connection pooling, SSL support, and proxy support. It replaces the various singleton HTTP clients and direct requests usage throughout the codebase. @@ -37,12 +37,12 @@ def __init__(self, client_context): def _setup_pool_manager(self): """Set up the urllib3 PoolManager with configuration from ClientContext.""" - + # SSL context setup ssl_context = None if self.config.ssl_options: ssl_context = ssl.create_default_context() - + # Configure SSL verification if not self.config.ssl_options.tls_verify: ssl_context.check_hostname = False @@ -50,18 +50,22 @@ def _setup_pool_manager(self): elif not self.config.ssl_options.tls_verify_hostname: ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_REQUIRED - + # Load custom CA file if specified if self.config.ssl_options.tls_trusted_ca_file: - ssl_context.load_verify_locations(self.config.ssl_options.tls_trusted_ca_file) - + ssl_context.load_verify_locations( + self.config.ssl_options.tls_trusted_ca_file + ) + # Load client certificate if specified - if (self.config.ssl_options.tls_client_cert_file and - self.config.ssl_options.tls_client_cert_key_file): + if ( + self.config.ssl_options.tls_client_cert_file + and self.config.ssl_options.tls_client_cert_key_file + ): ssl_context.load_cert_chain( self.config.ssl_options.tls_client_cert_file, self.config.ssl_options.tls_client_cert_key_file, - self.config.ssl_options.tls_client_cert_key_password + self.config.ssl_options.tls_client_cert_key_password, ) # Create retry policy @@ -76,14 +80,15 @@ def _setup_pool_manager(self): # Common pool manager kwargs pool_kwargs = { - 'num_pools': self.config.pool_connections, - 'maxsize': self.config.pool_maxsize, - 'retries': retry_policy, - 'timeout': urllib3.Timeout( - connect=self.config.socket_timeout, - read=self.config.socket_timeout - ) if self.config.socket_timeout else None, - 'ssl_context': ssl_context, + "num_pools": self.config.pool_connections, + "maxsize": self.config.pool_maxsize, + "retries": retry_policy, + "timeout": urllib3.Timeout( + connect=self.config.socket_timeout, read=self.config.socket_timeout + ) + if self.config.socket_timeout + else None, + "ssl_context": ssl_context, } # Create proxy or regular pool manager @@ -93,58 +98,51 @@ def _setup_pool_manager(self): proxy_headers = make_headers( proxy_basic_auth=f"{self.config.proxy_username}:{self.config.proxy_password}" ) - + self._pool_manager = ProxyManager( - self.config.http_proxy, - proxy_headers=proxy_headers, - **pool_kwargs + self.config.http_proxy, proxy_headers=proxy_headers, **pool_kwargs ) else: self._pool_manager = PoolManager(**pool_kwargs) - def _prepare_headers(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: + def _prepare_headers( + self, headers: Optional[Dict[str, str]] = None + ) -> Dict[str, str]: """Prepare headers for the request, including User-Agent.""" request_headers = {} - + if self.config.user_agent: - request_headers['User-Agent'] = self.config.user_agent - + request_headers["User-Agent"] = self.config.user_agent + if headers: request_headers.update(headers) - + return request_headers @contextmanager def request_context( - self, - method: str, - url: str, - headers: Optional[Dict[str, str]] = None, - **kwargs + self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs ) -> Generator[urllib3.HTTPResponse, None, None]: """ Context manager for making HTTP requests with proper resource cleanup. - + Args: method: HTTP method (GET, POST, PUT, DELETE) url: URL to request headers: Optional headers dict **kwargs: Additional arguments passed to urllib3 request - + Yields: urllib3.HTTPResponse: The HTTP response object """ logger.debug("Making %s request to %s", method, url) - + request_headers = self._prepare_headers(headers) response = None - + try: response = self._pool_manager.request( - method=method, - url=url, - headers=request_headers, - **kwargs + method=method, url=url, headers=request_headers, **kwargs ) yield response except MaxRetryError as e: @@ -157,16 +155,18 @@ def request_context( if response: response.close() - def request(self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs) -> urllib3.HTTPResponse: + def request( + self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs + ) -> urllib3.HTTPResponse: """ Make an HTTP request. - + Args: method: HTTP method (GET, POST, PUT, DELETE, etc.) url: URL to request headers: Optional headers dict **kwargs: Additional arguments passed to urllib3 request - + Returns: urllib3.HTTPResponse: The HTTP response object with data pre-loaded """ @@ -175,32 +175,36 @@ def request(self, method: str, url: str, headers: Optional[Dict[str, str]] = Non response._body = response.data return response - def upload_file(self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None) -> urllib3.HTTPResponse: + def upload_file( + self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None + ) -> urllib3.HTTPResponse: """ Upload a file using PUT method. - + Args: url: URL to upload to file_path: Path to the file to upload headers: Optional headers - + Returns: urllib3.HTTPResponse: The response from the server """ - with open(file_path, 'rb') as file_obj: - return self.request('PUT', url, body=file_obj.read(), headers=headers) + with open(file_path, "rb") as file_obj: + return self.request("PUT", url, body=file_obj.read(), headers=headers) - def download_file(self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None) -> None: + def download_file( + self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None + ) -> None: """ Download a file using GET method. - + Args: url: URL to download from file_path: Path where to save the downloaded file headers: Optional headers """ - response = self.request('GET', url, headers=headers) - with open(file_path, 'wb') as file_obj: + response = self.request("GET", url, headers=headers) + with open(file_path, "wb") as file_obj: file_obj.write(response.data) def close(self): @@ -222,5 +226,6 @@ class IgnoreNetrcAuth: Compatibility class for OAuth code that expects requests.auth.AuthBase interface. This is a no-op auth handler since OAuth handles auth differently. """ + def __call__(self, request): - return request \ No newline at end of file + return request diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index c9b4f939a..0cba8be48 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -193,7 +193,7 @@ def close(self) -> None: logger.error("Attempt to close session raised a local exception: %s", e) self.is_open = False - + # Close HTTP client if it exists - if hasattr(self, 'http_client') and self.http_client: + if hasattr(self, "http_client") and self.http_client: self.http_client.close() diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 13c15486d..2785d3cca 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -230,7 +230,7 @@ def _send_telemetry(self, events): try: logger.debug("Submitting telemetry request to thread pool") - + # Use unified HTTP client future = self._executor.submit( self._send_with_unified_client, @@ -239,7 +239,7 @@ def _send_telemetry(self, events): headers=headers, timeout=900, ) - + future.add_done_callback( lambda fut: self._telemetry_request_callback(fut, sent_count=sent_count) ) @@ -249,10 +249,14 @@ def _send_telemetry(self, events): def _send_with_unified_client(self, url, data, headers): """Helper method to send telemetry using the unified HTTP client.""" try: - response = self._http_client.request('POST', url, body=data, headers=headers, timeout=900) + response = self._http_client.request( + "POST", url, body=data, headers=headers, timeout=900 + ) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status - response.json = lambda: json.loads(response.data.decode()) if response.data else {} + response.json = ( + lambda: json.loads(response.data.decode()) if response.data else {} + ) return response except Exception as e: logger.error("Failed to send telemetry with unified client: %s", e) From 1143838ad277f4a0309fdb40cdf682ad8e98ad9a Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Sat, 9 Aug 2025 23:45:22 +0530 Subject: [PATCH 05/25] preliminary connection closure func --- src/databricks/sql/auth/thrift_http_client.py | 7 ++++--- src/databricks/sql/backend/sea/backend.py | 21 ++++++++++++++++--- .../sql/backend/sea/utils/http_client.py | 2 +- src/databricks/sql/backend/thrift_backend.py | 9 ++++---- 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index f0daae162..a60540712 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -105,7 +105,6 @@ def startRetryTimer(self): self.retry_policy and self.retry_policy.start_retry_timer() def open(self): - # self.__pool replaces the self.__http used by the original THttpClient _pool_kwargs = {"maxsize": self.max_connections} @@ -140,11 +139,14 @@ def open(self): else: self.__pool = pool_class(self.host, self.port, **_pool_kwargs) - def close(self): + def release_connection(self): self.__resp and self.__resp.drain_conn() self.__resp and self.__resp.release_conn() self.__resp = None + def close(self): + self.__pool.close() + def read(self, sz): return self.__resp.read(sz) @@ -152,7 +154,6 @@ def isOpen(self): return self.__resp is not None def flush(self): - # Pull data out of buffer that will be sent in this request data = self.__wbuf.getvalue() self.__wbuf = BytesIO() diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 75d2c665c..68f41084c 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -273,7 +273,7 @@ def open_session( return SessionId.from_sea_session_id(session_id) - def close_session(self, session_id: SessionId) -> None: + def _close_session(self, session_id: SessionId) -> None: """ Closes an existing session with the Databricks SQL service. @@ -285,8 +285,6 @@ def close_session(self, session_id: SessionId) -> None: OperationalError: If there's an error closing the session """ - logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) - if session_id.backend_type != BackendType.SEA: raise ValueError("Not a valid SEA session ID") sea_session_id = session_id.to_sea_session_id() @@ -302,6 +300,23 @@ def close_session(self, session_id: SessionId) -> None: data=request_data.to_dict(), ) + def close_session(self, session_id: SessionId) -> None: + """ + Closes the session and the underlying HTTP client. + + Args: + session_id: The session identifier returned by open_session() + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error closing the session + """ + + logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) + + self._close_session(session_id) + self._http_client.close() + def _extract_description_from_manifest( self, manifest: ResultManifest ) -> List[Tuple]: diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index ef9a14353..f0aec2b2d 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -197,7 +197,7 @@ def _open(self): def close(self): """Close the connection pool.""" if self._pool: - self._pool.clear() + self._pool.close() def using_proxy(self) -> bool: """Check if proxy is being used.""" diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index b404b1669..7598f8291 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -232,7 +232,7 @@ def __init__( try: self._transport.open() except: - self._transport.close() + self._transport.release_connection() raise self._request_lock = threading.RLock() @@ -478,7 +478,7 @@ def attempt_request(attempt): ) finally: # Calling `close()` here releases the active HTTP connection back to the pool - self._transport.close() + self._transport.release_connection() return RequestErrorInfo( error=error, @@ -607,7 +607,7 @@ def open_session(self, session_configuration, catalog, schema) -> SessionId: self._session_id_hex = session_id.hex_guid return session_id except: - self._transport.close() + self._transport.release_connection() raise def close_session(self, session_id: SessionId) -> None: @@ -619,7 +619,8 @@ def close_session(self, session_id: SessionId) -> None: try: self.make_request(self._client.CloseSession, req) finally: - self._transport.close() + self._transport.release_connection() + self._transport.close() def _check_command_not_in_error_or_closed_state( self, op_handle, get_operations_resp From 68cc8221457f75ac22b1ac6d877d9982f8dae2e5 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Sat, 9 Aug 2025 23:49:26 +0530 Subject: [PATCH 06/25] unit test for backend closure --- tests/unit/test_thrift_backend.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 0cdb43f5c..c671e4900 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1406,8 +1406,12 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - def test_session_handle_respected_in_close_session(self, tcli_service_class): + @patch("databricks.sql.auth.thrift_http_client.THttpClient", autospec=True) + def test_session_handle_respected_in_close_session( + self, mock_http_client_class, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value + mock_http_client_instance = mock_http_client_class.return_value thrift_backend = ThriftDatabricksClient( "foobar", 443, @@ -1416,12 +1420,16 @@ def test_session_handle_respected_in_close_session(self, tcli_service_class): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) + # Manually set the mocked transport instance + thrift_backend._transport = mock_http_client_instance + session_id = SessionId.from_thrift_handle(self.session_handle) thrift_backend.close_session(session_id) self.assertEqual( tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, self.session_handle, ) + mock_http_client_instance.close.assert_called_once() @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_non_arrow_non_column_based_set_triggers_exception( From ef1d9fd0fdfc2af5387786d24372726f69866091 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Sun, 10 Aug 2025 00:11:38 +0530 Subject: [PATCH 07/25] remove redundant comment --- tests/unit/test_thrift_backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index c671e4900..8e1a0065a 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1420,7 +1420,6 @@ def test_session_handle_respected_in_close_session( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - # Manually set the mocked transport instance thrift_backend._transport = mock_http_client_instance session_id = SessionId.from_thrift_handle(self.session_handle) From 4bb2e4b0fb4238612d9b1f5b8401706a5295113c Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Sun, 10 Aug 2025 10:36:18 +0530 Subject: [PATCH 08/25] assert SEA http client closure in unit tests --- tests/unit/test_sea_backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index f604f2874..1e8da7d34 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -220,6 +220,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i path=sea_client.SESSION_PATH_WITH_ID.format("test-session-789"), data={"session_id": "test-session-789", "warehouse_id": "abc123"}, ) + mock_http_client.close.assert_called_once() # Test close_session with invalid ID type with pytest.raises(ValueError) as excinfo: From 734dd06e131b274b6af7e5fab58dc4fabaa4902f Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Sun, 10 Aug 2025 15:56:43 +0530 Subject: [PATCH 09/25] correct docstrng --- src/databricks/sql/backend/thrift_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 7598f8291..ee6ed547e 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -477,7 +477,7 @@ def attempt_request(attempt): ) ) finally: - # Calling `close()` here releases the active HTTP connection back to the pool + # Calling `release_connection()` here releases the active HTTP connection back to the pool self._transport.release_connection() return RequestErrorInfo( From d00e3c86a798bc8c7f89c9cc59fcae20ef7eff61 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 11 Aug 2025 12:02:32 +0530 Subject: [PATCH 10/25] fix e2e Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/common.py | 6 ++--- src/databricks/sql/auth/retry.py | 6 ++--- src/databricks/sql/backend/thrift_backend.py | 2 +- src/databricks/sql/client.py | 20 +++++++--------- .../sql/common/unified_http_client.py | 24 ++++++++++++++++--- tests/e2e/common/retry_test_mixins.py | 5 +++- tests/e2e/common/staging_ingestion_tests.py | 10 +++++--- tests/e2e/common/uc_volume_tests.py | 9 +++++-- tests/e2e/test_driver.py | 14 ++++++----- 9 files changed, 61 insertions(+), 35 deletions(-) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 61b07cb91..cec869027 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -69,12 +69,10 @@ def __init__( # HTTP client configuration self.ssl_options = ssl_options self.socket_timeout = socket_timeout - self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 30 + self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5 self.retry_delay_min = retry_delay_min or 1.0 self.retry_delay_max = retry_delay_max or 60.0 - self.retry_stop_after_attempts_duration = ( - retry_stop_after_attempts_duration or 900.0 - ) + self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration or 900.0 self.retry_delay_default = retry_delay_default or 5.0 self.retry_dangerous_codes = retry_dangerous_codes or [] self.http_proxy = http_proxy diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index 368edc9a2..9c9988971 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -294,7 +294,7 @@ def sleep_for_retry(self, response: BaseHTTPResponse) -> bool: else: proposed_wait = self.get_backoff_time() - proposed_wait = max(proposed_wait, self.delay_max) + proposed_wait = min(proposed_wait, self.delay_max) self.check_proposed_wait(proposed_wait) logger.debug(f"Retrying after {proposed_wait} seconds") time.sleep(proposed_wait) @@ -355,8 +355,8 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: logger.info(f"Received status code {status_code} for {method} request") # Request succeeded. Don't retry. - if status_code == 200: - return False, "200 codes are not retried" + if status_code // 100 == 2: + return False, "2xx codes are not retried" if status_code == 401: return ( diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 801632a41..1a1849bb7 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -194,7 +194,7 @@ def __init__( if _max_redirects: if _max_redirects > self._retry_stop_after_attempts_count: - logger.warn( + logger.warning( "_retry_max_redirects > _retry_stop_after_attempts_count so it will have no affect!" ) urllib3_kwargs = {"redirect": _max_redirects} diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 1a35f97da..d3a72c86a 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -376,21 +376,17 @@ def _build_client_context(self, server_hostname: str, **kwargs): hostname=server_hostname, ssl_options=ssl_options, socket_timeout=kwargs.get("_socket_timeout"), - retry_stop_after_attempts_count=kwargs.get( - "_retry_stop_after_attempts_count", 30 - ), - retry_delay_min=kwargs.get("_retry_delay_min", 1.0), - retry_delay_max=kwargs.get("_retry_delay_max", 60.0), - retry_stop_after_attempts_duration=kwargs.get( - "_retry_stop_after_attempts_duration", 900.0 - ), - retry_delay_default=kwargs.get("_retry_delay_default", 1.0), - retry_dangerous_codes=kwargs.get("_retry_dangerous_codes", []), + retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count"), + retry_delay_min=kwargs.get("_retry_delay_min"), + retry_delay_max=kwargs.get("_retry_delay_max"), + retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration"), + retry_delay_default=kwargs.get("_retry_delay_default"), + retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), http_proxy=kwargs.get("_http_proxy"), proxy_username=kwargs.get("_proxy_username"), proxy_password=kwargs.get("_proxy_password"), - pool_connections=kwargs.get("_pool_connections", 1), - pool_maxsize=kwargs.get("_pool_maxsize", 1), + pool_connections=kwargs.get("_pool_connections"), + pool_maxsize=kwargs.get("_pool_maxsize"), user_agent=user_agent, ) diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index a296704b4..13fd9ddd2 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -9,7 +9,7 @@ from urllib3.util import make_headers from urllib3.exceptions import MaxRetryError -from databricks.sql.auth.retry import DatabricksRetryPolicy +from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType from databricks.sql.exc import RequestError logger = logging.getLogger(__name__) @@ -33,6 +33,7 @@ def __init__(self, client_context): """ self.config = client_context self._pool_manager = None + self._retry_policy = None self._setup_pool_manager() def _setup_pool_manager(self): @@ -69,7 +70,7 @@ def _setup_pool_manager(self): ) # Create retry policy - retry_policy = DatabricksRetryPolicy( + self._retry_policy = DatabricksRetryPolicy( delay_min=self.config.retry_delay_min, delay_max=self.config.retry_delay_max, stop_after_attempts_count=self.config.retry_stop_after_attempts_count, @@ -77,12 +78,17 @@ def _setup_pool_manager(self): delay_default=self.config.retry_delay_default, force_dangerous_codes=self.config.retry_dangerous_codes, ) + + # Initialize the required attributes that DatabricksRetryPolicy expects + # but doesn't initialize in its constructor + self._retry_policy._command_type = None + self._retry_policy._retry_start_time = None # Common pool manager kwargs pool_kwargs = { "num_pools": self.config.pool_connections, "maxsize": self.config.pool_maxsize, - "retries": retry_policy, + "retries": self._retry_policy, "timeout": urllib3.Timeout( connect=self.config.socket_timeout, read=self.config.socket_timeout ) @@ -119,6 +125,14 @@ def _prepare_headers( return request_headers + def _prepare_retry_policy(self): + """Set up the retry policy for the current request.""" + if isinstance(self._retry_policy, DatabricksRetryPolicy): + # Set command type for HTTP requests to OTHER (not database commands) + self._retry_policy.command_type = CommandType.OTHER + # Start the retry timer for duration-based retry limits + self._retry_policy.start_retry_timer() + @contextmanager def request_context( self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs @@ -138,6 +152,10 @@ def request_context( logger.debug("Making %s request to %s", method, url) request_headers = self._prepare_headers(headers) + + # Prepare retry policy for this request + self._prepare_retry_policy() + response = None try: diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index e1c32d68e..e5ff3dcb7 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -247,6 +247,7 @@ def test_retry_exponential_backoff(self, mock_send_telemetry, extra_params): """ retry_policy = self._retry_policy.copy() retry_policy["_retry_delay_min"] = 1 + retry_policy["_retry_delay_max"] = 10 time_start = time.time() with mocked_server_response( @@ -282,9 +283,11 @@ def test_retry_max_duration_not_exceeded(self, extra_params): WHEN the server sends a Retry-After header of 60 seconds THEN the connector raises a MaxRetryDurationError """ + retry_policy = self._retry_policy.copy() + retry_policy["_retry_delay_max"] = 60 with mocked_server_response(status=429, headers={"Retry-After": "60"}): with pytest.raises(RequestError) as cm: - extra_params = {**extra_params, **self._retry_policy} + extra_params = {**extra_params, **retry_policy} with self.connection(extra_params=extra_params) as conn: pass assert isinstance(cm.value.args[1], MaxRetryDurationError) diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index 825f830f3..377d51ef4 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -68,15 +68,19 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): # REMOVE should succeed remove_query = f"REMOVE 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv'" - - with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + # Use minimal retry settings to fail fast for staging operations + extra_params = { + "staging_allowed_local_path": "/", + "_retry_stop_after_attempts_count": 1, + } + with self.connection(extra_params=extra_params) as conn: cursor = conn.cursor() cursor.execute(remove_query) # GET after REMOVE should fail with pytest.raises( - Error, match="Staging operation over HTTP was unsuccessful: 404" + Error, match="too many 404 error responses" ): cursor = conn.cursor() query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'" diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index 72e2f5020..c60e10e6d 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -68,14 +68,19 @@ def test_uc_volume_life_cycle(self, catalog, schema): remove_query = f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" - with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + # Use minimal retry settings to fail fast + extra_params = { + "staging_allowed_local_path": "/", + "_retry_stop_after_attempts_count": 1, + } + with self.connection(extra_params=extra_params) as conn: cursor = conn.cursor() cursor.execute(remove_query) # GET after REMOVE should fail with pytest.raises( - Error, match="Staging operation over HTTP was unsuccessful: 404" + Error, match="too many 404 error responses" ): cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 3fa87b1af..53b7383e6 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -60,12 +60,14 @@ unsafe_logger.addHandler(logging.FileHandler("./tests-unsafe.log")) # manually decorate DecimalTestsMixin to need arrow support -for name in loader.getTestCaseNames(DecimalTestsMixin, "test_"): - fn = getattr(DecimalTestsMixin, name) - decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")( - fn - ) - setattr(DecimalTestsMixin, name, decorated) +test_loader = loader.TestLoader() +for name in test_loader.getTestCaseNames(DecimalTestsMixin): + if name.startswith("test_"): + fn = getattr(DecimalTestsMixin, name) + decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")( + fn + ) + setattr(DecimalTestsMixin, name, decorated) class PySQLPytestTestCase: From 000d3a360000c0b7f3c5914d061296685224cf5f Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 11 Aug 2025 12:12:39 +0530 Subject: [PATCH 11/25] fix unit Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/common.py | 4 +++- src/databricks/sql/client.py | 8 ++++++-- src/databricks/sql/common/unified_http_client.py | 6 +++--- tests/unit/test_retry.py | 4 ++-- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index cec869027..e80fac189 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -72,7 +72,9 @@ def __init__( self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5 self.retry_delay_min = retry_delay_min or 1.0 self.retry_delay_max = retry_delay_max or 60.0 - self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration or 900.0 + self.retry_stop_after_attempts_duration = ( + retry_stop_after_attempts_duration or 900.0 + ) self.retry_delay_default = retry_delay_default or 5.0 self.retry_dangerous_codes = retry_dangerous_codes or [] self.http_proxy = http_proxy diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index d3a72c86a..d2e94df63 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -376,10 +376,14 @@ def _build_client_context(self, server_hostname: str, **kwargs): hostname=server_hostname, ssl_options=ssl_options, socket_timeout=kwargs.get("_socket_timeout"), - retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count"), + retry_stop_after_attempts_count=kwargs.get( + "_retry_stop_after_attempts_count" + ), retry_delay_min=kwargs.get("_retry_delay_min"), retry_delay_max=kwargs.get("_retry_delay_max"), - retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration"), + retry_stop_after_attempts_duration=kwargs.get( + "_retry_stop_after_attempts_duration" + ), retry_delay_default=kwargs.get("_retry_delay_default"), retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), http_proxy=kwargs.get("_http_proxy"), diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index 13fd9ddd2..03f784ee2 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -78,7 +78,7 @@ def _setup_pool_manager(self): delay_default=self.config.retry_delay_default, force_dangerous_codes=self.config.retry_dangerous_codes, ) - + # Initialize the required attributes that DatabricksRetryPolicy expects # but doesn't initialize in its constructor self._retry_policy._command_type = None @@ -152,10 +152,10 @@ def request_context( logger.debug("Making %s request to %s", method, url) request_headers = self._prepare_headers(headers) - + # Prepare retry policy for this request self._prepare_retry_policy() - + response = None try: diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index 897a1d111..40096bf08 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -34,7 +34,7 @@ def test_sleep__no_retry_after(self, t_mock, retry_policy, error_history): retry_policy.history = [error_history, error_history] retry_policy.sleep(HTTPResponse(status=503)) - expected_backoff_time = max( + expected_backoff_time = min( self.calculate_backoff_time( 0, retry_policy.delay_min, retry_policy.delay_max ), @@ -57,7 +57,7 @@ def test_sleep__no_retry_after_header__multiple_retries(self, t_mock, retry_poli expected_backoff_times = [] for attempt in range(num_attempts): expected_backoff_times.append( - max( + min( self.calculate_backoff_time( attempt, retry_policy.delay_min, retry_policy.delay_max ), From cba3da70ce26696b645debd3a6d3dd523f7074f6 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 11 Aug 2025 14:29:12 +0530 Subject: [PATCH 12/25] more fixes Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/retry.py | 2 +- src/databricks/sql/backend/thrift_backend.py | 2 +- src/databricks/sql/client.py | 4 +- src/databricks/sql/cloudfetch/downloader.py | 3 - .../sql/common/unified_http_client.py | 43 ----------- .../sql/telemetry/telemetry_client.py | 72 ++++++++++++++++--- tests/e2e/common/retry_test_mixins.py | 5 +- tests/e2e/common/staging_ingestion_tests.py | 1 + tests/e2e/common/uc_volume_tests.py | 1 + 9 files changed, 68 insertions(+), 65 deletions(-) diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index 9c9988971..ad8e455f1 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -294,7 +294,7 @@ def sleep_for_retry(self, response: BaseHTTPResponse) -> bool: else: proposed_wait = self.get_backoff_time() - proposed_wait = min(proposed_wait, self.delay_max) + proposed_wait = max(proposed_wait, self.delay_max) self.check_proposed_wait(proposed_wait) logger.debug(f"Retrying after {proposed_wait} seconds") time.sleep(proposed_wait) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 1a1849bb7..25cc8428a 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -105,7 +105,7 @@ def __init__( http_headers, auth_provider: AuthProvider, ssl_options: SSLOptions, - http_client=None, + http_client, **kwargs, ): # Internal arguments in **kwargs: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index d2e94df63..74630cebc 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -277,7 +277,7 @@ def read(self) -> Optional[OAuthToken]: host_url=server_hostname, http_path=http_path, port=kwargs.get("_port", 443), - http_client=http_client, + client_context=client_context, user_agent=self.session.useragent_header if hasattr(self, "session") else None, @@ -299,7 +299,7 @@ def read(self) -> Optional[OAuthToken]: auth_provider=self.session.auth_provider, host_url=self.session.host, batch_size=self.telemetry_batch_size, - http_client=self.session.http_client, + client_context=client_context, ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index cef4ca274..a2a7837f0 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -13,9 +13,6 @@ logger = logging.getLogger(__name__) -# TODO: Ideally, we should use a common retry policy (DatabricksRetryPolicy) for all the requests across the library. -# But DatabricksRetryPolicy should be updated first - currently it can work only with Thrift requests - @dataclass class DownloadedFile: diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index 03f784ee2..bb26ae9de 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -193,38 +193,6 @@ def request( response._body = response.data return response - def upload_file( - self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None - ) -> urllib3.HTTPResponse: - """ - Upload a file using PUT method. - - Args: - url: URL to upload to - file_path: Path to the file to upload - headers: Optional headers - - Returns: - urllib3.HTTPResponse: The response from the server - """ - with open(file_path, "rb") as file_obj: - return self.request("PUT", url, body=file_obj.read(), headers=headers) - - def download_file( - self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None - ) -> None: - """ - Download a file using GET method. - - Args: - url: URL to download from - file_path: Path where to save the downloaded file - headers: Optional headers - """ - response = self.request("GET", url, headers=headers) - with open(file_path, "wb") as file_obj: - file_obj.write(response.data) - def close(self): """Close the underlying connection pools.""" if self._pool_manager: @@ -236,14 +204,3 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - - -# Compatibility class to maintain requests-like interface for OAuth -class IgnoreNetrcAuth: - """ - Compatibility class for OAuth code that expects requests.auth.AuthBase interface. - This is a no-op auth handler since OAuth handles auth differently. - """ - - def __call__(self, request): - return request diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 2785d3cca..9887b67a7 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -1,8 +1,11 @@ import threading import time import logging +import json from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Optional, TYPE_CHECKING +from concurrent.futures import Future +from datetime import datetime, timezone +from typing import List, Dict, Any, Optional, TYPE_CHECKING from databricks.sql.telemetry.models.event import ( TelemetryEvent, DriverSystemConfiguration, @@ -36,8 +39,7 @@ import locale from databricks.sql.telemetry.utils import BaseTelemetryClient from databricks.sql.common.feature_flag import FeatureFlagsContextFactory - -from src.databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.unified_http_client import UnifiedHttpClient if TYPE_CHECKING: from databricks.sql.client import Connection @@ -151,6 +153,44 @@ def _flush(self): pass +class TelemetryHttpClientSingleton: + """ + Singleton HTTP client for telemetry operations. + + This ensures that telemetry has its own dedicated HTTP client that + is independent of individual connection lifecycles. + """ + + _instance = None + _lock = threading.RLock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._http_client = None + cls._instance._initialized = False + return cls._instance + + def get_http_client(self, client_context): + """Get or create the singleton HTTP client.""" + if not self._initialized and client_context: + with self._lock: + if not self._initialized: + self._http_client = UnifiedHttpClient(client_context) + self._initialized = True + return self._http_client + + def close(self): + """Close the singleton HTTP client.""" + with self._lock: + if self._http_client: + self._http_client.close() + self._http_client = None + self._initialized = False + + class TelemetryClient(BaseTelemetryClient): """ Telemetry client class that handles sending telemetry events in batches to the server. @@ -169,7 +209,7 @@ def __init__( host_url, executor, batch_size, - http_client, + client_context, ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled @@ -182,7 +222,10 @@ def __init__( self._driver_connection_params = None self._host_url = host_url self._executor = executor - self._http_client = http_client + + # Use singleton HTTP client for telemetry instead of connection-specific client + self._http_client_singleton = TelemetryHttpClientSingleton() + self._http_client = self._http_client_singleton.get_http_client(client_context) def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" @@ -246,17 +289,24 @@ def _send_telemetry(self, events): except Exception as e: logger.debug("Failed to submit telemetry request: %s", e) - def _send_with_unified_client(self, url, data, headers): + def _send_with_unified_client(self, url, data, headers, timeout=900): """Helper method to send telemetry using the unified HTTP client.""" try: response = self._http_client.request( - "POST", url, body=data, headers=headers, timeout=900 + "POST", url, body=data, headers=headers, timeout=timeout ) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status + response.ok = 200 <= response.status < 300 response.json = ( lambda: json.loads(response.data.decode()) if response.data else {} ) + # Add raise_for_status method + def raise_for_status(): + if not response.ok: + raise Exception(f"HTTP {response.status_code}") + + response.raise_for_status = raise_for_status return response except Exception as e: logger.error("Failed to send telemetry with unified client: %s", e) @@ -452,7 +502,7 @@ def initialize_telemetry_client( auth_provider, host_url, batch_size, - http_client, + client_context, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: @@ -475,7 +525,7 @@ def initialize_telemetry_client( host_url=host_url, executor=TelemetryClientFactory._executor, batch_size=batch_size, - http_client=http_client, + client_context=client_context, ) else: TelemetryClientFactory._clients[ @@ -528,7 +578,7 @@ def connection_failure_log( host_url: str, http_path: str, port: int, - http_client: UnifiedHttpClient, + client_context, user_agent: Optional[str] = None, ): """Send error telemetry when connection creation fails, without requiring a session""" @@ -541,7 +591,7 @@ def connection_failure_log( auth_provider=None, host_url=host_url, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - http_client=http_client, + client_context=client_context, ) telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index e5ff3dcb7..e1c32d68e 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -247,7 +247,6 @@ def test_retry_exponential_backoff(self, mock_send_telemetry, extra_params): """ retry_policy = self._retry_policy.copy() retry_policy["_retry_delay_min"] = 1 - retry_policy["_retry_delay_max"] = 10 time_start = time.time() with mocked_server_response( @@ -283,11 +282,9 @@ def test_retry_max_duration_not_exceeded(self, extra_params): WHEN the server sends a Retry-After header of 60 seconds THEN the connector raises a MaxRetryDurationError """ - retry_policy = self._retry_policy.copy() - retry_policy["_retry_delay_max"] = 60 with mocked_server_response(status=429, headers={"Retry-After": "60"}): with pytest.raises(RequestError) as cm: - extra_params = {**extra_params, **retry_policy} + extra_params = {**extra_params, **self._retry_policy} with self.connection(extra_params=extra_params) as conn: pass assert isinstance(cm.value.args[1], MaxRetryDurationError) diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index 377d51ef4..73aa0a113 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -72,6 +72,7 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): extra_params = { "staging_allowed_local_path": "/", "_retry_stop_after_attempts_count": 1, + "_retry_delay_max": 10, } with self.connection(extra_params=extra_params) as conn: cursor = conn.cursor() diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index c60e10e6d..93e63bd28 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -72,6 +72,7 @@ def test_uc_volume_life_cycle(self, catalog, schema): extra_params = { "staging_allowed_local_path": "/", "_retry_stop_after_attempts_count": 1, + "_retry_delay_max": 10, } with self.connection(extra_params=extra_params) as conn: cursor = conn.cursor() From 2a1f719025224d9ec4ae9e53fbb1359a193948b7 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 11 Aug 2025 15:29:25 +0530 Subject: [PATCH 13/25] more fixes Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/auth.py | 6 ++- src/databricks/sql/auth/authenticators.py | 3 ++ src/databricks/sql/auth/oauth.py | 5 +- src/databricks/sql/backend/thrift_backend.py | 3 +- src/databricks/sql/client.py | 42 ++++++--------- src/databricks/sql/common/feature_flag.py | 9 ++-- tests/unit/test_auth.py | 1 + tests/unit/test_retry.py | 4 +- tests/unit/test_telemetry.py | 55 ++++++++++++++------ tests/unit/test_thrift_backend.py | 48 +++++++++++++++++ 10 files changed, 121 insertions(+), 55 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index cc421e69e..59da2a422 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -19,6 +19,7 @@ def get_auth_provider(cfg: ClientContext, http_client): cfg.hostname, cfg.azure_client_id, cfg.azure_client_secret, + http_client, cfg.azure_tenant_id, cfg.azure_workspace_resource_id, ) @@ -34,8 +35,8 @@ def get_auth_provider(cfg: ClientContext, http_client): cfg.oauth_redirect_port_range, cfg.oauth_client_id, cfg.oauth_scopes, + http_client, cfg.auth_type, - http_client=http_client, ) elif cfg.access_token is not None: return AccessTokenAuthProvider(cfg.access_token) @@ -54,7 +55,8 @@ def get_auth_provider(cfg: ClientContext, http_client): cfg.oauth_redirect_port_range, cfg.oauth_client_id, cfg.oauth_scopes, - http_client=http_client, + http_client, + cfg.auth_type or "databricks-oauth", ) else: raise RuntimeError("No valid authentication settings!") diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 80f44812c..66e2cbe53 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -190,6 +190,7 @@ def __init__( hostname, azure_client_id, azure_client_secret, + http_client, azure_tenant_id=None, azure_workspace_resource_id=None, ): @@ -200,6 +201,7 @@ def __init__( self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host( hostname ) + self._http_client = http_client def auth_type(self) -> str: return AuthType.AZURE_SP_M2M.value @@ -209,6 +211,7 @@ def get_token_source(self, resource: str) -> RefreshableTokenSource: token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}", client_id=self.azure_client_id, client_secret=self.azure_client_secret, + http_client=self._http_client, extra_params={"resource": resource}, ) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 7f96a2303..9fdf3955a 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -190,8 +190,7 @@ def __send_auth_code_token_request( data = f"{token_request_body}&code_verifier={verifier}" return self.__send_token_request(token_request_url, data) - @staticmethod - def __send_token_request(token_request_url, data): + def __send_token_request(self, token_request_url, data): headers = { "Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded", @@ -210,7 +209,7 @@ def __send_refresh_token_request(self, hostname, refresh_token): token_request_body = client.prepare_refresh_body( refresh_token=refresh_token, client_id=client.client_id ) - return OAuthManager.__send_token_request(token_request_url, token_request_body) + return self.__send_token_request(token_request_url, token_request_body) @staticmethod def __get_tokens_from_response(oauth_response): diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 25cc8428a..59cf69b6e 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -8,6 +8,7 @@ from typing import List, Optional, Union, Any, TYPE_CHECKING from uuid import UUID +from databricks.sql.common.unified_http_client import UnifiedHttpClient from databricks.sql.result_set import ThriftResultSet from databricks.sql.telemetry.models.event import StatementType @@ -105,7 +106,7 @@ def __init__( http_headers, auth_provider: AuthProvider, ssl_options: SSLOptions, - http_client, + http_client: UnifiedHttpClient, **kwargs, ): # Internal arguments in **kwargs: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 74630cebc..634c7e261 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -799,12 +799,6 @@ def _handle_staging_put( r = self.connection.session.http_client.request( "PUT", presigned_url, body=fh.read(), headers=headers ) - # Add compatibility attributes for urllib3 response - r.status_code = r.status - if hasattr(r, "data"): - r.content = r.data - r.ok = r.status < 400 - r.text = r.data.decode() if r.data else "" # fmt: off # HTTP status codes @@ -814,13 +808,15 @@ def _handle_staging_put( NO_CONTENT = 204 # fmt: on - if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: + if r.status not in [OK, CREATED, NO_CONTENT, ACCEPTED]: + # Decode response data for error message + error_text = r.data.decode() if r.data else "" raise OperationalError( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}", session_id_hex=self.connection.get_session_id_hex(), ) - if r.status_code == ACCEPTED: + if r.status == ACCEPTED: logger.debug( f"Response code {ACCEPTED} from server indicates ingestion command was accepted " + "but not yet applied on the server. It's possible this command may fail later." @@ -844,23 +840,19 @@ def _handle_staging_get( r = self.connection.session.http_client.request( "GET", presigned_url, headers=headers ) - # Add compatibility attributes for urllib3 response - r.status_code = r.status - if hasattr(r, "data"): - r.content = r.data - r.ok = r.status < 400 - r.text = r.data.decode() if r.data else "" # response.ok verifies the status code is not between 400-600. # Any 2xx or 3xx will evaluate r.ok == True - if not r.ok: + if r.status >= 400: + # Decode response data for error message + error_text = r.data.decode() if r.data else "" raise OperationalError( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}", session_id_hex=self.connection.get_session_id_hex(), ) with open(local_file, "wb") as fp: - fp.write(r.content) + fp.write(r.data) @log_latency(StatementType.SQL) def _handle_staging_remove( @@ -871,16 +863,12 @@ def _handle_staging_remove( r = self.connection.session.http_client.request( "DELETE", presigned_url, headers=headers ) - # Add compatibility attributes for urllib3 response - r.status_code = r.status - if hasattr(r, "data"): - r.content = r.data - r.ok = r.status < 400 - r.text = r.data.decode() if r.data else "" - - if not r.ok: + + if r.status >= 400: + # Decode response data for error message + error_text = r.data.decode() if r.data else "" raise OperationalError( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}", session_id_hex=self.connection.get_session_id_hex(), ) diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 1b920b008..2b7e27ab3 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -113,12 +113,11 @@ def _refresh_flags(self): response = self._http_client.request( "GET", self._feature_flag_endpoint, headers=headers, timeout=30 ) - # Add compatibility attributes for urllib3 response - response.status_code = response.status - response.json = lambda: json.loads(response.data.decode()) - if response.status_code == 200: - ff_response = FeatureFlagsResponse.from_dict(response.json()) + if response.status == 200: + # Parse JSON response from urllib3 response data + response_data = json.loads(response.data.decode()) + ff_response = FeatureFlagsResponse.from_dict(response_data) self._update_cache_from_response(ff_response) else: # On failure, initialize with an empty dictionary to prevent re-blocking. diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 333782fd8..d574ed27e 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -306,6 +306,7 @@ def credential_provider(self): hostname="hostname", azure_client_id="client_id", azure_client_secret="client_secret", + http_client=MagicMock(), azure_tenant_id="tenant_id", ) diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index 40096bf08..897a1d111 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -34,7 +34,7 @@ def test_sleep__no_retry_after(self, t_mock, retry_policy, error_history): retry_policy.history = [error_history, error_history] retry_policy.sleep(HTTPResponse(status=503)) - expected_backoff_time = min( + expected_backoff_time = max( self.calculate_backoff_time( 0, retry_policy.delay_min, retry_policy.delay_max ), @@ -57,7 +57,7 @@ def test_sleep__no_retry_after_header__multiple_retries(self, t_mock, retry_poli expected_backoff_times = [] for attempt in range(num_attempts): expected_backoff_times.append( - min( + max( self.calculate_backoff_time( attempt, retry_policy.delay_min, retry_policy.delay_max ), diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 989b2351c..0e828497f 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -19,12 +19,17 @@ @pytest.fixture -def mock_telemetry_client(): +@patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") +def mock_telemetry_client(mock_singleton_class): """Create a mock telemetry client for testing.""" session_id = str(uuid.uuid4()) auth_provider = AccessTokenAuthProvider("test-token") executor = MagicMock() - mock_http_client = MagicMock() + mock_client_context = MagicMock() + + # Mock the singleton to return a mock HTTP client + mock_singleton = mock_singleton_class.return_value + mock_singleton.get_http_client.return_value = MagicMock() return TelemetryClient( telemetry_enabled=True, @@ -33,7 +38,7 @@ def mock_telemetry_client(): host_url="test-host.com", executor=executor, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - http_client=mock_http_client, + client_context=mock_client_context, ) @@ -212,11 +217,16 @@ def telemetry_system_reset(self): TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False - def test_client_lifecycle_flow(self): + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") + def test_client_lifecycle_flow(self, mock_singleton_class): """Test complete client lifecycle: initialize -> use -> close.""" session_id_hex = "test-session" auth_provider = AccessTokenAuthProvider("token") - mock_http_client = MagicMock() + mock_client_context = MagicMock() + + # Mock the singleton to return a mock HTTP client + mock_singleton = mock_singleton_class.return_value + mock_singleton.get_http_client.return_value = MagicMock() # Initialize enabled client TelemetryClientFactory.initialize_telemetry_client( @@ -225,7 +235,7 @@ def test_client_lifecycle_flow(self): auth_provider=auth_provider, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - http_client=mock_http_client, + client_context=mock_client_context, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -241,10 +251,15 @@ def test_client_lifecycle_flow(self): client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) - def test_disabled_telemetry_flow(self): + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") + def test_disabled_telemetry_flow(self, mock_singleton_class): """Test that disabled telemetry uses NoopTelemetryClient.""" session_id_hex = "test-session" - mock_http_client = MagicMock() + mock_client_context = MagicMock() + + # Mock the singleton to return a mock HTTP client + mock_singleton = mock_singleton_class.return_value + mock_singleton.get_http_client.return_value = MagicMock() TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, @@ -252,16 +267,21 @@ def test_disabled_telemetry_flow(self): auth_provider=None, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - http_client=mock_http_client, + client_context=mock_client_context, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) - def test_factory_error_handling(self): + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") + def test_factory_error_handling(self, mock_singleton_class): """Test that factory errors fall back to NoopTelemetryClient.""" session_id = "test-session" - mock_http_client = MagicMock() + mock_client_context = MagicMock() + + # Mock the singleton to return a mock HTTP client + mock_singleton = mock_singleton_class.return_value + mock_singleton.get_http_client.return_value = MagicMock() # Simulate initialization error with patch( @@ -274,18 +294,23 @@ def test_factory_error_handling(self): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - http_client=mock_http_client, + client_context=mock_client_context, ) # Should fall back to NoopTelemetryClient client = TelemetryClientFactory.get_telemetry_client(session_id) assert isinstance(client, NoopTelemetryClient) - def test_factory_shutdown_flow(self): + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") + def test_factory_shutdown_flow(self, mock_singleton_class): """Test factory shutdown when last client is removed.""" session1 = "session-1" session2 = "session-2" - mock_http_client = MagicMock() + mock_client_context = MagicMock() + + # Mock the singleton to return a mock HTTP client + mock_singleton = mock_singleton_class.return_value + mock_singleton.get_http_client.return_value = MagicMock() # Initialize multiple clients for session in [session1, session2]: @@ -295,7 +320,7 @@ def test_factory_shutdown_flow(self): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - http_client=mock_http_client, + client_context=mock_client_context, ) # Factory should be initialized diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 0cdb43f5c..396e0e3f1 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -83,6 +83,7 @@ def test_make_request_checks_thrift_status_code(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError): thrift_backend.make_request(mock_method, Mock()) @@ -102,6 +103,7 @@ def _make_fake_thrift_backend(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend._hive_schema_to_arrow_schema = Mock() thrift_backend._hive_schema_to_description = Mock() @@ -196,6 +198,7 @@ def test_headers_are_set(self, t_http_client_class): [("header", "value")], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) t_http_client_class.return_value.setCustomHeaders.assert_called_with( {"header": "value"} @@ -243,6 +246,7 @@ def test_tls_cert_args_are_propagated( [], auth_provider=AuthProvider(), ssl_options=mock_ssl_options, + http_client=MagicMock(), ) mock_ssl_context.load_cert_chain.assert_called_once_with( @@ -329,6 +333,7 @@ def test_tls_no_verify_is_respected( [], auth_provider=AuthProvider(), ssl_options=mock_ssl_options, + http_client=MagicMock(), ) self.assertFalse(mock_ssl_context.check_hostname) @@ -353,6 +358,7 @@ def test_tls_verify_hostname_is_respected( [], auth_provider=AuthProvider(), ssl_options=mock_ssl_options, + http_client=MagicMock(), ) self.assertFalse(mock_ssl_context.check_hostname) @@ -370,6 +376,7 @@ def test_port_and_host_are_respected(self, t_http_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) self.assertEqual( t_http_client_class.call_args[1]["uri_or_host"], @@ -385,6 +392,7 @@ def test_host_with_https_does_not_duplicate(self, t_http_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) self.assertEqual( t_http_client_class.call_args[1]["uri_or_host"], @@ -400,6 +408,7 @@ def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_cla [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) self.assertEqual( t_http_client_class.call_args[1]["uri_or_host"], @@ -415,6 +424,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), _socket_timeout=129, ) self.assertEqual( @@ -427,6 +437,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), _socket_timeout=0, ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 0) @@ -437,6 +448,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) self.assertEqual( t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000 @@ -448,6 +460,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), _socket_timeout=None, ) self.assertEqual( @@ -559,6 +572,7 @@ def test_make_request_checks_status_code(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) for code in error_codes: @@ -604,6 +618,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -647,6 +662,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) execute_response, _ = thrift_backend._handle_execute_response( @@ -691,6 +707,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -729,6 +746,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command( @@ -772,6 +790,7 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command( @@ -840,6 +859,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -892,6 +912,7 @@ def test_handle_execute_response_can_handle_without_direct_results( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) ( execute_response, @@ -930,6 +951,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend._results_message_to_execute_response = Mock() @@ -1154,6 +1176,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) arrow_queue, has_more_results, _ = thrift_backend.fetch_results( command_id=Mock(), @@ -1183,6 +1206,7 @@ def test_execute_statement_calls_client_and_handle_execute_response( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend._handle_execute_response = Mock() thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) @@ -1219,6 +1243,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend._handle_execute_response = Mock() thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) @@ -1252,6 +1277,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend._handle_execute_response = Mock() thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) @@ -1294,6 +1320,7 @@ def test_get_tables_calls_client_and_handle_execute_response( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend._handle_execute_response = Mock() thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) @@ -1340,6 +1367,7 @@ def test_get_columns_calls_client_and_handle_execute_response( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend._handle_execute_response = Mock() thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) @@ -1383,6 +1411,7 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend.open_session({}, None, None) self.assertEqual(len(tcli_service_instance.OpenSession.call_args_list), 1) @@ -1397,6 +1426,7 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) command_id = CommandId.from_thrift_handle(self.operation_handle) thrift_backend.close_command(command_id) @@ -1415,6 +1445,7 @@ def test_session_handle_respected_in_close_session(self, tcli_service_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) session_id = SessionId.from_thrift_handle(self.session_handle) thrift_backend.close_session(session_id) @@ -1470,6 +1501,7 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) @@ -1490,6 +1522,7 @@ def test_create_arrow_table_calls_correct_conversion_method( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) convert_arrow_mock.return_value = (MagicMock(), Mock()) convert_col_mock.return_value = (MagicMock(), Mock()) @@ -1525,6 +1558,7 @@ def test_convert_arrow_based_set_to_arrow_table( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) lz4_decompress_mock.return_value = bytearray("Testing", "utf-8") @@ -1745,6 +1779,7 @@ def test_make_request_will_retry_GetOperationStatus( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), _retry_stop_after_attempts_count=EXPECTED_RETRIES, _retry_delay_default=1, ) @@ -1823,6 +1858,7 @@ def test_make_request_will_retry_GetOperationStatus_for_http_error( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), _retry_stop_after_attempts_count=EXPECTED_RETRIES, _retry_delay_default=1, ) @@ -1855,6 +1891,7 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(OperationalError) as cm: @@ -1884,6 +1921,7 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), _retry_stop_after_attempts_count=14, _retry_delay_max=0, _retry_delay_min=0, @@ -1913,6 +1951,7 @@ def test_make_request_will_read_error_message_headers_if_set( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) error_headers = [ @@ -2037,6 +2076,7 @@ def test_retry_args_passthrough(self, mock_http_client): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), **retry_delay_args, ) for arg, val in retry_delay_args.items(): @@ -2068,6 +2108,7 @@ def test_retry_args_bounding(self, mock_http_client): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), **retry_delay_args, ) retry_delay_expected_vals = { @@ -2096,6 +2137,7 @@ def test_configuration_passthrough(self, tcli_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) backend.open_session(mock_config, None, None) @@ -2114,6 +2156,7 @@ def test_cant_set_timestamp_as_string_to_true(self, tcli_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(databricks.sql.Error) as cm: @@ -2141,6 +2184,7 @@ def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) initial_cat_schem_args = [("cat", None), (None, "schem"), ("cat", "schem")] @@ -2172,6 +2216,7 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) backend.open_session({}, None, None) @@ -2191,6 +2236,7 @@ def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) # If the initial catalog is set, but server returns canUseMultipleCatalogs=False, we # expect failure. If the initial catalog isn't set, then canUseMultipleCatalogs=False @@ -2237,6 +2283,7 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(InvalidServerResponseError) as cm: @@ -2283,6 +2330,7 @@ def test_execute_command_sets_complex_type_fields_correctly( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), **complex_arg_types, ) thrift_backend.execute_command( From 1dd40a10b0227e837e5adda96816f2d54c22d58d Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Tue, 12 Aug 2025 11:38:28 +0530 Subject: [PATCH 14/25] review comments Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/auth.py | 2 +- src/databricks/sql/auth/common.py | 36 +++++----- src/databricks/sql/auth/oauth.py | 4 +- src/databricks/sql/client.py | 70 +++++-------------- src/databricks/sql/cloudfetch/downloader.py | 22 +++--- src/databricks/sql/common/feature_flag.py | 4 +- .../sql/common/unified_http_client.py | 24 +++++-- src/databricks/sql/session.py | 4 -- .../sql/telemetry/telemetry_client.py | 3 +- src/databricks/sql/utils.py | 59 ++++++++++++++-- 10 files changed, 127 insertions(+), 101 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 59da2a422..a8accac06 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -56,7 +56,7 @@ def get_auth_provider(cfg: ClientContext, http_client): cfg.oauth_client_id, cfg.oauth_scopes, http_client, - cfg.auth_type or "databricks-oauth", + cfg.auth_type or AuthType.DATABRICKS_OAUTH.value, ) else: raise RuntimeError("No valid authentication settings!") diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index e80fac189..4ae7afb0b 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -2,6 +2,8 @@ import logging from typing import Optional, List from urllib.parse import urlparse +from databricks.sql.auth.retry import DatabricksRetryPolicy +from databricks.sql.common.http import HttpMethod logger = logging.getLogger(__name__) @@ -38,17 +40,17 @@ def __init__( # HTTP client configuration parameters ssl_options=None, # SSLOptions type socket_timeout: Optional[float] = None, - retry_stop_after_attempts_count: Optional[int] = None, - retry_delay_min: Optional[float] = None, - retry_delay_max: Optional[float] = None, - retry_stop_after_attempts_duration: Optional[float] = None, - retry_delay_default: Optional[float] = None, + retry_stop_after_attempts_count: int = 5, + retry_delay_min: float = 1.0, + retry_delay_max: float = 60.0, + retry_stop_after_attempts_duration: float = 900.0, + retry_delay_default: float = 5.0, retry_dangerous_codes: Optional[List[int]] = None, http_proxy: Optional[str] = None, proxy_username: Optional[str] = None, proxy_password: Optional[str] = None, - pool_connections: Optional[int] = None, - pool_maxsize: Optional[int] = None, + pool_connections: int = 10, + pool_maxsize: int = 20, user_agent: Optional[str] = None, ): self.hostname = hostname @@ -69,19 +71,17 @@ def __init__( # HTTP client configuration self.ssl_options = ssl_options self.socket_timeout = socket_timeout - self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5 - self.retry_delay_min = retry_delay_min or 1.0 - self.retry_delay_max = retry_delay_max or 60.0 - self.retry_stop_after_attempts_duration = ( - retry_stop_after_attempts_duration or 900.0 - ) - self.retry_delay_default = retry_delay_default or 5.0 + self.retry_stop_after_attempts_count = retry_stop_after_attempts_count + self.retry_delay_min = retry_delay_min + self.retry_delay_max = retry_delay_max + self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration + self.retry_delay_default = retry_delay_default self.retry_dangerous_codes = retry_dangerous_codes or [] self.http_proxy = http_proxy self.proxy_username = proxy_username self.proxy_password = proxy_password - self.pool_connections = pool_connections or 10 - self.pool_maxsize = pool_maxsize or 20 + self.pool_connections = pool_connections + self.pool_maxsize = pool_maxsize self.user_agent = user_agent @@ -113,7 +113,9 @@ def get_azure_tenant_id_from_host(host: str, http_client) -> str: login_url = f"{host}/aad/auth" logger.debug("Loading tenant ID from %s", login_url) - with http_client.request_context("GET", login_url, allow_redirects=False) as resp: + with http_client.request_context( + HttpMethod.GET, login_url, allow_redirects=False + ) as resp: if resp.status // 100 != 3: raise ValueError( f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}" diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 9fdf3955a..09753c9ff 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -87,7 +87,7 @@ def __fetch_well_known_config(self, hostname: str): known_config_url = self.idp_endpoint.get_openid_config_url(hostname) try: - response = self.http_client.request("GET", url=known_config_url) + response = self.http_client.request(HttpMethod.GET, url=known_config_url) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status response.json = lambda: json.loads(response.data.decode()) @@ -197,7 +197,7 @@ def __send_token_request(self, token_request_url, data): } # Use unified HTTP client response = self.http_client.request( - "POST", url=token_request_url, body=data, headers=headers + HttpMethod.POST, url=token_request_url, body=data, headers=headers ) # Convert urllib3 response to dict for compatibility return json.loads(response.data.decode()) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 634c7e261..3cd7bcacf 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -31,6 +31,7 @@ transform_paramstyle, ColumnTable, ColumnQueue, + build_client_context, ) from databricks.sql.parameters.native import ( DbsqlParameterBase, @@ -52,6 +53,7 @@ from databricks.sql.auth.common import ClientContext from databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.http import HttpMethod from databricks.sql.thrift_api.TCLIService.ttypes import ( TOpenSessionResp, @@ -254,14 +256,14 @@ def read(self) -> Optional[OAuthToken]: "telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE ) - client_context = self._build_client_context(server_hostname, **kwargs) - http_client = UnifiedHttpClient(client_context) + client_context = build_client_context(server_hostname, __version__, **kwargs) + self.http_client = UnifiedHttpClient(client_context) try: self.session = Session( server_hostname, http_path, - http_client, + self.http_client, http_headers, session_configuration, catalog, @@ -350,50 +352,6 @@ def _set_use_inline_params_with_warning(self, value: Union[bool, str]): return value - def _build_client_context(self, server_hostname: str, **kwargs): - """Build ClientContext for HTTP client configuration.""" - from databricks.sql.auth.common import ClientContext - from databricks.sql.types import SSLOptions - - # Extract SSL options - ssl_options = SSLOptions( - tls_verify=not kwargs.get("_tls_no_verify", False), - tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), - tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), - tls_client_cert_file=kwargs.get("_tls_client_cert_file"), - tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), - tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), - ) - - # Build user agent - user_agent_entry = kwargs.get("user_agent_entry", "") - if user_agent_entry: - user_agent = f"PyDatabricksSqlConnector/{__version__} ({user_agent_entry})" - else: - user_agent = f"PyDatabricksSqlConnector/{__version__}" - - return ClientContext( - hostname=server_hostname, - ssl_options=ssl_options, - socket_timeout=kwargs.get("_socket_timeout"), - retry_stop_after_attempts_count=kwargs.get( - "_retry_stop_after_attempts_count" - ), - retry_delay_min=kwargs.get("_retry_delay_min"), - retry_delay_max=kwargs.get("_retry_delay_max"), - retry_stop_after_attempts_duration=kwargs.get( - "_retry_stop_after_attempts_duration" - ), - retry_delay_default=kwargs.get("_retry_delay_default"), - retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), - http_proxy=kwargs.get("_http_proxy"), - proxy_username=kwargs.get("_proxy_username"), - proxy_password=kwargs.get("_proxy_password"), - pool_connections=kwargs.get("_pool_connections"), - pool_maxsize=kwargs.get("_pool_maxsize"), - user_agent=user_agent, - ) - # The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently. def __enter__(self) -> "Connection": return self @@ -447,7 +405,7 @@ def get_protocol_version(openSessionResp: TOpenSessionResp): @property def open(self) -> bool: """Return whether the connection is open by checking if the session is open.""" - return hasattr(self, "session") and self.session.is_open + return self.session.is_open def cursor( self, @@ -497,6 +455,10 @@ def _close(self, close_cursors=True) -> None: TelemetryClientFactory.close(self.get_session_id_hex()) + # Close HTTP client that was created by this connection + if self.http_client: + self.http_client.close() + def commit(self): """No-op because Databricks does not support transactions""" pass @@ -796,8 +758,8 @@ def _handle_staging_put( ) with open(local_file, "rb") as fh: - r = self.connection.session.http_client.request( - "PUT", presigned_url, body=fh.read(), headers=headers + r = self.connection.http_client.request( + HttpMethod.PUT, presigned_url, body=fh.read(), headers=headers ) # fmt: off @@ -837,8 +799,8 @@ def _handle_staging_get( session_id_hex=self.connection.get_session_id_hex(), ) - r = self.connection.session.http_client.request( - "GET", presigned_url, headers=headers + r = self.connection.http_client.request( + HttpMethod.GET, presigned_url, headers=headers ) # response.ok verifies the status code is not between 400-600. @@ -860,8 +822,8 @@ def _handle_staging_remove( ): """Make an HTTP DELETE request to the presigned_url""" - r = self.connection.session.http_client.request( - "DELETE", presigned_url, headers=headers + r = self.connection.http_client.request( + HttpMethod.DELETE, presigned_url, headers=headers ) if r.status >= 400: diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index a2a7837f0..e6d1c6d10 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -10,6 +10,7 @@ from databricks.sql.types import SSLOptions from databricks.sql.telemetry.latency_logger import log_latency from databricks.sql.telemetry.models.event import StatementType +from databricks.sql.common.unified_http_client import UnifiedHttpClient logger = logging.getLogger(__name__) @@ -79,9 +80,10 @@ def run(self) -> DownloadedFile: """ logger.debug( - "ResultSetDownloadHandler: starting file download, chunk id {}, offset {}, row count {}".format( - self.chunk_id, self.link.startRowOffset, self.link.rowCount - ) + "ResultSetDownloadHandler: starting file download, chunk id %s, offset %s, row count %s", + self.chunk_id, + self.link.startRowOffset, + self.link.rowCount, ) # Check if link is already expired or is expiring @@ -92,7 +94,7 @@ def run(self) -> DownloadedFile: start_time = time.time() with self._http_client.request_context( - method="GET", + method=HttpMethod.GET, url=self.link.fileLink, timeout=self.settings.download_timeout, headers=self.link.httpHeaders, @@ -116,15 +118,15 @@ def run(self) -> DownloadedFile: # The size of the downloaded file should match the size specified from TSparkArrowResultLink if len(decompressed_data) != self.link.bytesNum: logger.debug( - "ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format( - len(decompressed_data), self.link.bytesNum - ) + "ResultSetDownloadHandler: downloaded file size %s does not match the expected value %s", + len(decompressed_data), + self.link.bytesNum, ) logger.debug( - "ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format( - self.link.startRowOffset, self.link.rowCount - ) + "ResultSetDownloadHandler: successfully downloaded file, offset %s, row count %s", + self.link.startRowOffset, + self.link.rowCount, ) return DownloadedFile( diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 2b7e27ab3..8a1cf5bd5 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -5,6 +5,8 @@ from concurrent.futures import ThreadPoolExecutor from typing import Dict, Optional, List, Any, TYPE_CHECKING +from databricks.sql.common.http import HttpMethod + if TYPE_CHECKING: from databricks.sql.client import Connection @@ -111,7 +113,7 @@ def _refresh_flags(self): headers["User-Agent"] = self._connection.session.useragent_header response = self._http_client.request( - "GET", self._feature_flag_endpoint, headers=headers, timeout=30 + HttpMethod.GET, self._feature_flag_endpoint, headers=headers, timeout=30 ) if response.status == 200: diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index bb26ae9de..62cfb3001 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -2,7 +2,7 @@ import ssl import urllib.parse from contextlib import contextmanager -from typing import Dict, Any, Optional, Generator, Union +from typing import Dict, Any, Optional, Generator import urllib3 from urllib3 import PoolManager, ProxyManager @@ -11,6 +11,7 @@ from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType from databricks.sql.exc import RequestError +from databricks.sql.common.http import HttpMethod logger = logging.getLogger(__name__) @@ -135,13 +136,17 @@ def _prepare_retry_policy(self): @contextmanager def request_context( - self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, ) -> Generator[urllib3.HTTPResponse, None, None]: """ Context manager for making HTTP requests with proper resource cleanup. Args: - method: HTTP method (GET, POST, PUT, DELETE) + method: HTTP method (HttpMethod.GET, HttpMethod.POST, HttpMethod.PUT, HttpMethod.DELETE) url: URL to request headers: Optional headers dict **kwargs: Additional arguments passed to urllib3 request @@ -160,7 +165,7 @@ def request_context( try: response = self._pool_manager.request( - method=method, url=url, headers=request_headers, **kwargs + method=method.value, url=url, headers=request_headers, **kwargs ) yield response except MaxRetryError as e: @@ -174,22 +179,27 @@ def request_context( response.close() def request( - self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, ) -> urllib3.HTTPResponse: """ Make an HTTP request. Args: - method: HTTP method (GET, POST, PUT, DELETE, etc.) + method: HTTP method (HttpMethod.GET, HttpMethod.POST, HttpMethod.PUT, HttpMethod.DELETE, etc.) url: URL to request headers: Optional headers dict **kwargs: Additional arguments passed to urllib3 request Returns: - urllib3.HTTPResponse: The HTTP response object with data pre-loaded + urllib3.HTTPResponse: The HTTP response object with data and metadata pre-loaded """ with self.request_context(method, url, headers=headers, **kwargs) as response: # Read the response data to ensure it's available after context exit + # Note: status and headers remain accessible after close(), only data needs caching response._body = response.data return response diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 0cba8be48..d8ba5d125 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -193,7 +193,3 @@ def close(self) -> None: logger.error("Attempt to close session raised a local exception: %s", e) self.is_open = False - - # Close HTTP client if it exists - if hasattr(self, "http_client") and self.http_client: - self.http_client.close() diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 9887b67a7..29935dc3a 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -40,6 +40,7 @@ from databricks.sql.telemetry.utils import BaseTelemetryClient from databricks.sql.common.feature_flag import FeatureFlagsContextFactory from databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.http import HttpMethod if TYPE_CHECKING: from databricks.sql.client import Connection @@ -293,7 +294,7 @@ def _send_with_unified_client(self, url, data, headers, timeout=900): """Helper method to send telemetry using the unified HTTP client.""" try: response = self._http_client.request( - "POST", url, body=data, headers=headers, timeout=timeout + HttpMethod.POST, url, body=data, headers=headers, timeout=timeout ) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index ff48e0e91..7b9746df9 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -105,16 +105,16 @@ def build_queue( elif row_set_type == TSparkRowSetType.URL_BASED_SET: return ThriftCloudFetchQueue( schema_bytes=arrow_schema_bytes, + start_row_offset=t_row_set.startRowOffset, + result_links=t_row_set.resultLinks, + lz4_compressed=lz4_compressed, + description=description, max_download_threads=max_download_threads, ssl_options=ssl_options, session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, http_client=http_client, - start_row_offset=t_row_set.startRowOffset, - result_links=t_row_set.resultLinks, - lz4_compressed=lz4_compressed, - description=description, ) else: raise AssertionError("Row set type is not valid") @@ -882,3 +882,54 @@ def concat_table_chunks( return ColumnTable(result_table, table_chunks[0].column_names) else: return pyarrow.concat_tables(table_chunks, use_threads=True) + + +def build_client_context(server_hostname: str, version: str, **kwargs): + """Build ClientContext for HTTP client configuration.""" + from databricks.sql.auth.common import ClientContext + from databricks.sql.types import SSLOptions + + # Extract SSL options + ssl_options = SSLOptions( + tls_verify=not kwargs.get("_tls_no_verify", False), + tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), + tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), + tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), + ) + + # Build user agent + user_agent_entry = kwargs.get("user_agent_entry", "") + if user_agent_entry: + user_agent = f"PyDatabricksSqlConnector/{version} ({user_agent_entry})" + else: + user_agent = f"PyDatabricksSqlConnector/{version}" + + # Build ClientContext kwargs, excluding None values to use defaults + context_kwargs = { + "hostname": server_hostname, + "ssl_options": ssl_options, + "user_agent": user_agent, + } + + # Only add non-None values to let defaults work + for param, kwarg_key in [ + ("socket_timeout", "_socket_timeout"), + ("retry_stop_after_attempts_count", "_retry_stop_after_attempts_count"), + ("retry_delay_min", "_retry_delay_min"), + ("retry_delay_max", "_retry_delay_max"), + ("retry_stop_after_attempts_duration", "_retry_stop_after_attempts_duration"), + ("retry_delay_default", "_retry_delay_default"), + ("retry_dangerous_codes", "_retry_dangerous_codes"), + ("http_proxy", "_http_proxy"), + ("proxy_username", "_proxy_username"), + ("proxy_password", "_proxy_password"), + ("pool_connections", "_pool_connections"), + ("pool_maxsize", "_pool_maxsize"), + ]: + value = kwargs.get(kwarg_key) + if value is not None: + context_kwargs[param] = value + + return ClientContext(**context_kwargs) From 3847acac62be645cba8294c513399f75eb9bc1b2 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Tue, 12 Aug 2025 12:23:53 +0530 Subject: [PATCH 15/25] fix warnings Signed-off-by: Vikrant Puppala --- tests/unit/test_telemetry.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 0e828497f..ab07b400c 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -349,7 +349,11 @@ def test_connection_failure_sends_correct_telemetry_payload( """ error_message = "Could not connect to host" - mock_session.side_effect = Exception(error_message) + # Set up the mock to create a session instance first, then make open() fail + mock_session_instance = MagicMock() + mock_session_instance.is_open = False # Ensure cleanup is safe + mock_session_instance.open.side_effect = Exception(error_message) + mock_session.return_value = mock_session_instance try: sql.connect(server_hostname="test-host", http_path="/test-path") @@ -391,6 +395,7 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + mock_session_instance.is_open = False # Connection starts closed for test cleanup # Set up mock HTTP client on the session mock_http_client = MagicMock() @@ -418,6 +423,7 @@ def test_telemetry_disabled_when_flag_is_false( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + mock_session_instance.is_open = False # Connection starts closed for test cleanup # Set up mock HTTP client on the session mock_http_client = MagicMock() @@ -445,6 +451,7 @@ def test_telemetry_disabled_when_flag_request_fails( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + mock_session_instance.is_open = False # Connection starts closed for test cleanup # Set up mock HTTP client on the session mock_http_client = MagicMock() From d9a4797bd54632375eb3487233926c208e2abfbc Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Tue, 12 Aug 2025 12:41:42 +0530 Subject: [PATCH 16/25] fix check-types Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/common.py | 30 ++++++++++--------- src/databricks/sql/utils.py | 50 +++++++++++++------------------ 2 files changed, 37 insertions(+), 43 deletions(-) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 4ae7afb0b..36a0f3707 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -40,17 +40,17 @@ def __init__( # HTTP client configuration parameters ssl_options=None, # SSLOptions type socket_timeout: Optional[float] = None, - retry_stop_after_attempts_count: int = 5, - retry_delay_min: float = 1.0, - retry_delay_max: float = 60.0, - retry_stop_after_attempts_duration: float = 900.0, - retry_delay_default: float = 5.0, + retry_stop_after_attempts_count: Optional[int] = None, + retry_delay_min: Optional[float] = None, + retry_delay_max: Optional[float] = None, + retry_stop_after_attempts_duration: Optional[float] = None, + retry_delay_default: Optional[float] = None, retry_dangerous_codes: Optional[List[int]] = None, http_proxy: Optional[str] = None, proxy_username: Optional[str] = None, proxy_password: Optional[str] = None, - pool_connections: int = 10, - pool_maxsize: int = 20, + pool_connections: Optional[int] = None, + pool_maxsize: Optional[int] = None, user_agent: Optional[str] = None, ): self.hostname = hostname @@ -71,17 +71,19 @@ def __init__( # HTTP client configuration self.ssl_options = ssl_options self.socket_timeout = socket_timeout - self.retry_stop_after_attempts_count = retry_stop_after_attempts_count - self.retry_delay_min = retry_delay_min - self.retry_delay_max = retry_delay_max - self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration - self.retry_delay_default = retry_delay_default + self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5 + self.retry_delay_min = retry_delay_min or 1.0 + self.retry_delay_max = retry_delay_max or 10.0 + self.retry_stop_after_attempts_duration = ( + retry_stop_after_attempts_duration or 300.0 + ) + self.retry_delay_default = retry_delay_default or 5.0 self.retry_dangerous_codes = retry_dangerous_codes or [] self.http_proxy = http_proxy self.proxy_username = proxy_username self.proxy_password = proxy_password - self.pool_connections = pool_connections - self.pool_maxsize = pool_maxsize + self.pool_connections = pool_connections or 10 + self.pool_maxsize = pool_maxsize or 20 self.user_agent = user_agent diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 7b9746df9..ce2ba5eaf 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Sequence from dateutil import parser import datetime @@ -9,7 +9,6 @@ from collections.abc import Mapping from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Union, Sequence import re import lz4.frame @@ -906,30 +905,23 @@ def build_client_context(server_hostname: str, version: str, **kwargs): else: user_agent = f"PyDatabricksSqlConnector/{version}" - # Build ClientContext kwargs, excluding None values to use defaults - context_kwargs = { - "hostname": server_hostname, - "ssl_options": ssl_options, - "user_agent": user_agent, - } - - # Only add non-None values to let defaults work - for param, kwarg_key in [ - ("socket_timeout", "_socket_timeout"), - ("retry_stop_after_attempts_count", "_retry_stop_after_attempts_count"), - ("retry_delay_min", "_retry_delay_min"), - ("retry_delay_max", "_retry_delay_max"), - ("retry_stop_after_attempts_duration", "_retry_stop_after_attempts_duration"), - ("retry_delay_default", "_retry_delay_default"), - ("retry_dangerous_codes", "_retry_dangerous_codes"), - ("http_proxy", "_http_proxy"), - ("proxy_username", "_proxy_username"), - ("proxy_password", "_proxy_password"), - ("pool_connections", "_pool_connections"), - ("pool_maxsize", "_pool_maxsize"), - ]: - value = kwargs.get(kwarg_key) - if value is not None: - context_kwargs[param] = value - - return ClientContext(**context_kwargs) + # Explicitly construct ClientContext with proper types + return ClientContext( + hostname=server_hostname, + ssl_options=ssl_options, + user_agent=user_agent, + socket_timeout=kwargs.get("_socket_timeout"), + retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count"), + retry_delay_min=kwargs.get("_retry_delay_min"), + retry_delay_max=kwargs.get("_retry_delay_max"), + retry_stop_after_attempts_duration=kwargs.get( + "_retry_stop_after_attempts_duration" + ), + retry_delay_default=kwargs.get("_retry_delay_default"), + retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), + http_proxy=kwargs.get("_http_proxy"), + proxy_username=kwargs.get("_proxy_username"), + proxy_password=kwargs.get("_proxy_password"), + pool_connections=kwargs.get("_pool_connections"), + pool_maxsize=kwargs.get("_pool_maxsize"), + ) From ba2a3a9827478dad7e983c8d0ac4c9e49d4bcb21 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Tue, 12 Aug 2025 22:38:22 +0530 Subject: [PATCH 17/25] remove separate http client for telemetry Signed-off-by: Vikrant Puppala --- src/databricks/sql/client.py | 4 +- .../sql/telemetry/telemetry_client.py | 55 +++-------------- tests/unit/test_telemetry.py | 59 +++++-------------- 3 files changed, 26 insertions(+), 92 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 3cd7bcacf..ecdf66401 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -279,7 +279,7 @@ def read(self) -> Optional[OAuthToken]: host_url=server_hostname, http_path=http_path, port=kwargs.get("_port", 443), - client_context=client_context, + http_client=self.http_client, user_agent=self.session.useragent_header if hasattr(self, "session") else None, @@ -301,7 +301,7 @@ def read(self) -> Optional[OAuthToken]: auth_provider=self.session.auth_provider, host_url=self.session.host, batch_size=self.telemetry_batch_size, - client_context=client_context, + http_client=self.http_client, ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 29935dc3a..f933885cf 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -154,44 +154,6 @@ def _flush(self): pass -class TelemetryHttpClientSingleton: - """ - Singleton HTTP client for telemetry operations. - - This ensures that telemetry has its own dedicated HTTP client that - is independent of individual connection lifecycles. - """ - - _instance = None - _lock = threading.RLock() - - def __new__(cls): - if cls._instance is None: - with cls._lock: - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._http_client = None - cls._instance._initialized = False - return cls._instance - - def get_http_client(self, client_context): - """Get or create the singleton HTTP client.""" - if not self._initialized and client_context: - with self._lock: - if not self._initialized: - self._http_client = UnifiedHttpClient(client_context) - self._initialized = True - return self._http_client - - def close(self): - """Close the singleton HTTP client.""" - with self._lock: - if self._http_client: - self._http_client.close() - self._http_client = None - self._initialized = False - - class TelemetryClient(BaseTelemetryClient): """ Telemetry client class that handles sending telemetry events in batches to the server. @@ -210,7 +172,7 @@ def __init__( host_url, executor, batch_size, - client_context, + http_client, ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled @@ -224,9 +186,8 @@ def __init__( self._host_url = host_url self._executor = executor - # Use singleton HTTP client for telemetry instead of connection-specific client - self._http_client_singleton = TelemetryHttpClientSingleton() - self._http_client = self._http_client_singleton.get_http_client(client_context) + # Use the provided HTTP client directly + self._http_client = http_client def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" @@ -503,7 +464,7 @@ def initialize_telemetry_client( auth_provider, host_url, batch_size, - client_context, + http_client, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: @@ -526,7 +487,7 @@ def initialize_telemetry_client( host_url=host_url, executor=TelemetryClientFactory._executor, batch_size=batch_size, - client_context=client_context, + http_client=http_client, ) else: TelemetryClientFactory._clients[ @@ -579,10 +540,10 @@ def connection_failure_log( host_url: str, http_path: str, port: int, - client_context, + http_client, user_agent: Optional[str] = None, ): - """Send error telemetry when connection creation fails, without requiring a session""" + """Send error telemetry when connection creation fails, using existing HTTP client""" UNAUTH_DUMMY_SESSION_ID = "unauth_session_id" @@ -592,7 +553,7 @@ def connection_failure_log( auth_provider=None, host_url=host_url, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - client_context=client_context, + http_client=http_client, ) telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index ab07b400c..ee0590ff8 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -19,17 +19,12 @@ @pytest.fixture -@patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") -def mock_telemetry_client(mock_singleton_class): +def mock_telemetry_client(): """Create a mock telemetry client for testing.""" session_id = str(uuid.uuid4()) auth_provider = AccessTokenAuthProvider("test-token") executor = MagicMock() - mock_client_context = MagicMock() - - # Mock the singleton to return a mock HTTP client - mock_singleton = mock_singleton_class.return_value - mock_singleton.get_http_client.return_value = MagicMock() + mock_http_client = MagicMock() return TelemetryClient( telemetry_enabled=True, @@ -38,7 +33,7 @@ def mock_telemetry_client(mock_singleton_class): host_url="test-host.com", executor=executor, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - client_context=mock_client_context, + http_client=mock_http_client, ) @@ -217,16 +212,11 @@ def telemetry_system_reset(self): TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") - def test_client_lifecycle_flow(self, mock_singleton_class): + def test_client_lifecycle_flow(self): """Test complete client lifecycle: initialize -> use -> close.""" session_id_hex = "test-session" auth_provider = AccessTokenAuthProvider("token") - mock_client_context = MagicMock() - - # Mock the singleton to return a mock HTTP client - mock_singleton = mock_singleton_class.return_value - mock_singleton.get_http_client.return_value = MagicMock() + mock_http_client = MagicMock() # Initialize enabled client TelemetryClientFactory.initialize_telemetry_client( @@ -235,7 +225,7 @@ def test_client_lifecycle_flow(self, mock_singleton_class): auth_provider=auth_provider, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - client_context=mock_client_context, + http_client=mock_http_client, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -248,18 +238,11 @@ def test_client_lifecycle_flow(self, mock_singleton_class): mock_close.assert_called_once() # Should get NoopTelemetryClient after close - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - assert isinstance(client, NoopTelemetryClient) - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") - def test_disabled_telemetry_flow(self, mock_singleton_class): - """Test that disabled telemetry uses NoopTelemetryClient.""" + def test_disabled_telemetry_creates_noop_client(self): + """Test that disabled telemetry creates NoopTelemetryClient.""" session_id_hex = "test-session" - mock_client_context = MagicMock() - - # Mock the singleton to return a mock HTTP client - mock_singleton = mock_singleton_class.return_value - mock_singleton.get_http_client.return_value = MagicMock() + mock_http_client = MagicMock() TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, @@ -267,21 +250,16 @@ def test_disabled_telemetry_flow(self, mock_singleton_class): auth_provider=None, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - client_context=mock_client_context, + http_client=mock_http_client, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") - def test_factory_error_handling(self, mock_singleton_class): + def test_factory_error_handling(self): """Test that factory errors fall back to NoopTelemetryClient.""" session_id = "test-session" - mock_client_context = MagicMock() - - # Mock the singleton to return a mock HTTP client - mock_singleton = mock_singleton_class.return_value - mock_singleton.get_http_client.return_value = MagicMock() + mock_http_client = MagicMock() # Simulate initialization error with patch( @@ -294,23 +272,18 @@ def test_factory_error_handling(self, mock_singleton_class): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - client_context=mock_client_context, + http_client=mock_http_client, ) # Should fall back to NoopTelemetryClient client = TelemetryClientFactory.get_telemetry_client(session_id) assert isinstance(client, NoopTelemetryClient) - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") - def test_factory_shutdown_flow(self, mock_singleton_class): + def test_factory_shutdown_flow(self): """Test factory shutdown when last client is removed.""" session1 = "session-1" session2 = "session-2" - mock_client_context = MagicMock() - - # Mock the singleton to return a mock HTTP client - mock_singleton = mock_singleton_class.return_value - mock_singleton.get_http_client.return_value = MagicMock() + mock_http_client = MagicMock() # Initialize multiple clients for session in [session1, session2]: @@ -320,7 +293,7 @@ def test_factory_shutdown_flow(self, mock_singleton_class): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - client_context=mock_client_context, + http_client=mock_http_client, ) # Factory should be initialized From d1f045ebe252883997bd1f67643da6da0cff5ada Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Tue, 12 Aug 2025 22:55:27 +0530 Subject: [PATCH 18/25] more clean up Signed-off-by: Vikrant Puppala --- src/databricks/sql/result_set.py | 2 +- .../sql/telemetry/telemetry_client.py | 24 +++++++------------ 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 77673db9a..6c4c3a43a 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -244,7 +244,7 @@ def __init__( session_id_hex=connection.get_session_id_hex(), statement_id=execute_response.command_id.to_hex_guid(), chunk_id=self.num_chunks, - http_client=connection.session.http_client, + http_client=connection.http_client, ) if t_row_set.resultLinks: self.num_chunks += len(t_row_set.resultLinks) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index f933885cf..f6ad4433d 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -257,18 +257,6 @@ def _send_with_unified_client(self, url, data, headers, timeout=900): response = self._http_client.request( HttpMethod.POST, url, body=data, headers=headers, timeout=timeout ) - # Convert urllib3 response to requests-like response for compatibility - response.status_code = response.status - response.ok = 200 <= response.status < 300 - response.json = ( - lambda: json.loads(response.data.decode()) if response.data else {} - ) - # Add raise_for_status method - def raise_for_status(): - if not response.ok: - raise Exception(f"HTTP {response.status_code}") - - response.raise_for_status = raise_for_status return response except Exception as e: logger.error("Failed to send telemetry with unified client: %s", e) @@ -279,14 +267,18 @@ def _telemetry_request_callback(self, future, sent_count: int): try: response = future.result() - if not response.ok: + # Check if response is successful (urllib3 uses response.status) + is_success = 200 <= response.status < 300 + if not is_success: logger.debug( "Telemetry request failed with status code: %s, response: %s", - response.status_code, - response.text, + response.status, + response.data.decode() if response.data else "", ) - telemetry_response = TelemetryResponse(**response.json()) + # Parse JSON response (urllib3 uses response.data) + response_data = json.loads(response.data.decode()) if response.data else {} + telemetry_response = TelemetryResponse(**response_data) logger.debug( "Pushed Telemetry logs with success count: %s, error count: %s", From 4e6623009be8c21c46088c4da420b1078dc80c59 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Wed, 13 Aug 2025 13:03:07 +0530 Subject: [PATCH 19/25] remove excess release_connection call --- src/databricks/sql/backend/thrift_backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 0df35bab9..1654a1d5a 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -619,8 +619,7 @@ def close_session(self, session_id: SessionId) -> None: try: self.make_request(self._client.CloseSession, req) finally: - self._transport.release_connection() - self._transport.close() + self._transport.close() def _check_command_not_in_error_or_closed_state( self, op_handle, get_operations_resp From 67020f1afb82dd15527a94ce90c97e3b5c6b6e2c Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Wed, 13 Aug 2025 21:59:58 +0530 Subject: [PATCH 20/25] formatting (black) - fix some closures --- src/databricks/sql/backend/thrift_backend.py | 4 +-- src/databricks/sql/client.py | 2 ++ .../sql/telemetry/telemetry_client.py | 2 +- tests/e2e/common/staging_ingestion_tests.py | 4 +-- tests/e2e/common/uc_volume_tests.py | 4 +-- tests/e2e/test_concurrent_telemetry.py | 8 +++-- tests/e2e/test_driver.py | 6 ++-- tests/unit/test_auth.py | 32 +++++++++++------ tests/unit/test_cloud_fetch_queue.py | 36 ++++++++++++------- tests/unit/test_downloader.py | 13 ++++--- tests/unit/test_sea_queue.py | 2 +- tests/unit/test_telemetry.py | 32 +++++++++++------ tests/unit/test_thrift_backend.py | 16 ++++----- 13 files changed, 101 insertions(+), 60 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 1654a1d5a..b089eacd5 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -232,7 +232,7 @@ def __init__( try: self._transport.open() except: - self._transport.release_connection() + self._transport.close() raise self._request_lock = threading.RLock() @@ -607,7 +607,7 @@ def open_session(self, session_configuration, catalog, schema) -> SessionId: self._session_id_hex = session_id.hex_guid return session_id except: - self._transport.release_connection() + self._transport.close() raise def close_session(self, session_id: SessionId) -> None: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 3cd7bcacf..0d4f71ae3 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -284,6 +284,8 @@ def read(self) -> Optional[OAuthToken]: if hasattr(self, "session") else None, ) + if self.http_client: + self.http_client.close() raise e self.use_inline_params = self._set_use_inline_params_with_warning( diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 71fcc40c6..fb5c3a116 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -359,6 +359,7 @@ def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() + self._http_client.close() class TelemetryClientFactory: @@ -460,7 +461,6 @@ def initialize_telemetry_client( ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: - with TelemetryClientFactory._lock: TelemetryClientFactory._initialize() diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index 73aa0a113..5b5f4e693 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -80,9 +80,7 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): # GET after REMOVE should fail - with pytest.raises( - Error, match="too many 404 error responses" - ): + with pytest.raises(Error, match="too many 404 error responses"): cursor = conn.cursor() query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'" cursor.execute(query) diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index 93e63bd28..0eeb22789 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -80,9 +80,7 @@ def test_uc_volume_life_cycle(self, catalog, schema): # GET after REMOVE should fail - with pytest.raises( - Error, match="too many 404 error responses" - ): + with pytest.raises(Error, match="too many 404 error responses"): cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" cursor.execute(query) diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py index d2ac4227d..615a7245e 100644 --- a/tests/e2e/test_concurrent_telemetry.py +++ b/tests/e2e/test_concurrent_telemetry.py @@ -122,9 +122,13 @@ def execute_query_worker(thread_id): response = future.result() # Check status using urllib3 method (response.status instead of response.raise_for_status()) if response.status >= 400: - raise Exception(f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}") + raise Exception( + f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}" + ) # Parse JSON using urllib3 method (response.data.decode() instead of response.json()) - response_data = json.loads(response.data.decode()) if response.data else {} + response_data = ( + json.loads(response.data.decode()) if response.data else {} + ) captured_responses.append(response_data) except Exception as e: captured_exceptions.append(e) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 53b7383e6..c8ae8a0bc 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -64,9 +64,9 @@ for name in test_loader.getTestCaseNames(DecimalTestsMixin): if name.startswith("test_"): fn = getattr(DecimalTestsMixin, name) - decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")( - fn - ) + decorated = skipUnless( + pysql_supports_arrow(), "Decimal tests need arrow support" + )(fn) setattr(DecimalTestsMixin, name, decorated) diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index a5ad7562e..e20c58d3d 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -145,7 +145,9 @@ def test_get_python_sql_connector_auth_provider_access_token(self): hostname = "moderakh-test.cloud.databricks.com" kwargs = {"access_token": "dpi123"} mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client, **kwargs + ) self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider") headers = {} @@ -163,7 +165,9 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: hostname = "moderakh-test.cloud.databricks.com" kwargs = {"credentials_provider": MyProvider()} mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client, **kwargs + ) self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider") headers = {} @@ -179,7 +183,9 @@ def test_get_python_sql_connector_auth_provider_noop(self): "_use_cert_as_auth": use_cert_as_auth, } mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client, **kwargs + ) self.assertTrue(type(auth_provider).__name__, "CredentialProvider") def test_get_python_sql_connector_basic_auth(self): @@ -189,7 +195,9 @@ def test_get_python_sql_connector_basic_auth(self): } mock_http_client = MagicMock() with self.assertRaises(ValueError) as e: - get_python_sql_connector_auth_provider("foo.cloud.databricks.com", mock_http_client, **kwargs) + get_python_sql_connector_auth_provider( + "foo.cloud.databricks.com", mock_http_client, **kwargs + ) self.assertIn( "Username/password authentication is no longer supported", str(e.exception) ) @@ -198,7 +206,9 @@ def test_get_python_sql_connector_basic_auth(self): def test_get_python_sql_connector_default_auth(self, mock__initial_get_token): hostname = "foo.cloud.databricks.com" mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client) + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client + ) self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider") self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID) @@ -259,16 +269,16 @@ def test_no_token_refresh__when_token_is_not_expired( def test_get_token_success(self, token_source, http_response): mock_http_client = MagicMock() - + with patch.object(token_source, "_http_client", mock_http_client): # Create a mock response with the expected format mock_response = MagicMock() mock_response.status = 200 mock_response.data.decode.return_value = '{"access_token": "abc123", "token_type": "Bearer", "refresh_token": null}' - + # Mock the request method to return the response directly mock_http_client.request.return_value = mock_response - + token = token_source.get_token() # Assert @@ -279,16 +289,16 @@ def test_get_token_success(self, token_source, http_response): def test_get_token_failure(self, token_source, http_response): mock_http_client = MagicMock() - + with patch.object(token_source, "_http_client", mock_http_client): # Create a mock response with error mock_response = MagicMock() mock_response.status = 400 mock_response.data.decode.return_value = "Bad Request" - + # Mock the request method to return the response directly mock_http_client.request.return_value = mock_response - + with pytest.raises(Exception) as e: token_source.get_token() assert "Failed to get token: 400" in str(e.value) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index 0c3fc7103..aeaf5bce6 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -13,22 +13,24 @@ @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") class CloudFetchQueueSuite(unittest.TestCase): - def create_queue(self, schema_bytes=None, result_links=None, description=None, **kwargs): + def create_queue( + self, schema_bytes=None, result_links=None, description=None, **kwargs + ): """Helper method to create ThriftCloudFetchQueue with sensible defaults""" # Set up defaults for commonly used parameters defaults = { - 'max_download_threads': 10, - 'ssl_options': SSLOptions(), - 'session_id_hex': Mock(), - 'statement_id': Mock(), - 'chunk_id': 0, - 'start_row_offset': 0, - 'lz4_compressed': True, + "max_download_threads": 10, + "ssl_options": SSLOptions(), + "session_id_hex": Mock(), + "statement_id": Mock(), + "chunk_id": 0, + "start_row_offset": 0, + "lz4_compressed": True, } - + # Override defaults with any provided kwargs defaults.update(kwargs) - + mock_http_client = MagicMock() return utils.ThriftCloudFetchQueue( schema_bytes=schema_bytes or MagicMock(), @@ -198,7 +200,12 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() # Create description that matches the 4-column schema - description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] + description = [ + ("col0", "uint32"), + ("col1", "uint32"), + ("col2", "uint32"), + ("col3", "uint32"), + ] queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None @@ -277,7 +284,12 @@ def test_remaining_rows_multiple_tables_fully_returned( def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() # Create description that matches the 4-column schema - description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] + description = [ + ("col0", "uint32"), + ("col1", "uint32"), + ("col2", "uint32"), + ("col3", "uint32"), + ] queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 00b1b849a..4d3570dc6 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -131,7 +131,7 @@ def test_run_uncompressed_successful(self, mock_time): self._setup_mock_http_response(mock_http_client, status=200, data=file_bytes) # Patch the log metrics method to avoid division by zero - with patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): + with patch.object(downloader.ResultSetDownloadHandler, "_log_download_metrics"): d = downloader.ResultSetDownloadHandler( settings, result_link, @@ -160,11 +160,16 @@ def test_run_compressed_successful(self, mock_time): result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789" # Setup mock HTTP response using helper method - self._setup_mock_http_response(mock_http_client, status=200, data=compressed_bytes) + self._setup_mock_http_response( + mock_http_client, status=200, data=compressed_bytes + ) # Mock the decompression method and log metrics to avoid issues - with patch.object(downloader.ResultSetDownloadHandler, '_decompress_data', return_value=file_bytes), \ - patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): + with patch.object( + downloader.ResultSetDownloadHandler, + "_decompress_data", + return_value=file_bytes, + ), patch.object(downloader.ResultSetDownloadHandler, "_log_download_metrics"): d = downloader.ResultSetDownloadHandler( settings, result_link, diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index 6471cb4fd..00e6d4939 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -520,7 +520,7 @@ def test_hybrid_disposition_with_external_links( # Create result data with external links but no attachment result_data = ResultData(external_links=external_links, attachment=None) - # Build queue + # Build queue mock_http_client = MagicMock() queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 738c617bd..b8430b9fc 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -27,7 +27,9 @@ def mock_telemetry_client(): client_context = MagicMock() # Patch the _setup_pool_manager method to avoid SSL file loading - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager" + ): return TelemetryClient( telemetry_enabled=True, session_id_hex=session_id, @@ -85,7 +87,7 @@ def test_network_request_flow(self, mock_http_request, mock_telemetry_client): mock_response.status = 200 mock_response.status_code = 200 mock_http_request.return_value = mock_response - + client = mock_telemetry_client # Create mock events @@ -221,7 +223,9 @@ def test_client_lifecycle_flow(self): client_context = MagicMock() # Initialize enabled client - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager" + ): TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, @@ -289,7 +293,9 @@ def test_factory_shutdown_flow(self): client_context = MagicMock() # Initialize multiple clients - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager" + ): for session in [session1, session2]: TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, @@ -372,8 +378,10 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -400,8 +408,10 @@ def test_telemetry_disabled_when_flag_is_false( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -428,8 +438,10 @@ def test_telemetry_disabled_when_flag_request_fails( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index d4d501c64..a71bce597 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -618,7 +618,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -662,7 +662,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) execute_response, _ = thrift_backend._handle_execute_response( @@ -707,7 +707,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -859,7 +859,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -912,7 +912,7 @@ def test_handle_execute_response_can_handle_without_direct_results( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) ( execute_response, @@ -951,7 +951,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) thrift_backend._results_message_to_execute_response = Mock() @@ -2115,7 +2115,7 @@ def test_retry_args_bounding(self, mock_http_client): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), **retry_delay_args, ) retry_delay_expected_vals = { @@ -2337,7 +2337,7 @@ def test_execute_command_sets_complex_type_fields_correctly( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), **complex_arg_types, ) thrift_backend.execute_command( From 496d7f7e949b1e4a51c2d4eb2c2318752093eea6 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Wed, 13 Aug 2025 22:00:52 +0530 Subject: [PATCH 21/25] Revert "formatting (black) - fix some closures" This reverts commit 67020f1afb82dd15527a94ce90c97e3b5c6b6e2c. --- src/databricks/sql/backend/thrift_backend.py | 4 +-- src/databricks/sql/client.py | 2 -- .../sql/telemetry/telemetry_client.py | 2 +- tests/e2e/common/staging_ingestion_tests.py | 4 ++- tests/e2e/common/uc_volume_tests.py | 4 ++- tests/e2e/test_concurrent_telemetry.py | 8 ++--- tests/e2e/test_driver.py | 6 ++-- tests/unit/test_auth.py | 32 ++++++----------- tests/unit/test_cloud_fetch_queue.py | 36 +++++++------------ tests/unit/test_downloader.py | 13 +++---- tests/unit/test_sea_queue.py | 2 +- tests/unit/test_telemetry.py | 32 ++++++----------- tests/unit/test_thrift_backend.py | 16 ++++----- 13 files changed, 60 insertions(+), 101 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index b089eacd5..1654a1d5a 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -232,7 +232,7 @@ def __init__( try: self._transport.open() except: - self._transport.close() + self._transport.release_connection() raise self._request_lock = threading.RLock() @@ -607,7 +607,7 @@ def open_session(self, session_configuration, catalog, schema) -> SessionId: self._session_id_hex = session_id.hex_guid return session_id except: - self._transport.close() + self._transport.release_connection() raise def close_session(self, session_id: SessionId) -> None: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 0d4f71ae3..3cd7bcacf 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -284,8 +284,6 @@ def read(self) -> Optional[OAuthToken]: if hasattr(self, "session") else None, ) - if self.http_client: - self.http_client.close() raise e self.use_inline_params = self._set_use_inline_params_with_warning( diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index fb5c3a116..71fcc40c6 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -359,7 +359,6 @@ def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() - self._http_client.close() class TelemetryClientFactory: @@ -461,6 +460,7 @@ def initialize_telemetry_client( ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: + with TelemetryClientFactory._lock: TelemetryClientFactory._initialize() diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index 5b5f4e693..73aa0a113 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -80,7 +80,9 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): # GET after REMOVE should fail - with pytest.raises(Error, match="too many 404 error responses"): + with pytest.raises( + Error, match="too many 404 error responses" + ): cursor = conn.cursor() query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'" cursor.execute(query) diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index 0eeb22789..93e63bd28 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -80,7 +80,9 @@ def test_uc_volume_life_cycle(self, catalog, schema): # GET after REMOVE should fail - with pytest.raises(Error, match="too many 404 error responses"): + with pytest.raises( + Error, match="too many 404 error responses" + ): cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" cursor.execute(query) diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py index 615a7245e..d2ac4227d 100644 --- a/tests/e2e/test_concurrent_telemetry.py +++ b/tests/e2e/test_concurrent_telemetry.py @@ -122,13 +122,9 @@ def execute_query_worker(thread_id): response = future.result() # Check status using urllib3 method (response.status instead of response.raise_for_status()) if response.status >= 400: - raise Exception( - f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}" - ) + raise Exception(f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}") # Parse JSON using urllib3 method (response.data.decode() instead of response.json()) - response_data = ( - json.loads(response.data.decode()) if response.data else {} - ) + response_data = json.loads(response.data.decode()) if response.data else {} captured_responses.append(response_data) except Exception as e: captured_exceptions.append(e) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index c8ae8a0bc..53b7383e6 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -64,9 +64,9 @@ for name in test_loader.getTestCaseNames(DecimalTestsMixin): if name.startswith("test_"): fn = getattr(DecimalTestsMixin, name) - decorated = skipUnless( - pysql_supports_arrow(), "Decimal tests need arrow support" - )(fn) + decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")( + fn + ) setattr(DecimalTestsMixin, name, decorated) diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index e20c58d3d..a5ad7562e 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -145,9 +145,7 @@ def test_get_python_sql_connector_auth_provider_access_token(self): hostname = "moderakh-test.cloud.databricks.com" kwargs = {"access_token": "dpi123"} mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider( - hostname, mock_http_client, **kwargs - ) + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider") headers = {} @@ -165,9 +163,7 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: hostname = "moderakh-test.cloud.databricks.com" kwargs = {"credentials_provider": MyProvider()} mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider( - hostname, mock_http_client, **kwargs - ) + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider") headers = {} @@ -183,9 +179,7 @@ def test_get_python_sql_connector_auth_provider_noop(self): "_use_cert_as_auth": use_cert_as_auth, } mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider( - hostname, mock_http_client, **kwargs - ) + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "CredentialProvider") def test_get_python_sql_connector_basic_auth(self): @@ -195,9 +189,7 @@ def test_get_python_sql_connector_basic_auth(self): } mock_http_client = MagicMock() with self.assertRaises(ValueError) as e: - get_python_sql_connector_auth_provider( - "foo.cloud.databricks.com", mock_http_client, **kwargs - ) + get_python_sql_connector_auth_provider("foo.cloud.databricks.com", mock_http_client, **kwargs) self.assertIn( "Username/password authentication is no longer supported", str(e.exception) ) @@ -206,9 +198,7 @@ def test_get_python_sql_connector_basic_auth(self): def test_get_python_sql_connector_default_auth(self, mock__initial_get_token): hostname = "foo.cloud.databricks.com" mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider( - hostname, mock_http_client - ) + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client) self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider") self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID) @@ -269,16 +259,16 @@ def test_no_token_refresh__when_token_is_not_expired( def test_get_token_success(self, token_source, http_response): mock_http_client = MagicMock() - + with patch.object(token_source, "_http_client", mock_http_client): # Create a mock response with the expected format mock_response = MagicMock() mock_response.status = 200 mock_response.data.decode.return_value = '{"access_token": "abc123", "token_type": "Bearer", "refresh_token": null}' - + # Mock the request method to return the response directly mock_http_client.request.return_value = mock_response - + token = token_source.get_token() # Assert @@ -289,16 +279,16 @@ def test_get_token_success(self, token_source, http_response): def test_get_token_failure(self, token_source, http_response): mock_http_client = MagicMock() - + with patch.object(token_source, "_http_client", mock_http_client): # Create a mock response with error mock_response = MagicMock() mock_response.status = 400 mock_response.data.decode.return_value = "Bad Request" - + # Mock the request method to return the response directly mock_http_client.request.return_value = mock_response - + with pytest.raises(Exception) as e: token_source.get_token() assert "Failed to get token: 400" in str(e.value) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index aeaf5bce6..0c3fc7103 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -13,24 +13,22 @@ @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") class CloudFetchQueueSuite(unittest.TestCase): - def create_queue( - self, schema_bytes=None, result_links=None, description=None, **kwargs - ): + def create_queue(self, schema_bytes=None, result_links=None, description=None, **kwargs): """Helper method to create ThriftCloudFetchQueue with sensible defaults""" # Set up defaults for commonly used parameters defaults = { - "max_download_threads": 10, - "ssl_options": SSLOptions(), - "session_id_hex": Mock(), - "statement_id": Mock(), - "chunk_id": 0, - "start_row_offset": 0, - "lz4_compressed": True, + 'max_download_threads': 10, + 'ssl_options': SSLOptions(), + 'session_id_hex': Mock(), + 'statement_id': Mock(), + 'chunk_id': 0, + 'start_row_offset': 0, + 'lz4_compressed': True, } - + # Override defaults with any provided kwargs defaults.update(kwargs) - + mock_http_client = MagicMock() return utils.ThriftCloudFetchQueue( schema_bytes=schema_bytes or MagicMock(), @@ -200,12 +198,7 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() # Create description that matches the 4-column schema - description = [ - ("col0", "uint32"), - ("col1", "uint32"), - ("col2", "uint32"), - ("col3", "uint32"), - ] + description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None @@ -284,12 +277,7 @@ def test_remaining_rows_multiple_tables_fully_returned( def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() # Create description that matches the 4-column schema - description = [ - ("col0", "uint32"), - ("col1", "uint32"), - ("col2", "uint32"), - ("col3", "uint32"), - ] + description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 4d3570dc6..00b1b849a 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -131,7 +131,7 @@ def test_run_uncompressed_successful(self, mock_time): self._setup_mock_http_response(mock_http_client, status=200, data=file_bytes) # Patch the log metrics method to avoid division by zero - with patch.object(downloader.ResultSetDownloadHandler, "_log_download_metrics"): + with patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): d = downloader.ResultSetDownloadHandler( settings, result_link, @@ -160,16 +160,11 @@ def test_run_compressed_successful(self, mock_time): result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789" # Setup mock HTTP response using helper method - self._setup_mock_http_response( - mock_http_client, status=200, data=compressed_bytes - ) + self._setup_mock_http_response(mock_http_client, status=200, data=compressed_bytes) # Mock the decompression method and log metrics to avoid issues - with patch.object( - downloader.ResultSetDownloadHandler, - "_decompress_data", - return_value=file_bytes, - ), patch.object(downloader.ResultSetDownloadHandler, "_log_download_metrics"): + with patch.object(downloader.ResultSetDownloadHandler, '_decompress_data', return_value=file_bytes), \ + patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): d = downloader.ResultSetDownloadHandler( settings, result_link, diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index 00e6d4939..6471cb4fd 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -520,7 +520,7 @@ def test_hybrid_disposition_with_external_links( # Create result data with external links but no attachment result_data = ResultData(external_links=external_links, attachment=None) - # Build queue + # Build queue mock_http_client = MagicMock() queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index b8430b9fc..738c617bd 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -27,9 +27,7 @@ def mock_telemetry_client(): client_context = MagicMock() # Patch the _setup_pool_manager method to avoid SSL file loading - with patch( - "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager" - ): + with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager'): return TelemetryClient( telemetry_enabled=True, session_id_hex=session_id, @@ -87,7 +85,7 @@ def test_network_request_flow(self, mock_http_request, mock_telemetry_client): mock_response.status = 200 mock_response.status_code = 200 mock_http_request.return_value = mock_response - + client = mock_telemetry_client # Create mock events @@ -223,9 +221,7 @@ def test_client_lifecycle_flow(self): client_context = MagicMock() # Initialize enabled client - with patch( - "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager" - ): + with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager'): TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, @@ -293,9 +289,7 @@ def test_factory_shutdown_flow(self): client_context = MagicMock() # Initialize multiple clients - with patch( - "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager" - ): + with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager'): for session in [session1, session2]: TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, @@ -378,10 +372,8 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = ( - False # Connection starts closed for test cleanup - ) - + mock_session_instance.is_open = False # Connection starts closed for test cleanup + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -408,10 +400,8 @@ def test_telemetry_disabled_when_flag_is_false( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = ( - False # Connection starts closed for test cleanup - ) - + mock_session_instance.is_open = False # Connection starts closed for test cleanup + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -438,10 +428,8 @@ def test_telemetry_disabled_when_flag_request_fails( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = ( - False # Connection starts closed for test cleanup - ) - + mock_session_instance.is_open = False # Connection starts closed for test cleanup + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index a71bce597..d4d501c64 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -618,7 +618,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -662,7 +662,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) execute_response, _ = thrift_backend._handle_execute_response( @@ -707,7 +707,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -859,7 +859,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -912,7 +912,7 @@ def test_handle_execute_response_can_handle_without_direct_results( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) ( execute_response, @@ -951,7 +951,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) thrift_backend._results_message_to_execute_response = Mock() @@ -2115,7 +2115,7 @@ def test_retry_args_bounding(self, mock_http_client): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), **retry_delay_args, ) retry_delay_expected_vals = { @@ -2337,7 +2337,7 @@ def test_execute_command_sets_complex_type_fields_correctly( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), **complex_arg_types, ) thrift_backend.execute_command( From 84ec33a01ca33d0580e65aa686fe51921d094dd0 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Wed, 13 Aug 2025 22:02:19 +0530 Subject: [PATCH 22/25] add more http_client closures --- src/databricks/sql/backend/thrift_backend.py | 4 ++-- src/databricks/sql/client.py | 6 ++++-- src/databricks/sql/telemetry/telemetry_client.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 1654a1d5a..b089eacd5 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -232,7 +232,7 @@ def __init__( try: self._transport.open() except: - self._transport.release_connection() + self._transport.close() raise self._request_lock = threading.RLock() @@ -607,7 +607,7 @@ def open_session(self, session_configuration, catalog, schema) -> SessionId: self._session_id_hex = session_id.hex_guid return session_id except: - self._transport.release_connection() + self._transport.close() raise def close_session(self, session_id: SessionId) -> None: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 3cd7bcacf..8150b9663 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -284,6 +284,7 @@ def read(self) -> Optional[OAuthToken]: if hasattr(self, "session") else None, ) + self.http_client.close() raise e self.use_inline_params = self._set_use_inline_params_with_warning( @@ -362,8 +363,9 @@ def __exit__(self, exc_type, exc_value, traceback): def __del__(self): if self.open: logger.debug( - "Closing unclosed connection for session " - "{}".format(self.get_session_id_hex()) + "Closing unclosed connection for session " "{}".format( + self.get_session_id_hex() + ) ) try: self._close(close_cursors=False) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 71fcc40c6..fb5c3a116 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -359,6 +359,7 @@ def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() + self._http_client.close() class TelemetryClientFactory: @@ -460,7 +461,6 @@ def initialize_telemetry_client( ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: - with TelemetryClientFactory._lock: TelemetryClientFactory._initialize() From 76ce5ce3fe083b25297e142a22b65d759fab1556 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Wed, 13 Aug 2025 22:13:20 +0530 Subject: [PATCH 23/25] remove excess close call --- src/databricks/sql/client.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 8150b9663..3cd7bcacf 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -284,7 +284,6 @@ def read(self) -> Optional[OAuthToken]: if hasattr(self, "session") else None, ) - self.http_client.close() raise e self.use_inline_params = self._set_use_inline_params_with_warning( @@ -363,9 +362,8 @@ def __exit__(self, exc_type, exc_value, traceback): def __del__(self): if self.open: logger.debug( - "Closing unclosed connection for session " "{}".format( - self.get_session_id_hex() - ) + "Closing unclosed connection for session " + "{}".format(self.get_session_id_hex()) ) try: self._close(close_cursors=False) From 4452725590dbd56d2b17dde1bd1c1e3f15c1ba3a Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Thu, 14 Aug 2025 10:42:33 +0530 Subject: [PATCH 24/25] wait for _flush before closing HTTP client --- .../sql/telemetry/telemetry_client.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index fb5c3a116..2a13d8747 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -4,6 +4,7 @@ import json from concurrent.futures import ThreadPoolExecutor from concurrent.futures import Future +from concurrent.futures import wait from datetime import datetime, timezone from typing import List, Dict, Any, Optional, TYPE_CHECKING from databricks.sql.telemetry.models.event import ( @@ -182,6 +183,7 @@ def __init__( self._user_agent = None self._events_batch = [] self._lock = threading.RLock() + self._pending_futures = set() self._driver_connection_params = None self._host_url = host_url self._executor = executor @@ -245,6 +247,9 @@ def _send_telemetry(self, events): timeout=900, ) + with self._lock: + self._pending_futures.add(future) + future.add_done_callback( lambda fut: self._telemetry_request_callback(fut, sent_count=sent_count) ) @@ -303,6 +308,9 @@ def _telemetry_request_callback(self, future, sent_count: int): except Exception as e: logger.debug("Telemetry request failed with exception: %s", e) + finally: + with self._lock: + self._pending_futures.discard(future) def _export_telemetry_log(self, **telemetry_event_kwargs): """ @@ -356,9 +364,20 @@ def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): ) def close(self): - """Flush remaining events before closing""" + """Flush remaining events and wait for them to complete before closing""" logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() + + with self._lock: + futures_to_wait_on = list(self._pending_futures) + + if futures_to_wait_on: + logger.debug( + "Waiting for %s pending telemetry requests to complete.", + len(futures_to_wait_on), + ) + wait(futures_to_wait_on) + self._http_client.close() From d90ac80a693e0c52ddb40551093225c4d0af5c60 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Thu, 14 Aug 2025 10:53:19 +0530 Subject: [PATCH 25/25] make close() async --- src/databricks/sql/telemetry/telemetry_client.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 2a13d8747..7245b64d5 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -364,8 +364,15 @@ def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): ) def close(self): - """Flush remaining events and wait for them to complete before closing""" - logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) + """Schedules the client to be closed in the background.""" + logger.debug( + "Scheduling background closure for TelemetryClient of connection %s", + self._session_id_hex, + ) + self._executor.submit(self._close_and_wait) + + def _close_and_wait(self): + """Flush remaining events and wait for them to complete before closing.""" self._flush() with self._lock: