-
Notifications
You must be signed in to change notification settings - Fork 36
Fix Zygote support #100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Zygote support #100
Conversation
src/threadsafe.jl
Outdated
################# | ||
# VectorOfLogps # | ||
################# | ||
|
||
struct VectorOfLogps{T1, T2 <: Vector{Base.RefValue{T1}}} | ||
v::T2 | ||
end | ||
VectorOfLogps(::Type{T}, n::Int) where {T} = VectorOfLogps(zero(T), n) | ||
function VectorOfLogps(val::T, n::Int) where {T} | ||
v = [val for i in 1:Threads.nthreads()] | ||
return VectorOfLogps(v) | ||
end | ||
VectorOfLogps(v::Vector) = VectorOfLogps(Ref.(v)) | ||
Base.getindex(v::VectorOfLogps, i::Integer) = v.v[i][] | ||
function Base.setindex!(v::VectorOfLogps, val, i::Integer) | ||
v.v[i][] = val | ||
return v | ||
end | ||
Base.sum(v::VectorOfLogps) = sum(v -> v[], v.v) | ||
function Base.fill!(v::VectorOfLogps, val) | ||
for i in 1:length(v.v) | ||
v.v[i][] = val | ||
end | ||
return v | ||
end | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should add a VectorOfLogps
. For the change from array to tuple of Refs one only has to change the constructor and the logp methods, in a very simple and trivial way. IMO coming up with all the methods and implementations for VectorLogps
is much more complicated and less obvious.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's nice in that it separate the logic of what we do to logps from that of how we do it. It's not really complicated at all. We can remove some constructors to simplify it if you want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it's not a simplification since it's much more than what would be required otherwise. There is no need for a special type since it is only used in the logp methods of ThreadSafeVarInfo
.
I think it would be good to keep this focused on the Zygote changes and update the Turing folder in a separate PR. If something went undetected before it is better to write tests based on DynamicPPL only. |
Why separate them? We need the latest Turing to test DPPL properly. Besides, the the tests update are in a separate commit there is no point separating it to another completely different PR. |
The problem is not that the Turing version is not the most recent one, the problem is that we just don't run the full Turing test suite (and don't want to). Thus, e.g., Zygote was not included but that is already fixed by moving the AD test to DynamicPPL and extending it (without Turing dependencies). Just updating the Turing folder does not fix the issues, and hence should be done separately. |
Actually, your copy of the Turing folder breaks DynamicPPL since DynamicPPL was already updated (and its Turing folder) to support Libtask 0.4 whereas the PR to Turing is not merged yet. So please just stick to the changes of ThreadSafeVarInfo and the adjoints. |
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
b200f86
to
08dac12
Compare
I'm fine with merging this if |
If you don't want to remove it, I can fix it in the other PR. IMO simplifying the constructor does not remedy the problem that it is still more complicated than just modifying the logp methods and hence does not provide an advantage. |
This PR fixes Zygote support in the branch
fixes_threaded
. It currently does some type piracy to define an adjoint fornthreads
andthreadid
. I will make a PR to Zygote with these now, then they can be removed from here when the PR is merged and released.