File size: 7,555 Bytes
0a78f68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import os
import torch
import json
import pandas as pd
from pathlib import Path
from safetensors.torch import load_file
from faster_whisper import WhisperModel
from google import genai
from google.genai import types
import soundfile as sf
import re
from tqdm import tqdm
import itertools
# Internal Modules
from src.config import TrainConfig
from src.chatterbox_.mtl_tts import ChatterboxMultilingualTTS
from src.chatterbox_.models.t3.t3 import T3
# ==============================================================================
# CONFIGURATION
# ==============================================================================
GEMINI_API_KEY = "INSERT_API_KEY_HERE"
CHECKPOINT_PATH = "./chatterbox_stage2_output/checkpoint-16"
REFERENCE_WAV = "/workspaces/work/Chatterbox-Finnish/GrowthMindset_Chatterbox_Dataset/wavs/growthmindset_00000.wav"
# Align with evaluate_checkpoints.py
LEAN_HOLDOUT_IDS = [
"growthmindset_00547", # Short
"growthmindset_00548", # Medium/Long
"growthmindset_00564" # Very expressive
]
EVERYDAY_PHRASES = [
"Voisitko ystävällisesti auttaa minua tämän asian kanssa?", # Short
"Tänään on todella kaunis päivä, joten ajattelin lähteä ulos kävelemään ja nauttimaan auringosta ennen kuin ilta viilenee.", # Long 1
"Huomenta kaikille, toivottavasti teillä on ollut mukava aamu ja olette valmiita aloittamaan uuden päivän täynnä mielenkiintoisia haasteita ja onnistumisia." # Long 2
]
# Parameter Grid
PARAM_GRID = {
"repetition_penalty": [1.2, 1.5],
"temperature": [0.7, 0.8],
"exaggeration": [0.5, 0.6],
"cfg_weight": [0.3, 0.5]
}
OUTPUT_BASE_DIR = "./param_sweep_results"
# ==============================================================================
def setup_gemini():
return genai.Client(api_key=GEMINI_API_KEY)
def get_mos_score(client, audio_path, target_text):
try:
audio_file = client.files.upload(file=audio_path)
import time
for _ in range(10):
file_info = client.files.get(name=audio_file.name)
if file_info.state == "ACTIVE": break
time.sleep(1)
prompt = f"""
Olet asiantunteva puheenlaadun arvioija.
Arvioi oheinen äänitiedosto, jossa hienoviritetty TTS-malli sanoo: "{target_text}"
Arvioi asteikolla 1-5 (1=huono, 5=erinomainen):
1. Luonnollisuus: Kuulostaako se ihmiseltä?
2. Selkeys: Ovatko sanat helposti erotettavissa?
3. Prosodia: Kuulostaako rytmi luonnolliselta suomen kielelle?
Vastaa TARKALLEEN tässä JSON-muodossa: {{"mos": <numero>, "reason": "<lyhyt_perustelu>"}}
"""
response = client.models.generate_content(
model='gemini-3-flash-preview',
contents=[prompt, audio_file],
config=types.GenerateContentConfig(response_mime_type="application/json")
)
result = json.loads(response.text)
if isinstance(result, list): result = result[0]
return result
except Exception:
return {"mos": 0}
def calculate_wer(reference, hypothesis):
try:
import jiwer
return jiwer.wer(reference, hypothesis)
except ImportError:
def clean(t): return re.sub(r'[^\w\s]', '', t.lower()).strip()
ref_words = clean(reference).split()
hyp_words = clean(hypothesis).split()
if not ref_words: return 0.0
import difflib
return 1.0 - difflib.SequenceMatcher(None, ref_words, hyp_words).ratio()
def main():
cfg = TrainConfig()
device = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(OUTPUT_BASE_DIR, exist_ok=True)
# Load metadata for holdouts
meta = pd.read_csv(cfg.csv_path, sep="|", header=None, quoting=3)
lean_meta = meta[meta[0].isin(LEAN_HOLDOUT_IDS)]
sweep_sentences = list(lean_meta[1]) + EVERYDAY_PHRASES
print("Loading Faster Whisper...")
whisper_model = WhisperModel("large-v3", device=device, compute_type="float16" if device == "cuda" else "int8")
gemini_client = setup_gemini()
# Load engine and checkpoint weights once
engine = ChatterboxMultilingualTTS.from_local(cfg.model_dir, device=device)
weights_path = Path(CHECKPOINT_PATH) / "model.safetensors"
checkpoint_state = load_file(str(weights_path))
t3_state_dict = {k[3:] if k.startswith("t3.") else k: v for k, v in checkpoint_state.items()}
if "text_emb.weight" in t3_state_dict:
engine.t3.hp.text_tokens_dict_size = t3_state_dict["text_emb.weight"].shape[0]
engine.t3 = T3(hp=engine.t3.hp).to(device)
engine.t3.load_state_dict(t3_state_dict, strict=False)
engine.t3.eval()
# Generate parameter combinations
keys, values = zip(*PARAM_GRID.items())
combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
print(f"Starting sweep of {len(combinations)} combinations using {len(sweep_sentences)} sentences...")
sweep_results = []
for i, params in enumerate(combinations):
print(f"\n[{i+1}/{len(combinations)}] Testing: {params}")
total_wer = 0
total_mos = 0
valid_mos_count = 0
for j, text in enumerate(sweep_sentences):
wav_tensor = engine.generate(
text=text,
language_id="fi",
audio_prompt_path=REFERENCE_WAV,
**params
)
# Format filename with key params for easy manual review
param_str = f"rp{params['repetition_penalty']}_temp{params['temperature']}_ex{params['exaggeration']}_cfg{params['cfg_weight']}"
audio_path = os.path.join(OUTPUT_BASE_DIR, f"trial_{i}_sent_{j}_{param_str}.wav")
sf.write(audio_path, wav_tensor.squeeze().cpu().numpy(), engine.sr)
# WER
segments, _ = whisper_model.transcribe(audio_path, language="fi")
hyp = " ".join([s.text for s in segments])
wer = calculate_wer(text, hyp)
total_wer += wer
# MOS
mos_data = get_mos_score(gemini_client, audio_path, text)
if mos_data.get('mos', 0) > 0:
total_mos += mos_data['mos']
valid_mos_count += 1
avg_wer = total_wer / len(sweep_sentences)
avg_mos = total_mos / valid_mos_count if valid_mos_count > 0 else 0
result_entry = {
"trial_id": i,
"params": params,
"avg_wer": avg_wer,
"avg_mos": avg_mos
}
sweep_results.append(result_entry)
print(f"Result: WER={avg_wer:.4f}, MOS={avg_mos:.2f}")
# Save intermediate results
with open(os.path.join(OUTPUT_BASE_DIR, "sweep_summary_partial.json"), "w") as f:
json.dump(sweep_results, f, indent=4)
# Find the best combination
# We want low WER and high MOS. A simple score: MOS * (1 - WER)
best_score = -1
best_params = None
for r in sweep_results:
score = r['avg_mos'] * (1 - r['avg_wer'])
if score > best_score:
best_score = score
best_params = r
print("\n" + "="*60)
print("SWEEP COMPLETE")
print(f"Best Params: {best_params['params']}")
print(f"Best Metrics: WER={best_params['avg_wer']:.4f}, MOS={best_params['avg_mos']:.2f}")
print("="*60)
with open(os.path.join(OUTPUT_BASE_DIR, "sweep_summary.json"), "w") as f:
json.dump(sweep_results, f, indent=4)
if __name__ == "__main__":
main()
|