jbilcke-hf commited on
Commit
110a710
·
verified ·
1 Parent(s): 15b0643

Delete stable_cascade.py

Browse files
Files changed (1) hide show
  1. stable_cascade.py +0 -177
stable_cascade.py DELETED
@@ -1,177 +0,0 @@
1
- import torch, os
2
- from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
3
- import gradio as gr
4
- from io import BytesIO
5
- import base64
6
- import re
7
-
8
- SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
9
-
10
- # Regex pattern to match data URI scheme
11
- data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')
12
-
13
- def readb64(b64):
14
- # Remove any data URI scheme prefix with regex
15
- b64 = data_uri_pattern.sub("", b64)
16
- # Decode and open the image with PIL
17
- img = Image.open(BytesIO(base64.b64decode(b64)))
18
- return img
19
-
20
- # convert from PIL to base64
21
- def writeb64(image):
22
- buffered = BytesIO()
23
- image.save(buffered, format="PNG")
24
- b64image = base64.b64encode(buffered.getvalue())
25
- b64image_str = b64image.decode("utf-8")
26
- return b64image_str
27
-
28
-
29
- prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda")
30
- decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda")
31
-
32
- def generate_images(
33
- secret_token="",
34
- prompt="",
35
- negative_prompt="bad,ugly,deformed",
36
- height=1024,
37
- width=1024,
38
- guidance_scale=4.0,
39
- seed=42,
40
- prior_inference_steps=20,
41
- decoder_inference_steps=10
42
- ):
43
- """
44
- Generates images based on a given prompt using Stable Diffusion models on CUDA device.
45
- Parameters:
46
- - prompt (str): The prompt to generate images for.
47
- - negative_prompt (str): The negative prompt to guide image generation away from.
48
- - height (int): The height of the generated images.
49
- - width (int): The width of the generated images.
50
- - guidance_scale (float): The scale of guidance for the image generation.
51
- - prior_inference_steps (int): The number of inference steps for the prior model.
52
- - decoder_inference_steps (int): The number of inference steps for the decoder model.
53
- Returns:
54
- - List[PIL.Image]: A list of generated PIL Image objects.
55
- """
56
- if secret_token != SECRET_TOKEN:
57
- raise gr.Error(
58
- f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
59
-
60
- generator = torch.Generator(device="cuda").manual_seed(int(seed))
61
-
62
- # Generate image embeddings using the prior model
63
- prior_output = prior(
64
- prompt=prompt,
65
- generator=generator,
66
- height=height,
67
- width=width,
68
- negative_prompt=negative_prompt,
69
- guidance_scale=guidance_scale,
70
- num_images_per_prompt=1,
71
- num_inference_steps=prior_inference_steps
72
- )
73
-
74
- # Generate images using the decoder model and the embeddings from the prior model
75
- decoder_output = decoder(
76
- image_embeddings=prior_output.image_embeddings.half(),
77
- prompt=prompt,
78
- generator=generator,
79
- negative_prompt=negative_prompt,
80
- guidance_scale=0.0, # Guidance scale typically set to 0 for decoder as guidance is applied in the prior
81
- output_type="pil",
82
- num_inference_steps=decoder_inference_steps
83
- ).images
84
-
85
- image = decoder_output[0]
86
-
87
- image_base64 = writeb64(image)
88
-
89
- return image_base64
90
-
91
-
92
- def web_demo():
93
- with gr.Blocks():
94
- with gr.Row():
95
- with gr.Column():
96
- secret_token = gr.Textbox(
97
- placeholder="Secret token",
98
- show_label=False,
99
- )
100
- text2image_prompt = gr.Textbox(
101
- lines=1,
102
- placeholder="Prompt",
103
- show_label=False,
104
- )
105
-
106
- text2image_negative_prompt = gr.Textbox(
107
- lines=1,
108
- placeholder="Negative Prompt",
109
- show_label=False,
110
- )
111
-
112
- text2image_seed = gr.Number(
113
- value=42,
114
- label="Seed",
115
- )
116
-
117
- with gr.Row():
118
- with gr.Column():
119
- text2image_height = gr.Slider(
120
- minimum=128,
121
- maximum=1024,
122
- step=32,
123
- value=1024,
124
- label="Image Height",
125
- )
126
-
127
- text2image_width = gr.Slider(
128
- minimum=128,
129
- maximum=1024,
130
- step=32,
131
- value=1024,
132
- label="Image Width",
133
- )
134
- with gr.Row():
135
- with gr.Column():
136
- text2image_guidance_scale = gr.Slider(
137
- minimum=0.1,
138
- maximum=15,
139
- step=0.1,
140
- value=4.0,
141
- label="Guidance Scale",
142
- )
143
- text2image_prior_inference_step = gr.Slider(
144
- minimum=1,
145
- maximum=50,
146
- step=1,
147
- value=20,
148
- label="Prior Inference Step",
149
- )
150
-
151
- text2image_decoder_inference_step = gr.Slider(
152
- minimum=1,
153
- maximum=50,
154
- step=1,
155
- value=10,
156
- label="Decoder Inference Step",
157
- )
158
- text2image_predict = gr.Button(value="Generate Image")
159
-
160
- output_image_base64 = gr.Text()
161
-
162
- text2image_predict.click(
163
- fn=generate_images,
164
- inputs=[
165
- secret_token,
166
- text2image_prompt,
167
- text2image_negative_prompt,
168
- text2image_height,
169
- text2image_width,
170
- text2image_guidance_scale,
171
- text2image_seed,
172
- text2image_prior_inference_step,
173
- text2image_decoder_inference_step
174
- ],
175
- outputs=output_image_base64,
176
- api_name='run',
177
- )