Skip to content

Commit

Permalink
Improve pandas support
Browse files Browse the repository at this point in the history
Resolves   #734.
  • Loading branch information
evhub committed Apr 29, 2023
1 parent a4514dc commit 2748093
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 37 deletions.
4 changes: 2 additions & 2 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ To allow for better use of [`numpy`](https://numpy.org/) objects in Coconut, all
- `numpy` objects are allowed seamlessly in Coconut's [implicit coefficient syntax](#implicit-function-application-and-coefficients), allowing the use of e.g. `A B**2` shorthand for `A * B**2` when `A` and `B` are `numpy` arrays (note: **not** `A @ B**2`).
- Coconut supports `@` for matrix multiplication of `numpy` arrays on all Python versions, as well as supplying the `(@)` [operator function](#operator-functions).

Additionally, Coconut provides the exact same support for [`pandas`](https://pandas.pydata.org/), [`pytorch`](https://pytorch.org/), and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html), including using `jax.numpy` methods over `numpy` methods when given `jax` arrays.
Additionally, Coconut provides the exact same support for [`pandas`](https://pandas.pydata.org/), [`pytorch`](https://pytorch.org/), and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html), including using `pandas`/`jax`-specific methods over `numpy` methods when given `pandas`/`jax` objects.

#### `xonsh` Support

Expand Down Expand Up @@ -1911,7 +1911,7 @@ class CanAddAndSub(typing.Protocol, typing.Generic[T, U, V]):

Coconut supports multidimensional array literal and array [concatenation](https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html)/[stack](https://numpy.org/doc/stable/reference/generated/numpy.stack.html) syntax.

By default, all multidimensional array syntax will simply operate on Python lists of lists. However, if [`numpy`](#numpy-integration) objects are used, the appropriate `numpy` calls will be made instead. To give custom objects multidimensional array concatenation support, define `type(obj).__matconcat__` (should behave as `np.concat`), `obj.ndim` (should behave as `np.ndarray.ndim`), and `obj.reshape` (should behave as `np.ndarray.reshape`).
By default, all multidimensional array syntax will simply operate on Python lists of lists (or any non-`str` `Sequence`). However, if [`numpy`](#numpy-integration) objects are used, the appropriate `numpy` calls will be made instead. To give custom objects multidimensional array concatenation support, define `type(obj).__matconcat__` (should behave as `np.concat`), `obj.ndim` (should behave as `np.ndarray.ndim`), and `obj.reshape` (should behave as `np.ndarray.reshape`).

As a simple example, 2D matrices can be constructed by separating the rows with `;;` inside of a list literal:
```coconut_pycon
Expand Down
1 change: 1 addition & 0 deletions _coconut/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ npt = _npt # Fake, like typing
zip_longest = _zip_longest

numpy_modules: _t.Any = ...
pandas_numpy_modules: _t.Any = ...
jax_numpy_modules: _t.Any = ...
tee_type: _t.Any = ...
reiterables: _t.Any = ...
Expand Down
4 changes: 3 additions & 1 deletion coconut/compiler/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
justify_len,
report_this_text,
numpy_modules,
pandas_numpy_modules,
jax_numpy_modules,
self_match_types,
is_data_var,
Expand Down Expand Up @@ -227,6 +228,7 @@ def process_header_args(which, use_hash, target, no_tco, strict, no_wrap):
comma_slash=", /" if target_info >= (3, 8) else "",
report_this_text=report_this_text,
numpy_modules=tuple_str_of(numpy_modules, add_quotes=True),
pandas_numpy_modules=tuple_str_of(pandas_numpy_modules, add_quotes=True),
jax_numpy_modules=tuple_str_of(jax_numpy_modules, add_quotes=True),
self_match_types=tuple_str_of(self_match_types),
set_super=(
Expand Down Expand Up @@ -420,7 +422,7 @@ def _coconut_matmul(a, b, **kwargs):
else:
if result is not _coconut.NotImplemented:
return result
if "numpy" in (a.__class__.__module__, b.__class__.__module__):
if "numpy" in (_coconut_get_base_module(a), _coconut_get_base_module(b)):
from numpy import matmul
return matmul(a, b)
raise _coconut.TypeError("unsupported operand type(s) for @: " + _coconut.repr(_coconut.type(a)) + " and " + _coconut.repr(_coconut.type(b)))
Expand Down
73 changes: 49 additions & 24 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class _coconut{object}:{COMMENT.EVERYTHING_HERE_MUST_BE_COPIED_TO_STUB_FILE}
else:
abc.Sequence.register(numpy.ndarray)
numpy_modules = {numpy_modules}
pandas_numpy_modules = {pandas_numpy_modules}
jax_numpy_modules = {jax_numpy_modules}
tee_type = type(itertools.tee((), 1)[0])
reiterables = abc.Sequence, abc.Mapping, abc.Set
Expand Down Expand Up @@ -78,6 +79,8 @@ class _coconut_Sentinel(_coconut_baseclass):
def __reduce__(self):
return (self.__class__, ())
_coconut_sentinel = _coconut_Sentinel()
def _coconut_get_base_module(obj):
return obj.__class__.__module__.split(".", 1)[0]
class MatchError(_coconut_baseclass, Exception):
"""Pattern-matching error. Has attributes .pattern, .value, and .message."""{COMMENT.no_slots_to_allow_setattr_below}
max_val_repr_len = 500
Expand Down Expand Up @@ -650,17 +653,21 @@ Additionally supports Cartesian products of numpy arrays."""
repeat = 1
if repeat < 0:
raise _coconut.ValueError("cartesian_product: repeat cannot be negative")
if iterables and _coconut.all(it.__class__.__module__ in _coconut.numpy_modules for it in iterables):
if _coconut.any(it.__class__.__module__ in _coconut.jax_numpy_modules for it in iterables):
from jax import numpy
else:
numpy = _coconut.numpy
iterables *= repeat
dtype = numpy.result_type(*iterables)
arr = numpy.empty([_coconut.len(a) for a in iterables] + [_coconut.len(iterables)], dtype=dtype)
for i, a in _coconut.enumerate(numpy.ix_(*iterables)):
arr[..., i] = a
return arr.reshape(-1, _coconut.len(iterables))
if iterables:
it_modules = [_coconut_get_base_module(it) for it in iterables]
if _coconut.all(mod in _coconut.numpy_modules for mod in it_modules):
if _coconut.any(mod in _coconut.pandas_numpy_modules for mod in it_modules):
iterables = tuple((it.to_numpy() if _coconut_get_base_module(it) in _coconut.pandas_numpy_modules else it) for it in iterables)
if _coconut.any(mod in _coconut.jax_numpy_modules for mod in it_modules):
from jax import numpy
else:
numpy = _coconut.numpy
iterables *= repeat
dtype = numpy.result_type(*iterables)
arr = numpy.empty([_coconut.len(a) for a in iterables] + [_coconut.len(iterables)], dtype=dtype)
for i, a in _coconut.enumerate(numpy.ix_(*iterables)):
arr[..., i] = a
return arr.reshape(-1, _coconut.len(iterables))
self = _coconut.object.__new__(cls)
self.iters = iterables
self.repeat = repeat
Expand Down Expand Up @@ -973,7 +980,7 @@ class multi_enumerate(_coconut_has_iter):
return self.__class__(self.get_new_iter())
@property
def is_numpy(self):
return self.iter.__class__.__module__ in _coconut.numpy_modules
return _coconut_get_base_module(self.iter) in _coconut.numpy_modules
def __iter__(self):
if self.is_numpy:
it = _coconut.numpy.nditer(self.iter, ["multi_index", "refs_ok"], [["readonly"]])
Expand Down Expand Up @@ -1474,11 +1481,17 @@ def fmap(func, obj, **kwargs):
else:
if result is not _coconut.NotImplemented:
return result
if obj.__class__.__module__ in _coconut.jax_numpy_modules:
obj_module = _coconut_get_base_module(obj)
if obj_module in _coconut.jax_numpy_modules:
import jax.numpy as jnp
return jnp.vectorize(func)(obj)
if obj.__class__.__module__ in _coconut.numpy_modules:
return _coconut.numpy.vectorize(func)(obj)
if obj_module in _coconut.numpy_modules:
got = _coconut.numpy.vectorize(func)(obj)
if obj_module in _coconut.pandas_numpy_modules:
new_obj = obj.copy()
new_obj[:] = got
return new_obj
return got
obj_aiter = _coconut.getattr(obj, "__aiter__", None)
if obj_aiter is not None and _coconut_amap is not None:
try:
Expand Down Expand Up @@ -1744,7 +1757,10 @@ def all_equal(iterable):

Supports numpy arrays. Assumes transitivity and 'x != y' being equivalent to 'not (x == y)'.
"""
if iterable.__class__.__module__ in _coconut.numpy_modules:
iterable_module = _coconut_get_base_module(iterable)
if iterable_module in _coconut.numpy_modules:
if iterable_module in _coconut.pandas_numpy_modules:
iterable = iterable.to_numpy()
return not _coconut.len(iterable) or (iterable == iterable[0]).all()
first_item = _coconut_sentinel
for item in iterable:
Expand Down Expand Up @@ -1787,40 +1803,49 @@ def _coconut_mk_anon_namedtuple(fields, types=None, of_kwargs=None):
return NT
return NT(**of_kwargs)
def _coconut_ndim(arr):
if (arr.__class__.__module__ in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "ndim"):
if (_coconut_get_base_module(arr) in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "ndim"):
return arr.ndim
if not _coconut.isinstance(arr, _coconut.abc.Sequence):
if not _coconut.isinstance(arr, _coconut.abc.Sequence) or _coconut.isinstance(arr, (_coconut.str, _coconut.bytes)):
return 0
if _coconut.len(arr) == 0:
return 1
arr_dim = 1
inner_arr = arr[0]
if inner_arr == arr:
return 0
while _coconut.isinstance(inner_arr, _coconut.abc.Sequence):
arr_dim += 1
if _coconut.len(inner_arr) < 1:
break
inner_arr = inner_arr[0]
new_inner_arr = inner_arr[0]
if new_inner_arr == inner_arr:
break
inner_arr = new_inner_arr
return arr_dim
def _coconut_expand_arr(arr, new_dims):
if (arr.__class__.__module__ in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "reshape"):
if (_coconut_get_base_module(arr) in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "reshape"):
return arr.reshape((1,) * new_dims + arr.shape)
for _ in _coconut.range(new_dims):
arr = [arr]
return arr
def _coconut_concatenate(arrs, axis):
matconcat = None
for a in arrs:
if a.__class__.__module__ in _coconut.jax_numpy_modules:
a_module = _coconut_get_base_module(a)
if a_module in _coconut.pandas_numpy_modules:
from pandas import concat as matconcat
break
if a_module in _coconut.jax_numpy_modules:
from jax.numpy import concatenate as matconcat
break
if a.__class__.__module__ in _coconut.numpy_modules:
if a_module in _coconut.numpy_modules:
matconcat = _coconut.numpy.concatenate
break
if _coconut.hasattr(a.__class__, "__matconcat__"):
matconcat = a.__class__.__matconcat__
break
if matconcat is not None:
return matconcat(arrs, axis)
return matconcat(arrs, axis=axis)
if not axis:
return _coconut.list(_coconut.itertools.chain.from_iterable(arrs))
return [_coconut_concatenate(rows, axis - 1) for rows in _coconut.zip(*arrs)]
Expand All @@ -1833,7 +1858,7 @@ def _coconut_multi_dim_arr(arrs, dim):
def _coconut_call_or_coefficient(func, *args):
if _coconut.callable(func):
return func(*args)
if not _coconut.isinstance(func, (_coconut.int, _coconut.float, _coconut.complex)) and func.__class__.__module__ not in _coconut.numpy_modules:
if not _coconut.isinstance(func, (_coconut.int, _coconut.float, _coconut.complex)) and _coconut_get_base_module(func) not in _coconut.numpy_modules:
raise _coconut.TypeError("implicit function application and coefficient syntax only supported for Callable, int, float, complex, and numpy objects")
func = func
for x in args:
Expand Down
26 changes: 17 additions & 9 deletions coconut/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,19 @@ def get_bool_env_var(env_var, default=False):
sys.setrecursionlimit(default_recursion_limit)

# modules that numpy-like arrays can live in
pandas_numpy_modules = (
"pandas",
)
jax_numpy_modules = (
"jaxlib.xla_extension",
"jaxlib",
)
numpy_modules = (
"numpy",
"pandas",
"torch",
) + jax_numpy_modules
) + (
pandas_numpy_modules
+ jax_numpy_modules
)

legal_indent_chars = " \t" # the only Python-legal indent chars

Expand Down Expand Up @@ -828,8 +833,8 @@ def get_bool_env_var(env_var, default=False):
("jupyter-client", "py<35"),
("jupyter-client", "py==35"),
("jupyter-client", "py36"),
("jedi", "py<37"),
("jedi", "py37"),
("jedi", "py<39"),
("jedi", "py39"),
("pywinpty", "py2;windows"),
),
"jupyter": (
Expand Down Expand Up @@ -879,6 +884,7 @@ def get_bool_env_var(env_var, default=False):
"pexpect",
("numpy", "py34"),
("numpy", "py2;cpy"),
("pandas", "py36"),
),
}

Expand All @@ -905,14 +911,15 @@ def get_bool_env_var(env_var, default=False):
"mypy[python2]": (1, 1),
("jupyter-console", "py37"): (6,),
("typing", "py<35"): (3, 10),
("jedi", "py37"): (0, 18),
("typing_extensions", "py37"): (4, 4),
("ipython", "py39"): (8,),
("ipykernel", "py39"): (6,),
("jedi", "py39"): (0, 18),

# pinned reqs: (must be added to pinned_reqs below)

# don't upgrade this; it breaks on Python 3.6
("pandas", "py36"): (1,),
("jupyter-client", "py36"): (7, 1, 2),
("typing_extensions", "py==36"): (4, 1),
# don't upgrade these; they break on Python 3.5
Expand Down Expand Up @@ -942,13 +949,14 @@ def get_bool_env_var(env_var, default=False):
"watchdog": (0, 10),
"papermill": (1, 2),
# don't upgrade this; it breaks with old IPython versions
("jedi", "py<37"): (0, 17),
("jedi", "py<39"): (0, 17),
# Coconut requires pyparsing 2
"pyparsing": (2, 4, 7),
}

# should match the reqs with comments above
pinned_reqs = (
("pandas", "py36"),
("jupyter-client", "py36"),
("typing_extensions", "py==36"),
("jupyter-client", "py<35"),
Expand All @@ -971,7 +979,7 @@ def get_bool_env_var(env_var, default=False):
("prompt_toolkit", "mark2"),
"watchdog",
"papermill",
("jedi", "py<37"),
("jedi", "py<39"),
"pyparsing",
)

Expand All @@ -984,7 +992,7 @@ def get_bool_env_var(env_var, default=False):
"pyparsing": _,
"cPyparsing": (_, _, _),
("prompt_toolkit", "mark2"): _,
("jedi", "py<37"): _,
("jedi", "py<39"): _,
("pywinpty", "py2;windows"): _,
("ipython", "py3;py<39"): _,
}
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "3.0.0"
VERSION_NAME = None
# False for release, int >= 1 for develop
DEVELOP = 36
DEVELOP = 37
ALPHA = True # for pre releases rather than post releases

# -----------------------------------------------------------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions coconut/tests/src/cocotest/agnostic/primary.coco
Original file line number Diff line number Diff line change
Expand Up @@ -1585,4 +1585,6 @@ def primary_test() -> bool:
a_dict = {"a": 1, "b": 2}
a_dict |= {"a": 10, "c": 20}
assert a_dict == {"a": 10, "b": 2, "c": 20} == {"a": 1, "b": 2} | {"a": 10, "c": 20}
assert ["abc" ; "def"] == ['abc', 'def']
assert ["abc" ;; "def"] == [['abc'], ['def']]
return True
30 changes: 30 additions & 0 deletions coconut/tests/src/extras.coco
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ from coconut.constants import (
PY2,
PY34,
PY35,
PY36,
WINDOWS,
PYPY,
) # type: ignore
Expand Down Expand Up @@ -464,9 +465,38 @@ def test_numpy() -> bool:
return True


def test_pandas() -> bool:
import pandas as pd
import numpy as np
d1 = pd.DataFrame({"nums": [1, 2, 3], "chars": ["a", "b", "c"]})
assert [d1; d1].keys() |> list == ["nums", "chars"] * 2 # type: ignore
assert [d1;; d1].itertuples() |> list == [(0, 1, 'a'), (1, 2, 'b'), (2, 3, 'c'), (0, 1, 'a'), (1, 2, 'b'), (2, 3, 'c')] # type: ignore
d2 = pd.DataFrame({"a": range(3) |> list, "b": range(1, 4) |> list})
new_d2 = d2 |> fmap$(.+1)
assert new_d2["a"] |> list == range(1, 4) |> list
assert new_d2["b"] |> list == range(2, 5) |> list
assert multi_enumerate(d1) |> list == [((0, 0), 1), ((1, 0), 2), ((2, 0), 3), ((0, 1), 'a'), ((1, 1), 'b'), ((2, 1), 'c')]
assert not all_equal(d1)
assert not all_equal(d2)
assert cartesian_product(d1["nums"], d1["chars"]) `np.array_equal` np.array([
1; 'a';;
1; 'b';;
1; 'c';;
2; 'a';;
2; 'b';;
2; 'c';;
3; 'a';;
3; 'b';;
3; 'c';;
], dtype=object)
return True


def test_extras() -> bool:
if not PYPY and (PY2 or PY34):
assert test_numpy() is True
if not PYPY and PY36:
assert test_pandas() is True
if CoconutKernel is not None:
assert test_kernel() is True
assert test_setup_none() is True
Expand Down

0 comments on commit 2748093

Please sign in to comment.