-
-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Labels
questionUser queriesUser queries
Description
_Quaxify.__call__ is converting things to tracers when they shouldn't be tracers. For example, here's a stable MWE
import jax.numpy as jnp
import quax
xbool = jnp.array([ True, False, True], dtype=bool)
x1 = jnp.array([1., 2., 3.], dtype=float)
compress = quax.quaxify(jnp.compress)
compress(xbool, x1) Here's the traceback. I think this is saying that some stuff are becoming tracer, and that's triggering the error.
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../quax/quax/_core.py:311: in __call__
out = fn(*args, **kwargs)
args = (Array([ True, False, True], dtype=bool), Array([1., 2., 3.], dtype=float32))
dynamic = (<function compress at 0x1084b2340>, (Array([ True, False, True], dtype=bool), Array([1., 2., 3.], dtype=float32)), {})
fn = <function compress at 0x1084b2340>
kwargs = {}
parent_trace = EvalTrace
self = _Quaxify(fn=<function compress>, filter_spec=True, dynamic=False)
static = (None, (None, None), {})
tag = <jax._src.core.TraceTag object at 0x110d3be90>
trace = _QuaxTrace
.venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:9595: in compress
if reductions.any(extra):
a = Array([1., 2., 3.], dtype=float32)
arr = Traced<ShapedArray(float32[3])>with<_QuaxTrace> with
value = _DenseArrayValue(array=f32[3])
axis = 0
condition = Array([ True, False, True], dtype=bool)
condition_arr = Array([ True, False, True], dtype=bool)
extra = Traced<ShapedArray(bool[0])>with<_QuaxTrace> with
value = _DenseArrayValue(array=bool[0])
fill_value = Traced<ShapedArray(int32[], weak_type=True)>with<_QuaxTrace> with
value = _DenseArrayValue(array=weak_i32[])
out = None
size = None
.venv/lib/python3.11/site-packages/jax/_src/core.py:842: in __bool__
return self.aval._bool(self)
self = Traced<ShapedArray(bool[])>with<_QuaxTrace> with
value = _DenseArrayValue(array=bool[])
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = ShapedArray(bool[]), arg = Traced<ShapedArray(bool[])>with<_QuaxTrace> with
value = _DenseArrayValue(array=bool[])
def error(self, arg):
> raise TracerBoolConversionError(arg)
E jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
E See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
arg = Traced<ShapedArray(bool[])>with<_QuaxTrace> with
value = _DenseArrayValue(array=bool[])
self = ShapedArray(bool[])
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries