Skip to content

Commit dfce4ca

Browse files
committed
fix for wmt dropout
1 parent 9325826 commit dfce4ca

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

algoperf/workloads/wmt/wmt_jax/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def __call__(
308308
bias_init=cfg.bias_init,
309309
use_bias=False,
310310
broadcast_dropout=False,
311-
dropout_rate=dropout_rate,
311+
dropout_rate=0.0, # Dropout applied after attention
312312
deterministic=cfg.deterministic,
313313
)(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask)
314314

0 commit comments

Comments
 (0)