import streamlit as st
import pandas as pd
import random
from config import initialize
from utils.firebase_utils import save_vote
from utils.openai_utils import generate_rewrite
from utils.data_utils import load_data, get_random_review, read_prompt
from utils.gemma_utils import get_gemma_response
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4o-mini", "princeton-nlp/gemma-2-9b-it-SimPO"]
def local_css(file_name):
with open(file_name) as f:
css = f.read()
st.markdown(f"", unsafe_allow_html=True)
def main():
st.set_page_config(page_title="AI Арена", layout="wide")
local_css("styles.css")
st.markdown("
⚔️ AI Арена - Битва Моделей
", unsafe_allow_html=True)
st.markdown("🎲 Нажми кнопку ниже, чтобы загрузить оригинальный отзыв и начать битву!
", unsafe_allow_html=True)
db = initialize()
data_files = ['data/final_data_ozon_with_names.xlsx', 'data/final_data_wb_with_names.xlsx']
data = load_data(data_files)
# Initialize session state variables
if 'original_review' not in st.session_state:
st.session_state.original_review = ''
if 'user_name' not in st.session_state:
st.session_state.user_name = ''
if 'product_name' not in st.session_state:
st.session_state.product_name = ''
if 'model_a' not in st.session_state:
st.session_state.model_a = ''
if 'model_b' not in st.session_state:
st.session_state.model_b = ''
if 'rewrite_a' not in st.session_state:
st.session_state.rewrite_a = ''
if 'rewrite_b' not in st.session_state:
st.session_state.rewrite_b = ''
if 'vote_submitted' not in st.session_state:
st.session_state.vote_submitted = False
if 'used_models' not in st.session_state:
st.session_state.used_models = []
if 'battle_started' not in st.session_state:
st.session_state.battle_started = False
if 'rewrites_generated' not in st.session_state:
st.session_state.rewrites_generated = False
instruction_gpt = read_prompt('prompts/gpt_prompt.txt')
instruction_gemma = read_prompt('prompts/gemma_prompt.txt')
if st.button("🚀 Начать битву"):
random_row = get_random_review(data)
st.session_state.original_review = random_row['body']
st.session_state.user_name = random_row['user_name']
st.session_state.product_name = random_row['names']
if pd.isna(st.session_state.user_name) or st.session_state.user_name.strip() == '':
st.session_state.user_name = 'Пользователь'
if pd.isna(st.session_state.product_name) or st.session_state.product_name.strip() == '':
st.session_state.product_name = 'не найдено'
st.session_state.used_models = []
selected_models = random.sample(models, 2)
st.session_state.model_a = selected_models[0]
st.session_state.model_b = selected_models[1]
st.session_state.used_models.extend(selected_models)
st.session_state.rewrite_a = ''
st.session_state.rewrite_b = ''
st.session_state.vote_submitted = False
st.session_state.battle_started = True
st.session_state.rewrites_generated = False
if st.session_state.battle_started:
st.markdown(f"📝 Оригинальный отзыв от {st.session_state.user_name}:
Наименование товара: {st.session_state.product_name}
", unsafe_allow_html=True)
st.markdown(
f"{st.session_state.original_review}
",
unsafe_allow_html=True
)
if st.button("⚔️ Сгенерировать рерайты", key='generate_rewrites'):
with st.spinner("Генерация рерайтов..."):
def generate_rewrite_for_model(model_name, text):
if "gemma" in model_name.lower():
instruction = instruction_gemma.format(user_text=text, product_name=st.session_state.product_name)
return get_gemma_response(model_name, instruction)
else:
instruction = instruction_gpt.format(user_name=st.session_state.user_name, product_name=st.session_state.product_name)
return generate_rewrite(model_name, instruction, text)
st.session_state.rewrite_a = generate_rewrite_for_model(
st.session_state.model_a, st.session_state.original_review)
st.session_state.rewrite_b = generate_rewrite_for_model(
st.session_state.model_b, st.session_state.original_review)
st.session_state.rewrites_generated = True
if st.session_state.rewrites_generated:
st.markdown("""
🗳️ Проголосуй за лучший рерайт!
Выбери рерайт, который соответствует данным критериям:
- Сохранен смысл оригинального отзыва, но написан другими словами.
- Рерайт соответствует товару, о котором говорится в оригинале.
- Рерайт соответствует полу, как в оригинале.
- Написан повседневным языком и не похож на сгенерированный.
- Имеет структуру Отзыв, Преимущества, Недостатки.
""", unsafe_allow_html=True)
cols = st.columns(2)
with cols[0]:
st.markdown("Model A
", unsafe_allow_html=True)
st.markdown(
f"{st.session_state.rewrite_a}
",
unsafe_allow_html=True
)
st.markdown("", unsafe_allow_html=True)
if not st.session_state.vote_submitted:
if st.button("👍 Выбрать Model A", key='choose_model_a'):
st.session_state.vote_submitted = True
st.success("Вы выбрали Model A")
save_vote(db, 'left', models)
with cols[1]:
st.markdown("Model B
", unsafe_allow_html=True)
st.markdown(
f"{st.session_state.rewrite_b}
",
unsafe_allow_html=True
)
st.markdown("", unsafe_allow_html=True)
if not st.session_state.vote_submitted:
if st.button("👍 Выбрать Model B", key='choose_model_b'):
st.session_state.vote_submitted = True
st.success("Вы выбрали Model B")
save_vote(db, 'right', models)
if not st.session_state.vote_submitted:
if st.button("😕 Ни один рерайт не понравился, получить другие рерайты", key='regenerate_rewrites'):
available_models = [model for model in models if model not in st.session_state.used_models]
if len(available_models) < 2:
st.session_state.used_models = []
available_models = models.copy()
selected_models = random.sample(available_models, 2)
st.session_state.model_a = selected_models[0]
st.session_state.model_b = selected_models[1]
st.session_state.used_models.extend(selected_models)
st.session_state.rewrite_a = ''
st.session_state.rewrite_b = ''
st.session_state.vote_submitted = False
st.session_state.rewrites_generated = False
with st.spinner("Генерация новых рерайтов..."):
def generate_rewrite_for_model(model_name, text):
if "gemma" in model_name.lower():
instruction = instruction_gemma.format(user_text=text, product_name=st.session_state.product_name)
return get_gemma_response(model_name, instruction)
else:
instruction = instruction_gpt.format(user_name=st.session_state.user_name, product_name=st.session_state.product_name)
return generate_rewrite(model_name, instruction, text)
st.session_state.rewrite_a = generate_rewrite_for_model(
st.session_state.model_a, st.session_state.original_review)
st.session_state.rewrite_b = generate_rewrite_for_model(
st.session_state.model_b, st.session_state.original_review)
st.session_state.rewrites_generated = True
st.rerun()
if __name__ == '__main__':
main()