#!/usr/bin/env python3 """ Main script for ensemble annotation. Usage: python annotate_ensemble.py --input dataset_name --mode balanced --output results.parquet """ import argparse import logging import sys from pathlib import Path import pandas as pd from tqdm import tqdm # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from ensemble_tts.models.emotion import EmotionEnsemble logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) def main(): parser = argparse.ArgumentParser(description="Ensemble TTS Annotation") parser.add_argument('--input', type=str, required=True, help='Input dataset name or path') parser.add_argument('--mode', type=str, default='balanced', choices=['quick', 'balanced', 'full'], help='Ensemble mode (quick=2 models, balanced=3, full=5)') parser.add_argument('--output', type=str, default='data/annotated/ensemble_results.parquet', help='Output file path') parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'cuda'], help='Device to use (cpu or cuda)') parser.add_argument('--voting', type=str, default='weighted', choices=['majority', 'weighted', 'confidence'], help='Voting strategy') parser.add_argument('--max-samples', type=int, default=None, help='Maximum number of samples to process (for testing)') args = parser.parse_args() logger.info("="*60) logger.info("ENSEMBLE TTS ANNOTATION") logger.info("="*60) logger.info(f"Mode: {args.mode}") logger.info(f"Device: {args.device}") logger.info(f"Voting: {args.voting}") # Initialize ensemble logger.info("\n[1/4] Initializing Ensemble...") ensemble = EmotionEnsemble( mode=args.mode, device=args.device, voting_strategy=args.voting ) # Load models logger.info("\n[2/4] Loading Models...") try: ensemble.load_models() except Exception as e: logger.error(f"Failed to load models: {e}") logger.info("\nPlease ensure all dependencies are installed:") logger.info(" pip install -r requirements.txt") sys.exit(1) # Load dataset logger.info(f"\n[3/4] Loading Dataset: {args.input}") try: from datasets import load_dataset dataset = load_dataset(args.input, split='train') if args.max_samples: dataset = dataset.select(range(min(args.max_samples, len(dataset)))) logger.info(f"Dataset loaded: {len(dataset)} samples") except Exception as e: logger.error(f"Failed to load dataset: {e}") sys.exit(1) # Annotate logger.info("\n[4/4] Annotating...") results = [] for idx, sample in enumerate(tqdm(dataset, desc="Processing")): try: # Get audio audio_data = sample.get('audio', {}) if isinstance(audio_data, dict): audio_array = audio_data.get('array', []) sample_rate = audio_data.get('sampling_rate', 16000) else: logger.warning(f"Sample {idx}: No audio data") continue # Predict with ensemble prediction = ensemble.predict(audio_array, sample_rate) # Add to results result = { 'index': idx, 'text': sample.get('text', ''), 'emotion_label': prediction['label'], 'emotion_confidence': prediction['confidence'], 'emotion_agreement': prediction['agreement'], 'emotion_votes': str(prediction.get('votes', {})), 'num_models': len(prediction.get('predictions', [])) } results.append(result) except Exception as e: logger.error(f"Error processing sample {idx}: {e}") continue # Save results logger.info(f"\nSaving results to: {args.output}") output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) df = pd.DataFrame(results) df.to_parquet(output_path, index=False) logger.info(f"āœ… Saved {len(df)} annotated samples") # Print statistics logger.info("\n" + "="*60) logger.info("STATISTICS") logger.info("="*60) logger.info(f"Total samples: {len(df)}") logger.info(f"Average confidence: {df['emotion_confidence'].mean():.3f}") logger.info(f"Average agreement: {df['emotion_agreement'].mean():.3f}") logger.info("\nEmotion distribution:") logger.info(df['emotion_label'].value_counts()) logger.info("\nāœ… DONE!") if __name__ == "__main__": main()