Spaces:
Sleeping
Sleeping
Upload 33 files
Browse files- gfpgan/__init__.py +7 -0
- gfpgan/__pycache__/__init__.cpython-38.pyc +0 -0
- gfpgan/__pycache__/train.cpython-38.pyc +0 -0
- gfpgan/__pycache__/utils.cpython-38.pyc +0 -0
- gfpgan/__pycache__/version.cpython-38.pyc +0 -0
- gfpgan/archs/__init__.py +10 -0
- gfpgan/archs/__pycache__/__init__.cpython-38.pyc +0 -0
- gfpgan/archs/__pycache__/arcface_arch.cpython-38.pyc +0 -0
- gfpgan/archs/__pycache__/gfpgan_bilinear_arch.cpython-38.pyc +0 -0
- gfpgan/archs/__pycache__/gfpganv1_arch.cpython-38.pyc +0 -0
- gfpgan/archs/__pycache__/gfpganv1_clean_arch.cpython-38.pyc +0 -0
- gfpgan/archs/__pycache__/restoreformer_arch.cpython-38.pyc +0 -0
- gfpgan/archs/__pycache__/stylegan2_bilinear_arch.cpython-38.pyc +0 -0
- gfpgan/archs/__pycache__/stylegan2_clean_arch.cpython-38.pyc +0 -0
- gfpgan/archs/arcface_arch.py +245 -0
- gfpgan/archs/gfpgan_bilinear_arch.py +312 -0
- gfpgan/archs/gfpganv1_arch.py +439 -0
- gfpgan/archs/gfpganv1_clean_arch.py +324 -0
- gfpgan/archs/restoreformer_arch.py +658 -0
- gfpgan/archs/stylegan2_bilinear_arch.py +613 -0
- gfpgan/archs/stylegan2_clean_arch.py +368 -0
- gfpgan/data/__init__.py +10 -0
- gfpgan/data/__pycache__/__init__.cpython-38.pyc +0 -0
- gfpgan/data/__pycache__/ffhq_degradation_dataset.cpython-38.pyc +0 -0
- gfpgan/data/ffhq_degradation_dataset.py +230 -0
- gfpgan/models/__init__.py +10 -0
- gfpgan/models/__pycache__/__init__.cpython-38.pyc +0 -0
- gfpgan/models/__pycache__/gfpgan_model.cpython-38.pyc +0 -0
- gfpgan/models/gfpgan_model.py +579 -0
- gfpgan/train.py +11 -0
- gfpgan/utils.py +148 -0
- gfpgan/version.py +5 -0
- gfpgan/weights/README.md +3 -0
gfpgan/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
from .archs import *
|
| 3 |
+
from .data import *
|
| 4 |
+
from .models import *
|
| 5 |
+
from .utils import *
|
| 6 |
+
|
| 7 |
+
# from .version import *
|
gfpgan/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (218 Bytes). View file
|
|
|
gfpgan/__pycache__/train.cpython-38.pyc
ADDED
|
Binary file (418 Bytes). View file
|
|
|
gfpgan/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (4.23 kB). View file
|
|
|
gfpgan/__pycache__/version.cpython-38.pyc
ADDED
|
Binary file (229 Bytes). View file
|
|
|
gfpgan/archs/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from basicsr.utils import scandir
|
| 3 |
+
from os import path as osp
|
| 4 |
+
|
| 5 |
+
# automatically scan and import arch modules for registry
|
| 6 |
+
# scan all the files that end with '_arch.py' under the archs folder
|
| 7 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
| 8 |
+
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
| 9 |
+
# import all the arch modules
|
| 10 |
+
_arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames]
|
gfpgan/archs/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (709 Bytes). View file
|
|
|
gfpgan/archs/__pycache__/arcface_arch.cpython-38.pyc
ADDED
|
Binary file (7.4 kB). View file
|
|
|
gfpgan/archs/__pycache__/gfpgan_bilinear_arch.cpython-38.pyc
ADDED
|
Binary file (9.14 kB). View file
|
|
|
gfpgan/archs/__pycache__/gfpganv1_arch.cpython-38.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
gfpgan/archs/__pycache__/gfpganv1_clean_arch.cpython-38.pyc
ADDED
|
Binary file (9.57 kB). View file
|
|
|
gfpgan/archs/__pycache__/restoreformer_arch.cpython-38.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
gfpgan/archs/__pycache__/stylegan2_bilinear_arch.cpython-38.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
gfpgan/archs/__pycache__/stylegan2_clean_arch.cpython-38.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
gfpgan/archs/arcface_arch.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def conv3x3(inplanes, outplanes, stride=1):
|
| 6 |
+
"""A simple wrapper for 3x3 convolution with padding.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
inplanes (int): Channel number of inputs.
|
| 10 |
+
outplanes (int): Channel number of outputs.
|
| 11 |
+
stride (int): Stride in convolution. Default: 1.
|
| 12 |
+
"""
|
| 13 |
+
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BasicBlock(nn.Module):
|
| 17 |
+
"""Basic residual block used in the ResNetArcFace architecture.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
inplanes (int): Channel number of inputs.
|
| 21 |
+
planes (int): Channel number of outputs.
|
| 22 |
+
stride (int): Stride in convolution. Default: 1.
|
| 23 |
+
downsample (nn.Module): The downsample module. Default: None.
|
| 24 |
+
"""
|
| 25 |
+
expansion = 1 # output channel expansion ratio
|
| 26 |
+
|
| 27 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 28 |
+
super(BasicBlock, self).__init__()
|
| 29 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 30 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 31 |
+
self.relu = nn.ReLU(inplace=True)
|
| 32 |
+
self.conv2 = conv3x3(planes, planes)
|
| 33 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 34 |
+
self.downsample = downsample
|
| 35 |
+
self.stride = stride
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
residual = x
|
| 39 |
+
|
| 40 |
+
out = self.conv1(x)
|
| 41 |
+
out = self.bn1(out)
|
| 42 |
+
out = self.relu(out)
|
| 43 |
+
|
| 44 |
+
out = self.conv2(out)
|
| 45 |
+
out = self.bn2(out)
|
| 46 |
+
|
| 47 |
+
if self.downsample is not None:
|
| 48 |
+
residual = self.downsample(x)
|
| 49 |
+
|
| 50 |
+
out += residual
|
| 51 |
+
out = self.relu(out)
|
| 52 |
+
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class IRBlock(nn.Module):
|
| 57 |
+
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
inplanes (int): Channel number of inputs.
|
| 61 |
+
planes (int): Channel number of outputs.
|
| 62 |
+
stride (int): Stride in convolution. Default: 1.
|
| 63 |
+
downsample (nn.Module): The downsample module. Default: None.
|
| 64 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
| 65 |
+
"""
|
| 66 |
+
expansion = 1 # output channel expansion ratio
|
| 67 |
+
|
| 68 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
| 69 |
+
super(IRBlock, self).__init__()
|
| 70 |
+
self.bn0 = nn.BatchNorm2d(inplanes)
|
| 71 |
+
self.conv1 = conv3x3(inplanes, inplanes)
|
| 72 |
+
self.bn1 = nn.BatchNorm2d(inplanes)
|
| 73 |
+
self.prelu = nn.PReLU()
|
| 74 |
+
self.conv2 = conv3x3(inplanes, planes, stride)
|
| 75 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 76 |
+
self.downsample = downsample
|
| 77 |
+
self.stride = stride
|
| 78 |
+
self.use_se = use_se
|
| 79 |
+
if self.use_se:
|
| 80 |
+
self.se = SEBlock(planes)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
residual = x
|
| 84 |
+
out = self.bn0(x)
|
| 85 |
+
out = self.conv1(out)
|
| 86 |
+
out = self.bn1(out)
|
| 87 |
+
out = self.prelu(out)
|
| 88 |
+
|
| 89 |
+
out = self.conv2(out)
|
| 90 |
+
out = self.bn2(out)
|
| 91 |
+
if self.use_se:
|
| 92 |
+
out = self.se(out)
|
| 93 |
+
|
| 94 |
+
if self.downsample is not None:
|
| 95 |
+
residual = self.downsample(x)
|
| 96 |
+
|
| 97 |
+
out += residual
|
| 98 |
+
out = self.prelu(out)
|
| 99 |
+
|
| 100 |
+
return out
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Bottleneck(nn.Module):
|
| 104 |
+
"""Bottleneck block used in the ResNetArcFace architecture.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
inplanes (int): Channel number of inputs.
|
| 108 |
+
planes (int): Channel number of outputs.
|
| 109 |
+
stride (int): Stride in convolution. Default: 1.
|
| 110 |
+
downsample (nn.Module): The downsample module. Default: None.
|
| 111 |
+
"""
|
| 112 |
+
expansion = 4 # output channel expansion ratio
|
| 113 |
+
|
| 114 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 115 |
+
super(Bottleneck, self).__init__()
|
| 116 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 117 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 118 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 119 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 120 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
| 121 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 122 |
+
self.relu = nn.ReLU(inplace=True)
|
| 123 |
+
self.downsample = downsample
|
| 124 |
+
self.stride = stride
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
residual = x
|
| 128 |
+
|
| 129 |
+
out = self.conv1(x)
|
| 130 |
+
out = self.bn1(out)
|
| 131 |
+
out = self.relu(out)
|
| 132 |
+
|
| 133 |
+
out = self.conv2(out)
|
| 134 |
+
out = self.bn2(out)
|
| 135 |
+
out = self.relu(out)
|
| 136 |
+
|
| 137 |
+
out = self.conv3(out)
|
| 138 |
+
out = self.bn3(out)
|
| 139 |
+
|
| 140 |
+
if self.downsample is not None:
|
| 141 |
+
residual = self.downsample(x)
|
| 142 |
+
|
| 143 |
+
out += residual
|
| 144 |
+
out = self.relu(out)
|
| 145 |
+
|
| 146 |
+
return out
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class SEBlock(nn.Module):
|
| 150 |
+
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
channel (int): Channel number of inputs.
|
| 154 |
+
reduction (int): Channel reduction ration. Default: 16.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(self, channel, reduction=16):
|
| 158 |
+
super(SEBlock, self).__init__()
|
| 159 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
|
| 160 |
+
self.fc = nn.Sequential(
|
| 161 |
+
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
|
| 162 |
+
nn.Sigmoid())
|
| 163 |
+
|
| 164 |
+
def forward(self, x):
|
| 165 |
+
b, c, _, _ = x.size()
|
| 166 |
+
y = self.avg_pool(x).view(b, c)
|
| 167 |
+
y = self.fc(y).view(b, c, 1, 1)
|
| 168 |
+
return x * y
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@ARCH_REGISTRY.register()
|
| 172 |
+
class ResNetArcFace(nn.Module):
|
| 173 |
+
"""ArcFace with ResNet architectures.
|
| 174 |
+
|
| 175 |
+
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
block (str): Block used in the ArcFace architecture.
|
| 179 |
+
layers (tuple(int)): Block numbers in each layer.
|
| 180 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(self, block, layers, use_se=True):
|
| 184 |
+
if block == 'IRBlock':
|
| 185 |
+
block = IRBlock
|
| 186 |
+
self.inplanes = 64
|
| 187 |
+
self.use_se = use_se
|
| 188 |
+
super(ResNetArcFace, self).__init__()
|
| 189 |
+
|
| 190 |
+
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
| 191 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 192 |
+
self.prelu = nn.PReLU()
|
| 193 |
+
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 194 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 195 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 196 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 197 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 198 |
+
self.bn4 = nn.BatchNorm2d(512)
|
| 199 |
+
self.dropout = nn.Dropout()
|
| 200 |
+
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
| 201 |
+
self.bn5 = nn.BatchNorm1d(512)
|
| 202 |
+
|
| 203 |
+
# initialization
|
| 204 |
+
for m in self.modules():
|
| 205 |
+
if isinstance(m, nn.Conv2d):
|
| 206 |
+
nn.init.xavier_normal_(m.weight)
|
| 207 |
+
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
| 208 |
+
nn.init.constant_(m.weight, 1)
|
| 209 |
+
nn.init.constant_(m.bias, 0)
|
| 210 |
+
elif isinstance(m, nn.Linear):
|
| 211 |
+
nn.init.xavier_normal_(m.weight)
|
| 212 |
+
nn.init.constant_(m.bias, 0)
|
| 213 |
+
|
| 214 |
+
def _make_layer(self, block, planes, num_blocks, stride=1):
|
| 215 |
+
downsample = None
|
| 216 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 217 |
+
downsample = nn.Sequential(
|
| 218 |
+
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
| 219 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 220 |
+
)
|
| 221 |
+
layers = []
|
| 222 |
+
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
|
| 223 |
+
self.inplanes = planes
|
| 224 |
+
for _ in range(1, num_blocks):
|
| 225 |
+
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
| 226 |
+
|
| 227 |
+
return nn.Sequential(*layers)
|
| 228 |
+
|
| 229 |
+
def forward(self, x):
|
| 230 |
+
x = self.conv1(x)
|
| 231 |
+
x = self.bn1(x)
|
| 232 |
+
x = self.prelu(x)
|
| 233 |
+
x = self.maxpool(x)
|
| 234 |
+
|
| 235 |
+
x = self.layer1(x)
|
| 236 |
+
x = self.layer2(x)
|
| 237 |
+
x = self.layer3(x)
|
| 238 |
+
x = self.layer4(x)
|
| 239 |
+
x = self.bn4(x)
|
| 240 |
+
x = self.dropout(x)
|
| 241 |
+
x = x.view(x.size(0), -1)
|
| 242 |
+
x = self.fc5(x)
|
| 243 |
+
x = self.bn5(x)
|
| 244 |
+
|
| 245 |
+
return x
|
gfpgan/archs/gfpgan_bilinear_arch.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from .gfpganv1_arch import ResUpBlock
|
| 8 |
+
from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
|
| 9 |
+
StyleGAN2GeneratorBilinear)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
|
| 13 |
+
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
| 14 |
+
|
| 15 |
+
It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
|
| 16 |
+
deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
out_size (int): The spatial size of outputs.
|
| 20 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
| 21 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
| 22 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
| 23 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
| 24 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
| 25 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self,
|
| 29 |
+
out_size,
|
| 30 |
+
num_style_feat=512,
|
| 31 |
+
num_mlp=8,
|
| 32 |
+
channel_multiplier=2,
|
| 33 |
+
lr_mlp=0.01,
|
| 34 |
+
narrow=1,
|
| 35 |
+
sft_half=False):
|
| 36 |
+
super(StyleGAN2GeneratorBilinearSFT, self).__init__(
|
| 37 |
+
out_size,
|
| 38 |
+
num_style_feat=num_style_feat,
|
| 39 |
+
num_mlp=num_mlp,
|
| 40 |
+
channel_multiplier=channel_multiplier,
|
| 41 |
+
lr_mlp=lr_mlp,
|
| 42 |
+
narrow=narrow)
|
| 43 |
+
self.sft_half = sft_half
|
| 44 |
+
|
| 45 |
+
def forward(self,
|
| 46 |
+
styles,
|
| 47 |
+
conditions,
|
| 48 |
+
input_is_latent=False,
|
| 49 |
+
noise=None,
|
| 50 |
+
randomize_noise=True,
|
| 51 |
+
truncation=1,
|
| 52 |
+
truncation_latent=None,
|
| 53 |
+
inject_index=None,
|
| 54 |
+
return_latents=False):
|
| 55 |
+
"""Forward function for StyleGAN2GeneratorBilinearSFT.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
styles (list[Tensor]): Sample codes of styles.
|
| 59 |
+
conditions (list[Tensor]): SFT conditions to generators.
|
| 60 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
| 61 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
| 62 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
| 63 |
+
truncation (float): The truncation ratio. Default: 1.
|
| 64 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
| 65 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
| 66 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
| 67 |
+
"""
|
| 68 |
+
# style codes -> latents with Style MLP layer
|
| 69 |
+
if not input_is_latent:
|
| 70 |
+
styles = [self.style_mlp(s) for s in styles]
|
| 71 |
+
# noises
|
| 72 |
+
if noise is None:
|
| 73 |
+
if randomize_noise:
|
| 74 |
+
noise = [None] * self.num_layers # for each style conv layer
|
| 75 |
+
else: # use the stored noise
|
| 76 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
| 77 |
+
# style truncation
|
| 78 |
+
if truncation < 1:
|
| 79 |
+
style_truncation = []
|
| 80 |
+
for style in styles:
|
| 81 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
| 82 |
+
styles = style_truncation
|
| 83 |
+
# get style latents with injection
|
| 84 |
+
if len(styles) == 1:
|
| 85 |
+
inject_index = self.num_latent
|
| 86 |
+
|
| 87 |
+
if styles[0].ndim < 3:
|
| 88 |
+
# repeat latent code for all the layers
|
| 89 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
| 90 |
+
else: # used for encoder with different latent code for each layer
|
| 91 |
+
latent = styles[0]
|
| 92 |
+
elif len(styles) == 2: # mixing noises
|
| 93 |
+
if inject_index is None:
|
| 94 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
| 95 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
| 96 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
| 97 |
+
latent = torch.cat([latent1, latent2], 1)
|
| 98 |
+
|
| 99 |
+
# main generation
|
| 100 |
+
out = self.constant_input(latent.shape[0])
|
| 101 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
| 102 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
| 103 |
+
|
| 104 |
+
i = 1
|
| 105 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
| 106 |
+
noise[2::2], self.to_rgbs):
|
| 107 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
| 108 |
+
|
| 109 |
+
# the conditions may have fewer levels
|
| 110 |
+
if i < len(conditions):
|
| 111 |
+
# SFT part to combine the conditions
|
| 112 |
+
if self.sft_half: # only apply SFT to half of the channels
|
| 113 |
+
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
| 114 |
+
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
| 115 |
+
out = torch.cat([out_same, out_sft], dim=1)
|
| 116 |
+
else: # apply SFT to all the channels
|
| 117 |
+
out = out * conditions[i - 1] + conditions[i]
|
| 118 |
+
|
| 119 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
| 120 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
| 121 |
+
i += 2
|
| 122 |
+
|
| 123 |
+
image = skip
|
| 124 |
+
|
| 125 |
+
if return_latents:
|
| 126 |
+
return image, latent
|
| 127 |
+
else:
|
| 128 |
+
return image, None
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@ARCH_REGISTRY.register()
|
| 132 |
+
class GFPGANBilinear(nn.Module):
|
| 133 |
+
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
| 134 |
+
|
| 135 |
+
It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
|
| 136 |
+
deployment. It can be easily converted to the clean version: GFPGANv1Clean.
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
out_size (int): The spatial size of outputs.
|
| 143 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
| 144 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
| 145 |
+
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
| 146 |
+
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
| 147 |
+
|
| 148 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
| 149 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
| 150 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
| 151 |
+
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
| 152 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
| 153 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def __init__(
|
| 157 |
+
self,
|
| 158 |
+
out_size,
|
| 159 |
+
num_style_feat=512,
|
| 160 |
+
channel_multiplier=1,
|
| 161 |
+
decoder_load_path=None,
|
| 162 |
+
fix_decoder=True,
|
| 163 |
+
# for stylegan decoder
|
| 164 |
+
num_mlp=8,
|
| 165 |
+
lr_mlp=0.01,
|
| 166 |
+
input_is_latent=False,
|
| 167 |
+
different_w=False,
|
| 168 |
+
narrow=1,
|
| 169 |
+
sft_half=False):
|
| 170 |
+
|
| 171 |
+
super(GFPGANBilinear, self).__init__()
|
| 172 |
+
self.input_is_latent = input_is_latent
|
| 173 |
+
self.different_w = different_w
|
| 174 |
+
self.num_style_feat = num_style_feat
|
| 175 |
+
|
| 176 |
+
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
| 177 |
+
channels = {
|
| 178 |
+
'4': int(512 * unet_narrow),
|
| 179 |
+
'8': int(512 * unet_narrow),
|
| 180 |
+
'16': int(512 * unet_narrow),
|
| 181 |
+
'32': int(512 * unet_narrow),
|
| 182 |
+
'64': int(256 * channel_multiplier * unet_narrow),
|
| 183 |
+
'128': int(128 * channel_multiplier * unet_narrow),
|
| 184 |
+
'256': int(64 * channel_multiplier * unet_narrow),
|
| 185 |
+
'512': int(32 * channel_multiplier * unet_narrow),
|
| 186 |
+
'1024': int(16 * channel_multiplier * unet_narrow)
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
self.log_size = int(math.log(out_size, 2))
|
| 190 |
+
first_out_size = 2**(int(math.log(out_size, 2)))
|
| 191 |
+
|
| 192 |
+
self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
|
| 193 |
+
|
| 194 |
+
# downsample
|
| 195 |
+
in_channels = channels[f'{first_out_size}']
|
| 196 |
+
self.conv_body_down = nn.ModuleList()
|
| 197 |
+
for i in range(self.log_size, 2, -1):
|
| 198 |
+
out_channels = channels[f'{2**(i - 1)}']
|
| 199 |
+
self.conv_body_down.append(ResBlock(in_channels, out_channels))
|
| 200 |
+
in_channels = out_channels
|
| 201 |
+
|
| 202 |
+
self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
|
| 203 |
+
|
| 204 |
+
# upsample
|
| 205 |
+
in_channels = channels['4']
|
| 206 |
+
self.conv_body_up = nn.ModuleList()
|
| 207 |
+
for i in range(3, self.log_size + 1):
|
| 208 |
+
out_channels = channels[f'{2**i}']
|
| 209 |
+
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
|
| 210 |
+
in_channels = out_channels
|
| 211 |
+
|
| 212 |
+
# to RGB
|
| 213 |
+
self.toRGB = nn.ModuleList()
|
| 214 |
+
for i in range(3, self.log_size + 1):
|
| 215 |
+
self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
|
| 216 |
+
|
| 217 |
+
if different_w:
|
| 218 |
+
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
| 219 |
+
else:
|
| 220 |
+
linear_out_channel = num_style_feat
|
| 221 |
+
|
| 222 |
+
self.final_linear = EqualLinear(
|
| 223 |
+
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
| 224 |
+
|
| 225 |
+
# the decoder: stylegan2 generator with SFT modulations
|
| 226 |
+
self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
|
| 227 |
+
out_size=out_size,
|
| 228 |
+
num_style_feat=num_style_feat,
|
| 229 |
+
num_mlp=num_mlp,
|
| 230 |
+
channel_multiplier=channel_multiplier,
|
| 231 |
+
lr_mlp=lr_mlp,
|
| 232 |
+
narrow=narrow,
|
| 233 |
+
sft_half=sft_half)
|
| 234 |
+
|
| 235 |
+
# load pre-trained stylegan2 model if necessary
|
| 236 |
+
if decoder_load_path:
|
| 237 |
+
self.stylegan_decoder.load_state_dict(
|
| 238 |
+
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
| 239 |
+
# fix decoder without updating params
|
| 240 |
+
if fix_decoder:
|
| 241 |
+
for _, param in self.stylegan_decoder.named_parameters():
|
| 242 |
+
param.requires_grad = False
|
| 243 |
+
|
| 244 |
+
# for SFT modulations (scale and shift)
|
| 245 |
+
self.condition_scale = nn.ModuleList()
|
| 246 |
+
self.condition_shift = nn.ModuleList()
|
| 247 |
+
for i in range(3, self.log_size + 1):
|
| 248 |
+
out_channels = channels[f'{2**i}']
|
| 249 |
+
if sft_half:
|
| 250 |
+
sft_out_channels = out_channels
|
| 251 |
+
else:
|
| 252 |
+
sft_out_channels = out_channels * 2
|
| 253 |
+
self.condition_scale.append(
|
| 254 |
+
nn.Sequential(
|
| 255 |
+
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
| 256 |
+
ScaledLeakyReLU(0.2),
|
| 257 |
+
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
|
| 258 |
+
self.condition_shift.append(
|
| 259 |
+
nn.Sequential(
|
| 260 |
+
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
| 261 |
+
ScaledLeakyReLU(0.2),
|
| 262 |
+
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
|
| 263 |
+
|
| 264 |
+
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
|
| 265 |
+
"""Forward function for GFPGANBilinear.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
x (Tensor): Input images.
|
| 269 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
| 270 |
+
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
| 271 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
| 272 |
+
"""
|
| 273 |
+
conditions = []
|
| 274 |
+
unet_skips = []
|
| 275 |
+
out_rgbs = []
|
| 276 |
+
|
| 277 |
+
# encoder
|
| 278 |
+
feat = self.conv_body_first(x)
|
| 279 |
+
for i in range(self.log_size - 2):
|
| 280 |
+
feat = self.conv_body_down[i](feat)
|
| 281 |
+
unet_skips.insert(0, feat)
|
| 282 |
+
|
| 283 |
+
feat = self.final_conv(feat)
|
| 284 |
+
|
| 285 |
+
# style code
|
| 286 |
+
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
| 287 |
+
if self.different_w:
|
| 288 |
+
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
| 289 |
+
|
| 290 |
+
# decode
|
| 291 |
+
for i in range(self.log_size - 2):
|
| 292 |
+
# add unet skip
|
| 293 |
+
feat = feat + unet_skips[i]
|
| 294 |
+
# ResUpLayer
|
| 295 |
+
feat = self.conv_body_up[i](feat)
|
| 296 |
+
# generate scale and shift for SFT layers
|
| 297 |
+
scale = self.condition_scale[i](feat)
|
| 298 |
+
conditions.append(scale.clone())
|
| 299 |
+
shift = self.condition_shift[i](feat)
|
| 300 |
+
conditions.append(shift.clone())
|
| 301 |
+
# generate rgb images
|
| 302 |
+
if return_rgb:
|
| 303 |
+
out_rgbs.append(self.toRGB[i](feat))
|
| 304 |
+
|
| 305 |
+
# decoder
|
| 306 |
+
image, _ = self.stylegan_decoder([style_code],
|
| 307 |
+
conditions,
|
| 308 |
+
return_latents=return_latents,
|
| 309 |
+
input_is_latent=self.input_is_latent,
|
| 310 |
+
randomize_noise=randomize_noise)
|
| 311 |
+
|
| 312 |
+
return image, out_rgbs
|
gfpgan/archs/gfpganv1_arch.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
|
| 5 |
+
StyleGAN2Generator)
|
| 6 |
+
from basicsr.ops.fused_act import FusedLeakyReLU
|
| 7 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
| 13 |
+
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
out_size (int): The spatial size of outputs.
|
| 17 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
| 18 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
| 19 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
| 20 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
| 21 |
+
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
| 22 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
| 23 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
| 24 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self,
|
| 28 |
+
out_size,
|
| 29 |
+
num_style_feat=512,
|
| 30 |
+
num_mlp=8,
|
| 31 |
+
channel_multiplier=2,
|
| 32 |
+
resample_kernel=(1, 3, 3, 1),
|
| 33 |
+
lr_mlp=0.01,
|
| 34 |
+
narrow=1,
|
| 35 |
+
sft_half=False):
|
| 36 |
+
super(StyleGAN2GeneratorSFT, self).__init__(
|
| 37 |
+
out_size,
|
| 38 |
+
num_style_feat=num_style_feat,
|
| 39 |
+
num_mlp=num_mlp,
|
| 40 |
+
channel_multiplier=channel_multiplier,
|
| 41 |
+
resample_kernel=resample_kernel,
|
| 42 |
+
lr_mlp=lr_mlp,
|
| 43 |
+
narrow=narrow)
|
| 44 |
+
self.sft_half = sft_half
|
| 45 |
+
|
| 46 |
+
def forward(self,
|
| 47 |
+
styles,
|
| 48 |
+
conditions,
|
| 49 |
+
input_is_latent=False,
|
| 50 |
+
noise=None,
|
| 51 |
+
randomize_noise=True,
|
| 52 |
+
truncation=1,
|
| 53 |
+
truncation_latent=None,
|
| 54 |
+
inject_index=None,
|
| 55 |
+
return_latents=False):
|
| 56 |
+
"""Forward function for StyleGAN2GeneratorSFT.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
styles (list[Tensor]): Sample codes of styles.
|
| 60 |
+
conditions (list[Tensor]): SFT conditions to generators.
|
| 61 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
| 62 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
| 63 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
| 64 |
+
truncation (float): The truncation ratio. Default: 1.
|
| 65 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
| 66 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
| 67 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
| 68 |
+
"""
|
| 69 |
+
# style codes -> latents with Style MLP layer
|
| 70 |
+
if not input_is_latent:
|
| 71 |
+
styles = [self.style_mlp(s) for s in styles]
|
| 72 |
+
# noises
|
| 73 |
+
if noise is None:
|
| 74 |
+
if randomize_noise:
|
| 75 |
+
noise = [None] * self.num_layers # for each style conv layer
|
| 76 |
+
else: # use the stored noise
|
| 77 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
| 78 |
+
# style truncation
|
| 79 |
+
if truncation < 1:
|
| 80 |
+
style_truncation = []
|
| 81 |
+
for style in styles:
|
| 82 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
| 83 |
+
styles = style_truncation
|
| 84 |
+
# get style latents with injection
|
| 85 |
+
if len(styles) == 1:
|
| 86 |
+
inject_index = self.num_latent
|
| 87 |
+
|
| 88 |
+
if styles[0].ndim < 3:
|
| 89 |
+
# repeat latent code for all the layers
|
| 90 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
| 91 |
+
else: # used for encoder with different latent code for each layer
|
| 92 |
+
latent = styles[0]
|
| 93 |
+
elif len(styles) == 2: # mixing noises
|
| 94 |
+
if inject_index is None:
|
| 95 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
| 96 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
| 97 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
| 98 |
+
latent = torch.cat([latent1, latent2], 1)
|
| 99 |
+
|
| 100 |
+
# main generation
|
| 101 |
+
out = self.constant_input(latent.shape[0])
|
| 102 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
| 103 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
| 104 |
+
|
| 105 |
+
i = 1
|
| 106 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
| 107 |
+
noise[2::2], self.to_rgbs):
|
| 108 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
| 109 |
+
|
| 110 |
+
# the conditions may have fewer levels
|
| 111 |
+
if i < len(conditions):
|
| 112 |
+
# SFT part to combine the conditions
|
| 113 |
+
if self.sft_half: # only apply SFT to half of the channels
|
| 114 |
+
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
| 115 |
+
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
| 116 |
+
out = torch.cat([out_same, out_sft], dim=1)
|
| 117 |
+
else: # apply SFT to all the channels
|
| 118 |
+
out = out * conditions[i - 1] + conditions[i]
|
| 119 |
+
|
| 120 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
| 121 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
| 122 |
+
i += 2
|
| 123 |
+
|
| 124 |
+
image = skip
|
| 125 |
+
|
| 126 |
+
if return_latents:
|
| 127 |
+
return image, latent
|
| 128 |
+
else:
|
| 129 |
+
return image, None
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class ConvUpLayer(nn.Module):
|
| 133 |
+
"""Convolutional upsampling layer. It uses bilinear upsampler + Conv.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
in_channels (int): Channel number of the input.
|
| 137 |
+
out_channels (int): Channel number of the output.
|
| 138 |
+
kernel_size (int): Size of the convolving kernel.
|
| 139 |
+
stride (int): Stride of the convolution. Default: 1
|
| 140 |
+
padding (int): Zero-padding added to both sides of the input. Default: 0.
|
| 141 |
+
bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
|
| 142 |
+
bias_init_val (float): Bias initialized value. Default: 0.
|
| 143 |
+
activate (bool): Whether use activateion. Default: True.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
def __init__(self,
|
| 147 |
+
in_channels,
|
| 148 |
+
out_channels,
|
| 149 |
+
kernel_size,
|
| 150 |
+
stride=1,
|
| 151 |
+
padding=0,
|
| 152 |
+
bias=True,
|
| 153 |
+
bias_init_val=0,
|
| 154 |
+
activate=True):
|
| 155 |
+
super(ConvUpLayer, self).__init__()
|
| 156 |
+
self.in_channels = in_channels
|
| 157 |
+
self.out_channels = out_channels
|
| 158 |
+
self.kernel_size = kernel_size
|
| 159 |
+
self.stride = stride
|
| 160 |
+
self.padding = padding
|
| 161 |
+
# self.scale is used to scale the convolution weights, which is related to the common initializations.
|
| 162 |
+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
| 163 |
+
|
| 164 |
+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
| 165 |
+
|
| 166 |
+
if bias and not activate:
|
| 167 |
+
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
| 168 |
+
else:
|
| 169 |
+
self.register_parameter('bias', None)
|
| 170 |
+
|
| 171 |
+
# activation
|
| 172 |
+
if activate:
|
| 173 |
+
if bias:
|
| 174 |
+
self.activation = FusedLeakyReLU(out_channels)
|
| 175 |
+
else:
|
| 176 |
+
self.activation = ScaledLeakyReLU(0.2)
|
| 177 |
+
else:
|
| 178 |
+
self.activation = None
|
| 179 |
+
|
| 180 |
+
def forward(self, x):
|
| 181 |
+
# bilinear upsample
|
| 182 |
+
out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
| 183 |
+
# conv
|
| 184 |
+
out = F.conv2d(
|
| 185 |
+
out,
|
| 186 |
+
self.weight * self.scale,
|
| 187 |
+
bias=self.bias,
|
| 188 |
+
stride=self.stride,
|
| 189 |
+
padding=self.padding,
|
| 190 |
+
)
|
| 191 |
+
# activation
|
| 192 |
+
if self.activation is not None:
|
| 193 |
+
out = self.activation(out)
|
| 194 |
+
return out
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class ResUpBlock(nn.Module):
|
| 198 |
+
"""Residual block with upsampling.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
in_channels (int): Channel number of the input.
|
| 202 |
+
out_channels (int): Channel number of the output.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def __init__(self, in_channels, out_channels):
|
| 206 |
+
super(ResUpBlock, self).__init__()
|
| 207 |
+
|
| 208 |
+
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
| 209 |
+
self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True)
|
| 210 |
+
self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False)
|
| 211 |
+
|
| 212 |
+
def forward(self, x):
|
| 213 |
+
out = self.conv1(x)
|
| 214 |
+
out = self.conv2(out)
|
| 215 |
+
skip = self.skip(x)
|
| 216 |
+
out = (out + skip) / math.sqrt(2)
|
| 217 |
+
return out
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
@ARCH_REGISTRY.register()
|
| 221 |
+
class GFPGANv1(nn.Module):
|
| 222 |
+
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
| 223 |
+
|
| 224 |
+
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
out_size (int): The spatial size of outputs.
|
| 228 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
| 229 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
| 230 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
| 231 |
+
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
| 232 |
+
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
| 233 |
+
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
| 234 |
+
|
| 235 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
| 236 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
| 237 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
| 238 |
+
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
| 239 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
| 240 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
def __init__(
|
| 244 |
+
self,
|
| 245 |
+
out_size,
|
| 246 |
+
num_style_feat=512,
|
| 247 |
+
channel_multiplier=1,
|
| 248 |
+
resample_kernel=(1, 3, 3, 1),
|
| 249 |
+
decoder_load_path=None,
|
| 250 |
+
fix_decoder=True,
|
| 251 |
+
# for stylegan decoder
|
| 252 |
+
num_mlp=8,
|
| 253 |
+
lr_mlp=0.01,
|
| 254 |
+
input_is_latent=False,
|
| 255 |
+
different_w=False,
|
| 256 |
+
narrow=1,
|
| 257 |
+
sft_half=False):
|
| 258 |
+
|
| 259 |
+
super(GFPGANv1, self).__init__()
|
| 260 |
+
self.input_is_latent = input_is_latent
|
| 261 |
+
self.different_w = different_w
|
| 262 |
+
self.num_style_feat = num_style_feat
|
| 263 |
+
|
| 264 |
+
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
| 265 |
+
channels = {
|
| 266 |
+
'4': int(512 * unet_narrow),
|
| 267 |
+
'8': int(512 * unet_narrow),
|
| 268 |
+
'16': int(512 * unet_narrow),
|
| 269 |
+
'32': int(512 * unet_narrow),
|
| 270 |
+
'64': int(256 * channel_multiplier * unet_narrow),
|
| 271 |
+
'128': int(128 * channel_multiplier * unet_narrow),
|
| 272 |
+
'256': int(64 * channel_multiplier * unet_narrow),
|
| 273 |
+
'512': int(32 * channel_multiplier * unet_narrow),
|
| 274 |
+
'1024': int(16 * channel_multiplier * unet_narrow)
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
self.log_size = int(math.log(out_size, 2))
|
| 278 |
+
first_out_size = 2**(int(math.log(out_size, 2)))
|
| 279 |
+
|
| 280 |
+
self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
|
| 281 |
+
|
| 282 |
+
# downsample
|
| 283 |
+
in_channels = channels[f'{first_out_size}']
|
| 284 |
+
self.conv_body_down = nn.ModuleList()
|
| 285 |
+
for i in range(self.log_size, 2, -1):
|
| 286 |
+
out_channels = channels[f'{2**(i - 1)}']
|
| 287 |
+
self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel))
|
| 288 |
+
in_channels = out_channels
|
| 289 |
+
|
| 290 |
+
self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
|
| 291 |
+
|
| 292 |
+
# upsample
|
| 293 |
+
in_channels = channels['4']
|
| 294 |
+
self.conv_body_up = nn.ModuleList()
|
| 295 |
+
for i in range(3, self.log_size + 1):
|
| 296 |
+
out_channels = channels[f'{2**i}']
|
| 297 |
+
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
|
| 298 |
+
in_channels = out_channels
|
| 299 |
+
|
| 300 |
+
# to RGB
|
| 301 |
+
self.toRGB = nn.ModuleList()
|
| 302 |
+
for i in range(3, self.log_size + 1):
|
| 303 |
+
self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
|
| 304 |
+
|
| 305 |
+
if different_w:
|
| 306 |
+
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
| 307 |
+
else:
|
| 308 |
+
linear_out_channel = num_style_feat
|
| 309 |
+
|
| 310 |
+
self.final_linear = EqualLinear(
|
| 311 |
+
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
| 312 |
+
|
| 313 |
+
# the decoder: stylegan2 generator with SFT modulations
|
| 314 |
+
self.stylegan_decoder = StyleGAN2GeneratorSFT(
|
| 315 |
+
out_size=out_size,
|
| 316 |
+
num_style_feat=num_style_feat,
|
| 317 |
+
num_mlp=num_mlp,
|
| 318 |
+
channel_multiplier=channel_multiplier,
|
| 319 |
+
resample_kernel=resample_kernel,
|
| 320 |
+
lr_mlp=lr_mlp,
|
| 321 |
+
narrow=narrow,
|
| 322 |
+
sft_half=sft_half)
|
| 323 |
+
|
| 324 |
+
# load pre-trained stylegan2 model if necessary
|
| 325 |
+
if decoder_load_path:
|
| 326 |
+
self.stylegan_decoder.load_state_dict(
|
| 327 |
+
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
| 328 |
+
# fix decoder without updating params
|
| 329 |
+
if fix_decoder:
|
| 330 |
+
for _, param in self.stylegan_decoder.named_parameters():
|
| 331 |
+
param.requires_grad = False
|
| 332 |
+
|
| 333 |
+
# for SFT modulations (scale and shift)
|
| 334 |
+
self.condition_scale = nn.ModuleList()
|
| 335 |
+
self.condition_shift = nn.ModuleList()
|
| 336 |
+
for i in range(3, self.log_size + 1):
|
| 337 |
+
out_channels = channels[f'{2**i}']
|
| 338 |
+
if sft_half:
|
| 339 |
+
sft_out_channels = out_channels
|
| 340 |
+
else:
|
| 341 |
+
sft_out_channels = out_channels * 2
|
| 342 |
+
self.condition_scale.append(
|
| 343 |
+
nn.Sequential(
|
| 344 |
+
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
| 345 |
+
ScaledLeakyReLU(0.2),
|
| 346 |
+
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
|
| 347 |
+
self.condition_shift.append(
|
| 348 |
+
nn.Sequential(
|
| 349 |
+
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
| 350 |
+
ScaledLeakyReLU(0.2),
|
| 351 |
+
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
|
| 352 |
+
|
| 353 |
+
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
|
| 354 |
+
"""Forward function for GFPGANv1.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
x (Tensor): Input images.
|
| 358 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
| 359 |
+
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
| 360 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
| 361 |
+
"""
|
| 362 |
+
conditions = []
|
| 363 |
+
unet_skips = []
|
| 364 |
+
out_rgbs = []
|
| 365 |
+
|
| 366 |
+
# encoder
|
| 367 |
+
feat = self.conv_body_first(x)
|
| 368 |
+
for i in range(self.log_size - 2):
|
| 369 |
+
feat = self.conv_body_down[i](feat)
|
| 370 |
+
unet_skips.insert(0, feat)
|
| 371 |
+
|
| 372 |
+
feat = self.final_conv(feat)
|
| 373 |
+
|
| 374 |
+
# style code
|
| 375 |
+
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
| 376 |
+
if self.different_w:
|
| 377 |
+
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
| 378 |
+
|
| 379 |
+
# decode
|
| 380 |
+
for i in range(self.log_size - 2):
|
| 381 |
+
# add unet skip
|
| 382 |
+
feat = feat + unet_skips[i]
|
| 383 |
+
# ResUpLayer
|
| 384 |
+
feat = self.conv_body_up[i](feat)
|
| 385 |
+
# generate scale and shift for SFT layers
|
| 386 |
+
scale = self.condition_scale[i](feat)
|
| 387 |
+
conditions.append(scale.clone())
|
| 388 |
+
shift = self.condition_shift[i](feat)
|
| 389 |
+
conditions.append(shift.clone())
|
| 390 |
+
# generate rgb images
|
| 391 |
+
if return_rgb:
|
| 392 |
+
out_rgbs.append(self.toRGB[i](feat))
|
| 393 |
+
|
| 394 |
+
# decoder
|
| 395 |
+
image, _ = self.stylegan_decoder([style_code],
|
| 396 |
+
conditions,
|
| 397 |
+
return_latents=return_latents,
|
| 398 |
+
input_is_latent=self.input_is_latent,
|
| 399 |
+
randomize_noise=randomize_noise)
|
| 400 |
+
|
| 401 |
+
return image, out_rgbs
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
@ARCH_REGISTRY.register()
|
| 405 |
+
class FacialComponentDiscriminator(nn.Module):
|
| 406 |
+
"""Facial component (eyes, mouth, noise) discriminator used in GFPGAN.
|
| 407 |
+
"""
|
| 408 |
+
|
| 409 |
+
def __init__(self):
|
| 410 |
+
super(FacialComponentDiscriminator, self).__init__()
|
| 411 |
+
# It now uses a VGG-style architectrue with fixed model size
|
| 412 |
+
self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
| 413 |
+
self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
| 414 |
+
self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
| 415 |
+
self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
| 416 |
+
self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
| 417 |
+
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
|
| 418 |
+
|
| 419 |
+
def forward(self, x, return_feats=False, **kwargs):
|
| 420 |
+
"""Forward function for FacialComponentDiscriminator.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
x (Tensor): Input images.
|
| 424 |
+
return_feats (bool): Whether to return intermediate features. Default: False.
|
| 425 |
+
"""
|
| 426 |
+
feat = self.conv1(x)
|
| 427 |
+
feat = self.conv3(self.conv2(feat))
|
| 428 |
+
rlt_feats = []
|
| 429 |
+
if return_feats:
|
| 430 |
+
rlt_feats.append(feat.clone())
|
| 431 |
+
feat = self.conv5(self.conv4(feat))
|
| 432 |
+
if return_feats:
|
| 433 |
+
rlt_feats.append(feat.clone())
|
| 434 |
+
out = self.final_conv(feat)
|
| 435 |
+
|
| 436 |
+
if return_feats:
|
| 437 |
+
return out, rlt_feats
|
| 438 |
+
else:
|
| 439 |
+
return out, None
|
gfpgan/archs/gfpganv1_clean_arch.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
from .stylegan2_clean_arch import StyleGAN2GeneratorClean
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
| 12 |
+
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
| 13 |
+
|
| 14 |
+
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
out_size (int): The spatial size of outputs.
|
| 18 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
| 19 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
| 20 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
| 21 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
| 22 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
|
| 26 |
+
super(StyleGAN2GeneratorCSFT, self).__init__(
|
| 27 |
+
out_size,
|
| 28 |
+
num_style_feat=num_style_feat,
|
| 29 |
+
num_mlp=num_mlp,
|
| 30 |
+
channel_multiplier=channel_multiplier,
|
| 31 |
+
narrow=narrow)
|
| 32 |
+
self.sft_half = sft_half
|
| 33 |
+
|
| 34 |
+
def forward(self,
|
| 35 |
+
styles,
|
| 36 |
+
conditions,
|
| 37 |
+
input_is_latent=False,
|
| 38 |
+
noise=None,
|
| 39 |
+
randomize_noise=True,
|
| 40 |
+
truncation=1,
|
| 41 |
+
truncation_latent=None,
|
| 42 |
+
inject_index=None,
|
| 43 |
+
return_latents=False):
|
| 44 |
+
"""Forward function for StyleGAN2GeneratorCSFT.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
styles (list[Tensor]): Sample codes of styles.
|
| 48 |
+
conditions (list[Tensor]): SFT conditions to generators.
|
| 49 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
| 50 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
| 51 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
| 52 |
+
truncation (float): The truncation ratio. Default: 1.
|
| 53 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
| 54 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
| 55 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
| 56 |
+
"""
|
| 57 |
+
# style codes -> latents with Style MLP layer
|
| 58 |
+
if not input_is_latent:
|
| 59 |
+
styles = [self.style_mlp(s) for s in styles]
|
| 60 |
+
# noises
|
| 61 |
+
if noise is None:
|
| 62 |
+
if randomize_noise:
|
| 63 |
+
noise = [None] * self.num_layers # for each style conv layer
|
| 64 |
+
else: # use the stored noise
|
| 65 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
| 66 |
+
# style truncation
|
| 67 |
+
if truncation < 1:
|
| 68 |
+
style_truncation = []
|
| 69 |
+
for style in styles:
|
| 70 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
| 71 |
+
styles = style_truncation
|
| 72 |
+
# get style latents with injection
|
| 73 |
+
if len(styles) == 1:
|
| 74 |
+
inject_index = self.num_latent
|
| 75 |
+
|
| 76 |
+
if styles[0].ndim < 3:
|
| 77 |
+
# repeat latent code for all the layers
|
| 78 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
| 79 |
+
else: # used for encoder with different latent code for each layer
|
| 80 |
+
latent = styles[0]
|
| 81 |
+
elif len(styles) == 2: # mixing noises
|
| 82 |
+
if inject_index is None:
|
| 83 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
| 84 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
| 85 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
| 86 |
+
latent = torch.cat([latent1, latent2], 1)
|
| 87 |
+
|
| 88 |
+
# main generation
|
| 89 |
+
out = self.constant_input(latent.shape[0])
|
| 90 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
| 91 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
| 92 |
+
|
| 93 |
+
i = 1
|
| 94 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
| 95 |
+
noise[2::2], self.to_rgbs):
|
| 96 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
| 97 |
+
|
| 98 |
+
# the conditions may have fewer levels
|
| 99 |
+
if i < len(conditions):
|
| 100 |
+
# SFT part to combine the conditions
|
| 101 |
+
if self.sft_half: # only apply SFT to half of the channels
|
| 102 |
+
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
| 103 |
+
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
| 104 |
+
out = torch.cat([out_same, out_sft], dim=1)
|
| 105 |
+
else: # apply SFT to all the channels
|
| 106 |
+
out = out * conditions[i - 1] + conditions[i]
|
| 107 |
+
|
| 108 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
| 109 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
| 110 |
+
i += 2
|
| 111 |
+
|
| 112 |
+
image = skip
|
| 113 |
+
|
| 114 |
+
if return_latents:
|
| 115 |
+
return image, latent
|
| 116 |
+
else:
|
| 117 |
+
return image, None
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class ResBlock(nn.Module):
|
| 121 |
+
"""Residual block with bilinear upsampling/downsampling.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
in_channels (int): Channel number of the input.
|
| 125 |
+
out_channels (int): Channel number of the output.
|
| 126 |
+
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, in_channels, out_channels, mode='down'):
|
| 130 |
+
super(ResBlock, self).__init__()
|
| 131 |
+
|
| 132 |
+
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
|
| 133 |
+
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
| 134 |
+
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
| 135 |
+
if mode == 'down':
|
| 136 |
+
self.scale_factor = 0.5
|
| 137 |
+
elif mode == 'up':
|
| 138 |
+
self.scale_factor = 2
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
|
| 142 |
+
# upsample/downsample
|
| 143 |
+
out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
| 144 |
+
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
|
| 145 |
+
# skip
|
| 146 |
+
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
| 147 |
+
skip = self.skip(x)
|
| 148 |
+
out = out + skip
|
| 149 |
+
return out
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@ARCH_REGISTRY.register()
|
| 153 |
+
class GFPGANv1Clean(nn.Module):
|
| 154 |
+
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
| 155 |
+
|
| 156 |
+
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
| 157 |
+
|
| 158 |
+
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
out_size (int): The spatial size of outputs.
|
| 162 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
| 163 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
| 164 |
+
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
| 165 |
+
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
| 166 |
+
|
| 167 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
| 168 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
| 169 |
+
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
| 170 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
| 171 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
out_size,
|
| 177 |
+
num_style_feat=512,
|
| 178 |
+
channel_multiplier=1,
|
| 179 |
+
decoder_load_path=None,
|
| 180 |
+
fix_decoder=True,
|
| 181 |
+
# for stylegan decoder
|
| 182 |
+
num_mlp=8,
|
| 183 |
+
input_is_latent=False,
|
| 184 |
+
different_w=False,
|
| 185 |
+
narrow=1,
|
| 186 |
+
sft_half=False):
|
| 187 |
+
|
| 188 |
+
super(GFPGANv1Clean, self).__init__()
|
| 189 |
+
self.input_is_latent = input_is_latent
|
| 190 |
+
self.different_w = different_w
|
| 191 |
+
self.num_style_feat = num_style_feat
|
| 192 |
+
|
| 193 |
+
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
| 194 |
+
channels = {
|
| 195 |
+
'4': int(512 * unet_narrow),
|
| 196 |
+
'8': int(512 * unet_narrow),
|
| 197 |
+
'16': int(512 * unet_narrow),
|
| 198 |
+
'32': int(512 * unet_narrow),
|
| 199 |
+
'64': int(256 * channel_multiplier * unet_narrow),
|
| 200 |
+
'128': int(128 * channel_multiplier * unet_narrow),
|
| 201 |
+
'256': int(64 * channel_multiplier * unet_narrow),
|
| 202 |
+
'512': int(32 * channel_multiplier * unet_narrow),
|
| 203 |
+
'1024': int(16 * channel_multiplier * unet_narrow)
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
self.log_size = int(math.log(out_size, 2))
|
| 207 |
+
first_out_size = 2**(int(math.log(out_size, 2)))
|
| 208 |
+
|
| 209 |
+
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
|
| 210 |
+
|
| 211 |
+
# downsample
|
| 212 |
+
in_channels = channels[f'{first_out_size}']
|
| 213 |
+
self.conv_body_down = nn.ModuleList()
|
| 214 |
+
for i in range(self.log_size, 2, -1):
|
| 215 |
+
out_channels = channels[f'{2**(i - 1)}']
|
| 216 |
+
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
|
| 217 |
+
in_channels = out_channels
|
| 218 |
+
|
| 219 |
+
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
|
| 220 |
+
|
| 221 |
+
# upsample
|
| 222 |
+
in_channels = channels['4']
|
| 223 |
+
self.conv_body_up = nn.ModuleList()
|
| 224 |
+
for i in range(3, self.log_size + 1):
|
| 225 |
+
out_channels = channels[f'{2**i}']
|
| 226 |
+
self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
|
| 227 |
+
in_channels = out_channels
|
| 228 |
+
|
| 229 |
+
# to RGB
|
| 230 |
+
self.toRGB = nn.ModuleList()
|
| 231 |
+
for i in range(3, self.log_size + 1):
|
| 232 |
+
self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
|
| 233 |
+
|
| 234 |
+
if different_w:
|
| 235 |
+
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
| 236 |
+
else:
|
| 237 |
+
linear_out_channel = num_style_feat
|
| 238 |
+
|
| 239 |
+
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
| 240 |
+
|
| 241 |
+
# the decoder: stylegan2 generator with SFT modulations
|
| 242 |
+
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
|
| 243 |
+
out_size=out_size,
|
| 244 |
+
num_style_feat=num_style_feat,
|
| 245 |
+
num_mlp=num_mlp,
|
| 246 |
+
channel_multiplier=channel_multiplier,
|
| 247 |
+
narrow=narrow,
|
| 248 |
+
sft_half=sft_half)
|
| 249 |
+
|
| 250 |
+
# load pre-trained stylegan2 model if necessary
|
| 251 |
+
if decoder_load_path:
|
| 252 |
+
self.stylegan_decoder.load_state_dict(
|
| 253 |
+
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
| 254 |
+
# fix decoder without updating params
|
| 255 |
+
if fix_decoder:
|
| 256 |
+
for _, param in self.stylegan_decoder.named_parameters():
|
| 257 |
+
param.requires_grad = False
|
| 258 |
+
|
| 259 |
+
# for SFT modulations (scale and shift)
|
| 260 |
+
self.condition_scale = nn.ModuleList()
|
| 261 |
+
self.condition_shift = nn.ModuleList()
|
| 262 |
+
for i in range(3, self.log_size + 1):
|
| 263 |
+
out_channels = channels[f'{2**i}']
|
| 264 |
+
if sft_half:
|
| 265 |
+
sft_out_channels = out_channels
|
| 266 |
+
else:
|
| 267 |
+
sft_out_channels = out_channels * 2
|
| 268 |
+
self.condition_scale.append(
|
| 269 |
+
nn.Sequential(
|
| 270 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
| 271 |
+
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
| 272 |
+
self.condition_shift.append(
|
| 273 |
+
nn.Sequential(
|
| 274 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
| 275 |
+
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
| 276 |
+
|
| 277 |
+
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
|
| 278 |
+
"""Forward function for GFPGANv1Clean.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
x (Tensor): Input images.
|
| 282 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
| 283 |
+
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
| 284 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
| 285 |
+
"""
|
| 286 |
+
conditions = []
|
| 287 |
+
unet_skips = []
|
| 288 |
+
out_rgbs = []
|
| 289 |
+
|
| 290 |
+
# encoder
|
| 291 |
+
feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
|
| 292 |
+
for i in range(self.log_size - 2):
|
| 293 |
+
feat = self.conv_body_down[i](feat)
|
| 294 |
+
unet_skips.insert(0, feat)
|
| 295 |
+
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
|
| 296 |
+
|
| 297 |
+
# style code
|
| 298 |
+
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
| 299 |
+
if self.different_w:
|
| 300 |
+
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
| 301 |
+
|
| 302 |
+
# decode
|
| 303 |
+
for i in range(self.log_size - 2):
|
| 304 |
+
# add unet skip
|
| 305 |
+
feat = feat + unet_skips[i]
|
| 306 |
+
# ResUpLayer
|
| 307 |
+
feat = self.conv_body_up[i](feat)
|
| 308 |
+
# generate scale and shift for SFT layers
|
| 309 |
+
scale = self.condition_scale[i](feat)
|
| 310 |
+
conditions.append(scale.clone())
|
| 311 |
+
shift = self.condition_shift[i](feat)
|
| 312 |
+
conditions.append(shift.clone())
|
| 313 |
+
# generate rgb images
|
| 314 |
+
if return_rgb:
|
| 315 |
+
out_rgbs.append(self.toRGB[i](feat))
|
| 316 |
+
|
| 317 |
+
# decoder
|
| 318 |
+
image, _ = self.stylegan_decoder([style_code],
|
| 319 |
+
conditions,
|
| 320 |
+
return_latents=return_latents,
|
| 321 |
+
input_is_latent=self.input_is_latent,
|
| 322 |
+
randomize_noise=randomize_noise)
|
| 323 |
+
|
| 324 |
+
return image, out_rgbs
|
gfpgan/archs/restoreformer_arch.py
ADDED
|
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modified from https://github.com/wzhouxiff/RestoreFormer
|
| 2 |
+
"""
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class VectorQuantizer(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
| 12 |
+
____________________________________________
|
| 13 |
+
Discretization bottleneck part of the VQ-VAE.
|
| 14 |
+
Inputs:
|
| 15 |
+
- n_e : number of embeddings
|
| 16 |
+
- e_dim : dimension of embedding
|
| 17 |
+
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
| 18 |
+
_____________________________________________
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, n_e, e_dim, beta):
|
| 22 |
+
super(VectorQuantizer, self).__init__()
|
| 23 |
+
self.n_e = n_e
|
| 24 |
+
self.e_dim = e_dim
|
| 25 |
+
self.beta = beta
|
| 26 |
+
|
| 27 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
| 28 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
| 29 |
+
|
| 30 |
+
def forward(self, z):
|
| 31 |
+
"""
|
| 32 |
+
Inputs the output of the encoder network z and maps it to a discrete
|
| 33 |
+
one-hot vector that is the index of the closest embedding vector e_j
|
| 34 |
+
z (continuous) -> z_q (discrete)
|
| 35 |
+
z.shape = (batch, channel, height, width)
|
| 36 |
+
quantization pipeline:
|
| 37 |
+
1. get encoder input (B,C,H,W)
|
| 38 |
+
2. flatten input to (B*H*W,C)
|
| 39 |
+
"""
|
| 40 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 41 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
| 42 |
+
z_flattened = z.view(-1, self.e_dim)
|
| 43 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
| 44 |
+
|
| 45 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
| 46 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
| 47 |
+
torch.matmul(z_flattened, self.embedding.weight.t())
|
| 48 |
+
|
| 49 |
+
# could possible replace this here
|
| 50 |
+
# #\start...
|
| 51 |
+
# find closest encodings
|
| 52 |
+
|
| 53 |
+
min_value, min_encoding_indices = torch.min(d, dim=1)
|
| 54 |
+
|
| 55 |
+
min_encoding_indices = min_encoding_indices.unsqueeze(1)
|
| 56 |
+
|
| 57 |
+
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
|
| 58 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
| 59 |
+
|
| 60 |
+
# dtype min encodings: torch.float32
|
| 61 |
+
# min_encodings shape: torch.Size([2048, 512])
|
| 62 |
+
# min_encoding_indices.shape: torch.Size([2048, 1])
|
| 63 |
+
|
| 64 |
+
# get quantized latent vectors
|
| 65 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
| 66 |
+
# .........\end
|
| 67 |
+
|
| 68 |
+
# with:
|
| 69 |
+
# .........\start
|
| 70 |
+
# min_encoding_indices = torch.argmin(d, dim=1)
|
| 71 |
+
# z_q = self.embedding(min_encoding_indices)
|
| 72 |
+
# ......\end......... (TODO)
|
| 73 |
+
|
| 74 |
+
# compute loss for embedding
|
| 75 |
+
loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
|
| 76 |
+
|
| 77 |
+
# preserve gradients
|
| 78 |
+
z_q = z + (z_q - z).detach()
|
| 79 |
+
|
| 80 |
+
# perplexity
|
| 81 |
+
|
| 82 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
| 83 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
| 84 |
+
|
| 85 |
+
# reshape back to match original input shape
|
| 86 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 87 |
+
|
| 88 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
|
| 89 |
+
|
| 90 |
+
def get_codebook_entry(self, indices, shape):
|
| 91 |
+
# shape specifying (batch, height, width, channel)
|
| 92 |
+
# TODO: check for more easy handling with nn.Embedding
|
| 93 |
+
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
| 94 |
+
min_encodings.scatter_(1, indices[:, None], 1)
|
| 95 |
+
|
| 96 |
+
# get quantized latent vectors
|
| 97 |
+
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
| 98 |
+
|
| 99 |
+
if shape is not None:
|
| 100 |
+
z_q = z_q.view(shape)
|
| 101 |
+
|
| 102 |
+
# reshape back to match original input shape
|
| 103 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 104 |
+
|
| 105 |
+
return z_q
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# pytorch_diffusion + derived encoder decoder
|
| 109 |
+
def nonlinearity(x):
|
| 110 |
+
# swish
|
| 111 |
+
return x * torch.sigmoid(x)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def Normalize(in_channels):
|
| 115 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class Upsample(nn.Module):
|
| 119 |
+
|
| 120 |
+
def __init__(self, in_channels, with_conv):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.with_conv = with_conv
|
| 123 |
+
if self.with_conv:
|
| 124 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode='nearest')
|
| 128 |
+
if self.with_conv:
|
| 129 |
+
x = self.conv(x)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class Downsample(nn.Module):
|
| 134 |
+
|
| 135 |
+
def __init__(self, in_channels, with_conv):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.with_conv = with_conv
|
| 138 |
+
if self.with_conv:
|
| 139 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 140 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
if self.with_conv:
|
| 144 |
+
pad = (0, 1, 0, 1)
|
| 145 |
+
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
|
| 146 |
+
x = self.conv(x)
|
| 147 |
+
else:
|
| 148 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 149 |
+
return x
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class ResnetBlock(nn.Module):
|
| 153 |
+
|
| 154 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.in_channels = in_channels
|
| 157 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 158 |
+
self.out_channels = out_channels
|
| 159 |
+
self.use_conv_shortcut = conv_shortcut
|
| 160 |
+
|
| 161 |
+
self.norm1 = Normalize(in_channels)
|
| 162 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 163 |
+
if temb_channels > 0:
|
| 164 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
| 165 |
+
self.norm2 = Normalize(out_channels)
|
| 166 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 167 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 168 |
+
if self.in_channels != self.out_channels:
|
| 169 |
+
if self.use_conv_shortcut:
|
| 170 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 171 |
+
else:
|
| 172 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 173 |
+
|
| 174 |
+
def forward(self, x, temb):
|
| 175 |
+
h = x
|
| 176 |
+
h = self.norm1(h)
|
| 177 |
+
h = nonlinearity(h)
|
| 178 |
+
h = self.conv1(h)
|
| 179 |
+
|
| 180 |
+
if temb is not None:
|
| 181 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
| 182 |
+
|
| 183 |
+
h = self.norm2(h)
|
| 184 |
+
h = nonlinearity(h)
|
| 185 |
+
h = self.dropout(h)
|
| 186 |
+
h = self.conv2(h)
|
| 187 |
+
|
| 188 |
+
if self.in_channels != self.out_channels:
|
| 189 |
+
if self.use_conv_shortcut:
|
| 190 |
+
x = self.conv_shortcut(x)
|
| 191 |
+
else:
|
| 192 |
+
x = self.nin_shortcut(x)
|
| 193 |
+
|
| 194 |
+
return x + h
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class MultiHeadAttnBlock(nn.Module):
|
| 198 |
+
|
| 199 |
+
def __init__(self, in_channels, head_size=1):
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.in_channels = in_channels
|
| 202 |
+
self.head_size = head_size
|
| 203 |
+
self.att_size = in_channels // head_size
|
| 204 |
+
assert (in_channels % head_size == 0), 'The size of head should be divided by the number of channels.'
|
| 205 |
+
|
| 206 |
+
self.norm1 = Normalize(in_channels)
|
| 207 |
+
self.norm2 = Normalize(in_channels)
|
| 208 |
+
|
| 209 |
+
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 210 |
+
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 211 |
+
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 212 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 213 |
+
self.num = 0
|
| 214 |
+
|
| 215 |
+
def forward(self, x, y=None):
|
| 216 |
+
h_ = x
|
| 217 |
+
h_ = self.norm1(h_)
|
| 218 |
+
if y is None:
|
| 219 |
+
y = h_
|
| 220 |
+
else:
|
| 221 |
+
y = self.norm2(y)
|
| 222 |
+
|
| 223 |
+
q = self.q(y)
|
| 224 |
+
k = self.k(h_)
|
| 225 |
+
v = self.v(h_)
|
| 226 |
+
|
| 227 |
+
# compute attention
|
| 228 |
+
b, c, h, w = q.shape
|
| 229 |
+
q = q.reshape(b, self.head_size, self.att_size, h * w)
|
| 230 |
+
q = q.permute(0, 3, 1, 2) # b, hw, head, att
|
| 231 |
+
|
| 232 |
+
k = k.reshape(b, self.head_size, self.att_size, h * w)
|
| 233 |
+
k = k.permute(0, 3, 1, 2)
|
| 234 |
+
|
| 235 |
+
v = v.reshape(b, self.head_size, self.att_size, h * w)
|
| 236 |
+
v = v.permute(0, 3, 1, 2)
|
| 237 |
+
|
| 238 |
+
q = q.transpose(1, 2)
|
| 239 |
+
v = v.transpose(1, 2)
|
| 240 |
+
k = k.transpose(1, 2).transpose(2, 3)
|
| 241 |
+
|
| 242 |
+
scale = int(self.att_size)**(-0.5)
|
| 243 |
+
q.mul_(scale)
|
| 244 |
+
w_ = torch.matmul(q, k)
|
| 245 |
+
w_ = F.softmax(w_, dim=3)
|
| 246 |
+
|
| 247 |
+
w_ = w_.matmul(v)
|
| 248 |
+
|
| 249 |
+
w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
|
| 250 |
+
w_ = w_.view(b, h, w, -1)
|
| 251 |
+
w_ = w_.permute(0, 3, 1, 2)
|
| 252 |
+
|
| 253 |
+
w_ = self.proj_out(w_)
|
| 254 |
+
|
| 255 |
+
return x + w_
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class MultiHeadEncoder(nn.Module):
|
| 259 |
+
|
| 260 |
+
def __init__(self,
|
| 261 |
+
ch,
|
| 262 |
+
out_ch,
|
| 263 |
+
ch_mult=(1, 2, 4, 8),
|
| 264 |
+
num_res_blocks=2,
|
| 265 |
+
attn_resolutions=(16, ),
|
| 266 |
+
dropout=0.0,
|
| 267 |
+
resamp_with_conv=True,
|
| 268 |
+
in_channels=3,
|
| 269 |
+
resolution=512,
|
| 270 |
+
z_channels=256,
|
| 271 |
+
double_z=True,
|
| 272 |
+
enable_mid=True,
|
| 273 |
+
head_size=1,
|
| 274 |
+
**ignore_kwargs):
|
| 275 |
+
super().__init__()
|
| 276 |
+
self.ch = ch
|
| 277 |
+
self.temb_ch = 0
|
| 278 |
+
self.num_resolutions = len(ch_mult)
|
| 279 |
+
self.num_res_blocks = num_res_blocks
|
| 280 |
+
self.resolution = resolution
|
| 281 |
+
self.in_channels = in_channels
|
| 282 |
+
self.enable_mid = enable_mid
|
| 283 |
+
|
| 284 |
+
# downsampling
|
| 285 |
+
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
| 286 |
+
|
| 287 |
+
curr_res = resolution
|
| 288 |
+
in_ch_mult = (1, ) + tuple(ch_mult)
|
| 289 |
+
self.down = nn.ModuleList()
|
| 290 |
+
for i_level in range(self.num_resolutions):
|
| 291 |
+
block = nn.ModuleList()
|
| 292 |
+
attn = nn.ModuleList()
|
| 293 |
+
block_in = ch * in_ch_mult[i_level]
|
| 294 |
+
block_out = ch * ch_mult[i_level]
|
| 295 |
+
for i_block in range(self.num_res_blocks):
|
| 296 |
+
block.append(
|
| 297 |
+
ResnetBlock(
|
| 298 |
+
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
| 299 |
+
block_in = block_out
|
| 300 |
+
if curr_res in attn_resolutions:
|
| 301 |
+
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
| 302 |
+
down = nn.Module()
|
| 303 |
+
down.block = block
|
| 304 |
+
down.attn = attn
|
| 305 |
+
if i_level != self.num_resolutions - 1:
|
| 306 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
| 307 |
+
curr_res = curr_res // 2
|
| 308 |
+
self.down.append(down)
|
| 309 |
+
|
| 310 |
+
# middle
|
| 311 |
+
if self.enable_mid:
|
| 312 |
+
self.mid = nn.Module()
|
| 313 |
+
self.mid.block_1 = ResnetBlock(
|
| 314 |
+
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
| 315 |
+
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
| 316 |
+
self.mid.block_2 = ResnetBlock(
|
| 317 |
+
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
| 318 |
+
|
| 319 |
+
# end
|
| 320 |
+
self.norm_out = Normalize(block_in)
|
| 321 |
+
self.conv_out = torch.nn.Conv2d(
|
| 322 |
+
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1)
|
| 323 |
+
|
| 324 |
+
def forward(self, x):
|
| 325 |
+
hs = {}
|
| 326 |
+
# timestep embedding
|
| 327 |
+
temb = None
|
| 328 |
+
|
| 329 |
+
# downsampling
|
| 330 |
+
h = self.conv_in(x)
|
| 331 |
+
hs['in'] = h
|
| 332 |
+
for i_level in range(self.num_resolutions):
|
| 333 |
+
for i_block in range(self.num_res_blocks):
|
| 334 |
+
h = self.down[i_level].block[i_block](h, temb)
|
| 335 |
+
if len(self.down[i_level].attn) > 0:
|
| 336 |
+
h = self.down[i_level].attn[i_block](h)
|
| 337 |
+
|
| 338 |
+
if i_level != self.num_resolutions - 1:
|
| 339 |
+
# hs.append(h)
|
| 340 |
+
hs['block_' + str(i_level)] = h
|
| 341 |
+
h = self.down[i_level].downsample(h)
|
| 342 |
+
|
| 343 |
+
# middle
|
| 344 |
+
# h = hs[-1]
|
| 345 |
+
if self.enable_mid:
|
| 346 |
+
h = self.mid.block_1(h, temb)
|
| 347 |
+
hs['block_' + str(i_level) + '_atten'] = h
|
| 348 |
+
h = self.mid.attn_1(h)
|
| 349 |
+
h = self.mid.block_2(h, temb)
|
| 350 |
+
hs['mid_atten'] = h
|
| 351 |
+
|
| 352 |
+
# end
|
| 353 |
+
h = self.norm_out(h)
|
| 354 |
+
h = nonlinearity(h)
|
| 355 |
+
h = self.conv_out(h)
|
| 356 |
+
# hs.append(h)
|
| 357 |
+
hs['out'] = h
|
| 358 |
+
|
| 359 |
+
return hs
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class MultiHeadDecoder(nn.Module):
|
| 363 |
+
|
| 364 |
+
def __init__(self,
|
| 365 |
+
ch,
|
| 366 |
+
out_ch,
|
| 367 |
+
ch_mult=(1, 2, 4, 8),
|
| 368 |
+
num_res_blocks=2,
|
| 369 |
+
attn_resolutions=(16, ),
|
| 370 |
+
dropout=0.0,
|
| 371 |
+
resamp_with_conv=True,
|
| 372 |
+
in_channels=3,
|
| 373 |
+
resolution=512,
|
| 374 |
+
z_channels=256,
|
| 375 |
+
give_pre_end=False,
|
| 376 |
+
enable_mid=True,
|
| 377 |
+
head_size=1,
|
| 378 |
+
**ignorekwargs):
|
| 379 |
+
super().__init__()
|
| 380 |
+
self.ch = ch
|
| 381 |
+
self.temb_ch = 0
|
| 382 |
+
self.num_resolutions = len(ch_mult)
|
| 383 |
+
self.num_res_blocks = num_res_blocks
|
| 384 |
+
self.resolution = resolution
|
| 385 |
+
self.in_channels = in_channels
|
| 386 |
+
self.give_pre_end = give_pre_end
|
| 387 |
+
self.enable_mid = enable_mid
|
| 388 |
+
|
| 389 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 390 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 391 |
+
curr_res = resolution // 2**(self.num_resolutions - 1)
|
| 392 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 393 |
+
print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape)))
|
| 394 |
+
|
| 395 |
+
# z to block_in
|
| 396 |
+
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 397 |
+
|
| 398 |
+
# middle
|
| 399 |
+
if self.enable_mid:
|
| 400 |
+
self.mid = nn.Module()
|
| 401 |
+
self.mid.block_1 = ResnetBlock(
|
| 402 |
+
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
| 403 |
+
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
| 404 |
+
self.mid.block_2 = ResnetBlock(
|
| 405 |
+
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
| 406 |
+
|
| 407 |
+
# upsampling
|
| 408 |
+
self.up = nn.ModuleList()
|
| 409 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 410 |
+
block = nn.ModuleList()
|
| 411 |
+
attn = nn.ModuleList()
|
| 412 |
+
block_out = ch * ch_mult[i_level]
|
| 413 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 414 |
+
block.append(
|
| 415 |
+
ResnetBlock(
|
| 416 |
+
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
| 417 |
+
block_in = block_out
|
| 418 |
+
if curr_res in attn_resolutions:
|
| 419 |
+
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
| 420 |
+
up = nn.Module()
|
| 421 |
+
up.block = block
|
| 422 |
+
up.attn = attn
|
| 423 |
+
if i_level != 0:
|
| 424 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 425 |
+
curr_res = curr_res * 2
|
| 426 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 427 |
+
|
| 428 |
+
# end
|
| 429 |
+
self.norm_out = Normalize(block_in)
|
| 430 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 431 |
+
|
| 432 |
+
def forward(self, z):
|
| 433 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
| 434 |
+
self.last_z_shape = z.shape
|
| 435 |
+
|
| 436 |
+
# timestep embedding
|
| 437 |
+
temb = None
|
| 438 |
+
|
| 439 |
+
# z to block_in
|
| 440 |
+
h = self.conv_in(z)
|
| 441 |
+
|
| 442 |
+
# middle
|
| 443 |
+
if self.enable_mid:
|
| 444 |
+
h = self.mid.block_1(h, temb)
|
| 445 |
+
h = self.mid.attn_1(h)
|
| 446 |
+
h = self.mid.block_2(h, temb)
|
| 447 |
+
|
| 448 |
+
# upsampling
|
| 449 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 450 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 451 |
+
h = self.up[i_level].block[i_block](h, temb)
|
| 452 |
+
if len(self.up[i_level].attn) > 0:
|
| 453 |
+
h = self.up[i_level].attn[i_block](h)
|
| 454 |
+
if i_level != 0:
|
| 455 |
+
h = self.up[i_level].upsample(h)
|
| 456 |
+
|
| 457 |
+
# end
|
| 458 |
+
if self.give_pre_end:
|
| 459 |
+
return h
|
| 460 |
+
|
| 461 |
+
h = self.norm_out(h)
|
| 462 |
+
h = nonlinearity(h)
|
| 463 |
+
h = self.conv_out(h)
|
| 464 |
+
return h
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
class MultiHeadDecoderTransformer(nn.Module):
|
| 468 |
+
|
| 469 |
+
def __init__(self,
|
| 470 |
+
ch,
|
| 471 |
+
out_ch,
|
| 472 |
+
ch_mult=(1, 2, 4, 8),
|
| 473 |
+
num_res_blocks=2,
|
| 474 |
+
attn_resolutions=(16, ),
|
| 475 |
+
dropout=0.0,
|
| 476 |
+
resamp_with_conv=True,
|
| 477 |
+
in_channels=3,
|
| 478 |
+
resolution=512,
|
| 479 |
+
z_channels=256,
|
| 480 |
+
give_pre_end=False,
|
| 481 |
+
enable_mid=True,
|
| 482 |
+
head_size=1,
|
| 483 |
+
**ignorekwargs):
|
| 484 |
+
super().__init__()
|
| 485 |
+
self.ch = ch
|
| 486 |
+
self.temb_ch = 0
|
| 487 |
+
self.num_resolutions = len(ch_mult)
|
| 488 |
+
self.num_res_blocks = num_res_blocks
|
| 489 |
+
self.resolution = resolution
|
| 490 |
+
self.in_channels = in_channels
|
| 491 |
+
self.give_pre_end = give_pre_end
|
| 492 |
+
self.enable_mid = enable_mid
|
| 493 |
+
|
| 494 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 495 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 496 |
+
curr_res = resolution // 2**(self.num_resolutions - 1)
|
| 497 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 498 |
+
print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape)))
|
| 499 |
+
|
| 500 |
+
# z to block_in
|
| 501 |
+
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 502 |
+
|
| 503 |
+
# middle
|
| 504 |
+
if self.enable_mid:
|
| 505 |
+
self.mid = nn.Module()
|
| 506 |
+
self.mid.block_1 = ResnetBlock(
|
| 507 |
+
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
| 508 |
+
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
| 509 |
+
self.mid.block_2 = ResnetBlock(
|
| 510 |
+
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
| 511 |
+
|
| 512 |
+
# upsampling
|
| 513 |
+
self.up = nn.ModuleList()
|
| 514 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 515 |
+
block = nn.ModuleList()
|
| 516 |
+
attn = nn.ModuleList()
|
| 517 |
+
block_out = ch * ch_mult[i_level]
|
| 518 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 519 |
+
block.append(
|
| 520 |
+
ResnetBlock(
|
| 521 |
+
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
| 522 |
+
block_in = block_out
|
| 523 |
+
if curr_res in attn_resolutions:
|
| 524 |
+
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
| 525 |
+
up = nn.Module()
|
| 526 |
+
up.block = block
|
| 527 |
+
up.attn = attn
|
| 528 |
+
if i_level != 0:
|
| 529 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 530 |
+
curr_res = curr_res * 2
|
| 531 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 532 |
+
|
| 533 |
+
# end
|
| 534 |
+
self.norm_out = Normalize(block_in)
|
| 535 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 536 |
+
|
| 537 |
+
def forward(self, z, hs):
|
| 538 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
| 539 |
+
# self.last_z_shape = z.shape
|
| 540 |
+
|
| 541 |
+
# timestep embedding
|
| 542 |
+
temb = None
|
| 543 |
+
|
| 544 |
+
# z to block_in
|
| 545 |
+
h = self.conv_in(z)
|
| 546 |
+
|
| 547 |
+
# middle
|
| 548 |
+
if self.enable_mid:
|
| 549 |
+
h = self.mid.block_1(h, temb)
|
| 550 |
+
h = self.mid.attn_1(h, hs['mid_atten'])
|
| 551 |
+
h = self.mid.block_2(h, temb)
|
| 552 |
+
|
| 553 |
+
# upsampling
|
| 554 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 555 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 556 |
+
h = self.up[i_level].block[i_block](h, temb)
|
| 557 |
+
if len(self.up[i_level].attn) > 0:
|
| 558 |
+
h = self.up[i_level].attn[i_block](h, hs['block_' + str(i_level) + '_atten'])
|
| 559 |
+
# hfeature = h.clone()
|
| 560 |
+
if i_level != 0:
|
| 561 |
+
h = self.up[i_level].upsample(h)
|
| 562 |
+
|
| 563 |
+
# end
|
| 564 |
+
if self.give_pre_end:
|
| 565 |
+
return h
|
| 566 |
+
|
| 567 |
+
h = self.norm_out(h)
|
| 568 |
+
h = nonlinearity(h)
|
| 569 |
+
h = self.conv_out(h)
|
| 570 |
+
return h
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
class RestoreFormer(nn.Module):
|
| 574 |
+
|
| 575 |
+
def __init__(self,
|
| 576 |
+
n_embed=1024,
|
| 577 |
+
embed_dim=256,
|
| 578 |
+
ch=64,
|
| 579 |
+
out_ch=3,
|
| 580 |
+
ch_mult=(1, 2, 2, 4, 4, 8),
|
| 581 |
+
num_res_blocks=2,
|
| 582 |
+
attn_resolutions=(16, ),
|
| 583 |
+
dropout=0.0,
|
| 584 |
+
in_channels=3,
|
| 585 |
+
resolution=512,
|
| 586 |
+
z_channels=256,
|
| 587 |
+
double_z=False,
|
| 588 |
+
enable_mid=True,
|
| 589 |
+
fix_decoder=False,
|
| 590 |
+
fix_codebook=True,
|
| 591 |
+
fix_encoder=False,
|
| 592 |
+
head_size=8):
|
| 593 |
+
super(RestoreFormer, self).__init__()
|
| 594 |
+
|
| 595 |
+
self.encoder = MultiHeadEncoder(
|
| 596 |
+
ch=ch,
|
| 597 |
+
out_ch=out_ch,
|
| 598 |
+
ch_mult=ch_mult,
|
| 599 |
+
num_res_blocks=num_res_blocks,
|
| 600 |
+
attn_resolutions=attn_resolutions,
|
| 601 |
+
dropout=dropout,
|
| 602 |
+
in_channels=in_channels,
|
| 603 |
+
resolution=resolution,
|
| 604 |
+
z_channels=z_channels,
|
| 605 |
+
double_z=double_z,
|
| 606 |
+
enable_mid=enable_mid,
|
| 607 |
+
head_size=head_size)
|
| 608 |
+
self.decoder = MultiHeadDecoderTransformer(
|
| 609 |
+
ch=ch,
|
| 610 |
+
out_ch=out_ch,
|
| 611 |
+
ch_mult=ch_mult,
|
| 612 |
+
num_res_blocks=num_res_blocks,
|
| 613 |
+
attn_resolutions=attn_resolutions,
|
| 614 |
+
dropout=dropout,
|
| 615 |
+
in_channels=in_channels,
|
| 616 |
+
resolution=resolution,
|
| 617 |
+
z_channels=z_channels,
|
| 618 |
+
enable_mid=enable_mid,
|
| 619 |
+
head_size=head_size)
|
| 620 |
+
|
| 621 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
|
| 622 |
+
|
| 623 |
+
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
|
| 624 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
|
| 625 |
+
|
| 626 |
+
if fix_decoder:
|
| 627 |
+
for _, param in self.decoder.named_parameters():
|
| 628 |
+
param.requires_grad = False
|
| 629 |
+
for _, param in self.post_quant_conv.named_parameters():
|
| 630 |
+
param.requires_grad = False
|
| 631 |
+
for _, param in self.quantize.named_parameters():
|
| 632 |
+
param.requires_grad = False
|
| 633 |
+
elif fix_codebook:
|
| 634 |
+
for _, param in self.quantize.named_parameters():
|
| 635 |
+
param.requires_grad = False
|
| 636 |
+
|
| 637 |
+
if fix_encoder:
|
| 638 |
+
for _, param in self.encoder.named_parameters():
|
| 639 |
+
param.requires_grad = False
|
| 640 |
+
|
| 641 |
+
def encode(self, x):
|
| 642 |
+
|
| 643 |
+
hs = self.encoder(x)
|
| 644 |
+
h = self.quant_conv(hs['out'])
|
| 645 |
+
quant, emb_loss, info = self.quantize(h)
|
| 646 |
+
return quant, emb_loss, info, hs
|
| 647 |
+
|
| 648 |
+
def decode(self, quant, hs):
|
| 649 |
+
quant = self.post_quant_conv(quant)
|
| 650 |
+
dec = self.decoder(quant, hs)
|
| 651 |
+
|
| 652 |
+
return dec
|
| 653 |
+
|
| 654 |
+
def forward(self, input, **kwargs):
|
| 655 |
+
quant, diff, info, hs = self.encode(input)
|
| 656 |
+
dec = self.decode(quant, hs)
|
| 657 |
+
|
| 658 |
+
return dec, None
|
gfpgan/archs/stylegan2_bilinear_arch.py
ADDED
|
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
|
| 5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class NormStyleCode(nn.Module):
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
"""Normalize the style codes.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
x (Tensor): Style codes with shape (b, c).
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
Tensor: Normalized tensor.
|
| 20 |
+
"""
|
| 21 |
+
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class EqualLinear(nn.Module):
|
| 25 |
+
"""Equalized Linear as StyleGAN2.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
in_channels (int): Size of each sample.
|
| 29 |
+
out_channels (int): Size of each output sample.
|
| 30 |
+
bias (bool): If set to ``False``, the layer will not learn an additive
|
| 31 |
+
bias. Default: ``True``.
|
| 32 |
+
bias_init_val (float): Bias initialized value. Default: 0.
|
| 33 |
+
lr_mul (float): Learning rate multiplier. Default: 1.
|
| 34 |
+
activation (None | str): The activation after ``linear`` operation.
|
| 35 |
+
Supported: 'fused_lrelu', None. Default: None.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
|
| 39 |
+
super(EqualLinear, self).__init__()
|
| 40 |
+
self.in_channels = in_channels
|
| 41 |
+
self.out_channels = out_channels
|
| 42 |
+
self.lr_mul = lr_mul
|
| 43 |
+
self.activation = activation
|
| 44 |
+
if self.activation not in ['fused_lrelu', None]:
|
| 45 |
+
raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
|
| 46 |
+
"Supported ones are: ['fused_lrelu', None].")
|
| 47 |
+
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
|
| 48 |
+
|
| 49 |
+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
|
| 50 |
+
if bias:
|
| 51 |
+
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
| 52 |
+
else:
|
| 53 |
+
self.register_parameter('bias', None)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
if self.bias is None:
|
| 57 |
+
bias = None
|
| 58 |
+
else:
|
| 59 |
+
bias = self.bias * self.lr_mul
|
| 60 |
+
if self.activation == 'fused_lrelu':
|
| 61 |
+
out = F.linear(x, self.weight * self.scale)
|
| 62 |
+
out = fused_leaky_relu(out, bias)
|
| 63 |
+
else:
|
| 64 |
+
out = F.linear(x, self.weight * self.scale, bias=bias)
|
| 65 |
+
return out
|
| 66 |
+
|
| 67 |
+
def __repr__(self):
|
| 68 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
| 69 |
+
f'out_channels={self.out_channels}, bias={self.bias is not None})')
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ModulatedConv2d(nn.Module):
|
| 73 |
+
"""Modulated Conv2d used in StyleGAN2.
|
| 74 |
+
|
| 75 |
+
There is no bias in ModulatedConv2d.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
in_channels (int): Channel number of the input.
|
| 79 |
+
out_channels (int): Channel number of the output.
|
| 80 |
+
kernel_size (int): Size of the convolving kernel.
|
| 81 |
+
num_style_feat (int): Channel number of style features.
|
| 82 |
+
demodulate (bool): Whether to demodulate in the conv layer.
|
| 83 |
+
Default: True.
|
| 84 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
| 85 |
+
Default: None.
|
| 86 |
+
eps (float): A value added to the denominator for numerical stability.
|
| 87 |
+
Default: 1e-8.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self,
|
| 91 |
+
in_channels,
|
| 92 |
+
out_channels,
|
| 93 |
+
kernel_size,
|
| 94 |
+
num_style_feat,
|
| 95 |
+
demodulate=True,
|
| 96 |
+
sample_mode=None,
|
| 97 |
+
eps=1e-8,
|
| 98 |
+
interpolation_mode='bilinear'):
|
| 99 |
+
super(ModulatedConv2d, self).__init__()
|
| 100 |
+
self.in_channels = in_channels
|
| 101 |
+
self.out_channels = out_channels
|
| 102 |
+
self.kernel_size = kernel_size
|
| 103 |
+
self.demodulate = demodulate
|
| 104 |
+
self.sample_mode = sample_mode
|
| 105 |
+
self.eps = eps
|
| 106 |
+
self.interpolation_mode = interpolation_mode
|
| 107 |
+
if self.interpolation_mode == 'nearest':
|
| 108 |
+
self.align_corners = None
|
| 109 |
+
else:
|
| 110 |
+
self.align_corners = False
|
| 111 |
+
|
| 112 |
+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
| 113 |
+
# modulation inside each modulated conv
|
| 114 |
+
self.modulation = EqualLinear(
|
| 115 |
+
num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
|
| 116 |
+
|
| 117 |
+
self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
|
| 118 |
+
self.padding = kernel_size // 2
|
| 119 |
+
|
| 120 |
+
def forward(self, x, style):
|
| 121 |
+
"""Forward function.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
x (Tensor): Tensor with shape (b, c, h, w).
|
| 125 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Tensor: Modulated tensor after convolution.
|
| 129 |
+
"""
|
| 130 |
+
b, c, h, w = x.shape # c = c_in
|
| 131 |
+
# weight modulation
|
| 132 |
+
style = self.modulation(style).view(b, 1, c, 1, 1)
|
| 133 |
+
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
| 134 |
+
weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
|
| 135 |
+
|
| 136 |
+
if self.demodulate:
|
| 137 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
| 138 |
+
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
| 139 |
+
|
| 140 |
+
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
| 141 |
+
|
| 142 |
+
if self.sample_mode == 'upsample':
|
| 143 |
+
x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
|
| 144 |
+
elif self.sample_mode == 'downsample':
|
| 145 |
+
x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)
|
| 146 |
+
|
| 147 |
+
b, c, h, w = x.shape
|
| 148 |
+
x = x.view(1, b * c, h, w)
|
| 149 |
+
# weight: (b*c_out, c_in, k, k), groups=b
|
| 150 |
+
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
| 151 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
| 152 |
+
|
| 153 |
+
return out
|
| 154 |
+
|
| 155 |
+
def __repr__(self):
|
| 156 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
| 157 |
+
f'out_channels={self.out_channels}, '
|
| 158 |
+
f'kernel_size={self.kernel_size}, '
|
| 159 |
+
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class StyleConv(nn.Module):
|
| 163 |
+
"""Style conv.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
in_channels (int): Channel number of the input.
|
| 167 |
+
out_channels (int): Channel number of the output.
|
| 168 |
+
kernel_size (int): Size of the convolving kernel.
|
| 169 |
+
num_style_feat (int): Channel number of style features.
|
| 170 |
+
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
| 171 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
| 172 |
+
Default: None.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(self,
|
| 176 |
+
in_channels,
|
| 177 |
+
out_channels,
|
| 178 |
+
kernel_size,
|
| 179 |
+
num_style_feat,
|
| 180 |
+
demodulate=True,
|
| 181 |
+
sample_mode=None,
|
| 182 |
+
interpolation_mode='bilinear'):
|
| 183 |
+
super(StyleConv, self).__init__()
|
| 184 |
+
self.modulated_conv = ModulatedConv2d(
|
| 185 |
+
in_channels,
|
| 186 |
+
out_channels,
|
| 187 |
+
kernel_size,
|
| 188 |
+
num_style_feat,
|
| 189 |
+
demodulate=demodulate,
|
| 190 |
+
sample_mode=sample_mode,
|
| 191 |
+
interpolation_mode=interpolation_mode)
|
| 192 |
+
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
| 193 |
+
self.activate = FusedLeakyReLU(out_channels)
|
| 194 |
+
|
| 195 |
+
def forward(self, x, style, noise=None):
|
| 196 |
+
# modulate
|
| 197 |
+
out = self.modulated_conv(x, style)
|
| 198 |
+
# noise injection
|
| 199 |
+
if noise is None:
|
| 200 |
+
b, _, h, w = out.shape
|
| 201 |
+
noise = out.new_empty(b, 1, h, w).normal_()
|
| 202 |
+
out = out + self.weight * noise
|
| 203 |
+
# activation (with bias)
|
| 204 |
+
out = self.activate(out)
|
| 205 |
+
return out
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class ToRGB(nn.Module):
|
| 209 |
+
"""To RGB from features.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
in_channels (int): Channel number of input.
|
| 213 |
+
num_style_feat (int): Channel number of style features.
|
| 214 |
+
upsample (bool): Whether to upsample. Default: True.
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):
|
| 218 |
+
super(ToRGB, self).__init__()
|
| 219 |
+
self.upsample = upsample
|
| 220 |
+
self.interpolation_mode = interpolation_mode
|
| 221 |
+
if self.interpolation_mode == 'nearest':
|
| 222 |
+
self.align_corners = None
|
| 223 |
+
else:
|
| 224 |
+
self.align_corners = False
|
| 225 |
+
self.modulated_conv = ModulatedConv2d(
|
| 226 |
+
in_channels,
|
| 227 |
+
3,
|
| 228 |
+
kernel_size=1,
|
| 229 |
+
num_style_feat=num_style_feat,
|
| 230 |
+
demodulate=False,
|
| 231 |
+
sample_mode=None,
|
| 232 |
+
interpolation_mode=interpolation_mode)
|
| 233 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
| 234 |
+
|
| 235 |
+
def forward(self, x, style, skip=None):
|
| 236 |
+
"""Forward function.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
x (Tensor): Feature tensor with shape (b, c, h, w).
|
| 240 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
| 241 |
+
skip (Tensor): Base/skip tensor. Default: None.
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
Tensor: RGB images.
|
| 245 |
+
"""
|
| 246 |
+
out = self.modulated_conv(x, style)
|
| 247 |
+
out = out + self.bias
|
| 248 |
+
if skip is not None:
|
| 249 |
+
if self.upsample:
|
| 250 |
+
skip = F.interpolate(
|
| 251 |
+
skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
|
| 252 |
+
out = out + skip
|
| 253 |
+
return out
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class ConstantInput(nn.Module):
|
| 257 |
+
"""Constant input.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
num_channel (int): Channel number of constant input.
|
| 261 |
+
size (int): Spatial size of constant input.
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
def __init__(self, num_channel, size):
|
| 265 |
+
super(ConstantInput, self).__init__()
|
| 266 |
+
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
| 267 |
+
|
| 268 |
+
def forward(self, batch):
|
| 269 |
+
out = self.weight.repeat(batch, 1, 1, 1)
|
| 270 |
+
return out
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
@ARCH_REGISTRY.register()
|
| 274 |
+
class StyleGAN2GeneratorBilinear(nn.Module):
|
| 275 |
+
"""StyleGAN2 Generator.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
out_size (int): The spatial size of outputs.
|
| 279 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
| 280 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
| 281 |
+
channel_multiplier (int): Channel multiplier for large networks of
|
| 282 |
+
StyleGAN2. Default: 2.
|
| 283 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
| 284 |
+
narrow (float): Narrow ratio for channels. Default: 1.0.
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
def __init__(self,
|
| 288 |
+
out_size,
|
| 289 |
+
num_style_feat=512,
|
| 290 |
+
num_mlp=8,
|
| 291 |
+
channel_multiplier=2,
|
| 292 |
+
lr_mlp=0.01,
|
| 293 |
+
narrow=1,
|
| 294 |
+
interpolation_mode='bilinear'):
|
| 295 |
+
super(StyleGAN2GeneratorBilinear, self).__init__()
|
| 296 |
+
# Style MLP layers
|
| 297 |
+
self.num_style_feat = num_style_feat
|
| 298 |
+
style_mlp_layers = [NormStyleCode()]
|
| 299 |
+
for i in range(num_mlp):
|
| 300 |
+
style_mlp_layers.append(
|
| 301 |
+
EqualLinear(
|
| 302 |
+
num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
|
| 303 |
+
activation='fused_lrelu'))
|
| 304 |
+
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
| 305 |
+
|
| 306 |
+
channels = {
|
| 307 |
+
'4': int(512 * narrow),
|
| 308 |
+
'8': int(512 * narrow),
|
| 309 |
+
'16': int(512 * narrow),
|
| 310 |
+
'32': int(512 * narrow),
|
| 311 |
+
'64': int(256 * channel_multiplier * narrow),
|
| 312 |
+
'128': int(128 * channel_multiplier * narrow),
|
| 313 |
+
'256': int(64 * channel_multiplier * narrow),
|
| 314 |
+
'512': int(32 * channel_multiplier * narrow),
|
| 315 |
+
'1024': int(16 * channel_multiplier * narrow)
|
| 316 |
+
}
|
| 317 |
+
self.channels = channels
|
| 318 |
+
|
| 319 |
+
self.constant_input = ConstantInput(channels['4'], size=4)
|
| 320 |
+
self.style_conv1 = StyleConv(
|
| 321 |
+
channels['4'],
|
| 322 |
+
channels['4'],
|
| 323 |
+
kernel_size=3,
|
| 324 |
+
num_style_feat=num_style_feat,
|
| 325 |
+
demodulate=True,
|
| 326 |
+
sample_mode=None,
|
| 327 |
+
interpolation_mode=interpolation_mode)
|
| 328 |
+
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
|
| 329 |
+
|
| 330 |
+
self.log_size = int(math.log(out_size, 2))
|
| 331 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
| 332 |
+
self.num_latent = self.log_size * 2 - 2
|
| 333 |
+
|
| 334 |
+
self.style_convs = nn.ModuleList()
|
| 335 |
+
self.to_rgbs = nn.ModuleList()
|
| 336 |
+
self.noises = nn.Module()
|
| 337 |
+
|
| 338 |
+
in_channels = channels['4']
|
| 339 |
+
# noise
|
| 340 |
+
for layer_idx in range(self.num_layers):
|
| 341 |
+
resolution = 2**((layer_idx + 5) // 2)
|
| 342 |
+
shape = [1, 1, resolution, resolution]
|
| 343 |
+
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
| 344 |
+
# style convs and to_rgbs
|
| 345 |
+
for i in range(3, self.log_size + 1):
|
| 346 |
+
out_channels = channels[f'{2**i}']
|
| 347 |
+
self.style_convs.append(
|
| 348 |
+
StyleConv(
|
| 349 |
+
in_channels,
|
| 350 |
+
out_channels,
|
| 351 |
+
kernel_size=3,
|
| 352 |
+
num_style_feat=num_style_feat,
|
| 353 |
+
demodulate=True,
|
| 354 |
+
sample_mode='upsample',
|
| 355 |
+
interpolation_mode=interpolation_mode))
|
| 356 |
+
self.style_convs.append(
|
| 357 |
+
StyleConv(
|
| 358 |
+
out_channels,
|
| 359 |
+
out_channels,
|
| 360 |
+
kernel_size=3,
|
| 361 |
+
num_style_feat=num_style_feat,
|
| 362 |
+
demodulate=True,
|
| 363 |
+
sample_mode=None,
|
| 364 |
+
interpolation_mode=interpolation_mode))
|
| 365 |
+
self.to_rgbs.append(
|
| 366 |
+
ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
|
| 367 |
+
in_channels = out_channels
|
| 368 |
+
|
| 369 |
+
def make_noise(self):
|
| 370 |
+
"""Make noise for noise injection."""
|
| 371 |
+
device = self.constant_input.weight.device
|
| 372 |
+
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
| 373 |
+
|
| 374 |
+
for i in range(3, self.log_size + 1):
|
| 375 |
+
for _ in range(2):
|
| 376 |
+
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
| 377 |
+
|
| 378 |
+
return noises
|
| 379 |
+
|
| 380 |
+
def get_latent(self, x):
|
| 381 |
+
return self.style_mlp(x)
|
| 382 |
+
|
| 383 |
+
def mean_latent(self, num_latent):
|
| 384 |
+
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
| 385 |
+
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
| 386 |
+
return latent
|
| 387 |
+
|
| 388 |
+
def forward(self,
|
| 389 |
+
styles,
|
| 390 |
+
input_is_latent=False,
|
| 391 |
+
noise=None,
|
| 392 |
+
randomize_noise=True,
|
| 393 |
+
truncation=1,
|
| 394 |
+
truncation_latent=None,
|
| 395 |
+
inject_index=None,
|
| 396 |
+
return_latents=False):
|
| 397 |
+
"""Forward function for StyleGAN2Generator.
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
styles (list[Tensor]): Sample codes of styles.
|
| 401 |
+
input_is_latent (bool): Whether input is latent style.
|
| 402 |
+
Default: False.
|
| 403 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
| 404 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is
|
| 405 |
+
False. Default: True.
|
| 406 |
+
truncation (float): TODO. Default: 1.
|
| 407 |
+
truncation_latent (Tensor | None): TODO. Default: None.
|
| 408 |
+
inject_index (int | None): The injection index for mixing noise.
|
| 409 |
+
Default: None.
|
| 410 |
+
return_latents (bool): Whether to return style latents.
|
| 411 |
+
Default: False.
|
| 412 |
+
"""
|
| 413 |
+
# style codes -> latents with Style MLP layer
|
| 414 |
+
if not input_is_latent:
|
| 415 |
+
styles = [self.style_mlp(s) for s in styles]
|
| 416 |
+
# noises
|
| 417 |
+
if noise is None:
|
| 418 |
+
if randomize_noise:
|
| 419 |
+
noise = [None] * self.num_layers # for each style conv layer
|
| 420 |
+
else: # use the stored noise
|
| 421 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
| 422 |
+
# style truncation
|
| 423 |
+
if truncation < 1:
|
| 424 |
+
style_truncation = []
|
| 425 |
+
for style in styles:
|
| 426 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
| 427 |
+
styles = style_truncation
|
| 428 |
+
# get style latent with injection
|
| 429 |
+
if len(styles) == 1:
|
| 430 |
+
inject_index = self.num_latent
|
| 431 |
+
|
| 432 |
+
if styles[0].ndim < 3:
|
| 433 |
+
# repeat latent code for all the layers
|
| 434 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
| 435 |
+
else: # used for encoder with different latent code for each layer
|
| 436 |
+
latent = styles[0]
|
| 437 |
+
elif len(styles) == 2: # mixing noises
|
| 438 |
+
if inject_index is None:
|
| 439 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
| 440 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
| 441 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
| 442 |
+
latent = torch.cat([latent1, latent2], 1)
|
| 443 |
+
|
| 444 |
+
# main generation
|
| 445 |
+
out = self.constant_input(latent.shape[0])
|
| 446 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
| 447 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
| 448 |
+
|
| 449 |
+
i = 1
|
| 450 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
| 451 |
+
noise[2::2], self.to_rgbs):
|
| 452 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
| 453 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
| 454 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
| 455 |
+
i += 2
|
| 456 |
+
|
| 457 |
+
image = skip
|
| 458 |
+
|
| 459 |
+
if return_latents:
|
| 460 |
+
return image, latent
|
| 461 |
+
else:
|
| 462 |
+
return image, None
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
class ScaledLeakyReLU(nn.Module):
|
| 466 |
+
"""Scaled LeakyReLU.
|
| 467 |
+
|
| 468 |
+
Args:
|
| 469 |
+
negative_slope (float): Negative slope. Default: 0.2.
|
| 470 |
+
"""
|
| 471 |
+
|
| 472 |
+
def __init__(self, negative_slope=0.2):
|
| 473 |
+
super(ScaledLeakyReLU, self).__init__()
|
| 474 |
+
self.negative_slope = negative_slope
|
| 475 |
+
|
| 476 |
+
def forward(self, x):
|
| 477 |
+
out = F.leaky_relu(x, negative_slope=self.negative_slope)
|
| 478 |
+
return out * math.sqrt(2)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class EqualConv2d(nn.Module):
|
| 482 |
+
"""Equalized Linear as StyleGAN2.
|
| 483 |
+
|
| 484 |
+
Args:
|
| 485 |
+
in_channels (int): Channel number of the input.
|
| 486 |
+
out_channels (int): Channel number of the output.
|
| 487 |
+
kernel_size (int): Size of the convolving kernel.
|
| 488 |
+
stride (int): Stride of the convolution. Default: 1
|
| 489 |
+
padding (int): Zero-padding added to both sides of the input.
|
| 490 |
+
Default: 0.
|
| 491 |
+
bias (bool): If ``True``, adds a learnable bias to the output.
|
| 492 |
+
Default: ``True``.
|
| 493 |
+
bias_init_val (float): Bias initialized value. Default: 0.
|
| 494 |
+
"""
|
| 495 |
+
|
| 496 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
|
| 497 |
+
super(EqualConv2d, self).__init__()
|
| 498 |
+
self.in_channels = in_channels
|
| 499 |
+
self.out_channels = out_channels
|
| 500 |
+
self.kernel_size = kernel_size
|
| 501 |
+
self.stride = stride
|
| 502 |
+
self.padding = padding
|
| 503 |
+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
| 504 |
+
|
| 505 |
+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
| 506 |
+
if bias:
|
| 507 |
+
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
| 508 |
+
else:
|
| 509 |
+
self.register_parameter('bias', None)
|
| 510 |
+
|
| 511 |
+
def forward(self, x):
|
| 512 |
+
out = F.conv2d(
|
| 513 |
+
x,
|
| 514 |
+
self.weight * self.scale,
|
| 515 |
+
bias=self.bias,
|
| 516 |
+
stride=self.stride,
|
| 517 |
+
padding=self.padding,
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
return out
|
| 521 |
+
|
| 522 |
+
def __repr__(self):
|
| 523 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
| 524 |
+
f'out_channels={self.out_channels}, '
|
| 525 |
+
f'kernel_size={self.kernel_size},'
|
| 526 |
+
f' stride={self.stride}, padding={self.padding}, '
|
| 527 |
+
f'bias={self.bias is not None})')
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class ConvLayer(nn.Sequential):
|
| 531 |
+
"""Conv Layer used in StyleGAN2 Discriminator.
|
| 532 |
+
|
| 533 |
+
Args:
|
| 534 |
+
in_channels (int): Channel number of the input.
|
| 535 |
+
out_channels (int): Channel number of the output.
|
| 536 |
+
kernel_size (int): Kernel size.
|
| 537 |
+
downsample (bool): Whether downsample by a factor of 2.
|
| 538 |
+
Default: False.
|
| 539 |
+
bias (bool): Whether with bias. Default: True.
|
| 540 |
+
activate (bool): Whether use activateion. Default: True.
|
| 541 |
+
"""
|
| 542 |
+
|
| 543 |
+
def __init__(self,
|
| 544 |
+
in_channels,
|
| 545 |
+
out_channels,
|
| 546 |
+
kernel_size,
|
| 547 |
+
downsample=False,
|
| 548 |
+
bias=True,
|
| 549 |
+
activate=True,
|
| 550 |
+
interpolation_mode='bilinear'):
|
| 551 |
+
layers = []
|
| 552 |
+
self.interpolation_mode = interpolation_mode
|
| 553 |
+
# downsample
|
| 554 |
+
if downsample:
|
| 555 |
+
if self.interpolation_mode == 'nearest':
|
| 556 |
+
self.align_corners = None
|
| 557 |
+
else:
|
| 558 |
+
self.align_corners = False
|
| 559 |
+
|
| 560 |
+
layers.append(
|
| 561 |
+
torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
|
| 562 |
+
stride = 1
|
| 563 |
+
self.padding = kernel_size // 2
|
| 564 |
+
# conv
|
| 565 |
+
layers.append(
|
| 566 |
+
EqualConv2d(
|
| 567 |
+
in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
|
| 568 |
+
and not activate))
|
| 569 |
+
# activation
|
| 570 |
+
if activate:
|
| 571 |
+
if bias:
|
| 572 |
+
layers.append(FusedLeakyReLU(out_channels))
|
| 573 |
+
else:
|
| 574 |
+
layers.append(ScaledLeakyReLU(0.2))
|
| 575 |
+
|
| 576 |
+
super(ConvLayer, self).__init__(*layers)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
class ResBlock(nn.Module):
|
| 580 |
+
"""Residual block used in StyleGAN2 Discriminator.
|
| 581 |
+
|
| 582 |
+
Args:
|
| 583 |
+
in_channels (int): Channel number of the input.
|
| 584 |
+
out_channels (int): Channel number of the output.
|
| 585 |
+
"""
|
| 586 |
+
|
| 587 |
+
def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
|
| 588 |
+
super(ResBlock, self).__init__()
|
| 589 |
+
|
| 590 |
+
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
| 591 |
+
self.conv2 = ConvLayer(
|
| 592 |
+
in_channels,
|
| 593 |
+
out_channels,
|
| 594 |
+
3,
|
| 595 |
+
downsample=True,
|
| 596 |
+
interpolation_mode=interpolation_mode,
|
| 597 |
+
bias=True,
|
| 598 |
+
activate=True)
|
| 599 |
+
self.skip = ConvLayer(
|
| 600 |
+
in_channels,
|
| 601 |
+
out_channels,
|
| 602 |
+
1,
|
| 603 |
+
downsample=True,
|
| 604 |
+
interpolation_mode=interpolation_mode,
|
| 605 |
+
bias=False,
|
| 606 |
+
activate=False)
|
| 607 |
+
|
| 608 |
+
def forward(self, x):
|
| 609 |
+
out = self.conv1(x)
|
| 610 |
+
out = self.conv2(out)
|
| 611 |
+
skip = self.skip(x)
|
| 612 |
+
out = (out + skip) / math.sqrt(2)
|
| 613 |
+
return out
|
gfpgan/archs/stylegan2_clean_arch.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from basicsr.archs.arch_util import default_init_weights
|
| 5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class NormStyleCode(nn.Module):
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
"""Normalize the style codes.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
x (Tensor): Style codes with shape (b, c).
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
Tensor: Normalized tensor.
|
| 20 |
+
"""
|
| 21 |
+
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ModulatedConv2d(nn.Module):
|
| 25 |
+
"""Modulated Conv2d used in StyleGAN2.
|
| 26 |
+
|
| 27 |
+
There is no bias in ModulatedConv2d.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
in_channels (int): Channel number of the input.
|
| 31 |
+
out_channels (int): Channel number of the output.
|
| 32 |
+
kernel_size (int): Size of the convolving kernel.
|
| 33 |
+
num_style_feat (int): Channel number of style features.
|
| 34 |
+
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
|
| 35 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
| 36 |
+
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self,
|
| 40 |
+
in_channels,
|
| 41 |
+
out_channels,
|
| 42 |
+
kernel_size,
|
| 43 |
+
num_style_feat,
|
| 44 |
+
demodulate=True,
|
| 45 |
+
sample_mode=None,
|
| 46 |
+
eps=1e-8):
|
| 47 |
+
super(ModulatedConv2d, self).__init__()
|
| 48 |
+
self.in_channels = in_channels
|
| 49 |
+
self.out_channels = out_channels
|
| 50 |
+
self.kernel_size = kernel_size
|
| 51 |
+
self.demodulate = demodulate
|
| 52 |
+
self.sample_mode = sample_mode
|
| 53 |
+
self.eps = eps
|
| 54 |
+
|
| 55 |
+
# modulation inside each modulated conv
|
| 56 |
+
self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
|
| 57 |
+
# initialization
|
| 58 |
+
default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
|
| 59 |
+
|
| 60 |
+
self.weight = nn.Parameter(
|
| 61 |
+
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
|
| 62 |
+
math.sqrt(in_channels * kernel_size**2))
|
| 63 |
+
self.padding = kernel_size // 2
|
| 64 |
+
|
| 65 |
+
def forward(self, x, style):
|
| 66 |
+
"""Forward function.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
x (Tensor): Tensor with shape (b, c, h, w).
|
| 70 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Tensor: Modulated tensor after convolution.
|
| 74 |
+
"""
|
| 75 |
+
b, c, h, w = x.shape # c = c_in
|
| 76 |
+
# weight modulation
|
| 77 |
+
style = self.modulation(style).view(b, 1, c, 1, 1)
|
| 78 |
+
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
| 79 |
+
weight = self.weight * style # (b, c_out, c_in, k, k)
|
| 80 |
+
|
| 81 |
+
if self.demodulate:
|
| 82 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
| 83 |
+
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
| 84 |
+
|
| 85 |
+
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
| 86 |
+
|
| 87 |
+
# upsample or downsample if necessary
|
| 88 |
+
if self.sample_mode == 'upsample':
|
| 89 |
+
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
| 90 |
+
elif self.sample_mode == 'downsample':
|
| 91 |
+
x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
|
| 92 |
+
|
| 93 |
+
b, c, h, w = x.shape
|
| 94 |
+
x = x.view(1, b * c, h, w)
|
| 95 |
+
# weight: (b*c_out, c_in, k, k), groups=b
|
| 96 |
+
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
| 97 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
| 98 |
+
|
| 99 |
+
return out
|
| 100 |
+
|
| 101 |
+
def __repr__(self):
|
| 102 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
|
| 103 |
+
f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class StyleConv(nn.Module):
|
| 107 |
+
"""Style conv used in StyleGAN2.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
in_channels (int): Channel number of the input.
|
| 111 |
+
out_channels (int): Channel number of the output.
|
| 112 |
+
kernel_size (int): Size of the convolving kernel.
|
| 113 |
+
num_style_feat (int): Channel number of style features.
|
| 114 |
+
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
| 115 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
|
| 119 |
+
super(StyleConv, self).__init__()
|
| 120 |
+
self.modulated_conv = ModulatedConv2d(
|
| 121 |
+
in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
|
| 122 |
+
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
| 123 |
+
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
|
| 124 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 125 |
+
|
| 126 |
+
def forward(self, x, style, noise=None):
|
| 127 |
+
# modulate
|
| 128 |
+
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
|
| 129 |
+
# noise injection
|
| 130 |
+
if noise is None:
|
| 131 |
+
b, _, h, w = out.shape
|
| 132 |
+
noise = out.new_empty(b, 1, h, w).normal_()
|
| 133 |
+
out = out + self.weight * noise
|
| 134 |
+
# add bias
|
| 135 |
+
out = out + self.bias
|
| 136 |
+
# activation
|
| 137 |
+
out = self.activate(out)
|
| 138 |
+
return out
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class ToRGB(nn.Module):
|
| 142 |
+
"""To RGB (image space) from features.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
in_channels (int): Channel number of input.
|
| 146 |
+
num_style_feat (int): Channel number of style features.
|
| 147 |
+
upsample (bool): Whether to upsample. Default: True.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(self, in_channels, num_style_feat, upsample=True):
|
| 151 |
+
super(ToRGB, self).__init__()
|
| 152 |
+
self.upsample = upsample
|
| 153 |
+
self.modulated_conv = ModulatedConv2d(
|
| 154 |
+
in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
|
| 155 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
| 156 |
+
|
| 157 |
+
def forward(self, x, style, skip=None):
|
| 158 |
+
"""Forward function.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
x (Tensor): Feature tensor with shape (b, c, h, w).
|
| 162 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
| 163 |
+
skip (Tensor): Base/skip tensor. Default: None.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
Tensor: RGB images.
|
| 167 |
+
"""
|
| 168 |
+
out = self.modulated_conv(x, style)
|
| 169 |
+
out = out + self.bias
|
| 170 |
+
if skip is not None:
|
| 171 |
+
if self.upsample:
|
| 172 |
+
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
|
| 173 |
+
out = out + skip
|
| 174 |
+
return out
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class ConstantInput(nn.Module):
|
| 178 |
+
"""Constant input.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
num_channel (int): Channel number of constant input.
|
| 182 |
+
size (int): Spatial size of constant input.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
def __init__(self, num_channel, size):
|
| 186 |
+
super(ConstantInput, self).__init__()
|
| 187 |
+
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
| 188 |
+
|
| 189 |
+
def forward(self, batch):
|
| 190 |
+
out = self.weight.repeat(batch, 1, 1, 1)
|
| 191 |
+
return out
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
@ARCH_REGISTRY.register()
|
| 195 |
+
class StyleGAN2GeneratorClean(nn.Module):
|
| 196 |
+
"""Clean version of StyleGAN2 Generator.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
out_size (int): The spatial size of outputs.
|
| 200 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
| 201 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
| 202 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
| 203 |
+
narrow (float): Narrow ratio for channels. Default: 1.0.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
|
| 207 |
+
super(StyleGAN2GeneratorClean, self).__init__()
|
| 208 |
+
# Style MLP layers
|
| 209 |
+
self.num_style_feat = num_style_feat
|
| 210 |
+
style_mlp_layers = [NormStyleCode()]
|
| 211 |
+
for i in range(num_mlp):
|
| 212 |
+
style_mlp_layers.extend(
|
| 213 |
+
[nn.Linear(num_style_feat, num_style_feat, bias=True),
|
| 214 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True)])
|
| 215 |
+
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
| 216 |
+
# initialization
|
| 217 |
+
default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
|
| 218 |
+
|
| 219 |
+
# channel list
|
| 220 |
+
channels = {
|
| 221 |
+
'4': int(512 * narrow),
|
| 222 |
+
'8': int(512 * narrow),
|
| 223 |
+
'16': int(512 * narrow),
|
| 224 |
+
'32': int(512 * narrow),
|
| 225 |
+
'64': int(256 * channel_multiplier * narrow),
|
| 226 |
+
'128': int(128 * channel_multiplier * narrow),
|
| 227 |
+
'256': int(64 * channel_multiplier * narrow),
|
| 228 |
+
'512': int(32 * channel_multiplier * narrow),
|
| 229 |
+
'1024': int(16 * channel_multiplier * narrow)
|
| 230 |
+
}
|
| 231 |
+
self.channels = channels
|
| 232 |
+
|
| 233 |
+
self.constant_input = ConstantInput(channels['4'], size=4)
|
| 234 |
+
self.style_conv1 = StyleConv(
|
| 235 |
+
channels['4'],
|
| 236 |
+
channels['4'],
|
| 237 |
+
kernel_size=3,
|
| 238 |
+
num_style_feat=num_style_feat,
|
| 239 |
+
demodulate=True,
|
| 240 |
+
sample_mode=None)
|
| 241 |
+
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
|
| 242 |
+
|
| 243 |
+
self.log_size = int(math.log(out_size, 2))
|
| 244 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
| 245 |
+
self.num_latent = self.log_size * 2 - 2
|
| 246 |
+
|
| 247 |
+
self.style_convs = nn.ModuleList()
|
| 248 |
+
self.to_rgbs = nn.ModuleList()
|
| 249 |
+
self.noises = nn.Module()
|
| 250 |
+
|
| 251 |
+
in_channels = channels['4']
|
| 252 |
+
# noise
|
| 253 |
+
for layer_idx in range(self.num_layers):
|
| 254 |
+
resolution = 2**((layer_idx + 5) // 2)
|
| 255 |
+
shape = [1, 1, resolution, resolution]
|
| 256 |
+
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
| 257 |
+
# style convs and to_rgbs
|
| 258 |
+
for i in range(3, self.log_size + 1):
|
| 259 |
+
out_channels = channels[f'{2**i}']
|
| 260 |
+
self.style_convs.append(
|
| 261 |
+
StyleConv(
|
| 262 |
+
in_channels,
|
| 263 |
+
out_channels,
|
| 264 |
+
kernel_size=3,
|
| 265 |
+
num_style_feat=num_style_feat,
|
| 266 |
+
demodulate=True,
|
| 267 |
+
sample_mode='upsample'))
|
| 268 |
+
self.style_convs.append(
|
| 269 |
+
StyleConv(
|
| 270 |
+
out_channels,
|
| 271 |
+
out_channels,
|
| 272 |
+
kernel_size=3,
|
| 273 |
+
num_style_feat=num_style_feat,
|
| 274 |
+
demodulate=True,
|
| 275 |
+
sample_mode=None))
|
| 276 |
+
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
|
| 277 |
+
in_channels = out_channels
|
| 278 |
+
|
| 279 |
+
def make_noise(self):
|
| 280 |
+
"""Make noise for noise injection."""
|
| 281 |
+
device = self.constant_input.weight.device
|
| 282 |
+
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
| 283 |
+
|
| 284 |
+
for i in range(3, self.log_size + 1):
|
| 285 |
+
for _ in range(2):
|
| 286 |
+
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
| 287 |
+
|
| 288 |
+
return noises
|
| 289 |
+
|
| 290 |
+
def get_latent(self, x):
|
| 291 |
+
return self.style_mlp(x)
|
| 292 |
+
|
| 293 |
+
def mean_latent(self, num_latent):
|
| 294 |
+
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
| 295 |
+
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
| 296 |
+
return latent
|
| 297 |
+
|
| 298 |
+
def forward(self,
|
| 299 |
+
styles,
|
| 300 |
+
input_is_latent=False,
|
| 301 |
+
noise=None,
|
| 302 |
+
randomize_noise=True,
|
| 303 |
+
truncation=1,
|
| 304 |
+
truncation_latent=None,
|
| 305 |
+
inject_index=None,
|
| 306 |
+
return_latents=False):
|
| 307 |
+
"""Forward function for StyleGAN2GeneratorClean.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
styles (list[Tensor]): Sample codes of styles.
|
| 311 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
| 312 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
| 313 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
| 314 |
+
truncation (float): The truncation ratio. Default: 1.
|
| 315 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
| 316 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
| 317 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
| 318 |
+
"""
|
| 319 |
+
# style codes -> latents with Style MLP layer
|
| 320 |
+
if not input_is_latent:
|
| 321 |
+
styles = [self.style_mlp(s) for s in styles]
|
| 322 |
+
# noises
|
| 323 |
+
if noise is None:
|
| 324 |
+
if randomize_noise:
|
| 325 |
+
noise = [None] * self.num_layers # for each style conv layer
|
| 326 |
+
else: # use the stored noise
|
| 327 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
| 328 |
+
# style truncation
|
| 329 |
+
if truncation < 1:
|
| 330 |
+
style_truncation = []
|
| 331 |
+
for style in styles:
|
| 332 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
| 333 |
+
styles = style_truncation
|
| 334 |
+
# get style latents with injection
|
| 335 |
+
if len(styles) == 1:
|
| 336 |
+
inject_index = self.num_latent
|
| 337 |
+
|
| 338 |
+
if styles[0].ndim < 3:
|
| 339 |
+
# repeat latent code for all the layers
|
| 340 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
| 341 |
+
else: # used for encoder with different latent code for each layer
|
| 342 |
+
latent = styles[0]
|
| 343 |
+
elif len(styles) == 2: # mixing noises
|
| 344 |
+
if inject_index is None:
|
| 345 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
| 346 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
| 347 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
| 348 |
+
latent = torch.cat([latent1, latent2], 1)
|
| 349 |
+
|
| 350 |
+
# main generation
|
| 351 |
+
out = self.constant_input(latent.shape[0])
|
| 352 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
| 353 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
| 354 |
+
|
| 355 |
+
i = 1
|
| 356 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
| 357 |
+
noise[2::2], self.to_rgbs):
|
| 358 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
| 359 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
| 360 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
| 361 |
+
i += 2
|
| 362 |
+
|
| 363 |
+
image = skip
|
| 364 |
+
|
| 365 |
+
if return_latents:
|
| 366 |
+
return image, latent
|
| 367 |
+
else:
|
| 368 |
+
return image, None
|
gfpgan/data/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from basicsr.utils import scandir
|
| 3 |
+
from os import path as osp
|
| 4 |
+
|
| 5 |
+
# automatically scan and import dataset modules for registry
|
| 6 |
+
# scan all the files that end with '_dataset.py' under the data folder
|
| 7 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
| 8 |
+
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
| 9 |
+
# import all the dataset modules
|
| 10 |
+
_dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames]
|
gfpgan/data/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (716 Bytes). View file
|
|
|
gfpgan/data/__pycache__/ffhq_degradation_dataset.cpython-38.pyc
ADDED
|
Binary file (7.14 kB). View file
|
|
|
gfpgan/data/ffhq_degradation_dataset.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os.path as osp
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils.data as data
|
| 7 |
+
from basicsr.data import degradations as degradations
|
| 8 |
+
from basicsr.data.data_util import paths_from_folder
|
| 9 |
+
from basicsr.data.transforms import augment
|
| 10 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
| 11 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 12 |
+
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
|
| 13 |
+
normalize)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@DATASET_REGISTRY.register()
|
| 17 |
+
class FFHQDegradationDataset(data.Dataset):
|
| 18 |
+
"""FFHQ dataset for GFPGAN.
|
| 19 |
+
|
| 20 |
+
It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 24 |
+
dataroot_gt (str): Data root path for gt.
|
| 25 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 26 |
+
mean (list | tuple): Image mean.
|
| 27 |
+
std (list | tuple): Image std.
|
| 28 |
+
use_hflip (bool): Whether to horizontally flip.
|
| 29 |
+
Please see more options in the codes.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, opt):
|
| 33 |
+
super(FFHQDegradationDataset, self).__init__()
|
| 34 |
+
self.opt = opt
|
| 35 |
+
# file client (io backend)
|
| 36 |
+
self.file_client = None
|
| 37 |
+
self.io_backend_opt = opt['io_backend']
|
| 38 |
+
|
| 39 |
+
self.gt_folder = opt['dataroot_gt']
|
| 40 |
+
self.mean = opt['mean']
|
| 41 |
+
self.std = opt['std']
|
| 42 |
+
self.out_size = opt['out_size']
|
| 43 |
+
|
| 44 |
+
self.crop_components = opt.get('crop_components', False) # facial components
|
| 45 |
+
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
|
| 46 |
+
|
| 47 |
+
if self.crop_components:
|
| 48 |
+
# load component list from a pre-process pth files
|
| 49 |
+
self.components_list = torch.load(opt.get('component_path'))
|
| 50 |
+
|
| 51 |
+
# file client (lmdb io backend)
|
| 52 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 53 |
+
self.io_backend_opt['db_paths'] = self.gt_folder
|
| 54 |
+
if not self.gt_folder.endswith('.lmdb'):
|
| 55 |
+
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
| 56 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
| 57 |
+
self.paths = [line.split('.')[0] for line in fin]
|
| 58 |
+
else:
|
| 59 |
+
# disk backend: scan file list from a folder
|
| 60 |
+
self.paths = paths_from_folder(self.gt_folder)
|
| 61 |
+
|
| 62 |
+
# degradation configurations
|
| 63 |
+
self.blur_kernel_size = opt['blur_kernel_size']
|
| 64 |
+
self.kernel_list = opt['kernel_list']
|
| 65 |
+
self.kernel_prob = opt['kernel_prob']
|
| 66 |
+
self.blur_sigma = opt['blur_sigma']
|
| 67 |
+
self.downsample_range = opt['downsample_range']
|
| 68 |
+
self.noise_range = opt['noise_range']
|
| 69 |
+
self.jpeg_range = opt['jpeg_range']
|
| 70 |
+
|
| 71 |
+
# color jitter
|
| 72 |
+
self.color_jitter_prob = opt.get('color_jitter_prob')
|
| 73 |
+
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
|
| 74 |
+
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
|
| 75 |
+
# to gray
|
| 76 |
+
self.gray_prob = opt.get('gray_prob')
|
| 77 |
+
|
| 78 |
+
logger = get_root_logger()
|
| 79 |
+
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
| 80 |
+
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
| 81 |
+
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
| 82 |
+
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
| 83 |
+
|
| 84 |
+
if self.color_jitter_prob is not None:
|
| 85 |
+
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
|
| 86 |
+
if self.gray_prob is not None:
|
| 87 |
+
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
| 88 |
+
self.color_jitter_shift /= 255.
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def color_jitter(img, shift):
|
| 92 |
+
"""jitter color: randomly jitter the RGB values, in numpy formats"""
|
| 93 |
+
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
| 94 |
+
img = img + jitter_val
|
| 95 |
+
img = np.clip(img, 0, 1)
|
| 96 |
+
return img
|
| 97 |
+
|
| 98 |
+
@staticmethod
|
| 99 |
+
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
| 100 |
+
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
|
| 101 |
+
fn_idx = torch.randperm(4)
|
| 102 |
+
for fn_id in fn_idx:
|
| 103 |
+
if fn_id == 0 and brightness is not None:
|
| 104 |
+
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
|
| 105 |
+
img = adjust_brightness(img, brightness_factor)
|
| 106 |
+
|
| 107 |
+
if fn_id == 1 and contrast is not None:
|
| 108 |
+
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
|
| 109 |
+
img = adjust_contrast(img, contrast_factor)
|
| 110 |
+
|
| 111 |
+
if fn_id == 2 and saturation is not None:
|
| 112 |
+
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
|
| 113 |
+
img = adjust_saturation(img, saturation_factor)
|
| 114 |
+
|
| 115 |
+
if fn_id == 3 and hue is not None:
|
| 116 |
+
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
|
| 117 |
+
img = adjust_hue(img, hue_factor)
|
| 118 |
+
return img
|
| 119 |
+
|
| 120 |
+
def get_component_coordinates(self, index, status):
|
| 121 |
+
"""Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
|
| 122 |
+
components_bbox = self.components_list[f'{index:08d}']
|
| 123 |
+
if status[0]: # hflip
|
| 124 |
+
# exchange right and left eye
|
| 125 |
+
tmp = components_bbox['left_eye']
|
| 126 |
+
components_bbox['left_eye'] = components_bbox['right_eye']
|
| 127 |
+
components_bbox['right_eye'] = tmp
|
| 128 |
+
# modify the width coordinate
|
| 129 |
+
components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
|
| 130 |
+
components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
|
| 131 |
+
components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
|
| 132 |
+
|
| 133 |
+
# get coordinates
|
| 134 |
+
locations = []
|
| 135 |
+
for part in ['left_eye', 'right_eye', 'mouth']:
|
| 136 |
+
mean = components_bbox[part][0:2]
|
| 137 |
+
half_len = components_bbox[part][2]
|
| 138 |
+
if 'eye' in part:
|
| 139 |
+
half_len *= self.eye_enlarge_ratio
|
| 140 |
+
loc = np.hstack((mean - half_len + 1, mean + half_len))
|
| 141 |
+
loc = torch.from_numpy(loc).float()
|
| 142 |
+
locations.append(loc)
|
| 143 |
+
return locations
|
| 144 |
+
|
| 145 |
+
def __getitem__(self, index):
|
| 146 |
+
if self.file_client is None:
|
| 147 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 148 |
+
|
| 149 |
+
# load gt image
|
| 150 |
+
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
| 151 |
+
gt_path = self.paths[index]
|
| 152 |
+
img_bytes = self.file_client.get(gt_path)
|
| 153 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 154 |
+
|
| 155 |
+
# random horizontal flip
|
| 156 |
+
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
| 157 |
+
h, w, _ = img_gt.shape
|
| 158 |
+
|
| 159 |
+
# get facial component coordinates
|
| 160 |
+
if self.crop_components:
|
| 161 |
+
locations = self.get_component_coordinates(index, status)
|
| 162 |
+
loc_left_eye, loc_right_eye, loc_mouth = locations
|
| 163 |
+
|
| 164 |
+
# ------------------------ generate lq image ------------------------ #
|
| 165 |
+
# blur
|
| 166 |
+
kernel = degradations.random_mixed_kernels(
|
| 167 |
+
self.kernel_list,
|
| 168 |
+
self.kernel_prob,
|
| 169 |
+
self.blur_kernel_size,
|
| 170 |
+
self.blur_sigma,
|
| 171 |
+
self.blur_sigma, [-math.pi, math.pi],
|
| 172 |
+
noise_range=None)
|
| 173 |
+
img_lq = cv2.filter2D(img_gt, -1, kernel)
|
| 174 |
+
# downsample
|
| 175 |
+
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
|
| 176 |
+
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
|
| 177 |
+
# noise
|
| 178 |
+
if self.noise_range is not None:
|
| 179 |
+
img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
|
| 180 |
+
# jpeg compression
|
| 181 |
+
if self.jpeg_range is not None:
|
| 182 |
+
img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
|
| 183 |
+
|
| 184 |
+
# resize to original size
|
| 185 |
+
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
|
| 186 |
+
|
| 187 |
+
# random color jitter (only for lq)
|
| 188 |
+
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
|
| 189 |
+
img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
|
| 190 |
+
# random to gray (only for lq)
|
| 191 |
+
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
| 192 |
+
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
|
| 193 |
+
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
|
| 194 |
+
if self.opt.get('gt_gray'): # whether convert GT to gray images
|
| 195 |
+
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
|
| 196 |
+
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
|
| 197 |
+
|
| 198 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 199 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
| 200 |
+
|
| 201 |
+
# random color jitter (pytorch version) (only for lq)
|
| 202 |
+
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
|
| 203 |
+
brightness = self.opt.get('brightness', (0.5, 1.5))
|
| 204 |
+
contrast = self.opt.get('contrast', (0.5, 1.5))
|
| 205 |
+
saturation = self.opt.get('saturation', (0, 1.5))
|
| 206 |
+
hue = self.opt.get('hue', (-0.1, 0.1))
|
| 207 |
+
img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
|
| 208 |
+
|
| 209 |
+
# round and clip
|
| 210 |
+
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
|
| 211 |
+
|
| 212 |
+
# normalize
|
| 213 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
| 214 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
| 215 |
+
|
| 216 |
+
if self.crop_components:
|
| 217 |
+
return_dict = {
|
| 218 |
+
'lq': img_lq,
|
| 219 |
+
'gt': img_gt,
|
| 220 |
+
'gt_path': gt_path,
|
| 221 |
+
'loc_left_eye': loc_left_eye,
|
| 222 |
+
'loc_right_eye': loc_right_eye,
|
| 223 |
+
'loc_mouth': loc_mouth
|
| 224 |
+
}
|
| 225 |
+
return return_dict
|
| 226 |
+
else:
|
| 227 |
+
return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
|
| 228 |
+
|
| 229 |
+
def __len__(self):
|
| 230 |
+
return len(self.paths)
|
gfpgan/models/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from basicsr.utils import scandir
|
| 3 |
+
from os import path as osp
|
| 4 |
+
|
| 5 |
+
# automatically scan and import model modules for registry
|
| 6 |
+
# scan all the files that end with '_model.py' under the model folder
|
| 7 |
+
model_folder = osp.dirname(osp.abspath(__file__))
|
| 8 |
+
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
| 9 |
+
# import all the model modules
|
| 10 |
+
_model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]
|
gfpgan/models/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (715 Bytes). View file
|
|
|
gfpgan/models/__pycache__/gfpgan_model.cpython-38.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
gfpgan/models/gfpgan_model.py
ADDED
|
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import torch
|
| 4 |
+
from basicsr.archs import build_network
|
| 5 |
+
from basicsr.losses import build_loss
|
| 6 |
+
from basicsr.losses.gan_loss import r1_penalty
|
| 7 |
+
from basicsr.metrics import calculate_metric
|
| 8 |
+
from basicsr.models.base_model import BaseModel
|
| 9 |
+
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
| 10 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
from torch.nn import functional as F
|
| 13 |
+
from torchvision.ops import roi_align
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@MODEL_REGISTRY.register()
|
| 18 |
+
class GFPGANModel(BaseModel):
|
| 19 |
+
"""The GFPGAN model for Towards real-world blind face restoratin with generative facial prior"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, opt):
|
| 22 |
+
super(GFPGANModel, self).__init__(opt)
|
| 23 |
+
self.idx = 0 # it is used for saving data for check
|
| 24 |
+
|
| 25 |
+
# define network
|
| 26 |
+
self.net_g = build_network(opt['network_g'])
|
| 27 |
+
self.net_g = self.model_to_device(self.net_g)
|
| 28 |
+
self.print_network(self.net_g)
|
| 29 |
+
|
| 30 |
+
# load pretrained model
|
| 31 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
| 32 |
+
if load_path is not None:
|
| 33 |
+
param_key = self.opt['path'].get('param_key_g', 'params')
|
| 34 |
+
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
|
| 35 |
+
|
| 36 |
+
self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))
|
| 37 |
+
|
| 38 |
+
if self.is_train:
|
| 39 |
+
self.init_training_settings()
|
| 40 |
+
|
| 41 |
+
def init_training_settings(self):
|
| 42 |
+
train_opt = self.opt['train']
|
| 43 |
+
|
| 44 |
+
# ----------- define net_d ----------- #
|
| 45 |
+
self.net_d = build_network(self.opt['network_d'])
|
| 46 |
+
self.net_d = self.model_to_device(self.net_d)
|
| 47 |
+
self.print_network(self.net_d)
|
| 48 |
+
# load pretrained model
|
| 49 |
+
load_path = self.opt['path'].get('pretrain_network_d', None)
|
| 50 |
+
if load_path is not None:
|
| 51 |
+
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
| 52 |
+
|
| 53 |
+
# ----------- define net_g with Exponential Moving Average (EMA) ----------- #
|
| 54 |
+
# net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel
|
| 55 |
+
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
| 56 |
+
# load pretrained model
|
| 57 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
| 58 |
+
if load_path is not None:
|
| 59 |
+
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
| 60 |
+
else:
|
| 61 |
+
self.model_ema(0) # copy net_g weight
|
| 62 |
+
|
| 63 |
+
self.net_g.train()
|
| 64 |
+
self.net_d.train()
|
| 65 |
+
self.net_g_ema.eval()
|
| 66 |
+
|
| 67 |
+
# ----------- facial component networks ----------- #
|
| 68 |
+
if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
|
| 69 |
+
self.use_facial_disc = True
|
| 70 |
+
else:
|
| 71 |
+
self.use_facial_disc = False
|
| 72 |
+
|
| 73 |
+
if self.use_facial_disc:
|
| 74 |
+
# left eye
|
| 75 |
+
self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
|
| 76 |
+
self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
|
| 77 |
+
self.print_network(self.net_d_left_eye)
|
| 78 |
+
load_path = self.opt['path'].get('pretrain_network_d_left_eye')
|
| 79 |
+
if load_path is not None:
|
| 80 |
+
self.load_network(self.net_d_left_eye, load_path, True, 'params')
|
| 81 |
+
# right eye
|
| 82 |
+
self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
|
| 83 |
+
self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
|
| 84 |
+
self.print_network(self.net_d_right_eye)
|
| 85 |
+
load_path = self.opt['path'].get('pretrain_network_d_right_eye')
|
| 86 |
+
if load_path is not None:
|
| 87 |
+
self.load_network(self.net_d_right_eye, load_path, True, 'params')
|
| 88 |
+
# mouth
|
| 89 |
+
self.net_d_mouth = build_network(self.opt['network_d_mouth'])
|
| 90 |
+
self.net_d_mouth = self.model_to_device(self.net_d_mouth)
|
| 91 |
+
self.print_network(self.net_d_mouth)
|
| 92 |
+
load_path = self.opt['path'].get('pretrain_network_d_mouth')
|
| 93 |
+
if load_path is not None:
|
| 94 |
+
self.load_network(self.net_d_mouth, load_path, True, 'params')
|
| 95 |
+
|
| 96 |
+
self.net_d_left_eye.train()
|
| 97 |
+
self.net_d_right_eye.train()
|
| 98 |
+
self.net_d_mouth.train()
|
| 99 |
+
|
| 100 |
+
# ----------- define facial component gan loss ----------- #
|
| 101 |
+
self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
|
| 102 |
+
|
| 103 |
+
# ----------- define losses ----------- #
|
| 104 |
+
# pixel loss
|
| 105 |
+
if train_opt.get('pixel_opt'):
|
| 106 |
+
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
| 107 |
+
else:
|
| 108 |
+
self.cri_pix = None
|
| 109 |
+
|
| 110 |
+
# perceptual loss
|
| 111 |
+
if train_opt.get('perceptual_opt'):
|
| 112 |
+
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
| 113 |
+
else:
|
| 114 |
+
self.cri_perceptual = None
|
| 115 |
+
|
| 116 |
+
# L1 loss is used in pyramid loss, component style loss and identity loss
|
| 117 |
+
self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
|
| 118 |
+
|
| 119 |
+
# gan loss (wgan)
|
| 120 |
+
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
|
| 121 |
+
|
| 122 |
+
# ----------- define identity loss ----------- #
|
| 123 |
+
if 'network_identity' in self.opt:
|
| 124 |
+
self.use_identity = True
|
| 125 |
+
else:
|
| 126 |
+
self.use_identity = False
|
| 127 |
+
|
| 128 |
+
if self.use_identity:
|
| 129 |
+
# define identity network
|
| 130 |
+
self.network_identity = build_network(self.opt['network_identity'])
|
| 131 |
+
self.network_identity = self.model_to_device(self.network_identity)
|
| 132 |
+
self.print_network(self.network_identity)
|
| 133 |
+
load_path = self.opt['path'].get('pretrain_network_identity')
|
| 134 |
+
if load_path is not None:
|
| 135 |
+
self.load_network(self.network_identity, load_path, True, None)
|
| 136 |
+
self.network_identity.eval()
|
| 137 |
+
for param in self.network_identity.parameters():
|
| 138 |
+
param.requires_grad = False
|
| 139 |
+
|
| 140 |
+
# regularization weights
|
| 141 |
+
self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
|
| 142 |
+
self.net_d_iters = train_opt.get('net_d_iters', 1)
|
| 143 |
+
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
|
| 144 |
+
self.net_d_reg_every = train_opt['net_d_reg_every']
|
| 145 |
+
|
| 146 |
+
# set up optimizers and schedulers
|
| 147 |
+
self.setup_optimizers()
|
| 148 |
+
self.setup_schedulers()
|
| 149 |
+
|
| 150 |
+
def setup_optimizers(self):
|
| 151 |
+
train_opt = self.opt['train']
|
| 152 |
+
|
| 153 |
+
# ----------- optimizer g ----------- #
|
| 154 |
+
net_g_reg_ratio = 1
|
| 155 |
+
normal_params = []
|
| 156 |
+
for _, param in self.net_g.named_parameters():
|
| 157 |
+
normal_params.append(param)
|
| 158 |
+
optim_params_g = [{ # add normal params first
|
| 159 |
+
'params': normal_params,
|
| 160 |
+
'lr': train_opt['optim_g']['lr']
|
| 161 |
+
}]
|
| 162 |
+
optim_type = train_opt['optim_g'].pop('type')
|
| 163 |
+
lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
|
| 164 |
+
betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
|
| 165 |
+
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
|
| 166 |
+
self.optimizers.append(self.optimizer_g)
|
| 167 |
+
|
| 168 |
+
# ----------- optimizer d ----------- #
|
| 169 |
+
net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
|
| 170 |
+
normal_params = []
|
| 171 |
+
for _, param in self.net_d.named_parameters():
|
| 172 |
+
normal_params.append(param)
|
| 173 |
+
optim_params_d = [{ # add normal params first
|
| 174 |
+
'params': normal_params,
|
| 175 |
+
'lr': train_opt['optim_d']['lr']
|
| 176 |
+
}]
|
| 177 |
+
optim_type = train_opt['optim_d'].pop('type')
|
| 178 |
+
lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
|
| 179 |
+
betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
|
| 180 |
+
self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
|
| 181 |
+
self.optimizers.append(self.optimizer_d)
|
| 182 |
+
|
| 183 |
+
# ----------- optimizers for facial component networks ----------- #
|
| 184 |
+
if self.use_facial_disc:
|
| 185 |
+
# setup optimizers for facial component discriminators
|
| 186 |
+
optim_type = train_opt['optim_component'].pop('type')
|
| 187 |
+
lr = train_opt['optim_component']['lr']
|
| 188 |
+
# left eye
|
| 189 |
+
self.optimizer_d_left_eye = self.get_optimizer(
|
| 190 |
+
optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99))
|
| 191 |
+
self.optimizers.append(self.optimizer_d_left_eye)
|
| 192 |
+
# right eye
|
| 193 |
+
self.optimizer_d_right_eye = self.get_optimizer(
|
| 194 |
+
optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99))
|
| 195 |
+
self.optimizers.append(self.optimizer_d_right_eye)
|
| 196 |
+
# mouth
|
| 197 |
+
self.optimizer_d_mouth = self.get_optimizer(
|
| 198 |
+
optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99))
|
| 199 |
+
self.optimizers.append(self.optimizer_d_mouth)
|
| 200 |
+
|
| 201 |
+
def feed_data(self, data):
|
| 202 |
+
self.lq = data['lq'].to(self.device)
|
| 203 |
+
if 'gt' in data:
|
| 204 |
+
self.gt = data['gt'].to(self.device)
|
| 205 |
+
|
| 206 |
+
if 'loc_left_eye' in data:
|
| 207 |
+
# get facial component locations, shape (batch, 4)
|
| 208 |
+
self.loc_left_eyes = data['loc_left_eye']
|
| 209 |
+
self.loc_right_eyes = data['loc_right_eye']
|
| 210 |
+
self.loc_mouths = data['loc_mouth']
|
| 211 |
+
|
| 212 |
+
# uncomment to check data
|
| 213 |
+
# import torchvision
|
| 214 |
+
# if self.opt['rank'] == 0:
|
| 215 |
+
# import os
|
| 216 |
+
# os.makedirs('tmp/gt', exist_ok=True)
|
| 217 |
+
# os.makedirs('tmp/lq', exist_ok=True)
|
| 218 |
+
# print(self.idx)
|
| 219 |
+
# torchvision.utils.save_image(
|
| 220 |
+
# self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
|
| 221 |
+
# torchvision.utils.save_image(
|
| 222 |
+
# self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
|
| 223 |
+
# self.idx = self.idx + 1
|
| 224 |
+
|
| 225 |
+
def construct_img_pyramid(self):
|
| 226 |
+
"""Construct image pyramid for intermediate restoration loss"""
|
| 227 |
+
pyramid_gt = [self.gt]
|
| 228 |
+
down_img = self.gt
|
| 229 |
+
for _ in range(0, self.log_size - 3):
|
| 230 |
+
down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False)
|
| 231 |
+
pyramid_gt.insert(0, down_img)
|
| 232 |
+
return pyramid_gt
|
| 233 |
+
|
| 234 |
+
def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
|
| 235 |
+
face_ratio = int(self.opt['network_g']['out_size'] / 512)
|
| 236 |
+
eye_out_size *= face_ratio
|
| 237 |
+
mouth_out_size *= face_ratio
|
| 238 |
+
|
| 239 |
+
rois_eyes = []
|
| 240 |
+
rois_mouths = []
|
| 241 |
+
for b in range(self.loc_left_eyes.size(0)): # loop for batch size
|
| 242 |
+
# left eye and right eye
|
| 243 |
+
img_inds = self.loc_left_eyes.new_full((2, 1), b)
|
| 244 |
+
bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4)
|
| 245 |
+
rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5)
|
| 246 |
+
rois_eyes.append(rois)
|
| 247 |
+
# mouse
|
| 248 |
+
img_inds = self.loc_left_eyes.new_full((1, 1), b)
|
| 249 |
+
rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5)
|
| 250 |
+
rois_mouths.append(rois)
|
| 251 |
+
|
| 252 |
+
rois_eyes = torch.cat(rois_eyes, 0).to(self.device)
|
| 253 |
+
rois_mouths = torch.cat(rois_mouths, 0).to(self.device)
|
| 254 |
+
|
| 255 |
+
# real images
|
| 256 |
+
all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
|
| 257 |
+
self.left_eyes_gt = all_eyes[0::2, :, :, :]
|
| 258 |
+
self.right_eyes_gt = all_eyes[1::2, :, :, :]
|
| 259 |
+
self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
|
| 260 |
+
# output
|
| 261 |
+
all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
|
| 262 |
+
self.left_eyes = all_eyes[0::2, :, :, :]
|
| 263 |
+
self.right_eyes = all_eyes[1::2, :, :, :]
|
| 264 |
+
self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
|
| 265 |
+
|
| 266 |
+
def _gram_mat(self, x):
|
| 267 |
+
"""Calculate Gram matrix.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
torch.Tensor: Gram matrix.
|
| 274 |
+
"""
|
| 275 |
+
n, c, h, w = x.size()
|
| 276 |
+
features = x.view(n, c, w * h)
|
| 277 |
+
features_t = features.transpose(1, 2)
|
| 278 |
+
gram = features.bmm(features_t) / (c * h * w)
|
| 279 |
+
return gram
|
| 280 |
+
|
| 281 |
+
def gray_resize_for_identity(self, out, size=128):
|
| 282 |
+
out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
|
| 283 |
+
out_gray = out_gray.unsqueeze(1)
|
| 284 |
+
out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
|
| 285 |
+
return out_gray
|
| 286 |
+
|
| 287 |
+
def optimize_parameters(self, current_iter):
|
| 288 |
+
# optimize net_g
|
| 289 |
+
for p in self.net_d.parameters():
|
| 290 |
+
p.requires_grad = False
|
| 291 |
+
self.optimizer_g.zero_grad()
|
| 292 |
+
|
| 293 |
+
# do not update facial component net_d
|
| 294 |
+
if self.use_facial_disc:
|
| 295 |
+
for p in self.net_d_left_eye.parameters():
|
| 296 |
+
p.requires_grad = False
|
| 297 |
+
for p in self.net_d_right_eye.parameters():
|
| 298 |
+
p.requires_grad = False
|
| 299 |
+
for p in self.net_d_mouth.parameters():
|
| 300 |
+
p.requires_grad = False
|
| 301 |
+
|
| 302 |
+
# image pyramid loss weight
|
| 303 |
+
pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0)
|
| 304 |
+
if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get('remove_pyramid_loss', float('inf')):
|
| 305 |
+
pyramid_loss_weight = 1e-12 # very small weight to avoid unused param error
|
| 306 |
+
if pyramid_loss_weight > 0:
|
| 307 |
+
self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
|
| 308 |
+
pyramid_gt = self.construct_img_pyramid()
|
| 309 |
+
else:
|
| 310 |
+
self.output, out_rgbs = self.net_g(self.lq, return_rgb=False)
|
| 311 |
+
|
| 312 |
+
# get roi-align regions
|
| 313 |
+
if self.use_facial_disc:
|
| 314 |
+
self.get_roi_regions(eye_out_size=80, mouth_out_size=120)
|
| 315 |
+
|
| 316 |
+
l_g_total = 0
|
| 317 |
+
loss_dict = OrderedDict()
|
| 318 |
+
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
|
| 319 |
+
# pixel loss
|
| 320 |
+
if self.cri_pix:
|
| 321 |
+
l_g_pix = self.cri_pix(self.output, self.gt)
|
| 322 |
+
l_g_total += l_g_pix
|
| 323 |
+
loss_dict['l_g_pix'] = l_g_pix
|
| 324 |
+
|
| 325 |
+
# image pyramid loss
|
| 326 |
+
if pyramid_loss_weight > 0:
|
| 327 |
+
for i in range(0, self.log_size - 2):
|
| 328 |
+
l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight
|
| 329 |
+
l_g_total += l_pyramid
|
| 330 |
+
loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid
|
| 331 |
+
|
| 332 |
+
# perceptual loss
|
| 333 |
+
if self.cri_perceptual:
|
| 334 |
+
l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
|
| 335 |
+
if l_g_percep is not None:
|
| 336 |
+
l_g_total += l_g_percep
|
| 337 |
+
loss_dict['l_g_percep'] = l_g_percep
|
| 338 |
+
if l_g_style is not None:
|
| 339 |
+
l_g_total += l_g_style
|
| 340 |
+
loss_dict['l_g_style'] = l_g_style
|
| 341 |
+
|
| 342 |
+
# gan loss
|
| 343 |
+
fake_g_pred = self.net_d(self.output)
|
| 344 |
+
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
| 345 |
+
l_g_total += l_g_gan
|
| 346 |
+
loss_dict['l_g_gan'] = l_g_gan
|
| 347 |
+
|
| 348 |
+
# facial component loss
|
| 349 |
+
if self.use_facial_disc:
|
| 350 |
+
# left eye
|
| 351 |
+
fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True)
|
| 352 |
+
l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False)
|
| 353 |
+
l_g_total += l_g_gan
|
| 354 |
+
loss_dict['l_g_gan_left_eye'] = l_g_gan
|
| 355 |
+
# right eye
|
| 356 |
+
fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True)
|
| 357 |
+
l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False)
|
| 358 |
+
l_g_total += l_g_gan
|
| 359 |
+
loss_dict['l_g_gan_right_eye'] = l_g_gan
|
| 360 |
+
# mouth
|
| 361 |
+
fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True)
|
| 362 |
+
l_g_gan = self.cri_component(fake_mouth, True, is_disc=False)
|
| 363 |
+
l_g_total += l_g_gan
|
| 364 |
+
loss_dict['l_g_gan_mouth'] = l_g_gan
|
| 365 |
+
|
| 366 |
+
if self.opt['train'].get('comp_style_weight', 0) > 0:
|
| 367 |
+
# get gt feat
|
| 368 |
+
_, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True)
|
| 369 |
+
_, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True)
|
| 370 |
+
_, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True)
|
| 371 |
+
|
| 372 |
+
def _comp_style(feat, feat_gt, criterion):
|
| 373 |
+
return criterion(self._gram_mat(feat[0]), self._gram_mat(
|
| 374 |
+
feat_gt[0].detach())) * 0.5 + criterion(
|
| 375 |
+
self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach()))
|
| 376 |
+
|
| 377 |
+
# facial component style loss
|
| 378 |
+
comp_style_loss = 0
|
| 379 |
+
comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1)
|
| 380 |
+
comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1)
|
| 381 |
+
comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1)
|
| 382 |
+
comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight']
|
| 383 |
+
l_g_total += comp_style_loss
|
| 384 |
+
loss_dict['l_g_comp_style_loss'] = comp_style_loss
|
| 385 |
+
|
| 386 |
+
# identity loss
|
| 387 |
+
if self.use_identity:
|
| 388 |
+
identity_weight = self.opt['train']['identity_weight']
|
| 389 |
+
# get gray images and resize
|
| 390 |
+
out_gray = self.gray_resize_for_identity(self.output)
|
| 391 |
+
gt_gray = self.gray_resize_for_identity(self.gt)
|
| 392 |
+
|
| 393 |
+
identity_gt = self.network_identity(gt_gray).detach()
|
| 394 |
+
identity_out = self.network_identity(out_gray)
|
| 395 |
+
l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight
|
| 396 |
+
l_g_total += l_identity
|
| 397 |
+
loss_dict['l_identity'] = l_identity
|
| 398 |
+
|
| 399 |
+
l_g_total.backward()
|
| 400 |
+
self.optimizer_g.step()
|
| 401 |
+
|
| 402 |
+
# EMA
|
| 403 |
+
self.model_ema(decay=0.5**(32 / (10 * 1000)))
|
| 404 |
+
|
| 405 |
+
# ----------- optimize net_d ----------- #
|
| 406 |
+
for p in self.net_d.parameters():
|
| 407 |
+
p.requires_grad = True
|
| 408 |
+
self.optimizer_d.zero_grad()
|
| 409 |
+
if self.use_facial_disc:
|
| 410 |
+
for p in self.net_d_left_eye.parameters():
|
| 411 |
+
p.requires_grad = True
|
| 412 |
+
for p in self.net_d_right_eye.parameters():
|
| 413 |
+
p.requires_grad = True
|
| 414 |
+
for p in self.net_d_mouth.parameters():
|
| 415 |
+
p.requires_grad = True
|
| 416 |
+
self.optimizer_d_left_eye.zero_grad()
|
| 417 |
+
self.optimizer_d_right_eye.zero_grad()
|
| 418 |
+
self.optimizer_d_mouth.zero_grad()
|
| 419 |
+
|
| 420 |
+
fake_d_pred = self.net_d(self.output.detach())
|
| 421 |
+
real_d_pred = self.net_d(self.gt)
|
| 422 |
+
l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
|
| 423 |
+
loss_dict['l_d'] = l_d
|
| 424 |
+
# In WGAN, real_score should be positive and fake_score should be negative
|
| 425 |
+
loss_dict['real_score'] = real_d_pred.detach().mean()
|
| 426 |
+
loss_dict['fake_score'] = fake_d_pred.detach().mean()
|
| 427 |
+
l_d.backward()
|
| 428 |
+
|
| 429 |
+
# regularization loss
|
| 430 |
+
if current_iter % self.net_d_reg_every == 0:
|
| 431 |
+
self.gt.requires_grad = True
|
| 432 |
+
real_pred = self.net_d(self.gt)
|
| 433 |
+
l_d_r1 = r1_penalty(real_pred, self.gt)
|
| 434 |
+
l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
|
| 435 |
+
loss_dict['l_d_r1'] = l_d_r1.detach().mean()
|
| 436 |
+
l_d_r1.backward()
|
| 437 |
+
|
| 438 |
+
self.optimizer_d.step()
|
| 439 |
+
|
| 440 |
+
# optimize facial component discriminators
|
| 441 |
+
if self.use_facial_disc:
|
| 442 |
+
# left eye
|
| 443 |
+
fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
|
| 444 |
+
real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
|
| 445 |
+
l_d_left_eye = self.cri_component(
|
| 446 |
+
real_d_pred, True, is_disc=True) + self.cri_gan(
|
| 447 |
+
fake_d_pred, False, is_disc=True)
|
| 448 |
+
loss_dict['l_d_left_eye'] = l_d_left_eye
|
| 449 |
+
l_d_left_eye.backward()
|
| 450 |
+
# right eye
|
| 451 |
+
fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach())
|
| 452 |
+
real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt)
|
| 453 |
+
l_d_right_eye = self.cri_component(
|
| 454 |
+
real_d_pred, True, is_disc=True) + self.cri_gan(
|
| 455 |
+
fake_d_pred, False, is_disc=True)
|
| 456 |
+
loss_dict['l_d_right_eye'] = l_d_right_eye
|
| 457 |
+
l_d_right_eye.backward()
|
| 458 |
+
# mouth
|
| 459 |
+
fake_d_pred, _ = self.net_d_mouth(self.mouths.detach())
|
| 460 |
+
real_d_pred, _ = self.net_d_mouth(self.mouths_gt)
|
| 461 |
+
l_d_mouth = self.cri_component(
|
| 462 |
+
real_d_pred, True, is_disc=True) + self.cri_gan(
|
| 463 |
+
fake_d_pred, False, is_disc=True)
|
| 464 |
+
loss_dict['l_d_mouth'] = l_d_mouth
|
| 465 |
+
l_d_mouth.backward()
|
| 466 |
+
|
| 467 |
+
self.optimizer_d_left_eye.step()
|
| 468 |
+
self.optimizer_d_right_eye.step()
|
| 469 |
+
self.optimizer_d_mouth.step()
|
| 470 |
+
|
| 471 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
| 472 |
+
|
| 473 |
+
def test(self):
|
| 474 |
+
with torch.no_grad():
|
| 475 |
+
if hasattr(self, 'net_g_ema'):
|
| 476 |
+
self.net_g_ema.eval()
|
| 477 |
+
self.output, _ = self.net_g_ema(self.lq)
|
| 478 |
+
else:
|
| 479 |
+
logger = get_root_logger()
|
| 480 |
+
logger.warning('Do not have self.net_g_ema, use self.net_g.')
|
| 481 |
+
self.net_g.eval()
|
| 482 |
+
self.output, _ = self.net_g(self.lq)
|
| 483 |
+
self.net_g.train()
|
| 484 |
+
|
| 485 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
| 486 |
+
if self.opt['rank'] == 0:
|
| 487 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
| 488 |
+
|
| 489 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
| 490 |
+
dataset_name = dataloader.dataset.opt['name']
|
| 491 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
| 492 |
+
use_pbar = self.opt['val'].get('pbar', False)
|
| 493 |
+
|
| 494 |
+
if with_metrics:
|
| 495 |
+
if not hasattr(self, 'metric_results'): # only execute in the first run
|
| 496 |
+
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
| 497 |
+
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
|
| 498 |
+
self._initialize_best_metric_results(dataset_name)
|
| 499 |
+
# zero self.metric_results
|
| 500 |
+
self.metric_results = {metric: 0 for metric in self.metric_results}
|
| 501 |
+
|
| 502 |
+
metric_data = dict()
|
| 503 |
+
if use_pbar:
|
| 504 |
+
pbar = tqdm(total=len(dataloader), unit='image')
|
| 505 |
+
|
| 506 |
+
for idx, val_data in enumerate(dataloader):
|
| 507 |
+
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
| 508 |
+
self.feed_data(val_data)
|
| 509 |
+
self.test()
|
| 510 |
+
|
| 511 |
+
sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
|
| 512 |
+
metric_data['img'] = sr_img
|
| 513 |
+
if hasattr(self, 'gt'):
|
| 514 |
+
gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
|
| 515 |
+
metric_data['img2'] = gt_img
|
| 516 |
+
del self.gt
|
| 517 |
+
|
| 518 |
+
# tentative for out of GPU memory
|
| 519 |
+
del self.lq
|
| 520 |
+
del self.output
|
| 521 |
+
torch.cuda.empty_cache()
|
| 522 |
+
|
| 523 |
+
if save_img:
|
| 524 |
+
if self.opt['is_train']:
|
| 525 |
+
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
| 526 |
+
f'{img_name}_{current_iter}.png')
|
| 527 |
+
else:
|
| 528 |
+
if self.opt['val']['suffix']:
|
| 529 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
| 530 |
+
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
| 531 |
+
else:
|
| 532 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
| 533 |
+
f'{img_name}_{self.opt["name"]}.png')
|
| 534 |
+
imwrite(sr_img, save_img_path)
|
| 535 |
+
|
| 536 |
+
if with_metrics:
|
| 537 |
+
# calculate metrics
|
| 538 |
+
for name, opt_ in self.opt['val']['metrics'].items():
|
| 539 |
+
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
| 540 |
+
if use_pbar:
|
| 541 |
+
pbar.update(1)
|
| 542 |
+
pbar.set_description(f'Test {img_name}')
|
| 543 |
+
if use_pbar:
|
| 544 |
+
pbar.close()
|
| 545 |
+
|
| 546 |
+
if with_metrics:
|
| 547 |
+
for metric in self.metric_results.keys():
|
| 548 |
+
self.metric_results[metric] /= (idx + 1)
|
| 549 |
+
# update the best metric result
|
| 550 |
+
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
|
| 551 |
+
|
| 552 |
+
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
| 553 |
+
|
| 554 |
+
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
| 555 |
+
log_str = f'Validation {dataset_name}\n'
|
| 556 |
+
for metric, value in self.metric_results.items():
|
| 557 |
+
log_str += f'\t # {metric}: {value:.4f}'
|
| 558 |
+
if hasattr(self, 'best_metric_results'):
|
| 559 |
+
log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
|
| 560 |
+
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
|
| 561 |
+
log_str += '\n'
|
| 562 |
+
|
| 563 |
+
logger = get_root_logger()
|
| 564 |
+
logger.info(log_str)
|
| 565 |
+
if tb_logger:
|
| 566 |
+
for metric, value in self.metric_results.items():
|
| 567 |
+
tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
|
| 568 |
+
|
| 569 |
+
def save(self, epoch, current_iter):
|
| 570 |
+
# save net_g and net_d
|
| 571 |
+
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
| 572 |
+
self.save_network(self.net_d, 'net_d', current_iter)
|
| 573 |
+
# save component discriminators
|
| 574 |
+
if self.use_facial_disc:
|
| 575 |
+
self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
|
| 576 |
+
self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
|
| 577 |
+
self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
|
| 578 |
+
# save training state
|
| 579 |
+
self.save_training_state(epoch, current_iter)
|
gfpgan/train.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
import os.path as osp
|
| 3 |
+
from basicsr.train import train_pipeline
|
| 4 |
+
|
| 5 |
+
import gfpgan.archs
|
| 6 |
+
import gfpgan.data
|
| 7 |
+
import gfpgan.models
|
| 8 |
+
|
| 9 |
+
if __name__ == '__main__':
|
| 10 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
| 11 |
+
train_pipeline(root_path)
|
gfpgan/utils.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from basicsr.utils import img2tensor, tensor2img
|
| 5 |
+
from basicsr.utils.download_util import load_file_from_url
|
| 6 |
+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
| 7 |
+
from torchvision.transforms.functional import normalize
|
| 8 |
+
|
| 9 |
+
from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
|
| 10 |
+
from gfpgan.archs.gfpganv1_arch import GFPGANv1
|
| 11 |
+
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
| 12 |
+
|
| 13 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class GFPGANer():
|
| 17 |
+
"""Helper for restoration with GFPGAN.
|
| 18 |
+
|
| 19 |
+
It will detect and crop faces, and then resize the faces to 512x512.
|
| 20 |
+
GFPGAN is used to restored the resized faces.
|
| 21 |
+
The background is upsampled with the bg_upsampler.
|
| 22 |
+
Finally, the faces will be pasted back to the upsample background image.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
|
| 26 |
+
upscale (float): The upscale of the final output. Default: 2.
|
| 27 |
+
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
|
| 28 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
| 29 |
+
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
|
| 33 |
+
self.upscale = upscale
|
| 34 |
+
self.bg_upsampler = bg_upsampler
|
| 35 |
+
|
| 36 |
+
# initialize model
|
| 37 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
| 38 |
+
# initialize the GFP-GAN
|
| 39 |
+
if arch == 'clean':
|
| 40 |
+
self.gfpgan = GFPGANv1Clean(
|
| 41 |
+
out_size=512,
|
| 42 |
+
num_style_feat=512,
|
| 43 |
+
channel_multiplier=channel_multiplier,
|
| 44 |
+
decoder_load_path=None,
|
| 45 |
+
fix_decoder=False,
|
| 46 |
+
num_mlp=8,
|
| 47 |
+
input_is_latent=True,
|
| 48 |
+
different_w=True,
|
| 49 |
+
narrow=1,
|
| 50 |
+
sft_half=True)
|
| 51 |
+
elif arch == 'bilinear':
|
| 52 |
+
self.gfpgan = GFPGANBilinear(
|
| 53 |
+
out_size=512,
|
| 54 |
+
num_style_feat=512,
|
| 55 |
+
channel_multiplier=channel_multiplier,
|
| 56 |
+
decoder_load_path=None,
|
| 57 |
+
fix_decoder=False,
|
| 58 |
+
num_mlp=8,
|
| 59 |
+
input_is_latent=True,
|
| 60 |
+
different_w=True,
|
| 61 |
+
narrow=1,
|
| 62 |
+
sft_half=True)
|
| 63 |
+
elif arch == 'original':
|
| 64 |
+
self.gfpgan = GFPGANv1(
|
| 65 |
+
out_size=512,
|
| 66 |
+
num_style_feat=512,
|
| 67 |
+
channel_multiplier=channel_multiplier,
|
| 68 |
+
decoder_load_path=None,
|
| 69 |
+
fix_decoder=True,
|
| 70 |
+
num_mlp=8,
|
| 71 |
+
input_is_latent=True,
|
| 72 |
+
different_w=True,
|
| 73 |
+
narrow=1,
|
| 74 |
+
sft_half=True)
|
| 75 |
+
elif arch == 'RestoreFormer':
|
| 76 |
+
from gfpgan.archs.restoreformer_arch import RestoreFormer
|
| 77 |
+
self.gfpgan = RestoreFormer()
|
| 78 |
+
# initialize face helper
|
| 79 |
+
self.face_helper = FaceRestoreHelper(
|
| 80 |
+
upscale,
|
| 81 |
+
face_size=512,
|
| 82 |
+
crop_ratio=(1, 1),
|
| 83 |
+
det_model='retinaface_resnet50',
|
| 84 |
+
save_ext='png',
|
| 85 |
+
use_parse=True,
|
| 86 |
+
device=self.device,
|
| 87 |
+
model_rootpath='gfpgan/weights')
|
| 88 |
+
|
| 89 |
+
if model_path.startswith('https://'):
|
| 90 |
+
model_path = load_file_from_url(
|
| 91 |
+
url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
|
| 92 |
+
loadnet = torch.load(model_path)
|
| 93 |
+
if 'params_ema' in loadnet:
|
| 94 |
+
keyname = 'params_ema'
|
| 95 |
+
else:
|
| 96 |
+
keyname = 'params'
|
| 97 |
+
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
|
| 98 |
+
self.gfpgan.eval()
|
| 99 |
+
self.gfpgan = self.gfpgan.to(self.device)
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
|
| 103 |
+
self.face_helper.clean_all()
|
| 104 |
+
|
| 105 |
+
if has_aligned: # the inputs are already aligned
|
| 106 |
+
img = cv2.resize(img, (512, 512))
|
| 107 |
+
self.face_helper.cropped_faces = [img]
|
| 108 |
+
else:
|
| 109 |
+
self.face_helper.read_image(img)
|
| 110 |
+
# get face landmarks for each face
|
| 111 |
+
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
|
| 112 |
+
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
|
| 113 |
+
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
|
| 114 |
+
# align and warp each face
|
| 115 |
+
self.face_helper.align_warp_face()
|
| 116 |
+
|
| 117 |
+
# face restoration
|
| 118 |
+
for cropped_face in self.face_helper.cropped_faces:
|
| 119 |
+
# prepare data
|
| 120 |
+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
| 121 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
| 122 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
|
| 126 |
+
# convert to image
|
| 127 |
+
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
| 128 |
+
except RuntimeError as error:
|
| 129 |
+
print(f'\tFailed inference for GFPGAN: {error}.')
|
| 130 |
+
restored_face = cropped_face
|
| 131 |
+
|
| 132 |
+
restored_face = restored_face.astype('uint8')
|
| 133 |
+
self.face_helper.add_restored_face(restored_face)
|
| 134 |
+
|
| 135 |
+
if not has_aligned and paste_back:
|
| 136 |
+
# upsample the background
|
| 137 |
+
if self.bg_upsampler is not None:
|
| 138 |
+
# Now only support RealESRGAN for upsampling background
|
| 139 |
+
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
| 140 |
+
else:
|
| 141 |
+
bg_img = None
|
| 142 |
+
|
| 143 |
+
self.face_helper.get_inverse_affine(None)
|
| 144 |
+
# paste each restored face to the input image
|
| 145 |
+
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
|
| 146 |
+
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
|
| 147 |
+
else:
|
| 148 |
+
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
|
gfpgan/version.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GENERATED VERSION FILE
|
| 2 |
+
# TIME: Fri Sep 16 11:35:59 2022
|
| 3 |
+
__version__ = '1.3.8'
|
| 4 |
+
__gitsha__ = '2eac203'
|
| 5 |
+
version_info = (1, 3, 8)
|
gfpgan/weights/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Weights
|
| 2 |
+
|
| 3 |
+
Put the downloaded weights to this folder.
|