Skip to content

Conversation

vchuravy
Copy link
Member

No description provided.

Copy link
Contributor

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/docs/src/notebooks/structured.jl b/docs/src/notebooks/structured.jl
index 359b215..f21e9c7 100644
--- a/docs/src/notebooks/structured.jl
+++ b/docs/src/notebooks/structured.jl
@@ -28,19 +28,19 @@ import EnzymeCore: EnzymeRules
 
 # ╔═╡ 4aa3740e-211c-4fff-9707-b8731a3fa57f
 begin
-	struct MySymmetric{T,S<:AbstractMatrix{<:T}} <: AbstractMatrix{T}
-	    data::S
-	    uplo::Char
-	
-	    function MySymmetric{T,S}(data, uplo::Char) where {T,S<:AbstractMatrix{<:T}}
-	        LinearAlgebra.require_one_based_indexing(data)
-	        (uplo != 'U' && uplo != 'L') && LinearAlgebra.throw_uplo()
-	        new{T,S}(data, uplo)
-	    end
-	end
-	function MySymmetric(A, uplo='U')
-		 MySymmetric{eltype(A), typeof(A)}(A, 'U')
-	end
+    struct MySymmetric{T, S <: AbstractMatrix{<:T}} <: AbstractMatrix{T}
+        data::S
+        uplo::Char
+
+        function MySymmetric{T, S}(data, uplo::Char) where {T, S <: AbstractMatrix{<:T}}
+            LinearAlgebra.require_one_based_indexing(data)
+            (uplo != 'U' && uplo != 'L') && LinearAlgebra.throw_uplo()
+            return new{T, S}(data, uplo)
+        end
+    end
+    function MySymmetric(A, uplo = 'U')
+        return MySymmetric{eltype(A), typeof(A)}(A, 'U')
+    end
 end
 
 # ╔═╡ d6572143-fc11-4fdc-9c23-34867b33ad85
@@ -57,13 +57,15 @@ end
 
 # ╔═╡ 8f6588f2-754b-435c-9998-38da8f6b14ad
 begin
-	Base.size(A::MySymmetric) = size(A.data)
-	Base.length(A::MySymmetric) = length(A.data)
+    Base.size(A::MySymmetric) = size(A.data)
+    Base.length(A::MySymmetric) = length(A.data)
 end
 
 # ╔═╡ de04ba90-a397-4b75-b6b4-977d5881e848
-x = [1.0 0.0
-	 0.0 1.0]
+x = [
+    1.0 0.0
+    0.0 1.0
+]
 
 # ╔═╡ 9beada1d-6281-4dee-9049-8a66af1199a4
 norm(x)
@@ -72,11 +74,13 @@ norm(x)
 Enzyme.gradient(Reverse, norm, x) |> only
 
 # ╔═╡ 4be9a16a-c14b-4378-aa02-fc4bfa783d10
-Enzyme.gradient(Reverse, norm, MySymmetric(x))|> only
+Enzyme.gradient(Reverse, norm, MySymmetric(x)) |> only
 
 # ╔═╡ de7e14eb-4e55-4661-8614-e8026d09e6d3
-x2 = [0.0 1.0
-	  1.0 0.0]
+x2 = [
+    0.0 1.0
+    1.0 0.0
+]
 
 # ╔═╡ 999fb61b-20c6-4dcf-ad34-eca257bfda9f
 d_x2 = Enzyme.gradient(Reverse, norm, x2) |> only
@@ -91,26 +95,26 @@ d_x2 == d_x2_sym
 sum(d_x2) == sum(d_x2_sym.data)
 
 # ╔═╡ 827485bf-0973-4650-a969-6225f72e5d6a
- Symmetric(x2) |> dump
+Symmetric(x2) |> dump
 
 # ╔═╡ dbe34880-93bf-4a5d-b28b-5e6b76267742
- d_x2_sym |> dump
+d_x2_sym |> dump
 
 # ╔═╡ 0122a4df-75d9-444e-8d83-d7a93b6dfeb5
 begin
-	struct MySymmetric2{T,S<:AbstractMatrix{<:T}} <: AbstractMatrix{T}
-	    data::S
-	    uplo::Char
-	
-	    function MySymmetric2{T,S}(data, uplo::Char) where {T,S<:AbstractMatrix{<:T}}
-	        LinearAlgebra.require_one_based_indexing(data)
-	        (uplo != 'U' && uplo != 'L') && LinearAlgebra.throw_uplo()
-	        new{T,S}(data, uplo)
-	    end
-	end
-	function MySymmetric2(A, uplo='U')
-		 MySymmetric2{eltype(A), typeof(A)}(A, 'U')
-	end
+    struct MySymmetric2{T, S <: AbstractMatrix{<:T}} <: AbstractMatrix{T}
+        data::S
+        uplo::Char
+
+        function MySymmetric2{T, S}(data, uplo::Char) where {T, S <: AbstractMatrix{<:T}}
+            LinearAlgebra.require_one_based_indexing(data)
+            (uplo != 'U' && uplo != 'L') && LinearAlgebra.throw_uplo()
+            return new{T, S}(data, uplo)
+        end
+    end
+    function MySymmetric2(A, uplo = 'U')
+        return MySymmetric2{eltype(A), typeof(A)}(A, 'U')
+    end
 end
 
 # ╔═╡ 8497709d-d123-48bc-a86e-5f58aa1b0ebc
@@ -149,8 +153,8 @@ md"""
 
 # ╔═╡ d0a031e4-99a4-417b-8f57-58a67219fa23
 begin
-	Base.size(A::MySymmetric2) = size(A.data)
-	Base.length(A::MySymmetric2) = length(A.data)
+    Base.size(A::MySymmetric2) = size(A.data)
+    Base.length(A::MySymmetric2) = length(A.data)
 end
 
 # ╔═╡ d047c162-446f-4de4-b2fd-5f2550f0ad78
@@ -160,34 +164,36 @@ Now we can implement a rule where we adjust the gradient contribution to be half
 
 # ╔═╡ e81ac66b-75ba-4e88-8e9c-491a60a671dc
 begin
-	function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.getindex)}, ::Type{<:Active}, S::Duplicated{<:MySymmetric2}, i::Const, j::Const)
-	    # Compute primal
-	    if needs_primal(config)
-	        primal = func.val(S.val, i.val, j.val)
-	    else
-	        primal = nothing
-	    end
-	
-	    # Return an AugmentedReturn object with shadow = nothing
-	    return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
-	end
-
-	function EnzymeRules.reverse(config, ::Const{typeof(Base.getindex)}, dret::Active, tape,
-	                 S::Duplicated{<:MySymmetric2}, i::Const, j::Const)
-		i = i.val
-		j = j.val
-		A = S.val
-		dA = S.dval
-		@inbounds if i == j
-        	dA.data[i, j] += dret.val
-    	elseif (A.uplo == 'U') == (i < j)
-        	dA.data[i, j] += dret.val / 2
-    	else
-	        dA.data[j, i] += dret.val / 2
-    	end
-		
-	    return (nothing, nothing, nothing)
-	end
+    function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.getindex)}, ::Type{<:Active}, S::Duplicated{<:MySymmetric2}, i::Const, j::Const)
+        # Compute primal
+        if needs_primal(config)
+            primal = func.val(S.val, i.val, j.val)
+        else
+            primal = nothing
+        end
+
+        # Return an AugmentedReturn object with shadow = nothing
+        return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
+    end
+
+    function EnzymeRules.reverse(
+            config, ::Const{typeof(Base.getindex)}, dret::Active, tape,
+            S::Duplicated{<:MySymmetric2}, i::Const, j::Const
+        )
+        i = i.val
+        j = j.val
+        A = S.val
+        dA = S.dval
+        @inbounds if i == j
+            dA.data[i, j] += dret.val
+        elseif (A.uplo == 'U') == (i < j)
+            dA.data[i, j] += dret.val / 2
+        else
+            dA.data[j, i] += dret.val / 2
+        end
+
+        return (nothing, nothing, nothing)
+    end
 end
 
 # ╔═╡ fe7138ce-950a-4f79-b176-cdb227d4c898

Copy link
Contributor

Benchmark Results

main 71667d5... main / 71667d5...
basics/overhead 4.64 ± 0.01 ns 4.34 ± 0.01 ns 1.07 ± 0.0034
time_to_load 1.26 ± 0.0021 s 1.26 ± 0.022 s 1 ± 0.018

Benchmark Plots

A plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/17343787618/artifacts/3889340236.

Copy link

codecov bot commented Aug 30, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 74.92%. Comparing base (e31a84a) to head (71667d5).

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2545   +/-   ##
=======================================
  Coverage   74.92%   74.92%           
=======================================
  Files          56       56           
  Lines       17428    17428           
=======================================
  Hits        13058    13058           
  Misses       4370     4370           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

dA = S.dval
@inbounds if i == j
dA.data[i, j] += dret.val
elseif (A.uplo == 'U') == (i < j)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about the correctness of this generally

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, which is why I haven't pushed this as a general rule for Symmetric, this notebook is more meant to explain to users what is happening.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants