Ahambrahmasmi's picture
Update scripts/setup.py
6684fb8 verified
import os
import json
import pickle
from dotenv import load_dotenv
import chromadb
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.embeddings.cohere import CohereEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore
from custom_retriever import CustomRetriever
from utils import init_mongo_db
# βœ… Load .env file
load_dotenv()
def create_docs(input_file: str) -> list[Document]:
"""Read JSONL and convert to LlamaIndex Document objects."""
documents = []
with open(input_file, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
documents.append(
Document(
doc_id=data["doc_id"],
text=data["content"],
metadata={
"url": data["metadata"]["url"],
"title": data["metadata"]["name"],
"tokens": data["metadata"]["tokens"],
"retrieve_doc": data["metadata"]["retrieve_doc"],
"source": data["metadata"]["source"],
},
excluded_llm_metadata_keys=["title", "tokens", "retrieve_doc", "source"],
excluded_embed_metadata_keys=["url", "tokens", "retrieve_doc", "source"],
)
)
return documents
def setup_database(db_collection_name, dict_file_name, input_data_file=None) -> CustomRetriever:
"""Create or load Chroma DB + build custom retriever."""
db_path = f"data/{db_collection_name}"
db = chromadb.PersistentClient(path=db_path)
chroma_collection = db.get_or_create_collection(name=db_collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
cohere_api_key = os.environ.get("COHERE_API_KEY")
if not cohere_api_key:
raise ValueError("❌ Missing COHERE_API_KEY in .env")
embed_model = CohereEmbedding(
api_key=cohere_api_key,
model_name="embed-english-v3.0",
input_type="search_query",
)
document_dict = {}
if chroma_collection.count() == 0:
if not input_data_file or not os.path.exists(input_data_file):
raise FileNotFoundError(f"❌ Missing: {input_data_file}")
print(f"🧠 Building vector DB from: {input_data_file}")
documents = create_docs(input_data_file)
index = VectorStoreIndex.from_documents(
documents,
vector_store=vector_store,
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)],
embed_model=embed_model,
show_progress=True,
)
os.makedirs(db_path, exist_ok=True)
document_dict = {doc.doc_id: doc for doc in documents}
with open(f"{db_path}/{dict_file_name}", "wb") as f:
pickle.dump(document_dict, f)
print(f"βœ… Vector DB + document dict saved in '{db_path}'")
else:
print(f"♻️ Loading existing DB from: {db_path}")
index = VectorStoreIndex.from_vector_store(
vector_store=vector_store,
embed_model=embed_model,
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)],
)
with open(f"{db_path}/{dict_file_name}", "rb") as f:
document_dict = pickle.load(f)
print("βœ… Document dict loaded successfully")
vector_retriever = VectorIndexRetriever(
index=index,
similarity_top_k=15,
embed_model=embed_model,
)
return CustomRetriever(vector_retriever, document_dict)
# 🧠 Load retriever for entire 4AT knowledge base
custom_retriever_all_sources: CustomRetriever = setup_database(
db_collection_name="chroma-db-all_sources",
dict_file_name="document_dict_all_sources.pkl",
input_data_file="data/4at_content.jsonl"
)
# UI toggle filters β€” currently same retriever
AVAILABLE_SOURCES_UI = [
"4AT Website",
"4AT Blog",
"4AT Case Studies",
"4AT Services",
"4AT AI Solutions",
]
AVAILABLE_SOURCES = [
"4at_website",
"4at_blog",
"4at_case_studies",
"4at_services",
"4at_ai_solutions",
]
# Optional Mongo logging
CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64))
MONGODB_URI = os.getenv("MONGODB_URI")
mongo_db = (
init_mongo_db(uri=MONGODB_URI, db_name="4at-data")
if MONGODB_URI
else print("⚠️ MONGODB_URI not set β€” skipping Mongo DB logging")
)
__all__ = [
"custom_retriever_all_sources",
"mongo_db",
"CONCURRENCY_COUNT",
"AVAILABLE_SOURCES_UI",
"AVAILABLE_SOURCES",
]