""" Complete RAG Engine Integrates hybrid retrieval, GraphRAG, and Groq LLM for Ireland Q&A """ import json import time from typing import List, Dict, Optional from hybrid_retriever import HybridRetriever, RetrievalResult from groq_llm import GroqLLM import hashlib class IrelandRAGEngine: """Complete RAG engine for Ireland knowledge base""" def __init__( self, chunks_file: str = "dataset/wikipedia_ireland/chunks.json", graphrag_index_file: str = "dataset/wikipedia_ireland/graphrag_index.json", groq_api_key: Optional[str] = None, groq_model: str = "llama-3.3-70b-versatile", use_cache: bool = True ): """Initialize RAG engine""" print("[INFO] Initializing Ireland RAG Engine...") # Initialize retriever self.retriever = HybridRetriever( chunks_file=chunks_file, graphrag_index_file=graphrag_index_file ) # Try to load pre-built indexes, otherwise build them try: self.retriever.load_indexes() except: print("[INFO] Pre-built indexes not found, building new ones...") self.retriever.build_semantic_index() self.retriever.build_keyword_index() self.retriever.save_indexes() # Initialize LLM self.llm = GroqLLM(api_key=groq_api_key, model=groq_model) # Cache for instant responses self.use_cache = use_cache self.cache = {} self.cache_hits = 0 self.cache_misses = 0 print("[SUCCESS] RAG Engine ready!") def _hash_query(self, query: str) -> str: """Create hash of query for caching""" return hashlib.md5(query.lower().strip().encode()).hexdigest() def answer_question( self, question: str, top_k: int = 5, semantic_weight: float = 0.7, keyword_weight: float = 0.3, use_community_context: bool = True, return_debug_info: bool = False ) -> Dict: """ Answer a question about Ireland using GraphRAG Args: question: User's question top_k: Number of chunks to retrieve semantic_weight: Weight for semantic search (0-1) keyword_weight: Weight for keyword search (0-1) use_community_context: Whether to include community summaries return_debug_info: Whether to return detailed debug information Returns: Dict with answer, citations, and metadata """ start_time = time.time() # Check cache query_hash = self._hash_query(question) if self.use_cache and query_hash in self.cache: self.cache_hits += 1 cached_result = self.cache[query_hash].copy() cached_result['cached'] = True cached_result['response_time'] = time.time() - start_time return cached_result self.cache_misses += 1 # Step 1: Hybrid retrieval retrieval_start = time.time() retrieved_chunks = self.retriever.hybrid_search( query=question, top_k=top_k, semantic_weight=semantic_weight, keyword_weight=keyword_weight ) retrieval_time = time.time() - retrieval_start # Step 2: Prepare contexts for LLM contexts = [] for result in retrieved_chunks: context = { 'text': result.text, 'source_title': result.source_title, 'source_url': result.source_url, 'combined_score': result.combined_score, 'semantic_score': result.semantic_score, 'keyword_score': result.keyword_score, 'community_id': result.community_id } contexts.append(context) # Step 3: Add community context if enabled community_summaries = [] if use_community_context: # Get unique communities from results communities = set(result.community_id for result in retrieved_chunks if result.community_id >= 0) for comm_id in list(communities)[:2]: # Use top 2 communities comm_context = self.retriever.get_community_context(comm_id) if comm_context: community_summaries.append({ 'community_id': comm_id, 'num_chunks': comm_context.get('num_chunks', 0), 'top_entities': [e['entity'] for e in comm_context.get('top_entities', [])[:5]], 'sources': comm_context.get('sources', [])[:3] }) # Step 4: Generate answer with citations generation_start = time.time() llm_result = self.llm.generate_with_citations( question=question, contexts=contexts, max_contexts=top_k ) generation_time = time.time() - generation_start # Step 5: Build response response = { 'question': question, 'answer': llm_result['answer'], 'citations': llm_result['citations'], 'num_contexts_used': llm_result['num_contexts_used'], 'communities': community_summaries if use_community_context else [], 'cached': False, 'response_time': time.time() - start_time, 'retrieval_time': retrieval_time, 'generation_time': generation_time } # Add debug info if requested if return_debug_info: response['debug'] = { 'retrieved_chunks': [ { 'rank': r.rank, 'source': r.source_title, 'semantic_score': f"{r.semantic_score:.3f}", 'keyword_score': f"{r.keyword_score:.3f}", 'combined_score': f"{r.combined_score:.3f}", 'community': r.community_id, 'text_preview': r.text[:150] + "..." } for r in retrieved_chunks ], 'cache_stats': { 'hits': self.cache_hits, 'misses': self.cache_misses, 'hit_rate': f"{self.cache_hits / (self.cache_hits + self.cache_misses) * 100:.1f}%" if (self.cache_hits + self.cache_misses) > 0 else "0%" } } # Cache the response if self.use_cache: self.cache[query_hash] = response.copy() return response def get_cache_stats(self) -> Dict: """Get cache statistics""" total_queries = self.cache_hits + self.cache_misses hit_rate = (self.cache_hits / total_queries * 100) if total_queries > 0 else 0 return { 'cache_size': len(self.cache), 'cache_hits': self.cache_hits, 'cache_misses': self.cache_misses, 'total_queries': total_queries, 'hit_rate': f"{hit_rate:.1f}%" } def clear_cache(self): """Clear the response cache""" self.cache.clear() self.cache_hits = 0 self.cache_misses = 0 print("[INFO] Cache cleared") def get_stats(self) -> Dict: """Get engine statistics""" return { 'total_chunks': len(self.retriever.chunks), 'total_communities': len(self.retriever.graphrag_index['communities']), 'cache_stats': self.get_cache_stats() } if __name__ == "__main__": # Test RAG engine engine = IrelandRAGEngine() # Test questions questions = [ "What is the capital of Ireland?", "When did Ireland join the European Union?", "Who is the current president of Ireland?", "What is the oldest university in Ireland?" ] for question in questions: print("\n" + "=" * 80) print(f"Question: {question}") print("=" * 80) result = engine.answer_question(question, top_k=5, return_debug_info=True) print(f"\nAnswer:\n{result['answer']}") print(f"\nResponse Time: {result['response_time']:.2f}s") print(f" - Retrieval: {result['retrieval_time']:.2f}s") print(f" - Generation: {result['generation_time']:.2f}s") print(f"\nCitations:") for cite in result['citations']: print(f" [{cite['id']}] {cite['source']} (score: {cite['relevance_score']:.3f})") if result.get('communities'): print(f"\nRelated Topics:") for comm in result['communities']: print(f" - {', '.join(comm['top_entities'][:3])}") print("\n" + "=" * 80) print("Cache Stats:", engine.get_cache_stats()) print("=" * 80)