""" Artist Style Embedding - Evaluation and Inference """ import argparse from pathlib import Path from typing import Dict, List, Tuple import torch import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision import transforms from PIL import Image import numpy as np from tqdm import tqdm from sklearn.manifold import TSNE try: import matplotlib.pyplot as plt PLOT_AVAILABLE = True except ImportError: PLOT_AVAILABLE = False from config import get_config from model import ArtistStyleModel from dataset import ArtistDataset, build_dataset_splits, collate_fn class ArtistEmbeddingInference: """Inference class for artist style embedding""" def __init__(self, checkpoint_path: str, device: str = 'cuda'): requested_device = device if requested_device.startswith('cuda') and not torch.cuda.is_available(): print( "[WARN] --device=cuda requested but torch.cuda.is_available() is False. " "Falling back to CPU. (Install a CUDA-enabled PyTorch build to use GPU.)" ) requested_device = 'cpu' self.device = torch.device(requested_device) # Always load checkpoint on CPU to avoid duplicating large tensors on GPU. checkpoint = torch.load(checkpoint_path, map_location='cpu') self.artist_to_idx = checkpoint['artist_to_idx'] self.idx_to_artist = {v: k for k, v in self.artist_to_idx.items()} config = get_config() self.model = ArtistStyleModel( num_classes=len(self.artist_to_idx), embedding_dim=config.model.embedding_dim, hidden_dim=config.model.hidden_dim, ) self.model.load_state_dict(checkpoint['model_state_dict']) # Reduce VRAM: keep weights in FP16 on CUDA. if self.device.type == 'cuda': self.model = self.model.to(dtype=torch.float16) self.model = self.model.to(self.device) self.model.eval() self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def get_embedding(self, image: Image.Image) -> torch.Tensor: tensor = self.transform(image).unsqueeze(0).to(self.device) placeholder = torch.zeros(1, 3, 224, 224).to(self.device) has_false = torch.tensor([False]).to(self.device) with torch.no_grad(): embedding = self.model.get_embeddings(tensor, placeholder, placeholder, has_false, has_false) return embedding.squeeze(0) def predict_artist(self, image: Image.Image, top_k: int = 5) -> List[Tuple[str, float]]: tensor = self.transform(image).unsqueeze(0).to(self.device) placeholder = torch.zeros(1, 3, 224, 224).to(self.device) has_false = torch.tensor([False]).to(self.device) with torch.no_grad(): output = self.model(tensor, placeholder, placeholder, has_false, has_false) probs = F.softmax(output['cosine'].squeeze(0), dim=0) top_probs, top_indices = probs.topk(top_k) return [(self.idx_to_artist[idx.item()], prob.item()) for prob, idx in zip(top_probs, top_indices)] def evaluate_model(checkpoint_path: str, dataset_root: str, dataset_face_root: str, dataset_eyes_root: str, device: str = 'cuda'): inference = ArtistEmbeddingInference(checkpoint_path, device) config = get_config() artist_to_idx, full_splits, face_splits, eye_splits = build_dataset_splits( dataset_root, dataset_face_root, dataset_eyes_root, min_images=config.data.min_images_per_artist ) test_dataset = ArtistDataset( dataset_root, dataset_face_root, dataset_eyes_root, artist_to_idx, full_splits['test'], face_splits['test'], eye_splits['test'], is_training=False ) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4, collate_fn=collate_fn) total_correct = 0 total_correct_top5 = 0 total_samples = 0 for batch in tqdm(test_loader, desc="Evaluating"): full = batch['full'].to(inference.device) face = batch['face'].to(inference.device) eye = batch['eye'].to(inference.device) has_face = batch['has_face'].to(inference.device) has_eye = batch['has_eye'].to(inference.device) labels = batch['label'].to(inference.device) with torch.no_grad(): output = inference.model(full, face, eye, has_face, has_eye) # Top-1 preds = output['cosine'].argmax(dim=1) total_correct += (preds == labels).sum().item() # Top-5 _, top5_preds = output['cosine'].topk(5, dim=1) top5_correct = top5_preds.eq(labels.view(-1, 1).expand_as(top5_preds)) total_correct_top5 += top5_correct.any(dim=1).sum().item() total_samples += labels.size(0) accuracy = total_correct / total_samples if total_samples > 0 else 0 accuracy_top5 = total_correct_top5 / total_samples if total_samples > 0 else 0 print("\nEvaluation Results:") print("-" * 40) print(f"Top-1 Accuracy: {accuracy:.4f} ({total_correct}/{total_samples})") print(f"Top-5 Accuracy: {accuracy_top5:.4f} ({total_correct_top5}/{total_samples})") def visualize_embeddings(checkpoint_path: str, dataset_root: str, dataset_face_root: str, dataset_eyes_root: str, output_path: str = 'tsne.png', max_artists: int = 50, device: str = 'cuda'): if not PLOT_AVAILABLE: print("matplotlib not available") return inference = ArtistEmbeddingInference(checkpoint_path, device) config = get_config() artist_to_idx, full_splits, face_splits, eye_splits = build_dataset_splits( dataset_root, dataset_face_root, dataset_eyes_root, min_images=config.data.min_images_per_artist ) selected = list(artist_to_idx.keys())[:max_artists] filtered_full = {a: p[:10] for a, p in full_splits['test'].items() if a in selected} filtered_face = {a: face_splits['test'].get(a, []) for a in selected} filtered_eye = {a: eye_splits['test'].get(a, []) for a in selected} filtered_idx = {a: i for i, a in enumerate(selected)} dataset = ArtistDataset( dataset_root, dataset_face_root, dataset_eyes_root, filtered_idx, filtered_full, filtered_face, filtered_eye, is_training=False ) loader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=collate_fn) all_embeddings, all_labels = [], [] for batch in tqdm(loader, desc="Extracting"): full = batch['full'].to(inference.device) face = batch['face'].to(inference.device) eye = batch['eye'].to(inference.device) has_face = batch['has_face'].to(inference.device) has_eye = batch['has_eye'].to(inference.device) with torch.no_grad(): embeddings = inference.model.get_embeddings(full, face, eye, has_face, has_eye) all_embeddings.append(embeddings.cpu()) all_labels.extend(batch['label'].tolist()) embeddings = torch.cat(all_embeddings, dim=0).numpy() print("Running t-SNE...") tsne = TSNE(n_components=2, random_state=42, perplexity=30) embeddings_2d = tsne.fit_transform(embeddings) plt.figure(figsize=(14, 10)) colors = plt.cm.tab20(np.linspace(0, 1, max_artists)) for label in set(all_labels): mask = np.array(all_labels) == label plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1], c=[colors[label]], alpha=0.7, s=50) plt.title('Artist Style Embeddings (t-SNE)') plt.tight_layout() plt.savefig(output_path, dpi=150) plt.close() print(f"Saved to {output_path}") def main(): parser = argparse.ArgumentParser() parser.add_argument('--checkpoint', type=str, required=True) parser.add_argument('--dataset_root', type=str, default='./dataset') parser.add_argument('--dataset_face_root', type=str, default='./dataset_face') parser.add_argument('--dataset_eyes_root', type=str, default='./dataset_eyes') parser.add_argument('--mode', type=str, default='evaluate', choices=['evaluate', 'visualize', 'predict']) parser.add_argument('--image', type=str, default=None) parser.add_argument('--output', type=str, default='tsne.png') parser.add_argument('--device', type=str, default='cuda') args = parser.parse_args() if args.mode == 'evaluate': evaluate_model(args.checkpoint, args.dataset_root, args.dataset_face_root, args.dataset_eyes_root, args.device) elif args.mode == 'visualize': visualize_embeddings(args.checkpoint, args.dataset_root, args.dataset_face_root, args.dataset_eyes_root, args.output, device=args.device) elif args.mode == 'predict': if not args.image: print("--image required for predict mode") return inference = ArtistEmbeddingInference(args.checkpoint, args.device) image = Image.open(args.image).convert('RGB') predictions = inference.predict_artist(image, top_k=10) print("\nTop 10 Predictions:") for i, (artist, prob) in enumerate(predictions, 1): print(f"{i}. {artist}: {prob:.4f}") if __name__ == '__main__': main()