diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index c252a7c18f5c9..2e378d478c100 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -27,6 +27,13 @@ Changes impacting all modules Changelog --------- +:mod:`sklearn.base` +................... + +- |Fix| When `set_output(transform="pandas")`, :class:`base.TransformerMixin` maintains + the index if the :term:`transform` output is already a DataFrame. :pr:`25747` by + `Thomas Fan`_. + :mod:`sklearn.calibration` .......................... diff --git a/sklearn/utils/_set_output.py b/sklearn/utils/_set_output.py index 335773c6af96c..0a07ee77b9fc1 100644 --- a/sklearn/utils/_set_output.py +++ b/sklearn/utils/_set_output.py @@ -34,7 +34,7 @@ def _wrap_in_pandas_container( `range(n_features)`. index : array-like, default=None - Index for data. + Index for data. `index` is ignored if `data_to_wrap` is already a DataFrame. Returns ------- @@ -55,8 +55,6 @@ def _wrap_in_pandas_container( if isinstance(data_to_wrap, pd.DataFrame): if columns is not None: data_to_wrap.columns = columns - if index is not None: - data_to_wrap.index = index return data_to_wrap return pd.DataFrame(data_to_wrap, index=index, columns=columns) diff --git a/sklearn/utils/tests/test_set_output.py b/sklearn/utils/tests/test_set_output.py index ac73ca09439ff..52213d771ee44 100644 --- a/sklearn/utils/tests/test_set_output.py +++ b/sklearn/utils/tests/test_set_output.py @@ -33,7 +33,9 @@ def test__wrap_in_pandas_container_dense_update_columns_and_index(): new_df = _wrap_in_pandas_container(X_df, columns=new_columns, index=new_index) assert_array_equal(new_df.columns, new_columns) - assert_array_equal(new_df.index, new_index) + + # Index does not change when the input is a DataFrame + assert_array_equal(new_df.index, X_df.index) def test__wrap_in_pandas_container_error_validation(): @@ -260,3 +262,33 @@ class C(A, B): pass assert C().transform(None) == "B" + + +class EstimatorWithSetOutputIndex(_SetOutputMixin): + def fit(self, X, y=None): + self.n_features_in_ = X.shape[1] + return self + + def transform(self, X, y=None): + import pandas as pd + + # transform by giving output a new index. + return pd.DataFrame(X.to_numpy(), index=[f"s{i}" for i in range(X.shape[0])]) + + def get_feature_names_out(self, input_features=None): + return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object) + + +def test_set_output_pandas_keep_index(): + """Check that set_output does not override index. + + Non-regression test for gh-25730. + """ + pd = pytest.importorskip("pandas") + + X = pd.DataFrame([[1, 2, 3], [4, 5, 6]], index=[0, 1]) + est = EstimatorWithSetOutputIndex().set_output(transform="pandas") + est.fit(X) + + X_trans = est.transform(X) + assert_array_equal(X_trans.index, ["s0", "s1"])