From ac9520da2ebf4f5460eac02c92f6e978c43ec708 Mon Sep 17 00:00:00 2001 From: Farzan Taj Date: Fri, 3 Mar 2023 11:49:17 -0500 Subject: [PATCH 1/8] Added check in Metric._apply() to stop dtype conversions on metric states and defaults --- src/torchmetrics/metric.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 3bca6ff0ce6..a29995f70a6 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -700,6 +700,10 @@ def _apply(self, fn: Callable) -> Module: This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, etc. methods are called. """ this = super()._apply(fn) + + if "Module.half" in str(fn) or "Module.float" in str(fn): + return this + # Also apply fn to metric states and defaults for key, value in this._defaults.items(): if isinstance(value, Tensor): From 257eb46f8afedad6c10193ed2a63aee27c2f4be7 Mon Sep 17 00:00:00 2001 From: Farzan Taj Date: Fri, 3 Mar 2023 12:18:29 -0500 Subject: [PATCH 2/8] Updated docstring --- src/torchmetrics/metric.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index a29995f70a6..465075d6600 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -697,7 +697,8 @@ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric": def _apply(self, fn: Callable) -> Module: """Overwrite _apply function such that we can also move metric states to the correct device. - This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, etc. methods are called. + This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, etc. methods are called, + however .half() and .float() calls are not applied on metric states and defaults. """ this = super()._apply(fn) From 48e38cfd70448466f2394bf15618e3fe5ccbcc39 Mon Sep 17 00:00:00 2001 From: Farzan Taj Date: Fri, 3 Mar 2023 12:57:34 -0500 Subject: [PATCH 3/8] Unit test for metric state dtype change in LightningModule --- tests/unittests/bases/test_metric.py | 36 +++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index 8792779a719..e0ffe5bba2e 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -22,7 +22,8 @@ import pytest import torch from torch import Tensor, tensor -from torch.nn import Module +from torch.nn import Module, Linear +from pytorch_lightning import LightningModule from torchmetrics import PearsonCorrCoef from torchmetrics.classification import BinaryAccuracy @@ -296,6 +297,39 @@ def test_device_and_dtype_transfer(tmpdir): assert metric.x.dtype == torch.float16 +def test_dtype_in_pl_module_transfer(tmpdir): + """Test that metric states don't change dtype when .half() or .float() is called on the LightningModule.""" + class BoringModel(LightningModule): + def __init__(self, metric_dtype=torch.float32): + super().__init__() + self.layer = Linear(32, 32) + self.metric = DummyMetricSum() + self.metric.set_dtype(metric_dtype) + + def forward(self, x): + return self.layer(x) + + def training_step(self, batch, batch_idx): + pred = self.forward(batch) + loss = self(batch).sum() + self.metric.update(torch.flatten(pred), torch.flatten(batch)) + + return {"loss": loss} + + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + model = BoringModel() + assert model.metric.x.dtype == torch.float32 + model = model.half() + assert model.metric.x.dtype == torch.float32 + + model = BoringModel(metric_dtype=torch.float16) + assert model.metric.x.dtype == torch.float16 + model = model.float() + assert model.metric.x.dtype == torch.float16 + + def test_warning_on_compute_before_update(): """Test that an warning is raised if user tries to call compute before update.""" metric = DummyMetricSum() From b85df2a1850e8c3a7538317d7a8fb9bab6341468 Mon Sep 17 00:00:00 2001 From: Farzan Taj Date: Fri, 3 Mar 2023 13:06:01 -0500 Subject: [PATCH 4/8] Formatting fix with pre-commit --- tests/unittests/bases/test_metric.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index e0ffe5bba2e..7543d6527f7 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -21,9 +21,9 @@ import psutil import pytest import torch -from torch import Tensor, tensor -from torch.nn import Module, Linear from pytorch_lightning import LightningModule +from torch import Tensor, tensor +from torch.nn import Linear, Module from torchmetrics import PearsonCorrCoef from torchmetrics.classification import BinaryAccuracy @@ -299,6 +299,7 @@ def test_device_and_dtype_transfer(tmpdir): def test_dtype_in_pl_module_transfer(tmpdir): """Test that metric states don't change dtype when .half() or .float() is called on the LightningModule.""" + class BoringModel(LightningModule): def __init__(self, metric_dtype=torch.float32): super().__init__() @@ -318,7 +319,7 @@ def training_step(self, batch, batch_idx): def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1) - + model = BoringModel() assert model.metric.x.dtype == torch.float32 model = model.half() From dcd0a86af0c99b95536053a486241820033a3384 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 4 Mar 2023 19:13:18 +0100 Subject: [PATCH 5/8] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d1a55c77ce7..cc4828afa2b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,6 +102,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed evaluation of `R2Score` with near constant target ([#1576](https://github.com/Lightning-AI/metrics/pull/1576)) +- Fixed dtype conversion when metric is submodule ([#1583](https://github.com/Lightning-AI/metrics/pull/1583)) + + ## [0.11.2] - 2023-02-21 ### Fixed From 6aa9ffd009fd918815204c140d017a40456cc0a3 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 4 Mar 2023 20:33:19 +0100 Subject: [PATCH 6/8] add more cases --- src/torchmetrics/collections.py | 10 ++++++++++ src/torchmetrics/metric.py | 15 ++++++++++----- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index ab381007425..8a7c443482d 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -472,3 +472,13 @@ def __repr__(self) -> str: if self.postfix: repr_str += f"{',' if not self.prefix else ''}\n postfix={self.postfix}" return repr_str + "\n)" + + def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "MetricCollection": + """Transfer all metric state to specific dtype. Special version of standard `type` method. + + Arguments: + dst_type (type or string): the desired type. + """ + for _, m in self.items(keep_base=True, copy_state=False): + m.set_dtype(dst_type) + return self diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 465075d6600..efda07a076b 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -145,6 +145,7 @@ def __init__( self._to_sync = self.sync_on_compute self._should_unsync = True self._enable_grad = False + self._dtype_convert = False # initialize state self._defaults: Dict[str, Union[List, Tensor]] = {} @@ -692,17 +693,21 @@ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric": Arguments: dst_type (type or string): the desired type. """ - return super().type(dst_type) + self._dtype_convert = True + out = super().type(dst_type) + out._dtype_convert = False + return out def _apply(self, fn: Callable) -> Module: """Overwrite _apply function such that we can also move metric states to the correct device. - This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, etc. methods are called, - however .half() and .float() calls are not applied on metric states and defaults. + This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods + are called. Dtype conversion is garded and will only happen through the special `set_dtype` method. """ this = super()._apply(fn) - - if "Module.half" in str(fn) or "Module.float" in str(fn): + fs = str(fn) + cond = any(f in fs for f in ["Module.type", "Module.half", "Module.float", "Module.double", "Module.bfloat16"]) + if not self._dtype_convert and cond: return this # Also apply fn to metric states and defaults From 1e3530159d6851b32579abeec95dbdb3cae7b777 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 4 Mar 2023 20:33:47 +0100 Subject: [PATCH 7/8] fix cases --- tests/integrations/test_lightning.py | 46 +++++++++++++++++++++++ tests/unittests/bases/test_collections.py | 4 +- tests/unittests/bases/test_metric.py | 43 +++++++-------------- 3 files changed, 61 insertions(+), 32 deletions(-) diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_lightning.py index 6652d0a3114..96b852f0b65 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_lightning.py @@ -16,6 +16,7 @@ import torch from pytorch_lightning import LightningModule, Trainer from torch import tensor +from torch.nn import Linear from torch.utils.data import DataLoader from integrations.helpers import no_warning_call @@ -302,3 +303,48 @@ def training_step(self, batch, batch_idx): output = model(rand_input) script_output = script_model(rand_input) assert torch.allclose(output, script_output) + + +def test_dtype_in_pl_module_transfer(tmpdir): + """Test that metric states don't change dtype when .half() or .float() is called on the LightningModule.""" + + class BoringModel(LightningModule): + def __init__(self, metric_dtype=torch.float32): + super().__init__() + self.layer = Linear(32, 32) + self.metric = SumMetric() + self.metric.set_dtype(metric_dtype) + + def forward(self, x): + return self.layer(x) + + def training_step(self, batch, batch_idx): + pred = self.forward(batch) + loss = self(batch).sum() + self.metric.update(torch.flatten(pred), torch.flatten(batch)) + + return {"loss": loss} + + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + model = BoringModel() + assert model.metric.value.dtype == torch.float32 + model = model.half() + assert model.metric.value.dtype == torch.float32 + + model = BoringModel() + assert model.metric.value.dtype == torch.float32 + model = model.double() + assert model.metric.value.dtype == torch.float32 + + model = BoringModel(metric_dtype=torch.float16) + assert model.metric.value.dtype == torch.float16 + model = model.float() + assert model.metric.value.dtype == torch.float16 + + model = BoringModel() + assert model.metric.value.dtype == torch.float32 + + model = model.type(torch.half) + assert model.metric.value.dtype == torch.float32 diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 30cfe857c68..a58380e16c6 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -91,11 +91,11 @@ def test_device_and_dtype_transfer_metriccollection(tmpdir): for _, metric in metric_collection.items(): assert metric.x.is_cuda - metric_collection = metric_collection.double() + metric_collection = metric_collection.set_dtype(torch.double) for _, metric in metric_collection.items(): assert metric.x.dtype == torch.float64 - metric_collection = metric_collection.half() + metric_collection = metric_collection.set_dtype(torch.half) for _, metric in metric_collection.items(): assert metric.x.dtype == torch.float16 diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index 7543d6527f7..2c9c333e081 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -21,9 +21,8 @@ import psutil import pytest import torch -from pytorch_lightning import LightningModule from torch import Tensor, tensor -from torch.nn import Linear, Module +from torch.nn import Module from torchmetrics import PearsonCorrCoef from torchmetrics.classification import BinaryAccuracy @@ -295,40 +294,24 @@ def test_device_and_dtype_transfer(tmpdir): assert metric.x.dtype == torch.float16 metric.reset() assert metric.x.dtype == torch.float16 + import pdb + pdb.set_trace() -def test_dtype_in_pl_module_transfer(tmpdir): - """Test that metric states don't change dtype when .half() or .float() is called on the LightningModule.""" - class BoringModel(LightningModule): - def __init__(self, metric_dtype=torch.float32): - super().__init__() - self.layer = Linear(32, 32) - self.metric = DummyMetricSum() - self.metric.set_dtype(metric_dtype) - - def forward(self, x): - return self.layer(x) - - def training_step(self, batch, batch_idx): - pred = self.forward(batch) - loss = self(batch).sum() - self.metric.update(torch.flatten(pred), torch.flatten(batch)) - - return {"loss": loss} +def test_disable_of_normal_dtype_methods(): + """Check that the default dtype changing methods does nothing.""" + metric = DummyMetricSum() + assert metric.x.dtype == torch.float32 - def configure_optimizers(self): - return torch.optim.SGD(self.layer.parameters(), lr=0.1) + metric = metric.half() + assert metric.x.dtype == torch.float32 - model = BoringModel() - assert model.metric.x.dtype == torch.float32 - model = model.half() - assert model.metric.x.dtype == torch.float32 + metric = metric.double() + assert metric.x.dtype == torch.float32 - model = BoringModel(metric_dtype=torch.float16) - assert model.metric.x.dtype == torch.float16 - model = model.float() - assert model.metric.x.dtype == torch.float16 + metric = metric.type(torch.half) + assert metric.x.dtype == torch.float32 def test_warning_on_compute_before_update(): From cfee14a68d48390a522886f26d4ea56e8792cf59 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 7 Mar 2023 12:05:54 +0100 Subject: [PATCH 8/8] Update tests/unittests/bases/test_metric.py --- tests/unittests/bases/test_metric.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index b88081bd480..2af3de42668 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -308,9 +308,6 @@ def test_device_and_dtype_transfer(tmpdir): assert metric.x.dtype == torch.float16 metric.reset() assert metric.x.dtype == torch.float16 - import pdb - - pdb.set_trace() def test_disable_of_normal_dtype_methods():