HMWCS commited on
Commit
9390992
·
verified ·
1 Parent(s): bd6d077

Upload 8 files

Browse files
app.py CHANGED
@@ -1,7 +1,93 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from PIL import Image
3
+ import os
4
+ from classifier import GarbageClassifier
5
+ from config import Config
6
 
7
+ # Initialize classifier
8
+ config = Config()
9
+ classifier = GarbageClassifier(config)
10
 
11
+ # Load model at startup
12
+ print("Loading model...")
13
+ classifier.load_model()
14
+ print("Model loaded successfully!")
15
+
16
+
17
+ def classify_garbage(image):
18
+ """
19
+ Classify garbage in uploaded image
20
+ """
21
+ if image is None:
22
+ return "Please upload an image", "No image provided"
23
+
24
+ try:
25
+ classification, full_response = classifier.classify_image(image)
26
+ return classification, full_response
27
+ except Exception as e:
28
+ return "Error", f"Classification failed: {str(e)}"
29
+
30
+
31
+ def get_example_images():
32
+ """Get example images if they exist"""
33
+ example_dir = "test_images"
34
+ examples = []
35
+ if os.path.exists(example_dir):
36
+ for file in os.listdir(example_dir):
37
+ if file.lower().endswith((".png", ".jpg", ".jpeg")):
38
+ examples.append(os.path.join(example_dir, file))
39
+ return examples[:3] # Limit to 3 examples
40
+
41
+
42
+ # Create Gradio interface
43
+ with gr.Blocks(title="Garbage Classification System") as demo:
44
+ gr.Markdown("# 🗂️ Garbage Classification System")
45
+ gr.Markdown(
46
+ "Upload an image to classify garbage into: Recyclable Waste, Food/Kitchen Waste, Hazardous Waste, or Other Waste"
47
+ )
48
+
49
+ with gr.Row():
50
+ with gr.Column():
51
+ image_input = gr.Image(type="pil", label="Upload Garbage Image")
52
+
53
+ classify_btn = gr.Button("Classify Garbage", variant="primary", size="lg")
54
+
55
+ with gr.Column():
56
+ classification_output = gr.Textbox(
57
+ label="Classification Result",
58
+ placeholder="Upload an image and click classify",
59
+ )
60
+
61
+ full_response_output = gr.Textbox(
62
+ label="Detailed Analysis",
63
+ placeholder="Detailed reasoning will appear here",
64
+ lines=10,
65
+ )
66
+
67
+ # Category information
68
+ with gr.Accordion("📋 Garbage Categories Information", open=False):
69
+ category_info = classifier.get_categories_info()
70
+ for category, description in category_info.items():
71
+ gr.Markdown(f"**{category}**: {description}")
72
+
73
+ # Examples
74
+ examples = get_example_images()
75
+ if examples:
76
+ gr.Examples(examples=examples, inputs=image_input, label="Example Images")
77
+
78
+ # Event handlers
79
+ classify_btn.click(
80
+ fn=classify_garbage,
81
+ inputs=image_input,
82
+ outputs=[classification_output, full_response_output],
83
+ )
84
+
85
+ # Auto-classify on image upload
86
+ image_input.change(
87
+ fn=classify_garbage,
88
+ inputs=image_input,
89
+ outputs=[classification_output, full_response_output],
90
+ )
91
+
92
+ if __name__ == "__main__":
93
+ demo.launch()
classifier.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForImageTextToText
2
+ from PIL import Image
3
+ import torch
4
+ import logging
5
+ from typing import Union, Tuple
6
+ from config import Config
7
+ from knowledge_base import GarbageClassificationKnowledge
8
+
9
+
10
+ class GarbageClassifier:
11
+ def __init__(self, config: Config = None):
12
+ self.config = config or Config()
13
+ self.knowledge = GarbageClassificationKnowledge()
14
+ self.processor = None
15
+ self.model = None
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # Setup logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ self.logger = logging.getLogger(__name__)
21
+
22
+ def load_model(self):
23
+ """Load the model and processor"""
24
+ try:
25
+ self.logger.info(f"Loading model: {self.config.MODEL_NAME}")
26
+
27
+ # Load processor
28
+ kwargs = {}
29
+ if self.config.HF_TOKEN:
30
+ kwargs["token"] = self.config.HF_TOKEN
31
+
32
+ self.processor = AutoProcessor.from_pretrained(
33
+ self.config.MODEL_NAME, **kwargs
34
+ )
35
+
36
+ # Load model
37
+ self.model = AutoModelForImageTextToText.from_pretrained(
38
+ self.config.MODEL_NAME,
39
+ torch_dtype=self.config.TORCH_DTYPE,
40
+ device_map=self.config.DEVICE_MAP,
41
+ )
42
+
43
+ self.logger.info("Model loaded successfully")
44
+
45
+ except Exception as e:
46
+ self.logger.error(f"Error loading model: {str(e)}")
47
+ raise
48
+
49
+ def preprocess_image(self, image: Image.Image) -> Image.Image:
50
+ """
51
+ Preprocess image to meet Gemma3n requirements (512x512)
52
+ """
53
+ # Convert to RGB if necessary
54
+ if image.mode != "RGB":
55
+ image = image.convert("RGB")
56
+
57
+ # Resize to 512x512 as required by Gemma3n
58
+ target_size = (512, 512)
59
+
60
+ # Calculate aspect ratio preserving resize
61
+ original_width, original_height = image.size
62
+ aspect_ratio = original_width / original_height
63
+
64
+ if aspect_ratio > 1:
65
+ # Width is larger
66
+ new_width = target_size[0]
67
+ new_height = int(target_size[0] / aspect_ratio)
68
+ else:
69
+ # Height is larger or equal
70
+ new_height = target_size[1]
71
+ new_width = int(target_size[1] * aspect_ratio)
72
+
73
+ # Resize image maintaining aspect ratio
74
+ image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
75
+
76
+ # Create a new image with target size and paste the resized image
77
+ processed_image = Image.new(
78
+ "RGB", target_size, (255, 255, 255)
79
+ ) # White background
80
+
81
+ # Calculate position to center the image
82
+ x_offset = (target_size[0] - new_width) // 2
83
+ y_offset = (target_size[1] - new_height) // 2
84
+
85
+ processed_image.paste(image, (x_offset, y_offset))
86
+
87
+ return processed_image
88
+
89
+ def classify_image(self, image: Union[str, Image.Image]) -> Tuple[str, str]:
90
+ """
91
+ Classify garbage in the image
92
+
93
+ Args:
94
+ image: PIL Image or path to image file
95
+
96
+ Returns:
97
+ Tuple of (classification_result, full_response)
98
+ """
99
+ if self.model is None or self.processor is None:
100
+ raise RuntimeError("Model not loaded. Call load_model() first.")
101
+
102
+ try:
103
+ # Load and process image
104
+ if isinstance(image, str):
105
+ image = Image.open(image)
106
+ elif not isinstance(image, Image.Image):
107
+ raise ValueError("Image must be a PIL Image or file path")
108
+
109
+ # Preprocess image to meet Gemma3n requirements
110
+ processed_image = self.preprocess_image(image)
111
+
112
+ # Prepare messages with system prompt and user query
113
+ messages = [
114
+ {
115
+ "role": "system",
116
+ "content": [
117
+ {
118
+ "type": "text",
119
+ "text": self.knowledge.get_system_prompt(),
120
+ }
121
+ ],
122
+ },
123
+ {
124
+ "role": "user",
125
+ "content": [
126
+ {"type": "image", "image": processed_image},
127
+ {
128
+ "type": "text",
129
+ "text": "Please classify the garbage in this image and explain your reasoning.",
130
+ },
131
+ ],
132
+ },
133
+ ]
134
+
135
+ # Apply chat template and tokenize
136
+ inputs = self.processor.apply_chat_template(
137
+ messages,
138
+ add_generation_prompt=True,
139
+ tokenize=True,
140
+ return_dict=True,
141
+ return_tensors="pt",
142
+ ).to(self.model.device, dtype=self.model.dtype)
143
+ input_len = inputs["input_ids"].shape[-1]
144
+
145
+ outputs = self.model.generate(
146
+ **inputs,
147
+ max_new_tokens=self.config.MAX_NEW_TOKENS,
148
+ disable_compile=True,
149
+ )
150
+ response = self.processor.batch_decode(
151
+ outputs[:, input_len:],
152
+ skip_special_tokens=True,
153
+ )[0]
154
+
155
+ # Extract classification from response
156
+ classification = self._extract_classification(response)
157
+
158
+ # Create formatted response
159
+ formatted_response = self._format_response(classification, response)
160
+
161
+ return classification, formatted_response
162
+
163
+ except Exception as e:
164
+ self.logger.error(f"Error during classification: {str(e)}")
165
+ import traceback
166
+
167
+ traceback.print_exc()
168
+ return "Error", f"Classification failed: {str(e)}"
169
+
170
+ def _extract_classification(self, response: str) -> str:
171
+ """Extract the main classification from the response"""
172
+ categories = self.knowledge.get_categories()
173
+
174
+ # Convert response to lowercase for matching
175
+ response_lower = response.lower()
176
+
177
+ # Look for exact category matches first
178
+ for category in categories:
179
+ if category.lower() in response_lower:
180
+ return category
181
+
182
+ # Look for key terms if no exact match
183
+ category_keywords = {
184
+ "Recyclable Waste": [
185
+ "recyclable",
186
+ "recycle",
187
+ "plastic",
188
+ "paper",
189
+ "metal",
190
+ "glass",
191
+ "bottle",
192
+ "can",
193
+ "aluminum",
194
+ "cardboard",
195
+ ],
196
+ "Food/Kitchen Waste": [
197
+ "food",
198
+ "kitchen",
199
+ "organic",
200
+ "fruit",
201
+ "vegetable",
202
+ "leftovers",
203
+ "scraps",
204
+ "peel",
205
+ "core",
206
+ "bone",
207
+ ],
208
+ "Hazardous Waste": [
209
+ "hazardous",
210
+ "dangerous",
211
+ "toxic",
212
+ "battery",
213
+ "chemical",
214
+ "medicine",
215
+ "paint",
216
+ "pharmaceutical",
217
+ ],
218
+ "Other Waste": [
219
+ "other",
220
+ "general",
221
+ "trash",
222
+ "garbage",
223
+ "waste",
224
+ "cigarette",
225
+ "ceramic",
226
+ "dust",
227
+ ],
228
+ }
229
+
230
+ for category, keywords in category_keywords.items():
231
+ if any(keyword in response_lower for keyword in keywords):
232
+ return category
233
+
234
+ return "Unable to classify"
235
+
236
+ def _format_response(self, classification: str, full_response: str) -> str:
237
+ """Format the response with classification and reasoning"""
238
+ if not full_response.strip():
239
+ return f"**Classification**: {classification}\n**Reasoning**: No detailed analysis available."
240
+
241
+ # If response already contains structured format, return as is
242
+ if "**Classification**" in full_response and "**Reasoning**" in full_response:
243
+ return full_response
244
+
245
+ # Otherwise, format it
246
+ return f"**Classification**: {classification}\n\n**Reasoning**: {full_response}"
247
+
248
+ def get_categories_info(self):
249
+ """Get information about all categories"""
250
+ return self.knowledge.get_category_descriptions()
config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class Config:
8
+ # Gemma3n model configuration
9
+ MODEL_NAME: str = "google/gemma-3n-E2B-it"
10
+
11
+ # Generation parameters
12
+ MAX_NEW_TOKENS: int = 512
13
+
14
+ # Device configuration
15
+ TORCH_DTYPE: str = torch.bfloat16
16
+ if torch.cuda.is_available():
17
+ DEVICE_MAP: str = "cuda:0" # Use first GPU if available
18
+ else:
19
+ DEVICE_MAP: str = "cpu"
20
+
21
+ # Image preprocessing
22
+ IMAGE_SIZE: int = 512
23
+
24
+ # Hugging Face token
25
+ HF_TOKEN: str = os.getenv("HF_TOKEN", "")
knowledge_base.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class GarbageClassificationKnowledge:
2
+ @staticmethod
3
+ def get_system_prompt():
4
+ return """You are a professional garbage classification expert. You need to carefully observe the items in the picture, analyze their materials, properties and uses, and then make accurate judgments according to garbage classification standards.
5
+
6
+ Garbage classification standards:
7
+
8
+ **Recyclable Waste**:
9
+ - Paper: newspapers, magazines, books, various packaging papers, office paper, advertising flyers, cardboard boxes, copy paper, etc.
10
+ - Plastics: various plastic bags, plastic packaging, disposable plastic food containers and utensils, toothbrushes, cups, water bottles, plastic toys, etc.
11
+ - Metals: aluminum cans, tin cans, toothpaste tubes, metal toys, metal stationery, nails, metal sheets, aluminum foil, etc.
12
+ - Glass: glass bottles, broken glass pieces, mirrors, light bulbs, vacuum flasks, etc.
13
+ - Textiles: old clothing, textile products, shoes, curtains, towels, bags, etc.
14
+
15
+ **Food/Kitchen Waste**:
16
+ - Food scraps: rice, noodles, bread, meat, fish, shrimp shells, crab shells, bones, etc.
17
+ - Fruit peels and cores: watermelon rinds, apple cores, orange peels, banana peels, nut shells, etc.
18
+ - Plants: withered branches and leaves, flowers, traditional Chinese medicine residue, etc.
19
+ - Expired food: expired canned food, cookies, candy, etc.
20
+
21
+ **Hazardous Waste**:
22
+ - Batteries: dry batteries, rechargeable batteries, button batteries, and all types of batteries
23
+ - Light tubes: energy-saving lamps, fluorescent tubes, incandescent bulbs, LED lights, etc.
24
+ - Pharmaceuticals: expired medicines, medicine packaging, thermometers, blood pressure monitors, etc.
25
+ - Paints: paint, coatings, glue, nail polish, cosmetics, etc.
26
+ - Others: pesticides, cleaning agents, agricultural chemicals, X-ray films, etc.
27
+
28
+ **Other Waste**:
29
+ - Contaminated non-recyclable paper: toilet paper, diapers, wet wipes, napkins, etc.
30
+ - Cigarette butts, ceramics, dust, disposable tableware (non-plastic)
31
+ - Large bones, hard shells, hard fruit pits (coconut shells, durian shells, walnut shells, corn cobs, etc.)
32
+ - Hair, pet waste, cat litter, etc.
33
+
34
+ Please observe the items in the image carefully according to the above classification standards, provide accurate garbage classification results, and briefly explain the classification reasoning. Format your response as:
35
+
36
+ **Classification**: [Category Name]
37
+ **Reasoning**: [Brief explanation of why this item belongs to this category]"""
38
+
39
+ @staticmethod
40
+ def get_categories():
41
+ return [
42
+ "Recyclable Waste",
43
+ "Food/Kitchen Waste",
44
+ "Hazardous Waste",
45
+ "Other Waste",
46
+ ]
47
+
48
+ @staticmethod
49
+ def get_category_descriptions():
50
+ return {
51
+ "Recyclable Waste": "Items that can be processed and reused, including paper, plastic, metal, glass, and textiles",
52
+ "Food/Kitchen Waste": "Organic waste from food preparation and consumption",
53
+ "Hazardous Waste": "Items containing harmful substances that require special disposal",
54
+ "Other Waste": "Items that don't fit into other categories and go to general waste",
55
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pillow
3
+ torch
4
+ torchvision
5
+ transformers >= 4.53
6
+ accelerate
7
+ timm
8
+ gradio
test_images/cardboard1.jpg ADDED
test_images/glass2.jpg ADDED
test_images/metal5.jpg ADDED