Spaces:
Running
Running
Update
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- .gitattributes +4 -0
- assets/demo.gif +3 -0
- assets/metrics.png +0 -0
- assets/network.png +0 -0
- assets/title_any_image.gif +0 -0
- assets/title_harmon.gif +0 -0
- assets/title_you_want.gif +0 -0
- assets/visualizations.png +0 -0
- assets/visualizations2.png +3 -0
- datasets/__init__.py +0 -0
- datasets/__pycache__/__init__.cpython-38.pyc +0 -0
- datasets/__pycache__/build_INR_dataset.cpython-38.pyc +0 -0
- datasets/__pycache__/build_dataset.cpython-38.pyc +0 -0
- datasets/build_INR_dataset.py +36 -0
- datasets/build_dataset.py +371 -0
- demo/demo_2k_composite.jpg +0 -0
- demo/demo_2k_mask.jpg +0 -0
- demo/demo_2k_real.jpg +0 -0
- demo/demo_6k_composite.jpg +3 -0
- demo/demo_6k_mask.jpg +0 -0
- demo/demo_6k_real.jpg +3 -0
- model/__init__.py +0 -0
- model/__pycache__/__init__.cpython-38.pyc +0 -0
- model/__pycache__/backbone.cpython-38.pyc +0 -0
- model/__pycache__/build_model.cpython-38.pyc +0 -0
- model/__pycache__/lut_transformation_net.cpython-38.pyc +0 -0
- model/backbone.py +79 -0
- model/base/__init__.py +0 -0
- model/base/__pycache__/__init__.cpython-38.pyc +0 -0
- model/base/__pycache__/basic_blocks.cpython-38.pyc +0 -0
- model/base/__pycache__/conv_autoencoder.cpython-38.pyc +0 -0
- model/base/__pycache__/ih_model.cpython-38.pyc +0 -0
- model/base/__pycache__/ops.cpython-38.pyc +0 -0
- model/base/basic_blocks.py +366 -0
- model/base/conv_autoencoder.py +519 -0
- model/base/ih_model.py +88 -0
- model/base/ops.py +397 -0
- model/build_model.py +24 -0
- model/hrnetv2/__init__.py +0 -0
- model/hrnetv2/__pycache__/__init__.cpython-38.pyc +0 -0
- model/hrnetv2/__pycache__/hrnet_ocr.cpython-38.pyc +0 -0
- model/hrnetv2/__pycache__/modifiers.cpython-38.pyc +0 -0
- model/hrnetv2/__pycache__/ocr.cpython-38.pyc +0 -0
- model/hrnetv2/__pycache__/resnetv1b.cpython-38.pyc +0 -0
- model/hrnetv2/hrnet_ocr.py +400 -0
- model/hrnetv2/modifiers.py +11 -0
- model/hrnetv2/ocr.py +140 -0
- model/hrnetv2/resnetv1b.py +276 -0
- model/lut_transformation_net.py +65 -0
- pretrained_models/Resolution_1024_HAdobe5K.pth +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/demo.gif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/visualizations2.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
demo/demo_6k_composite.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
demo/demo_6k_real.jpg filter=lfs diff=lfs merge=lfs -text
|
assets/demo.gif
ADDED
|
Git LFS Details
|
assets/metrics.png
ADDED
|
assets/network.png
ADDED
|
assets/title_any_image.gif
ADDED
|
assets/title_harmon.gif
ADDED
|
assets/title_you_want.gif
ADDED
|
assets/visualizations.png
ADDED
|
assets/visualizations2.png
ADDED
|
Git LFS Details
|
datasets/__init__.py
ADDED
|
File without changes
|
datasets/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (176 Bytes). View file
|
|
|
datasets/__pycache__/build_INR_dataset.cpython-38.pyc
ADDED
|
Binary file (1.31 kB). View file
|
|
|
datasets/__pycache__/build_dataset.cpython-38.pyc
ADDED
|
Binary file (6.96 kB). View file
|
|
|
datasets/build_INR_dataset.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import misc
|
| 2 |
+
from albumentations import Resize
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Implicit2DGenerator(object):
|
| 6 |
+
def __init__(self, opt, mode):
|
| 7 |
+
if mode == 'Train':
|
| 8 |
+
sidelength = opt.INR_input_size
|
| 9 |
+
elif mode == 'Val':
|
| 10 |
+
sidelength = opt.input_size
|
| 11 |
+
else:
|
| 12 |
+
raise NotImplementedError
|
| 13 |
+
|
| 14 |
+
self.mode = mode
|
| 15 |
+
|
| 16 |
+
self.size = sidelength
|
| 17 |
+
|
| 18 |
+
if isinstance(sidelength, int):
|
| 19 |
+
sidelength = (sidelength, sidelength)
|
| 20 |
+
|
| 21 |
+
self.mgrid = misc.get_mgrid(sidelength)
|
| 22 |
+
|
| 23 |
+
self.transform = Resize(self.size, self.size)
|
| 24 |
+
|
| 25 |
+
def generator(self, torch_transforms, composite_image, real_image, mask):
|
| 26 |
+
composite_image = torch_transforms(self.transform(image=composite_image)['image'])
|
| 27 |
+
real_image = torch_transforms(self.transform(image=real_image)['image'])
|
| 28 |
+
|
| 29 |
+
fg_INR_RGB = composite_image.permute(1, 2, 0).contiguous().view(-1, 3)
|
| 30 |
+
fg_transfer_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3)
|
| 31 |
+
bg_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3)
|
| 32 |
+
|
| 33 |
+
fg_INR_coordinates = self.mgrid
|
| 34 |
+
bg_INR_coordinates = self.mgrid
|
| 35 |
+
|
| 36 |
+
return fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB
|
datasets/build_dataset.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torchvision
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
from utils.misc import prepare_cooridinate_input, customRandomCrop
|
| 9 |
+
|
| 10 |
+
from datasets.build_INR_dataset import Implicit2DGenerator
|
| 11 |
+
import albumentations
|
| 12 |
+
from albumentations import Resize, RandomResizedCrop, HorizontalFlip
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class dataset_generator(torch.utils.data.Dataset):
|
| 17 |
+
def __init__(self, dataset_txt, alb_transforms, torch_transforms, opt, area_keep_thresh=0.2, mode='Train'):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
self.opt = opt
|
| 21 |
+
self.root_path = opt.dataset_path
|
| 22 |
+
self.mode = mode
|
| 23 |
+
|
| 24 |
+
self.alb_transforms = alb_transforms
|
| 25 |
+
self.torch_transforms = torch_transforms
|
| 26 |
+
self.kp_t = area_keep_thresh
|
| 27 |
+
|
| 28 |
+
with open(dataset_txt, 'r') as f:
|
| 29 |
+
self.dataset_samples = [os.path.join(self.root_path, x.strip()) for x in f.readlines()]
|
| 30 |
+
|
| 31 |
+
self.INR_dataset = Implicit2DGenerator(opt, self.mode)
|
| 32 |
+
|
| 33 |
+
def __len__(self):
|
| 34 |
+
return len(self.dataset_samples)
|
| 35 |
+
|
| 36 |
+
def __getitem__(self, idx):
|
| 37 |
+
composite_image = self.dataset_samples[idx]
|
| 38 |
+
|
| 39 |
+
if self.opt.hr_train:
|
| 40 |
+
if self.opt.isFullRes:
|
| 41 |
+
"Since in dataset preprocessing, we resize the image in HAdobe5k to a lower resolution for " \
|
| 42 |
+
"quick loading, we need to change the path here to that of the original resolution of HAdobe5k " \
|
| 43 |
+
"if `opt.isFullRes` is set to True."
|
| 44 |
+
composite_image = composite_image.replace("HAdobe5k", "HAdobe5kori")
|
| 45 |
+
|
| 46 |
+
real_image = '_'.join(composite_image.split('_')[:2]).replace("composite_images", "real_images") + '.jpg'
|
| 47 |
+
mask = '_'.join(composite_image.split('_')[:-1]).replace("composite_images", "masks") + '.png'
|
| 48 |
+
|
| 49 |
+
composite_image = cv2.imread(composite_image)
|
| 50 |
+
composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB)
|
| 51 |
+
|
| 52 |
+
real_image = cv2.imread(real_image)
|
| 53 |
+
real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB)
|
| 54 |
+
|
| 55 |
+
mask = cv2.imread(mask)
|
| 56 |
+
mask = mask[:, :, 0].astype(np.float32) / 255.
|
| 57 |
+
|
| 58 |
+
"""
|
| 59 |
+
If set `opt.hr_train` to True:
|
| 60 |
+
|
| 61 |
+
Apply multi resolution crop for HR image train. Specifically, for 1024/2048 `input_size` (not fullres),
|
| 62 |
+
the training phase is first to RandomResizeCrop 1024/2048 `input_size`, then to random crop a `base_size`
|
| 63 |
+
patch to feed in multiINR process. For inference, just resize it.
|
| 64 |
+
|
| 65 |
+
While for fullres, the RandomResizeCrop is removed and just do a random crop. For inference, just keep the size.
|
| 66 |
+
|
| 67 |
+
BTW, we implement LR and HR mixing train. I.e., the following `random.random() < 0.5`
|
| 68 |
+
"""
|
| 69 |
+
if self.opt.hr_train:
|
| 70 |
+
if self.mode == 'Train' and self.opt.isFullRes:
|
| 71 |
+
if random.random() < 0.5: # LR mix training
|
| 72 |
+
mixTransform = albumentations.Compose(
|
| 73 |
+
[
|
| 74 |
+
RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)),
|
| 75 |
+
HorizontalFlip()],
|
| 76 |
+
additional_targets={'real_image': 'image', 'object_mask': 'image'}
|
| 77 |
+
)
|
| 78 |
+
origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
|
| 79 |
+
origin_bg_ratio = 1 - origin_fg_ratio
|
| 80 |
+
|
| 81 |
+
"Ensure fg and bg not disappear after transformation"
|
| 82 |
+
valid_augmentation = False
|
| 83 |
+
transform_out = None
|
| 84 |
+
time = 0
|
| 85 |
+
while not valid_augmentation:
|
| 86 |
+
time += 1
|
| 87 |
+
# There are some extreme ratio pics, this code is to avoid being hindered by them.
|
| 88 |
+
if time == 20:
|
| 89 |
+
tmp_transform = albumentations.Compose(
|
| 90 |
+
[Resize(self.opt.base_size, self.opt.base_size)],
|
| 91 |
+
additional_targets={'real_image': 'image',
|
| 92 |
+
'object_mask': 'image'})
|
| 93 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image,
|
| 94 |
+
object_mask=mask)
|
| 95 |
+
valid_augmentation = True
|
| 96 |
+
else:
|
| 97 |
+
transform_out = mixTransform(image=composite_image, real_image=real_image,
|
| 98 |
+
object_mask=mask)
|
| 99 |
+
valid_augmentation = check_augmented_sample(transform_out['object_mask'],
|
| 100 |
+
origin_fg_ratio,
|
| 101 |
+
origin_bg_ratio,
|
| 102 |
+
self.kp_t)
|
| 103 |
+
composite_image = transform_out['image']
|
| 104 |
+
real_image = transform_out['real_image']
|
| 105 |
+
mask = transform_out['object_mask']
|
| 106 |
+
else: # Padding to ensure that the original resolution can be divided by 4. This is for pixel-aligned crop.
|
| 107 |
+
if real_image.shape[0] < 256:
|
| 108 |
+
bottom_pad = 256 - real_image.shape[0]
|
| 109 |
+
else:
|
| 110 |
+
bottom_pad = (4 - real_image.shape[0] % 4) % 4
|
| 111 |
+
if real_image.shape[1] < 256:
|
| 112 |
+
right_pad = 256 - real_image.shape[1]
|
| 113 |
+
else:
|
| 114 |
+
right_pad = (4 - real_image.shape[1] % 4) % 4
|
| 115 |
+
composite_image = cv2.copyMakeBorder(composite_image, 0, bottom_pad, 0, right_pad,
|
| 116 |
+
cv2.BORDER_REPLICATE)
|
| 117 |
+
real_image = cv2.copyMakeBorder(real_image, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE)
|
| 118 |
+
mask = cv2.copyMakeBorder(mask, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE)
|
| 119 |
+
|
| 120 |
+
origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
|
| 121 |
+
origin_bg_ratio = 1 - origin_fg_ratio
|
| 122 |
+
|
| 123 |
+
"Ensure fg and bg not disappear after transformation"
|
| 124 |
+
valid_augmentation = False
|
| 125 |
+
transform_out = None
|
| 126 |
+
time = 0
|
| 127 |
+
|
| 128 |
+
if self.opt.hr_train:
|
| 129 |
+
if self.mode == 'Train':
|
| 130 |
+
if not self.opt.isFullRes:
|
| 131 |
+
if random.random() < 0.5: # LR mix training
|
| 132 |
+
mixTransform = albumentations.Compose(
|
| 133 |
+
[
|
| 134 |
+
RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)),
|
| 135 |
+
HorizontalFlip()],
|
| 136 |
+
additional_targets={'real_image': 'image', 'object_mask': 'image'}
|
| 137 |
+
)
|
| 138 |
+
while not valid_augmentation:
|
| 139 |
+
time += 1
|
| 140 |
+
# There are some extreme ratio pics, this code is to avoid being hindered by them.
|
| 141 |
+
if time == 20:
|
| 142 |
+
tmp_transform = albumentations.Compose(
|
| 143 |
+
[Resize(self.opt.base_size, self.opt.base_size)],
|
| 144 |
+
additional_targets={'real_image': 'image',
|
| 145 |
+
'object_mask': 'image'})
|
| 146 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image,
|
| 147 |
+
object_mask=mask)
|
| 148 |
+
valid_augmentation = True
|
| 149 |
+
else:
|
| 150 |
+
transform_out = mixTransform(image=composite_image, real_image=real_image,
|
| 151 |
+
object_mask=mask)
|
| 152 |
+
valid_augmentation = check_augmented_sample(transform_out['object_mask'],
|
| 153 |
+
origin_fg_ratio,
|
| 154 |
+
origin_bg_ratio,
|
| 155 |
+
self.kp_t)
|
| 156 |
+
else:
|
| 157 |
+
while not valid_augmentation:
|
| 158 |
+
time += 1
|
| 159 |
+
# There are some extreme ratio pics, this code is to avoid being hindered by them.
|
| 160 |
+
if time == 20:
|
| 161 |
+
tmp_transform = albumentations.Compose(
|
| 162 |
+
[Resize(self.opt.input_size, self.opt.input_size)],
|
| 163 |
+
additional_targets={'real_image': 'image',
|
| 164 |
+
'object_mask': 'image'})
|
| 165 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image,
|
| 166 |
+
object_mask=mask)
|
| 167 |
+
valid_augmentation = True
|
| 168 |
+
else:
|
| 169 |
+
transform_out = self.alb_transforms(image=composite_image, real_image=real_image,
|
| 170 |
+
object_mask=mask)
|
| 171 |
+
valid_augmentation = check_augmented_sample(transform_out['object_mask'],
|
| 172 |
+
origin_fg_ratio,
|
| 173 |
+
origin_bg_ratio,
|
| 174 |
+
self.kp_t)
|
| 175 |
+
composite_image = transform_out['image']
|
| 176 |
+
real_image = transform_out['real_image']
|
| 177 |
+
mask = transform_out['object_mask']
|
| 178 |
+
|
| 179 |
+
origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
|
| 180 |
+
|
| 181 |
+
full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0)
|
| 182 |
+
|
| 183 |
+
tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)],
|
| 184 |
+
additional_targets={'real_image': 'image',
|
| 185 |
+
'object_mask': 'image'})
|
| 186 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
|
| 187 |
+
compos_list = [self.torch_transforms(transform_out['image'])]
|
| 188 |
+
real_list = [self.torch_transforms(transform_out['real_image'])]
|
| 189 |
+
mask_list = [
|
| 190 |
+
torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))]
|
| 191 |
+
coord_map_list = []
|
| 192 |
+
|
| 193 |
+
valid_augmentation = False
|
| 194 |
+
while not valid_augmentation:
|
| 195 |
+
# RSC strategy. To crop different resolutions.
|
| 196 |
+
transform_out, c_h, c_w = customRandomCrop([composite_image, real_image, mask, full_coord],
|
| 197 |
+
self.opt.base_size, self.opt.base_size)
|
| 198 |
+
valid_augmentation = check_hr_crop_sample(transform_out[2], origin_fg_ratio)
|
| 199 |
+
|
| 200 |
+
compos_list.append(self.torch_transforms(transform_out[0]))
|
| 201 |
+
real_list.append(self.torch_transforms(transform_out[1]))
|
| 202 |
+
mask_list.append(
|
| 203 |
+
torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32)))
|
| 204 |
+
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
|
| 205 |
+
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
|
| 206 |
+
for n in range(2):
|
| 207 |
+
tmp_comp = cv2.resize(composite_image, (
|
| 208 |
+
composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1)))
|
| 209 |
+
tmp_real = cv2.resize(real_image,
|
| 210 |
+
(real_image.shape[1] // 2 ** (n + 1), real_image.shape[0] // 2 ** (n + 1)))
|
| 211 |
+
tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1)))
|
| 212 |
+
tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0)
|
| 213 |
+
|
| 214 |
+
transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_real, tmp_mask, tmp_coord],
|
| 215 |
+
self.opt.base_size // 2 ** (n + 1),
|
| 216 |
+
self.opt.base_size // 2 ** (n + 1), c_h, c_w)
|
| 217 |
+
compos_list.append(self.torch_transforms(transform_out[0]))
|
| 218 |
+
real_list.append(self.torch_transforms(transform_out[1]))
|
| 219 |
+
mask_list.append(
|
| 220 |
+
torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32)))
|
| 221 |
+
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
|
| 222 |
+
out_comp = compos_list
|
| 223 |
+
out_real = real_list
|
| 224 |
+
out_mask = mask_list
|
| 225 |
+
out_coord = coord_map_list
|
| 226 |
+
|
| 227 |
+
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
|
| 228 |
+
self.torch_transforms, transform_out[0], transform_out[1], mask)
|
| 229 |
+
|
| 230 |
+
return {
|
| 231 |
+
'file_path': self.dataset_samples[idx],
|
| 232 |
+
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
|
| 233 |
+
'composite_image': out_comp,
|
| 234 |
+
'real_image': out_real,
|
| 235 |
+
'mask': out_mask,
|
| 236 |
+
'coordinate_map': out_coord,
|
| 237 |
+
'composite_image0': out_comp[0],
|
| 238 |
+
'real_image0': out_real[0],
|
| 239 |
+
'mask0': out_mask[0],
|
| 240 |
+
'coordinate_map0': out_coord[0],
|
| 241 |
+
'composite_image1': out_comp[1],
|
| 242 |
+
'real_image1': out_real[1],
|
| 243 |
+
'mask1': out_mask[1],
|
| 244 |
+
'coordinate_map1': out_coord[1],
|
| 245 |
+
'composite_image2': out_comp[2],
|
| 246 |
+
'real_image2': out_real[2],
|
| 247 |
+
'mask2': out_mask[2],
|
| 248 |
+
'coordinate_map2': out_coord[2],
|
| 249 |
+
'composite_image3': out_comp[3],
|
| 250 |
+
'real_image3': out_real[3],
|
| 251 |
+
'mask3': out_mask[3],
|
| 252 |
+
'coordinate_map3': out_coord[3],
|
| 253 |
+
'fg_INR_coordinates': fg_INR_coordinates,
|
| 254 |
+
'bg_INR_coordinates': bg_INR_coordinates,
|
| 255 |
+
'fg_INR_RGB': fg_INR_RGB,
|
| 256 |
+
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
|
| 257 |
+
'bg_INR_RGB': bg_INR_RGB
|
| 258 |
+
}
|
| 259 |
+
else:
|
| 260 |
+
if not self.opt.isFullRes:
|
| 261 |
+
tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)],
|
| 262 |
+
additional_targets={'real_image': 'image',
|
| 263 |
+
'object_mask': 'image'})
|
| 264 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
|
| 265 |
+
|
| 266 |
+
coordinate_map = prepare_cooridinate_input(transform_out['object_mask'])
|
| 267 |
+
|
| 268 |
+
"Generate INR dataset."
|
| 269 |
+
mask = (torchvision.transforms.ToTensor()(
|
| 270 |
+
transform_out['object_mask']).squeeze() > 100 / 255.).view(-1)
|
| 271 |
+
mask = np.bool_(mask.numpy())
|
| 272 |
+
|
| 273 |
+
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
|
| 274 |
+
self.torch_transforms, transform_out['image'], transform_out['real_image'], mask)
|
| 275 |
+
|
| 276 |
+
return {
|
| 277 |
+
'file_path': self.dataset_samples[idx],
|
| 278 |
+
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
|
| 279 |
+
'composite_image': self.torch_transforms(transform_out['image']),
|
| 280 |
+
'real_image': self.torch_transforms(transform_out['real_image']),
|
| 281 |
+
'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32),
|
| 282 |
+
# Can automatically transfer to Tensor.
|
| 283 |
+
'coordinate_map': coordinate_map,
|
| 284 |
+
'fg_INR_coordinates': fg_INR_coordinates,
|
| 285 |
+
'bg_INR_coordinates': bg_INR_coordinates,
|
| 286 |
+
'fg_INR_RGB': fg_INR_RGB,
|
| 287 |
+
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
|
| 288 |
+
'bg_INR_RGB': bg_INR_RGB
|
| 289 |
+
}
|
| 290 |
+
else:
|
| 291 |
+
coordinate_map = prepare_cooridinate_input(mask)
|
| 292 |
+
|
| 293 |
+
"Generate INR dataset."
|
| 294 |
+
mask_tmp = (torchvision.transforms.ToTensor()(mask).squeeze() > 100 / 255.).view(-1)
|
| 295 |
+
mask_tmp = np.bool_(mask_tmp.numpy())
|
| 296 |
+
|
| 297 |
+
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
|
| 298 |
+
self.torch_transforms, composite_image, real_image, mask_tmp)
|
| 299 |
+
|
| 300 |
+
return {
|
| 301 |
+
'file_path': self.dataset_samples[idx],
|
| 302 |
+
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
|
| 303 |
+
'composite_image': self.torch_transforms(composite_image),
|
| 304 |
+
'real_image': self.torch_transforms(real_image),
|
| 305 |
+
'mask': mask[np.newaxis, ...].astype(np.float32),
|
| 306 |
+
# Can automatically transfer to Tensor.
|
| 307 |
+
'coordinate_map': coordinate_map,
|
| 308 |
+
'fg_INR_coordinates': fg_INR_coordinates,
|
| 309 |
+
'bg_INR_coordinates': bg_INR_coordinates,
|
| 310 |
+
'fg_INR_RGB': fg_INR_RGB,
|
| 311 |
+
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
|
| 312 |
+
'bg_INR_RGB': bg_INR_RGB
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
while not valid_augmentation:
|
| 316 |
+
time += 1
|
| 317 |
+
# There are some extreme ratio pics, this code is to avoid being hindered by them.
|
| 318 |
+
if time == 20:
|
| 319 |
+
tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)],
|
| 320 |
+
additional_targets={'real_image': 'image',
|
| 321 |
+
'object_mask': 'image'})
|
| 322 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
|
| 323 |
+
valid_augmentation = True
|
| 324 |
+
else:
|
| 325 |
+
transform_out = self.alb_transforms(image=composite_image, real_image=real_image, object_mask=mask)
|
| 326 |
+
valid_augmentation = check_augmented_sample(transform_out['object_mask'], origin_fg_ratio,
|
| 327 |
+
origin_bg_ratio,
|
| 328 |
+
self.kp_t)
|
| 329 |
+
|
| 330 |
+
coordinate_map = prepare_cooridinate_input(transform_out['object_mask'])
|
| 331 |
+
|
| 332 |
+
"Generate INR dataset."
|
| 333 |
+
mask = (torchvision.transforms.ToTensor()(transform_out['object_mask']).squeeze() > 100 / 255.).view(-1)
|
| 334 |
+
mask = np.bool_(mask.numpy())
|
| 335 |
+
|
| 336 |
+
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
|
| 337 |
+
self.torch_transforms, transform_out['image'], transform_out['real_image'], mask)
|
| 338 |
+
|
| 339 |
+
return {
|
| 340 |
+
'file_path': self.dataset_samples[idx],
|
| 341 |
+
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
|
| 342 |
+
'composite_image': self.torch_transforms(transform_out['image']),
|
| 343 |
+
'real_image': self.torch_transforms(transform_out['real_image']),
|
| 344 |
+
'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32),
|
| 345 |
+
# Can automatically transfer to Tensor.
|
| 346 |
+
'coordinate_map': coordinate_map,
|
| 347 |
+
'fg_INR_coordinates': fg_INR_coordinates,
|
| 348 |
+
'bg_INR_coordinates': bg_INR_coordinates,
|
| 349 |
+
'fg_INR_RGB': fg_INR_RGB,
|
| 350 |
+
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
|
| 351 |
+
'bg_INR_RGB': bg_INR_RGB
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def check_augmented_sample(mask, origin_fg_ratio, origin_bg_ratio, area_keep_thresh):
|
| 356 |
+
current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
|
| 357 |
+
current_bg_ratio = 1 - current_fg_ratio
|
| 358 |
+
|
| 359 |
+
if current_fg_ratio < origin_fg_ratio * area_keep_thresh or current_bg_ratio < origin_bg_ratio * area_keep_thresh:
|
| 360 |
+
return False
|
| 361 |
+
|
| 362 |
+
return True
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def check_hr_crop_sample(mask, origin_fg_ratio):
|
| 366 |
+
current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
|
| 367 |
+
|
| 368 |
+
if current_fg_ratio < 0.8 * origin_fg_ratio:
|
| 369 |
+
return False
|
| 370 |
+
|
| 371 |
+
return True
|
demo/demo_2k_composite.jpg
ADDED
|
demo/demo_2k_mask.jpg
ADDED
|
demo/demo_2k_real.jpg
ADDED
|
demo/demo_6k_composite.jpg
ADDED
|
Git LFS Details
|
demo/demo_6k_mask.jpg
ADDED
|
demo/demo_6k_real.jpg
ADDED
|
Git LFS Details
|
model/__init__.py
ADDED
|
File without changes
|
model/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (173 Bytes). View file
|
|
|
model/__pycache__/backbone.cpython-38.pyc
ADDED
|
Binary file (2.96 kB). View file
|
|
|
model/__pycache__/build_model.cpython-38.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
model/__pycache__/lut_transformation_net.cpython-38.pyc
ADDED
|
Binary file (2.43 kB). View file
|
|
|
model/backbone.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from .hrnetv2.hrnet_ocr import HighResolutionNet
|
| 4 |
+
from .hrnetv2.modifiers import LRMult
|
| 5 |
+
from .base.basic_blocks import MaxPoolDownSize
|
| 6 |
+
from .base.ih_model import IHModelWithBackbone, DeepImageHarmonization
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def build_backbone(name, opt):
|
| 10 |
+
return eval(name)(opt)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class baseline(IHModelWithBackbone):
|
| 14 |
+
def __init__(self, opt, ocr=64):
|
| 15 |
+
base_config = {'model': DeepImageHarmonization,
|
| 16 |
+
'params': {'depth': 7, 'batchnorm_from': 2, 'image_fusion': True, 'opt': opt}}
|
| 17 |
+
|
| 18 |
+
params = base_config['params']
|
| 19 |
+
|
| 20 |
+
backbone = HRNetV2(opt, ocr=ocr)
|
| 21 |
+
|
| 22 |
+
params.update(dict(
|
| 23 |
+
backbone_from=2,
|
| 24 |
+
backbone_channels=backbone.output_channels,
|
| 25 |
+
backbone_mode='cat',
|
| 26 |
+
opt=opt
|
| 27 |
+
))
|
| 28 |
+
base_model = base_config['model'](**params)
|
| 29 |
+
|
| 30 |
+
super(baseline, self).__init__(base_model, backbone, False, 'sum', opt=opt)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class HRNetV2(nn.Module):
|
| 34 |
+
def __init__(
|
| 35 |
+
self, opt,
|
| 36 |
+
cat_outputs=True,
|
| 37 |
+
pyramid_channels=-1, pyramid_depth=4,
|
| 38 |
+
width=18, ocr=128, small=False,
|
| 39 |
+
lr_mult=0.1, pretained=True
|
| 40 |
+
):
|
| 41 |
+
super(HRNetV2, self).__init__()
|
| 42 |
+
self.opt = opt
|
| 43 |
+
self.cat_outputs = cat_outputs
|
| 44 |
+
self.ocr_on = ocr > 0 and cat_outputs
|
| 45 |
+
self.pyramid_on = pyramid_channels > 0 and cat_outputs
|
| 46 |
+
|
| 47 |
+
self.hrnet = HighResolutionNet(width, 2, ocr_width=ocr, small=small, opt=opt)
|
| 48 |
+
self.hrnet.apply(LRMult(lr_mult))
|
| 49 |
+
if self.ocr_on:
|
| 50 |
+
self.hrnet.ocr_distri_head.apply(LRMult(1.0))
|
| 51 |
+
self.hrnet.ocr_gather_head.apply(LRMult(1.0))
|
| 52 |
+
self.hrnet.conv3x3_ocr.apply(LRMult(1.0))
|
| 53 |
+
|
| 54 |
+
hrnet_cat_channels = [width * 2 ** i for i in range(4)]
|
| 55 |
+
if self.pyramid_on:
|
| 56 |
+
self.output_channels = [pyramid_channels] * 4
|
| 57 |
+
elif self.ocr_on:
|
| 58 |
+
self.output_channels = [ocr * 2]
|
| 59 |
+
elif self.cat_outputs:
|
| 60 |
+
self.output_channels = [sum(hrnet_cat_channels)]
|
| 61 |
+
else:
|
| 62 |
+
self.output_channels = hrnet_cat_channels
|
| 63 |
+
|
| 64 |
+
if self.pyramid_on:
|
| 65 |
+
downsize_in_channels = ocr * 2 if self.ocr_on else sum(hrnet_cat_channels)
|
| 66 |
+
self.downsize = MaxPoolDownSize(downsize_in_channels, pyramid_channels, pyramid_channels, pyramid_depth)
|
| 67 |
+
|
| 68 |
+
if pretained:
|
| 69 |
+
self.load_pretrained_weights(
|
| 70 |
+
r".\pretrained_models/hrnetv2_w18_imagenet_pretrained.pth")
|
| 71 |
+
|
| 72 |
+
self.output_resolution = (opt.input_size // 8) ** 2
|
| 73 |
+
|
| 74 |
+
def forward(self, image, mask, mask_features=None):
|
| 75 |
+
outputs = list(self.hrnet(image, mask, mask_features))
|
| 76 |
+
return outputs
|
| 77 |
+
|
| 78 |
+
def load_pretrained_weights(self, pretrained_path):
|
| 79 |
+
self.hrnet.load_pretrained_weights(pretrained_path)
|
model/base/__init__.py
ADDED
|
File without changes
|
model/base/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (178 Bytes). View file
|
|
|
model/base/__pycache__/basic_blocks.cpython-38.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
model/base/__pycache__/conv_autoencoder.cpython-38.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
model/base/__pycache__/ih_model.cpython-38.pyc
ADDED
|
Binary file (3.22 kB). View file
|
|
|
model/base/__pycache__/ops.cpython-38.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
model/base/basic_blocks.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def hyper_weight_init(m, in_features_main_net, activation):
|
| 7 |
+
if hasattr(m, 'weight'):
|
| 8 |
+
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
|
| 9 |
+
m.weight.data = m.weight.data / 1.e2
|
| 10 |
+
|
| 11 |
+
if hasattr(m, 'bias'):
|
| 12 |
+
with torch.no_grad():
|
| 13 |
+
if activation == 'sine':
|
| 14 |
+
m.bias.uniform_(-np.sqrt(6 / in_features_main_net) / 30, np.sqrt(6 / in_features_main_net) / 30)
|
| 15 |
+
elif activation == 'leakyrelu_pe':
|
| 16 |
+
m.bias.uniform_(-np.sqrt(6 / in_features_main_net), np.sqrt(6 / in_features_main_net))
|
| 17 |
+
else:
|
| 18 |
+
raise NotImplementedError
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ConvBlock(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
in_channels, out_channels,
|
| 25 |
+
kernel_size=4, stride=2, padding=1,
|
| 26 |
+
norm_layer=nn.BatchNorm2d, activation=nn.ELU,
|
| 27 |
+
bias=True,
|
| 28 |
+
):
|
| 29 |
+
super(ConvBlock, self).__init__()
|
| 30 |
+
self.block = nn.Sequential(
|
| 31 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
|
| 32 |
+
norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
|
| 33 |
+
activation(),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return self.block(x)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MaxPoolDownSize(nn.Module):
|
| 41 |
+
def __init__(self, in_channels, mid_channels, out_channels, depth):
|
| 42 |
+
super(MaxPoolDownSize, self).__init__()
|
| 43 |
+
self.depth = depth
|
| 44 |
+
self.reduce_conv = ConvBlock(in_channels, mid_channels, kernel_size=1, stride=1, padding=0)
|
| 45 |
+
self.convs = nn.ModuleList([
|
| 46 |
+
ConvBlock(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 47 |
+
for conv_i in range(depth)
|
| 48 |
+
])
|
| 49 |
+
self.pool2d = nn.MaxPool2d(kernel_size=2)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
outputs = []
|
| 53 |
+
|
| 54 |
+
output = self.reduce_conv(x)
|
| 55 |
+
|
| 56 |
+
for conv_i, conv in enumerate(self.convs):
|
| 57 |
+
output = output if conv_i == 0 else self.pool2d(output)
|
| 58 |
+
outputs.append(conv(output))
|
| 59 |
+
|
| 60 |
+
return outputs
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class convParams(nn.Module):
|
| 64 |
+
def __init__(self, input_dim, INR_in_out, opt, hidden_mlp_num, hidden_dim=512, toRGB=False):
|
| 65 |
+
super(convParams, self).__init__()
|
| 66 |
+
self.INR_in_out = INR_in_out
|
| 67 |
+
self.cont_split_weight = []
|
| 68 |
+
self.cont_split_bias = []
|
| 69 |
+
self.hidden_mlp_num = hidden_mlp_num
|
| 70 |
+
self.param_factorize_dim = opt.param_factorize_dim
|
| 71 |
+
output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num, toRGB)
|
| 72 |
+
self.output_dim = output_dim
|
| 73 |
+
self.toRGB = toRGB
|
| 74 |
+
self.cont_extraction_net = nn.Sequential(
|
| 75 |
+
nn.Conv2d(input_dim, hidden_dim, kernel_size=3, stride=2, padding=1, bias=False),
|
| 76 |
+
# nn.BatchNorm2d(hidden_dim),
|
| 77 |
+
nn.ReLU(inplace=True),
|
| 78 |
+
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False),
|
| 79 |
+
# nn.BatchNorm2d(hidden_dim),
|
| 80 |
+
nn.ReLU(inplace=True),
|
| 81 |
+
nn.Conv2d(hidden_dim, output_dim, kernel_size=1, stride=1, padding=0, bias=True),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.cont_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation))
|
| 85 |
+
|
| 86 |
+
self.basic_params = nn.ParameterList()
|
| 87 |
+
if opt.param_factorize_dim > 0:
|
| 88 |
+
for id in range(self.hidden_mlp_num + 1):
|
| 89 |
+
if id == 0:
|
| 90 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
| 91 |
+
else:
|
| 92 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
| 93 |
+
self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, inp, outp)))
|
| 94 |
+
|
| 95 |
+
if toRGB:
|
| 96 |
+
self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, self.INR_in_out[1], 3)))
|
| 97 |
+
|
| 98 |
+
def forward(self, feat, outMore=False):
|
| 99 |
+
cont_params = self.cont_extraction_net(feat)
|
| 100 |
+
out_mlp = self.to_mlp(cont_params)
|
| 101 |
+
if outMore:
|
| 102 |
+
return out_mlp, cont_params
|
| 103 |
+
return out_mlp
|
| 104 |
+
|
| 105 |
+
def cal_params_num(self, INR_in_out, hidden_mlp_num, toRGB=False):
|
| 106 |
+
cont_params = 0
|
| 107 |
+
start = 0
|
| 108 |
+
if self.param_factorize_dim == -1:
|
| 109 |
+
cont_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1]
|
| 110 |
+
self.cont_split_weight.append([start, cont_params - INR_in_out[1]])
|
| 111 |
+
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
|
| 112 |
+
start = cont_params
|
| 113 |
+
|
| 114 |
+
for id in range(hidden_mlp_num):
|
| 115 |
+
cont_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1]
|
| 116 |
+
self.cont_split_weight.append([start, cont_params - INR_in_out[1]])
|
| 117 |
+
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
|
| 118 |
+
start = cont_params
|
| 119 |
+
|
| 120 |
+
if toRGB:
|
| 121 |
+
cont_params += INR_in_out[1] * 3 + 3
|
| 122 |
+
self.cont_split_weight.append([start, cont_params - 3])
|
| 123 |
+
self.cont_split_bias.append([cont_params - 3, cont_params])
|
| 124 |
+
|
| 125 |
+
elif self.param_factorize_dim > 0:
|
| 126 |
+
cont_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
|
| 127 |
+
INR_in_out[1]
|
| 128 |
+
self.cont_split_weight.append(
|
| 129 |
+
[start, start + INR_in_out[0] * self.param_factorize_dim, cont_params - INR_in_out[1]])
|
| 130 |
+
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
|
| 131 |
+
start = cont_params
|
| 132 |
+
|
| 133 |
+
for id in range(hidden_mlp_num):
|
| 134 |
+
cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
|
| 135 |
+
INR_in_out[1]
|
| 136 |
+
self.cont_split_weight.append(
|
| 137 |
+
[start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - INR_in_out[1]])
|
| 138 |
+
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
|
| 139 |
+
start = cont_params
|
| 140 |
+
|
| 141 |
+
if toRGB:
|
| 142 |
+
cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3
|
| 143 |
+
self.cont_split_weight.append(
|
| 144 |
+
[start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - 3])
|
| 145 |
+
self.cont_split_bias.append([cont_params - 3, cont_params])
|
| 146 |
+
|
| 147 |
+
return cont_params
|
| 148 |
+
|
| 149 |
+
def to_mlp(self, params):
|
| 150 |
+
all_weight_bias = []
|
| 151 |
+
if self.param_factorize_dim == -1:
|
| 152 |
+
for id in range(self.hidden_mlp_num + 1):
|
| 153 |
+
if id == 0:
|
| 154 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
| 155 |
+
else:
|
| 156 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
| 157 |
+
weight = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :]
|
| 158 |
+
weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:],
|
| 159 |
+
inp, outp)
|
| 160 |
+
|
| 161 |
+
bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :]
|
| 162 |
+
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
|
| 163 |
+
all_weight_bias.append([weight, bias])
|
| 164 |
+
|
| 165 |
+
if self.toRGB:
|
| 166 |
+
inp, outp = self.INR_in_out[1], 3
|
| 167 |
+
weight = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :]
|
| 168 |
+
weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:],
|
| 169 |
+
inp, outp)
|
| 170 |
+
|
| 171 |
+
bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :]
|
| 172 |
+
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
|
| 173 |
+
all_weight_bias.append([weight, bias])
|
| 174 |
+
|
| 175 |
+
return all_weight_bias
|
| 176 |
+
|
| 177 |
+
else:
|
| 178 |
+
for id in range(self.hidden_mlp_num + 1):
|
| 179 |
+
if id == 0:
|
| 180 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
| 181 |
+
else:
|
| 182 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
| 183 |
+
weight1 = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :]
|
| 184 |
+
weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:],
|
| 185 |
+
inp, self.param_factorize_dim)
|
| 186 |
+
|
| 187 |
+
weight2 = params[:, self.cont_split_weight[id][1]:self.cont_split_weight[id][2], :, :]
|
| 188 |
+
weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:],
|
| 189 |
+
self.param_factorize_dim, outp)
|
| 190 |
+
|
| 191 |
+
bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :]
|
| 192 |
+
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
|
| 193 |
+
|
| 194 |
+
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
|
| 195 |
+
|
| 196 |
+
if self.toRGB:
|
| 197 |
+
inp, outp = self.INR_in_out[1], 3
|
| 198 |
+
weight1 = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :]
|
| 199 |
+
weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:],
|
| 200 |
+
inp, self.param_factorize_dim)
|
| 201 |
+
|
| 202 |
+
weight2 = params[:, self.cont_split_weight[-1][1]:self.cont_split_weight[-1][2], :, :]
|
| 203 |
+
weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:],
|
| 204 |
+
self.param_factorize_dim, outp)
|
| 205 |
+
|
| 206 |
+
bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :]
|
| 207 |
+
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
|
| 208 |
+
|
| 209 |
+
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[-1], bias])
|
| 210 |
+
|
| 211 |
+
return all_weight_bias
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class lineParams(nn.Module):
|
| 215 |
+
def __init__(self, input_dim, INR_in_out, input_resolution, opt, hidden_mlp_num, toRGB=False,
|
| 216 |
+
hidden_dim=512):
|
| 217 |
+
super(lineParams, self).__init__()
|
| 218 |
+
self.INR_in_out = INR_in_out
|
| 219 |
+
self.app_split_weight = []
|
| 220 |
+
self.app_split_bias = []
|
| 221 |
+
self.toRGB = toRGB
|
| 222 |
+
self.hidden_mlp_num = hidden_mlp_num
|
| 223 |
+
self.param_factorize_dim = opt.param_factorize_dim
|
| 224 |
+
output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num)
|
| 225 |
+
self.output_dim = output_dim
|
| 226 |
+
|
| 227 |
+
self.compress_layer = nn.Sequential(
|
| 228 |
+
nn.Linear(input_resolution, 64, bias=False),
|
| 229 |
+
nn.BatchNorm1d(input_dim),
|
| 230 |
+
nn.ReLU(inplace=True),
|
| 231 |
+
nn.Linear(64, 1, bias=True)
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
self.app_extraction_net = nn.Sequential(
|
| 235 |
+
nn.Linear(input_dim, hidden_dim, bias=False),
|
| 236 |
+
# nn.BatchNorm1d(hidden_dim),
|
| 237 |
+
nn.ReLU(inplace=True),
|
| 238 |
+
nn.Linear(hidden_dim, hidden_dim, bias=False),
|
| 239 |
+
# nn.BatchNorm1d(hidden_dim),
|
| 240 |
+
nn.ReLU(inplace=True),
|
| 241 |
+
nn.Linear(hidden_dim, output_dim, bias=True)
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
self.app_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation))
|
| 245 |
+
|
| 246 |
+
self.basic_params = nn.ParameterList()
|
| 247 |
+
if opt.param_factorize_dim > 0:
|
| 248 |
+
for id in range(self.hidden_mlp_num + 1):
|
| 249 |
+
if id == 0:
|
| 250 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
| 251 |
+
else:
|
| 252 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
| 253 |
+
self.basic_params.append(nn.Parameter(torch.randn(1, inp, outp)))
|
| 254 |
+
if toRGB:
|
| 255 |
+
self.basic_params.append(nn.Parameter(torch.randn(1, self.INR_in_out[1], 3)))
|
| 256 |
+
|
| 257 |
+
def forward(self, feat):
|
| 258 |
+
app_params = self.app_extraction_net(self.compress_layer(torch.flatten(feat, 2)).squeeze(-1))
|
| 259 |
+
out_mlp = self.to_mlp(app_params)
|
| 260 |
+
return out_mlp, app_params
|
| 261 |
+
|
| 262 |
+
def cal_params_num(self, INR_in_out, hidden_mlp_num):
|
| 263 |
+
app_params = 0
|
| 264 |
+
start = 0
|
| 265 |
+
if self.param_factorize_dim == -1:
|
| 266 |
+
app_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1]
|
| 267 |
+
self.app_split_weight.append([start, app_params - INR_in_out[1]])
|
| 268 |
+
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
|
| 269 |
+
start = app_params
|
| 270 |
+
|
| 271 |
+
for id in range(hidden_mlp_num):
|
| 272 |
+
app_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1]
|
| 273 |
+
self.app_split_weight.append([start, app_params - INR_in_out[1]])
|
| 274 |
+
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
|
| 275 |
+
start = app_params
|
| 276 |
+
|
| 277 |
+
if self.toRGB:
|
| 278 |
+
app_params += INR_in_out[1] * 3 + 3
|
| 279 |
+
self.app_split_weight.append([start, app_params - 3])
|
| 280 |
+
self.app_split_bias.append([app_params - 3, app_params])
|
| 281 |
+
|
| 282 |
+
elif self.param_factorize_dim > 0:
|
| 283 |
+
app_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
|
| 284 |
+
INR_in_out[1]
|
| 285 |
+
self.app_split_weight.append([start, start + INR_in_out[0] * self.param_factorize_dim,
|
| 286 |
+
app_params - INR_in_out[1]])
|
| 287 |
+
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
|
| 288 |
+
start = app_params
|
| 289 |
+
|
| 290 |
+
for id in range(hidden_mlp_num):
|
| 291 |
+
app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
|
| 292 |
+
INR_in_out[1]
|
| 293 |
+
self.app_split_weight.append(
|
| 294 |
+
[start, start + INR_in_out[1] * self.param_factorize_dim, app_params - INR_in_out[1]])
|
| 295 |
+
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
|
| 296 |
+
start = app_params
|
| 297 |
+
|
| 298 |
+
if self.toRGB:
|
| 299 |
+
app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3
|
| 300 |
+
self.app_split_weight.append([start, start + INR_in_out[1] * self.param_factorize_dim,
|
| 301 |
+
app_params - 3])
|
| 302 |
+
self.app_split_bias.append([app_params - 3, app_params])
|
| 303 |
+
|
| 304 |
+
return app_params
|
| 305 |
+
|
| 306 |
+
def to_mlp(self, params):
|
| 307 |
+
all_weight_bias = []
|
| 308 |
+
if self.param_factorize_dim == -1:
|
| 309 |
+
for id in range(self.hidden_mlp_num + 1):
|
| 310 |
+
if id == 0:
|
| 311 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
| 312 |
+
else:
|
| 313 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
| 314 |
+
weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
|
| 315 |
+
weight = weight.view(weight.shape[0], inp, outp)
|
| 316 |
+
|
| 317 |
+
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
|
| 318 |
+
bias = bias.view(bias.shape[0], 1, outp)
|
| 319 |
+
|
| 320 |
+
all_weight_bias.append([weight, bias])
|
| 321 |
+
|
| 322 |
+
if self.toRGB:
|
| 323 |
+
id = -1
|
| 324 |
+
inp, outp = self.INR_in_out[1], 3
|
| 325 |
+
weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
|
| 326 |
+
weight = weight.view(weight.shape[0], inp, outp)
|
| 327 |
+
|
| 328 |
+
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
|
| 329 |
+
bias = bias.view(bias.shape[0], 1, outp)
|
| 330 |
+
|
| 331 |
+
all_weight_bias.append([weight, bias])
|
| 332 |
+
|
| 333 |
+
return all_weight_bias
|
| 334 |
+
|
| 335 |
+
else:
|
| 336 |
+
for id in range(self.hidden_mlp_num + 1):
|
| 337 |
+
if id == 0:
|
| 338 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
| 339 |
+
else:
|
| 340 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
| 341 |
+
weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
|
| 342 |
+
weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim)
|
| 343 |
+
|
| 344 |
+
weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]]
|
| 345 |
+
weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp)
|
| 346 |
+
|
| 347 |
+
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
|
| 348 |
+
bias = bias.view(bias.shape[0], 1, outp)
|
| 349 |
+
|
| 350 |
+
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
|
| 351 |
+
|
| 352 |
+
if self.toRGB:
|
| 353 |
+
id = -1
|
| 354 |
+
inp, outp = self.INR_in_out[1], 3
|
| 355 |
+
weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
|
| 356 |
+
weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim)
|
| 357 |
+
|
| 358 |
+
weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]]
|
| 359 |
+
weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp)
|
| 360 |
+
|
| 361 |
+
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
|
| 362 |
+
bias = bias.view(bias.shape[0], 1, outp)
|
| 363 |
+
|
| 364 |
+
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
|
| 365 |
+
|
| 366 |
+
return all_weight_bias
|
model/base/conv_autoencoder.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
from torch import nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import numpy as np
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
from .basic_blocks import ConvBlock, lineParams, convParams
|
| 9 |
+
from .ops import MaskedChannelAttention, FeaturesConnector
|
| 10 |
+
from .ops import PosEncodingNeRF, INRGAN_embed, RandomFourier, CIPS_embed
|
| 11 |
+
from utils import misc
|
| 12 |
+
from utils.misc import lin2img
|
| 13 |
+
from ..lut_transformation_net import build_lut_transform
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Sine(nn.Module):
|
| 17 |
+
def __init__(self):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
def forward(self, input):
|
| 21 |
+
return torch.sin(30 * input)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Leaky_relu(nn.Module):
|
| 25 |
+
def __init__(self):
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
def forward(self, input):
|
| 29 |
+
return torch.nn.functional.leaky_relu(input, 0.01, inplace=True)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def select_activation(type):
|
| 33 |
+
if type == 'sine':
|
| 34 |
+
return Sine()
|
| 35 |
+
elif type == 'leakyrelu_pe':
|
| 36 |
+
return Leaky_relu()
|
| 37 |
+
else:
|
| 38 |
+
raise NotImplementedError
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ConvEncoder(nn.Module):
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
depth, ch,
|
| 45 |
+
norm_layer, batchnorm_from, max_channels,
|
| 46 |
+
backbone_from, backbone_channels=None, backbone_mode='', INRDecode=False
|
| 47 |
+
):
|
| 48 |
+
super(ConvEncoder, self).__init__()
|
| 49 |
+
self.depth = depth
|
| 50 |
+
self.INRDecode = INRDecode
|
| 51 |
+
self.backbone_from = backbone_from
|
| 52 |
+
backbone_channels = [] if backbone_channels is None else backbone_channels[::-1]
|
| 53 |
+
|
| 54 |
+
in_channels = 4
|
| 55 |
+
out_channels = ch
|
| 56 |
+
|
| 57 |
+
self.block0 = ConvBlock(in_channels, out_channels, norm_layer=norm_layer if batchnorm_from == 0 else None)
|
| 58 |
+
self.block1 = ConvBlock(out_channels, out_channels, norm_layer=norm_layer if 0 <= batchnorm_from <= 1 else None)
|
| 59 |
+
self.blocks_channels = [out_channels, out_channels]
|
| 60 |
+
|
| 61 |
+
self.blocks_connected = nn.ModuleDict()
|
| 62 |
+
self.connectors = nn.ModuleDict()
|
| 63 |
+
for block_i in range(2, depth):
|
| 64 |
+
if block_i % 2:
|
| 65 |
+
in_channels = out_channels
|
| 66 |
+
else:
|
| 67 |
+
in_channels, out_channels = out_channels, min(2 * out_channels, max_channels)
|
| 68 |
+
|
| 69 |
+
if 0 <= backbone_from <= block_i and len(backbone_channels):
|
| 70 |
+
if INRDecode:
|
| 71 |
+
self.blocks_connected[f'block{block_i}_decode'] = ConvBlock(
|
| 72 |
+
in_channels, out_channels,
|
| 73 |
+
norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None,
|
| 74 |
+
padding=int(block_i < depth - 1)
|
| 75 |
+
)
|
| 76 |
+
self.blocks_channels += [out_channels]
|
| 77 |
+
stage_channels = backbone_channels.pop()
|
| 78 |
+
connector = FeaturesConnector(backbone_mode, in_channels, stage_channels, in_channels)
|
| 79 |
+
self.connectors[f'connector{block_i}'] = connector
|
| 80 |
+
in_channels = connector.output_channels
|
| 81 |
+
|
| 82 |
+
self.blocks_connected[f'block{block_i}'] = ConvBlock(
|
| 83 |
+
in_channels, out_channels,
|
| 84 |
+
norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None,
|
| 85 |
+
padding=int(block_i < depth - 1)
|
| 86 |
+
)
|
| 87 |
+
self.blocks_channels += [out_channels]
|
| 88 |
+
|
| 89 |
+
def forward(self, x, backbone_features):
|
| 90 |
+
backbone_features = [] if backbone_features is None else backbone_features[::-1]
|
| 91 |
+
|
| 92 |
+
outputs = [self.block0(x)]
|
| 93 |
+
outputs += [self.block1(outputs[-1])]
|
| 94 |
+
|
| 95 |
+
for block_i in range(2, self.depth):
|
| 96 |
+
output = outputs[-1]
|
| 97 |
+
connector_name = f'connector{block_i}'
|
| 98 |
+
if connector_name in self.connectors:
|
| 99 |
+
if self.INRDecode:
|
| 100 |
+
block = self.blocks_connected[f'block{block_i}_decode']
|
| 101 |
+
outputs += [block(output)]
|
| 102 |
+
|
| 103 |
+
stage_features = backbone_features.pop()
|
| 104 |
+
connector = self.connectors[connector_name]
|
| 105 |
+
output = connector(output, stage_features)
|
| 106 |
+
block = self.blocks_connected[f'block{block_i}']
|
| 107 |
+
outputs += [block(output)]
|
| 108 |
+
|
| 109 |
+
return outputs[::-1]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class DeconvDecoder(nn.Module):
|
| 113 |
+
def __init__(self, depth, encoder_blocks_channels, norm_layer, attend_from=-1, image_fusion=False):
|
| 114 |
+
super(DeconvDecoder, self).__init__()
|
| 115 |
+
self.image_fusion = image_fusion
|
| 116 |
+
self.deconv_blocks = nn.ModuleList()
|
| 117 |
+
|
| 118 |
+
in_channels = encoder_blocks_channels.pop()
|
| 119 |
+
out_channels = in_channels
|
| 120 |
+
for d in range(depth):
|
| 121 |
+
out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2
|
| 122 |
+
self.deconv_blocks.append(SEDeconvBlock(
|
| 123 |
+
in_channels, out_channels,
|
| 124 |
+
norm_layer=norm_layer,
|
| 125 |
+
padding=0 if d == 0 else 1,
|
| 126 |
+
with_se=0 <= attend_from <= d
|
| 127 |
+
))
|
| 128 |
+
in_channels = out_channels
|
| 129 |
+
|
| 130 |
+
if self.image_fusion:
|
| 131 |
+
self.conv_attention = nn.Conv2d(out_channels, 1, kernel_size=1)
|
| 132 |
+
self.to_rgb = nn.Conv2d(out_channels, 3, kernel_size=1)
|
| 133 |
+
|
| 134 |
+
def forward(self, encoder_outputs, image, mask=None):
|
| 135 |
+
output = encoder_outputs[0]
|
| 136 |
+
for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
|
| 137 |
+
output = block(output, mask)
|
| 138 |
+
output = output + skip_output
|
| 139 |
+
output = self.deconv_blocks[-1](output, mask)
|
| 140 |
+
|
| 141 |
+
if self.image_fusion:
|
| 142 |
+
attention_map = torch.sigmoid(3.0 * self.conv_attention(output))
|
| 143 |
+
output = attention_map * image + (1.0 - attention_map) * self.to_rgb(output)
|
| 144 |
+
else:
|
| 145 |
+
output = self.to_rgb(output)
|
| 146 |
+
|
| 147 |
+
return output
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class SEDeconvBlock(nn.Module):
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
in_channels, out_channels,
|
| 154 |
+
kernel_size=4, stride=2, padding=1,
|
| 155 |
+
norm_layer=nn.BatchNorm2d, activation=nn.ELU,
|
| 156 |
+
with_se=False
|
| 157 |
+
):
|
| 158 |
+
super(SEDeconvBlock, self).__init__()
|
| 159 |
+
self.with_se = with_se
|
| 160 |
+
self.block = nn.Sequential(
|
| 161 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
|
| 162 |
+
norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
|
| 163 |
+
activation(),
|
| 164 |
+
)
|
| 165 |
+
if self.with_se:
|
| 166 |
+
self.se = MaskedChannelAttention(out_channels)
|
| 167 |
+
|
| 168 |
+
def forward(self, x, mask=None):
|
| 169 |
+
out = self.block(x)
|
| 170 |
+
if self.with_se:
|
| 171 |
+
out = self.se(out, mask)
|
| 172 |
+
return out
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class INRDecoder(nn.Module):
|
| 176 |
+
def __init__(self, depth, encoder_blocks_channels, norm_layer, opt, attend_from):
|
| 177 |
+
super(INRDecoder, self).__init__()
|
| 178 |
+
self.INR_encoding = None
|
| 179 |
+
if opt.embedding_type == "PosEncodingNeRF":
|
| 180 |
+
self.INR_encoding = PosEncodingNeRF(in_features=2, sidelength=opt.input_size)
|
| 181 |
+
elif opt.embedding_type == "RandomFourier":
|
| 182 |
+
self.INR_encoding = RandomFourier(std_scale=10, embedding_length=64, device=opt.device)
|
| 183 |
+
elif opt.embedding_type == "CIPS_embed":
|
| 184 |
+
self.INR_encoding = CIPS_embed(size=opt.base_size, embedding_length=32)
|
| 185 |
+
elif opt.embedding_type == "INRGAN_embed":
|
| 186 |
+
self.INR_encoding = INRGAN_embed(resolution=opt.INR_input_size)
|
| 187 |
+
else:
|
| 188 |
+
raise NotImplementedError
|
| 189 |
+
encoder_blocks_channels = encoder_blocks_channels[::-1]
|
| 190 |
+
max_hidden_mlp_num = attend_from + 1
|
| 191 |
+
self.opt = opt
|
| 192 |
+
self.max_hidden_mlp_num = max_hidden_mlp_num
|
| 193 |
+
self.content_mlp_blocks = nn.ModuleDict()
|
| 194 |
+
for n in range(max_hidden_mlp_num):
|
| 195 |
+
if n != max_hidden_mlp_num - 1:
|
| 196 |
+
self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(),
|
| 197 |
+
[self.INR_encoding.out_dim + opt.INR_MLP_dim + (
|
| 198 |
+
4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim],
|
| 199 |
+
opt, n + 1)
|
| 200 |
+
else:
|
| 201 |
+
self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(),
|
| 202 |
+
[self.INR_encoding.out_dim + (
|
| 203 |
+
4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim],
|
| 204 |
+
opt, n + 1)
|
| 205 |
+
|
| 206 |
+
self.deconv_blocks = nn.ModuleList()
|
| 207 |
+
|
| 208 |
+
encoder_blocks_channels = encoder_blocks_channels[::-1]
|
| 209 |
+
in_channels = encoder_blocks_channels.pop()
|
| 210 |
+
out_channels = in_channels
|
| 211 |
+
for d in range(depth - attend_from):
|
| 212 |
+
out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2
|
| 213 |
+
self.deconv_blocks.append(SEDeconvBlock(
|
| 214 |
+
in_channels, out_channels,
|
| 215 |
+
norm_layer=norm_layer,
|
| 216 |
+
padding=0 if d == 0 else 1,
|
| 217 |
+
with_se=False
|
| 218 |
+
))
|
| 219 |
+
in_channels = out_channels
|
| 220 |
+
|
| 221 |
+
self.appearance_mlps = lineParams(out_channels, [opt.INR_MLP_dim, opt.INR_MLP_dim],
|
| 222 |
+
(opt.base_size // (2 ** (max_hidden_mlp_num - 1))) ** 2,
|
| 223 |
+
opt, 2, toRGB=True)
|
| 224 |
+
|
| 225 |
+
self.lut_transform = build_lut_transform(self.appearance_mlps.output_dim, opt.LUT_dim,
|
| 226 |
+
None, opt)
|
| 227 |
+
|
| 228 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
| 229 |
+
|
| 230 |
+
def forward(self, encoder_outputs, image=None, mask=None, coord_samples=None, start_proportion=None):
|
| 231 |
+
"""For full resolution, do split."""
|
| 232 |
+
if self.opt.hr_train and not (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt,
|
| 233 |
+
'split_resolution')) and self.opt.isFullRes:
|
| 234 |
+
return self.forward_fullResInference(encoder_outputs, image=image, mask=mask, coord_samples=coord_samples)
|
| 235 |
+
|
| 236 |
+
encoder_outputs = encoder_outputs[::-1]
|
| 237 |
+
mlp_output = None
|
| 238 |
+
waitToRGB = []
|
| 239 |
+
for n in range(self.max_hidden_mlp_num):
|
| 240 |
+
if not self.opt.hr_train:
|
| 241 |
+
coord = misc.get_mgrid(self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))) \
|
| 242 |
+
.unsqueeze(0).repeat(encoder_outputs[0].shape[0], 1, 1).to(self.opt.device)
|
| 243 |
+
else:
|
| 244 |
+
if self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution'):
|
| 245 |
+
coord = coord_samples[self.max_hidden_mlp_num - n - 1].permute(0, 2, 3, 1).view(
|
| 246 |
+
encoder_outputs[0].shape[0], -1, 2)
|
| 247 |
+
else:
|
| 248 |
+
coord = misc.get_mgrid(
|
| 249 |
+
self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))).unsqueeze(0).repeat(
|
| 250 |
+
encoder_outputs[0].shape[0], 1, 1).to(self.opt.device)
|
| 251 |
+
|
| 252 |
+
"""Whether to leverage multiple input to INR decoder. See Section 3.4 in the paper."""
|
| 253 |
+
if self.opt.isMoreINRInput:
|
| 254 |
+
if not self.opt.isFullRes or (
|
| 255 |
+
self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
|
| 256 |
+
res_h = res_w = np.sqrt(coord.shape[1]).astype(int)
|
| 257 |
+
else:
|
| 258 |
+
res_h = image.shape[-2] // (2 ** (self.max_hidden_mlp_num - n - 1))
|
| 259 |
+
res_w = image.shape[-1] // (2 ** (self.max_hidden_mlp_num - n - 1))
|
| 260 |
+
|
| 261 |
+
res_image = torchvision.transforms.Resize([res_h, res_w])(image)
|
| 262 |
+
res_mask = torchvision.transforms.Resize([res_h, res_w])(mask)
|
| 263 |
+
coord = torch.cat([self.INR_encoding(coord), res_image.view(*res_image.shape[:2], -1).permute(0, 2, 1),
|
| 264 |
+
res_mask.view(*res_mask.shape[:2], -1).permute(0, 2, 1)], dim=-1)
|
| 265 |
+
else:
|
| 266 |
+
coord = self.INR_encoding(coord)
|
| 267 |
+
|
| 268 |
+
"""============ LRIP structure, see Section 3.3 =============="""
|
| 269 |
+
|
| 270 |
+
"""Local MLPs."""
|
| 271 |
+
if n == 0:
|
| 272 |
+
mlp_output = self.mlp_process(coord, self.INR_encoding.out_dim + (4 if self.opt.isMoreINRInput else 0),
|
| 273 |
+
self.opt, content_mlp=self.content_mlp_blocks[
|
| 274 |
+
f"block{self.max_hidden_mlp_num - 1 - n}"](
|
| 275 |
+
encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)), start_proportion=start_proportion)
|
| 276 |
+
waitToRGB.append(mlp_output[1])
|
| 277 |
+
else:
|
| 278 |
+
mlp_output = self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + (
|
| 279 |
+
4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0],
|
| 280 |
+
content_mlp=self.content_mlp_blocks[
|
| 281 |
+
f"block{self.max_hidden_mlp_num - 1 - n}"](
|
| 282 |
+
encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)),
|
| 283 |
+
start_proportion=start_proportion)
|
| 284 |
+
waitToRGB.append(mlp_output[1])
|
| 285 |
+
|
| 286 |
+
encoder_outputs = encoder_outputs[::-1]
|
| 287 |
+
output = encoder_outputs[0]
|
| 288 |
+
for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
|
| 289 |
+
output = block(output)
|
| 290 |
+
output = output + skip_output
|
| 291 |
+
output = self.deconv_blocks[-1](output)
|
| 292 |
+
|
| 293 |
+
"""Global MLPs."""
|
| 294 |
+
app_mlp, app_params = self.appearance_mlps(output)
|
| 295 |
+
harm_out = []
|
| 296 |
+
for id in range(len(waitToRGB)):
|
| 297 |
+
output = self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=waitToRGB[id],
|
| 298 |
+
appearance_mlp=app_mlp)
|
| 299 |
+
harm_out.append(output[0])
|
| 300 |
+
|
| 301 |
+
"""Optional 3D LUT prediction."""
|
| 302 |
+
fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None)
|
| 303 |
+
|
| 304 |
+
return harm_out, fit_lut3d, lut_transform_image
|
| 305 |
+
|
| 306 |
+
def mlp_process(self, coorinates, INR_input_dim, opt, base_feat=None, content_mlp=None, appearance_mlp=None,
|
| 307 |
+
resolution=None, start_proportion=None):
|
| 308 |
+
|
| 309 |
+
activation = select_activation(opt.activation)
|
| 310 |
+
|
| 311 |
+
output = None
|
| 312 |
+
|
| 313 |
+
if content_mlp is not None:
|
| 314 |
+
if base_feat is not None:
|
| 315 |
+
coorinates = torch.cat([coorinates, base_feat], dim=2)
|
| 316 |
+
coorinates = lin2img(coorinates, resolution)
|
| 317 |
+
|
| 318 |
+
if hasattr(opt, 'split_resolution'):
|
| 319 |
+
"""
|
| 320 |
+
Here we crop the needed MLPs according to the region of the split input patches.
|
| 321 |
+
Note that this only support inferencing square images.
|
| 322 |
+
"""
|
| 323 |
+
for idx in range(len(content_mlp)):
|
| 324 |
+
content_mlp[idx][0] = content_mlp[idx][0][:,
|
| 325 |
+
(content_mlp[idx][0].shape[1] * start_proportion[0]).int():(
|
| 326 |
+
content_mlp[idx][0].shape[1] * start_proportion[2]).int(),
|
| 327 |
+
(content_mlp[idx][0].shape[2] * start_proportion[1]).int():(
|
| 328 |
+
content_mlp[idx][0].shape[2] * start_proportion[3]).int(), :,
|
| 329 |
+
:]
|
| 330 |
+
content_mlp[idx][1] = content_mlp[idx][1][:,
|
| 331 |
+
(content_mlp[idx][1].shape[1] * start_proportion[0]).int():(
|
| 332 |
+
content_mlp[idx][1].shape[1] * start_proportion[2]).int(),
|
| 333 |
+
(content_mlp[idx][1].shape[2] * start_proportion[1]).int():(
|
| 334 |
+
content_mlp[idx][1].shape[2] * start_proportion[3]).int(),
|
| 335 |
+
:,
|
| 336 |
+
:]
|
| 337 |
+
k_h = coorinates.shape[2] // content_mlp[0][0].shape[1]
|
| 338 |
+
k_w = coorinates.shape[3] // content_mlp[0][0].shape[1]
|
| 339 |
+
bs = coorinates.shape[0]
|
| 340 |
+
h_lr = w_lr = content_mlp[0][0].shape[1]
|
| 341 |
+
nci = INR_input_dim
|
| 342 |
+
|
| 343 |
+
coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w)
|
| 344 |
+
coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view(
|
| 345 |
+
bs, h_lr, w_lr, int(k_h * k_w), nci)
|
| 346 |
+
|
| 347 |
+
for id, layer in enumerate(content_mlp):
|
| 348 |
+
if id == 0:
|
| 349 |
+
output = torch.matmul(coorinates, layer[0]) + layer[1]
|
| 350 |
+
output = activation(output)
|
| 351 |
+
else:
|
| 352 |
+
output = torch.matmul(output, layer[0]) + layer[1]
|
| 353 |
+
output = activation(output)
|
| 354 |
+
|
| 355 |
+
output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute(
|
| 356 |
+
0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim)
|
| 357 |
+
|
| 358 |
+
output_large = self.up(lin2img(output))
|
| 359 |
+
|
| 360 |
+
return output_large.view(bs, -1, opt.INR_MLP_dim), output
|
| 361 |
+
|
| 362 |
+
k_h = coorinates.shape[2] // content_mlp[0][0].shape[1]
|
| 363 |
+
k_w = coorinates.shape[3] // content_mlp[0][0].shape[1]
|
| 364 |
+
bs = coorinates.shape[0]
|
| 365 |
+
h_lr = w_lr = content_mlp[0][0].shape[1]
|
| 366 |
+
nci = INR_input_dim
|
| 367 |
+
|
| 368 |
+
"""(evaluation or not HR training) and not fullres evaluation"""
|
| 369 |
+
if (not self.opt.hr_train or not (self.training or hasattr(self.opt, 'split_num'))) and not (
|
| 370 |
+
not (self.training or hasattr(self.opt, 'split_num')) and self.opt.isFullRes and self.opt.hr_train):
|
| 371 |
+
coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w)
|
| 372 |
+
coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view(
|
| 373 |
+
bs, h_lr, w_lr, int(k_h * k_w), nci)
|
| 374 |
+
|
| 375 |
+
for id, layer in enumerate(content_mlp):
|
| 376 |
+
if id == 0:
|
| 377 |
+
output = torch.matmul(coorinates, layer[0]) + layer[1]
|
| 378 |
+
output = activation(output)
|
| 379 |
+
else:
|
| 380 |
+
output = torch.matmul(output, layer[0]) + layer[1]
|
| 381 |
+
output = activation(output)
|
| 382 |
+
|
| 383 |
+
output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute(
|
| 384 |
+
0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim)
|
| 385 |
+
|
| 386 |
+
output_large = self.up(lin2img(output))
|
| 387 |
+
|
| 388 |
+
return output_large.view(bs, -1, opt.INR_MLP_dim), output
|
| 389 |
+
else:
|
| 390 |
+
coorinates = coorinates.permute(0, 2, 3, 1)
|
| 391 |
+
for id, layer in enumerate(content_mlp):
|
| 392 |
+
weigt_shape = layer[0].shape
|
| 393 |
+
bias_shape = layer[1].shape
|
| 394 |
+
layer[0] = layer[0].view(*layer[0].shape[:-2], -1).permute(0, 3, 1, 2).contiguous()
|
| 395 |
+
layer[1] = layer[1].view(*layer[1].shape[:-2], -1).permute(0, 3, 1, 2).contiguous()
|
| 396 |
+
layer[0] = F.grid_sample(layer[0], coorinates[..., :2].flip(-1), mode='nearest' if True
|
| 397 |
+
else 'bilinear', padding_mode='border', align_corners=False)
|
| 398 |
+
layer[1] = F.grid_sample(layer[1], coorinates[..., :2].flip(-1), mode='nearest' if True
|
| 399 |
+
else 'bilinear', padding_mode='border', align_corners=False)
|
| 400 |
+
layer[0] = layer[0].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *weigt_shape[-2:])
|
| 401 |
+
layer[1] = layer[1].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *bias_shape[-2:])
|
| 402 |
+
|
| 403 |
+
if id == 0:
|
| 404 |
+
output = torch.matmul(coorinates.unsqueeze(-2), layer[0]) + layer[1]
|
| 405 |
+
output = activation(output)
|
| 406 |
+
else:
|
| 407 |
+
output = torch.matmul(output, layer[0]) + layer[1]
|
| 408 |
+
output = activation(output)
|
| 409 |
+
|
| 410 |
+
output = output.squeeze(-2).view(bs, -1, opt.INR_MLP_dim)
|
| 411 |
+
|
| 412 |
+
output_large = self.up(lin2img(output, resolution))
|
| 413 |
+
|
| 414 |
+
return output_large.view(bs, -1, opt.INR_MLP_dim), output
|
| 415 |
+
|
| 416 |
+
elif appearance_mlp is not None:
|
| 417 |
+
output = base_feat
|
| 418 |
+
genMask = None
|
| 419 |
+
for id, layer in enumerate(appearance_mlp):
|
| 420 |
+
if id != len(appearance_mlp) - 1:
|
| 421 |
+
output = torch.matmul(output, layer[0]) + layer[1]
|
| 422 |
+
output = activation(output)
|
| 423 |
+
else:
|
| 424 |
+
output = torch.matmul(output, layer[0]) + layer[1] # last layer
|
| 425 |
+
if opt.activation == 'leakyrelu_pe':
|
| 426 |
+
output = torch.tanh(output)
|
| 427 |
+
return lin2img(output, resolution), None
|
| 428 |
+
|
| 429 |
+
def forward_fullResInference(self, encoder_outputs, image=None, mask=None, coord_samples=None):
|
| 430 |
+
encoder_outputs = encoder_outputs[::-1]
|
| 431 |
+
mlp_output = None
|
| 432 |
+
res_w = image.shape[-1]
|
| 433 |
+
res_h = image.shape[-2]
|
| 434 |
+
coord = misc.get_mgrid([image.shape[-2], image.shape[-1]]).unsqueeze(0).repeat(
|
| 435 |
+
encoder_outputs[0].shape[0], 1, 1).to(self.opt.device)
|
| 436 |
+
|
| 437 |
+
if self.opt.isMoreINRInput:
|
| 438 |
+
coord = torch.cat(
|
| 439 |
+
[self.INR_encoding(coord, (res_h, res_w)), image.view(*image.shape[:2], -1).permute(0, 2, 1),
|
| 440 |
+
mask.view(*mask.shape[:2], -1).permute(0, 2, 1)], dim=-1)
|
| 441 |
+
else:
|
| 442 |
+
coord = self.INR_encoding(coord, (res_h, res_w))
|
| 443 |
+
|
| 444 |
+
total = coord.clone()
|
| 445 |
+
|
| 446 |
+
interval = 10
|
| 447 |
+
all_intervals = math.ceil(res_h / interval)
|
| 448 |
+
divisible = True
|
| 449 |
+
if res_h / interval != res_h // interval:
|
| 450 |
+
divisible = False
|
| 451 |
+
|
| 452 |
+
for n in range(self.max_hidden_mlp_num):
|
| 453 |
+
accum_mlp_output = []
|
| 454 |
+
for line in range(all_intervals):
|
| 455 |
+
if not divisible and line == all_intervals - 1:
|
| 456 |
+
coord = total[:, line * interval * res_w:, :]
|
| 457 |
+
else:
|
| 458 |
+
coord = total[:, line * interval * res_w: (line + 1) * interval * res_w, :]
|
| 459 |
+
if n == 0:
|
| 460 |
+
accum_mlp_output.append(self.mlp_process(coord,
|
| 461 |
+
self.INR_encoding.out_dim + (
|
| 462 |
+
4 if self.opt.isMoreINRInput else 0),
|
| 463 |
+
self.opt, content_mlp=self.content_mlp_blocks[
|
| 464 |
+
f"block{self.max_hidden_mlp_num - 1 - n}"](
|
| 465 |
+
encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else
|
| 466 |
+
encoder_outputs[self.max_hidden_mlp_num - 1 - n]),
|
| 467 |
+
resolution=(interval,
|
| 468 |
+
res_w) if divisible or line != all_intervals - 1 else (
|
| 469 |
+
res_h - interval * (all_intervals - 1), res_w))[1])
|
| 470 |
+
|
| 471 |
+
else:
|
| 472 |
+
accum_mlp_output.append(self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + (
|
| 473 |
+
4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0][:,
|
| 474 |
+
line * interval * res_w: (
|
| 475 |
+
line + 1) * interval * res_w,
|
| 476 |
+
:]
|
| 477 |
+
if divisible or line != all_intervals - 1 else mlp_output[0][:, line * interval * res_w:, :],
|
| 478 |
+
content_mlp=self.content_mlp_blocks[
|
| 479 |
+
f"block{self.max_hidden_mlp_num - 1 - n}"](
|
| 480 |
+
encoder_outputs.pop(
|
| 481 |
+
self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else
|
| 482 |
+
encoder_outputs[self.max_hidden_mlp_num - 1 - n]),
|
| 483 |
+
resolution=(interval,
|
| 484 |
+
res_w) if divisible or line != all_intervals - 1 else (
|
| 485 |
+
res_h - interval * (all_intervals - 1), res_w))[1])
|
| 486 |
+
|
| 487 |
+
accum_mlp_output = torch.cat(accum_mlp_output, dim=1)
|
| 488 |
+
mlp_output = [accum_mlp_output, accum_mlp_output]
|
| 489 |
+
|
| 490 |
+
encoder_outputs = encoder_outputs[::-1]
|
| 491 |
+
output = encoder_outputs[0]
|
| 492 |
+
for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
|
| 493 |
+
output = block(output)
|
| 494 |
+
output = output + skip_output
|
| 495 |
+
output = self.deconv_blocks[-1](output)
|
| 496 |
+
|
| 497 |
+
app_mlp, app_params = self.appearance_mlps(output)
|
| 498 |
+
harm_out = []
|
| 499 |
+
|
| 500 |
+
accum_mlp_output = []
|
| 501 |
+
for line in range(all_intervals):
|
| 502 |
+
if not divisible and line == all_intervals - 1:
|
| 503 |
+
base = mlp_output[1][:, line * interval * res_w:, :]
|
| 504 |
+
else:
|
| 505 |
+
base = mlp_output[1][:, line * interval * res_w: (line + 1) * interval * res_w, :]
|
| 506 |
+
|
| 507 |
+
accum_mlp_output.append(self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=base,
|
| 508 |
+
appearance_mlp=app_mlp,
|
| 509 |
+
resolution=(
|
| 510 |
+
interval,
|
| 511 |
+
res_w) if divisible or line != all_intervals - 1 else (
|
| 512 |
+
res_h - interval * (all_intervals - 1), res_w))[0])
|
| 513 |
+
|
| 514 |
+
accum_mlp_output = torch.cat(accum_mlp_output, dim=2)
|
| 515 |
+
harm_out.append(accum_mlp_output)
|
| 516 |
+
|
| 517 |
+
fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None)
|
| 518 |
+
|
| 519 |
+
return harm_out, fit_lut3d, lut_transform_image
|
model/base/ih_model.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from .conv_autoencoder import ConvEncoder, DeconvDecoder, INRDecoder
|
| 6 |
+
|
| 7 |
+
from .ops import ScaleLayer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class IHModelWithBackbone(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
model, backbone,
|
| 14 |
+
downsize_backbone_input=False,
|
| 15 |
+
mask_fusion='sum',
|
| 16 |
+
backbone_conv1_channels=64, opt=None
|
| 17 |
+
):
|
| 18 |
+
super(IHModelWithBackbone, self).__init__()
|
| 19 |
+
self.downsize_backbone_input = downsize_backbone_input
|
| 20 |
+
self.mask_fusion = mask_fusion
|
| 21 |
+
|
| 22 |
+
self.backbone = backbone
|
| 23 |
+
self.model = model
|
| 24 |
+
self.opt = opt
|
| 25 |
+
|
| 26 |
+
self.mask_conv = nn.Sequential(
|
| 27 |
+
nn.Conv2d(1, backbone_conv1_channels, kernel_size=3, stride=2, padding=1, bias=True),
|
| 28 |
+
ScaleLayer(init_value=0.1, lr_mult=1)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def forward(self, image, mask, coord=None, start_proportion=None):
|
| 32 |
+
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
|
| 33 |
+
backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0])
|
| 34 |
+
backbone_mask = torch.cat(
|
| 35 |
+
(torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0]),
|
| 36 |
+
1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1)
|
| 37 |
+
else:
|
| 38 |
+
backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image)
|
| 39 |
+
backbone_mask = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask),
|
| 40 |
+
1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1)
|
| 41 |
+
|
| 42 |
+
backbone_mask_features = self.mask_conv(backbone_mask[:, :1])
|
| 43 |
+
backbone_features = self.backbone(backbone_image, backbone_mask, backbone_mask_features)
|
| 44 |
+
|
| 45 |
+
output = self.model(image, mask, backbone_features, coord=coord, start_proportion=start_proportion)
|
| 46 |
+
return output
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DeepImageHarmonization(nn.Module):
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
depth,
|
| 53 |
+
norm_layer=nn.BatchNorm2d, batchnorm_from=0,
|
| 54 |
+
attend_from=-1,
|
| 55 |
+
image_fusion=False,
|
| 56 |
+
ch=64, max_channels=512,
|
| 57 |
+
backbone_from=-1, backbone_channels=None, backbone_mode='', opt=None
|
| 58 |
+
):
|
| 59 |
+
super(DeepImageHarmonization, self).__init__()
|
| 60 |
+
self.depth = depth
|
| 61 |
+
self.encoder = ConvEncoder(
|
| 62 |
+
depth, ch,
|
| 63 |
+
norm_layer, batchnorm_from, max_channels,
|
| 64 |
+
backbone_from, backbone_channels, backbone_mode, INRDecode=opt.INRDecode
|
| 65 |
+
)
|
| 66 |
+
self.opt = opt
|
| 67 |
+
if opt.INRDecode:
|
| 68 |
+
"See Table 2 in the paper to test with different INR decoders' structures."
|
| 69 |
+
self.decoder = INRDecoder(depth, self.encoder.blocks_channels, norm_layer, opt, backbone_from)
|
| 70 |
+
else:
|
| 71 |
+
"Baseline: https://github.com/SamsungLabs/image_harmonization"
|
| 72 |
+
self.decoder = DeconvDecoder(depth, self.encoder.blocks_channels, norm_layer, attend_from, image_fusion)
|
| 73 |
+
|
| 74 |
+
def forward(self, image, mask, backbone_features=None, coord=None, start_proportion=None):
|
| 75 |
+
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
|
| 76 |
+
x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0]),
|
| 77 |
+
torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1)
|
| 78 |
+
else:
|
| 79 |
+
x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image),
|
| 80 |
+
torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1)
|
| 81 |
+
|
| 82 |
+
intermediates = self.encoder(x, backbone_features)
|
| 83 |
+
|
| 84 |
+
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
|
| 85 |
+
output = self.decoder(intermediates, image[1], mask[1], coord_samples=coord, start_proportion=start_proportion)
|
| 86 |
+
else:
|
| 87 |
+
output = self.decoder(intermediates, image, mask)
|
| 88 |
+
return output
|
model/base/ops.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import math
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SimpleInputFusion(nn.Module):
|
| 9 |
+
def __init__(self, add_ch=1, rgb_ch=3, ch=8, norm_layer=nn.BatchNorm2d):
|
| 10 |
+
super(SimpleInputFusion, self).__init__()
|
| 11 |
+
|
| 12 |
+
self.fusion_conv = nn.Sequential(
|
| 13 |
+
nn.Conv2d(in_channels=add_ch + rgb_ch, out_channels=ch, kernel_size=1),
|
| 14 |
+
nn.LeakyReLU(negative_slope=0.2),
|
| 15 |
+
norm_layer(ch),
|
| 16 |
+
nn.Conv2d(in_channels=ch, out_channels=rgb_ch, kernel_size=1),
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
def forward(self, image, additional_input):
|
| 20 |
+
return self.fusion_conv(torch.cat((image, additional_input), dim=1))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MaskedChannelAttention(nn.Module):
|
| 24 |
+
def __init__(self, in_channels, *args, **kwargs):
|
| 25 |
+
super(MaskedChannelAttention, self).__init__()
|
| 26 |
+
self.global_max_pool = MaskedGlobalMaxPool2d()
|
| 27 |
+
self.global_avg_pool = FastGlobalAvgPool2d()
|
| 28 |
+
|
| 29 |
+
intermediate_channels_count = max(in_channels // 16, 8)
|
| 30 |
+
self.attention_transform = nn.Sequential(
|
| 31 |
+
nn.Linear(3 * in_channels, intermediate_channels_count),
|
| 32 |
+
nn.ReLU(inplace=True),
|
| 33 |
+
nn.Linear(intermediate_channels_count, in_channels),
|
| 34 |
+
nn.Sigmoid(),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def forward(self, x, mask):
|
| 38 |
+
if mask.shape[2:] != x.shape[:2]:
|
| 39 |
+
mask = nn.functional.interpolate(
|
| 40 |
+
mask, size=x.size()[-2:],
|
| 41 |
+
mode='bilinear', align_corners=True
|
| 42 |
+
)
|
| 43 |
+
pooled_x = torch.cat([
|
| 44 |
+
self.global_max_pool(x, mask),
|
| 45 |
+
self.global_avg_pool(x)
|
| 46 |
+
], dim=1)
|
| 47 |
+
channel_attention_weights = self.attention_transform(pooled_x)[..., None, None]
|
| 48 |
+
|
| 49 |
+
return channel_attention_weights * x
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class MaskedGlobalMaxPool2d(nn.Module):
|
| 53 |
+
def __init__(self):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.global_max_pool = FastGlobalMaxPool2d()
|
| 56 |
+
|
| 57 |
+
def forward(self, x, mask):
|
| 58 |
+
return torch.cat((
|
| 59 |
+
self.global_max_pool(x * mask),
|
| 60 |
+
self.global_max_pool(x * (1.0 - mask))
|
| 61 |
+
), dim=1)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class FastGlobalAvgPool2d(nn.Module):
|
| 65 |
+
def __init__(self):
|
| 66 |
+
super(FastGlobalAvgPool2d, self).__init__()
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
in_size = x.size()
|
| 70 |
+
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class FastGlobalMaxPool2d(nn.Module):
|
| 74 |
+
def __init__(self):
|
| 75 |
+
super(FastGlobalMaxPool2d, self).__init__()
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
in_size = x.size()
|
| 79 |
+
return x.view((in_size[0], in_size[1], -1)).max(dim=2)[0]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ScaleLayer(nn.Module):
|
| 83 |
+
def __init__(self, init_value=1.0, lr_mult=1):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.lr_mult = lr_mult
|
| 86 |
+
self.scale = nn.Parameter(
|
| 87 |
+
torch.full((1,), init_value / lr_mult, dtype=torch.float32)
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def forward(self, x):
|
| 91 |
+
scale = torch.abs(self.scale * self.lr_mult)
|
| 92 |
+
return x * scale
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class FeaturesConnector(nn.Module):
|
| 96 |
+
def __init__(self, mode, in_channels, feature_channels, out_channels):
|
| 97 |
+
super(FeaturesConnector, self).__init__()
|
| 98 |
+
self.mode = mode if feature_channels else ''
|
| 99 |
+
|
| 100 |
+
if self.mode == 'catc':
|
| 101 |
+
self.reduce_conv = nn.Conv2d(in_channels + feature_channels, out_channels, kernel_size=1)
|
| 102 |
+
elif self.mode == 'sum':
|
| 103 |
+
self.reduce_conv = nn.Conv2d(feature_channels, out_channels, kernel_size=1)
|
| 104 |
+
|
| 105 |
+
self.output_channels = out_channels if self.mode != 'cat' else in_channels + feature_channels
|
| 106 |
+
|
| 107 |
+
def forward(self, x, features):
|
| 108 |
+
if self.mode == 'cat':
|
| 109 |
+
return torch.cat((x, features), 1)
|
| 110 |
+
if self.mode == 'catc':
|
| 111 |
+
return self.reduce_conv(torch.cat((x, features), 1))
|
| 112 |
+
if self.mode == 'sum':
|
| 113 |
+
return self.reduce_conv(features) + x
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
def extra_repr(self):
|
| 117 |
+
return self.mode
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class PosEncodingNeRF(nn.Module):
|
| 121 |
+
def __init__(self, in_features, sidelength=None, fn_samples=None, use_nyquist=True):
|
| 122 |
+
super().__init__()
|
| 123 |
+
|
| 124 |
+
self.in_features = in_features
|
| 125 |
+
|
| 126 |
+
if self.in_features == 3:
|
| 127 |
+
self.num_frequencies = 10
|
| 128 |
+
elif self.in_features == 2:
|
| 129 |
+
assert sidelength is not None
|
| 130 |
+
if isinstance(sidelength, int):
|
| 131 |
+
sidelength = (sidelength, sidelength)
|
| 132 |
+
self.num_frequencies = 4
|
| 133 |
+
if use_nyquist:
|
| 134 |
+
self.num_frequencies = self.get_num_frequencies_nyquist(min(sidelength[0], sidelength[1]))
|
| 135 |
+
elif self.in_features == 1:
|
| 136 |
+
assert fn_samples is not None
|
| 137 |
+
self.num_frequencies = 4
|
| 138 |
+
if use_nyquist:
|
| 139 |
+
self.num_frequencies = self.get_num_frequencies_nyquist(fn_samples)
|
| 140 |
+
|
| 141 |
+
self.out_dim = in_features + 2 * in_features * self.num_frequencies
|
| 142 |
+
|
| 143 |
+
def get_num_frequencies_nyquist(self, samples):
|
| 144 |
+
nyquist_rate = 1 / (2 * (2 * 1 / samples))
|
| 145 |
+
return int(math.floor(math.log(nyquist_rate, 2)))
|
| 146 |
+
|
| 147 |
+
def forward(self, coords):
|
| 148 |
+
coords = coords.view(coords.shape[0], -1, self.in_features)
|
| 149 |
+
|
| 150 |
+
coords_pos_enc = coords
|
| 151 |
+
for i in range(self.num_frequencies):
|
| 152 |
+
for j in range(self.in_features):
|
| 153 |
+
c = coords[..., j]
|
| 154 |
+
|
| 155 |
+
sin = torch.unsqueeze(torch.sin((2 ** i) * np.pi * c), -1)
|
| 156 |
+
cos = torch.unsqueeze(torch.cos((2 ** i) * np.pi * c), -1)
|
| 157 |
+
|
| 158 |
+
coords_pos_enc = torch.cat((coords_pos_enc, sin, cos), axis=-1)
|
| 159 |
+
|
| 160 |
+
return coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class RandomFourier(nn.Module):
|
| 164 |
+
def __init__(self, std_scale, embedding_length, device):
|
| 165 |
+
super().__init__()
|
| 166 |
+
|
| 167 |
+
self.embed = torch.normal(0, 1, (2, embedding_length)) * std_scale
|
| 168 |
+
self.embed = self.embed.to(device)
|
| 169 |
+
|
| 170 |
+
self.out_dim = embedding_length * 2 + 2
|
| 171 |
+
|
| 172 |
+
def forward(self, coords):
|
| 173 |
+
coords_pos_enc = torch.cat([torch.sin(torch.matmul(2 * np.pi * coords, self.embed)),
|
| 174 |
+
torch.cos(torch.matmul(2 * np.pi * coords, self.embed))], dim=-1)
|
| 175 |
+
|
| 176 |
+
return torch.cat([coords, coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)], dim=-1)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class CIPS_embed(nn.Module):
|
| 180 |
+
def __init__(self, size, embedding_length):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.fourier_embed = ConstantInput(size, embedding_length)
|
| 183 |
+
self.predict_embed = Predict_embed(embedding_length)
|
| 184 |
+
self.out_dim = embedding_length * 2 + 2
|
| 185 |
+
|
| 186 |
+
def forward(self, coord, res=None):
|
| 187 |
+
x = self.predict_embed(coord)
|
| 188 |
+
y = self.fourier_embed(x, coord, res)
|
| 189 |
+
|
| 190 |
+
return torch.cat([coord, x, y], dim=-1)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class Predict_embed(nn.Module):
|
| 194 |
+
def __init__(self, embedding_length):
|
| 195 |
+
super(Predict_embed, self).__init__()
|
| 196 |
+
self.ffm = nn.Linear(2, embedding_length, bias=True)
|
| 197 |
+
nn.init.uniform_(self.ffm.weight, -np.sqrt(9 / 2), np.sqrt(9 / 2))
|
| 198 |
+
|
| 199 |
+
def forward(self, x):
|
| 200 |
+
x = self.ffm(x)
|
| 201 |
+
x = torch.sin(x)
|
| 202 |
+
return x
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class ConstantInput(nn.Module):
|
| 206 |
+
def __init__(self, size, channel):
|
| 207 |
+
super().__init__()
|
| 208 |
+
|
| 209 |
+
self.input = nn.Parameter(torch.randn(1, size ** 2, channel))
|
| 210 |
+
|
| 211 |
+
def forward(self, input, coord, resolution=None):
|
| 212 |
+
batch = input.shape[0]
|
| 213 |
+
out = self.input.repeat(batch, 1, 1)
|
| 214 |
+
|
| 215 |
+
if coord.shape[1] != self.input.shape[1]:
|
| 216 |
+
x = out.permute(0, 2, 1).contiguous().view(batch, self.input.shape[-1],
|
| 217 |
+
int(self.input.shape[1] ** 0.5), int(self.input.shape[1] ** 0.5))
|
| 218 |
+
|
| 219 |
+
if resolution is None:
|
| 220 |
+
grid = coord.view(coord.shape[0], int(coord.shape[1] ** 0.5), int(coord.shape[1] ** 0.5), coord.shape[-1])
|
| 221 |
+
else:
|
| 222 |
+
grid = coord.view(coord.shape[0], *resolution, coord.shape[-1])
|
| 223 |
+
|
| 224 |
+
out = F.grid_sample(x, grid.flip(-1), mode='bilinear', padding_mode='border', align_corners=True)
|
| 225 |
+
|
| 226 |
+
out = out.permute(0, 2, 3, 1).contiguous().view(batch, -1, self.input.shape[-1])
|
| 227 |
+
|
| 228 |
+
return out
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class INRGAN_embed(nn.Module):
|
| 232 |
+
def __init__(self, resolution: int, w_dim=None):
|
| 233 |
+
super().__init__()
|
| 234 |
+
|
| 235 |
+
self.resolution = resolution
|
| 236 |
+
self.res_cfg = {"log_emb_size": 32,
|
| 237 |
+
"random_emb_size": 32,
|
| 238 |
+
"const_emb_size": 64,
|
| 239 |
+
"use_cosine": True}
|
| 240 |
+
self.log_emb_size = self.res_cfg.get('log_emb_size', 0)
|
| 241 |
+
self.random_emb_size = self.res_cfg.get('random_emb_size', 0)
|
| 242 |
+
self.shared_emb_size = self.res_cfg.get('shared_emb_size', 0)
|
| 243 |
+
self.predictable_emb_size = self.res_cfg.get('predictable_emb_size', 0)
|
| 244 |
+
self.const_emb_size = self.res_cfg.get('const_emb_size', 0)
|
| 245 |
+
self.fourier_scale = self.res_cfg.get('fourier_scale', np.sqrt(10))
|
| 246 |
+
self.use_cosine = self.res_cfg.get('use_cosine', False)
|
| 247 |
+
|
| 248 |
+
if self.log_emb_size > 0:
|
| 249 |
+
self.register_buffer('log_basis', generate_logarithmic_basis(
|
| 250 |
+
resolution, self.log_emb_size, use_diagonal=self.res_cfg.get('use_diagonal', False)))
|
| 251 |
+
|
| 252 |
+
if self.random_emb_size > 0:
|
| 253 |
+
self.register_buffer('random_basis', self.sample_w_matrix((2, self.random_emb_size), self.fourier_scale))
|
| 254 |
+
|
| 255 |
+
if self.shared_emb_size > 0:
|
| 256 |
+
self.shared_basis = nn.Parameter(self.sample_w_matrix((2, self.shared_emb_size), self.fourier_scale))
|
| 257 |
+
|
| 258 |
+
if self.predictable_emb_size > 0:
|
| 259 |
+
self.W_size = self.predictable_emb_size * self.cfg.coord_dim
|
| 260 |
+
self.b_size = self.predictable_emb_size
|
| 261 |
+
self.affine = nn.Linear(w_dim, self.W_size + self.b_size)
|
| 262 |
+
|
| 263 |
+
if self.const_emb_size > 0:
|
| 264 |
+
self.const_embs = nn.Parameter(torch.randn(1, resolution ** 2, self.const_emb_size))
|
| 265 |
+
|
| 266 |
+
self.out_dim = self.get_total_dim() + 2
|
| 267 |
+
|
| 268 |
+
def sample_w_matrix(self, shape, scale: float):
|
| 269 |
+
return torch.randn(shape) * scale
|
| 270 |
+
|
| 271 |
+
def get_total_dim(self) -> int:
|
| 272 |
+
total_dim = 0
|
| 273 |
+
if self.log_emb_size > 0:
|
| 274 |
+
total_dim += self.log_basis.shape[0] * (2 if self.use_cosine else 1)
|
| 275 |
+
total_dim += self.random_emb_size * (2 if self.use_cosine else 1)
|
| 276 |
+
total_dim += self.shared_emb_size * (2 if self.use_cosine else 1)
|
| 277 |
+
total_dim += self.predictable_emb_size * (2 if self.use_cosine else 1)
|
| 278 |
+
total_dim += self.const_emb_size
|
| 279 |
+
|
| 280 |
+
return total_dim
|
| 281 |
+
|
| 282 |
+
def forward(self, raw_coords, w=None):
|
| 283 |
+
batch_size, img_size, in_channels = raw_coords.shape
|
| 284 |
+
|
| 285 |
+
raw_embs = []
|
| 286 |
+
|
| 287 |
+
if self.log_emb_size > 0:
|
| 288 |
+
log_bases = self.log_basis.unsqueeze(0).repeat(batch_size, 1, 1).permute(0, 2, 1)
|
| 289 |
+
raw_log_embs = torch.matmul(raw_coords, log_bases)
|
| 290 |
+
raw_embs.append(raw_log_embs)
|
| 291 |
+
|
| 292 |
+
if self.random_emb_size > 0:
|
| 293 |
+
random_bases = self.random_basis.unsqueeze(0).repeat(batch_size, 1, 1)
|
| 294 |
+
raw_random_embs = torch.matmul(raw_coords, random_bases)
|
| 295 |
+
raw_embs.append(raw_random_embs)
|
| 296 |
+
|
| 297 |
+
if self.shared_emb_size > 0:
|
| 298 |
+
shared_bases = self.shared_basis.unsqueeze(0).repeat(batch_size, 1, 1)
|
| 299 |
+
raw_shared_embs = torch.matmul(raw_coords, shared_bases)
|
| 300 |
+
raw_embs.append(raw_shared_embs)
|
| 301 |
+
|
| 302 |
+
if self.predictable_emb_size > 0:
|
| 303 |
+
mod = self.affine(w)
|
| 304 |
+
W = self.fourier_scale * mod[:, :self.W_size]
|
| 305 |
+
W = W.view(batch_size, self.cfg.coord_dim, self.predictable_emb_size)
|
| 306 |
+
bias = mod[:, self.W_size:].view(batch_size, 1, self.predictable_emb_size)
|
| 307 |
+
raw_predictable_embs = (torch.matmul(raw_coords, W) + bias)
|
| 308 |
+
raw_embs.append(raw_predictable_embs)
|
| 309 |
+
|
| 310 |
+
if len(raw_embs) > 0:
|
| 311 |
+
raw_embs = torch.cat(raw_embs, dim=-1)
|
| 312 |
+
raw_embs = raw_embs.contiguous()
|
| 313 |
+
out = raw_embs.sin()
|
| 314 |
+
|
| 315 |
+
if self.use_cosine:
|
| 316 |
+
out = torch.cat([out, raw_embs.cos()], dim=-1)
|
| 317 |
+
|
| 318 |
+
if self.const_emb_size > 0:
|
| 319 |
+
const_embs = self.const_embs.repeat([batch_size, 1, 1])
|
| 320 |
+
const_embs = const_embs
|
| 321 |
+
out = torch.cat([out, const_embs], dim=-1)
|
| 322 |
+
|
| 323 |
+
return torch.cat([raw_coords, out], dim=-1)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def generate_logarithmic_basis(
|
| 327 |
+
resolution,
|
| 328 |
+
max_num_feats,
|
| 329 |
+
remove_lowest_freq: bool = False,
|
| 330 |
+
use_diagonal: bool = True):
|
| 331 |
+
"""
|
| 332 |
+
Generates a directional logarithmic basis with the following directions:
|
| 333 |
+
- horizontal
|
| 334 |
+
- vertical
|
| 335 |
+
- main diagonal
|
| 336 |
+
- anti-diagonal
|
| 337 |
+
"""
|
| 338 |
+
max_num_feats_per_direction = np.ceil(np.log2(resolution)).astype(int)
|
| 339 |
+
bases = [
|
| 340 |
+
generate_horizontal_basis(max_num_feats_per_direction),
|
| 341 |
+
generate_vertical_basis(max_num_feats_per_direction),
|
| 342 |
+
]
|
| 343 |
+
|
| 344 |
+
if use_diagonal:
|
| 345 |
+
bases.extend([
|
| 346 |
+
generate_diag_main_basis(max_num_feats_per_direction),
|
| 347 |
+
generate_anti_diag_basis(max_num_feats_per_direction),
|
| 348 |
+
])
|
| 349 |
+
|
| 350 |
+
if remove_lowest_freq:
|
| 351 |
+
bases = [b[1:] for b in bases]
|
| 352 |
+
|
| 353 |
+
# If we do not fit into `max_num_feats`, then trying to remove the features in the order:
|
| 354 |
+
# 1) anti-diagonal 2) main-diagonal
|
| 355 |
+
# while (max_num_feats_per_direction * len(bases) > max_num_feats) and (len(bases) > 2):
|
| 356 |
+
# bases = bases[:-1]
|
| 357 |
+
|
| 358 |
+
basis = torch.cat(bases, dim=0)
|
| 359 |
+
|
| 360 |
+
# If we still do not fit, then let's remove each second feature,
|
| 361 |
+
# then each third, each forth and so on
|
| 362 |
+
# We cannot drop the whole horizontal or vertical direction since otherwise
|
| 363 |
+
# model won't be able to locate the position
|
| 364 |
+
# (unless the previously computed embeddings encode the position)
|
| 365 |
+
# while basis.shape[0] > max_num_feats:
|
| 366 |
+
# num_exceeding_feats = basis.shape[0] - max_num_feats
|
| 367 |
+
# basis = basis[::2]
|
| 368 |
+
|
| 369 |
+
assert basis.shape[0] <= max_num_feats, \
|
| 370 |
+
f"num_coord_feats > max_num_fixed_coord_feats: {basis.shape, max_num_feats}."
|
| 371 |
+
|
| 372 |
+
return basis
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def generate_horizontal_basis(num_feats: int):
|
| 376 |
+
return generate_wavefront_basis(num_feats, [0.0, 1.0], 4.0)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def generate_vertical_basis(num_feats: int):
|
| 380 |
+
return generate_wavefront_basis(num_feats, [1.0, 0.0], 4.0)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def generate_diag_main_basis(num_feats: int):
|
| 384 |
+
return generate_wavefront_basis(num_feats, [-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2))
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def generate_anti_diag_basis(num_feats: int):
|
| 388 |
+
return generate_wavefront_basis(num_feats, [1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2))
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def generate_wavefront_basis(num_feats: int, basis_block, period_length: float):
|
| 392 |
+
period_coef = 2.0 * np.pi / period_length
|
| 393 |
+
basis = torch.tensor([basis_block]).repeat(num_feats, 1) # [num_feats, 2]
|
| 394 |
+
powers = torch.tensor([2]).repeat(num_feats).pow(torch.arange(num_feats)).unsqueeze(1) # [num_feats, 1]
|
| 395 |
+
result = basis * powers * period_coef # [num_feats, 2]
|
| 396 |
+
|
| 397 |
+
return result.float()
|
model/build_model.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from .backbone import build_backbone
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class build_model(nn.Module):
|
| 6 |
+
def __init__(self, opt):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
self.opt = opt
|
| 10 |
+
self.backbone = build_backbone('baseline', opt)
|
| 11 |
+
|
| 12 |
+
def forward(self, composite_image, mask, fg_INR_coordinates, start_proportion=None):
|
| 13 |
+
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
|
| 14 |
+
"""
|
| 15 |
+
For HR Training, due to the designed RSC strategy in Section 3.4 in the paper,
|
| 16 |
+
here we need to pass in the coordinates of the cropped regions.
|
| 17 |
+
"""
|
| 18 |
+
extracted_features = self.backbone(composite_image, mask, fg_INR_coordinates, start_proportion=start_proportion)
|
| 19 |
+
else:
|
| 20 |
+
extracted_features = self.backbone(composite_image, mask)
|
| 21 |
+
|
| 22 |
+
if self.opt.INRDecode:
|
| 23 |
+
return extracted_features
|
| 24 |
+
return None, None, extracted_features
|
model/hrnetv2/__init__.py
ADDED
|
File without changes
|
model/hrnetv2/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (181 Bytes). View file
|
|
|
model/hrnetv2/__pycache__/hrnet_ocr.cpython-38.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
model/hrnetv2/__pycache__/modifiers.cpython-38.pyc
ADDED
|
Binary file (704 Bytes). View file
|
|
|
model/hrnetv2/__pycache__/ocr.cpython-38.pyc
ADDED
|
Binary file (4.54 kB). View file
|
|
|
model/hrnetv2/__pycache__/resnetv1b.cpython-38.pyc
ADDED
|
Binary file (7.54 kB). View file
|
|
|
model/hrnetv2/hrnet_ocr.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch._utils
|
| 7 |
+
|
| 8 |
+
from .ocr import SpatialOCR_Module, SpatialGather_Module
|
| 9 |
+
from .resnetv1b import BasicBlockV1b, BottleneckV1b
|
| 10 |
+
|
| 11 |
+
relu_inplace = True
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class HighResolutionModule(nn.Module):
|
| 15 |
+
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
|
| 16 |
+
num_channels, fuse_method,multi_scale_output=True,
|
| 17 |
+
norm_layer=nn.BatchNorm2d, align_corners=True):
|
| 18 |
+
super(HighResolutionModule, self).__init__()
|
| 19 |
+
self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
|
| 20 |
+
|
| 21 |
+
self.num_inchannels = num_inchannels
|
| 22 |
+
self.fuse_method = fuse_method
|
| 23 |
+
self.num_branches = num_branches
|
| 24 |
+
self.norm_layer = norm_layer
|
| 25 |
+
self.align_corners = align_corners
|
| 26 |
+
|
| 27 |
+
self.multi_scale_output = multi_scale_output
|
| 28 |
+
|
| 29 |
+
self.branches = self._make_branches(
|
| 30 |
+
num_branches, blocks, num_blocks, num_channels)
|
| 31 |
+
self.fuse_layers = self._make_fuse_layers()
|
| 32 |
+
self.relu = nn.ReLU(inplace=relu_inplace)
|
| 33 |
+
|
| 34 |
+
def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
|
| 35 |
+
if num_branches != len(num_blocks):
|
| 36 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
|
| 37 |
+
num_branches, len(num_blocks))
|
| 38 |
+
raise ValueError(error_msg)
|
| 39 |
+
|
| 40 |
+
if num_branches != len(num_channels):
|
| 41 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
|
| 42 |
+
num_branches, len(num_channels))
|
| 43 |
+
raise ValueError(error_msg)
|
| 44 |
+
|
| 45 |
+
if num_branches != len(num_inchannels):
|
| 46 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
|
| 47 |
+
num_branches, len(num_inchannels))
|
| 48 |
+
raise ValueError(error_msg)
|
| 49 |
+
|
| 50 |
+
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
|
| 51 |
+
stride=1):
|
| 52 |
+
downsample = None
|
| 53 |
+
if stride != 1 or \
|
| 54 |
+
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
|
| 55 |
+
downsample = nn.Sequential(
|
| 56 |
+
nn.Conv2d(self.num_inchannels[branch_index],
|
| 57 |
+
num_channels[branch_index] * block.expansion,
|
| 58 |
+
kernel_size=1, stride=stride, bias=False),
|
| 59 |
+
self.norm_layer(num_channels[branch_index] * block.expansion),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
layers = []
|
| 63 |
+
layers.append(block(self.num_inchannels[branch_index],
|
| 64 |
+
num_channels[branch_index], stride,
|
| 65 |
+
downsample=downsample, norm_layer=self.norm_layer))
|
| 66 |
+
self.num_inchannels[branch_index] = \
|
| 67 |
+
num_channels[branch_index] * block.expansion
|
| 68 |
+
for i in range(1, num_blocks[branch_index]):
|
| 69 |
+
layers.append(block(self.num_inchannels[branch_index],
|
| 70 |
+
num_channels[branch_index],
|
| 71 |
+
norm_layer=self.norm_layer))
|
| 72 |
+
|
| 73 |
+
return nn.Sequential(*layers)
|
| 74 |
+
|
| 75 |
+
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
| 76 |
+
branches = []
|
| 77 |
+
|
| 78 |
+
for i in range(num_branches):
|
| 79 |
+
branches.append(
|
| 80 |
+
self._make_one_branch(i, block, num_blocks, num_channels))
|
| 81 |
+
|
| 82 |
+
return nn.ModuleList(branches)
|
| 83 |
+
|
| 84 |
+
def _make_fuse_layers(self):
|
| 85 |
+
if self.num_branches == 1:
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
num_branches = self.num_branches
|
| 89 |
+
num_inchannels = self.num_inchannels
|
| 90 |
+
fuse_layers = []
|
| 91 |
+
for i in range(num_branches if self.multi_scale_output else 1):
|
| 92 |
+
fuse_layer = []
|
| 93 |
+
for j in range(num_branches):
|
| 94 |
+
if j > i:
|
| 95 |
+
fuse_layer.append(nn.Sequential(
|
| 96 |
+
nn.Conv2d(in_channels=num_inchannels[j],
|
| 97 |
+
out_channels=num_inchannels[i],
|
| 98 |
+
kernel_size=1,
|
| 99 |
+
bias=False),
|
| 100 |
+
self.norm_layer(num_inchannels[i])))
|
| 101 |
+
elif j == i:
|
| 102 |
+
fuse_layer.append(None)
|
| 103 |
+
else:
|
| 104 |
+
conv3x3s = []
|
| 105 |
+
for k in range(i - j):
|
| 106 |
+
if k == i - j - 1:
|
| 107 |
+
num_outchannels_conv3x3 = num_inchannels[i]
|
| 108 |
+
conv3x3s.append(nn.Sequential(
|
| 109 |
+
nn.Conv2d(num_inchannels[j],
|
| 110 |
+
num_outchannels_conv3x3,
|
| 111 |
+
kernel_size=3, stride=2, padding=1, bias=False),
|
| 112 |
+
self.norm_layer(num_outchannels_conv3x3)))
|
| 113 |
+
else:
|
| 114 |
+
num_outchannels_conv3x3 = num_inchannels[j]
|
| 115 |
+
conv3x3s.append(nn.Sequential(
|
| 116 |
+
nn.Conv2d(num_inchannels[j],
|
| 117 |
+
num_outchannels_conv3x3,
|
| 118 |
+
kernel_size=3, stride=2, padding=1, bias=False),
|
| 119 |
+
self.norm_layer(num_outchannels_conv3x3),
|
| 120 |
+
nn.ReLU(inplace=relu_inplace)))
|
| 121 |
+
fuse_layer.append(nn.Sequential(*conv3x3s))
|
| 122 |
+
fuse_layers.append(nn.ModuleList(fuse_layer))
|
| 123 |
+
|
| 124 |
+
return nn.ModuleList(fuse_layers)
|
| 125 |
+
|
| 126 |
+
def get_num_inchannels(self):
|
| 127 |
+
return self.num_inchannels
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
if self.num_branches == 1:
|
| 131 |
+
return [self.branches[0](x[0])]
|
| 132 |
+
|
| 133 |
+
for i in range(self.num_branches):
|
| 134 |
+
x[i] = self.branches[i](x[i])
|
| 135 |
+
|
| 136 |
+
x_fuse = []
|
| 137 |
+
for i in range(len(self.fuse_layers)):
|
| 138 |
+
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
| 139 |
+
for j in range(1, self.num_branches):
|
| 140 |
+
if i == j:
|
| 141 |
+
y = y + x[j]
|
| 142 |
+
elif j > i:
|
| 143 |
+
width_output = x[i].shape[-1]
|
| 144 |
+
height_output = x[i].shape[-2]
|
| 145 |
+
y = y + F.interpolate(
|
| 146 |
+
self.fuse_layers[i][j](x[j]),
|
| 147 |
+
size=[height_output, width_output],
|
| 148 |
+
mode='bilinear', align_corners=self.align_corners)
|
| 149 |
+
else:
|
| 150 |
+
y = y + self.fuse_layers[i][j](x[j])
|
| 151 |
+
x_fuse.append(self.relu(y))
|
| 152 |
+
|
| 153 |
+
return x_fuse
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class HighResolutionNet(nn.Module):
|
| 157 |
+
def __init__(self, width, num_classes, ocr_width=256, small=False,
|
| 158 |
+
norm_layer=nn.BatchNorm2d, align_corners=True, opt=None):
|
| 159 |
+
super(HighResolutionNet, self).__init__()
|
| 160 |
+
self.opt = opt
|
| 161 |
+
self.norm_layer = norm_layer
|
| 162 |
+
self.width = width
|
| 163 |
+
self.ocr_width = ocr_width
|
| 164 |
+
self.ocr_on = ocr_width > 0
|
| 165 |
+
self.align_corners = align_corners
|
| 166 |
+
|
| 167 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
| 168 |
+
self.bn1 = norm_layer(64)
|
| 169 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
| 170 |
+
self.bn2 = norm_layer(64)
|
| 171 |
+
self.relu = nn.ReLU(inplace=relu_inplace)
|
| 172 |
+
|
| 173 |
+
num_blocks = 2 if small else 4
|
| 174 |
+
|
| 175 |
+
stage1_num_channels = 64
|
| 176 |
+
self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
|
| 177 |
+
stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
|
| 178 |
+
|
| 179 |
+
self.stage2_num_branches = 2
|
| 180 |
+
num_channels = [width, 2 * width]
|
| 181 |
+
num_inchannels = [
|
| 182 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
|
| 183 |
+
self.transition1 = self._make_transition_layer(
|
| 184 |
+
[stage1_out_channel], num_inchannels)
|
| 185 |
+
self.stage2, pre_stage_channels = self._make_stage(
|
| 186 |
+
BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
|
| 187 |
+
num_blocks=2 * [num_blocks], num_channels=num_channels)
|
| 188 |
+
|
| 189 |
+
self.stage3_num_branches = 3
|
| 190 |
+
num_channels = [width, 2 * width, 4 * width]
|
| 191 |
+
num_inchannels = [
|
| 192 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
|
| 193 |
+
self.transition2 = self._make_transition_layer(
|
| 194 |
+
pre_stage_channels, num_inchannels)
|
| 195 |
+
self.stage3, pre_stage_channels = self._make_stage(
|
| 196 |
+
BasicBlockV1b, num_inchannels=num_inchannels,
|
| 197 |
+
num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
|
| 198 |
+
num_blocks=3 * [num_blocks], num_channels=num_channels)
|
| 199 |
+
|
| 200 |
+
self.stage4_num_branches = 4
|
| 201 |
+
num_channels = [width, 2 * width, 4 * width, 8 * width]
|
| 202 |
+
num_inchannels = [
|
| 203 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
|
| 204 |
+
self.transition3 = self._make_transition_layer(
|
| 205 |
+
pre_stage_channels, num_inchannels)
|
| 206 |
+
self.stage4, pre_stage_channels = self._make_stage(
|
| 207 |
+
BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
|
| 208 |
+
num_branches=self.stage4_num_branches,
|
| 209 |
+
num_blocks=4 * [num_blocks], num_channels=num_channels)
|
| 210 |
+
|
| 211 |
+
if self.ocr_on:
|
| 212 |
+
last_inp_channels = np.int(np.sum(pre_stage_channels))
|
| 213 |
+
ocr_mid_channels = 2 * ocr_width
|
| 214 |
+
ocr_key_channels = ocr_width
|
| 215 |
+
|
| 216 |
+
self.conv3x3_ocr = nn.Sequential(
|
| 217 |
+
nn.Conv2d(last_inp_channels, ocr_mid_channels,
|
| 218 |
+
kernel_size=3, stride=1, padding=1),
|
| 219 |
+
norm_layer(ocr_mid_channels),
|
| 220 |
+
nn.ReLU(inplace=relu_inplace),
|
| 221 |
+
)
|
| 222 |
+
self.ocr_gather_head = SpatialGather_Module(num_classes)
|
| 223 |
+
|
| 224 |
+
self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
|
| 225 |
+
key_channels=ocr_key_channels,
|
| 226 |
+
out_channels=ocr_mid_channels,
|
| 227 |
+
scale=1,
|
| 228 |
+
dropout=0.05,
|
| 229 |
+
norm_layer=norm_layer,
|
| 230 |
+
align_corners=align_corners, opt=opt)
|
| 231 |
+
|
| 232 |
+
def _make_transition_layer(
|
| 233 |
+
self, num_channels_pre_layer, num_channels_cur_layer):
|
| 234 |
+
num_branches_cur = len(num_channels_cur_layer)
|
| 235 |
+
num_branches_pre = len(num_channels_pre_layer)
|
| 236 |
+
|
| 237 |
+
transition_layers = []
|
| 238 |
+
for i in range(num_branches_cur):
|
| 239 |
+
if i < num_branches_pre:
|
| 240 |
+
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
| 241 |
+
transition_layers.append(nn.Sequential(
|
| 242 |
+
nn.Conv2d(num_channels_pre_layer[i],
|
| 243 |
+
num_channels_cur_layer[i],
|
| 244 |
+
kernel_size=3,
|
| 245 |
+
stride=1,
|
| 246 |
+
padding=1,
|
| 247 |
+
bias=False),
|
| 248 |
+
self.norm_layer(num_channels_cur_layer[i]),
|
| 249 |
+
nn.ReLU(inplace=relu_inplace)))
|
| 250 |
+
else:
|
| 251 |
+
transition_layers.append(None)
|
| 252 |
+
else:
|
| 253 |
+
conv3x3s = []
|
| 254 |
+
for j in range(i + 1 - num_branches_pre):
|
| 255 |
+
inchannels = num_channels_pre_layer[-1]
|
| 256 |
+
outchannels = num_channels_cur_layer[i] \
|
| 257 |
+
if j == i - num_branches_pre else inchannels
|
| 258 |
+
conv3x3s.append(nn.Sequential(
|
| 259 |
+
nn.Conv2d(inchannels, outchannels,
|
| 260 |
+
kernel_size=3, stride=2, padding=1, bias=False),
|
| 261 |
+
self.norm_layer(outchannels),
|
| 262 |
+
nn.ReLU(inplace=relu_inplace)))
|
| 263 |
+
transition_layers.append(nn.Sequential(*conv3x3s))
|
| 264 |
+
|
| 265 |
+
return nn.ModuleList(transition_layers)
|
| 266 |
+
|
| 267 |
+
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
| 268 |
+
downsample = None
|
| 269 |
+
if stride != 1 or inplanes != planes * block.expansion:
|
| 270 |
+
downsample = nn.Sequential(
|
| 271 |
+
nn.Conv2d(inplanes, planes * block.expansion,
|
| 272 |
+
kernel_size=1, stride=stride, bias=False),
|
| 273 |
+
self.norm_layer(planes * block.expansion),
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
layers = []
|
| 277 |
+
layers.append(block(inplanes, planes, stride,
|
| 278 |
+
downsample=downsample, norm_layer=self.norm_layer))
|
| 279 |
+
inplanes = planes * block.expansion
|
| 280 |
+
for i in range(1, blocks):
|
| 281 |
+
layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
|
| 282 |
+
|
| 283 |
+
return nn.Sequential(*layers)
|
| 284 |
+
|
| 285 |
+
def _make_stage(self, block, num_inchannels,
|
| 286 |
+
num_modules, num_branches, num_blocks, num_channels,
|
| 287 |
+
fuse_method='SUM',
|
| 288 |
+
multi_scale_output=True):
|
| 289 |
+
modules = []
|
| 290 |
+
for i in range(num_modules):
|
| 291 |
+
# multi_scale_output is only used last module
|
| 292 |
+
if not multi_scale_output and i == num_modules - 1:
|
| 293 |
+
reset_multi_scale_output = False
|
| 294 |
+
else:
|
| 295 |
+
reset_multi_scale_output = True
|
| 296 |
+
modules.append(
|
| 297 |
+
HighResolutionModule(num_branches,
|
| 298 |
+
block,
|
| 299 |
+
num_blocks,
|
| 300 |
+
num_inchannels,
|
| 301 |
+
num_channels,
|
| 302 |
+
fuse_method,
|
| 303 |
+
reset_multi_scale_output,
|
| 304 |
+
norm_layer=self.norm_layer,
|
| 305 |
+
align_corners=self.align_corners)
|
| 306 |
+
)
|
| 307 |
+
num_inchannels = modules[-1].get_num_inchannels()
|
| 308 |
+
|
| 309 |
+
return nn.Sequential(*modules), num_inchannels
|
| 310 |
+
|
| 311 |
+
def forward(self, x, mask=None, additional_features=None):
|
| 312 |
+
hrnet_feats = self.compute_hrnet_feats(x, additional_features)
|
| 313 |
+
if not self.ocr_on:
|
| 314 |
+
return hrnet_feats,
|
| 315 |
+
|
| 316 |
+
ocr_feats = self.conv3x3_ocr(hrnet_feats)
|
| 317 |
+
mask = nn.functional.interpolate(mask, size=ocr_feats.size()[2:], mode='bilinear', align_corners=True)
|
| 318 |
+
context = self.ocr_gather_head(ocr_feats, mask)
|
| 319 |
+
ocr_feats = self.ocr_distri_head(ocr_feats, context)
|
| 320 |
+
return ocr_feats,
|
| 321 |
+
|
| 322 |
+
def compute_hrnet_feats(self, x, additional_features, return_list=False):
|
| 323 |
+
x = self.compute_pre_stage_features(x, additional_features)
|
| 324 |
+
x = self.layer1(x)
|
| 325 |
+
|
| 326 |
+
x_list = []
|
| 327 |
+
for i in range(self.stage2_num_branches):
|
| 328 |
+
if self.transition1[i] is not None:
|
| 329 |
+
x_list.append(self.transition1[i](x))
|
| 330 |
+
else:
|
| 331 |
+
x_list.append(x)
|
| 332 |
+
y_list = self.stage2(x_list)
|
| 333 |
+
|
| 334 |
+
x_list = []
|
| 335 |
+
for i in range(self.stage3_num_branches):
|
| 336 |
+
if self.transition2[i] is not None:
|
| 337 |
+
if i < self.stage2_num_branches:
|
| 338 |
+
x_list.append(self.transition2[i](y_list[i]))
|
| 339 |
+
else:
|
| 340 |
+
x_list.append(self.transition2[i](y_list[-1]))
|
| 341 |
+
else:
|
| 342 |
+
x_list.append(y_list[i])
|
| 343 |
+
y_list = self.stage3(x_list)
|
| 344 |
+
|
| 345 |
+
x_list = []
|
| 346 |
+
for i in range(self.stage4_num_branches):
|
| 347 |
+
if self.transition3[i] is not None:
|
| 348 |
+
if i < self.stage3_num_branches:
|
| 349 |
+
x_list.append(self.transition3[i](y_list[i]))
|
| 350 |
+
else:
|
| 351 |
+
x_list.append(self.transition3[i](y_list[-1]))
|
| 352 |
+
else:
|
| 353 |
+
x_list.append(y_list[i])
|
| 354 |
+
x = self.stage4(x_list)
|
| 355 |
+
|
| 356 |
+
if return_list:
|
| 357 |
+
return x
|
| 358 |
+
|
| 359 |
+
# Upsampling
|
| 360 |
+
x0_h, x0_w = x[0].size(2), x[0].size(3)
|
| 361 |
+
x1 = F.interpolate(x[1], size=(x0_h, x0_w),
|
| 362 |
+
mode='bilinear', align_corners=self.align_corners)
|
| 363 |
+
x2 = F.interpolate(x[2], size=(x0_h, x0_w),
|
| 364 |
+
mode='bilinear', align_corners=self.align_corners)
|
| 365 |
+
x3 = F.interpolate(x[3], size=(x0_h, x0_w),
|
| 366 |
+
mode='bilinear', align_corners=self.align_corners)
|
| 367 |
+
|
| 368 |
+
return torch.cat([x[0], x1, x2, x3], 1)
|
| 369 |
+
|
| 370 |
+
def compute_pre_stage_features(self, x, additional_features):
|
| 371 |
+
x = self.conv1(x)
|
| 372 |
+
x = self.bn1(x)
|
| 373 |
+
x = self.relu(x)
|
| 374 |
+
if additional_features is not None:
|
| 375 |
+
x = x + additional_features
|
| 376 |
+
x = self.conv2(x)
|
| 377 |
+
x = self.bn2(x)
|
| 378 |
+
return self.relu(x)
|
| 379 |
+
|
| 380 |
+
def load_pretrained_weights(self, pretrained_path=''):
|
| 381 |
+
model_dict = self.state_dict()
|
| 382 |
+
|
| 383 |
+
if not os.path.exists(pretrained_path):
|
| 384 |
+
print(f'\nFile "{pretrained_path}" does not exist.')
|
| 385 |
+
print('You need to specify the correct path to the pre-trained weights.\n'
|
| 386 |
+
'You can download the weights for HRNet from the repository:\n'
|
| 387 |
+
'https://github.com/HRNet/HRNet-Image-Classification')
|
| 388 |
+
exit(1)
|
| 389 |
+
pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'})
|
| 390 |
+
pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in
|
| 391 |
+
pretrained_dict.items()}
|
| 392 |
+
params_count = len(pretrained_dict)
|
| 393 |
+
|
| 394 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items()
|
| 395 |
+
if k in model_dict.keys()}
|
| 396 |
+
|
| 397 |
+
print(f'Loaded {len(pretrained_dict)} of {params_count} pretrained parameters for HRNet')
|
| 398 |
+
|
| 399 |
+
model_dict.update(pretrained_dict)
|
| 400 |
+
self.load_state_dict(model_dict)
|
model/hrnetv2/modifiers.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
class LRMult(object):
|
| 4 |
+
def __init__(self, lr_mult=1.):
|
| 5 |
+
self.lr_mult = lr_mult
|
| 6 |
+
|
| 7 |
+
def __call__(self, m):
|
| 8 |
+
if getattr(m, 'weight', None) is not None:
|
| 9 |
+
m.weight.lr_mult = self.lr_mult
|
| 10 |
+
if getattr(m, 'bias', None) is not None:
|
| 11 |
+
m.bias.lr_mult = self.lr_mult
|
model/hrnetv2/ocr.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch._utils
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SpatialGather_Module(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Aggregate the context features according to the initial
|
| 10 |
+
predicted probability distribution.
|
| 11 |
+
Employ the soft-weighted method to aggregate the context.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, cls_num=0, scale=1):
|
| 15 |
+
super(SpatialGather_Module, self).__init__()
|
| 16 |
+
self.cls_num = cls_num
|
| 17 |
+
self.scale = scale
|
| 18 |
+
|
| 19 |
+
def forward(self, feats, probs):
|
| 20 |
+
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
|
| 21 |
+
probs = probs.view(batch_size, c, -1)
|
| 22 |
+
feats = feats.view(batch_size, feats.size(1), -1)
|
| 23 |
+
feats = feats.permute(0, 2, 1) # batch x hw x c
|
| 24 |
+
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
|
| 25 |
+
ocr_context = torch.matmul(probs, feats) \
|
| 26 |
+
.permute(0, 2, 1).unsqueeze(3).contiguous() # batch x k x c
|
| 27 |
+
return ocr_context
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SpatialOCR_Module(nn.Module):
|
| 31 |
+
"""
|
| 32 |
+
Implementation of the OCR module:
|
| 33 |
+
We aggregate the global object representation to update the representation for each pixel.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self,
|
| 37 |
+
in_channels,
|
| 38 |
+
key_channels,
|
| 39 |
+
out_channels,
|
| 40 |
+
scale=1,
|
| 41 |
+
dropout=0.1,
|
| 42 |
+
norm_layer=nn.BatchNorm2d,
|
| 43 |
+
align_corners=True, opt=None):
|
| 44 |
+
super(SpatialOCR_Module, self).__init__()
|
| 45 |
+
self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale,
|
| 46 |
+
norm_layer, align_corners)
|
| 47 |
+
_in_channels = 2 * in_channels
|
| 48 |
+
self.conv_bn_dropout = nn.Sequential(
|
| 49 |
+
nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
|
| 50 |
+
nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
|
| 51 |
+
nn.Dropout2d(dropout)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward(self, feats, proxy_feats):
|
| 55 |
+
context = self.object_context_block(feats, proxy_feats)
|
| 56 |
+
|
| 57 |
+
output = self.conv_bn_dropout(torch.cat([context, feats], 1))
|
| 58 |
+
|
| 59 |
+
return output
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ObjectAttentionBlock2D(nn.Module):
|
| 63 |
+
'''
|
| 64 |
+
The basic implementation for object context block
|
| 65 |
+
Input:
|
| 66 |
+
N X C X H X W
|
| 67 |
+
Parameters:
|
| 68 |
+
in_channels : the dimension of the input feature map
|
| 69 |
+
key_channels : the dimension after the key/query transform
|
| 70 |
+
scale : choose the scale to downsample the input feature maps (save memory cost)
|
| 71 |
+
bn_type : specify the bn type
|
| 72 |
+
Return:
|
| 73 |
+
N X C X H X W
|
| 74 |
+
'''
|
| 75 |
+
|
| 76 |
+
def __init__(self,
|
| 77 |
+
in_channels,
|
| 78 |
+
key_channels,
|
| 79 |
+
scale=1,
|
| 80 |
+
norm_layer=nn.BatchNorm2d,
|
| 81 |
+
align_corners=True):
|
| 82 |
+
super(ObjectAttentionBlock2D, self).__init__()
|
| 83 |
+
self.scale = scale
|
| 84 |
+
self.in_channels = in_channels
|
| 85 |
+
self.key_channels = key_channels
|
| 86 |
+
self.align_corners = align_corners
|
| 87 |
+
|
| 88 |
+
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
|
| 89 |
+
self.f_pixel = nn.Sequential(
|
| 90 |
+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
| 91 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
| 92 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
| 93 |
+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
|
| 94 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
| 95 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
| 96 |
+
)
|
| 97 |
+
self.f_object = nn.Sequential(
|
| 98 |
+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
| 99 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
| 100 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
| 101 |
+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
|
| 102 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
| 103 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
| 104 |
+
)
|
| 105 |
+
self.f_down = nn.Sequential(
|
| 106 |
+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
| 107 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
| 108 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
| 109 |
+
)
|
| 110 |
+
self.f_up = nn.Sequential(
|
| 111 |
+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
|
| 112 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
| 113 |
+
nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True))
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def forward(self, x, proxy):
|
| 117 |
+
batch_size, h, w = x.size(0), x.size(2), x.size(3)
|
| 118 |
+
if self.scale > 1:
|
| 119 |
+
x = self.pool(x)
|
| 120 |
+
|
| 121 |
+
query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
|
| 122 |
+
query = query.permute(0, 2, 1)
|
| 123 |
+
key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
|
| 124 |
+
value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
|
| 125 |
+
value = value.permute(0, 2, 1)
|
| 126 |
+
|
| 127 |
+
sim_map = torch.matmul(query, key)
|
| 128 |
+
sim_map = (self.key_channels ** -.5) * sim_map
|
| 129 |
+
sim_map = F.softmax(sim_map, dim=-1)
|
| 130 |
+
|
| 131 |
+
# add bg context ...
|
| 132 |
+
context = torch.matmul(sim_map, value)
|
| 133 |
+
context = context.permute(0, 2, 1).contiguous()
|
| 134 |
+
context = context.view(batch_size, self.key_channels, *x.size()[2:])
|
| 135 |
+
context = self.f_up(context)
|
| 136 |
+
if self.scale > 1:
|
| 137 |
+
context = F.interpolate(input=context, size=(h, w),
|
| 138 |
+
mode='bilinear', align_corners=self.align_corners)
|
| 139 |
+
|
| 140 |
+
return context
|
model/hrnetv2/resnetv1b.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet'
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BasicBlockV1b(nn.Module):
|
| 7 |
+
expansion = 1
|
| 8 |
+
|
| 9 |
+
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
|
| 10 |
+
previous_dilation=1, norm_layer=nn.BatchNorm2d):
|
| 11 |
+
super(BasicBlockV1b, self).__init__()
|
| 12 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
|
| 13 |
+
padding=dilation, dilation=dilation, bias=False)
|
| 14 |
+
self.bn1 = norm_layer(planes)
|
| 15 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
|
| 16 |
+
padding=previous_dilation, dilation=previous_dilation, bias=False)
|
| 17 |
+
self.bn2 = norm_layer(planes)
|
| 18 |
+
|
| 19 |
+
self.relu = nn.ReLU(inplace=True)
|
| 20 |
+
self.downsample = downsample
|
| 21 |
+
self.stride = stride
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
residual = x
|
| 25 |
+
|
| 26 |
+
out = self.conv1(x)
|
| 27 |
+
out = self.bn1(out)
|
| 28 |
+
out = self.relu(out)
|
| 29 |
+
|
| 30 |
+
out = self.conv2(out)
|
| 31 |
+
out = self.bn2(out)
|
| 32 |
+
|
| 33 |
+
if self.downsample is not None:
|
| 34 |
+
residual = self.downsample(x)
|
| 35 |
+
|
| 36 |
+
out = out + residual
|
| 37 |
+
out = self.relu(out)
|
| 38 |
+
|
| 39 |
+
return out
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class BottleneckV1b(nn.Module):
|
| 43 |
+
expansion = 4
|
| 44 |
+
|
| 45 |
+
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
|
| 46 |
+
previous_dilation=1, norm_layer=nn.BatchNorm2d):
|
| 47 |
+
super(BottleneckV1b, self).__init__()
|
| 48 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 49 |
+
self.bn1 = norm_layer(planes)
|
| 50 |
+
|
| 51 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
| 52 |
+
padding=dilation, dilation=dilation, bias=False)
|
| 53 |
+
self.bn2 = norm_layer(planes)
|
| 54 |
+
|
| 55 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
| 56 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 57 |
+
|
| 58 |
+
self.relu = nn.ReLU(inplace=True)
|
| 59 |
+
self.downsample = downsample
|
| 60 |
+
self.stride = stride
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
residual = x
|
| 64 |
+
|
| 65 |
+
out = self.conv1(x)
|
| 66 |
+
out = self.bn1(out)
|
| 67 |
+
out = self.relu(out)
|
| 68 |
+
|
| 69 |
+
out = self.conv2(out)
|
| 70 |
+
out = self.bn2(out)
|
| 71 |
+
out = self.relu(out)
|
| 72 |
+
|
| 73 |
+
out = self.conv3(out)
|
| 74 |
+
out = self.bn3(out)
|
| 75 |
+
|
| 76 |
+
if self.downsample is not None:
|
| 77 |
+
residual = self.downsample(x)
|
| 78 |
+
|
| 79 |
+
out = out + residual
|
| 80 |
+
out = self.relu(out)
|
| 81 |
+
|
| 82 |
+
return out
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class ResNetV1b(nn.Module):
|
| 86 |
+
""" Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5.
|
| 87 |
+
|
| 88 |
+
Parameters
|
| 89 |
+
----------
|
| 90 |
+
block : Block
|
| 91 |
+
Class for the residual block. Options are BasicBlockV1, BottleneckV1.
|
| 92 |
+
layers : list of int
|
| 93 |
+
Numbers of layers in each block
|
| 94 |
+
classes : int, default 1000
|
| 95 |
+
Number of classification classes.
|
| 96 |
+
dilated : bool, default False
|
| 97 |
+
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
|
| 98 |
+
typically used in Semantic Segmentation.
|
| 99 |
+
norm_layer : object
|
| 100 |
+
Normalization layer used (default: :class:`nn.BatchNorm2d`)
|
| 101 |
+
deep_stem : bool, default False
|
| 102 |
+
Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
|
| 103 |
+
avg_down : bool, default False
|
| 104 |
+
Whether to use average pooling for projection skip connection between stages/downsample.
|
| 105 |
+
final_drop : float, default 0.0
|
| 106 |
+
Dropout ratio before the final classification layer.
|
| 107 |
+
|
| 108 |
+
Reference:
|
| 109 |
+
- He, Kaiming, et al. "Deep residual learning for image recognition."
|
| 110 |
+
Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
|
| 111 |
+
|
| 112 |
+
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
|
| 113 |
+
"""
|
| 114 |
+
def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32,
|
| 115 |
+
avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d):
|
| 116 |
+
self.inplanes = stem_width*2 if deep_stem else 64
|
| 117 |
+
super(ResNetV1b, self).__init__()
|
| 118 |
+
if not deep_stem:
|
| 119 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
| 120 |
+
else:
|
| 121 |
+
self.conv1 = nn.Sequential(
|
| 122 |
+
nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
|
| 123 |
+
norm_layer(stem_width),
|
| 124 |
+
nn.ReLU(True),
|
| 125 |
+
nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
|
| 126 |
+
norm_layer(stem_width),
|
| 127 |
+
nn.ReLU(True),
|
| 128 |
+
nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False)
|
| 129 |
+
)
|
| 130 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 131 |
+
self.relu = nn.ReLU(True)
|
| 132 |
+
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
| 133 |
+
self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down,
|
| 134 |
+
norm_layer=norm_layer)
|
| 135 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down,
|
| 136 |
+
norm_layer=norm_layer)
|
| 137 |
+
if dilated:
|
| 138 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2,
|
| 139 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
| 140 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4,
|
| 141 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
| 142 |
+
else:
|
| 143 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
| 144 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
| 145 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
| 146 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
| 147 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 148 |
+
self.drop = None
|
| 149 |
+
if final_drop > 0.0:
|
| 150 |
+
self.drop = nn.Dropout(final_drop)
|
| 151 |
+
self.fc = nn.Linear(512 * block.expansion, classes)
|
| 152 |
+
|
| 153 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
|
| 154 |
+
avg_down=False, norm_layer=nn.BatchNorm2d):
|
| 155 |
+
downsample = None
|
| 156 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 157 |
+
downsample = []
|
| 158 |
+
if avg_down:
|
| 159 |
+
if dilation == 1:
|
| 160 |
+
downsample.append(
|
| 161 |
+
nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)
|
| 162 |
+
)
|
| 163 |
+
else:
|
| 164 |
+
downsample.append(
|
| 165 |
+
nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False)
|
| 166 |
+
)
|
| 167 |
+
downsample.extend([
|
| 168 |
+
nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
|
| 169 |
+
kernel_size=1, stride=1, bias=False),
|
| 170 |
+
norm_layer(planes * block.expansion)
|
| 171 |
+
])
|
| 172 |
+
downsample = nn.Sequential(*downsample)
|
| 173 |
+
else:
|
| 174 |
+
downsample = nn.Sequential(
|
| 175 |
+
nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
|
| 176 |
+
kernel_size=1, stride=stride, bias=False),
|
| 177 |
+
norm_layer(planes * block.expansion)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
layers = []
|
| 181 |
+
if dilation in (1, 2):
|
| 182 |
+
layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
|
| 183 |
+
previous_dilation=dilation, norm_layer=norm_layer))
|
| 184 |
+
elif dilation == 4:
|
| 185 |
+
layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
|
| 186 |
+
previous_dilation=dilation, norm_layer=norm_layer))
|
| 187 |
+
else:
|
| 188 |
+
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
|
| 189 |
+
|
| 190 |
+
self.inplanes = planes * block.expansion
|
| 191 |
+
for _ in range(1, blocks):
|
| 192 |
+
layers.append(block(self.inplanes, planes, dilation=dilation,
|
| 193 |
+
previous_dilation=dilation, norm_layer=norm_layer))
|
| 194 |
+
|
| 195 |
+
return nn.Sequential(*layers)
|
| 196 |
+
|
| 197 |
+
def forward(self, x):
|
| 198 |
+
x = self.conv1(x)
|
| 199 |
+
x = self.bn1(x)
|
| 200 |
+
x = self.relu(x)
|
| 201 |
+
x = self.maxpool(x)
|
| 202 |
+
|
| 203 |
+
x = self.layer1(x)
|
| 204 |
+
x = self.layer2(x)
|
| 205 |
+
x = self.layer3(x)
|
| 206 |
+
x = self.layer4(x)
|
| 207 |
+
|
| 208 |
+
x = self.avgpool(x)
|
| 209 |
+
x = x.view(x.size(0), -1)
|
| 210 |
+
if self.drop is not None:
|
| 211 |
+
x = self.drop(x)
|
| 212 |
+
x = self.fc(x)
|
| 213 |
+
|
| 214 |
+
return x
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _safe_state_dict_filtering(orig_dict, model_dict_keys):
|
| 218 |
+
filtered_orig_dict = {}
|
| 219 |
+
for k, v in orig_dict.items():
|
| 220 |
+
if k in model_dict_keys:
|
| 221 |
+
filtered_orig_dict[k] = v
|
| 222 |
+
else:
|
| 223 |
+
print(f"[ERROR] Failed to load <{k}> in backbone")
|
| 224 |
+
return filtered_orig_dict
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def resnet34_v1b(pretrained=False, **kwargs):
|
| 228 |
+
model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
|
| 229 |
+
if pretrained:
|
| 230 |
+
model_dict = model.state_dict()
|
| 231 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
| 232 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(),
|
| 233 |
+
model_dict.keys()
|
| 234 |
+
)
|
| 235 |
+
model_dict.update(filtered_orig_dict)
|
| 236 |
+
model.load_state_dict(model_dict)
|
| 237 |
+
return model
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def resnet50_v1s(pretrained=False, **kwargs):
|
| 241 |
+
model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs)
|
| 242 |
+
if pretrained:
|
| 243 |
+
model_dict = model.state_dict()
|
| 244 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
| 245 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(),
|
| 246 |
+
model_dict.keys()
|
| 247 |
+
)
|
| 248 |
+
model_dict.update(filtered_orig_dict)
|
| 249 |
+
model.load_state_dict(model_dict)
|
| 250 |
+
return model
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def resnet101_v1s(pretrained=False, **kwargs):
|
| 254 |
+
model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs)
|
| 255 |
+
if pretrained:
|
| 256 |
+
model_dict = model.state_dict()
|
| 257 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
| 258 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(),
|
| 259 |
+
model_dict.keys()
|
| 260 |
+
)
|
| 261 |
+
model_dict.update(filtered_orig_dict)
|
| 262 |
+
model.load_state_dict(model_dict)
|
| 263 |
+
return model
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def resnet152_v1s(pretrained=False, **kwargs):
|
| 267 |
+
model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs)
|
| 268 |
+
if pretrained:
|
| 269 |
+
model_dict = model.state_dict()
|
| 270 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
| 271 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(),
|
| 272 |
+
model_dict.keys()
|
| 273 |
+
)
|
| 274 |
+
model_dict.update(filtered_orig_dict)
|
| 275 |
+
model.load_state_dict(model_dict)
|
| 276 |
+
return model
|
model/lut_transformation_net.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from utils.misc import normalize
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class build_lut_transform(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, input_dim, lut_dim, input_resolution, opt):
|
| 11 |
+
super().__init__()
|
| 12 |
+
|
| 13 |
+
self.lut_dim = lut_dim
|
| 14 |
+
self.opt = opt
|
| 15 |
+
|
| 16 |
+
# self.compress_layer = nn.Linear(input_resolution, 1)
|
| 17 |
+
|
| 18 |
+
self.transform_layers = nn.Sequential(
|
| 19 |
+
nn.Linear(input_dim, 3 * lut_dim ** 3, bias=True),
|
| 20 |
+
# nn.BatchNorm1d(3 * lut_dim ** 3, affine=False),
|
| 21 |
+
nn.ReLU(inplace=True),
|
| 22 |
+
nn.Linear(3 * lut_dim ** 3, 3 * lut_dim ** 3, bias=True),
|
| 23 |
+
)
|
| 24 |
+
self.transform_layers[-1].apply(lambda m: hyper_weight_init(m))
|
| 25 |
+
|
| 26 |
+
def forward(self, composite_image, fg_appearance_features, bg_appearance_features):
|
| 27 |
+
composite_image = normalize(composite_image, self.opt, 'inv')
|
| 28 |
+
|
| 29 |
+
features = fg_appearance_features
|
| 30 |
+
|
| 31 |
+
lut_params = self.transform_layers(features)
|
| 32 |
+
|
| 33 |
+
fit_3DLUT = lut_params.view(lut_params.shape[0], 3, self.lut_dim, self.lut_dim, self.lut_dim)
|
| 34 |
+
|
| 35 |
+
lut_transform_image = torch.stack(
|
| 36 |
+
[TrilinearInterpolation(lut, image)[0] for lut, image in zip(fit_3DLUT, composite_image)], dim=0)
|
| 37 |
+
|
| 38 |
+
return fit_3DLUT, normalize(lut_transform_image, self.opt)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def TrilinearInterpolation(LUT, img):
|
| 42 |
+
img = (img - 0.5) * 2.
|
| 43 |
+
|
| 44 |
+
img = img.unsqueeze(0).permute(0, 2, 3, 1)[:, None].flip(-1)
|
| 45 |
+
|
| 46 |
+
# Note that the coordinates in the grid_sample are inverse to LUT DHW, i.e., xyz is to WHD not DHW.
|
| 47 |
+
LUT = LUT[None]
|
| 48 |
+
|
| 49 |
+
# grid sample
|
| 50 |
+
result = F.grid_sample(LUT, img, mode='bilinear', padding_mode='border', align_corners=True)
|
| 51 |
+
|
| 52 |
+
# drop added dimensions and permute back
|
| 53 |
+
result = result[:, :, 0]
|
| 54 |
+
|
| 55 |
+
return result
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def hyper_weight_init(m):
|
| 59 |
+
if hasattr(m, 'weight'):
|
| 60 |
+
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
|
| 61 |
+
m.weight.data = m.weight.data / 1.e2
|
| 62 |
+
|
| 63 |
+
if hasattr(m, 'bias'):
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
m.bias.uniform_(0., 1.)
|
pretrained_models/Resolution_1024_HAdobe5K.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4917e99cc20c2530b6d248d530368929c1784113d20365085b96bbb10860a2f8
|
| 3 |
+
size 477235439
|