Spaces:
Runtime error
Runtime error
Apply superresolution with Real-ESRGAN
Browse files
app.py
CHANGED
|
@@ -60,6 +60,8 @@ def create_advanced_demo(model: Model) -> gr.Blocks:
|
|
| 60 |
step=1,
|
| 61 |
value=1234,
|
| 62 |
label='Seed')
|
|
|
|
|
|
|
| 63 |
run_button = gr.Button('Run')
|
| 64 |
with gr.Column():
|
| 65 |
with gr.Tabs():
|
|
@@ -80,6 +82,7 @@ def create_advanced_demo(model: Model) -> gr.Blocks:
|
|
| 80 |
num_steps,
|
| 81 |
randomize_seed,
|
| 82 |
seed,
|
|
|
|
| 83 |
],
|
| 84 |
outputs=[
|
| 85 |
result,
|
|
|
|
| 60 |
step=1,
|
| 61 |
value=1234,
|
| 62 |
label='Seed')
|
| 63 |
+
superresolve = gr.Checkbox(value=False,
|
| 64 |
+
label='Superresolve')
|
| 65 |
run_button = gr.Button('Run')
|
| 66 |
with gr.Column():
|
| 67 |
with gr.Tabs():
|
|
|
|
| 82 |
num_steps,
|
| 83 |
randomize_seed,
|
| 84 |
seed,
|
| 85 |
+
superresolve,
|
| 86 |
],
|
| 87 |
outputs=[
|
| 88 |
result,
|
model.py
CHANGED
|
@@ -6,6 +6,7 @@ import random
|
|
| 6 |
import sys
|
| 7 |
import tempfile
|
| 8 |
|
|
|
|
| 9 |
import imageio
|
| 10 |
import numpy as np
|
| 11 |
import PIL.Image
|
|
@@ -44,6 +45,8 @@ class Model:
|
|
| 44 |
self.scheduler_type)
|
| 45 |
self.rng = random.Random()
|
| 46 |
|
|
|
|
|
|
|
| 47 |
@staticmethod
|
| 48 |
def _load_pipeline(model_name: str,
|
| 49 |
scheduler_type: str) -> DiffusionPipeline:
|
|
@@ -140,17 +143,29 @@ class Model:
|
|
| 140 |
writer.close()
|
| 141 |
|
| 142 |
logger.info('--- done ---')
|
| 143 |
-
return res, out_file.name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
def run(self, model_name: str, scheduler_type: str, num_steps: int,
|
| 146 |
-
randomize_seed: bool,
|
| 147 |
-
|
| 148 |
self.set_pipeline(model_name, scheduler_type)
|
| 149 |
if scheduler_type == 'PNDM':
|
| 150 |
num_steps = max(4, min(num_steps, 100))
|
| 151 |
if randomize_seed:
|
| 152 |
seed = self.rng.randint(0, 100000)
|
| 153 |
res, filename = self.generate_with_video(seed, num_steps)
|
|
|
|
|
|
|
| 154 |
return res, seed, filename
|
| 155 |
|
| 156 |
@staticmethod
|
|
@@ -169,4 +184,5 @@ class Model:
|
|
| 169 |
self.set_pipeline(self.MODEL_NAMES[0], 'DDIM')
|
| 170 |
seed = self.rng.randint(0, 1000000)
|
| 171 |
images = self.generate(seed, num_steps=10, num_images=4)
|
|
|
|
| 172 |
return self.to_grid(images, 2)
|
|
|
|
| 6 |
import sys
|
| 7 |
import tempfile
|
| 8 |
|
| 9 |
+
import gradio as gr
|
| 10 |
import imageio
|
| 11 |
import numpy as np
|
| 12 |
import PIL.Image
|
|
|
|
| 45 |
self.scheduler_type)
|
| 46 |
self.rng = random.Random()
|
| 47 |
|
| 48 |
+
self.real_esrgan = gr.Interface.load('spaces/hysts/Real-ESRGAN-anime')
|
| 49 |
+
|
| 50 |
@staticmethod
|
| 51 |
def _load_pipeline(model_name: str,
|
| 52 |
scheduler_type: str) -> DiffusionPipeline:
|
|
|
|
| 143 |
writer.close()
|
| 144 |
|
| 145 |
logger.info('--- done ---')
|
| 146 |
+
return PIL.Image.fromarray(res), out_file.name
|
| 147 |
+
|
| 148 |
+
def superresolve(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
| 149 |
+
logger.info('--- superresolve ---')
|
| 150 |
+
|
| 151 |
+
with tempfile.NamedTemporaryFile(suffix='.png') as f:
|
| 152 |
+
image.save(f.name)
|
| 153 |
+
out_file = self.real_esrgan(f.name)
|
| 154 |
+
|
| 155 |
+
logger.info('--- done ---')
|
| 156 |
+
return PIL.Image.open(out_file)
|
| 157 |
|
| 158 |
def run(self, model_name: str, scheduler_type: str, num_steps: int,
|
| 159 |
+
randomize_seed: bool, seed: int,
|
| 160 |
+
superresolve: bool) -> tuple[PIL.Image.Image, int, str]:
|
| 161 |
self.set_pipeline(model_name, scheduler_type)
|
| 162 |
if scheduler_type == 'PNDM':
|
| 163 |
num_steps = max(4, min(num_steps, 100))
|
| 164 |
if randomize_seed:
|
| 165 |
seed = self.rng.randint(0, 100000)
|
| 166 |
res, filename = self.generate_with_video(seed, num_steps)
|
| 167 |
+
if superresolve:
|
| 168 |
+
res = self.superresolve(res)
|
| 169 |
return res, seed, filename
|
| 170 |
|
| 171 |
@staticmethod
|
|
|
|
| 184 |
self.set_pipeline(self.MODEL_NAMES[0], 'DDIM')
|
| 185 |
seed = self.rng.randint(0, 1000000)
|
| 186 |
images = self.generate(seed, num_steps=10, num_images=4)
|
| 187 |
+
images = [self.superresolve(image) for image in images]
|
| 188 |
return self.to_grid(images, 2)
|