-
Notifications
You must be signed in to change notification settings - Fork 50
FEAT: Add ActiveDirectoryMSI support for bulk copy #573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
9e0b418
3ca2165
66b8eec
6451a4e
e106d5f
ff1aa41
1b7b434
df24965
9ad28a1
cbb673d
303b66c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,10 +16,24 @@ | |
| # Reusing credential objects allows the Azure Identity SDK's built-in | ||
| # in-memory token cache to work, avoiding redundant token acquisitions. | ||
| # See: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/TOKEN_CACHING.md | ||
| _credential_cache: Dict[str, object] = {} | ||
| _credential_cache: Dict[object, object] = {} | ||
| _credential_cache_lock = threading.Lock() | ||
|
|
||
|
|
||
| def _credential_cache_key(auth_type: str, credential_kwargs: Optional[Dict[str, str]]): | ||
| """Build a hashable cache key from auth_type and optional credential kwargs. | ||
|
|
||
| Returns the plain auth_type string when no kwargs are provided so that | ||
| callers caching by string (the original behavior) keep working. When | ||
| kwargs are present (e.g. user-assigned MSI client_id), the key is a | ||
| tuple of ``(auth_type, sorted_kwargs_items)`` so different kwargs map | ||
| to different cached credentials. | ||
| """ | ||
| if not credential_kwargs: | ||
| return auth_type | ||
| return (auth_type, tuple(sorted(credential_kwargs.items()))) | ||
|
|
||
|
|
||
| class AADAuth: | ||
| """Handles Azure Active Directory authentication""" | ||
|
|
||
|
|
@@ -37,24 +51,30 @@ def get_token_struct(token: str) -> bytes: | |
| return struct.pack(f"<I{len(token_bytes)}s", len(token_bytes), token_bytes) | ||
|
|
||
| @staticmethod | ||
| def get_token(auth_type: str) -> bytes: | ||
| def get_token( | ||
| auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None | ||
| ) -> bytes: | ||
| """Get DDBC token struct for the specified authentication type.""" | ||
| token_struct, _ = AADAuth._acquire_token(auth_type) | ||
| token_struct, _ = AADAuth._acquire_token(auth_type, credential_kwargs) | ||
| return token_struct | ||
|
|
||
| @staticmethod | ||
| def get_raw_token(auth_type: str) -> str: | ||
| def get_raw_token( | ||
| auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None | ||
| ) -> str: | ||
| """Acquire a raw JWT for the mssql-py-core connection (bulk copy). | ||
|
|
||
| Uses the cached credential instance so the Azure Identity SDK's | ||
| built-in token cache can serve a valid token without a round-trip | ||
| when the previous token has not yet expired. | ||
| """ | ||
| _, raw_token = AADAuth._acquire_token(auth_type) | ||
| _, raw_token = AADAuth._acquire_token(auth_type, credential_kwargs) | ||
| return raw_token | ||
|
|
||
| @staticmethod | ||
| def _acquire_token(auth_type: str) -> Tuple[bytes, str]: | ||
| def _acquire_token( | ||
| auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None | ||
| ) -> Tuple[bytes, str]: | ||
| """Internal: acquire token and return (ddbc_struct, raw_jwt).""" | ||
| # Import Azure libraries inside method to support test mocking | ||
| # pylint: disable=import-outside-toplevel | ||
|
|
@@ -63,6 +83,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: | |
| DefaultAzureCredential, | ||
| DeviceCodeCredential, | ||
| InteractiveBrowserCredential, | ||
| ManagedIdentityCredential, | ||
| ) | ||
| from azure.core.exceptions import ClientAuthenticationError | ||
| except ImportError as e: | ||
|
|
@@ -76,6 +97,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: | |
| "default": DefaultAzureCredential, | ||
| "devicecode": DeviceCodeCredential, | ||
| "interactive": InteractiveBrowserCredential, | ||
| "msi": ManagedIdentityCredential, | ||
| } | ||
|
|
||
| credential_class = credential_map.get(auth_type) | ||
|
|
@@ -89,20 +111,22 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: | |
| credential_class.__name__, | ||
| ) | ||
|
|
||
| kwargs = credential_kwargs or {} | ||
| cache_key = _credential_cache_key(auth_type, kwargs) | ||
| try: | ||
| with _credential_cache_lock: | ||
| if auth_type not in _credential_cache: | ||
| if cache_key not in _credential_cache: | ||
| logger.debug( | ||
| "get_token: Creating new credential instance for auth_type=%s", | ||
| auth_type, | ||
| ) | ||
| _credential_cache[auth_type] = credential_class() | ||
| _credential_cache[cache_key] = credential_class(**kwargs) | ||
| else: | ||
| logger.debug( | ||
| "get_token: Reusing cached credential instance for auth_type=%s", | ||
| auth_type, | ||
| ) | ||
| credential = _credential_cache[auth_type] | ||
| credential = _credential_cache[cache_key] | ||
| raw_token = credential.get_token("https://database.windows.net/.default").token | ||
| logger.info( | ||
| "get_token: Azure AD token acquired successfully - token_length=%d chars", | ||
|
|
@@ -130,6 +154,20 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: | |
| raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e | ||
|
|
||
|
|
||
| def _extract_msi_client_id(parameters: List[str]) -> Optional[str]: | ||
| """Pull UID out of connection parameters for user-assigned MSI. | ||
|
|
||
| For ActiveDirectoryMSI, UID (when present) carries the user-assigned | ||
| identity's client_id. Returns None for system-assigned MSI. | ||
| """ | ||
| for param in parameters: | ||
| key, _, value = param.strip().partition("=") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to get this from the conn str parser map? Assuming we need the conn str UID
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure what will happen if someone provides a UID={hello=world}
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks - missed this, parser should be used as a standard across |
||
| if key.strip().lower() == "uid": | ||
| value = value.strip() | ||
| return value or None | ||
| return None | ||
|
|
||
|
|
||
| def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[str]]: | ||
| """ | ||
| Process connection parameters and extract authentication type. | ||
|
|
@@ -180,6 +218,10 @@ def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[ | |
| # Default authentication (uses DefaultAzureCredential) | ||
| logger.debug("process_auth_parameters: Default Azure authentication detected") | ||
| auth_type = "default" | ||
| elif value_lower == AuthType.MSI.value: | ||
| # Managed identity authentication (system- or user-assigned) | ||
| logger.debug("process_auth_parameters: Managed identity authentication detected") | ||
| auth_type = "msi" | ||
| modified_parameters.append(param) | ||
|
|
||
| logger.debug( | ||
|
|
@@ -212,7 +254,9 @@ def remove_sensitive_params(parameters: List[str]) -> List[str]: | |
| return result | ||
|
|
||
|
|
||
| def get_auth_token(auth_type: str) -> Optional[bytes]: | ||
| def get_auth_token( | ||
| auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None | ||
| ) -> Optional[bytes]: | ||
| """Get DDBC authentication token struct based on auth type.""" | ||
| logger.debug("get_auth_token: Starting - auth_type=%s", auth_type) | ||
| if not auth_type: | ||
|
|
@@ -225,7 +269,7 @@ def get_auth_token(auth_type: str) -> Optional[bytes]: | |
| return None # Let Windows handle AADInteractive natively | ||
|
|
||
| try: | ||
| token = AADAuth.get_token(auth_type) | ||
| token = AADAuth.get_token(auth_type, credential_kwargs) | ||
| logger.info("get_auth_token: Token acquired successfully - auth_type=%s", auth_type) | ||
| return token | ||
| except (ValueError, RuntimeError) as e: | ||
|
|
@@ -246,6 +290,7 @@ def extract_auth_type(connection_string: str) -> Optional[str]: | |
| AuthType.INTERACTIVE.value: "interactive", | ||
| AuthType.DEVICE_CODE.value: "devicecode", | ||
| AuthType.DEFAULT.value: "default", | ||
| AuthType.MSI.value: "msi", | ||
| } | ||
| for part in connection_string.split(";"): | ||
| key, _, value = part.strip().partition("=") | ||
|
|
@@ -254,6 +299,20 @@ def extract_auth_type(connection_string: str) -> Optional[str]: | |
| return None | ||
|
|
||
|
|
||
| def extract_credential_kwargs( | ||
| connection_string: str, auth_type: Optional[str] | ||
| ) -> Dict[str, str]: | ||
| """Extract credential constructor kwargs for the given auth type. | ||
|
|
||
| For ActiveDirectoryMSI: returns ``{"client_id": uid}`` when UID is | ||
| set (user-assigned MSI) and ``{}`` for system-assigned MSI. | ||
| """ | ||
| if auth_type != "msi": | ||
| return {} | ||
| client_id = _extract_msi_client_id(connection_string.split(";")) | ||
| return {"client_id": client_id} if client_id else {} | ||
|
bewithgaurav marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| def process_connection_string( | ||
| connection_string: str, | ||
| ) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str]]: | ||
|
|
@@ -301,12 +360,20 @@ def process_connection_string( | |
|
|
||
| modified_parameters, auth_type = process_auth_parameters(parameters) | ||
|
|
||
| # Capture credential kwargs (e.g. user-assigned MSI client_id) before | ||
| # remove_sensitive_params strips UID from the parameter list. | ||
| credential_kwargs: Dict[str, str] = {} | ||
| if auth_type == "msi": | ||
| client_id = _extract_msi_client_id(modified_parameters) | ||
| if client_id: | ||
| credential_kwargs["client_id"] = client_id | ||
|
|
||
| if auth_type: | ||
| logger.info( | ||
| "process_connection_string: Authentication type detected - auth_type=%s", auth_type | ||
| ) | ||
| modified_parameters = remove_sensitive_params(modified_parameters) | ||
| token_struct = get_auth_token(auth_type) | ||
| token_struct = get_auth_token(auth_type, credential_kwargs or None) | ||
| if token_struct: | ||
| logger.info( | ||
| "process_connection_string: Token authentication configured successfully - auth_type=%s", | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.