Skip to content

Commit 26049b1

Browse files
Jieying Luojax authors
authored andcommitted
Fix cuda array interface with old jaxlib.
arg name was added in xla_extension_version 261. PiperOrigin-RevId: 629745255
1 parent e75e4a5 commit 26049b1

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2538,9 +2538,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
25382538
cai=cai, gpu_backend=backend, device_id=device_id
25392539
)
25402540
else:
2541-
object = xc._xla.cuda_array_interface_to_buffer(
2542-
cai=cai, gpu_backend=backend
2543-
)
2541+
object = xc._xla.cuda_array_interface_to_buffer(cai, backend)
25442542

25452543
object = tree_map(lambda leaf: leaf.__jax_array__()
25462544
if hasattr(leaf, "__jax_array__") else leaf, object)

0 commit comments

Comments
 (0)