File size: 1,786 Bytes
1835c38
 
 
 
77afe44
1835c38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77afe44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import math
from pathlib import Path
from collections import defaultdict
import statistics
from typing import Literal

from pydantic import BaseModel
import pandas as pd
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
import lpips
import torch
import torchvision.transforms.v2 as T
import torchvision.transforms.v2.functional as TF


_transforms = T.Compose([
    T.ToImage(),
    T.RGB(),
    T.ToDtype(torch.float32, scale=True), # [0,1]
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # [-1,1]
])

def compare_lpips(loss_fn, image1, image2, resize=False, device="cuda", to_item=True):
    if image1.size != image2.size:
        if resize:
            image2 = image2.resize(image1.size, Image.LANCZOS)
        else:
            raise ValueError(f"Got mismatch {image1.size=} {image2.size=}")

    im1_t = _transforms(image1).unsqueeze(0).to(device=device)
    im2_t = _transforms(image2).unsqueeze(0).to(device=device)
    
    with torch.no_grad():
        score = loss_fn(im1_t, im2_t)
    
    if to_item:
        return score.float().item()
    return score

class LpipsCalculator:
    def __init__(self, resize=False, device="cuda", to_item=True):
        self.resize = resize
        self.to_item = to_item
        self.loss_fn = lpips.LPIPS(net='alex')
        if torch.cuda.is_available():
            self.device = "cuda"
        else:
            self.device = "cpu"
        self.loss_fn = self.loss_fn.to(device=self.device)
    
    def __call__(self, image1, image2, resize=None, to_item=None):
        if resize is None:
            resize = self.resize
        if to_item is None:
            to_item = self.to_item
        return compare_lpips(self.loss_fn, image1, image2, resize=resize, device=self.device, to_item=to_item)