Skip to content

Commit 0643268

Browse files
Merge pull request #884 from Niccolo-Ajroldi/dropout_subs_fix
Dropout Fix - Final Polishing: updated submissions, tests, docs.
2 parents bfd72bb + 816c167 commit 0643268

File tree

64 files changed

+153
-53
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+153
-53
lines changed

algoperf/spec.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -235,16 +235,9 @@ def model_params_types(self):
235235
def is_output_params(self, param_key: ParameterKey) -> bool:
236236
"""Whether a key in ParameterContainer is the output layer parameters."""
237237

238-
# InitModelFn = Callable[
239-
# Tuple[RandomState, Optional[float], Optional[float]],
240-
# ParameterContainer]
238+
# InitModelFn = Callable[Optional[float]], ParameterContainer]
241239
@abc.abstractmethod
242-
def init_model_fn(
243-
self,
244-
rng: RandomState,
245-
dropout_rate: Optional[float] = None,
246-
aux_dropout_rate: Optional[float] = None,
247-
) -> ModelInitState:
240+
def init_model_fn(self, rng: RandomState) -> ModelInitState:
248241
"""Return (initial_params, initial_model_state)."""
249242

250243
# ModelFn = Callable[

algoperf/workloads/cifar/cifar_jax/workload.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,11 @@ def model_fn(
105105
rng: spec.RandomState,
106106
update_batch_norm: bool,
107107
use_running_average_bn: Optional[bool] = None,
108+
dropout_rate: float = 0.0,
108109
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
109110
del mode
110111
del rng
112+
del dropout_rate
111113
variables = {'params': params, **model_state}
112114
if update_batch_norm:
113115
logits, new_model_state = self._model.apply(

algoperf/workloads/cifar/cifar_pytorch/workload.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,8 @@ def _build_dataset(
118118
dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP)
119119
return dataloader
120120

121-
def init_model_fn(
122-
self,
123-
rng: spec.RandomState,
124-
dropout_rate: Optional[float] = None,
125-
aux_dropout_rate: Optional[float] = None,
126-
) -> spec.ModelInitState:
121+
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
127122
"""Dropout is unused."""
128-
del dropout_rate
129-
del aux_dropout_rate
130-
131123
if hasattr(self, '_model'):
132124
if isinstance(self._model, (DDP, torch.nn.DataParallel)):
133125
self._model.module.reset_parameters()
@@ -158,9 +150,11 @@ def model_fn(
158150
mode: spec.ForwardPassMode,
159151
rng: spec.RandomState,
160152
update_batch_norm: bool,
153+
dropout_rate: float = 0.0,
161154
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
162155
del model_state
163156
del rng
157+
del dropout_rate
164158
model = params
165159
if mode == spec.ForwardPassMode.EVAL:
166160
if update_batch_norm:

algoperf/workloads/fastmri/fastmri_jax/workload.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def init_model_fn(
1919
self,
2020
rng: spec.RandomState,
2121
) -> spec.ModelInitState:
22-
"""aux_dropout_rate is unused."""
2322
fake_batch = jnp.zeros((13, 320, 320))
2423
self._model = UNet(
2524
num_pool_layers=self.num_pool_layers,

algoperf/workloads/imagenet_vit/imagenet_jax/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class ViT(nn.Module):
186186
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
187187
num_heads: int = 12
188188
rep_size: Union[int, bool] = True
189-
dropout_rate: [float] = DROPOUT_RATE
189+
dropout_rate: float = DROPOUT_RATE
190190
reinit: Optional[Sequence[str]] = None
191191
head_zeroinit: bool = True
192192
use_glu: bool = False

algoperf/workloads/mnist/mnist_jax/workload.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ def model_fn(
4848
mode: spec.ForwardPassMode,
4949
rng: spec.RandomState,
5050
update_batch_norm: bool,
51+
dropout_rate: float = 0.0,
5152
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
5253
del model_state
5354
del rng
5455
del update_batch_norm
56+
del dropout_rate
5557
train = mode == spec.ForwardPassMode.TRAIN
5658
logits_batch = self._model.apply(
5759
{'params': params},

algoperf/workloads/mnist/mnist_pytorch/workload.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,16 +138,7 @@ def shard(batch):
138138
}
139139
yield batch
140140

141-
def init_model_fn(
142-
self,
143-
rng: spec.RandomState,
144-
dropout_rate: Optional[float] = None,
145-
aux_dropout_rate: Optional[float] = None,
146-
) -> spec.ModelInitState:
147-
"""Dropout is unused."""
148-
del dropout_rate
149-
del aux_dropout_rate
150-
141+
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
151142
if hasattr(self, '_model'):
152143
if isinstance(self._model, (DDP, torch.nn.DataParallel)):
153144
self._model.module.reset_parameters()
@@ -178,10 +169,12 @@ def model_fn(
178169
mode: spec.ForwardPassMode,
179170
rng: spec.RandomState,
180171
update_batch_norm: bool,
172+
dropout_rate: float = 0.0,
181173
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
182174
del model_state
183175
del rng
184176
del update_batch_norm
177+
del dropout_rate
185178
model = params
186179
if mode == spec.ForwardPassMode.EVAL:
187180
model.eval()

algoperf/workloads/wmt/wmt_jax/workload.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -239,14 +239,7 @@ def translate_and_calculate_bleu(
239239
bleu_score = bleu.corpus_bleu(predictions, [references]).score
240240
return bleu_score
241241

242-
def init_model_fn(
243-
self,
244-
rng: spec.RandomState,
245-
dropout_rate: Optional[float] = None,
246-
aux_dropout_rate: Optional[float] = None,
247-
) -> spec.ModelInitState:
248-
"""aux_dropout_rate is used as attention_dropout_rate."""
249-
242+
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
250243
init_fake_batch_size = 8
251244
input_shape = (init_fake_batch_size, 256)
252245
target_shape = (init_fake_batch_size, 256)
@@ -295,7 +288,7 @@ def model_fn(
295288
mode: spec.ForwardPassMode,
296289
rng: spec.RandomState,
297290
update_batch_norm: bool,
298-
dropout_rate: [float] = models.DROPOUT_RATE,
291+
dropout_rate: float = models.DROPOUT_RATE,
299292
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
300293
del model_state
301294
del update_batch_norm

docs/DOCUMENTATION.md

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,7 @@ def _build_input_queue(
104104
###### Model initialization
105105

106106
```python
107-
def init_model_fn(
108-
self,
109-
rng: RandomState,
110-
dropout_rate: Optional[float] = None,
111-
aux_dropout_rate: Optional[float] = None
112-
) -> initial model parameters
107+
def init_model_fn(self, rng: RandomState) -> initial model parameters
113108
```
114109

115110
- Unlike in the *Model Track*, this function that initializes the parameters of the model, is fixed. While it can be called by the submission (e.g. to restart the model after a failed training effort) it cannot be changed.
@@ -125,7 +120,8 @@ def model_fn(
125120
mode: ForwardPassMode, # mode \in {train, eval}
126121
rng: RandomState,
127122
hyperparameters: Hyperparameters,
128-
update_batch_norm: bool
123+
update_batch_norm: bool,
124+
dropout_rate: float
129125
) -> (logits_output_batch, new_model_state): Tuple[Tensor, ModelAuxiliaryState]
130126
```
131127

@@ -134,6 +130,7 @@ def model_fn(
134130
- `logits_output_batch` is before the output activation
135131
- `new_model_state` is for batch norm or similar side effects and will only be updated if `update_batch_norm` is set
136132
- `hyperparameters` will contain only dropout rates, which will be used in the models that support it. These can be tuned or will default to documented model-specific values. Note that adding additional dropout would be considered changing the model, which is not allowed, but the tuning of dropout in existing dropout layers can be considered a regularizer, so we allow it. There should be at most two dropout rates in a model (if there are more than two we will reuse the same values).
133+
- `dropout_rate` is used in the model forward pass. If not provided, the workload’s default value is used (see below for the list of defaults).
137134

138135
###### Loss function
139136

reference_algorithms/paper_baselines/adafactor/jax/submission.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def pmapped_train_step(
7878
rng,
7979
grad_clip,
8080
label_smoothing,
81+
dropout_rate,
8182
):
8283
def _loss_fn(params):
8384
"""Loss function used for training."""
@@ -94,6 +95,7 @@ def _loss_fn(params):
9495
logits_batch=logits,
9596
mask_batch=batch.get('weights'),
9697
label_smoothing=label_smoothing,
98+
dropout_rate=dropout_rate,
9799
)
98100
summed_loss = loss_dict['summed']
99101
n_valid_examples = loss_dict['n_valid_examples']
@@ -156,6 +158,7 @@ def update_params(
156158
grad_clip = hyperparameters.grad_clip
157159
else:
158160
grad_clip = None
161+
dropout_rate = hyperparameters.dropout_rate
159162
outputs = pmapped_train_step(
160163
workload,
161164
opt_update_fn,
@@ -166,6 +169,7 @@ def update_params(
166169
per_device_rngs,
167170
grad_clip,
168171
label_smoothing,
172+
dropout_rate,
169173
)
170174
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
171175

0 commit comments

Comments
 (0)