Spaces:
Paused
Paused
Upload 10 files
Browse files- utils/PCA_utils.py +29 -0
- utils/__init__.py +0 -0
- utils/bicubic.py +75 -0
- utils/image_utils.py +108 -0
- utils/save_utils.py +38 -0
- utils/seed.py +31 -0
- utils/shape_predictor.py +2 -1
- utils/time.py +36 -0
- utils/train.py +161 -0
utils/PCA_utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn.decomposition import IncrementalPCA
|
| 2 |
+
import numpy as np
|
| 3 |
+
class IPCAEstimator():
|
| 4 |
+
def __init__(self, n_components):
|
| 5 |
+
self.n_components = n_components
|
| 6 |
+
self.whiten = False
|
| 7 |
+
self.transformer = IncrementalPCA(n_components, whiten=self.whiten, batch_size=max(100, 5*n_components))
|
| 8 |
+
self.batch_support = True
|
| 9 |
+
|
| 10 |
+
def get_param_str(self):
|
| 11 |
+
return "ipca_c{}{}".format(self.n_components, '_w' if self.whiten else '')
|
| 12 |
+
|
| 13 |
+
def fit(self, X):
|
| 14 |
+
self.transformer.fit(X)
|
| 15 |
+
|
| 16 |
+
def fit_partial(self, X):
|
| 17 |
+
try:
|
| 18 |
+
self.transformer.partial_fit(X)
|
| 19 |
+
self.transformer.n_samples_seen_ = \
|
| 20 |
+
self.transformer.n_samples_seen_.astype(np.int64) # avoid overflow
|
| 21 |
+
return True
|
| 22 |
+
except ValueError as e:
|
| 23 |
+
print(f'\nIPCA error:', e)
|
| 24 |
+
return False
|
| 25 |
+
|
| 26 |
+
def get_components(self):
|
| 27 |
+
stdev = np.sqrt(self.transformer.explained_variance_) # already sorted
|
| 28 |
+
var_ratio = self.transformer.explained_variance_ratio_
|
| 29 |
+
return self.transformer.components_, stdev, var_ratio # PCA outputs are normalized
|
utils/__init__.py
ADDED
|
File without changes
|
utils/bicubic.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BicubicDownSample(nn.Module):
|
| 7 |
+
def bicubic_kernel(self, x, a=-0.50):
|
| 8 |
+
"""
|
| 9 |
+
This equation is exactly copied from the website below:
|
| 10 |
+
https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic
|
| 11 |
+
"""
|
| 12 |
+
abs_x = torch.abs(x)
|
| 13 |
+
if abs_x <= 1.:
|
| 14 |
+
return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1
|
| 15 |
+
elif 1. < abs_x < 2.:
|
| 16 |
+
return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a
|
| 17 |
+
else:
|
| 18 |
+
return 0.0
|
| 19 |
+
|
| 20 |
+
def __init__(self, factor=4, cuda=True, padding='reflect'):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.factor = factor
|
| 23 |
+
size = factor * 4
|
| 24 |
+
k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor)
|
| 25 |
+
for i in range(size)], dtype=torch.float32)
|
| 26 |
+
k = k / torch.sum(k)
|
| 27 |
+
# k = torch.einsum('i,j->ij', (k, k))
|
| 28 |
+
k1 = torch.reshape(k, shape=(1, 1, size, 1))
|
| 29 |
+
self.k1 = torch.cat([k1, k1, k1], dim=0)
|
| 30 |
+
k2 = torch.reshape(k, shape=(1, 1, 1, size))
|
| 31 |
+
self.k2 = torch.cat([k2, k2, k2], dim=0)
|
| 32 |
+
self.cuda = '.cuda' if cuda else ''
|
| 33 |
+
self.padding = padding
|
| 34 |
+
for param in self.parameters():
|
| 35 |
+
param.requires_grad = False
|
| 36 |
+
|
| 37 |
+
def forward(self, x, nhwc=False, clip_round=False, byte_output=False):
|
| 38 |
+
# x = torch.from_numpy(x).type('torch.FloatTensor')
|
| 39 |
+
filter_height = self.factor * 4
|
| 40 |
+
filter_width = self.factor * 4
|
| 41 |
+
stride = self.factor
|
| 42 |
+
|
| 43 |
+
pad_along_height = max(filter_height - stride, 0)
|
| 44 |
+
pad_along_width = max(filter_width - stride, 0)
|
| 45 |
+
filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda))
|
| 46 |
+
filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda))
|
| 47 |
+
|
| 48 |
+
# compute actual padding values for each side
|
| 49 |
+
pad_top = pad_along_height // 2
|
| 50 |
+
pad_bottom = pad_along_height - pad_top
|
| 51 |
+
pad_left = pad_along_width // 2
|
| 52 |
+
pad_right = pad_along_width - pad_left
|
| 53 |
+
|
| 54 |
+
# apply mirror padding
|
| 55 |
+
if nhwc:
|
| 56 |
+
x = torch.transpose(torch.transpose(
|
| 57 |
+
x, 2, 3), 1, 2) # NHWC to NCHW
|
| 58 |
+
|
| 59 |
+
# downscaling performed by 1-d convolution
|
| 60 |
+
x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding)
|
| 61 |
+
x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3)
|
| 62 |
+
if clip_round:
|
| 63 |
+
x = torch.clamp(torch.round(x), 0.0, 255.)
|
| 64 |
+
|
| 65 |
+
x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding)
|
| 66 |
+
x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3)
|
| 67 |
+
if clip_round:
|
| 68 |
+
x = torch.clamp(torch.round(x), 0.0, 255.)
|
| 69 |
+
|
| 70 |
+
if nhwc:
|
| 71 |
+
x = torch.transpose(torch.transpose(x, 1, 3), 1, 2)
|
| 72 |
+
if byte_output:
|
| 73 |
+
return x.type('torch.ByteTensor'.format(self.cuda))
|
| 74 |
+
else:
|
| 75 |
+
return x
|
utils/image_utils.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import tempfile
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torchvision.transforms import transforms
|
| 10 |
+
from torchvision.utils import save_image
|
| 11 |
+
|
| 12 |
+
from models.Net import get_segmentation
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def equal_replacer(images: list[torch.Tensor]) -> list[torch.Tensor]:
|
| 16 |
+
for i in range(len(images)):
|
| 17 |
+
if images[i].dtype is torch.uint8:
|
| 18 |
+
images[i] = images[i] / 255
|
| 19 |
+
|
| 20 |
+
for i in range(len(images)):
|
| 21 |
+
for j in range(i + 1, len(images)):
|
| 22 |
+
if torch.allclose(images[i], images[j]):
|
| 23 |
+
images[j] = images[i]
|
| 24 |
+
return images
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DilateErosion:
|
| 28 |
+
def __init__(self, dilate_erosion=5, device='cuda'):
|
| 29 |
+
self.dilate_erosion = dilate_erosion
|
| 30 |
+
self.weight = torch.Tensor([
|
| 31 |
+
[False, True, False],
|
| 32 |
+
[True, True, True],
|
| 33 |
+
[False, True, False]
|
| 34 |
+
]).float()[None, None, ...].to(device)
|
| 35 |
+
|
| 36 |
+
def hair_from_mask(self, mask):
|
| 37 |
+
mask = torch.where(mask == 13, torch.ones_like(mask), torch.zeros_like(mask))
|
| 38 |
+
mask = F.interpolate(mask, size=(256, 256), mode='nearest')
|
| 39 |
+
dilate, erosion = self.mask(mask)
|
| 40 |
+
return dilate, erosion
|
| 41 |
+
|
| 42 |
+
def mask(self, mask):
|
| 43 |
+
masks = mask.clone().repeat(*([2] + [1] * (len(mask.shape) - 1))).float()
|
| 44 |
+
sum_w = self.weight.sum().item()
|
| 45 |
+
n = len(mask)
|
| 46 |
+
|
| 47 |
+
for _ in range(self.dilate_erosion):
|
| 48 |
+
masks = F.conv2d(masks, self.weight,
|
| 49 |
+
bias=None, stride=1, padding='same', dilation=1, groups=1)
|
| 50 |
+
masks[:n] = (masks[:n] > 0).float()
|
| 51 |
+
masks[n:] = (masks[n:] == sum_w).float()
|
| 52 |
+
|
| 53 |
+
hair_mask_dilate, hair_mask_erode = masks[:n], masks[n:]
|
| 54 |
+
|
| 55 |
+
return hair_mask_dilate, hair_mask_erode
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def poisson_image_blending(final_image, face_image, dilate_erosion=30, maxn=115):
|
| 59 |
+
dilate_erosion = DilateErosion(dilate_erosion=dilate_erosion)
|
| 60 |
+
transform = transforms.ToTensor()
|
| 61 |
+
|
| 62 |
+
if isinstance(face_image, str):
|
| 63 |
+
face_image = transform(Image.open(face_image))
|
| 64 |
+
elif not isinstance(face_image, torch.Tensor):
|
| 65 |
+
face_image = transform(face_image)
|
| 66 |
+
|
| 67 |
+
final_mask = get_segmentation(final_image.cuda().unsqueeze(0), resize=False)
|
| 68 |
+
face_mask = get_segmentation(face_image.cuda().unsqueeze(0), resize=False)
|
| 69 |
+
|
| 70 |
+
hair_target = torch.where(final_mask == 13, torch.ones_like(final_mask),
|
| 71 |
+
torch.zeros_like(final_mask))
|
| 72 |
+
hair_face = torch.where(face_mask == 13, torch.ones_like(face_mask),
|
| 73 |
+
torch.zeros_like(face_mask))
|
| 74 |
+
|
| 75 |
+
final_mask = F.interpolate(((1 - hair_target) * (1 - hair_face)).float(), size=(1024, 1024), mode='bicubic')
|
| 76 |
+
dilation, _ = dilate_erosion.mask(1 - final_mask)
|
| 77 |
+
mask_save = 1 - dilation[0]
|
| 78 |
+
|
| 79 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 80 |
+
final_image_path = os.path.join(temp_dir, 'final_image.png')
|
| 81 |
+
face_image_path = os.path.join(temp_dir, 'face_image.png')
|
| 82 |
+
mask_path = os.path.join(temp_dir, 'mask_save.png')
|
| 83 |
+
save_image(final_image, final_image_path)
|
| 84 |
+
save_image(face_image, face_image_path)
|
| 85 |
+
save_image(mask_save, mask_path)
|
| 86 |
+
|
| 87 |
+
out_image_path = os.path.join(temp_dir, 'out_image_path.png')
|
| 88 |
+
result = subprocess.run(
|
| 89 |
+
["fpie", "-s", face_image_path, "-m", mask_path, "-t", final_image_path, "-o", out_image_path, "-n",
|
| 90 |
+
str(maxn), "-b", "taichi-gpu", "-g", "max"],
|
| 91 |
+
check=True
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
return Image.open(out_image_path), Image.open(mask_path)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def list_image_files(directory):
|
| 98 |
+
image_extensions = ['.jpg', '.jpeg', '.png']
|
| 99 |
+
image_files = []
|
| 100 |
+
|
| 101 |
+
for entry in sorted(os.listdir(directory)):
|
| 102 |
+
file_path = os.path.join(directory, entry)
|
| 103 |
+
if os.path.isfile(file_path):
|
| 104 |
+
file_extension = Path(file_path).suffix.lower()
|
| 105 |
+
if file_extension in image_extensions:
|
| 106 |
+
image_files.append(entry)
|
| 107 |
+
|
| 108 |
+
return image_files
|
utils/save_utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torchvision.transforms as T
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
from models.CtrlHair.util.mask_color_util import mask_to_rgb
|
| 8 |
+
|
| 9 |
+
toPIL = T.ToPILImage()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def save_gen_image(output_dir, path, name, gen_im):
|
| 13 |
+
if len(gen_im.shape) == 4:
|
| 14 |
+
gen_im = gen_im[0]
|
| 15 |
+
save_im = toPIL(((gen_im + 1) / 2).detach().cpu().clamp(0, 1))
|
| 16 |
+
|
| 17 |
+
save_dir = output_dir / path
|
| 18 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 19 |
+
|
| 20 |
+
image_path = save_dir / name
|
| 21 |
+
save_im.save(image_path)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def save_vis_mask(output_dir, path, name, mask):
|
| 25 |
+
out_dir = output_dir / path
|
| 26 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 27 |
+
out_mask_path = out_dir / name
|
| 28 |
+
|
| 29 |
+
rgb_img = Image.fromarray(mask_to_rgb(mask.detach().cpu().squeeze(), 0))
|
| 30 |
+
rgb_img.save(out_mask_path)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def save_latents(output_dir, path, file_name, **latents):
|
| 34 |
+
save_dir = output_dir / path
|
| 35 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 36 |
+
|
| 37 |
+
latent_path = save_dir / file_name
|
| 38 |
+
np.savez(latent_path, **{key: latent.detach().cpu().numpy() for key, latent in latents.items()})
|
utils/seed.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def set_seed(seed):
|
| 9 |
+
torch.manual_seed(seed)
|
| 10 |
+
torch.cuda.manual_seed(seed)
|
| 11 |
+
torch.cuda.manual_seed_all(seed)
|
| 12 |
+
torch.backends.cudnn.benchmark = False
|
| 13 |
+
torch.backends.cudnn.deterministic = True
|
| 14 |
+
np.random.seed(seed)
|
| 15 |
+
random.seed(seed)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def seed_setter(func):
|
| 19 |
+
default_seed = 3407
|
| 20 |
+
|
| 21 |
+
@functools.wraps(func)
|
| 22 |
+
def wraps(*args, **kwargs):
|
| 23 |
+
seed = kwargs.pop('seed', None)
|
| 24 |
+
if seed is None:
|
| 25 |
+
seed = default_seed
|
| 26 |
+
set_seed(seed)
|
| 27 |
+
|
| 28 |
+
result = func(*args, **kwargs)
|
| 29 |
+
return result
|
| 30 |
+
|
| 31 |
+
return wraps
|
utils/shape_predictor.py
CHANGED
|
@@ -19,6 +19,7 @@ date: 2020.1.5
|
|
| 19 |
note: code is heavily borrowed from
|
| 20 |
https://github.com/NVlabs/ffhq-dataset
|
| 21 |
http://dlib.net/face_landmark_detection.py.html
|
|
|
|
| 22 |
requirements:
|
| 23 |
apt install cmake
|
| 24 |
conda install Pillow numpy scipy
|
|
@@ -82,7 +83,7 @@ def align_face(data, predictor=None, is_filepath=False, return_tensors=True):
|
|
| 82 |
:return: list of PIL Images
|
| 83 |
"""
|
| 84 |
if predictor is None:
|
| 85 |
-
predictor_path = 'shape_predictor_68_face_landmarks.dat'
|
| 86 |
|
| 87 |
if not os.path.isfile(predictor_path):
|
| 88 |
print("Downloading Shape Predictor")
|
|
|
|
| 19 |
note: code is heavily borrowed from
|
| 20 |
https://github.com/NVlabs/ffhq-dataset
|
| 21 |
http://dlib.net/face_landmark_detection.py.html
|
| 22 |
+
|
| 23 |
requirements:
|
| 24 |
apt install cmake
|
| 25 |
conda install Pillow numpy scipy
|
|
|
|
| 83 |
:return: list of PIL Images
|
| 84 |
"""
|
| 85 |
if predictor is None:
|
| 86 |
+
predictor_path = 'pretrained_models/ShapeAdaptor/shape_predictor_68_face_landmarks.dat'
|
| 87 |
|
| 88 |
if not os.path.isfile(predictor_path):
|
| 89 |
print("Downloading Shape Predictor")
|
utils/time.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_time():
|
| 10 |
+
torch.cuda.current_stream().synchronize()
|
| 11 |
+
return time.time()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def bench_session(func):
|
| 15 |
+
times = []
|
| 16 |
+
|
| 17 |
+
@functools.wraps(func)
|
| 18 |
+
def wraps(*args, **kwargs):
|
| 19 |
+
if kwargs.pop('benchmark', False):
|
| 20 |
+
nonlocal times
|
| 21 |
+
start = get_time()
|
| 22 |
+
|
| 23 |
+
result = func(*args, **kwargs)
|
| 24 |
+
|
| 25 |
+
eval_time = get_time() - start
|
| 26 |
+
times.append(eval_time)
|
| 27 |
+
|
| 28 |
+
print(f'\n{len(times)} experiment ended in {eval_time:.3f}(s)', file=sys.stderr)
|
| 29 |
+
print(f'min time: {np.min(times):.3f}(s),'
|
| 30 |
+
f' median time: {np.median(times):.3f}(s),'
|
| 31 |
+
f' std time: {np.std(times):.3f}(s)', file=sys.stderr)
|
| 32 |
+
return result
|
| 33 |
+
else:
|
| 34 |
+
return func(*args, **kwargs)
|
| 35 |
+
|
| 36 |
+
return wraps
|
utils/train.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import random
|
| 4 |
+
import shutil
|
| 5 |
+
import typing as tp
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
import wandb
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from joblib import Parallel, delayed
|
| 13 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 14 |
+
from torchmetrics.image.fid import FrechetInceptionDistance
|
| 15 |
+
from tqdm.auto import tqdm
|
| 16 |
+
|
| 17 |
+
from models.Encoders import ClipModel
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def image_grid(imgs, rows, cols):
|
| 21 |
+
assert len(imgs) == rows * cols
|
| 22 |
+
|
| 23 |
+
w, h = imgs[0].size
|
| 24 |
+
grid = Image.new('RGB', size=(cols * w, rows * h))
|
| 25 |
+
|
| 26 |
+
for i, img in enumerate(imgs):
|
| 27 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
| 28 |
+
return grid
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class WandbLogger:
|
| 32 |
+
def __init__(self, name='base-name', project='HairFast'):
|
| 33 |
+
self.name = name
|
| 34 |
+
self.project = project
|
| 35 |
+
|
| 36 |
+
def start_logging(self):
|
| 37 |
+
wandb.login(key=os.environ['WANDB_KEY'].strip(), relogin=True)
|
| 38 |
+
wandb.init(
|
| 39 |
+
project=self.project,
|
| 40 |
+
name=self.name
|
| 41 |
+
)
|
| 42 |
+
self.wandb = wandb
|
| 43 |
+
self.run_dir = self.wandb.run.dir
|
| 44 |
+
self.train_step = 0
|
| 45 |
+
|
| 46 |
+
def log(self, scalar_name: str, scalar: tp.Any):
|
| 47 |
+
self.wandb.log({scalar_name: scalar}, step=self.train_step, commit=False)
|
| 48 |
+
|
| 49 |
+
def log_scalars(self, scalars: dict):
|
| 50 |
+
self.wandb.log(scalars, step=self.train_step, commit=False)
|
| 51 |
+
|
| 52 |
+
def next_step(self):
|
| 53 |
+
self.train_step += 1
|
| 54 |
+
|
| 55 |
+
def save(self, file_path, save_online=True):
|
| 56 |
+
file = os.path.basename(file_path)
|
| 57 |
+
new_path = os.path.join(self.run_dir, file)
|
| 58 |
+
shutil.copy2(file_path, new_path)
|
| 59 |
+
if save_online:
|
| 60 |
+
self.wandb.save(new_path)
|
| 61 |
+
|
| 62 |
+
def __del__(self):
|
| 63 |
+
self.wandb.finish()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def toggle_grad(model, flag=True):
|
| 67 |
+
for p in model.parameters():
|
| 68 |
+
p.requires_grad = flag
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class _LegacyUnpickler(pickle.Unpickler):
|
| 72 |
+
def find_class(self, module, name):
|
| 73 |
+
if module == 'dnnlib.tflib.network' and name == 'Network':
|
| 74 |
+
return _TFNetworkStub
|
| 75 |
+
module = module.replace('torch_utils', 'models.stylegan2.torch_utils')
|
| 76 |
+
module = module.replace('dnnlib', 'models.stylegan2.dnnlib')
|
| 77 |
+
return super().find_class(module, name)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def seed_everything(seed: int = 1729) -> None:
|
| 81 |
+
random.seed(seed)
|
| 82 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 83 |
+
np.random.seed(seed)
|
| 84 |
+
torch.manual_seed(seed)
|
| 85 |
+
torch.cuda.manual_seed(seed)
|
| 86 |
+
torch.cuda.manual_seed_all(seed)
|
| 87 |
+
torch.backends.cudnn.deterministic = True
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def load_images_to_torch(paths, imgs=None, use_tqdm=True):
|
| 91 |
+
transform = T.PILToTensor()
|
| 92 |
+
tensor = []
|
| 93 |
+
for path in paths:
|
| 94 |
+
if imgs is None:
|
| 95 |
+
pbar = sorted(os.listdir(path))
|
| 96 |
+
else:
|
| 97 |
+
pbar = imgs
|
| 98 |
+
|
| 99 |
+
if use_tqdm:
|
| 100 |
+
pbar = tqdm(pbar)
|
| 101 |
+
|
| 102 |
+
for img_name in pbar:
|
| 103 |
+
if '.jpg' in img_name or '.png' in img_name:
|
| 104 |
+
img_path = os.path.join(path, img_name)
|
| 105 |
+
img = Image.open(img_path).resize((299, 299), resample=Image.LANCZOS)
|
| 106 |
+
tensor.append(transform(img))
|
| 107 |
+
try:
|
| 108 |
+
return torch.stack(tensor)
|
| 109 |
+
except:
|
| 110 |
+
print(paths, imgs)
|
| 111 |
+
return torch.tensor([], dtype=torch.uint8)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def parallel_load_images(paths, imgs):
|
| 115 |
+
assert imgs is not None
|
| 116 |
+
if not isinstance(paths, list):
|
| 117 |
+
paths = [paths]
|
| 118 |
+
|
| 119 |
+
list_torch_images = Parallel(n_jobs=-1)(delayed(load_images_to_torch)(
|
| 120 |
+
paths, [i], use_tqdm=False
|
| 121 |
+
) for i in tqdm(imgs))
|
| 122 |
+
return torch.cat(list_torch_images)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_fid_calc(instance='fid.pkl', dataset_path='', device=torch.device('cuda')):
|
| 126 |
+
if os.path.isfile(instance):
|
| 127 |
+
with open(instance, 'rb') as f:
|
| 128 |
+
fid = pickle.load(f)
|
| 129 |
+
else:
|
| 130 |
+
fid = FrechetInceptionDistance(feature=ClipModel(), reset_real_features=False, normalize=True)
|
| 131 |
+
fid.to(device).eval()
|
| 132 |
+
|
| 133 |
+
imgs_file = []
|
| 134 |
+
for file in os.listdir(dataset_path):
|
| 135 |
+
if 'flip' not in file and os.path.splitext(file)[1] in ['.png', '.jpg']:
|
| 136 |
+
imgs_file.append(file)
|
| 137 |
+
|
| 138 |
+
tensor_images = parallel_load_images([dataset_path], imgs_file).float().div(255)
|
| 139 |
+
real_dataloader = DataLoader(TensorDataset(tensor_images), batch_size=128)
|
| 140 |
+
with torch.inference_mode():
|
| 141 |
+
for batch in tqdm(real_dataloader):
|
| 142 |
+
batch = batch[0].to(device)
|
| 143 |
+
fid.update(batch, real=True)
|
| 144 |
+
|
| 145 |
+
with open(instance, 'wb') as f:
|
| 146 |
+
pickle.dump(fid.cpu(), f)
|
| 147 |
+
fid.to(device).eval()
|
| 148 |
+
|
| 149 |
+
@torch.inference_mode()
|
| 150 |
+
def compute_fid_datasets(images):
|
| 151 |
+
nonlocal fid, device
|
| 152 |
+
fid.reset()
|
| 153 |
+
|
| 154 |
+
fake_dataloader = DataLoader(TensorDataset(images), batch_size=128)
|
| 155 |
+
for batch in tqdm(fake_dataloader):
|
| 156 |
+
batch = batch[0].to(device)
|
| 157 |
+
fid.update(batch, real=False)
|
| 158 |
+
|
| 159 |
+
return fid.compute()
|
| 160 |
+
|
| 161 |
+
return compute_fid_datasets
|