Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Copyright (c) XiMing Xing. All rights reserved. | |
| # Author: XiMing Xing | |
| # Description: | |
| import shutil | |
| from PIL import Image | |
| from pathlib import Path | |
| import torch | |
| from torchvision import transforms | |
| import clip | |
| from tqdm.auto import tqdm | |
| import numpy as np | |
| from pytorch_svgrender.libs.engine import ModelState | |
| from pytorch_svgrender.painter.style_clipdraw import ( | |
| Painter, PainterOptimizer, VGG16Extractor, StyleLoss, sample_indices | |
| ) | |
| from pytorch_svgrender.plt import plot_img, plot_couple | |
| class StyleCLIPDrawPipeline(ModelState): | |
| def __init__(self, args): | |
| logdir_ = f"sd{args.seed}" \ | |
| f"-P{args.x.num_paths}" \ | |
| f"-style{args.x.style_strength}" \ | |
| f"-n{args.x.num_aug}" | |
| super().__init__(args, log_path_suffix=logdir_) | |
| # create log dir | |
| self.png_logs_dir = self.result_path / "png_logs" | |
| self.svg_logs_dir = self.result_path / "svg_logs" | |
| if self.accelerator.is_main_process: | |
| self.png_logs_dir.mkdir(parents=True, exist_ok=True) | |
| self.svg_logs_dir.mkdir(parents=True, exist_ok=True) | |
| # make video log | |
| self.make_video = self.args.mv | |
| if self.make_video: | |
| self.frame_idx = 0 | |
| self.frame_log_dir = self.result_path / "frame_logs" | |
| self.frame_log_dir.mkdir(parents=True, exist_ok=True) | |
| self.clip, self.tokenize_fn = self.init_clip() | |
| self.style_extractor = VGG16Extractor(space="normal").to(self.device) | |
| self.style_loss = StyleLoss() | |
| def init_clip(self): | |
| model, _ = clip.load('ViT-B/32', self.device, jit=False) | |
| return model, clip.tokenize | |
| def drawing_augment(self, image): | |
| augment_trans = transforms.Compose([ | |
| transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5), | |
| transforms.RandomResizedCrop(224, scale=(0.7, 0.9)), | |
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
| ]) | |
| # image augmentation transformation | |
| img_augs = [] | |
| for n in range(self.x_cfg.num_aug): | |
| img_augs.append(augment_trans(image)) | |
| im_batch = torch.cat(img_augs) | |
| # clip visual encoding | |
| image_features = self.clip.encode_image(im_batch) | |
| return image_features | |
| def style_file_preprocess(self, style_file): | |
| process_comp = transforms.Compose([ | |
| transforms.Resize(size=(224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Lambda(lambda t: t.unsqueeze(0)), | |
| transforms.Lambda(lambda t: (t + 1) / 2), | |
| ]) | |
| style_file = process_comp(style_file) | |
| style_file = style_file.to(self.device) | |
| return style_file | |
| def painterly_rendering(self, prompt, style_fpath): | |
| # load style file | |
| style_path = Path(style_fpath) | |
| assert style_path.exists(), f"{style_fpath} is not exist!" | |
| self.print(f"load style file from: {style_path.as_posix()}") | |
| style_pil = Image.open(style_path.as_posix()).convert("RGB") | |
| style_img = self.style_file_preprocess(style_pil) | |
| shutil.copy(style_fpath, self.result_path) # copy style file | |
| # extract style features from style image | |
| feat_style = None | |
| for i in range(5): | |
| with torch.no_grad(): | |
| # r is region of interest (mask) | |
| feat_e = self.style_extractor.forward_samples_hypercolumn(style_img, samps=1000) | |
| feat_style = feat_e if feat_style is None else torch.cat((feat_style, feat_e), dim=2) | |
| # text prompt encoding | |
| self.print(f"prompt: {prompt}") | |
| text_tokenize = self.tokenize_fn(prompt).to(self.device) | |
| with torch.no_grad(): | |
| text_features = self.clip.encode_text(text_tokenize) | |
| renderer = Painter(self.x_cfg, | |
| self.args.diffvg, | |
| num_strokes=self.x_cfg.num_paths, | |
| canvas_size=self.x_cfg.image_size, | |
| device=self.device) | |
| img = renderer.init_image(stage=0) | |
| self.print("init_image shape: ", img.shape) | |
| plot_img(img, self.result_path, fname="init_img") | |
| optimizer = PainterOptimizer(renderer, self.x_cfg.lr, self.x_cfg.width_lr, self.x_cfg.color_lr) | |
| optimizer.init_optimizers() | |
| style_weight = 4 * (self.x_cfg.style_strength / 100) | |
| self.print(f'style_weight: {style_weight}') | |
| total_step = self.x_cfg.num_iter | |
| with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar: | |
| while self.step < total_step: | |
| rendering = renderer.get_image(self.step).to(self.device) | |
| if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1): | |
| plot_img(rendering, self.frame_log_dir, fname=f"iter{self.frame_idx}") | |
| self.frame_idx += 1 | |
| rendering_aug = self.drawing_augment(rendering) | |
| loss = torch.tensor(0., device=self.device) | |
| # do clip optimization | |
| if self.step < 0.9 * total_step: | |
| for n in range(self.x_cfg.num_aug): | |
| loss -= torch.cosine_similarity(text_features, rendering_aug[n:n + 1], dim=1).mean() | |
| # do style optimization | |
| # extract style features based on the approach from STROTSS [Kolkin et al., 2019]. | |
| feat_content = self.style_extractor(rendering) | |
| xx, xy = sample_indices(feat_content[0], feat_style) | |
| np.random.shuffle(xx) | |
| np.random.shuffle(xy) | |
| L_style = self.style_loss.forward(feat_content, feat_content, feat_style, [xx, xy], 0) | |
| loss += L_style * style_weight | |
| pbar.set_description( | |
| f"lr: {optimizer.get_lr():.3f}, " | |
| f"L_train: {loss.item():.4f}, " | |
| f"L_style: {L_style.item():.4f}" | |
| ) | |
| # optimization | |
| optimizer.zero_grad_() | |
| loss.backward() | |
| optimizer.step_() | |
| renderer.clip_curve_shape() | |
| if self.x_cfg.lr_schedule: | |
| optimizer.update_lr(self.step) | |
| if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: | |
| plot_couple(style_img, | |
| rendering, | |
| self.step, | |
| prompt=prompt, | |
| output_dir=self.png_logs_dir.as_posix(), | |
| fname=f"iter{self.step}") | |
| renderer.save_svg(self.svg_logs_dir.as_posix(), f"svg_iter{self.step}") | |
| self.step += 1 | |
| pbar.update(1) | |
| plot_couple(style_img, | |
| rendering, | |
| self.step, | |
| prompt=prompt, | |
| output_dir=self.result_path.as_posix(), | |
| fname=f"final_iter") | |
| renderer.save_svg(self.result_path.as_posix(), "final_svg") | |
| if self.make_video: | |
| from subprocess import call | |
| call([ | |
| "ffmpeg", | |
| "-framerate", f"{self.args.framerate}", | |
| "-i", (self.frame_log_dir / "iter%d.png").as_posix(), | |
| "-vb", "20M", | |
| (self.result_path / "styleclipdraw_rendering.mp4").as_posix() | |
| ]) | |
| self.close(msg="painterly rendering complete.") | |