|
|
""" |
|
|
Quick test script for OPTION A ensemble. |
|
|
|
|
|
Tests: |
|
|
1. Model loading |
|
|
2. Single audio annotation |
|
|
3. Batch processing |
|
|
4. Performance benchmarking |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import logging |
|
|
import time |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent)) |
|
|
|
|
|
from ensemble_tts import EnsembleAnnotator |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def test_model_loading(): |
|
|
"""Test 1: Model Loading""" |
|
|
logger.info("=" * 60) |
|
|
logger.info("TEST 1: Model Loading") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
try: |
|
|
annotator = EnsembleAnnotator( |
|
|
mode='quick', |
|
|
device='cpu', |
|
|
enable_events=False |
|
|
) |
|
|
|
|
|
start = time.time() |
|
|
annotator.load_models() |
|
|
elapsed = time.time() - start |
|
|
|
|
|
logger.info(f"β
Models loaded successfully in {elapsed:.2f}s") |
|
|
return annotator, True |
|
|
except Exception as e: |
|
|
logger.error(f"β Model loading failed: {e}") |
|
|
return None, False |
|
|
|
|
|
|
|
|
def test_single_annotation(annotator): |
|
|
"""Test 2: Single Audio Annotation""" |
|
|
logger.info("\n" + "=" * 60) |
|
|
logger.info("TEST 2: Single Audio Annotation") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
try: |
|
|
|
|
|
audio = np.random.randn(16000 * 3).astype(np.float32) |
|
|
|
|
|
start = time.time() |
|
|
result = annotator.annotate(audio, sample_rate=16000) |
|
|
elapsed = time.time() - start |
|
|
|
|
|
logger.info(f"\nπ Annotation Result:") |
|
|
logger.info(f" Emotion: {result['emotion']['label']}") |
|
|
logger.info(f" Confidence: {result['emotion']['confidence']:.2%}") |
|
|
logger.info(f" Agreement: {result['emotion']['agreement']:.2%}") |
|
|
logger.info(f" Votes: {result['emotion']['votes']}") |
|
|
logger.info(f" Time: {elapsed:.2f}s") |
|
|
|
|
|
|
|
|
assert 'emotion' in result |
|
|
assert 'label' in result['emotion'] |
|
|
assert 'confidence' in result['emotion'] |
|
|
assert result['emotion']['confidence'] >= 0 and result['emotion']['confidence'] <= 1 |
|
|
|
|
|
logger.info(f"\nβ
Single annotation successful") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"β Single annotation failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def test_batch_processing(annotator): |
|
|
"""Test 3: Batch Processing""" |
|
|
logger.info("\n" + "=" * 60) |
|
|
logger.info("TEST 3: Batch Processing") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
try: |
|
|
|
|
|
batch_size = 5 |
|
|
audios = [np.random.randn(16000 * (i + 1)).astype(np.float32) for i in range(batch_size)] |
|
|
|
|
|
start = time.time() |
|
|
results = annotator.annotate_batch(audios, sample_rates=[16000] * batch_size) |
|
|
elapsed = time.time() - start |
|
|
|
|
|
logger.info(f"\nπ Batch Results:") |
|
|
for i, result in enumerate(results): |
|
|
logger.info(f" Sample {i+1}: {result['emotion']['label']} ({result['emotion']['confidence']:.2%})") |
|
|
|
|
|
logger.info(f"\n Total time: {elapsed:.2f}s") |
|
|
logger.info(f" Average time per sample: {elapsed/batch_size:.2f}s") |
|
|
|
|
|
|
|
|
assert len(results) == batch_size |
|
|
|
|
|
logger.info(f"\nβ
Batch processing successful") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"β Batch processing failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def test_balanced_mode(): |
|
|
"""Test 4: Balanced Mode (OPTION A)""" |
|
|
logger.info("\n" + "=" * 60) |
|
|
logger.info("TEST 4: Balanced Mode (OPTION A)") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
try: |
|
|
annotator_balanced = EnsembleAnnotator( |
|
|
mode='balanced', |
|
|
device='cpu', |
|
|
enable_events=False |
|
|
) |
|
|
|
|
|
start = time.time() |
|
|
annotator_balanced.load_models() |
|
|
load_time = time.time() - start |
|
|
logger.info(f" Load time: {load_time:.2f}s") |
|
|
|
|
|
|
|
|
audio = np.random.randn(16000 * 3).astype(np.float32) |
|
|
|
|
|
start = time.time() |
|
|
result = annotator_balanced.annotate(audio, sample_rate=16000) |
|
|
annotate_time = time.time() - start |
|
|
|
|
|
logger.info(f"\nπ Balanced Mode Result:") |
|
|
logger.info(f" Emotion: {result['emotion']['label']}") |
|
|
logger.info(f" Confidence: {result['emotion']['confidence']:.2%}") |
|
|
logger.info(f" Agreement: {result['emotion']['agreement']:.2%}") |
|
|
logger.info(f" Number of predictions: {len(result['emotion']['predictions'])}") |
|
|
logger.info(f" Annotation time: {annotate_time:.2f}s") |
|
|
|
|
|
|
|
|
assert len(result['emotion']['predictions']) == 3, \ |
|
|
f"Expected 3 predictions, got {len(result['emotion']['predictions'])}" |
|
|
|
|
|
logger.info(f"\nβ
Balanced mode (OPTION A) successful") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"β Balanced mode failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def benchmark_modes(): |
|
|
"""Test 5: Benchmark All Modes""" |
|
|
logger.info("\n" + "=" * 60) |
|
|
logger.info("TEST 5: Performance Benchmark") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
modes = ['quick', 'balanced'] |
|
|
audio = np.random.randn(16000 * 3).astype(np.float32) |
|
|
|
|
|
results = {} |
|
|
|
|
|
for mode in modes: |
|
|
logger.info(f"\nπ Testing {mode.upper()} mode...") |
|
|
|
|
|
try: |
|
|
annotator = EnsembleAnnotator( |
|
|
mode=mode, |
|
|
device='cpu', |
|
|
enable_events=False |
|
|
) |
|
|
|
|
|
|
|
|
start = time.time() |
|
|
annotator.load_models() |
|
|
load_time = time.time() - start |
|
|
|
|
|
|
|
|
times = [] |
|
|
for _ in range(3): |
|
|
start = time.time() |
|
|
result = annotator.annotate(audio, sample_rate=16000) |
|
|
times.append(time.time() - start) |
|
|
|
|
|
avg_time = np.mean(times) |
|
|
|
|
|
results[mode] = { |
|
|
'load_time': load_time, |
|
|
'avg_annotation_time': avg_time, |
|
|
'num_models': len(result['emotion']['predictions']) |
|
|
} |
|
|
|
|
|
logger.info(f" Load time: {load_time:.2f}s") |
|
|
logger.info(f" Avg annotation time: {avg_time:.2f}s") |
|
|
logger.info(f" Models: {results[mode]['num_models']}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f" β {mode} mode failed: {e}") |
|
|
results[mode] = {'error': str(e)} |
|
|
|
|
|
|
|
|
logger.info("\n" + "=" * 60) |
|
|
logger.info("BENCHMARK SUMMARY") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
for mode, metrics in results.items(): |
|
|
if 'error' not in metrics: |
|
|
logger.info(f"\n{mode.upper()} MODE:") |
|
|
logger.info(f" Models: {metrics['num_models']}") |
|
|
logger.info(f" Load: {metrics['load_time']:.2f}s") |
|
|
logger.info(f" Annotation: {metrics['avg_annotation_time']:.2f}s/sample") |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Run all tests""" |
|
|
logger.info("\n" + "=" * 60) |
|
|
logger.info("ENSEMBLE TTS ANNOTATION - QUICK TEST") |
|
|
logger.info("OPTION A - Balanced Mode (3 models)") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
results = { |
|
|
'model_loading': False, |
|
|
'single_annotation': False, |
|
|
'batch_processing': False, |
|
|
'balanced_mode': False, |
|
|
'benchmark': False |
|
|
} |
|
|
|
|
|
|
|
|
annotator, success = test_model_loading() |
|
|
results['model_loading'] = success |
|
|
|
|
|
if not success: |
|
|
logger.error("\nβ Model loading failed. Cannot continue tests.") |
|
|
return False |
|
|
|
|
|
|
|
|
results['single_annotation'] = test_single_annotation(annotator) |
|
|
|
|
|
|
|
|
results['batch_processing'] = test_batch_processing(annotator) |
|
|
|
|
|
|
|
|
results['balanced_mode'] = test_balanced_mode() |
|
|
|
|
|
|
|
|
results['benchmark'] = benchmark_modes() |
|
|
|
|
|
|
|
|
logger.info("\n" + "=" * 60) |
|
|
logger.info("TEST SUMMARY") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
for test_name, success in results.items(): |
|
|
status = "β
PASS" if success else "β FAIL" |
|
|
logger.info(f" {test_name}: {status}") |
|
|
|
|
|
all_passed = all(results.values()) |
|
|
|
|
|
if all_passed: |
|
|
logger.info("\nπ ALL TESTS PASSED!") |
|
|
logger.info("\nSystem is ready for production use.") |
|
|
else: |
|
|
logger.error("\nβ SOME TESTS FAILED") |
|
|
logger.error("\nPlease check the logs above for details.") |
|
|
|
|
|
logger.info("\n" + "=" * 60) |
|
|
|
|
|
return all_passed |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
success = main() |
|
|
sys.exit(0 if success else 1) |
|
|
|