Spaces:
Sleeping
Sleeping
| # final_app.py | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModel | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| # --- SriYantra Custom Layer --- | |
| class SriYantraLayer(nn.Module): | |
| def __init__(self, in_dim, out_dim, num_heads): | |
| super().__init__() | |
| self.triangle_heads = nn.ModuleList([ | |
| nn.Linear(in_dim, out_dim) for _ in range(num_heads) | |
| ]) | |
| self.norm = nn.LayerNorm(out_dim) | |
| def forward(self, x): | |
| outputs = [F.relu(head(x)) for head in self.triangle_heads] | |
| combined = sum(outputs) / len(outputs) | |
| return self.norm(combined) | |
| # --- Full Model --- | |
| class SriYantraNet(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.outer1 = SriYantraLayer(896, 256, 4) | |
| self.outer2 = SriYantraLayer(256, 256, 4) | |
| self.inner1 = SriYantraLayer(256, 256, 3) | |
| self.inner2 = SriYantraLayer(256, 64, 3) | |
| self.center = nn.Linear(64, 64) | |
| self.decoder1 = SriYantraLayer(64, 256, 3) | |
| self.decoder2 = SriYantraLayer(256, 256, 4) | |
| self.final = nn.Linear(256, 10) | |
| def forward(self, x): | |
| x = self.outer1(x) | |
| x = self.outer2(x) | |
| x = self.inner1(x) | |
| x = self.inner2(x) | |
| x = F.relu(self.center(x)) | |
| x = self.decoder1(x) | |
| x = self.decoder2(x) | |
| return self.final(x) | |
| # Load tokenizer and text model | |
| print("Loading IndicBERTv2...") | |
| tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicBERTv2-MLM-only") | |
| sanskrit_model = AutoModel.from_pretrained("ai4bharat/IndicBERTv2-MLM-only") | |
| # Load symbol classifier | |
| model = SriYantraNet() | |
| model.eval() | |
| # Dummy weights (replace with trained weights if available) | |
| with torch.no_grad(): | |
| for param in model.parameters(): | |
| param.uniform_(-0.1, 0.1) | |
| # Image preprocessing | |
| image_transform = transforms.Compose([ | |
| transforms.Resize((64, 64)), | |
| transforms.Grayscale(), | |
| transforms.ToTensor() | |
| ]) | |
| # --- Inference Function --- | |
| def predict(image, sanskrit_text): | |
| try: | |
| image = image.convert("RGB") | |
| img_tensor = image_transform(image).view(1, -1)[:, :128] # shape: [1, 128] | |
| tokens = tokenizer(sanskrit_text, return_tensors="pt", truncation=True, padding=True) | |
| with torch.no_grad(): | |
| text_emb = sanskrit_model(**tokens).last_hidden_state.mean(dim=1) # shape: [1, 768] | |
| fused = torch.cat([img_tensor, text_emb], dim=1) # shape: [1, 896] | |
| with torch.no_grad(): | |
| output = model(fused) | |
| pred = torch.argmax(output, dim=1).item() | |
| return f"๐ฎ Predicted Symbolic Pattern Class: {pred}" | |
| except Exception as e: | |
| return f"โ Error during prediction: {str(e)}" | |
| # --- Gradio Interface --- | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Symbol Image"), | |
| gr.Textbox(label="Enter Sanskrit Text") | |
| ], | |
| outputs=gr.Textbox(label="Prediction"), | |
| title="๐บ SriYantra-Net: Symbolic Pattern Classifier", | |
| description="Upload a sacred symbol image and Sanskrit phrase to classify symbolic pattern using a fused image-text deep network.", | |
| theme="default" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |