-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Refactor LLM pretraining examples #7159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7bce6ec
a188faf
6763963
6fab8cc
bfd62b9
e18579d
2cb1b0d
77f1c86
0344f4e
96dc68e
caaebd0
84c0130
703f3e4
ca25161
fd098f0
f81cc4d
744f9e4
5d8bb3a
4040ccc
d7000ee
a9c4a65
0278cc6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from omegaconf import DictConfig | ||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.callbacks import ModelSummary | ||
from pytorch_lightning.plugins.environments import TorchElasticEnvironment | ||
from nemo.collections.nlp.parts.nlp_overrides import ( | ||
GradScaler, | ||
MegatronHalfPrecisionPlugin, | ||
NLPDDPStrategy, | ||
PipelineMixedPrecisionPlugin, | ||
) | ||
|
||
|
||
class MegatronTrainerBuilder: | ||
""" | ||
Builder type to hide complex configuration of PTL Trainers for Megatron LLM models. | ||
Can be extended to change behavior for a specific model. | ||
""" | ||
|
||
def __init__(self, cfg: DictConfig) -> None: | ||
self.cfg = cfg | ||
|
||
def _training_strategy(self) -> NLPDDPStrategy: | ||
""" | ||
Returns a ddp strategy passed to Trainer.strategy. | ||
""" | ||
return NLPDDPStrategy( | ||
no_ddp_communication_hook=True, | ||
gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view, | ||
find_unused_parameters=False, | ||
) | ||
|
||
def _grad_scaler(self) -> GradScaler: | ||
""" | ||
Returns a scaler for precision plugins. | ||
""" | ||
return GradScaler( | ||
init_scale=self.cfg.model.get('native_amp_init_scale', 2 ** 32), | ||
growth_interval=self.cfg.model.get('native_amp_growth_interval', 1000), | ||
hysteresis=self.cfg.model.get('hysteresis', 2), | ||
) | ||
|
||
def _plugins(self) -> list: | ||
""" | ||
Returns: | ||
plugins: list of plugins passed to Trainer.plugins including precision plugins. | ||
""" | ||
megatron_amp_o2 = self.cfg.model.get('megatron_amp_O2', False) | ||
with_distributed_adam = self.cfg.model.optim.get('name') == 'distributed_fused_adam' | ||
|
||
plugins = [] | ||
if self.cfg.trainer.precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']: | ||
scaler = None | ||
if self.cfg.trainer.precision in [16, '16', '16-mixed']: | ||
scaler = self._grad_scaler() | ||
plugin_precision = '16-mixed' | ||
else: | ||
plugin_precision = 'bf16-mixed' | ||
|
||
if megatron_amp_o2 and not with_distributed_adam: | ||
plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) | ||
else: | ||
plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) | ||
|
||
if self.cfg.get('cluster_type', None) == 'BCP': | ||
plugins.append(TorchElasticEnvironment()) | ||
|
||
return plugins | ||
|
||
def create_trainer(self) -> Trainer: | ||
strategy = self._training_strategy() | ||
plugins = self._plugins() | ||
return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer) | ||
|
||
|
||
class MegatronBertTrainerBuilder(MegatronTrainerBuilder): | ||
"""Builder for BERT model Trainer with overrides.""" | ||
|
||
def _grad_scaler(self) -> GradScaler: | ||
return GradScaler( | ||
init_scale=self.cfg.model.get('native_amp_init_scale', 2 ** 32), | ||
growth_interval=self.cfg.model.get('native_amp_growth_interval', 1000), | ||
) | ||
|
||
|
||
class MegatronT5TrainerBuilder(MegatronTrainerBuilder): | ||
"""Builder for T5 model Trainer with overrides.""" | ||
|
||
def create_trainer(self) -> Trainer: | ||
strategy = self._training_strategy() | ||
plugins = self._plugins() | ||
return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=[ModelSummary(max_depth=3)]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -139,6 +139,7 @@ class ExpManagerConfig: | |
resume_if_exists: Optional[bool] = False | ||
resume_past_end: Optional[bool] = False | ||
resume_ignore_no_checkpoint: Optional[bool] = False | ||
resume_from_checkpoint: Optional[str] = None | ||
# Logging parameters | ||
create_tensorboard_logger: Optional[bool] = True | ||
summary_writer_kwargs: Optional[Dict[Any, Any]] = None | ||
|
@@ -257,6 +258,8 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo | |
- resume_ignore_no_checkpoint (bool): exp_manager errors out if resume_if_exists is True and no checkpoint | ||
could be found. This behaviour can be disabled, in which case exp_manager will print a message and | ||
continue without restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False. | ||
- resume_from_checkpoint (str): Can be used to specify a path to a specific checkpoint file to load from. This will | ||
override any checkpoint found when resume_if_exists is True. Defaults to None. | ||
- create_tensorboard_logger (bool): Whether to create a tensorboard logger and attach it to the pytorch | ||
lightning trainer. Defaults to True. | ||
- summary_writer_kwargs (dict): A dictionary of kwargs that can be passed to lightning's TensorboardLogger | ||
|
@@ -343,6 +346,12 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo | |
else: | ||
check_resume(trainer, log_dir, cfg.resume_past_end, cfg.resume_ignore_no_checkpoint) | ||
|
||
# TODO: this behavior is undesirable, need ckpts in exp_dir to take priority if present over resume_from_checkpoint | ||
# if cfg.resume_from_checkpoint is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @maanug-nv where are we taking care of the below lines then:
Since the pre training scripts were assigning the checkpoint to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, so initially I moved those lines exactly as is to this place in if 'resume_from_checkpoint' is set, that checkpoint is always used despite what is in the log dir. What makes more sense is that 'resume_from_checkpoint' is used if no log_dir is present, but log_dir takes priority if present. |
||
# trainer.ckpt_path = cfg.resume_from_checkpoint | ||
Comment on lines
+350
to
+351
Check noticeCode scanning / CodeQL Commented-out code
This comment appears to contain commented-out code.
|
||
|
||
logging.info(f'Resuming training from checkpoint: {trainer.ckpt_path}') | ||
|
||
checkpoint_name = name | ||
# If name returned from get_log_dir is "", use cfg.name for checkpointing | ||
if checkpoint_name is None or checkpoint_name == '': | ||
|
Uh oh!
There was an error while loading. Please reload this page.