diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index 3747e85fc622..98711d9a41cd 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -33,6 +33,7 @@ exp_manager: name: null resume_if_exists: True resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} create_checkpoint_callback: True checkpoint_callback_params: monitor: val_loss diff --git a/examples/nlp/language_modeling/megatron_bert_pretraining.py b/examples/nlp/language_modeling/megatron_bert_pretraining.py index 9199f03f0890..5b4876141f74 100644 --- a/examples/nlp/language_modeling/megatron_bert_pretraining.py +++ b/examples/nlp/language_modeling/megatron_bert_pretraining.py @@ -14,17 +14,9 @@ import torch.multiprocessing as mp from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel -from nemo.collections.nlp.parts.nlp_overrides import ( - GradScaler, - MegatronHalfPrecisionPlugin, - NLPDDPStrategy, - PipelineMixedPrecisionPlugin, -) +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronBertTrainerBuilder from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager @@ -38,42 +30,9 @@ def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') - megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) - with_distributed_adam = cfg.model.optim.get('name') == 'distributed_fused_adam' - - plugins = [] - strategy = NLPDDPStrategy( - no_ddp_communication_hook=True, - gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, - find_unused_parameters=False, - ) - - if cfg.trainer.precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']: - scaler = None - if cfg.trainer.precision in [16, '16', '16-mixed']: - scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), - growth_interval=cfg.model.get('native_amp_growth_interval', 1000), - ) - # MixedPrecisionPlugin in PTL >= 2.0 requires precision to be 16-mixed or bf16-mixed - plugin_precision = '16-mixed' - else: - plugin_precision = 'bf16-mixed' - if megatron_amp_o2 and not with_distributed_adam: - plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) - else: - plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) - - if cfg.get('cluster_type', None) == 'BCP': - plugins.append(TorchElasticEnvironment()) - - trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) - + trainer = MegatronBertTrainerBuilder(cfg).create_trainer() exp_manager(trainer, cfg.exp_manager) - # resume_from_checkpoint = uninject_model_parallel_rank(resume_from_checkpoint) - logging.info(f'Resuming training from checkpoint: {trainer.ckpt_path}') - # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams with open_dict(cfg): cfg.model.precision = cfg.trainer.precision diff --git a/examples/nlp/language_modeling/megatron_gpt_pretraining.py b/examples/nlp/language_modeling/megatron_gpt_pretraining.py index 5068f5d2222d..291a85ac8a05 100644 --- a/examples/nlp/language_modeling/megatron_gpt_pretraining.py +++ b/examples/nlp/language_modeling/megatron_gpt_pretraining.py @@ -15,17 +15,9 @@ import torch.multiprocessing as mp from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel -from nemo.collections.nlp.parts.nlp_overrides import ( - GradScaler, - MegatronHalfPrecisionPlugin, - NLPDDPStrategy, - PipelineMixedPrecisionPlugin, -) +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager @@ -38,46 +30,9 @@ def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') - megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) - with_distributed_adam = cfg.model.optim.get('name') == 'distributed_fused_adam' - - plugins = [] - strategy = NLPDDPStrategy( - no_ddp_communication_hook=True, # we don't use DDP for async grad allreduce - gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, - find_unused_parameters=False, - ) - if cfg.trainer.precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']: - scaler = None - if cfg.trainer.precision in [16, '16', '16-mixed']: - scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), - growth_interval=cfg.model.get('native_amp_growth_interval', 1000), - hysteresis=cfg.model.get('hysteresis', 2), - ) - # MixedPrecisionPlugin in PTL >= 2.0 requires precision to be 16-mixed or bf16-mixed - plugin_precision = '16-mixed' - else: - plugin_precision = 'bf16-mixed' - - if megatron_amp_o2 and not with_distributed_adam: - plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) - else: - plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) - - if cfg.get('cluster_type', None) == 'BCP': - plugins.append(TorchElasticEnvironment()) - - trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) - + trainer = MegatronTrainerBuilder(cfg).create_trainer() exp_manager(trainer, cfg.exp_manager) - # update resume from checkpoint found by exp_manager - if cfg.model.resume_from_checkpoint is not None: - trainer.ckpt_path = cfg.model.resume_from_checkpoint - - logging.info(f'Resuming training from checkpoint: {trainer.ckpt_path}') - # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams with open_dict(cfg): cfg.model.precision = cfg.trainer.precision diff --git a/examples/nlp/language_modeling/megatron_t5_pretraining.py b/examples/nlp/language_modeling/megatron_t5_pretraining.py index 1674faab773f..ea5d751ab59a 100644 --- a/examples/nlp/language_modeling/megatron_t5_pretraining.py +++ b/examples/nlp/language_modeling/megatron_t5_pretraining.py @@ -14,18 +14,9 @@ from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelSummary -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model -from nemo.collections.nlp.parts.nlp_overrides import ( - GradScaler, - MegatronHalfPrecisionPlugin, - NLPDDPStrategy, - PipelineMixedPrecisionPlugin, -) +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronT5TrainerBuilder from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager @@ -36,41 +27,9 @@ def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') - megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) - with_distributed_adam = cfg.model.optim.get('name') == 'distributed_fused_adam' - plugins = [] - strategy = NLPDDPStrategy( - no_ddp_communication_hook=True, # we don't use DDP for async grad allreduce - gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, - find_unused_parameters=False, - ) - if cfg.trainer.precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']: - scaler = None - if cfg.trainer.precision in [16, '16', '16-mixed']: - scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), - growth_interval=cfg.model.get('native_amp_growth_interval', 1000), - hysteresis=cfg.model.get('hysteresis', 2), - ) - plugin_precision = '16-mixed' - else: - plugin_precision = 'bf16-mixed' - if megatron_amp_o2 and not with_distributed_adam: - plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) - else: - plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) - - if cfg.get('cluster_type', None) == 'BCP': - plugins.append(TorchElasticEnvironment()) - - trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer, callbacks=[ModelSummary(max_depth=3)]) + trainer = MegatronT5TrainerBuilder(cfg).create_trainer() exp_manager(trainer, cfg.exp_manager) - # update resume from checkpoint found by exp_manager - if cfg.model.resume_from_checkpoint is not None: - trainer.ckpt_path = cfg.model.resume_from_checkpoint - logging.info(f'Resuming training from checkpoint: {trainer.ckpt_path}') - # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams with open_dict(cfg): cfg.model.precision = cfg.trainer.precision diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py new file mode 100644 index 000000000000..e5af76cbd1ec --- /dev/null +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -0,0 +1,105 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from omegaconf import DictConfig +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelSummary +from pytorch_lightning.plugins.environments import TorchElasticEnvironment +from nemo.collections.nlp.parts.nlp_overrides import ( + GradScaler, + MegatronHalfPrecisionPlugin, + NLPDDPStrategy, + PipelineMixedPrecisionPlugin, +) + + +class MegatronTrainerBuilder: + """ + Builder type to hide complex configuration of PTL Trainers for Megatron LLM models. + Can be extended to change behavior for a specific model. + """ + + def __init__(self, cfg: DictConfig) -> None: + self.cfg = cfg + + def _training_strategy(self) -> NLPDDPStrategy: + """ + Returns a ddp strategy passed to Trainer.strategy. + """ + return NLPDDPStrategy( + no_ddp_communication_hook=True, + gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view, + find_unused_parameters=False, + ) + + def _grad_scaler(self) -> GradScaler: + """ + Returns a scaler for precision plugins. + """ + return GradScaler( + init_scale=self.cfg.model.get('native_amp_init_scale', 2 ** 32), + growth_interval=self.cfg.model.get('native_amp_growth_interval', 1000), + hysteresis=self.cfg.model.get('hysteresis', 2), + ) + + def _plugins(self) -> list: + """ + Returns: + plugins: list of plugins passed to Trainer.plugins including precision plugins. + """ + megatron_amp_o2 = self.cfg.model.get('megatron_amp_O2', False) + with_distributed_adam = self.cfg.model.optim.get('name') == 'distributed_fused_adam' + + plugins = [] + if self.cfg.trainer.precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']: + scaler = None + if self.cfg.trainer.precision in [16, '16', '16-mixed']: + scaler = self._grad_scaler() + plugin_precision = '16-mixed' + else: + plugin_precision = 'bf16-mixed' + + if megatron_amp_o2 and not with_distributed_adam: + plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) + else: + plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) + + if self.cfg.get('cluster_type', None) == 'BCP': + plugins.append(TorchElasticEnvironment()) + + return plugins + + def create_trainer(self) -> Trainer: + strategy = self._training_strategy() + plugins = self._plugins() + return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer) + + +class MegatronBertTrainerBuilder(MegatronTrainerBuilder): + """Builder for BERT model Trainer with overrides.""" + + def _grad_scaler(self) -> GradScaler: + return GradScaler( + init_scale=self.cfg.model.get('native_amp_init_scale', 2 ** 32), + growth_interval=self.cfg.model.get('native_amp_growth_interval', 1000), + ) + + +class MegatronT5TrainerBuilder(MegatronTrainerBuilder): + """Builder for T5 model Trainer with overrides.""" + + def create_trainer(self) -> Trainer: + strategy = self._training_strategy() + plugins = self._plugins() + return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=[ModelSummary(max_depth=3)]) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 31d188776a41..63775f4058c5 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -139,6 +139,7 @@ class ExpManagerConfig: resume_if_exists: Optional[bool] = False resume_past_end: Optional[bool] = False resume_ignore_no_checkpoint: Optional[bool] = False + resume_from_checkpoint: Optional[str] = None # Logging parameters create_tensorboard_logger: Optional[bool] = True summary_writer_kwargs: Optional[Dict[Any, Any]] = None @@ -257,6 +258,8 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo - resume_ignore_no_checkpoint (bool): exp_manager errors out if resume_if_exists is True and no checkpoint could be found. This behaviour can be disabled, in which case exp_manager will print a message and continue without restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False. + - resume_from_checkpoint (str): Can be used to specify a path to a specific checkpoint file to load from. This will + override any checkpoint found when resume_if_exists is True. Defaults to None. - create_tensorboard_logger (bool): Whether to create a tensorboard logger and attach it to the pytorch lightning trainer. Defaults to True. - summary_writer_kwargs (dict): A dictionary of kwargs that can be passed to lightning's TensorboardLogger @@ -343,6 +346,12 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo else: check_resume(trainer, log_dir, cfg.resume_past_end, cfg.resume_ignore_no_checkpoint) + # TODO: this behavior is undesirable, need ckpts in exp_dir to take priority if present over resume_from_checkpoint + # if cfg.resume_from_checkpoint is not None: + # trainer.ckpt_path = cfg.resume_from_checkpoint + + logging.info(f'Resuming training from checkpoint: {trainer.ckpt_path}') + checkpoint_name = name # If name returned from get_log_dir is "", use cfg.name for checkpointing if checkpoint_name is None or checkpoint_name == '':