Upload sweep_params.py
Browse files- sweep_params.py +201 -0
sweep_params.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import json
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from safetensors.torch import load_file
|
| 7 |
+
from faster_whisper import WhisperModel
|
| 8 |
+
from google import genai
|
| 9 |
+
from google.genai import types
|
| 10 |
+
import soundfile as sf
|
| 11 |
+
import re
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import itertools
|
| 14 |
+
|
| 15 |
+
# Internal Modules
|
| 16 |
+
from src.config import TrainConfig
|
| 17 |
+
from src.chatterbox_.mtl_tts import ChatterboxMultilingualTTS
|
| 18 |
+
from src.chatterbox_.models.t3.t3 import T3
|
| 19 |
+
|
| 20 |
+
# ==============================================================================
|
| 21 |
+
# CONFIGURATION
|
| 22 |
+
# ==============================================================================
|
| 23 |
+
GEMINI_API_KEY = "INSERT_API_KEY_HERE"
|
| 24 |
+
CHECKPOINT_PATH = "./chatterbox_stage2_output/checkpoint-16"
|
| 25 |
+
REFERENCE_WAV = "/workspaces/work/Chatterbox-Finnish/GrowthMindset_Chatterbox_Dataset/wavs/growthmindset_00000.wav"
|
| 26 |
+
|
| 27 |
+
# Align with evaluate_checkpoints.py
|
| 28 |
+
LEAN_HOLDOUT_IDS = [
|
| 29 |
+
"growthmindset_00547", # Short
|
| 30 |
+
"growthmindset_00548", # Medium/Long
|
| 31 |
+
"growthmindset_00564" # Very expressive
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
EVERYDAY_PHRASES = [
|
| 35 |
+
"Voisitko ystävällisesti auttaa minua tämän asian kanssa?", # Short
|
| 36 |
+
"Tänään on todella kaunis päivä, joten ajattelin lähteä ulos kävelemään ja nauttimaan auringosta ennen kuin ilta viilenee.", # Long 1
|
| 37 |
+
"Huomenta kaikille, toivottavasti teillä on ollut mukava aamu ja olette valmiita aloittamaan uuden päivän täynnä mielenkiintoisia haasteita ja onnistumisia." # Long 2
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
# Parameter Grid
|
| 41 |
+
PARAM_GRID = {
|
| 42 |
+
"repetition_penalty": [1.2, 1.5],
|
| 43 |
+
"temperature": [0.7, 0.8],
|
| 44 |
+
"exaggeration": [0.5, 0.6],
|
| 45 |
+
"cfg_weight": [0.3, 0.5]
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
OUTPUT_BASE_DIR = "./param_sweep_results"
|
| 49 |
+
# ==============================================================================
|
| 50 |
+
|
| 51 |
+
def setup_gemini():
|
| 52 |
+
return genai.Client(api_key=GEMINI_API_KEY)
|
| 53 |
+
|
| 54 |
+
def get_mos_score(client, audio_path, target_text):
|
| 55 |
+
try:
|
| 56 |
+
audio_file = client.files.upload(file=audio_path)
|
| 57 |
+
import time
|
| 58 |
+
for _ in range(10):
|
| 59 |
+
file_info = client.files.get(name=audio_file.name)
|
| 60 |
+
if file_info.state == "ACTIVE": break
|
| 61 |
+
time.sleep(1)
|
| 62 |
+
|
| 63 |
+
prompt = f"""
|
| 64 |
+
Olet asiantunteva puheenlaadun arvioija.
|
| 65 |
+
Arvioi oheinen äänitiedosto, jossa hienoviritetty TTS-malli sanoo: "{target_text}"
|
| 66 |
+
|
| 67 |
+
Arvioi asteikolla 1-5 (1=huono, 5=erinomainen):
|
| 68 |
+
1. Luonnollisuus: Kuulostaako se ihmiseltä?
|
| 69 |
+
2. Selkeys: Ovatko sanat helposti erotettavissa?
|
| 70 |
+
3. Prosodia: Kuulostaako rytmi luonnolliselta suomen kielelle?
|
| 71 |
+
|
| 72 |
+
Vastaa TARKALLEEN tässä JSON-muodossa: {{"mos": <numero>, "reason": "<lyhyt_perustelu>"}}
|
| 73 |
+
"""
|
| 74 |
+
response = client.models.generate_content(
|
| 75 |
+
model='gemini-3-flash-preview',
|
| 76 |
+
contents=[prompt, audio_file],
|
| 77 |
+
config=types.GenerateContentConfig(response_mime_type="application/json")
|
| 78 |
+
)
|
| 79 |
+
result = json.loads(response.text)
|
| 80 |
+
if isinstance(result, list): result = result[0]
|
| 81 |
+
return result
|
| 82 |
+
except Exception:
|
| 83 |
+
return {"mos": 0}
|
| 84 |
+
|
| 85 |
+
def calculate_wer(reference, hypothesis):
|
| 86 |
+
try:
|
| 87 |
+
import jiwer
|
| 88 |
+
return jiwer.wer(reference, hypothesis)
|
| 89 |
+
except ImportError:
|
| 90 |
+
def clean(t): return re.sub(r'[^\w\s]', '', t.lower()).strip()
|
| 91 |
+
ref_words = clean(reference).split()
|
| 92 |
+
hyp_words = clean(hypothesis).split()
|
| 93 |
+
if not ref_words: return 0.0
|
| 94 |
+
import difflib
|
| 95 |
+
return 1.0 - difflib.SequenceMatcher(None, ref_words, hyp_words).ratio()
|
| 96 |
+
|
| 97 |
+
def main():
|
| 98 |
+
cfg = TrainConfig()
|
| 99 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 100 |
+
os.makedirs(OUTPUT_BASE_DIR, exist_ok=True)
|
| 101 |
+
|
| 102 |
+
# Load metadata for holdouts
|
| 103 |
+
meta = pd.read_csv(cfg.csv_path, sep="|", header=None, quoting=3)
|
| 104 |
+
lean_meta = meta[meta[0].isin(LEAN_HOLDOUT_IDS)]
|
| 105 |
+
sweep_sentences = list(lean_meta[1]) + EVERYDAY_PHRASES
|
| 106 |
+
|
| 107 |
+
print("Loading Faster Whisper...")
|
| 108 |
+
whisper_model = WhisperModel("large-v3", device=device, compute_type="float16" if device == "cuda" else "int8")
|
| 109 |
+
|
| 110 |
+
gemini_client = setup_gemini()
|
| 111 |
+
|
| 112 |
+
# Load engine and checkpoint weights once
|
| 113 |
+
engine = ChatterboxMultilingualTTS.from_local(cfg.model_dir, device=device)
|
| 114 |
+
weights_path = Path(CHECKPOINT_PATH) / "model.safetensors"
|
| 115 |
+
checkpoint_state = load_file(str(weights_path))
|
| 116 |
+
t3_state_dict = {k[3:] if k.startswith("t3.") else k: v for k, v in checkpoint_state.items()}
|
| 117 |
+
if "text_emb.weight" in t3_state_dict:
|
| 118 |
+
engine.t3.hp.text_tokens_dict_size = t3_state_dict["text_emb.weight"].shape[0]
|
| 119 |
+
engine.t3 = T3(hp=engine.t3.hp).to(device)
|
| 120 |
+
engine.t3.load_state_dict(t3_state_dict, strict=False)
|
| 121 |
+
engine.t3.eval()
|
| 122 |
+
|
| 123 |
+
# Generate parameter combinations
|
| 124 |
+
keys, values = zip(*PARAM_GRID.items())
|
| 125 |
+
combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
| 126 |
+
|
| 127 |
+
print(f"Starting sweep of {len(combinations)} combinations using {len(sweep_sentences)} sentences...")
|
| 128 |
+
|
| 129 |
+
sweep_results = []
|
| 130 |
+
|
| 131 |
+
for i, params in enumerate(combinations):
|
| 132 |
+
print(f"\n[{i+1}/{len(combinations)}] Testing: {params}")
|
| 133 |
+
|
| 134 |
+
total_wer = 0
|
| 135 |
+
total_mos = 0
|
| 136 |
+
valid_mos_count = 0
|
| 137 |
+
|
| 138 |
+
for j, text in enumerate(sweep_sentences):
|
| 139 |
+
wav_tensor = engine.generate(
|
| 140 |
+
text=text,
|
| 141 |
+
language_id="fi",
|
| 142 |
+
audio_prompt_path=REFERENCE_WAV,
|
| 143 |
+
**params
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# Format filename with key params for easy manual review
|
| 147 |
+
param_str = f"rp{params['repetition_penalty']}_temp{params['temperature']}_ex{params['exaggeration']}_cfg{params['cfg_weight']}"
|
| 148 |
+
audio_path = os.path.join(OUTPUT_BASE_DIR, f"trial_{i}_sent_{j}_{param_str}.wav")
|
| 149 |
+
sf.write(audio_path, wav_tensor.squeeze().cpu().numpy(), engine.sr)
|
| 150 |
+
|
| 151 |
+
# WER
|
| 152 |
+
segments, _ = whisper_model.transcribe(audio_path, language="fi")
|
| 153 |
+
hyp = " ".join([s.text for s in segments])
|
| 154 |
+
wer = calculate_wer(text, hyp)
|
| 155 |
+
total_wer += wer
|
| 156 |
+
|
| 157 |
+
# MOS
|
| 158 |
+
mos_data = get_mos_score(gemini_client, audio_path, text)
|
| 159 |
+
if mos_data.get('mos', 0) > 0:
|
| 160 |
+
total_mos += mos_data['mos']
|
| 161 |
+
valid_mos_count += 1
|
| 162 |
+
|
| 163 |
+
avg_wer = total_wer / len(sweep_sentences)
|
| 164 |
+
avg_mos = total_mos / valid_mos_count if valid_mos_count > 0 else 0
|
| 165 |
+
|
| 166 |
+
result_entry = {
|
| 167 |
+
"trial_id": i,
|
| 168 |
+
"params": params,
|
| 169 |
+
"avg_wer": avg_wer,
|
| 170 |
+
"avg_mos": avg_mos
|
| 171 |
+
}
|
| 172 |
+
sweep_results.append(result_entry)
|
| 173 |
+
print(f"Result: WER={avg_wer:.4f}, MOS={avg_mos:.2f}")
|
| 174 |
+
|
| 175 |
+
# Save intermediate results
|
| 176 |
+
with open(os.path.join(OUTPUT_BASE_DIR, "sweep_summary_partial.json"), "w") as f:
|
| 177 |
+
json.dump(sweep_results, f, indent=4)
|
| 178 |
+
|
| 179 |
+
# Find the best combination
|
| 180 |
+
# We want low WER and high MOS. A simple score: MOS * (1 - WER)
|
| 181 |
+
best_score = -1
|
| 182 |
+
best_params = None
|
| 183 |
+
|
| 184 |
+
for r in sweep_results:
|
| 185 |
+
score = r['avg_mos'] * (1 - r['avg_wer'])
|
| 186 |
+
if score > best_score:
|
| 187 |
+
best_score = score
|
| 188 |
+
best_params = r
|
| 189 |
+
|
| 190 |
+
print("\n" + "="*60)
|
| 191 |
+
print("SWEEP COMPLETE")
|
| 192 |
+
print(f"Best Params: {best_params['params']}")
|
| 193 |
+
print(f"Best Metrics: WER={best_params['avg_wer']:.4f}, MOS={best_params['avg_mos']:.2f}")
|
| 194 |
+
print("="*60)
|
| 195 |
+
|
| 196 |
+
with open(os.path.join(OUTPUT_BASE_DIR, "sweep_summary.json"), "w") as f:
|
| 197 |
+
json.dump(sweep_results, f, indent=4)
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
main()
|
| 201 |
+
|