-
Notifications
You must be signed in to change notification settings - Fork 157
Description
XGBoost.jl models have non-persistent fitresult
s which means that they cannot be directly serialised. That's not a problem, because such models can overload MLJModelInterface's save
and restore
. However, it has been reported that XGBoost models wrapped in TunedModel
don't deserialise properly. Here is a MWE:
First, we define a supervised model, EphemeralRegressor
, with an ephemeral fitresult. For this model we overload save/restore
to ensure deserialization works, provided you use the correct API.
using Statistics, MLJBase, Test, MLJTuning, Serialization, StatisticalMeasures
import MLJModelInterface
# define a model with non-persistent fitresult:
thing = []
struct EphemeralRegressor <: Deterministic end
function MLJModelInterface.fit(::EphemeralRegressor, verbosity, X, y)
# if I serialize/deserialized `thing` then `view` below changes:
view = objectid(thing)
fitresult = (thing, view, mean(y))
return fitresult, nothing, NamedTuple()
end
function MLJModelInterface.predict(::EphemeralRegressor, fitresult, X)
thing, view, μ = fitresult
return view == objectid(thing) ? fill(μ, nrows(X)) :
throw(ErrorException("dead fitresult"))
end
function MLJModelInterface.save(::EphemeralRegressor, fitresult)
thing, _, μ = fitresult
return (thing, μ)
end
function MLJModelInterface.restore(::EphemeralRegressor, serialized_fitresult)
thing, μ = serialized_fitresult
view = objectid(thing)
return (thing, view, μ)
end
# EphemeralRegressor cannot be directly serialized:
X, y = (; x = rand(3)), fill(42.0, 3)
model = EphemeralRegressor()
mach = machine(model, X, y) |> fit!
io = IOBuffer()
serialize(io, mach)
seekstart(io)
mach2 = deserialize(io)
@test_throws ErrorException("dead fitresult") predict(mach2, 42)
# But it can be serialized/deserialized using correct API:
io = IOBuffer()
serialize(io, serializable(mach))
seekstart(io)
mach2 = restore!(deserialize(io))
@test MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2)
But wrapping this model using TunedModel
leads to deserialization failure:
tmodel = TunedModel(
models=fill(EphemeralRegressor(), 2),
measure = l2,
)
mach = machine(tmodel, X, y) |> fit!
io = IOBuffer()
serialize(io, serializable(mach))
seekstart(io)
mach2 = restore!(deserialize(io))
MLJBase.predict(mach2, (; x = rand(2)))
# ERROR: dead fitresult
# Stacktrace:
# [1] predict(::EphemeralRegressor, fitresult::Tuple{Vector{Any}, UInt64, Float64}, X::@NamedTuple{x::Vector{Float64}})
# @ Main ./REPL[7]:3
#
# < truncated trace >
The remedy is to properly "forward" the save
/restore
methods of the atomic models. We can exclude any wrapper model implemented as NetworkComposite
(ie, using learning networks) as they already overload save
and restore
properly.
To do (waiting on review):
-
TunedModel
Overloadsave
andrestore
MLJTuning.jl#208 -
IteratedModel
Overloadsave
andrestore
to fix a serialization issue MLJIteration.jl#59 -
EnsembleModel
Add extra serialization test MLJEnsembles.jl#32 -
BinaryThresholdPredictor
Overloadsave
andrestore
to address serialization issue forBinaryThresholdPredictor
MLJModels.jl#550