3
3
import jax
4
4
import numpy as np
5
5
6
+ from keras .src .backend .common import global_state
7
+ from keras .src .random import seed_generator
6
8
from keras .src .utils import jax_utils
9
+ from keras .src .utils import rng_utils
7
10
8
11
9
12
def list_devices (device_type = None ):
@@ -185,6 +188,52 @@ def distribute_data_input(per_process_batch, layout, batch_dim_name):
185
188
return global_batch_array
186
189
187
190
191
+ def initialize_rng ():
192
+ """Initializes the global random number generator across processes.
193
+
194
+ This is required for consistent initialization in multi-host settings.
195
+ """
196
+ global_seed = rng_utils .get_random_seed ()
197
+ # Only set a random seed if not already set
198
+ # via keras.config.set_random_seed()
199
+ if global_seed is None :
200
+ # Generate a random seed on each CPU host and psum them to get a single
201
+ # consistent seed across all processes.
202
+ cpu_devices = jax .devices ("cpu" )
203
+ num_local_cpu_devices = jax .local_device_count ("cpu" )
204
+ # Seed must be in range [0, 2^32 - 1], so to ensure proper range and
205
+ # avoid signed integer overflow, we use uint32.
206
+ local_seed = jax .numpy .asarray (
207
+ [seed_generator .make_default_seed ()] * num_local_cpu_devices ,
208
+ dtype = jax .numpy .uint32 ,
209
+ )
210
+ # Sum across processes and pull out the first item.
211
+ global_seed = jax .pmap (
212
+ lambda x : jax .lax .psum (x , "all" ),
213
+ axis_name = "all" ,
214
+ devices = cpu_devices ,
215
+ )(local_seed ).item (0 )
216
+ # Set the global seed.
217
+ rng_utils .set_random_seed (global_seed )
218
+
219
+ # Check if the global seed generator is set and ensure it has an initialized
220
+ # seed. Otherwise, reset the seed to the global seed.
221
+ global_seed_generator = global_state .get_global_attribute (
222
+ "global_seed_generator"
223
+ )
224
+ if global_seed_generator is not None :
225
+ seed = global_seed_generator .get_config ()["seed" ]
226
+ if seed is None :
227
+ global_state .set_global_attribute (
228
+ "global_seed_generator" ,
229
+ seed_generator .SeedGenerator (
230
+ seed = global_seed ,
231
+ name = global_seed_generator .name ,
232
+ backend = global_seed_generator .backend ,
233
+ ),
234
+ )
235
+
236
+
188
237
def initialize (job_addresses , num_processes , process_id ):
189
238
if job_addresses and "," in job_addresses :
190
239
# When user provide all the job addresses, we will split and get the
@@ -208,6 +257,9 @@ def initialize(job_addresses, num_processes, process_id):
208
257
process_id = process_id ,
209
258
)
210
259
260
+ # Ensure the random number generator is initialized across processes.
261
+ initialize_rng ()
262
+
211
263
212
264
def num_processes ():
213
265
"""Return the number of processes for the current distribution setting."""
0 commit comments