|
3 | 3 | import jax
|
4 | 4 | import numpy as np
|
5 | 5 |
|
| 6 | +from keras.src.backend.common import global_state |
| 7 | +from keras.src.random import seed_generator |
6 | 8 | from keras.src.utils import jax_utils
|
| 9 | +from keras.src.utils import rng_utils |
7 | 10 |
|
8 | 11 |
|
9 | 12 | def list_devices(device_type=None):
|
@@ -185,28 +188,82 @@ def distribute_data_input(per_process_batch, layout, batch_dim_name):
|
185 | 188 | return global_batch_array
|
186 | 189 |
|
187 | 190 |
|
188 |
| -def initialize(job_addresses, num_processes, process_id): |
189 |
| - if job_addresses and "," in job_addresses: |
190 |
| - # When user provide all the job addresses, we will split and get the |
191 |
| - # first one, which is the coordinator. |
192 |
| - job_addresses = job_addresses.split(",") |
193 |
| - # Do a sanity check to make sure the number of addresses also match |
194 |
| - # the num_processes. |
195 |
| - if num_processes is not None and num_processes != len(job_addresses): |
196 |
| - raise ValueError( |
197 |
| - f"The provided job_addresses {job_addresses} has " |
198 |
| - f"{len(job_addresses)} jobs, but num_processes is " |
199 |
| - f"{num_processes}" |
200 |
| - ) |
201 |
| - coordinator_address = job_addresses[0] |
202 |
| - else: |
203 |
| - coordinator_address = job_addresses |
| 191 | +def initialize_rng(): |
| 192 | + """Initializes the global random number generator across processes. |
204 | 193 |
|
205 |
| - jax.distributed.initialize( |
206 |
| - coordinator_address=coordinator_address, |
207 |
| - num_processes=num_processes, |
208 |
| - process_id=process_id, |
| 194 | + This is required for consistent initialization in multi-host settings. |
| 195 | + """ |
| 196 | + global_seed = rng_utils.get_random_seed() |
| 197 | + # Only set a random seed if not already set |
| 198 | + # via keras.config.set_random_seed() |
| 199 | + if global_seed is None: |
| 200 | + # Generate a random seed on each CPU host and psum them to get a single |
| 201 | + # consistent seed across all processes. |
| 202 | + cpu_devices = jax.devices("cpu") |
| 203 | + num_local_cpu_devices = jax.local_device_count("cpu") |
| 204 | + # Seed must be in range [0, 2^32 - 1], so to ensure proper range and |
| 205 | + # avoid signed integer overflow, we use uint32. |
| 206 | + local_seed = jax.numpy.asarray( |
| 207 | + [seed_generator.make_default_seed()] * num_local_cpu_devices, |
| 208 | + dtype=jax.numpy.uint32, |
| 209 | + ) |
| 210 | + # Sum across processes and pull out the first item. |
| 211 | + global_seed = jax.pmap( |
| 212 | + lambda x: jax.lax.psum(x, "all"), |
| 213 | + axis_name="all", |
| 214 | + devices=cpu_devices, |
| 215 | + )(local_seed).item(0) |
| 216 | + # Set the global seed. |
| 217 | + rng_utils.set_random_seed(global_seed) |
| 218 | + |
| 219 | + # Check if the global seed generator is set and ensure it has an initialized |
| 220 | + # seed. Otherwise, reset the seed to the global seed. |
| 221 | + global_seed_generator = global_state.get_global_attribute( |
| 222 | + "global_seed_generator" |
209 | 223 | )
|
| 224 | + if global_seed_generator is not None: |
| 225 | + seed = global_seed_generator.get_config()["seed"] |
| 226 | + if seed is None: |
| 227 | + global_state.set_global_attribute( |
| 228 | + "global_seed_generator", |
| 229 | + seed_generator.SeedGenerator( |
| 230 | + seed=global_seed, |
| 231 | + name=global_seed_generator.name, |
| 232 | + backend=global_seed_generator.backend, |
| 233 | + ), |
| 234 | + ) |
| 235 | + |
| 236 | + |
| 237 | +def initialize(job_addresses, num_processes, process_id): |
| 238 | + # Only call JAX initialize if not already initialized. |
| 239 | + # Some JAX processes already set this up (e.g. multiprocess tests). |
| 240 | + if not jax.distributed.is_initialized(): |
| 241 | + if job_addresses and "," in job_addresses: |
| 242 | + # When user provide all the job addresses, we will split and get the |
| 243 | + # first one, which is the coordinator. |
| 244 | + job_addresses = job_addresses.split(",") |
| 245 | + # Do a sanity check to make sure the number of addresses also match |
| 246 | + # the num_processes. |
| 247 | + if num_processes is not None and num_processes != len( |
| 248 | + job_addresses |
| 249 | + ): |
| 250 | + raise ValueError( |
| 251 | + f"The provided job_addresses {job_addresses} has " |
| 252 | + f"{len(job_addresses)} jobs, but num_processes is " |
| 253 | + f"{num_processes}" |
| 254 | + ) |
| 255 | + coordinator_address = job_addresses[0] |
| 256 | + else: |
| 257 | + coordinator_address = job_addresses |
| 258 | + |
| 259 | + jax.distributed.initialize( |
| 260 | + coordinator_address=coordinator_address, |
| 261 | + num_processes=num_processes, |
| 262 | + process_id=process_id, |
| 263 | + ) |
| 264 | + |
| 265 | + # Ensure the random number generator is initialized across processes. |
| 266 | + initialize_rng() |
210 | 267 |
|
211 | 268 |
|
212 | 269 | def num_processes():
|
|
0 commit comments