Siddharth63 commited on
Commit
6086a80
·
verified ·
1 Parent(s): 8b74da9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -27,24 +27,32 @@ class EmbeddingBackend:
27
  def __init__(self, repo: str):
28
  self.repo = repo
29
  if repo == "BAAI/bge-small-en-v1.5":
30
- # FlagEmbedding back‑end (BGE)
31
  self.model = FlagModel(
32
  repo,
33
- query_instruction_for_retrieval="Generate a representation for this sentence to retrieve related articles:",
34
  use_fp16=True,
35
  )
36
- self.encode_docs = self.model.encode
37
  self.encode_query = lambda q: self.model.encode_queries([q])[0]
38
  else:
39
- # SentenceTransformer back‑ends
40
- self.model = SentenceTransformer(repo, trust_remote_code=True)
 
 
 
 
41
  if "Qwen3" in repo:
42
  self.encode_query = lambda q: self.model.encode(q, prompt_name="query")
43
  elif "stella" in repo:
44
  self.encode_query = lambda q: self.model.encode(q, prompt_name="s2p_query")
45
  else:
46
  self.encode_query = lambda q: self.model.encode(q)
47
- self.encode_docs = lambda docs: self.model.encode(docs)
 
 
 
 
 
48
 
49
  # Convenience wrappers that return *numpy* arrays
50
  def encode_corpus(self, passages: List[str]) -> np.ndarray:
 
27
  def __init__(self, repo: str):
28
  self.repo = repo
29
  if repo == "BAAI/bge-small-en-v1.5":
 
30
  self.model = FlagModel(
31
  repo,
32
+ query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
33
  use_fp16=True,
34
  )
35
+ self.encode_docs = lambda docs: self.model.encode(docs, batch_size=BATCH_SIZE)
36
  self.encode_query = lambda q: self.model.encode_queries([q])[0]
37
  else:
38
+ model_kwargs = {}
39
+ if "Qwen3" in repo and not os.getenv("QWEN_USE_FLASH"):
40
+ model_kwargs["attn_implementation"] = "eager" # lower‑mem CPU path
41
+ self.model = SentenceTransformer(repo, trust_remote_code=True, model_kwargs=model_kwargs)
42
+
43
+ # Custom token truncation handled externally
44
  if "Qwen3" in repo:
45
  self.encode_query = lambda q: self.model.encode(q, prompt_name="query")
46
  elif "stella" in repo:
47
  self.encode_query = lambda q: self.model.encode(q, prompt_name="s2p_query")
48
  else:
49
  self.encode_query = lambda q: self.model.encode(q)
50
+
51
+ self.encode_docs = lambda docs: self.model.encode(
52
+ docs,
53
+ batch_size=BATCH_SIZE,
54
+ normalize_embeddings=False,
55
+ )
56
 
57
  # Convenience wrappers that return *numpy* arrays
58
  def encode_corpus(self, passages: List[str]) -> np.ndarray: