Skip to content

Commit 5e893d6

Browse files
author
Pingchuan Ma
authored
Simplify trainining step in av-asr recipe (#3598)
* Simplify trainining step in av-asr recipe * Run pre-commit
1 parent 3e1d8f3 commit 5e893d6

File tree

3 files changed

+1
-24
lines changed

3 files changed

+1
-24
lines changed

examples/avsr/lightning.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@ def __init__(self, args=None, sp_model=None, pretrained_model_path=None):
8484
betas=(0.9, 0.98),
8585
)
8686

87-
self.automatic_optimization = False
88-
8987
def _step(self, batch, _, step_type):
9088
if batch is None:
9189
return None
@@ -123,20 +121,10 @@ def forward(self, batch):
123121
return post_process_hypos(hypotheses, self.sp_model)[0][0]
124122

125123
def training_step(self, batch, batch_idx):
126-
opt = self.optimizers()
127-
opt.zero_grad()
128124
loss = self._step(batch, batch_idx, "train")
129125
batch_size = batch.inputs.size(0)
130126
batch_sizes = self.all_gather(batch_size)
131-
132127
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
133-
self.manual_backward(loss)
134-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
135-
opt.step()
136-
137-
sch = self.lr_schedulers()
138-
sch.step()
139-
140128
self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
141129

142130
return loss

examples/avsr/lightning_av.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ def __init__(self, args=None, sp_model=None):
8080
betas=(0.9, 0.98),
8181
)
8282

83-
self.automatic_optimization = False
84-
8583
def _step(self, batch, _, step_type):
8684
if batch is None:
8785
return None
@@ -128,20 +126,10 @@ def forward(self, batch):
128126
return post_process_hypos(hypotheses, self.sp_model)[0][0]
129127

130128
def training_step(self, batch, batch_idx):
131-
opt = self.optimizers()
132-
opt.zero_grad()
133129
loss = self._step(batch, batch_idx, "train")
134130
batch_size = batch.videos.size(0)
135131
batch_sizes = self.all_gather(batch_size)
136-
137132
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
138-
self.manual_backward(loss)
139-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
140-
opt.step()
141-
142-
sch = self.lr_schedulers()
143-
sch.step()
144-
145133
self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
146134

147135
return loss

examples/avsr/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def get_trainer(args):
3636
strategy=DDPStrategy(find_unused_parameters=False),
3737
callbacks=callbacks,
3838
reload_dataloaders_every_n_epochs=1,
39+
gradient_clip_val=10.0,
3940
)
4041

4142

0 commit comments

Comments
 (0)