""" Farm Disease Detection API - Gradio Interface ViT and specialized models for plant disease detection """ import gradio as gr import torch import cv2 import numpy as np from PIL import Image import json import base64 import io import time from typing import List, Dict, Any # Import models try: from transformers import ViTImageProcessor, ViTForImageClassification from transformers import AutoImageProcessor, AutoModelForImageClassification MODELS_AVAILABLE = True except ImportError: MODELS_AVAILABLE = False class DiseaseDetectionAPI: def __init__(self): self.models = {} self.processors = {} self.model_configs = { "vit_base_224": "google/vit-base-patch16-224", "vit_base_384": "google/vit-base-patch16-384", "plant_disease": "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification" } # Disease treatment database self.treatments = { "corn_blight": "Apply fungicide containing azoxystrobin or propiconazole", "tomato_late_blight": "Remove affected leaves, apply copper-based fungicide", "wheat_rust": "Apply triazole fungicides, improve air circulation", "potato_early_blight": "Use preventive fungicide spray, improve drainage", "apple_scab": "Apply sulfur-based fungicide, prune for air circulation", "healthy": "Continue current care routine, monitor regularly" } if MODELS_AVAILABLE: self.load_models() def load_models(self): """Load disease detection models""" for model_key, model_name in self.model_configs.items(): try: print(f"Loading {model_name}...") if "vit" in model_key: processor = ViTImageProcessor.from_pretrained(model_name) model = ViTForImageClassification.from_pretrained(model_name) else: processor = AutoImageProcessor.from_pretrained(model_name) model = AutoModelForImageClassification.from_pretrained(model_name) self.processors[model_key] = processor self.models[model_key] = model print(f"✅ {model_name} loaded successfully") except Exception as e: print(f"❌ Failed to load {model_name}: {e}") def analyze_plant_health(self, image: Image.Image, model_key: str = "plant_disease") -> Dict[str, Any]: """Analyze plant health and detect diseases""" if not MODELS_AVAILABLE or model_key not in self.models: return {"error": "Model not available"} start_time = time.time() try: # Preprocess image processor = self.processors[model_key] model = self.models[model_key] inputs = processor(images=image, return_tensors="pt") # Run inference with torch.no_grad(): outputs = model(**inputs) # Get predictions predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) top_predictions = torch.topk(predictions, 5) # Format results diseases_detected = [] for score, idx in zip(top_predictions.values[0], top_predictions.indices[0]): confidence = float(score) if confidence > 0.1: # Threshold for relevance disease_name = model.config.id2label[idx.item()] treatment = self.treatments.get(disease_name.lower(), "Consult agricultural expert") diseases_detected.append({ "disease": disease_name, "confidence": confidence, "treatment": treatment }) # Calculate health score (higher if healthy classes dominate) primary_disease = diseases_detected[0] health_score = 1.0 - primary_disease["confidence"] if "healthy" not in primary_disease["disease"].lower() else primary_disease["confidence"] # Generate recommendations recommendations = self.generate_recommendations(diseases_detected, health_score) processing_time = time.time() - start_time return { "health_score": round(float(health_score), 2), "primary_disease": { "name": primary_disease["disease"], "confidence": round(primary_disease["confidence"], 2), "severity": self.get_severity(primary_disease["confidence"]) }, "diseases_detected": diseases_detected[:3], # Top 3 "recommendations": recommendations, "processing_time": round(processing_time, 2), "model_used": model_key } except Exception as e: return {"error": str(e)} def get_severity(self, confidence: float) -> str: """Determine disease severity based on confidence""" if confidence > 0.8: return "severe" elif confidence > 0.5: return "moderate" elif confidence > 0.3: return "mild" else: return "minimal" def generate_recommendations(self, diseases: List[Dict], health_score: float) -> List[str]: """Generate treatment recommendations""" recommendations = [] if health_score > 0.8: recommendations.extend([ "Plant appears healthy - continue current care", "Monitor regularly for early disease signs", "Maintain proper watering and nutrition" ]) elif health_score > 0.5: recommendations.extend([ "Early intervention recommended", "Improve growing conditions", "Consider preventive treatments" ]) else: recommendations.extend([ "Immediate treatment required", "Isolate affected plants if possible", "Consult agricultural specialist" ]) # Add specific disease treatments for disease in diseases[:2]: if disease["confidence"] > 0.3: recommendations.append(disease["treatment"]) return recommendations[:5] # Limit to 5 recommendations # Initialize API api = DiseaseDetectionAPI() def predict_disease(image, model_choice): """Gradio prediction function""" if image is None: return None, "Please upload an image" # Convert to PIL Image if isinstance(image, np.ndarray): image = Image.fromarray(image) # Run analysis results = api.analyze_plant_health(image, model_choice) if "error" in results: return None, f"Error: {results['error']}" # Create visualization annotated_image = image.copy() # Format results text health_score = results['health_score'] primary_disease = results['primary_disease'] health_color = "🟢" if health_score > 0.7 else "🟡" if health_score > 0.4 else "🔴" results_text = f""" 🩺 **Plant Health Analysis** {health_color} **Health Score**: {health_score:.1%} 🦠 **Primary Issue**: {primary_disease['name']} ({primary_disease['confidence']:.1%} confidence) ⚠️ **Severity**: {primary_disease['severity'].title()} **🔬 Detected Issues**: """ for i, disease in enumerate(results["diseases_detected"], 1): results_text += f"\n{i}. **{disease['disease']}** ({disease['confidence']:.1%})" results_text += f"\n\n**💡 Recommendations**:" for i, rec in enumerate(results["recommendations"], 1): results_text += f"\n{i}. {rec}" return annotated_image, results_text # Gradio Interface with gr.Blocks(title="🩺 Farm Disease Detection API") as app: gr.Markdown("# 🩺 Farm Disease Detection API") gr.Markdown("AI-powered plant disease detection and health assessment") with gr.Tab("🌱 Plant Analysis"): with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Plant Image") model_choice = gr.Dropdown( choices=["plant_disease", "vit_base_224", "vit_base_384"], value="plant_disease", label="Select Model" ) analyze_btn = gr.Button("🔍 Analyze Plant Health", variant="primary") with gr.Column(): output_image = gr.Image(label="Plant Image") results_text = gr.Textbox(label="Health Analysis", lines=15) analyze_btn.click( predict_disease, inputs=[image_input, model_choice], outputs=[output_image, results_text] ) with gr.Tab("📡 API Documentation"): gr.Markdown(""" ## 🚀 API Endpoint **POST** `/api/predict` ### Request Format ```json { "data": ["", ""] } ``` ### Model Options - **plant_disease**: Specialized plant disease model (recommended) - **vit_base_224**: Fast Vision Transformer - **vit_base_384**: High resolution Vision Transformer ### Response Format ```json { "health_score": 0.75, "primary_disease": { "name": "corn_blight", "confidence": 0.92, "severity": "moderate" }, "diseases_detected": [...], "recommendations": [...], "processing_time": 1.2 } ``` """) if __name__ == "__main__": app.launch()