Skip to content

Commit a2c06d6

Browse files
authored
Added clearer error message for tracers in numpy.split (#2508)
* Added clearer error message for tracers in numpy.split Now we print: ConcretizationTypeError: Abstract tracer value where concrete value is expected (in jax.numpy.split argument 1). Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values. See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`. Encountered value: Traced<ShapedArray> * Fixed tests, slight change to the error message * Expanded the FAQ entry about abstract tracers for higher-order primitives * Added clarification for tracers inside jit of grad * Updated FAQ language in response to reviews
1 parent 2e34dbc commit a2c06d6

File tree

7 files changed

+194
-17
lines changed

7 files changed

+194
-17
lines changed

docs/faq.rst

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,138 @@ Additional reading:
7272

7373
* JAX_sharp_bits_
7474

75+
.. comment We refer to the anchor below in JAX error messages
76+
77+
`Abstract tracer value encountered where concrete value is expected` error
78+
--------------------------------------------------------------------------
79+
80+
If you are getting an error that a library function is called with
81+
*"Abstract tracer value encountered where concrete value is expected"*, you may need to
82+
change how you invoke JAX transformations. We give first an example, and
83+
a couple of solutions, and then we explain in more detail what is actually
84+
happening, if you are curious or the simple solution does not work for you.
85+
86+
Some library functions take arguments that specify shapes or axes,
87+
such as the 2nd and 3rd arguments for :func:`jax.numpy.split`::
88+
89+
# def np.split(arr, num_sections: Union[int, Sequence[int]], axis: int):
90+
np.split(np.zeros(2), 2, 0) # works
91+
92+
If you try the following code::
93+
94+
jax.jit(np.split)(np.zeros(4), 2, 0)
95+
96+
you will get the following error::
97+
98+
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in jax.numpy.split argument 1).
99+
Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
100+
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
101+
Encountered value: Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=-1/1)>
102+
103+
We must change the way we use :func:`jax.jit` to ensure that the ``num_sections``
104+
and ``axis`` arguments use their concrete values (``2`` and ``0`` respectively).
105+
The best mechanism is to use special transformation parameters
106+
to declare some arguments to be static, e.g., ``static_argnums`` for :func:`jax.jit`::
107+
108+
jax.jit(np.split, static_argnums=(1, 2))(np.zeros(4), 2, 0)
109+
110+
An alternative is to apply the transformation to a closure
111+
that encapsulates the arguments to be protected, either manually as below
112+
or by using ``functools.partial``::
113+
114+
jax.jit(lambda arr: np.split(arr, 2, 0))(np.zeros(4))
115+
116+
**Note a new closure is created at every invocation, which defeats the
117+
compilation caching mechanism, which is why static_argnums is preferred.**
118+
119+
To understand more subtleties having to do with tracers vs. regular values, and
120+
concrete vs. abstract values, you may want to read `Different kinds of JAX values`_.
121+
122+
Different kinds of JAX values
123+
------------------------------
124+
125+
In the process of transforming functions, JAX replaces some some function
126+
arguments with special tracer values.
127+
You could see this if you use a ``print`` statement::
128+
129+
def func(x):
130+
print(x)
131+
return np.cos(x)
132+
133+
res = jax.jit(func)(0.)
134+
135+
The above code does return the correct value ``1.`` but it also prints
136+
``Traced<ShapedArray(float32[])>`` for the value of ``x``. Normally, JAX
137+
handles these tracer values internally in a transparent way, e.g.,
138+
in the numeric JAX primitives that are used to implement the
139+
``jax.numpy`` functions. This is why ``np.cos`` works in the example above.
140+
141+
More precisely, a **tracer** value is introduced for the argument of
142+
a JAX-transformed function, except the arguments identified by special
143+
parameters such as ``static_argnums`` for :func:`jax.jit` or
144+
``static_broadcasted_argnums`` for :func:`jax.pmap`. Typically, computations
145+
that involve at least a tracer value will produce a tracer value. Besides tracer
146+
values, there are **regular** Python values: values that are computed outside JAX
147+
transformations, or arise from above-mentioned static arguments of certain JAX
148+
transformations, or computed solely from other regular Python values.
149+
These are the values that are used everywhere in absence of JAX transformations.
150+
151+
A tracer value carries an **abstract** value, e.g., ``ShapedArray`` with information
152+
about the shape and dtype of an array. We will refer here to such tracers as
153+
**abstract tracers**. Some tracers, e.g., those that are
154+
introduced for arguments of autodiff transformations, carry ``ConcreteArray``
155+
abstract values that actually include the regular array data, and are used,
156+
e.g., for resolving conditionals. We will refer here to such tracers
157+
as **concrete tracers**. Tracer values computed from these concrete tracers,
158+
perhaps in combination with regular values, result in concrete tracers.
159+
A **concrete value** is either a regular value or a concrete tracer.
160+
161+
Most often values computed from tracer values are themselves tracer values.
162+
There are very few exceptions, when a computation can be entirely done
163+
using the abstract value carried by a tracer, in which case the result
164+
can be a regular value. For example, getting the shape of a tracer
165+
with ``ShapedArray`` abstract value. Another example, is when explicitly
166+
casting a concrete tracer value to a regular type, e.g., ``int(x)`` or
167+
``x.astype(float)``.
168+
Another such situation is for ``bool(x)``, which produces a Python bool when
169+
concreteness makes it possible. That case is especially salient because
170+
of how often it arises in control flow.
171+
172+
Here is how the transformations introduce abstract or concrete tracers:
173+
174+
* :func:`jax.jit`: introduces **abstract tracers** for all positional arguments
175+
except those denoted by ``static_argnums``, which remain regular
176+
values.
177+
* :func:`jax.pmap`: introduces **abstract tracers** for all positional arguments
178+
except those denoted by ``static_broadcasted_argnums``.
179+
* :func:`jax.vmap`, :func:`jax.make_jaxpr`, :func:`xla_computation`:
180+
introduce **abstract tracers** for all positional arguments.
181+
* :func:`jax.jvp` and :func:`jax.grad` introduce **concrete tracers**
182+
for all positional arguments. An exception is when these transformations
183+
are within an outer transformation and the actual arguments are
184+
themselves abstract tracers; in that case, the tracers introduced
185+
by the autodiff transformations are also abstract tracers.
186+
* All higher-order control-flow primitives (:func:`lax.cond`, :func:`lax.while_loop`,
187+
:func:`lax.fori_loop`, :func:`lax.scan`) when they process the functionals
188+
introduce **abstract tracers**, whether or not there is a JAX transformation
189+
in progress.
190+
191+
All of this is relevant when you have code that can operate
192+
only on regular Python values, such as code that has conditional
193+
control-flow based on data::
194+
195+
def divide(x, y):
196+
return x / y if y >= 1. else 0.
197+
198+
If we want to apply :func:`jax.jit`, we must ensure to specify ``static_argnums=1``
199+
to ensure ``y`` stays a regular value. This is due to the boolean expression
200+
``y >= 1.``, which requires concrete values (regular or tracers). The
201+
same would happen if we write explicitly ``bool(y >= 1.)``, or ``int(y)``,
202+
or ``float(y)``.
203+
204+
Interestingly, ``jax.grad(divide)(3., 2.)``, works because :func:`jax.grad`
205+
uses concrete tracers, and resolves the conditional using the concrete
206+
value of ``y``.
75207

76208
Gradients contain `NaN` where using ``where``
77209
------------------------------------------------

jax/abstract_arrays.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
AbstractToken = core.AbstractToken
2727
abstract_token = core.abstract_token
2828
canonicalize_shape = core.canonicalize_shape
29-
concretization_err_msg = core.concretization_err_msg
3029
raise_to_shaped = core.raise_to_shaped
3130

3231

jax/core.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -714,21 +714,39 @@ def __repr__(self): return '*'
714714
identity_p.def_impl(lambda x: x)
715715
identity_p.def_custom_bind(lambda x: x)
716716

717-
def concretization_err_msg(fun, context=None):
717+
class ConcretizationTypeError(TypeError): pass
718+
719+
def raise_concretization_error(val, context=""):
720+
msg = (f"Abstract tracer value encountered where concrete value is expected ({context}).\n"
721+
"Use transformation parameters such as `static_argnums` for `jit` "
722+
"to avoid tracing input values.\n"
723+
"See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error`.\n"
724+
f"Encountered value: {val}")
725+
raise ConcretizationTypeError(msg)
726+
727+
728+
def concretization_function_error(fun, context=""):
718729
fname = getattr(fun, "__name__", fun)
719-
if context is None:
720-
context = ("The function to be transformed can't be traced at the required level "
721-
"of abstraction. If using `jit`, try using `static_argnums` or "
722-
"applying `jit` to smaller subfunctions instead.")
723-
msg = "Abstract value passed to `{}`, which requires a concrete value. {}"
724-
return msg.format(fname, context)
725-
726-
def concretization_function_error(fun, context=None):
727-
def error(self, *args):
728-
raise TypeError(concretization_err_msg(fun, context))
730+
fname_context = f"in `{fname}`"
731+
if context:
732+
fname_context += f" {context}"
733+
def error(self, arg):
734+
raise_concretization_error(arg, fname_context)
729735
return error
730736

731737

738+
def concrete_or_error(typ: Type, val: Any, context=""):
739+
"""Like typ(val), but gives the context in the error message.
740+
Use with typ either `int`, or `bool`.
741+
"""
742+
if isinstance(val, Tracer):
743+
if isinstance(val.aval, ConcreteArray):
744+
return typ(val.aval.val)
745+
else:
746+
raise_concretization_error(val, context)
747+
else:
748+
return typ(val)
749+
732750
class UnshapedArray(AbstractValue):
733751
__slots__ = ['dtype', 'weak_type']
734752
array_abstraction_level = 2

jax/numpy/lax_numpy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,6 +1298,14 @@ def broadcast_to(arr, shape):
12981298
@_wraps(onp.split)
12991299
def split(ary, indices_or_sections, axis=0):
13001300
dummy_val = onp.broadcast_to(0, ary.shape) # zero strides
1301+
if isinstance(indices_or_sections, (tuple, list) + _arraylike_types):
1302+
indices_or_sections = [core.concrete_or_error(int, i_s, "in jax.numpy.split argument 1")
1303+
for i_s in indices_or_sections]
1304+
else:
1305+
indices_or_sections = core.concrete_or_error(int, indices_or_sections,
1306+
"in jax.numpy.split argument 1")
1307+
axis = core.concrete_or_error(int, axis, "in jax.numpy.split argument `axis`")
1308+
13011309
subarrays = onp.split(dummy_val, indices_or_sections, axis) # shapes
13021310
split_indices = onp.cumsum([0] + [onp.shape(sub)[axis] for sub in subarrays])
13031311
starts, ends = [0] * ndim(ary), shape(ary)

tests/api_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from jax.core import Primitive
3636
from jax.interpreters import ad
3737
from jax.interpreters import xla
38-
from jax.abstract_arrays import concretization_err_msg
3938
from jax.lib import xla_bridge as xb
4039
from jax import test_util as jtu
4140
from jax import tree_util
@@ -225,7 +224,9 @@ def f(x):
225224

226225
assert grad(f)(1.0) == 1.0
227226
assert grad(f)(-1.0) == -1.0
228-
jtu.check_raises(lambda: jit(f)(1), TypeError, concretization_err_msg(bool))
227+
with self.assertRaisesRegex(core.ConcretizationTypeError,
228+
"Abstract tracer value encountered where concrete value is expected"):
229+
jit(f)(1)
229230

230231
def test_range_err(self):
231232
def f(x, n):
@@ -246,7 +247,7 @@ def test_casts(self):
246247
self.assertRaisesRegex(
247248
TypeError,
248249
"('JaxprTracer' object cannot be interpreted as an integer"
249-
"|Abstract value passed to .*)", lambda: jit(f)(0))
250+
"|Abstract tracer value encountered where concrete value is expected .*)", lambda: jit(f)(0))
250251

251252
def test_unimplemented_interpreter_rules(self):
252253
foo_p = Primitive('foo')

tests/lax_numpy_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,6 +1494,25 @@ def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng_factory):
14941494
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True)
14951495
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
14961496

1497+
def testSplitTypeError(self):
1498+
# If we pass an ndarray for indices_or_sections -> no error
1499+
self.assertEqual(3, len(jnp.split(jnp.zeros(3), jnp.array([1, 2]))))
1500+
1501+
CONCRETIZATION_MSG = "Abstract tracer value encountered where concrete value is expected."
1502+
with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG):
1503+
# An abstract tracer for idx
1504+
api.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), idx))(2.)
1505+
with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG):
1506+
# A list including an abstract tracer
1507+
api.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), [2, idx]))(2.)
1508+
1509+
# A concrete tracer -> no error
1510+
api.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), idx),
1511+
(2.,), (1.,))
1512+
# A tuple including a concrete tracer -> no error
1513+
api.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), (1, idx)),
1514+
(2,), (1,))
1515+
14971516
@parameterized.named_parameters(jtu.cases_from_list(
14981517
{"testcase_name": "_{}_axis={}_{}sections".format(
14991518
jtu.format_shape_dtype_string(shape, dtype), axis, num_sections),

tests/loops_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,9 @@ def f_op(start, end, inc):
285285
self.assertAllClose(16., f_op(0, 4, 4.), check_dtypes=True)
286286
# Ok to jit, as long as the start and end are static
287287
self.assertAllClose(16., api.jit(f_op, static_argnums=(0, 1))(0, 4, 4.), check_dtypes=True)
288-
with self.assertRaisesRegex(TypeError, "Abstract value passed to `int`, which requires a concrete value"):
288+
with self.assertRaisesRegex(TypeError, "Abstract tracer value encountered where concrete value is expected"):
289289
self.assertAllClose(16., api.jit(f_op)(0, 4, 4.), check_dtypes=True)
290-
with self.assertRaisesRegex(TypeError, "Abstract value passed to `int`, which requires a concrete value"):
290+
with self.assertRaisesRegex(TypeError, "Abstract tracer value encountered where concrete value is expected"):
291291
self.assertAllClose(16., api.vmap(f_op)(np.zeros(10), np.ones(10), np.array([4.] * 10)), check_dtypes=True)
292292

293293
def test_cond(self):

0 commit comments

Comments
 (0)