shubhammjha commited on
Commit
ffba6d1
ยท
verified ยท
1 Parent(s): 8787115

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -69
app.py CHANGED
@@ -1,36 +1,38 @@
1
- # app.py
2
- import gradio as gr
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
- import torchvision.transforms as T
 
7
  from PIL import Image
 
8
 
9
- # --- Sacred Geometry Symbolic Classifier Core ---
10
  class SriYantraLayer(nn.Module):
11
- def __init__(self, in_features, out_features, num_heads=3):
12
- super(SriYantraLayer, self).__init__()
13
  self.triangle_heads = nn.ModuleList([
14
- nn.Linear(in_features, out_features) for _ in range(num_heads)
15
  ])
16
- self.norm = nn.LayerNorm(out_features)
17
 
18
  def forward(self, x):
19
- head_outputs = [F.relu(head(x)) for head in self.triangle_heads]
20
- merged = sum(head_outputs) / len(head_outputs)
21
- return self.norm(merged)
22
 
 
23
  class SriYantraNet(nn.Module):
24
- def __init__(self, input_dim, hidden_dim, latent_dim, output_dim):
25
- super(SriYantraNet, self).__init__()
26
- self.outer1 = SriYantraLayer(input_dim, hidden_dim, num_heads=4)
27
- self.outer2 = SriYantraLayer(hidden_dim, hidden_dim, num_heads=4)
28
- self.inner1 = SriYantraLayer(hidden_dim, hidden_dim, num_heads=3)
29
- self.inner2 = SriYantraLayer(hidden_dim, latent_dim, num_heads=3)
30
- self.center = nn.Linear(latent_dim, latent_dim)
31
- self.decoder1 = SriYantraLayer(latent_dim, hidden_dim, num_heads=3)
32
- self.decoder2 = SriYantraLayer(hidden_dim, hidden_dim, num_heads=4)
33
- self.final = nn.Linear(hidden_dim, output_dim)
34
 
35
  def forward(self, x):
36
  x = self.outer1(x)
@@ -42,72 +44,59 @@ class SriYantraNet(nn.Module):
42
  x = self.decoder2(x)
43
  return self.final(x)
44
 
45
- # --- Image Feature Extractor ---
46
- class ImageToVector(nn.Module):
47
- def __init__(self, out_dim=128):
48
- super(ImageToVector, self).__init__()
49
- self.encoder = nn.Sequential(
50
- nn.Conv2d(1, 16, 3, stride=2),
51
- nn.ReLU(),
52
- nn.Conv2d(16, 32, 3, stride=2),
53
- nn.ReLU(),
54
- nn.Conv2d(32, 64, 3, stride=2),
55
- nn.ReLU(),
56
- nn.Flatten(),
57
- nn.Linear(64*6*6, out_dim),
58
- nn.ReLU()
59
- )
60
-
61
- def forward(self, x):
62
- return self.encoder(x)
63
-
64
- # --- Complete Classifier ---
65
- class SacredSymbolClassifier(nn.Module):
66
- def __init__(self):
67
- super(SacredSymbolClassifier, self).__init__()
68
- self.visual = ImageToVector(out_dim=128)
69
- self.symbolic = SriYantraNet(128, 256, 64, 10)
70
-
71
- def forward(self, x):
72
- v = self.visual(x)
73
- return self.symbolic(v)
74
 
75
- # Instantiate model
76
- model = SacredSymbolClassifier()
77
  model.eval()
78
 
79
- # Dummy weights for demo (replace with trained model)
80
  with torch.no_grad():
81
  for param in model.parameters():
82
  param.uniform_(-0.1, 0.1)
83
 
84
- # --- Preprocessing ---
85
- transform = T.Compose([
86
- T.Grayscale(),
87
- T.Resize((64, 64)),
88
- T.ToTensor()
89
  ])
90
 
91
  # --- Inference Function ---
92
- def classify_symbol(img: Image.Image):
93
  try:
94
- img_tensor = transform(img).unsqueeze(0) # [1, 1, 64, 64]
 
 
 
95
  with torch.no_grad():
96
- output = model(img_tensor)
 
 
 
 
 
97
  pred = torch.argmax(output, dim=1).item()
 
98
  return f"๐Ÿ”ฎ Predicted Symbolic Pattern Class: {pred}"
99
  except Exception as e:
100
  return f"โŒ Error during prediction: {str(e)}"
101
 
102
  # --- Gradio Interface ---
103
- demo = gr.Interface(
104
- fn=classify_symbol,
105
- inputs=gr.Image(type="pil", label="Upload Sacred Image (Yantra, Mandala, etc.)"),
106
- outputs=gr.Textbox(label="Prediction Result"),
107
- title="๐Ÿ›ธ Sacred Geometry Symbol Classifier",
108
- description="A Sri-Yantra inspired deep learning model to classify symbolic sacred patterns.",
109
- theme="compact"
 
 
 
110
  )
111
 
112
  if __name__ == "__main__":
113
- demo.launch()
 
1
+ # final_app.py
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from torchvision import transforms
7
  from PIL import Image
8
+ import gradio as gr
9
 
10
+ # --- SriYantra Custom Layer ---
11
  class SriYantraLayer(nn.Module):
12
+ def __init__(self, in_dim, out_dim, num_heads):
13
+ super().__init__()
14
  self.triangle_heads = nn.ModuleList([
15
+ nn.Linear(in_dim, out_dim) for _ in range(num_heads)
16
  ])
17
+ self.norm = nn.LayerNorm(out_dim)
18
 
19
  def forward(self, x):
20
+ outputs = [F.relu(head(x)) for head in self.triangle_heads]
21
+ combined = sum(outputs) / len(outputs)
22
+ return self.norm(combined)
23
 
24
+ # --- Full Model ---
25
  class SriYantraNet(nn.Module):
26
+ def __init__(self):
27
+ super().__init__()
28
+ self.outer1 = SriYantraLayer(896, 256, 4)
29
+ self.outer2 = SriYantraLayer(256, 256, 4)
30
+ self.inner1 = SriYantraLayer(256, 256, 3)
31
+ self.inner2 = SriYantraLayer(256, 64, 3)
32
+ self.center = nn.Linear(64, 64)
33
+ self.decoder1 = SriYantraLayer(64, 256, 3)
34
+ self.decoder2 = SriYantraLayer(256, 256, 4)
35
+ self.final = nn.Linear(256, 10)
36
 
37
  def forward(self, x):
38
  x = self.outer1(x)
 
44
  x = self.decoder2(x)
45
  return self.final(x)
46
 
47
+ # Load tokenizer and text model
48
+ print("Loading IndicBERTv2...")
49
+ tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicBERTv2-MLM-only")
50
+ sanskrit_model = AutoModel.from_pretrained("ai4bharat/IndicBERTv2-MLM-only")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # Load symbol classifier
53
+ model = SriYantraNet()
54
  model.eval()
55
 
56
+ # Dummy weights (replace with trained weights if available)
57
  with torch.no_grad():
58
  for param in model.parameters():
59
  param.uniform_(-0.1, 0.1)
60
 
61
+ # Image preprocessing
62
+ image_transform = transforms.Compose([
63
+ transforms.Resize((64, 64)),
64
+ transforms.Grayscale(),
65
+ transforms.ToTensor()
66
  ])
67
 
68
  # --- Inference Function ---
69
+ def predict(image, sanskrit_text):
70
  try:
71
+ image = image.convert("RGB")
72
+ img_tensor = image_transform(image).view(1, -1)[:, :128] # shape: [1, 128]
73
+
74
+ tokens = tokenizer(sanskrit_text, return_tensors="pt", truncation=True, padding=True)
75
  with torch.no_grad():
76
+ text_emb = sanskrit_model(**tokens).last_hidden_state.mean(dim=1) # shape: [1, 768]
77
+
78
+ fused = torch.cat([img_tensor, text_emb], dim=1) # shape: [1, 896]
79
+
80
+ with torch.no_grad():
81
+ output = model(fused)
82
  pred = torch.argmax(output, dim=1).item()
83
+
84
  return f"๐Ÿ”ฎ Predicted Symbolic Pattern Class: {pred}"
85
  except Exception as e:
86
  return f"โŒ Error during prediction: {str(e)}"
87
 
88
  # --- Gradio Interface ---
89
+ iface = gr.Interface(
90
+ fn=predict,
91
+ inputs=[
92
+ gr.Image(type="pil", label="Upload Symbol Image"),
93
+ gr.Textbox(label="Enter Sanskrit Text")
94
+ ],
95
+ outputs=gr.Textbox(label="Prediction"),
96
+ title="๐Ÿ”บ SriYantra-Net: Symbolic Pattern Classifier",
97
+ description="Upload a sacred symbol image and Sanskrit phrase to classify symbolic pattern using a fused image-text deep network.",
98
+ theme="default"
99
  )
100
 
101
  if __name__ == "__main__":
102
+ iface.launch()