Skip to content

Commit

Permalink
Plot method for aggregation + refactor tests (#1485)
Browse files Browse the repository at this point in the history
* starting point
* fix aggregation methods
* update testing
* fix confusion matrix
* changelog
* cleaning
* drop unused plot_options

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
5 people committed Feb 27, 2023
1 parent fe86adf commit 2850524
Show file tree
Hide file tree
Showing 6 changed files with 412 additions and 205 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for deterministic evaluation on GPU for metrics that uses `torch.cumsum` operator ([#1499](https://github.com/Lightning-AI/metrics/pull/1499))


- Added support for plotting of aggregation metrics through `.plot()` method ([#1485](https://github.com/Lightning-AI/metrics/pull/1485))

### Changed

- Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370))
Expand Down
187 changes: 182 additions & 5 deletions src/torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, List, Union
from typing import Any, Callable, List, Optional, Sequence, Union

import torch
from torch import Tensor

from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SumMetric.plot", "MeanMetric.plot", "MaxMetric.plot", "MinMetric.plot"]


class BaseAggregator(Metric):
Expand All @@ -44,7 +48,7 @@ class BaseAggregator(Metric):
value: Tensor
is_differentiable = None
higher_is_better = None
full_state_update = False
full_state_update: bool = False

def __init__(
self,
Expand Down Expand Up @@ -128,7 +132,7 @@ class MaxMetric(BaseAggregator):
tensor(3.)
"""

full_state_update = True
full_state_update: bool = True

def __init__(
self,
Expand All @@ -153,6 +157,49 @@ def update(self, value: Union[float, Tensor]) -> None:
if value.numel(): # make sure tensor not empty
self.value = torch.max(self.value, torch.max(value))

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> from torchmetrics import MaxMetric
>>> metric = MaxMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> from torchmetrics import MaxMetric
>>> metric = MaxMetric()
>>> values = [ ]
>>> for i in range(10):
... values.append(metric(i))
>>> fig_, ax_ = metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val, ax=ax, higher_is_better=self.higher_is_better, name=self.__class__.__name__
)
return fig, ax


class MinMetric(BaseAggregator):
"""Aggregate a stream of value into their minimum value.
Expand Down Expand Up @@ -189,7 +236,7 @@ class MinMetric(BaseAggregator):
tensor(1.)
"""

full_state_update = True
full_state_update: bool = True

def __init__(
self,
Expand All @@ -214,6 +261,49 @@ def update(self, value: Union[float, Tensor]) -> None:
if value.numel(): # make sure tensor not empty
self.value = torch.min(self.value, torch.min(value))

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> from torchmetrics import MinMetric
>>> metric = MinMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> from torchmetrics import MinMetric
>>> metric = MinMetric()
>>> values = [ ]
>>> for i in range(10):
... values.append(metric(i))
>>> fig_, ax_ = metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val, ax=ax, higher_is_better=self.higher_is_better, name=self.__class__.__name__
)
return fig, ax


class SumMetric(BaseAggregator):
"""Aggregate a stream of value into their sum.
Expand Down Expand Up @@ -273,6 +363,50 @@ def update(self, value: Union[float, Tensor]) -> None:
if value.numel():
self.value += value.sum()

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> from torchmetrics import SumMetric
>>> metric = SumMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> from torch import rand, randint
>>> from torchmetrics import SumMetric
>>> metric = SumMetric()
>>> values = [ ]
>>> for i in range(10):
... values.append(metric([i, i+1]))
>>> fig_, ax_ = metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val, ax=ax, higher_is_better=self.higher_is_better, name=self.__class__.__name__
)
return fig, ax


class CatMetric(BaseAggregator):
"""Concatenate a stream of values.
Expand Down Expand Up @@ -407,3 +541,46 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0
def compute(self) -> Tensor:
"""Compute the aggregated value."""
return self.value / self.weight

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> from torchmetrics import MeanMetric
>>> metric = MeanMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> from torchmetrics import MeanMetric
>>> metric = MeanMetric()
>>> values = [ ]
>>> for i in range(10):
... values.append(metric([i, i+1]))
>>> fig_, ax_ = metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val, ax=ax, higher_is_better=self.higher_is_better, name=self.__class__.__name__
)
return fig, ax
82 changes: 78 additions & 4 deletions src/torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, List, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -40,7 +40,11 @@
from torchmetrics.utilities.plot import _PLOT_OUT_TYPE, plot_confusion_matrix

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["MulticlassConfusionMatrix.plot"]
__doctest_skip__ = [
"BinaryConfusionMatrix.plot",
"MulticlassConfusionMatrix.plot",
"MultilabelConfusionMatrix.plot",
]


class BinaryConfusionMatrix(Metric):
Expand Down Expand Up @@ -126,6 +130,39 @@ def compute(self) -> Tensor:
"""Compute confusion matrix."""
return _binary_confusion_matrix_compute(self.confmat, self.normalize)

def plot(
self, val: Optional[Tensor] = None, add_text: bool = True, labels: Optional[List[str]] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
add_text: if the value of each cell should be added to the plot
labels: a list of strings, if provided will be added to the plot to indicate the different classes
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randint
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> metric = MulticlassConfusionMatrix(num_classes=5)
>>> metric.update(randint(5, (20,)), randint(5, (20,)))
>>> fig_, ax_ = metric.plot()
"""
val = val or self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val, add_text=add_text, labels=labels)
return fig, ax


class MulticlassConfusionMatrix(Metric):
r"""Compute the `confusion matrix`_ for multiclass tasks.
Expand Down Expand Up @@ -231,12 +268,16 @@ def compute(self) -> Tensor:
"""Compute confusion matrix."""
return _multiclass_confusion_matrix_compute(self.confmat, self.normalize)

def plot(self, val: Optional[Tensor] = None) -> _PLOT_OUT_TYPE:
def plot(
self, val: Optional[Tensor] = None, add_text: bool = True, labels: Optional[List[str]] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
add_text: if the value of each cell should be added to the plot
labels: a list of strings, if provided will be added to the plot to indicate the different classes
Returns:
Figure and Axes object
Expand All @@ -257,7 +298,7 @@ def plot(self, val: Optional[Tensor] = None) -> _PLOT_OUT_TYPE:
val = val or self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val)
fig, ax = plot_confusion_matrix(val, add_text=add_text, labels=labels)
return fig, ax


Expand Down Expand Up @@ -351,6 +392,39 @@ def compute(self) -> Tensor:
"""Compute confusion matrix."""
return _multilabel_confusion_matrix_compute(self.confmat, self.normalize)

def plot(
self, val: Optional[Tensor] = None, add_text: bool = True, labels: Optional[List[str]] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
add_text: if the value of each cell should be added to the plot
labels: a list of strings, if provided will be added to the plot to indicate the different classes
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randint
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> metric = MulticlassConfusionMatrix(num_classes=5)
>>> metric.update(randint(5, (20,)), randint(5, (20,)))
>>> fig_, ax_ = metric.plot()
"""
val = val or self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val, add_text=add_text, labels=labels)
return fig, ax


class ConfusionMatrix:
r"""Compute the `confusion matrix`_.
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class Metric(Module, ABC):
is_differentiable: Optional[bool] = None
higher_is_better: Optional[bool] = None
full_state_update: Optional[bool] = None
plot_options: Dict[str, Union[str, float]] = {}

def __init__(
self,
Expand Down

0 comments on commit 2850524

Please sign in to comment.