Update vit_zana_v3.py
Browse files- vit_zana_v3.py +10 -10
vit_zana_v3.py
CHANGED
|
@@ -21,7 +21,7 @@ class PentachoraEmbedding(nn.Module):
|
|
| 21 |
|
| 22 |
def __init__(self, vertices: torch.Tensor):
|
| 23 |
super().__init__()
|
| 24 |
-
assert vertices.shape == (5, 128), f"Expected shape (5, 128), got {vertices.shape}"
|
| 25 |
|
| 26 |
self.embed_dim = vertices.shape[-1]
|
| 27 |
|
|
@@ -135,7 +135,8 @@ class BaselineViT(nn.Module):
|
|
| 135 |
|
| 136 |
def __init__(
|
| 137 |
self,
|
| 138 |
-
pentachora_list: list, # List of torch.Tensor, each [5,
|
|
|
|
| 139 |
img_size: int = 32,
|
| 140 |
patch_size: int = 4,
|
| 141 |
embed_dim: int = 512,
|
|
@@ -155,13 +156,12 @@ class BaselineViT(nn.Module):
|
|
| 155 |
# Validate each pentachora
|
| 156 |
for i, penta in enumerate(pentachora_list):
|
| 157 |
assert isinstance(penta, torch.Tensor), f"Item {i} is not a tensor"
|
| 158 |
-
assert penta.shape == (5, 128), f"Item {i} has shape {penta.shape}, expected (5, 128)"
|
| 159 |
|
| 160 |
self.num_classes = len(pentachora_list)
|
| 161 |
self.embed_dim = embed_dim
|
| 162 |
self.num_patches = (img_size // patch_size) ** 2
|
| 163 |
self.similarity_mode = similarity_mode
|
| 164 |
-
self.pentachora_dim =
|
| 165 |
|
| 166 |
# Create individual pentachora embeddings from list
|
| 167 |
self.class_pentachora = nn.ModuleList([
|
|
@@ -239,15 +239,15 @@ class BaselineViT(nn.Module):
|
|
| 239 |
"""
|
| 240 |
if self.similarity_mode == 'rose':
|
| 241 |
# Stack all vertices into single tensor for batch Rose scoring
|
| 242 |
-
all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora]) # [100, 5,
|
| 243 |
# Expand features for batch computation
|
| 244 |
-
features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1) # [B, 100,
|
| 245 |
# Compute Rose scores in parallel
|
| 246 |
-
return PentachoronStabilizer.rose_score_magnitude(features_exp.reshape(-1,
|
| 247 |
else:
|
| 248 |
# Stack all centroids
|
| 249 |
-
centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora]) # [100,
|
| 250 |
-
features_norm = F.normalize(features, dim=-1) # [B,
|
| 251 |
return torch.matmul(features_norm, centroids.T) # [B, 100]
|
| 252 |
|
| 253 |
|
|
@@ -319,5 +319,5 @@ class BaselineViT(nn.Module):
|
|
| 319 |
if __name__ == "__main__":
|
| 320 |
print("BaselineViT requires:")
|
| 321 |
print(" 1. PentachoronStabilizer loaded externally")
|
| 322 |
-
print(" 2. pentachora_batch tensor [num_classes, 5,
|
| 323 |
print("\nNo random initialization. No fallbacks.")
|
|
|
|
| 21 |
|
| 22 |
def __init__(self, vertices: torch.Tensor):
|
| 23 |
super().__init__()
|
| 24 |
+
#assert vertices.shape == (5, 128), f"Expected shape (5, 128), got {vertices.shape}"
|
| 25 |
|
| 26 |
self.embed_dim = vertices.shape[-1]
|
| 27 |
|
|
|
|
| 135 |
|
| 136 |
def __init__(
|
| 137 |
self,
|
| 138 |
+
pentachora_list: list, # List of torch.Tensor, each [5, vocab_dim]
|
| 139 |
+
vocab_dim: int = 256,
|
| 140 |
img_size: int = 32,
|
| 141 |
patch_size: int = 4,
|
| 142 |
embed_dim: int = 512,
|
|
|
|
| 156 |
# Validate each pentachora
|
| 157 |
for i, penta in enumerate(pentachora_list):
|
| 158 |
assert isinstance(penta, torch.Tensor), f"Item {i} is not a tensor"
|
|
|
|
| 159 |
|
| 160 |
self.num_classes = len(pentachora_list)
|
| 161 |
self.embed_dim = embed_dim
|
| 162 |
self.num_patches = (img_size // patch_size) ** 2
|
| 163 |
self.similarity_mode = similarity_mode
|
| 164 |
+
self.pentachora_dim = vocab_dim
|
| 165 |
|
| 166 |
# Create individual pentachora embeddings from list
|
| 167 |
self.class_pentachora = nn.ModuleList([
|
|
|
|
| 239 |
"""
|
| 240 |
if self.similarity_mode == 'rose':
|
| 241 |
# Stack all vertices into single tensor for batch Rose scoring
|
| 242 |
+
all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora]) # [100, 5, vocab_dim]
|
| 243 |
# Expand features for batch computation
|
| 244 |
+
features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1) # [B, 100, vocab_dim]
|
| 245 |
# Compute Rose scores in parallel
|
| 246 |
+
return PentachoronStabilizer.rose_score_magnitude(features_exp.reshape(-1, self.embed_dim), all_vertices.repeat(features.shape[0], 1, 1)).reshape(features.shape[0], -1)
|
| 247 |
else:
|
| 248 |
# Stack all centroids
|
| 249 |
+
centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora]) # [100, vocab_dim]
|
| 250 |
+
features_norm = F.normalize(features, dim=-1) # [B, vocab_dim]
|
| 251 |
return torch.matmul(features_norm, centroids.T) # [B, 100]
|
| 252 |
|
| 253 |
|
|
|
|
| 319 |
if __name__ == "__main__":
|
| 320 |
print("BaselineViT requires:")
|
| 321 |
print(" 1. PentachoronStabilizer loaded externally")
|
| 322 |
+
print(" 2. pentachora_batch tensor [num_classes, 5, vocab_dim]")
|
| 323 |
print("\nNo random initialization. No fallbacks.")
|