Skip to content

Commit 56dd098

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 03fb5ce commit 56dd098

File tree

4 files changed

+6
-13
lines changed

4 files changed

+6
-13
lines changed

examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@
4242
from typing import Any, Optional
4343

4444
import torch
45+
from lightning_fabric.utilities.cloud_io import _load as pl_load
4546
from megatron.core import parallel_state
4647
from pytorch_lightning.core.saving import _load_state as ptl_load_state
4748
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml
4849
from pytorch_lightning.trainer.trainer import Trainer
49-
from lightning_fabric.utilities.cloud_io import _load as pl_load
5050
from pytorch_lightning.utilities.migration import pl_legacy_patch
5151

5252
from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel

nemo/collections/nlp/models/nlp_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
import os
1919
from typing import Any, Optional
2020

21+
from lightning_fabric.utilities.cloud_io import _load as pl_load
2122
from omegaconf import DictConfig, OmegaConf
2223
from pytorch_lightning import Trainer
2324
from pytorch_lightning.core.saving import _load_state as ptl_load_state
2425
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml
2526
from pytorch_lightning.utilities import rank_zero_only
26-
from lightning_fabric.utilities.cloud_io import _load as pl_load
2727
from pytorch_lightning.utilities.migration import pl_legacy_patch
2828
from transformers import TRANSFORMERS_CACHE
2929

nemo/collections/nlp/parts/nlp_overrides.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@
2424
import pytorch_lightning as pl
2525
import torch
2626
from omegaconf import OmegaConf
27+
from pytorch_lightning.loops.fetchers import _DataFetcher
2728
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
2829
from pytorch_lightning.plugins import ClusterEnvironment
2930
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
3031
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
3132
from pytorch_lightning.strategies.ddp import DDPStrategy
3233
from pytorch_lightning.trainer.trainer import Trainer
3334
from pytorch_lightning.utilities.exceptions import MisconfigurationException
34-
from pytorch_lightning.loops.fetchers import _DataFetcher
3535
from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook
3636
from torch.nn.parallel import DistributedDataParallel
3737

tests/core/test_exp_manager.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -329,17 +329,13 @@ def test_resume(self, tmp_path):
329329
{"resume_if_exists": True, "explicit_log_dir": str(tmp_path / "test_resume" / "default" / "version_0")},
330330
)
331331
checkpoint = Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last.ckpt")
332-
assert (
333-
Path(test_trainer._checkpoint_connector._ckpt_path).resolve() == checkpoint.resolve()
334-
)
332+
assert Path(test_trainer._checkpoint_connector._ckpt_path).resolve() == checkpoint.resolve()
335333

336334
# Succeed again and make sure that run_0 exists and previous log files were moved
337335
test_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False)
338336
exp_manager(test_trainer, {"resume_if_exists": True, "explicit_log_dir": str(log_dir)})
339337
checkpoint = Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last.ckpt")
340-
assert (
341-
Path(test_trainer._checkpoint_connector._ckpt_path).resolve() == checkpoint.resolve()
342-
)
338+
assert Path(test_trainer._checkpoint_connector._ckpt_path).resolve() == checkpoint.resolve()
343339
prev_run_dir = Path(tmp_path / "test_resume" / "default" / "version_0" / "run_0")
344340
assert prev_run_dir.exists()
345341
prev_log = Path(tmp_path / "test_resume" / "default" / "version_0" / "run_0" / "lightning_logs.txt")
@@ -372,10 +368,7 @@ def test_resume(self, tmp_path):
372368
"explicit_log_dir": str(dirpath_log_dir),
373369
},
374370
)
375-
assert (
376-
Path(test_trainer._checkpoint_connector._ckpt_path).resolve()
377-
== dirpath_checkpoint.resolve()
378-
)
371+
assert Path(test_trainer._checkpoint_connector._ckpt_path).resolve() == dirpath_checkpoint.resolve()
379372

380373
@pytest.mark.unit
381374
def test_nemo_checkpoint_save_best_model_1(self, tmp_path):

0 commit comments

Comments
 (0)