Skip to content

Initialize random seed for distributed models. #21261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2025

Conversation

cantonios
Copy link
Contributor

Keep track of a global_random_seed, and ensure it is set when initializing keras.distribution.initialize(...).

In multi-host processes in JAX, all processes require consistent random number generation. Otherwise, initializers on different hosts would produce inconsistent values, resulting in both compilation and runtime failures.

@codecov-commenter
Copy link

codecov-commenter commented May 7, 2025

Codecov Report

Attention: Patch coverage is 77.27273% with 5 lines in your changes missing coverage. Please review.

Project coverage is 82.54%. Comparing base (f98b91f) to head (efbf9b0).

Files with missing lines Patch % Lines
keras/src/backend/jax/distribution_lib.py 70.58% 3 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21261      +/-   ##
==========================================
- Coverage   82.55%   82.54%   -0.01%     
==========================================
  Files         564      564              
  Lines       54629    54651      +22     
  Branches     8493     8496       +3     
==========================================
+ Hits        45097    45114      +17     
- Misses       7442     7445       +3     
- Partials     2090     2092       +2     
Flag Coverage Δ
keras 82.36% <77.27%> (-0.01%) ⬇️
keras-jax 63.64% <77.27%> (+<0.01%) ⬆️
keras-numpy 58.74% <36.36%> (-0.02%) ⬇️
keras-openvino 32.96% <36.36%> (+<0.01%) ⬆️
keras-tensorflow 64.03% <36.36%> (-0.02%) ⬇️
keras-torch 63.69% <36.36%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Keep track of a `global_random_seed`, and ensure it is set
when initializing `keras.distribution.initialize(...)`.

In multi-host processes in JAX, all processes require consistent
random number generation.  Otherwise, initializers on different
hosts would produce inconsistent values, resulting in both
compilation and runtime failures.
@gbaned gbaned added this to PR Queue May 8, 2025
@github-project-automation github-project-automation bot moved this to Assigned Reviewer in PR Queue May 8, 2025
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you

@fchollet fchollet merged commit 5f2793e into keras-team:master May 8, 2025
7 checks passed
@github-project-automation github-project-automation bot moved this from Assigned Reviewer to Merged in PR Queue May 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Merged
Development

Successfully merging this pull request may close these issues.

4 participants