File size: 10,096 Bytes
1932a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
"""

Farm Disease Detection API - Gradio Interface

ViT and specialized models for plant disease detection

"""

import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
import json
import base64
import io
import time
from typing import List, Dict, Any

# Import models
try:
    from transformers import ViTImageProcessor, ViTForImageClassification
    from transformers import AutoImageProcessor, AutoModelForImageClassification
    MODELS_AVAILABLE = True
except ImportError:
    MODELS_AVAILABLE = False

class DiseaseDetectionAPI:
    def __init__(self):
        self.models = {}
        self.processors = {}
        self.model_configs = {
            "vit_base_224": "google/vit-base-patch16-224",
            "vit_base_384": "google/vit-base-patch16-384",
            "plant_disease": "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification"
        }
        
        # Disease treatment database
        self.treatments = {
            "corn_blight": "Apply fungicide containing azoxystrobin or propiconazole",
            "tomato_late_blight": "Remove affected leaves, apply copper-based fungicide",
            "wheat_rust": "Apply triazole fungicides, improve air circulation",
            "potato_early_blight": "Use preventive fungicide spray, improve drainage",
            "apple_scab": "Apply sulfur-based fungicide, prune for air circulation",
            "healthy": "Continue current care routine, monitor regularly"
        }
        
        if MODELS_AVAILABLE:
            self.load_models()
    
    def load_models(self):
        """Load disease detection models"""
        for model_key, model_name in self.model_configs.items():
            try:
                print(f"Loading {model_name}...")
                if "vit" in model_key:
                    processor = ViTImageProcessor.from_pretrained(model_name)
                    model = ViTForImageClassification.from_pretrained(model_name)
                else:
                    processor = AutoImageProcessor.from_pretrained(model_name)
                    model = AutoModelForImageClassification.from_pretrained(model_name)
                
                self.processors[model_key] = processor
                self.models[model_key] = model
                print(f"βœ… {model_name} loaded successfully")
            except Exception as e:
                print(f"❌ Failed to load {model_name}: {e}")
    
    def analyze_plant_health(self, image: Image.Image, model_key: str = "plant_disease") -> Dict[str, Any]:
        """Analyze plant health and detect diseases"""
        if not MODELS_AVAILABLE or model_key not in self.models:
            return {"error": "Model not available"}
        
        start_time = time.time()
        
        try:
            # Preprocess image
            processor = self.processors[model_key]
            model = self.models[model_key]
            
            inputs = processor(images=image, return_tensors="pt")
            
            # Run inference
            with torch.no_grad():
                outputs = model(**inputs)
            
            # Get predictions
            predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
            top_predictions = torch.topk(predictions, 5)
            
            # Format results
            diseases_detected = []
            for score, idx in zip(top_predictions.values[0], top_predictions.indices[0]):
                confidence = float(score)
                if confidence > 0.1:  # Threshold for relevance
                    disease_name = model.config.id2label[idx.item()]
                    treatment = self.treatments.get(disease_name.lower(), "Consult agricultural expert")
                    
                    diseases_detected.append({
                        "disease": disease_name,
                        "confidence": confidence,
                        "treatment": treatment
                    })
            
            # Calculate health score (higher if healthy classes dominate)
            primary_disease = diseases_detected[0]
            health_score = 1.0 - primary_disease["confidence"] if "healthy" not in primary_disease["disease"].lower() else primary_disease["confidence"]
            
            # Generate recommendations
            recommendations = self.generate_recommendations(diseases_detected, health_score)
            
            processing_time = time.time() - start_time
            
            return {
                "health_score": round(float(health_score), 2),
                "primary_disease": {
                    "name": primary_disease["disease"],
                    "confidence": round(primary_disease["confidence"], 2),
                    "severity": self.get_severity(primary_disease["confidence"])
                },
                "diseases_detected": diseases_detected[:3],  # Top 3
                "recommendations": recommendations,
                "processing_time": round(processing_time, 2),
                "model_used": model_key
            }
            
        except Exception as e:
            return {"error": str(e)}
    
    def get_severity(self, confidence: float) -> str:
        """Determine disease severity based on confidence"""
        if confidence > 0.8:
            return "severe"
        elif confidence > 0.5:
            return "moderate"
        elif confidence > 0.3:
            return "mild"
        else:
            return "minimal"
    
    def generate_recommendations(self, diseases: List[Dict], health_score: float) -> List[str]:
        """Generate treatment recommendations"""
        recommendations = []
        
        if health_score > 0.8:
            recommendations.extend([
                "Plant appears healthy - continue current care",
                "Monitor regularly for early disease signs",
                "Maintain proper watering and nutrition"
            ])
        elif health_score > 0.5:
            recommendations.extend([
                "Early intervention recommended",
                "Improve growing conditions",
                "Consider preventive treatments"
            ])
        else:
            recommendations.extend([
                "Immediate treatment required",
                "Isolate affected plants if possible",
                "Consult agricultural specialist"
            ])
        
        # Add specific disease treatments
        for disease in diseases[:2]:
            if disease["confidence"] > 0.3:
                recommendations.append(disease["treatment"])
        
        return recommendations[:5]  # Limit to 5 recommendations

# Initialize API
api = DiseaseDetectionAPI()

def predict_disease(image, model_choice):
    """Gradio prediction function"""
    if image is None:
        return None, "Please upload an image"
    
    # Convert to PIL Image
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    # Run analysis
    results = api.analyze_plant_health(image, model_choice)
    
    if "error" in results:
        return None, f"Error: {results['error']}"
    
    # Create visualization
    annotated_image = image.copy()
    
    # Format results text
    health_score = results['health_score']
    primary_disease = results['primary_disease']
    
    health_color = "🟒" if health_score > 0.7 else "🟑" if health_score > 0.4 else "πŸ”΄"
    
    results_text = f"""

🩺 **Plant Health Analysis**



{health_color} **Health Score**: {health_score:.1%}

🦠 **Primary Issue**: {primary_disease['name']} ({primary_disease['confidence']:.1%} confidence)

⚠️ **Severity**: {primary_disease['severity'].title()}



**πŸ”¬ Detected Issues**:

"""
    
    for i, disease in enumerate(results["diseases_detected"], 1):
        results_text += f"\n{i}. **{disease['disease']}** ({disease['confidence']:.1%})"
    
    results_text += f"\n\n**πŸ’‘ Recommendations**:"
    for i, rec in enumerate(results["recommendations"], 1):
        results_text += f"\n{i}. {rec}"
    
    return annotated_image, results_text

# Gradio Interface
with gr.Blocks(title="🩺 Farm Disease Detection API") as app:
    gr.Markdown("# 🩺 Farm Disease Detection API")
    gr.Markdown("AI-powered plant disease detection and health assessment")
    
    with gr.Tab("🌱 Plant Analysis"):
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(type="pil", label="Upload Plant Image")
                model_choice = gr.Dropdown(
                    choices=["plant_disease", "vit_base_224", "vit_base_384"],
                    value="plant_disease",
                    label="Select Model"
                )
                analyze_btn = gr.Button("πŸ” Analyze Plant Health", variant="primary")
            
            with gr.Column():
                output_image = gr.Image(label="Plant Image")
                results_text = gr.Textbox(label="Health Analysis", lines=15)
        
        analyze_btn.click(
            predict_disease,
            inputs=[image_input, model_choice],
            outputs=[output_image, results_text]
        )
    
    with gr.Tab("πŸ“‘ API Documentation"):
        gr.Markdown("""

## πŸš€ API Endpoint



**POST** `/api/predict`



### Request Format

```json

{

  "data": ["<base64_image>", "<model_choice>"]

}

```



### Model Options

- **plant_disease**: Specialized plant disease model (recommended)

- **vit_base_224**: Fast Vision Transformer

- **vit_base_384**: High resolution Vision Transformer



### Response Format

```json

{

  "health_score": 0.75,

  "primary_disease": {

    "name": "corn_blight",

    "confidence": 0.92,

    "severity": "moderate"

  },

  "diseases_detected": [...],

  "recommendations": [...],

  "processing_time": 1.2

}

```

        """)

if __name__ == "__main__":
    app.launch()