|
|
""" |
|
|
Complete RAG Engine |
|
|
Integrates hybrid retrieval, GraphRAG, and Groq LLM for Ireland Q&A |
|
|
""" |
|
|
|
|
|
import json |
|
|
import time |
|
|
from typing import List, Dict, Optional |
|
|
from hybrid_retriever import HybridRetriever, RetrievalResult |
|
|
from groq_llm import GroqLLM |
|
|
import hashlib |
|
|
|
|
|
|
|
|
class IrelandRAGEngine: |
|
|
"""Complete RAG engine for Ireland knowledge base""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
chunks_file: str = "dataset/wikipedia_ireland/chunks.json", |
|
|
graphrag_index_file: str = "dataset/wikipedia_ireland/graphrag_index.json", |
|
|
groq_api_key: Optional[str] = None, |
|
|
groq_model: str = "llama-3.3-70b-versatile", |
|
|
use_cache: bool = True |
|
|
): |
|
|
"""Initialize RAG engine""" |
|
|
print("[INFO] Initializing Ireland RAG Engine...") |
|
|
|
|
|
|
|
|
self.retriever = HybridRetriever( |
|
|
chunks_file=chunks_file, |
|
|
graphrag_index_file=graphrag_index_file |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
self.retriever.load_indexes() |
|
|
except: |
|
|
print("[INFO] Pre-built indexes not found, building new ones...") |
|
|
self.retriever.build_semantic_index() |
|
|
self.retriever.build_keyword_index() |
|
|
self.retriever.save_indexes() |
|
|
|
|
|
|
|
|
self.llm = GroqLLM(api_key=groq_api_key, model=groq_model) |
|
|
|
|
|
|
|
|
self.use_cache = use_cache |
|
|
self.cache = {} |
|
|
self.cache_hits = 0 |
|
|
self.cache_misses = 0 |
|
|
|
|
|
print("[SUCCESS] RAG Engine ready!") |
|
|
|
|
|
def _hash_query(self, query: str) -> str: |
|
|
"""Create hash of query for caching""" |
|
|
return hashlib.md5(query.lower().strip().encode()).hexdigest() |
|
|
|
|
|
def answer_question( |
|
|
self, |
|
|
question: str, |
|
|
top_k: int = 5, |
|
|
semantic_weight: float = 0.7, |
|
|
keyword_weight: float = 0.3, |
|
|
use_community_context: bool = True, |
|
|
return_debug_info: bool = False |
|
|
) -> Dict: |
|
|
""" |
|
|
Answer a question about Ireland using GraphRAG |
|
|
|
|
|
Args: |
|
|
question: User's question |
|
|
top_k: Number of chunks to retrieve |
|
|
semantic_weight: Weight for semantic search (0-1) |
|
|
keyword_weight: Weight for keyword search (0-1) |
|
|
use_community_context: Whether to include community summaries |
|
|
return_debug_info: Whether to return detailed debug information |
|
|
|
|
|
Returns: |
|
|
Dict with answer, citations, and metadata |
|
|
""" |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
query_hash = self._hash_query(question) |
|
|
if self.use_cache and query_hash in self.cache: |
|
|
self.cache_hits += 1 |
|
|
cached_result = self.cache[query_hash].copy() |
|
|
cached_result['cached'] = True |
|
|
cached_result['response_time'] = time.time() - start_time |
|
|
return cached_result |
|
|
|
|
|
self.cache_misses += 1 |
|
|
|
|
|
|
|
|
retrieval_start = time.time() |
|
|
retrieved_chunks = self.retriever.hybrid_search( |
|
|
query=question, |
|
|
top_k=top_k, |
|
|
semantic_weight=semantic_weight, |
|
|
keyword_weight=keyword_weight |
|
|
) |
|
|
retrieval_time = time.time() - retrieval_start |
|
|
|
|
|
|
|
|
contexts = [] |
|
|
for result in retrieved_chunks: |
|
|
context = { |
|
|
'text': result.text, |
|
|
'source_title': result.source_title, |
|
|
'source_url': result.source_url, |
|
|
'combined_score': result.combined_score, |
|
|
'semantic_score': result.semantic_score, |
|
|
'keyword_score': result.keyword_score, |
|
|
'community_id': result.community_id |
|
|
} |
|
|
contexts.append(context) |
|
|
|
|
|
|
|
|
community_summaries = [] |
|
|
if use_community_context: |
|
|
|
|
|
communities = set(result.community_id for result in retrieved_chunks if result.community_id >= 0) |
|
|
|
|
|
for comm_id in list(communities)[:2]: |
|
|
comm_context = self.retriever.get_community_context(comm_id) |
|
|
if comm_context: |
|
|
community_summaries.append({ |
|
|
'community_id': comm_id, |
|
|
'num_chunks': comm_context.get('num_chunks', 0), |
|
|
'top_entities': [e['entity'] for e in comm_context.get('top_entities', [])[:5]], |
|
|
'sources': comm_context.get('sources', [])[:3] |
|
|
}) |
|
|
|
|
|
|
|
|
generation_start = time.time() |
|
|
llm_result = self.llm.generate_with_citations( |
|
|
question=question, |
|
|
contexts=contexts, |
|
|
max_contexts=top_k |
|
|
) |
|
|
generation_time = time.time() - generation_start |
|
|
|
|
|
|
|
|
response = { |
|
|
'question': question, |
|
|
'answer': llm_result['answer'], |
|
|
'citations': llm_result['citations'], |
|
|
'num_contexts_used': llm_result['num_contexts_used'], |
|
|
'communities': community_summaries if use_community_context else [], |
|
|
'cached': False, |
|
|
'response_time': time.time() - start_time, |
|
|
'retrieval_time': retrieval_time, |
|
|
'generation_time': generation_time |
|
|
} |
|
|
|
|
|
|
|
|
if return_debug_info: |
|
|
response['debug'] = { |
|
|
'retrieved_chunks': [ |
|
|
{ |
|
|
'rank': r.rank, |
|
|
'source': r.source_title, |
|
|
'semantic_score': f"{r.semantic_score:.3f}", |
|
|
'keyword_score': f"{r.keyword_score:.3f}", |
|
|
'combined_score': f"{r.combined_score:.3f}", |
|
|
'community': r.community_id, |
|
|
'text_preview': r.text[:150] + "..." |
|
|
} |
|
|
for r in retrieved_chunks |
|
|
], |
|
|
'cache_stats': { |
|
|
'hits': self.cache_hits, |
|
|
'misses': self.cache_misses, |
|
|
'hit_rate': f"{self.cache_hits / (self.cache_hits + self.cache_misses) * 100:.1f}%" if (self.cache_hits + self.cache_misses) > 0 else "0%" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if self.use_cache: |
|
|
self.cache[query_hash] = response.copy() |
|
|
|
|
|
return response |
|
|
|
|
|
def get_cache_stats(self) -> Dict: |
|
|
"""Get cache statistics""" |
|
|
total_queries = self.cache_hits + self.cache_misses |
|
|
hit_rate = (self.cache_hits / total_queries * 100) if total_queries > 0 else 0 |
|
|
|
|
|
return { |
|
|
'cache_size': len(self.cache), |
|
|
'cache_hits': self.cache_hits, |
|
|
'cache_misses': self.cache_misses, |
|
|
'total_queries': total_queries, |
|
|
'hit_rate': f"{hit_rate:.1f}%" |
|
|
} |
|
|
|
|
|
def clear_cache(self): |
|
|
"""Clear the response cache""" |
|
|
self.cache.clear() |
|
|
self.cache_hits = 0 |
|
|
self.cache_misses = 0 |
|
|
print("[INFO] Cache cleared") |
|
|
|
|
|
def get_stats(self) -> Dict: |
|
|
"""Get engine statistics""" |
|
|
return { |
|
|
'total_chunks': len(self.retriever.chunks), |
|
|
'total_communities': len(self.retriever.graphrag_index['communities']), |
|
|
'cache_stats': self.get_cache_stats() |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
engine = IrelandRAGEngine() |
|
|
|
|
|
|
|
|
questions = [ |
|
|
"What is the capital of Ireland?", |
|
|
"When did Ireland join the European Union?", |
|
|
"Who is the current president of Ireland?", |
|
|
"What is the oldest university in Ireland?" |
|
|
] |
|
|
|
|
|
for question in questions: |
|
|
print("\n" + "=" * 80) |
|
|
print(f"Question: {question}") |
|
|
print("=" * 80) |
|
|
|
|
|
result = engine.answer_question(question, top_k=5, return_debug_info=True) |
|
|
|
|
|
print(f"\nAnswer:\n{result['answer']}") |
|
|
print(f"\nResponse Time: {result['response_time']:.2f}s") |
|
|
print(f" - Retrieval: {result['retrieval_time']:.2f}s") |
|
|
print(f" - Generation: {result['generation_time']:.2f}s") |
|
|
|
|
|
print(f"\nCitations:") |
|
|
for cite in result['citations']: |
|
|
print(f" [{cite['id']}] {cite['source']} (score: {cite['relevance_score']:.3f})") |
|
|
|
|
|
if result.get('communities'): |
|
|
print(f"\nRelated Topics:") |
|
|
for comm in result['communities']: |
|
|
print(f" - {', '.join(comm['top_entities'][:3])}") |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Cache Stats:", engine.get_cache_stats()) |
|
|
print("=" * 80) |
|
|
|