Skip to content

Commit

Permalink
ENH Introduces the __sklearn_clone__ protocol (#24568)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan committed Jan 19, 2023
1 parent 4b55dee commit 3f82f84
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 1 deletion.
25 changes: 25 additions & 0 deletions doc/developers/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,31 @@ Objects that do not provide this method will be deep-copied
(using the Python standard function ``copy.deepcopy``)
if ``safe=False`` is passed to ``clone``.

Estimators can customize the behavior of :func:`base.clone` by defining a
`__sklearn_clone__` method. `__sklearn_clone__` must return an instance of the
estimator. `__sklearn_clone__` is useful when an estimator needs to hold on to
some state when :func:`base.clone` is called on the estimator. For example, a
frozen meta-estimator for transformers can be defined as follows::

class FrozenTransformer(BaseEstimator):
def __init__(self, fitted_transformer):
self.fitted_transformer = fitted_transformer

def __getattr__(self, name):
# `fitted_transformer`'s attributes are now accessible
return getattr(self.fitted_transformer, name)

def __sklearn_clone__(self):
return self

def fit(self, X, y):
# Fitting does not change the state of the estimator
return self

def fit_transform(self, X, y=None):
# fit_transform only transforms the data
return self.fitted_transformer.transform(X, y)

Pipeline compatibility
----------------------
For an estimator to be usable together with ``pipeline.Pipeline`` in any but the
Expand Down
6 changes: 6 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ Changelog
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
where 123456 is the *pull request* number, not the issue number.
:mod:`sklearn.base`
...................

- |Feature| A `__sklearn_clone__` protocol is now available to override the
default behavior of :func:`base.clone`. :pr:`24568` by `Thomas Fan`_.

:mod:`sklearn.cluster`
......................

Expand Down
17 changes: 16 additions & 1 deletion sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,18 @@ def clone(estimator, *, safe=True):
without actually copying attached data. It returns a new estimator
with the same parameters that has not been fitted on any data.
.. versionchanged:: 1.3
Delegates to `estimator.__sklearn_clone__` if the method exists.
Parameters
----------
estimator : {list, tuple, set} of estimator instance or a single \
estimator instance
The estimator or group of estimators to be cloned.
safe : bool, default=True
If safe is False, clone will fall back to a deep copy on objects
that are not estimators.
that are not estimators. Ignored if `estimator.__sklearn_clone__`
exists.
Returns
-------
Expand All @@ -62,6 +66,14 @@ def clone(estimator, *, safe=True):
return different results from the original estimator. More details can be
found in :ref:`randomness`.
"""
if hasattr(estimator, "__sklearn_clone__") and not inspect.isclass(estimator):
return estimator.__sklearn_clone__()
return _clone_parametrized(estimator, safe=safe)


def _clone_parametrized(estimator, *, safe=True):
"""Default implementation of clone. See :func:`sklearn.base.clone` for details."""

estimator_type = type(estimator)
# XXX: not handling dictionaries
if estimator_type in (list, tuple, set, frozenset):
Expand Down Expand Up @@ -219,6 +231,9 @@ def set_params(self, **params):

return self

def __sklearn_clone__(self):
return _clone_parametrized(self)

def __repr__(self, N_CHAR_MAX=700):
# N_CHAR_MAX is the (approximate) maximum number of non-blank
# characters to render. We pass it as an optional parameter to ease
Expand Down
46 changes: 46 additions & 0 deletions sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import scipy.sparse as sp
import pytest
import warnings
from numpy.testing import assert_allclose

import sklearn
from sklearn.utils._testing import assert_array_equal
Expand All @@ -17,6 +18,7 @@
from sklearn.preprocessing import StandardScaler
from sklearn.utils._set_output import _get_output_config
from sklearn.pipeline import Pipeline
from sklearn.decomposition import PCA
from sklearn.model_selection import GridSearchCV

from sklearn.tree import DecisionTreeClassifier
Expand Down Expand Up @@ -362,6 +364,50 @@ def transform(self, X):
assert e.scalar_param == cloned_e.scalar_param


def test_clone_protocol():
"""Checks that clone works with `__sklearn_clone__` protocol."""

class FrozenEstimator(BaseEstimator):
def __init__(self, fitted_estimator):
self.fitted_estimator = fitted_estimator

def __getattr__(self, name):
return getattr(self.fitted_estimator, name)

def __sklearn_clone__(self):
return self

def fit(self, *args, **kwargs):
return self

def fit_transform(self, *args, **kwargs):
return self.fitted_estimator.transform(*args, **kwargs)

X = np.array([[-1, -1], [-2, -1], [-3, -2]])
pca = PCA().fit(X)
components = pca.components_

frozen_pca = FrozenEstimator(pca)
assert_allclose(frozen_pca.components_, components)

# Calling PCA methods such as `get_feature_names_out` still works
assert_array_equal(frozen_pca.get_feature_names_out(), pca.get_feature_names_out())

# Fitting on a new data does not alter `components_`
X_new = np.asarray([[-1, 2], [3, 4], [1, 2]])
frozen_pca.fit(X_new)
assert_allclose(frozen_pca.components_, components)

# `fit_transform` does not alter state
frozen_pca.fit_transform(X_new)
assert_allclose(frozen_pca.components_, components)

# Cloning estimator is a no-op
clone_frozen_pca = clone(frozen_pca)
assert clone_frozen_pca is frozen_pca
assert_allclose(clone_frozen_pca.components_, components)


def test_pickle_version_warning_is_not_raised_with_matching_version():
iris = datasets.load_iris()
tree = DecisionTreeClassifier().fit(iris.data, iris.target)
Expand Down

0 comments on commit 3f82f84

Please sign in to comment.