File size: 5,697 Bytes
2204ef0 15b0643 2204ef0 1e1a292 15b0643 7e7b8c7 15b0643 f5da000 15b0643 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import gradio as gr
import torch, os
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
import gradio as gr
from io import BytesIO
import base64
import re
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
# Regex pattern to match data URI scheme
data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')
def readb64(b64):
# Remove any data URI scheme prefix with regex
b64 = data_uri_pattern.sub("", b64)
# Decode and open the image with PIL
img = Image.open(BytesIO(base64.b64decode(b64)))
return img
# convert from PIL to base64
def writeb64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
b64image = base64.b64encode(buffered.getvalue())
b64image_str = b64image.decode("utf-8")
return b64image_str
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda")
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda")
def generate_images(
secret_token="",
prompt="",
negative_prompt="bad,ugly,deformed",
height=1024,
width=1024,
guidance_scale=4.0,
seed=42,
prior_inference_steps=20,
decoder_inference_steps=10
):
"""
Generates images based on a given prompt using Stable Diffusion models on CUDA device.
Parameters:
- prompt (str): The prompt to generate images for.
- negative_prompt (str): The negative prompt to guide image generation away from.
- height (int): The height of the generated images.
- width (int): The width of the generated images.
- guidance_scale (float): The scale of guidance for the image generation.
- prior_inference_steps (int): The number of inference steps for the prior model.
- decoder_inference_steps (int): The number of inference steps for the decoder model.
Returns:
- List[PIL.Image]: A list of generated PIL Image objects.
"""
if secret_token != SECRET_TOKEN:
raise gr.Error(
f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
generator = torch.Generator(device="cuda").manual_seed(int(seed))
# Generate image embeddings using the prior model
prior_output = prior(
prompt=prompt,
generator=generator,
height=height,
width=width,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
num_inference_steps=prior_inference_steps
)
# Generate images using the decoder model and the embeddings from the prior model
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings.half(),
prompt=prompt,
generator=generator,
negative_prompt=negative_prompt,
guidance_scale=0.0, # Guidance scale typically set to 0 for decoder as guidance is applied in the prior
output_type="pil",
num_inference_steps=decoder_inference_steps
).images
image = decoder_output[0]
image_base64 = writeb64(image)
return image_base64
with gr.Blocks() as gradio_app:
gr.HTML("""
<div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
<div style="text-align: center; color: black;">
<p style="color: black;">This space is a REST API to programmatically generate an image.</p>
<p style="color: black;">Interested in using it? Please use the <a href="https://huggingface.co/spaces/ArtGAN/Diffusion-API" target="_blank">original space</a>, thank you!</p>
</div>
</div>""")
secret_token = gr.Textbox(
placeholder="Secret token",
show_label=False,
)
text2image_prompt = gr.Textbox(
lines=1,
placeholder="Prompt",
show_label=False,
)
text2image_negative_prompt = gr.Textbox(
lines=1,
placeholder="Negative Prompt",
show_label=False,
)
text2image_seed = gr.Number(
value=42,
label="Seed",
)
text2image_height = gr.Slider(
minimum=128,
maximum=1024,
step=32,
value=1024,
label="Image Height",
)
text2image_width = gr.Slider(
minimum=128,
maximum=1024,
step=32,
value=1024,
label="Image Width",
)
text2image_guidance_scale = gr.Slider(
minimum=0.1,
maximum=15,
step=0.1,
value=4.0,
label="Guidance Scale",
)
text2image_prior_inference_step = gr.Slider(
minimum=1,
maximum=50,
step=1,
value=20,
label="Prior Inference Step",
)
text2image_decoder_inference_step = gr.Slider(
minimum=1,
maximum=50,
step=1,
value=10,
label="Decoder Inference Step",
)
text2image_predict = gr.Button(value="Generate Image")
output_image_base64 = gr.Text()
text2image_predict.click(
fn=generate_images,
inputs=[
secret_token,
text2image_prompt,
text2image_negative_prompt,
text2image_height,
text2image_width,
text2image_guidance_scale,
text2image_seed,
text2image_prior_inference_step,
text2image_decoder_inference_step
],
outputs=output_image_base64,
api_name='run',
)
gradio_app.queue(max_size=20).launch() |