@@ -72,6 +72,138 @@ Additional reading:
72
72
73
73
* JAX_sharp_bits _
74
74
75
+ .. comment We refer to the anchor below in JAX error messages
76
+
77
+ `Abstract tracer value encountered where concrete value is expected ` error
78
+ --------------------------------------------------------------------------
79
+
80
+ If you are getting an error that a library function is called with
81
+ *"Abstract tracer value encountered where concrete value is expected" *, you may need to
82
+ change how you invoke JAX transformations. We give first an example, and
83
+ a couple of solutions, and then we explain in more detail what is actually
84
+ happening, if you are curious or the simple solution does not work for you.
85
+
86
+ Some library functions take arguments that specify shapes or axes,
87
+ such as the 2nd and 3rd arguments for :func: `jax.numpy.split `::
88
+
89
+ # def np.split(arr, num_sections: Union[int, Sequence[int]], axis: int):
90
+ np.split(np.zeros(2), 2, 0) # works
91
+
92
+ If you try the following code::
93
+
94
+ jax.jit(np.split)(np.zeros(4), 2, 0)
95
+
96
+ you will get the following error::
97
+
98
+ ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in jax.numpy.split argument 1).
99
+ Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
100
+ See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
101
+ Encountered value: Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=-1/1)>
102
+
103
+ We must change the way we use :func: `jax.jit ` to ensure that the ``num_sections ``
104
+ and ``axis `` arguments use their concrete values (``2 `` and ``0 `` respectively).
105
+ The best mechanism is to use special transformation parameters
106
+ to declare some arguments to be static, e.g., ``static_argnums `` for :func: `jax.jit `::
107
+
108
+ jax.jit(np.split, static_argnums=(1, 2))(np.zeros(4), 2, 0)
109
+
110
+ An alternative is to apply the transformation to a closure
111
+ that encapsulates the arguments to be protected, either manually as below
112
+ or by using ``functools.partial ``::
113
+
114
+ jax.jit(lambda arr: np.split(arr, 2, 0))(np.zeros(4))
115
+
116
+ **Note a new closure is created at every invocation, which defeats the
117
+ compilation caching mechanism, which is why static_argnums is preferred. **
118
+
119
+ To understand more subtleties having to do with tracers vs. regular values, and
120
+ concrete vs. abstract values, you may want to read `Different kinds of JAX values `_.
121
+
122
+ Different kinds of JAX values
123
+ ------------------------------
124
+
125
+ In the process of transforming functions, JAX replaces some some function
126
+ arguments with special tracer values.
127
+ You could see this if you use a ``print `` statement::
128
+
129
+ def func(x):
130
+ print(x)
131
+ return np.cos(x)
132
+
133
+ res = jax.jit(func)(0.)
134
+
135
+ The above code does return the correct value ``1. `` but it also prints
136
+ ``Traced<ShapedArray(float32[])> `` for the value of ``x ``. Normally, JAX
137
+ handles these tracer values internally in a transparent way, e.g.,
138
+ in the numeric JAX primitives that are used to implement the
139
+ ``jax.numpy `` functions. This is why ``np.cos `` works in the example above.
140
+
141
+ More precisely, a **tracer ** value is introduced for the argument of
142
+ a JAX-transformed function, except the arguments identified by special
143
+ parameters such as ``static_argnums `` for :func: `jax.jit ` or
144
+ ``static_broadcasted_argnums `` for :func: `jax.pmap `. Typically, computations
145
+ that involve at least a tracer value will produce a tracer value. Besides tracer
146
+ values, there are **regular ** Python values: values that are computed outside JAX
147
+ transformations, or arise from above-mentioned static arguments of certain JAX
148
+ transformations, or computed solely from other regular Python values.
149
+ These are the values that are used everywhere in absence of JAX transformations.
150
+
151
+ A tracer value carries an **abstract ** value, e.g., ``ShapedArray `` with information
152
+ about the shape and dtype of an array. We will refer here to such tracers as
153
+ **abstract tracers **. Some tracers, e.g., those that are
154
+ introduced for arguments of autodiff transformations, carry ``ConcreteArray ``
155
+ abstract values that actually include the regular array data, and are used,
156
+ e.g., for resolving conditionals. We will refer here to such tracers
157
+ as **concrete tracers **. Tracer values computed from these concrete tracers,
158
+ perhaps in combination with regular values, result in concrete tracers.
159
+ A **concrete value ** is either a regular value or a concrete tracer.
160
+
161
+ Most often values computed from tracer values are themselves tracer values.
162
+ There are very few exceptions, when a computation can be entirely done
163
+ using the abstract value carried by a tracer, in which case the result
164
+ can be a regular value. For example, getting the shape of a tracer
165
+ with ``ShapedArray `` abstract value. Another example, is when explicitly
166
+ casting a concrete tracer value to a regular type, e.g., ``int(x) `` or
167
+ ``x.astype(float) ``.
168
+ Another such situation is for ``bool(x) ``, which produces a Python bool when
169
+ concreteness makes it possible. That case is especially salient because
170
+ of how often it arises in control flow.
171
+
172
+ Here is how the transformations introduce abstract or concrete tracers:
173
+
174
+ * :func: `jax.jit `: introduces **abstract tracers ** for all positional arguments
175
+ except those denoted by ``static_argnums ``, which remain regular
176
+ values.
177
+ * :func: `jax.pmap `: introduces **abstract tracers ** for all positional arguments
178
+ except those denoted by ``static_broadcasted_argnums ``.
179
+ * :func: `jax.vmap `, :func: `jax.make_jaxpr `, :func: `xla_computation `:
180
+ introduce **abstract tracers ** for all positional arguments.
181
+ * :func: `jax.jvp ` and :func: `jax.grad ` introduce **concrete tracers **
182
+ for all positional arguments. An exception is when these transformations
183
+ are within an outer transformation and the actual arguments are
184
+ themselves abstract tracers; in that case, the tracers introduced
185
+ by the autodiff transformations are also abstract tracers.
186
+ * All higher-order control-flow primitives (:func: `lax.cond `, :func: `lax.while_loop `,
187
+ :func: `lax.fori_loop `, :func: `lax.scan `) when they process the functionals
188
+ introduce **abstract tracers **, whether or not there is a JAX transformation
189
+ in progress.
190
+
191
+ All of this is relevant when you have code that can operate
192
+ only on regular Python values, such as code that has conditional
193
+ control-flow based on data::
194
+
195
+ def divide(x, y):
196
+ return x / y if y >= 1. else 0.
197
+
198
+ If we want to apply :func: `jax.jit `, we must ensure to specify ``static_argnums=1 ``
199
+ to ensure ``y `` stays a regular value. This is due to the boolean expression
200
+ ``y >= 1. ``, which requires concrete values (regular or tracers). The
201
+ same would happen if we write explicitly ``bool(y >= 1.) ``, or ``int(y) ``,
202
+ or ``float(y) ``.
203
+
204
+ Interestingly, ``jax.grad(divide)(3., 2.) ``, works because :func: `jax.grad `
205
+ uses concrete tracers, and resolves the conditional using the concrete
206
+ value of ``y ``.
75
207
76
208
Gradients contain `NaN ` where using ``where ``
77
209
------------------------------------------------
0 commit comments