Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced Message Classification System for Arabic Regulatory Chatbot | |
| Improves differentiation between casual chat and regulatory queries. | |
| """ | |
| import re | |
| import logging | |
| from typing import Dict, List, Tuple, Optional | |
| from enum import Enum | |
| from dataclasses import dataclass | |
| logger = logging.getLogger(__name__) | |
| class ClassificationConfidence(Enum): | |
| """Classification confidence levels.""" | |
| HIGH = "high" # 0.8+ | |
| MEDIUM = "medium" # 0.5-0.8 | |
| LOW = "low" # 0.3-0.5 | |
| VERY_LOW = "very_low" # <0.3 | |
| class ClassificationResult: | |
| """Result of message classification.""" | |
| message_type: str | |
| confidence: float | |
| confidence_level: ClassificationConfidence | |
| reasoning: str | |
| alternative_types: List[Tuple[str, float]] | |
| class EnhancedMessageClassifier: | |
| """ | |
| Enhanced message classifier with improved casual vs regulatory detection. | |
| """ | |
| def __init__(self): | |
| self.setup_patterns() | |
| self.setup_weights() | |
| def setup_patterns(self): | |
| """Setup comprehensive classification patterns.""" | |
| # === GREETING PATTERNS === | |
| self.greeting_patterns = { | |
| 'arabic': [ | |
| r'(^|\s)(مرحبا|أهلا|السلام عليكم|سلام عليكم|اهلا وسهلا|مرحباً|أهلاً)(\s|$)', | |
| r'(^|\s)(كيف حالك|كيف الحال|كيف الأحوال|شلونك|كيفك)(\s|$)', | |
| r'(^|\s)(صباح الخير|مساء الخير|تحية طيبة)(\s|$)', | |
| r'(^|\s)(هلا|هالو|هلو)(\s|$)', | |
| ], | |
| 'english': [ | |
| r'(^|\s)(hello|hi|hey|greetings|good morning|good afternoon|good evening)(\s|$)', | |
| r'(^|\s)(how are you|how do you do|what\'s up|how\'s it going)(\s|$)', | |
| r'(^|\s)(hola|bonjour|guten tag)(\s|$)', | |
| ] | |
| } | |
| # === ENHANCED CASUAL PATTERNS === | |
| self.casual_patterns = { | |
| 'time_queries': [ | |
| r'(what|ما|كم).*(time|وقت|ساعة)', | |
| r'(what time|كم الساعة|ما الوقت)', | |
| r'(current time|الوقت الحالي)', | |
| r'(time.*now|الوقت.*الآن)', | |
| ], | |
| 'weather_queries': [ | |
| r'(weather|طقس|جو)', | |
| r'(temperature|درجة حرارة)', | |
| r'(rain|snow|sunny|مطر|شمس|غائم)', | |
| r'(how.*weather|كيف.*الطقس)', | |
| ], | |
| 'personal_inquiries': [ | |
| r'(how.*day|كيف.*اليوم)', | |
| r'(how.*doing|كيف.*حالك)', | |
| r'(how.*feeling|كيف.*شعورك)', | |
| r'(what.*up|شو الأخبار)', | |
| r'(how.*going|كيف.*الأمور)', | |
| ], | |
| 'general_questions': [ | |
| r'(tell me about yourself|أخبرني عن نفسك)', | |
| r'(who are you|من أنت)', | |
| r'(what.*you do|ماذا تفعل)', | |
| r'(can you help|تقدر تساعد)', | |
| r'(what.*capabilities|ما قدراتك)', | |
| ], | |
| 'social_interactions': [ | |
| r'(thank.*you|شكرا|شكراً|تسلم)', | |
| r'(you\'re welcome|عفواً|لا شكر على واجب)', | |
| r'(goodbye|bye|وداع|باي|مع السلامة)', | |
| r'(see you|أراك لاحقاً)', | |
| r'(nice.*talk|كان من الجميل)', | |
| ], | |
| 'casual_requests': [ | |
| r'(tell.*joke|احك نكتة)', | |
| r'(how.*weekend|كيف.*الأسبوع)', | |
| r'(favorite.*color|اللون المفضل)', | |
| r'(what.*think|ما رأيك)', | |
| ] | |
| } | |
| # === REGULATORY PATTERNS (More Specific) === | |
| self.regulatory_patterns = { | |
| 'licensing_keywords': [ | |
| r'(ترخيص|license|licensing|permit|تصريح)', | |
| r'(تسجيل|registration|register)', | |
| r'(موافقة|approval|اعتماد)', | |
| ], | |
| 'compliance_keywords': [ | |
| r'(امتثال|compliance|comply)', | |
| r'(متطلبات|requirements|شروط)', | |
| r'(ضوابط|regulations|أنظمة|لوائح)', | |
| r'(قوانين|laws|legislation)', | |
| ], | |
| 'financial_keywords': [ | |
| r'(مصرف|بنك|bank|banking|مصرفي)', | |
| r'(ائتمان|credit|قرض|loan)', | |
| r'(استثمار|investment|توريق)', | |
| r'(أسواق المال|capital markets|securities)', | |
| r'(صندوق|fund|محفظة|portfolio)', | |
| ], | |
| 'authorities_keywords': [ | |
| r'(البنك المركزي|central bank|cbk)', | |
| r'(هيئة أسواق المال|capital markets authority|cma)', | |
| r'(مؤسسة النقد|monetary authority)', | |
| ], | |
| 'violations_keywords': [ | |
| r'(مخالفة|violation|infringement)', | |
| r'(عقوبة|penalty|fine|جزاء)', | |
| r'(تأديب|disciplinary|تأديبي)', | |
| r'(محاكمة|hearing|جلسة)', | |
| ], | |
| 'aml_keywords': [ | |
| r'(غسل الأموال|money laundering|aml)', | |
| r'(مشبوه|suspicious|مشكوك)', | |
| r'(kyc|know your customer|اعرف عميلك)', | |
| r'(بيانات العميل|customer data)', | |
| ] | |
| } | |
| # === CONTEXT CLUES === | |
| self.context_indicators = { | |
| 'casual_indicators': [ | |
| r'^(just|فقط|بس)', | |
| r'(curious|فضول)', | |
| r'(wondering|أتساءل)', | |
| r'(by the way|بالمناسبة)', | |
| r'(quick question|سؤال سريع)', | |
| ], | |
| 'regulatory_indicators': [ | |
| r'(according to|وفقاً ل|حسب)', | |
| r'(article|مادة|فقرة)', | |
| r'(section|قسم|باب)', | |
| r'(regulation number|رقم اللائحة)', | |
| r'(what are the|ما هي)', | |
| r'(requirements for|متطلبات)', | |
| ] | |
| } | |
| def setup_weights(self): | |
| """Setup scoring weights for different pattern categories.""" | |
| self.weights = { | |
| 'greeting': 1.0, | |
| 'casual_time': 0.9, | |
| 'casual_weather': 0.9, | |
| 'casual_personal': 0.8, | |
| 'casual_general': 0.7, | |
| 'casual_social': 0.8, | |
| 'casual_requests': 0.7, | |
| 'regulatory_licensing': 0.9, | |
| 'regulatory_compliance': 0.9, | |
| 'regulatory_financial': 0.8, | |
| 'regulatory_authorities': 1.0, | |
| 'regulatory_violations': 0.9, | |
| 'regulatory_aml': 0.9, | |
| 'context_casual': 0.3, | |
| 'context_regulatory': 0.4, | |
| } | |
| def classify_message(self, message: str) -> ClassificationResult: | |
| """ | |
| Classify message with confidence scoring and reasoning. | |
| """ | |
| message_lower = message.lower().strip() | |
| # Score different categories | |
| scores = { | |
| 'greeting': self._score_greeting(message_lower), | |
| 'casual': self._score_casual(message_lower), | |
| 'regulatory': self._score_regulatory(message_lower) | |
| } | |
| # Determine best classification | |
| best_type = max(scores, key=scores.get) | |
| best_score = scores[best_type] | |
| # Get confidence level | |
| confidence_level = self._get_confidence_level(best_score) | |
| # Generate reasoning | |
| reasoning = self._generate_reasoning(message_lower, scores, best_type) | |
| # Get alternative classifications | |
| sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) | |
| alternatives = [(t, s) for t, s in sorted_scores[1:] if s > 0.1] | |
| # Handle ambiguous cases with smart defaults | |
| if confidence_level == ClassificationConfidence.VERY_LOW: | |
| # If all scores are very low, check message characteristics | |
| if len(message.split()) <= 3 and not any(char in message for char in '?!.'): | |
| # Short message without punctuation - likely casual | |
| best_type = 'casual' | |
| best_score = 0.4 | |
| confidence_level = ClassificationConfidence.LOW | |
| reasoning = "Short informal message defaulted to casual" | |
| elif '?' in message: | |
| # Question format - analyze further | |
| if any(word in message_lower for word in ['what', 'how', 'when', 'where', 'why', 'ماذا', 'كيف', 'متى', 'أين', 'لماذا']): | |
| best_type = 'casual' | |
| best_score = 0.4 | |
| confidence_level = ClassificationConfidence.LOW | |
| reasoning = "Question format defaulted to casual conversation" | |
| return ClassificationResult( | |
| message_type=best_type, | |
| confidence=best_score, | |
| confidence_level=confidence_level, | |
| reasoning=reasoning, | |
| alternative_types=alternatives | |
| ) | |
| def _score_greeting(self, message: str) -> float: | |
| """Score greeting likelihood.""" | |
| score = 0.0 | |
| for lang, patterns in self.greeting_patterns.items(): | |
| for pattern in patterns: | |
| if re.search(pattern, message, re.IGNORECASE): | |
| score = max(score, self.weights['greeting']) | |
| return min(score, 1.0) | |
| def _score_casual(self, message: str) -> float: | |
| """Score casual conversation likelihood.""" | |
| score = 0.0 | |
| # Check casual patterns | |
| for category, patterns in self.casual_patterns.items(): | |
| category_weight = self.weights.get(f'casual_{category.split("_")[0]}', 0.7) | |
| for pattern in patterns: | |
| if re.search(pattern, message, re.IGNORECASE): | |
| score = max(score, category_weight) | |
| # Add context indicators | |
| for pattern in self.context_indicators['casual_indicators']: | |
| if re.search(pattern, message, re.IGNORECASE): | |
| score += self.weights['context_casual'] | |
| # Boost for common casual question words | |
| casual_words = ['how', 'what', 'when', 'where', 'كيف', 'ماذا', 'متى', 'أين'] | |
| if any(word in message for word in casual_words) and len(message.split()) <= 6: | |
| score += 0.2 | |
| return min(score, 1.0) | |
| def _score_regulatory(self, message: str) -> float: | |
| """Score regulatory query likelihood.""" | |
| score = 0.0 | |
| # Check regulatory patterns | |
| for category, patterns in self.regulatory_patterns.items(): | |
| category_weight = self.weights.get(f'regulatory_{category.split("_")[0]}', 0.8) | |
| for pattern in patterns: | |
| if re.search(pattern, message, re.IGNORECASE): | |
| score = max(score, category_weight) | |
| # Add context indicators | |
| for pattern in self.context_indicators['regulatory_indicators']: | |
| if re.search(pattern, message, re.IGNORECASE): | |
| score += self.weights['context_regulatory'] | |
| # Penalty for very short messages | |
| if len(message.split()) <= 3: | |
| score *= 0.7 | |
| return min(score, 1.0) | |
| def _get_confidence_level(self, score: float) -> ClassificationConfidence: | |
| """Convert score to confidence level.""" | |
| if score >= 0.8: | |
| return ClassificationConfidence.HIGH | |
| elif score >= 0.5: | |
| return ClassificationConfidence.MEDIUM | |
| elif score >= 0.3: | |
| return ClassificationConfidence.LOW | |
| else: | |
| return ClassificationConfidence.VERY_LOW | |
| def _generate_reasoning(self, message: str, scores: Dict[str, float], best_type: str) -> str: | |
| """Generate human-readable reasoning for classification.""" | |
| reasoning_parts = [] | |
| if best_type == 'greeting': | |
| reasoning_parts.append("Contains greeting patterns") | |
| elif best_type == 'casual': | |
| if any(word in message for word in ['time', 'وقت', 'ساعة']): | |
| reasoning_parts.append("Time-related inquiry") | |
| elif any(word in message for word in ['weather', 'طقس']): | |
| reasoning_parts.append("Weather inquiry") | |
| elif any(word in message for word in ['how', 'كيف']): | |
| reasoning_parts.append("Personal/casual question") | |
| else: | |
| reasoning_parts.append("General casual conversation") | |
| elif best_type == 'regulatory': | |
| if any(word in message for word in ['bank', 'مصرف', 'بنك']): | |
| reasoning_parts.append("Banking-related query") | |
| elif any(word in message for word in ['license', 'ترخيص']): | |
| reasoning_parts.append("Licensing inquiry") | |
| elif any(word in message for word in ['compliance', 'امتثال']): | |
| reasoning_parts.append("Compliance question") | |
| else: | |
| reasoning_parts.append("Regulatory content detected") | |
| # Add confidence information | |
| if scores[best_type] < 0.5: | |
| reasoning_parts.append("(Low confidence - ambiguous)") | |
| return "; ".join(reasoning_parts) | |
| # Test the enhanced classifier | |
| if __name__ == "__main__": | |
| classifier = EnhancedMessageClassifier() | |
| test_cases = [ | |
| "Hello", | |
| "How is the day?", | |
| "What is the time right now?", | |
| "How are you doing today?", | |
| "What are the banking license requirements?", | |
| "Tell me about compliance regulations", | |
| "مرحبا", | |
| "كيف حالك اليوم؟", | |
| "ما الوقت الآن؟", | |
| "ما هي متطلبات الترخيص المصرفي؟" | |
| ] | |
| print("Enhanced Message Classification Test Results:") | |
| print("=" * 60) | |
| for message in test_cases: | |
| result = classifier.classify_message(message) | |
| print(f"\nMessage: \"{message}\"") | |
| print(f"Classification: {result.message_type}") | |
| print(f"Confidence: {result.confidence:.2f} ({result.confidence_level.value})") | |
| print(f"Reasoning: {result.reasoning}") | |
| if result.alternative_types: | |
| alts = ", ".join([f"{t}({s:.2f})" for t, s in result.alternative_types[:2]]) | |
| print(f"Alternatives: {alts}") |