Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Adds Target Regression Encoder (Impact Encoder) #17323

Closed

Conversation

thomasjpfan
Copy link
Member

@thomasjpfan thomasjpfan commented May 24, 2020

Reference Issues/PRs

Partially Addresses #5853
Closes #9614

What does this implement/fix? Explain your changes.

  • This encoding scheme automatically chooses the based on variance and count, as described in Chapter 12 of Andrew Gelman and Jennifer Hill. Data Analysis Using Regression and Multilevel/Hierarchical Models:

80754415-fda40d00-8afc-11ea-9383-9f47fc95afe1

There are some interesting points about this scheme:

  1. As n_j increases then the mean conditioned on a category plays a bigger role.
  2. If the variance conditioned on a category is much smaller than the overall variance, then the mean conditioned on category play a bigger role. This leads to cases like the following:
from sklearn.preprocessing import TargetRegressorEncoder

X = np.array([[0] * 200 + [1] * 200]).T
y = np.array([1, 1.01] * 100 + [100, 100.01] * 100)
enc = TargetRegressorEncoder().fit(X, y)
enc.transform([[0], [1]])
# array([[  1.005],
#        [100.005]])

Any other comments?

  • There are ways to run encode with CV during train time but that would break the our API contract of having fit().transform() == fit_transform().

  • This is missing an example that shows where it is more useful when compared to OneHotEncoder.

  • The classifier version of this encoder will follow if we decide this encoder should be included in scikit-learn.

@amueller
Copy link
Member

I'm not sure what you're illustrating in 2). That in this case there's very little smoothing? That's good, right? This seems like the encoding I'd want?

@amueller
Copy link
Member

I think this is great, though haven't reviewed in detail yet.
The name is a bit awkward but I don't have a better idea. I prefer having separate regressor and classifier classes as you do here as that makes things cleaner API-wise, I think.
So none of the housing data worked for this (ames, kings county, melbourne)? there's also a craigslist car price dataset that might work.

the target conditioned on the categorical feature. The target encoding scheme
takes a weighted average of the overall target mean and the target mean
conditioned on categories. A multilevel linear model, where the levels are the
categories, is used to construct the weighted average is estimated in Chapter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They call it pooled average, right? Maybe say that? Also maybe provide the other references, even though not entirely applicable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also maybe say that this is the "GLM" version of target encoding and reference Max Kuhn or something and also the thesis I posted? I feel like it's hard to have too many references ;)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it just "multilevel linear model", i.e. not "generalized"? Or did I miss something essential?

Just as info: A multilevel linear model with only a random intercept, i.e. the approach chosen here, is equivalent/aka (Bühlmann Straub) "credibility estimator". Maybe too much information for the UG.

Copy link
Member

@amueller amueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haven't checked the test yet, otherwise looks good.

@@ -19,6 +21,25 @@
]


def _get_counts(values, uniques):
"""Get the number of times each of the values comes up `values`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Get the number of times each of the values comes up `values`
"""Get the number of times each of the values comes up in `values`

This is just bincounts for integers and value_counts if we had pandas, right?

the target conditioned on the categorical feature. The target encoding scheme
takes a weighted average of the overall target mean and the target mean
conditioned on categories. A multilevel linear model, where the levels are the
categories, is used to construct the weighted average is estimated in Chapter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also maybe say that this is the "GLM" version of target encoding and reference Max Kuhn or something and also the thesis I posted? I feel like it's hard to have too many references ;)

sklearn.preprocessing.OrdinalEncoder : Performs an ordinal (integer)
encoding of the categorical features.
sklearn.preprocessing.OneHotEncoder : Performs a one-hot encoding of
categorical features.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add references here as well? We're not super consistent with that, are we?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should avoid references in the docstrings because they lack context, and we sometimes end up with big lists of refs that are impractical because some of them are about super specific details.

It's better to have them in the UG so that we can explicitly refer to them as in "this is detailed in blahblah" or "this was introduced in ...".

In my recent UG-cleaning PRs, I've removed refs from docstrings.


Attributes
----------
cat_encodings_ : list of ndarray
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth giving the shape of this somehow? Took me a minute to understand the short description.
"For each feature, provide the encoding corresponding to each category, in the order of categories_?"

cat_var_ratio = np.ones(n_cats, dtype=float)

for encoding in range(n_cats):
np.equal(X_int[:, i], encoding, out=tmp_mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that the same as tmp_mask[:] = X_int[:, i] == encoding and you find this one more readable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used np.equal so we do not need to allocated more memory in the inner for loop.

From my understanding, the right hand side will allocate an ndarray and then copy it over to tmp_mask.

tmp_mask[:] = X_int[:, i] == encoding

cat_means = np.zeros(n_cats, dtype=float)
cat_var_ratio = np.ones(n_cats, dtype=float)

for encoding in range(n_cats):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm slightly afraid of this being slow but I'm not sure how to do it quickly without pandas or cython. I guess if it's integer encoded you could do some csr_matrix trick? Is it worth it?

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @thomasjpfan

First pass on docs and tests

I haven't reviewed the code yet

Comment on lines 638 to 639
.. [GEL] Andrew Gelman and Jennifer Hill. Data Analysis Using Regression
and Multilevel/Hierarchical Models. Cambridge University Press, 2007
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a $50+ textbook, can we have a free reference?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a pdf online, right? But yes to more references like the paper that no-one cites and the masters' thesis ;)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not know if it is okay to share those link to the book.

@@ -593,6 +593,51 @@ the 2 features with respectively 3 categories each.
See :ref:`dict_feature_extraction` for categorical features that are
represented as a dict, not as scalars.

.. _target_regressor_encoder:

Target Regressor Encoder
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a subsection of the "Encoding Categorical Feature" above, not a separate one.

Maybe you'll need to create subsection for OE and OHE too.

Comment on lines 603 to 605
conditioned on categories. A multilevel linear model, where the levels are the
categories, is used to construct the weighted average is estimated in Chapter
12 of [GEL]_:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence is grammatically incorrect

Comment on lines +623 to +624
encoding for `'cat'` is pulled toward the overall mean of `53` when compared to
`'dog'` because the `'cat'` category appears less frequently::
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't obvious at first since the encoding for dog is still much closer to the global mean than the encoding for cat.

It might be clearer if the mean for dog was significantly different from 53

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there are so many dogs, it's mean will be fairly close to the global maen.

@@ -19,6 +21,25 @@
]


def _get_counts(values, uniques):
"""Get the number of times each of the values comes up `values`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sentence is not correct?

Comment on lines 94 to 96
# unknown
X_trans = enc.transform([unknown_X])
assert_allclose(X_trans, [[y.mean()]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this relevant to test that here?

assert_allclose(X_trans[-1], [y.mean()])

assert len(enc.cat_encodings_) == 1
# unknown category seen during fitting is mapped to the mean
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this instead a known category unseen during fitting?

i.e. 2 and 'cow'?

y_mean = np.mean(y)

# manually compute multilevel partial pooling
feat_0_cat_0_encoding = ((4 * 4.0 / 5.0 + 4.0 / 9.5) /
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe avoid using the .0 notation which adds noise

Comment on lines 57 to 62
# check known test data
X_test = np.array([[2, 0, 1]], dtype=int).T
X_input = categories[X_test]
X_trans = enc.transform(X_input)
expected_encoding = cat_encoded[X_test]
assert_allclose(expected_encoding, X_trans)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part does not seem to test anything more than the one just above with the training data? Could it be removed?


Target Regressor Encoder
========================
The :class:`~sklearn.preprocessing.TargetRegressorEncoder` uses statistics of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should start by describing in which cases this Encoder could or should be used. AFAIK it's useful when there are lots of unordered categories, so the OHE would be too expensive? Zipcodes sounds like the go-to illustration use-case?

@zachmayer
Copy link
Contributor

@thomasjpfan this looks really useful, and I am excited you are adding it to sklearn! Out of curiosity, is this similar to the encoding scheme as used by (category_encoders.target_encoder.TargetEncoder)[http://contrib.scikit-learn.org/category_encoders/targetencoder.html]?

@thomasjpfan thomasjpfan changed the title ENH Adds Target Regression Encoder ENH Adds Target Regression Encoder (Impact Encoder) Jul 17, 2020
@lorentzenchr
Copy link
Member

@jnothman If the house price does depend on population density, then the marginal mean of house price per zip code does also depend on population density. If this marginal mean is then used to encode zip code (e.g. instead of OHE), for linear models or certain interpretability/explainability tools, the effect of (residual, i.e. without pop density) zip code is blurred by the effect of population density.

@jnothman
Copy link
Member

jnothman commented Aug 2, 2020 via email

@thomasjpfan
Copy link
Member Author

As for naming, I would want the word "Encoder" somewhere in the name. In the end, I am thinking of having two encoders, one for categorical targets and another for regression targets. Maybe: RegressionTargetEncoder, and ClassificationTargetEncoder?

Initially, I had it named TargetEncoder with a default parameter target_type='continuous', this would mean that it could treat classification encoded as ints as a regression problem.

@lorentzenchr
Copy link
Member

Interesting, #18012 just linked the package feature-engine in the docs, which calls it MeanCategoricalEncoder, see
https://feature-engine.readthedocs.io/en/latest/encoders/MeanCategoricalEncoder.html.

@jnothman
Copy link
Member

jnothman commented Aug 4, 2020 via email

@ogrisel
Copy link
Member

ogrisel commented Nov 26, 2020

Related to this PR, here is very interesting report on the TargetEncoder implemented in cuML that was used by the RAPIDS.ai team to win the RecSys 2020 challenge.

https://medium.com/rapids-ai/target-encoding-with-rapids-cuml-do-more-with-your-categorical-data-8c762c79e784

https://medium.com/rapids-ai/winning-solution-of-recsys2020-challenge-gpu-accelerated-feature-engineering-and-training-for-cd67c5a87b1f

In particular, it's very interesting that they identified that the internal CV is really needed to avoid introducing a distribution shift between train and test for the downstream classifier.

There are also notes about scalability and parallelism to tackle very large datasets efficiently because apparently this pre-processing was a significant part of the computational load of their winning pipeline for this large scale problem (because of the internal CV).

That being said, the fit().transform() != fit_transform() discrepancy would be a problem if we implement the internal CV mode.

@mfeurer
Copy link
Contributor

mfeurer commented Nov 26, 2020

There's another benchmark on using different categorical encoders (medium)(github) which might be of interest for this PR as well. The gist is that using CV is crucial for any supervised categorical encoding as was found by the NVIDIA team winning the RecSYS challenge. In addition to CV, there's also the possibility to add a 2nd layer of CV (as I understand, the implementation from CuML only does one layer of CV) advertised to reduce overfitting even further, which is also benchmarked in this medium blogpost.

@ogrisel
Copy link
Member

ogrisel commented Nov 27, 2020

Thanks very much. For those annoyed by the medium paywall, just open the URL in a private browsing session.

@thomasjpfan
Copy link
Member Author

That being said, the fit().transform() != fit_transform() discrepancy would be a problem if we implement the internal CV mode.

This was the primary reason why I did not use an internal CV for this PR. When I was looking at sources on impact/target encoders, I commonly saw that a CV was important.

The current implementation is trying to automatically weight the "group mean" and the "overall mean" by multilevel estimates of the mean for each group.

@amueller
Copy link
Member

amueller commented Dec 3, 2020

That's the James-Stein encoder, that you're implementing, right?

@thomasjpfan
Copy link
Member Author

That's the James-Stein encoder, right that you're implementing, right?

Oooo it is! It is nice to know this encoding scheme has a name.


cat_encoded = cat_counts * cat_means + cat_var_ratio * y_mean
cat_encoded /= cat_counts + cat_var_ratio
cat_encodings.append(cat_encoded)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure that this matches the James Stein estimator as described in this post:

https://towardsdatascience.com/benchmarking-categorical-encoders-9c322bd77ee8

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ogrisel
Copy link
Member

ogrisel commented Dec 8, 2020

According to https://towardsdatascience.com/benchmarking-categorical-encoders-9c322bd77ee8 , even when using the James Stein encoder you need to do some form of nested cross-validation, otherwise the generalization performance of the overall pipeline is degraded.

@ogrisel
Copy link
Member

ogrisel commented Dec 9, 2020

To summarize some offline discussions: an alternative to the strategy "cross_valid_predict to fit_transform the training set for the downstream classifier, then retrain encoder on full train data to be used to transform test data" that leads to the .fit_transform != .fit.transform discrepancy, we could also consider the following bagging strategy "train k encoders on k sub/resampled training sets and then use the k encoders averaged predictions to transform both the training and testing data for the downstream classifier". This strategy would not induce the .fit_transform != .fit.transform discrepancy and while exactly nested CV, it should allow us to mitigate most of the train / test distribution shift induced by using supervised encoders as pre-processors. This bagging strategy could also be useful for general (multivariate) stacking ensembles.

But AFAIK, this strategy has not been well studied in the literature, so it could be considered problematic to implement in scikit-learn as it kind breaks our rule of only implementing standard, "not invented here" methods.

@ogrisel
Copy link
Member

ogrisel commented Dec 9, 2020

The strategy above would be very similar to BaggingRegressor(base_estimator=TargetEncoder(cv=None), n_estimators=10) if we think that sampling with replacement is a good idea. In practice, for low values of n_estimators (e.g. around 10), I am not sure whether sampling without replacement and with excluded val sets such as done in StratifiedKFold(n_splits=10) would not be better than sampling with replacement.

@lorentzenchr
Copy link
Member

There might be yet another good option to avoid overfitting on the training set.
A bit reformulated, @thomasjpfan implemented for categorical level c:

encode_c = alpha_c * mean(y|c) + (1 - alpha_c) * mean(y)
alpha_c = n_c / (n_c + var(y|c) / var(y))

mean and var being the empirical estimators for expectation and variance and n_c the number of samples having level c. Note that var(y|c) is different for each individual categorical level c.

Another, statistically well founded option is to use one and the same value for all var(y|c), namely their average:

new_var(y|c) = 1/numer_of_levels * sum_c var(y|c)

In addition, I'd use the homogeneous estimator for mean(y), cf. Theorem 8.16 of the reference, and an unbiased estimator for var(y), cf. Eq. (8.15).

Reference:
https://dx.doi.org/10.2139/ssrn.2319328 (Last revised: 17 Dec 2020), Chapter

  • 8.2.2 Bühlmann-Straub credibility formula
  • 8.2.3 Estimation of structural parameters

@thomasjpfan
Copy link
Member Author

@lorentzenchr That looks interesting. Do you have a reference for classification targets?

Going through the code and diagrams of the blog post mentioned in the comments, I think all of them include a sense of bagging. The "single" and "double" validation both generates models on different folds. During prediction time, the predictions are combined as shown here.

I extended this PR with benchmarks here: https://github.com/thomasjpfan/sk_encoder_cv where the readme shows the results. All the datasets are from openml.

  1. SKTargetEncoderCV is the CV version of the estimator in this PR.
  2. James Stein is from categorical_encoder and there is also a CV version.
  3. The sk_encoder_cv repo adds a NestedEncoderCV that adds the CV functionality.

TLDR: For the most part the CV version of the target encoder does better or on par with the non-cv version. For telco or amazon_access datasets, the CV versions does quite a bit beter.

Base automatically changed from master to main January 22, 2021 10:52
@thomasjpfan
Copy link
Member Author

I recently added more benchmarks with more datasets with cv=10 here: https://github.com/thomasjpfan/sk_encoder_cv. Note I do not show datasets where all the encoders perform the same. The "bagging" approach used 10 bootstrap samples to train 10 encoders and transform averages the encodings together. The CV approach results in fit.transform != fit_transform. I also included the target-encoder with the Bühlmann-Straub (BS) estimator provided by @lorentzenchr

Thoughts:

  1. CV performed better: telco, amazon_access, kicks, SpeedDating, colleges, KDCup09_upselling, KDDCup09_appetency, rl
  2. CV did not change result: dress_sales
  3. CV performed worst: phishing_websites, black_friday (but only slightly)
  4. It looks like bagging does help, but not as good as the CV approach.

@thomasjpfan
Copy link
Member Author

I am closing this PR in favor of #25334. #25334 uses a normal CV scheme for encoding and is simpler compared to this PR's encoding scheme.

@thomasjpfan thomasjpfan closed this Jan 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants