Fix stale transform copy-chain leaks#3290
Merged
angeloskath merged 1 commit intoml-explore:mainfrom Mar 24, 2026
Merged
Conversation
zcbenz
requested changes
Mar 22, 2026
5fdb196 to
314a742
Compare
zcbenz
approved these changes
Mar 23, 2026
Collaborator
zcbenz
left a comment
There was a problem hiding this comment.
I wonder if we should make it a new primitive, but I think current approach is good enough. 👍
Fixes ml-explore#2841 vjp() and jvp() wrap each primal in copy(p, s) to create tracers. When user code stores a tracer into an external container and feeds it back as a primal in the next call, each iteration nests another Copy: call 1: container[0] = copy(original) call 2: container[0] = copy(copy(original)) call N: container[0] = copy^N(original) The container keeps the head alive, which transitively keeps every intermediate Copy node alive — linear memory growth per call. Fix: before creating the new tracer copy, peel off one stale Copy wrapper (non-tracer Copy primitive with inputs). Active tracers (is_tracer()=true, from nested transforms) are never unwrapped, preserving nested transform semantics. Copy's VJP is identity so flattening is gradient-safe.
314a742 to
3e253ce
Compare
angeloskath
approved these changes
Mar 24, 2026
Member
angeloskath
left a comment
There was a problem hiding this comment.
This looks good thanks.
However, it is mostly a convenience for the user imho. Because the pattern that exhibits the issue has a bug in that it doesn't evaluate the container. If it did then there would be no copy-chain.
Just to make it clear why the problem is not the copy. The following code will "leak" even with the "fix".
auto grad_fn = grad([&container](const std::vector<array>& inputs) {
container[0] = 2 * inputs[0];
return sum(inputs[1]);
});container[0] will no longer just be copy but 2x of copy so we can't extract it from the input.
Another way of seeing this is that just calling forward in a loop will end up having huge chains of arrays if you never evaluate anything.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
vjp()andjvp()wrap each primal incopy(p, s)to create tracers. When user code stores a tracer into an external container and feeds it back as a primal in the next call, each iteration nests another Copy:The container keeps the head alive, which transitively keeps every intermediate Copy node alive — linear memory growth per call.
Fix: before creating the new tracer copy,
unwrap_stale_copy_wrappers()peels off Copy nodes that haveis_tracer()=false, collapsing the chain to depth 1. Active tracers (is_tracer()=true, from nested transforms) are never unwrapped, preserving nested transform semantics. Copy's VJP is identity so flattening is gradient-safe.