sri-yantra-NN / app.py
shubhammjha's picture
Update app.py
ffba6d1 verified
# final_app.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from torchvision import transforms
from PIL import Image
import gradio as gr
# --- SriYantra Custom Layer ---
class SriYantraLayer(nn.Module):
def __init__(self, in_dim, out_dim, num_heads):
super().__init__()
self.triangle_heads = nn.ModuleList([
nn.Linear(in_dim, out_dim) for _ in range(num_heads)
])
self.norm = nn.LayerNorm(out_dim)
def forward(self, x):
outputs = [F.relu(head(x)) for head in self.triangle_heads]
combined = sum(outputs) / len(outputs)
return self.norm(combined)
# --- Full Model ---
class SriYantraNet(nn.Module):
def __init__(self):
super().__init__()
self.outer1 = SriYantraLayer(896, 256, 4)
self.outer2 = SriYantraLayer(256, 256, 4)
self.inner1 = SriYantraLayer(256, 256, 3)
self.inner2 = SriYantraLayer(256, 64, 3)
self.center = nn.Linear(64, 64)
self.decoder1 = SriYantraLayer(64, 256, 3)
self.decoder2 = SriYantraLayer(256, 256, 4)
self.final = nn.Linear(256, 10)
def forward(self, x):
x = self.outer1(x)
x = self.outer2(x)
x = self.inner1(x)
x = self.inner2(x)
x = F.relu(self.center(x))
x = self.decoder1(x)
x = self.decoder2(x)
return self.final(x)
# Load tokenizer and text model
print("Loading IndicBERTv2...")
tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicBERTv2-MLM-only")
sanskrit_model = AutoModel.from_pretrained("ai4bharat/IndicBERTv2-MLM-only")
# Load symbol classifier
model = SriYantraNet()
model.eval()
# Dummy weights (replace with trained weights if available)
with torch.no_grad():
for param in model.parameters():
param.uniform_(-0.1, 0.1)
# Image preprocessing
image_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.Grayscale(),
transforms.ToTensor()
])
# --- Inference Function ---
def predict(image, sanskrit_text):
try:
image = image.convert("RGB")
img_tensor = image_transform(image).view(1, -1)[:, :128] # shape: [1, 128]
tokens = tokenizer(sanskrit_text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
text_emb = sanskrit_model(**tokens).last_hidden_state.mean(dim=1) # shape: [1, 768]
fused = torch.cat([img_tensor, text_emb], dim=1) # shape: [1, 896]
with torch.no_grad():
output = model(fused)
pred = torch.argmax(output, dim=1).item()
return f"๐Ÿ”ฎ Predicted Symbolic Pattern Class: {pred}"
except Exception as e:
return f"โŒ Error during prediction: {str(e)}"
# --- Gradio Interface ---
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload Symbol Image"),
gr.Textbox(label="Enter Sanskrit Text")
],
outputs=gr.Textbox(label="Prediction"),
title="๐Ÿ”บ SriYantra-Net: Symbolic Pattern Classifier",
description="Upload a sacred symbol image and Sanskrit phrase to classify symbolic pattern using a fused image-text deep network.",
theme="default"
)
if __name__ == "__main__":
iface.launch()