Skip to content

Commit

Permalink
Fix ListColumn.to_pandas() to retain list type (#15155)
Browse files Browse the repository at this point in the history
Fixes: #14568 

This PR fixes `ListColumn.to_pandas()` by calling `ArrowArray.to_pylist()` method to retain `list` type in pandas series.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Matthew Roeschke (https://github.com/mroeschke)
  - Richard (Rick) Zamora (https://github.com/rjzamora)

URL: #15155
  • Loading branch information
galipremsagar committed Mar 4, 2024
1 parent d158ccd commit c3cad1d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
18 changes: 18 additions & 0 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import pyarrow as pa
from typing_extensions import Self

Expand Down Expand Up @@ -288,6 +289,23 @@ def _transform_leaves(self, func, *args, **kwargs) -> Self:
)
return lc

def to_pandas(
self,
*,
index: Optional[pd.Index] = None,
nullable: bool = False,
) -> pd.Series:
# Can't rely on Column.to_pandas implementation for lists.
# Need to perform `to_pylist` to preserve list types.
if nullable:
raise NotImplementedError(f"{nullable=} is not implemented.")

pd_series = pd.Series(self.to_arrow().to_pylist(), dtype="object")

if index is not None:
pd_series.index = index
return pd_series


class ListMethods(ColumnMethods):
"""
Expand Down
4 changes: 3 additions & 1 deletion python/cudf/cudf/tests/test_list.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

import functools
import operator
Expand Down Expand Up @@ -41,6 +41,8 @@ def test_create_list_series(data):
expect = pd.Series(data)
got = cudf.Series(data)
assert_eq(expect, got)
assert isinstance(got[0], type(expect[0]))
assert isinstance(got.to_pandas()[0], type(expect[0]))


@pytest.mark.parametrize(
Expand Down
6 changes: 1 addition & 5 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,13 +702,9 @@ def test_is_supported(arg, supported):

def test_groupby_unique_lists():
df = pd.DataFrame({"a": [0, 0, 0, 1, 1, 1], "b": [10, 10, 10, 7, 8, 9]})
ddf = dd.from_pandas(df, 2)
gdf = cudf.from_pandas(df)
gddf = dask_cudf.from_cudf(gdf, 2)
dd.assert_eq(
ddf.groupby("a").b.unique().compute(),
gddf.groupby("a").b.unique().compute(),
)

dd.assert_eq(
gdf.groupby("a").b.unique(),
gddf.groupby("a").b.unique().compute(),
Expand Down

0 comments on commit c3cad1d

Please sign in to comment.