File size: 4,718 Bytes
5ddcfe5
9d67682
5ddcfe5
 
9d67682
 
1fb6cc9
5ddcfe5
1fb6cc9
5ddcfe5
 
9d67682
 
5ddcfe5
 
6684fb8
5ddcfe5
 
 
6684fb8
1beb8d5
cb75cb1
5ddcfe5
 
 
 
 
 
c61edd7
9d67682
 
 
 
 
5ddcfe5
9d67682
 
5ddcfe5
 
 
 
6684fb8
1beb8d5
6684fb8
9d67682
 
 
5ddcfe5
1beb8d5
 
 
6684fb8
1beb8d5
5ddcfe5
1beb8d5
5ddcfe5
 
 
 
9d67682
 
 
6684fb8
 
 
1beb8d5
 
 
 
 
 
 
 
 
 
9d67682
1beb8d5
9d67682
1beb8d5
6684fb8
1beb8d5
9d67682
6684fb8
1beb8d5
 
 
 
 
9d67682
1beb8d5
6684fb8
c61edd7
5ddcfe5
 
 
 
 
c61edd7
5ddcfe5
 
 
6684fb8
5ddcfe5
9d67682
 
 
5ddcfe5
 
6684fb8
5ddcfe5
593bfb7
 
 
 
 
5ddcfe5
 
 
593bfb7
 
 
 
 
5ddcfe5
 
6684fb8
 
 
5ddcfe5
593bfb7
5ddcfe5
6684fb8
5ddcfe5
 
 
 
 
 
 
 
9d67682
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import json
import pickle
from dotenv import load_dotenv
import chromadb

from llama_index.core import Document, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.embeddings.cohere import CohereEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore

from custom_retriever import CustomRetriever
from utils import init_mongo_db

# βœ… Load .env file
load_dotenv()

def create_docs(input_file: str) -> list[Document]:
    """Read JSONL and convert to LlamaIndex Document objects."""
    documents = []
    with open(input_file, "r", encoding="utf-8") as f:
        for line in f:
            data = json.loads(line)
            documents.append(
                Document(
                    doc_id=data["doc_id"],
                    text=data["content"],
                    metadata={
                        "url": data["metadata"]["url"],
                        "title": data["metadata"]["name"],
                        "tokens": data["metadata"]["tokens"],
                        "retrieve_doc": data["metadata"]["retrieve_doc"],
                        "source": data["metadata"]["source"],
                    },
                    excluded_llm_metadata_keys=["title", "tokens", "retrieve_doc", "source"],
                    excluded_embed_metadata_keys=["url", "tokens", "retrieve_doc", "source"],
                )
            )
    return documents


def setup_database(db_collection_name, dict_file_name, input_data_file=None) -> CustomRetriever:
    """Create or load Chroma DB + build custom retriever."""
    db_path = f"data/{db_collection_name}"
    db = chromadb.PersistentClient(path=db_path)
    chroma_collection = db.get_or_create_collection(name=db_collection_name)
    vector_store = ChromaVectorStore(chroma_collection=chroma_collection)

    cohere_api_key = os.environ.get("COHERE_API_KEY")
    if not cohere_api_key:
        raise ValueError("❌ Missing COHERE_API_KEY in .env")

    embed_model = CohereEmbedding(
        api_key=cohere_api_key,
        model_name="embed-english-v3.0",
        input_type="search_query",
    )

    document_dict = {}
    if chroma_collection.count() == 0:
        if not input_data_file or not os.path.exists(input_data_file):
            raise FileNotFoundError(f"❌ Missing: {input_data_file}")

        print(f"🧠 Building vector DB from: {input_data_file}")
        documents = create_docs(input_data_file)

        index = VectorStoreIndex.from_documents(
            documents,
            vector_store=vector_store,
            transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)],
            embed_model=embed_model,
            show_progress=True,
        )

        os.makedirs(db_path, exist_ok=True)
        document_dict = {doc.doc_id: doc for doc in documents}
        with open(f"{db_path}/{dict_file_name}", "wb") as f:
            pickle.dump(document_dict, f)
        print(f"βœ… Vector DB + document dict saved in '{db_path}'")

    else:
        print(f"♻️ Loading existing DB from: {db_path}")
        index = VectorStoreIndex.from_vector_store(
            vector_store=vector_store,
            embed_model=embed_model,
            transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)],
        )
        with open(f"{db_path}/{dict_file_name}", "rb") as f:
            document_dict = pickle.load(f)
        print("βœ… Document dict loaded successfully")

    vector_retriever = VectorIndexRetriever(
        index=index,
        similarity_top_k=15,
        embed_model=embed_model,
    )

    return CustomRetriever(vector_retriever, document_dict)


# 🧠 Load retriever for entire 4AT knowledge base
custom_retriever_all_sources: CustomRetriever = setup_database(
    db_collection_name="chroma-db-all_sources",
    dict_file_name="document_dict_all_sources.pkl",
    input_data_file="data/4at_content.jsonl"
)

# UI toggle filters β€” currently same retriever
AVAILABLE_SOURCES_UI = [
    "4AT Website",
    "4AT Blog",
    "4AT Case Studies",
    "4AT Services",
    "4AT AI Solutions",
]

AVAILABLE_SOURCES = [
    "4at_website",
    "4at_blog",
    "4at_case_studies",
    "4at_services",
    "4at_ai_solutions",
]

# Optional Mongo logging
CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64))
MONGODB_URI = os.getenv("MONGODB_URI")
mongo_db = (
    init_mongo_db(uri=MONGODB_URI, db_name="4at-data")
    if MONGODB_URI
    else print("⚠️ MONGODB_URI not set β€” skipping Mongo DB logging")
)

__all__ = [
    "custom_retriever_all_sources",
    "mongo_db",
    "CONCURRENCY_COUNT",
    "AVAILABLE_SOURCES_UI",
    "AVAILABLE_SOURCES",
]