@@ -6274,10 +6274,6 @@ def has_no_custom_eq_checks(t: Type) -> bool:
6274
6274
coerce_only_in_literal_context ,
6275
6275
)
6276
6276
6277
- # Strictly speaking, we should also skip this check if the objects in the expr
6278
- # chain have custom __eq__ or __ne__ methods. But we (maybe optimistically)
6279
- # assume nobody would actually create a custom objects that considers itself
6280
- # equal to None.
6281
6277
if if_map == {} and else_map == {}:
6282
6278
if_map , else_map = self .refine_away_none_in_comparison (
6283
6279
operands , operand_types , expr_indices , narrowable_operand_index_to_hash .keys ()
@@ -6602,25 +6598,36 @@ def refine_away_none_in_comparison(
6602
6598
For more details about what the different arguments mean, see the
6603
6599
docstring of 'refine_identity_comparison_expression' up above.
6604
6600
"""
6601
+
6605
6602
non_optional_types = []
6606
6603
for i in chain_indices :
6607
6604
typ = operand_types [i ]
6608
6605
if not is_overlapping_none (typ ):
6609
6606
non_optional_types .append (typ )
6610
6607
6611
- # Make sure we have a mixture of optional and non-optional types.
6612
- if len (non_optional_types ) == 0 or len (non_optional_types ) == len (chain_indices ):
6613
- return {}, {}
6608
+ if_map , else_map = {}, {}
6614
6609
6615
- if_map = {}
6616
- for i in narrowable_operand_indices :
6617
- expr_type = operand_types [i ]
6618
- if not is_overlapping_none (expr_type ):
6619
- continue
6620
- if any (is_overlapping_erased_types (expr_type , t ) for t in non_optional_types ):
6621
- if_map [operands [i ]] = remove_optional (expr_type )
6610
+ if not non_optional_types or (len (non_optional_types ) != len (chain_indices )):
6622
6611
6623
- return if_map , {}
6612
+ # Narrow e.g. `Optional[A] == "x"` or `Optional[A] is "x"` to `A` (which may be
6613
+ # convenient but is strictly not type-safe):
6614
+ for i in narrowable_operand_indices :
6615
+ expr_type = operand_types [i ]
6616
+ if not is_overlapping_none (expr_type ):
6617
+ continue
6618
+ if any (is_overlapping_erased_types (expr_type , t ) for t in non_optional_types ):
6619
+ if_map [operands [i ]] = remove_optional (expr_type )
6620
+
6621
+ # Narrow e.g. `Optional[A] != None` to `A` (which is stricter than the above step and
6622
+ # so type-safe but less convenient, because e.g. `Optional[A] == None` still results
6623
+ # in `Optional[A]`):
6624
+ if any (isinstance (get_proper_type (ot ), NoneType ) for ot in operand_types ):
6625
+ for i in narrowable_operand_indices :
6626
+ expr_type = operand_types [i ]
6627
+ if is_overlapping_none (expr_type ):
6628
+ else_map [operands [i ]] = remove_optional (expr_type )
6629
+
6630
+ return if_map , else_map
6624
6631
6625
6632
def is_len_of_tuple (self , expr : Expression ) -> bool :
6626
6633
"""Is this expression a `len(x)` call where x is a tuple or union of tuples?"""
0 commit comments