25
25
26
26
27
27
class MegatronTrainerBuilder :
28
+ """
29
+ Builder type to hide complex configuration of PTL Trainers for Megatron LLM models.
30
+ Can be extended to change behavior for a specific model.
31
+ """
32
+
28
33
def __init__ (self , cfg : DictConfig ) -> None :
29
34
self .cfg = cfg
30
35
31
36
def _training_strategy (self ) -> NLPDDPStrategy :
37
+ """
38
+ Returns a ddp strategy passed to Trainer.strategy.
39
+ """
32
40
return NLPDDPStrategy (
33
41
no_ddp_communication_hook = True ,
34
42
gradient_as_bucket_view = self .cfg .model .gradient_as_bucket_view ,
35
43
find_unused_parameters = False ,
36
44
)
37
45
38
46
def _grad_scaler (self ) -> GradScaler :
47
+ """
48
+ Returns a scaler for precision plugins.
49
+ """
39
50
return GradScaler (
40
51
init_scale = self .cfg .model .get ('native_amp_init_scale' , 2 ** 32 ),
41
52
growth_interval = self .cfg .model .get ('native_amp_growth_interval' , 1000 ),
42
53
hysteresis = self .cfg .model .get ('hysteresis' , 2 ),
43
54
)
44
55
45
56
def _plugins (self ) -> list :
57
+ """
58
+ Returns:
59
+ plugins: list of plugins passed to Trainer.plugins including precision plugins.
60
+ """
46
61
megatron_amp_o2 = self .cfg .model .get ('megatron_amp_O2' , False )
47
62
with_distributed_adam = self .cfg .model .optim .get ('name' ) == 'distributed_fused_adam'
48
63
@@ -72,6 +87,8 @@ def create_trainer(self) -> Trainer:
72
87
73
88
74
89
class MegatronBertTrainerBuilder (MegatronTrainerBuilder ):
90
+ """Builder for BERT model Trainer with overrides."""
91
+
75
92
def _grad_scaler (self ) -> GradScaler :
76
93
return GradScaler (
77
94
init_scale = self .cfg .model .get ('native_amp_init_scale' , 2 ** 32 ),
@@ -80,6 +97,8 @@ def _grad_scaler(self) -> GradScaler:
80
97
81
98
82
99
class MegatronT5TrainerBuilder (MegatronTrainerBuilder ):
100
+ """Builder for T5 model Trainer with overrides."""
101
+
83
102
def create_trainer (self ) -> Trainer :
84
103
strategy = self ._training_strategy ()
85
104
plugins = self ._plugins ()
0 commit comments