File size: 3,289 Bytes
ffba6d1
5e32834
 
9015bf0
ffba6d1
 
5e32834
ffba6d1
5e32834
ffba6d1
5e32834
ffba6d1
 
9015bf0
ffba6d1
9015bf0
ffba6d1
5e32834
 
ffba6d1
 
 
5e32834
ffba6d1
5e32834
ffba6d1
 
 
 
 
 
 
 
 
 
5e32834
 
 
 
 
 
9015bf0
5e32834
 
9015bf0
5e32834
ffba6d1
 
 
 
8b8a54f
ffba6d1
 
3bf1e19
8b8a54f
ffba6d1
3bf1e19
 
 
9015bf0
ffba6d1
 
 
 
 
3bf1e19
 
 
ffba6d1
3bf1e19
ffba6d1
 
 
 
9015bf0
ffba6d1
 
 
 
 
 
1b86128
ffba6d1
8787115
9015bf0
8787115
1b86128
3bf1e19
ffba6d1
 
 
 
 
 
 
 
 
 
1b86128
 
 
ffba6d1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# final_app.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from torchvision import transforms
from PIL import Image
import gradio as gr

# --- SriYantra Custom Layer ---
class SriYantraLayer(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads):
        super().__init__()
        self.triangle_heads = nn.ModuleList([
            nn.Linear(in_dim, out_dim) for _ in range(num_heads)
        ])
        self.norm = nn.LayerNorm(out_dim)

    def forward(self, x):
        outputs = [F.relu(head(x)) for head in self.triangle_heads]
        combined = sum(outputs) / len(outputs)
        return self.norm(combined)

# --- Full Model ---
class SriYantraNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.outer1 = SriYantraLayer(896, 256, 4)
        self.outer2 = SriYantraLayer(256, 256, 4)
        self.inner1 = SriYantraLayer(256, 256, 3)
        self.inner2 = SriYantraLayer(256, 64, 3)
        self.center = nn.Linear(64, 64)
        self.decoder1 = SriYantraLayer(64, 256, 3)
        self.decoder2 = SriYantraLayer(256, 256, 4)
        self.final = nn.Linear(256, 10)

    def forward(self, x):
        x = self.outer1(x)
        x = self.outer2(x)
        x = self.inner1(x)
        x = self.inner2(x)
        x = F.relu(self.center(x))
        x = self.decoder1(x)
        x = self.decoder2(x)
        return self.final(x)

# Load tokenizer and text model
print("Loading IndicBERTv2...")
tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicBERTv2-MLM-only")
sanskrit_model = AutoModel.from_pretrained("ai4bharat/IndicBERTv2-MLM-only")

# Load symbol classifier
model = SriYantraNet()
model.eval()

# Dummy weights (replace with trained weights if available)
with torch.no_grad():
    for param in model.parameters():
        param.uniform_(-0.1, 0.1)

# Image preprocessing
image_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.Grayscale(),
    transforms.ToTensor()
])

# --- Inference Function ---
def predict(image, sanskrit_text):
    try:
        image = image.convert("RGB")
        img_tensor = image_transform(image).view(1, -1)[:, :128]  # shape: [1, 128]

        tokens = tokenizer(sanskrit_text, return_tensors="pt", truncation=True, padding=True)
        with torch.no_grad():
            text_emb = sanskrit_model(**tokens).last_hidden_state.mean(dim=1)  # shape: [1, 768]

        fused = torch.cat([img_tensor, text_emb], dim=1)  # shape: [1, 896]

        with torch.no_grad():
            output = model(fused)
            pred = torch.argmax(output, dim=1).item()

        return f"๐Ÿ”ฎ Predicted Symbolic Pattern Class: {pred}"
    except Exception as e:
        return f"โŒ Error during prediction: {str(e)}"

# --- Gradio Interface ---
iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Upload Symbol Image"),
        gr.Textbox(label="Enter Sanskrit Text")
    ],
    outputs=gr.Textbox(label="Prediction"),
    title="๐Ÿ”บ SriYantra-Net: Symbolic Pattern Classifier",
    description="Upload a sacred symbol image and Sanskrit phrase to classify symbolic pattern using a fused image-text deep network.",
    theme="default"
)

if __name__ == "__main__":
    iface.launch()