Skip to content

Commit 01cfbb0

Browse files
committed
make attention dropout static
1 parent d1770c2 commit 01cfbb0

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

algoperf/workloads/librispeech_conformer/librispeech_jax/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE):
428428
use_bias=True,
429429
broadcast_dropout=False,
430430
attention_fn=attention_fn,
431+
dropout_rate=0.0,
431432
deterministic=not train,
432433
)(inputs_q=inputs, mask=attention_mask)
433434

algoperf/workloads/wmt/wmt_jax/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE):
223223
bias_init=cfg.bias_init,
224224
use_bias=False,
225225
broadcast_dropout=False,
226+
dropout_rate=0.0,
226227
deterministic=cfg.deterministic,
227228
)(cfg.attention_temp * x, x, mask=encoder_mask)
228229

0 commit comments

Comments
 (0)