Skip to content

Commit

Permalink
fixed for checking np.ndarrays and regular arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
trinhcon committed Mar 2, 2022
1 parent 5a6df33 commit 689e72d
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,14 +1613,20 @@ def ndcg_score(y_true, y_score, *, k=None, sample_weight=None, ignore_ties=False
0.5
"""
if (y_true.min() < 0):
raise DeprecationWarning("ndcg_score should not use negative y_true values")

y_true = check_array(y_true, ensure_2d=False)
y_score = check_array(y_score, ensure_2d=False)
check_consistent_length(y_true, y_score, sample_weight)
_check_dcg_target_type(y_true)
gain = _ndcg_sample_scores(y_true, y_score, k=k, ignore_ties=ignore_ties)

if (isinstance(y_true, np.ndarray)):
if (y_true.min() < 0):
raise DeprecationWarning("ndcg_score should not use negative y_true values")
else:
for value in y_true:
if (value < 0):
raise DeprecationWarning("ndcg_score should not use negative y_true values")
return np.average(gain, weights=sample_weight)


Expand Down

0 comments on commit 689e72d

Please sign in to comment.