Skip to content

Forward-mode differentiation rule for 'custom_lin' not implemented #2784

@jacobjinkelly

Description

@jacobjinkelly

Running this code produces the above error.

@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def _clip_gradient(lo, hi, x):
  return x  # identity function

def clip_gradient_fwd(lo, hi, x):
    # return x, None
    return x, (hi, ) 

def clip_gradient_bwd(lo, hi, _, g):
    return (np.clip(g, lo, hi),)

_clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)

def clip_gradient(x):
    lo = -1
    hi = x + 1  # causes things to break
    return _clip_gradient(lo, hi, x)

print(jax.grad(clip_gradient)(1.))

Replacing the residual with None (see commented out line in clip_gradient_fwd) makes the output

Traced<ConcreteArray(1.0)>with<JVPTrace(level=1/0)>
  with primal = DeviceArray(1., dtype=float32)
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=0/0)>

I was also able to get a mismatched tracer levels error on a more complex example.

To my understanding setting hi = x + 1 is the issue as it creates a trace of hi with x.

This example may seem a bit contrived, but I originally came across this trying to set an initial step size for odeint (see #2604).

IIUC we want to make concrete all JVPTracer instances that are in static args here.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions