|
16 | 16 |
|
17 | 17 | FLAGS = flags.FLAGS
|
18 | 18 |
|
19 |
| -# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an |
| 19 | +# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 31 - 1] (an |
20 | 20 | # unsigned int), while RandomState.randint only accepts and returns signed ints.
|
21 |
| -MAX_UINT32 = 2**32 - 1 |
22 |
| -MIN_UINT32 = 0 |
| 21 | +MAX_INT32 = 2**31 - 1 |
| 22 | +MIN_INT32 = 0 |
23 | 23 |
|
24 | 24 | SeedType = Union[int, list, np.ndarray]
|
25 | 25 |
|
26 | 26 |
|
27 | 27 | def _signed_to_unsigned(seed: SeedType) -> SeedType:
|
28 | 28 | if isinstance(seed, int):
|
29 |
| - return seed % MAX_UINT32 |
| 29 | + return seed % MAX_INT32 |
30 | 30 | if isinstance(seed, list):
|
31 |
| - return [s % MAX_UINT32 for s in seed] |
| 31 | + return [s % MAX_INT32 for s in seed] |
32 | 32 | if isinstance(seed, np.ndarray):
|
33 |
| - return np.array([s % MAX_UINT32 for s in seed.tolist()]) |
| 33 | + return np.array([s % MAX_INT32 for s in seed.tolist()]) |
34 | 34 |
|
35 | 35 |
|
36 | 36 | def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
|
37 | 37 | rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
|
38 |
| - new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32) |
| 38 | + new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) |
39 | 39 | return [new_seed, data]
|
40 | 40 |
|
41 | 41 |
|
42 | 42 | def _split(seed: SeedType, num: int = 2) -> SeedType:
|
43 | 43 | rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
|
44 |
| - return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2]) |
| 44 | + return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) |
45 | 45 |
|
46 | 46 |
|
47 | 47 | def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
|
|
0 commit comments