Skip to content

Commit

Permalink
MAINT Added Parameter Validation for mutual_info_classif (#25769)
Browse files Browse the repository at this point in the history
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
  • Loading branch information
still-learning-ev and jeremiedbb committed Mar 7, 2023
1 parent 9bf04b5 commit 7e8d3c5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
16 changes: 14 additions & 2 deletions sklearn/feature_selection/_mutual_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# License: 3-clause BSD

import numpy as np
from numbers import Integral
from scipy.sparse import issparse
from scipy.special import digamma

Expand All @@ -11,6 +12,7 @@
from ..utils import check_random_state
from ..utils.validation import check_array, check_X_y
from ..utils.multiclass import check_classification_targets
from ..utils._param_validation import Interval, StrOptions, validate_params


def _compute_mi_cc(x, y, n_neighbors):
Expand Down Expand Up @@ -388,6 +390,16 @@ def mutual_info_regression(
return _estimate_mi(X, y, discrete_features, False, n_neighbors, copy, random_state)


@validate_params(
{
"X": ["array-like", "sparse matrix"],
"y": ["array-like"],
"discrete_features": [StrOptions({"auto"}), "boolean", "array-like"],
"n_neighbors": [Interval(Integral, 1, None, closed="left")],
"copy": ["boolean"],
"random_state": ["random_state"],
}
)
def mutual_info_classif(
X, y, *, discrete_features="auto", n_neighbors=3, copy=True, random_state=None
):
Expand All @@ -407,13 +419,13 @@ def mutual_info_classif(
Parameters
----------
X : array-like or sparse matrix, shape (n_samples, n_features)
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Feature matrix.
y : array-like of shape (n_samples,)
Target vector.
discrete_features : {'auto', bool, array-like}, default='auto'
discrete_features : 'auto', bool or array-like, default='auto'
If bool, then determines whether to consider all features discrete
or continuous. If array, then it should be either a boolean mask
with shape (n_features,) or array with indices of discrete features.
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def _check_function_param_validation(
"sklearn.feature_selection.chi2",
"sklearn.feature_selection.f_classif",
"sklearn.feature_selection.f_regression",
"sklearn.feature_selection.mutual_info_classif",
"sklearn.feature_selection.r_regression",
"sklearn.metrics.accuracy_score",
"sklearn.metrics.auc",
Expand Down

0 comments on commit 7e8d3c5

Please sign in to comment.