-
Notifications
You must be signed in to change notification settings - Fork 387
/
precision_recall_curve.py
834 lines (718 loc) · 39.5 KB
/
precision_recall_curve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor, tensor
from torch.nn import functional as F
from typing_extensions import Literal
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import _safe_divide
from torchmetrics.utilities.data import _bincount
def _binary_clf_curve(
preds: Tensor,
target: Tensor,
sample_weights: Optional[Sequence] = None,
pos_label: int = 1,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Calculates the tps and false positives for all unique thresholds in the preds tensor. Adapted from
https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_ranking.py.
Args:
preds: 1d tensor with predictions
target: 1d tensor with true values
sample_weights: a 1d tensor with a weight per sample
pos_label: interger determining what the positive class in target tensor is
Returns:
fps: 1d tensor with false positives for different thresholds
tps: 1d tensor with true positives for different thresholds
thresholds: the unique thresholds use for calculating fps and tps
"""
with torch.no_grad():
if sample_weights is not None and not isinstance(sample_weights, Tensor):
sample_weights = tensor(sample_weights, device=preds.device, dtype=torch.float)
# remove class dimension if necessary
if preds.ndim > target.ndim:
preds = preds[:, 0]
desc_score_indices = torch.argsort(preds, descending=True)
preds = preds[desc_score_indices]
target = target[desc_score_indices]
if sample_weights is not None:
weight = sample_weights[desc_score_indices]
else:
weight = 1.0
# pred typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]
threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1)
target = (target == pos_label).to(torch.long)
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
if sample_weights is not None:
# express fps as a cumsum to ensure fps is increasing even in
# the presence of floating point errors
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]
else:
fps = 1 + threshold_idxs - tps
return fps, tps, preds[threshold_idxs]
def _adjust_threshold_arg(
thresholds: Optional[Union[int, List[float], Tensor]] = None, device: Optional[torch.device] = None
) -> Optional[Tensor]:
"""Utility function for converting the threshold arg for list and int to tensor format."""
if isinstance(thresholds, int):
thresholds = torch.linspace(0, 1, thresholds, device=device)
if isinstance(thresholds, list):
thresholds = torch.tensor(thresholds, device=device)
return thresholds
def _binary_precision_recall_curve_arg_validation(
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
) -> None:
"""Validate non tensor input.
- ``threshold`` has to be None | a 1d tensor | a list of floats in the [0,1] range | an int
- ``ignore_index`` has to be None or int
"""
if thresholds is not None and not isinstance(thresholds, (list, int, Tensor)):
raise ValueError(
"Expected argument `thresholds` to either be an integer, list of floats or"
f" tensor of floats, but got {thresholds}"
)
else:
if isinstance(thresholds, int) and thresholds < 2:
raise ValueError(
f"If argument `thresholds` is an integer, expected it to be larger than 1, but got {thresholds}"
)
if isinstance(thresholds, list) and not all(isinstance(t, float) and 0 <= t <= 1 for t in thresholds):
raise ValueError(
"If argument `thresholds` is a list, expected all elements to be floats in the [0,1] range,"
f" but got {thresholds}"
)
if isinstance(thresholds, Tensor) and not thresholds.ndim == 1:
raise ValueError("If argument `thresholds` is an tensor, expected the tensor to be 1d")
if ignore_index is not None and not isinstance(ignore_index, int):
raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}")
def _binary_precision_recall_curve_tensor_validation(
preds: Tensor, target: Tensor, ignore_index: Optional[int] = None
) -> None:
"""Validate tensor input.
- tensors have to be of same shape
- all values in target tensor that are not ignored have to be in {0, 1}
- that the pred tensor is floating point
"""
_check_same_shape(preds, target)
if target.is_floating_point():
raise ValueError(
"Expected argument `target` to be an int or long tensor with ground truth labels"
f" but got tensor with dtype {target.dtype}"
)
if not preds.is_floating_point():
raise ValueError(
"Expected argument `preds` to be an floating tensor with probability/logit scores,"
f" but got tensor with dtype {preds.dtype}"
)
# Check that target only contains {0,1} values or value in ignore_index
unique_values = torch.unique(target)
if ignore_index is None:
check = torch.any((unique_values != 0) & (unique_values != 1))
else:
check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index))
if check:
raise RuntimeError(
f"Detected the following values in `target`: {unique_values} but expected only"
f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}."
)
def _binary_precision_recall_curve_format(
preds: Tensor,
target: Tensor,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
"""Convert all input to the right format.
- flattens additional dimensions
- Remove all datapoints that should be ignored
- Applies sigmoid if pred tensor not in [0,1] range
- Format thresholds arg to be a tensor
"""
preds = preds.flatten()
target = target.flatten()
if ignore_index is not None:
idx = target != ignore_index
preds = preds[idx]
target = target[idx]
if not torch.all((0 <= preds) * (preds <= 1)):
preds = preds.sigmoid()
thresholds = _adjust_threshold_arg(thresholds, preds.device)
return preds, target, thresholds
def _binary_precision_recall_curve_update(
preds: Tensor,
target: Tensor,
thresholds: Optional[Tensor],
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Returns the state to calculate the pr-curve with.
If thresholds is `None` the direct preds and targets are used. If thresholds is not `None` we compute a multi
threshold confusion matrix.
"""
if thresholds is None:
return preds, target
len_t = len(thresholds)
preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0)).long() # num_samples x num_thresholds
unique_mapping = preds_t + 2 * target.unsqueeze(-1) + 4 * torch.arange(len_t, device=target.device)
bins = _bincount(unique_mapping.flatten(), minlength=4 * len_t)
return bins.reshape(len_t, 2, 2)
def _binary_precision_recall_curve_compute(
state: Union[Tensor, Tuple[Tensor, Tensor]],
thresholds: Optional[Tensor],
pos_label: int = 1,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Computes the final pr-curve.
If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is
original input, then we dynamically compute the binary classification curve.
"""
if isinstance(state, Tensor):
tps = state[:, 1, 1]
fps = state[:, 0, 1]
fns = state[:, 1, 0]
precision = _safe_divide(tps, tps + fps)
recall = _safe_divide(tps, tps + fns)
precision = torch.cat([precision, torch.ones(1, dtype=precision.dtype, device=precision.device)])
recall = torch.cat([recall, torch.zeros(1, dtype=recall.dtype, device=recall.device)])
return precision, recall, thresholds
else:
fps, tps, thresholds = _binary_clf_curve(state[0], state[1], pos_label=pos_label)
precision = tps / (tps + fps)
recall = tps / tps[-1]
# stop when full recall attained and reverse the outputs so recall is decreasing
last_ind = torch.where(tps == tps[-1])[0][0]
sl = slice(0, last_ind.item() + 1)
# need to call reversed explicitly, since including that to slice would
# introduce negative strides that are not yet supported in pytorch
precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)])
recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)])
thresholds = reversed(thresholds[sl]).detach().clone() # type: ignore
return precision, recall, thresholds
def binary_precision_recall_curve(
preds: Tensor,
target: Tensor,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Computes the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and
recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.
Accepts the following input tensors:
- ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each
observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
sigmoid per element.
- ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class.
Additional dimension ``...`` will be flattened into the batch dimension.
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
Args:
preds: Tensor with predictions
target: Tensor with true labels
thresholds:
Can be one of:
- If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
all the data. Most accurate but also most memory consuming approach.
- If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
0 to 1 as bins for the calculation.
- If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
Returns:
(tuple): a tuple of 3 tensors containing:
- precision: an 1d tensor of size (n_thresholds+1, ) with precision values
- recall: an 1d tensor of size (n_thresholds+1, ) with recall values
- thresholds: an 1d tensor of size (n_thresholds, ) with increasing threshold values
Example:
>>> from torchmetrics.functional.classification import binary_precision_recall_curve
>>> preds = torch.tensor([0, 0.5, 0.7, 0.8])
>>> target = torch.tensor([0, 1, 1, 0])
>>> binary_precision_recall_curve(preds, target, thresholds=None) # doctest: +NORMALIZE_WHITESPACE
(tensor([0.6667, 0.5000, 0.0000, 1.0000]),
tensor([1.0000, 0.5000, 0.0000, 0.0000]),
tensor([0.5000, 0.7000, 0.8000]))
>>> binary_precision_recall_curve(preds, target, thresholds=5) # doctest: +NORMALIZE_WHITESPACE
(tensor([0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000]),
tensor([1., 1., 1., 0., 0., 0.]),
tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
"""
if validate_args:
_binary_precision_recall_curve_arg_validation(thresholds, ignore_index)
_binary_precision_recall_curve_tensor_validation(preds, target, ignore_index)
preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index)
state = _binary_precision_recall_curve_update(preds, target, thresholds)
return _binary_precision_recall_curve_compute(state, thresholds)
def _multiclass_precision_recall_curve_arg_validation(
num_classes: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
) -> None:
"""Validate non tensor input.
- ``num_classes`` has to be an int larger than 1
- ``threshold`` has to be None | a 1d tensor | a list of floats in the [0,1] range | an int
- ``ignore_index`` has to be None or int
"""
if not isinstance(num_classes, int) or num_classes < 2:
raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}")
_binary_precision_recall_curve_arg_validation(thresholds, ignore_index)
def _multiclass_precision_recall_curve_tensor_validation(
preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None
) -> None:
"""Validate tensor input.
- target should have one more dimension than preds and all dimensions except for preds.shape[1] should match
exactly. preds.shape[1] should have size equal to number of classes
- all values in target tensor that are not ignored have to be in {0, 1}
"""
if not preds.ndim == target.ndim + 1:
raise ValueError(
f"Expected `preds` to have one more dimension than `target` but got {preds.ndim} and {target.ndim}"
)
if target.is_floating_point():
raise ValueError(
f"Expected argument `target` to be an int or long tensor, but got tensor with dtype {target.dtype}"
)
if not preds.is_floating_point():
raise ValueError(f"Expected `preds` to be a float tensor, but got {preds.dtype}")
if preds.shape[1] != num_classes:
raise ValueError(
"Expected `preds.shape[1]` to be equal to the number of classes but"
f" got {preds.shape[1]} and {num_classes}."
)
if preds.shape[0] != target.shape[0] or preds.shape[2:] != target.shape[1:]:
raise ValueError(
"Expected the shape of `preds` should be (N, C, ...) and the shape of `target` should be (N, ...)"
f" but got {preds.shape} and {target.shape}"
)
num_unique_values = len(torch.unique(target))
if ignore_index is None:
check = num_unique_values > num_classes
else:
check = num_unique_values > num_classes + 1
if check:
raise RuntimeError(
"Detected more unique values in `target` than `num_classes`. Expected only "
f"{num_classes if ignore_index is None else num_classes + 1} but found "
f"{num_unique_values} in `target`."
)
def _multiclass_precision_recall_curve_format(
preds: Tensor,
target: Tensor,
num_classes: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
"""Convert all input to the right format.
- flattens additional dimensions
- Remove all datapoints that should be ignored
- Applies softmax if pred tensor not in [0,1] range
- Format thresholds arg to be a tensor
"""
preds = preds.transpose(0, 1).reshape(num_classes, -1).T
target = target.flatten()
if ignore_index is not None:
idx = target != ignore_index
preds = preds[idx]
target = target[idx]
if not torch.all((0 <= preds) * (preds <= 1)):
preds = preds.softmax(1)
thresholds = _adjust_threshold_arg(thresholds, preds.device)
return preds, target, thresholds
def _multiclass_precision_recall_curve_update(
preds: Tensor,
target: Tensor,
num_classes: int,
thresholds: Optional[Tensor],
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Returns the state to calculate the pr-curve with.
If thresholds is `None` the direct preds and targets are used. If thresholds is not `None` we compute a multi
threshold confusion matrix.
"""
if thresholds is None:
return preds, target
len_t = len(thresholds)
# num_samples x num_classes x num_thresholds
preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0).unsqueeze(0)).long()
target_t = torch.nn.functional.one_hot(target, num_classes=num_classes)
unique_mapping = preds_t + 2 * target_t.unsqueeze(-1)
unique_mapping += 4 * torch.arange(num_classes, device=preds.device).unsqueeze(0).unsqueeze(-1)
unique_mapping += 4 * num_classes * torch.arange(len_t, device=preds.device)
bins = _bincount(unique_mapping.flatten(), minlength=4 * num_classes * len_t)
return bins.reshape(len_t, num_classes, 2, 2)
def _multiclass_precision_recall_curve_compute(
state: Union[Tensor, Tuple[Tensor, Tensor]],
num_classes: int,
thresholds: Optional[Tensor],
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Computes the final pr-curve.
If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is
original input, then we dynamically compute the binary classification curve in an iterative way.
"""
if isinstance(state, Tensor):
tps = state[:, :, 1, 1]
fps = state[:, :, 0, 1]
fns = state[:, :, 1, 0]
precision = _safe_divide(tps, tps + fps)
recall = _safe_divide(tps, tps + fns)
precision = torch.cat([precision, torch.ones(1, num_classes, dtype=precision.dtype, device=precision.device)])
recall = torch.cat([recall, torch.zeros(1, num_classes, dtype=recall.dtype, device=recall.device)])
return precision.T, recall.T, thresholds
else:
precision, recall, thresholds = [], [], []
for i in range(num_classes):
res = _binary_precision_recall_curve_compute([state[0][:, i], state[1]], thresholds=None, pos_label=i)
precision.append(res[0])
recall.append(res[1])
thresholds.append(res[2])
return precision, recall, thresholds
def multiclass_precision_recall_curve(
preds: Tensor,
target: Tensor,
num_classes: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
r"""Computes the precision-recall curve for multiclass tasks. The curve consist of multiple pairs of precision
and recall values evaluated at different thresholds, such that the tradeoff between the two values can been
seen.
Accepts the following input tensors:
- ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
softmax per sample.
- ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).
Additional dimension ``...`` will be flattened into the batch dimension.
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory).
Args:
preds: Tensor with predictions
target: Tensor with true labels
num_classes: Integer specifing the number of classes
thresholds:
Can be one of:
- If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
all the data. Most accurate but also most memory consuming approach.
- If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
0 to 1 as bins for the calculation.
- If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
Returns:
(tuple): a tuple of either 3 tensors or 3 lists containing
- precision: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, )
with precision values (length may differ between classes). If `thresholds` is set to something else,
then a single 2d tensor of size (n_classes, n_thresholds+1) with precision values is returned.
- recall: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, )
with recall values (length may differ between classes). If `thresholds` is set to something else,
then a single 2d tensor of size (n_classes, n_thresholds+1) with recall values is returned.
- thresholds: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds, )
with increasing threshold values (length may differ between classes). If `threshold` is set to something else,
then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes.
Example:
>>> from torchmetrics.functional.classification import multiclass_precision_recall_curve
>>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> precision, recall, thresholds = multiclass_precision_recall_curve(
... preds, target, num_classes=5, thresholds=None
... )
>>> precision # doctest: +NORMALIZE_WHITESPACE
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
>>> recall
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
>>> thresholds
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
>>> multiclass_precision_recall_curve(
... preds, target, num_classes=5, thresholds=5
... ) # doctest: +NORMALIZE_WHITESPACE
(tensor([[0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000],
[0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000],
[0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
[0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]),
tensor([[1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0., 0.],
[1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]]),
tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
"""
if validate_args:
_multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index)
_multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index)
preds, target, thresholds = _multiclass_precision_recall_curve_format(
preds, target, num_classes, thresholds, ignore_index
)
state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds)
return _multiclass_precision_recall_curve_compute(state, num_classes, thresholds)
def _multilabel_precision_recall_curve_arg_validation(
num_labels: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
) -> None:
"""Validate non tensor input.
- ``num_labels`` has to be an int larger than 1
- ``threshold`` has to be None | a 1d tensor | a list of floats in the [0,1] range | an int
- ``ignore_index`` has to be None or int
"""
_multiclass_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index)
def _multilabel_precision_recall_curve_tensor_validation(
preds: Tensor, target: Tensor, num_labels: int, ignore_index: Optional[int] = None
) -> None:
"""Validate tensor input.
- tensors have to be of same shape
- preds.shape[1] is equal to the number of labels
- all values in target tensor that are not ignored have to be in {0, 1}
- that the pred tensor is floating point
"""
_binary_precision_recall_curve_tensor_validation(preds, target, ignore_index)
if preds.shape[1] != num_labels:
raise ValueError(
"Expected both `target.shape[1]` and `preds.shape[1]` to be equal to the number of labels"
f" but got {preds.shape[1]} and expected {num_labels}"
)
def _multilabel_precision_recall_curve_format(
preds: Tensor,
target: Tensor,
num_labels: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
"""Convert all input to the right format.
- flattens additional dimensions
- Mask all datapoints that should be ignored with negative values
- Applies sigmoid if pred tensor not in [0,1] range
- Format thresholds arg to be a tensor
"""
preds = preds.transpose(0, 1).reshape(num_labels, -1).T
target = target.transpose(0, 1).reshape(num_labels, -1).T
if not torch.all((0 <= preds) * (preds <= 1)):
preds = preds.sigmoid()
thresholds = _adjust_threshold_arg(thresholds, preds.device)
if ignore_index is not None and thresholds is not None:
preds = preds.clone()
target = target.clone()
# Make sure that when we map, it will always result in a negative number that we can filter away
idx = target == ignore_index
preds[idx] = -4 * num_labels * (len(thresholds) if thresholds is not None else 1)
target[idx] = -4 * num_labels * (len(thresholds) if thresholds is not None else 1)
return preds, target, thresholds
def _multilabel_precision_recall_curve_update(
preds: Tensor,
target: Tensor,
num_labels: int,
thresholds: Optional[Tensor],
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Returns the state to calculate the pr-curve with.
If thresholds is `None` the direct preds and targets are used. If thresholds is not `None` we compute a multi
threshold confusion matrix.
"""
if thresholds is None:
return preds, target
len_t = len(thresholds)
# num_samples x num_labels x num_thresholds
preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0).unsqueeze(0)).long()
unique_mapping = preds_t + 2 * target.unsqueeze(-1)
unique_mapping += 4 * torch.arange(num_labels, device=preds.device).unsqueeze(0).unsqueeze(-1)
unique_mapping += 4 * num_labels * torch.arange(len_t, device=preds.device)
unique_mapping = unique_mapping[unique_mapping >= 0]
bins = _bincount(unique_mapping, minlength=4 * num_labels * len_t)
return bins.reshape(len_t, num_labels, 2, 2)
def _multilabel_precision_recall_curve_compute(
state: Union[Tensor, Tuple[Tensor, Tensor]],
num_labels: int,
thresholds: Optional[Tensor],
ignore_index: Optional[int] = None,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Computes the final pr-curve.
If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is
original input, then we dynamically compute the binary classification curve in an iterative way.
"""
if isinstance(state, Tensor):
tps = state[:, :, 1, 1]
fps = state[:, :, 0, 1]
fns = state[:, :, 1, 0]
precision = _safe_divide(tps, tps + fps)
recall = _safe_divide(tps, tps + fns)
precision = torch.cat([precision, torch.ones(1, num_labels, dtype=precision.dtype, device=precision.device)])
recall = torch.cat([recall, torch.zeros(1, num_labels, dtype=recall.dtype, device=recall.device)])
return precision.T, recall.T, thresholds
else:
precision, recall, thresholds = [], [], []
for i in range(num_labels):
preds = state[0][:, i]
target = state[1][:, i]
if ignore_index is not None:
idx = target == ignore_index
preds = preds[~idx]
target = target[~idx]
res = _binary_precision_recall_curve_compute([preds, target], thresholds=None, pos_label=1)
precision.append(res[0])
recall.append(res[1])
thresholds.append(res[2])
return precision, recall, thresholds
def multilabel_precision_recall_curve(
preds: Tensor,
target: Tensor,
num_labels: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
r"""Computes the precision-recall curve for multilabel tasks. The curve consist of multiple pairs of precision
and recall values evaluated at different thresholds, such that the tradeoff between the two values can been
seen.
Accepts the following input tensors:
- ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
sigmoid per element.
- ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore
only contain {0,1} values (except if `ignore_index` is specified).
Additional dimension ``...`` will be flattened into the batch dimension.
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).
Args:
preds: Tensor with predictions
target: Tensor with true labels
num_labels: Integer specifing the number of labels
thresholds:
Can be one of:
- If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
all the data. Most accurate but also most memory consuming approach.
- If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
0 to 1 as bins for the calculation.
- If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
Returns:
(tuple): a tuple of either 3 tensors or 3 lists containing
- precision: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, )
with precision values (length may differ between labels). If `thresholds` is set to something else,
then a single 2d tensor of size (n_labels, n_thresholds+1) with precision values is returned.
- recall: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, )
with recall values (length may differ between labels). If `thresholds` is set to something else,
then a single 2d tensor of size (n_labels, n_thresholds+1) with recall values is returned.
- thresholds: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds, )
with increasing threshold values (length may differ between labels). If `threshold` is set to something else,
then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels.
Example:
>>> from torchmetrics.functional.classification import multilabel_precision_recall_curve
>>> preds = torch.tensor([[0.75, 0.05, 0.35],
... [0.45, 0.75, 0.05],
... [0.05, 0.55, 0.75],
... [0.05, 0.65, 0.05]])
>>> target = torch.tensor([[1, 0, 1],
... [0, 0, 0],
... [0, 1, 1],
... [1, 1, 1]])
>>> precision, recall, thresholds = multilabel_precision_recall_curve(
... preds, target, num_labels=3, thresholds=None
... )
>>> precision # doctest: +NORMALIZE_WHITESPACE
[tensor([0.5000, 0.5000, 1.0000, 1.0000]), tensor([0.6667, 0.5000, 0.0000, 1.0000]),
tensor([0.7500, 1.0000, 1.0000, 1.0000])]
>>> recall # doctest: +NORMALIZE_WHITESPACE
[tensor([1.0000, 0.5000, 0.5000, 0.0000]), tensor([1.0000, 0.5000, 0.0000, 0.0000]),
tensor([1.0000, 0.6667, 0.3333, 0.0000])]
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
[tensor([0.0500, 0.4500, 0.7500]), tensor([0.5500, 0.6500, 0.7500]),
tensor([0.0500, 0.3500, 0.7500])]
>>> multilabel_precision_recall_curve(
... preds, target, num_labels=3, thresholds=5
... ) # doctest: +NORMALIZE_WHITESPACE
(tensor([[0.5000, 0.5000, 1.0000, 1.0000, 0.0000, 1.0000],
[0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000],
[0.7500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000]]),
tensor([[1.0000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000],
[1.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000],
[1.0000, 0.6667, 0.3333, 0.3333, 0.0000, 0.0000]]),
tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
"""
if validate_args:
_multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index)
_multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index)
preds, target, thresholds = _multilabel_precision_recall_curve_format(
preds, target, num_labels, thresholds, ignore_index
)
state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds)
return _multilabel_precision_recall_curve_compute(state, num_labels, thresholds, ignore_index)
def precision_recall_curve(
preds: Tensor,
target: Tensor,
task: Literal["binary", "multiclass", "multilabel"],
thresholds: Optional[Union[int, List[float], Tensor]] = None,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
r"""Computes the precision-recall curve. The curve consist of multiple pairs of precision and recall values
evaluated at different thresholds, such that the tradeoff between the two values can been seen.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of
:func:`binary_precision_recall_curve`, :func:`multiclass_precision_recall_curve` and
:func:`multilabel_precision_recall_curve` for the specific details of each argument influence and examples.
Legacy Example:
>>> pred = torch.tensor([0.0, 1.0, 2.0, 3.0])
>>> target = torch.tensor([0, 1, 1, 0])
>>> precision, recall, thresholds = precision_recall_curve(pred, target, task='binary')
>>> precision
tensor([0.6667, 0.5000, 0.0000, 1.0000])
>>> recall
tensor([1.0000, 0.5000, 0.0000, 0.0000])
>>> thresholds
tensor([0.7311, 0.8808, 0.9526])
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> precision, recall, thresholds = precision_recall_curve(pred, target, task='multiclass', num_classes=5)
>>> precision
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
>>> recall
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
>>> thresholds
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
"""
if task == "binary":
return binary_precision_recall_curve(preds, target, thresholds, ignore_index, validate_args)
if task == "multiclass":
assert isinstance(num_classes, int)
return multiclass_precision_recall_curve(preds, target, num_classes, thresholds, ignore_index, validate_args)
if task == "multilabel":
assert isinstance(num_labels, int)
return multilabel_precision_recall_curve(preds, target, num_labels, thresholds, ignore_index, validate_args)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)