Spaces:
Runtime error
Runtime error
Commit
·
527bf99
1
Parent(s):
d64b071
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,39 +2,11 @@ import gradio as gr
|
|
| 2 |
import numpy as np
|
| 3 |
import cv2
|
| 4 |
from PIL import Image
|
| 5 |
-
|
|
|
|
| 6 |
MAX_COLORS = 12
|
| 7 |
|
| 8 |
-
|
| 9 |
-
im = image.getcolors(maxcolors=1024*1024)
|
| 10 |
-
sorted_colors = sorted(im, key=lambda x: x[0], reverse=True)
|
| 11 |
-
|
| 12 |
-
freqs = [c[0] for c in sorted_colors]
|
| 13 |
-
mean_freq = sum(freqs) / len(freqs)
|
| 14 |
-
|
| 15 |
-
high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq/3)] # Ignore colors that occur very few times (less than 2) or less than half the average frequency
|
| 16 |
-
return high_freq_colors
|
| 17 |
-
|
| 18 |
-
def color_quantization(image, n_colors):
|
| 19 |
-
# Get color histogram
|
| 20 |
-
hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256)))
|
| 21 |
-
# Get most frequent colors
|
| 22 |
-
colors = np.argwhere(hist > 0)
|
| 23 |
-
colors = colors[np.argsort(hist[colors[:, 0], colors[:, 1], colors[:, 2]])[::-1]]
|
| 24 |
-
colors = colors[:n_colors]
|
| 25 |
-
# Replace each pixel with the closest color
|
| 26 |
-
dists = np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2)
|
| 27 |
-
labels = np.argmin(dists, axis=1)
|
| 28 |
-
return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8)
|
| 29 |
-
|
| 30 |
-
def create_binary_matrix(img_arr, target_color):
|
| 31 |
-
print(target_color)
|
| 32 |
-
# Create mask of pixels with target color
|
| 33 |
-
mask = np.all(img_arr == target_color, axis=-1)
|
| 34 |
-
|
| 35 |
-
# Convert mask to binary matrix
|
| 36 |
-
binary_matrix = mask.astype(int)
|
| 37 |
-
return binary_matrix
|
| 38 |
|
| 39 |
def process_sketch(image, binary_matrixes):
|
| 40 |
high_freq_colors = get_high_freq_colors(image)
|
|
@@ -43,13 +15,12 @@ def process_sketch(image, binary_matrixes):
|
|
| 43 |
im2arr = color_quantization(im2arr, n_colors=how_many_colors)
|
| 44 |
|
| 45 |
colors_fixed = []
|
| 46 |
-
for color in high_freq_colors
|
| 47 |
-
r = color[1]
|
| 48 |
-
g
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
colors_fixed.append(gr.update(value=f'<div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div>'))
|
| 53 |
visibilities = []
|
| 54 |
colors = []
|
| 55 |
for n in range(MAX_COLORS):
|
|
@@ -62,8 +33,15 @@ def process_sketch(image, binary_matrixes):
|
|
| 62 |
|
| 63 |
def process_generation(binary_matrixes, master_prompt, *prompts):
|
| 64 |
clipped_prompts = prompts[:len(binary_matrixes)]
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
css = '''
|
| 69 |
#color-bg{display:flex;justify-content: center;align-items: center;}
|
|
@@ -72,15 +50,11 @@ css = '''
|
|
| 72 |
'''
|
| 73 |
def update_css(aspect):
|
| 74 |
if(aspect=='Square'):
|
| 75 |
-
|
| 76 |
-
height = 512
|
| 77 |
elif(aspect == 'Horizontal'):
|
| 78 |
-
|
| 79 |
-
height = 512
|
| 80 |
elif(aspect=='Vertical'):
|
| 81 |
-
|
| 82 |
-
height = 768
|
| 83 |
-
return gr.update(value=f"<style>#main-image{{width: {width}px}} .fixed-height{{height: {height}px !important}}</style>")
|
| 84 |
|
| 85 |
with gr.Blocks(css=css) as demo:
|
| 86 |
binary_matrixes = gr.State([])
|
|
@@ -89,11 +63,13 @@ with gr.Blocks(css=css) as demo:
|
|
| 89 |
''')
|
| 90 |
with gr.Row():
|
| 91 |
with gr.Box(elem_id="main-image"):
|
| 92 |
-
with gr.Row():
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
| 97 |
|
| 98 |
prompts = []
|
| 99 |
colors = []
|
|
@@ -111,9 +87,8 @@ with gr.Blocks(css=css) as demo:
|
|
| 111 |
gr.Markdown('''
|
| 112 |

|
| 113 |
''')
|
| 114 |
-
css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
|
| 115 |
-
|
| 116 |
-
aspect.change(update_css, inputs=aspect, outputs=css_height)
|
| 117 |
button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors])
|
| 118 |
final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image)
|
| 119 |
-
demo.launch()
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import cv2
|
| 4 |
from PIL import Image
|
| 5 |
+
from region_control import MultiDiffusion, get_views, preprocess_mask
|
| 6 |
+
from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix
|
| 7 |
MAX_COLORS = 12
|
| 8 |
|
| 9 |
+
sd = MultiDiffusion("cuda", "2.1")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def process_sketch(image, binary_matrixes):
|
| 12 |
high_freq_colors = get_high_freq_colors(image)
|
|
|
|
| 15 |
im2arr = color_quantization(im2arr, n_colors=how_many_colors)
|
| 16 |
|
| 17 |
colors_fixed = []
|
| 18 |
+
for color in high_freq_colors:
|
| 19 |
+
r, g, b = color[1]
|
| 20 |
+
if any(c != 255 for c in (r, g, b)):
|
| 21 |
+
binary_matrix = create_binary_matrix(im2arr, (r,g,b))
|
| 22 |
+
binary_matrixes.append(binary_matrix)
|
| 23 |
+
colors_fixed.append(gr.update(value=f'<div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div>'))
|
|
|
|
| 24 |
visibilities = []
|
| 25 |
colors = []
|
| 26 |
for n in range(MAX_COLORS):
|
|
|
|
| 33 |
|
| 34 |
def process_generation(binary_matrixes, master_prompt, *prompts):
|
| 35 |
clipped_prompts = prompts[:len(binary_matrixes)]
|
| 36 |
+
prompts = [master_prompt] + list(clipped_prompts)
|
| 37 |
+
neg_prompts = [""] * len(prompts)
|
| 38 |
+
fg_masks = torch.cat([preprocess_mask(mask_path, 512 // 8, 512 // 8, "cuda") for mask_path in binary_matrixes])
|
| 39 |
+
bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
|
| 40 |
+
bg_mask[bg_mask < 0] = 0
|
| 41 |
+
masks = torch.cat([bg_mask, fg_masks])
|
| 42 |
+
print(masks.size())
|
| 43 |
+
image = sd.generate(masks, prompts, neg_prompts, 512, 512, 50, bootstrapping=20)
|
| 44 |
+
return(image)
|
| 45 |
|
| 46 |
css = '''
|
| 47 |
#color-bg{display:flex;justify-content: center;align-items: center;}
|
|
|
|
| 50 |
'''
|
| 51 |
def update_css(aspect):
|
| 52 |
if(aspect=='Square'):
|
| 53 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
|
|
|
|
| 54 |
elif(aspect == 'Horizontal'):
|
| 55 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)]
|
|
|
|
| 56 |
elif(aspect=='Vertical'):
|
| 57 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
|
|
|
|
|
|
| 58 |
|
| 59 |
with gr.Blocks(css=css) as demo:
|
| 60 |
binary_matrixes = gr.State([])
|
|
|
|
| 63 |
''')
|
| 64 |
with gr.Row():
|
| 65 |
with gr.Box(elem_id="main-image"):
|
| 66 |
+
#with gr.Row():
|
| 67 |
+
image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512,512), brush_radius=45)
|
| 68 |
+
#image_horizontal = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(768,512), visible=False, brush_radius=45)
|
| 69 |
+
#image_vertical = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512, 768), visible=False, brush_radius=45)
|
| 70 |
+
#with gr.Row():
|
| 71 |
+
# aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio")
|
| 72 |
+
button_run = gr.Button("I've finished my sketch",elem_id="main_button", interactive=True)
|
| 73 |
|
| 74 |
prompts = []
|
| 75 |
colors = []
|
|
|
|
| 87 |
gr.Markdown('''
|
| 88 |

|
| 89 |
''')
|
| 90 |
+
#css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
|
| 91 |
+
#aspect.change(update_css, inputs=aspect, outputs=[image, image_horizontal, image_vertical])
|
|
|
|
| 92 |
button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors])
|
| 93 |
final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image)
|
| 94 |
+
demo.launch(debug=True)
|