Skip to content

Commit 460481f

Browse files
Add ResultDownloader class for job result processing and model file management
1 parent c267fdf commit 460481f

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

monai/nvflare/scripts.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import os
2+
import shutil
3+
from nvflare.apis.fl_context import FLContext
4+
from nvflare.apis.server_engine_spec import ServerEngineSpec
5+
from nvflare.app_common.app_constant import AppConstants
6+
from nvflare.app_common.abstract.fl_app_script import FLAppScript
7+
8+
9+
class ResultDownloader(FLAppScript):
10+
def __init__(self, job_id: str, destination_path: str,bundle_root: str, fold_id: int):
11+
super().__init__()
12+
self.job_id = job_id
13+
self.destination_path = destination_path
14+
self.fold_id = fold_id
15+
self.bundle_root = bundle_root
16+
17+
def execute(self, fl_ctx: FLContext):
18+
engine: ServerEngineSpec = fl_ctx.get_prop(AppConstants.ENGINE)
19+
job_meta = engine.get_job_store().get_job_meta(self.job_id)
20+
job_dir = job_meta.get("folder")
21+
22+
if not job_dir or not os.path.exists(job_dir):
23+
self.log_error(fl_ctx, f"Job directory for {self.job_id} not found.")
24+
return
25+
26+
try:
27+
# 1. Copy job directory to destination_path
28+
shutil.copytree(job_dir, self.destination_path, dirs_exist_ok=True)
29+
self.log_info(fl_ctx, f"Copied results for job {self.job_id} to {self.destination_path}")
30+
31+
# 2. Locate the global model file
32+
source_model = os.path.join(self.destination_path, "job", "workspace", "app_server", "FL_global_model.pt")
33+
if not os.path.exists(source_model):
34+
self.log_error(fl_ctx, f"Model file not found at expected location: {source_model}")
35+
return
36+
37+
# 3. Determine target path inside BUNDLE_ROOT
38+
target_model_dir = os.path.join(self.bundle_root, "models", f"fold_{self.fold_id}")
39+
os.makedirs(target_model_dir, exist_ok=True)
40+
41+
target_model_path = os.path.join(target_model_dir, "FL_global_model.pt")
42+
43+
# 4. Copy model file
44+
shutil.copy2(source_model, target_model_path)
45+
self.log_info(fl_ctx, f"Model copied to {target_model_path}")
46+
47+
except Exception as e:
48+
self.log_error(fl_ctx, f"Failed to process job results: {str(e)}")

0 commit comments

Comments
 (0)