Skip to content

Commit 4196305

Browse files
authored
Merge pull request #1064 from alan-turing-institute/mljbalancing
Add MLJBalancing to MLJ and add class imbalance docs
2 parents 7b32802 + 8253bfb commit 4196305

10 files changed

+93
-37
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
99
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1010
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
MLJBalancing = "45f359ea-796d-4f51-95a5-deb1a414c586"
1213
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
1314
MLJEnsembles = "50ed68f4-41fd-4504-931a-ed422449fee0"
1415
MLJFlow = "7b7b8358-b45c-48ea-a8ef-7ca328ad328f"
@@ -31,6 +32,7 @@ CategoricalArrays = "0.8,0.9, 0.10"
3132
ComputationalResources = "0.3"
3233
Distributions = "0.21,0.22,0.23, 0.24, 0.25"
3334
MLJBase = "1"
35+
MLJBalancing = "0.1"
3436
MLJEnsembles = "0.4"
3537
MLJFlow = "0.2"
3638
MLJIteration = "0.6"
@@ -40,8 +42,8 @@ OpenML = "0.2,0.3"
4042
ProgressMeter = "1.1"
4143
Reexport = "1.2"
4244
ScientificTypes = "3"
43-
StatsBase = "0.32,0.33, 0.34"
4445
StatisticalMeasures = "0.1"
46+
StatsBase = "0.32,0.33, 0.34"
4547
Tables = "0.2,1.0"
4648
julia = "1.6"
4749

docs/ModelDescriptors.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,20 @@ AgglomerativeClustering_MLJScikitLearnInterface = ["clustering", "static_models"
1010
BM25Transformer_MLJText = ["encoders", "text_analysis"]
1111
BaggingClassifier_MLJScikitLearnInterface = ["classification", "ensemble_models"]
1212
BaggingRegressor_MLJScikitLearnInterface = ["regression", "ensemble_models"]
13+
BalancedBaggingClassifier_MLJBalancing = ["class_imbalance", "classification"]
1314
BayesianLDA_MultivariateStats = ["dimension_reduction", "classification", "Bayesian_models"]
1415
BayesianLDA_MLJScikitLearnInterface = ["dimension_reduction", "classification", "Bayesian_models"]
1516
BayesianQDA_MLJScikitLearnInterface = ["dimension_reduction", "classification", "Bayesian_models"]
1617
BayesianRidgeRegressor_MLJScikitLearnInterface = ["regression", "Bayesian_models"]
1718
BayesianSubspaceLDA_MultivariateStats = ["dimension_reduction", "classification", "Bayesian_models"]
1819
BernoulliNBClassifier_MLJScikitLearnInterface = ["classification", "Bayesian_models"]
1920
Birch_MLJScikitLearnInterface = ["clustering", "dimension_reduction", ]
21+
BorderlineSMOTE1_Imbalance = ["class_imbalance"]
2022
CatBoostClassifier_CatBoost = ["classification", "ensemble_models", "iterative_models"]
2123
CatBoostRegressor_CatBoost = ["regression", "ensemble_models", "iterative_models"]
2224
CBLOFDetector_OutlierDetectionPython = ["outlier_detection"]
2325
CDDetector_OutlierDetectionPython = ["outlier_detection"]
26+
ClusterUndersampler_Imbalance = ["class_imbalance"]
2427
COFDetector_OutlierDetectionNeighbors = ["outlier_detection"]
2528
COFDetector_OutlierDetectionPython = ["outlier_detection"]
2629
COPODDetector_OutlierDetectionPython = ["outlier_detection"]
@@ -46,6 +49,7 @@ ESADDetector_OutlierDetectionNetworks = ["outlier_detection"]
4649
ElasticNetCVRegressor_MLJScikitLearnInterface = ["regression"]
4750
ElasticNetRegressor_MLJLinearModels = ["regression"]
4851
ElasticNetRegressor_MLJScikitLearnInterface = ["regression"]
52+
ENNUndersampler_Imbalance = ["class_imbalance"]
4953
EpsilonSVR_LIBSVM = ["regression"]
5054
EvoLinearRegressor_EvoLinear = ["regression"]
5155
EvoTreeClassifier_EvoTrees = ["classification", "ensemble_models", "iterative_models"]
@@ -167,8 +171,12 @@ ProbabilisticNuSVC_LIBSVM = ["classification"]
167171
ProbabilisticSGDClassifier_MLJScikitLearnInterface = ["classification"]
168172
ProbabilisticSVC_LIBSVM = ["classification"]
169173
QuantileRegressor_MLJLinearModels = ["regression"]
174+
RandomOversampler_Imbalance = ["class_imbalance"]
175+
RandomUndersampler_Imbalance = ["class_imbalance"]
176+
RandomWalkOversampler_Imbalance = ["class_imbalance"]
170177
RANSACRegressor_MLJScikitLearnInterface = ["regression"]
171178
RODDetector_OutlierDetectionPython = ["outlier_detection"]
179+
ROSE_Imbalance = ["class_imbalance"]
172180
RandomForestClassifier_BetaML = ["classification", "ensemble_models", "iterative_models"]
173181
RandomForestClassifier_DecisionTree = ["classification", "ensemble_models", "iterative_models"]
174182
RandomForestClassifier_MLJScikitLearnInterface = ["classification", "ensemble_models", "iterative_models"]
@@ -186,6 +194,9 @@ RobustRegressor_MLJLinearModels = ["regression"]
186194
SelfOrganizingMap_SelfOrganizingMaps = ["dimension_reduction", "clustering"]
187195
SGDClassifier_MLJScikitLearnInterface = ["classification"]
188196
SGDRegressor_MLJScikitLearnInterface = ["regression"]
197+
SMOTE_Imbalance = ["class_imbalance"]
198+
SMOTEN_Imbalance = ["class_imbalance"]
199+
SMOTENC_Imbalance = ["class_imbalance"]
189200
SODDetector_OutlierDetectionPython = ["outlier_detection", "outlier_detection"]
190201
SOSDetector_OutlierDetectionPython = ["outlier_detection"]
191202
SRRegressor_SymbolicRegression = ["regression"]
@@ -204,6 +215,7 @@ SimpleImputer_BetaML = ["missing_value_imputation"]
204215
SpectralClustering_MLJScikitLearnInterface = ["clustering", "static_models"]
205216
Standardizer_MLJModels = ["encoders"]
206217
SubspaceLDA_MultivariateStats = ["classification", "dimension_reduction"]
218+
TomekUndersampler_Imbalance = ["class_imbalance"]
207219
TSVDTransformer_TSVD = ["dimension_reduction"]
208220
TfidfTransformer_MLJText = ["encoders", "text_analysis"]
209221
TheilSenRegressor_MLJScikitLearnInterface = ["regression"]

docs/Project.toml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,18 @@ CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
44
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
55
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
66
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
7-
EarlyStopping = "792122b4-ca99-40de-a6bc-6742525f08b6"
87
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
98
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
109
IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
1110
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
1211
MLJClusteringInterface = "d354fa79-ed1c-40d4-88ef-b8c7bd1568af"
1312
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
14-
MLJEnsembles = "50ed68f4-41fd-4504-931a-ed422449fee0"
15-
MLJFlow = "7b7b8358-b45c-48ea-a8ef-7ca328ad328f"
1613
MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c"
17-
MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
1814
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
19-
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
20-
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
2115
MLJMultivariateStatsInterface = "1b6a4a23-ba22-4f51-9698-8599985d3728"
22-
MLJTuning = "03970b2e-30c4-11ea-3135-d1576263f10f"
2316
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
2417
NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36"
2518
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
26-
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
2719
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
2820
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
2921
StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"

docs/make.jl

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@ end
55
using Pkg
66
using Documenter
77
using MLJ
8-
import MLJIteration
9-
import IterationControl
10-
import EarlyStopping
11-
import MLJBase
12-
import MLJTuning
13-
import MLJModels
14-
import MLJEnsembles
15-
import ScientificTypes
16-
import MLJModelInterface
17-
import ScientificTypes
8+
using MLJBase
9+
import MLJ.MLJBase.MLJModelInterface
10+
import MLJ.MLJIteration
11+
import MLJ.MLJIteration.IterationControl
12+
import MLJ.MLJIteration.IterationControl.EarlyStopping
13+
import MLJ.MLJTuning
14+
import MLJ.MLJModels
15+
import MLJ.MLJEnsembles
16+
import MLJ.ScientificTypes
17+
import MLJ.MLJBalancing
1818
import ScientificTypesBase
1919
import Distributions
2020
using CategoricalArrays
@@ -72,6 +72,7 @@ pages = [
7272
"Linear Pipelines" => "linear_pipelines.md",
7373
"Target Transformations" => "target_transformations.md",
7474
"Homogeneous Ensembles" => "homogeneous_ensembles.md",
75+
"Correcting Class Imbalance" => "correcting_class_imbalance.md",
7576
"Model Stacking" => "model_stacking.md",
7677
"Learning Networks" => "learning_networks.md",
7778
"Controlling Iterative Models" => "controlling_iterative_models.md",
@@ -101,20 +102,23 @@ makedocs(
101102
doctest = true,
102103
sitename = "MLJ",
103104
format = Documenter.HTML(),
104-
modules = [MLJ,
105-
MLJBase,
106-
MLJTuning,
107-
MLJModels,
108-
MLJEnsembles,
109-
ScientificTypes,
110-
MLJModelInterface,
111-
ScientificTypesBase,
112-
StatisticalMeasures,
113-
MLJIteration,
114-
EarlyStopping,
115-
IterationControl,
116-
CategoricalDistributions,
117-
StatisticalMeasures],
105+
modules = [
106+
MLJ,
107+
MLJBase,
108+
MLJTuning,
109+
MLJModels,
110+
MLJEnsembles,
111+
MLJBalancing,
112+
MLJIteration,
113+
ScientificTypes,
114+
MLJModelInterface,
115+
ScientificTypesBase,
116+
StatisticalMeasures,
117+
EarlyStopping,
118+
IterationControl,
119+
CategoricalDistributions,
120+
StatisticalMeasures,
121+
],
118122
pages = pages,
119123
warnonly = [:cross_references, :missing_docs],
120124
)

docs/model_docstring_tools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ const HANDLES = keys(DESCRIPTORS_GIVEN_HANDLE)
6262
"""
6363
models_missing_descriptors()
6464
65-
Return a list of handles for those models in the registry not have the corresponding
65+
Return a list of handles for those models in the registry not having the corresponding
6666
handle as key in /docs/src/ModelDescriptors.toml.
6767
6868
"""
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Correcting Class Imbalance
2+
3+
## Oversampling and undersampling methods
4+
5+
Models providing oversampling or undersampling methods, to correct for class imbalance,
6+
are listed under [Class Imbalance](@ref). In particular, several popular algorithms are
7+
provided by the [Imbalance.jl]() package, which includes detailed documentation and
8+
tutorials.
9+
10+
## Incorporating class imbalance in supervised learning pipelines
11+
12+
One or more oversampling/undersampling algorithms can be fused with an MLJ classifier
13+
using the [`BalancedModel`](@ref) wrapper. This creates a new classifier which can be
14+
treated like any other; resampling to correct for class imbalance, relevant only for
15+
*training* of the atomic classifier, is then carried out internally. If, for example, one
16+
applies cross-validation to the wrapped classifier (using [`evaluate!`](@ref), say) then
17+
this means over/undersampling is then repeated for each training fold automatically.
18+
19+
Refer to the
20+
[MLJBalancing.jl](https://juliaai.github.io/Imbalance.jl/dev/algorithms/mlj_balancing/)
21+
documentation for further details.
22+
23+
```@docs
24+
MLJBalancing.BalancedModel
25+
```

docs/src/index.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ To support MLJ development, please cite these works or star the repo:
4949
[Working with Categorical Data](@ref) |
5050
[Preparing Data](@ref) |
5151
[Generating Synthetic Data](@ref) |
52-
[OpenML Integration](@ref)
52+
[OpenML Integration](@ref) |
53+
[Correcting Class Imbalance](@ref)
5354

5455
### Models
5556
[Model Search](@ref model_search) |
@@ -65,15 +66,18 @@ To support MLJ development, please cite these works or star the repo:
6566
[Evaluating Model Performance](@ref) |
6667
[Tuning Models](@ref) |
6768
[Controlling Iterative Models](@ref) |
68-
[Learning Curves](@ref)
69+
[Learning Curves](@ref)|
70+
[Correcting Class Imbalance](@ref)
6971

7072
### Composition
7173
[Composing Models](@ref) |
7274
[Linear Pipelines](@ref) |
7375
[Target Transformations](@ref) |
7476
[Homogeneous Ensembles](@ref) |
7577
[Model Stacking](@ref) |
76-
[Learning Networks](@ref)
78+
[Learning Networks](@ref)|
79+
[Correcting Class Imbalance](@ref)
80+
7781

7882
### Integration
7983
[Logging Workflows](@ref) |

docs/src/list_of_supported_models.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# [List of Supported Models](@id model_list)
22

3+
For a list of models organized around function ("classification", "regression", etc.), see
4+
the [Model Browser](@ref).
5+
36
MLJ provides access to a wide variety of machine learning models.
47
We are always looking for
58
[help](https://github.com/alan-turing-institute/MLJ.jl/blob/master/CONTRIBUTING.md)
@@ -34,9 +37,11 @@ independent assessment.
3437
[EvoTrees.jl](https://github.com/Evovest/EvoTrees.jl) | - | EvoTreeRegressor, EvoTreeClassifier, EvoTreeCount, EvoTreeGaussian, EvoTreeMLE | medium | tree-based gradient boosting models
3538
[EvoLinear.jl](https://github.com/jeremiedb/EvoLinear.jl) | - | EvoLinearRegressor | medium | linear boosting models
3639
[GLM.jl](https://github.com/JuliaStats/GLM.jl) | [MLJGLMInterface.jl](https://github.com/JuliaAI/MLJGLMInterface.jl) | LinearRegressor, LinearBinaryClassifier, LinearCountRegressor | medium² |
40+
[Imbalance.jl](https://github.com/JuliaAI/Imbalance.jl) | - | RandomOversampler, RandomWalkOversampler, ROSE, SMOTE, BorderlineSMOTE1, SMOTEN, SMOTENC, RandomUndersampler, ClusterUndersampler, ENNUndersampler, TomekUndersampler, | low |
3741
[LIBSVM.jl](https://github.com/mpastell/LIBSVM.jl) | [MLJLIBSVMInterface.jl](https://github.com/JuliaAI/MLJLIBSVMInterface.jl) | LinearSVC, SVC, NuSVC, NuSVR, EpsilonSVR, OneClassSVM | high | also via ScikitLearn.jl
3842
[LightGBM.jl](https://github.com/IQVIA-ML/LightGBM.jl) | - | LGBMClassifier, LGBMRegressor | high |
3943
[Flux.jl](https://github.com/FluxML/Flux.jl) | [MLJFlux.jl](https://github.com/FluxML/MLJFlux.jl) | NeuralNetworkRegressor, NeuralNetworkClassifier, MultitargetNeuralNetworkRegressor, ImageClassifier | low |
44+
[MLJBalancing.jl](https://github.com/JuliaAI/MLJBalancing.jl) | - | BalancedBaggingClassifier | low |
4045
[MLJLinearModels.jl](https://github.com/JuliaAI/MLJLinearModels.jl) | - | LinearRegressor, RidgeRegressor, LassoRegressor, ElasticNetRegressor, QuantileRegressor, HuberRegressor, RobustRegressor, LADRegressor, LogisticClassifier, MultinomialClassifier | medium |
4146
[MLJModels.jl](https://github.com/JuliaAI/MLJModels.jl) (built-in) | - | ConstantClassifier, ConstantRegressor, ContinuousEncoder, DeterministicConstantClassifier, DeterministicConstantRegressor, FeatureSelector, FillImputer, InteractionTransformer, OneHotEncoder, Standardizer, UnivariateBoxCoxTransformer, UnivariateDiscretizer, UnivariateFillImputer, UnivariateTimeTypeToContinuous, Standardizer, BinaryThreshholdPredictor | medium |
4247
[MLJText.jl](https://github.com/JuliaAI/MLJText.jl) | - | TfidfTransformer, BM25Transformer, CountTransformer | low |

src/MLJ.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ using MLJModels
5050
using OpenML
5151
@reexport using MLJFlow
5252
@reexport using StatisticalMeasures
53+
import MLJBalancing
54+
@reexport using MLJBalancing: BalancedModel
5355
using MLJIteration
5456
import MLJIteration.IterationControl
5557

test/exported_names.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@ IterationControl.with_state_do(Step(2))
1212
IteratedModel
1313
MLJIteration
1414

15+
# MLJBalancing
16+
17+
bmodel = @test_logs(
18+
(:warn, r"^No balancer"),
19+
BalancedModel(model=ConstantClassifier()),
20+
)
21+
22+
@test bmodel isa Probabilistic
23+
24+
1525
# MLJSerialization
1626

1727
Save()

0 commit comments

Comments
 (0)