Skip to content

Commit 7582999

Browse files
authored
Merge pull request #226 from JuliaDiff/ox/accum
Introduce add!!
2 parents 81106f5 + 899efb7 commit 7582999

File tree

8 files changed

+101
-9
lines changed

8 files changed

+101
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.9.11"
3+
version = "0.9.12"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

docs/Manifest.toml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
1010
version = "0.5.10"
1111

1212
[[ChainRulesCore]]
13-
deps = ["MuladdMacro"]
13+
deps = ["LinearAlgebra", "MuladdMacro"]
1414
path = ".."
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.9.5"
16+
version = "0.9.12"
1717

1818
[[Dates]]
1919
deps = ["Printf"]
@@ -25,9 +25,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
2525

2626
[[DocStringExtensions]]
2727
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
28-
git-tree-sha1 = "c5714d9bcdba66389612dc4c47ed827c64112997"
28+
git-tree-sha1 = "50ddf44c53698f5e784bbebb3f4b21c5807401b1"
2929
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
30-
version = "0.8.2"
30+
version = "0.8.3"
3131

3232
[[Documenter]]
3333
deps = ["Base64", "Dates", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
@@ -50,9 +50,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
5050

5151
[[JSON]]
5252
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
53-
git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e"
53+
git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4"
5454
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
55-
version = "0.21.0"
55+
version = "0.21.1"
5656

5757
[[LibGit2]]
5858
deps = ["Printf"]
@@ -61,6 +61,10 @@ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
6161
[[Libdl]]
6262
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
6363

64+
[[LinearAlgebra]]
65+
deps = ["Libdl"]
66+
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
67+
6468
[[Logging]]
6569
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
6670

@@ -91,7 +95,7 @@ deps = ["Unicode"]
9195
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
9296

9397
[[REPL]]
94-
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
98+
deps = ["InteractiveUtils", "Markdown", "Sockets"]
9599
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
96100

97101
[[Random]]

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ Pages = [
2727
Private = false
2828
```
2929

30+
## Accumulation
31+
```@autodocs
32+
Modules = [ChainRulesCore]
33+
Pages = ["accumulation.jl"]
34+
Private = false
35+
```
36+
3037
## Ruleset Loading
3138
```@autodocs
3239
Modules = [ChainRulesCore]

src/ChainRulesCore.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ export on_new_rule, refresh_rules # generation tools
77
export frule, rrule # core function
88
export @non_differentiable, @scalar_rule, @thunk # definition helper macros
99
export canonicalize, extern, unthunk # differential operations
10+
export add!! # gradient accumulation operations
1011
# differentials
1112
export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero, AbstractThunk
1213
export NO_FIELDS
@@ -21,6 +22,7 @@ include("differentials/thunks.jl")
2122
include("differentials/composite.jl")
2223

2324
include("differential_arithmetic.jl")
25+
include("accumulation.jl")
2426

2527
include("rules.jl")
2628
include("rule_definition_tools.jl")

src/accumulation.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
add!!(x, y)
3+
4+
Returns `x+y`, potentially mutating `x` in-place to hold this value.
5+
This avoids allocations when `x` can be mutated in this way.
6+
7+
See also: [`InplaceableThunk`](@ref).
8+
"""
9+
add!!(x, y) = x + y
10+
11+
add!!(x, t::InplaceableThunk) = t.add!(x)
12+
13+
function add!!(x::Array{<:Any, N}, y::AbstractArray{<:Any, N}) where N
14+
return x .+= y
15+
end

test/accumulation.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
@testset "accumulation.jl" begin
2+
@testset "scalar" begin
3+
@test 16 == add!!(12, 4)
4+
end
5+
6+
@testset "Differentials" begin
7+
@test 16 == add!!(12, @thunk(2*2))
8+
@test 16 == add!!(16, Zero())
9+
10+
@test 16 == add!!(16, DoesNotExist()) # Should this be an error?
11+
end
12+
13+
@testset "Array" begin
14+
@testset "Happy Path" begin
15+
@testset "RHS Array" begin
16+
A = [1.0 2.0; 3.0 4.0]
17+
result = -1.0*ones(2,2)
18+
ret = add!!(result, A)
19+
@test ret === result # must be same object
20+
@test result == [0.0 1.0; 2.0 3.0]
21+
end
22+
23+
@testset "RHS StaticArray" begin
24+
A = @SMatrix[1.0 2.0; 3.0 4.0]
25+
result = -1.0*ones(2,2)
26+
ret = add!!(result, A)
27+
@test ret === result # must be same object
28+
@test result == [0.0 1.0; 2.0 3.0]
29+
end
30+
31+
@testset "RHS Diagonal" begin
32+
A = Diagonal([1.0, 2.0])
33+
result = -1.0*ones(2,2)
34+
ret = add!!(result, A)
35+
@test ret === result # must be same object
36+
@test result == [0.0 -1.0; -1.0 1.0]
37+
end
38+
end
39+
40+
@testset "Unhappy Path" begin
41+
# wrong length
42+
@test_throws DimensionMismatch add!!(ones(4,4), ones(2,2))
43+
# wrong shape
44+
@test_throws DimensionMismatch add!!(ones(4,4), ones(16))
45+
# wrong type (adding scalar to array)
46+
@test_throws MethodError add!!(ones(4), 21.0)
47+
end
48+
end
49+
50+
@testset "InplaceableThunk" begin
51+
A=[1.0 2.0; 3.0 4.0]
52+
ithunk = InplaceableThunk(
53+
@thunk(A*B),
54+
x -> x.+=A
55+
)
56+
57+
result = -1.0*ones(2,2)
58+
ret = add!!(result, ithunk)
59+
@test ret === result # must be same object
60+
@test result == [0.0 1.0; 2.0 3.0]
61+
end
62+
end

test/rules.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#######
22
# Demo setup
3-
using StaticArrays: @SVector
43

54
cool(x) = x + 1
65
cool(x, y) = x + y + 1

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using BenchmarkTools
33
using ChainRulesCore
44
using LinearAlgebra: Diagonal, dot
55
using FiniteDifferences
6+
using StaticArrays
67
using Test
78

89
@testset "ChainRulesCore" begin
@@ -13,6 +14,8 @@ using Test
1314
include("differentials/composite.jl")
1415
end
1516

17+
include("accumulation.jl")
18+
1619
include("ruleset_loading.jl")
1720
include("rules.jl")
1821
include("rule_definition_tools.jl")

0 commit comments

Comments
 (0)