Skip to content

Commit 0665ce9

Browse files
authored
Fix strict equality with enum type with custom __eq__ (#14518)
Fixes regression introduced in #14513.
1 parent 757e0d4 commit 0665ce9

File tree

2 files changed

+47
-10
lines changed

2 files changed

+47
-10
lines changed

mypy/checkexpr.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -2970,7 +2970,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
29702970
not local_errors.has_new_errors()
29712971
and cont_type
29722972
and self.dangerous_comparison(
2973-
left_type, cont_type, original_container=right_type
2973+
left_type, cont_type, original_container=right_type, prefer_literal=False
29742974
)
29752975
):
29762976
self.msg.dangerous_comparison(left_type, cont_type, "container", e)
@@ -2988,21 +2988,19 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
29882988
# testCustomEqCheckStrictEquality for an example.
29892989
if not w.has_new_errors() and operator in ("==", "!="):
29902990
right_type = self.accept(right)
2991-
# Also flag non-overlapping literals in situations like:
2992-
# x: Literal['a', 'b']
2993-
# if x == 'c':
2994-
# ...
2995-
left_type = try_getting_literal(left_type)
2996-
right_type = try_getting_literal(right_type)
29972991
if self.dangerous_comparison(left_type, right_type):
2992+
# Show the most specific literal types possible
2993+
left_type = try_getting_literal(left_type)
2994+
right_type = try_getting_literal(right_type)
29982995
self.msg.dangerous_comparison(left_type, right_type, "equality", e)
29992996

30002997
elif operator == "is" or operator == "is not":
30012998
right_type = self.accept(right) # validate the right operand
30022999
sub_result = self.bool_type()
3003-
left_type = try_getting_literal(left_type)
3004-
right_type = try_getting_literal(right_type)
30053000
if self.dangerous_comparison(left_type, right_type):
3001+
# Show the most specific literal types possible
3002+
left_type = try_getting_literal(left_type)
3003+
right_type = try_getting_literal(right_type)
30063004
self.msg.dangerous_comparison(left_type, right_type, "identity", e)
30073005
method_type = None
30083006
else:
@@ -3036,7 +3034,12 @@ def find_partial_type_ref_fast_path(self, expr: Expression) -> Type | None:
30363034
return None
30373035

30383036
def dangerous_comparison(
3039-
self, left: Type, right: Type, original_container: Type | None = None
3037+
self,
3038+
left: Type,
3039+
right: Type,
3040+
original_container: Type | None = None,
3041+
*,
3042+
prefer_literal: bool = True,
30403043
) -> bool:
30413044
"""Check for dangerous non-overlapping comparisons like 42 == 'no'.
30423045
@@ -3064,6 +3067,14 @@ def dangerous_comparison(
30643067
if custom_special_method(left, "__eq__") or custom_special_method(right, "__eq__"):
30653068
return False
30663069

3070+
if prefer_literal:
3071+
# Also flag non-overlapping literals in situations like:
3072+
# x: Literal['a', 'b']
3073+
# if x == 'c':
3074+
# ...
3075+
left = try_getting_literal(left)
3076+
right = try_getting_literal(right)
3077+
30673078
if self.chk.binder.is_unreachable_warning_suppressed():
30683079
# We are inside a function that contains type variables with value restrictions in
30693080
# its signature. In this case we just suppress all strict-equality checks to avoid

test-data/unit/check-expressions.test

+26
Original file line numberDiff line numberDiff line change
@@ -2221,6 +2221,32 @@ int == y
22212221
y == int
22222222
[builtins fixtures/bool.pyi]
22232223

2224+
[case testStrictEqualityAndEnumWithCustomEq]
2225+
# flags: --strict-equality
2226+
from enum import Enum
2227+
2228+
class E1(Enum):
2229+
X = 0
2230+
Y = 1
2231+
2232+
class E2(Enum):
2233+
X = 0
2234+
Y = 1
2235+
2236+
def __eq__(self, other: object) -> bool:
2237+
return bool()
2238+
2239+
E1.X == E1.Y # E: Non-overlapping equality check (left operand type: "Literal[E1.X]", right operand type: "Literal[E1.Y]")
2240+
E2.X == E2.Y
2241+
[builtins fixtures/bool.pyi]
2242+
2243+
[case testStrictEqualityWithBytesContains]
2244+
# flags: --strict-equality
2245+
data = b"xy"
2246+
b"x" in data
2247+
[builtins fixtures/primitives.pyi]
2248+
[typing fixtures/typing-full.pyi]
2249+
22242250
[case testUnimportedHintAny]
22252251
def f(x: Any) -> None: # E: Name "Any" is not defined \
22262252
# N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any")

0 commit comments

Comments
 (0)