Spaces:
Sleeping
Sleeping
File size: 10,096 Bytes
1932a55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
"""
Farm Disease Detection API - Gradio Interface
ViT and specialized models for plant disease detection
"""
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
import json
import base64
import io
import time
from typing import List, Dict, Any
# Import models
try:
from transformers import ViTImageProcessor, ViTForImageClassification
from transformers import AutoImageProcessor, AutoModelForImageClassification
MODELS_AVAILABLE = True
except ImportError:
MODELS_AVAILABLE = False
class DiseaseDetectionAPI:
def __init__(self):
self.models = {}
self.processors = {}
self.model_configs = {
"vit_base_224": "google/vit-base-patch16-224",
"vit_base_384": "google/vit-base-patch16-384",
"plant_disease": "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification"
}
# Disease treatment database
self.treatments = {
"corn_blight": "Apply fungicide containing azoxystrobin or propiconazole",
"tomato_late_blight": "Remove affected leaves, apply copper-based fungicide",
"wheat_rust": "Apply triazole fungicides, improve air circulation",
"potato_early_blight": "Use preventive fungicide spray, improve drainage",
"apple_scab": "Apply sulfur-based fungicide, prune for air circulation",
"healthy": "Continue current care routine, monitor regularly"
}
if MODELS_AVAILABLE:
self.load_models()
def load_models(self):
"""Load disease detection models"""
for model_key, model_name in self.model_configs.items():
try:
print(f"Loading {model_name}...")
if "vit" in model_key:
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
else:
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
self.processors[model_key] = processor
self.models[model_key] = model
print(f"β
{model_name} loaded successfully")
except Exception as e:
print(f"β Failed to load {model_name}: {e}")
def analyze_plant_health(self, image: Image.Image, model_key: str = "plant_disease") -> Dict[str, Any]:
"""Analyze plant health and detect diseases"""
if not MODELS_AVAILABLE or model_key not in self.models:
return {"error": "Model not available"}
start_time = time.time()
try:
# Preprocess image
processor = self.processors[model_key]
model = self.models[model_key]
inputs = processor(images=image, return_tensors="pt")
# Run inference
with torch.no_grad():
outputs = model(**inputs)
# Get predictions
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
top_predictions = torch.topk(predictions, 5)
# Format results
diseases_detected = []
for score, idx in zip(top_predictions.values[0], top_predictions.indices[0]):
confidence = float(score)
if confidence > 0.1: # Threshold for relevance
disease_name = model.config.id2label[idx.item()]
treatment = self.treatments.get(disease_name.lower(), "Consult agricultural expert")
diseases_detected.append({
"disease": disease_name,
"confidence": confidence,
"treatment": treatment
})
# Calculate health score (higher if healthy classes dominate)
primary_disease = diseases_detected[0]
health_score = 1.0 - primary_disease["confidence"] if "healthy" not in primary_disease["disease"].lower() else primary_disease["confidence"]
# Generate recommendations
recommendations = self.generate_recommendations(diseases_detected, health_score)
processing_time = time.time() - start_time
return {
"health_score": round(float(health_score), 2),
"primary_disease": {
"name": primary_disease["disease"],
"confidence": round(primary_disease["confidence"], 2),
"severity": self.get_severity(primary_disease["confidence"])
},
"diseases_detected": diseases_detected[:3], # Top 3
"recommendations": recommendations,
"processing_time": round(processing_time, 2),
"model_used": model_key
}
except Exception as e:
return {"error": str(e)}
def get_severity(self, confidence: float) -> str:
"""Determine disease severity based on confidence"""
if confidence > 0.8:
return "severe"
elif confidence > 0.5:
return "moderate"
elif confidence > 0.3:
return "mild"
else:
return "minimal"
def generate_recommendations(self, diseases: List[Dict], health_score: float) -> List[str]:
"""Generate treatment recommendations"""
recommendations = []
if health_score > 0.8:
recommendations.extend([
"Plant appears healthy - continue current care",
"Monitor regularly for early disease signs",
"Maintain proper watering and nutrition"
])
elif health_score > 0.5:
recommendations.extend([
"Early intervention recommended",
"Improve growing conditions",
"Consider preventive treatments"
])
else:
recommendations.extend([
"Immediate treatment required",
"Isolate affected plants if possible",
"Consult agricultural specialist"
])
# Add specific disease treatments
for disease in diseases[:2]:
if disease["confidence"] > 0.3:
recommendations.append(disease["treatment"])
return recommendations[:5] # Limit to 5 recommendations
# Initialize API
api = DiseaseDetectionAPI()
def predict_disease(image, model_choice):
"""Gradio prediction function"""
if image is None:
return None, "Please upload an image"
# Convert to PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Run analysis
results = api.analyze_plant_health(image, model_choice)
if "error" in results:
return None, f"Error: {results['error']}"
# Create visualization
annotated_image = image.copy()
# Format results text
health_score = results['health_score']
primary_disease = results['primary_disease']
health_color = "π’" if health_score > 0.7 else "π‘" if health_score > 0.4 else "π΄"
results_text = f"""
π©Ί **Plant Health Analysis**
{health_color} **Health Score**: {health_score:.1%}
π¦ **Primary Issue**: {primary_disease['name']} ({primary_disease['confidence']:.1%} confidence)
β οΈ **Severity**: {primary_disease['severity'].title()}
**π¬ Detected Issues**:
"""
for i, disease in enumerate(results["diseases_detected"], 1):
results_text += f"\n{i}. **{disease['disease']}** ({disease['confidence']:.1%})"
results_text += f"\n\n**π‘ Recommendations**:"
for i, rec in enumerate(results["recommendations"], 1):
results_text += f"\n{i}. {rec}"
return annotated_image, results_text
# Gradio Interface
with gr.Blocks(title="π©Ί Farm Disease Detection API") as app:
gr.Markdown("# π©Ί Farm Disease Detection API")
gr.Markdown("AI-powered plant disease detection and health assessment")
with gr.Tab("π± Plant Analysis"):
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Plant Image")
model_choice = gr.Dropdown(
choices=["plant_disease", "vit_base_224", "vit_base_384"],
value="plant_disease",
label="Select Model"
)
analyze_btn = gr.Button("π Analyze Plant Health", variant="primary")
with gr.Column():
output_image = gr.Image(label="Plant Image")
results_text = gr.Textbox(label="Health Analysis", lines=15)
analyze_btn.click(
predict_disease,
inputs=[image_input, model_choice],
outputs=[output_image, results_text]
)
with gr.Tab("π‘ API Documentation"):
gr.Markdown("""
## π API Endpoint
**POST** `/api/predict`
### Request Format
```json
{
"data": ["<base64_image>", "<model_choice>"]
}
```
### Model Options
- **plant_disease**: Specialized plant disease model (recommended)
- **vit_base_224**: Fast Vision Transformer
- **vit_base_384**: High resolution Vision Transformer
### Response Format
```json
{
"health_score": 0.75,
"primary_disease": {
"name": "corn_blight",
"confidence": 0.92,
"severity": "moderate"
},
"diseases_detected": [...],
"recommendations": [...],
"processing_time": 1.2
}
```
""")
if __name__ == "__main__":
app.launch() |