Skip to content

Commit

Permalink
Prevent metric dtype conversion when metric is part of `LightningMo…
Browse files Browse the repository at this point in the history
…dule` (#1583)

Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>

(cherry picked from commit 21b23b6)
  • Loading branch information
FarzanT authored and Borda committed Mar 10, 2023
1 parent d4c7e28 commit 377d5dc
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -25,6 +25,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed evaluation of `R2Score` with near constant target ([#1576](https://github.com/Lightning-AI/metrics/pull/1576))


- Fixed dtype conversion when metric is submodule ([#1583](https://github.com/Lightning-AI/metrics/pull/1583))


- Fixed bug related to `top_k>1` and `ignore_index!=None` in `StatScores` based metrics ([#1589](https://github.com/Lightning-AI/metrics/pull/1589))


Expand Down
10 changes: 10 additions & 0 deletions src/torchmetrics/collections.py
Expand Up @@ -471,3 +471,13 @@ def __repr__(self) -> str:
if self.postfix:
repr_str += f"{',' if not self.prefix else ''}\n postfix={self.postfix}"
return repr_str + "\n)"

def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "MetricCollection":
"""Transfer all metric state to specific dtype. Special version of standard `type` method.
Arguments:
dst_type (type or string): the desired type.
"""
for _, m in self.items(keep_base=True, copy_state=False):
m.set_dtype(dst_type)
return self
18 changes: 15 additions & 3 deletions src/torchmetrics/metric.py
Expand Up @@ -129,6 +129,7 @@ def __init__(
self._to_sync = self.sync_on_compute
self._should_unsync = True
self._enable_grad = False
self._dtype_convert = False

# initialize state
self._defaults: Dict[str, Union[List, Tensor]] = {}
Expand Down Expand Up @@ -619,12 +620,23 @@ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric":
Arguments:
dst_type (type or string): the desired type
"""
return super().type(dst_type)
self._dtype_convert = True
out = super().type(dst_type)
out._dtype_convert = False
return out

def _apply(self, fn: Callable) -> Module:
"""Overwrite _apply function such that we can also move metric states to the correct device when `.to`,
`.cuda`, etc methods are called."""
"""Overwrite _apply function such that we can also move metric states to the correct device.
This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods
are called. Dtype conversion is garded and will only happen through the special `set_dtype` method.
"""
this = super()._apply(fn)
fs = str(fn)
cond = any(f in fs for f in ["Module.type", "Module.half", "Module.float", "Module.double", "Module.bfloat16"])
if not self._dtype_convert and cond:
return this

# Also apply fn to metric states and defaults
for key, value in this._defaults.items():
if isinstance(value, Tensor):
Expand Down
46 changes: 46 additions & 0 deletions tests/integrations/test_lightning.py
Expand Up @@ -16,6 +16,7 @@
import torch
from pytorch_lightning import LightningModule, Trainer
from torch import tensor
from torch.nn import Linear
from torch.utils.data import DataLoader

from integrations.helpers import no_warning_call
Expand Down Expand Up @@ -300,3 +301,48 @@ def training_step(self, batch, batch_idx):
output = model(rand_input)
script_output = script_model(rand_input)
assert torch.allclose(output, script_output)


def test_dtype_in_pl_module_transfer(tmpdir):
"""Test that metric states don't change dtype when .half() or .float() is called on the LightningModule."""

class BoringModel(LightningModule):
def __init__(self, metric_dtype=torch.float32):
super().__init__()
self.layer = Linear(32, 32)
self.metric = SumMetric()
self.metric.set_dtype(metric_dtype)

def forward(self, x):
return self.layer(x)

def training_step(self, batch, batch_idx):
pred = self.forward(batch)
loss = self(batch).sum()
self.metric.update(torch.flatten(pred), torch.flatten(batch))

return {"loss": loss}

def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)

model = BoringModel()
assert model.metric.value.dtype == torch.float32
model = model.half()
assert model.metric.value.dtype == torch.float32

model = BoringModel()
assert model.metric.value.dtype == torch.float32
model = model.double()
assert model.metric.value.dtype == torch.float32

model = BoringModel(metric_dtype=torch.float16)
assert model.metric.value.dtype == torch.float16
model = model.float()
assert model.metric.value.dtype == torch.float16

model = BoringModel()
assert model.metric.value.dtype == torch.float32

model = model.type(torch.half)
assert model.metric.value.dtype == torch.float32
4 changes: 2 additions & 2 deletions tests/unittests/bases/test_collections.py
Expand Up @@ -91,11 +91,11 @@ def test_device_and_dtype_transfer_metriccollection(tmpdir):
for _, metric in metric_collection.items():
assert metric.x.is_cuda

metric_collection = metric_collection.double()
metric_collection = metric_collection.set_dtype(torch.double)
for _, metric in metric_collection.items():
assert metric.x.dtype == torch.float64

metric_collection = metric_collection.half()
metric_collection = metric_collection.set_dtype(torch.half)
for _, metric in metric_collection.items():
assert metric.x.dtype == torch.float16

Expand Down
15 changes: 15 additions & 0 deletions tests/unittests/bases/test_metric.py
Expand Up @@ -285,6 +285,21 @@ def test_device_and_dtype_transfer(tmpdir):
assert metric.x.dtype == torch.float16


def test_disable_of_normal_dtype_methods():
"""Check that the default dtype changing methods does nothing."""
metric = DummyMetricSum()
assert metric.x.dtype == torch.float32

metric = metric.half()
assert metric.x.dtype == torch.float32

metric = metric.double()
assert metric.x.dtype == torch.float32

metric = metric.type(torch.half)
assert metric.x.dtype == torch.float32


def test_warning_on_compute_before_update():
"""test that an warning is raised if user tries to call compute before update."""
metric = DummyMetricSum()
Expand Down

0 comments on commit 377d5dc

Please sign in to comment.