Skip to content

Commit

Permalink
Support semi string type hint. (#36)
Browse files Browse the repository at this point in the history
* Support semi string type hints.

* Support semi string type hint.

* Fix: Python3.9 ast.Subscript change.
  • Loading branch information
hadialqattan authored Oct 13, 2020
1 parent 5d1d062 commit fe0dbeb
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 20 deletions.
1 change: 1 addition & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm

### Added

- [Support semi string type hint by @hadialqattan](https://github.com/hadialqattan/pycln/pull/35)
- [Support casting case by @hadialqattan](https://github.com/hadialqattan/pycln/pull/34)

## [0.0.1-alpha.3] - 2020-10-07
Expand Down
5 changes: 1 addition & 4 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -763,11 +763,8 @@ All bellow imports are considered as used:

- Semi string:

> Not supported yet, on the roadmap:
> [# TODO](https://github.com/hadialqattan/pycln/projects/1#card-46611579).
```python
from ast import Import # With the current version Pycln will remove this.
from ast import Import
from typing import List

def foo(bar: List["Import"]):
Expand Down
99 changes: 87 additions & 12 deletions pycln/utils/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

# Constants.
PY38_PLUS = sys.version_info >= (3, 8)
PY39_PLUS = sys.version_info >= (3, 9)
IMPORT_EXCEPTIONS = {"ImportError", "ImportWarning", "ModuleNotFoundError"}
__ALL__ = "__all__"
NAMES_TO_SKIP = frozenset(
Expand All @@ -27,6 +28,63 @@
__ALL__,
}
)
SUBSCRIPT_TYPE_VARIABLE = frozenset(
{
"AbstractSet",
"AsyncContextManager",
"AsyncGenerator",
"AsyncIterable",
"AsyncIterator",
"Awaitable",
"ByteString",
"Callable",
"ChainMap",
"ClassVar",
"Collection",
"Container",
"ContextManager",
"Coroutine",
"Counter",
"DefaultDict",
"Deque",
"Dict",
"FrozenSet",
"Generator",
"IO",
"ItemsView",
"Iterable",
"Iterator",
"KeysView",
"List",
"Mapping",
"MappingView",
"Match",
"MutableMapping",
"MutableSequence",
"MutableSet",
"Optional",
"Pattern",
"Reversible",
"Sequence",
"Set",
"SupportsRound",
"Tuple",
"Type",
"Union",
"ValuesView",
# Python >=3.7:
"Literal",
# Python >=3.8:
"OrderedDict",
# Python >=3.9:
"tuple",
"list",
"dict",
"set",
"frozenset",
"type",
}
)

# Custom types.
FunctionT = TypeVar("FunctionT", bound=Callable[..., Any])
Expand Down Expand Up @@ -137,10 +195,26 @@ def visit_Call(self, node: ast.Call):
getattr(func, "attr", "") == "cast"
and getattr(func.value, "id", "") == "typing" # type: ignore
):
type_ = node.args[0]
value = getattr(type_, "value", "") or getattr(type_, "s", "")
if value:
self._source_stats.name_.add(value)
self._parse_string(node.args[0]) # type: ignore

@recursive
def visit_Subscript(self, node: ast.Subscript) -> None:
#: Support semi string type hints.
#: >>> from ast import Import
#: >>> from typing import List
#: >>> def foo(bar: List["Import"]):
#: >>> pass
#: Issue: https://github.com/hadialqattan/pycln/issues/32
value = getattr(node, "value", "")
if getattr(value, "id", "") in SUBSCRIPT_TYPE_VARIABLE or (
hasattr(value, "value") and getattr(value.value, "id", "") == "typing"
):
if PY39_PLUS:
s_val = node.slice # type: ignore
else:
s_val = node.slice.value # type: ignore
for elt in getattr(s_val, "elts", ()) or (s_val,):
self._parse_string(elt) # type: ignore

@recursive
def visit_Try(self, node: ast.Try):
Expand Down Expand Up @@ -250,14 +324,7 @@ def _visit_string_type_annotation(
annotation = node.annotation
else:
annotation = node.returns
if isinstance(annotation, (ast.Constant, ast.Str)):
if hasattr(annotation, "value"):
value = annotation.value
else:
value = annotation.s
if value:
tree = parse_ast(value, mode="eval")
self._add_name_attr(tree)
self._parse_string(annotation) # type: ignore

def _visit_type_comment(
self, node: Union[ast.Assign, ast.arg, FunctionDefT]
Expand All @@ -278,6 +345,14 @@ def _visit_type_comment(
tree = parse_ast(type_comment, mode=mode)
self._add_name_attr(tree)

def _parse_string(self, node: Union[ast.Constant, ast.Str]) -> None:
# Parse string names/attrs.
if isinstance(node, (ast.Constant, ast.Str)):
val = getattr(node, "value", "") or getattr(node, "s", "")
if val and isinstance(val, str):
tree = parse_ast(val, mode="eval")
self._add_name_attr(tree)

def _add_name_attr(self, tree: ast.AST):
# Add any `ast.Name` or `ast.Attribute`
# child to `self._source_stats`.
Expand Down
70 changes: 66 additions & 4 deletions tests/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_visit_Attribute(self, code, expec_attrs):
self.assert_set_equal_or_not(source_stats.attr_, expec_attrs)

@pytest.mark.parametrize(
"code, expec_names",
"code, expec_names, expec_attrs",
[
pytest.param(
(
Expand All @@ -202,7 +202,8 @@ def test_visit_Attribute(self, code, expec_attrs):
"baz = cast('foo', bar)\n"
),
{"cast", "foo", "bar", "baz"},
id="cast",
set(),
id="cast name",
),
pytest.param(
(
Expand All @@ -211,14 +212,75 @@ def test_visit_Attribute(self, code, expec_attrs):
"baz = typing.cast('foo', bar)\n"
),
{"typing", "foo", "bar", "baz"},
id="typing.cast",
{"cast"},
id="typing.cast name",
),
pytest.param(
(
"from typing import cast\n"
"import foo.x, bar\n"
"baz = cast('foo.x', bar)\n"
),
{"cast", "foo", "bar", "baz"},
{"x"},
id="cast attr-0",
),
pytest.param(
("import typing\n" "import foo, bar\n" "baz = cast('foo.x', bar)\n"),
{"cast", "foo", "bar", "baz"},
{"x"},
id="cast attr-1",
),
],
)
def test_visit_Call(self, code, expec_names):
def test_visit_Call(self, code, expec_names, expec_attrs):
analyzer = self._get_analyzer(code)
source_stats, _ = analyzer.get_stats()
self.assert_set_equal_or_not(source_stats.name_, expec_names)
self.assert_set_equal_or_not(source_stats.attr_, expec_attrs)

@pytest.mark.parametrize(
"code, expec_names, expec_attrs",
[
pytest.param(
("from typing import List\n" "import foo\n" "baz = List['foo']\n"),
{"List", "foo", "baz"},
set(),
id="str-name",
),
pytest.param(
("from typing import List\n" "import foo.x\n" "baz = List['foo.x']\n"),
{"List", "foo", "baz"},
{"x"},
id="str-attr",
),
pytest.param(
(
"from typing import Union\n"
"import foo, bar\n"
"baz = Union['foo', 'bar']\n"
),
{"Union", "foo", "bar", "baz"},
set(),
id="tuple-name",
),
pytest.param(
(
"from typing import Union\n"
"import foo.x, bar.y\n"
"baz = Union['foo.x', 'bar.y']\n"
),
{"Union", "foo", "bar", "baz"},
{"x", "y"},
id="tuple-attr",
),
],
)
def test_visit_Subscript(self, code, expec_names, expec_attrs):
analyzer = self._get_analyzer(code)
source_stats, _ = analyzer.get_stats()
self.assert_set_equal_or_not(source_stats.name_, expec_names)
self.assert_set_equal_or_not(source_stats.attr_, expec_attrs)

@pytest.mark.parametrize(
"code, expec_name",
Expand Down

0 comments on commit fe0dbeb

Please sign in to comment.