HMWCS commited on
Commit
702c64f
·
verified ·
1 Parent(s): ea9a3cb

feat: enhance mixed garbage detection and food residue assessment

Browse files
test_images/classifier.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForImageTextToText
2
+ from PIL import Image
3
+ import torch
4
+ import logging
5
+ from typing import Union, Tuple
6
+ from config import Config
7
+ from knowledge_base import GarbageClassificationKnowledge
8
+ import re
9
+
10
+
11
+ def preprocess_image(image: Image.Image) -> Image.Image:
12
+ """
13
+ Preprocess image to meet Gemma3n requirements (512x512)
14
+ """
15
+ # Convert to RGB if necessary
16
+ if image.mode != "RGB":
17
+ image = image.convert("RGB")
18
+
19
+ # Resize to 512x512 as required by Gemma3n
20
+ target_size = (512, 512)
21
+
22
+ # Calculate aspect ratio preserving resize
23
+ original_width, original_height = image.size
24
+ aspect_ratio = original_width / original_height
25
+
26
+ if aspect_ratio > 1:
27
+ # Width is larger
28
+ new_width = target_size[0]
29
+ new_height = int(target_size[0] / aspect_ratio)
30
+ else:
31
+ # Height is larger or equal
32
+ new_height = target_size[1]
33
+ new_width = int(target_size[1] * aspect_ratio)
34
+
35
+ # Resize image maintaining aspect ratio
36
+ image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
37
+
38
+ # Create a new image with target size and paste the resized image
39
+ processed_image = Image.new(
40
+ "RGB", target_size, (255, 255, 255)
41
+ ) # White background
42
+
43
+ # Calculate position to center the image
44
+ x_offset = (target_size[0] - new_width) // 2
45
+ y_offset = (target_size[1] - new_height) // 2
46
+
47
+ processed_image.paste(image, (x_offset, y_offset))
48
+
49
+ return processed_image
50
+
51
+
52
+ class GarbageClassifier:
53
+ def __init__(self, config: Config = None):
54
+ self.config = config or Config()
55
+ self.knowledge = GarbageClassificationKnowledge()
56
+ self.processor = None
57
+ self.model = None
58
+ # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+
60
+ # Setup logging
61
+ logging.basicConfig(level=logging.INFO)
62
+ self.logger = logging.getLogger(__name__)
63
+
64
+ def load_model(self):
65
+ """Load the model and processor"""
66
+ try:
67
+ self.logger.info(f"Loading model: {self.config.MODEL_NAME}")
68
+
69
+ # Load processor
70
+ kwargs = {}
71
+ if self.config.HF_TOKEN:
72
+ kwargs["token"] = self.config.HF_TOKEN
73
+
74
+ self.processor = AutoProcessor.from_pretrained(
75
+ self.config.MODEL_NAME, **kwargs
76
+ )
77
+
78
+ # Load model
79
+ self.model = AutoModelForImageTextToText.from_pretrained(
80
+ self.config.MODEL_NAME,
81
+ torch_dtype=self.config.TORCH_DTYPE,
82
+ device_map=self.config.DEVICE_MAP,
83
+ )
84
+
85
+ self.logger.info("Model loaded successfully")
86
+
87
+ except Exception as e:
88
+ self.logger.error(f"Error loading model: {str(e)}")
89
+ raise
90
+
91
+ def classify_image(self, image: Union[str, Image.Image]) -> Tuple[str, str, int]:
92
+ """
93
+ Classify garbage in the image
94
+
95
+ Args:
96
+ image: PIL Image or path to image file
97
+
98
+ Returns:
99
+ Tuple of (classification_result, detailed_analysis, confidence_score)
100
+ """
101
+ if self.model is None or self.processor is None:
102
+ raise RuntimeError("Model not loaded. Call load_model() first.")
103
+
104
+ try:
105
+ # Load and process image
106
+ if isinstance(image, str):
107
+ image = Image.open(image)
108
+ elif not isinstance(image, Image.Image):
109
+ raise ValueError("Image must be a PIL Image or file path")
110
+
111
+ # Preprocess image to meet Gemma3n requirements
112
+ processed_image = preprocess_image(image)
113
+
114
+ # Prepare messages with system prompt and user query
115
+ messages = [
116
+ {
117
+ "role": "system",
118
+ "content": [
119
+ {
120
+ "type": "text",
121
+ "text": self.knowledge.get_system_prompt(),
122
+ }
123
+ ],
124
+ },
125
+ {
126
+ "role": "user",
127
+ "content": [
128
+ {"type": "image", "image": processed_image},
129
+ {
130
+ "type": "text",
131
+ "text": "Please classify what you see in this image. If it shows garbage/waste items, classify them according to the garbage classification standards. If it shows people, living things, or other non-waste items, classify it as 'Unable to classify' and explain why it's not garbage. Also provide a confidence score from 1-10 indicating how certain you are about your classification.",
132
+ },
133
+ ],
134
+ },
135
+ ]
136
+
137
+ # Apply chat template and tokenize
138
+ inputs = self.processor.apply_chat_template(
139
+ messages,
140
+ add_generation_prompt=True,
141
+ tokenize=True,
142
+ return_dict=True,
143
+ return_tensors="pt",
144
+ ).to(self.model.device, dtype=self.model.dtype)
145
+ input_len = inputs["input_ids"].shape[-1]
146
+
147
+ outputs = self.model.generate(
148
+ **inputs,
149
+ max_new_tokens=self.config.MAX_NEW_TOKENS,
150
+ disable_compile=True,
151
+ )
152
+ response = self.processor.batch_decode(
153
+ outputs[:, input_len:],
154
+ skip_special_tokens=True,
155
+ )[0]
156
+
157
+ # Extract classification from response
158
+ classification = self._extract_classification(response)
159
+
160
+ # Extract reasoning from response
161
+ reasoning = self._extract_reasoning(response)
162
+
163
+ # Extract confidence score from response
164
+ confidence_score = self._extract_confidence_score(response, classification)
165
+
166
+ return classification, reasoning, confidence_score
167
+
168
+ except Exception as e:
169
+ self.logger.error(f"Error during classification: {str(e)}")
170
+ import traceback
171
+
172
+ traceback.print_exc()
173
+ return "Error", f"Classification failed: {str(e)}", 0
174
+
175
+
176
+ def _calculate_confidence_heuristic(self, response_lower: str, classification: str) -> int:
177
+ """Calculate confidence based on response content and classification type"""
178
+ base_confidence = 5
179
+
180
+ # Confidence indicators (increase confidence)
181
+ high_confidence_words = ["clearly", "obviously", "definitely", "certainly", "exactly"]
182
+ medium_confidence_words = ["appears", "seems", "likely", "probably"]
183
+
184
+ # Uncertainty indicators (decrease confidence)
185
+ uncertainty_words = ["might", "could", "possibly", "maybe", "unclear", "difficult"]
186
+
187
+ # Adjust based on confidence words
188
+ for word in high_confidence_words:
189
+ if word in response_lower:
190
+ base_confidence += 2
191
+ break
192
+
193
+ for word in medium_confidence_words:
194
+ if word in response_lower:
195
+ base_confidence += 1
196
+ break
197
+
198
+ for word in uncertainty_words:
199
+ if word in response_lower:
200
+ base_confidence -= 2
201
+ break
202
+
203
+ # Classification-specific adjustments
204
+ if classification == "Unable to classify":
205
+ if any(indicator in response_lower for indicator in ["person", "people", "human", "living"]):
206
+ base_confidence += 1 # High confidence when clearly not waste
207
+ else:
208
+ base_confidence -= 1 # Lower confidence for unclear items
209
+
210
+ elif classification == "Error":
211
+ base_confidence = 1
212
+
213
+ else:
214
+ # Check for specific material mentions (increases confidence)
215
+ specific_materials = ["aluminum", "plastic", "glass", "metal", "cardboard", "paper"]
216
+ if any(material in response_lower for material in specific_materials):
217
+ base_confidence += 1
218
+
219
+ return min(max(base_confidence, 1), 10)
220
+
221
+ def _extract_confidence_score(self, response: str, classification: str) -> int:
222
+ """Extract confidence score from response or calculate based on classification"""
223
+ response_lower = response.lower()
224
+
225
+ # Look for explicit confidence scores in the response
226
+ confidence_patterns = [
227
+ r'\*\*confidence score\*\*[:\s]*(\d+)', # For **Confidence Score**: format
228
+ r'confidence[:\s]*(\d+)',
229
+ r'confident[:\s]*(\d+)',
230
+ r'certainty[:\s]*(\d+)',
231
+ r'score[:\s]*(\d+)',
232
+ r'(\d+)/10',
233
+ r'(\d+)\s*out\s*of\s*10'
234
+ ]
235
+
236
+ for pattern in confidence_patterns:
237
+ match = re.search(pattern, response_lower)
238
+ if match:
239
+ score = int(match.group(1))
240
+ return min(max(score, 1), 10) # Clamp between 1-10
241
+
242
+ # If no explicit score found, calculate based on classification indicators
243
+ return self._calculate_confidence_heuristic(response_lower, classification)
244
+
245
+ def _extract_classification(self, response: str) -> str:
246
+ """Extract the main classification from the response - trust Gemma 3n intelligence more"""
247
+ response_lower = response.lower()
248
+
249
+ # Primary: Trust explicit category mentions from Gemma 3n
250
+ categories = self.knowledge.get_categories()
251
+
252
+ for category in categories:
253
+ if category.lower() in response_lower:
254
+ # Simple negation check
255
+ category_index = response_lower.find(category.lower())
256
+ context_before = response_lower[max(0, category_index - 20):category_index]
257
+
258
+ if not any(neg in context_before[-10:] for neg in ["not", "cannot", "isn't"]):
259
+ return category
260
+
261
+ # Secondary: Look for explicit mixed garbage warnings from model
262
+ mixed_warnings = [
263
+ "multiple garbage types detected",
264
+ "separate items",
265
+ "different garbage types",
266
+ "mixed together"
267
+ ]
268
+
269
+ if any(warning in response_lower for warning in mixed_warnings):
270
+ return "Unable to classify"
271
+
272
+ # Tertiary: Basic material detection (simplified)
273
+ if any(material in response_lower for material in
274
+ ["recyclable", "aluminum", "plastic", "glass", "metal", "cardboard"]):
275
+ # Check for contamination
276
+ if any(cont in response_lower for cont in ["obvious food", "substantial residue", "chunks", "liquids"]):
277
+ return "Food/Kitchen Waste"
278
+ return "Recyclable Waste"
279
+
280
+ if any(food in response_lower for food in ["food", "organic", "kitchen", "fruit", "vegetable"]):
281
+ return "Food/Kitchen Waste"
282
+
283
+ if any(hazard in response_lower for hazard in ["battery", "hazardous", "chemical", "toxic"]):
284
+ return "Hazardous Waste"
285
+
286
+ if any(other in response_lower for other in ["cigarette", "ceramic", "styrofoam"]):
287
+ return "Other Waste"
288
+
289
+ # Non-garbage detection
290
+ if any(non_garbage in response_lower for non_garbage in ["person", "people", "human", "living", "animal"]):
291
+ return "Unable to classify"
292
+
293
+ # Final fallback - let Gemma 3n's reasoning guide us
294
+ if any(unable in response_lower for unable in ["unable to classify", "cannot classify", "not garbage"]):
295
+ return "Unable to classify"
296
+
297
+ # Default to Unable to classify if unclear
298
+ return "Unable to classify"
299
+
300
+ def _extract_reasoning(self, response: str) -> str:
301
+ """Extract only the reasoning content, removing all formatting markers and classification info"""
302
+ import re
303
+
304
+ # Remove all formatting markers
305
+ cleaned_response = response.replace("**Classification**:", "")
306
+ cleaned_response = cleaned_response.replace("**Reasoning**:", "")
307
+ cleaned_response = re.sub(r'\*\*.*?\*\*:', '', cleaned_response) # Remove any **text**: patterns
308
+ cleaned_response = cleaned_response.replace("**", "") # Remove remaining ** markers
309
+
310
+ # Remove category names that might appear at the beginning
311
+ categories = self.knowledge.get_categories()
312
+ for category in categories:
313
+ if cleaned_response.strip().startswith(category):
314
+ cleaned_response = cleaned_response.replace(category, "", 1)
315
+ break
316
+
317
+ # Remove common material names that might appear at the beginning
318
+ material_names = [
319
+ "Glass", "Plastic", "Metal", "Paper", "Cardboard", "Aluminum",
320
+ "Steel", "Iron", "Tin", "Foil", "Wood", "Ceramic", "Fabric",
321
+ "Recyclable Waste", "Food/Kitchen Waste", "Hazardous Waste", "Other Waste"
322
+ ]
323
+
324
+ # Clean the response
325
+ cleaned_response = cleaned_response.strip()
326
+
327
+ # Remove material names at the beginning
328
+ for material in material_names:
329
+ if cleaned_response.startswith(material):
330
+ # Remove the material name and any following punctuation/whitespace
331
+ cleaned_response = cleaned_response[len(material):].lstrip(" .,;:")
332
+ break
333
+
334
+ # Split into sentences and clean up
335
+ sentences = []
336
+
337
+ # Split by common sentence endings, but keep the endings
338
+ parts = re.split(r'([.!?])\s+', cleaned_response)
339
+
340
+ # Rejoin parts to maintain sentence structure
341
+ reconstructed_parts = []
342
+ for i in range(0, len(parts), 2):
343
+ if i < len(parts):
344
+ sentence = parts[i]
345
+ if i + 1 < len(parts):
346
+ sentence += parts[i + 1] # Add the punctuation back
347
+ reconstructed_parts.append(sentence)
348
+
349
+ for part in reconstructed_parts:
350
+ part = part.strip()
351
+ if not part:
352
+ continue
353
+
354
+ # Skip parts that are just category names or material names
355
+ if part in categories or part.rstrip(".,;:") in material_names:
356
+ continue
357
+
358
+ # Skip parts that start with category names or material names
359
+ is_category_line = False
360
+ for item in categories + material_names:
361
+ if part.startswith(item):
362
+ is_category_line = True
363
+ break
364
+
365
+ if is_category_line:
366
+ continue
367
+
368
+ # Clean up the sentence
369
+ part = re.sub(r'^[A-Za-z\s]+:', '', part).strip() # Remove "Category:" type prefixes
370
+
371
+ if part and len(part) > 3: # Only keep meaningful content
372
+ sentences.append(part)
373
+
374
+ # Join sentences
375
+ reasoning = ' '.join(sentences)
376
+
377
+ # Final cleanup - remove any remaining standalone material words at the beginning
378
+ reasoning_words = reasoning.split()
379
+ if reasoning_words and reasoning_words[0] in [m.lower() for m in material_names]:
380
+ reasoning_words = reasoning_words[1:]
381
+ reasoning = ' '.join(reasoning_words)
382
+
383
+ # Ensure proper capitalization
384
+ if reasoning:
385
+ reasoning = reasoning[0].upper() + reasoning[1:] if len(reasoning) > 1 else reasoning.upper()
386
+
387
+ # Ensure proper punctuation
388
+ if not reasoning.endswith(('.', '!', '?')):
389
+ reasoning += '.'
390
+
391
+ return reasoning if reasoning else "Analysis not available"
392
+
393
+ def get_categories_info(self):
394
+ """Get information about all categories"""
395
+ return self.knowledge.get_category_descriptions()
test_images/knowledge_base.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class GarbageClassificationKnowledge:
2
+ @staticmethod
3
+ def get_system_prompt():
4
+ return """You are a professional garbage classification expert. You need to carefully observe the items in the picture, analyze their materials, properties and uses, and then make accurate judgments according to garbage classification standards.
5
+
6
+ IMPORTANT: You should ONLY classify items that are actually garbage/waste. If the image contains people, living things, furniture, electronics in use, or other non-waste items, you should classify it as "Unable to classify" and explain that it's not garbage.
7
+
8
+ **MIXED GARBAGE HANDLING RULES:**
9
+
10
+ 1. **Food Residue Assessment**:
11
+ - OBVIOUSLY VISIBLE FOOD (chunks, liquids, substantial residue): Container goes to "Food/Kitchen Waste" with warning "⚠️ Tip: Empty and rinse this container first, then it can be recycled!"
12
+ - MINOR RESIDUE (grease stains, light film, pizza box grease spots): Container remains "Recyclable Waste"
13
+
14
+ 2. **Multiple Different Garbage Types**:
15
+ - If image shows clearly different waste categories mixed together (electronics + organic waste, batteries + food scraps, multiple unrelated garbage types): classify as "Unable to classify" with warning "⚠️ Warning: Multiple garbage types detected. Please separate items for proper classification."
16
+ - Recyclable container with food is the ONLY allowed mixed situation - handle with rule 1 above
17
+ - ALL other mixed scenarios must be classified as "Unable to classify"
18
+
19
+ STRICTLY ENFORCE: Only recyclable containers with food are permitted mixed classification. Everything else mixed = "Unable to classify" with separation warning.
20
+
21
+ Garbage classification standards:
22
+
23
+ **Recyclable Waste**:
24
+ - Paper: newspapers, magazines, books, various packaging papers, office paper, advertising flyers, cardboard boxes with light grease stains, copy paper, etc.
25
+ - Plastics: clean plastic bottles (#1 PETE, #2 HDPE), clean plastic containers, plastic bags, toothbrushes, cups, water bottles, plastic toys, etc. (NOT styrofoam #6 or heavily coated containers)
26
+ - Metals: clean aluminum cans, clean tin cans, toothpaste tubes, metal toys, metal stationery, nails, metal sheets, aluminum foil, etc.
27
+ - Glass: clean glass bottles and jars, broken glass pieces, mirrors, light bulbs, vacuum flasks, etc.
28
+ - Textiles: old clothing, textile products, shoes, curtains, towels, bags, etc.
29
+ - NOTE: Light grease stains or minor residue are acceptable for recycling. Only obvious food content requires cleaning first.
30
+
31
+ **Food/Kitchen Waste**:
32
+ - Food scraps: rice, noodles, bread, meat, fish, shrimp shells, crab shells, bones, etc.
33
+ - Fruit peels and cores: watermelon rinds, apple cores, orange peels, banana peels, nut shells, etc.
34
+ - Plants: withered branches and leaves, flowers, traditional Chinese medicine residue, etc.
35
+ - Expired food: expired canned food, cookies, candy, etc.
36
+ - Containers with obvious food content (chunks, liquids, substantial residue)
37
+
38
+ **Hazardous Waste**:
39
+ - Batteries: dry batteries, rechargeable batteries, button batteries, and all types of batteries
40
+ - Light tubes: energy-saving lamps, fluorescent tubes, incandescent bulbs, LED lights, etc.
41
+ - Pharmaceuticals: expired medicines, medicine packaging, thermometers, blood pressure monitors, etc.
42
+ - Paints: paint, coatings, glue, nail polish, cosmetics, etc.
43
+ - Others: pesticides, cleaning agents, agricultural chemicals, X-ray films, etc.
44
+
45
+ **Other Waste**:
46
+ - Contaminated non-recyclable paper: toilet paper, diapers, wet wipes, napkins, etc.
47
+ - Non-recyclable containers: styrofoam containers (#6 polystyrene), wax-coated containers, multi-material packaging
48
+ - Cigarette butts, ceramics, dust, disposable tableware (non-plastic)
49
+ - Large bones, hard shells, hard fruit pits (coconut shells, durian shells, walnut shells, corn cobs, etc.)
50
+ - Hair, pet waste, cat litter, etc.
51
+
52
+ **Unable to classify**:
53
+ - People, human faces, human body parts
54
+ - Living animals, pets
55
+ - Furniture, appliances, electronics in normal use
56
+ - Buildings, landscapes, vehicles
57
+ - Any item that is not intended to be discarded as waste
58
+ - Multiple different garbage types mixed together
59
+
60
+ Please observe the items in the image carefully according to the above classification standards and provide accurate classification results.
61
+
62
+ Format your response EXACTLY as follows:
63
+
64
+ **Classification**: [Category Name or "Unable to classify"]
65
+ **Reasoning**: [Brief explanation of why this item belongs to this category, or why it cannot be classified as garbage]
66
+ **Confidence Score**: [Number from 1-10]"""
67
+
68
+ @staticmethod
69
+ def get_categories():
70
+ return [
71
+ "Recyclable Waste",
72
+ "Food/Kitchen Waste",
73
+ "Hazardous Waste",
74
+ "Other Waste",
75
+ "Unable to classify",
76
+ ]
77
+
78
+ @staticmethod
79
+ def get_category_descriptions():
80
+ return {
81
+ "Recyclable Waste": "Items that can be processed and reused, including paper, plastic, metal, glass, and textiles (light grease stains acceptable)",
82
+ "Food/Kitchen Waste": "Organic waste from food preparation and consumption, including containers with obvious food content",
83
+ "Hazardous Waste": "Items containing harmful substances that require special disposal",
84
+ "Other Waste": "Items that don't fit into other categories and go to general waste",
85
+ "Unable to classify": "Items that are not garbage/waste, such as people, living things, functioning objects, or multiple different garbage types mixed together",
86
+ }