@@ -152,6 +152,12 @@ class ModelnnUNetWrapper(torch.nn.Module):
152
152
The folder path where the model and related files are stored.
153
153
model_name : str, optional
154
154
The name of the model file, by default "model.pt".
155
+ dataset_json : dict, optional
156
+ The dataset JSON file containing dataset information.
157
+ plans : dict, optional
158
+ The plans JSON file containing model configuration.
159
+ nnunet_config : dict, optional
160
+ The nnUNet configuration dictionary containing model parameters.
155
161
156
162
Attributes
157
163
----------
@@ -166,7 +172,7 @@ class ModelnnUNetWrapper(torch.nn.Module):
166
172
restoring network architecture, and setting up the predictor for inference.
167
173
"""
168
174
169
- def __init__ (self , predictor : object , model_folder : Union [str , Path ], model_name : str = "model.pt" ): # type: ignore
175
+ def __init__ (self , predictor : object , model_folder : Union [str , Path ], model_name : str = "model.pt" , dataset_json : dict = None , plans : dict = None , nnunet_config : dict = None ): # type: ignore
170
176
super ().__init__ ()
171
177
self .predictor = predictor
172
178
@@ -175,23 +181,31 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
175
181
from nnunetv2 .utilities .plans_handling .plans_handler import PlansManager
176
182
177
183
# Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor
178
- dataset_json = load_json (join (Path (model_training_output_dir ).parent , "dataset.json" ))
179
- plans = load_json (join (Path (model_training_output_dir ).parent , "plans.json" ))
184
+ if dataset_json is None :
185
+ dataset_json = load_json (join (Path (model_training_output_dir ).parent , "dataset.json" ))
186
+ if plans is None :
187
+ plans = load_json (join (Path (model_training_output_dir ).parent , "plans.json" ))
180
188
plans_manager = PlansManager (plans )
181
189
182
190
parameters = []
183
191
184
- checkpoint = torch .load (
185
- join (Path (model_training_output_dir ).parent , "nnunet_checkpoint.pth" ), map_location = torch .device ("cpu" )
186
- )
187
- trainer_name = checkpoint ["trainer_name" ]
188
- configuration_name = checkpoint ["init_args" ]["configuration" ]
189
- inference_allowed_mirroring_axes = (
190
- checkpoint ["inference_allowed_mirroring_axes" ]
191
- if "inference_allowed_mirroring_axes" in checkpoint .keys ()
192
- else None
193
- )
194
- if Path (model_training_output_dir ).joinpath (model_name ).is_file ():
192
+ if nnunet_config is None :
193
+ checkpoint = torch .load (
194
+ join (Path (model_training_output_dir ).parent , "nnunet_checkpoint.pth" ), map_location = torch .device ("cpu" )
195
+ )
196
+ trainer_name = checkpoint ["trainer_name" ]
197
+ configuration_name = checkpoint ["init_args" ]["configuration" ]
198
+ inference_allowed_mirroring_axes = (
199
+ checkpoint ["inference_allowed_mirroring_axes" ]
200
+ if "inference_allowed_mirroring_axes" in checkpoint .keys ()
201
+ else None
202
+ )
203
+ else :
204
+ trainer_name = nnunet_config ["trainer_name" ]
205
+ configuration_name = nnunet_config ["configuration" ]
206
+ inference_allowed_mirroring_axes = nnunet_config ["inference_allowed_mirroring_axes" ]
207
+
208
+ if Path (model_training_output_dir ).joinpath (model_name ).is_file () and model_name .endswith (".pt" ):
195
209
monai_checkpoint = torch .load (join (model_training_output_dir , model_name ), map_location = torch .device ("cpu" ))
196
210
if "network_weights" in monai_checkpoint .keys ():
197
211
parameters .append (monai_checkpoint ["network_weights" ])
@@ -255,7 +269,16 @@ def forward(self, x: MetaTensor) -> MetaTensor:
255
269
"""
256
270
if isinstance (x , MetaTensor ):
257
271
if "pixdim" in x .meta :
258
- properties_or_list_of_properties = {"spacing" : x .meta ["pixdim" ][0 ][1 :4 ].numpy ().tolist ()}
272
+ if x .meta ["pixdim" ].ndim == 1 :
273
+ if x .meta ["pixdim" ][0 ] == 1 :
274
+ properties_or_list_of_properties = {"spacing" : x .meta ["pixdim" ][1 :4 ].tolist ()}
275
+ else :
276
+ properties_or_list_of_properties = {"spacing" : x .meta ["pixdim" ][:3 ].tolist ()}
277
+ else :
278
+ if x .meta ["pixdim" ][0 ][0 ] == 1 :
279
+ properties_or_list_of_properties = {"spacing" : x .meta ["pixdim" ][0 ][1 :4 ].numpy ().tolist ()}
280
+ else :
281
+ properties_or_list_of_properties = {"spacing" : x .meta ["pixdim" ][0 ][:3 ].numpy ().tolist ()}
259
282
elif "affine" in x .meta :
260
283
spacing = [
261
284
abs (x .meta ["affine" ][0 ][0 ].item ()),
@@ -269,6 +292,8 @@ def forward(self, x: MetaTensor) -> MetaTensor:
269
292
raise TypeError ("Input must be a MetaTensor or a tuple of MetaTensors." )
270
293
271
294
image_or_list_of_images = x .cpu ().numpy ()[0 , :]
295
+ image_or_list_of_images = np .transpose (image_or_list_of_images , (0 , 3 , 2 , 1 ))
296
+ properties_or_list_of_properties ["spacing" ] = properties_or_list_of_properties ["spacing" ][::- 1 ]
272
297
273
298
# input_files should be a list of file paths, one per modality
274
299
prediction_output = self .predictor .predict_from_list_of_npy_arrays ( # type: ignore
@@ -286,11 +311,11 @@ def forward(self, x: MetaTensor) -> MetaTensor:
286
311
for out in prediction_output : # Add batch and channel dimensions
287
312
out_tensors .append (torch .from_numpy (np .expand_dims (np .expand_dims (out , 0 ), 0 )))
288
313
out_tensor = torch .cat (out_tensors , 0 ) # Concatenate along batch dimension
289
-
314
+ out_tensor = out_tensor . permute ( 0 , 1 , 4 , 3 , 2 )
290
315
return MetaTensor (out_tensor , meta = x .meta )
291
316
292
317
293
- def get_nnunet_monai_predictor (model_folder : Union [str , Path ], model_name : str = "model.pt" ) -> ModelnnUNetWrapper :
318
+ def get_nnunet_monai_predictor (model_folder : Union [str , Path ], model_name : str = "model.pt" , dataset_json : dict = None , plans : dict = None , nnunet_config : dict = None ) -> ModelnnUNetWrapper :
294
319
"""
295
320
Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`.
296
321
The model folder should contain the following files, created during training:
@@ -321,6 +346,12 @@ def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str =
321
346
The folder where the model is stored.
322
347
model_name : str, optional
323
348
The name of the model file, by default "model.pt".
349
+ dataset_json : dict, optional
350
+ The dataset JSON file containing dataset information.
351
+ plans : dict, optional
352
+ The plans JSON file containing model configuration.
353
+ nnunet_config : dict, optional
354
+ The nnUNet configuration dictionary containing model parameters.
324
355
325
356
Returns
326
357
-------
@@ -335,12 +366,12 @@ def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str =
335
366
use_gaussian = True ,
336
367
use_mirroring = False ,
337
368
device = torch .device ("cuda" , 0 ),
338
- verbose = False ,
339
- verbose_preprocessing = False ,
369
+ verbose = True ,
370
+ verbose_preprocessing = True ,
340
371
allow_tqdm = True ,
341
372
)
342
373
# initializes the network architecture, loads the checkpoint
343
- wrapper = ModelnnUNetWrapper (predictor , model_folder , model_name )
374
+ wrapper = ModelnnUNetWrapper (predictor , model_folder , model_name , dataset_json , plans , nnunet_config )
344
375
return wrapper
345
376
346
377
@@ -561,7 +592,8 @@ def subfiles(
561
592
f"{ bundle_root_folder } /models/fold_{ fold } /checkpoint_key_metric={ best_key_metric } .pt"
562
593
)
563
594
564
- nnunet_checkpoint ["optimizer_state" ] = monai_last_checkpoint ["optimizer_state" ]
595
+ if "optimizer_state" in monai_last_checkpoint :
596
+ nnunet_checkpoint ["optimizer_state" ] = monai_last_checkpoint ["optimizer_state" ]
565
597
566
598
nnunet_checkpoint ["network_weights" ] = odict ()
567
599
@@ -577,7 +609,8 @@ def subfiles(
577
609
578
610
nnunet_checkpoint ["network_weights" ] = odict ()
579
611
580
- nnunet_checkpoint ["optimizer_state" ] = monai_best_checkpoint ["optimizer_state" ]
612
+ if "optimizer_state" in monai_last_checkpoint :
613
+ nnunet_checkpoint ["optimizer_state" ] = monai_best_checkpoint ["optimizer_state" ]
581
614
582
615
for key in monai_best_checkpoint ["network_weights" ]:
583
616
nnunet_checkpoint ["network_weights" ][key ] = monai_best_checkpoint ["network_weights" ][key ]
0 commit comments