|
|
import os |
|
|
import json |
|
|
import pickle |
|
|
from dotenv import load_dotenv |
|
|
import chromadb |
|
|
|
|
|
from llama_index.core import Document, VectorStoreIndex |
|
|
from llama_index.core.node_parser import SentenceSplitter |
|
|
from llama_index.core.retrievers import VectorIndexRetriever |
|
|
from llama_index.embeddings.cohere import CohereEmbedding |
|
|
from llama_index.vector_stores.chroma import ChromaVectorStore |
|
|
|
|
|
from custom_retriever import CustomRetriever |
|
|
from utils import init_mongo_db |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
def create_docs(input_file: str) -> list[Document]: |
|
|
"""Read JSONL and convert to LlamaIndex Document objects.""" |
|
|
documents = [] |
|
|
with open(input_file, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
data = json.loads(line) |
|
|
documents.append( |
|
|
Document( |
|
|
doc_id=data["doc_id"], |
|
|
text=data["content"], |
|
|
metadata={ |
|
|
"url": data["metadata"]["url"], |
|
|
"title": data["metadata"]["name"], |
|
|
"tokens": data["metadata"]["tokens"], |
|
|
"retrieve_doc": data["metadata"]["retrieve_doc"], |
|
|
"source": data["metadata"]["source"], |
|
|
}, |
|
|
excluded_llm_metadata_keys=["title", "tokens", "retrieve_doc", "source"], |
|
|
excluded_embed_metadata_keys=["url", "tokens", "retrieve_doc", "source"], |
|
|
) |
|
|
) |
|
|
return documents |
|
|
|
|
|
|
|
|
def setup_database(db_collection_name, dict_file_name, input_data_file=None) -> CustomRetriever: |
|
|
"""Create or load Chroma DB + build custom retriever.""" |
|
|
db_path = f"data/{db_collection_name}" |
|
|
db = chromadb.PersistentClient(path=db_path) |
|
|
chroma_collection = db.get_or_create_collection(name=db_collection_name) |
|
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
|
|
|
|
|
cohere_api_key = os.environ.get("COHERE_API_KEY") |
|
|
if not cohere_api_key: |
|
|
raise ValueError("β Missing COHERE_API_KEY in .env") |
|
|
|
|
|
embed_model = CohereEmbedding( |
|
|
api_key=cohere_api_key, |
|
|
model_name="embed-english-v3.0", |
|
|
input_type="search_query", |
|
|
) |
|
|
|
|
|
document_dict = {} |
|
|
if chroma_collection.count() == 0: |
|
|
if not input_data_file or not os.path.exists(input_data_file): |
|
|
raise FileNotFoundError(f"β Missing: {input_data_file}") |
|
|
|
|
|
print(f"π§ Building vector DB from: {input_data_file}") |
|
|
documents = create_docs(input_data_file) |
|
|
|
|
|
index = VectorStoreIndex.from_documents( |
|
|
documents, |
|
|
vector_store=vector_store, |
|
|
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)], |
|
|
embed_model=embed_model, |
|
|
show_progress=True, |
|
|
) |
|
|
|
|
|
os.makedirs(db_path, exist_ok=True) |
|
|
document_dict = {doc.doc_id: doc for doc in documents} |
|
|
with open(f"{db_path}/{dict_file_name}", "wb") as f: |
|
|
pickle.dump(document_dict, f) |
|
|
print(f"β
Vector DB + document dict saved in '{db_path}'") |
|
|
|
|
|
else: |
|
|
print(f"β»οΈ Loading existing DB from: {db_path}") |
|
|
index = VectorStoreIndex.from_vector_store( |
|
|
vector_store=vector_store, |
|
|
embed_model=embed_model, |
|
|
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)], |
|
|
) |
|
|
with open(f"{db_path}/{dict_file_name}", "rb") as f: |
|
|
document_dict = pickle.load(f) |
|
|
print("β
Document dict loaded successfully") |
|
|
|
|
|
vector_retriever = VectorIndexRetriever( |
|
|
index=index, |
|
|
similarity_top_k=15, |
|
|
embed_model=embed_model, |
|
|
) |
|
|
|
|
|
return CustomRetriever(vector_retriever, document_dict) |
|
|
|
|
|
|
|
|
|
|
|
custom_retriever_all_sources: CustomRetriever = setup_database( |
|
|
db_collection_name="chroma-db-all_sources", |
|
|
dict_file_name="document_dict_all_sources.pkl", |
|
|
input_data_file="data/4at_content.jsonl" |
|
|
) |
|
|
|
|
|
|
|
|
AVAILABLE_SOURCES_UI = [ |
|
|
"4AT Website", |
|
|
"4AT Blog", |
|
|
"4AT Case Studies", |
|
|
"4AT Services", |
|
|
"4AT AI Solutions", |
|
|
] |
|
|
|
|
|
AVAILABLE_SOURCES = [ |
|
|
"4at_website", |
|
|
"4at_blog", |
|
|
"4at_case_studies", |
|
|
"4at_services", |
|
|
"4at_ai_solutions", |
|
|
] |
|
|
|
|
|
|
|
|
CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64)) |
|
|
MONGODB_URI = os.getenv("MONGODB_URI") |
|
|
mongo_db = ( |
|
|
init_mongo_db(uri=MONGODB_URI, db_name="4at-data") |
|
|
if MONGODB_URI |
|
|
else print("β οΈ MONGODB_URI not set β skipping Mongo DB logging") |
|
|
) |
|
|
|
|
|
__all__ = [ |
|
|
"custom_retriever_all_sources", |
|
|
"mongo_db", |
|
|
"CONCURRENCY_COUNT", |
|
|
"AVAILABLE_SOURCES_UI", |
|
|
"AVAILABLE_SOURCES", |
|
|
] |
|
|
|