@@ -50,6 +50,7 @@ def __init__(self, batch_size=20,
50
50
):
51
51
52
52
super (DeepBLAST , self ).__init__ ()
53
+ self .validation_step_outputs = []
53
54
self .save_hyperparameters (ignore = ['lm' , 'tokenizer' ])
54
55
55
56
if device == 'gpu' : # this is for users, in case they specify gpu
@@ -74,6 +75,7 @@ def __init__(self, batch_size=20,
74
75
n_input , n_units , n_embed , n_layers , dropout = dropout , lm = lm ,
75
76
alignment_mode = alignment_mode ,
76
77
device = device )
78
+ self .tokenizer = tokenizer
77
79
78
80
def align (self , x , y ):
79
81
x_code = get_sequence (x , self .tokenizer )[0 ].to (self .device )
@@ -236,6 +238,7 @@ def validation_step(self, batch, batch_idx):
236
238
predA , theta , gap = self .aligner (seq , order )
237
239
x , xlen , y , ylen = unpack_sequences (seq , order )
238
240
loss = self .compute_loss (xlen , ylen , predA , A , P , G , theta )
241
+ self .validation_step_outputs .append (loss )
239
242
240
243
assert torch .isnan (loss ).item () is False
241
244
@@ -291,27 +294,10 @@ def test_step(self, batch, batch_idx):
291
294
statistics ['key_name' ] = other_names
292
295
return statistics
293
296
294
- def validation_epoch_end (self , outputs ):
295
- loss_f = lambda x : x ['validation_loss' ]
296
- losses = list (map (loss_f , outputs ))
297
- loss = sum (losses ) / len (losses )
298
- self .logger .experiment .add_scalar ('val_loss' , loss , self .global_step )
299
- # self.log('validation_loss') = loss
300
-
301
- # metrics = ['val_tp', 'val_fp', 'val_fn', 'val_perc_id',
302
- # 'val_ppv', 'val_fnr', 'val_fdr']
303
- # scores = []
304
- # for i, m in enumerate(metrics):
305
- # loss_f = lambda x: x['log'][m]
306
- # losses = list(map(loss_f, outputs))
307
- # scalar = sum(losses) / len(losses)
308
- # scores.append(scalar)
309
- # self.logger.experiment.add_scalar(m, scalar, self.global_step)
310
-
311
- tensorboard_logs = dict (
312
- [('val_loss' , loss )] # + list(zip(metrics, scores))
313
- )
314
- return {'val_loss' : loss , 'log' : tensorboard_logs }
297
+ def on_validation_epoch_end (self ):
298
+ epoch_average = torch .stack (self .validation_step_outputs ).mean ()
299
+ self .log ("validation_epoch_average" , epoch_average )
300
+ self .validation_step_outputs .clear () # free memory
315
301
316
302
def configure_optimizers (self ):
317
303
# Freeze language model
0 commit comments