File size: 6,594 Bytes
f0aa930
35dc227
 
8a15462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35dc227
f9309dc
20e6fde
35dc227
 
8a15462
 
35dc227
 
 
 
a064d1f
35dc227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a15462
 
 
 
7e1ec9f
35dc227
 
 
 
a064d1f
35dc227
 
 
 
80eae63
35dc227
 
 
 
 
 
 
a064d1f
35dc227
 
 
 
 
 
8a15462
 
 
 
 
35dc227
 
 
 
 
 
8a15462
 
 
 
35dc227
 
 
 
 
 
 
 
 
 
 
a064d1f
d1b6cf8
a064d1f
 
 
 
35dc227
 
 
 
c383967
35dc227
edf4c40
35dc227
 
 
 
 
c383967
35dc227
edf4c40
35dc227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1a3ecf
35dc227
 
 
 
8a15462
35dc227
 
 
 
 
0cc2ea6
35dc227
 
 
8a15462
dd01375
35dc227
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import torch, os
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
import gradio as gr
from io import BytesIO
import base64
import re

SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')

# Regex pattern to match data URI scheme
data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')

def readb64(b64):
    # Remove any data URI scheme prefix with regex
    b64 = data_uri_pattern.sub("", b64)
    # Decode and open the image with PIL
    img = Image.open(BytesIO(base64.b64decode(b64)))
    return img
    
# convert from PIL to base64
def writeb64(image):
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    b64image = base64.b64encode(buffered.getvalue())
    b64image_str = b64image.decode("utf-8")
    return b64image_str


prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda")
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda")

def generate_images(
    secret_token="",
    prompt="",
    negative_prompt="bad,ugly,deformed",
    height=1024,
    width=1024, 
    guidance_scale=4.0, 
    seed=42,
    prior_inference_steps=20, 
    decoder_inference_steps=10
    ):
    """
    Generates images based on a given prompt using Stable Diffusion models on CUDA device.
    Parameters:
    - prompt (str): The prompt to generate images for.
    - negative_prompt (str): The negative prompt to guide image generation away from.
    - height (int): The height of the generated images.
    - width (int): The width of the generated images.
    - guidance_scale (float): The scale of guidance for the image generation.
    - prior_inference_steps (int): The number of inference steps for the prior model.
    - decoder_inference_steps (int): The number of inference steps for the decoder model.
    Returns:
    - List[PIL.Image]: A list of generated PIL Image objects.
    """
    if secret_token != SECRET_TOKEN:
        raise gr.Error(
            f'Invalid secret token. Please fork the original space if you want to use it for yourself.')

    generator = torch.Generator(device="cuda").manual_seed(int(seed))

    # Generate image embeddings using the prior model
    prior_output = prior(
        prompt=prompt,
        generator=generator,
        height=height,
        width=width,
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale,
        num_images_per_prompt=num_images_per_prompt,
        num_inference_steps=prior_inference_steps
    )

    # Generate images using the decoder model and the embeddings from the prior model
    decoder_output = decoder(
        image_embeddings=prior_output.image_embeddings.half(),
        prompt=prompt,
        generator=generator,
        negative_prompt=negative_prompt,
        guidance_scale=0.0,  # Guidance scale typically set to 0 for decoder as guidance is applied in the prior
        output_type="pil",
        num_inference_steps=decoder_inference_steps
    ).images

    image = decoder_output[0]
        
    image_base64 = writeb64(image)
        
    return image_base64


def web_demo():
    with gr.Blocks():
        with gr.Row():
            with gr.Column():
                secret_token = gr.Textbox(
                    placeholder="Secret token",
                    show_label=False,
                )
                text2image_prompt = gr.Textbox(
                    lines=1,
                    placeholder="Prompt",
                    show_label=False,
                )

                text2image_negative_prompt = gr.Textbox(
                    lines=1,
                    placeholder="Negative Prompt",
                    show_label=False,
                )
                
                text2image_seed = gr.Number(
                    value=42,
                    label="Seed",
                )
                
                with gr.Row():
                    with gr.Column():
                        text2image_height = gr.Slider(
                            minimum=128,
                            maximum=1024,
                            step=32,
                            value=1024,
                            label="Image Height",
                        )

                        text2image_width = gr.Slider(
                            minimum=128,
                            maximum=1024,
                            step=32,
                            value=1024,
                            label="Image Width",
                        )
                        with gr.Row():
                            with gr.Column():
                                text2image_guidance_scale = gr.Slider(
                                    minimum=0.1,
                                    maximum=15,
                                    step=0.1,
                                    value=4.0,
                                    label="Guidance Scale",
                                )                
                                text2image_prior_inference_step = gr.Slider(
                                    minimum=1,
                                    maximum=50,
                                    step=1,
                                    value=20,
                                    label="Prior Inference Step",
                                )                
                                
                                text2image_decoder_inference_step = gr.Slider(
                                    minimum=1,
                                    maximum=50,
                                    step=1,
                                    value=10,
                                    label="Decoder Inference Step",
                                )               
                text2image_predict = gr.Button(value="Generate Image")
                
            output_image_base64 = gr.Text()
                
            text2image_predict.click(
                fn=generate_images,
                inputs=[
                    secret_token,
                    text2image_prompt,
                    text2image_negative_prompt,
                    text2image_height,
                    text2image_width,
                    text2image_guidance_scale,
                    text2image_seed,
                    text2image_prior_inference_step,
                    text2image_decoder_inference_step
                ],
                outputs=output_image_base64,
                api_name='run',
            )