Skip to content

Commit e2d15dd

Browse files
committed
make StatisticalMeasures a [weakdep]; add DefaultMeasuresExt.jl
1 parent 91a6f54 commit e2d15dd

File tree

14 files changed

+52
-15
lines changed

14 files changed

+52
-15
lines changed

Project.toml

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,35 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2626
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
2727
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
2828
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
29+
StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
2930
StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9"
3031
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3132
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3233
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
3334

35+
[weakdeps]
36+
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
37+
38+
[extensions]
39+
DefaultMeasuresExt = "StatisticalMeasures"
40+
3441
[compat]
3542
CategoricalArrays = "0.9, 0.10"
3643
CategoricalDistributions = "0.1"
3744
ComputationalResources = "0.3"
3845
Distributions = "0.25.3"
3946
InvertedIndices = "1"
47+
LearnAPI = "0.1"
4048
MLJModelInterface = "1.7"
4149
Missings = "0.4, 1"
42-
LearnAPI = "0.1"
4350
OrderedCollections = "1.1"
4451
Parameters = "0.12"
4552
PrettyTables = "1, 2"
4653
ProgressMeter = "1.7.1"
4754
Reexport = "1.2"
4855
ScientificTypes = "3"
49-
StatisticalMeasures = "0.1"
56+
StatisticalMeasures = "0.1.1"
57+
StatisticalMeasuresBase = "0.1.1"
5058
StatisticalTraits = "3.2"
5159
StatsBase = "0.32, 0.33, 0.34"
5260
Tables = "0.2, 1.0"
@@ -60,8 +68,9 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
6068
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
6169
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
6270
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
71+
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
6372
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6473
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
6574

6675
[targets]
67-
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"]
76+
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "StatisticalMeasures", "Test", "TypedTables"]

ext/DefaultMeasuresExt.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module DefaultMeasuresExt
2+
3+
using MLJBase
4+
import MLJBase:default_measure, ProbabilisticDetector, DeterministicDetector
5+
using StatisticalMeasures
6+
using StatisticalMeasures.ScientificTypesBase
7+
8+
default_measure(::Deterministic, ::Type{<:Union{Continuous,Count}}) = l2
9+
default_measure(::Deterministic, ::Type{<:Finite}) = misclassification_rate
10+
default_measure(::Probabilistic, ::Type{<:Union{Finite,Count}}) = log_loss
11+
default_measure(::Probabilistic, ::Type{<:Continuous}) = log_loss
12+
default_measure(::ProbabilisticDetector, ::Type{<:OrderedFactor{2}}) = area_under_curve
13+
default_measure(::DeterministicDetector, ::Type{<:OrderedFactor{2}}) = balanced_accuracy
14+
15+
end # module

src/MLJBase.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,7 @@ import Distributions: pdf, logpdf, sampler
8989
const Dist = Distributions
9090

9191
# Measures
92-
@reexport using StatisticalMeasures
93-
import StatisticalMeasures.StatisticalMeasuresBase
92+
import StatisticalMeasuresBase
9493

9594
# from Standard Library:
9695
using Statistics, LinearAlgebra, Random, InteractiveUtils
@@ -312,4 +311,10 @@ export default_measure
312311
export pdf, sampler, mode, median, mean, shuffle!, categorical, shuffle,
313312
levels, levels!, std, Not, support, logpdf, LittleDict
314313

314+
# for julia < 1.9
315+
if !isdefined(Base, :get_extension)
316+
include(joinpath("..","ext", "DefaultMeasuresExt.jl"))
317+
@reexport using .DefaultMeasuresExt.StatisticalMeasures
318+
end
319+
315320
end # module

src/composition/models/stacking.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ function internal_stack_report(
425425
model_results.operation,
426426
))
427427
ypred = operation(mach, Xtest)
428-
measurements = StatisticalMeasures.measurements(measure, ypred, ytest)
428+
measurements = StatisticalMeasuresBase.measurements(measure, ypred, ytest)
429429

430430
# Update per observation:
431431
model_results.per_observation[i][foldid] = measurements

src/default_measures.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@ reliably inferred.
99
For Julia 1.9 and higher, `nothing` is returned, unless StatisticalMeasures.jl is
1010
loaded.
1111
12+
# New implementations
13+
14+
This method dispatches `default_measure(model, observation_scitype)`, which has
15+
`nothing` as the fallback return value. Extend `default_measure` by overloading this
16+
version of the method. See for example the MLJBase.jl package extension,
17+
DefaultMeausuresExt.jl.
18+
1219
"""
1320
default_measure(m) = nothing
1421
default_measure(m::Union{Supervised,Annotator}) =
1522
default_measure(m, nonmissingtype(guess_model_target_observation_scitype(m)))
1623
default_measure(m, S) = nothing
17-
default_measure(::Deterministic, ::Type{<:Union{Continuous,Count}}) = l2
18-
default_measure(::Deterministic, ::Type{<:Finite}) = misclassification_rate
19-
default_measure(::Probabilistic, ::Type{<:Union{Finite,Count}}) = log_loss
20-
default_measure(::Probabilistic, ::Type{<:Continuous}) = log_loss
21-
default_measure(::ProbabilisticDetector, ::Type{<:OrderedFactor{2}}) = area_under_curve
22-
default_measure(::DeterministicDetector, ::Type{<:OrderedFactor{2}}) = balanced_accuracy

test/composition/learning_networks/deprecated_machines.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using MLJBase
99
using Tables
1010
using StableRNGs
1111
using Serialization
12+
using StatisticalMeasures
1213
rng = StableRNG(616161)
1314

1415
# A dummy clustering model:

test/composition/learning_networks/nodes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using MLJBase
66
using ..Models
77
using ..TestUtilities
88
using CategoricalArrays
9+
using StatisticalMeasures
910
import Random.seed!
1011
seed!(1234)
1112

test/composition/learning_networks/signatures.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using Tables
77
using Test
88
using MLJModelInterface
99
using OrderedCollections
10+
using StatisticalMeasures
1011

1112
@testset "signatures - accessor functions" begin
1213
a = source(:a)

test/composition/models/network_composite.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module TestNetowrkComposite
1+
module TestNetoworkComposite
22

33
using Test
44
using MLJBase
@@ -9,6 +9,7 @@ using Tables
99
using MLJModelInterface
1010
using CategoricalArrays
1111
using OrderedCollections
12+
using StatisticalMeasures
1213
using Serialization
1314

1415
const MMI = MLJModelInterface

test/composition/models/stacking.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ module TestStacking
22

33
using Test
44
using MLJBase
5+
using StatisticalMeasures
56
using MLJModelInterface
67
using ..Models
78
using Random
89
using StableRNGs
9-
1010
import Distributions
1111

1212
rng = StableRNGs.StableRNG(1234)

0 commit comments

Comments
 (0)