"""Image processor for Sybil CT scan preprocessing""" import numpy as np import torch from typing import Dict, List, Optional, Union, Tuple from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.utils import TensorType import pydicom from PIL import Image import torchio as tio def order_slices(dicoms: List) -> List: """Order DICOM slices by their position""" # Sort by ImagePositionPatient if available try: dicoms = sorted(dicoms, key=lambda x: float(x.ImagePositionPatient[2])) except (AttributeError, TypeError): # Fall back to InstanceNumber if ImagePositionPatient not available try: dicoms = sorted(dicoms, key=lambda x: int(x.InstanceNumber)) except (AttributeError, TypeError): pass # Keep original order if neither attribute is available return dicoms class SybilImageProcessor(BaseImageProcessor): """ Constructs a Sybil image processor for preprocessing CT scans. Args: voxel_spacing (`List[float]`, *optional*, defaults to `[0.703125, 0.703125, 2.5]`): Target voxel spacing for resampling (row, column, slice thickness). img_size (`List[int]`, *optional*, defaults to `[512, 512]`): Target image size after resizing. num_images (`int`, *optional*, defaults to `208`): Number of slices to use from the CT scan. windowing (`Dict[str, float]`, *optional*): Windowing parameters for CT scan visualization. Default uses lung window: center=-600, width=1500. normalize (`bool`, *optional*, defaults to `True`): Whether to normalize pixel values to [0, 1]. **kwargs: Additional keyword arguments passed to the parent class. """ model_input_names = ["pixel_values"] def __init__( self, voxel_spacing: List[float] = None, img_size: List[int] = None, num_images: int = 208, windowing: Dict[str, float] = None, normalize: bool = True, **kwargs ): super().__init__(**kwargs) self.voxel_spacing = voxel_spacing if voxel_spacing is not None else [0.703125, 0.703125, 2.5] self.img_size = img_size if img_size is not None else [512, 512] self.num_images = num_images # Default lung window settings self.windowing = windowing if windowing is not None else { "center": -600, "width": 1500 } self.normalize = normalize # TorchIO transforms for standardization self.resample_transform = tio.transforms.Resample(target=self.voxel_spacing) # Note: Original Sybil uses 200 depth, 256x256 images self.default_depth = 200 self.default_size = [256, 256] self.padding_transform = tio.transforms.CropOrPad( target_shape=(self.default_depth, *self.default_size), padding_mode=0 ) def load_dicom_series(self, paths: List[str]) -> Tuple[np.ndarray, Dict]: """ Load a series of DICOM files. Args: paths: List of paths to DICOM files. Returns: Tuple of (volume array, metadata dict) """ dicoms = [] for path in paths: try: dcm = pydicom.dcmread(path, stop_before_pixels=False) dicoms.append(dcm) except Exception as e: print(f"Error reading DICOM file {path}: {e}") continue if not dicoms: raise ValueError("No valid DICOM files found") # Order slices by position dicoms = order_slices(dicoms) # Extract pixel arrays volume = np.stack([dcm.pixel_array.astype(np.float32) for dcm in dicoms]) # Extract metadata metadata = { "slice_thickness": float(dicoms[0].SliceThickness) if hasattr(dicoms[0], 'SliceThickness') else None, "pixel_spacing": list(map(float, dicoms[0].PixelSpacing)) if hasattr(dicoms[0], 'PixelSpacing') else None, "manufacturer": str(dicoms[0].Manufacturer) if hasattr(dicoms[0], 'Manufacturer') else None, "num_slices": len(dicoms) } # Apply rescale if present if hasattr(dicoms[0], 'RescaleSlope') and hasattr(dicoms[0], 'RescaleIntercept'): slope = float(dicoms[0].RescaleSlope) intercept = float(dicoms[0].RescaleIntercept) volume = volume * slope + intercept return volume, metadata def load_png_series(self, paths: List[str]) -> np.ndarray: """ Load a series of PNG files. Args: paths: List of paths to PNG files (must be in anatomical order). Returns: 3D volume array """ images = [] for path in paths: img = Image.open(path).convert('L') # Convert to grayscale images.append(np.array(img, dtype=np.float32)) return np.stack(images) def apply_windowing(self, volume: np.ndarray) -> np.ndarray: """ Apply windowing to CT scan for better visualization. Args: volume: 3D CT volume. Returns: Windowed volume. """ center = self.windowing["center"] width = self.windowing["width"] # Calculate window boundaries lower = center - width / 2 upper = center + width / 2 # Apply windowing volume = np.clip(volume, lower, upper) # Normalize to [0, 1] if requested if self.normalize: volume = (volume - lower) / (upper - lower) return volume def resample_volume( self, volume: torch.Tensor, original_spacing: Optional[List[float]] = None ) -> torch.Tensor: """ Resample volume to target voxel spacing. Args: volume: 3D volume tensor. original_spacing: Original voxel spacing. Returns: Resampled volume. """ # Create TorchIO subject subject = tio.Subject( image=tio.ScalarImage(tensor=volume.unsqueeze(0), spacing=original_spacing) ) # Apply resampling resampled = self.resample_transform(subject) return resampled['image'].data.squeeze(0) def pad_or_crop_volume(self, volume: torch.Tensor) -> torch.Tensor: """ Pad or crop volume to target shape. Args: volume: 3D volume tensor. Returns: Padded/cropped volume. """ # Create TorchIO subject subject = tio.Subject( image=tio.ScalarImage(tensor=volume.unsqueeze(0)) ) # Apply padding/cropping transformed = self.padding_transform(subject) return transformed['image'].data.squeeze(0) def preprocess( self, images: Union[List[str], np.ndarray, torch.Tensor], file_type: str = "dicom", voxel_spacing: Optional[List[float]] = None, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs ) -> BatchFeature: """ Preprocess CT scan images. Args: images: Either list of file paths or numpy/torch array of images. file_type: Type of input files ("dicom" or "png"). voxel_spacing: Original voxel spacing (required for PNG files). return_tensors: The type of tensors to return. Returns: BatchFeature with preprocessed images. """ # Load images if paths are provided if isinstance(images, list) and isinstance(images[0], str): if file_type == "dicom": volume, metadata = self.load_dicom_series(images) if voxel_spacing is None and metadata["pixel_spacing"]: voxel_spacing = metadata["pixel_spacing"] + [metadata["slice_thickness"]] elif file_type == "png": if voxel_spacing is None: raise ValueError("voxel_spacing must be provided for PNG files") volume = self.load_png_series(images) else: raise ValueError(f"Unknown file type: {file_type}") elif isinstance(images, (np.ndarray, torch.Tensor)): volume = images else: raise ValueError("Images must be file paths, numpy array, or torch tensor") # Convert to torch tensor if isinstance(volume, np.ndarray): volume = torch.from_numpy(volume).float() # Apply windowing if isinstance(volume, torch.Tensor): volume_np = volume.numpy() else: volume_np = volume volume_np = self.apply_windowing(volume_np) volume = torch.from_numpy(volume_np).float() # Resample if spacing is provided if voxel_spacing is not None: volume = self.resample_volume(volume, voxel_spacing) # Pad or crop to target shape volume = self.pad_or_crop_volume(volume) # Reshape to match original Sybil format: (D, H, W) -> (C, D, H, W) # The model expects 3 channels (RGB format), so repeat grayscale to 3 channels volume = volume.unsqueeze(0).repeat(3, 1, 1, 1) # Now (3, D, H, W) # Prepare output data = {"pixel_values": volume} # Convert to requested tensor type if return_tensors == "pt": return BatchFeature(data=data, tensor_type=TensorType.PYTORCH) elif return_tensors == "np": data = {k: v.numpy() for k, v in data.items()} return BatchFeature(data=data, tensor_type=TensorType.NUMPY) else: return BatchFeature(data=data) def __call__( self, images: Union[List[str], List[List[str]], np.ndarray, torch.Tensor], **kwargs ) -> BatchFeature: """ Main method to prepare images for the model. Args: images: Images to preprocess. Can be: - List of file paths for a single series - List of lists of file paths for multiple series - Numpy array or torch tensor Returns: BatchFeature with preprocessed images ready for model input. """ # Handle batch processing if isinstance(images, list) and images and isinstance(images[0], list): # Multiple series batch_volumes = [] for series_paths in images: result = self.preprocess(series_paths, **kwargs) batch_volumes.append(result["pixel_values"]) # Stack into batch (B, C, D, H, W) pixel_values = torch.stack(batch_volumes) return BatchFeature(data={"pixel_values": pixel_values}) else: # Single series return self.preprocess(images, **kwargs)