Skip to content

Commit fd12ece

Browse files
Add continue_training parameter to nnUNetExecutor and update job config timeout
1 parent 47798af commit fd12ece

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

monai/nvflare/nnunet_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(
124124
modality_list=None,
125125
train_extra_configs=None,
126126
exclude_vars=None,
127+
continue_training=False,
127128
):
128129
super().__init__()
129130

@@ -147,6 +148,7 @@ def __init__(
147148
self.bundle_root = bundle_root
148149
self.train_extra_configs = train_extra_configs
149150
self.modality_list = modality_list
151+
self.continue_training = continue_training
150152

151153
def handle_event(self, event_type: str, fl_ctx: FLContext):
152154
if event_type == EventType.START_RUN:
@@ -305,6 +307,7 @@ def train(self):
305307
dataset_name_or_id=self.nnunet_config["dataset_name_or_id"],
306308
run_with_bundle=True if self.bundle_root is not None else False,
307309
bundle_root=self.bundle_root,
310+
continue_training=self.continue_training
308311
)
309312
outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=validation_summary, meta={})
310313
return outgoing_dxo.to_shareable()

monai/nvflare/nvflare_generate_job_configs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def train_config(clients, experiment, root_dir, script_dir, nvflare_exec):
665665
"min_responses_required": 0,
666666
"wait_time_after_min_received": 10,
667667
"task_name": task_name,
668-
"timeout": 6000,
668+
"timeout": 600000,
669669
},
670670
}
671671
],
@@ -699,6 +699,7 @@ def train_config(clients, experiment, root_dir, script_dir, nvflare_exec):
699699
},
700700
"client_name": clients[client_id]["client_name"],
701701
"tracking_uri": experiment["tracking_uri"],
702+
"continue_training": experiment["continue_training"],
702703
},
703704
},
704705
}

monai/nvflare/nvflare_nnunet.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def train(
4646
fold=0,
4747
bundle_root=None,
4848
mlflow_token=None,
49+
continue_training=False,
50+
resume_epoch="latest",
4951
):
5052
"""
5153
@@ -75,6 +77,10 @@ def train(
7577
Root directory for MONAI bundle, by default None.
7678
mlflow_token : str, optional
7779
Token for MLflow authentication, by default None.
80+
continue_training : bool, optional
81+
Whether to continue training from a checkpoint, by default False.
82+
resume_epoch : int, optional
83+
Epoch to resume training from, by default "latest".
7884
7985
Returns
8086
-------
@@ -89,15 +95,19 @@ def train(
8995
else:
9096
os.environ["BUNDLE_ROOT"] = bundle_root
9197
os.environ["PYTHONPATH"] = os.environ["PYTHONPATH"] + ":" + bundle_root
98+
config_files = os.path.join(bundle_root, "configs", "train_resume.yaml")
99+
if continue_training:
100+
config_files = [os.path.join(bundle_root, "configs", "train.yaml"), os.path.join(bundle_root, "configs", "train_continue.yaml")]
92101
monai.bundle.run(
93-
config_file=Path(bundle_root).joinpath("configs/train.yaml"),
102+
config_file=config_files,
94103
bundle_root=bundle_root,
95104
nnunet_trainer_class_name=trainer_class_name,
96105
mlflow_experiment_name=experiment_name,
97106
mlflow_run_name="run_" + client_name,
98107
tracking_uri=tracking_uri,
99108
fold_id=fold,
100109
nnunet_root_folder=nnunet_root_dir,
110+
reload_checkpoint_epoch=resume_epoch
101111
)
102112
nnunet_config = {"dataset_name_or_id": dataset_name_or_id, "nnunet_trainer": trainer_class_name}
103113
convert_monai_bundle_to_nnunet(nnunet_config, bundle_root)
@@ -619,6 +629,7 @@ def prepare_bundle(bundle_config, train_extra_configs=None):
619629
train_config["mlflow_run_name"] = bundle_config["mlflow_run_name"]
620630

621631
train_config["data_src_cfg"] = "$@nnunet_root_folder+'/data_src_cfg.yaml'"
632+
train_config["nnunet_root_folder"] = "."
622633
train_config["runner"] = {
623634
"_target_": "nnUNetV2Runner",
624635
"input_config": "$@data_src_cfg",

0 commit comments

Comments
 (0)