garbage-segregate / ml /retrain.py
Rahiq's picture
Deploy waste classification backend with ML model
bf17f74
"""
Continuous learning script for model improvement
Fine-tunes existing model with new corrected samples
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models
from pathlib import Path
import shutil
from datetime import datetime
import json
from .train import WasteDataset, get_transforms, validate, CATEGORIES, CONFIG
def get_model_version():
"""Get next model version number"""
model_dir = Path(CONFIG['model_dir'])
existing_versions = list(model_dir.glob('model_v*.pth'))
if not existing_versions:
return 1
versions = [int(p.stem.split('_v')[1]) for p in existing_versions]
return max(versions) + 1
def prepare_retraining_data():
"""Organize retraining data into proper structure"""
retraining_dir = Path('ml/data/retraining')
processed_dir = Path(CONFIG['data_dir'])
if not retraining_dir.exists():
print("No retraining data found")
return 0
# Count new samples
new_samples = 0
for category in CATEGORIES:
category_dir = retraining_dir / category
if category_dir.exists():
images = list(category_dir.glob('*.jpg')) + list(category_dir.glob('*.png'))
new_samples += len(images)
# Copy to training set
target_dir = processed_dir / 'train' / category
target_dir.mkdir(parents=True, exist_ok=True)
for img_path in images:
target_path = target_dir / f"retrain_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{img_path.name}"
shutil.copy(img_path, target_path)
print(f"Added {new_samples} new samples to training set")
return new_samples
def retrain_model(base_model_path='ml/models/best_model.pth',
num_epochs=10,
learning_rate=0.0001):
"""
Fine-tune existing model with new data
Uses lower learning rate for incremental learning
"""
print("Starting retraining process...")
# Prepare new data
new_samples = prepare_retraining_data()
if new_samples == 0:
print("No new samples to train on")
return None
# Setup device
device = torch.device(CONFIG['device'])
print(f"Using device: {device}")
# Load base model
checkpoint = torch.load(base_model_path, map_location=device)
model = models.efficientnet_b0(pretrained=False)
num_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(p=0.3),
nn.Linear(num_features, CONFIG['num_classes'])
)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
print(f"Loaded base model with accuracy: {checkpoint['accuracy']:.2f}%")
# Create datasets with updated data
train_dataset = WasteDataset(
CONFIG['data_dir'],
split='train',
transform=get_transforms('train')
)
val_dataset = WasteDataset(
CONFIG['data_dir'],
split='val',
transform=get_transforms('val')
)
train_loader = DataLoader(
train_dataset,
batch_size=CONFIG['batch_size'],
shuffle=True,
num_workers=4
)
val_loader = DataLoader(
val_dataset,
batch_size=CONFIG['batch_size'],
shuffle=False,
num_workers=4
)
# Setup training
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
best_acc = checkpoint['accuracy']
improvement_threshold = 1.0 # Must improve by at least 1%
# Fine-tuning loop
for epoch in range(num_epochs):
print(f"\nRetraining Epoch {epoch+1}/{num_epochs}")
print("-" * 50)
# Train
model.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Validate
val_loss, val_acc, f1_macro, f1_weighted, val_preds, val_labels = validate(
model, val_loader, criterion, device
)
print(f"Val Acc: {val_acc:.2f}% | F1 Macro: {f1_macro:.4f}")
# Check improvement
if val_acc > best_acc:
improvement = val_acc - best_acc
best_acc = val_acc
# Save improved model
version = get_model_version()
new_model_path = f"{CONFIG['model_dir']}/model_v{version}.pth"
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': val_acc,
'f1_macro': f1_macro,
'f1_weighted': f1_weighted,
'categories': CATEGORIES,
'config': CONFIG,
'base_model': base_model_path,
'new_samples': new_samples,
'improvement': improvement,
'retrain_date': datetime.now().isoformat()
}, new_model_path)
print(f"✓ Improved model saved as v{version} (+{improvement:.2f}%)")
# If significant improvement, promote to production
if improvement >= improvement_threshold:
production_path = f"{CONFIG['model_dir']}/best_model.pth"
# Backup old production model
if Path(production_path).exists():
backup_path = f"{CONFIG['model_dir']}/best_model_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pth"
shutil.copy(production_path, backup_path)
# Promote new model
shutil.copy(new_model_path, production_path)
print(f"✓ Model promoted to production!")
# Log retraining event
log_retraining_event(version, val_acc, improvement, new_samples)
# Clean up retraining directory
retraining_dir = Path('ml/data/retraining')
archive_dir = Path('ml/data/retraining_archive') / datetime.now().strftime('%Y%m%d_%H%M%S')
archive_dir.mkdir(parents=True, exist_ok=True)
for category in CATEGORIES:
category_dir = retraining_dir / category
if category_dir.exists():
shutil.move(str(category_dir), str(archive_dir / category))
print(f"\nRetraining complete! Final accuracy: {best_acc:.2f}%")
return model
def log_retraining_event(version, accuracy, improvement, new_samples):
"""Log retraining events for monitoring"""
log_file = Path(CONFIG['model_dir']) / 'retraining_log.json'
event = {
'version': version,
'timestamp': datetime.now().isoformat(),
'accuracy': accuracy,
'improvement': improvement,
'new_samples': new_samples
}
# Load existing log
if log_file.exists():
with open(log_file, 'r') as f:
log = json.load(f)
else:
log = []
log.append(event)
# Save updated log
with open(log_file, 'w') as f:
json.dump(log, f, indent=2)
print(f"Retraining event logged")
if __name__ == "__main__":
retrain_model()