Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using ndarray as init for KMeans raises a ValueError #26657

Merged
merged 18 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ Changelog
- |API| The `Xred` argument in :func:`cluster.FeatureAgglomeration.inverse_transform`
is renamed to `Xt` and will be removed in v1.5. :pr:`26503` by `Adrin Jalali`_.

- [Fix] :class:`cluster._BaseKMeans` :meth:`_check_params_vs_input` was checking
`if self.init == "k-means++"`.
This fails with numpy>=1.25.0 with a ValueError, asking to use `.any()` or
`.all()` to clarify. Fixed to first explicitly check if `self.init` is a str
and only then do a string comparison.
:pr:`26657` by :user:`Binesh Bannerjee <bnsh>`.
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved

:mod:`sklearn.compose`
......................

Expand Down
28 changes: 23 additions & 5 deletions sklearn/cluster/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,10 @@ def k_means(
n_init consecutive runs in terms of inertia.

When `n_init='auto'`, the number of runs depends on the value of init:
10 if using `init='random'`, 1 if using `init='k-means++'`.
10 if using `init='random'`
1 if using `init='k-means++'`
10 if init is a callable
1 if init is array-like.
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved

.. versionadded:: 1.2
Added 'auto' option for `n_init`.
Expand Down Expand Up @@ -897,10 +900,19 @@ def _check_params_vs_input(self, X, default_n_init=None):
)
self._n_init = default_n_init
if self._n_init == "auto":
if self.init == "k-means++":
if isinstance(self.init, str) and self.init == "k-means++":
self._n_init = 1
else:
elif isinstance(self.init, str) and self.init == "random":
self._n_init = default_n_init
elif callable(self.init):
self._n_init = default_n_init
elif _is_arraylike_not_scalar(self.init):
self._n_init = 1
else:
raise ValueError(
'Expect init to be one of ["k-means++", "random", callable or'
"array-like of shape(n_clusters, n_features)]"
)
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved

if _is_arraylike_not_scalar(self.init) and self._n_init != 1:
warnings.warn(
Expand Down Expand Up @@ -1254,7 +1266,10 @@ class KMeans(_BaseKMeans):
high-dimensional problems (see :ref:`kmeans_sparse_high_dim`).

When `n_init='auto'`, the number of runs depends on the value of init:
10 if using `init='random'`, 1 if using `init='k-means++'`.
10 if using `init='random'`
1 if using `init='k-means++'`
10 if init is a callable
1 if init is array-like.

.. versionadded:: 1.2
Added 'auto' option for `n_init`.
Expand Down Expand Up @@ -1790,7 +1805,10 @@ class MiniBatchKMeans(_BaseKMeans):
:ref:`kmeans_sparse_high_dim`).

When `n_init='auto'`, the number of runs depends on the value of init:
3 if using `init='random'`, 1 if using `init='k-means++'`.
3 if using `init='random'`
1 if using `init='k-means++'`
3 if init is a callable
1 if init is array-like.

.. versionadded:: 1.2
Added 'auto' option for `n_init`.
Expand Down
31 changes: 31 additions & 0 deletions sklearn/cluster/tests/test_k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,37 @@ def test_minibatch_kmeans_partial_fit_init(init):
_check_fitted_model(km)


@pytest.mark.parametrize(
"init, expected_n_init",
[
("k-means++", 1),
("random", "default"),
(
lambda X, n_clusters, random_state: random_state.uniform(
size=(n_clusters, X.shape[1])
),
"default",
),
("array-like", 1),
],
)
@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
def test_kmeans_init_auto_with_initial_centroids(Estimator, init, expected_n_init):
"""Check that `n_init="auto"` chooses the right number of initializations.
Non-regression test for #26657:
https://github.com/scikit-learn/scikit-learn/pull/26657
"""
n_sample, n_features, n_clusters = 100, 10, 5
X = np.random.randn(n_sample, n_features)
if init == "array-like":
init = np.random.randn(n_clusters, n_features)
if expected_n_init == "default":
expected_n_init = 3 if Estimator is MiniBatchKMeans else 10

kmeans = Estimator(n_clusters=n_clusters, init=init, n_init="auto").fit(X)
assert kmeans._n_init == expected_n_init


@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
def test_fortran_aligned_data(Estimator, global_random_seed):
# Check that KMeans works with fortran-aligned data.
Expand Down