Skip to content

Commit 9f9261d

Browse files
Add dataset_name parameter to MLflow logging in prepare_bundle_api
1 parent ac94578 commit 9f9261d

File tree

2 files changed

+49
-24
lines changed

2 files changed

+49
-24
lines changed

monai/nvflare/nvflare_nnunet.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def plan_and_preprocess(
252252
mlflow_token=None,
253253
nnunet_plans_name="nnUNetPlans",
254254
trainer_class_name="nnUNetTrainer",
255+
dataset_name=None,
255256
):
256257
"""
257258
Plan and preprocess the dataset using nnUNetV2Runner and log the plans to MLflow.
@@ -299,18 +300,23 @@ def plan_and_preprocess(
299300
print(e)
300301
mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id))
301302

302-
filter = f"""
303-
tags."client" = "{client_name}"
304-
"""
303+
run_name = f"run_plan_and_preprocess_{client_name}"
305304

306-
runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"])
305+
runs = mlflow.search_runs(
306+
experiment_names=[experiment_name],
307+
filter_string=f"tags.mlflow.runName = '{run_name}'",
308+
order_by=["start_time DESC"]
309+
)
310+
tags = {"client": client_name}
311+
if dataset_name is not None:
312+
tags["dataset_name"] = dataset_name
307313

308314
if len(runs) == 0:
309-
with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}):
315+
with mlflow.start_run(run_name=f"run_plan_and_preprocess_{client_name}", tags=tags):
310316
mlflow.log_dict(nnunet_plans, nnunet_plans_name + ".json")
311317

312318
else:
313-
with mlflow.start_run(run_id=runs.iloc[0].run_id, tags={"client": client_name}):
319+
with mlflow.start_run(run_id=runs.iloc[0].run_id, tags=tags):
314320
mlflow.log_dict(nnunet_plans, nnunet_plans_name + ".json")
315321

316322
return nnunet_plans
@@ -330,6 +336,7 @@ def prepare_data_folder(
330336
subfolder_suffix=None,
331337
patient_id_in_file_identifier=True,
332338
trainer_class_name="nnUNetTrainer",
339+
dataset_name=None,
333340
):
334341
"""
335342
Prepare the data folder for nnUNet training and log the data to MLflow.
@@ -394,18 +401,23 @@ def prepare_data_folder(
394401
print(e)
395402
mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id))
396403

397-
filter = f"""
398-
tags."client" = "{client_name}"
399-
"""
400404

401-
runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"])
405+
run_name = f"run_prepare_{client_name}"
402406

407+
runs = mlflow.search_runs(
408+
experiment_names=[experiment_name],
409+
filter_string=f"tags.mlflow.runName = '{run_name}'",
410+
order_by=["start_time DESC"]
411+
)
412+
tags = {"client": client_name}
413+
if dataset_name is not None:
414+
tags["dataset_name"] = dataset_name
403415
try:
404416
if len(runs) == 0:
405-
with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}):
417+
with mlflow.start_run(run_name=f"run_prepare_{client_name}", tags=tags):
406418
mlflow.log_table(pd.DataFrame.from_records(data_list["training"]), f"{client_name}_train.json")
407419
else:
408-
with mlflow.start_run(run_id=runs.iloc[0].run_id, tags={"client": client_name}):
420+
with mlflow.start_run(run_id=runs.iloc[0].run_id, tags=tags):
409421
mlflow.log_table(pd.DataFrame.from_records(data_list["training"]), f"{client_name}_train.json")
410422
except (BrokenPipeError, ConnectionError) as e:
411423
logging.error(f"Failed to log data to MLflow: {e}")
@@ -537,7 +549,7 @@ def prepare_bundle(bundle_config, train_extra_configs=None):
537549
def finalize_bundle(bundle_root, nnunet_root_dir=None, validate_with_nnunet=True,
538550
experiment_name=None, client_name=None, tracking_uri=None,
539551
dataset_name_or_id=None, trainer_class_name="nnUNetTrainer",
540-
nnunet_plans_name="nnUNetPlans", fold=0, mlflow_token=None):
552+
nnunet_plans_name="nnUNetPlans", fold=0, mlflow_token=None, dataset_name=None):
541553
"""
542554
Finalizes a MONAI bundle by converting model and dataset configurations to nnUNet format,
543555
saving checkpoints, and optionally validating the model using nnUNet.
@@ -601,23 +613,28 @@ def finalize_bundle(bundle_root, nnunet_root_dir=None, validate_with_nnunet=True
601613
print(e)
602614
mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name("FedLearning-"+experiment_name).experiment_id))
603615

604-
filter = f"""
605-
tags."client" = "{client_name}"
606-
"""
616+
run_name = f"run_validation_{client_name}"
607617

608-
runs = mlflow.search_runs(experiment_names=["FedLearning-"+experiment_name], filter_string=filter, order_by=["start_time DESC"])
618+
runs = mlflow.search_runs(
619+
experiment_names=["FedLearning-"+experiment_name],
620+
filter_string=f"tags.mlflow.runName = '{run_name}'",
621+
order_by=["start_time DESC"]
622+
)
623+
tags = {"client": client_name}
624+
if dataset_name is not None:
625+
tags["dataset_name"] = dataset_name
609626

610627

611628
if len(runs) == 0:
612-
with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}):
629+
with mlflow.start_run(run_name=f"run_validation_{client_name}", tags=tags):
613630
mlflow.log_dict(validation_summary_dict, "validation_summary.json")
614631
for label in validation_summary_dict["mean"]:
615632
for metric in validation_summary_dict["mean"][label]:
616633
label_name = labels[label]
617634
mlflow.log_metric(f"{label_name}_{metric}", float(validation_summary_dict["mean"][label][metric]))
618635

619636
else:
620-
with mlflow.start_run(run_id=runs.iloc[0].run_id, tags={"client": client_name}):
637+
with mlflow.start_run(run_id=runs.iloc[0].run_id, tags=tags):
621638
mlflow.log_dict(validation_summary_dict, "validation_summary.json")
622639
for label in validation_summary_dict["mean"]:
623640
for metric in validation_summary_dict["mean"][label]:
@@ -629,7 +646,7 @@ def finalize_bundle(bundle_root, nnunet_root_dir=None, validate_with_nnunet=True
629646

630647
def run_cross_site_validation(nnunet_root_dir, dataset_name_or_id, app_path, app_model_path, app_output_path, trainer_class_name="nnUNetTrainer", fold=0,
631648
experiment_name=None, client_name=None, tracking_uri=None,
632-
nnunet_plans_name="nnUNetPlans", mlflow_token=None, skip_prediction=False):
649+
nnunet_plans_name="nnUNetPlans", mlflow_token=None, skip_prediction=False, dataset_name=None):
633650

634651
validation_summary_dict, labels = cross_site_evaluation_api(
635652
nnunet_root_dir,
@@ -653,11 +670,16 @@ def run_cross_site_validation(nnunet_root_dir, dataset_name_or_id, app_path, app
653670
print(e)
654671
mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id))
655672

656-
filter = f"""
657-
tags."client" = "{client_name}"
658-
"""
673+
run_name = f"run_cross_site_validation_{client_name}"
659674

660-
runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"])
675+
runs = mlflow.search_runs(
676+
experiment_names=[experiment_name],
677+
filter_string=f"tags.mlflow.runName = '{run_name}'",
678+
order_by=["start_time DESC"]
679+
)
680+
tags = {"client": client_name}
681+
if dataset_name is not None:
682+
tags["dataset_name"] = dataset_name
661683

662684

663685
if len(runs) == 0:

monai/nvflare/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,9 @@ def prepare_bundle_api(bundle_config, train_extra_configs=None, is_federated=Fal
540540

541541
if "nnunet_trainer_class_name" in bundle_config:
542542
mlflow_params["nnunet_trainer_class_name"] = bundle_config["nnunet_trainer_class_name"]
543+
544+
if "dataset_name" in bundle_config:
545+
mlflow_params["dataset_name"] = bundle_config["dataset_name"]
543546

544547
with open(Path(bundle_config["bundle_root"]).joinpath("nnUNet", "params.yaml"), "w") as f:
545548
yaml.dump(mlflow_params, f)

0 commit comments

Comments
 (0)