lataon commited on
Commit
99d9342
·
1 Parent(s): 45089ef

add: new model, ds

Browse files
app.py CHANGED
@@ -36,6 +36,14 @@ def load_results(results_dir: str) -> pd.DataFrame:
36
  rows = []
37
  all_dataset_keys = set()
38
 
 
 
 
 
 
 
 
 
39
  if not os.path.isdir(results_dir):
40
  return pd.DataFrame(columns=["Model", "Avg PER", "Avg Duration (s)"])
41
 
@@ -72,7 +80,7 @@ def load_results(results_dir: str) -> pd.DataFrame:
72
  dur_value = dataset_data.get("avg_duration") if dataset_data else None
73
 
74
  display_name = dataset_display_names[dataset_key]
75
- per_values[f"PER {display_name}"] = per_value
76
 
77
  if dur_value is not None:
78
  dur_values.append(dur_value)
@@ -80,9 +88,11 @@ def load_results(results_dir: str) -> pd.DataFrame:
80
  # Calculate average PER across all datasets
81
  per_vals = [v for v in per_values.values() if v is not None]
82
  avg_per = sum(per_vals) / len(per_vals) if per_vals else None
 
83
 
84
  # Calculate average duration
85
  avg_dur = sum(dur_values) / len(dur_values) if dur_values else None
 
86
 
87
  row = {
88
  "Model": make_clickable_model(model_name),
@@ -109,7 +119,15 @@ def load_results(results_dir: str) -> pd.DataFrame:
109
 
110
  # Load initial data
111
  try:
112
- eval_queue_repo, requested_models, csv_results = load_all_info_from_dataset_hub()
 
 
 
 
 
 
 
 
113
  if eval_queue_repo is None or requested_models is None or csv_results is None:
114
  # No token provided, fallback to local results
115
  original_df = load_results(EVAL_RESULTS_DIR)
@@ -143,20 +161,6 @@ except Exception as e:
143
  # Fallback to local results
144
  original_df = load_results(EVAL_RESULTS_DIR)
145
 
146
- # If no data is loaded, create a sample empty dataframe with proper columns
147
- if original_df.empty:
148
- print("No results found. Creating empty dataframe with sample data...")
149
- # Create sample data to demonstrate the interface
150
- sample_data = {
151
- "Model": [make_clickable_model("sample/hubert-base"), make_clickable_model("sample/whisper-base")],
152
- "Average PER ⬇️": [15.2, 18.5],
153
- "Avg Duration (s)": [0.12, 0.15],
154
- "PER phoneme_asr": [14.8, 17.2],
155
- "PER kids_phoneme_md": [15.6, 19.8]
156
- }
157
- original_df = pd.DataFrame(sample_data)
158
- print("Sample data created for demonstration.")
159
-
160
  COLS = [c.name for c in fields(PhonemeEvalColumn)]
161
  TYPES = [c.type for c in fields(PhonemeEvalColumn)]
162
 
 
36
  rows = []
37
  all_dataset_keys = set()
38
 
39
+ def round_two_decimals(value):
40
+ try:
41
+ if value is None:
42
+ return None
43
+ return round(float(value), 2)
44
+ except Exception:
45
+ return value
46
+
47
  if not os.path.isdir(results_dir):
48
  return pd.DataFrame(columns=["Model", "Avg PER", "Avg Duration (s)"])
49
 
 
80
  dur_value = dataset_data.get("avg_duration") if dataset_data else None
81
 
82
  display_name = dataset_display_names[dataset_key]
83
+ per_values[f"{display_name}"] = round_two_decimals(per_value)
84
 
85
  if dur_value is not None:
86
  dur_values.append(dur_value)
 
88
  # Calculate average PER across all datasets
89
  per_vals = [v for v in per_values.values() if v is not None]
90
  avg_per = sum(per_vals) / len(per_vals) if per_vals else None
91
+ avg_per = round_two_decimals(avg_per)
92
 
93
  # Calculate average duration
94
  avg_dur = sum(dur_values) / len(dur_values) if dur_values else None
95
+ avg_dur = round_two_decimals(avg_dur)
96
 
97
  row = {
98
  "Model": make_clickable_model(model_name),
 
119
 
120
  # Load initial data
121
  try:
122
+ # Support both legacy (3-tuple) and new (4-tuple) returns
123
+ hub_info = load_all_info_from_dataset_hub()
124
+ if isinstance(hub_info, tuple) and len(hub_info) >= 3:
125
+ eval_queue_repo = hub_info[0]
126
+ requested_models = hub_info[1]
127
+ csv_results = hub_info[2]
128
+ # Fourth value (if present) is not used in this app
129
+ else:
130
+ eval_queue_repo, requested_models, csv_results = None, None, None
131
  if eval_queue_repo is None or requested_models is None or csv_results is None:
132
  # No token provided, fallback to local results
133
  original_df = load_results(EVAL_RESULTS_DIR)
 
161
  # Fallback to local results
162
  original_df = load_results(EVAL_RESULTS_DIR)
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  COLS = [c.name for c in fields(PhonemeEvalColumn)]
165
  TYPES = [c.type for c in fields(PhonemeEvalColumn)]
166
 
constants.py CHANGED
@@ -4,6 +4,8 @@ from pathlib import Path
4
  DIR_OUTPUT_REQUESTS = Path("requested_models")
5
  EVAL_REQUESTS_PATH = Path("eval_requests")
6
 
 
 
7
  ##########################
8
  # Text definitions #
9
  ##########################
@@ -64,19 +66,36 @@ P.S. We'd love to know which other models you'd like us to benchmark next. Contr
64
 
65
  Evaluating Phoneme Recognition systems requires diverse datasets with phonetic transcriptions. We use multiple datasets to obtain robust evaluation scores for each model.
66
 
67
- | Dataset | Description | Language | License |
68
- |---------|-------------|----------|---------|
69
- | phoneme_asr | General phoneme recognition dataset | English | Open |
70
- | kids_phoneme_md | Children's speech phoneme dataset | English | Open |
 
 
 
71
 
72
  For more details on the individual datasets and how models are evaluated, refer to our documentation.
73
  """
74
 
75
  LEADERBOARD_CSS = """
 
 
 
 
 
76
  #leaderboard-table th .header-content {
77
  white-space: nowrap;
78
  }
79
 
 
 
 
 
 
 
 
 
 
80
  #phoneme-table th .header-content {
81
  white-space: nowrap;
82
  }
@@ -84,6 +103,10 @@ LEADERBOARD_CSS = """
84
  #phoneme-table th:hover {
85
  background-color: var(--table-row-focus);
86
  }
 
 
 
 
87
  """
88
 
89
 
 
4
  DIR_OUTPUT_REQUESTS = Path("requested_models")
5
  EVAL_REQUESTS_PATH = Path("eval_requests")
6
 
7
+ FINAL_SIZE = 100
8
+
9
  ##########################
10
  # Text definitions #
11
  ##########################
 
66
 
67
  Evaluating Phoneme Recognition systems requires diverse datasets with phonetic transcriptions. We use multiple datasets to obtain robust evaluation scores for each model.
68
 
69
+ | Dataset | Description | Language | Notes |
70
+ |---------|-------------|----------|-------|
71
+ | mirfan899/phoneme_asr | General phoneme recognition | English | split: train, field: phonetic |
72
+ | mirfan899/kids_phoneme_md | Children's speech phoneme dataset | English | split: train, field: phonetic |
73
+ | kylelovesllms/timit_asr_ipa | TIMIT phoneme transcriptions (IPA) | English | split: train, field: text |
74
+ | openslr/librispeech_asr | LibriSpeech clean test subset | English | split: test.clean, field: text, streaming |
75
+ | leduckhai/MultiMed | Multi-domain medical speech (English config) | English | split: test, config: English, streaming |
76
 
77
  For more details on the individual datasets and how models are evaluated, refer to our documentation.
78
  """
79
 
80
  LEADERBOARD_CSS = """
81
+ #leaderboard-table {
82
+ max-height: 600px;
83
+ overflow-y: auto;
84
+ }
85
+
86
  #leaderboard-table th .header-content {
87
  white-space: nowrap;
88
  }
89
 
90
+ #leaderboard-table td:first-child {
91
+ min-width: 300px;
92
+ }
93
+
94
+ #phoneme-table {
95
+ max-height: 600px;
96
+ overflow-y: auto;
97
+ }
98
+
99
  #phoneme-table th .header-content {
100
  white-space: nowrap;
101
  }
 
103
  #phoneme-table th:hover {
104
  background-color: var(--table-row-focus);
105
  }
106
+
107
+ #phoneme-table td:first-child {
108
+ min-width: 300px;
109
+ }
110
  """
111
 
112
 
eval-results/{results_1759479712_HuBERT-Base.json → results_1759491458_HuBERT-Base.json} RENAMED
@@ -6,24 +6,24 @@
6
  },
7
  "results": {
8
  "phoneme_asr": {
9
- "per": 80.73712068409569,
10
- "avg_duration": 1.006052589416504
11
  },
12
  "kids_phoneme_md": {
13
- "per": 74.8274712307235,
14
- "avg_duration": 1.4053531885147095
15
  },
16
  "timit_asr_ipa": {
17
- "per": 79.21011611385504,
18
- "avg_duration": 0.8184992551803589
19
  },
20
  "librispeech_asr": {
21
- "per": 81.8414587948362,
22
- "avg_duration": 2.6552599668502808
23
  },
24
  "MultiMed": {
25
- "per": 86.31836686921642,
26
- "avg_duration": 2.520846700668335
27
  }
28
  }
29
  }
 
6
  },
7
  "results": {
8
  "phoneme_asr": {
9
+ "per": 78.22004335857109,
10
+ "avg_duration": 3.3285199880599974
11
  },
12
  "kids_phoneme_md": {
13
+ "per": 79.46124268247958,
14
+ "avg_duration": 7.384845638275147
15
  },
16
  "timit_asr_ipa": {
17
+ "per": 80.13455092277195,
18
+ "avg_duration": 3.2261718797683714
19
  },
20
  "librispeech_asr": {
21
+ "per": 81.18908836624553,
22
+ "avg_duration": 7.476902644634247
23
  },
24
  "MultiMed": {
25
+ "per": 83.5727737665735,
26
+ "avg_duration": 10.891806457042694
27
  }
28
  }
29
  }
eval-results/{results_1759479712_HuBERT-fine-tuned.json → results_1759491458_HuBERT-fine-tuned.json} RENAMED
@@ -6,24 +6,24 @@
6
  },
7
  "results": {
8
  "phoneme_asr": {
9
- "per": 3.1765040500162365,
10
- "avg_duration": 1.0928319931030273
11
  },
12
  "kids_phoneme_md": {
13
- "per": 13.847118841760139,
14
- "avg_duration": 1.43447744846344
15
  },
16
  "timit_asr_ipa": {
17
- "per": 3.5624700539646397,
18
- "avg_duration": 0.8138290405273437
19
  },
20
  "librispeech_asr": {
21
- "per": 2.1361935038679745,
22
- "avg_duration": 2.591994023323059
23
  },
24
  "MultiMed": {
25
- "per": 12.195454796657222,
26
- "avg_duration": 2.4015810966491697
27
  }
28
  }
29
  }
 
6
  },
7
  "results": {
8
  "phoneme_asr": {
9
+ "per": 2.0906059507271304,
10
+ "avg_duration": 3.4651901078224183
11
  },
12
  "kids_phoneme_md": {
13
+ "per": 20.20195546890277,
14
+ "avg_duration": 7.601937489509583
15
  },
16
  "timit_asr_ipa": {
17
+ "per": 2.6819661674832194,
18
+ "avg_duration": 3.3618062925338745
19
  },
20
  "librispeech_asr": {
21
+ "per": 1.6319143740707203,
22
+ "avg_duration": 7.760291111469269
23
  },
24
  "MultiMed": {
25
+ "per": 9.572457365078227,
26
+ "avg_duration": 11.040356299877168
27
  }
28
  }
29
  }
eval-results/{results_1759479712_LJSpeech-Gruut.json → results_1759491458_LJSpeech-Gruut.json} RENAMED
@@ -6,24 +6,24 @@
6
  },
7
  "results": {
8
  "phoneme_asr": {
9
- "per": 28.34934978626287,
10
- "avg_duration": 0.3894784927368164
11
  },
12
  "kids_phoneme_md": {
13
- "per": 62.007568280756246,
14
- "avg_duration": 0.5734055519104004
15
  },
16
  "timit_asr_ipa": {
17
- "per": 24.322912970242964,
18
- "avg_duration": 0.3130455732345581
19
  },
20
  "librispeech_asr": {
21
- "per": 21.098893815003613,
22
- "avg_duration": 1.034156036376953
23
  },
24
  "MultiMed": {
25
- "per": 37.90138577574676,
26
- "avg_duration": 1.0464757680892944
27
  }
28
  }
29
  }
 
6
  },
7
  "results": {
8
  "phoneme_asr": {
9
+ "per": 27.635612463370368,
10
+ "avg_duration": 2.216831774711609
11
  },
12
  "kids_phoneme_md": {
13
+ "per": 61.80856575663577,
14
+ "avg_duration": 4.8097358679771425
15
  },
16
  "timit_asr_ipa": {
17
+ "per": 28.17040265355878,
18
+ "avg_duration": 2.08021559715271
19
  },
20
  "librispeech_asr": {
21
+ "per": 20.67960537404926,
22
+ "avg_duration": 4.945555350780487
23
  },
24
  "MultiMed": {
25
+ "per": 31.53710463881287,
26
+ "avg_duration": 7.100828051567078
27
  }
28
  }
29
  }
eval-results/{results_1759479712_Timit.json → results_1759491458_Timit.json} RENAMED
@@ -6,24 +6,24 @@
6
  },
7
  "results": {
8
  "phoneme_asr": {
9
- "per": 32.78310772297904,
10
- "avg_duration": 1.0769179582595825
11
  },
12
  "kids_phoneme_md": {
13
- "per": 42.393439204382865,
14
- "avg_duration": 1.4808897733688355
15
  },
16
  "timit_asr_ipa": {
17
- "per": 28.852864777541704,
18
- "avg_duration": 0.8038362503051758
19
  },
20
  "librispeech_asr": {
21
- "per": 28.88432664616071,
22
- "avg_duration": 2.5855883836746214
23
  },
24
  "MultiMed": {
25
- "per": 42.29417929178023,
26
- "avg_duration": 2.4689067125320436
27
  }
28
  }
29
  }
 
6
  },
7
  "results": {
8
  "phoneme_asr": {
9
+ "per": 31.917506576464163,
10
+ "avg_duration": 3.4731807804107664
11
  },
12
  "kids_phoneme_md": {
13
+ "per": 44.56843086404637,
14
+ "avg_duration": 7.674495687484741
15
  },
16
  "timit_asr_ipa": {
17
+ "per": 33.44181535059672,
18
+ "avg_duration": 3.374768352508545
19
  },
20
  "librispeech_asr": {
21
+ "per": 29.537610471893803,
22
+ "avg_duration": 7.891125264167786
23
  },
24
  "MultiMed": {
25
+ "per": 37.45253395374299,
26
+ "avg_duration": 11.265925951004029
27
  }
28
  }
29
  }
eval-results/{results_1759479712_WavLM.json → results_1759491458_WavLM.json} RENAMED
@@ -6,24 +6,24 @@
6
  },
7
  "results": {
8
  "phoneme_asr": {
9
- "per": 25.04219454527341,
10
- "avg_duration": 1.054517960548401
11
  },
12
  "kids_phoneme_md": {
13
- "per": 63.40875812391994,
14
- "avg_duration": 1.476344680786133
15
  },
16
  "timit_asr_ipa": {
17
- "per": 22.821457511149568,
18
- "avg_duration": 0.7534051895141601
19
  },
20
  "librispeech_asr": {
21
- "per": 36.13438162282092,
22
- "avg_duration": 2.5621693611145018
23
  },
24
  "MultiMed": {
25
- "per": 57.01443813462704,
26
- "avg_duration": 2.337135744094849
27
  }
28
  }
29
  }
 
6
  },
7
  "results": {
8
  "phoneme_asr": {
9
+ "per": 24.631130546986757,
10
+ "avg_duration": 3.4335393691062928
11
  },
12
  "kids_phoneme_md": {
13
+ "per": 63.661901397475695,
14
+ "avg_duration": 7.561313712596894
15
  },
16
  "timit_asr_ipa": {
17
+ "per": 22.054351601266735,
18
+ "avg_duration": 3.340735013484955
19
  },
20
  "librispeech_asr": {
21
+ "per": 32.58195540587739,
22
+ "avg_duration": 7.779554929733276
23
  },
24
  "MultiMed": {
25
+ "per": 45.96974612462279,
26
+ "avg_duration": 11.072271597385406
27
  }
28
  }
29
  }
eval-results/{results_1759479712_Whisper.json → results_1759491458_Whisper.json} RENAMED
@@ -6,24 +6,24 @@
6
  },
7
  "results": {
8
  "phoneme_asr": {
9
- "per": 83.44842270480702,
10
- "avg_duration": 1.5802977561950684
11
  },
12
  "kids_phoneme_md": {
13
- "per": 73.97112058868787,
14
- "avg_duration": 1.4796640157699585
15
  },
16
  "timit_asr_ipa": {
17
- "per": 78.25013458573484,
18
- "avg_duration": 1.2946593046188355
19
  },
20
  "librispeech_asr": {
21
- "per": 82.02327697665437,
22
- "avg_duration": 1.9603740453720093
23
  },
24
  "MultiMed": {
25
- "per": 77.10185035170976,
26
- "avg_duration": 1.68308687210083
27
  }
28
  }
29
  }
 
6
  },
7
  "results": {
8
  "phoneme_asr": {
9
+ "per": 78.71122630859638,
10
+ "avg_duration": 3.847285704612732
11
  },
12
  "kids_phoneme_md": {
13
+ "per": 77.85164413992199,
14
+ "avg_duration": 8.320557019710542
15
  },
16
  "timit_asr_ipa": {
17
+ "per": 80.6895957363744,
18
+ "avg_duration": 3.7425442838668825
19
  },
20
  "librispeech_asr": {
21
+ "per": 81.412840566159,
22
+ "avg_duration": 8.644328632354735
23
  },
24
  "MultiMed": {
25
+ "per": 80.89067869438723,
26
+ "avg_duration": 11.937099692821503
27
  }
28
  }
29
  }
utils/load_model.py CHANGED
@@ -9,6 +9,7 @@ from transformers import (
9
  from .cmu_process import text_to_phoneme, cmu_to_ipa, clean_cmu
10
 
11
  from dotenv import load_dotenv
 
12
 
13
  # Load environment variables from .env file
14
  load_dotenv()
@@ -17,6 +18,10 @@ load_dotenv()
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  print("Using device:", device)
19
 
 
 
 
 
20
  # === Helper: move all tensors to model device ===
21
  def to_device(batch, device):
22
  if isinstance(batch, dict):
@@ -61,9 +66,16 @@ wavlm_model = AutoModelForCTC.from_pretrained("speech31/wavlm-large-english-phon
61
  def run_hubert_base(wav):
62
  start = time.time()
63
  inputs = base_proc(wav, sampling_rate=16000, return_tensors="pt", padding=True).input_values
64
- inputs = inputs.to(device)
 
 
 
 
 
 
 
65
 
66
- with torch.no_grad():
67
  logits = base_model(inputs).logits
68
  ids = torch.argmax(logits, dim=-1)
69
  text = base_proc.batch_decode(ids)[0]
@@ -74,20 +86,47 @@ def run_hubert_base(wav):
74
  def run_whisper(wav):
75
  start = time.time()
76
 
77
- # Preprocess
78
- inputs = whisper_proc(wav, sampling_rate=16000, return_tensors="pt").input_features
79
- inputs = inputs.to(device)
80
-
81
- # Forward pass
82
- with torch.no_grad():
83
- pred_ids = whisper_model.generate(inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # Decode
86
  text = whisper_proc.batch_decode(pred_ids, skip_special_tokens=True)[0]
87
-
88
- # Convert to phonemes
89
  phonemes = text_to_phoneme(text)
90
-
91
  return phonemes.strip(), time.time() - start
92
 
93
 
@@ -95,10 +134,18 @@ def run_model(wav):
95
  start = time.time()
96
 
97
  # Prepare input (BatchEncoding supports .to(device))
98
- inputs = proc(wav, sampling_rate=16000, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
99
 
100
  # Forward pass
101
- with torch.no_grad():
102
  logits = model(**inputs).logits
103
 
104
  # Greedy decode
@@ -112,10 +159,17 @@ def run_timit(wav):
112
  start = time.time()
113
  # Read and process the input
114
  inputs = timit_proc(wav, sampling_rate=16_000, return_tensors="pt", padding=True)
115
- inputs = inputs.to(device)
 
 
 
 
 
 
 
116
 
117
  # Forward pass
118
- with torch.no_grad():
119
  logits = timit_model(inputs.input_values, attention_mask=inputs.attention_mask).logits
120
 
121
  # Decode id into string
@@ -135,10 +189,18 @@ def run_gruut(wav):
135
  sampling_rate=16000,
136
  return_tensors="pt",
137
  padding=True
138
- ).to(device)
 
 
 
 
 
 
 
 
139
 
140
  # Forward pass
141
- with torch.no_grad():
142
  logits = gruut_model(**inputs).logits
143
 
144
  # Greedy decode → IPA phonemes
@@ -157,13 +219,21 @@ def run_wavlm_large_phoneme(wav):
157
  sampling_rate=16000,
158
  return_tensors="pt",
159
  padding=True
160
- ).to(device)
 
 
 
 
 
 
 
 
161
 
162
  input_values = inputs.input_values
163
  attention_mask = inputs.get("attention_mask", None)
164
 
165
  # Forward pass
166
- with torch.no_grad():
167
  logits = wavlm_model(input_values, attention_mask=attention_mask).logits
168
 
169
  # Greedy decode → phoneme tokens
 
9
  from .cmu_process import text_to_phoneme, cmu_to_ipa, clean_cmu
10
 
11
  from dotenv import load_dotenv
12
+ import torch.backends.cudnn as cudnn
13
 
14
  # Load environment variables from .env file
15
  load_dotenv()
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  print("Using device:", device)
20
 
21
+ # Enable faster cudnn autotuner for variable input lengths
22
+ if device.type == "cuda":
23
+ cudnn.benchmark = True
24
+
25
  # === Helper: move all tensors to model device ===
26
  def to_device(batch, device):
27
  if isinstance(batch, dict):
 
66
  def run_hubert_base(wav):
67
  start = time.time()
68
  inputs = base_proc(wav, sampling_rate=16000, return_tensors="pt", padding=True).input_values
69
+ if device.type == "cuda":
70
+ try:
71
+ inputs = inputs.pin_memory()
72
+ except Exception:
73
+ pass
74
+ inputs = inputs.to(device, non_blocking=True)
75
+ else:
76
+ inputs = inputs.to(device)
77
 
78
+ with torch.inference_mode():
79
  logits = base_model(inputs).logits
80
  ids = torch.argmax(logits, dim=-1)
81
  text = base_proc.batch_decode(ids)[0]
 
86
  def run_whisper(wav):
87
  start = time.time()
88
 
89
+ inputs = whisper_proc(wav, sampling_rate=16000, return_tensors="pt")
90
+ input_features = inputs.input_features
91
+ if device.type == "cuda":
92
+ try:
93
+ input_features = input_features.pin_memory()
94
+ except Exception:
95
+ pass
96
+ input_features = input_features.to(device, non_blocking=True)
97
+ else:
98
+ input_features = input_features.to(device)
99
+ attention_mask = inputs.get("attention_mask", None)
100
+ gen_kwargs = {"language": "en"}
101
+ if attention_mask is not None:
102
+ if device.type == "cuda":
103
+ try:
104
+ attention_mask = attention_mask.pin_memory()
105
+ except Exception:
106
+ pass
107
+ gen_kwargs["attention_mask"] = attention_mask.to(device, non_blocking=True)
108
+ else:
109
+ gen_kwargs["attention_mask"] = attention_mask.to(device)
110
+
111
+ # Force English transcription and use greedy decoding with short max tokens for speed
112
+ try:
113
+ forced_ids = whisper_proc.get_decoder_prompt_ids(language="en", task="transcribe")
114
+ except Exception:
115
+ forced_ids = None
116
+
117
+ with torch.inference_mode():
118
+ pred_ids = whisper_model.generate(
119
+ input_features,
120
+ forced_decoder_ids=forced_ids,
121
+ do_sample=False,
122
+ num_beams=1,
123
+ max_new_tokens=64,
124
+ use_cache=True,
125
+ **gen_kwargs,
126
+ )
127
 
 
128
  text = whisper_proc.batch_decode(pred_ids, skip_special_tokens=True)[0]
 
 
129
  phonemes = text_to_phoneme(text)
 
130
  return phonemes.strip(), time.time() - start
131
 
132
 
 
134
  start = time.time()
135
 
136
  # Prepare input (BatchEncoding supports .to(device))
137
+ inputs = proc(wav, sampling_rate=16000, return_tensors="pt")
138
+ if device.type == "cuda":
139
+ try:
140
+ inputs = inputs.pin_memory()
141
+ except Exception:
142
+ pass
143
+ inputs = inputs.to(device, non_blocking=True)
144
+ else:
145
+ inputs = inputs.to(device)
146
 
147
  # Forward pass
148
+ with torch.inference_mode():
149
  logits = model(**inputs).logits
150
 
151
  # Greedy decode
 
159
  start = time.time()
160
  # Read and process the input
161
  inputs = timit_proc(wav, sampling_rate=16_000, return_tensors="pt", padding=True)
162
+ if device.type == "cuda":
163
+ try:
164
+ inputs = inputs.pin_memory()
165
+ except Exception:
166
+ pass
167
+ inputs = inputs.to(device, non_blocking=True)
168
+ else:
169
+ inputs = inputs.to(device)
170
 
171
  # Forward pass
172
+ with torch.inference_mode():
173
  logits = timit_model(inputs.input_values, attention_mask=inputs.attention_mask).logits
174
 
175
  # Decode id into string
 
189
  sampling_rate=16000,
190
  return_tensors="pt",
191
  padding=True
192
+ )
193
+ if device.type == "cuda":
194
+ try:
195
+ inputs = inputs.pin_memory()
196
+ except Exception:
197
+ pass
198
+ inputs = inputs.to(device, non_blocking=True)
199
+ else:
200
+ inputs = inputs.to(device)
201
 
202
  # Forward pass
203
+ with torch.inference_mode():
204
  logits = gruut_model(**inputs).logits
205
 
206
  # Greedy decode → IPA phonemes
 
219
  sampling_rate=16000,
220
  return_tensors="pt",
221
  padding=True
222
+ )
223
+ if device.type == "cuda":
224
+ try:
225
+ inputs = inputs.pin_memory()
226
+ except Exception:
227
+ pass
228
+ inputs = inputs.to(device, non_blocking=True)
229
+ else:
230
+ inputs = inputs.to(device)
231
 
232
  input_values = inputs.input_values
233
  attention_mask = inputs.get("attention_mask", None)
234
 
235
  # Forward pass
236
+ with torch.inference_mode():
237
  logits = wavlm_model(input_values, attention_mask=attention_mask).logits
238
 
239
  # Greedy decode → phoneme tokens
utils_display.py CHANGED
@@ -34,6 +34,10 @@ def make_clickable_model(model_name):
34
  link = "https://huggingface.co/vitouphy/wav2vec2-xls-r-300m-timit-phoneme"
35
  elif model_name_list[0] == "Whisper":
36
  link = "https://huggingface.co/openai/whisper-base"
 
 
 
 
37
  else:
38
  link = f"https://huggingface.co/{model_name}"
39
  return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
 
34
  link = "https://huggingface.co/vitouphy/wav2vec2-xls-r-300m-timit-phoneme"
35
  elif model_name_list[0] == "Whisper":
36
  link = "https://huggingface.co/openai/whisper-base"
37
+ elif model_name_list[0] == "WavLM":
38
+ link = "https://huggingface.co/speech31/wavlm-large-english-phoneme"
39
+ elif model_name_list[0] == "LJSpeech Gruut":
40
+ link = "https://huggingface.co/bookbot/wav2vec2-ljspeech-gruut"
41
  else:
42
  link = f"https://huggingface.co/{model_name}"
43
  return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'