New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Setting typing.TYPE_CHECKING=True
breaks Jax's API
#17385
Comments
It looks like the error is coming from |
Looking closer at this – I don't think this would be expected to work. The sphinx bug looks like its real, but the code snippet above seems unrelated to the original issue. |
I think this issue comes from what different Python modules assume |
Looking at the error from the sphinx issue, I also don't understand where that is coming from. # contents of main.py
import jax.numpy as jnp
from jax._src.array import ArrayImpl
x: ArrayImpl = jnp.arange(5)
print(x[0])
What version of JAX are you using when you're seeing this error? |
Versions 0.4.14 for both |
Strange - ArrayImpl is definitiely indexable in jax v0.4.14 (see https://github.com/google/jax/blob/jax-v0.4.14/jax/_src/array.py#L297). If you can come up with a self-contained reproducer, let me know. |
I suspect the bug you're hitting has to do with this TODO, which has to do with working around a static type issue in pytype: Lines 544 to 548 in 88a60b8
But I'd still like to understand in which situations this would come up. |
Did you check the MWE I put in the Sphinx issue I me mentioned above? |
Yes I saw that, but it's not very minimal... If it's a JAX bug, we should be able to reproduce it without setting up a sphinx project directory. If we can't reproduce it without sphinx, then I'd assume it's a sphinx bug. The fact that sphinx sets |
I’ll try to have a better example when I can |
Description
I am not sure if this is intended, but I recently encountered an issue
when documenting my jax-related library, where array would by of type
ArrayImpl
, which would break my code.This is caused by an update in
Sphinx
's autodoc, where they import the packages withTYPE_CHECKING=True
as ofv7.2
, see the discussion in sphinx-doc/sphinx#11652.However, they feel it is an issue from Jax that it fails to have the expected behavior when type checking is enabled.
I could not produce the same error, but we can see that importing
jax
after type checking was set results in an error:which outputs:
What jax/jaxlib version are you using?
jax v0.4.14, jaxlib v0.4.14+cuda12.cudnn89
Which accelerator(s) are you using?
CPU/GPU
Additional system info
3.10, Linux
NVIDIA GPU info
The text was updated successfully, but these errors were encountered: