Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX EmptyRequest.get defaults to Bunch of METHODS #28371

Merged
15 changes: 15 additions & 0 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,3 +1490,18 @@ def test_make_scorer_deprecation(deprecated_params, new_params, warn_msg):
assert deprecated_roc_auc_scorer(classifier, X, y) == pytest.approx(
roc_auc_scorer(classifier, X, y)
)


@pytest.mark.parametrize("enable_metadata_routing", [True, False])
def test_metadata_routing_multimetric_metadata_routing(enable_metadata_routing):
"""Test multimetric scorer works with and without metadata routing enabled when
there is no actual metadata to pass.

Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28256
"""
X, y = make_classification(n_samples=50, n_features=10, random_state=0)
estimator = EstimatorWithFitAndPredict().fit(X, y)

multimetric_scorer = _MultimetricScorer(scorers={"acc": get_scorer("accuracy")})
with config_context(enable_metadata_routing=enable_metadata_routing):
multimetric_scorer(estimator, X, y)
23 changes: 23 additions & 0 deletions sklearn/tests/test_metadata_routing.py
eddiebergman marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
assert_request_is_empty,
check_recorded_metadata,
)
from sklearn.utils._bunch import Bunch
from sklearn.utils import metadata_routing
from sklearn.utils._metadata_requests import (
COMPOSITE_METHODS,
Expand Down Expand Up @@ -239,6 +240,28 @@ class InvalidObject:
process_routing(InvalidObject(), "fit", groups=my_groups)


def test_process_routing_empty_params_get_with_default():
empty_params = {}
routed_params = process_routing(ConsumingClassifier(), "fit", **empty_params)

# Behaviour should be an empty dictionary returned for each method when retrieved.
for method in METHODS:
params_for_method = routed_params[method]

# An empty dictionary for each method
assert isinstance(params_for_method, dict)
assert len(params_for_method) == 0

# This behaviour should be equivalent with using `get` with no default
assert routed_params.get(method) == params_for_method

# However, with a default, should return that instead.
assert routed_params.get(method, default="default") == "default"

# This would fail due to use of `if not default` instead of `if default is None`
# assert routed_params.get(method, default=[]) == []
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please advise on this part. I raised it earlier in the PR but there's been no comment there.

It's hard to find exact usage without types but by searching for .get( across the repo and I found a few places there is params.get("x", {}). It's hard to tell if params is a dict or an EmptyRequest.

It might not be a problem yet but I feel like this could silently cause hard to debug issues in the future, especically in cases where you expected a {} but instead got a Bunch(**{method: {} for method in METHODS}).

if len(params.get("x", {})) == 0:
    # Can never get here, the `Bunch` was returned, not the suggested default of `{}`

First recommendation is use a sentinel value to indicate nothing was passed in. Similar to how more_itertools works with defaults. This allows things to work like so:

default_bunch = params.get("x")
none_value = params.get("x", default=None)
list_value = params.get("x", default=[])

Second recommendation if you do not wish to introduce a sentinel pattern is just to use an explicit if default is None check instead of implicit falsyness. However this might not work as expected:

default_bunch = params.get("x")
dict_value = params.get("x", default={})  # This gives back what was expected

if params.get("x", default=None) is None:
   # This can never happen

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, the most intuitive approach here is the sentinel value. Basically, not passing anything will always return a Bunch. Setting default will return the type of default.

Such semantic is not surprising and expected. Right now, having None returning a Bunch is indeed surprising.

I don't know what @adrinjalali thinks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented the sentinel value approach in the meantime, happy to revert it if @adrinjalali thinks this should not be done.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior should be like a dictionary, when passed the default and the key doesn't exist, we return the default. In this case, I wonder if we should ignore default completely (there only to immitate dict), and always return the empty routing list. Afterall, the whole point of this class is to return an empty routed_params object.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not opposed at removing the default param. However, we would need to change the pattern:

params.get("fit", default={})

that is used in the pipeline for instance.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't remove it, it just always return empty and ignores default. The default needs to be there to mimic a dict().get



def test_simple_metadata_routing():
# Tests that metadata is properly routed

Expand Down
5 changes: 4 additions & 1 deletion sklearn/utils/_metadata_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,7 +1530,10 @@ def process_routing(_obj, _method, /, **kwargs):
# an empty dict on routed_params.ANYTHING.ANY_METHOD.
class EmptyRequest:
def get(self, name, default=None):
return default if default else {}
if not default:
return Bunch(**{method: dict() for method in METHODS})

return default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

codecov is not happy here. I need to figure out when is it the case that default=None

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adrinjalali I assume that we should be able to cover this one because it would be equivalent to call e.g.

routed_params = _process_routing(self, "score", **kwargs)
routed_params.get("score", default="default")

I don't where is the best place to test this. This looks like a metadata routing test to me.


def __getitem__(self, name):
return Bunch(**{method: dict() for method in METHODS})
Expand Down