Spaces:
Runtime error
Runtime error
update
Browse files- app.py +17 -11
- inference.py +6 -4
app.py
CHANGED
|
@@ -131,7 +131,11 @@ def create_inference_demo(func: inference_fn) -> gr.Blocks:
|
|
| 131 |
maximum=10.,
|
| 132 |
step=1,
|
| 133 |
value=10)
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
run_button = gr.Button('Generate')
|
| 136 |
|
| 137 |
# gr.Markdown('''
|
|
@@ -146,23 +150,25 @@ def create_inference_demo(func: inference_fn) -> gr.Blocks:
|
|
| 146 |
# inputs=None,
|
| 147 |
# outputs=weight_name)
|
| 148 |
prompt.submit(fn=func,
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
| 157 |
run_button.click(fn=func,
|
| 158 |
inputs=[
|
| 159 |
model_id,
|
| 160 |
prompt,
|
| 161 |
num_samples,
|
| 162 |
guidance_scale,
|
|
|
|
| 163 |
],
|
| 164 |
-
|
| 165 |
-
|
| 166 |
return demo
|
| 167 |
|
| 168 |
|
|
|
|
| 131 |
maximum=10.,
|
| 132 |
step=1,
|
| 133 |
value=10)
|
| 134 |
+
ddim_steps = gr.Slider(label='Number of DDIM Sampling Steps',
|
| 135 |
+
minimum=10,
|
| 136 |
+
maximum=100,
|
| 137 |
+
step=1,
|
| 138 |
+
value=50)
|
| 139 |
run_button = gr.Button('Generate')
|
| 140 |
|
| 141 |
# gr.Markdown('''
|
|
|
|
| 150 |
# inputs=None,
|
| 151 |
# outputs=weight_name)
|
| 152 |
prompt.submit(fn=func,
|
| 153 |
+
inputs=[
|
| 154 |
+
model_id,
|
| 155 |
+
prompt,
|
| 156 |
+
num_samples,
|
| 157 |
+
guidance_scale,
|
| 158 |
+
ddim_steps,
|
| 159 |
+
],
|
| 160 |
+
outputs=result,
|
| 161 |
+
queue=False)
|
| 162 |
run_button.click(fn=func,
|
| 163 |
inputs=[
|
| 164 |
model_id,
|
| 165 |
prompt,
|
| 166 |
num_samples,
|
| 167 |
guidance_scale,
|
| 168 |
+
ddim_steps,
|
| 169 |
],
|
| 170 |
+
outputs=result,
|
| 171 |
+
queue=False)
|
| 172 |
return demo
|
| 173 |
|
| 174 |
|
inference.py
CHANGED
|
@@ -46,12 +46,14 @@ def inference_fn(
|
|
| 46 |
prompt: str,
|
| 47 |
num_samples: int,
|
| 48 |
guidance_scale: float,
|
|
|
|
| 49 |
) -> PIL.Image.Image:
|
| 50 |
|
| 51 |
# create inference pipeline
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
| 55 |
# make directory to save images
|
| 56 |
image_root_folder = os.path.join('experiments', model_id, 'inference')
|
| 57 |
os.makedirs(image_root_folder, exist_ok = True)
|
|
@@ -80,7 +82,7 @@ def inference_fn(
|
|
| 80 |
os.makedirs(image_folder, exist_ok = True)
|
| 81 |
|
| 82 |
# batch generation
|
| 83 |
-
images = pipe(prompt, num_inference_steps=
|
| 84 |
|
| 85 |
# save generated images
|
| 86 |
for idx, image in enumerate(images):
|
|
|
|
| 46 |
prompt: str,
|
| 47 |
num_samples: int,
|
| 48 |
guidance_scale: float,
|
| 49 |
+
ddim_steps: int,
|
| 50 |
) -> PIL.Image.Image:
|
| 51 |
|
| 52 |
# create inference pipeline
|
| 53 |
+
if torch.cuda.is_available():
|
| 54 |
+
pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id),torch_dtype=torch.float16).to('cuda')
|
| 55 |
+
else:
|
| 56 |
+
pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id)).to('cpu')
|
| 57 |
# make directory to save images
|
| 58 |
image_root_folder = os.path.join('experiments', model_id, 'inference')
|
| 59 |
os.makedirs(image_root_folder, exist_ok = True)
|
|
|
|
| 82 |
os.makedirs(image_folder, exist_ok = True)
|
| 83 |
|
| 84 |
# batch generation
|
| 85 |
+
images = pipe(prompt, num_inference_steps=ddim_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_samples).images
|
| 86 |
|
| 87 |
# save generated images
|
| 88 |
for idx, image in enumerate(images):
|