add: simple leaderboard
Browse files- simple_leaderboard.py +131 -0
- src/phoneme_eval.py +43 -29
simple_leaderboard.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import json
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 9 |
+
EVAL_RESULTS_DIR = os.path.join(ROOT_DIR, "eval-results")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_results(results_dir: str) -> pd.DataFrame:
|
| 13 |
+
rows = []
|
| 14 |
+
all_dataset_keys = set()
|
| 15 |
+
|
| 16 |
+
if not os.path.isdir(results_dir):
|
| 17 |
+
return pd.DataFrame(columns=["Model", "Avg PER", "Avg Duration (s)"])
|
| 18 |
+
|
| 19 |
+
# First pass: collect all dataset keys from all files
|
| 20 |
+
for path in glob.glob(os.path.join(results_dir, "*.json")):
|
| 21 |
+
try:
|
| 22 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 23 |
+
data = json.load(f)
|
| 24 |
+
res = data.get("results", {})
|
| 25 |
+
all_dataset_keys.update(res.keys())
|
| 26 |
+
except Exception:
|
| 27 |
+
continue
|
| 28 |
+
|
| 29 |
+
# Use dataset keys directly as display names
|
| 30 |
+
dataset_display_names = {key: key for key in all_dataset_keys}
|
| 31 |
+
|
| 32 |
+
# Second pass: extract data
|
| 33 |
+
for path in glob.glob(os.path.join(results_dir, "*.json")):
|
| 34 |
+
try:
|
| 35 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 36 |
+
data = json.load(f)
|
| 37 |
+
cfg = data.get("config", {})
|
| 38 |
+
res = data.get("results", {})
|
| 39 |
+
|
| 40 |
+
model_name = cfg.get("model_name", "unknown")
|
| 41 |
+
|
| 42 |
+
# Extract PER for each dataset dynamically
|
| 43 |
+
per_values = {}
|
| 44 |
+
dur_values = []
|
| 45 |
+
|
| 46 |
+
for dataset_key in all_dataset_keys:
|
| 47 |
+
dataset_data = res.get(dataset_key, {})
|
| 48 |
+
per_value = dataset_data.get("per") if dataset_data else None
|
| 49 |
+
dur_value = dataset_data.get("avg_duration") if dataset_data else None
|
| 50 |
+
|
| 51 |
+
display_name = dataset_display_names[dataset_key]
|
| 52 |
+
per_values[f"PER {display_name}"] = per_value
|
| 53 |
+
|
| 54 |
+
if dur_value is not None:
|
| 55 |
+
dur_values.append(dur_value)
|
| 56 |
+
|
| 57 |
+
# Calculate average PER across all datasets
|
| 58 |
+
per_vals = [v for v in per_values.values() if v is not None]
|
| 59 |
+
avg_per = sum(per_vals) / len(per_vals) if per_vals else None
|
| 60 |
+
|
| 61 |
+
# Calculate average duration
|
| 62 |
+
avg_dur = sum(dur_values) / len(dur_values) if dur_values else None
|
| 63 |
+
|
| 64 |
+
row = {
|
| 65 |
+
"Model": model_name,
|
| 66 |
+
"Avg PER": avg_per,
|
| 67 |
+
"Avg Duration (s)": avg_dur,
|
| 68 |
+
"_file": os.path.basename(path),
|
| 69 |
+
}
|
| 70 |
+
row.update(per_values)
|
| 71 |
+
rows.append(row)
|
| 72 |
+
|
| 73 |
+
except Exception:
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
df = pd.DataFrame(rows)
|
| 77 |
+
if df.empty:
|
| 78 |
+
# Create default columns based on discovered datasets
|
| 79 |
+
default_cols = ["Model", "Avg PER", "Avg Duration (s)"]
|
| 80 |
+
for key in sorted(all_dataset_keys):
|
| 81 |
+
display_name = dataset_display_names[key]
|
| 82 |
+
default_cols.insert(-2, f"PER {display_name}")
|
| 83 |
+
return pd.DataFrame(columns=default_cols)
|
| 84 |
+
|
| 85 |
+
df = df.sort_values(by=["Avg PER"], ascending=True, na_position="last")
|
| 86 |
+
return df.reset_index(drop=True)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def build_interface():
|
| 90 |
+
with gr.Blocks() as demo:
|
| 91 |
+
gr.Markdown("# Simple Phoneme Leaderboard")
|
| 92 |
+
info = gr.Markdown(f"Results directory: `{EVAL_RESULTS_DIR}`")
|
| 93 |
+
|
| 94 |
+
# Get initial data to determine columns dynamically
|
| 95 |
+
initial_df = load_results(EVAL_RESULTS_DIR)
|
| 96 |
+
if not initial_df.empty:
|
| 97 |
+
headers = list(initial_df.columns)
|
| 98 |
+
# Remove internal columns
|
| 99 |
+
headers = [h for h in headers if not h.startswith('_')]
|
| 100 |
+
else:
|
| 101 |
+
headers = ["Model", "Avg PER", "Avg Duration (s)"]
|
| 102 |
+
|
| 103 |
+
table = gr.Dataframe(headers=headers, row_count=5)
|
| 104 |
+
|
| 105 |
+
def refresh():
|
| 106 |
+
df = load_results(EVAL_RESULTS_DIR)
|
| 107 |
+
if df.empty:
|
| 108 |
+
return df
|
| 109 |
+
|
| 110 |
+
# Get the column order from the dataframe
|
| 111 |
+
cols = [c for c in df.columns if not c.startswith('_')]
|
| 112 |
+
|
| 113 |
+
# Ensure all columns exist for the dataframe component
|
| 114 |
+
for c in cols:
|
| 115 |
+
if c not in df.columns:
|
| 116 |
+
df[c] = None
|
| 117 |
+
return df[cols].round(3)
|
| 118 |
+
|
| 119 |
+
btn = gr.Button("Refresh")
|
| 120 |
+
btn.click(fn=refresh, outputs=table)
|
| 121 |
+
|
| 122 |
+
# Auto-load on start
|
| 123 |
+
table.value = refresh()
|
| 124 |
+
return demo
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
demo = build_interface()
|
| 129 |
+
demo.queue().launch()
|
| 130 |
+
|
| 131 |
+
|
src/phoneme_eval.py
CHANGED
|
@@ -78,35 +78,58 @@ def benchmark_dataset(dataset):
|
|
| 78 |
|
| 79 |
from datasets import load_dataset, Audio
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
def main():
|
| 83 |
-
dataset = load_dataset("mirfan899/phoneme_asr", split="train")
|
| 84 |
-
# Disable automatic audio decoding to avoid torchcodec requirement
|
| 85 |
-
dataset = dataset.cast_column("audio", Audio(decode=False))
|
| 86 |
field = "phonetic"
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
| 96 |
|
| 97 |
-
|
|
|
|
| 98 |
|
| 99 |
-
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
# Save results for leaderboard consumption (one JSON per model)
|
| 112 |
import json, os, time
|
|
@@ -114,25 +137,16 @@ def main():
|
|
| 114 |
os.makedirs(results_dir, exist_ok=True)
|
| 115 |
|
| 116 |
timestamp = int(time.time())
|
| 117 |
-
for
|
| 118 |
-
model_name = str(row["model"]).replace(" ", "-")
|
| 119 |
org_model = f"local/{model_name}"
|
| 120 |
-
per = float(row["Average PER"]) if row["Average PER"] is not None else None
|
| 121 |
-
avg_dur = float(row["Average Duration (s)"]) if row["Average Duration (s)"] is not None else None
|
| 122 |
-
|
| 123 |
payload = {
|
| 124 |
"config": {
|
| 125 |
"model_name": org_model,
|
| 126 |
"model_dtype": "float32",
|
| 127 |
"model_sha": ""
|
| 128 |
},
|
| 129 |
-
"results":
|
| 130 |
-
# Populate both keys expected by Tasks to avoid NaNs in the leaderboard
|
| 131 |
-
"phoneme_dev": {"per": per, "avg_duration": avg_dur},
|
| 132 |
-
"phoneme_test": {"per": per, "avg_duration": avg_dur}
|
| 133 |
-
}
|
| 134 |
}
|
| 135 |
-
|
| 136 |
out_path = os.path.join(results_dir, f"results_{timestamp}_{model_name}.json")
|
| 137 |
with open(out_path, "w", encoding="utf-8") as f:
|
| 138 |
json.dump(payload, f, ensure_ascii=False, indent=2)
|
|
|
|
| 78 |
|
| 79 |
from datasets import load_dataset, Audio
|
| 80 |
|
| 81 |
+
DATASET_LIST = [
|
| 82 |
+
"mirfan899/phoneme_asr",
|
| 83 |
+
"mirfan899/kids_phoneme_md",
|
| 84 |
+
]
|
| 85 |
|
| 86 |
def main():
|
|
|
|
|
|
|
|
|
|
| 87 |
field = "phonetic"
|
| 88 |
|
| 89 |
+
# Collect per-model metrics across datasets
|
| 90 |
+
per_model_results = {}
|
| 91 |
|
| 92 |
+
for dataset_name in DATASET_LIST:
|
| 93 |
+
try:
|
| 94 |
+
dataset = load_dataset(dataset_name, split="train")
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f"[warn] skip dataset {dataset_name}: {e}")
|
| 97 |
+
continue
|
| 98 |
|
| 99 |
+
try:
|
| 100 |
+
dataset = dataset.cast_column("audio", Audio(decode=False))
|
| 101 |
+
except Exception:
|
| 102 |
+
pass
|
| 103 |
|
| 104 |
+
unique_texts = dataset.unique(field)
|
| 105 |
+
print("Unique phonetic strings (", dataset_name, "):", len(unique_texts))
|
| 106 |
|
| 107 |
+
dataset_unique = dataset.filter(lambda x: x[field] in unique_texts)
|
| 108 |
|
| 109 |
+
def is_valid(example):
|
| 110 |
+
phoneme_tokens = example[field].split()
|
| 111 |
+
return len(phoneme_tokens) >= 10
|
| 112 |
|
| 113 |
+
dataset_filtered = dataset_unique.filter(is_valid)
|
| 114 |
+
dataset_final = dataset_filtered.shuffle(seed=42).select(range(min(100, len(dataset_filtered))))
|
| 115 |
|
| 116 |
+
print(dataset_final)
|
| 117 |
+
print("Final size:", len(dataset_final))
|
| 118 |
+
|
| 119 |
+
full_results, avg_stats = benchmark_dataset(dataset_final.select(range(min(10, len(dataset_final)))))
|
| 120 |
+
print("Average Statistic per model (", dataset_name, "):")
|
| 121 |
+
print(avg_stats)
|
| 122 |
+
|
| 123 |
+
# Use dataset name as key (extract the actual name part)
|
| 124 |
+
dataset_key = dataset_name.split("/")[-1] # Get the last part after the slash
|
| 125 |
+
for _, row in avg_stats.iterrows():
|
| 126 |
+
model_name = str(row["model"]).replace(" ", "-")
|
| 127 |
+
per = float(row["Average PER"]) if row["Average PER"] is not None else None
|
| 128 |
+
avg_dur = float(row["Average Duration (s)"]) if row["Average Duration (s)"] is not None else None
|
| 129 |
+
|
| 130 |
+
if model_name not in per_model_results:
|
| 131 |
+
per_model_results[model_name] = {}
|
| 132 |
+
per_model_results[model_name][dataset_key] = {"per": per, "avg_duration": avg_dur}
|
| 133 |
|
| 134 |
# Save results for leaderboard consumption (one JSON per model)
|
| 135 |
import json, os, time
|
|
|
|
| 137 |
os.makedirs(results_dir, exist_ok=True)
|
| 138 |
|
| 139 |
timestamp = int(time.time())
|
| 140 |
+
for model_name, task_results in per_model_results.items():
|
|
|
|
| 141 |
org_model = f"local/{model_name}"
|
|
|
|
|
|
|
|
|
|
| 142 |
payload = {
|
| 143 |
"config": {
|
| 144 |
"model_name": org_model,
|
| 145 |
"model_dtype": "float32",
|
| 146 |
"model_sha": ""
|
| 147 |
},
|
| 148 |
+
"results": task_results
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
}
|
|
|
|
| 150 |
out_path = os.path.join(results_dir, f"results_{timestamp}_{model_name}.json")
|
| 151 |
with open(out_path, "w", encoding="utf-8") as f:
|
| 152 |
json.dump(payload, f, ensure_ascii=False, indent=2)
|