|
|
""" |
|
|
FILE: intent_classifier.py (ENHANCED VERSION) |
|
|
|
|
|
PURPOSE: |
|
|
- Detect user intent from natural text using Semantic Similarity |
|
|
- Route query into domains aligned with the expanded dataset: FOOD / HERITAGE / TRAVEL / CULTURAL / ARCHITECTURE / NATURE |
|
|
- Uses SentenceTransformer for robust understanding beyond keywords |
|
|
""" |
|
|
|
|
|
from sentence_transformers import SentenceTransformer, util |
|
|
import torch |
|
|
|
|
|
|
|
|
DOMAIN_PROTOTYPES = { |
|
|
"food": "I want to eat delicious food, snacks, dishes, sweets, desserts, breakfast, lunch, dinner, restaurants, street food, traditional cuisine, beverages. biryani, dosa, idli, thali, spicy food, local recipes, sweets like mysore pak, halwa.", |
|
|
|
|
|
"heritage": "I want to visit historical sites, ancient monuments, forts, palaces, museums, archaeological ruins, tombs, heritage buildings, UNESCO sites, historical architecture. temples like tirupati, meenakshi, forts like golconda, ancient caves.", |
|
|
|
|
|
"travel": "I want to travel and explore places like hill stations, valleys, mountains, passes, trekking spots, adventure destinations, offbeat locations, hidden gems, scenic viewpoints, roads. beaches like kothapatnam, varkala, hill stations like ooty, munnar.", |
|
|
|
|
|
"nature": "I want to experience nature through waterfalls, lakes, rivers, forests, wildlife sanctuaries, national parks, caves, islands, beaches, natural landscapes, gardens. waterfalls, tiger reserves, sprawling lakes, botanical gardens.", |
|
|
|
|
|
"cultural": "I want to experience culture through festivals, traditional events, art forms, folk performances, local customs, tribal culture, villages, markets, handlooms, crafts. dance forms, music festivals, handicraft markets, silk saree weaving.", |
|
|
|
|
|
"architecture": "I want to see beautiful architecture, design, structures, buildings, temples, churches, mosques, monasteries, modern architecture, engineering marvels. dravidian style, mughal architecture, intricate carvings, massive domes." |
|
|
} |
|
|
|
|
|
class IntentClassifier: |
|
|
def __init__(self, model_name="all-MiniLM-L6-v2"): |
|
|
print(f"π§ Loading Intent Classifier Model: {model_name}...") |
|
|
self.model = SentenceTransformer(model_name) |
|
|
|
|
|
|
|
|
self.domains = list(DOMAIN_PROTOTYPES.keys()) |
|
|
self.prototypes = list(DOMAIN_PROTOTYPES.values()) |
|
|
self.prototype_embeddings = self.model.encode(self.prototypes, convert_to_tensor=True) |
|
|
print("β
Intent Classifier Ready") |
|
|
|
|
|
def predict_intent(self, query: str, threshold: float = 0.25) -> str: |
|
|
""" |
|
|
Predicts the intent of the query based on semantic similarity to domain prototypes. |
|
|
Returns 'general' if the highest similarity score is below the threshold. |
|
|
""" |
|
|
query_embedding = self.model.encode(query, convert_to_tensor=True) |
|
|
|
|
|
|
|
|
cosine_scores = util.cos_sim(query_embedding, self.prototype_embeddings)[0] |
|
|
|
|
|
|
|
|
best_score, best_index = torch.max(cosine_scores, dim=0) |
|
|
best_score = best_score.item() |
|
|
best_domain = self.domains[best_index] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if best_score < threshold: |
|
|
return "general" |
|
|
|
|
|
return best_domain |
|
|
|
|
|
|
|
|
_shared_classifier = None |
|
|
|
|
|
def classify_intent(query: str): |
|
|
global _shared_classifier |
|
|
if _shared_classifier is None: |
|
|
_shared_classifier = IntentClassifier() |
|
|
return _shared_classifier.predict_intent(query) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
classifier = IntentClassifier() |
|
|
|
|
|
test_queries = [ |
|
|
|
|
|
("spicy masala dosa bangalore", "food"), |
|
|
("sweet Indian dessert", "food"), |
|
|
("iconic restaurant butter chicken", "food"), |
|
|
|
|
|
|
|
|
("ancient fort in rajasthan", "heritage"), |
|
|
("historical monuments", "heritage"), |
|
|
("mughal palace", "heritage"), |
|
|
|
|
|
|
|
|
("valley of flowers nagaland", "travel"), |
|
|
("hill station honeymoon", "travel"), |
|
|
("trekking adventure ladakh", "travel"), |
|
|
("mountain pass", "travel"), |
|
|
|
|
|
|
|
|
("hidden waterfall meghalaya", "nature"), |
|
|
("wildlife sanctuary", "nature"), |
|
|
("national park tigers", "nature"), |
|
|
("beautiful lake", "nature"), |
|
|
|
|
|
|
|
|
("traditional festival", "cultural"), |
|
|
("tribal village", "cultural"), |
|
|
("folk art performance", "cultural"), |
|
|
|
|
|
|
|
|
("temple architecture", "architecture"), |
|
|
("beautiful building design", "architecture"), |
|
|
|
|
|
|
|
|
("random gibberish text", "general"), |
|
|
] |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("INTENT CLASSIFIER TEST RESULTS") |
|
|
print("="*60) |
|
|
|
|
|
correct = 0 |
|
|
total = len(test_queries) |
|
|
|
|
|
for query, expected in test_queries: |
|
|
predicted = classifier.predict_intent(query) |
|
|
is_correct = predicted == expected |
|
|
correct += is_correct |
|
|
|
|
|
status = "β
" if is_correct else "β" |
|
|
print(f"{status} Query: '{query}'") |
|
|
print(f" Expected: {expected} | Predicted: {predicted}\n") |
|
|
|
|
|
print("="*60) |
|
|
print(f"ACCURACY: {correct}/{total} ({round(correct/total * 100, 1)}%)") |
|
|
print("="*60) |