diff --git a/CHANGELOG.md b/CHANGELOG.md index f1951c1b9a6..7cf7115e8b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,10 +40,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- Fix compatibility between XLA in `_bincount` function ([#1471](https://github.com/Lightning-AI/metrics/pull/1471)) +- Fixed compatibility between XLA in `_bincount` function ([#1471](https://github.com/Lightning-AI/metrics/pull/1471)) -- +- Fixed type hints in methods belonging to `MetricTracker` wrapper ([#1472](https://github.com/Lightning-AI/metrics/pull/1472)) ## [0.11.1] - 2023-01-30 diff --git a/docs/source/wrappers/metric_tracker.rst b/docs/source/wrappers/metric_tracker.rst index 19ad2d3b343..ed98fd80663 100644 --- a/docs/source/wrappers/metric_tracker.rst +++ b/docs/source/wrappers/metric_tracker.rst @@ -14,3 +14,4 @@ ________________ .. autoclass:: torchmetrics.MetricTracker :noindex: + :exclude-members: update, compute diff --git a/src/torchmetrics/wrappers/tracker.py b/src/torchmetrics/wrappers/tracker.py index af958cd0a49..c9b5f799fb6 100644 --- a/src/torchmetrics/wrappers/tracker.py +++ b/src/torchmetrics/wrappers/tracker.py @@ -134,8 +134,17 @@ def compute(self) -> Any: self._check_for_increment("compute") return self[-1].compute() - def compute_all(self) -> Tensor: - """Compute the metric value for all tracked metrics.""" + def compute_all(self) -> Union[Tensor, Dict[str, Tensor]]: + """Compute the metric value for all tracked metrics. + + Return: + Either a single tensor if the tracked base object is a single metric, else if a metric collection is + provide a dict of tensors will be returned + + Raises: + ValueError: + If `self.increment` have not been called before this method is called. + """ self._check_for_increment("compute_all") # The i!=0 accounts for the self._base_metric should be ignored res = [metric.compute() for i, metric in enumerate(self) if i != 0] @@ -158,10 +167,10 @@ def best_metric( ) -> Union[ None, float, - Tuple[int, float], + Tuple[float, int], Tuple[None, None], Dict[str, Union[float, None]], - Tuple[Dict[str, Union[int, None]], Dict[str, Union[float, None]]], + Tuple[Dict[str, Union[float, None]], Dict[str, Union[int, None]]], ]: """Returns the highest metric out of all tracked. @@ -169,7 +178,21 @@ def best_metric( return_step: If ``True`` will also return the step with the highest metric value. Returns: - The best metric value, and optionally the time-step. + Either a single value or a tuple, depends on the value of ``return_step`` and the object being tracked. + + + - If a single metric is being tracked and ``return_step=False`` then a single tensor will be returned + - If a single metric is being tracked and ``return_step=True`` then a 2-element tuple will be returned, + where the first value is optimal value and second value is the corresponding optimal step + - If a metric collection is being tracked and ``return_step=False`` then a single dict will be returned, + where keys correspond to the different values of the collection and the values are the optimal metric + value + - If a metric collection is being bracked and ``return_step=True`` then a 2-element tuple will be returned + where each is a dict, with keys corresponding to the different values of th collection and the values + of the first dict being the optimal values and the values of the second dict being the optimal step + + In addtion the value in all cases may be ``None`` if the underlying metric does have a proper defined way + of being optimal. """ if isinstance(self._base_metric, Metric): fn = torch.max if self.maximize else torch.min @@ -212,5 +235,6 @@ def best_metric( return value def _check_for_increment(self, method: str) -> None: + """Method for checking that a metric that can be updated/used for computations has been intialized.""" if not self._increment_called: raise ValueError(f"`{method}` cannot be called before `.increment()` has been called") diff --git a/tests/unittests/wrappers/test_tracker.py b/tests/unittests/wrappers/test_tracker.py index 9485cf0e276..b1aff2c33dc 100644 --- a/tests/unittests/wrappers/test_tracker.py +++ b/tests/unittests/wrappers/test_tracker.py @@ -170,7 +170,7 @@ def test_best_metric_for_not_well_defined_metric_collection(base_metric): assert best is None with pytest.warns(UserWarning, match="Encountered the following error when trying to get the best metric.*"): - idx, best = tracker.best_metric(return_step=True) + best, idx = tracker.best_metric(return_step=True) if isinstance(best, dict): assert best["MulticlassAccuracy"] is not None