Skip to content

Problem in _Quaxify.__call__ #58

@nstarman

Description

@nstarman

_Quaxify.__call__ is converting things to tracers when they shouldn't be tracers. For example, here's a stable MWE

import jax.numpy as jnp
import quax

xbool = jnp.array([ True, False,  True], dtype=bool)
x1 = jnp.array([1., 2., 3.], dtype=float)

compress = quax.quaxify(jnp.compress)
compress(xbool, x1) 

Here's the traceback. I think this is saying that some stuff are becoming tracer, and that's triggering the error.

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../quax/quax/_core.py:311: in __call__
    out = fn(*args, **kwargs)
        args       = (Array([ True, False,  True], dtype=bool), Array([1., 2., 3.], dtype=float32))
        dynamic    = (<function compress at 0x1084b2340>, (Array([ True, False,  True], dtype=bool), Array([1., 2., 3.], dtype=float32)), {})
        fn         = <function compress at 0x1084b2340>
        kwargs     = {}
        parent_trace = EvalTrace
        self       = _Quaxify(fn=<function compress>, filter_spec=True, dynamic=False)
        static     = (None, (None, None), {})
        tag        = <jax._src.core.TraceTag object at 0x110d3be90>
        trace      = _QuaxTrace
.venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:9595: in compress
    if reductions.any(extra):
        a          = Array([1., 2., 3.], dtype=float32)
        arr        = Traced<ShapedArray(float32[3])>with<_QuaxTrace> with
  value = _DenseArrayValue(array=f32[3])
        axis       = 0
        condition  = Array([ True, False,  True], dtype=bool)
        condition_arr = Array([ True, False,  True], dtype=bool)
        extra      = Traced<ShapedArray(bool[0])>with<_QuaxTrace> with
  value = _DenseArrayValue(array=bool[0])
        fill_value = Traced<ShapedArray(int32[], weak_type=True)>with<_QuaxTrace> with
  value = _DenseArrayValue(array=weak_i32[])
        out        = None
        size       = None
.venv/lib/python3.11/site-packages/jax/_src/core.py:842: in __bool__
    return self.aval._bool(self)
        self       = Traced<ShapedArray(bool[])>with<_QuaxTrace> with
  value = _DenseArrayValue(array=bool[])
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = ShapedArray(bool[]), arg = Traced<ShapedArray(bool[])>with<_QuaxTrace> with
  value = _DenseArrayValue(array=bool[])

    def error(self, arg):
>     raise TracerBoolConversionError(arg)
E     jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
E     See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

arg        = Traced<ShapedArray(bool[])>with<_QuaxTrace> with
  value = _DenseArrayValue(array=bool[])
self       = ShapedArray(bool[])

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions