Skip to content

Commit 5a3122d

Browse files
authored
Merge pull request #23 from IBM/pipe_or_expression
support :| operation for selecting best learners
2 parents daae3ab + e3a3760 commit 5a3122d

File tree

4 files changed

+17
-9
lines changed

4 files changed

+17
-9
lines changed

src/AutoMLPipeline.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@ include("basefilters.jl")
1414
using .BaseFilters
1515
export OneHotEncoder
1616

17-
include("pipelines.jl")
18-
using .Pipelines
19-
export @pipeline
20-
export @pipelinex
21-
export Pipeline, ComboPipeline
2217

2318
include("featureselector.jl")
2419
using .FeatureSelectors
@@ -48,4 +43,10 @@ include("skcrossvalidator.jl")
4843
using .SKCrossValidators
4944
export crossvalidate
5045

46+
include("pipelines.jl")
47+
using .Pipelines
48+
export @pipeline
49+
export @pipelinex
50+
export Pipeline, ComboPipeline
51+
5152
end # module

src/featureselector.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ mutable struct FeatureSelector <: Transformer
4242
end
4343
end
4444

45-
```
45+
"""
4646
FeatureSelector(cols::Vector{Int})
4747
4848
Helper function for FeatureSelector.
49-
```
49+
"""
5050
function FeatureSelector(cols::Vector{Int})
5151
FeatureSelector(Dict(:columns => cols))
5252
end
@@ -195,11 +195,11 @@ mutable struct CatNumDiscriminator <: Transformer
195195
end
196196
end
197197

198-
```
198+
"""
199199
CatNumDiscriminator(maxcat::Int)
200200
201201
Helper function for CatNumDiscriminator.
202-
```
202+
"""
203203
function CatNumDiscriminator(maxcat::Int)
204204
CatNumDiscriminator(Dict(:maxcategories=>maxcat))
205205
end

src/pipelines.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Random
66
using AutoMLPipeline.AbsTypes
77
using AutoMLPipeline.BaseFilters
88
using AutoMLPipeline.Utils
9+
using AutoMLPipeline.EnsembleMethods: BestLearner
910

1011
import AutoMLPipeline.AbsTypes: fit!, transform!
1112
export fit!, transform!
@@ -189,6 +190,8 @@ function processexpr(args)
189190
args[ndx] = :Pipeline
190191
elseif args[ndx] == :+
191192
args[ndx] = :ComboPipeline
193+
elseif args[ndx] == :|
194+
args[ndx] = :VoteEnsemble
192195
end
193196
end
194197
return args

test/test_pipeline.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,12 @@ function test_pipeline()
5252
catf = CatFeatureSelector()
5353
numf = NumFeatureSelector()
5454
rf = RandomForest()
55+
ada = Adaboost()
56+
pt = PrunedTree()
5557
pcombo3 = @pipeline disc |> ((catf + numf) + (numf |> pca) + (numf |> ica) + (catf|>ohe)) |> rf
5658
(fit_transform!(pcombo3,X,Y) .== Y) |> sum == 150
59+
pcombo4 = @pipeline (numf |> pca) + (numf |> ica) |> (ada | rf | pt)
60+
@test crossvalidate(pcombo4,X,Y).mean >= 0.90
5761
end
5862
@testset "Pipelines" begin
5963
Random.seed!(123)

0 commit comments

Comments
 (0)