Skip to content

Commit 9d1c957

Browse files
Merge pull request #811 from init-22/python311
Moving from Python3.8 to Python 3.11
2 parents ea66793 + 53eff1d commit 9d1c957

File tree

18 files changed

+598
-125
lines changed

18 files changed

+598
-125
lines changed

.github/workflows/CI.yml

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ jobs:
77
runs-on: ubuntu-latest
88
steps:
99
- uses: actions/checkout@v3
10-
- name: Set up Python 3.9
10+
- name: Set up Python 3.11.10
1111
uses: actions/setup-python@v4
1212
with:
13-
python-version: 3.9
13+
python-version: 3.11.10
1414
cache: 'pip' # Cache pip dependencies\.
1515
cache-dependency-path: '**/setup.py'
1616
- name: Install Modules and Run
@@ -25,10 +25,10 @@ jobs:
2525
runs-on: ubuntu-latest
2626
steps:
2727
- uses: actions/checkout@v3
28-
- name: Set up Python 3.9
28+
- name: Set up Python 3.11.10
2929
uses: actions/setup-python@v4
3030
with:
31-
python-version: 3.9
31+
python-version: 3.11.10
3232
cache: 'pip' # Cache pip dependencies\.
3333
cache-dependency-path: '**/setup.py'
3434
- name: Install Modules and Run
@@ -42,10 +42,10 @@ jobs:
4242
runs-on: ubuntu-latest
4343
steps:
4444
- uses: actions/checkout@v3
45-
- name: Set up Python 3.9
45+
- name: Set up Python 3.11.10
4646
uses: actions/setup-python@v4
4747
with:
48-
python-version: 3.9
48+
python-version: 3.11.10
4949
cache: 'pip' # Cache pip dependencies\.
5050
cache-dependency-path: '**/setup.py'
5151
- name: Install Modules and Run
@@ -59,10 +59,10 @@ jobs:
5959
runs-on: ubuntu-latest
6060
steps:
6161
- uses: actions/checkout@v3
62-
- name: Set up Python 3.9
62+
- name: Set up Python 3.11.10
6363
uses: actions/setup-python@v4
6464
with:
65-
python-version: 3.9
65+
python-version: 3.11.10
6666
cache: 'pip' # Cache pip dependencies\.
6767
cache-dependency-path: '**/setup.py'
6868
- name: Install Modules and Run
@@ -77,10 +77,10 @@ jobs:
7777
runs-on: ubuntu-latest
7878
steps:
7979
- uses: actions/checkout@v3
80-
- name: Set up Python 3.9
80+
- name: Set up Python 3.11.10
8181
uses: actions/setup-python@v4
8282
with:
83-
python-version: 3.9
83+
python-version: 3.11.10
8484
cache: 'pip' # Cache pip dependencies\.
8585
cache-dependency-path: '**/setup.py'
8686
- name: Install Modules and Run
@@ -96,10 +96,10 @@ jobs:
9696
runs-on: ubuntu-latest
9797
steps:
9898
- uses: actions/checkout@v3
99-
- name: Set up Python 3.9
99+
- name: Set up Python 3.11.10
100100
uses: actions/setup-python@v4
101101
with:
102-
python-version: 3.9
102+
python-version: 3.11.10
103103
cache: 'pip' # Cache pip dependencies\.
104104
cache-dependency-path: '**/setup.py'
105105
- name: Install Modules and Run
@@ -113,10 +113,10 @@ jobs:
113113
runs-on: ubuntu-latest
114114
steps:
115115
- uses: actions/checkout@v3
116-
- name: Set up Python 3.9
116+
- name: Set up Python 3.11.10
117117
uses: actions/setup-python@v4
118118
with:
119-
python-version: 3.9
119+
python-version: 3.11.10
120120
cache: 'pip' # Cache pip dependencies\.
121121
cache-dependency-path: '**/setup.py'
122122
- name: Install Modules and Run
@@ -130,10 +130,10 @@ jobs:
130130
runs-on: ubuntu-latest
131131
steps:
132132
- uses: actions/checkout@v3
133-
- name: Set up Python 3.9
133+
- name: Set up Python 3.11.10
134134
uses: actions/setup-python@v4
135135
with:
136-
python-version: 3.9
136+
python-version: 3.11.10
137137
cache: 'pip' # Cache pip dependencies\.
138138
cache-dependency-path: '**/setup.py'
139139
- name: Install Modules and Run
@@ -148,10 +148,10 @@ jobs:
148148
runs-on: ubuntu-latest
149149
steps:
150150
- uses: actions/checkout@v3
151-
- name: Set up Python 3.9
151+
- name: Set up Python 3.11.10
152152
uses: actions/setup-python@v4
153153
with:
154-
python-version: 3.9
154+
python-version: 3.11.10
155155
cache: 'pip' # Cache pip dependencies\.
156156
cache-dependency-path: '**/setup.py'
157157
- name: Install Modules and Run
@@ -166,10 +166,10 @@ jobs:
166166
runs-on: ubuntu-latest
167167
steps:
168168
- uses: actions/checkout@v3
169-
- name: Set up Python 3.9
169+
- name: Set up Python 3.11.10
170170
uses: actions/setup-python@v4
171171
with:
172-
python-version: 3.9
172+
python-version: 3.11.10
173173
cache: 'pip' # Cache pip dependencies\.
174174
cache-dependency-path: '**/setup.py'
175175
- name: Install Modules and Run
@@ -184,10 +184,10 @@ jobs:
184184
runs-on: ubuntu-latest
185185
steps:
186186
- uses: actions/checkout@v3
187-
- name: Set up Python 3.9
187+
- name: Set up Python 3.11.10
188188
uses: actions/setup-python@v4
189189
with:
190-
python-version: 3.9
190+
python-version: 3.11.10
191191
cache: 'pip' # Cache pip dependencies\.
192192
cache-dependency-path: '**/setup.py'
193193
- name: Install pytest
@@ -208,10 +208,10 @@ jobs:
208208
runs-on: ubuntu-latest
209209
steps:
210210
- uses: actions/checkout@v3
211-
- name: Set up Python 3.9
211+
- name: Set up Python 3.11.10
212212
uses: actions/setup-python@v4
213213
with:
214-
python-version: 3.9
214+
python-version: 3.11.10
215215
cache: 'pip' # Cache pip dependencies\.
216216
cache-dependency-path: '**/setup.py'
217217
- name: Install pytest

.github/workflows/linting.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ jobs:
77
runs-on: ubuntu-latest
88
steps:
99
- uses: actions/checkout@v2
10-
- name: Set up Python 3.9
10+
- name: Set up Python 3.11.10
1111
uses: actions/setup-python@v2
1212
with:
13-
python-version: 3.9
13+
python-version: 3.11.10
1414
- name: Install pylint
1515
run: |
1616
python -m pip install --upgrade pip
@@ -27,10 +27,10 @@ jobs:
2727
runs-on: ubuntu-latest
2828
steps:
2929
- uses: actions/checkout@v2
30-
- name: Set up Python 3.9
30+
- name: Set up Python 3.11.10
3131
uses: actions/setup-python@v2
3232
with:
33-
python-version: 3.9
33+
python-version: 3.11.10
3434
- name: Install isort
3535
run: |
3636
python -m pip install --upgrade pip
@@ -43,10 +43,10 @@ jobs:
4343
runs-on: ubuntu-latest
4444
steps:
4545
- uses: actions/checkout@v2
46-
- name: Set up Python 3.9
46+
- name: Set up Python 3.11.10
4747
uses: actions/setup-python@v2
4848
with:
49-
python-version: 3.9
49+
python-version: 3.11.10
5050
- name: Install yapf
5151
run: |
5252
python -m pip install --upgrade pip

GETTING_STARTED.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ The specs on the benchmarking machines are:
3535

3636
> **Prerequisites:**
3737
>
38-
> - Python minimum requirement >= 3.8
38+
> - Python minimum requirement >= 3.11
3939
> - CUDA 12.1
4040
> - NVIDIA Driver version 535.104.05
4141

algorithmic_efficiency/checkpoint_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def save_checkpoint(framework: str,
231231
target=checkpoint_state,
232232
step=global_step,
233233
overwrite=True,
234-
keep=np.Inf if save_intermediate_checkpoints else 1)
234+
keep=np.inf if save_intermediate_checkpoints else 1)
235235
else:
236236
if not save_intermediate_checkpoints:
237237
checkpoint_files = gfile.glob(

algorithmic_efficiency/halton.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
import functools
1111
import itertools
1212
import math
13-
from typing import Any, Callable, Dict, List, Sequence, Text, Tuple, Union
13+
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
1414

1515
from absl import logging
1616
from numpy import random
1717

18-
_SweepSequence = List[Dict[Text, Any]]
19-
_GeneratorFn = Callable[[float], Tuple[Text, float]]
18+
_SweepSequence = List[Dict[str, Any]]
19+
_GeneratorFn = Callable[[float], Tuple[str, float]]
2020

2121

2222
def generate_primes(n: int) -> List[int]:
@@ -195,10 +195,10 @@ def generate_sequence(num_samples: int,
195195
return halton_sequence
196196

197197

198-
def _generate_double_point(name: Text,
198+
def _generate_double_point(name: str,
199199
min_val: float,
200200
max_val: float,
201-
scaling: Text,
201+
scaling: str,
202202
halton_point: float) -> Tuple[str, float]:
203203
"""Generate a float hyperparameter value from a Halton sequence point."""
204204
if scaling not in ['linear', 'log']:
@@ -234,7 +234,7 @@ def interval(start: int, end: int) -> Tuple[int, int]:
234234
return start, end
235235

236236

237-
def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
237+
def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
238238
min_val, max_val = range_endpoints
239239
return functools.partial(_generate_double_point,
240240
name,
@@ -244,8 +244,8 @@ def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
244244

245245

246246
def uniform(
247-
name: Text, search_points: Union[_DiscretePoints,
248-
Tuple[int, int]]) -> _GeneratorFn:
247+
name: str, search_points: Union[_DiscretePoints,
248+
Tuple[int, int]]) -> _GeneratorFn:
249249
if isinstance(search_points, _DiscretePoints):
250250
return functools.partial(_generate_discrete_point,
251251
name,

algorithmic_efficiency/logger_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def _get_system_software_info() -> Dict:
211211
system_software_info['os_platform'] = \
212212
platform.platform() # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29'
213213
system_software_info['python_version'] = platform.python_version(
214-
) # Ex. '3.8.10'
214+
) # Ex. '3.11.10'
215215
system_software_info['python_compiler'] = platform.python_compiler(
216216
) # Ex. 'GCC 9.3.0'
217217
# Note: do not store hostname as that may be sensitive

algorithmic_efficiency/random_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,30 @@
1818

1919
# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an
2020
# unsigned int), while RandomState.randint only accepts and returns signed ints.
21-
MAX_INT32 = 2**31
22-
MIN_INT32 = -MAX_INT32
21+
MAX_UINT32 = 2**32 - 1
22+
MIN_UINT32 = 0
2323

2424
SeedType = Union[int, list, np.ndarray]
2525

2626

2727
def _signed_to_unsigned(seed: SeedType) -> SeedType:
2828
if isinstance(seed, int):
29-
return seed % 2**32
29+
return seed % MAX_UINT32
3030
if isinstance(seed, list):
31-
return [s % 2**32 for s in seed]
31+
return [s % MAX_UINT32 for s in seed]
3232
if isinstance(seed, np.ndarray):
33-
return np.array([s % 2**32 for s in seed.tolist()])
33+
return np.array([s % MAX_UINT32 for s in seed.tolist()])
3434

3535

3636
def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
3737
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
38-
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
38+
new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32)
3939
return [new_seed, data]
4040

4141

4242
def _split(seed: SeedType, num: int = 2) -> SeedType:
4343
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
44-
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
44+
return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2])
4545

4646

4747
def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
@@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType:
7575
def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
7676
if FLAGS.framework == 'jax':
7777
_check_jax_install()
78-
return jax_rng.PRNGKey(seed)
78+
return jax_rng.key(seed)
7979
return _PRNGKey(seed)

algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from flax import jax_utils
77
from flax import linen as nn
8+
from flax.core import pop
89
import jax
910
from jax import lax
1011
import jax.numpy as jnp
@@ -75,8 +76,8 @@ def sync_batch_stats(
7576
# In this case each device has its own version of the batch statistics
7677
# and we average them.
7778
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
78-
new_model_state = model_state.copy(
79-
{'batch_stats': avg_fn(model_state['batch_stats'])})
79+
new_model_state = model_state.copy()
80+
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
8081
return new_model_state
8182

8283
def init_model_fn(
@@ -93,7 +94,7 @@ def init_model_fn(
9394
input_shape = (1, 32, 32, 3)
9495
variables = jax.jit(model.init)({'params': rng},
9596
jnp.ones(input_shape, model.dtype))
96-
model_state, params = variables.pop('params')
97+
model_state, params = pop(variables, 'params')
9798
self._param_shapes = param_utils.jax_param_shapes(params)
9899
self._param_types = param_utils.jax_param_types(self._param_shapes)
99100
model_state = jax_utils.replicate(model_state)

0 commit comments

Comments
 (0)