Spaces:
Build error
Build error
| from openTSNE import TSNE | |
| import numpy as np | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import random | |
| def visualize( | |
| x, | |
| y, | |
| ax=None, | |
| title=None, | |
| draw_legend=True, | |
| draw_centers=False, | |
| draw_cluster_labels=False, | |
| colors=None, | |
| legend_kwargs=None, | |
| label_order=None, | |
| **kwargs | |
| ): | |
| if ax is None: | |
| _, ax = matplotlib.pyplot.subplots(figsize=(10, 8)) | |
| if title is not None: | |
| ax.set_title(title) | |
| plot_params = {"alpha": kwargs.get("alpha", 0.6), "s": kwargs.get("s", 1)} | |
| # Create main plot | |
| if label_order is not None: | |
| assert all(np.isin(np.unique(y), label_order)) | |
| classes = [l for l in label_order if l in np.unique(y)] | |
| else: | |
| classes = np.unique(y) | |
| if colors is None: | |
| default_colors = matplotlib.rcParams["axes.prop_cycle"] | |
| colors = {k: v["color"] for k, v in zip(classes, default_colors())} | |
| point_colors = list(map(colors.get, y)) | |
| ax.scatter(x[:, 0], x[:, 1], c=point_colors, rasterized=True, **plot_params) | |
| # Plot mediods | |
| if draw_centers: | |
| centers = [] | |
| for yi in classes: | |
| mask = yi == y | |
| centers.append(np.median(x[mask, :2], axis=0)) | |
| centers = np.array(centers) | |
| center_colors = list(map(colors.get, classes)) | |
| ax.scatter( | |
| centers[:, 0], centers[:, 1], c=center_colors, s=48, alpha=1, edgecolor="k" | |
| ) | |
| # Draw mediod labels | |
| if draw_cluster_labels: | |
| for idx, label in enumerate(classes): | |
| ax.text( | |
| centers[idx, 0], | |
| centers[idx, 1] + 2.2, | |
| label, | |
| fontsize=kwargs.get("fontsize", 6), | |
| horizontalalignment="center", | |
| ) | |
| # Hide ticks and axis | |
| ax.set_xticks([]), ax.set_yticks([]), ax.axis("off") | |
| if draw_legend: | |
| legend_handles = [ | |
| matplotlib.lines.Line2D( | |
| [], | |
| [], | |
| marker="s", | |
| color="w", | |
| markerfacecolor=colors[yi], | |
| ms=10, | |
| alpha=1, | |
| linewidth=0, | |
| label=yi, | |
| markeredgecolor="k", | |
| ) | |
| for yi in classes | |
| ] | |
| legend_kwargs_ = dict(loc="best", bbox_to_anchor=(0.05, 0.5), frameon=False, ) | |
| if legend_kwargs is not None: | |
| legend_kwargs_.update(legend_kwargs) | |
| ax.legend(handles=legend_handles, **legend_kwargs_) | |
| tsne = TSNE( | |
| perplexity=30, | |
| metric="euclidean", | |
| n_jobs=8, | |
| random_state=42, | |
| verbose=True, | |
| ) | |
| idexp_lm3d_pred_lrs3 = np.load("infer_out/tmp_npys/lrs3_pred_all.npy") | |
| idx = np.random.choice(np.arange(len(idexp_lm3d_pred_lrs3)), 10000) | |
| idexp_lm3d_pred_lrs3 = idexp_lm3d_pred_lrs3[idx] | |
| person_ds = np.load("data/binary/videos/May/trainval_dataset.npy", allow_pickle=True).tolist() | |
| person_idexp_mean = person_ds['idexp_lm3d_mean'].reshape([1,204]) | |
| person_idexp_std = person_ds['idexp_lm3d_std'].reshape([1,204]) | |
| person_idexp_lm3d_train = np.stack([s['idexp_lm3d_normalized'].reshape([204,]) for s in person_ds['train_samples']]) | |
| person_idexp_lm3d_val = np.stack([s['idexp_lm3d_normalized'].reshape([204,]) for s in person_ds['val_samples']]) | |
| lrs3_stats = np.load('/home/yezhenhui/datasets/binary/lrs3_0702/stats.npy',allow_pickle=True).tolist() | |
| lrs3_idexp_mean = lrs3_stats['idexp_lm3d_mean'].reshape([1,204]) | |
| lrs3_idexp_std = lrs3_stats['idexp_lm3d_std'].reshape([1,204]) | |
| person_idexp_lm3d_train = person_idexp_lm3d_train * person_idexp_std + person_idexp_mean | |
| # person_idexp_lm3d_train = (person_idexp_lm3d_train - lrs3_idexp_mean) / lrs3_idexp_std | |
| person_idexp_lm3d_val = person_idexp_lm3d_val * person_idexp_std + person_idexp_mean | |
| # person_idexp_lm3d_val = (person_idexp_lm3d_val - lrs3_idexp_mean) / lrs3_idexp_std | |
| idexp_lm3d_pred_lrs3 = idexp_lm3d_pred_lrs3 * lrs3_idexp_std + lrs3_idexp_mean | |
| idexp_lm3d_pred_vae = np.load("infer_out/tmp_npys/pred_exp_0_vae.npy").reshape([-1,204]) | |
| idexp_lm3d_pred_postnet = np.load("infer_out/tmp_npys/pred_exp_0_postnet_hubert.npy").reshape([-1,204]) | |
| # idexp_lm3d_pred_postnet = idexp_lm3d_pred_postnet * lrs3_idexp_std + lrs3_idexp_mean | |
| idexp_lm3d_all = np.concatenate([idexp_lm3d_pred_lrs3, person_idexp_lm3d_train,idexp_lm3d_pred_vae, idexp_lm3d_pred_postnet]) | |
| idexp_lm3d_all_emb = tsne.fit(idexp_lm3d_all) # array(float64) [B,50]==>[B, 2] | |
| # z_p_emb = tsne.fit(z_p) # array(float64) [B,50]==>[B, 2] | |
| y1 = ["pred_lrs3" for _ in range(len(idexp_lm3d_pred_lrs3))] | |
| y2 = ["person_train" for _ in range(len(person_idexp_lm3d_train))] | |
| y3 = ["vae" for _ in range(len(idexp_lm3d_pred_vae))] | |
| y4 = ["postnet" for _ in range(len(idexp_lm3d_pred_postnet))] | |
| visualize(idexp_lm3d_all_emb, y1+y2+y3+y4) | |
| plt.savefig("infer_out/tmp_npys/lrs3_pred_all_0k.png") |