Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX TransformerMixin does not override index if transform=pandas #25747

Merged
merged 2 commits into from Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/whats_new/v1.2.rst
Expand Up @@ -27,6 +27,13 @@ Changes impacting all modules
Changelog
---------

:mod:`sklearn.base`
...................

- |Fix| When `set_output(transform="pandas")`, :class:`base.TransformerMixin` maintains
the index if the :term:`transform` output is already a DataFrame. :pr:`25747` by
`Thomas Fan`_.

:mod:`sklearn.calibration`
..........................

Expand Down
4 changes: 1 addition & 3 deletions sklearn/utils/_set_output.py
Expand Up @@ -34,7 +34,7 @@ def _wrap_in_pandas_container(
`range(n_features)`.

index : array-like, default=None
Index for data.
Index for data. `index` is ignored if `data_to_wrap` is already a DataFrame.

Returns
-------
Expand All @@ -55,8 +55,6 @@ def _wrap_in_pandas_container(
if isinstance(data_to_wrap, pd.DataFrame):
if columns is not None:
data_to_wrap.columns = columns
if index is not None:
data_to_wrap.index = index
return data_to_wrap

return pd.DataFrame(data_to_wrap, index=index, columns=columns)
Expand Down
34 changes: 33 additions & 1 deletion sklearn/utils/tests/test_set_output.py
Expand Up @@ -33,7 +33,9 @@ def test__wrap_in_pandas_container_dense_update_columns_and_index():

new_df = _wrap_in_pandas_container(X_df, columns=new_columns, index=new_index)
assert_array_equal(new_df.columns, new_columns)
assert_array_equal(new_df.index, new_index)

# Index does not change when the input is a DataFrame
assert_array_equal(new_df.index, X_df.index)


def test__wrap_in_pandas_container_error_validation():
Expand Down Expand Up @@ -260,3 +262,33 @@ class C(A, B):
pass

assert C().transform(None) == "B"


class EstimatorWithSetOutputIndex(_SetOutputMixin):
def fit(self, X, y=None):
self.n_features_in_ = X.shape[1]
return self

def transform(self, X, y=None):
import pandas as pd

# transform by giving output a new index.
return pd.DataFrame(X.to_numpy(), index=[f"s{i}" for i in range(X.shape[0])])

def get_feature_names_out(self, input_features=None):
return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)


def test_set_output_pandas_keep_index():
"""Check that set_output does not override index.

Non-regression test for gh-25730.
"""
pd = pytest.importorskip("pandas")

X = pd.DataFrame([[1, 2, 3], [4, 5, 6]], index=[0, 1])
est = EstimatorWithSetOutputIndex().set_output(transform="pandas")
est.fit(X)

X_trans = est.transform(X)
assert_array_equal(X_trans.index, ["s0", "s1"])