Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from torchvision import transforms | |
| from torchvision.utils import make_grid | |
| from skimage.transform import resize | |
| from .u2net import U2NET | |
| def plot_attn_dino(attn, threshold_map, inputs, inds, output_path): | |
| # currently supports one image (and not a batch) | |
| plt.figure(figsize=(10, 5)) | |
| plt.subplot(2, attn.shape[0] + 2, 1) | |
| main_im = make_grid(inputs, normalize=True, pad_value=2) | |
| main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) | |
| plt.imshow(main_im, interpolation='nearest') | |
| plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') | |
| plt.title("input im") | |
| plt.axis("off") | |
| plt.subplot(2, attn.shape[0] + 2, 2) | |
| plt.imshow(attn.sum(0).numpy(), interpolation='nearest') | |
| plt.title("atn map sum") | |
| plt.axis("off") | |
| plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 3) | |
| plt.imshow(threshold_map[-1].numpy(), interpolation='nearest') | |
| plt.title("prob sum") | |
| plt.axis("off") | |
| plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 4) | |
| plt.imshow(threshold_map[:-1].sum(0).numpy(), interpolation='nearest') | |
| plt.title("thresh sum") | |
| plt.axis("off") | |
| for i in range(attn.shape[0]): | |
| plt.subplot(2, attn.shape[0] + 2, i + 3) | |
| plt.imshow(attn[i].numpy()) | |
| plt.axis("off") | |
| plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 1 + i + 4) | |
| plt.imshow(threshold_map[i].numpy()) | |
| plt.axis("off") | |
| plt.tight_layout() | |
| plt.savefig(output_path) | |
| plt.close() | |
| def plot_attn_clip(attn, threshold_map, inputs, inds, output_path): | |
| # currently supports one image (and not a batch) | |
| plt.figure(figsize=(10, 5)) | |
| plt.subplot(1, 3, 1) | |
| main_im = make_grid(inputs, normalize=True, pad_value=2) | |
| main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) | |
| plt.imshow(main_im, interpolation='nearest') | |
| plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') | |
| plt.title("input im") | |
| plt.axis("off") | |
| plt.subplot(1, 3, 2) | |
| plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1) | |
| plt.title("attn map") | |
| plt.axis("off") | |
| plt.subplot(1, 3, 3) | |
| threshold_map_ = (threshold_map - threshold_map.min()) / \ | |
| (threshold_map.max() - threshold_map.min()) | |
| plt.imshow(threshold_map_, interpolation='nearest', vmin=0, vmax=1) | |
| plt.title("prob softmax") | |
| plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') | |
| plt.axis("off") | |
| plt.tight_layout() | |
| plt.savefig(output_path) | |
| plt.close() | |
| def plot_attn(attn, threshold_map, inputs, inds, output_path, saliency_model): | |
| if saliency_model == "dino": | |
| plot_attn_dino(attn, threshold_map, inputs, inds, output_path) | |
| elif saliency_model == "clip": | |
| plot_attn_clip(attn, threshold_map, inputs, inds, output_path) | |
| def fix_image_scale(im): | |
| im_np = np.array(im) / 255 | |
| height, width = im_np.shape[0], im_np.shape[1] | |
| max_len = max(height, width) + 20 | |
| new_background = np.ones((max_len, max_len, 3)) | |
| y, x = max_len // 2 - height // 2, max_len // 2 - width // 2 | |
| new_background[y: y + height, x: x + width] = im_np | |
| new_background = (new_background / new_background.max() * 255).astype(np.uint8) | |
| new_im = Image.fromarray(new_background) | |
| return new_im | |
| def get_mask_u2net(pil_im, output_dir, u2net_path, device="cpu"): | |
| # input preprocess | |
| w, h = pil_im.size[0], pil_im.size[1] | |
| im_size = min(w, h) | |
| data_transforms = transforms.Compose([ | |
| transforms.Resize(min(320, im_size), interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), | |
| std=(0.26862954, 0.26130258, 0.27577711)), | |
| ]) | |
| input_im_trans = data_transforms(pil_im).unsqueeze(0).to(device) | |
| # load U^2 Net model | |
| net = U2NET(in_ch=3, out_ch=1) | |
| net.load_state_dict(torch.load(u2net_path)) | |
| net.to(device) | |
| net.eval() | |
| # get mask | |
| with torch.no_grad(): | |
| d1, d2, d3, d4, d5, d6, d7 = net(input_im_trans.detach()) | |
| pred = d1[:, 0, :, :] | |
| pred = (pred - pred.min()) / (pred.max() - pred.min()) | |
| predict = pred | |
| predict[predict < 0.5] = 0 | |
| predict[predict >= 0.5] = 1 | |
| mask = torch.cat([predict, predict, predict], dim=0).permute(1, 2, 0) | |
| mask = mask.cpu().numpy() | |
| mask = resize(mask, (h, w), anti_aliasing=False) | |
| mask[mask < 0.5] = 0 | |
| mask[mask >= 0.5] = 1 | |
| # predict_np = predict.clone().cpu().data.numpy() | |
| im = Image.fromarray((mask[:, :, 0] * 255).astype(np.uint8)).convert('RGB') | |
| save_path_ = output_dir / "mask.png" | |
| im.save(save_path_) | |
| im_np = np.array(pil_im) | |
| im_np = im_np / im_np.max() | |
| im_np = mask * im_np | |
| im_np[mask == 0] = 1 | |
| im_final = (im_np / im_np.max() * 255).astype(np.uint8) | |
| im_final = Image.fromarray(im_final) | |
| # free u2net | |
| del net | |
| torch.cuda.empty_cache() | |
| return im_final, predict | |