Skip to content

Commit 816c167

Browse files
formatting
1 parent 56ab820 commit 816c167

File tree

22 files changed

+23
-26
lines changed

22 files changed

+23
-26
lines changed

algoperf/workloads/wmt/wmt_jax/workload.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ def translate_and_calculate_bleu(
240240
return bleu_score
241241

242242
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
243-
244243
init_fake_batch_size = 8
245244
input_shape = (init_fake_batch_size, 256)
246245
target_shape = (init_fake_batch_size, 256)

reference_algorithms/paper_baselines/adafactor/jax/submission.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def update_params(
169169
per_device_rngs,
170170
grad_clip,
171171
label_smoothing,
172-
dropout_rate
172+
dropout_rate,
173173
)
174174
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
175175

reference_algorithms/paper_baselines/adafactor/pytorch/submission.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def update_params(
227227
mode=spec.ForwardPassMode.TRAIN,
228228
rng=rng,
229229
update_batch_norm=True,
230-
dropout_rate=hyperparameters.dropout_rate
230+
dropout_rate=hyperparameters.dropout_rate,
231231
)
232232

233233
label_smoothing = (

reference_algorithms/paper_baselines/adamw/jax/submission.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def update_params(
189189
rng,
190190
grad_clip,
191191
label_smoothing,
192-
dropout_rate
192+
dropout_rate,
193193
)
194194
)
195195

reference_algorithms/paper_baselines/adamw/pytorch/submission.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def update_params(
8484
mode=spec.ForwardPassMode.TRAIN,
8585
rng=rng,
8686
update_batch_norm=True,
87-
dropout_rate=hyperparameters.dropout_rate
87+
dropout_rate=hyperparameters.dropout_rate,
8888
)
8989

9090
label_smoothing = (

reference_algorithms/paper_baselines/lamb/jax/submission.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def update_params(
176176
per_device_rngs,
177177
grad_clip,
178178
label_smoothing,
179-
dropout_rate
179+
dropout_rate,
180180
)
181181
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
182182

reference_algorithms/paper_baselines/lamb/pytorch/submission.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def update_params(
225225
mode=spec.ForwardPassMode.TRAIN,
226226
rng=rng,
227227
update_batch_norm=True,
228-
dropout_rate=hyperparameters.dropout_rate
228+
dropout_rate=hyperparameters.dropout_rate,
229229
)
230230

231231
label_smoothing = (

reference_algorithms/paper_baselines/momentum/jax/submission.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def update_params(
227227
rng,
228228
grad_clip,
229229
label_smoothing,
230-
dropout_rate
230+
dropout_rate,
231231
)
232232
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
233233

reference_algorithms/paper_baselines/momentum/pytorch/submission.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def update_params(
104104
mode=spec.ForwardPassMode.TRAIN,
105105
rng=rng,
106106
update_batch_norm=True,
107-
dropout_rate=hyperparameters.dropout_rate
107+
dropout_rate=hyperparameters.dropout_rate,
108108
)
109109

110110
label_smoothing = (

reference_algorithms/paper_baselines/nadamw/jax/submission.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def update_params(
340340
rng,
341341
grad_clip,
342342
label_smoothing,
343-
dropout_rate
343+
dropout_rate,
344344
)
345345
)
346346

0 commit comments

Comments
 (0)