jbilcke-hf commited on
Commit
15b0643
·
verified ·
1 Parent(s): c02d7b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -8
app.py CHANGED
@@ -1,9 +1,96 @@
1
- from stable_cascade import web_demo
2
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
- gradio_app = gr.Blocks()
6
- with gradio_app:
7
  gr.HTML("""
8
  <div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
9
  <div style="text-align: center; color: black;">
@@ -11,9 +98,76 @@ with gradio_app:
11
  <p style="color: black;">Interested in using it? Please use the <a href="https://huggingface.co/spaces/ArtGAN/Diffusion-API" target="_blank">original space</a>, thank you!</p>
12
  </div>
13
  </div>""")
14
- with gr.Row():
15
- with gr.Column():
16
- web_demo()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- gradio_app.queue()
19
- gradio_app.launch(debug=True)
 
 
1
  import gradio as gr
2
+ import torch, os
3
+ from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
4
+ import gradio as gr
5
+ from io import BytesIO
6
+ import base64
7
+ import re
8
+
9
+ SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
10
+
11
+ # Regex pattern to match data URI scheme
12
+ data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')
13
+
14
+ def readb64(b64):
15
+ # Remove any data URI scheme prefix with regex
16
+ b64 = data_uri_pattern.sub("", b64)
17
+ # Decode and open the image with PIL
18
+ img = Image.open(BytesIO(base64.b64decode(b64)))
19
+ return img
20
+
21
+ # convert from PIL to base64
22
+ def writeb64(image):
23
+ buffered = BytesIO()
24
+ image.save(buffered, format="PNG")
25
+ b64image = base64.b64encode(buffered.getvalue())
26
+ b64image_str = b64image.decode("utf-8")
27
+ return b64image_str
28
+
29
+
30
+ prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda")
31
+ decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda")
32
+
33
+ def generate_images(
34
+ secret_token="",
35
+ prompt="",
36
+ negative_prompt="bad,ugly,deformed",
37
+ height=1024,
38
+ width=1024,
39
+ guidance_scale=4.0,
40
+ seed=42,
41
+ prior_inference_steps=20,
42
+ decoder_inference_steps=10
43
+ ):
44
+ """
45
+ Generates images based on a given prompt using Stable Diffusion models on CUDA device.
46
+ Parameters:
47
+ - prompt (str): The prompt to generate images for.
48
+ - negative_prompt (str): The negative prompt to guide image generation away from.
49
+ - height (int): The height of the generated images.
50
+ - width (int): The width of the generated images.
51
+ - guidance_scale (float): The scale of guidance for the image generation.
52
+ - prior_inference_steps (int): The number of inference steps for the prior model.
53
+ - decoder_inference_steps (int): The number of inference steps for the decoder model.
54
+ Returns:
55
+ - List[PIL.Image]: A list of generated PIL Image objects.
56
+ """
57
+ if secret_token != SECRET_TOKEN:
58
+ raise gr.Error(
59
+ f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
60
+
61
+ generator = torch.Generator(device="cuda").manual_seed(int(seed))
62
+
63
+ # Generate image embeddings using the prior model
64
+ prior_output = prior(
65
+ prompt=prompt,
66
+ generator=generator,
67
+ height=height,
68
+ width=width,
69
+ negative_prompt=negative_prompt,
70
+ guidance_scale=guidance_scale,
71
+ num_images_per_prompt=1,
72
+ num_inference_steps=prior_inference_steps
73
+ )
74
+
75
+ # Generate images using the decoder model and the embeddings from the prior model
76
+ decoder_output = decoder(
77
+ image_embeddings=prior_output.image_embeddings.half(),
78
+ prompt=prompt,
79
+ generator=generator,
80
+ negative_prompt=negative_prompt,
81
+ guidance_scale=0.0, # Guidance scale typically set to 0 for decoder as guidance is applied in the prior
82
+ output_type="pil",
83
+ num_inference_steps=decoder_inference_steps
84
+ ).images
85
+
86
+ image = decoder_output[0]
87
+
88
+ image_base64 = writeb64(image)
89
+
90
+ return image_base64
91
 
92
 
93
+ with gr.Blocks() as gradio_app:
 
94
  gr.HTML("""
95
  <div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
96
  <div style="text-align: center; color: black;">
 
98
  <p style="color: black;">Interested in using it? Please use the <a href="https://huggingface.co/spaces/ArtGAN/Diffusion-API" target="_blank">original space</a>, thank you!</p>
99
  </div>
100
  </div>""")
101
+ secret_token = gr.Textbox(
102
+ placeholder="Secret token",
103
+ show_label=False,
104
+ )
105
+ text2image_prompt = gr.Textbox(
106
+ lines=1,
107
+ placeholder="Prompt",
108
+ show_label=False,
109
+ )
110
+ text2image_negative_prompt = gr.Textbox(
111
+ lines=1,
112
+ placeholder="Negative Prompt",
113
+ show_label=False,
114
+ )
115
+ text2image_seed = gr.Number(
116
+ value=42,
117
+ label="Seed",
118
+ )
119
+ text2image_height = gr.Slider(
120
+ minimum=128,
121
+ maximum=1024,
122
+ step=32,
123
+ value=1024,
124
+ label="Image Height",
125
+ )
126
+ text2image_width = gr.Slider(
127
+ minimum=128,
128
+ maximum=1024,
129
+ step=32,
130
+ value=1024,
131
+ label="Image Width",
132
+ )
133
+ text2image_guidance_scale = gr.Slider(
134
+ minimum=0.1,
135
+ maximum=15,
136
+ step=0.1,
137
+ value=4.0,
138
+ label="Guidance Scale",
139
+ )
140
+ text2image_prior_inference_step = gr.Slider(
141
+ minimum=1,
142
+ maximum=50,
143
+ step=1,
144
+ value=20,
145
+ label="Prior Inference Step",
146
+ )
147
+ text2image_decoder_inference_step = gr.Slider(
148
+ minimum=1,
149
+ maximum=50,
150
+ step=1,
151
+ value=10,
152
+ label="Decoder Inference Step",
153
+ )
154
+ text2image_predict = gr.Button(value="Generate Image")
155
+ output_image_base64 = gr.Text()
156
+ text2image_predict.click(
157
+ fn=generate_images,
158
+ inputs=[
159
+ secret_token,
160
+ text2image_prompt,
161
+ text2image_negative_prompt,
162
+ text2image_height,
163
+ text2image_width,
164
+ text2image_guidance_scale,
165
+ text2image_seed,
166
+ text2image_prior_inference_step,
167
+ text2image_decoder_inference_step
168
+ ],
169
+ outputs=output_image_base64,
170
+ api_name='run',
171
+ )
172
 
173
+ gradio_app.queue(max_size=20).launch()