File size: 14,368 Bytes
f2a52eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
def caption_slide(image_path, slide_name, prompt="Diagnosis:", output_dir="./output"):
    """Captions a Whole Slide Image(WSI).
    
    Parameters
    ----------
    image_path: str
        Path to the whole slide image file.
    slide_name: str
        Name of whole slide image file
    prompt: str
        Starting prompt of the generated caption (default: "Diagnosis:")
    output_dir: str, optional
        Directory to save output files (default: "./output")
    Returns
    -------
    str
        Research log summarizing analysis and results
    """
    import os
    import glob
    import timm
    import torch
    from PIL import Image
    import lazyslide as zs
    from pathlib import Path
    from datetime import datetime
    from transformers import AutoModel
    from timm.layers import SwiGLUPacked
    from timm.data import resolve_data_config
    from huggingface_hub import login, whoami
    from timm.data.transforms_factory import create_transform
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # Step 1: Login to HuggingFace
    login(token=os.getenv("HUGGINGFACE_ACCESS_TOKEN"))
    hf_user = whoami()
    username = hf_user['name']

    # Step 2: Setup models and transforms
    virchow2 = timm.create_model("hf-hub:paige-ai/Virchow2", pretrained=True, mlp_layer=SwiGLUPacked, act_layer=torch.nn.SiLU)
    virchow2 = virchow2.eval()
    prism = AutoModel.from_pretrained('paige-ai/Prism', trust_remote_code=True)
    prism = prism.to(device)
    transforms = create_transform(**resolve_data_config(virchow2.pretrained_cfg, model=virchow2))
    tile_embeddings = []
    # Step 3: Initialize, process, tile, and encode slide file(s)
    files = [f for f in glob.glob(f"{image_path}/*") if slide_name in os.path.basename(f)]
    if len(files) == 1 and files[0].endswith(".svs"):
        # dealing with the whole slide in itself
        wsi = zs.open_wsi(f"{image_path}/{slide_name}.svs")
        tiles, tile_spec = zs.pp.tile_tissues(wsi, 224, mpp=0.5, return_tiles=True)

        tile_dir = Path("tiles")
        tile_dir.mkdir(exist_ok=True)
        for _, row in tiles.iterrows():
            tile_id = row["tile_id"]
            geometry = row["geometry"]  # shapely Polygon of the tile
            # Get top-left corner of the tile
            minx, miny, maxx, maxy = geometry.bounds
            width = int(maxx - minx)
            height = int(maxy - miny)

            # Read the tile from WSI
            tile_img = wsi.read_region(int(minx), int(miny), width, height, tile_spec.ops_level)
            tile_img = Image.fromarray(tile_img, 'RGB')
            tile_tensor = transforms(tile_img).unsqueeze(0)
            output = virchow2(tile_tensor)
            class_token = output[:, 0]
            patch_tokens = output[:, 1:]

            embedding = torch.cat([class_token, patch_tokens.mean(1)], dim=-1)
            tile_embeddings.append(embedding)

            # Save as PNG
            tile_path = tile_dir / f"tile_{tile_id:05d}.png"
            tile_img.save(tile_path)
    else:
        # dealing with patches (not svs); need to encode tiles with Virchow directly
        for file in files:
            tile_img = Image.open(file).convert('RGB')
            tile_tensor = transforms(tile_img).unsqueeze(0)
            output = virchow2(tile_tensor)
            class_token = output[:, 0]
            patch_tokens = output[:, 1:]
            embedding = torch.cat([class_token, patch_tokens.mean(1)], dim=-1)
            tile_embeddings.append(embedding)

    tile_embeddings = torch.cat(tile_embeddings, dim=0).unsqueeze(0).to(device)    
    with torch.autocast(device, torch.float16), torch.inference_mode():
        reprs = prism.slide_representations(tile_embeddings)
        genned_ids = prism.generate(
            key_value_states=reprs['image_latents'],
            do_sample=False,
            num_beams=5,
            num_beam_groups=1,
        )
        generated_caption = prism.untokenize(genned_ids)
    
    # Step 4: Generate caption using latent representation and initial prompt

    log = f"""
Research Log: Whole Slide Image Captioning
Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
Image Path: {os.path.basename(image_path)}
Slide Name: {slide_name}

Analysis Steps:
1. Logged into HuggingFace as {username}
2. Load in PRISM and Virchow2 models for encoding and captioning
3. Initialized, processed, tiled, and encode slide file(s)
4. Generated the caption with "{prompt}" as initial prompt

Results:

Caption
-------
{generated_caption}
"""

    return log


def segment_slide(image_path, seg_type, model, output_dir="./output"):
    """Segment a Whole Slide Image (WSI).
    
    Parameters
    ----------
    image_path: str
        Path to the whole slide image file.
    seg_type: str
        Type of segmentation to perform
    model: str
        Segmentation model to use
    output_dir: str, optional
        Directory to save output files (default: "./output")
    Returns
    -------
    str
        Research log summarizing analysis and results
    """
    import os
    import lazyslide as zs
    from datetime import datetime
    from huggingface_hub import login, whoami
    
    # Step 1: Perform validity checking
    usable_models = set(zs.models.list_models("segmentation"))
    if seg_type not in {"cells", "cell_type", "semantic", "tissue", "artifact"}: return None
    if model not in usable_models: return None
    if seg_type == "tissue" and model not in {"grandqc", "pathprofiler"}: return None
    if seg_type == "artifact" and model != "grandqc": return None
    if seg_type == "cells" and model not in {"instanseg", "cellpose"}: return None
    if seg_type == "cell_type" and model != "nulite": return None

    # Step 2: Login to HuggingFace if gated model
    login(token=os.getenv("HUGGINGFACE_ACCESS_TOKEN"))
    hf_user = whoami()
    username = hf_user['name']

    # Step 3: Open, process, and tile WSI image
    wsi = zs.open_wsi(image_path)
    zs.pp.find_tissues(wsi)
    zs.pp.tile_graph(wsi)
    #TODO Change values
    zs.pp.tile_tissues(wsi, 512, background_fraction=0.95, mpp=0.5)

    # Step 4: Appropriately Segment the slide
    if seg_type == "cells":
        zs.seg.cells(wsi, model=model)
    elif seg_type == "cell_type":
        zs.seg.cell_type(wsi, model=model)
    elif seg_type == "semantic":
        zs.seg.semantic(wsi, model=model)
    elif seg_type == "tissue":
        zs.seg.tissue(wsi, model=model)
    else:
        zs.seg.artifact(wsi, model=model)
    
    # Step 5: Generate WSI with annotations

    log = f"""
Research Log: Whole Slide Image Segmentation
Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
Image: {os.path.basename(image_path)}

Analysis Steps:
1. Performed validity checking
2. Logged into HuggingFace as {username}
3. Open WSI, find, tile and graph tissues
4. Segmented tissues using {model}
5. Generated and displayed segmentation results in {output_dir} 

Results:

Output Files
"""
    return log

def zero_shot_classification(image_path, labels, output_dir="./output"):
    """Performs Zero-Shot Classification from Whole Slide Images (WSIs).

    Parameters
    ----------
    image_path: str
        Path to the whole slide image file.
    labels: list
        Labels of the classes to perform zero-shot classification
    output_dir: str, optional
        Directory to save output files (default: "./output")
    
    Returns
    -------
    str
        Research log summarizing analysis and results
    """
    import os
    import lazyslide as zs
    from datetime import datetime
    from huggingface_hub import login, whoami

    # login to huggingface; zero shot via LazySlide only possible with gated models
    login(token=os.getenv("HUGGINGFACE_ACCESS_TOKEN"))
    hf_user = whoami()
    username = hf_user['name']
    wsi = zs.open_wsi(image_path)
    zs.pp.find_tissues(wsi)
    zs.pp.tile_tissues(wsi, 512, background_fraction=0.95, mpp=0.5)
    # might want to make tile graph
    # zs.pp.tile_graph(wsi)
    
    zs.tl.feature_extraction(wsi, "virchow")
    zs.tl.feature_aggregation(wsi, feature_key="virchow", encoder="prism")
    results = zs.tl.zero_shot_score(wsi, labels, feature_key="virchow_tiles")
    log = f"""
Research Log: Zero-Shot Classification
Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
Image: {os.path.basename(image_path)}

Analysis Steps:
1. Logged in as user {username} to HuggingFace
2. Loaded WSI: {wsi}
3. Found tissues
4. Tiled tissues
5. Extracted features
6. Aggregated features


Results:
{results}

Output Files:

"""
    print(log)
    return log

def quantify_tumor_infiltrating_lymphocites(image_path, tile_size=256, tile_step=128, batch_size=4, output_dir="./output"):
    """Quantifies Tumor-Infiltrating Lymphocytes (TILs) from Whole-Slide Images (WSIs).

    Parameters
    ----------
    image_path: str
        Path to the whole slide image file.
    tile_size: int, optional
        Size of inference tiles (default: 256)
    tile_step: int, optional
        Step size between inference tiles (default: 128)
    batch_size: int, optional
        Simulatenous inference tiles (default: 4)
    output_dir: str, optional
        Directory to save output files (default: "./output")
    Returns
    -------
    str
        Research log summarizing analysis and results
        
    """
    import os
    import numpy as np
    import pandas as pd
    import lazyslide as zs
    from datetime import datetime
    import matplotlib.pyplot as plt
    
    # Step 1: Load WSI via LazySlide
    try:
        wsi = zs.open_wsi(image_path)
    except Exception as e:
        return f"Error loading WSI: {str(e)}"
        
    # Step 2: Build a tissue mask + upscale it for higher resolutions
    try:
        tissue_mask = zs.pp.find_tissues(wsi, refine_level=0, to_hsv=True)
    except:
        return f"Error building tissue mask: {str(e)}"


    # Step 3: Cell type segmentation using LazySlide"s seg.cell_types
    try:
        zs.seg.cell_types(wsi, batch_size=batch_size)
    except Exception as e:
        return f"Error during cell type segmentation: {str(e)}"

    # Step 4: Load results
    instance_map = zs.io.load_annotations(wsi, "instance_map")
    type_map = zs.io.load_annotations(wsi, "cell_types")  # may include TIL labels

    instance_map_path = os.path.join(output_dir, "instance_map.npy")
    type_map_path = os.path.join(output_dir, "cell_type_map.npy")
    np.save(instance_map_path, instance_map)
    np.save(type_map_path, type_map)

    # Step 5: Define the TIL cell type ID (e.g., 1 for TILs)
    til_type_id = 1
    
    # Step 6: Compute TIL counts
    valid_cells = tissue_mask & (type_map == til_type_id)
    total_cells = np.count_nonzero(valid_cells)
    til_cells = np.count_nonzero(valid_cells & (type_map == til_type_id))

    # Step 7: Compute densities
    pixel_area_mm2 = (wsi.mpp ** 2) / 1e6  # convert μm² to mm²
    roi_area_mm2 = np.count_nonzero(tissue_mask) * pixel_area_mm2
    til_density = til_cells / roi_area_mm2 if roi_area_mm2 > 0 else float("nan")
    total_density = total_cells / roi_area_mm2 if roi_area_mm2 > 0 else float("nan")
    til_fraction = til_cells / total_cells if total_cells > 0 else float("nan")

    # Step 6: Save metrics CSV
    metrics = {
        "total_nuclei": total_cells,
        "til_nuclei": til_cells,
        "til_fraction": til_fraction,
        "til_density_per_mm2": til_density,
        "total_density_per_mm2": total_density,
        "roi_area_mm2": roi_area_mm2
    }
    metrics_df = pd.DataFrame([metrics])
    metrics_path = os.path.join(output_dir, "metrics.csv")
    metrics_df.to_csv(metrics_path, index=False)

    # Step 7: Create and save overlay visualization
    overlay = np.zeros((*type_map.shape, 3), dtype=np.uint8)
    overlay[type_map == til_type_id] = [255, 0, 0]  # red for TILs
    overlay[(type_map != til_type_id) & (instance_map > 0)] = [0, 255, 0]  # green for other nuclei
    overlay_path = os.path.join(output_dir, "overlay.png")
    plt.imsave(overlay_path, overlay)

    # Step 8: Create and return research log
    log = f"""
Research Log: Quantification of Tumor-Infiltrating Lymphocytes
Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
Image: {os.path.basename(image_path)}
    
Analysis Steps:
1. Loaded and preprocessed the whole slide image into upscaled tiles
2. Applied NuLite Nucleus Instance Segmentation and Classification on tiles
3. Computed and quantified TIL (based on inflammed cell class) and total nuclear density

Results:
- Total Nuclei: {int(total_cells)}
- Total Inflammed Nuclei: {int(til_cells)}
- Fiber Density: {til_density:.2f}

Output Files:
- Segmented Image: {os.path.basename(overlay_path)}
- Measurements: {os.path.basename(metrics_path)}
"""
    
    return log

def quantify_fibrosis(image_path, model="grandqc", output_dir="./output"):
    """Quantifies Fibrosis from Whole Slide Images (WSIs).

    Parameters
    ----------
    image_path: str
        Path to the image file.
    output_dir: str, optional
        Directory to save output files (default: "./output")
    model: str, optional
        Tissue segmentation model to use (default: grandqc)

    Returns
    -------
    str
        Research log summarizing analysis and results
    """
    import os
    import lazyslide as zs
    from datetime import datetime
    # Step 1: Load WSI via LazySlide
    try:
        wsi = zs.open_wsi(image_path)
    except Exception as e:
        return f"Error loading WSI: {str(e)}"

    zs.seg.tissue(wsi, model=model)
    log = f"""
Research Log: Template
Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
Image: {os.path.basename(image_path)}

Analysis Steps:
1.
2.
3.

Results:
-
-
-

Output Files:
- 
- 


    """
    return log

# def template(image_path, output_dir="./output"):
#     """Template.

#     Parameters
#     ----------
#     image_path: str
#         Path to the image file.
#     output_dir: str, optional
#         Directory to save output files (default: "./output")

#     Returns
#     -------
#     str
#         Research log summarizing analysis and results
#     """
#     # Step X
    
#     log = f"""
# Research Log: Template
# Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
# Image: {os.path.basename(image_path)}

# Analysis Steps:
# 1.
# 2.
# 3.

# Results:
# -
# -
# -

# Output Files:
# - 
# - 


#     """
#     return log