Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# ExplainableAI.jl
## Version `v0.9.0`
- ![Feature][badge-feature] Support selection of AD backend via DifferentiationInterface.jl ([#167])
- `Gradient`, `InputTimesGradient` and `GradCAM` analyzers now have an additional `backend` field and type parameter
- ![Maintenance][badge-maintenance] Update XAIBase interface to v4 ([#166])

## Version `v0.8.0`
This release removes the automatic reexport of heatmapping functionality.
Users are now required to manually load
Expand Down Expand Up @@ -210,6 +215,8 @@ Performance improvements:
[VisionHeatmaps]: https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/
[TextHeatmaps]: https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/

[#167]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/167
[#166]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/166
[#162]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/162
[#159]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/159
[#157]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/157
Expand Down
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
name = "ExplainableAI"
uuid = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b"
authors = ["Adrian Hill <[email protected]>"]
version = "0.8.1"
version = "0.9.0-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -12,6 +14,8 @@ XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
DifferentiationInterface = "0.5"
Distributions = "0.25"
Random = "<0.0.1, 1"
Reexport = "1"
Expand Down
5 changes: 5 additions & 0 deletions src/ExplainableAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ import XAIBase: call_analyzer
using Base.Iterators
using Distributions: Distribution, Sampleable, Normal
using Random: AbstractRNG, GLOBAL_RNG

# Automatic differentiation
using ADTypes: AbstractADType, AutoZygote
using DifferentiationInterface: value_and_pullback
using Zygote
const DEFAULT_AD_BACKEND = AutoZygote()

include("compat.jl")
include("bibliography.jl")
Expand Down
13 changes: 11 additions & 2 deletions src/gradcam.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,25 @@ GradCAM is compatible with a wide variety of CNN model-families.
# References
- $REF_SELVARAJU_GRADCAM
"""
struct GradCAM{F,A} <: AbstractXAIMethod
struct GradCAM{F,A,B<:AbstractADType} <: AbstractXAIMethod
feature_layers::F
adaptation_layers::A
backend::B

function GradCAM(
feature_layers::F, adaptation_layers::A, backend::B=DEFAULT_AD_BACKEND
) where {F,A,B<:AbstractADType}
new{F,A,B}(feature_layers, adaptation_layers, backend)
end
end
function call_analyzer(input, analyzer::GradCAM, ns::AbstractOutputSelector; kwargs...)
A = analyzer.feature_layers(input) # feature map
feature_map_size = size(A, 1) * size(A, 2)

# Determine neuron importance αₖᶜ = 1/Z * ∑ᵢ ∑ⱼ ∂yᶜ / ∂Aᵢⱼᵏ
grad, output, output_indices = gradient_wrt_input(analyzer.adaptation_layers, A, ns)
grad, output, output_indices = gradient_wrt_input(
analyzer.adaptation_layers, A, ns, analyzer.backend
)
αᶜ = sum(grad; dims=(1, 2)) / feature_map_size
Lᶜ = max.(sum(αᶜ .* A; dims=3), 0)
return Explanation(Lᶜ, input, output, output_indices, :GradCAM, :cam, nothing)
Expand Down
57 changes: 42 additions & 15 deletions src/gradient.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,45 @@
function gradient_wrt_input(model, input, ns::AbstractOutputSelector)
output, back = Zygote.pullback(model, input)
output_indices = ns(output)

# Compute VJP w.r.t. full model output, selecting vector s.t. it masks output neurons
v = zero(output)
v[output_indices] .= 1
grad = only(back(v))
return grad, output, output_indices
function forward_with_output_selection(model, input, selector::AbstractOutputSelector)
output = model(input)
sel = selector(output)
return output[sel]
end

function gradient_wrt_input(
model, input, output_selector::AbstractOutputSelector, backend::AbstractADType
)
output = model(input)
return gradient_wrt_input(model, input, output, output_selector, backend)
end

function gradient_wrt_input(
model, input, output, output_selector::AbstractOutputSelector, backend::AbstractADType
)
output_selection = output_selector(output)
dy = zero(output)
dy[output_selection] .= 1

output, grad = value_and_pullback(model, backend, input, dy)
return grad, output, output_selection
end

"""
Gradient(model)

Analyze model by calculating the gradient of a neuron activation with respect to the input.
"""
struct Gradient{M} <: AbstractXAIMethod
struct Gradient{M,B<:AbstractADType} <: AbstractXAIMethod
model::M
Gradient(model) = new{typeof(model)}(model)
backend::B

function Gradient(model::M, backend::B=DEFAULT_AD_BACKEND) where {M,B<:AbstractADType}
new{M,B}(model, backend)
end
end

function call_analyzer(input, analyzer::Gradient, ns::AbstractOutputSelector; kwargs...)
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
grad, output, output_indices = gradient_wrt_input(
analyzer.model, input, ns, analyzer.backend
)
return Explanation(
grad, input, output, output_indices, :Gradient, :sensitivity, nothing
)
Expand All @@ -32,15 +51,23 @@ end
Analyze model by calculating the gradient of a neuron activation with respect to the input.
This gradient is then multiplied element-wise with the input.
"""
struct InputTimesGradient{M} <: AbstractXAIMethod
struct InputTimesGradient{M,B<:AbstractADType} <: AbstractXAIMethod
model::M
InputTimesGradient(model) = new{typeof(model)}(model)
backend::B

function InputTimesGradient(
model::M, backend::B=DEFAULT_AD_BACKEND
) where {M,B<:AbstractADType}
new{M,B}(model, backend)
end
end

function call_analyzer(
input, analyzer::InputTimesGradient, ns::AbstractOutputSelector; kwargs...
)
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
grad, output, output_indices = gradient_wrt_input(
analyzer.model, input, ns, analyzer.backend
)
attr = input .* grad
return Explanation(
attr, input, output, output_indices, :InputTimesGradient, :attribution, nothing
Expand Down
Loading