Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Copyright (c) XiMing Xing. All rights reserved. | |
| # Author: XiMing Xing | |
| # Description: | |
| from pathlib import Path | |
| from tqdm.auto import tqdm | |
| import torch | |
| from pytorch_svgrender.libs.engine import ModelState | |
| from pytorch_svgrender.painter.wordasimage import Painter, PainterOptimizer | |
| from pytorch_svgrender.painter.wordasimage.losses import ToneLoss, ConformalLoss | |
| from pytorch_svgrender.painter.vectorfusion import LSDSPipeline | |
| from pytorch_svgrender.plt import plot_img, plot_couple | |
| from pytorch_svgrender.diffusers_warp import init_StableDiffusion_pipeline | |
| from pytorch_svgrender.svgtools import FONT_LIST | |
| class WordAsImagePipeline(ModelState): | |
| def __init__(self, args): | |
| # assert | |
| assert args.x.optim_letter in args.x.word | |
| assert Path(args.x.font_path).exists(), f"{args.x.font_path} is not exist." | |
| assert args.x.font in FONT_LIST, f"{args.x.font} is not currently supported." | |
| # make logdir | |
| logdir_ = f"sd{args.seed}" \ | |
| f"-im{args.x.image_size}" \ | |
| f"-{args.x.word}-{args.x.optim_letter}" | |
| super().__init__(args, log_path_suffix=logdir_) | |
| # log dir | |
| self.png_log_dir = self.result_path / "png_logs" | |
| self.svg_log_dir = self.result_path / "svg_logs" | |
| # font | |
| self.font = self.x_cfg.font | |
| self.font_path = self.x_cfg.font_path | |
| self.optim_letter = self.x_cfg.optim_letter | |
| # letter | |
| self.letter = self.x_cfg.optim_letter | |
| self.target_letter = self.result_path / f"{self.font}_{self.optim_letter}_scaled.svg" | |
| # make log dir | |
| if self.accelerator.is_main_process: | |
| self.png_log_dir.mkdir(parents=True, exist_ok=True) | |
| self.svg_log_dir.mkdir(parents=True, exist_ok=True) | |
| # make video log | |
| self.make_video = self.args.mv | |
| if self.make_video: | |
| self.frame_idx = 0 | |
| self.frame_log_dir = self.result_path / "frame_logs" | |
| self.frame_log_dir.mkdir(parents=True, exist_ok=True) | |
| self.diffusion = init_StableDiffusion_pipeline( | |
| self.x_cfg.model_id, | |
| custom_pipeline=LSDSPipeline, | |
| device=self.device, | |
| local_files_only=not args.diffuser.download, | |
| force_download=args.diffuser.force_download, | |
| resume_download=args.diffuser.resume_download, | |
| ldm_speed_up=self.x_cfg.ldm_speed_up, | |
| enable_xformers=self.x_cfg.enable_xformers, | |
| gradient_checkpoint=self.x_cfg.gradient_checkpoint, | |
| lora_path=self.x_cfg.lora_path | |
| ) | |
| self.g_device = torch.Generator(device=self.device).manual_seed(args.seed) | |
| def painterly_rendering(self, word, semantic_concept, optimized_letter): | |
| prompt = semantic_concept + ". " + self.x_cfg.prompt_suffix | |
| self.print(f"prompt: {prompt}") | |
| # load the optimized letter | |
| renderer = Painter(self.font, canvas_size=self.x_cfg.image_size, device=self.device) | |
| # font to svg | |
| self.print(f"font type: {self.font}\n") | |
| renderer.preprocess_font(word, | |
| optimized_letter, | |
| self.x_cfg.level_of_cc, | |
| self.font_path, | |
| self.result_path.as_posix()) | |
| # init letter shape | |
| img_init = renderer.init_shape(self.target_letter) | |
| plot_img(img_init, self.result_path, fname="word_init") | |
| # save init letter | |
| renderer.pretty_save_svg(self.result_path / "letter_init.svg") | |
| init_letter = renderer.get_image() | |
| n_iter = self.x_cfg.num_iter | |
| # init optimizer and lr_schedular | |
| optimizer = PainterOptimizer(renderer, n_iter, self.x_cfg.lr) | |
| optimizer.init_optimizers() | |
| # init Tone loss | |
| if self.x_cfg.tone_loss.use: | |
| tone_loss = ToneLoss(self.x_cfg.tone_loss) | |
| tone_loss.set_image_init(img_init) | |
| # init conformal loss | |
| if self.x_cfg.conformal.use: | |
| conformal_loss = ConformalLoss(renderer.get_point_parameters(), | |
| renderer.shape_groups, | |
| optimized_letter, self.device) | |
| with tqdm(initial=self.step, total=n_iter, disable=not self.accelerator.is_main_process) as pbar: | |
| for i in range(n_iter): | |
| raster_img = renderer.get_image(step=i) | |
| if self.make_video and (i % self.args.framefreq == 0 or i == n_iter - 1): | |
| plot_img(raster_img, self.frame_log_dir, fname=f"iter{self.step}") | |
| L_sds, grad = self.diffusion.score_distillation_sampling( | |
| raster_img, | |
| im_size=self.x_cfg.sds.im_size, | |
| prompt=[prompt], | |
| negative_prompt=self.args.neg_prompt, | |
| guidance_scale=self.x_cfg.sds.guidance_scale, | |
| grad_scale=self.x_cfg.sds.grad_scale, | |
| t_range=list(self.x_cfg.sds.t_range), | |
| ) | |
| loss = L_sds | |
| if self.x_cfg.tone_loss.use: | |
| tone_loss_res = tone_loss(raster_img, step=i) | |
| loss = loss + tone_loss_res | |
| if self.x_cfg.conformal.use: | |
| loss_angles = conformal_loss() | |
| loss_angles = self.x_cfg.conformal.angeles_w * loss_angles | |
| loss = loss + loss_angles | |
| pbar.set_description( | |
| f"n_params: {len(renderer.get_point_parameters())}, " | |
| f"lr: {optimizer.get_lr():.4f}, " | |
| f"L_total: {loss.item():.4f}, " | |
| ) | |
| # optimization | |
| optimizer.zero_grad_() | |
| loss.backward() | |
| optimizer.step_() | |
| if self.x_cfg.lr_schedule: | |
| optimizer.update_lr() | |
| if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: | |
| plot_couple(init_letter, | |
| raster_img, | |
| self.step, | |
| output_dir=self.png_log_dir.as_posix(), | |
| fname=f"iter{self.step}", | |
| prompt=prompt) | |
| renderer.pretty_save_svg(self.svg_log_dir / f"svg_iter{self.step}.svg") | |
| self.step += 1 | |
| pbar.update(1) | |
| # save final optimized letter | |
| renderer.pretty_save_svg(self.result_path / "final_letter.svg") | |
| # combine word | |
| renderer.combine_word(word, optimized_letter, self.font, self.result_path) | |
| if self.make_video: | |
| from subprocess import call | |
| call([ | |
| "ffmpeg", | |
| "-framerate", f"{self.args.framerate}", | |
| "-i", (self.frame_log_dir / "iter%d.png").as_posix(), | |
| "-vb", "20M", | |
| (self.result_path / "wordasimg_rendering.mp4").as_posix() | |
| ]) | |
| self.close(msg="painterly rendering complete.") | |