Skip to content

Commit d1770c2

Browse files
committed
temporary fix remove dropout from attention layers
1 parent e716ac9 commit d1770c2

File tree

2 files changed

+0
-2
lines changed

2 files changed

+0
-2
lines changed

algoperf/workloads/librispeech_conformer/librispeech_jax/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,6 @@ def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE):
428428
use_bias=True,
429429
broadcast_dropout=False,
430430
attention_fn=attention_fn,
431-
dropout_rate=dropout_rate,
432431
deterministic=not train,
433432
)(inputs_q=inputs, mask=attention_mask)
434433

algoperf/workloads/wmt/wmt_jax/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE):
223223
bias_init=cfg.bias_init,
224224
use_bias=False,
225225
broadcast_dropout=False,
226-
dropout_rate=dropout_rate,
227226
deterministic=cfg.deterministic,
228227
)(cfg.attention_temp * x, x, mask=encoder_mask)
229228

0 commit comments

Comments
 (0)