Spaces:
Sleeping
Sleeping
| """ | |
| Metrics for image quality evaluation. | |
| This module provides PSNR, SSIM, and LPIPS metrics for evaluating | |
| super-resolution model performance. | |
| """ | |
| import torch | |
| import lpips | |
| def compute_psnr(img1, img2): | |
| """ | |
| Compute PSNR between two images. | |
| Args: | |
| img1: First image tensor (B, C, H, W) in range [0, 1] | |
| img2: Second image tensor (B, C, H, W) in range [0, 1] | |
| Returns: | |
| psnr: PSNR value in dB | |
| """ | |
| mse = torch.mean((img1 - img2) ** 2) | |
| if mse == 0: | |
| return float('inf') | |
| psnr = 20 * torch.log10(1.0 / torch.sqrt(mse)) | |
| return psnr.item() | |
| def compute_ssim(img1, img2): | |
| """ | |
| Compute SSIM between two images. | |
| Args: | |
| img1: First image tensor (B, C, H, W) in range [0, 1] | |
| img2: Second image tensor (B, C, H, W) in range [0, 1] | |
| Returns: | |
| ssim: SSIM value (0 to 1, higher is better) | |
| """ | |
| # SSIM parameters | |
| C1 = 0.01 ** 2 | |
| C2 = 0.03 ** 2 | |
| mu1 = torch.mean(img1, dim=[2, 3], keepdim=True) | |
| mu2 = torch.mean(img2, dim=[2, 3], keepdim=True) | |
| sigma1_sq = torch.var(img1, dim=[2, 3], keepdim=True) | |
| sigma2_sq = torch.var(img2, dim=[2, 3], keepdim=True) | |
| sigma12 = torch.mean((img1 - mu1) * (img2 - mu2), dim=[2, 3], keepdim=True) | |
| ssim_n = (2 * mu1 * mu2 + C1) * (2 * sigma12 + C2) | |
| ssim_d = (mu1 ** 2 + mu2 ** 2 + C1) * (sigma1_sq + sigma2_sq + C2) | |
| ssim = ssim_n / ssim_d | |
| return ssim.mean().item() | |
| def compute_lpips(img1, img2, lpips_model): | |
| """ | |
| Compute LPIPS between two images. | |
| Args: | |
| img1: First image tensor (B, C, H, W) in range [0, 1] | |
| img2: Second image tensor (B, C, H, W) in range [0, 1] | |
| lpips_model: Initialized LPIPS model | |
| Returns: | |
| lpips: LPIPS value (lower is better, typically 0-1 range) | |
| """ | |
| # LPIPS expects images in range [-1, 1] | |
| img1_lpips = img1 * 2.0 - 1.0 # [0, 1] -> [-1, 1] | |
| img2_lpips = img2 * 2.0 - 1.0 # [0, 1] -> [-1, 1] | |
| # Ensure tensors are on the correct device | |
| device = next(lpips_model.parameters()).device | |
| img1_lpips = img1_lpips.to(device) | |
| img2_lpips = img2_lpips.to(device) | |
| with torch.no_grad(): | |
| lpips_value = lpips_model(img1_lpips, img2_lpips).mean().item() | |
| return lpips_value | |