Skip to content

Commit

Permalink
Fix type hints (#1472)
Browse files Browse the repository at this point in the history
* Fix return type in MetricTracker.compute_all

* Fix return type in MetricTracker.best_metric

* Fix test_best_metric_for_not_well_defined_metric_collection

* Assign MetricTracker.best_metric return value to a tuple of the form (best, idx), instead of (idx, best).

* changelog

* add docstring

---------

Co-authored-by: SkafteNicki <skaftenicki@gmail.com>
  • Loading branch information
ValerianRey and SkafteNicki committed Feb 1, 2023
1 parent 2ce0efb commit fbdc393
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 8 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Expand Up @@ -40,10 +40,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fix compatibility between XLA in `_bincount` function ([#1471](https://github.com/Lightning-AI/metrics/pull/1471))
- Fixed compatibility between XLA in `_bincount` function ([#1471](https://github.com/Lightning-AI/metrics/pull/1471))


-
- Fixed type hints in methods belonging to `MetricTracker` wrapper ([#1472](https://github.com/Lightning-AI/metrics/pull/1472))


## [0.11.1] - 2023-01-30
Expand Down
1 change: 1 addition & 0 deletions docs/source/wrappers/metric_tracker.rst
Expand Up @@ -14,3 +14,4 @@ ________________

.. autoclass:: torchmetrics.MetricTracker
:noindex:
:exclude-members: update, compute
34 changes: 29 additions & 5 deletions src/torchmetrics/wrappers/tracker.py
Expand Up @@ -134,8 +134,17 @@ def compute(self) -> Any:
self._check_for_increment("compute")
return self[-1].compute()

def compute_all(self) -> Tensor:
"""Compute the metric value for all tracked metrics."""
def compute_all(self) -> Union[Tensor, Dict[str, Tensor]]:
"""Compute the metric value for all tracked metrics.
Return:
Either a single tensor if the tracked base object is a single metric, else if a metric collection is
provide a dict of tensors will be returned
Raises:
ValueError:
If `self.increment` have not been called before this method is called.
"""
self._check_for_increment("compute_all")
# The i!=0 accounts for the self._base_metric should be ignored
res = [metric.compute() for i, metric in enumerate(self) if i != 0]
Expand All @@ -158,18 +167,32 @@ def best_metric(
) -> Union[
None,
float,
Tuple[int, float],
Tuple[float, int],
Tuple[None, None],
Dict[str, Union[float, None]],
Tuple[Dict[str, Union[int, None]], Dict[str, Union[float, None]]],
Tuple[Dict[str, Union[float, None]], Dict[str, Union[int, None]]],
]:
"""Returns the highest metric out of all tracked.
Args:
return_step: If ``True`` will also return the step with the highest metric value.
Returns:
The best metric value, and optionally the time-step.
Either a single value or a tuple, depends on the value of ``return_step`` and the object being tracked.
- If a single metric is being tracked and ``return_step=False`` then a single tensor will be returned
- If a single metric is being tracked and ``return_step=True`` then a 2-element tuple will be returned,
where the first value is optimal value and second value is the corresponding optimal step
- If a metric collection is being tracked and ``return_step=False`` then a single dict will be returned,
where keys correspond to the different values of the collection and the values are the optimal metric
value
- If a metric collection is being bracked and ``return_step=True`` then a 2-element tuple will be returned
where each is a dict, with keys corresponding to the different values of th collection and the values
of the first dict being the optimal values and the values of the second dict being the optimal step
In addtion the value in all cases may be ``None`` if the underlying metric does have a proper defined way
of being optimal.
"""
if isinstance(self._base_metric, Metric):
fn = torch.max if self.maximize else torch.min
Expand Down Expand Up @@ -212,5 +235,6 @@ def best_metric(
return value

def _check_for_increment(self, method: str) -> None:
"""Method for checking that a metric that can be updated/used for computations has been intialized."""
if not self._increment_called:
raise ValueError(f"`{method}` cannot be called before `.increment()` has been called")
2 changes: 1 addition & 1 deletion tests/unittests/wrappers/test_tracker.py
Expand Up @@ -170,7 +170,7 @@ def test_best_metric_for_not_well_defined_metric_collection(base_metric):
assert best is None

with pytest.warns(UserWarning, match="Encountered the following error when trying to get the best metric.*"):
idx, best = tracker.best_metric(return_step=True)
best, idx = tracker.best_metric(return_step=True)

if isinstance(best, dict):
assert best["MulticlassAccuracy"] is not None
Expand Down

0 comments on commit fbdc393

Please sign in to comment.