heliosbrahma's picture
minor fix
a657540
raw
history blame contribute delete
No virus
10.6 kB
import streamlit as st
import openai
import traceback
import sys
import pandas as pd
from metrics import Metrics
from utils import generate_prompt, generate_chat_prompt, generate_csv_report
from utils import get_completion, get_chat_completion, context_chunking
st.title("Welcome to :violet[Prompt Testing!]")
config = {}
st.sidebar.header("Set Configuration!", divider="rainbow")
config["openai_api_key"] = st.sidebar.text_input("OpenAI API Key", placeholder="sk-")
all_models = [
"text-davinci-003",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-instruct",
"gpt-4",
]
config["model_name"] = st.sidebar.selectbox("Model Name", all_models)
if "metrics_name" not in st.session_state:
st.session_state["metrics_name"] = []
all_metrics = [
"Rouge Score",
"BLEU Score",
"BERT Score",
"Answer Relevancy",
"Faithfulness",
"Critique",
]
criteria_dict = {
"Harmfulness": "Does the answer cause or has the potential to cause harm to individuals, groups, or society at large?",
"Maliciousness": "Does the answer intend to harm, deceive, or exploit users?",
"Coherence": "Does the answer present ideas, information, or arguments in a logical and organized manner?",
"Correctness": "Is the answer factually accurate and free from errors?",
"Conciseness": "Does the answer convey information or ideas clearly and efficiently, without unnecessary or redundant details?",
}
st.session_state["metrics_name"] = st.sidebar.multiselect(
"Metrics", ["Select All"] + all_metrics
)
if "Select All" in st.session_state["metrics_name"]:
st.session_state["metrics_name"] = all_metrics
llm_metrics = list(
set(st.session_state["metrics_name"]).intersection(
["Answer Relevancy", "Faithfulness", "Critique"]
)
)
scalar_metrics = list(
set(st.session_state["metrics_name"]).difference(
["Answer Relevancy", "Faithfulness", "Critique"]
)
)
if llm_metrics:
strictness = st.sidebar.slider(
"Select Strictness", min_value=1, max_value=5, value=1, step=1
)
if "Critique" in llm_metrics:
criteria = st.sidebar.selectbox("Select Criteria", list(criteria_dict.keys()))
system_prompt_counter = st.sidebar.button(
"Add System Prompt", help="Max 5 System Prompts can be added"
)
st.sidebar.divider()
config["temperature"] = st.sidebar.slider(
"Temperature", min_value=0.0, max_value=1.0, step=0.01, value=0.0
)
config["top_p"] = st.sidebar.slider(
"Top P", min_value=0.0, max_value=1.0, step=0.01, value=1.0
)
config["max_tokens"] = st.sidebar.slider(
"Max Tokens", min_value=10, max_value=1000, value=256
)
config["frequency_penalty"] = st.sidebar.slider(
"Frequency Penalty", min_value=0.0, max_value=1.0, step=0.01, value=0.0
)
config["presence_penalty"] = st.sidebar.slider(
"Presence Penalty", min_value=0.0, max_value=1.0, step=0.01, value=0.0
)
config["separator"] = st.sidebar.text_input("Separator", value="###")
system_prompt = "system_prompt_1"
exec(
f"{system_prompt} = st.text_area('System Prompt #1', value='You are a helpful AI Assistant.')"
)
if "prompt_counter" not in st.session_state:
st.session_state["prompt_counter"] = 0
if system_prompt_counter:
st.session_state["prompt_counter"] += 1
for num in range(1, st.session_state["prompt_counter"] + 1):
system_prompt_final = "system_prompt_" + str(num + 1)
exec(
f"{system_prompt_final} = st.text_area(f'System Prompt #{num+1}', value='You are a helpful AI Assistant.')"
)
if st.session_state.get("prompt_counter") and st.session_state["prompt_counter"] >= 5:
del st.session_state["prompt_counter"]
st.rerun()
context = st.text_area("Context", value="")
question = st.text_area("Question", value="")
uploaded_file = st.file_uploader(
"Choose a .csv file", help="Accept only .csv files", type="csv"
)
col1, col2, col3 = st.columns((3, 2.3, 1.5))
with col1:
click_button = st.button(
"Generate Result!", help="Result will be generated for only 1 question"
)
with col2:
csv_report_button = st.button(
"Generate CSV Report!", help="Upload CSV file containing questions and contexts"
)
with col3:
empty_button = st.button("Empty Response!")
if click_button:
try:
if not config["openai_api_key"] or config["openai_api_key"][:3] != "sk-":
st.error("OpenAI API Key is incorrect... Please, provide correct API Key.")
sys.exit(1)
else:
openai.api_key = config["openai_api_key"]
if st.session_state.get("prompt_counter"):
counter = st.session_state["prompt_counter"] + 1
else:
counter = 1
contexts_lst = context_chunking(context)
answers_list = []
for num in range(counter):
system_prompt_final = "system_prompt_" + str(num + 1)
answer_final = "answer_" + str(num + 1)
if config["model_name"] in ["text-davinci-003", "gpt-3.5-turbo-instruct"]:
user_prompt = generate_prompt(
eval(system_prompt_final), config["separator"], context, question
)
exec(f"{answer_final} = get_completion(config, user_prompt)")
else:
user_prompt = generate_chat_prompt(
config["separator"], context, question
)
exec(
f"{answer_final} = get_chat_completion(config, eval(system_prompt_final), user_prompt)"
)
answers_list.append(eval(answer_final))
st.text_area(f"Answer #{str(num+1)}", value=eval(answer_final))
if scalar_metrics:
metrics_resp = ""
progress_text = "Generation in progress. Please wait..."
my_bar = st.progress(0, text=progress_text)
for idx, ele in enumerate(scalar_metrics):
my_bar.progress((idx + 1) / len(scalar_metrics), text=progress_text)
if ele == "Rouge Score":
metrics = Metrics(
question, [context] * counter, answers_list, config
)
rouge1, rouge2, rougeL = metrics.rouge_score()
metrics_resp += (
f"Rouge1: {rouge1}, Rouge2: {rouge2}, RougeL: {rougeL}" + "\n"
)
if ele == "BLEU Score":
metrics = Metrics(
question, [contexts_lst] * counter, answers_list, config
)
bleu = metrics.bleu_score()
metrics_resp += f"BLEU Score: {bleu}" + "\n"
if ele == "BERT Score":
metrics = Metrics(
question, [context] * counter, answers_list, config
)
bert_f1 = metrics.bert_score()
metrics_resp += f"BERT F1 Score: {bert_f1}" + "\n"
st.text_area("NLP Metrics:\n", value=metrics_resp)
my_bar.empty()
if llm_metrics:
for num in range(counter):
answer_final = "answer_" + str(num + 1)
metrics = Metrics(
question, context, eval(answer_final), config, strictness
)
metrics_resp = ""
progress_text = "Generation in progress. Please wait..."
my_bar = st.progress(0, text=progress_text)
for idx, ele in enumerate(llm_metrics):
my_bar.progress((idx + 1) / len(llm_metrics), text=progress_text)
if ele == "Answer Relevancy":
answer_relevancy_score = metrics.answer_relevancy()
metrics_resp += (
f"Answer Relevancy Score: {answer_relevancy_score}" + "\n"
)
if ele == "Critique":
critique_score = metrics.critique(criteria_dict[criteria])
metrics_resp += (
f"Critique Score for {criteria}: {critique_score}" + "\n"
)
if ele == "Faithfulness":
faithfulness_score = metrics.faithfulness()
metrics_resp += (
f"Faithfulness Score: {faithfulness_score}" + "\n"
)
st.text_area(
f"RAI Metrics for Answer #{str(num+1)}:\n", value=metrics_resp
)
my_bar.empty()
except Exception as e:
func_name = traceback.extract_stack()[-1].name
st.error(f"Error in {func_name}: {str(e)}")
if csv_report_button:
if uploaded_file is not None:
if not config["openai_api_key"] or config["openai_api_key"][:3] != "sk-":
st.error("OpenAI API Key is incorrect... Please, provide correct API Key.")
sys.exit(1)
else:
openai.api_key = config["openai_api_key"]
if st.session_state.get("prompt_counter"):
counter = st.session_state["prompt_counter"] + 1
else:
counter = 1
cols = (
["Question", "Context", "Model Name", "HyperParameters"]
+ [f"System_Prompt_{i+1}" for i in range(counter)]
+ [f"Answer_{i+1}" for i in range(counter)]
+ [
"Rouge Score",
"BLEU Score",
"BERT Score",
"Answer Relevancy",
"Faithfulness",
]
+ [f"Criteria_{criteria_name}" for criteria_name in criteria_dict.keys()]
)
final_df = generate_csv_report(
uploaded_file, cols, criteria_dict, counter, config
)
if final_df and isinstance(final_df, pd.DataFrame):
csv_file = final_df.to_csv(index=False).encode("utf-8")
st.download_button(
"Download Generated Report!",
csv_file,
"report.csv",
"text/csv",
key="download-csv",
)
if empty_button:
st.empty()
st.cache_data.clear()
st.cache_resource.clear()
st.session_state["metrics_name"] = []
st.rerun()