# after SmilingWolf/wd-tagger import gradio as gr import torch import torch.nn as nn import timm import timm.layers.ml_decoder from transformers import AutoModel, AutoTokenizer import torchvision from torchvision import transforms import PIL from PIL import Image import requests from io import BytesIO import json import pickle headers = { "User-Agent": "Gradio 0-shot classification demo", } TITLE = "Danbooru 0-shot classifiction demo" DESCRIPTION = """ Demo for 0-shot classification on Danbooru images. Davit-tiny backbone, ML-Decoder classification head, Alibaba-NLP/gte-large-en-v1.5 text embedding model. Training set includes IDs with <= 5,400,000 and last 3 digits in range [0, 899], inclusive. Get image by uploading or fetching by post ID. Get tag description by input box or fetching by tag name. """ def scrape_img(postID): postURL = "https://danbooru.donmai.us/posts/" + str(postID) + ".json" postData = json.loads(requests.get(postURL, headers=headers).content) imageURL = postData['file_url'] print("Getting image from " + imageURL) response = requests.get(imageURL, headers=headers) image = Image.open(BytesIO(response.content)) image.load() return image def scrape_wiki(tagName): wikiHistoryURL = f"https://danbooru.donmai.us/wiki_page_versions.json?search[title]={tagName}" wikiHistory = json.loads(requests.get(wikiHistoryURL, headers=headers).content) wikiBody = (": " + wikiHistory[0]['body'] if len(wikiHistory) > 0 else "") return tagName + wikiBody class Predictor: def __init__(self): self.img_size = (288, 288) self.cls_model = None self.tokenizer = None self.text_emb_model = None self.class_embed = None self.tag_names = None self.load_model() def load_model(self): with open('tags1588.pkl', 'rb') as f: classes = pickle.load(f) tagNames = classes[0].to_list() self.tag_names = tagNames pretrained_weights = torch.load('model.pth', map_location=torch.device('cpu')) self.class_embed = pretrained_weights['0.head.head.class_embed.weight'] cls_model = timm.create_model('davit_tiny', num_classes=len(classes)) cls_model = timm.layers.ml_decoder.add_ml_decoder_head( cls_model, num_groups=len(classes), class_embed=self.class_embed, class_embed_merge='', shared_fc=True) cls_model = nn.Sequential(cls_model) cls_model.load_state_dict(pretrained_weights, strict=True) cls_model = cls_model.eval() self.cls_model = cls_model model_path = 'Alibaba-NLP/gte-large-en-v1.5' self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.text_emb_model = AutoModel.from_pretrained(model_path, trust_remote_code=True) self.text_emb_model = self.text_emb_model.eval() def embed_text(self, input_strings): with torch.no_grad(): # Tokenize the input texts embeddingList = [] for text in input_strings: batch_dict = self.tokenizer(text, padding=True, truncation=False, return_tensors='pt') outputs = self.text_emb_model(**batch_dict.to(self.text_emb_model.device)) embeddings = outputs.last_hidden_state[:, 0] embeddingList.append(embeddings.cpu()) embeddings = torch.cat(embeddingList) return embeddings def prepare_image(self, image): image.load() # check if file valid image = image.convert("RGBA") color = (255,255,255) background = Image.new('RGB', image.size, color) background.paste(image, mask=image.split()[3]) image = background image = transforms.Resize(self.img_size, interpolation = torchvision.transforms.InterpolationMode.BICUBIC)(image) image = transforms.ToTensor()(image) return image def predict( self, image, query, tag_names, ): image = self.prepare_image(image) image_features = self.cls_model[0].forward_features(image.unsqueeze(0)) outputs = self.cls_model[0].head(image_features, q = query).sigmoid().float() general_tag_list = list(zip(tag_names, outputs[0].tolist())) general_tag_list.sort(key=lambda y: y[1], reverse=True) general_tag_preds_dict = {} for tag, prob in general_tag_list[:50]: general_tag_preds_dict[tag] = prob return general_tag_preds_dict def predict_seen_tags( self, image, ): return self.predict(image, self.class_embed, self.tag_names) def predict_new_tag( self, image, description, ): return self.predict(image, self.embed_text([description]), ["embedding"])["embedding"] def main(): predictor = Predictor() with gr.Blocks(title=TITLE) as demo: with gr.Column(): gr.Markdown( value=f"