Spaces:
Runtime error
Runtime error
Commit
·
f797ad8
1
Parent(s):
7b90989
Create sketch_helper.py
Browse files- sketch_helper.py +38 -0
sketch_helper.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
def get_high_freq_colors(image):
|
| 6 |
+
im = image.getcolors(maxcolors=1024*1024)
|
| 7 |
+
sorted_colors = sorted(im, key=lambda x: x[0], reverse=True)
|
| 8 |
+
|
| 9 |
+
freqs = [c[0] for c in sorted_colors]
|
| 10 |
+
mean_freq = sum(freqs) / len(freqs)
|
| 11 |
+
|
| 12 |
+
high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq)] # Ignore colors that occur very few times (less than 2) or less than half the average frequency
|
| 13 |
+
return high_freq_colors
|
| 14 |
+
|
| 15 |
+
def color_quantization(image, n_colors):
|
| 16 |
+
# Get color histogram
|
| 17 |
+
hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256)))
|
| 18 |
+
# Get most frequent colors
|
| 19 |
+
colors = np.argwhere(hist > 0)
|
| 20 |
+
colors = colors[np.argsort(hist[colors[:, 0], colors[:, 1], colors[:, 2]])[::-1]]
|
| 21 |
+
colors = colors[:n_colors]
|
| 22 |
+
# Replace each pixel with the closest color
|
| 23 |
+
dists = np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2)
|
| 24 |
+
labels = np.argmin(dists, axis=1)
|
| 25 |
+
return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8)
|
| 26 |
+
|
| 27 |
+
def create_binary_matrix(img_arr, target_color):
|
| 28 |
+
# Create mask of pixels with target color
|
| 29 |
+
mask = np.all(img_arr == target_color, axis=-1)
|
| 30 |
+
|
| 31 |
+
# Convert mask to binary matrix
|
| 32 |
+
binary_matrix = mask.astype(int)
|
| 33 |
+
from datetime import datetime
|
| 34 |
+
binary_file_name = f'mask-{datetime.now().timestamp()}.png'
|
| 35 |
+
cv2.imwrite(binary_file_name, binary_matrix * 255)
|
| 36 |
+
|
| 37 |
+
#binary_matrix = torch.from_numpy(binary_matrix).unsqueeze(0).unsqueeze(0)
|
| 38 |
+
return binary_file_name
|