Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Three-View-Style-Embedder - Inference Utilities | |
| Lazy loading for Hugging Face Spaces compatibility | |
| """ | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| import threading | |
| # Shim for spaces to allow local execution without the package | |
| try: | |
| import spaces | |
| except ImportError: | |
| class spaces: | |
| def GPU(func): | |
| return func | |
| def _import_gradio(): | |
| try: | |
| import gradio as gr # type: ignore | |
| return gr | |
| except Exception as e: | |
| raise RuntimeError( | |
| "Failed to import Gradio. This usually means you're running with the wrong Python interpreter " | |
| "(e.g., system Python instead of the workspace .venv) or you have incompatible package versions.\n" | |
| "Fix: run with the venv interpreter: .\\.venv\\Scripts\\python.exe app.py ...\n" | |
| "Or on Windows: run run.bat" | |
| ) from e | |
| def _default_path(path_str: str) -> Path: | |
| return (Path(__file__).resolve().parent / path_str).resolve() | |
| from config import get_config | |
| from model import ArtistStyleModel | |
| class FaceEyeExtractor: | |
| def __init__( | |
| self, | |
| yolo_dir: Path, | |
| weights_path: Path, | |
| cascade_path: Path, | |
| device: str = 'cpu', | |
| imgsz: int = 640, | |
| conf: float = 0.5, | |
| iou: float = 0.5, | |
| eye_roi_frac: float = 0.70, | |
| eye_min_size: int = 12, | |
| eye_margin: float = 0.60, | |
| neighbors: int = 9, | |
| eye_fallback_to_face: bool = True, | |
| ): | |
| self.yolo_dir = Path(yolo_dir) | |
| self.weights_path = Path(weights_path) | |
| self.cascade_path = Path(cascade_path) | |
| self.device = device | |
| self.imgsz = imgsz | |
| self.conf = conf | |
| self.iou = iou | |
| self.eye_roi_frac = eye_roi_frac | |
| self.eye_min_size = eye_min_size | |
| self.eye_margin = eye_margin | |
| self.neighbors = neighbors | |
| self.eye_fallback_to_face = eye_fallback_to_face | |
| # No lock needed - Gradio runs synchronously | |
| self._yolo_model = None | |
| self._yolo_device = None | |
| self._stride = 32 | |
| self._tl = threading.local() | |
| def __getstate__(self): | |
| state = self.__dict__.copy() | |
| if "_tl" in state: | |
| del state["_tl"] | |
| return state | |
| def __setstate__(self, state): | |
| self.__dict__.update(state) | |
| self._tl = threading.local() | |
| def _patch_torch_load_for_old_ckpt(self): | |
| import torch as _torch | |
| import numpy as _np | |
| try: | |
| _torch.serialization.add_safe_globals([ | |
| _np.core.multiarray._reconstruct, | |
| _np.ndarray, | |
| ]) | |
| except Exception: | |
| pass | |
| def _ensure_ready(self): | |
| if self._yolo_model is not None and self._cascade is not None: | |
| return | |
| # Lazy import so app can still run if OpenCV/YOLO deps are missing. | |
| import sys | |
| import cv2 | |
| # Try to locate yolov5_anime if not strictly at yolo_dir | |
| if not self.yolo_dir.exists(): | |
| # Fallback: check if it's in the current working directory | |
| cwd_yolo = Path("yolov5_anime").resolve() | |
| if cwd_yolo.exists(): | |
| self.yolo_dir = cwd_yolo | |
| else: | |
| # Try relative to current file | |
| file_yolo = Path(__file__).parent / "yolov5_anime" | |
| if file_yolo.exists(): | |
| self.yolo_dir = file_yolo | |
| if not self.yolo_dir.exists(): | |
| raise RuntimeError( | |
| f"yolov5_anime directory not found. Tried: {self.yolo_dir}, " | |
| f"current dir: {Path.cwd()}, file dir: {Path(__file__).parent}" | |
| ) | |
| # Add to sys.path if not already there | |
| yolo_path_str = str(self.yolo_dir.resolve()) | |
| if yolo_path_str not in sys.path: | |
| sys.path.insert(0, yolo_path_str) | |
| self._patch_torch_load_for_old_ckpt() | |
| import torch as _torch | |
| # Attempt imports. If they fail, it might be because yolo_dir is missing or deps missing. | |
| try: | |
| from models.experimental import attempt_load # type: ignore | |
| from utils.torch_utils import select_device # type: ignore | |
| except ImportError as e: | |
| raise RuntimeError( | |
| f"Failed to import YOLOv5 modules. Make sure yolov5_anime directory exists at {self.yolo_dir}. " | |
| f"sys.path includes: {[p for p in sys.path if 'yolo' in p.lower()]}. " | |
| f"Original error: {e}" | |
| ) from e | |
| # Ensure YOLOv5 loads old .pt even on torch 2.6+ (weights_only default changes). | |
| orig_load = _torch.load | |
| def patched_load(*args, **kwargs): | |
| kwargs.setdefault('weights_only', False) | |
| return orig_load(*args, **kwargs) | |
| _torch.load = patched_load | |
| try: | |
| # For Spaces, use CPU for detector to avoid CUDA init in main process | |
| detector_device = 'cpu' if self.device.startswith('cuda') else self.device | |
| self._yolo_device = select_device(detector_device) | |
| if not self.weights_path.exists(): | |
| raise RuntimeError(f"YOLO weights not found: {self.weights_path}") | |
| self._yolo_model = attempt_load(str(self.weights_path), map_location=self._yolo_device) | |
| self._yolo_model.eval() | |
| self._stride = int(self._yolo_model.stride.max()) | |
| finally: | |
| _torch.load = orig_load | |
| if not self.cascade_path.exists(): | |
| raise RuntimeError(f"Cascade xml not found: {self.cascade_path}") | |
| cascade = cv2.CascadeClassifier(str(self.cascade_path)) | |
| if cascade.empty(): | |
| raise RuntimeError(f"cascade load failed: {self.cascade_path}") | |
| self._tl.cascade = cascade | |
| def _letterbox_compat(self, img0, new_shape, stride): | |
| from utils.datasets import letterbox # type: ignore | |
| try: | |
| out = letterbox(img0, new_shape, stride=stride, auto=False) | |
| except TypeError: | |
| try: | |
| out = letterbox(img0, new_shape, auto=False) | |
| except TypeError: | |
| out = letterbox(img0, new_shape) | |
| return out[0] | |
| def _detect_faces(self, rgb: np.ndarray) -> List[Tuple[int, int, int, int]]: | |
| self._ensure_ready() | |
| import cv2 | |
| import torch as _torch | |
| from utils.general import non_max_suppression, scale_coords # type: ignore | |
| img0 = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) | |
| h0, w0 = img0.shape[:2] | |
| imgsz = int(np.ceil(self.imgsz / self._stride) * self._stride) | |
| img = self._letterbox_compat(img0, imgsz, self._stride) | |
| img = img[:, :, ::-1].transpose(2, 0, 1) | |
| img = np.ascontiguousarray(img) | |
| im = _torch.from_numpy(img).to(self._yolo_device) | |
| im = im.float() / 255.0 | |
| if im.ndim == 3: | |
| im = im[None] | |
| with _torch.no_grad(): | |
| pred = self._yolo_model(im)[0] | |
| pred = non_max_suppression( | |
| pred, | |
| conf_thres=self.conf, | |
| iou_thres=self.iou, | |
| classes=None, | |
| agnostic=False, | |
| ) | |
| boxes: List[Tuple[int, int, int, int, float]] = [] | |
| det = pred[0] | |
| if det is not None and len(det): | |
| det[:, :4] = scale_coords((imgsz, imgsz), det[:, :4], (h0, w0)).round() | |
| for *xyxy, conf, _cls in det.tolist(): | |
| x1, y1, x2, y2 = [int(v) for v in xyxy] | |
| boxes.append((x1, y1, x2, y2, float(conf))) | |
| # Return only coordinates. | |
| boxes_xyxy = [(b[0], b[1], b[2], b[3]) for b in boxes] | |
| return boxes_xyxy | |
| def _expand(self, box, margin, W, H): | |
| x1, y1, x2, y2 = box | |
| cx = (x1 + x2) / 2.0 | |
| cy = (y1 + y2) / 2.0 | |
| w = (x2 - x1) * (1 + margin) | |
| h = (y2 - y1) * (1 + margin) | |
| nx1 = int(round(cx - w / 2)) | |
| ny1 = int(round(cy - h / 2)) | |
| nx2 = int(round(cx + w / 2)) | |
| ny2 = int(round(cy + h / 2)) | |
| nx1 = max(0, min(W, nx1)) | |
| ny1 = max(0, min(H, ny1)) | |
| nx2 = max(0, min(W, nx2)) | |
| ny2 = max(0, min(H, ny2)) | |
| return nx1, ny1, nx2, ny2 | |
| def _pre(self, gray): | |
| import cv2 | |
| gray = cv2.GaussianBlur(gray, (3, 3), 0) | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| return clahe.apply(gray) | |
| def _shrink_for_eye(self, img, limit=900): | |
| import cv2 | |
| h, w = img.shape[:2] | |
| m = max(h, w) | |
| if m <= limit: | |
| return img, 1.0 | |
| s = limit / float(m) | |
| nh, nw = int(h * s), int(w * s) | |
| small = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_AREA) | |
| return small, s | |
| def _detect_eyes_in_roi(self, rgb_roi): | |
| import cv2 | |
| gray = cv2.cvtColor(rgb_roi, cv2.COLOR_RGB2GRAY) | |
| proc = self._pre(gray) | |
| H, W = proc.shape[:2] | |
| min_side = max(1, min(W, H)) | |
| dyn_min = int(0.07 * min_side) | |
| min_sz = max(8, int(self.eye_min_size), dyn_min) | |
| cascade = getattr(self._tl, 'cascade', None) | |
| if cascade is None: | |
| cascade = cv2.CascadeClassifier(str(self.cascade_path)) | |
| if cascade.empty(): | |
| raise RuntimeError(f"cascade load failed: {self.cascade_path}") | |
| self._tl.cascade = cascade | |
| raw = cascade.detectMultiScale( | |
| proc, | |
| scaleFactor=1.15, | |
| minNeighbors=self.neighbors, | |
| minSize=(min_sz, min_sz), | |
| flags=cv2.CASCADE_SCALE_IMAGE, | |
| ) | |
| try: | |
| arr = np.asarray(raw if not isinstance(raw, tuple) else raw[0]) | |
| except Exception: | |
| arr = np.empty((0, 4), dtype=int) | |
| if arr.size == 0: | |
| return [] | |
| if arr.ndim == 1: | |
| arr = arr.reshape(1, -1) | |
| boxes = [] | |
| for r in arr: | |
| x, y, w, h = [int(v) for v in r[:4]] | |
| if w <= 0 or h <= 0: | |
| continue | |
| boxes.append((x, y, x + w, y + h)) | |
| return boxes | |
| def _best_pair(self, boxes, W, H): | |
| import itertools | |
| clean = [(int(b[0]), int(b[1]), int(b[2]), int(b[3])) for b in boxes] | |
| if len(clean) < 2: | |
| return [] | |
| def cxcy(b): | |
| x1, y1, x2, y2 = b | |
| return (x1 + x2) / 2.0, (y1 + y2) / 2.0 | |
| def area(b): | |
| x1, y1, x2, y2 = b | |
| return max(1, (x2 - x1) * (y2 - y1)) | |
| best = None | |
| best_s = 1e9 | |
| for b1, b2 in itertools.combinations(clean, 2): | |
| c1x, c1y = cxcy(b1) | |
| c2x, c2y = cxcy(b2) | |
| a1, a2 = area(b1), area(b2) | |
| horiz = 0.0 if c1x < c2x else 0.5 | |
| y_aln = abs(c1y - c2y) / max(1.0, H) | |
| szsim = abs(a1 - a2) / float(max(a1, a2)) | |
| gap = abs(c2x - c1x) / max(1.0, W) | |
| if 0.05 <= gap <= 0.5: | |
| gap_pen = 0.0 | |
| else: | |
| gap_pen = 0.5 * ((0.5 + abs(gap - 0.05) * 10) if gap < 0.05 else (gap - 0.5) * 2.0) | |
| mean_y = (c1y + c2y) / 2.0 / max(1.0, H) | |
| upper = 0.3 * max(0.0, (mean_y - 0.67) * 2.0) | |
| s = y_aln + szsim + gap_pen + upper + horiz | |
| if s < best_s: | |
| best_s = s | |
| best = (b1, b2) | |
| if best is None: | |
| return [] | |
| b1, b2 = best | |
| left, right = (b1, b2) if (b1[0] + b1[2]) <= (b2[0] + b2[2]) else (b2, b1) | |
| return [("left", left), ("right", right)] | |
| def extract_face(self, full_image: Image.Image) -> Optional[Image.Image]: | |
| rgb = np.array(full_image.convert('RGB')) | |
| boxes = self._detect_faces(rgb) | |
| if not boxes: | |
| return None | |
| # choose largest face | |
| def area(b): | |
| x1, y1, x2, y2 = b | |
| return max(0, x2 - x1) * max(0, y2 - y1) | |
| x1, y1, x2, y2 = max(boxes, key=area) | |
| H, W = rgb.shape[:2] | |
| x1 = max(0, min(W, x1)) | |
| x2 = max(0, min(W, x2)) | |
| y1 = max(0, min(H, y1)) | |
| y2 = max(0, min(H, y2)) | |
| if x2 <= x1 or y2 <= y1: | |
| return None | |
| face = rgb[y1:y2, x1:x2] | |
| return Image.fromarray(face) | |
| def extract_eye_region(self, face_image: Image.Image) -> Optional[Image.Image]: | |
| # Ensure ready (Gradio runs synchronously, so thread-safety not critical) | |
| self._ensure_ready() | |
| rgb_face = np.array(face_image.convert('RGB')) | |
| H, W = rgb_face.shape[:2] | |
| if H < 2 or W < 2: | |
| return None | |
| roi_h = int(H * float(self.eye_roi_frac)) | |
| roi_h = max(1, min(H, roi_h)) | |
| roi = rgb_face[0:roi_h, :] | |
| roi_small, s_roi = self._shrink_for_eye(roi, limit=512) | |
| face_small, s_face = self._shrink_for_eye(rgb_face, limit=768) | |
| eyes_roi = self._detect_eyes_in_roi(roi_small) | |
| eyes_roi = [ | |
| (int(x1 / s_roi), int(y1 / s_roi), int(x2 / s_roi), int(y2 / s_roi)) | |
| for (x1, y1, x2, y2) in eyes_roi | |
| ] | |
| labs = self._best_pair(eyes_roi, W, roi_h) | |
| origin = 'roi' if labs else None | |
| eyes_full = [] | |
| if self.eye_fallback_to_face and (not labs or len(labs) < 2): | |
| eyes_full = self._detect_eyes_in_roi(face_small) | |
| eyes_full = [ | |
| (int(x1 / s_face), int(y1 / s_face), int(x2 / s_face), int(y2 / s_face)) | |
| for (x1, y1, x2, y2) in eyes_full | |
| ] | |
| if len(eyes_full) >= 2: | |
| labs = self._best_pair(eyes_full, W, H) | |
| origin = 'face' if labs else origin | |
| if not labs: | |
| cand = eyes_roi | |
| cand_origin = 'roi' | |
| if self.eye_fallback_to_face and len(eyes_full) >= 1: | |
| cand = eyes_full | |
| cand_origin = 'face' | |
| if len(cand) >= 2: | |
| top2 = sorted(cand, key=lambda b: (b[2] - b[0]) * (b[3] - b[1]), reverse=True)[:2] | |
| top2 = sorted(top2, key=lambda b: (b[0] + b[2])) | |
| labs = [("left", top2[0]), ("right", top2[1])] | |
| origin = cand_origin | |
| elif len(cand) == 1: | |
| labs = [("left", cand[0])] | |
| origin = cand_origin | |
| if not labs: | |
| return None | |
| boxes = [box for _label, box in labs] | |
| if len(boxes) >= 2: | |
| boxes = sorted(boxes, key=lambda b: (b[0] + b[2]))[:2] | |
| src_img = roi if origin == 'roi' else rgb_face | |
| bound_h = roi_h if origin == 'roi' else H | |
| # Extract only one eye (prefer left eye) as a square crop | |
| target_box = boxes[0] # Take first eye (left) | |
| bx1, by1, bx2, by2 = target_box | |
| # Expand with margin | |
| ex1, ey1, ex2, ey2 = self._expand((bx1, by1, bx2, by2), self.eye_margin, W, bound_h) | |
| # Make it square by expanding to the larger dimension | |
| ew = ex2 - ex1 | |
| eh = ey2 - ey1 | |
| if ew > eh: | |
| # Width is larger, expand height | |
| diff = ew - eh | |
| ey1 = max(0, ey1 - diff // 2) | |
| ey2 = min(bound_h, ey2 + (diff - diff // 2)) | |
| elif eh > ew: | |
| # Height is larger, expand width | |
| diff = eh - ew | |
| ex1 = max(0, ex1 - diff // 2) | |
| ex2 = min(W, ex2 + (diff - diff // 2)) | |
| crop = src_img[ey1:ey2, ex1:ex2] | |
| if crop.size == 0 or min(crop.shape[0], crop.shape[1]) < self.eye_min_size: | |
| return None | |
| return Image.fromarray(crop) | |
| class StyleEmbedderApp: | |
| """Web UI μ± - Lazy loading for Spaces compatibility""" | |
| def __init__( | |
| self, | |
| checkpoint_path: str, | |
| embeddings_path: str, | |
| device: str = 'cuda', | |
| yolo_dir: Optional[str] = None, | |
| yolo_weights: Optional[str] = None, | |
| eyes_cascade: Optional[str] = None, | |
| detector_device: str = 'cpu', | |
| ): | |
| # Store paths - don't load anything yet to avoid CUDA init in main process | |
| self.checkpoint_path = checkpoint_path | |
| self.embeddings_path = embeddings_path | |
| self.requested_device = device | |
| self.detector_device = detector_device | |
| # Model will be loaded lazily in @spaces.GPU decorated function | |
| self._model = None | |
| self._model_loading = False # Flag to prevent concurrent loading | |
| self._embeddings_loaded = False | |
| self._artist_names = None | |
| self._embeddings = None | |
| # Face/Eye extractor - lazy load to avoid pickle issues with cv2.CascadeClassifier | |
| self._extractor = None | |
| self._extractor_yolo_dir = yolo_dir | |
| self._extractor_yolo_weights = yolo_weights | |
| self._extractor_eyes_cascade = eyes_cascade | |
| # Transform (no CUDA needed) | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def _ensure_model_loaded(self): | |
| """Lazy load model - only called inside @spaces.GPU decorated function""" | |
| if self._model is not None: | |
| return | |
| # Simple double-check pattern (Gradio runs synchronously, so race condition unlikely) | |
| if self._model_loading: | |
| # Wait for loading to complete | |
| import time | |
| while self._model_loading and self._model is None: | |
| time.sleep(0.01) | |
| return | |
| if self._model is not None: | |
| return | |
| self._model_loading = True | |
| try: | |
| print("Loading model (lazy)...") | |
| # Load checkpoint on CPU first | |
| checkpoint = torch.load(self.checkpoint_path, map_location='cpu') | |
| config = get_config() | |
| self._model = ArtistStyleModel( | |
| num_classes=len(checkpoint['artist_to_idx']), | |
| embedding_dim=config.model.embedding_dim, | |
| hidden_dim=config.model.hidden_dim, | |
| ) | |
| self._model.load_state_dict(checkpoint['model_state_dict']) | |
| # Determine device - in @spaces.GPU context, CUDA should be available | |
| if self.requested_device.startswith('cuda') and torch.cuda.is_available(): | |
| device = torch.device(self.requested_device) | |
| # Reduce VRAM: keep weights in FP16 on CUDA | |
| self._model = self._model.to(dtype=torch.float16) | |
| else: | |
| device = torch.device('cpu') | |
| self._model = self._model.to(device) | |
| self._model.eval() | |
| self.device = device | |
| self.embedding_dim = config.model.embedding_dim | |
| print("Model loaded successfully") | |
| finally: | |
| self._model_loading = False | |
| def _ensure_embeddings_loaded(self): | |
| """Lazy load embeddings - no CUDA needed""" | |
| if self._embeddings_loaded: | |
| return | |
| # Simple check (Gradio runs synchronously) | |
| if self._embeddings_loaded: | |
| return | |
| print("Loading embeddings...") | |
| data = np.load(self.embeddings_path) | |
| self._artist_names = data['artist_names'].tolist() | |
| self._embeddings = data['embeddings'] | |
| self._embeddings_loaded = True | |
| print(f"Loaded {len(self._artist_names)} artist embeddings") | |
| def preprocess_image(self, image: Optional[Image.Image]) -> Optional[torch.Tensor]: | |
| """μ΄λ―Έμ§ μ μ²λ¦¬""" | |
| if image is None: | |
| return None | |
| try: | |
| if image.mode in ('RGBA', 'LA', 'P'): | |
| background = Image.new('RGB', image.size, (255, 255, 255)) | |
| if image.mode == 'P': | |
| image = image.convert('RGBA') | |
| if image.mode in ('RGBA', 'LA'): | |
| background.paste(image, mask=image.split()[-1]) | |
| image = background | |
| else: | |
| image = image.convert('RGB') | |
| else: | |
| image = image.convert('RGB') | |
| return self.transform(image) | |
| except: | |
| return None | |
| def get_embedding( | |
| self, | |
| full_image: Image.Image, | |
| face_image: Optional[Image.Image] = None, | |
| eye_image: Optional[Image.Image] = None, | |
| ) -> np.ndarray: | |
| """μ΄λ―Έμ§μμ μλ² λ© μΆμΆ - GPU lazy loading""" | |
| # Load model on first call (inside @spaces.GPU context) | |
| self._ensure_model_loaded() | |
| full_tensor = self.preprocess_image(full_image) | |
| if full_tensor is None: | |
| raise ValueError("Full image is required") | |
| full = full_tensor.unsqueeze(0).to(self.device) | |
| # Auto face/eye extraction if not provided | |
| auto_face_image = face_image | |
| auto_eye_image = eye_image | |
| if auto_face_image is None or auto_eye_image is None: | |
| try: | |
| extractor = self._get_extractor() | |
| if auto_face_image is None: | |
| auto_face_image = extractor.extract_face(full_image) | |
| if auto_eye_image is None: | |
| # Prefer detecting eyes from face if available. | |
| if auto_face_image is not None: | |
| auto_eye_image = extractor.extract_eye_region(auto_face_image) | |
| except Exception as e: | |
| # If detector fails, proceed without branches. | |
| print(f"[WARN] Auto face/eye extraction failed: {e}") | |
| face_tensor = self.preprocess_image(auto_face_image) | |
| if face_tensor is not None: | |
| face = face_tensor.unsqueeze(0).to(self.device) | |
| has_face = torch.tensor([True]).to(self.device) | |
| else: | |
| face = torch.zeros(1, 3, 224, 224).to(self.device) | |
| has_face = torch.tensor([False]).to(self.device) | |
| eye_tensor = self.preprocess_image(auto_eye_image) | |
| if eye_tensor is not None: | |
| eye = eye_tensor.unsqueeze(0).to(self.device) | |
| has_eye = torch.tensor([True]).to(self.device) | |
| else: | |
| eye = torch.zeros(1, 3, 224, 224).to(self.device) | |
| has_eye = torch.tensor([False]).to(self.device) | |
| with torch.cuda.amp.autocast(enabled=(self.device.type == 'cuda')): | |
| embedding = self._model.get_embeddings(full, face, eye, has_face, has_eye) | |
| # Keep output float32 for downstream numpy similarity math. | |
| return embedding.squeeze(0).float().cpu().numpy() | |
| def find_similar_artists( | |
| self, | |
| query_embedding: np.ndarray, | |
| top_k: int = 10, | |
| ) -> List[Tuple[str, float]]: | |
| """μ μ¬ μκ° κ²μ""" | |
| # Load embeddings if not loaded | |
| self._ensure_embeddings_loaded() | |
| query_norm = query_embedding / np.linalg.norm(query_embedding) | |
| embeddings_norm = self._embeddings / np.linalg.norm(self._embeddings, axis=1, keepdims=True) | |
| similarities = embeddings_norm @ query_norm | |
| top_indices = np.argsort(similarities)[::-1][:top_k] | |
| return [(self._artist_names[i], float(similarities[i])) for i in top_indices] | |
| def _get_extractor(self): | |
| """Lazy load extractor to avoid pickle issues""" | |
| if self._extractor is None: | |
| self._extractor = FaceEyeExtractor( | |
| yolo_dir=_default_path('yolov5_anime') if self._extractor_yolo_dir is None else Path(self._extractor_yolo_dir), | |
| weights_path=_default_path('yolov5x_anime.pt') if self._extractor_yolo_weights is None else Path(self._extractor_yolo_weights), | |
| cascade_path=_default_path('anime-eyes-cascade.xml') if self._extractor_eyes_cascade is None else Path(self._extractor_eyes_cascade), | |
| device='cpu', # Always use CPU for detector to avoid CUDA init | |
| ) | |
| return self._extractor | |
| def _extract_crops_impl(self, full_image: Image.Image) -> Tuple[Optional[Image.Image], Optional[Image.Image], str]: | |
| """μΌκ΅΄κ³Ό λ μλ μΆμΆ - λ΄λΆ ꡬν""" | |
| if full_image is None: | |
| return None, None, "β μ 체 μ΄λ―Έμ§λ₯Ό λ¨Όμ μ λ‘λν΄μ£ΌμΈμ." | |
| try: | |
| extractor = self._get_extractor() | |
| face = extractor.extract_face(full_image) | |
| eye = None | |
| if face is not None: | |
| eye = extractor.extract_eye_region(face) | |
| status = "β μΆμΆ μλ£:\n" | |
| status += f"- μΌκ΅΄: {'λ°κ²¬λ¨' if face else 'λ°κ²¬λμ§ μμ'}\n" | |
| status += f"- λ: {'λ°κ²¬λ¨' if eye else 'λ°κ²¬λμ§ μμ'}\n\n" | |
| if face is None: | |
| status += "π‘ μΌκ΅΄μ΄ κ°μ§λμ§ μμμ΅λλ€. μλμΌλ‘ μ λ‘λν΄μ£ΌμΈμ." | |
| elif eye is None: | |
| status += "π‘ λμ΄ κ°μ§λμ§ μμμ΅λλ€. μλμΌλ‘ μ λ‘λν΄μ£ΌμΈμ." | |
| return face, eye, status | |
| except Exception as e: | |
| return None, None, f"β μΆμΆ μ€ν¨: {str(e)}" | |
| def extract_crops(self, full_image: Image.Image) -> Tuple[Optional[Image.Image], Optional[Image.Image], str]: | |
| """μΌκ΅΄κ³Ό λ μλ μΆμΆ - Gradioμ© λν ν¨μ""" | |
| # Create extractor on-demand to avoid pickle issues | |
| # The extractor will be created fresh each time, but _ensure_ready() handles caching | |
| return self._extract_crops_impl(full_image) | |
| def search( | |
| self, | |
| full_image: Image.Image, | |
| face_image: Optional[Image.Image], | |
| eye_image: Optional[Image.Image], | |
| top_k: int, | |
| ) -> str: | |
| """κ²μ μ€ν""" | |
| if full_image is None: | |
| return "β μ 체 μ΄λ―Έμ§λ₯Ό μ λ‘λν΄μ£ΌμΈμ." | |
| try: | |
| # μλ² λ© μΆμΆ (μλμΌλ‘ μΌκ΅΄/λ μΆμΆ) | |
| auto_extracted = False | |
| if face_image is None or eye_image is None: | |
| auto_extracted = True | |
| # This calls the @spaces.GPU decorated function | |
| embedding = self.get_embedding(full_image, face_image, eye_image) | |
| # μ μ¬ μκ° κ²μ | |
| results = self.find_similar_artists(embedding, top_k=top_k) | |
| # κ²°κ³Ό ν¬λ§·ν | |
| output = "## π¨ κ²μ κ²°κ³Ό\n\n" | |
| if auto_extracted: | |
| output += "_βΉοΈ μΌκ΅΄/λμ΄ μ λ‘λλμ§ μμ μλ μΆμΆμ μλνμ΅λλ€._\n\n" | |
| output += "| μμ | μκ° | μ μ¬λ |\n" | |
| output += "|:----:|:-----|:------:|\n" | |
| for i, (name, score) in enumerate(results, 1): | |
| bar = "β" * int(score * 20) + "β" * (20 - int(score * 20)) | |
| output += f"| {i} | **{name}** | {score:.4f} {bar} |\n" | |
| return output | |
| except Exception as e: | |
| return f"β μ€λ₯ λ°μ: {str(e)}" | |
| def create_ui(self): | |
| """Gradio UI μμ±""" | |
| gr = _import_gradio() | |
| with gr.Blocks(title="Three-View-Style-Embedder", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π¨ Three-View-Style-Embedder | |
| μΌλ¬μ€νΈ μ΄λ―Έμ§λ₯Ό μ λ‘λνλ©΄ κ°μ₯ μ μ¬ν μ€νμΌμ μκ°λ₯Ό μ°Ύμλ립λλ€. | |
| - **μ 체 μ΄λ―Έμ§**: νμ (μν μ 체) | |
| - **μΌκ΅΄/λ μ΄λ―Έμ§**: μ ν (μλ μΆμΆλκ±°λ μλ μ λ‘λ) | |
| π‘ **μΌκ΅΄/λμ μ λ‘λνμ§ μμΌλ©΄ μλμΌλ‘ κ°μ§νμ¬ μΆμΆν©λλ€!** | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| full_input = gr.Image( | |
| label="μ 체 μ΄λ―Έμ§ (νμ)", | |
| type="pil", | |
| height=256, | |
| ) | |
| extract_btn = gr.Button("βοΈ μΌκ΅΄/λ μλ μΆμΆ", variant="secondary") | |
| extract_status = gr.Markdown(value="") | |
| with gr.Row(): | |
| face_input = gr.Image( | |
| label="μΌκ΅΄ (μ ν - μλμΆμΆ κ°λ₯)", | |
| type="pil", | |
| height=128, | |
| ) | |
| eye_input = gr.Image( | |
| label="λ (μ ν - μλμΆμΆ κ°λ₯)", | |
| type="pil", | |
| height=128, | |
| ) | |
| top_k = gr.Slider( | |
| minimum=5, | |
| maximum=50, | |
| value=10, | |
| step=5, | |
| label="κ²μ κ²°κ³Ό μ", | |
| ) | |
| search_btn = gr.Button("π κ²μ", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| output = gr.Markdown( | |
| value="μ΄λ―Έμ§λ₯Ό μ λ‘λνκ³ κ²μ λ²νΌμ λλ¬μ£ΌμΈμ.", | |
| label="κ²°κ³Ό", | |
| ) | |
| # μ΄λ²€νΈ μ°κ²° | |
| extract_btn.click( | |
| fn=self.extract_crops, | |
| inputs=[full_input], | |
| outputs=[face_input, eye_input, extract_status], | |
| ) | |
| search_btn.click( | |
| fn=self.search, | |
| inputs=[full_input, face_input, eye_input, top_k], | |
| outputs=output, | |
| ) | |
| # μμ | |
| gr.Markdown(""" | |
| --- | |
| ### π‘ μ¬μ© λ°©λ² | |
| 1. **μ 체 μ΄λ―Έμ§**λ₯Ό μ λ‘λ | |
| 2. **[βοΈ μΌκ΅΄/λ μλ μΆμΆ]** λ²νΌμ ν΄λ¦ (μ νμ¬ν) | |
| - λλ μ§μ μΌκ΅΄/λ μ΄λ―Έμ§λ₯Ό μ λ‘λ | |
| - μ무κ²λ νμ§ μμΌλ©΄ κ²μ μ μλμΌλ‘ μΆμΆλ©λλ€ | |
| 3. **[π κ²μ]** λ²νΌμ ν΄λ¦νμ¬ μ μ¬ μκ° μ°ΎκΈ° | |
| ### π‘ ν | |
| - μΌκ΅΄/λμ μλμΌλ‘ μ λ‘λνλ©΄ λ μ νν κ²°κ³Όλ₯Ό μ»μ μ μμ΅λλ€ | |
| - μ μ¬λ 1.0μ κ°κΉμΈμλ‘ μ€νμΌμ΄ λΉμ·ν©λλ€ | |
| """) | |
| return demo | |