miguelmuzo commited on
Commit
a88f97d
·
verified ·
1 Parent(s): d01a7d5

Upload 10 files

Browse files
utils/PCA_utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.decomposition import IncrementalPCA
2
+ import numpy as np
3
+ class IPCAEstimator():
4
+ def __init__(self, n_components):
5
+ self.n_components = n_components
6
+ self.whiten = False
7
+ self.transformer = IncrementalPCA(n_components, whiten=self.whiten, batch_size=max(100, 5*n_components))
8
+ self.batch_support = True
9
+
10
+ def get_param_str(self):
11
+ return "ipca_c{}{}".format(self.n_components, '_w' if self.whiten else '')
12
+
13
+ def fit(self, X):
14
+ self.transformer.fit(X)
15
+
16
+ def fit_partial(self, X):
17
+ try:
18
+ self.transformer.partial_fit(X)
19
+ self.transformer.n_samples_seen_ = \
20
+ self.transformer.n_samples_seen_.astype(np.int64) # avoid overflow
21
+ return True
22
+ except ValueError as e:
23
+ print(f'\nIPCA error:', e)
24
+ return False
25
+
26
+ def get_components(self):
27
+ stdev = np.sqrt(self.transformer.explained_variance_) # already sorted
28
+ var_ratio = self.transformer.explained_variance_ratio_
29
+ return self.transformer.components_, stdev, var_ratio # PCA outputs are normalized
utils/__init__.py ADDED
File without changes
utils/bicubic.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class BicubicDownSample(nn.Module):
7
+ def bicubic_kernel(self, x, a=-0.50):
8
+ """
9
+ This equation is exactly copied from the website below:
10
+ https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic
11
+ """
12
+ abs_x = torch.abs(x)
13
+ if abs_x <= 1.:
14
+ return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1
15
+ elif 1. < abs_x < 2.:
16
+ return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a
17
+ else:
18
+ return 0.0
19
+
20
+ def __init__(self, factor=4, cuda=True, padding='reflect'):
21
+ super().__init__()
22
+ self.factor = factor
23
+ size = factor * 4
24
+ k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor)
25
+ for i in range(size)], dtype=torch.float32)
26
+ k = k / torch.sum(k)
27
+ # k = torch.einsum('i,j->ij', (k, k))
28
+ k1 = torch.reshape(k, shape=(1, 1, size, 1))
29
+ self.k1 = torch.cat([k1, k1, k1], dim=0)
30
+ k2 = torch.reshape(k, shape=(1, 1, 1, size))
31
+ self.k2 = torch.cat([k2, k2, k2], dim=0)
32
+ self.cuda = '.cuda' if cuda else ''
33
+ self.padding = padding
34
+ for param in self.parameters():
35
+ param.requires_grad = False
36
+
37
+ def forward(self, x, nhwc=False, clip_round=False, byte_output=False):
38
+ # x = torch.from_numpy(x).type('torch.FloatTensor')
39
+ filter_height = self.factor * 4
40
+ filter_width = self.factor * 4
41
+ stride = self.factor
42
+
43
+ pad_along_height = max(filter_height - stride, 0)
44
+ pad_along_width = max(filter_width - stride, 0)
45
+ filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda))
46
+ filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda))
47
+
48
+ # compute actual padding values for each side
49
+ pad_top = pad_along_height // 2
50
+ pad_bottom = pad_along_height - pad_top
51
+ pad_left = pad_along_width // 2
52
+ pad_right = pad_along_width - pad_left
53
+
54
+ # apply mirror padding
55
+ if nhwc:
56
+ x = torch.transpose(torch.transpose(
57
+ x, 2, 3), 1, 2) # NHWC to NCHW
58
+
59
+ # downscaling performed by 1-d convolution
60
+ x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding)
61
+ x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3)
62
+ if clip_round:
63
+ x = torch.clamp(torch.round(x), 0.0, 255.)
64
+
65
+ x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding)
66
+ x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3)
67
+ if clip_round:
68
+ x = torch.clamp(torch.round(x), 0.0, 255.)
69
+
70
+ if nhwc:
71
+ x = torch.transpose(torch.transpose(x, 1, 3), 1, 2)
72
+ if byte_output:
73
+ return x.type('torch.ByteTensor'.format(self.cuda))
74
+ else:
75
+ return x
utils/image_utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import tempfile
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from PIL import Image
9
+ from torchvision.transforms import transforms
10
+ from torchvision.utils import save_image
11
+
12
+ from models.Net import get_segmentation
13
+
14
+
15
+ def equal_replacer(images: list[torch.Tensor]) -> list[torch.Tensor]:
16
+ for i in range(len(images)):
17
+ if images[i].dtype is torch.uint8:
18
+ images[i] = images[i] / 255
19
+
20
+ for i in range(len(images)):
21
+ for j in range(i + 1, len(images)):
22
+ if torch.allclose(images[i], images[j]):
23
+ images[j] = images[i]
24
+ return images
25
+
26
+
27
+ class DilateErosion:
28
+ def __init__(self, dilate_erosion=5, device='cuda'):
29
+ self.dilate_erosion = dilate_erosion
30
+ self.weight = torch.Tensor([
31
+ [False, True, False],
32
+ [True, True, True],
33
+ [False, True, False]
34
+ ]).float()[None, None, ...].to(device)
35
+
36
+ def hair_from_mask(self, mask):
37
+ mask = torch.where(mask == 13, torch.ones_like(mask), torch.zeros_like(mask))
38
+ mask = F.interpolate(mask, size=(256, 256), mode='nearest')
39
+ dilate, erosion = self.mask(mask)
40
+ return dilate, erosion
41
+
42
+ def mask(self, mask):
43
+ masks = mask.clone().repeat(*([2] + [1] * (len(mask.shape) - 1))).float()
44
+ sum_w = self.weight.sum().item()
45
+ n = len(mask)
46
+
47
+ for _ in range(self.dilate_erosion):
48
+ masks = F.conv2d(masks, self.weight,
49
+ bias=None, stride=1, padding='same', dilation=1, groups=1)
50
+ masks[:n] = (masks[:n] > 0).float()
51
+ masks[n:] = (masks[n:] == sum_w).float()
52
+
53
+ hair_mask_dilate, hair_mask_erode = masks[:n], masks[n:]
54
+
55
+ return hair_mask_dilate, hair_mask_erode
56
+
57
+
58
+ def poisson_image_blending(final_image, face_image, dilate_erosion=30, maxn=115):
59
+ dilate_erosion = DilateErosion(dilate_erosion=dilate_erosion)
60
+ transform = transforms.ToTensor()
61
+
62
+ if isinstance(face_image, str):
63
+ face_image = transform(Image.open(face_image))
64
+ elif not isinstance(face_image, torch.Tensor):
65
+ face_image = transform(face_image)
66
+
67
+ final_mask = get_segmentation(final_image.cuda().unsqueeze(0), resize=False)
68
+ face_mask = get_segmentation(face_image.cuda().unsqueeze(0), resize=False)
69
+
70
+ hair_target = torch.where(final_mask == 13, torch.ones_like(final_mask),
71
+ torch.zeros_like(final_mask))
72
+ hair_face = torch.where(face_mask == 13, torch.ones_like(face_mask),
73
+ torch.zeros_like(face_mask))
74
+
75
+ final_mask = F.interpolate(((1 - hair_target) * (1 - hair_face)).float(), size=(1024, 1024), mode='bicubic')
76
+ dilation, _ = dilate_erosion.mask(1 - final_mask)
77
+ mask_save = 1 - dilation[0]
78
+
79
+ with tempfile.TemporaryDirectory() as temp_dir:
80
+ final_image_path = os.path.join(temp_dir, 'final_image.png')
81
+ face_image_path = os.path.join(temp_dir, 'face_image.png')
82
+ mask_path = os.path.join(temp_dir, 'mask_save.png')
83
+ save_image(final_image, final_image_path)
84
+ save_image(face_image, face_image_path)
85
+ save_image(mask_save, mask_path)
86
+
87
+ out_image_path = os.path.join(temp_dir, 'out_image_path.png')
88
+ result = subprocess.run(
89
+ ["fpie", "-s", face_image_path, "-m", mask_path, "-t", final_image_path, "-o", out_image_path, "-n",
90
+ str(maxn), "-b", "taichi-gpu", "-g", "max"],
91
+ check=True
92
+ )
93
+
94
+ return Image.open(out_image_path), Image.open(mask_path)
95
+
96
+
97
+ def list_image_files(directory):
98
+ image_extensions = ['.jpg', '.jpeg', '.png']
99
+ image_files = []
100
+
101
+ for entry in sorted(os.listdir(directory)):
102
+ file_path = os.path.join(directory, entry)
103
+ if os.path.isfile(file_path):
104
+ file_extension = Path(file_path).suffix.lower()
105
+ if file_extension in image_extensions:
106
+ image_files.append(entry)
107
+
108
+ return image_files
utils/save_utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torchvision.transforms as T
5
+ from PIL import Image
6
+
7
+ from models.CtrlHair.util.mask_color_util import mask_to_rgb
8
+
9
+ toPIL = T.ToPILImage()
10
+
11
+
12
+ def save_gen_image(output_dir, path, name, gen_im):
13
+ if len(gen_im.shape) == 4:
14
+ gen_im = gen_im[0]
15
+ save_im = toPIL(((gen_im + 1) / 2).detach().cpu().clamp(0, 1))
16
+
17
+ save_dir = output_dir / path
18
+ os.makedirs(save_dir, exist_ok=True)
19
+
20
+ image_path = save_dir / name
21
+ save_im.save(image_path)
22
+
23
+
24
+ def save_vis_mask(output_dir, path, name, mask):
25
+ out_dir = output_dir / path
26
+ os.makedirs(out_dir, exist_ok=True)
27
+ out_mask_path = out_dir / name
28
+
29
+ rgb_img = Image.fromarray(mask_to_rgb(mask.detach().cpu().squeeze(), 0))
30
+ rgb_img.save(out_mask_path)
31
+
32
+
33
+ def save_latents(output_dir, path, file_name, **latents):
34
+ save_dir = output_dir / path
35
+ os.makedirs(save_dir, exist_ok=True)
36
+
37
+ latent_path = save_dir / file_name
38
+ np.savez(latent_path, **{key: latent.detach().cpu().numpy() for key, latent in latents.items()})
utils/seed.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def set_seed(seed):
9
+ torch.manual_seed(seed)
10
+ torch.cuda.manual_seed(seed)
11
+ torch.cuda.manual_seed_all(seed)
12
+ torch.backends.cudnn.benchmark = False
13
+ torch.backends.cudnn.deterministic = True
14
+ np.random.seed(seed)
15
+ random.seed(seed)
16
+
17
+
18
+ def seed_setter(func):
19
+ default_seed = 3407
20
+
21
+ @functools.wraps(func)
22
+ def wraps(*args, **kwargs):
23
+ seed = kwargs.pop('seed', None)
24
+ if seed is None:
25
+ seed = default_seed
26
+ set_seed(seed)
27
+
28
+ result = func(*args, **kwargs)
29
+ return result
30
+
31
+ return wraps
utils/shape_predictor.py CHANGED
@@ -19,6 +19,7 @@ date: 2020.1.5
19
  note: code is heavily borrowed from
20
  https://github.com/NVlabs/ffhq-dataset
21
  http://dlib.net/face_landmark_detection.py.html
 
22
  requirements:
23
  apt install cmake
24
  conda install Pillow numpy scipy
@@ -82,7 +83,7 @@ def align_face(data, predictor=None, is_filepath=False, return_tensors=True):
82
  :return: list of PIL Images
83
  """
84
  if predictor is None:
85
- predictor_path = 'shape_predictor_68_face_landmarks.dat'
86
 
87
  if not os.path.isfile(predictor_path):
88
  print("Downloading Shape Predictor")
 
19
  note: code is heavily borrowed from
20
  https://github.com/NVlabs/ffhq-dataset
21
  http://dlib.net/face_landmark_detection.py.html
22
+
23
  requirements:
24
  apt install cmake
25
  conda install Pillow numpy scipy
 
83
  :return: list of PIL Images
84
  """
85
  if predictor is None:
86
+ predictor_path = 'pretrained_models/ShapeAdaptor/shape_predictor_68_face_landmarks.dat'
87
 
88
  if not os.path.isfile(predictor_path):
89
  print("Downloading Shape Predictor")
utils/time.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import sys
3
+ import time
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ def get_time():
10
+ torch.cuda.current_stream().synchronize()
11
+ return time.time()
12
+
13
+
14
+ def bench_session(func):
15
+ times = []
16
+
17
+ @functools.wraps(func)
18
+ def wraps(*args, **kwargs):
19
+ if kwargs.pop('benchmark', False):
20
+ nonlocal times
21
+ start = get_time()
22
+
23
+ result = func(*args, **kwargs)
24
+
25
+ eval_time = get_time() - start
26
+ times.append(eval_time)
27
+
28
+ print(f'\n{len(times)} experiment ended in {eval_time:.3f}(s)', file=sys.stderr)
29
+ print(f'min time: {np.min(times):.3f}(s),'
30
+ f' median time: {np.median(times):.3f}(s),'
31
+ f' std time: {np.std(times):.3f}(s)', file=sys.stderr)
32
+ return result
33
+ else:
34
+ return func(*args, **kwargs)
35
+
36
+ return wraps
utils/train.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import random
4
+ import shutil
5
+ import typing as tp
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torchvision.transforms as T
10
+ import wandb
11
+ from PIL import Image
12
+ from joblib import Parallel, delayed
13
+ from torch.utils.data import DataLoader, TensorDataset
14
+ from torchmetrics.image.fid import FrechetInceptionDistance
15
+ from tqdm.auto import tqdm
16
+
17
+ from models.Encoders import ClipModel
18
+
19
+
20
+ def image_grid(imgs, rows, cols):
21
+ assert len(imgs) == rows * cols
22
+
23
+ w, h = imgs[0].size
24
+ grid = Image.new('RGB', size=(cols * w, rows * h))
25
+
26
+ for i, img in enumerate(imgs):
27
+ grid.paste(img, box=(i % cols * w, i // cols * h))
28
+ return grid
29
+
30
+
31
+ class WandbLogger:
32
+ def __init__(self, name='base-name', project='HairFast'):
33
+ self.name = name
34
+ self.project = project
35
+
36
+ def start_logging(self):
37
+ wandb.login(key=os.environ['WANDB_KEY'].strip(), relogin=True)
38
+ wandb.init(
39
+ project=self.project,
40
+ name=self.name
41
+ )
42
+ self.wandb = wandb
43
+ self.run_dir = self.wandb.run.dir
44
+ self.train_step = 0
45
+
46
+ def log(self, scalar_name: str, scalar: tp.Any):
47
+ self.wandb.log({scalar_name: scalar}, step=self.train_step, commit=False)
48
+
49
+ def log_scalars(self, scalars: dict):
50
+ self.wandb.log(scalars, step=self.train_step, commit=False)
51
+
52
+ def next_step(self):
53
+ self.train_step += 1
54
+
55
+ def save(self, file_path, save_online=True):
56
+ file = os.path.basename(file_path)
57
+ new_path = os.path.join(self.run_dir, file)
58
+ shutil.copy2(file_path, new_path)
59
+ if save_online:
60
+ self.wandb.save(new_path)
61
+
62
+ def __del__(self):
63
+ self.wandb.finish()
64
+
65
+
66
+ def toggle_grad(model, flag=True):
67
+ for p in model.parameters():
68
+ p.requires_grad = flag
69
+
70
+
71
+ class _LegacyUnpickler(pickle.Unpickler):
72
+ def find_class(self, module, name):
73
+ if module == 'dnnlib.tflib.network' and name == 'Network':
74
+ return _TFNetworkStub
75
+ module = module.replace('torch_utils', 'models.stylegan2.torch_utils')
76
+ module = module.replace('dnnlib', 'models.stylegan2.dnnlib')
77
+ return super().find_class(module, name)
78
+
79
+
80
+ def seed_everything(seed: int = 1729) -> None:
81
+ random.seed(seed)
82
+ os.environ["PYTHONHASHSEED"] = str(seed)
83
+ np.random.seed(seed)
84
+ torch.manual_seed(seed)
85
+ torch.cuda.manual_seed(seed)
86
+ torch.cuda.manual_seed_all(seed)
87
+ torch.backends.cudnn.deterministic = True
88
+
89
+
90
+ def load_images_to_torch(paths, imgs=None, use_tqdm=True):
91
+ transform = T.PILToTensor()
92
+ tensor = []
93
+ for path in paths:
94
+ if imgs is None:
95
+ pbar = sorted(os.listdir(path))
96
+ else:
97
+ pbar = imgs
98
+
99
+ if use_tqdm:
100
+ pbar = tqdm(pbar)
101
+
102
+ for img_name in pbar:
103
+ if '.jpg' in img_name or '.png' in img_name:
104
+ img_path = os.path.join(path, img_name)
105
+ img = Image.open(img_path).resize((299, 299), resample=Image.LANCZOS)
106
+ tensor.append(transform(img))
107
+ try:
108
+ return torch.stack(tensor)
109
+ except:
110
+ print(paths, imgs)
111
+ return torch.tensor([], dtype=torch.uint8)
112
+
113
+
114
+ def parallel_load_images(paths, imgs):
115
+ assert imgs is not None
116
+ if not isinstance(paths, list):
117
+ paths = [paths]
118
+
119
+ list_torch_images = Parallel(n_jobs=-1)(delayed(load_images_to_torch)(
120
+ paths, [i], use_tqdm=False
121
+ ) for i in tqdm(imgs))
122
+ return torch.cat(list_torch_images)
123
+
124
+
125
+ def get_fid_calc(instance='fid.pkl', dataset_path='', device=torch.device('cuda')):
126
+ if os.path.isfile(instance):
127
+ with open(instance, 'rb') as f:
128
+ fid = pickle.load(f)
129
+ else:
130
+ fid = FrechetInceptionDistance(feature=ClipModel(), reset_real_features=False, normalize=True)
131
+ fid.to(device).eval()
132
+
133
+ imgs_file = []
134
+ for file in os.listdir(dataset_path):
135
+ if 'flip' not in file and os.path.splitext(file)[1] in ['.png', '.jpg']:
136
+ imgs_file.append(file)
137
+
138
+ tensor_images = parallel_load_images([dataset_path], imgs_file).float().div(255)
139
+ real_dataloader = DataLoader(TensorDataset(tensor_images), batch_size=128)
140
+ with torch.inference_mode():
141
+ for batch in tqdm(real_dataloader):
142
+ batch = batch[0].to(device)
143
+ fid.update(batch, real=True)
144
+
145
+ with open(instance, 'wb') as f:
146
+ pickle.dump(fid.cpu(), f)
147
+ fid.to(device).eval()
148
+
149
+ @torch.inference_mode()
150
+ def compute_fid_datasets(images):
151
+ nonlocal fid, device
152
+ fid.reset()
153
+
154
+ fake_dataloader = DataLoader(TensorDataset(images), batch_size=128)
155
+ for batch in tqdm(fake_dataloader):
156
+ batch = batch[0].to(device)
157
+ fid.update(batch, real=False)
158
+
159
+ return fid.compute()
160
+
161
+ return compute_fid_datasets