Spaces:
Runtime error
Runtime error
Update visualization.py
Browse files- visualization.py +33 -1
visualization.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import matplotlib.pyplot as plt
|
|
|
|
| 2 |
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
| 3 |
import matplotlib.colors as mcolors
|
| 4 |
from matplotlib.colors import LinearSegmentedColormap
|
|
@@ -315,4 +316,35 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
|
|
| 315 |
return heatmap_video_path
|
| 316 |
else:
|
| 317 |
print(f"Failed to create heatmap video at: {heatmap_video_path}")
|
| 318 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import matplotlib.pyplot as plt
|
| 2 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 3 |
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
| 4 |
import matplotlib.colors as mcolors
|
| 5 |
from matplotlib.colors import LinearSegmentedColormap
|
|
|
|
| 316 |
return heatmap_video_path
|
| 317 |
else:
|
| 318 |
print(f"Failed to create heatmap video at: {heatmap_video_path}")
|
| 319 |
+
return None
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
# Function to create the correlation heatmap
|
| 323 |
+
def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
|
| 324 |
+
mse_data = {
|
| 325 |
+
'Facial Features MSE': mse_embeddings,
|
| 326 |
+
'Body Posture MSE': mse_posture,
|
| 327 |
+
'Voice MSE': mse_voice
|
| 328 |
+
}
|
| 329 |
+
mse_df = pd.DataFrame(mse_data)
|
| 330 |
+
correlation_matrix = mse_df.corr()
|
| 331 |
+
|
| 332 |
+
plt.figure(figsize=(8, 6))
|
| 333 |
+
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
|
| 334 |
+
plt.title("Correlation Heatmap of MSEs")
|
| 335 |
+
plt.close()
|
| 336 |
+
return plt.gcf()
|
| 337 |
+
|
| 338 |
+
# Function to create the 3D scatter plot
|
| 339 |
+
def plot_3d_scatter(mse_embeddings, mse_posture, mse_voice):
|
| 340 |
+
fig = plt.figure(figsize=(10, 8))
|
| 341 |
+
ax = fig.add_subplot(111, projection='3d')
|
| 342 |
+
ax.scatter(mse_posture, mse_embeddings, mse_voice, c='b', marker='o')
|
| 343 |
+
|
| 344 |
+
ax.set_xlabel('Body Posture MSE')
|
| 345 |
+
ax.set_ylabel('Facial Features MSE')
|
| 346 |
+
ax.set_zlabel('Voice MSE')
|
| 347 |
+
ax.set_title('3D Scatter Plot of MSEs')
|
| 348 |
+
|
| 349 |
+
plt.close()
|
| 350 |
+
return fig
|