Spaces:
Runtime error
Runtime error
| import random | |
| import torch | |
| import numpy as np | |
| from tqdm import tqdm | |
| from functools import partialmethod | |
| import gradio as gr | |
| from gradio.mix import Series | |
| from transformers import pipeline, FSMTForConditionalGeneration, FSMTTokenizer | |
| from rudalle.pipelines import generate_images | |
| from rudalle import get_rudalle_model, get_tokenizer, get_vae | |
| # disable tqdm logging from the rudalle pipeline | |
| tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| translation_model = FSMTForConditionalGeneration.from_pretrained("facebook/wmt19-en-ru", torch_dtype=torch.float16).half().to(device) | |
| translation_tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-en-ru") | |
| dalle = get_rudalle_model("Malevich", pretrained=True, fp16=True, device=device) | |
| tokenizer = get_tokenizer() | |
| vae = get_vae().to(device) | |
| def translation_wrapper(text: str): | |
| input_ids = translation_tokenizer.encode(text, return_tensors="pt") | |
| outputs = translation_model.generate(input_ids.to(device)) | |
| decoded = translation_tokenizer.decode(outputs[0].float(), skip_special_tokens=True) | |
| return decoded | |
| def dalle_wrapper(prompt: str): | |
| top_k, top_p = random.choice([ | |
| (1024, 0.98), | |
| (512, 0.97), | |
| (384, 0.96), | |
| ]) | |
| images , _ = generate_images( | |
| prompt, | |
| tokenizer, | |
| dalle, | |
| vae, | |
| top_k=top_k, | |
| images_num=1, | |
| top_p=top_p | |
| ) | |
| title = f"<b>{prompt}</b>" | |
| return title, images[0] | |
| translator = gr.Interface(fn=translation_wrapper, | |
| inputs=[gr.inputs.Textbox(label='What would you like to see?')], | |
| outputs="text") | |
| outputs = [ | |
| gr.outputs.HTML(label=""), | |
| gr.outputs.Image(label=""), | |
| ] | |
| generator = gr.Interface(fn=dalle_wrapper, inputs="text", outputs=outputs) | |
| description = ( | |
| "ruDALL-E is a 1.3B params text-to-image model by SberAI (links at the bottom). " | |
| "This demo uses an English-Russian translation model to adapt the prompts. " | |
| "Try pressing [Submit] multiple times to generate new images!" | |
| ) | |
| article = ( | |
| "<p style='text-align: center'>" | |
| "<a href='https://github.com/sberbank-ai/ru-dalle'>GitHub</a> | " | |
| "<a href='https://habr.com/ru/company/sberbank/blog/586926/'>Article (in Russian)</a>" | |
| "</p>" | |
| ) | |
| examples = [["A still life of grapes and a bottle of wine"], | |
| ["Город в стиле киберпанк"], | |
| ["A colorful photo of a coral reef"], | |
| ["A white cat sitting in a cardboard box"]] | |
| series = Series(translator, generator, | |
| title='Kinda-English ruDALL-E', | |
| description=description, | |
| article=article, | |
| layout='horizontal', | |
| theme='huggingface', | |
| examples=examples, | |
| allow_flagging=False, | |
| live=False, | |
| enable_queue=True, | |
| ) | |
| series.launch() | |