Spaces:
Sleeping
Sleeping
| import shutil | |
| from pathlib import Path | |
| import imageio | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from pytorch_svgrender.libs.engine import ModelState | |
| from pytorch_svgrender.painter.clipascene import Painter, PainterOptimizer, Loss | |
| from pytorch_svgrender.painter.clipascene.lama_utils import apply_inpaint | |
| from pytorch_svgrender.painter.clipascene.scripts_utils import read_svg | |
| from pytorch_svgrender.painter.clipascene.sketch_utils import plot_attn, get_mask_u2net, fix_image_scale | |
| from pytorch_svgrender.plt import plot_img, plot_couple | |
| from skimage.transform import resize | |
| from torchvision import transforms | |
| from torchvision.transforms import InterpolationMode | |
| from tqdm.auto import tqdm | |
| class CLIPascenePipeline(ModelState): | |
| def __init__(self, args): | |
| logdir_ = f"sd{args.seed}" \ | |
| f"-im{args.x.image_size}" \ | |
| f"-P{args.x.num_paths}W{args.x.width}" | |
| super().__init__(args, log_path_suffix=logdir_) | |
| def painterly_rendering(self, image_path): | |
| foreground_target, background_target = self.preprocess_image(image_path) | |
| background_output_dir = self.run_background(background_target) | |
| foreground_output_dir = self.run_foreground(foreground_target) | |
| self.combine(background_output_dir, foreground_output_dir, self.device) | |
| self.close(msg="painterly rendering complete.") | |
| def preprocess_image(self, image_path): | |
| image_path = Path(image_path) | |
| scene_path = self.result_path / "scene" | |
| background_path = self.result_path / "background" | |
| if self.accelerator.is_main_process: | |
| scene_path.mkdir(parents=True, exist_ok=True) | |
| background_path.mkdir(parents=True, exist_ok=True) | |
| im = Image.open(image_path) | |
| max_size = max(im.size[0], im.size[1]) | |
| scaled_path = scene_path / f"{image_path.stem}.png" | |
| if max_size > 512: | |
| im = Image.open(image_path).convert("RGB").resize((512, 512)) | |
| im.save(scaled_path) | |
| else: | |
| shutil.copyfile(image_path, scaled_path) | |
| scaled_img = Image.open(scaled_path) | |
| mask = get_mask_u2net(scaled_img, scene_path, self.args.x.u2net_path, preprocess=True, device=self.device) | |
| masked_path = scene_path / f"{image_path.stem}_mask.png" | |
| imageio.imsave(masked_path, mask) | |
| apply_inpaint(scene_path, background_path, self.device) | |
| return scaled_path, background_path / f"{image_path.stem}_mask.png" | |
| def run_background(self, target_file): | |
| print("=====Start background=====") | |
| self.args.x.resize_obj = 0 | |
| self.args.x.mask_object = 0 | |
| clip_conv_layer_weights_int = [0 for _ in range(12)] | |
| clip_conv_layer_weights_int[self.args.x.background_layer] = 1 | |
| clip_conv_layer_weights_str = [str(j) for j in clip_conv_layer_weights_int] | |
| self.args.x.clip_conv_layer_weights = ','.join(clip_conv_layer_weights_str) | |
| output_dir = self.result_path / "background" | |
| if self.accelerator.is_main_process: | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| self.paint(target_file, output_dir, self.args.x.background_num_iter) | |
| print("=====End background=====") | |
| return output_dir | |
| def run_foreground(self, target_file): | |
| print("=====Start foreground=====") | |
| self.args.x.resize_obj = 1 | |
| if self.args.x.foreground_layer != 4: | |
| self.args.x.gradnorm = 1 | |
| self.args.x.mask_object = 1 | |
| clip_conv_layer_weights_int = [0 for _ in range(12)] | |
| clip_conv_layer_weights_int[4] = 0.5 | |
| clip_conv_layer_weights_int[self.args.x.foreground_layer] = 1 | |
| clip_conv_layer_weights_str = [str(j) for j in clip_conv_layer_weights_int] | |
| self.args.x.clip_conv_layer_weights = ','.join(clip_conv_layer_weights_str) | |
| output_dir = self.result_path / "object" | |
| if self.accelerator.is_main_process: | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| self.paint(target_file, output_dir, self.args.x.foreground_num_iter) | |
| print("=====End foreground=====") | |
| return output_dir | |
| def paint(self, target, output_dir, num_iter): | |
| png_log_dir = output_dir / "png_logs" | |
| svg_log_dir = output_dir / "svg_logs" | |
| if self.accelerator.is_main_process: | |
| png_log_dir.mkdir(parents=True, exist_ok=True) | |
| svg_log_dir.mkdir(parents=True, exist_ok=True) | |
| # make video log | |
| self.make_video = self.args.mv | |
| if self.make_video: | |
| self.frame_idx = 0 | |
| self.frame_log_dir = output_dir / "frame_logs" | |
| self.frame_log_dir.mkdir(parents=True, exist_ok=True) | |
| # preprocess input image | |
| inputs, mask = self.get_target(target, | |
| self.args.x.image_size, | |
| output_dir, | |
| self.args.x.resize_obj, | |
| self.args.x.u2net_path, | |
| self.args.x.mask_object, | |
| self.args.x.fix_scale, | |
| self.device) | |
| plot_img(inputs, output_dir, fname="target") | |
| loss_func = Loss(self.x_cfg, mask, self.device) | |
| # init renderer | |
| renderer = self.load_renderer(inputs, mask) | |
| # init optimizer | |
| optimizer = PainterOptimizer(self.x_cfg, renderer) | |
| best_loss, best_fc_loss, best_num_strokes = 100, 100, self.args.x.num_paths | |
| best_iter, best_iter_fc = 0, 0 | |
| min_delta = 1e-7 | |
| renderer.set_random_noise(0) | |
| renderer.init_image(stage=0) | |
| renderer.save_svg(svg_log_dir, "init_svg") | |
| optimizer.init_optimizers() | |
| if self.args.x.switch_loss: | |
| # start with width optim and than switch every switch_loss iterations | |
| renderer.turn_off_points_optim() | |
| optimizer.turn_off_points_optim() | |
| with torch.no_grad(): | |
| renderer.get_image("init").to(self.device) | |
| renderer.save_svg(self.result_path, "init") | |
| total_step = num_iter | |
| step = 0 | |
| with tqdm(initial=step, total=total_step, disable=not self.accelerator.is_main_process) as pbar: | |
| while step < total_step: | |
| optimizer.zero_grad_() | |
| sketches = renderer.get_image().to(self.device) | |
| if self.make_video and (step % self.args.framefreq == 0 or step == total_step - 1): | |
| plot_img(sketches, self.frame_log_dir, fname=f"iter{self.frame_idx}") | |
| self.frame_idx += 1 | |
| losses_dict_weighted, _, _ = loss_func(sketches, inputs.detach(), step, | |
| renderer.get_widths(), renderer, | |
| optimizer, mode="train", | |
| width_opt=renderer.width_optim) | |
| loss = sum(list(losses_dict_weighted.values())) | |
| loss.backward() | |
| optimizer.step_() | |
| if step % self.args.x.save_step == 0: | |
| plot_couple(inputs, | |
| sketches, | |
| self.step, | |
| output_dir=png_log_dir.as_posix(), | |
| fname=f"iter{step}") | |
| renderer.save_svg(svg_log_dir.as_posix(), f"svg_iter{step}") | |
| if step % self.args.x.eval_step == 0: | |
| with torch.no_grad(): | |
| losses_dict_weighted_eval, _, _ = loss_func( | |
| sketches, | |
| inputs, | |
| step, | |
| renderer.get_widths(), | |
| renderer=renderer, | |
| mode="eval", | |
| width_opt=renderer.width_optim) | |
| loss_eval = sum(list(losses_dict_weighted_eval.values())) | |
| cur_delta = loss_eval.item() - best_loss | |
| if abs(cur_delta) > min_delta: | |
| if cur_delta < 0: | |
| best_loss = loss_eval.item() | |
| best_iter = step | |
| plot_couple(inputs, | |
| sketches, | |
| best_iter, | |
| output_dir=output_dir.as_posix(), | |
| fname="best_iter") | |
| renderer.save_svg(output_dir.as_posix(), "best_iter") | |
| if step == 0 and self.x_cfg.attention_init and self.accelerator.is_main_process: | |
| plot_attn(renderer.get_attn(), | |
| renderer.get_thresh(), | |
| inputs, | |
| renderer.get_inds(), | |
| (output_dir / "attention_map.png").as_posix(), | |
| self.x_cfg.saliency_model) | |
| if self.args.x.switch_loss: | |
| if step > 0 and step % self.args.x.switch_loss == 0: | |
| renderer.switch_opt() | |
| optimizer.switch_opt() | |
| step += 1 | |
| pbar.update(1) | |
| if self.make_video: | |
| from subprocess import call | |
| call([ | |
| "ffmpeg", | |
| "-framerate", f"{self.args.framerate}", | |
| "-i", (self.frame_log_dir / "iter%d.png").as_posix(), | |
| "-vb", "20M", | |
| (output_dir / f"clipascene_sketch.mp4").as_posix() | |
| ]) | |
| def load_renderer(self, target_im=None, mask=None): | |
| renderer = Painter(method_cfg=self.x_cfg, | |
| diffvg_cfg=self.args.diffvg, | |
| num_strokes=self.x_cfg.num_paths, | |
| canvas_size=self.x_cfg.image_size, | |
| device=self.device, | |
| target_im=target_im, | |
| mask=mask) | |
| return renderer | |
| def get_target(self, | |
| target_file, | |
| image_size, | |
| output_dir, | |
| resize_obj, | |
| u2net_path, | |
| mask_object, | |
| fix_scale, | |
| device): | |
| target = Image.open(target_file) | |
| if target.mode == "RGBA": | |
| # Create a white rgba background | |
| new_image = Image.new("RGBA", target.size, "WHITE") | |
| # Paste the image on the background. | |
| new_image.paste(target, (0, 0), target) | |
| target = new_image | |
| target = target.convert("RGB") | |
| # U^2 net mask | |
| masked_im, mask = get_mask_u2net(target, output_dir, u2net_path, resize_obj=resize_obj, device=device) | |
| if mask_object: | |
| target = masked_im | |
| if fix_scale: | |
| target = fix_image_scale(target) | |
| transforms_ = [] | |
| if target.size[0] != target.size[1]: | |
| transforms_.append( | |
| transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC) | |
| ) | |
| else: | |
| transforms_.append(transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC)) | |
| transforms_.append(transforms.CenterCrop(image_size)) | |
| transforms_.append(transforms.ToTensor()) | |
| data_transforms = transforms.Compose(transforms_) | |
| target_ = data_transforms(target).unsqueeze(0).to(self.device) | |
| return target_, mask | |
| def combine(self, background_output_dir, foreground_output_dir, device, output_size=448): | |
| params_path = foreground_output_dir / "resize_params.npy" | |
| params = None | |
| if params_path.exists(): | |
| params = np.load(params_path, allow_pickle=True)[()] | |
| mask_path = foreground_output_dir / "mask.png" | |
| mask = imageio.imread(mask_path) | |
| mask = resize(mask, (output_size, output_size), anti_aliasing=False) | |
| object_svg_path = foreground_output_dir / "best_iter.svg" | |
| raster_o = read_svg(object_svg_path, resize_obj=1, params=params, multiply=1.8, device=device) | |
| background_svg_path = background_output_dir / "best_iter.svg" | |
| raster_b = read_svg(background_svg_path, resize_obj=0, params=params, multiply=1.8, device=device) | |
| raster_b[mask == 1] = 1 | |
| raster_b[raster_o != 1] = raster_o[raster_o != 1] | |
| raster_b = torch.from_numpy(raster_b).unsqueeze(0).permute(0, 3, 1, 2).to(device) | |
| plot_img(raster_b, self.result_path, fname="combined") | |