|
|
import os |
|
|
os.system("pip install git+https://github.com/elliottzheng/face-detection.git@master") |
|
|
os.system("git clone https://github.com/thohemp/6DRepNet") |
|
|
|
|
|
import sys |
|
|
sys.path.append("6DRepNet") |
|
|
|
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
from face_detection import RetinaFace |
|
|
from model import SixDRepNet |
|
|
import utils |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
|
|
|
snapshot_path = hf_hub_download(repo_id="osanseviero/6DRepNet_300W_LP_AFLW2000", filename="model.pth") |
|
|
|
|
|
model = SixDRepNet(backbone_name='RepVGG-B1g2', |
|
|
backbone_file='', |
|
|
deploy=True, |
|
|
pretrained=False) |
|
|
|
|
|
detector = RetinaFace(0) |
|
|
saved_state_dict = torch.load(os.path.join( |
|
|
snapshot_path), map_location='cpu') |
|
|
|
|
|
if 'model_state_dict' in saved_state_dict: |
|
|
model.load_state_dict(saved_state_dict['model_state_dict']) |
|
|
else: |
|
|
model.load_state_dict(saved_state_dict) |
|
|
model.cuda(0) |
|
|
model.eval() |
|
|
|
|
|
def predict(frame): |
|
|
faces = detector(frame) |
|
|
for box, landmarks, score in faces: |
|
|
|
|
|
if score < .95: |
|
|
continue |
|
|
x_min = int(box[0]) |
|
|
y_min = int(box[1]) |
|
|
x_max = int(box[2]) |
|
|
y_max = int(box[3]) |
|
|
bbox_width = abs(x_max - x_min) |
|
|
bbox_height = abs(y_max - y_min) |
|
|
|
|
|
x_min = max(0,x_min-int(0.2*bbox_height)) |
|
|
y_min = max(0,y_min-int(0.2*bbox_width)) |
|
|
x_max = x_max+int(0.2*bbox_height) |
|
|
y_max = y_max+int(0.2*bbox_width) |
|
|
|
|
|
img = frame[y_min:y_max,x_min:x_max] |
|
|
img = cv2.resize(img, (244, 244))/255.0 |
|
|
img = img.transpose(2, 0, 1) |
|
|
img = torch.from_numpy(img).type(torch.FloatTensor) |
|
|
img = torch.Tensor(img).cuda(0) |
|
|
img=img.unsqueeze(0) |
|
|
R_pred = model(img) |
|
|
euler = utils.compute_euler_angles_from_rotation_matrices( |
|
|
R_pred)*180/np.pi |
|
|
p_pred_deg = euler[:, 0].cpu() |
|
|
y_pred_deg = euler[:, 1].cpu() |
|
|
r_pred_deg = euler[:, 2].cpu() |
|
|
return utils.plot_pose_cube(frame, y_pred_deg, p_pred_deg, r_pred_deg, x_min + int(.5*(x_max-x_min)), y_min + int(.5*(y_max-y_min)), size = bbox_width) |
|
|
|
|
|
title = "6D Rotation Representation for Unconstrained Head Pose Estimation" |
|
|
description = "Gradio demo for 6DRepNet. To use it, simply click the camera picture. Read more at the links below." |
|
|
article = "<div style='text-align: center;'><a href='https://github.com/thohemp/6DRepNet' target='_blank'>Github Repo</a> | <a href='https://arxiv.org/abs/2202.12555' target='_blank'>Paper</a></div>" |
|
|
|
|
|
image_flip_css = """ |
|
|
.input-image .image-preview img{ |
|
|
-webkit-transform: scaleX(-1); |
|
|
transform: scaleX(-1) !important; |
|
|
} |
|
|
|
|
|
.output-image img { |
|
|
-webkit-transform: scaleX(-1); |
|
|
transform: scaleX(-1) !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.inputs.Image(label="Input Image", source="webcam"), |
|
|
outputs='image', |
|
|
live=True, |
|
|
title=title, |
|
|
description=description, |
|
|
article=article, |
|
|
css = image_flip_css |
|
|
) |
|
|
|
|
|
iface.launch() |