Skip to content

Commit cdb485e

Browse files
committed
Add default values for args to fix Attribute Error
Signed-off-by: Abhishree <[email protected]>
1 parent ca2c900 commit cdb485e

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

nemo/core/classes/modelPT.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,8 +849,11 @@ def test_dataloader(self):
849849
if self._test_dl is not None:
850850
return self._test_dl
851851

852+
#TODO: Confirm if outputs default vals are required
852853
def on_validation_epoch_end(
853-
self, outputs: Union[List[Dict[str, torch.Tensor]], List[List[Dict[str, torch.Tensor]]]]
854+
self, outputs: Union[List[Dict[str, torch.Tensor]], List[List[Dict[str, torch.Tensor]]]] = [{'val_loss':torch.tensor(0.), 'val_wer_num':torch.tensor(0.), 'val_wer_denom':torch.tensor(0.),
855+
'val_final_loss': torch.tensor(0.), 'val_inter_ctc_loss_l2': torch.tensor(0.), 'val_inter_ctc_loss_l4': torch.tensor(0.), 'val_inter_wer_l2': torch.tensor(0.)}]
856+
#[[torch.tensor(0)], [torch.tensor(0)]]
854857
) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
855858
"""
856859
Default DataLoader for Validation set which automatically supports multiple data loaders

nemo/utils/exp_manager.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,10 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
197197
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
198198
self._on_batch_end("train_step_timing", pl_module)
199199

200-
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
200+
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx=0):
201201
self._on_batch_start("validation_step_timing")
202202

203-
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
203+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
204204
self._on_batch_end("validation_step_timing", pl_module)
205205

206206
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
@@ -453,7 +453,6 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
453453
if cfg.disable_validation_on_resume:
454454
# extend training loop to skip initial validation when resuming from checkpoint
455455
configure_no_restart_validation_training_loop(trainer)
456-
457456
# Setup a stateless timer for use on clusters.
458457
if cfg.max_time_per_run is not None:
459458
found_ptl_timer = False
@@ -937,8 +936,8 @@ def configure_no_restart_validation_training_loop(trainer: pytorch_lightning.Tra
937936
if type(trainer.fit_loop.epoch_loop) != _TrainingEpochLoop:
938937
warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning)
939938
return
940-
loop = SkipResumeTrainingValidationLoop(trainer.min_steps, trainer.max_steps)
941-
loop.trainer = trainer
939+
## Pass trainer object to avoid trainer getting overwritten as None
940+
loop = SkipResumeTrainingValidationLoop(trainer, trainer.min_steps, trainer.max_steps)
942941
trainer.fit_loop.epoch_loop = loop
943942

944943

tests/collections/common/test_ema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import os.path
16-
from typing import Any, Dict, Union
16+
from typing import Any, Dict, Union, Optional
1717

1818
import pytest
1919
import pytorch_lightning as pl
@@ -105,7 +105,7 @@ def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]):
105105
def setup_test_data(self, val_data_config: Union[DictConfig, Dict]):
106106
pass
107107

108-
def on_validation_epoch_end(self, loss):
108+
def on_validation_epoch_end(self, loss: torch.tensor = [torch.tensor([0.0])]):
109109
self.log("val_loss", torch.stack(loss).mean())
110110

111111

0 commit comments

Comments
 (0)