Skip to content

Commit 186bab8

Browse files
committed
docstrings for trainer builder
Signed-off-by: Maanu Grover <[email protected]>
1 parent b992496 commit 186bab8

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

examples/nlp/language_modeling/megatron_trainer_builder.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,39 @@
2525

2626

2727
class MegatronTrainerBuilder:
28+
"""
29+
Builder type to hide complex configuration of PTL Trainers for Megatron LLM models.
30+
Can be extended to change behavior for a specific model.
31+
"""
32+
2833
def __init__(self, cfg: DictConfig) -> None:
2934
self.cfg = cfg
3035

3136
def _training_strategy(self) -> NLPDDPStrategy:
37+
"""
38+
Returns a ddp strategy passed to Trainer.strategy.
39+
"""
3240
return NLPDDPStrategy(
3341
no_ddp_communication_hook=True,
3442
gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view,
3543
find_unused_parameters=False,
3644
)
3745

3846
def _grad_scaler(self) -> GradScaler:
47+
"""
48+
Returns a scaler for precision plugins.
49+
"""
3950
return GradScaler(
4051
init_scale=self.cfg.model.get('native_amp_init_scale', 2 ** 32),
4152
growth_interval=self.cfg.model.get('native_amp_growth_interval', 1000),
4253
hysteresis=self.cfg.model.get('hysteresis', 2),
4354
)
4455

4556
def _plugins(self) -> list:
57+
"""
58+
Returns:
59+
plugins: list of plugins passed to Trainer.plugins including precision plugins.
60+
"""
4661
megatron_amp_o2 = self.cfg.model.get('megatron_amp_O2', False)
4762
with_distributed_adam = self.cfg.model.optim.get('name') == 'distributed_fused_adam'
4863

@@ -72,6 +87,8 @@ def create_trainer(self) -> Trainer:
7287

7388

7489
class MegatronBertTrainerBuilder(MegatronTrainerBuilder):
90+
"""Builder for BERT model Trainer with overrides."""
91+
7592
def _grad_scaler(self) -> GradScaler:
7693
return GradScaler(
7794
init_scale=self.cfg.model.get('native_amp_init_scale', 2 ** 32),
@@ -80,6 +97,8 @@ def _grad_scaler(self) -> GradScaler:
8097

8198

8299
class MegatronT5TrainerBuilder(MegatronTrainerBuilder):
100+
"""Builder for T5 model Trainer with overrides."""
101+
83102
def create_trainer(self) -> Trainer:
84103
strategy = self._training_strategy()
85104
plugins = self._plugins()

0 commit comments

Comments
 (0)