Skip to content

Commit ec661fa

Browse files
authored
Merge pull request #165 from aryankeluskar/master
migrated to newer version of lightning
2 parents 20b5272 + b022f17 commit ec661fa

File tree

2 files changed

+8
-22
lines changed

2 files changed

+8
-22
lines changed

deepblast/dataset/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def gap_mask(states: str, sparse=False):
406406
if sparse:
407407
return mat
408408
else:
409-
return mat.toarray().astype(np.bool)
409+
return mat.toarray().astype(bool)
410410

411411

412412
def window(seq, n=2):

deepblast/trainer.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(self, batch_size=20,
5050
):
5151

5252
super(DeepBLAST, self).__init__()
53+
self.validation_step_outputs = []
5354
self.save_hyperparameters(ignore=['lm', 'tokenizer'])
5455

5556
if device == 'gpu': # this is for users, in case they specify gpu
@@ -74,6 +75,7 @@ def __init__(self, batch_size=20,
7475
n_input, n_units, n_embed, n_layers, dropout=dropout, lm=lm,
7576
alignment_mode=alignment_mode,
7677
device=device)
78+
self.tokenizer = tokenizer
7779

7880
def align(self, x, y):
7981
x_code = get_sequence(x, self.tokenizer)[0].to(self.device)
@@ -236,6 +238,7 @@ def validation_step(self, batch, batch_idx):
236238
predA, theta, gap = self.aligner(seq, order)
237239
x, xlen, y, ylen = unpack_sequences(seq, order)
238240
loss = self.compute_loss(xlen, ylen, predA, A, P, G, theta)
241+
self.validation_step_outputs.append(loss)
239242

240243
assert torch.isnan(loss).item() is False
241244

@@ -291,27 +294,10 @@ def test_step(self, batch, batch_idx):
291294
statistics['key_name'] = other_names
292295
return statistics
293296

294-
def validation_epoch_end(self, outputs):
295-
loss_f = lambda x: x['validation_loss']
296-
losses = list(map(loss_f, outputs))
297-
loss = sum(losses) / len(losses)
298-
self.logger.experiment.add_scalar('val_loss', loss, self.global_step)
299-
# self.log('validation_loss') = loss
300-
301-
# metrics = ['val_tp', 'val_fp', 'val_fn', 'val_perc_id',
302-
# 'val_ppv', 'val_fnr', 'val_fdr']
303-
# scores = []
304-
# for i, m in enumerate(metrics):
305-
# loss_f = lambda x: x['log'][m]
306-
# losses = list(map(loss_f, outputs))
307-
# scalar = sum(losses) / len(losses)
308-
# scores.append(scalar)
309-
# self.logger.experiment.add_scalar(m, scalar, self.global_step)
310-
311-
tensorboard_logs = dict(
312-
[('val_loss', loss)] # + list(zip(metrics, scores))
313-
)
314-
return {'val_loss': loss, 'log': tensorboard_logs}
297+
def on_validation_epoch_end(self):
298+
epoch_average = torch.stack(self.validation_step_outputs).mean()
299+
self.log("validation_epoch_average", epoch_average)
300+
self.validation_step_outputs.clear() # free memory
315301

316302
def configure_optimizers(self):
317303
# Freeze language model

0 commit comments

Comments
 (0)