|
|
|
|
|
""" |
|
|
Main script for ensemble annotation. |
|
|
|
|
|
Usage: |
|
|
python annotate_ensemble.py --input dataset_name --mode balanced --output results.parquet |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import sys |
|
|
from pathlib import Path |
|
|
import pandas as pd |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent)) |
|
|
|
|
|
from ensemble_tts.models.emotion import EmotionEnsemble |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Ensemble TTS Annotation") |
|
|
|
|
|
parser.add_argument('--input', type=str, required=True, |
|
|
help='Input dataset name or path') |
|
|
parser.add_argument('--mode', type=str, default='balanced', |
|
|
choices=['quick', 'balanced', 'full'], |
|
|
help='Ensemble mode (quick=2 models, balanced=3, full=5)') |
|
|
parser.add_argument('--output', type=str, default='data/annotated/ensemble_results.parquet', |
|
|
help='Output file path') |
|
|
parser.add_argument('--device', type=str, default='cpu', |
|
|
choices=['cpu', 'cuda'], |
|
|
help='Device to use (cpu or cuda)') |
|
|
parser.add_argument('--voting', type=str, default='weighted', |
|
|
choices=['majority', 'weighted', 'confidence'], |
|
|
help='Voting strategy') |
|
|
parser.add_argument('--max-samples', type=int, default=None, |
|
|
help='Maximum number of samples to process (for testing)') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
logger.info("="*60) |
|
|
logger.info("ENSEMBLE TTS ANNOTATION") |
|
|
logger.info("="*60) |
|
|
logger.info(f"Mode: {args.mode}") |
|
|
logger.info(f"Device: {args.device}") |
|
|
logger.info(f"Voting: {args.voting}") |
|
|
|
|
|
|
|
|
logger.info("\n[1/4] Initializing Ensemble...") |
|
|
ensemble = EmotionEnsemble( |
|
|
mode=args.mode, |
|
|
device=args.device, |
|
|
voting_strategy=args.voting |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("\n[2/4] Loading Models...") |
|
|
try: |
|
|
ensemble.load_models() |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load models: {e}") |
|
|
logger.info("\nPlease ensure all dependencies are installed:") |
|
|
logger.info(" pip install -r requirements.txt") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
logger.info(f"\n[3/4] Loading Dataset: {args.input}") |
|
|
try: |
|
|
from datasets import load_dataset |
|
|
dataset = load_dataset(args.input, split='train') |
|
|
|
|
|
if args.max_samples: |
|
|
dataset = dataset.select(range(min(args.max_samples, len(dataset)))) |
|
|
|
|
|
logger.info(f"Dataset loaded: {len(dataset)} samples") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load dataset: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
logger.info("\n[4/4] Annotating...") |
|
|
results = [] |
|
|
|
|
|
for idx, sample in enumerate(tqdm(dataset, desc="Processing")): |
|
|
try: |
|
|
|
|
|
audio_data = sample.get('audio', {}) |
|
|
if isinstance(audio_data, dict): |
|
|
audio_array = audio_data.get('array', []) |
|
|
sample_rate = audio_data.get('sampling_rate', 16000) |
|
|
else: |
|
|
logger.warning(f"Sample {idx}: No audio data") |
|
|
continue |
|
|
|
|
|
|
|
|
prediction = ensemble.predict(audio_array, sample_rate) |
|
|
|
|
|
|
|
|
result = { |
|
|
'index': idx, |
|
|
'text': sample.get('text', ''), |
|
|
'emotion_label': prediction['label'], |
|
|
'emotion_confidence': prediction['confidence'], |
|
|
'emotion_agreement': prediction['agreement'], |
|
|
'emotion_votes': str(prediction.get('votes', {})), |
|
|
'num_models': len(prediction.get('predictions', [])) |
|
|
} |
|
|
|
|
|
results.append(result) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing sample {idx}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
logger.info(f"\nSaving results to: {args.output}") |
|
|
output_path = Path(args.output) |
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
df = pd.DataFrame(results) |
|
|
df.to_parquet(output_path, index=False) |
|
|
|
|
|
logger.info(f"✅ Saved {len(df)} annotated samples") |
|
|
|
|
|
|
|
|
logger.info("\n" + "="*60) |
|
|
logger.info("STATISTICS") |
|
|
logger.info("="*60) |
|
|
logger.info(f"Total samples: {len(df)}") |
|
|
logger.info(f"Average confidence: {df['emotion_confidence'].mean():.3f}") |
|
|
logger.info(f"Average agreement: {df['emotion_agreement'].mean():.3f}") |
|
|
logger.info("\nEmotion distribution:") |
|
|
logger.info(df['emotion_label'].value_counts()) |
|
|
|
|
|
logger.info("\n✅ DONE!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|