""" Inference script for waste classification Optimized for CPU with fast preprocessing """ import torch import torch.nn.functional as F from torchvision import transforms, models from PIL import Image import numpy as np import base64 from io import BytesIO import json from pathlib import Path class WasteClassifier: """Waste classification inference class""" def __init__(self, model_path='ml/models/best_model.pth', device=None): self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load checkpoint checkpoint = torch.load(model_path, map_location=self.device) self.categories = checkpoint['categories'] # Create model self.model = models.efficientnet_b0(pretrained=False) num_features = self.model.classifier[1].in_features self.model.classifier = torch.nn.Sequential( torch.nn.Dropout(p=0.3), torch.nn.Linear(num_features, len(self.categories)) ) # Load weights self.model.load_state_dict(checkpoint['model_state_dict']) self.model.to(self.device) self.model.eval() # Setup transforms self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) print(f"Model loaded successfully on {self.device}") print(f"Categories: {self.categories}") def preprocess_image(self, image_input): """ Preprocess image from various input formats Accepts: PIL Image, file path, base64 string, or numpy array """ if isinstance(image_input, str): if image_input.startswith('data:image'): # Base64 encoded image image_data = image_input.split(',')[1] image_bytes = base64.b64decode(image_data) image = Image.open(BytesIO(image_bytes)).convert('RGB') else: # File path image = Image.open(image_input).convert('RGB') elif isinstance(image_input, np.ndarray): image = Image.fromarray(image_input).convert('RGB') elif isinstance(image_input, Image.Image): image = image_input.convert('RGB') else: raise ValueError(f"Unsupported image input type: {type(image_input)}") return self.transform(image).unsqueeze(0) def predict(self, image_input): """ Predict waste category for input image Returns: dict: { 'category': str, 'confidence': float, 'probabilities': dict } """ # Preprocess image_tensor = self.preprocess_image(image_input).to(self.device) # Inference with torch.no_grad(): outputs = self.model(image_tensor) probabilities = F.softmax(outputs, dim=1) confidence, predicted_idx = torch.max(probabilities, 1) # Format results predicted_category = self.categories[predicted_idx.item()] confidence_score = confidence.item() # Get all probabilities prob_dict = { category: float(prob) for category, prob in zip(self.categories, probabilities[0].cpu().numpy()) } return { 'category': predicted_category, 'confidence': confidence_score, 'probabilities': prob_dict, 'timestamp': int(np.datetime64('now').astype(int) / 1000000) } def predict_batch(self, image_inputs): """Predict for multiple images""" results = [] for image_input in image_inputs: results.append(self.predict(image_input)) return results def export_to_onnx(model_path='ml/models/best_model.pth', output_path='ml/models/model.onnx'): """Export PyTorch model to ONNX format for deployment""" classifier = WasteClassifier(model_path) # Create dummy input dummy_input = torch.randn(1, 3, 224, 224).to(classifier.device) # Export torch.onnx.export( classifier.model, dummy_input, output_path, export_params=True, opset_version=12, do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes={ 'input': {0: 'batch_size'}, 'output': {0: 'batch_size'} } ) print(f"Model exported to ONNX: {output_path}") if __name__ == "__main__": # Test inference classifier = WasteClassifier() # Example usage test_image = "ml/data/processed/test/recyclable/sample.jpg" if Path(test_image).exists(): result = classifier.predict(test_image) print("\nPrediction Result:") print(json.dumps(result, indent=2))