"""Processor class for MarkupDM.""" import math import re import shutil import subprocess import tempfile from pathlib import Path import numpy as np import torch from .fonts import FontManager from PIL import Image, ImageDraw from transformers import ( ImageProcessingMixin, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, ) from transformers.utils import logging logger = logging.get_logger(__name__) MAXIMUM_DECODE_IMAGE_SIZE = 4096 IMG_FORMAT = "{:03d}.png" FONT_FORMAT = "{:03d}.ttf" class MarkupDMProcessor(ProcessorMixin): # type: ignore attributes = ["tokenizer", "image_processor"] # The superclass checks if the tokenizer is a subclass of `PreTrainedTokenizerBase` tokenizer_class = "AutoTokenizer" tokenizer: PreTrainedTokenizerBase # and the image_processor is a subclass of `ImageProcessingMixin`. image_processor_class = "AutoImageProcessor" image_processor: ImageProcessingMixin def __init__( self, tokenizer: PreTrainedTokenizerBase, image_processor: ImageProcessingMixin, ): super().__init__(tokenizer, image_processor) # Extend the tokenizer if it has not been extended yet. if "" not in tokenizer.additional_special_tokens: self.extend_base_tokenizer(self.tokenizer) # Regular expressions boi = "" img_sep = "" self.re_img_size = re.compile(rf"{boi}(\d+){img_sep}(\d+){img_sep}") self.re_svg_width = re.compile(r']*\bwidth="(\d+)"[^>]*>') self.re_svg_height = re.compile(r']*\bheight="(\d+)"[^>]*>') # Font manager self.font_manager = None def extend_base_tokenizer(self, tokenizer: PreTrainedTokenizerBase) -> None: logger.info("Extending tokenizer...") tokenizer.clean_up_tokenization_spaces = False # Add special tokens additional_special_tokens = [ "", "", "", "", ] logger.info(f"Add special tokens: {additional_special_tokens}") tokenizer.add_special_tokens( {"additional_special_tokens": additional_special_tokens}, replace_additional_special_tokens=False, ) def __call__( self, svg: str | None = None, images: list[Image.Image] | None = None, filenames: list[str] | None = None, vision_model: PreTrainedModel | None = None, ) -> dict: # Process images if not isinstance(images, list): images = [images] # type: ignore if len(images) > 0 and images[0] is not None: output = self.preprocess_images(images) output = self.encode_images(output, vision_model) else: output = {"width": [], "height": [], "image_ids": []} # Process the entire example output.update({"svg": svg, "filenames": filenames}) output = self.tokenize_example(output) return output def preprocess_images(self, images: list[Image.Image]) -> dict: assert images is not None, "Images must be provided." output: dict = {"image": [], "width": [], "height": []} for image in images: processed = self.image_processor(image) for key, value in processed.items(): output[key].append(value) # Stack tensors output["image"] = torch.stack(output["image"]) return output def encode_images(self, example: dict, vision_model: PreTrainedModel) -> dict: if "images" in example and "width" not in example: example = self.preprocess_images(example["images"]) assert vision_model is not None, "Vision model must be provided." image = example.pop("image") image = image.to(dtype=vision_model.dtype, device=vision_model.device) with torch.inference_mode(): _, _, (_, _, image_ids) = vision_model.model.encode(image) example["image_ids"] = list(image_ids.view(image.size(0), -1).cpu()) return example def tokenize_example(self, example: dict) -> dict: # Validate the input example for key in ["svg", "filenames", "width", "height", "image_ids"]: msg = f"Missing key: {key}." if key in ["width", "height", "image_ids"]: msg += " Images must be encoded first using `encode_images`." assert example.get(key, None) is not None, msg tokenizer = self.tokenizer bos_id = tokenizer.bos_token_id eos_id = tokenizer.eos_token_id bos_id = bos_id if bos_id is not None else eos_id boi_id = tokenizer.convert_tokens_to_ids("") eoi_id = tokenizer.convert_tokens_to_ids("") img_sep_id = tokenizer.convert_tokens_to_ids("") # Tokenize images and build a mapping from image filenames to tokens name2token = {} for filename, image_ids, width, height in zip( example["filenames"], example["image_ids"], example["width"], example["height"], ): _image_ids = (image_ids + len(tokenizer)).tolist() W_tokens = tokenizer.encode(str(width)) H_tokens = tokenizer.encode(str(height)) # Image tokens image_tokens = [ boi_id, *W_tokens, img_sep_id, *H_tokens, img_sep_id, *_image_ids, eoi_id, ] name2token[filename] = image_tokens # Tokenize SVG # TODO: remove bos_id as it seems to be not necessary in modern practice tokens = [bos_id] svg = example["svg"] while svg: # Find the start position of the next image filename start, end = len(svg), len(svg) for name in name2token.keys(): _start = svg.find(name) if -1 < _start and _start < start: start = _start end = start + len(name) # Tokenize the text before the image filename tokens += tokenizer.encode(svg[:start]) # Append the tokenized image if start < end: tokens += name2token[svg[start:end]] # Update the remaining text svg = svg[end:] tokens.append(eos_id) # Format output data input_ids = torch.tensor(tokens) image_mask = input_ids >= len(tokenizer) # Compute image position ids image_pos_ids = torch.zeros_like(input_ids) if len(example["image_ids"]) > 0: length = example["image_ids"][0].size(0) num_images = sum(image_mask) // length image_pos_ids[image_mask] = torch.arange(length).repeat(num_images) return { "input_ids": input_ids, "image_mask": image_mask, "image_pos_ids": image_pos_ids, } def decode( self, tokens: torch.Tensor | np.ndarray, vision_model: PreTrainedModel | None = None, ) -> dict: tokenizer = self.tokenizer bos = tokenizer.bos_token eos = tokenizer.eos_token bos = bos if bos is not None else eos # Validate the input tokens msg = "Should be reverted from FIM format before decoding." for fim_type in ["prefix", "middle", "suffix"]: token_id = tokenizer.convert_tokens_to_ids(f"") if token_id is None: token_id = tokenizer.convert_tokens_to_ids(f"<|fim_{fim_type}|>") assert token_id is not None, f"{fim_type} token not found" assert token_id not in tokens, msg tokens = torch.asarray(tokens).detach().cpu() assert tokens.ndim == 1, "Tokens must be 1D." boi_id = tokenizer.convert_tokens_to_ids("") eoi_id = tokenizer.convert_tokens_to_ids("") # Decode tokens svg = "" images: list = [] filenames: list = [] while len(tokens) > 0: # Find the start position of the next image filename boi_idx = torch.where(tokens == boi_id)[0] eoi_idx = torch.where(tokens == eoi_id)[0] if boi_idx.size(0) > 0: start = int(boi_idx[0].item()) end = int(eoi_idx[0].item()) + 1 if eoi_idx.size(0) > 0 else len(tokens) assert start < end, "Invalid image tokens." else: start, end = len(tokens), len(tokens) # Decode the tokens before the image tokens svg += tokenizer.decode(tokens[:start]) # Decode the image tokens if start < end: # Extract image size image_tokens = tokens[start:end] image_text = tokenizer.decode(image_tokens) matched = self.re_img_size.match(image_text) if matched is not None: width, height = map(int, matched.groups()) else: width = self.image_processor.size height = self.image_processor.size # Decode tokens to PIL image image_mask = image_tokens >= len(tokenizer) image_ids = image_tokens[image_mask] - len(tokenizer) image = self.decode_image(vision_model, image_ids, width, height) filename = IMG_FORMAT.format(len(images)) svg += filename images.append(image) filenames.append(filename) # Update the remaining tokens tokens = tokens[end:] # Remove consecutive and svg = re.sub(rf"({re.escape(bos)})+", bos, svg) svg = re.sub(rf"({re.escape(eos)})+", eos, svg) # Extract the text between and i_bos = svg.find(bos) svg = svg[i_bos + len(bos) :] if i_bos > -1 else svg i_eos = svg.find(eos, i_bos + 1) svg = svg[:i_eos] if i_eos > -1 else svg return {"svg": svg, "images": images, "filenames": filenames} def decode_image( self, vision_model: PreTrainedModel | None = None, image_ids: torch.Tensor | np.ndarray | None = None, width: int | None = None, height: int | None = None, dummy_color: tuple[int, int, int, int] = (200,) * 4, pad_value: int = 0, ) -> Image.Image: # Prepare image size width = width or self.image_processor.size height = height or self.image_processor.size width, height = self.compute_safe_image_size(width, height) if vision_model is None and image_ids is None: # Return a dummy image return Image.new("RGBA", (width, height), dummy_color) # Compute required length assert vision_model is not None, "Vision model must be provided." scale_factor = 2 ** (vision_model.model.encoder.num_resolutions - 1) latent_size = self.image_processor.size // scale_factor required_length = latent_size**2 # Pad image ids if necessary image_ids = torch.asarray(image_ids, device=vision_model.device) code_length = image_ids.shape[0] # type: ignore if code_length < required_length: pad_size = required_length - code_length pad = torch.full((pad_size,), pad_value).to(image_ids) image_ids = torch.cat([image_ids, pad]) # Decode image with torch.inference_mode(): codebook_entry = vision_model.model.quantize.get_codebook_entry( image_ids, (1, latent_size, latent_size, -1) ) recon = vision_model.model.decode(codebook_entry)[0].float() # Postprocess image img = self.image_processor.postprocess( recon, self.image_processor.size, self.image_processor.size ) # Mask the padded area if code_length < required_length: img = self.mask_padded_area(img, code_length, scale_factor) # Resize the image to the original size img = img.resize((width, height), resample=self.image_processor.resample) return img # type: ignore def compute_safe_image_size(self, width: int, height: int) -> tuple[int, int]: long_edge = max(width, height) if MAXIMUM_DECODE_IMAGE_SIZE < long_edge: scale = MAXIMUM_DECODE_IMAGE_SIZE / long_edge width = min(max(int(width * scale), 1), MAXIMUM_DECODE_IMAGE_SIZE) height = min(max(int(height * scale), 1), MAXIMUM_DECODE_IMAGE_SIZE) return width, height def mask_padded_area( self, img: Image.Image, code_length: int, scale_factor: int, fill: tuple[int, int, int, int] = (200, 200, 200, 255), ) -> Image.Image: draw = ImageDraw.Draw(img, mode="RGBA") width, height = img.size zw = math.ceil(width / scale_factor) cw = code_length % zw ch = code_length // zw draw.polygon( [ (cw * scale_factor, ch * scale_factor), (width, ch * scale_factor), (width, height), (0, height), (0, (ch + 1) * scale_factor), (cw * scale_factor, (ch + 1) * scale_factor), ], fill=fill, ) return img def set_font_manager(self, fonts_path: str | None = None) -> None: self.font_manager = FontManager(fonts_path) def render_preprocess(self, example: dict, out_dir: str | Path) -> None: msg = "Font manager is not set. Call `set_font_manager` first." assert self.font_manager is not None, msg out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) svg = example["svg"] # Costruct style tag found = set() style_text = "text{dominant-baseline:text-before-edge}" for i, text_str in enumerate(re.findall("]*>", svg)): matched = re.search('font-family="([^"]*)"', text_str) if matched is None: logger.warning(f"Font family not found in {text_str}") continue # Parse font attributes font_family = matched.group(1) is_bold = 'font-weight="bold"' in text_str is_italic = 'font-style="italic"' in text_str font_weight = "bold" if is_bold else "regular" if is_italic: font_style = "bolditalic" if is_bold else "italic" else: font_style = font_weight key = (font_family, font_weight, font_style) if key in found: continue font_bytes = self.font_manager.lookup( font_family=font_family, font_weight=font_weight, font_style=font_style, ) # @font-face font_path = FONT_FORMAT.format(i) font_face = "@font-face{" font_face += f"font-family:'{font_family}';" font_face += f"font-weight:{font_weight};" font_face += f"font-style:{font_style};" font_face += f"src:url('{font_path}');" font_face += "}" style_text += font_face # Save font Path(f"{out_dir}/{font_path}").write_bytes(font_bytes) found.add(key) # Insert style tag matched = re.search("]*>", svg) assert matched is not None, "SVG tag not found" i = matched.span()[1] style = f"" example["svg"] = svg[:i] + style + svg[i:] def render(self, example: dict, save_dir: str | Path | None = None) -> Image.Image: with tempfile.TemporaryDirectory() as tmp_dir: self.render_preprocess(example, tmp_dir) # Parse the SVG size matched = self.re_svg_width.search(example["svg"]) assert matched is not None, "Width not found in SVG." width = int(matched.group(1)) matched = self.re_svg_height.search(example["svg"]) assert matched is not None, "Height not found in SVG." height = int(matched.group(1)) # Convert SVG to HTML html = '' html += f"{example['svg']}" # Save HTML Path(f"{tmp_dir}/index.html").write_text(html, encoding="utf-8") # Save images for img, filename in zip(example["images"], example["filenames"]): Path(f"{tmp_dir}/{filename}").parent.mkdir(parents=True, exist_ok=True) img.save(f"{tmp_dir}/{filename}") # Take screenshot command = [ "google-chrome", "--headless", "--disable-web-security", "--allow-running-insecure-content", "--no-sandbox", "--disable-infobars", "--hide-scrollbars", "--disable-dev-shm-usage", "--no-zygote", f"--window-size={width},{height}", f"--screenshot={tmp_dir}/screenshot.png", f"{tmp_dir}/index.html", ] subprocess.run(command, check=True, stderr=subprocess.DEVNULL) # Load the screenshot as PIL image out = Image.open(f"{tmp_dir}/screenshot.png") size = (width, height) out = out.resize(size, resample=Image.Resampling.LANCZOS) # type: ignore # Copy the result if save_dir is specified if save_dir is not None: shutil.copytree(tmp_dir, save_dir, dirs_exist_ok=True) return out