Aidan-Bench / app.py
Whiteshadow12's picture
x
187c8cf
import streamlit as st
from main import benchmark_model_multithreaded, benchmark_model_sequential
from prompts import questions as predefined_questions
import requests
import pandas as pd
# Set the title in the browser tab
st.set_page_config(page_title="Aidan Bench - Generator")
st.title("Aidan Bench - Generator")
# API Key Inputs with Security and User Experience Enhancements
st.warning("Please keep your API keys secure and confidential. This app does not store or log your API keys.")
if "open_router_key" not in st.session_state:
st.session_state.open_router_key = ""
if "openai_api_key" not in st.session_state:
st.session_state.openai_api_key = ""
open_router_key = st.text_input("Enter your Open Router API Key:", type="password", value=st.session_state.open_router_key)
openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password", value=st.session_state.openai_api_key)
if st.button("Confirm API Keys"):
if open_router_key and openai_api_key:
st.session_state.open_router_key = open_router_key
st.session_state.openai_api_key = openai_api_key
st.success("API keys confirmed!")
else:
st.warning("Please enter both API keys.")
# Access API keys from session state
if st.session_state.open_router_key and st.session_state.openai_api_key:
# Fetch models from OpenRouter API
try:
response = requests.get("https://openrouter.ai/api/v1/models")
response.raise_for_status() # Raise an exception for bad status codes
all_models = response.json()["data"]
# Sort models alphabetically by their ID
all_models.sort(key=lambda model: model["id"])
# --- Create dictionaries for easy model lookup ---
models_by_id = {model["id"]: model for model in all_models}
judge_models = [model["id"] for model in all_models if "gpt" in model["id"]]
judge_models.sort()
model_names = list(models_by_id.keys())
except requests.exceptions.RequestException as e:
st.error(f"Error fetching models from OpenRouter API: {e}")
model_names = [] # Provide an empty list if API call fails
judge_models = []
# Model Selection
if model_names:
model_name = st.selectbox("Select a Contestant Model", model_names)
# --- Display pricing for the selected model ---
selected_model = models_by_id.get(model_name)
if selected_model:
pricing_info = selected_model.get('pricing', {})
prompt_price = float(pricing_info.get("prompt", 0)) * 1000000
completion_price = float(pricing_info.get("completion", 0)) * 1000000
# Display pricing information with increased precision
st.write(f"**Prompt Pricing:** ${prompt_price:.2f}/Million tokens (if applicable)")
st.write(f"**Completion Pricing:** ${completion_price:.2f}/Million tokens")
else:
st.write("**Pricing:** N/A")
else:
st.error("No models available. Please check your API connection.")
st.stop()
# Judge Model Selection
if judge_models:
judge_model_name = st.selectbox("Select a Judge Model", judge_models)
# --- Display pricing for the selected judge model ---
selected_judge_model = models_by_id.get(judge_model_name)
if selected_judge_model:
pricing_info = selected_judge_model.get('pricing', {})
prompt_price = float(pricing_info.get("prompt", 0)) * 1000000
completion_price = float(pricing_info.get("completion", 0)) * 1000000
# Display pricing information with increased precision
st.write(f"**Prompt Pricing:** ${prompt_price:.2f}/Million tokens (if applicable)")
st.write(f"**Completion Pricing:** ${completion_price:.2f}/Million tokens")
else:
st.write("**Pricing:** N/A")
else:
st.error("No judge models available. Please check your API connection.")
st.stop()
# Initialize session state for user_questions and predefined_questions
if "user_questions" not in st.session_state:
st.session_state.user_questions = []
# Threshold Sliders
st.sidebar.subheader("Threshold Sliders")
coherence_threshold = st.sidebar.slider("Coherence Threshold (0-5):", 0, 5, 3)
novelty_threshold = st.sidebar.slider("Novelty Threshold (0-1):", 0.0, 1.0, 0.1)
st.sidebar.subheader("Temp Sliders")
temp_threshold = st.sidebar.slider("Temperature (0-2):", 0.0, 2.0, 1.0)
top_p = st.sidebar.slider("Top P (0-1):", 0.0, 1.0, 1.0)
# Workflow Selection
workflow = st.radio("Select Workflow:", ["Use Predefined Questions", "Use User-Defined Questions"])
# Handle Predefined Questions
if workflow == "Use Predefined Questions":
st.header("Question Selection")
# Multiselect for predefined questions
selected_questions = st.multiselect(
"Select questions to benchmark:",
predefined_questions,
predefined_questions # Select all by default
)
# Handle User-Defined Questions
elif workflow == "Use User-Defined Questions":
st.header("Question Input")
# Input for adding a new question
new_question = st.text_input("Enter a new question:")
if st.button("Add Question") and new_question:
new_question = new_question.strip() # Remove leading/trailing whitespace
if new_question and new_question not in st.session_state.user_questions:
st.session_state.user_questions.append(new_question) # Append to session state
st.success(f"Question '{new_question}' added successfully.")
else:
st.warning("Question already exists or is empty!")
# Display multiselect with updated user questions
selected_questions = st.multiselect(
"Select your custom questions:",
options=st.session_state.user_questions,
default=st.session_state.user_questions
)
# Display selected questions
st.write("Selected Questions:", selected_questions)
# Choose execution mode
execution_mode = st.radio("Execution Mode:", ["Sequential", "Multithreaded"])
# If multithreaded, allow user to configure thread pool size
if execution_mode == "Multithreaded":
max_threads = st.slider("Maximum Number of Threads:", 1, 10, 4) # Default to 4 threads
else:
max_threads = None # For sequential mode
# Benchmark Execution
if st.button("Start Benchmark"):
if not selected_questions:
st.warning("Please select at least one question.")
else:
num_questions = len(selected_questions)
results = []
# Stop button (not implemented yet)
stop_button = st.button("Stop Benchmark")
# Benchmarking logic using the chosen execution mode
if execution_mode == "Sequential":
question_results = benchmark_model_sequential(model_name, selected_questions, st.session_state.open_router_key, st.session_state.openai_api_key,judge_model_name,coherence_threshold,novelty_threshold,temp_threshold,top_p)
else: # Multithreaded
question_results = benchmark_model_multithreaded(model_name, selected_questions, st.session_state.open_router_key, st.session_state.openai_api_key, max_threads, judge_model_name, coherence_threshold,novelty_threshold,temp_threshold,top_p)
results.extend(question_results)
# Display results in a table
st.write("Results:")
results_table = []
for result in results:
for answer in result["answers"]:
results_table.append({
"Question": result["question"],
"Answer": answer,
"Contestant Model": model_name,
"Judge Model": judge_model_name,
"Coherence Score": result["coherence_score"],
"Novelty Score": result["novelty_score"]
})
st.table(results_table)
df = pd.DataFrame(results_table) # Create a Pandas DataFrame from the results
csv = df.to_csv(index=False).encode('utf-8') # Convert DataFrame to CSV
st.download_button(
label="Export Results as CSV",
data=csv,
file_name="benchmark_results.csv",
mime='text/csv'
)
if stop_button:
st.warning("Partial results displayed due to interruption.")
else:
st.success("Benchmark completed!")
else:
st.warning("Please confirm your API keys first.")