Skip to content

Commit

Permalink
Improve fmap for pandas
Browse files Browse the repository at this point in the history
Resolves   #734.
  • Loading branch information
evhub committed Apr 30, 2023
1 parent 23d4959 commit b81719e
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 11 deletions.
4 changes: 3 additions & 1 deletion DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3168,6 +3168,8 @@ For `dict`, or any other `collections.abc.Mapping`, `fmap` will map over the map

For [`numpy`](#numpy-integration) objects, `fmap` will use [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) to produce the result.

For [`pandas`](https://pandas.pydata.org/) objects, `fmap` will use [`.apply`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.apply.html) along the last axis (so row-wise for `DataFrame`'s, element-wise for `Series`'s).

For asynchronous iterables, `fmap` will map asynchronously, making `fmap` equivalent in that case to
```coconut_python
async def fmap_over_async_iters(func, async_iter):
Expand Down Expand Up @@ -3198,7 +3200,7 @@ _Can't be done without a series of method definitions for each data type. See th

**call**(_func_, /, *_args_, \*\*_kwargs_)

Coconut's `call` simply implements function application. Thus, `call` is equivalent to
Coconut's `call` simply implements function application. Thus, `call` is effectively equivalent to
```coconut
def call(f, /, *args, **kwargs) = f(*args, **kwargs)
```
Expand Down
11 changes: 5 additions & 6 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -1486,16 +1486,15 @@ def fmap(func, obj, **kwargs):
if result is not _coconut.NotImplemented:
return result
obj_module = _coconut_get_base_module(obj)
if obj_module in _coconut.pandas_numpy_modules:
if obj.ndim <= 1:
return obj.apply(func)
return obj.apply(func, axis=obj.ndim-1)
if obj_module in _coconut.jax_numpy_modules:
import jax.numpy as jnp
return jnp.vectorize(func)(obj)
if obj_module in _coconut.numpy_modules:
got = _coconut.numpy.vectorize(func)(obj)
if obj_module in _coconut.pandas_numpy_modules:
new_obj = obj.copy()
new_obj[:] = got
return new_obj
return got
return _coconut.numpy.vectorize(func)(obj)
obj_aiter = _coconut.getattr(obj, "__aiter__", None)
if obj_aiter is not None and _coconut_amap is not None:
try:
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "3.0.0"
VERSION_NAME = None
# False for release, int >= 1 for develop
DEVELOP = 39
DEVELOP = 40
ALPHA = True # for pre releases rather than post releases

# -----------------------------------------------------------------------------------------------------------------------
Expand Down
8 changes: 5 additions & 3 deletions coconut/tests/src/extras.coco
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,9 @@ def test_pandas() -> bool:
assert [d1; d1].keys() |> list == ["nums", "chars"] * 2 # type: ignore
assert [d1;; d1].itertuples() |> list == [(0, 1, 'a'), (1, 2, 'b'), (2, 3, 'c'), (0, 1, 'a'), (1, 2, 'b'), (2, 3, 'c')] # type: ignore
d2 = pd.DataFrame({"a": range(3) |> list, "b": range(1, 4) |> list})
new_d2 = d2 |> fmap$(.+1)
assert new_d2["a"] |> list == range(1, 4) |> list
assert new_d2["b"] |> list == range(2, 5) |> list
d3 = d2 |> fmap$(fmap$(.+1))
assert d3["a"] |> list == range(1, 4) |> list
assert d3["b"] |> list == range(2, 5) |> list
assert multi_enumerate(d1) |> list == [((0, 0), 1), ((1, 0), 2), ((2, 0), 3), ((0, 1), 'a'), ((1, 1), 'b'), ((2, 1), 'c')]
assert not all_equal(d1)
assert not all_equal(d2)
Expand All @@ -489,6 +489,8 @@ def test_pandas() -> bool:
3; 'b';;
3; 'c';;
], dtype=object) # type: ignore
d4 = d1 |> fmap$(def r -> r["nums2"] = r["nums"]*2; r)
assert (d4["nums"] * 2 == d4["nums2"]).all()
return True


Expand Down

0 comments on commit b81719e

Please sign in to comment.