# --- Imports import os import cv2 import numpy as np from PIL import Image, ImageDraw, ImageFont from rembg import remove from sklearn.cluster import KMeans from zipfile import ZipFile import gradio as gr # --- Dominant color extractor def get_dominant_colors(image_path, k=3): image = Image.open(image_path).convert("RGB") image = image.resize((image.width // 10, image.height // 10)) image = np.array(image) pixels = image.reshape((-1, 3)) kmeans = KMeans(n_clusters=k) kmeans.fit(pixels) dominant_colors = kmeans.cluster_centers_.astype(int) hex_colors = ['#{:02x}{:02x}{:02x}'.format(*color) for color in dominant_colors] return dominant_colors.tolist(), hex_colors # --- Create palette image def create_color_palette_image(hex_colors, save_path): swatch_size = 100 spacing = 20 width = swatch_size * len(hex_colors) + spacing * (len(hex_colors) + 1) height = swatch_size + 60 palette_image = Image.new("RGB", (width, height), "white") draw = ImageDraw.Draw(palette_image) try: font = ImageFont.truetype("DejaVuSans-Bold.ttf", 16) except: font = ImageFont.load_default() for i, hex_color in enumerate(hex_colors): x = spacing + i * (swatch_size + spacing) y = 20 draw.rectangle([x, y, x + swatch_size, y + swatch_size], fill=hex_color) text_width, text_height = draw.textbbox((0, 0), hex_color, font=font)[2:] text_x = x + (swatch_size - text_width) // 2 text_y = y + swatch_size + 5 draw.text((text_x, text_y), hex_color, fill="black", font=font) palette_image.save(save_path) # --- Theme determiner def determine_theme(rgb_colors): color_ranges = { 'pastel': [(230, 230, 230), (180, 180, 180)], 'neon': [(255, 0, 255), (0, 255, 255)], 'earthy': [(139, 69, 19), (128, 128, 0)], 'neutral': [(211, 211, 211), (169, 169, 169)], 'bold': [(255, 0, 0), (0, 0, 255)], 'dark': [(0, 0, 0), (25, 25, 112)], } for t, ref_colors in color_ranges.items(): for color in rgb_colors: if any(np.allclose(color, ref, atol=40) for ref in ref_colors): return t return 'mixed' # --- Brightness checker def is_valid_crop(roi): gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY) mean_brightness = np.mean(gray) return 30 < mean_brightness < 240 # --- Main sticker extraction def extract_stickers(image_pil): image = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) temp_dir = "temp_stickers" os.makedirs(temp_dir, exist_ok=True) filename = "uploaded_image.png" sheet_path = os.path.join(temp_dir, filename) Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).save(sheet_path) rgb_colors, hex_colors = get_dominant_colors(sheet_path) theme = determine_theme(rgb_colors) palette_path = os.path.join(temp_dir, "palette.png") create_color_palette_image(hex_colors, palette_path) gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) sticker_count = 0 for i, cnt in enumerate(contours): x, y, w, h = cv2.boundingRect(cnt) if 50 < w < image.shape[1] and 50 < h < image.shape[0]: x1, y1 = max(x - 10, 0), max(y - 10, 0) x2, y2 = min(x + w + 10, image.shape[1]), min(y + h + 10, image.shape[0]) roi = image[y1:y2, x1:x2] if not is_valid_crop(roi): continue sticker_name = f"sticker_{sticker_count + 1}.png" cropped_path = os.path.join(temp_dir, sticker_name) Image.fromarray(cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)).save(cropped_path) with open(cropped_path, "rb") as f: input_bytes = f.read() output_bytes = remove(input_bytes) with open(cropped_path, "wb") as out: out.write(output_bytes) sticker_count += 1 zip_filename = "extracted_stickers.zip" with ZipFile(zip_filename, 'w') as zipf: for root, _, files in os.walk(temp_dir): for file in files: zipf.write(os.path.join(root, file), arcname=file) return zip_filename # --- Gradio UI interface = gr.Interface( fn=extract_stickers, inputs=gr.Image(type="pil", label="Upload Sticker Sheet (PNG)"), outputs=gr.File(label="Download Extracted Stickers (ZIP)"), title="Sticker Extractor ✂️", description="Upload your PNG sticker sheet to extract individual stickers. Click the button below to begin!", theme=gr.themes.Base(primary_hue="pink", font=["Fira Code", "monospace"]) ) interface.launch()