new_recommender_system_nlp2 / intent_classifier.py
bharatverse11's picture
Update intent_classifier.py
5837b53 verified
"""
FILE: intent_classifier.py (ENHANCED VERSION)
PURPOSE:
- Detect user intent from natural text using Semantic Similarity
- Route query into domains aligned with the expanded dataset: FOOD / HERITAGE / TRAVEL / CULTURAL / ARCHITECTURE / NATURE
- Uses SentenceTransformer for robust understanding beyond keywords
"""
from sentence_transformers import SentenceTransformer, util
import torch
# Define domain prototypes aligned with the dataset's actual domains
DOMAIN_PROTOTYPES = {
"food": "I want to eat delicious food, snacks, dishes, sweets, desserts, breakfast, lunch, dinner, restaurants, street food, traditional cuisine, beverages. biryani, dosa, idli, thali, spicy food, local recipes, sweets like mysore pak, halwa.",
"heritage": "I want to visit historical sites, ancient monuments, forts, palaces, museums, archaeological ruins, tombs, heritage buildings, UNESCO sites, historical architecture. temples like tirupati, meenakshi, forts like golconda, ancient caves.",
"travel": "I want to travel and explore places like hill stations, valleys, mountains, passes, trekking spots, adventure destinations, offbeat locations, hidden gems, scenic viewpoints, roads. beaches like kothapatnam, varkala, hill stations like ooty, munnar.",
"nature": "I want to experience nature through waterfalls, lakes, rivers, forests, wildlife sanctuaries, national parks, caves, islands, beaches, natural landscapes, gardens. waterfalls, tiger reserves, sprawling lakes, botanical gardens.",
"cultural": "I want to experience culture through festivals, traditional events, art forms, folk performances, local customs, tribal culture, villages, markets, handlooms, crafts. dance forms, music festivals, handicraft markets, silk saree weaving.",
"architecture": "I want to see beautiful architecture, design, structures, buildings, temples, churches, mosques, monasteries, modern architecture, engineering marvels. dravidian style, mughal architecture, intricate carvings, massive domes."
}
class IntentClassifier:
def __init__(self, model_name="all-MiniLM-L6-v2"):
print(f"🧠 Loading Intent Classifier Model: {model_name}...")
self.model = SentenceTransformer(model_name)
# Pre-compute embeddings for domain prototypes
self.domains = list(DOMAIN_PROTOTYPES.keys())
self.prototypes = list(DOMAIN_PROTOTYPES.values())
self.prototype_embeddings = self.model.encode(self.prototypes, convert_to_tensor=True)
print("βœ… Intent Classifier Ready")
def predict_intent(self, query: str, threshold: float = 0.25) -> str:
"""
Predicts the intent of the query based on semantic similarity to domain prototypes.
Returns 'general' if the highest similarity score is below the threshold.
"""
query_embedding = self.model.encode(query, convert_to_tensor=True)
# Compute cosine similarity
cosine_scores = util.cos_sim(query_embedding, self.prototype_embeddings)[0]
# Find the best match
best_score, best_index = torch.max(cosine_scores, dim=0)
best_score = best_score.item()
best_domain = self.domains[best_index]
# Debug output (can be uncommented for testing)
# print(f"DEBUG: Query='{query}' | Best Match='{best_domain}' ({best_score:.4f})")
if best_score < threshold:
return "general"
return best_domain
# Backwards compatibility wrapper
_shared_classifier = None
def classify_intent(query: str):
global _shared_classifier
if _shared_classifier is None:
_shared_classifier = IntentClassifier()
return _shared_classifier.predict_intent(query)
# Enhanced test suite
if __name__ == "__main__":
classifier = IntentClassifier()
test_queries = [
# Food queries
("spicy masala dosa bangalore", "food"),
("sweet Indian dessert", "food"),
("iconic restaurant butter chicken", "food"),
# Heritage queries
("ancient fort in rajasthan", "heritage"),
("historical monuments", "heritage"),
("mughal palace", "heritage"),
# Travel queries
("valley of flowers nagaland", "travel"),
("hill station honeymoon", "travel"),
("trekking adventure ladakh", "travel"),
("mountain pass", "travel"),
# Nature queries
("hidden waterfall meghalaya", "nature"),
("wildlife sanctuary", "nature"),
("national park tigers", "nature"),
("beautiful lake", "nature"),
# Cultural queries
("traditional festival", "cultural"),
("tribal village", "cultural"),
("folk art performance", "cultural"),
# Architecture queries
("temple architecture", "architecture"),
("beautiful building design", "architecture"),
# General/Ambiguous
("random gibberish text", "general"),
]
print("\n" + "="*60)
print("INTENT CLASSIFIER TEST RESULTS")
print("="*60)
correct = 0
total = len(test_queries)
for query, expected in test_queries:
predicted = classifier.predict_intent(query)
is_correct = predicted == expected
correct += is_correct
status = "βœ…" if is_correct else "❌"
print(f"{status} Query: '{query}'")
print(f" Expected: {expected} | Predicted: {predicted}\n")
print("="*60)
print(f"ACCURACY: {correct}/{total} ({round(correct/total * 100, 1)}%)")
print("="*60)