Skip to content

Commit

Permalink
FIX Using custom init for KMeans does a single init (#26657)
Browse files Browse the repository at this point in the history
  • Loading branch information
bnsh committed Jun 26, 2023
1 parent 4e88150 commit 2579841
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 5 deletions.
5 changes: 5 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ Changelog
:user:`Jérémie du Boisberranger <jeremiedbb>`,
:user:`Guillaume Lemaitre <glemaitre>`.

- |Fix| :class:`cluster.KMeans`, :class:`cluster.MiniBatchKMeans` and
:func:`cluster.k_means` now correctly handle the combination of `n_init="auto"`
and `init` being an array-like, running one initialization in that case.
:pr:`26657` by :user:`Binesh Bannerjee <bnsh>`.

- |API| The `sample_weight` parameter in `predict` for
:meth:`cluster.KMeans.predict` and :meth:`cluster.MiniBatchKMeans.predict`
is now deprecated and will be removed in v1.5.
Expand Down
17 changes: 12 additions & 5 deletions sklearn/cluster/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ def k_means(
n_init consecutive runs in terms of inertia.
When `n_init='auto'`, the number of runs depends on the value of init:
10 if using `init='random'`, 1 if using `init='k-means++'`.
10 if using `init='random'` or `init` is a callable;
1 if using `init='k-means++'` or `init` is an array-like.
.. versionadded:: 1.2
Added 'auto' option for `n_init`.
Expand Down Expand Up @@ -884,10 +885,14 @@ def _check_params_vs_input(self, X, default_n_init=None):
)
self._n_init = default_n_init
if self._n_init == "auto":
if self.init == "k-means++":
if isinstance(self.init, str) and self.init == "k-means++":
self._n_init = 1
else:
elif isinstance(self.init, str) and self.init == "random":
self._n_init = default_n_init
elif callable(self.init):
self._n_init = default_n_init
else: # array-like
self._n_init = 1

if _is_arraylike_not_scalar(self.init) and self._n_init != 1:
warnings.warn(
Expand Down Expand Up @@ -1241,7 +1246,8 @@ class KMeans(_BaseKMeans):
high-dimensional problems (see :ref:`kmeans_sparse_high_dim`).
When `n_init='auto'`, the number of runs depends on the value of init:
10 if using `init='random'`, 1 if using `init='k-means++'`.
10 if using `init='random'` or `init` is a callable;
1 if using `init='k-means++'` or `init` is an array-like.
.. versionadded:: 1.2
Added 'auto' option for `n_init`.
Expand Down Expand Up @@ -1777,7 +1783,8 @@ class MiniBatchKMeans(_BaseKMeans):
:ref:`kmeans_sparse_high_dim`).
When `n_init='auto'`, the number of runs depends on the value of init:
3 if using `init='random'`, 1 if using `init='k-means++'`.
3 if using `init='random'` or `init` is a callable;
1 if using `init='k-means++'` or `init` is an array-like.
.. versionadded:: 1.2
Added 'auto' option for `n_init`.
Expand Down
31 changes: 31 additions & 0 deletions sklearn/cluster/tests/test_k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,37 @@ def test_minibatch_kmeans_partial_fit_init(init):
_check_fitted_model(km)


@pytest.mark.parametrize(
"init, expected_n_init",
[
("k-means++", 1),
("random", "default"),
(
lambda X, n_clusters, random_state: random_state.uniform(
size=(n_clusters, X.shape[1])
),
"default",
),
("array-like", 1),
],
)
@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
def test_kmeans_init_auto_with_initial_centroids(Estimator, init, expected_n_init):
"""Check that `n_init="auto"` chooses the right number of initializations.
Non-regression test for #26657:
https://github.com/scikit-learn/scikit-learn/pull/26657
"""
n_sample, n_features, n_clusters = 100, 10, 5
X = np.random.randn(n_sample, n_features)
if init == "array-like":
init = np.random.randn(n_clusters, n_features)
if expected_n_init == "default":
expected_n_init = 3 if Estimator is MiniBatchKMeans else 10

kmeans = Estimator(n_clusters=n_clusters, init=init, n_init="auto").fit(X)
assert kmeans._n_init == expected_n_init


@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
def test_fortran_aligned_data(Estimator, global_random_seed):
# Check that KMeans works with fortran-aligned data.
Expand Down

0 comments on commit 2579841

Please sign in to comment.