Spaces:
Sleeping
Sleeping
File size: 5,084 Bytes
bf17f74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
"""
Inference script for waste classification
Optimized for CPU with fast preprocessing
"""
import torch
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import numpy as np
import base64
from io import BytesIO
import json
from pathlib import Path
class WasteClassifier:
"""Waste classification inference class"""
def __init__(self, model_path='ml/models/best_model.pth', device=None):
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load checkpoint
checkpoint = torch.load(model_path, map_location=self.device)
self.categories = checkpoint['categories']
# Create model
self.model = models.efficientnet_b0(pretrained=False)
num_features = self.model.classifier[1].in_features
self.model.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=0.3),
torch.nn.Linear(num_features, len(self.categories))
)
# Load weights
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
self.model.eval()
# Setup transforms
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
print(f"Model loaded successfully on {self.device}")
print(f"Categories: {self.categories}")
def preprocess_image(self, image_input):
"""
Preprocess image from various input formats
Accepts: PIL Image, file path, base64 string, or numpy array
"""
if isinstance(image_input, str):
if image_input.startswith('data:image'):
# Base64 encoded image
image_data = image_input.split(',')[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes)).convert('RGB')
else:
# File path
image = Image.open(image_input).convert('RGB')
elif isinstance(image_input, np.ndarray):
image = Image.fromarray(image_input).convert('RGB')
elif isinstance(image_input, Image.Image):
image = image_input.convert('RGB')
else:
raise ValueError(f"Unsupported image input type: {type(image_input)}")
return self.transform(image).unsqueeze(0)
def predict(self, image_input):
"""
Predict waste category for input image
Returns:
dict: {
'category': str,
'confidence': float,
'probabilities': dict
}
"""
# Preprocess
image_tensor = self.preprocess_image(image_input).to(self.device)
# Inference
with torch.no_grad():
outputs = self.model(image_tensor)
probabilities = F.softmax(outputs, dim=1)
confidence, predicted_idx = torch.max(probabilities, 1)
# Format results
predicted_category = self.categories[predicted_idx.item()]
confidence_score = confidence.item()
# Get all probabilities
prob_dict = {
category: float(prob)
for category, prob in zip(self.categories, probabilities[0].cpu().numpy())
}
return {
'category': predicted_category,
'confidence': confidence_score,
'probabilities': prob_dict,
'timestamp': int(np.datetime64('now').astype(int) / 1000000)
}
def predict_batch(self, image_inputs):
"""Predict for multiple images"""
results = []
for image_input in image_inputs:
results.append(self.predict(image_input))
return results
def export_to_onnx(model_path='ml/models/best_model.pth',
output_path='ml/models/model.onnx'):
"""Export PyTorch model to ONNX format for deployment"""
classifier = WasteClassifier(model_path)
# Create dummy input
dummy_input = torch.randn(1, 3, 224, 224).to(classifier.device)
# Export
torch.onnx.export(
classifier.model,
dummy_input,
output_path,
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print(f"Model exported to ONNX: {output_path}")
if __name__ == "__main__":
# Test inference
classifier = WasteClassifier()
# Example usage
test_image = "ml/data/processed/test/recyclable/sample.jpg"
if Path(test_image).exists():
result = classifier.predict(test_image)
print("\nPrediction Result:")
print(json.dumps(result, indent=2))
|