|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Dict, Optional |
|
|
|
|
|
import torch |
|
|
from torch import Tensor, nn |
|
|
|
|
|
from .slide_transformer import VisionTransformer |
|
|
|
|
|
__all__ = ["WSIEncoderHead"] |
|
|
|
|
|
|
|
|
class WSIEncoderHead(nn.Module): |
|
|
"""Adapter around VisionTransformer with aggregation over patch tokens. |
|
|
|
|
|
Inputs: |
|
|
- patch_features: [B, N, C] |
|
|
- patch_mask: [B, N] with 1 for valid tokens (required for correct masking) |
|
|
- patch_coords: optional [B, N, 2] integer coords for RoPE |
|
|
|
|
|
Returns: |
|
|
- dict with exactly two keys: |
|
|
- patch_embedding: [B, N, C_in + C] concat(raw_patch_features, transformer_patch_tokens) |
|
|
- slide_embedding: [B, C_in + C] concat(masked_mean(raw_patch_features), masked_mean(transformer_patch_tokens)) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
transformer: VisionTransformer, |
|
|
input_dim: int, |
|
|
embed_dim: int, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.transformer = transformer |
|
|
self.embed_dim = int(embed_dim) |
|
|
self.input_dim = int(input_dim) |
|
|
|
|
|
def _masked_mean(self, tokens: Tensor, mask: Optional[Tensor]) -> Tensor: |
|
|
"""Mask-aware mean over sequence dimension without fallback. |
|
|
|
|
|
- tokens: [B, N, C] |
|
|
- mask: [B, N] with 1 valid, 0 invalid; when all invalid, returns zero-vector mean (sum=0, count=1) |
|
|
""" |
|
|
if mask is None: |
|
|
return tokens.mean(dim=1) |
|
|
valid = mask.to(dtype=tokens.dtype).unsqueeze(-1) |
|
|
sums = (tokens * valid).sum(dim=1) |
|
|
counts = valid.sum(dim=1).clamp_min(1.0) |
|
|
return sums / counts |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
patch_features: Tensor, |
|
|
patch_mask: Tensor, |
|
|
patch_coords: Optional[Tensor] = None, |
|
|
patch_contour_index: Optional[Tensor] = None, |
|
|
) -> Dict[str, Tensor]: |
|
|
|
|
|
if patch_mask is None: |
|
|
raise ValueError("WSIFeatureEncoder requires patch_mask (shape [B, N]) to be provided.") |
|
|
|
|
|
mask = patch_mask.to(device=patch_features.device) |
|
|
|
|
|
encoded = self.transformer( |
|
|
patch_features, |
|
|
masks=mask, |
|
|
coords=patch_coords, |
|
|
contour_index=patch_contour_index, |
|
|
) |
|
|
patch_tokens = encoded["x_norm_patchtokens"] |
|
|
|
|
|
|
|
|
patch_embedding = torch.cat([patch_features, patch_tokens], dim=-1) |
|
|
|
|
|
|
|
|
raw_patch_mean = self._masked_mean(patch_features, mask) |
|
|
token_mean = self._masked_mean(patch_tokens, mask) |
|
|
slide_embedding = torch.cat([raw_patch_mean, token_mean], dim=-1) |
|
|
|
|
|
return { |
|
|
"patch_embedding": patch_embedding, |
|
|
"slide_embedding": slide_embedding, |
|
|
} |
|
|
|