Skip to content

Control randomness of Kfold,StratifiedKfold,RandomSub #57

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/crossval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ struct Kfold <: CrossValGenerator
k::Int
coeff::Float64

function Kfold(n::Int, k::Int)
function Kfold(rng::AbstractRNG, n::Int, k::Int)
2 <= k <= n || error("The value of k must be in [2, length(a)].")
new(randperm(n), k, n / k)
new(randperm(rng, n), k, n / k)
end
end

Kfold(n::Int, k::Int) = Kfold(Random.GLOBAL_RNG, n, k)

length(c::Kfold) = c.k

struct KfoldState
Expand All @@ -42,10 +44,10 @@ struct StratifiedKfold <: CrossValGenerator
permseqs::Vector{Vector{Int}} #Vectors of vectors of indexes for each stratum
k::Int #Number of splits
coeffs::Vector{Float64} #About how many observations per strata are in a val set
function StratifiedKfold(strata, k)
function StratifiedKfold(rng::AbstractRNG, strata, k)
2 <= k <= length(strata) || error("The value of k must be in [2, length(strata)].")
strata_labels, permseqs = unique_inverse(strata)
map(shuffle!, permseqs)
map( s -> shuffle!(rng, s), permseqs)
coeffs = Float64[]
for (stratum, permseq) in zip(strata_labels, permseqs)
k <= length(permseq) || error("k is greater than the length of stratum $stratum")
Expand All @@ -55,6 +57,8 @@ struct StratifiedKfold <: CrossValGenerator
end
end

StratifiedKfold(strata, k) = StratifiedKfold(Random.GLOBAL_RNG, strata, k)

length(c::StratifiedKfold) = c.k

function Base.iterate(c::StratifiedKfold, s::Int=1)
Expand Down Expand Up @@ -95,16 +99,19 @@ end
# Repeated random sub-sampling

struct RandomSub <: CrossValGenerator
rng::AbstractRNG # Random number generator
n::Int # total length
sn::Int # length of each subset
k::Int # number of subsets
end

RandomSub(n::Int, sn::Int, k::Int) = RandomSub(Random.GLOBAL_RNG, n::Int, sn::Int, k::Int)

length(c::RandomSub) = c.k

function iterate(c::RandomSub, s::Int=1)
(s > c.k) && return nothing
return (sort!(sample(1:c.n, c.sn; replace=false)), s+1)
return (sort!(sample(c.rng, 1:c.n, c.sn; replace=false)), s+1)
end

# Stratified repeated random sub-sampling
Expand Down
55 changes: 55 additions & 0 deletions test/crossval.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using MLBase
using Random
using Test

## Kfold
Expand All @@ -12,6 +13,25 @@ end
x = vcat(map(s -> setdiff(1:12, s), ss)...)
@test sort(x) == collect(1:12)


# Verify that we are indeed controlling the random partitioning
# by specifying in the first argument a random number generator.
# The test involves calling Kfold with arbitrary chosen n,k pairs
# for the same seed choice picked from an arbitrarily defined set.
# Desired outcome is that for the same seed, we should get the same folds.

for seed in [1, 101, 10101]
for n in [20, 30, 31]
for k in [2, 3, 7]
aux1 = collect(Kfold(MersenneTwister(seed), n, k))
aux2 = collect(Kfold(MersenneTwister(seed), n, k))
@test aux1 == aux2 # folds must be the same
end
end
end



## StratifiedKfold

strat = [:a, :a, :b, :b, :c, :c, :b, :c, :a]
Expand All @@ -25,6 +45,19 @@ end
x = vcat(map(s -> setdiff(1:9, s), ss)...)
@test sort(x) == collect(1:9)


# Verify that we are indeed controlling the random partitioning
# by specifying in the first argument a random number generator.
# Try this out for a few arbitrarily chosen random seeds.
# Desired outcome is that for the same seed, we should get the same folds.

for seed in [1, 101, 10101]
aux1 = collect(StratifiedKfold(MersenneTwister(seed), strat, 3))
aux2 = collect(StratifiedKfold(MersenneTwister(seed), strat, 3))
@test aux1 == aux2 # folds must be the same
end


## LOOCV

ss = collect(LOOCV(4))
Expand All @@ -49,6 +82,28 @@ for i = 1:6
@test length(unique(ss[i])) == 5
end


# Verify that we are indeed controlling the random partitioning
# by specifying in the first argument a random number generator.
# The test involves calling RandomSub with arbitrary chosen
# n,k,sn arguments for for the same seed choice picked from an
# arbitrarily defined set.
# Desired outcome is that for the same seed, we should get the
# same folds.

for seed in [1, 101, 10101]
for n in [20, 30, 31]
for sn in [2, 7, 11]
for k in [2, 3, 7]
aux1 = collect(RandomSub(MersenneTwister(seed), n, k, sn))
aux2 = collect(RandomSub(MersenneTwister(seed), n, k, sn))
@test aux1 == aux2 # folds must be the same
end
end
end
end


## StratifiedRandomSub

strat = [:a, :a, :b, :b, :c, :c, :b, :c, :a]
Expand Down