|
|
import gradio as gr |
|
|
import torch, os |
|
|
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline |
|
|
import gradio as gr |
|
|
from io import BytesIO |
|
|
import base64 |
|
|
import re |
|
|
|
|
|
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret') |
|
|
|
|
|
|
|
|
data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,') |
|
|
|
|
|
def readb64(b64): |
|
|
|
|
|
b64 = data_uri_pattern.sub("", b64) |
|
|
|
|
|
img = Image.open(BytesIO(base64.b64decode(b64))) |
|
|
return img |
|
|
|
|
|
|
|
|
def writeb64(image): |
|
|
buffered = BytesIO() |
|
|
image.save(buffered, format="PNG") |
|
|
b64image = base64.b64encode(buffered.getvalue()) |
|
|
b64image_str = b64image.decode("utf-8") |
|
|
return b64image_str |
|
|
|
|
|
|
|
|
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda") |
|
|
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda") |
|
|
|
|
|
def generate_images( |
|
|
secret_token="", |
|
|
prompt="", |
|
|
negative_prompt="bad,ugly,deformed", |
|
|
height=1024, |
|
|
width=1024, |
|
|
guidance_scale=4.0, |
|
|
seed=42, |
|
|
prior_inference_steps=20, |
|
|
decoder_inference_steps=10 |
|
|
): |
|
|
""" |
|
|
Generates images based on a given prompt using Stable Diffusion models on CUDA device. |
|
|
Parameters: |
|
|
- prompt (str): The prompt to generate images for. |
|
|
- negative_prompt (str): The negative prompt to guide image generation away from. |
|
|
- height (int): The height of the generated images. |
|
|
- width (int): The width of the generated images. |
|
|
- guidance_scale (float): The scale of guidance for the image generation. |
|
|
- prior_inference_steps (int): The number of inference steps for the prior model. |
|
|
- decoder_inference_steps (int): The number of inference steps for the decoder model. |
|
|
Returns: |
|
|
- List[PIL.Image]: A list of generated PIL Image objects. |
|
|
""" |
|
|
if secret_token != SECRET_TOKEN: |
|
|
raise gr.Error( |
|
|
f'Invalid secret token. Please fork the original space if you want to use it for yourself.') |
|
|
|
|
|
generator = torch.Generator(device="cuda").manual_seed(int(seed)) |
|
|
|
|
|
|
|
|
prior_output = prior( |
|
|
prompt=prompt, |
|
|
generator=generator, |
|
|
height=height, |
|
|
width=width, |
|
|
negative_prompt=negative_prompt, |
|
|
guidance_scale=guidance_scale, |
|
|
num_images_per_prompt=1, |
|
|
num_inference_steps=prior_inference_steps |
|
|
) |
|
|
|
|
|
|
|
|
decoder_output = decoder( |
|
|
image_embeddings=prior_output.image_embeddings.half(), |
|
|
prompt=prompt, |
|
|
generator=generator, |
|
|
negative_prompt=negative_prompt, |
|
|
guidance_scale=0.0, |
|
|
output_type="pil", |
|
|
num_inference_steps=decoder_inference_steps |
|
|
).images |
|
|
|
|
|
image = decoder_output[0] |
|
|
|
|
|
image_base64 = writeb64(image) |
|
|
|
|
|
return image_base64 |
|
|
|
|
|
|
|
|
with gr.Blocks() as gradio_app: |
|
|
gr.HTML(""" |
|
|
<div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;"> |
|
|
<div style="text-align: center; color: black;"> |
|
|
<p style="color: black;">This space is a REST API to programmatically generate an image.</p> |
|
|
<p style="color: black;">Interested in using it? Please use the <a href="https://huggingface.co/spaces/ArtGAN/Diffusion-API" target="_blank">original space</a>, thank you!</p> |
|
|
</div> |
|
|
</div>""") |
|
|
secret_token = gr.Textbox( |
|
|
placeholder="Secret token", |
|
|
show_label=False, |
|
|
) |
|
|
text2image_prompt = gr.Textbox( |
|
|
lines=1, |
|
|
placeholder="Prompt", |
|
|
show_label=False, |
|
|
) |
|
|
text2image_negative_prompt = gr.Textbox( |
|
|
lines=1, |
|
|
placeholder="Negative Prompt", |
|
|
show_label=False, |
|
|
) |
|
|
text2image_seed = gr.Number( |
|
|
value=42, |
|
|
label="Seed", |
|
|
) |
|
|
text2image_height = gr.Slider( |
|
|
minimum=128, |
|
|
maximum=1024, |
|
|
step=32, |
|
|
value=1024, |
|
|
label="Image Height", |
|
|
) |
|
|
text2image_width = gr.Slider( |
|
|
minimum=128, |
|
|
maximum=1024, |
|
|
step=32, |
|
|
value=1024, |
|
|
label="Image Width", |
|
|
) |
|
|
text2image_guidance_scale = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=15, |
|
|
step=0.1, |
|
|
value=4.0, |
|
|
label="Guidance Scale", |
|
|
) |
|
|
text2image_prior_inference_step = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=50, |
|
|
step=1, |
|
|
value=20, |
|
|
label="Prior Inference Step", |
|
|
) |
|
|
text2image_decoder_inference_step = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=50, |
|
|
step=1, |
|
|
value=10, |
|
|
label="Decoder Inference Step", |
|
|
) |
|
|
text2image_predict = gr.Button(value="Generate Image") |
|
|
output_image_base64 = gr.Text() |
|
|
text2image_predict.click( |
|
|
fn=generate_images, |
|
|
inputs=[ |
|
|
secret_token, |
|
|
text2image_prompt, |
|
|
text2image_negative_prompt, |
|
|
text2image_height, |
|
|
text2image_width, |
|
|
text2image_guidance_scale, |
|
|
text2image_seed, |
|
|
text2image_prior_inference_step, |
|
|
text2image_decoder_inference_step |
|
|
], |
|
|
outputs=output_image_base64, |
|
|
api_name='run', |
|
|
) |
|
|
|
|
|
gradio_app.queue(max_size=20).launch() |