Skip to content

Commit 3b61e9f

Browse files
Enhance ModelnnUNetWrapper to accept additional parameters for dataset JSON, plans, and nnUNet configuration; improve loading logic for optional inputs.
1 parent 02eec30 commit 3b61e9f

File tree

1 file changed

+55
-22
lines changed

1 file changed

+55
-22
lines changed

monai/apps/nnunet/nnunet_bundle.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ class ModelnnUNetWrapper(torch.nn.Module):
152152
The folder path where the model and related files are stored.
153153
model_name : str, optional
154154
The name of the model file, by default "model.pt".
155+
dataset_json : dict, optional
156+
The dataset JSON file containing dataset information.
157+
plans : dict, optional
158+
The plans JSON file containing model configuration.
159+
nnunet_config : dict, optional
160+
The nnUNet configuration dictionary containing model parameters.
155161
156162
Attributes
157163
----------
@@ -166,7 +172,7 @@ class ModelnnUNetWrapper(torch.nn.Module):
166172
restoring network architecture, and setting up the predictor for inference.
167173
"""
168174

169-
def __init__(self, predictor: object, model_folder: Union[str, Path], model_name: str = "model.pt"): # type: ignore
175+
def __init__(self, predictor: object, model_folder: Union[str, Path], model_name: str = "model.pt", dataset_json: dict = None, plans: dict = None, nnunet_config: dict = None): # type: ignore
170176
super().__init__()
171177
self.predictor = predictor
172178

@@ -175,23 +181,31 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
175181
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
176182

177183
# Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor
178-
dataset_json = load_json(join(Path(model_training_output_dir).parent, "dataset.json"))
179-
plans = load_json(join(Path(model_training_output_dir).parent, "plans.json"))
184+
if dataset_json is None:
185+
dataset_json = load_json(join(Path(model_training_output_dir).parent, "dataset.json"))
186+
if plans is None:
187+
plans = load_json(join(Path(model_training_output_dir).parent, "plans.json"))
180188
plans_manager = PlansManager(plans)
181189

182190
parameters = []
183191

184-
checkpoint = torch.load(
185-
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
186-
)
187-
trainer_name = checkpoint["trainer_name"]
188-
configuration_name = checkpoint["init_args"]["configuration"]
189-
inference_allowed_mirroring_axes = (
190-
checkpoint["inference_allowed_mirroring_axes"]
191-
if "inference_allowed_mirroring_axes" in checkpoint.keys()
192-
else None
193-
)
194-
if Path(model_training_output_dir).joinpath(model_name).is_file():
192+
if nnunet_config is None:
193+
checkpoint = torch.load(
194+
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
195+
)
196+
trainer_name = checkpoint["trainer_name"]
197+
configuration_name = checkpoint["init_args"]["configuration"]
198+
inference_allowed_mirroring_axes = (
199+
checkpoint["inference_allowed_mirroring_axes"]
200+
if "inference_allowed_mirroring_axes" in checkpoint.keys()
201+
else None
202+
)
203+
else:
204+
trainer_name = nnunet_config["trainer_name"]
205+
configuration_name = nnunet_config["configuration"]
206+
inference_allowed_mirroring_axes = nnunet_config["inference_allowed_mirroring_axes"]
207+
208+
if Path(model_training_output_dir).joinpath(model_name).is_file() and model_name.endswith(".pt"):
195209
monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu"))
196210
if "network_weights" in monai_checkpoint.keys():
197211
parameters.append(monai_checkpoint["network_weights"])
@@ -255,7 +269,16 @@ def forward(self, x: MetaTensor) -> MetaTensor:
255269
"""
256270
if isinstance(x, MetaTensor):
257271
if "pixdim" in x.meta:
258-
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()}
272+
if x.meta["pixdim"].ndim == 1:
273+
if x.meta["pixdim"][0] == 1:
274+
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][1:4].tolist()}
275+
else:
276+
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][:3].tolist()}
277+
else:
278+
if x.meta["pixdim"][0][0] == 1:
279+
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()}
280+
else:
281+
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][:3].numpy().tolist()}
259282
elif "affine" in x.meta:
260283
spacing = [
261284
abs(x.meta["affine"][0][0].item()),
@@ -269,6 +292,8 @@ def forward(self, x: MetaTensor) -> MetaTensor:
269292
raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.")
270293

271294
image_or_list_of_images = x.cpu().numpy()[0, :]
295+
image_or_list_of_images = np.transpose(image_or_list_of_images, (0, 3, 2, 1))
296+
properties_or_list_of_properties["spacing"] = properties_or_list_of_properties["spacing"][::-1]
272297

273298
# input_files should be a list of file paths, one per modality
274299
prediction_output = self.predictor.predict_from_list_of_npy_arrays( # type: ignore
@@ -286,11 +311,11 @@ def forward(self, x: MetaTensor) -> MetaTensor:
286311
for out in prediction_output: # Add batch and channel dimensions
287312
out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0)))
288313
out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension
289-
314+
out_tensor = out_tensor.permute(0, 1, 4, 3, 2)
290315
return MetaTensor(out_tensor, meta=x.meta)
291316

292317

293-
def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str = "model.pt") -> ModelnnUNetWrapper:
318+
def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str = "model.pt", dataset_json: dict = None, plans: dict = None, nnunet_config: dict = None) -> ModelnnUNetWrapper:
294319
"""
295320
Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`.
296321
The model folder should contain the following files, created during training:
@@ -321,6 +346,12 @@ def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str =
321346
The folder where the model is stored.
322347
model_name : str, optional
323348
The name of the model file, by default "model.pt".
349+
dataset_json : dict, optional
350+
The dataset JSON file containing dataset information.
351+
plans : dict, optional
352+
The plans JSON file containing model configuration.
353+
nnunet_config : dict, optional
354+
The nnUNet configuration dictionary containing model parameters.
324355
325356
Returns
326357
-------
@@ -335,12 +366,12 @@ def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str =
335366
use_gaussian=True,
336367
use_mirroring=False,
337368
device=torch.device("cuda", 0),
338-
verbose=False,
339-
verbose_preprocessing=False,
369+
verbose=True,
370+
verbose_preprocessing=True,
340371
allow_tqdm=True,
341372
)
342373
# initializes the network architecture, loads the checkpoint
343-
wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name)
374+
wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name, dataset_json, plans, nnunet_config)
344375
return wrapper
345376

346377

@@ -561,7 +592,8 @@ def subfiles(
561592
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt"
562593
)
563594

564-
nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"]
595+
if "optimizer_state" in monai_last_checkpoint:
596+
nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"]
565597

566598
nnunet_checkpoint["network_weights"] = odict()
567599

@@ -577,7 +609,8 @@ def subfiles(
577609

578610
nnunet_checkpoint["network_weights"] = odict()
579611

580-
nnunet_checkpoint["optimizer_state"] = monai_best_checkpoint["optimizer_state"]
612+
if "optimizer_state" in monai_last_checkpoint:
613+
nnunet_checkpoint["optimizer_state"] = monai_best_checkpoint["optimizer_state"]
581614

582615
for key in monai_best_checkpoint["network_weights"]:
583616
nnunet_checkpoint["network_weights"][key] = monai_best_checkpoint["network_weights"][key]

0 commit comments

Comments
 (0)