Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Author: ximing | |
| # Description: LIVE pipeline | |
| # Copyright (c) 2023, XiMing Xing. | |
| # License: MIT License | |
| import shutil | |
| from pathlib import Path | |
| from functools import partial | |
| from typing import AnyStr | |
| from PIL import Image | |
| from tqdm.auto import tqdm | |
| import torch | |
| from torchvision import transforms | |
| from pytorch_svgrender.libs.engine import ModelState | |
| from pytorch_svgrender.painter.diffvg import Painter, PainterOptimizer | |
| from pytorch_svgrender.plt import plot_img, plot_couple | |
| from pytorch_svgrender.libs.metric.lpips_origin import LPIPS | |
| class DiffVGPipeline(ModelState): | |
| def __init__(self, args): | |
| logdir_ = f"sd{args.seed}" \ | |
| f"-{args.x.path_type}" \ | |
| f"-P{args.x.num_paths}" | |
| super().__init__(args, log_path_suffix=logdir_) | |
| assert self.x_cfg.path_type in ['unclosed', 'closed'] | |
| # create log dir | |
| self.png_logs_dir = self.result_path / "png_logs" | |
| self.svg_logs_dir = self.result_path / "svg_logs" | |
| if self.accelerator.is_main_process: | |
| self.png_logs_dir.mkdir(parents=True, exist_ok=True) | |
| self.svg_logs_dir.mkdir(parents=True, exist_ok=True) | |
| # make video log | |
| self.make_video = self.args.mv | |
| if self.make_video: | |
| self.frame_idx = 0 | |
| self.frame_log_dir = self.result_path / "frame_logs" | |
| self.frame_log_dir.mkdir(parents=True, exist_ok=True) | |
| def target_file_preprocess(self, tar_path): | |
| process_comp = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Lambda(lambda t: t.unsqueeze(0)), | |
| ]) | |
| tar_pil = Image.open(tar_path).convert("RGB") # open file | |
| target_img = process_comp(tar_pil) # preprocess | |
| target_img = target_img.to(self.device) | |
| return target_img | |
| def painterly_rendering(self, img_path: AnyStr): | |
| # load target file | |
| target_file = Path(img_path) | |
| assert target_file.exists(), f"{target_file} is not exist!" | |
| shutil.copy(target_file, self.result_path) # copy target file | |
| target_img = self.target_file_preprocess(target_file.as_posix()) | |
| self.print(f"load image from: '{target_file.as_posix()}'") | |
| # init Painter | |
| renderer = Painter(target_img, | |
| self.args.diffvg, | |
| canvas_size=[target_img.shape[3], target_img.shape[2]], | |
| path_type=self.x_cfg.path_type, | |
| max_width=self.x_cfg.max_width, | |
| device=self.device) | |
| init_img = renderer.init_image(num_paths=self.x_cfg.num_paths) | |
| self.print("init_image shape: ", init_img.shape) | |
| plot_img(init_img, self.result_path, fname="init_img") | |
| # init Painter Optimizer | |
| num_iter = self.x_cfg.num_iter | |
| optimizer = PainterOptimizer(renderer, | |
| num_iter, | |
| self.x_cfg.lr_base, | |
| trainable_stroke=self.x_cfg.path_type == 'unclosed') | |
| optimizer.init_optimizer() | |
| # Set Loss | |
| if self.x_cfg.loss_type in ['lpips', 'l2+lpips']: | |
| lpips_loss_fn = LPIPS(net=self.x_cfg.perceptual.lpips_net).to(self.device) | |
| perceptual_loss_fn = partial(lpips_loss_fn.forward, return_per_layer=False, normalize=False) | |
| with tqdm(initial=self.step, total=num_iter, disable=not self.accelerator.is_main_process) as pbar: | |
| while self.step < num_iter: | |
| raster_img = renderer.get_image(self.step).to(self.device) | |
| if self.make_video and (self.step % self.args.framefreq == 0 or self.step == num_iter - 1): | |
| plot_img(raster_img, self.frame_log_dir, fname=f"iter{self.frame_idx}") | |
| self.frame_idx += 1 | |
| # Reconstruction Loss | |
| if self.x_cfg.loss_type == 'l1': | |
| loss_recon = torch.nn.functional.l1_loss(raster_img, target_img) | |
| elif self.x_cfg.loss_type == 'lpips': | |
| loss_recon = perceptual_loss_fn(raster_img, target_img).mean() | |
| elif self.x_cfg.loss_type == 'l2': # default: MSE loss | |
| loss_recon = torch.nn.functional.mse_loss(raster_img, target_img) | |
| elif self.x_cfg.loss_type == 'l2+lpips': # default: MSE loss | |
| lpips = perceptual_loss_fn(raster_img, target_img).mean() | |
| loss_mse = torch.nn.functional.mse_loss(raster_img, target_img) | |
| loss_recon = loss_mse + lpips | |
| # total loss | |
| loss = loss_recon | |
| pbar.set_description( | |
| f"lr: {optimizer.get_lr():.4f}, " | |
| f"L_recon: {loss_recon.item():.4f}" | |
| ) | |
| # optimization | |
| optimizer.zero_grad_() | |
| loss.backward() | |
| optimizer.step_() | |
| renderer.clip_curve_shape() | |
| if self.x_cfg.lr_schedule: | |
| optimizer.update_lr() | |
| if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: | |
| plot_couple(target_img, | |
| raster_img, | |
| self.step, | |
| output_dir=self.png_logs_dir.as_posix(), | |
| fname=f"iter{self.step}") | |
| renderer.save_svg(self.svg_logs_dir / f"svg_iter{self.step}.svg") | |
| self.step += 1 | |
| pbar.update(1) | |
| # end rendering | |
| renderer.save_svg(self.result_path / "final_svg.svg") | |
| if self.make_video: | |
| from subprocess import call | |
| call([ | |
| "ffmpeg", | |
| "-framerate", f"{self.args.framerate}", | |
| "-i", (self.frame_log_dir / "iter%d.png").as_posix(), | |
| "-vb", "20M", | |
| (self.result_path / "live_rendering.mp4").as_posix() | |
| ]) | |
| self.close(msg="painterly rendering complete.") | |