File size: 6,754 Bytes
775a7d0
0b8a777
 
775a7d0
4af310b
88cc76a
9d21791
4af310b
 
9d21791
0b8a777
 
 
 
 
88cc76a
0b8a777
4af310b
 
 
88cc76a
 
4af310b
 
 
 
0b8a777
4af310b
 
 
 
 
 
 
 
 
 
 
 
0b8a777
 
 
 
 
 
 
 
88cc76a
 
 
0b8a777
 
88cc76a
0b8a777
4af310b
 
 
 
 
88cc76a
 
 
 
 
 
 
 
 
 
 
0b8a777
 
 
 
88cc76a
0b8a777
 
 
88cc76a
0b8a777
 
 
 
88cc76a
 
0b8a777
9d21791
 
4af310b
775a7d0
9d21791
88cc76a
be6b61f
 
 
 
 
 
 
 
 
58a1fee
 
be6b61f
58a1fee
 
 
 
 
 
 
 
be6b61f
 
4af310b
88cc76a
4af310b
0b8a777
 
88cc76a
0b8a777
4af310b
88cc76a
9d21791
88cc76a
9d21791
0b8a777
4af310b
88cc76a
 
 
4af310b
 
0b8a777
 
88cc76a
 
 
 
 
0b8a777
 
775a7d0
88cc76a
9d21791
 
 
88cc76a
 
4af310b
88cc76a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b8a777
4af310b
 
88cc76a
 
 
4af310b
0b8a777
4af310b
775a7d0
88cc76a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import faiss
import os
import pickle
from pypdf import PdfReader
from sentence_transformers import SentenceTransformer, CrossEncoder
from rank_bm25 import BM25Okapi

USE_HNSW = True
USE_RERANKER = True

CHUNK_SIZE = 800
CHUNK_OVERLAP = 200

DB_FILE_INDEX = "vector.index"
DB_FILE_META = "metadata.pkl"
DB_FILE_BM25 = "bm25.pkl"

index = None
documents = []
metadata = []
bm25 = None
tokenized_corpus = []

embedder = SentenceTransformer("all-MiniLM-L6-v2")
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

def chunk_text(text):
    import re
    sentences = re.split(r'(?<=[.!?])\s+', text)
    chunks, current = [], ""
    for s in sentences:
        if len(current) + len(s) > CHUNK_SIZE and current:
            chunks.append(current.strip())
            overlap = max(0, len(current) - CHUNK_OVERLAP)
            current = current[overlap:] + " " + s
        else:
            current += " " + s if current else s
    if current.strip():
        chunks.append(current.strip())
    return chunks

def save_db():
    if index:
        faiss.write_index(index, DB_FILE_INDEX)
    if documents:
        with open(DB_FILE_META, "wb") as f:
            pickle.dump({"documents": documents, "metadata": metadata}, f)
    if bm25:
        with open(DB_FILE_BM25, "wb") as f:
            pickle.dump(bm25, f)

def load_db():
    global index, documents, metadata, bm25
    if os.path.exists(DB_FILE_INDEX) and os.path.exists(DB_FILE_META):
        index = faiss.read_index(DB_FILE_INDEX)
        with open(DB_FILE_META, "rb") as f:
            data = pickle.load(f)
            documents = data["documents"]
            metadata = data["metadata"]
    
    if os.path.exists(DB_FILE_BM25):
        with open(DB_FILE_BM25, "rb") as f:
            bm25 = pickle.load(f)
    elif documents:
        # Auto-backfill if documents exist but BM25 is missing
        print("Backfilling BM25 index on first load...")
        tokenized_corpus = [doc.split(" ") for doc in documents]
        bm25 = BM25Okapi(tokenized_corpus)
        with open(DB_FILE_BM25, "wb") as f:
            pickle.dump(bm25, f)

load_db()

def clear_database():
    global index, documents, metadata, bm25
    index = None
    documents = []
    metadata = []
    bm25 = None
    if os.path.exists(DB_FILE_INDEX):
        os.remove(DB_FILE_INDEX)
    if os.path.exists(DB_FILE_META):
        os.remove(DB_FILE_META)
    if os.path.exists(DB_FILE_BM25):
        os.remove(DB_FILE_BM25)

def ingest_documents(files):
    global index, documents, metadata
    texts, meta = [], []

    for file in files:
        if file.filename.endswith(".pdf"):
            # Save temp file for pymupdf4llm
            import tempfile
            with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
                tmp.write(file.file.read())
                tmp_path = tmp.name
            
            try:
                # Use pymupdf4llm to extract markdown with tables
                import pymupdf4llm
                # Get list of dicts: [{'text': '...', 'metadata': {'page': 1, ...}}]
                pages_data = pymupdf4llm.to_markdown(tmp_path, page_chunks=True)
                
                for page_obj in pages_data:
                    p_text = page_obj["text"]
                    p_num = page_obj["metadata"].get("page", "N/A")
                    
                    # Chunk within the page to preserve page context
                    for chunk in chunk_text(p_text):
                        texts.append(chunk)
                        meta.append({"source": file.filename, "page": p_num})
            finally:
                os.remove(tmp_path)

        elif file.filename.endswith(".txt"):
            content = file.file.read().decode("utf-8", errors="ignore")
            for chunk in chunk_text(content):
                texts.append(chunk)
                meta.append({"source": file.filename, "page": "N/A"})

    if not texts:
        raise ValueError("No readable text found (OCR needed for scanned PDFs).")

    embeddings = embedder.encode(texts, convert_to_numpy=True, normalize_embeddings=True)

    if index is None:
        dim = embeddings.shape[1]
        index = faiss.IndexHNSWFlat(dim, 32) if USE_HNSW else faiss.IndexFlatIP(dim)
        index.hnsw.efConstruction = 200
        index.hnsw.efSearch = 64

    index.add(embeddings)
    documents.extend(texts)
    metadata.extend(meta)
    
    # Update BM25
    tokenized_corpus = [doc.split(" ") for doc in documents]
    bm25 = BM25Okapi(tokenized_corpus)
    
    save_db()
    return len(documents)

def search_knowledge(query, top_k=8):
    if index is None:
        return []

    # 1. Vector Search
    qvec = embedder.encode([query], convert_to_numpy=True, normalize_embeddings=True)
    scores, indices = index.search(qvec, top_k)
    
    vector_results = {}
    for i, (idx, score) in enumerate(zip(indices[0], scores[0])):
        if idx == -1: continue
        vector_results[idx] = i  # Store rank (0-based)

    # 2. Keyword Search (BM25)
    bm25_results = {}
    if bm25:
        tokenized_query = query.split(" ")
        bm25_scores = bm25.get_scores(tokenized_query)
        # Get top_k indices
        top_n = sorted(range(len(bm25_scores)), key=lambda i: bm25_scores[i], reverse=True)[:top_k]
        for i, idx in enumerate(top_n):
            bm25_results[idx] = i  # Store rank

    # 3. Reciprocal Rank Fusion (RRF)
    # score = 1 / (k + rank)
    k = 60
    candidates_idx = set(vector_results.keys()) | set(bm25_results.keys())
    merged_candidates = []

    for idx in candidates_idx:
        v_rank = vector_results.get(idx, float('inf'))
        b_rank = bm25_results.get(idx, float('inf'))
        
        rrf_score = (1 / (k + v_rank)) + (1 / (k + b_rank))
        
        merged_candidates.append({
            "text": documents[idx],
            "metadata": metadata[idx],
            "score": rrf_score,  # This is RRF score, not cosine/BM25 score
            "vector_rank": v_rank if v_rank != float('inf') else None,
            "bm25_rank": b_rank if b_rank != float('inf') else None
        })

    # Sort by RRF score
    merged_candidates.sort(key=lambda x: x["score"], reverse=True)
    
    # 4. Rerank Top Candidates
    candidates = merged_candidates[:10] # Take top 10 for reranking

    if USE_RERANKER and candidates:
        pairs = [(query, c["text"]) for c in candidates]
        rerank_scores = reranker.predict(pairs)
        for c, rs in zip(candidates, rerank_scores):
            c["rerank"] = float(rs)
        candidates.sort(key=lambda x: x["rerank"], reverse=True)

    return candidates[:5]

def get_all_chunks(limit=80):
    return [{"text": t, "metadata": m} for t, m in zip(documents[:limit], metadata[:limit])]