|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
from PIL import Image |
|
|
|
|
|
import frame_utils |
|
|
import vis |
|
|
|
|
|
|
|
|
def get_pos(H,W,disp=None,slant="slant",slant_norm=False,patch_size=None,device=None): |
|
|
if slant=="slant": |
|
|
u,v = torch.arange(W,device=device), torch.arange(H,device=device) |
|
|
grid_u, grid_v = torch.meshgrid(u, v, indexing="xy") |
|
|
if slant_norm: |
|
|
grid_u = grid_u/W |
|
|
grid_v = grid_v/H |
|
|
elif slant=="slant_local": |
|
|
assert H%patch_size==0 and W%patch_size==0 |
|
|
if not slant_norm: |
|
|
u = torch.arange(-patch_size/2+0.5, patch_size/2-0.5 + 1, step=1, device=device) |
|
|
v = torch.arange(-patch_size/2+0.5, patch_size/2-0.5 + 1, step=1, device=device) |
|
|
else: |
|
|
|
|
|
u = torch.arange(-1+1/patch_size, 1, step=2/patch_size, device=device) |
|
|
v = torch.arange(-1+1/patch_size, 1, step=2/patch_size, device=device) |
|
|
|
|
|
u = u.tile((W//patch_size)) |
|
|
v = v.tile((H//patch_size)) |
|
|
grid_u, grid_v = torch.meshgrid(u, v, indexing="xy") |
|
|
|
|
|
|
|
|
grid_u = grid_u.view((1,1,H,W)) |
|
|
grid_v = grid_v.view((1,1,H,W)) |
|
|
if disp is not None: |
|
|
pos = torch.cat([grid_u,grid_v,disp],dim=1) |
|
|
else: |
|
|
pos = torch.cat([grid_u,grid_v],dim=1) |
|
|
return pos.float() |
|
|
|
|
|
def convert2patch(data, patch_size, div_last=False): |
|
|
""" |
|
|
data: B,C,H,W; |
|
|
""" |
|
|
B,C,H,W = data.shape |
|
|
assert H%patch_size==0 and W%patch_size==0 |
|
|
patch_data = F.unfold(data, kernel_size=patch_size, dilation=1, padding=0, stride=patch_size) |
|
|
patch_data = patch_data.view((-1,C,patch_size*patch_size,H//patch_size,W//patch_size)) |
|
|
if div_last: |
|
|
patch_data[:,-1] /= patch_size |
|
|
return patch_data |
|
|
|
|
|
def intra_dist4patch(patch_data, patch_size): |
|
|
""" |
|
|
patch_data: B,C,patch_size*patch_size,H,W |
|
|
""" |
|
|
src = patch_data.unsqueeze(3).tile((1,1,1,patch_size*patch_size,1,1)) |
|
|
tar = patch_data.unsqueeze(2).tile((1,1,patch_size*patch_size,1,1,1)) |
|
|
dist = torch.sqrt(torch.square(src-tar).sum(dim=1)) |
|
|
return dist |
|
|
|
|
|
def get_adjacent_matrix(dist,patch_size,thold=3): |
|
|
connect = (dist<thold).float() |
|
|
max_loop = int(np.ceil(np.log2(patch_size*patch_size))) |
|
|
for _ in range(max_loop): |
|
|
connect = torch.einsum('bijhw,bjkhw->bikhw', connect, connect) |
|
|
connect = (connect>0).float() |
|
|
connect = (connect>0).sum(dim=2) |
|
|
return connect |
|
|
|
|
|
def reduce_noise(patch_coord, mask): |
|
|
""" |
|
|
patch_coord: B,C,patch_size*patch_size,H,W; |
|
|
mask: B,patch_size*patch_size,H,W; |
|
|
""" |
|
|
|
|
|
center_coord = (patch_coord*mask.unsqueeze(1)).sum(dim=2) / mask.sum(dim=1) |
|
|
chs_coord = patch_coord*mask.unsqueeze(1) + (~mask.unsqueeze(1)) * center_coord.unsqueeze(2) |
|
|
|
|
|
return chs_coord |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_plane_lstsq(chs_coord, slant, patch_coord=None): |
|
|
""" |
|
|
chs_coord: B,C,patch_size*patch_size,H,W; |
|
|
mask: B,patch_size*patch_size,H,W; |
|
|
return: |
|
|
cab: B,6,H,W; (disparity, a, b, g_uu, g_vv, g_uv) |
|
|
""" |
|
|
|
|
|
|
|
|
B,C,L,H,W = chs_coord.shape |
|
|
chs_coord = chs_coord.flatten(-2,-1).transpose(-2,-1) |
|
|
u_coord = chs_coord[:,0] |
|
|
v_coord = chs_coord[:,1] |
|
|
d_coord = chs_coord[:,2] |
|
|
A = torch.stack((torch.ones_like(u_coord), u_coord, v_coord, |
|
|
u_coord*u_coord/2, v_coord*v_coord/2, u_coord*v_coord), dim=3) |
|
|
|
|
|
|
|
|
cab = torch.linalg.lstsq(A, d_coord).solution |
|
|
cab = cab.transpose(1,2).view((-1,6,H,W)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return cab |
|
|
|
|
|
def extract_plane(disp,slant="slant", slant_norm=False, patch_size=4,thold=3,vis=False): |
|
|
""" |
|
|
disp: B,1,H,W; |
|
|
return: |
|
|
cab: B,6,H,W; (disparity, a, b, g_uu, g_vv, g_uv) |
|
|
""" |
|
|
|
|
|
patch_pos = convert2patch(disp, patch_size=patch_size) |
|
|
dist = intra_dist4patch(patch_pos, patch_size=patch_size) |
|
|
connect = get_adjacent_matrix(dist, patch_size=patch_size, thold=thold) |
|
|
|
|
|
|
|
|
mask = connect - torch.amax(connect,dim=1).unsqueeze(1) |
|
|
mask = mask >= -0.0001 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
B,_,H,W = disp.shape |
|
|
coord = get_pos(H,W,disp=disp,slant=slant,slant_norm=slant_norm,patch_size=patch_size) |
|
|
patch_coord = convert2patch(coord, patch_size=patch_size, div_last=True) |
|
|
|
|
|
|
|
|
chs_coord = reduce_noise(patch_coord, mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cab = get_plane_lstsq(chs_coord, slant, patch_coord) |
|
|
|
|
|
if vis: |
|
|
return cab, mask |
|
|
return cab |
|
|
|
|
|
def predict_disp(cab, uv_coord, patch_size, mul_last=False): |
|
|
""" |
|
|
cab: B,6,H,W; (disparity, a, b, g_uu, g_vv, g_uv) |
|
|
uv_coord: B,2,patch_size*patch_size,H,W; |
|
|
""" |
|
|
u_coord = uv_coord[:,0] |
|
|
v_coord = uv_coord[:,1] |
|
|
A = torch.stack((torch.ones_like(u_coord), u_coord, v_coord, |
|
|
u_coord*u_coord/2, v_coord*v_coord/2, u_coord*v_coord), dim=1) |
|
|
d_coord = (A * cab.unsqueeze(dim=2)).sum(dim=1) |
|
|
if mul_last: |
|
|
d_coord *= patch_size |
|
|
|
|
|
return d_coord |
|
|
|
|
|
def compute_curvature(cab): |
|
|
""" |
|
|
cab: B,6,H,W; (disparity, a, b, g_uu, g_vv, g_uv) |
|
|
|
|
|
""" |
|
|
B,C,H,W = cab.shape |
|
|
hessian = torch.stack([cab[0,-3], cab[0,-1], cab[0,-1], cab[0,-2]],dim=-1).reshape(H,W,2,2) |
|
|
eigen_val, eigen_vec = torch.linalg.eigh(hessian) |
|
|
Gaussian_cur = eigen_val[...,0] * eigen_val[...,1] |
|
|
mean_cur = (eigen_val[...,0] + eigen_val[...,1]) / 2 |
|
|
|
|
|
Gaussian_cur = Gaussian_cur.abs() |
|
|
mean_cur = mean_cur.abs() |
|
|
|
|
|
Gaussian_cur[Gaussian_cur>0.03] = 0 |
|
|
mean_cur[mean_cur>0.01] = 0 |
|
|
|
|
|
Gaussian_cur = (Gaussian_cur - Gaussian_cur.min()) / (Gaussian_cur.max()-Gaussian_cur.min()) |
|
|
mean_cur = (mean_cur - mean_cur.min()) / (mean_cur.max()-mean_cur.min()) |
|
|
|
|
|
|
|
|
|
|
|
return Gaussian_cur, mean_cur |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
slant = "slant_local" |
|
|
|
|
|
slant_norm = False |
|
|
patch_size = 4 |
|
|
root = "/horizon-bucket/saturn_v_dev/01_users/chengtang.yao/Sceneflow" |
|
|
disp_path = root+"/flyingthings3d/disparity/TRAIN/A/0717/left/0006.pfm" |
|
|
left_path = root+"/flyingthings3d/frames_cleanpass/TRAIN/A/0717/left/0006.png" |
|
|
sv_path = "./tmp.png" |
|
|
|
|
|
img0 = np.array(Image.open(left_path)) |
|
|
disp = np.array(frame_utils.readPFM(disp_path)) |
|
|
|
|
|
|
|
|
H,W = disp.shape |
|
|
|
|
|
start_time = time.time() |
|
|
disp = torch.from_numpy(disp).unsqueeze(0).unsqueeze(0) |
|
|
img0 = torch.from_numpy(img0).permute((2,0,1)).unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
cab, mask = extract_plane(disp, |
|
|
slant=slant, slant_norm=slant_norm, |
|
|
patch_size=patch_size, thold=3, vis=True) |
|
|
|
|
|
|
|
|
uv_coord = get_pos(H,W, slant=slant, slant_norm=slant_norm, patch_size=patch_size) |
|
|
patch_uv_coord = convert2patch(uv_coord, patch_size=patch_size) |
|
|
d_coord = predict_disp(cab, patch_uv_coord, patch_size=patch_size, mul_last=True) |
|
|
|
|
|
patch_disp = convert2patch(disp, patch_size=patch_size, div_last=True) |
|
|
rec_disp = F.fold(d_coord.flatten(-2,-1), disp.shape[-2:], kernel_size=patch_size, stride=patch_size).view(1,1,H,W) |
|
|
rec_mask = F.fold(mask.flatten(-2,-1).float(), disp.shape[-2:], kernel_size=patch_size, stride=patch_size).view(1,1,H,W).bool() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end_time = time.time() |
|
|
print("cost time: {}".format(end_time-start_time), cab.shape) |
|
|
|
|
|
disp = disp.squeeze(0).squeeze(0).cpu().data.numpy() |
|
|
img0 = img0.squeeze(0).permute((1,2,0)).cpu().data.numpy() |
|
|
patch_disp = patch_disp[0,0,0,...].cpu().data.numpy() |
|
|
rec_disp = rec_disp[0,0,...].cpu().data.numpy() |
|
|
rec_mask = rec_mask[0,0,...].cpu().data.numpy() |
|
|
|
|
|
error_map = np.abs(rec_disp-disp) |
|
|
color_error_map = vis.colorize_error_map(error_map) |
|
|
|
|
|
|
|
|
degree = torch.atan(cab[0,1] / cab[0,2]) |
|
|
|
|
|
|
|
|
Gaussian_cur, mean_cur = compute_curvature(cab) |
|
|
print("-"*10, Gaussian_cur.min(), Gaussian_cur.max(), Gaussian_cur.mean(), Gaussian_cur.median()) |
|
|
print("-"*10, mean_cur.min(), mean_cur.max(), mean_cur.mean(), mean_cur.median()) |
|
|
|
|
|
atom_dict = [{"img":img0, "title":"Left Image", }, |
|
|
{"img":disp, "title":"GT Disparity", "cmap":'jet', }, |
|
|
{"img":patch_disp, "title":"GT Patch Disparity", "cmap":'jet', }, |
|
|
{"img":rec_disp, "title":"GT recover Disparity", "cmap":'jet', }, |
|
|
{"img":rec_mask, "title":"rec_mask", "cmap": "gray"}, |
|
|
{"img":color_error_map, "title":"color_error_map", }, |
|
|
{"img":degree, "title":"GT ab", "cmap":'jet', }, |
|
|
{"img":cab[0,0], "title":"GT c", "cmap":'jet', }, |
|
|
|
|
|
{"img":Gaussian_cur.abs(), "title":"Gaussian curvature", "cmap":'jet', }, |
|
|
{"img":mean_cur.abs(), "title":"mean curvature", "cmap":'jet', }, |
|
|
] |
|
|
|
|
|
if slant=="slant_local": |
|
|
d_p = cab[0,0] |
|
|
error_map = np.abs(d_p-patch_disp) |
|
|
color_error_map = vis.colorize_error_map(error_map) |
|
|
tmp_dict = [{"img":d_p, "title":"GT Disparity of Plane", "cmap":'jet', }, |
|
|
{"img":color_error_map, "title":"color_error_map of Plane", },] |
|
|
atom_dict += tmp_dict |
|
|
|
|
|
vis.show_imgs(atom_dict, |
|
|
sv_img=True, save2where=sv_path, if_inter=False, |
|
|
fontsize=20, szWidth=10, szHeight=5, group=2) |