Siddharth63 commited on
Commit
c6c48c5
·
verified ·
1 Parent(s): 43550eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -55,32 +55,36 @@ class EmbeddingBackend:
55
  return np.asarray(emb)
56
 
57
  # --------------------------------------------------
58
- # Hybrid (exact semantic) index
59
  # --------------------------------------------------
60
  class HybridIndex:
61
  def __init__(self, df: pd.DataFrame, text_col: str, backend: EmbeddingBackend):
62
- self.df = df
63
- self.text_col = text_col
64
- self.backend = backend
65
- self.ac_tree = self._build_ac()
66
- self.embeddings = self._build_emb()
 
67
 
68
  # ---------- exact match ----------
69
  def _build_ac(self):
70
  tree = KeywordTree(case_insensitive=True)
71
  for i, passage in self.df[self.text_col].astype(str).items():
72
- tree.add(passage, i)
 
73
  tree.finalize()
74
  return tree
75
 
76
  def exact_hits(self, query: str) -> List[int]:
77
- return list({m[1] for m in self.ac_tree.search_all(query)})
 
 
 
78
 
79
  # ---------- semantic ----------
80
  def _build_emb(self):
81
  docs = self.df[self.text_col].astype(str).tolist()
82
  emb = self.backend.encode_corpus(docs)
83
- # Normalise for cosine similarity via dot‑product
84
  emb_norm = emb / np.linalg.norm(emb, axis=1, keepdims=True)
85
  return emb_norm.astype(np.float32)
86
 
 
55
  return np.asarray(emb)
56
 
57
  # --------------------------------------------------
58
+ # Hybrid (exact semantic) index
59
  # --------------------------------------------------
60
  class HybridIndex:
61
  def __init__(self, df: pd.DataFrame, text_col: str, backend: EmbeddingBackend):
62
+ self.df = df
63
+ self.text_col = text_col
64
+ self.backend = backend
65
+ self.text_to_rows = defaultdict(list) # passage → [row ids]
66
+ self.ac_tree = self._build_ac()
67
+ self.embeddings = self._build_emb()
68
 
69
  # ---------- exact match ----------
70
  def _build_ac(self):
71
  tree = KeywordTree(case_insensitive=True)
72
  for i, passage in self.df[self.text_col].astype(str).items():
73
+ tree.add(passage)
74
+ self.text_to_rows[passage].append(i)
75
  tree.finalize()
76
  return tree
77
 
78
  def exact_hits(self, query: str) -> List[int]:
79
+ rows = set()
80
+ for keyword, _ in self.ac_tree.search_all(query):
81
+ rows.update(self.text_to_rows[keyword])
82
+ return list(rows)
83
 
84
  # ---------- semantic ----------
85
  def _build_emb(self):
86
  docs = self.df[self.text_col].astype(str).tolist()
87
  emb = self.backend.encode_corpus(docs)
 
88
  emb_norm = emb / np.linalg.norm(emb, axis=1, keepdims=True)
89
  return emb_norm.astype(np.float32)
90