Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Adds TargetEncoder #25334

Merged
merged 91 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
f29c2d8
ENH Adds Target Regression Encoder
thomasjpfan Jan 8, 2023
336b17d
DOC Adds pr number
thomasjpfan Jan 8, 2023
c52ce42
DOC Adds pr number
thomasjpfan Jan 8, 2023
1863be9
DOC Adds example
thomasjpfan Jan 9, 2023
22af6a6
DOC Fixes example link
thomasjpfan Jan 9, 2023
f535f54
FIX Use fancy indexing
thomasjpfan Jan 9, 2023
c6d1fc4
FIX Fix issue with 32bit
thomasjpfan Jan 9, 2023
2868b3c
ENH Adds support for binary classification
thomasjpfan Jan 10, 2023
aa6e545
TST Adds test to check target_type_
thomasjpfan Jan 10, 2023
871ea45
DOC Update whats new with better wording
thomasjpfan Jan 10, 2023
71e6bad
DOC Update docs about target encoder itself
thomasjpfan Jan 10, 2023
fdd39d0
TST Update names of tests
thomasjpfan Jan 10, 2023
c8d5546
DOC better names for variables
thomasjpfan Jan 10, 2023
6a5771c
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Jan 10, 2023
6dbcb85
ENH Adds target_type parameter
thomasjpfan Jan 10, 2023
00d5230
TST Fixes test failures
thomasjpfan Jan 10, 2023
e75a51e
DOC Remove mention on RMSE
thomasjpfan Jan 10, 2023
2978e23
DOC Adds concluding statement in example
thomasjpfan Jan 10, 2023
429085d
CLN Refactor names
thomasjpfan Jan 10, 2023
93ea268
DOC Improves user guide
thomasjpfan Jan 13, 2023
cfe7afa
CLN Cleaner implementation
thomasjpfan Jan 13, 2023
3967934
ENH Adds auto for smoothing
thomasjpfan Jan 13, 2023
7875c2e
CLN Address comments
thomasjpfan Jan 13, 2023
5ff5b61
DOC Update example title
thomasjpfan Jan 13, 2023
2568e89
DOC Adds comment to point to equation
thomasjpfan Jan 13, 2023
afc4223
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Jan 13, 2023
6e4bf50
Apply suggestions from code review
thomasjpfan Jan 13, 2023
d613c61
DOC Only say classification
thomasjpfan Jan 13, 2023
ba85edb
DOC Update examples and docstring
thomasjpfan Jan 13, 2023
e3c8473
DOC Imrove example
thomasjpfan Jan 14, 2023
0d6c4a2
DOC Clarify fit_transform vs fit.transform
thomasjpfan Jan 14, 2023
d9cd410
CLN Address comments
thomasjpfan Jan 14, 2023
ee1e26b
DOC Adds more code comments
thomasjpfan Jan 14, 2023
a399791
CLN Code formatting
thomasjpfan Jan 14, 2023
fd2b963
CLN Move test closer to comment
thomasjpfan Jan 14, 2023
cd6c650
DOC Adds more details aboue equations
thomasjpfan Jan 14, 2023
991fd9f
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Jan 16, 2023
962125d
CLN Address comments
thomasjpfan Jan 18, 2023
02d8346
TST Simplify test logic
thomasjpfan Jan 20, 2023
acb4b2e
CLN Address comments
thomasjpfan Jan 20, 2023
3261d10
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Jan 20, 2023
a57b117
DOC Adds note about type_of_target inference
thomasjpfan Jan 20, 2023
5a330be
ENH Use integers for counts
thomasjpfan Jan 23, 2023
f1a14d7
STY Formatting
thomasjpfan Jan 23, 2023
66fb8bd
STY Fix numpydoc linting error
thomasjpfan Jan 23, 2023
f5df8b8
DOC Adds example about low and high smoothing parameters
thomasjpfan Jan 27, 2023
906ac98
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Jan 30, 2023
43630cf
Apply suggestions from code review
thomasjpfan Jan 30, 2023
177128d
STY Black formating
thomasjpfan Jan 31, 2023
e37b6ec
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Feb 2, 2023
86d357e
DOC Adds mixed encoder for high and low cardinality
thomasjpfan Feb 2, 2023
a7cc80f
DOC Address comments
thomasjpfan Feb 2, 2023
0bcad5b
ENH Make sure the colors align between graphs
thomasjpfan Feb 2, 2023
61081db
Apply suggestions from code review
thomasjpfan Feb 2, 2023
2e4bd85
CLN Reformat section headers and move sections around
thomasjpfan Feb 2, 2023
07a225c
CLN Try to work around sphinx 4.0.1
thomasjpfan Feb 2, 2023
7cd0c8d
CLN Address comments
thomasjpfan Feb 6, 2023
94fd097
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Feb 7, 2023
56dd2fa
CLN Better variable names
thomasjpfan Feb 10, 2023
317e13a
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Feb 10, 2023
c4eacfa
ENH Restrict cv to integers
thomasjpfan Feb 10, 2023
0ab474b
DOC Improve docstring
thomasjpfan Feb 10, 2023
c09556c
CLN Better variable names in test
thomasjpfan Feb 10, 2023
b7654bb
Apply suggestions from code review
thomasjpfan Feb 11, 2023
387f42c
DOC Simplify notes
thomasjpfan Feb 11, 2023
efe630c
DOC Adds diagram about cross validation
thomasjpfan Feb 12, 2023
803a988
DOC Explain fit vs fit_transform
thomasjpfan Feb 12, 2023
f5f7ab1
DOC Clarify fit
thomasjpfan Feb 12, 2023
0e784b1
Apply suggestions from code review
thomasjpfan Mar 7, 2023
444f589
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Mar 7, 2023
4994bc4
FIX Fixes compile issue
thomasjpfan Mar 7, 2023
8fb1a11
CLN Removes pow
thomasjpfan Mar 7, 2023
850694e
CLN Use nogil and pointers
thomasjpfan Mar 7, 2023
665e2e9
MNT Use vector instead
thomasjpfan Mar 7, 2023
5247123
DOC Adds note about missing values
thomasjpfan Mar 7, 2023
70f99fe
API Adds shuffle and random_state
thomasjpfan Mar 7, 2023
e09660c
STY Slight reformatting
thomasjpfan Mar 7, 2023
690fcdc
Check random seeds [all random seeds]
thomasjpfan Mar 8, 2023
a327b52
API Change attriubte name to target_mean_
thomasjpfan Mar 8, 2023
6b27afd
Apply suggestions from code review
thomasjpfan Mar 8, 2023
9773765
STY Linting issue
thomasjpfan Mar 8, 2023
8e18746
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Mar 8, 2023
1ea0407
FIX Fixes edge case with smooth=0.0 and unknown categories in cv
thomasjpfan Mar 10, 2023
df5574d
Apply suggestions from code review
thomasjpfan Mar 10, 2023
d681e06
Merge remote-tracking branch 'upstream/main' into target_regression_e…
thomasjpfan Mar 10, 2023
0b458c3
TST Improves testing maintainability
thomasjpfan Mar 13, 2023
26d2429
TST add statistical-integration tests to check the benefit of interna…
ogrisel Mar 14, 2023
d77e519
Fix comment in test to more accurately describe what's happening
ogrisel Mar 14, 2023
05997bb
TST Use RNG for permutation
thomasjpfan Mar 14, 2023
6377419
More fixes in inline comments for the new test
ogrisel Mar 14, 2023
180ee0e
Use pandas to parse the fetched the example dataset.
ogrisel Mar 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,7 @@ details.
preprocessing.RobustScaler
preprocessing.SplineTransformer
preprocessing.StandardScaler
preprocessing.TargetRegressorEncoder

.. autosummary::
:toctree: generated/
Expand Down
47 changes: 47 additions & 0 deletions doc/modules/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,53 @@ lexicon order.
>>> enc.infrequent_categories_
[array(['b', 'c'], dtype=object)]

.. _target_regressor_encoder:

Target Regressor Encoder
------------------------
lorentzenchr marked this conversation as resolved.
Show resolved Hide resolved

.. currentmodule:: sklearn.preprocessing

The :class:`TargetRegressorEncoder` uses target mean conditioned on the
categorical feature for encoding the categories [PAR]_ [MIC]_. This encoding
scheme is useful with categorical features with high cardinality, where one hot
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to emphasise that "categorical features that have no order (nominal categories)" (then we can use "nominal categories" later as a shorthand) is the target (haha) use-case of the target encoder. High cardinality features could also be dealt with ordinal encoding.

encoding would inflate the feature space making it more expensive for a
downstream model to process. A classical example of high cardinality categories
lorentzenchr marked this conversation as resolved.
Show resolved Hide resolved
are location based such as zip code or region. The
:class:`TargetRegressorEncoder` implementation mixes the global target mean with
the target mean conditioned on the category:

.. math::
E_c = \dfrac{\sum_{X_i = c}y_i + s\mu_y}{|X_c| + s}

where :math:`E_c` is the encoding for category :math:`c`, :math:`X_i` is the
category at :math:`i`, :math:`y_i` is the target at :math:`i`, :math:`s` is a
smoothing parameter, and :math:`X_c` is the set of data points with category
:math:`c`.

:class:`TargetRegressorEncoder` uses a cross validation scheme in
:meth:`~TargetRegressorEncoder.fit_transform` to prevent leaking the target
during training. In :meth:`~TargetRegressorEncoder.fit_transform`, Categorical
encodings are obtained from one split and used to encoding the other split.
Afterwards, a final categorical encoding is obtained from all the training data,
which is used to encode data during :meth:`~TargetRegressorEncoder.transform`.
This means that `fit().transform()` does not equal `fit_transform()`.

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_preprocessing_plot_target_regressor_encoder.py`

.. topic:: References

.. [MIC] :doi:`Micci-Barreca, Daniele. "A preprocessing scheme for high-cardinality
categorical attributes in classification and prediction problems"
SIGKDD Explor. Newsl. 3, 1 (July 2001), 27–32. <10.1145/507533.507538>`

.. [PAR] :doi:`Pargent, F., Pfisterer, F., Thomas, J. et al. "Regularized target
encoding outperforms traditional methods in supervised machine learning with
high cardinality features" Comput Stat 37, 2671–2692 (2022)
<10.1007/s00180-022-01207-6>`

.. _preprocessing_discretization:

Discretization
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ Changelog

:mod:`sklearn.preprocessing`
............................
- |MajorFeature| Introduces :class:`preprocessing.TargetRegressorEncoder` which uses
the target mean conditions on the categorices to encode the categories.
:pr:`25334` by `Thomas Fan`_.

- |Enhancement| Added support for `sample_weight` in
:class:`preprocessing.KBinsDiscretizer`. This allows specifying the parameter
`sample_weight` for each sample to be used while fitting. The option is only
Expand Down
121 changes: 121 additions & 0 deletions examples/preprocessing/plot_target_regressor_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
=============================
Target Encoder for Regressors
=============================

.. currentmodule:: sklearn.preprocessing

The :class:`TargetRegressorEncoder` uses target statistics conditioned on
the categorical features for encoding. In this example, we will compare
TomDLT marked this conversation as resolved.
Show resolved Hide resolved
:class:`TargetRegressorEncoder`, :class:`OrdinalEncoder`, and dropping the
TomDLT marked this conversation as resolved.
Show resolved Hide resolved
category on a wine review dataset.
TomDLT marked this conversation as resolved.
Show resolved Hide resolved
"""

# %%
# Loading Data from OpenML
# ========================
# First, we load the wine reviews dataset, where the target is the points given
# be a reviewer:
from sklearn.datasets import fetch_openml

wine_reviews = fetch_openml(data_id=42074, as_frame=True, parser="liac-arff")

df = wine_reviews.frame
df.head()

# %%
# For this example, we use the following subset of numerical and categorical
# features in the data. The categorical features have a cardinality ranging
# from 18 to 14810:
numerical_features = ["price"]
categorical_features = [
"country",
"province",
"region_1",
"region_2",
"variety",
"winery",
]

X = df[numerical_features + categorical_features]
y = df["points"]
X.nunique().sort_values(ascending=False)
TomDLT marked this conversation as resolved.
Show resolved Hide resolved

# %%
# We split the dataset into a training and test set:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

print(f"Samples in training set: {len(X_train)}\nSamples in test set: {len(X_test)}")

# %%
# Building and Training Pipelines with Different Encoders
# =======================================================
# Dropping the categorical features
# ---------------------------------
# As a basline, we construct a pipeline where the categorical features are
TomDLT marked this conversation as resolved.
Show resolved Hide resolved
# dropped.
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.ensemble import HistGradientBoostingRegressor

prep = ColumnTransformer(
[
("num", "passthrough", numerical_features),
("cat", "drop", categorical_features),
]
)

reg_drop_cats = Pipeline(
TomDLT marked this conversation as resolved.
Show resolved Hide resolved
[("prep", prep), ("hist", HistGradientBoostingRegressor(random_state=0))]
)
reg_drop_cats

# %%
# Here we train and use the root mean squared error to evalute the baseline
# model:
TomDLT marked this conversation as resolved.
Show resolved Hide resolved
from sklearn.metrics import mean_squared_error

reg_drop_cats.fit(X_train, y_train)
reg_drop_cats_rmse = mean_squared_error(y_test, reg_drop_cats.predict(X_test))
print(f"RMSE for dropping categorical features: {reg_drop_cats_rmse:.4}")

# %%
# Using the OrdinalEncoder
# ------------------------
# Since the categorical features have missing values, we impute the feature
# with `'sk_missing'` before passing it to the :class:`OrdinalEncoder`.
from sklearn.preprocessing import OrdinalEncoder

cat_prep = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1)

# %%
# We modify the original pipeline to use the ordinal categorical preprocessing:
reg_ordinal = reg_drop_cats.set_params(prep__cat=cat_prep)
reg_ordinal
TomDLT marked this conversation as resolved.
Show resolved Hide resolved

# %%
# When we include the categorical features through ordinal encoding the RMSE
# improves:
reg_ordinal.fit(X_train, y_train)
reg_ordinal_rmse = mean_squared_error(
y_test, reg_ordinal.predict(X_test), squared=False
)
print(f"RMSE with ordinal encoding: {reg_ordinal_rmse:.4}")

# %%
# Using the TargetRegressorEncoder
# --------------------------------
# Finally, we replace the ordinal encoder with the
# :class:`TargetRegressorEncoder`:
from sklearn.preprocessing import TargetRegressorEncoder

reg_target = reg_ordinal.set_params(prep__cat=TargetRegressorEncoder())
reg_target
TomDLT marked this conversation as resolved.
Show resolved Hide resolved

# %%
# The :class:`TargetRegressorEncoder` further improves the RMSE:
reg_target.fit(X_train, y_train)
reg_target_rmse = mean_squared_error(y_test, reg_target.predict(X_test), squared=False)
print(f"RMSE with target encoding: {reg_target_rmse:.4}")
TomDLT marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def check_package_status(package, min_version):
],
"preprocessing": [
{"sources": ["_csr_polynomial_expansion.pyx"], "include_np": True},
{"sources": ["_target_encoder_fast.pyx"], "include_np": True},
],
"neighbors": [
{"sources": ["_ball_tree.pyx"], "include_np": True},
Expand Down
2 changes: 2 additions & 0 deletions sklearn/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from ._encoders import OneHotEncoder
from ._encoders import OrdinalEncoder
from ._target_encoder import TargetRegressorEncoder

from ._label import label_binarize
from ._label import LabelBinarizer
Expand Down Expand Up @@ -56,6 +57,7 @@
"RobustScaler",
"SplineTransformer",
"StandardScaler",
"TargetRegressorEncoder",
"add_dummy_feature",
"PolynomialFeatures",
"binarize",
Expand Down