diff --git a/benchmarks/bench_saga.py b/benchmarks/bench_saga.py index 581f7e3881e9e..340549ef240e1 100644 --- a/benchmarks/bench_saga.py +++ b/benchmarks/bench_saga.py @@ -7,8 +7,7 @@ import time import os -from joblib import Parallel -from sklearn.utils.fixes import delayed +from sklearn.utils.parallel import delayed, Parallel import matplotlib.pyplot as plt import numpy as np diff --git a/build_tools/azure/linting.sh b/build_tools/azure/linting.sh index 21ef53c8012dc..9cc57c5f06066 100755 --- a/build_tools/azure/linting.sh +++ b/build_tools/azure/linting.sh @@ -34,10 +34,15 @@ then exit 1 fi -joblib_import="$(git grep -l -A 10 -E "joblib import.+delayed" -- "*.py" ":!sklearn/utils/_joblib.py" ":!sklearn/utils/fixes.py")" - -if [ ! -z "$joblib_import" ]; then - echo "Use from sklearn.utils.fixes import delayed instead of joblib delayed. The following files contains imports to joblib.delayed:" - echo "$joblib_import" +joblib_delayed_import="$(git grep -l -A 10 -E "joblib import.+delayed" -- "*.py" ":!sklearn/utils/_joblib.py" ":!sklearn/utils/parallel.py")" +if [ ! -z "$joblib_delayed_import" ]; then + echo "Use from sklearn.utils.parallel import delayed instead of joblib delayed. The following files contains imports to joblib.delayed:" + echo "$joblib_delayed_import" + exit 1 +fi +joblib_Parallel_import="$(git grep -l -A 10 -E "joblib import.+Parallel" -- "*.py" ":!sklearn/utils/_joblib.py" ":!sklearn/utils/parallel.py")" +if [ ! -z "$joblib_Parallel_import" ]; then + echo "Use from sklearn.utils.parallel import Parallel instead of joblib Parallel. The following files contains imports to joblib.Parallel:" + echo "$joblib_Parallel_import" exit 1 fi diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index d55becb0c512a..b1a50f65f0f7b 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1666,9 +1666,16 @@ Utilities from joblib: :toctree: generated/ :template: function.rst + utils.parallel.delayed utils.parallel_backend utils.register_parallel_backend +.. autosummary:: + :toctree: generated/ + :template: class.rst + + utils.parallel.Parallel + Recently deprecated =================== diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 3bc338e68f5de..0189ae36ba1aa 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -9,6 +9,16 @@ Version 1.2.1 **In Development** +Changes impacting all modules +----------------------------- + +- |Fix| Fix a bug where the current configuration was ignored in estimators using + `n_jobs > 1`. This bug was triggered for tasks dispatched by the auxillary + thread of `joblib` as :func:`sklearn.get_config` used to access an empty thread + local configuration instead of the configuration visible from the thread where + `joblib.Parallel` was first called. + :pr:`25363` by :user:`Guillaume Lemaitre `. + Changed models -------------- @@ -139,6 +149,13 @@ Changelog boolean. The type is maintained, instead of converting to `float64.` :pr:`25147` by :user:`Tim Head `. +- |API| :func:`utils.fixes.delayed` is deprecated in 1.2.1 and will be removed + in 1.5. Instead, import :func:`utils.parallel.delayed` and use it in + conjunction with the newly introduced :func:`utils.parallel.Parallel` + to ensure proper propagation of the scikit-learn configuration to + the workers. + :pr:`25363` by :user:`Guillaume Lemaitre `. + .. _changes_1_2: Version 1.2.0 diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 8b3e39e31f3bb..2c4a33616d22c 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -14,7 +14,6 @@ from math import log import numpy as np -from joblib import Parallel from scipy.special import expit from scipy.special import xlogy @@ -36,7 +35,7 @@ ) from .utils.multiclass import check_classification_targets -from .utils.fixes import delayed +from .utils.parallel import delayed, Parallel from .utils._param_validation import StrOptions, HasMethods, Hidden from .utils.validation import ( _check_fit_params, diff --git a/sklearn/cluster/_mean_shift.py b/sklearn/cluster/_mean_shift.py index 29723ad81d2d0..8a60bf770c958 100644 --- a/sklearn/cluster/_mean_shift.py +++ b/sklearn/cluster/_mean_shift.py @@ -16,13 +16,12 @@ import numpy as np import warnings -from joblib import Parallel from numbers import Integral, Real from collections import defaultdict from ..utils._param_validation import Interval from ..utils.validation import check_is_fitted -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel from ..utils import check_random_state, gen_batches, check_array from ..base import BaseEstimator, ClusterMixin from ..neighbors import NearestNeighbors diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index ba54441e2b63e..aabe846ea174a 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -11,7 +11,6 @@ import numpy as np from scipy import sparse -from joblib import Parallel from ..base import clone, TransformerMixin from ..utils._estimator_html_repr import _VisualBlock @@ -24,7 +23,7 @@ from ..utils import check_pandas_support from ..utils.metaestimators import _BaseComposition from ..utils.validation import check_array, check_is_fitted, _check_feature_names_in -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel __all__ = ["ColumnTransformer", "make_column_transformer", "make_column_selector"] diff --git a/sklearn/covariance/_graph_lasso.py b/sklearn/covariance/_graph_lasso.py index 6b6116ecce040..dfbb7e75753c5 100644 --- a/sklearn/covariance/_graph_lasso.py +++ b/sklearn/covariance/_graph_lasso.py @@ -13,7 +13,6 @@ from numbers import Integral, Real import numpy as np from scipy import linalg -from joblib import Parallel from . import empirical_covariance, EmpiricalCovariance, log_likelihood @@ -23,7 +22,7 @@ check_random_state, check_scalar, ) -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel from ..utils._param_validation import Interval, StrOptions # mypy error: Module 'sklearn.linear_model' has no attribute '_cd_fast' diff --git a/sklearn/decomposition/_dict_learning.py b/sklearn/decomposition/_dict_learning.py index f734f7361944b..98e1e77742de2 100644 --- a/sklearn/decomposition/_dict_learning.py +++ b/sklearn/decomposition/_dict_learning.py @@ -13,7 +13,7 @@ import numpy as np from scipy import linalg -from joblib import Parallel, effective_n_jobs +from joblib import effective_n_jobs from ..base import BaseEstimator, TransformerMixin, ClassNamePrefixFeaturesOutMixin from ..utils import check_array, check_random_state, gen_even_slices, gen_batches @@ -21,7 +21,7 @@ from ..utils._param_validation import Hidden, Interval, StrOptions from ..utils.extmath import randomized_svd, row_norms, svd_flip from ..utils.validation import check_is_fitted -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel from ..linear_model import Lasso, orthogonal_mp_gram, LassoLars, Lars diff --git a/sklearn/decomposition/_lda.py b/sklearn/decomposition/_lda.py index d187611251eda..b7ef0af8eafd9 100644 --- a/sklearn/decomposition/_lda.py +++ b/sklearn/decomposition/_lda.py @@ -15,13 +15,13 @@ import numpy as np import scipy.sparse as sp from scipy.special import gammaln, logsumexp -from joblib import Parallel, effective_n_jobs +from joblib import effective_n_jobs from ..base import BaseEstimator, TransformerMixin, ClassNamePrefixFeaturesOutMixin from ..utils import check_random_state, gen_batches, gen_even_slices from ..utils.validation import check_non_negative from ..utils.validation import check_is_fitted -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel from ..utils._param_validation import Interval, StrOptions from ._online_lda_fast import ( diff --git a/sklearn/decomposition/tests/test_dict_learning.py b/sklearn/decomposition/tests/test_dict_learning.py index 16a695d1189ab..7c383390f5b60 100644 --- a/sklearn/decomposition/tests/test_dict_learning.py +++ b/sklearn/decomposition/tests/test_dict_learning.py @@ -5,8 +5,6 @@ from functools import partial import itertools -from joblib import Parallel - import sklearn from sklearn.base import clone @@ -14,6 +12,7 @@ from sklearn.exceptions import ConvergenceWarning from sklearn.utils import check_array +from sklearn.utils.parallel import Parallel from sklearn.utils._testing import assert_allclose from sklearn.utils._testing import assert_array_almost_equal diff --git a/sklearn/ensemble/_bagging.py b/sklearn/ensemble/_bagging.py index fbe54a8afb530..4586e55a59f97 100644 --- a/sklearn/ensemble/_bagging.py +++ b/sklearn/ensemble/_bagging.py @@ -12,8 +12,6 @@ from warnings import warn from functools import partial -from joblib import Parallel - from ._base import BaseEnsemble, _partition_estimators from ..base import ClassifierMixin, RegressorMixin from ..metrics import r2_score, accuracy_score @@ -25,7 +23,7 @@ from ..utils.random import sample_without_replacement from ..utils._param_validation import Interval, HasMethods, StrOptions from ..utils.validation import has_fit_parameter, check_is_fitted, _check_sample_weight -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel __all__ = ["BaggingClassifier", "BaggingRegressor"] diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index b13315a5c00a7..69f48e6344ce7 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -48,7 +48,6 @@ class calls the ``fit`` method of each sub-estimator on random samples import numpy as np from scipy.sparse import issparse from scipy.sparse import hstack as sparse_hstack -from joblib import Parallel from ..base import is_classifier from ..base import ClassifierMixin, MultiOutputMixin, RegressorMixin, TransformerMixin @@ -66,7 +65,7 @@ class calls the ``fit`` method of each sub-estimator on random samples from ..utils import check_random_state, compute_sample_weight from ..exceptions import DataConversionWarning from ._base import BaseEnsemble, _partition_estimators -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel from ..utils.multiclass import check_classification_targets, type_of_target from ..utils.validation import ( check_is_fitted, diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 2e3d1b6db5798..e7468ddb5ac22 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -8,7 +8,6 @@ from numbers import Integral import numpy as np -from joblib import Parallel import scipy.sparse as sparse from ..base import clone @@ -33,7 +32,7 @@ from ..utils.metaestimators import available_if from ..utils.validation import check_is_fitted from ..utils.validation import column_or_1d -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel from ..utils._param_validation import HasMethods, StrOptions from ..utils.validation import _check_feature_names_in diff --git a/sklearn/ensemble/_voting.py b/sklearn/ensemble/_voting.py index 2e1e27b3b3de9..97db90fc0c172 100644 --- a/sklearn/ensemble/_voting.py +++ b/sklearn/ensemble/_voting.py @@ -18,8 +18,6 @@ import numpy as np -from joblib import Parallel - from ..base import ClassifierMixin from ..base import RegressorMixin from ..base import TransformerMixin @@ -36,7 +34,7 @@ from ..utils._param_validation import StrOptions from ..exceptions import NotFittedError from ..utils._estimator_html_repr import _VisualBlock -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel class _BaseVoting(TransformerMixin, _BaseHeterogeneousEnsemble): diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 088b0f12d5cc1..aebc5ea87d52b 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -17,16 +17,15 @@ from typing import Dict, Any import numpy as np -from joblib import Parallel from scipy.sparse import csr_matrix from scipy.sparse import csc_matrix from scipy.sparse import coo_matrix from scipy.special import comb -import pytest - import joblib +import pytest + import sklearn from sklearn.dummy import DummyRegressor from sklearn.metrics import mean_poisson_deviance @@ -55,6 +54,7 @@ >>>>>>> c3fca81536 (FIX Support read-only sparse datasets for `Tree`-based estimators (#25341)) from sklearn.model_selection import GridSearchCV from sklearn.svm import LinearSVC +from sklearn.utils.parallel import Parallel from sklearn.utils.validation import check_random_state from sklearn.metrics import mean_squared_error diff --git a/sklearn/feature_selection/_rfe.py b/sklearn/feature_selection/_rfe.py index a025fe7c36490..d105ba1ae3567 100644 --- a/sklearn/feature_selection/_rfe.py +++ b/sklearn/feature_selection/_rfe.py @@ -8,7 +8,7 @@ import numpy as np from numbers import Integral, Real -from joblib import Parallel, effective_n_jobs +from joblib import effective_n_jobs from ..utils.metaestimators import available_if @@ -16,7 +16,7 @@ from ..utils._param_validation import HasMethods, Interval from ..utils._tags import _safe_tags from ..utils.validation import check_is_fitted -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel from ..base import BaseEstimator from ..base import MetaEstimatorMixin from ..base import clone diff --git a/sklearn/inspection/_permutation_importance.py b/sklearn/inspection/_permutation_importance.py index 63ab4f69968d8..a418cb34b3540 100644 --- a/sklearn/inspection/_permutation_importance.py +++ b/sklearn/inspection/_permutation_importance.py @@ -1,7 +1,6 @@ """Permutation importance for estimators.""" import numbers import numpy as np -from joblib import Parallel from ..ensemble._bagging import _generate_indices from ..metrics import check_scoring @@ -10,7 +9,7 @@ from ..utils import Bunch, _safe_indexing from ..utils import check_random_state from ..utils import check_array -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel def _weights_scorer(scorer, estimator, X, y, sample_weight): diff --git a/sklearn/inspection/_plot/partial_dependence.py b/sklearn/inspection/_plot/partial_dependence.py index e3392eeb911b1..bba8e0dea826f 100644 --- a/sklearn/inspection/_plot/partial_dependence.py +++ b/sklearn/inspection/_plot/partial_dependence.py @@ -6,7 +6,6 @@ import numpy as np from scipy import sparse from scipy.stats.mstats import mquantiles -from joblib import Parallel from .. import partial_dependence from .._pd_utils import _check_feature_names, _get_feature_index @@ -16,7 +15,7 @@ from ...utils import check_matplotlib_support # noqa from ...utils import check_random_state from ...utils import _safe_indexing -from ...utils.fixes import delayed +from ...utils.parallel import delayed, Parallel from ...utils._encode import _unique diff --git a/sklearn/linear_model/_base.py b/sklearn/linear_model/_base.py index a3ac37257a98b..987ae57c12250 100644 --- a/sklearn/linear_model/_base.py +++ b/sklearn/linear_model/_base.py @@ -25,7 +25,6 @@ from scipy import sparse from scipy.sparse.linalg import lsqr from scipy.special import expit -from joblib import Parallel from numbers import Integral from ..base import BaseEstimator, ClassifierMixin, RegressorMixin, MultiOutputMixin @@ -40,7 +39,7 @@ from ..utils._seq_dataset import ArrayDataset32, CSRDataset32 from ..utils._seq_dataset import ArrayDataset64, CSRDataset64 from ..utils.validation import check_is_fitted, _check_sample_weight -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel # TODO: bayesian_ridge_regression and bayesian_regression_ard # should be squashed into its respective objects. diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index 67fffac39d466..c1783be1ae45f 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -14,7 +14,7 @@ import numpy as np from scipy import sparse -from joblib import Parallel, effective_n_jobs +from joblib import effective_n_jobs from ._base import LinearModel, _pre_fit from ..base import RegressorMixin, MultiOutputMixin @@ -30,7 +30,7 @@ check_is_fitted, column_or_1d, ) -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel # mypy error: Module 'sklearn.linear_model' has no attribute '_cd_fast' from . import _cd_fast as cd_fast # type: ignore diff --git a/sklearn/linear_model/_least_angle.py b/sklearn/linear_model/_least_angle.py index f3b3ba33517b8..7ed7b27811ec6 100644 --- a/sklearn/linear_model/_least_angle.py +++ b/sklearn/linear_model/_least_angle.py @@ -16,7 +16,6 @@ import numpy as np from scipy import linalg, interpolate from scipy.linalg.lapack import get_lapack_funcs -from joblib import Parallel from ._base import LinearModel, LinearRegression from ._base import _deprecate_normalize, _preprocess_data @@ -28,7 +27,7 @@ from ..utils._param_validation import Hidden, Interval, StrOptions from ..model_selection import check_cv from ..exceptions import ConvergenceWarning -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel SOLVE_TRIANGULAR_ARGS = {"check_finite": False} diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 42f0ffb56d8e3..0d0da3983c664 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -16,7 +16,7 @@ import numpy as np from scipy import optimize -from joblib import Parallel, effective_n_jobs +from joblib import effective_n_jobs from sklearn.metrics import get_scorer_names @@ -34,7 +34,7 @@ from ..utils.optimize import _newton_cg, _check_optimize_result from ..utils.validation import check_is_fitted, _check_sample_weight from ..utils.multiclass import check_classification_targets -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel from ..utils._param_validation import StrOptions, Interval from ..model_selection import check_cv from ..metrics import get_scorer diff --git a/sklearn/linear_model/_omp.py b/sklearn/linear_model/_omp.py index 819cfbfb21adc..f0bd04568c473 100644 --- a/sklearn/linear_model/_omp.py +++ b/sklearn/linear_model/_omp.py @@ -12,12 +12,11 @@ import numpy as np from scipy import linalg from scipy.linalg.lapack import get_lapack_funcs -from joblib import Parallel from ._base import LinearModel, _pre_fit, _deprecate_normalize from ..base import RegressorMixin, MultiOutputMixin from ..utils import as_float_array, check_array -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel from ..utils._param_validation import Hidden, Interval, StrOptions from ..model_selection import check_cv diff --git a/sklearn/linear_model/_stochastic_gradient.py b/sklearn/linear_model/_stochastic_gradient.py index baa4361cac9ef..1361ef6a1c609 100644 --- a/sklearn/linear_model/_stochastic_gradient.py +++ b/sklearn/linear_model/_stochastic_gradient.py @@ -12,8 +12,6 @@ from abc import ABCMeta, abstractmethod from numbers import Integral, Real -from joblib import Parallel - from ..base import clone, is_classifier from ._base import LinearClassifierMixin, SparseCoefMixin from ._base import make_dataset @@ -26,7 +24,7 @@ from ..utils._param_validation import Interval from ..utils._param_validation import StrOptions from ..utils._param_validation import Hidden -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel from ..exceptions import ConvergenceWarning from ..model_selection import StratifiedShuffleSplit, ShuffleSplit diff --git a/sklearn/linear_model/_theil_sen.py b/sklearn/linear_model/_theil_sen.py index ab9c883add0c6..67d6ca532a8ab 100644 --- a/sklearn/linear_model/_theil_sen.py +++ b/sklearn/linear_model/_theil_sen.py @@ -15,13 +15,13 @@ from scipy import linalg from scipy.special import binom from scipy.linalg.lapack import get_lapack_funcs -from joblib import Parallel, effective_n_jobs +from joblib import effective_n_jobs from ._base import LinearModel from ..base import RegressorMixin from ..utils import check_random_state from ..utils._param_validation import Interval -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel from ..exceptions import ConvergenceWarning _EPSILON = np.finfo(np.double).eps diff --git a/sklearn/manifold/_mds.py b/sklearn/manifold/_mds.py index bfa4c6160d9ce..d6f99c84f55f1 100644 --- a/sklearn/manifold/_mds.py +++ b/sklearn/manifold/_mds.py @@ -8,7 +8,7 @@ from numbers import Integral, Real import numpy as np -from joblib import Parallel, effective_n_jobs +from joblib import effective_n_jobs import warnings @@ -17,7 +17,7 @@ from ..utils import check_random_state, check_array, check_symmetric from ..isotonic import IsotonicRegression from ..utils._param_validation import Interval, StrOptions, Hidden -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel def _smacof_single( diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index b34722dc25df7..3d01d5eeaf12d 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -15,7 +15,7 @@ from scipy.spatial import distance from scipy.sparse import csr_matrix from scipy.sparse import issparse -from joblib import Parallel, effective_n_jobs +from joblib import effective_n_jobs from .. import config_context from ..utils.validation import _num_samples @@ -27,7 +27,7 @@ from ..utils.extmath import row_norms, safe_sparse_dot from ..preprocessing import normalize from ..utils._mask import _get_mask -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel from ..utils.fixes import sp_version, parse_version from ._pairwise_distances_reduction import ArgKmin diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 14ef0d9651f2e..636a1429bd21b 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -33,13 +33,12 @@ from ._validation import _normalize_score_results from ._validation import _warn_or_raise_about_fit_failures from ..exceptions import NotFittedError -from joblib import Parallel from ..utils import check_random_state from ..utils.random import sample_without_replacement from ..utils._tags import _safe_tags from ..utils.validation import indexable, check_is_fitted, _check_fit_params from ..utils.metaestimators import available_if -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel from ..metrics._scorer import _check_multimetric_scoring from ..metrics import check_scoring diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 1a13b770cd1c5..8a626fe1ce1f3 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -21,13 +21,13 @@ import numpy as np import scipy.sparse as sp -from joblib import Parallel, logger +from joblib import logger from ..base import is_classifier, clone from ..utils import indexable, check_random_state, _safe_indexing from ..utils.validation import _check_fit_params from ..utils.validation import _num_samples -from ..utils.fixes import delayed +from ..utils.parallel import delayed, Parallel from ..utils.metaestimators import _safe_split from ..metrics import check_scoring from ..metrics._scorer import _check_multimetric_scoring, _MultimetricScorer diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 318066391a638..863435e5c7a1f 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -56,9 +56,7 @@ _ovr_decision_function, ) from .utils.metaestimators import _safe_split, available_if -from .utils.fixes import delayed - -from joblib import Parallel +from .utils.parallel import delayed, Parallel __all__ = [ "OneVsRestClassifier", diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index ba3ddf9572232..4b7015dd40ece 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -17,7 +17,6 @@ import numpy as np import scipy.sparse as sp -from joblib import Parallel from abc import ABCMeta, abstractmethod from .base import BaseEstimator, clone, MetaEstimatorMixin @@ -31,7 +30,7 @@ has_fit_parameter, _check_fit_params, ) -from .utils.fixes import delayed +from .utils.parallel import delayed, Parallel from .utils._param_validation import HasMethods, StrOptions __all__ = [ diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 3b01824a3a73a..2e97bab4a4f8d 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -16,7 +16,7 @@ import numpy as np from scipy.sparse import csr_matrix, issparse -from joblib import Parallel, effective_n_jobs +from joblib import effective_n_jobs from ._ball_tree import BallTree from ._kd_tree import KDTree @@ -37,8 +37,8 @@ from ..utils.validation import check_is_fitted from ..utils.validation import check_non_negative from ..utils._param_validation import Interval, StrOptions -from ..utils.fixes import delayed, sp_version -from ..utils.fixes import parse_version +from ..utils.parallel import delayed, Parallel +from ..utils.fixes import parse_version, sp_version from ..exceptions import DataConversionWarning, EfficiencyWarning VALID_METRICS = dict( diff --git a/sklearn/neighbors/tests/test_kd_tree.py b/sklearn/neighbors/tests/test_kd_tree.py index d8d9437636d1d..525c15436e24c 100644 --- a/sklearn/neighbors/tests/test_kd_tree.py +++ b/sklearn/neighbors/tests/test_kd_tree.py @@ -1,7 +1,6 @@ import numpy as np import pytest -from joblib import Parallel -from sklearn.utils.fixes import delayed +from sklearn.utils.parallel import delayed, Parallel from sklearn.neighbors._kd_tree import KDTree diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index b3a4d180a4c68..50f8f11fa212e 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -14,7 +14,6 @@ import numpy as np from scipy import sparse -from joblib import Parallel from .base import clone, TransformerMixin from .preprocessing import FunctionTransformer @@ -29,7 +28,7 @@ from .utils.validation import check_is_fitted from .utils import check_pandas_support from .utils._set_output import _safe_set_output, _get_output_config -from .utils.fixes import delayed +from .utils.parallel import delayed, Parallel from .exceptions import NotFittedError from .utils.metaestimators import _BaseComposition diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index a0b8f29662b69..bcc4c233e7ea3 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -1,11 +1,10 @@ import time from concurrent.futures import ThreadPoolExecutor -from joblib import Parallel import pytest from sklearn import get_config, set_config, config_context -from sklearn.utils.fixes import delayed +from sklearn.utils.parallel import delayed, Parallel def test_config_context(): @@ -120,15 +119,15 @@ def test_config_threadsafe_joblib(backend): should be the same as the value passed to the function. In other words, it is not influenced by the other job setting assume_finite to True. """ - assume_finites = [False, True] - sleep_durations = [0.1, 0.2] + assume_finites = [False, True, False, True] + sleep_durations = [0.1, 0.2, 0.1, 0.2] items = Parallel(backend=backend, n_jobs=2)( delayed(set_assume_finite)(assume_finite, sleep_dur) for assume_finite, sleep_dur in zip(assume_finites, sleep_durations) ) - assert items == [False, True] + assert items == [False, True, False, True] def test_config_threadsafe(): @@ -136,8 +135,8 @@ def test_config_threadsafe(): between threads. Same test as `test_config_threadsafe_joblib` but with `ThreadPoolExecutor`.""" - assume_finites = [False, True] - sleep_durations = [0.1, 0.2] + assume_finites = [False, True, False, True] + sleep_durations = [0.1, 0.2, 0.1, 0.2] with ThreadPoolExecutor(max_workers=2) as e: items = [ @@ -145,4 +144,4 @@ def test_config_threadsafe(): for output in e.map(set_assume_finite, assume_finites, sleep_durations) ] - assert items == [False, True] + assert items == [False, True, False, True] diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index 53fe56ea68144..37a25ff96ba00 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -10,9 +10,7 @@ # # License: BSD 3 clause -from functools import update_wrapper from importlib import resources -import functools import sys import sklearn @@ -20,7 +18,8 @@ import scipy import scipy.stats import threadpoolctl -from .._config import config_context, get_config + +from .deprecation import deprecated from ..externals._packaging.version import parse as parse_version @@ -106,30 +105,6 @@ def _eigh(*args, **kwargs): return scipy.linalg.eigh(*args, eigvals=eigvals, **kwargs) -# remove when https://github.com/joblib/joblib/issues/1071 is fixed -def delayed(function): - """Decorator used to capture the arguments of a function.""" - - @functools.wraps(function) - def delayed_function(*args, **kwargs): - return _FuncWrapper(function), args, kwargs - - return delayed_function - - -class _FuncWrapper: - """ "Load the global configuration before calling the function.""" - - def __init__(self, function): - self.function = function - self.config = get_config() - update_wrapper(self, self.function) - - def __call__(self, *args, **kwargs): - with config_context(**self.config): - return self.function(*args, **kwargs) - - # Rename the `method` kwarg to `interpolation` for NumPy < 1.22, because # `interpolation` kwarg was deprecated in favor of `method` in NumPy >= 1.22. def _percentile(a, q, *, method="linear", **kwargs): @@ -178,6 +153,16 @@ def threadpool_info(): threadpool_info.__doc__ = threadpoolctl.threadpool_info.__doc__ +@deprecated( + "The function `delayed` has been moved from `sklearn.utils.fixes` to " + "`sklearn.utils.parallel`. This import path will be removed in 1.5." +) +def delayed(function): + from sklearn.utils.parallel import delayed + + return delayed(function) + + # TODO: Remove when SciPy 1.11 is the minimum supported version def _mode(a, axis=0): if sp_version >= parse_version("1.9.0"): diff --git a/sklearn/utils/parallel.py b/sklearn/utils/parallel.py new file mode 100644 index 0000000000000..48a31ee93d8a0 --- /dev/null +++ b/sklearn/utils/parallel.py @@ -0,0 +1,123 @@ +"""Module that customize joblib tools for scikit-learn usage.""" + +import functools +import warnings +from functools import update_wrapper + +import joblib + +from .._config import config_context, get_config + + +def _with_config(delayed_func, config): + """Helper function that intends to attach a config to a delayed function.""" + if hasattr(delayed_func, "with_config"): + return delayed_func.with_config(config) + else: + warnings.warn( + "`sklearn.utils.parallel.Parallel` needs to be used in " + "conjunction with `sklearn.utils.parallel.delayed` instead of " + "`joblib.delayed` to correctly propagate the scikit-learn " + "configuration to the joblib workers.", + UserWarning, + ) + return delayed_func + + +class Parallel(joblib.Parallel): + """Tweak of :class:`joblib.Parallel` that propagates the scikit-learn configuration. + + This subclass of :class:`joblib.Parallel` ensures that the active configuration + (thread-local) of scikit-learn is propagated to the parallel workers for the + duration of the execution of the parallel tasks. + + The API does not change and you can refer to :class:`joblib.Parallel` + documentation for more details. + + .. versionadded:: 1.3 + """ + + def __call__(self, iterable): + """Dispatch the tasks and return the results. + + Parameters + ---------- + iterable : iterable + Iterable containing tuples of (delayed_function, args, kwargs) that should + be consumed. + + Returns + ------- + results : list + List of results of the tasks. + """ + # Capture the thread-local scikit-learn configuration at the time + # Parallel.__call__ is issued since the tasks can be dispatched + # in a different thread depending on the backend and on the value of + # pre_dispatch and n_jobs. + config = get_config() + iterable_with_config = ( + (_with_config(delayed_func, config), args, kwargs) + for delayed_func, args, kwargs in iterable + ) + return super().__call__(iterable_with_config) + + +# remove when https://github.com/joblib/joblib/issues/1071 is fixed +def delayed(function): + """Decorator used to capture the arguments of a function. + + This alternative to `joblib.delayed` is meant to be used in conjunction + with `sklearn.utils.parallel.Parallel`. The latter captures the the scikit- + learn configuration by calling `sklearn.get_config()` in the current + thread, prior to dispatching the first task. The captured configuration is + then propagated and enabled for the duration of the execution of the + delayed function in the joblib workers. + + .. versionchanged:: 1.3 + `delayed` was moved from `sklearn.utils.fixes` to `sklearn.utils.parallel` + in scikit-learn 1.3. + + Parameters + ---------- + function : callable + The function to be delayed. + + Returns + ------- + output: tuple + Tuple containing the delayed function, the positional arguments, and the + keyword arguments. + """ + + @functools.wraps(function) + def delayed_function(*args, **kwargs): + return _FuncWrapper(function), args, kwargs + + return delayed_function + + +class _FuncWrapper: + """Load the global configuration before calling the function.""" + + def __init__(self, function): + self.function = function + update_wrapper(self, self.function) + + def with_config(self, config): + self.config = config + return self + + def __call__(self, *args, **kwargs): + config = getattr(self, "config", None) + if config is None: + warnings.warn( + "`sklearn.utils.parallel.delayed` should be used with " + "`sklearn.utils.parallel.Parallel` to make it possible to propagate " + "the scikit-learn configuration of the current thread to the " + "joblib workers.", + UserWarning, + ) + config = {} + with config_context(**config): + return self.function(*args, **kwargs) diff --git a/sklearn/utils/tests/test_fixes.py b/sklearn/utils/tests/test_fixes.py index 3566897da5efc..64db4006b5f1a 100644 --- a/sklearn/utils/tests/test_fixes.py +++ b/sklearn/utils/tests/test_fixes.py @@ -11,8 +11,7 @@ from sklearn.utils._testing import assert_array_equal -from sklearn.utils.fixes import _object_dtype_isnan -from sklearn.utils.fixes import loguniform +from sklearn.utils.fixes import _object_dtype_isnan, delayed, loguniform @pytest.mark.parametrize("dtype, val", ([object, 1], [object, "a"], [float, 1])) @@ -46,3 +45,14 @@ def test_loguniform(low, high, base): assert loguniform(base**low, base**high).rvs(random_state=0) == loguniform( base**low, base**high ).rvs(random_state=0) + + +def test_delayed_deprecation(): + """Check that we issue the FutureWarning regarding the deprecation of delayed.""" + + def func(x): + return x + + warn_msg = "The function `delayed` has been moved from `sklearn.utils.fixes`" + with pytest.warns(FutureWarning, match=warn_msg): + delayed(func) diff --git a/sklearn/utils/tests/test_parallel.py b/sklearn/utils/tests/test_parallel.py index dfecd7b464168..2f56c584300d1 100644 --- a/sklearn/utils/tests/test_parallel.py +++ b/sklearn/utils/tests/test_parallel.py @@ -1,10 +1,19 @@ -import pytest -from joblib import Parallel +import time +import joblib +import numpy as np +import pytest from numpy.testing import assert_array_equal -from sklearn._config import config_context, get_config -from sklearn.utils.fixes import delayed +from sklearn import config_context, get_config +from sklearn.compose import make_column_transformer +from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import GridSearchCV +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler + +from sklearn.utils.parallel import delayed, Parallel def get_working_memory(): @@ -22,3 +31,71 @@ def test_configuration_passes_through_to_joblib(n_jobs, backend): ) assert_array_equal(results, [123] * 2) + + +def test_parallel_delayed_warnings(): + """Informative warnings should be raised when mixing sklearn and joblib API""" + # We should issue a warning when one wants to use sklearn.utils.fixes.Parallel + # with joblib.delayed. The config will not be propagated to the workers. + warn_msg = "`sklearn.utils.parallel.Parallel` needs to be used in conjunction" + with pytest.warns(UserWarning, match=warn_msg) as records: + Parallel()(joblib.delayed(time.sleep)(0) for _ in range(10)) + assert len(records) == 10 + + # We should issue a warning if one wants to use sklearn.utils.fixes.delayed with + # joblib.Parallel + warn_msg = ( + "`sklearn.utils.parallel.delayed` should be used with " + "`sklearn.utils.parallel.Parallel` to make it possible to propagate" + ) + with pytest.warns(UserWarning, match=warn_msg) as records: + joblib.Parallel()(delayed(time.sleep)(0) for _ in range(10)) + assert len(records) == 10 + + +@pytest.mark.parametrize("n_jobs", [1, 2]) +def test_dispatch_config_parallel(n_jobs): + """Check that we properly dispatch the configuration in parallel processing. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/25239 + """ + pd = pytest.importorskip("pandas") + iris = load_iris(as_frame=True) + + class TransformerRequiredDataFrame(StandardScaler): + def fit(self, X, y=None): + assert isinstance(X, pd.DataFrame), "X should be a DataFrame" + return super().fit(X, y) + + def transform(self, X, y=None): + assert isinstance(X, pd.DataFrame), "X should be a DataFrame" + return super().transform(X, y) + + dropper = make_column_transformer( + ("drop", [0]), + remainder="passthrough", + n_jobs=n_jobs, + ) + param_grid = {"randomforestclassifier__max_depth": [1, 2, 3]} + search_cv = GridSearchCV( + make_pipeline( + dropper, + TransformerRequiredDataFrame(), + RandomForestClassifier(n_estimators=5, n_jobs=n_jobs), + ), + param_grid, + cv=5, + n_jobs=n_jobs, + error_score="raise", # this search should not fail + ) + + # make sure that `fit` would fail in case we don't request dataframe + with pytest.raises(AssertionError, match="X should be a DataFrame"): + search_cv.fit(iris.data, iris.target) + + with config_context(transform_output="pandas"): + # we expect each intermediate steps to output a DataFrame + search_cv.fit(iris.data, iris.target) + + assert not np.isnan(search_cv.cv_results_["mean_test_score"]).any()