add: new model, ds
Browse files- app.py +20 -16
- constants.py +27 -4
- eval-results/{results_1759479712_HuBERT-Base.json → results_1759491458_HuBERT-Base.json} +10 -10
- eval-results/{results_1759479712_HuBERT-fine-tuned.json → results_1759491458_HuBERT-fine-tuned.json} +10 -10
- eval-results/{results_1759479712_LJSpeech-Gruut.json → results_1759491458_LJSpeech-Gruut.json} +10 -10
- eval-results/{results_1759479712_Timit.json → results_1759491458_Timit.json} +10 -10
- eval-results/{results_1759479712_WavLM.json → results_1759491458_WavLM.json} +10 -10
- eval-results/{results_1759479712_Whisper.json → results_1759491458_Whisper.json} +10 -10
- utils/load_model.py +91 -21
- utils_display.py +4 -0
app.py
CHANGED
|
@@ -36,6 +36,14 @@ def load_results(results_dir: str) -> pd.DataFrame:
|
|
| 36 |
rows = []
|
| 37 |
all_dataset_keys = set()
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
if not os.path.isdir(results_dir):
|
| 40 |
return pd.DataFrame(columns=["Model", "Avg PER", "Avg Duration (s)"])
|
| 41 |
|
|
@@ -72,7 +80,7 @@ def load_results(results_dir: str) -> pd.DataFrame:
|
|
| 72 |
dur_value = dataset_data.get("avg_duration") if dataset_data else None
|
| 73 |
|
| 74 |
display_name = dataset_display_names[dataset_key]
|
| 75 |
-
per_values[f"
|
| 76 |
|
| 77 |
if dur_value is not None:
|
| 78 |
dur_values.append(dur_value)
|
|
@@ -80,9 +88,11 @@ def load_results(results_dir: str) -> pd.DataFrame:
|
|
| 80 |
# Calculate average PER across all datasets
|
| 81 |
per_vals = [v for v in per_values.values() if v is not None]
|
| 82 |
avg_per = sum(per_vals) / len(per_vals) if per_vals else None
|
|
|
|
| 83 |
|
| 84 |
# Calculate average duration
|
| 85 |
avg_dur = sum(dur_values) / len(dur_values) if dur_values else None
|
|
|
|
| 86 |
|
| 87 |
row = {
|
| 88 |
"Model": make_clickable_model(model_name),
|
|
@@ -109,7 +119,15 @@ def load_results(results_dir: str) -> pd.DataFrame:
|
|
| 109 |
|
| 110 |
# Load initial data
|
| 111 |
try:
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
if eval_queue_repo is None or requested_models is None or csv_results is None:
|
| 114 |
# No token provided, fallback to local results
|
| 115 |
original_df = load_results(EVAL_RESULTS_DIR)
|
|
@@ -143,20 +161,6 @@ except Exception as e:
|
|
| 143 |
# Fallback to local results
|
| 144 |
original_df = load_results(EVAL_RESULTS_DIR)
|
| 145 |
|
| 146 |
-
# If no data is loaded, create a sample empty dataframe with proper columns
|
| 147 |
-
if original_df.empty:
|
| 148 |
-
print("No results found. Creating empty dataframe with sample data...")
|
| 149 |
-
# Create sample data to demonstrate the interface
|
| 150 |
-
sample_data = {
|
| 151 |
-
"Model": [make_clickable_model("sample/hubert-base"), make_clickable_model("sample/whisper-base")],
|
| 152 |
-
"Average PER ⬇️": [15.2, 18.5],
|
| 153 |
-
"Avg Duration (s)": [0.12, 0.15],
|
| 154 |
-
"PER phoneme_asr": [14.8, 17.2],
|
| 155 |
-
"PER kids_phoneme_md": [15.6, 19.8]
|
| 156 |
-
}
|
| 157 |
-
original_df = pd.DataFrame(sample_data)
|
| 158 |
-
print("Sample data created for demonstration.")
|
| 159 |
-
|
| 160 |
COLS = [c.name for c in fields(PhonemeEvalColumn)]
|
| 161 |
TYPES = [c.type for c in fields(PhonemeEvalColumn)]
|
| 162 |
|
|
|
|
| 36 |
rows = []
|
| 37 |
all_dataset_keys = set()
|
| 38 |
|
| 39 |
+
def round_two_decimals(value):
|
| 40 |
+
try:
|
| 41 |
+
if value is None:
|
| 42 |
+
return None
|
| 43 |
+
return round(float(value), 2)
|
| 44 |
+
except Exception:
|
| 45 |
+
return value
|
| 46 |
+
|
| 47 |
if not os.path.isdir(results_dir):
|
| 48 |
return pd.DataFrame(columns=["Model", "Avg PER", "Avg Duration (s)"])
|
| 49 |
|
|
|
|
| 80 |
dur_value = dataset_data.get("avg_duration") if dataset_data else None
|
| 81 |
|
| 82 |
display_name = dataset_display_names[dataset_key]
|
| 83 |
+
per_values[f"{display_name}"] = round_two_decimals(per_value)
|
| 84 |
|
| 85 |
if dur_value is not None:
|
| 86 |
dur_values.append(dur_value)
|
|
|
|
| 88 |
# Calculate average PER across all datasets
|
| 89 |
per_vals = [v for v in per_values.values() if v is not None]
|
| 90 |
avg_per = sum(per_vals) / len(per_vals) if per_vals else None
|
| 91 |
+
avg_per = round_two_decimals(avg_per)
|
| 92 |
|
| 93 |
# Calculate average duration
|
| 94 |
avg_dur = sum(dur_values) / len(dur_values) if dur_values else None
|
| 95 |
+
avg_dur = round_two_decimals(avg_dur)
|
| 96 |
|
| 97 |
row = {
|
| 98 |
"Model": make_clickable_model(model_name),
|
|
|
|
| 119 |
|
| 120 |
# Load initial data
|
| 121 |
try:
|
| 122 |
+
# Support both legacy (3-tuple) and new (4-tuple) returns
|
| 123 |
+
hub_info = load_all_info_from_dataset_hub()
|
| 124 |
+
if isinstance(hub_info, tuple) and len(hub_info) >= 3:
|
| 125 |
+
eval_queue_repo = hub_info[0]
|
| 126 |
+
requested_models = hub_info[1]
|
| 127 |
+
csv_results = hub_info[2]
|
| 128 |
+
# Fourth value (if present) is not used in this app
|
| 129 |
+
else:
|
| 130 |
+
eval_queue_repo, requested_models, csv_results = None, None, None
|
| 131 |
if eval_queue_repo is None or requested_models is None or csv_results is None:
|
| 132 |
# No token provided, fallback to local results
|
| 133 |
original_df = load_results(EVAL_RESULTS_DIR)
|
|
|
|
| 161 |
# Fallback to local results
|
| 162 |
original_df = load_results(EVAL_RESULTS_DIR)
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
COLS = [c.name for c in fields(PhonemeEvalColumn)]
|
| 165 |
TYPES = [c.type for c in fields(PhonemeEvalColumn)]
|
| 166 |
|
constants.py
CHANGED
|
@@ -4,6 +4,8 @@ from pathlib import Path
|
|
| 4 |
DIR_OUTPUT_REQUESTS = Path("requested_models")
|
| 5 |
EVAL_REQUESTS_PATH = Path("eval_requests")
|
| 6 |
|
|
|
|
|
|
|
| 7 |
##########################
|
| 8 |
# Text definitions #
|
| 9 |
##########################
|
|
@@ -64,19 +66,36 @@ P.S. We'd love to know which other models you'd like us to benchmark next. Contr
|
|
| 64 |
|
| 65 |
Evaluating Phoneme Recognition systems requires diverse datasets with phonetic transcriptions. We use multiple datasets to obtain robust evaluation scores for each model.
|
| 66 |
|
| 67 |
-
| Dataset | Description | Language |
|
| 68 |
-
|
| 69 |
-
| phoneme_asr | General phoneme recognition
|
| 70 |
-
| kids_phoneme_md | Children's speech phoneme dataset | English |
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
For more details on the individual datasets and how models are evaluated, refer to our documentation.
|
| 73 |
"""
|
| 74 |
|
| 75 |
LEADERBOARD_CSS = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
#leaderboard-table th .header-content {
|
| 77 |
white-space: nowrap;
|
| 78 |
}
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
#phoneme-table th .header-content {
|
| 81 |
white-space: nowrap;
|
| 82 |
}
|
|
@@ -84,6 +103,10 @@ LEADERBOARD_CSS = """
|
|
| 84 |
#phoneme-table th:hover {
|
| 85 |
background-color: var(--table-row-focus);
|
| 86 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
"""
|
| 88 |
|
| 89 |
|
|
|
|
| 4 |
DIR_OUTPUT_REQUESTS = Path("requested_models")
|
| 5 |
EVAL_REQUESTS_PATH = Path("eval_requests")
|
| 6 |
|
| 7 |
+
FINAL_SIZE = 100
|
| 8 |
+
|
| 9 |
##########################
|
| 10 |
# Text definitions #
|
| 11 |
##########################
|
|
|
|
| 66 |
|
| 67 |
Evaluating Phoneme Recognition systems requires diverse datasets with phonetic transcriptions. We use multiple datasets to obtain robust evaluation scores for each model.
|
| 68 |
|
| 69 |
+
| Dataset | Description | Language | Notes |
|
| 70 |
+
|---------|-------------|----------|-------|
|
| 71 |
+
| mirfan899/phoneme_asr | General phoneme recognition | English | split: train, field: phonetic |
|
| 72 |
+
| mirfan899/kids_phoneme_md | Children's speech phoneme dataset | English | split: train, field: phonetic |
|
| 73 |
+
| kylelovesllms/timit_asr_ipa | TIMIT phoneme transcriptions (IPA) | English | split: train, field: text |
|
| 74 |
+
| openslr/librispeech_asr | LibriSpeech clean test subset | English | split: test.clean, field: text, streaming |
|
| 75 |
+
| leduckhai/MultiMed | Multi-domain medical speech (English config) | English | split: test, config: English, streaming |
|
| 76 |
|
| 77 |
For more details on the individual datasets and how models are evaluated, refer to our documentation.
|
| 78 |
"""
|
| 79 |
|
| 80 |
LEADERBOARD_CSS = """
|
| 81 |
+
#leaderboard-table {
|
| 82 |
+
max-height: 600px;
|
| 83 |
+
overflow-y: auto;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
#leaderboard-table th .header-content {
|
| 87 |
white-space: nowrap;
|
| 88 |
}
|
| 89 |
|
| 90 |
+
#leaderboard-table td:first-child {
|
| 91 |
+
min-width: 300px;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
#phoneme-table {
|
| 95 |
+
max-height: 600px;
|
| 96 |
+
overflow-y: auto;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
#phoneme-table th .header-content {
|
| 100 |
white-space: nowrap;
|
| 101 |
}
|
|
|
|
| 103 |
#phoneme-table th:hover {
|
| 104 |
background-color: var(--table-row-focus);
|
| 105 |
}
|
| 106 |
+
|
| 107 |
+
#phoneme-table td:first-child {
|
| 108 |
+
min-width: 300px;
|
| 109 |
+
}
|
| 110 |
"""
|
| 111 |
|
| 112 |
|
eval-results/{results_1759479712_HuBERT-Base.json → results_1759491458_HuBERT-Base.json}
RENAMED
|
@@ -6,24 +6,24 @@
|
|
| 6 |
},
|
| 7 |
"results": {
|
| 8 |
"phoneme_asr": {
|
| 9 |
-
"per":
|
| 10 |
-
"avg_duration":
|
| 11 |
},
|
| 12 |
"kids_phoneme_md": {
|
| 13 |
-
"per":
|
| 14 |
-
"avg_duration":
|
| 15 |
},
|
| 16 |
"timit_asr_ipa": {
|
| 17 |
-
"per":
|
| 18 |
-
"avg_duration":
|
| 19 |
},
|
| 20 |
"librispeech_asr": {
|
| 21 |
-
"per": 81.
|
| 22 |
-
"avg_duration":
|
| 23 |
},
|
| 24 |
"MultiMed": {
|
| 25 |
-
"per":
|
| 26 |
-
"avg_duration":
|
| 27 |
}
|
| 28 |
}
|
| 29 |
}
|
|
|
|
| 6 |
},
|
| 7 |
"results": {
|
| 8 |
"phoneme_asr": {
|
| 9 |
+
"per": 78.22004335857109,
|
| 10 |
+
"avg_duration": 3.3285199880599974
|
| 11 |
},
|
| 12 |
"kids_phoneme_md": {
|
| 13 |
+
"per": 79.46124268247958,
|
| 14 |
+
"avg_duration": 7.384845638275147
|
| 15 |
},
|
| 16 |
"timit_asr_ipa": {
|
| 17 |
+
"per": 80.13455092277195,
|
| 18 |
+
"avg_duration": 3.2261718797683714
|
| 19 |
},
|
| 20 |
"librispeech_asr": {
|
| 21 |
+
"per": 81.18908836624553,
|
| 22 |
+
"avg_duration": 7.476902644634247
|
| 23 |
},
|
| 24 |
"MultiMed": {
|
| 25 |
+
"per": 83.5727737665735,
|
| 26 |
+
"avg_duration": 10.891806457042694
|
| 27 |
}
|
| 28 |
}
|
| 29 |
}
|
eval-results/{results_1759479712_HuBERT-fine-tuned.json → results_1759491458_HuBERT-fine-tuned.json}
RENAMED
|
@@ -6,24 +6,24 @@
|
|
| 6 |
},
|
| 7 |
"results": {
|
| 8 |
"phoneme_asr": {
|
| 9 |
-
"per":
|
| 10 |
-
"avg_duration":
|
| 11 |
},
|
| 12 |
"kids_phoneme_md": {
|
| 13 |
-
"per":
|
| 14 |
-
"avg_duration":
|
| 15 |
},
|
| 16 |
"timit_asr_ipa": {
|
| 17 |
-
"per":
|
| 18 |
-
"avg_duration":
|
| 19 |
},
|
| 20 |
"librispeech_asr": {
|
| 21 |
-
"per":
|
| 22 |
-
"avg_duration":
|
| 23 |
},
|
| 24 |
"MultiMed": {
|
| 25 |
-
"per":
|
| 26 |
-
"avg_duration":
|
| 27 |
}
|
| 28 |
}
|
| 29 |
}
|
|
|
|
| 6 |
},
|
| 7 |
"results": {
|
| 8 |
"phoneme_asr": {
|
| 9 |
+
"per": 2.0906059507271304,
|
| 10 |
+
"avg_duration": 3.4651901078224183
|
| 11 |
},
|
| 12 |
"kids_phoneme_md": {
|
| 13 |
+
"per": 20.20195546890277,
|
| 14 |
+
"avg_duration": 7.601937489509583
|
| 15 |
},
|
| 16 |
"timit_asr_ipa": {
|
| 17 |
+
"per": 2.6819661674832194,
|
| 18 |
+
"avg_duration": 3.3618062925338745
|
| 19 |
},
|
| 20 |
"librispeech_asr": {
|
| 21 |
+
"per": 1.6319143740707203,
|
| 22 |
+
"avg_duration": 7.760291111469269
|
| 23 |
},
|
| 24 |
"MultiMed": {
|
| 25 |
+
"per": 9.572457365078227,
|
| 26 |
+
"avg_duration": 11.040356299877168
|
| 27 |
}
|
| 28 |
}
|
| 29 |
}
|
eval-results/{results_1759479712_LJSpeech-Gruut.json → results_1759491458_LJSpeech-Gruut.json}
RENAMED
|
@@ -6,24 +6,24 @@
|
|
| 6 |
},
|
| 7 |
"results": {
|
| 8 |
"phoneme_asr": {
|
| 9 |
-
"per":
|
| 10 |
-
"avg_duration":
|
| 11 |
},
|
| 12 |
"kids_phoneme_md": {
|
| 13 |
-
"per":
|
| 14 |
-
"avg_duration":
|
| 15 |
},
|
| 16 |
"timit_asr_ipa": {
|
| 17 |
-
"per":
|
| 18 |
-
"avg_duration":
|
| 19 |
},
|
| 20 |
"librispeech_asr": {
|
| 21 |
-
"per":
|
| 22 |
-
"avg_duration":
|
| 23 |
},
|
| 24 |
"MultiMed": {
|
| 25 |
-
"per":
|
| 26 |
-
"avg_duration":
|
| 27 |
}
|
| 28 |
}
|
| 29 |
}
|
|
|
|
| 6 |
},
|
| 7 |
"results": {
|
| 8 |
"phoneme_asr": {
|
| 9 |
+
"per": 27.635612463370368,
|
| 10 |
+
"avg_duration": 2.216831774711609
|
| 11 |
},
|
| 12 |
"kids_phoneme_md": {
|
| 13 |
+
"per": 61.80856575663577,
|
| 14 |
+
"avg_duration": 4.8097358679771425
|
| 15 |
},
|
| 16 |
"timit_asr_ipa": {
|
| 17 |
+
"per": 28.17040265355878,
|
| 18 |
+
"avg_duration": 2.08021559715271
|
| 19 |
},
|
| 20 |
"librispeech_asr": {
|
| 21 |
+
"per": 20.67960537404926,
|
| 22 |
+
"avg_duration": 4.945555350780487
|
| 23 |
},
|
| 24 |
"MultiMed": {
|
| 25 |
+
"per": 31.53710463881287,
|
| 26 |
+
"avg_duration": 7.100828051567078
|
| 27 |
}
|
| 28 |
}
|
| 29 |
}
|
eval-results/{results_1759479712_Timit.json → results_1759491458_Timit.json}
RENAMED
|
@@ -6,24 +6,24 @@
|
|
| 6 |
},
|
| 7 |
"results": {
|
| 8 |
"phoneme_asr": {
|
| 9 |
-
"per":
|
| 10 |
-
"avg_duration":
|
| 11 |
},
|
| 12 |
"kids_phoneme_md": {
|
| 13 |
-
"per":
|
| 14 |
-
"avg_duration":
|
| 15 |
},
|
| 16 |
"timit_asr_ipa": {
|
| 17 |
-
"per":
|
| 18 |
-
"avg_duration":
|
| 19 |
},
|
| 20 |
"librispeech_asr": {
|
| 21 |
-
"per":
|
| 22 |
-
"avg_duration":
|
| 23 |
},
|
| 24 |
"MultiMed": {
|
| 25 |
-
"per":
|
| 26 |
-
"avg_duration":
|
| 27 |
}
|
| 28 |
}
|
| 29 |
}
|
|
|
|
| 6 |
},
|
| 7 |
"results": {
|
| 8 |
"phoneme_asr": {
|
| 9 |
+
"per": 31.917506576464163,
|
| 10 |
+
"avg_duration": 3.4731807804107664
|
| 11 |
},
|
| 12 |
"kids_phoneme_md": {
|
| 13 |
+
"per": 44.56843086404637,
|
| 14 |
+
"avg_duration": 7.674495687484741
|
| 15 |
},
|
| 16 |
"timit_asr_ipa": {
|
| 17 |
+
"per": 33.44181535059672,
|
| 18 |
+
"avg_duration": 3.374768352508545
|
| 19 |
},
|
| 20 |
"librispeech_asr": {
|
| 21 |
+
"per": 29.537610471893803,
|
| 22 |
+
"avg_duration": 7.891125264167786
|
| 23 |
},
|
| 24 |
"MultiMed": {
|
| 25 |
+
"per": 37.45253395374299,
|
| 26 |
+
"avg_duration": 11.265925951004029
|
| 27 |
}
|
| 28 |
}
|
| 29 |
}
|
eval-results/{results_1759479712_WavLM.json → results_1759491458_WavLM.json}
RENAMED
|
@@ -6,24 +6,24 @@
|
|
| 6 |
},
|
| 7 |
"results": {
|
| 8 |
"phoneme_asr": {
|
| 9 |
-
"per":
|
| 10 |
-
"avg_duration":
|
| 11 |
},
|
| 12 |
"kids_phoneme_md": {
|
| 13 |
-
"per": 63.
|
| 14 |
-
"avg_duration":
|
| 15 |
},
|
| 16 |
"timit_asr_ipa": {
|
| 17 |
-
"per": 22.
|
| 18 |
-
"avg_duration":
|
| 19 |
},
|
| 20 |
"librispeech_asr": {
|
| 21 |
-
"per":
|
| 22 |
-
"avg_duration":
|
| 23 |
},
|
| 24 |
"MultiMed": {
|
| 25 |
-
"per":
|
| 26 |
-
"avg_duration":
|
| 27 |
}
|
| 28 |
}
|
| 29 |
}
|
|
|
|
| 6 |
},
|
| 7 |
"results": {
|
| 8 |
"phoneme_asr": {
|
| 9 |
+
"per": 24.631130546986757,
|
| 10 |
+
"avg_duration": 3.4335393691062928
|
| 11 |
},
|
| 12 |
"kids_phoneme_md": {
|
| 13 |
+
"per": 63.661901397475695,
|
| 14 |
+
"avg_duration": 7.561313712596894
|
| 15 |
},
|
| 16 |
"timit_asr_ipa": {
|
| 17 |
+
"per": 22.054351601266735,
|
| 18 |
+
"avg_duration": 3.340735013484955
|
| 19 |
},
|
| 20 |
"librispeech_asr": {
|
| 21 |
+
"per": 32.58195540587739,
|
| 22 |
+
"avg_duration": 7.779554929733276
|
| 23 |
},
|
| 24 |
"MultiMed": {
|
| 25 |
+
"per": 45.96974612462279,
|
| 26 |
+
"avg_duration": 11.072271597385406
|
| 27 |
}
|
| 28 |
}
|
| 29 |
}
|
eval-results/{results_1759479712_Whisper.json → results_1759491458_Whisper.json}
RENAMED
|
@@ -6,24 +6,24 @@
|
|
| 6 |
},
|
| 7 |
"results": {
|
| 8 |
"phoneme_asr": {
|
| 9 |
-
"per":
|
| 10 |
-
"avg_duration":
|
| 11 |
},
|
| 12 |
"kids_phoneme_md": {
|
| 13 |
-
"per":
|
| 14 |
-
"avg_duration":
|
| 15 |
},
|
| 16 |
"timit_asr_ipa": {
|
| 17 |
-
"per":
|
| 18 |
-
"avg_duration":
|
| 19 |
},
|
| 20 |
"librispeech_asr": {
|
| 21 |
-
"per":
|
| 22 |
-
"avg_duration":
|
| 23 |
},
|
| 24 |
"MultiMed": {
|
| 25 |
-
"per":
|
| 26 |
-
"avg_duration":
|
| 27 |
}
|
| 28 |
}
|
| 29 |
}
|
|
|
|
| 6 |
},
|
| 7 |
"results": {
|
| 8 |
"phoneme_asr": {
|
| 9 |
+
"per": 78.71122630859638,
|
| 10 |
+
"avg_duration": 3.847285704612732
|
| 11 |
},
|
| 12 |
"kids_phoneme_md": {
|
| 13 |
+
"per": 77.85164413992199,
|
| 14 |
+
"avg_duration": 8.320557019710542
|
| 15 |
},
|
| 16 |
"timit_asr_ipa": {
|
| 17 |
+
"per": 80.6895957363744,
|
| 18 |
+
"avg_duration": 3.7425442838668825
|
| 19 |
},
|
| 20 |
"librispeech_asr": {
|
| 21 |
+
"per": 81.412840566159,
|
| 22 |
+
"avg_duration": 8.644328632354735
|
| 23 |
},
|
| 24 |
"MultiMed": {
|
| 25 |
+
"per": 80.89067869438723,
|
| 26 |
+
"avg_duration": 11.937099692821503
|
| 27 |
}
|
| 28 |
}
|
| 29 |
}
|
utils/load_model.py
CHANGED
|
@@ -9,6 +9,7 @@ from transformers import (
|
|
| 9 |
from .cmu_process import text_to_phoneme, cmu_to_ipa, clean_cmu
|
| 10 |
|
| 11 |
from dotenv import load_dotenv
|
|
|
|
| 12 |
|
| 13 |
# Load environment variables from .env file
|
| 14 |
load_dotenv()
|
|
@@ -17,6 +18,10 @@ load_dotenv()
|
|
| 17 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
print("Using device:", device)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# === Helper: move all tensors to model device ===
|
| 21 |
def to_device(batch, device):
|
| 22 |
if isinstance(batch, dict):
|
|
@@ -61,9 +66,16 @@ wavlm_model = AutoModelForCTC.from_pretrained("speech31/wavlm-large-english-phon
|
|
| 61 |
def run_hubert_base(wav):
|
| 62 |
start = time.time()
|
| 63 |
inputs = base_proc(wav, sampling_rate=16000, return_tensors="pt", padding=True).input_values
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
with torch.
|
| 67 |
logits = base_model(inputs).logits
|
| 68 |
ids = torch.argmax(logits, dim=-1)
|
| 69 |
text = base_proc.batch_decode(ids)[0]
|
|
@@ -74,20 +86,47 @@ def run_hubert_base(wav):
|
|
| 74 |
def run_whisper(wav):
|
| 75 |
start = time.time()
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
# Decode
|
| 86 |
text = whisper_proc.batch_decode(pred_ids, skip_special_tokens=True)[0]
|
| 87 |
-
|
| 88 |
-
# Convert to phonemes
|
| 89 |
phonemes = text_to_phoneme(text)
|
| 90 |
-
|
| 91 |
return phonemes.strip(), time.time() - start
|
| 92 |
|
| 93 |
|
|
@@ -95,10 +134,18 @@ def run_model(wav):
|
|
| 95 |
start = time.time()
|
| 96 |
|
| 97 |
# Prepare input (BatchEncoding supports .to(device))
|
| 98 |
-
inputs = proc(wav, sampling_rate=16000, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
# Forward pass
|
| 101 |
-
with torch.
|
| 102 |
logits = model(**inputs).logits
|
| 103 |
|
| 104 |
# Greedy decode
|
|
@@ -112,10 +159,17 @@ def run_timit(wav):
|
|
| 112 |
start = time.time()
|
| 113 |
# Read and process the input
|
| 114 |
inputs = timit_proc(wav, sampling_rate=16_000, return_tensors="pt", padding=True)
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
# Forward pass
|
| 118 |
-
with torch.
|
| 119 |
logits = timit_model(inputs.input_values, attention_mask=inputs.attention_mask).logits
|
| 120 |
|
| 121 |
# Decode id into string
|
|
@@ -135,10 +189,18 @@ def run_gruut(wav):
|
|
| 135 |
sampling_rate=16000,
|
| 136 |
return_tensors="pt",
|
| 137 |
padding=True
|
| 138 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
# Forward pass
|
| 141 |
-
with torch.
|
| 142 |
logits = gruut_model(**inputs).logits
|
| 143 |
|
| 144 |
# Greedy decode → IPA phonemes
|
|
@@ -157,13 +219,21 @@ def run_wavlm_large_phoneme(wav):
|
|
| 157 |
sampling_rate=16000,
|
| 158 |
return_tensors="pt",
|
| 159 |
padding=True
|
| 160 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
input_values = inputs.input_values
|
| 163 |
attention_mask = inputs.get("attention_mask", None)
|
| 164 |
|
| 165 |
# Forward pass
|
| 166 |
-
with torch.
|
| 167 |
logits = wavlm_model(input_values, attention_mask=attention_mask).logits
|
| 168 |
|
| 169 |
# Greedy decode → phoneme tokens
|
|
|
|
| 9 |
from .cmu_process import text_to_phoneme, cmu_to_ipa, clean_cmu
|
| 10 |
|
| 11 |
from dotenv import load_dotenv
|
| 12 |
+
import torch.backends.cudnn as cudnn
|
| 13 |
|
| 14 |
# Load environment variables from .env file
|
| 15 |
load_dotenv()
|
|
|
|
| 18 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
print("Using device:", device)
|
| 20 |
|
| 21 |
+
# Enable faster cudnn autotuner for variable input lengths
|
| 22 |
+
if device.type == "cuda":
|
| 23 |
+
cudnn.benchmark = True
|
| 24 |
+
|
| 25 |
# === Helper: move all tensors to model device ===
|
| 26 |
def to_device(batch, device):
|
| 27 |
if isinstance(batch, dict):
|
|
|
|
| 66 |
def run_hubert_base(wav):
|
| 67 |
start = time.time()
|
| 68 |
inputs = base_proc(wav, sampling_rate=16000, return_tensors="pt", padding=True).input_values
|
| 69 |
+
if device.type == "cuda":
|
| 70 |
+
try:
|
| 71 |
+
inputs = inputs.pin_memory()
|
| 72 |
+
except Exception:
|
| 73 |
+
pass
|
| 74 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 75 |
+
else:
|
| 76 |
+
inputs = inputs.to(device)
|
| 77 |
|
| 78 |
+
with torch.inference_mode():
|
| 79 |
logits = base_model(inputs).logits
|
| 80 |
ids = torch.argmax(logits, dim=-1)
|
| 81 |
text = base_proc.batch_decode(ids)[0]
|
|
|
|
| 86 |
def run_whisper(wav):
|
| 87 |
start = time.time()
|
| 88 |
|
| 89 |
+
inputs = whisper_proc(wav, sampling_rate=16000, return_tensors="pt")
|
| 90 |
+
input_features = inputs.input_features
|
| 91 |
+
if device.type == "cuda":
|
| 92 |
+
try:
|
| 93 |
+
input_features = input_features.pin_memory()
|
| 94 |
+
except Exception:
|
| 95 |
+
pass
|
| 96 |
+
input_features = input_features.to(device, non_blocking=True)
|
| 97 |
+
else:
|
| 98 |
+
input_features = input_features.to(device)
|
| 99 |
+
attention_mask = inputs.get("attention_mask", None)
|
| 100 |
+
gen_kwargs = {"language": "en"}
|
| 101 |
+
if attention_mask is not None:
|
| 102 |
+
if device.type == "cuda":
|
| 103 |
+
try:
|
| 104 |
+
attention_mask = attention_mask.pin_memory()
|
| 105 |
+
except Exception:
|
| 106 |
+
pass
|
| 107 |
+
gen_kwargs["attention_mask"] = attention_mask.to(device, non_blocking=True)
|
| 108 |
+
else:
|
| 109 |
+
gen_kwargs["attention_mask"] = attention_mask.to(device)
|
| 110 |
+
|
| 111 |
+
# Force English transcription and use greedy decoding with short max tokens for speed
|
| 112 |
+
try:
|
| 113 |
+
forced_ids = whisper_proc.get_decoder_prompt_ids(language="en", task="transcribe")
|
| 114 |
+
except Exception:
|
| 115 |
+
forced_ids = None
|
| 116 |
+
|
| 117 |
+
with torch.inference_mode():
|
| 118 |
+
pred_ids = whisper_model.generate(
|
| 119 |
+
input_features,
|
| 120 |
+
forced_decoder_ids=forced_ids,
|
| 121 |
+
do_sample=False,
|
| 122 |
+
num_beams=1,
|
| 123 |
+
max_new_tokens=64,
|
| 124 |
+
use_cache=True,
|
| 125 |
+
**gen_kwargs,
|
| 126 |
+
)
|
| 127 |
|
|
|
|
| 128 |
text = whisper_proc.batch_decode(pred_ids, skip_special_tokens=True)[0]
|
|
|
|
|
|
|
| 129 |
phonemes = text_to_phoneme(text)
|
|
|
|
| 130 |
return phonemes.strip(), time.time() - start
|
| 131 |
|
| 132 |
|
|
|
|
| 134 |
start = time.time()
|
| 135 |
|
| 136 |
# Prepare input (BatchEncoding supports .to(device))
|
| 137 |
+
inputs = proc(wav, sampling_rate=16000, return_tensors="pt")
|
| 138 |
+
if device.type == "cuda":
|
| 139 |
+
try:
|
| 140 |
+
inputs = inputs.pin_memory()
|
| 141 |
+
except Exception:
|
| 142 |
+
pass
|
| 143 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 144 |
+
else:
|
| 145 |
+
inputs = inputs.to(device)
|
| 146 |
|
| 147 |
# Forward pass
|
| 148 |
+
with torch.inference_mode():
|
| 149 |
logits = model(**inputs).logits
|
| 150 |
|
| 151 |
# Greedy decode
|
|
|
|
| 159 |
start = time.time()
|
| 160 |
# Read and process the input
|
| 161 |
inputs = timit_proc(wav, sampling_rate=16_000, return_tensors="pt", padding=True)
|
| 162 |
+
if device.type == "cuda":
|
| 163 |
+
try:
|
| 164 |
+
inputs = inputs.pin_memory()
|
| 165 |
+
except Exception:
|
| 166 |
+
pass
|
| 167 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 168 |
+
else:
|
| 169 |
+
inputs = inputs.to(device)
|
| 170 |
|
| 171 |
# Forward pass
|
| 172 |
+
with torch.inference_mode():
|
| 173 |
logits = timit_model(inputs.input_values, attention_mask=inputs.attention_mask).logits
|
| 174 |
|
| 175 |
# Decode id into string
|
|
|
|
| 189 |
sampling_rate=16000,
|
| 190 |
return_tensors="pt",
|
| 191 |
padding=True
|
| 192 |
+
)
|
| 193 |
+
if device.type == "cuda":
|
| 194 |
+
try:
|
| 195 |
+
inputs = inputs.pin_memory()
|
| 196 |
+
except Exception:
|
| 197 |
+
pass
|
| 198 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 199 |
+
else:
|
| 200 |
+
inputs = inputs.to(device)
|
| 201 |
|
| 202 |
# Forward pass
|
| 203 |
+
with torch.inference_mode():
|
| 204 |
logits = gruut_model(**inputs).logits
|
| 205 |
|
| 206 |
# Greedy decode → IPA phonemes
|
|
|
|
| 219 |
sampling_rate=16000,
|
| 220 |
return_tensors="pt",
|
| 221 |
padding=True
|
| 222 |
+
)
|
| 223 |
+
if device.type == "cuda":
|
| 224 |
+
try:
|
| 225 |
+
inputs = inputs.pin_memory()
|
| 226 |
+
except Exception:
|
| 227 |
+
pass
|
| 228 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 229 |
+
else:
|
| 230 |
+
inputs = inputs.to(device)
|
| 231 |
|
| 232 |
input_values = inputs.input_values
|
| 233 |
attention_mask = inputs.get("attention_mask", None)
|
| 234 |
|
| 235 |
# Forward pass
|
| 236 |
+
with torch.inference_mode():
|
| 237 |
logits = wavlm_model(input_values, attention_mask=attention_mask).logits
|
| 238 |
|
| 239 |
# Greedy decode → phoneme tokens
|
utils_display.py
CHANGED
|
@@ -34,6 +34,10 @@ def make_clickable_model(model_name):
|
|
| 34 |
link = "https://huggingface.co/vitouphy/wav2vec2-xls-r-300m-timit-phoneme"
|
| 35 |
elif model_name_list[0] == "Whisper":
|
| 36 |
link = "https://huggingface.co/openai/whisper-base"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
else:
|
| 38 |
link = f"https://huggingface.co/{model_name}"
|
| 39 |
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
|
|
|
| 34 |
link = "https://huggingface.co/vitouphy/wav2vec2-xls-r-300m-timit-phoneme"
|
| 35 |
elif model_name_list[0] == "Whisper":
|
| 36 |
link = "https://huggingface.co/openai/whisper-base"
|
| 37 |
+
elif model_name_list[0] == "WavLM":
|
| 38 |
+
link = "https://huggingface.co/speech31/wavlm-large-english-phoneme"
|
| 39 |
+
elif model_name_list[0] == "LJSpeech Gruut":
|
| 40 |
+
link = "https://huggingface.co/bookbot/wav2vec2-ljspeech-gruut"
|
| 41 |
else:
|
| 42 |
link = f"https://huggingface.co/{model_name}"
|
| 43 |
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|