-
Notifications
You must be signed in to change notification settings - Fork 43
Description
Hi All,
I noticed during some of my development that StructArrays and Zygote seem to be broken. It seems that if you access a property of the struct array in a function Zygote/ChainRules doesn't maintain the StructArray type and this causes an issue during gradient accumulation. A MWE is
using Zygote
using StructArrays
f(p) = p.U^2 + p.V^2
l1(x) = sum(f, x) + sum(x.U)
l2(x) = sum(f.(x) + x.U)
l3(x) = sum(f.(x) .+ x.U)
x = StructArray{NamedTuple{(:U,:V)}}((U=rand(10), V=rand(10)))
Zygote.gradient(l1, x)
ERROR: MethodError: no method matching +(::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, ::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
+(::ChainRulesCore.AbstractThunk, ::Any) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:122
+(::Array, ::Array...) at arraymath.jl:12
...
Stacktrace:
[1] accum(x::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, y::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:17
[2] collect_similar
@ ./array.jl:716 [inlined]
[3] map
@ ./abstractarray.jl:2933 [inlined]
[4] wrap_chainrules_output
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:122 [inlined]
[5] map
@ ./tuple.jl:223 [inlined]
[6] wrap_chainrules_output
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:106 [inlined]
[7] ZBack
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:206 [inlined]
[8] Pullback
@ ~/struct_array_issue.jl:5 [inlined]
[9] (::typeof(∂(l0)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[10] (::Zygote.var"#60#61"{typeof(∂(l0))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
[11] gradient(f::Function, args::StructVector{NamedTuple{(:U, :V)}, NamedTuple{(:U, :V), Tuple{Vector{Float64}, Vector{Float64}}}, Int64})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
[12] top-level scope
@ REPL[50]:1
#####################################################################
Zygote.gradient(l2, x)
ERROR: MethodError: no method matching +(::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, ::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
+(::ChainRulesCore.AbstractThunk, ::Any) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:122
+(::Array, ::Array...) at arraymath.jl:12
...
Stacktrace:
[1] accum(x::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, y::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:17
[2] Pullback
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:43 [inlined]
[3] Pullback
@ ~/struct_array_issue.jl:6 [inlined]
[4] (::typeof(∂(l1)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[5] (::Zygote.var"#60#61"{typeof(∂(l1))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
[6] gradient(f::Function, args::StructVector{NamedTuple{(:U, :V)}, NamedTuple{(:U, :V), Tuple{Vector{Float64}, Vector{Float64}}}, Int64})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
[7] top-level scope
@ REPL[51]:1
##########################################################################
Zygote.gradient(l3, x)
ERROR: MethodError: no method matching +(::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, ::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
+(::ChainRulesCore.AbstractThunk, ::Any) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:122
+(::Array, ::Array...) at arraymath.jl:12
...
Stacktrace:
[1] accum(x::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, y::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:17
[2] Pullback
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:43 [inlined]
[3] Pullback
@ ~/struct_array_issue.jl:7 [inlined]
[4] (::typeof(∂(l2)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[5] (::Zygote.var"#60#61"{typeof(∂(l2))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
[6] gradient(f::Function, args::StructVector{NamedTuple{(:U, :V)}, NamedTuple{(:U, :V), Tuple{Vector{Float64}, Vector{Float64}}}, Int64})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
[7] top-level scope
@ REPL[52]:1
On the other hand
l0(x) = sum(f, x)
Zygote.gradient(l0, x)
Seems to work fine and return a StructArray.
I have been playing with ChainRulesCore and ProjectTo to see if I could get this to work but I am not sure the best way to store everything internally.
Working environment
julia> Pkg.status()
Status /tmp/jl_sssQXD/Project.toml
[09ab397b] StructArrays v0.6.13 https://github.com/JuliaArrays/StructArrays.jl.git#master
[e88e6eb3] Zygote v0.6.51 [09ab397b] StructArrays v0.6.13 https://github.com/JuliaArrays/StructArrays.jl.git#master
[e88e6eb3] Zygote v0.6.51
julia> versioninfo()
Julia Version 1.8.3
Commit 0434deb161e (2022-11-14 20:14 UTC)
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 32 × AMD Ryzen 9 7950X 16-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-13.0.1 (ORCJIT, znver3)
Threads: 1 on 32 virtual cores
Environment:
JULIA_EDITOR = code
JULIA_NUM_THREADS = 1