Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions src/ExplainabilityMethods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,11 @@ export AbstractXAIMethod
export Gradient, InputTimesGradient
export LRP, LRPZero, LRPEpsilon, LRPGamma

const ANALYZERS = Dict(
"Gradient" => Gradient,
"InputTimesGradient" => InputTimesGradient,
"LRP" => LRP,
"LRPZero" => LRPZero,
"LRPEpsilon" => LRPEpsilon,
"LRPGamma" => LRPGamma,
)

# LRP rules
export AbstractLRPRule
export ZeroRule, EpsilonRule, GammaRule, ZBoxRule
export modify_layer, modify_params, modify_denominator

const RULES = Dict(
"ZeroRule" => ZeroRule,
"EpsilonRule" => EpsilonRule,
"GammaRule" => GammaRule,
"ZBoxRule" => ZBoxRule,
)

# heatmapping
export heatmap

Expand Down
Binary file removed test/references/vgg19/LRP.jld2
Binary file not shown.
19 changes: 13 additions & 6 deletions test/test_rules.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
using ExplainabilityMethods
using ExplainabilityMethods: RULES, modify_params
using ExplainabilityMethods: modify_params
using Flux
using LinearAlgebra
using ReferenceTests
using Random

const RULES = Dict(
"ZeroRule" => ZeroRule,
"EpsilonRule" => EpsilonRule,
"GammaRule" => GammaRule,
"ZBoxRule" => ZBoxRule,
)

## Hand-written tests
@testset "ZeroRule analytic" begin
rule = ZeroRule()

## Simple dense layer
Rₖ₊₁ = [1/3, 2/3]
Rₖ₊₁ = [1 / 3, 2 / 3]
aₖ = [1.0, 2.0]
W = [3.0 4.0; 5.0 6.0]
b = [7.0, 8.0]
Rₖ = [17/90, 316/675] # expected output
Rₖ = [17 / 90, 316 / 675] # expected output

layer = Dense(W, b, relu)
@test rule(layer, aₖ, Rₖ₊₁) ≈ Rₖ
Expand All @@ -26,10 +33,10 @@ using Random

# Repeat in color channel dim and add batch dim
Rₖ₊₁ = reshape(repeat(Rₖ₊₁, 1, 3), 2, 2, 3, 1)
aₖ = reshape(repeat(aₖ,1, 3), 3, 3, 3, 1)
Rₖ = reshape(repeat(Rₖ,1, 3), 3, 3, 3, 1)
aₖ = reshape(repeat(aₖ, 1, 3), 3, 3, 3, 1)
Rₖ = reshape(repeat(Rₖ, 1, 3), 3, 3, 3, 1)

layer = MaxPool((2,2), stride=(1,1))
layer = MaxPool((2, 2); stride=(1, 1))
@test rule(layer, aₖ, Rₖ₊₁) ≈ Rₖ
end

Expand Down
43 changes: 31 additions & 12 deletions test/test_vgg19.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
using ExplainabilityMethods
using ExplainabilityMethods: ANALYZERS
using Flux
using JLD2

const GRADIENT_ANALYZERS = Dict(
"Gradient" => Gradient, "InputTimesGradient" => InputTimesGradient
)
const LRP_ANALYZERS = Dict(
"LRPZero" => LRPZero, "LRPEpsilon" => LRPEpsilon, "LRPGamma" => LRPGamma
)

using Random
pseudorand(T, dims...) = rand(MersenneTwister(123), T, dims...)
img = pseudorand(Float32, (224, 224, 3, 1))
input_size = (224, 224, 3, 1)
img = pseudorand(Float32, input_size)

# Load VGG model:
# We run the reference test on the randomly intialized weights
Expand All @@ -14,25 +21,37 @@ include("./vgg19.jl")
vgg19 = VGG19(; pretrain=false)
model = flatten_chain(strip_softmax(vgg19.layers))

# Run analyzers
analyzers = ANALYZERS
function LRPCustom(model::Chain)
return LRP(model, [ZBoxRule(), repeat([GammaRule()], length(model.layers) - 1)...])
end
analyzers["LRPCustom"] = LRPCustom

for (name, method) in analyzers
println("Running tests on VGG16...")

function test_vgg16(name, method)
@time @testset "$name" begin
print("Timing $name on VGG19...")
if name == "LRP"
analyzer = method(model, ZeroRule())
else
analyzer = method(model)
end
print("Timing $name...\t")
analyzer = method(model)
expl, _ = analyze(img, analyzer)

@test size(expl) == size(img)
@test_reference "references/vgg19/$(name).jld2" Dict("expl" => expl) by =
(r, a) -> isapprox(r["expl"], a["expl"]; rtol=0.05)
end
return nothing
end

# Run analyzers
@testset "LRP analyzers" begin
for (name, method) in LRP_ANALYZERS
test_vgg16(name, method)
end
end
@testset "Custom LRP composite" begin
test_vgg16("LRPCustom", LRPCustom)
end

@testset "Gradient analyzers" begin
for (name, method) in GRADIENT_ANALYZERS
test_vgg16(name, method)
end
end