MagicQuillV2 / util.py
LiuZichen's picture
Upload 3 files
47a0ec7 verified
raw
history blame
7.5 kB
import random
from collections import Counter
import numpy as np
from torchvision import transforms
import cv2 # OpenCV
import torch
import re
import io
import base64
from PIL import Image, ImageOps
PREFERRED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]
def get_bounding_box_from_mask(mask, padded=False):
mask = mask.squeeze()
rows, cols = torch.where(mask > 0.5)
if len(rows) == 0 or len(cols) == 0:
return (0, 0, 0, 0)
height, width = mask.shape
if padded:
padded_size = max(width, height)
if width < height:
offset_x = (padded_size - width) / 2
offset_y = 0
else:
offset_y = (padded_size - height) / 2
offset_x = 0
top_left_x = round(float((torch.min(cols).item() + offset_x) / padded_size), 3)
bottom_right_x = round(float((torch.max(cols).item() + offset_x) / padded_size), 3)
top_left_y = round(float((torch.min(rows).item() + offset_y) / padded_size), 3)
bottom_right_y = round(float((torch.max(rows).item() + offset_y) / padded_size), 3)
else:
offset_x = 0
offset_y = 0
top_left_x = round(float(torch.min(cols).item() / width), 3)
bottom_right_x = round(float(torch.max(cols).item() / width), 3)
top_left_y = round(float(torch.min(rows).item() / height), 3)
bottom_right_y = round(float(torch.max(rows).item() / height), 3)
return (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
def extract_bbox(text):
pattern = r"\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]"
match = re.search(pattern, text)
return (int(match.group(1)), int(match.group(2)), int(match.group(3)), int(match.group(4)))
def resize_bbox(bbox, width_ratio, height_ratio):
x1, y1, x2, y2 = bbox
new_x1 = int(x1 * width_ratio)
new_y1 = int(y1 * height_ratio)
new_x2 = int(x2 * width_ratio)
new_y2 = int(y2 * height_ratio)
return (new_x1, new_y1, new_x2, new_y2)
def tensor_to_base64(tensor, quality=80, method=6):
tensor = tensor.squeeze(0).clone().detach().cpu()
if tensor.dtype == torch.float32 or tensor.dtype == torch.float64 or tensor.dtype == torch.float16:
tensor *= 255
tensor = tensor.to(torch.uint8)
if tensor.ndim == 2: # 灰度图像
pil_image = Image.fromarray(tensor.numpy(), 'L')
pil_image = pil_image.convert('RGB')
elif tensor.ndim == 3:
if tensor.shape[2] == 1: # 单通道
pil_image = Image.fromarray(tensor.numpy().squeeze(2), 'L')
pil_image = pil_image.convert('RGB')
elif tensor.shape[2] == 3: # RGB
pil_image = Image.fromarray(tensor.numpy(), 'RGB')
elif tensor.shape[2] == 4: # RGBA
pil_image = Image.fromarray(tensor.numpy(), 'RGBA')
else:
raise ValueError(f"Unsupported number of channels: {tensor.shape[2]}")
else:
raise ValueError(f"Unsupported tensor dimensions: {tensor.ndim}")
buffered = io.BytesIO()
pil_image.save(buffered, format="WEBP", quality=quality, method=method, lossless=False)
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
def load_and_preprocess_image(image_path, convert_to='RGB', has_alpha=False):
image = Image.open(image_path)
image = ImageOps.exif_transpose(image)
if image.mode == 'RGBA':
background = Image.new('RGBA', image.size, (255, 255, 255, 255))
image = Image.alpha_composite(background, image)
image = image.convert(convert_to)
image_array = np.array(image).astype(np.float32) / 255.0
if has_alpha and convert_to == 'RGBA':
image_tensor = torch.from_numpy(image_array)[None,]
else:
if len(image_array.shape) == 3 and image_array.shape[2] > 3:
image_array = image_array[:, :, :3]
image_tensor = torch.from_numpy(image_array)[None,]
return image_tensor
def process_background(base64_image, convert_to='RGB', size=None):
image_data = read_base64_image(base64_image)
image = Image.open(image_data)
image = ImageOps.exif_transpose(image)
image = image.convert(convert_to)
# Select preferred size by closest aspect ratio, then snap to multiple_of
w0, h0 = image.size
aspect_ratio = (w0 / h0) if h0 != 0 else 1.0
# Choose the (w, h) whose aspect ratio is closest to the input
_, tw, th = min((abs(aspect_ratio - w / h), w, h) for (w, h) in PREFERRED_KONTEXT_RESOLUTIONS)
multiple_of = 16 # default: vae_scale_factor (8) * 2
tw = (tw // multiple_of) * multiple_of
th = (th // multiple_of) * multiple_of
if (w0, h0) != (tw, th):
image = image.resize((tw, th), resample=Image.BICUBIC)
image_array = np.array(image).astype(np.uint8)
image_tensor = torch.from_numpy(image_array)[None,]
return image_tensor
def read_base64_image(base64_image):
if base64_image.startswith("data:image/png;base64,"):
base64_image = base64_image.split(",")[1]
elif base64_image.startswith("data:image/jpeg;base64,"):
base64_image = base64_image.split(",")[1]
elif base64_image.startswith("data:image/webp;base64,"):
base64_image = base64_image.split(",")[1]
else:
raise ValueError("Unsupported image format.")
image_data = base64.b64decode(base64_image)
return io.BytesIO(image_data)
def create_alpha_mask(image_path):
"""Create an alpha mask from the alpha channel of an image."""
image = Image.open(image_path)
image = ImageOps.exif_transpose(image)
mask = torch.zeros((1, image.height, image.width), dtype=torch.float32)
if 'A' in image.getbands():
alpha_channel = np.array(image.getchannel('A')).astype(np.float32) / 255.0
mask[0] = 1.0 - torch.from_numpy(alpha_channel)
return mask
def get_mask_bbox(mask_tensor, padding=10):
assert len(mask_tensor.shape) == 3 and mask_tensor.shape[0] == 1
_, H, W = mask_tensor.shape
mask_2d = mask_tensor.squeeze(0)
y_coords, x_coords = torch.where(mask_2d > 0)
if len(y_coords) == 0:
return None
x_min = int(torch.min(x_coords))
y_min = int(torch.min(y_coords))
x_max = int(torch.max(x_coords))
y_max = int(torch.max(y_coords))
x_min = max(0, x_min - padding)
y_min = max(0, y_min - padding)
x_max = min(W - 1, x_max + padding)
y_max = min(H - 1, y_max + padding)
return x_min, y_min, x_max, y_max
def tensor_to_pil(tensor):
tensor = tensor.squeeze(0).clone().detach().cpu()
if tensor.dtype in [torch.float32, torch.float64, torch.float16]:
if tensor.max() <= 1.0:
tensor *= 255
tensor = tensor.to(torch.uint8)
if tensor.ndim == 2: # 灰度图像 [H, W]
return Image.fromarray(tensor.numpy(), 'L')
elif tensor.ndim == 3:
if tensor.shape[2] == 1: # 单通道 [H, W, 1]
return Image.fromarray(tensor.numpy().squeeze(2), 'L')
elif tensor.shape[2] >= 3: # RGB [H, W, 3]
return Image.fromarray(tensor.numpy(), 'RGB')
else:
raise ValueError(f"不支持的通道数: {tensor.shape[2]}")
else:
raise ValueError(f"不支持的tensor维度: {tensor.ndim}")