Skip to content

Commit 76f2e85

Browse files
authored
Merge pull request #182 from JuliaDiff/ox/decoratormacros
Allow AD systems to register hooks so they can create new overloads in response to new rules
2 parents 67996f2 + eccb894 commit 76f2e85

15 files changed

+579
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.9.5"
3+
version = "0.9.6"
44

55
[deps]
66
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"

docs/Manifest.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ version = "0.8.2"
3131

3232
[[Documenter]]
3333
deps = ["Base64", "Dates", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
34-
git-tree-sha1 = "1c593d1efa27437ed9dd365d1143c594b563e138"
34+
git-tree-sha1 = "fb1ff838470573adc15c71ba79f8d31328f035da"
3535
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
36-
version = "0.25.1"
36+
version = "0.25.2"
3737

3838
[[DocumenterTools]]
3939
deps = ["Base64", "DocStringExtensions", "Documenter", "FileWatching", "LibGit2", "Sass"]
@@ -78,9 +78,9 @@ version = "0.2.2"
7878

7979
[[Parsers]]
8080
deps = ["Dates", "Test"]
81-
git-tree-sha1 = "10134f2ee0b1978ae7752c41306e131a684e1f06"
81+
git-tree-sha1 = "8077624b3c450b15c087944363606a6ba12f925e"
8282
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
83-
version = "1.0.7"
83+
version = "1.0.10"
8484

8585
[[Pkg]]
8686
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
@@ -91,7 +91,7 @@ deps = ["Unicode"]
9191
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
9292

9393
[[REPL]]
94-
deps = ["InteractiveUtils", "Markdown", "Sockets"]
94+
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
9595
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
9696

9797
[[Random]]

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
DocumenterTools = "35a29f4d-8980-5a13-9543-d66fff28ecb8"
5+
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
56

67
[compat]
78
Documenter = "0.25"

docs/make.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using ChainRulesCore
22
using Documenter
33
using DocumenterTools: Themes
4+
using Markdown
45

56
DocMeta.setdocmeta!(
67
ChainRulesCore,
@@ -36,6 +37,10 @@ makedocs(
3637
"Complex Numbers" => "complex.md",
3738
"Deriving Array Rules" => "arrays.md",
3839
"Debug Mode" => "debug_mode.md",
40+
"Usage in AD" => [
41+
"Overview" => "autodiff/overview.md",
42+
"Operator Overloading" => "autodiff/operator_overloading.md"
43+
],
3944
"Design" => [
4045
"Many Differential Types" => "design/many_differentials.md",
4146
],

docs/src/api.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,16 @@ Pages = [
2727
Private = false
2828
```
2929

30+
## Ruleset Loading
31+
```@autodocs
32+
Modules = [ChainRulesCore]
33+
Pages = ["ruleset_loading.jl"]
34+
Private = false
35+
```
36+
3037
## Internal
3138
```@docs
3239
ChainRulesCore.AbstractDifferential
3340
ChainRulesCore.debug_mode
41+
ChainRulesCore.clear_new_rule_hooks!
3442
```
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Operator Overloading
2+
3+
The principal interface for using the operator overload generation method is [`on_new_rule`](@ref).
4+
This function allows one to register a hook to be run every time a new rule is defined.
5+
The hook receives a signature type-type as input, and generally will use `eval` to define
6+
an overload of an AD system's overloaded type.
7+
For example, using the signature type `Tuple{typeof(+), Real, Real}` to make
8+
`+(::DualNumber, ::DualNumber)` call the `frule` for `+`.
9+
A signature type tuple always has the form:
10+
`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2}, ...}`, where `pos_arg1` is the
11+
first positional argument.
12+
One can dispatch on the signature type to make rules with argument types your AD does not support not call `eval`;
13+
or more simply you can just use conditions for this.
14+
For example if your AD only supports `AbstractMatrix{Float64}` and `Float64` inputs you might write:
15+
```julia
16+
const ACCEPT_TYPE = Union{Float64, AbstractMatrix{Float64}}
17+
function define_overload(sig::Type{<:Tuple{F, Vararg{ACCEPT_TYPE}}) where F
18+
@eval quote
19+
# ...
20+
end
21+
end
22+
define_overload(::Any) = nothing # don't do anything for any other signature
23+
24+
on_new_rule(frule, define_overload)
25+
```
26+
27+
or you might write:
28+
```julia
29+
const ACCEPT_TYPES = (Float64, AbstractMatrix{Float64})
30+
function define_overload(sig) where F
31+
sig = Base.unwrap_unionall(sig) # not really handling most UnionAll,
32+
opT, argTs = Iterators.peel(sig.parameters)
33+
all(any(acceptT<: argT for acceptT in ACCEPT_TYPES) for argT in argTs) || return
34+
@eval quote
35+
# ...
36+
end
37+
end
38+
39+
on_new_rule(frule, define_overload)
40+
```
41+
42+
The generation of overloaded code is the responsibility of the AD implementor.
43+
Packages like [ExprTools.jl](https://github.com/invenia/ExprTools.jl) can be helpful for this.
44+
Its generally fairly simple, though can become complex if you need to handle complicated type-constraints.
45+
Examples are shown below.
46+
47+
The hook is automatically triggered whenever a package is loaded.
48+
It can also be triggers manually using `refresh_rules`(@ref).
49+
This is useful for example if new rules are define in the REPL, or if a package defining rules is modified.
50+
(Revise.jl will not automatically trigger).
51+
When the rules are refreshed (automatically or manually), the hooks are only triggered on new/modified rules; not ones that have already had the hooks triggered on.
52+
53+
`clear_new_rule_hooks!`(@ref) clears all registered hooks.
54+
It is useful to undo [`on_new_rule`] hook registration if you are iteratively developing your overload generation function.
55+
56+
## Examples
57+
58+
### ForwardDiffZero
59+
The overload generation hook in this example is: `define_dual_overload`.
60+
61+
````@eval
62+
using Markdown
63+
Markdown.parse("""
64+
```julia
65+
$(read(joinpath(@__DIR__,"../../../test/demos/forwarddiffzero.jl"), String))
66+
```
67+
""")
68+
````
69+
70+
### ReverseDiffZero
71+
The overload generation hook in this example is: `define_tracked_overload`.
72+
73+
````@eval
74+
using Markdown
75+
Markdown.parse("""
76+
```julia
77+
$(read(joinpath(@__DIR__,"../../../test/demos/reversediffzero.jl"), String))
78+
```
79+
""")
80+
````

docs/src/autodiff/overview.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Using ChainRules in your AD system
2+
3+
This section is for authors of AD systems.
4+
It assumes a pretty solid understanding of both Julia and automatic differentiation.
5+
It explains how to make use of ChainRule's "rulesets" ([`frule`](@ref)s, [`rrule`](@ref)s,)
6+
to avoid having to code all your own AD primitives / custom sensitives.
7+
8+
There are 3 main ways to access ChainRules rule sets in your AutoDiff system.
9+
10+
1. [Operation Overloading Generation](operator_overloading.html)
11+
- This is primarily intended for operator overloading based AD systems which will generate overloads for primal functions based for their overloaded types based on the existence of an `rrule`/`frule`.
12+
- A source code generation based AD can also use this by overloading their transform generating function directly so as not to recursively generate a transform but to just return the rule.
13+
- This does not play nice with Revise.jl, adding or modifying rules in loaded files will not be reflected until a manual refresh, and deleting rules will not be reflected at all.
14+
2. Source code tranform based on inserting branches that check of `rrule`/`frule` return `nothing`
15+
- If the `rrule`/`frule` returns a rule result then use it, if it returns `nothing` then do normal AD path.
16+
- In theory type inference optimizes these branchs out; in practice it may not.
17+
- This is a fairly simple Cassette overdub (or similar) of all calls, and is suitable for overloading based AD or source code transformation.
18+
3. Source code transform based on `rrule`/`frule` method-table
19+
- If an applicable `rrule`/`frule` exists in the method table then use it, else generate normal AD path.
20+
- This avoids having branches in your generated code.
21+
- This requires maintaining your own back-edges.
22+
- This is pretty hardcore even by the standard of source code tranformations.

src/ChainRulesCore.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
module ChainRulesCore
22
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
3+
using MuladdMacro: @muladd
34

4-
export frule, rrule
5-
export @scalar_rule, @thunk
6-
export canonicalize, extern, unthunk
5+
export on_new_rule, refresh_rules # generation tools
6+
export frule, rrule # core function
7+
export @scalar_rule, @thunk # definition helper macros
8+
export canonicalize, extern, unthunk # differential operations
9+
# differentials
710
export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero, AbstractThunk
811
export NO_FIELDS
912

@@ -20,5 +23,6 @@ include("differential_arithmetic.jl")
2023

2124
include("rules.jl")
2225
include("rule_definition_tools.jl")
26+
include("ruleset_loading.jl")
2327

2428
end # module

src/rule_definition_tools.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
# These are some macros (and supporting functions) to make it easier to define rules.
2-
using MuladdMacro: @muladd
3-
42
"""
53
@scalar_rule(f(x₁, x₂, ...),
64
@setup(statement₁, statement₂, ...),

src/rules.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
#####
2-
##### `frule`/`rrule`
3-
#####
4-
51
"""
62
frule((Δf, Δx...), f, x...)
73

0 commit comments

Comments
 (0)