diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index fe3a1e179..90a12b779 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -93,7 +93,7 @@ def __init__(self, out_features=self.encoder_dim, bias=True) self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) - self.dropout = nn.Dropout(p=self.input_dropout_rate) + self.dropout = nn.Dropout(p=self.input_dropout_rate, inplace=True) def forward(self, inputs, input_paddings): output_paddings = input_paddings @@ -195,7 +195,7 @@ def __init__(self, config: ConformerConfig): in_features=config.encoder_dim, out_features=config.encoder_dim * config.feed_forward_expansion_factor, bias=True) - self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate) + self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate, inplace=True) self.linear2 = nn.Linear( in_features=config.encoder_dim * config.feed_forward_expansion_factor, out_features=config.encoder_dim, @@ -206,8 +206,9 @@ def __init__(self, config: ConformerConfig): else: feed_forward_residual_dropout_rate = ( config.feed_forward_residual_dropout_rate) - self.dropout2 = nn.Dropout(p=feed_forward_residual_dropout_rate) - + self.dropout2 = nn.Dropout( + p=feed_forward_residual_dropout_rate, inplace=True) + def forward(self, inputs, padding_mask): inputs = self.ln(inputs) inputs = self.linear1(inputs) @@ -316,7 +317,7 @@ def __init__(self, config: ConformerConfig): attention_residual_dropout_rate = 0.1 else: attention_residual_dropout_rate = config.attention_residual_dropout_rate - self.dropout = nn.Dropout(p=attention_residual_dropout_rate) + self.dropout = nn.Dropout(p=attention_residual_dropout_rate, inplace=True) def forward(self, outputs, paddings): outputs = self.ln(outputs) @@ -407,7 +408,7 @@ def __init__(self, config): conv_residual_dropout_rate = 0.0 else: conv_residual_dropout_rate = config.conv_residual_dropout_rate - self.dropout = nn.Dropout(p=conv_residual_dropout_rate) + self.dropout = nn.Dropout(p=conv_residual_dropout_rate, inplace=True) def forward(self, inputs, input_paddings): inputs = self.ln(inputs) diff --git a/submission_runner.py b/submission_runner.py index ff290079b..7c8d7fb53 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -203,6 +203,7 @@ def train_once( log_dir: Optional[str] = None, save_checkpoints: Optional[bool] = True ) -> Tuple[spec.Timing, Dict[str, Any]]: + _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) # Workload setup.