Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Copyright (c) XiMing Xing. All rights reserved. | |
| # Author: XiMing Xing | |
| # Description: | |
| import pathlib | |
| from PIL import Image | |
| from typing import AnyStr | |
| import numpy as np | |
| from tqdm.auto import tqdm | |
| import torch | |
| from torch.optim.lr_scheduler import LambdaLR | |
| import torchvision | |
| from torchvision import transforms | |
| from pytorch_svgrender.libs.engine import ModelState | |
| from pytorch_svgrender.libs.solver.optim import get_optimizer | |
| from pytorch_svgrender.painter.svgdreamer import Painter, PainterOptimizer | |
| from pytorch_svgrender.painter.svgdreamer.painter_params import CosineWithWarmupLRLambda | |
| from pytorch_svgrender.painter.live import xing_loss_fn | |
| from pytorch_svgrender.painter.svgdreamer import VectorizedParticleSDSPipeline | |
| from pytorch_svgrender.plt import plot_img | |
| from pytorch_svgrender.utils.color_attrs import init_tensor_with_color | |
| from pytorch_svgrender.token2attn.ptp_utils import view_images | |
| from pytorch_svgrender.diffusers_warp import model2res | |
| import ImageReward as RM | |
| class SVGDreamerPipeline(ModelState): | |
| def __init__(self, args): | |
| assert args.x.style in ["iconography", "pixelart", "low-poly", "painting", "sketch", "ink"] | |
| assert args.x.guidance.n_particle >= args.x.guidance.vsd_n_particle | |
| assert args.x.guidance.n_particle >= args.x.guidance.phi_n_particle | |
| assert args.x.guidance.n_phi_sample >= 1 | |
| logdir_ = f"sd{args.seed}" \ | |
| f"-{'vpsd' if args.x.skip_sive else 'sive'}" \ | |
| f"-{args.x.model_id}" \ | |
| f"-{args.x.style}" \ | |
| f"-P{args.x.num_paths}" \ | |
| f"{'-RePath' if args.x.path_reinit.use else ''}" | |
| super().__init__(args, log_path_suffix=logdir_) | |
| # create log dir | |
| self.png_logs_dir = self.result_path / "png_logs" | |
| self.svg_logs_dir = self.result_path / "svg_logs" | |
| self.ft_png_logs_dir = self.result_path / "ft_png_logs" | |
| self.ft_svg_logs_dir = self.result_path / "ft_svg_logs" | |
| self.sd_sample_dir = self.result_path / 'sd_samples' | |
| self.reinit_dir = self.result_path / "reinit_logs" | |
| self.init_stage_two_dir = self.result_path / "stage_two_init_logs" | |
| self.phi_samples_dir = self.result_path / "phi_sampling_logs" | |
| if self.accelerator.is_main_process: | |
| self.png_logs_dir.mkdir(parents=True, exist_ok=True) | |
| self.svg_logs_dir.mkdir(parents=True, exist_ok=True) | |
| self.ft_png_logs_dir.mkdir(parents=True, exist_ok=True) | |
| self.ft_svg_logs_dir.mkdir(parents=True, exist_ok=True) | |
| self.sd_sample_dir.mkdir(parents=True, exist_ok=True) | |
| self.reinit_dir.mkdir(parents=True, exist_ok=True) | |
| self.init_stage_two_dir.mkdir(parents=True, exist_ok=True) | |
| self.phi_samples_dir.mkdir(parents=True, exist_ok=True) | |
| self.select_fpth = self.result_path / 'select_sample.png' | |
| # make video log | |
| self.make_video = self.args.mv | |
| if self.make_video: | |
| self.frame_idx = 0 | |
| self.frame_log_dir = self.result_path / "frame_logs" | |
| self.frame_log_dir.mkdir(parents=True, exist_ok=True) | |
| self.g_device = torch.Generator(device=self.device).manual_seed(args.seed) | |
| self.pipeline = VectorizedParticleSDSPipeline(args, args.diffuser, self.x_cfg.guidance, self.device) | |
| # load reward model | |
| self.reward_model = None | |
| if self.x_cfg.guidance.phi_ReFL: | |
| self.reward_model = RM.load("ImageReward-v1.0", device=self.device, download_root=self.x_cfg.reward_path) | |
| self.style = self.x_cfg.style | |
| if self.style == "pixelart": | |
| self.x_cfg.lr_stage_one.lr_schedule = False | |
| self.x_cfg.lr_stage_two.lr_schedule = False | |
| def target_file_preprocess(self, tar_path: AnyStr): | |
| process_comp = transforms.Compose([ | |
| transforms.Resize(size=(self.x_cfg.image_size, self.x_cfg.image_size)), | |
| transforms.ToTensor(), | |
| transforms.Lambda(lambda t: t.unsqueeze(0)), | |
| ]) | |
| tar_pil = Image.open(tar_path).convert("RGB") # open file | |
| target_img = process_comp(tar_pil) # preprocess | |
| target_img = target_img.to(self.device) | |
| return target_img | |
| def SIVE_stage(self, text_prompt: str): | |
| # TODO: SIVE implementation | |
| pass | |
| def painterly_rendering(self, text_prompt: str, target_file: AnyStr = None): | |
| # log prompts | |
| self.print(f"prompt: {text_prompt}") | |
| self.print(f"neg_prompt: {self.args.neg_prompt}\n") | |
| # for convenience | |
| im_size = self.x_cfg.image_size | |
| guidance_cfg = self.x_cfg.guidance | |
| n_particle = self.x_cfg.guidance.n_particle | |
| total_step = self.x_cfg.guidance.num_iter | |
| path_reinit = self.x_cfg.path_reinit | |
| init_from_target = True if (target_file and pathlib.Path(target_file).exists()) else False | |
| # switch mode | |
| if self.x_cfg.skip_sive and not init_from_target: | |
| # mode 1: optimization with VPSD from scratch | |
| # randomly init | |
| self.print("optimization with VPSD from scratch...") | |
| if self.x_cfg.color_init == 'rand': | |
| target_img = torch.randn(1, 3, im_size, im_size) | |
| self.print("color: randomly init") | |
| else: | |
| target_img = init_tensor_with_color(self.x_cfg.color_init, 1, im_size, im_size) | |
| self.print(f"color: {self.x_cfg.color_init}") | |
| # log init target_img | |
| plot_img(target_img, self.result_path, fname='init_target_img') | |
| final_svg_path = None | |
| elif init_from_target: | |
| # mode 2: load the SVG file and finetune it | |
| self.print(f"load svg from {target_file} ...") | |
| self.print(f"SVG fine-tuning via VPSD...") | |
| final_svg_path = target_file | |
| if self.x_cfg.color_init == 'target_randn': | |
| # special order: init newly paths color use random color | |
| target_img = torch.randn(1, 3, im_size, im_size) | |
| self.print("color: randomly init") | |
| else: | |
| # load the SVG and init newly paths color use target_img | |
| # note: the target will be converted to png via pydiffvg when load_renderer called | |
| target_img = None | |
| else: | |
| # mode 3: text-to-img-to-svg (two stage) | |
| target_img, final_svg_path = self.SIVE_stage(text_prompt) | |
| self.x_cfg.path_svg = final_svg_path | |
| self.print("\n SVG fine-tuning via VPSD...") | |
| plot_img(target_img, self.result_path, fname='init_target_img') | |
| # create svg renderer | |
| renderers = [self.load_renderer(final_svg_path) for _ in range(n_particle)] | |
| # randomly initialize the particles | |
| if self.x_cfg.skip_sive or init_from_target: | |
| if target_img is None: | |
| target_img = self.target_file_preprocess(self.result_path / 'target_img.png') | |
| for render in renderers: | |
| render.component_wise_path_init(gt=target_img, pred=None, init_type='random') | |
| # log init images | |
| for i, r in enumerate(renderers): | |
| init_imgs = r.init_image(stage=0, num_paths=self.x_cfg.num_paths) | |
| plot_img(init_imgs, self.init_stage_two_dir, fname=f"init_img_stage_two_{i}") | |
| # init renderer optimizer | |
| optimizers = [] | |
| for renderer in renderers: | |
| optim_ = PainterOptimizer(renderer, | |
| self.style, | |
| guidance_cfg.num_iter, | |
| self.x_cfg.lr_stage_two, | |
| self.x_cfg.trainable_bg) | |
| optim_.init_optimizers() | |
| optimizers.append(optim_) | |
| # init phi_model optimizer | |
| phi_optimizer = get_optimizer('adamW', | |
| self.pipeline.phi_params, | |
| guidance_cfg.phi_lr, | |
| guidance_cfg.phi_optim) | |
| # init phi_model lr scheduler | |
| phi_scheduler = None | |
| schedule_cfg = guidance_cfg.phi_schedule | |
| if schedule_cfg.use: | |
| phi_lr_lambda = CosineWithWarmupLRLambda(num_steps=schedule_cfg.total_step, | |
| warmup_steps=schedule_cfg.warmup_steps, | |
| warmup_start_lr=schedule_cfg.warmup_start_lr, | |
| warmup_end_lr=schedule_cfg.warmup_end_lr, | |
| cosine_end_lr=schedule_cfg.cosine_end_lr) | |
| phi_scheduler = LambdaLR(phi_optimizer, lr_lambda=phi_lr_lambda, last_epoch=-1) | |
| self.print(f"-> Painter point Params: {len(renderers[0].get_point_parameters())}") | |
| self.print(f"-> Painter color Params: {len(renderers[0].get_color_parameters())}") | |
| self.print(f"-> Painter width Params: {len(renderers[0].get_width_parameters())}") | |
| L_reward = torch.tensor(0.) | |
| self.step = 0 # reset global step | |
| self.print(f"\ntotal VPSD optimization steps: {total_step}") | |
| with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar: | |
| while self.step < total_step: | |
| # set particles | |
| particles = [renderer.get_image() for renderer in renderers] | |
| raster_imgs = torch.cat(particles, dim=0) | |
| if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1): | |
| plot_img(raster_imgs, self.frame_log_dir, fname=f"iter{self.frame_idx}") | |
| self.frame_idx += 1 | |
| L_guide, grad, latents, t_step = self.pipeline.variational_score_distillation( | |
| raster_imgs, | |
| self.step, | |
| prompt=[text_prompt], | |
| negative_prompt=self.args.neg_prompt, | |
| grad_scale=guidance_cfg.grad_scale, | |
| enhance_particle=guidance_cfg.particle_aug, | |
| im_size=model2res(self.x_cfg.model_id) | |
| ) | |
| # Xing Loss for Self-Interaction Problem | |
| L_add = torch.tensor(0.) | |
| if self.style == "iconography" or self.x_cfg.xing_loss.use: | |
| for r in renderers: | |
| L_add += xing_loss_fn(r.get_point_parameters()) * self.x_cfg.xing_loss.weight | |
| loss = L_guide + L_add | |
| # optimization | |
| for opt_ in optimizers: | |
| opt_.zero_grad_() | |
| loss.backward() | |
| for opt_ in optimizers: | |
| opt_.step_() | |
| # phi_model optimization | |
| for _ in range(guidance_cfg.phi_update_step): | |
| L_lora = self.pipeline.train_phi_model(latents, guidance_cfg.phi_t, as_latent=True) | |
| phi_optimizer.zero_grad() | |
| L_lora.backward() | |
| phi_optimizer.step() | |
| # reward learning | |
| if guidance_cfg.phi_ReFL and self.step % guidance_cfg.phi_sample_step == 0: | |
| with torch.no_grad(): | |
| phi_outputs = [] | |
| phi_sample_paths = [] | |
| for idx in range(guidance_cfg.n_phi_sample): | |
| phi_output = self.pipeline.sample(text_prompt, | |
| num_inference_steps=guidance_cfg.phi_infer_step, | |
| generator=self.g_device) | |
| sample_path = (self.phi_samples_dir / f'iter{idx}.png').as_posix() | |
| phi_output.images[0].save(sample_path) | |
| phi_sample_paths.append(sample_path) | |
| phi_output_np = np.array(phi_output.images[0]) | |
| phi_outputs.append(phi_output_np) | |
| # save all samples | |
| view_images(phi_outputs, save_image=True, | |
| num_rows=max(len(phi_outputs) // 6, 1), | |
| fp=self.phi_samples_dir / f'samples_iter{self.step}.png') | |
| ranking, rewards = self.reward_model.inference_rank(text_prompt, phi_sample_paths) | |
| self.print(f"ranking: {ranking}, reward score: {rewards}") | |
| for k in range(guidance_cfg.n_phi_sample): | |
| phi = self.target_file_preprocess(phi_sample_paths[ranking[k] - 1]) | |
| L_reward = self.pipeline.train_phi_model_refl(phi, weight=rewards[k]) | |
| phi_optimizer.zero_grad() | |
| L_reward.backward() | |
| phi_optimizer.step() | |
| # update the learning rate of the phi_model | |
| if phi_scheduler is not None: | |
| phi_scheduler.step() | |
| # curve regularization | |
| for r in renderers: | |
| r.clip_curve_shape() | |
| # re-init paths | |
| if self.step % path_reinit.freq == 0 and self.step < path_reinit.stop_step and self.step != 0: | |
| for i, r in enumerate(renderers): | |
| r.reinitialize_paths(path_reinit.use, # on-off | |
| path_reinit.opacity_threshold, | |
| path_reinit.area_threshold, | |
| fpath=self.reinit_dir / f"reinit-{self.step}_p{i}.svg") | |
| # update lr | |
| if self.x_cfg.lr_stage_two.lr_schedule: | |
| for opt_ in optimizers: | |
| opt_.update_lr() | |
| # log pretrained model lr | |
| lr_str = "" | |
| for k, lr in optimizers[0].get_lr().items(): | |
| lr_str += f"{k}_lr: {lr:.4f}, " | |
| # log phi model lr | |
| cur_phi_lr = phi_optimizer.param_groups[0]['lr'] | |
| lr_str += f"phi_lr: {cur_phi_lr:.3e}, " | |
| pbar.set_description( | |
| lr_str + | |
| f"t: {t_step.item():.2f}, " | |
| f"L_total: {loss.item():.4f}, " | |
| f"L_add: {L_add.item():.4e}, " | |
| f"L_lora: {L_lora.item():.4f}, " | |
| f"L_reward: {L_reward.item():.4f}, " | |
| f"vpsd: {grad.item():.4e}" | |
| ) | |
| if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: | |
| # save png | |
| torchvision.utils.save_image(raster_imgs, | |
| fp=self.ft_png_logs_dir / f'iter{self.step}.png') | |
| # save svg | |
| for i, r in enumerate(renderers): | |
| r.pretty_save_svg(self.ft_svg_logs_dir / f"svg_iter{self.step}_p{i}.svg") | |
| self.step += 1 | |
| pbar.update(1) | |
| # save final | |
| for i, r in enumerate(renderers): | |
| final_svg_path = self.result_path / f"finetune_final_p_{i}.svg" | |
| r.pretty_save_svg(final_svg_path) | |
| # save SVGs | |
| torchvision.utils.save_image(raster_imgs, fp=self.result_path / f'all_particles.png') | |
| if self.make_video: | |
| from subprocess import call | |
| call([ | |
| "ffmpeg", | |
| "-framerate", f"{self.args.framerate}", | |
| "-i", (self.frame_log_dir / "iter%d.png").as_posix(), | |
| "-vb", "20M", | |
| (self.result_path / "svgdreamer_rendering.mp4").as_posix() | |
| ]) | |
| self.close(msg="painterly rendering complete.") | |
| def load_renderer(self, path_svg=None): | |
| renderer = Painter(self.args.diffvg, | |
| self.style, | |
| self.x_cfg.num_segments, | |
| self.x_cfg.segment_init, | |
| self.x_cfg.radius, | |
| self.x_cfg.image_size, | |
| self.x_cfg.grid, | |
| self.x_cfg.trainable_bg, | |
| self.x_cfg.width, | |
| path_svg=path_svg, | |
| device=self.device) | |
| # if load a svg file, then rasterize it | |
| save_path = self.result_path / 'target_img.png' | |
| if path_svg is not None and (not save_path.exists()): | |
| canvas_width, canvas_height, shapes, shape_groups = renderer.load_svg(path_svg) | |
| render_img = renderer.render_image(canvas_width, canvas_height, shapes, shape_groups) | |
| torchvision.utils.save_image(render_img, fp=save_path) | |
| return renderer | |