Skip to content
Merged
10 changes: 7 additions & 3 deletions src/databricks/sql/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from databricks.sql.auth.common import AuthType, ClientContext


def get_auth_provider(cfg: ClientContext):
def get_auth_provider(cfg: ClientContext, http_client):
if cfg.credentials_provider:
return ExternalAuthProvider(cfg.credentials_provider)
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
Expand All @@ -19,6 +19,7 @@ def get_auth_provider(cfg: ClientContext):
cfg.hostname,
cfg.azure_client_id,
cfg.azure_client_secret,
http_client,
cfg.azure_tenant_id,
cfg.azure_workspace_resource_id,
)
Expand All @@ -34,6 +35,7 @@ def get_auth_provider(cfg: ClientContext):
cfg.oauth_redirect_port_range,
cfg.oauth_client_id,
cfg.oauth_scopes,
http_client,
cfg.auth_type,
)
elif cfg.access_token is not None:
Expand All @@ -53,6 +55,8 @@ def get_auth_provider(cfg: ClientContext):
cfg.oauth_redirect_port_range,
cfg.oauth_client_id,
cfg.oauth_scopes,
http_client,
cfg.auth_type or AuthType.DATABRICKS_OAUTH.value,
)
else:
raise RuntimeError("No valid authentication settings!")
Expand All @@ -79,7 +83,7 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):
)


def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs):
# TODO : unify all the auth mechanisms with the Python SDK

auth_type = kwargs.get("auth_type")
Expand Down Expand Up @@ -111,4 +115,4 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
credentials_provider=kwargs.get("credentials_provider"),
)
return get_auth_provider(cfg)
return get_auth_provider(cfg, http_client)
7 changes: 6 additions & 1 deletion src/databricks/sql/auth/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
redirect_port_range: List[int],
client_id: str,
scopes: List[str],
http_client,
auth_type: str = "databricks-oauth",
):
try:
Expand All @@ -79,6 +80,7 @@ def __init__(
port_range=redirect_port_range,
client_id=client_id,
idp_endpoint=idp_endpoint,
http_client=http_client,
)
self._hostname = hostname
self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(cloud_scopes)
Expand Down Expand Up @@ -188,6 +190,7 @@ def __init__(
hostname,
azure_client_id,
azure_client_secret,
http_client,
azure_tenant_id=None,
azure_workspace_resource_id=None,
):
Expand All @@ -196,8 +199,9 @@ def __init__(
self.azure_client_secret = azure_client_secret
self.azure_workspace_resource_id = azure_workspace_resource_id
self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(
hostname
hostname, http_client
)
self._http_client = http_client

def auth_type(self) -> str:
return AuthType.AZURE_SP_M2M.value
Expand All @@ -207,6 +211,7 @@ def get_token_source(self, resource: str) -> RefreshableTokenSource:
token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}",
client_id=self.azure_client_id,
client_secret=self.azure_client_secret,
http_client=self._http_client,
extra_params={"resource": resource},
)

Expand Down
67 changes: 49 additions & 18 deletions src/databricks/sql/auth/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import logging
from typing import Optional, List
from urllib.parse import urlparse
from databricks.sql.common.http import DatabricksHttpClient, HttpMethod
from databricks.sql.auth.retry import DatabricksRetryPolicy
from databricks.sql.common.http import HttpMethod

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -36,6 +37,21 @@ def __init__(
tls_client_cert_file: Optional[str] = None,
oauth_persistence=None,
credentials_provider=None,
# HTTP client configuration parameters
ssl_options=None, # SSLOptions type
socket_timeout: Optional[float] = None,
retry_stop_after_attempts_count: Optional[int] = None,
retry_delay_min: Optional[float] = None,
retry_delay_max: Optional[float] = None,
retry_stop_after_attempts_duration: Optional[float] = None,
retry_delay_default: Optional[float] = None,
retry_dangerous_codes: Optional[List[int]] = None,
http_proxy: Optional[str] = None,
proxy_username: Optional[str] = None,
proxy_password: Optional[str] = None,
pool_connections: Optional[int] = None,
pool_maxsize: Optional[int] = None,
user_agent: Optional[str] = None,
):
self.hostname = hostname
self.access_token = access_token
Expand All @@ -52,6 +68,24 @@ def __init__(
self.oauth_persistence = oauth_persistence
self.credentials_provider = credentials_provider

# HTTP client configuration
self.ssl_options = ssl_options
self.socket_timeout = socket_timeout
self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5
self.retry_delay_min = retry_delay_min or 1.0
self.retry_delay_max = retry_delay_max or 10.0
self.retry_stop_after_attempts_duration = (
retry_stop_after_attempts_duration or 300.0
)
self.retry_delay_default = retry_delay_default or 5.0
self.retry_dangerous_codes = retry_dangerous_codes or []
self.http_proxy = http_proxy
self.proxy_username = proxy_username
self.proxy_password = proxy_password
self.pool_connections = pool_connections or 10
self.pool_maxsize = pool_maxsize or 20
self.user_agent = user_agent


def get_effective_azure_login_app_id(hostname) -> str:
"""
Expand All @@ -69,7 +103,7 @@ def get_effective_azure_login_app_id(hostname) -> str:
return AzureAppId.PROD.value[1]


def get_azure_tenant_id_from_host(host: str, http_client=None) -> str:
def get_azure_tenant_id_from_host(host: str, http_client) -> str:
"""
Load the Azure tenant ID from the Azure Databricks login page.

Expand All @@ -78,23 +112,20 @@ def get_azure_tenant_id_from_host(host: str, http_client=None) -> str:
the Azure login page, and the tenant ID is extracted from the redirect URL.
"""

if http_client is None:
http_client = DatabricksHttpClient.get_instance()

login_url = f"{host}/aad/auth"
logger.debug("Loading tenant ID from %s", login_url)
with http_client.execute(HttpMethod.GET, login_url, allow_redirects=False) as resp:
if resp.status_code // 100 != 3:

with http_client.request_context(HttpMethod.GET, login_url) as resp:
entra_id_endpoint = resp.retries.history[-1].redirect_location
if entra_id_endpoint is None:
raise ValueError(
f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}"
f"No Location header in response from {login_url}: {entra_id_endpoint}"
)
entra_id_endpoint = resp.headers.get("Location")
if entra_id_endpoint is None:
raise ValueError(f"No Location header in response from {login_url}")
# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
url = urlparse(entra_id_endpoint)
path_segments = url.path.split("/")
if len(path_segments) < 2:
raise ValueError(f"Invalid path in Location header: {url.path}")
return path_segments[1]

# The final redirect URL has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
url = urlparse(entra_id_endpoint)
path_segments = url.path.split("/")
if len(path_segments) < 2:
raise ValueError(f"Invalid path in Location header: {url.path}")
return path_segments[1]
75 changes: 32 additions & 43 deletions src/databricks/sql/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from typing import List, Optional

import oauthlib.oauth2
import requests
from oauthlib.oauth2.rfc6749.errors import OAuth2Error
from requests.exceptions import RequestException
from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader
from databricks.sql.common.http import HttpMethod, HttpHeader
from databricks.sql.common.http import OAuthResponse
from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler
from databricks.sql.auth.endpoint import OAuthEndpointCollection
Expand Down Expand Up @@ -63,33 +61,19 @@ def refresh(self) -> Token:
pass


class IgnoreNetrcAuth(requests.auth.AuthBase):
"""This auth method is a no-op.

We use it to force requestslib to not use .netrc to write auth headers
when making .post() requests to the oauth token endpoints, since these
don't require authentication.

In cases where .netrc is outdated or corrupt, these requests will fail.

See issue #121
"""

def __call__(self, r):
return r


class OAuthManager:
def __init__(
self,
port_range: List[int],
client_id: str,
idp_endpoint: OAuthEndpointCollection,
http_client,
):
self.port_range = port_range
self.client_id = client_id
self.redirect_port = None
self.idp_endpoint = idp_endpoint
self.http_client = http_client

@staticmethod
def __token_urlsafe(nbytes=32):
Expand All @@ -103,8 +87,11 @@ def __fetch_well_known_config(self, hostname: str):
known_config_url = self.idp_endpoint.get_openid_config_url(hostname)

try:
response = requests.get(url=known_config_url, auth=IgnoreNetrcAuth())
except RequestException as e:
response = self.http_client.request(HttpMethod.GET, url=known_config_url)
# Convert urllib3 response to requests-like response for compatibility
response.status_code = response.status
response.json = lambda: json.loads(response.data.decode())
except Exception as e:
logger.error(
f"Unable to fetch OAuth configuration from {known_config_url}.\n"
"Verify it is a valid workspace URL and that OAuth is "
Expand All @@ -122,7 +109,7 @@ def __fetch_well_known_config(self, hostname: str):
raise RuntimeError(msg)
try:
return response.json()
except requests.exceptions.JSONDecodeError as e:
except Exception as e:
logger.error(
f"Unable to decode OAuth configuration from {known_config_url}.\n"
"Verify it is a valid workspace URL and that OAuth is "
Expand Down Expand Up @@ -203,16 +190,17 @@ def __send_auth_code_token_request(
data = f"{token_request_body}&code_verifier={verifier}"
return self.__send_token_request(token_request_url, data)

@staticmethod
def __send_token_request(token_request_url, data):
def __send_token_request(self, token_request_url, data):
headers = {
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
}
response = requests.post(
url=token_request_url, data=data, headers=headers, auth=IgnoreNetrcAuth()
# Use unified HTTP client
response = self.http_client.request(
HttpMethod.POST, url=token_request_url, body=data, headers=headers
)
return response.json()
# Convert urllib3 response to dict for compatibility
return json.loads(response.data.decode())

def __send_refresh_token_request(self, hostname, refresh_token):
oauth_config = self.__fetch_well_known_config(hostname)
Expand All @@ -221,7 +209,7 @@ def __send_refresh_token_request(self, hostname, refresh_token):
token_request_body = client.prepare_refresh_body(
refresh_token=refresh_token, client_id=client.client_id
)
return OAuthManager.__send_token_request(token_request_url, token_request_body)
return self.__send_token_request(token_request_url, token_request_body)

@staticmethod
def __get_tokens_from_response(oauth_response):
Expand Down Expand Up @@ -320,14 +308,15 @@ def __init__(
token_url,
client_id,
client_secret,
http_client,
extra_params: dict = {},
):
self.client_id = client_id
self.client_secret = client_secret
self.token_url = token_url
self.extra_params = extra_params
self.token: Optional[Token] = None
self._http_client = DatabricksHttpClient.get_instance()
self._http_client = http_client

def get_token(self) -> Token:
if self.token is None or self.token.is_expired():
Expand All @@ -348,17 +337,17 @@ def refresh(self) -> Token:
}
)

with self._http_client.execute(
method=HttpMethod.POST, url=self.token_url, headers=headers, data=data
) as response:
if response.status_code == 200:
oauth_response = OAuthResponse(**response.json())
return Token(
oauth_response.access_token,
oauth_response.token_type,
oauth_response.refresh_token,
)
else:
raise Exception(
f"Failed to get token: {response.status_code} {response.text}"
)
response = self._http_client.request(
method=HttpMethod.POST, url=self.token_url, headers=headers, body=data
)
if response.status == 200:
oauth_response = OAuthResponse(**json.loads(response.data.decode("utf-8")))
return Token(
oauth_response.access_token,
oauth_response.token_type,
oauth_response.refresh_token,
)
else:
raise Exception(
f"Failed to get token: {response.status} {response.data.decode('utf-8')}"
)
10 changes: 8 additions & 2 deletions src/databricks/sql/auth/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,14 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
logger.info(f"Received status code {status_code} for {method} request")

# Request succeeded. Don't retry.
if status_code == 200:
return False, "200 codes are not retried"
if status_code // 100 <= 3:
return False, "2xx/3xx codes are not retried"

if status_code == 400:
return (
False,
"Received 400 - BAD_REQUEST. Please check the request parameters.",
)

if status_code == 401:
return (
Expand Down
4 changes: 4 additions & 0 deletions src/databricks/sql/backend/sea/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def build_queue(
max_download_threads: int,
sea_client: SeaDatabricksClient,
lz4_compressed: bool,
http_client,
) -> ResultSetQueue:
"""
Factory method to build a result set queue for SEA backend.
Expand Down Expand Up @@ -94,6 +95,7 @@ def build_queue(
total_chunk_count=manifest.total_chunk_count,
lz4_compressed=lz4_compressed,
description=description,
http_client=http_client,
)
raise ProgrammingError("Invalid result format")

Expand Down Expand Up @@ -309,6 +311,7 @@ def __init__(
sea_client: SeaDatabricksClient,
statement_id: str,
total_chunk_count: int,
http_client,
lz4_compressed: bool = False,
description: List[Tuple] = [],
):
Expand Down Expand Up @@ -337,6 +340,7 @@ def __init__(
# TODO: fix these arguments when telemetry is implemented in SEA
session_id_hex=None,
chunk_id=0,
http_client=http_client,
)

logger.debug(
Expand Down
Loading
Loading