""" Continuous learning script for model improvement Fine-tunes existing model with new corrected samples """ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import models from pathlib import Path import shutil from datetime import datetime import json from .train import WasteDataset, get_transforms, validate, CATEGORIES, CONFIG def get_model_version(): """Get next model version number""" model_dir = Path(CONFIG['model_dir']) existing_versions = list(model_dir.glob('model_v*.pth')) if not existing_versions: return 1 versions = [int(p.stem.split('_v')[1]) for p in existing_versions] return max(versions) + 1 def prepare_retraining_data(): """Organize retraining data into proper structure""" retraining_dir = Path('ml/data/retraining') processed_dir = Path(CONFIG['data_dir']) if not retraining_dir.exists(): print("No retraining data found") return 0 # Count new samples new_samples = 0 for category in CATEGORIES: category_dir = retraining_dir / category if category_dir.exists(): images = list(category_dir.glob('*.jpg')) + list(category_dir.glob('*.png')) new_samples += len(images) # Copy to training set target_dir = processed_dir / 'train' / category target_dir.mkdir(parents=True, exist_ok=True) for img_path in images: target_path = target_dir / f"retrain_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{img_path.name}" shutil.copy(img_path, target_path) print(f"Added {new_samples} new samples to training set") return new_samples def retrain_model(base_model_path='ml/models/best_model.pth', num_epochs=10, learning_rate=0.0001): """ Fine-tune existing model with new data Uses lower learning rate for incremental learning """ print("Starting retraining process...") # Prepare new data new_samples = prepare_retraining_data() if new_samples == 0: print("No new samples to train on") return None # Setup device device = torch.device(CONFIG['device']) print(f"Using device: {device}") # Load base model checkpoint = torch.load(base_model_path, map_location=device) model = models.efficientnet_b0(pretrained=False) num_features = model.classifier[1].in_features model.classifier = nn.Sequential( nn.Dropout(p=0.3), nn.Linear(num_features, CONFIG['num_classes']) ) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) print(f"Loaded base model with accuracy: {checkpoint['accuracy']:.2f}%") # Create datasets with updated data train_dataset = WasteDataset( CONFIG['data_dir'], split='train', transform=get_transforms('train') ) val_dataset = WasteDataset( CONFIG['data_dir'], split='val', transform=get_transforms('val') ) train_loader = DataLoader( train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=4 ) val_loader = DataLoader( val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=4 ) # Setup training criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) best_acc = checkpoint['accuracy'] improvement_threshold = 1.0 # Must improve by at least 1% # Fine-tuning loop for epoch in range(num_epochs): print(f"\nRetraining Epoch {epoch+1}/{num_epochs}") print("-" * 50) # Train model.train() for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() # Validate val_loss, val_acc, f1_macro, f1_weighted, val_preds, val_labels = validate( model, val_loader, criterion, device ) print(f"Val Acc: {val_acc:.2f}% | F1 Macro: {f1_macro:.4f}") # Check improvement if val_acc > best_acc: improvement = val_acc - best_acc best_acc = val_acc # Save improved model version = get_model_version() new_model_path = f"{CONFIG['model_dir']}/model_v{version}.pth" torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'accuracy': val_acc, 'f1_macro': f1_macro, 'f1_weighted': f1_weighted, 'categories': CATEGORIES, 'config': CONFIG, 'base_model': base_model_path, 'new_samples': new_samples, 'improvement': improvement, 'retrain_date': datetime.now().isoformat() }, new_model_path) print(f"✓ Improved model saved as v{version} (+{improvement:.2f}%)") # If significant improvement, promote to production if improvement >= improvement_threshold: production_path = f"{CONFIG['model_dir']}/best_model.pth" # Backup old production model if Path(production_path).exists(): backup_path = f"{CONFIG['model_dir']}/best_model_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pth" shutil.copy(production_path, backup_path) # Promote new model shutil.copy(new_model_path, production_path) print(f"✓ Model promoted to production!") # Log retraining event log_retraining_event(version, val_acc, improvement, new_samples) # Clean up retraining directory retraining_dir = Path('ml/data/retraining') archive_dir = Path('ml/data/retraining_archive') / datetime.now().strftime('%Y%m%d_%H%M%S') archive_dir.mkdir(parents=True, exist_ok=True) for category in CATEGORIES: category_dir = retraining_dir / category if category_dir.exists(): shutil.move(str(category_dir), str(archive_dir / category)) print(f"\nRetraining complete! Final accuracy: {best_acc:.2f}%") return model def log_retraining_event(version, accuracy, improvement, new_samples): """Log retraining events for monitoring""" log_file = Path(CONFIG['model_dir']) / 'retraining_log.json' event = { 'version': version, 'timestamp': datetime.now().isoformat(), 'accuracy': accuracy, 'improvement': improvement, 'new_samples': new_samples } # Load existing log if log_file.exists(): with open(log_file, 'r') as f: log = json.load(f) else: log = [] log.append(event) # Save updated log with open(log_file, 'w') as f: json.dump(log, f, indent=2) print(f"Retraining event logged") if __name__ == "__main__": retrain_model()