|
|
"""Processor class for MarkupDM.""" |
|
|
|
|
|
import math |
|
|
import re |
|
|
import shutil |
|
|
import subprocess |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from .fonts import FontManager |
|
|
from PIL import Image, ImageDraw |
|
|
from transformers import ( |
|
|
ImageProcessingMixin, |
|
|
PreTrainedModel, |
|
|
PreTrainedTokenizerBase, |
|
|
ProcessorMixin, |
|
|
) |
|
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
MAXIMUM_DECODE_IMAGE_SIZE = 4096 |
|
|
IMG_FORMAT = "{:03d}.png" |
|
|
FONT_FORMAT = "{:03d}.ttf" |
|
|
|
|
|
|
|
|
class MarkupDMProcessor(ProcessorMixin): |
|
|
attributes = ["tokenizer", "image_processor"] |
|
|
|
|
|
|
|
|
tokenizer_class = "AutoTokenizer" |
|
|
tokenizer: PreTrainedTokenizerBase |
|
|
|
|
|
|
|
|
image_processor_class = "AutoImageProcessor" |
|
|
image_processor: ImageProcessingMixin |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
tokenizer: PreTrainedTokenizerBase, |
|
|
image_processor: ImageProcessingMixin, |
|
|
): |
|
|
super().__init__(tokenizer, image_processor) |
|
|
|
|
|
|
|
|
if "<begin_of_image>" not in tokenizer.additional_special_tokens: |
|
|
self.extend_base_tokenizer(self.tokenizer) |
|
|
|
|
|
|
|
|
boi = "<begin_of_image>" |
|
|
img_sep = "<image_sep>" |
|
|
self.re_img_size = re.compile(rf"{boi}(\d+){img_sep}(\d+){img_sep}") |
|
|
self.re_svg_width = re.compile(r'<svg[^>]*\bwidth="(\d+)"[^>]*>') |
|
|
self.re_svg_height = re.compile(r'<svg[^>]*\bheight="(\d+)"[^>]*>') |
|
|
|
|
|
|
|
|
self.font_manager = None |
|
|
|
|
|
def extend_base_tokenizer(self, tokenizer: PreTrainedTokenizerBase) -> None: |
|
|
logger.info("Extending tokenizer...") |
|
|
tokenizer.clean_up_tokenization_spaces = False |
|
|
|
|
|
|
|
|
additional_special_tokens = [ |
|
|
"<begin_of_image>", |
|
|
"<end_of_image>", |
|
|
"<image_sep>", |
|
|
"<image_token>", |
|
|
] |
|
|
logger.info(f"Add special tokens: {additional_special_tokens}") |
|
|
tokenizer.add_special_tokens( |
|
|
{"additional_special_tokens": additional_special_tokens}, |
|
|
replace_additional_special_tokens=False, |
|
|
) |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
svg: str | None = None, |
|
|
images: list[Image.Image] | None = None, |
|
|
filenames: list[str] | None = None, |
|
|
vision_model: PreTrainedModel | None = None, |
|
|
) -> dict: |
|
|
|
|
|
if not isinstance(images, list): |
|
|
images = [images] |
|
|
|
|
|
if len(images) > 0 and images[0] is not None: |
|
|
output = self.preprocess_images(images) |
|
|
output = self.encode_images(output, vision_model) |
|
|
else: |
|
|
output = {"width": [], "height": [], "image_ids": []} |
|
|
|
|
|
|
|
|
output.update({"svg": svg, "filenames": filenames}) |
|
|
output = self.tokenize_example(output) |
|
|
|
|
|
return output |
|
|
|
|
|
def preprocess_images(self, images: list[Image.Image]) -> dict: |
|
|
assert images is not None, "Images must be provided." |
|
|
output: dict = {"image": [], "width": [], "height": []} |
|
|
|
|
|
for image in images: |
|
|
processed = self.image_processor(image) |
|
|
for key, value in processed.items(): |
|
|
output[key].append(value) |
|
|
|
|
|
|
|
|
output["image"] = torch.stack(output["image"]) |
|
|
|
|
|
return output |
|
|
|
|
|
def encode_images(self, example: dict, vision_model: PreTrainedModel) -> dict: |
|
|
if "images" in example and "width" not in example: |
|
|
example = self.preprocess_images(example["images"]) |
|
|
|
|
|
assert vision_model is not None, "Vision model must be provided." |
|
|
image = example.pop("image") |
|
|
image = image.to(dtype=vision_model.dtype, device=vision_model.device) |
|
|
with torch.inference_mode(): |
|
|
_, _, (_, _, image_ids) = vision_model.model.encode(image) |
|
|
example["image_ids"] = list(image_ids.view(image.size(0), -1).cpu()) |
|
|
|
|
|
return example |
|
|
|
|
|
def tokenize_example(self, example: dict) -> dict: |
|
|
|
|
|
for key in ["svg", "filenames", "width", "height", "image_ids"]: |
|
|
msg = f"Missing key: {key}." |
|
|
if key in ["width", "height", "image_ids"]: |
|
|
msg += " Images must be encoded first using `encode_images`." |
|
|
assert example.get(key, None) is not None, msg |
|
|
|
|
|
tokenizer = self.tokenizer |
|
|
bos_id = tokenizer.bos_token_id |
|
|
eos_id = tokenizer.eos_token_id |
|
|
bos_id = bos_id if bos_id is not None else eos_id |
|
|
boi_id = tokenizer.convert_tokens_to_ids("<begin_of_image>") |
|
|
eoi_id = tokenizer.convert_tokens_to_ids("<end_of_image>") |
|
|
img_sep_id = tokenizer.convert_tokens_to_ids("<image_sep>") |
|
|
|
|
|
|
|
|
name2token = {} |
|
|
for filename, image_ids, width, height in zip( |
|
|
example["filenames"], |
|
|
example["image_ids"], |
|
|
example["width"], |
|
|
example["height"], |
|
|
): |
|
|
_image_ids = (image_ids + len(tokenizer)).tolist() |
|
|
W_tokens = tokenizer.encode(str(width)) |
|
|
H_tokens = tokenizer.encode(str(height)) |
|
|
|
|
|
|
|
|
image_tokens = [ |
|
|
boi_id, |
|
|
*W_tokens, |
|
|
img_sep_id, |
|
|
*H_tokens, |
|
|
img_sep_id, |
|
|
*_image_ids, |
|
|
eoi_id, |
|
|
] |
|
|
|
|
|
name2token[filename] = image_tokens |
|
|
|
|
|
|
|
|
|
|
|
tokens = [bos_id] |
|
|
svg = example["svg"] |
|
|
while svg: |
|
|
|
|
|
start, end = len(svg), len(svg) |
|
|
for name in name2token.keys(): |
|
|
_start = svg.find(name) |
|
|
if -1 < _start and _start < start: |
|
|
start = _start |
|
|
end = start + len(name) |
|
|
|
|
|
|
|
|
tokens += tokenizer.encode(svg[:start]) |
|
|
|
|
|
|
|
|
if start < end: |
|
|
tokens += name2token[svg[start:end]] |
|
|
|
|
|
|
|
|
svg = svg[end:] |
|
|
|
|
|
tokens.append(eos_id) |
|
|
|
|
|
|
|
|
input_ids = torch.tensor(tokens) |
|
|
image_mask = input_ids >= len(tokenizer) |
|
|
|
|
|
|
|
|
image_pos_ids = torch.zeros_like(input_ids) |
|
|
if len(example["image_ids"]) > 0: |
|
|
length = example["image_ids"][0].size(0) |
|
|
num_images = sum(image_mask) // length |
|
|
image_pos_ids[image_mask] = torch.arange(length).repeat(num_images) |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"image_mask": image_mask, |
|
|
"image_pos_ids": image_pos_ids, |
|
|
} |
|
|
|
|
|
def decode( |
|
|
self, |
|
|
tokens: torch.Tensor | np.ndarray, |
|
|
vision_model: PreTrainedModel | None = None, |
|
|
) -> dict: |
|
|
tokenizer = self.tokenizer |
|
|
bos = tokenizer.bos_token |
|
|
eos = tokenizer.eos_token |
|
|
bos = bos if bos is not None else eos |
|
|
|
|
|
|
|
|
msg = "Should be reverted from FIM format before decoding." |
|
|
for fim_type in ["prefix", "middle", "suffix"]: |
|
|
token_id = tokenizer.convert_tokens_to_ids(f"<fim_{fim_type}>") |
|
|
if token_id is None: |
|
|
token_id = tokenizer.convert_tokens_to_ids(f"<|fim_{fim_type}|>") |
|
|
assert token_id is not None, f"{fim_type} token not found" |
|
|
assert token_id not in tokens, msg |
|
|
|
|
|
tokens = torch.asarray(tokens).detach().cpu() |
|
|
assert tokens.ndim == 1, "Tokens must be 1D." |
|
|
boi_id = tokenizer.convert_tokens_to_ids("<begin_of_image>") |
|
|
eoi_id = tokenizer.convert_tokens_to_ids("<end_of_image>") |
|
|
|
|
|
|
|
|
svg = "" |
|
|
images: list = [] |
|
|
filenames: list = [] |
|
|
while len(tokens) > 0: |
|
|
|
|
|
boi_idx = torch.where(tokens == boi_id)[0] |
|
|
eoi_idx = torch.where(tokens == eoi_id)[0] |
|
|
if boi_idx.size(0) > 0: |
|
|
start = int(boi_idx[0].item()) |
|
|
end = int(eoi_idx[0].item()) + 1 if eoi_idx.size(0) > 0 else len(tokens) |
|
|
assert start < end, "Invalid image tokens." |
|
|
else: |
|
|
start, end = len(tokens), len(tokens) |
|
|
|
|
|
|
|
|
svg += tokenizer.decode(tokens[:start]) |
|
|
|
|
|
|
|
|
if start < end: |
|
|
|
|
|
image_tokens = tokens[start:end] |
|
|
image_text = tokenizer.decode(image_tokens) |
|
|
matched = self.re_img_size.match(image_text) |
|
|
if matched is not None: |
|
|
width, height = map(int, matched.groups()) |
|
|
else: |
|
|
width = self.image_processor.size |
|
|
height = self.image_processor.size |
|
|
|
|
|
|
|
|
image_mask = image_tokens >= len(tokenizer) |
|
|
image_ids = image_tokens[image_mask] - len(tokenizer) |
|
|
image = self.decode_image(vision_model, image_ids, width, height) |
|
|
filename = IMG_FORMAT.format(len(images)) |
|
|
svg += filename |
|
|
|
|
|
images.append(image) |
|
|
filenames.append(filename) |
|
|
|
|
|
|
|
|
tokens = tokens[end:] |
|
|
|
|
|
|
|
|
svg = re.sub(rf"({re.escape(bos)})+", bos, svg) |
|
|
svg = re.sub(rf"({re.escape(eos)})+", eos, svg) |
|
|
|
|
|
|
|
|
i_bos = svg.find(bos) |
|
|
svg = svg[i_bos + len(bos) :] if i_bos > -1 else svg |
|
|
i_eos = svg.find(eos, i_bos + 1) |
|
|
svg = svg[:i_eos] if i_eos > -1 else svg |
|
|
|
|
|
return {"svg": svg, "images": images, "filenames": filenames} |
|
|
|
|
|
def decode_image( |
|
|
self, |
|
|
vision_model: PreTrainedModel | None = None, |
|
|
image_ids: torch.Tensor | np.ndarray | None = None, |
|
|
width: int | None = None, |
|
|
height: int | None = None, |
|
|
dummy_color: tuple[int, int, int, int] = (200,) * 4, |
|
|
pad_value: int = 0, |
|
|
) -> Image.Image: |
|
|
|
|
|
width = width or self.image_processor.size |
|
|
height = height or self.image_processor.size |
|
|
width, height = self.compute_safe_image_size(width, height) |
|
|
|
|
|
if vision_model is None and image_ids is None: |
|
|
|
|
|
return Image.new("RGBA", (width, height), dummy_color) |
|
|
|
|
|
|
|
|
assert vision_model is not None, "Vision model must be provided." |
|
|
scale_factor = 2 ** (vision_model.model.encoder.num_resolutions - 1) |
|
|
latent_size = self.image_processor.size // scale_factor |
|
|
required_length = latent_size**2 |
|
|
|
|
|
|
|
|
image_ids = torch.asarray(image_ids, device=vision_model.device) |
|
|
code_length = image_ids.shape[0] |
|
|
if code_length < required_length: |
|
|
pad_size = required_length - code_length |
|
|
pad = torch.full((pad_size,), pad_value).to(image_ids) |
|
|
image_ids = torch.cat([image_ids, pad]) |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
codebook_entry = vision_model.model.quantize.get_codebook_entry( |
|
|
image_ids, (1, latent_size, latent_size, -1) |
|
|
) |
|
|
recon = vision_model.model.decode(codebook_entry)[0].float() |
|
|
|
|
|
|
|
|
img = self.image_processor.postprocess( |
|
|
recon, self.image_processor.size, self.image_processor.size |
|
|
) |
|
|
|
|
|
|
|
|
if code_length < required_length: |
|
|
img = self.mask_padded_area(img, code_length, scale_factor) |
|
|
|
|
|
|
|
|
img = img.resize((width, height), resample=self.image_processor.resample) |
|
|
|
|
|
return img |
|
|
|
|
|
def compute_safe_image_size(self, width: int, height: int) -> tuple[int, int]: |
|
|
long_edge = max(width, height) |
|
|
if MAXIMUM_DECODE_IMAGE_SIZE < long_edge: |
|
|
scale = MAXIMUM_DECODE_IMAGE_SIZE / long_edge |
|
|
width = min(max(int(width * scale), 1), MAXIMUM_DECODE_IMAGE_SIZE) |
|
|
height = min(max(int(height * scale), 1), MAXIMUM_DECODE_IMAGE_SIZE) |
|
|
return width, height |
|
|
|
|
|
def mask_padded_area( |
|
|
self, |
|
|
img: Image.Image, |
|
|
code_length: int, |
|
|
scale_factor: int, |
|
|
fill: tuple[int, int, int, int] = (200, 200, 200, 255), |
|
|
) -> Image.Image: |
|
|
draw = ImageDraw.Draw(img, mode="RGBA") |
|
|
width, height = img.size |
|
|
zw = math.ceil(width / scale_factor) |
|
|
cw = code_length % zw |
|
|
ch = code_length // zw |
|
|
draw.polygon( |
|
|
[ |
|
|
(cw * scale_factor, ch * scale_factor), |
|
|
(width, ch * scale_factor), |
|
|
(width, height), |
|
|
(0, height), |
|
|
(0, (ch + 1) * scale_factor), |
|
|
(cw * scale_factor, (ch + 1) * scale_factor), |
|
|
], |
|
|
fill=fill, |
|
|
) |
|
|
return img |
|
|
|
|
|
def set_font_manager(self, fonts_path: str | None = None) -> None: |
|
|
self.font_manager = FontManager(fonts_path) |
|
|
|
|
|
def render_preprocess(self, example: dict, out_dir: str | Path) -> None: |
|
|
msg = "Font manager is not set. Call `set_font_manager` first." |
|
|
assert self.font_manager is not None, msg |
|
|
|
|
|
out_dir = Path(out_dir) |
|
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
|
svg = example["svg"] |
|
|
|
|
|
|
|
|
found = set() |
|
|
style_text = "text{dominant-baseline:text-before-edge}" |
|
|
for i, text_str in enumerate(re.findall("<text[^>]*>", svg)): |
|
|
matched = re.search('font-family="([^"]*)"', text_str) |
|
|
if matched is None: |
|
|
logger.warning(f"Font family not found in {text_str}") |
|
|
continue |
|
|
|
|
|
|
|
|
font_family = matched.group(1) |
|
|
is_bold = 'font-weight="bold"' in text_str |
|
|
is_italic = 'font-style="italic"' in text_str |
|
|
font_weight = "bold" if is_bold else "regular" |
|
|
if is_italic: |
|
|
font_style = "bolditalic" if is_bold else "italic" |
|
|
else: |
|
|
font_style = font_weight |
|
|
key = (font_family, font_weight, font_style) |
|
|
if key in found: |
|
|
continue |
|
|
|
|
|
font_bytes = self.font_manager.lookup( |
|
|
font_family=font_family, |
|
|
font_weight=font_weight, |
|
|
font_style=font_style, |
|
|
) |
|
|
|
|
|
|
|
|
font_path = FONT_FORMAT.format(i) |
|
|
font_face = "@font-face{" |
|
|
font_face += f"font-family:'{font_family}';" |
|
|
font_face += f"font-weight:{font_weight};" |
|
|
font_face += f"font-style:{font_style};" |
|
|
font_face += f"src:url('{font_path}');" |
|
|
font_face += "}" |
|
|
style_text += font_face |
|
|
|
|
|
|
|
|
Path(f"{out_dir}/{font_path}").write_bytes(font_bytes) |
|
|
found.add(key) |
|
|
|
|
|
|
|
|
matched = re.search("<svg[^>]*>", svg) |
|
|
assert matched is not None, "SVG tag not found" |
|
|
i = matched.span()[1] |
|
|
style = f"<style>{style_text}</style>" |
|
|
example["svg"] = svg[:i] + style + svg[i:] |
|
|
|
|
|
def render(self, example: dict, save_dir: str | Path | None = None) -> Image.Image: |
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
self.render_preprocess(example, tmp_dir) |
|
|
|
|
|
|
|
|
matched = self.re_svg_width.search(example["svg"]) |
|
|
assert matched is not None, "Width not found in SVG." |
|
|
width = int(matched.group(1)) |
|
|
matched = self.re_svg_height.search(example["svg"]) |
|
|
assert matched is not None, "Height not found in SVG." |
|
|
height = int(matched.group(1)) |
|
|
|
|
|
|
|
|
html = '<!DOCTYPE html><html><body style="margin: 0px">' |
|
|
html += f"{example['svg']}</body></html>" |
|
|
|
|
|
|
|
|
Path(f"{tmp_dir}/index.html").write_text(html, encoding="utf-8") |
|
|
|
|
|
|
|
|
for img, filename in zip(example["images"], example["filenames"]): |
|
|
Path(f"{tmp_dir}/{filename}").parent.mkdir(parents=True, exist_ok=True) |
|
|
img.save(f"{tmp_dir}/{filename}") |
|
|
|
|
|
|
|
|
command = [ |
|
|
"google-chrome", |
|
|
"--headless", |
|
|
"--disable-web-security", |
|
|
"--allow-running-insecure-content", |
|
|
"--no-sandbox", |
|
|
"--disable-infobars", |
|
|
"--hide-scrollbars", |
|
|
"--disable-dev-shm-usage", |
|
|
"--no-zygote", |
|
|
f"--window-size={width},{height}", |
|
|
f"--screenshot={tmp_dir}/screenshot.png", |
|
|
f"{tmp_dir}/index.html", |
|
|
] |
|
|
subprocess.run(command, check=True, stderr=subprocess.DEVNULL) |
|
|
|
|
|
|
|
|
out = Image.open(f"{tmp_dir}/screenshot.png") |
|
|
size = (width, height) |
|
|
out = out.resize(size, resample=Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
if save_dir is not None: |
|
|
shutil.copytree(tmp_dir, save_dir, dirs_exist_ok=True) |
|
|
|
|
|
return out |
|
|
|