Skip to content

Commit

Permalink
FIX TransformerMixin does not override index if transform=pandas (#25747
Browse files Browse the repository at this point in the history
)
  • Loading branch information
thomasjpfan authored and jeremiedbb committed Mar 8, 2023
1 parent 067f78e commit 0ea5793
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 4 deletions.
7 changes: 7 additions & 0 deletions doc/whats_new/v1.2.rst
Expand Up @@ -12,6 +12,13 @@ Version 1.2.2
Changelog
---------

:mod:`sklearn.base`
...................

- |Fix| When `set_output(transform="pandas")`, :class:`base.TransformerMixin` maintains
the index if the :term:`transform` output is already a DataFrame. :pr:`25747` by
`Thomas Fan`_.

:mod:`sklearn.calibration`
..........................

Expand Down
4 changes: 1 addition & 3 deletions sklearn/utils/_set_output.py
Expand Up @@ -34,7 +34,7 @@ def _wrap_in_pandas_container(
`range(n_features)`.
index : array-like, default=None
Index for data.
Index for data. `index` is ignored if `data_to_wrap` is already a DataFrame.
Returns
-------
Expand All @@ -55,8 +55,6 @@ def _wrap_in_pandas_container(
if isinstance(data_to_wrap, pd.DataFrame):
if columns is not None:
data_to_wrap.columns = columns
if index is not None:
data_to_wrap.index = index
return data_to_wrap

return pd.DataFrame(data_to_wrap, index=index, columns=columns)
Expand Down
34 changes: 33 additions & 1 deletion sklearn/utils/tests/test_set_output.py
Expand Up @@ -33,7 +33,9 @@ def test__wrap_in_pandas_container_dense_update_columns_and_index():

new_df = _wrap_in_pandas_container(X_df, columns=new_columns, index=new_index)
assert_array_equal(new_df.columns, new_columns)
assert_array_equal(new_df.index, new_index)

# Index does not change when the input is a DataFrame
assert_array_equal(new_df.index, X_df.index)


def test__wrap_in_pandas_container_error_validation():
Expand Down Expand Up @@ -260,3 +262,33 @@ class C(A, B):
pass

assert C().transform(None) == "B"


class EstimatorWithSetOutputIndex(_SetOutputMixin):
def fit(self, X, y=None):
self.n_features_in_ = X.shape[1]
return self

def transform(self, X, y=None):
import pandas as pd

# transform by giving output a new index.
return pd.DataFrame(X.to_numpy(), index=[f"s{i}" for i in range(X.shape[0])])

def get_feature_names_out(self, input_features=None):
return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)


def test_set_output_pandas_keep_index():
"""Check that set_output does not override index.
Non-regression test for gh-25730.
"""
pd = pytest.importorskip("pandas")

X = pd.DataFrame([[1, 2, 3], [4, 5, 6]], index=[0, 1])
est = EstimatorWithSetOutputIndex().set_output(transform="pandas")
est.fit(X)

X_trans = est.transform(X)
assert_array_equal(X_trans.index, ["s0", "s1"])

0 comments on commit 0ea5793

Please sign in to comment.