jbilcke-hf commited on
Commit
8a15462
·
verified ·
1 Parent(s): 94a2689

Update stable_cascade.py

Browse files
Files changed (1) hide show
  1. stable_cascade.py +42 -9
stable_cascade.py CHANGED
@@ -1,12 +1,37 @@
1
  import torch, os
2
  from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda")
6
  decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda")
7
 
8
  def generate_images(
9
- prompt="a photo of a girl",
 
10
  negative_prompt="bad,ugly,deformed",
11
  height=1024,
12
  width=1024,
@@ -29,6 +54,10 @@ def generate_images(
29
  Returns:
30
  - List[PIL.Image]: A list of generated PIL Image objects.
31
  """
 
 
 
 
32
  generator = torch.Generator(device="cuda").manual_seed(int(seed))
33
 
34
  # Generate image embeddings using the prior model
@@ -54,13 +83,21 @@ def generate_images(
54
  num_inference_steps=decoder_inference_steps
55
  ).images
56
 
57
- return decoder_output
 
 
 
 
58
 
59
 
60
  def web_demo():
61
  with gr.Blocks():
62
  with gr.Row():
63
  with gr.Column():
 
 
 
 
64
  text2image_prompt = gr.Textbox(
65
  lines=1,
66
  placeholder="Prompt",
@@ -129,16 +166,12 @@ def web_demo():
129
  )
130
  text2image_predict = gr.Button(value="Generate Image")
131
 
132
- with gr.Column():
133
- output_image = gr.Gallery(
134
- label="Generated images",
135
- show_label=False,
136
- elem_id="gallery",
137
- ).style(grid=(1, 2), height=300)
138
 
139
  text2image_predict.click(
140
  fn=generate_images,
141
  inputs=[
 
142
  text2image_prompt,
143
  text2image_negative_prompt,
144
  text2image_height,
@@ -149,5 +182,5 @@ def web_demo():
149
  text2image_prior_inference_step,
150
  text2image_decoder_inference_step
151
  ],
152
- outputs=output_image,
153
  )
 
1
  import torch, os
2
  from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
3
  import gradio as gr
4
+ from io import BytesIO
5
+ import base64
6
+ import re
7
+
8
+ SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
9
+
10
+ # Regex pattern to match data URI scheme
11
+ data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')
12
+
13
+ def readb64(b64):
14
+ # Remove any data URI scheme prefix with regex
15
+ b64 = data_uri_pattern.sub("", b64)
16
+ # Decode and open the image with PIL
17
+ img = Image.open(BytesIO(base64.b64decode(b64)))
18
+ return img
19
+
20
+ # convert from PIL to base64
21
+ def writeb64(image):
22
+ buffered = BytesIO()
23
+ image.save(buffered, format="PNG")
24
+ b64image = base64.b64encode(buffered.getvalue())
25
+ b64image_str = b64image.decode("utf-8")
26
+ return b64image_str
27
+
28
 
29
  prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda")
30
  decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda")
31
 
32
  def generate_images(
33
+ secret_token="",
34
+ prompt="",
35
  negative_prompt="bad,ugly,deformed",
36
  height=1024,
37
  width=1024,
 
54
  Returns:
55
  - List[PIL.Image]: A list of generated PIL Image objects.
56
  """
57
+ if secret_token != SECRET_TOKEN:
58
+ raise gr.Error(
59
+ f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
60
+
61
  generator = torch.Generator(device="cuda").manual_seed(int(seed))
62
 
63
  # Generate image embeddings using the prior model
 
83
  num_inference_steps=decoder_inference_steps
84
  ).images
85
 
86
+ image = decoder_output[0]
87
+
88
+ image_base64 = writeb64(image)
89
+
90
+ return image_base64
91
 
92
 
93
  def web_demo():
94
  with gr.Blocks():
95
  with gr.Row():
96
  with gr.Column():
97
+ secret_token = gr.Textbox(
98
+ placeholder="Secret token",
99
+ show_label=False,
100
+ )
101
  text2image_prompt = gr.Textbox(
102
  lines=1,
103
  placeholder="Prompt",
 
166
  )
167
  text2image_predict = gr.Button(value="Generate Image")
168
 
169
+ output_image_base64 = gr.Text()
 
 
 
 
 
170
 
171
  text2image_predict.click(
172
  fn=generate_images,
173
  inputs=[
174
+ secret_token,
175
  text2image_prompt,
176
  text2image_negative_prompt,
177
  text2image_height,
 
182
  text2image_prior_inference_step,
183
  text2image_decoder_inference_step
184
  ],
185
+ outputs=output_image_base64,
186
  )