Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/pipelines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,25 @@ function transform!(pipe::ComboPipeline, features::DataFrame=DataFrame())::Union
new_instances = DataFrame()
for t_index in eachindex(machines)
machine = machines[t_index]
current_instances = transform!(machine, instances)
new_instances = hcat(new_instances,current_instances,makeunique=true)
current_instances = transform!(machine, instances)
new_instances = mcat(new_instances,current_instances)
end

return new_instances
end

# dispatch concat between vectors/dataframes
function mcat(x::DataFrame,y::DataFrame)
hcat(x,y,makeunique=true)
end

function mcat(x::DataFrame,y::Vector)
hcat(x,DataFrame(v=y),makeunique=true)
end

function mcat(x::Vector,y::DataFrame)
hcat(DataFrame(v=x),y,makeunique=true)
end

function transform(pipe::ComboPipeline, features::DataFrame=DataFrame())::Union{Vector,DataFrame}
return transform!(pipe, features)
end
Expand Down
6 changes: 3 additions & 3 deletions test/test_featureselector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ function iris_test()
@test (fit_transform(catf,X) .== X[:,5]) |> Matrix |> sum == 150
@test (fit_transform!(numf,X) .== X[:,1:4]) |> Matrix |> sum == 600
@test (fit_transform(numf,X) .== X[:,1:4]) |> Matrix |> sum == 600
catnumdata = hcat(X,repeat([1,2,3,4,5],30))
catnumdata = hcat(X,DataFrame(x1=repeat([1,2,3,4,5],30)))
catnum = CatNumDiscriminator()
res = fit_transform!(catnum,catnumdata)
@test infer_eltype(catnumdata[:,[2,4,6]]) <: Number
@test infer_eltype(res[:,[2,4,6]]) <: String
catnumdata = hcat(X,repeat([1,2,3,4,5],30))
catnumdata = hcat(X,DataFrame(x1=repeat([1,2,3,4,5],30)))
catnum = CatNumDiscriminator(0)
res = fit_transform!(catnum,catnumdata)
@test eltype(res[:,6]) <: Number
catnumdata = hcat(X,repeat([1,2,3,4,5],30))
catnumdata = hcat(X,DataFrame(x1=repeat([1,2,3,4,5],30)))
res1 = fit_transform(catnum,catnumdata)
@test eltype(res1[:,6]) <: Number
end
Expand Down
16 changes: 15 additions & 1 deletion test/test_pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ const noop = Identity()
const rf = RandomForest()
const ada = Adaboost()
const pt = PrunedTree()
const numf = NumFeatureSelector()
const catf = CatFeatureSelector()

acc(X,Y) = score(:accuracy,X,Y)

function test_pipeline()
# test initialization of types
Expand Down Expand Up @@ -58,7 +62,6 @@ end
test_pipeline()
end

acc(X,Y) = score(:accuracy,X,Y)

function test_sympipeline()
pcombo5 = @pipeline :((ohe + noop) |> (ada * rf * pt))
Expand All @@ -78,6 +81,17 @@ end
test_sympipeline()
end

function test_advancedpipeline()
expr = @pipeline ((((ada) + (rf) ) |> ohe) + numf) |> ada;
@test crossvalidate(expr,X,Y,acc,5,false).mean >= 0.90
end
@testset "Advanced Pipeline: Learners as filters" begin
Random.seed!(123)
test_advancedpipeline()
end



function test_pipeline()
# test symbolic pipeline expression
pcombo2 = @pipeline ohe + noop
Expand Down