@@ -199,14 +199,21 @@ def prepare_data_folder_api(data_dir,
199
199
modality_id = list (modality_dict .keys ())[0 ]
200
200
case_id = Path (case [modality_id ]).name [: - len (modality_dict [modality_id ])]
201
201
data_list ["training" ][idx ]["label" ] = str (Path ("labelsTr" ).joinpath (case_id + modality_dict ["label" ]))
202
+ elif dataset_format == "monai-label" :
203
+ data_list = data_dir
202
204
else :
203
205
raise ValueError ("Dataset format not supported" )
204
206
205
207
for idx , train_case in enumerate (data_list ["training" ]):
206
208
for modality_id in modality_dict :
207
- data_list ["training" ][idx ][modality_id + "_is_file" ] = (
208
- Path (data_dir ).joinpath (data_list ["training" ][idx ][modality_id ]).is_file ()
209
- )
209
+ if dataset_format == "monai-label" :
210
+ data_list ["training" ][idx ][modality_id + "_is_file" ] = (
211
+ Path (data_list ["training" ][idx ][modality_id ]).is_file ()
212
+ )
213
+ else :
214
+ data_list ["training" ][idx ][modality_id + "_is_file" ] = (
215
+ Path (data_dir ).joinpath (data_list ["training" ][idx ][modality_id ]).is_file ()
216
+ )
210
217
if "image" not in data_list ["training" ][idx ] and modality_id != "label" :
211
218
data_list ["training" ][idx ]["image" ] = data_list ["training" ][idx ][modality_id ]
212
219
data_list ["training" ][idx ]["fold" ] = 0
@@ -222,7 +229,13 @@ def prepare_data_folder_api(data_dir,
222
229
for j in range (fold_size ):
223
230
data_list ["training" ][i * fold_size + j ]["fold" ] = i
224
231
225
- datalist_file = Path (data_dir ).joinpath (f"{ experiment_name } _folds.json" )
232
+ if dataset_format == "monai-label" :
233
+ monai_label_data_dir = Path (data_list ["training" ][0 ]["image" ]).parent
234
+ datalist_file = Path (monai_label_data_dir ).joinpath (f"{ experiment_name } _folds.json" )
235
+ dataroot = str (monai_label_data_dir )
236
+ else :
237
+ datalist_file = Path (data_dir ).joinpath (f"{ experiment_name } _folds.json" )
238
+ dataroot = str (data_dir )
226
239
with open (datalist_file , "w" , encoding = "utf-8" ) as f :
227
240
json .dump (data_list , f , ensure_ascii = False , indent = 4 )
228
241
@@ -236,11 +249,14 @@ def prepare_data_folder_api(data_dir,
236
249
"modality" : modality_list ,
237
250
"dataset_name_or_id" : dataset_name_or_id ,
238
251
"datalist" : str (datalist_file ),
239
- "dataroot" : str ( data_dir ) ,
252
+ "dataroot" : dataroot ,
240
253
}
241
254
if labels is not None :
242
255
print ("Labels: " , labels )
243
256
data_src ["labels" ] = labels
257
+ for label in labels :
258
+ if isinstance (labels [label ], str ):
259
+ labels [label ] = labels [label ].split ("," )
244
260
if regions_class_order is not None :
245
261
data_src ["regions_class_order" ] = regions_class_order
246
262
@@ -429,6 +445,43 @@ def plan_and_preprocess_api(nnunet_root_dir, dataset_name_or_id, trainer_class_n
429
445
return nnunet_plans
430
446
431
447
def prepare_bundle_api (bundle_config , train_extra_configs = None , is_federated = False ):
448
+ """
449
+ Prepare and update MONAI bundle configuration files for training and evaluation, supporting both standard and federated workflows.
450
+ This function loads, modifies, and saves configuration files (YAML/JSON) for MONAI nnUNet bundles, injecting runtime parameters,
451
+ handling federated learning specifics, and updating label dictionaries and metrics. It also manages MLflow tracking parameters
452
+ and ensures all necessary configuration files are present and up-to-date.
453
+ Parameters
454
+ ----------
455
+ bundle_config : dict
456
+ Dictionary containing bundle configuration parameters. Expected keys:
457
+ - bundle_root (str): Root directory of the bundle.
458
+ - tracking_uri (str): URI for MLflow tracking.
459
+ - mlflow_experiment_name (str): MLflow experiment name.
460
+ - mlflow_run_name (str): MLflow run name.
461
+ - dataset_name_or_id (str): Dataset identifier or name.
462
+ - label_dict (dict): Mapping of label indices to label names.
463
+ - nnunet_plans_identifier (str, optional): Identifier for nnUNet plans.
464
+ - nnunet_trainer_class_name (str, optional): Name of the nnUNet trainer class.
465
+ - dataset_name (str, optional): Human-readable dataset name.
466
+ train_extra_configs : dict, optional
467
+ Additional training configuration parameters. May include:
468
+ - resume_epoch (int): Epoch to resume training from.
469
+ - region_based (bool): Whether to use region-based postprocessing and metrics.
470
+ Any other keys will be injected into the training configuration.
471
+ is_federated : bool, default=False
472
+ Whether to prepare the configuration for federated learning.
473
+ Returns
474
+ -------
475
+ dict
476
+ Dictionary containing the updated configuration objects:
477
+ - "evaluate_config": The evaluation configuration dictionary.
478
+ - "train_config": The training configuration dictionary.
479
+ Notes
480
+ -----
481
+ - This function modifies and overwrites configuration files in-place within the bundle directory.
482
+ - Handles both standard and federated learning scenarios, including metric renaming and handler removal for federated mode.
483
+ - Updates label dictionaries and ensures consistency across all configuration files.
484
+ """
432
485
with open (Path (bundle_config ["bundle_root" ]).joinpath ("configs" , "train.yaml" )) as f :
433
486
train_config = yaml .safe_load (f )
434
487
train_config ["bundle_root" ] = bundle_config ["bundle_root" ]
@@ -501,7 +554,7 @@ def prepare_bundle_api(bundle_config, train_extra_configs=None, is_federated=Fal
501
554
else :
502
555
train_config ["train" ]["train_data" ] = "$[{'case_identifier':k} for k in @nnunet_trainer.dataloader_train.generator._data.identifiers]"
503
556
504
- if "region_based" in train_extra_configs :
557
+ if train_extra_configs is not None and "region_based" in train_extra_configs :
505
558
if "train_postprocessing_label_based" not in train_config :
506
559
train_config ["train_postprocessing_label_based" ] = train_config ["train_postprocessing" ]
507
560
train_config ["train_postprocessing" ] = train_config ["train_postprocessing_region_based" ]
@@ -563,8 +616,8 @@ def prepare_bundle_api(bundle_config, train_extra_configs=None, is_federated=Fal
563
616
if "nnunet_trainer_class_name" in bundle_config :
564
617
mlflow_params ["nnunet_trainer_class_name" ] = bundle_config ["nnunet_trainer_class_name" ]
565
618
566
- if "dataset_name " in bundle_config :
567
- mlflow_params ["dataset_name " ] = bundle_config ["dataset_name " ]
619
+ if "dataset_name_or_id " in bundle_config :
620
+ mlflow_params ["dataset_name_or_id " ] = bundle_config ["dataset_name_or_id " ]
568
621
569
622
with open (Path (bundle_config ["bundle_root" ]).joinpath ("nnUNet" , "params.yaml" ), "w" ) as f :
570
623
yaml .dump (mlflow_params , f )
0 commit comments