Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT Checking function _estimator_has also raises AttributeError #28167

Merged
merged 10 commits into from
Feb 13, 2024
20 changes: 13 additions & 7 deletions sklearn/ensemble/_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,20 @@
def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.

First, we check the first fitted final estimator if available, otherwise we
check the unfitted final estimator.
First, we check the fitted final_estimator if available, otherwise we check the
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved
unfitted final_estimator. We raise the original `AttributeError` if `attr` does
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved
not exist. This function is used together with `avaliable_if`.
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved
"""
return lambda self: (
hasattr(self.final_estimator_, attr)
if hasattr(self, "final_estimator_")
else hasattr(self.final_estimator, attr)
)

def check(self):
if hasattr(self, "final_estimator_"):
getattr(self.final_estimator_, attr)
return True
else:
getattr(self.final_estimator, attr)
return True
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved

return check
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved


class _BaseStacking(TransformerMixin, _BaseHeterogeneousEnsemble, metaclass=ABCMeta):
Expand Down
20 changes: 13 additions & 7 deletions sklearn/feature_selection/_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,20 @@ def _calculate_threshold(estimator, importances, threshold):
def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.

First, we check the fitted estimator if available, otherwise we
check the unfitted estimator.
First, we check the fitted estimator if available, otherwise we check the
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved
unfitted estimator. We raise the original `AttributeError` if `attr` does
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved
not exist. This function is used together with `avaliable_if`.
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved
"""
return lambda self: (
hasattr(self.estimator_, attr)
if hasattr(self, "estimator_")
else hasattr(self.estimator, attr)
)

def check(self):
if hasattr(self, "estimator_"):
getattr(self.estimator_, attr)
return True
else:
getattr(self.estimator, attr)
return True

return check
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved


class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
Expand Down
20 changes: 13 additions & 7 deletions sklearn/feature_selection/_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,20 @@ def _rfe_single_fit(rfe, estimator, X, y, train, test, scorer):
def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.

First, we check the first fitted estimator if available, otherwise we
check the unfitted estimator.
First, we check the fitted estimator if available, otherwise we check the
unfitted estimator. We raise the original `AttributeError` if `attr` does
not exist. This function is used together with `avaliable_if`.
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved
"""
return lambda self: (
hasattr(self.estimator_, attr)
if hasattr(self, "estimator_")
else hasattr(self.estimator, attr)
)

def check(self):
if hasattr(self, "estimator_"):
getattr(self.estimator_, attr)
return True
else:
getattr(self.estimator, attr)
return True

return check
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved


class RFE(_RoutingNotSupportedMixin, SelectorMixin, MetaEstimatorMixin, BaseEstimator):
Expand Down
23 changes: 12 additions & 11 deletions sklearn/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,19 @@ def _estimators_has(attr):
"""Check if self.estimator or self.estimators_[0] has attr.

If `self.estimators_[0]` has the attr, then its safe to assume that other
values has it too. This function is used together with `avaliable_if`.
estimators have it too. We raise the original `AttributeError` if `attr`
does not exist. This function is used together with `avaliable_if`.
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved
"""
return lambda self: (
hasattr(self.estimator, attr)
or (hasattr(self, "estimators_") and hasattr(self.estimators_[0], attr))
)

def check(self):
if hasattr(self, "estimators_"):
getattr(self.estimators_[0], attr)
return True
else:
getattr(self.estimator, attr)
return True

StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved
return check


class OneVsRestClassifier(
Expand Down Expand Up @@ -434,12 +441,6 @@ def partial_fit(self, X, y, classes=None, **partial_fit_params):
)

if _check_partial_fit_first_call(self, classes):
if not hasattr(self.estimator, "partial_fit"):
raise ValueError(
("Base estimator {0}, doesn't have partial_fit method").format(
self.estimator
)
)
self.estimators_ = [clone(self.estimator) for _ in range(self.n_classes_)]

# A sparse LabelBinarizer, with sparse_output=True, has been
Expand Down
22 changes: 16 additions & 6 deletions sklearn/semi_supervised/_self_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,22 @@


def _estimator_has(attr):
"""Check if `self.base_estimator_ `or `self.base_estimator_` has `attr`."""
return lambda self: (
hasattr(self.base_estimator_, attr)
if hasattr(self, "base_estimator_")
else hasattr(self.base_estimator, attr)
)
"""Check if we can delegate a method to the underlying estimator.

First, we check the fitted base_estimator if available, otherwise we check
the unfitted base_estimator. We raise the original `AttributeError` if
`attr` does not exist. This function is used together with `avaliable_if`.
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved
"""

def check(self):
if hasattr(self, "base_estimator_"):
getattr(self.base_estimator_, attr)
return True
else:
getattr(self.base_estimator, attr)
return True

return check
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved


class SelfTrainingClassifier(
Expand Down