Spaces:
Paused
Paused
| import torch | |
| from torch.utils.data import DataLoader | |
| from starvector.metrics.base_metric import BaseMetric | |
| from tqdm import tqdm | |
| from transformers import AutoModel, AutoImageProcessor | |
| from PIL import Image | |
| import torch.nn as nn | |
| class DINOScoreCalculator(BaseMetric): | |
| def __init__(self, config=None, device='cuda'): | |
| super().__init__() | |
| self.class_name = self.__class__.__name__ | |
| self.config = config | |
| self.model, self.processor = self.get_DINOv2_model("base") | |
| self.model = self.model.to(device) | |
| self.device = device | |
| self.metric = self.calculate_DINOv2_similarity_score | |
| def get_DINOv2_model(self, model_size): | |
| if model_size == "small": | |
| model_size = "facebook/dinov2-small" | |
| elif model_size == "base": | |
| model_size = "facebook/dinov2-base" | |
| elif model_size == "large": | |
| model_size = "facebook/dinov2-large" | |
| else: | |
| raise ValueError(f"model_size should be either 'small', 'base' or 'large', got {model_size}") | |
| return AutoModel.from_pretrained(model_size), AutoImageProcessor.from_pretrained(model_size) | |
| def process_input(self, image, processor): | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| if isinstance(image, Image.Image): | |
| with torch.no_grad(): | |
| inputs = processor(images=image, return_tensors="pt").to(self.device) | |
| outputs = self.model(**inputs) | |
| features = outputs.last_hidden_state.mean(dim=1) | |
| elif isinstance(image, torch.Tensor): | |
| features = image.unsqueeze(0) if image.dim() == 1 else image | |
| else: | |
| raise ValueError("Input must be a file path, PIL Image, or tensor of features") | |
| return features | |
| def calculate_DINOv2_similarity_score(self, **kwargs): | |
| image1 = kwargs.get('gt_im') | |
| image2 = kwargs.get('gen_im') | |
| features1 = self.process_input(image1, self.processor) | |
| features2 = self.process_input(image2, self.processor) | |
| cos = nn.CosineSimilarity(dim=1) | |
| sim = cos(features1, features2).item() | |
| sim = (sim + 1) / 2 | |
| return sim | |