-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Added clearer error message for tracers in numpy.split #2508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7885271
1904f36
fabd9ea
e087afa
c5f4b84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,6 +65,138 @@ Additional reading: | |
|
||
* `JAX - The Sharp Bits: Pure Functions <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Pure-functions>`_. | ||
|
||
.. comment We refer to the anchor below in JAX error messages | ||
|
||
`Abstract tracer value encountered where concrete value is expected` error | ||
-------------------------------------------------------------------------- | ||
|
||
If you are getting an error that a library function is called with | ||
*"Abstract tracer value encountered where concrete value is expected"*, you may need to | ||
change how you invoke JAX transformations. We give first an example, and | ||
a couple of solutions, and then we explain in more detail what is actually | ||
happening, if you are curious or the simple solution does not work for you. | ||
|
||
Some library functions take arguments that specify shapes or axes, | ||
such as the 2nd and 3rd arguments for :func:`jax.numpy.split`:: | ||
|
||
# def np.split(arr, num_sections: Union[int, Sequence[int]], axis: int): | ||
np.split(np.zeros(2), 2, 0) # works | ||
|
||
If you try the following code:: | ||
|
||
jax.jit(np.split)(np.zeros(4), 2, 0) | ||
|
||
you will get the following error:: | ||
|
||
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in jax.numpy.split argument 1). | ||
Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values. | ||
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`. | ||
Encountered value: Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=-1/1)> | ||
|
||
We must change the way we use :func:`jax.jit` to ensure that the ``num_sections`` | ||
and ``axis`` arguments use their concrete values (``2`` and ``0`` respectively). | ||
The best mechanism is to use special transformation parameters | ||
to declare some arguments to be static, e.g., ``static_argnums`` for :func:`jax.jit`:: | ||
|
||
jax.jit(np.split, static_argnums=(1, 2))(np.zeros(4), 2, 0) | ||
|
||
An alternative is to apply the transformation to a closure | ||
that encapsulates the arguments to be protected, either manually as below | ||
or by using ``functools.partial``:: | ||
|
||
jax.jit(lambda arr: np.split(arr, 2, 0))(np.zeros(4)) | ||
|
||
**Note a new closure is created at every invocation, which defeats the | ||
compilation caching mechanism, which is why static_argnums is preferred.** | ||
|
||
To understand more subtleties having to do with tracers vs. regular values, and | ||
concrete vs. abstract values, you may want to read `Different kinds of JAX values`_. | ||
|
||
Different kinds of JAX values | ||
------------------------------ | ||
|
||
In the process of transforming functions, JAX replaces some some function | ||
arguments with special tracer values. | ||
You could see this if you use a ``print`` statement:: | ||
|
||
def func(x): | ||
print(x) | ||
return np.cos(x) | ||
|
||
res = jax.jit(func)(0.) | ||
|
||
The above code does return the correct value ``1.`` but it also prints | ||
``Traced<ShapedArray(float32[])>`` for the value of ``x``. Normally, JAX | ||
handles these tracer values internally in a transparent way, e.g., | ||
in the numeric JAX primitives that are used to implement the | ||
``jax.numpy`` functions. This is why ``np.cos`` works in the example above. | ||
|
||
More precisely, a **tracer** value is introduced for the argument of | ||
a JAX-transformed function, except the arguments identified by special | ||
parameters such as ``static_argnums`` for :func:`jax.jit` or | ||
``static_broadcasted_argnums`` for :func:`jax.pmap`. Typically, computations | ||
that involve at least a tracer value will produce a tracer value. Besides tracer | ||
values, there are **regular** Python values: values that are computed outside JAX | ||
transformations, or arise from above-mentioned static arguments of certain JAX | ||
transformations, or computed solely from other regular Python values. | ||
These are the values that are used everywhere in absence of JAX transformations. | ||
|
||
A tracer value carries an **abstract** value, e.g., ``ShapedArray`` with information | ||
about the shape and dtype of an array. We will refer here to such tracers as | ||
**abstract tracers**. Some tracers, e.g., those that are | ||
introduced for arguments of autodiff transformations, carry ``ConcreteArray`` | ||
abstract values that actually include the regular array data, and are used, | ||
e.g., for resolving conditionals. We will refer here to such tracers | ||
as **concrete tracers**. Tracer values computed from these concrete tracers, | ||
perhaps in combination with regular values, result in concrete tracers. | ||
A **concrete value** is either a regular value or a concrete tracer. | ||
|
||
Most often values computed from tracer values are themselves tracer values. | ||
There are very few exceptions, when a computation can be entirely done | ||
using the abstract value carried by a tracer, in which case the result | ||
can be a regular value. For example, getting the shape of a tracer | ||
with ``ShapedArray`` abstract value. Another example, is when explicitly | ||
casting a concrete tracer value to a regular type, e.g., ``int(x)`` or | ||
``x.astype(float)``. | ||
Another such situation is for ``bool(x)``, which produces a Python bool when | ||
concreteness makes it possible. That case is especially salient because | ||
of how often it arises in control flow. | ||
|
||
Here is how the transformations introduce abstract or concrete tracers: | ||
|
||
* :func:`jax.jit`: introduces **abstract tracers** for all positional arguments | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of "abstract tracers" how about "tracers with abstract values raised to at least the Shaped level"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same below. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added definition above |
||
except those denoted by ``static_argnums``, which remain regular | ||
values. | ||
* :func:`jax.pmap`: introduces **abstract tracers** for all positional arguments | ||
except those denoted by ``static_broadcasted_argnums``. | ||
* :func:`jax.vmap`, :func:`jax.make_jaxpr`, :func:`xla_computation`: | ||
introduce **abstract tracers** for all positional arguments. | ||
* :func:`jax.jvp` and :func:`jax.grad` introduce **concrete tracers** | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't quite true, because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added additional text to the explanation for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with James that it's better to say "raise the level of abstraction". For example, instead of "introduces abstract tracers" how about "introduces tracers with abstract values raised to at least the ____ level", where ____ is different for different transformations? |
||
for all positional arguments. An exception is when these transformations | ||
are within an outer transformation and the actual arguments are | ||
themselves abstract tracers; in that case, the tracers introduced | ||
by the autodiff transformations are also abstract tracers. | ||
* All higher-order control-flow primitives (:func:`lax.cond`, :func:`lax.while_loop`, | ||
:func:`lax.fori_loop`, :func:`lax.scan`) when they process the functionals | ||
introduce **abstract tracers**, whether or not there is a JAX transformation | ||
in progress. | ||
|
||
All of this is relevant when you have code that can operate | ||
only on regular Python values, such as code that has conditional | ||
control-flow based on data:: | ||
|
||
def divide(x, y): | ||
return x / y if y >= 1. else 0. | ||
|
||
If we want to apply :func:`jax.jit`, we must ensure to specify ``static_argnums=1`` | ||
to ensure ``y`` stays a regular value. This is due to the boolean expression | ||
``y >= 1.``, which requires concrete values (regular or tracers). The | ||
same would happen if we write explicitly ``bool(y >= 1.)``, or ``int(y)``, | ||
or ``float(y)``. | ||
|
||
Interestingly, ``jax.grad(divide)(3., 2.)``, works because :func:`jax.grad` | ||
uses concrete tracers, and resolves the conditional using the concrete | ||
value of ``y``. | ||
|
||
Gradients contain `NaN` where using ``where`` | ||
------------------------------------------------ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think "abstract tracers" really means "tracers with abstract values raised above the concrete level." Maybe it'd be good to define that term precisely?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The text above said:
I have added
So that I can refer to them as "abstract tracers". I have not used the word "raised" because
I do not think it is necessary here to talk about abstraction levels and why a level is higher than others.