Rahiq's picture
Deploy waste classification backend with ML model
bf17f74
"""
Training script for waste classification model
Uses transfer learning with EfficientNet-B0 for optimal accuracy and speed
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import os
import json
from pathlib import Path
from tqdm import tqdm
import numpy as np
from sklearn.metrics import confusion_matrix, f1_score, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
# Configuration
CONFIG = {
'data_dir': 'ml/data/processed',
'model_dir': 'ml/models',
'batch_size': 32,
'num_epochs': 50,
'learning_rate': 0.001,
'image_size': 224,
'num_classes': 7,
'early_stopping_patience': 7,
'device': 'cuda' if torch.cuda.is_available() else 'cpu',
}
# Waste categories mapping
CATEGORIES = [
'recyclable',
'organic',
'wet-waste',
'dry-waste',
'ewaste',
'hazardous',
'landfill'
]
class WasteDataset(Dataset):
"""Custom dataset for waste classification"""
def __init__(self, data_dir, split='train', transform=None):
self.data_dir = Path(data_dir) / split
self.transform = transform
self.samples = []
# Load all images and labels
for category_idx, category in enumerate(CATEGORIES):
category_path = self.data_dir / category
if category_path.exists():
for img_path in category_path.glob('*.jpg'):
self.samples.append((str(img_path), category_idx))
for img_path in category_path.glob('*.png'):
self.samples.append((str(img_path), category_idx))
print(f"Loaded {len(self.samples)} samples for {split} split")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
def get_transforms(split='train'):
"""Get data augmentation transforms"""
if split == 'train':
return transforms.Compose([
transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
else:
return transforms.Compose([
transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def create_model(num_classes):
"""
Create EfficientNet-B0 model with pretrained weights
EfficientNet provides excellent accuracy with low latency
"""
model = models.efficientnet_b0(pretrained=True)
# Freeze early layers
for param in model.features[:5].parameters():
param.requires_grad = False
# Replace classifier
num_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(p=0.3),
nn.Linear(num_features, num_classes)
)
return model
def train_epoch(model, dataloader, criterion, optimizer, device):
"""Train for one epoch"""
model.train()
running_loss = 0.0
correct = 0
total = 0
pbar = tqdm(dataloader, desc='Training')
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
pbar.set_postfix({
'loss': f'{running_loss/len(pbar):.4f}',
'acc': f'{100.*correct/total:.2f}%'
})
return running_loss / len(dataloader), 100. * correct / total
def validate(model, dataloader, criterion, device):
"""Validate the model"""
model.eval()
running_loss = 0.0
correct = 0
total = 0
all_preds = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(dataloader, desc='Validating'):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
accuracy = 100. * correct / total
avg_loss = running_loss / len(dataloader)
# Calculate F1 scores
f1_macro = f1_score(all_labels, all_preds, average='macro')
f1_weighted = f1_score(all_labels, all_preds, average='weighted')
return avg_loss, accuracy, f1_macro, f1_weighted, all_preds, all_labels
def plot_confusion_matrix(y_true, y_pred, save_path):
"""Plot and save confusion matrix"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=CATEGORIES, yticklabels=CATEGORIES)
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig(save_path)
plt.close()
print(f"Confusion matrix saved to {save_path}")
def train_model():
"""Main training function"""
# Create directories
Path(CONFIG['model_dir']).mkdir(parents=True, exist_ok=True)
# Setup device
device = torch.device(CONFIG['device'])
print(f"Using device: {device}")
# Create datasets
train_dataset = WasteDataset(
CONFIG['data_dir'],
split='train',
transform=get_transforms('train')
)
val_dataset = WasteDataset(
CONFIG['data_dir'],
split='val',
transform=get_transforms('val')
)
# Create dataloaders
train_loader = DataLoader(
train_dataset,
batch_size=CONFIG['batch_size'],
shuffle=True,
num_workers=4,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=CONFIG['batch_size'],
shuffle=False,
num_workers=4,
pin_memory=True
)
# Create model
model = create_model(CONFIG['num_classes']).to(device)
print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='max', factor=0.5, patience=3, verbose=True
)
# Training loop
best_acc = 0.0
patience_counter = 0
history = {
'train_loss': [], 'train_acc': [],
'val_loss': [], 'val_acc': [],
'val_f1_macro': [], 'val_f1_weighted': []
}
for epoch in range(CONFIG['num_epochs']):
print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
print("-" * 50)
# Train
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
# Validate
val_loss, val_acc, f1_macro, f1_weighted, val_preds, val_labels = validate(
model, val_loader, criterion, device
)
# Update scheduler
scheduler.step(val_acc)
# Save history
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
history['val_f1_macro'].append(f1_macro)
history['val_f1_weighted'].append(f1_weighted)
print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
print(f"F1 Macro: {f1_macro:.4f} | F1 Weighted: {f1_weighted:.4f}")
# Save best model
if val_acc > best_acc:
best_acc = val_acc
patience_counter = 0
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': val_acc,
'f1_macro': f1_macro,
'f1_weighted': f1_weighted,
'categories': CATEGORIES,
'config': CONFIG
}, f"{CONFIG['model_dir']}/best_model.pth")
print(f"✓ Best model saved with accuracy: {best_acc:.2f}%")
# Save confusion matrix for best model
plot_confusion_matrix(
val_labels,
val_preds,
f"{CONFIG['model_dir']}/confusion_matrix.png"
)
else:
patience_counter += 1
# Early stopping
if patience_counter >= CONFIG['early_stopping_patience']:
print(f"\nEarly stopping triggered after {epoch+1} epochs")
break
# Save training history
with open(f"{CONFIG['model_dir']}/training_history.json", 'w') as f:
json.dump(history, f, indent=2)
# Generate classification report
print("\nClassification Report:")
print(classification_report(val_labels, val_preds, target_names=CATEGORIES))
print(f"\nTraining complete! Best validation accuracy: {best_acc:.2f}%")
return model, history
if __name__ == "__main__":
train_model()