diff --git a/lisa/environment.py b/lisa/environment.py index 6bbfcd8944..0d38158ffc 100644 --- a/lisa/environment.py +++ b/lisa/environment.py @@ -493,15 +493,13 @@ def load_environments( class EnvironmentHookSpec: @hookspec def get_environment_information(self, environment: Environment) -> Dict[str, str]: - ... + raise NotImplementedError class EnvironmentHookImpl: @hookimpl def get_environment_information(self, environment: Environment) -> Dict[str, str]: - information: Dict[str, str] = {} - information["name"] = environment.name - + information: Dict[str, str] = {"name": environment.name} if environment.nodes: node = environment.default_node try: diff --git a/lisa/node.py b/lisa/node.py index bafbb04a78..4c9752271f 100644 --- a/lisa/node.py +++ b/lisa/node.py @@ -737,7 +737,7 @@ def quick_connect( class NodeHookSpec: @hookspec def get_node_information(self, node: Node) -> Dict[str, str]: - ... + raise NotImplementedError class NodeHookImpl: diff --git a/lisa/sut_orchestrator/aws/common.py b/lisa/sut_orchestrator/aws/common.py index b014494864..169bacdfd5 100644 --- a/lisa/sut_orchestrator/aws/common.py +++ b/lisa/sut_orchestrator/aws/common.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from dataclasses import InitVar, dataclass, field +from dataclasses import dataclass, field from typing import Dict, List, Optional from dataclasses_json import dataclass_json @@ -67,16 +67,13 @@ class AwsNodeSchema: data_disk_size: int = 32 disk_type: str = "" - # for marketplace image, which need to accept terms - _marketplace: InitVar[Optional[AwsVmMarketplaceSchema]] = None + def __post_init__(self) -> None: + # Caching for marketplace image + self._marketplace: Optional[AwsVmMarketplaceSchema] = None @property def marketplace(self) -> AwsVmMarketplaceSchema: - # this is a safe guard and prevent mypy error on typing - if not hasattr(self, "_marketplace"): - self._marketplace: Optional[AwsVmMarketplaceSchema] = None - - if not self._marketplace: + if self._marketplace is None: assert isinstance( self.marketplace_raw, str ), f"actual: {type(self.marketplace_raw)}" diff --git a/lisa/sut_orchestrator/azure/common.py b/lisa/sut_orchestrator/azure/common.py index dd478fba67..4976195e1f 100644 --- a/lisa/sut_orchestrator/azure/common.py +++ b/lisa/sut_orchestrator/azure/common.py @@ -4,7 +4,7 @@ import os import re import sys -from dataclasses import InitVar, dataclass, field +from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from functools import lru_cache from pathlib import Path @@ -225,13 +225,12 @@ class AzureNodeSchema: # image. is_linux: Optional[bool] = None - _marketplace: InitVar[Optional[AzureVmMarketplaceSchema]] = None + def __post_init__(self) -> None: + # Caching + self._marketplace: Optional[AzureVmMarketplaceSchema] = None + self._shared_gallery: Optional[SharedImageGallerySchema] = None + self._vhd: Optional[VhdSchema] = None - _shared_gallery: InitVar[Optional[SharedImageGallerySchema]] = None - - _vhd: InitVar[Optional[VhdSchema]] = None - - def __post_init__(self, *args: Any, **kwargs: Any) -> None: # trim whitespace of values. strip_strs( self, @@ -254,109 +253,96 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None: @property def marketplace(self) -> Optional[AzureVmMarketplaceSchema]: - # this is a safe guard and prevent mypy error on typing - if not hasattr(self, "_marketplace"): - self._marketplace: Optional[AzureVmMarketplaceSchema] = None - marketplace: Optional[AzureVmMarketplaceSchema] = self._marketplace - if not marketplace: - if isinstance(self.marketplace_raw, dict): + if self._marketplace is not None: + return self._marketplace + + if isinstance(self.marketplace_raw, dict): + # Users decide the cases of image names, + # inconsistent cases cause a mismatch error in notifiers. + # lower() normalizes the image names, it has no impact on deployment + self.marketplace_raw = { + k: v.lower() for k, v in self.marketplace_raw.items() + } + self._marketplace = schema.load_by_type( + AzureVmMarketplaceSchema, self.marketplace_raw + ) + # Validated marketplace_raw and filter out any unwanted content + self.marketplace_raw = self._marketplace.to_dict() # type: ignore + + elif self.marketplace_raw: + assert isinstance( + self.marketplace_raw, str + ), f"actual: {type(self.marketplace_raw)}" + + self.marketplace_raw = self.marketplace_raw.strip() + + if self.marketplace_raw: # Users decide the cases of image names, - # the inconsistent cases cause the mismatched error in notifiers. - # The lower() normalizes the image names, - # it has no impact on deployment. - self.marketplace_raw = dict( - (k, v.lower()) for k, v in self.marketplace_raw.items() - ) - marketplace = schema.load_by_type( - AzureVmMarketplaceSchema, self.marketplace_raw - ) - # this step makes marketplace_raw is validated, and - # filter out any unwanted content. - self.marketplace_raw = marketplace.to_dict() # type: ignore - elif self.marketplace_raw: - assert isinstance( - self.marketplace_raw, str - ), f"actual: {type(self.marketplace_raw)}" - - self.marketplace_raw = self.marketplace_raw.strip() - - if self.marketplace_raw: - # Users decide the cases of image names, - # the inconsistent cases cause the mismatched error in notifiers. - # The lower() normalizes the image names, - # it has no impact on deployment. - marketplace_strings = re.split( - r"[:\s]+", self.marketplace_raw.lower() + # inconsistent cases cause a mismatch error in notifiers. + # lower() normalizes the image names, it has no impact on deployment + marketplace_strings = re.split(r"[:\s]+", self.marketplace_raw.lower()) + + if len(marketplace_strings) != 4: + raise LisaException( + "Invalid value for the provided marketplace " + f"parameter: '{self.marketplace_raw}'." + "The marketplace parameter should be in the format: " + "' ' " + "or ':::'" ) + self._marketplace = AzureVmMarketplaceSchema(*marketplace_strings) + # marketplace_raw is used + self.marketplace_raw = ( + self._marketplace.to_dict() # type: ignore [attr-defined] + ) - if len(marketplace_strings) == 4: - marketplace = AzureVmMarketplaceSchema(*marketplace_strings) - # marketplace_raw is used - self.marketplace_raw = marketplace.to_dict() # type: ignore - else: - raise LisaException( - f"Invalid value for the provided marketplace " - f"parameter: '{self.marketplace_raw}'." - f"The marketplace parameter should be in the format: " - f"' ' " - f"or ':::'" - ) - self._marketplace = marketplace - return marketplace + return self._marketplace @marketplace.setter def marketplace(self, value: Optional[AzureVmMarketplaceSchema]) -> None: self._marketplace = value - if value is None: - self.marketplace_raw = None - else: - self.marketplace_raw = value.to_dict() # type: ignore + # dataclass_json doesn't use a protocol return type, so to_dict() is unknown + self.marketplace_raw = ( + None if value is None else value.to_dict() # type: ignore [attr-defined] + ) @property def shared_gallery(self) -> Optional[SharedImageGallerySchema]: - # this is a safe guard and prevent mypy error on typing - if not hasattr(self, "_shared_gallery"): - self._shared_gallery: Optional[SharedImageGallerySchema] = None - shared_gallery: Optional[SharedImageGallerySchema] = self._shared_gallery - if shared_gallery: - return shared_gallery + if self._shared_gallery is not None: + return self._shared_gallery + if isinstance(self.shared_gallery_raw, dict): # Users decide the cases of image names, - # the inconsistent cases cause the mismatched error in notifiers. - # The lower() normalizes the image names, - # it has no impact on deployment. - self.shared_gallery_raw = dict( - (k, v.lower()) for k, v in self.shared_gallery_raw.items() - ) - shared_gallery = schema.load_by_type( + # inconsistent cases cause a mismatch error in notifiers. + # lower() normalizes the image names, it has no impact on deployment + self.shared_gallery_raw = { + k: v.lower() for k, v in self.shared_gallery_raw.items() + } + + self._shared_gallery = schema.load_by_type( SharedImageGallerySchema, self.shared_gallery_raw ) - if not shared_gallery.subscription_id: - shared_gallery.subscription_id = self.subscription_id - # this step makes shared_gallery_raw is validated, and - # filter out any unwanted content. - self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore + if not self._shared_gallery.subscription_id: + self._shared_gallery.subscription_id = self.subscription_id + # Validated shared_gallery_raw and filter out any unwanted content + self.shared_gallery_raw = self._shared_gallery.to_dict() # type: ignore + elif self.shared_gallery_raw: assert isinstance( self.shared_gallery_raw, str ), f"actual: {type(self.shared_gallery_raw)}" # Users decide the cases of image names, - # the inconsistent cases cause the mismatched error in notifiers. - # The lower() normalizes the image names, - # it has no impact on deployment. + # inconsistent cases cause a mismatch error in notifiers. + # lower() normalizes the image names, it has no impact on deployment shared_gallery_strings = re.split( r"[/]+", self.shared_gallery_raw.strip().lower() ) if len(shared_gallery_strings) == 5: - shared_gallery = SharedImageGallerySchema(*shared_gallery_strings) - # shared_gallery_raw is used - self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore + self._shared_gallery = SharedImageGallerySchema(*shared_gallery_strings) elif len(shared_gallery_strings) == 3: - shared_gallery = SharedImageGallerySchema( + self._shared_gallery = SharedImageGallerySchema( self.subscription_id, None, *shared_gallery_strings ) - # shared_gallery_raw is used - self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore else: raise LisaException( f"Invalid value for the provided shared gallery " @@ -366,51 +352,43 @@ def shared_gallery(self) -> Optional[SharedImageGallerySchema]: f"/' or '/" f"/'" ) - self._shared_gallery = shared_gallery - return shared_gallery + self.shared_gallery_raw = self._shared_gallery.to_dict() # type: ignore + + return self._shared_gallery @shared_gallery.setter def shared_gallery(self, value: Optional[SharedImageGallerySchema]) -> None: self._shared_gallery = value - if value is None: - self.shared_gallery_raw = None - else: - self.shared_gallery_raw = value.to_dict() # type: ignore + # dataclass_json doesn't use a protocol return type, so to_dict() is unknown + self.shared_gallery_raw = ( + None if value is None else value.to_dict() # type: ignore [attr-defined] + ) @property def vhd(self) -> Optional[VhdSchema]: - # this is a safe guard and prevent mypy error on typing - if not hasattr(self, "_vhd"): - self._vhd: Optional[VhdSchema] = None - vhd: Optional[VhdSchema] = self._vhd - if vhd: - return vhd + if self._vhd is not None: + return self._vhd + if isinstance(self.vhd_raw, dict): - vhd = schema.load_by_type(VhdSchema, self.vhd_raw) - add_secret(vhd.vhd_path, PATTERN_URL) - if vhd.vmgs_path: - add_secret(vhd.vmgs_path, PATTERN_URL) - # this step makes vhd_raw is validated, and - # filter out any unwanted content. - self.vhd_raw = vhd.to_dict() # type: ignore + self._vhd = schema.load_by_type(VhdSchema, self.vhd_raw) + add_secret(self._vhd.vhd_path, PATTERN_URL) + if self._vhd.vmgs_path: + add_secret(self._vhd.vmgs_path, PATTERN_URL) + # Validated vhd_raw and filter out any unwanted content + self.vhd_raw = self._vhd.to_dict() # type: ignore + elif self.vhd_raw is not None: assert isinstance(self.vhd_raw, str), f"actual: {type(self.vhd_raw)}" - vhd = VhdSchema(self.vhd_raw) - add_secret(vhd.vhd_path, PATTERN_URL) - self.vhd_raw = vhd.to_dict() # type: ignore - self._vhd = vhd - if vhd: - return vhd - else: - return None + self._vhd = VhdSchema(self.vhd_raw) + add_secret(self._vhd.vhd_path, PATTERN_URL) + self.vhd_raw = self._vhd.to_dict() # type: ignore + + return self._vhd @vhd.setter def vhd(self, value: Optional[VhdSchema]) -> None: self._vhd = value - if value is None: - self.vhd_raw = None - else: - self.vhd_raw = self._vhd.to_dict() # type: ignore + self.vhd_raw = None if value is None else self._vhd.to_dict() # type: ignore def get_image_name(self) -> str: result = "" @@ -421,7 +399,7 @@ def get_image_name(self) -> str: self.shared_gallery_raw, dict ), f"actual type: {type(self.shared_gallery_raw)}" if self.shared_gallery.resource_group_name: - result = "/".join([x for x in self.shared_gallery_raw.values()]) + result = "/".join(self.shared_gallery_raw.values()) else: result = ( f"{self.shared_gallery.image_gallery}/" @@ -432,7 +410,7 @@ def get_image_name(self) -> str: assert isinstance( self.marketplace_raw, dict ), f"actual type: {type(self.marketplace_raw)}" - result = " ".join([x for x in self.marketplace_raw.values()]) + result = " ".join(self.marketplace_raw.values()) return result @@ -457,9 +435,7 @@ def from_node_runbook(cls, runbook: AzureNodeSchema) -> "AzureNodeArmParameter": parameters["vhd_raw"] = parameters["vhd"] del parameters["vhd"] - arm_parameters = AzureNodeArmParameter(**parameters) - - return arm_parameters + return AzureNodeArmParameter(**parameters) class DataDiskCreateOption: diff --git a/lisa/sut_orchestrator/azure/platform_.py b/lisa/sut_orchestrator/azure/platform_.py index 73ebd08a45..1f4f16e86c 100644 --- a/lisa/sut_orchestrator/azure/platform_.py +++ b/lisa/sut_orchestrator/azure/platform_.py @@ -9,7 +9,7 @@ import re import sys from copy import deepcopy -from dataclasses import InitVar, dataclass, field +from dataclasses import dataclass, field from datetime import datetime from difflib import SequenceMatcher from functools import lru_cache, partial @@ -265,7 +265,6 @@ class AzurePlatformSchema: cloud_raw: Optional[Union[Dict[str, Any], str]] = field( default=None, metadata=field_metadata(data_key="cloud") ) - _cloud: InitVar[Cloud] = None shared_resource_group_name: str = AZURE_SHARED_RG_NAME resource_group_name: str = field(default="") @@ -315,6 +314,8 @@ class AzurePlatformSchema: azcopy_path: str = field(default="") def __post_init__(self, *args: Any, **kwargs: Any) -> None: + self._cloud: Optional[Cloud] = None + strip_strs( self, [ @@ -352,61 +353,59 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None: @property def cloud(self) -> Cloud: - # this is a safe guard and prevent mypy error on typing - if not hasattr(self, "_cloud"): - self._cloud: Cloud = None - cloud: Cloud = self._cloud - if not cloud: - # if pass str into cloud, it should be one of below values, case insensitive - # azurecloud - # azurechinacloud - # azuregermancloud - # azureusgovernment - # example - # cloud: AzureCloud - if isinstance(self.cloud_raw, str): - cloud = CLOUD.get(self.cloud_raw.lower(), None) - assert cloud, ( - f"cannot find cloud type {self.cloud_raw}," - f" current support list is {list(CLOUD.keys())}" - ) - # if pass dict to construct a cloud instance, the full example is - # cloud: - # name: AzureCloud - # endpoints: - # management: https://management.core.windows.net/ - # resource_manager: https://management.azure.com/ - # sql_management: https://management.core.windows.net:8443/ - # batch_resource_id: https://batch.core.windows.net/ - # gallery: https://gallery.azure.com/ - # active_directory: https://login.microsoftonline.com - # active_directory_resource_id: https://management.core.windows.net/ - # active_directory_graph_resource_id: https://graph.windows.net/ - # microsoft_graph_resource_id: https://graph.microsoft.com/ - # suffixes: - # storage_endpoint: core.windows.net - # keyvault_dns: .vault.azure.net - # sql_server_hostname: .database.windows.net - # azure_datalake_store_file_system_endpoint: azuredatalakestore.net - # azure_datalake_analytics_catalog_and_job_endpoint: azuredatalakeanalytics.net # noqa: E501 - elif isinstance(self.cloud_raw, dict): - cloudschema = schema.load_by_type(CloudSchema, self.cloud_raw) - cloud = Cloud( - cloudschema.name, cloudschema.endpoints, cloudschema.suffixes - ) - else: - # by default use azure public cloud - cloud = AZURE_PUBLIC_CLOUD - self._cloud = cloud + if self._cloud is not None: + return self._cloud + + # if pass str into cloud, it should be one of below values, case insensitive + # azurecloud + # azurechinacloud + # azuregermancloud + # azureusgovernment + # example + # cloud: AzureCloud + if isinstance(self.cloud_raw, str): + cloud = CLOUD.get(self.cloud_raw.lower(), None) + assert cloud, ( + f"cannot find cloud type {self.cloud_raw}," + f" current support list is {list(CLOUD.keys())}" + ) + # if pass dict to construct a cloud instance, the full example is + # cloud: + # name: AzureCloud + # endpoints: + # management: https://management.core.windows.net/ + # resource_manager: https://management.azure.com/ + # sql_management: https://management.core.windows.net:8443/ + # batch_resource_id: https://batch.core.windows.net/ + # gallery: https://gallery.azure.com/ + # active_directory: https://login.microsoftonline.com + # active_directory_resource_id: https://management.core.windows.net/ + # active_directory_graph_resource_id: https://graph.windows.net/ + # microsoft_graph_resource_id: https://graph.microsoft.com/ + # suffixes: + # storage_endpoint: core.windows.net + # keyvault_dns: .vault.azure.net + # sql_server_hostname: .database.windows.net + # azure_datalake_store_file_system_endpoint: azuredatalakestore.net + # azure_datalake_analytics_catalog_and_job_endpoint: azuredatalakeanalytics.net # noqa: E501 + + elif isinstance(self.cloud_raw, dict): + cloudschema = schema.load_by_type(CloudSchema, self.cloud_raw) + cloud = Cloud(cloudschema.name, cloudschema.endpoints, cloudschema.suffixes) + + else: + # by default use azure public cloud + cloud = AZURE_PUBLIC_CLOUD + + self._cloud = cloud return cloud @cloud.setter def cloud(self, value: Optional[CloudSchema]) -> None: self._cloud = value - if value is None: - self.cloud_raw = None - else: - self.cloud_raw = value.to_dict() # type: ignore + self.cloud_raw = ( + None if value is None else value.to_dict() # type: ignore[attr-defined] + ) class AzurePlatform(Platform): @@ -2168,7 +2167,7 @@ def _get_vhd_os_disk_size(self, blob_url: str) -> int: assert properties.size, f"fail to get blob size of {blob_url}" # Azure requires only megabyte alignment of vhds, round size up # for cases where the size is megabyte aligned - return math.ceil(properties.size / 1024 / 1024 / 1024) + return int(math.ceil(properties.size / 1024 / 1024 / 1024)) def _get_sig_info( self, shared_image: SharedImageGallerySchema diff --git a/lisa/testsuite.py b/lisa/testsuite.py index e034756a63..d6a0fffa98 100644 --- a/lisa/testsuite.py +++ b/lisa/testsuite.py @@ -575,7 +575,7 @@ def start( suite_error_message, suite_error_stacktrace, ) = self.__suite_method( - self.before_suite, # type: ignore + self.before_suite, test_kwargs=test_kwargs, log=suite_log, ) @@ -654,7 +654,7 @@ def start( if hasattr(self, "after_suite"): self.__suite_method( - self.after_suite, # type: ignore + self.after_suite, test_kwargs=test_kwargs, log=suite_log, ) diff --git a/lisa/util/__init__.py b/lisa/util/__init__.py index 6289d05896..14d826dbf6 100644 --- a/lisa/util/__init__.py +++ b/lisa/util/__init__.py @@ -164,9 +164,7 @@ def __init__(self, os: "OperatingSystem", message: str = "") -> None: self.version = os.information.full_version self.kernel_version = "" if hasattr(os, "get_kernel_information"): - self.kernel_version = ( - os.get_kernel_information().raw_version # type: ignore - ) + self.kernel_version = os.get_kernel_information().raw_version self._extended_message = message def __str__(self) -> str: @@ -516,18 +514,8 @@ def find_group_in_lines( def deep_update_dict(src: Dict[str, Any], dest: Dict[str, Any]) -> Dict[str, Any]: - if ( - dest is None - or isinstance(dest, int) - or isinstance(dest, bool) - or isinstance(dest, float) - or isinstance(dest, str) - ): - result = dest - else: + if isinstance(dest, dict): result = dest.copy() - - if isinstance(result, dict): for key, value in src.items(): if isinstance(value, dict) and key in dest: value = deep_update_dict(value, dest[key])