36
36
from pyhocon import ConfigFactory
37
37
from pyhocon .converter import HOCONConverter
38
38
39
- from monai .nvflare .utils import prepare_data_folder_api , compute_validation_metrics , cross_site_evaluation_api , plan_and_preprocess_api
39
+ from monai .nvflare .utils import prepare_data_folder_api , finalize_bundle_api , cross_site_evaluation_api , plan_and_preprocess_api , prepare_bundle_api , train_api , validation_api
40
40
41
41
42
42
@@ -146,36 +146,10 @@ def train(
146
146
dict
147
147
Dictionary containing validation summary metrics.
148
148
"""
149
- data_src_cfg = os .path .join (nnunet_root_dir , f"Task{ dataset_name_or_id } _data_src_cfg.yaml" )
150
- runner = nnUNetV2Runner (input_config = data_src_cfg , trainer_class_name = trainer_class_name , work_dir = nnunet_root_dir )
151
-
152
- if not skip_training :
153
- if not run_with_bundle :
154
- if continue_training :
155
- runner .train_single_model (config = "3d_fullres" , fold = fold , c = "" )
156
- else :
157
- runner .train_single_model (config = "3d_fullres" , fold = fold )
158
- else :
159
- os .environ ["BUNDLE_ROOT" ] = bundle_root
160
- os .environ ["PYTHONPATH" ] = os .environ ["PYTHONPATH" ] + ":" + bundle_root
161
- config_files = os .path .join (bundle_root , "configs" , "train.yaml" )
162
- if continue_training :
163
- config_files = [os .path .join (bundle_root , "configs" , "train.yaml" ), os .path .join (bundle_root , "configs" , "train_resume.yaml" )]
164
- monai .bundle .run (
165
- config_file = config_files ,
166
- bundle_root = bundle_root ,
167
- nnunet_trainer_class_name = trainer_class_name ,
168
- mlflow_experiment_name = experiment_name ,
169
- mlflow_run_name = "run_" + client_name ,
170
- tracking_uri = tracking_uri ,
171
- fold_id = fold ,
172
- nnunet_root_folder = nnunet_root_dir ,
173
- reload_checkpoint_epoch = resume_epoch
174
- )
175
- nnunet_config = {"dataset_name_or_id" : dataset_name_or_id , "nnunet_trainer" : trainer_class_name }
176
- convert_monai_bundle_to_nnunet (nnunet_config , bundle_root )
177
- runner .train_single_model (config = "3d_fullres" , fold = fold , val = "" )
149
+
150
+ train_api (nnunet_root_dir , dataset_name_or_id , experiment_name , trainer_class_name , run_with_bundle , bundle_root , skip_training , continue_training , fold , tracking_uri , client_name , resume_epoch )
178
151
152
+ validation_summary_dict , labels = validation_api (nnunet_root_dir , dataset_name_or_id , trainer_class_name , nnunet_plans_name , fold )
179
153
if mlflow_token is not None :
180
154
os .environ ["MLFLOW_TRACKING_TOKEN" ] = mlflow_token
181
155
if tracking_uri is not None :
@@ -193,21 +167,7 @@ def train(
193
167
194
168
runs = mlflow .search_runs (experiment_names = [experiment_name ], filter_string = filter , order_by = ["start_time DESC" ])
195
169
196
- dataset_file = os .path .join (
197
- runner .nnunet_results ,
198
- runner .dataset_name ,
199
- f"{ trainer_class_name } __{ nnunet_plans_name } __3d_fullres" ,
200
- "dataset.json" ,
201
- )
202
-
203
- with open (dataset_file , "r" ) as f :
204
- dataset_dict = json .load (f )
205
- labels = dataset_dict ["labels" ]
206
- labels = {str (v ): k for k , v in labels .items ()}
207
-
208
- validation_summary_dict = compute_validation_metrics (str (Path (runner .nnunet_raw ).joinpath (runner .dataset_name , "labelsTr" )),
209
- str (Path (runner .nnunet_results ).joinpath (runner .dataset_name ,f"{ trainer_class_name } __{ nnunet_plans_name } __3d_fullres" , f"fold_{ fold } " , "validation" )),
210
- len (labels )- 1 )
170
+
211
171
212
172
if len (runs ) == 0 :
213
173
with mlflow .start_run (run_name = f"run_{ client_name } " , tags = {"client" : client_name }):
@@ -569,112 +529,7 @@ def prepare_bundle(bundle_config, train_extra_configs=None):
569
529
None
570
530
"""
571
531
572
- with open (Path (bundle_config ["bundle_root" ]).joinpath ("configs" , "train.yaml" )) as f :
573
- train_config = yaml .safe_load (f )
574
- train_config ["bundle_root" ] = bundle_config ["bundle_root" ]
575
- train_config ["tracking_uri" ] = bundle_config ["tracking_uri" ]
576
- train_config ["mlflow_experiment_name" ] = bundle_config ["mlflow_experiment_name" ]
577
- train_config ["mlflow_run_name" ] = bundle_config ["mlflow_run_name" ]
578
-
579
- train_config ["dataset_name_or_id" ] = bundle_config ["dataset_name_or_id" ]
580
- train_config ["data_src_cfg" ] = "$@nnunet_root_folder+'/Task'+@dataset_name_or_id+'_data_src_cfg.yaml'"
581
- train_config ["nnunet_root_folder" ] = "."
582
- train_config ["runner" ] = {
583
- "_target_" : "nnUNetV2Runner" ,
584
- "input_config" : "$@data_src_cfg" ,
585
- "trainer_class_name" : "@nnunet_trainer_class_name" ,
586
- "work_dir" : "@nnunet_root_folder" ,
587
- }
588
-
589
- train_config ["network" ] = "$@nnunet_trainer.network._orig_mod"
590
-
591
- train_handlers = train_config ["train_handlers" ]["handlers" ]
592
-
593
- for idx , handler in enumerate (train_handlers ):
594
- if handler ["_target_" ] == "ValidationHandler" :
595
- train_handlers .pop (idx )
596
- break
597
-
598
- train_config ["train_handlers" ]["handlers" ] = train_handlers
599
-
600
- if train_extra_configs is not None and "resume_epoch" in train_extra_configs :
601
- resume_epoch = train_extra_configs ["resume_epoch" ]
602
- train_config ["initialize" ] = [
603
- "$monai.utils.set_determinism(seed=123)" ,
604
-
605
- f"$src.trainer.reload_checkpoint(@train#trainer, { resume_epoch } , @iterations, @ckpt_dir, @lr_scheduler)" ,
606
- ]
607
- else :
608
- train_config [
"initialize" ]
= [
"$monai.utils.set_determinism(seed=123)" ,
"[email protected] _name_or_id" ]
609
-
610
- if train_extra_configs is not None :
611
- for key in train_extra_configs :
612
- if key != "resume_epoch" :
613
- train_config [key ] = train_extra_configs [key ]
614
-
615
- if "Val_Dice" in train_config ["val_key_metric" ]:
616
- train_config ["val_key_metric" ] = {"Val_Dice_Local" : train_config ["val_key_metric" ]["Val_Dice" ]}
617
-
618
- if "Val_Dice_per_class" in train_config ["val_additional_metrics" ]:
619
- train_config ["val_additional_metrics" ] = {
620
- "Val_Dice_per_class_Local" : train_config ["val_additional_metrics" ]["Val_Dice_per_class" ]
621
- }
622
- if "nnunet_plans_identifier" in bundle_config :
623
- train_config ["nnunet_plans_identifier" ] = bundle_config ["nnunet_plans_identifier" ]
624
-
625
- if "nnunet_trainer_class_name" in bundle_config :
626
- train_config ["nnunet_trainer_class_name" ] = bundle_config ["nnunet_trainer_class_name" ]
627
-
628
-
629
- train_config ["label_dict" ] = {
630
- "0" : "background" ,
631
- }
632
-
633
- for k , v in bundle_config ["label_dict" ].items ():
634
- if k != "0" :
635
- train_config ["label_dict" ][str (v )] = k
636
-
637
-
638
- if "region_based" in train_extra_configs :
639
- train_config ["train_postprocessing_label_based" ] = train_config ["train_postprocessing" ]
640
- train_config ["train_postprocessing" ] = train_config ["train_postprocessing_region_based" ]
641
- train_config ["val_additional_metrics" ]["Val_Dice_per_class_Local" ]["include_background" ] = True
642
- train_config ["train_additional_metrics" ]["Train_Dice_per_class" ]["include_background" ] = True
643
-
644
-
645
- train_config ["num_classes" ] = len (train_config ["label_dict" ])
646
- with open (Path (bundle_config ["bundle_root" ]).joinpath ("configs" , "train.json" ), "w" ) as f :
647
- json .dump (train_config , f )
648
-
649
- with open (Path (bundle_config ["bundle_root" ]).joinpath ("configs" , "train.yaml" ), "w" ) as f :
650
- yaml .dump (train_config , f )
651
-
652
- if not Path (bundle_config ["bundle_root" ]).joinpath ("configs" , "evaluate.yaml" ).exists ():
653
- shutil .copy (
654
- Path (bundle_config ["bundle_root" ]).joinpath ("nnUNet" , "evaluator" , "evaluator.yaml" ),
655
- Path (bundle_config ["bundle_root" ]).joinpath ("configs" , "evaluate.yaml" ),
656
- )
657
-
658
- with open (Path (bundle_config ["bundle_root" ]).joinpath ("configs" , "evaluate.yaml" )) as f :
659
- evaluate_config = yaml .safe_load (f )
660
- evaluate_config ["bundle_root" ] = bundle_config ["bundle_root" ]
661
-
662
- evaluate_config ["tracking_uri" ] = bundle_config ["tracking_uri" ]
663
- evaluate_config ["mlflow_experiment_name" ] = bundle_config ["mlflow_experiment_name" ]
664
- evaluate_config ["mlflow_run_name" ] = bundle_config ["mlflow_run_name" ]
665
-
666
- if "nnunet_plans_identifier" in bundle_config :
667
- evaluate_config ["nnunet_plans_identifier" ] = bundle_config ["nnunet_plans_identifier" ]
668
- if "nnunet_trainer_class_name" in bundle_config :
669
- evaluate_config ["nnunet_trainer_class_name" ] = bundle_config ["nnunet_trainer_class_name" ]
670
-
671
- with open (Path (bundle_config ["bundle_root" ]).joinpath ("configs" , "evaluate.json" ), "w" ) as f :
672
- json .dump (evaluate_config , f )
673
-
674
- with open (Path (bundle_config ["bundle_root" ]).joinpath ("configs" , "evaluate.yaml" ), "w" ) as f :
675
- yaml .dump (evaluate_config , f )
676
-
677
- return {"evaluate_config" : evaluate_config , "train_config" : train_config }
532
+ prepare_bundle_api (bundle_config , train_extra_configs = train_extra_configs , is_federated = True )
678
533
679
534
680
535
@@ -726,71 +581,14 @@ def finalize_bundle(bundle_root, nnunet_root_dir=None, validate_with_nnunet=True
726
581
trains a single model, and logs validation metrics to MLflow.
727
582
- The function creates and saves nnUNet-compatible checkpoints in the `models` directory.
728
583
"""
729
- print ("Finalizing bundle..." )
730
- if nnunet_root_dir is None :
731
- raise ValueError ("nnunet_root_dir must be provided if validate_with_nnunet is True" )
732
- if not Path (bundle_root ).joinpath ("models" , "plans.json" ).exists ():
733
- raise ValueError ("plans.json file not found in the models directory of the bundle" )
734
- if not Path (bundle_root ).joinpath ("models" , "dataset.json" ).exists ():
735
- raise ValueError ("dataset.json file not found in the models directory of the bundle" )
736
-
737
- print ("Converting bundle to nnUNet format..." )
738
-
739
- with open (Path (bundle_root ).joinpath ("models" ,"plans.json" ),"r" ) as f :
740
- plans = json .load (f )
741
-
742
- with open (Path (bundle_root ).joinpath ("configs" ,"plans.yaml" ),"w" ) as f :
743
- yaml .dump ({"plans" : plans }, f )
744
-
745
- with open (Path (bundle_root ).joinpath ("models" ,"dataset.json" ),"r" ) as f :
746
- dataset_json = json .load (f )
747
-
748
- with open (Path (bundle_root ).joinpath ("configs" ,"dataset.yaml" ),"w" ) as f :
749
- yaml .dump ({"dataset_json" : dataset_json }, f )
750
-
751
- checkpoint = {
752
- "trainer_name" : trainer_class_name ,
753
- "inference_allowed_mirroring_axes" : (0 , 1 , 2 ),
754
- "init_args" : {
755
- "configuration" : "3d_fullres" ,
756
- }
757
- }
758
-
759
- torch .save (checkpoint , Path (bundle_root ).joinpath ("models" ,"nnunet_checkpoint.pth" ))
760
-
761
- checkpoint_dict = torch .load (Path (bundle_root ).joinpath ("models" ,f"fold_{ fold } " ,"FL_global_model.pt" ))
762
-
763
- new_checkpoint_dict = {}
764
- new_checkpoint_dict ["network_weights" ] = checkpoint_dict ["model" ]
765
- torch .save (new_checkpoint_dict , Path (bundle_root ).joinpath ("models" ,f"fold_{ fold } " ,"checkpoint_epoch=1000.pt" ))
766
- torch .save (new_checkpoint_dict , Path (bundle_root ).joinpath ("models" ,f"fold_{ fold } " ,"checkpoint_key_metric=1.0.pt" ))
584
+ finalize_bundle_api (nnunet_root_dir , bundle_root , trainer_class_name , fold )
585
+
767
586
768
587
if validate_with_nnunet :
769
- data_src_cfg = os .path .join (nnunet_root_dir , f"Task{ dataset_name_or_id } _data_src_cfg.yaml" )
770
- runner = nnUNetV2Runner (input_config = data_src_cfg , trainer_class_name = trainer_class_name , work_dir = nnunet_root_dir )
771
-
772
588
nnunet_config = {"dataset_name_or_id" : dataset_name_or_id , "nnunet_trainer" : trainer_class_name }
773
589
convert_monai_bundle_to_nnunet (nnunet_config , bundle_root )
590
+ validation_summary_dict , labels = validation_api (nnunet_root_dir , dataset_name_or_id , trainer_class_name , nnunet_plans_name , fold )
774
591
775
- runner .train_single_model (config = "3d_fullres" , fold = fold , val = "" )
776
- from nnunetv2 .utilities .dataset_name_id_conversion import maybe_convert_to_dataset_name
777
- dataset_name = maybe_convert_to_dataset_name (int (dataset_name_or_id ))
778
-
779
- dataset_file = os .path .join (
780
- runner .nnunet_results ,
781
- runner .dataset_name ,
782
- f"{ trainer_class_name } __{ nnunet_plans_name } __3d_fullres" ,
783
- "dataset.json" ,
784
- )
785
-
786
- with open (dataset_file , "r" ) as f :
787
- dataset_dict = json .load (f )
788
- labels = dataset_dict ["labels" ]
789
- labels = {str (v ): k for k , v in labels .items ()}
790
-
791
- validation_summary_dict = compute_validation_metrics (str (Path (runner .nnunet_raw ).joinpath (dataset_name , "labelsTr" )),
792
- str (Path (runner .nnunet_results ).joinpath (dataset_name ,f"{ trainer_class_name } __{ nnunet_plans_name } __3d_fullres" , f"fold_{ fold } " , "validation" )),
793
- len (labels )- 1 )
794
592
795
593
if mlflow_token is not None :
796
594
os .environ ["MLFLOW_TRACKING_TOKEN" ] = mlflow_token
0 commit comments