diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 2e54d000a13aa..2fd1366e18434 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -946,6 +946,7 @@ details. metrics.cohen_kappa_score metrics.confusion_matrix metrics.dcg_score + metrics.detection_error_tradeoff_curve metrics.f1_score metrics.fbeta_score metrics.hamming_loss diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 5bdb5091ef9c7..60e99a5102cb7 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -306,6 +306,7 @@ Some of these are restricted to the binary classification case: precision_recall_curve roc_curve + detection_error_tradeoff_curve Others also work in the multiclass case: @@ -1437,6 +1438,93 @@ to the given limit. In Data Mining, 2001. Proceedings IEEE International Conference, pp. 131-138. +.. _det_curve: + +Detection error tradeoff (DET) +------------------------------ + +The function :func:`detection_error_tradeoff_curve` computes the +detection error tradeoff curve (DET) curve [WikipediaDET2017]_. +Quoting Wikipedia: + + "A detection error tradeoff (DET) graph is a graphical plot of error rates for + binary classification systems, plotting false reject rate vs. false accept + rate. The x- and y-axes are scaled non-linearly by their standard normal + deviates (or just by logarithmic transformation), yielding tradeoff curves + that are more linear than ROC curves, and use most of the image area to + highlight the differences of importance in the critical operating region." + +DET curves are a variation of receiver operating characteristic (ROC) curves +where False Negative Rate is plotted on the ordinate instead of True Positive +Rate. +DET curves are commonly plotted in normal deviate scale by transformation with +:math:`\phi^{-1}` (with :math:`\phi` being the cumulative distribution +function). +The resulting performance curves explicitly visualize the tradeoff of error +types for given classification algorithms. +See [Martin1997]_ for examples and further motivation. + +This figure compares the ROC and DET curves of two example classifiers on the +same classification task: + +.. image:: ../auto_examples/model_selection/images/sphx_glr_plot_det_001.png + :target: ../auto_examples/model_selection/plot_det.html + :scale: 75 + :align: center + +**Properties:** + +* DET curves form a linear curve in normal deviate scale if the detection + scores are normally (or close-to normally) distributed. + It was shown by [Navratil2007]_ that the reverse it not necessarily true and even more + general distributions are able produce linear DET curves. + +* The normal deviate scale transformation spreads out the points such that a + comparatively larger space of plot is occupied. + Therefore curves with similar classification performance might be easier to + distinguish on a DET plot. + +* With False Negative Rate being "inverse" to True Positive Rate the point + of perfection for DET curves is the origin (in contrast to the top left corner + for ROC curves). + +**Applications and limitations:** + +DET curves are intuitive to read and hence allow quick visual assessment of a +classifier's performance. +Additionally DET curves can be consulted for threshold analysis and operating +point selection. +This is particularly helpful if a comparison of error types is required. + +One the other hand DET curves do not provide their metric as a single number. +Therefore for either automated evaluation or comparison to other +classification tasks metrics like the derived area under ROC curve might be +better suited. + +.. topic:: Examples: + + * See :ref:`sphx_glr_auto_examples_model_selection_plot_det.py` + for an example comparison between receiver operating characteristic (ROC) + curves and Detection error tradeoff (DET) curves. + +.. topic:: References: + + .. [WikipediaDET2017] Wikipedia contributors. Detection error tradeoff. + Wikipedia, The Free Encyclopedia. September 4, 2017, 23:33 UTC. + Available at: https://en.wikipedia.org/w/index.php?title=Detection_error_tradeoff&oldid=798982054. + Accessed February 19, 2018. + + .. [Martin1997] A. Martin, G. Doddington, T. Kamm, M. Ordowski, and M. Przybocki, + `The DET Curve in Assessment of Detection Task Performance + `_, + NIST 1997. + + .. [Navratil2007] J. Navractil and D. Klusacek, + "`On Linear DETs, + `_" + 2007 IEEE International Conference on Acoustics, + Speech and Signal Processing - ICASSP '07, Honolulu, + HI, 2007, pp. IV-229-IV-232. .. _zero_one_loss: diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index 764274c6dccdb..4005a31619771 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -270,6 +270,11 @@ Changelog :mod:`sklearn.metrics` ...................... +- |Feature| Added :func:`metrics.detection_error_tradeoff_curve` to compute + Detection Error Tradeoff curve classification metric. + :pr:`10591` by :user:`Jeremy Karnowski ` and + :user:`Daniel Mohns `. + - |Feature| Added :func:`metrics.mean_absolute_percentage_error` metric and the associated scorer for regression problems. :issue:`10708` fixed with the PR :pr:`15007` by :user:`Ashutosh Hathidara `. The scorer and diff --git a/examples/model_selection/plot_det.py b/examples/model_selection/plot_det.py new file mode 100644 index 0000000000000..6cfac7e5ce0ca --- /dev/null +++ b/examples/model_selection/plot_det.py @@ -0,0 +1,145 @@ +""" +======================================= +Detection error tradeoff (DET) curve +======================================= + +In this example, we compare receiver operating characteristic (ROC) and +detection error tradeoff (DET) curves for different classification algorithms +for the same classification task. + +DET curves are commonly plotted in normal deviate scale. +To achieve this we transform the errors rates as returned by the +``detection_error_tradeoff_curve`` function and the axis scale using +``scipy.stats.norm``. + +The point of this example is to demonstrate two properties of DET curves, +namely: + +1. It might be easier to visually assess the overall performance of different + classification algorithms using DET curves over ROC curves. + Due to the linear scale used for plotting ROC curves, different classifiers + usually only differ in the top left corner of the graph and appear similar + for a large part of the plot. On the other hand, because DET curves + represent straight lines in normal deviate scale. As such, they tend to be + distinguishable as a whole and the area of interest spans a large part of + the plot. +2. DET curves give the user direct feedback of the detection error tradeoff to + aid in operating point analysis. + The user can deduct directly from the DET-curve plot at which rate + false-negative error rate will improve when willing to accept an increase in + false-positive error rate (or vice-versa). + +The plots in this example compare ROC curves on the left side to corresponding +DET curves on the right. +There is no particular reason why these classifiers have been chosen for the +example plot over other classifiers available in scikit-learn. + +.. note:: + + - See :func:`sklearn.metrics.roc_curve` for further information about ROC + curves. + + - See :func:`sklearn.metrics.detection_error_tradeoff_curve` for further + information about DET curves. + + - This example is loosely based on + :ref:`sphx_glr_auto_examples_classification_plot_classifier_comparison.py` + . + +""" +import matplotlib.pyplot as plt + +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.datasets import make_classification +from sklearn.svm import SVC +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import detection_error_tradeoff_curve +from sklearn.metrics import roc_curve + +from scipy.stats import norm +from matplotlib.ticker import FuncFormatter + +N_SAMPLES = 1000 + +names = [ + "Linear SVM", + "Random Forest", +] + +classifiers = [ + SVC(kernel="linear", C=0.025), + RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1), +] + +X, y = make_classification( + n_samples=N_SAMPLES, n_features=2, n_redundant=0, n_informative=2, + random_state=1, n_clusters_per_class=1) + +# preprocess dataset, split into training and test part +X = StandardScaler().fit_transform(X) + +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=.4, random_state=0) + +# prepare plots +fig, [ax_roc, ax_det] = plt.subplots(1, 2, figsize=(10, 5)) + +# first prepare the ROC curve +ax_roc.set_title('Receiver Operating Characteristic (ROC) curves') +ax_roc.set_xlabel('False Positive Rate') +ax_roc.set_ylabel('True Positive Rate') +ax_roc.set_xlim(0, 1) +ax_roc.set_ylim(0, 1) +ax_roc.grid(linestyle='--') +ax_roc.yaxis.set_major_formatter( + FuncFormatter(lambda y, _: '{:.0%}'.format(y))) +ax_roc.xaxis.set_major_formatter( + FuncFormatter(lambda y, _: '{:.0%}'.format(y))) + +# second prepare the DET curve +ax_det.set_title('Detection Error Tradeoff (DET) curves') +ax_det.set_xlabel('False Positive Rate') +ax_det.set_ylabel('False Negative Rate') +ax_det.set_xlim(-3, 3) +ax_det.set_ylim(-3, 3) +ax_det.grid(linestyle='--') + +# customized ticks for DET curve plot to represent normal deviate scale +ticks = [0.001, 0.01, 0.05, 0.20, 0.5, 0.80, 0.95, 0.99, 0.999] +tick_locs = norm.ppf(ticks) +tick_lbls = [ + '{:.0%}'.format(s) if (100*s).is_integer() else '{:.1%}'.format(s) + for s in ticks +] +plt.sca(ax_det) +plt.xticks(tick_locs, tick_lbls) +plt.yticks(tick_locs, tick_lbls) + +# iterate over classifiers +for name, clf in zip(names, classifiers): + clf.fit(X_train, y_train) + + if hasattr(clf, "decision_function"): + y_score = clf.decision_function(X_test) + else: + y_score = clf.predict_proba(X_test)[:, 1] + + roc_fpr, roc_tpr, _ = roc_curve(y_test, y_score) + det_fpr, det_fnr, _ = detection_error_tradeoff_curve(y_test, y_score) + + ax_roc.plot(roc_fpr, roc_tpr) + + # transform errors into normal deviate scale + ax_det.plot( + norm.ppf(det_fpr), + norm.ppf(det_fnr) + ) + +# add a single legend +plt.sca(ax_det) +plt.legend(names, loc="upper right") + +# plot +plt.tight_layout() +plt.show() diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index be28005631963..a69d5c618c20f 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -7,6 +7,7 @@ from ._ranking import auc from ._ranking import average_precision_score from ._ranking import coverage_error +from ._ranking import detection_error_tradeoff_curve from ._ranking import dcg_score from ._ranking import label_ranking_average_precision_score from ._ranking import label_ranking_loss @@ -104,6 +105,7 @@ 'coverage_error', 'dcg_score', 'davies_bouldin_score', + 'detection_error_tradeoff_curve', 'euclidean_distances', 'explained_variance_score', 'f1_score', diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 547dd0a9a0e22..6727de0c05c65 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -218,6 +218,94 @@ def _binary_uninterpolated_average_precision( average, sample_weight=sample_weight) +def detection_error_tradeoff_curve(y_true, y_score, pos_label=None, + sample_weight=None): + """Compute error rates for different probability thresholds. + + Note: This metrics is used for ranking evaluation of a binary + classification task. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + y_true : array, shape = [n_samples] + True targets of binary classification in range {-1, 1} or {0, 1}. + + y_score : array, shape = [n_samples] + Estimated probabilities or decision function. + + pos_label : int, optional (default=None) + The label of the positive class + + sample_weight : array-like of shape = [n_samples], optional + Sample weights. + + Returns + ------- + fpr : array, shape = [n_thresholds] + False positive rate (FPR) such that element i is the false positive + rate of predictions with score >= thresholds[i]. This is occasionally + referred to as false acceptance propability or fall-out. + + fnr : array, shape = [n_thresholds] + False negative rate (FNR) such that element i is the false negative + rate of predictions with score >= thresholds[i]. This is occasionally + referred to as false rejection or miss rate. + + thresholds : array, shape = [n_thresholds] + Decreasing score values. + + See also + -------- + roc_curve : Compute Receiver operating characteristic (ROC) curve + precision_recall_curve : Compute precision-recall curve + + Examples + -------- + >>> import numpy as np + >>> from sklearn.metrics import detection_error_tradeoff_curve + >>> y_true = np.array([0, 0, 1, 1]) + >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8]) + >>> fpr, fnr, thresholds = detection_error_tradeoff_curve(y_true, y_scores) + >>> fpr + array([0.5, 0.5, 0. ]) + >>> fnr + array([0. , 0.5, 0.5]) + >>> thresholds + array([0.35, 0.4 , 0.8 ]) + + """ + if len(np.unique(y_true)) != 2: + raise ValueError("Only one class present in y_true. Detection error " + "tradeoff curve is not defined in that case.") + + fps, tps, thresholds = _binary_clf_curve(y_true, y_score, + pos_label=pos_label, + sample_weight=sample_weight) + + fns = tps[-1] - tps + p_count = tps[-1] + n_count = fps[-1] + + # start with false positives zero + first_ind = ( + fps.searchsorted(fps[0], side='right') - 1 + if fps.searchsorted(fps[0], side='right') > 0 + else None + ) + # stop with false negatives zero + last_ind = tps.searchsorted(tps[-1]) + 1 + sl = slice(first_ind, last_ind) + + # reverse the output such that list of false positives is decreasing + return ( + fps[sl][::-1] / n_count, + fns[sl][::-1] / p_count, + thresholds[sl][::-1] + ) + + def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None): """Binary roc auc score.""" if len(np.unique(y_true)) != 2: diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 3f2ba83b474c7..24f01d46610a7 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -29,6 +29,7 @@ from sklearn.metrics import cohen_kappa_score from sklearn.metrics import confusion_matrix from sklearn.metrics import coverage_error +from sklearn.metrics import detection_error_tradeoff_curve from sklearn.metrics import explained_variance_score from sklearn.metrics import f1_score from sklearn.metrics import fbeta_score @@ -205,6 +206,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): CURVE_METRICS = { "roc_curve": roc_curve, "precision_recall_curve": precision_recall_curve_padded_thresholds, + "detection_error_tradeoff_curve": detection_error_tradeoff_curve, } THRESHOLDED_METRICS = { @@ -301,6 +303,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): # curves "roc_curve", "precision_recall_curve", + "detection_error_tradeoff_curve", } # Metric undefined with "binary" or "multiclass" input @@ -322,6 +325,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): METRICS_WITH_POS_LABEL = { "roc_curve", "precision_recall_curve", + "detection_error_tradeoff_curve", "brier_score_loss", @@ -352,6 +356,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "normalized_confusion_matrix", "roc_curve", "precision_recall_curve", + "detection_error_tradeoff_curve", "precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score", "jaccard_score", @@ -464,6 +469,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "normalized_confusion_matrix", "roc_curve", "precision_recall_curve", + "detection_error_tradeoff_curve", "precision_score", "recall_score", "f2_score", "f0.5_score", diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index 3daafa8d196d3..e08a8909cfe72 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -16,11 +16,13 @@ from sklearn.utils._testing import assert_almost_equal from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_array_almost_equal +from sklearn.utils._testing import assert_raises from sklearn.utils._testing import assert_warns from sklearn.metrics import auc from sklearn.metrics import average_precision_score from sklearn.metrics import coverage_error +from sklearn.metrics import detection_error_tradeoff_curve from sklearn.metrics import label_ranking_average_precision_score from sklearn.metrics import precision_recall_curve from sklearn.metrics import label_ranking_loss @@ -925,6 +927,111 @@ def test_score_scale_invariance(): assert pr_auc == pr_auc_shifted +@pytest.mark.parametrize("y_true,y_score,expected_fpr,expected_fnr", [ + ([0, 0, 1], [0, 0.5, 1], [0], [0]), + ([0, 0, 1], [0, 0.25, 0.5], [0], [0]), + ([0, 0, 1], [0.5, 0.75, 1], [0], [0]), + ([0, 0, 1], [0.25, 0.5, 0.75], [0], [0]), + ([0, 1, 0], [0, 0.5, 1], [0.5], [0]), + ([0, 1, 0], [0, 0.25, 0.5], [0.5], [0]), + ([0, 1, 0], [0.5, 0.75, 1], [0.5], [0]), + ([0, 1, 0], [0.25, 0.5, 0.75], [0.5], [0]), + ([0, 1, 1], [0, 0.5, 1], [0.0], [0]), + ([0, 1, 1], [0, 0.25, 0.5], [0], [0]), + ([0, 1, 1], [0.5, 0.75, 1], [0], [0]), + ([0, 1, 1], [0.25, 0.5, 0.75], [0], [0]), + ([1, 0, 0], [0, 0.5, 1], [1, 1, 0.5], [0, 1, 1]), + ([1, 0, 0], [0, 0.25, 0.5], [1, 1, 0.5], [0, 1, 1]), + ([1, 0, 0], [0.5, 0.75, 1], [1, 1, 0.5], [0, 1, 1]), + ([1, 0, 0], [0.25, 0.5, 0.75], [1, 1, 0.5], [0, 1, 1]), + ([1, 0, 1], [0, 0.5, 1], [1, 1, 0], [0, 0.5, 0.5]), + ([1, 0, 1], [0, 0.25, 0.5], [1, 1, 0], [0, 0.5, 0.5]), + ([1, 0, 1], [0.5, 0.75, 1], [1, 1, 0], [0, 0.5, 0.5]), + ([1, 0, 1], [0.25, 0.5, 0.75], [1, 1, 0], [0, 0.5, 0.5]), +]) +def test_detection_error_tradeoff_curve_toydata(y_true, y_score, + expected_fpr, expected_fnr): + # Check on a batch of small examples. + fpr, fnr, _ = detection_error_tradeoff_curve(y_true, y_score) + + assert_array_almost_equal(fpr, expected_fpr) + assert_array_almost_equal(fnr, expected_fnr) + + +@pytest.mark.parametrize("y_true,y_score,expected_fpr,expected_fnr", [ + ([1, 0], [0.5, 0.5], [1], [0]), + ([0, 1], [0.5, 0.5], [1], [0]), + ([0, 0, 1], [0.25, 0.5, 0.5], [0.5], [0]), + ([0, 1, 0], [0.25, 0.5, 0.5], [0.5], [0]), + ([0, 1, 1], [0.25, 0.5, 0.5], [0], [0]), + ([1, 0, 0], [0.25, 0.5, 0.5], [1], [0]), + ([1, 0, 1], [0.25, 0.5, 0.5], [1], [0]), + ([1, 1, 0], [0.25, 0.5, 0.5], [1], [0]), +]) +def test_detection_error_tradeoff_curve_tie_handling(y_true, y_score, + expected_fpr, + expected_fnr): + fpr, fnr, _ = detection_error_tradeoff_curve(y_true, y_score) + + assert_array_almost_equal(fpr, expected_fpr) + assert_array_almost_equal(fnr, expected_fnr) + + +def test_detection_error_tradeoff_curve_sanity_check(): + # Exactly duplicated inputs yield the same result. + assert_array_almost_equal( + detection_error_tradeoff_curve([0, 0, 1], [0, 0.5, 1]), + detection_error_tradeoff_curve( + [0, 0, 0, 0, 1, 1], [0, 0, 0.5, 0.5, 1, 1]) + ) + + +@pytest.mark.parametrize("y_score", [ + (0), (0.25), (0.5), (0.75), (1) +]) +def test_detection_error_tradeoff_curve_constant_scores(y_score): + fpr, fnr, threshold = detection_error_tradeoff_curve( + y_true=[0, 1, 0, 1, 0, 1], + y_score=np.full(6, y_score) + ) + + assert_array_almost_equal(fpr, [1]) + assert_array_almost_equal(fnr, [0]) + assert_array_almost_equal(threshold, [y_score]) + + +@pytest.mark.parametrize("y_true", [ + ([0, 0, 0, 0, 0, 1]), + ([0, 0, 0, 0, 1, 1]), + ([0, 0, 0, 1, 1, 1]), + ([0, 0, 1, 1, 1, 1]), + ([0, 1, 1, 1, 1, 1]), +]) +def test_detection_error_tradeoff_curve_perfect_scores(y_true): + fpr, fnr, _ = detection_error_tradeoff_curve( + y_true=y_true, + y_score=y_true + ) + + assert_array_almost_equal(fpr, [0]) + assert_array_almost_equal(fnr, [0]) + + +def test_detection_error_tradeoff_curve_bad_input(): + # input variables with inconsistent numbers of samples + assert_raises(ValueError, detection_error_tradeoff_curve, + [0, 1], [0, 0.5, 1]) + assert_raises(ValueError, detection_error_tradeoff_curve, + [0, 1, 1], [0, 0.5]) + + # When the y_true values are all the same a detection error tradeoff cannot + # be computed. + assert_raises(ValueError, detection_error_tradeoff_curve, + [0, 0, 0], [0, 0.5, 1]) + assert_raises(ValueError, detection_error_tradeoff_curve, + [1, 1, 1], [0, 0.5, 1]) + + def check_lrap_toy(lrap_score): # Check on several small example that it works assert_almost_equal(lrap_score([[0, 1]], [[0.25, 0.75]]), 1)