Skip to content

Commit e02e401

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into prepare_v0.6
2 parents 29a552a + 0643268 commit e02e401

File tree

98 files changed

+1597
-504
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

98 files changed

+1597
-504
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ jobs:
180180
pip install -e .
181181
python tests/reference_algorithm_tests.py --workload=ogbg --framework=pytorch --global_batch_size=8 --submission_path=algorithms/target_setting_algorithms/pytorch_nesterov.py --tuning_search_space=algorithms/target_setting_algorithms/ogbg/tuning_search_space.json
182182
python tests/reference_algorithm_tests.py --workload=ogbg --framework=jax --global_batch_size=8 --submission_path=algorithms/target_setting_algorithms/jax_nesterov.py --tuning_search_space=algorithms/target_setting_algorithms/ogbg/tuning_search_space.json
183-
pytest:
183+
pytest-params:
184184
runs-on: ubuntu-latest
185185
steps:
186186
- uses: actions/checkout@v3

README.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,16 @@ We recommend using the provided [**Docker container**](/docs/GETTING_STARTED.md#
7070
Alternatively, you can install the package and its dependencies in a Python virtual environment.
7171
Both options are described in more detail in the [**Getting Started**](/docs/GETTING_STARTED.md) document.
7272

73-
_TL;DR: Install for JAX on GPU:_
73+
_TL;DR Install JAX version for GPU (with workload dependencies):_
7474

7575
```bash
76-
pip3 install -e '.[pytorch_cpu]'
77-
pip3 install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html'
78-
pip3 install -e '.[full]'
76+
pip3 install -e '.[pytorch_cpu,jax_gpu,full]' --extra-index-url https://download.pytorch.org/whl/cpu
7977
```
8078

81-
_TL;DR: Install for PyTorch on GPU:_
79+
_TL;DR Install PyTorch version for GPU (with workload dependencies):_
8280

8381
```bash
84-
pip3 install -e '.[jax_cpu]'
85-
pip3 install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'
86-
pip3 install -e '.[full]'
82+
pip3 install -e '.[jax_cpu,pytorch_gpu,full]'
8783
```
8884

8985
### Run a Workload

algoperf/checkpoint_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import os
88
from typing import Sequence, Tuple
99

10-
import jax
1110
import numpy as np
1211
import torch
1312
from absl import logging
@@ -210,10 +209,7 @@ def save_checkpoint(
210209
train_state, eval_results, global_step, preemption_count).
211210
"""
212211
if framework == 'jax':
213-
model_params = jax.device_get(jax_utils.unreplicate(model_params))
214212
opt_state, _ = optimizer_state
215-
opt_state = jax.device_get(jax_utils.unreplicate(opt_state))
216-
model_state = jax.device_get(jax_utils.unreplicate(model_state))
217213
else:
218214
if isinstance(
219215
model_params,

algoperf/data_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,19 @@ def _prepare(x):
6262
if remainder_size != 0 or pad_to_global_batch_size:
6363
x = pad(x, pad_size, padding_value=padding_value)
6464

65-
# Reshape (global_batch_size, ...) to
66-
# (local_device_count, per_device_batch_size, ...).
67-
# Assumes that `global_batch_size % local_device_count == 0`.
68-
return x.reshape((local_device_count, -1, *x.shape[1:]))
65+
# return x.reshape((local_device_count, -1, *x.shape[1:]))
66+
return x
6967

7068
return jax.tree.map(_prepare, batch)
7169

7270

71+
def shard(batch):
72+
local_device_count = max(torch.cuda.device_count(), jax.local_device_count())
73+
return jax.tree.map(
74+
lambda x: x.reshape((local_device_count, -1, *x.shape[1:])), batch
75+
)
76+
77+
7378
def pad(
7479
tensor: np.ndarray, pad_size: int, padding_value: int = 0
7580
) -> np.ndarray:

algoperf/jax_sharding_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Utilities for dealing with sharding in JAX."""
2+
3+
import jax
4+
from jax.sharding import NamedSharding
5+
from jax.sharding import PartitionSpec as P
6+
7+
8+
def get_replicate_sharding():
9+
"""Returns a sharding spec that replicates data across all devices."""
10+
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
11+
return NamedSharding(mesh, P())
12+
13+
14+
def get_batch_dim_sharding():
15+
"""Returns a sharding spec that shards data along the first axis."""
16+
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
17+
return NamedSharding(mesh, P('batch'))
18+
19+
20+
def shard_along_batch_dim(x):
21+
"""Shards a tensor across all devices."""
22+
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
23+
return jax.tree.map(
24+
lambda x: jax.device_put(x, NamedSharding(mesh, P('batch'))), x
25+
)
26+
27+
28+
def replicate(x):
29+
"""Replicates tensor across all devices."""
30+
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
31+
return jax.tree.map(lambda x: jax.device_put(x, NamedSharding(mesh, P())), x)
32+
33+
34+
def display_shard_info(x: jax.Array):
35+
"""Displays shard info of a jax array."""
36+
for shard in x.addressable_shards:
37+
print(
38+
f'shard.device: {shard.device}, index: {shard.index}, replica_id:'
39+
f' {shard.replica_id}.\n'
40+
)

algoperf/spec.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -235,16 +235,9 @@ def model_params_types(self):
235235
def is_output_params(self, param_key: ParameterKey) -> bool:
236236
"""Whether a key in ParameterContainer is the output layer parameters."""
237237

238-
# InitModelFn = Callable[
239-
# Tuple[RandomState, Optional[float], Optional[float]],
240-
# ParameterContainer]
238+
# InitModelFn = Callable[Optional[float]], ParameterContainer]
241239
@abc.abstractmethod
242-
def init_model_fn(
243-
self,
244-
rng: RandomState,
245-
dropout_rate: Optional[float] = None,
246-
aux_dropout_rate: Optional[float] = None,
247-
) -> ModelInitState:
240+
def init_model_fn(self, rng: RandomState) -> ModelInitState:
248241
"""Return (initial_params, initial_model_state)."""
249242

250243
# ModelFn = Callable[

algoperf/workloads/cifar/cifar_jax/workload.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77
import jax.numpy as jnp
88
import optax
99
import tensorflow_datasets as tfds
10-
from flax import jax_utils
1110
from flax import linen as nn
1211
from flax.core import pop
1312
from jax import lax
1413

15-
from algoperf import param_utils, spec
14+
from algoperf import jax_sharding_utils, param_utils, spec
1615
from algoperf.workloads.cifar.cifar_jax import models
1716
from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter
1817
from algoperf.workloads.cifar.workload import BaseCifarWorkload
@@ -29,6 +28,7 @@ def _build_cifar_dataset(
2928
repeat_final_dataset: Optional[bool] = None,
3029
) -> Iterator[Dict[str, spec.Tensor]]:
3130
ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir)
31+
ds_builder.download_and_prepare()
3232
train = split == 'train'
3333
assert self.num_train_examples + self.num_validation_examples == 50000
3434
if split in ['train', 'eval_train']:
@@ -89,8 +89,8 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
8989
model_state, params = pop(variables, 'params')
9090
self._param_shapes = param_utils.jax_param_shapes(params)
9191
self._param_types = param_utils.jax_param_types(self._param_shapes)
92-
model_state = jax_utils.replicate(model_state)
93-
params = jax_utils.replicate(params)
92+
model_state = jax_sharding_utils.replicate(params)
93+
params = jax_sharding_utils.replicate(params)
9494
return params, model_state
9595

9696
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
@@ -105,9 +105,11 @@ def model_fn(
105105
rng: spec.RandomState,
106106
update_batch_norm: bool,
107107
use_running_average_bn: Optional[bool] = None,
108+
dropout_rate: float = 0.0,
108109
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
109110
del mode
110111
del rng
112+
del dropout_rate
111113
variables = {'params': params, **model_state}
112114
if update_batch_norm:
113115
logits, new_model_state = self._model.apply(
@@ -171,15 +173,8 @@ def _compute_metrics(
171173
'loss': summed_loss,
172174
'accuracy': accuracy,
173175
}
174-
metrics = lax.psum(metrics, axis_name='batch')
175176
return metrics
176177

177-
@functools.partial(
178-
jax.pmap,
179-
axis_name='batch',
180-
in_axes=(None, 0, 0, 0, None),
181-
static_broadcasted_argnums=(0,),
182-
)
183178
def _eval_model(
184179
self,
185180
params: spec.ParameterContainer,
@@ -188,21 +183,41 @@ def _eval_model(
188183
rng: spec.RandomState,
189184
) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
190185
"""Return the mean accuracy and loss as a dict."""
191-
logits, _ = self.model_fn(
192-
params,
193-
batch,
194-
model_state,
195-
spec.ForwardPassMode.EVAL,
196-
rng,
197-
update_batch_norm=False,
186+
187+
@functools.partial(
188+
jax.jit,
189+
in_shardings=(
190+
jax_sharding_utils.get_replicate_sharding(), # params
191+
jax_sharding_utils.get_batch_dim_sharding(), # batch
192+
jax_sharding_utils.get_replicate_sharding(), # model_state
193+
jax_sharding_utils.get_batch_dim_sharding(), # rng
194+
),
198195
)
199-
weights = batch.get('weights')
200-
if weights is None:
201-
weights = jnp.ones(len(logits))
202-
return self._compute_metrics(logits, batch['targets'], weights)
196+
def _eval_model_jitted(
197+
params: spec.ParameterContainer,
198+
batch: Dict[str, spec.Tensor],
199+
model_state: spec.ModelAuxiliaryState,
200+
rng: spec.RandomState,
201+
) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
202+
"""Return the mean accuracy and loss as a dict."""
203+
logits, _ = self.model_fn(
204+
params,
205+
batch,
206+
model_state,
207+
spec.ForwardPassMode.EVAL,
208+
rng,
209+
update_batch_norm=False,
210+
)
211+
weights = batch.get('weights')
212+
if weights is None:
213+
weights = jnp.ones(len(logits))
214+
return self._compute_metrics(logits, batch['targets'], weights)
215+
216+
metrics = _eval_model_jitted(params, batch, model_state, rng)
217+
return jax.tree.map(lambda x: x.item(), metrics)
203218

204219
def _normalize_eval_metrics(
205220
self, num_examples: int, total_metrics: Dict[str, Any]
206221
) -> Dict[str, float]:
207222
"""Normalize eval metrics."""
208-
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)
223+
return jax.tree_map(lambda x: x / num_examples, total_metrics)

algoperf/workloads/cifar/cifar_pytorch/workload.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,8 @@ def _build_dataset(
118118
dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP)
119119
return dataloader
120120

121-
def init_model_fn(
122-
self,
123-
rng: spec.RandomState,
124-
dropout_rate: Optional[float] = None,
125-
aux_dropout_rate: Optional[float] = None,
126-
) -> spec.ModelInitState:
121+
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
127122
"""Dropout is unused."""
128-
del dropout_rate
129-
del aux_dropout_rate
130-
131123
if hasattr(self, '_model'):
132124
if isinstance(self._model, (DDP, torch.nn.DataParallel)):
133125
self._model.module.reset_parameters()
@@ -158,9 +150,11 @@ def model_fn(
158150
mode: spec.ForwardPassMode,
159151
rng: spec.RandomState,
160152
update_batch_norm: bool,
153+
dropout_rate: float = 0.0,
161154
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
162155
del model_state
163156
del rng
157+
del dropout_rate
164158
model = params
165159
if mode == spec.ForwardPassMode.EVAL:
166160
if update_batch_norm:

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
import jax
77
import jax.numpy as jnp
88
import numpy as np
9-
from flax import jax_utils
109

11-
from algoperf import param_utils, spec
10+
from algoperf import jax_sharding_utils, param_utils, spec
1211
from algoperf.workloads.criteo1tb.criteo1tb_jax import models
1312
from algoperf.workloads.criteo1tb.workload import BaseCriteo1TbDlrmSmallWorkload
1413

@@ -106,7 +105,7 @@ def init_model_fn(
106105
initial_params = initial_variables['params']
107106
self._param_shapes = param_utils.jax_param_shapes(initial_params)
108107
self._param_types = param_utils.jax_param_types(self._param_shapes)
109-
return jax_utils.replicate(initial_params), None
108+
return jax_sharding_utils.replicate(initial_params), None
110109

111110
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
112111
return param_key == 'Dense_7'
@@ -132,13 +131,40 @@ def model_fn(
132131
logits_batch = self._model.apply({'params': params}, inputs, **apply_kwargs)
133132
return logits_batch, None
134133

134+
def _build_input_queue(
135+
self,
136+
data_rng: spec.RandomState,
137+
split: str,
138+
data_dir: str,
139+
global_batch_size: int,
140+
cache: Optional[bool] = None,
141+
repeat_final_dataset: Optional[bool] = None,
142+
num_batches: Optional[int] = None,
143+
):
144+
it = super()._build_input_queue(
145+
data_rng,
146+
split,
147+
data_dir,
148+
global_batch_size,
149+
cache,
150+
repeat_final_dataset,
151+
num_batches,
152+
)
153+
f = functools.partial(
154+
jax.device_put, device=jax_sharding_utils.get_batch_dim_sharding()
155+
)
156+
return map(f, it)
157+
135158
@functools.partial(
136-
jax.pmap,
137-
axis_name='batch',
138-
in_axes=(None, 0, 0),
139-
static_broadcasted_argnums=(0,),
159+
jax.jit,
160+
in_shardings=(
161+
jax_sharding_utils.get_replicate_sharding(),
162+
jax_sharding_utils.get_batch_dim_sharding(),
163+
),
164+
static_argnums=(0,),
165+
out_shardings=jax_sharding_utils.get_replicate_sharding(),
140166
)
141-
def _eval_batch_pmapped(
167+
def _eval_batch_jitted(
142168
self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]
143169
) -> spec.Tensor:
144170
logits, _ = self.model_fn(
@@ -162,9 +188,7 @@ def _eval_batch(
162188
) -> spec.Tensor:
163189
# We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of
164190
# shape (local_device_count,) will all be different values.
165-
return np.array(
166-
self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64
167-
)
191+
return np.array(self._eval_batch_jitted(params, batch), dtype=np.float64)
168192

169193

170194
class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):

algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.distributed as dist
88
from torch.nn.parallel import DistributedDataParallel as DDP
99

10-
from algoperf import param_utils, spec
10+
from algoperf import data_utils, param_utils, spec
1111
from algoperf.pytorch_utils import pytorch_setup
1212
from algoperf.workloads.criteo1tb.criteo1tb_pytorch import models
1313
from algoperf.workloads.criteo1tb.workload import BaseCriteo1TbDlrmSmallWorkload
@@ -152,6 +152,7 @@ def _build_input_queue(
152152
num_batches=num_batches,
153153
repeat_final_dataset=repeat_final_dataset,
154154
)
155+
np_iter = map(data_utils.shard, np_iter)
155156
weights = None
156157
while True:
157158
if RANK == 0:

0 commit comments

Comments
 (0)