Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Copyright (c) XiMing Xing. All rights reserved. | |
| # Author: XiMing Xing | |
| # Description: | |
| import pathlib | |
| from typing import Union, List, Text, BinaryIO, AnyStr | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torchvision.transforms as transforms | |
| from torchvision.utils import make_grid | |
| __all__ = [ | |
| 'sample2pil_transforms', | |
| 'pt2numpy_transforms', | |
| 'plt_pt_img', | |
| 'save_grid_images_and_labels', | |
| 'save_grid_images_and_captions', | |
| ] | |
| # generate sample to PIL images | |
| sample2pil_transforms = transforms.Compose([ | |
| # unnormalizing to [0,1] | |
| transforms.Lambda(lambda t: torch.clamp((t + 1) / 2, min=0.0, max=1.0)), | |
| # Add 0.5 after unnormalizing to [0, 255] | |
| transforms.Lambda(lambda t: torch.clamp(t * 255. + 0.5, min=0, max=255)), | |
| # CHW to HWC | |
| transforms.Lambda(lambda t: t.permute(1, 2, 0)), | |
| # to numpy ndarray, dtype int8 | |
| transforms.Lambda(lambda t: t.to('cpu', torch.uint8).numpy()), | |
| # Converts a numpy ndarray of shape H x W x C to a PIL Image | |
| transforms.ToPILImage(), | |
| ]) | |
| # generate sample to PIL images | |
| pt2numpy_transforms = transforms.Compose([ | |
| # Add 0.5 after unnormalizing to [0, 255] | |
| transforms.Lambda(lambda t: torch.clamp(t * 255. + 0.5, min=0, max=255)), | |
| # CHW to HWC | |
| transforms.Lambda(lambda t: t.permute(1, 2, 0)), | |
| # to numpy ndarray, dtype int8 | |
| transforms.Lambda(lambda t: t.to('cpu', torch.uint8).numpy()), | |
| ]) | |
| def plt_pt_img( | |
| pt_img: torch.Tensor, | |
| save_path: AnyStr = None, | |
| title: AnyStr = None, | |
| dpi: int = 300 | |
| ): | |
| grid = make_grid(pt_img, normalize=True, pad_value=2) | |
| ndarr = pt2numpy_transforms(grid) | |
| plt.imshow(ndarr) | |
| plt.axis("off") | |
| plt.tight_layout() | |
| if title is not None: | |
| plt.title(f"{title}") | |
| plt.show() | |
| if save_path is not None: | |
| plt.savefig(save_path, dpi=dpi) | |
| plt.close() | |
| def save_grid_images_and_labels( | |
| images: Union[torch.Tensor, List[torch.Tensor]], | |
| probs: Union[torch.Tensor, List[torch.Tensor]], | |
| labels: Union[torch.Tensor, List[torch.Tensor]], | |
| classes: Union[torch.Tensor, List[torch.Tensor]], | |
| fp: Union[Text, pathlib.Path, BinaryIO], | |
| nrow: int = 4, | |
| normalize: bool = True | |
| ) -> None: | |
| """Save a given Tensor into an image file. | |
| """ | |
| num_images = len(images) | |
| num_rows, num_cols = _get_subplot_shape(num_images, nrow) | |
| fig = plt.figure(figsize=(25, 20)) | |
| for i in range(num_images): | |
| ax = fig.add_subplot(num_rows, num_cols, i + 1) | |
| image, true_label, prob = images[i], labels[i], probs[i] | |
| true_prob = prob[true_label] | |
| incorrect_prob, incorrect_label = torch.max(prob, dim=0) | |
| true_class = classes[true_label] | |
| incorrect_class = classes[incorrect_label] | |
| if normalize: | |
| image = sample2pil_transforms(image) | |
| ax.imshow(image) | |
| title = f'true label: {true_class} ({true_prob:.3f})\n ' \ | |
| f'pred label: {incorrect_class} ({incorrect_prob:.3f})' | |
| ax.set_title(title, fontsize=20) | |
| ax.axis('off') | |
| fig.subplots_adjust(hspace=0.3) | |
| plt.savefig(fp) | |
| plt.close() | |
| def save_grid_images_and_captions( | |
| images: Union[torch.Tensor, List[torch.Tensor]], | |
| captions: List, | |
| fp: Union[Text, pathlib.Path, BinaryIO], | |
| nrow: int = 4, | |
| normalize: bool = True | |
| ) -> None: | |
| """ | |
| Save a grid of images and their captions into an image file. | |
| Args: | |
| images (Union[torch.Tensor, List[torch.Tensor]]): A list of images to display. | |
| captions (List): A list of captions for each image. | |
| fp (Union[Text, pathlib.Path, BinaryIO]): The file path to save the image to. | |
| nrow (int, optional): The number of images to display in each row. Defaults to 4. | |
| normalize (bool, optional): Whether to normalize the image or not. Defaults to False. | |
| """ | |
| num_images = len(images) | |
| num_rows, num_cols = _get_subplot_shape(num_images, nrow) | |
| fig = plt.figure(figsize=(25, 20)) | |
| for i in range(num_images): | |
| ax = fig.add_subplot(num_rows, num_cols, i + 1) | |
| image, caption = images[i], captions[i] | |
| if normalize: | |
| image = sample2pil_transforms(image) | |
| ax.imshow(image) | |
| title = f'"{caption}"' if num_images > 1 else f'"{captions}"' | |
| title = _insert_newline(title) | |
| ax.set_title(title, fontsize=20) | |
| ax.axis('off') | |
| fig.subplots_adjust(hspace=0.3) | |
| plt.savefig(fp) | |
| plt.close() | |
| def _get_subplot_shape(num_images, nrow): | |
| """ | |
| Calculate the number of rows and columns required to display images in a grid. | |
| Args: | |
| num_images (int): The total number of images to display. | |
| nrow (int): The maximum number of images to display in each row. | |
| Returns: | |
| Tuple[int, int]: The number of rows and columns required to display images in a grid. | |
| """ | |
| num_cols = min(num_images, nrow) | |
| num_rows = (num_images + num_cols - 1) // num_cols | |
| return num_rows, num_cols | |
| def _insert_newline(string, point=9): | |
| # split by blank | |
| words = string.split() | |
| if len(words) <= point: | |
| return string | |
| word_chunks = [words[i:i + point] for i in range(0, len(words), point)] | |
| new_string = "\n".join(" ".join(chunk) for chunk in word_chunks) | |
| return new_string | |