Skip to content

Commit

Permalink
Resolve wemake-services#3265: add WPS476 TypeVarTupleFollowsTypeVarWi…
Browse files Browse the repository at this point in the history
…thDefaultViolation and corresponding tests
  • Loading branch information
Tapeline committed Feb 10, 2025
1 parent 413b261 commit e4c5c98
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 1 deletion.
14 changes: 14 additions & 0 deletions tests/fixtures/noqa/noqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,3 +747,17 @@ def my_function():

def pos_only_problem(first_argpm=0, second_argpm=1, /): # noqa: WPS475
my_print(first_argpm, second_argpm)


TypeVarDefault = TypeVar("T", default=int)
FollowingTuple = TypeVarTuple("Ts")


class NewStyleGenerics[TypeVarDefault, *FollowingTuple]: # noqa: WPS476
...


class OldStyleGenerics(
Generic[TypeVarDefault, *FollowingTuple] # noqa: WPS476
):
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Final

import pytest

from wemake_python_styleguide.violations.best_practices import (
TypeVarTupleFollowsTypeVarWithDefaultViolation,
)
from wemake_python_styleguide.visitors.ast.classes import (
ConsecutiveDefaultTypeVarsVisitor,
)


class_header_formats: Final[list[str]] = [
"Class[{0}]",
"Class(Generic[{0}])"
]
various_code: Final[str] = (
"pi = 3.14\n"
"a = obj.method_call()\n"
"w, h = get_size()\n"
"obj.field = function_call()\n"
"AlmostTypeVar = NotReallyATypeVar()\n"
"NonDefault = TypeVar('NonDefault')\n"
)
classes_with_various_bases: Final[str] = (
"class SimpleBase(object): ...\n"
"class NotANameSubscript(Some.Class[object]): ...\n"
"class NotAGenericBase(NotAGeneric[T]): ...\n"
"class GenericButNotMultiple(Generic[T]): ...\n"
)


@pytest.mark.parametrize(
"class_header_format",
class_header_formats,
)
def test_type_var_tuple_after_type_var_with_default(
assert_errors,
parse_ast_tree,
default_options,
class_header_format,
):
"""Test that WPS476 works correctly."""
class_header = class_header_format.format("T, *Ts")
src = (
various_code +
classes_with_various_bases +
"T = TypeVar('T', default=int)\n"
"Ts = TypeVarTuple('Ts')\n"
"\n"
"class {0}:\n"
" ..."
).format(class_header)

tree = parse_ast_tree(src)

visitor = ConsecutiveDefaultTypeVarsVisitor(default_options, tree=tree)
visitor.run()

assert_errors(visitor, [TypeVarTupleFollowsTypeVarWithDefaultViolation])


@pytest.mark.parametrize(
"class_header_format",
class_header_formats,
)
def test_type_var_tuple_after_type_var_without_default(
assert_errors,
parse_ast_tree,
default_options,
class_header_format,
):
"""Test that WPS476 ignores non-defaulted TypeVars."""
class_header = class_header_format.format("T, *Ts")
src = (
various_code +
classes_with_various_bases +
"T = TypeVar('T')\n"
"Ts = TypeVarTuple('Ts')\n"
"\n"
"class {0}:\n"
" ..."
).format(class_header)

tree = parse_ast_tree(src)

visitor = ConsecutiveDefaultTypeVarsVisitor(default_options, tree=tree)
visitor.run()

assert_errors(visitor, [])
1 change: 1 addition & 0 deletions wemake_python_styleguide/presets/types/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
decorators.WrongDecoratorVisitor,
redundancy.RedundantEnumerateVisitor,
pm.MatchSubjectVisitor,
classes.ConsecutiveDefaultTypeVarsVisitor,
# Modules:
modules.EmptyModuleContentsVisitor,
modules.MagicModuleFunctionsVisitor,
Expand Down
37 changes: 37 additions & 0 deletions wemake_python_styleguide/violations/best_practices.py
Original file line number Diff line number Diff line change
Expand Up @@ -2968,3 +2968,40 @@ def function(first=0, *args): ...

error_template = 'Found problematic function parameters'
code = 475


@final
class TypeVarTupleFollowsTypeVarWithDefaultViolation(ASTViolation):
"""
Forbid using TypeVarTuple after a TypeVar with default.
Reasoning:
Following a defaulted TypeVar with a TypeVarTuple is bad,
because you cannot specify the TypeVarTuple without
specifying TypeVar.
Solution:
Consider refactoring and getting rid of that pattern.
Example::
# Wrong:
T = TypeVar("T", default=int)
class Class[T, *Ts]:
...
# Correct (no default):
class Class[T, *Ts]:
...
# Correct (no tuple):
T = TypeVar("T", default=int)
class Class[T]:
...
.. versionadded:: 1.1.0
"""

error_template = 'Found a TypeVarTuple following a TypeVar with default'
code = 476
107 changes: 106 additions & 1 deletion wemake_python_styleguide/visitors/ast/classes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import ast
from collections import defaultdict
from typing import ClassVar, final
from collections.abc import Sequence
from typing import ClassVar, final, cast

from attrs import frozen

from wemake_python_styleguide import constants, types
from wemake_python_styleguide.compat.aliases import AssignNodes, FunctionNodes
Expand All @@ -15,8 +18,12 @@
getters_setters,
strings,
)
from wemake_python_styleguide.options.validation import ValidatedOptions
from wemake_python_styleguide.violations import best_practices as bp
from wemake_python_styleguide.violations import consistency, oop
from wemake_python_styleguide.violations.best_practices import (
TypeVarTupleFollowsTypeVarWithDefaultViolation,
)
from wemake_python_styleguide.visitors import base, decorators


Expand Down Expand Up @@ -493,3 +500,101 @@ def _check_buggy_super_context(self, node: ast.Call):

if walk.get_closest_parent(node, self._buggy_super_contexts):
self.add_violation(oop.BuggySuperContextViolation(node))


@frozen
class TypeVarInfo:
name: str
has_default: bool


@final
@decorators.alias(
'visit_any_assign',
(
'visit_Assign',
'visit_AnnAssign',
),
)
class ConsecutiveDefaultTypeVarsVisitor(base.BaseNodeVisitor):
"""Responsible for finding TypeVarTuple after a TypeVar with default."""

def __init__(
self, options: ValidatedOptions, tree: ast.AST, **kwargs
) -> None:
super().__init__(options, tree, **kwargs)
self._defaulted_typevars: set[str] = set()

def visit_any_assign(self, node: types.AnyAssign) -> None:
typevar = self._assume_typevar_creation(node)
if not typevar or not typevar.has_default:
return
self._defaulted_typevars.add(typevar.name)

def _assume_typevar_creation(
self, node: types.AnyAssign
) -> TypeVarInfo | None:
if not isinstance(node.value, ast.Call):
return None
if not isinstance(node.value.func, ast.Name):
return None
if len(node.targets) != 1: # pragma: no cover
return None
if not isinstance(node.targets[0], ast.Name):
return None
if node.value.func.id != "TypeVar":
return None
return TypeVarInfo(
name=cast(ast.Name, node.targets[0]).id,
has_default=any(
cast(ast.keyword, kw).arg == "default"
for kw in node.value.keywords
)
)

def visit_ClassDef(self, node: ast.ClassDef) -> None:
if hasattr(node, "type_params"): # pragma: no cover
self._check_new_style_generics(node.type_params)
self._check_old_style_generics(node.bases)

def _check_new_style_generics(
self, type_params: Sequence[ast.type_param]
) -> None:
had_default = False
for type_param in type_params:
had_default = had_default or (
isinstance(type_param, ast.TypeVar)
and type_param.name in self._defaulted_typevars
)
if isinstance(type_param, ast.TypeVarTuple) and had_default:
self.add_violation(
TypeVarTupleFollowsTypeVarWithDefaultViolation(type_param)
)

def _check_old_style_generics(self, bases: Sequence[ast.expr]) -> None:
for cls_base in bases:
if self._is_generic_tuple_base(cls_base):
self._check_generic_tuple(
cls_base.slice.elts
)

def _is_generic_tuple_base(self, cls_base: ast.expr) -> bool:
if not isinstance(cls_base, ast.Subscript):
return False
if not isinstance(cls_base.value, ast.Name):
return False
if cls_base.value.id != "Generic":
return False
return isinstance(cls_base.slice, ast.Tuple)

def _check_generic_tuple(self, elts: Sequence[ast.expr]) -> None:
had_default = False
for expr in elts:
had_default = had_default or (
isinstance(expr, ast.Name)
and expr.id in self._defaulted_typevars
)
if isinstance(expr, ast.Starred) and had_default:
self.add_violation(
TypeVarTupleFollowsTypeVarWithDefaultViolation(expr)
)

0 comments on commit e4c5c98

Please sign in to comment.