Spaces:
Build error
Build error
Upload 16 files (#8)
Browse files- Upload 16 files (654c92761e49f26e7c52e337cdd12c207ebeb3e5)
- app.py +90 -24
- utils/models.py +1 -1
app.py
CHANGED
|
@@ -22,9 +22,9 @@ from utils.models import (
|
|
| 22 |
get_data,
|
| 23 |
get_flan_alpaca_xl_model,
|
| 24 |
get_flan_t5_model,
|
|
|
|
| 25 |
get_mpnet_embedding_model,
|
| 26 |
get_sgpt_embedding_model,
|
| 27 |
-
get_instructor_embedding_model,
|
| 28 |
get_spacy_model,
|
| 29 |
get_splade_sparse_embedding_model,
|
| 30 |
get_t5_model,
|
|
@@ -248,7 +248,13 @@ with st.sidebar:
|
|
| 248 |
|
| 249 |
# Choose encoder model
|
| 250 |
|
| 251 |
-
encoder_models_choice = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
with st.sidebar:
|
| 253 |
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
|
| 254 |
|
|
@@ -285,12 +291,32 @@ elif encoder_model == "SGPT":
|
|
| 285 |
elif encoder_model == "Instructor":
|
| 286 |
# Connect to pinecone environment
|
| 287 |
pinecone.init(
|
| 288 |
-
api_key=st.secrets["pinecone_instructor"],
|
|
|
|
| 289 |
)
|
| 290 |
pinecone_index_name = "week13-instructor-xl"
|
| 291 |
pinecone_index = pinecone.Index(pinecone_index_name)
|
| 292 |
retriever_model = get_instructor_embedding_model()
|
| 293 |
-
instruction =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
elif encoder_model == "Hybrid MPNET - SPLADE":
|
| 296 |
pinecone.init(
|
|
@@ -332,10 +358,15 @@ with st.sidebar:
|
|
| 332 |
data = get_data()
|
| 333 |
|
| 334 |
if document_type == "Single-Document":
|
| 335 |
-
if encoder_model
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
sparse_query_embedding = create_sparse_embeddings(
|
| 340 |
query_text, sparse_retriever_model, sparse_retriever_tokenizer
|
| 341 |
)
|
|
@@ -383,10 +414,18 @@ else:
|
|
| 383 |
# Multi-Document Retreival
|
| 384 |
# Single Company
|
| 385 |
if multi_company_choice == "Single-Company":
|
| 386 |
-
if encoder_model
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
sparse_query_embedding = create_sparse_embeddings(
|
| 391 |
query_text, sparse_retriever_model, sparse_retriever_tokenizer
|
| 392 |
)
|
|
@@ -448,10 +487,18 @@ else:
|
|
| 448 |
multi_doc_context = generate_multi_doc_context(context_group)
|
| 449 |
# Companies Comparison
|
| 450 |
else:
|
| 451 |
-
if encoder_model
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
sparse_query_embedding = create_sparse_embeddings(
|
| 456 |
query_text, sparse_retriever_model, sparse_retriever_tokenizer
|
| 457 |
)
|
|
@@ -766,22 +813,41 @@ with tab2:
|
|
| 766 |
for year, quarter in year_quarter_list:
|
| 767 |
file_text = retrieve_transcript(data, year, quarter, ticker)
|
| 768 |
with st.expander(f"See Transcript - {quarter} {year}"):
|
| 769 |
-
st.subheader(
|
|
|
|
|
|
|
| 770 |
stx.scrollableTextbox(
|
| 771 |
-
file_text,
|
|
|
|
|
|
|
|
|
|
| 772 |
)
|
| 773 |
else:
|
| 774 |
for year, quarter in year_quarter_list:
|
| 775 |
-
file_text = retrieve_transcript(
|
|
|
|
|
|
|
| 776 |
with st.expander(f"See Transcript - {quarter} {year}"):
|
| 777 |
-
st.subheader(
|
|
|
|
|
|
|
| 778 |
stx.scrollableTextbox(
|
| 779 |
-
file_text,
|
|
|
|
|
|
|
|
|
|
| 780 |
)
|
| 781 |
for year, quarter in year_quarter_list:
|
| 782 |
-
file_text = retrieve_transcript(
|
|
|
|
|
|
|
| 783 |
with st.expander(f"See Transcript - {quarter} {year}"):
|
| 784 |
-
st.subheader(
|
|
|
|
|
|
|
| 785 |
stx.scrollableTextbox(
|
| 786 |
-
file_text,
|
|
|
|
|
|
|
|
|
|
| 787 |
)
|
|
|
|
| 22 |
get_data,
|
| 23 |
get_flan_alpaca_xl_model,
|
| 24 |
get_flan_t5_model,
|
| 25 |
+
get_instructor_embedding_model,
|
| 26 |
get_mpnet_embedding_model,
|
| 27 |
get_sgpt_embedding_model,
|
|
|
|
| 28 |
get_spacy_model,
|
| 29 |
get_splade_sparse_embedding_model,
|
| 30 |
get_t5_model,
|
|
|
|
| 248 |
|
| 249 |
# Choose encoder model
|
| 250 |
|
| 251 |
+
encoder_models_choice = [
|
| 252 |
+
"MPNET",
|
| 253 |
+
"Instructor",
|
| 254 |
+
"Hybrid Instructor - SPLADE",
|
| 255 |
+
"SGPT",
|
| 256 |
+
"Hybrid MPNET - SPLADE",
|
| 257 |
+
]
|
| 258 |
with st.sidebar:
|
| 259 |
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
|
| 260 |
|
|
|
|
| 291 |
elif encoder_model == "Instructor":
|
| 292 |
# Connect to pinecone environment
|
| 293 |
pinecone.init(
|
| 294 |
+
api_key=st.secrets["pinecone_instructor"],
|
| 295 |
+
environment="us-west4-gcp-free",
|
| 296 |
)
|
| 297 |
pinecone_index_name = "week13-instructor-xl"
|
| 298 |
pinecone_index = pinecone.Index(pinecone_index_name)
|
| 299 |
retriever_model = get_instructor_embedding_model()
|
| 300 |
+
instruction = (
|
| 301 |
+
"Represent the financial question for retrieving supporting documents:"
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
elif encoder_model == "Hybrid Instructor - SPLADE":
|
| 305 |
+
# Connect to pinecone environment
|
| 306 |
+
pinecone.init(
|
| 307 |
+
api_key=st.secrets["pinecone_instructor_splade"],
|
| 308 |
+
environment="us-west4-gcp-free",
|
| 309 |
+
)
|
| 310 |
+
pinecone_index_name = "week13-splade-instructor-xl"
|
| 311 |
+
pinecone_index = pinecone.Index(pinecone_index_name)
|
| 312 |
+
retriever_model = get_instructor_embedding_model()
|
| 313 |
+
(
|
| 314 |
+
sparse_retriever_model,
|
| 315 |
+
sparse_retriever_tokenizer,
|
| 316 |
+
) = get_splade_sparse_embedding_model()
|
| 317 |
+
instruction = (
|
| 318 |
+
"Represent the financial question for retrieving supporting documents:"
|
| 319 |
+
)
|
| 320 |
|
| 321 |
elif encoder_model == "Hybrid MPNET - SPLADE":
|
| 322 |
pinecone.init(
|
|
|
|
| 358 |
data = get_data()
|
| 359 |
|
| 360 |
if document_type == "Single-Document":
|
| 361 |
+
if encoder_model in ["Hybrid SGPT - SPLADE", "Hybrid Instructor - SPLADE"]:
|
| 362 |
+
if encoder_model == "Hybrid Instructor - SPLADE":
|
| 363 |
+
dense_query_embedding = create_dense_embeddings(
|
| 364 |
+
query_text, retriever_model, instruction
|
| 365 |
+
)
|
| 366 |
+
else:
|
| 367 |
+
dense_query_embedding = create_dense_embeddings(
|
| 368 |
+
query_text, retriever_model
|
| 369 |
+
)
|
| 370 |
sparse_query_embedding = create_sparse_embeddings(
|
| 371 |
query_text, sparse_retriever_model, sparse_retriever_tokenizer
|
| 372 |
)
|
|
|
|
| 414 |
# Multi-Document Retreival
|
| 415 |
# Single Company
|
| 416 |
if multi_company_choice == "Single-Company":
|
| 417 |
+
if encoder_model in [
|
| 418 |
+
"Hybrid SGPT - SPLADE",
|
| 419 |
+
"Hybrid Instructor - SPLADE",
|
| 420 |
+
]:
|
| 421 |
+
if encoder_model == "Hybrid Instructor - SPLADE":
|
| 422 |
+
dense_query_embedding = create_dense_embeddings(
|
| 423 |
+
query_text, retriever_model, instruction
|
| 424 |
+
)
|
| 425 |
+
else:
|
| 426 |
+
dense_query_embedding = create_dense_embeddings(
|
| 427 |
+
query_text, retriever_model
|
| 428 |
+
)
|
| 429 |
sparse_query_embedding = create_sparse_embeddings(
|
| 430 |
query_text, sparse_retriever_model, sparse_retriever_tokenizer
|
| 431 |
)
|
|
|
|
| 487 |
multi_doc_context = generate_multi_doc_context(context_group)
|
| 488 |
# Companies Comparison
|
| 489 |
else:
|
| 490 |
+
if encoder_model in [
|
| 491 |
+
"Hybrid SGPT - SPLADE",
|
| 492 |
+
"Hybrid Instructor - SPLADE",
|
| 493 |
+
]:
|
| 494 |
+
if encoder_model == "Hybrid Instructor - SPLADE":
|
| 495 |
+
dense_query_embedding = create_dense_embeddings(
|
| 496 |
+
query_text, retriever_model, instruction
|
| 497 |
+
)
|
| 498 |
+
else:
|
| 499 |
+
dense_query_embedding = create_dense_embeddings(
|
| 500 |
+
query_text, retriever_model
|
| 501 |
+
)
|
| 502 |
sparse_query_embedding = create_sparse_embeddings(
|
| 503 |
query_text, sparse_retriever_model, sparse_retriever_tokenizer
|
| 504 |
)
|
|
|
|
| 813 |
for year, quarter in year_quarter_list:
|
| 814 |
file_text = retrieve_transcript(data, year, quarter, ticker)
|
| 815 |
with st.expander(f"See Transcript - {quarter} {year}"):
|
| 816 |
+
st.subheader(
|
| 817 |
+
"Earnings Call Transcript - {quarter} {year}:"
|
| 818 |
+
)
|
| 819 |
stx.scrollableTextbox(
|
| 820 |
+
file_text,
|
| 821 |
+
height=700,
|
| 822 |
+
border=False,
|
| 823 |
+
fontFamily="Helvetica",
|
| 824 |
)
|
| 825 |
else:
|
| 826 |
for year, quarter in year_quarter_list:
|
| 827 |
+
file_text = retrieve_transcript(
|
| 828 |
+
data, year, quarter, ticker_first
|
| 829 |
+
)
|
| 830 |
with st.expander(f"See Transcript - {quarter} {year}"):
|
| 831 |
+
st.subheader(
|
| 832 |
+
"Earnings Call Transcript - {quarter} {year}:"
|
| 833 |
+
)
|
| 834 |
stx.scrollableTextbox(
|
| 835 |
+
file_text,
|
| 836 |
+
height=700,
|
| 837 |
+
border=False,
|
| 838 |
+
fontFamily="Helvetica",
|
| 839 |
)
|
| 840 |
for year, quarter in year_quarter_list:
|
| 841 |
+
file_text = retrieve_transcript(
|
| 842 |
+
data, year, quarter, ticker_second
|
| 843 |
+
)
|
| 844 |
with st.expander(f"See Transcript - {quarter} {year}"):
|
| 845 |
+
st.subheader(
|
| 846 |
+
"Earnings Call Transcript - {quarter} {year}:"
|
| 847 |
+
)
|
| 848 |
stx.scrollableTextbox(
|
| 849 |
+
file_text,
|
| 850 |
+
height=700,
|
| 851 |
+
border=False,
|
| 852 |
+
fontFamily="Helvetica",
|
| 853 |
)
|
utils/models.py
CHANGED
|
@@ -8,8 +8,8 @@ import spacy
|
|
| 8 |
import spacy_transformers
|
| 9 |
import streamlit_scrollable_textbox as stx
|
| 10 |
import torch
|
| 11 |
-
from sentence_transformers import SentenceTransformer
|
| 12 |
from InstructorEmbedding import INSTRUCTOR
|
|
|
|
| 13 |
from tqdm import tqdm
|
| 14 |
from transformers import (
|
| 15 |
AutoModelForMaskedLM,
|
|
|
|
| 8 |
import spacy_transformers
|
| 9 |
import streamlit_scrollable_textbox as stx
|
| 10 |
import torch
|
|
|
|
| 11 |
from InstructorEmbedding import INSTRUCTOR
|
| 12 |
+
from sentence_transformers import SentenceTransformer
|
| 13 |
from tqdm import tqdm
|
| 14 |
from transformers import (
|
| 15 |
AutoModelForMaskedLM,
|