Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import json | |
| import torch | |
| import numpy as np | |
| from utils import ModelWrapper | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| st.title('HRA Document QA') | |
| with st.spinner("Please wait for loading the models"): | |
| model_loader = ModelWrapper() | |
| with st.chat_message("assistant"): | |
| st.write("Hello π I am an HRA chatbot~") | |
| st.write("I know everything about the leadership of HRA.") | |
| st.write("Please ask your questions about the leadership of HRA. For example, you can ask 'Where did Robert Kauffman graduate?', 'What's the position for Fred Danback?' ") | |
| question = st.chat_input("Please ask me some questions about the leadership of HRA:") | |
| if question: | |
| with st.chat_message("assistant"): | |
| st.write("You asked a question:") | |
| with st.chat_message("user"): | |
| st.write(question) | |
| # get the embeddings for the question | |
| question_embeddings = model_loader.get_embeddings(question, 0) | |
| # get the embeddings of all the documents | |
| if 0: | |
| with st.spinner("Please wait for computing the embeddings"): | |
| files = os.listdir("./documents") | |
| document_embeddings = {} | |
| for file in files: | |
| # open document | |
| f = open("./documents/"+file,"r", encoding="utf-8") | |
| f = f.read() | |
| # get the embedding of the document | |
| document_embeddings[file] = model_loader.get_embeddings(f, 1).tolist() | |
| # save the embeddings of all the documents as vector database | |
| with open("./vectors/embeddings.json","w") as outfile: | |
| outfile.write(json.dumps(document_embeddings, indent=4)) | |
| embeddings_file = open("./vectors/embeddings.json","r") | |
| document_embeddings = json.load(embeddings_file) | |
| # linear search for the most relevant documnet | |
| max_similarity = -1 | |
| most_relevant_document = None | |
| for document in document_embeddings: | |
| cur_similarity = cosine_similarity(question_embeddings, document_embeddings[document]) | |
| if cur_similarity > max_similarity: | |
| most_relevant_document = document | |
| max_similarity = cur_similarity | |
| with st.chat_message("assistant"): | |
| if max_similarity < 0.35: | |
| st.write("Sorry we can't find relevant document") | |
| else: | |
| st.write("The most relevant document is:") | |
| st.write(most_relevant_document) | |
| st.write("And the cosine similarity is:" + str(max_similarity)) | |
| if max_similarity >= 0.35: | |
| with open("./documents/"+most_relevant_document, "r", encoding="utf-8") as f: | |
| f = f.read() | |
| inputs = model_loader.tokenizer(question, f, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model_loader.model_qa(**inputs) | |
| answer_start_index = outputs.start_logits.argmax() | |
| answer_end_index = outputs.end_logits.argmax() | |
| predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1] | |
| predict_answer = model_loader.tokenizer.decode(predict_answer_tokens, skip_special_tokens=True) | |
| with st.chat_message("assistant"): | |
| st.write("Answer:") | |
| if predict_answer: | |
| st.write(predict_answer) | |
| else: | |
| st.write(f) | |