|
|
""" |
|
|
Automated validation script for rmtariq/multilingual-emotion-classifier |
|
|
This script runs automated tests and generates a validation report. |
|
|
|
|
|
Usage: |
|
|
python validate_model.py |
|
|
python validate_model.py --output report.txt |
|
|
|
|
|
Author: rmtariq |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import time |
|
|
from datetime import datetime |
|
|
from transformers import pipeline |
|
|
import torch |
|
|
|
|
|
def validate_model(model_name="rmtariq/multilingual-emotion-classifier", output_file=None): |
|
|
"""Run comprehensive validation and generate report""" |
|
|
|
|
|
print("π AUTOMATED MODEL VALIDATION") |
|
|
print("=" * 60) |
|
|
print(f"Model: {model_name}") |
|
|
print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
print() |
|
|
|
|
|
|
|
|
validation_results = { |
|
|
"model_name": model_name, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"device": "GPU" if torch.cuda.is_available() else "CPU", |
|
|
"tests": {}, |
|
|
"overall_status": "UNKNOWN" |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
print("π₯ Loading model...") |
|
|
classifier = pipeline( |
|
|
"text-classification", |
|
|
model=model_name, |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
print(f"β
Model loaded on {validation_results['device']}") |
|
|
|
|
|
|
|
|
print("\nπ§ͺ Test 1: Critical Functionality") |
|
|
print("-" * 40) |
|
|
|
|
|
critical_cases = [ |
|
|
("I am happy", "happy"), |
|
|
("I am angry", "anger"), |
|
|
("I love this", "love"), |
|
|
("I am scared", "fear"), |
|
|
("I am sad", "sadness"), |
|
|
("What a surprise", "surprise") |
|
|
] |
|
|
|
|
|
critical_correct = 0 |
|
|
for text, expected in critical_cases: |
|
|
result = classifier(text) |
|
|
predicted = result[0]['label'].lower() |
|
|
is_correct = predicted == expected |
|
|
if is_correct: |
|
|
critical_correct += 1 |
|
|
|
|
|
status = "β
" if is_correct else "β" |
|
|
print(f" {status} '{text}' β {predicted}") |
|
|
|
|
|
critical_accuracy = critical_correct / len(critical_cases) |
|
|
validation_results["tests"]["critical_functionality"] = { |
|
|
"accuracy": critical_accuracy, |
|
|
"passed": critical_accuracy >= 0.8, |
|
|
"details": f"{critical_correct}/{len(critical_cases)} correct" |
|
|
} |
|
|
|
|
|
print(f" π Critical Accuracy: {critical_accuracy:.1%}") |
|
|
|
|
|
|
|
|
print("\nπ§ͺ Test 2: Malay Fixes Validation") |
|
|
print("-" * 40) |
|
|
|
|
|
malay_fixes = [ |
|
|
("Ini adalah hari jadi terbaik", "happy"), |
|
|
("Terbaik!", "happy"), |
|
|
("Ini adalah hari yang baik", "happy"), |
|
|
("Pengalaman terbaik", "happy") |
|
|
] |
|
|
|
|
|
malay_correct = 0 |
|
|
for text, expected in malay_fixes: |
|
|
result = classifier(text) |
|
|
predicted = result[0]['label'].lower() |
|
|
is_correct = predicted == expected |
|
|
if is_correct: |
|
|
malay_correct += 1 |
|
|
|
|
|
status = "β
" if is_correct else "β" |
|
|
print(f" {status} '{text}' β {predicted}") |
|
|
|
|
|
malay_accuracy = malay_correct / len(malay_fixes) |
|
|
validation_results["tests"]["malay_fixes"] = { |
|
|
"accuracy": malay_accuracy, |
|
|
"passed": malay_accuracy >= 0.8, |
|
|
"details": f"{malay_correct}/{len(malay_fixes)} correct" |
|
|
} |
|
|
|
|
|
print(f" π Malay Fixes Accuracy: {malay_accuracy:.1%}") |
|
|
|
|
|
|
|
|
print("\nπ§ͺ Test 3: Performance Benchmark") |
|
|
print("-" * 40) |
|
|
|
|
|
benchmark_texts = ["I am happy"] * 20 |
|
|
|
|
|
start_time = time.time() |
|
|
for text in benchmark_texts: |
|
|
_ = classifier(text) |
|
|
end_time = time.time() |
|
|
|
|
|
total_time = end_time - start_time |
|
|
predictions_per_second = len(benchmark_texts) / total_time |
|
|
|
|
|
validation_results["tests"]["performance"] = { |
|
|
"predictions_per_second": predictions_per_second, |
|
|
"passed": predictions_per_second >= 3.0, |
|
|
"details": f"{predictions_per_second:.1f} predictions/second" |
|
|
} |
|
|
|
|
|
print(f" β‘ Speed: {predictions_per_second:.1f} predictions/second") |
|
|
|
|
|
|
|
|
print("\nπ§ͺ Test 4: Confidence Validation") |
|
|
print("-" * 40) |
|
|
|
|
|
confidence_cases = [ |
|
|
"I am extremely happy today!", |
|
|
"I absolutely love this!", |
|
|
"I am terrified!", |
|
|
"Saya sangat gembira!", |
|
|
"Terbaik!" |
|
|
] |
|
|
|
|
|
high_confidence_count = 0 |
|
|
total_confidence = 0 |
|
|
|
|
|
for text in confidence_cases: |
|
|
result = classifier(text) |
|
|
confidence = result[0]['score'] |
|
|
total_confidence += confidence |
|
|
|
|
|
if confidence > 0.8: |
|
|
high_confidence_count += 1 |
|
|
|
|
|
print(f" π '{text[:30]}...' β {confidence:.1%}") |
|
|
|
|
|
avg_confidence = total_confidence / len(confidence_cases) |
|
|
high_confidence_rate = high_confidence_count / len(confidence_cases) |
|
|
|
|
|
validation_results["tests"]["confidence"] = { |
|
|
"average_confidence": avg_confidence, |
|
|
"high_confidence_rate": high_confidence_rate, |
|
|
"passed": avg_confidence >= 0.7 and high_confidence_rate >= 0.6, |
|
|
"details": f"Avg: {avg_confidence:.1%}, High: {high_confidence_rate:.1%}" |
|
|
} |
|
|
|
|
|
print(f" π Average Confidence: {avg_confidence:.1%}") |
|
|
print(f" π High Confidence Rate: {high_confidence_rate:.1%}") |
|
|
|
|
|
|
|
|
print("\nπ― VALIDATION SUMMARY") |
|
|
print("=" * 60) |
|
|
|
|
|
all_tests_passed = all(test["passed"] for test in validation_results["tests"].values()) |
|
|
|
|
|
if all_tests_passed: |
|
|
validation_results["overall_status"] = "PASS" |
|
|
print("π VALIDATION PASSED!") |
|
|
print("β
All tests passed successfully") |
|
|
print("β
Model is ready for production use") |
|
|
else: |
|
|
validation_results["overall_status"] = "FAIL" |
|
|
print("β VALIDATION FAILED!") |
|
|
print("β οΈ Some tests did not meet requirements") |
|
|
|
|
|
failed_tests = [name for name, test in validation_results["tests"].items() if not test["passed"]] |
|
|
print(f"β Failed tests: {', '.join(failed_tests)}") |
|
|
|
|
|
|
|
|
print("\nπ DETAILED RESULTS:") |
|
|
for test_name, test_result in validation_results["tests"].items(): |
|
|
status = "β
PASS" if test_result["passed"] else "β FAIL" |
|
|
print(f" {status} {test_name.replace('_', ' ').title()}: {test_result['details']}") |
|
|
|
|
|
|
|
|
if output_file: |
|
|
with open(output_file, 'w') as f: |
|
|
json.dump(validation_results, f, indent=2) |
|
|
print(f"\nπΎ Results saved to: {output_file}") |
|
|
|
|
|
return validation_results |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Validation failed with error: {e}") |
|
|
validation_results["overall_status"] = "ERROR" |
|
|
validation_results["error"] = str(e) |
|
|
return validation_results |
|
|
|
|
|
def main(): |
|
|
"""Main validation function""" |
|
|
parser = argparse.ArgumentParser(description="Validate the multilingual emotion classifier") |
|
|
parser.add_argument( |
|
|
"--model", |
|
|
default="rmtariq/multilingual-emotion-classifier", |
|
|
help="Model name or path to validate" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
help="Output file for validation results (JSON format)" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
results = validate_model(args.model, args.output) |
|
|
|
|
|
|
|
|
if results["overall_status"] == "PASS": |
|
|
return 0 |
|
|
else: |
|
|
return 1 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |
|
|
|