AbstractPhil commited on
Commit
f459cdd
·
verified ·
1 Parent(s): b6349d4

Update vit_zana_v3.py

Browse files
Files changed (1) hide show
  1. vit_zana_v3.py +10 -10
vit_zana_v3.py CHANGED
@@ -21,7 +21,7 @@ class PentachoraEmbedding(nn.Module):
21
 
22
  def __init__(self, vertices: torch.Tensor):
23
  super().__init__()
24
- assert vertices.shape == (5, 128), f"Expected shape (5, 128), got {vertices.shape}"
25
 
26
  self.embed_dim = vertices.shape[-1]
27
 
@@ -135,7 +135,8 @@ class BaselineViT(nn.Module):
135
 
136
  def __init__(
137
  self,
138
- pentachora_list: list, # List of torch.Tensor, each [5, 128]
 
139
  img_size: int = 32,
140
  patch_size: int = 4,
141
  embed_dim: int = 512,
@@ -155,13 +156,12 @@ class BaselineViT(nn.Module):
155
  # Validate each pentachora
156
  for i, penta in enumerate(pentachora_list):
157
  assert isinstance(penta, torch.Tensor), f"Item {i} is not a tensor"
158
- assert penta.shape == (5, 128), f"Item {i} has shape {penta.shape}, expected (5, 128)"
159
 
160
  self.num_classes = len(pentachora_list)
161
  self.embed_dim = embed_dim
162
  self.num_patches = (img_size // patch_size) ** 2
163
  self.similarity_mode = similarity_mode
164
- self.pentachora_dim = 128 # Always 128 from vocab
165
 
166
  # Create individual pentachora embeddings from list
167
  self.class_pentachora = nn.ModuleList([
@@ -239,15 +239,15 @@ class BaselineViT(nn.Module):
239
  """
240
  if self.similarity_mode == 'rose':
241
  # Stack all vertices into single tensor for batch Rose scoring
242
- all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora]) # [100, 5, 128]
243
  # Expand features for batch computation
244
- features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1) # [B, 100, 128]
245
  # Compute Rose scores in parallel
246
- return PentachoronStabilizer.rose_score_magnitude(features_exp.reshape(-1, 128), all_vertices.repeat(features.shape[0], 1, 1)).reshape(features.shape[0], -1)
247
  else:
248
  # Stack all centroids
249
- centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora]) # [100, 128]
250
- features_norm = F.normalize(features, dim=-1) # [B, 128]
251
  return torch.matmul(features_norm, centroids.T) # [B, 100]
252
 
253
 
@@ -319,5 +319,5 @@ class BaselineViT(nn.Module):
319
  if __name__ == "__main__":
320
  print("BaselineViT requires:")
321
  print(" 1. PentachoronStabilizer loaded externally")
322
- print(" 2. pentachora_batch tensor [num_classes, 5, 128]")
323
  print("\nNo random initialization. No fallbacks.")
 
21
 
22
  def __init__(self, vertices: torch.Tensor):
23
  super().__init__()
24
+ #assert vertices.shape == (5, 128), f"Expected shape (5, 128), got {vertices.shape}"
25
 
26
  self.embed_dim = vertices.shape[-1]
27
 
 
135
 
136
  def __init__(
137
  self,
138
+ pentachora_list: list, # List of torch.Tensor, each [5, vocab_dim]
139
+ vocab_dim: int = 256,
140
  img_size: int = 32,
141
  patch_size: int = 4,
142
  embed_dim: int = 512,
 
156
  # Validate each pentachora
157
  for i, penta in enumerate(pentachora_list):
158
  assert isinstance(penta, torch.Tensor), f"Item {i} is not a tensor"
 
159
 
160
  self.num_classes = len(pentachora_list)
161
  self.embed_dim = embed_dim
162
  self.num_patches = (img_size // patch_size) ** 2
163
  self.similarity_mode = similarity_mode
164
+ self.pentachora_dim = vocab_dim
165
 
166
  # Create individual pentachora embeddings from list
167
  self.class_pentachora = nn.ModuleList([
 
239
  """
240
  if self.similarity_mode == 'rose':
241
  # Stack all vertices into single tensor for batch Rose scoring
242
+ all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora]) # [100, 5, vocab_dim]
243
  # Expand features for batch computation
244
+ features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1) # [B, 100, vocab_dim]
245
  # Compute Rose scores in parallel
246
+ return PentachoronStabilizer.rose_score_magnitude(features_exp.reshape(-1, self.embed_dim), all_vertices.repeat(features.shape[0], 1, 1)).reshape(features.shape[0], -1)
247
  else:
248
  # Stack all centroids
249
+ centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora]) # [100, vocab_dim]
250
+ features_norm = F.normalize(features, dim=-1) # [B, vocab_dim]
251
  return torch.matmul(features_norm, centroids.T) # [B, 100]
252
 
253
 
 
319
  if __name__ == "__main__":
320
  print("BaselineViT requires:")
321
  print(" 1. PentachoronStabilizer loaded externally")
322
+ print(" 2. pentachora_batch tensor [num_classes, 5, vocab_dim]")
323
  print("\nNo random initialization. No fallbacks.")