Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Author: ximing | |
| # Copyright (c) 2023, XiMing Xing. | |
| # License: MPL-2.0 License | |
| from typing import AnyStr | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from torchvision.utils import make_grid | |
| def plot_couple(input_1: torch.Tensor, | |
| input_2: torch.Tensor, | |
| step: int, | |
| output_dir: str, | |
| fname: str, # file name | |
| prompt: str = '', # text prompt as image tile | |
| dpi: int = 300): | |
| if input_1.shape != input_2.shape: | |
| raise ValueError("inputs and outputs must have the same dimensions") | |
| plt.figure() | |
| plt.subplot(1, 2, 1) # nrows=1, ncols=2, index=1 | |
| grid = make_grid(input_1, normalize=True, pad_value=2) | |
| ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
| plt.imshow(ndarr) | |
| plt.axis("off") | |
| plt.title("Input") | |
| plt.subplot(1, 2, 2) # nrows=1, ncols=2, index=2 | |
| grid = make_grid(input_2, normalize=True, pad_value=2) | |
| ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
| plt.imshow(ndarr) | |
| plt.axis("off") | |
| plt.title(f"Rendering - {step} steps") | |
| def insert_newline(string, point=9): | |
| # split by blank | |
| words = string.split() | |
| if len(words) <= point: | |
| return string | |
| word_chunks = [words[i:i + point] for i in range(0, len(words), point)] | |
| new_string = "\n".join(" ".join(chunk) for chunk in word_chunks) | |
| return new_string | |
| plt.suptitle(insert_newline(prompt), fontsize=10) | |
| plt.tight_layout() | |
| plt.savefig(f"{output_dir}/{fname}.png", dpi=dpi) | |
| plt.close() | |
| def plot_img(inputs: torch.Tensor, | |
| output_dir: AnyStr, | |
| fname: str, # file name | |
| dpi: int = 100): | |
| assert torch.is_tensor(inputs), f"The input must be tensor type, but got {type(inputs)}" | |
| grid = make_grid(inputs, normalize=True, pad_value=2) | |
| ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
| plt.imshow(ndarr) | |
| plt.axis("off") | |
| plt.tight_layout() | |
| plt.savefig(f"{output_dir}/{fname}.png", dpi=dpi, bbox_inches='tight') | |
| plt.close() | |
| def plot_img_title(inputs: torch.Tensor, | |
| title: str, | |
| output_dir: AnyStr, | |
| fname: str, # file name | |
| dpi: int = 500): | |
| assert torch.is_tensor(inputs), f"The input must be tensor type, but got {type(inputs)}" | |
| grid = make_grid(inputs, normalize=True, pad_value=2) | |
| ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
| plt.imshow(ndarr) | |
| plt.axis("off") | |
| plt.title(f"{title}") | |
| plt.savefig(f"{output_dir}/{fname}.png", dpi=dpi) | |
| plt.close() | |