Skip to content

Commit

Permalink
Fix super() in match func
Browse files Browse the repository at this point in the history
Resolves   #728.
  • Loading branch information
evhub committed Apr 16, 2023
1 parent 7281f15 commit f725b58
Show file tree
Hide file tree
Showing 14 changed files with 138 additions and 76 deletions.
2 changes: 1 addition & 1 deletion coconut/__coconut__.pyi
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from __coconut__ import *
from __coconut__ import _coconut_tail_call, _coconut_tco, _coconut_call_set_names, _coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, _namedtuple_of, _coconut, _coconut_super, _coconut_Expected, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_multi_dim_arr, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in
from __coconut__ import _coconut_tail_call, _coconut_tco, _coconut_call_set_names, _coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, _namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_multi_dim_arr, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in
3 changes: 2 additions & 1 deletion coconut/command/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
interpreter_uses_auto_compilation,
interpreter_uses_coconut_breakpoint,
interpreter_compiler_var,
must_use_specific_target_builtins,
)

if PY26:
Expand Down Expand Up @@ -568,7 +569,7 @@ def fix_pickle(self):
"""Fix pickling of Coconut header objects."""
from coconut import __coconut__ # this is expensive, so only do it here
for var in self.vars:
if not var.startswith("__") and var in dir(__coconut__):
if not var.startswith("__") and var in dir(__coconut__) and var not in must_use_specific_target_builtins:
cur_val = self.vars[var]
static_val = getattr(__coconut__, var)
if getattr(cur_val, "__doc__", None) == getattr(static_val, "__doc__", None):
Expand Down
39 changes: 26 additions & 13 deletions coconut/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
streamline_grammar_for_len,
all_builtins,
in_place_op_funcs,
match_first_arg_var,
)
from coconut.util import (
pickleable_obj,
Expand Down Expand Up @@ -176,6 +177,13 @@
# -----------------------------------------------------------------------------------------------------------------------


match_func_paramdef = "{match_first_arg_var}=_coconut_sentinel, *{match_to_args_var}, **{match_to_kwargs_var}".format(
match_first_arg_var=match_first_arg_var,
match_to_args_var=match_to_args_var,
match_to_kwargs_var=match_to_kwargs_var,
)


def set_to_tuple(tokens):
"""Converts set literal tokens to tuples."""
internal_assert(len(tokens) == 1, "invalid set maker tokens", tokens)
Expand Down Expand Up @@ -1901,10 +1909,7 @@ def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method,
def_name = self.get_temp_var(undotted_name)

# detect pattern-matching functions
is_match_func = func_paramdef == "*{match_to_args_var}, **{match_to_kwargs_var}".format(
match_to_args_var=match_to_args_var,
match_to_kwargs_var=match_to_kwargs_var,
)
is_match_func = func_paramdef == match_func_paramdef

# handle addpattern functions
if addpattern:
Expand Down Expand Up @@ -2612,23 +2617,28 @@ def match_datadef_handle(self, original, loc, tokens):
matcher = self.get_matcher(original, loc, check_var, name_list=[])

pos_only_args, req_args, default_args, star_arg, kwd_only_args, dubstar_arg = split_args_list(matches, loc)
matcher.match_function(match_to_args_var, match_to_kwargs_var, pos_only_args, req_args + default_args, star_arg, kwd_only_args, dubstar_arg)
matcher.match_function(
pos_only_match_args=pos_only_args,
match_args=req_args + default_args,
star_arg=star_arg,
kwd_only_match_args=kwd_only_args,
dubstar_arg=dubstar_arg,
)

if cond is not None:
matcher.add_guard(cond)

extra_stmts = handle_indentation(
'''
def __new__(_coconut_cls, *{match_to_args_var}, **{match_to_kwargs_var}):
def __new__(_coconut_cls, {match_func_paramdef}):
{check_var} = False
{matching}
{pattern_error}
return _coconut.tuple.__new__(_coconut_cls, {arg_tuple})
''',
add_newline=True,
).format(
match_to_args_var=match_to_args_var,
match_to_kwargs_var=match_to_kwargs_var,
match_func_paramdef=match_func_paramdef,
check_var=check_var,
matching=matcher.out(),
pattern_error=self.pattern_error(original, loc, match_to_args_var, check_var, function_match_error_var),
Expand Down Expand Up @@ -3129,15 +3139,18 @@ def name_match_funcdef_handle(self, original, loc, tokens):
matcher = self.get_matcher(original, loc, check_var)

pos_only_args, req_args, default_args, star_arg, kwd_only_args, dubstar_arg = split_args_list(matches, loc)
matcher.match_function(match_to_args_var, match_to_kwargs_var, pos_only_args, req_args + default_args, star_arg, kwd_only_args, dubstar_arg)
matcher.match_function(
pos_only_match_args=pos_only_args,
match_args=req_args + default_args,
star_arg=star_arg,
kwd_only_match_args=kwd_only_args,
dubstar_arg=dubstar_arg,
)

if cond is not None:
matcher.add_guard(cond)

before_colon = (
"def " + func
+ "(*" + match_to_args_var + ", **" + match_to_kwargs_var + ")"
)
before_colon = "def " + func + "(" + match_func_paramdef + ")"
after_docstring = (
openindent
+ check_var + " = False\n"
Expand Down
35 changes: 19 additions & 16 deletions coconut/compiler/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from collections import defaultdict
from contextlib import contextmanager
from functools import partial

from coconut._pyparsing import (
CaselessLiteral,
Expand Down Expand Up @@ -96,7 +97,7 @@
split_trailing_indent,
split_leading_indent,
collapse_indents,
keyword,
base_keyword,
match_in,
disallow_keywords,
regex_item,
Expand Down Expand Up @@ -717,17 +718,20 @@ class Grammar(object):
questionmark = ~dubquestion + Literal("?")
bang = ~Literal("!=") + Literal("!")

keyword = partial(base_keyword, explicit_prefix=colon)

except_star_kwd = combine(keyword("except") + star)
except_kwd = ~except_star_kwd + keyword("except")
lambda_kwd = keyword("lambda") | fixto(keyword("\u03bb", explicit_prefix=colon), "lambda")
data_kwd = keyword("data", explicit_prefix=colon)
match_kwd = keyword("match", explicit_prefix=colon)
case_kwd = keyword("case", explicit_prefix=colon)
cases_kwd = keyword("cases", explicit_prefix=colon)
where_kwd = keyword("where", explicit_prefix=colon)
addpattern_kwd = keyword("addpattern", explicit_prefix=colon)
then_kwd = keyword("then", explicit_prefix=colon)
type_kwd = keyword("type", explicit_prefix=colon)
lambda_kwd = keyword("lambda") | fixto(keyword("\u03bb"), "lambda")
operator_kwd = keyword("operator", require_whitespace=True)
data_kwd = keyword("data")
match_kwd = keyword("match")
case_kwd = keyword("case")
cases_kwd = keyword("cases")
where_kwd = keyword("where")
addpattern_kwd = keyword("addpattern")
then_kwd = keyword("then")
type_kwd = keyword("type")

ellipsis = Forward()
ellipsis_tokens = Literal("...") | fixto(Literal("\u2026"), "...")
Expand Down Expand Up @@ -1905,7 +1909,7 @@ class Grammar(object):
+ testlist_star_namedexpr
+ match_guard
# avoid match match-case blocks
+ ~FollowedBy(colon + newline + indent + keyword("case", explicit_prefix=colon))
+ ~FollowedBy(colon + newline + indent + case_kwd)
- full_suite
)
match_stmt = trace(condense(full_match - Optional(else_stmt)))
Expand Down Expand Up @@ -2369,10 +2373,10 @@ def get_tre_return_grammar(self, func_name):
"""The TRE return grammar is parameterized by the name of the function being optimized."""
return (
self.start_marker
+ keyword("return").suppress()
+ self.keyword("return").suppress()
+ maybeparens(
self.lparen,
keyword(func_name, explicit_prefix=False).suppress()
base_keyword(func_name).suppress()
+ self.original_function_call_tokens,
self.rparen,
) + self.end_marker
Expand Down Expand Up @@ -2421,8 +2425,8 @@ def get_tre_return_grammar(self, func_name):
| ~comma + ~rparen + ~equals + any_char,
),
)
tfpdef_tokens = unsafe_name - Optional(colon.suppress() - rest_of_tfpdef.suppress())
tfpdef_default_tokens = tfpdef_tokens - Optional(equals.suppress() - rest_of_tfpdef)
tfpdef_tokens = unsafe_name - Optional(colon - rest_of_tfpdef).suppress()
tfpdef_default_tokens = tfpdef_tokens - Optional(equals - rest_of_tfpdef)
type_comment = Optional(
comment_tokens.suppress()
| passthrough_item.suppress(),
Expand Down Expand Up @@ -2481,7 +2485,6 @@ def get_tre_return_grammar(self, func_name):

string_start = start_marker + quotedString

operator_kwd = keyword("operator", explicit_prefix=colon, require_whitespace=True)
operator_stmt = (
start_marker
+ operator_kwd.suppress()
Expand Down

0 comments on commit f725b58

Please sign in to comment.