Spaces:
Sleeping
Sleeping
| from functools import partial | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from matplotlib.ticker import NullFormatter | |
| import numpy as np | |
| from sklearn import datasets, manifold | |
| SEED = 0 | |
| N_COMPONENTS = 2 | |
| np.random.seed(SEED) | |
| def get_circles(n_samples): | |
| X, color = datasets.make_circles( | |
| n_samples=n_samples, | |
| factor=0.5, | |
| noise=0.05, | |
| random_state=SEED | |
| ) | |
| return X, color | |
| def get_s_curve(n_samples): | |
| X, color = datasets.make_s_curve(n_samples=n_samples, random_state=SEED) | |
| X[:, 1], X[:, 2] = X[:, 2], X[:, 1].copy() | |
| return X, color | |
| def get_uniform_grid(n_samples): | |
| x = np.linspace(0, 1, int(np.sqrt(n_samples))) | |
| xx, yy = np.meshgrid(x, x) | |
| X = np.hstack( | |
| [ | |
| xx.ravel().reshape(-1, 1), | |
| yy.ravel().reshape(-1, 1), | |
| ] | |
| ) | |
| color = xx.ravel() | |
| return X, color | |
| DATA_MAPPING = { | |
| 'Circles': get_circles, | |
| 'S-curve': get_s_curve, | |
| 'Uniform Grid': get_uniform_grid, | |
| } | |
| def plot_data(dataset: str, perplexity: int, n_samples: int, tsne: bool): | |
| if isinstance(perplexity, dict): | |
| perplexity = perplexity['value'] | |
| else: | |
| perplexity = int(perplexity) | |
| X, color = DATA_MAPPING[dataset](n_samples) | |
| if tsne: | |
| tsne = manifold.TSNE( | |
| n_components=N_COMPONENTS, | |
| init="random", | |
| random_state=0, | |
| perplexity=perplexity, | |
| n_iter=400, | |
| ) | |
| Y = tsne.fit_transform(X) | |
| else: | |
| Y = X | |
| fig, ax = plt.subplots(figsize=(7, 7)) | |
| ax.scatter(Y[:, 0], Y[:, 1], c=color) | |
| ax.xaxis.set_major_formatter(NullFormatter()) | |
| ax.yaxis.set_major_formatter(NullFormatter()) | |
| ax.axis("tight") | |
| return fig | |
| title = "t-SNE: The effect of various perplexity values on the shape" | |
| description = """ | |
| t-Stochastic Neighborhood Embedding ([t-SNE](https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html)) is a powerful technique dimensionality reduction and visualization of high dimensional datasets. | |
| One of the key parameters in t-SNE is perplexity, which controls the number of nearest neighbors used to represent each data point in the low-dimensional space. | |
| In this illustration, we explore the impact of various perplexity values on t-SNE visualizations using three commonly used datasets: Concentric Circles, S-curve and Uniform Grid. | |
| By comparing the resulting visualizations, we demonstrate how changing the perplexity value affects the shape of the visualization. | |
| Created by [@Hnabil](https://huggingface.co/Hnabil) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/manifold/plot_t_sne_perplexity.html) | |
| """ | |
| with gr.Blocks(title=title) as demo: | |
| gr.HTML(f"<b>{title}</b>") | |
| gr.Markdown(description) | |
| input_data = gr.Radio( | |
| list(DATA_MAPPING), | |
| value="Circles", | |
| label="dataset" | |
| ) | |
| n_samples = gr.Slider( | |
| minimum=100, | |
| maximum=1000, | |
| value=150, | |
| step=25, | |
| label='Number of Samples' | |
| ) | |
| perplexity = gr.Slider( | |
| minimum=2, | |
| maximum=100, | |
| value=5, | |
| step=1, | |
| label='Perplexity' | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| plot = gr.Plot(label="Original data") | |
| fn = partial(plot_data, tsne=False) | |
| input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) | |
| n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) | |
| demo.load(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) | |
| with gr.Column(): | |
| plot = gr.Plot(label="t-SNE") | |
| fn = partial(plot_data, tsne=True) | |
| input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) | |
| perplexity.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) | |
| n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) | |
| demo.load(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) | |
| demo.launch() | |