Spaces:
Runtime error
Runtime error
add cross-encoder
Browse files- app.py +5 -4
- backend/semantic_search.py +12 -3
app.py
CHANGED
|
@@ -34,7 +34,7 @@ def add_text(history, text):
|
|
| 34 |
return history, gr.Textbox(value="", interactive=False)
|
| 35 |
|
| 36 |
|
| 37 |
-
def bot(history, api_kind):
|
| 38 |
query = history[-1][0]
|
| 39 |
|
| 40 |
if not query:
|
|
@@ -44,7 +44,7 @@ def bot(history, api_kind):
|
|
| 44 |
# Retrieve documents relevant to query
|
| 45 |
document_start = perf_counter()
|
| 46 |
|
| 47 |
-
documents = retrieve(query, TOP_K)
|
| 48 |
|
| 49 |
document_time = perf_counter() - document_start
|
| 50 |
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
|
@@ -86,12 +86,13 @@ with gr.Blocks() as demo:
|
|
| 86 |
)
|
| 87 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
| 88 |
|
| 89 |
-
api_kind = gr.
|
|
|
|
| 90 |
|
| 91 |
prompt_html = gr.HTML()
|
| 92 |
# Turn off interactivity while generating if you click
|
| 93 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
| 94 |
-
bot, [chatbot, api_kind], [chatbot, prompt_html])
|
| 95 |
|
| 96 |
# Turn it back on
|
| 97 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
|
|
|
| 34 |
return history, gr.Textbox(value="", interactive=False)
|
| 35 |
|
| 36 |
|
| 37 |
+
def bot(history, api_kind, with_cross_encoder):
|
| 38 |
query = history[-1][0]
|
| 39 |
|
| 40 |
if not query:
|
|
|
|
| 44 |
# Retrieve documents relevant to query
|
| 45 |
document_start = perf_counter()
|
| 46 |
|
| 47 |
+
documents = retrieve(query, TOP_K, with_cross_encoder)
|
| 48 |
|
| 49 |
document_time = perf_counter() - document_start
|
| 50 |
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
|
|
|
| 86 |
)
|
| 87 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
| 88 |
|
| 89 |
+
api_kind = gr.Checkbox(label="Cross-encoder")
|
| 90 |
+
cross_encoder = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
|
| 91 |
|
| 92 |
prompt_html = gr.HTML()
|
| 93 |
# Turn off interactivity while generating if you click
|
| 94 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
| 95 |
+
bot, [chatbot, api_kind, cross_encoder], [chatbot, prompt_html])
|
| 96 |
|
| 97 |
# Turn it back on
|
| 98 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
backend/semantic_search.py
CHANGED
|
@@ -2,6 +2,7 @@ import lancedb
|
|
| 2 |
import os
|
| 3 |
import gradio as gr
|
| 4 |
from sentence_transformers import SentenceTransformer
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
db = lancedb.connect(".lancedb")
|
|
@@ -12,13 +13,21 @@ TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
|
|
| 12 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
| 13 |
|
| 14 |
retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
-
def retrieve(query, k):
|
| 18 |
query_vec = retriever.encode(query)
|
| 19 |
try:
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
return documents
|
| 24 |
|
|
|
|
| 2 |
import os
|
| 3 |
import gradio as gr
|
| 4 |
from sentence_transformers import SentenceTransformer
|
| 5 |
+
from sentence_transformers import CrossEncoder
|
| 6 |
|
| 7 |
|
| 8 |
db = lancedb.connect(".lancedb")
|
|
|
|
| 13 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
| 14 |
|
| 15 |
retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
|
| 16 |
+
cross_encoder = CrossEncoder(os.getenv("RERANK_MODEL"), max_length=512)
|
| 17 |
|
| 18 |
|
| 19 |
+
def retrieve(query, k, with_cross_encoder=False):
|
| 20 |
query_vec = retriever.encode(query)
|
| 21 |
try:
|
| 22 |
+
if not with_cross_encoder:
|
| 23 |
+
documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
|
| 24 |
+
documents = [doc[TEXT_COLUMN] for doc in documents]
|
| 25 |
+
else:
|
| 26 |
+
documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k * 2).to_list()
|
| 27 |
+
scores = cross_encoder.predict([(query, doc[TEXT_COLUMN]) for doc in documents])
|
| 28 |
+
indexed_arr = [(elem, index) for index, elem in enumerate(scores)]
|
| 29 |
+
sorted_arr = sorted(indexed_arr, key=lambda x: x[0], reverse=True)
|
| 30 |
+
documents = [elem for elem, _ in sorted_arr[:k]]
|
| 31 |
|
| 32 |
return documents
|
| 33 |
|