yutohub commited on
Commit
c624d50
1 Parent(s): da86125

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +307 -0
app.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import time
5
+
6
+ import pandas as pd
7
+ import requests
8
+ import streamlit as st
9
+
10
+
11
+ # 環境変数
12
+ with open("models_info.json", "r") as json_file:
13
+ MODELS_INFO = json.load(json_file)
14
+ with open("test.csv", "r") as file:
15
+ QUESTION_DF = pd.read_csv(file)
16
+ MODELS = list(MODELS_INFO.keys())
17
+ NUM_QUESTION = 100
18
+
19
+
20
+ # ランキングを取得
21
+ @st.cache_data
22
+ def get_leaderboard():
23
+ try:
24
+ response = requests.get(os.environ['DARABASE_URL'])
25
+ response_data = response.json()
26
+ return response_data
27
+ except Exception as e:
28
+ print(f"An unexpected error occurred: {e}")
29
+ return "Error"
30
+
31
+ # リーダーボードを作成
32
+ @st.cache_data
33
+ def create_leaderboard_df():
34
+ # リーダーボードを取得
35
+ ranking = get_leaderboard()
36
+ # エラー処理
37
+ if ranking == "Error":
38
+ st.error("リーダーボードを取得できませんでした。")
39
+ print("リーダーボードを取得できませんでした。") # ログを表示
40
+ return pd.DataFrame()
41
+ else:
42
+ # データの初期化
43
+ ranks, model_names, ratings, organizations, licenses = [], [], [], [], []
44
+ # リーダーボードの作成
45
+ for i in range(len(ranking)):
46
+ ranks.append(i + 1)
47
+ model_names.append(MODELS_INFO[ranking[i]["model"]][0])
48
+ ratings.append(ranking[i]["rating"])
49
+ organizations.append(MODELS_INFO[ranking[i]["model"]][2])
50
+ licenses.append(MODELS_INFO[ranking[i]["model"]][1])
51
+ # データフレームを返す
52
+ return pd.DataFrame({
53
+ "ランク" : ranks,
54
+ "🤖 モデル" : model_names,
55
+ "⭐️ Eloレーティング" : ratings,
56
+ "🏢 組織" : organizations,
57
+ "📃 ライセンス" : licenses
58
+ })
59
+
60
+ # サーバーから回答を取得
61
+ def get_answer(model_name, question_id):
62
+ try:
63
+ params = {'modelName': model_name, 'questionId': question_id}
64
+ response = requests.get(os.environ['ANSWER_URL'], params=params)
65
+ response_data = response.json()
66
+ return response_data["answer"]
67
+ except Exception as e:
68
+ print(f"An unexpected error occurred: {e}")
69
+ return "Error"
70
+
71
+ # サーバーに回答を送信
72
+ def send_choice(question_id, model_a, model_b, winner, language):
73
+ # エラー処理 (データが入力されていない場合)
74
+ if not question_id or not model_a or not model_b or not winner or not language:
75
+ st.error("データが入力されていないため、回答を送信できませんでした。")
76
+ print("質問と回答を取得してください。") # ログを表示
77
+ return "Error"
78
+ try:
79
+ data = {
80
+ "question_id": question_id,
81
+ "model_a": model_a,
82
+ "model_b": model_b,
83
+ "winner": winner,
84
+ "language": language,
85
+ "tstamp": time.time(),
86
+ }
87
+ headers = {
88
+ 'Content-Type': 'application/json'
89
+ }
90
+ response = requests.post(os.environ['DARABASE_URL'], headers=headers, data=json.dumps(data))
91
+ response_data = response.text
92
+ return response_data
93
+ except Exception as e:
94
+ print(f"An unexpected error occurred: {e}")
95
+ return "Error"
96
+
97
+
98
+ ### Callback Functions ###
99
+ # ステートの初期化を行う
100
+ def handle_init_state():
101
+ if "chat_history_a" not in st.session_state:
102
+ st.session_state["chat_history_a"] = []
103
+ if "chat_history_b" not in st.session_state:
104
+ st.session_state["chat_history_b"] = []
105
+ if "question_id" not in st.session_state:
106
+ st.session_state["question_id"] = None
107
+ if "model_a" not in st.session_state:
108
+ st.session_state["model_a"] = None
109
+ if "model_b" not in st.session_state:
110
+ st.session_state["model_b"] = None
111
+ if "question" not in st.session_state:
112
+ st.session_state["question"] = None
113
+ # ボタンの状態を初期化
114
+ if "question_loaded" not in st.session_state:
115
+ st.session_state["question_loaded"] = False
116
+ # 送信を状態を初期化
117
+ if "answer_sent" not in st.session_state:
118
+ st.session_state["answer_sent"] = False
119
+
120
+ # 質問と回答を取得する
121
+ def handle_init_question():
122
+ # エラー処理
123
+ if st.session_state.question_loaded:
124
+ st.session_state.question_loaded = False
125
+ st.session_state.chat_history_a = []
126
+ st.session_state.chat_history_b = []
127
+ st.error("ボタンを連打しないでください。")
128
+ print("既に質問と回答を取得しています。") # ログを表示
129
+ else:
130
+ # ボタンの状態を更新
131
+ st.session_state.question_loaded = True
132
+ st.success("質問と回答を取得しています。しばらくお待ちください。")
133
+ # 質問を取得
134
+ st.session_state.question_id = random.randint(1, NUM_QUESTION)
135
+ st.session_state.question = QUESTION_DF["input"][st.session_state.question_id - 1]
136
+ st.session_state.chat_history_a.append({"role": "user", "content": st.session_state.question})
137
+ st.session_state.chat_history_b.append({"role": "user", "content": st.session_state.question})
138
+ # 回答を取得
139
+ random.shuffle(MODELS)
140
+ st.session_state.model_a = MODELS[0]
141
+ st.session_state.model_b = MODELS[1]
142
+ answer_a = get_answer(st.session_state.model_a, st.session_state.question_id)
143
+ answer_b = get_answer(st.session_state.model_b, st.session_state.question_id)
144
+ # チャット履歴を更新
145
+ st.session_state.chat_history_a.append({"role": "assistant", "content": answer_a})
146
+ st.session_state.chat_history_b.append({"role": "assistant", "content": answer_b})
147
+ st.success("質問と回答を取得しました。回答を選択してください。")
148
+ print("質問と回答を取得しました。") # ログを表示
149
+
150
+ # ユーザーの回答を送信する
151
+ def handle_send_choice(winner):
152
+ # エラー処理
153
+ if st.session_state.answer_sent:
154
+ st.error("既に回答を送信しています。")
155
+ print("既に回答を送信しています。") # ログを表示
156
+ else:
157
+ # ボタンの状態を更新
158
+ st.session_state.answer_sent = True
159
+ # ユーザーの回答を送信
160
+ response = send_choice(
161
+ question_id=st.session_state.question_id,
162
+ model_a=st.session_state.model_a,
163
+ model_b=st.session_state.model_b,
164
+ winner=winner,
165
+ language="Japanese"
166
+ )
167
+ # エラーが発生した場合
168
+ if response == "Error":
169
+ st.error("予期せぬエラーが発生しました。")
170
+ else:
171
+ st.success("選択肢は正常に送信されました。")
172
+ # 初期化
173
+ st.session_state.question_loaded = False
174
+
175
+
176
+ # 表示部分
177
+ def main():
178
+ # page config
179
+ st.set_page_config(
180
+ page_title="日本語チャットボットアリーナ",
181
+ page_icon="🏆",
182
+ layout="wide",
183
+ )
184
+
185
+ # ステートの初期化
186
+ handle_init_state()
187
+ # 説明を表示
188
+ st.markdown("# 🏆 日本語チャットボットアリーナ")
189
+ st.markdown("## 📖 説明")
190
+ st.markdown("| [Twitter](https://twitter.com/yutohub) | [GitHub](https://github.com/yutohub) | [ブログ](https://zenn.dev/yutohub) |")
191
+ st.markdown("日本語チャットボットアリーナは、日本語に対応しているLLMの評価のためのクラウドソーシングプラットフォームです。[LMSYS Chatbot Arena](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) を参考に、日本語に対応しているLLMのリーダーボードを作成することを目的としています。また、一部の質問と回答は、 [ELYZA-tasks-100](https://huggingface.co/elyza/ELYZA-tasks-100) を利用しています。")
192
+ st.markdown(""" > **注意事項:**
193
+ >
194
+ > 日本語チャットボットアリーナが提供する情報によって生じたいかなる損害についても、サービス提供者は一切の責任を負いません。
195
+ > 日本語チャットボットアリーナは開発中であり、予告なく停止または終了する可能性があります。
196
+ > また、ユーザーの回答を収集し、Creative Commons Attribution (CC-BY) または同様のライセンスの下で配布する権利を留保しています。
197
+ """)
198
+
199
+ # チャット履歴の表示部分
200
+ st.markdown("## ⚔️ チャットボットアリーナ ⚔️")
201
+ st.markdown(" 2つの匿名モデル (ChatGPT、Llama など) の回答を見て、より良いモデルに投票してください。")
202
+ with st.expander(f"🔍 展開するとアリーナに参加している {len(MODELS)} 個のモデルの一覧が表示されます。"):
203
+ st.write(MODELS)
204
+ model_a, model_b = st.columns([1, 1])
205
+ with model_a:
206
+ st.markdown("### モデル A")
207
+ if not st.session_state.chat_history_a:
208
+ st.markdown("質問を取得してください。")
209
+ else:
210
+ for message in st.session_state.chat_history_a:
211
+ with st.chat_message(message["role"]):
212
+ st.write(message["content"])
213
+ # 送信後に正解のモデルを表示する
214
+ if st.session_state.answer_sent:
215
+ with st.chat_message("assistant"):
216
+ st.markdown(f"`{st.session_state.model_a}` が回答しました、")
217
+ with model_b:
218
+ st.markdown("### モデル B")
219
+ if not st.session_state.chat_history_b:
220
+ st.markdown("質問を取得してください。")
221
+ else:
222
+ for message in st.session_state.chat_history_b:
223
+ with st.chat_message(message["role"]):
224
+ st.write(message["content"])
225
+ # 送信後に正解のモデルを表示する
226
+ if st.session_state.answer_sent:
227
+ with st.chat_message("assistant"):
228
+ st.markdown(f"`{st.session_state.model_b}` が回答しました。")
229
+ # 質問を取得する
230
+ load_question = st.button(
231
+ label="質問を取得",
232
+ on_click=handle_init_question,
233
+ # 回答済みの場合 or 質問を取得済の場合はボタンを無効化
234
+ disabled=st.session_state.answer_sent or st.session_state.question_loaded,
235
+ type="primary",
236
+ use_container_width=True
237
+ )
238
+ # 回答を送信する
239
+ choice_1, choice_2, choice_3, choice_4 = st.columns([1, 1, 1, 1])
240
+ with choice_1:
241
+ choice_1 = st.button(
242
+ label="👈 Aの方が良い",
243
+ on_click=handle_send_choice,
244
+ args=("model_a",),
245
+ disabled=not st.session_state.question_loaded,
246
+ use_container_width=True
247
+ )
248
+ with choice_2:
249
+ choice_2 = st.button(
250
+ label="👉 Bの方が良い",
251
+ on_click=handle_send_choice,
252
+ args=("model_b",),
253
+ disabled=not st.session_state.question_loaded,
254
+ use_container_width=True
255
+ )
256
+ with choice_3:
257
+ choice_3 = st.button(
258
+ label="🤝 どちらも良い",
259
+ on_click=handle_send_choice,
260
+ args=("tie",),
261
+ disabled=not st.session_state.question_loaded,
262
+ use_container_width=True
263
+ )
264
+ with choice_4:
265
+ choice_4 = st.button(
266
+ label="👎 どちらも悪い",
267
+ on_click=handle_send_choice,
268
+ args=("tie (bothbad)",),
269
+ disabled=not st.session_state.question_loaded,
270
+ use_container_width=True
271
+ )
272
+
273
+ # リーダーボードを表示する
274
+ st.markdown("## 🏆 リーダーボード")
275
+ st.markdown(f"合計で {len(MODELS)} 個のモデルがアリーナに参加しています。30 分毎にリーダーボードが更新されます。")
276
+ # 回答を送信した場合のみ表示する
277
+ if st.session_state.answer_sent:
278
+ # リーダーボードを取得
279
+ leaderboard = create_leaderboard_df()
280
+ st.dataframe(
281
+ data=leaderboard,
282
+ height=(len(MODELS) + 1) * 35 + 3,
283
+ use_container_width=True,
284
+ hide_index=True,
285
+ )
286
+ else:
287
+ st.markdown("""
288
+ > まずは、「⚔️ チャットボットアリーナ ⚔️」に回答を送信してください。
289
+ > 回答を送信すると、リーダーボードが表示されます。
290
+ """)
291
+
292
+ # 引用を表示する
293
+ st.markdown("## 📚 引用")
294
+ st.markdown("""
295
+ ```
296
+ @misc{elyzatasks100,
297
+ title={ELYZA-tasks-100: 日本語instructionモデル評価データセット},
298
+ url={https://huggingface.co/elyza/ELYZA-tasks-100},
299
+ author={Akira Sasaki and Masato Hirakawa and Shintaro Horie and Tomoaki Nakamura},
300
+ year={2023},
301
+ }
302
+ ```
303
+ """)
304
+
305
+
306
+ if __name__ == "__main__":
307
+ main()