From 2d1118161a2b51dd2c10311c13afc5d8e525ee4d Mon Sep 17 00:00:00 2001 From: Avram Lubkin Date: Thu, 17 Nov 2022 22:06:14 -0500 Subject: [PATCH] Typing fixes for new version of mypy --- lisa/environment.py | 6 +- lisa/sut_orchestrator/aws/common.py | 13 ++- lisa/sut_orchestrator/azure/common.py | 142 +++++++++++++------------- lisa/tools/bzip2.py | 1 + lisa/util/__init__.py | 17 +-- 5 files changed, 83 insertions(+), 96 deletions(-) diff --git a/lisa/environment.py b/lisa/environment.py index 5de2f7d69c..8da64c1915 100644 --- a/lisa/environment.py +++ b/lisa/environment.py @@ -489,15 +489,13 @@ def load_environments( class EnvironmentHookSpec: @hookspec def get_environment_information(self, environment: Environment) -> Dict[str, str]: - ... + raise NotImplementedError class EnvironmentHookImpl: @hookimpl def get_environment_information(self, environment: Environment) -> Dict[str, str]: - information: Dict[str, str] = {} - information["name"] = environment.name - + information: Dict[str, str] = {"name": environment.name} if environment.nodes: node = environment.default_node try: diff --git a/lisa/sut_orchestrator/aws/common.py b/lisa/sut_orchestrator/aws/common.py index b014494864..6d36ea9aae 100644 --- a/lisa/sut_orchestrator/aws/common.py +++ b/lisa/sut_orchestrator/aws/common.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from dataclasses import InitVar, dataclass, field +from dataclasses import dataclass, field from typing import Dict, List, Optional from dataclasses_json import dataclass_json @@ -67,16 +67,15 @@ class AwsNodeSchema: data_disk_size: int = 32 disk_type: str = "" - # for marketplace image, which need to accept terms - _marketplace: InitVar[Optional[AwsVmMarketplaceSchema]] = None + def __post_init__(self) -> None: + + # Caching for marketplace image + self._marketplace: Optional[AwsVmMarketplaceSchema] = None @property def marketplace(self) -> AwsVmMarketplaceSchema: - # this is a safe guard and prevent mypy error on typing - if not hasattr(self, "_marketplace"): - self._marketplace: Optional[AwsVmMarketplaceSchema] = None - if not self._marketplace: + if self._marketplace is None: assert isinstance( self.marketplace_raw, str ), f"actual: {type(self.marketplace_raw)}" diff --git a/lisa/sut_orchestrator/azure/common.py b/lisa/sut_orchestrator/azure/common.py index 3aa6d92ea7..f4d159f317 100644 --- a/lisa/sut_orchestrator/azure/common.py +++ b/lisa/sut_orchestrator/azure/common.py @@ -3,7 +3,7 @@ import re import sys -from dataclasses import InitVar, dataclass, field +from dataclasses import dataclass, field from datetime import datetime, timedelta from pathlib import Path from threading import Lock @@ -175,11 +175,12 @@ class AzureNodeSchema: # image. is_linux: Optional[bool] = None - _marketplace: InitVar[Optional[AzureVmMarketplaceSchema]] = None + def __post_init__(self) -> None: - _shared_gallery: InitVar[Optional[SharedImageGallerySchema]] = None + # Caching + self._marketplace: Optional[AzureVmMarketplaceSchema] = None + self._shared_gallery: Optional[SharedImageGallerySchema] = None - def __post_init__(self, *args: Any, **kwargs: Any) -> None: # trim whitespace of values. strip_strs( self, @@ -201,80 +202,78 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None: @property def marketplace(self) -> Optional[AzureVmMarketplaceSchema]: - # this is a safe guard and prevent mypy error on typing - if not hasattr(self, "_marketplace"): - self._marketplace: Optional[AzureVmMarketplaceSchema] = None - marketplace: Optional[AzureVmMarketplaceSchema] = self._marketplace - if not marketplace: - if isinstance(self.marketplace_raw, dict): + + if self._marketplace is not None: + return self._marketplace + + if isinstance(self.marketplace_raw, dict): + # Users decide the cases of image names, + # the inconsistent cases cause the mismatched error in notifiers. + # The lower() normalizes the image names, + # it has no impact on deployment. + self.marketplace_raw = { + k: v.lower() for k, v in self.marketplace_raw.items() + } + self._marketplace = schema.load_by_type( + AzureVmMarketplaceSchema, self.marketplace_raw + ) + # This step makes sure marketplace_raw is validated, and + # filters out any unwanted content. + self.marketplace_raw = self._marketplace.to_dict() # type: ignore + + elif self.marketplace_raw: + assert isinstance( + self.marketplace_raw, str + ), f"actual: {type(self.marketplace_raw)}" + + self.marketplace_raw = self.marketplace_raw.strip() + + if self.marketplace_raw: # Users decide the cases of image names, # the inconsistent cases cause the mismatched error in notifiers. # The lower() normalizes the image names, # it has no impact on deployment. - self.marketplace_raw = dict( - (k, v.lower()) for k, v in self.marketplace_raw.items() - ) - marketplace = schema.load_by_type( - AzureVmMarketplaceSchema, self.marketplace_raw - ) - # this step makes marketplace_raw is validated, and - # filter out any unwanted content. - self.marketplace_raw = marketplace.to_dict() # type: ignore - elif self.marketplace_raw: - assert isinstance( - self.marketplace_raw, str - ), f"actual: {type(self.marketplace_raw)}" - - self.marketplace_raw = self.marketplace_raw.strip() - - if self.marketplace_raw: - # Users decide the cases of image names, - # the inconsistent cases cause the mismatched error in notifiers. - # The lower() normalizes the image names, - # it has no impact on deployment. - marketplace_strings = re.split( - r"[:\s]+", self.marketplace_raw.lower() + marketplace_strings = re.split(r"[:\s]+", self.marketplace_raw.lower()) + + if len(marketplace_strings) != 4: + raise LisaException( + "Invalid value for the provided marketplace " + f"parameter: '{self.marketplace_raw}'." + "The marketplace parameter should be in the format: " + "' ' " + "or ':::'" ) + self._marketplace = AzureVmMarketplaceSchema(*marketplace_strings) + # marketplace_raw is used + self.marketplace_raw = ( + self._marketplace.to_dict() # type: ignore [attr-defined] + ) - if len(marketplace_strings) == 4: - marketplace = AzureVmMarketplaceSchema(*marketplace_strings) - # marketplace_raw is used - self.marketplace_raw = marketplace.to_dict() # type: ignore - else: - raise LisaException( - f"Invalid value for the provided marketplace " - f"parameter: '{self.marketplace_raw}'." - f"The marketplace parameter should be in the format: " - f"' ' " - f"or ':::'" - ) - self._marketplace = marketplace - return marketplace + return self._marketplace @marketplace.setter def marketplace(self, value: Optional[AzureVmMarketplaceSchema]) -> None: self._marketplace = value - if value is None: - self.marketplace_raw = None - else: - self.marketplace_raw = value.to_dict() # type: ignore + # dataclass_json doesn't use a protocol return type, so to_dict() is unknown + self.marketplace_raw = ( + None if value is None else value.to_dict() # type: ignore [attr-defined] + ) @property def shared_gallery(self) -> Optional[SharedImageGallerySchema]: - # this is a safe guard and prevent mypy error on typing - if not hasattr(self, "_shared_gallery"): - self._shared_gallery: Optional[SharedImageGallerySchema] = None - shared_gallery: Optional[SharedImageGallerySchema] = self._shared_gallery - if shared_gallery: - return shared_gallery + + if self._shared_gallery is not None: + return self._shared_gallery + if isinstance(self.shared_gallery_raw, dict): # Users decide the cases of image names, # the inconsistent cases cause the mismatched error in notifiers. # The lower() normalizes the image names, # it has no impact on deployment. - self.shared_gallery_raw = dict( - (k, v.lower()) for k, v in self.shared_gallery_raw.items() - ) + self.shared_gallery_raw = { + k: v.lower() for k, v in self.shared_gallery_raw.items() + } + shared_gallery = schema.load_by_type( SharedImageGallerySchema, self.shared_gallery_raw ) @@ -283,6 +282,8 @@ def shared_gallery(self) -> Optional[SharedImageGallerySchema]: # this step makes shared_gallery_raw is validated, and # filter out any unwanted content. self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore + self._shared_gallery = shared_gallery + elif self.shared_gallery_raw: assert isinstance( self.shared_gallery_raw, str @@ -299,11 +300,12 @@ def shared_gallery(self) -> Optional[SharedImageGallerySchema]: # shared_gallery_raw is used self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore elif len(shared_gallery_strings) == 3: - shared_gallery = SharedImageGallerySchema( + self._shared_gallery = SharedImageGallerySchema( self.subscription_id, None, *shared_gallery_strings ) # shared_gallery_raw is used - self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore + self.shared_gallery_raw = self._shared_gallery.to_dict() # type: ignore + else: raise LisaException( f"Invalid value for the provided shared gallery " @@ -313,16 +315,16 @@ def shared_gallery(self) -> Optional[SharedImageGallerySchema]: f"/' or '/" f"/'" ) - self._shared_gallery = shared_gallery - return shared_gallery + + return self._shared_gallery @shared_gallery.setter def shared_gallery(self, value: Optional[SharedImageGallerySchema]) -> None: self._shared_gallery = value - if value is None: - self.shared_gallery_raw = None - else: - self.shared_gallery_raw = value.to_dict() # type: ignore + # dataclass_json doesn't use a protocol return type, so to_dict() is unknown + self.shared_gallery_raw = ( + None if value is None else value.to_dict() # type: ignore [attr-defined] + ) def get_image_name(self) -> str: result = "" @@ -365,9 +367,7 @@ def from_node_runbook(cls, runbook: AzureNodeSchema) -> "AzureNodeArmParameter": parameters["shared_gallery_raw"] = parameters["shared_gallery"] del parameters["shared_gallery"] - arm_parameters = AzureNodeArmParameter(**parameters) - - return arm_parameters + return AzureNodeArmParameter(**parameters) class DataDiskCreateOption: diff --git a/lisa/tools/bzip2.py b/lisa/tools/bzip2.py index 525a50ac8d..341431c79f 100644 --- a/lisa/tools/bzip2.py +++ b/lisa/tools/bzip2.py @@ -9,6 +9,7 @@ class Bzip2(Tool): def command(self) -> str: return "bzip2" + @property def can_install(self) -> bool: return True diff --git a/lisa/util/__init__.py b/lisa/util/__init__.py index 8a7dc102d2..2697e42c30 100644 --- a/lisa/util/__init__.py +++ b/lisa/util/__init__.py @@ -159,9 +159,7 @@ def __init__(self, os: "OperatingSystem", message: str = "") -> None: self.version = os.information.full_version self.kernel_version = "" if hasattr(os, "get_kernel_information"): - self.kernel_version = ( - os.get_kernel_information().raw_version # type: ignore - ) + self.kernel_version = os.get_kernel_information().raw_version self._extended_message = message def __str__(self) -> str: @@ -505,18 +503,9 @@ def find_group_in_lines( def deep_update_dict(src: Dict[str, Any], dest: Dict[str, Any]) -> Dict[str, Any]: - if ( - dest is None - or isinstance(dest, int) - or isinstance(dest, bool) - or isinstance(dest, float) - or isinstance(dest, str) - ): - result = dest - else: - result = dest.copy() - if isinstance(result, dict): + if isinstance(dest, dict): + result = dest.copy() for key, value in src.items(): if isinstance(value, dict) and key in dest: value = deep_update_dict(value, dest[key])