Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Copyright (c) XiMing Xing. All rights reserved. | |
| # Author: XiMing Xing | |
| # Description: | |
| import math | |
| import copy | |
| import random | |
| import pathlib | |
| from typing import Dict | |
| from shapely.geometry.polygon import Polygon | |
| import omegaconf | |
| import cv2 | |
| import numpy as np | |
| import pydiffvg | |
| import torch | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from pytorch_svgrender.diffvg_warp import DiffVGState | |
| from pytorch_svgrender.libs.solver.optim import get_optimizer | |
| class Painter(DiffVGState): | |
| def __init__( | |
| self, | |
| diffvg_cfg: omegaconf.DictConfig, | |
| style: str, | |
| num_segments: int, | |
| segment_init: str, | |
| radius: int = 20, | |
| canvas_size: int = 600, | |
| n_grid: int = 32, | |
| trainable_bg: bool = False, | |
| stroke_width: int = 3, | |
| path_svg=None, | |
| device=None, | |
| ): | |
| super().__init__(device, print_timing=diffvg_cfg.print_timing, | |
| canvas_width=canvas_size, canvas_height=canvas_size) | |
| self.style = style | |
| self.num_segments = num_segments | |
| self.segment_init = segment_init | |
| self.radius = radius | |
| """pixelart params""" | |
| self.n_grid = n_grid # divide the canvas into n grids | |
| self.pixel_per_grid = self.canvas_width // self.n_grid | |
| """sketch params""" | |
| self.stroke_width = stroke_width | |
| """iconography params""" | |
| self.color_ref = None | |
| self.path_svg = path_svg | |
| self.optimize_flag = [] | |
| self.strokes_counter = 0 # counts the number of calls to "get_path" | |
| # Background color | |
| self.para_bg = torch.tensor([1., 1., 1.], requires_grad=trainable_bg, device=self.device) | |
| self.target_img = None | |
| self.pos_init_method = None | |
| def component_wise_path_init(self, gt, pred, init_type: str = 'sparse'): | |
| # set target image | |
| self.target_img = gt | |
| if init_type == 'random': | |
| self.pos_init_method = RandomCoordInit(self.canvas_height, self.canvas_width) | |
| elif init_type == 'sparse': | |
| # when initialized for the first time, the render result is None | |
| if pred is None: | |
| pred = self.para_bg.view(1, -1, 1, 1).repeat(1, 1, self.canvas_height, self.canvas_width) | |
| # then pred is the render result | |
| self.pos_init_method = SparseCoordInit(pred, gt) | |
| elif init_type == 'naive': | |
| if pred is None: | |
| pred = self.para_bg.view(1, -1, 1, 1).repeat(1, 1, self.canvas_height, self.canvas_width) | |
| self.pos_init_method = NaiveCoordInit(pred, gt) | |
| else: | |
| raise NotImplementedError(f"'{init_type}' is not support.") | |
| def init_image(self, stage=0, num_paths=0): | |
| self.cur_shapes, self.cur_shape_groups = [], [] | |
| # or init svg by pydiffvg | |
| if self.style in ['pixelart', 'low-poly']: # update path definition | |
| num_paths = self.n_grid | |
| if stage > 0: | |
| # Noting: if multi stages training than add new strokes on existing ones | |
| # don't optimize on previous strokes | |
| self.optimize_flag = [False for i in range(len(self.shapes))] | |
| for i in range(num_paths): | |
| if self.style == 'iconography': | |
| path = self.get_path() | |
| self.shapes.append(path) | |
| self.cur_shapes.append(path) | |
| fill_color_init = torch.FloatTensor(np.random.uniform(size=[4])) | |
| fill_color_init[-1] = 1.0 | |
| path_group = pydiffvg.ShapeGroup( | |
| shape_ids=torch.tensor([self.strokes_counter - 1]), | |
| fill_color=fill_color_init, | |
| stroke_color=None | |
| ) | |
| self.shape_groups.append(path_group) | |
| self.cur_shape_groups.append(path_group) | |
| self.optimize_flag.append(True) | |
| elif self.style in ['pixelart', 'low-poly']: | |
| for j in range(num_paths): | |
| path = self.get_path(coord=[i, j]) | |
| self.shapes.append(path) | |
| self.cur_shapes.append(path) | |
| fill_color_init = torch.FloatTensor(np.random.uniform(size=[4])) | |
| fill_color_init[-1] = 1.0 | |
| path_group = pydiffvg.ShapeGroup( | |
| shape_ids=torch.LongTensor([i * num_paths + j]), | |
| fill_color=fill_color_init, | |
| stroke_color=None, | |
| ) | |
| self.shape_groups.append(path_group) | |
| self.cur_shape_groups.append(path_group) | |
| self.optimize_flag.append(True) | |
| elif self.style in ['ink', 'sketch']: | |
| path = self.get_path() | |
| self.shapes.append(path) | |
| self.cur_shapes.append(path) | |
| stroke_color_init = [0.0, 0.0, 0.0] + [random.random()] | |
| stroke_color_init = torch.FloatTensor(stroke_color_init) | |
| path_group = pydiffvg.ShapeGroup( | |
| shape_ids=torch.tensor([len(self.shapes) - 1]), | |
| fill_color=None, | |
| stroke_color=stroke_color_init | |
| ) | |
| self.shape_groups.append(path_group) | |
| self.cur_shape_groups.append(path_group) | |
| elif self.style == 'painting': | |
| path = self.get_path() | |
| self.shapes.append(path) | |
| self.cur_shapes.append(path) | |
| wref, href = self.color_ref | |
| wref = max(0, min(int(wref), self.canvas_width - 1)) | |
| href = max(0, min(int(href), self.canvas_height - 1)) | |
| stroke_color_init = list(self.target_img[0, :, href, wref]) + [1.] | |
| path_group = pydiffvg.ShapeGroup( | |
| shape_ids=torch.tensor([len(self.shapes) - 1]), | |
| fill_color=None, | |
| stroke_color=torch.FloatTensor(stroke_color_init) | |
| ) | |
| self.shape_groups.append(path_group) | |
| self.cur_shape_groups.append(path_group) | |
| else: | |
| num_paths_exists = 0 | |
| if self.path_svg is not None and pathlib.Path(self.path_svg).exists(): | |
| print(f"-> init svg from `{self.path_svg}` ...") | |
| self.canvas_width, self.canvas_height, self.shapes, self.shape_groups = self.load_svg(self.path_svg) | |
| # if you want to add more strokes to existing ones and optimize on all of them | |
| num_paths_exists = len(self.shapes) | |
| self.cur_shapes = self.shapes | |
| self.cur_shape_groups = self.shape_groups | |
| for i in range(num_paths_exists, num_paths): | |
| if self.style == 'iconography': | |
| path = self.get_path() | |
| self.shapes.append(path) | |
| self.cur_shapes.append(path) | |
| wref, href = self.color_ref | |
| wref = max(0, min(int(wref), self.canvas_width - 1)) | |
| href = max(0, min(int(href), self.canvas_height - 1)) | |
| fill_color_init = list(self.target_img[0, :, href, wref]) + [1.] | |
| path_group = pydiffvg.ShapeGroup( | |
| shape_ids=torch.tensor([self.strokes_counter - 1]), | |
| fill_color=torch.FloatTensor(fill_color_init), | |
| stroke_color=None | |
| ) | |
| self.shape_groups.append(path_group) | |
| self.cur_shape_groups.append(path_group) | |
| elif self.style in ['pixelart', 'low-poly']: | |
| for j in range(num_paths): | |
| path = self.get_path(coord=[i, j]) | |
| self.shapes.append(path) | |
| self.cur_shapes.append(path) | |
| fill_color_init = torch.FloatTensor(np.random.uniform(size=[4])) | |
| fill_color_init[-1] = 1.0 | |
| path_group = pydiffvg.ShapeGroup( | |
| shape_ids=torch.LongTensor([i * num_paths + j]), | |
| fill_color=fill_color_init, | |
| stroke_color=None, | |
| ) | |
| self.shape_groups.append(path_group) | |
| self.cur_shape_groups.append(path_group) | |
| elif self.style in ['sketch', 'ink']: | |
| path = self.get_path() | |
| self.shapes.append(path) | |
| self.cur_shapes.append(path) | |
| stroke_color_init = [0.0, 0.0, 0.0] + [random.random()] | |
| stroke_color_init = torch.FloatTensor(stroke_color_init) | |
| path_group = pydiffvg.ShapeGroup( | |
| shape_ids=torch.tensor([len(self.shapes) - 1]), | |
| fill_color=None, | |
| stroke_color=stroke_color_init | |
| ) | |
| self.shape_groups.append(path_group) | |
| self.cur_shape_groups.append(path_group) | |
| elif self.style in ['painting']: | |
| path = self.get_path() | |
| self.shapes.append(path) | |
| self.cur_shapes.append(path) | |
| if self.color_ref is None: | |
| stroke_color_val = np.random.uniform(size=[4]) | |
| stroke_color_val[-1] = 1.0 | |
| stroke_color_init = torch.FloatTensor(stroke_color_val) | |
| else: | |
| wref, href = self.color_ref | |
| wref = max(0, min(int(wref), self.canvas_width - 1)) | |
| href = max(0, min(int(href), self.canvas_height - 1)) | |
| stroke_color_init = list(self.target_img[0, :, href, wref]) + [1.] | |
| stroke_color_init = torch.FloatTensor(stroke_color_init) | |
| path_group = pydiffvg.ShapeGroup( | |
| shape_ids=torch.tensor([len(self.shapes) - 1]), | |
| fill_color=None, | |
| stroke_color=stroke_color_init | |
| ) | |
| self.shape_groups.append(path_group) | |
| self.cur_shape_groups.append(path_group) | |
| self.optimize_flag = [True for i in range(len(self.shapes))] | |
| img = self.get_image() | |
| return img | |
| def get_image(self, step: int = 0): | |
| img = self.render_warp(step) | |
| img = img[:, :, 3:4] * img[:, :, :3] + self.para_bg * (1 - img[:, :, 3:4]) | |
| img = img.unsqueeze(0) # convert img from HWC to NCHW | |
| img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW | |
| return img | |
| def get_path(self, coord=None): | |
| num_segments = self.num_segments | |
| points = [] | |
| if self.style == 'iconography': | |
| # init segment | |
| if self.segment_init == 'circle': | |
| num_control_points = [2] * num_segments | |
| radius = self.radius if self.radius is not None else np.random.uniform(0.5, 1) | |
| if self.pos_init_method is not None: | |
| center = self.pos_init_method() | |
| else: | |
| center = (random.random(), random.random()) | |
| bias = center | |
| self.color_ref = copy.deepcopy(bias) | |
| avg_degree = 360 / (num_segments * 3) | |
| for i in range(0, num_segments * 3): | |
| point = ( | |
| np.cos(np.deg2rad(i * avg_degree)), np.sin(np.deg2rad(i * avg_degree)) | |
| ) | |
| points.append(point) | |
| points = torch.FloatTensor(points) * radius + torch.FloatTensor(bias).unsqueeze(dim=0) | |
| elif self.segment_init == 'random': | |
| num_control_points = [2] * num_segments | |
| p0 = self.pos_init_method() | |
| self.color_ref = copy.deepcopy(p0) | |
| points.append(p0) | |
| for j in range(num_segments): | |
| radius = self.radius | |
| p1 = (p0[0] + radius * np.random.uniform(-0.5, 0.5), | |
| p0[1] + radius * np.random.uniform(-0.5, 0.5)) | |
| p2 = (p1[0] + radius * np.random.uniform(-0.5, 0.5), | |
| p1[1] + radius * np.random.uniform(-0.5, 0.5)) | |
| p3 = (p2[0] + radius * np.random.uniform(-0.5, 0.5), | |
| p2[1] + radius * np.random.uniform(-0.5, 0.5)) | |
| points.append(p1) | |
| points.append(p2) | |
| if j < num_segments - 1: | |
| points.append(p3) | |
| p0 = p3 | |
| points = torch.FloatTensor(points) | |
| else: | |
| raise NotImplementedError(f"{self.segment_init} is not exists.") | |
| path = pydiffvg.Path( | |
| num_control_points=torch.LongTensor(num_control_points), | |
| points=points, | |
| stroke_width=torch.tensor(0.0), | |
| is_closed=True | |
| ) | |
| elif self.style in ['sketch', 'painting', 'ink']: | |
| num_control_points = torch.zeros(num_segments, dtype=torch.long) + 2 | |
| points = [] | |
| p0 = [random.random(), random.random()] | |
| points.append(p0) | |
| # select color by first point coordinate | |
| color_ref = copy.deepcopy(p0) | |
| color_ref[0] *= self.canvas_width | |
| color_ref[1] *= self.canvas_height | |
| self.color_ref = color_ref | |
| for j in range(num_segments): | |
| radius = 0.1 | |
| p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5)) | |
| p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5)) | |
| p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5)) | |
| points.append(p1) | |
| points.append(p2) | |
| points.append(p3) | |
| p0 = p3 | |
| points = torch.tensor(points).to(self.device) | |
| points[:, 0] *= self.canvas_width | |
| points[:, 1] *= self.canvas_height | |
| path = pydiffvg.Path(num_control_points=torch.LongTensor(num_control_points), | |
| points=points, | |
| stroke_width=torch.tensor(float(self.stroke_width)), | |
| is_closed=False) | |
| elif self.style in ['pixelart', 'low-poly']: | |
| x = coord[0] * self.pixel_per_grid | |
| y = coord[1] * self.pixel_per_grid | |
| points = torch.FloatTensor([ | |
| [x, y], | |
| [x + self.pixel_per_grid, y], | |
| [x + self.pixel_per_grid, y + self.pixel_per_grid], | |
| [x, y + self.pixel_per_grid] | |
| ]).to(self.device) | |
| path = pydiffvg.Polygon(points=points, | |
| stroke_width=torch.tensor(0.0), | |
| is_closed=True) | |
| self.strokes_counter += 1 | |
| return path | |
| def clip_curve_shape(self): | |
| if self.style in ['sketch', 'ink']: | |
| for group in self.shape_groups: | |
| group.stroke_color.data[:3].clamp_(0., 0.) # to force black stroke | |
| group.stroke_color.data[-1].clamp_(0., 1.) # clip alpha | |
| else: | |
| for group in self.shape_groups: | |
| if group.stroke_color is not None: | |
| group.stroke_color.data.clamp_(0.0, 1.0) # clip rgba | |
| if group.fill_color is not None: | |
| group.fill_color.data.clamp_(0.0, 1.0) # clip rgba | |
| def reinitialize_paths(self, | |
| reinit_path: bool = False, | |
| opacity_threshold: float = None, | |
| area_threshold: float = None, | |
| fpath: pathlib.Path = None): | |
| """ | |
| reinitialize paths, also known as 'Reinitializing paths' in VectorFusion paper. | |
| Args: | |
| reinit_path: whether to reinitialize paths or not. | |
| opacity_threshold: Threshold of opacity. | |
| area_threshold: Threshold of the closed polygon area. | |
| fpath: The path to save the reinitialized SVG. | |
| """ | |
| if not reinit_path: | |
| return | |
| if self.style not in ['iconography', 'low-poly', 'painting']: | |
| return | |
| def get_keys_below_threshold(my_dict, threshold): | |
| keys_below_threshold = [key for key, value in my_dict.items() if value < threshold] | |
| return keys_below_threshold | |
| select_path_ids_by_opc = [] | |
| select_path_ids_by_area = [] | |
| if self.style in ['iconography', 'low-poly']: | |
| # re-init by opacity_threshold | |
| if opacity_threshold != 0 and opacity_threshold is not None: | |
| opacity_record_ = {group.shape_ids.item(): group.fill_color[-1].item() | |
| for group in self.cur_shape_groups} | |
| # print("-> opacity_record: ", opacity_record_) | |
| print("-> opacity_record: ", [f"{k}: {v:.3f}" for k, v in opacity_record_.items()]) | |
| select_path_ids_by_opc = get_keys_below_threshold(opacity_record_, opacity_threshold) | |
| print("select_path_ids_by_opc: ", select_path_ids_by_opc) | |
| # remove path by area_threshold | |
| if area_threshold != 0 and area_threshold is not None: | |
| area_records = [Polygon(shape.points.detach().cpu().numpy()).area for shape in self.cur_shapes] | |
| # print("-> area_records: ", area_records) | |
| print("-> area_records: ", ['%.2f' % i for i in area_records]) | |
| for i, shape in enumerate(self.cur_shapes): | |
| points_ = shape.points.detach().cpu().numpy() | |
| if Polygon(points_).area < area_threshold: | |
| select_path_ids_by_area.append(shape.id) | |
| print("select_path_ids_by_area: ", select_path_ids_by_area) | |
| elif self.style in ['painting']: | |
| # re-init by opacity_threshold | |
| if opacity_threshold != 0 and opacity_threshold is not None: | |
| opacity_record_ = {group.shape_ids.item(): group.stroke_color[-1].item() | |
| for group in self.cur_shape_groups} | |
| # print("-> opacity_record: ", opacity_record_) | |
| print("-> opacity_record: ", [f"{k}: {v:.3f}" for k, v in opacity_record_.items()]) | |
| select_path_ids_by_opc = get_keys_below_threshold(opacity_record_, opacity_threshold) | |
| print("select_path_ids_by_opc: ", select_path_ids_by_opc) | |
| # re-init paths | |
| reinit_union = list(set(select_path_ids_by_opc + select_path_ids_by_area)) | |
| if len(reinit_union) > 0: | |
| for i, path in enumerate(self.cur_shapes): | |
| if path.id in reinit_union: | |
| coord = [i, i] if self.style == 'low-poly' else None | |
| self.cur_shapes[i] = self.get_path(coord=coord) | |
| for i, group in enumerate(self.cur_shape_groups): | |
| shp_ids = group.shape_ids.cpu().numpy().tolist() | |
| if set(shp_ids).issubset(reinit_union): | |
| if self.style in ['iconography', 'low-poly']: | |
| fill_color_init = torch.FloatTensor(np.random.uniform(size=[4])) | |
| fill_color_init[-1] = 1.0 | |
| self.cur_shape_groups[i] = pydiffvg.ShapeGroup( | |
| shape_ids=torch.tensor(list(shp_ids)), | |
| fill_color=fill_color_init, | |
| stroke_color=None) | |
| elif self.style in ['painting']: | |
| stroke_color_init = torch.FloatTensor(np.random.uniform(size=[4])) | |
| stroke_color_init[-1] = 1.0 | |
| self.cur_shape_groups[i] = pydiffvg.ShapeGroup( | |
| shape_ids=torch.tensor([len(self.shapes) - 1]), | |
| fill_color=None, | |
| stroke_color=stroke_color_init | |
| ) | |
| # save reinit svg | |
| self.pretty_save_svg(fpath) | |
| print("-" * 40) | |
| def calc_distance_weight(self, loss_weight_keep): | |
| shapes_forsdf = copy.deepcopy(self.cur_shapes) | |
| shape_groups_forsdf = copy.deepcopy(self.cur_shape_groups) | |
| for si in shapes_forsdf: | |
| si.stroke_width = torch.FloatTensor([0]).to(self.device) | |
| for sg_idx, sgi in enumerate(shape_groups_forsdf): | |
| sgi.fill_color = torch.FloatTensor([1, 1, 1, 1]).to(self.device) | |
| sgi.shape_ids = torch.LongTensor([sg_idx]).to(self.device) | |
| sargs_forsdf = pydiffvg.RenderFunction.serialize_scene( | |
| self.canvas_width, self.canvas_height, shapes_forsdf, shape_groups_forsdf | |
| ) | |
| _render = pydiffvg.RenderFunction.apply | |
| with torch.no_grad(): | |
| im_forsdf = _render(self.canvas_width, # width | |
| self.canvas_height, # height | |
| 2, # num_samples_x | |
| 2, # num_samples_y | |
| 0, # seed | |
| None, | |
| *sargs_forsdf) | |
| # use alpha channel is a trick to get 0-1 image | |
| im_forsdf = (im_forsdf[:, :, 3]).detach().cpu().numpy() | |
| loss_weight = get_sdf(im_forsdf, normalize='to1') | |
| loss_weight += loss_weight_keep | |
| loss_weight = np.clip(loss_weight, 0, 1) | |
| loss_weight = torch.FloatTensor(loss_weight).to(self.device) | |
| return loss_weight | |
| def set_point_parameters(self, id_delta=0): | |
| self.point_vars = [] | |
| for i, path in enumerate(self.cur_shapes): | |
| path.id = i + id_delta # set point id | |
| path.points.requires_grad = True | |
| self.point_vars.append(path.points) | |
| def get_point_parameters(self): | |
| return self.point_vars | |
| def set_color_parameters(self): | |
| self.color_vars = [] | |
| for i, group in enumerate(self.cur_shape_groups): | |
| if group.fill_color is not None: | |
| group.fill_color.requires_grad = True | |
| self.color_vars.append(group.fill_color) | |
| if group.stroke_color is not None: | |
| group.stroke_color.requires_grad = True | |
| self.color_vars.append(group.stroke_color) | |
| def get_color_parameters(self): | |
| return self.color_vars | |
| def set_width_parameters(self): | |
| # stroke`s width optimization | |
| self.width_vars = [] | |
| for i, path in enumerate(self.shapes): | |
| path.stroke_width.requires_grad = True | |
| self.width_vars.append(path.stroke_width) | |
| def get_width_parameters(self): | |
| return self.width_vars | |
| def pretty_save_svg(self, filename, width=None, height=None, shapes=None, shape_groups=None): | |
| width = self.canvas_width if width is None else width | |
| height = self.canvas_height if height is None else height | |
| shapes = self.shapes if shapes is None else shapes | |
| shape_groups = self.shape_groups if shape_groups is None else shape_groups | |
| self.save_svg(filename, width, height, shapes, shape_groups, use_gamma=False, background=None) | |
| def load_svg(self, path_svg): | |
| canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg) | |
| return canvas_width, canvas_height, shapes, shape_groups | |
| def get_sdf(phi, **kwargs): | |
| import skfmm # local import | |
| phi = (phi - 0.5) * 2 | |
| if (phi.max() <= 0) or (phi.min() >= 0): | |
| return np.zeros(phi.shape).astype(np.float32) | |
| sd = skfmm.distance(phi, dx=1) | |
| flip_negative = kwargs.get('flip_negative', True) | |
| if flip_negative: | |
| sd = np.abs(sd) | |
| truncate = kwargs.get('truncate', 10) | |
| sd = np.clip(sd, -truncate, truncate) | |
| # print(f"max sd value is: {sd.max()}") | |
| zero2max = kwargs.get('zero2max', True) | |
| if zero2max and flip_negative: | |
| sd = sd.max() - sd | |
| elif zero2max: | |
| raise ValueError | |
| normalize = kwargs.get('normalize', 'sum') | |
| if normalize == 'sum': | |
| sd /= sd.sum() | |
| elif normalize == 'to1': | |
| sd /= sd.max() | |
| return sd | |
| class SparseCoordInit: | |
| def __init__(self, pred, gt, format='[bs x c x 2D]', quantile_interval=200, nodiff_thres=0.1): | |
| if torch.is_tensor(pred): | |
| pred = pred.detach().cpu().numpy() | |
| if torch.is_tensor(gt): | |
| gt = gt.detach().cpu().numpy() | |
| if format == '[bs x c x 2D]': | |
| self.map = ((pred[0] - gt[0]) ** 2).sum(0) | |
| self.reference_gt = copy.deepcopy(np.transpose(gt[0], (1, 2, 0))) | |
| elif format == ['[2D x c]']: | |
| self.map = (np.abs(pred - gt)).sum(-1) | |
| self.reference_gt = copy.deepcopy(gt[0]) | |
| else: | |
| raise ValueError | |
| # OptionA: Zero too small errors to avoid the error too small deadloop | |
| self.map[self.map < nodiff_thres] = 0 | |
| quantile_interval = np.linspace(0., 1., quantile_interval) | |
| quantized_interval = np.quantile(self.map, quantile_interval) | |
| # remove redundant | |
| quantized_interval = np.unique(quantized_interval) | |
| quantized_interval = sorted(quantized_interval[1:-1]) | |
| self.map = np.digitize(self.map, quantized_interval, right=False) | |
| self.map = np.clip(self.map, 0, 255).astype(np.uint8) | |
| self.idcnt = {} | |
| for idi in sorted(np.unique(self.map)): | |
| self.idcnt[idi] = (self.map == idi).sum() | |
| # remove smallest one to remove the correct region | |
| self.idcnt.pop(min(self.idcnt.keys())) | |
| def __call__(self): | |
| if len(self.idcnt) == 0: | |
| h, w = self.map.shape | |
| return [np.random.uniform(0, 1) * w, np.random.uniform(0, 1) * h] | |
| target_id = max(self.idcnt, key=self.idcnt.get) | |
| _, component, cstats, ccenter = cv2.connectedComponentsWithStats( | |
| (self.map == target_id).astype(np.uint8), | |
| connectivity=4 | |
| ) | |
| # remove cid = 0, it is the invalid area | |
| csize = [ci[-1] for ci in cstats[1:]] | |
| target_cid = csize.index(max(csize)) + 1 | |
| center = ccenter[target_cid][::-1] | |
| coord = np.stack(np.where(component == target_cid)).T | |
| dist = np.linalg.norm(coord - center, axis=1) | |
| target_coord_id = np.argmin(dist) | |
| coord_h, coord_w = coord[target_coord_id] | |
| # replace_sampling | |
| self.idcnt[target_id] -= max(csize) | |
| if self.idcnt[target_id] == 0: | |
| self.idcnt.pop(target_id) | |
| self.map[component == target_cid] = 0 | |
| return [coord_w, coord_h] | |
| class RandomCoordInit: | |
| def __init__(self, canvas_width, canvas_height): | |
| self.canvas_width, self.canvas_height = canvas_width, canvas_height | |
| def __call__(self): | |
| w, h = self.canvas_width, self.canvas_height | |
| return [np.random.uniform(0, 1) * w, np.random.uniform(0, 1) * h] | |
| class NaiveCoordInit: | |
| def __init__(self, pred, gt, format='[bs x c x 2D]', replace_sampling=True): | |
| if isinstance(pred, torch.Tensor): | |
| pred = pred.detach().cpu().numpy() | |
| if isinstance(gt, torch.Tensor): | |
| gt = gt.detach().cpu().numpy() | |
| if format == '[bs x c x 2D]': | |
| self.map = ((pred[0] - gt[0]) ** 2).sum(0) | |
| elif format == ['[2D x c]']: | |
| self.map = ((pred - gt) ** 2).sum(-1) | |
| else: | |
| raise ValueError | |
| self.replace_sampling = replace_sampling | |
| def __call__(self): | |
| coord = np.where(self.map == self.map.max()) | |
| coord_h, coord_w = coord[0][0], coord[1][0] | |
| if self.replace_sampling: | |
| self.map[coord_h, coord_w] = -1 | |
| return [coord_w, coord_h] | |
| class PainterOptimizer: | |
| def __init__(self, | |
| renderer: Painter, | |
| style: str, | |
| num_iter: int, | |
| lr_config: omegaconf.DictConfig, | |
| trainable_bg: bool = False): | |
| self.renderer = renderer | |
| self.num_iter = num_iter | |
| self.trainable_bg = trainable_bg | |
| self.lr_config = lr_config | |
| # set optimized params via style | |
| self.optim_point, self.optim_color, self.optim_width = { | |
| "iconography": (True, True, False), | |
| "pixelart": (False, True, False), | |
| "low-poly": (True, True, False), | |
| "sketch": (True, True, False), | |
| "ink": (True, True, True), | |
| "painting": (True, True, True) | |
| }.get(style, (False, False, False)) | |
| self.optim_bg = trainable_bg | |
| # set lr schedule | |
| schedule_cfg = lr_config.schedule | |
| if schedule_cfg.name == 'linear': | |
| self.lr_lambda = LinearDecayWithKeepLRLambda(init_lr=lr_config.point, | |
| keep_ratio=schedule_cfg.keep_ratio, | |
| decay_every=self.num_iter, | |
| decay_ratio=schedule_cfg.decay_ratio) | |
| elif schedule_cfg.name == 'cosine': | |
| self.lr_lambda = CosineWithWarmupLRLambda(num_steps=self.num_iter, | |
| warmup_steps=schedule_cfg.warmup_steps, | |
| warmup_start_lr=schedule_cfg.warmup_start_lr, | |
| warmup_end_lr=schedule_cfg.warmup_end_lr, | |
| cosine_end_lr=schedule_cfg.cosine_end_lr) | |
| else: | |
| print(f"{schedule_cfg.name} is not support.") | |
| self.lr_lambda = None | |
| self.point_optimizer = None | |
| self.color_optimizer = None | |
| self.width_optimizer = None | |
| self.bg_optimizer = None | |
| self.point_scheduler = None | |
| def init_optimizers(self, pid_delta: int = 0): | |
| # optimizer | |
| optim_cfg = self.lr_config.optim | |
| optim_name = optim_cfg.name | |
| params = {} | |
| if self.optim_point: | |
| self.renderer.set_point_parameters(pid_delta) | |
| params['point'] = self.renderer.get_point_parameters() | |
| self.point_optimizer = get_optimizer(optim_name, params['point'], self.lr_config.point, optim_cfg) | |
| if self.optim_color: | |
| self.renderer.set_color_parameters() | |
| params['color'] = self.renderer.get_color_parameters() | |
| self.color_optimizer = get_optimizer(optim_name, params['color'], self.lr_config.color, optim_cfg) | |
| if self.optim_width: | |
| self.renderer.set_width_parameters() | |
| params['width'] = self.renderer.get_width_parameters() | |
| if len(params['width']) > 0: | |
| self.width_optimizer = get_optimizer(optim_name, params['width'], self.lr_config.width, optim_cfg) | |
| if self.optim_bg: | |
| self.renderer.para_bg.requires_grad = True | |
| self.bg_optimizer = get_optimizer(optim_name, self.renderer.para_bg, self.lr_config.bg, optim_cfg) | |
| # lr schedule | |
| if self.lr_lambda is not None and self.optim_point: | |
| self.point_scheduler = LambdaLR(self.point_optimizer, lr_lambda=self.lr_lambda, last_epoch=-1) | |
| def update_lr(self): | |
| if self.point_scheduler is not None: | |
| self.point_scheduler.step() | |
| def zero_grad_(self): | |
| if self.point_optimizer is not None: | |
| self.point_optimizer.zero_grad() | |
| if self.color_optimizer is not None: | |
| self.color_optimizer.zero_grad() | |
| if self.width_optimizer is not None: | |
| self.width_optimizer.zero_grad() | |
| if self.bg_optimizer is not None: | |
| self.bg_optimizer.zero_grad() | |
| def step_(self): | |
| if self.point_optimizer is not None: | |
| self.point_optimizer.step() | |
| if self.color_optimizer is not None: | |
| self.color_optimizer.step() | |
| if self.width_optimizer is not None: | |
| self.width_optimizer.step() | |
| if self.bg_optimizer is not None: | |
| self.bg_optimizer.step() | |
| def get_lr(self) -> Dict: | |
| lr = {} | |
| if self.point_optimizer is not None: | |
| lr['pnt'] = self.point_optimizer.param_groups[0]['lr'] | |
| if self.color_optimizer is not None: | |
| lr['clr'] = self.color_optimizer.param_groups[0]['lr'] | |
| if self.width_optimizer is not None: | |
| lr['wd'] = self.width_optimizer.param_groups[0]['lr'] | |
| if self.bg_optimizer is not None: | |
| lr['bg'] = self.bg_optimizer.param_groups[0]['lr'] | |
| return lr | |
| class LinearDecayWithKeepLRLambda: | |
| """apply in LIVE stage""" | |
| def __init__(self, init_lr, keep_ratio, decay_every, decay_ratio): | |
| self.init_lr = init_lr | |
| self.keep_ratio = keep_ratio | |
| self.decay_every = decay_every | |
| self.decay_ratio = decay_ratio | |
| def __call__(self, n): | |
| if n < self.keep_ratio * self.decay_every: | |
| return self.init_lr | |
| decay_time = n // self.decay_every | |
| decay_step = n % self.decay_every | |
| lr_s = self.decay_ratio ** decay_time | |
| lr_e = self.decay_ratio ** (decay_time + 1) | |
| r = decay_step / self.decay_every | |
| lr = lr_s * (1 - r) + lr_e * r | |
| return lr | |
| class CosineWithWarmupLRLambda: | |
| """apply in fine-tuning stage""" | |
| def __init__(self, num_steps, warmup_steps, warmup_start_lr, warmup_end_lr, cosine_end_lr): | |
| self.n_steps = num_steps | |
| self.n_warmup = warmup_steps | |
| self.warmup_start_lr = warmup_start_lr | |
| self.warmup_end_lr = warmup_end_lr | |
| self.cosine_end_lr = cosine_end_lr | |
| def __call__(self, n): | |
| if n < self.n_warmup: | |
| # linearly warmup | |
| return self.warmup_start_lr + (n / self.n_warmup) * (self.warmup_end_lr - self.warmup_start_lr) | |
| else: | |
| # cosine decayed schedule | |
| return self.cosine_end_lr + 0.5 * (self.warmup_end_lr - self.cosine_end_lr) * ( | |
| 1 + math.cos(math.pi * (n - self.n_warmup) / (self.n_steps - self.n_warmup))) | |