Skip to content

Commit

Permalink
Add missing dtype check in PrecisionRecallCurve (#1457)
Browse files Browse the repository at this point in the history
Co-authored-by: stancld <daniel.stancl@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
(cherry picked from commit 94c49f3)
  • Loading branch information
SkafteNicki authored and Borda committed Jan 30, 2023
1 parent 4cafd20 commit b45213e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -34,6 +34,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed type checking on the `maximize` parameter at the initialization of `MetricTracker` ([#1428](https://github.com/Lightning-AI/metrics/issues/1428))


- Fix dtype checking in `PrecisionRecallCurve` for `target` tensor ([#1457](https://github.com/Lightning-AI/metrics/pull/1457))

## [0.11.0] - 2022-11-30

### Added
Expand Down
Expand Up @@ -132,6 +132,12 @@ def _binary_precision_recall_curve_tensor_validation(
"""
_check_same_shape(preds, target)

if target.is_floating_point():
raise ValueError(
"Expected argument `target` to be an int or long tensor with ground truth labels"
f" but got tensor with dtype {target.dtype}"
)

if not preds.is_floating_point():
raise ValueError(
"Expected argument `preds` to be an floating tensor with probability/logit scores,"
Expand Down Expand Up @@ -334,6 +340,10 @@ def _multiclass_precision_recall_curve_tensor_validation(
raise ValueError(
f"Expected `preds` to have one more dimension than `target` but got {preds.ndim} and {target.ndim}"
)
if target.is_floating_point():
raise ValueError(
f"Expected argument `target` to be an int or long tensor, but got tensor with dtype {target.dtype}"
)
if not preds.is_floating_point():
raise ValueError(f"Expected `preds` to be a float tensor, but got {preds.dtype}")
if preds.shape[1] != num_classes:
Expand Down
30 changes: 30 additions & 0 deletions tests/unittests/classification/test_precision_recall_curve.py
Expand Up @@ -132,6 +132,16 @@ def test_binary_precision_recall_curve_threshold_arg(self, input, threshold_fn):
assert torch.allclose(r1, r2)
assert torch.allclose(t1, t2)

def test_binary_error_on_wrong_dtypes(self, input):
"""Test that error are raised on wrong dtype."""
preds, target = input

with pytest.raises(ValueError, match="Expected argument `target` to be an int or long tensor with ground.*"):
binary_precision_recall_curve(preds[0], target[0].to(torch.float32))

with pytest.raises(ValueError, match="Expected argument `preds` to be an floating tensor with probability.*"):
binary_precision_recall_curve(preds[0].long(), target[0])


def _sk_precision_recall_curve_multiclass(preds, target, ignore_index=None):
preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1]))
Expand Down Expand Up @@ -243,6 +253,16 @@ def test_multiclass_precision_recall_curve_threshold_arg(self, input, threshold_
assert torch.allclose(r1[i], r2[i])
assert torch.allclose(t1[i], t2)

def test_multiclass_error_on_wrong_dtypes(self, input):
"""Test that error are raised on wrong dtype."""
preds, target = input

with pytest.raises(ValueError, match="Expected argument `target` to be an int or long tensor, but got.*"):
multiclass_precision_recall_curve(preds[0], target[0].to(torch.float32), num_classes=NUM_CLASSES)

with pytest.raises(ValueError, match="Expected `preds` to be a float tensor, but got.*"):
multiclass_precision_recall_curve(preds[0].long(), target[0], num_classes=NUM_CLASSES)


def _sk_precision_recall_curve_multilabel(preds, target, ignore_index=None):
precision, recall, thresholds = [], [], []
Expand Down Expand Up @@ -345,6 +365,16 @@ def test_multilabel_precision_recall_curve_threshold_arg(self, input, threshold_
assert torch.allclose(r1[i], r2[i])
assert torch.allclose(t1[i], t2)

def test_multilabel_error_on_wrong_dtypes(self, input):
"""Test that error are raised on wrong dtype."""
preds, target = input

with pytest.raises(ValueError, match="Expected argument `target` to be an int or long tensor with ground.*"):
multilabel_precision_recall_curve(preds[0], target[0].to(torch.float32), num_labels=NUM_CLASSES)

with pytest.raises(ValueError, match="Expected argument `preds` to be an floating tensor with probability.*"):
multilabel_precision_recall_curve(preds[0].long(), target[0], num_labels=NUM_CLASSES)


@pytest.mark.parametrize(
"metric",
Expand Down

0 comments on commit b45213e

Please sign in to comment.