Skip to content

Commit

Permalink
ENH Adds TargetEncoder (scikit-learn#25334)
Browse files Browse the repository at this point in the history
Co-authored-by: Andreas Mueller <t3kcit@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Jovan Stojanovic <62058944+jovan-stojanovic@users.noreply.github.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
  • Loading branch information
5 people authored and Itay committed Apr 15, 2023
1 parent 23f9694 commit d13af4c
Show file tree
Hide file tree
Showing 11 changed files with 1,394 additions and 1 deletion.
3 changes: 3 additions & 0 deletions doc/images/target_encoder_cross_validation.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,7 @@ details.
preprocessing.RobustScaler
preprocessing.SplineTransformer
preprocessing.StandardScaler
preprocessing.TargetEncoder

.. autosummary::
:toctree: generated/
Expand Down
86 changes: 86 additions & 0 deletions doc/modules/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,92 @@ lexicon order.
>>> enc.infrequent_categories_
[array(['b', 'c'], dtype=object)]

.. _target_encoder:

Target Encoder
--------------

.. currentmodule:: sklearn.preprocessing

The :class:`TargetEncoder` uses the target mean conditioned on the categorical
feature for encoding unordered categories, i.e. nominal categories [PAR]_
[MIC]_. This encoding scheme is useful with categorical features with high
cardinality, where one-hot encoding would inflate the feature space making it
more expensive for a downstream model to process. A classical example of high
cardinality categories are location based such as zip code or region. For the
binary classification target, the target encoding is given by:

.. math::
S_i = \lambda_i\frac{n_{iY}}{n_i} + (1 - \lambda_i)\frac{n_y}{n}
where :math:`S_i` is the encoding for category :math:`i`, :math:`n_{iY}` is the
number of observations with :math:`Y=1` with category :math:`i`, :math:`n_i` is
the number of observations with category :math:`i`, :math:`n_y` is the number of
observations with :math:`Y=1`, :math:`n` is the number of observations, and
:math:`\lambda_i` is a shrinkage factor. The shrinkage factor is given by:

.. math::
\lambda_i = \frac{n_i}{m + n_i}
where :math:`m` is a smoothing factor, which is controlled with the `smooth`
parameter in :class:`TargetEncoder`. Large smoothing factors will put more
weight on the global mean. When `smooth="auto"`, the smoothing factor is
computed as an empirical Bayes estimate: :math:`m=\sigma_c^2/\tau^2`, where
:math:`\sigma_i^2` is the variance of `y` with category :math:`i` and
:math:`\tau^2` is the global variance of `y`.

For continuous targets, the formulation is similar to binary classification:

.. math::
S_i = \lambda_i\frac{\sum_{k\in L_i}y_k}{n_i} + (1 - \lambda_i)\frac{\sum_{k=1}^{n}y_k}{n}
where :math:`L_i` is the set of observations for which :math:`X=X_i` and
:math:`n_i` is the cardinality of :math:`L_i`.

:meth:`~TargetEncoder.fit_transform` internally relies on a cross validation
scheme to prevent information from the target from leaking into the train-time
representation for non-informative high-cardinality categorical variables and
help prevent the downstream model to overfit spurious correlations. Note that
as a result, `fit(X, y).transform(X)` does not equal `fit_transform(X, y)`. In
:meth:`~TargetEncoder.fit_transform`, the training data is split into multiple
folds and encodes each fold by using the encodings trained on the other folds.
After cross validation is complete in :meth:`~TargetEncoder.fit_transform`, the
target encoder learns one final encoding on the whole training set. This final
encoding is used to encode categories in :meth:`~TargetEncoder.transform`. The
following diagram shows the cross validation scheme in
:meth:`~TargetEncoder.fit_transform` with the default `cv=5`:

.. image:: ../images/target_encoder_cross_validation.svg
:width: 600
:align: center

The :meth:`~TargetEncoder.fit` method does **not** use any cross validation
schemes and learns one encoding on the entire training set, which is used to
encode categories in :meth:`~TargetEncoder.transform`.
:meth:`~TargetEncoder.fit`'s one encoding is the same as the final encoding
learned in :meth:`~TargetEncoder.fit_transform`.

.. note::
:class:`TargetEncoder` considers missing values, such as `np.nan` or `None`,
as another category and encodes them like any other category. Categories
that are not seen during `fit` are encoded with the target mean, i.e.
`target_mean_`.

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_preprocessing_plot_target_encoder.py`

.. topic:: References

.. [MIC] :doi:`Micci-Barreca, Daniele. "A preprocessing scheme for high-cardinality
categorical attributes in classification and prediction problems"
SIGKDD Explor. Newsl. 3, 1 (July 2001), 27–32. <10.1145/507533.507538>`
.. [PAR] :doi:`Pargent, F., Pfisterer, F., Thomas, J. et al. "Regularized target
encoding outperforms traditional methods in supervised machine learning with
high cardinality features" Comput Stat 37, 2671–2692 (2022)
<10.1007/s00180-022-01207-6>`
.. _preprocessing_discretization:

Discretization
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,10 @@ Changelog
:mod:`sklearn.preprocessing`
............................

- |MajorFeature| Introduces :class:`preprocessing.TargetEncoder` which is a
categorical encoding based on target mean conditioned on the value of the
category. :pr:`25334` by `Thomas Fan`_.

- |Enhancement| Adds a `feature_name_combiner` parameter to
:class:`preprocessing.OneHotEncoder`. This specifies a custom callable to create
feature names to be returned by :meth:`get_feature_names_out`.
Expand Down
227 changes: 227 additions & 0 deletions examples/preprocessing/plot_target_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""
============================================
Comparing Target Encoder with Other Encoders
============================================
.. currentmodule:: sklearn.preprocessing
The :class:`TargetEncoder` uses the value of the target to encode each
categorical feature. In this example, we will compare three different approaches
for handling categorical features: :class:`TargetEncoder`,
:class:`OrdinalEncoder`, :class:`OneHotEncoder` and dropping the category.
.. note::
`fit(X, y).transform(X)` does not equal `fit_transform(X, y)` because a
cross-validation scheme is used in `fit_transform` for encoding. See the
:ref:`User Guide <target_encoder>`. for details.
"""

# %%
# Loading Data from OpenML
# ========================
# First, we load the wine reviews dataset, where the target is the points given
# be a reviewer:
from sklearn.datasets import fetch_openml

wine_reviews = fetch_openml(data_id=42074, as_frame=True, parser="pandas")

df = wine_reviews.frame
df.head()

# %%
# For this example, we use the following subset of numerical and categorical
# features in the data. The target are continuous values from 80 to 100:
numerical_features = ["price"]
categorical_features = [
"country",
"province",
"region_1",
"region_2",
"variety",
"winery",
]
target_name = "points"

X = df[numerical_features + categorical_features]
y = df[target_name]

_ = y.hist()

# %%
# Training and Evaluating Pipelines with Different Encoders
# =========================================================
# In this section, we will evaluate pipelines with
# :class:`~sklearn.ensemble.HistGradientBoostingRegressor` with different encoding
# strategies. First, we list out the encoders we will be using to preprocess
# the categorical features:
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import TargetEncoder

categorical_preprocessors = [
("drop", "drop"),
("ordinal", OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1)),
(
"one_hot",
OneHotEncoder(handle_unknown="ignore", max_categories=20, sparse_output=False),
),
("target", TargetEncoder(target_type="continuous")),
]

# %%
# Next, we evaluate the models using cross validation and record the results:
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import cross_validate
from sklearn.ensemble import HistGradientBoostingRegressor

n_cv_folds = 3
max_iter = 20
results = []


def evaluate_model_and_store(name, pipe):
result = cross_validate(
pipe,
X,
y,
scoring="neg_root_mean_squared_error",
cv=n_cv_folds,
return_train_score=True,
)
rmse_test_score = -result["test_score"]
rmse_train_score = -result["train_score"]
results.append(
{
"preprocessor": name,
"rmse_test_mean": rmse_test_score.mean(),
"rmse_test_std": rmse_train_score.std(),
"rmse_train_mean": rmse_train_score.mean(),
"rmse_train_std": rmse_train_score.std(),
}
)


for name, categorical_preprocessor in categorical_preprocessors:
preprocessor = ColumnTransformer(
[
("numerical", "passthrough", numerical_features),
("categorical", categorical_preprocessor, categorical_features),
]
)
pipe = make_pipeline(
preprocessor, HistGradientBoostingRegressor(random_state=0, max_iter=max_iter)
)
evaluate_model_and_store(name, pipe)


# %%
# Native Categorical Feature Support
# ==================================
# In this section, we build and evaluate a pipeline that uses native categorical
# feature support in :class:`~sklearn.ensemble.HistGradientBoostingRegressor`,
# which only supports up to 255 unique categories. In our dataset, the most of
# the categorical features have more than 255 unique categories:
n_unique_categories = df[categorical_features].nunique().sort_values(ascending=False)
n_unique_categories

# %%
# To workaround the limitation above, we group the categorical features into
# low cardinality and high cardinality features. The high cardinality features
# will be target encoded and the low cardinality features will use the native
# categorical feature in gradient boosting.
high_cardinality_features = n_unique_categories[n_unique_categories > 255].index
low_cardinality_features = n_unique_categories[n_unique_categories <= 255].index
mixed_encoded_preprocessor = ColumnTransformer(
[
("numerical", "passthrough", numerical_features),
(
"high_cardinality",
TargetEncoder(target_type="continuous"),
high_cardinality_features,
),
(
"low_cardinality",
OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1),
low_cardinality_features,
),
],
verbose_feature_names_out=False,
)

# The output of the of the preprocessor must be set to pandas so the
# gradient boosting model can detect the low cardinality features.
mixed_encoded_preprocessor.set_output(transform="pandas")
mixed_pipe = make_pipeline(
mixed_encoded_preprocessor,
HistGradientBoostingRegressor(
random_state=0, max_iter=max_iter, categorical_features=low_cardinality_features
),
)
mixed_pipe

# %%
# Finally, we evaluate the pipeline using cross validation and record the results:
evaluate_model_and_store("mixed_target", mixed_pipe)

# %%
# Plotting the Results
# ====================
# In this section, we display the results by plotting the test and train scores:
import matplotlib.pyplot as plt
import pandas as pd

results_df = (
pd.DataFrame(results).set_index("preprocessor").sort_values("rmse_test_mean")
)

fig, (ax1, ax2) = plt.subplots(
1, 2, figsize=(12, 8), sharey=True, constrained_layout=True
)
xticks = range(len(results_df))
name_to_color = dict(
zip((r["preprocessor"] for r in results), ["C0", "C1", "C2", "C3", "C4"])
)

for subset, ax in zip(["test", "train"], [ax1, ax2]):
mean, std = f"rmse_{subset}_mean", f"rmse_{subset}_std"
data = results_df[[mean, std]].sort_values(mean)
ax.bar(
x=xticks,
height=data[mean],
yerr=data[std],
width=0.9,
color=[name_to_color[name] for name in data.index],
)
ax.set(
title=f"RMSE ({subset.title()})",
xlabel="Encoding Scheme",
xticks=xticks,
xticklabels=data.index,
)

# %%
# When evaluating the predictive performance on the test set, dropping the
# categories perform the worst and the target encoders performs the best. This
# can be explained as follows:
#
# - Dropping the categorical features makes the pipeline less expressive and
# underfitting as a result;
# - Due to the high cardinality and to reduce the training time, the one-hot
# encoding scheme uses `max_categories=20` which prevents the features from
# expanding too much, which can result in underfitting.
# - If we had not set `max_categories=20`, the one-hot encoding scheme would have
# likely made the pipeline overfitting as the number of features explodes with rare
# category occurrences that are correlated with the target by chance (on the training
# set only);
# - The ordinal encoding imposes an arbitrary order to the features which are then
# treated as numerical values by the
# :class:`~sklearn.ensemble.HistGradientBoostingRegressor`. Since this
# model groups numerical features in 256 bins per feature, many unrelated categories
# can be grouped together and as a result overall pipeline can underfit;
# - When using the target encoder, the same binning happens, but since the encoded
# values are statistically ordered by marginal association with the target variable,
# the binning use by the :class:`~sklearn.ensemble.HistGradientBoostingRegressor`
# makes sense and leads to good results: the combination of smoothed target
# encoding and binning works as a good regularizing strategy against
# overfitting while not limiting the expressiveness of the pipeline too much.
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,12 @@ def check_package_status(package, min_version):
],
"preprocessing": [
{"sources": ["_csr_polynomial_expansion.pyx"], "include_np": True},
{
"sources": ["_target_encoder_fast.pyx"],
"include_np": True,
"language": "c++",
"extra_compile_args": ["-std=c++11"],
},
],
"neighbors": [
{"sources": ["_ball_tree.pyx"], "include_np": True},
Expand Down
2 changes: 2 additions & 0 deletions sklearn/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from ._encoders import OneHotEncoder
from ._encoders import OrdinalEncoder
from ._target_encoder import TargetEncoder

from ._label import label_binarize
from ._label import LabelBinarizer
Expand Down Expand Up @@ -56,6 +57,7 @@
"RobustScaler",
"SplineTransformer",
"StandardScaler",
"TargetEncoder",
"add_dummy_feature",
"PolynomialFeatures",
"binarize",
Expand Down
8 changes: 7 additions & 1 deletion sklearn/preprocessing/_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ class OneHotEncoder(_BaseEncoder):
--------
OrdinalEncoder : Performs an ordinal (integer)
encoding of the categorical features.
TargetEncoder : Encodes categorical features using the target.
sklearn.feature_extraction.DictVectorizer : Performs a one-hot encoding of
dictionary items (also handles string-valued features).
sklearn.feature_extraction.FeatureHasher : Performs an approximate one-hot
Expand Down Expand Up @@ -1229,7 +1230,12 @@ class OrdinalEncoder(OneToOneFeatureMixin, _BaseEncoder):
See Also
--------
OneHotEncoder : Performs a one-hot encoding of categorical features.
OneHotEncoder : Performs a one-hot encoding of categorical features. This encoding
is suitable for low to medium cardinality categorical variables, both in
supervised and unsupervised settings.
TargetEncoder : Encodes categorical features using supervised signal
in a classification or regression pipeline. This encoding is typically
suitable for high cardinality categorical variables.
LabelEncoder : Encodes target labels with values between 0 and
``n_classes-1``.
Expand Down

0 comments on commit d13af4c

Please sign in to comment.