Spaces:
Running
Running
| import pathlib | |
| import random | |
| import numpy as np | |
| import omegaconf | |
| import pydiffvg | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from pytorch_svgrender.diffvg_warp import DiffVGState | |
| from pytorch_svgrender.libs.modules.edge_map.DoG import XDoG | |
| from pytorch_svgrender.painter.clipasso import modified_clip as clip | |
| from pytorch_svgrender.painter.clipasso.grad_cam import gradCAM | |
| from torchvision import transforms | |
| class Painter(DiffVGState): | |
| def __init__( | |
| self, | |
| method_cfg: omegaconf.DictConfig, | |
| diffvg_cfg: omegaconf.DictConfig, | |
| num_strokes: int = 4, | |
| canvas_size: int = 224, | |
| device=None, | |
| target_im=None, | |
| mask=None | |
| ): | |
| super(Painter, self).__init__(device, print_timing=diffvg_cfg.print_timing, | |
| canvas_width=canvas_size, canvas_height=canvas_size) | |
| self.args = method_cfg | |
| self.num_paths = num_strokes | |
| self.num_segments = method_cfg.num_segments | |
| self.width = method_cfg.width | |
| self.control_points_per_seg = method_cfg.control_points_per_seg | |
| self.num_control_points = torch.zeros(self.num_segments, dtype=torch.int32) + (self.control_points_per_seg - 2) | |
| self.opacity_optim = method_cfg.force_sparse | |
| self.num_stages = method_cfg.num_stages | |
| self.noise_thresh = method_cfg.noise_thresh | |
| self.softmax_temp = method_cfg.softmax_temp | |
| self.add_random_noise = "noise" in method_cfg.augemntations | |
| self.optimize_points = method_cfg.optimize_points | |
| self.optimize_points_global = method_cfg.optimize_points | |
| self.points_init = [] # for mlp training | |
| self.color_vars_threshold = method_cfg.color_vars_threshold | |
| self.path_svg = method_cfg.path_svg | |
| self.strokes_per_stage = self.num_paths | |
| self.optimize_flag = [] | |
| # attention related for strokes initialisation | |
| self.attention_init = method_cfg.attention_init | |
| self.saliency_model = method_cfg.saliency_model | |
| self.xdog_intersec = method_cfg.xdog_intersec | |
| self.mask_object_attention = method_cfg.mask_object_attention | |
| self.text_target = method_cfg.text_target # for clip gradients | |
| self.saliency_clip_model = method_cfg.saliency_clip_model | |
| self.image2clip_input = self.clip_preprocess(target_im) | |
| self.mask = mask | |
| self.attention_map = self.set_attention_map() if self.attention_init else None | |
| self.thresh = self.set_attention_threshold_map() if self.attention_init else None | |
| self.strokes_counter = 0 # counts the number of calls to "get_path" | |
| self.epoch = 0 | |
| self.final_epoch = method_cfg.num_iter - 1 | |
| if "for" in method_cfg.loss_mask: | |
| # default for the mask is to mask out the background | |
| # if mask loss is for it means we want to maskout the foreground | |
| self.mask = 1 - mask | |
| self.mlp_train = method_cfg.mlp_train | |
| self.width_optim = method_cfg.width_optim | |
| self.width_optim_global = method_cfg.width_optim | |
| if self.width_optim: | |
| self.init_widths = torch.ones((self.num_paths)).to(device) * 1.5 | |
| self.mlp_width = WidthMLP(num_strokes=self.num_paths, num_cp=self.control_points_per_seg, | |
| width_optim=self.width_optim).to(device) | |
| self.mlp_width_weights_path = method_cfg.mlp_width_weights_path | |
| self.mlp_width_weight_init() | |
| self.gumbel_temp = method_cfg.gumbel_temp | |
| self.mlp = MLP(num_strokes=self.num_paths, num_cp=self.control_points_per_seg, width_optim=self.width_optim).to( | |
| device) if self.mlp_train else None | |
| self.mlp_points_weights_path = method_cfg.mlp_points_weights_path | |
| self.mlp_points_weight_init() | |
| self.out_of_canvas_mask = torch.ones((self.num_paths)).to(self.device) | |
| def turn_off_points_optim(self): | |
| self.optimize_points = False | |
| def switch_opt(self): | |
| self.width_optim = not self.width_optim | |
| self.optimize_points = not self.optimize_points | |
| def mlp_points_weight_init(self): | |
| if self.mlp_points_weights_path != "none": | |
| checkpoint = torch.load(self.mlp_points_weights_path) | |
| self.mlp.load_state_dict(checkpoint['model_state_dict']) | |
| print("mlp checkpoint loaded from ", self.mlp_points_weights_path) | |
| def mlp_width_weight_init(self): | |
| if self.mlp_width_weights_path == "none": | |
| self.mlp_width.apply(init_weights) | |
| else: | |
| checkpoint = torch.load(self.mlp_width_weights_path) | |
| self.mlp_width.load_state_dict(checkpoint['model_state_dict']) | |
| print("mlp checkpoint loaded from ", self.mlp_width_weights_path) | |
| def init_image(self, stage=0): | |
| if stage > 0: | |
| # Noting: if multi stages training than add new strokes on existing ones | |
| # don't optimize on previous strokes | |
| self.optimize_flag = [False for i in range(len(self.shapes))] | |
| for i in range(self.strokes_per_stage): | |
| stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0]) | |
| path = self.get_path() | |
| self.shapes.append(path) | |
| path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(self.shapes) - 1]), | |
| fill_color=None, | |
| stroke_color=stroke_color) | |
| self.shape_groups.append(path_group) | |
| self.optimize_flag.append(True) | |
| else: | |
| num_paths_exists = 0 | |
| if self.path_svg is not None and pathlib.Path(self.path_svg).exists(): | |
| print(f"-> init svg from `{self.path_svg}` ...") | |
| self.canvas_width, self.canvas_height, self.shapes, self.shape_groups = self.load_svg(self.path_svg) | |
| # if you want to add more strokes to existing ones and optimize on all of them | |
| num_paths_exists = len(self.shapes) | |
| for path in self.shapes: | |
| self.points_init.append(path.points) | |
| for i in range(num_paths_exists, self.num_paths): | |
| stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0]) | |
| path = self.get_path() | |
| self.shapes.append(path) | |
| path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(self.shapes) - 1]), | |
| fill_color=None, | |
| stroke_color=stroke_color) | |
| self.shape_groups.append(path_group) | |
| self.optimize_flag = [True for i in range(len(self.shapes))] | |
| def get_image(self, mode="train"): | |
| if self.mlp_train: | |
| img = self.mlp_pass(mode) | |
| else: | |
| img = self.render_warp(mode) | |
| opacity = img[:, :, 3:4] | |
| img = opacity * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device=self.device) * (1 - opacity) | |
| img = img[:, :, :3] | |
| # Convert img from HWC to NCHW | |
| img = img.unsqueeze(0) | |
| img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW | |
| return img | |
| def mlp_pass(self, mode, eps=1e-4): | |
| """ | |
| update self.shapes etc through mlp pass instead of directly (should be updated with the optimizer as well). | |
| """ | |
| if self.optimize_points_global: | |
| points_vars = self.points_init | |
| # reshape and normalise to [-1,1] range | |
| points_vars = torch.stack(points_vars).unsqueeze(0).to(self.device) | |
| points_vars = points_vars / self.canvas_width | |
| points_vars = 2 * points_vars - 1 | |
| if self.optimize_points: | |
| points = self.mlp(points_vars) | |
| else: | |
| with torch.no_grad(): | |
| points = self.mlp(points_vars) | |
| else: | |
| points = torch.stack(self.points_init).unsqueeze(0).to(self.device) | |
| if self.width_optim and mode != "init": # first iter use just the location mlp | |
| widths_ = self.mlp_width(self.init_widths).clamp(min=1e-8) | |
| mask_flipped = (1 - widths_).clamp(min=1e-8) | |
| v = torch.stack((torch.log(widths_), torch.log(mask_flipped)), dim=-1) | |
| hard_mask = torch.nn.functional.gumbel_softmax(v, self.gumbel_temp, False) | |
| self.stroke_probs = hard_mask[:, 0] * self.out_of_canvas_mask | |
| self.widths = self.stroke_probs * self.init_widths | |
| # normalize back to canvas size [0, 224] and reshape | |
| all_points = 0.5 * (points + 1.0) * self.canvas_width | |
| all_points = all_points + eps * torch.randn_like(all_points) | |
| all_points = all_points.reshape((-1, self.num_paths, self.control_points_per_seg, 2)) | |
| if self.width_optim_global and not self.width_optim: | |
| self.widths = self.widths.detach() | |
| # all_points = all_points.detach() | |
| # define new primitives to render | |
| shapes = [] | |
| shape_groups = [] | |
| for p in range(self.num_paths): | |
| width = torch.tensor(self.width) | |
| if self.width_optim_global and mode != "init": | |
| width = self.widths[p] | |
| path = pydiffvg.Path( | |
| num_control_points=self.num_control_points, points=all_points[:, p].reshape((-1, 2)), | |
| stroke_width=width, is_closed=False) | |
| if mode == "init": | |
| # do once at the begining, define a mask for strokes that are outside the canvas | |
| is_in_canvas_ = self.is_in_canvas(self.canvas_width, self.canvas_height, path) | |
| if not is_in_canvas_: | |
| self.out_of_canvas_mask[p] = 0 | |
| shapes.append(path) | |
| path_group = pydiffvg.ShapeGroup( | |
| shape_ids=torch.tensor([len(shapes) - 1]), | |
| fill_color=None, | |
| stroke_color=torch.tensor([0, 0, 0, 1])) | |
| shape_groups.append(path_group) | |
| _render = pydiffvg.RenderFunction.apply | |
| scene_method_cfg = pydiffvg.RenderFunction.serialize_scene( \ | |
| self.canvas_width, self.canvas_height, shapes, shape_groups) | |
| img = _render(self.canvas_width, # width | |
| self.canvas_height, # height | |
| 2, # num_samples_x | |
| 2, # num_samples_y | |
| 0, # seed | |
| None, | |
| *scene_method_cfg) | |
| self.shapes = shapes.copy() | |
| self.shape_groups = shape_groups.copy() | |
| return img | |
| def get_path(self): | |
| points = [] | |
| p0 = self.inds_normalised[self.strokes_counter] if self.attention_init else (random.random(), random.random()) | |
| points.append(p0) | |
| for j in range(self.num_segments): | |
| radius = 0.05 | |
| for k in range(self.control_points_per_seg - 1): | |
| p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5)) | |
| points.append(p1) | |
| p0 = p1 | |
| points = torch.tensor(points).to(self.device) | |
| points[:, 0] *= self.canvas_width | |
| points[:, 1] *= self.canvas_height | |
| self.points_init.append(points) | |
| path = pydiffvg.Path(num_control_points=self.num_control_points, | |
| points=points, | |
| stroke_width=torch.tensor(self.width), | |
| is_closed=False) | |
| self.strokes_counter += 1 | |
| return path | |
| def render_warp(self, mode): | |
| if not self.mlp_train: | |
| if self.opacity_optim: | |
| for group in self.shape_groups: | |
| group.stroke_color.data[:3].clamp_(0., 0.) # to force black stroke | |
| group.stroke_color.data[-1].clamp_(0., 1.) # opacity | |
| # group.stroke_color.data[-1] = (group.stroke_color.data[-1] >= self.color_vars_threshold).float() | |
| # uncomment if you want to add random noise | |
| if self.add_random_noise: | |
| if random.random() > self.noise_thresh: | |
| eps = 0.01 * min(self.canvas_width, self.canvas_height) | |
| for path in self.shapes: | |
| path.points.data.add_(eps * torch.randn_like(path.points)) | |
| if self.width_optim and mode != "init": | |
| widths_ = self.mlp_width(self.init_widths).clamp(min=1e-8) | |
| mask_flipped = 1 - widths_ | |
| v = torch.stack((torch.log(widths_), torch.log(mask_flipped)), dim=-1) | |
| hard_mask = torch.nn.functional.gumbel_softmax(v, self.gumbel_temp, False) | |
| self.stroke_probs = hard_mask[:, 0] * self.out_of_canvas_mask | |
| self.widths = self.stroke_probs * self.init_widths | |
| if self.optimize_points: | |
| _render = pydiffvg.RenderFunction.apply | |
| scene_method_cfg = pydiffvg.RenderFunction.serialize_scene( \ | |
| self.canvas_width, self.canvas_height, self.shapes, self.shape_groups) | |
| img = _render(self.canvas_width, # width | |
| self.canvas_height, # height | |
| 2, # num_samples_x | |
| 2, # num_samples_y | |
| 0, # seed | |
| None, | |
| *scene_method_cfg) | |
| else: | |
| points = torch.stack(self.points_init).unsqueeze(0).to(self.device) | |
| shapes = [] | |
| shape_groups = [] | |
| for p in range(self.num_paths): | |
| width = torch.tensor(self.width) | |
| if self.width_optim: | |
| width = self.widths[p] | |
| path = pydiffvg.Path( | |
| num_control_points=self.num_control_points, points=points[:, p].reshape((-1, 2)), | |
| stroke_width=width, is_closed=False) | |
| shapes.append(path) | |
| path_group = pydiffvg.ShapeGroup( | |
| shape_ids=torch.tensor([len(shapes) - 1]), | |
| fill_color=None, | |
| stroke_color=torch.tensor([0, 0, 0, 1])) | |
| shape_groups.append(path_group) | |
| _render = pydiffvg.RenderFunction.apply | |
| scene_method_cfg = pydiffvg.RenderFunction.serialize_scene( \ | |
| self.canvas_width, self.canvas_height, shapes, shape_groups) | |
| img = _render(self.canvas_width, # width | |
| self.canvas_height, # height | |
| 2, # num_samples_x | |
| 2, # num_samples_y | |
| 0, # seed | |
| None, | |
| *scene_method_cfg) | |
| self.shapes = shapes.copy() | |
| self.shape_groups = shape_groups.copy() | |
| return img | |
| def parameters(self): | |
| if self.optimize_points: | |
| if self.mlp_train: | |
| self.points_vars = self.mlp.parameters() | |
| else: | |
| self.points_vars = [] | |
| # storkes' location optimization | |
| for i, path in enumerate(self.shapes): | |
| if self.optimize_flag[i]: | |
| path.points.requires_grad = True | |
| self.points_vars.append(path.points) | |
| self.optimize_flag[i] = False | |
| if self.width_optim: | |
| return self.points_vars, self.mlp_width.parameters() | |
| return self.points_vars | |
| def get_mlp(self): | |
| return self.mlp | |
| def get_width_mlp(self): | |
| if self.width_optim_global: | |
| return self.mlp_width | |
| else: | |
| return None | |
| def set_color_parameters(self): | |
| # for storkes' color optimization (opacity) | |
| self.color_vars = [] | |
| for i, group in enumerate(self.shape_groups): | |
| if self.optimize_flag[i]: | |
| group.stroke_color.requires_grad = True | |
| self.color_vars.append(group.stroke_color) | |
| return self.color_vars | |
| def get_color_parameters(self): | |
| return self.color_vars | |
| def get_widths(self): | |
| if self.width_optim_global: | |
| return self.stroke_probs | |
| return None | |
| def get_strokes_in_canvas_count(self): | |
| return self.out_of_canvas_mask.sum() | |
| def get_strokes_count(self): | |
| if self.width_optim_global: | |
| with torch.no_grad(): | |
| return torch.sum(self.stroke_probs) | |
| return self.num_paths | |
| def is_in_canvas(self, canvas_width, canvas_height, path): | |
| shapes, shape_groups = [], [] | |
| stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0]) | |
| shapes.append(path) | |
| path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(shapes) - 1]), | |
| fill_color=None, | |
| stroke_color=stroke_color) | |
| shape_groups.append(path_group) | |
| _render = pydiffvg.RenderFunction.apply | |
| scene_method_cfg = pydiffvg.RenderFunction.serialize_scene( | |
| canvas_width, canvas_height, shapes, shape_groups) | |
| img = _render(canvas_width, # width | |
| canvas_height, # height | |
| 2, # num_samples_x | |
| 2, # num_samples_y | |
| 0, # seed | |
| None, | |
| *scene_method_cfg) | |
| img = img[:, :, 3:4] * img[:, :, :3] + \ | |
| torch.ones(img.shape[0], img.shape[1], 3, | |
| device=self.device) * (1 - img[:, :, 3:4]) | |
| img = img[:, :, :3].detach().cpu().numpy() | |
| return (1 - img).sum() | |
| def save_svg(self, output_dir, name): | |
| if not self.width_optim: | |
| pydiffvg.save_svg('{}/{}.svg'.format(output_dir, name), self.canvas_width, self.canvas_height, self.shapes, | |
| self.shape_groups) | |
| else: | |
| stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0]) | |
| new_shapes, new_shape_groups = [], [] | |
| for path in self.shapes: | |
| is_in_canvas_ = True | |
| w = path.stroke_width / 1.5 | |
| if w > 0.7 and is_in_canvas_: | |
| new_shapes.append(path) | |
| path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(new_shapes) - 1]), | |
| fill_color=None, | |
| stroke_color=stroke_color) | |
| new_shape_groups.append(path_group) | |
| pydiffvg.save_svg('{}/{}.svg'.format(output_dir, name), self.canvas_width, self.canvas_height, new_shapes, | |
| new_shape_groups) | |
| def clip_preprocess(self, target_im): | |
| model, preprocess = clip.load(self.saliency_clip_model, device=self.device, jit=False) | |
| model.eval().to(self.device) | |
| data_transforms = transforms.Compose([ | |
| preprocess.transforms[-1], | |
| ]) | |
| return data_transforms(target_im).to(self.device) | |
| def dino_attn(self): | |
| patch_size = 8 # dino hyperparameter | |
| threshold = 0.6 | |
| # for dino model | |
| mean_imagenet = torch.Tensor([0.485, 0.456, 0.406])[None, :, None, None].to(self.device) | |
| std_imagenet = torch.Tensor([0.229, 0.224, 0.225])[None, :, None, None].to(self.device) | |
| totens = transforms.Compose([ | |
| transforms.Resize((self.canvas_height, self.canvas_width)), | |
| transforms.ToTensor() | |
| ]) | |
| dino_model = torch.hub.load('facebookresearch/dino:main', 'dino_vits8').eval().to(self.device) | |
| self.main_im = Image.open(self.target_path).convert("RGB") | |
| main_im_tensor = totens(self.main_im).to(self.device) | |
| img = (main_im_tensor.unsqueeze(0) - mean_imagenet) / std_imagenet | |
| w_featmap = img.shape[-2] // patch_size | |
| h_featmap = img.shape[-1] // patch_size | |
| with torch.no_grad(): | |
| attn = dino_model.get_last_selfattention(img).detach().cpu()[0] | |
| nh = attn.shape[0] | |
| attn = attn[:, 0, 1:].reshape(nh, -1) | |
| val, idx = torch.sort(attn) | |
| val /= torch.sum(val, dim=1, keepdim=True) | |
| cumval = torch.cumsum(val, dim=1) | |
| th_attn = cumval > (1 - threshold) | |
| idx2 = torch.method_cfgort(idx) | |
| for head in range(nh): | |
| th_attn[head] = th_attn[head][idx2[head]] | |
| th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() | |
| th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu() | |
| attn = attn.reshape(nh, w_featmap, h_featmap).float() | |
| attn = nn.functional.interpolate(attn.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu() | |
| return attn | |
| def clip_attn(self): | |
| model, preprocess = clip.load(self.saliency_clip_model, device=self.device, jit=False) | |
| model.eval().to(self.device) | |
| if "RN" in self.saliency_clip_model: | |
| text_input = clip.tokenize([self.text_target]).to(self.device) | |
| saliency_layer = "layer4" | |
| attn_map = gradCAM( | |
| model.visual, | |
| self.image2clip_input, | |
| model.encode_text(text_input).float(), | |
| getattr(model.visual, saliency_layer) | |
| ) | |
| attn_map = attn_map.squeeze().detach().cpu().numpy() | |
| attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min()) | |
| else: # ViT | |
| attn_map = interpret(self.image2clip_input, model, device=self.device) | |
| del model | |
| return attn_map | |
| def set_attention_map(self): | |
| assert self.saliency_model in ["dino", "clip"] | |
| if self.saliency_model == "dino": | |
| return self.dino_attn() | |
| elif self.saliency_model == "clip": | |
| return self.clip_attn() | |
| def softmax(self, x, tau=0.2): | |
| e_x = np.exp(x / tau) | |
| return e_x / e_x.sum() | |
| def set_inds_clip(self): | |
| attn_map = (self.attention_map - self.attention_map.min()) / ( | |
| self.attention_map.max() - self.attention_map.min()) | |
| if self.xdog_intersec: | |
| xdog = XDoG(k=10) | |
| im_xdog = xdog(self.image2clip_input[0].permute(1, 2, 0).cpu().numpy()) | |
| intersec_map = (1 - im_xdog) * attn_map | |
| attn_map = intersec_map | |
| if self.mask_object_attention: | |
| attn_map = attn_map * self.mask[0, 0].cpu().numpy() | |
| attn_map_soft = np.copy(attn_map) | |
| attn_map_soft[attn_map > 0] = self.softmax(attn_map[attn_map > 0], tau=self.softmax_temp) | |
| k = self.num_stages * self.num_paths | |
| self.inds = np.random.choice(range(attn_map.flatten().shape[0]), size=k, replace=False, | |
| p=attn_map_soft.flatten()) | |
| self.inds = np.array(np.unravel_index(self.inds, attn_map.shape)).T | |
| self.inds_normalised = np.zeros(self.inds.shape) | |
| self.inds_normalised[:, 0] = self.inds[:, 1] / self.canvas_width | |
| self.inds_normalised[:, 1] = self.inds[:, 0] / self.canvas_height | |
| self.inds_normalised = self.inds_normalised.tolist() | |
| return attn_map_soft | |
| def set_inds_dino(self): | |
| k = max(3, (self.num_stages * self.num_paths) // 6 + 1) # sample top 3 three points from each attention head | |
| num_heads = self.attention_map.shape[0] | |
| self.inds = np.zeros((k * num_heads, 2)) | |
| # "thresh" is used for visualisaiton purposes only | |
| thresh = torch.zeros(num_heads + 1, self.attention_map.shape[1], self.attention_map.shape[2]) | |
| softmax = nn.Softmax(dim=1) | |
| for i in range(num_heads): | |
| # replace "self.attention_map[i]" with "self.attention_map" to get the highest values among | |
| # all heads. | |
| topk, indices = np.unique(self.attention_map[i].numpy(), return_index=True) | |
| topk = topk[::-1][:k] | |
| cur_attn_map = self.attention_map[i].numpy() | |
| # prob function for uniform sampling | |
| prob = cur_attn_map.flatten() | |
| prob[prob > topk[-1]] = 1 | |
| prob[prob <= topk[-1]] = 0 | |
| prob = prob / prob.sum() | |
| thresh[i] = torch.Tensor(prob.reshape(cur_attn_map.shape)) | |
| # choose k pixels from each head | |
| inds = np.random.choice(range(cur_attn_map.flatten().shape[0]), size=k, replace=False, p=prob) | |
| inds = np.unravel_index(inds, cur_attn_map.shape) | |
| self.inds[i * k: i * k + k, 0] = inds[0] | |
| self.inds[i * k: i * k + k, 1] = inds[1] | |
| # for visualisaiton | |
| sum_attn = self.attention_map.sum(0).numpy() | |
| mask = np.zeros(sum_attn.shape) | |
| mask[thresh[:-1].sum(0) > 0] = 1 | |
| sum_attn = sum_attn * mask | |
| sum_attn = sum_attn / sum_attn.sum() | |
| thresh[-1] = torch.Tensor(sum_attn) | |
| # sample num_paths from the chosen pixels. | |
| prob_sum = sum_attn[self.inds[:, 0].astype(np.int), self.inds[:, 1].astype(np.int)] | |
| prob_sum = prob_sum / prob_sum.sum() | |
| new_inds = [] | |
| for i in range(self.num_stages): | |
| new_inds.extend(np.random.choice(range(self.inds.shape[0]), size=self.num_paths, replace=False, p=prob_sum)) | |
| self.inds = self.inds[new_inds] | |
| self.inds_normalised = np.zeros(self.inds.shape) | |
| self.inds_normalised[:, 0] = self.inds[:, 1] / self.canvas_width | |
| self.inds_normalised[:, 1] = self.inds[:, 0] / self.canvas_height | |
| self.inds_normalised = self.inds_normalised.tolist() | |
| return thresh | |
| def set_attention_threshold_map(self): | |
| assert self.saliency_model in ["dino", "clip"] | |
| if self.saliency_model == "dino": | |
| return self.set_inds_dino() | |
| elif self.saliency_model == "clip": | |
| return self.set_inds_clip() | |
| def get_attn(self): | |
| return self.attention_map | |
| def get_thresh(self): | |
| return self.thresh | |
| def get_inds(self): | |
| return self.inds | |
| def get_mask(self): | |
| return self.mask | |
| def set_random_noise(self, epoch): | |
| if epoch % self.args.save_step == 0: | |
| self.add_random_noise = False | |
| else: | |
| self.add_random_noise = "noise" in self.args.augemntations | |
| class PainterOptimizer: | |
| def __init__(self, args, renderer): | |
| self.renderer = renderer | |
| self.points_lr = args.lr | |
| self.color_lr = args.color_lr | |
| self.args = args | |
| self.optim_color = args.force_sparse | |
| self.width_optim = args.width_optim | |
| self.width_optim_global = args.width_optim | |
| self.width_lr = args.width_lr | |
| self.optimize_points = args.optimize_points | |
| self.optimize_points_global = args.optimize_points | |
| self.points_optim = None | |
| self.width_optimizer = None | |
| self.mlp_width_weights_path = args.mlp_width_weights_path | |
| self.mlp_points_weights_path = args.mlp_points_weights_path | |
| self.load_points_opt_weights = args.load_points_opt_weights | |
| # self.only_width = args.only_width | |
| def turn_off_points_optim(self): | |
| self.optimize_points = False | |
| def switch_opt(self): | |
| self.width_optim = not self.width_optim | |
| self.optimize_points = not self.optimize_points | |
| def init_optimizers(self): | |
| if self.width_optim: | |
| points_params, width_params = self.renderer.parameters() | |
| self.width_optimizer = torch.optim.Adam(width_params, lr=self.width_lr) | |
| if self.mlp_width_weights_path != "none": | |
| checkpoint = torch.load(self.mlp_width_weights_path) | |
| self.width_optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| print("optimizer checkpoint loaded from ", self.mlp_width_weights_path) | |
| else: | |
| points_params = self.renderer.parameters() | |
| if self.optimize_points: | |
| self.points_optim = torch.optim.Adam(points_params, lr=self.points_lr) | |
| if self.mlp_points_weights_path != "none" and self.load_points_opt_weights: | |
| checkpoint = torch.load(self.mlp_points_weights_path) | |
| self.points_optim.load_state_dict(checkpoint['optimizer_state_dict']) | |
| print("optimizer checkpoint loaded from ", self.mlp_points_weights_path) | |
| if self.optim_color: | |
| self.color_optim = torch.optim.Adam(self.renderer.set_color_parameters(), lr=self.color_lr) | |
| def zero_grad_(self): | |
| if self.optimize_points: | |
| self.points_optim.zero_grad() | |
| if self.width_optim: | |
| self.width_optimizer.zero_grad() | |
| if self.optim_color: | |
| self.color_optim.zero_grad() | |
| def step_(self): | |
| if self.optimize_points: | |
| self.points_optim.step() | |
| if self.width_optim: | |
| self.width_optimizer.step() | |
| if self.optim_color: | |
| self.color_optim.step() | |
| def get_lr(self, optim="points"): | |
| if optim == "points" and self.optimize_points_global: | |
| return self.points_optim.param_groups[0]['lr'] | |
| if optim == "width" and self.width_optim_global: | |
| return self.width_optimizer.param_groups[0]['lr'] | |
| else: | |
| return None | |
| def get_points_optim(self): | |
| return self.points_optim | |
| def get_width_optim(self): | |
| return self.width_optimizer | |
| class LinearDecayLR: | |
| def __init__(self, decay_every, decay_ratio): | |
| self.decay_every = decay_every | |
| self.decay_ratio = decay_ratio | |
| def __call__(self, n): | |
| decay_time = n // self.decay_every | |
| decay_step = n % self.decay_every | |
| lr_s = self.decay_ratio ** decay_time | |
| lr_e = self.decay_ratio ** (decay_time + 1) | |
| r = decay_step / self.decay_every | |
| lr = lr_s * (1 - r) + lr_e * r | |
| return lr | |
| def interpret(image, clip_model, device): | |
| # virtual forward to get attention map | |
| images = image.repeat(1, 1, 1, 1) | |
| _ = clip_model.encode_image(images) # ensure `attn_probs` in attention is not empty | |
| clip_model.zero_grad() | |
| image_attn_blocks = list(dict(clip_model.visual.transformer.resblocks.named_children()).values()) | |
| # create R to store attention map | |
| num_tokens = image_attn_blocks[0].attn_probs.shape[-1] | |
| R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device) | |
| R = R.unsqueeze(0).expand(1, num_tokens, num_tokens) | |
| cams = [] | |
| for i, blk in enumerate(image_attn_blocks): # 12 attention blocks | |
| cam = blk.attn_probs.detach() # attn_probs shape: [12, 50, 50] | |
| # each patch is 7x7 so we have 49 pixels + 1 for positional encoding | |
| cam = cam.reshape(1, -1, cam.shape[-1], cam.shape[-1]) | |
| cam = cam.clamp(min=0) | |
| cam = cam.clamp(min=0).mean(dim=1) # mean of the 12 something | |
| cams.append(cam) | |
| R = R + torch.bmm(cam, R) | |
| cams_avg = torch.cat(cams) # [12, 50, 50] | |
| cams_avg = cams_avg[:, 0, 1:] # [12, 49] | |
| image_relevance = cams_avg.mean(dim=0).unsqueeze(0) # [1, 49] | |
| image_relevance = image_relevance.reshape(1, 1, 7, 7) # [1, 1, 7, 7] | |
| # interpolate: [1, 1, 7, 7] -> [1, 3, 224, 224] | |
| image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bicubic') | |
| image_relevance = image_relevance.reshape(224, 224).data.cpu().numpy().astype(np.float32) | |
| # normalize the tensor to [0, 1] | |
| image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min()) | |
| return image_relevance | |
| class MLP(nn.Module): | |
| def __init__(self, num_strokes, num_cp, width_optim=False): | |
| super().__init__() | |
| outdim = 1000 | |
| self.width_optim = width_optim | |
| self.layers_points = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(num_strokes * num_cp * 2, outdim), | |
| nn.SELU(inplace=True), | |
| nn.Linear(outdim, outdim), | |
| nn.SELU(inplace=True), | |
| nn.Linear(outdim, num_strokes * num_cp * 2), | |
| ) | |
| def forward(self, x, widths=None): | |
| '''Forward pass''' | |
| deltas = self.layers_points(x) | |
| # if self.width_optim: | |
| # return x.flatten() + 0.1 * deltas, self.layers_width(widths) | |
| return x.flatten() + 0.1 * deltas | |
| class WidthMLP(nn.Module): | |
| def __init__(self, num_strokes, num_cp, width_optim=False): | |
| super().__init__() | |
| outdim = 1000 | |
| self.width_optim = width_optim | |
| self.layers_width = nn.Sequential( | |
| nn.Linear(num_strokes, outdim), | |
| nn.SELU(inplace=True), | |
| nn.Linear(outdim, outdim), | |
| nn.SELU(inplace=True), | |
| nn.Linear(outdim, num_strokes), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, widths=None): | |
| '''Forward pass''' | |
| return self.layers_width(widths) | |
| def init_weights(m): | |
| if isinstance(m, nn.Linear): | |
| torch.nn.init.xavier_uniform(m.weight) | |
| m.bias.data.fill_(0.01) | |