From e9a58117b3af93032c7cd56c3cee6636e947b5b2 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Fri, 16 Feb 2024 16:30:13 -0500 Subject: [PATCH] Switch to ruff preview formatting So I can get https://github.com/astral-sh/ruff/issues/5822#issuecomment-1949350841 --- pyproject.toml | 4 + python/egglog/builtins.py | 348 ++++++---------- python/egglog/declarations.py | 3 +- python/egglog/egraph.py | 126 ++---- python/egglog/examples/bool.py | 3 +- python/egglog/examples/eqsat_basic.py | 13 +- python/egglog/examples/fib.py | 4 +- python/egglog/examples/lambda_.py | 40 +- python/egglog/examples/matrix.py | 4 +- python/egglog/examples/ndarrays.py | 43 +- python/egglog/exp/array_api.py | 452 +++++++-------------- python/egglog/exp/array_api_program_gen.py | 60 +-- python/egglog/runtime.py | 2 +- python/tests/test_convert.py | 24 +- python/tests/test_high_level.py | 138 +++---- python/tests/test_program_gen.py | 21 +- 16 files changed, 436 insertions(+), 849 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f5a0e0a9..73bf7cbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,6 +182,10 @@ select = ["ALL"] extend-exclude = ["python/tests/__snapshots__"] unsafe-fixes = true + +[tool.ruff.format] +preview = true + [tool.ruff.lint.per-file-ignores] # Don't require annotations for tests "python/tests/**" = ["ANN001", "ANN201"] diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index 7ce89c70..c2b3a5f9 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -35,20 +35,17 @@ class String(Expr, builtin=True): - def __init__(self, value: str) -> None: - ... + def __init__(self, value: str) -> None: ... @method(egg_fn="replace") - def replace(self, old: StringLike, new: StringLike) -> String: - ... + def replace(self, old: StringLike, new: StringLike) -> String: ... StringLike: TypeAlias = String | str @function(egg_fn="+", builtin=True) -def join(*strings: StringLike) -> String: - ... +def join(*strings: StringLike) -> String: ... converter(str, String, String) @@ -57,28 +54,22 @@ def join(*strings: StringLike) -> String: class Bool(Expr, egg_sort="bool", builtin=True): - def __init__(self, value: bool) -> None: - ... + def __init__(self, value: bool) -> None: ... @method(egg_fn="not") - def __invert__(self) -> Bool: - ... + def __invert__(self) -> Bool: ... @method(egg_fn="and") - def __and__(self, other: BoolLike) -> Bool: - ... + def __and__(self, other: BoolLike) -> Bool: ... @method(egg_fn="or") - def __or__(self, other: BoolLike) -> Bool: - ... + def __or__(self, other: BoolLike) -> Bool: ... @method(egg_fn="xor") - def __xor__(self, other: BoolLike) -> Bool: - ... + def __xor__(self, other: BoolLike) -> Bool: ... @method(egg_fn="=>") - def implies(self, other: BoolLike) -> Bool: - ... + def implies(self, other: BoolLike) -> Bool: ... converter(bool, Bool, Bool) @@ -88,218 +79,167 @@ def implies(self, other: BoolLike) -> Bool: class i64(Expr, builtin=True): # noqa: N801 - def __init__(self, value: int) -> None: - ... + def __init__(self, value: int) -> None: ... @method(egg_fn="+") - def __add__(self, other: i64Like) -> i64: - ... + def __add__(self, other: i64Like) -> i64: ... @method(egg_fn="-") - def __sub__(self, other: i64Like) -> i64: - ... + def __sub__(self, other: i64Like) -> i64: ... @method(egg_fn="*") - def __mul__(self, other: i64Like) -> i64: - ... + def __mul__(self, other: i64Like) -> i64: ... @method(egg_fn="/") - def __truediv__(self, other: i64Like) -> i64: - ... + def __truediv__(self, other: i64Like) -> i64: ... @method(egg_fn="%") - def __mod__(self, other: i64Like) -> i64: - ... + def __mod__(self, other: i64Like) -> i64: ... @method(egg_fn="&") - def __and__(self, other: i64Like) -> i64: - ... + def __and__(self, other: i64Like) -> i64: ... @method(egg_fn="|") - def __or__(self, other: i64Like) -> i64: - ... + def __or__(self, other: i64Like) -> i64: ... @method(egg_fn="^") - def __xor__(self, other: i64Like) -> i64: - ... + def __xor__(self, other: i64Like) -> i64: ... @method(egg_fn="<<") - def __lshift__(self, other: i64Like) -> i64: - ... + def __lshift__(self, other: i64Like) -> i64: ... @method(egg_fn=">>") - def __rshift__(self, other: i64Like) -> i64: - ... + def __rshift__(self, other: i64Like) -> i64: ... - def __radd__(self, other: i64Like) -> i64: - ... + def __radd__(self, other: i64Like) -> i64: ... - def __rsub__(self, other: i64Like) -> i64: - ... + def __rsub__(self, other: i64Like) -> i64: ... - def __rmul__(self, other: i64Like) -> i64: - ... + def __rmul__(self, other: i64Like) -> i64: ... - def __rtruediv__(self, other: i64Like) -> i64: - ... + def __rtruediv__(self, other: i64Like) -> i64: ... - def __rmod__(self, other: i64Like) -> i64: - ... + def __rmod__(self, other: i64Like) -> i64: ... - def __rand__(self, other: i64Like) -> i64: - ... + def __rand__(self, other: i64Like) -> i64: ... - def __ror__(self, other: i64Like) -> i64: - ... + def __ror__(self, other: i64Like) -> i64: ... - def __rxor__(self, other: i64Like) -> i64: - ... + def __rxor__(self, other: i64Like) -> i64: ... - def __rlshift__(self, other: i64Like) -> i64: - ... + def __rlshift__(self, other: i64Like) -> i64: ... - def __rrshift__(self, other: i64Like) -> i64: - ... + def __rrshift__(self, other: i64Like) -> i64: ... @method(egg_fn="not-i64") - def __invert__(self) -> i64: - ... + def __invert__(self) -> i64: ... @method(egg_fn="<") def __lt__(self, other: i64Like) -> Unit: # type: ignore[empty-body,has-type] ... @method(egg_fn=">") - def __gt__(self, other: i64Like) -> Unit: - ... + def __gt__(self, other: i64Like) -> Unit: ... @method(egg_fn="<=") def __le__(self, other: i64Like) -> Unit: # type: ignore[empty-body,has-type] ... @method(egg_fn=">=") - def __ge__(self, other: i64Like) -> Unit: - ... + def __ge__(self, other: i64Like) -> Unit: ... @method(egg_fn="min") - def min(self, other: i64Like) -> i64: - ... + def min(self, other: i64Like) -> i64: ... @method(egg_fn="max") - def max(self, other: i64Like) -> i64: - ... + def max(self, other: i64Like) -> i64: ... @method(egg_fn="to-string") - def to_string(self) -> String: - ... + def to_string(self) -> String: ... @method(egg_fn="bool-<") - def bool_lt(self, other: i64Like) -> Bool: - ... + def bool_lt(self, other: i64Like) -> Bool: ... @method(egg_fn="bool->") - def bool_gt(self, other: i64Like) -> Bool: - ... + def bool_gt(self, other: i64Like) -> Bool: ... @method(egg_fn="bool-<=") - def bool_le(self, other: i64Like) -> Bool: - ... + def bool_le(self, other: i64Like) -> Bool: ... @method(egg_fn="bool->=") - def bool_ge(self, other: i64Like) -> Bool: - ... + def bool_ge(self, other: i64Like) -> Bool: ... converter(int, i64, i64) @function(builtin=True, egg_fn="count-matches") -def count_matches(s: StringLike, pattern: StringLike) -> i64: - ... +def count_matches(s: StringLike, pattern: StringLike) -> i64: ... f64Like = Union["f64", float] # noqa: N816 class f64(Expr, builtin=True): # noqa: N801 - def __init__(self, value: float) -> None: - ... + def __init__(self, value: float) -> None: ... @method(egg_fn="neg") - def __neg__(self) -> f64: - ... + def __neg__(self) -> f64: ... @method(egg_fn="+") - def __add__(self, other: f64Like) -> f64: - ... + def __add__(self, other: f64Like) -> f64: ... @method(egg_fn="-") - def __sub__(self, other: f64Like) -> f64: - ... + def __sub__(self, other: f64Like) -> f64: ... @method(egg_fn="*") - def __mul__(self, other: f64Like) -> f64: - ... + def __mul__(self, other: f64Like) -> f64: ... @method(egg_fn="/") - def __truediv__(self, other: f64Like) -> f64: - ... + def __truediv__(self, other: f64Like) -> f64: ... @method(egg_fn="%") - def __mod__(self, other: f64Like) -> f64: - ... + def __mod__(self, other: f64Like) -> f64: ... - def __radd__(self, other: f64Like) -> f64: - ... + def __radd__(self, other: f64Like) -> f64: ... - def __rsub__(self, other: f64Like) -> f64: - ... + def __rsub__(self, other: f64Like) -> f64: ... - def __rmul__(self, other: f64Like) -> f64: - ... + def __rmul__(self, other: f64Like) -> f64: ... - def __rtruediv__(self, other: f64Like) -> f64: - ... + def __rtruediv__(self, other: f64Like) -> f64: ... - def __rmod__(self, other: f64Like) -> f64: - ... + def __rmod__(self, other: f64Like) -> f64: ... @method(egg_fn="<") def __lt__(self, other: f64Like) -> Unit: # type: ignore[empty-body,has-type] ... @method(egg_fn=">") - def __gt__(self, other: f64Like) -> Unit: - ... + def __gt__(self, other: f64Like) -> Unit: ... @method(egg_fn="<=") def __le__(self, other: f64Like) -> Unit: # type: ignore[empty-body,has-type] ... @method(egg_fn=">=") - def __ge__(self, other: f64Like) -> Unit: - ... + def __ge__(self, other: f64Like) -> Unit: ... @method(egg_fn="min") - def min(self, other: f64Like) -> f64: - ... + def min(self, other: f64Like) -> f64: ... @method(egg_fn="max") - def max(self, other: f64Like) -> f64: - ... + def max(self, other: f64Like) -> f64: ... @method(egg_fn="to-i64") - def to_i64(self) -> i64: - ... + def to_i64(self) -> i64: ... @method(egg_fn="to-f64") @classmethod - def from_i64(cls, i: i64) -> f64: - ... + def from_i64(cls, i: i64) -> f64: ... @method(egg_fn="to-string") - def to_string(self) -> String: - ... + def to_string(self) -> String: ... converter(float, f64, f64) @@ -312,243 +252,188 @@ def to_string(self) -> String: class Map(Expr, Generic[T, V], builtin=True): @method(egg_fn="map-empty") @classmethod - def empty(cls) -> Map[T, V]: - ... + def empty(cls) -> Map[T, V]: ... @method(egg_fn="map-insert") - def insert(self, key: T, value: V) -> Map[T, V]: - ... + def insert(self, key: T, value: V) -> Map[T, V]: ... @method(egg_fn="map-get") - def __getitem__(self, key: T) -> V: - ... + def __getitem__(self, key: T) -> V: ... @method(egg_fn="map-not-contains") - def not_contains(self, key: T) -> Unit: - ... + def not_contains(self, key: T) -> Unit: ... @method(egg_fn="map-contains") - def contains(self, key: T) -> Unit: - ... + def contains(self, key: T) -> Unit: ... @method(egg_fn="map-remove") - def remove(self, key: T) -> Map[T, V]: - ... + def remove(self, key: T) -> Map[T, V]: ... @method(egg_fn="rebuild") - def rebuild(self) -> Map[T, V]: - ... + def rebuild(self) -> Map[T, V]: ... class Set(Expr, Generic[T], builtin=True): @method(egg_fn="set-of") - def __init__(self, *args: T) -> None: - ... + def __init__(self, *args: T) -> None: ... @method(egg_fn="set-empty") @classmethod - def empty(cls) -> Set[T]: - ... + def empty(cls) -> Set[T]: ... @method(egg_fn="set-insert") - def insert(self, value: T) -> Set[T]: - ... + def insert(self, value: T) -> Set[T]: ... @method(egg_fn="set-not-contains") - def not_contains(self, value: T) -> Unit: - ... + def not_contains(self, value: T) -> Unit: ... @method(egg_fn="set-contains") - def contains(self, value: T) -> Unit: - ... + def contains(self, value: T) -> Unit: ... @method(egg_fn="set-remove") - def remove(self, value: T) -> Set[T]: - ... + def remove(self, value: T) -> Set[T]: ... @method(egg_fn="set-union") - def __or__(self, other: Set[T]) -> Set[T]: - ... + def __or__(self, other: Set[T]) -> Set[T]: ... @method(egg_fn="set-diff") - def __sub__(self, other: Set[T]) -> Set[T]: - ... + def __sub__(self, other: Set[T]) -> Set[T]: ... @method(egg_fn="set-intersect") - def __and__(self, other: Set[T]) -> Set[T]: - ... + def __and__(self, other: Set[T]) -> Set[T]: ... @method(egg_fn="rebuild") - def rebuild(self) -> Set[T]: - ... + def rebuild(self) -> Set[T]: ... class Rational(Expr, builtin=True): @method(egg_fn="rational") - def __init__(self, num: i64Like, den: i64Like) -> None: - ... + def __init__(self, num: i64Like, den: i64Like) -> None: ... @method(egg_fn="to-f64") - def to_f64(self) -> f64: - ... + def to_f64(self) -> f64: ... @method(egg_fn="+") - def __add__(self, other: Rational) -> Rational: - ... + def __add__(self, other: Rational) -> Rational: ... @method(egg_fn="-") - def __sub__(self, other: Rational) -> Rational: - ... + def __sub__(self, other: Rational) -> Rational: ... @method(egg_fn="*") - def __mul__(self, other: Rational) -> Rational: - ... + def __mul__(self, other: Rational) -> Rational: ... @method(egg_fn="/") - def __truediv__(self, other: Rational) -> Rational: - ... + def __truediv__(self, other: Rational) -> Rational: ... @method(egg_fn="min") - def min(self, other: Rational) -> Rational: - ... + def min(self, other: Rational) -> Rational: ... @method(egg_fn="max") - def max(self, other: Rational) -> Rational: - ... + def max(self, other: Rational) -> Rational: ... @method(egg_fn="neg") - def __neg__(self) -> Rational: - ... + def __neg__(self) -> Rational: ... @method(egg_fn="abs") - def __abs__(self) -> Rational: - ... + def __abs__(self) -> Rational: ... @method(egg_fn="floor") - def floor(self) -> Rational: - ... + def floor(self) -> Rational: ... @method(egg_fn="ceil") - def ceil(self) -> Rational: - ... + def ceil(self) -> Rational: ... @method(egg_fn="round") - def round(self) -> Rational: - ... + def round(self) -> Rational: ... @method(egg_fn="pow") - def __pow__(self, other: Rational) -> Rational: - ... + def __pow__(self, other: Rational) -> Rational: ... @method(egg_fn="log") - def log(self) -> Rational: - ... + def log(self) -> Rational: ... @method(egg_fn="sqrt") - def sqrt(self) -> Rational: - ... + def sqrt(self) -> Rational: ... @method(egg_fn="cbrt") - def cbrt(self) -> Rational: - ... + def cbrt(self) -> Rational: ... @method(egg_fn="numer") # type: ignore[misc] @property - def numer(self) -> i64: - ... + def numer(self) -> i64: ... @method(egg_fn="denom") # type: ignore[misc] @property - def denom(self) -> i64: - ... + def denom(self) -> i64: ... class Vec(Expr, Generic[T], builtin=True): @method(egg_fn="vec-of") - def __init__(self, *args: T) -> None: - ... + def __init__(self, *args: T) -> None: ... @method(egg_fn="vec-empty") @classmethod - def empty(cls) -> Vec[T]: - ... + def empty(cls) -> Vec[T]: ... @method(egg_fn="vec-append") - def append(self, *others: Vec[T]) -> Vec[T]: - ... + def append(self, *others: Vec[T]) -> Vec[T]: ... @method(egg_fn="vec-push") - def push(self, value: T) -> Vec[T]: - ... + def push(self, value: T) -> Vec[T]: ... @method(egg_fn="vec-pop") - def pop(self) -> Vec[T]: - ... + def pop(self) -> Vec[T]: ... @method(egg_fn="vec-not-contains") - def not_contains(self, value: T) -> Unit: - ... + def not_contains(self, value: T) -> Unit: ... @method(egg_fn="vec-contains") - def contains(self, value: T) -> Unit: - ... + def contains(self, value: T) -> Unit: ... @method(egg_fn="vec-length") - def length(self) -> i64: - ... + def length(self) -> i64: ... @method(egg_fn="vec-get") - def __getitem__(self, index: i64Like) -> T: - ... + def __getitem__(self, index: i64Like) -> T: ... @method(egg_fn="rebuild") - def rebuild(self) -> Vec[T]: - ... + def rebuild(self) -> Vec[T]: ... class PyObject(Expr, builtin=True): - def __init__(self, value: object) -> None: - ... + def __init__(self, value: object) -> None: ... @method(egg_fn="py-from-string") @classmethod - def from_string(cls, s: StringLike) -> PyObject: - ... + def from_string(cls, s: StringLike) -> PyObject: ... @method(egg_fn="py-to-string") - def to_string(self) -> String: - ... + def to_string(self) -> String: ... @method(egg_fn="py-to-bool") - def to_bool(self) -> Bool: - ... + def to_bool(self) -> Bool: ... @method(egg_fn="py-dict-update") - def dict_update(self, *keys_and_values: object) -> PyObject: - ... + def dict_update(self, *keys_and_values: object) -> PyObject: ... @method(egg_fn="py-from-int") @classmethod - def from_int(cls, i: i64Like) -> PyObject: - ... + def from_int(cls, i: i64Like) -> PyObject: ... @method(egg_fn="py-dict") @classmethod - def dict(cls, *keys_and_values: object) -> PyObject: - ... + def dict(cls, *keys_and_values: object) -> PyObject: ... converter(object, PyObject, PyObject) @function(builtin=True, egg_fn="py-eval") -def py_eval(code: StringLike, globals: object = PyObject.dict(), locals: object = PyObject.dict()) -> PyObject: - ... +def py_eval(code: StringLike, globals: object = PyObject.dict(), locals: object = PyObject.dict()) -> PyObject: ... class PyObjectFunction(Protocol): - def __call__(self, *__args: PyObject) -> PyObject: - ... + def __call__(self, *__args: PyObject) -> PyObject: ... def py_eval_fn(fn: Callable) -> PyObjectFunction: @@ -563,8 +448,7 @@ def inner(*__args: PyObject, __fn: Callable = fn) -> PyObject: new_kvs: list[object] = [] eval_str = "__fn(" for i, arg in enumerate(__args): - new_kvs.append(f"__arg_{i}") - new_kvs.append(arg) + new_kvs.extend((f"__arg_{i}", arg)) eval_str += f"__arg_{i}, " eval_str += ")" return py_eval(eval_str, PyObject({"__fn": __fn}).dict_update(*new_kvs), __fn.__globals__) diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index 7e3e80cb..fc396b44 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -103,8 +103,7 @@ @runtime_checkable class HasDeclerations(Protocol): @property - def __egg_decls__(self) -> Declarations: - ... + def __egg_decls__(self) -> Declarations: ... DeclerationsLike: TypeAlias = Union[HasDeclerations, None, "Declarations"] diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 60067443..f6d68a93 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -170,13 +170,11 @@ def __post_init__(self, modules: list[Module]) -> None: @deprecated("Remove this decorator and move the egg_sort to the class statement, i.e. E(Expr, egg_sort='MySort').") @overload - def class_(self, *, egg_sort: str) -> Callable[[TYPE], TYPE]: - ... + def class_(self, *, egg_sort: str) -> Callable[[TYPE], TYPE]: ... @deprecated("Remove this decorator. Simply subclassing Expr is enough now.") @overload - def class_(self, cls: TYPE, /) -> TYPE: - ... + def class_(self, cls: TYPE, /) -> TYPE: ... def class_(self, *args, **kwargs) -> Any: """ @@ -201,8 +199,7 @@ def method( self, *, preserve: Literal[True], - ) -> Callable[[CALLABLE], CALLABLE]: - ... + ) -> Callable[[CALLABLE], CALLABLE]: ... @overload def method( @@ -214,8 +211,7 @@ def method( on_merge: Callable[[Any, Any], Iterable[ActionLike]] | None = None, mutates_self: bool = False, unextractable: bool = False, - ) -> Callable[[CALLABLE], CALLABLE]: - ... + ) -> Callable[[CALLABLE], CALLABLE]: ... @overload def method( @@ -228,8 +224,7 @@ def method( on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None, mutates_self: bool = False, unextractable: bool = False, - ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: - ... + ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ... @deprecated("Use top level method function instead") def method( @@ -249,8 +244,7 @@ def method( ) @overload - def function(self, fn: CALLABLE, /) -> CALLABLE: - ... + def function(self, fn: CALLABLE, /) -> CALLABLE: ... @overload def function( @@ -262,8 +256,7 @@ def function( on_merge: Callable[[Any, Any], Iterable[ActionLike]] | None = None, mutates_first_arg: bool = False, unextractable: bool = False, - ) -> Callable[[CALLABLE], CALLABLE]: - ... + ) -> Callable[[CALLABLE], CALLABLE]: ... @overload def function( @@ -276,8 +269,7 @@ def function( on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None, mutates_first_arg: bool = False, unextractable: bool = False, - ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: - ... + ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ... @deprecated("Use top level function `function` instead") def function(self, *args, **kwargs) -> Any: @@ -300,24 +292,19 @@ def ruleset(self, name: str) -> Ruleset: @overload def relation( self, name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], tp4: type[E4], / - ) -> Callable[[E1, E2, E3, E4], Unit]: - ... + ) -> Callable[[E1, E2, E3, E4], Unit]: ... @overload - def relation(self, name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], /) -> Callable[[E1, E2, E3], Unit]: - ... + def relation(self, name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], /) -> Callable[[E1, E2, E3], Unit]: ... @overload - def relation(self, name: str, tp1: type[E1], tp2: type[E2], /) -> Callable[[E1, E2], Unit]: - ... + def relation(self, name: str, tp1: type[E1], tp2: type[E2], /) -> Callable[[E1, E2], Unit]: ... @overload - def relation(self, name: str, tp1: type[T], /, *, egg_fn: str | None = None) -> Callable[[T], Unit]: - ... + def relation(self, name: str, tp1: type[T], /, *, egg_fn: str | None = None) -> Callable[[T], Unit]: ... @overload - def relation(self, name: str, /, *, egg_fn: str | None = None) -> Callable[[], Unit]: - ... + def relation(self, name: str, /, *, egg_fn: str | None = None) -> Callable[[], Unit]: ... @deprecated("Use top level relation function instead") def relation(self, name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..., Unit]: @@ -362,8 +349,7 @@ def _register_commands(self, cmds: list[Command]) -> None: def method( *, preserve: Literal[True], -) -> Callable[[CALLABLE], CALLABLE]: - ... +) -> Callable[[CALLABLE], CALLABLE]: ... # We have to seperate method/function overloads for those that use the T params and those that don't @@ -380,8 +366,7 @@ def method( on_merge: Callable[[Any, Any], Iterable[ActionLike]] | None = None, mutates_self: bool = False, unextractable: bool = False, -) -> Callable[[CALLABLE], CALLABLE]: - ... +) -> Callable[[CALLABLE], CALLABLE]: ... @overload @@ -394,8 +379,7 @@ def method( on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None, mutates_self: bool = False, unextractable: bool = False, -) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: - ... +) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ... def method( @@ -565,8 +549,7 @@ class _Dummytype: @overload -def function(fn: CALLABLE, /) -> CALLABLE: - ... +def function(fn: CALLABLE, /) -> CALLABLE: ... @overload @@ -579,8 +562,7 @@ def function( mutates_first_arg: bool = False, unextractable: bool = False, builtin: bool = False, -) -> Callable[[CALLABLE], CALLABLE]: - ... +) -> Callable[[CALLABLE], CALLABLE]: ... @overload @@ -593,8 +575,7 @@ def function( on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None, mutates_first_arg: bool = False, unextractable: bool = False, -) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: - ... +) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ... def function(*args, **kwargs) -> Any: @@ -760,28 +741,23 @@ def _register_function( @overload def relation( name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], tp4: type[E4], / -) -> Callable[[E1, E2, E3, E4], Unit]: - ... +) -> Callable[[E1, E2, E3, E4], Unit]: ... @overload -def relation(name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], /) -> Callable[[E1, E2, E3], Unit]: - ... +def relation(name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], /) -> Callable[[E1, E2, E3], Unit]: ... @overload -def relation(name: str, tp1: type[E1], tp2: type[E2], /) -> Callable[[E1, E2], Unit]: - ... +def relation(name: str, tp1: type[E1], tp2: type[E2], /) -> Callable[[E1, E2], Unit]: ... @overload -def relation(name: str, tp1: type[T], /, *, egg_fn: str | None = None) -> Callable[[T], Unit]: - ... +def relation(name: str, tp1: type[T], /, *, egg_fn: str | None = None) -> Callable[[T], Unit]: ... @overload -def relation(name: str, /, *, egg_fn: str | None = None) -> Callable[[], Unit]: - ... +def relation(name: str, /, *, egg_fn: str | None = None) -> Callable[[], Unit]: ... def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..., Unit]: @@ -1053,12 +1029,10 @@ def let(self, name: str, expr: EXPR) -> EXPR: return cast(EXPR, RuntimeExpr(expr.__egg_decls__, TypedExprDecl(expr.__egg_typed_expr__.tp, VarDecl(name)))) @overload - def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> EXPR: - ... + def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> EXPR: ... @overload - def simplify(self, expr: EXPR, schedule: Schedule, /) -> EXPR: - ... + def simplify(self, expr: EXPR, schedule: Schedule, /) -> EXPR: ... def simplify( self, expr: EXPR, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None @@ -1102,12 +1076,10 @@ def output(self) -> None: raise NotImplementedError(msg) @overload - def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> bindings.RunReport: - ... + def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> bindings.RunReport: ... @overload - def run(self, schedule: Schedule, /) -> bindings.RunReport: - ... + def run(self, schedule: Schedule, /) -> bindings.RunReport: ... def run( self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None @@ -1147,12 +1119,10 @@ def _facts_to_check(self, facts: Iterable[FactLike]) -> bindings.Check: return bindings.Check(egg_facts) @overload - def extract(self, expr: EXPR, /, include_cost: Literal[False] = False) -> EXPR: - ... + def extract(self, expr: EXPR, /, include_cost: Literal[False] = False) -> EXPR: ... @overload - def extract(self, expr: EXPR, /, include_cost: Literal[True]) -> tuple[EXPR, int]: - ... + def extract(self, expr: EXPR, /, include_cost: Literal[True]) -> tuple[EXPR, int]: ... def extract(self, expr: EXPR, include_cost: bool = False) -> EXPR | tuple[EXPR, int]: """ @@ -1229,24 +1199,19 @@ def __exit__(self, exc_type, exc, exc_tb) -> None: # noqa: ANN001 self.pop() @overload - def eval(self, expr: i64) -> int: - ... + def eval(self, expr: i64) -> int: ... @overload - def eval(self, expr: f64) -> float: - ... + def eval(self, expr: f64) -> float: ... @overload - def eval(self, expr: Bool) -> bool: - ... + def eval(self, expr: Bool) -> bool: ... @overload - def eval(self, expr: String) -> str: - ... + def eval(self, expr: String) -> str: ... @overload - def eval(self, expr: PyObject) -> object: - ... + def eval(self, expr: PyObject) -> object: ... def eval(self, expr: Expr) -> object: """ @@ -1347,8 +1312,7 @@ class Unit(Expr, egg_sort="Unit", builtin=True): The unit type. This is also used to reprsent if a value exists, if it is resolved or not. """ - def __init__(self) -> None: - ... + def __init__(self) -> None: ... def ruleset( @@ -1836,13 +1800,11 @@ def __egg_decls__(self) -> Declarations: @deprecated("Use .register() instead of passing rulesets as arguments to rewrites.") @overload -def rewrite(lhs: EXPR, ruleset: Ruleset) -> _RewriteBuilder[EXPR]: - ... +def rewrite(lhs: EXPR, ruleset: Ruleset) -> _RewriteBuilder[EXPR]: ... @overload -def rewrite(lhs: EXPR, ruleset: None = None) -> _RewriteBuilder[EXPR]: - ... +def rewrite(lhs: EXPR, ruleset: None = None) -> _RewriteBuilder[EXPR]: ... def rewrite(lhs: EXPR, ruleset: Ruleset | None = None) -> _RewriteBuilder[EXPR]: @@ -1852,13 +1814,11 @@ def rewrite(lhs: EXPR, ruleset: Ruleset | None = None) -> _RewriteBuilder[EXPR]: @deprecated("Use .register() instead of passing rulesets as arguments to birewrites.") @overload -def birewrite(lhs: EXPR, ruleset: Ruleset) -> _BirewriteBuilder[EXPR]: - ... +def birewrite(lhs: EXPR, ruleset: Ruleset) -> _BirewriteBuilder[EXPR]: ... @overload -def birewrite(lhs: EXPR, ruleset: None = None) -> _BirewriteBuilder[EXPR]: - ... +def birewrite(lhs: EXPR, ruleset: None = None) -> _BirewriteBuilder[EXPR]: ... def birewrite(lhs: EXPR, ruleset: Ruleset | None = None) -> _BirewriteBuilder[EXPR]: @@ -1911,13 +1871,11 @@ def set_(lhs: EXPR) -> _SetBuilder[EXPR]: @deprecated("Use .register() instead of passing rulesets as arguments to rules.") @overload -def rule(*facts: FactLike, ruleset: Ruleset, name: str | None = None) -> _RuleBuilder: - ... +def rule(*facts: FactLike, ruleset: Ruleset, name: str | None = None) -> _RuleBuilder: ... @overload -def rule(*facts: FactLike, ruleset: None = None, name: str | None = None) -> _RuleBuilder: - ... +def rule(*facts: FactLike, ruleset: None = None, name: str | None = None) -> _RuleBuilder: ... def rule(*facts: FactLike, ruleset: Ruleset | None = None, name: str | None = None) -> _RuleBuilder: diff --git a/python/egglog/examples/bool.py b/python/egglog/examples/bool.py index 4a074435..be2cc8d4 100644 --- a/python/egglog/examples/bool.py +++ b/python/egglog/examples/bool.py @@ -27,8 +27,7 @@ @function -def f(i: i64Like) -> Bool: - ... +def f(i: i64Like) -> Bool: ... i = var("i", i64) diff --git a/python/egglog/examples/eqsat_basic.py b/python/egglog/examples/eqsat_basic.py index 35a6613e..e2bf99f2 100644 --- a/python/egglog/examples/eqsat_basic.py +++ b/python/egglog/examples/eqsat_basic.py @@ -3,6 +3,7 @@ Basic equality saturation example. ================================== """ + from __future__ import annotations from egglog import * @@ -11,18 +12,14 @@ class Num(Expr): - def __init__(self, value: i64Like) -> None: - ... + def __init__(self, value: i64Like) -> None: ... @classmethod - def var(cls, name: StringLike) -> Num: - ... + def var(cls, name: StringLike) -> Num: ... - def __add__(self, other: Num) -> Num: - ... + def __add__(self, other: Num) -> Num: ... - def __mul__(self, other: Num) -> Num: - ... + def __mul__(self, other: Num) -> Num: ... expr1 = Num(2) * (Num.var("x") + Num(3)) diff --git a/python/egglog/examples/fib.py b/python/egglog/examples/fib.py index 4d7dc381..c91bb96d 100644 --- a/python/egglog/examples/fib.py +++ b/python/egglog/examples/fib.py @@ -3,14 +3,14 @@ Fibonacci numbers example ========================= """ + from __future__ import annotations from egglog import * @function -def fib(x: i64Like) -> i64: - ... +def fib(x: i64Like) -> i64: ... f0, f1, x = vars_("f0 f1 x", i64) diff --git a/python/egglog/examples/lambda_.py b/python/egglog/examples/lambda_.py index 76aa0efe..d22cddb8 100644 --- a/python/egglog/examples/lambda_.py +++ b/python/egglog/examples/lambda_.py @@ -4,6 +4,7 @@ Lambda Calculus =============== """ + from __future__ import annotations from typing import TYPE_CHECKING, ClassVar @@ -22,66 +23,53 @@ class Val(Expr): TRUE: ClassVar[Val] FALSE: ClassVar[Val] - def __init__(self, v: i64Like) -> None: - ... + def __init__(self, v: i64Like) -> None: ... class Var(Expr): - def __init__(self, v: StringLike) -> None: - ... + def __init__(self, v: StringLike) -> None: ... class Term(Expr): @classmethod - def val(cls, v: Val) -> Term: - ... + def val(cls, v: Val) -> Term: ... @classmethod - def var(cls, v: Var) -> Term: - ... + def var(cls, v: Var) -> Term: ... - def __add__(self, other: Term) -> Term: - ... + def __add__(self, other: Term) -> Term: ... def __eq__(self, other: Term) -> Term: # type: ignore[override] ... - def __call__(self, other: Term) -> Term: - ... + def __call__(self, other: Term) -> Term: ... - def eval(self) -> Val: - ... + def eval(self) -> Val: ... - def v(self) -> Var: - ... + def v(self) -> Var: ... @function -def lam(x: Var, t: Term) -> Term: - ... +def lam(x: Var, t: Term) -> Term: ... @function -def let_(x: Var, t: Term, b: Term) -> Term: - ... +def let_(x: Var, t: Term, b: Term) -> Term: ... @function -def fix(x: Var, t: Term) -> Term: - ... +def fix(x: Var, t: Term) -> Term: ... @function -def if_(c: Term, t: Term, f: Term) -> Term: - ... +def if_(c: Term, t: Term, f: Term) -> Term: ... StringSet = Set[Var] @function(merge=lambda old, new: old & new) -def freer(t: Term) -> StringSet: - ... +def freer(t: Term) -> StringSet: ... (v, v1, v2) = vars_("v v1 v2", Val) diff --git a/python/egglog/examples/matrix.py b/python/egglog/examples/matrix.py index ddacf46e..2fdd3861 100644 --- a/python/egglog/examples/matrix.py +++ b/python/egglog/examples/matrix.py @@ -2,6 +2,7 @@ Matrix multiplication and Kronecker product. ============================================ """ + from __future__ import annotations from egglog import * @@ -18,8 +19,7 @@ class Dim(Expr): """ @method(egg_fn="Lit") - def __init__(self, value: i64Like) -> None: - ... + def __init__(self, value: i64Like) -> None: ... @method(egg_fn="NamedDim") @classmethod diff --git a/python/egglog/examples/ndarrays.py b/python/egglog/examples/ndarrays.py index 7639b76c..943fd256 100644 --- a/python/egglog/examples/ndarrays.py +++ b/python/egglog/examples/ndarrays.py @@ -6,6 +6,7 @@ Example of building NDarray in the vein of Mathemetics of Arrays. """ + from __future__ import annotations from egglog import * @@ -14,14 +15,11 @@ class Value(Expr): - def __init__(self, v: i64Like) -> None: - ... + def __init__(self, v: i64Like) -> None: ... - def __mul__(self, other: Value) -> Value: - ... + def __mul__(self, other: Value) -> Value: ... - def __add__(self, other: Value) -> Value: - ... + def __add__(self, other: Value) -> Value: ... i, j = vars_("i j", i64) @@ -32,17 +30,13 @@ def __add__(self, other: Value) -> Value: class Values(Expr): - def __init__(self, v: Vec[Value]) -> None: - ... + def __init__(self, v: Vec[Value]) -> None: ... - def __getitem__(self, idx: Value) -> Value: - ... + def __getitem__(self, idx: Value) -> Value: ... - def length(self) -> Value: - ... + def length(self) -> Value: ... - def concat(self, other: Values) -> Values: - ... + def concat(self, other: Values) -> Values: ... @egraph.register @@ -59,16 +53,13 @@ class NDArray(Expr): An n-dimensional array. """ - def __getitem__(self, idx: Values) -> Value: - ... + def __getitem__(self, idx: Values) -> Value: ... - def shape(self) -> Values: - ... + def shape(self) -> Values: ... @function -def arange(n: Value) -> NDArray: - ... +def arange(n: Value) -> NDArray: ... @egraph.register @@ -94,8 +85,7 @@ def assert_simplifies(left: Expr, right: Expr) -> None: @function -def py_value(s: StringLike) -> Value: - ... +def py_value(s: StringLike) -> Value: ... @egraph.register @@ -105,8 +95,7 @@ def _py_value(l: String, r: String): @function -def py_values(s: StringLike) -> Values: - ... +def py_values(s: StringLike) -> Values: ... @egraph.register @@ -117,8 +106,7 @@ def _py_values(l: String, r: String): @function -def py_ndarray(s: StringLike) -> NDArray: - ... +def py_ndarray(s: StringLike) -> NDArray: ... @egraph.register @@ -134,8 +122,7 @@ def _py_ndarray(l: String, r: String): @function -def cross(l: NDArray, r: NDArray) -> NDArray: - ... +def cross(l: NDArray, r: NDArray) -> NDArray: ... @egraph.register diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index 80acc792..36a82e6a 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -34,14 +34,11 @@ def __bool__(self) -> bool: return try_evaling(self, self.bool) @property - def bool(self) -> Bool: - ... + def bool(self) -> Bool: ... - def __or__(self, other: Boolean) -> Boolean: - ... + def __or__(self, other: Boolean) -> Boolean: ... - def __and__(self, other: Boolean) -> Boolean: - ... + def __and__(self, other: Boolean) -> Boolean: ... TRUE = constant("TRUE", Boolean) @@ -91,22 +88,18 @@ class IsDtypeKind(Expr): NULL: ClassVar[IsDtypeKind] @classmethod - def string(cls, s: StringLike) -> IsDtypeKind: - ... + def string(cls, s: StringLike) -> IsDtypeKind: ... @classmethod - def dtype(cls, d: DType) -> IsDtypeKind: - ... + def dtype(cls, d: DType) -> IsDtypeKind: ... @method(cost=10) - def __or__(self, other: IsDtypeKind) -> IsDtypeKind: - ... + def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ... # TODO: Make kind more generic to support tuples. @function -def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: - ... +def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ... converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x)) @@ -142,17 +135,13 @@ def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind): class Int(Expr): - def __init__(self, value: i64Like) -> None: - ... + def __init__(self, value: i64Like) -> None: ... - def __invert__(self) -> Int: - ... + def __invert__(self) -> Int: ... - def __lt__(self, other: Int) -> Boolean: - ... + def __lt__(self, other: Int) -> Boolean: ... - def __le__(self, other: Int) -> Boolean: - ... + def __le__(self, other: Int) -> Boolean: ... def __eq__(self, other: Int) -> Boolean: # type: ignore[override] ... @@ -163,93 +152,64 @@ def __eq__(self, other: Int) -> Boolean: # type: ignore[override] def __ne__(self, other: Int) -> bool: # type: ignore[override] return not (self == other) - def __gt__(self, other: Int) -> Boolean: - ... + def __gt__(self, other: Int) -> Boolean: ... - def __ge__(self, other: Int) -> Boolean: - ... + def __ge__(self, other: Int) -> Boolean: ... - def __add__(self, other: Int) -> Int: - ... + def __add__(self, other: Int) -> Int: ... - def __sub__(self, other: Int) -> Int: - ... + def __sub__(self, other: Int) -> Int: ... - def __mul__(self, other: Int) -> Int: - ... + def __mul__(self, other: Int) -> Int: ... - def __truediv__(self, other: Int) -> Int: - ... + def __truediv__(self, other: Int) -> Int: ... - def __floordiv__(self, other: Int) -> Int: - ... + def __floordiv__(self, other: Int) -> Int: ... - def __mod__(self, other: Int) -> Int: - ... + def __mod__(self, other: Int) -> Int: ... - def __divmod__(self, other: Int) -> Int: - ... + def __divmod__(self, other: Int) -> Int: ... - def __pow__(self, other: Int) -> Int: - ... + def __pow__(self, other: Int) -> Int: ... - def __lshift__(self, other: Int) -> Int: - ... + def __lshift__(self, other: Int) -> Int: ... - def __rshift__(self, other: Int) -> Int: - ... + def __rshift__(self, other: Int) -> Int: ... - def __and__(self, other: Int) -> Int: - ... + def __and__(self, other: Int) -> Int: ... - def __xor__(self, other: Int) -> Int: - ... + def __xor__(self, other: Int) -> Int: ... - def __or__(self, other: Int) -> Int: - ... + def __or__(self, other: Int) -> Int: ... - def __radd__(self, other: Int) -> Int: - ... + def __radd__(self, other: Int) -> Int: ... - def __rsub__(self, other: Int) -> Int: - ... + def __rsub__(self, other: Int) -> Int: ... - def __rmul__(self, other: Int) -> Int: - ... + def __rmul__(self, other: Int) -> Int: ... - def __rmatmul__(self, other: Int) -> Int: - ... + def __rmatmul__(self, other: Int) -> Int: ... - def __rtruediv__(self, other: Int) -> Int: - ... + def __rtruediv__(self, other: Int) -> Int: ... - def __rfloordiv__(self, other: Int) -> Int: - ... + def __rfloordiv__(self, other: Int) -> Int: ... - def __rmod__(self, other: Int) -> Int: - ... + def __rmod__(self, other: Int) -> Int: ... - def __rpow__(self, other: Int) -> Int: - ... + def __rpow__(self, other: Int) -> Int: ... - def __rlshift__(self, other: Int) -> Int: - ... + def __rlshift__(self, other: Int) -> Int: ... - def __rrshift__(self, other: Int) -> Int: - ... + def __rrshift__(self, other: Int) -> Int: ... - def __rand__(self, other: Int) -> Int: - ... + def __rand__(self, other: Int) -> Int: ... - def __rxor__(self, other: Int) -> Int: - ... + def __rxor__(self, other: Int) -> Int: ... - def __ror__(self, other: Int) -> Int: - ... + def __ror__(self, other: Int) -> Int: ... @property - def i64(self) -> i64: - ... + def i64(self) -> i64: ... @method(preserve=True) def __int__(self) -> int: @@ -304,31 +264,23 @@ def _int(i: i64, j: i64, r: Boolean, o: Int): class Float(Expr): - def __init__(self, value: f64Like) -> None: - ... + def __init__(self, value: f64Like) -> None: ... - def abs(self) -> Float: - ... + def abs(self) -> Float: ... @classmethod - def rational(cls, r: Rational) -> Float: - ... + def rational(cls, r: Rational) -> Float: ... @classmethod - def from_int(cls, i: Int) -> Float: - ... + def from_int(cls, i: Int) -> Float: ... - def __truediv__(self, other: Float) -> Float: - ... + def __truediv__(self, other: Float) -> Float: ... - def __mul__(self, other: Float) -> Float: - ... + def __mul__(self, other: Float) -> Float: ... - def __add__(self, other: Float) -> Float: - ... + def __add__(self, other: Float) -> Float: ... - def __sub__(self, other: Float) -> Float: - ... + def __sub__(self, other: Float) -> Float: ... converter(float, Float, lambda x: Float(x)) @@ -356,14 +308,11 @@ def _float(f: f64, f2: f64, i: i64, r: Rational, r1: Rational): class TupleInt(Expr): EMPTY: ClassVar[TupleInt] - def __init__(self, head: Int) -> None: - ... + def __init__(self, head: Int) -> None: ... - def __add__(self, other: TupleInt) -> TupleInt: - ... + def __add__(self, other: TupleInt) -> TupleInt: ... - def length(self) -> Int: - ... + def length(self) -> Int: ... @method(preserve=True) def __len__(self) -> int: @@ -373,11 +322,9 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[Int]: return iter(self[Int(i)] for i in range(len(self))) - def __getitem__(self, i: Int) -> Int: - ... + def __getitem__(self, i: Int) -> Int: ... - def product(self) -> Int: - ... + def product(self) -> Int: ... converter( @@ -412,8 +359,7 @@ class OptionalInt(Expr): none: ClassVar[OptionalInt] @classmethod - def some(cls, value: Int) -> OptionalInt: - ... + def some(cls, value: Int) -> OptionalInt: ... converter(type(None), OptionalInt, lambda _: OptionalInt.none) @@ -426,8 +372,7 @@ def __init__( start: OptionalInt = OptionalInt.none, stop: OptionalInt = OptionalInt.none, step: OptionalInt = OptionalInt.none, - ) -> None: - ... + ) -> None: ... converter( @@ -442,12 +387,10 @@ class MultiAxisIndexKeyItem(Expr): NONE: ClassVar[MultiAxisIndexKeyItem] @classmethod - def int(cls, i: Int) -> MultiAxisIndexKeyItem: - ... + def int(cls, i: Int) -> MultiAxisIndexKeyItem: ... @classmethod - def slice(cls, slice: Slice) -> MultiAxisIndexKeyItem: - ... + def slice(cls, slice: Slice) -> MultiAxisIndexKeyItem: ... converter(type(...), MultiAxisIndexKeyItem, lambda _: MultiAxisIndexKeyItem.ELLIPSIS) @@ -457,13 +400,11 @@ def slice(cls, slice: Slice) -> MultiAxisIndexKeyItem: class MultiAxisIndexKey(Expr): - def __init__(self, item: MultiAxisIndexKeyItem) -> None: - ... + def __init__(self, item: MultiAxisIndexKeyItem) -> None: ... EMPTY: ClassVar[MultiAxisIndexKey] - def __add__(self, other: MultiAxisIndexKey) -> MultiAxisIndexKey: - ... + def __add__(self, other: MultiAxisIndexKey) -> MultiAxisIndexKey: ... converter( @@ -491,12 +432,10 @@ class IndexKey(Expr): ELLIPSIS: ClassVar[IndexKey] @classmethod - def int(cls, i: Int) -> IndexKey: - ... + def int(cls, i: Int) -> IndexKey: ... @classmethod - def slice(cls, slice: Slice) -> IndexKey: - ... + def slice(cls, slice: Slice) -> IndexKey: ... # Disabled until we support late binding # @classmethod @@ -504,8 +443,7 @@ def slice(cls, slice: Slice) -> IndexKey: # ... @classmethod - def multi_axis(cls, key: MultiAxisIndexKey) -> IndexKey: - ... + def multi_axis(cls, key: MultiAxisIndexKey) -> IndexKey: ... converter(type(...), IndexKey, lambda _: IndexKey.ELLIPSIS) @@ -514,8 +452,7 @@ def multi_axis(cls, key: MultiAxisIndexKey) -> IndexKey: converter(MultiAxisIndexKey, IndexKey, IndexKey.multi_axis) -class Device(Expr): - ... +class Device(Expr): ... ALL_INDICES: TupleInt = constant("ALL_INDICES", TupleInt) @@ -527,28 +464,21 @@ class Device(Expr): class Value(Expr): @classmethod - def int(cls, i: Int) -> Value: - ... + def int(cls, i: Int) -> Value: ... @classmethod - def float(cls, f: Float) -> Value: - ... + def float(cls, f: Float) -> Value: ... @classmethod - def bool(cls, b: Boolean) -> Value: - ... + def bool(cls, b: Boolean) -> Value: ... - def isfinite(self) -> Boolean: - ... + def isfinite(self) -> Boolean: ... - def __lt__(self, other: Value) -> Value: - ... + def __lt__(self, other: Value) -> Value: ... - def __truediv__(self, other: Value) -> Value: - ... + def __truediv__(self, other: Value) -> Value: ... - def astype(self, dtype: DType) -> Value: - ... + def astype(self, dtype: DType) -> Value: ... # TODO: Add all operations @@ -559,12 +489,10 @@ def dtype(self) -> DType: """ @property - def to_bool(self) -> Boolean: - ... + def to_bool(self) -> Boolean: ... @property - def to_int(self) -> Int: - ... + def to_int(self) -> Int: ... @property def to_truthy_value(self) -> Value: @@ -599,20 +527,15 @@ def _value(i: Int, f: Float, b: Boolean): class TupleValue(Expr): EMPTY: ClassVar[TupleValue] - def __init__(self, head: Value) -> None: - ... + def __init__(self, head: Value) -> None: ... - def __add__(self, other: TupleValue) -> TupleValue: - ... + def __add__(self, other: TupleValue) -> TupleValue: ... - def length(self) -> Int: - ... + def length(self) -> Int: ... - def __getitem__(self, i: Int) -> Value: - ... + def __getitem__(self, i: Int) -> Value: ... - def includes(self, value: Value) -> Boolean: - ... + def includes(self, value: Value) -> Boolean: ... converter( @@ -660,41 +583,34 @@ def possible_values(values: Value) -> TupleValue: class NDArray(Expr): - def __init__(self, py_array: PyObject) -> None: - ... + def __init__(self, py_array: PyObject) -> None: ... @method(cost=200) @classmethod - def var(cls, name: StringLike) -> NDArray: - ... + def var(cls, name: StringLike) -> NDArray: ... @method(preserve=True) def __array_namespace__(self, api_version: object = None) -> ModuleType: return sys.modules[__name__] @property - def ndim(self) -> Int: - ... + def ndim(self) -> Int: ... @property - def dtype(self) -> DType: - ... + def dtype(self) -> DType: ... @property - def device(self) -> Device: - ... + def device(self) -> Device: ... @property - def shape(self) -> TupleInt: - ... + def shape(self) -> TupleInt: ... @method(preserve=True) def __bool__(self) -> bool: return bool(self.to_value().to_bool) @property - def size(self) -> Int: - ... + def size(self) -> Int: ... @method(preserve=True) def __len__(self) -> int: @@ -705,17 +621,13 @@ def __iter__(self) -> Iterator[NDArray]: for i in range(len(self)): yield self[IndexKey.int(Int(i))] - def __getitem__(self, key: IndexKey) -> NDArray: - ... + def __getitem__(self, key: IndexKey) -> NDArray: ... - def __setitem__(self, key: IndexKey, value: NDArray) -> None: - ... + def __setitem__(self, key: IndexKey, value: NDArray) -> None: ... - def __lt__(self, other: NDArray) -> NDArray: - ... + def __lt__(self, other: NDArray) -> NDArray: ... - def __le__(self, other: NDArray) -> NDArray: - ... + def __le__(self, other: NDArray) -> NDArray: ... def __eq__(self, other: NDArray) -> NDArray: # type: ignore[override] ... @@ -724,99 +636,68 @@ def __eq__(self, other: NDArray) -> NDArray: # type: ignore[override] # def __ne__(self, other: NDArray) -> NDArray: # type: ignore[override] # ... - def __gt__(self, other: NDArray) -> NDArray: - ... + def __gt__(self, other: NDArray) -> NDArray: ... - def __ge__(self, other: NDArray) -> NDArray: - ... + def __ge__(self, other: NDArray) -> NDArray: ... - def __add__(self, other: NDArray) -> NDArray: - ... + def __add__(self, other: NDArray) -> NDArray: ... - def __sub__(self, other: NDArray) -> NDArray: - ... + def __sub__(self, other: NDArray) -> NDArray: ... - def __mul__(self, other: NDArray) -> NDArray: - ... + def __mul__(self, other: NDArray) -> NDArray: ... - def __matmul__(self, other: NDArray) -> NDArray: - ... + def __matmul__(self, other: NDArray) -> NDArray: ... - def __truediv__(self, other: NDArray) -> NDArray: - ... + def __truediv__(self, other: NDArray) -> NDArray: ... - def __floordiv__(self, other: NDArray) -> NDArray: - ... + def __floordiv__(self, other: NDArray) -> NDArray: ... - def __mod__(self, other: NDArray) -> NDArray: - ... + def __mod__(self, other: NDArray) -> NDArray: ... - def __divmod__(self, other: NDArray) -> NDArray: - ... + def __divmod__(self, other: NDArray) -> NDArray: ... - def __pow__(self, other: NDArray) -> NDArray: - ... + def __pow__(self, other: NDArray) -> NDArray: ... - def __lshift__(self, other: NDArray) -> NDArray: - ... + def __lshift__(self, other: NDArray) -> NDArray: ... - def __rshift__(self, other: NDArray) -> NDArray: - ... + def __rshift__(self, other: NDArray) -> NDArray: ... - def __and__(self, other: NDArray) -> NDArray: - ... + def __and__(self, other: NDArray) -> NDArray: ... - def __xor__(self, other: NDArray) -> NDArray: - ... + def __xor__(self, other: NDArray) -> NDArray: ... - def __or__(self, other: NDArray) -> NDArray: - ... + def __or__(self, other: NDArray) -> NDArray: ... - def __radd__(self, other: NDArray) -> NDArray: - ... + def __radd__(self, other: NDArray) -> NDArray: ... - def __rsub__(self, other: NDArray) -> NDArray: - ... + def __rsub__(self, other: NDArray) -> NDArray: ... - def __rmul__(self, other: NDArray) -> NDArray: - ... + def __rmul__(self, other: NDArray) -> NDArray: ... - def __rmatmul__(self, other: NDArray) -> NDArray: - ... + def __rmatmul__(self, other: NDArray) -> NDArray: ... - def __rtruediv__(self, other: NDArray) -> NDArray: - ... + def __rtruediv__(self, other: NDArray) -> NDArray: ... - def __rfloordiv__(self, other: NDArray) -> NDArray: - ... + def __rfloordiv__(self, other: NDArray) -> NDArray: ... - def __rmod__(self, other: NDArray) -> NDArray: - ... + def __rmod__(self, other: NDArray) -> NDArray: ... - def __rpow__(self, other: NDArray) -> NDArray: - ... + def __rpow__(self, other: NDArray) -> NDArray: ... - def __rlshift__(self, other: NDArray) -> NDArray: - ... + def __rlshift__(self, other: NDArray) -> NDArray: ... - def __rrshift__(self, other: NDArray) -> NDArray: - ... + def __rrshift__(self, other: NDArray) -> NDArray: ... - def __rand__(self, other: NDArray) -> NDArray: - ... + def __rand__(self, other: NDArray) -> NDArray: ... - def __rxor__(self, other: NDArray) -> NDArray: - ... + def __rxor__(self, other: NDArray) -> NDArray: ... - def __ror__(self, other: NDArray) -> NDArray: - ... + def __ror__(self, other: NDArray) -> NDArray: ... @classmethod - def scalar(cls, value: Value) -> NDArray: - ... + def scalar(cls, value: Value) -> NDArray: ... - def to_value(self) -> Value: - ... + def to_value(self) -> Value: ... @property def T(self) -> NDArray: @@ -825,8 +706,7 @@ def T(self) -> NDArray: """ @classmethod - def vector(cls, values: TupleValue) -> NDArray: - ... + def vector(cls, values: TupleValue) -> NDArray: ... def index(self, indices: TupleInt) -> Value: """ @@ -877,14 +757,11 @@ def _ndarray(x: NDArray, b: Boolean, f: Float, fi1: f64, fi2: f64): class TupleNDArray(Expr): EMPTY: ClassVar[TupleNDArray] - def __init__(self, head: NDArray) -> None: - ... + def __init__(self, head: NDArray) -> None: ... - def __add__(self, other: TupleNDArray) -> TupleNDArray: - ... + def __add__(self, other: TupleNDArray) -> TupleNDArray: ... - def length(self) -> Int: - ... + def length(self) -> Int: ... @method(preserve=True) def __len__(self) -> int: @@ -894,8 +771,7 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[NDArray]: return iter(self[Int(i)] for i in range(len(self))) - def __getitem__(self, i: Int) -> NDArray: - ... + def __getitem__(self, i: Int) -> NDArray: ... converter( @@ -927,8 +803,7 @@ class OptionalBool(Expr): none: ClassVar[OptionalBool] @classmethod - def some(cls, value: Boolean) -> OptionalBool: - ... + def some(cls, value: Boolean) -> OptionalBool: ... converter(type(None), OptionalBool, lambda _: OptionalBool.none) @@ -939,8 +814,7 @@ class OptionalDType(Expr): none: ClassVar[OptionalDType] @classmethod - def some(cls, value: DType) -> OptionalDType: - ... + def some(cls, value: DType) -> OptionalDType: ... converter(type(None), OptionalDType, lambda _: OptionalDType.none) @@ -951,8 +825,7 @@ class OptionalDevice(Expr): none: ClassVar[OptionalDevice] @classmethod - def some(cls, value: Device) -> OptionalDevice: - ... + def some(cls, value: Device) -> OptionalDevice: ... converter(type(None), OptionalDevice, lambda _: OptionalDevice.none) @@ -963,8 +836,7 @@ class OptionalTupleInt(Expr): none: ClassVar[OptionalTupleInt] @classmethod - def some(cls, value: TupleInt) -> OptionalTupleInt: - ... + def some(cls, value: TupleInt) -> OptionalTupleInt: ... converter(type(None), OptionalTupleInt, lambda _: OptionalTupleInt.none) @@ -975,12 +847,10 @@ class IntOrTuple(Expr): none: ClassVar[IntOrTuple] @classmethod - def int(cls, value: Int) -> IntOrTuple: - ... + def int(cls, value: Int) -> IntOrTuple: ... @classmethod - def tuple(cls, value: TupleInt) -> IntOrTuple: - ... + def tuple(cls, value: TupleInt) -> IntOrTuple: ... converter(Int, IntOrTuple, IntOrTuple.int) @@ -991,8 +861,7 @@ class OptionalIntOrTuple(Expr): none: ClassVar[OptionalIntOrTuple] @classmethod - def some(cls, value: IntOrTuple) -> OptionalIntOrTuple: - ... + def some(cls, value: IntOrTuple) -> OptionalIntOrTuple: ... converter(type(None), OptionalIntOrTuple, lambda _: OptionalIntOrTuple.none) @@ -1000,8 +869,9 @@ def some(cls, value: IntOrTuple) -> OptionalIntOrTuple: @function -def asarray(a: NDArray, dtype: OptionalDType = OptionalDType.none, copy: OptionalBool = OptionalBool.none) -> NDArray: - ... +def asarray( + a: NDArray, dtype: OptionalDType = OptionalDType.none, copy: OptionalBool = OptionalBool.none +) -> NDArray: ... @array_api_ruleset.register @@ -1011,8 +881,7 @@ def _assarray(a: NDArray, d: OptionalDType, ob: OptionalBool): @function -def isfinite(x: NDArray) -> NDArray: - ... +def isfinite(x: NDArray) -> NDArray: ... @function @@ -1031,8 +900,7 @@ def _sum(x: NDArray, y: NDArray, v: Value, dtype: DType): @function -def reshape(x: NDArray, shape: TupleInt, copy: OptionalBool = OptionalBool.none) -> NDArray: - ... +def reshape(x: NDArray, shape: TupleInt, copy: OptionalBool = OptionalBool.none) -> NDArray: ... # @function @@ -1074,8 +942,7 @@ def reshape(x: NDArray, shape: TupleInt, copy: OptionalBool = OptionalBool.none) @function -def unique_values(x: NDArray) -> NDArray: - ... +def unique_values(x: NDArray) -> NDArray: ... @array_api_ruleset.register @@ -1086,8 +953,7 @@ def _unique_values(x: NDArray): @function -def concat(arrays: TupleNDArray, axis: OptionalInt = OptionalInt.none) -> NDArray: - ... +def concat(arrays: TupleNDArray, axis: OptionalInt = OptionalInt.none) -> NDArray: ... @array_api_ruleset.register @@ -1098,8 +964,7 @@ def _concat(x: NDArray): @function -def astype(x: NDArray, dtype: DType) -> NDArray: - ... +def astype(x: NDArray, dtype: DType) -> NDArray: ... @array_api_ruleset.register @@ -1113,8 +978,7 @@ def _astype(x: NDArray, dtype: DType, i: i64): @function(cost=500) -def unique_counts(x: NDArray) -> TupleNDArray: - ... +def unique_counts(x: NDArray) -> TupleNDArray: ... @array_api_ruleset.register @@ -1130,23 +994,19 @@ def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value, dtype: DTyp @function -def square(x: NDArray) -> NDArray: - ... +def square(x: NDArray) -> NDArray: ... @function -def any(x: NDArray) -> NDArray: - ... +def any(x: NDArray) -> NDArray: ... @function(egg_fn="ndarray-abs") -def abs(x: NDArray) -> NDArray: - ... +def abs(x: NDArray) -> NDArray: ... @function(egg_fn="ndarray-log") -def log(x: NDArray) -> NDArray: - ... +def log(x: NDArray) -> NDArray: ... @array_api_ruleset.register @@ -1157,8 +1017,7 @@ def _abs(f: Float): @function(cost=100) -def unique_inverse(x: NDArray) -> TupleNDArray: - ... +def unique_inverse(x: NDArray) -> TupleNDArray: ... @array_api_ruleset.register @@ -1173,29 +1032,24 @@ def _unique_inverse(x: NDArray, i: Int): @function def zeros( shape: TupleInt, dtype: OptionalDType = OptionalDType.none, device: OptionalDevice = OptionalDevice.none -) -> NDArray: - ... +) -> NDArray: ... @function -def expand_dims(x: NDArray, axis: Int = Int(0)) -> NDArray: - ... +def expand_dims(x: NDArray, axis: Int = Int(0)) -> NDArray: ... @function(cost=100000) -def mean(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none, keepdims: Boolean = FALSE) -> NDArray: - ... +def mean(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none, keepdims: Boolean = FALSE) -> NDArray: ... # TODO: Possibly change names to include modules. @function(egg_fn="ndarray-sqrt") -def sqrt(x: NDArray) -> NDArray: - ... +def sqrt(x: NDArray) -> NDArray: ... @function(cost=100000) -def std(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray: - ... +def std(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray: ... linalg = sys.modules[__name__] @@ -1484,13 +1338,11 @@ def _size(x: NDArray): @overload -def try_evaling(expr: Expr, prim_expr: i64) -> int: - ... +def try_evaling(expr: Expr, prim_expr: i64) -> int: ... @overload -def try_evaling(expr: Expr, prim_expr: Bool) -> bool: - ... +def try_evaling(expr: Expr, prim_expr: Bool) -> bool: ... def try_evaling(expr: Expr, prim_expr: i64 | Bool) -> int | bool: diff --git a/python/egglog/exp/array_api_program_gen.py b/python/egglog/exp/array_api_program_gen.py index 486465f8..adaf2b2d 100644 --- a/python/egglog/exp/array_api_program_gen.py +++ b/python/egglog/exp/array_api_program_gen.py @@ -19,8 +19,7 @@ @function -def bool_program(x: Boolean) -> Program: - ... +def bool_program(x: Boolean) -> Program: ... @array_api_program_gen_ruleset.register @@ -30,8 +29,7 @@ def _bool_program(): @function -def int_program(x: Int) -> Program: - ... +def int_program(x: Int) -> Program: ... @array_api_program_gen_ruleset.register @@ -58,13 +56,11 @@ def _int_program(i64_: i64, i: Int, j: Int): @function -def tuple_int_program(x: TupleInt) -> Program: - ... +def tuple_int_program(x: TupleInt) -> Program: ... @function -def tuple_int_program_inner(x: TupleInt) -> Program: - ... +def tuple_int_program_inner(x: TupleInt) -> Program: ... @array_api_program_gen_ruleset.register @@ -79,13 +75,11 @@ def _tuple_int_program(i: Int, j: Int, ti: TupleInt, ti1: TupleInt, ti2: TupleIn @function -def ndarray_program(x: NDArray) -> Program: - ... +def ndarray_program(x: NDArray) -> Program: ... @function -def ndarray_function_two(res: NDArray, l: NDArray, r: NDArray) -> Program: - ... +def ndarray_function_two(res: NDArray, l: NDArray, r: NDArray) -> Program: ... @array_api_program_gen_ruleset.register @@ -98,8 +92,7 @@ def _ndarray_function_two(f: Program, res: NDArray, l: NDArray, r: NDArray, o: P @function -def dtype_program(x: DType) -> Program: - ... +def dtype_program(x: DType) -> Program: ... @array_api_program_gen_ruleset.register @@ -113,8 +106,7 @@ def _dtype_program(): @function -def float_program(x: Float) -> Program: - ... +def float_program(x: Float) -> Program: ... @array_api_program_gen_ruleset.register @@ -136,8 +128,7 @@ def _float_program(f: Float, g: Float, f64_: f64, i: Int, r: Rational): @function -def value_program(x: Value) -> Program: - ... +def value_program(x: Value) -> Program: ... @array_api_program_gen_ruleset.register @@ -154,13 +145,11 @@ def _value_program(i: Int, b: Boolean, f: Float, x: NDArray, v1: Value, v2: Valu @function -def tuple_value_program(x: TupleValue) -> Program: - ... +def tuple_value_program(x: TupleValue) -> Program: ... @function -def tuple_value_program_inner(x: TupleValue) -> Program: - ... +def tuple_value_program_inner(x: TupleValue) -> Program: ... @array_api_program_gen_ruleset.register @@ -173,8 +162,7 @@ def _tuple_value_program(tv1: TupleValue, tv2: TupleValue, v: Value): @function -def tuple_ndarray_program(x: TupleNDArray) -> Program: - ... +def tuple_ndarray_program(x: TupleNDArray) -> Program: ... @function @@ -197,8 +185,7 @@ def _tuple_ndarray_program(x: NDArray, l: TupleNDArray, r: TupleNDArray, i: Int) @function -def optional_dtype_program(x: OptionalDType) -> Program: - ... +def optional_dtype_program(x: OptionalDType) -> Program: ... @array_api_program_gen_ruleset.register @@ -208,8 +195,7 @@ def _optional_dtype_program(dtype: DType): @function -def optional_int_program(x: OptionalInt) -> Program: - ... +def optional_int_program(x: OptionalInt) -> Program: ... @array_api_program_gen_ruleset.register @@ -232,8 +218,7 @@ def _optional_int_slice_program(x: Int): @function -def slice_program(x: Slice) -> Program: - ... +def slice_program(x: Slice) -> Program: ... @array_api_program_gen_ruleset.register @@ -247,8 +232,7 @@ def _slice_program(start: OptionalInt, stop: OptionalInt, i: Int): @function -def multi_axis_index_key_item_program(x: MultiAxisIndexKeyItem) -> Program: - ... +def multi_axis_index_key_item_program(x: MultiAxisIndexKeyItem) -> Program: ... @array_api_program_gen_ruleset.register @@ -260,8 +244,7 @@ def _multi_axis_index_key_item_program(i: Int, s: Slice): @function -def multi_axis_index_key_program(x: MultiAxisIndexKey) -> Program: - ... +def multi_axis_index_key_program(x: MultiAxisIndexKey) -> Program: ... @array_api_program_gen_ruleset.register @@ -274,8 +257,7 @@ def _multi_axis_index_key_program(l: MultiAxisIndexKey, r: MultiAxisIndexKey, it @function -def index_key_program(x: IndexKey) -> Program: - ... +def index_key_program(x: IndexKey) -> Program: ... @array_api_program_gen_ruleset.register @@ -288,8 +270,7 @@ def _index_key_program(i: Int, s: Slice, key: MultiAxisIndexKey, a: NDArray): @function -def int_or_tuple_program(x: IntOrTuple) -> Program: - ... +def int_or_tuple_program(x: IntOrTuple) -> Program: ... @array_api_program_gen_ruleset.register @@ -299,8 +280,7 @@ def _int_or_tuple_program(x: Int, t: TupleInt): @function -def optional_int_or_tuple_program(x: OptionalIntOrTuple) -> Program: - ... +def optional_int_or_tuple_program(x: OptionalIntOrTuple) -> Program: ... @array_api_program_gen_ruleset.register diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index 7f10c5f4..97814512 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -359,7 +359,7 @@ def __call__(self, *args: object) -> RuntimeExpr | None: def __getattr__(self, name: str) -> RuntimeClassMethod | RuntimeClass: # Special case so when get_type_annotations proccessed it can work - if name in {"__origin__"}: + if name == "__origin__": return RuntimeClass(self.__egg_decls__.update_other, self.__egg_tp__.name) return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), name) diff --git a/python/tests/test_convert.py b/python/tests/test_convert.py index e4c53610..a2e651c2 100644 --- a/python/tests/test_convert.py +++ b/python/tests/test_convert.py @@ -11,8 +11,7 @@ class MyType(metaclass=MyMeta): EGraph() class MyTypeExpr(Expr): - def __init__(self) -> None: - ... + def __init__(self) -> None: ... converter(MyMeta, MyTypeExpr, lambda x: MyTypeExpr()) assert expr_parts(convert(MyType(), MyTypeExpr)) == expr_parts(MyTypeExpr()) @@ -25,8 +24,7 @@ class MyType: pass class MyTypeExpr(Expr): - def __init__(self) -> None: - ... + def __init__(self) -> None: ... converter(MyType, MyTypeExpr, lambda x: MyTypeExpr()) @@ -40,12 +38,10 @@ class MyType: pass class MyTypeExpr(Expr): - def __init__(self) -> None: - ... + def __init__(self) -> None: ... class MyTypeExpr2(Expr): - def __init__(self) -> None: - ... + def __init__(self) -> None: ... converter(MyType, MyTypeExpr, lambda x: MyTypeExpr()) converter(MyTypeExpr, MyTypeExpr2, lambda x: MyTypeExpr2()) @@ -60,12 +56,10 @@ class MyType: pass class MyTypeExpr(Expr): - def __init__(self) -> None: - ... + def __init__(self) -> None: ... class MyTypeExpr2(Expr): - def __init__(self) -> None: - ... + def __init__(self) -> None: ... converter(MyTypeExpr, MyTypeExpr2, lambda x: MyTypeExpr2()) converter(MyType, MyTypeExpr, lambda x: MyTypeExpr()) @@ -79,12 +73,10 @@ class MyType: pass class MyTypeExpr(Expr): - def __init__(self) -> None: - ... + def __init__(self) -> None: ... class MyTypeExpr2(Expr): - def __init__(self) -> None: - ... + def __init__(self) -> None: ... converter(MyType, MyTypeExpr, lambda x: MyTypeExpr()) converter(MyTypeExpr, MyTypeExpr2, lambda x: MyTypeExpr2()) diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 4f5a2c24..90e32863 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -39,18 +39,14 @@ def test_eqsat_basic(): egraph = EGraph() class Math(Expr): - def __init__(self, value: i64Like) -> None: - ... + def __init__(self, value: i64Like) -> None: ... @classmethod - def var(cls, v: StringLike) -> Math: - ... + def var(cls, v: StringLike) -> Math: ... - def __add__(self, other: Math) -> Math: - ... + def __add__(self, other: Math) -> Math: ... - def __mul__(self, other: Math) -> Math: - ... + def __mul__(self, other: Math) -> Math: ... # expr1 = 2 * (x + 3) expr1 = egraph.let("expr1", Math(2) * (Math.var("x") + Math(3))) @@ -77,8 +73,7 @@ def test_fib(): egraph = EGraph() @function - def fib(x: i64Like) -> i64: - ... + def fib(x: i64Like) -> i64: ... f0, f1, x = vars_("f0 f1 x", i64) egraph.register( @@ -97,15 +92,12 @@ def test_fib_demand(): egraph = EGraph() class Num(Expr): - def __init__(self, i: i64Like) -> None: - ... + def __init__(self, i: i64Like) -> None: ... - def __add__(self, other: Num) -> Num: - ... + def __add__(self, other: Num) -> Num: ... @function(cost=20) - def fib(x: i64Like) -> Num: - ... + def fib(x: i64Like) -> Num: ... @egraph.register def _fib(a: i64, b: i64): @@ -124,8 +116,7 @@ def test_push_pop(): egraph = EGraph() @function(merge=lambda old, new: old.max(new)) - def foo() -> i64: - ... + def foo() -> i64: ... egraph.register(set_(foo()).to(i64(1))) egraph.check(eq(foo()).to(i64(1))) @@ -216,8 +207,7 @@ def test_keyword_args(): EGraph() @function - def foo(x: i64Like, y: i64Like) -> i64: - ... + def foo(x: i64Like, y: i64Like) -> i64: ... pos = expr_parts(foo(i64(1), i64(2))) assert expr_parts(foo(i64(1), y=i64(2))) == pos @@ -228,8 +218,7 @@ def test_keyword_args_init(): EGraph() class Foo(Expr): - def __init__(self, x: i64Like) -> None: - ... + def __init__(self, x: i64Like) -> None: ... assert expr_parts(Foo(1)) == expr_parts(Foo(x=1)) @@ -246,14 +235,12 @@ class Numeric(Expr): @m2.class_ class OtherNumeric(Expr): @m2.method(cost=10) - def __init__(self, v: i64Like) -> None: - ... + def __init__(self, v: i64Like) -> None: ... egraph = EGraph([m, m2]) @function - def from_numeric(n: Numeric) -> OtherNumeric: - ... + def from_numeric(n: Numeric) -> OtherNumeric: ... egraph.register(rewrite(OtherNumeric(1)).to(from_numeric(Numeric.ONE))) assert expr_parts(egraph.simplify(OtherNumeric(i64(1)), 10)) == expr_parts(from_numeric(Numeric.ONE)) @@ -263,12 +250,10 @@ def test_property(): egraph = EGraph() class Foo(Expr): - def __init__(self) -> None: - ... + def __init__(self) -> None: ... @property - def bar(self) -> i64: - ... + def bar(self) -> i64: ... egraph.register(set_(Foo().bar).to(i64(1))) egraph.check(eq(Foo().bar).to(i64(1))) @@ -278,8 +263,7 @@ def test_default_args(): EGraph() @function - def foo(x: i64Like, y: i64Like = i64(1)) -> i64: - ... + def foo(x: i64Like, y: i64Like = i64(1)) -> i64: ... assert expr_parts(foo(i64(1))) == expr_parts(foo(i64(1), i64(1))) @@ -348,8 +332,7 @@ def test_custom_equality(): egraph = EGraph() class Boolean(Expr): - def __init__(self, value: BoolLike) -> None: - ... + def __init__(self, value: BoolLike) -> None: ... def __eq__(self, other: Boolean) -> Boolean: # type: ignore[override] ... @@ -373,11 +356,9 @@ def test_setitem_defaults(self): EGraph() class Foo(Expr): - def __init__(self) -> None: - ... + def __init__(self) -> None: ... - def __setitem__(self, key: i64Like, value: i64Like) -> None: - ... + def __setitem__(self, key: i64Like, value: i64Like) -> None: ... foo = Foo() foo[10] = 20 @@ -391,15 +372,12 @@ def test_function(self): egraph = EGraph() class Math(Expr): - def __init__(self, i: i64Like) -> None: - ... + def __init__(self, i: i64Like) -> None: ... - def __add__(self, other: Math) -> Math: - ... + def __add__(self, other: Math) -> Math: ... @function(mutates_first_arg=True) - def incr(x: Math) -> None: - ... + def incr(x: Math) -> None: ... x = Math(i64(10)) x_copied = copy(x) @@ -436,14 +414,11 @@ def test_reflected_binary_method(): EGraph() class Math(Expr): - def __init__(self, value: i64Like) -> None: - ... + def __init__(self, value: i64Like) -> None: ... - def __add__(self, other: Math) -> Math: - ... + def __add__(self, other: Math) -> Math: ... - def __radd__(self, other: Math) -> Math: - ... + def __radd__(self, other: Math) -> Math: ... converter(i64, Math, Math) @@ -460,22 +435,17 @@ def test_upcast_args(): EGraph() class Int(Expr): - def __init__(self, value: i64Like) -> None: - ... + def __init__(self, value: i64Like) -> None: ... - def __add__(self, other: Int) -> Int: - ... + def __add__(self, other: Int) -> Int: ... class Float(Expr): - def __init__(self, value: f64Like) -> None: - ... + def __init__(self, value: f64Like) -> None: ... - def __add__(self, other: Float) -> Float: - ... + def __add__(self, other: Float) -> Float: ... @classmethod - def from_int(cls, other: Int) -> Float: - ... + def from_int(cls, other: Int) -> Float: ... converter(i64, Int, Int) converter(f64, Float, Float) @@ -496,8 +466,7 @@ def test_function_default_upcasts(): EGraph() @function - def f(x: i64Like) -> i64: - ... + def f(x: i64Like) -> i64: ... assert expr_parts(f(1)) == expr_parts(f(i64(1))) @@ -508,30 +477,23 @@ def test_upcast_self_lower_cost(): EGraph() class Int(Expr): - def __init__(self, name: StringLike) -> None: - ... + def __init__(self, name: StringLike) -> None: ... - def __add__(self, other: Int) -> Int: - ... + def __add__(self, other: Int) -> Int: ... NDArrayLike = Union[Int, "NDArray"] class NDArray(Expr): - def __init__(self, name: StringLike) -> None: - ... + def __init__(self, name: StringLike) -> None: ... - def __add__(self, other: NDArrayLike) -> NDArray: - ... + def __add__(self, other: NDArrayLike) -> NDArray: ... - def __radd__(self, other: NDArrayLike) -> NDArray: - ... + def __radd__(self, other: NDArrayLike) -> NDArray: ... - def to_int(self) -> Int: - ... + def to_int(self) -> Int: ... @classmethod - def from_int(cls, other: Int) -> NDArray: - ... + def from_int(cls, other: Int) -> NDArray: ... converter(Int, NDArray, NDArray.from_int) converter(NDArray, Int, lambda a: a.to_int(), 100) @@ -588,14 +550,11 @@ def _locals_make_tuple(x): def test_lazy_types(): class A(Expr): - def __init__(self) -> None: - ... + def __init__(self) -> None: ... - def b(self) -> B: - ... + def b(self) -> B: ... - class B(Expr): - ... + class B(Expr): ... simplify(A().b()) @@ -605,22 +564,19 @@ def test_functions_seperate_pop(): egraph = EGraph() class T(Expr): - def __init__(self, x: i64Like) -> None: - ... + def __init__(self, x: i64Like) -> None: ... with egraph: @function - def f(x: T) -> T: - ... + def f(x: T) -> T: ... egraph.register(f(T(1))) with egraph: @function - def f(x: T, y: T) -> T: - ... + def f(x: T, y: T) -> T: ... egraph.register(f(T(1), T(2))) # type: ignore[call-arg] @@ -628,12 +584,10 @@ def f(x: T, y: T) -> T: # https://github.com/egraphs-good/egglog/issues/113 def test_multiple_generics(): @function - def f() -> Vec[i64]: - ... + def f() -> Vec[i64]: ... @function - def g() -> Vec[String]: - ... + def g() -> Vec[String]: ... egraph = EGraph() diff --git a/python/tests/test_program_gen.py b/python/tests/test_program_gen.py index 7c6e8e78..e3c156f6 100644 --- a/python/tests/test_program_gen.py +++ b/python/tests/test_program_gen.py @@ -8,31 +8,24 @@ class Math(Expr): - def __init__(self, value: i64Like) -> None: - ... + def __init__(self, value: i64Like) -> None: ... @classmethod - def var(cls, v: StringLike) -> Math: - ... + def var(cls, v: StringLike) -> Math: ... - def __add__(self, other: Math) -> Math: - ... + def __add__(self, other: Math) -> Math: ... - def __mul__(self, other: Math) -> Math: - ... + def __mul__(self, other: Math) -> Math: ... - def __neg__(self) -> Math: - ... + def __neg__(self) -> Math: ... @method(cost=1000) # type: ignore[misc] @property - def program(self) -> Program: - ... + def program(self) -> Program: ... @function -def assume_pos(x: Math) -> Math: - ... +def assume_pos(x: Math) -> Math: ... @ruleset