marcosremar
Initial commit: Ensemble TTS Annotation System
06b4215
#!/usr/bin/env python3
"""
Main script for ensemble annotation.
Usage:
python annotate_ensemble.py --input dataset_name --mode balanced --output results.parquet
"""
import argparse
import logging
import sys
from pathlib import Path
import pandas as pd
from tqdm import tqdm
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from ensemble_tts.models.emotion import EmotionEnsemble
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(description="Ensemble TTS Annotation")
parser.add_argument('--input', type=str, required=True,
help='Input dataset name or path')
parser.add_argument('--mode', type=str, default='balanced',
choices=['quick', 'balanced', 'full'],
help='Ensemble mode (quick=2 models, balanced=3, full=5)')
parser.add_argument('--output', type=str, default='data/annotated/ensemble_results.parquet',
help='Output file path')
parser.add_argument('--device', type=str, default='cpu',
choices=['cpu', 'cuda'],
help='Device to use (cpu or cuda)')
parser.add_argument('--voting', type=str, default='weighted',
choices=['majority', 'weighted', 'confidence'],
help='Voting strategy')
parser.add_argument('--max-samples', type=int, default=None,
help='Maximum number of samples to process (for testing)')
args = parser.parse_args()
logger.info("="*60)
logger.info("ENSEMBLE TTS ANNOTATION")
logger.info("="*60)
logger.info(f"Mode: {args.mode}")
logger.info(f"Device: {args.device}")
logger.info(f"Voting: {args.voting}")
# Initialize ensemble
logger.info("\n[1/4] Initializing Ensemble...")
ensemble = EmotionEnsemble(
mode=args.mode,
device=args.device,
voting_strategy=args.voting
)
# Load models
logger.info("\n[2/4] Loading Models...")
try:
ensemble.load_models()
except Exception as e:
logger.error(f"Failed to load models: {e}")
logger.info("\nPlease ensure all dependencies are installed:")
logger.info(" pip install -r requirements.txt")
sys.exit(1)
# Load dataset
logger.info(f"\n[3/4] Loading Dataset: {args.input}")
try:
from datasets import load_dataset
dataset = load_dataset(args.input, split='train')
if args.max_samples:
dataset = dataset.select(range(min(args.max_samples, len(dataset))))
logger.info(f"Dataset loaded: {len(dataset)} samples")
except Exception as e:
logger.error(f"Failed to load dataset: {e}")
sys.exit(1)
# Annotate
logger.info("\n[4/4] Annotating...")
results = []
for idx, sample in enumerate(tqdm(dataset, desc="Processing")):
try:
# Get audio
audio_data = sample.get('audio', {})
if isinstance(audio_data, dict):
audio_array = audio_data.get('array', [])
sample_rate = audio_data.get('sampling_rate', 16000)
else:
logger.warning(f"Sample {idx}: No audio data")
continue
# Predict with ensemble
prediction = ensemble.predict(audio_array, sample_rate)
# Add to results
result = {
'index': idx,
'text': sample.get('text', ''),
'emotion_label': prediction['label'],
'emotion_confidence': prediction['confidence'],
'emotion_agreement': prediction['agreement'],
'emotion_votes': str(prediction.get('votes', {})),
'num_models': len(prediction.get('predictions', []))
}
results.append(result)
except Exception as e:
logger.error(f"Error processing sample {idx}: {e}")
continue
# Save results
logger.info(f"\nSaving results to: {args.output}")
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame(results)
df.to_parquet(output_path, index=False)
logger.info(f"✅ Saved {len(df)} annotated samples")
# Print statistics
logger.info("\n" + "="*60)
logger.info("STATISTICS")
logger.info("="*60)
logger.info(f"Total samples: {len(df)}")
logger.info(f"Average confidence: {df['emotion_confidence'].mean():.3f}")
logger.info(f"Average agreement: {df['emotion_agreement'].mean():.3f}")
logger.info("\nEmotion distribution:")
logger.info(df['emotion_label'].value_counts())
logger.info("\n✅ DONE!")
if __name__ == "__main__":
main()