Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,638 Bytes
d2c9b66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import os
import torch
import torch.nn as nn
from transformers import CLIPModel
from torchvision.datasets.utils import download_url
URL = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/sac%2Blogos%2Bava1-l14-linearMSE.pth"
FILENAME = "sac+logos+ava1-l14-linearMSE.pth"
MD5 = "b1047fd767a00134b8fd6529bf19521a"
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(768, 1024),
nn.Dropout(0.2),
nn.Linear(1024, 128),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.Dropout(0.1),
nn.Linear(64, 16),
nn.Linear(16, 1),
)
def forward(self, embed):
return self.layers(embed)
class ImprovedAestheticPredictor(nn.Module):
def __init__(self, encoder_path="openai/clip-vit-large-patch14", predictor_path=None):
super().__init__()
self.encoder = CLIPModel.from_pretrained(encoder_path)
self.predictor = MLP()
if predictor_path is None or not os.path.exists(predictor_path):
download_url(URL, torch.hub.get_dir(), FILENAME, md5=MD5)
predictor_path = os.path.join(torch.hub.get_dir(), FILENAME)
state_dict = torch.load(predictor_path, map_location="cpu")
self.predictor.load_state_dict(state_dict)
self.eval()
def forward(self, pixel_values):
embed = self.encoder.get_image_features(pixel_values=pixel_values)
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
return self.predictor(embed).squeeze(1)
|