| | import json |
| | from tqdm import tqdm |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import xml.etree.ElementTree as ET |
| | from xml.dom import minidom |
| | import os |
| | from PIL import Image |
| | import matplotlib.animation as animation |
| | import copy |
| | from PIL import ImageEnhance |
| | import colorsys |
| | import matplotlib.colors as mcolors |
| | from matplotlib.collections import LineCollection |
| | from matplotlib.patheffects import withStroke |
| | import random |
| | import warnings |
| | from matplotlib.figure import Figure |
| | from io import BytesIO |
| | from matplotlib.animation import FuncAnimation, FFMpegWriter, PillowWriter |
| | import requests |
| | import zipfile |
| | import base64 |
| |
|
| |
|
| | warnings.filterwarnings("ignore") |
| |
|
| |
|
| | def get_svg_content(svg_path): |
| | with open(svg_path, "r") as file: |
| | return file.read() |
| |
|
| |
|
| | def download_file(url, filename): |
| | if os.path.exists(filename): |
| | return |
| | response = requests.get(url) |
| | with open(filename, "wb") as f: |
| | f.write(response.content) |
| |
|
| |
|
| | def unzip_file(filename, extract_to="."): |
| | with zipfile.ZipFile(filename, "r") as zip_ref: |
| | zip_ref.extractall(extract_to) |
| |
|
| |
|
| | def get_base64_encoded_gif(gif_path): |
| | with open(gif_path, "rb") as gif_file: |
| | return base64.b64encode(gif_file.read()).decode("utf-8") |
| |
|
| |
|
| | def load_and_pad_img_dir(file_dir): |
| | image_path = os.path.join(file_dir) |
| | image = Image.open(image_path) |
| | width, height = image.size |
| | ratio = min(224 / width, 224 / height) |
| | image = image.resize((int(width * ratio), int(height * ratio))) |
| | width, height = image.size |
| | if height < 224: |
| | |
| | top_padding = (224 - height) // 2 |
| | bottom_padding = 224 - height - top_padding |
| | padded_image = Image.new("RGB", (width, 224), (255, 255, 255)) |
| | padded_image.paste(image, (0, top_padding)) |
| | else: |
| | |
| | left_padding = (224 - width) // 2 |
| | right_padding = 224 - width - left_padding |
| | padded_image = Image.new("RGB", (224, height), (255, 255, 255)) |
| | padded_image.paste(image, (left_padding, 0)) |
| | return padded_image |
| |
|
| |
|
| | def plot_ink(ink, ax, lw=1.8, input_image=None, with_path=True, path_color="white"): |
| | if input_image is not None: |
| | img = copy.deepcopy(input_image) |
| | enhancer = ImageEnhance.Brightness(img) |
| | img = enhancer.enhance(0.45) |
| | ax.imshow(img) |
| |
|
| | base_colors = plt.cm.get_cmap("rainbow", len(ink.strokes)) |
| |
|
| | for i, stroke in enumerate(ink.strokes): |
| | x, y = np.array(stroke.x), np.array(stroke.y) |
| |
|
| | base_color = base_colors(len(ink.strokes) - 1 - i) |
| | hsv_color = colorsys.rgb_to_hsv(*base_color[:3]) |
| |
|
| | darker_color = colorsys.hsv_to_rgb(hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65)) |
| | colors = [mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(x))) for j in range(len(x))] |
| |
|
| | points = np.array([x, y]).T.reshape(-1, 1, 2) |
| | segments = np.concatenate([points[:-1], points[1:]], axis=1) |
| |
|
| | lc = LineCollection(segments, colors=colors, linewidth=lw) |
| | if with_path: |
| | lc.set_path_effects([withStroke(linewidth=lw * 1.25, foreground=path_color)]) |
| | ax.add_collection(lc) |
| |
|
| | ax.set_xlim(0, 224) |
| | ax.set_ylim(0, 224) |
| | ax.invert_yaxis() |
| |
|
| |
|
| | def plot_ink_to_video(ink, output_name, lw=1.8, input_image=None, path_color="white", fps=30): |
| | fig, ax = plt.subplots(figsize=(4, 4), dpi=150) |
| |
|
| | if input_image is not None: |
| | img = copy.deepcopy(input_image) |
| | enhancer = ImageEnhance.Brightness(img) |
| | img = enhancer.enhance(0.45) |
| | ax.imshow(img) |
| |
|
| | ax.set_xlim(0, 224) |
| | ax.set_ylim(0, 224) |
| | ax.invert_yaxis() |
| | ax.axis("off") |
| |
|
| | base_colors = plt.cm.get_cmap("rainbow", len(ink.strokes)) |
| | all_points = sum([len(stroke.x) for stroke in ink.strokes], 0) |
| |
|
| | def update(frame): |
| | ax.clear() |
| | if input_image is not None: |
| | ax.imshow(img) |
| | ax.set_xlim(0, 224) |
| | ax.set_ylim(0, 224) |
| | ax.invert_yaxis() |
| | ax.axis("off") |
| |
|
| | points_drawn = 0 |
| | for stroke_index, stroke in enumerate(ink.strokes): |
| | x, y = np.array(stroke.x), np.array(stroke.y) |
| | points = np.array([x, y]).T.reshape(-1, 1, 2) |
| | segments = np.concatenate([points[:-1], points[1:]], axis=1) |
| |
|
| | base_color = base_colors(len(ink.strokes) - 1 - stroke_index) |
| | hsv_color = colorsys.rgb_to_hsv(*base_color[:3]) |
| | darker_color = colorsys.hsv_to_rgb(hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65)) |
| | visible_segments = segments[: frame - points_drawn] if frame - points_drawn < len(segments) else segments |
| | colors = [ |
| | mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(visible_segments))) |
| | for j in range(len(visible_segments)) |
| | ] |
| |
|
| | if len(visible_segments) > 0: |
| | lc = LineCollection(visible_segments, colors=colors, linewidth=lw) |
| | lc.set_path_effects([withStroke(linewidth=lw * 1.25, foreground=path_color)]) |
| | ax.add_collection(lc) |
| |
|
| | points_drawn += len(segments) |
| | if points_drawn >= frame: |
| | break |
| |
|
| | ani = FuncAnimation(fig, update, frames=all_points + 1, blit=False) |
| | Writer = FFMpegWriter(fps=fps) |
| | plt.tight_layout() |
| | ani.save(output_name, writer=Writer) |
| | plt.close(fig) |
| |
|
| |
|
| | class Stroke: |
| | def __init__(self, list_of_coordinates=None) -> None: |
| | self.x = [] |
| | self.y = [] |
| | if list_of_coordinates: |
| | for point in list_of_coordinates: |
| | self.x.append(point[0]) |
| | self.y.append(point[1]) |
| |
|
| | def __len__(self): |
| | return len(self.x) |
| |
|
| | def __getitem__(self, index): |
| | return (self.x[index], self.y[index]) |
| |
|
| |
|
| | class Ink: |
| | def __init__(self, list_of_strokes=None) -> None: |
| | self.strokes = [] |
| | if list_of_strokes: |
| | self.strokes = list_of_strokes |
| |
|
| | def __len__(self): |
| | return len(self.strokes) |
| |
|
| | def __getitem__(self, index): |
| | return self.strokes[index] |
| |
|
| |
|
| | def inkml_to_ink(inkml_file): |
| | """Convert inkml file to Ink""" |
| | tree = ET.parse(inkml_file) |
| | root = tree.getroot() |
| |
|
| | inkml_namespace = {"inkml": "http://www.w3.org/2003/InkML"} |
| |
|
| | strokes = [] |
| |
|
| | for trace in root.findall("inkml:trace", inkml_namespace): |
| | points = trace.text.strip().split() |
| | stroke_points = [] |
| |
|
| | for point in points: |
| | x, y = point.split(",") |
| | stroke_points.append((float(x), float(y))) |
| | strokes.append(Stroke(stroke_points)) |
| | return Ink(strokes) |
| |
|
| |
|
| | def parse_inkml_annotations(inkml_file): |
| | tree = ET.parse(inkml_file) |
| | root = tree.getroot() |
| |
|
| | annotations = root.findall(".//{http://www.w3.org/2003/InkML}annotation") |
| |
|
| | annotation_dict = {} |
| |
|
| | for annotation in annotations: |
| | annotation_type = annotation.get("type") |
| | annotation_text = annotation.text |
| |
|
| | annotation_dict[annotation_type] = annotation_text |
| |
|
| | return annotation_dict |
| |
|
| |
|
| | def pregenerate_videos(video_cache_dir): |
| | datasets = ["IAM", "IMGUR5K", "HierText"] |
| | models = ["Small-i", "Large-i", "Small-p"] |
| | query_modes = ["d+t", "r+d", "vanilla"] |
| | for Dataset in datasets: |
| | for Model in models: |
| | inkml_path_base = f"./derendering_supp/{Model.lower()}_{Dataset}_inkml" |
| | for mode in query_modes: |
| | path = f"./derendering_supp/{Dataset}/images_sample" |
| | if not os.path.exists(path): |
| | continue |
| | samples = os.listdir(path) |
| | for name in tqdm(samples, desc=f"Generating {Model}-{Dataset}-{mode} videos"): |
| | example_id = name.strip(".png") |
| | inkml_file = os.path.join(inkml_path_base, mode, f"{example_id}.inkml") |
| | if not os.path.exists(inkml_file): |
| | continue |
| | video_filename = f"{Model}_{Dataset}_{mode}_{example_id}.mp4" |
| | video_filepath = video_cache_dir / video_filename |
| | if not video_filepath.exists(): |
| | img_path = os.path.join(path, name) |
| | img = load_and_pad_img_dir(img_path) |
| | ink = inkml_to_ink(inkml_file) |
| | plot_ink_to_video(ink, str(video_filepath), input_image=img) |
| |
|