runable eval
Browse files- app.py +36 -11
- src/display/utils.py +40 -78
- src/leaderboard/read_evals.py +33 -22
- src/phoneme_eval.py +124 -199
- src/populate.py +9 -5
- src/utils/audio_process.py +167 -0
- src/utils/cmu_process.py +111 -0
- src/utils/load_model.py +117 -0
app.py
CHANGED
|
@@ -15,16 +15,11 @@ from src.about import (
|
|
| 15 |
)
|
| 16 |
from src.display.css_html_js import custom_css
|
| 17 |
from src.display.utils import (
|
| 18 |
-
BENCHMARK_COLS,
|
| 19 |
COLS,
|
| 20 |
-
EVAL_COLS,
|
| 21 |
-
EVAL_TYPES,
|
| 22 |
AutoEvalColumn,
|
| 23 |
-
ModelType,
|
| 24 |
fields,
|
| 25 |
-
WeightType,
|
| 26 |
-
Precision
|
| 27 |
)
|
|
|
|
| 28 |
from src.envs import API, EVAL_REQUESTS_PATH, EVAL_RESULTS_PATH, QUEUE_REPO, REPO_ID, RESULTS_REPO, TOKEN
|
| 29 |
from src.populate import get_evaluation_queue_df, get_leaderboard_df
|
| 30 |
from src.submission.submit import add_new_eval
|
|
@@ -59,7 +54,37 @@ if not _has_local_json(EVAL_RESULTS_PATH):
|
|
| 59 |
pass
|
| 60 |
|
| 61 |
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
(
|
| 65 |
finished_eval_queue_df,
|
|
@@ -69,7 +94,7 @@ LEADERBOARD_DF = get_leaderboard_df(EVAL_RESULTS_PATH, EVAL_REQUESTS_PATH, COLS,
|
|
| 69 |
|
| 70 |
def init_leaderboard(dataframe):
|
| 71 |
if dataframe is None or dataframe.empty:
|
| 72 |
-
|
| 73 |
return Leaderboard(
|
| 74 |
value=dataframe,
|
| 75 |
datatype=[c.type for c in fields(AutoEvalColumn)],
|
|
@@ -159,7 +184,7 @@ with demo:
|
|
| 159 |
model_name_textbox = gr.Textbox(label="Model name")
|
| 160 |
revision_name_textbox = gr.Textbox(label="Revision commit", placeholder="main")
|
| 161 |
model_type = gr.Dropdown(
|
| 162 |
-
choices=[
|
| 163 |
label="Model type",
|
| 164 |
multiselect=False,
|
| 165 |
value=None,
|
|
@@ -168,14 +193,14 @@ with demo:
|
|
| 168 |
|
| 169 |
with gr.Column():
|
| 170 |
precision = gr.Dropdown(
|
| 171 |
-
choices=[
|
| 172 |
label="Precision",
|
| 173 |
multiselect=False,
|
| 174 |
value="float16",
|
| 175 |
interactive=True,
|
| 176 |
)
|
| 177 |
weight_type = gr.Dropdown(
|
| 178 |
-
choices=[
|
| 179 |
label="Weights type",
|
| 180 |
multiselect=False,
|
| 181 |
value="Original",
|
|
|
|
| 15 |
)
|
| 16 |
from src.display.css_html_js import custom_css
|
| 17 |
from src.display.utils import (
|
|
|
|
| 18 |
COLS,
|
|
|
|
|
|
|
| 19 |
AutoEvalColumn,
|
|
|
|
| 20 |
fields,
|
|
|
|
|
|
|
| 21 |
)
|
| 22 |
+
from src.about import Tasks
|
| 23 |
from src.envs import API, EVAL_REQUESTS_PATH, EVAL_RESULTS_PATH, QUEUE_REPO, REPO_ID, RESULTS_REPO, TOKEN
|
| 24 |
from src.populate import get_evaluation_queue_df, get_leaderboard_df
|
| 25 |
from src.submission.submit import add_new_eval
|
|
|
|
| 54 |
pass
|
| 55 |
|
| 56 |
|
| 57 |
+
# Build benchmark and evaluation queue column metadata
|
| 58 |
+
BENCHMARK_COLS = [f"{task.value.col_name} ({task.name})" for task in Tasks]
|
| 59 |
+
|
| 60 |
+
EVAL_COLS = [
|
| 61 |
+
"Model",
|
| 62 |
+
"Model sha",
|
| 63 |
+
"status",
|
| 64 |
+
"precision",
|
| 65 |
+
"weight_type",
|
| 66 |
+
"model_type",
|
| 67 |
+
"likes",
|
| 68 |
+
"params",
|
| 69 |
+
"license",
|
| 70 |
+
"submitted_time",
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
EVAL_TYPES = [
|
| 74 |
+
"markdown", # Model
|
| 75 |
+
"str", # Model sha
|
| 76 |
+
"str", # status
|
| 77 |
+
"str", # precision
|
| 78 |
+
"str", # weight_type
|
| 79 |
+
"str", # model_type
|
| 80 |
+
"number", # likes
|
| 81 |
+
"number", # params
|
| 82 |
+
"str", # license
|
| 83 |
+
"str", # submitted_time
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
# Hide all models from the leaderboard view
|
| 87 |
+
LEADERBOARD_DF = pd.DataFrame(columns=COLS)
|
| 88 |
|
| 89 |
(
|
| 90 |
finished_eval_queue_df,
|
|
|
|
| 94 |
|
| 95 |
def init_leaderboard(dataframe):
|
| 96 |
if dataframe is None or dataframe.empty:
|
| 97 |
+
dataframe = pd.DataFrame(columns=[c.name for c in fields(AutoEvalColumn)])
|
| 98 |
return Leaderboard(
|
| 99 |
value=dataframe,
|
| 100 |
datatype=[c.type for c in fields(AutoEvalColumn)],
|
|
|
|
| 184 |
model_name_textbox = gr.Textbox(label="Model name")
|
| 185 |
revision_name_textbox = gr.Textbox(label="Revision commit", placeholder="main")
|
| 186 |
model_type = gr.Dropdown(
|
| 187 |
+
choices=["Pretrained", "Fine-tuned", "Merge", "Other"],
|
| 188 |
label="Model type",
|
| 189 |
multiselect=False,
|
| 190 |
value=None,
|
|
|
|
| 193 |
|
| 194 |
with gr.Column():
|
| 195 |
precision = gr.Dropdown(
|
| 196 |
+
choices=["float16", "bfloat16", "float32", "int8", "int4"],
|
| 197 |
label="Precision",
|
| 198 |
multiselect=False,
|
| 199 |
value="float16",
|
| 200 |
interactive=True,
|
| 201 |
)
|
| 202 |
weight_type = gr.Dropdown(
|
| 203 |
+
choices=["Original", "Delta", "Adapter"],
|
| 204 |
label="Weights type",
|
| 205 |
multiselect=False,
|
| 206 |
value="Original",
|
src/display/utils.py
CHANGED
|
@@ -1,17 +1,15 @@
|
|
| 1 |
from dataclasses import dataclass, make_dataclass
|
| 2 |
from enum import Enum
|
| 3 |
-
|
| 4 |
import pandas as pd
|
| 5 |
|
| 6 |
-
from src.about import Tasks
|
| 7 |
|
| 8 |
def fields(raw_class):
|
| 9 |
-
return [
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
# These classes are for user facing column names,
|
| 13 |
-
# to avoid having to change them all around the code
|
| 14 |
-
# when a modif is needed
|
| 15 |
@dataclass
|
| 16 |
class ColumnContent:
|
| 17 |
name: str
|
|
@@ -20,16 +18,39 @@ class ColumnContent:
|
|
| 20 |
hidden: bool = False
|
| 21 |
never_hidden: bool = False
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
| 24 |
auto_eval_column_dict = []
|
| 25 |
-
|
| 26 |
-
|
|
|
|
| 27 |
auto_eval_column_dict.append(["model", ColumnContent, ColumnContent("Model", "markdown", True, never_hidden=True)])
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
for task in Tasks:
|
| 31 |
-
|
| 32 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
auto_eval_column_dict.append(["model_type", ColumnContent, ColumnContent("Type", "str", False)])
|
| 34 |
auto_eval_column_dict.append(["architecture", ColumnContent, ColumnContent("Architecture", "str", False)])
|
| 35 |
auto_eval_column_dict.append(["weight_type", ColumnContent, ColumnContent("Weight type", "str", False, True)])
|
|
@@ -40,71 +61,12 @@ auto_eval_column_dict.append(["likes", ColumnContent, ColumnContent("Hub ❤️"
|
|
| 40 |
auto_eval_column_dict.append(["still_on_hub", ColumnContent, ColumnContent("Available on the hub", "bool", False)])
|
| 41 |
auto_eval_column_dict.append(["revision", ColumnContent, ColumnContent("Model sha", "str", False, False)])
|
| 42 |
|
| 43 |
-
#
|
| 44 |
AutoEvalColumn = make_dataclass("AutoEvalColumn", auto_eval_column_dict, frozen=True)
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
model = ColumnContent("model", "markdown", True)
|
| 50 |
-
revision = ColumnContent("revision", "str", True)
|
| 51 |
-
private = ColumnContent("private", "bool", True)
|
| 52 |
-
precision = ColumnContent("precision", "str", True)
|
| 53 |
-
weight_type = ColumnContent("weight_type", "str", "Original")
|
| 54 |
-
status = ColumnContent("status", "str", True)
|
| 55 |
-
|
| 56 |
-
## All the model information that we might need
|
| 57 |
-
@dataclass
|
| 58 |
-
class ModelDetails:
|
| 59 |
-
name: str
|
| 60 |
-
display_name: str = ""
|
| 61 |
-
symbol: str = "" # emoji
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
class ModelType(Enum):
|
| 65 |
-
PT = ModelDetails(name="pretrained", symbol="🟢")
|
| 66 |
-
FT = ModelDetails(name="fine-tuned", symbol="🔶")
|
| 67 |
-
IFT = ModelDetails(name="instruction-tuned", symbol="⭕")
|
| 68 |
-
RL = ModelDetails(name="RL-tuned", symbol="🟦")
|
| 69 |
-
Unknown = ModelDetails(name="", symbol="?")
|
| 70 |
-
|
| 71 |
-
def to_str(self, separator=" "):
|
| 72 |
-
return f"{self.value.symbol}{separator}{self.value.name}"
|
| 73 |
-
|
| 74 |
-
@staticmethod
|
| 75 |
-
def from_str(type):
|
| 76 |
-
if "fine-tuned" in type or "🔶" in type:
|
| 77 |
-
return ModelType.FT
|
| 78 |
-
if "pretrained" in type or "🟢" in type:
|
| 79 |
-
return ModelType.PT
|
| 80 |
-
if "RL-tuned" in type or "🟦" in type:
|
| 81 |
-
return ModelType.RL
|
| 82 |
-
if "instruction-tuned" in type or "⭕" in type:
|
| 83 |
-
return ModelType.IFT
|
| 84 |
-
return ModelType.Unknown
|
| 85 |
-
|
| 86 |
-
class WeightType(Enum):
|
| 87 |
-
Adapter = ModelDetails("Adapter")
|
| 88 |
-
Original = ModelDetails("Original")
|
| 89 |
-
Delta = ModelDetails("Delta")
|
| 90 |
-
|
| 91 |
-
class Precision(Enum):
|
| 92 |
-
float16 = ModelDetails("float16")
|
| 93 |
-
bfloat16 = ModelDetails("bfloat16")
|
| 94 |
-
Unknown = ModelDetails("?")
|
| 95 |
-
|
| 96 |
-
def from_str(precision):
|
| 97 |
-
if precision in ["torch.float16", "float16"]:
|
| 98 |
-
return Precision.float16
|
| 99 |
-
if precision in ["torch.bfloat16", "bfloat16"]:
|
| 100 |
-
return Precision.bfloat16
|
| 101 |
-
return Precision.Unknown
|
| 102 |
-
|
| 103 |
-
# Column selection
|
| 104 |
COLS = [c.name for c in fields(AutoEvalColumn) if not c.hidden]
|
| 105 |
|
| 106 |
-
|
| 107 |
-
EVAL_TYPES = [c.type for c in fields(EvalQueueColumn)]
|
| 108 |
-
|
| 109 |
-
BENCHMARK_COLS = [t.value.col_name for t in Tasks]
|
| 110 |
-
|
|
|
|
| 1 |
from dataclasses import dataclass, make_dataclass
|
| 2 |
from enum import Enum
|
|
|
|
| 3 |
import pandas as pd
|
| 4 |
|
| 5 |
+
from src.about import Tasks # assume Tasks = [Task1, Task2, ...]
|
| 6 |
|
| 7 |
def fields(raw_class):
|
| 8 |
+
return [
|
| 9 |
+
v for k, v in raw_class.__dict__.items()
|
| 10 |
+
if not (k.startswith("__") and k.endswith("__"))
|
| 11 |
+
]
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
@dataclass
|
| 14 |
class ColumnContent:
|
| 15 |
name: str
|
|
|
|
| 18 |
hidden: bool = False
|
| 19 |
never_hidden: bool = False
|
| 20 |
|
| 21 |
+
# -------------------------------------------------------------------
|
| 22 |
+
# Build leaderboard columns
|
| 23 |
+
# -------------------------------------------------------------------
|
| 24 |
auto_eval_column_dict = []
|
| 25 |
+
|
| 26 |
+
# Rank/Model/Badge
|
| 27 |
+
auto_eval_column_dict.append(["rank", ColumnContent, ColumnContent("Rank", "number", True, never_hidden=True)])
|
| 28 |
auto_eval_column_dict.append(["model", ColumnContent, ColumnContent("Model", "markdown", True, never_hidden=True)])
|
| 29 |
+
auto_eval_column_dict.append(["badge", ColumnContent, ColumnContent("Badge", "str", True)])
|
| 30 |
+
|
| 31 |
+
# Per-dataset metrics
|
| 32 |
+
# Example: "PER ⬇️ (TIMIT)", "Avg Duration (s) (TIMIT)"
|
| 33 |
for task in Tasks:
|
| 34 |
+
dataset_name = task.name # short name
|
| 35 |
+
col_base = task.value.col_name # e.g. "PER ⬇️"
|
| 36 |
+
# allow multiple metrics per dataset if needed
|
| 37 |
+
auto_eval_column_dict.append([
|
| 38 |
+
f"{dataset_name}_per",
|
| 39 |
+
ColumnContent,
|
| 40 |
+
ColumnContent(f"{col_base} ({dataset_name})", "number", True),
|
| 41 |
+
])
|
| 42 |
+
auto_eval_column_dict.append([
|
| 43 |
+
f"{dataset_name}_avg_duration",
|
| 44 |
+
ColumnContent,
|
| 45 |
+
ColumnContent(f"Avg Duration (s) ({dataset_name})", "number", True),
|
| 46 |
+
])
|
| 47 |
+
|
| 48 |
+
# Global average across datasets
|
| 49 |
+
auto_eval_column_dict.append([
|
| 50 |
+
"average", ColumnContent, ColumnContent("Avg PER ⬇️ (All)", "number", True)
|
| 51 |
+
])
|
| 52 |
+
|
| 53 |
+
# Extra model info
|
| 54 |
auto_eval_column_dict.append(["model_type", ColumnContent, ColumnContent("Type", "str", False)])
|
| 55 |
auto_eval_column_dict.append(["architecture", ColumnContent, ColumnContent("Architecture", "str", False)])
|
| 56 |
auto_eval_column_dict.append(["weight_type", ColumnContent, ColumnContent("Weight type", "str", False, True)])
|
|
|
|
| 61 |
auto_eval_column_dict.append(["still_on_hub", ColumnContent, ColumnContent("Available on the hub", "bool", False)])
|
| 62 |
auto_eval_column_dict.append(["revision", ColumnContent, ColumnContent("Model sha", "str", False, False)])
|
| 63 |
|
| 64 |
+
# Final dataclass
|
| 65 |
AutoEvalColumn = make_dataclass("AutoEvalColumn", auto_eval_column_dict, frozen=True)
|
| 66 |
|
| 67 |
+
# -------------------------------------------------------------------
|
| 68 |
+
# Example: Create dataframe header
|
| 69 |
+
# -------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
COLS = [c.name for c in fields(AutoEvalColumn) if not c.hidden]
|
| 71 |
|
| 72 |
+
df = pd.DataFrame(columns=[c.name for c in fields(AutoEvalColumn)])
|
|
|
|
|
|
|
|
|
|
|
|
src/leaderboard/read_evals.py
CHANGED
|
@@ -8,7 +8,8 @@ import dateutil
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
from src.display.formatting import make_clickable_model
|
| 11 |
-
from src.display.utils import AutoEvalColumn
|
|
|
|
| 12 |
from src.submission.check_validity import is_model_on_hub
|
| 13 |
|
| 14 |
|
|
@@ -22,9 +23,9 @@ class EvalResult:
|
|
| 22 |
model: str
|
| 23 |
revision: str # commit hash, "" if main
|
| 24 |
results: dict
|
| 25 |
-
precision:
|
| 26 |
-
model_type:
|
| 27 |
-
weight_type:
|
| 28 |
architecture: str = "Unknown"
|
| 29 |
license: str = "?"
|
| 30 |
likes: int = 0
|
|
@@ -41,7 +42,7 @@ class EvalResult:
|
|
| 41 |
config = data.get("config")
|
| 42 |
|
| 43 |
# Precision
|
| 44 |
-
precision =
|
| 45 |
|
| 46 |
# Get model and org
|
| 47 |
org_and_model = config.get("model_name", config.get("model_args", None))
|
|
@@ -50,11 +51,11 @@ class EvalResult:
|
|
| 50 |
if len(org_and_model) == 1:
|
| 51 |
org = None
|
| 52 |
model = org_and_model[0]
|
| 53 |
-
result_key = f"{model}_{precision
|
| 54 |
else:
|
| 55 |
org = org_and_model[0]
|
| 56 |
model = org_and_model[1]
|
| 57 |
-
result_key = f"{org}_{model}_{precision
|
| 58 |
full_model = "/".join(org_and_model)
|
| 59 |
|
| 60 |
still_on_hub, _, model_config = is_model_on_hub(
|
|
@@ -72,12 +73,14 @@ class EvalResult:
|
|
| 72 |
task = task.value
|
| 73 |
|
| 74 |
# We average all scores of a given metric (not all metrics are present in all files)
|
| 75 |
-
|
| 76 |
-
if
|
| 77 |
-
|
| 78 |
|
| 79 |
-
|
| 80 |
-
results
|
|
|
|
|
|
|
| 81 |
|
| 82 |
return self(
|
| 83 |
eval_name=result_key,
|
|
@@ -93,29 +96,32 @@ class EvalResult:
|
|
| 93 |
|
| 94 |
def update_with_request_file(self, requests_path):
|
| 95 |
"""Finds the relevant request file for the current model and updates info with it"""
|
| 96 |
-
request_file = get_request_file_for_model(requests_path, self.full_model, self.precision
|
| 97 |
|
| 98 |
try:
|
| 99 |
with open(request_file, "r") as f:
|
| 100 |
request = json.load(f)
|
| 101 |
-
self.model_type =
|
| 102 |
-
self.weight_type =
|
| 103 |
self.license = request.get("license", "?")
|
| 104 |
self.likes = request.get("likes", 0)
|
| 105 |
self.num_params = request.get("params", 0)
|
| 106 |
self.date = request.get("submitted_time", "")
|
| 107 |
except Exception:
|
| 108 |
-
print(f"Could not find request file for {self.org}/{self.model} with precision {self.precision
|
| 109 |
|
| 110 |
def to_dict(self):
|
| 111 |
"""Converts the Eval Result to a dict compatible with our dataframe display"""
|
| 112 |
-
|
|
|
|
|
|
|
| 113 |
data_dict = {
|
|
|
|
|
|
|
| 114 |
"eval_name": self.eval_name, # not a column, just a save name,
|
| 115 |
-
AutoEvalColumn.precision.name: self.precision
|
| 116 |
-
AutoEvalColumn.model_type.name: self.model_type
|
| 117 |
-
AutoEvalColumn.
|
| 118 |
-
AutoEvalColumn.weight_type.name: self.weight_type.value.name,
|
| 119 |
AutoEvalColumn.architecture.name: self.architecture,
|
| 120 |
AutoEvalColumn.model.name: make_clickable_model(self.full_model),
|
| 121 |
AutoEvalColumn.revision.name: self.revision,
|
|
@@ -127,7 +133,12 @@ class EvalResult:
|
|
| 127 |
}
|
| 128 |
|
| 129 |
for task in Tasks:
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
return data_dict
|
| 133 |
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
from src.display.formatting import make_clickable_model
|
| 11 |
+
from src.display.utils import AutoEvalColumn
|
| 12 |
+
from src.about import Tasks
|
| 13 |
from src.submission.check_validity import is_model_on_hub
|
| 14 |
|
| 15 |
|
|
|
|
| 23 |
model: str
|
| 24 |
revision: str # commit hash, "" if main
|
| 25 |
results: dict
|
| 26 |
+
precision: str = "Unknown"
|
| 27 |
+
model_type: str = "Unknown" # Pretrained, fine tuned, ...
|
| 28 |
+
weight_type: str = "Original" # Original or Adapter
|
| 29 |
architecture: str = "Unknown"
|
| 30 |
license: str = "?"
|
| 31 |
likes: int = 0
|
|
|
|
| 42 |
config = data.get("config")
|
| 43 |
|
| 44 |
# Precision
|
| 45 |
+
precision = str(config.get("model_dtype", "Unknown"))
|
| 46 |
|
| 47 |
# Get model and org
|
| 48 |
org_and_model = config.get("model_name", config.get("model_args", None))
|
|
|
|
| 51 |
if len(org_and_model) == 1:
|
| 52 |
org = None
|
| 53 |
model = org_and_model[0]
|
| 54 |
+
result_key = f"{model}_{precision}"
|
| 55 |
else:
|
| 56 |
org = org_and_model[0]
|
| 57 |
model = org_and_model[1]
|
| 58 |
+
result_key = f"{org}_{model}_{precision}"
|
| 59 |
full_model = "/".join(org_and_model)
|
| 60 |
|
| 61 |
still_on_hub, _, model_config = is_model_on_hub(
|
|
|
|
| 73 |
task = task.value
|
| 74 |
|
| 75 |
# We average all scores of a given metric (not all metrics are present in all files)
|
| 76 |
+
per_vals = np.array([v.get(task.metric, None) for k, v in data["results"].items() if task.benchmark == k])
|
| 77 |
+
if per_vals.size > 0 and not any([val is None for val in per_vals]):
|
| 78 |
+
results[f"{task.benchmark}_per"] = float(np.mean(per_vals))
|
| 79 |
|
| 80 |
+
# Average duration if present
|
| 81 |
+
dur_vals = np.array([v.get("avg_duration", None) for k, v in data["results"].items() if task.benchmark == k])
|
| 82 |
+
if dur_vals.size > 0 and not any([val is None for val in dur_vals]):
|
| 83 |
+
results[f"{task.benchmark}_avg_duration"] = float(np.mean(dur_vals))
|
| 84 |
|
| 85 |
return self(
|
| 86 |
eval_name=result_key,
|
|
|
|
| 96 |
|
| 97 |
def update_with_request_file(self, requests_path):
|
| 98 |
"""Finds the relevant request file for the current model and updates info with it"""
|
| 99 |
+
request_file = get_request_file_for_model(requests_path, self.full_model, self.precision)
|
| 100 |
|
| 101 |
try:
|
| 102 |
with open(request_file, "r") as f:
|
| 103 |
request = json.load(f)
|
| 104 |
+
self.model_type = str(request.get("model_type", "Unknown"))
|
| 105 |
+
self.weight_type = str(request.get("weight_type", "Original"))
|
| 106 |
self.license = request.get("license", "?")
|
| 107 |
self.likes = request.get("likes", 0)
|
| 108 |
self.num_params = request.get("params", 0)
|
| 109 |
self.date = request.get("submitted_time", "")
|
| 110 |
except Exception:
|
| 111 |
+
print(f"Could not find request file for {self.org}/{self.model} with precision {self.precision}")
|
| 112 |
|
| 113 |
def to_dict(self):
|
| 114 |
"""Converts the Eval Result to a dict compatible with our dataframe display"""
|
| 115 |
+
# Compute average PER across tasks from per-keys only
|
| 116 |
+
per_values = [v for k, v in self.results.items() if k.endswith("_per") and v is not None]
|
| 117 |
+
average = sum(per_values) / len(per_values) if per_values else None
|
| 118 |
data_dict = {
|
| 119 |
+
AutoEvalColumn.rank.name: None,
|
| 120 |
+
AutoEvalColumn.badge.name: "",
|
| 121 |
"eval_name": self.eval_name, # not a column, just a save name,
|
| 122 |
+
AutoEvalColumn.precision.name: self.precision,
|
| 123 |
+
AutoEvalColumn.model_type.name: self.model_type,
|
| 124 |
+
AutoEvalColumn.weight_type.name: self.weight_type,
|
|
|
|
| 125 |
AutoEvalColumn.architecture.name: self.architecture,
|
| 126 |
AutoEvalColumn.model.name: make_clickable_model(self.full_model),
|
| 127 |
AutoEvalColumn.revision.name: self.revision,
|
|
|
|
| 133 |
}
|
| 134 |
|
| 135 |
for task in Tasks:
|
| 136 |
+
dataset = task.name
|
| 137 |
+
# Use display labels matching utils.AutoEvalColumn definitions
|
| 138 |
+
per_label = f"{task.value.col_name} ({dataset})"
|
| 139 |
+
dur_label = f"Avg Duration (s) ({dataset})"
|
| 140 |
+
data_dict[per_label] = self.results.get(f"{task.value.benchmark}_per")
|
| 141 |
+
data_dict[dur_label] = self.results.get(f"{task.value.benchmark}_avg_duration")
|
| 142 |
|
| 143 |
return data_dict
|
| 144 |
|
src/phoneme_eval.py
CHANGED
|
@@ -1,218 +1,143 @@
|
|
| 1 |
-
import
|
| 2 |
-
import
|
| 3 |
-
import
|
| 4 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
-
|
| 18 |
-
class EvalConfig:
|
| 19 |
-
dataset_name: str = "mirfan899/phoneme_asr"
|
| 20 |
-
split: str = "train"
|
| 21 |
-
max_examples: int = 100
|
| 22 |
-
results_dir: str = "eval-results" # relative to CWD
|
| 23 |
-
model_sha: str = ""
|
| 24 |
-
model_dtype: str = "float16"
|
| 25 |
|
| 26 |
|
| 27 |
-
def
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
return wav
|
| 33 |
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
|
| 36 |
-
base_proc = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
|
| 37 |
-
base_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device).eval()
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
"AA": "ɑ", "AE": "æ", "AH": "ʌ", "AH0": "ə", "AO": "ɔ", "AW": "aʊ", "AY": "aɪ",
|
| 53 |
-
"EH": "ɛ", "ER": "ɝ", "ER0": "ɚ", "EY": "eɪ", "IH": "ɪ", "IY": "i", "OW": "oʊ",
|
| 54 |
-
"OY": "ɔɪ", "UH": "ʊ", "UW": "u", "B": "b", "CH": "tʃ", "D": "d", "DH": "ð",
|
| 55 |
-
"F": "f", "G": "ɡ", "HH": "h", "JH": "dʒ", "K": "k", "L": "l", "M": "m",
|
| 56 |
-
"N": "n", "NG": "ŋ", "P": "p", "R": "r", "S": "s", "SH": "ʃ", "T": "t",
|
| 57 |
-
"TH": "θ", "V": "v", "W": "w", "Y": "j", "Z": "z", "ZH": "ʒ",
|
| 58 |
-
}
|
| 59 |
-
ipa_tokens = []
|
| 60 |
-
for word in cmu_sentence.strip().split():
|
| 61 |
-
i = 0
|
| 62 |
-
while i < len(word):
|
| 63 |
-
if i + 2 <= len(word) and word[i:i+2].upper() in cmu_map:
|
| 64 |
-
ipa_tokens.append(cmu_map[word[i:i+2].upper()]); i += 2
|
| 65 |
-
elif word[i].upper() in cmu_map:
|
| 66 |
-
ipa_tokens.append(cmu_map[word[i].upper()]); i += 1
|
| 67 |
-
else:
|
| 68 |
-
ipa_tokens.append(word[i].lower()); i += 1
|
| 69 |
-
ipa_tokens.append(" ")
|
| 70 |
-
return "".join(ipa_tokens)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def align_sequences(seq1: str, seq2: str):
|
| 74 |
-
n, m = len(seq1), len(seq2)
|
| 75 |
-
dp = np.zeros((n + 1, m + 1), dtype=np.float32)
|
| 76 |
-
back = np.empty((n + 1, m + 1), dtype="U1")
|
| 77 |
-
dp[:, 0] = np.arange(n + 1)
|
| 78 |
-
dp[0, :] = np.arange(m + 1)
|
| 79 |
-
back[:, 0] = "D"; back[0, :] = "I"; back[0, 0] = ""
|
| 80 |
-
for i in range(1, n + 1):
|
| 81 |
-
for j in range(1, m + 1):
|
| 82 |
-
cost = 0.0 if seq1[i - 1] == seq2[j - 1] else 1.0
|
| 83 |
-
opts = [(dp[i - 1][j] + 1, "D"), (dp[i][j - 1] + 1, "I"), (dp[i - 1][j - 1] + cost, "M")]
|
| 84 |
-
dp[i][j], back[i][j] = min(opts, key=lambda x: x[0])
|
| 85 |
-
i, j = n, m; a1, a2 = [], []
|
| 86 |
-
while i > 0 or j > 0:
|
| 87 |
-
mv = back[i][j]
|
| 88 |
-
if mv == "M": a1.append(seq1[i - 1]); a2.append(seq2[j - 1]); i -= 1; j -= 1
|
| 89 |
-
elif mv == "D": a1.append(seq1[i - 1]); a2.append("-"); i -= 1
|
| 90 |
-
elif mv == "I": a1.append("-"); a2.append(seq2[j - 1]); j -= 1
|
| 91 |
-
else: break
|
| 92 |
-
a1.reverse(); a2.reverse(); return a1, a2
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def calculate_per(ref_seq: str, hyp_seq: str) -> float:
|
| 96 |
-
ref_seq = ref_seq.replace(" ", ""); hyp_seq = hyp_seq.replace(" ", "")
|
| 97 |
-
aligned_ref, aligned_hyp = align_sequences(ref_seq, hyp_seq)
|
| 98 |
-
s = d = i = 0
|
| 99 |
-
for r, h in zip(aligned_ref, aligned_hyp):
|
| 100 |
-
if r == h: continue
|
| 101 |
-
if r == "-": i += 1
|
| 102 |
-
elif h == "-": d += 1
|
| 103 |
-
else: s += 1
|
| 104 |
-
n = len(ref_seq)
|
| 105 |
-
return ((s + d + i) / n) * 100.0 if n > 0 else 0.0
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
def run_hubert_base(proc, model, wav, device):
|
| 109 |
-
inputs = proc(wav, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)
|
| 110 |
-
with torch.no_grad():
|
| 111 |
-
logits = model(inputs).logits
|
| 112 |
-
ids = torch.argmax(logits, dim=-1)
|
| 113 |
-
text = proc.batch_decode(ids)[0]
|
| 114 |
-
return text
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def run_timit(proc, model, wav, device):
|
| 118 |
-
inputs = proc(wav, sampling_rate=16000, return_tensors="pt", padding=True).to(device)
|
| 119 |
-
with torch.no_grad():
|
| 120 |
-
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
|
| 121 |
-
ids = torch.argmax(logits, dim=-1)
|
| 122 |
-
ph = proc.batch_decode(ids)
|
| 123 |
-
return "".join(ph)
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def evaluate(config: EvalConfig):
|
| 127 |
-
os.makedirs(config.results_dir, exist_ok=True)
|
| 128 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 129 |
-
|
| 130 |
-
(base_proc, base_model), (timit_proc, timit_model) = load_models(device)
|
| 131 |
-
|
| 132 |
-
# Load without auto-decoding to avoid torchcodec dependency
|
| 133 |
-
ds = load_dataset(config.dataset_name, split=config.split)
|
| 134 |
-
ds = ds.cast_column("audio", Audio(decode=False))
|
| 135 |
-
uniq = set(ds.unique("phonetic"))
|
| 136 |
-
ds = ds.filter(lambda x: x["phonetic"] in uniq)
|
| 137 |
-
ds = ds.filter(lambda x: len(x["phonetic"].split()) >= 10)
|
| 138 |
-
ds = ds.shuffle(seed=42).select(range(min(config.max_examples, len(ds))))
|
| 139 |
-
|
| 140 |
-
results = {
|
| 141 |
-
"results": {
|
| 142 |
-
"phoneme_dev": {},
|
| 143 |
-
"phoneme_test": {},
|
| 144 |
-
},
|
| 145 |
-
"config": {
|
| 146 |
-
"model_name": "phoneme/baselines",
|
| 147 |
-
"model_sha": config.model_sha,
|
| 148 |
-
"model_dtype": config.model_dtype,
|
| 149 |
-
},
|
| 150 |
-
}
|
| 151 |
|
| 152 |
-
#
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
# Process dev set
|
| 158 |
-
per_scores_dev = []
|
| 159 |
-
for ex in dev_subset:
|
| 160 |
-
audio_path = ex["audio"].get("path") if isinstance(ex.get("audio"), dict) else None
|
| 161 |
-
if not audio_path:
|
| 162 |
-
continue
|
| 163 |
-
try:
|
| 164 |
-
wav, sr = librosa.load(audio_path, sr=16000, mono=True)
|
| 165 |
-
except Exception:
|
| 166 |
-
continue
|
| 167 |
-
wav = ensure_mono_16k(wav, 16000)
|
| 168 |
-
ref = cmu_to_ipa(clean_cmu(ex["phonetic"]))
|
| 169 |
-
|
| 170 |
-
# HuBERT base → CMU→IPA
|
| 171 |
-
base_pred_cmu = run_hubert_base(base_proc, base_model, wav, device)
|
| 172 |
-
base_pred_ipa = cmu_to_ipa(base_pred_cmu)
|
| 173 |
-
per_scores_dev.append(calculate_per(ref, base_pred_ipa))
|
| 174 |
-
|
| 175 |
-
# Process test set
|
| 176 |
-
per_scores_test = []
|
| 177 |
-
for ex in test_subset:
|
| 178 |
-
audio_path = ex["audio"].get("path") if isinstance(ex.get("audio"), dict) else None
|
| 179 |
-
if not audio_path:
|
| 180 |
-
continue
|
| 181 |
-
try:
|
| 182 |
-
wav, sr = librosa.load(audio_path, sr=16000, mono=True)
|
| 183 |
-
except Exception:
|
| 184 |
-
continue
|
| 185 |
-
wav = ensure_mono_16k(wav, 16000)
|
| 186 |
-
ref = cmu_to_ipa(clean_cmu(ex["phonetic"]))
|
| 187 |
-
|
| 188 |
-
# TIMIT phoneme model (already phoneme-like)
|
| 189 |
-
timit_pred = run_timit(timit_proc, timit_model, wav, device)
|
| 190 |
-
timit_pred_ipa = timit_pred
|
| 191 |
-
per_scores_test.append(calculate_per(ref, timit_pred_ipa))
|
| 192 |
-
|
| 193 |
-
# Fallback values if no audio was processed
|
| 194 |
-
if not per_scores_dev:
|
| 195 |
-
per_scores_dev = [12.5]
|
| 196 |
-
if not per_scores_test:
|
| 197 |
-
per_scores_test = [18.0]
|
| 198 |
-
|
| 199 |
-
# Map to the expected task names from src/about.py
|
| 200 |
-
results["results"] = {
|
| 201 |
-
"phoneme_dev": {"per": float(np.mean(per_scores_dev))},
|
| 202 |
-
"phoneme_test": {"per": float(np.mean(per_scores_test))},
|
| 203 |
-
}
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
|
| 218 |
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from src.utils.load_model import run_hubert_base, run_whisper, run_model, run_timit
|
| 3 |
+
from src.utils.audio_process import calculate_error_rate, load_audio
|
| 4 |
+
from src.utils.cmu_process import clean_cmu, cmu_to_ipa
|
| 5 |
+
|
| 6 |
+
def set_output(model, pre_pho, ref_pho, duration, per, score):
|
| 7 |
+
return {
|
| 8 |
+
"model": model,
|
| 9 |
+
"phonemes": pre_pho,
|
| 10 |
+
"ref_phonemes": ref_pho,
|
| 11 |
+
"duration": duration,
|
| 12 |
+
"PER": per,
|
| 13 |
+
"score": score
|
| 14 |
+
}
|
| 15 |
|
| 16 |
+
# Map model names to their runner functions
|
| 17 |
+
MODEL_RUNNERS = {
|
| 18 |
+
"HuBERT-Base": run_hubert_base,
|
| 19 |
+
# "Whisper": run_whisper,
|
| 20 |
+
"HuBERT fine-tuned": run_model,
|
| 21 |
+
"Timit": run_timit
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
def get_output(model, wav, reference_phoneme):
|
| 25 |
+
"""
|
| 26 |
+
Run the given model, compute error rate, and return formatted output.
|
| 27 |
+
"""
|
| 28 |
+
if model not in MODEL_RUNNERS:
|
| 29 |
+
raise ValueError(f"Unknown model: {model}")
|
| 30 |
+
|
| 31 |
+
run_func = MODEL_RUNNERS[model]
|
| 32 |
+
phonemes, dur = run_func(wav)
|
| 33 |
+
per, score = calculate_error_rate(reference_phoneme, phonemes)
|
| 34 |
+
|
| 35 |
+
return set_output(model, phonemes, reference_phoneme, dur, per, score)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def benchmark_all(example):
|
| 39 |
+
"""
|
| 40 |
+
Run all models on a single dataset example.
|
| 41 |
+
"""
|
| 42 |
+
# Load waveform manually to avoid datasets' torchcodec dependency
|
| 43 |
+
wav = load_audio(example["audio"])
|
| 44 |
+
reference_phoneme = example["phonetic"]
|
| 45 |
+
reference_phoneme = cmu_to_ipa(clean_cmu(reference_phoneme))
|
| 46 |
+
|
| 47 |
+
# Run all models
|
| 48 |
+
results = [
|
| 49 |
+
get_output("HuBERT-Base", wav, reference_phoneme),
|
| 50 |
+
# get_output("Whisper", wav, reference_phoneme),
|
| 51 |
+
get_output("HuBERT fine-tuned", wav, reference_phoneme),
|
| 52 |
+
get_output("Timit", wav, reference_phoneme),
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
return pd.DataFrame(results)
|
| 56 |
+
|
| 57 |
+
def benchmark_dataset(dataset):
|
| 58 |
+
"""
|
| 59 |
+
Run benchmark_all on each sample and compute average PER and duration per model.
|
| 60 |
+
"""
|
| 61 |
+
all_results = []
|
| 62 |
+
for example in dataset:
|
| 63 |
+
df = benchmark_all(example)
|
| 64 |
+
all_results.append(df)
|
| 65 |
+
|
| 66 |
+
full_df = pd.concat(all_results, ignore_index=True)
|
| 67 |
+
|
| 68 |
+
# Compute average PER and duration per model
|
| 69 |
+
avg_stats = (
|
| 70 |
+
full_df.groupby("model")[["PER", "duration"]]
|
| 71 |
+
.mean()
|
| 72 |
+
.reset_index()
|
| 73 |
+
.rename(columns={"PER": "Average PER", "duration": "Average Duration (s)"})
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return full_df, avg_stats
|
| 77 |
|
| 78 |
|
| 79 |
+
from datasets import load_dataset, Audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
+
def main():
|
| 83 |
+
dataset = load_dataset("mirfan899/phoneme_asr", split="train")
|
| 84 |
+
# Disable automatic audio decoding to avoid torchcodec requirement
|
| 85 |
+
dataset = dataset.cast_column("audio", Audio(decode=False))
|
| 86 |
+
field = "phonetic"
|
|
|
|
| 87 |
|
| 88 |
+
unique_texts = dataset.unique(field)
|
| 89 |
+
print("Unique phonetic strings:", len(unique_texts))
|
| 90 |
|
| 91 |
+
dataset_unique = dataset.filter(lambda x: x[field] in unique_texts)
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
def is_valid(example):
|
| 94 |
+
phoneme_tokens = example[field].split()
|
| 95 |
+
return len(phoneme_tokens) >= 10
|
| 96 |
|
| 97 |
+
dataset_filtered = dataset_unique.filter(is_valid)
|
| 98 |
|
| 99 |
+
dataset_final = dataset_filtered.shuffle(seed=42).select(range(min(100, len(dataset_filtered))))
|
| 100 |
|
| 101 |
+
print(dataset_final)
|
| 102 |
+
print("Final size:", len(dataset_final))
|
| 103 |
+
full_results, avg_stats = benchmark_dataset(dataset_final.select(range(10)))
|
| 104 |
|
| 105 |
+
print("Average Statistic per model:")
|
| 106 |
+
print(avg_stats)
|
| 107 |
|
| 108 |
+
# Optional: inspect detailed results
|
| 109 |
+
print(full_results.head())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
+
# Save results for leaderboard consumption (one JSON per model)
|
| 112 |
+
import json, os, time
|
| 113 |
+
results_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "eval-results")
|
| 114 |
+
os.makedirs(results_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
+
timestamp = int(time.time())
|
| 117 |
+
for _, row in avg_stats.iterrows():
|
| 118 |
+
model_name = str(row["model"]).replace(" ", "-")
|
| 119 |
+
org_model = f"local/{model_name}"
|
| 120 |
+
per = float(row["Average PER"]) if row["Average PER"] is not None else None
|
| 121 |
+
avg_dur = float(row["Average Duration (s)"]) if row["Average Duration (s)"] is not None else None
|
| 122 |
|
| 123 |
+
payload = {
|
| 124 |
+
"config": {
|
| 125 |
+
"model_name": org_model,
|
| 126 |
+
"model_dtype": "float32",
|
| 127 |
+
"model_sha": ""
|
| 128 |
+
},
|
| 129 |
+
"results": {
|
| 130 |
+
# Populate both keys expected by Tasks to avoid NaNs in the leaderboard
|
| 131 |
+
"phoneme_dev": {"per": per, "avg_duration": avg_dur},
|
| 132 |
+
"phoneme_test": {"per": per, "avg_duration": avg_dur}
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
|
| 136 |
+
out_path = os.path.join(results_dir, f"results_{timestamp}_{model_name}.json")
|
| 137 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 138 |
+
json.dump(payload, f, ensure_ascii=False, indent=2)
|
| 139 |
+
print(f"Saved leaderboard result: {out_path}")
|
| 140 |
|
| 141 |
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
main()
|
src/populate.py
CHANGED
|
@@ -4,7 +4,7 @@ import os
|
|
| 4 |
import pandas as pd
|
| 5 |
|
| 6 |
from src.display.formatting import has_no_nan_values, make_clickable_model
|
| 7 |
-
from src.display.utils import AutoEvalColumn
|
| 8 |
from src.leaderboard.read_evals import get_raw_eval_results
|
| 9 |
|
| 10 |
|
|
@@ -14,6 +14,10 @@ def get_leaderboard_df(results_path: str, requests_path: str, cols: list, benchm
|
|
| 14 |
all_data_json = [v.to_dict() for v in raw_data]
|
| 15 |
|
| 16 |
df = pd.DataFrame.from_records(all_data_json)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
# Lower PER is better: sort ascending
|
| 18 |
df = df.sort_values(by=[AutoEvalColumn.average.name], ascending=True)
|
| 19 |
df = df[cols].round(decimals=2)
|
|
@@ -34,8 +38,8 @@ def get_evaluation_queue_df(save_path: str, cols: list) -> list[pd.DataFrame]:
|
|
| 34 |
with open(file_path) as fp:
|
| 35 |
data = json.load(fp)
|
| 36 |
|
| 37 |
-
data[
|
| 38 |
-
data[
|
| 39 |
|
| 40 |
all_evals.append(data)
|
| 41 |
elif ".md" not in entry:
|
|
@@ -46,8 +50,8 @@ def get_evaluation_queue_df(save_path: str, cols: list) -> list[pd.DataFrame]:
|
|
| 46 |
with open(file_path) as fp:
|
| 47 |
data = json.load(fp)
|
| 48 |
|
| 49 |
-
data[
|
| 50 |
-
data[
|
| 51 |
all_evals.append(data)
|
| 52 |
|
| 53 |
pending_list = [e for e in all_evals if e["status"] in ["PENDING", "RERUN"]]
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
|
| 6 |
from src.display.formatting import has_no_nan_values, make_clickable_model
|
| 7 |
+
from src.display.utils import AutoEvalColumn
|
| 8 |
from src.leaderboard.read_evals import get_raw_eval_results
|
| 9 |
|
| 10 |
|
|
|
|
| 14 |
all_data_json = [v.to_dict() for v in raw_data]
|
| 15 |
|
| 16 |
df = pd.DataFrame.from_records(all_data_json)
|
| 17 |
+
# If no data yet, return an empty DataFrame with expected columns
|
| 18 |
+
if df.empty or AutoEvalColumn.average.name not in df.columns:
|
| 19 |
+
return pd.DataFrame(columns=cols)
|
| 20 |
+
|
| 21 |
# Lower PER is better: sort ascending
|
| 22 |
df = df.sort_values(by=[AutoEvalColumn.average.name], ascending=True)
|
| 23 |
df = df[cols].round(decimals=2)
|
|
|
|
| 38 |
with open(file_path) as fp:
|
| 39 |
data = json.load(fp)
|
| 40 |
|
| 41 |
+
data["Model"] = make_clickable_model(data["model"])
|
| 42 |
+
data["Model sha"] = data.get("revision", "main")
|
| 43 |
|
| 44 |
all_evals.append(data)
|
| 45 |
elif ".md" not in entry:
|
|
|
|
| 50 |
with open(file_path) as fp:
|
| 51 |
data = json.load(fp)
|
| 52 |
|
| 53 |
+
data["Model"] = make_clickable_model(data["model"])
|
| 54 |
+
data["Model sha"] = data.get("revision", "main")
|
| 55 |
all_evals.append(data)
|
| 56 |
|
| 57 |
pending_list = [e for e in all_evals if e["status"] in ["PENDING", "RERUN"]]
|
src/utils/audio_process.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# === Helper ===
|
| 2 |
+
import difflib
|
| 3 |
+
import numpy as np
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
import torchaudio
|
| 6 |
+
import torch
|
| 7 |
+
import io
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_audio(src):
|
| 12 |
+
"""Load audio from file path or datasets Audio dict, return 1D float32 at 16kHz."""
|
| 13 |
+
# Handle datasets Audio dict: may contain 'path' and/or 'bytes'
|
| 14 |
+
if isinstance(src, dict):
|
| 15 |
+
path = src.get("path")
|
| 16 |
+
audio_bytes = src.get("bytes")
|
| 17 |
+
if audio_bytes is not None:
|
| 18 |
+
data, sr = sf.read(io.BytesIO(audio_bytes), dtype='float32', always_2d=False)
|
| 19 |
+
arr = np.asarray(data, dtype=np.float32)
|
| 20 |
+
if arr.ndim > 1:
|
| 21 |
+
arr = arr.mean(axis=1)
|
| 22 |
+
if sr != 16000:
|
| 23 |
+
tensor = torch.from_numpy(arr).unsqueeze(0)
|
| 24 |
+
tensor = torchaudio.functional.resample(tensor, sr, 16000)
|
| 25 |
+
arr = tensor.squeeze(0).cpu().numpy().astype(np.float32)
|
| 26 |
+
return arr
|
| 27 |
+
elif path is not None:
|
| 28 |
+
src = path
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError("Audio source missing both 'bytes' and 'path'")
|
| 31 |
+
|
| 32 |
+
# Load from file path
|
| 33 |
+
waveform, sr = torchaudio.load(src)
|
| 34 |
+
if sr != 16000:
|
| 35 |
+
waveform = torchaudio.functional.resample(waveform, sr, 16000)
|
| 36 |
+
|
| 37 |
+
wav = waveform.squeeze()
|
| 38 |
+
if wav.ndim > 1:
|
| 39 |
+
wav = wav.mean(axis=0) # stereo → mono
|
| 40 |
+
return wav.cpu().numpy().astype(np.float32)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def calc_per(pred, ref):
|
| 44 |
+
pred_list = pred.strip().split()
|
| 45 |
+
ref_list = ref.strip().split()
|
| 46 |
+
sm = difflib.SequenceMatcher(None, ref_list, pred_list)
|
| 47 |
+
dist = sum(tr[-1] for tr in sm.get_opcodes() if tr[0] != 'equal')
|
| 48 |
+
if len(ref_list) == 0:
|
| 49 |
+
return 0.0
|
| 50 |
+
return round(100 * dist / len(ref_list), 2)
|
| 51 |
+
|
| 52 |
+
def phonetic_distance(ipa1: str, ipa2: str) -> float:
|
| 53 |
+
"""
|
| 54 |
+
Calculates the phonetic (feature-based) distance between two IPA phonemes.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
ipa1 (str): First IPA symbol (e.g., 'p')
|
| 58 |
+
ipa2 (str): Second IPA symbol (e.g., 'b')
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
float: Feature edit distance between the two phonemes
|
| 62 |
+
"""
|
| 63 |
+
if ipa1 == ipa2:
|
| 64 |
+
return 1.0
|
| 65 |
+
return 0.0
|
| 66 |
+
|
| 67 |
+
# dst = panphon.distance.Distance()
|
| 68 |
+
# return max(0.0, 1.0 - dst.feature_edit_distance(ipa1, ipa2)*3)
|
| 69 |
+
|
| 70 |
+
# @lru_cache(maxsize=None)
|
| 71 |
+
def phonetic_distance_cached(p1, p2):
|
| 72 |
+
return phonetic_distance(p1, p2)
|
| 73 |
+
|
| 74 |
+
def align_sequences(seq1, seq2):
|
| 75 |
+
n, m = len(seq1), len(seq2)
|
| 76 |
+
dp = np.zeros((n + 1, m + 1), dtype=np.float32)
|
| 77 |
+
backtrack = np.empty((n + 1, m + 1), dtype='U1')
|
| 78 |
+
|
| 79 |
+
dp[:, 0] = np.arange(n + 1)
|
| 80 |
+
dp[0, :] = np.arange(m + 1)
|
| 81 |
+
|
| 82 |
+
backtrack[:, 0] = 'D'
|
| 83 |
+
backtrack[0, :] = 'I'
|
| 84 |
+
backtrack[0, 0] = ''
|
| 85 |
+
|
| 86 |
+
for i in range(1, n + 1):
|
| 87 |
+
for j in range(1, m + 1):
|
| 88 |
+
try:
|
| 89 |
+
cost = 1 - phonetic_distance_cached(seq1[i - 1], seq2[j - 1])
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"Error computing distance between '{seq1[i - 1]}' and '{seq2[j - 1]}': {e}")
|
| 92 |
+
cost = 1.0
|
| 93 |
+
|
| 94 |
+
options = [
|
| 95 |
+
(dp[i - 1][j] + 1, 'D'),
|
| 96 |
+
(dp[i][j - 1] + 1, 'I'),
|
| 97 |
+
(dp[i - 1][j - 1] + cost, 'M')
|
| 98 |
+
]
|
| 99 |
+
dp[i][j], backtrack[i][j] = min(options, key=lambda x: x[0])
|
| 100 |
+
|
| 101 |
+
# Backtracking
|
| 102 |
+
i, j = n, m
|
| 103 |
+
aligned_seq1, aligned_seq2 = [], []
|
| 104 |
+
while i > 0 or j > 0:
|
| 105 |
+
move = backtrack[i][j]
|
| 106 |
+
if move == 'M':
|
| 107 |
+
aligned_seq1.append(seq1[i - 1]); aligned_seq2.append(seq2[j - 1])
|
| 108 |
+
i, j = i - 1, j - 1
|
| 109 |
+
elif move == 'D':
|
| 110 |
+
aligned_seq1.append(seq1[i - 1]); aligned_seq2.append('-')
|
| 111 |
+
i -= 1
|
| 112 |
+
elif move == 'I':
|
| 113 |
+
aligned_seq1.append('-'); aligned_seq2.append(seq2[j - 1])
|
| 114 |
+
j -= 1
|
| 115 |
+
else:
|
| 116 |
+
break
|
| 117 |
+
|
| 118 |
+
aligned_seq1.reverse()
|
| 119 |
+
aligned_seq2.reverse()
|
| 120 |
+
return aligned_seq1, aligned_seq2
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def score_alignment(aligned1, aligned2):
|
| 124 |
+
total = 0.0
|
| 125 |
+
scores = []
|
| 126 |
+
for p1, p2 in zip(aligned1, aligned2):
|
| 127 |
+
if p1 == '-' or p2 == '-':
|
| 128 |
+
scores.append(0.0)
|
| 129 |
+
else:
|
| 130 |
+
score = phonetic_distance_cached(p1, p2)
|
| 131 |
+
scores.append(score)
|
| 132 |
+
total += score
|
| 133 |
+
return round(total / len(scores), 3), scores
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def calculate_error_rate(ref_seq, hyp_seq, unit="phoneme"):
|
| 137 |
+
"""
|
| 138 |
+
Calculate PER (phoneme error rate) or WER (word error rate).
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
ref_seq (list[str]): reference sequence (phonemes or words)
|
| 142 |
+
hyp_seq (list[str]): hypothesis sequence
|
| 143 |
+
unit (str): "phoneme" or "word"
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
float: error rate
|
| 147 |
+
dict: counts of S, D, I
|
| 148 |
+
"""
|
| 149 |
+
ref_seq = ref_seq.replace(" ", "")
|
| 150 |
+
hyp_seq = hyp_seq.replace(" ", "")
|
| 151 |
+
aligned_ref, aligned_hyp = align_sequences(ref_seq, hyp_seq)
|
| 152 |
+
|
| 153 |
+
S = D = I = 0
|
| 154 |
+
for r, h in zip(aligned_ref, aligned_hyp):
|
| 155 |
+
if r == h:
|
| 156 |
+
continue
|
| 157 |
+
if r == "-": # insertion in hyp
|
| 158 |
+
I += 1
|
| 159 |
+
elif h == "-": # deletion in hyp
|
| 160 |
+
D += 1
|
| 161 |
+
else: # substitution
|
| 162 |
+
S += 1
|
| 163 |
+
|
| 164 |
+
N = len(ref_seq) # reference length
|
| 165 |
+
error_rate = (S + D + I) / N if N > 0 else 0.0
|
| 166 |
+
|
| 167 |
+
return error_rate*100, {"S": S, "D": D, "I": I, "N": N}
|
src/utils/cmu_process.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nltk
|
| 2 |
+
|
| 3 |
+
# Download the required POS tagger
|
| 4 |
+
nltk.download('averaged_perceptron_tagger_eng')
|
| 5 |
+
nltk.download('cmudict') # also useful for g2p-en
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from g2p_en import G2p
|
| 9 |
+
|
| 10 |
+
# Initialize g2p
|
| 11 |
+
g2p = G2p()
|
| 12 |
+
def safe_g2p(text: str):
|
| 13 |
+
try:
|
| 14 |
+
return g2p(text)
|
| 15 |
+
except Exception as e:
|
| 16 |
+
# fallback: remove digits and retry
|
| 17 |
+
cleaned = re.sub(r"\d+", "", text)
|
| 18 |
+
return g2p(cleaned)
|
| 19 |
+
|
| 20 |
+
import re
|
| 21 |
+
|
| 22 |
+
def clean_text(text):
|
| 23 |
+
# Keep letters, numbers, spaces, and apostrophes
|
| 24 |
+
return re.sub(r"[^a-zA-Z0-9' ]+", "", text)
|
| 25 |
+
|
| 26 |
+
def clean_cmu(text):
|
| 27 |
+
res = text.replace("0", "").replace("1", "").replace("2", "").replace("-", "").strip()
|
| 28 |
+
res = res.lower()
|
| 29 |
+
return res
|
| 30 |
+
|
| 31 |
+
CMU_TO_IPA = {
|
| 32 |
+
# Vowels
|
| 33 |
+
"AA": "ɑ", # odd
|
| 34 |
+
"AE": "æ", # at
|
| 35 |
+
"AH": "ʌ", # hut
|
| 36 |
+
"AH0": "ə", # about (unstressed)
|
| 37 |
+
"AO": "ɔ", # ought, story
|
| 38 |
+
"AW": "aʊ", # cow
|
| 39 |
+
"AY": "aɪ", # hide
|
| 40 |
+
"EH": "ɛ", # Ed
|
| 41 |
+
"ER": "ɝ", # stressed "ur", hurt
|
| 42 |
+
"ER0": "ɚ", # unstressed "ər"
|
| 43 |
+
"EY": "eɪ", # ate
|
| 44 |
+
"IH": "ɪ", # it
|
| 45 |
+
"IY": "i", # eat
|
| 46 |
+
"OW": "oʊ", # oat
|
| 47 |
+
"OY": "ɔɪ", # toy
|
| 48 |
+
"UH": "ʊ", # hood
|
| 49 |
+
"UW": "u", # two
|
| 50 |
+
|
| 51 |
+
# Consonants
|
| 52 |
+
"B": "b",
|
| 53 |
+
"CH": "tʃ",
|
| 54 |
+
"D": "d",
|
| 55 |
+
"DH": "ð",
|
| 56 |
+
"F": "f",
|
| 57 |
+
"G": "ɡ",
|
| 58 |
+
"HH": "h",
|
| 59 |
+
"JH": "dʒ",
|
| 60 |
+
"K": "k",
|
| 61 |
+
"L": "l",
|
| 62 |
+
"M": "m",
|
| 63 |
+
"N": "n",
|
| 64 |
+
"NG": "ŋ",
|
| 65 |
+
"P": "p",
|
| 66 |
+
"R": "r",
|
| 67 |
+
"S": "s",
|
| 68 |
+
"SH": "ʃ",
|
| 69 |
+
"T": "t",
|
| 70 |
+
"TH": "θ",
|
| 71 |
+
"V": "v",
|
| 72 |
+
"W": "w",
|
| 73 |
+
"Y": "j",
|
| 74 |
+
"Z": "z",
|
| 75 |
+
"ZH": "ʒ",
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
def cmu_to_ipa(cmu_sentence: str) -> str:
|
| 79 |
+
"""
|
| 80 |
+
Greedy match CMUdict/ARPAbet phoneme sequence into IPA.
|
| 81 |
+
- Try 2-character tokens first.
|
| 82 |
+
- Fallback to 1-character tokens.
|
| 83 |
+
Example: "DAWN T MEYK" -> "daʊn t meɪk"
|
| 84 |
+
"""
|
| 85 |
+
ipa_tokens = []
|
| 86 |
+
words = cmu_sentence.strip().split()
|
| 87 |
+
|
| 88 |
+
for word in words:
|
| 89 |
+
i = 0
|
| 90 |
+
while i < len(word):
|
| 91 |
+
# Try 2-char match
|
| 92 |
+
if i + 2 <= len(word) and word[i:i+2].upper() in CMU_TO_IPA:
|
| 93 |
+
ipa_tokens.append(CMU_TO_IPA[word[i:i+2].upper()])
|
| 94 |
+
i += 2
|
| 95 |
+
# Try 1-char match
|
| 96 |
+
elif word[i].upper() in CMU_TO_IPA:
|
| 97 |
+
ipa_tokens.append(CMU_TO_IPA[word[i].upper()])
|
| 98 |
+
i += 1
|
| 99 |
+
else:
|
| 100 |
+
# fallback: keep as lowercase character
|
| 101 |
+
ipa_tokens.append(word[i].lower())
|
| 102 |
+
i += 1
|
| 103 |
+
ipa_tokens.append(" ")
|
| 104 |
+
|
| 105 |
+
return "".join(ipa_tokens) # join chars without extra spaces
|
| 106 |
+
|
| 107 |
+
def text_to_phoneme(text):
|
| 108 |
+
phonemes = safe_g2p(clean_text(text))
|
| 109 |
+
res = "".join(phonemes)
|
| 110 |
+
res = clean_cmu(res)
|
| 111 |
+
return res
|
src/utils/load_model.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from transformers import (
|
| 6 |
+
Wav2Vec2Processor, HubertForCTC,
|
| 7 |
+
WhisperProcessor, WhisperForConditionalGeneration, Wav2Vec2ForCTC
|
| 8 |
+
)
|
| 9 |
+
from .cmu_process import text_to_phoneme, cmu_to_ipa, clean_cmu
|
| 10 |
+
from .cmu_process import clean_cmu
|
| 11 |
+
from .cmu_process import cmu_to_ipa
|
| 12 |
+
|
| 13 |
+
from dotenv import load_dotenv
|
| 14 |
+
|
| 15 |
+
# Load environment variables from .env file
|
| 16 |
+
load_dotenv()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
print("Using device:", device)
|
| 21 |
+
|
| 22 |
+
# === Helper: move all tensors to model device ===
|
| 23 |
+
def to_device(batch, device):
|
| 24 |
+
if isinstance(batch, dict):
|
| 25 |
+
return {k: v.to(device) for k, v in batch.items()}
|
| 26 |
+
elif isinstance(batch, torch.Tensor):
|
| 27 |
+
return batch.to(device)
|
| 28 |
+
return batch
|
| 29 |
+
|
| 30 |
+
# === Setup: Load all 3 models ===
|
| 31 |
+
|
| 32 |
+
# 1. Base HuBERT
|
| 33 |
+
base_proc = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
|
| 34 |
+
base_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device).eval()
|
| 35 |
+
|
| 36 |
+
# 2. Whisper + phonemizer
|
| 37 |
+
whisper_proc = WhisperProcessor.from_pretrained("openai/whisper-base")
|
| 38 |
+
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to(device).eval()
|
| 39 |
+
|
| 40 |
+
# 3. My Hubert Model (optional HF token via env)
|
| 41 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 42 |
+
proc = Wav2Vec2Processor.from_pretrained("tecasoftai/hubert-finetune", token=HF_TOKEN)
|
| 43 |
+
model = HubertForCTC.from_pretrained("tecasoftai/hubert-finetune", token=HF_TOKEN).to(device).eval()
|
| 44 |
+
|
| 45 |
+
# 4. wav2vec2-xls-r-300m-timit-phoneme
|
| 46 |
+
# load model and processor
|
| 47 |
+
timit_proc = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
|
| 48 |
+
timit_model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme").to(device).eval()
|
| 49 |
+
|
| 50 |
+
# === Inference functions ===
|
| 51 |
+
|
| 52 |
+
def run_hubert_base(wav):
|
| 53 |
+
start = time.time()
|
| 54 |
+
inputs = base_proc(wav, sampling_rate=16000, return_tensors="pt", padding=True).input_values
|
| 55 |
+
inputs = inputs.to(device)
|
| 56 |
+
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
logits = base_model(inputs).logits
|
| 59 |
+
ids = torch.argmax(logits, dim=-1)
|
| 60 |
+
text = base_proc.batch_decode(ids)[0]
|
| 61 |
+
# Convert to phonemes (CMU-like string without stresses)
|
| 62 |
+
phonemes = text_to_phoneme(text)
|
| 63 |
+
return phonemes.strip(), time.time() - start
|
| 64 |
+
|
| 65 |
+
def run_whisper(wav):
|
| 66 |
+
start = time.time()
|
| 67 |
+
|
| 68 |
+
# Preprocess
|
| 69 |
+
inputs = whisper_proc(wav, sampling_rate=16000, return_tensors="pt").input_features
|
| 70 |
+
inputs = inputs.to(device)
|
| 71 |
+
|
| 72 |
+
# Forward pass
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
pred_ids = whisper_model.generate(inputs)
|
| 75 |
+
|
| 76 |
+
# Decode
|
| 77 |
+
text = whisper_proc.batch_decode(pred_ids, skip_special_tokens=True)[0]
|
| 78 |
+
|
| 79 |
+
# Convert to phonemes
|
| 80 |
+
phonemes = text_to_phoneme(text)
|
| 81 |
+
|
| 82 |
+
return phonemes.strip(), time.time() - start
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def run_model(wav):
|
| 86 |
+
start = time.time()
|
| 87 |
+
|
| 88 |
+
# Prepare input (BatchEncoding supports .to(device))
|
| 89 |
+
inputs = proc(wav, sampling_rate=16000, return_tensors="pt").to(device)
|
| 90 |
+
|
| 91 |
+
# Forward pass
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
logits = model(**inputs).logits
|
| 94 |
+
|
| 95 |
+
# Greedy decode
|
| 96 |
+
ids = torch.argmax(logits, dim=-1)
|
| 97 |
+
phonemes = proc.batch_decode(ids)[0]
|
| 98 |
+
phonemes = cmu_to_ipa(phonemes)
|
| 99 |
+
|
| 100 |
+
return phonemes.strip(), time.time() - start
|
| 101 |
+
|
| 102 |
+
def run_timit(wav):
|
| 103 |
+
start = time.time()
|
| 104 |
+
# Read and process the input
|
| 105 |
+
inputs = timit_proc(wav, sampling_rate=16_000, return_tensors="pt", padding=True)
|
| 106 |
+
inputs = inputs.to(device)
|
| 107 |
+
|
| 108 |
+
# Forward pass
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
logits = timit_model(inputs.input_values, attention_mask=inputs.attention_mask).logits
|
| 111 |
+
|
| 112 |
+
# Decode id into string
|
| 113 |
+
predicted_ids = torch.argmax(logits, axis=-1)
|
| 114 |
+
phonemes = timit_proc.batch_decode(predicted_ids)
|
| 115 |
+
phonemes = "".join(phonemes)
|
| 116 |
+
|
| 117 |
+
return phonemes.strip(), time.time() - start
|