Spaces:
Sleeping
Sleeping
| """ | |
| Farm Disease Detection API - Gradio Interface | |
| ViT and specialized models for plant disease detection | |
| """ | |
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import json | |
| import base64 | |
| import io | |
| import time | |
| from typing import List, Dict, Any | |
| # Import models | |
| try: | |
| from transformers import ViTImageProcessor, ViTForImageClassification | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| MODELS_AVAILABLE = True | |
| except ImportError: | |
| MODELS_AVAILABLE = False | |
| class DiseaseDetectionAPI: | |
| def __init__(self): | |
| self.models = {} | |
| self.processors = {} | |
| self.model_configs = { | |
| "vit_base_224": "google/vit-base-patch16-224", | |
| "vit_base_384": "google/vit-base-patch16-384", | |
| "plant_disease": "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification" | |
| } | |
| # Disease treatment database | |
| self.treatments = { | |
| "corn_blight": "Apply fungicide containing azoxystrobin or propiconazole", | |
| "tomato_late_blight": "Remove affected leaves, apply copper-based fungicide", | |
| "wheat_rust": "Apply triazole fungicides, improve air circulation", | |
| "potato_early_blight": "Use preventive fungicide spray, improve drainage", | |
| "apple_scab": "Apply sulfur-based fungicide, prune for air circulation", | |
| "healthy": "Continue current care routine, monitor regularly" | |
| } | |
| if MODELS_AVAILABLE: | |
| self.load_models() | |
| def load_models(self): | |
| """Load disease detection models""" | |
| for model_key, model_name in self.model_configs.items(): | |
| try: | |
| print(f"Loading {model_name}...") | |
| if "vit" in model_key: | |
| processor = ViTImageProcessor.from_pretrained(model_name) | |
| model = ViTForImageClassification.from_pretrained(model_name) | |
| else: | |
| processor = AutoImageProcessor.from_pretrained(model_name) | |
| model = AutoModelForImageClassification.from_pretrained(model_name) | |
| self.processors[model_key] = processor | |
| self.models[model_key] = model | |
| print(f"β {model_name} loaded successfully") | |
| except Exception as e: | |
| print(f"β Failed to load {model_name}: {e}") | |
| def analyze_plant_health(self, image: Image.Image, model_key: str = "plant_disease") -> Dict[str, Any]: | |
| """Analyze plant health and detect diseases""" | |
| if not MODELS_AVAILABLE or model_key not in self.models: | |
| return {"error": "Model not available"} | |
| start_time = time.time() | |
| try: | |
| # Preprocess image | |
| processor = self.processors[model_key] | |
| model = self.models[model_key] | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Get predictions | |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| top_predictions = torch.topk(predictions, 5) | |
| # Format results | |
| diseases_detected = [] | |
| for score, idx in zip(top_predictions.values[0], top_predictions.indices[0]): | |
| confidence = float(score) | |
| if confidence > 0.1: # Threshold for relevance | |
| disease_name = model.config.id2label[idx.item()] | |
| treatment = self.treatments.get(disease_name.lower(), "Consult agricultural expert") | |
| diseases_detected.append({ | |
| "disease": disease_name, | |
| "confidence": confidence, | |
| "treatment": treatment | |
| }) | |
| # Calculate health score (higher if healthy classes dominate) | |
| primary_disease = diseases_detected[0] | |
| health_score = 1.0 - primary_disease["confidence"] if "healthy" not in primary_disease["disease"].lower() else primary_disease["confidence"] | |
| # Generate recommendations | |
| recommendations = self.generate_recommendations(diseases_detected, health_score) | |
| processing_time = time.time() - start_time | |
| return { | |
| "health_score": round(float(health_score), 2), | |
| "primary_disease": { | |
| "name": primary_disease["disease"], | |
| "confidence": round(primary_disease["confidence"], 2), | |
| "severity": self.get_severity(primary_disease["confidence"]) | |
| }, | |
| "diseases_detected": diseases_detected[:3], # Top 3 | |
| "recommendations": recommendations, | |
| "processing_time": round(processing_time, 2), | |
| "model_used": model_key | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def get_severity(self, confidence: float) -> str: | |
| """Determine disease severity based on confidence""" | |
| if confidence > 0.8: | |
| return "severe" | |
| elif confidence > 0.5: | |
| return "moderate" | |
| elif confidence > 0.3: | |
| return "mild" | |
| else: | |
| return "minimal" | |
| def generate_recommendations(self, diseases: List[Dict], health_score: float) -> List[str]: | |
| """Generate treatment recommendations""" | |
| recommendations = [] | |
| if health_score > 0.8: | |
| recommendations.extend([ | |
| "Plant appears healthy - continue current care", | |
| "Monitor regularly for early disease signs", | |
| "Maintain proper watering and nutrition" | |
| ]) | |
| elif health_score > 0.5: | |
| recommendations.extend([ | |
| "Early intervention recommended", | |
| "Improve growing conditions", | |
| "Consider preventive treatments" | |
| ]) | |
| else: | |
| recommendations.extend([ | |
| "Immediate treatment required", | |
| "Isolate affected plants if possible", | |
| "Consult agricultural specialist" | |
| ]) | |
| # Add specific disease treatments | |
| for disease in diseases[:2]: | |
| if disease["confidence"] > 0.3: | |
| recommendations.append(disease["treatment"]) | |
| return recommendations[:5] # Limit to 5 recommendations | |
| # Initialize API | |
| api = DiseaseDetectionAPI() | |
| def predict_disease(image, model_choice): | |
| """Gradio prediction function""" | |
| if image is None: | |
| return None, "Please upload an image" | |
| # Convert to PIL Image | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Run analysis | |
| results = api.analyze_plant_health(image, model_choice) | |
| if "error" in results: | |
| return None, f"Error: {results['error']}" | |
| # Create visualization | |
| annotated_image = image.copy() | |
| # Format results text | |
| health_score = results['health_score'] | |
| primary_disease = results['primary_disease'] | |
| health_color = "π’" if health_score > 0.7 else "π‘" if health_score > 0.4 else "π΄" | |
| results_text = f""" | |
| π©Ί **Plant Health Analysis** | |
| {health_color} **Health Score**: {health_score:.1%} | |
| π¦ **Primary Issue**: {primary_disease['name']} ({primary_disease['confidence']:.1%} confidence) | |
| β οΈ **Severity**: {primary_disease['severity'].title()} | |
| **π¬ Detected Issues**: | |
| """ | |
| for i, disease in enumerate(results["diseases_detected"], 1): | |
| results_text += f"\n{i}. **{disease['disease']}** ({disease['confidence']:.1%})" | |
| results_text += f"\n\n**π‘ Recommendations**:" | |
| for i, rec in enumerate(results["recommendations"], 1): | |
| results_text += f"\n{i}. {rec}" | |
| return annotated_image, results_text | |
| # Gradio Interface | |
| with gr.Blocks(title="π©Ί Farm Disease Detection API") as app: | |
| gr.Markdown("# π©Ί Farm Disease Detection API") | |
| gr.Markdown("AI-powered plant disease detection and health assessment") | |
| with gr.Tab("π± Plant Analysis"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Plant Image") | |
| model_choice = gr.Dropdown( | |
| choices=["plant_disease", "vit_base_224", "vit_base_384"], | |
| value="plant_disease", | |
| label="Select Model" | |
| ) | |
| analyze_btn = gr.Button("π Analyze Plant Health", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Plant Image") | |
| results_text = gr.Textbox(label="Health Analysis", lines=15) | |
| analyze_btn.click( | |
| predict_disease, | |
| inputs=[image_input, model_choice], | |
| outputs=[output_image, results_text] | |
| ) | |
| with gr.Tab("π‘ API Documentation"): | |
| gr.Markdown(""" | |
| ## π API Endpoint | |
| **POST** `/api/predict` | |
| ### Request Format | |
| ```json | |
| { | |
| "data": ["<base64_image>", "<model_choice>"] | |
| } | |
| ``` | |
| ### Model Options | |
| - **plant_disease**: Specialized plant disease model (recommended) | |
| - **vit_base_224**: Fast Vision Transformer | |
| - **vit_base_384**: High resolution Vision Transformer | |
| ### Response Format | |
| ```json | |
| { | |
| "health_score": 0.75, | |
| "primary_disease": { | |
| "name": "corn_blight", | |
| "confidence": 0.92, | |
| "severity": "moderate" | |
| }, | |
| "diseases_detected": [...], | |
| "recommendations": [...], | |
| "processing_time": 1.2 | |
| } | |
| ``` | |
| """) | |
| if __name__ == "__main__": | |
| app.launch() |