import gradio as gr import torch, os from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline import gradio as gr from io import BytesIO import base64 import re SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret') # Regex pattern to match data URI scheme data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,') def readb64(b64): # Remove any data URI scheme prefix with regex b64 = data_uri_pattern.sub("", b64) # Decode and open the image with PIL img = Image.open(BytesIO(base64.b64decode(b64))) return img # convert from PIL to base64 def writeb64(image): buffered = BytesIO() image.save(buffered, format="PNG") b64image = base64.b64encode(buffered.getvalue()) b64image_str = b64image.decode("utf-8") return b64image_str prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda") decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda") def generate_images( secret_token="", prompt="", negative_prompt="bad,ugly,deformed", height=1024, width=1024, guidance_scale=4.0, seed=42, prior_inference_steps=20, decoder_inference_steps=10 ): """ Generates images based on a given prompt using Stable Diffusion models on CUDA device. Parameters: - prompt (str): The prompt to generate images for. - negative_prompt (str): The negative prompt to guide image generation away from. - height (int): The height of the generated images. - width (int): The width of the generated images. - guidance_scale (float): The scale of guidance for the image generation. - prior_inference_steps (int): The number of inference steps for the prior model. - decoder_inference_steps (int): The number of inference steps for the decoder model. Returns: - List[PIL.Image]: A list of generated PIL Image objects. """ if secret_token != SECRET_TOKEN: raise gr.Error( f'Invalid secret token. Please fork the original space if you want to use it for yourself.') generator = torch.Generator(device="cuda").manual_seed(int(seed)) # Generate image embeddings using the prior model prior_output = prior( prompt=prompt, generator=generator, height=height, width=width, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_images_per_prompt=1, num_inference_steps=prior_inference_steps ) # Generate images using the decoder model and the embeddings from the prior model decoder_output = decoder( image_embeddings=prior_output.image_embeddings.half(), prompt=prompt, generator=generator, negative_prompt=negative_prompt, guidance_scale=0.0, # Guidance scale typically set to 0 for decoder as guidance is applied in the prior output_type="pil", num_inference_steps=decoder_inference_steps ).images image = decoder_output[0] image_base64 = writeb64(image) return image_base64 with gr.Blocks() as gradio_app: gr.HTML("""

This space is a REST API to programmatically generate an image.

Interested in using it? Please use the original space, thank you!

""") secret_token = gr.Textbox( placeholder="Secret token", show_label=False, ) text2image_prompt = gr.Textbox( lines=1, placeholder="Prompt", show_label=False, ) text2image_negative_prompt = gr.Textbox( lines=1, placeholder="Negative Prompt", show_label=False, ) text2image_seed = gr.Number( value=42, label="Seed", ) text2image_height = gr.Slider( minimum=128, maximum=1024, step=32, value=1024, label="Image Height", ) text2image_width = gr.Slider( minimum=128, maximum=1024, step=32, value=1024, label="Image Width", ) text2image_guidance_scale = gr.Slider( minimum=0.1, maximum=15, step=0.1, value=4.0, label="Guidance Scale", ) text2image_prior_inference_step = gr.Slider( minimum=1, maximum=50, step=1, value=20, label="Prior Inference Step", ) text2image_decoder_inference_step = gr.Slider( minimum=1, maximum=50, step=1, value=10, label="Decoder Inference Step", ) text2image_predict = gr.Button(value="Generate Image") output_image_base64 = gr.Text() text2image_predict.click( fn=generate_images, inputs=[ secret_token, text2image_prompt, text2image_negative_prompt, text2image_height, text2image_width, text2image_guidance_scale, text2image_seed, text2image_prior_inference_step, text2image_decoder_inference_step ], outputs=output_image_base64, api_name='run', ) gradio_app.queue(max_size=20).launch()