Spaces:
Runtime error
Runtime error
SirinootKK
commited on
Commit
Β·
0375104
1
Parent(s):
3f8e821
fixed app.py
Browse files
app.py
CHANGED
|
@@ -184,15 +184,18 @@ class Chatbot:
|
|
| 184 |
# Answer = self.model_pipeline(message, context)
|
| 185 |
# return Answer
|
| 186 |
def predict_semantic_search(self, message):
|
|
|
|
| 187 |
message = message.strip()
|
| 188 |
-
query_embedding = self.embedding_model.encode(
|
| 189 |
-
|
| 190 |
-
hits = util.semantic_search(query_embedding
|
| 191 |
hit = hits[0][0]
|
| 192 |
context = self.df['Context'][hit['corpus_id']]
|
| 193 |
Answer = self.model_pipeline(message, context)
|
| 194 |
return Answer
|
| 195 |
|
|
|
|
|
|
|
| 196 |
def predict_without_faiss(self,message):
|
| 197 |
MostSimilarContext = ""
|
| 198 |
min_distance = 1000
|
|
@@ -212,6 +215,8 @@ class Chatbot:
|
|
| 212 |
return Answer
|
| 213 |
|
| 214 |
bot = ChatbotModel()
|
|
|
|
|
|
|
| 215 |
|
| 216 |
"""#Gradio"""
|
| 217 |
|
|
|
|
| 184 |
# Answer = self.model_pipeline(message, context)
|
| 185 |
# return Answer
|
| 186 |
def predict_semantic_search(self, message):
|
| 187 |
+
corpus_embeddings = bot._chatbot.prepare_sentences_vector(bot._chatbot.get_embeddings(bot._chatbot.df['Context']))
|
| 188 |
message = message.strip()
|
| 189 |
+
query_embedding = self.embedding_model.encode(message, convert_to_tensor=True)
|
| 190 |
+
query_embedding = query_embedding.to('cuda')
|
| 191 |
+
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=1)
|
| 192 |
hit = hits[0][0]
|
| 193 |
context = self.df['Context'][hit['corpus_id']]
|
| 194 |
Answer = self.model_pipeline(message, context)
|
| 195 |
return Answer
|
| 196 |
|
| 197 |
+
|
| 198 |
+
|
| 199 |
def predict_without_faiss(self,message):
|
| 200 |
MostSimilarContext = ""
|
| 201 |
min_distance = 1000
|
|
|
|
| 215 |
return Answer
|
| 216 |
|
| 217 |
bot = ChatbotModel()
|
| 218 |
+
corpus_embeddings = bot._chatbot.get_embeddings(bot._chatbot.df['Context'])
|
| 219 |
+
corpus_embeddings = bot._chatbot.prepare_sentences_vector(corpus_embeddings)
|
| 220 |
|
| 221 |
"""#Gradio"""
|
| 222 |
|