Skip to content

Commit 36b0449

Browse files
authored
Dictionary differentials (#183)
* Dictionary differentials * Removes unnecessary space * Tweaks comment * Update docs Manifest * Bump patch
1 parent be20cb0 commit 36b0449

File tree

5 files changed

+40
-4
lines changed

5 files changed

+40
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.9.4"
3+
version = "0.9.5"
44

55
[deps]
66
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"

docs/Manifest.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
77
deps = ["MuladdMacro"]
88
path = ".."
99
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
10-
version = "0.7.0"
10+
version = "0.9.3"
1111

1212
[[Dates]]
1313
deps = ["Printf"]
@@ -40,6 +40,7 @@ uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
4040
version = "0.21.0"
4141

4242
[[LibGit2]]
43+
deps = ["Printf"]
4344
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
4445

4546
[[Libdl]]
@@ -67,7 +68,7 @@ uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
6768
version = "0.3.12"
6869

6970
[[Pkg]]
70-
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"]
71+
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
7172
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
7273

7374
[[Printf]]

src/differential_arithmetic.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ function Base.:+(a::P, d::Composite{P}) where P
115115
return construct(P, net_backing)
116116
end
117117
end
118+
Base.:+(a::Dict, d::Composite{P}) where {P} = merge(+, a, backing(d))
118119
Base.:+(a::Composite{P}, b::P) where P = b + a
119120

120121
# We intentionally do not define, `Base.*(::Composite, ::Composite)` as that is not meaningful

src/differentials/composite.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ To make a `Composite` have all the fields of the primal the [`canonicalize`](@re
2222
function is provided.
2323
"""
2424
struct Composite{P, T} <: AbstractDifferential
25-
# Note: If T is a Tuple, then P is also a Tuple
25+
# Note: If T is a Tuple/Dict, then P is also a Tuple/Dict
2626
# (but potentially a different one, as it doesn't contain differentials)
2727
backing::T
2828
end
@@ -41,6 +41,12 @@ function Composite{P}() where P<:Tuple
4141
return Composite{P, typeof(backing)}(backing)
4242
end
4343

44+
function Composite{P}(d::Dict) where {P<:Dict}
45+
return Composite{P, typeof(d)}(d)
46+
end
47+
48+
Base.:(==)(a::Composite, b::Composite) = backing(a) == backing(b)
49+
4450
function Base.show(io::IO, comp::Composite{P}) where P
4551
print(io, "Composite{")
4652
show(io, P)
@@ -51,6 +57,7 @@ end
5157

5258
Base.convert(::Type{<:NamedTuple}, comp::Composite{<:Any, <:NamedTuple}) = backing(comp)
5359
Base.convert(::Type{<:Tuple}, comp::Composite{<:Any, <:Tuple}) = backing(comp)
60+
Base.convert(::Type{<:Dict}, comp::Composite{<:Dict, <:Dict}) = backing(comp)
5461

5562
Base.getindex(comp::Composite, idx) = getindex(backing(comp), idx)
5663
Base.getproperty(comp::Composite, idx::Int) = getproperty(backing(comp), idx) # for Tuple
@@ -79,6 +86,9 @@ function Base.map(f, comp::Composite{P, <:NamedTuple{L}}) where{P, L}
7986
named_vals = NamedTuple{L, typeof(vals)}(vals)
8087
return Composite{P, typeof(named_vals)}(named_vals)
8188
end
89+
function Base.map(f, comp::Composite{P, <:Dict}) where {P<:Dict}
90+
return Composite{P}(Dict(k => f(v) for (k, v) in backing(comp)))
91+
end
8292

8393
Base.conj(comp::Composite) = map(conj, comp)
8494

@@ -97,6 +107,7 @@ primal types.
97107
"""
98108
backing(x::Tuple) = x
99109
backing(x::NamedTuple) = x
110+
backing(x::Dict) = x
100111
backing(x::Composite) = getfield(x, :backing)
101112

102113
function backing(x::T)::NamedTuple where T
@@ -235,6 +246,7 @@ function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn}
235246
end
236247
end
237248

249+
elementwise_add(a::Dict, b::Dict) = merge(+, a, b)
238250

239251
struct PrimalAdditionFailedException{P} <: Exception
240252
primal::P

test/differentials/composite.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ end
2626
@testset "convert" begin
2727
@test convert(NamedTuple, Composite{Foo}(x=2.5)) == (; x=2.5)
2828
@test convert(Tuple, Composite{Tuple{Float64,}}(2.0)) == (2.0,)
29+
@test convert(Dict, Composite{Dict}(Dict(4 => 3))) == Dict(4 => 3)
2930
end
3031

3132
@testset "indexing, iterating, and properties" begin
@@ -58,15 +59,21 @@ end
5859
conj(Composite{Tuple{Float64,}}(2.0+3.0im)),
5960
Composite{Tuple{Float64,}}(2.0-3.0im)
6061
)
62+
@test ==(
63+
conj(Composite{Dict}(Dict(4 => 2.0 + 3.0im))),
64+
Composite{Dict}(Dict(4 => 2.0 + -3.0im)),
65+
)
6166
end
6267

6368
@testset "extern" begin
6469
@test extern(Composite{Foo}(x=2.0)) == (;x=2.0)
6570
@test extern(Composite{Tuple{Float64,}}(2.0)) == (2.0,)
71+
@test extern(Composite{Dict}(Dict(4 => 3))) == Dict(4 => 3)
6672

6773
# with differentials on the inside
6874
@test extern(Composite{Foo}(x=@thunk(0+2.0))) == (;x=2.0)
6975
@test extern(Composite{Tuple{Float64,}}(@thunk(0+2.0))) == (2.0,)
76+
@test extern(Composite{Dict}(Dict(4 => @thunk(3)))) == Dict(4 => 3)
7077
end
7178

7279
@testset "canonicalize" begin
@@ -110,6 +117,13 @@ end
110117
Composite{typeof(nt2)}(; nt2...)
111118
) == Composite{typeof(nt_sum)}(; nt_sum...)
112119
end
120+
121+
@testset "Dicts" begin
122+
d1 = Composite{Dict}(Dict(4 => 3.0, 3 => 2.0))
123+
d2 = Composite{Dict}(Dict(4 => 3.0, 2 => 2.0))
124+
d_sum = Composite{Dict}(Dict(4 => 3.0 + 3.0, 3 => 2.0, 2 => 2.0))
125+
@test d1 + d2 == d_sum
126+
end
113127
end
114128

115129
@testset "+ with Primals" begin
@@ -133,6 +147,11 @@ end
133147
@test Composite{typeof(nty)}(; nty...) + nty == (; a=3.0, b=1.0)
134148
end
135149

150+
@testset "Dicts" begin
151+
d_primal = Dict(4 => 3.0, 3 => 2.0)
152+
d_tangent = Composite{typeof(d_primal)}(Dict(4 =>5.0))
153+
@test d_primal + d_tangent == Dict(4 => 3.0 + 5.0, 3 => 2.0)
154+
end
136155
end
137156

138157
@testset "+ with Primals, with inner constructor" begin
@@ -189,6 +208,9 @@ end
189208
== Composite{Tuple{Float64, Float64}}(4.0, 8.0)
190209
== Composite{Tuple{Float64, Float64}}(2.0, 4.0) * 2
191210
)
211+
d = Composite{Dict}(Dict(4 => 3.0))
212+
two_d = Composite{Dict}(Dict(4 => 2 * 3.0))
213+
@test 2 * d == two_d == d * 2
192214
end
193215

194216
@testset "show" begin

0 commit comments

Comments
 (0)