Spaces:
Running
on
Zero
Running
on
Zero
| from scipy.interpolate import interp1d, PchipInterpolator | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| import torch | |
| def sift_match( | |
| img1, img2, | |
| thr=0.5, | |
| topk=5, method="max_dist", | |
| output_path="sift_matches.png", | |
| ): | |
| assert method in ["max_dist", "random", "max_score", "max_score_even"] | |
| # img1 and img2 are PIL images | |
| # small threshold means less points | |
| # 1. to cv2 grayscale image | |
| img1_rgb = np.array(img1).copy() | |
| img2_rgb = np.array(img2).copy() | |
| img1 = cv2.cvtColor(np.array(img1), cv2.COLOR_RGB2BGR) | |
| img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) | |
| img2 = cv2.cvtColor(np.array(img2), cv2.COLOR_RGB2BGR) | |
| img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY) | |
| # 2. use sift to extract keypoints and descriptors | |
| # Initiate SIFT detector | |
| sift = cv2.SIFT_create() | |
| # find the keypoints and descriptors with SIFT | |
| kp1, des1 = sift.detectAndCompute(img1, None) | |
| kp2, des2 = sift.detectAndCompute(img2, None) | |
| # BFMatcher with default params | |
| bf = cv2.BFMatcher() | |
| # bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True) | |
| # bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) | |
| matches = bf.knnMatch(des1, des2, k=2) | |
| # Apply ratio test | |
| good = [] | |
| point_list = [] | |
| distance_list = [] | |
| if method in ['max_score', 'max_score_even']: | |
| matches = sorted(matches, key=lambda x: x[0].distance / x[1].distance) | |
| anchor_points_list = [] | |
| for m, n in matches[:topk]: | |
| print(m.distance / n.distance) | |
| # check evenly distributed | |
| if method == 'max_score_even': | |
| to_close = False | |
| for anchor_point in anchor_points_list: | |
| pt1 = kp1[m.queryIdx].pt | |
| dist = np.linalg.norm(np.array(pt1) - np.array(anchor_point)) | |
| if dist < 50: | |
| to_close = True | |
| break | |
| if to_close: | |
| continue | |
| good.append([m]) | |
| pt1 = kp1[m.queryIdx].pt | |
| pt2 = kp2[m.trainIdx].pt | |
| dist = np.linalg.norm(np.array(pt1) - np.array(pt2)) | |
| distance_list.append(dist) | |
| anchor_points_list.append(pt1) | |
| pt1 = torch.tensor(pt1) | |
| pt2 = torch.tensor(pt2) | |
| pt = torch.stack([pt1, pt2]) # (2, 2) | |
| point_list.append(pt) | |
| if method in ['max_dist', 'random']: | |
| for m, n in matches: | |
| if m.distance < thr * n.distance: | |
| good.append([m]) | |
| pt1 = kp1[m.queryIdx].pt | |
| pt2 = kp2[m.trainIdx].pt | |
| dist = np.linalg.norm(np.array(pt1) - np.array(pt2)) | |
| distance_list.append(dist) | |
| pt1 = torch.tensor(pt1) | |
| pt2 = torch.tensor(pt2) | |
| pt = torch.stack([pt1, pt2]) # (2, 2) | |
| point_list.append(pt) | |
| distance_list = np.array(distance_list) | |
| # only keep the points with the largest topk distance | |
| idx = np.argsort(distance_list) | |
| if method == "max_dist": | |
| idx = idx[-topk:] | |
| elif method == "random": | |
| topk = min(topk, len(idx)) | |
| idx = np.random.choice(idx, topk, replace=False) | |
| elif method == "max_score": | |
| import pdb; pdb.set_trace() | |
| raise NotImplementedError | |
| # idx = np.argsort(distance_list)[:topk] | |
| else: | |
| raise ValueError(f"Unknown method {method}") | |
| point_list = [point_list[i] for i in idx] | |
| good = [good[i] for i in idx] | |
| # # cv2.drawMatchesKnn expects list of lists as matches. | |
| # draw_params = dict( | |
| # matchColor=(255, 0, 0), | |
| # singlePointColor=None, | |
| # flags=2, | |
| # ) | |
| # img3 = cv2.drawMatchesKnn(img1, kp1, img2, kp2, good, None, **draw_params) | |
| # # manually draw the matches, the images are put in horizontal | |
| # img3 = np.concatenate([img1_rgb, img2_rgb], axis=1) # (h, 2w, 3) | |
| # for m in good: | |
| # pt1 = kp1[m[0].queryIdx].pt | |
| # pt2 = kp2[m[0].trainIdx].pt | |
| # pt1 = (int(pt1[0]), int(pt1[1])) | |
| # pt2 = (int(pt2[0]) + img1_rgb.shape[1], int(pt2[1])) | |
| # cv2.line(img3, pt1, pt2, (255, 0, 0), 1) | |
| # manually draw the matches, the images are put in vertical. with 10 pixels margin | |
| margin = 10 | |
| img3 = np.zeros((img1_rgb.shape[0] + img2_rgb.shape[0] + margin, max(img1_rgb.shape[1], img2_rgb.shape[1]), 3), dtype=np.uint8) | |
| # the margin is white | |
| img3[:, :] = 255 | |
| img3[:img1_rgb.shape[0], :img1_rgb.shape[1]] = img1_rgb | |
| img3[img1_rgb.shape[0] + margin:, :img2_rgb.shape[1]] = img2_rgb | |
| # create a color list of 6 different colors | |
| color_list = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255)] | |
| for color_idx, m in enumerate(good): | |
| pt1 = kp1[m[0].queryIdx].pt | |
| pt2 = kp2[m[0].trainIdx].pt | |
| pt1 = (int(pt1[0]), int(pt1[1])) | |
| pt2 = (int(pt2[0]), int(pt2[1]) + img1_rgb.shape[0] + margin) | |
| # cv2.line(img3, pt1, pt2, (255, 0, 0), 1) | |
| # avoid the zigzag artifact in line | |
| # random_color = tuple(np.random.randint(0, 255, 3).tolist()) | |
| color = color_list[color_idx % len(color_list)] | |
| cv2.line(img3, pt1, pt2, color, 1, lineType=cv2.LINE_AA) | |
| # add a empty circle to both start and end points | |
| cv2.circle(img3, pt1, 3, color, lineType=cv2.LINE_AA) | |
| cv2.circle(img3, pt2, 3, color, lineType=cv2.LINE_AA) | |
| Image.fromarray(img3).save(output_path) | |
| print(f"Save the sift matches to {output_path}") | |
| # (f, topk, 2), f=2 (before interpolation) | |
| if len(point_list) == 0: | |
| return None | |
| point_list = torch.stack(point_list) | |
| point_list = point_list.permute(1, 0, 2) | |
| return point_list | |
| def interpolate_trajectory(points_torch, num_frames, t=None): | |
| # points:(f, topk, 2), f=2 (before interpolation) | |
| num_points = points_torch.shape[1] | |
| points_torch = points_torch.permute(1, 0, 2) # (topk, f, 2) | |
| points_list = [] | |
| for i in range(num_points): | |
| # points:(f, 2) | |
| points = points_torch[i].cpu().numpy() | |
| x = [point[0] for point in points] | |
| y = [point[1] for point in points] | |
| if t is None: | |
| t = np.linspace(0, 1, len(points)) | |
| # fx = interp1d(t, x, kind='cubic') | |
| # fy = interp1d(t, y, kind='cubic') | |
| fx = PchipInterpolator(t, x) | |
| fy = PchipInterpolator(t, y) | |
| new_t = np.linspace(0, 1, num_frames) | |
| new_x = fx(new_t) | |
| new_y = fy(new_t) | |
| new_points = list(zip(new_x, new_y)) | |
| points_list.append(new_points) | |
| points = torch.tensor(points_list) # (topk, num_frames, 2) | |
| points = points.permute(1, 0, 2) # (num_frames, topk, 2) | |
| return points | |
| # diffusion feature matching | |
| def point_tracking( | |
| F0, | |
| F1, | |
| handle_points, | |
| handle_points_init, | |
| track_dist=5, | |
| ): | |
| # handle_points: (num_points, 2) | |
| # NOTE: | |
| # 1. all row and col are reversed | |
| # 2. handle_points in (y, x), not (x, y) | |
| # reverse row and col | |
| handle_points = torch.stack([handle_points[:, 1], handle_points[:, 0]], dim=-1) | |
| handle_points_init = torch.stack([handle_points_init[:, 1], handle_points_init[:, 0]], dim=-1) | |
| with torch.no_grad(): | |
| _, _, max_r, max_c = F0.shape | |
| for i in range(len(handle_points)): | |
| pi0, pi = handle_points_init[i], handle_points[i] | |
| f0 = F0[:, :, int(pi0[0]), int(pi0[1])] | |
| r1, r2 = max(0, int(pi[0]) - track_dist), min(max_r, int(pi[0]) + track_dist + 1) | |
| c1, c2 = max(0, int(pi[1]) - track_dist), min(max_c, int(pi[1]) + track_dist + 1) | |
| F1_neighbor = F1[:, :, r1:r2, c1:c2] | |
| all_dist = (f0.unsqueeze(dim=-1).unsqueeze(dim=-1) - F1_neighbor).abs().sum(dim=1) | |
| all_dist = all_dist.squeeze(dim=0) | |
| row, col = divmod(all_dist.argmin().item(), all_dist.shape[-1]) | |
| # handle_points[i][0] = pi[0] - track_dist + row | |
| # handle_points[i][1] = pi[1] - track_dist + col | |
| handle_points[i][0] = r1 + row | |
| handle_points[i][1] = c1 + col | |
| handle_points = torch.stack([handle_points[:, 1], handle_points[:, 0]], dim=-1) # (num_points, 2) | |
| return handle_points | |