| import albumentations | |
| import cv2 | |
| import torch | |
| import timm | |
| import gradio as gr | |
| import numpy as np | |
| import os | |
| import random | |
| device = torch.device('cpu') | |
| labels = { | |
| 0: 'bacterial_leaf_blight', | |
| 1: 'bacterial_leaf_streak', | |
| 2: 'bacterial_panicle_blight', | |
| 3: 'blast', | |
| 4: 'brown_spot', | |
| 5: 'dead_heart', | |
| 6: 'downy_mildew', | |
| 7: 'hispa', | |
| 8: 'normal', | |
| 9: 'tungro' | |
| } | |
| def inference_fn(model, image=None): | |
| model.eval() | |
| image = image.to(device) | |
| with torch.no_grad(): | |
| output = model(image.unsqueeze(0)) | |
| out = output.sigmoid().detach().cpu().numpy().flatten() | |
| return out | |
| def predict(image=None) -> dict: | |
| mean = (0.485, 0.456, 0.406) | |
| std = (0.229, 0.224, 0.225) | |
| augmentations = albumentations.Compose( | |
| [ | |
| albumentations.Resize(256, 256), | |
| albumentations.HorizontalFlip(p=0.5), | |
| albumentations.VerticalFlip(p=0.5), | |
| albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True), | |
| ] | |
| ) | |
| augmented = augmentations(image=image) | |
| image = augmented["image"] | |
| image = np.transpose(image, (2, 0, 1)) | |
| image = torch.tensor(image, dtype=torch.float32) | |
| model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10) | |
| model.load_state_dict(torch.load("/home/aswin/Downloads/paddy_model.pth", map_location=torch.device(device))) | |
| model.to(device) | |
| predicted = inference_fn(model, image) | |
| return {labels[i]: float(predicted[i]) for i in range(10)} | |
| interface = gr.Interface(fn=predict, | |
| inputs=gr.inputs.Image(), | |
| outputs=gr.outputs.Label(num_top_classes=10), | |
| interpretation='default').launch() | |
| interface.launch() |