Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -55,32 +55,36 @@ class EmbeddingBackend:
|
|
| 55 |
return np.asarray(emb)
|
| 56 |
|
| 57 |
# --------------------------------------------------
|
| 58 |
-
# Hybrid (exact
|
| 59 |
# --------------------------------------------------
|
| 60 |
class HybridIndex:
|
| 61 |
def __init__(self, df: pd.DataFrame, text_col: str, backend: EmbeddingBackend):
|
| 62 |
-
self.df
|
| 63 |
-
self.text_col
|
| 64 |
-
self.backend
|
| 65 |
-
self.
|
| 66 |
-
self.
|
|
|
|
| 67 |
|
| 68 |
# ---------- exact match ----------
|
| 69 |
def _build_ac(self):
|
| 70 |
tree = KeywordTree(case_insensitive=True)
|
| 71 |
for i, passage in self.df[self.text_col].astype(str).items():
|
| 72 |
-
tree.add(passage
|
|
|
|
| 73 |
tree.finalize()
|
| 74 |
return tree
|
| 75 |
|
| 76 |
def exact_hits(self, query: str) -> List[int]:
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
# ---------- semantic ----------
|
| 80 |
def _build_emb(self):
|
| 81 |
docs = self.df[self.text_col].astype(str).tolist()
|
| 82 |
emb = self.backend.encode_corpus(docs)
|
| 83 |
-
# Normalise for cosine similarity via dot‑product
|
| 84 |
emb_norm = emb / np.linalg.norm(emb, axis=1, keepdims=True)
|
| 85 |
return emb_norm.astype(np.float32)
|
| 86 |
|
|
|
|
| 55 |
return np.asarray(emb)
|
| 56 |
|
| 57 |
# --------------------------------------------------
|
| 58 |
+
# Hybrid (exact → semantic) index
|
| 59 |
# --------------------------------------------------
|
| 60 |
class HybridIndex:
|
| 61 |
def __init__(self, df: pd.DataFrame, text_col: str, backend: EmbeddingBackend):
|
| 62 |
+
self.df = df
|
| 63 |
+
self.text_col = text_col
|
| 64 |
+
self.backend = backend
|
| 65 |
+
self.text_to_rows = defaultdict(list) # passage → [row ids]
|
| 66 |
+
self.ac_tree = self._build_ac()
|
| 67 |
+
self.embeddings = self._build_emb()
|
| 68 |
|
| 69 |
# ---------- exact match ----------
|
| 70 |
def _build_ac(self):
|
| 71 |
tree = KeywordTree(case_insensitive=True)
|
| 72 |
for i, passage in self.df[self.text_col].astype(str).items():
|
| 73 |
+
tree.add(passage)
|
| 74 |
+
self.text_to_rows[passage].append(i)
|
| 75 |
tree.finalize()
|
| 76 |
return tree
|
| 77 |
|
| 78 |
def exact_hits(self, query: str) -> List[int]:
|
| 79 |
+
rows = set()
|
| 80 |
+
for keyword, _ in self.ac_tree.search_all(query):
|
| 81 |
+
rows.update(self.text_to_rows[keyword])
|
| 82 |
+
return list(rows)
|
| 83 |
|
| 84 |
# ---------- semantic ----------
|
| 85 |
def _build_emb(self):
|
| 86 |
docs = self.df[self.text_col].astype(str).tolist()
|
| 87 |
emb = self.backend.encode_corpus(docs)
|
|
|
|
| 88 |
emb_norm = emb / np.linalg.norm(emb, axis=1, keepdims=True)
|
| 89 |
return emb_norm.astype(np.float32)
|
| 90 |
|