Spaces:
Sleeping
Sleeping
| """ | |
| Training script for waste classification model | |
| Uses transfer learning with EfficientNet-B0 for optimal accuracy and speed | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, Dataset | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import os | |
| import json | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import numpy as np | |
| from sklearn.metrics import confusion_matrix, f1_score, classification_report | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| # Configuration | |
| CONFIG = { | |
| 'data_dir': 'ml/data/processed', | |
| 'model_dir': 'ml/models', | |
| 'batch_size': 32, | |
| 'num_epochs': 50, | |
| 'learning_rate': 0.001, | |
| 'image_size': 224, | |
| 'num_classes': 7, | |
| 'early_stopping_patience': 7, | |
| 'device': 'cuda' if torch.cuda.is_available() else 'cpu', | |
| } | |
| # Waste categories mapping | |
| CATEGORIES = [ | |
| 'recyclable', | |
| 'organic', | |
| 'wet-waste', | |
| 'dry-waste', | |
| 'ewaste', | |
| 'hazardous', | |
| 'landfill' | |
| ] | |
| class WasteDataset(Dataset): | |
| """Custom dataset for waste classification""" | |
| def __init__(self, data_dir, split='train', transform=None): | |
| self.data_dir = Path(data_dir) / split | |
| self.transform = transform | |
| self.samples = [] | |
| # Load all images and labels | |
| for category_idx, category in enumerate(CATEGORIES): | |
| category_path = self.data_dir / category | |
| if category_path.exists(): | |
| for img_path in category_path.glob('*.jpg'): | |
| self.samples.append((str(img_path), category_idx)) | |
| for img_path in category_path.glob('*.png'): | |
| self.samples.append((str(img_path), category_idx)) | |
| print(f"Loaded {len(self.samples)} samples for {split} split") | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| img_path, label = self.samples[idx] | |
| image = Image.open(img_path).convert('RGB') | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| def get_transforms(split='train'): | |
| """Get data augmentation transforms""" | |
| if split == 'train': | |
| return transforms.Compose([ | |
| transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomRotation(15), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), | |
| transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| else: | |
| return transforms.Compose([ | |
| transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def create_model(num_classes): | |
| """ | |
| Create EfficientNet-B0 model with pretrained weights | |
| EfficientNet provides excellent accuracy with low latency | |
| """ | |
| model = models.efficientnet_b0(pretrained=True) | |
| # Freeze early layers | |
| for param in model.features[:5].parameters(): | |
| param.requires_grad = False | |
| # Replace classifier | |
| num_features = model.classifier[1].in_features | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.3), | |
| nn.Linear(num_features, num_classes) | |
| ) | |
| return model | |
| def train_epoch(model, dataloader, criterion, optimizer, device): | |
| """Train for one epoch""" | |
| model.train() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| pbar = tqdm(dataloader, desc='Training') | |
| for images, labels in pbar: | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| _, predicted = outputs.max(1) | |
| total += labels.size(0) | |
| correct += predicted.eq(labels).sum().item() | |
| pbar.set_postfix({ | |
| 'loss': f'{running_loss/len(pbar):.4f}', | |
| 'acc': f'{100.*correct/total:.2f}%' | |
| }) | |
| return running_loss / len(dataloader), 100. * correct / total | |
| def validate(model, dataloader, criterion, device): | |
| """Validate the model""" | |
| model.eval() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| all_preds = [] | |
| all_labels = [] | |
| with torch.no_grad(): | |
| for images, labels in tqdm(dataloader, desc='Validating'): | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| running_loss += loss.item() | |
| _, predicted = outputs.max(1) | |
| total += labels.size(0) | |
| correct += predicted.eq(labels).sum().item() | |
| all_preds.extend(predicted.cpu().numpy()) | |
| all_labels.extend(labels.cpu().numpy()) | |
| accuracy = 100. * correct / total | |
| avg_loss = running_loss / len(dataloader) | |
| # Calculate F1 scores | |
| f1_macro = f1_score(all_labels, all_preds, average='macro') | |
| f1_weighted = f1_score(all_labels, all_preds, average='weighted') | |
| return avg_loss, accuracy, f1_macro, f1_weighted, all_preds, all_labels | |
| def plot_confusion_matrix(y_true, y_pred, save_path): | |
| """Plot and save confusion matrix""" | |
| cm = confusion_matrix(y_true, y_pred) | |
| plt.figure(figsize=(10, 8)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', | |
| xticklabels=CATEGORIES, yticklabels=CATEGORIES) | |
| plt.title('Confusion Matrix') | |
| plt.ylabel('True Label') | |
| plt.xlabel('Predicted Label') | |
| plt.tight_layout() | |
| plt.savefig(save_path) | |
| plt.close() | |
| print(f"Confusion matrix saved to {save_path}") | |
| def train_model(): | |
| """Main training function""" | |
| # Create directories | |
| Path(CONFIG['model_dir']).mkdir(parents=True, exist_ok=True) | |
| # Setup device | |
| device = torch.device(CONFIG['device']) | |
| print(f"Using device: {device}") | |
| # Create datasets | |
| train_dataset = WasteDataset( | |
| CONFIG['data_dir'], | |
| split='train', | |
| transform=get_transforms('train') | |
| ) | |
| val_dataset = WasteDataset( | |
| CONFIG['data_dir'], | |
| split='val', | |
| transform=get_transforms('val') | |
| ) | |
| # Create dataloaders | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=CONFIG['batch_size'], | |
| shuffle=True, | |
| num_workers=4, | |
| pin_memory=True | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=CONFIG['batch_size'], | |
| shuffle=False, | |
| num_workers=4, | |
| pin_memory=True | |
| ) | |
| # Create model | |
| model = create_model(CONFIG['num_classes']).to(device) | |
| print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters") | |
| # Loss and optimizer | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate']) | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode='max', factor=0.5, patience=3, verbose=True | |
| ) | |
| # Training loop | |
| best_acc = 0.0 | |
| patience_counter = 0 | |
| history = { | |
| 'train_loss': [], 'train_acc': [], | |
| 'val_loss': [], 'val_acc': [], | |
| 'val_f1_macro': [], 'val_f1_weighted': [] | |
| } | |
| for epoch in range(CONFIG['num_epochs']): | |
| print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}") | |
| print("-" * 50) | |
| # Train | |
| train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device) | |
| # Validate | |
| val_loss, val_acc, f1_macro, f1_weighted, val_preds, val_labels = validate( | |
| model, val_loader, criterion, device | |
| ) | |
| # Update scheduler | |
| scheduler.step(val_acc) | |
| # Save history | |
| history['train_loss'].append(train_loss) | |
| history['train_acc'].append(train_acc) | |
| history['val_loss'].append(val_loss) | |
| history['val_acc'].append(val_acc) | |
| history['val_f1_macro'].append(f1_macro) | |
| history['val_f1_weighted'].append(f1_weighted) | |
| print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%") | |
| print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%") | |
| print(f"F1 Macro: {f1_macro:.4f} | F1 Weighted: {f1_weighted:.4f}") | |
| # Save best model | |
| if val_acc > best_acc: | |
| best_acc = val_acc | |
| patience_counter = 0 | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'accuracy': val_acc, | |
| 'f1_macro': f1_macro, | |
| 'f1_weighted': f1_weighted, | |
| 'categories': CATEGORIES, | |
| 'config': CONFIG | |
| }, f"{CONFIG['model_dir']}/best_model.pth") | |
| print(f"✓ Best model saved with accuracy: {best_acc:.2f}%") | |
| # Save confusion matrix for best model | |
| plot_confusion_matrix( | |
| val_labels, | |
| val_preds, | |
| f"{CONFIG['model_dir']}/confusion_matrix.png" | |
| ) | |
| else: | |
| patience_counter += 1 | |
| # Early stopping | |
| if patience_counter >= CONFIG['early_stopping_patience']: | |
| print(f"\nEarly stopping triggered after {epoch+1} epochs") | |
| break | |
| # Save training history | |
| with open(f"{CONFIG['model_dir']}/training_history.json", 'w') as f: | |
| json.dump(history, f, indent=2) | |
| # Generate classification report | |
| print("\nClassification Report:") | |
| print(classification_report(val_labels, val_preds, target_names=CATEGORIES)) | |
| print(f"\nTraining complete! Best validation accuracy: {best_acc:.2f}%") | |
| return model, history | |
| if __name__ == "__main__": | |
| train_model() | |