diff --git a/src/crossval.jl b/src/crossval.jl index b30fba2..0da59ec 100644 --- a/src/crossval.jl +++ b/src/crossval.jl @@ -11,12 +11,14 @@ struct Kfold <: CrossValGenerator k::Int coeff::Float64 - function Kfold(n::Int, k::Int) + function Kfold(rng::AbstractRNG, n::Int, k::Int) 2 <= k <= n || error("The value of k must be in [2, length(a)].") - new(randperm(n), k, n / k) + new(randperm(rng, n), k, n / k) end end +Kfold(n::Int, k::Int) = Kfold(Random.GLOBAL_RNG, n, k) + length(c::Kfold) = c.k struct KfoldState @@ -42,10 +44,10 @@ struct StratifiedKfold <: CrossValGenerator permseqs::Vector{Vector{Int}} #Vectors of vectors of indexes for each stratum k::Int #Number of splits coeffs::Vector{Float64} #About how many observations per strata are in a val set - function StratifiedKfold(strata, k) + function StratifiedKfold(rng::AbstractRNG, strata, k) 2 <= k <= length(strata) || error("The value of k must be in [2, length(strata)].") strata_labels, permseqs = unique_inverse(strata) - map(shuffle!, permseqs) + map( s -> shuffle!(rng, s), permseqs) coeffs = Float64[] for (stratum, permseq) in zip(strata_labels, permseqs) k <= length(permseq) || error("k is greater than the length of stratum $stratum") @@ -55,6 +57,8 @@ struct StratifiedKfold <: CrossValGenerator end end +StratifiedKfold(strata, k) = StratifiedKfold(Random.GLOBAL_RNG, strata, k) + length(c::StratifiedKfold) = c.k function Base.iterate(c::StratifiedKfold, s::Int=1) @@ -95,16 +99,19 @@ end # Repeated random sub-sampling struct RandomSub <: CrossValGenerator + rng::AbstractRNG # Random number generator n::Int # total length sn::Int # length of each subset k::Int # number of subsets end +RandomSub(n::Int, sn::Int, k::Int) = RandomSub(Random.GLOBAL_RNG, n::Int, sn::Int, k::Int) + length(c::RandomSub) = c.k function iterate(c::RandomSub, s::Int=1) (s > c.k) && return nothing - return (sort!(sample(1:c.n, c.sn; replace=false)), s+1) + return (sort!(sample(c.rng, 1:c.n, c.sn; replace=false)), s+1) end # Stratified repeated random sub-sampling diff --git a/test/crossval.jl b/test/crossval.jl index a93712f..023332f 100644 --- a/test/crossval.jl +++ b/test/crossval.jl @@ -1,4 +1,5 @@ using MLBase +using Random using Test ## Kfold @@ -12,6 +13,25 @@ end x = vcat(map(s -> setdiff(1:12, s), ss)...) @test sort(x) == collect(1:12) + +# Verify that we are indeed controlling the random partitioning +# by specifying in the first argument a random number generator. +# The test involves calling Kfold with arbitrary chosen n,k pairs +# for the same seed choice picked from an arbitrarily defined set. +# Desired outcome is that for the same seed, we should get the same folds. + +for seed in [1, 101, 10101] + for n in [20, 30, 31] + for k in [2, 3, 7] + aux1 = collect(Kfold(MersenneTwister(seed), n, k)) + aux2 = collect(Kfold(MersenneTwister(seed), n, k)) + @test aux1 == aux2 # folds must be the same + end + end +end + + + ## StratifiedKfold strat = [:a, :a, :b, :b, :c, :c, :b, :c, :a] @@ -25,6 +45,19 @@ end x = vcat(map(s -> setdiff(1:9, s), ss)...) @test sort(x) == collect(1:9) + +# Verify that we are indeed controlling the random partitioning +# by specifying in the first argument a random number generator. +# Try this out for a few arbitrarily chosen random seeds. +# Desired outcome is that for the same seed, we should get the same folds. + +for seed in [1, 101, 10101] + aux1 = collect(StratifiedKfold(MersenneTwister(seed), strat, 3)) + aux2 = collect(StratifiedKfold(MersenneTwister(seed), strat, 3)) + @test aux1 == aux2 # folds must be the same +end + + ## LOOCV ss = collect(LOOCV(4)) @@ -49,6 +82,28 @@ for i = 1:6 @test length(unique(ss[i])) == 5 end + +# Verify that we are indeed controlling the random partitioning +# by specifying in the first argument a random number generator. +# The test involves calling RandomSub with arbitrary chosen +# n,k,sn arguments for for the same seed choice picked from an +# arbitrarily defined set. +# Desired outcome is that for the same seed, we should get the +# same folds. + +for seed in [1, 101, 10101] + for n in [20, 30, 31] + for sn in [2, 7, 11] + for k in [2, 3, 7] + aux1 = collect(RandomSub(MersenneTwister(seed), n, k, sn)) + aux2 = collect(RandomSub(MersenneTwister(seed), n, k, sn)) + @test aux1 == aux2 # folds must be the same + end + end + end +end + + ## StratifiedRandomSub strat = [:a, :a, :b, :b, :c, :c, :b, :c, :a]