7
7
import jax .numpy as jnp
8
8
import optax
9
9
import tensorflow_datasets as tfds
10
- from flax import jax_utils
11
10
from flax import linen as nn
12
11
from flax .core import pop
13
12
from jax import lax
14
13
15
- from algoperf import param_utils , spec
14
+ from algoperf import jax_sharding_utils , param_utils , spec
16
15
from algoperf .workloads .cifar .cifar_jax import models
17
16
from algoperf .workloads .cifar .cifar_jax .input_pipeline import create_input_iter
18
17
from algoperf .workloads .cifar .workload import BaseCifarWorkload
@@ -29,6 +28,7 @@ def _build_cifar_dataset(
29
28
repeat_final_dataset : Optional [bool ] = None ,
30
29
) -> Iterator [Dict [str , spec .Tensor ]]:
31
30
ds_builder = tfds .builder ('cifar10:3.0.2' , data_dir = data_dir )
31
+ ds_builder .download_and_prepare ()
32
32
train = split == 'train'
33
33
assert self .num_train_examples + self .num_validation_examples == 50000
34
34
if split in ['train' , 'eval_train' ]:
@@ -89,8 +89,8 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
89
89
model_state , params = pop (variables , 'params' )
90
90
self ._param_shapes = param_utils .jax_param_shapes (params )
91
91
self ._param_types = param_utils .jax_param_types (self ._param_shapes )
92
- model_state = jax_utils .replicate (model_state )
93
- params = jax_utils .replicate (params )
92
+ model_state = jax_sharding_utils .replicate (params )
93
+ params = jax_sharding_utils .replicate (params )
94
94
return params , model_state
95
95
96
96
def is_output_params (self , param_key : spec .ParameterKey ) -> bool :
@@ -105,9 +105,11 @@ def model_fn(
105
105
rng : spec .RandomState ,
106
106
update_batch_norm : bool ,
107
107
use_running_average_bn : Optional [bool ] = None ,
108
+ dropout_rate : float = 0.0 ,
108
109
) -> Tuple [spec .Tensor , spec .ModelAuxiliaryState ]:
109
110
del mode
110
111
del rng
112
+ del dropout_rate
111
113
variables = {'params' : params , ** model_state }
112
114
if update_batch_norm :
113
115
logits , new_model_state = self ._model .apply (
@@ -171,15 +173,8 @@ def _compute_metrics(
171
173
'loss' : summed_loss ,
172
174
'accuracy' : accuracy ,
173
175
}
174
- metrics = lax .psum (metrics , axis_name = 'batch' )
175
176
return metrics
176
177
177
- @functools .partial (
178
- jax .pmap ,
179
- axis_name = 'batch' ,
180
- in_axes = (None , 0 , 0 , 0 , None ),
181
- static_broadcasted_argnums = (0 ,),
182
- )
183
178
def _eval_model (
184
179
self ,
185
180
params : spec .ParameterContainer ,
@@ -188,21 +183,41 @@ def _eval_model(
188
183
rng : spec .RandomState ,
189
184
) -> Dict [spec .Tensor , spec .ModelAuxiliaryState ]:
190
185
"""Return the mean accuracy and loss as a dict."""
191
- logits , _ = self .model_fn (
192
- params ,
193
- batch ,
194
- model_state ,
195
- spec .ForwardPassMode .EVAL ,
196
- rng ,
197
- update_batch_norm = False ,
186
+
187
+ @functools .partial (
188
+ jax .jit ,
189
+ in_shardings = (
190
+ jax_sharding_utils .get_replicate_sharding (), # params
191
+ jax_sharding_utils .get_batch_dim_sharding (), # batch
192
+ jax_sharding_utils .get_replicate_sharding (), # model_state
193
+ jax_sharding_utils .get_batch_dim_sharding (), # rng
194
+ ),
198
195
)
199
- weights = batch .get ('weights' )
200
- if weights is None :
201
- weights = jnp .ones (len (logits ))
202
- return self ._compute_metrics (logits , batch ['targets' ], weights )
196
+ def _eval_model_jitted (
197
+ params : spec .ParameterContainer ,
198
+ batch : Dict [str , spec .Tensor ],
199
+ model_state : spec .ModelAuxiliaryState ,
200
+ rng : spec .RandomState ,
201
+ ) -> Dict [spec .Tensor , spec .ModelAuxiliaryState ]:
202
+ """Return the mean accuracy and loss as a dict."""
203
+ logits , _ = self .model_fn (
204
+ params ,
205
+ batch ,
206
+ model_state ,
207
+ spec .ForwardPassMode .EVAL ,
208
+ rng ,
209
+ update_batch_norm = False ,
210
+ )
211
+ weights = batch .get ('weights' )
212
+ if weights is None :
213
+ weights = jnp .ones (len (logits ))
214
+ return self ._compute_metrics (logits , batch ['targets' ], weights )
215
+
216
+ metrics = _eval_model_jitted (params , batch , model_state , rng )
217
+ return jax .tree .map (lambda x : x .item (), metrics )
203
218
204
219
def _normalize_eval_metrics (
205
220
self , num_examples : int , total_metrics : Dict [str , Any ]
206
221
) -> Dict [str , float ]:
207
222
"""Normalize eval metrics."""
208
- return jax .tree . map (lambda x : float ( x [ 0 ] / num_examples ) , total_metrics )
223
+ return jax .tree_map (lambda x : x / num_examples , total_metrics )
0 commit comments