diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index c252a7c18f5c9..fbc99f3cbc80d 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -79,6 +79,15 @@ Changelog `encoded_missing_value` or `unknown_value` set to a categories' cardinality when there is missing values in the training data. :pr:`25704` by `Thomas Fan`_. +:mod:`sklearn.tree` +................... + +- |Fix| Fixed a regression in :class:`tree.DecisionTreeClassifier`, + :class:`tree.DecisionTreeRegressor`, :class:`tree.ExtraTreeClassifier` and + :class:`tree.ExtraTreeRegressor` where an error was no longer raised in version + 1.2 when `min_sample_split=1`. + :pr:`25744` by :user:`Jérémie du Boisberranger `. + :mod:`sklearn.utils` .................... diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index e2e41f9aea78b..6e01b8b49e594 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -99,16 +99,16 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): "max_depth": [Interval(Integral, 1, None, closed="left"), None], "min_samples_split": [ Interval(Integral, 2, None, closed="left"), - Interval(Real, 0.0, 1.0, closed="right"), + Interval("real_not_int", 0.0, 1.0, closed="right"), ], "min_samples_leaf": [ Interval(Integral, 1, None, closed="left"), - Interval(Real, 0.0, 1.0, closed="neither"), + Interval("real_not_int", 0.0, 1.0, closed="neither"), ], "min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")], "max_features": [ Interval(Integral, 1, None, closed="left"), - Interval(Real, 0.0, 1.0, closed="right"), + Interval("real_not_int", 0.0, 1.0, closed="right"), StrOptions({"auto", "sqrt", "log2"}, deprecated={"auto"}), None, ], diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 9b1a29f02ead7..c796177ad814c 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2425,3 +2425,25 @@ def test_tree_deserialization_from_read_only_buffer(tmpdir): clf.tree_, "The trees of the original and loaded classifiers are not equal.", ) + + +@pytest.mark.parametrize("Tree", ALL_TREES.values()) +def test_min_sample_split_1_error(Tree): + """Check that an error is raised when min_sample_split=1. + + non-regression test for issue gh-25481. + """ + X = np.array([[0, 0], [1, 1]]) + y = np.array([0, 1]) + + # min_samples_split=1.0 is valid + Tree(min_samples_split=1.0).fit(X, y) + + # min_samples_split=1 is invalid + tree = Tree(min_samples_split=1) + msg = ( + r"'min_samples_split' .* must be an int in the range \[2, inf\) " + r"or a float in the range \(0.0, 1.0\]" + ) + with pytest.raises(ValueError, match=msg): + tree.fit(X, y) diff --git a/sklearn/utils/_param_validation.py b/sklearn/utils/_param_validation.py index aa8906071c6af..8d23f0b23b6eb 100644 --- a/sklearn/utils/_param_validation.py +++ b/sklearn/utils/_param_validation.py @@ -364,9 +364,12 @@ class Interval(_Constraint): Parameters ---------- - type : {numbers.Integral, numbers.Real} + type : {numbers.Integral, numbers.Real, "real_not_int"} The set of numbers in which to set the interval. + If "real_not_int", only reals that don't have the integer type + are allowed. For example 1.0 is allowed but 1 is not. + left : float or int or None The left bound of the interval. None means left bound is -∞. @@ -392,14 +395,6 @@ class Interval(_Constraint): `[0, +∞) U {+∞}`. """ - @validate_params( - { - "type": [type], - "left": [Integral, Real, None], - "right": [Integral, Real, None], - "closed": [StrOptions({"left", "right", "both", "neither"})], - } - ) def __init__(self, type, left, right, *, closed): super().__init__() self.type = type @@ -410,6 +405,18 @@ def __init__(self, type, left, right, *, closed): self._check_params() def _check_params(self): + if self.type not in (Integral, Real, "real_not_int"): + raise ValueError( + "type must be either numbers.Integral, numbers.Real or 'real_not_int'." + f" Got {self.type} instead." + ) + + if self.closed not in ("left", "right", "both", "neither"): + raise ValueError( + "closed must be either 'left', 'right', 'both' or 'neither'. " + f"Got {self.closed} instead." + ) + if self.type is Integral: suffix = "for an interval over the integers." if self.left is not None and not isinstance(self.left, Integral): @@ -424,6 +431,11 @@ def _check_params(self): raise ValueError( f"right can't be None when closed == {self.closed} {suffix}" ) + else: + if self.left is not None and not isinstance(self.left, Real): + raise TypeError("Expecting left to be a real number.") + if self.right is not None and not isinstance(self.right, Real): + raise TypeError("Expecting right to be a real number.") if self.right is not None and self.left is not None and self.right <= self.left: raise ValueError( @@ -447,8 +459,13 @@ def __contains__(self, val): return False return True + def _has_valid_type(self, val): + if self.type == "real_not_int": + return isinstance(val, Real) and not isinstance(val, Integral) + return isinstance(val, self.type) + def is_satisfied_by(self, val): - if not isinstance(val, self.type): + if not self._has_valid_type(val): return False return val in self diff --git a/sklearn/utils/tests/test_param_validation.py b/sklearn/utils/tests/test_param_validation.py index 85cd06d0f38b8..ce8f9cdf939fd 100644 --- a/sklearn/utils/tests/test_param_validation.py +++ b/sklearn/utils/tests/test_param_validation.py @@ -662,3 +662,10 @@ def fit(self, X=None, y=None): # does not raise, even though "b" is not in the constraints dict and "a" is not # a parameter of the estimator. ThirdPartyEstimator(b=0).fit() + + +def test_interval_real_not_int(): + """Check for the type "real_not_int" in the Interval constraint.""" + constraint = Interval("real_not_int", 0, 1, closed="both") + assert constraint.is_satisfied_by(1.0) + assert not constraint.is_satisfied_by(1)