Spaces:
Build error
Build error
| # os | |
| from pathlib import Path | |
| # torch | |
| import torch | |
| import torchvision.transforms.functional as F | |
| from einops import repeat | |
| # Text2Punks and Tokenizer | |
| from text2punks.text2punk import Text2Punks, CLIP | |
| from text2punks.tokenizer import txt_tokenizer | |
| # select device | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| # load decoder | |
| codebook = torch.load('./text2punks/data/codebook.pt') | |
| # helper fns | |
| def exists(val): | |
| return val is not None | |
| def resize(image_tensor, size): | |
| return F.resize(image_tensor, (size, size), F.InterpolationMode.NEAREST) | |
| def to_pil_image(image_tensor): | |
| return F.to_pil_image(image_tensor.type(torch.uint8)) | |
| def model_loader(text2punk_path, clip_path): | |
| # load pre-trained TEXT2PUNKS model | |
| text2punk_path = Path(text2punk_path) | |
| assert text2punk_path.exists(), 'trained Text2Punks must exist' | |
| load_obj = torch.load(str(text2punk_path), map_location=torch.device(device)) | |
| text2punks_params, weights = load_obj.pop('hparams'), load_obj.pop('weights') | |
| text2punk = Text2Punks(**text2punks_params).to(device) | |
| text2punk.load_state_dict(weights) | |
| # load pre-trained CLIP model | |
| clip_path = Path(clip_path) | |
| assert clip_path.exists(), 'trained CLIP must exist' | |
| load_obj = torch.load(str(clip_path), map_location=torch.device(device)) | |
| clip_params, weights = load_obj.pop('hparams'), load_obj.pop('weights') | |
| clip = CLIP(**clip_params).to(device) | |
| clip.load_state_dict(weights) | |
| return text2punk, clip | |
| def generate_image(prompt_text, top_k, temperature, num_images, batch_size, top_prediction, text2punk_model, clip_model, codebook=codebook): | |
| text = txt_tokenizer.tokenize(prompt_text, text2punk_model.text_seq_len, truncate_text=True).to(device) | |
| text = repeat(text, '() n -> b n', b = num_images) | |
| img_outputs = [] | |
| score_outputs = [] | |
| for text_chunk in text.split(batch_size): | |
| images, scores = text2punk_model.generate_images(text_chunk, codebook.to(device), clip = clip_model, filter_thres = top_k, temperature = temperature) | |
| img_outputs.append(images) | |
| score_outputs.append(scores) | |
| img_outputs = torch.cat(img_outputs) | |
| score_outputs = torch.cat(score_outputs) | |
| similarity = score_outputs.softmax(dim=-1) | |
| values, indices = similarity.topk(top_prediction) | |
| img_outputs = img_outputs[indices] | |
| score_outputs = score_outputs[indices] | |
| return img_outputs, score_outputs | |