-
Notifications
You must be signed in to change notification settings - Fork 74
add example for symmetric matrices #2545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/docs/src/notebooks/structured.jl b/docs/src/notebooks/structured.jl
index 359b215..f21e9c7 100644
--- a/docs/src/notebooks/structured.jl
+++ b/docs/src/notebooks/structured.jl
@@ -28,19 +28,19 @@ import EnzymeCore: EnzymeRules
# ╔═╡ 4aa3740e-211c-4fff-9707-b8731a3fa57f
begin
- struct MySymmetric{T,S<:AbstractMatrix{<:T}} <: AbstractMatrix{T}
- data::S
- uplo::Char
-
- function MySymmetric{T,S}(data, uplo::Char) where {T,S<:AbstractMatrix{<:T}}
- LinearAlgebra.require_one_based_indexing(data)
- (uplo != 'U' && uplo != 'L') && LinearAlgebra.throw_uplo()
- new{T,S}(data, uplo)
- end
- end
- function MySymmetric(A, uplo='U')
- MySymmetric{eltype(A), typeof(A)}(A, 'U')
- end
+ struct MySymmetric{T, S <: AbstractMatrix{<:T}} <: AbstractMatrix{T}
+ data::S
+ uplo::Char
+
+ function MySymmetric{T, S}(data, uplo::Char) where {T, S <: AbstractMatrix{<:T}}
+ LinearAlgebra.require_one_based_indexing(data)
+ (uplo != 'U' && uplo != 'L') && LinearAlgebra.throw_uplo()
+ return new{T, S}(data, uplo)
+ end
+ end
+ function MySymmetric(A, uplo = 'U')
+ return MySymmetric{eltype(A), typeof(A)}(A, 'U')
+ end
end
# ╔═╡ d6572143-fc11-4fdc-9c23-34867b33ad85
@@ -57,13 +57,15 @@ end
# ╔═╡ 8f6588f2-754b-435c-9998-38da8f6b14ad
begin
- Base.size(A::MySymmetric) = size(A.data)
- Base.length(A::MySymmetric) = length(A.data)
+ Base.size(A::MySymmetric) = size(A.data)
+ Base.length(A::MySymmetric) = length(A.data)
end
# ╔═╡ de04ba90-a397-4b75-b6b4-977d5881e848
-x = [1.0 0.0
- 0.0 1.0]
+x = [
+ 1.0 0.0
+ 0.0 1.0
+]
# ╔═╡ 9beada1d-6281-4dee-9049-8a66af1199a4
norm(x)
@@ -72,11 +74,13 @@ norm(x)
Enzyme.gradient(Reverse, norm, x) |> only
# ╔═╡ 4be9a16a-c14b-4378-aa02-fc4bfa783d10
-Enzyme.gradient(Reverse, norm, MySymmetric(x))|> only
+Enzyme.gradient(Reverse, norm, MySymmetric(x)) |> only
# ╔═╡ de7e14eb-4e55-4661-8614-e8026d09e6d3
-x2 = [0.0 1.0
- 1.0 0.0]
+x2 = [
+ 0.0 1.0
+ 1.0 0.0
+]
# ╔═╡ 999fb61b-20c6-4dcf-ad34-eca257bfda9f
d_x2 = Enzyme.gradient(Reverse, norm, x2) |> only
@@ -91,26 +95,26 @@ d_x2 == d_x2_sym
sum(d_x2) == sum(d_x2_sym.data)
# ╔═╡ 827485bf-0973-4650-a969-6225f72e5d6a
- Symmetric(x2) |> dump
+Symmetric(x2) |> dump
# ╔═╡ dbe34880-93bf-4a5d-b28b-5e6b76267742
- d_x2_sym |> dump
+d_x2_sym |> dump
# ╔═╡ 0122a4df-75d9-444e-8d83-d7a93b6dfeb5
begin
- struct MySymmetric2{T,S<:AbstractMatrix{<:T}} <: AbstractMatrix{T}
- data::S
- uplo::Char
-
- function MySymmetric2{T,S}(data, uplo::Char) where {T,S<:AbstractMatrix{<:T}}
- LinearAlgebra.require_one_based_indexing(data)
- (uplo != 'U' && uplo != 'L') && LinearAlgebra.throw_uplo()
- new{T,S}(data, uplo)
- end
- end
- function MySymmetric2(A, uplo='U')
- MySymmetric2{eltype(A), typeof(A)}(A, 'U')
- end
+ struct MySymmetric2{T, S <: AbstractMatrix{<:T}} <: AbstractMatrix{T}
+ data::S
+ uplo::Char
+
+ function MySymmetric2{T, S}(data, uplo::Char) where {T, S <: AbstractMatrix{<:T}}
+ LinearAlgebra.require_one_based_indexing(data)
+ (uplo != 'U' && uplo != 'L') && LinearAlgebra.throw_uplo()
+ return new{T, S}(data, uplo)
+ end
+ end
+ function MySymmetric2(A, uplo = 'U')
+ return MySymmetric2{eltype(A), typeof(A)}(A, 'U')
+ end
end
# ╔═╡ 8497709d-d123-48bc-a86e-5f58aa1b0ebc
@@ -149,8 +153,8 @@ md"""
# ╔═╡ d0a031e4-99a4-417b-8f57-58a67219fa23
begin
- Base.size(A::MySymmetric2) = size(A.data)
- Base.length(A::MySymmetric2) = length(A.data)
+ Base.size(A::MySymmetric2) = size(A.data)
+ Base.length(A::MySymmetric2) = length(A.data)
end
# ╔═╡ d047c162-446f-4de4-b2fd-5f2550f0ad78
@@ -160,34 +164,36 @@ Now we can implement a rule where we adjust the gradient contribution to be half
# ╔═╡ e81ac66b-75ba-4e88-8e9c-491a60a671dc
begin
- function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.getindex)}, ::Type{<:Active}, S::Duplicated{<:MySymmetric2}, i::Const, j::Const)
- # Compute primal
- if needs_primal(config)
- primal = func.val(S.val, i.val, j.val)
- else
- primal = nothing
- end
-
- # Return an AugmentedReturn object with shadow = nothing
- return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
- end
-
- function EnzymeRules.reverse(config, ::Const{typeof(Base.getindex)}, dret::Active, tape,
- S::Duplicated{<:MySymmetric2}, i::Const, j::Const)
- i = i.val
- j = j.val
- A = S.val
- dA = S.dval
- @inbounds if i == j
- dA.data[i, j] += dret.val
- elseif (A.uplo == 'U') == (i < j)
- dA.data[i, j] += dret.val / 2
- else
- dA.data[j, i] += dret.val / 2
- end
-
- return (nothing, nothing, nothing)
- end
+ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.getindex)}, ::Type{<:Active}, S::Duplicated{<:MySymmetric2}, i::Const, j::Const)
+ # Compute primal
+ if needs_primal(config)
+ primal = func.val(S.val, i.val, j.val)
+ else
+ primal = nothing
+ end
+
+ # Return an AugmentedReturn object with shadow = nothing
+ return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
+ end
+
+ function EnzymeRules.reverse(
+ config, ::Const{typeof(Base.getindex)}, dret::Active, tape,
+ S::Duplicated{<:MySymmetric2}, i::Const, j::Const
+ )
+ i = i.val
+ j = j.val
+ A = S.val
+ dA = S.dval
+ @inbounds if i == j
+ dA.data[i, j] += dret.val
+ elseif (A.uplo == 'U') == (i < j)
+ dA.data[i, j] += dret.val / 2
+ else
+ dA.data[j, i] += dret.val / 2
+ end
+
+ return (nothing, nothing, nothing)
+ end
end
# ╔═╡ fe7138ce-950a-4f79-b176-cdb227d4c898 |
Benchmark Results
Benchmark PlotsA plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/17343787618/artifacts/3889340236. |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2545 +/- ##
=======================================
Coverage 74.92% 74.92%
=======================================
Files 56 56
Lines 17428 17428
=======================================
Hits 13058 13058
Misses 4370 4370 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
dA = S.dval | ||
@inbounds if i == j | ||
dA.data[i, j] += dret.val | ||
elseif (A.uplo == 'U') == (i < j) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about the correctness of this generally
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, which is why I haven't pushed this as a general rule for Symmetric
, this notebook is more meant to explain to users what is happening.
No description provided.