""" Training script for waste classification model Uses transfer learning with EfficientNet-B0 for optimal accuracy and speed """ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset from torchvision import transforms, models from PIL import Image import os import json from pathlib import Path from tqdm import tqdm import numpy as np from sklearn.metrics import confusion_matrix, f1_score, classification_report import matplotlib.pyplot as plt import seaborn as sns # Configuration CONFIG = { 'data_dir': 'ml/data/processed', 'model_dir': 'ml/models', 'batch_size': 32, 'num_epochs': 50, 'learning_rate': 0.001, 'image_size': 224, 'num_classes': 7, 'early_stopping_patience': 7, 'device': 'cuda' if torch.cuda.is_available() else 'cpu', } # Waste categories mapping CATEGORIES = [ 'recyclable', 'organic', 'wet-waste', 'dry-waste', 'ewaste', 'hazardous', 'landfill' ] class WasteDataset(Dataset): """Custom dataset for waste classification""" def __init__(self, data_dir, split='train', transform=None): self.data_dir = Path(data_dir) / split self.transform = transform self.samples = [] # Load all images and labels for category_idx, category in enumerate(CATEGORIES): category_path = self.data_dir / category if category_path.exists(): for img_path in category_path.glob('*.jpg'): self.samples.append((str(img_path), category_idx)) for img_path in category_path.glob('*.png'): self.samples.append((str(img_path), category_idx)) print(f"Loaded {len(self.samples)} samples for {split} split") def __len__(self): return len(self.samples) def __getitem__(self, idx): img_path, label = self.samples[idx] image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) return image, label def get_transforms(split='train'): """Get data augmentation transforms""" if split == 'train': return transforms.Compose([ transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: return transforms.Compose([ transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def create_model(num_classes): """ Create EfficientNet-B0 model with pretrained weights EfficientNet provides excellent accuracy with low latency """ model = models.efficientnet_b0(pretrained=True) # Freeze early layers for param in model.features[:5].parameters(): param.requires_grad = False # Replace classifier num_features = model.classifier[1].in_features model.classifier = nn.Sequential( nn.Dropout(p=0.3), nn.Linear(num_features, num_classes) ) return model def train_epoch(model, dataloader, criterion, optimizer, device): """Train for one epoch""" model.train() running_loss = 0.0 correct = 0 total = 0 pbar = tqdm(dataloader, desc='Training') for images, labels in pbar: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() pbar.set_postfix({ 'loss': f'{running_loss/len(pbar):.4f}', 'acc': f'{100.*correct/total:.2f}%' }) return running_loss / len(dataloader), 100. * correct / total def validate(model, dataloader, criterion, device): """Validate the model""" model.eval() running_loss = 0.0 correct = 0 total = 0 all_preds = [] all_labels = [] with torch.no_grad(): for images, labels in tqdm(dataloader, desc='Validating'): images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() all_preds.extend(predicted.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) accuracy = 100. * correct / total avg_loss = running_loss / len(dataloader) # Calculate F1 scores f1_macro = f1_score(all_labels, all_preds, average='macro') f1_weighted = f1_score(all_labels, all_preds, average='weighted') return avg_loss, accuracy, f1_macro, f1_weighted, all_preds, all_labels def plot_confusion_matrix(y_true, y_pred, save_path): """Plot and save confusion matrix""" cm = confusion_matrix(y_true, y_pred) plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=CATEGORIES, yticklabels=CATEGORIES) plt.title('Confusion Matrix') plt.ylabel('True Label') plt.xlabel('Predicted Label') plt.tight_layout() plt.savefig(save_path) plt.close() print(f"Confusion matrix saved to {save_path}") def train_model(): """Main training function""" # Create directories Path(CONFIG['model_dir']).mkdir(parents=True, exist_ok=True) # Setup device device = torch.device(CONFIG['device']) print(f"Using device: {device}") # Create datasets train_dataset = WasteDataset( CONFIG['data_dir'], split='train', transform=get_transforms('train') ) val_dataset = WasteDataset( CONFIG['data_dir'], split='val', transform=get_transforms('val') ) # Create dataloaders train_loader = DataLoader( train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=4, pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=4, pin_memory=True ) # Create model model = create_model(CONFIG['num_classes']).to(device) print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters") # Loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate']) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=0.5, patience=3, verbose=True ) # Training loop best_acc = 0.0 patience_counter = 0 history = { 'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'val_f1_macro': [], 'val_f1_weighted': [] } for epoch in range(CONFIG['num_epochs']): print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}") print("-" * 50) # Train train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device) # Validate val_loss, val_acc, f1_macro, f1_weighted, val_preds, val_labels = validate( model, val_loader, criterion, device ) # Update scheduler scheduler.step(val_acc) # Save history history['train_loss'].append(train_loss) history['train_acc'].append(train_acc) history['val_loss'].append(val_loss) history['val_acc'].append(val_acc) history['val_f1_macro'].append(f1_macro) history['val_f1_weighted'].append(f1_weighted) print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%") print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%") print(f"F1 Macro: {f1_macro:.4f} | F1 Weighted: {f1_weighted:.4f}") # Save best model if val_acc > best_acc: best_acc = val_acc patience_counter = 0 torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'accuracy': val_acc, 'f1_macro': f1_macro, 'f1_weighted': f1_weighted, 'categories': CATEGORIES, 'config': CONFIG }, f"{CONFIG['model_dir']}/best_model.pth") print(f"✓ Best model saved with accuracy: {best_acc:.2f}%") # Save confusion matrix for best model plot_confusion_matrix( val_labels, val_preds, f"{CONFIG['model_dir']}/confusion_matrix.png" ) else: patience_counter += 1 # Early stopping if patience_counter >= CONFIG['early_stopping_patience']: print(f"\nEarly stopping triggered after {epoch+1} epochs") break # Save training history with open(f"{CONFIG['model_dir']}/training_history.json", 'w') as f: json.dump(history, f, indent=2) # Generate classification report print("\nClassification Report:") print(classification_report(val_labels, val_preds, target_names=CATEGORIES)) print(f"\nTraining complete! Best validation accuracy: {best_acc:.2f}%") return model, history if __name__ == "__main__": train_model()