File size: 5,697 Bytes
2204ef0
15b0643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2204ef0
1e1a292
15b0643
7e7b8c7
 
 
 
 
 
 
15b0643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5da000
15b0643
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import gradio as gr
import torch, os
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
import gradio as gr
from io import BytesIO
import base64
import re

SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')

# Regex pattern to match data URI scheme
data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')

def readb64(b64):
    # Remove any data URI scheme prefix with regex
    b64 = data_uri_pattern.sub("", b64)
    # Decode and open the image with PIL
    img = Image.open(BytesIO(base64.b64decode(b64)))
    return img
    
# convert from PIL to base64
def writeb64(image):
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    b64image = base64.b64encode(buffered.getvalue())
    b64image_str = b64image.decode("utf-8")
    return b64image_str


prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda")
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda")

def generate_images(
    secret_token="",
    prompt="",
    negative_prompt="bad,ugly,deformed",
    height=1024,
    width=1024, 
    guidance_scale=4.0, 
    seed=42,
    prior_inference_steps=20, 
    decoder_inference_steps=10
    ):
    """
    Generates images based on a given prompt using Stable Diffusion models on CUDA device.
    Parameters:
    - prompt (str): The prompt to generate images for.
    - negative_prompt (str): The negative prompt to guide image generation away from.
    - height (int): The height of the generated images.
    - width (int): The width of the generated images.
    - guidance_scale (float): The scale of guidance for the image generation.
    - prior_inference_steps (int): The number of inference steps for the prior model.
    - decoder_inference_steps (int): The number of inference steps for the decoder model.
    Returns:
    - List[PIL.Image]: A list of generated PIL Image objects.
    """
    if secret_token != SECRET_TOKEN:
        raise gr.Error(
            f'Invalid secret token. Please fork the original space if you want to use it for yourself.')

    generator = torch.Generator(device="cuda").manual_seed(int(seed))

    # Generate image embeddings using the prior model
    prior_output = prior(
        prompt=prompt,
        generator=generator,
        height=height,
        width=width,
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale,
        num_images_per_prompt=1,
        num_inference_steps=prior_inference_steps
    )

    # Generate images using the decoder model and the embeddings from the prior model
    decoder_output = decoder(
        image_embeddings=prior_output.image_embeddings.half(),
        prompt=prompt,
        generator=generator,
        negative_prompt=negative_prompt,
        guidance_scale=0.0,  # Guidance scale typically set to 0 for decoder as guidance is applied in the prior
        output_type="pil",
        num_inference_steps=decoder_inference_steps
    ).images

    image = decoder_output[0]
        
    image_base64 = writeb64(image)
        
    return image_base64


with gr.Blocks() as gradio_app:
    gr.HTML("""
        <div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
        <div style="text-align: center; color: black;">
        <p style="color: black;">This space is a REST API to programmatically generate an image.</p>
        <p style="color: black;">Interested in using it? Please use the <a href="https://huggingface.co/spaces/ArtGAN/Diffusion-API" target="_blank">original space</a>, thank you!</p>
        </div>
        </div>""")
    secret_token = gr.Textbox(
        placeholder="Secret token",
        show_label=False,
    )
    text2image_prompt = gr.Textbox(
        lines=1,
        placeholder="Prompt",
        show_label=False,
    )
    text2image_negative_prompt = gr.Textbox(
        lines=1,
        placeholder="Negative Prompt",
        show_label=False,
    )
    text2image_seed = gr.Number(
        value=42,
        label="Seed",
    )
    text2image_height = gr.Slider(
        minimum=128,
        maximum=1024,
        step=32,
        value=1024,
        label="Image Height",
    )
    text2image_width = gr.Slider(
        minimum=128,
        maximum=1024,
        step=32,
        value=1024,
        label="Image Width",
    )
    text2image_guidance_scale = gr.Slider(
        minimum=0.1,
        maximum=15,
        step=0.1,
        value=4.0,
        label="Guidance Scale",
    )                
    text2image_prior_inference_step = gr.Slider(
        minimum=1,
        maximum=50,
        step=1,
        value=20,
        label="Prior Inference Step",
    )                
    text2image_decoder_inference_step = gr.Slider(
        minimum=1,
        maximum=50,
        step=1,
        value=10,
        label="Decoder Inference Step",
    )               
    text2image_predict = gr.Button(value="Generate Image")
    output_image_base64 = gr.Text()
    text2image_predict.click(
        fn=generate_images,
        inputs=[
            secret_token,
            text2image_prompt,
            text2image_negative_prompt,
            text2image_height,
            text2image_width,
            text2image_guidance_scale,
            text2image_seed,
            text2image_prior_inference_step,
            text2image_decoder_inference_step
        ],
        outputs=output_image_base64,
        api_name='run',
    )

gradio_app.queue(max_size=20).launch()