Skip to content

Commit 0816e10

Browse files
committed
Initialize random seed for distributed models.
Keep track of a `global_random_seed`, and ensure it is set when initializing `keras.distribution.initialize(...)`. In multi-host processes in JAX, all processes require consistent random number generation. Otherwise, initializers on different hosts would produce inconsistent values, resulting in both compilation and runtime failures.
1 parent f98b91f commit 0816e10

File tree

2 files changed

+92
-20
lines changed

2 files changed

+92
-20
lines changed

keras/src/backend/jax/distribution_lib.py

Lines changed: 77 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import jax
44
import numpy as np
55

6+
from keras.src.backend.common import global_state
7+
from keras.src.random import seed_generator
68
from keras.src.utils import jax_utils
9+
from keras.src.utils import rng_utils
710

811

912
def list_devices(device_type=None):
@@ -185,28 +188,82 @@ def distribute_data_input(per_process_batch, layout, batch_dim_name):
185188
return global_batch_array
186189

187190

188-
def initialize(job_addresses, num_processes, process_id):
189-
if job_addresses and "," in job_addresses:
190-
# When user provide all the job addresses, we will split and get the
191-
# first one, which is the coordinator.
192-
job_addresses = job_addresses.split(",")
193-
# Do a sanity check to make sure the number of addresses also match
194-
# the num_processes.
195-
if num_processes is not None and num_processes != len(job_addresses):
196-
raise ValueError(
197-
f"The provided job_addresses {job_addresses} has "
198-
f"{len(job_addresses)} jobs, but num_processes is "
199-
f"{num_processes}"
200-
)
201-
coordinator_address = job_addresses[0]
202-
else:
203-
coordinator_address = job_addresses
191+
def initialize_rng():
192+
"""Initializes the global random number generator across processes.
204193
205-
jax.distributed.initialize(
206-
coordinator_address=coordinator_address,
207-
num_processes=num_processes,
208-
process_id=process_id,
194+
This is required for consistent initialization in multi-host settings.
195+
"""
196+
global_seed = rng_utils.get_random_seed()
197+
# Only set a random seed if not already set
198+
# via keras.config.set_random_seed()
199+
if global_seed is None:
200+
# Generate a random seed on each CPU host and psum them to get a single
201+
# consistent seed across all processes.
202+
cpu_devices = jax.devices("cpu")
203+
num_local_cpu_devices = jax.local_device_count("cpu")
204+
# Seed must be in range [0, 2^32 - 1], so to ensure proper range and
205+
# avoid signed integer overflow, we use uint32.
206+
local_seed = jax.numpy.asarray(
207+
[seed_generator.make_default_seed()] * num_local_cpu_devices,
208+
dtype=jax.numpy.uint32,
209+
)
210+
# Sum across processes and pull out the first item.
211+
global_seed = jax.pmap(
212+
lambda x: jax.lax.psum(x, "all"),
213+
axis_name="all",
214+
devices=cpu_devices,
215+
)(local_seed).item(0)
216+
# Set the global seed.
217+
rng_utils.set_random_seed(global_seed)
218+
219+
# Check if the global seed generator is set and ensure it has an initialized
220+
# seed. Otherwise, reset the seed to the global seed.
221+
global_seed_generator = global_state.get_global_attribute(
222+
"global_seed_generator"
209223
)
224+
if global_seed_generator is not None:
225+
seed = global_seed_generator.get_config()["seed"]
226+
if seed is None:
227+
global_state.set_global_attribute(
228+
"global_seed_generator",
229+
seed_generator.SeedGenerator(
230+
seed=global_seed,
231+
name=global_seed_generator.name,
232+
backend=global_seed_generator.backend,
233+
),
234+
)
235+
236+
237+
def initialize(job_addresses, num_processes, process_id):
238+
# Only call JAX initialize if not already initialized.
239+
# Some JAX processes already set this up (e.g. multiprocess tests).
240+
if not jax.distributed.is_initialized():
241+
if job_addresses and "," in job_addresses:
242+
# When user provide all the job addresses, we will split and get the
243+
# first one, which is the coordinator.
244+
job_addresses = job_addresses.split(",")
245+
# Do a sanity check to make sure the number of addresses also match
246+
# the num_processes.
247+
if num_processes is not None and num_processes != len(
248+
job_addresses
249+
):
250+
raise ValueError(
251+
f"The provided job_addresses {job_addresses} has "
252+
f"{len(job_addresses)} jobs, but num_processes is "
253+
f"{num_processes}"
254+
)
255+
coordinator_address = job_addresses[0]
256+
else:
257+
coordinator_address = job_addresses
258+
259+
jax.distributed.initialize(
260+
coordinator_address=coordinator_address,
261+
num_processes=num_processes,
262+
process_id=process_id,
263+
)
264+
265+
# Ensure the random number generator is initialized across processes.
266+
initialize_rng()
210267

211268

212269
def num_processes():

keras/src/utils/rng_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44

55
from keras.src import backend
66
from keras.src.api_export import keras_export
7+
from keras.src.backend.common import global_state
78
from keras.src.utils.module_utils import tensorflow as tf
89

10+
GLOBAL_RANDOM_SEED = "global_random_seed"
11+
912

1013
@keras_export("keras.utils.set_random_seed")
1114
def set_random_seed(seed):
@@ -46,6 +49,9 @@ def set_random_seed(seed):
4649
"Expected `seed` argument to be an integer. "
4750
f"Received: seed={seed} (of type {type(seed)})"
4851
)
52+
53+
# Store seed in global state so we can query it if set.
54+
global_state.set_global_attribute(GLOBAL_RANDOM_SEED, seed)
4955
random.seed(seed)
5056
np.random.seed(seed)
5157
if tf.available:
@@ -54,3 +60,12 @@ def set_random_seed(seed):
5460
import torch
5561

5662
torch.manual_seed(seed)
63+
64+
65+
def get_random_seed():
66+
"""Returns the explicit integer random seed if set.
67+
68+
If the seed has been explicitly set via `set_random_seed`, then
69+
returns the seed. Otherwise, returns `None`.
70+
"""
71+
return global_state.get_global_attribute(GLOBAL_RANDOM_SEED)

0 commit comments

Comments
 (0)