Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Narrow individual items when matching a tuple to a sequence pattern #16905

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5086,6 +5086,9 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
)
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
self.push_type_map(pattern_map)
if pattern_map:
for expr, typ in pattern_map.items():
self.push_type_map(self._get_recursive_sub_patterns_map(expr, typ))
self.push_type_map(pattern_type.captures)
if g is not None:
with self.binder.frame_context(can_skip=False, fall_through=3):
Expand Down Expand Up @@ -5123,6 +5126,20 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
with self.binder.frame_context(can_skip=False, fall_through=2):
pass

def _get_recursive_sub_patterns_map(
self, expr: Expression, typ: Type
) -> dict[Expression, Type]:
sub_patterns_map: dict[Expression, Type] = {}
typ_ = get_proper_type(typ)
if isinstance(expr, TupleExpr) and isinstance(typ_, TupleType):
# When matching a tuple expression with a sequence pattern, narrow individual tuple items
assert len(expr.items) == len(typ_.items)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some test cases where the length doesn't match, to make sure we don't fail this assertion? Especially in nested tuples.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! I did not check anything in non-matching cases, just that the assertion didn't fail, eg.

match m, n, o:
    case [3, "foo"]:
        pass
    case [3, "foo", True, True]:
        pass

Interestingly, these paths are not considered unreachable by mypy. They probably should, but that's unrelated to this issue!

for item_expr, item_typ in zip(expr.items, typ_.items):
sub_patterns_map[item_expr] = item_typ
sub_patterns_map.update(self._get_recursive_sub_patterns_map(item_expr, item_typ))

return sub_patterns_map

def infer_variable_types_from_type_maps(self, type_maps: list[TypeMap]) -> dict[Var, Type]:
all_captures: dict[Var, list[tuple[NameExpr, Type]]] = defaultdict(list)
for tm in type_maps:
Expand Down
66 changes: 66 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,72 @@ match m:
reveal_type(m) # N: Revealed type is "builtins.list[builtins.list[builtins.str]]"
[builtins fixtures/list.pyi]

[case testMatchSequencePatternNarrowSubjectItems]
m: int
n: str
o: bool

match m, n, o:
case [3, "foo", True]:
reveal_type(m) # N: Revealed type is "Literal[3]"
reveal_type(n) # N: Revealed type is "Literal['foo']"
reveal_type(o) # N: Revealed type is "Literal[True]"
case [a, b, c]:
reveal_type(m) # N: Revealed type is "builtins.int"
reveal_type(n) # N: Revealed type is "builtins.str"
reveal_type(o) # N: Revealed type is "builtins.bool"

reveal_type(m) # N: Revealed type is "builtins.int"
reveal_type(n) # N: Revealed type is "builtins.str"
reveal_type(o) # N: Revealed type is "builtins.bool"
[builtins fixtures/tuple.pyi]

[case testMatchSequencePatternNarrowSubjectItemsRecursive]
m: int
n: int
o: int
p: int
q: int
r: int

match m, (n, o), (p, (q, r)):
case [0, [1, 2], [3, [4, 5]]]:
reveal_type(m) # N: Revealed type is "Literal[0]"
reveal_type(n) # N: Revealed type is "Literal[1]"
reveal_type(o) # N: Revealed type is "Literal[2]"
reveal_type(p) # N: Revealed type is "Literal[3]"
reveal_type(q) # N: Revealed type is "Literal[4]"
reveal_type(r) # N: Revealed type is "Literal[5]"
[builtins fixtures/tuple.pyi]

[case testMatchSequencePatternSequencesLengthMismatchNoNarrowing]
m: int
n: str
o: bool

match m, n, o:
case [3, "foo"]:
pass
case [3, "foo", True, True]:
pass
[builtins fixtures/tuple.pyi]

[case testMatchSequencePatternSequencesLengthMismatchNoNarrowingRecursive]
m: int
n: int
o: int

match m, (n, o):
case [0]:
pass
case [0, 1, [2]]:
pass
case [0, [1]]:
pass
case [0, [1, 2, 3]]:
pass
[builtins fixtures/tuple.pyi]

-- Mapping Pattern --

[case testMatchMappingPatternCaptures]
Expand Down