From c8745c0a710ab08f02e7c17fc7674d0b94c66222 Mon Sep 17 00:00:00 2001 From: Kav Date: Fri, 25 Apr 2025 14:00:27 -0400 Subject: [PATCH 1/2] Add TorchScript model (model.ts) for Swin UNETR segmentation --- .../large_files.yml | 4 + .../scripts/standalone_model_test.py | 168 ++++++++++++++++++ 2 files changed, 172 insertions(+) create mode 100644 models/swin_unetr_btcv_segmentation/scripts/standalone_model_test.py diff --git a/models/swin_unetr_btcv_segmentation/large_files.yml b/models/swin_unetr_btcv_segmentation/large_files.yml index 4ca671f5..a4837ca9 100644 --- a/models/swin_unetr_btcv_segmentation/large_files.yml +++ b/models/swin_unetr_btcv_segmentation/large_files.yml @@ -3,3 +3,7 @@ large_files: url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_swin_unetr_btcv_segmentation_v1.pt" hash_val: "50dd67a01b28a1d5487fd9ac27e682fb" hash_type: "md5" + - path: "models/model.ts" + url: "https://drive.google.com/file/d/1byxFoe4XUGLjYT9LAIXj3fxiAWT7v1-T/" + hash_val: "28fe0edc4c533e0ee41d952f1d3962e0" + hash_type: "md5" \ No newline at end of file diff --git a/models/swin_unetr_btcv_segmentation/scripts/standalone_model_test.py b/models/swin_unetr_btcv_segmentation/scripts/standalone_model_test.py new file mode 100644 index 00000000..94c08339 --- /dev/null +++ b/models/swin_unetr_btcv_segmentation/scripts/standalone_model_test.py @@ -0,0 +1,168 @@ +import os +import torch +import numpy as np +import nibabel as nib +import pydicom +from pathlib import Path +from glob import glob +import SimpleITK as sitk +from monai.transforms import ( + Compose, + ScaleIntensityRange, + Spacing, + Orientation, + EnsureChannelFirst, + CropForeground +) + +# Paths +input_dir = "input/patient1/study1/series1" ## Please supply input data. +model_path = "../models/model.ts" +output_dir = "output" +os.makedirs(output_dir, exist_ok=True) + +# Load the traced model on CPU to avoid CUDA requirements +model = torch.jit.load(model_path, map_location=torch.device('cpu')) +model.eval() + +# Check file types +files = glob(os.path.join(input_dir, "*")) + +# Determine file types and load accordingly +if len(files) > 0: + # For multiple DICOM files (one per slice) + if files[0].endswith('.dcm') or len(files) > 10: # Assume multiple files is a DICOM series + reader = sitk.ImageSeriesReader() + dicom_names = reader.GetGDCMSeriesFileNames(input_dir) + reader.SetFileNames(dicom_names) + image = reader.Execute() + image_array = sitk.GetArrayFromImage(image) + + # Get spacing information from the DICOM + spacing = image.GetSpacing() + else: + # For NIfTI or other formats + image = nib.load(files[0]) + image_array = image.get_fdata() + # NIfTI is typically (x, y, z), so transpose to (z, y, x) for MONAI + image_array = np.transpose(image_array, (2, 1, 0)) + + # Handling different dimensionality cases + if len(image_array.shape) == 3: + z, y, x = image_array.shape + + # Check if we have a single slice (or very few slices) + if z == 1: + image_array = np.repeat(image_array, 96, axis=0) # Repeat along z to get desired depth + + # Add channel dimension for MONAI: (C, Z, Y, X) + image_array = np.expand_dims(image_array, 0) + image_tensor = torch.from_numpy(image_array).float() + else: + # Regular 3D data - add channel dimension: (C, Z, Y, X) + image_array = np.expand_dims(image_array, 0) + image_tensor = torch.from_numpy(image_array).float() + else: + # Already has channel dimension or other unusual shape + image_tensor = torch.from_numpy(image_array).float() + + + try: + # Skip the EnsureChannelFirst transform as tensor already has channel dimension first + + # Apply Spacing + # Doesn't work for 2D, only 3d + if len(image_tensor.shape) >= 4: # For tensors with at least 4 dimensions (C, Z, Y, X) + transform = Spacing(pixdim=(1.5, 1.5, 2.0), mode="bilinear") + image_tensor = transform(image_tensor) + + # Apply Orientation for 3d + if len(image_tensor.shape) >= 4: + transform = Orientation(axcodes="RAS") + image_tensor = transform(image_tensor) + + # Scale Intensity - works for both 2d & 3d + transform = ScaleIntensityRange(a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True) + image_tensor = transform(image_tensor) + + # Crop Foreground - use allow_smaller=False to prevent dimension issues + transform = CropForeground(select_fn=lambda x: x > 0, margin=0, allow_smaller=False) + image_tensor = transform(image_tensor) + + # Add batch dimension + if len(image_tensor.shape) == 3: # 2d case: (C, H, W) + image_tensor = image_tensor.unsqueeze(0) # Add batch: (B, C, H, W) + elif len(image_tensor.shape) == 4: # 3d case: (C, D, H, W) + image_tensor = image_tensor.unsqueeze(0) # Add batch: (B, C, D, H, W) + + # Check tensor shape against model requirements + expected_size = (96, 96, 96) + + # Center crop or pad to match expected dimensions + def center_crop_or_pad(tensor, target_size): + # Get current spatial dimensions (skip batch and channel) + current_size = tensor.shape[2:] + + # Create padded tensor with target size + if len(current_size) == 2: # 2d case + # For 2d, we'd need to handle differently or convert to 3d + raise ValueError("2D input not supported for 3D model") + elif len(current_size) == 3: # 3d case + d, h, w = current_size + td, th, tw = target_size + + # Calculate start/end indices for cropping/padding + d_start = max(0, (d - td) // 2) + d_end = min(d, d_start + td) + h_start = max(0, (h - th) // 2) + h_end = min(h, h_start + th) + w_start = max(0, (w - tw) // 2) + w_end = min(w, w_start + tw) + + # Crop + result = tensor[:, :, d_start:d_end, h_start:h_end, w_start:w_end] + + # Pad if necessary + pad_d = max(0, td - (d_end - d_start)) + pad_h = max(0, th - (h_end - h_start)) + pad_w = max(0, tw - (w_end - w_start)) + + if pad_d > 0 or pad_h > 0 or pad_w > 0: + pad_d_before = pad_d // 2 + pad_d_after = pad_d - pad_d_before + pad_h_before = pad_h // 2 + pad_h_after = pad_h - pad_h_before + pad_w_before = pad_w // 2 + pad_w_after = pad_w - pad_w_before + + padding = (pad_w_before, pad_w_after, + pad_h_before, pad_h_after, + pad_d_before, pad_d_after, + 0, 0) + + result = torch.nn.functional.pad(result, padding) + + return result + + # Only resize if the shape doesn't match expected + spatial_dims = image_tensor.shape[2:] + if spatial_dims != expected_size: + image_tensor = center_crop_or_pad(image_tensor, expected_size) + + # Run inference + with torch.no_grad(): + outputs = model(image_tensor) + + # Post-process + output_array = outputs[0].argmax(dim=0).numpy().astype(np.uint8) + + # Save output + output_nifti = nib.Nifti1Image(output_array, np.eye(4)) + output_path = os.path.join(output_dir, "segmentation.nii.gz") + nib.save(output_nifti, output_path) + + except Exception as e: + import traceback + traceback.print_exc() +else: + print(f"No files found in {input_dir}") \ No newline at end of file From 7e2b4b6132cb045221c46f96ef6970da5cc1d366 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 25 Apr 2025 18:03:06 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../large_files.yml | 2 +- .../scripts/standalone_model_test.py | 78 +++++++++---------- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/models/swin_unetr_btcv_segmentation/large_files.yml b/models/swin_unetr_btcv_segmentation/large_files.yml index a4837ca9..f0deb57e 100644 --- a/models/swin_unetr_btcv_segmentation/large_files.yml +++ b/models/swin_unetr_btcv_segmentation/large_files.yml @@ -6,4 +6,4 @@ large_files: - path: "models/model.ts" url: "https://drive.google.com/file/d/1byxFoe4XUGLjYT9LAIXj3fxiAWT7v1-T/" hash_val: "28fe0edc4c533e0ee41d952f1d3962e0" - hash_type: "md5" \ No newline at end of file + hash_type: "md5" diff --git a/models/swin_unetr_btcv_segmentation/scripts/standalone_model_test.py b/models/swin_unetr_btcv_segmentation/scripts/standalone_model_test.py index 94c08339..a116bb98 100644 --- a/models/swin_unetr_btcv_segmentation/scripts/standalone_model_test.py +++ b/models/swin_unetr_btcv_segmentation/scripts/standalone_model_test.py @@ -7,9 +7,9 @@ from glob import glob import SimpleITK as sitk from monai.transforms import ( - Compose, - ScaleIntensityRange, - Spacing, + Compose, + ScaleIntensityRange, + Spacing, Orientation, EnsureChannelFirst, CropForeground @@ -37,7 +37,7 @@ reader.SetFileNames(dicom_names) image = reader.Execute() image_array = sitk.GetArrayFromImage(image) - + # Get spacing information from the DICOM spacing = image.GetSpacing() else: @@ -46,15 +46,15 @@ image_array = image.get_fdata() # NIfTI is typically (x, y, z), so transpose to (z, y, x) for MONAI image_array = np.transpose(image_array, (2, 1, 0)) - + # Handling different dimensionality cases if len(image_array.shape) == 3: z, y, x = image_array.shape - + # Check if we have a single slice (or very few slices) if z == 1: image_array = np.repeat(image_array, 96, axis=0) # Repeat along z to get desired depth - + # Add channel dimension for MONAI: (C, Z, Y, X) image_array = np.expand_dims(image_array, 0) image_tensor = torch.from_numpy(image_array).float() @@ -65,44 +65,44 @@ else: # Already has channel dimension or other unusual shape image_tensor = torch.from_numpy(image_array).float() - - + + try: # Skip the EnsureChannelFirst transform as tensor already has channel dimension first - + # Apply Spacing - # Doesn't work for 2D, only 3d + # Doesn't work for 2D, only 3d if len(image_tensor.shape) >= 4: # For tensors with at least 4 dimensions (C, Z, Y, X) transform = Spacing(pixdim=(1.5, 1.5, 2.0), mode="bilinear") image_tensor = transform(image_tensor) - - # Apply Orientation for 3d + + # Apply Orientation for 3d if len(image_tensor.shape) >= 4: transform = Orientation(axcodes="RAS") image_tensor = transform(image_tensor) - + # Scale Intensity - works for both 2d & 3d transform = ScaleIntensityRange(a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True) image_tensor = transform(image_tensor) - + # Crop Foreground - use allow_smaller=False to prevent dimension issues transform = CropForeground(select_fn=lambda x: x > 0, margin=0, allow_smaller=False) image_tensor = transform(image_tensor) - - # Add batch dimension + + # Add batch dimension if len(image_tensor.shape) == 3: # 2d case: (C, H, W) image_tensor = image_tensor.unsqueeze(0) # Add batch: (B, C, H, W) elif len(image_tensor.shape) == 4: # 3d case: (C, D, H, W) image_tensor = image_tensor.unsqueeze(0) # Add batch: (B, C, D, H, W) - + # Check tensor shape against model requirements - expected_size = (96, 96, 96) - + expected_size = (96, 96, 96) + # Center crop or pad to match expected dimensions def center_crop_or_pad(tensor, target_size): # Get current spatial dimensions (skip batch and channel) current_size = tensor.shape[2:] - + # Create padded tensor with target size if len(current_size) == 2: # 2d case # For 2d, we'd need to handle differently or convert to 3d @@ -110,7 +110,7 @@ def center_crop_or_pad(tensor, target_size): elif len(current_size) == 3: # 3d case d, h, w = current_size td, th, tw = target_size - + # Calculate start/end indices for cropping/padding d_start = max(0, (d - td) // 2) d_end = min(d, d_start + td) @@ -118,15 +118,15 @@ def center_crop_or_pad(tensor, target_size): h_end = min(h, h_start + th) w_start = max(0, (w - tw) // 2) w_end = min(w, w_start + tw) - + # Crop result = tensor[:, :, d_start:d_end, h_start:h_end, w_start:w_end] - + # Pad if necessary pad_d = max(0, td - (d_end - d_start)) pad_h = max(0, th - (h_end - h_start)) pad_w = max(0, tw - (w_end - w_start)) - + if pad_d > 0 or pad_h > 0 or pad_w > 0: pad_d_before = pad_d // 2 pad_d_after = pad_d - pad_d_before @@ -134,35 +134,35 @@ def center_crop_or_pad(tensor, target_size): pad_h_after = pad_h - pad_h_before pad_w_before = pad_w // 2 pad_w_after = pad_w - pad_w_before - - padding = (pad_w_before, pad_w_after, - pad_h_before, pad_h_after, - pad_d_before, pad_d_after, - 0, 0) - + + padding = (pad_w_before, pad_w_after, + pad_h_before, pad_h_after, + pad_d_before, pad_d_after, + 0, 0) + result = torch.nn.functional.pad(result, padding) - + return result - + # Only resize if the shape doesn't match expected spatial_dims = image_tensor.shape[2:] if spatial_dims != expected_size: image_tensor = center_crop_or_pad(image_tensor, expected_size) - + # Run inference with torch.no_grad(): outputs = model(image_tensor) - + # Post-process output_array = outputs[0].argmax(dim=0).numpy().astype(np.uint8) - - # Save output + + # Save output output_nifti = nib.Nifti1Image(output_array, np.eye(4)) output_path = os.path.join(output_dir, "segmentation.nii.gz") nib.save(output_nifti, output_path) - + except Exception as e: import traceback traceback.print_exc() else: - print(f"No files found in {input_dir}") \ No newline at end of file + print(f"No files found in {input_dir}")