Skip to content

Commit

Permalink
fix(rust, python): make some list expressions aware of groupby context (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 10, 2023
1 parent 1fe6430 commit c2f16ec
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
16 changes: 11 additions & 5 deletions polars/polars-lazy/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,12 @@ impl ListNameSpace {

/// Get items in every sublist by index.
pub fn get(self, index: Expr) -> Expr {
self.0
.map_many_private(FunctionExpr::ListExpr(ListFunction::Get), &[index], false)
self.0.apply_many_private(
FunctionExpr::ListExpr(ListFunction::Get),
&[index],
false,
false,
)
}

/// Get items in every sublist by multiple indexes.
Expand All @@ -129,10 +133,11 @@ impl ListNameSpace {
/// This behavior is more expensive than defaulting to returing an `Error`.
#[cfg(feature = "list_take")]
pub fn take(self, index: Expr, null_on_oob: bool) -> Expr {
self.0.map_many_private(
self.0.apply_many_private(
FunctionExpr::ListExpr(ListFunction::Take(null_on_oob)),
&[index],
false,
false,
)
}

Expand All @@ -152,7 +157,7 @@ impl ListNameSpace {
pub fn join(self, separator: &str) -> Expr {
let separator = separator.to_string();
self.0
.map(
.apply(
move |s| {
s.list()?
.lst_join(&separator)
Expand Down Expand Up @@ -206,10 +211,11 @@ impl ListNameSpace {

/// Slice every sublist.
pub fn slice(self, offset: Expr, length: Expr) -> Expr {
self.0.map_many_private(
self.0.apply_many_private(
FunctionExpr::ListExpr(ListFunction::Slice),
&[offset, length],
false,
false,
)
}

Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,3 +698,25 @@ def test_list_eval_all_null() -> None:
assert df.select(pl.col("bar").arr.eval(pl.element())).to_dict(False) == {
"bar": [None, None, None]
}


def test_list_function_group_awareness() -> None:
df = pl.DataFrame(
{
"a": [100, 103, 105, 106, 105, 104, 103, 106, 100, 102],
"group": [0, 0, 1, 1, 1, 1, 1, 1, 2, 2],
}
)

assert df.groupby("group").agg(
[
pl.col("a").list().arr.get(0).alias("get"),
pl.col("a").list().arr.take([0]).alias("take"),
pl.col("a").list().arr.slice(0, 3).alias("slice"),
]
).sort("group").to_dict(False) == {
"group": [0, 1, 2],
"get": [[100], [105], [100]],
"take": [[[100]], [[105]], [[100]]],
"slice": [[[100, 103]], [[105, 106, 105]], [[100, 102]]],
}

0 comments on commit c2f16ec

Please sign in to comment.