Skip to content

Commit

Permalink
FIX Raise an error when min_samples_split=1 in trees (#25744)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiedbb committed Mar 8, 2023
1 parent 12f1675 commit 0cae7df
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 13 deletions.
9 changes: 9 additions & 0 deletions doc/whats_new/v1.2.rst
Expand Up @@ -79,6 +79,15 @@ Changelog
`encoded_missing_value` or `unknown_value` set to a categories' cardinality
when there is missing values in the training data. :pr:`25704` by `Thomas Fan`_.

:mod:`sklearn.tree`
...................

- |Fix| Fixed a regression in :class:`tree.DecisionTreeClassifier`,
:class:`tree.DecisionTreeRegressor`, :class:`tree.ExtraTreeClassifier` and
:class:`tree.ExtraTreeRegressor` where an error was no longer raised in version
1.2 when `min_sample_split=1`.
:pr:`25744` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.utils`
....................

Expand Down
6 changes: 3 additions & 3 deletions sklearn/tree/_classes.py
Expand Up @@ -99,16 +99,16 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta):
"max_depth": [Interval(Integral, 1, None, closed="left"), None],
"min_samples_split": [
Interval(Integral, 2, None, closed="left"),
Interval(Real, 0.0, 1.0, closed="right"),
Interval("real_not_int", 0.0, 1.0, closed="right"),
],
"min_samples_leaf": [
Interval(Integral, 1, None, closed="left"),
Interval(Real, 0.0, 1.0, closed="neither"),
Interval("real_not_int", 0.0, 1.0, closed="neither"),
],
"min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")],
"max_features": [
Interval(Integral, 1, None, closed="left"),
Interval(Real, 0.0, 1.0, closed="right"),
Interval("real_not_int", 0.0, 1.0, closed="right"),
StrOptions({"auto", "sqrt", "log2"}, deprecated={"auto"}),
None,
],
Expand Down
22 changes: 22 additions & 0 deletions sklearn/tree/tests/test_tree.py
Expand Up @@ -2406,3 +2406,25 @@ def test_tree_deserialization_from_read_only_buffer(tmpdir):
clf.tree_,
"The trees of the original and loaded classifiers are not equal.",
)


@pytest.mark.parametrize("Tree", ALL_TREES.values())
def test_min_sample_split_1_error(Tree):
"""Check that an error is raised when min_sample_split=1.
non-regression test for issue gh-25481.
"""
X = np.array([[0, 0], [1, 1]])
y = np.array([0, 1])

# min_samples_split=1.0 is valid
Tree(min_samples_split=1.0).fit(X, y)

# min_samples_split=1 is invalid
tree = Tree(min_samples_split=1)
msg = (
r"'min_samples_split' .* must be an int in the range \[2, inf\) "
r"or a float in the range \(0.0, 1.0\]"
)
with pytest.raises(ValueError, match=msg):
tree.fit(X, y)
37 changes: 27 additions & 10 deletions sklearn/utils/_param_validation.py
Expand Up @@ -364,9 +364,12 @@ class Interval(_Constraint):
Parameters
----------
type : {numbers.Integral, numbers.Real}
type : {numbers.Integral, numbers.Real, "real_not_int"}
The set of numbers in which to set the interval.
If "real_not_int", only reals that don't have the integer type
are allowed. For example 1.0 is allowed but 1 is not.
left : float or int or None
The left bound of the interval. None means left bound is -∞.
Expand All @@ -392,14 +395,6 @@ class Interval(_Constraint):
`[0, +∞) U {+∞}`.
"""

@validate_params(
{
"type": [type],
"left": [Integral, Real, None],
"right": [Integral, Real, None],
"closed": [StrOptions({"left", "right", "both", "neither"})],
}
)
def __init__(self, type, left, right, *, closed):
super().__init__()
self.type = type
Expand All @@ -410,6 +405,18 @@ def __init__(self, type, left, right, *, closed):
self._check_params()

def _check_params(self):
if self.type not in (Integral, Real, "real_not_int"):
raise ValueError(
"type must be either numbers.Integral, numbers.Real or 'real_not_int'."
f" Got {self.type} instead."
)

if self.closed not in ("left", "right", "both", "neither"):
raise ValueError(
"closed must be either 'left', 'right', 'both' or 'neither'. "
f"Got {self.closed} instead."
)

if self.type is Integral:
suffix = "for an interval over the integers."
if self.left is not None and not isinstance(self.left, Integral):
Expand All @@ -424,6 +431,11 @@ def _check_params(self):
raise ValueError(
f"right can't be None when closed == {self.closed} {suffix}"
)
else:
if self.left is not None and not isinstance(self.left, Real):
raise TypeError("Expecting left to be a real number.")
if self.right is not None and not isinstance(self.right, Real):
raise TypeError("Expecting right to be a real number.")

if self.right is not None and self.left is not None and self.right <= self.left:
raise ValueError(
Expand All @@ -447,8 +459,13 @@ def __contains__(self, val):
return False
return True

def _has_valid_type(self, val):
if self.type == "real_not_int":
return isinstance(val, Real) and not isinstance(val, Integral)
return isinstance(val, self.type)

def is_satisfied_by(self, val):
if not isinstance(val, self.type):
if not self._has_valid_type(val):
return False

return val in self
Expand Down
7 changes: 7 additions & 0 deletions sklearn/utils/tests/test_param_validation.py
Expand Up @@ -662,3 +662,10 @@ def fit(self, X=None, y=None):
# does not raise, even though "b" is not in the constraints dict and "a" is not
# a parameter of the estimator.
ThirdPartyEstimator(b=0).fit()


def test_interval_real_not_int():
"""Check for the type "real_not_int" in the Interval constraint."""
constraint = Interval("real_not_int", 0, 1, closed="both")
assert constraint.is_satisfied_by(1.0)
assert not constraint.is_satisfied_by(1)

0 comments on commit 0cae7df

Please sign in to comment.