ensemble-tts-annotation / scripts /test /test_real_audio.py
marcosremar
πŸš€ SkyPilot Multi-Cloud GPU Support + Synthetic Data Generation
13e402e
"""
Test ensemble annotation with real/synthetic audio files.
This script tests the complete annotation pipeline with actual audio,
validating both emotion and event detection.
"""
import logging
import argparse
from pathlib import Path
import sys
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from ensemble_tts import EnsembleAnnotator
import numpy as np
import soundfile as sf
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)
def test_single_audio(annotator: EnsembleAnnotator, audio_path: Path):
"""Test annotation on a single audio file."""
logger.info(f"\n🎡 Testing: {audio_path.name}")
logger.info("=" * 60)
# Load audio
audio, sr = sf.read(audio_path)
logger.info(f" Audio: {len(audio)/sr:.2f}s, {sr}Hz")
# Annotate
result = annotator.annotate(audio, sample_rate=sr)
# Show results
logger.info(f"\n πŸ“Š Emotion Results:")
logger.info(f" Label: {result['emotion']['label']}")
logger.info(f" Confidence: {result['emotion']['confidence']:.2%}")
if 'predictions' in result['emotion']:
logger.info(f"\n Individual model predictions:")
for pred in result['emotion']['predictions']:
logger.info(f" {pred['model_name']:15s}: {pred['label']:10s} ({pred.get('confidence', 0.0):.2%})")
if result.get('events') and result['events'].get('detected'):
logger.info(f"\n 🎭 Events Detected:")
for event in result['events']['detected']:
logger.info(f" - {event}")
return result
def test_dataset_sample(annotator: EnsembleAnnotator, dataset_path: Path, n_samples: int = 5):
"""Test annotation on a sample of prepared dataset."""
from datasets import load_from_disk
logger.info(f"\nπŸ“¦ Loading dataset from: {dataset_path}")
dataset = load_from_disk(str(dataset_path))
logger.info(f" Total samples: {len(dataset)}")
logger.info(f" Testing {n_samples} random samples...")
# Random sample
import random
indices = random.sample(range(len(dataset)), min(n_samples, len(dataset)))
results = []
correct = 0
for i, idx in enumerate(indices, 1):
sample = dataset[idx]
audio_array = sample['audio']['array']
sr = sample['audio']['sampling_rate']
true_emotion = sample['emotion']
logger.info(f"\n{'='*60}")
logger.info(f"Sample {i}/{n_samples} - True emotion: {true_emotion}")
logger.info(f"{'='*60}")
# Annotate
result = annotator.annotate(audio_array, sample_rate=sr)
predicted_emotion = result['emotion']['label']
confidence = result['emotion']['confidence']
logger.info(f" Predicted: {predicted_emotion} ({confidence:.2%})")
if predicted_emotion == true_emotion:
logger.info(f" βœ… CORRECT")
correct += 1
else:
logger.info(f" ❌ INCORRECT (expected: {true_emotion})")
results.append({
'true': true_emotion,
'predicted': predicted_emotion,
'confidence': confidence,
'correct': predicted_emotion == true_emotion
})
# Summary
accuracy = correct / len(results)
logger.info(f"\n{'='*60}")
logger.info(f"πŸ“Š TEST SUMMARY")
logger.info(f"{'='*60}")
logger.info(f" Samples tested: {len(results)}")
logger.info(f" Correct: {correct}")
logger.info(f" Accuracy: {accuracy:.2%}")
logger.info(f"{'='*60}")
return results
def main():
parser = argparse.ArgumentParser(description="Test annotation with real audio")
parser.add_argument("--mode", type=str, default="quick",
choices=["quick", "balanced", "full"],
help="Ensemble mode")
parser.add_argument("--device", type=str, default="cpu",
choices=["cpu", "cuda"],
help="Device to use")
parser.add_argument("--audio", type=str, default=None,
help="Path to single audio file")
parser.add_argument("--dataset", type=str, default="data/prepared/synthetic_prepared",
help="Path to prepared dataset")
parser.add_argument("--samples", type=int, default=5,
help="Number of dataset samples to test")
parser.add_argument("--no-events", action="store_true",
help="Disable event detection")
args = parser.parse_args()
logger.info("\n" + "="*60)
logger.info("🎯 Ensemble Audio Annotation Test")
logger.info("="*60)
logger.info(f" Mode: {args.mode}")
logger.info(f" Device: {args.device}")
logger.info(f" Events: {'disabled' if args.no_events else 'enabled'}")
# Create annotator
logger.info("\nπŸ“¦ Creating annotator...")
annotator = EnsembleAnnotator(
mode=args.mode,
device=args.device,
enable_events=not args.no_events
)
# Load models
logger.info("πŸ“₯ Loading models...")
annotator.load_models()
logger.info("βœ… Models loaded!")
# Test single audio file
if args.audio:
audio_path = Path(args.audio)
if not audio_path.exists():
logger.error(f"❌ Audio file not found: {audio_path}")
return 1
test_single_audio(annotator, audio_path)
# Test dataset samples
elif Path(args.dataset).exists():
test_dataset_sample(annotator, Path(args.dataset), args.samples)
else:
logger.error(f"❌ Dataset not found: {args.dataset}")
logger.error("\nCreate synthetic dataset first:")
logger.error(" python scripts/data/create_synthetic_test_data.py")
return 1
logger.info("\nβœ… Test complete!")
return 0
if __name__ == "__main__":
sys.exit(main())