Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setting typing.TYPE_CHECKING=True breaks Jax's API #17385

Open
jeertmans opened this issue Aug 31, 2023 · 10 comments
Open

Setting typing.TYPE_CHECKING=True breaks Jax's API #17385

jeertmans opened this issue Aug 31, 2023 · 10 comments
Labels
bug Something isn't working

Comments

@jeertmans
Copy link
Contributor

jeertmans commented Aug 31, 2023

Description

I am not sure if this is intended, but I recently encountered an issue
when documenting my jax-related library, where array would by of type ArrayImpl, which would break my code.

This is caused by an update in Sphinx's autodoc, where they import the packages with TYPE_CHECKING=True as of v7.2, see the discussion in sphinx-doc/sphinx#11652.

However, they feel it is an issue from Jax that it fails to have the expected behavior when type checking is enabled.

I could not produce the same error, but we can see that importing jax after type checking was set results in an error:

import typing

typing.TYPE_CHECKING = True

import jax

which outputs:

Traceback (most recent call last):
  File "/export/home/eertmans/repositories/jeertmans.github.io/test.py", line 5, in <module>
    import jax
  File "/home/eertmans/.local/lib/python3.10/site-packages/jax/__init__.py", line 169, in <module>
    from jax import scipy as scipy
  File "/home/eertmans/.local/lib/python3.10/site-packages/jax/scipy/__init__.py", line 21, in <module>
    from jax.scipy import interpolate as interpolate
  File "/home/eertmans/.local/lib/python3.10/site-packages/jax/scipy/interpolate/__init__.py", line 17, in <module>
    from jax._src.third_party.scipy.interpolate import (
  File "/home/eertmans/.local/lib/python3.10/site-packages/jax/_src/third_party/scipy/interpolate.py", line 2, in <module>
    import scipy.interpolate as osp_interpolate
  File "/home/eertmans/.local/lib/python3.10/site-packages/scipy/interpolate/__init__.py", line 166, in <module>
    from ._interpolate import *
  File "/home/eertmans/.local/lib/python3.10/site-packages/scipy/interpolate/_interpolate.py", line 21, in <module>
    from .interpnd import _ndim_coords_from_arrays
  File "interpnd.pyx", line 1, in init scipy.interpolate.interpnd
  File "/home/eertmans/.local/lib/python3.10/site-packages/scipy/spatial/__init__.py", line 108, in <module>
    from ._geometric_slerp import geometric_slerp
  File "/home/eertmans/.local/lib/python3.10/site-packages/scipy/spatial/_geometric_slerp.py", line 12, in <module>
    import numpy.typing as npt
  File "/home/eertmans/.local/lib/python3.10/site-packages/numpy/typing/__init__.py", line 158, in <module>
    from numpy._typing import (
  File "/home/eertmans/.local/lib/python3.10/site-packages/numpy/_typing/__init__.py", line 209, in <module>
    from ._ufunc import (
ModuleNotFoundError: No module named 'numpy._typing._ufunc'

What jax/jaxlib version are you using?

jax v0.4.14, jaxlib v0.4.14+cuda12.cudnn89

Which accelerator(s) are you using?

CPU/GPU

Additional system info

3.10, Linux

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.125.06   Driver Version: 525.125.06   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   39C    P8    16W / 220W |      9MiB /  8192MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      2024      G   /usr/lib/xorg/Xorg                  4MiB |
|    0   N/A  N/A      2330      G   ...ome-remote-desktop-daemon        3MiB |
+-----------------------------------------------------------------------------+
@jeertmans jeertmans added the bug Something isn't working label Aug 31, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 31, 2023

It looks like the error is coming from import numpy.typing, which comes via import scipy - if this is a bug, it’s a bug in numpy and scipy, not in JAX itself I think.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 31, 2023

Looking closer at this – I don't think this would be expected to work. typing.TYPE_CHECKING=True implies that the code is not being executed by a Python runtime, but rather only type-checked. So for example, numpy/typing/_ufunc.py does not exist: you cannot import it. But numpy/typing/_ufunc.pyi does exist, so when a static type checker parses this code, it will be able to find the type declarations at that path.

The sphinx bug looks like its real, but the code snippet above seems unrelated to the original issue.

@jeertmans
Copy link
Contributor Author

I think this issue comes from what different Python modules assume TYPE_CHECKING=True means...

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 31, 2023

Looking at the error from the sphinx issue, I also don't understand where that is coming from. ArrayImpl is in fact subscriptable in the current version of JAX:

# contents of main.py
import jax.numpy as jnp
from jax._src.array import ArrayImpl

x: ArrayImpl = jnp.arange(5)
print(x[0])
$ mypy main.py
Success: no issues found in 1 source file

What version of JAX are you using when you're seeing this error?

@jeertmans
Copy link
Contributor Author

Versions 0.4.14 for both jax and jaxlib.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 31, 2023

Strange - ArrayImpl is definitiely indexable in jax v0.4.14 (see https://github.com/google/jax/blob/jax-v0.4.14/jax/_src/array.py#L297).

If you can come up with a self-contained reproducer, let me know.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 31, 2023

I suspect the bug you're hitting has to do with this TODO, which has to do with working around a static type issue in pytype:

jax/jax/_src/array.py

Lines 544 to 548 in 88a60b8

# TODO(b/273265390): ideally we would write this as a decorator on the ArrayImpl
# class, however this triggers a pytype bug. Workaround: apply the decorator
# after the fact.
if not TYPE_CHECKING:
ArrayImpl = use_cpp_class(xc.ArrayImpl)(ArrayImpl)

But I'd still like to understand in which situations this would come up.

@jeertmans
Copy link
Contributor Author

Did you check the MWE I put in the Sphinx issue I me mentioned above?

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 5, 2023

Yes I saw that, but it's not very minimal...

If it's a JAX bug, we should be able to reproduce it without setting up a sphinx project directory. If we can't reproduce it without sphinx, then I'd assume it's a sphinx bug. The fact that sphinx sets TYPE_CHECKING = True before doing a runtime import is a bit suspicious.

@jeertmans
Copy link
Contributor Author

I’ll try to have a better example when I can

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants