Skip to content

Commit ea66793

Browse files
Merge pull request #789 from Niccolo-Ajroldi/prepare_for_eval
Introduce prepare for eval, fix evaluation bug
2 parents 9d37d3e + 364ce41 commit ea66793

File tree

33 files changed

+812
-93
lines changed

33 files changed

+812
-93
lines changed

DOCUMENTATION.md

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ In principle, submissions are allowed to use the available hardware systems in a
8080
Submissions provide a [per-workload batch size](#batch-size-getter) to use. Specification of the batch size for each workload is necessary to avoid running out of memory for different workloads. Therefore, submitters can determine this batch size in advance and specify it as part of the submission. Submitters may also provide per-workload batch sizes for all [randomized workloads](#randomized-workloads). If no such batch size is provided for a randomized workload, by default, submissions will then use the batch size of the most similar [fixed workload](#fixed-workloads) (for example, if there is an ImageNet fixed workload and also a randomized workload with a similarly sized model on similarly sized images, the ImageNet batch size will be used for held-out workloads generated from this randomized workload).
8181
Note that submitters are *not* allowed to modify the *evaluation batch size*, which is set by the benchmarking codebase. However, you can file an issue if you believe that the evaluation batch size of a particular workload is set inappropriately. The working group will review this request and consider adjusting the evaluation batch size in the benchmarking codebase, thus affecting all submitters equally.
8282

83-
The **submission functions** are the *batch size getter*, *optimizer state initializer*, *variable update*, and *data selection functions*. The *fixed functions* are the *data augmentation/preprocessing*, *model initialization*, *forward pass*, and *loss function*. The trained model will be evaluated in a separate step that does not call any of the submitted code.
83+
The **submission functions** are the *batch size getter*, *optimizer state initializer*, *variable update*, *prepare for evaluation function*, and *data selection functions*. The *fixed functions* are the *data augmentation/preprocessing*, *model initialization*, *forward pass*, and *loss function*. The trained model will be evaluated in a separate step that does not call any of the submitted code.
8484

8585
##### Fixed functions
8686

@@ -220,9 +220,35 @@ def update_params(
220220
- Cannot modify the given hyperparameters in a workload-conditional way (please see the [Valid submission](#valid-submissions) section). This rule is intended to prohibit circumventing the tuning rules by looking up a pre-tuned optimal set of hyperparameters for each workload. It is not intended to prohibit line searches and other similar techniques.
221221
- The fixed `init_model_fn` can optionally be called during training, for example, to reinitialize the model after a failed training effort.
222222
- Cannot replace the model parameters with pre-trained ones.
223-
- This API supports Polyak averaging and similar methods that implement moving averages of model parameters.
224223
- Batch norm should work here because the `model_fn` will return updated batch norm moving averages when it is told to with `update_batch_norm`.
225224

225+
226+
###### Prepare for evaluation function
227+
228+
```python
229+
def prepare_for_eval(
230+
workload: Workload,
231+
current_param_container: ParameterContainer,
232+
current_params_types: ParameterTypeTree,
233+
model_state: ModelAuxiliaryState,
234+
hyperparameters: Hyperparameters,
235+
loss_type: LossType,
236+
optimizer_state: OptimizerState,
237+
eval_results: List[Tuple[int, float]],
238+
global_step: int,
239+
rng: RandomState
240+
) -> (updated_optimizer_state, updated_variables, updated_model_state)
241+
```
242+
243+
- Arguments are the same of `update_param`, with the only exception of `batch`.
244+
- This function is called when a submission is deemed eligible for an evaluation (see [Evluation during training](#evaluation-during-training) section).
245+
- The call to `prepare_for_eval` is timed and its runtime accumulates to the overall submission time.
246+
- The returned model parameters are evaluated on the validation and test sets, provided that the accumulated submission time does not exceed the maximum runtime after this function call.
247+
- This API supports Polyak averaging and similar methods that implement moving averages of model parameters.
248+
- Allowed to update model state and model parameters.
249+
- Allowed to update state for the optimizer.
250+
- Cannot replace the model parameters with pre-trained ones.
251+
226252
###### Data selection
227253

228254
```python
@@ -252,7 +278,8 @@ def data_selection(
252278

253279
In general, with noisy, non-deterministic training, evaluation frequency can affect training time measurements as more "bites of the apple" potentially allows the training code to exploit instability. We also want to discourage submissions from complicated and unrealistic logic that attempts to guess when training is close to complete and increases the evaluation rate, while not producing a well-sampled training curve at the start of training. Simply allowing submissions complete freedom over evaluation frequency encourages competitors to work to minimize the number of evaluations, which distracts from the primary goal of finding better training algorithms.
254280

255-
Submissions are eligible for an untimed eval every `eval_period` seconds, run as soon as the current call of `update_params` completes. Any additional evaluations performed by the submission code count against the runtime for scoring. The harness that runs the submission code will attempt to eval every `eval_period` seconds by checking between each submission step (call of `update_params`) whether it has been at least `eval_period` seconds since that last eval and, if so, pausing the clock and running an eval. This means that if calls to `update_params` typically take a lot more than `eval_period` seconds, such submissions will not receive as many untimed evals as a submission that had an `update_params` function that took less time. However, for appropriate settings of `eval_period`, we expect this to be quite rare. Submissions are always free to restructure their `update_params` code to split work into two subsequent steps to regain the potential benefits of these untimed model evaluations. For each workload, the `eval_period` will be set such that the total evaluation time is roughly between 10% and 20% of the total training time for the target-setting runs.
281+
Submissions are eligible for an untimed eval every `eval_period` seconds. Before proceeding to evaluation, the submission can prepare the model through a call to `prepare_for_eval`, effectively modifying the model parameters and state as well as the the optimizer state. Any additional evaluations performed by the submission code count against the runtime for scoring.
282+
The harness that runs the submission code will attempt to eval every `eval_period` seconds by checking between each submission step (call of `update_params`) whether it has been at least `eval_period` seconds since that last eval, if so, the submission is given the possibility to prepare for evaluation (through a timed call to `prepare_for_eval`). If the accumulated runtime does not exceed the maximum allowed runtime after the preparation step, the clock is paused, and the submission is evaluated. This means that if calls to `update_params` typically take a lot more than `eval_period` seconds, such submissions will not receive as many untimed evals as a submission that had an `update_params` function that took less time. However, for appropriate settings of `eval_period`, we expect this to be quite rare. Submissions are always free to restructure their `update_params` code to split work into two subsequent steps to regain the potential benefits of these untimed model evaluations. For each workload, the `eval_period` will be set such that the total evaluation time is roughly between 10% and 20% of the total training time for the target-setting runs.
256283

257284
#### Valid submissions
258285

algorithmic_efficiency/spec.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,36 @@ def update_params(workload: Workload,
431431
pass
432432

433433

434+
PrepareForEvalFn = Callable[[
435+
Workload,
436+
ParameterContainer,
437+
ParameterTypeTree,
438+
ModelAuxiliaryState,
439+
Hyperparameters,
440+
LossType,
441+
OptimizerState,
442+
List[Tuple[int, float]],
443+
int,
444+
RandomState
445+
],
446+
UpdateReturn]
447+
448+
449+
# Prepare model and optimizer for evaluation.
450+
def prepare_for_eval(workload: Workload,
451+
current_param_container: ParameterContainer,
452+
current_params_types: ParameterTypeTree,
453+
model_state: ModelAuxiliaryState,
454+
hyperparameters: Hyperparameters,
455+
loss_type: LossType,
456+
optimizer_state: OptimizerState,
457+
eval_results: List[Tuple[int, float]],
458+
global_step: int,
459+
rng: RandomState) -> UpdateReturn:
460+
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
461+
pass
462+
463+
434464
DataSelectionFn = Callable[[
435465
Workload,
436466
Iterator[Dict[str, Any]],

prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,27 @@ def update_params(
302302
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
303303

304304

305+
def prepare_for_eval(workload: spec.Workload,
306+
current_param_container: spec.ParameterContainer,
307+
current_params_types: spec.ParameterTypeTree,
308+
model_state: spec.ModelAuxiliaryState,
309+
hyperparameters: spec.Hyperparameters,
310+
loss_type: spec.LossType,
311+
optimizer_state: spec.OptimizerState,
312+
eval_results: List[Tuple[int, float]],
313+
global_step: int,
314+
rng: spec.RandomState) -> spec.UpdateReturn:
315+
"""Return (updated_optimizer_state, updated_params)."""
316+
del workload
317+
del hyperparameters
318+
del current_params_types
319+
del loss_type
320+
del eval_results
321+
del global_step
322+
del rng
323+
return (optimizer_state, current_param_container, model_state)
324+
325+
305326
def get_batch_size(workload_name):
306327
# Return the global batch size.
307328
if workload_name == 'criteo1tb':

prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,27 @@ def update_params(
302302
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
303303

304304

305+
def prepare_for_eval(workload: spec.Workload,
306+
current_param_container: spec.ParameterContainer,
307+
current_params_types: spec.ParameterTypeTree,
308+
model_state: spec.ModelAuxiliaryState,
309+
hyperparameters: spec.Hyperparameters,
310+
loss_type: spec.LossType,
311+
optimizer_state: spec.OptimizerState,
312+
eval_results: List[Tuple[int, float]],
313+
global_step: int,
314+
rng: spec.RandomState) -> spec.UpdateReturn:
315+
"""Return (updated_optimizer_state, updated_params)."""
316+
del workload
317+
del hyperparameters
318+
del current_params_types
319+
del loss_type
320+
del eval_results
321+
del global_step
322+
del rng
323+
return (optimizer_state, current_param_container, model_state)
324+
325+
305326
def get_batch_size(workload_name):
306327
# Return the global batch size.
307328
if workload_name == 'criteo1tb':

prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,27 @@ def update_params(
304304
return (optimizer_state, current_param_container, new_model_state)
305305

306306

307+
def prepare_for_eval(workload: spec.Workload,
308+
current_param_container: spec.ParameterContainer,
309+
current_params_types: spec.ParameterTypeTree,
310+
model_state: spec.ModelAuxiliaryState,
311+
hyperparameters: spec.Hyperparameters,
312+
loss_type: spec.LossType,
313+
optimizer_state: spec.OptimizerState,
314+
eval_results: List[Tuple[int, float]],
315+
global_step: int,
316+
rng: spec.RandomState) -> spec.UpdateReturn:
317+
"""Return (updated_optimizer_state, updated_params)."""
318+
del workload
319+
del hyperparameters
320+
del current_params_types
321+
del loss_type
322+
del eval_results
323+
del global_step
324+
del rng
325+
return (optimizer_state, current_param_container, model_state)
326+
327+
307328
def get_batch_size(workload_name):
308329
# Return the global batch size.
309330
if workload_name == 'criteo1tb':

prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,27 @@ def update_params(
304304
return (optimizer_state, current_param_container, new_model_state)
305305

306306

307+
def prepare_for_eval(workload: spec.Workload,
308+
current_param_container: spec.ParameterContainer,
309+
current_params_types: spec.ParameterTypeTree,
310+
model_state: spec.ModelAuxiliaryState,
311+
hyperparameters: spec.Hyperparameters,
312+
loss_type: spec.LossType,
313+
optimizer_state: spec.OptimizerState,
314+
eval_results: List[Tuple[int, float]],
315+
global_step: int,
316+
rng: spec.RandomState) -> spec.UpdateReturn:
317+
"""Return (updated_optimizer_state, updated_params)."""
318+
del workload
319+
del hyperparameters
320+
del current_params_types
321+
del loss_type
322+
del eval_results
323+
del global_step
324+
del rng
325+
return (optimizer_state, current_param_container, model_state)
326+
327+
307328
def get_batch_size(workload_name):
308329
# Return the global batch size.
309330
if workload_name == 'criteo1tb':

prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,27 @@ def update_params(
317317
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
318318

319319

320+
def prepare_for_eval(workload: spec.Workload,
321+
current_param_container: spec.ParameterContainer,
322+
current_params_types: spec.ParameterTypeTree,
323+
model_state: spec.ModelAuxiliaryState,
324+
hyperparameters: spec.Hyperparameters,
325+
loss_type: spec.LossType,
326+
optimizer_state: spec.OptimizerState,
327+
eval_results: List[Tuple[int, float]],
328+
global_step: int,
329+
rng: spec.RandomState) -> spec.UpdateReturn:
330+
"""Return (updated_optimizer_state, updated_params)."""
331+
del workload
332+
del hyperparameters
333+
del current_params_types
334+
del loss_type
335+
del eval_results
336+
del global_step
337+
del rng
338+
return (optimizer_state, current_param_container, model_state)
339+
340+
320341
def get_batch_size(workload_name):
321342
# Return the global batch size.
322343
if workload_name == 'criteo1tb':

prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,27 @@ def update_params(
317317
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
318318

319319

320+
def prepare_for_eval(workload: spec.Workload,
321+
current_param_container: spec.ParameterContainer,
322+
current_params_types: spec.ParameterTypeTree,
323+
model_state: spec.ModelAuxiliaryState,
324+
hyperparameters: spec.Hyperparameters,
325+
loss_type: spec.LossType,
326+
optimizer_state: spec.OptimizerState,
327+
eval_results: List[Tuple[int, float]],
328+
global_step: int,
329+
rng: spec.RandomState) -> spec.UpdateReturn:
330+
"""Return (updated_optimizer_state, updated_params)."""
331+
del workload
332+
del hyperparameters
333+
del current_params_types
334+
del loss_type
335+
del eval_results
336+
del global_step
337+
del rng
338+
return (optimizer_state, current_param_container, model_state)
339+
340+
320341
def get_batch_size(workload_name):
321342
# Return the global batch size.
322343
if workload_name == 'criteo1tb':

prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,27 @@ def update_params(
319319
return (optimizer_state, current_param_container, new_model_state)
320320

321321

322+
def prepare_for_eval(workload: spec.Workload,
323+
current_param_container: spec.ParameterContainer,
324+
current_params_types: spec.ParameterTypeTree,
325+
model_state: spec.ModelAuxiliaryState,
326+
hyperparameters: spec.Hyperparameters,
327+
loss_type: spec.LossType,
328+
optimizer_state: spec.OptimizerState,
329+
eval_results: List[Tuple[int, float]],
330+
global_step: int,
331+
rng: spec.RandomState) -> spec.UpdateReturn:
332+
"""Return (updated_optimizer_state, updated_params)."""
333+
del workload
334+
del hyperparameters
335+
del current_params_types
336+
del loss_type
337+
del eval_results
338+
del global_step
339+
del rng
340+
return (optimizer_state, current_param_container, model_state)
341+
342+
322343
def get_batch_size(workload_name):
323344
# Return the global batch size.
324345
if workload_name == 'criteo1tb':

prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,27 @@ def update_params(
319319
return (optimizer_state, current_param_container, new_model_state)
320320

321321

322+
def prepare_for_eval(workload: spec.Workload,
323+
current_param_container: spec.ParameterContainer,
324+
current_params_types: spec.ParameterTypeTree,
325+
model_state: spec.ModelAuxiliaryState,
326+
hyperparameters: spec.Hyperparameters,
327+
loss_type: spec.LossType,
328+
optimizer_state: spec.OptimizerState,
329+
eval_results: List[Tuple[int, float]],
330+
global_step: int,
331+
rng: spec.RandomState) -> spec.UpdateReturn:
332+
"""Return (updated_optimizer_state, updated_params)."""
333+
del workload
334+
del hyperparameters
335+
del current_params_types
336+
del loss_type
337+
del eval_results
338+
del global_step
339+
del rng
340+
return (optimizer_state, current_param_container, model_state)
341+
342+
322343
def get_batch_size(workload_name):
323344
# Return the global batch size.
324345
if workload_name == 'criteo1tb':

0 commit comments

Comments
 (0)