-
Notifications
You must be signed in to change notification settings - Fork 65
Closed
Description
Currently, evaluating rules can be pretty slow (and in some cases extremely slow) on the first call. It seems to be due to JIT compiling the anonymous functions, the use of which is a central design point of ChainRules.
Here's a basic example, where it takes 2.6 seconds(!!!) to evaluate the derivative for svd
. Note that the ChainRules version is a simple port of the Nabla version, so the underlying code that does the computation is nearly identical.
julia> using ChainRules, LinearAlgebra
julia> F, dX = rrule(svd, randn(4, 4));
julia> nt = (U=F.U, S=F.S, V=F.V);
julia> @time dX(nt)
2.644154 seconds (11.32 M allocations: 589.599 MiB, 6.32% gc time)
4×4 Array{Float64,2}:
0.189976 0.714807 0.456714 2.18944
-0.313175 -0.760914 -0.0347824 0.465126
-0.53937 1.4229 -1.20063 -1.46798
-0.517822 1.03534 -1.66866 0.263543
julia> @time dX(nt)
0.017788 seconds (2.15 k allocations: 146.201 KiB)
4×4 Array{Float64,2}:
0.189976 0.714807 0.456714 2.18944
-0.313175 -0.760914 -0.0347824 0.465126
-0.53937 1.4229 -1.20063 -1.46798
-0.517822 1.03534 -1.66866 0.263543
Compare this to Nabla:
julia> using Nabla, LinearAlgebra
julia> X = randn(4, 4); F = svd(X); nt = (U=F.U, S=F.S, V=F.V);
julia> @time ∇(svd, Arg{1}, (), F, nt, X)
0.631797 seconds (2.37 M allocations: 114.680 MiB, 3.77% gc time)
4×4 Array{Float64,2}:
-1.09959 1.16123 2.27612 -0.97398
1.42299 -0.17832 -2.52324 1.48229
-3.32982 -0.746237 1.1226 2.98457
0.346265 0.402947 -0.0502055 0.661208
julia> @time dX(nt)
0.006206 seconds (2.15 k allocations: 146.201 KiB)
4×4 Array{Float64,2}:
0.189976 0.714807 0.456714 2.18944
-0.313175 -0.760914 -0.0347824 0.465126
-0.53937 1.4229 -1.20063 -1.46798
-0.517822 1.03534 -1.66866 0.263543
We should find some way(s) to mitigate this so that AD systems which switch to using ChainRules underneath won't take an enormous performance hit by doing so.
Metadata
Metadata
Assignees
Labels
No labels