RASMUS commited on
Commit
0a78f68
·
verified ·
1 Parent(s): 9669f51

Upload sweep_params.py

Browse files
Files changed (1) hide show
  1. sweep_params.py +201 -0
sweep_params.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+ import pandas as pd
5
+ from pathlib import Path
6
+ from safetensors.torch import load_file
7
+ from faster_whisper import WhisperModel
8
+ from google import genai
9
+ from google.genai import types
10
+ import soundfile as sf
11
+ import re
12
+ from tqdm import tqdm
13
+ import itertools
14
+
15
+ # Internal Modules
16
+ from src.config import TrainConfig
17
+ from src.chatterbox_.mtl_tts import ChatterboxMultilingualTTS
18
+ from src.chatterbox_.models.t3.t3 import T3
19
+
20
+ # ==============================================================================
21
+ # CONFIGURATION
22
+ # ==============================================================================
23
+ GEMINI_API_KEY = "INSERT_API_KEY_HERE"
24
+ CHECKPOINT_PATH = "./chatterbox_stage2_output/checkpoint-16"
25
+ REFERENCE_WAV = "/workspaces/work/Chatterbox-Finnish/GrowthMindset_Chatterbox_Dataset/wavs/growthmindset_00000.wav"
26
+
27
+ # Align with evaluate_checkpoints.py
28
+ LEAN_HOLDOUT_IDS = [
29
+ "growthmindset_00547", # Short
30
+ "growthmindset_00548", # Medium/Long
31
+ "growthmindset_00564" # Very expressive
32
+ ]
33
+
34
+ EVERYDAY_PHRASES = [
35
+ "Voisitko ystävällisesti auttaa minua tämän asian kanssa?", # Short
36
+ "Tänään on todella kaunis päivä, joten ajattelin lähteä ulos kävelemään ja nauttimaan auringosta ennen kuin ilta viilenee.", # Long 1
37
+ "Huomenta kaikille, toivottavasti teillä on ollut mukava aamu ja olette valmiita aloittamaan uuden päivän täynnä mielenkiintoisia haasteita ja onnistumisia." # Long 2
38
+ ]
39
+
40
+ # Parameter Grid
41
+ PARAM_GRID = {
42
+ "repetition_penalty": [1.2, 1.5],
43
+ "temperature": [0.7, 0.8],
44
+ "exaggeration": [0.5, 0.6],
45
+ "cfg_weight": [0.3, 0.5]
46
+ }
47
+
48
+ OUTPUT_BASE_DIR = "./param_sweep_results"
49
+ # ==============================================================================
50
+
51
+ def setup_gemini():
52
+ return genai.Client(api_key=GEMINI_API_KEY)
53
+
54
+ def get_mos_score(client, audio_path, target_text):
55
+ try:
56
+ audio_file = client.files.upload(file=audio_path)
57
+ import time
58
+ for _ in range(10):
59
+ file_info = client.files.get(name=audio_file.name)
60
+ if file_info.state == "ACTIVE": break
61
+ time.sleep(1)
62
+
63
+ prompt = f"""
64
+ Olet asiantunteva puheenlaadun arvioija.
65
+ Arvioi oheinen äänitiedosto, jossa hienoviritetty TTS-malli sanoo: "{target_text}"
66
+
67
+ Arvioi asteikolla 1-5 (1=huono, 5=erinomainen):
68
+ 1. Luonnollisuus: Kuulostaako se ihmiseltä?
69
+ 2. Selkeys: Ovatko sanat helposti erotettavissa?
70
+ 3. Prosodia: Kuulostaako rytmi luonnolliselta suomen kielelle?
71
+
72
+ Vastaa TARKALLEEN tässä JSON-muodossa: {{"mos": <numero>, "reason": "<lyhyt_perustelu>"}}
73
+ """
74
+ response = client.models.generate_content(
75
+ model='gemini-3-flash-preview',
76
+ contents=[prompt, audio_file],
77
+ config=types.GenerateContentConfig(response_mime_type="application/json")
78
+ )
79
+ result = json.loads(response.text)
80
+ if isinstance(result, list): result = result[0]
81
+ return result
82
+ except Exception:
83
+ return {"mos": 0}
84
+
85
+ def calculate_wer(reference, hypothesis):
86
+ try:
87
+ import jiwer
88
+ return jiwer.wer(reference, hypothesis)
89
+ except ImportError:
90
+ def clean(t): return re.sub(r'[^\w\s]', '', t.lower()).strip()
91
+ ref_words = clean(reference).split()
92
+ hyp_words = clean(hypothesis).split()
93
+ if not ref_words: return 0.0
94
+ import difflib
95
+ return 1.0 - difflib.SequenceMatcher(None, ref_words, hyp_words).ratio()
96
+
97
+ def main():
98
+ cfg = TrainConfig()
99
+ device = "cuda" if torch.cuda.is_available() else "cpu"
100
+ os.makedirs(OUTPUT_BASE_DIR, exist_ok=True)
101
+
102
+ # Load metadata for holdouts
103
+ meta = pd.read_csv(cfg.csv_path, sep="|", header=None, quoting=3)
104
+ lean_meta = meta[meta[0].isin(LEAN_HOLDOUT_IDS)]
105
+ sweep_sentences = list(lean_meta[1]) + EVERYDAY_PHRASES
106
+
107
+ print("Loading Faster Whisper...")
108
+ whisper_model = WhisperModel("large-v3", device=device, compute_type="float16" if device == "cuda" else "int8")
109
+
110
+ gemini_client = setup_gemini()
111
+
112
+ # Load engine and checkpoint weights once
113
+ engine = ChatterboxMultilingualTTS.from_local(cfg.model_dir, device=device)
114
+ weights_path = Path(CHECKPOINT_PATH) / "model.safetensors"
115
+ checkpoint_state = load_file(str(weights_path))
116
+ t3_state_dict = {k[3:] if k.startswith("t3.") else k: v for k, v in checkpoint_state.items()}
117
+ if "text_emb.weight" in t3_state_dict:
118
+ engine.t3.hp.text_tokens_dict_size = t3_state_dict["text_emb.weight"].shape[0]
119
+ engine.t3 = T3(hp=engine.t3.hp).to(device)
120
+ engine.t3.load_state_dict(t3_state_dict, strict=False)
121
+ engine.t3.eval()
122
+
123
+ # Generate parameter combinations
124
+ keys, values = zip(*PARAM_GRID.items())
125
+ combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
126
+
127
+ print(f"Starting sweep of {len(combinations)} combinations using {len(sweep_sentences)} sentences...")
128
+
129
+ sweep_results = []
130
+
131
+ for i, params in enumerate(combinations):
132
+ print(f"\n[{i+1}/{len(combinations)}] Testing: {params}")
133
+
134
+ total_wer = 0
135
+ total_mos = 0
136
+ valid_mos_count = 0
137
+
138
+ for j, text in enumerate(sweep_sentences):
139
+ wav_tensor = engine.generate(
140
+ text=text,
141
+ language_id="fi",
142
+ audio_prompt_path=REFERENCE_WAV,
143
+ **params
144
+ )
145
+
146
+ # Format filename with key params for easy manual review
147
+ param_str = f"rp{params['repetition_penalty']}_temp{params['temperature']}_ex{params['exaggeration']}_cfg{params['cfg_weight']}"
148
+ audio_path = os.path.join(OUTPUT_BASE_DIR, f"trial_{i}_sent_{j}_{param_str}.wav")
149
+ sf.write(audio_path, wav_tensor.squeeze().cpu().numpy(), engine.sr)
150
+
151
+ # WER
152
+ segments, _ = whisper_model.transcribe(audio_path, language="fi")
153
+ hyp = " ".join([s.text for s in segments])
154
+ wer = calculate_wer(text, hyp)
155
+ total_wer += wer
156
+
157
+ # MOS
158
+ mos_data = get_mos_score(gemini_client, audio_path, text)
159
+ if mos_data.get('mos', 0) > 0:
160
+ total_mos += mos_data['mos']
161
+ valid_mos_count += 1
162
+
163
+ avg_wer = total_wer / len(sweep_sentences)
164
+ avg_mos = total_mos / valid_mos_count if valid_mos_count > 0 else 0
165
+
166
+ result_entry = {
167
+ "trial_id": i,
168
+ "params": params,
169
+ "avg_wer": avg_wer,
170
+ "avg_mos": avg_mos
171
+ }
172
+ sweep_results.append(result_entry)
173
+ print(f"Result: WER={avg_wer:.4f}, MOS={avg_mos:.2f}")
174
+
175
+ # Save intermediate results
176
+ with open(os.path.join(OUTPUT_BASE_DIR, "sweep_summary_partial.json"), "w") as f:
177
+ json.dump(sweep_results, f, indent=4)
178
+
179
+ # Find the best combination
180
+ # We want low WER and high MOS. A simple score: MOS * (1 - WER)
181
+ best_score = -1
182
+ best_params = None
183
+
184
+ for r in sweep_results:
185
+ score = r['avg_mos'] * (1 - r['avg_wer'])
186
+ if score > best_score:
187
+ best_score = score
188
+ best_params = r
189
+
190
+ print("\n" + "="*60)
191
+ print("SWEEP COMPLETE")
192
+ print(f"Best Params: {best_params['params']}")
193
+ print(f"Best Metrics: WER={best_params['avg_wer']:.4f}, MOS={best_params['avg_mos']:.2f}")
194
+ print("="*60)
195
+
196
+ with open(os.path.join(OUTPUT_BASE_DIR, "sweep_summary.json"), "w") as f:
197
+ json.dump(sweep_results, f, indent=4)
198
+
199
+ if __name__ == "__main__":
200
+ main()
201
+