Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX Raise an error when min_samples_split=1 in trees #25744

Merged
merged 4 commits into from Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions doc/whats_new/v1.2.rst
Expand Up @@ -79,6 +79,15 @@ Changelog
`encoded_missing_value` or `unknown_value` set to a categories' cardinality
when there is missing values in the training data. :pr:`25704` by `Thomas Fan`_.

:mod:`sklearn.tree`
...................

- |Fix| Fixed a regression in :class:`tree.DecisionTreeClassifier`,
:class:`tree.DecisionTreeRegressor`, :class:`tree.ExtraTreeClassifier` and
:class:`tree.ExtraTreeRegressor` where an error was no longer raised in version
1.2 when `min_sample_split=1`.
:pr:`25744` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.utils`
....................

Expand Down
6 changes: 3 additions & 3 deletions sklearn/tree/_classes.py
Expand Up @@ -99,16 +99,16 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta):
"max_depth": [Interval(Integral, 1, None, closed="left"), None],
"min_samples_split": [
Interval(Integral, 2, None, closed="left"),
Interval(Real, 0.0, 1.0, closed="right"),
Interval("real_not_int", 0.0, 1.0, closed="right"),
],
"min_samples_leaf": [
Interval(Integral, 1, None, closed="left"),
Interval(Real, 0.0, 1.0, closed="neither"),
Interval("real_not_int", 0.0, 1.0, closed="neither"),
],
"min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")],
"max_features": [
Interval(Integral, 1, None, closed="left"),
Interval(Real, 0.0, 1.0, closed="right"),
Interval("real_not_int", 0.0, 1.0, closed="right"),
StrOptions({"auto", "sqrt", "log2"}, deprecated={"auto"}),
None,
],
Expand Down
22 changes: 22 additions & 0 deletions sklearn/tree/tests/test_tree.py
Expand Up @@ -2425,3 +2425,25 @@ def test_tree_deserialization_from_read_only_buffer(tmpdir):
clf.tree_,
"The trees of the original and loaded classifiers are not equal.",
)


@pytest.mark.parametrize("Tree", ALL_TREES.values())
def test_min_sample_split_1_error(Tree):
"""Check that an error is raised when min_sample_split=1.

non-regression test for issue gh-25481.
"""
X = np.array([[0, 0], [1, 1]])
y = np.array([0, 1])

# min_samples_split=1.0 is valid
Tree(min_samples_split=1.0).fit(X, y)

# min_samples_split=1 is invalid
tree = Tree(min_samples_split=1)
msg = (
r"'min_samples_split' .* must be an int in the range \[2, inf\) "
r"or a float in the range \(0.0, 1.0\]"
)
with pytest.raises(ValueError, match=msg):
tree.fit(X, y)
32 changes: 22 additions & 10 deletions sklearn/utils/_param_validation.py
Expand Up @@ -364,9 +364,12 @@ class Interval(_Constraint):

Parameters
----------
type : {numbers.Integral, numbers.Real}
type : {numbers.Integral, numbers.Real, "real_not_int"}
The set of numbers in which to set the interval.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be good to have a description of these internally in the docstring? Just for community-devs that aren't familiar w/ what each of these are intended to mean?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a description of the new option.

A short description of all the constraints and what they represent can be found here https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/utils/_param_validation.py#L28

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

I was just mentioning mainly the real_not_int option, not the others.


If "real_not_int", only reals that don't have the integer type
are allowed. For example 1.0 is allowed but 1 is not.

left : float or int or None
The left bound of the interval. None means left bound is -∞.

Expand All @@ -392,14 +395,6 @@ class Interval(_Constraint):
`[0, +∞) U {+∞}`.
"""

@validate_params(
{
"type": [type],
thomasjpfan marked this conversation as resolved.
Show resolved Hide resolved
"left": [Integral, Real, None],
"right": [Integral, Real, None],
"closed": [StrOptions({"left", "right", "both", "neither"})],
}
)
def __init__(self, type, left, right, *, closed):
super().__init__()
self.type = type
Expand All @@ -410,6 +405,18 @@ def __init__(self, type, left, right, *, closed):
self._check_params()

def _check_params(self):
if self.type not in (Integral, Real, "real_not_int"):
raise ValueError(
"type must be either numbers.Integral, numbers.Real or 'real_not_int'."
f" Got {self.type} instead."
)

if self.closed not in ("left", "right", "both", "neither"):
raise ValueError(
"closed must be either 'left', 'right', 'both' or 'neither'. "
f"Got {self.closed} instead."
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the removal of @validate_params, does self.left and self.right needs to be validated here?

Copy link
Member Author

@jeremiedbb jeremiedbb Mar 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I added the missing validation

if self.type is Integral:
suffix = "for an interval over the integers."
if self.left is not None and not isinstance(self.left, Integral):
Expand Down Expand Up @@ -447,8 +454,13 @@ def __contains__(self, val):
return False
return True

def _has_valid_type(self, val):
if self.type == "real_not_int":
return isinstance(val, Real) and not isinstance(val, Integral)
return isinstance(val, self.type)

def is_satisfied_by(self, val):
if not isinstance(val, self.type):
if not self._has_valid_type(val):
return False

return val in self
Expand Down
7 changes: 7 additions & 0 deletions sklearn/utils/tests/test_param_validation.py
Expand Up @@ -662,3 +662,10 @@ def fit(self, X=None, y=None):
# does not raise, even though "b" is not in the constraints dict and "a" is not
# a parameter of the estimator.
ThirdPartyEstimator(b=0).fit()


def test_interval_real_not_int():
"""Check for the type "real_not_int" in the Interval constraint."""
constraint = Interval("real_not_int", 0, 1, closed="both")
assert constraint.is_satisfied_by(1.0)
assert not constraint.is_satisfied_by(1)