diff --git a/sklearn/cluster/_agglomerative.py b/sklearn/cluster/_agglomerative.py index 48e2d38ebf32b..2c259e0287065 100644 --- a/sklearn/cluster/_agglomerative.py +++ b/sklearn/cluster/_agglomerative.py @@ -16,8 +16,8 @@ from ..base import BaseEstimator, ClusterMixin from ..metrics.pairwise import paired_distances, pairwise_distances -from ..neighbors import DistanceMetric -from ..neighbors._dist_metrics import METRIC_MAPPING +from ..metrics import DistanceMetric +from ..metrics._dist_metrics import METRIC_MAPPING from ..utils import check_array from ..utils._fast_dict import IntFloatDict from ..utils.fixes import _astype_copy_false diff --git a/sklearn/cluster/_hierarchical_fast.pyx b/sklearn/cluster/_hierarchical_fast.pyx index 2a58757ce327d..11ea3294c086a 100644 --- a/sklearn/cluster/_hierarchical_fast.pyx +++ b/sklearn/cluster/_hierarchical_fast.pyx @@ -13,7 +13,7 @@ ctypedef np.int8_t INT8 np.import_array() -from ..neighbors._dist_metrics cimport DistanceMetric +from ..metrics._dist_metrics cimport DistanceMetric from ..utils._fast_dict cimport IntFloatDict # C++ @@ -236,8 +236,8 @@ def max_merge(IntFloatDict a, IntFloatDict b, def average_merge(IntFloatDict a, IntFloatDict b, np.ndarray[ITYPE_t, ndim=1] mask, ITYPE_t n_a, ITYPE_t n_b): - """Merge two IntFloatDicts with the average strategy: when the - same key is present in the two dicts, the weighted average of the two + """Merge two IntFloatDicts with the average strategy: when the + same key is present in the two dicts, the weighted average of the two values is used. Parameters @@ -290,13 +290,13 @@ def average_merge(IntFloatDict a, IntFloatDict b, ############################################################################### -# An edge object for fast comparisons +# An edge object for fast comparisons cdef class WeightedEdge: cdef public ITYPE_t a cdef public ITYPE_t b cdef public DTYPE_t weight - + def __init__(self, DTYPE_t weight, ITYPE_t a, ITYPE_t b): self.weight = weight self.a = a @@ -326,7 +326,7 @@ cdef class WeightedEdge: return self.weight > other.weight elif op == 5: return self.weight >= other.weight - + def __repr__(self): return "%s(weight=%f, a=%i, b=%i)" % (self.__class__.__name__, self.weight, @@ -475,7 +475,7 @@ def mst_linkage_core( dist_metric: DistanceMetric A DistanceMetric object conforming to the API from - ``sklearn.neighbors._dist_metrics.pxd`` that will be + ``sklearn.metrics._dist_metrics.pxd`` that will be used to compute distances. Returns @@ -534,4 +534,3 @@ def mst_linkage_core( current_node = new_node return np.array(result) - diff --git a/sklearn/cluster/tests/test_hierarchical.py b/sklearn/cluster/tests/test_hierarchical.py index 8aff7136c574f..73fee94b1b016 100644 --- a/sklearn/cluster/tests/test_hierarchical.py +++ b/sklearn/cluster/tests/test_hierarchical.py @@ -16,7 +16,7 @@ from scipy.cluster import hierarchy from sklearn.metrics.cluster import adjusted_rand_score -from sklearn.neighbors.tests.test_dist_metrics import METRICS_DEFAULT_PARAMS +from sklearn.metrics.tests.test_dist_metrics import METRICS_DEFAULT_PARAMS from sklearn.utils._testing import assert_almost_equal, create_memmap_backed_data from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import ignore_warnings @@ -30,6 +30,7 @@ _fix_connectivity, ) from sklearn.feature_extraction.image import grid_to_graph +from sklearn.metrics import DistanceMetric from sklearn.metrics.pairwise import ( PAIRED_DISTANCES, cosine_distances, @@ -37,7 +38,7 @@ pairwise_distances, ) from sklearn.metrics.cluster import normalized_mutual_info_score -from sklearn.neighbors import kneighbors_graph, DistanceMetric +from sklearn.neighbors import kneighbors_graph from sklearn.cluster._hierarchical_fast import ( average_merge, max_merge, diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index a0b06a02ad6d1..68409a7f85d35 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -36,6 +36,8 @@ from ._classification import brier_score_loss from ._classification import multilabel_confusion_matrix +from ._dist_metrics import DistanceMetric + from . import cluster from .cluster import adjusted_mutual_info_score from .cluster import adjusted_rand_score @@ -113,6 +115,7 @@ "davies_bouldin_score", "DetCurveDisplay", "det_curve", + "DistanceMetric", "euclidean_distances", "explained_variance_score", "f1_score", diff --git a/sklearn/metrics/_argkmin_fast.pyx b/sklearn/metrics/_argkmin_fast.pyx deleted file mode 100644 index 46549816d3b1b..0000000000000 --- a/sklearn/metrics/_argkmin_fast.pyx +++ /dev/null @@ -1,467 +0,0 @@ -# cython: language_level=3 -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False -# cython: profile=False -# cython: linetrace=False -# cython: binding=False -# distutils: define_macros=CYTHON_TRACE_NOGIL=0 - -import numpy as np -cimport numpy as np - -from libc.math cimport floor, sqrt -from libc.stdlib cimport free, malloc - -from cython cimport floating -from cython.parallel cimport parallel, prange - -DEF CHUNK_SIZE = 256 # number of vectors - -DEF MIN_CHUNK_SAMPLES = 20 - -DEF FLOAT_INF = 1e36 - -from ..utils._cython_blas cimport ( - BLAS_Order, - BLAS_Trans, - ColMajor, - NoTrans, - RowMajor, - Trans, - _gemm, -) - -from ..utils._heap cimport _simultaneous_sort, _push -from ..utils._openmp_helpers import _openmp_effective_n_threads -from ..utils._typedefs cimport ITYPE_t -from ..utils._typedefs import ITYPE - - -### argkmin helpers - -cdef void _argkmin_on_chunk( - floating[:, ::1] X_c, # IN - floating[:, ::1] Y_c, # IN - floating[::1] Y_sq_norms, # IN - floating *dist_middle_terms, # IN - floating *heaps_red_distances, # IN/OUT - ITYPE_t *heaps_indices, # IN/OUT - ITYPE_t k, # IN - # ID of the first element of Y_c - ITYPE_t Y_idx_offset, -) nogil: - """ - Critical part of the computation of pairwise distances. - - "Fast Squared Euclidean" distances strategy relying - on the gemm-trick. - """ - cdef: - ITYPE_t i, j - # Instead of computing the full pairwise squared distances matrix, - # ||X_c - Y_c||² = ||X_c||² - 2 X_c.Y_c^T + ||Y_c||², - # we only need to store the - 2 X_c.Y_c^T + ||Y_c||² - # term since the argmin for a given sample X_c^{i} does not depend on - # ||X_c^{i}||² - - # Careful: LDA, LDB and LDC are given for F-ordered arrays. - # Here, we use their counterpart values as indicated in the documentation. - # See the documentation of parameters here: - # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html - # - # dist_middle_terms = -2 * X_c.dot(Y_c.T) - _gemm(RowMajor, NoTrans, Trans, - X_c.shape[0], Y_c.shape[0], X_c.shape[1], - -2.0, - &X_c[0, 0], X_c.shape[1], - &Y_c[0, 0], X_c.shape[1], 0.0, - dist_middle_terms, Y_c.shape[0]) - - # Computing argmins here - for i in range(X_c.shape[0]): - for j in range(Y_c.shape[0]): - _push(heaps_red_distances + i * k, - heaps_indices + i * k, - k, - # reduced distance: - 2 X_c_i.Y_c_j^T + ||Y_c_j||² - dist_middle_terms[i * Y_c.shape[0] + j] + Y_sq_norms[j], - j + Y_idx_offset) - - - -cdef int _argkmin_on_X( - floating[:, ::1] X, # IN - floating[:, ::1] Y, # IN - floating[::1] Y_sq_norms, # IN - ITYPE_t chunk_size, # IN - ITYPE_t effective_n_threads, # IN - ITYPE_t[:, ::1] argkmin_indices, # OUT - floating[:, ::1] argkmin_red_distances, # OUT -) nogil: - """Computes the argkmin of each vector (row) of X on Y - by parallelising computation on chunks of X. - """ - cdef: - ITYPE_t k = argkmin_indices.shape[1] - ITYPE_t d = X.shape[1] - ITYPE_t sf = sizeof(floating) - ITYPE_t si = sizeof(ITYPE_t) - ITYPE_t n_samples_chunk = max(MIN_CHUNK_SAMPLES, chunk_size) - - ITYPE_t n_train = Y.shape[0] - ITYPE_t Y_n_samples_chunk = min(n_train, n_samples_chunk) - ITYPE_t Y_n_full_chunks = n_train / Y_n_samples_chunk - ITYPE_t Y_n_samples_rem = n_train % Y_n_samples_chunk - - ITYPE_t n_test = X.shape[0] - ITYPE_t X_n_samples_chunk = min(n_test, n_samples_chunk) - ITYPE_t X_n_full_chunks = n_test // X_n_samples_chunk - ITYPE_t X_n_samples_rem = n_test % X_n_samples_chunk - - # Counting remainder chunk in total number of chunks - ITYPE_t Y_n_chunks = Y_n_full_chunks + ( - n_train != (Y_n_full_chunks * Y_n_samples_chunk) - ) - - ITYPE_t X_n_chunks = X_n_full_chunks + ( - n_test != (X_n_full_chunks * X_n_samples_chunk) - ) - - ITYPE_t num_threads = min(Y_n_chunks, effective_n_threads) - - ITYPE_t Y_start, Y_end, X_start, X_end - ITYPE_t X_chunk_idx, Y_chunk_idx, idx, jdx - - floating *dist_middle_terms_chunks - floating *heaps_red_distances_chunks - - - with nogil, parallel(num_threads=num_threads): - # Thread local buffers - - # Temporary buffer for the -2 * X_c.dot(Y_c.T) term - dist_middle_terms_chunks = malloc(Y_n_samples_chunk * X_n_samples_chunk * sf) - heaps_red_distances_chunks = malloc(X_n_samples_chunk * k * sf) - - for X_chunk_idx in prange(X_n_chunks, schedule='static'): - # We reset the heap between X chunks (memset isn't suitable here) - for idx in range(X_n_samples_chunk * k): - heaps_red_distances_chunks[idx] = FLOAT_INF - - X_start = X_chunk_idx * X_n_samples_chunk - if X_chunk_idx == X_n_chunks - 1 and X_n_samples_rem > 0: - X_end = X_start + X_n_samples_rem - else: - X_end = X_start + X_n_samples_chunk - - for Y_chunk_idx in range(Y_n_chunks): - Y_start = Y_chunk_idx * Y_n_samples_chunk - if Y_chunk_idx == Y_n_chunks - 1 and Y_n_samples_rem > 0: - Y_end = Y_start + Y_n_samples_rem - else: - Y_end = Y_start + Y_n_samples_chunk - - _argkmin_on_chunk( - X[X_start:X_end, :], - Y[Y_start:Y_end, :], - Y_sq_norms[Y_start:Y_end], - dist_middle_terms_chunks, - heaps_red_distances_chunks, - &argkmin_indices[X_start, 0], - k, - Y_start - ) - - # Sorting indices so that the closests' come first. - for idx in range(X_end - X_start): - _simultaneous_sort( - heaps_red_distances_chunks + idx * k, - &argkmin_indices[X_start + idx, 0], - k - ) - - # end: for X_chunk_idx - free(dist_middle_terms_chunks) - free(heaps_red_distances_chunks) - - # end: with nogil, parallel - return X_n_chunks - - -cdef int _argkmin_on_Y( - floating[:, ::1] X, # IN - floating[:, ::1] Y, # IN - floating[::1] Y_sq_norms, # IN - ITYPE_t chunk_size, # IN - ITYPE_t effective_n_threads, # IN - ITYPE_t[:, ::1] argkmin_indices, # OUT - floating[:, ::1] argkmin_red_distances, # OUT -) nogil: - """Computes the argkmin of each vector (row) of X on Y - by parallelising computation on chunks of Y. - - This parallelisation strategy is more costly (as we need - extra heaps and synchronisation), yet it is useful in - most contexts. - """ - cdef: - ITYPE_t k = argkmin_indices.shape[1] - ITYPE_t d = X.shape[1] - ITYPE_t sf = sizeof(floating) - ITYPE_t si = sizeof(ITYPE_t) - ITYPE_t n_samples_chunk = max(MIN_CHUNK_SAMPLES, chunk_size) - - ITYPE_t n_train = Y.shape[0] - ITYPE_t Y_n_samples_chunk = min(n_train, n_samples_chunk) - ITYPE_t Y_n_full_chunks = n_train / Y_n_samples_chunk - ITYPE_t Y_n_samples_rem = n_train % Y_n_samples_chunk - - ITYPE_t n_test = X.shape[0] - ITYPE_t X_n_samples_chunk = min(n_test, n_samples_chunk) - ITYPE_t X_n_full_chunks = n_test // X_n_samples_chunk - ITYPE_t X_n_samples_rem = n_test % X_n_samples_chunk - - # Counting remainder chunk in total number of chunks - ITYPE_t Y_n_chunks = Y_n_full_chunks + ( - n_train != (Y_n_full_chunks * Y_n_samples_chunk) - ) - - ITYPE_t X_n_chunks = X_n_full_chunks + ( - n_test != (X_n_full_chunks * X_n_samples_chunk) - ) - - ITYPE_t num_threads = min(Y_n_chunks, effective_n_threads) - - ITYPE_t Y_start, Y_end, X_start, X_end - ITYPE_t X_chunk_idx, Y_chunk_idx, idx, jdx - - floating *dist_middle_terms_chunks - floating *heaps_red_distances_chunks - - # As chunks of X are shared across threads, so must their - # heaps. To solve this, each thread has its own locals - # heaps which are then synchronised back in the main ones. - ITYPE_t *heaps_indices_chunks - - for X_chunk_idx in range(X_n_chunks): - X_start = X_chunk_idx * X_n_samples_chunk - if X_chunk_idx == X_n_chunks - 1 and X_n_samples_rem > 0: - X_end = X_start + X_n_samples_rem - else: - X_end = X_start + X_n_samples_chunk - - with nogil, parallel(num_threads=num_threads): - # Thread local buffers - - # Temporary buffer for the -2 * X_c.dot(Y_c.T) term - dist_middle_terms_chunks = malloc( - Y_n_samples_chunk * X_n_samples_chunk * sf) - heaps_red_distances_chunks = malloc( - X_n_samples_chunk * k * sf) - heaps_indices_chunks = malloc( - X_n_samples_chunk * k * sf) - - # Initialising heaps (memset can't be used here) - for idx in range(X_n_samples_chunk * k): - heaps_red_distances_chunks[idx] = FLOAT_INF - heaps_indices_chunks[idx] = -1 - - for Y_chunk_idx in prange(Y_n_chunks, schedule='static'): - Y_start = Y_chunk_idx * Y_n_samples_chunk - if Y_chunk_idx == Y_n_chunks - 1 \ - and Y_n_samples_rem > 0: - Y_end = Y_start + Y_n_samples_rem - else: - Y_end = Y_start + Y_n_samples_chunk - - _argkmin_on_chunk( - X[X_start:X_end, :], - Y[Y_start:Y_end, :], - Y_sq_norms[Y_start:Y_end], - dist_middle_terms_chunks, - heaps_red_distances_chunks, - heaps_indices_chunks, - k, - Y_start, - ) - - # end: for Y_chunk_idx - with gil: - # Synchronising the thread local heaps - # with the main heaps - for idx in range(X_end - X_start): - for jdx in range(k): - _push( - &argkmin_red_distances[X_start + idx, 0], - &argkmin_indices[X_start + idx, 0], - k, - heaps_red_distances_chunks[idx * k + jdx], - heaps_indices_chunks[idx * k + jdx], - ) - - free(dist_middle_terms_chunks) - free(heaps_red_distances_chunks) - free(heaps_indices_chunks) - - # end: with nogil, parallel - # Sorting indices of the argkmin for each query vector of X - for idx in prange(n_test,schedule='static', - nogil=True, num_threads=num_threads): - _simultaneous_sort( - &argkmin_red_distances[idx, 0], - &argkmin_indices[idx, 0], - k, - ) - # end: prange - - # end: for X_chunk_idx - return Y_n_chunks - -cdef inline floating _euclidean_dist( - floating[:, ::1] X, - floating[:, ::1] Y, - ITYPE_t i, - ITYPE_t j, -) nogil: - cdef: - floating dist = 0 - ITYPE_t k - ITYPE_t upper_unrolled_idx = (X.shape[1] // 4) * 4 - - # Unrolling loop to help with vectorisation - for k in range(0, upper_unrolled_idx, 4): - dist += (X[i, k] - Y[j, k]) * (X[i, k] - Y[j, k]) - dist += (X[i, k + 1] - Y[j, k + 1]) * (X[i, k + 1] - Y[j, k + 1]) - dist += (X[i, k + 2] - Y[j, k + 2]) * (X[i, k + 2] - Y[j, k + 2]) - dist += (X[i, k + 3] - Y[j, k + 3]) * (X[i, k + 3] - Y[j, k + 3]) - - for k in range(upper_unrolled_idx, X.shape[1]): - dist += (X[i, k] - Y[j, k]) * (X[i, k] - Y[j, k]) - - return sqrt(dist) - -cdef int _exact_euclidean_dist( - floating[:, ::1] X, # IN - floating[:, ::1] Y, # IN - ITYPE_t[:, ::1] Y_indices, # IN - ITYPE_t effective_n_threads, # IN - floating[:, ::1] distances, # OUT -) nogil: - """ - Compute exact pairwise euclidean distances in parallel. - - The pairwise distances considered are X vectors - and a subset of Y given for each row if X given in - Y_indices. - - Notes: the body of this function could have been inlined, - but we use a function to have a cdef nogil context. - """ - cdef: - ITYPE_t i, k - - for i in prange(X.shape[0], schedule='static', - nogil=True, num_threads=effective_n_threads): - for k in range(Y_indices.shape[1]): - distances[i, k] = _euclidean_dist(X, Y, i, - Y_indices[i, k]) - - -# Python interface - -def _argkmin( - floating[:, ::1] X, - floating[:, ::1] Y, - ITYPE_t k, - ITYPE_t chunk_size = CHUNK_SIZE, - str strategy = "auto", - bint return_distance = False, -): - """Computes the argkmin of vectors (rows) of X on Y for - the euclidean distance. - - The implementation is parallelised on chunks whose size can - be set using ``chunk_size``. - - Parameters - ---------- - X: ndarray of shape (n, d) - Rows represent vectors - - Y: ndarray of shape (m, d) - Rows represent vectors - - chunk_size: int - The number of vectors per chunk. - - strategy: str, {'auto', 'chunk_on_X', 'chunk_on_Y'} - The chunking strategy defining which dataset - parallelisation are made on. - - - 'chunk_on_X' is embarassingly parallel but - is less used in practice. - - 'chunk_on_Y' comes with synchronisation but - is more useful in practice. - -'auto' relies on a simple heuristic to choose - between 'chunk_on_X' and 'chunk_on_Y'. - - return_distance: boolean - Return distances between each X vectory and its - argkmin if set to True. - - Returns - ------- - distances: ndarray of shape (n, k) - Distances between each X vector and its argkmin - in Y. Only returned if ``return_distance=True``. - - indices: ndarray of shape (n, k) - Indices of each X vector argkmin in Y. - """ - int_dtype = np.intp - float_dtype = np.float32 if floating is float else np.float64 - cdef: - ITYPE_t[:, ::1] argkmin_indices = np.full((X.shape[0], k), 0, - dtype=ITYPE) - floating[:, ::1] argkmin_distances = np.full((X.shape[0], k), - FLOAT_INF, - dtype=float_dtype) - floating[::1] Y_sq_norms = np.einsum('ij,ij->i', Y, Y) - ITYPE_t effective_n_threads = _openmp_effective_n_threads() - - if strategy == 'auto': - # This is a simple heuristic whose constant for the - # comparison has been chosen based on experiments. - if 4 * chunk_size * effective_n_threads < X.shape[0]: - strategy = 'chunk_on_X' - else: - strategy = 'chunk_on_Y' - - if strategy == 'chunk_on_Y': - _argkmin_on_Y( - X, Y, Y_sq_norms, - chunk_size, effective_n_threads, - argkmin_indices, argkmin_distances - ) - elif strategy == 'chunk_on_X': - _argkmin_on_X( - X, Y, Y_sq_norms, - chunk_size, effective_n_threads, - argkmin_indices, argkmin_distances - ) - else: - raise RuntimeError(f"strategy '{strategy}' not supported.") - - if return_distance: - # We need to recompute distances because we relied on - # reduced distances using _gemm, which are missing a - # term for squarred norms and which are not the most - # precise (catastrophic cancellation might have happened). - _exact_euclidean_dist(X, Y, argkmin_indices, - effective_n_threads, - argkmin_distances) - return (np.asarray(argkmin_distances), - np.asarray(argkmin_indices)) - - return np.asarray(argkmin_indices) diff --git a/sklearn/neighbors/_dist_metrics.pxd b/sklearn/metrics/_dist_metrics.pxd similarity index 100% rename from sklearn/neighbors/_dist_metrics.pxd rename to sklearn/metrics/_dist_metrics.pxd diff --git a/sklearn/neighbors/_dist_metrics.pyx b/sklearn/metrics/_dist_metrics.pyx similarity index 99% rename from sklearn/neighbors/_dist_metrics.pyx rename to sklearn/metrics/_dist_metrics.pyx index c9941cab0fc60..8d28773821127 100755 --- a/sklearn/neighbors/_dist_metrics.pyx +++ b/sklearn/metrics/_dist_metrics.pyx @@ -108,7 +108,7 @@ cdef class DistanceMetric: Examples -------- - >>> from sklearn.neighbors import DistanceMetric + >>> from sklearn.metrics import DistanceMetric >>> dist = DistanceMetric.get_metric('euclidean') >>> X = [[0, 1, 2], [3, 4, 5]] @@ -513,7 +513,7 @@ cdef class ChebyshevDistance(DistanceMetric): Examples -------- - >>> from sklearn.neighbors.dist_metrics import DistanceMetric + >>> from sklearn.metrics import DistanceMetric >>> dist = DistanceMetric.get_metric('chebyshev') >>> X = [[0, 1, 2], ... [3, 4, 5]] diff --git a/sklearn/metrics/_parallel_reductions.pyx b/sklearn/metrics/_parallel_reductions.pyx new file mode 100644 index 0000000000000..f57d7bcd9fd5e --- /dev/null +++ b/sklearn/metrics/_parallel_reductions.pyx @@ -0,0 +1,765 @@ +# cython: language_level=3 +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False +# cython: profile=False +# cython: linetrace=False +# cython: initializedcheck=False +# cython: binding=False +# distutils: define_macros=CYTHON_TRACE_NOGIL=0 + +import numpy as np +cimport numpy as np +cimport openmp + +from libc.math cimport sqrt +from libc.stdlib cimport free, malloc + +from cython.parallel cimport parallel, prange + +from ._dist_metrics cimport DistanceMetric +from ._dist_metrics import METRIC_MAPPING +from ..utils import check_array + +DEF CHUNK_SIZE = 256 # number of vectors + +DEF MIN_CHUNK_SAMPLES = 20 + +DEF FLOAT_INF = 1e36 + +from ..utils._cython_blas cimport ( + BLAS_Order, + BLAS_Trans, + ColMajor, + NoTrans, + RowMajor, + Trans, + _gemm, +) + +from ..utils._heap cimport _simultaneous_sort, _push +from ..utils._openmp_helpers import _openmp_effective_n_threads +from ..utils._typedefs cimport ITYPE_t, DTYPE_t +from ..utils._typedefs import ITYPE, DTYPE + + +cdef inline DTYPE_t _euclidean_dist( + DTYPE_t[:, ::1] X, + DTYPE_t[:, ::1] Y, + ITYPE_t i, + ITYPE_t j, +) nogil: + cdef: + DTYPE_t dist = 0 + ITYPE_t k + ITYPE_t upper_unrolled_idx = (X.shape[1] // 4) * 4 + + # Unrolling loop to help with vectorisation + for k in range(0, upper_unrolled_idx, 4): + dist += (X[i, k] - Y[j, k]) * (X[i, k] - Y[j, k]) + dist += (X[i, k + 1] - Y[j, k + 1]) * (X[i, k + 1] - Y[j, k + 1]) + dist += (X[i, k + 2] - Y[j, k + 2]) * (X[i, k + 2] - Y[j, k + 2]) + dist += (X[i, k + 3] - Y[j, k + 3]) * (X[i, k + 3] - Y[j, k + 3]) + + for k in range(upper_unrolled_idx, X.shape[1]): + dist += (X[i, k] - Y[j, k]) * (X[i, k] - Y[j, k]) + + return sqrt(dist) + + +cdef class ParallelReduction: + """Abstract class to computes a reduction of a set of + vectors (rows) of X on another set of vectors (rows) of Y. + + The implementation of the reduction is done parallelized + on chunks whose size can be set using ``chunk_size``. + Parameters + ---------- + X: ndarray of shape (n, d) + Rows represent vectors + Y: ndarray of shape (m, d) + Rows represent vectors + distance_metric: DistanceMetric + The distance to use + chunk_size: int + The number of vectors per chunk + """ + + cdef: + const DTYPE_t[:, ::1] X # shape: (n_X, d) + const DTYPE_t[:, ::1] Y # shape: (n_Y, d) + + DistanceMetric distance_metric + + ITYPE_t effective_omp_n_thread + ITYPE_t n_samples_chunk, chunk_size + + ITYPE_t d + + # dtypes sizes + ITYPE_t sf, si + + ITYPE_t n_X, X_n_samples_chunk, X_n_chunks, X_n_samples_rem + ITYPE_t n_Y, Y_n_samples_chunk, Y_n_chunks, Y_n_samples_rem + + @classmethod + def valid_metrics(cls): + return {*METRIC_MAPPING.keys()} + + def __cinit__(self): + # Initializing memory view to prevent memory errors and seg-faults + # in rare cases where __init__ is not called + self.X = np.empty((1, 1), dtype=DTYPE, order='c') + self.Y = np.empty((1, 1), dtype=DTYPE, order='c') + + def __init__(self, + X, + Y, + DistanceMetric distance_metric, + ITYPE_t chunk_size = CHUNK_SIZE, + ): + cdef: + ITYPE_t X_n_full_chunks, Y_n_full_chunks + + self.effective_omp_n_thread = _openmp_effective_n_threads() + + self.X = check_array(X, dtype=DTYPE) + self.Y = check_array(Y, dtype=DTYPE) + + assert X.shape[1] == Y.shape[1], "Vectors of X and Y must have the " \ + "same dimension but currently are " \ + f"respectively {X.shape[1]}-dimensional " \ + f"and {Y.shape[1]}-dimensional." + distance_metric._validate_data(X) + distance_metric._validate_data(Y) + + self.d = X.shape[1] + self.sf = sizeof(DTYPE_t) + self.si = sizeof(ITYPE_t) + self.chunk_size = chunk_size + self.n_samples_chunk = max(MIN_CHUNK_SAMPLES, chunk_size) + + self.distance_metric = distance_metric + + self.n_Y = Y.shape[0] + self.Y_n_samples_chunk = min(self.n_Y, self.n_samples_chunk) + Y_n_full_chunks = self.n_Y // self.Y_n_samples_chunk + self.Y_n_samples_rem = self.n_Y % self.Y_n_samples_chunk + + self.n_X = X.shape[0] + self.X_n_samples_chunk = min(self.n_X, self.n_samples_chunk) + X_n_full_chunks = self.n_X // self.X_n_samples_chunk + self.X_n_samples_rem = self.n_X % self.X_n_samples_chunk + + # Counting remainder chunk in total number of chunks + self.Y_n_chunks = Y_n_full_chunks + ( + self.n_Y != (Y_n_full_chunks * self.Y_n_samples_chunk) + ) + + self.X_n_chunks = X_n_full_chunks + ( + self.n_X != (X_n_full_chunks * self.X_n_samples_chunk) + ) + + def __dealloc__(self): + pass + + cdef void _on_X_parallel_init(self, + ITYPE_t thread_num, + ) nogil: + return + + cdef void _on_X_parallel_finalize(self, + ITYPE_t thread_num + ) nogil: + return + + cdef void _on_X_prange_iter_init(self, + ITYPE_t thread_num, + ITYPE_t X_chunk_idx, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + return + + cdef void _on_X_prange_iter_finalize(self, + ITYPE_t thread_num, + ITYPE_t X_chunk_idx, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + return + + cdef void _parallel_on_X(self) nogil: + """Computes the reduction of each vector (row) of X on Y + by parallelizing computation on chunks of X. + + Private datastructures are modified internally by threads. + + Private template methods can be implemented on subclasses to + interact with those datastructures at various stages. + """ + cdef: + ITYPE_t Y_start, Y_end, X_start, X_end, X_chunk_idx, Y_chunk_idx + ITYPE_t num_threads = min(self.X_n_chunks, self.effective_omp_n_thread) + ITYPE_t thread_num + + with nogil, parallel(num_threads=num_threads): + thread_num = openmp.omp_get_thread_num() + + # Allocating thread local datastructures + self._on_X_parallel_init(thread_num) + + for X_chunk_idx in prange(self.X_n_chunks, schedule='static'): + X_start = X_chunk_idx * self.X_n_samples_chunk + if X_chunk_idx == self.X_n_chunks - 1 and self.X_n_samples_rem > 0: + X_end = X_start + self.X_n_samples_rem + else: + X_end = X_start + self.X_n_samples_chunk + + # Reinitializing thread local datastructures for the new X chunk + self._on_X_prange_iter_init(thread_num, X_chunk_idx, X_start, X_end) + + for Y_chunk_idx in range(self.Y_n_chunks): + Y_start = Y_chunk_idx * self.Y_n_samples_chunk + if Y_chunk_idx == self.Y_n_chunks - 1 and self.Y_n_samples_rem > 0: + Y_end = Y_start + self.Y_n_samples_rem + else: + Y_end = Y_start + self.Y_n_samples_chunk + + self._reduce_on_chunks( + self.X, + self.Y, + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + + # Adjusting thread local datastructures on the full pass on Y + self._on_X_prange_iter_finalize(thread_num, X_chunk_idx, X_start, X_end) + + # end: for X_chunk_idx + + # Deallocating thread local datastructures + self._on_X_parallel_finalize(thread_num) + + # end: with nogil, parallel + return + + cdef void _on_Y_parallel_init(self, + ITYPE_t thread_num, + ) nogil: + return + + cdef void _on_Y_parallel_finalize(self, + ITYPE_t thread_num, + ITYPE_t X_chunk_idx, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + return + + cdef void _on_Y_finalize(self, + ITYPE_t thread_num, + ) nogil: + return + + cdef void _parallel_on_Y(self) nogil: + """Computes the argkmin of each vector (row) of X on Y + by parallelizing computation on chunks of Y. + + Private datastructures are modified internally by threads. + + Private template methods can be implemented on subclasses to + interact with those datastructures at various stages. + """ + cdef: + ITYPE_t Y_start, Y_end, X_start, X_end, X_chunk_idx, Y_chunk_idx + ITYPE_t num_threads = min(self.X_n_chunks, self.effective_omp_n_thread) + ITYPE_t thread_num + + for X_chunk_idx in range(self.X_n_chunks): + X_start = X_chunk_idx * self.X_n_samples_chunk + if X_chunk_idx == self.X_n_chunks - 1 and self.X_n_samples_rem > 0: + X_end = X_start + self.X_n_samples_rem + else: + X_end = X_start + self.X_n_samples_chunk + + with nogil, parallel(num_threads=num_threads): + # Thread local buffers + thread_num = openmp.omp_get_thread_num() + + # Allocating thread local datastructures + self._on_Y_parallel_init(thread_num) + + for Y_chunk_idx in prange(self.Y_n_chunks, schedule='static'): + Y_start = Y_chunk_idx * self.Y_n_samples_chunk + if Y_chunk_idx == self.Y_n_chunks - 1 \ + and self.Y_n_samples_rem > 0: + Y_end = Y_start + self.Y_n_samples_rem + else: + Y_end = Y_start + self.Y_n_samples_chunk + + self._reduce_on_chunks( + self.X, + self.Y, + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + # end: prange + + # Synchronizing thread local datastructures with the main ones + # This can potentially block + self._on_Y_parallel_finalize(thread_num, X_chunk_idx, X_start, X_end) + # end: with nogil, parallel + + # end: for X_chunk_idx + # Adjusting main datastructures before returning + self._on_Y_finalize(num_threads) + return + + cdef int _reduce_on_chunks(self, + const DTYPE_t[:, ::1] X, + const DTYPE_t[:, ::1] Y, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil except -1: + """ Abstract method: Sub-classes implemented the reduction + on a pair of chunks""" + return -1 + +cdef class ArgKmin(ParallelReduction): + """Computes the argkmin of vectors (rows) of a set of + vectors (rows) of X on another set of vectors (rows) of Y. + + The implementation is parallelized on chunks whose size can + be set using ``chunk_size``. + + Parameters + ---------- + X: ndarray of shape (n, d) + Rows represent vectors + Y: ndarray of shape (m, d) + Rows represent vectors + distance_metric: DistanceMetric + The distance to use + k: int + The k for the argkmin reduction + chunk_size: int + The number of vectors per chunk + """ + + cdef: + ITYPE_t k + + DTYPE_t ** heaps_approx_distances_chunks + ITYPE_t ** heaps_indices_chunks + + ITYPE_t[:, ::1] argkmin_indices + DTYPE_t[:, ::1] argkmin_distances + + @classmethod + def valid_metrics(cls): + return {"fast_sqeuclidean", *METRIC_MAPPING.keys()} + + @classmethod + def get_for(cls, + X, + Y, + ITYPE_t k, + str metric="fast_sqeuclidean", + ITYPE_t chunk_size=CHUNK_SIZE, + dict metric_kwargs=dict(), + ): + if metric == "fast_sqeuclidean": + return FastSquaredEuclideanArgKmin(X=X, Y=Y, k=k, chunk_size=chunk_size) + return ArgKmin(X=X, Y=Y, + distance_metric=DistanceMetric.get_metric(metric, **metric_kwargs), + k=k, + chunk_size=chunk_size) + + def __init__(self, + X, + Y, + DistanceMetric distance_metric, + ITYPE_t k, + ITYPE_t chunk_size = CHUNK_SIZE, + ): + ParallelReduction.__init__(self, X, Y, distance_metric, chunk_size) + + self.k = k + + # Results returned by ArgKmin.compute + self.argkmin_indices = np.full((self.n_X, self.k), 0, dtype=ITYPE) + self.argkmin_distances = np.full((self.n_X, self.k), FLOAT_INF, dtype=DTYPE) + + # Temporary datastructures used in threads + self.heaps_approx_distances_chunks = malloc(sizeof(DTYPE_t *) * self.effective_omp_n_thread) + self.heaps_indices_chunks = malloc(sizeof(ITYPE_t *) * self.effective_omp_n_thread) + + def __dealloc__(self): + ParallelReduction.__dealloc__(self) + if self.heaps_indices_chunks is not NULL: + free(self.heaps_indices_chunks) + else: + raise RuntimeError("Trying to free heaps_indices_chunks which is NULL") + + if self.heaps_approx_distances_chunks is not NULL: + free(self.heaps_approx_distances_chunks) + else: + raise RuntimeError("Trying to free heaps_approx_distances_chunks which is NULL") + + cdef int _reduce_on_chunks(self, + const DTYPE_t[:, ::1] X, + const DTYPE_t[:, ::1] Y, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil except -1: + cdef: + ITYPE_t i, j + const DTYPE_t[:, ::1] X_c = X[X_start:X_end, :] + const DTYPE_t[:, ::1] Y_c = Y[Y_start:Y_end, :] + ITYPE_t k = self.k + DTYPE_t *heaps_approx_distances = self.heaps_approx_distances_chunks[thread_num] + ITYPE_t *heaps_indices = self.heaps_indices_chunks[thread_num] + + ITYPE_t n_x = X_end - X_start + ITYPE_t n_y = Y_end - Y_start + + for i in range(X_c.shape[0]): + for j in range(Y_c.shape[0]): + _push(heaps_approx_distances + i * self.k, + heaps_indices + i * self.k, + k, + self.distance_metric.rdist(&X_c[i, 0], + &Y_c[j, 0], + self.d), + Y_start + j) + + return 0 + + cdef void _on_X_parallel_init(self, + ITYPE_t thread_num, + ) nogil: + cdef: + # in bytes + ITYPE_t heap_size = self.X_n_samples_chunk * self.k * self.sf + + # Temporary buffer for the -2 * X_c.dot(Y_c.T) term + self.heaps_approx_distances_chunks[thread_num] = malloc(heap_size) + + cdef void _on_X_prange_iter_init(self, + ITYPE_t thread_num, + ITYPE_t X_chunk_idx, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + + # We reset the heap between X chunks (memset can't be used here) + for idx in range(self.X_n_samples_chunk * self.k): + self.heaps_approx_distances_chunks[thread_num][idx] = FLOAT_INF + + # Referencing the thread-local heaps via the thread-scope pointer + # of pointers attached to the instance + self.heaps_indices_chunks[thread_num] = &self.argkmin_indices[X_start, 0] + + cdef void _on_X_prange_iter_finalize(self, + ITYPE_t thread_num, + ITYPE_t X_chunk_idx, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + cdef: + ITYPE_t idx, jdx + + # Sorting indices of the argkmin for each query vector of X + for idx in range(X_end - X_start): + _simultaneous_sort( + self.heaps_approx_distances_chunks[thread_num] + idx * self.k, + &self.argkmin_indices[X_start + idx, 0], + self.k + ) + + cdef void _on_X_parallel_finalize(self, + ITYPE_t thread_num + ) nogil: + free(self.heaps_approx_distances_chunks[thread_num]) + + cdef void _on_Y_parallel_init(self, + ITYPE_t thread_num, + ) nogil: + cdef: + # in bytes + ITYPE_t int_heap_size = self.X_n_samples_chunk * self.k * self.si + ITYPE_t float_heap_size = self.X_n_samples_chunk * self.k * self.sf + + self.heaps_approx_distances_chunks[thread_num] = malloc(float_heap_size) + + # As chunks of X are shared across threads, so must their + # heaps. To solve this, each thread has its own locals + # heaps which are then synchronised back in the main ones. + self.heaps_indices_chunks[thread_num] = malloc(int_heap_size) + + # Initialising heaps (memset can't be used here) + for idx in range(self.X_n_samples_chunk * self.k): + self.heaps_approx_distances_chunks[thread_num][idx] = FLOAT_INF + self.heaps_indices_chunks[thread_num][idx] = -1 + + cdef void _on_Y_parallel_finalize(self, + ITYPE_t thread_num, + ITYPE_t X_chunk_idx, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + cdef: + ITYPE_t idx, jdx + with gil: + # Synchronising the thread local heaps + # with the main heaps + for idx in range(X_end - X_start): + for jdx in range(self.k): + _push( + &self.argkmin_distances[X_start + idx, 0], + &self.argkmin_indices[X_start + idx, 0], + self.k, + self.heaps_approx_distances_chunks[thread_num][idx * self.k + jdx], + self.heaps_indices_chunks[thread_num][idx * self.k + jdx], + ) + + free(self.heaps_approx_distances_chunks[thread_num]) + free(self.heaps_indices_chunks[thread_num]) + + cdef void _on_Y_finalize(self, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t num_threads = min(self.X_n_chunks, self.effective_omp_n_thread) + ITYPE_t idx + + # Sorting indices of the argkmin for each query vector of X + for idx in prange(self.n_X, schedule='static', + nogil=True, num_threads=num_threads): + _simultaneous_sort( + &self.argkmin_distances[idx, 0], + &self.argkmin_indices[idx, 0], + self.k, + ) + return + + cdef void _exact_distances(self, + ITYPE_t[:, ::1] Y_indices, # IN + DTYPE_t[:, ::1] distances, # IN/OUT + ) nogil: + """Convert reduced distances to pairwise distances in parallel.""" + cdef: + ITYPE_t i, j + + for i in prange(self.n_X, schedule='static', nogil=True, + num_threads=self.effective_omp_n_thread): + for j in range(self.k): + distances[i, j] = self.distance_metric.dist(&self.X[i, 0], + &self.Y[Y_indices[i, j], 0], + self.d) + + # Python interface + def compute(self, + str strategy = "auto", + bint return_distance = False + ): + """Computes the reduction of vectors (rows) of X on Y. + + strategy: str, {'auto', 'parallel_on_X', 'parallel_on_Y'} + The chunking strategy defining which dataset + parallelization are made on. + + - 'parallel_on_X' is embarassingly parallel but + is less used in practice. + - 'parallel_on_Y' comes with synchronisation but + is more useful in practice. + -'auto' relies on a simple heuristic to choose + between 'parallel_on_X' and 'parallel_on_Y'. + + return_distance: boolean + Return distances between each X vector and its + argkmin if set to True. + + Returns + ------- + distances: ndarray of shape (n, k) + Distances between each X vector and its argkmin + in Y. Only returned if ``return_distance=True``. + + indices: ndarray of shape (n, k) + Indices of each X vector argkmin in Y. + """ + if strategy == 'auto': + # This is a simple heuristic whose constant for the + # comparison has been chosen based on experiments. + if 4 * self.chunk_size * self.effective_omp_n_thread < self.n_X: + strategy = 'parallel_on_X' + else: + strategy = 'parallel_on_Y' + + if strategy == 'parallel_on_Y': + self._parallel_on_Y() + elif strategy == 'parallel_on_X': + self._parallel_on_X() + else: + raise RuntimeError(f"strategy '{strategy}' not supported.") + + if return_distance: + # We need to recompute distances because we relied on + # reduced distances. + self._exact_distances(self.argkmin_indices, self.argkmin_distances) + return np.asarray(self.argkmin_distances), np.asarray(self.argkmin_indices) + + return np.asarray(self.argkmin_indices) + +cdef class FastSquaredEuclideanArgKmin(ArgKmin): + """Fast specialized alternative for ArgKmin on + EuclideanDistance. + + Computes the argkmin of vectors (rows) of a set of + vectors (rows) of X on another set of vectors (rows) of Y + using the GEMM-trick. + + This implementation has an superior arithmetic intensity + and hence running time, but it can suffer from numerical + instability. We recommend using ArgKmin with + EuclideanDistance when exact precision is needed. + """ + + cdef: + DTYPE_t[::1] Y_sq_norms + + # Buffers for GEMM + DTYPE_t ** dist_middle_terms_chunks + + def __init__(self, + X, + Y, + ITYPE_t k, + ITYPE_t chunk_size = CHUNK_SIZE, + ): + ArgKmin.__init__(self, X, Y, + distance_metric=DistanceMetric.get_metric("euclidean"), + k=k, + chunk_size=chunk_size) + self.Y_sq_norms = np.einsum('ij,ij->i', self.Y, self.Y) + # Temporary datastructures used in threads + self.dist_middle_terms_chunks = malloc(sizeof(DTYPE_t *) * self.effective_omp_n_thread) + + def __dealloc__(self): + ArgKmin.__dealloc__(self) + if self.dist_middle_terms_chunks is not NULL: + free(self.dist_middle_terms_chunks) + else: + raise RuntimeError("Trying to free dist_middle_terms_chunks which is NULL") + + cdef void _on_X_parallel_init(self, + ITYPE_t thread_num, + ) nogil: + ArgKmin._on_X_parallel_init(self, thread_num) + # Temporary buffer for the -2 * X_c.dot(Y_c.T) term + self.dist_middle_terms_chunks[thread_num] = malloc( + self.Y_n_samples_chunk * self.X_n_samples_chunk * self.sf) + + cdef void _on_X_parallel_finalize(self, + ITYPE_t thread_num + ) nogil: + ArgKmin._on_X_parallel_finalize(self, thread_num) + free(self.dist_middle_terms_chunks[thread_num]) + + cdef void _on_Y_parallel_init(self, + ITYPE_t thread_num, + ) nogil: + ArgKmin._on_Y_parallel_init(self, thread_num) + # Temporary buffer for the -2 * X_c.dot(Y_c.T) term + self.dist_middle_terms_chunks[thread_num] = malloc( + self.Y_n_samples_chunk * self.X_n_samples_chunk * self.sf) + + cdef void _on_Y_parallel_finalize(self, + ITYPE_t thread_num, + ITYPE_t X_chunk_idx, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + ArgKmin._on_Y_parallel_finalize(self, thread_num, X_chunk_idx, X_start, X_end) + free(self.dist_middle_terms_chunks[thread_num]) + + cdef int _reduce_on_chunks(self, + const DTYPE_t[:, ::1] X, + const DTYPE_t[:, ::1] Y, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil except -1: + """ + Critical part of the computation of pairwise distances. + + "Fast Squared Euclidean" distances strategy relying + on the gemm-trick. + """ + cdef: + ITYPE_t i, j + const DTYPE_t[:, ::1] X_c = X[X_start:X_end, :] + const DTYPE_t[:, ::1] Y_c = Y[Y_start:Y_end, :] + ITYPE_t k = self.k + DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num] + DTYPE_t *heaps_approx_distances = self.heaps_approx_distances_chunks[thread_num] + ITYPE_t *heaps_indices = self.heaps_indices_chunks[thread_num] + + # Instead of computing the full pairwise squared distances matrix, + # + # ||X_c - Y_c||² = ||X_c||² - 2 X_c.Y_c^T + ||Y_c||², + # + # we only need to store the + # - 2 X_c.Y_c^T + ||Y_c||² + # + # term since the argkmin for a given sample X_c^{i} does not depend on + # ||X_c^{i}||² + # + # This term gets computed efficiently bellow using GEMM from BLAS Level 3. + # + # Careful: LDA, LDB and LDC are given for F-ordered arrays in BLAS documentations, + # for instance: + # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html + # + # Here, we use their counterpart values to work with C-ordered arrays. + BLAS_Order order = RowMajor + BLAS_Trans ta = NoTrans + BLAS_Trans tb = Trans + ITYPE_t m = X_c.shape[0] + ITYPE_t n = Y_c.shape[0] + ITYPE_t K = X_c.shape[1] + DTYPE_t alpha = - 2. + DTYPE_t * A = & X_c[0, 0] + ITYPE_t lda = X_c.shape[1] + DTYPE_t * B = & Y_c[0, 0] + ITYPE_t ldb = X_c.shape[1] + DTYPE_t beta = 0. + DTYPE_t * C = dist_middle_terms + ITYPE_t ldc = Y_c.shape[0] + + # dist_middle_terms = -2 * X_c.dot(Y_c.T) + _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, C, ldc) + + # Pushing the distance and their associated indices on heaps + # which keep tracks of the argkmin. + for i in range(X_c.shape[0]): + for j in range(Y_c.shape[0]): + _push(heaps_approx_distances + i * k, + heaps_indices + i * k, + k, + # approximated distance: - 2 X_c_i.Y_c_j^T + ||Y_c_j||² + dist_middle_terms[i * Y_c.shape[0] + j] + self.Y_sq_norms[j + Y_start], + j + Y_start) + return 0 diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index a20c49a20346c..bcf42ac2ca9fa 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -31,7 +31,7 @@ from ..utils.fixes import delayed from ..utils.fixes import sp_version, parse_version -from ._argkmin_fast import _argkmin +from ._parallel_reductions import ArgKmin from ._pairwise_fast import _chi2_kernel_fast, _sparse_manhattan from ..exceptions import DataConversionWarning @@ -646,18 +646,24 @@ def pairwise_distances_argmin_min( """ X, Y = check_pairwise_arrays(X, Y) - if metric == "fast_sqeuclidean": - # TODO: generalise this simple plug here - values, indices = _argkmin(X, Y, k=1, strategy="auto", return_distance=True) + if axis == 0: + X, Y = Y, X + + if metric_kwargs is None: + metric_kwargs = {} + + if ( + # TODO: support sparse arrays + not issparse(X) + and not issparse(X) + and metric in ArgKmin.valid_metrics() + ): + values, indices = ArgKmin.get_for( + X=X, Y=Y, k=1, metric=metric, metric_kwargs=metric_kwargs + ).compute(strategy="auto", return_distance=True) values = np.ndarray.flatten(values) indices = np.ndarray.flatten(indices) else: - if metric_kwargs is None: - metric_kwargs = {} - - if axis == 0: - X, Y = Y, X - indices, values = zip( *pairwise_distances_chunked( X, Y, reduce_func=_argmin_min_reduce, metric=metric, **metric_kwargs @@ -786,7 +792,7 @@ def haversine_distances(X, Y=None): array([[ 0. , 11099.54035582], [11099.54035582, 0. ]]) """ - from ..neighbors import DistanceMetric + from ..metrics import DistanceMetric return DistanceMetric.get_metric("haversine").pairwise(X, Y) diff --git a/sklearn/metrics/setup.py b/sklearn/metrics/setup.py index 1edd6fe368d5e..6fd445d2c1a00 100644 --- a/sklearn/metrics/setup.py +++ b/sklearn/metrics/setup.py @@ -1,4 +1,5 @@ import os +import numpy as np from numpy.distutils.misc_util import Configuration @@ -19,7 +20,16 @@ def configuration(parent_package="", top_path=None): ) config.add_extension( - "_argkmin_fast", sources=["_argkmin_fast.pyx"], libraries=libraries + "_parallel_reductions", + sources=["_parallel_reductions.pyx"], + libraries=libraries, + ) + + config.add_extension( + "_dist_metrics", + sources=["_dist_metrics.pyx"], + include_dirs=[np.get_include(), os.path.join(np.get_include(), "numpy")], + libraries=libraries, ) config.add_subpackage("tests") diff --git a/sklearn/neighbors/tests/test_dist_metrics.py b/sklearn/metrics/tests/test_dist_metrics.py similarity index 95% rename from sklearn/neighbors/tests/test_dist_metrics.py rename to sklearn/metrics/tests/test_dist_metrics.py index 0703819536916..efa8031c53935 100644 --- a/sklearn/neighbors/tests/test_dist_metrics.py +++ b/sklearn/metrics/tests/test_dist_metrics.py @@ -7,8 +7,7 @@ import pytest from scipy.spatial.distance import cdist -from sklearn.neighbors import DistanceMetric -from sklearn.neighbors import BallTree +from sklearn.metrics import DistanceMetric from sklearn.utils import check_random_state from sklearn.utils._testing import create_memmap_backed_data from sklearn.utils.fixes import sp_version, parse_version @@ -230,16 +229,6 @@ def test_pyfunc_metric(): assert_array_almost_equal(D1_pkl, D2_pkl) -def test_bad_pyfunc_metric(): - def wrong_distance(x, y): - return "1" - - X = np.ones((5, 2)) - msg = "Custom distance function must accept two vectors" - with pytest.raises(TypeError, match=msg): - BallTree(X, metric=wrong_distance) - - def test_input_data_size(): # Regression test for #6288 # Previously, a metric requiring a particular input dimension would fail diff --git a/sklearn/neighbors/__init__.py b/sklearn/neighbors/__init__.py index 8a0934eecf142..3cd1d7925acf6 100644 --- a/sklearn/neighbors/__init__.py +++ b/sklearn/neighbors/__init__.py @@ -5,7 +5,6 @@ from ._ball_tree import BallTree from ._kd_tree import KDTree -from ._dist_metrics import DistanceMetric from ._graph import kneighbors_graph, radius_neighbors_graph from ._graph import KNeighborsTransformer, RadiusNeighborsTransformer from ._unsupervised import NearestNeighbors @@ -19,7 +18,6 @@ __all__ = [ "BallTree", - "DistanceMetric", "KDTree", "KNeighborsClassifier", "KNeighborsRegressor", diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index be561e0bd3f64..85c3cd743cb57 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -23,7 +23,7 @@ from ..base import is_classifier from ..metrics import pairwise_distances_chunked from ..metrics.pairwise import PAIRWISE_DISTANCE_FUNCTIONS -from ..metrics._argkmin_fast import _argkmin +from ..metrics._parallel_reductions import ArgKmin from ..utils import ( check_array, gen_even_slices, @@ -737,13 +737,19 @@ class from an array representing our data set and ask who's ) elif ( - self._fit_method == "brute" and self.effective_metric_ == "fast_sqeuclidean" + # TODO: support sparse arrays + not issparse(X) + and not issparse(self._fit_X) + and self._fit_method == "brute" + and self.effective_metric_ in ArgKmin.valid_metrics() ): - # TODO: generalise this simple plug here - results = _argkmin( - X, + results = ArgKmin.get_for( + X=X, Y=self._fit_X, k=n_neighbors, + metric=self.effective_metric_, + metric_kwargs=self.effective_metric_params_, + ).compute( strategy="auto", return_distance=return_distance, ) @@ -755,12 +761,6 @@ class from an array representing our data set and ask who's return_distance=return_distance, ) - # for efficiency, use squared euclidean distances - if self.effective_metric_ == "euclidean": - kwds = {"squared": True} - else: - kwds = self.effective_metric_params_ - chunked_results = list( pairwise_distances_chunked( X, @@ -768,7 +768,7 @@ class from an array representing our data set and ask who's reduce_func=reduce_func, metric=self.effective_metric_, n_jobs=n_jobs, - **kwds, + **self.effective_metric_params_, ) ) diff --git a/sklearn/neighbors/_binary_tree.pxi b/sklearn/neighbors/_binary_tree.pxi index 37aa13b0a4f30..b64dbecac9e24 100755 --- a/sklearn/neighbors/_binary_tree.pxi +++ b/sklearn/neighbors/_binary_tree.pxi @@ -142,7 +142,6 @@ # BinaryTree tree2, ITYPE_t i_node2): # """Compute the maximum distance between two nodes""" -cimport cython cimport numpy as np from libc.math cimport fabs, sqrt, exp, cos, pow, log, lgamma from libc.math cimport fmin, fmax @@ -152,8 +151,7 @@ from libc.string cimport memcpy import numpy as np import warnings -from ._dist_metrics cimport (DistanceMetric, euclidean_dist, euclidean_rdist, - euclidean_dist_to_rdist, euclidean_rdist_to_dist) +from ..metrics._dist_metrics cimport (DistanceMetric, euclidean_dist, euclidean_rdist, euclidean_dist_to_rdist) from ._partition_nodes cimport partition_node_indices @@ -796,7 +794,7 @@ def newObj(obj): ###################################################################### # define the reverse mapping of VALID_METRICS -from ._dist_metrics import get_valid_metric_ids +from sklearn.metrics._dist_metrics import get_valid_metric_ids VALID_METRIC_IDS = get_valid_metric_ids(VALID_METRICS) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index 1e47e1b8020f2..bf433fea30aea 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -67,8 +67,8 @@ class KNeighborsClassifier(KNeighborsMixin, ClassifierMixin, NeighborsBase): metric : str or callable, default='minkowski' the distance metric to use for the tree. The default metric is minkowski, and with p=2 is equivalent to the standard Euclidean - metric. See the documentation of :class:`DistanceMetric` for a - list of available metrics. + metric. See the documentation of :class:`metrics.DistanceMetric` + for a list of available metrics. If metric is "precomputed", X is assumed to be a distance matrix and must be square during fit. X may be a :term:`sparse graph`, in which case only "nonzero" elements may be considered neighbors. @@ -339,8 +339,8 @@ class RadiusNeighborsClassifier(RadiusNeighborsMixin, ClassifierMixin, Neighbors metric : str or callable, default='minkowski' the distance metric to use for the tree. The default metric is minkowski, and with p=2 is equivalent to the standard Euclidean - metric. See the documentation of :class:`DistanceMetric` for a - list of available metrics. + metric. See the documentation of :class:`metrics.DistanceMetric` + for a list of available metrics. If metric is "precomputed", X is assumed to be a distance matrix and must be square during fit. X may be a :term:`sparse graph`, in which case only "nonzero" elements may be considered neighbors. diff --git a/sklearn/neighbors/_graph.py b/sklearn/neighbors/_graph.py index d5bcaf9408c72..1fcb568e5dff4 100644 --- a/sklearn/neighbors/_graph.py +++ b/sklearn/neighbors/_graph.py @@ -65,10 +65,11 @@ def kneighbors_graph( between neighbors according to the given metric. metric : str, default='minkowski' - The distance metric used to calculate the k-Neighbors for each sample - point. The DistanceMetric class gives a list of available metrics. - The default distance is 'euclidean' ('minkowski' metric with the p - param equal to 2.) + The distance metric used to calculate the neighbors within a + given radius for each sample point. The default distance is + 'euclidean' ('minkowski' metric with the param equal to 2.) + See the documentation of :class:`metrics.DistanceMetric` + for a list of available metrics. p : int, default=2 Power parameter for the Minkowski metric. When p = 1, this is @@ -158,9 +159,10 @@ def radius_neighbors_graph( metric : str, default='minkowski' The distance metric used to calculate the neighbors within a - given radius for each sample point. The DistanceMetric class - gives a list of available metrics. The default distance is + given radius for each sample point. The default distance is 'euclidean' ('minkowski' metric with the param equal to 2.) + See the documentation of :class:`metrics.DistanceMetric` + for a list of available metrics. p : int, default=2 Power parameter for the Minkowski metric. When p = 1, this is diff --git a/sklearn/neighbors/_partition_nodes.pxd b/sklearn/neighbors/_partition_nodes.pxd index 1659801db469d..94b02002d7a1e 100644 --- a/sklearn/neighbors/_partition_nodes.pxd +++ b/sklearn/neighbors/_partition_nodes.pxd @@ -1,4 +1,4 @@ -from sklearn.utils._typedefs cimport DTYPE_t, ITYPE_t +from ..utils._typedefs cimport DTYPE_t, ITYPE_t cdef int partition_node_indices( DTYPE_t *data, diff --git a/sklearn/neighbors/_regression.py b/sklearn/neighbors/_regression.py index fe536f06c20a5..77179f3bb317f 100644 --- a/sklearn/neighbors/_regression.py +++ b/sklearn/neighbors/_regression.py @@ -75,8 +75,8 @@ class KNeighborsRegressor(KNeighborsMixin, RegressorMixin, NeighborsBase): metric : str or callable, default='minkowski' the distance metric to use for the tree. The default metric is minkowski, and with p=2 is equivalent to the standard Euclidean - metric. See the documentation of :class:`DistanceMetric` for a - list of available metrics. + metric. See the documentation of :class:`metrics.DistanceMetric` + for a list of available metrics. If metric is "precomputed", X is assumed to be a distance matrix and must be square during fit. X may be a :term:`sparse graph`, in which case only "nonzero" elements may be considered neighbors. @@ -301,8 +301,8 @@ class RadiusNeighborsRegressor(RadiusNeighborsMixin, RegressorMixin, NeighborsBa metric : str or callable, default='minkowski' the distance metric to use for the tree. The default metric is minkowski, and with p=2 is equivalent to the standard Euclidean - metric. See the documentation of :class:`DistanceMetric` for a - list of available metrics. + metric. See the documentation of :class:`metrics.DistanceMetric` + for a list of available metrics. If metric is "precomputed", X is assumed to be a distance matrix and must be square during fit. X may be a :term:`sparse graph`, in which case only "nonzero" elements may be considered neighbors. diff --git a/sklearn/neighbors/_unsupervised.py b/sklearn/neighbors/_unsupervised.py index 06566b0807b7a..b11df8af8790f 100644 --- a/sklearn/neighbors/_unsupervised.py +++ b/sklearn/neighbors/_unsupervised.py @@ -41,8 +41,8 @@ class NearestNeighbors(KNeighborsMixin, RadiusNeighborsMixin, NeighborsBase): metric : str or callable, default='minkowski' the distance metric to use for the tree. The default metric is minkowski, and with p=2 is equivalent to the standard Euclidean - metric. See the documentation of :class:`DistanceMetric` for a - list of available metrics. + metric. See the documentation of :class:`metrics.DistanceMetric` + for a list of available metrics. If metric is "precomputed", X is assumed to be a distance matrix and must be square during fit. X may be a :term:`sparse graph`, in which case only "nonzero" elements may be considered neighbors. diff --git a/sklearn/neighbors/setup.py b/sklearn/neighbors/setup.py index 34921de75041a..aa19ba501b18d 100644 --- a/sklearn/neighbors/setup.py +++ b/sklearn/neighbors/setup.py @@ -32,13 +32,6 @@ def configuration(parent_package="", top_path=None): libraries=libraries, ) - config.add_extension( - "_dist_metrics", - sources=["_dist_metrics.pyx"], - include_dirs=[numpy.get_include(), os.path.join(numpy.get_include(), "numpy")], - libraries=libraries, - ) - config.add_extension( "_quad_tree", sources=["_quad_tree.pyx"], diff --git a/sklearn/neighbors/tests/test_ball_tree.py b/sklearn/neighbors/tests/test_ball_tree.py index c751539f2a1ae..a823a03251a1b 100644 --- a/sklearn/neighbors/tests/test_ball_tree.py +++ b/sklearn/neighbors/tests/test_ball_tree.py @@ -4,7 +4,6 @@ import pytest from numpy.testing import assert_array_almost_equal from sklearn.neighbors._ball_tree import BallTree -from sklearn.neighbors import DistanceMetric from sklearn.utils import check_random_state from sklearn.utils.validation import check_array from sklearn.utils._testing import _convert_container @@ -40,6 +39,8 @@ def brute_force_neighbors(X, Y, k, metric, **kwargs): + from sklearn.metrics import DistanceMetric + X, Y = check_array(X), check_array(Y) D = DistanceMetric.get_metric(metric, **kwargs).pairwise(Y, X) ind = np.argsort(D, axis=1)[:, :k] @@ -84,3 +85,13 @@ def test_array_object_type(): X = np.array([(1, 2, 3), (2, 5), (5, 5, 1, 2)], dtype=object) with pytest.raises(ValueError, match="setting an array element with a sequence"): BallTree(X) + + +def test_bad_pyfunc_metric(): + def wrong_distance(x, y): + return "1" + + X = np.ones((5, 2)) + msg = "Custom distance function must accept two vectors" + with pytest.raises(TypeError, match=msg): + BallTree(X, metric=wrong_distance) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index ee9e92b0347ee..959ed6bfd7210 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -57,6 +57,7 @@ SPARSE_OR_DENSE = SPARSE_TYPES + (np.asarray,) ALGORITHMS = ("ball_tree", "brute", "kd_tree", "auto") +COMMON_VALID_METRICS = set.intersection(*map(set, neighbors.VALID_METRICS.values())) P = (1, 2, 3, 4, np.inf) JOBLIB_BACKENDS = list(joblib.parallel.BACKENDS.keys()) @@ -77,31 +78,65 @@ def _weight_func(dist): return retval ** 2 +@pytest.mark.parametrize("n_samples", [10 ** i for i in [2, 3]]) +@pytest.mark.parametrize("n_features", [5, 10, 100]) +@pytest.mark.parametrize("n_query_pts", [1, 10, 100]) +@pytest.mark.parametrize("n_neighbors", [1, 10, 100]) +@pytest.mark.parametrize("metric", COMMON_VALID_METRICS) def test_unsupervised_kneighbors( - n_samples=20, n_features=5, n_query_pts=2, n_neighbors=5 + n_samples, + n_features, + n_query_pts, + n_neighbors, + metric, ): - # Test unsupervised neighbors methods - X = rng.rand(n_samples, n_features) + # The different algorithms must return identical results + # on their common metrics, with and without returning + # distances - test = rng.rand(n_query_pts, n_features) + # Redefining the rng locally to use the same generated X + local_rng = np.random.RandomState(0) + X = local_rng.rand(n_samples, n_features) - for p in P: - results_nodist = [] - results = [] + test = local_rng.rand(n_query_pts, n_features) - for algorithm in ALGORITHMS: - neigh = neighbors.NearestNeighbors( - n_neighbors=n_neighbors, algorithm=algorithm, p=p - ) - neigh.fit(X) + results_nodist = [] + results = [] - results_nodist.append(neigh.kneighbors(test, return_distance=False)) - results.append(neigh.kneighbors(test, return_distance=True)) + for algorithm in ALGORITHMS: + neigh = neighbors.NearestNeighbors( + n_neighbors=n_neighbors, algorithm=algorithm, metric=metric + ) + neigh.fit(X) - for i in range(len(results) - 1): - assert_array_almost_equal(results_nodist[i], results[i][1]) - assert_array_almost_equal(results[i][0], results[i + 1][0]) - assert_array_almost_equal(results[i][1], results[i + 1][1]) + results_nodist.append(neigh.kneighbors(test, return_distance=False)) + results.append(neigh.kneighbors(test, return_distance=True)) + + for i in range(len(results) - 1): + algorithm = ALGORITHMS[i] + next_algorithm = ALGORITHMS[i + 1] + + indices_no_dist = results_nodist[i] + distances, next_distances = results[i][0], results[i + 1][0] + indices, next_indices = results[i][1], results[i + 1][1] + assert_array_equal( + indices_no_dist, + indices, + err_msg=f"The '{algorithm}' algorithm returns different" + f"indices depending on 'return_distances'.", + ) + assert_array_equal( + indices, + next_indices, + err_msg=f"The '{algorithm}' and '{next_algorithm}' " + f"algorithms return different indices.", + ) + assert_array_equal( + distances, + next_distances, + err_msg=f"The '{algorithm}' and '{next_algorithm}' " + f"algorithms return different distances.", + ) @pytest.mark.parametrize( @@ -1299,19 +1334,9 @@ def test_neighbors_metrics(n_samples=20, n_features=3, n_query_pts=2, n_neighbor neigh.fit(X[:, feature_sl]) - # wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0 - ExceptionToAssert = None - if ( - metric == "wminkowski" - and algorithm == "brute" - and sp_version >= parse_version("1.6.0") - ): - ExceptionToAssert = DeprecationWarning - - with pytest.warns(ExceptionToAssert): - results[algorithm] = neigh.kneighbors( - test[:, feature_sl], return_distance=True - ) + results[algorithm] = neigh.kneighbors( + test[:, feature_sl], return_distance=True + ) assert_array_almost_equal(results["brute"][0], results["ball_tree"][0]) assert_array_almost_equal(results["brute"][1], results["ball_tree"][1]) @@ -1518,49 +1543,48 @@ def test_k_and_radius_neighbors_X_None(): ) -def test_k_and_radius_neighbors_duplicates(): +@pytest.mark.parametrize("algorithm", ALGORITHMS) +def test_k_and_radius_neighbors_duplicates(algorithm): # Test behavior of kneighbors when duplicates are present in query - - for algorithm in ALGORITHMS: - nn = neighbors.NearestNeighbors(n_neighbors=1, algorithm=algorithm) - nn.fit([[0], [1]]) - - # Do not do anything special to duplicates. - kng = nn.kneighbors_graph([[0], [1]], mode="distance") - assert_array_equal(kng.A, np.array([[0.0, 0.0], [0.0, 0.0]])) - assert_array_equal(kng.data, [0.0, 0.0]) - assert_array_equal(kng.indices, [0, 1]) - - dist, ind = nn.radius_neighbors([[0], [1]], radius=1.5) - check_object_arrays(dist, [[0, 1], [1, 0]]) - check_object_arrays(ind, [[0, 1], [0, 1]]) - - rng = nn.radius_neighbors_graph([[0], [1]], radius=1.5) - assert_array_equal(rng.A, np.ones((2, 2))) - - rng = nn.radius_neighbors_graph([[0], [1]], radius=1.5, mode="distance") - rng.sort_indices() - assert_array_equal(rng.A, [[0, 1], [1, 0]]) - assert_array_equal(rng.indices, [0, 1, 0, 1]) - assert_array_equal(rng.data, [0, 1, 1, 0]) - - # Mask the first duplicates when n_duplicates > n_neighbors. - X = np.ones((3, 1)) - nn = neighbors.NearestNeighbors(n_neighbors=1, algorithm="brute") - nn.fit(X) - dist, ind = nn.kneighbors() - assert_array_equal(dist, np.zeros((3, 1))) - assert_array_equal(ind, [[1], [0], [1]]) - - # Test that zeros are explicitly marked in kneighbors_graph. - kng = nn.kneighbors_graph(mode="distance") - assert_array_equal(kng.A, np.zeros((3, 3))) - assert_array_equal(kng.data, np.zeros(3)) - assert_array_equal(kng.indices, [1.0, 0.0, 1.0]) - assert_array_equal( - nn.kneighbors_graph().A, - np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), - ) + nn = neighbors.NearestNeighbors(n_neighbors=1, algorithm=algorithm) + nn.fit([[0], [1]]) + + # Do not do anything special to duplicates. + kng = nn.kneighbors_graph([[0], [1]], mode="distance") + assert_array_equal(kng.A, np.array([[0.0, 0.0], [0.0, 0.0]])) + assert_array_equal(kng.data, [0.0, 0.0]) + assert_array_equal(kng.indices, [0, 1]) + + dist, ind = nn.radius_neighbors([[0], [1]], radius=1.5) + check_object_arrays(dist, [[0, 1], [1, 0]]) + check_object_arrays(ind, [[0, 1], [0, 1]]) + + rng = nn.radius_neighbors_graph([[0], [1]], radius=1.5) + assert_array_equal(rng.A, np.ones((2, 2))) + + rng = nn.radius_neighbors_graph([[0], [1]], radius=1.5, mode="distance") + rng.sort_indices() + assert_array_equal(rng.A, [[0, 1], [1, 0]]) + assert_array_equal(rng.indices, [0, 1, 0, 1]) + assert_array_equal(rng.data, [0, 1, 1, 0]) + + # Mask the first duplicates when n_duplicates > n_neighbors. + X = np.ones((3, 1)) + nn = neighbors.NearestNeighbors(n_neighbors=1, algorithm="brute") + nn.fit(X) + dist, ind = nn.kneighbors() + assert_array_equal(dist, np.zeros((3, 1))) + assert_array_equal(ind, [[1], [0], [1]]) + + # Test that zeros are explicitly marked in kneighbors_graph. + kng = nn.kneighbors_graph(mode="distance") + assert_array_equal(kng.A, np.zeros((3, 3))) + assert_array_equal(kng.data, np.zeros(3)) + assert_array_equal(kng.indices, [1.0, 0.0, 1.0]) + assert_array_equal( + nn.kneighbors_graph().A, + np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), + ) def test_include_self_neighbors_graph(): @@ -1801,15 +1825,11 @@ def test_pairwise_deprecated(NearestNeighbors): @pytest.mark.parametrize("d", [5, 10, 100]) @pytest.mark.parametrize("ratio_train_test", [10, 2, 1, 0.5]) @pytest.mark.parametrize("n_neighbors", [1, 10, 100, 1000]) -@pytest.mark.parametrize("chunk_size", [2 ** i for i in range(8, 11)]) -@pytest.mark.parametrize("strategy", ["chunk_on_train", "chunk_on_test"]) def test_fast_sqeuclidean_correctness( n, d, ratio_train_test, n_neighbors, - chunk_size, - strategy, dtype=np.float64, ): # The fast squared euclidean strategy must return results @@ -1850,6 +1870,10 @@ def test_fast_sqeuclidean_correctness( @pytest.mark.parametrize("d", [5, 10, 100, 500]) @pytest.mark.parametrize("n_neighbors", [1, 10, 100, 1000]) @pytest.mark.parametrize("translation", [10 ** i for i in [2, 3, 4, 5, 6, 7]]) +@pytest.mark.skip( + reason="Long test, translation invariance should " + "have its own study: skipping for now" +) def test_fast_sqeuclidean_translation_invariance( n, d, diff --git a/sklearn/neighbors/tests/test_neighbors_tree.py b/sklearn/neighbors/tests/test_neighbors_tree.py index de34b4d230171..e043ffb730708 100644 --- a/sklearn/neighbors/tests/test_neighbors_tree.py +++ b/sklearn/neighbors/tests/test_neighbors_tree.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from sklearn.neighbors import DistanceMetric +from sklearn.metrics import DistanceMetric from sklearn.neighbors._ball_tree import ( BallTree, kernel_norm,