Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent metric dtype conversion when metric is part of LightningModule #1583

Merged
merged 17 commits into from Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -103,6 +103,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed evaluation of `R2Score` with near constant target ([#1576](https://github.com/Lightning-AI/metrics/pull/1576))


- Fixed dtype conversion when metric is submodule ([#1583](https://github.com/Lightning-AI/metrics/pull/1583))


- Fixed corner case for `PearsonCorrCoef` when running in ddp mode but only on single device ([#1587](https://github.com/Lightning-AI/metrics/pull/1587))


Expand Down
10 changes: 10 additions & 0 deletions src/torchmetrics/collections.py
Expand Up @@ -472,3 +472,13 @@ def __repr__(self) -> str:
if self.postfix:
repr_str += f"{',' if not self.prefix else ''}\n postfix={self.postfix}"
return repr_str + "\n)"

def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "MetricCollection":
"""Transfer all metric state to specific dtype. Special version of standard `type` method.

Arguments:
dst_type (type or string): the desired type.
"""
for _, m in self.items(keep_base=True, copy_state=False):
m.set_dtype(dst_type)
return self
14 changes: 12 additions & 2 deletions src/torchmetrics/metric.py
Expand Up @@ -145,6 +145,7 @@ def __init__(
self._to_sync = self.sync_on_compute
self._should_unsync = True
self._enable_grad = False
self._dtype_convert = False

# initialize state
self._defaults: Dict[str, Union[List, Tensor]] = {}
Expand Down Expand Up @@ -692,14 +693,23 @@ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric":
Arguments:
dst_type (type or string): the desired type.
"""
return super().type(dst_type)
self._dtype_convert = True
out = super().type(dst_type)
out._dtype_convert = False
return out

def _apply(self, fn: Callable) -> Module:
"""Overwrite _apply function such that we can also move metric states to the correct device.

This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, etc. methods are called.
This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods
are called. Dtype conversion is garded and will only happen through the special `set_dtype` method.
"""
this = super()._apply(fn)
fs = str(fn)
cond = any(f in fs for f in ["Module.type", "Module.half", "Module.float", "Module.double", "Module.bfloat16"])
FarzanT marked this conversation as resolved.
Show resolved Hide resolved
if not self._dtype_convert and cond:
return this

# Also apply fn to metric states and defaults
for key, value in this._defaults.items():
if isinstance(value, Tensor):
Expand Down
46 changes: 46 additions & 0 deletions tests/integrations/test_lightning.py
Expand Up @@ -16,6 +16,7 @@
import torch
from pytorch_lightning import LightningModule, Trainer
from torch import tensor
from torch.nn import Linear
from torch.utils.data import DataLoader

from integrations.helpers import no_warning_call
Expand Down Expand Up @@ -304,3 +305,48 @@ def training_step(self, batch, batch_idx):
output = model(rand_input)
script_output = script_model(rand_input)
assert torch.allclose(output, script_output)


def test_dtype_in_pl_module_transfer(tmpdir):
"""Test that metric states don't change dtype when .half() or .float() is called on the LightningModule."""

class BoringModel(LightningModule):
def __init__(self, metric_dtype=torch.float32):
super().__init__()
self.layer = Linear(32, 32)
self.metric = SumMetric()
self.metric.set_dtype(metric_dtype)

def forward(self, x):
return self.layer(x)

def training_step(self, batch, batch_idx):
pred = self.forward(batch)
loss = self(batch).sum()
self.metric.update(torch.flatten(pred), torch.flatten(batch))

return {"loss": loss}

def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)

model = BoringModel()
assert model.metric.value.dtype == torch.float32
model = model.half()
assert model.metric.value.dtype == torch.float32

model = BoringModel()
assert model.metric.value.dtype == torch.float32
model = model.double()
assert model.metric.value.dtype == torch.float32

model = BoringModel(metric_dtype=torch.float16)
assert model.metric.value.dtype == torch.float16
model = model.float()
assert model.metric.value.dtype == torch.float16

model = BoringModel()
assert model.metric.value.dtype == torch.float32

model = model.type(torch.half)
assert model.metric.value.dtype == torch.float32
4 changes: 2 additions & 2 deletions tests/unittests/bases/test_collections.py
Expand Up @@ -93,11 +93,11 @@ def test_device_and_dtype_transfer_metriccollection(tmpdir):
for _, metric in metric_collection.items():
assert metric.x.is_cuda

metric_collection = metric_collection.double()
metric_collection = metric_collection.set_dtype(torch.double)
for _, metric in metric_collection.items():
assert metric.x.dtype == torch.float64

metric_collection = metric_collection.half()
metric_collection = metric_collection.set_dtype(torch.half)
for _, metric in metric_collection.items():
assert metric.x.dtype == torch.float16

Expand Down
18 changes: 18 additions & 0 deletions tests/unittests/bases/test_metric.py
Expand Up @@ -308,6 +308,24 @@ def test_device_and_dtype_transfer(tmpdir):
assert metric.x.dtype == torch.float16
metric.reset()
assert metric.x.dtype == torch.float16
import pdb

pdb.set_trace()
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved


def test_disable_of_normal_dtype_methods():
"""Check that the default dtype changing methods does nothing."""
metric = DummyMetricSum()
assert metric.x.dtype == torch.float32

metric = metric.half()
assert metric.x.dtype == torch.float32

metric = metric.double()
assert metric.x.dtype == torch.float32

metric = metric.type(torch.half)
assert metric.x.dtype == torch.float32


def test_warning_on_compute_before_update():
Expand Down