Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] Gradient error with LeakyRELU/Mish/GELU + .sum() on 1D and scalar tensors #123178

Closed
jtang98 opened this issue Apr 2, 2024 · 4 comments
Closed
Labels
module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jtang98
Copy link
Contributor

jtang98 commented Apr 2, 2024

🐛 Describe the bug

Hi,

Calling the sum() on a scalar or 1D tensors will produce wrong gradients on a tensor that went through LeakyRELU, GELU or Mish.

RELU is fine.
Tensors >= 2D are fine as well.

Might be related to #117826

Snippet of code to reproduce the issue:

import torch
import torch.nn.functional as F

for func in [F.relu, F.leaky_relu, F.gelu, F.mish, lambda x: x]:
    for device in ['mps', 'cpu']:
        x1 = torch.tensor(3.0).to(device)  # This will NOT work for leakyRELU, GELU and Mish
        #x1 = torch.Tensor([3.0, -3.1]).to(device)  # This will NOT work for leakyRELU, GELU and Mish
        #x1 = torch.Tensor([[3.0, -3.1], [2.1, 3.4]]).to(device)  # This will work for all the cases
        x1.requires_grad = True
        y1 = func(x1).sum()
        y1.backward()
        print("Gradient on " + device + ":")
        print(x1.grad)
    print('----')

I'm happy to help having confirmation that this is indeed a bug and would be glad to have inputs on the possible root cause of the issue.

Versions

PyTorch version: 2.4.0a0+git0ff6d76
Is debug build: True
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.4.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.28.4
Libc version: N/A

Python version: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:49:36) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-14.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: False

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.12.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.8.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.0
[pip3] optree==0.10.0
[pip3] torch==2.4.0a0+gitb27ee65
[conda] numpy 1.26.0 pypi_0 pypi
[conda] optree 0.10.0 pypi_0 pypi
[conda] torch 2.4.0a0+gitb27ee65 dev_0
[conda] torchfix 0.4.0 pypi_0 pypi

Build was done executing: DEBUG=1 USE_DISTRIBUTED=0 USE_MKLDNN=0 USE_CUDA=0 BUILD_TEST=0 USE_FBGEMM=0 USE_NNPACK=0 USE_QNNPACK=0 USE_XNNPACK=0 python setup.py develop

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen @razarmehr

@MariaPonomarenko38
Copy link

Hi @jtang98 ! Could you show your output for x1 = torch.tensor(3.0).to(device)?

@jtang98
Copy link
Contributor Author

jtang98 commented Apr 2, 2024

Sure, it is:

Gradient on mps:
tensor(1., device='mps:0')
Gradient on cpu:
tensor(1.)
----
Gradient on mps:
tensor(3., device='mps:0')
Gradient on cpu:
tensor(1.)
----
Gradient on mps:
tensor(3.0317, device='mps:0')
Gradient on cpu:
tensor(1.0119)
----
Gradient on mps:
tensor(3.0496, device='mps:0')
Gradient on cpu:
tensor(1.0211)
----
Gradient on mps:
tensor(1., device='mps:0')
Gradient on cpu:
tensor(1.)

@malfet malfet added module: mps Related to Apple Metal Performance Shaders framework module: regression It used to work, and now it doesn't module: correctness (silent) issue that returns an incorrect result silently labels Apr 2, 2024
@malfet
Copy link
Contributor

malfet commented Apr 2, 2024

I suspect it's the same regression as another 14.4.1 problem with scalar tensors

@jbschlosser jbschlosser added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 2, 2024
pytorchmergebot pushed a commit that referenced this issue Apr 3, 2024
Fixes #122016 and #123178. This regression is related to an OS side change that requires a slight adjustment from us on PyTorch side to restore the previous behavior. Additionally we cleared out pre-MacOS13 related workarounds.

Before the fix on MacOS 14.4:

```
python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)"
tensor([0., 3., 3.], device='mps:0')
```

After the fix:
```
python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)"
tensor([0., 1., 3.], device='mps:0')
```

This also fixes complex number initialization and as such makes `nn.functional.rms_norm` pass on MacOS-14+

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: #123234
Approved by: https://github.com/malfet, https://github.com/kulinseth
@malfet
Copy link
Contributor

malfet commented Apr 4, 2024

Validated that it was fixed by #123234
Closing, but please do not hesitate to reopen or file a new one if this is not the case

@malfet malfet closed this as completed Apr 4, 2024
pytorchbot pushed a commit that referenced this issue Apr 4, 2024
Fixes #122016 and #123178. This regression is related to an OS side change that requires a slight adjustment from us on PyTorch side to restore the previous behavior. Additionally we cleared out pre-MacOS13 related workarounds.

Before the fix on MacOS 14.4:

```
python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)"
tensor([0., 3., 3.], device='mps:0')
```

After the fix:
```
python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)"
tensor([0., 1., 3.], device='mps:0')
```

This also fixes complex number initialization and as such makes `nn.functional.rms_norm` pass on MacOS-14+

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: #123234
Approved by: https://github.com/malfet, https://github.com/kulinseth

(cherry picked from commit 05289a2)
pytorchmergebot pushed a commit that referenced this issue Apr 5, 2024
Fixes #122016 and #123178. This regression is related to an OS side change that requires a slight adjustment from us on PyTorch side to restore the previous behavior. Additionally we cleared out pre-MacOS13 related workarounds.

Before the fix on MacOS 14.4:

```
python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)"
tensor([0., 3., 3.], device='mps:0')
```

After the fix:
```
python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)"
tensor([0., 1., 3.], device='mps:0')
```

This also fixes complex number initialization and as such makes `nn.functional.rms_norm` pass on MacOS-14+

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: #123234
Approved by: https://github.com/malfet, https://github.com/kulinseth

(cherry picked from commit 05289a2)
atalman pushed a commit that referenced this issue Apr 5, 2024
Fixes #122016 and #123178. This regression is related to an OS side change that requires a slight adjustment from us on PyTorch side to restore the previous behavior. Additionally we cleared out pre-MacOS13 related workarounds.

Before the fix on MacOS 14.4:

```
python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)"
tensor([0., 3., 3.], device='mps:0')
```

After the fix:
```
python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)"
tensor([0., 1., 3.], device='mps:0')
```

This also fixes complex number initialization and as such makes `nn.functional.rms_norm` pass on MacOS-14+

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: #123234
Approved by: https://github.com/malfet, https://github.com/kulinseth

(cherry picked from commit 05289a2)

Co-authored-by: Joona Havukainen <jhavukainen@apple.com>
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this issue Apr 22, 2024
…123234)

Fixes pytorch#122016 and pytorch#123178. This regression is related to an OS side change that requires a slight adjustment from us on PyTorch side to restore the previous behavior. Additionally we cleared out pre-MacOS13 related workarounds.

Before the fix on MacOS 14.4:

```
python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)"
tensor([0., 3., 3.], device='mps:0')
```

After the fix:
```
python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)"
tensor([0., 1., 3.], device='mps:0')
```

This also fixes complex number initialization and as such makes `nn.functional.rms_norm` pass on MacOS-14+

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: pytorch#123234
Approved by: https://github.com/malfet, https://github.com/kulinseth
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants