Spaces:
Build error
Build error
Upload app.py
Browse files
app.py
CHANGED
|
@@ -59,11 +59,17 @@ def save_key(api_key):
|
|
| 59 |
return api_key
|
| 60 |
|
| 61 |
|
| 62 |
-
def query_pinecone(query, top_k, model, index):
|
| 63 |
# generate embeddings for the query
|
| 64 |
xq = model.encode([query]).tolist()
|
| 65 |
# search pinecone index for context passage with the answer
|
| 66 |
xc = index.query(xq, top_k=top_k, include_metadata=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
return xc
|
| 68 |
|
| 69 |
|
|
@@ -127,19 +133,19 @@ st.title("Abstractive Question Answering - APPL")
|
|
| 127 |
|
| 128 |
query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
|
| 129 |
|
| 130 |
-
num_results = int(st.number_input("Number of Results to query", 1, 5, value=
|
| 131 |
|
| 132 |
|
| 133 |
# Choose encoder model
|
| 134 |
|
| 135 |
-
encoder_models_choice = ["
|
| 136 |
|
| 137 |
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
|
| 138 |
|
| 139 |
|
| 140 |
# Choose decoder model
|
| 141 |
|
| 142 |
-
decoder_models_choice = ["GPT3 (QA_davinci)", "GPT3 (
|
| 143 |
|
| 144 |
decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
|
| 145 |
|
|
@@ -163,23 +169,33 @@ elif encoder_model == "SGPT":
|
|
| 163 |
retriever_model = get_sgpt_embedding_model()
|
| 164 |
|
| 165 |
|
| 166 |
-
query_results = query_pinecone(query_text, num_results, retriever_model, pinecone_index)
|
| 167 |
-
|
| 168 |
window = int(st.number_input("Sentence Window Size", 1, 3, value=1))
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
data = get_data()
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
st.subheader("Answer:")
|
| 177 |
|
| 178 |
|
| 179 |
-
if decoder_model == "GPT3 (
|
| 180 |
openai_key = st.text_input(
|
| 181 |
"Enter OpenAI key",
|
| 182 |
-
value="sk-
|
| 183 |
type="password",
|
| 184 |
)
|
| 185 |
api_key = save_key(openai_key)
|
|
@@ -193,7 +209,7 @@ if decoder_model == "GPT3 (text_davinci)":
|
|
| 193 |
elif decoder_model == "GPT3 (QA_davinci)":
|
| 194 |
openai_key = st.text_input(
|
| 195 |
"Enter OpenAI key",
|
| 196 |
-
value="sk-
|
| 197 |
type="password",
|
| 198 |
)
|
| 199 |
api_key = save_key(openai_key)
|
|
|
|
| 59 |
return api_key
|
| 60 |
|
| 61 |
|
| 62 |
+
def query_pinecone(query, top_k, model, index, threshold=0.5):
|
| 63 |
# generate embeddings for the query
|
| 64 |
xq = model.encode([query]).tolist()
|
| 65 |
# search pinecone index for context passage with the answer
|
| 66 |
xc = index.query(xq, top_k=top_k, include_metadata=True)
|
| 67 |
+
# filter the context passages based on the score threshold
|
| 68 |
+
filtered_matches = []
|
| 69 |
+
for match in xc["matches"]:
|
| 70 |
+
if match["score"] >= threshold:
|
| 71 |
+
filtered_matches.append(match)
|
| 72 |
+
xc["matches"] = filtered_matches
|
| 73 |
return xc
|
| 74 |
|
| 75 |
|
|
|
|
| 133 |
|
| 134 |
query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
|
| 135 |
|
| 136 |
+
num_results = int(st.number_input("Number of Results to query", 1, 5, value=3))
|
| 137 |
|
| 138 |
|
| 139 |
# Choose encoder model
|
| 140 |
|
| 141 |
+
encoder_models_choice = ["SGPT", "MPNET"]
|
| 142 |
|
| 143 |
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
|
| 144 |
|
| 145 |
|
| 146 |
# Choose decoder model
|
| 147 |
|
| 148 |
+
decoder_models_choice = ["GPT3 (QA_davinci)", "GPT3 (summary_davinci)", "T5", "FLAN-T5"]
|
| 149 |
|
| 150 |
decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
|
| 151 |
|
|
|
|
| 169 |
retriever_model = get_sgpt_embedding_model()
|
| 170 |
|
| 171 |
|
|
|
|
|
|
|
| 172 |
window = int(st.number_input("Sentence Window Size", 1, 3, value=1))
|
| 173 |
|
| 174 |
+
threshold = float(
|
| 175 |
+
st.number_input(
|
| 176 |
+
label="Similarity Score Threshold", step=0.05, format="%.2f", value=0.55
|
| 177 |
+
)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
data = get_data()
|
| 181 |
|
| 182 |
+
query_results = query_pinecone(
|
| 183 |
+
query_text, num_results, retriever_model, pinecone_index, threshold
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if threshold <= 0.65:
|
| 187 |
+
context_list = sentence_id_combine(data, query_results, lag=window)
|
| 188 |
+
else:
|
| 189 |
+
context_list = format_query(query_results)
|
| 190 |
|
| 191 |
|
| 192 |
st.subheader("Answer:")
|
| 193 |
|
| 194 |
|
| 195 |
+
if decoder_model == "GPT3 (summary_davinci)":
|
| 196 |
openai_key = st.text_input(
|
| 197 |
"Enter OpenAI key",
|
| 198 |
+
value="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",
|
| 199 |
type="password",
|
| 200 |
)
|
| 201 |
api_key = save_key(openai_key)
|
|
|
|
| 209 |
elif decoder_model == "GPT3 (QA_davinci)":
|
| 210 |
openai_key = st.text_input(
|
| 211 |
"Enter OpenAI key",
|
| 212 |
+
value="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",
|
| 213 |
type="password",
|
| 214 |
)
|
| 215 |
api_key = save_key(openai_key)
|