multilingual-emotion-classifier / validate_model.py
rmtariq's picture
πŸ§ͺ Add Automated validation script
4dfba2e verified
"""
Automated validation script for rmtariq/multilingual-emotion-classifier
This script runs automated tests and generates a validation report.
Usage:
python validate_model.py
python validate_model.py --output report.txt
Author: rmtariq
"""
import argparse
import json
import time
from datetime import datetime
from transformers import pipeline
import torch
def validate_model(model_name="rmtariq/multilingual-emotion-classifier", output_file=None):
"""Run comprehensive validation and generate report"""
print("πŸ” AUTOMATED MODEL VALIDATION")
print("=" * 60)
print(f"Model: {model_name}")
print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print()
# Initialize results
validation_results = {
"model_name": model_name,
"timestamp": datetime.now().isoformat(),
"device": "GPU" if torch.cuda.is_available() else "CPU",
"tests": {},
"overall_status": "UNKNOWN"
}
try:
# Load model
print("πŸ“₯ Loading model...")
classifier = pipeline(
"text-classification",
model=model_name,
device=0 if torch.cuda.is_available() else -1
)
print(f"βœ… Model loaded on {validation_results['device']}")
# Test 1: Critical functionality
print("\nπŸ§ͺ Test 1: Critical Functionality")
print("-" * 40)
critical_cases = [
("I am happy", "happy"),
("I am angry", "anger"),
("I love this", "love"),
("I am scared", "fear"),
("I am sad", "sadness"),
("What a surprise", "surprise")
]
critical_correct = 0
for text, expected in critical_cases:
result = classifier(text)
predicted = result[0]['label'].lower()
is_correct = predicted == expected
if is_correct:
critical_correct += 1
status = "βœ…" if is_correct else "❌"
print(f" {status} '{text}' β†’ {predicted}")
critical_accuracy = critical_correct / len(critical_cases)
validation_results["tests"]["critical_functionality"] = {
"accuracy": critical_accuracy,
"passed": critical_accuracy >= 0.8,
"details": f"{critical_correct}/{len(critical_cases)} correct"
}
print(f" πŸ“Š Critical Accuracy: {critical_accuracy:.1%}")
# Test 2: Malay fixes validation
print("\nπŸ§ͺ Test 2: Malay Fixes Validation")
print("-" * 40)
malay_fixes = [
("Ini adalah hari jadi terbaik", "happy"),
("Terbaik!", "happy"),
("Ini adalah hari yang baik", "happy"),
("Pengalaman terbaik", "happy")
]
malay_correct = 0
for text, expected in malay_fixes:
result = classifier(text)
predicted = result[0]['label'].lower()
is_correct = predicted == expected
if is_correct:
malay_correct += 1
status = "βœ…" if is_correct else "❌"
print(f" {status} '{text}' β†’ {predicted}")
malay_accuracy = malay_correct / len(malay_fixes)
validation_results["tests"]["malay_fixes"] = {
"accuracy": malay_accuracy,
"passed": malay_accuracy >= 0.8,
"details": f"{malay_correct}/{len(malay_fixes)} correct"
}
print(f" πŸ“Š Malay Fixes Accuracy: {malay_accuracy:.1%}")
# Test 3: Performance benchmark
print("\nπŸ§ͺ Test 3: Performance Benchmark")
print("-" * 40)
benchmark_texts = ["I am happy"] * 20
start_time = time.time()
for text in benchmark_texts:
_ = classifier(text)
end_time = time.time()
total_time = end_time - start_time
predictions_per_second = len(benchmark_texts) / total_time
validation_results["tests"]["performance"] = {
"predictions_per_second": predictions_per_second,
"passed": predictions_per_second >= 3.0,
"details": f"{predictions_per_second:.1f} predictions/second"
}
print(f" ⚑ Speed: {predictions_per_second:.1f} predictions/second")
# Test 4: Confidence validation
print("\nπŸ§ͺ Test 4: Confidence Validation")
print("-" * 40)
confidence_cases = [
"I am extremely happy today!",
"I absolutely love this!",
"I am terrified!",
"Saya sangat gembira!",
"Terbaik!"
]
high_confidence_count = 0
total_confidence = 0
for text in confidence_cases:
result = classifier(text)
confidence = result[0]['score']
total_confidence += confidence
if confidence > 0.8:
high_confidence_count += 1
print(f" πŸ“Š '{text[:30]}...' β†’ {confidence:.1%}")
avg_confidence = total_confidence / len(confidence_cases)
high_confidence_rate = high_confidence_count / len(confidence_cases)
validation_results["tests"]["confidence"] = {
"average_confidence": avg_confidence,
"high_confidence_rate": high_confidence_rate,
"passed": avg_confidence >= 0.7 and high_confidence_rate >= 0.6,
"details": f"Avg: {avg_confidence:.1%}, High: {high_confidence_rate:.1%}"
}
print(f" πŸ“Š Average Confidence: {avg_confidence:.1%}")
print(f" πŸ“Š High Confidence Rate: {high_confidence_rate:.1%}")
# Overall assessment
print("\n🎯 VALIDATION SUMMARY")
print("=" * 60)
all_tests_passed = all(test["passed"] for test in validation_results["tests"].values())
if all_tests_passed:
validation_results["overall_status"] = "PASS"
print("πŸŽ‰ VALIDATION PASSED!")
print("βœ… All tests passed successfully")
print("βœ… Model is ready for production use")
else:
validation_results["overall_status"] = "FAIL"
print("❌ VALIDATION FAILED!")
print("⚠️ Some tests did not meet requirements")
failed_tests = [name for name, test in validation_results["tests"].items() if not test["passed"]]
print(f"❌ Failed tests: {', '.join(failed_tests)}")
# Print detailed results
print("\nπŸ“‹ DETAILED RESULTS:")
for test_name, test_result in validation_results["tests"].items():
status = "βœ… PASS" if test_result["passed"] else "❌ FAIL"
print(f" {status} {test_name.replace('_', ' ').title()}: {test_result['details']}")
# Save results if output file specified
if output_file:
with open(output_file, 'w') as f:
json.dump(validation_results, f, indent=2)
print(f"\nπŸ’Ύ Results saved to: {output_file}")
return validation_results
except Exception as e:
print(f"❌ Validation failed with error: {e}")
validation_results["overall_status"] = "ERROR"
validation_results["error"] = str(e)
return validation_results
def main():
"""Main validation function"""
parser = argparse.ArgumentParser(description="Validate the multilingual emotion classifier")
parser.add_argument(
"--model",
default="rmtariq/multilingual-emotion-classifier",
help="Model name or path to validate"
)
parser.add_argument(
"--output",
help="Output file for validation results (JSON format)"
)
args = parser.parse_args()
results = validate_model(args.model, args.output)
# Exit with appropriate code
if results["overall_status"] == "PASS":
return 0
else:
return 1
if __name__ == "__main__":
exit(main())