Diving-into-the-Fusion-of-Monocular-Priors-for-Generalized-Stereo-Matching-Demo
/
tools
/demo_matching.py
| import os | |
| import sys | |
| sys.path.insert(0,'mast3r') | |
| from mast3r.model import AsymmetricMASt3R | |
| from mast3r.fast_nn import fast_reciprocal_NNs | |
| import mast3r.utils.path_to_dust3r | |
| from dust3r.inference import inference | |
| from dust3r.utils.image import load_images | |
| if __name__ == '__main__': | |
| device = 'cuda' | |
| schedule = 'cosine' | |
| lr = 0.01 | |
| niter = 300 | |
| # Namespace(model="AsymmetricMASt3R(enc_depth=24, dec_depth=12, enc_embed_dim=1024, dec_embed_dim=768, enc_num_heads=16, dec_num_heads=12, pos_embed='RoPE100',img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), patch_embed_cls='PatchEmbedDust3R', two_confs=True, desc_conf_mode=('exp', 0, inf))") | |
| model_name = "/data5/yao/pretrained/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" | |
| # you can put the path to a local checkpoint in model_name if needed | |
| model = AsymmetricMASt3R.from_pretrained(model_name).to(device) | |
| print(next(model.parameters()).device) | |
| images = load_images(['/data5/yao/tmp/Piano/im0.png', '/data5/yao/tmp/Piano/im1.png'], size=512) | |
| output = inference([tuple(images)], model, device, batch_size=1, verbose=False) | |
| # from IPython import embed | |
| # embed() | |
| # at this stage, you have the raw dust3r predictions | |
| view1, pred1 = output['view1'], output['pred1'] | |
| view2, pred2 = output['view2'], output['pred2'] | |
| desc1, desc2 = pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach() | |
| # find 2D-2D matches between the two images | |
| matches_im0, matches_im1 = fast_reciprocal_NNs(desc1, desc2, subsample_or_initxy1=8, | |
| device=device, dist='dot', block_size=2**13) | |
| # ignore small border around the edge | |
| H0, W0 = view1['true_shape'][0] | |
| valid_matches_im0 = (matches_im0[:, 0] >= 3) & (matches_im0[:, 0] < int(W0) - 3) & ( | |
| matches_im0[:, 1] >= 3) & (matches_im0[:, 1] < int(H0) - 3) | |
| H1, W1 = view2['true_shape'][0] | |
| valid_matches_im1 = (matches_im1[:, 0] >= 3) & (matches_im1[:, 0] < int(W1) - 3) & ( | |
| matches_im1[:, 1] >= 3) & (matches_im1[:, 1] < int(H1) - 3) | |
| valid_matches = valid_matches_im0 & valid_matches_im1 | |
| matches_im0, matches_im1 = matches_im0[valid_matches], matches_im1[valid_matches] | |
| # visualize a few matches | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms.functional | |
| from matplotlib import pyplot as pl | |
| n_viz = 20 | |
| num_matches = matches_im0.shape[0] | |
| match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int) | |
| viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz] | |
| image_mean = torch.as_tensor([0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1) | |
| image_std = torch.as_tensor([0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1) | |
| viz_imgs = [] | |
| for i, view in enumerate([view1, view2]): | |
| rgb_tensor = view['img'] * image_std + image_mean | |
| viz_imgs.append(rgb_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()) | |
| H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2] | |
| img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) | |
| img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) | |
| img = np.concatenate((img0, img1), axis=1) | |
| pl.figure() | |
| pl.imshow(img) | |
| cmap = pl.get_cmap('jet') | |
| for i in range(n_viz): | |
| (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T | |
| pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False) | |
| # pl.show(block=True) | |
| pl.savefig('./plot_image.png') |