We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9325826 commit dfce4caCopy full SHA for dfce4ca
algoperf/workloads/wmt/wmt_jax/models.py
@@ -308,7 +308,7 @@ def __call__(
308
bias_init=cfg.bias_init,
309
use_bias=False,
310
broadcast_dropout=False,
311
- dropout_rate=dropout_rate,
+ dropout_rate=0.0, # Dropout applied after attention
312
deterministic=cfg.deterministic,
313
)(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask)
314
0 commit comments