jbilcke-hf's picture
Update app.py
15b0643 verified
import gradio as gr
import torch, os
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
import gradio as gr
from io import BytesIO
import base64
import re
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
# Regex pattern to match data URI scheme
data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')
def readb64(b64):
# Remove any data URI scheme prefix with regex
b64 = data_uri_pattern.sub("", b64)
# Decode and open the image with PIL
img = Image.open(BytesIO(base64.b64decode(b64)))
return img
# convert from PIL to base64
def writeb64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
b64image = base64.b64encode(buffered.getvalue())
b64image_str = b64image.decode("utf-8")
return b64image_str
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda")
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda")
def generate_images(
secret_token="",
prompt="",
negative_prompt="bad,ugly,deformed",
height=1024,
width=1024,
guidance_scale=4.0,
seed=42,
prior_inference_steps=20,
decoder_inference_steps=10
):
"""
Generates images based on a given prompt using Stable Diffusion models on CUDA device.
Parameters:
- prompt (str): The prompt to generate images for.
- negative_prompt (str): The negative prompt to guide image generation away from.
- height (int): The height of the generated images.
- width (int): The width of the generated images.
- guidance_scale (float): The scale of guidance for the image generation.
- prior_inference_steps (int): The number of inference steps for the prior model.
- decoder_inference_steps (int): The number of inference steps for the decoder model.
Returns:
- List[PIL.Image]: A list of generated PIL Image objects.
"""
if secret_token != SECRET_TOKEN:
raise gr.Error(
f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
generator = torch.Generator(device="cuda").manual_seed(int(seed))
# Generate image embeddings using the prior model
prior_output = prior(
prompt=prompt,
generator=generator,
height=height,
width=width,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
num_inference_steps=prior_inference_steps
)
# Generate images using the decoder model and the embeddings from the prior model
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings.half(),
prompt=prompt,
generator=generator,
negative_prompt=negative_prompt,
guidance_scale=0.0, # Guidance scale typically set to 0 for decoder as guidance is applied in the prior
output_type="pil",
num_inference_steps=decoder_inference_steps
).images
image = decoder_output[0]
image_base64 = writeb64(image)
return image_base64
with gr.Blocks() as gradio_app:
gr.HTML("""
<div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
<div style="text-align: center; color: black;">
<p style="color: black;">This space is a REST API to programmatically generate an image.</p>
<p style="color: black;">Interested in using it? Please use the <a href="https://huggingface.co/spaces/ArtGAN/Diffusion-API" target="_blank">original space</a>, thank you!</p>
</div>
</div>""")
secret_token = gr.Textbox(
placeholder="Secret token",
show_label=False,
)
text2image_prompt = gr.Textbox(
lines=1,
placeholder="Prompt",
show_label=False,
)
text2image_negative_prompt = gr.Textbox(
lines=1,
placeholder="Negative Prompt",
show_label=False,
)
text2image_seed = gr.Number(
value=42,
label="Seed",
)
text2image_height = gr.Slider(
minimum=128,
maximum=1024,
step=32,
value=1024,
label="Image Height",
)
text2image_width = gr.Slider(
minimum=128,
maximum=1024,
step=32,
value=1024,
label="Image Width",
)
text2image_guidance_scale = gr.Slider(
minimum=0.1,
maximum=15,
step=0.1,
value=4.0,
label="Guidance Scale",
)
text2image_prior_inference_step = gr.Slider(
minimum=1,
maximum=50,
step=1,
value=20,
label="Prior Inference Step",
)
text2image_decoder_inference_step = gr.Slider(
minimum=1,
maximum=50,
step=1,
value=10,
label="Decoder Inference Step",
)
text2image_predict = gr.Button(value="Generate Image")
output_image_base64 = gr.Text()
text2image_predict.click(
fn=generate_images,
inputs=[
secret_token,
text2image_prompt,
text2image_negative_prompt,
text2image_height,
text2image_width,
text2image_guidance_scale,
text2image_seed,
text2image_prior_inference_step,
text2image_decoder_inference_step
],
outputs=output_image_base64,
api_name='run',
)
gradio_app.queue(max_size=20).launch()