Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Adds TargetEncoder #25334

Merged
merged 91 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from 90 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
f29c2d8
ENH Adds Target Regression Encoder
thomasjpfan Jan 8, 2023
336b17d
DOC Adds pr number
thomasjpfan Jan 8, 2023
c52ce42
DOC Adds pr number
thomasjpfan Jan 8, 2023
1863be9
DOC Adds example
thomasjpfan Jan 9, 2023
22af6a6
DOC Fixes example link
thomasjpfan Jan 9, 2023
f535f54
FIX Use fancy indexing
thomasjpfan Jan 9, 2023
c6d1fc4
FIX Fix issue with 32bit
thomasjpfan Jan 9, 2023
2868b3c
ENH Adds support for binary classification
thomasjpfan Jan 10, 2023
aa6e545
TST Adds test to check target_type_
thomasjpfan Jan 10, 2023
871ea45
DOC Update whats new with better wording
thomasjpfan Jan 10, 2023
71e6bad
DOC Update docs about target encoder itself
thomasjpfan Jan 10, 2023
fdd39d0
TST Update names of tests
thomasjpfan Jan 10, 2023
c8d5546
DOC better names for variables
thomasjpfan Jan 10, 2023
6a5771c
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Jan 10, 2023
6dbcb85
ENH Adds target_type parameter
thomasjpfan Jan 10, 2023
00d5230
TST Fixes test failures
thomasjpfan Jan 10, 2023
e75a51e
DOC Remove mention on RMSE
thomasjpfan Jan 10, 2023
2978e23
DOC Adds concluding statement in example
thomasjpfan Jan 10, 2023
429085d
CLN Refactor names
thomasjpfan Jan 10, 2023
93ea268
DOC Improves user guide
thomasjpfan Jan 13, 2023
cfe7afa
CLN Cleaner implementation
thomasjpfan Jan 13, 2023
3967934
ENH Adds auto for smoothing
thomasjpfan Jan 13, 2023
7875c2e
CLN Address comments
thomasjpfan Jan 13, 2023
5ff5b61
DOC Update example title
thomasjpfan Jan 13, 2023
2568e89
DOC Adds comment to point to equation
thomasjpfan Jan 13, 2023
afc4223
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Jan 13, 2023
6e4bf50
Apply suggestions from code review
thomasjpfan Jan 13, 2023
d613c61
DOC Only say classification
thomasjpfan Jan 13, 2023
ba85edb
DOC Update examples and docstring
thomasjpfan Jan 13, 2023
e3c8473
DOC Imrove example
thomasjpfan Jan 14, 2023
0d6c4a2
DOC Clarify fit_transform vs fit.transform
thomasjpfan Jan 14, 2023
d9cd410
CLN Address comments
thomasjpfan Jan 14, 2023
ee1e26b
DOC Adds more code comments
thomasjpfan Jan 14, 2023
a399791
CLN Code formatting
thomasjpfan Jan 14, 2023
fd2b963
CLN Move test closer to comment
thomasjpfan Jan 14, 2023
cd6c650
DOC Adds more details aboue equations
thomasjpfan Jan 14, 2023
991fd9f
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Jan 16, 2023
962125d
CLN Address comments
thomasjpfan Jan 18, 2023
02d8346
TST Simplify test logic
thomasjpfan Jan 20, 2023
acb4b2e
CLN Address comments
thomasjpfan Jan 20, 2023
3261d10
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Jan 20, 2023
a57b117
DOC Adds note about type_of_target inference
thomasjpfan Jan 20, 2023
5a330be
ENH Use integers for counts
thomasjpfan Jan 23, 2023
f1a14d7
STY Formatting
thomasjpfan Jan 23, 2023
66fb8bd
STY Fix numpydoc linting error
thomasjpfan Jan 23, 2023
f5df8b8
DOC Adds example about low and high smoothing parameters
thomasjpfan Jan 27, 2023
906ac98
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Jan 30, 2023
43630cf
Apply suggestions from code review
thomasjpfan Jan 30, 2023
177128d
STY Black formating
thomasjpfan Jan 31, 2023
e37b6ec
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Feb 2, 2023
86d357e
DOC Adds mixed encoder for high and low cardinality
thomasjpfan Feb 2, 2023
a7cc80f
DOC Address comments
thomasjpfan Feb 2, 2023
0bcad5b
ENH Make sure the colors align between graphs
thomasjpfan Feb 2, 2023
61081db
Apply suggestions from code review
thomasjpfan Feb 2, 2023
2e4bd85
CLN Reformat section headers and move sections around
thomasjpfan Feb 2, 2023
07a225c
CLN Try to work around sphinx 4.0.1
thomasjpfan Feb 2, 2023
7cd0c8d
CLN Address comments
thomasjpfan Feb 6, 2023
94fd097
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Feb 7, 2023
56dd2fa
CLN Better variable names
thomasjpfan Feb 10, 2023
317e13a
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Feb 10, 2023
c4eacfa
ENH Restrict cv to integers
thomasjpfan Feb 10, 2023
0ab474b
DOC Improve docstring
thomasjpfan Feb 10, 2023
c09556c
CLN Better variable names in test
thomasjpfan Feb 10, 2023
b7654bb
Apply suggestions from code review
thomasjpfan Feb 11, 2023
387f42c
DOC Simplify notes
thomasjpfan Feb 11, 2023
efe630c
DOC Adds diagram about cross validation
thomasjpfan Feb 12, 2023
803a988
DOC Explain fit vs fit_transform
thomasjpfan Feb 12, 2023
f5f7ab1
DOC Clarify fit
thomasjpfan Feb 12, 2023
0e784b1
Apply suggestions from code review
thomasjpfan Mar 7, 2023
444f589
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Mar 7, 2023
4994bc4
FIX Fixes compile issue
thomasjpfan Mar 7, 2023
8fb1a11
CLN Removes pow
thomasjpfan Mar 7, 2023
850694e
CLN Use nogil and pointers
thomasjpfan Mar 7, 2023
665e2e9
MNT Use vector instead
thomasjpfan Mar 7, 2023
5247123
DOC Adds note about missing values
thomasjpfan Mar 7, 2023
70f99fe
API Adds shuffle and random_state
thomasjpfan Mar 7, 2023
e09660c
STY Slight reformatting
thomasjpfan Mar 7, 2023
690fcdc
Check random seeds [all random seeds]
thomasjpfan Mar 8, 2023
a327b52
API Change attriubte name to target_mean_
thomasjpfan Mar 8, 2023
6b27afd
Apply suggestions from code review
thomasjpfan Mar 8, 2023
9773765
STY Linting issue
thomasjpfan Mar 8, 2023
8e18746
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Mar 8, 2023
1ea0407
FIX Fixes edge case with smooth=0.0 and unknown categories in cv
thomasjpfan Mar 10, 2023
df5574d
Apply suggestions from code review
thomasjpfan Mar 10, 2023
d681e06
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Mar 10, 2023
0b458c3
TST Improves testing maintainability
thomasjpfan Mar 13, 2023
26d2429
TST add statistical-integration tests to check the benefit of interna…
ogrisel Mar 14, 2023
d77e519
Fix comment in test to more accurately describe what's happening
ogrisel Mar 14, 2023
05997bb
TST Use RNG for permutation
thomasjpfan Mar 14, 2023
6377419
More fixes in inline comments for the new test
ogrisel Mar 14, 2023
180ee0e
Use pandas to parse the fetched the example dataset.
ogrisel Mar 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/images/target_encoder_cross_validation.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,7 @@ details.
preprocessing.RobustScaler
preprocessing.SplineTransformer
preprocessing.StandardScaler
preprocessing.TargetEncoder

.. autosummary::
:toctree: generated/
Expand Down
86 changes: 86 additions & 0 deletions doc/modules/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,92 @@ lexicon order.
>>> enc.infrequent_categories_
[array(['b', 'c'], dtype=object)]

.. _target_encoder:

Target Encoder
--------------

.. currentmodule:: sklearn.preprocessing

The :class:`TargetEncoder` uses the target mean conditioned on the categorical
feature for encoding unordered categories, i.e. nominal categories [PAR]_
[MIC]_. This encoding scheme is useful with categorical features with high
cardinality, where one-hot encoding would inflate the feature space making it
more expensive for a downstream model to process. A classical example of high
cardinality categories are location based such as zip code or region. For the
binary classification target, the target encoding is given by:

.. math::
S_i = \lambda_i\frac{n_{iY}}{n_i} + (1 - \lambda_i)\frac{n_y}{n}

where :math:`S_i` is the encoding for category :math:`i`, :math:`n_{iY}` is the
number of observations with :math:`Y=1` with category :math:`i`, :math:`n_i` is
the number of observations with category :math:`i`, :math:`n_y` is the number of
observations with :math:`Y=1`, :math:`n` is the number of observations, and
:math:`\lambda_i` is a shrinkage factor. The shrinkage factor is given by:

.. math::
\lambda_i = \frac{n_i}{m + n_i}

where :math:`m` is a smoothing factor, which is controlled with the `smooth`
parameter in :class:`TargetEncoder`. Large smoothing factors will put more
weight on the global mean. When `smooth="auto"`, the smoothing factor is
computed as an empirical Bayes estimate: :math:`m=\sigma_c^2/\tau^2`, where
:math:`\sigma_i^2` is the variance of `y` with category :math:`i` and
:math:`\tau^2` is the global variance of `y`.

For continuous targets, the formulation is similar to binary classification:

.. math::
S_i = \lambda_i\frac{\sum_{k\in L_i}y_k}{n_i} + (1 - \lambda_i)\frac{\sum_{k=1}^{n}y_k}{n}

where :math:`L_i` is the set of observations for which :math:`X=X_i` and
:math:`n_i` is the cardinality of :math:`L_i`.

:meth:`~TargetEncoder.fit_transform` internally relies on a cross validation
scheme to prevent information from the target from leaking into the train-time
representation for non-informative high-cardinality categorical variables and
help prevent the downstream model to overfit spurious correlations. Note that
as a result, `fit(X, y).transform(X)` does not equal `fit_transform(X, y)`. In
:meth:`~TargetEncoder.fit_transform`, the training data is split into multiple
folds and encodes each fold by using the encodings trained on the other folds.
After cross validation is complete in :meth:`~TargetEncoder.fit_transform`, the
target encoder learns one final encoding on the whole training set. This final
encoding is used to encode categories in :meth:`~TargetEncoder.transform`. The
following diagram shows the cross validation scheme in
:meth:`~TargetEncoder.fit_transform` with the default `cv=5`:

.. image:: ../images/target_encoder_cross_validation.svg
:width: 600
:align: center

The :meth:`~TargetEncoder.fit` method does **not** use any cross validation
schemes and learns one encoding on the entire training set, which is used to
encode categories in :meth:`~TargetEncoder.transform`.
:meth:`~TargetEncoder.fit`'s one encoding is the same as the final encoding
learned in :meth:`~TargetEncoder.fit_transform`.

.. note::
:class:`TargetEncoder` considers missing values, such as `np.nan` or `None`,
as another category and encodes them like any other category. Categories
that are not seen during `fit` are encoded with the target mean, i.e.
`target_mean_`.

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_preprocessing_plot_target_encoder.py`

.. topic:: References

.. [MIC] :doi:`Micci-Barreca, Daniele. "A preprocessing scheme for high-cardinality
categorical attributes in classification and prediction problems"
SIGKDD Explor. Newsl. 3, 1 (July 2001), 27–32. <10.1145/507533.507538>`

.. [PAR] :doi:`Pargent, F., Pfisterer, F., Thomas, J. et al. "Regularized target
encoding outperforms traditional methods in supervised machine learning with
high cardinality features" Comput Stat 37, 2671–2692 (2022)
<10.1007/s00180-022-01207-6>`

.. _preprocessing_discretization:

Discretization
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ Changelog
:mod:`sklearn.preprocessing`
............................

- |MajorFeature| Introduces :class:`preprocessing.TargetEncoder` which is a
categorical encoding based on target mean conditioned on the value of the
category. :pr:`25334` by `Thomas Fan`_.

- |Enhancement| Adds a `feature_name_combiner` parameter to
:class:`preprocessing.OneHotEncoder`. This specifies a custom callable to create
feature names to be returned by :meth:`get_feature_names_out`.
Expand Down
227 changes: 227 additions & 0 deletions examples/preprocessing/plot_target_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""
============================================
Comparing Target Encoder with Other Encoders
============================================

.. currentmodule:: sklearn.preprocessing

The :class:`TargetEncoder` uses the value of the target to encode each
categorical feature. In this example, we will compare three different approaches
for handling categorical features: :class:`TargetEncoder`,
:class:`OrdinalEncoder`, :class:`OneHotEncoder` and dropping the category.

.. note::
`fit(X, y).transform(X)` does not equal `fit_transform(X, y)` because a
cross-validation scheme is used in `fit_transform` for encoding. See the
:ref:`User Guide <target_encoder>`. for details.
"""

# %%
# Loading Data from OpenML
# ========================
# First, we load the wine reviews dataset, where the target is the points given
# be a reviewer:
from sklearn.datasets import fetch_openml

wine_reviews = fetch_openml(data_id=42074, as_frame=True, parser="liac-arff")
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
ogrisel marked this conversation as resolved.
Show resolved Hide resolved

df = wine_reviews.frame
df.head()

# %%
# For this example, we use the following subset of numerical and categorical
# features in the data. The target are continuous values from 80 to 100:
numerical_features = ["price"]
categorical_features = [
"country",
"province",
"region_1",
"region_2",
"variety",
"winery",
]
thomasjpfan marked this conversation as resolved.
Show resolved Hide resolved
target_name = "points"

X = df[numerical_features + categorical_features]
y = df[target_name]

_ = y.hist()

# %%
# Training and Evaluating Pipelines with Different Encoders
# =========================================================
# In this section, we will evaluate pipelines with
# :class:`~sklearn.ensemble.HistGradientBoostingRegressor` with different encoding
# strategies. First, we list out the encoders we will be using to preprocess
# the categorical features:
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import TargetEncoder

categorical_preprocessors = [
("drop", "drop"),
("ordinal", OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1)),
(
"one_hot",
OneHotEncoder(handle_unknown="ignore", max_categories=20, sparse_output=False),
),
("target", TargetEncoder(target_type="continuous")),
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
]

# %%
# Next, we evaluate the models using cross validation and record the results:
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import cross_validate
from sklearn.ensemble import HistGradientBoostingRegressor

n_cv_folds = 3
max_iter = 20
results = []


def evaluate_model_and_store(name, pipe):
result = cross_validate(
pipe,
X,
y,
scoring="neg_root_mean_squared_error",
cv=n_cv_folds,
return_train_score=True,
)
rmse_test_score = -result["test_score"]
rmse_train_score = -result["train_score"]
results.append(
{
"preprocessor": name,
"rmse_test_mean": rmse_test_score.mean(),
"rmse_test_std": rmse_train_score.std(),
"rmse_train_mean": rmse_train_score.mean(),
"rmse_train_std": rmse_train_score.std(),
}
)


for name, categorical_preprocessor in categorical_preprocessors:
preprocessor = ColumnTransformer(
[
("numerical", "passthrough", numerical_features),
("categorical", categorical_preprocessor, categorical_features),
]
)
pipe = make_pipeline(
preprocessor, HistGradientBoostingRegressor(random_state=0, max_iter=max_iter)
)
evaluate_model_and_store(name, pipe)


# %%
# Native Categorical Feature Support
# ==================================
# In this section, we build and evaluate a pipeline that uses native categorical
# feature support in :class:`~sklearn.ensemble.HistGradientBoostingRegressor`,
# which only supports up to 255 unique categories. In our dataset, the most of
# the categorical features have more than 255 unique categories:
n_unique_categories = df[categorical_features].nunique().sort_values(ascending=False)
n_unique_categories

# %%
# To workaround the limitation above, we group the categorical features into
# low cardinality and high cardinality features. The high cardinality features
# will be target encoded and the low cardinality features will use the native
# categorical feature in gradient boosting.
high_cardinality_features = n_unique_categories[n_unique_categories > 255].index
low_cardinality_features = n_unique_categories[n_unique_categories <= 255].index
mixed_encoded_preprocessor = ColumnTransformer(
[
("numerical", "passthrough", numerical_features),
(
"high_cardinality",
TargetEncoder(target_type="continuous"),
high_cardinality_features,
),
(
"low_cardinality",
OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1),
low_cardinality_features,
),
],
verbose_feature_names_out=False,
)

# The output of the of the preprocessor must be set to pandas so the
# gradient boosting model can detect the low cardinality features.
mixed_encoded_preprocessor.set_output(transform="pandas")
mixed_pipe = make_pipeline(
mixed_encoded_preprocessor,
HistGradientBoostingRegressor(
random_state=0, max_iter=max_iter, categorical_features=low_cardinality_features
),
)
mixed_pipe

# %%
# Finally, we evaluate the pipeline using cross validation and record the results:
evaluate_model_and_store("mixed_target", mixed_pipe)

# %%
# Plotting the Results
# ====================
# In this section, we display the results by plotting the test and train scores:
import matplotlib.pyplot as plt
import pandas as pd

results_df = (
pd.DataFrame(results).set_index("preprocessor").sort_values("rmse_test_mean")
)

fig, (ax1, ax2) = plt.subplots(
1, 2, figsize=(12, 8), sharey=True, constrained_layout=True
)
xticks = range(len(results_df))
name_to_color = dict(
zip((r["preprocessor"] for r in results), ["C0", "C1", "C2", "C3", "C4"])
)

for subset, ax in zip(["test", "train"], [ax1, ax2]):
mean, std = f"rmse_{subset}_mean", f"rmse_{subset}_std"
data = results_df[[mean, std]].sort_values(mean)
ax.bar(
x=xticks,
height=data[mean],
yerr=data[std],
width=0.9,
color=[name_to_color[name] for name in data.index],
)
ax.set(
title=f"RMSE ({subset.title()})",
xlabel="Encoding Scheme",
xticks=xticks,
xticklabels=data.index,
)

# %%
# When evaluating the predictive performance on the test set, dropping the
# categories perform the worst and the target encoders performs the best. This
# can be explained as follows:
#
# - Dropping the categorical features makes the pipeline less expressive and
# underfitting as a result;
# - Due to the high cardinality and to reduce the training time, the one-hot
# encoding scheme uses `max_categories=20` which prevents the features from
# expanding too much, which can result in underfitting.
# - If we had not set `max_categories=20`, the one-hot encoding scheme would have
# likely made the pipeline overfitting as the number of features explodes with rare
# category occurrences that are correlated with the target by chance (on the training
# set only);
# - The ordinal encoding imposes an arbitrary order to the features which are then
# treated as numerical values by the
# :class:`~sklearn.ensemble.HistGradientBoostingRegressor`. Since this
# model groups numerical features in 256 bins per feature, many unrelated categories
# can be grouped together and as a result overall pipeline can underfit;
# - When using the target encoder, the same binning happens, but since the encoded
# values are statistically ordered by marginal association with the target variable,
# the binning use by the :class:`~sklearn.ensemble.HistGradientBoostingRegressor`
# makes sense and leads to good results: the combination of smoothed target
# encoding and binning works as a good regularizing strategy against
# overfitting while not limiting the expressiveness of the pipeline too much.
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,12 @@ def check_package_status(package, min_version):
],
"preprocessing": [
{"sources": ["_csr_polynomial_expansion.pyx"], "include_np": True},
{
"sources": ["_target_encoder_fast.pyx"],
"include_np": True,
"language": "c++",
"extra_compile_args": ["-std=c++11"],
},
],
"neighbors": [
{"sources": ["_ball_tree.pyx"], "include_np": True},
Expand Down
2 changes: 2 additions & 0 deletions sklearn/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from ._encoders import OneHotEncoder
from ._encoders import OrdinalEncoder
from ._target_encoder import TargetEncoder

from ._label import label_binarize
from ._label import LabelBinarizer
Expand Down Expand Up @@ -56,6 +57,7 @@
"RobustScaler",
"SplineTransformer",
"StandardScaler",
"TargetEncoder",
"add_dummy_feature",
"PolynomialFeatures",
"binarize",
Expand Down
8 changes: 7 additions & 1 deletion sklearn/preprocessing/_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ class OneHotEncoder(_BaseEncoder):
--------
OrdinalEncoder : Performs an ordinal (integer)
encoding of the categorical features.
TargetEncoder : Encodes categorical features using the target.
sklearn.feature_extraction.DictVectorizer : Performs a one-hot encoding of
dictionary items (also handles string-valued features).
sklearn.feature_extraction.FeatureHasher : Performs an approximate one-hot
Expand Down Expand Up @@ -1228,7 +1229,12 @@ class OrdinalEncoder(OneToOneFeatureMixin, _BaseEncoder):

See Also
--------
OneHotEncoder : Performs a one-hot encoding of categorical features.
OneHotEncoder : Performs a one-hot encoding of categorical features. This encoding
is suitable for low to medium cardinality categorical variables, both in
supervised and unsupervised settings.
TargetEncoder : Encodes categorical features using supervised signal
in a classification or regression pipeline. This encoding is typically
suitable for high cardinality categorical variables.
LabelEncoder : Encodes target labels with values between 0 and
``n_classes-1``.

Expand Down