|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import json |
|
|
import time |
|
|
import pickle |
|
|
from llama_index.core import VectorStoreIndex |
|
|
from llama_index.core.node_parser import SentenceSplitter |
|
|
from llama_index.embeddings.cohere import CohereEmbedding |
|
|
from llama_index.vector_stores.chroma import ChromaVectorStore |
|
|
from llama_index.core.retrievers import VectorIndexRetriever |
|
|
from llama_index.core.schema import QueryBundle |
|
|
from custom_retriever import CustomRetriever |
|
|
|
|
|
|
|
|
print("Loading Chroma vector store and document dict...") |
|
|
db_path = "data/chroma-db-all_sources" |
|
|
dict_path = os.path.join(db_path, "document_dict_all_sources.pkl") |
|
|
|
|
|
|
|
|
import chromadb |
|
|
client = chromadb.PersistentClient(path=db_path) |
|
|
collection = client.get_or_create_collection(name="chroma-db-all_sources") |
|
|
vector_store = ChromaVectorStore(chroma_collection=collection) |
|
|
|
|
|
|
|
|
embed_model = CohereEmbedding( |
|
|
api_key=os.environ.get("COHERE_API_KEY"), |
|
|
model_name="embed-english-v3.0", |
|
|
input_type="search_query", |
|
|
) |
|
|
|
|
|
|
|
|
with open(dict_path, "rb") as f: |
|
|
document_dict = pickle.load(f) |
|
|
|
|
|
|
|
|
index = VectorStoreIndex.from_vector_store( |
|
|
vector_store=vector_store, |
|
|
embed_model=embed_model, |
|
|
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)], |
|
|
) |
|
|
retriever = CustomRetriever( |
|
|
vector_retriever=VectorIndexRetriever(index=index, similarity_top_k=10), |
|
|
document_dict=document_dict, |
|
|
) |
|
|
|
|
|
|
|
|
sample_queries = [ |
|
|
"Who are the founders of 4AT?", |
|
|
"What services does 4AT provide?", |
|
|
"Where are the office locations of 4AT?", |
|
|
"Tell me about 4AT Academy", |
|
|
"What industries does 4AT serve?", |
|
|
"How many successful projects has 4AT delivered?", |
|
|
"What is 4AT's mission?", |
|
|
"Does 4AT offer CFO services?", |
|
|
"Where is 4AT headquartered?", |
|
|
] |
|
|
|
|
|
print("\n\n================ Evaluation Start ================") |
|
|
for query in sample_queries: |
|
|
print(f"\n\nπ Query: {query}") |
|
|
bundle = QueryBundle(query_str=query) |
|
|
try: |
|
|
start = time.time() |
|
|
results = retriever._retrieve(bundle) |
|
|
duration = time.time() - start |
|
|
|
|
|
print(f"π Retrieved {len(results)} nodes in {duration:.2f}s:") |
|
|
for i, node in enumerate(results): |
|
|
print(f"\n--- Node {i+1} ---") |
|
|
print(node.node.text[:400] + ("..." if len(node.node.text) > 400 else "")) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error: {e}") |
|
|
|
|
|
print("\n\nβ
Evaluation complete.") |
|
|
|