|
|
import cv2 |
|
|
import time |
|
|
import torch |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.cm as cm |
|
|
from matplotlib.patches import ConnectionPatch |
|
|
from controller.controller import AttentionControl |
|
|
|
|
|
from einops import repeat, rearrange |
|
|
from typing import Tuple, Callable |
|
|
|
|
|
from vidtome.patch import PCA_token |
|
|
from utils.flow_utils import coords_grid |
|
|
|
|
|
def do_nothing(x: torch.Tensor, mode: str = None): |
|
|
return x |
|
|
|
|
|
|
|
|
def mps_gather_workaround(input, dim, index): |
|
|
if input.shape[-1] == 1: |
|
|
return torch.gather( |
|
|
input.unsqueeze(-1), |
|
|
dim - 1 if dim < 0 else dim, |
|
|
index.unsqueeze(-1) |
|
|
).squeeze(-1) |
|
|
else: |
|
|
return torch.gather(input, dim, index) |
|
|
|
|
|
def visualize_flow_correspondence(src_img: torch.Tensor, tar_img: torch.Tensor, flow: torch.Tensor, flow_confid: torch.Tensor, |
|
|
ratio: float, H: int=64, out: str = "correspondence.png") -> Tuple[Callable, Callable, dict]: |
|
|
if len(src_img.shape) == 4: |
|
|
B, C, H, W = src_img.shape |
|
|
src_img = rearrange(src_img, 'b c h w -> b (h w) c', h=H) |
|
|
tar_img = rearrange(tar_img, 'b c h w -> b (h w) c', h=H) |
|
|
else: |
|
|
B, N, C = src_img.shape |
|
|
W = N // H |
|
|
|
|
|
src_PCA_token = PCA_token(src_img, token_h=H) |
|
|
tar_PCA_token = PCA_token(tar_img, token_h=H) |
|
|
|
|
|
|
|
|
gather = mps_gather_workaround if src_img.device.type == "mps" else torch.gather |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
a = src_img / src_img.norm(dim=-1, keepdim=True) |
|
|
b = tar_img / tar_img.norm(dim=-1, keepdim=True) |
|
|
|
|
|
scores = a @ b.transpose(-1, -2) |
|
|
|
|
|
r = min(a.shape[1], int(a.shape[1] * ratio)) |
|
|
print(f"[INFO] flow r {r} ") |
|
|
|
|
|
flow_confid = rearrange(flow_confid, 'b h w -> b (h w)') |
|
|
edge_idx = flow_confid.argsort(dim=-1, descending=True)[..., None] |
|
|
|
|
|
unm_idx = edge_idx[..., r:, :] |
|
|
src_idx = edge_idx[..., :r, :] |
|
|
|
|
|
src_xy = [(id.item() % W, id.item() // W) for id in src_idx[0]] |
|
|
grid = coords_grid(B, H, W).to(flow.device) + flow |
|
|
tar_xy = [(grid[0, 0, y, x].clamp(0, W-1).item(), \ |
|
|
grid[0, 1, y, x].clamp(0, H-1).item()) for (x, y) in src_xy] |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(1, 2, figsize=(8, 3)) |
|
|
|
|
|
ax[0].imshow(src_PCA_token, cmap='gray') |
|
|
ax[1].imshow(tar_PCA_token, cmap='gray') |
|
|
|
|
|
ax[0].axis('off') |
|
|
ax[1].axis('off') |
|
|
|
|
|
colors = cm.Greens(np.linspace(0.5, 1, len(src_xy))) |
|
|
|
|
|
for (x1, y1), (x2, y2), color in zip(src_xy, tar_xy, colors): |
|
|
ax[0].plot(x1, y1, marker='o', color=color, markersize=0.5) |
|
|
ax[1].plot(x2, y2, marker='o', color=color, markersize=1) |
|
|
con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", |
|
|
axesA=ax[1], axesB=ax[0], color=color, linewidth=0.2) |
|
|
ax[1].add_artist(con) |
|
|
|
|
|
plt.savefig(out, bbox_inches="tight") |
|
|
plt.close() |
|
|
|
|
|
def visualize_correspondence_score(src_img: torch.Tensor, tar_img: torch.Tensor, score: torch.Tensor, |
|
|
ratio: float, H: int=64, out: str = "correspondence_idx.png") -> Tuple[Callable, Callable, dict]: |
|
|
if len(src_img.shape) == 4: |
|
|
B, C, H, W = src_img.shape |
|
|
src_img = rearrange(src_img, 'b c h w -> b (h w) c', h=H) |
|
|
tar_img = rearrange(tar_img, 'b c h w -> b (h w) c', h=H) |
|
|
else: |
|
|
B, N, C = src_img.shape |
|
|
W = N // H |
|
|
|
|
|
src_PCA_token = PCA_token(src_img, token_h=H) |
|
|
tar_PCA_token = PCA_token(tar_img, token_h=H) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
r = min(N, int(N * ratio)) |
|
|
|
|
|
node_max, node_idx = score.max(dim=-1) |
|
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
|
|
|
|
|
src_idx = edge_idx[0, :r, 0] |
|
|
tar_idx = torch.gather(node_idx[0], dim=0, index=src_idx) |
|
|
|
|
|
src_idx = src_idx[:r] |
|
|
tar_idx = tar_idx[:r] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src_xy = [(id.item() % W, id.item() // W) for id in src_idx] |
|
|
tar_xy = [(id.item() % W, id.item() // W) for id in tar_idx] |
|
|
|
|
|
fig, ax = plt.subplots(1, 2, figsize=(8, 3)) |
|
|
|
|
|
ax[0].imshow(src_PCA_token, cmap='gray') |
|
|
ax[1].imshow(tar_PCA_token, cmap='gray') |
|
|
|
|
|
colors = cm.cool(np.linspace(0, 1, len(src_xy))) |
|
|
|
|
|
for (x1, y1), (x2, y2), color in zip(src_xy, tar_xy, colors): |
|
|
ax[0].plot(x1, y1, marker='o', color=color, markersize=1) |
|
|
ax[1].plot(x2, y2, marker='o', color=color, markersize=1) |
|
|
con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", |
|
|
axesA=ax[1], axesB=ax[0], color=color, linewidth=0.2) |
|
|
ax[1].add_artist(con) |
|
|
|
|
|
plt.savefig(out, bbox_inches="tight") |
|
|
plt.close() |
|
|
|
|
|
def visualize_cosine_correspondence(src_img: torch.Tensor, tar_img: torch.Tensor, |
|
|
ratio: float, H: int=64, out: str = "correspondence.png", |
|
|
flow: torch.Tensor=None, flow_confid: torch.Tensor=None, |
|
|
controller: AttentionControl=None ) -> Tuple[Callable, Callable, dict]: |
|
|
if len(src_img.shape) == 4: |
|
|
B, C, H, W = src_img.shape |
|
|
src_img = rearrange(src_img, 'b c h w -> b (h w) c', h=H) |
|
|
tar_img = rearrange(tar_img, 'b c h w -> b (h w) c', h=H) |
|
|
else: |
|
|
B, N, C = src_img.shape |
|
|
W = N // H |
|
|
|
|
|
src_PCA_token = PCA_token(src_img, token_h=H) |
|
|
tar_PCA_token = PCA_token(tar_img, token_h=H) |
|
|
|
|
|
|
|
|
|
|
|
gather = mps_gather_workaround if src_img.device.type == "mps" else torch.gather |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
a = src_img / src_img.norm(dim=-1, keepdim=True) |
|
|
b = tar_img / tar_img.norm(dim=-1, keepdim=True) |
|
|
|
|
|
scores = a @ b.transpose(-1, -2) |
|
|
|
|
|
|
|
|
r = min(a.shape[1], int(a.shape[1] * ratio)) |
|
|
print(f"[INFO] cosine r {r} ") |
|
|
|
|
|
|
|
|
|
|
|
node_max, node_idx = scores.max(dim=-1) |
|
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
|
|
|
|
|
unm_idx = edge_idx[..., r:, :] |
|
|
src_idx = edge_idx[..., int(4*r):int(5*r), :] |
|
|
|
|
|
|
|
|
|
|
|
tar_idx = gather(node_idx[..., None], dim=-2, index=src_idx) |
|
|
|
|
|
src_xy = [(id.item() % W, id.item() // W) for id in src_idx[0]] |
|
|
tar_xy = [(id.item() % W, id.item() // W) for id in tar_idx[0]] |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(1, 2, figsize=(8, 3)) |
|
|
|
|
|
ax[0].imshow(src_PCA_token, cmap='spring') |
|
|
ax[1].imshow(tar_PCA_token, cmap='spring') |
|
|
|
|
|
|
|
|
ax[0].axis('off') |
|
|
ax[1].axis('off') |
|
|
|
|
|
|
|
|
colors = cm.cool(np.linspace(0.5, 1, len(src_xy))) |
|
|
|
|
|
for (x1, y1), (x2, y2), color in zip(src_xy, tar_xy, colors): |
|
|
|
|
|
ax[0].plot(x1, y1, marker='o', color=color, markersize=0.5) |
|
|
ax[1].plot(x2, y2, marker='o', color=color, markersize=1) |
|
|
con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", |
|
|
axesA=ax[1], axesB=ax[0], color=color, linewidth=0.2) |
|
|
ax[1].add_artist(con) |
|
|
|
|
|
plt.savefig(out, bbox_inches="tight") |
|
|
plt.close() |
|
|
|
|
|
def visualize_correspondence(src_img: torch.Tensor, tar_img: torch.Tensor, |
|
|
ratio: float, H: int=64, out: str = "correspondence.png", |
|
|
flow: torch.Tensor=None, flow_confid: torch.Tensor=None, |
|
|
controller: AttentionControl=None ) -> Tuple[Callable, Callable, dict]: |
|
|
|
|
|
if len(src_img.shape) == 4: |
|
|
B, C, H, W = src_img.shape |
|
|
src_img = rearrange(src_img, 'b c h w -> b (h w) c', h=H) |
|
|
tar_img = rearrange(tar_img, 'b c h w -> b (h w) c', h=H) |
|
|
else: |
|
|
B, N, C = src_img.shape |
|
|
W = N // H |
|
|
src_PCA_token = PCA_token(src_img, token_h=H, n=1) |
|
|
tar_PCA_token = PCA_token(tar_img, token_h=H, n=1) |
|
|
|
|
|
if abs(np.mean(src_PCA_token[:20, :20]) - np.mean(tar_PCA_token[:20, :20])) > 50: |
|
|
if np.mean(src_PCA_token[:20, :20]) > np.mean(tar_PCA_token[:20, :20]): |
|
|
src_PCA_token = 255 - src_PCA_token |
|
|
else: |
|
|
tar_PCA_token = 255 - tar_PCA_token |
|
|
print(f"[INFO] src_PCA_token mean {np.mean(src_PCA_token[:20, :20])} tar_PCA_token mean {np.mean(tar_PCA_token[:20, :20])} ") |
|
|
|
|
|
|
|
|
gather = mps_gather_workaround if src_img.device.type == "mps" else torch.gather |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
a = src_img / src_img.norm(dim=-1, keepdim=True) |
|
|
b = tar_img / tar_img.norm(dim=-1, keepdim=True) |
|
|
|
|
|
scores = a @ b.transpose(-1, -2) |
|
|
|
|
|
|
|
|
r = min(a.shape[1], int(a.shape[1] * ratio)) |
|
|
|
|
|
|
|
|
|
|
|
print(f"[INFO] no distance weigthed ... ") |
|
|
|
|
|
node_max, node_idx = scores.max(dim=-1) |
|
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
|
|
|
|
|
unm_idx = edge_idx[..., r:, :] |
|
|
src_idx = edge_idx[..., :r, :] |
|
|
|
|
|
|
|
|
|
|
|
tar_idx = gather(node_idx[..., None], dim=-2, index=src_idx) |
|
|
|
|
|
src_xy = [(id.item() % W, id.item() // W) for id in src_idx[0]] |
|
|
tar_xy = [(id.item() % W, id.item() // W) for id in tar_idx[0]] |
|
|
|
|
|
|
|
|
flow_confid = rearrange(flow_confid, 'b h w -> b (h w)') |
|
|
edge_idx = flow_confid.argsort(dim=-1, descending=True)[..., None] |
|
|
|
|
|
unm_idx = edge_idx[..., r:, :] |
|
|
src_idx = edge_idx[..., :r, :] |
|
|
|
|
|
flow_src_xy = [(id.item() % W, id.item() // W) for id in src_idx[0]] |
|
|
|
|
|
grid = coords_grid(B, H, W).to(flow.device) + flow |
|
|
flow_tar_xy = [(grid[0, 0, y, x].clamp(0, W-1).item(), \ |
|
|
grid[0, 1, y, x].clamp(0, H-1).item()) for (x, y) in flow_src_xy] |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(2, 2, figsize=(8, 4)) |
|
|
|
|
|
if len(controller.decoded_imgs): |
|
|
step = out.split("/")[-1].split(".")[0] |
|
|
|
|
|
_, h_, w_, _ = controller.decoded_imgs[0].shape |
|
|
mul = h_ // H |
|
|
decoded_img = controller.decoded_imgs[1] |
|
|
decoded_img = decoded_img[0, :, :int(W * mul), :] |
|
|
if step == "49": |
|
|
decoded_img = cv2.imread("/project/DiffBVR_eval/DAVIS/BDx8_results/DiffBIR_ours/cows/00001.png") |
|
|
decoded_img = cv2.resize(decoded_img, (W, H)) |
|
|
ax[0, 0].imshow(decoded_img, aspect='auto') |
|
|
decoded_img = controller.decoded_imgs[2] |
|
|
decoded_img = decoded_img[0, :, :int(W * mul), :] |
|
|
if step == "49": |
|
|
decoded_img = cv2.imread("/project/DiffBVR_eval/DAVIS/BDx8_results/DiffBIR_ours/cows/00002.png") |
|
|
decoded_img = cv2.resize(decoded_img, (W, H)) |
|
|
ax[0, 1].imshow(decoded_img, aspect='auto') |
|
|
else: |
|
|
|
|
|
ax[0, 0].imshow(src_PCA_token, cmap='ocean', aspect='auto') |
|
|
ax[0, 1].imshow(tar_PCA_token, cmap='ocean', aspect='auto') |
|
|
|
|
|
ax[0, 0].axis('off') |
|
|
ax[0, 1].axis('off') |
|
|
|
|
|
ax[1, 0].imshow(src_PCA_token, cmap='Blues', aspect='auto') |
|
|
ax[1, 1].imshow(tar_PCA_token, cmap='Blues', aspect='auto') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax[1, 0].axis('off') |
|
|
ax[1, 1].axis('off') |
|
|
|
|
|
|
|
|
colors = cm.Greens(np.linspace(0.25, 0.75, len(flow_src_xy))) |
|
|
|
|
|
for (x1, y1), (x2, y2), color in zip(flow_src_xy, flow_tar_xy, colors): |
|
|
|
|
|
|
|
|
ax[1, 1].plot(x2, y2, marker='o', color=color, markersize=1) |
|
|
con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", |
|
|
axesA=ax[1, 1], axesB=ax[1, 0], color=color, linewidth=0.2) |
|
|
ax[1, 1].add_artist(con) |
|
|
|
|
|
colors = cm.Reds(np.linspace(0.25, 0.75, len(src_xy))) |
|
|
|
|
|
for (x1, y1), (x2, y2), color in zip(src_xy, tar_xy, colors): |
|
|
|
|
|
|
|
|
ax[1, 1].plot(x2, y2, marker='o', color=color, markersize=1) |
|
|
con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", |
|
|
axesA=ax[1, 1], axesB=ax[1, 0], color=color, linewidth=0.2) |
|
|
ax[1, 1].add_artist(con) |
|
|
|
|
|
plt.subplots_adjust(wspace=0.05, hspace=0.1) |
|
|
plt.savefig(out, bbox_inches="tight") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def bipartite_soft_matching_randframe(metric: torch.Tensor, |
|
|
F: int, ratio: float, unm_pre: int, generator: torch.Generator=None, |
|
|
target_stride: int = 4, align_batch: bool = False, |
|
|
merge_mode: str = "replace", H: int=64, |
|
|
flow_merge: bool=False, |
|
|
controller: AttentionControl=None) -> Tuple[Callable, Callable, dict]: |
|
|
""" |
|
|
Partitions the multi-frame tokens into src and dst and merges ratio of src tokens from src to dst. |
|
|
Dst tokens are partitioned by choosing one random frame. |
|
|
|
|
|
Args: |
|
|
- metric [B, N, C]: metric to use for similarity. |
|
|
- F: frame number. |
|
|
- ratio: ratio of src tokens to be removed (by merging). |
|
|
- unm_pre: number of src tokens not merged at previous ToMe. Pre-sequence: [unm_pre|F_0|F_1|...] |
|
|
- generator: random number generator |
|
|
- target_stride: stride of target frame. |
|
|
- align_batch: whether to align similarity matching maps of samples in the batch. True when using PnP. |
|
|
- merge_mode: how to merge tokens. "mean": tokens -> Mean(src_token, dst_token); "replace": tokens -> dst_token. |
|
|
|
|
|
Returns: |
|
|
Merge and unmerge operation according to the matching result. Return a dict including other values. |
|
|
""" |
|
|
B, N, _ = metric.shape |
|
|
A = N // F |
|
|
W = A // H |
|
|
|
|
|
tnum = (N - unm_pre) // F |
|
|
if ratio <= 0: |
|
|
return do_nothing, do_nothing, {"unm_num": tnum} |
|
|
|
|
|
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
idx_buffer = torch.arange( |
|
|
N - unm_pre, device=metric.device, dtype=torch.int64) |
|
|
|
|
|
|
|
|
target_stride = min(target_stride, F) |
|
|
|
|
|
if controller is None: |
|
|
randf = torch.randint(0, target_stride, torch.Size( |
|
|
[1]), generator=generator, device=generator.device) |
|
|
else: |
|
|
randf = torch.tensor(target_stride // 2).to(metric.device) |
|
|
|
|
|
dst_select = ((torch.div(idx_buffer, tnum, rounding_mode='floor')) % |
|
|
target_stride == randf).to(torch.bool) |
|
|
|
|
|
|
|
|
a_idx = idx_buffer[None, ~dst_select, None] + unm_pre |
|
|
b_idx = idx_buffer[None, dst_select, None] + unm_pre |
|
|
|
|
|
|
|
|
|
|
|
unm_buffer = torch.arange(unm_pre, device=metric.device, dtype=torch.int64)[ |
|
|
None, :, None] |
|
|
b_idx = torch.cat([b_idx, unm_buffer], dim=1) |
|
|
|
|
|
|
|
|
del idx_buffer, unm_buffer |
|
|
|
|
|
num_dst = b_idx.shape[1] |
|
|
|
|
|
def split(x): |
|
|
|
|
|
b, n, c = x.shape |
|
|
src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c)) |
|
|
dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c)) |
|
|
|
|
|
return src, dst |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metric = metric / metric.norm(dim=-1, keepdim=True) |
|
|
|
|
|
a, b = split(metric) |
|
|
scores = a @ b.transpose(-1, -2) |
|
|
|
|
|
|
|
|
r = min(a.shape[1], int(a.shape[1] * ratio)) |
|
|
|
|
|
if align_batch: |
|
|
|
|
|
|
|
|
scores = torch.cat([*scores], dim=-1) |
|
|
node_max, node_idx = scores.max(dim=-1) |
|
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
|
|
|
|
|
unm_idx = edge_idx[..., r:, :] |
|
|
src_idx = edge_idx[..., :r, :] |
|
|
dst_idx = gather(node_idx[..., None], |
|
|
dim=-2, index=src_idx) % num_dst |
|
|
|
|
|
|
|
|
unm_idx = unm_idx.expand(B, -1, -1) |
|
|
src_idx = src_idx.expand(B, -1, -1) |
|
|
dst_idx = dst_idx.expand(B, -1, -1) |
|
|
else: |
|
|
|
|
|
if flow_merge: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unm_idx = controller.flow_correspondence[H][0][:, r:, :] |
|
|
src_idx = controller.flow_correspondence[H][0][:, :r, :] |
|
|
tar_idx = controller.flow_correspondence[H][1][:, :r, :] |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
''' distacne weighted ''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if controller is not None: |
|
|
if H not in controller.distances: |
|
|
controller.set_distance(F-1, H, W, W//30, metric.device) |
|
|
print(f"[INFO] distance weighted ... ") |
|
|
|
|
|
scores *= controller.distances[H] |
|
|
|
|
|
|
|
|
''' node_idx: src_idx to tar_idx ''' |
|
|
node_max, node_idx = scores.max(dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
|
|
|
|
|
unm_idx = edge_idx[..., r:, :] |
|
|
''' idx in all src tokens ''' |
|
|
src_idx = edge_idx[..., :r, :] |
|
|
tar_idx = gather(node_idx[..., None], dim=-2, index=src_idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def merge(x: torch.Tensor, mode=None) -> torch.Tensor: |
|
|
|
|
|
src, dst = split(x) |
|
|
n, t1, c = src.shape |
|
|
u_idx, s_idx, t_idx = unm_idx, src_idx, tar_idx |
|
|
|
|
|
unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c)) |
|
|
mode = mode if mode is not None else merge_mode |
|
|
if mode != "replace": |
|
|
src = gather(src, dim=-2, index=s_idx.expand(-1, -1, c)) |
|
|
|
|
|
dst = dst.scatter_reduce(-2, t_idx.expand(-1, -1, c), |
|
|
src, reduce=mode, include_self=True) |
|
|
|
|
|
return torch.cat([unm, dst], dim=1) |
|
|
|
|
|
def unmerge(x: torch.Tensor, **kwarg) -> torch.Tensor: |
|
|
|
|
|
unm_len = unm_idx.shape[1] |
|
|
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] |
|
|
b, _, c = unm.shape |
|
|
u_idx, s_idx, t_idx = unm_idx, src_idx, tar_idx |
|
|
|
|
|
src = gather(dst, dim=-2, index=t_idx.expand(-1, -1, c)) |
|
|
|
|
|
out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype) |
|
|
|
|
|
out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst) |
|
|
|
|
|
out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), |
|
|
dim=1, index=u_idx).expand(-1, -1, c), src=unm) |
|
|
|
|
|
out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), |
|
|
dim=1, index=s_idx).expand(-1, -1, c), src=src) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
ret_dict = {"scores": scores, "unm_num": unm_idx.shape[1] if unm_idx.shape[1] is not None else 0} |
|
|
return merge, unmerge, ret_dict |
|
|
|
|
|
|
|
|
def bipartite_soft_matching_random2d_hier(metric: torch.Tensor, frame_num: int, ratio: float, unm_pre: int, generator: torch.Generator, target_stride: int = 4, adhere_src: bool = False, merge_mode: str = "replace", scores = None, coord = None, rec_field = 2) -> Tuple[Callable, Callable]: |
|
|
""" |
|
|
Partitions the tokens into src and dst and merges r tokens from src to dst. |
|
|
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. |
|
|
|
|
|
Args: |
|
|
- metric [B, N, C]: metric to use for similarity |
|
|
- w: image width in tokens |
|
|
- h: image height in tokens |
|
|
- sx: stride in the x dimension for dst, must divide w |
|
|
- sy: stride in the y dimension for dst, must divide h |
|
|
- r: number of tokens to remove (by merging) |
|
|
- no_rand: if true, disable randomness (use top left corner only) |
|
|
- rand_seed: if no_rand is false, and if not None, sets random seed. |
|
|
""" |
|
|
B, N, _ = metric.shape |
|
|
F = frame_num |
|
|
nf = (N - unm_pre) // F |
|
|
|
|
|
if ratio <= 0: |
|
|
return do_nothing, do_nothing |
|
|
|
|
|
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
idx_buffer = torch.arange(N - unm_pre, device=metric.device, dtype=torch.int64) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
max_f = min(target_stride, F) |
|
|
randn = torch.randint(0, max_f, torch.Size([1]), generator=generator, device = generator.device) |
|
|
|
|
|
dst_select = ((torch.div(idx_buffer, nf, rounding_mode='floor')) % max_f == randn).to(torch.bool) |
|
|
|
|
|
a_idx = idx_buffer[None, ~dst_select, None] + unm_pre |
|
|
b_idx = idx_buffer[None, dst_select, None] + unm_pre |
|
|
|
|
|
unm_buffer = torch.arange(unm_pre, device=metric.device, dtype=torch.int64)[None,:,None] |
|
|
b_idx = torch.cat([b_idx, unm_buffer], dim = 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
del idx_buffer, unm_buffer |
|
|
|
|
|
num_dst = b_idx.shape[1] |
|
|
|
|
|
def split(x): |
|
|
b, n, c = x.shape |
|
|
src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c)) |
|
|
dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c)) |
|
|
return src, dst |
|
|
|
|
|
def split_coord(coord): |
|
|
b, n, c = coord.shape |
|
|
src = gather(coord, dim=1, index=a_idx.expand(b, n - num_dst, c)) |
|
|
dst = gather(coord, dim=1, index=b_idx.expand(b, num_dst, c)) |
|
|
return src, dst |
|
|
|
|
|
|
|
|
|
|
|
metric = metric / metric.norm(dim=-1, keepdim=True) |
|
|
a, b = split(metric) |
|
|
|
|
|
|
|
|
if coord is not None: |
|
|
src_coord, dst_coord = split_coord(coord) |
|
|
mask = torch.norm(src_coord[:,:,None,:] - dst_coord[:,None,:,:], dim=-1) > rec_field |
|
|
|
|
|
|
|
|
scores = a @ b.transpose(-1, -2) |
|
|
|
|
|
if coord is not None: |
|
|
scores[mask] = 0 |
|
|
|
|
|
|
|
|
r = int(a.shape[1] * ratio) |
|
|
r = min(a.shape[1], r) |
|
|
|
|
|
|
|
|
|
|
|
if adhere_src: |
|
|
|
|
|
scores = torch.cat([*scores], dim = -1) |
|
|
node_max, node_idx = scores.max(dim=-1) |
|
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
|
|
|
|
|
unm_idx = edge_idx[..., r:, :] |
|
|
src_idx = edge_idx[..., :r, :] |
|
|
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) % num_dst |
|
|
|
|
|
unm_idx = unm_idx.expand(B, -1, -1) |
|
|
src_idx = src_idx.expand(B, -1, -1) |
|
|
dst_idx = dst_idx.expand(B, -1, -1) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
node_max, node_idx = scores.max(dim=-1) |
|
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
|
|
|
|
|
unm_idx = edge_idx[..., r:, :] |
|
|
src_idx = edge_idx[..., :r, :] |
|
|
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def merge(x: torch.Tensor, mode=None, b_select = None, **kwarg) -> torch.Tensor: |
|
|
src, dst = split(x) |
|
|
n, t1, c = src.shape |
|
|
if b_select is not None: |
|
|
if not isinstance(b_select, list): |
|
|
b_select = [b_select] |
|
|
u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select] |
|
|
else: |
|
|
u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx |
|
|
|
|
|
unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c)) |
|
|
src = gather(src, dim=-2, index=s_idx.expand(-1, -1, c)) |
|
|
mode = mode if mode is not None else merge_mode |
|
|
if mode != "replace": |
|
|
dst = dst.scatter_reduce(-2, d_idx.expand(-1, -1, c), src, reduce=mode, include_self=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return torch.cat([unm, dst], dim=1) |
|
|
|
|
|
def unmerge(x: torch.Tensor, b_select = None, unm_modi = None, **kwarg) -> torch.Tensor: |
|
|
unm_len = unm_idx.shape[1] |
|
|
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] |
|
|
b, _, c = unm.shape |
|
|
if b_select is not None: |
|
|
if not isinstance(b_select, list): |
|
|
b_select = [b_select] |
|
|
u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select] |
|
|
else: |
|
|
u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx |
|
|
if unm_modi is not None: |
|
|
if unm_modi == "zero": |
|
|
unm = torch.zeros_like(unm) |
|
|
src = gather(dst, dim=-2, index=d_idx.expand(-1, -1, c)) |
|
|
|
|
|
|
|
|
out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype) |
|
|
out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst) |
|
|
out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=u_idx).expand(-1, -1, c), src=unm) |
|
|
out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=s_idx).expand(-1, -1, c), src=src) |
|
|
|
|
|
return out |
|
|
|
|
|
ret_dict = {"unm_num": unm_idx.shape[1]} |
|
|
return merge, unmerge, ret_dict |
|
|
|
|
|
|
|
|
def bipartite_soft_matching_2s( metric: torch.Tensor, |
|
|
src_len: int, ratio: float, align_batch: bool, |
|
|
merge_mode: str = "replace", unmerge_chunk: int = 0) -> Tuple[Callable, Callable, dict]: |
|
|
""" |
|
|
Partitions the tokens into src and dst and merges ratio of src tokens from src to dst. |
|
|
Src tokens are partitioned as first src_len tokens. Others are dst tokens. |
|
|
|
|
|
Args: |
|
|
- metric [B, N, C]: metric to use for similarity. |
|
|
- src_len: src token length. [ src | dst ]: [ src_len | N - src_len ] |
|
|
- ratio: ratio of src tokens to be removed (by merging). |
|
|
- unm_pre: number of src tokens not merged at previous ToMe. Pre-sequence: [unm_pre|F_0|F_1|...] |
|
|
- align_batch: whether to align similarity matching maps of samples in the batch. True when using PnP. |
|
|
- merge_mode: how to merge tokens. "mean": tokens -> Mean(src_token, dst_token); "replace": tokens -> dst_token. |
|
|
- unmerge_chunk: return which partition in unmerge. 0 for src and 1 for dst. |
|
|
|
|
|
Returns: |
|
|
Merge and unmerge operation according to the matching result. Return a dict including other values. |
|
|
""" |
|
|
B, N, _ = metric.shape |
|
|
|
|
|
if ratio <= 0: |
|
|
return do_nothing, do_nothing |
|
|
|
|
|
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
idx_buffer = torch.arange(N, device=metric.device, dtype=torch.int64) |
|
|
|
|
|
|
|
|
a_idx = idx_buffer[None, :src_len, None] |
|
|
b_idx = idx_buffer[None, src_len:, None] |
|
|
|
|
|
del idx_buffer |
|
|
|
|
|
num_dst = b_idx.shape[1] |
|
|
|
|
|
def split(x): |
|
|
|
|
|
b, n, c = x.shape |
|
|
|
|
|
src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c)) |
|
|
dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c)) |
|
|
return src, dst |
|
|
|
|
|
|
|
|
metric = metric / metric.norm(dim=-1, keepdim=True) |
|
|
a, b = split(metric) |
|
|
|
|
|
scores = a @ b.transpose(-1, -2) |
|
|
|
|
|
|
|
|
r = min(a.shape[1], int(a.shape[1] * ratio)) |
|
|
|
|
|
if align_batch: |
|
|
|
|
|
|
|
|
scores = torch.cat([*scores], dim=-1) |
|
|
node_max, node_idx = scores.max(dim=-1) |
|
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
|
|
|
|
|
unm_idx = edge_idx[..., r:, :] |
|
|
src_idx = edge_idx[..., :r, :] |
|
|
dst_idx = gather(node_idx[..., None], |
|
|
dim=-2, index=src_idx) % num_dst |
|
|
|
|
|
|
|
|
unm_idx = unm_idx.expand(B, -1, -1) |
|
|
src_idx = src_idx.expand(B, -1, -1) |
|
|
dst_idx = dst_idx.expand(B, -1, -1) |
|
|
else: |
|
|
|
|
|
|
|
|
node_max, node_idx = scores.max(dim=-1) |
|
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
|
|
|
|
|
unm_idx = edge_idx[..., r:, :] |
|
|
src_idx = edge_idx[..., :r, :] |
|
|
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) |
|
|
|
|
|
def merge(x: torch.Tensor, mode=None) -> torch.Tensor: |
|
|
|
|
|
|
|
|
src, dst = split(x) |
|
|
n, t1, c = src.shape |
|
|
u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx |
|
|
|
|
|
unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c)) |
|
|
mode = mode if mode is not None else merge_mode |
|
|
if mode != "replace": |
|
|
src = gather(src, dim=-2, index=s_idx.expand(-1, -1, c)) |
|
|
|
|
|
dst = dst.scatter_reduce(-2, d_idx.expand(-1, -1, c), |
|
|
src, reduce=mode, include_self=True) |
|
|
|
|
|
return torch.cat([unm, dst], dim=1) |
|
|
|
|
|
def unmerge(x: torch.Tensor, **kwarg) -> torch.Tensor: |
|
|
|
|
|
unm_len = unm_idx.shape[1] |
|
|
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] |
|
|
b, _, c = unm.shape |
|
|
u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx |
|
|
|
|
|
src = gather(dst, dim=-2, index=d_idx.expand(-1, -1, c)) |
|
|
|
|
|
|
|
|
out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype) |
|
|
|
|
|
out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst) |
|
|
|
|
|
out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), |
|
|
dim=1, index=u_idx).expand(-1, -1, c), src=unm) |
|
|
|
|
|
out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), |
|
|
dim=1, index=s_idx).expand(-1, -1, c), src=src) |
|
|
|
|
|
out = out[:, :src_len, :] if unmerge_chunk == 0 else out[:, src_len:, :] |
|
|
return out |
|
|
|
|
|
ret_dict = {"unm_num": unm_idx.shape[1]} |
|
|
return merge, unmerge, ret_dict |
|
|
|
|
|
|
|
|
|
|
|
def bipartite_soft_matching_random2d(metric: torch.Tensor, |
|
|
w: int, h: int, sx: int, sy: int, r: int, |
|
|
no_rand: bool = False, |
|
|
generator: torch.Generator = None) -> Tuple[Callable, Callable]: |
|
|
""" |
|
|
Partitions the tokens into src and dst and merges r tokens from src to dst. |
|
|
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. |
|
|
|
|
|
Args: |
|
|
- metric [B, N, C]: metric to use for similarity |
|
|
- w: image width in tokens |
|
|
- h: image height in tokens |
|
|
- sx: stride in the x dimension for dst, must divide w |
|
|
- sy: stride in the y dimension for dst, must divide h |
|
|
- r: number of tokens to remove (by merging) |
|
|
- no_rand: if true, disable randomness (use top left corner only) |
|
|
- rand_seed: if no_rand is false, and if not None, sets random seed. |
|
|
""" |
|
|
B, N, _ = metric.shape |
|
|
|
|
|
if r <= 0: |
|
|
return do_nothing, do_nothing |
|
|
|
|
|
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather |
|
|
|
|
|
with torch.no_grad(): |
|
|
hsy, wsx = h // sy, w // sx |
|
|
|
|
|
|
|
|
if no_rand: |
|
|
rand_idx = torch.zeros( |
|
|
hsy, wsx, 1, device=metric.device, dtype=torch.int64) |
|
|
else: |
|
|
rand_idx = torch.randint( |
|
|
sy*sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(metric.device) |
|
|
|
|
|
|
|
|
idx_buffer_view = torch.zeros( |
|
|
hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64) |
|
|
idx_buffer_view.scatter_( |
|
|
dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype)) |
|
|
idx_buffer_view = idx_buffer_view.view( |
|
|
hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx) |
|
|
|
|
|
|
|
|
if (hsy * sy) < h or (wsx * sx) < w: |
|
|
idx_buffer = torch.zeros( |
|
|
h, w, device=metric.device, dtype=torch.int64) |
|
|
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view |
|
|
else: |
|
|
idx_buffer = idx_buffer_view |
|
|
|
|
|
|
|
|
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1) |
|
|
|
|
|
|
|
|
del idx_buffer, idx_buffer_view |
|
|
|
|
|
|
|
|
num_dst = hsy * wsx |
|
|
a_idx = rand_idx[:, num_dst:, :] |
|
|
b_idx = rand_idx[:, :num_dst, :] |
|
|
|
|
|
def split(x): |
|
|
C = x.shape[-1] |
|
|
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C)) |
|
|
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) |
|
|
return src, dst |
|
|
|
|
|
|
|
|
metric = metric / metric.norm(dim=-1, keepdim=True) |
|
|
a, b = split(metric) |
|
|
scores = a @ b.transpose(-1, -2) |
|
|
|
|
|
|
|
|
r = min(a.shape[1], r) |
|
|
|
|
|
|
|
|
node_max, node_idx = scores.max(dim=-1) |
|
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
|
|
|
|
|
unm_idx = edge_idx[..., r:, :] |
|
|
src_idx = edge_idx[..., :r, :] |
|
|
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) |
|
|
|
|
|
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: |
|
|
src, dst = split(x) |
|
|
n, t1, c = src.shape |
|
|
|
|
|
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c)) |
|
|
src = gather(src, dim=-2, index=src_idx.expand(n, r, c)) |
|
|
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) |
|
|
|
|
|
return torch.cat([unm, dst], dim=1) |
|
|
|
|
|
def unmerge(x: torch.Tensor) -> torch.Tensor: |
|
|
unm_len = unm_idx.shape[1] |
|
|
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] |
|
|
_, _, c = unm.shape |
|
|
|
|
|
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c)) |
|
|
|
|
|
|
|
|
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) |
|
|
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) |
|
|
out.scatter_(dim=-2, index=gather(a_idx.expand(B, |
|
|
a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) |
|
|
out.scatter_(dim=-2, index=gather(a_idx.expand(B, |
|
|
a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src) |
|
|
|
|
|
return out |
|
|
|
|
|
return merge, unmerge |
|
|
|
|
|
|
|
|
def bipartite_soft_matching_2f(metric: torch.Tensor, src_len: int, ratio: float, adhere_src: bool, merge_mode: str = "replace", scores = None, coord = None, rec_field = 2, unmerge_chunk = 0) -> Tuple[Callable, Callable]: |
|
|
""" |
|
|
Partitions the tokens into src and dst and merges r tokens from src to dst. |
|
|
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. |
|
|
|
|
|
Args: |
|
|
- metric [B, N, C]: metric to use for similarity |
|
|
- w: image width in tokens |
|
|
- h: image height in tokens |
|
|
- sx: stride in the x dimension for dst, must divide w |
|
|
- sy: stride in the y dimension for dst, must divide h |
|
|
- r: number of tokens to remove (by merging) |
|
|
- no_rand: if true, disable randomness (use top left corner only) |
|
|
- rand_seed: if no_rand is false, and if not None, sets random seed. |
|
|
""" |
|
|
B, N, _ = metric.shape |
|
|
|
|
|
if ratio <= 0: |
|
|
return do_nothing, do_nothing |
|
|
|
|
|
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
idx_buffer = torch.arange(N, device=metric.device, dtype=torch.int64) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
a_idx = idx_buffer[None, :src_len, None] |
|
|
b_idx = idx_buffer[None, src_len:, None] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
del idx_buffer |
|
|
|
|
|
num_dst = b_idx.shape[1] |
|
|
|
|
|
def split(x): |
|
|
b, n, c = x.shape |
|
|
src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c)) |
|
|
dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c)) |
|
|
return src, dst |
|
|
|
|
|
def split_coord(coord): |
|
|
b, n, c = coord.shape |
|
|
src = gather(coord, dim=1, index=a_idx.expand(b, n - num_dst, c)) |
|
|
dst = gather(coord, dim=1, index=b_idx.expand(b, num_dst, c)) |
|
|
return src, dst |
|
|
|
|
|
|
|
|
|
|
|
metric = metric / metric.norm(dim=-1, keepdim=True) |
|
|
a, b = split(metric) |
|
|
|
|
|
|
|
|
if coord is not None: |
|
|
src_coord, dst_coord = split_coord(coord) |
|
|
mask = torch.norm(src_coord[:,:,None,:] - dst_coord[:,None,:,:], dim=-1) > rec_field |
|
|
|
|
|
|
|
|
scores = a @ b.transpose(-1, -2) |
|
|
|
|
|
if coord is not None: |
|
|
scores[mask] = 0 |
|
|
|
|
|
|
|
|
r = int(a.shape[1] * ratio) |
|
|
r = min(a.shape[1], r) |
|
|
|
|
|
|
|
|
|
|
|
if adhere_src: |
|
|
scores = torch.cat([*scores], dim = -1) |
|
|
|
|
|
node_max, node_idx = scores.max(dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
|
|
|
|
|
unm_idx = edge_idx[..., r:, :] |
|
|
src_idx = edge_idx[..., :r, :] |
|
|
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) % num_dst |
|
|
|
|
|
unm_idx = unm_idx.expand(B, -1, -1) |
|
|
src_idx = src_idx.expand(B, -1, -1) |
|
|
dst_idx = dst_idx.expand(B, -1, -1) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
node_max, node_idx = scores.max(dim=-1) |
|
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
|
|
|
|
|
unm_idx = edge_idx[..., r:, :] |
|
|
src_idx = edge_idx[..., :r, :] |
|
|
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def merge(x: torch.Tensor, mode=None, b_select = None) -> torch.Tensor: |
|
|
|
|
|
src, dst = split(x) |
|
|
n, t1, c = src.shape |
|
|
if b_select is not None: |
|
|
if not isinstance(b_select, list): |
|
|
b_select = [b_select] |
|
|
u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select] |
|
|
else: |
|
|
u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx |
|
|
|
|
|
unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c)) |
|
|
|
|
|
mode = mode if mode is not None else merge_mode |
|
|
if mode != "replace": |
|
|
dst = dst.scatter_reduce(-2, d_idx.expand(-1, -1, c), src, reduce=mode, include_self=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return torch.cat([unm, dst], dim=1) |
|
|
|
|
|
def unmerge(x: torch.Tensor, b_select = None, unm_modi = None) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
unm_len = unm_idx.shape[1] |
|
|
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] |
|
|
b, _, c = unm.shape |
|
|
if b_select is not None: |
|
|
if not isinstance(b_select, list): |
|
|
b_select = [b_select] |
|
|
u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select] |
|
|
else: |
|
|
u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx |
|
|
if unm_modi is not None: |
|
|
if unm_modi == "zero": |
|
|
unm = torch.zeros_like(unm) |
|
|
src = gather(dst, dim=-2, index=d_idx.expand(-1, -1, c)) |
|
|
|
|
|
|
|
|
out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype) |
|
|
out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst) |
|
|
out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=u_idx).expand(-1, -1, c), src=unm) |
|
|
out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=s_idx).expand(-1, -1, c), src=src) |
|
|
|
|
|
|
|
|
if unmerge_chunk == 0: |
|
|
out = out[:,:src_len,:] |
|
|
else: |
|
|
out = out[:,src_len:,:] |
|
|
|
|
|
return out |
|
|
|
|
|
ret_dict = {"unm_num": unm_idx.shape[1]} |
|
|
return merge, unmerge, ret_dict |