Skip to content

Commit 9ee2d80

Browse files
committed
Merge branch 'dev' into rename_algoperf
2 parents bc666a7 + b356844 commit 9ee2d80

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

algoperf/random_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,32 @@
1616

1717
FLAGS = flags.FLAGS
1818

19-
# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an
19+
# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 31 - 1] (an
2020
# unsigned int), while RandomState.randint only accepts and returns signed ints.
21-
MAX_UINT32 = 2**32 - 1
22-
MIN_UINT32 = 0
21+
MAX_INT32 = 2**31 - 1
22+
MIN_INT32 = 0
2323

2424
SeedType = Union[int, list, np.ndarray]
2525

2626

2727
def _signed_to_unsigned(seed: SeedType) -> SeedType:
2828
if isinstance(seed, int):
29-
return seed % MAX_UINT32
29+
return seed % MAX_INT32
3030
if isinstance(seed, list):
31-
return [s % MAX_UINT32 for s in seed]
31+
return [s % MAX_INT32 for s in seed]
3232
if isinstance(seed, np.ndarray):
33-
return np.array([s % MAX_UINT32 for s in seed.tolist()])
33+
return np.array([s % MAX_INT32 for s in seed.tolist()])
3434

3535

3636
def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
3737
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
38-
new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32)
38+
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
3939
return [new_seed, data]
4040

4141

4242
def _split(seed: SeedType, num: int = 2) -> SeedType:
4343
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
44-
return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2])
44+
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
4545

4646

4747
def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name

0 commit comments

Comments
 (0)