File size: 10,951 Bytes
1206896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
"""Image processor for Sybil CT scan preprocessing"""

import numpy as np
import torch
from typing import Dict, List, Optional, Union, Tuple
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.utils import TensorType
import pydicom
from PIL import Image
import torchio as tio


def order_slices(dicoms: List) -> List:
    """Order DICOM slices by their position"""
    # Sort by ImagePositionPatient if available
    try:
        dicoms = sorted(dicoms, key=lambda x: float(x.ImagePositionPatient[2]))
    except (AttributeError, TypeError):
        # Fall back to InstanceNumber if ImagePositionPatient not available
        try:
            dicoms = sorted(dicoms, key=lambda x: int(x.InstanceNumber))
        except (AttributeError, TypeError):
            pass  # Keep original order if neither attribute is available
    return dicoms


class SybilImageProcessor(BaseImageProcessor):
    """
    Constructs a Sybil image processor for preprocessing CT scans.

    Args:
        voxel_spacing (`List[float]`, *optional*, defaults to `[0.703125, 0.703125, 2.5]`):
            Target voxel spacing for resampling (row, column, slice thickness).
        img_size (`List[int]`, *optional*, defaults to `[512, 512]`):
            Target image size after resizing.
        num_images (`int`, *optional*, defaults to `208`):
            Number of slices to use from the CT scan.
        windowing (`Dict[str, float]`, *optional*):
            Windowing parameters for CT scan visualization.
            Default uses lung window: center=-600, width=1500.
        normalize (`bool`, *optional*, defaults to `True`):
            Whether to normalize pixel values to [0, 1].
        **kwargs:
            Additional keyword arguments passed to the parent class.
    """

    model_input_names = ["pixel_values"]

    def __init__(
        self,
        voxel_spacing: List[float] = None,
        img_size: List[int] = None,
        num_images: int = 208,
        windowing: Dict[str, float] = None,
        normalize: bool = True,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.voxel_spacing = voxel_spacing if voxel_spacing is not None else [0.703125, 0.703125, 2.5]
        self.img_size = img_size if img_size is not None else [512, 512]
        self.num_images = num_images

        # Default lung window settings
        self.windowing = windowing if windowing is not None else {
            "center": -600,
            "width": 1500
        }
        self.normalize = normalize

        # TorchIO transforms for standardization
        self.resample_transform = tio.transforms.Resample(target=self.voxel_spacing)
        # Note: Original Sybil uses 200 depth, 256x256 images
        self.default_depth = 200
        self.default_size = [256, 256]
        self.padding_transform = tio.transforms.CropOrPad(
            target_shape=(self.default_depth, *self.default_size),
            padding_mode=0
        )

    def load_dicom_series(self, paths: List[str]) -> Tuple[np.ndarray, Dict]:
        """
        Load a series of DICOM files.

        Args:
            paths: List of paths to DICOM files.

        Returns:
            Tuple of (volume array, metadata dict)
        """
        dicoms = []
        for path in paths:
            try:
                dcm = pydicom.dcmread(path, stop_before_pixels=False)
                dicoms.append(dcm)
            except Exception as e:
                print(f"Error reading DICOM file {path}: {e}")
                continue

        if not dicoms:
            raise ValueError("No valid DICOM files found")

        # Order slices by position
        dicoms = order_slices(dicoms)

        # Extract pixel arrays
        volume = np.stack([dcm.pixel_array.astype(np.float32) for dcm in dicoms])

        # Extract metadata
        metadata = {
            "slice_thickness": float(dicoms[0].SliceThickness) if hasattr(dicoms[0], 'SliceThickness') else None,
            "pixel_spacing": list(map(float, dicoms[0].PixelSpacing)) if hasattr(dicoms[0], 'PixelSpacing') else None,
            "manufacturer": str(dicoms[0].Manufacturer) if hasattr(dicoms[0], 'Manufacturer') else None,
            "num_slices": len(dicoms)
        }

        # Apply rescale if present
        if hasattr(dicoms[0], 'RescaleSlope') and hasattr(dicoms[0], 'RescaleIntercept'):
            slope = float(dicoms[0].RescaleSlope)
            intercept = float(dicoms[0].RescaleIntercept)
            volume = volume * slope + intercept

        return volume, metadata

    def load_png_series(self, paths: List[str]) -> np.ndarray:
        """
        Load a series of PNG files.

        Args:
            paths: List of paths to PNG files (must be in anatomical order).

        Returns:
            3D volume array
        """
        images = []
        for path in paths:
            img = Image.open(path).convert('L')  # Convert to grayscale
            images.append(np.array(img, dtype=np.float32))

        return np.stack(images)

    def apply_windowing(self, volume: np.ndarray) -> np.ndarray:
        """
        Apply windowing to CT scan for better visualization.

        Args:
            volume: 3D CT volume.

        Returns:
            Windowed volume.
        """
        center = self.windowing["center"]
        width = self.windowing["width"]

        # Calculate window boundaries
        lower = center - width / 2
        upper = center + width / 2

        # Apply windowing
        volume = np.clip(volume, lower, upper)

        # Normalize to [0, 1] if requested
        if self.normalize:
            volume = (volume - lower) / (upper - lower)

        return volume

    def resample_volume(
        self,
        volume: torch.Tensor,
        original_spacing: Optional[List[float]] = None
    ) -> torch.Tensor:
        """
        Resample volume to target voxel spacing.

        Args:
            volume: 3D volume tensor.
            original_spacing: Original voxel spacing.

        Returns:
            Resampled volume.
        """
        # Create TorchIO subject
        subject = tio.Subject(
            image=tio.ScalarImage(tensor=volume.unsqueeze(0), spacing=original_spacing)
        )

        # Apply resampling
        resampled = self.resample_transform(subject)

        return resampled['image'].data.squeeze(0)

    def pad_or_crop_volume(self, volume: torch.Tensor) -> torch.Tensor:
        """
        Pad or crop volume to target shape.

        Args:
            volume: 3D volume tensor.

        Returns:
            Padded/cropped volume.
        """
        # Create TorchIO subject
        subject = tio.Subject(
            image=tio.ScalarImage(tensor=volume.unsqueeze(0))
        )

        # Apply padding/cropping
        transformed = self.padding_transform(subject)

        return transformed['image'].data.squeeze(0)

    def preprocess(
        self,
        images: Union[List[str], np.ndarray, torch.Tensor],
        file_type: str = "dicom",
        voxel_spacing: Optional[List[float]] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs
    ) -> BatchFeature:
        """
        Preprocess CT scan images.

        Args:
            images: Either list of file paths or numpy/torch array of images.
            file_type: Type of input files ("dicom" or "png").
            voxel_spacing: Original voxel spacing (required for PNG files).
            return_tensors: The type of tensors to return.

        Returns:
            BatchFeature with preprocessed images.
        """
        # Load images if paths are provided
        if isinstance(images, list) and isinstance(images[0], str):
            if file_type == "dicom":
                volume, metadata = self.load_dicom_series(images)
                if voxel_spacing is None and metadata["pixel_spacing"]:
                    voxel_spacing = metadata["pixel_spacing"] + [metadata["slice_thickness"]]
            elif file_type == "png":
                if voxel_spacing is None:
                    raise ValueError("voxel_spacing must be provided for PNG files")
                volume = self.load_png_series(images)
            else:
                raise ValueError(f"Unknown file type: {file_type}")
        elif isinstance(images, (np.ndarray, torch.Tensor)):
            volume = images
        else:
            raise ValueError("Images must be file paths, numpy array, or torch tensor")

        # Convert to torch tensor
        if isinstance(volume, np.ndarray):
            volume = torch.from_numpy(volume).float()

        # Apply windowing
        if isinstance(volume, torch.Tensor):
            volume_np = volume.numpy()
        else:
            volume_np = volume
        volume_np = self.apply_windowing(volume_np)
        volume = torch.from_numpy(volume_np).float()

        # Resample if spacing is provided
        if voxel_spacing is not None:
            volume = self.resample_volume(volume, voxel_spacing)

        # Pad or crop to target shape
        volume = self.pad_or_crop_volume(volume)

        # Reshape to match original Sybil format: (D, H, W) -> (C, D, H, W)
        # The model expects 3 channels (RGB format), so repeat grayscale to 3 channels
        volume = volume.unsqueeze(0).repeat(3, 1, 1, 1)  # Now (3, D, H, W)

        # Prepare output
        data = {"pixel_values": volume}

        # Convert to requested tensor type
        if return_tensors == "pt":
            return BatchFeature(data=data, tensor_type=TensorType.PYTORCH)
        elif return_tensors == "np":
            data = {k: v.numpy() for k, v in data.items()}
            return BatchFeature(data=data, tensor_type=TensorType.NUMPY)
        else:
            return BatchFeature(data=data)

    def __call__(
        self,
        images: Union[List[str], List[List[str]], np.ndarray, torch.Tensor],
        **kwargs
    ) -> BatchFeature:
        """
        Main method to prepare images for the model.

        Args:
            images: Images to preprocess. Can be:
                - List of file paths for a single series
                - List of lists of file paths for multiple series
                - Numpy array or torch tensor

        Returns:
            BatchFeature with preprocessed images ready for model input.
        """
        # Handle batch processing
        if isinstance(images, list) and images and isinstance(images[0], list):
            # Multiple series
            batch_volumes = []
            for series_paths in images:
                result = self.preprocess(series_paths, **kwargs)
                batch_volumes.append(result["pixel_values"])

            # Stack into batch (B, C, D, H, W)
            pixel_values = torch.stack(batch_volumes)
            return BatchFeature(data={"pixel_values": pixel_values})
        else:
            # Single series
            return self.preprocess(images, **kwargs)