File size: 18,054 Bytes
0bdb3ea
 
 
 
 
 
 
 
 
 
 
7a654e0
0bdb3ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74a5657
 
0bdb3ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a654e0
0012b8d
0bdb3ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
"""Processor class for MarkupDM."""

import math
import re
import shutil
import subprocess
import tempfile
from pathlib import Path

import numpy as np
import torch
from .fonts import FontManager
from PIL import Image, ImageDraw
from transformers import (
    ImageProcessingMixin,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    ProcessorMixin,
)
from transformers.utils import logging

logger = logging.get_logger(__name__)

MAXIMUM_DECODE_IMAGE_SIZE = 4096
IMG_FORMAT = "{:03d}.png"
FONT_FORMAT = "{:03d}.ttf"


class MarkupDMProcessor(ProcessorMixin):  # type: ignore
    attributes = ["tokenizer", "image_processor"]

    # The superclass checks if the tokenizer is a subclass of `PreTrainedTokenizerBase`
    tokenizer_class = "AutoTokenizer"
    tokenizer: PreTrainedTokenizerBase

    # and the image_processor is a subclass of `ImageProcessingMixin`.
    image_processor_class = "AutoImageProcessor"
    image_processor: ImageProcessingMixin

    def __init__(
        self,
        tokenizer: PreTrainedTokenizerBase,
        image_processor: ImageProcessingMixin,
    ):
        super().__init__(tokenizer, image_processor)

        # Extend the tokenizer if it has not been extended yet.
        if "<begin_of_image>" not in tokenizer.additional_special_tokens:
            self.extend_base_tokenizer(self.tokenizer)

        # Regular expressions
        boi = "<begin_of_image>"
        img_sep = "<image_sep>"
        self.re_img_size = re.compile(rf"{boi}(\d+){img_sep}(\d+){img_sep}")
        self.re_svg_width = re.compile(r'<svg[^>]*\bwidth="(\d+)"[^>]*>')
        self.re_svg_height = re.compile(r'<svg[^>]*\bheight="(\d+)"[^>]*>')

        # Font manager
        self.font_manager = None

    def extend_base_tokenizer(self, tokenizer: PreTrainedTokenizerBase) -> None:
        logger.info("Extending tokenizer...")
        tokenizer.clean_up_tokenization_spaces = False

        # Add special tokens
        additional_special_tokens = [
            "<begin_of_image>",
            "<end_of_image>",
            "<image_sep>",
            "<image_token>",
        ]
        logger.info(f"Add special tokens: {additional_special_tokens}")
        tokenizer.add_special_tokens(
            {"additional_special_tokens": additional_special_tokens},
            replace_additional_special_tokens=False,
        )

    def __call__(
        self,
        svg: str | None = None,
        images: list[Image.Image] | None = None,
        filenames: list[str] | None = None,
        vision_model: PreTrainedModel | None = None,
    ) -> dict:
        # Process images
        if not isinstance(images, list):
            images = [images]  # type: ignore

        if len(images) > 0 and images[0] is not None:
            output = self.preprocess_images(images)
            output = self.encode_images(output, vision_model)
        else:
            output = {"width": [], "height": [], "image_ids": []}

        # Process the entire example
        output.update({"svg": svg, "filenames": filenames})
        output = self.tokenize_example(output)

        return output

    def preprocess_images(self, images: list[Image.Image]) -> dict:
        assert images is not None, "Images must be provided."
        output: dict = {"image": [], "width": [], "height": []}

        for image in images:
            processed = self.image_processor(image)
            for key, value in processed.items():
                output[key].append(value)

        # Stack tensors
        output["image"] = torch.stack(output["image"])

        return output

    def encode_images(self, example: dict, vision_model: PreTrainedModel) -> dict:
        if "images" in example and "width" not in example:
            example = self.preprocess_images(example["images"])

        assert vision_model is not None, "Vision model must be provided."
        image = example.pop("image")
        image = image.to(dtype=vision_model.dtype, device=vision_model.device)
        with torch.inference_mode():
            _, _, (_, _, image_ids) = vision_model.model.encode(image)
        example["image_ids"] = list(image_ids.view(image.size(0), -1).cpu())

        return example

    def tokenize_example(self, example: dict) -> dict:
        # Validate the input example
        for key in ["svg", "filenames", "width", "height", "image_ids"]:
            msg = f"Missing key: {key}."
            if key in ["width", "height", "image_ids"]:
                msg += " Images must be encoded first using `encode_images`."
            assert example.get(key, None) is not None, msg

        tokenizer = self.tokenizer
        bos_id = tokenizer.bos_token_id
        eos_id = tokenizer.eos_token_id
        bos_id = bos_id if bos_id is not None else eos_id
        boi_id = tokenizer.convert_tokens_to_ids("<begin_of_image>")
        eoi_id = tokenizer.convert_tokens_to_ids("<end_of_image>")
        img_sep_id = tokenizer.convert_tokens_to_ids("<image_sep>")

        # Tokenize images and build a mapping from image filenames to tokens
        name2token = {}
        for filename, image_ids, width, height in zip(
            example["filenames"],
            example["image_ids"],
            example["width"],
            example["height"],
        ):
            _image_ids = (image_ids + len(tokenizer)).tolist()
            W_tokens = tokenizer.encode(str(width))
            H_tokens = tokenizer.encode(str(height))

            # Image tokens
            image_tokens = [
                boi_id,
                *W_tokens,
                img_sep_id,
                *H_tokens,
                img_sep_id,
                *_image_ids,
                eoi_id,
            ]

            name2token[filename] = image_tokens

        # Tokenize SVG
        # TODO: remove bos_id as it seems to be not necessary in modern practice
        tokens = [bos_id]
        svg = example["svg"]
        while svg:
            # Find the start position of the next image filename
            start, end = len(svg), len(svg)
            for name in name2token.keys():
                _start = svg.find(name)
                if -1 < _start and _start < start:
                    start = _start
                    end = start + len(name)

            # Tokenize the text before the image filename
            tokens += tokenizer.encode(svg[:start])

            # Append the tokenized image
            if start < end:
                tokens += name2token[svg[start:end]]

            # Update the remaining text
            svg = svg[end:]

        tokens.append(eos_id)

        # Format output data
        input_ids = torch.tensor(tokens)
        image_mask = input_ids >= len(tokenizer)

        # Compute image position ids
        image_pos_ids = torch.zeros_like(input_ids)
        if len(example["image_ids"]) > 0:
            length = example["image_ids"][0].size(0)
            num_images = sum(image_mask) // length
            image_pos_ids[image_mask] = torch.arange(length).repeat(num_images)

        return {
            "input_ids": input_ids,
            "image_mask": image_mask,
            "image_pos_ids": image_pos_ids,
        }

    def decode(
        self,
        tokens: torch.Tensor | np.ndarray,
        vision_model: PreTrainedModel | None = None,
    ) -> dict:
        tokenizer = self.tokenizer
        bos = tokenizer.bos_token
        eos = tokenizer.eos_token
        bos = bos if bos is not None else eos

        # Validate the input tokens
        msg = "Should be reverted from FIM format before decoding."
        for fim_type in ["prefix", "middle", "suffix"]:
            token_id = tokenizer.convert_tokens_to_ids(f"<fim_{fim_type}>")
            if token_id is None:
                token_id = tokenizer.convert_tokens_to_ids(f"<|fim_{fim_type}|>")
            assert token_id is not None, f"{fim_type} token not found"
            assert token_id not in tokens, msg

        tokens = torch.asarray(tokens).detach().cpu()
        assert tokens.ndim == 1, "Tokens must be 1D."
        boi_id = tokenizer.convert_tokens_to_ids("<begin_of_image>")
        eoi_id = tokenizer.convert_tokens_to_ids("<end_of_image>")

        # Decode tokens
        svg = ""
        images: list = []
        filenames: list = []
        while len(tokens) > 0:
            # Find the start position of the next image filename
            boi_idx = torch.where(tokens == boi_id)[0]
            eoi_idx = torch.where(tokens == eoi_id)[0]
            if boi_idx.size(0) > 0:
                start = int(boi_idx[0].item())
                end = int(eoi_idx[0].item()) + 1 if eoi_idx.size(0) > 0 else len(tokens)
                assert start < end, "Invalid image tokens."
            else:
                start, end = len(tokens), len(tokens)

            # Decode the tokens before the image tokens
            svg += tokenizer.decode(tokens[:start])

            # Decode the image tokens
            if start < end:
                # Extract image size
                image_tokens = tokens[start:end]
                image_text = tokenizer.decode(image_tokens)
                matched = self.re_img_size.match(image_text)
                if matched is not None:
                    width, height = map(int, matched.groups())
                else:
                    width = self.image_processor.size
                    height = self.image_processor.size

                # Decode tokens to PIL image
                image_mask = image_tokens >= len(tokenizer)
                image_ids = image_tokens[image_mask] - len(tokenizer)
                image = self.decode_image(vision_model, image_ids, width, height)
                filename = IMG_FORMAT.format(len(images))
                svg += filename

                images.append(image)
                filenames.append(filename)

            # Update the remaining tokens
            tokens = tokens[end:]

        # Remove consecutive <bos> and <eos>
        svg = re.sub(rf"({re.escape(bos)})+", bos, svg)
        svg = re.sub(rf"({re.escape(eos)})+", eos, svg)

        # Extract the text between <bos> and <eos>
        i_bos = svg.find(bos)
        svg = svg[i_bos + len(bos) :] if i_bos > -1 else svg
        i_eos = svg.find(eos, i_bos + 1)
        svg = svg[:i_eos] if i_eos > -1 else svg

        return {"svg": svg, "images": images, "filenames": filenames}

    def decode_image(
        self,
        vision_model: PreTrainedModel | None = None,
        image_ids: torch.Tensor | np.ndarray | None = None,
        width: int | None = None,
        height: int | None = None,
        dummy_color: tuple[int, int, int, int] = (200,) * 4,
        pad_value: int = 0,
    ) -> Image.Image:
        # Prepare image size
        width = width or self.image_processor.size
        height = height or self.image_processor.size
        width, height = self.compute_safe_image_size(width, height)

        if vision_model is None and image_ids is None:
            # Return a dummy image
            return Image.new("RGBA", (width, height), dummy_color)

        # Compute required length
        assert vision_model is not None, "Vision model must be provided."
        scale_factor = 2 ** (vision_model.model.encoder.num_resolutions - 1)
        latent_size = self.image_processor.size // scale_factor
        required_length = latent_size**2

        # Pad image ids if necessary
        image_ids = torch.asarray(image_ids, device=vision_model.device)
        code_length = image_ids.shape[0]  # type: ignore
        if code_length < required_length:
            pad_size = required_length - code_length
            pad = torch.full((pad_size,), pad_value).to(image_ids)
            image_ids = torch.cat([image_ids, pad])

        # Decode image
        with torch.inference_mode():
            codebook_entry = vision_model.model.quantize.get_codebook_entry(
                image_ids, (1, latent_size, latent_size, -1)
            )
            recon = vision_model.model.decode(codebook_entry)[0].float()

        # Postprocess image
        img = self.image_processor.postprocess(
            recon, self.image_processor.size, self.image_processor.size
        )

        # Mask the padded area
        if code_length < required_length:
            img = self.mask_padded_area(img, code_length, scale_factor)

        # Resize the image to the original size
        img = img.resize((width, height), resample=self.image_processor.resample)

        return img  # type: ignore

    def compute_safe_image_size(self, width: int, height: int) -> tuple[int, int]:
        long_edge = max(width, height)
        if MAXIMUM_DECODE_IMAGE_SIZE < long_edge:
            scale = MAXIMUM_DECODE_IMAGE_SIZE / long_edge
            width = min(max(int(width * scale), 1), MAXIMUM_DECODE_IMAGE_SIZE)
            height = min(max(int(height * scale), 1), MAXIMUM_DECODE_IMAGE_SIZE)
        return width, height

    def mask_padded_area(
        self,
        img: Image.Image,
        code_length: int,
        scale_factor: int,
        fill: tuple[int, int, int, int] = (200, 200, 200, 255),
    ) -> Image.Image:
        draw = ImageDraw.Draw(img, mode="RGBA")
        width, height = img.size
        zw = math.ceil(width / scale_factor)
        cw = code_length % zw
        ch = code_length // zw
        draw.polygon(
            [
                (cw * scale_factor, ch * scale_factor),
                (width, ch * scale_factor),
                (width, height),
                (0, height),
                (0, (ch + 1) * scale_factor),
                (cw * scale_factor, (ch + 1) * scale_factor),
            ],
            fill=fill,
        )
        return img

    def set_font_manager(self, fonts_path: str | None = None) -> None:
        self.font_manager = FontManager(fonts_path)

    def render_preprocess(self, example: dict, out_dir: str | Path) -> None:
        msg = "Font manager is not set. Call `set_font_manager` first."
        assert self.font_manager is not None, msg

        out_dir = Path(out_dir)
        out_dir.mkdir(parents=True, exist_ok=True)
        svg = example["svg"]

        # Costruct style tag
        found = set()
        style_text = "text{dominant-baseline:text-before-edge}"
        for i, text_str in enumerate(re.findall("<text[^>]*>", svg)):
            matched = re.search('font-family="([^"]*)"', text_str)
            if matched is None:
                logger.warning(f"Font family not found in {text_str}")
                continue

            # Parse font attributes
            font_family = matched.group(1)
            is_bold = 'font-weight="bold"' in text_str
            is_italic = 'font-style="italic"' in text_str
            font_weight = "bold" if is_bold else "regular"
            if is_italic:
                font_style = "bolditalic" if is_bold else "italic"
            else:
                font_style = font_weight
            key = (font_family, font_weight, font_style)
            if key in found:
                continue

            font_bytes = self.font_manager.lookup(
                font_family=font_family,
                font_weight=font_weight,
                font_style=font_style,
            )

            # @font-face
            font_path = FONT_FORMAT.format(i)
            font_face = "@font-face{"
            font_face += f"font-family:'{font_family}';"
            font_face += f"font-weight:{font_weight};"
            font_face += f"font-style:{font_style};"
            font_face += f"src:url('{font_path}');"
            font_face += "}"
            style_text += font_face

            # Save font
            Path(f"{out_dir}/{font_path}").write_bytes(font_bytes)
            found.add(key)

        # Insert style tag
        matched = re.search("<svg[^>]*>", svg)
        assert matched is not None, "SVG tag not found"
        i = matched.span()[1]
        style = f"<style>{style_text}</style>"
        example["svg"] = svg[:i] + style + svg[i:]

    def render(self, example: dict, save_dir: str | Path | None = None) -> Image.Image:
        with tempfile.TemporaryDirectory() as tmp_dir:
            self.render_preprocess(example, tmp_dir)

            # Parse the SVG size
            matched = self.re_svg_width.search(example["svg"])
            assert matched is not None, "Width not found in SVG."
            width = int(matched.group(1))
            matched = self.re_svg_height.search(example["svg"])
            assert matched is not None, "Height not found in SVG."
            height = int(matched.group(1))

            # Convert SVG to HTML
            html = '<!DOCTYPE html><html><body style="margin: 0px">'
            html += f"{example['svg']}</body></html>"

            # Save HTML
            Path(f"{tmp_dir}/index.html").write_text(html, encoding="utf-8")

            # Save images
            for img, filename in zip(example["images"], example["filenames"]):
                Path(f"{tmp_dir}/{filename}").parent.mkdir(parents=True, exist_ok=True)
                img.save(f"{tmp_dir}/{filename}")

            # Take screenshot
            command = [
                "google-chrome",
                "--headless",
                "--disable-web-security",
                "--allow-running-insecure-content",
                "--no-sandbox",
                "--disable-infobars",
                "--hide-scrollbars",
                "--disable-dev-shm-usage",
                "--no-zygote",
                f"--window-size={width},{height}",
                f"--screenshot={tmp_dir}/screenshot.png",
                f"{tmp_dir}/index.html",
            ]
            subprocess.run(command, check=True, stderr=subprocess.DEVNULL)

            # Load the screenshot as PIL image
            out = Image.open(f"{tmp_dir}/screenshot.png")
            size = (width, height)
            out = out.resize(size, resample=Image.Resampling.LANCZOS)  # type: ignore

            # Copy the result if save_dir is specified
            if save_dir is not None:
                shutil.copytree(tmp_dir, save_dir, dirs_exist_ok=True)

        return out