Spaces:
Build error
Build error
Upload app.py
Browse files
app.py
CHANGED
|
@@ -59,11 +59,16 @@ def save_key(api_key):
|
|
| 59 |
return api_key
|
| 60 |
|
| 61 |
|
| 62 |
-
def query_pinecone(query, top_k, model, index, threshold=0.5):
|
| 63 |
# generate embeddings for the query
|
| 64 |
xq = model.encode([query]).tolist()
|
| 65 |
# search pinecone index for context passage with the answer
|
| 66 |
-
xc = index.query(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
# filter the context passages based on the score threshold
|
| 68 |
filtered_matches = []
|
| 69 |
for match in xc["matches"]:
|
|
@@ -137,6 +142,27 @@ st.write(
|
|
| 137 |
|
| 138 |
query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
num_results = int(st.number_input("Number of Results to query", 1, 5, value=3))
|
| 141 |
|
| 142 |
|
|
@@ -180,7 +206,14 @@ threshold = float(
|
|
| 180 |
data = get_data()
|
| 181 |
|
| 182 |
query_results = query_pinecone(
|
| 183 |
-
query_text,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
)
|
| 185 |
|
| 186 |
if threshold <= 0.60:
|
|
|
|
| 59 |
return api_key
|
| 60 |
|
| 61 |
|
| 62 |
+
def query_pinecone(query, top_k, model, index, year, quarter, ticker, threshold=0.5):
|
| 63 |
# generate embeddings for the query
|
| 64 |
xq = model.encode([query]).tolist()
|
| 65 |
# search pinecone index for context passage with the answer
|
| 66 |
+
xc = index.query(
|
| 67 |
+
xq,
|
| 68 |
+
top_k=top_k,
|
| 69 |
+
filter={"year": year, "quarter": quarter, "ticker": ticker},
|
| 70 |
+
include_metadata=True,
|
| 71 |
+
)
|
| 72 |
# filter the context passages based on the score threshold
|
| 73 |
filtered_matches = []
|
| 74 |
for match in xc["matches"]:
|
|
|
|
| 142 |
|
| 143 |
query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
|
| 144 |
|
| 145 |
+
years_choice = ["2016", "2017", "2018", "2019", "2020"]
|
| 146 |
+
|
| 147 |
+
year = st.selectbox("Year", years_choice)
|
| 148 |
+
|
| 149 |
+
quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4"])
|
| 150 |
+
|
| 151 |
+
ticker_choice = [
|
| 152 |
+
"AAPL",
|
| 153 |
+
"CSCO",
|
| 154 |
+
"MSFT",
|
| 155 |
+
"ASML",
|
| 156 |
+
"NVDA",
|
| 157 |
+
"GOOGL",
|
| 158 |
+
"MU",
|
| 159 |
+
"INTC",
|
| 160 |
+
"AMZN",
|
| 161 |
+
"AMD",
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
ticker = st.selectbox("Company", ticker_choice)
|
| 165 |
+
|
| 166 |
num_results = int(st.number_input("Number of Results to query", 1, 5, value=3))
|
| 167 |
|
| 168 |
|
|
|
|
| 206 |
data = get_data()
|
| 207 |
|
| 208 |
query_results = query_pinecone(
|
| 209 |
+
query_text,
|
| 210 |
+
num_results,
|
| 211 |
+
retriever_model,
|
| 212 |
+
pinecone_index,
|
| 213 |
+
year,
|
| 214 |
+
quarter,
|
| 215 |
+
ticker,
|
| 216 |
+
threshold,
|
| 217 |
)
|
| 218 |
|
| 219 |
if threshold <= 0.60:
|