zekaic commited on
Commit
88fdd12
·
verified ·
1 Parent(s): f94c6d3

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +149 -0
README.md ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ ---
4
+
5
+ # SMB-RAD-Encoder-v1
6
+
7
+ SMB-RAD-Encoder-v1 is a pure-vision backbone designed for medical imaging foundation models, with a strong focus on radiology modalities such as CT, MRI, and X‑ray. It implements efficient 3D patch embedding, rotary position encodings, scalable Transformer blocks, multi‑scale deep feature extraction, and two self‑supervised objectives tailored for medical imagery: masked image modeling (MIM) and joint embedding predictive architecture (JEPA).
8
+
9
+
10
+ ## Architecture Overview
11
+
12
+ The implementation lives in `modeling_smb_vision.py` and exposes three main classes:
13
+
14
+ - `SMBVisionEncoder`: Vision encoder with 3D patch embedding and stacked Transformer blocks
15
+ - `SMBVisionPredictor`: Lightweight Transformer for JEPA next‑embedding prediction
16
+ - `SMBVisionModel`: Wrapper that combines encoder + predictor and computes MIM and JEPA losses
17
+
18
+ Key components and how they map to the code:
19
+
20
+ - **3D Patch Embedding (`SMBVisionPatchEmbed`)**
21
+ - A `Conv3d` with kernel=stride=`[temporal_patch_size, patch_size, patch_size]` over per‑patch tensors
22
+ - Supports `in_channels` = 1 (grayscale), 3 (RGB), or 4; radiology typically uses 1
23
+ - Produces per‑patch embeddings of size `hidden_size`
24
+
25
+ - **Learned 2D Positional Embedding + Fast Interpolation**
26
+ - `pos_embed: nn.Embedding(num_position_embeddings, hidden_size)` with bilinear‑style interpolation (`fast_pos_embed_interpolate`) to target grid sizes (height×width) per frame
27
+
28
+ - **Rotary Position Embedding (RoPE) in Space (and Time)**
29
+ - `SMBVisionRotaryEmbedding` generates frequencies; applied in attention via `apply_rotary_pos_emb_vision`
30
+ - Encodes spatial (and slice/temporal) structure for robust geometric reasoning
31
+
32
+ - **Transformer Blocks (`SMBVisionBlock`)**
33
+ - Pre‑norm residual blocks with `SMBVisionAttention` and `SMBVisionMLP`
34
+ - Attention backends: eager, SDPA, FlashAttention‑2 (config‑selectable)
35
+
36
+ - **DeepStack Multi‑scale Features**
37
+ - `deepstack_visual_indexes` selects block indices whose outputs are merged by `SMBVisionPatchMerger`
38
+ - Produces multi‑level visual descriptors for downstream tasks (e.g., detection, retrieval)
39
+
40
+ - **Masked Image Modeling (MIM)**
41
+ - Randomly masks a ratio of patch tokens and reconstructs pixels via `to_pixels: Linear(hidden_size -> patch_volume)`
42
+ - Reconstruction loss: L1 (MAE) on masked patches
43
+ - Note: For medical grayscale data, set `in_channels=1` so reconstruction target matches output shape
44
+
45
+ - **JEPA Next‑Embedding Prediction**
46
+ - Context/target partitions at the study level expand to patch tokens internally
47
+ - `SMBVisionPredictor` predicts target encoder embeddings; loss is MSE on target tokens
48
+
49
+
50
+ ## Radiology‑centric Design Notes
51
+
52
+ - **Modalities**: CT/MRI volumes (slice stacks) and X‑ray images are supported via patch tokenization
53
+ - **Through‑plane handling**: `temporal_patch_size` acts as slice depth for 3D patching over the Z/through‑plane axis
54
+ - **Grayscale emphasis**: Use `in_channels=1` for CT/MRI/X‑ray to align MIM reconstruction shapes
55
+ - **Scalability**: Attention backends support SDPA and FlashAttention‑2 for large studies and high‑res inputs
56
+ - **Multi‑scale features**: `deepstack_visual_indexes` provide hooks for detection/segmentation heads
57
+
58
+
59
+ ## Installation
60
+
61
+ ```bash
62
+ pip install torch torchvision
63
+ pip install transformers nibabel monai smb_biopan_utils
64
+ pip install 'monai[all]'
65
+ ```
66
+
67
+
68
+ ## Quick Start (CT volumes)
69
+
70
+ The encoder expects a list of patch tokens and a per‑sample grid descriptor `grid_thw = [T, H, W]`, where:
71
+
72
+ - `T = num_slices / temporal_patch_size`
73
+ - `H = image_height / patch_size`
74
+ - `W = image_width / patch_size`
75
+
76
+ You must first patchify the volume into non‑overlapping 3D patches of shape `[in_channels, temporal_patch_size, patch_size, patch_size]`, flatten each patch to a token, and concatenate all tokens for the batch.
77
+
78
+ Example helper for NIfTI volumes:
79
+
80
+ ```python
81
+ from smb_biopan_utils import process_mm_info
82
+ from transformers import AutoModel
83
+
84
+
85
+ # Prepare message spec for your volume(s). Each "image" can be a path to NIfTI/DICOM.
86
+ messages = [
87
+ {
88
+ "content": [
89
+ {"type": "image", "image": "dummy.nii.gz"}, # Volume size is [1, 64, 160, 160]
90
+ {"type": "image", "image": "dummy.nii.gz"},
91
+ ]
92
+ }
93
+ ]
94
+
95
+ # Convert to patch tokens and grid descriptor expected by SMB‑Vision
96
+ # Default patch size is 16 for all dimensions
97
+ images, grid_thw = process_mm_info(messages) # images size is [800(400*2), 4096]
98
+
99
+ # Optional - Dummy images and grid_thw
100
+ images, grid_thw = torch.randn(800, 4096), torch.tensor([[4, 10, 10], [4, 10, 10]])
101
+
102
+ # Load backbone from HF Hub (uses this repo's modeling with trust_remote_code)
103
+ model = AutoModel.from_pretrained(
104
+ "standardmodelbio/SMB-RAD-Encoder-v1",
105
+ trust_remote_code=True,
106
+ dtype=torch.bfloat16,
107
+ attn_implementation="flash_attention_2",
108
+ )
109
+ model.to("cuda")
110
+
111
+ # Encode features
112
+ encoded_patches, deepstack_features = model.forward_features(
113
+ images.to("cuda"), grid_thw=grid_thw.to("cuda")
114
+ )
115
+ print(encoded_patches.shape)
116
+ # (800, 1152)
117
+ ```
118
+
119
+ ## API Summary
120
+
121
+ - `SMBVisionEncoder.forward(hidden_states, grid_thw)` → `(encoded_patches, deepstack_features)`
122
+ - `hidden_states`: Float tensor of shape `(num_patches, in_channels * temporal_patch_size * patch_size^2)`
123
+ - `grid_thw`: Int tensor of shape `(num_studies, 3)` with `[T, H, W]` per study
124
+
125
+ - `SMBVisionModel.forward(hidden_states, grid_thw, context_mask, target_mask)` → `SMBVisionModelOutput`
126
+ - Computes MIM (always) and JEPA (if masks provided)
127
+ - Output contains losses and (optionally) encoder/predicted hidden states
128
+
129
+ - `SMBVisionModel.forward_features(hidden_states, grid_thw)` → `(encoded_patches, deepstack_features)`
130
+ - Convenience wrapper that calls the encoder directly for feature extraction
131
+
132
+
133
+ ## Recommended Radiology Settings
134
+
135
+ - **CT chest/abdomen**: `patch_size=16`, `temporal_patch_size=16`, `in_channels=1`
136
+ - **MRI brain**: `patch_size=16`, `temporal_patch_size=16` (or per‑sequence 2D with `temporal_patch_size=1`)
137
+ - **X‑ray**: `patch_size=16`, `temporal_patch_size=1`, `in_channels=1`
138
+
139
+
140
+ ## Notes
141
+
142
+ - FlashAttention‑2 can be enabled via the attention implementation setting in the vision config
143
+ - Ensure volume dimensions are divisible by `patch_size` and `temporal_patch_size` (or center‑crop/pad before patchify)
144
+ - For multi‑sequence MRI or 4‑channel inputs, set `in_channels=4` and adapt reconstruction paths accordingly
145
+
146
+
147
+ ## Citation
148
+
149
+ If you use SMB‑Vision in your research, please cite this repository.