awinml's picture
Upload 2 files
6627aee
raw
history blame
11.3 kB
import openai
import streamlit_scrollable_textbox as stx
import pinecone
import streamlit as st
st.set_page_config(layout="wide") # isort: split
from utils import (
clean_entities,
create_dense_embeddings,
create_sparse_embeddings,
extract_entities,
format_query,
generate_flant5_prompt_instruct_chunk_context,
generate_flant5_prompt_instruct_complete_context,
generate_flant5_prompt_instruct_chunk_context_single,
generate_flant5_prompt_summ_chunk_context_single,
generate_flant5_prompt_summ_chunk_context,
generate_text_flan_t5,
generate_gpt_prompt,
generate_gpt_j_two_shot_prompt_1,
generate_gpt_j_two_shot_prompt_2,
get_context_list_prompt,
get_data,
get_flan_t5_model,
get_mpnet_embedding_model,
get_sgpt_embedding_model,
get_spacy_model,
get_splade_sparse_embedding_model,
get_t5_model,
gpt_model,
hybrid_score_norm,
query_pinecone,
query_pinecone_sparse,
retrieve_transcript,
save_key,
sentence_id_combine,
text_lookup,
)
st.title("Abstractive Question Answering")
st.write(
"The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020."
)
col1, col2 = st.columns([3, 3], gap="medium")
spacy_model = get_spacy_model()
with col1:
st.subheader("Question")
query_text = st.text_area(
"Input Query",
value="What was discussed regarding Wearables revenue performance?",
)
company_ent, quarter_ent, year_ent = extract_entities(query_text, spacy_model)
ticker_index, quarter_index, year_index = clean_entities(
company_ent, quarter_ent, year_ent
)
with col1:
years_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
with col1:
# Hardcoding the defaults for a question without metadata
if (
query_text
== "What was discussed regarding Wearables revenue performance?"
):
year = st.selectbox("Year", years_choice)
else:
year = st.selectbox("Year", years_choice, index=year_index)
with col1:
# Hardcoding the defaults for a question without metadata
if (
query_text
== "What was discussed regarding Wearables revenue performance?"
):
quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4", "All"])
else:
quarter = st.selectbox(
"Quarter", ["Q1", "Q2", "Q3", "Q4", "All"], index=quarter_index
)
with col1:
participant_type = st.selectbox("Speaker", ["Company Speaker", "Analyst"])
ticker_choice = [
"AAPL",
"CSCO",
"MSFT",
"ASML",
"NVDA",
"GOOGL",
"MU",
"INTC",
"AMZN",
"AMD",
]
with col1:
# Hardcoding the defaults for a question without metadata
if (
query_text
== "What was discussed regarding Wearables revenue performance?"
):
ticker = st.selectbox("Company", ticker_choice)
else:
ticker = st.selectbox("Company", ticker_choice, ticker_index)
with st.sidebar:
st.subheader("Select Options:")
with st.sidebar:
num_results = int(
st.number_input("Number of Results to query", 1, 15, value=5)
)
# Choose encoder model
encoder_models_choice = ["MPNET", "SGPT", "Hybrid MPNET - SPLADE"]
with st.sidebar:
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
# Choose decoder model
decoder_models_choice = ["GPT3 - (text-davinci-003)", "T5", "FLAN-T5", "GPT-J"]
with st.sidebar:
decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
if encoder_model == "MPNET":
# Connect to pinecone environment
pinecone.init(
api_key=st.secrets["pinecone_mpnet"], environment="us-east1-gcp"
)
pinecone_index_name = "week2-all-mpnet-base"
pinecone_index = pinecone.Index(pinecone_index_name)
retriever_model = get_mpnet_embedding_model()
elif encoder_model == "SGPT":
# Connect to pinecone environment
pinecone.init(
api_key=st.secrets["pinecone_sgpt"], environment="us-east1-gcp"
)
pinecone_index_name = "week2-sgpt-125m"
pinecone_index = pinecone.Index(pinecone_index_name)
retriever_model = get_sgpt_embedding_model()
elif encoder_model == "Hybrid MPNET - SPLADE":
pinecone.init(
api_key=st.secrets["pinecone_hybrid_splade_mpnet"],
environment="us-central1-gcp",
)
pinecone_index_name = "splade-mpnet"
pinecone_index = pinecone.Index(pinecone_index_name)
retriever_model = get_mpnet_embedding_model()
(
sparse_retriever_model,
sparse_retriever_tokenizer,
) = get_splade_sparse_embedding_model()
with st.sidebar:
window = int(st.number_input("Sentence Window Size", 0, 10, value=1))
with st.sidebar:
threshold = float(
st.number_input(
label="Similarity Score Threshold",
step=0.05,
format="%.2f",
value=0.25,
)
)
data = get_data()
if encoder_model == "Hybrid SGPT - SPLADE":
dense_query_embedding = create_dense_embeddings(
query_text, retriever_model
)
sparse_query_embedding = create_sparse_embeddings(
query_text, sparse_retriever_model, sparse_retriever_tokenizer
)
dense_query_embedding, sparse_query_embedding = hybrid_score_norm(
dense_query_embedding, sparse_query_embedding, 0
)
query_results = query_pinecone_sparse(
dense_query_embedding,
sparse_query_embedding,
num_results,
pinecone_index,
year,
quarter,
ticker,
participant_type,
threshold,
)
else:
dense_query_embedding = create_dense_embeddings(
query_text, retriever_model
)
query_results = query_pinecone(
dense_query_embedding,
num_results,
pinecone_index,
year,
quarter,
ticker,
participant_type,
threshold,
)
if threshold <= 0.90:
context_list = sentence_id_combine(data, query_results, lag=window)
else:
context_list = format_query(query_results)
if decoder_model == "GPT3 - (text-davinci-003)":
prompt = generate_gpt_prompt(query_text, context_list)
with col2:
with st.form("my_form"):
edited_prompt = st.text_area(
label="Model Prompt", value=prompt, height=270
)
openai_key = st.text_input(
"Enter OpenAI key",
value="",
type="password",
)
submitted = st.form_submit_button("Submit")
if submitted:
api_key = save_key(openai_key)
openai.api_key = api_key
generated_text = gpt_model(edited_prompt)
st.subheader("Answer:")
st.write(generated_text)
elif decoder_model == "T5":
prompt = generate_flant5_prompt_instruct_complete_context(query_text, context_list)
t5_pipeline = get_t5_model()
output_text = []
with col2:
with st.form("my_form"):
edited_prompt = st.text_area(
label="Model Prompt", value=prompt, height=270
)
context_list = get_context_list_prompt(edited_prompt)
submitted = st.form_submit_button("Submit")
if submitted:
for context_text in context_list:
output_text.append(
t5_pipeline(context_text)[0]["summary_text"]
)
st.subheader("Answer:")
for text in output_text:
st.markdown(f"- {text}")
elif decoder_model == "FLAN-T5":
flan_t5_model, flan_t5_tokenizer = get_flan_t5_model()
output_text = []
with col2:
prompt_type = st.selectbox(
"Select prompt type", ["Complete Text QA", "Chunkwise QA", "Chunkwise Summarize"]
)
if prompt_type == "Complete Text QA":
prompt = generate_flant5_prompt_instruct_complete_context(
query_text, context_list
)
elif prompt_type == "Chunkwise QA":
st.write("The following prompt is not editable.")
prompt = generate_flant5_prompt_instruct_chunk_context(
query_text, context_list
)
elif prompt_type == "Chunkwise Summarize":
st.write("The following prompt is not editable.")
prompt = generate_flant5_prompt_summ_chunk_context(
query_text, context_list
)
else:
prompt = ""
with st.form("my_form"):
edited_prompt = st.text_area(
label="Model Prompt", value=prompt, height=270
)
submitted = st.form_submit_button("Submit")
if submitted:
if prompt_type == "Complete Text QA":
output_text_string = generate_text_flan_t5(flan_t5_model, flan_t5_tokenizer, prompt)
st.subheader("Answer:")
st.write(output_text_string)
elif prompt_type == "Chunkwise QA":
for context_text in context_list:
model_input = generate_flant5_prompt_instruct_chunk_context_single(query_text, context_text)
output_text.append(
generate_text_flan_t5(flan_t5_model, flan_t5_tokenizer, model_input))
st.subheader("Answer:")
for text in output_text:
if "(iii)" not in text:
st.markdown(f"- {text}")
elif prompt_type == "Chunkwise Summarize":
for context_text in context_list:
model_input = generate_flant5_prompt_summ_chunk_context_single(query_text, context_text)
output_text.append(
generate_text_flan_t5(flan_t5_model, flan_t5_tokenizer, model_input))
st.subheader("Answer:")
for text in output_text:
if "(iii)" not in text:
st.markdown(f"- {text}")
if decoder_model == "GPT-J":
if ticker in ["AAPL", "AMD"]:
prompt = generate_gpt_j_two_shot_prompt_1(query_text, context_list)
elif ticker in ["NVDA", "INTC", "AMZN"]:
prompt = generate_gpt_j_two_shot_prompt_2(query_text, context_list)
else:
prompt = generate_gpt_j_two_shot_prompt_1(query_text, context_list)
with col2:
with st.form("my_form"):
edited_prompt = st.text_area(
label="Model Prompt", value=prompt, height=270
)
st.write(
"The app currently just shows the prompt. The app does not load the model due to memory limitations."
)
submitted = st.form_submit_button("Submit")
with col1:
with st.expander("See Retrieved Text"):
for context_text in context_list:
st.markdown(f"- {context_text}")
file_text = retrieve_transcript(data, year, quarter, ticker)
with col1:
with st.expander("See Transcript"):
stx.scrollableTextbox(
file_text, height=700, border=False, fontFamily="Helvetica"
)