Spaces:
Paused
Paused
| import os | |
| import pickle | |
| import random | |
| import shutil | |
| import typing as tp | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as T | |
| import wandb | |
| from PIL import Image | |
| from joblib import Parallel, delayed | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from torchmetrics.image.fid import FrechetInceptionDistance | |
| from tqdm.auto import tqdm | |
| from models.Encoders import ClipModel | |
| def image_grid(imgs, rows, cols): | |
| assert len(imgs) == rows * cols | |
| w, h = imgs[0].size | |
| grid = Image.new('RGB', size=(cols * w, rows * h)) | |
| for i, img in enumerate(imgs): | |
| grid.paste(img, box=(i % cols * w, i // cols * h)) | |
| return grid | |
| class WandbLogger: | |
| def __init__(self, name='base-name', project='HairFast'): | |
| self.name = name | |
| self.project = project | |
| def start_logging(self): | |
| wandb.login(key=os.environ['WANDB_KEY'].strip(), relogin=True) | |
| wandb.init( | |
| project=self.project, | |
| name=self.name | |
| ) | |
| self.wandb = wandb | |
| self.run_dir = self.wandb.run.dir | |
| self.train_step = 0 | |
| def log(self, scalar_name: str, scalar: tp.Any): | |
| self.wandb.log({scalar_name: scalar}, step=self.train_step, commit=False) | |
| def log_scalars(self, scalars: dict): | |
| self.wandb.log(scalars, step=self.train_step, commit=False) | |
| def next_step(self): | |
| self.train_step += 1 | |
| def save(self, file_path, save_online=True): | |
| file = os.path.basename(file_path) | |
| new_path = os.path.join(self.run_dir, file) | |
| shutil.copy2(file_path, new_path) | |
| if save_online: | |
| self.wandb.save(new_path) | |
| def __del__(self): | |
| self.wandb.finish() | |
| def toggle_grad(model, flag=True): | |
| for p in model.parameters(): | |
| p.requires_grad = flag | |
| class _LegacyUnpickler(pickle.Unpickler): | |
| def find_class(self, module, name): | |
| if module == 'dnnlib.tflib.network' and name == 'Network': | |
| return _TFNetworkStub | |
| module = module.replace('torch_utils', 'models.stylegan2.torch_utils') | |
| module = module.replace('dnnlib', 'models.stylegan2.dnnlib') | |
| return super().find_class(module, name) | |
| def seed_everything(seed: int = 1729) -> None: | |
| random.seed(seed) | |
| os.environ["PYTHONHASHSEED"] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| def load_images_to_torch(paths, imgs=None, use_tqdm=True): | |
| transform = T.PILToTensor() | |
| tensor = [] | |
| for path in paths: | |
| if imgs is None: | |
| pbar = sorted(os.listdir(path)) | |
| else: | |
| pbar = imgs | |
| if use_tqdm: | |
| pbar = tqdm(pbar) | |
| for img_name in pbar: | |
| if '.jpg' in img_name or '.png' in img_name: | |
| img_path = os.path.join(path, img_name) | |
| img = Image.open(img_path).resize((299, 299), resample=Image.LANCZOS) | |
| tensor.append(transform(img)) | |
| try: | |
| return torch.stack(tensor) | |
| except: | |
| print(paths, imgs) | |
| return torch.tensor([], dtype=torch.uint8) | |
| def parallel_load_images(paths, imgs): | |
| assert imgs is not None | |
| if not isinstance(paths, list): | |
| paths = [paths] | |
| list_torch_images = Parallel(n_jobs=-1)(delayed(load_images_to_torch)( | |
| paths, [i], use_tqdm=False | |
| ) for i in tqdm(imgs)) | |
| return torch.cat(list_torch_images) | |
| def get_fid_calc(instance='fid.pkl', dataset_path='', device=torch.device('cuda')): | |
| if os.path.isfile(instance): | |
| with open(instance, 'rb') as f: | |
| fid = pickle.load(f) | |
| else: | |
| fid = FrechetInceptionDistance(feature=ClipModel(), reset_real_features=False, normalize=True) | |
| fid.to(device).eval() | |
| imgs_file = [] | |
| for file in os.listdir(dataset_path): | |
| if 'flip' not in file and os.path.splitext(file)[1] in ['.png', '.jpg']: | |
| imgs_file.append(file) | |
| tensor_images = parallel_load_images([dataset_path], imgs_file).float().div(255) | |
| real_dataloader = DataLoader(TensorDataset(tensor_images), batch_size=128) | |
| with torch.inference_mode(): | |
| for batch in tqdm(real_dataloader): | |
| batch = batch[0].to(device) | |
| fid.update(batch, real=True) | |
| with open(instance, 'wb') as f: | |
| pickle.dump(fid.cpu(), f) | |
| fid.to(device).eval() | |
| def compute_fid_datasets(images): | |
| nonlocal fid, device | |
| fid.reset() | |
| fake_dataloader = DataLoader(TensorDataset(images), batch_size=128) | |
| for batch in tqdm(fake_dataloader): | |
| batch = batch[0].to(device) | |
| fid.update(batch, real=False) | |
| return fid.compute() | |
| return compute_fid_datasets | |