Skip to content

Commit 7792bde

Browse files
author
jax authors
committed
Merge pull request #22574 from jakevdp:xla-computation
PiperOrigin-RevId: 655237152
2 parents 13e42ad + f887b66 commit 7792bde

File tree

5 files changed

+36
-24
lines changed

5 files changed

+36
-24
lines changed

jax/_src/api.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,14 @@ def xla_computation(fun: Callable,
346346
donate_argnums: int | Iterable[int] = ()) -> Callable:
347347
"""Creates a function that produces its XLA computation given example args.
348348
349+
.. warning::
350+
351+
This function is deprecated as of JAX v0.4.30, and will be removed in a future
352+
JAX release. You can replace it with :ref:`ahead-of-time-lowering` APIs; for
353+
example, ``jax.xla_computation(fn)(*args)`` can be replaced with
354+
``jax.jit(fn).lower(*args).compiler_ir('hlo')``.
355+
See the `JAX 0.4.30 Change log`_ for more examples.
356+
349357
Args:
350358
fun: Function from which to form XLA computations.
351359
static_argnums: See the :py:func:`jax.jit` docstring.
@@ -404,7 +412,7 @@ def xla_computation(fun: Callable,
404412
>>> import jax
405413
>>>
406414
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
407-
>>> c = jax.xla_computation(f)(3.)
415+
>>> c = jax.xla_computation(f)(3.) # doctest: +SKIP
408416
>>> print(c.as_hlo_text()) # doctest: +SKIP
409417
HloModule xla_computation_f.6
410418
<BLANKLINE>
@@ -423,13 +431,13 @@ def xla_computation(fun: Callable,
423431
424432
>>> import types
425433
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
426-
>>> c = jax.xla_computation(f)(scalar)
434+
>>> c = jax.xla_computation(f)(scalar) # doctest: +SKIP
427435
428436
429437
Here's an example that involves a parallel collective and axis name:
430438
431439
>>> def f(x): return x - jax.lax.psum(x, 'i')
432-
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2)
440+
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) # doctest: +SKIP
433441
>>> print(c.as_hlo_text()) # doctest: +SKIP
434442
HloModule jaxpr_computation.9
435443
primitive_computation.3 {
@@ -457,7 +465,7 @@ def xla_computation(fun: Callable,
457465
... return rowsum, colsum, allsum
458466
...
459467
>>> axis_env = [('i', 4), ('j', 2)]
460-
>>> c = xla_computation(g, axis_env=axis_env)(5.)
468+
>>> c = jax.xla_computation(g, axis_env=axis_env)(5.) # doctest: +SKIP
461469
>>> print(c.as_hlo_text()) # doctest: +SKIP
462470
HloModule jaxpr_computation__1.19
463471
[removed uninteresting text here]
@@ -469,6 +477,8 @@ def xla_computation(fun: Callable,
469477
all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13
470478
ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17)
471479
}
480+
481+
.. _JAX 0.4.30 Change log: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-30-june-18-2024
472482
"""
473483
if instantiate_const_outputs is not None:
474484
raise ValueError(

jax/_src/deprecations.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ def accelerate_getattr_deprecation(module: ModuleType, name: str) -> None:
6666
message, _ = module._deprecations[name]
6767
module._deprecations[name] = (message, None)
6868

69+
def is_accelerated_attribute(module: ModuleType, name: str) -> bool:
70+
"""Returns true if given name is accelerated.
71+
72+
Raises an error if name is not a deprecated attribute in module.
73+
"""
74+
return module._deprecations[name][1] is None
75+
6976
# The following mechanism is a separate one, for registering and
7077
# accelerating deprecations that are not imports (for example, deprecations
7178
# of a function argument).

jax/_src/test_util.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -188,18 +188,6 @@ def check_eq(xs, ys, err_msg=''):
188188
tree_all(tree_map(assert_close, xs, ys))
189189

190190

191-
# TODO(yashkatariya): Make this context manager check for deprecation message
192-
# in OSS.
193-
@contextmanager
194-
def unaccelerate_getattr_deprecation(module, name):
195-
message, prev_attr = module._deprecations[name]
196-
module._deprecations[name] = (message, getattr(module, f"_deprecated_{name}"))
197-
try:
198-
yield
199-
finally:
200-
module._deprecations[name] = (message, prev_attr)
201-
202-
203191
@contextmanager
204192
def _capture_output(fp: TextIO) -> Generator[Callable[[], str], None, None]:
205193
"""Context manager to capture all output written to a given file object.

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,16 @@ filterwarnings = [
5454
"default:Error (reading|writing) persistent compilation cache entry for 'jit_equal'",
5555
"default:Error (reading|writing) persistent compilation cache entry for 'jit__lambda_'",
5656
"default:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",
57-
"default:jax.xla_computation is deprecated. Please use the AOT APIs.*:DeprecationWarning",
57+
5858
# TODO(jakevdp): remove when array_api_tests stabilize
59-
# start array_api_tests-related warnings
6059
"default:.*not machine-readable.*:UserWarning",
6160
"default:Special cases found for .* but none were parsed.*:UserWarning",
6261
"default:.*is not JSON-serializable. Using the repr instead.",
63-
# end array_api_tests-related warnings
64-
# This is a transitive warning coming from TensorFlow dependencies.
62+
63+
# These are transitive warnings coming from TensorFlow dependencies.
6564
# TODO(slebedev): Remove once we bump the minimum TensorFlow version.
6665
"default:The key path API is deprecated .*",
66+
"default:jax.xla_computation is deprecated.*:DeprecationWarning",
6767
]
6868
doctest_optionflags = [
6969
"NUMBER",

tests/api_test.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from jax._src import config
5151
from jax._src import core
5252
from jax._src import custom_derivatives
53+
from jax._src import deprecations
5354
from jax._src import linear_util as lu
5455
from jax._src import test_util as jtu
5556
from jax._src import xla_bridge
@@ -3020,16 +3021,19 @@ def fn(x):
30203021
axis_env = [(axis_name, jax.local_device_count())]
30213022
_ = api.xla_computation(fn, axis_env=axis_env, backend='cpu')(input_x)
30223023

3023-
@jtu.unaccelerate_getattr_deprecation(jax, 'xla_computation')
3024+
@jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated')
30243025
def test_xla_computation_axis_env(self):
3026+
is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation')
3027+
xla_computation = api.xla_computation if is_accelerated else jax.xla_computation
3028+
30253029
def fn(x):
30263030
z = x * jax.lax.axis_index('i').astype(jnp.float32)
30273031
def inner_fn(carry, a):
30283032
return carry + a, ()
30293033
return jax.lax.scan(inner_fn, jnp.zeros_like(z[0]), z)
30303034

30313035
x = jnp.ones((5, 6, 4), dtype=jnp.float32)
3032-
_ = jax.xla_computation(fn, axis_env=(('i', 8),), backend='cpu')(x)
3036+
_ = xla_computation(fn, axis_env=(('i', 8),), backend='cpu')(x)
30333037

30343038
def test_concurrent_device_get_and_put(self):
30353039
def f(x):
@@ -10760,8 +10764,11 @@ def test_pmap_nested_donate_ignored(self):
1076010764

1076110765
class NamedCallTest(jtu.JaxTestCase):
1076210766

10763-
@jtu.unaccelerate_getattr_deprecation(jax, 'xla_computation')
10767+
@jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated')
1076410768
def test_default_name(self):
10769+
is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation')
10770+
xla_computation = api.xla_computation if is_accelerated else jax.xla_computation
10771+
1076510772
@api.named_call
1076610773
def my_test_function(x):
1076710774
return x**2
@@ -10770,7 +10777,7 @@ def my_test_function(x):
1077010777
def f(x):
1077110778
return my_test_function(x)
1077210779

10773-
c = jax.xla_computation(f)(2)
10780+
c = xla_computation(f)(2)
1077410781
print_opts = xla_client._xla.HloPrintOptions.short_parsable()
1077510782
print_opts.print_metadata = True
1077610783
hlo_text = c.as_hlo_module().to_string(print_opts)

0 commit comments

Comments
 (0)