Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Added
- New feature: Support for macOS and Linux.
- Documentation: Added API documentation in the Wiki.
- Bulk copy: `Authentication=ActiveDirectoryMSI` support (system- and user-assigned managed identity). UID is interpreted as the user-assigned identity's `client_id`. Partial fix for #534.

### Changed
- Improved error handling in the connection module.
Expand Down
91 changes: 79 additions & 12 deletions mssql_python/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,24 @@
# Reusing credential objects allows the Azure Identity SDK's built-in
# in-memory token cache to work, avoiding redundant token acquisitions.
# See: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/TOKEN_CACHING.md
_credential_cache: Dict[str, object] = {}
_credential_cache: Dict[object, object] = {}
_credential_cache_lock = threading.Lock()


def _credential_cache_key(auth_type: str, credential_kwargs: Optional[Dict[str, str]]):
"""Build a hashable cache key from auth_type and optional credential kwargs.

Returns the plain auth_type string when no kwargs are provided so that
callers caching by string (the original behavior) keep working. When
kwargs are present (e.g. user-assigned MSI client_id), the key is a
tuple of ``(auth_type, sorted_kwargs_items)`` so different kwargs map
to different cached credentials.
"""
if not credential_kwargs:
return auth_type
return (auth_type, tuple(sorted(credential_kwargs.items())))


class AADAuth:
"""Handles Azure Active Directory authentication"""

Expand All @@ -37,24 +51,30 @@ def get_token_struct(token: str) -> bytes:
return struct.pack(f"<I{len(token_bytes)}s", len(token_bytes), token_bytes)

@staticmethod
def get_token(auth_type: str) -> bytes:
def get_token(
auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None
) -> bytes:
"""Get DDBC token struct for the specified authentication type."""
token_struct, _ = AADAuth._acquire_token(auth_type)
token_struct, _ = AADAuth._acquire_token(auth_type, credential_kwargs)
return token_struct

@staticmethod
def get_raw_token(auth_type: str) -> str:
def get_raw_token(
auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None
) -> str:
"""Acquire a raw JWT for the mssql-py-core connection (bulk copy).

Uses the cached credential instance so the Azure Identity SDK's
built-in token cache can serve a valid token without a round-trip
when the previous token has not yet expired.
"""
_, raw_token = AADAuth._acquire_token(auth_type)
_, raw_token = AADAuth._acquire_token(auth_type, credential_kwargs)
return raw_token

@staticmethod
def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
def _acquire_token(
auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None
) -> Tuple[bytes, str]:
"""Internal: acquire token and return (ddbc_struct, raw_jwt)."""
# Import Azure libraries inside method to support test mocking
# pylint: disable=import-outside-toplevel
Expand All @@ -63,6 +83,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
DefaultAzureCredential,
DeviceCodeCredential,
InteractiveBrowserCredential,
ManagedIdentityCredential,
)
from azure.core.exceptions import ClientAuthenticationError
except ImportError as e:
Expand All @@ -76,6 +97,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
"default": DefaultAzureCredential,
"devicecode": DeviceCodeCredential,
"interactive": InteractiveBrowserCredential,
"msi": ManagedIdentityCredential,
}

credential_class = credential_map.get(auth_type)
Expand All @@ -89,20 +111,22 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
credential_class.__name__,
)

kwargs = credential_kwargs or {}
cache_key = _credential_cache_key(auth_type, kwargs)
try:
with _credential_cache_lock:
if auth_type not in _credential_cache:
if cache_key not in _credential_cache:
Comment thread
bewithgaurav marked this conversation as resolved.
logger.debug(
"get_token: Creating new credential instance for auth_type=%s",
auth_type,
)
_credential_cache[auth_type] = credential_class()
_credential_cache[cache_key] = credential_class(**kwargs)
else:
logger.debug(
"get_token: Reusing cached credential instance for auth_type=%s",
auth_type,
)
credential = _credential_cache[auth_type]
credential = _credential_cache[cache_key]
raw_token = credential.get_token("https://database.windows.net/.default").token
logger.info(
"get_token: Azure AD token acquired successfully - token_length=%d chars",
Expand Down Expand Up @@ -130,6 +154,20 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e


def _extract_msi_client_id(parameters: List[str]) -> Optional[str]:
"""Pull UID out of connection parameters for user-assigned MSI.

For ActiveDirectoryMSI, UID (when present) carries the user-assigned
identity's client_id. Returns None for system-assigned MSI.
"""
for param in parameters:
key, _, value = param.strip().partition("=")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to get this from the conn str parser map? Assuming we need the conn str UID

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what will happen if someone provides a UID={hello=world}

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks - missed this, parser should be used as a standard across
that handles {hello=world} correctly

if key.strip().lower() == "uid":
value = value.strip()
return value or None
return None


def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[str]]:
"""
Process connection parameters and extract authentication type.
Expand Down Expand Up @@ -180,6 +218,10 @@ def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[
# Default authentication (uses DefaultAzureCredential)
logger.debug("process_auth_parameters: Default Azure authentication detected")
auth_type = "default"
elif value_lower == AuthType.MSI.value:
# Managed identity authentication (system- or user-assigned)
logger.debug("process_auth_parameters: Managed identity authentication detected")
auth_type = "msi"
modified_parameters.append(param)

logger.debug(
Expand Down Expand Up @@ -212,7 +254,9 @@ def remove_sensitive_params(parameters: List[str]) -> List[str]:
return result


def get_auth_token(auth_type: str) -> Optional[bytes]:
def get_auth_token(
auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None
) -> Optional[bytes]:
"""Get DDBC authentication token struct based on auth type."""
logger.debug("get_auth_token: Starting - auth_type=%s", auth_type)
if not auth_type:
Expand All @@ -225,7 +269,7 @@ def get_auth_token(auth_type: str) -> Optional[bytes]:
return None # Let Windows handle AADInteractive natively

try:
token = AADAuth.get_token(auth_type)
token = AADAuth.get_token(auth_type, credential_kwargs)
logger.info("get_auth_token: Token acquired successfully - auth_type=%s", auth_type)
return token
except (ValueError, RuntimeError) as e:
Expand All @@ -246,6 +290,7 @@ def extract_auth_type(connection_string: str) -> Optional[str]:
AuthType.INTERACTIVE.value: "interactive",
AuthType.DEVICE_CODE.value: "devicecode",
AuthType.DEFAULT.value: "default",
AuthType.MSI.value: "msi",
}
for part in connection_string.split(";"):
key, _, value = part.strip().partition("=")
Expand All @@ -254,6 +299,20 @@ def extract_auth_type(connection_string: str) -> Optional[str]:
return None


def extract_credential_kwargs(
connection_string: str, auth_type: Optional[str]
) -> Dict[str, str]:
"""Extract credential constructor kwargs for the given auth type.

For ActiveDirectoryMSI: returns ``{"client_id": uid}`` when UID is
set (user-assigned MSI) and ``{}`` for system-assigned MSI.
"""
if auth_type != "msi":
return {}
client_id = _extract_msi_client_id(connection_string.split(";"))
return {"client_id": client_id} if client_id else {}
Comment thread
bewithgaurav marked this conversation as resolved.
Outdated


def process_connection_string(
connection_string: str,
) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str]]:
Expand Down Expand Up @@ -301,12 +360,20 @@ def process_connection_string(

modified_parameters, auth_type = process_auth_parameters(parameters)

# Capture credential kwargs (e.g. user-assigned MSI client_id) before
# remove_sensitive_params strips UID from the parameter list.
credential_kwargs: Dict[str, str] = {}
if auth_type == "msi":
client_id = _extract_msi_client_id(modified_parameters)
if client_id:
credential_kwargs["client_id"] = client_id

if auth_type:
logger.info(
"process_connection_string: Authentication type detected - auth_type=%s", auth_type
)
modified_parameters = remove_sensitive_params(modified_parameters)
token_struct = get_auth_token(auth_type)
token_struct = get_auth_token(auth_type, credential_kwargs or None)
if token_struct:
logger.info(
"process_connection_string: Token authentication configured successfully - auth_type=%s",
Expand Down
1 change: 1 addition & 0 deletions mssql_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ class AuthType(Enum):
INTERACTIVE = "activedirectoryinteractive"
DEVICE_CODE = "activedirectorydevicecode"
DEFAULT = "activedirectorydefault"
MSI = "activedirectorymsi"


class SQLTypes:
Expand Down
9 changes: 7 additions & 2 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2912,10 +2912,15 @@ def bulkcopy(
# Token acquisition — only thing cursor must handle (needs azure-identity SDK)
if self.connection._auth_type:
# Fresh token acquisition for mssql-py-core connection
from mssql_python.auth import AADAuth
from mssql_python.auth import AADAuth, extract_credential_kwargs

credential_kwargs = extract_credential_kwargs(
self.connection.connection_str, self.connection._auth_type
)
try:
raw_token = AADAuth.get_raw_token(self.connection._auth_type)
raw_token = AADAuth.get_raw_token(
self.connection._auth_type, credential_kwargs or None
)
Comment thread
bewithgaurav marked this conversation as resolved.
Outdated
except (RuntimeError, ValueError) as e:
raise RuntimeError(
f"Bulk copy failed: unable to acquire Azure AD token "
Expand Down
Loading
Loading