Skip to content

Commit c8104f4

Browse files
Enhance prepare_data_folder_api to support 'monai-label' dataset format and update dataroot handling in prepare_bundle_api
1 parent 84d25c6 commit c8104f4

File tree

1 file changed

+61
-8
lines changed

1 file changed

+61
-8
lines changed

monai/nvflare/utils.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,14 +199,21 @@ def prepare_data_folder_api(data_dir,
199199
modality_id = list(modality_dict.keys())[0]
200200
case_id = Path(case[modality_id]).name[: -len(modality_dict[modality_id])]
201201
data_list["training"][idx]["label"] = str(Path("labelsTr").joinpath(case_id + modality_dict["label"]))
202+
elif dataset_format == "monai-label":
203+
data_list = data_dir
202204
else:
203205
raise ValueError("Dataset format not supported")
204206

205207
for idx, train_case in enumerate(data_list["training"]):
206208
for modality_id in modality_dict:
207-
data_list["training"][idx][modality_id + "_is_file"] = (
208-
Path(data_dir).joinpath(data_list["training"][idx][modality_id]).is_file()
209-
)
209+
if dataset_format == "monai-label":
210+
data_list["training"][idx][modality_id + "_is_file"] = (
211+
Path(data_list["training"][idx][modality_id]).is_file()
212+
)
213+
else:
214+
data_list["training"][idx][modality_id + "_is_file"] = (
215+
Path(data_dir).joinpath(data_list["training"][idx][modality_id]).is_file()
216+
)
210217
if "image" not in data_list["training"][idx] and modality_id != "label":
211218
data_list["training"][idx]["image"] = data_list["training"][idx][modality_id]
212219
data_list["training"][idx]["fold"] = 0
@@ -222,7 +229,13 @@ def prepare_data_folder_api(data_dir,
222229
for j in range(fold_size):
223230
data_list["training"][i * fold_size + j]["fold"] = i
224231

225-
datalist_file = Path(data_dir).joinpath(f"{experiment_name}_folds.json")
232+
if dataset_format == "monai-label":
233+
monai_label_data_dir = Path(data_list["training"][0]["image"]).parent
234+
datalist_file = Path(monai_label_data_dir).joinpath(f"{experiment_name}_folds.json")
235+
dataroot = str(monai_label_data_dir)
236+
else:
237+
datalist_file = Path(data_dir).joinpath(f"{experiment_name}_folds.json")
238+
dataroot = str(data_dir)
226239
with open(datalist_file, "w", encoding="utf-8") as f:
227240
json.dump(data_list, f, ensure_ascii=False, indent=4)
228241

@@ -236,11 +249,14 @@ def prepare_data_folder_api(data_dir,
236249
"modality": modality_list,
237250
"dataset_name_or_id": dataset_name_or_id,
238251
"datalist": str(datalist_file),
239-
"dataroot": str(data_dir),
252+
"dataroot": dataroot,
240253
}
241254
if labels is not None:
242255
print("Labels: ", labels)
243256
data_src["labels"] = labels
257+
for label in labels:
258+
if isinstance(labels[label], str):
259+
labels[label] = labels[label].split(",")
244260
if regions_class_order is not None:
245261
data_src["regions_class_order"] = regions_class_order
246262

@@ -429,6 +445,43 @@ def plan_and_preprocess_api(nnunet_root_dir, dataset_name_or_id, trainer_class_n
429445
return nnunet_plans
430446

431447
def prepare_bundle_api(bundle_config, train_extra_configs=None, is_federated=False):
448+
"""
449+
Prepare and update MONAI bundle configuration files for training and evaluation, supporting both standard and federated workflows.
450+
This function loads, modifies, and saves configuration files (YAML/JSON) for MONAI nnUNet bundles, injecting runtime parameters,
451+
handling federated learning specifics, and updating label dictionaries and metrics. It also manages MLflow tracking parameters
452+
and ensures all necessary configuration files are present and up-to-date.
453+
Parameters
454+
----------
455+
bundle_config : dict
456+
Dictionary containing bundle configuration parameters. Expected keys:
457+
- bundle_root (str): Root directory of the bundle.
458+
- tracking_uri (str): URI for MLflow tracking.
459+
- mlflow_experiment_name (str): MLflow experiment name.
460+
- mlflow_run_name (str): MLflow run name.
461+
- dataset_name_or_id (str): Dataset identifier or name.
462+
- label_dict (dict): Mapping of label indices to label names.
463+
- nnunet_plans_identifier (str, optional): Identifier for nnUNet plans.
464+
- nnunet_trainer_class_name (str, optional): Name of the nnUNet trainer class.
465+
- dataset_name (str, optional): Human-readable dataset name.
466+
train_extra_configs : dict, optional
467+
Additional training configuration parameters. May include:
468+
- resume_epoch (int): Epoch to resume training from.
469+
- region_based (bool): Whether to use region-based postprocessing and metrics.
470+
Any other keys will be injected into the training configuration.
471+
is_federated : bool, default=False
472+
Whether to prepare the configuration for federated learning.
473+
Returns
474+
-------
475+
dict
476+
Dictionary containing the updated configuration objects:
477+
- "evaluate_config": The evaluation configuration dictionary.
478+
- "train_config": The training configuration dictionary.
479+
Notes
480+
-----
481+
- This function modifies and overwrites configuration files in-place within the bundle directory.
482+
- Handles both standard and federated learning scenarios, including metric renaming and handler removal for federated mode.
483+
- Updates label dictionaries and ensures consistency across all configuration files.
484+
"""
432485
with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.yaml")) as f:
433486
train_config = yaml.safe_load(f)
434487
train_config["bundle_root"] = bundle_config["bundle_root"]
@@ -501,7 +554,7 @@ def prepare_bundle_api(bundle_config, train_extra_configs=None, is_federated=Fal
501554
else:
502555
train_config["train"]["train_data"] = "$[{'case_identifier':k} for k in @nnunet_trainer.dataloader_train.generator._data.identifiers]"
503556

504-
if "region_based" in train_extra_configs:
557+
if train_extra_configs is not None and "region_based" in train_extra_configs:
505558
if "train_postprocessing_label_based" not in train_config:
506559
train_config["train_postprocessing_label_based"] = train_config["train_postprocessing"]
507560
train_config["train_postprocessing"] = train_config["train_postprocessing_region_based"]
@@ -563,8 +616,8 @@ def prepare_bundle_api(bundle_config, train_extra_configs=None, is_federated=Fal
563616
if "nnunet_trainer_class_name" in bundle_config:
564617
mlflow_params["nnunet_trainer_class_name"] = bundle_config["nnunet_trainer_class_name"]
565618

566-
if "dataset_name" in bundle_config:
567-
mlflow_params["dataset_name"] = bundle_config["dataset_name"]
619+
if "dataset_name_or_id" in bundle_config:
620+
mlflow_params["dataset_name_or_id"] = bundle_config["dataset_name_or_id"]
568621

569622
with open(Path(bundle_config["bundle_root"]).joinpath("nnUNet", "params.yaml"), "w") as f:
570623
yaml.dump(mlflow_params, f)

0 commit comments

Comments
 (0)