Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,138 @@ Additional reading:

* `JAX - The Sharp Bits: Pure Functions <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Pure-functions>`_.

.. comment We refer to the anchor below in JAX error messages

`Abstract tracer value encountered where concrete value is expected` error
--------------------------------------------------------------------------

If you are getting an error that a library function is called with
*"Abstract tracer value encountered where concrete value is expected"*, you may need to
change how you invoke JAX transformations. We give first an example, and
a couple of solutions, and then we explain in more detail what is actually
happening, if you are curious or the simple solution does not work for you.

Some library functions take arguments that specify shapes or axes,
such as the 2nd and 3rd arguments for :func:`jax.numpy.split`::

# def np.split(arr, num_sections: Union[int, Sequence[int]], axis: int):
np.split(np.zeros(2), 2, 0) # works

If you try the following code::

jax.jit(np.split)(np.zeros(4), 2, 0)

you will get the following error::

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in jax.numpy.split argument 1).
Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=-1/1)>

We must change the way we use :func:`jax.jit` to ensure that the ``num_sections``
and ``axis`` arguments use their concrete values (``2`` and ``0`` respectively).
The best mechanism is to use special transformation parameters
to declare some arguments to be static, e.g., ``static_argnums`` for :func:`jax.jit`::

jax.jit(np.split, static_argnums=(1, 2))(np.zeros(4), 2, 0)

An alternative is to apply the transformation to a closure
that encapsulates the arguments to be protected, either manually as below
or by using ``functools.partial``::

jax.jit(lambda arr: np.split(arr, 2, 0))(np.zeros(4))

**Note a new closure is created at every invocation, which defeats the
compilation caching mechanism, which is why static_argnums is preferred.**

To understand more subtleties having to do with tracers vs. regular values, and
concrete vs. abstract values, you may want to read `Different kinds of JAX values`_.

Different kinds of JAX values
------------------------------

In the process of transforming functions, JAX replaces some some function
arguments with special tracer values.
You could see this if you use a ``print`` statement::

def func(x):
print(x)
return np.cos(x)

res = jax.jit(func)(0.)

The above code does return the correct value ``1.`` but it also prints
``Traced<ShapedArray(float32[])>`` for the value of ``x``. Normally, JAX
handles these tracer values internally in a transparent way, e.g.,
in the numeric JAX primitives that are used to implement the
``jax.numpy`` functions. This is why ``np.cos`` works in the example above.

More precisely, a **tracer** value is introduced for the argument of
a JAX-transformed function, except the arguments identified by special
parameters such as ``static_argnums`` for :func:`jax.jit` or
``static_broadcasted_argnums`` for :func:`jax.pmap`. Typically, computations
that involve at least a tracer value will produce a tracer value. Besides tracer
values, there are **regular** Python values: values that are computed outside JAX
transformations, or arise from above-mentioned static arguments of certain JAX
transformations, or computed solely from other regular Python values.
These are the values that are used everywhere in absence of JAX transformations.

A tracer value carries an **abstract** value, e.g., ``ShapedArray`` with information
about the shape and dtype of an array. We will refer here to such tracers as
**abstract tracers**. Some tracers, e.g., those that are
introduced for arguments of autodiff transformations, carry ``ConcreteArray``
abstract values that actually include the regular array data, and are used,
e.g., for resolving conditionals. We will refer here to such tracers
as **concrete tracers**. Tracer values computed from these concrete tracers,
perhaps in combination with regular values, result in concrete tracers.
A **concrete value** is either a regular value or a concrete tracer.

Most often values computed from tracer values are themselves tracer values.
There are very few exceptions, when a computation can be entirely done
using the abstract value carried by a tracer, in which case the result
can be a regular value. For example, getting the shape of a tracer
with ``ShapedArray`` abstract value. Another example, is when explicitly
casting a concrete tracer value to a regular type, e.g., ``int(x)`` or
``x.astype(float)``.
Another such situation is for ``bool(x)``, which produces a Python bool when
concreteness makes it possible. That case is especially salient because
of how often it arises in control flow.

Here is how the transformations introduce abstract or concrete tracers:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "abstract tracers" really means "tracers with abstract values raised above the concrete level." Maybe it'd be good to define that term precisely?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The text above said:

A tracer value carries an abstract value, e.g., ShapedArray with information
about the shape and dtype of an array.

I have added

We will refer here to such tracers as abstract tracers.

So that I can refer to them as "abstract tracers". I have not used the word "raised" because
I do not think it is necessary here to talk about abstraction levels and why a level is higher than others.


* :func:`jax.jit`: introduces **abstract tracers** for all positional arguments
Copy link
Collaborator

@mattjj mattjj Apr 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of "abstract tracers" how about "tracers with abstract values raised to at least the Shaped level"?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added definition above

except those denoted by ``static_argnums``, which remain regular
values.
* :func:`jax.pmap`: introduces **abstract tracers** for all positional arguments
except those denoted by ``static_broadcasted_argnums``.
* :func:`jax.vmap`, :func:`jax.make_jaxpr`, :func:`xla_computation`:
introduce **abstract tracers** for all positional arguments.
* :func:`jax.jvp` and :func:`jax.grad` introduce **concrete tracers**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite true, because grad might be called inside a jit, in which case the trace happens at shape level. It's more accurate to talk about levels of abstraction and say that transformations like jit and vmap raise the abstraction level to "shaped" while jvp doesn't raise it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added additional text to the explanation for grad.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with James that it's better to say "raise the level of abstraction". For example, instead of "introduces abstract tracers" how about "introduces tracers with abstract values raised to at least the ____ level", where ____ is different for different transformations?

for all positional arguments. An exception is when these transformations
are within an outer transformation and the actual arguments are
themselves abstract tracers; in that case, the tracers introduced
by the autodiff transformations are also abstract tracers.
* All higher-order control-flow primitives (:func:`lax.cond`, :func:`lax.while_loop`,
:func:`lax.fori_loop`, :func:`lax.scan`) when they process the functionals
introduce **abstract tracers**, whether or not there is a JAX transformation
in progress.

All of this is relevant when you have code that can operate
only on regular Python values, such as code that has conditional
control-flow based on data::

def divide(x, y):
return x / y if y >= 1. else 0.

If we want to apply :func:`jax.jit`, we must ensure to specify ``static_argnums=1``
to ensure ``y`` stays a regular value. This is due to the boolean expression
``y >= 1.``, which requires concrete values (regular or tracers). The
same would happen if we write explicitly ``bool(y >= 1.)``, or ``int(y)``,
or ``float(y)``.

Interestingly, ``jax.grad(divide)(3., 2.)``, works because :func:`jax.grad`
uses concrete tracers, and resolves the conditional using the concrete
value of ``y``.

Gradients contain `NaN` where using ``where``
------------------------------------------------
Expand Down
1 change: 0 additions & 1 deletion jax/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
AbstractToken = core.AbstractToken
abstract_token = core.abstract_token
canonicalize_shape = core.canonicalize_shape
concretization_err_msg = core.concretization_err_msg
raise_to_shaped = core.raise_to_shaped


Expand Down
40 changes: 29 additions & 11 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,21 +714,39 @@ def __repr__(self): return '*'
identity_p.def_impl(lambda x: x)
identity_p.def_custom_bind(lambda x: x)

def concretization_err_msg(fun, context=None):
class ConcretizationTypeError(TypeError): pass

def raise_concretization_error(val, context=""):
msg = (f"Abstract tracer value encountered where concrete value is expected ({context}).\n"
"Use transformation parameters such as `static_argnums` for `jit` "
"to avoid tracing input values.\n"
"See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error`.\n"
f"Encountered value: {val}")
raise ConcretizationTypeError(msg)


def concretization_function_error(fun, context=""):
fname = getattr(fun, "__name__", fun)
if context is None:
context = ("The function to be transformed can't be traced at the required level "
"of abstraction. If using `jit`, try using `static_argnums` or "
"applying `jit` to smaller subfunctions instead.")
msg = "Abstract value passed to `{}`, which requires a concrete value. {}"
return msg.format(fname, context)

def concretization_function_error(fun, context=None):
def error(self, *args):
raise TypeError(concretization_err_msg(fun, context))
fname_context = f"in `{fname}`"
if context:
fname_context += f" {context}"
def error(self, arg):
raise_concretization_error(arg, fname_context)
return error


def concrete_or_error(typ: Type, val: Any, context=""):
"""Like typ(val), but gives the context in the error message.
Use with typ either `int`, or `bool`.
"""
if isinstance(val, Tracer):
if isinstance(val.aval, ConcreteArray):
return typ(val.aval.val)
else:
raise_concretization_error(val, context)
else:
return typ(val)

class UnshapedArray(AbstractValue):
__slots__ = ['dtype', 'weak_type']
array_abstraction_level = 2
Expand Down
8 changes: 8 additions & 0 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,14 @@ def broadcast_to(arr, shape):
@_wraps(onp.split)
def split(ary, indices_or_sections, axis=0):
dummy_val = onp.broadcast_to(0, ary.shape) # zero strides
if isinstance(indices_or_sections, (tuple, list) + _arraylike_types):
indices_or_sections = [core.concrete_or_error(int, i_s, "in jax.numpy.split argument 1")
for i_s in indices_or_sections]
else:
indices_or_sections = core.concrete_or_error(int, indices_or_sections,
"in jax.numpy.split argument 1")
axis = core.concrete_or_error(int, axis, "in jax.numpy.split argument `axis`")

subarrays = onp.split(dummy_val, indices_or_sections, axis) # shapes
split_indices = onp.cumsum([0] + [onp.shape(sub)[axis] for sub in subarrays])
starts, ends = [0] * ndim(ary), shape(ary)
Expand Down
7 changes: 4 additions & 3 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from jax.core import Primitive
from jax.interpreters import ad
from jax.interpreters import xla
from jax.abstract_arrays import concretization_err_msg
from jax.lib import xla_bridge as xb
from jax import test_util as jtu
from jax import tree_util
Expand Down Expand Up @@ -225,7 +224,9 @@ def f(x):

assert grad(f)(1.0) == 1.0
assert grad(f)(-1.0) == -1.0
jtu.check_raises(lambda: jit(f)(1), TypeError, concretization_err_msg(bool))
with self.assertRaisesRegex(core.ConcretizationTypeError,
"Abstract tracer value encountered where concrete value is expected"):
jit(f)(1)

def test_range_err(self):
def f(x, n):
Expand All @@ -246,7 +247,7 @@ def test_casts(self):
self.assertRaisesRegex(
TypeError,
"('JaxprTracer' object cannot be interpreted as an integer"
"|Abstract value passed to .*)", lambda: jit(f)(0))
"|Abstract tracer value encountered where concrete value is expected .*)", lambda: jit(f)(0))

def test_unimplemented_interpreter_rules(self):
foo_p = Primitive('foo')
Expand Down
19 changes: 19 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,25 @@ def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng_factory):
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)

def testSplitTypeError(self):
# If we pass an ndarray for indices_or_sections -> no error
self.assertEqual(3, len(jnp.split(jnp.zeros(3), jnp.array([1, 2]))))

CONCRETIZATION_MSG = "Abstract tracer value encountered where concrete value is expected."
with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG):
# An abstract tracer for idx
api.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), idx))(2.)
with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG):
# A list including an abstract tracer
api.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), [2, idx]))(2.)

# A concrete tracer -> no error
api.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), idx),
(2.,), (1.,))
# A tuple including a concrete tracer -> no error
api.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), (1, idx)),
(2,), (1,))

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_{}sections".format(
jtu.format_shape_dtype_string(shape, dtype), axis, num_sections),
Expand Down
4 changes: 2 additions & 2 deletions tests/loops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@ def f_op(start, end, inc):
self.assertAllClose(16., f_op(0, 4, 4.), check_dtypes=True)
# Ok to jit, as long as the start and end are static
self.assertAllClose(16., api.jit(f_op, static_argnums=(0, 1))(0, 4, 4.), check_dtypes=True)
with self.assertRaisesRegex(TypeError, "Abstract value passed to `int`, which requires a concrete value"):
with self.assertRaisesRegex(TypeError, "Abstract tracer value encountered where concrete value is expected"):
self.assertAllClose(16., api.jit(f_op)(0, 4, 4.), check_dtypes=True)
with self.assertRaisesRegex(TypeError, "Abstract value passed to `int`, which requires a concrete value"):
with self.assertRaisesRegex(TypeError, "Abstract tracer value encountered where concrete value is expected"):
self.assertAllClose(16., api.vmap(f_op)(np.zeros(10), np.ones(10), np.array([4.] * 10)), check_dtypes=True)

def test_cond(self):
Expand Down