Skip to content

Commit 310745f

Browse files
Refactor compute_validation_metrics to correct metric indexing and remove debug prints
1 parent f5ea8e8 commit 310745f

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

monai/nvflare/utils.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -407,13 +407,10 @@ def compute_validation_metrics(gt_folder, pred_folder, n_labels=1):
407407
)
408408
dice = dice_fn(to_onehot(pred_array[None])[None], to_onehot(gt_array[None])[None])
409409
for label_id in range(1,1+n_labels):
410-
print(hd_95)
411-
print(asd)
412-
print(dice)
413410
summary['metric_per_case'][idx]["metrics"][str(label_id)] = {}
414-
summary['metric_per_case'][idx]["metrics"][str(label_id)]["HD95"] = hd_95[label_id-1][0].item()
415-
summary['metric_per_case'][idx]["metrics"][str(label_id)]["ASD"] = asd[label_id-1][0].item()
416-
summary['metric_per_case'][idx]["metrics"][str(label_id)]["Dice"] = dice[label_id-1][0].item()
411+
summary['metric_per_case'][idx]["metrics"][str(label_id)]["HD95"] = hd_95[0][label_id-1].item()
412+
summary['metric_per_case'][idx]["metrics"][str(label_id)]["ASD"] = asd[0][label_id-1].item()
413+
summary['metric_per_case'][idx]["metrics"][str(label_id)]["Dice"] = dice[0][label_id-1].item()
417414

418415
for label_id in range(1,1+n_labels):
419416
summary["mean"] = {}
@@ -663,7 +660,7 @@ def train_api(nnunet_root_dir, dataset_name_or_id, experiment_name, trainer_clas
663660
def validation_api(nnunet_root_dir, dataset_name_or_id, trainer_class_name="nnUNetTrainer", nnunet_plans_name="nnUNetPlans", fold=0):
664661
data_src_cfg = os.path.join(nnunet_root_dir, f"Task{dataset_name_or_id}_data_src_cfg.yaml")
665662
runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir)
666-
#runner.train_single_model(config="3d_fullres", fold=fold, val="")
663+
runner.train_single_model(config="3d_fullres", fold=fold, val="")
667664
dataset_file = os.path.join(
668665
runner.nnunet_results,
669666
runner.dataset_name,

0 commit comments

Comments
 (0)