import streamlit as st import pandas as pd import random from config import initialize from utils.firebase_utils import save_vote from utils.openai_utils import generate_rewrite from utils.data_utils import load_data, get_random_review, read_prompt from utils.gemma_utils import get_gemma_response models = ["gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4o-mini", "princeton-nlp/gemma-2-9b-it-SimPO"] def local_css(file_name): with open(file_name) as f: css = f.read() st.markdown(f"", unsafe_allow_html=True) def main(): st.set_page_config(page_title="AI Арена", layout="wide") local_css("styles.css") st.markdown("

⚔️ AI Арена - Битва Моделей

", unsafe_allow_html=True) st.markdown("

🎲 Нажми кнопку ниже, чтобы загрузить оригинальный отзыв и начать битву!

", unsafe_allow_html=True) db = initialize() data_files = ['data/final_data_ozon_with_names.xlsx', 'data/final_data_wb_with_names.xlsx'] data = load_data(data_files) # Initialize session state variables if 'original_review' not in st.session_state: st.session_state.original_review = '' if 'user_name' not in st.session_state: st.session_state.user_name = '' if 'product_name' not in st.session_state: st.session_state.product_name = '' if 'model_a' not in st.session_state: st.session_state.model_a = '' if 'model_b' not in st.session_state: st.session_state.model_b = '' if 'rewrite_a' not in st.session_state: st.session_state.rewrite_a = '' if 'rewrite_b' not in st.session_state: st.session_state.rewrite_b = '' if 'vote_submitted' not in st.session_state: st.session_state.vote_submitted = False if 'used_models' not in st.session_state: st.session_state.used_models = [] if 'battle_started' not in st.session_state: st.session_state.battle_started = False if 'rewrites_generated' not in st.session_state: st.session_state.rewrites_generated = False instruction_gpt = read_prompt('prompts/gpt_prompt.txt') instruction_gemma = read_prompt('prompts/gemma_prompt.txt') if st.button("🚀 Начать битву"): random_row = get_random_review(data) st.session_state.original_review = random_row['body'] st.session_state.user_name = random_row['user_name'] st.session_state.product_name = random_row['names'] if pd.isna(st.session_state.user_name) or st.session_state.user_name.strip() == '': st.session_state.user_name = 'Пользователь' if pd.isna(st.session_state.product_name) or st.session_state.product_name.strip() == '': st.session_state.product_name = 'не найдено' st.session_state.used_models = [] selected_models = random.sample(models, 2) st.session_state.model_a = selected_models[0] st.session_state.model_b = selected_models[1] st.session_state.used_models.extend(selected_models) st.session_state.rewrite_a = '' st.session_state.rewrite_b = '' st.session_state.vote_submitted = False st.session_state.battle_started = True st.session_state.rewrites_generated = False if st.session_state.battle_started: st.markdown(f"

📝 Оригинальный отзыв от {st.session_state.user_name}:
Наименование товара: {st.session_state.product_name}

", unsafe_allow_html=True) st.markdown( f"
{st.session_state.original_review}
", unsafe_allow_html=True ) if st.button("⚔️ Сгенерировать рерайты", key='generate_rewrites'): with st.spinner("Генерация рерайтов..."): def generate_rewrite_for_model(model_name, text): if "gemma" in model_name.lower(): instruction = instruction_gemma.format(user_text=text, product_name=st.session_state.product_name) return get_gemma_response(model_name, instruction) else: instruction = instruction_gpt.format(user_name=st.session_state.user_name, product_name=st.session_state.product_name) return generate_rewrite(model_name, instruction, text) st.session_state.rewrite_a = generate_rewrite_for_model( st.session_state.model_a, st.session_state.original_review) st.session_state.rewrite_b = generate_rewrite_for_model( st.session_state.model_b, st.session_state.original_review) st.session_state.rewrites_generated = True if st.session_state.rewrites_generated: st.markdown("""

🗳️ Проголосуй за лучший рерайт!

Выбери рерайт, который соответствует данным критериям:

  1. Сохранен смысл оригинального отзыва, но написан другими словами.
  2. Рерайт соответствует товару, о котором говорится в оригинале.
  3. Рерайт соответствует полу, как в оригинале.
  4. Написан повседневным языком и не похож на сгенерированный.
  5. Имеет структуру Отзыв, Преимущества, Недостатки.
""", unsafe_allow_html=True) cols = st.columns(2) with cols[0]: st.markdown("

Model A

", unsafe_allow_html=True) st.markdown( f"
{st.session_state.rewrite_a}
", unsafe_allow_html=True ) st.markdown("", unsafe_allow_html=True) if not st.session_state.vote_submitted: if st.button("👍 Выбрать Model A", key='choose_model_a'): st.session_state.vote_submitted = True st.success("Вы выбрали Model A") save_vote(db, 'left', models) with cols[1]: st.markdown("

Model B

", unsafe_allow_html=True) st.markdown( f"
{st.session_state.rewrite_b}
", unsafe_allow_html=True ) st.markdown("", unsafe_allow_html=True) if not st.session_state.vote_submitted: if st.button("👍 Выбрать Model B", key='choose_model_b'): st.session_state.vote_submitted = True st.success("Вы выбрали Model B") save_vote(db, 'right', models) if not st.session_state.vote_submitted: if st.button("😕 Ни один рерайт не понравился, получить другие рерайты", key='regenerate_rewrites'): available_models = [model for model in models if model not in st.session_state.used_models] if len(available_models) < 2: st.session_state.used_models = [] available_models = models.copy() selected_models = random.sample(available_models, 2) st.session_state.model_a = selected_models[0] st.session_state.model_b = selected_models[1] st.session_state.used_models.extend(selected_models) st.session_state.rewrite_a = '' st.session_state.rewrite_b = '' st.session_state.vote_submitted = False st.session_state.rewrites_generated = False with st.spinner("Генерация новых рерайтов..."): def generate_rewrite_for_model(model_name, text): if "gemma" in model_name.lower(): instruction = instruction_gemma.format(user_text=text, product_name=st.session_state.product_name) return get_gemma_response(model_name, instruction) else: instruction = instruction_gpt.format(user_name=st.session_state.user_name, product_name=st.session_state.product_name) return generate_rewrite(model_name, instruction, text) st.session_state.rewrite_a = generate_rewrite_for_model( st.session_state.model_a, st.session_state.original_review) st.session_state.rewrite_b = generate_rewrite_for_model( st.session_state.model_b, st.session_state.original_review) st.session_state.rewrites_generated = True st.rerun() if __name__ == '__main__': main()