Spaces:
Runtime error
Runtime error
| import base64 | |
| import imghdr | |
| import io | |
| import os | |
| import sys | |
| from typing import List, Optional, Dict, Tuple | |
| from urllib.parse import urlparse | |
| import cv2 | |
| from PIL import Image, ImageOps, PngImagePlugin | |
| import numpy as np | |
| import torch | |
| from iopaint.const import MPS_UNSUPPORT_MODELS | |
| from loguru import logger | |
| from torch.hub import download_url_to_file, get_dir | |
| import hashlib | |
| def md5sum(filename): | |
| md5 = hashlib.md5() | |
| with open(filename, "rb") as f: | |
| for chunk in iter(lambda: f.read(128 * md5.block_size), b""): | |
| md5.update(chunk) | |
| return md5.hexdigest() | |
| def switch_mps_device(model_name, device): | |
| if model_name in MPS_UNSUPPORT_MODELS and str(device) == "mps": | |
| logger.info(f"{model_name} not support mps, switch to cpu") | |
| return torch.device("cpu") | |
| return device | |
| def get_cache_path_by_url(url): | |
| parts = urlparse(url) | |
| hub_dir = get_dir() | |
| model_dir = os.path.join(hub_dir, "checkpoints") | |
| if not os.path.isdir(model_dir): | |
| os.makedirs(model_dir) | |
| filename = os.path.basename(parts.path) | |
| cached_file = os.path.join(model_dir, filename) | |
| return cached_file | |
| def get_cache_path_by_local(url): | |
| root_path = os.getcwd() | |
| model_path = os.path.join(root_path, 'pretrained-model', 'big-lama.pt') | |
| return model_path | |
| def download_model(url, model_md5: str = None): | |
| cached_file = get_cache_path_by_url(url) | |
| # cached_file = get_cache_path_by_local(url) | |
| if not os.path.exists(cached_file): | |
| sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | |
| hash_prefix = None | |
| download_url_to_file(url, cached_file, hash_prefix, progress=True) | |
| if model_md5: | |
| _md5 = md5sum(cached_file) | |
| if model_md5 == _md5: | |
| logger.info(f"Download model success, md5: {_md5}") | |
| else: | |
| try: | |
| os.remove(cached_file) | |
| logger.error( | |
| f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart iopaint." | |
| f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n" | |
| ) | |
| except: | |
| logger.error( | |
| f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart iopaint." | |
| ) | |
| exit(-1) | |
| return cached_file | |
| def ceil_modulo(x, mod): | |
| if x % mod == 0: | |
| return x | |
| return (x // mod + 1) * mod | |
| def handle_error(model_path, model_md5, e): | |
| _md5 = md5sum(model_path) | |
| if _md5 != model_md5: | |
| try: | |
| os.remove(model_path) | |
| logger.error( | |
| f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart iopaint." | |
| f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n" | |
| ) | |
| except: | |
| logger.error( | |
| f"Model md5: {_md5}, expected md5: {model_md5}, please delete {model_path} and restart iopaint." | |
| ) | |
| else: | |
| logger.error( | |
| f"Failed to load model {model_path}," | |
| f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}" | |
| ) | |
| exit(-1) | |
| def load_jit_model(url_or_path, device, model_md5: str): | |
| if os.path.exists(url_or_path): | |
| model_path = url_or_path | |
| else: | |
| model_path = download_model(url_or_path, model_md5) | |
| logger.info(f"Loading model from: {model_path}") | |
| try: | |
| model = torch.jit.load(model_path, map_location="cpu").to(device) | |
| except Exception as e: | |
| handle_error(model_path, model_md5, e) | |
| model.eval() | |
| return model | |
| def load_model(model: torch.nn.Module, url_or_path, device, model_md5): | |
| if os.path.exists(url_or_path): | |
| model_path = url_or_path | |
| else: | |
| model_path = download_model(url_or_path, model_md5) | |
| try: | |
| logger.info(f"Loading model from: {model_path}") | |
| state_dict = torch.load(model_path, map_location="cpu") | |
| model.load_state_dict(state_dict, strict=True) | |
| model.to(device) | |
| except Exception as e: | |
| handle_error(model_path, model_md5, e) | |
| model.eval() | |
| return model | |
| def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: | |
| data = cv2.imencode( | |
| f".{ext}", | |
| image_numpy, | |
| [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], | |
| )[1] | |
| image_bytes = data.tobytes() | |
| return image_bytes | |
| def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes: | |
| with io.BytesIO() as output: | |
| kwargs = {k: v for k, v in infos.items() if v is not None} | |
| if ext == "jpg": | |
| ext = "jpeg" | |
| if "png" == ext.lower() and "parameters" in kwargs: | |
| pnginfo_data = PngImagePlugin.PngInfo() | |
| pnginfo_data.add_text("parameters", kwargs["parameters"]) | |
| kwargs["pnginfo"] = pnginfo_data | |
| pil_img.save(output, format=ext, quality=quality, **kwargs) | |
| image_bytes = output.getvalue() | |
| return image_bytes | |
| def pil_to_bytes_single(pil_img, ext: str, quality: int = 95, infos=None) -> bytes: | |
| infos = infos or {} # Use an empty dictionary if infos is None | |
| with io.BytesIO() as output: | |
| kwargs = {k: v for k, v in infos.items() if v is not None} | |
| if ext == "jpg": | |
| ext = "jpeg" | |
| if "png" == ext.lower() and "parameters" in kwargs: | |
| pnginfo_data = PngImagePlugin.PngInfo() | |
| pnginfo_data.add_text("parameters", kwargs["parameters"]) | |
| kwargs["pnginfo"] = pnginfo_data | |
| pil_img.save(output, format=ext, quality=quality, **kwargs) | |
| image_bytes = output.getvalue() | |
| return image_bytes | |
| def load_img(img_bytes, gray: bool = False, return_info: bool = False): | |
| alpha_channel = None | |
| image = Image.open(io.BytesIO(img_bytes)) | |
| if return_info: | |
| infos = image.info | |
| try: | |
| image = ImageOps.exif_transpose(image) | |
| except: | |
| pass | |
| if gray: | |
| image = image.convert("L") | |
| np_img = np.array(image) | |
| else: | |
| if image.mode == "RGBA": | |
| np_img = np.array(image) | |
| alpha_channel = np_img[:, :, -1] | |
| np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB) | |
| else: | |
| image = image.convert("RGB") | |
| np_img = np.array(image) | |
| if return_info: | |
| return np_img, alpha_channel, infos | |
| return np_img, alpha_channel | |
| def norm_img(np_img): | |
| if len(np_img.shape) == 2: | |
| np_img = np_img[:, :, np.newaxis] | |
| np_img = np.transpose(np_img, (2, 0, 1)) | |
| np_img = np_img.astype("float32") / 255 | |
| return np_img | |
| def resize_max_size( | |
| np_img, size_limit: int, interpolation=cv2.INTER_CUBIC | |
| ) -> np.ndarray: | |
| # Resize image's longer size to size_limit if longer size larger than size_limit | |
| h, w = np_img.shape[:2] | |
| if max(h, w) > size_limit: | |
| ratio = size_limit / max(h, w) | |
| new_w = int(w * ratio + 0.5) | |
| new_h = int(h * ratio + 0.5) | |
| return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation) | |
| else: | |
| return np_img | |
| def pad_img_to_modulo( | |
| img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None | |
| ): | |
| """ | |
| Args: | |
| img: [H, W, C] | |
| mod: | |
| square: 是否为正方形 | |
| min_size: | |
| Returns: | |
| """ | |
| if len(img.shape) == 2: | |
| img = img[:, :, np.newaxis] | |
| height, width = img.shape[:2] | |
| out_height = ceil_modulo(height, mod) | |
| out_width = ceil_modulo(width, mod) | |
| if min_size is not None: | |
| assert min_size % mod == 0 | |
| out_width = max(min_size, out_width) | |
| out_height = max(min_size, out_height) | |
| if square: | |
| max_size = max(out_height, out_width) | |
| out_height = max_size | |
| out_width = max_size | |
| return np.pad( | |
| img, | |
| ((0, out_height - height), (0, out_width - width), (0, 0)), | |
| mode="symmetric", | |
| ) | |
| def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]: | |
| """ | |
| Args: | |
| mask: (h, w, 1) 0~255 | |
| Returns: | |
| """ | |
| height, width = mask.shape[:2] | |
| _, thresh = cv2.threshold(mask, 127, 255, 0) | |
| contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| boxes = [] | |
| for cnt in contours: | |
| x, y, w, h = cv2.boundingRect(cnt) | |
| box = np.array([x, y, x + w, y + h]).astype(int) | |
| box[::2] = np.clip(box[::2], 0, width) | |
| box[1::2] = np.clip(box[1::2], 0, height) | |
| boxes.append(box) | |
| return boxes | |
| def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]: | |
| """ | |
| Args: | |
| mask: (h, w) 0~255 | |
| Returns: | |
| """ | |
| _, thresh = cv2.threshold(mask, 127, 255, 0) | |
| contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| max_area = 0 | |
| max_index = -1 | |
| for i, cnt in enumerate(contours): | |
| area = cv2.contourArea(cnt) | |
| if area > max_area: | |
| max_area = area | |
| max_index = i | |
| if max_index != -1: | |
| new_mask = np.zeros_like(mask) | |
| return cv2.drawContours(new_mask, contours, max_index, 255, -1) | |
| else: | |
| return mask | |
| def is_mac(): | |
| return sys.platform == "darwin" | |
| def get_image_ext(img_bytes): | |
| w = imghdr.what("", img_bytes) | |
| if w is None: | |
| w = "jpeg" | |
| return w | |
| def decode_base64_to_image( | |
| encoding: str, gray=False | |
| ) -> Tuple[np.array, Optional[np.array], Dict]: | |
| if encoding.startswith("data:image/") or encoding.startswith( | |
| "data:application/octet-stream;base64," | |
| ): | |
| encoding = encoding.split(";")[1].split(",")[1] | |
| image = Image.open(io.BytesIO(base64.b64decode(encoding))) | |
| alpha_channel = None | |
| try: | |
| image = ImageOps.exif_transpose(image) | |
| except: | |
| pass | |
| # exif_transpose will remove exif rotate info,we must call image.info after exif_transpose | |
| infos = image.info | |
| if gray: | |
| image = image.convert("L") | |
| np_img = np.array(image) | |
| else: | |
| if image.mode == "RGBA": | |
| np_img = np.array(image) | |
| alpha_channel = np_img[:, :, -1] | |
| np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB) | |
| else: | |
| image = image.convert("RGB") | |
| np_img = np.array(image) | |
| return np_img, alpha_channel, infos | |
| def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes: | |
| img_bytes = pil_to_bytes( | |
| image, | |
| "png", | |
| quality=quality, | |
| infos=infos, | |
| ) | |
| return base64.b64encode(img_bytes) | |
| def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray: | |
| if alpha_channel is not None: | |
| if alpha_channel.shape[:2] != rgb_np_img.shape[:2]: | |
| alpha_channel = cv2.resize( | |
| alpha_channel, dsize=(rgb_np_img.shape[1], rgb_np_img.shape[0]) | |
| ) | |
| rgb_np_img = np.concatenate( | |
| (rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 | |
| ) | |
| return rgb_np_img | |
| def adjust_mask(mask: np.ndarray, kernel_size: int, operate): | |
| # fronted brush color "ffcc00bb" | |
| # kernel_size = kernel_size*2+1 | |
| mask[mask >= 127] = 255 | |
| mask[mask < 127] = 0 | |
| if operate == "reverse": | |
| mask = 255 - mask | |
| else: | |
| kernel = cv2.getStructuringElement( | |
| cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1) | |
| ) | |
| if operate == "expand": | |
| mask = cv2.dilate( | |
| mask, | |
| kernel, | |
| iterations=1, | |
| ) | |
| else: | |
| mask = cv2.erode( | |
| mask, | |
| kernel, | |
| iterations=1, | |
| ) | |
| res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8) | |
| res_mask[mask > 128] = [255, 203, 0, int(255 * 0.73)] | |
| res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA) | |
| return res_mask | |
| def gen_frontend_mask(bgr_or_gray_mask): | |
| if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1: | |
| bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY) | |
| # fronted brush color "ffcc00bb" | |
| # TODO: how to set kernel size? | |
| kernel_size = 9 | |
| bgr_or_gray_mask = cv2.dilate( | |
| bgr_or_gray_mask, | |
| np.ones((kernel_size, kernel_size), np.uint8), | |
| iterations=1, | |
| ) | |
| res_mask = np.zeros( | |
| (bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8 | |
| ) | |
| res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)] | |
| res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA) | |
| return res_mask | |