Omartificial-Intelligence-Space's picture
update app.py
61dd04e verified
raw
history blame
5.96 kB
import gradio as gr
import spaces
import torch
import pandas as pd
import plotly.graph_objects as go
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator, SequentialEvaluator
from sentence_transformers.util import cos_sim
# Check for GPU support and configure appropriately
device = "cuda" if torch.cuda.is_available() else "cpu"
zero = torch.Tensor([0]).to(device)
print(f"Device being used: {zero.device}")
@spaces.GPU
def evaluate_model(model_id):
model = SentenceTransformer(model_id, device=device)
matryoshka_dimensions = [768, 512, 256, 128, 64]
# Prepare datasets
datasets_info = [
{
"name": "Financial",
"dataset_id": "Omartificial-Intelligence-Space/Arabic-finanical-rag-embedding-dataset",
"split": "train",
"size": 7000,
"columns": ("question", "context"),
"sample_size": 500
},
{
"name": "MLQA",
"dataset_id": "google/xtreme",
"subset": "MLQA.ar.ar",
"split": "validation",
"size": 500,
"columns": ("question", "context"),
"sample_size": 500
},
{
"name": "ARCD",
"dataset_id": "hsseinmz/arcd",
"split": "train",
"size": None,
"columns": ("question", "context"),
"sample_size": 500,
"last_rows": True # Take the last 500 rows
}
]
evaluation_results = []
scores_by_dataset = {}
for dataset_info in datasets_info:
# Load the dataset with subset if available
if "subset" in dataset_info:
dataset = load_dataset(dataset_info["dataset_id"], dataset_info["subset"], split=dataset_info["split"])
else:
dataset = load_dataset(dataset_info["dataset_id"], split=dataset_info["split"])
# Take last 500 rows if specified
if dataset_info.get("last_rows"):
dataset = dataset.select(range(len(dataset) - dataset_info["sample_size"], len(dataset)))
else:
dataset = dataset.select(range(min(dataset_info["sample_size"], len(dataset))))
# Rename columns
dataset = dataset.rename_column(dataset_info["columns"][0], "anchor")
dataset = dataset.rename_column(dataset_info["columns"][1], "positive")
# Check if "id" column already exists before adding it
if "id" not in dataset.column_names:
dataset = dataset.add_column("id", range(len(dataset)))
# Prepare queries and corpus
corpus = dict(zip(dataset["id"], dataset["positive"]))
queries = dict(zip(dataset["id"], dataset["anchor"]))
# Create a mapping of relevant documents (1 in our case) for each query
relevant_docs = {q_id: [q_id] for q_id in queries}
matryoshka_evaluators = []
for dim in matryoshka_dimensions:
ir_evaluator = InformationRetrievalEvaluator(
queries=queries,
corpus=corpus,
relevant_docs=relevant_docs,
name=f"dim_{dim}",
truncate_dim=dim,
score_functions={"cosine": cos_sim},
)
matryoshka_evaluators.append(ir_evaluator)
evaluator = SequentialEvaluator(matryoshka_evaluators)
results = evaluator(model)
scores = []
for dim in matryoshka_dimensions:
key = f"dim_{dim}_cosine_ndcg@10"
score = results[key] if key in results else None
evaluation_results.append({
"Dataset": dataset_info["name"],
"Dimension": dim,
"Score": score
})
scores.append(score)
# Store scores by dataset for plot creation
scores_by_dataset[dataset_info["name"]] = scores
# Convert results to DataFrame for display
result_df = pd.DataFrame(evaluation_results)
# Generate bar charts for each dataset using Plotly
charts = []
color_scale = ['#003f5c', '#2f4b7c', '#665191', '#a05195', '#d45087']
for dataset_name, scores in scores_by_dataset.items():
fig = go.Figure()
fig.add_trace(go.Bar(
x=[str(dim) for dim in matryoshka_dimensions],
y=scores,
marker_color=color_scale,
text=[f"{score:.3f}" if score else "N/A" for score in scores],
textposition='auto'
))
fig.update_layout(
title=f"{dataset_name} Evaluation",
xaxis_title="Embedding Dimension",
yaxis_title="NDCG@10 Score",
template="plotly_white"
)
charts.append(fig)
return result_df, charts[0], charts[1], charts[2]
# Define the Gradio interface
def display_results(model_name):
result_df, chart1, chart2, chart3 = evaluate_model(model_name)
return result_df, chart1, chart2, chart3
demo = gr.Interface(
fn=display_results,
inputs=gr.Textbox(label="Enter a Hugging Face Model ID", placeholder="e.g., sentence-transformers/all-MiniLM-L6-v2"),
outputs=[
gr.Dataframe(label="Evaluation Results"),
gr.Plot(label="Financial Dataset"),
gr.Plot(label="MLQA Dataset"),
gr.Plot(label="ARCD Dataset")
],
title="Arabic Embedding Evaluation",
description=(
"Evaluate your Sentence Transformer model on **Arabic retrieval tasks** using Matryoshka embeddings. "
"Compare performance across financial, long-context, and short-context datasets.\n\n"
"The evaluation uses **NDCG@10** to measure how well the model retrieves relevant contexts. "
"Embedding dimensions are reduced from 768 to 64."
),
theme="default",
live=False,
css="footer {visibility: hidden;}"
)
demo.launch(share=True)