[callback] Allow external callbacks to return 64-bit values in 32-bit mode #20534
+14
−13
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Previously, prior to #20433, if the Python callback returned a Python literal (which is natively a 64-bit value), and the
result_shape_dtypes
specified a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced an error in this situation. However, when trying to port the internal code that uses host_callback toio_callback
, I am getting many instances of this error. The common scenario is a Python callback function that returns a Python scalar:However, if the
f_host
were called directly JAX would canonicalize the value42.
to afloat32
(whenjax_enable_x64
is not set). I do not think that it makes sense forio_callback
to have stricter behavior that a direct call.In this PR we add a canonicalization step on the returned values of Python callbacks, which casts the values to 32-bits if JAX is running in 32-bit mode.
Note that the above example should return an error in 64-bit mode, because the actual returned value is a 64-bit value but the declared expected value is
np.float32
. To avoid the error in both 64-bit and 32-bit mode, the python callback should returnnp.float32(42.)
.In some sense this is replacing the change in #20433 to add a canonicalization step instead of an error.