@@ -46,6 +46,8 @@ def train(
46
46
fold = 0 ,
47
47
bundle_root = None ,
48
48
mlflow_token = None ,
49
+ continue_training = False ,
50
+ resume_epoch = "latest" ,
49
51
):
50
52
"""
51
53
@@ -75,6 +77,10 @@ def train(
75
77
Root directory for MONAI bundle, by default None.
76
78
mlflow_token : str, optional
77
79
Token for MLflow authentication, by default None.
80
+ continue_training : bool, optional
81
+ Whether to continue training from a checkpoint, by default False.
82
+ resume_epoch : int, optional
83
+ Epoch to resume training from, by default "latest".
78
84
79
85
Returns
80
86
-------
@@ -89,15 +95,19 @@ def train(
89
95
else :
90
96
os .environ ["BUNDLE_ROOT" ] = bundle_root
91
97
os .environ ["PYTHONPATH" ] = os .environ ["PYTHONPATH" ] + ":" + bundle_root
98
+ config_files = os .path .join (bundle_root , "configs" , "train_resume.yaml" )
99
+ if continue_training :
100
+ config_files = [os .path .join (bundle_root , "configs" , "train.yaml" ), os .path .join (bundle_root , "configs" , "train_continue.yaml" )]
92
101
monai .bundle .run (
93
- config_file = Path ( bundle_root ). joinpath ( "configs/train.yaml" ) ,
102
+ config_file = config_files ,
94
103
bundle_root = bundle_root ,
95
104
nnunet_trainer_class_name = trainer_class_name ,
96
105
mlflow_experiment_name = experiment_name ,
97
106
mlflow_run_name = "run_" + client_name ,
98
107
tracking_uri = tracking_uri ,
99
108
fold_id = fold ,
100
109
nnunet_root_folder = nnunet_root_dir ,
110
+ reload_checkpoint_epoch = resume_epoch
101
111
)
102
112
nnunet_config = {"dataset_name_or_id" : dataset_name_or_id , "nnunet_trainer" : trainer_class_name }
103
113
convert_monai_bundle_to_nnunet (nnunet_config , bundle_root )
@@ -619,6 +629,7 @@ def prepare_bundle(bundle_config, train_extra_configs=None):
619
629
train_config ["mlflow_run_name" ] = bundle_config ["mlflow_run_name" ]
620
630
621
631
train_config ["data_src_cfg" ] = "$@nnunet_root_folder+'/data_src_cfg.yaml'"
632
+ train_config ["nnunet_root_folder" ] = "."
622
633
train_config ["runner" ] = {
623
634
"_target_" : "nnUNetV2Runner" ,
624
635
"input_config" : "$@data_src_cfg" ,
0 commit comments