Spaces:
Sleeping
Sleeping
| """ | |
| Inference script for waste classification | |
| Optimized for CPU with fast preprocessing | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import numpy as np | |
| import base64 | |
| from io import BytesIO | |
| import json | |
| from pathlib import Path | |
| class WasteClassifier: | |
| """Waste classification inference class""" | |
| def __init__(self, model_path='ml/models/best_model.pth', device=None): | |
| self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load checkpoint | |
| checkpoint = torch.load(model_path, map_location=self.device) | |
| self.categories = checkpoint['categories'] | |
| # Create model | |
| self.model = models.efficientnet_b0(pretrained=False) | |
| num_features = self.model.classifier[1].in_features | |
| self.model.classifier = torch.nn.Sequential( | |
| torch.nn.Dropout(p=0.3), | |
| torch.nn.Linear(num_features, len(self.categories)) | |
| ) | |
| # Load weights | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Setup transforms | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| print(f"Model loaded successfully on {self.device}") | |
| print(f"Categories: {self.categories}") | |
| def preprocess_image(self, image_input): | |
| """ | |
| Preprocess image from various input formats | |
| Accepts: PIL Image, file path, base64 string, or numpy array | |
| """ | |
| if isinstance(image_input, str): | |
| if image_input.startswith('data:image'): | |
| # Base64 encoded image | |
| image_data = image_input.split(',')[1] | |
| image_bytes = base64.b64decode(image_data) | |
| image = Image.open(BytesIO(image_bytes)).convert('RGB') | |
| else: | |
| # File path | |
| image = Image.open(image_input).convert('RGB') | |
| elif isinstance(image_input, np.ndarray): | |
| image = Image.fromarray(image_input).convert('RGB') | |
| elif isinstance(image_input, Image.Image): | |
| image = image_input.convert('RGB') | |
| else: | |
| raise ValueError(f"Unsupported image input type: {type(image_input)}") | |
| return self.transform(image).unsqueeze(0) | |
| def predict(self, image_input): | |
| """ | |
| Predict waste category for input image | |
| Returns: | |
| dict: { | |
| 'category': str, | |
| 'confidence': float, | |
| 'probabilities': dict | |
| } | |
| """ | |
| # Preprocess | |
| image_tensor = self.preprocess_image(image_input).to(self.device) | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = self.model(image_tensor) | |
| probabilities = F.softmax(outputs, dim=1) | |
| confidence, predicted_idx = torch.max(probabilities, 1) | |
| # Format results | |
| predicted_category = self.categories[predicted_idx.item()] | |
| confidence_score = confidence.item() | |
| # Get all probabilities | |
| prob_dict = { | |
| category: float(prob) | |
| for category, prob in zip(self.categories, probabilities[0].cpu().numpy()) | |
| } | |
| return { | |
| 'category': predicted_category, | |
| 'confidence': confidence_score, | |
| 'probabilities': prob_dict, | |
| 'timestamp': int(np.datetime64('now').astype(int) / 1000000) | |
| } | |
| def predict_batch(self, image_inputs): | |
| """Predict for multiple images""" | |
| results = [] | |
| for image_input in image_inputs: | |
| results.append(self.predict(image_input)) | |
| return results | |
| def export_to_onnx(model_path='ml/models/best_model.pth', | |
| output_path='ml/models/model.onnx'): | |
| """Export PyTorch model to ONNX format for deployment""" | |
| classifier = WasteClassifier(model_path) | |
| # Create dummy input | |
| dummy_input = torch.randn(1, 3, 224, 224).to(classifier.device) | |
| # Export | |
| torch.onnx.export( | |
| classifier.model, | |
| dummy_input, | |
| output_path, | |
| export_params=True, | |
| opset_version=12, | |
| do_constant_folding=True, | |
| input_names=['input'], | |
| output_names=['output'], | |
| dynamic_axes={ | |
| 'input': {0: 'batch_size'}, | |
| 'output': {0: 'batch_size'} | |
| } | |
| ) | |
| print(f"Model exported to ONNX: {output_path}") | |
| if __name__ == "__main__": | |
| # Test inference | |
| classifier = WasteClassifier() | |
| # Example usage | |
| test_image = "ml/data/processed/test/recyclable/sample.jpg" | |
| if Path(test_image).exists(): | |
| result = classifier.predict(test_image) | |
| print("\nPrediction Result:") | |
| print(json.dumps(result, indent=2)) | |