|
|
import torch |
|
|
import torch.nn as nn |
|
|
from monai.networks.nets import ViT |
|
|
import os |
|
|
|
|
|
|
|
|
class ViTBackboneNet(nn.Module): |
|
|
def __init__(self, simclr_ckpt_path: str): |
|
|
super().__init__() |
|
|
self.backbone = ViT( |
|
|
in_channels=1, |
|
|
img_size=(96, 96, 96), |
|
|
patch_size=(16, 16, 16), |
|
|
hidden_size=768, |
|
|
mlp_dim=3072, |
|
|
num_layers=12, |
|
|
num_heads=12, |
|
|
save_attn=True, |
|
|
) |
|
|
|
|
|
if simclr_ckpt_path and os.path.exists(simclr_ckpt_path): |
|
|
ckpt = torch.load(simclr_ckpt_path, map_location="cpu", weights_only=False) |
|
|
state_dict = ckpt.get("state_dict", ckpt) |
|
|
backbone_state_dict = {} |
|
|
for key, value in state_dict.items(): |
|
|
if key.startswith("backbone."): |
|
|
new_key = key[len("backbone."):] |
|
|
backbone_state_dict[new_key] = value |
|
|
missing, unexpected = self.backbone.load_state_dict(backbone_state_dict, strict=False) |
|
|
print(f"Loaded SimCLR backbone weights. Missing: {len(missing)}, Unexpected: {len(unexpected)}") |
|
|
else: |
|
|
print("Warning: SimCLR checkpoint not found or not provided. Using randomly initialized backbone.") |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
features = self.backbone(x) |
|
|
cls_token = features[0][:, 0] |
|
|
return cls_token |
|
|
|
|
|
|
|
|
class Classifier(nn.Module): |
|
|
def __init__(self, d_model: int = 768, num_classes: int = 1): |
|
|
super().__init__() |
|
|
self.fc = nn.Linear(d_model, num_classes) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.fc(x) |
|
|
|
|
|
|
|
|
class SingleScanModelBP(nn.Module): |
|
|
def __init__(self, backbone: nn.Module, classifier: nn.Module): |
|
|
super().__init__() |
|
|
self.backbone = backbone |
|
|
self.classifier = classifier |
|
|
self.dropout = nn.Dropout(p=0.2) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
scan_features_list = [] |
|
|
for scan_tensor_with_extra_dim in x.split(1, dim=1): |
|
|
squeezed_scan_tensor = scan_tensor_with_extra_dim.squeeze(1) |
|
|
feature = self.backbone(squeezed_scan_tensor) |
|
|
scan_features_list.append(feature) |
|
|
stacked_features = torch.stack(scan_features_list, dim=1) |
|
|
merged_features = torch.mean(stacked_features, dim=1) |
|
|
merged_features = self.dropout(merged_features) |
|
|
output = self.classifier(merged_features) |
|
|
return output |