4at-consulting-chatbot / scripts /evaluate_rag_system.py
Ahambrahmasmi's picture
Update scripts/evaluate_rag_system.py
db4c200 verified
# βœ… STEP 1: Fix evaluate_rag_system.py
# This script helps you evaluate how well the chatbot retrieves and answers questions
# based on the 4AT knowledge base. It runs queries and prints out retrieved chunks and final answers.
import os
import json
import time
import pickle
from llama_index.core import VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.cohere import CohereEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import QueryBundle
from custom_retriever import CustomRetriever
# βœ… Load vector DB and document_dict
print("Loading Chroma vector store and document dict...")
db_path = "data/chroma-db-all_sources"
dict_path = os.path.join(db_path, "document_dict_all_sources.pkl")
# Load Chroma vector store
import chromadb
client = chromadb.PersistentClient(path=db_path)
collection = client.get_or_create_collection(name="chroma-db-all_sources")
vector_store = ChromaVectorStore(chroma_collection=collection)
# Load embedding model
embed_model = CohereEmbedding(
api_key=os.environ.get("COHERE_API_KEY"),
model_name="embed-english-v3.0",
input_type="search_query",
)
# Load document dict
with open(dict_path, "rb") as f:
document_dict = pickle.load(f)
# Load retriever
index = VectorStoreIndex.from_vector_store(
vector_store=vector_store,
embed_model=embed_model,
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)],
)
retriever = CustomRetriever(
vector_retriever=VectorIndexRetriever(index=index, similarity_top_k=10),
document_dict=document_dict,
)
# βœ… Sample evaluation questions
sample_queries = [
"Who are the founders of 4AT?",
"What services does 4AT provide?",
"Where are the office locations of 4AT?",
"Tell me about 4AT Academy",
"What industries does 4AT serve?",
"How many successful projects has 4AT delivered?",
"What is 4AT's mission?",
"Does 4AT offer CFO services?",
"Where is 4AT headquartered?",
]
print("\n\n================ Evaluation Start ================")
for query in sample_queries:
print(f"\n\nπŸ” Query: {query}")
bundle = QueryBundle(query_str=query)
try:
start = time.time()
results = retriever._retrieve(bundle)
duration = time.time() - start
print(f"πŸ“„ Retrieved {len(results)} nodes in {duration:.2f}s:")
for i, node in enumerate(results):
print(f"\n--- Node {i+1} ---")
print(node.node.text[:400] + ("..." if len(node.node.text) > 400 else ""))
except Exception as e:
print(f"❌ Error: {e}")
print("\n\nβœ… Evaluation complete.")