Skip to content

Commit

Permalink
Typing fixes for new version of mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
avylove committed Feb 6, 2023
1 parent f7693fc commit 2cbde7e
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 144 deletions.
6 changes: 2 additions & 4 deletions lisa/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,15 +489,13 @@ def load_environments(
class EnvironmentHookSpec:
@hookspec
def get_environment_information(self, environment: Environment) -> Dict[str, str]:
...
raise NotImplementedError


class EnvironmentHookImpl:
@hookimpl
def get_environment_information(self, environment: Environment) -> Dict[str, str]:
information: Dict[str, str] = {}
information["name"] = environment.name

information: Dict[str, str] = {"name": environment.name}
if environment.nodes:
node = environment.default_node
try:
Expand Down
13 changes: 5 additions & 8 deletions lisa/sut_orchestrator/aws/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from dataclasses import InitVar, dataclass, field
from dataclasses import dataclass, field
from typing import Dict, List, Optional

from dataclasses_json import dataclass_json
Expand Down Expand Up @@ -67,16 +67,13 @@ class AwsNodeSchema:
data_disk_size: int = 32
disk_type: str = ""

# for marketplace image, which need to accept terms
_marketplace: InitVar[Optional[AwsVmMarketplaceSchema]] = None
def __post_init__(self) -> None:
# Caching for marketplace image
self._marketplace: Optional[AwsVmMarketplaceSchema] = None

@property
def marketplace(self) -> AwsVmMarketplaceSchema:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_marketplace"):
self._marketplace: Optional[AwsVmMarketplaceSchema] = None

if not self._marketplace:
if self._marketplace is None:
assert isinstance(
self.marketplace_raw, str
), f"actual: {type(self.marketplace_raw)}"
Expand Down
212 changes: 94 additions & 118 deletions lisa/sut_orchestrator/azure/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import re
import sys
from dataclasses import InitVar, dataclass, field
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from threading import Lock
Expand Down Expand Up @@ -185,13 +185,12 @@ class AzureNodeSchema:
# image.
is_linux: Optional[bool] = None

_marketplace: InitVar[Optional[AzureVmMarketplaceSchema]] = None
def __post_init__(self) -> None:
# Caching
self._marketplace: Optional[AzureVmMarketplaceSchema] = None
self._shared_gallery: Optional[SharedImageGallerySchema] = None
self._vhd: Optional[VhdSchema] = None

_shared_gallery: InitVar[Optional[SharedImageGallerySchema]] = None

_vhd: InitVar[Optional[VhdSchema]] = None

def __post_init__(self, *args: Any, **kwargs: Any) -> None:
# trim whitespace of values.
strip_strs(
self,
Expand All @@ -214,109 +213,96 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:

@property
def marketplace(self) -> Optional[AzureVmMarketplaceSchema]:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_marketplace"):
self._marketplace: Optional[AzureVmMarketplaceSchema] = None
marketplace: Optional[AzureVmMarketplaceSchema] = self._marketplace
if not marketplace:
if isinstance(self.marketplace_raw, dict):
if self._marketplace is not None:
return self._marketplace

if isinstance(self.marketplace_raw, dict):
# Users decide the cases of image names,
# inconsistent cases cause a mismatch error in notifiers.
# lower() normalizes the image names, it has no impact on deployment
self.marketplace_raw = {
k: v.lower() for k, v in self.marketplace_raw.items()
}
self._marketplace = schema.load_by_type(
AzureVmMarketplaceSchema, self.marketplace_raw
)
# Validated marketplace_raw and filter out any unwanted content
self.marketplace_raw = self._marketplace.to_dict() # type: ignore

elif self.marketplace_raw:
assert isinstance(
self.marketplace_raw, str
), f"actual: {type(self.marketplace_raw)}"

self.marketplace_raw = self.marketplace_raw.strip()

if self.marketplace_raw:
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
self.marketplace_raw = dict(
(k, v.lower()) for k, v in self.marketplace_raw.items()
)
marketplace = schema.load_by_type(
AzureVmMarketplaceSchema, self.marketplace_raw
)
# this step makes marketplace_raw is validated, and
# filter out any unwanted content.
self.marketplace_raw = marketplace.to_dict() # type: ignore
elif self.marketplace_raw:
assert isinstance(
self.marketplace_raw, str
), f"actual: {type(self.marketplace_raw)}"

self.marketplace_raw = self.marketplace_raw.strip()

if self.marketplace_raw:
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
marketplace_strings = re.split(
r"[:\s]+", self.marketplace_raw.lower()
# inconsistent cases cause a mismatch error in notifiers.
# lower() normalizes the image names, it has no impact on deployment
marketplace_strings = re.split(r"[:\s]+", self.marketplace_raw.lower())

if len(marketplace_strings) != 4:
raise LisaException(
"Invalid value for the provided marketplace "
f"parameter: '{self.marketplace_raw}'."
"The marketplace parameter should be in the format: "
"'<Publisher> <Offer> <Sku> <Version>' "
"or '<Publisher>:<Offer>:<Sku>:<Version>'"
)
self._marketplace = AzureVmMarketplaceSchema(*marketplace_strings)
# marketplace_raw is used
self.marketplace_raw = (
self._marketplace.to_dict() # type: ignore [attr-defined]
)

if len(marketplace_strings) == 4:
marketplace = AzureVmMarketplaceSchema(*marketplace_strings)
# marketplace_raw is used
self.marketplace_raw = marketplace.to_dict() # type: ignore
else:
raise LisaException(
f"Invalid value for the provided marketplace "
f"parameter: '{self.marketplace_raw}'."
f"The marketplace parameter should be in the format: "
f"'<Publisher> <Offer> <Sku> <Version>' "
f"or '<Publisher>:<Offer>:<Sku>:<Version>'"
)
self._marketplace = marketplace
return marketplace
return self._marketplace

@marketplace.setter
def marketplace(self, value: Optional[AzureVmMarketplaceSchema]) -> None:
self._marketplace = value
if value is None:
self.marketplace_raw = None
else:
self.marketplace_raw = value.to_dict() # type: ignore
# dataclass_json doesn't use a protocol return type, so to_dict() is unknown
self.marketplace_raw = (
None if value is None else value.to_dict() # type: ignore [attr-defined]
)

@property
def shared_gallery(self) -> Optional[SharedImageGallerySchema]:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_shared_gallery"):
self._shared_gallery: Optional[SharedImageGallerySchema] = None
shared_gallery: Optional[SharedImageGallerySchema] = self._shared_gallery
if shared_gallery:
return shared_gallery
if self._shared_gallery is not None:
return self._shared_gallery

if isinstance(self.shared_gallery_raw, dict):
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
self.shared_gallery_raw = dict(
(k, v.lower()) for k, v in self.shared_gallery_raw.items()
)
shared_gallery = schema.load_by_type(
# inconsistent cases cause a mismatch error in notifiers.
# lower() normalizes the image names, it has no impact on deployment
self.shared_gallery_raw = {
k: v.lower() for k, v in self.shared_gallery_raw.items()
}

self._shared_gallery = schema.load_by_type(
SharedImageGallerySchema, self.shared_gallery_raw
)
if not shared_gallery.subscription_id:
shared_gallery.subscription_id = self.subscription_id
# this step makes shared_gallery_raw is validated, and
# filter out any unwanted content.
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore
if not self._shared_gallery.subscription_id:
self._shared_gallery.subscription_id = self.subscription_id
# Validated shared_gallery_raw and filter out any unwanted content
self.shared_gallery_raw = self._shared_gallery.to_dict() # type: ignore

elif self.shared_gallery_raw:
assert isinstance(
self.shared_gallery_raw, str
), f"actual: {type(self.shared_gallery_raw)}"
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
# inconsistent cases cause a mismatch error in notifiers.
# lower() normalizes the image names, it has no impact on deployment
shared_gallery_strings = re.split(
r"[/]+", self.shared_gallery_raw.strip().lower()
)
if len(shared_gallery_strings) == 5:
shared_gallery = SharedImageGallerySchema(*shared_gallery_strings)
# shared_gallery_raw is used
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore
self._shared_gallery = SharedImageGallerySchema(*shared_gallery_strings)
elif len(shared_gallery_strings) == 3:
shared_gallery = SharedImageGallerySchema(
self._shared_gallery = SharedImageGallerySchema(
self.subscription_id, None, *shared_gallery_strings
)
# shared_gallery_raw is used
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore
else:
raise LisaException(
f"Invalid value for the provided shared gallery "
Expand All @@ -326,51 +312,43 @@ def shared_gallery(self) -> Optional[SharedImageGallerySchema]:
f"<image_definition>/<image_version>' or '<image_gallery>/"
f"<image_definition>/<image_version>'"
)
self._shared_gallery = shared_gallery
return shared_gallery
self.shared_gallery_raw = self._shared_gallery.to_dict() # type: ignore

return self._shared_gallery

@shared_gallery.setter
def shared_gallery(self, value: Optional[SharedImageGallerySchema]) -> None:
self._shared_gallery = value
if value is None:
self.shared_gallery_raw = None
else:
self.shared_gallery_raw = value.to_dict() # type: ignore
# dataclass_json doesn't use a protocol return type, so to_dict() is unknown
self.shared_gallery_raw = (
None if value is None else value.to_dict() # type: ignore [attr-defined]
)

@property
def vhd(self) -> Optional[VhdSchema]:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_vhd"):
self._vhd: Optional[VhdSchema] = None
vhd: Optional[VhdSchema] = self._vhd
if vhd:
return vhd
if self._vhd is not None:
return self._vhd

if isinstance(self.vhd_raw, dict):
vhd = schema.load_by_type(VhdSchema, self.vhd_raw)
add_secret(vhd.vhd_path, PATTERN_URL)
if vhd.vmgs_path:
add_secret(vhd.vmgs_path, PATTERN_URL)
# this step makes vhd_raw is validated, and
# filter out any unwanted content.
self.vhd_raw = vhd.to_dict() # type: ignore
self._vhd = schema.load_by_type(VhdSchema, self.vhd_raw)
add_secret(self._vhd.vhd_path, PATTERN_URL)
if self._vhd.vmgs_path:
add_secret(self._vhd.vmgs_path, PATTERN_URL)
# Validated vhd_raw and filter out any unwanted content
self.vhd_raw = self._vhd.to_dict() # type: ignore

elif self.vhd_raw is not None:
assert isinstance(self.vhd_raw, str), f"actual: {type(self.vhd_raw)}"
vhd = VhdSchema(self.vhd_raw)
add_secret(vhd.vhd_path, PATTERN_URL)
self.vhd_raw = vhd.to_dict() # type: ignore
self._vhd = vhd
if vhd:
return vhd
else:
return None
self._vhd = VhdSchema(self.vhd_raw)
add_secret(self._vhd.vhd_path, PATTERN_URL)
self.vhd_raw = self._vhd.to_dict() # type: ignore

return self._vhd

@vhd.setter
def vhd(self, value: Optional[VhdSchema]) -> None:
self._vhd = value
if value is None:
self.vhd_raw = None
else:
self.vhd_raw = self._vhd.to_dict() # type: ignore
self.vhd_raw = None if value is None else self._vhd.to_dict() # type: ignore

def get_image_name(self) -> str:
result = ""
Expand Down Expand Up @@ -416,9 +394,7 @@ def from_node_runbook(cls, runbook: AzureNodeSchema) -> "AzureNodeArmParameter":
parameters["vhd_raw"] = parameters["vhd"]
del parameters["vhd"]

arm_parameters = AzureNodeArmParameter(**parameters)

return arm_parameters
return AzureNodeArmParameter(**parameters)


class DataDiskCreateOption:
Expand Down
1 change: 1 addition & 0 deletions lisa/tools/bzip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class Bzip2(Tool):
def command(self) -> str:
return "bzip2"

@property
def can_install(self) -> bool:
return True

Expand Down
16 changes: 2 additions & 14 deletions lisa/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,7 @@ def __init__(self, os: "OperatingSystem", message: str = "") -> None:
self.version = os.information.full_version
self.kernel_version = ""
if hasattr(os, "get_kernel_information"):
self.kernel_version = (
os.get_kernel_information().raw_version # type: ignore
)
self.kernel_version = os.get_kernel_information().raw_version
self._extended_message = message

def __str__(self) -> str:
Expand Down Expand Up @@ -504,18 +502,8 @@ def find_group_in_lines(


def deep_update_dict(src: Dict[str, Any], dest: Dict[str, Any]) -> Dict[str, Any]:
if (
dest is None
or isinstance(dest, int)
or isinstance(dest, bool)
or isinstance(dest, float)
or isinstance(dest, str)
):
result = dest
else:
if isinstance(dest, dict):
result = dest.copy()

if isinstance(result, dict):
for key, value in src.items():
if isinstance(value, dict) and key in dest:
value = deep_update_dict(value, dest[key])
Expand Down

0 comments on commit 2cbde7e

Please sign in to comment.