@@ -2970,7 +2970,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
2970
2970
not local_errors .has_new_errors ()
2971
2971
and cont_type
2972
2972
and self .dangerous_comparison (
2973
- left_type , cont_type , original_container = right_type
2973
+ left_type , cont_type , original_container = right_type , prefer_literal = False
2974
2974
)
2975
2975
):
2976
2976
self .msg .dangerous_comparison (left_type , cont_type , "container" , e )
@@ -2988,21 +2988,19 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
2988
2988
# testCustomEqCheckStrictEquality for an example.
2989
2989
if not w .has_new_errors () and operator in ("==" , "!=" ):
2990
2990
right_type = self .accept (right )
2991
- # Also flag non-overlapping literals in situations like:
2992
- # x: Literal['a', 'b']
2993
- # if x == 'c':
2994
- # ...
2995
- left_type = try_getting_literal (left_type )
2996
- right_type = try_getting_literal (right_type )
2997
2991
if self .dangerous_comparison (left_type , right_type ):
2992
+ # Show the most specific literal types possible
2993
+ left_type = try_getting_literal (left_type )
2994
+ right_type = try_getting_literal (right_type )
2998
2995
self .msg .dangerous_comparison (left_type , right_type , "equality" , e )
2999
2996
3000
2997
elif operator == "is" or operator == "is not" :
3001
2998
right_type = self .accept (right ) # validate the right operand
3002
2999
sub_result = self .bool_type ()
3003
- left_type = try_getting_literal (left_type )
3004
- right_type = try_getting_literal (right_type )
3005
3000
if self .dangerous_comparison (left_type , right_type ):
3001
+ # Show the most specific literal types possible
3002
+ left_type = try_getting_literal (left_type )
3003
+ right_type = try_getting_literal (right_type )
3006
3004
self .msg .dangerous_comparison (left_type , right_type , "identity" , e )
3007
3005
method_type = None
3008
3006
else :
@@ -3036,7 +3034,12 @@ def find_partial_type_ref_fast_path(self, expr: Expression) -> Type | None:
3036
3034
return None
3037
3035
3038
3036
def dangerous_comparison (
3039
- self , left : Type , right : Type , original_container : Type | None = None
3037
+ self ,
3038
+ left : Type ,
3039
+ right : Type ,
3040
+ original_container : Type | None = None ,
3041
+ * ,
3042
+ prefer_literal : bool = True ,
3040
3043
) -> bool :
3041
3044
"""Check for dangerous non-overlapping comparisons like 42 == 'no'.
3042
3045
@@ -3064,6 +3067,14 @@ def dangerous_comparison(
3064
3067
if custom_special_method (left , "__eq__" ) or custom_special_method (right , "__eq__" ):
3065
3068
return False
3066
3069
3070
+ if prefer_literal :
3071
+ # Also flag non-overlapping literals in situations like:
3072
+ # x: Literal['a', 'b']
3073
+ # if x == 'c':
3074
+ # ...
3075
+ left = try_getting_literal (left )
3076
+ right = try_getting_literal (right )
3077
+
3067
3078
if self .chk .binder .is_unreachable_warning_suppressed ():
3068
3079
# We are inside a function that contains type variables with value restrictions in
3069
3080
# its signature. In this case we just suppress all strict-equality checks to avoid
0 commit comments