Skip to content

Conversation

priyakasimbeg
Copy link
Contributor

@priyakasimbeg priyakasimbeg commented Jul 31, 2025

Dropout feature changes
Move formatting from yapf to ruff
OGBG OOM in metrics calculation fix

priyakasimbeg and others added 30 commits April 17, 2025 03:20
It seems that the problem affecting the pytorch ogbg workloads (but only if they run for some length of time) has to do with jax/xla cpu compilation of the metrics computation. By converting the jax arrays to numpy, hopefully this can be avoided. The next step is to test on schedule free and shampoo, which I hope to do very soon.
The problem with torchrun and jax seems to be caused by jax.nn.sigmoid.
Changed from lambda expression which pylint doesn't like.
Defined np sigmoid inside use_pytorch_ddp
Added white space before and after sigmoid_np
@priyakasimbeg priyakasimbeg requested a review from a team as a code owner July 31, 2025 19:04
Copy link

github-actions bot commented Jul 31, 2025

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@priyakasimbeg priyakasimbeg merged commit fb2f492 into main Aug 2, 2025
48 of 49 checks passed
@github-actions github-actions bot locked and limited conversation to collaborators Aug 2, 2025
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants