|
|
"""Image processor for Sybil CT scan preprocessing""" |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from typing import Dict, List, Optional, Union, Tuple |
|
|
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
|
|
from transformers.utils import TensorType |
|
|
import pydicom |
|
|
from PIL import Image |
|
|
import torchio as tio |
|
|
|
|
|
|
|
|
def order_slices(dicoms: List) -> List: |
|
|
"""Order DICOM slices by their position""" |
|
|
|
|
|
try: |
|
|
dicoms = sorted(dicoms, key=lambda x: float(x.ImagePositionPatient[2])) |
|
|
except (AttributeError, TypeError): |
|
|
|
|
|
try: |
|
|
dicoms = sorted(dicoms, key=lambda x: int(x.InstanceNumber)) |
|
|
except (AttributeError, TypeError): |
|
|
pass |
|
|
return dicoms |
|
|
|
|
|
|
|
|
class SybilImageProcessor(BaseImageProcessor): |
|
|
""" |
|
|
Constructs a Sybil image processor for preprocessing CT scans. |
|
|
|
|
|
Args: |
|
|
voxel_spacing (`List[float]`, *optional*, defaults to `[0.703125, 0.703125, 2.5]`): |
|
|
Target voxel spacing for resampling (row, column, slice thickness). |
|
|
img_size (`List[int]`, *optional*, defaults to `[512, 512]`): |
|
|
Target image size after resizing. |
|
|
num_images (`int`, *optional*, defaults to `208`): |
|
|
Number of slices to use from the CT scan. |
|
|
windowing (`Dict[str, float]`, *optional*): |
|
|
Windowing parameters for CT scan visualization. |
|
|
Default uses lung window: center=-600, width=1500. |
|
|
normalize (`bool`, *optional*, defaults to `True`): |
|
|
Whether to normalize pixel values to [0, 1]. |
|
|
**kwargs: |
|
|
Additional keyword arguments passed to the parent class. |
|
|
""" |
|
|
|
|
|
model_input_names = ["pixel_values"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
voxel_spacing: List[float] = None, |
|
|
img_size: List[int] = None, |
|
|
num_images: int = 208, |
|
|
windowing: Dict[str, float] = None, |
|
|
normalize: bool = True, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.voxel_spacing = voxel_spacing if voxel_spacing is not None else [0.703125, 0.703125, 2.5] |
|
|
self.img_size = img_size if img_size is not None else [512, 512] |
|
|
self.num_images = num_images |
|
|
|
|
|
|
|
|
self.windowing = windowing if windowing is not None else { |
|
|
"center": -600, |
|
|
"width": 1500 |
|
|
} |
|
|
self.normalize = normalize |
|
|
|
|
|
|
|
|
self.resample_transform = tio.transforms.Resample(target=self.voxel_spacing) |
|
|
|
|
|
self.default_depth = 200 |
|
|
self.default_size = [256, 256] |
|
|
self.padding_transform = tio.transforms.CropOrPad( |
|
|
target_shape=(self.default_depth, *self.default_size), |
|
|
padding_mode=0 |
|
|
) |
|
|
|
|
|
def load_dicom_series(self, paths: List[str]) -> Tuple[np.ndarray, Dict]: |
|
|
""" |
|
|
Load a series of DICOM files. |
|
|
|
|
|
Args: |
|
|
paths: List of paths to DICOM files. |
|
|
|
|
|
Returns: |
|
|
Tuple of (volume array, metadata dict) |
|
|
""" |
|
|
dicoms = [] |
|
|
for path in paths: |
|
|
try: |
|
|
dcm = pydicom.dcmread(path, stop_before_pixels=False) |
|
|
dicoms.append(dcm) |
|
|
except Exception as e: |
|
|
print(f"Error reading DICOM file {path}: {e}") |
|
|
continue |
|
|
|
|
|
if not dicoms: |
|
|
raise ValueError("No valid DICOM files found") |
|
|
|
|
|
|
|
|
dicoms = order_slices(dicoms) |
|
|
|
|
|
|
|
|
volume = np.stack([dcm.pixel_array.astype(np.float32) for dcm in dicoms]) |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"slice_thickness": float(dicoms[0].SliceThickness) if hasattr(dicoms[0], 'SliceThickness') else None, |
|
|
"pixel_spacing": list(map(float, dicoms[0].PixelSpacing)) if hasattr(dicoms[0], 'PixelSpacing') else None, |
|
|
"manufacturer": str(dicoms[0].Manufacturer) if hasattr(dicoms[0], 'Manufacturer') else None, |
|
|
"num_slices": len(dicoms) |
|
|
} |
|
|
|
|
|
|
|
|
if hasattr(dicoms[0], 'RescaleSlope') and hasattr(dicoms[0], 'RescaleIntercept'): |
|
|
slope = float(dicoms[0].RescaleSlope) |
|
|
intercept = float(dicoms[0].RescaleIntercept) |
|
|
volume = volume * slope + intercept |
|
|
|
|
|
return volume, metadata |
|
|
|
|
|
def load_png_series(self, paths: List[str]) -> np.ndarray: |
|
|
""" |
|
|
Load a series of PNG files. |
|
|
|
|
|
Args: |
|
|
paths: List of paths to PNG files (must be in anatomical order). |
|
|
|
|
|
Returns: |
|
|
3D volume array |
|
|
""" |
|
|
images = [] |
|
|
for path in paths: |
|
|
img = Image.open(path).convert('L') |
|
|
images.append(np.array(img, dtype=np.float32)) |
|
|
|
|
|
return np.stack(images) |
|
|
|
|
|
def apply_windowing(self, volume: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Apply windowing to CT scan for better visualization. |
|
|
|
|
|
Args: |
|
|
volume: 3D CT volume. |
|
|
|
|
|
Returns: |
|
|
Windowed volume. |
|
|
""" |
|
|
center = self.windowing["center"] |
|
|
width = self.windowing["width"] |
|
|
|
|
|
|
|
|
lower = center - width / 2 |
|
|
upper = center + width / 2 |
|
|
|
|
|
|
|
|
volume = np.clip(volume, lower, upper) |
|
|
|
|
|
|
|
|
if self.normalize: |
|
|
volume = (volume - lower) / (upper - lower) |
|
|
|
|
|
return volume |
|
|
|
|
|
def resample_volume( |
|
|
self, |
|
|
volume: torch.Tensor, |
|
|
original_spacing: Optional[List[float]] = None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Resample volume to target voxel spacing. |
|
|
|
|
|
Args: |
|
|
volume: 3D volume tensor. |
|
|
original_spacing: Original voxel spacing. |
|
|
|
|
|
Returns: |
|
|
Resampled volume. |
|
|
""" |
|
|
|
|
|
subject = tio.Subject( |
|
|
image=tio.ScalarImage(tensor=volume.unsqueeze(0), spacing=original_spacing) |
|
|
) |
|
|
|
|
|
|
|
|
resampled = self.resample_transform(subject) |
|
|
|
|
|
return resampled['image'].data.squeeze(0) |
|
|
|
|
|
def pad_or_crop_volume(self, volume: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Pad or crop volume to target shape. |
|
|
|
|
|
Args: |
|
|
volume: 3D volume tensor. |
|
|
|
|
|
Returns: |
|
|
Padded/cropped volume. |
|
|
""" |
|
|
|
|
|
subject = tio.Subject( |
|
|
image=tio.ScalarImage(tensor=volume.unsqueeze(0)) |
|
|
) |
|
|
|
|
|
|
|
|
transformed = self.padding_transform(subject) |
|
|
|
|
|
return transformed['image'].data.squeeze(0) |
|
|
|
|
|
def preprocess( |
|
|
self, |
|
|
images: Union[List[str], np.ndarray, torch.Tensor], |
|
|
file_type: str = "dicom", |
|
|
voxel_spacing: Optional[List[float]] = None, |
|
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
|
**kwargs |
|
|
) -> BatchFeature: |
|
|
""" |
|
|
Preprocess CT scan images. |
|
|
|
|
|
Args: |
|
|
images: Either list of file paths or numpy/torch array of images. |
|
|
file_type: Type of input files ("dicom" or "png"). |
|
|
voxel_spacing: Original voxel spacing (required for PNG files). |
|
|
return_tensors: The type of tensors to return. |
|
|
|
|
|
Returns: |
|
|
BatchFeature with preprocessed images. |
|
|
""" |
|
|
|
|
|
if isinstance(images, list) and isinstance(images[0], str): |
|
|
if file_type == "dicom": |
|
|
volume, metadata = self.load_dicom_series(images) |
|
|
if voxel_spacing is None and metadata["pixel_spacing"]: |
|
|
voxel_spacing = metadata["pixel_spacing"] + [metadata["slice_thickness"]] |
|
|
elif file_type == "png": |
|
|
if voxel_spacing is None: |
|
|
raise ValueError("voxel_spacing must be provided for PNG files") |
|
|
volume = self.load_png_series(images) |
|
|
else: |
|
|
raise ValueError(f"Unknown file type: {file_type}") |
|
|
elif isinstance(images, (np.ndarray, torch.Tensor)): |
|
|
volume = images |
|
|
else: |
|
|
raise ValueError("Images must be file paths, numpy array, or torch tensor") |
|
|
|
|
|
|
|
|
if isinstance(volume, np.ndarray): |
|
|
volume = torch.from_numpy(volume).float() |
|
|
|
|
|
|
|
|
if isinstance(volume, torch.Tensor): |
|
|
volume_np = volume.numpy() |
|
|
else: |
|
|
volume_np = volume |
|
|
volume_np = self.apply_windowing(volume_np) |
|
|
volume = torch.from_numpy(volume_np).float() |
|
|
|
|
|
|
|
|
if voxel_spacing is not None: |
|
|
volume = self.resample_volume(volume, voxel_spacing) |
|
|
|
|
|
|
|
|
volume = self.pad_or_crop_volume(volume) |
|
|
|
|
|
|
|
|
|
|
|
volume = volume.unsqueeze(0).repeat(3, 1, 1, 1) |
|
|
|
|
|
|
|
|
data = {"pixel_values": volume} |
|
|
|
|
|
|
|
|
if return_tensors == "pt": |
|
|
return BatchFeature(data=data, tensor_type=TensorType.PYTORCH) |
|
|
elif return_tensors == "np": |
|
|
data = {k: v.numpy() for k, v in data.items()} |
|
|
return BatchFeature(data=data, tensor_type=TensorType.NUMPY) |
|
|
else: |
|
|
return BatchFeature(data=data) |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
images: Union[List[str], List[List[str]], np.ndarray, torch.Tensor], |
|
|
**kwargs |
|
|
) -> BatchFeature: |
|
|
""" |
|
|
Main method to prepare images for the model. |
|
|
|
|
|
Args: |
|
|
images: Images to preprocess. Can be: |
|
|
- List of file paths for a single series |
|
|
- List of lists of file paths for multiple series |
|
|
- Numpy array or torch tensor |
|
|
|
|
|
Returns: |
|
|
BatchFeature with preprocessed images ready for model input. |
|
|
""" |
|
|
|
|
|
if isinstance(images, list) and images and isinstance(images[0], list): |
|
|
|
|
|
batch_volumes = [] |
|
|
for series_paths in images: |
|
|
result = self.preprocess(series_paths, **kwargs) |
|
|
batch_volumes.append(result["pixel_values"]) |
|
|
|
|
|
|
|
|
pixel_values = torch.stack(batch_volumes) |
|
|
return BatchFeature(data={"pixel_values": pixel_values}) |
|
|
else: |
|
|
|
|
|
return self.preprocess(images, **kwargs) |