Skip to content

Commit ac94578

Browse files
Refactor training and validation processes by introducing API functions for better modularity and maintainability
1 parent f63cc5e commit ac94578

File tree

2 files changed

+241
-212
lines changed

2 files changed

+241
-212
lines changed

monai/nvflare/nvflare_nnunet.py

Lines changed: 9 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from pyhocon import ConfigFactory
3737
from pyhocon.converter import HOCONConverter
3838

39-
from monai.nvflare.utils import prepare_data_folder_api, compute_validation_metrics, cross_site_evaluation_api, plan_and_preprocess_api
39+
from monai.nvflare.utils import prepare_data_folder_api, finalize_bundle_api, cross_site_evaluation_api, plan_and_preprocess_api, prepare_bundle_api, train_api, validation_api
4040

4141

4242

@@ -146,36 +146,10 @@ def train(
146146
dict
147147
Dictionary containing validation summary metrics.
148148
"""
149-
data_src_cfg = os.path.join(nnunet_root_dir, f"Task{dataset_name_or_id}_data_src_cfg.yaml")
150-
runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir)
151-
152-
if not skip_training:
153-
if not run_with_bundle:
154-
if continue_training:
155-
runner.train_single_model(config="3d_fullres", fold=fold, c="")
156-
else:
157-
runner.train_single_model(config="3d_fullres", fold=fold)
158-
else:
159-
os.environ["BUNDLE_ROOT"] = bundle_root
160-
os.environ["PYTHONPATH"] = os.environ["PYTHONPATH"] + ":" + bundle_root
161-
config_files = os.path.join(bundle_root, "configs", "train.yaml")
162-
if continue_training:
163-
config_files = [os.path.join(bundle_root, "configs", "train.yaml"), os.path.join(bundle_root, "configs", "train_resume.yaml")]
164-
monai.bundle.run(
165-
config_file=config_files,
166-
bundle_root=bundle_root,
167-
nnunet_trainer_class_name=trainer_class_name,
168-
mlflow_experiment_name=experiment_name,
169-
mlflow_run_name="run_" + client_name,
170-
tracking_uri=tracking_uri,
171-
fold_id=fold,
172-
nnunet_root_folder=nnunet_root_dir,
173-
reload_checkpoint_epoch=resume_epoch
174-
)
175-
nnunet_config = {"dataset_name_or_id": dataset_name_or_id, "nnunet_trainer": trainer_class_name}
176-
convert_monai_bundle_to_nnunet(nnunet_config, bundle_root)
177-
runner.train_single_model(config="3d_fullres", fold=fold, val="")
149+
150+
train_api(nnunet_root_dir, dataset_name_or_id, experiment_name, trainer_class_name, run_with_bundle, bundle_root, skip_training, continue_training, fold, tracking_uri, client_name, resume_epoch)
178151

152+
validation_summary_dict, labels = validation_api(nnunet_root_dir, dataset_name_or_id, trainer_class_name, nnunet_plans_name, fold)
179153
if mlflow_token is not None:
180154
os.environ["MLFLOW_TRACKING_TOKEN"] = mlflow_token
181155
if tracking_uri is not None:
@@ -193,21 +167,7 @@ def train(
193167

194168
runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"])
195169

196-
dataset_file = os.path.join(
197-
runner.nnunet_results,
198-
runner.dataset_name,
199-
f"{trainer_class_name}__{nnunet_plans_name}__3d_fullres",
200-
"dataset.json",
201-
)
202-
203-
with open(dataset_file, "r") as f:
204-
dataset_dict = json.load(f)
205-
labels = dataset_dict["labels"]
206-
labels = {str(v): k for k, v in labels.items()}
207-
208-
validation_summary_dict = compute_validation_metrics(str(Path(runner.nnunet_raw).joinpath(runner.dataset_name, "labelsTr")),
209-
str(Path(runner.nnunet_results).joinpath(runner.dataset_name,f"{trainer_class_name}__{nnunet_plans_name}__3d_fullres", f"fold_{fold}", "validation")),
210-
len(labels)-1)
170+
211171

212172
if len(runs) == 0:
213173
with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}):
@@ -569,112 +529,7 @@ def prepare_bundle(bundle_config, train_extra_configs=None):
569529
None
570530
"""
571531

572-
with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.yaml")) as f:
573-
train_config = yaml.safe_load(f)
574-
train_config["bundle_root"] = bundle_config["bundle_root"]
575-
train_config["tracking_uri"] = bundle_config["tracking_uri"]
576-
train_config["mlflow_experiment_name"] = bundle_config["mlflow_experiment_name"]
577-
train_config["mlflow_run_name"] = bundle_config["mlflow_run_name"]
578-
579-
train_config["dataset_name_or_id"] = bundle_config["dataset_name_or_id"]
580-
train_config["data_src_cfg"] = "$@nnunet_root_folder+'/Task'+@dataset_name_or_id+'_data_src_cfg.yaml'"
581-
train_config["nnunet_root_folder"] = "."
582-
train_config["runner"] = {
583-
"_target_": "nnUNetV2Runner",
584-
"input_config": "$@data_src_cfg",
585-
"trainer_class_name": "@nnunet_trainer_class_name",
586-
"work_dir": "@nnunet_root_folder",
587-
}
588-
589-
train_config["network"] = "$@nnunet_trainer.network._orig_mod"
590-
591-
train_handlers = train_config["train_handlers"]["handlers"]
592-
593-
for idx, handler in enumerate(train_handlers):
594-
if handler["_target_"] == "ValidationHandler":
595-
train_handlers.pop(idx)
596-
break
597-
598-
train_config["train_handlers"]["handlers"] = train_handlers
599-
600-
if train_extra_configs is not None and "resume_epoch" in train_extra_configs:
601-
resume_epoch = train_extra_configs["resume_epoch"]
602-
train_config["initialize"] = [
603-
"$monai.utils.set_determinism(seed=123)",
604-
"[email protected]_name_or_id",
605-
f"$src.trainer.reload_checkpoint(@train#trainer, {resume_epoch}, @iterations, @ckpt_dir, @lr_scheduler)",
606-
]
607-
else:
608-
train_config["initialize"] = ["$monai.utils.set_determinism(seed=123)", "[email protected]_name_or_id"]
609-
610-
if train_extra_configs is not None:
611-
for key in train_extra_configs:
612-
if key != "resume_epoch":
613-
train_config[key] = train_extra_configs[key]
614-
615-
if "Val_Dice" in train_config["val_key_metric"]:
616-
train_config["val_key_metric"] = {"Val_Dice_Local": train_config["val_key_metric"]["Val_Dice"]}
617-
618-
if "Val_Dice_per_class" in train_config["val_additional_metrics"]:
619-
train_config["val_additional_metrics"] = {
620-
"Val_Dice_per_class_Local": train_config["val_additional_metrics"]["Val_Dice_per_class"]
621-
}
622-
if "nnunet_plans_identifier" in bundle_config:
623-
train_config["nnunet_plans_identifier"] = bundle_config["nnunet_plans_identifier"]
624-
625-
if "nnunet_trainer_class_name" in bundle_config:
626-
train_config["nnunet_trainer_class_name"] = bundle_config["nnunet_trainer_class_name"]
627-
628-
629-
train_config["label_dict"] = {
630-
"0": "background",
631-
}
632-
633-
for k, v in bundle_config["label_dict"].items():
634-
if k != "0":
635-
train_config["label_dict"][str(v)] = k
636-
637-
638-
if "region_based" in train_extra_configs:
639-
train_config["train_postprocessing_label_based"] = train_config["train_postprocessing"]
640-
train_config["train_postprocessing"] = train_config["train_postprocessing_region_based"]
641-
train_config["val_additional_metrics"]["Val_Dice_per_class_Local"]["include_background"] = True
642-
train_config["train_additional_metrics"]["Train_Dice_per_class"]["include_background"] = True
643-
644-
645-
train_config["num_classes"] = len(train_config["label_dict"])
646-
with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.json"), "w") as f:
647-
json.dump(train_config, f)
648-
649-
with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.yaml"), "w") as f:
650-
yaml.dump(train_config, f)
651-
652-
if not Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml").exists():
653-
shutil.copy(
654-
Path(bundle_config["bundle_root"]).joinpath("nnUNet", "evaluator", "evaluator.yaml"),
655-
Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml"),
656-
)
657-
658-
with open(Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml")) as f:
659-
evaluate_config = yaml.safe_load(f)
660-
evaluate_config["bundle_root"] = bundle_config["bundle_root"]
661-
662-
evaluate_config["tracking_uri"] = bundle_config["tracking_uri"]
663-
evaluate_config["mlflow_experiment_name"] = bundle_config["mlflow_experiment_name"]
664-
evaluate_config["mlflow_run_name"] = bundle_config["mlflow_run_name"]
665-
666-
if "nnunet_plans_identifier" in bundle_config:
667-
evaluate_config["nnunet_plans_identifier"] = bundle_config["nnunet_plans_identifier"]
668-
if "nnunet_trainer_class_name" in bundle_config:
669-
evaluate_config["nnunet_trainer_class_name"] = bundle_config["nnunet_trainer_class_name"]
670-
671-
with open(Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.json"), "w") as f:
672-
json.dump(evaluate_config, f)
673-
674-
with open(Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml"), "w") as f:
675-
yaml.dump(evaluate_config, f)
676-
677-
return {"evaluate_config": evaluate_config, "train_config": train_config}
532+
prepare_bundle_api(bundle_config, train_extra_configs=train_extra_configs, is_federated=True)
678533

679534

680535

@@ -726,71 +581,14 @@ def finalize_bundle(bundle_root, nnunet_root_dir=None, validate_with_nnunet=True
726581
trains a single model, and logs validation metrics to MLflow.
727582
- The function creates and saves nnUNet-compatible checkpoints in the `models` directory.
728583
"""
729-
print("Finalizing bundle...")
730-
if nnunet_root_dir is None:
731-
raise ValueError("nnunet_root_dir must be provided if validate_with_nnunet is True")
732-
if not Path(bundle_root).joinpath("models", "plans.json").exists():
733-
raise ValueError("plans.json file not found in the models directory of the bundle")
734-
if not Path(bundle_root).joinpath("models", "dataset.json").exists():
735-
raise ValueError("dataset.json file not found in the models directory of the bundle")
736-
737-
print("Converting bundle to nnUNet format...")
738-
739-
with open(Path(bundle_root).joinpath("models","plans.json"),"r") as f:
740-
plans = json.load(f)
741-
742-
with open(Path(bundle_root).joinpath("configs","plans.yaml"),"w") as f:
743-
yaml.dump({"plans": plans}, f)
744-
745-
with open(Path(bundle_root).joinpath("models","dataset.json"),"r") as f:
746-
dataset_json = json.load(f)
747-
748-
with open(Path(bundle_root).joinpath("configs","dataset.yaml"),"w") as f:
749-
yaml.dump({"dataset_json": dataset_json}, f)
750-
751-
checkpoint = {
752-
"trainer_name": trainer_class_name,
753-
"inference_allowed_mirroring_axes": (0, 1, 2),
754-
"init_args": {
755-
"configuration": "3d_fullres",
756-
}
757-
}
758-
759-
torch.save(checkpoint, Path(bundle_root).joinpath("models","nnunet_checkpoint.pth"))
760-
761-
checkpoint_dict = torch.load(Path(bundle_root).joinpath("models",f"fold_{fold}","FL_global_model.pt"))
762-
763-
new_checkpoint_dict = {}
764-
new_checkpoint_dict["network_weights"] = checkpoint_dict["model"]
765-
torch.save(new_checkpoint_dict, Path(bundle_root).joinpath("models",f"fold_{fold}","checkpoint_epoch=1000.pt"))
766-
torch.save(new_checkpoint_dict, Path(bundle_root).joinpath("models",f"fold_{fold}","checkpoint_key_metric=1.0.pt"))
584+
finalize_bundle_api(nnunet_root_dir, bundle_root, trainer_class_name, fold)
585+
767586

768587
if validate_with_nnunet:
769-
data_src_cfg = os.path.join(nnunet_root_dir, f"Task{dataset_name_or_id}_data_src_cfg.yaml")
770-
runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir)
771-
772588
nnunet_config = {"dataset_name_or_id": dataset_name_or_id, "nnunet_trainer": trainer_class_name}
773589
convert_monai_bundle_to_nnunet(nnunet_config, bundle_root)
590+
validation_summary_dict, labels = validation_api(nnunet_root_dir, dataset_name_or_id, trainer_class_name, nnunet_plans_name, fold)
774591

775-
runner.train_single_model(config="3d_fullres", fold=fold, val="")
776-
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
777-
dataset_name = maybe_convert_to_dataset_name(int(dataset_name_or_id))
778-
779-
dataset_file = os.path.join(
780-
runner.nnunet_results,
781-
runner.dataset_name,
782-
f"{trainer_class_name}__{nnunet_plans_name}__3d_fullres",
783-
"dataset.json",
784-
)
785-
786-
with open(dataset_file, "r") as f:
787-
dataset_dict = json.load(f)
788-
labels = dataset_dict["labels"]
789-
labels = {str(v): k for k, v in labels.items()}
790-
791-
validation_summary_dict = compute_validation_metrics(str(Path(runner.nnunet_raw).joinpath(dataset_name, "labelsTr")),
792-
str(Path(runner.nnunet_results).joinpath(dataset_name,f"{trainer_class_name}__{nnunet_plans_name}__3d_fullres", f"fold_{fold}", "validation")),
793-
len(labels)-1)
794592

795593
if mlflow_token is not None:
796594
os.environ["MLFLOW_TRACKING_TOKEN"] = mlflow_token

0 commit comments

Comments
 (0)