Skip to content

Commit e2a47e2

Browse files
authored
Support ==-based narrowing of Optional (#18163)
Closes #18135 This change implements the third approach mentioned in #18135, which is stricter than similar narrowings, as clarified by the new/modified code comments. Personally, I prefer this more stringent way but could also switch this PR to approach two if there is consent that convenience is more important than type safety here.
1 parent 87998c8 commit e2a47e2

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

mypy/checker.py

+22-15
Original file line numberDiff line numberDiff line change
@@ -6274,10 +6274,6 @@ def has_no_custom_eq_checks(t: Type) -> bool:
62746274
coerce_only_in_literal_context,
62756275
)
62766276

6277-
# Strictly speaking, we should also skip this check if the objects in the expr
6278-
# chain have custom __eq__ or __ne__ methods. But we (maybe optimistically)
6279-
# assume nobody would actually create a custom objects that considers itself
6280-
# equal to None.
62816277
if if_map == {} and else_map == {}:
62826278
if_map, else_map = self.refine_away_none_in_comparison(
62836279
operands, operand_types, expr_indices, narrowable_operand_index_to_hash.keys()
@@ -6602,25 +6598,36 @@ def refine_away_none_in_comparison(
66026598
For more details about what the different arguments mean, see the
66036599
docstring of 'refine_identity_comparison_expression' up above.
66046600
"""
6601+
66056602
non_optional_types = []
66066603
for i in chain_indices:
66076604
typ = operand_types[i]
66086605
if not is_overlapping_none(typ):
66096606
non_optional_types.append(typ)
66106607

6611-
# Make sure we have a mixture of optional and non-optional types.
6612-
if len(non_optional_types) == 0 or len(non_optional_types) == len(chain_indices):
6613-
return {}, {}
6608+
if_map, else_map = {}, {}
66146609

6615-
if_map = {}
6616-
for i in narrowable_operand_indices:
6617-
expr_type = operand_types[i]
6618-
if not is_overlapping_none(expr_type):
6619-
continue
6620-
if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types):
6621-
if_map[operands[i]] = remove_optional(expr_type)
6610+
if not non_optional_types or (len(non_optional_types) != len(chain_indices)):
66226611

6623-
return if_map, {}
6612+
# Narrow e.g. `Optional[A] == "x"` or `Optional[A] is "x"` to `A` (which may be
6613+
# convenient but is strictly not type-safe):
6614+
for i in narrowable_operand_indices:
6615+
expr_type = operand_types[i]
6616+
if not is_overlapping_none(expr_type):
6617+
continue
6618+
if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types):
6619+
if_map[operands[i]] = remove_optional(expr_type)
6620+
6621+
# Narrow e.g. `Optional[A] != None` to `A` (which is stricter than the above step and
6622+
# so type-safe but less convenient, because e.g. `Optional[A] == None` still results
6623+
# in `Optional[A]`):
6624+
if any(isinstance(get_proper_type(ot), NoneType) for ot in operand_types):
6625+
for i in narrowable_operand_indices:
6626+
expr_type = operand_types[i]
6627+
if is_overlapping_none(expr_type):
6628+
else_map[operands[i]] = remove_optional(expr_type)
6629+
6630+
return if_map, else_map
66246631

66256632
def is_len_of_tuple(self, expr: Expression) -> bool:
66266633
"""Is this expression a `len(x)` call where x is a tuple or union of tuples?"""

test-data/unit/check-narrowing.test

+2-2
Original file line numberDiff line numberDiff line change
@@ -1385,9 +1385,9 @@ val: Optional[A]
13851385
if val == None:
13861386
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
13871387
else:
1388-
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
1388+
reveal_type(val) # N: Revealed type is "__main__.A"
13891389
if val != None:
1390-
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
1390+
reveal_type(val) # N: Revealed type is "__main__.A"
13911391
else:
13921392
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
13931393

0 commit comments

Comments
 (0)