Skip to content

Mooncake gives the wrong aggregator and thus does not give contextualized error messages #1230

@ChrisRackauckas

Description

@ChrisRackauckas

MWE:

using OrdinaryDiffEq, SciMLSensitivity, Mooncake, Test
mooncake_gradient(f, x) = Mooncake.value_and_gradient!!(Mooncake.build_rrule(f, x), f, x)[2][2]

odef(du, u, p, t) = du .= u .* p
const prob = ODEProblem(odef, [2.0], (0.0, 1.0), [3.0])

struct senseloss{T}
    sense::T
end
function (f::senseloss)(u0p)
    sum(solve(prob, Tsit5(), u0 = u0p[1:1], p = u0p[2:2], abstol = 1e-12,
        reltol = 1e-12, saveat = 0.1, sensealg = f.sense))
end
function loss(u0p)
    sum(solve(prob, Tsit5(), u0 = u0p[1:1], p = u0p[2:2], abstol = 1e-12, reltol = 1e-12,
        saveat = 0.1))
end
u0p = [2.0, 3.0]
@test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss(TrackerAdjoint()), u0p)  dup

Mooncake works, but it doesn't throw the right error message since SciMLSenstivity cannot know if it's in Mooncake, which makes the error messages confusing as it skips the simplified ones.

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions