Skip to content

Commit

Permalink
Improve set matching
Browse files Browse the repository at this point in the history
Resolves   #714.
  • Loading branch information
evhub committed Jan 8, 2023
1 parent 269de48 commit d6e378d
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 23 deletions.
11 changes: 9 additions & 2 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,10 @@ base_pattern ::= (
| "class" NAME "(" patterns ")" # classes
| "{" pattern_pairs # dictionaries
["," "**" (NAME | "{}")] "}" # (keys must be constants or equality checks)
| ["s"] "{" pattern_consts "}" # sets
| ["s" | "f" | "m"] "{"
pattern_consts
["," ("*_" | "*()")]
"}" # sets
| (EXPR) -> pattern # view patterns
| "(" patterns ")" # sequences can be in tuple form
| "[" patterns "]" # or in list form
Expand Down Expand Up @@ -1088,7 +1091,6 @@ base_pattern ::= (
- Constants, Numbers, and Strings: will only match to the same constant, number, or string in the same position in the arguments.
- Equality Checks (`==<expr>`): will check that whatever is in that position is `==` to the expression `<expr>`.
- Identity Checks (`is <expr>`): will check that whatever is in that position `is` the expression `<expr>`.
- Sets (`{<constants>}`): will only match a set (`collections.abc.Set`) of the same length and contents.
- Arbitrary Function Patterns:
- Infix Checks (`` <pattern> `<op>` <expr> ``): will check that the operator `<op>$(?, <expr>)` returns a truthy value when called on whatever is in that position, then matches `<pattern>`. For example, `` x `isinstance` int `` will check that whatever is in that position `isinstance$(?, int)` and bind it to `x`. If `<expr>` is not given, will simply check `<op>` directly rather than `<op>$(<expr>)`. Additionally, `` `<op>` `` can instead be a [custom operator](#custom-operators) (in that case, no backticks should be used).
- View Patterns (`(<expression>) -> <pattern>`): calls `<expression>` on the item being matched and matches the result to `<pattern>`. The match fails if a [`MatchError`](#matcherror) is raised. `<expression>` may be unparenthesized only when it is a single atom.
Expand All @@ -1099,6 +1101,11 @@ base_pattern ::= (
- Mapping Destructuring:
- Dicts (`{<key>: <value>, ...}`): will match any mapping (`collections.abc.Mapping`) with the given keys and values that match the value patterns. Keys must be constants or equality checks.
- Dicts With Rest (`{<pairs>, **<rest>}`): will match a mapping (`collections.abc.Mapping`) containing all the `<pairs>`, and will put a `dict` of everything else into `<rest>`. If `<rest>` is `{}`, will enforce that the mapping is exactly the same length as `<pairs>`.
- Set Destructuring:
- Sets (`s{<constants>, *_}`): will match a set (`collections.abc.Set`) that contains the given `<constants>`, though it may also contain other items. The `s` prefix and the `*_` are optional.
- Fixed-length Sets (`s{<constants>, *()}`): will match a `set` (`collections.abc.Set`) that contains the given `<constants>`, and nothing else.
- Frozensets (`f{<constants>}`): will match a `frozenset` (`frozenset`) that contains the given `<constants>`. May use either normal or fixed-length syntax.
- Multisets (`m{<constants>}`): will match a [`multiset`](#multiset) (`collections.Counter`) that contains at least the given `<constants>`. May use either normal or fixed-length syntax.
- Sequence Destructuring:
- Lists (`[<patterns>]`), Tuples (`(<patterns>)`): will only match a sequence (`collections.abc.Sequence`) of the same length, and will check the contents against `<patterns>` (Coconut automatically registers `numpy` arrays and `collections.deque` objects as sequences).
- Lazy lists (`(|<patterns>|)`): same as list or tuple matching, but checks for an Iterable (`collections.abc.Iterable`) instead of a Sequence.
Expand Down
22 changes: 19 additions & 3 deletions coconut/compiler/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
untcoable_funcs,
early_passthrough_wrapper,
new_operators,
wildcard,
)
from coconut.compiler.util import (
combine,
Expand Down Expand Up @@ -110,6 +111,7 @@
any_len_perm,
boundary,
compile_regex,
always_match,
)


Expand Down Expand Up @@ -1782,10 +1784,16 @@ class Grammar(object):
| Optional(neg_minus) + number
| match_dotted_name_const,
)
empty_const = fixto(
lparen + rparen
| lbrack + rbrack
| set_letter + lbrace + rbrace,
"()",
)

matchlist_set = Group(Optional(tokenlist(match_const, comma)))
match_pair = Group(match_const + colon.suppress() + match)
matchlist_dict = Group(Optional(tokenlist(match_pair, comma)))
set_star = star.suppress() + (keyword(wildcard) | empty_const)

matchlist_tuple_items = (
match + OneOrMore(comma.suppress() + match) + Optional(comma.suppress())
Expand Down Expand Up @@ -1834,13 +1842,21 @@ class Grammar(object):
| match_const("const")
| (keyword_atom | keyword("is").suppress() + negable_atom_item)("is")
| (keyword("in").suppress() + negable_atom_item)("in")
| (lbrace.suppress() + matchlist_dict + Optional(dubstar.suppress() + (setname | condense(lbrace + rbrace))) + rbrace.suppress())("dict")
| (Optional(set_s.suppress()) + lbrace.suppress() + matchlist_set + rbrace.suppress())("set")
| iter_match
| match_lazy("lazy")
| sequence_match
| star_match
| (lparen.suppress() + match + rparen.suppress())("paren")
| (lbrace.suppress() + matchlist_dict + Optional(dubstar.suppress() + (setname | condense(lbrace + rbrace)) + Optional(comma.suppress())) + rbrace.suppress())("dict")
| (
Group(Optional(set_letter))
+ lbrace.suppress()
+ (
Group(tokenlist(match_const, comma, allow_trailing=False)) + Optional(comma.suppress() + set_star + Optional(comma.suppress()))
| Group(always_match) + set_star + Optional(comma.suppress())
| Group(Optional(tokenlist(match_const, comma)))
) + rbrace.suppress()
)("set")
| (data_kwd.suppress() + dotted_refname + lparen.suppress() + matchlist_data + rparen.suppress())("data")
| (keyword("class").suppress() + dotted_refname + lparen.suppress() + matchlist_data + rparen.suppress())("class")
| (dotted_refname + lparen.suppress() + matchlist_data + rparen.suppress())("data_or_class")
Expand Down
63 changes: 52 additions & 11 deletions coconut/compiler/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,12 @@ def using_python_rules(self):
"""Whether the current style uses PEP 622 rules."""
return self.style.startswith("python")

def rule_conflict_warn(self, message, if_coconut=None, if_python=None, extra=None):
def rule_conflict_warn(self, message, if_coconut=None, if_python=None, extra=None, only_strict=False):
"""Warns on conflicting style rules if callback was given."""
if self.style.endswith("warn") or self.style.endswith("strict") and self.comp.strict:
if (
self.style.endswith("warn") and (not only_strict or self.comp.strict)
or self.style.endswith("strict") and self.comp.strict
):
full_msg = message
if if_python or if_coconut:
full_msg += " (" + (if_python if self.using_python_rules else if_coconut) + ")"
Expand Down Expand Up @@ -475,15 +478,16 @@ def match_dict(self, tokens, item):
self.rule_conflict_warn(
"found pattern with new behavior in Coconut v2; dict patterns now allow the dictionary being matched against to contain extra keys",
extra="use explicit '{..., **_}' or '{..., **{}}' syntax to resolve",
only_strict=True,
)
check_len = not self.using_python_rules
strict_len = not self.using_python_rules
elif rest == "{}":
check_len = True
strict_len = True
rest = None
else:
check_len = False
strict_len = False

if check_len:
if strict_len:
self.add_check("_coconut.len(" + item + ") == " + str(len(matches)))

seen_keys = set()
Expand Down Expand Up @@ -900,11 +904,48 @@ def match_in(self, tokens, item):

def match_set(self, tokens, item):
"""Matches a set."""
match, = tokens
self.add_check("_coconut.isinstance(" + item + ", _coconut.abc.Set)")
self.add_check("_coconut.len(" + item + ") == " + str(len(match)))
for const in match:
self.add_check(const + " in " + item)
if len(tokens) == 2:
letter_toks, match = tokens
star = None
else:
letter_toks, match, star = tokens

if letter_toks:
letter, = letter_toks
else:
letter = "s"

# process *() or *_
if star is None:
self.rule_conflict_warn(
"found pattern with new behavior in Coconut v3; set patterns now allow the set being matched against to contain extra items",
extra="use explicit '{..., *_}' or '{..., *()}' syntax to resolve",
)
strict_len = not self.using_python_rules
elif star == wildcard:
strict_len = False
else:
internal_assert(star == "()", "invalid set match tokens", tokens)
strict_len = True

# handle set letter
if letter == "s":
self.add_check("_coconut.isinstance(" + item + ", _coconut.abc.Set)")
elif letter == "f":
self.add_check("_coconut.isinstance(" + item + ", _coconut.frozenset)")
elif letter == "m":
self.add_check("_coconut.isinstance(" + item + ", _coconut.collections.Counter)")
else:
raise CoconutInternalException("invalid set match letter", letter)

# match set contents
if letter == "m":
self.add_check("_coconut_multiset(" + tuple_str_of(match) + ") " + ("== " if strict_len else "<= ") + item)
else:
if strict_len:
self.add_check("_coconut.len(" + item + ") == " + str(len(match)))
for const in match:
self.add_check(const + " in " + item)

def split_data_or_class_match(self, tokens):
"""Split data/class match tokens into cls_name, pos_matches, name_matches, star_match."""
Expand Down
3 changes: 2 additions & 1 deletion coconut/compiler/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,8 @@ def stores_loc_action(loc, tokens):
stores_loc_action.ignore_tokens = True


stores_loc_item = attach(Empty(), stores_loc_action, make_copy=False)
always_match = Empty()
stores_loc_item = attach(always_match, stores_loc_action)


def disallow_keywords(kwds, with_suffix=None):
Expand Down
2 changes: 1 addition & 1 deletion coconut/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def get_bool_env_var(env_var, default=False):
justify_len = 79 # ideal line length

# for pattern-matching
default_matcher_style = "python warn on strict"
default_matcher_style = "python warn"
wildcard = "_"

in_place_op_funcs = {
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "3.0.0"
VERSION_NAME = None
# False for release, int >= 1 for develop
DEVELOP = 1
DEVELOP = 2
ALPHA = True # for pre releases rather than post releases

# -----------------------------------------------------------------------------------------------------------------------
Expand Down
33 changes: 32 additions & 1 deletion coconut/tests/src/cocotest/agnostic/primary.coco
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,7 @@ def primary_test() -> bool:
assert weakref.ref(hardref)() |> list == [2, 3, 4]
my_match_err = MatchError("my match error", 123)
assert parallel_map(ident, [my_match_err]) |> list |> str == str([my_match_err])
# repeat the same thin again now that my_match_err.str has been called
# repeat the same thing again now that my_match_err.str has been called
assert parallel_map(ident, [my_match_err]) |> list |> str == str([my_match_err])
match data tuple(1, 2) in (1, 2, 3):
assert False
Expand Down Expand Up @@ -1504,6 +1504,37 @@ def primary_test() -> bool:
optx <**?..= const None
assert optx() is None

s{} = s{1, 2}
s{*_} = s{1, 2}
s{*()} = s{}
s{*[]} = s{}
s{*s{}} = s{}
s{*f{}} = s{}
s{*m{}} = s{}
match s{*()} in s{1, 2}:
assert False
s{} = f{1, 2}
f{1} = f{1, 2}
f{1, *_} = f{1, 2}
f{1, 2, *()} = f{1, 2}
match f{} in s{}:
assert False
s{} = m{1, 1}
s{1} = m{1}
m{1, 1} = m{1, 1}
m{1} = m{1, 1}
match m{1, 1} in m{1}:
assert False
m{1, *_} = m{1, 1}
match m{1, *()} in m{1, 1}:
assert False
s{*(),} = s{}
s{1, *_,} = s{1, 2}
{**{},} = {}
m{} = collections.Counter()
match m{1, 1} in collections.Counter((1, 1)):
assert False

assert_raises(() :: 1 .. 2, TypeError)
assert 1.0 2 3 ** -4 5 == 2*5/3**4
x = 10
Expand Down
6 changes: 3 additions & 3 deletions coconut/tests/src/cocotest/agnostic/util.coco
Original file line number Diff line number Diff line change
Expand Up @@ -649,10 +649,10 @@ def classify(value):
return "empty dict"
else:
return "dict"
match _ `isinstance` (set, frozenset) in value:
match s{} in value:
match s{*_} in value:
match s{*()} in value:
return "empty set"
match {0} in value:
match {0, *()} in value:
return "set of 0"
return "set"
raise TypeError()
Expand Down

0 comments on commit d6e378d

Please sign in to comment.