iljung1106
Initial commit
546ff88
"""
Artist Style Embedding - Evaluation and Inference
"""
import argparse
from pathlib import Path
from typing import Dict, List, Tuple
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm
from sklearn.manifold import TSNE
try:
import matplotlib.pyplot as plt
PLOT_AVAILABLE = True
except ImportError:
PLOT_AVAILABLE = False
from config import get_config
from model import ArtistStyleModel
from dataset import ArtistDataset, build_dataset_splits, collate_fn
class ArtistEmbeddingInference:
"""Inference class for artist style embedding"""
def __init__(self, checkpoint_path: str, device: str = 'cuda'):
requested_device = device
if requested_device.startswith('cuda') and not torch.cuda.is_available():
print(
"[WARN] --device=cuda requested but torch.cuda.is_available() is False. "
"Falling back to CPU. (Install a CUDA-enabled PyTorch build to use GPU.)"
)
requested_device = 'cpu'
self.device = torch.device(requested_device)
# Always load checkpoint on CPU to avoid duplicating large tensors on GPU.
checkpoint = torch.load(checkpoint_path, map_location='cpu')
self.artist_to_idx = checkpoint['artist_to_idx']
self.idx_to_artist = {v: k for k, v in self.artist_to_idx.items()}
config = get_config()
self.model = ArtistStyleModel(
num_classes=len(self.artist_to_idx),
embedding_dim=config.model.embedding_dim,
hidden_dim=config.model.hidden_dim,
)
self.model.load_state_dict(checkpoint['model_state_dict'])
# Reduce VRAM: keep weights in FP16 on CUDA.
if self.device.type == 'cuda':
self.model = self.model.to(dtype=torch.float16)
self.model = self.model.to(self.device)
self.model.eval()
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def get_embedding(self, image: Image.Image) -> torch.Tensor:
tensor = self.transform(image).unsqueeze(0).to(self.device)
placeholder = torch.zeros(1, 3, 224, 224).to(self.device)
has_false = torch.tensor([False]).to(self.device)
with torch.no_grad():
embedding = self.model.get_embeddings(tensor, placeholder, placeholder, has_false, has_false)
return embedding.squeeze(0)
def predict_artist(self, image: Image.Image, top_k: int = 5) -> List[Tuple[str, float]]:
tensor = self.transform(image).unsqueeze(0).to(self.device)
placeholder = torch.zeros(1, 3, 224, 224).to(self.device)
has_false = torch.tensor([False]).to(self.device)
with torch.no_grad():
output = self.model(tensor, placeholder, placeholder, has_false, has_false)
probs = F.softmax(output['cosine'].squeeze(0), dim=0)
top_probs, top_indices = probs.topk(top_k)
return [(self.idx_to_artist[idx.item()], prob.item()) for prob, idx in zip(top_probs, top_indices)]
def evaluate_model(checkpoint_path: str, dataset_root: str, dataset_face_root: str, dataset_eyes_root: str, device: str = 'cuda'):
inference = ArtistEmbeddingInference(checkpoint_path, device)
config = get_config()
artist_to_idx, full_splits, face_splits, eye_splits = build_dataset_splits(
dataset_root, dataset_face_root, dataset_eyes_root,
min_images=config.data.min_images_per_artist
)
test_dataset = ArtistDataset(
dataset_root, dataset_face_root, dataset_eyes_root,
artist_to_idx, full_splits['test'], face_splits['test'], eye_splits['test'],
is_training=False
)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4, collate_fn=collate_fn)
total_correct = 0
total_correct_top5 = 0
total_samples = 0
for batch in tqdm(test_loader, desc="Evaluating"):
full = batch['full'].to(inference.device)
face = batch['face'].to(inference.device)
eye = batch['eye'].to(inference.device)
has_face = batch['has_face'].to(inference.device)
has_eye = batch['has_eye'].to(inference.device)
labels = batch['label'].to(inference.device)
with torch.no_grad():
output = inference.model(full, face, eye, has_face, has_eye)
# Top-1
preds = output['cosine'].argmax(dim=1)
total_correct += (preds == labels).sum().item()
# Top-5
_, top5_preds = output['cosine'].topk(5, dim=1)
top5_correct = top5_preds.eq(labels.view(-1, 1).expand_as(top5_preds))
total_correct_top5 += top5_correct.any(dim=1).sum().item()
total_samples += labels.size(0)
accuracy = total_correct / total_samples if total_samples > 0 else 0
accuracy_top5 = total_correct_top5 / total_samples if total_samples > 0 else 0
print("\nEvaluation Results:")
print("-" * 40)
print(f"Top-1 Accuracy: {accuracy:.4f} ({total_correct}/{total_samples})")
print(f"Top-5 Accuracy: {accuracy_top5:.4f} ({total_correct_top5}/{total_samples})")
def visualize_embeddings(checkpoint_path: str, dataset_root: str, dataset_face_root: str, dataset_eyes_root: str, output_path: str = 'tsne.png', max_artists: int = 50, device: str = 'cuda'):
if not PLOT_AVAILABLE:
print("matplotlib not available")
return
inference = ArtistEmbeddingInference(checkpoint_path, device)
config = get_config()
artist_to_idx, full_splits, face_splits, eye_splits = build_dataset_splits(
dataset_root, dataset_face_root, dataset_eyes_root,
min_images=config.data.min_images_per_artist
)
selected = list(artist_to_idx.keys())[:max_artists]
filtered_full = {a: p[:10] for a, p in full_splits['test'].items() if a in selected}
filtered_face = {a: face_splits['test'].get(a, []) for a in selected}
filtered_eye = {a: eye_splits['test'].get(a, []) for a in selected}
filtered_idx = {a: i for i, a in enumerate(selected)}
dataset = ArtistDataset(
dataset_root, dataset_face_root, dataset_eyes_root,
filtered_idx, filtered_full, filtered_face, filtered_eye,
is_training=False
)
loader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
all_embeddings, all_labels = [], []
for batch in tqdm(loader, desc="Extracting"):
full = batch['full'].to(inference.device)
face = batch['face'].to(inference.device)
eye = batch['eye'].to(inference.device)
has_face = batch['has_face'].to(inference.device)
has_eye = batch['has_eye'].to(inference.device)
with torch.no_grad():
embeddings = inference.model.get_embeddings(full, face, eye, has_face, has_eye)
all_embeddings.append(embeddings.cpu())
all_labels.extend(batch['label'].tolist())
embeddings = torch.cat(all_embeddings, dim=0).numpy()
print("Running t-SNE...")
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
embeddings_2d = tsne.fit_transform(embeddings)
plt.figure(figsize=(14, 10))
colors = plt.cm.tab20(np.linspace(0, 1, max_artists))
for label in set(all_labels):
mask = np.array(all_labels) == label
plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1], c=[colors[label]], alpha=0.7, s=50)
plt.title('Artist Style Embeddings (t-SNE)')
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()
print(f"Saved to {output_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, required=True)
parser.add_argument('--dataset_root', type=str, default='./dataset')
parser.add_argument('--dataset_face_root', type=str, default='./dataset_face')
parser.add_argument('--dataset_eyes_root', type=str, default='./dataset_eyes')
parser.add_argument('--mode', type=str, default='evaluate', choices=['evaluate', 'visualize', 'predict'])
parser.add_argument('--image', type=str, default=None)
parser.add_argument('--output', type=str, default='tsne.png')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
if args.mode == 'evaluate':
evaluate_model(args.checkpoint, args.dataset_root, args.dataset_face_root, args.dataset_eyes_root, args.device)
elif args.mode == 'visualize':
visualize_embeddings(args.checkpoint, args.dataset_root, args.dataset_face_root, args.dataset_eyes_root, args.output, device=args.device)
elif args.mode == 'predict':
if not args.image:
print("--image required for predict mode")
return
inference = ArtistEmbeddingInference(args.checkpoint, args.device)
image = Image.open(args.image).convert('RGB')
predictions = inference.predict_artist(image, top_k=10)
print("\nTop 10 Predictions:")
for i, (artist, prob) in enumerate(predictions, 1):
print(f"{i}. {artist}: {prob:.4f}")
if __name__ == '__main__':
main()