Skip to content

Commit 0233147

Browse files
committed
Treat NewTypes like normal subclasses
NewTypes are assumed not to inherit any members from their base classes. This results in incorrect inference results. Avoid this by changing the transformation for NewTypes to treat them like any other subclass. pylint-dev/pylint#3162 pylint-dev/pylint#2296
1 parent 39c37c1 commit 0233147

File tree

3 files changed

+197
-11
lines changed

3 files changed

+197
-11
lines changed

ChangeLog

+5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ What's New in astroid 2.9.1?
1212
============================
1313
Release date: TBA
1414

15+
* Treat ``typing.NewType()`` values as normal subclasses.
16+
17+
Closes PyCQA/pylint#2296
18+
Closes PyCQA/pylint#3162
19+
1520
* Prefer the module loader get_source() method in AstroidBuilder's
1621
module_build() when possible to avoid assumptions about source
1722
code being available on a filesystem. Otherwise the source cannot

astroid/brain/brain_typing.py

+70-11
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import typing
1616
from functools import partial
1717

18-
from astroid import context, extract_node, inference_tip
18+
from astroid import context, extract_node, inference_tip, nodes
1919
from astroid.const import PY37_PLUS, PY38_PLUS, PY39_PLUS
2020
from astroid.exceptions import (
2121
AttributeInferenceError,
@@ -38,8 +38,6 @@
3838
from astroid.util import Uninferable
3939

4040
TYPING_NAMEDTUPLE_BASENAMES = {"NamedTuple", "typing.NamedTuple"}
41-
TYPING_TYPEVARS = {"TypeVar", "NewType"}
42-
TYPING_TYPEVARS_QUALIFIED = {"typing.TypeVar", "typing.NewType"}
4341
TYPING_TYPE_TEMPLATE = """
4442
class Meta(type):
4543
def __getitem__(self, item):
@@ -52,6 +50,13 @@ def __args__(self):
5250
class {0}(metaclass=Meta):
5351
pass
5452
"""
53+
# PEP484 suggests NewType is equivalent to this for typing purposes
54+
# https://www.python.org/dev/peps/pep-0484/#newtype-helper-function
55+
TYPING_NEWTYPE_TEMPLATE = """
56+
class {derived}({base}):
57+
def __init__(self, val: {base}) -> None:
58+
...
59+
"""
5560
TYPING_MEMBERS = set(getattr(typing, "__all__", []))
5661

5762
TYPING_ALIAS = frozenset(
@@ -106,23 +111,34 @@ def __class_getitem__(cls, item):
106111
"""
107112

108113

109-
def looks_like_typing_typevar_or_newtype(node):
114+
def looks_like_typing_typevar(node: nodes.Call) -> bool:
115+
func = node.func
116+
if isinstance(func, Attribute):
117+
return func.attrname == "TypeVar"
118+
if isinstance(func, Name):
119+
return func.name == "TypeVar"
120+
return False
121+
122+
123+
def looks_like_typing_newtype(node: nodes.Call) -> bool:
110124
func = node.func
111125
if isinstance(func, Attribute):
112-
return func.attrname in TYPING_TYPEVARS
126+
return func.attrname == "NewType"
113127
if isinstance(func, Name):
114-
return func.name in TYPING_TYPEVARS
128+
return func.name == "NewType"
115129
return False
116130

117131

118-
def infer_typing_typevar_or_newtype(node, context_itton=None):
119-
"""Infer a typing.TypeVar(...) or typing.NewType(...) call"""
132+
def infer_typing_typevar(
133+
node: nodes.Call, context_itton: typing.Optional[context.InferenceContext] = None
134+
) -> typing.Iterator[nodes.ClassDef]:
135+
"""Infer a typing.TypeVar(...) call"""
120136
try:
121137
func = next(node.func.infer(context=context_itton))
122138
except (InferenceError, StopIteration) as exc:
123139
raise UseInferenceDefault from exc
124140

125-
if func.qname() not in TYPING_TYPEVARS_QUALIFIED:
141+
if func.qname() != "typing.TypeVar":
126142
raise UseInferenceDefault
127143
if not node.args:
128144
raise UseInferenceDefault
@@ -132,6 +148,44 @@ def infer_typing_typevar_or_newtype(node, context_itton=None):
132148
return node.infer(context=context_itton)
133149

134150

151+
def infer_typing_newtype(
152+
node: nodes.Call, context_itton: typing.Optional[context.InferenceContext] = None
153+
) -> typing.Iterator[nodes.ClassDef]:
154+
"""Infer a typing.NewType(...) call"""
155+
try:
156+
func = next(node.func.infer(context=context_itton))
157+
except (InferenceError, StopIteration) as exc:
158+
raise UseInferenceDefault from exc
159+
160+
if func.qname() != "typing.NewType":
161+
raise UseInferenceDefault
162+
if len(node.args) != 2:
163+
raise UseInferenceDefault
164+
165+
derived, base = node.args
166+
derived_name = derived.as_string().strip("'")
167+
base_name = base.as_string().strip("'")
168+
169+
new_node: ClassDef = extract_node(
170+
TYPING_NEWTYPE_TEMPLATE.format(derived=derived_name, base=base_name)
171+
)
172+
new_node.parent = node.parent
173+
174+
# Base type arg is a normal reference, so no need to do special lookups
175+
if not isinstance(base, nodes.Const):
176+
new_node.bases = [base]
177+
178+
# If the base type is given as a string (e.g. for a forward reference),
179+
# make a naive attempt to find the corresponding node.
180+
# Note that this will not work with imported types.
181+
if isinstance(base, nodes.Const) and isinstance(base.value, str):
182+
_, resolved_base = node.frame().lookup(base_name)
183+
if resolved_base:
184+
new_node.bases = [resolved_base[0]]
185+
186+
return new_node.infer(context=context_itton)
187+
188+
135189
def _looks_like_typing_subscript(node):
136190
"""Try to figure out if a Subscript node *might* be a typing-related subscript"""
137191
if isinstance(node, Name):
@@ -409,8 +463,13 @@ def infer_typing_cast(
409463

410464
AstroidManager().register_transform(
411465
Call,
412-
inference_tip(infer_typing_typevar_or_newtype),
413-
looks_like_typing_typevar_or_newtype,
466+
inference_tip(infer_typing_typevar),
467+
looks_like_typing_typevar,
468+
)
469+
AstroidManager().register_transform(
470+
Call,
471+
inference_tip(infer_typing_newtype),
472+
looks_like_typing_newtype,
414473
)
415474
AstroidManager().register_transform(
416475
Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript

tests/unittest_brain.py

+122
Original file line numberDiff line numberDiff line change
@@ -1659,6 +1659,128 @@ def test_typing_types(self) -> None:
16591659
inferred = next(node.infer())
16601660
self.assertIsInstance(inferred, nodes.ClassDef, node.as_string())
16611661

1662+
def test_typing_newtype_attrs(self) -> None:
1663+
ast_nodes = builder.extract_node(
1664+
"""
1665+
from typing import NewType
1666+
import decimal
1667+
from decimal import Decimal
1668+
1669+
NewType("Foo", str) #@
1670+
NewType("Bar", "int") #@
1671+
NewType("Baz", Decimal) #@
1672+
NewType("Qux", decimal.Decimal) #@
1673+
"""
1674+
)
1675+
assert isinstance(ast_nodes, list)
1676+
1677+
# Base type given by reference
1678+
foo_node = ast_nodes[0]
1679+
foo_inferred = next(foo_node.infer())
1680+
self.assertIsInstance(foo_inferred, astroid.ClassDef)
1681+
1682+
# Check base type method is inferred by accessing one of its methods
1683+
foo_base_class_method = foo_inferred.getattr("endswith")[0]
1684+
self.assertIsInstance(foo_base_class_method, astroid.FunctionDef)
1685+
self.assertEqual("builtins.str.endswith", foo_base_class_method.qname())
1686+
1687+
# Base type given by string (i.e. "int")
1688+
bar_node = ast_nodes[1]
1689+
bar_inferred = next(bar_node.infer())
1690+
self.assertIsInstance(bar_inferred, astroid.ClassDef)
1691+
1692+
bar_base_class_method = bar_inferred.getattr("bit_length")[0]
1693+
self.assertIsInstance(bar_base_class_method, astroid.FunctionDef)
1694+
self.assertEqual("builtins.int.bit_length", bar_base_class_method.qname())
1695+
1696+
# Decimal may be reexported from an implementation-defined module. For
1697+
# example, in CPython 3.10 this is _decimal, but in PyPy 7.3 it's
1698+
# _pydecimal. So the expected qname needs to be grabbed dynamically.
1699+
decimal_quant_node = builder.extract_node(
1700+
"""
1701+
from decimal import Decimal
1702+
Decimal.quantize #@
1703+
"""
1704+
)
1705+
assert isinstance(decimal_quant_node, nodes.NodeNG)
1706+
decimal_quant_qname = next(decimal_quant_node.infer()).qname()
1707+
1708+
# Base type is from an "import from"
1709+
baz_node = ast_nodes[2]
1710+
baz_inferred = next(baz_node.infer())
1711+
self.assertIsInstance(baz_inferred, astroid.ClassDef)
1712+
1713+
baz_base_class_method = baz_inferred.getattr("quantize")[0]
1714+
self.assertIsInstance(baz_base_class_method, astroid.FunctionDef)
1715+
self.assertEqual(decimal_quant_qname, baz_base_class_method.qname())
1716+
1717+
# Base type is from an import
1718+
qux_node = ast_nodes[3]
1719+
qux_inferred = next(qux_node.infer())
1720+
self.assertIsInstance(qux_inferred, astroid.ClassDef)
1721+
1722+
qux_base_class_method = qux_inferred.getattr("quantize")[0]
1723+
self.assertIsInstance(qux_base_class_method, astroid.FunctionDef)
1724+
self.assertEqual(decimal_quant_qname, qux_base_class_method.qname())
1725+
1726+
def test_typing_newtype_user_defined(self) -> None:
1727+
ast_nodes = builder.extract_node(
1728+
"""
1729+
from typing import NewType
1730+
1731+
class A:
1732+
def __init__(self, value: int):
1733+
self.value = value
1734+
1735+
a = A(5)
1736+
a #@
1737+
1738+
B = NewType("B", A)
1739+
b = B(5)
1740+
b #@
1741+
"""
1742+
)
1743+
assert isinstance(ast_nodes, list)
1744+
1745+
for node in ast_nodes:
1746+
self._verify_node_has_expected_attr(node)
1747+
1748+
def test_typing_newtype_forward_reference(self) -> None:
1749+
# Similar to the test above, but using a forward reference for "A"
1750+
ast_nodes = builder.extract_node(
1751+
"""
1752+
from typing import NewType
1753+
1754+
B = NewType("B", "A")
1755+
1756+
class A:
1757+
def __init__(self, value: int):
1758+
self.value = value
1759+
1760+
a = A(5)
1761+
a #@
1762+
1763+
b = B(5)
1764+
b #@
1765+
"""
1766+
)
1767+
assert isinstance(ast_nodes, list)
1768+
1769+
for node in ast_nodes:
1770+
self._verify_node_has_expected_attr(node)
1771+
1772+
def _verify_node_has_expected_attr(self, node: nodes.NodeNG) -> None:
1773+
inferred = next(node.infer())
1774+
self.assertIsInstance(inferred, astroid.Instance)
1775+
1776+
# Should be able to infer that the "value" attr is present on both types
1777+
val = inferred.getattr("value")[0]
1778+
self.assertIsInstance(val, astroid.AssignAttr)
1779+
1780+
# Sanity check: nonexistent attr is not inferred
1781+
with self.assertRaises(AttributeInferenceError):
1782+
inferred.getattr("bad_attr")
1783+
16621784
def test_namedtuple_nested_class(self):
16631785
result = builder.extract_node(
16641786
"""

0 commit comments

Comments
 (0)