Waste_Classification / predict.py
Terence9's picture
Upload 6 files
9df438a verified
import torch
from torchvision import transforms
from PIL import Image
from model import WasteCNN # Import the model architecture
def predict_waste(image_path):
# Load the model
model = WasteCNN()
model.load_state_dict(torch.load('waste_classifier.pth', map_location=torch.device('cpu')))
model.eval()
# Prepare the image
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
])
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0) # Add batch dimension
# Make prediction
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output, 1)
return "Dry Waste" if predicted.item() == 0 else "Wet Waste"
if __name__ == "__main__":
# Example usage
image_path = input("Enter the path to your waste image: ")
result = predict_waste(image_path)
print(f"Prediction: {result}")