Skip to content

Commit

Permalink
auth azure: support access token
Browse files Browse the repository at this point in the history
  • Loading branch information
LiliDeng committed Jan 14, 2025
1 parent 3290e82 commit 9e89205
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 9 deletions.
88 changes: 82 additions & 6 deletions lisa/sut_orchestrator/azure/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import base64
import hashlib
import json
import os
Expand All @@ -26,6 +27,7 @@

import requests
from assertpy import assert_that
from azure.core.credentials import AccessToken, TokenCredential
from azure.core.exceptions import ResourceExistsError
from azure.keyvault.certificates import (
CertificateClient,
Expand Down Expand Up @@ -1634,12 +1636,14 @@ def generate_user_delegation_sas_token(
connection_string: Optional[str] = None,
writable: bool = False,
expired_hours: int = 1,
platform: Optional["AzurePlatform"] = None,
) -> Any:
blob_service_client = get_blob_service_client(
cloud=cloud,
credential=credential,
account_name=account_name,
connection_string=connection_string,
platform=platform,
)
start_time = datetime.now(timezone.utc)
expiry_time = start_time + timedelta(hours=expired_hours)
Expand All @@ -1664,6 +1668,7 @@ def get_blob_service_client(
credential: Optional[Any] = None,
account_name: Optional[str] = None,
connection_string: Optional[str] = None,
platform: Optional["AzurePlatform"] = None,
) -> BlobServiceClient:
"""
Create a Azure Storage container if it does not exist.
Expand All @@ -1677,7 +1682,10 @@ def get_blob_service_client(
assert (
account_name
), "account_name is required, if connection_string is not set."

if platform and platform._azure_runbook.azure_storage_access_token:
credential = get_static_access_token(
platform._azure_runbook.azure_storage_access_token
)
blob_service_client = BlobServiceClient(
f"https://{account_name}.blob.{cloud.suffixes.storage_endpoint}",
credential,
Expand All @@ -1692,15 +1700,21 @@ def get_or_create_storage_container(
account_name: Optional[str] = None,
connection_string: Optional[str] = None,
allow_create: bool = True,
platform: Optional["AzurePlatform"] = None,
) -> ContainerClient:
"""
Create a Azure Storage container if it does not exist.
"""
if platform and platform._azure_runbook.azure_storage_access_token:
credential = get_static_access_token(
platform._azure_runbook.azure_storage_access_token
)
blob_service_client = get_blob_service_client(
cloud=cloud,
credential=credential,
account_name=account_name,
connection_string=connection_string,
platform=platform,
)
container_client = blob_service_client.get_container_client(container_name)
if not container_client.exists():
Expand Down Expand Up @@ -1836,6 +1850,7 @@ def copy_vhd_to_storage(
cloud=platform.cloud,
account_name=storage_name,
container_name=SAS_COPIED_CONTAINER_NAME,
platform=platform,
)
full_vhd_path = f"{container_client.url}/{dst_vhd_name}"

Expand Down Expand Up @@ -1882,6 +1897,7 @@ def copy_vhd_to_storage(
cloud=platform.cloud,
account_name=storage_name,
writable=True,
platform=platform,
)
dst_vhd_sas_url = f"{full_vhd_path}?{sas_token}"
log.info(f"copying vhd by azcopy {dst_vhd_name}")
Expand Down Expand Up @@ -2254,6 +2270,7 @@ def _generate_sas_token_for_vhd(
cloud=platform.cloud,
account_name=sc_name,
container_name=container_name,
platform=platform,
)
source_blob = source_container_client.get_blob_client(blob_name)
sas_token = generate_user_delegation_sas_token(
Expand All @@ -2262,6 +2279,7 @@ def _generate_sas_token_for_vhd(
credential=platform.credential,
cloud=platform.cloud,
account_name=sc_name,
platform=platform,
)
source_url = source_blob.url + "?" + sas_token
return source_url
Expand Down Expand Up @@ -2546,6 +2564,7 @@ def check_blob_exist(
account_name=account_name,
container_name=container_name,
allow_create=False,
platform=platform,
)
blob_client = container_client.get_blob_client(blob_name)
blob_exist = blob_client.exists()
Expand Down Expand Up @@ -2696,14 +2715,62 @@ def get_size(disk_type: schema.DiskType, data_disk_iops: int = 1) -> int:
raise LisaException(f"Data disk type {disk_type} is unsupported.")


class StaticAccessTokenCredential(TokenCredential):
def __init__(self, token: str) -> None:
"""
Initialize StaticAccessTokenCredential with the provided token.
:param token: The Azure access token as a string.
"""
self._token = token
self._expires_on = self._get_exp()

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"""
Get the access token for the specified scopes.
:param scopes: The OAuth 2.0 scopes the token applies to.
:param kwargs: Additional keyword arguments that may be required by the SDK.
:return: An AccessToken instance containing the token and its expiry time.
"""
# You can choose to print or log the scopes and kwargs for debugging if needed
return AccessToken(self._token, self._expires_on)

def _get_exp(self) -> Any:
# The second part of the JWT is the payload
payload = self._token.split(".")[1]
# Add padding to ensure Base64 decoding works properly
padded_payload = payload + "=" * (4 - len(payload) % 4)
# Decode the Base64 URL-safe encoded payload
decoded_payload = base64.urlsafe_b64decode(padded_payload)
# Convert the payload into a dictionary and get the expiration time
# 'exp' is the UNIX timestamp for expiration
return json.loads(decoded_payload).get("exp")


def get_static_access_token(token: str) -> Any:
credential = None
if token:
credential = StaticAccessTokenCredential(token)
return credential


def get_certificate_client(
vault_url: str, platform: "AzurePlatform"
) -> CertificateClient:
return CertificateClient(vault_url, platform.credential)
credential = (
get_static_access_token(platform._azure_runbook.azure_keyvault_access_token)
or platform.credential
)
return CertificateClient(vault_url, credential)


def get_secret_client(vault_url: str, platform: "AzurePlatform") -> SecretClient:
return SecretClient(vault_url, platform.credential)
credential = (
get_static_access_token(platform._azure_runbook.azure_keyvault_access_token)
or platform.credential
)
return SecretClient(vault_url, credential)


def get_key_vault_management_client(
Expand Down Expand Up @@ -2799,7 +2866,14 @@ def get_identity_id(
else:
endpoint = "me"
graph_api_url = f"{base_url}{api_version}/{endpoint}"
token = platform.credential.get_token("https://graph.microsoft.com/.default").token
credential = (
get_static_access_token(platform._azure_runbook.azure_graph_access_token)
or platform.credential
)
if isinstance(credential, StaticAccessTokenCredential):
token = credential._token
else:
token = credential.get_token("https://graph.microsoft.com/.default").token
# Set up the API call headers
headers = {
"Authorization": f"Bearer {token}",
Expand Down Expand Up @@ -3002,9 +3076,11 @@ def create_certificate(
def check_certificate_existence(
vault_url: str, cert_name: str, log: Logger, platform: "AzurePlatform"
) -> bool:
certificate_client = CertificateClient(
vault_url=vault_url, credential=platform.credential
credential = (
get_static_access_token(platform._azure_runbook.azure_keyvault_access_token)
or platform.credential
)
certificate_client = CertificateClient(vault_url=vault_url, credential=credential)

try:
certificate = certificate_client.get_certificate(cert_name)
Expand Down
27 changes: 24 additions & 3 deletions lisa/sut_orchestrator/azure/platform_.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast

import requests
from azure.core.credentials import TokenCredential
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from azure.identity import DefaultAzureCredential
from azure.mgmt.compute.models import (
Expand Down Expand Up @@ -116,6 +117,7 @@
get_or_create_storage_container,
get_primary_ip_addresses,
get_resource_management_client,
get_static_access_token,
get_storage_account_name,
get_vhd_details,
get_vm,
Expand Down Expand Up @@ -246,6 +248,10 @@ class AzurePlatformSchema:
),
)
service_principal_key: str = field(default="")
azure_arm_access_token: str = field(default="")
azure_storage_access_token: str = field(default="")
azure_keyvault_access_token: str = field(default="")
azure_graph_access_token: str = field(default="")
subscription_id: str = field(
default="",
metadata=field_metadata(
Expand Down Expand Up @@ -320,6 +326,10 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:
"service_principal_tenant_id",
"service_principal_client_id",
"service_principal_key",
"azure_arm_access_token",
"azure_storage_access_token",
"azure_keyvault_access_token",
"azure_graph_access_token",
"subscription_id",
"shared_resource_group_name",
"resource_group_name",
Expand All @@ -338,6 +348,14 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:
add_secret(self.subscription_id, mask=PATTERN_GUID)
if self.service_principal_key:
add_secret(self.service_principal_key)
if self.azure_arm_access_token:
add_secret(self.azure_arm_access_token)
if self.azure_storage_access_token:
add_secret(self.azure_storage_access_token)
if self.azure_keyvault_access_token:
add_secret(self.azure_keyvault_access_token)
if self.azure_graph_access_token:
add_secret(self.azure_graph_access_token)
if self.service_principal_client_id:
add_secret(self.service_principal_client_id, mask=PATTERN_GUID)

Expand Down Expand Up @@ -407,14 +425,14 @@ class AzurePlatform(Platform):
)
_arm_template: Any = None

_credentials: Dict[str, DefaultAzureCredential] = {}
_credentials: Dict[str, Union[DefaultAzureCredential, TokenCredential]] = {}
_locations_data_cache: Dict[str, AzureLocation] = {}

def __init__(self, runbook: schema.Platform) -> None:
super().__init__(runbook=runbook)

# for type detection
self.credential: DefaultAzureCredential
self.credential: Union[DefaultAzureCredential, TokenCredential]
self.cloud: Cloud

# It has to be defined after the class definition is loaded. So it
Expand Down Expand Up @@ -937,7 +955,9 @@ def _initialize_credential(self) -> None:
if azure_runbook.service_principal_key:
os.environ["AZURE_CLIENT_SECRET"] = azure_runbook.service_principal_key

credential = DefaultAzureCredential(
credential = get_static_access_token(
azure_runbook.azure_arm_access_token
) or DefaultAzureCredential(
authority=self.cloud.endpoints.active_directory,
)

Expand Down Expand Up @@ -2270,6 +2290,7 @@ def _get_vhd_os_disk_size(self, blob_url: str) -> int:
cloud=self.cloud,
account_name=result_dict["account_name"],
container_name=result_dict["container_name"],
platform=self,
)

vhd_blob = container_client.get_blob_client(result_dict["blob_name"])
Expand Down
1 change: 1 addition & 0 deletions lisa/sut_orchestrator/azure/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def _export_vhd(
cloud=platform.cloud,
account_name=runbook.storage_account_name,
container_name=runbook.container_name,
platform=platform,
)

if runbook.custom_blob_name:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def retrieve_storage_blob_url(
cloud=platform.cloud,
account_name=storage_account_name,
container_name=container_name,
platform=platform,
)

blob = container_client.get_blob_client(blob_name)
Expand Down

0 comments on commit 9e89205

Please sign in to comment.