File size: 6,594 Bytes
f0aa930 35dc227 8a15462 35dc227 f9309dc 20e6fde 35dc227 8a15462 35dc227 a064d1f 35dc227 8a15462 7e1ec9f 35dc227 a064d1f 35dc227 80eae63 35dc227 a064d1f 35dc227 8a15462 35dc227 8a15462 35dc227 a064d1f d1b6cf8 a064d1f 35dc227 c383967 35dc227 edf4c40 35dc227 c383967 35dc227 edf4c40 35dc227 e1a3ecf 35dc227 8a15462 35dc227 0cc2ea6 35dc227 8a15462 dd01375 35dc227 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import torch, os
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
import gradio as gr
from io import BytesIO
import base64
import re
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
# Regex pattern to match data URI scheme
data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')
def readb64(b64):
# Remove any data URI scheme prefix with regex
b64 = data_uri_pattern.sub("", b64)
# Decode and open the image with PIL
img = Image.open(BytesIO(base64.b64decode(b64)))
return img
# convert from PIL to base64
def writeb64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
b64image = base64.b64encode(buffered.getvalue())
b64image_str = b64image.decode("utf-8")
return b64image_str
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda")
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda")
def generate_images(
secret_token="",
prompt="",
negative_prompt="bad,ugly,deformed",
height=1024,
width=1024,
guidance_scale=4.0,
seed=42,
prior_inference_steps=20,
decoder_inference_steps=10
):
"""
Generates images based on a given prompt using Stable Diffusion models on CUDA device.
Parameters:
- prompt (str): The prompt to generate images for.
- negative_prompt (str): The negative prompt to guide image generation away from.
- height (int): The height of the generated images.
- width (int): The width of the generated images.
- guidance_scale (float): The scale of guidance for the image generation.
- prior_inference_steps (int): The number of inference steps for the prior model.
- decoder_inference_steps (int): The number of inference steps for the decoder model.
Returns:
- List[PIL.Image]: A list of generated PIL Image objects.
"""
if secret_token != SECRET_TOKEN:
raise gr.Error(
f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
generator = torch.Generator(device="cuda").manual_seed(int(seed))
# Generate image embeddings using the prior model
prior_output = prior(
prompt=prompt,
generator=generator,
height=height,
width=width,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=prior_inference_steps
)
# Generate images using the decoder model and the embeddings from the prior model
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings.half(),
prompt=prompt,
generator=generator,
negative_prompt=negative_prompt,
guidance_scale=0.0, # Guidance scale typically set to 0 for decoder as guidance is applied in the prior
output_type="pil",
num_inference_steps=decoder_inference_steps
).images
image = decoder_output[0]
image_base64 = writeb64(image)
return image_base64
def web_demo():
with gr.Blocks():
with gr.Row():
with gr.Column():
secret_token = gr.Textbox(
placeholder="Secret token",
show_label=False,
)
text2image_prompt = gr.Textbox(
lines=1,
placeholder="Prompt",
show_label=False,
)
text2image_negative_prompt = gr.Textbox(
lines=1,
placeholder="Negative Prompt",
show_label=False,
)
text2image_seed = gr.Number(
value=42,
label="Seed",
)
with gr.Row():
with gr.Column():
text2image_height = gr.Slider(
minimum=128,
maximum=1024,
step=32,
value=1024,
label="Image Height",
)
text2image_width = gr.Slider(
minimum=128,
maximum=1024,
step=32,
value=1024,
label="Image Width",
)
with gr.Row():
with gr.Column():
text2image_guidance_scale = gr.Slider(
minimum=0.1,
maximum=15,
step=0.1,
value=4.0,
label="Guidance Scale",
)
text2image_prior_inference_step = gr.Slider(
minimum=1,
maximum=50,
step=1,
value=20,
label="Prior Inference Step",
)
text2image_decoder_inference_step = gr.Slider(
minimum=1,
maximum=50,
step=1,
value=10,
label="Decoder Inference Step",
)
text2image_predict = gr.Button(value="Generate Image")
output_image_base64 = gr.Text()
text2image_predict.click(
fn=generate_images,
inputs=[
secret_token,
text2image_prompt,
text2image_negative_prompt,
text2image_height,
text2image_width,
text2image_guidance_scale,
text2image_seed,
text2image_prior_inference_step,
text2image_decoder_inference_step
],
outputs=output_image_base64,
api_name='run',
) |