File size: 4,117 Bytes
98938e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Local test - Verifica se o sistema funciona antes de provisionar máquina.
"""

import sys
import logging

logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)

def test_imports():
    """Test if all imports work"""
    logger.info("=" * 60)
    logger.info("TEST: Imports")
    logger.info("=" * 60)

    try:
        logger.info("Importing ensemble_tts...")
        from ensemble_tts import EnsembleAnnotator
        logger.info("✅ Import successful")
        return True
    except Exception as e:
        logger.error(f"❌ Import failed: {e}")
        import traceback
        traceback.print_exc()
        return False

def test_create_annotator():
    """Test creating annotator without loading models"""
    logger.info("\n" + "=" * 60)
    logger.info("TEST: Create Annotator (no model loading)")
    logger.info("=" * 60)

    try:
        from ensemble_tts import EnsembleAnnotator

        logger.info("Creating annotator in quick mode...")
        annotator = EnsembleAnnotator(
            mode='quick',
            device='cpu',
            enable_events=False
        )
        logger.info(f"  Mode: {annotator.mode}")
        logger.info(f"  Device: {annotator.device}")
        logger.info(f"  Voting: {annotator.voting_strategy}")
        logger.info("✅ Annotator created successfully")
        return annotator, True
    except Exception as e:
        logger.error(f"❌ Annotator creation failed: {e}")
        import traceback
        traceback.print_exc()
        return None, False

def test_model_structure():
    """Test model structure without loading weights"""
    logger.info("\n" + "=" * 60)
    logger.info("TEST: Model Structure")
    logger.info("=" * 60)

    try:
        from ensemble_tts.models.emotion import EmotionEnsemble

        logger.info("Creating emotion ensemble...")
        ensemble = EmotionEnsemble(mode='quick', device='cpu')

        logger.info(f"  Number of models: {len(ensemble.models)}")
        for model in ensemble.models:
            logger.info(f"    - {model.name} (weight: {model.weight})")

        logger.info("✅ Model structure correct")
        return True
    except Exception as e:
        logger.error(f"❌ Model structure test failed: {e}")
        import traceback
        traceback.print_exc()
        return False

def main():
    """Run local tests"""
    logger.info("\n" + "=" * 60)
    logger.info("ENSEMBLE TTS ANNOTATION - LOCAL TEST")
    logger.info("Testing without loading model weights")
    logger.info("=" * 60 + "\n")

    results = {}

    # Test 1: Imports
    results['imports'] = test_imports()

    if not results['imports']:
        logger.error("\n❌ Import test failed. Please install requirements:")
        logger.error("  pip install -r requirements.txt")
        return False

    # Test 2: Create annotator
    annotator, success = test_create_annotator()
    results['create_annotator'] = success

    # Test 3: Model structure
    results['model_structure'] = test_model_structure()

    # Summary
    logger.info("\n" + "=" * 60)
    logger.info("TEST SUMMARY")
    logger.info("=" * 60)

    for test_name, success in results.items():
        status = "✅ PASS" if success else "❌ FAIL"
        logger.info(f"  {test_name}: {status}")

    all_passed = all(results.values())

    if all_passed:
        logger.info("\n" + "=" * 60)
        logger.info("✅ ALL LOCAL TESTS PASSED!")
        logger.info("=" * 60)
        logger.info("\nNext steps:")
        logger.info("1. Run full test (downloads models):")
        logger.info("   python scripts/test/test_quick.py")
        logger.info("\n2. Or test on spot instance:")
        logger.info("   bash scripts/test/launch_spot_test.sh")
        logger.info("")
    else:
        logger.error("\n" + "=" * 60)
        logger.error("❌ SOME TESTS FAILED")
        logger.error("=" * 60)
        logger.error("\nPlease check errors above and fix before proceeding.")
        logger.error("")

    return all_passed

if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)