Skip to content

Fix stale transform copy-chain leaks#3290

Merged
angeloskath merged 1 commit intoml-explore:mainfrom
mlx-node:fix/transform-stale-wrapper-leak
Mar 24, 2026
Merged

Fix stale transform copy-chain leaks#3290
angeloskath merged 1 commit intoml-explore:mainfrom
mlx-node:fix/transform-stale-wrapper-leak

Conversation

@Brooooooklyn
Copy link
Copy Markdown
Contributor

vjp() and jvp() wrap each primal in copy(p, s) to create tracers. When user code stores a tracer into an external container and feeds it back as a primal in the next call, each iteration nests another Copy:

  call 1: container[0] = copy(original)
  call 2: container[0] = copy(copy(original))
  call N: container[0] = copy^N(original)

The container keeps the head alive, which transitively keeps every intermediate Copy node alive — linear memory growth per call.

Fix: before creating the new tracer copy, unwrap_stale_copy_wrappers() peels off Copy nodes that have is_tracer()=false, collapsing the chain to depth 1. Active tracers (is_tracer()=true, from nested transforms) are never unwrapped, preserving nested transform semantics. Copy's VJP is identity so flattening is gradient-safe.

@Brooooooklyn Brooooooklyn force-pushed the fix/transform-stale-wrapper-leak branch from 5fdb196 to 314a742 Compare March 23, 2026 12:45
@Brooooooklyn Brooooooklyn requested a review from zcbenz March 23, 2026 12:55
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should make it a new primitive, but I think current approach is good enough. 👍

@zcbenz zcbenz requested a review from angeloskath March 23, 2026 22:06
Fixes ml-explore#2841

vjp() and jvp() wrap each primal in copy(p, s) to create tracers. When
user code stores a tracer into an external container and feeds it back
as a primal in the next call, each iteration nests another Copy:

  call 1: container[0] = copy(original)
  call 2: container[0] = copy(copy(original))
  call N: container[0] = copy^N(original)

The container keeps the head alive, which transitively keeps every
intermediate Copy node alive — linear memory growth per call.

Fix: before creating the new tracer copy, peel off one stale Copy
wrapper (non-tracer Copy primitive with inputs). Active tracers
(is_tracer()=true, from nested transforms) are never unwrapped,
preserving nested transform semantics. Copy's VJP is identity so
flattening is gradient-safe.
@Brooooooklyn Brooooooklyn force-pushed the fix/transform-stale-wrapper-leak branch from 314a742 to 3e253ce Compare March 24, 2026 01:48
Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good thanks.

However, it is mostly a convenience for the user imho. Because the pattern that exhibits the issue has a bug in that it doesn't evaluate the container. If it did then there would be no copy-chain.

Just to make it clear why the problem is not the copy. The following code will "leak" even with the "fix".

auto grad_fn = grad([&container](const std::vector<array>& inputs) {
  container[0] = 2 * inputs[0];
  return sum(inputs[1]);
});

container[0] will no longer just be copy but 2x of copy so we can't extract it from the input.

Another way of seeing this is that just calling forward in a loop will end up having huge chains of arrays if you never evaluate anything.

@angeloskath angeloskath merged commit 604c825 into ml-explore:main Mar 24, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Memory leak in function transform with unused container

3 participants