Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from functools import partial | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| import math, random | |
| #from sklearn.cluster import KMeans, kmeans_plusplus, MeanShift, estimate_bandwidth | |
| def tensor_kmeans_sklearn(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20): | |
| N,C,H,W = data_vecs.shape | |
| assert N == 1, 'only support singe image tensor' | |
| ## (1,C,H,W) -> (HW,C) | |
| data_vecs = data_vecs.permute(0,2,3,1).view(-1,C) | |
| ## convert tensor to array | |
| data_vecs_np = data_vecs.squeeze().detach().to("cpu").numpy() | |
| km = KMeans(n_clusters=n_clusters, init='k-means++', n_init=10, max_iter=300) | |
| pred = km.fit_predict(data_vecs_np) | |
| cluster_ids_x = torch.from_numpy(km.labels_).to(data_vecs.device) | |
| id_maps = cluster_ids_x.reshape(1,1,H,W).long() | |
| if need_layer_masks: | |
| one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float() | |
| cluster_mask = one_hot_labels.permute(0,3,1,2) | |
| return cluster_mask | |
| return id_maps | |
| def tensor_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20): | |
| N,C,H,W = data_vecs.shape | |
| assert N == 1, 'only support singe image tensor' | |
| ## (1,C,H,W) -> (HW,C) | |
| data_vecs = data_vecs.permute(0,2,3,1).view(-1,C) | |
| ## cosine | euclidean | |
| #cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric, device=data_vecs.device) | |
| cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\ | |
| tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device) | |
| id_maps = cluster_ids_x.reshape(1,1,H,W) | |
| if need_layer_masks: | |
| one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float() | |
| cluster_mask = one_hot_labels.permute(0,3,1,2) | |
| return cluster_mask | |
| return id_maps | |
| def batch_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', use_sklearn_kmeans=False): | |
| N,C,H,W = data_vecs.shape | |
| sample_list = [] | |
| for idx in range(N): | |
| if use_sklearn_kmeans: | |
| cluster_mask = tensor_kmeans_sklearn(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True) | |
| else: | |
| cluster_mask = tensor_kmeans_pytorch(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True) | |
| sample_list.append(cluster_mask) | |
| return torch.cat(sample_list, dim=0) | |
| def get_centroid_candidates(data_vecs, n_clusters=7, metric='euclidean', max_iters=20): | |
| N,C,H,W = data_vecs.shape | |
| data_vecs = data_vecs.permute(0,2,3,1).view(-1,C) | |
| cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\ | |
| tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device) | |
| return cluster_centers | |
| def find_distinctive_elements(data_tensor, n_clusters=7, topk=3, metric='euclidean'): | |
| N,C,H,W = data_tensor.shape | |
| centroid_list = [] | |
| for idx in range(N): | |
| cluster_centers = get_centroid_candidates(data_tensor[idx:idx+1,:,:,:], n_clusters, metric) | |
| centroid_list.append(cluster_centers) | |
| batch_centroids = torch.stack(centroid_list, dim=0) | |
| data_vecs = data_tensor.flatten(2) | |
| ## distance matrix: (N,K,HW) = (N,K,C) x (N,C,HW) | |
| AtB = torch.matmul(batch_centroids, data_vecs) | |
| AtA = torch.matmul(batch_centroids, batch_centroids.permute(0,2,1)) | |
| BtB = torch.matmul(data_vecs.permute(0,2,1), data_vecs) | |
| diag_A = torch.diagonal(AtA, dim1=-2, dim2=-1) | |
| diag_B = torch.diagonal(BtB, dim1=-2, dim2=-1) | |
| A2 = diag_A.unsqueeze(2).repeat(1,1,H*W) | |
| B2 = diag_B.unsqueeze(1).repeat(1,n_clusters,1) | |
| distance_map = A2 - 2*AtB + B2 | |
| values, indices = distance_map.topk(topk, dim=2, largest=False, sorted=True) | |
| cluster_mask = torch.where(distance_map <= values[:,:,topk-1:], torch.ones_like(distance_map), torch.zeros_like(distance_map)) | |
| cluster_mask = cluster_mask.view(N,n_clusters,H,W) | |
| return cluster_mask | |
| ##--------------------------------------------------------------------------------- | |
| ''' | |
| resource from github: https://github.com/subhadarship/kmeans_pytorch | |
| ''' | |
| ##--------------------------------------------------------------------------------- | |
| def initialize(X, num_clusters): | |
| """ | |
| initialize cluster centers | |
| :param X: (torch.tensor) matrix | |
| :param num_clusters: (int) number of clusters | |
| :return: (np.array) initial state | |
| """ | |
| np.random.seed(1) | |
| num_samples = len(X) | |
| indices = np.random.choice(num_samples, num_clusters, replace=False) | |
| initial_state = X[indices] | |
| return initial_state | |
| def kmeans( | |
| X, | |
| num_clusters, | |
| distance='euclidean', | |
| cluster_centers=[], | |
| tol=1e-4, | |
| tqdm_flag=True, | |
| iter_limit=0, | |
| device=torch.device('cpu'), | |
| gamma_for_soft_dtw=0.001 | |
| ): | |
| """ | |
| perform kmeans | |
| :param X: (torch.tensor) matrix | |
| :param num_clusters: (int) number of clusters | |
| :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] | |
| :param tol: (float) threshold [default: 0.0001] | |
| :param device: (torch.device) device [default: cpu] | |
| :param tqdm_flag: Allows to turn logs on and off | |
| :param iter_limit: hard limit for max number of iterations | |
| :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0 | |
| :return: (torch.tensor, torch.tensor) cluster ids, cluster centers | |
| """ | |
| if tqdm_flag: | |
| print(f'running k-means on {device}..') | |
| if distance == 'euclidean': | |
| pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag) | |
| elif distance == 'cosine': | |
| pairwise_distance_function = partial(pairwise_cosine, device=device) | |
| else: | |
| raise NotImplementedError | |
| # convert to float | |
| X = X.float() | |
| # transfer to device | |
| X = X.to(device) | |
| # initialize | |
| if type(cluster_centers) == list: # ToDo: make this less annoyingly weird | |
| initial_state = initialize(X, num_clusters) | |
| else: | |
| if tqdm_flag: | |
| print('resuming') | |
| # find data point closest to the initial cluster center | |
| initial_state = cluster_centers | |
| dis = pairwise_distance_function(X, initial_state) | |
| choice_points = torch.argmin(dis, dim=0) | |
| initial_state = X[choice_points] | |
| initial_state = initial_state.to(device) | |
| iteration = 0 | |
| if tqdm_flag: | |
| tqdm_meter = tqdm(desc='[running kmeans]') | |
| while True: | |
| dis = pairwise_distance_function(X, initial_state) | |
| choice_cluster = torch.argmin(dis, dim=1) | |
| initial_state_pre = initial_state.clone() | |
| for index in range(num_clusters): | |
| selected = torch.nonzero(choice_cluster == index).squeeze().to(device) | |
| selected = torch.index_select(X, 0, selected) | |
| # https://github.com/subhadarship/kmeans_pytorch/issues/16 | |
| if selected.shape[0] == 0: | |
| selected = X[torch.randint(len(X), (1,))] | |
| initial_state[index] = selected.mean(dim=0) | |
| center_shift = torch.sum( | |
| torch.sqrt( | |
| torch.sum((initial_state - initial_state_pre) ** 2, dim=1) | |
| )) | |
| # increment iteration | |
| iteration = iteration + 1 | |
| # update tqdm meter | |
| if tqdm_flag: | |
| tqdm_meter.set_postfix( | |
| iteration=f'{iteration}', | |
| center_shift=f'{center_shift ** 2:0.6f}', | |
| tol=f'{tol:0.6f}' | |
| ) | |
| tqdm_meter.update() | |
| if center_shift ** 2 < tol: | |
| break | |
| if iter_limit != 0 and iteration >= iter_limit: | |
| #print('hello, there!') | |
| break | |
| return choice_cluster.to(device), initial_state.to(device) | |
| def kmeans_predict( | |
| X, | |
| cluster_centers, | |
| distance='euclidean', | |
| device=torch.device('cpu'), | |
| gamma_for_soft_dtw=0.001, | |
| tqdm_flag=True | |
| ): | |
| """ | |
| predict using cluster centers | |
| :param X: (torch.tensor) matrix | |
| :param cluster_centers: (torch.tensor) cluster centers | |
| :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] | |
| :param device: (torch.device) device [default: 'cpu'] | |
| :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0 | |
| :return: (torch.tensor) cluster ids | |
| """ | |
| if tqdm_flag: | |
| print(f'predicting on {device}..') | |
| if distance == 'euclidean': | |
| pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag) | |
| elif distance == 'cosine': | |
| pairwise_distance_function = partial(pairwise_cosine, device=device) | |
| elif distance == 'soft_dtw': | |
| sdtw = SoftDTW(use_cuda=device.type == 'cuda', gamma=gamma_for_soft_dtw) | |
| pairwise_distance_function = partial(pairwise_soft_dtw, sdtw=sdtw, device=device) | |
| else: | |
| raise NotImplementedError | |
| # convert to float | |
| X = X.float() | |
| # transfer to device | |
| X = X.to(device) | |
| dis = pairwise_distance_function(X, cluster_centers) | |
| choice_cluster = torch.argmin(dis, dim=1) | |
| return choice_cluster.cpu() | |
| def pairwise_distance(data1, data2, device=torch.device('cpu'), tqdm_flag=True): | |
| if tqdm_flag: | |
| print(f'device is :{device}') | |
| # transfer to device | |
| data1, data2 = data1.to(device), data2.to(device) | |
| # N*1*M | |
| A = data1.unsqueeze(dim=1) | |
| # 1*N*M | |
| B = data2.unsqueeze(dim=0) | |
| dis = (A - B) ** 2.0 | |
| # return N*N matrix for pairwise distance | |
| dis = dis.sum(dim=-1).squeeze() | |
| return dis | |
| def pairwise_cosine(data1, data2, device=torch.device('cpu')): | |
| # transfer to device | |
| data1, data2 = data1.to(device), data2.to(device) | |
| # N*1*M | |
| A = data1.unsqueeze(dim=1) | |
| # 1*N*M | |
| B = data2.unsqueeze(dim=0) | |
| # normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5] | |
| A_normalized = A / A.norm(dim=-1, keepdim=True) | |
| B_normalized = B / B.norm(dim=-1, keepdim=True) | |
| cosine = A_normalized * B_normalized | |
| # return N*N matrix for pairwise distance | |
| cosine_dis = 1 - cosine.sum(dim=-1).squeeze() | |
| return cosine_dis |