Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Artist Style Embedding - Evaluation and Inference | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| from tqdm import tqdm | |
| from sklearn.manifold import TSNE | |
| try: | |
| import matplotlib.pyplot as plt | |
| PLOT_AVAILABLE = True | |
| except ImportError: | |
| PLOT_AVAILABLE = False | |
| from config import get_config | |
| from model import ArtistStyleModel | |
| from dataset import ArtistDataset, build_dataset_splits, collate_fn | |
| class ArtistEmbeddingInference: | |
| """Inference class for artist style embedding""" | |
| def __init__(self, checkpoint_path: str, device: str = 'cuda'): | |
| requested_device = device | |
| if requested_device.startswith('cuda') and not torch.cuda.is_available(): | |
| print( | |
| "[WARN] --device=cuda requested but torch.cuda.is_available() is False. " | |
| "Falling back to CPU. (Install a CUDA-enabled PyTorch build to use GPU.)" | |
| ) | |
| requested_device = 'cpu' | |
| self.device = torch.device(requested_device) | |
| # Always load checkpoint on CPU to avoid duplicating large tensors on GPU. | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| self.artist_to_idx = checkpoint['artist_to_idx'] | |
| self.idx_to_artist = {v: k for k, v in self.artist_to_idx.items()} | |
| config = get_config() | |
| self.model = ArtistStyleModel( | |
| num_classes=len(self.artist_to_idx), | |
| embedding_dim=config.model.embedding_dim, | |
| hidden_dim=config.model.hidden_dim, | |
| ) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| # Reduce VRAM: keep weights in FP16 on CUDA. | |
| if self.device.type == 'cuda': | |
| self.model = self.model.to(dtype=torch.float16) | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def get_embedding(self, image: Image.Image) -> torch.Tensor: | |
| tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| placeholder = torch.zeros(1, 3, 224, 224).to(self.device) | |
| has_false = torch.tensor([False]).to(self.device) | |
| with torch.no_grad(): | |
| embedding = self.model.get_embeddings(tensor, placeholder, placeholder, has_false, has_false) | |
| return embedding.squeeze(0) | |
| def predict_artist(self, image: Image.Image, top_k: int = 5) -> List[Tuple[str, float]]: | |
| tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| placeholder = torch.zeros(1, 3, 224, 224).to(self.device) | |
| has_false = torch.tensor([False]).to(self.device) | |
| with torch.no_grad(): | |
| output = self.model(tensor, placeholder, placeholder, has_false, has_false) | |
| probs = F.softmax(output['cosine'].squeeze(0), dim=0) | |
| top_probs, top_indices = probs.topk(top_k) | |
| return [(self.idx_to_artist[idx.item()], prob.item()) for prob, idx in zip(top_probs, top_indices)] | |
| def evaluate_model(checkpoint_path: str, dataset_root: str, dataset_face_root: str, dataset_eyes_root: str, device: str = 'cuda'): | |
| inference = ArtistEmbeddingInference(checkpoint_path, device) | |
| config = get_config() | |
| artist_to_idx, full_splits, face_splits, eye_splits = build_dataset_splits( | |
| dataset_root, dataset_face_root, dataset_eyes_root, | |
| min_images=config.data.min_images_per_artist | |
| ) | |
| test_dataset = ArtistDataset( | |
| dataset_root, dataset_face_root, dataset_eyes_root, | |
| artist_to_idx, full_splits['test'], face_splits['test'], eye_splits['test'], | |
| is_training=False | |
| ) | |
| test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4, collate_fn=collate_fn) | |
| total_correct = 0 | |
| total_correct_top5 = 0 | |
| total_samples = 0 | |
| for batch in tqdm(test_loader, desc="Evaluating"): | |
| full = batch['full'].to(inference.device) | |
| face = batch['face'].to(inference.device) | |
| eye = batch['eye'].to(inference.device) | |
| has_face = batch['has_face'].to(inference.device) | |
| has_eye = batch['has_eye'].to(inference.device) | |
| labels = batch['label'].to(inference.device) | |
| with torch.no_grad(): | |
| output = inference.model(full, face, eye, has_face, has_eye) | |
| # Top-1 | |
| preds = output['cosine'].argmax(dim=1) | |
| total_correct += (preds == labels).sum().item() | |
| # Top-5 | |
| _, top5_preds = output['cosine'].topk(5, dim=1) | |
| top5_correct = top5_preds.eq(labels.view(-1, 1).expand_as(top5_preds)) | |
| total_correct_top5 += top5_correct.any(dim=1).sum().item() | |
| total_samples += labels.size(0) | |
| accuracy = total_correct / total_samples if total_samples > 0 else 0 | |
| accuracy_top5 = total_correct_top5 / total_samples if total_samples > 0 else 0 | |
| print("\nEvaluation Results:") | |
| print("-" * 40) | |
| print(f"Top-1 Accuracy: {accuracy:.4f} ({total_correct}/{total_samples})") | |
| print(f"Top-5 Accuracy: {accuracy_top5:.4f} ({total_correct_top5}/{total_samples})") | |
| def visualize_embeddings(checkpoint_path: str, dataset_root: str, dataset_face_root: str, dataset_eyes_root: str, output_path: str = 'tsne.png', max_artists: int = 50, device: str = 'cuda'): | |
| if not PLOT_AVAILABLE: | |
| print("matplotlib not available") | |
| return | |
| inference = ArtistEmbeddingInference(checkpoint_path, device) | |
| config = get_config() | |
| artist_to_idx, full_splits, face_splits, eye_splits = build_dataset_splits( | |
| dataset_root, dataset_face_root, dataset_eyes_root, | |
| min_images=config.data.min_images_per_artist | |
| ) | |
| selected = list(artist_to_idx.keys())[:max_artists] | |
| filtered_full = {a: p[:10] for a, p in full_splits['test'].items() if a in selected} | |
| filtered_face = {a: face_splits['test'].get(a, []) for a in selected} | |
| filtered_eye = {a: eye_splits['test'].get(a, []) for a in selected} | |
| filtered_idx = {a: i for i, a in enumerate(selected)} | |
| dataset = ArtistDataset( | |
| dataset_root, dataset_face_root, dataset_eyes_root, | |
| filtered_idx, filtered_full, filtered_face, filtered_eye, | |
| is_training=False | |
| ) | |
| loader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=collate_fn) | |
| all_embeddings, all_labels = [], [] | |
| for batch in tqdm(loader, desc="Extracting"): | |
| full = batch['full'].to(inference.device) | |
| face = batch['face'].to(inference.device) | |
| eye = batch['eye'].to(inference.device) | |
| has_face = batch['has_face'].to(inference.device) | |
| has_eye = batch['has_eye'].to(inference.device) | |
| with torch.no_grad(): | |
| embeddings = inference.model.get_embeddings(full, face, eye, has_face, has_eye) | |
| all_embeddings.append(embeddings.cpu()) | |
| all_labels.extend(batch['label'].tolist()) | |
| embeddings = torch.cat(all_embeddings, dim=0).numpy() | |
| print("Running t-SNE...") | |
| tsne = TSNE(n_components=2, random_state=42, perplexity=30) | |
| embeddings_2d = tsne.fit_transform(embeddings) | |
| plt.figure(figsize=(14, 10)) | |
| colors = plt.cm.tab20(np.linspace(0, 1, max_artists)) | |
| for label in set(all_labels): | |
| mask = np.array(all_labels) == label | |
| plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1], c=[colors[label]], alpha=0.7, s=50) | |
| plt.title('Artist Style Embeddings (t-SNE)') | |
| plt.tight_layout() | |
| plt.savefig(output_path, dpi=150) | |
| plt.close() | |
| print(f"Saved to {output_path}") | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--checkpoint', type=str, required=True) | |
| parser.add_argument('--dataset_root', type=str, default='./dataset') | |
| parser.add_argument('--dataset_face_root', type=str, default='./dataset_face') | |
| parser.add_argument('--dataset_eyes_root', type=str, default='./dataset_eyes') | |
| parser.add_argument('--mode', type=str, default='evaluate', choices=['evaluate', 'visualize', 'predict']) | |
| parser.add_argument('--image', type=str, default=None) | |
| parser.add_argument('--output', type=str, default='tsne.png') | |
| parser.add_argument('--device', type=str, default='cuda') | |
| args = parser.parse_args() | |
| if args.mode == 'evaluate': | |
| evaluate_model(args.checkpoint, args.dataset_root, args.dataset_face_root, args.dataset_eyes_root, args.device) | |
| elif args.mode == 'visualize': | |
| visualize_embeddings(args.checkpoint, args.dataset_root, args.dataset_face_root, args.dataset_eyes_root, args.output, device=args.device) | |
| elif args.mode == 'predict': | |
| if not args.image: | |
| print("--image required for predict mode") | |
| return | |
| inference = ArtistEmbeddingInference(args.checkpoint, args.device) | |
| image = Image.open(args.image).convert('RGB') | |
| predictions = inference.predict_artist(image, top_k=10) | |
| print("\nTop 10 Predictions:") | |
| for i, (artist, prob) in enumerate(predictions, 1): | |
| print(f"{i}. {artist}: {prob:.4f}") | |
| if __name__ == '__main__': | |
| main() | |