@@ -407,13 +407,10 @@ def compute_validation_metrics(gt_folder, pred_folder, n_labels=1):
407
407
)
408
408
dice = dice_fn (to_onehot (pred_array [None ])[None ], to_onehot (gt_array [None ])[None ])
409
409
for label_id in range (1 ,1 + n_labels ):
410
- print (hd_95 )
411
- print (asd )
412
- print (dice )
413
410
summary ['metric_per_case' ][idx ]["metrics" ][str (label_id )] = {}
414
- summary ['metric_per_case' ][idx ]["metrics" ][str (label_id )]["HD95" ] = hd_95 [label_id - 1 ][ 0 ].item ()
415
- summary ['metric_per_case' ][idx ]["metrics" ][str (label_id )]["ASD" ] = asd [label_id - 1 ][ 0 ].item ()
416
- summary ['metric_per_case' ][idx ]["metrics" ][str (label_id )]["Dice" ] = dice [label_id - 1 ][ 0 ].item ()
411
+ summary ['metric_per_case' ][idx ]["metrics" ][str (label_id )]["HD95" ] = hd_95 [0 ][ label_id - 1 ].item ()
412
+ summary ['metric_per_case' ][idx ]["metrics" ][str (label_id )]["ASD" ] = asd [0 ][ label_id - 1 ].item ()
413
+ summary ['metric_per_case' ][idx ]["metrics" ][str (label_id )]["Dice" ] = dice [0 ][ label_id - 1 ].item ()
417
414
418
415
for label_id in range (1 ,1 + n_labels ):
419
416
summary ["mean" ] = {}
@@ -663,7 +660,7 @@ def train_api(nnunet_root_dir, dataset_name_or_id, experiment_name, trainer_clas
663
660
def validation_api (nnunet_root_dir , dataset_name_or_id , trainer_class_name = "nnUNetTrainer" , nnunet_plans_name = "nnUNetPlans" , fold = 0 ):
664
661
data_src_cfg = os .path .join (nnunet_root_dir , f"Task{ dataset_name_or_id } _data_src_cfg.yaml" )
665
662
runner = nnUNetV2Runner (input_config = data_src_cfg , trainer_class_name = trainer_class_name , work_dir = nnunet_root_dir )
666
- # runner.train_single_model(config="3d_fullres", fold=fold, val="")
663
+ runner .train_single_model (config = "3d_fullres" , fold = fold , val = "" )
667
664
dataset_file = os .path .join (
668
665
runner .nnunet_results ,
669
666
runner .dataset_name ,
0 commit comments