diff --git a/CHANGELOG.md b/CHANGELOG.md index c5ddeb65e9c..df2bf32b38b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed evaluation of `R2Score` with near constant target ([#1576](https://github.com/Lightning-AI/metrics/pull/1576)) +- Fixed dtype conversion when metric is submodule ([#1583](https://github.com/Lightning-AI/metrics/pull/1583)) + + - Fixed bug related to `top_k>1` and `ignore_index!=None` in `StatScores` based metrics ([#1589](https://github.com/Lightning-AI/metrics/pull/1589)) diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 8e647d3f220..5554dbaa13a 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -471,3 +471,13 @@ def __repr__(self) -> str: if self.postfix: repr_str += f"{',' if not self.prefix else ''}\n postfix={self.postfix}" return repr_str + "\n)" + + def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "MetricCollection": + """Transfer all metric state to specific dtype. Special version of standard `type` method. + + Arguments: + dst_type (type or string): the desired type. + """ + for _, m in self.items(keep_base=True, copy_state=False): + m.set_dtype(dst_type) + return self diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index b6c8a928ca9..13a21214fcd 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -129,6 +129,7 @@ def __init__( self._to_sync = self.sync_on_compute self._should_unsync = True self._enable_grad = False + self._dtype_convert = False # initialize state self._defaults: Dict[str, Union[List, Tensor]] = {} @@ -619,12 +620,23 @@ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric": Arguments: dst_type (type or string): the desired type """ - return super().type(dst_type) + self._dtype_convert = True + out = super().type(dst_type) + out._dtype_convert = False + return out def _apply(self, fn: Callable) -> Module: - """Overwrite _apply function such that we can also move metric states to the correct device when `.to`, - `.cuda`, etc methods are called.""" + """Overwrite _apply function such that we can also move metric states to the correct device. + + This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods + are called. Dtype conversion is garded and will only happen through the special `set_dtype` method. + """ this = super()._apply(fn) + fs = str(fn) + cond = any(f in fs for f in ["Module.type", "Module.half", "Module.float", "Module.double", "Module.bfloat16"]) + if not self._dtype_convert and cond: + return this + # Also apply fn to metric states and defaults for key, value in this._defaults.items(): if isinstance(value, Tensor): diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_lightning.py index e05d71b3cb7..80bcff587c7 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_lightning.py @@ -16,6 +16,7 @@ import torch from pytorch_lightning import LightningModule, Trainer from torch import tensor +from torch.nn import Linear from torch.utils.data import DataLoader from integrations.helpers import no_warning_call @@ -300,3 +301,48 @@ def training_step(self, batch, batch_idx): output = model(rand_input) script_output = script_model(rand_input) assert torch.allclose(output, script_output) + + +def test_dtype_in_pl_module_transfer(tmpdir): + """Test that metric states don't change dtype when .half() or .float() is called on the LightningModule.""" + + class BoringModel(LightningModule): + def __init__(self, metric_dtype=torch.float32): + super().__init__() + self.layer = Linear(32, 32) + self.metric = SumMetric() + self.metric.set_dtype(metric_dtype) + + def forward(self, x): + return self.layer(x) + + def training_step(self, batch, batch_idx): + pred = self.forward(batch) + loss = self(batch).sum() + self.metric.update(torch.flatten(pred), torch.flatten(batch)) + + return {"loss": loss} + + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + model = BoringModel() + assert model.metric.value.dtype == torch.float32 + model = model.half() + assert model.metric.value.dtype == torch.float32 + + model = BoringModel() + assert model.metric.value.dtype == torch.float32 + model = model.double() + assert model.metric.value.dtype == torch.float32 + + model = BoringModel(metric_dtype=torch.float16) + assert model.metric.value.dtype == torch.float16 + model = model.float() + assert model.metric.value.dtype == torch.float16 + + model = BoringModel() + assert model.metric.value.dtype == torch.float32 + + model = model.type(torch.half) + assert model.metric.value.dtype == torch.float32 diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 28e7b2e1de0..17ed66ee242 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -91,11 +91,11 @@ def test_device_and_dtype_transfer_metriccollection(tmpdir): for _, metric in metric_collection.items(): assert metric.x.is_cuda - metric_collection = metric_collection.double() + metric_collection = metric_collection.set_dtype(torch.double) for _, metric in metric_collection.items(): assert metric.x.dtype == torch.float64 - metric_collection = metric_collection.half() + metric_collection = metric_collection.set_dtype(torch.half) for _, metric in metric_collection.items(): assert metric.x.dtype == torch.float16 diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index 46ea1ca988b..71d9f9b497b 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -285,6 +285,21 @@ def test_device_and_dtype_transfer(tmpdir): assert metric.x.dtype == torch.float16 +def test_disable_of_normal_dtype_methods(): + """Check that the default dtype changing methods does nothing.""" + metric = DummyMetricSum() + assert metric.x.dtype == torch.float32 + + metric = metric.half() + assert metric.x.dtype == torch.float32 + + metric = metric.double() + assert metric.x.dtype == torch.float32 + + metric = metric.type(torch.half) + assert metric.x.dtype == torch.float32 + + def test_warning_on_compute_before_update(): """test that an warning is raised if user tries to call compute before update.""" metric = DummyMetricSum()