Skip to content

Deserialisation fails for wrappers like TunedModel when atomic model overloads save/restore #1099

@ablaom

Description

@ablaom

XGBoost.jl models have non-persistent fitresults which means that they cannot be directly serialised. That's not a problem, because such models can overload MLJModelInterface's save and restore. However, it has been reported that XGBoost models wrapped in TunedModel don't deserialise properly. Here is a MWE:

First, we define a supervised model, EphemeralRegressor, with an ephemeral fitresult. For this model we overload save/restore to ensure deserialization works, provided you use the correct API.

using Statistics, MLJBase, Test, MLJTuning, Serialization, StatisticalMeasures
import MLJModelInterface

# define a model with non-persistent fitresult:
thing = []
struct EphemeralRegressor <: Deterministic end
function MLJModelInterface.fit(::EphemeralRegressor, verbosity, X, y)
    # if I serialize/deserialized `thing` then `view` below changes:
    view = objectid(thing)
    fitresult = (thing, view, mean(y))
    return fitresult, nothing, NamedTuple()
end
function MLJModelInterface.predict(::EphemeralRegressor, fitresult, X)
    thing, view, μ = fitresult
    return view == objectid(thing) ? fill(μ, nrows(X)) :
        throw(ErrorException("dead fitresult"))
end
function MLJModelInterface.save(::EphemeralRegressor, fitresult)
    thing, _, μ = fitresult
    return (thing, μ)
end
function MLJModelInterface.restore(::EphemeralRegressor, serialized_fitresult)
    thing, μ = serialized_fitresult
    view = objectid(thing)
    return (thing, view, μ)
end

# EphemeralRegressor cannot be directly serialized:
X, y = (; x = rand(3)), fill(42.0, 3)
model = EphemeralRegressor()
mach = machine(model, X, y) |> fit!
io = IOBuffer()
serialize(io, mach)
seekstart(io)
mach2 = deserialize(io)
@test_throws ErrorException("dead fitresult") predict(mach2, 42)

# But it can be serialized/deserialized using correct API:
io = IOBuffer()
serialize(io, serializable(mach))
seekstart(io)
mach2 = restore!(deserialize(io))
@test MLJBase.predict(mach2, (; x = rand(2)))  fill(42.0, 2)

But wrapping this model using TunedModel leads to deserialization failure:

tmodel = TunedModel(
    models=fill(EphemeralRegressor(), 2),
    measure = l2,
)
mach = machine(tmodel, X, y) |> fit!
io = IOBuffer()
serialize(io, serializable(mach))
seekstart(io)
mach2 = restore!(deserialize(io))
MLJBase.predict(mach2, (; x = rand(2)))
# ERROR: dead fitresult
# Stacktrace:
#  [1] predict(::EphemeralRegressor, fitresult::Tuple{Vector{Any}, UInt64, Float64}, X::@NamedTuple{x::Vector{Float64}})                                                            
#    @ Main ./REPL[7]:3
#
# < truncated trace >

The remedy is to properly "forward" the save/restore methods of the atomic models. We can exclude any wrapper model implemented as NetworkComposite (ie, using learning networks) as they already overload save and restore properly.

To do (waiting on review):

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions