Spaces:
Runtime error
Runtime error
| """ | |
| Scream: python painterly_rendering.py imgs/scream.jpg --num_paths 2048 --max_width 4.0 | |
| Fallingwater: python painterly_rendering.py imgs/fallingwater.jpg --num_paths 2048 --max_width 4.0 | |
| Fallingwater: python painterly_rendering.py imgs/fallingwater.jpg --num_paths 2048 --max_width 4.0 --use_lpips_loss | |
| Baboon: python painterly_rendering.py imgs/baboon.png --num_paths 1024 --max_width 4.0 --num_iter 250 | |
| Baboon Lpips: python painterly_rendering.py imgs/baboon.png --num_paths 1024 --max_width 4.0 --num_iter 500 --use_lpips_loss | |
| smile: python painterly_rendering.py ../LIVE/figures/smile.png --num_paths 5 --use_blob --num_iter 500 | |
| """ | |
| import pydiffvg | |
| import torch | |
| import skimage | |
| import skimage.io | |
| import random | |
| import ttools.modules | |
| import argparse | |
| import math | |
| pydiffvg.set_print_timing(True) | |
| gamma = 1.0 | |
| def main(args): | |
| # Use GPU if available | |
| pydiffvg.set_use_gpu(torch.cuda.is_available()) | |
| perception_loss = ttools.modules.LPIPS().to(pydiffvg.get_device()) | |
| #target = torch.from_numpy(skimage.io.imread('imgs/lena.png')).to(torch.float32) / 255.0 | |
| target = torch.from_numpy(skimage.io.imread(args.target)).to(torch.float32) / 255.0 | |
| target = target.pow(gamma) | |
| target = target.to(pydiffvg.get_device()) | |
| target = target.unsqueeze(0) | |
| target = target.permute(0, 3, 1, 2) # NHWC -> NCHW | |
| #target = torch.nn.functional.interpolate(target, size = [256, 256], mode = 'area') | |
| canvas_width, canvas_height = target.shape[3], target.shape[2] | |
| num_paths = args.num_paths | |
| max_width = args.max_width | |
| random.seed(1234) | |
| torch.manual_seed(1234) | |
| shapes = [] | |
| shape_groups = [] | |
| if args.use_blob: | |
| for i in range(num_paths): | |
| num_segments = random.randint(3, 5) | |
| num_control_points = torch.zeros(num_segments, dtype = torch.int32) + 2 | |
| points = [] | |
| p0 = (random.random(), random.random()) | |
| points.append(p0) | |
| for j in range(num_segments): | |
| radius = 0.05 | |
| p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5)) | |
| p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5)) | |
| p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5)) | |
| points.append(p1) | |
| points.append(p2) | |
| if j < num_segments - 1: | |
| points.append(p3) | |
| p0 = p3 | |
| points = torch.tensor(points) | |
| points[:, 0] *= canvas_width | |
| points[:, 1] *= canvas_height | |
| path = pydiffvg.Path(num_control_points = num_control_points, | |
| points = points, | |
| stroke_width = torch.tensor(1.0), | |
| is_closed = True) | |
| shapes.append(path) | |
| path_group = pydiffvg.ShapeGroup(shape_ids = torch.tensor([len(shapes) - 1]), | |
| fill_color = torch.tensor([random.random(), | |
| random.random(), | |
| random.random(), | |
| random.random()])) | |
| shape_groups.append(path_group) | |
| else: | |
| for i in range(num_paths): | |
| num_segments = random.randint(1, 3) | |
| num_control_points = torch.zeros(num_segments, dtype = torch.int32) + 2 | |
| points = [] | |
| p0 = (random.random(), random.random()) | |
| points.append(p0) | |
| for j in range(num_segments): | |
| radius = 0.05 | |
| p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5)) | |
| p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5)) | |
| p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5)) | |
| points.append(p1) | |
| points.append(p2) | |
| points.append(p3) | |
| p0 = p3 | |
| points = torch.tensor(points) | |
| points[:, 0] *= canvas_width | |
| points[:, 1] *= canvas_height | |
| #points = torch.rand(3 * num_segments + 1, 2) * min(canvas_width, canvas_height) | |
| path = pydiffvg.Path(num_control_points = num_control_points, | |
| points = points, | |
| stroke_width = torch.tensor(1.0), | |
| is_closed = False) | |
| shapes.append(path) | |
| path_group = pydiffvg.ShapeGroup(shape_ids = torch.tensor([len(shapes) - 1]), | |
| fill_color = None, | |
| stroke_color = torch.tensor([random.random(), | |
| random.random(), | |
| random.random(), | |
| random.random()])) | |
| shape_groups.append(path_group) | |
| scene_args = pydiffvg.RenderFunction.serialize_scene(\ | |
| canvas_width, canvas_height, shapes, shape_groups) | |
| render = pydiffvg.RenderFunction.apply | |
| img = render(canvas_width, # width | |
| canvas_height, # height | |
| 2, # num_samples_x | |
| 2, # num_samples_y | |
| 0, # seed | |
| None, | |
| *scene_args) | |
| pydiffvg.imwrite(img.cpu(), 'results/painterly_rendering/init.png', gamma=gamma) | |
| points_vars = [] | |
| stroke_width_vars = [] | |
| color_vars = [] | |
| for path in shapes: | |
| path.points.requires_grad = True | |
| points_vars.append(path.points) | |
| if not args.use_blob: | |
| for path in shapes: | |
| path.stroke_width.requires_grad = True | |
| stroke_width_vars.append(path.stroke_width) | |
| if args.use_blob: | |
| for group in shape_groups: | |
| group.fill_color.requires_grad = True | |
| color_vars.append(group.fill_color) | |
| else: | |
| for group in shape_groups: | |
| group.stroke_color.requires_grad = True | |
| color_vars.append(group.stroke_color) | |
| # Optimize | |
| points_optim = torch.optim.Adam(points_vars, lr=1.0) | |
| if len(stroke_width_vars) > 0: | |
| width_optim = torch.optim.Adam(stroke_width_vars, lr=0.1) | |
| color_optim = torch.optim.Adam(color_vars, lr=0.01) | |
| # Adam iterations. | |
| for t in range(args.num_iter): | |
| print('iteration:', t) | |
| points_optim.zero_grad() | |
| if len(stroke_width_vars) > 0: | |
| width_optim.zero_grad() | |
| color_optim.zero_grad() | |
| # Forward pass: render the image. | |
| scene_args = pydiffvg.RenderFunction.serialize_scene(\ | |
| canvas_width, canvas_height, shapes, shape_groups) | |
| img = render(canvas_width, # width | |
| canvas_height, # height | |
| 2, # num_samples_x | |
| 2, # num_samples_y | |
| t, # seed | |
| None, | |
| *scene_args) | |
| # Compose img with white background | |
| img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device = pydiffvg.get_device()) * (1 - img[:, :, 3:4]) | |
| # Save the intermediate render. | |
| pydiffvg.imwrite(img.cpu(), 'results/painterly_rendering/iter_{}.png'.format(t), gamma=gamma) | |
| img = img[:, :, :3] | |
| # Convert img from HWC to NCHW | |
| img = img.unsqueeze(0) | |
| img = img.permute(0, 3, 1, 2) # NHWC -> NCHW | |
| if args.use_lpips_loss: | |
| loss = perception_loss(img, target) + (img.mean() - target.mean()).pow(2) | |
| else: | |
| loss = (img - target).pow(2).mean() | |
| print('render loss:', loss.item()) | |
| # Backpropagate the gradients. | |
| loss.backward() | |
| # Take a gradient descent step. | |
| points_optim.step() | |
| if len(stroke_width_vars) > 0: | |
| width_optim.step() | |
| color_optim.step() | |
| if len(stroke_width_vars) > 0: | |
| for path in shapes: | |
| path.stroke_width.data.clamp_(1.0, max_width) | |
| if args.use_blob: | |
| for group in shape_groups: | |
| group.fill_color.data.clamp_(0.0, 1.0) | |
| else: | |
| for group in shape_groups: | |
| group.stroke_color.data.clamp_(0.0, 1.0) | |
| if t % 10 == 0 or t == args.num_iter - 1: | |
| pydiffvg.save_svg('results/painterly_rendering/iter_{}.svg'.format(t), | |
| canvas_width, canvas_height, shapes, shape_groups) | |
| # Render the final result. | |
| img = render(target.shape[1], # width | |
| target.shape[0], # height | |
| 2, # num_samples_x | |
| 2, # num_samples_y | |
| 0, # seed | |
| None, | |
| *scene_args) | |
| # Save the intermediate render. | |
| pydiffvg.imwrite(img.cpu(), 'results/painterly_rendering/final.png'.format(t), gamma=gamma) | |
| # Convert the intermediate renderings to a video. | |
| from subprocess import call | |
| call(["ffmpeg", "-framerate", "24", "-i", | |
| "results/painterly_rendering/iter_%d.png", "-vb", "20M", | |
| "results/painterly_rendering/out.mp4"]) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("target", help="target image path") | |
| parser.add_argument("--num_paths", type=int, default=512) | |
| parser.add_argument("--max_width", type=float, default=2.0) | |
| parser.add_argument("--use_lpips_loss", dest='use_lpips_loss', action='store_true') | |
| parser.add_argument("--num_iter", type=int, default=500) | |
| parser.add_argument("--use_blob", dest='use_blob', action='store_true') | |
| args = parser.parse_args() | |
| main(args) | |