Spaces:
Build error
Build error
| from numpy.core.numeric import require | |
| from numpy.lib.function_base import quantile | |
| import torch | |
| import torch.nn.functional as F | |
| import copy | |
| import numpy as np | |
| import os | |
| import sys | |
| import cv2 | |
| import argparse | |
| import tqdm | |
| from utils.commons.multiprocess_utils import multiprocess_run_tqdm | |
| from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker | |
| from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel | |
| import pickle | |
| face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM', | |
| camera_distance=10, focal=1015, keypoint_mode='mediapipe') | |
| face_model.to("cuda") | |
| index_lm68_from_lm468 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305, | |
| 33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178] | |
| dir_path = os.path.dirname(os.path.realpath(__file__)) | |
| LAMBDA_REG_ID = 0.3 | |
| LAMBDA_REG_EXP = 0.05 | |
| def save_file(name, content): | |
| with open(name, "wb") as f: | |
| pickle.dump(content, f) | |
| def load_file(name): | |
| with open(name, "rb") as f: | |
| content = pickle.load(f) | |
| return content | |
| def cal_lan_loss_mp(proj_lan, gt_lan): | |
| # [B, 68, 2] | |
| loss = (proj_lan - gt_lan).pow(2) | |
| # loss = (proj_lan - gt_lan).abs() | |
| unmatch_mask = [ 93, 127, 132, 234, 323, 356, 361, 454] | |
| eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] | |
| inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] | |
| outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] | |
| weights = torch.ones_like(loss) | |
| weights[:, eye] = 5 | |
| weights[:, inner_lip] = 2 | |
| weights[:, outer_lip] = 2 | |
| weights[:, unmatch_mask] = 0 | |
| loss = loss * weights | |
| return torch.mean(loss) | |
| def cal_lan_loss(proj_lan, gt_lan): | |
| # [B, 68, 2] | |
| loss = (proj_lan - gt_lan)** 2 | |
| # use the ldm weights from deep3drecon, see deep_3drecon/deep_3drecon_models/losses.py | |
| weights = torch.zeros_like(loss) | |
| weights = torch.ones_like(loss) | |
| weights[:, 36:48, :] = 3 # eye 12 points | |
| weights[:, -8:, :] = 3 # inner lip 8 points | |
| weights[:, 28:31, :] = 3 # nose 3 points | |
| loss = loss * weights | |
| return torch.mean(loss) | |
| def set_requires_grad(tensor_list): | |
| for tensor in tensor_list: | |
| tensor.requires_grad = True | |
| def read_video_to_frames(img_name): | |
| frames = [] | |
| cap = cv2.VideoCapture(img_name) | |
| while cap.isOpened(): | |
| ret, frame_bgr = cap.read() | |
| if frame_bgr is None: | |
| break | |
| frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
| frames.append(frame_rgb) | |
| return np.stack(frames) | |
| def fit_3dmm_for_a_image(img_name, debug=False, keypoint_mode='mediapipe', device="cuda:0", save=True): | |
| img = cv2.imread(img_name) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img_h, img_w = img.shape[0], img.shape[0] | |
| assert img_h == img_w | |
| num_frames = 1 | |
| lm_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png", "_lms.npy") | |
| if lm_name.endswith('_lms.npy') and os.path.exists(lm_name): | |
| lms = np.load(lm_name) | |
| else: | |
| # print("lms_2d file not found, try to extract it from image...") | |
| try: | |
| landmarker = MediapipeLandmarker() | |
| lms = landmarker.extract_lm478_from_img_name(img_name) | |
| # lms = landmarker.extract_lm478_from_img(img) | |
| except Exception as e: | |
| print(e) | |
| return | |
| if lms is None: | |
| print("get None lms_2d, please check whether each frame has one head, exiting...") | |
| return | |
| lms = lms[:468].reshape([468,2]) | |
| lms = torch.FloatTensor(lms).to(device=device) | |
| lms[..., 1] = img_h - lms[..., 1] # flip the height axis | |
| if keypoint_mode == 'mediapipe': | |
| cal_lan_loss_fn = cal_lan_loss_mp | |
| out_name = img_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png", "_coeff_fit_mp.npy") | |
| else: | |
| cal_lan_loss_fn = cal_lan_loss | |
| out_name = img_name.replace("/images_512/", "/coeff_fit_lm68/").replace(".png", "_coeff_fit_lm68.npy") | |
| try: | |
| os.makedirs(os.path.dirname(out_name), exist_ok=True) | |
| except: | |
| pass | |
| id_dim, exp_dim = 80, 64 | |
| sel_ids = np.arange(0, num_frames, 40) | |
| sel_num = sel_ids.shape[0] | |
| arg_focal = face_model.focal | |
| h = w = face_model.center * 2 | |
| img_scale_factor = img_h / h | |
| lms /= img_scale_factor | |
| cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).to(device=device) | |
| id_para = lms.new_zeros((num_frames, id_dim), requires_grad=True) # lms.new_zeros((1, id_dim), requires_grad=True) | |
| exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True) | |
| euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True) | |
| trans = lms.new_zeros((num_frames, 3), requires_grad=True) | |
| focal_length = lms.new_zeros(1, requires_grad=True) | |
| focal_length.data += arg_focal | |
| set_requires_grad([id_para, exp_para, euler_angle, trans]) | |
| optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1) | |
| optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=.1) | |
| # 其他参数初始化,先训练euler和trans | |
| for _ in range(200): | |
| proj_geo = face_model.compute_for_landmark_fit( | |
| id_para, exp_para, euler_angle, trans) | |
| loss_lan = cal_lan_loss_fn(proj_geo[:, :, :2], lms.detach()) | |
| loss = loss_lan | |
| optimizer_frame.zero_grad() | |
| loss.backward() | |
| optimizer_frame.step() | |
| # print(f"loss_lan: {loss_lan.item():.2f}, euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}") | |
| # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}") | |
| for param_group in optimizer_frame.param_groups: | |
| param_group['lr'] = 0.1 | |
| # "jointly roughly training id exp euler trans" | |
| for _ in range(200): | |
| proj_geo = face_model.compute_for_landmark_fit( | |
| id_para, exp_para, euler_angle, trans) | |
| loss_lan = cal_lan_loss_fn( | |
| proj_geo[:, :, :2], lms.detach()) | |
| loss_regid = torch.mean(id_para*id_para) # 正则化 | |
| loss_regexp = torch.mean(exp_para * exp_para) | |
| loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP | |
| optimizer_idexp.zero_grad() | |
| optimizer_frame.zero_grad() | |
| loss.backward() | |
| optimizer_idexp.step() | |
| optimizer_frame.step() | |
| # print(f"loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},") | |
| # print(f"euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}") | |
| # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}") | |
| # start fine training, intialize from the roughly trained results | |
| id_para_ = lms.new_zeros((num_frames, id_dim), requires_grad=True) | |
| id_para_.data = id_para.data.clone() | |
| id_para = id_para_ | |
| exp_para_ = lms.new_zeros((num_frames, exp_dim), requires_grad=True) | |
| exp_para_.data = exp_para.data.clone() | |
| exp_para = exp_para_ | |
| euler_angle_ = lms.new_zeros((num_frames, 3), requires_grad=True) | |
| euler_angle_.data = euler_angle.data.clone() | |
| euler_angle = euler_angle_ | |
| trans_ = lms.new_zeros((num_frames, 3), requires_grad=True) | |
| trans_.data = trans.data.clone() | |
| trans = trans_ | |
| batch_size = 1 | |
| # "fine fitting the 3DMM in batches" | |
| for i in range(int((num_frames-1)/batch_size+1)): | |
| if (i+1)*batch_size > num_frames: | |
| start_n = num_frames-batch_size | |
| sel_ids = np.arange(max(num_frames-batch_size,0), num_frames) | |
| else: | |
| start_n = i*batch_size | |
| sel_ids = np.arange(i*batch_size, i*batch_size+batch_size) | |
| sel_lms = lms[sel_ids] | |
| sel_id_para = id_para.new_zeros( | |
| (batch_size, id_dim), requires_grad=True) | |
| sel_id_para.data = id_para[sel_ids].clone() | |
| sel_exp_para = exp_para.new_zeros( | |
| (batch_size, exp_dim), requires_grad=True) | |
| sel_exp_para.data = exp_para[sel_ids].clone() | |
| sel_euler_angle = euler_angle.new_zeros( | |
| (batch_size, 3), requires_grad=True) | |
| sel_euler_angle.data = euler_angle[sel_ids].clone() | |
| sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True) | |
| sel_trans.data = trans[sel_ids].clone() | |
| set_requires_grad([sel_id_para, sel_exp_para, sel_euler_angle, sel_trans]) | |
| optimizer_cur_batch = torch.optim.Adam( | |
| [sel_id_para, sel_exp_para, sel_euler_angle, sel_trans], lr=0.005) | |
| for j in range(50): | |
| proj_geo = face_model.compute_for_landmark_fit( | |
| sel_id_para, sel_exp_para, sel_euler_angle, sel_trans) | |
| loss_lan = cal_lan_loss_fn( | |
| proj_geo[:, :, :2], lms.unsqueeze(0).detach()) | |
| loss_regid = torch.mean(sel_id_para*sel_id_para) # 正则化 | |
| loss_regexp = torch.mean(sel_exp_para*sel_exp_para) | |
| loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP | |
| optimizer_cur_batch.zero_grad() | |
| loss.backward() | |
| optimizer_cur_batch.step() | |
| print(f"batch {i} | loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f}") | |
| id_para[sel_ids].data = sel_id_para.data.clone() | |
| exp_para[sel_ids].data = sel_exp_para.data.clone() | |
| euler_angle[sel_ids].data = sel_euler_angle.data.clone() | |
| trans[sel_ids].data = sel_trans.data.clone() | |
| coeff_dict = {'id': id_para.detach().cpu().numpy(), 'exp': exp_para.detach().cpu().numpy(), | |
| 'euler': euler_angle.detach().cpu().numpy(), 'trans': trans.detach().cpu().numpy()} | |
| if save: | |
| np.save(out_name, coeff_dict, allow_pickle=True) | |
| if debug: | |
| import imageio | |
| debug_name = img_name.replace("/images_512/", "/coeff_fit_mp_debug/").replace(".png", "_debug.png").replace(".jpg", "_debug.jpg") | |
| try: os.makedirs(os.path.dirname(debug_name), exist_ok=True) | |
| except: pass | |
| proj_geo = face_model.compute_for_landmark_fit(id_para, exp_para, euler_angle, trans) | |
| lm68s = proj_geo[:,:,:2].detach().cpu().numpy() # [T, 68,2] | |
| lm68s = lm68s * img_scale_factor | |
| lms = lms * img_scale_factor | |
| lm68s[..., 1] = img_h - lm68s[..., 1] # flip the height axis | |
| lms[..., 1] = img_h - lms[..., 1] # flip the height axis | |
| lm68s = lm68s.astype(int) | |
| lm68s = lm68s.reshape([-1,2]) | |
| lms = lms.cpu().numpy().astype(int).reshape([-1,2]) | |
| for lm in lm68s: | |
| img = cv2.circle(img, lm, 1, (0, 0, 255), thickness=-1) | |
| for gt_lm in lms: | |
| img = cv2.circle(img, gt_lm, 2, (255, 0, 0), thickness=1) | |
| imageio.imwrite(debug_name, img) | |
| print(f"debug img saved at {debug_name}") | |
| return coeff_dict | |
| def out_exist_job(vid_name): | |
| out_name = vid_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png","_coeff_fit_mp.npy") | |
| # if os.path.exists(out_name) or not os.path.exists(lms_name): | |
| if os.path.exists(out_name): | |
| return None | |
| else: | |
| return vid_name | |
| def get_todo_img_names(img_names): | |
| todo_img_names = [] | |
| for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=16): | |
| if res is not None: | |
| todo_img_names.append(res) | |
| return todo_img_names | |
| if __name__ == '__main__': | |
| import argparse, glob, tqdm | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512') | |
| parser.add_argument("--ds_name", default='FFHQ') | |
| parser.add_argument("--seed", default=0, type=int) | |
| parser.add_argument("--process_id", default=0, type=int) | |
| parser.add_argument("--total_process", default=1, type=int) | |
| parser.add_argument("--keypoint_mode", default='mediapipe', type=str) | |
| parser.add_argument("--debug", action='store_true') | |
| parser.add_argument("--reset", action='store_true') | |
| parser.add_argument("--device", default="cuda:0", type=str) | |
| parser.add_argument("--output_log", action='store_true') | |
| parser.add_argument("--load_names", action="store_true") | |
| args = parser.parse_args() | |
| img_dir = args.img_dir | |
| load_names = args.load_names | |
| print(f"args {args}") | |
| if args.ds_name == 'single_img': | |
| img_names = [img_dir] | |
| else: | |
| img_names_path = os.path.join(img_dir, "img_dir.pkl") | |
| if os.path.exists(img_names_path) and load_names: | |
| print(f"loading vid names from {img_names_path}") | |
| img_names = load_file(img_names_path) | |
| else: | |
| if args.ds_name == 'FFHQ_MV': | |
| img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png") | |
| img_names1 = glob.glob(img_name_pattern1) | |
| img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png") | |
| img_names2 = glob.glob(img_name_pattern2) | |
| img_names = img_names1 + img_names2 | |
| img_names = sorted(img_names) | |
| elif args.ds_name == 'FFHQ': | |
| img_name_pattern = os.path.join(img_dir, "*.png") | |
| img_names = glob.glob(img_name_pattern) | |
| img_names = sorted(img_names) | |
| elif args.ds_name == "PanoHeadGen": | |
| img_name_patterns = ["ref/*/*.png"] | |
| img_names = [] | |
| for img_name_pattern in img_name_patterns: | |
| img_name_pattern_full = os.path.join(img_dir, img_name_pattern) | |
| img_names_part = glob.glob(img_name_pattern_full) | |
| img_names.extend(img_names_part) | |
| img_names = sorted(img_names) | |
| print(f"saving image names to {img_names_path}") | |
| save_file(img_names_path, img_names) | |
| # import random | |
| # random.seed(args.seed) | |
| # random.shuffle(img_names) | |
| face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM', | |
| camera_distance=10, focal=1015, keypoint_mode=args.keypoint_mode) | |
| face_model.to(torch.device(args.device)) | |
| process_id = args.process_id | |
| total_process = args.total_process | |
| if total_process > 1: | |
| assert process_id <= total_process -1 and process_id >= 0 | |
| num_samples_per_process = len(img_names) // total_process | |
| if process_id == total_process: | |
| img_names = img_names[process_id * num_samples_per_process : ] | |
| else: | |
| img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process] | |
| print(f"image names number (before fileter): {len(img_names)}") | |
| if not args.reset: | |
| img_names = get_todo_img_names(img_names) | |
| print(f"image names number (after fileter): {len(img_names)}") | |
| for i in tqdm.trange(len(img_names), desc=f"process {process_id}: fitting 3dmm ..."): | |
| img_name = img_names[i] | |
| try: | |
| fit_3dmm_for_a_image(img_name, args.debug, device=args.device) | |
| except Exception as e: | |
| print(img_name, e) | |
| if args.output_log and i % max(int(len(img_names) * 0.003), 1) == 0: | |
| print(f"process {process_id}: {i + 1} / {len(img_names)} done") | |
| sys.stdout.flush() | |
| sys.stderr.flush() | |
| print(f"process {process_id}: fitting 3dmm all done") | |