Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Adds TargetEncoder #25334

Merged
merged 91 commits into from Mar 16, 2023

Conversation

thomasjpfan
Copy link
Member

@thomasjpfan thomasjpfan commented Jan 8, 2023

Reference Issues/PRs

Closes #5853
Closes #9614
Supersedes #17323
Fixes or at least related to #24967

What does this implement/fix? Explain your changes.

This PR implements a target encoder which uses CV during fit_transform to prevent the target from leaking. transform uses the the target encoding from all the training data. This means that fit_transform() != fit().transform().

The implementation uses Cython to learn the encoding which provides a 10x speed up compared to using a pure Python+NumPy approach. Cython is required because many encodings are learn during cross validation in fit_transform.

Any other comments?

The implementation uses the same scheme as cuML's TargetEncoder, which they used to win Recsys2020.

@thomasjpfan thomasjpfan changed the title ENH Adds Target Regression Encoder ENH Adds Target Regressor Encoder Jan 8, 2023
@glemaitre
Copy link
Member

Do we want to have Regressor in the naming since we could make the encoder work on both regression and classification problem?

@betatim
Copy link
Member

betatim commented Jan 9, 2023

Do we want to have Regressor in the naming since we could make the encoder work on both regression and classification problem?

To have a encoder that works for both classification and regression I think we'd have to detect the type of problem based on the vaules of y. How good is the automatic detection of "regression vs classification"? In addition it feels like the code for classification, in particular multi-class classification, would be more complicated than for the regression case (you can't use np.mean(y)).


I think the way this works is that when performing the change from "categories" to "encoded categories" _BaseEncoder._transform returns a version of X where a column representing a categorical feature is replaced by a column of integers. This is then translated to the actual encoded value in _transform_X_int. On top we have the complexity of fit().transform() and fit_transform() not doing exactly the same thing. It took a lot of going back and forth and working out what _BaseEncoder does to get to that understanding.

After reading the code and example once my big picture thoughts are:

  • nice!
  • this looks like it should work,
  • can we have more doc strings and comments that explain why something is the way it is
  • both doc string and naming stuff should be done after code structure, etc
  • right now sample weights aren't supported, would it be easy to add that?

What is your plan? Is it ready to go or does it need more tweaking?

@glemaitre
Copy link
Member

How good is the automatic detection of "regression vs classification"?

type_of_target is quite good and would do the job.

In addition it feels like the code for classification, in particular multi-class classification, would be more complicated than for the regression case (you can't use np.mean(y)).

dirty_cat implements both but this is not the same branch of code: https://github.com/dirty-cat/dirty_cat/blob/main/dirty_cat/_target_encoder.py

@betatim
Copy link
Member

betatim commented Jan 9, 2023

What can the TargetEncoder from dirty cat not do/what is the advantage of making yet another implementation over taking it verbatim? For example from the constructor args of it, it seems it supports regression, binary and multi-class classification. It seems like an excellent starting point

@glemaitre
Copy link
Member

From the dev meeting, we thought that could put aside the multiclass on the side for the moment and support binary classification and regression. We need to make sure that the name of the encoder reflects that but we don't have to support all possible classification and regression problems at first.

I will make a review having those points in mind.

@glemaitre
Copy link
Member

What can the TargetEncoder from dirty cat not do/what is the advantage of making yet another implementation over taking it verbatim?

Since OneHotEncoder and OrdinalEncoder are working for all problems, it could come as a surprise to our user to have to select the right encoder type. There is also the precedent of categorical_encoder, cuml, dirty_cat (for the one that I am aware of) that are exposing a single encoder for all usage.

@betatim
Copy link
Member

betatim commented Jan 9, 2023

Since OneHotEncoder and OrdinalEncoder are working for all problems, it could come as a surprise to our user to have to select the right encoder type.

I agree. Having one encoder for all types of problems is nicer than having to choose.

My question was "Why not take the dirty_cat implementation of TargetEncoder (as a starting point) for the implementation in scikit-learn?"

@thomasjpfan
Copy link
Member Author

thomasjpfan commented Jan 9, 2023

Here are the key differences between this PR and dirty_cat's version:

  1. This PR takes advantage of the existing code in _BaseEncoder, which greatly simplifies the implementation.
  2. This PR does cross validation in fit_transform while dirty cat does not. The cross validation is required to prevent the target from leaking during training.
  3. This PR's implementation is much faster because of the Cython, even with cross validation in fit_transform:
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import TargetRegressorEncoder
from dirty_cat import TargetEncoder as DirtyCatTargetEncoder
import numpy as np

rng = np.random.default_rng()
n_samples, n_features = 500_000, 20
X = rng.integers(0, high=30, size=(n_samples, n_features))
y = rng.standard_normal(size=n_samples)
%%timeit
_ = TargetRegressorEncoder().fit_transform(X, y)
# 3.37 s ± 132 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
_ = DirtyCatTargetEncoder().fit_transform(X, y)
# 9.9 s ± 220 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

This PR is faster even when it's fit_transform learns 6 encodings: five for cv=5 and one more encoding over the whole dataset. dirty_cat only learns one encoding for the whole dataset because it does not do cross validation.

@thomasjpfan
Copy link
Member Author

thomasjpfan commented Jan 10, 2023

What is your plan? Is it ready to go or does it need more tweaking?

It's mostly to decide on how we want to extend the API for classification targets. Currently, this PR is the minimum requirement for regression targets. The core computation in this PR can be extended to classification without too much trouble.

type_of_target is quite good and would do the job.

I do not like how type_of_target will treats floats as classification targets. For example:

import numpy as np
from sklearn.utils.multiclass import type_of_target

type_of_target(np.asarray([1.0] * 10 + [2.0] * 30 + [4.0] * 10 + [5.0]))
# 'multiclass'

I prefer two more explicit options:

  1. TargetRegressorEncoder and TargetClassificationEncoder
  2. TargetEncoder with a target_type parameter to switch between regression and classification.

I went with option 1 in this PR, but I am okay with either option. For option 2, I am +0.5 on having a target_type="auto" option that will infer the type of the target.

@thomasjpfan thomasjpfan changed the title ENH Adds Target Regressor Encoder ENH Adds TargetEncoder Jan 10, 2023
@thomasjpfan
Copy link
Member Author

thomasjpfan commented Jan 10, 2023

After thinking about it a little more, I am okay with just inferring the target type with type_of_target. I updated this PR:

  1. Infer the target type with type_of_target
  2. Added a target_type="auto" option to allow the user to control the inference.
  3. Support both binary classification and regression
  4. The encoder is now called TargetEncoder

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Based on test_target_encoding_for_linear_regression I think we need to do the nested cross-val by default and break the usual fit_transform <=> fit + transform implicit equivalence of other scikit-learn estimators. In particular, I don't see how to compute the "real" training accuracy: to do so we would need a fit_score method on pipelines (which could be a good idea by the way to save some redundant computation, but this is a digression).

Anyways, I don't see any other way around, and to me the protection against catastrophic overfitting caused by noisy high-cardinality categorical features outweighs the potentially surprising (but well documented) behavior of fit_transform.

@thomasjpfan
Copy link
Member Author

I am still no decided whether this should better be a pitfall-style example or a test.

I think both are useful. I have not seen an example similar to your test case that demonstrates why the internal validation is useful.

In a follow up PR, we can convert the test into a pitfall style example and link it in the docstring for cv. (I'm still unsure about adding cv=None)

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spotted another typo in the inline comment of the new test.

sklearn/preprocessing/tests/test_target_encoder.py Outdated Show resolved Hide resolved
sklearn/preprocessing/tests/test_target_encoder.py Outdated Show resolved Hide resolved
@ogrisel
Copy link
Member

ogrisel commented Mar 14, 2023

@lorentzenchr @betatim @glemaitre any more feedback?

@jovan-stojanovic you might be interested in the new test: I checked that dirty_cat's TargetEncoder can also lead to this kind of overfitting problems because of the lack of built-in CV during fit_transform.

@ogrisel
Copy link
Member

ogrisel commented Mar 15, 2023

@thomasjpfan I had a devil inspired idea at coffee: we could store a weakref to the traininset set at fit time to detect if the X passed to .transform is the same as the one passed to .fit/.fit_transform. That would make it possible to hide the fit_transform / fit + transform discrepancy like cuML does but without holding a strong ref to X_train that would be problematic as it would prevent early collection of otherwise GC-able data and would make model serialization inefficient without additional care.

Still the weakref hack could still lead to surprising behaviors. For instance, while the following would work:

X_train = load_dataset_from_disk()
X_trans_1 = target_transformer.fit_transform(X_train)
X_trans_2 = target_transformer.transform(X_train)
np.assert_allclose(X_trans_1, X_trans_2)

this seemingly innocuous variation would fail:

X_train = load_dataset_from_disk()
X_trans_1 = target_transformer.fit_transform(X_train)

X_train = load_dataset_from_disk()
X_trans_2 = target_transformer.transform(X_train)
np.assert_allclose(X_trans_1, X_trans_2)

so overall, I am not 100% sure the weakref hack would be a usability improvement or not.

Feel free to pretend that you haven't read this comment and not reply. I would perfectly understand.

@ogrisel
Copy link
Member

ogrisel commented Mar 15, 2023

Another pitfall I discovered when experimenting with this PR:

If you have a mix of informative and non-informative categorical features (e.g. f_i and f_u), and if you apply StandardScaler() after TargetEncoder() then you can run in big troubles: the variance of target_encoder.fit_tramsform(X_train[[f_u]]) can be very small but non-zero, so StandarScaler can rescale this massively which is probably a bad idea for a downstream distance based model (e.g. a kernel machine, nearest neighbors and so on).

However if you use the raw target encoded values of f_i and f_u in conjunction of true numerical features that have all been scaled (via a dedicated entry in a column transformers) and if y has a much larger standard deviation, then f_i will have a very large standard deviation compared to the independently scaled numerical features, which is problematic.

I see two possible solutions:

  • a) add an option to TargetEncoder to internally use a scaled version of y_train to compute the encodings;
  • b) make it possible for StandardScaler or those scalers to use shared mean and scale statistics for a group of columns:
preprocessor = ColumnTransformer(
    [
        (
            "categorical",
             make_pipeline(TargetEncoder(), StandardScaler(shared_mean=True, shared_scale=True),
            ["f_i", "f_u"],
        ),
    ],
    remainder=StandardScaler(),
)

Option b) is probably useful beyond post-processing target encoded categorical values. For instance @TomDLT and I encountered a similar problem a long time ago when running make_pipeline(StandardScaler(), LogisticRegression()) with the saga solver on MNIST: some pixel values are nearly constant 0 or 1 and having per-feature scaling is catastrophically hurting the condition number of X.T @ X if I recall correctly. At the time we just decided to use MinMaxScaler instead and move along. But this is a real problems when groups of features should be re-scaled together to preserve meaning (e.g. pixel intensities, EEG channels, MEL features for audio signal and so on).

EDIT: we should probably do StandardScaler(shared_mean_and_scale=True) instead because decoupling the two might be challenging from a code maintenance point of view, and not that useful to the end users.

Even if we decide would also be a) convenient UX improvement, we can delay this to a follow-up PR.

I just wanted to brain-dump this here so that we can think about it when we work on a pitfall example for TargetEncoder-based feature engineering.

/cc @jovan-stojanovic who might also be interested for dirty_cat.

@thomasjpfan
Copy link
Member Author

so overall, I am not 100% sure the weakref hack would be a usability improvement or not.

I think it borders on being too magical. For example, if the data is sliced the same way or copied, the references are not the same:

import numpy as np
import weakref

X = np.random.randn(10, 10)

X1 = X[:4]
X2 = X[:4]
X3 = X1.copy()

X1_ref = weakref.ref(X1)

assert X1 is X1_ref()
assert X2 is not X1_ref()
assert X3 is not X1_ref()

For reference, cuML's TargetEncoder holds the training data and checks all the values.

a) add an option to TargetEncoder to internally use a scaled version of y_train to compute the encodings;

At one point, I had something similar implemented in #17323 as the default. I think it's reasonable to use a scaled version of the target for encoding purposes.

@TomDLT
Copy link
Member

TomDLT commented Mar 15, 2023

For instance @TomDLT and I encountered a similar problem a long time ago when running make_pipeline(StandardScaler(), LogisticRegression()) with the saga solver on MNIST: some pixel values are nearly constant 0 or 1 and having per-feature scaling is catastrophic.

Yes, some MNIST pixels reach a value of 250 after rescaling (see Details below), which is quite out of distribution for a unit-norm Normal distribution. We should add a warning when some outputs of StandardScaler are large (say above 100, or even 10).

This is a real problems when groups of features should be re-scaled together to preserve meaning (e.g. pixel intensities, EEG channels, MEL features for audio signal and so on).

Yes, this is a big limitation of StandardScaler. +1 for adding StandardScaler(shared_mean=True, shared_scale=True).

Figure_1

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

from sklearn.datasets import fetch_openml
from sklearn.preprocessing import StandardScaler

mnist = fetch_openml("mnist_784", as_frame=False, parser="pandas")
X, y = mnist.data, mnist.target

X_scaled = StandardScaler().fit_transform(X)
max_values = X_scaled.max(axis=0)

fig, ax = plt.subplots()
image = ax.imshow(max_values.reshape(28, 28), cmap=plt.get_cmap("viridis", 6),
                  norm=LogNorm())
ax.set(xticks=[], yticks=[], title="Maximum value of each scaled MNIST feature")
fig.colorbar(image)
plt.show()

@ogrisel
Copy link
Member

ogrisel commented Mar 16, 2023

a) add an option to TargetEncoder to internally use a scaled version of y_train to compute the encodings;

At one point, I had something similar implemented in #17323 as the default. I think it's reasonable to use a scaled version of the target for encoding purposes.

Let's keep that in mind for a follow-up PR. But it we want to make it the default (which would probably be helpful), we should probably do that before the 1.3 release.

@ogrisel
Copy link
Member

ogrisel commented Mar 16, 2023

For reference, cuML's TargetEncoder holds the training data and checks all the values.

Note that we could use a weakref + a concrete value check. But even that would feel to complex/magical. +0.5 for keeping the code as it is.

@glemaitre glemaitre self-requested a review March 16, 2023 12:53
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have only one nitpick.

examples/preprocessing/plot_target_encoder.py Outdated Show resolved Hide resolved
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
@ogrisel ogrisel merged commit 392fdee into scikit-learn:main Mar 16, 2023
@ogrisel
Copy link
Member

ogrisel commented Mar 16, 2023

Merged! Thank you very much @thomasjpfan!

@lorentzenchr
Copy link
Member

Just an idea: For detection of the training set during transform, we could store mean, variance and maybe 2 quantiles. This should be pretty unique.

@ogrisel
Copy link
Member

ogrisel commented Mar 17, 2023

The checks need to happen before encoding on the categorical variables. We could store the per feature category counts instead. Maybe with a few random probe records that contain several features with infrequent categories.

@ogrisel
Copy link
Member

ogrisel commented Mar 17, 2023

The checks need to happen before encoding on the categorical variables. We could store the per feature category counts instead. Maybe with a few random probe records.

But this would be quite catastrophic in case of false positives.

@lorentzenchr
Copy link
Member

Now I see the difficulty. Maybe it is good enough as is. In principle, we would need to detect every single row of the training set and that’s the responsibility of the user, isn‘t it.

@betatim
Copy link
Member

betatim commented Mar 17, 2023

Whoop whoop! Nice work!

@BenjaminBossan
Copy link
Contributor

BenjaminBossan commented Mar 20, 2023

Pretty nice addition, thanks for this.

A small question: According to the tags, TargetEncoder does not require y, should this be changed?

>>> from sklearn.preprocessing import TargetEncoder
>>> from sklearn.utils._tags import _safe_tags
>>> _safe_tags(TargetEncoder())['requires_y']
False

Veghit pushed a commit to Veghit/scikit-learn that referenced this pull request Apr 15, 2023
Co-authored-by: Andreas Mueller <t3kcit@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Jovan Stojanovic <62058944+jovan-stojanovic@users.noreply.github.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
MohitBurkule added a commit to MohitBurkule/scikit-learn that referenced this pull request May 7, 2023
* MAINT Clean deprecated losses in (hist) gradient boosting for 1.3 (scikit-learn#25834)

* MAINT Clean deprecation of normalize in calibration_curve for 1.3 (scikit-learn#25833)

* BLD Clean command removes generated from cython templates (scikit-learn#25839)

* PERF Implement `PairwiseDistancesReduction` backend for `KNeighbors.predict_proba` (scikit-learn#24076)

Signed-off-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

* MAINT Added Parameter Validation for datasets.make_circles (scikit-learn#25848)

Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>

* MNT use a single job by default with sphinx build (scikit-learn#25836)

* BLD Generate warning automatically for templated cython files (scikit-learn#25842)

* MAINT parameter validation for sklearn.datasets.fetch_lfw_people (scikit-learn#25820)

Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>

* MAINT Parameters validation for metrics.fbeta_score (scikit-learn#25841)

* TST add global_random_seed fixture to sklearn/covariance/tests/test_robust_covariance.py (scikit-learn#25821)

* MAINT Parameter validation for linear_model.orthogonal_mp (scikit-learn#25817)

* TST activate common tests for TSNE (scikit-learn#25374)

* CI Update lock files (scikit-learn#25849)

* MAINT Added Parameter Validation for metrics.mean_gamma_deviance (scikit-learn#25853)

* MAINT Parameters validation for feature_selection.mutual_info_regression (scikit-learn#25850)

* MAINT parameter validation metrics.class_likelihood_ratios (scikit-learn#25863)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Ensure disjoint interval constraints (scikit-learn#25797)

* MAINT Parameters validation for utils.gen_batches (scikit-learn#25864)

* TST use global_random_seed in test_dict_vectorizer.py (scikit-learn#24533)

* TST use global_random_seed in test_pls.py (scikit-learn#24526)

Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>

* TST use global_random_seed in test_gpc.py (scikit-learn#24600)

Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>

* DOC Fix overlapping plot axis in bench_sample_without_replacement.py (scikit-learn#25870)

* MAINT Use contiguous memoryviews in _random.pyx (scikit-learn#25871)

* MAINT parameter validation sklearn.datasets.fetch_lfw_pair (scikit-learn#25857)

* MAINT Parameters validation for metrics.classification_report (scikit-learn#25868)

* Empty commit

* DOC fix docstring dtype parameter in OrdinalEncoder (scikit-learn#25877)

* MAINT Clean up depreacted "log" loss of SGDClassifier for 1.3 (scikit-learn#25865)

* ENH Adds TargetEncoder (scikit-learn#25334)

Co-authored-by: Andreas Mueller <t3kcit@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Jovan Stojanovic <62058944+jovan-stojanovic@users.noreply.github.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* CI make it possible to cancel running Azure jobs (scikit-learn#25876)

* MAINT Clean-up deprecated if_delegate_has_method for 1.3 (scikit-learn#25879)

* MAINT Parameter validation for tree.export_text (scikit-learn#25867)

* DOC impact of `tol` for solvers in RidgeClassifier (scikit-learn#25530)

* MAINT Parameters validation for metrics.hinge_loss (scikit-learn#25880)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for metrics.ndcg_score (scikit-learn#25885)

* ENH KMeans initialization account for sample weights (scikit-learn#25752)

Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* TST use global_random_seed in sklearn/tests/test_dummy.py (scikit-learn#25884)

* DOC improve calibration user guide (scikit-learn#25687)

* ENH Support for sparse matrices added to `sklearn.metrics.silhouette_samples` (scikit-learn#24677)

Co-authored-by: Sahil Gupta <sahil@Sahils-MBP.lan>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* MAINT validate_params for plot_tree (scikit-learn#25882)

Co-authored-by: Itay <itayvegh@gmail.com>

* MAINT add missing space in error message in SVM (scikit-learn#25913)

* FIX Adds requires_y tag to TargetEncoder (scikit-learn#25917)

* MAINT Consistent cython types continued (scikit-learn#25810)

* TST Speed-up common tests of DictionaryLearning (scikit-learn#25892)

* TST Speed-up test_dbscan_optics_parity (scikit-learn#25893)

* ENH add np.nan option for zero_division in precision/recall/f-score (scikit-learn#25531)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* MAINT Parameters validation for datasets.make_low_rank_matrix (scikit-learn#25901)

* MAINT Parameter validation for metrics.cluster.adjusted_mutual_info_score (scikit-learn#25898)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* TST Speed-up test_partial_dependence.test_output_shape (scikit-learn#25895)

Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>

* MAINT Parameters validation for datasets.make_regression (scikit-learn#25899)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for metrics.mean_squared_log_error (scikit-learn#25924)

* TST Use global_random_seed in tests/test_naive_bayes.py (scikit-learn#25890)

* TST add global_random_seed fixture to sklearn/datasets/tests/test_covtype.py (scikit-learn#25904)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>

* MAINT Parameters validation for datasets.make_multilabel_classification (scikit-learn#25920)

* Fixed feature mapping typo (scikit-learn#25934)

* MAINT switch to newer codecov uploader (scikit-learn#25919)

Co-authored-by: Loïc Estève <loic.esteve@ymail.com>

* TST Speed-up test suite when using pytest-xdist (scikit-learn#25918)

* DOC update license year to 2023 (scikit-learn#25936)

* FIX Remove spurious feature names warning in IsolationForest (scikit-learn#25931)

* TST fix unstable test_newrand_set_seed (scikit-learn#25940)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Clean-up deprecated max_features="auto" in trees/forests/gb (scikit-learn#25941)

* MAINT LogisticRegression informative error msg when penaly=elasticnet and l1_ratio is None (scikit-learn#25925)

Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>

* MAINT Clean-up remaining SGDClassifier(loss="log") (scikit-learn#25938)

* FIX Fixes pandas extension arrays in check_array (scikit-learn#25813)

* FIX Fixes pandas extension arrays with objects in check_array (scikit-learn#25814)

* CI Disable pytest-xdist in pylatest_pip_openblas_pandas build (scikit-learn#25943)

* MAINT remove deprecated call to resources.content (scikit-learn#25951)

* DOC note on calibration impact on ranking (scikit-learn#25900)

* Remove loguniform fix, use scipy.stats instead (scikit-learn#24665)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

* MAINT Fix broken links in cluster.dbscan module (scikit-learn#25958)

* DOC Fix lars Xy shape (scikit-learn#25952)

* ENH Add drop_intermediate parameter to metrics.precision_recall_curve (scikit-learn#24668)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* FIX improve error message when computing NDCG with a single document (scikit-learn#25672)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* MAINT introduce _get_response_values and _check_response_methods (scikit-learn#23073)

Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Extend message for large sparse matrices support (scikit-learn#25961)

Co-authored-by: Meekail Zain <34613774+Micky774@users.noreply.github.com>

* MAINT Parameters validation for datasets.make_gaussian_quantiles (scikit-learn#25959)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.metrics.d2_tweedie_score (scikit-learn#25975)

* MAINT Parameters validation for datasets.make_hastie_10_2 (scikit-learn#25967)

* MAINT Parameters validation for preprocessing.minmax_scale (scikit-learn#25962)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for datasets.make_checkerboard (scikit-learn#25955)

* MAINT Parameters validation for datasets.make_biclusters (scikit-learn#25945)

* MAINT Parameters validation for datasets.make_moons (scikit-learn#25971)

* DOC replace deviance by loss in docstring of GradientBoosting (scikit-learn#25968)

* MAINT Fix broken link in feature_selection/_univariate_selection.py (scikit-learn#25984)

* DOC Update model_persistence.rst to fix skops example (scikit-learn#25993)

Co-authored-by: adrinjalali <adrin.jalali@gmail.com>

* DOC Specified meaning for max_patches=None in extract_patches_2d  (scikit-learn#25996)

* DOC document that last step is never cached in pipeline (scikit-learn#25995)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* FIX SequentialFeatureSelector throws IndexError when cv is a generator (scikit-learn#25973)

* ENH Adds infrequent categories support to OrdinalEncoder (scikit-learn#25677)

Co-authored-by: Tim Head <betatim@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Andreas Mueller <t3kcit@gmail.com>

* MAINT make plot_digits_denoising deterministic by fixing random state (scikit-learn#26004)

* DOC improve example of PatchExtractor (scikit-learn#26002)

* MAINT Parameters validation for datasets.make_friedman2 (scikit-learn#25986)

* MAINT Parameters validation for datasets.make_friedman3 (scikit-learn#25989)

* MAINT Parameters validation for datasets.make_sparse_uncorrelated (scikit-learn#26001)

* MAINT Parameters validation for datasets.make_spd_matrix (scikit-learn#26003)

* MAINT Parameters validation for datasets.make_sparse_spd_matrix (scikit-learn#26009)

* DOC Added the meanings of default=None for PatchExtractor parameters (scikit-learn#26005)

* MAINT remove unecessary check covered by parameter validation framework (scikit-learn#26014)

* MAINT Consistent cython types from _typedefs (scikit-learn#25942)

Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>

* MAINT Parameters validation for datasets.make_swiss_roll (scikit-learn#26020)

* MAINT Parameters validation for datasets.make_s_curve (scikit-learn#26022)

* MAINT Parameters validation for datasets.make_blobs (scikit-learn#25983)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* DOC fix SplineTransformer include_bias docstring (scikit-learn#26018)

* ENH RocCurveDisplay add option to plot chance level (scikit-learn#25987)

* DOC show from_estimator and from_predictions for Displays (scikit-learn#25994)

* EXA Fix rst in plot_partial_dependence (scikit-learn#26028)

* CI Adds coverage to docker jobs on Azure (scikit-learn#26027)

Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

* API Replace `n_iter` in `Bayesian Ridge` and `ARDRegression` (scikit-learn#25697)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* CLN Make _NumPyAPIWrapper naming consistent to _ArrayAPIWrapper (scikit-learn#26039)

* CI disable coverage on Windows to keep CI times reasonable (scikit-learn#26052)

* DOC Use Scientific Python Plausible instance for analytics (scikit-learn#25547)

* MAINT Parameters validation for sklearn.preprocessing.scale (scikit-learn#26036)

* MAINT Parameters validation for sklearn.metrics.pairwise.haversine_distances (scikit-learn#26047)

* MAINT Parameters validation for sklearn.metrics.pairwise.laplacian_kernel (scikit-learn#26048)

* MAINT Parameters validation for sklearn.metrics.pairwise.linear_kernel (scikit-learn#26049)

* MAINT Parameters validation for sklearn.metrics.silhouette_samples (scikit-learn#26053)

* MAINT Parameters validation for sklearn.preprocessing.add_dummy_feature (scikit-learn#26058)

* Added Parameter Validation for metrics.cluster.normalized_mutual_info_score() (scikit-learn#26060)

* DOC Typos in HistGradientBoosting documentation (scikit-learn#26057)

* TST add global_random_seed fixture to sklearn/datasets/tests/test_rcv1.py (scikit-learn#26043)

* MAINT Parameters validation for sklearn.metrics.pairwise.cosine_similarity (scikit-learn#26006)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* ENH Adds isdtype to Array API wrapper (scikit-learn#26029)

* MAINT Parameters validation for sklearn.metrics.silhouette_score (scikit-learn#26054)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* FIX fix spelling mistake in _NumPyAPIWrapper (scikit-learn#26064)

* CI ignore more non-library Python files in codecov (scikit-learn#26059)

* MAINT Parameters validation for sklearn.metrics.pairwise.cosine_distances (scikit-learn#26046)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Introduce BinaryClassifierCurveDisplayMixin (scikit-learn#25969)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* ENH Forces shape to be tuple when using Array API's reshape (scikit-learn#26030)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Tim Head <betatim@gmail.com>

* MAINT Parameters validation for sklearn.metrics.pairwise.paired_euclidean_distances (scikit-learn#26073)

* MAINT Parameters validation for sklearn.metrics.pairwise.paired_manhattan_distances (scikit-learn#26074)

* MAINT Parameters validation for sklearn.metrics.pairwise.paired_cosine_distances (scikit-learn#26075)

* MAINT Parameters validation for sklearn.preprocessing.binarize (scikit-learn#26076)

* MAINT Parameters validation for metrics.explained_variance_score (scikit-learn#26079)

* DOC use correct template name for displays (scikit-learn#26081)

* MAINT Parameters validation for sklearn.preprocessing.maxabs_scale (scikit-learn#26077)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.preprocessing.label_binarize (scikit-learn#26078)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT parameter validation for d2_absolute_error_score (scikit-learn#26066)

Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>

* MAINT Parameter validation for roc_auc_score (scikit-learn#26007)

Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>

* MAINT Parameters validation for sklearn.preprocessing.normalize (scikit-learn#26069)

Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>

* MAINT Parameter validation for metrics.cluster.fowlkes_mallows_score (scikit-learn#26080)

Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>

* MAINT Parameters validation for compose.make_column_transformer (scikit-learn#25897)

Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>

* MAINT Parameters validation for sklearn.metrics.pairwise.polynomial_kernel (scikit-learn#26070)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.metrics.pairwise.rbf_kernel (scikit-learn#26071)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.metrics.pairwise.sigmoid_kernel (scikit-learn#26072)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Param validation: constraint for numeric missing values (scikit-learn#26085)

* FIX Adds support for negative values in categorical features in gradient boosting (scikit-learn#25629)

Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Tim Head <betatim@gmail.com>

* MAINT Fix C warning in Cython module splitting.pyx (scikit-learn#26051)

* MNT Updates _isotonic.pyx to use memoryviews instead of `cnp.ndarray` (scikit-learn#26068)

* FIX Fixes memory regression for inspecting extension arrays (scikit-learn#26106)

* PERF set openmp to use only physical cores by default (scikit-learn#26082)

* MNT Update black to 23.3.0 (scikit-learn#26110)

* MNT Adds black commit to git-blame-ignore-revs (scikit-learn#26111)

* MAINT Parameters validation for sklearn.metrics.pair_confusion_matrix (scikit-learn#26107)

* MAINT Parameters validation for sklearn.metrics.mean_poisson_deviance (scikit-learn#26104)

* DOC Use notebook style in plot_lof_outlier_detection.py (scikit-learn#26017)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* MAINT utils._fast_dict uses types from utils._typedefs (scikit-learn#26025)

* DOC remove sparse-matrix for `y` in ElasticNet (scikit-learn#26127)

* ENH add exponential loss (scikit-learn#25965)

* MAINT Parameters validation for sklearn.preprocessing.robust_scale (scikit-learn#26086)

* MAINT Parameters validation for sklearn.datasets.fetch_rcv1 (scikit-learn#26126)

* MAINT Parameters validation for sklearn.metrics.adjusted_rand_score (scikit-learn#26134)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.metrics.calinski_harabasz_score  (scikit-learn#26135)

* MAINT Parameters validation for sklearn.metrics.davies_bouldin_score  (scikit-learn#26136)

* MAINT: remove `from numpy.math cimport` statements (scikit-learn#26143)

* MAINT Parameters validation for sklearn.inspection.permutation_importance (scikit-learn#26145)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.metrics.cluster.homogeneity_completeness_v_measure (scikit-learn#26137)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.metrics.rand_score (scikit-learn#26138)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* DOC update comment in metrics/tests/test_classification.py (scikit-learn#26150)

* CI small cleanup of Cirrus CI test script (scikit-learn#26168)

* MAINT remove deprecated is_categorical_dtype (scikit-learn#26156)

* DOC Add skforecast to related projects page (scikit-learn#26133)

Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>

* FIX Keeps namedtuple's class when transform returns a tuple (scikit-learn#26121)

* DOC corrected letter case for better readability in sklearn/metrics/_classification.py / (scikit-learn#26169)

* MAINT Parameters validation for sklearn.preprocessing.power_transform (scikit-learn#26142)

* FIX `roc_auc_score` now uses `y_prob` instead of `y_pred` (scikit-learn#26155)

* MAINT Parameters validation for sklearn.datasets.load_iris (scikit-learn#26177)

* MAINT Parameters validation for sklearn.datasets.load_diabetes (scikit-learn#26166)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.datasets.load_breast_cancer (scikit-learn#26165)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.metrics.cluster.entropy (scikit-learn#26162)

* MAINT Parameters validation for sklearn.datasets.fetch_species_distributions (scikit-learn#26161)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* ASV Fix tol in SGDRegressorBenchmark (scikit-learn#26146)

Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>

* MNT use api.openml.org URLs for fetch_openml (scikit-learn#26171)

* MAINT Parameters validation for sklearn.utils.resample (scikit-learn#26139)

* MAINT make it explicit that additive_chi2_kernel does not accept sparse matrix (scikit-learn#26178)

* MNT fix circleci link in README.rst (scikit-learn#26183)

* CI Fix circleci artifact redirector action (scikit-learn#26181)

* GOV introduce rights for groups as discussed in SLEP019 (scikit-learn#25753)

Co-authored-by: Julien <git@jjerphan.xyz>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>

* MAINT Parameters validation for sklearn.neighbors.sort_graph_by_row_values (scikit-learn#26173)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* FIX improve convergence criterion for LogisticRegression(penalty="l1", solver='liblinear') (scikit-learn#25214)

Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

* MAINT Fix several typos in src and doc files (scikit-learn#26187)

* PERF fix overhead of _rescale_data in LinearRegression (scikit-learn#26207)

* ENH add Huber loss (scikit-learn#25966)

* MAINT Refactor GraphicalLasso and graphical_lasso (scikit-learn#26033)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Cython linting (scikit-learn#25861)

* DOC Add JupyterLite button in example gallery (scikit-learn#25887)

* MAINT Parameters validation for sklearn.covariance.ledoit_wolf_shrinkage (scikit-learn#26200)

* MAINT Parameters validation for sklearn.datasets.load_linnerud (scikit-learn#26199)

* MAINT Parameters validation for sklearn.datasets.load_wine (scikit-learn#26196)

* DOC Added redirect to Provost paper + minor refactor (scikit-learn#26223)

* MAINT Parameter Validation for `covariance.graphical_lasso` (scikit-learn#25053)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.datasets.load_digits (scikit-learn#26195)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.preprocessing.quantile_transform (scikit-learn#26144)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.model_selection.cross_validate (scikit-learn#26129)

Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>

* DOC Adds TargetEncoder example explaining the internal CV (scikit-learn#26185)

Co-authored-by: Tim Head <betatim@gmail.com>

* spelling mistake corrected in documentation for script `plot_document_clustering.py` (scikit-learn#26228)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

* FIX possible UnboundLocalError in fetch_openml (scikit-learn#26236)

* ENH Adds PyTorch support to LinearDiscriminantAnalysis (scikit-learn#25956)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Tim Head <betatim@gmail.com>

* MNT Use fixed version of Pyodide (scikit-learn#26247)

* MNT Reset transform_output default in example to fix doc build build (scikit-learn#26269)

* DOC Update example plot_nearest_centroid.py (scikit-learn#26263)

* MNT reduce JupyterLite build size (scikit-learn#26246)

* DOC term -> meth in GradientBoosting (scikit-learn#26225)

* MNT speed-up html-noplot build (scikit-learn#26245)

Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>

* MNT Use copy=False when creating DataFrames (scikit-learn#26272)

* MAINT Parameters validation for sklearn.model_selection.permutation_test_score (scikit-learn#26230)

* MAINT Parameters validation for sklearn.datasets.clear_data_home (scikit-learn#26259)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.datasets.load_files (scikit-learn#26203)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.datasets.get_data_home (scikit-learn#26260)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* DOC Fix y-axis plot labels in permutation test score example (scikit-learn#26240)

* MAINT cython-lint ignores asv_benchmarks (scikit-learn#26282)

* MAINT Parameter validation for metrics.cluster._supervised (scikit-learn#26258)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* DOC Improve docstring for tol in SequentialFeatureSelector (scikit-learn#26271)

* MAINT Parameters validation for  sklearn.datasets.load_sample_image (scikit-learn#26226)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* DOC Consistent param type for pos_label (scikit-learn#26237)

* DOC Minor grammar fix to imputation docs (scikit-learn#26283)

* MAINT Parameters validation for sklearn.calibration.calibration_curve (scikit-learn#26198)

Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>

* MAINT Parameters validation for sklearn.inspection.partial_dependence (scikit-learn#26209)

Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>

* MAINT Parameters validation for sklearn.model_selection.validation_curve (scikit-learn#26229)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.model_selection.learning_curve (scikit-learn#26227)

Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>

* MNT Remove deprecated pandas.api.types.is_sparse (scikit-learn#26287)

* CI Use Trusted Publishers for uploading wheels to PyPI (scikit-learn#26249)

* MAINT Parameters validation for sklearn.metrics.pairwise.manhattan_distances (scikit-learn#26122)

* PERF revert openmp use in csr_row_norms (scikit-learn#26275)

* MAINT Parameters validation for metrics.check_scoring (scikit-learn#26041)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MNT Improve error message when checking classification target is of a non-regression type (scikit-learn#26281)

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>

* DOC fix link to User Guide encoder_infrequent_categories (scikit-learn#26309)

* MNT remove unused args in _predict_regression_tree_inplace_fast_dense (scikit-learn#26314)

* ENH Adds missing value support for trees (scikit-learn#23595)

Co-authored-by: Tim Head <betatim@gmail.com>
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>

* CLN Clean up logic in validate_data and cast_to_ndarray (scikit-learn#26300)

* MAINT refactor scorer using _get_response_values (scikit-learn#26037)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>

* DOC Add HGBDT to "see also" section of random forests (scikit-learn#26319)

Co-authored-by: ArturoAmorQ <arturo.amor-quiroz@polytechnique.edu>
Co-authored-by: Tim Head <betatim@gmail.com>

* MNT Bump Github Action labeler version to use newer Node (scikit-learn#26302)

* FIX thresholds should not exceed 1.0 with probabilities in `roc_curve`  (scikit-learn#26194)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

* ENH Allow for appropriate dtype us in `preprocessing.PolynomialFeatures` for sparse matrices (scikit-learn#23731)

Co-authored-by: Aleksandr Kokhaniukov <alexander.kohanyukov@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>

* DOC Fix minor typo (scikit-learn#26327)

* MAINT bump minimum version for pytest (scikit-learn#26184)

Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

* DOC fix return type in isotonic_regression (scikit-learn#26332)

* FIX fix available_if for MultiOutputRegressor.partial_fit (scikit-learn#26333)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* FIX make pipeline pass check_estimator (scikit-learn#26325)

* FEA Add multiclass support to `average_precision_score` (scikit-learn#24769)

Co-authored-by: Geoffrey <geoffrey.bolmier@gmail.com>
Co-authored-by: gbolmier <geoffrey.bolmier@volvocars.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>

---------

Signed-off-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
Co-authored-by: Meekail Zain <34613774+Micky774@users.noreply.github.com>
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: zeeshan lone <56621467+still-learning-ev@users.noreply.github.com>
Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Co-authored-by: Shiva chauhan <103742975+Shivachauhan17@users.noreply.github.com>
Co-authored-by: AymericBasset <45051041+AymericBasset@users.noreply.github.com>
Co-authored-by: Maren Westermann <maren.westermann@gmail.com>
Co-authored-by: Nishu Choudhary <51842539+choudharynishu@users.noreply.github.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
Co-authored-by: Benedek Harsanyi <80836204+hbenedek@users.noreply.github.com>
Co-authored-by: Pooja Subramaniam <poojas2086@gmail.com>
Co-authored-by: Rushil Desai <rushildesai01@gmail.com>
Co-authored-by: Xiao Yuan <yuanx749@gmail.com>
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Co-authored-by: 2357juan <29247195+2357juan@users.noreply.github.com>
Co-authored-by: Théophile Baranger <39696928+tbaranger@users.noreply.github.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Andreas Mueller <t3kcit@gmail.com>
Co-authored-by: Jovan Stojanovic <62058944+jovan-stojanovic@users.noreply.github.com>
Co-authored-by: Rahil Parikh <75483881+rprkh@users.noreply.github.com>
Co-authored-by: Bharat Raghunathan <bharatraghunthan9767@gmail.com>
Co-authored-by: Sortofamudkip <wishyutp0328@gmail.com>
Co-authored-by: Gleb Levitski <36483986+glevv@users.noreply.github.com>
Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>
Co-authored-by: Ashwin Mathur <97467100+awinml@users.noreply.github.com>
Co-authored-by: Sahil Gupta <sahil@Sahils-MBP.lan>
Co-authored-by: Veghit <itay.vegh@gmail.com>
Co-authored-by: Itay <itayvegh@gmail.com>
Co-authored-by: precondition <57645186+precondition@users.noreply.github.com>
Co-authored-by: Marc Torrellas Socastro <marc.torsoc@gmail.com>
Co-authored-by: Dominic Fox <dominicjfox2@gmail.com>
Co-authored-by: futurewarning <36329275+futurewarning@users.noreply.github.com>
Co-authored-by: Yao Xiao <108576690+Charlie-XIAO@users.noreply.github.com>
Co-authored-by: Joey Ortiz <orangesherbet0@gmail.com>
Co-authored-by: Tim Head <betatim@gmail.com>
Co-authored-by: Christian Veenhuis <veenhuis@gmail.com>
Co-authored-by: adienes <51664769+adienes@users.noreply.github.com>
Co-authored-by: Dave Berenbaum <dave.berenbaum@gmail.com>
Co-authored-by: Lene Preuss <lene.preuss@gmail.com>
Co-authored-by: A.H.Mansouri <83764851+A-H-Mansoury@users.noreply.github.com>
Co-authored-by: Boris Feld <lothiraldan@gmail.com>
Co-authored-by: Carla J <ca.jancik@gmail.com>
Co-authored-by: windiana42 <61181806+windiana42@users.noreply.github.com>
Co-authored-by: mdarii <dariimaxim@gmail.com>
Co-authored-by: murezzda <47388020+murezzda@users.noreply.github.com>
Co-authored-by: Peter Piontek <piontek0@gmail.com>
Co-authored-by: John Pangas <swiftyxswaggy@outlook.com>
Co-authored-by: Dmitry Nesterov <76070534+dmitrylala@users.noreply.github.com>
Co-authored-by: Yuchen Zhou <72342196+ROMEEZHOU@users.noreply.github.com>
Co-authored-by: Ekaterina Butyugina <102963496+ekaterinabutyugina@users.noreply.github.com>
Co-authored-by: Jiawei Zhang <jiawei.zhang@nyu.edu>
Co-authored-by: Ansam Zedan <86729068+ansamz@users.noreply.github.com>
Co-authored-by: genvalen <genvalen@protonmail.com>
Co-authored-by: farhan khan <86480450+BabaYaga1221@users.noreply.github.com>
Co-authored-by: Arturo Amor <86408019+ArturoAmorQ@users.noreply.github.com>
Co-authored-by: Jiawei Zhang <jz4721@nyu.edu>
Co-authored-by: Ralf Gommers <ralf.gommers@gmail.com>
Co-authored-by: Jessicakk0711 <106110789+Jessicakk0711@users.noreply.github.com>
Co-authored-by: Ankur Singh <singankur28@gmail.com>
Co-authored-by: Seoeun(Sun☀️) Hong <75988952+seoeunHong@users.noreply.github.com>
Co-authored-by: Nightwalkx <74856680+xi-jiajun@users.noreply.github.com>
Co-authored-by: VIGNESH D <35656793+dvignesh1995@users.noreply.github.com>
Co-authored-by: Vincent-violet <130581473+Vincent-violet@users.noreply.github.com>
Co-authored-by: Elabonga Atuo <elabongaatuo@gmail.com>
Co-authored-by: Tom Dupré la Tour <tom.dupre-la-tour@m4x.org>
Co-authored-by: André Pedersen <andrped94@gmail.com>
Co-authored-by: Ashish Dutt <ashish.dutt8@gmail.com>
Co-authored-by: Phil <philsupertramp@users.noreply.github.com>
Co-authored-by: Stanislav (Stanley) Modrak <44023416+smith558@users.noreply.github.com>
Co-authored-by: hujiahong726 <52920842+hujiahong726@users.noreply.github.com>
Co-authored-by: James Dean <24254612+AcylSilane@users.noreply.github.com>
Co-authored-by: ArturoAmorQ <arturo.amor-quiroz@polytechnique.edu>
Co-authored-by: Aleksandr Kokhaniukov <alexander.kohanyukov@gmail.com>
Co-authored-by: c-git <43485962+c-git@users.noreply.github.com>
Co-authored-by: annegnx <64203599+annegnx@users.noreply.github.com>
Co-authored-by: Geoffrey <geoffrey.bolmier@gmail.com>
Co-authored-by: gbolmier <geoffrey.bolmier@volvocars.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

CountFeaturizer for categorical data
10 participants