a100 kh commited on
Commit
529989d
1 Parent(s): 264a139
api_endpoints.json ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "claude-3-5-sonnet-20240620": {
3
+ "model_name": "claude-3-5-sonnet-20240620",
4
+ "api_type": "anthropic",
5
+ "anony_only": false,
6
+ "recommended_config": {
7
+ "temperature": 0.7,
8
+ "top_p": 1.0
9
+ },
10
+ "text-arena": true,
11
+ "vision-arena": false
12
+ },
13
+ "command-r-plus": {
14
+ "model_name": "command-r-plus",
15
+ "api_type": "cohere",
16
+ "anony_only": false,
17
+ "recommended_config": {
18
+ "temperature": 0.7,
19
+ "top_p": 1.0
20
+ },
21
+ "text-arena": true,
22
+ "vision-arena": false
23
+ },
24
+ "deepseek-chat": {
25
+ "model_name": "deepseek-chat",
26
+ "api_type": "openai-custom-deepinfra",
27
+ "api_base": "https://api.deepseek.com/v1",
28
+ "env_api_key": "DEEPSEEK_API_KEY",
29
+ "anony_only": false,
30
+ "recommended_config": {
31
+ "temperature": 0.7,
32
+ "top_p": 1.0
33
+ },
34
+ "text-arena": true,
35
+ "vision-arena": false
36
+ },
37
+ "mistral-large-latest": {
38
+ "model_name": "mistral-large-latest",
39
+ "api_type": "mistral",
40
+ "anony_only": false,
41
+ "recommended_config": {
42
+ "temperature": 0.7,
43
+ "top_p": 1.0
44
+ },
45
+ "text-arena": true,
46
+ "vision-arena": false
47
+ },
48
+ "Qwen/Qwen2.5-72B-Instruct": {
49
+ "model_name": "Qwen/Qwen2.5-72B-Instruct",
50
+ "api_type": "openai-custom-deepinfra",
51
+ "api_base": "https://api.deepinfra.com/v1/openai",
52
+ "env_api_key": "DEEPINFRA_API_KEY",
53
+ "anony_only": false,
54
+ "recommended_config": {
55
+ "temperature": 0.7,
56
+ "top_p": 1.0
57
+ },
58
+ "text-arena": true,
59
+ "vision-arena": false
60
+ },
61
+ "google/gemma-2-27b-it": {
62
+ "model_name": "google/gemma-2-27b-it",
63
+ "api_type": "openai-custom-deepinfra",
64
+ "api_base": "https://api.deepinfra.com/v1/openai",
65
+ "env_api_key": "DEEPINFRA_API_KEY",
66
+ "anony_only": false,
67
+ "recommended_config": {
68
+ "temperature": 0.7,
69
+ "top_p": 1.0
70
+ },
71
+ "text-arena": true,
72
+ "vision-arena": false
73
+ },
74
+ "gemini-1.5-flash-latest": {
75
+ "model_name": "gemini-1.5-flash-latest",
76
+ "api_type": "gemini",
77
+ "anony_only": false,
78
+ "recommended_config": {
79
+ "temperature": 0.7,
80
+ "top_p": 1.0
81
+ },
82
+ "text-arena": true,
83
+ "vision-arena": false
84
+ },
85
+ "gemini-1.5-pro-latest": {
86
+ "model_name": "gemini-1.5-pro-latest",
87
+ "api_type": "gemini",
88
+ "anony_only": false,
89
+ "recommended_config": {
90
+ "temperature": 0.7,
91
+ "top_p": 1.0
92
+ },
93
+ "text-arena": true,
94
+ "vision-arena": false
95
+ },
96
+ "gpt-4-turbo-2024-04-09": {
97
+ "model_name": "gpt-4-turbo-2024-04-09",
98
+ "api_type": "openai",
99
+ "api_base": "https://api.openai.com/v1",
100
+ "anony_only": false,
101
+ "recommended_config": {
102
+ "temperature": 0.7,
103
+ "top_p": 1.0
104
+ },
105
+ "text-arena": true,
106
+ "vision-arena": false
107
+ },
108
+ "gpt-4o-mini-2024-07-18": {
109
+ "model_name": "gpt-4o-mini-2024-07-18",
110
+ "api_type": "openai",
111
+ "api_base": "https://api.openai.com/v1",
112
+ "anony_only": false,
113
+ "recommended_config": {
114
+ "temperature": 0.7,
115
+ "top_p": 1.0
116
+ },
117
+ "text-arena": true,
118
+ "vision-arena": false
119
+ },
120
+ "tokyotech-llm-Llama-3.1-Swallow-8B-Instruct-v0.1-Q8_0": {
121
+ "model_name": "tokyotech-llm-Llama-3.1-Swallow-8B-Instruct-v0.1-Q8_0",
122
+ "api_type": "openai-llama3.1",
123
+ "api_base": "http://localhost:8010/v1",
124
+ "api_key": "12345",
125
+ "anony_only": false,
126
+ "recommended_config": {
127
+ "temperature": 0.7,
128
+ "top_p": 1.0
129
+ },
130
+ "text-arena": true,
131
+ "vision-arena": false
132
+ },
133
+ "cyberagent/calm3-22b-chat-BitsAndBytes": {
134
+ "model_name": "cyberagent/calm3-22b-chat",
135
+ "api_type": "openai-custom-calm",
136
+ "api_base": "http://localhost:8011/v1",
137
+ "api_key": "12345",
138
+ "anony_only": false,
139
+ "recommended_config": {
140
+ "temperature": 0.7,
141
+ "top_p": 1.0
142
+ },
143
+ "text-arena": true,
144
+ "vision-arena": false
145
+ },
146
+ "weblab-GENIAC/Tanuki-8B-dpo-v1.0-BitsAndBytes": {
147
+ "model_name": "weblab-GENIAC/Tanuki-8B-dpo-v1.0",
148
+ "api_type": "openai-custom-tanuki",
149
+ "api_base": "http://localhost:8012/v1",
150
+ "api_key": "12345",
151
+ "anony_only": false,
152
+ "recommended_config": {
153
+ "temperature": 0.7,
154
+ "top_p": 1.0
155
+ },
156
+ "text-arena": true,
157
+ "vision-arena": false
158
+ },
159
+ "llm-jp-3-13b-instruct-Q8_0.gguf": {
160
+ "model_name": "llm-jp-3-13b-instruct-Q8_0.gguf",
161
+ "api_type": "openai-llmjp3",
162
+ "api_base": "http://localhost:8016/v1",
163
+ "api_key": "12345",
164
+ "anony_only": false,
165
+ "recommended_config": {
166
+ "temperature": 0.7,
167
+ "top_p": 1.0
168
+ },
169
+ "text-arena": true,
170
+ "vision-arena": false
171
+ },
172
+ "tokyotech-llm/Llama-3.1-Swallow-70B-Instruct-v0.1-BitsAndBytes": {
173
+ "model_name": "tokyotech-llm/Llama-3.1-Swallow-70B-Instruct-v0.1",
174
+ "api_type": "openai-llama3.1",
175
+ "api_base": "http://localhost:8019/v1",
176
+ "api_key": "12345",
177
+ "anony_only": false,
178
+ "recommended_config": {
179
+ "temperature": 0.7,
180
+ "top_p": 1.0
181
+ },
182
+ "text-arena": true,
183
+ "vision-arena": false
184
+ }
185
+ }
app copy.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+
4
+ """
5
+ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
+ """
7
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
+
9
+
10
+ def respond(
11
+ message,
12
+ history: list[tuple[str, str]],
13
+ system_message,
14
+ max_tokens,
15
+ temperature,
16
+ top_p,
17
+ ):
18
+ messages = [{"role": "system", "content": system_message}]
19
+
20
+ for val in history:
21
+ if val[0]:
22
+ messages.append({"role": "user", "content": val[0]})
23
+ if val[1]:
24
+ messages.append({"role": "assistant", "content": val[1]})
25
+
26
+ messages.append({"role": "user", "content": message})
27
+
28
+ response = ""
29
+
30
+ for message in client.chat_completion(
31
+ messages,
32
+ max_tokens=max_tokens,
33
+ stream=True,
34
+ temperature=temperature,
35
+ top_p=top_p,
36
+ ):
37
+ token = message.choices[0].delta.content
38
+
39
+ response += token
40
+ yield response
41
+
42
+
43
+ """
44
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
+ """
46
+ demo = gr.ChatInterface(
47
+ respond,
48
+ additional_inputs=[
49
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
+ gr.Slider(
53
+ minimum=0.1,
54
+ maximum=1.0,
55
+ value=0.95,
56
+ step=0.05,
57
+ label="Top-p (nucleus sampling)",
58
+ ),
59
+ ],
60
+ )
61
+
62
+
63
+ if __name__ == "__main__":
64
+ demo.launch()
app.py CHANGED
@@ -1,64 +1,367 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
  """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
6
  """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- response = ""
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ The gradio demo server with multiple tabs.
3
+ It supports chatting with a single model or chatting with two models side-by-side.
4
  """
 
5
 
6
+ import argparse
7
+ from typing import List
8
 
9
+ import gradio as gr
 
 
 
 
 
 
 
 
10
 
11
+ from serve.gradio_block_arena_anony import (
12
+ build_side_by_side_ui_anony,
13
+ load_demo_side_by_side_anony,
14
+ set_global_vars_anony,
15
+ )
16
+ from serve.gradio_block_arena_named import (
17
+ build_side_by_side_ui_named,
18
+ load_demo_side_by_side_named,
19
+ set_global_vars_named,
20
+ )
21
+ from serve.gradio_block_arena_vision import (
22
+ build_single_vision_language_model_ui,
23
+ )
24
+ from serve.gradio_block_arena_vision_anony import (
25
+ build_side_by_side_vision_ui_anony,
26
+ load_demo_side_by_side_vision_anony,
27
+ )
28
+ from serve.gradio_block_arena_vision_named import (
29
+ build_side_by_side_vision_ui_named,
30
+ load_demo_side_by_side_vision_named,
31
+ )
32
+ from serve.gradio_global_state import Context
33
 
34
+ from serve.gradio_web_server import (
35
+ set_global_vars,
36
+ block_css,
37
+ build_single_model_ui,
38
+ get_model_list,
39
+ load_demo_single,
40
+ get_ip,
41
+ )
42
+ from serve.utils import (
43
+ build_logger,
44
+ get_window_url_params_js,
45
+ get_window_url_params_with_tos_js,
46
+ alert_js,
47
+ parse_gradio_auth_creds,
48
+ )
49
 
50
+ logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
51
 
 
 
 
 
 
 
 
 
52
 
53
+ def load_demo(context: Context, request: gr.Request):
54
+ ip = get_ip(request)
55
+ logger.info(f"load_demo. ip: {ip}. params: {request.query_params}")
56
 
57
+ inner_selected = 0
58
+ if "arena" in request.query_params:
59
+ inner_selected = 0
60
+ elif "vision" in request.query_params:
61
+ inner_selected = 0
62
+ elif "compare" in request.query_params:
63
+ inner_selected = 1
64
+ elif "direct" in request.query_params or "model" in request.query_params:
65
+ inner_selected = 2
66
+ elif "leaderboard" in request.query_params:
67
+ inner_selected = 3
68
+ elif "about" in request.query_params:
69
+ inner_selected = 4
70
 
71
+ if args.model_list_mode == "reload":
72
+ context.text_models, context.all_text_models = get_model_list(
73
+ args.controller_url,
74
+ args.register_api_endpoint_file,
75
+ vision_arena=False,
76
+ )
77
+
78
+ context.vision_models, context.all_vision_models = get_model_list(
79
+ args.controller_url,
80
+ args.register_api_endpoint_file,
81
+ vision_arena=True,
82
+ )
83
+
84
+ # Text models
85
+ if args.vision_arena:
86
+ side_by_side_anony_updates = load_demo_side_by_side_vision_anony()
87
+
88
+ side_by_side_named_updates = load_demo_side_by_side_vision_named(
89
+ context,
90
+ )
91
+
92
+ direct_chat_updates = load_demo_single(context, request.query_params)
93
+ else:
94
+ direct_chat_updates = load_demo_single(context, request.query_params)
95
+ side_by_side_anony_updates = load_demo_side_by_side_anony(
96
+ context.all_text_models, request.query_params
97
+ )
98
+ side_by_side_named_updates = load_demo_side_by_side_named(
99
+ context.text_models, request.query_params
100
+ )
101
+
102
+ tabs_list = (
103
+ [gr.Tabs(selected=inner_selected)]
104
+ + side_by_side_anony_updates
105
+ + side_by_side_named_updates
106
+ + direct_chat_updates
107
+ )
108
+
109
+ return tabs_list
110
+
111
+
112
+ def build_demo(
113
+ context: Context, elo_results_file: str, leaderboard_table_file, arena_hard_table
114
+ ):
115
+ if args.show_terms_of_use:
116
+ load_js = get_window_url_params_with_tos_js
117
+ else:
118
+ load_js = get_window_url_params_js
119
+
120
+ head_js = """
121
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.min.js"></script>
122
  """
123
+ if args.ga_id is not None:
124
+ head_js += f"""
125
+ <script async src="https://www.googletagmanager.com/gtag/js?id={args.ga_id}"></script>
126
+ <script>
127
+ window.dataLayer = window.dataLayer || [];
128
+ function gtag(){{dataLayer.push(arguments);}}
129
+ gtag('js', new Date());
130
+
131
+ gtag('config', '{args.ga_id}');
132
+ window.__gradio_mode__ = "app";
133
+ </script>
134
+ """
135
+
136
+ # head_js = """"""
137
+ text_size = gr.themes.sizes.text_lg
138
+ with gr.Blocks(
139
+ title="Chatbot Arena 日本語版α",
140
+ theme=gr.themes.Default(text_size=text_size),
141
+ css=block_css,
142
+ head=head_js,
143
+ ) as demo:
144
+ with gr.Tabs() as inner_tabs:
145
+ if args.vision_arena:
146
+ with gr.Tab("⚔️ Arena (battle)", id=0) as arena_tab:
147
+ arena_tab.select(None, None, None, js=load_js)
148
+ side_by_side_anony_list = build_side_by_side_vision_ui_anony(
149
+ context,
150
+ random_questions=args.random_questions,
151
+ )
152
+ with gr.Tab("⚔️ Arena (side-by-side)", id=1) as side_by_side_tab:
153
+ side_by_side_tab.select(None, None, None, js=alert_js)
154
+ side_by_side_named_list = build_side_by_side_vision_ui_named(
155
+ context, random_questions=args.random_questions
156
+ )
157
+
158
+ with gr.Tab("💬 Direct Chat", id=2) as direct_tab:
159
+ direct_tab.select(None, None, None, js=alert_js)
160
+ single_model_list = build_single_vision_language_model_ui(
161
+ context,
162
+ add_promotion_links=True,
163
+ random_questions=args.random_questions,
164
+ )
165
+
166
+ else:
167
+ with gr.Tab("⚔️ Arena (battle)", id=0) as arena_tab:
168
+ arena_tab.select(None, None, None, js=load_js)
169
+ side_by_side_anony_list = build_side_by_side_ui_anony(
170
+ context.all_text_models
171
+ )
172
+
173
+ with gr.Tab("⚔️ Arena (side-by-side)", id=1) as side_by_side_tab:
174
+ side_by_side_tab.select(None, None, None, js=alert_js)
175
+ side_by_side_named_list = build_side_by_side_ui_named(
176
+ context.text_models
177
+ )
178
+
179
+ with gr.Tab("💬 Direct Chat", id=2) as direct_tab:
180
+ direct_tab.select(None, None, None, js=alert_js)
181
+ single_model_list = build_single_model_ui(
182
+ context.text_models, add_promotion_links=True
183
+ )
184
+
185
+ demo_tabs = (
186
+ [inner_tabs]
187
+ + side_by_side_anony_list
188
+ + side_by_side_named_list
189
+ + single_model_list
190
+ )
191
+
192
+ # if elo_results_file:
193
+ # with gr.Tab("🏆 Leaderboard", id=3):
194
+ # build_leaderboard_tab(
195
+ # elo_results_file,
196
+ # leaderboard_table_file,
197
+ # arena_hard_table,
198
+ # show_plot=True,
199
+ # )
200
+
201
+ # with gr.Tab("ℹ️ About Us", id=4):
202
+ # about = build_about()
203
+
204
+ context_state = gr.State(context)
205
+ url_params = gr.JSON(visible=False)
206
+
207
+ if args.model_list_mode not in ["once", "reload"]:
208
+ raise ValueError(
209
+ f"Unknown model list mode: {args.model_list_mode}")
210
+
211
+ demo.load(
212
+ load_demo,
213
+ [context_state],
214
+ demo_tabs,
215
+ js=load_js,
216
+ )
217
+
218
+ return demo
219
 
220
 
221
  if __name__ == "__main__":
222
+ parser = argparse.ArgumentParser()
223
+ parser.add_argument("--host", type=str, default="0.0.0.0")
224
+ parser.add_argument("--port", type=int)
225
+ parser.add_argument(
226
+ "--share",
227
+ action="store_true",
228
+ help="Whether to generate a public, shareable link",
229
+ )
230
+ parser.add_argument(
231
+ "--controller-url",
232
+ type=str,
233
+ default="http://localhost:21001",
234
+ help="The address of the controller",
235
+ )
236
+ parser.add_argument(
237
+ "--concurrency-count",
238
+ type=int,
239
+ default=10,
240
+ help="The concurrency count of the gradio queue",
241
+ )
242
+ parser.add_argument(
243
+ "--model-list-mode",
244
+ type=str,
245
+ default="once",
246
+ choices=["once", "reload"],
247
+ help="Whether to load the model list once or reload the model list every time.",
248
+ )
249
+ parser.add_argument(
250
+ "--moderate",
251
+ action="store_true",
252
+ help="Enable content moderation to block unsafe inputs",
253
+ )
254
+ parser.add_argument(
255
+ "--show-terms-of-use",
256
+ action="store_true",
257
+ help="Shows term of use before loading the demo",
258
+ )
259
+ parser.add_argument(
260
+ "--vision-arena", action="store_true", help="Show tabs for vision arena."
261
+ )
262
+ parser.add_argument(
263
+ "--random-questions", type=str, help="Load random questions from a JSON file"
264
+ )
265
+ parser.add_argument(
266
+ "--register-api-endpoint-file",
267
+ type=str,
268
+ help="Register API-based model endpoints from a JSON file",
269
+ default="api_endpoints.json",
270
+ )
271
+ parser.add_argument(
272
+ "--gradio-auth-path",
273
+ type=str,
274
+ help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
275
+ default=None,
276
+ )
277
+ parser.add_argument(
278
+ "--elo-results-file", type=str, help="Load leaderboard results and plots"
279
+ )
280
+ parser.add_argument(
281
+ "--leaderboard-table-file", type=str, help="Load leaderboard results and plots"
282
+ )
283
+ parser.add_argument(
284
+ "--arena-hard-table", type=str, help="Load leaderboard results and plots"
285
+ )
286
+ parser.add_argument(
287
+ "--gradio-root-path",
288
+ type=str,
289
+ help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix",
290
+ )
291
+ parser.add_argument(
292
+ "--ga-id",
293
+ type=str,
294
+ help="the Google Analytics ID",
295
+ default=None,
296
+ )
297
+ parser.add_argument(
298
+ "--use-remote-storage",
299
+ action="store_true",
300
+ default=False,
301
+ help="Uploads image files to google cloud storage if set to true",
302
+ )
303
+ parser.add_argument(
304
+ "--password",
305
+ type=str,
306
+ help="Set the password for the gradio web server",
307
+ )
308
+ args = parser.parse_args()
309
+ logger.info(f"args: {args}")
310
+
311
+ # Set global variables
312
+ set_global_vars(args.controller_url, args.moderate,
313
+ args.use_remote_storage)
314
+ set_global_vars_named(args.moderate)
315
+ set_global_vars_anony(args.moderate)
316
+ text_models, all_text_models = get_model_list(
317
+ args.controller_url,
318
+ args.register_api_endpoint_file,
319
+ vision_arena=False,
320
+ )
321
+
322
+ vision_models, all_vision_models = get_model_list(
323
+ args.controller_url,
324
+ args.register_api_endpoint_file,
325
+ vision_arena=True,
326
+ )
327
+
328
+ models = text_models + [
329
+ model for model in vision_models if model not in text_models
330
+ ]
331
+ all_models = all_text_models + [
332
+ model for model in all_vision_models if model not in all_text_models
333
+ ]
334
+ context = Context(
335
+ text_models,
336
+ all_text_models,
337
+ vision_models,
338
+ all_vision_models,
339
+ models,
340
+ all_models,
341
+ )
342
+
343
+ # Set authorization credentials
344
+ auth = None
345
+ if args.gradio_auth_path is not None:
346
+ auth = parse_gradio_auth_creds(args.gradio_auth_path)
347
+
348
+ # Launch the demo
349
+ demo = build_demo(
350
+ context,
351
+ args.elo_results_file,
352
+ args.leaderboard_table_file,
353
+ args.arena_hard_table,
354
+ )
355
+ demo.queue(
356
+ default_concurrency_limit=args.concurrency_count,
357
+ status_update_rate=10,
358
+ api_open=False,
359
+ ).launch(
360
+ server_name=args.host,
361
+ server_port=args.port,
362
+ share=args.share,
363
+ max_threads=200,
364
+ auth=auth,
365
+ root_path=args.gradio_root_path,
366
+ show_api=False,
367
+ )
requirements.txt CHANGED
@@ -1 +1,7 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
 
1
+ huggingface_hub==0.25.2
2
+ openai==1.52.0
3
+ google-generativeai==0.8.3
4
+ mistralai==1.1.0
5
+ cohere==5.11.1
6
+ anthropic==0.36.2
7
+
serve/api_provider.py ADDED
@@ -0,0 +1,1361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Call API providers."""
2
+
3
+ import json
4
+ import os
5
+ import random
6
+ import re
7
+ from typing import Optional
8
+ import time
9
+
10
+ import requests
11
+
12
+ from .utils import build_logger
13
+
14
+
15
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
16
+
17
+
18
+ def get_api_provider_stream_iter(
19
+ conv,
20
+ model_name,
21
+ model_api_dict,
22
+ temperature,
23
+ top_p,
24
+ max_new_tokens,
25
+ state,
26
+ ):
27
+ if model_api_dict["api_type"] == "openai":
28
+ if model_api_dict.get("vision-arena", False):
29
+ prompt = conv.to_openai_vision_api_messages()
30
+ else:
31
+ prompt = conv.to_openai_api_messages()
32
+ stream_iter = openai_api_stream_iter(
33
+ model_api_dict["model_name"],
34
+ prompt,
35
+ temperature,
36
+ top_p,
37
+ max_new_tokens,
38
+ api_base=model_api_dict["api_base"],
39
+ # api_key=model_api_dict["api_key"],
40
+ )
41
+
42
+ elif model_api_dict["api_type"].find("openai-custom") >= 0:
43
+ if conv.get_system_message() == "":
44
+ if model_api_dict["api_type"] == "openai-custom-tanuki":
45
+ conv.set_system_message('以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。')
46
+ elif model_api_dict["api_type"] == "openai-custom-calm":
47
+ conv.set_system_message('あなたは親切なAIアシスタントです。')
48
+ elif model_api_dict["api_type"] == "openai-custom-deepinfra":
49
+ conv.set_system_message(
50
+ 'あなたは親切な日本語のアシスタントです。')
51
+
52
+ if "api_key" in model_api_dict:
53
+ api_key = model_api_dict["api_key"]
54
+ else:
55
+ api_key = os.environ[model_api_dict["env_api_key"]]
56
+
57
+ messages = conv.to_openai_api_messages()
58
+ stream_iter = openai_api_stream_iter(
59
+ model_api_dict["model_name"],
60
+ messages,
61
+ temperature,
62
+ top_p,
63
+ max_new_tokens,
64
+ api_base=model_api_dict["api_base"],
65
+ api_key=api_key,
66
+ # api_key=os.environ[model_api_dict["env_api_key"]],
67
+ # api_key=model_api_dict["api_key"],
68
+ )
69
+ elif model_api_dict["api_type"] == "openai-llama3.1":
70
+ if conv.get_system_message() == "":
71
+ conv.set_system_message('あなたは誠実で優秀な日本人のアシスタントです。')
72
+
73
+ messages = conv.to_openai_api_messages()
74
+ stream_iter = openai_api_stream_iter(
75
+ model_api_dict["model_name"],
76
+ messages,
77
+ temperature,
78
+ top_p,
79
+ max_new_tokens,
80
+ api_base=model_api_dict["api_base"],
81
+ api_key=model_api_dict["api_key"],
82
+ stop="<|im_end|>",
83
+ )
84
+ elif model_api_dict["api_type"] == "openai-llmjp3":
85
+ if conv.get_system_message() == "":
86
+ conv.set_system_message('以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。')
87
+
88
+ messages = conv.to_openai_api_messages()
89
+ stream_iter = openai_api_stream_iter(
90
+ model_api_dict["model_name"],
91
+ messages,
92
+ temperature,
93
+ top_p,
94
+ max_new_tokens,
95
+ api_base=model_api_dict["api_base"],
96
+ api_key=model_api_dict["api_key"],
97
+ stop="<|im_end|>",
98
+ )
99
+ elif model_api_dict["api_type"] == "openai_no_stream":
100
+ prompt = conv.to_openai_api_messages()
101
+ stream_iter = openai_api_stream_iter(
102
+ model_api_dict["model_name"],
103
+ prompt,
104
+ temperature,
105
+ top_p,
106
+ max_new_tokens,
107
+ api_base=model_api_dict["api_base"],
108
+ # api_key=model_api_dict["api_key"],
109
+ stream=False,
110
+ )
111
+ elif model_api_dict["api_type"] == "openai_o1":
112
+ prompt = conv.to_openai_api_messages()
113
+ stream_iter = openai_api_stream_iter(
114
+ model_api_dict["model_name"],
115
+ prompt,
116
+ temperature,
117
+ top_p,
118
+ max_new_tokens,
119
+ api_base=model_api_dict["api_base"],
120
+ api_key=model_api_dict["api_key"],
121
+ is_o1=True,
122
+ )
123
+ elif model_api_dict["api_type"] == "openai_assistant":
124
+ last_prompt = conv.messages[-2][1]
125
+ stream_iter = openai_assistant_api_stream_iter(
126
+ state,
127
+ last_prompt,
128
+ assistant_id=model_api_dict["assistant_id"],
129
+ api_key=model_api_dict["api_key"],
130
+ )
131
+ elif model_api_dict["api_type"] == "anthropic":
132
+ if model_api_dict.get("vision-arena", False):
133
+ prompt = conv.to_anthropic_vision_api_messages()
134
+ else:
135
+ prompt = conv.to_openai_api_messages()
136
+ stream_iter = anthropic_api_stream_iter(
137
+ model_name, prompt, temperature, top_p, max_new_tokens
138
+ )
139
+ elif model_api_dict["api_type"] == "anthropic_message":
140
+ if model_api_dict.get("vision-arena", False):
141
+ prompt = conv.to_anthropic_vision_api_messages()
142
+ else:
143
+ prompt = conv.to_openai_api_messages()
144
+ stream_iter = anthropic_message_api_stream_iter(
145
+ model_api_dict["model_name"], prompt, temperature, top_p, max_new_tokens
146
+ )
147
+ elif model_api_dict["api_type"] == "anthropic_message_vertex":
148
+ if model_api_dict.get("vision-arena", False):
149
+ prompt = conv.to_anthropic_vision_api_messages()
150
+ else:
151
+ prompt = conv.to_openai_api_messages()
152
+ stream_iter = anthropic_message_api_stream_iter(
153
+ model_api_dict["model_name"],
154
+ prompt,
155
+ temperature,
156
+ top_p,
157
+ max_new_tokens,
158
+ vertex_ai=True,
159
+ )
160
+ elif model_api_dict["api_type"] == "gemini":
161
+ prompt = conv.to_gemini_api_messages()
162
+ stream_iter = gemini_api_stream_iter(
163
+ model_api_dict["model_name"],
164
+ prompt,
165
+ temperature,
166
+ top_p,
167
+ max_new_tokens,
168
+ # api_key=model_api_dict["api_key"],
169
+ )
170
+ elif model_api_dict["api_type"] == "gemini_no_stream":
171
+ prompt = conv.to_gemini_api_messages()
172
+ stream_iter = gemini_api_stream_iter(
173
+ model_api_dict["model_name"],
174
+ prompt,
175
+ temperature,
176
+ top_p,
177
+ max_new_tokens,
178
+ # api_key=model_api_dict["api_key"],
179
+ use_stream=False,
180
+ )
181
+ elif model_api_dict["api_type"] == "bard":
182
+ prompt = conv.to_openai_api_messages()
183
+ stream_iter = gemini_api_stream_iter(
184
+ model_api_dict["model_name"],
185
+ prompt,
186
+ None, # use Bard's default temperature
187
+ None, # use Bard's default top_p
188
+ max_new_tokens,
189
+ api_key=(model_api_dict["api_key"] or os.environ["BARD_API_KEY"]),
190
+ use_stream=False,
191
+ )
192
+ elif model_api_dict["api_type"] == "mistral":
193
+ if model_api_dict.get("vision-arena", False):
194
+ prompt = conv.to_openai_vision_api_messages(is_mistral=True)
195
+ else:
196
+ prompt = conv.to_openai_api_messages()
197
+ stream_iter = mistral_api_stream_iter(
198
+ model_api_dict["model_name"],
199
+ prompt,
200
+ temperature,
201
+ top_p,
202
+ max_new_tokens,
203
+ api_key=None,
204
+ )
205
+ elif model_api_dict["api_type"] == "nvidia":
206
+ prompt = conv.to_openai_api_messages()
207
+ stream_iter = nvidia_api_stream_iter(
208
+ model_name,
209
+ prompt,
210
+ temperature,
211
+ top_p,
212
+ max_new_tokens,
213
+ model_api_dict["api_base"],
214
+ model_api_dict["api_key"],
215
+ )
216
+ elif model_api_dict["api_type"] == "ai2":
217
+ prompt = conv.to_openai_api_messages()
218
+ stream_iter = ai2_api_stream_iter(
219
+ model_name,
220
+ model_api_dict["model_name"],
221
+ prompt,
222
+ temperature,
223
+ top_p,
224
+ max_new_tokens,
225
+ api_base=model_api_dict["api_base"],
226
+ api_key=model_api_dict["api_key"],
227
+ )
228
+ elif model_api_dict["api_type"] == "vertex":
229
+ prompt = conv.to_vertex_api_messages()
230
+ stream_iter = vertex_api_stream_iter(
231
+ model_name, prompt, temperature, top_p, max_new_tokens
232
+ )
233
+ elif model_api_dict["api_type"] == "yandexgpt":
234
+ # note: top_p parameter is unused by yandexgpt
235
+
236
+ messages = []
237
+ if conv.system_message:
238
+ messages.append({"role": "system", "text": conv.system_message})
239
+ messages += [
240
+ {"role": role, "text": text}
241
+ for role, text in conv.messages
242
+ if text is not None
243
+ ]
244
+
245
+ fixed_temperature = model_api_dict.get("fixed_temperature")
246
+ if fixed_temperature is not None:
247
+ temperature = fixed_temperature
248
+
249
+ stream_iter = yandexgpt_api_stream_iter(
250
+ model_name=model_api_dict["model_name"],
251
+ messages=messages,
252
+ temperature=temperature,
253
+ max_tokens=max_new_tokens,
254
+ api_base=model_api_dict["api_base"],
255
+ api_key=model_api_dict.get("api_key"),
256
+ folder_id=model_api_dict.get("folder_id"),
257
+ )
258
+ elif model_api_dict["api_type"] == "cohere":
259
+ messages = conv.to_openai_api_messages()
260
+ stream_iter = cohere_api_stream_iter(
261
+ client_name=model_api_dict.get("client_name", "FastChat"),
262
+ model_id=model_api_dict["model_name"],
263
+ messages=messages,
264
+ temperature=temperature,
265
+ top_p=top_p,
266
+ max_new_tokens=max_new_tokens,
267
+ # api_base=model_api_dict["api_base"],
268
+ # api_key=model_api_dict["api_key"],
269
+ )
270
+ elif model_api_dict["api_type"] == "reka":
271
+ messages = conv.to_reka_api_messages()
272
+ stream_iter = reka_api_stream_iter(
273
+ model_name=model_api_dict["model_name"],
274
+ messages=messages,
275
+ temperature=temperature,
276
+ top_p=top_p,
277
+ max_new_tokens=max_new_tokens,
278
+ api_base=model_api_dict["api_base"],
279
+ api_key=model_api_dict["api_key"],
280
+ )
281
+ elif model_api_dict["api_type"] == "column":
282
+ if model_api_dict.get("vision-arena", False):
283
+ prompt = conv.to_openai_vision_api_messages()
284
+ else:
285
+ prompt = conv.to_openai_api_messages()
286
+ stream_iter = column_api_stream_iter(
287
+ model_name=model_api_dict["model_name"],
288
+ messages=prompt,
289
+ temperature=temperature,
290
+ top_p=top_p,
291
+ max_new_tokens=max_new_tokens,
292
+ api_base=model_api_dict["api_base"],
293
+ api_key=model_api_dict["api_key"],
294
+ )
295
+ elif model_api_dict["api_type"] == "metagen":
296
+ prompt = conv.to_metagen_api_messages()
297
+ stream_iter = metagen_api_stream_iter(
298
+ model_api_dict["model_name"],
299
+ prompt,
300
+ temperature,
301
+ top_p,
302
+ max_new_tokens,
303
+ api_base=model_api_dict["api_base"],
304
+ api_key=model_api_dict["api_key"],
305
+ )
306
+ else:
307
+ raise NotImplementedError()
308
+
309
+ return stream_iter
310
+
311
+
312
+ def openai_api_stream_iter(
313
+ model_name,
314
+ messages,
315
+ temperature,
316
+ top_p,
317
+ max_new_tokens,
318
+ api_base=None,
319
+ api_key=None,
320
+ stream=True,
321
+ is_o1=False,
322
+ stop="dummy_stop_token123456789",
323
+ ):
324
+ import openai
325
+
326
+ api_key = api_key or os.environ["OPENAI_API_KEY"]
327
+
328
+ if "azure" in model_name:
329
+ client = openai.AzureOpenAI(
330
+ api_version="2023-07-01-preview",
331
+ azure_endpoint=api_base or "https://api.openai.com/v1",
332
+ api_key=api_key,
333
+ )
334
+ else:
335
+ client = openai.OpenAI(
336
+ base_url=api_base or "https://api.openai.com/v1",
337
+ api_key=api_key,
338
+ timeout=180,
339
+ )
340
+
341
+ # Make requests for logging
342
+ text_messages = []
343
+ for message in messages:
344
+ if type(message["content"]) == str: # text-only model
345
+ text_messages.append(message)
346
+ else: # vision model
347
+ filtered_content_list = [
348
+ content for content in message["content"] if content["type"] == "text"
349
+ ]
350
+ text_messages.append(
351
+ {"role": message["role"], "content": filtered_content_list}
352
+ )
353
+
354
+ gen_params = {
355
+ "model": model_name,
356
+ "prompt": text_messages,
357
+ "temperature": temperature,
358
+ "top_p": top_p,
359
+ "max_new_tokens": max_new_tokens,
360
+ }
361
+ logger.info(f"==== request ====\n{gen_params}")
362
+
363
+ if stream and not is_o1:
364
+ res = client.chat.completions.create(
365
+ model=model_name,
366
+ messages=messages,
367
+ temperature=temperature,
368
+ max_tokens=max_new_tokens,
369
+ stream=True,
370
+ stop=stop,
371
+ )
372
+ text = ""
373
+ for chunk in res:
374
+ if len(chunk.choices) > 0:
375
+ text += chunk.choices[0].delta.content or ""
376
+ data = {
377
+ "text": text,
378
+ "error_code": 0,
379
+ }
380
+ yield data
381
+ else:
382
+ if is_o1:
383
+ res = client.chat.completions.create(
384
+ model=model_name,
385
+ messages=messages,
386
+ temperature=1.0,
387
+ stream=False,
388
+ )
389
+ else:
390
+ res = client.chat.completions.create(
391
+ model=model_name,
392
+ messages=messages,
393
+ temperature=temperature,
394
+ max_tokens=max_new_tokens,
395
+ stream=False,
396
+ )
397
+ text = res.choices[0].message.content
398
+ pos = 0
399
+ while pos < len(text):
400
+ # simulate token streaming
401
+ pos += 2
402
+ time.sleep(0.001)
403
+ data = {
404
+ "text": text[:pos],
405
+ "error_code": 0,
406
+ }
407
+ yield data
408
+
409
+
410
+ def column_api_stream_iter(
411
+ model_name,
412
+ messages,
413
+ temperature,
414
+ top_p,
415
+ max_new_tokens,
416
+ api_base=None,
417
+ api_key=None,
418
+ ):
419
+ try:
420
+ messages_no_img = []
421
+ for msg in messages:
422
+ msg_no_img = msg.copy()
423
+ msg_no_img.pop("attachment", None)
424
+ messages_no_img.append(msg_no_img)
425
+
426
+ gen_params = {
427
+ "model": model_name,
428
+ "messages": messages_no_img,
429
+ "temperature": temperature,
430
+ "top_p": top_p,
431
+ "max_new_tokens": max_new_tokens,
432
+ "seed": 42,
433
+ }
434
+ logger.info(f"==== request ====\n{gen_params}")
435
+
436
+ gen_params["messages"] = messages
437
+ gen_params["stream"] = True
438
+
439
+ # payload.pop("model")
440
+
441
+ # try 3 times
442
+ for i in range(3):
443
+ try:
444
+ response = requests.post(
445
+ api_base, json=gen_params, stream=True, timeout=30
446
+ )
447
+ break
448
+ except Exception as e:
449
+ logger.error(f"==== error ====\n{e}")
450
+ if i == 2:
451
+ yield {
452
+ "text": f"**API REQUEST ERROR** Reason: API timeout. please try again later.",
453
+ "error_code": 1,
454
+ }
455
+ return
456
+
457
+ text = ""
458
+ for line in response.iter_lines():
459
+ if line:
460
+ data = line.decode("utf-8")
461
+ if data.startswith("data:"):
462
+ data = json.loads(data[6:])["message"]
463
+ text += data
464
+ yield {"text": text, "error_code": 0}
465
+
466
+ except Exception as e:
467
+ logger.error(f"==== error ====\n{e}")
468
+ yield {
469
+ "text": f"**API REQUEST ERROR** Reason: Unknown.",
470
+ "error_code": 1,
471
+ }
472
+
473
+
474
+ def upload_openai_file_to_gcs(file_id):
475
+ import openai
476
+ from google.cloud import storage
477
+
478
+ storage_client = storage.Client()
479
+
480
+ file = openai.files.content(file_id)
481
+ # upload file to GCS
482
+ bucket = storage_client.get_bucket("arena_user_content")
483
+ blob = bucket.blob(f"{file_id}")
484
+ blob.upload_from_string(file.read())
485
+ blob.make_public()
486
+ return blob.public_url
487
+
488
+
489
+ def openai_assistant_api_stream_iter(
490
+ state,
491
+ prompt,
492
+ assistant_id,
493
+ api_key=None,
494
+ ):
495
+ import openai
496
+ import base64
497
+
498
+ api_key = api_key or os.environ["OPENAI_API_KEY"]
499
+ client = openai.OpenAI(
500
+ base_url="https://api.openai.com/v1", api_key=api_key)
501
+
502
+ if state.oai_thread_id is None:
503
+ logger.info("==== create thread ====")
504
+ thread = client.beta.threads.create()
505
+ state.oai_thread_id = thread.id
506
+ logger.info(f"==== thread_id ====\n{state.oai_thread_id}")
507
+ thread_message = client.beta.threads.messages.with_raw_response.create(
508
+ state.oai_thread_id,
509
+ role="user",
510
+ content=prompt,
511
+ timeout=3,
512
+ )
513
+ # logger.info(f"header {thread_message.headers}")
514
+ thread_message = thread_message.parse()
515
+ # Make requests
516
+ gen_params = {
517
+ "assistant_id": assistant_id,
518
+ "thread_id": state.oai_thread_id,
519
+ "message": prompt,
520
+ }
521
+ logger.info(f"==== request ====\n{gen_params}")
522
+
523
+ res = requests.post(
524
+ f"https://api.openai.com/v1/threads/{state.oai_thread_id}/runs",
525
+ headers={
526
+ "Authorization": f"Bearer {api_key}",
527
+ "Content-Type": "application/json",
528
+ "OpenAI-Beta": "assistants=v1",
529
+ },
530
+ json={"assistant_id": assistant_id, "stream": True},
531
+ timeout=30,
532
+ stream=True,
533
+ )
534
+
535
+ list_of_text = []
536
+ list_of_raw_text = []
537
+ offset_idx = 0
538
+ full_ret_text = ""
539
+ idx_mapping = {}
540
+ cur_offset = 0
541
+ for line in res.iter_lines():
542
+ if not line:
543
+ continue
544
+ data = line.decode("utf-8")
545
+ # logger.info("data:", data)
546
+ if data.endswith("[DONE]"):
547
+ break
548
+ if data.startswith("event"):
549
+ event = data.split(":")[1].strip()
550
+ if event == "thread.message.completed":
551
+ offset_idx += len(list_of_text)
552
+ continue
553
+ data = json.loads(data[6:])
554
+
555
+ if data.get("status") == "failed":
556
+ yield {
557
+ "text": f"**API REQUEST ERROR** Reason: {data['last_error']['message']}",
558
+ "error_code": 1,
559
+ }
560
+ return
561
+
562
+ if data.get("status") == "completed":
563
+ logger.info(f"[debug]: {data}")
564
+
565
+ if data["object"] != "thread.message.delta":
566
+ continue
567
+
568
+ for delta in data["delta"]["content"]:
569
+ text_index = delta["index"] + offset_idx
570
+ if len(list_of_text) <= text_index:
571
+ list_of_text.append("")
572
+ list_of_raw_text.append("")
573
+
574
+ text = list_of_text[text_index]
575
+ raw_text = list_of_raw_text[text_index]
576
+
577
+ if delta["type"] == "text":
578
+ # text, url_citation or file_path
579
+ content = delta["text"]
580
+ if "annotations" in content and len(content["annotations"]) > 0:
581
+ annotations = content["annotations"]
582
+
583
+ raw_text_copy = text
584
+ for anno in annotations:
585
+ if anno["type"] == "url_citation":
586
+ pattern = r"【\d+†source】"
587
+ matches = re.findall(pattern, content["value"])
588
+ if len(matches) > 0:
589
+ for match in matches:
590
+ print(match)
591
+ if match not in idx_mapping:
592
+ idx_mapping[match] = len(
593
+ idx_mapping) + 1
594
+ citation_number = idx_mapping[match]
595
+
596
+ start_idx = anno["start_index"] + cur_offset
597
+ end_idx = anno["end_index"] + cur_offset
598
+ url = anno["url_citation"]["url"]
599
+
600
+ citation = f" [[{citation_number}]]({url})"
601
+ raw_text_copy = (
602
+ raw_text_copy[:start_idx]
603
+ + citation
604
+ + raw_text_copy[end_idx:]
605
+ )
606
+ cur_offset += len(citation) - (end_idx - start_idx)
607
+ elif anno["type"] == "file_path":
608
+ file_public_url = upload_openai_file_to_gcs(
609
+ anno["file_path"]["file_id"]
610
+ )
611
+ raw_text_copy = raw_text_copy.replace(
612
+ anno["text"], f"{file_public_url}"
613
+ )
614
+ text = raw_text_copy
615
+ else:
616
+ text_content = content["value"]
617
+ text += text_content
618
+ elif delta["type"] == "image_file":
619
+ image_public_url = upload_openai_file_to_gcs(
620
+ delta["image_file"]["file_id"]
621
+ )
622
+ text += f"![image]({image_public_url})"
623
+
624
+ list_of_text[text_index] = text
625
+ list_of_raw_text[text_index] = raw_text
626
+
627
+ full_ret_text = "\n".join(list_of_text)
628
+ yield {"text": full_ret_text, "error_code": 0}
629
+
630
+
631
+ def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens):
632
+ import anthropic
633
+
634
+ c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
635
+
636
+ # Make requests
637
+ gen_params = {
638
+ "model": model_name,
639
+ "prompt": prompt,
640
+ "temperature": temperature,
641
+ "top_p": top_p,
642
+ "max_new_tokens": max_new_tokens,
643
+ }
644
+ logger.info(f"==== request ====\n{gen_params}")
645
+
646
+ res = c.messages.create(
647
+ # res = c.completions.create(
648
+ # prompt=prompt,
649
+ messages=prompt,
650
+ # stop_sequences=[anthropic.HUMAN_PROMPT],
651
+ # max_tokens_to_sample=max_new_tokens,
652
+ max_tokens=max_new_tokens,
653
+ temperature=temperature,
654
+ top_p=top_p,
655
+ model=model_name,
656
+ stream=True,
657
+ )
658
+ text = ""
659
+ text = ""
660
+ for chunk in res:
661
+
662
+ if hasattr(chunk, 'delta'):
663
+ if hasattr(chunk.delta, 'text'):
664
+ if chunk.delta.text is not None:
665
+ if isinstance(chunk.delta.text, str):
666
+ text += chunk.delta.text
667
+ elif isinstance(chunk.delta.text, list):
668
+ text += ''.join(chunk.delta.text)
669
+ elif hasattr(chunk, 'message') and chunk.message.content is not None:
670
+ if isinstance(chunk.message.content, str):
671
+ text += chunk.message.content
672
+ elif isinstance(chunk.message.content, list):
673
+ text += ''.join(chunk.message.content)
674
+ else:
675
+ print(chunk)
676
+ continue
677
+
678
+ data = {
679
+ "text": text,
680
+ "error_code": 0,
681
+ }
682
+ yield data
683
+ # for chunk in res:
684
+ # text += chunk.completion
685
+ # text += chunk.text_stream
686
+ # data = {
687
+ # "text": text,
688
+ # "error_code": 0,
689
+ # }
690
+ # yield data
691
+
692
+
693
+ def anthropic_message_api_stream_iter(
694
+ model_name,
695
+ messages,
696
+ temperature,
697
+ top_p,
698
+ max_new_tokens,
699
+ vertex_ai=False,
700
+ ):
701
+ import anthropic
702
+
703
+ if vertex_ai:
704
+ client = anthropic.AnthropicVertex(
705
+ region=os.environ["GCP_LOCATION"],
706
+ project_id=os.environ["GCP_PROJECT_ID"],
707
+ max_retries=5,
708
+ )
709
+ else:
710
+ client = anthropic.Anthropic(
711
+ api_key=os.environ["ANTHROPIC_API_KEY"],
712
+ max_retries=5,
713
+ )
714
+
715
+ text_messages = []
716
+ for message in messages:
717
+ if type(message["content"]) == str: # text-only model
718
+ text_messages.append(message)
719
+ else: # vision model
720
+ filtered_content_list = [
721
+ content for content in message["content"] if content["type"] == "text"
722
+ ]
723
+ text_messages.append(
724
+ {"role": message["role"], "content": filtered_content_list}
725
+ )
726
+
727
+ # Make requests for logging
728
+ gen_params = {
729
+ "model": model_name,
730
+ "prompt": text_messages,
731
+ "temperature": temperature,
732
+ "top_p": top_p,
733
+ "max_new_tokens": max_new_tokens,
734
+ }
735
+ logger.info(f"==== request ====\n{gen_params}")
736
+
737
+ system_prompt = ""
738
+ if messages[0]["role"] == "system":
739
+ if type(messages[0]["content"]) == dict:
740
+ system_prompt = messages[0]["content"]["text"]
741
+ elif type(messages[0]["content"]) == str:
742
+ system_prompt = messages[0]["content"]
743
+ # remove system prompt
744
+ messages = messages[1:]
745
+
746
+ text = ""
747
+ with client.messages.stream(
748
+ temperature=temperature,
749
+ top_p=top_p,
750
+ max_tokens=max_new_tokens,
751
+ messages=messages,
752
+ model=model_name,
753
+ system=system_prompt,
754
+ ) as stream:
755
+ for chunk in stream.text_stream:
756
+ text += chunk
757
+ data = {
758
+ "text": text,
759
+ "error_code": 0,
760
+ }
761
+ yield data
762
+
763
+
764
+ def gemini_api_stream_iter(
765
+ model_name,
766
+ messages,
767
+ temperature,
768
+ top_p,
769
+ max_new_tokens,
770
+ api_key=None,
771
+ use_stream=True,
772
+ ):
773
+ import google.generativeai as genai # pip install google-generativeai
774
+
775
+ if api_key is None:
776
+ api_key = os.environ["GEMINI_API_KEY"]
777
+ genai.configure(api_key=api_key)
778
+
779
+ generation_config = {
780
+ "temperature": temperature,
781
+ "max_output_tokens": max_new_tokens,
782
+ "top_p": top_p,
783
+ }
784
+ params = {
785
+ "model": model_name,
786
+ "prompt": messages,
787
+ }
788
+ params.update(generation_config)
789
+ logger.info(f"==== request ====\n{params}")
790
+
791
+ safety_settings = [
792
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
793
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
794
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
795
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
796
+ ]
797
+
798
+ history = []
799
+ system_prompt = None
800
+ for message in messages[:-1]:
801
+ if message["role"] == "system":
802
+ system_prompt = message["content"]
803
+ continue
804
+ history.append({"role": message["role"], "parts": message["content"]})
805
+
806
+ model = genai.GenerativeModel(
807
+ model_name=model_name,
808
+ system_instruction=system_prompt,
809
+ generation_config=generation_config,
810
+ safety_settings=safety_settings,
811
+ )
812
+ convo = model.start_chat(history=history)
813
+
814
+ if use_stream:
815
+ response = convo.send_message(messages[-1]["content"], stream=True)
816
+ try:
817
+ text = ""
818
+ for chunk in response:
819
+ text += chunk.candidates[0].content.parts[0].text
820
+ data = {
821
+ "text": text,
822
+ "error_code": 0,
823
+ }
824
+ yield data
825
+ except Exception as e:
826
+ logger.error(f"==== error ====\n{e}")
827
+ reason = chunk.candidates
828
+ yield {
829
+ "text": f"**API REQUEST ERROR** Reason: {reason}.",
830
+ "error_code": 1,
831
+ }
832
+ else:
833
+ try:
834
+ response = convo.send_message(
835
+ messages[-1]["content"], stream=False)
836
+ text = response.candidates[0].content.parts[0].text
837
+ pos = 0
838
+ while pos < len(text):
839
+ # simulate token streaming
840
+ pos += 5
841
+ time.sleep(0.001)
842
+ data = {
843
+ "text": text[:pos],
844
+ "error_code": 0,
845
+ }
846
+ yield data
847
+ except Exception as e:
848
+ logger.error(f"==== error ====\n{e}")
849
+ yield {
850
+ "text": f"**API REQUEST ERROR** Reason: {e}.",
851
+ "error_code": 1,
852
+ }
853
+
854
+
855
+ def ai2_api_stream_iter(
856
+ model_name,
857
+ model_id,
858
+ messages,
859
+ temperature,
860
+ top_p,
861
+ max_new_tokens,
862
+ api_key=None,
863
+ api_base=None,
864
+ ):
865
+ # get keys and needed values
866
+ ai2_key = api_key or os.environ.get("AI2_API_KEY")
867
+ api_base = api_base or "https://inferd.allen.ai/api/v1/infer"
868
+
869
+ # Make requests
870
+ gen_params = {
871
+ "model": model_name,
872
+ "prompt": messages,
873
+ "temperature": temperature,
874
+ "top_p": top_p,
875
+ "max_new_tokens": max_new_tokens,
876
+ }
877
+ logger.info(f"==== request ====\n{gen_params}")
878
+
879
+ # AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling:
880
+ # https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157
881
+ if temperature == 0.0 and top_p < 1.0:
882
+ raise ValueError("top_p must be 1 when temperature is 0.0")
883
+
884
+ res = requests.post(
885
+ api_base,
886
+ stream=True,
887
+ headers={"Authorization": f"Bearer {ai2_key}"},
888
+ json={
889
+ "model_id": model_id,
890
+ # This input format is specific to the Tulu2 model. Other models
891
+ # may require different input formats. See the model's schema
892
+ # documentation on InferD for more information.
893
+ "input": {
894
+ "messages": messages,
895
+ "opts": {
896
+ "max_tokens": max_new_tokens,
897
+ "temperature": temperature,
898
+ "top_p": top_p,
899
+ "logprobs": 1, # increase for more choices
900
+ },
901
+ },
902
+ },
903
+ timeout=5,
904
+ )
905
+
906
+ if res.status_code != 200:
907
+ logger.error(f"unexpected response ({res.status_code}): {res.text}")
908
+ raise ValueError("unexpected response from InferD", res)
909
+
910
+ text = ""
911
+ for line in res.iter_lines():
912
+ if line:
913
+ part = json.loads(line)
914
+ if "result" in part and "output" in part["result"]:
915
+ for t in part["result"]["output"]["text"]:
916
+ text += t
917
+ else:
918
+ logger.error(f"unexpected part: {part}")
919
+ raise ValueError("empty result in InferD response")
920
+
921
+ data = {
922
+ "text": text,
923
+ "error_code": 0,
924
+ }
925
+ yield data
926
+
927
+
928
+ def mistral_api_stream_iter(
929
+ model_name, messages, temperature, top_p, max_new_tokens, api_key=None
930
+ ):
931
+ # from mistralai.client import MistralClient
932
+ # from mistralai.models.chat_completion import ChatMessage
933
+ from mistralai import Mistral
934
+
935
+ if api_key is None:
936
+ api_key = os.environ["MISTRAL_API_KEY"]
937
+
938
+ client = Mistral(api_key=api_key)
939
+
940
+ # Make requests for logging
941
+ text_messages = []
942
+ for message in messages:
943
+ if type(message["content"]) == str: # text-only model
944
+ text_messages.append(message)
945
+ else: # vision model
946
+ filtered_content_list = [
947
+ content for content in message["content"] if content["type"] == "text"
948
+ ]
949
+ text_messages.append(
950
+ {"role": message["role"], "content": filtered_content_list}
951
+ )
952
+
953
+ # Make requests
954
+ gen_params = {
955
+ "model": model_name,
956
+ "prompt": text_messages,
957
+ "temperature": temperature,
958
+ "top_p": top_p,
959
+ "max_new_tokens": max_new_tokens,
960
+ }
961
+ logger.info(f"==== request ====\n{gen_params}")
962
+
963
+ # new_messages = [
964
+ # ChatMessage(role=message["role"], content=message["content"])
965
+ # for message in messages
966
+ # ]
967
+
968
+ res = client.chat.stream(
969
+ model=model_name,
970
+ temperature=temperature,
971
+ messages=messages,
972
+ max_tokens=max_new_tokens,
973
+ top_p=top_p,
974
+ )
975
+
976
+ text = ""
977
+ for chunk in res:
978
+ if chunk.data.choices[0].delta.content is not None:
979
+ text += chunk.data.choices[0].delta.content
980
+ data = {
981
+ "text": text,
982
+ "error_code": 0,
983
+ }
984
+ yield data
985
+
986
+
987
+ def nvidia_api_stream_iter(
988
+ model_name, messages, temp, top_p, max_tokens, api_base, api_key=None
989
+ ):
990
+ model_2_api = {
991
+ "nemotron-4-340b": "/b0fcd392-e905-4ab4-8eb9-aeae95c30b37",
992
+ }
993
+ api_base += model_2_api[model_name]
994
+
995
+ api_key = api_key or os.environ["NVIDIA_API_KEY"]
996
+ headers = {
997
+ "Authorization": f"Bearer {api_key}",
998
+ "accept": "text/event-stream",
999
+ "content-type": "application/json",
1000
+ }
1001
+ # nvidia api does not accept 0 temperature
1002
+ if temp == 0.0:
1003
+ temp = 0.000001
1004
+
1005
+ payload = {
1006
+ "model": model_name,
1007
+ "messages": messages,
1008
+ "temperature": temp,
1009
+ "top_p": top_p,
1010
+ "max_tokens": max_tokens,
1011
+ "seed": 42,
1012
+ "stream": True,
1013
+ }
1014
+ logger.info(f"==== request ====\n{payload}")
1015
+
1016
+ # payload.pop("model")
1017
+
1018
+ # try 3 times
1019
+ for i in range(3):
1020
+ try:
1021
+ response = requests.post(
1022
+ api_base, headers=headers, json=payload, stream=True, timeout=3
1023
+ )
1024
+ break
1025
+ except Exception as e:
1026
+ logger.error(f"==== error ====\n{e}")
1027
+ if i == 2:
1028
+ yield {
1029
+ "text": f"**API REQUEST ERROR** Reason: API timeout. please try again later.",
1030
+ "error_code": 1,
1031
+ }
1032
+ return
1033
+
1034
+ text = ""
1035
+ for line in response.iter_lines():
1036
+ if line:
1037
+ data = line.decode("utf-8")
1038
+ if data.endswith("[DONE]"):
1039
+ break
1040
+ data = json.loads(data[6:])["choices"][0]["delta"]["content"]
1041
+ text += data
1042
+ yield {"text": text, "error_code": 0}
1043
+
1044
+
1045
+ def yandexgpt_api_stream_iter(
1046
+ model_name, messages, temperature, max_tokens, api_base, api_key, folder_id
1047
+ ):
1048
+ api_key = api_key or os.environ["YANDEXGPT_API_KEY"]
1049
+ headers = {
1050
+ "Authorization": f"Api-Key {api_key}",
1051
+ "content-type": "application/json",
1052
+ }
1053
+
1054
+ payload = {
1055
+ "modelUri": f"gpt://{folder_id}/{model_name}",
1056
+ "completionOptions": {
1057
+ "temperature": temperature,
1058
+ "max_tokens": max_tokens,
1059
+ "stream": True,
1060
+ },
1061
+ "messages": messages,
1062
+ }
1063
+ logger.info(f"==== request ====\n{payload}")
1064
+
1065
+ # https://llm.api.cloud.yandex.net/foundationModels/v1/completion
1066
+ response = requests.post(
1067
+ api_base, headers=headers, json=payload, stream=True, timeout=60
1068
+ )
1069
+ text = ""
1070
+ for line in response.iter_lines():
1071
+ if line:
1072
+ data = json.loads(line.decode("utf-8"))
1073
+ data = data["result"]
1074
+ top_alternative = data["alternatives"][0]
1075
+ text = top_alternative["message"]["text"]
1076
+ yield {"text": text, "error_code": 0}
1077
+
1078
+ status = top_alternative["status"]
1079
+ if status in (
1080
+ "ALTERNATIVE_STATUS_FINAL",
1081
+ "ALTERNATIVE_STATUS_TRUNCATED_FINAL",
1082
+ ):
1083
+ break
1084
+
1085
+
1086
+ def cohere_api_stream_iter(
1087
+ client_name: str,
1088
+ model_id: str,
1089
+ messages: list,
1090
+ temperature: Optional[
1091
+ float
1092
+ ] = None, # The SDK or API handles None for all parameters following
1093
+ top_p: Optional[float] = None,
1094
+ max_new_tokens: Optional[int] = None,
1095
+ api_key: Optional[str] = None, # default is env var CO_API_KEY
1096
+ api_base: Optional[str] = None,
1097
+ ):
1098
+ import cohere
1099
+ if api_key is None:
1100
+ api_key = os.environ["COHERE_API_KEY"]
1101
+
1102
+ OPENAI_TO_COHERE_ROLE_MAP = {
1103
+ "user": "User",
1104
+ "assistant": "Chatbot",
1105
+ "system": "System",
1106
+ }
1107
+
1108
+ # client = cohere.ClientV2(
1109
+ client = cohere.Client(
1110
+ api_key=api_key,
1111
+ # base_url=api_base,
1112
+ # client_name=client_name,
1113
+ )
1114
+
1115
+ # prepare and log requests
1116
+ chat_history = [
1117
+ dict(
1118
+ role=OPENAI_TO_COHERE_ROLE_MAP[message["role"]
1119
+ ], message=message["content"]
1120
+ )
1121
+ for message in messages[:-1]
1122
+ ]
1123
+ actual_prompt = messages[-1]["content"]
1124
+
1125
+ gen_params = {
1126
+ "model": model_id,
1127
+ "messages": messages,
1128
+ "chat_history": chat_history,
1129
+ "prompt": actual_prompt,
1130
+ "temperature": temperature,
1131
+ "top_p": top_p,
1132
+ "max_new_tokens": max_new_tokens,
1133
+ }
1134
+ logger.info(f"==== request ====\n{gen_params}")
1135
+
1136
+ # make request and stream response
1137
+ res = client.chat_stream(
1138
+ # messages=messages,
1139
+ message=actual_prompt,
1140
+ chat_history=chat_history,
1141
+ model=model_id,
1142
+ temperature=temperature,
1143
+ max_tokens=max_new_tokens,
1144
+ p=top_p,
1145
+ )
1146
+ try:
1147
+ text = ""
1148
+ for streaming_item in res:
1149
+ if streaming_item.event_type == "text-generation":
1150
+ text += streaming_item.text
1151
+ yield {"text": text, "error_code": 0}
1152
+ except cohere.core.ApiError as e:
1153
+ logger.error(f"==== error from cohere api: {e} ====")
1154
+ yield {
1155
+ "text": f"**API REQUEST ERROR** Reason: {e}",
1156
+ "error_code": 1,
1157
+ }
1158
+
1159
+
1160
+ def vertex_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens):
1161
+ import vertexai
1162
+ from vertexai import generative_models
1163
+ from vertexai.generative_models import (
1164
+ GenerationConfig,
1165
+ GenerativeModel,
1166
+ Image,
1167
+ )
1168
+
1169
+ project_id = os.environ.get("GCP_PROJECT_ID", None)
1170
+ location = os.environ.get("GCP_LOCATION", None)
1171
+ vertexai.init(project=project_id, location=location)
1172
+
1173
+ text_messages = []
1174
+ for message in messages:
1175
+ if type(message) == str:
1176
+ text_messages.append(message)
1177
+
1178
+ gen_params = {
1179
+ "model": model_name,
1180
+ "prompt": text_messages,
1181
+ "temperature": temperature,
1182
+ "top_p": top_p,
1183
+ "max_new_tokens": max_new_tokens,
1184
+ }
1185
+ logger.info(f"==== request ====\n{gen_params}")
1186
+
1187
+ safety_settings = [
1188
+ generative_models.SafetySetting(
1189
+ category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT,
1190
+ threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
1191
+ ),
1192
+ generative_models.SafetySetting(
1193
+ category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
1194
+ threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
1195
+ ),
1196
+ generative_models.SafetySetting(
1197
+ category=generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
1198
+ threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
1199
+ ),
1200
+ generative_models.SafetySetting(
1201
+ category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
1202
+ threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
1203
+ ),
1204
+ ]
1205
+ generator = GenerativeModel(model_name).generate_content(
1206
+ messages,
1207
+ stream=True,
1208
+ generation_config=GenerationConfig(
1209
+ top_p=top_p, max_output_tokens=max_new_tokens, temperature=temperature
1210
+ ),
1211
+ safety_settings=safety_settings,
1212
+ )
1213
+
1214
+ ret = ""
1215
+ for chunk in generator:
1216
+ # NOTE(chris): This may be a vertex api error, below is HOTFIX: https://github.com/googleapis/python-aiplatform/issues/3129
1217
+ ret += chunk.candidates[0].content.parts[0]._raw_part.text
1218
+ # ret += chunk.text
1219
+ data = {
1220
+ "text": ret,
1221
+ "error_code": 0,
1222
+ }
1223
+ yield data
1224
+
1225
+
1226
+ def reka_api_stream_iter(
1227
+ model_name: str,
1228
+ messages: list,
1229
+ temperature: Optional[
1230
+ float
1231
+ ] = None, # The SDK or API handles None for all parameters following
1232
+ top_p: Optional[float] = None,
1233
+ max_new_tokens: Optional[int] = None,
1234
+ api_key: Optional[str] = None, # default is env var CO_API_KEY
1235
+ api_base: Optional[str] = None,
1236
+ ):
1237
+ from reka.client import Reka
1238
+ from reka import TypedText
1239
+
1240
+ api_key = api_key or os.environ["REKA_API_KEY"]
1241
+
1242
+ client = Reka(api_key=api_key)
1243
+
1244
+ use_search_engine = False
1245
+ if "-online" in model_name:
1246
+ model_name = model_name.replace("-online", "")
1247
+ use_search_engine = True
1248
+ request = {
1249
+ "model_name": model_name,
1250
+ "conversation_history": messages,
1251
+ "temperature": temperature,
1252
+ "request_output_len": max_new_tokens,
1253
+ "runtime_top_p": top_p,
1254
+ "stream": True,
1255
+ "use_search_engine": use_search_engine,
1256
+ }
1257
+
1258
+ # Make requests for logging
1259
+ text_messages = []
1260
+ for turn in messages:
1261
+ for message in turn.content:
1262
+ if isinstance(message, TypedText):
1263
+ text_messages.append(
1264
+ {"type": message.type, "text": message.text})
1265
+ logged_request = dict(request)
1266
+ logged_request["conversation_history"] = text_messages
1267
+
1268
+ logger.info(f"==== request ====\n{logged_request}")
1269
+
1270
+ response = client.chat.create_stream(
1271
+ messages=messages,
1272
+ max_tokens=max_new_tokens,
1273
+ top_p=top_p,
1274
+ model=model_name,
1275
+ )
1276
+
1277
+ for chunk in response:
1278
+ try:
1279
+ yield {"text": chunk.responses[0].chunk.content, "error_code": 0}
1280
+ except:
1281
+ yield {
1282
+ "text": f"**API REQUEST ERROR** ",
1283
+ "error_code": 1,
1284
+ }
1285
+
1286
+
1287
+ def metagen_api_stream_iter(
1288
+ model_name,
1289
+ messages,
1290
+ temperature,
1291
+ top_p,
1292
+ max_new_tokens,
1293
+ api_key,
1294
+ api_base,
1295
+ ):
1296
+ try:
1297
+ text_messages = []
1298
+ for message in messages:
1299
+ if type(message["content"]) == str: # text-only model
1300
+ text_messages.append(message)
1301
+ else: # vision model
1302
+ filtered_content_list = [
1303
+ content
1304
+ for content in message["content"]
1305
+ if content["type"] == "text"
1306
+ ]
1307
+ text_messages.append(
1308
+ {"role": message["role"], "content": filtered_content_list}
1309
+ )
1310
+ gen_params = {
1311
+ "model": model_name,
1312
+ "prompt": text_messages,
1313
+ "temperature": temperature,
1314
+ "top_p": top_p,
1315
+ "max_new_tokens": max_new_tokens,
1316
+ }
1317
+ logger.info(f"==== request ====\n{gen_params}")
1318
+
1319
+ res = requests.post(
1320
+ f"{api_base}/chat_stream_completions?access_token={api_key}",
1321
+ stream=True,
1322
+ headers={"Content-Type": "application/json"},
1323
+ json={
1324
+ "model": model_name,
1325
+ "chunks_delimited": True,
1326
+ "messages": messages,
1327
+ "options": {
1328
+ "max_tokens": max_new_tokens,
1329
+ "generation_algorithm": "top_p",
1330
+ "top_p": top_p,
1331
+ "temperature": temperature,
1332
+ },
1333
+ },
1334
+ timeout=30,
1335
+ )
1336
+
1337
+ if res.status_code != 200:
1338
+ logger.error(
1339
+ f"Unexpected response ({res.status_code}): {res.text}")
1340
+ yield {
1341
+ "text": f"**API REQUEST ERROR** Reason: Unknown.",
1342
+ "error_code": 1,
1343
+ }
1344
+
1345
+ text = ""
1346
+ for line in res.iter_lines():
1347
+ if line:
1348
+ part = json.loads(line.decode("utf-8"))
1349
+ if "text" in part:
1350
+ text += part["text"]
1351
+ data = {
1352
+ "text": text,
1353
+ "error_code": 0,
1354
+ }
1355
+ yield data
1356
+ except Exception as e:
1357
+ logger.error(f"==== error ====\n{e}")
1358
+ yield {
1359
+ "text": f"**API REQUEST ERROR** Reason: Unknown.",
1360
+ "error_code": 1,
1361
+ }
serve/constants.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Global constants.
3
+ """
4
+
5
+ from enum import IntEnum
6
+ import os
7
+
8
+ REPO_PATH = os.path.dirname(os.path.dirname(__file__))
9
+
10
+ # Survey Link URL (to be removed) #00729c
11
+ SURVEY_LINK = """"""
12
+ # SURVEY_LINK = ""
13
+
14
+ # For the gradio web server
15
+ SERVER_ERROR_MSG = (
16
+ "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
17
+ )
18
+ TEXT_MODERATION_MSG = (
19
+ "$MODERATION$ YOUR TEXT VIOLATES OUR CONTENT MODERATION GUIDELINES."
20
+ )
21
+ IMAGE_MODERATION_MSG = (
22
+ "$MODERATION$ YOUR IMAGE VIOLATES OUR CONTENT MODERATION GUIDELINES."
23
+ )
24
+ MODERATION_MSG = "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES."
25
+ CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION."
26
+ INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE."
27
+ SLOW_MODEL_MSG = (
28
+ "⚠️ Models are thinking. Please stay patient as it may take over a minute."
29
+ )
30
+ RATE_LIMIT_MSG = "**RATE LIMIT OF THIS MODEL IS REACHED. PLEASE COME BACK LATER OR USE <span style='color: red; font-weight: bold;'>[BATTLE MODE](https://lmarena.ai)</span> (the 1st tab).**"
31
+ # Maximum input length
32
+ INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 12000))
33
+ BLIND_MODE_INPUT_CHAR_LEN_LIMIT = int(
34
+ os.getenv("FASTCHAT_BLIND_MODE_INPUT_CHAR_LEN_LIMIT", 30000)
35
+ )
36
+ # Maximum conversation turns
37
+ CONVERSATION_TURN_LIMIT = 50
38
+ # Session expiration time
39
+ SESSION_EXPIRATION_TIME = 3600
40
+ # The output dir of log files
41
+ LOGDIR = os.getenv("LOGDIR", ".")
42
+ # CPU Instruction Set Architecture
43
+ CPU_ISA = os.getenv("CPU_ISA")
44
+
45
+
46
+ # For the controller and workers (could be overwritten through ENV variables.)
47
+ CONTROLLER_HEART_BEAT_EXPIRATION = int(
48
+ os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90)
49
+ )
50
+ WORKER_HEART_BEAT_INTERVAL = int(
51
+ os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 45))
52
+ WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100))
53
+ WORKER_API_EMBEDDING_BATCH_SIZE = int(
54
+ os.getenv("FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE", 4)
55
+ )
56
+
57
+
58
+ class ErrorCode(IntEnum):
59
+ """
60
+ https://platform.openai.com/docs/guides/error-codes/api-errors
61
+ """
62
+
63
+ VALIDATION_TYPE_ERROR = 40001
64
+
65
+ INVALID_AUTH_KEY = 40101
66
+ INCORRECT_AUTH_KEY = 40102
67
+ NO_PERMISSION = 40103
68
+
69
+ INVALID_MODEL = 40301
70
+ PARAM_OUT_OF_RANGE = 40302
71
+ CONTEXT_OVERFLOW = 40303
72
+
73
+ RATE_LIMIT = 42901
74
+ QUOTA_EXCEEDED = 42902
75
+ ENGINE_OVERLOADED = 42903
76
+
77
+ INTERNAL_ERROR = 50001
78
+ CUDA_OUT_OF_MEMORY = 50002
79
+ GRADIO_REQUEST_ERROR = 50003
80
+ GRADIO_STREAM_UNKNOWN_ERROR = 50004
81
+ CONTROLLER_NO_WORKER = 50005
82
+ CONTROLLER_WORKER_TIMEOUT = 50006
serve/conversation.py ADDED
The diff for this file is too large to render. See raw diff
 
serve/gradio_block_arena_anony.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatbot Arena (battle) tab.
3
+ Users chat with two anonymous models.
4
+ """
5
+
6
+ import json
7
+ import time
8
+ import re
9
+
10
+ import gradio as gr
11
+ import numpy as np
12
+
13
+ from .constants import (
14
+ MODERATION_MSG,
15
+ CONVERSATION_LIMIT_MSG,
16
+ SLOW_MODEL_MSG,
17
+ BLIND_MODE_INPUT_CHAR_LEN_LIMIT,
18
+ CONVERSATION_TURN_LIMIT,
19
+ SURVEY_LINK,
20
+ )
21
+ from .gradio_block_arena_named import flash_buttons
22
+ from .gradio_web_server import (
23
+ State,
24
+ bot_response,
25
+ get_conv_log_filename,
26
+ no_change_btn,
27
+ enable_btn,
28
+ disable_btn,
29
+ invisible_btn,
30
+ enable_text,
31
+ disable_text,
32
+ acknowledgment_md,
33
+ get_ip,
34
+ get_model_description_md,
35
+ )
36
+ from .remote_logger import get_remote_logger
37
+ from .utils import (
38
+ build_logger,
39
+ moderation_filter,
40
+ )
41
+
42
+ logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
43
+
44
+ num_sides = 2
45
+ enable_moderation = False
46
+ anony_names = ["", ""]
47
+ models = []
48
+
49
+
50
+ def set_global_vars_anony(enable_moderation_):
51
+ global enable_moderation
52
+ enable_moderation = enable_moderation_
53
+
54
+
55
+ def load_demo_side_by_side_anony(models_, url_params):
56
+ global models
57
+ models = models_
58
+
59
+ states = [None] * num_sides
60
+ selector_updates = [
61
+ gr.Markdown(visible=True),
62
+ gr.Markdown(visible=True),
63
+ ]
64
+
65
+ return states + selector_updates
66
+
67
+
68
+ def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
69
+ with open(get_conv_log_filename(), "a") as fout:
70
+ data = {
71
+ "tstamp": round(time.time(), 4),
72
+ "type": vote_type,
73
+ "models": [x for x in model_selectors],
74
+ "states": [x.dict() for x in states],
75
+ "ip": get_ip(request),
76
+ }
77
+ fout.write(json.dumps(data) + "\n")
78
+ get_remote_logger().log(data)
79
+
80
+ gr.Info(
81
+ "🎉 Thanks for voting! Your vote shapes the leaderboard, please vote RESPONSIBLY."
82
+ )
83
+ if ":" not in model_selectors[0]:
84
+ for i in range(5):
85
+ names = (
86
+ "### Model A: " + states[0].model_name,
87
+ "### Model B: " + states[1].model_name,
88
+ )
89
+ # yield names + ("",) + (disable_btn,) * 4
90
+ yield names + (disable_text,) + (disable_btn,) * 5
91
+ time.sleep(0.1)
92
+ else:
93
+ names = (
94
+ "### Model A: " + states[0].model_name,
95
+ "### Model B: " + states[1].model_name,
96
+ )
97
+ # yield names + ("",) + (disable_btn,) * 4
98
+ yield names + (disable_text,) + (disable_btn,) * 5
99
+
100
+
101
+ def leftvote_last_response(
102
+ state0, state1, model_selector0, model_selector1, request: gr.Request
103
+ ):
104
+ logger.info(f"leftvote (anony). ip: {get_ip(request)}")
105
+ for x in vote_last_response(
106
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
107
+ ):
108
+ yield x
109
+
110
+
111
+ def rightvote_last_response(
112
+ state0, state1, model_selector0, model_selector1, request: gr.Request
113
+ ):
114
+ logger.info(f"rightvote (anony). ip: {get_ip(request)}")
115
+ for x in vote_last_response(
116
+ [state0, state1], "rightvote", [
117
+ model_selector0, model_selector1], request
118
+ ):
119
+ yield x
120
+
121
+
122
+ def tievote_last_response(
123
+ state0, state1, model_selector0, model_selector1, request: gr.Request
124
+ ):
125
+ logger.info(f"tievote (anony). ip: {get_ip(request)}")
126
+ for x in vote_last_response(
127
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
128
+ ):
129
+ yield x
130
+
131
+
132
+ def bothbad_vote_last_response(
133
+ state0, state1, model_selector0, model_selector1, request: gr.Request
134
+ ):
135
+ logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
136
+ for x in vote_last_response(
137
+ [state0, state1], "bothbad_vote", [
138
+ model_selector0, model_selector1], request
139
+ ):
140
+ yield x
141
+
142
+
143
+ def regenerate(state0, state1, request: gr.Request):
144
+ logger.info(f"regenerate (anony). ip: {get_ip(request)}")
145
+ states = [state0, state1]
146
+ if state0.regen_support and state1.regen_support:
147
+ for i in range(num_sides):
148
+ states[i].conv.update_last_message(None)
149
+ return (
150
+ states + [x.to_gradio_chatbot() for x in states] +
151
+ [""] + [disable_btn] * 6
152
+ )
153
+ states[0].skip_next = True
154
+ states[1].skip_next = True
155
+ return states + [x.to_gradio_chatbot() for x in states] + [""] + [no_change_btn] * 6
156
+
157
+
158
+ def clear_history(request: gr.Request):
159
+ logger.info(f"clear_history (anony). ip: {get_ip(request)}")
160
+ return (
161
+ [None] * num_sides
162
+ + [None] * num_sides
163
+ + anony_names
164
+ + [enable_text]
165
+ + [invisible_btn] * 4
166
+ + [disable_btn] * 2
167
+ + [""]
168
+ + [enable_btn]
169
+ )
170
+
171
+
172
+ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
173
+ logger.info(f"share (anony). ip: {get_ip(request)}")
174
+ if state0 is not None and state1 is not None:
175
+ vote_last_response(
176
+ [state0, state1], "share", [model_selector0, model_selector1], request
177
+ )
178
+
179
+
180
+ SAMPLING_WEIGHTS = {}
181
+
182
+ # target model sampling weights will be boosted.
183
+ BATTLE_TARGETS = {}
184
+
185
+ BATTLE_STRICT_TARGETS = {}
186
+
187
+ ANON_MODELS = []
188
+
189
+ SAMPLING_BOOST_MODELS = []
190
+
191
+ # outage models won't be sampled.
192
+ OUTAGE_MODELS = []
193
+
194
+
195
+ def get_sample_weight(model, outage_models, sampling_weights, sampling_boost_models=[]):
196
+ if model in outage_models:
197
+ return 0
198
+ weight = sampling_weights.get(model, 0)
199
+ if model in sampling_boost_models:
200
+ weight *= 5
201
+ return weight
202
+
203
+
204
+ def is_model_match_pattern(model, patterns):
205
+ flag = False
206
+ for pattern in patterns:
207
+ pattern = pattern.replace("*", ".*")
208
+ if re.match(pattern, model) is not None:
209
+ flag = True
210
+ break
211
+ return flag
212
+
213
+
214
+ def get_battle_pair(
215
+ models, battle_targets, outage_models, sampling_weights, sampling_boost_models
216
+ ):
217
+ print("models", models)
218
+ print("battle_targets", battle_targets)
219
+ print("outage_models", outage_models)
220
+ print("sampling_weights", sampling_weights)
221
+ print("sampling_boost_models", sampling_boost_models)
222
+ if len(models) == 1:
223
+ return models[0], models[0]
224
+
225
+ model_weights = []
226
+ for model in models:
227
+ weight = get_sample_weight(
228
+ model, outage_models, sampling_weights, sampling_boost_models
229
+ )
230
+ if weight == 0:
231
+ weight += 0.01
232
+ model_weights.append(weight)
233
+ total_weight = np.sum(model_weights)
234
+ print("model_weights", model_weights)
235
+ print("total_weight", total_weight)
236
+ model_weights = model_weights / total_weight
237
+
238
+ chosen_idx = np.random.choice(len(models), p=model_weights)
239
+ chosen_model = models[chosen_idx]
240
+ # for p, w in zip(models, model_weights):
241
+ # print(p, w)
242
+
243
+ rival_models = []
244
+ rival_weights = []
245
+ for model in models:
246
+ if model == chosen_model:
247
+ continue
248
+ if model in ANON_MODELS and chosen_model in ANON_MODELS:
249
+ continue
250
+ if chosen_model in BATTLE_STRICT_TARGETS:
251
+ if not is_model_match_pattern(model, BATTLE_STRICT_TARGETS[chosen_model]):
252
+ continue
253
+ if model in BATTLE_STRICT_TARGETS:
254
+ if not is_model_match_pattern(chosen_model, BATTLE_STRICT_TARGETS[model]):
255
+ continue
256
+ weight = get_sample_weight(model, outage_models, sampling_weights)
257
+ if (
258
+ weight != 0
259
+ and chosen_model in battle_targets
260
+ and model in battle_targets[chosen_model]
261
+ ):
262
+ # boost to 20% chance
263
+ weight = 0.5 * total_weight / len(battle_targets[chosen_model])
264
+ rival_models.append(model)
265
+ if weight == 0:
266
+ weight += 0.01
267
+ rival_weights.append(weight)
268
+ # for p, w in zip(rival_models, rival_weights):
269
+ # print(p, w)
270
+ rival_weights = rival_weights / np.sum(rival_weights)
271
+ rival_idx = np.random.choice(len(rival_models), p=rival_weights)
272
+ rival_model = rival_models[rival_idx]
273
+
274
+ swap = np.random.randint(2)
275
+ if swap == 0:
276
+ return chosen_model, rival_model
277
+ else:
278
+ return rival_model, chosen_model
279
+
280
+
281
+ def add_text(
282
+ state0, state1, model_selector0, model_selector1, text, request: gr.Request
283
+ ):
284
+ ip = get_ip(request)
285
+ logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
286
+ states = [state0, state1]
287
+ model_selectors = [model_selector0, model_selector1]
288
+
289
+ # Init states if necessary
290
+ if states[0] is None:
291
+ assert states[1] is None
292
+
293
+ model_left, model_right = get_battle_pair(
294
+ models,
295
+ BATTLE_TARGETS,
296
+ OUTAGE_MODELS,
297
+ SAMPLING_WEIGHTS,
298
+ SAMPLING_BOOST_MODELS,
299
+ )
300
+ states = [
301
+ State(model_left),
302
+ State(model_right),
303
+ ]
304
+
305
+ if len(text) <= 0:
306
+ for i in range(num_sides):
307
+ states[i].skip_next = True
308
+ return (
309
+ states
310
+ + [x.to_gradio_chatbot() for x in states]
311
+ + ["", None]
312
+ + [
313
+ no_change_btn,
314
+ ]
315
+ * 6
316
+ + [""]
317
+ )
318
+
319
+ model_list = [states[i].model_name for i in range(num_sides)]
320
+ # turn on moderation in battle mode
321
+ all_conv_text_left = states[0].conv.get_prompt()
322
+ all_conv_text_right = states[0].conv.get_prompt()
323
+ all_conv_text = (
324
+ all_conv_text_left[-1000:] +
325
+ all_conv_text_right[-1000:] + "\nuser: " + text
326
+ )
327
+ flagged = moderation_filter(all_conv_text, model_list, do_moderation=True)
328
+ if flagged:
329
+ logger.info(f"violate moderation (anony). ip: {ip}. text: {text}")
330
+ # overwrite the original text
331
+ text = MODERATION_MSG
332
+
333
+ conv = states[0].conv
334
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
335
+ logger.info(
336
+ f"conversation turn limit. ip: {get_ip(request)}. text: {text}")
337
+ for i in range(num_sides):
338
+ states[i].skip_next = True
339
+ return (
340
+ states
341
+ + [x.to_gradio_chatbot() for x in states]
342
+ + [CONVERSATION_LIMIT_MSG]
343
+ + [
344
+ no_change_btn,
345
+ ]
346
+ * 6
347
+ + [""]
348
+ )
349
+
350
+ text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
351
+ for i in range(num_sides):
352
+ states[i].conv.append_message(states[i].conv.roles[0], text)
353
+ states[i].conv.append_message(states[i].conv.roles[1], None)
354
+ states[i].skip_next = False
355
+
356
+ hint_msg = ""
357
+ for i in range(num_sides):
358
+ if "deluxe" in states[i].model_name:
359
+ hint_msg = SLOW_MODEL_MSG
360
+ return (
361
+ states
362
+ + [x.to_gradio_chatbot() for x in states]
363
+ + [""]
364
+ + [
365
+ disable_btn,
366
+ ]
367
+ * 6
368
+ + [hint_msg]
369
+ )
370
+
371
+
372
+ def bot_response_multi(
373
+ state0,
374
+ state1,
375
+ temperature,
376
+ top_p,
377
+ max_new_tokens,
378
+ request: gr.Request,
379
+ ):
380
+ logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}")
381
+
382
+ if state0 is None or state0.skip_next:
383
+ # This generate call is skipped due to invalid inputs
384
+ yield (
385
+ state0,
386
+ state1,
387
+ state0.to_gradio_chatbot(),
388
+ state1.to_gradio_chatbot(),
389
+ ) + (no_change_btn,) * 6
390
+ return
391
+
392
+ states = [state0, state1]
393
+ gen = []
394
+ for i in range(num_sides):
395
+ gen.append(
396
+ bot_response(
397
+ states[i],
398
+ temperature,
399
+ top_p,
400
+ max_new_tokens,
401
+ request,
402
+ apply_rate_limit=False,
403
+ use_recommended_config=True,
404
+ )
405
+ )
406
+
407
+ model_tpy = []
408
+ for i in range(num_sides):
409
+ token_per_yield = 1
410
+ if states[i].model_name in [
411
+ "gemini-pro",
412
+ "gemma-1.1-2b-it",
413
+ "gemma-1.1-7b-it",
414
+ "phi-3-mini-4k-instruct",
415
+ "phi-3-mini-128k-instruct",
416
+ "snowflake-arctic-instruct",
417
+ ]:
418
+ token_per_yield = 30
419
+ elif states[i].model_name in [
420
+ "qwen-max-0428",
421
+ "qwen-vl-max-0809",
422
+ "qwen1.5-110b-chat",
423
+ "llava-v1.6-34b",
424
+ ]:
425
+ token_per_yield = 7
426
+ elif states[i].model_name in [
427
+ "qwen2.5-72b-instruct",
428
+ "qwen2-72b-instruct",
429
+ "qwen-plus-0828",
430
+ "qwen-max-0919",
431
+ "llama-3.1-405b-instruct-bf16",
432
+ ]:
433
+ token_per_yield = 4
434
+ model_tpy.append(token_per_yield)
435
+
436
+ chatbots = [None] * num_sides
437
+ iters = 0
438
+ while True:
439
+ stop = True
440
+ iters += 1
441
+ for i in range(num_sides):
442
+ try:
443
+ # yield fewer times if chunk size is larger
444
+ if model_tpy[i] == 1 or (iters % model_tpy[i] == 1 or iters < 3):
445
+ ret = next(gen[i])
446
+ states[i], chatbots[i] = ret[0], ret[1]
447
+ stop = False
448
+ except StopIteration:
449
+ pass
450
+ yield states + chatbots + [disable_btn] * 6
451
+ if stop:
452
+ break
453
+
454
+
455
+ def build_side_by_side_ui_anony(models):
456
+ notice_markdown = f"""
457
+ # ⚔️ Chatbot Arena 日本語版(α版)
458
+
459
+ {SURVEY_LINK}
460
+
461
+ ## 📣 News
462
+ - 評価結果はこちら: [here](https://huggingface.co/datasets/kanhatakeyama/chatbot-arena-ja-elo-rating).
463
+ ## 👇 Chat now!
464
+ """
465
+
466
+ states = [gr.State() for _ in range(num_sides)]
467
+ model_selectors = [None] * num_sides
468
+ chatbots = [None] * num_sides
469
+
470
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
471
+
472
+ with gr.Group(elem_id="share-region-anony"):
473
+ with gr.Accordion(
474
+ f"🔍 Expand to see the descriptions of {len(models)} models", open=False
475
+ ):
476
+ model_description_md = get_model_description_md(models)
477
+ gr.Markdown(model_description_md,
478
+ elem_id="model_description_markdown")
479
+ with gr.Row():
480
+ for i in range(num_sides):
481
+ label = "Model A" if i == 0 else "Model B"
482
+ with gr.Column():
483
+ chatbots[i] = gr.Chatbot(
484
+ label=label,
485
+ elem_id="chatbot",
486
+ height=650,
487
+ show_copy_button=True,
488
+ )
489
+
490
+ with gr.Row():
491
+ for i in range(num_sides):
492
+ with gr.Column():
493
+ model_selectors[i] = gr.Markdown(
494
+ anony_names[i], elem_id="model_selector_md"
495
+ )
496
+ with gr.Row():
497
+ slow_warning = gr.Markdown("")
498
+
499
+ with gr.Row():
500
+ leftvote_btn = gr.Button(
501
+ value="👈 A is better", visible=False, interactive=False
502
+ )
503
+ rightvote_btn = gr.Button(
504
+ value="👉 B is better", visible=False, interactive=False
505
+ )
506
+ tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
507
+ bothbad_btn = gr.Button(
508
+ value="👎 Both are bad", visible=False, interactive=False
509
+ )
510
+
511
+ with gr.Row():
512
+ textbox = gr.Textbox(
513
+ show_label=False,
514
+ placeholder="👉 Enter your prompt and press ENTER",
515
+ elem_id="input_box",
516
+ )
517
+ send_btn = gr.Button(value="Send", variant="primary", scale=0)
518
+
519
+ with gr.Row() as button_row:
520
+ clear_btn = gr.Button(value="🎲 New Round", interactive=False)
521
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
522
+ share_btn = gr.Button(value="📷 Share")
523
+
524
+ with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
525
+ temperature = gr.Slider(
526
+ minimum=0.0,
527
+ maximum=1.0,
528
+ value=0.7,
529
+ step=0.1,
530
+ interactive=True,
531
+ label="Temperature",
532
+ )
533
+ top_p = gr.Slider(
534
+ minimum=0.0,
535
+ maximum=1.0,
536
+ value=1.0,
537
+ step=0.1,
538
+ interactive=True,
539
+ label="Top P",
540
+ )
541
+ max_output_tokens = gr.Slider(
542
+ minimum=16,
543
+ maximum=2048,
544
+ value=2000,
545
+ step=64,
546
+ interactive=True,
547
+ label="Max output tokens",
548
+ )
549
+
550
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
551
+
552
+ # Register listeners
553
+ btn_list = [
554
+ leftvote_btn,
555
+ rightvote_btn,
556
+ tie_btn,
557
+ bothbad_btn,
558
+ regenerate_btn,
559
+ clear_btn,
560
+ ]
561
+ leftvote_btn.click(
562
+ leftvote_last_response,
563
+ states + model_selectors,
564
+ model_selectors
565
+ + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, send_btn],
566
+ )
567
+ rightvote_btn.click(
568
+ rightvote_last_response,
569
+ states + model_selectors,
570
+ model_selectors
571
+ + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, send_btn],
572
+ )
573
+ tie_btn.click(
574
+ tievote_last_response,
575
+ states + model_selectors,
576
+ model_selectors
577
+ + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, send_btn],
578
+ )
579
+ bothbad_btn.click(
580
+ bothbad_vote_last_response,
581
+ states + model_selectors,
582
+ model_selectors
583
+ + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, send_btn],
584
+ )
585
+ regenerate_btn.click(
586
+ regenerate, states, states + chatbots + [textbox] + btn_list
587
+ ).then(
588
+ bot_response_multi,
589
+ states + [temperature, top_p, max_output_tokens],
590
+ states + chatbots + btn_list,
591
+ ).then(
592
+ flash_buttons, [], btn_list
593
+ )
594
+ clear_btn.click(
595
+ clear_history,
596
+ None,
597
+ states
598
+ + chatbots
599
+ + model_selectors
600
+ + [textbox]
601
+ + btn_list
602
+ + [slow_warning]
603
+ + [send_btn],
604
+ )
605
+
606
+ share_js = """
607
+ function (a, b, c, d) {
608
+ const captureElement = document.querySelector('#share-region-anony');
609
+ html2canvas(captureElement)
610
+ .then(canvas => {
611
+ canvas.style.display = 'none'
612
+ document.body.appendChild(canvas)
613
+ return canvas
614
+ })
615
+ .then(canvas => {
616
+ const image = canvas.toDataURL('image/png')
617
+ const a = document.createElement('a')
618
+ a.setAttribute('download', 'chatbot-arena.png')
619
+ a.setAttribute('href', image)
620
+ a.click()
621
+ canvas.remove()
622
+ });
623
+ return [a, b, c, d];
624
+ }
625
+ """
626
+ share_btn.click(share_click, states + model_selectors, [], js=share_js)
627
+
628
+ textbox.submit(
629
+ add_text,
630
+ states + model_selectors + [textbox],
631
+ states + chatbots + [textbox] + btn_list + [slow_warning],
632
+ ).then(
633
+ bot_response_multi,
634
+ states + [temperature, top_p, max_output_tokens],
635
+ states + chatbots + btn_list,
636
+ ).then(
637
+ flash_buttons,
638
+ [],
639
+ btn_list,
640
+ )
641
+
642
+ send_btn.click(
643
+ add_text,
644
+ states + model_selectors + [textbox],
645
+ states + chatbots + [textbox] + btn_list,
646
+ ).then(
647
+ bot_response_multi,
648
+ states + [temperature, top_p, max_output_tokens],
649
+ states + chatbots + btn_list,
650
+ ).then(
651
+ flash_buttons, [], btn_list
652
+ )
653
+
654
+ return states + model_selectors
serve/gradio_block_arena_named.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatbot Arena (side-by-side) tab.
3
+ Users chat with two chosen models.
4
+ """
5
+
6
+ import json
7
+ import time
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+
12
+ from .constants import (
13
+ MODERATION_MSG,
14
+ CONVERSATION_LIMIT_MSG,
15
+ INPUT_CHAR_LEN_LIMIT,
16
+ CONVERSATION_TURN_LIMIT,
17
+ SURVEY_LINK,
18
+ )
19
+ from .gradio_web_server import (
20
+ State,
21
+ bot_response,
22
+ get_conv_log_filename,
23
+ no_change_btn,
24
+ enable_btn,
25
+ disable_btn,
26
+ invisible_btn,
27
+ acknowledgment_md,
28
+ get_ip,
29
+ get_model_description_md,
30
+ )
31
+ from .remote_logger import get_remote_logger
32
+ from .utils import (
33
+ build_logger,
34
+ moderation_filter,
35
+ )
36
+
37
+ logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
38
+
39
+ num_sides = 2
40
+ enable_moderation = False
41
+
42
+
43
+ def set_global_vars_named(enable_moderation_):
44
+ global enable_moderation
45
+ enable_moderation = enable_moderation_
46
+
47
+
48
+ def load_demo_side_by_side_named(models, url_params):
49
+ states = [None] * num_sides
50
+
51
+ model_left = models[0] if len(models) > 0 else ""
52
+ if len(models) > 1:
53
+ weights = ([8] * 4 + [4] * 8 + [1] * 64)[: len(models) - 1]
54
+ weights = weights / np.sum(weights)
55
+ model_right = np.random.choice(models[1:], p=weights)
56
+ else:
57
+ model_right = model_left
58
+
59
+ selector_updates = [
60
+ gr.Dropdown(choices=models, value=model_left, visible=True),
61
+ gr.Dropdown(choices=models, value=model_right, visible=True),
62
+ ]
63
+
64
+ return states + selector_updates
65
+
66
+
67
+ def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
68
+ with open(get_conv_log_filename(), "a") as fout:
69
+ data = {
70
+ "tstamp": round(time.time(), 4),
71
+ "type": vote_type,
72
+ "models": [x for x in model_selectors],
73
+ "states": [x.dict() for x in states],
74
+ "ip": get_ip(request),
75
+ }
76
+ fout.write(json.dumps(data) + "\n")
77
+ get_remote_logger().log(data)
78
+
79
+
80
+ def leftvote_last_response(
81
+ state0, state1, model_selector0, model_selector1, request: gr.Request
82
+ ):
83
+ logger.info(f"leftvote (named). ip: {get_ip(request)}")
84
+ vote_last_response(
85
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
86
+ )
87
+ return ("",) + (disable_btn,) * 4
88
+
89
+
90
+ def rightvote_last_response(
91
+ state0, state1, model_selector0, model_selector1, request: gr.Request
92
+ ):
93
+ logger.info(f"rightvote (named). ip: {get_ip(request)}")
94
+ vote_last_response(
95
+ [state0, state1], "rightvote", [
96
+ model_selector0, model_selector1], request
97
+ )
98
+ return ("",) + (disable_btn,) * 4
99
+
100
+
101
+ def tievote_last_response(
102
+ state0, state1, model_selector0, model_selector1, request: gr.Request
103
+ ):
104
+ logger.info(f"tievote (named). ip: {get_ip(request)}")
105
+ vote_last_response(
106
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
107
+ )
108
+ return ("",) + (disable_btn,) * 4
109
+
110
+
111
+ def bothbad_vote_last_response(
112
+ state0, state1, model_selector0, model_selector1, request: gr.Request
113
+ ):
114
+ logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
115
+ vote_last_response(
116
+ [state0, state1], "bothbad_vote", [
117
+ model_selector0, model_selector1], request
118
+ )
119
+ return ("",) + (disable_btn,) * 4
120
+
121
+
122
+ def regenerate(state0, state1, request: gr.Request):
123
+ logger.info(f"regenerate (named). ip: {get_ip(request)}")
124
+ states = [state0, state1]
125
+ if state0.regen_support and state1.regen_support:
126
+ for i in range(num_sides):
127
+ states[i].conv.update_last_message(None)
128
+ return (
129
+ states + [x.to_gradio_chatbot() for x in states] +
130
+ [""] + [disable_btn] * 6
131
+ )
132
+ states[0].skip_next = True
133
+ states[1].skip_next = True
134
+ return states + [x.to_gradio_chatbot() for x in states] + [""] + [no_change_btn] * 6
135
+
136
+
137
+ def clear_history(request: gr.Request):
138
+ logger.info(f"clear_history (named). ip: {get_ip(request)}")
139
+ return (
140
+ [None] * num_sides
141
+ + [None] * num_sides
142
+ + [""]
143
+ + [invisible_btn] * 4
144
+ + [disable_btn] * 2
145
+ )
146
+
147
+
148
+ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
149
+ logger.info(f"share (named). ip: {get_ip(request)}")
150
+ if state0 is not None and state1 is not None:
151
+ vote_last_response(
152
+ [state0, state1], "share", [model_selector0, model_selector1], request
153
+ )
154
+
155
+
156
+ def add_text(
157
+ state0, state1, model_selector0, model_selector1, text, request: gr.Request
158
+ ):
159
+ ip = get_ip(request)
160
+ logger.info(f"add_text (named). ip: {ip}. len: {len(text)}")
161
+ states = [state0, state1]
162
+ model_selectors = [model_selector0, model_selector1]
163
+
164
+ # Init states if necessary
165
+ for i in range(num_sides):
166
+ if states[i] is None:
167
+ states[i] = State(model_selectors[i])
168
+
169
+ if len(text) <= 0:
170
+ for i in range(num_sides):
171
+ states[i].skip_next = True
172
+ return (
173
+ states
174
+ + [x.to_gradio_chatbot() for x in states]
175
+ + ["", None]
176
+ + [
177
+ no_change_btn,
178
+ ]
179
+ * 6
180
+ )
181
+
182
+ model_list = [states[i].model_name for i in range(num_sides)]
183
+ all_conv_text_left = states[0].conv.get_prompt()
184
+ all_conv_text_right = states[1].conv.get_prompt()
185
+ all_conv_text = (
186
+ all_conv_text_left[-1000:] +
187
+ all_conv_text_right[-1000:] + "\nuser: " + text
188
+ )
189
+ flagged = moderation_filter(all_conv_text, model_list)
190
+ if flagged:
191
+ logger.info(f"violate moderation (named). ip: {ip}. text: {text}")
192
+ # overwrite the original text
193
+ text = MODERATION_MSG
194
+
195
+ conv = states[0].conv
196
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
197
+ logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
198
+ for i in range(num_sides):
199
+ states[i].skip_next = True
200
+ return (
201
+ states
202
+ + [x.to_gradio_chatbot() for x in states]
203
+ + [CONVERSATION_LIMIT_MSG]
204
+ + [
205
+ no_change_btn,
206
+ ]
207
+ * 6
208
+ )
209
+
210
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
211
+ for i in range(num_sides):
212
+ states[i].conv.append_message(states[i].conv.roles[0], text)
213
+ states[i].conv.append_message(states[i].conv.roles[1], None)
214
+ states[i].skip_next = False
215
+
216
+ return (
217
+ states
218
+ + [x.to_gradio_chatbot() for x in states]
219
+ + [""]
220
+ + [
221
+ disable_btn,
222
+ ]
223
+ * 6
224
+ )
225
+
226
+
227
+ def bot_response_multi(
228
+ state0,
229
+ state1,
230
+ temperature,
231
+ top_p,
232
+ max_new_tokens,
233
+ request: gr.Request,
234
+ ):
235
+ logger.info(f"bot_response_multi (named). ip: {get_ip(request)}")
236
+
237
+ if state0.skip_next:
238
+ # This generate call is skipped due to invalid inputs
239
+ yield (
240
+ state0,
241
+ state1,
242
+ state0.to_gradio_chatbot(),
243
+ state1.to_gradio_chatbot(),
244
+ ) + (no_change_btn,) * 6
245
+ return
246
+
247
+ states = [state0, state1]
248
+ gen = []
249
+ for i in range(num_sides):
250
+ gen.append(
251
+ bot_response(
252
+ states[i],
253
+ temperature,
254
+ top_p,
255
+ max_new_tokens,
256
+ request,
257
+ )
258
+ )
259
+
260
+ model_tpy = []
261
+ for i in range(num_sides):
262
+ token_per_yield = 1
263
+ if states[i].model_name in [
264
+ "gemini-pro",
265
+ "gemma-1.1-2b-it",
266
+ "gemma-1.1-7b-it",
267
+ "phi-3-mini-4k-instruct",
268
+ "phi-3-mini-128k-instruct",
269
+ "snowflake-arctic-instruct",
270
+ ]:
271
+ token_per_yield = 30
272
+ elif states[i].model_name in [
273
+ "qwen-max-0428",
274
+ "qwen-vl-max-0809",
275
+ "qwen1.5-110b-chat",
276
+ ]:
277
+ token_per_yield = 7
278
+ elif states[i].model_name in [
279
+ "qwen2.5-72b-instruct",
280
+ "qwen2-72b-instruct",
281
+ "qwen-plus-0828",
282
+ "qwen-max-0919",
283
+ "llama-3.1-405b-instruct-bf16",
284
+ ]:
285
+ token_per_yield = 4
286
+ model_tpy.append(token_per_yield)
287
+
288
+ chatbots = [None] * num_sides
289
+ iters = 0
290
+ while True:
291
+ stop = True
292
+ iters += 1
293
+ for i in range(num_sides):
294
+ try:
295
+ # yield fewer times if chunk size is larger
296
+ if model_tpy[i] == 1 or (iters % model_tpy[i] == 1 or iters < 3):
297
+ ret = next(gen[i])
298
+ states[i], chatbots[i] = ret[0], ret[1]
299
+ stop = False
300
+ except StopIteration:
301
+ pass
302
+ yield states + chatbots + [disable_btn] * 6
303
+ if stop:
304
+ break
305
+
306
+
307
+ def flash_buttons():
308
+ btn_updates = [
309
+ [disable_btn] * 4 + [enable_btn] * 2,
310
+ [enable_btn] * 6,
311
+ ]
312
+ for i in range(4):
313
+ yield btn_updates[i % 2]
314
+ time.sleep(0.3)
315
+
316
+
317
+ def build_side_by_side_ui_named(models):
318
+ notice_markdown = f"""
319
+ # ⚔️ Chatbot Arena (formerly LMSYS): Free AI Chat to Compare & Test Best AI Chatbots
320
+
321
+ {SURVEY_LINK}
322
+
323
+ ## 📜 How It Works
324
+ - Ask any question to two chosen models (e.g., ChatGPT, Gemini, Claude, Llama) and vote for the better one!
325
+ - You can chat for multiple turns until you identify a winner.
326
+
327
+ ## 👇 Choose two models to compare
328
+ """
329
+
330
+ states = [gr.State() for _ in range(num_sides)]
331
+ model_selectors = [None] * num_sides
332
+ chatbots = [None] * num_sides
333
+
334
+ notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
335
+
336
+ with gr.Group(elem_id="share-region-named"):
337
+ with gr.Row():
338
+ for i in range(num_sides):
339
+ with gr.Column():
340
+ model_selectors[i] = gr.Dropdown(
341
+ choices=models,
342
+ value=models[i] if len(models) > i else "",
343
+ interactive=True,
344
+ show_label=False,
345
+ container=False,
346
+ )
347
+ with gr.Row():
348
+ with gr.Accordion(
349
+ f"🔍 Expand to see the descriptions of {len(models)} models", open=False
350
+ ):
351
+ model_description_md = get_model_description_md(models)
352
+ gr.Markdown(model_description_md,
353
+ elem_id="model_description_markdown")
354
+
355
+ with gr.Row():
356
+ for i in range(num_sides):
357
+ label = "Model A" if i == 0 else "Model B"
358
+ with gr.Column():
359
+ chatbots[i] = gr.Chatbot(
360
+ label=label,
361
+ elem_id=f"chatbot",
362
+ height=650,
363
+ show_copy_button=True,
364
+ )
365
+
366
+ with gr.Row():
367
+ leftvote_btn = gr.Button(
368
+ value="👈 A is better", visible=False, interactive=False
369
+ )
370
+ rightvote_btn = gr.Button(
371
+ value="👉 B is better", visible=False, interactive=False
372
+ )
373
+ tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
374
+ bothbad_btn = gr.Button(
375
+ value="👎 Both are bad", visible=False, interactive=False
376
+ )
377
+
378
+ with gr.Row():
379
+ textbox = gr.Textbox(
380
+ show_label=False,
381
+ placeholder="👉 Enter your prompt and press ENTER",
382
+ elem_id="input_box",
383
+ )
384
+ send_btn = gr.Button(value="Send", variant="primary", scale=0)
385
+
386
+ with gr.Row() as button_row:
387
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
388
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
389
+ share_btn = gr.Button(value="📷 Share")
390
+
391
+ with gr.Accordion("Parameters", open=False) as parameter_row:
392
+ temperature = gr.Slider(
393
+ minimum=0.0,
394
+ maximum=1.0,
395
+ value=0.7,
396
+ step=0.1,
397
+ interactive=True,
398
+ label="Temperature",
399
+ )
400
+ top_p = gr.Slider(
401
+ minimum=0.0,
402
+ maximum=1.0,
403
+ value=1.0,
404
+ step=0.1,
405
+ interactive=True,
406
+ label="Top P",
407
+ )
408
+ max_output_tokens = gr.Slider(
409
+ minimum=16,
410
+ maximum=2048,
411
+ value=1024,
412
+ step=64,
413
+ interactive=True,
414
+ label="Max output tokens",
415
+ )
416
+
417
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
418
+
419
+ # Register listeners
420
+ btn_list = [
421
+ leftvote_btn,
422
+ rightvote_btn,
423
+ tie_btn,
424
+ bothbad_btn,
425
+ regenerate_btn,
426
+ clear_btn,
427
+ ]
428
+ leftvote_btn.click(
429
+ leftvote_last_response,
430
+ states + model_selectors,
431
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
432
+ )
433
+ rightvote_btn.click(
434
+ rightvote_last_response,
435
+ states + model_selectors,
436
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
437
+ )
438
+ tie_btn.click(
439
+ tievote_last_response,
440
+ states + model_selectors,
441
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
442
+ )
443
+ bothbad_btn.click(
444
+ bothbad_vote_last_response,
445
+ states + model_selectors,
446
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
447
+ )
448
+ regenerate_btn.click(
449
+ regenerate, states, states + chatbots + [textbox] + btn_list
450
+ ).then(
451
+ bot_response_multi,
452
+ states + [temperature, top_p, max_output_tokens],
453
+ states + chatbots + btn_list,
454
+ ).then(
455
+ flash_buttons, [], btn_list
456
+ )
457
+ clear_btn.click(clear_history, None, states +
458
+ chatbots + [textbox] + btn_list)
459
+
460
+ share_js = """
461
+ function (a, b, c, d) {
462
+ const captureElement = document.querySelector('#share-region-named');
463
+ html2canvas(captureElement)
464
+ .then(canvas => {
465
+ canvas.style.display = 'none'
466
+ document.body.appendChild(canvas)
467
+ return canvas
468
+ })
469
+ .then(canvas => {
470
+ const image = canvas.toDataURL('image/png')
471
+ const a = document.createElement('a')
472
+ a.setAttribute('download', 'chatbot-arena.png')
473
+ a.setAttribute('href', image)
474
+ a.click()
475
+ canvas.remove()
476
+ });
477
+ return [a, b, c, d];
478
+ }
479
+ """
480
+ share_btn.click(share_click, states + model_selectors, [], js=share_js)
481
+
482
+ for i in range(num_sides):
483
+ model_selectors[i].change(
484
+ clear_history, None, states + chatbots + [textbox] + btn_list
485
+ )
486
+
487
+ textbox.submit(
488
+ add_text,
489
+ states + model_selectors + [textbox],
490
+ states + chatbots + [textbox] + btn_list,
491
+ ).then(
492
+ bot_response_multi,
493
+ states + [temperature, top_p, max_output_tokens],
494
+ states + chatbots + btn_list,
495
+ ).then(
496
+ flash_buttons, [], btn_list
497
+ )
498
+ send_btn.click(
499
+ add_text,
500
+ states + model_selectors + [textbox],
501
+ states + chatbots + [textbox] + btn_list,
502
+ ).then(
503
+ bot_response_multi,
504
+ states + [temperature, top_p, max_output_tokens],
505
+ states + chatbots + btn_list,
506
+ ).then(
507
+ flash_buttons, [], btn_list
508
+ )
509
+
510
+ return states + model_selectors
serve/gradio_block_arena_vision.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The gradio demo server for chatting with a large multimodal model.
3
+
4
+ Usage:
5
+ python3 -m fastchat.serve.controller
6
+ python3 -m fastchat.serve.sglang_worker --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf
7
+ python3 -m fastchat.serve.gradio_web_server_multi --share --vision-arena
8
+ """
9
+
10
+ import json
11
+ import os
12
+ import time
13
+ from typing import List, Union
14
+
15
+ import gradio as gr
16
+ from gradio.data_classes import FileData
17
+ import numpy as np
18
+
19
+ from .constants import (
20
+ TEXT_MODERATION_MSG,
21
+ IMAGE_MODERATION_MSG,
22
+ MODERATION_MSG,
23
+ CONVERSATION_LIMIT_MSG,
24
+ INPUT_CHAR_LEN_LIMIT,
25
+ CONVERSATION_TURN_LIMIT,
26
+ SURVEY_LINK,
27
+ )
28
+ # from fastchat.model.model_adapter import (
29
+ # get_conversation_template,
30
+ # )
31
+ from .gradio_global_state import Context
32
+ from .gradio_web_server import (
33
+ get_model_description_md,
34
+ acknowledgment_md,
35
+ bot_response,
36
+ get_ip,
37
+ disable_btn,
38
+ State,
39
+ get_conv_log_filename,
40
+ get_remote_logger,
41
+ )
42
+ # from fastchat.serve.vision.image import ImageFormat, Image
43
+ from .utils import (
44
+ build_logger,
45
+ moderation_filter,
46
+ image_moderation_filter,
47
+ )
48
+
49
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
50
+
51
+ no_change_btn = gr.Button()
52
+ enable_btn = gr.Button(interactive=True, visible=True)
53
+ disable_btn = gr.Button(interactive=False)
54
+ invisible_btn = gr.Button(interactive=False, visible=False)
55
+ visible_image_column = gr.Image(visible=True)
56
+ invisible_image_column = gr.Image(visible=False)
57
+ enable_multimodal = gr.MultimodalTextbox(
58
+ interactive=True, visible=True, placeholder="Enter your prompt or add image here"
59
+ )
60
+ invisible_text = gr.Textbox(visible=False, value="", interactive=False)
61
+ visible_text = gr.Textbox(
62
+ visible=True,
63
+ value="",
64
+ interactive=True,
65
+ placeholder="👉 Enter your prompt and press ENTER",
66
+ )
67
+ disable_multimodal = gr.MultimodalTextbox(
68
+ visible=False, value=None, interactive=False)
69
+
70
+
71
+ def get_vqa_sample():
72
+ random_sample = np.random.choice(vqa_samples)
73
+ question, path = random_sample["question"], random_sample["path"]
74
+ res = {"text": "", "files": [path]}
75
+ return (res, path)
76
+
77
+
78
+ def set_visible_image(textbox):
79
+ images = textbox["files"]
80
+ if len(images) == 0:
81
+ return invisible_image_column
82
+ elif len(images) > 1:
83
+ gr.Warning(
84
+ "We only support single image conversations. Please start a new round if you would like to chat using this image."
85
+ )
86
+
87
+ return visible_image_column
88
+
89
+
90
+ def set_invisible_image():
91
+ return invisible_image_column
92
+
93
+
94
+ def add_image(textbox):
95
+ images = textbox["files"]
96
+ if len(images) == 0:
97
+ return None
98
+
99
+ return images[0]
100
+
101
+
102
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
103
+ filename = get_conv_log_filename(state.is_vision, state.has_csam_image)
104
+ with open(filename, "a") as fout:
105
+ data = {
106
+ "tstamp": round(time.time(), 4),
107
+ "type": vote_type,
108
+ "model": model_selector,
109
+ "state": state.dict(),
110
+ "ip": get_ip(request),
111
+ }
112
+ fout.write(json.dumps(data) + "\n")
113
+ get_remote_logger().log(data)
114
+
115
+
116
+ def upvote_last_response(state, model_selector, request: gr.Request):
117
+ ip = get_ip(request)
118
+ logger.info(f"upvote. ip: {ip}")
119
+ vote_last_response(state, "upvote", model_selector, request)
120
+ return (None,) + (disable_btn,) * 3
121
+
122
+
123
+ def downvote_last_response(state, model_selector, request: gr.Request):
124
+ ip = get_ip(request)
125
+ logger.info(f"downvote. ip: {ip}")
126
+ vote_last_response(state, "downvote", model_selector, request)
127
+ return (None,) + (disable_btn,) * 3
128
+
129
+
130
+ def flag_last_response(state, model_selector, request: gr.Request):
131
+ ip = get_ip(request)
132
+ logger.info(f"flag. ip: {ip}")
133
+ vote_last_response(state, "flag", model_selector, request)
134
+ return (None,) + (disable_btn,) * 3
135
+
136
+
137
+ def regenerate(state, request: gr.Request):
138
+ ip = get_ip(request)
139
+ logger.info(f"regenerate. ip: {ip}")
140
+ if not state.regen_support:
141
+ state.skip_next = True
142
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
143
+ state.conv.update_last_message(None)
144
+ return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5
145
+
146
+
147
+ def clear_history(request: gr.Request):
148
+ ip = get_ip(request)
149
+ logger.info(f"clear_history. ip: {ip}")
150
+ state = None
151
+ return (state, [], enable_multimodal, invisible_text, invisible_btn) + (
152
+ disable_btn,
153
+ ) * 5
154
+
155
+
156
+ def clear_history_example(request: gr.Request):
157
+ ip = get_ip(request)
158
+ logger.info(f"clear_history_example. ip: {ip}")
159
+ state = None
160
+ return (state, [], enable_multimodal, invisible_text, invisible_btn) + (
161
+ disable_btn,
162
+ ) * 5
163
+
164
+
165
+ # TODO(Chris): At some point, we would like this to be a live-reporting feature.
166
+ def report_csam_image(state, image):
167
+ pass
168
+
169
+
170
+ def _prepare_text_with_image(state, text, images, csam_flag):
171
+ if len(images) > 0:
172
+ if len(state.conv.get_images()) > 0:
173
+ # reset convo with new image
174
+ state.conv = get_conversation_template(state.model_name)
175
+
176
+ text = text, [images[0]]
177
+
178
+ return text
179
+
180
+
181
+ # NOTE(chris): take multiple images later on
182
+ def convert_images_to_conversation_format(images):
183
+ import base64
184
+
185
+ MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB = 5 / 1.5
186
+ conv_images = []
187
+ if len(images) > 0:
188
+ conv_image = Image(url=images[0])
189
+ conv_image.to_conversation_format(MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB)
190
+ conv_images.append(conv_image)
191
+
192
+ return conv_images
193
+
194
+
195
+ def moderate_input(state, text, all_conv_text, model_list, images, ip):
196
+ text_flagged = moderation_filter(all_conv_text, model_list)
197
+ # flagged = moderation_filter(text, [state.model_name])
198
+ nsfw_flagged, csam_flagged = False, False
199
+ if len(images) > 0:
200
+ nsfw_flagged, csam_flagged = image_moderation_filter(images[0])
201
+
202
+ image_flagged = nsfw_flagged or csam_flagged
203
+ if text_flagged or image_flagged:
204
+ logger.info(f"violate moderation. ip: {ip}. text: {all_conv_text}")
205
+ if text_flagged and not image_flagged:
206
+ # overwrite the original text
207
+ text = TEXT_MODERATION_MSG
208
+ elif not text_flagged and image_flagged:
209
+ text = IMAGE_MODERATION_MSG
210
+ elif text_flagged and image_flagged:
211
+ text = MODERATION_MSG
212
+
213
+ if csam_flagged:
214
+ state.has_csam_image = True
215
+ report_csam_image(state, images[0])
216
+
217
+ return text, image_flagged, csam_flagged
218
+
219
+
220
+ def add_text(
221
+ state,
222
+ model_selector,
223
+ chat_input: Union[str, dict],
224
+ context: Context,
225
+ request: gr.Request,
226
+ ):
227
+ if isinstance(chat_input, dict):
228
+ text, images = chat_input["text"], chat_input["files"]
229
+ else:
230
+ text, images = chat_input, []
231
+
232
+ if (
233
+ len(images) > 0
234
+ and model_selector in context.text_models
235
+ and model_selector not in context.vision_models
236
+ ):
237
+ gr.Warning(f"{model_selector} is a text-only model. Image is ignored.")
238
+ images = []
239
+
240
+ ip = get_ip(request)
241
+ logger.info(f"add_text. ip: {ip}. len: {len(text)}")
242
+
243
+ if state is None:
244
+ if len(images) == 0:
245
+ state = State(model_selector, is_vision=False)
246
+ else:
247
+ state = State(model_selector, is_vision=True)
248
+
249
+ if len(text) <= 0:
250
+ state.skip_next = True
251
+ return (state, state.to_gradio_chatbot(), None, "", no_change_btn) + (
252
+ no_change_btn,
253
+ ) * 5
254
+
255
+ all_conv_text = state.conv.get_prompt()
256
+ all_conv_text = all_conv_text[-2000:] + "\nuser: " + text
257
+
258
+ images = convert_images_to_conversation_format(images)
259
+
260
+ text, image_flagged, csam_flag = moderate_input(
261
+ state, text, all_conv_text, [state.model_name], images, ip
262
+ )
263
+
264
+ if image_flagged:
265
+ logger.info(f"image flagged. ip: {ip}. text: {text}")
266
+ state.skip_next = True
267
+ return (
268
+ state,
269
+ state.to_gradio_chatbot(),
270
+ {"text": IMAGE_MODERATION_MSG},
271
+ "",
272
+ no_change_btn,
273
+ ) + (no_change_btn,) * 5
274
+
275
+ if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
276
+ logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
277
+ state.skip_next = True
278
+ return (
279
+ state,
280
+ state.to_gradio_chatbot(),
281
+ {"text": CONVERSATION_LIMIT_MSG},
282
+ "",
283
+ no_change_btn,
284
+ ) + (no_change_btn,) * 5
285
+
286
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
287
+ text = _prepare_text_with_image(state, text, images, csam_flag=csam_flag)
288
+ state.conv.append_message(state.conv.roles[0], text)
289
+ state.conv.append_message(state.conv.roles[1], None)
290
+ return (
291
+ state,
292
+ state.to_gradio_chatbot(),
293
+ disable_multimodal,
294
+ visible_text,
295
+ enable_btn,
296
+ ) + (disable_btn,) * 5
297
+
298
+
299
+ def build_single_vision_language_model_ui(
300
+ context: Context, add_promotion_links=False, random_questions=None
301
+ ):
302
+ promotion = (
303
+ f"""
304
+ [Blog](https://blog.lmarena.ai/blog/2023/arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2403.04132) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/6GXcFg3TH8) | [Kaggle Competition](https://www.kaggle.com/competitions/lmsys-chatbot-arena)
305
+
306
+ {SURVEY_LINK}
307
+
308
+ **❗️ For research purposes, we log user prompts and images, and may release this data to the public in the future. Please do not upload any confidential or personal information.**
309
+
310
+ Note: You can only chat with <span style='color: #DE3163; font-weight: bold'>one image per conversation</span>. You can upload images less than 15MB. Click the "Random Example" button to chat with a random image."""
311
+ if add_promotion_links
312
+ else ""
313
+ )
314
+
315
+ notice_markdown = f"""
316
+ # 🏔️ Chatbot Arena (formerly LMSYS): Free AI Chat to Compare & Test Best AI Chatbots
317
+ {promotion}
318
+ """
319
+
320
+ state = gr.State()
321
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
322
+ vision_not_in_text_models = [
323
+ model for model in context.vision_models if model not in context.text_models
324
+ ]
325
+ text_and_vision_models = context.text_models + vision_not_in_text_models
326
+ context_state = gr.State(context)
327
+
328
+ with gr.Group():
329
+ with gr.Row(elem_id="model_selector_row"):
330
+ model_selector = gr.Dropdown(
331
+ choices=text_and_vision_models,
332
+ value=text_and_vision_models[0]
333
+ if len(text_and_vision_models) > 0
334
+ else "",
335
+ interactive=True,
336
+ show_label=False,
337
+ container=False,
338
+ )
339
+
340
+ with gr.Accordion(
341
+ f"🔍 Expand to see the descriptions of {len(text_and_vision_models)} models",
342
+ open=False,
343
+ ):
344
+ model_description_md = get_model_description_md(
345
+ text_and_vision_models)
346
+ gr.Markdown(model_description_md,
347
+ elem_id="model_description_markdown")
348
+
349
+ with gr.Row():
350
+ with gr.Column(scale=2, visible=False) as image_column:
351
+ imagebox = gr.Image(
352
+ type="pil",
353
+ show_label=False,
354
+ interactive=False,
355
+ )
356
+ with gr.Column(scale=8):
357
+ chatbot = gr.Chatbot(
358
+ elem_id="chatbot",
359
+ label="Scroll down and start chatting",
360
+ height=650,
361
+ show_copy_button=True,
362
+ )
363
+
364
+ with gr.Row():
365
+ textbox = gr.Textbox(
366
+ show_label=False,
367
+ placeholder="👉 Enter your prompt and press ENTER",
368
+ elem_id="input_box",
369
+ visible=False,
370
+ )
371
+
372
+ send_btn = gr.Button(
373
+ value="Send", variant="primary", scale=0, visible=False, interactive=False
374
+ )
375
+
376
+ multimodal_textbox = gr.MultimodalTextbox(
377
+ file_types=["image"],
378
+ show_label=False,
379
+ placeholder="Enter your prompt or add image here",
380
+ container=True,
381
+ elem_id="input_box",
382
+ )
383
+
384
+ with gr.Row(elem_id="buttons"):
385
+ if random_questions:
386
+ global vqa_samples
387
+ with open(random_questions, "r") as f:
388
+ vqa_samples = json.load(f)
389
+ random_btn = gr.Button(value="🎲 Random Example", interactive=True)
390
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
391
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
392
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
393
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
394
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
395
+
396
+ with gr.Accordion("Parameters", open=False) as parameter_row:
397
+ temperature = gr.Slider(
398
+ minimum=0.0,
399
+ maximum=1.0,
400
+ value=0.7,
401
+ step=0.1,
402
+ interactive=True,
403
+ label="Temperature",
404
+ )
405
+ top_p = gr.Slider(
406
+ minimum=0.0,
407
+ maximum=1.0,
408
+ value=0.7,
409
+ step=0.1,
410
+ interactive=True,
411
+ label="Top P",
412
+ )
413
+ max_output_tokens = gr.Slider(
414
+ minimum=0,
415
+ maximum=2048,
416
+ value=1024,
417
+ step=64,
418
+ interactive=True,
419
+ label="Max output tokens",
420
+ )
421
+
422
+ if add_promotion_links:
423
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
424
+
425
+ # Register listeners
426
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
427
+ upvote_btn.click(
428
+ upvote_last_response,
429
+ [state, model_selector],
430
+ [textbox, upvote_btn, downvote_btn, flag_btn],
431
+ )
432
+ downvote_btn.click(
433
+ downvote_last_response,
434
+ [state, model_selector],
435
+ [textbox, upvote_btn, downvote_btn, flag_btn],
436
+ )
437
+ flag_btn.click(
438
+ flag_last_response,
439
+ [state, model_selector],
440
+ [textbox, upvote_btn, downvote_btn, flag_btn],
441
+ )
442
+ regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
443
+ bot_response,
444
+ [state, temperature, top_p, max_output_tokens],
445
+ [state, chatbot] + btn_list,
446
+ )
447
+ clear_btn.click(
448
+ clear_history,
449
+ None,
450
+ [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
451
+ )
452
+
453
+ model_selector.change(
454
+ clear_history,
455
+ None,
456
+ [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
457
+ ).then(set_visible_image, [multimodal_textbox], [image_column])
458
+
459
+ multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
460
+ set_visible_image, [multimodal_textbox], [image_column]
461
+ ).then(
462
+ clear_history_example,
463
+ None,
464
+ [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
465
+ )
466
+
467
+ multimodal_textbox.submit(
468
+ add_text,
469
+ [state, model_selector, multimodal_textbox, context_state],
470
+ [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
471
+ ).then(set_invisible_image, [], [image_column]).then(
472
+ bot_response,
473
+ [state, temperature, top_p, max_output_tokens],
474
+ [state, chatbot] + btn_list,
475
+ )
476
+
477
+ textbox.submit(
478
+ add_text,
479
+ [state, model_selector, textbox, context_state],
480
+ [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
481
+ ).then(set_invisible_image, [], [image_column]).then(
482
+ bot_response,
483
+ [state, temperature, top_p, max_output_tokens],
484
+ [state, chatbot] + btn_list,
485
+ )
486
+
487
+ send_btn.click(
488
+ add_text,
489
+ [state, model_selector, textbox, context_state],
490
+ [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
491
+ ).then(set_invisible_image, [], [image_column]).then(
492
+ bot_response,
493
+ [state, temperature, top_p, max_output_tokens],
494
+ [state, chatbot] + btn_list,
495
+ )
496
+
497
+ if random_questions:
498
+ random_btn.click(
499
+ get_vqa_sample, # First, get the VQA sample
500
+ [], # Pass the path to the VQA samples
501
+ [multimodal_textbox, imagebox], # Outputs are textbox and imagebox
502
+ ).then(set_visible_image, [multimodal_textbox], [image_column]).then(
503
+ clear_history_example,
504
+ None,
505
+ [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
506
+ )
507
+
508
+ return [state, model_selector]
serve/gradio_block_arena_vision_anony.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatbot Arena (battle) tab.
3
+ Users chat with two anonymous models.
4
+ """
5
+
6
+ import json
7
+ import time
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ from typing import Union
12
+
13
+ from .constants import (
14
+ TEXT_MODERATION_MSG,
15
+ IMAGE_MODERATION_MSG,
16
+ MODERATION_MSG,
17
+ CONVERSATION_LIMIT_MSG,
18
+ SLOW_MODEL_MSG,
19
+ BLIND_MODE_INPUT_CHAR_LEN_LIMIT,
20
+ CONVERSATION_TURN_LIMIT,
21
+ SURVEY_LINK,
22
+ )
23
+ from .gradio_block_arena_named import flash_buttons
24
+ from .gradio_web_server import (
25
+ State,
26
+ bot_response,
27
+ get_conv_log_filename,
28
+ no_change_btn,
29
+ enable_btn,
30
+ disable_btn,
31
+ invisible_btn,
32
+ acknowledgment_md,
33
+ get_ip,
34
+ get_model_description_md,
35
+ disable_text,
36
+ enable_text,
37
+ )
38
+ from .gradio_block_arena_anony import (
39
+ flash_buttons,
40
+ vote_last_response,
41
+ leftvote_last_response,
42
+ rightvote_last_response,
43
+ tievote_last_response,
44
+ bothbad_vote_last_response,
45
+ regenerate,
46
+ clear_history,
47
+ share_click,
48
+ bot_response_multi,
49
+ set_global_vars_anony,
50
+ load_demo_side_by_side_anony,
51
+ get_sample_weight,
52
+ get_battle_pair,
53
+ SAMPLING_WEIGHTS,
54
+ BATTLE_TARGETS,
55
+ SAMPLING_BOOST_MODELS,
56
+ OUTAGE_MODELS,
57
+ )
58
+ from .gradio_block_arena_vision import (
59
+ set_invisible_image,
60
+ set_visible_image,
61
+ add_image,
62
+ moderate_input,
63
+ enable_multimodal,
64
+ _prepare_text_with_image,
65
+ convert_images_to_conversation_format,
66
+ invisible_text,
67
+ visible_text,
68
+ disable_multimodal,
69
+ )
70
+ from .gradio_global_state import Context
71
+ from .remote_logger import get_remote_logger
72
+ from .utils import (
73
+ build_logger,
74
+ moderation_filter,
75
+ image_moderation_filter,
76
+ )
77
+
78
+ logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
79
+
80
+ num_sides = 2
81
+ enable_moderation = False
82
+ anony_names = ["", ""]
83
+ text_models = []
84
+ vl_models = []
85
+
86
+ # TODO(chris): fix sampling weights
87
+ VISION_SAMPLING_WEIGHTS = {}
88
+
89
+ # TODO(chris): Find battle targets that make sense
90
+ VISION_BATTLE_TARGETS = {}
91
+
92
+ # TODO(chris): Fill out models that require sampling boost
93
+ VISION_SAMPLING_BOOST_MODELS = []
94
+
95
+ # outage models won't be sampled.
96
+ VISION_OUTAGE_MODELS = []
97
+
98
+
99
+ def get_vqa_sample():
100
+ random_sample = np.random.choice(vqa_samples)
101
+ question, path = random_sample["question"], random_sample["path"]
102
+ res = {"text": "", "files": [path]}
103
+ return (res, path)
104
+
105
+
106
+ def load_demo_side_by_side_vision_anony():
107
+ states = [None] * num_sides
108
+ selector_updates = [
109
+ gr.Markdown(visible=True),
110
+ gr.Markdown(visible=True),
111
+ ]
112
+
113
+ return states + selector_updates
114
+
115
+
116
+ def clear_history_example(request: gr.Request):
117
+ logger.info(f"clear_history_example (anony). ip: {get_ip(request)}")
118
+ return (
119
+ [None] * num_sides
120
+ + [None] * num_sides
121
+ + anony_names
122
+ + [enable_multimodal, invisible_text, invisible_btn]
123
+ + [invisible_btn] * 4
124
+ + [disable_btn] * 2
125
+ + [enable_btn]
126
+ )
127
+
128
+
129
+ def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
130
+ filename = get_conv_log_filename(
131
+ states[0].is_vision, states[0].has_csam_image)
132
+
133
+ with open(filename, "a") as fout:
134
+ data = {
135
+ "tstamp": round(time.time(), 4),
136
+ "type": vote_type,
137
+ "models": [x for x in model_selectors],
138
+ "states": [x.dict() for x in states],
139
+ "ip": get_ip(request),
140
+ }
141
+ fout.write(json.dumps(data) + "\n")
142
+ get_remote_logger().log(data)
143
+
144
+ gr.Info(
145
+ "🎉 Thanks for voting! Your vote shapes the leaderboard, please vote RESPONSIBLY."
146
+ )
147
+
148
+ model_name_1 = states[0].model_name
149
+ model_name_2 = states[1].model_name
150
+ model_name_map = {}
151
+
152
+ if model_name_1 in model_name_map:
153
+ model_name_1 = model_name_map[model_name_1]
154
+ if model_name_2 in model_name_map:
155
+ model_name_2 = model_name_map[model_name_2]
156
+
157
+ if ":" not in model_selectors[0]:
158
+ for i in range(5):
159
+ names = (
160
+ "### Model A: " + model_name_1,
161
+ "### Model B: " + model_name_2,
162
+ )
163
+ yield names + (disable_text,) + (disable_btn,) * 4
164
+ time.sleep(0.1)
165
+ else:
166
+ names = (
167
+ "### Model A: " + model_name_1,
168
+ "### Model B: " + model_name_2,
169
+ )
170
+ yield names + (disable_text,) + (disable_btn,) * 4
171
+
172
+
173
+ def leftvote_last_response(
174
+ state0, state1, model_selector0, model_selector1, request: gr.Request
175
+ ):
176
+ logger.info(f"leftvote (anony). ip: {get_ip(request)}")
177
+ for x in vote_last_response(
178
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
179
+ ):
180
+ yield x
181
+
182
+
183
+ def rightvote_last_response(
184
+ state0, state1, model_selector0, model_selector1, request: gr.Request
185
+ ):
186
+ logger.info(f"rightvote (anony). ip: {get_ip(request)}")
187
+ for x in vote_last_response(
188
+ [state0, state1], "rightvote", [
189
+ model_selector0, model_selector1], request
190
+ ):
191
+ yield x
192
+
193
+
194
+ def tievote_last_response(
195
+ state0, state1, model_selector0, model_selector1, request: gr.Request
196
+ ):
197
+ logger.info(f"tievote (anony). ip: {get_ip(request)}")
198
+ for x in vote_last_response(
199
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
200
+ ):
201
+ yield x
202
+
203
+
204
+ def bothbad_vote_last_response(
205
+ state0, state1, model_selector0, model_selector1, request: gr.Request
206
+ ):
207
+ logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
208
+ for x in vote_last_response(
209
+ [state0, state1], "bothbad_vote", [
210
+ model_selector0, model_selector1], request
211
+ ):
212
+ yield x
213
+
214
+
215
+ def regenerate(state0, state1, request: gr.Request):
216
+ logger.info(f"regenerate (anony). ip: {get_ip(request)}")
217
+ states = [state0, state1]
218
+ if state0.regen_support and state1.regen_support:
219
+ for i in range(num_sides):
220
+ states[i].conv.update_last_message(None)
221
+ return (
222
+ states
223
+ + [x.to_gradio_chatbot() for x in states]
224
+ + [None]
225
+ + [disable_btn] * 6
226
+ )
227
+ states[0].skip_next = True
228
+ states[1].skip_next = True
229
+ return (
230
+ states + [x.to_gradio_chatbot() for x in states] +
231
+ [None] + [no_change_btn] * 6
232
+ )
233
+
234
+
235
+ def clear_history(request: gr.Request):
236
+ logger.info(f"clear_history (anony). ip: {get_ip(request)}")
237
+ return (
238
+ [None] * num_sides
239
+ + [None] * num_sides
240
+ + anony_names
241
+ + [enable_multimodal, invisible_text, invisible_btn]
242
+ + [invisible_btn] * 4
243
+ + [disable_btn] * 2
244
+ + [enable_btn]
245
+ + [""]
246
+ )
247
+
248
+
249
+ def add_text(
250
+ state0,
251
+ state1,
252
+ model_selector0,
253
+ model_selector1,
254
+ chat_input: Union[str, dict],
255
+ context: Context,
256
+ request: gr.Request,
257
+ ):
258
+ if isinstance(chat_input, dict):
259
+ text, images = chat_input["text"], chat_input["files"]
260
+ else:
261
+ text = chat_input
262
+ images = []
263
+
264
+ ip = get_ip(request)
265
+ logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
266
+ states = [state0, state1]
267
+ model_selectors = [model_selector0, model_selector1]
268
+
269
+ # Init states if necessary
270
+ if states[0] is None:
271
+ assert states[1] is None
272
+
273
+ if len(images) > 0:
274
+ model_left, model_right = get_battle_pair(
275
+ context.all_vision_models,
276
+ VISION_BATTLE_TARGETS,
277
+ VISION_OUTAGE_MODELS,
278
+ VISION_SAMPLING_WEIGHTS,
279
+ VISION_SAMPLING_BOOST_MODELS,
280
+ )
281
+ states = [
282
+ State(model_left, is_vision=True),
283
+ State(model_right, is_vision=True),
284
+ ]
285
+ else:
286
+ model_left, model_right = get_battle_pair(
287
+ context.all_text_models,
288
+ BATTLE_TARGETS,
289
+ OUTAGE_MODELS,
290
+ SAMPLING_WEIGHTS,
291
+ SAMPLING_BOOST_MODELS,
292
+ )
293
+
294
+ states = [
295
+ State(model_left, is_vision=False),
296
+ State(model_right, is_vision=False),
297
+ ]
298
+
299
+ if len(text) <= 0:
300
+ for i in range(num_sides):
301
+ states[i].skip_next = True
302
+ return (
303
+ states
304
+ + [x.to_gradio_chatbot() for x in states]
305
+ + [None, "", no_change_btn]
306
+ + [
307
+ no_change_btn,
308
+ ]
309
+ * 7
310
+ + [""]
311
+ )
312
+
313
+ model_list = [states[i].model_name for i in range(num_sides)]
314
+
315
+ images = convert_images_to_conversation_format(images)
316
+
317
+ text, image_flagged, csam_flag = moderate_input(
318
+ state0, text, text, model_list, images, ip
319
+ )
320
+
321
+ conv = states[0].conv
322
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
323
+ logger.info(
324
+ f"conversation turn limit. ip: {get_ip(request)}. text: {text}")
325
+ for i in range(num_sides):
326
+ states[i].skip_next = True
327
+ return (
328
+ states
329
+ + [x.to_gradio_chatbot() for x in states]
330
+ + [{"text": CONVERSATION_LIMIT_MSG}, "", no_change_btn]
331
+ + [
332
+ no_change_btn,
333
+ ]
334
+ * 7
335
+ + [""]
336
+ )
337
+
338
+ if image_flagged:
339
+ logger.info(f"image flagged. ip: {ip}. text: {text}")
340
+ for i in range(num_sides):
341
+ states[i].skip_next = True
342
+ return (
343
+ states
344
+ + [x.to_gradio_chatbot() for x in states]
345
+ + [
346
+ {
347
+ "text": IMAGE_MODERATION_MSG
348
+ + " PLEASE CLICK 🎲 NEW ROUND TO START A NEW CONVERSATION."
349
+ },
350
+ "",
351
+ no_change_btn,
352
+ ]
353
+ + [no_change_btn] * 7
354
+ + [""]
355
+ )
356
+
357
+ text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
358
+ for i in range(num_sides):
359
+ post_processed_text = _prepare_text_with_image(
360
+ states[i], text, images, csam_flag=csam_flag
361
+ )
362
+ states[i].conv.append_message(
363
+ states[i].conv.roles[0], post_processed_text)
364
+ states[i].conv.append_message(states[i].conv.roles[1], None)
365
+ states[i].skip_next = False
366
+
367
+ hint_msg = ""
368
+ for i in range(num_sides):
369
+ if "deluxe" in states[i].model_name:
370
+ hint_msg = SLOW_MODEL_MSG
371
+ return (
372
+ states
373
+ + [x.to_gradio_chatbot() for x in states]
374
+ + [disable_multimodal, visible_text, enable_btn]
375
+ + [
376
+ disable_btn,
377
+ ]
378
+ * 7
379
+ + [hint_msg]
380
+ )
381
+
382
+
383
+ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
384
+ notice_markdown = f"""
385
+ # ⚔️ Chatbot Arena (formerly LMSYS): Free AI Chat to Compare & Test Best AI Chatbots
386
+
387
+ {SURVEY_LINK}
388
+
389
+ ## 📜 How It Works
390
+ - **Blind Test**: Ask any question to two anonymous AI chatbots (ChatGPT, Gemini, Claude, Llama, and more).
391
+ - **Vote for the Best**: Choose the best response. You can keep chatting until you find a winner.
392
+ - **Play Fair**: If AI identity reveals, your vote won't count.
393
+
394
+ **NEW** Image Support: <span style='color: #DE3163; font-weight: bold'>Upload an image</span> to unlock the multimodal arena!
395
+
396
+ ## 🏆 Chatbot Arena LLM [Leaderboard](https://lmarena.ai/leaderboard)
397
+ - Backed by over **1,000,000+** community votes, our platform ranks the best LLM and AI chatbots. Explore the top AI models on our LLM [leaderboard](https://lmarena.ai/leaderboard)!
398
+
399
+ ## 👇 Chat now!
400
+ """
401
+
402
+ states = [gr.State() for _ in range(num_sides)]
403
+ model_selectors = [None] * num_sides
404
+ chatbots = [None] * num_sides
405
+ context_state = gr.State(context)
406
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
407
+ text_and_vision_models = context.models
408
+
409
+ with gr.Row():
410
+ with gr.Column(scale=2, visible=False) as image_column:
411
+ imagebox = gr.Image(
412
+ type="pil",
413
+ show_label=False,
414
+ interactive=False,
415
+ )
416
+
417
+ with gr.Column(scale=5):
418
+ with gr.Group(elem_id="share-region-anony"):
419
+ with gr.Accordion(
420
+ f"🔍 Expand to see the descriptions of {len(text_and_vision_models)} models",
421
+ open=False,
422
+ ):
423
+ model_description_md = get_model_description_md(
424
+ text_and_vision_models
425
+ )
426
+ gr.Markdown(
427
+ model_description_md, elem_id="model_description_markdown"
428
+ )
429
+
430
+ with gr.Row():
431
+ for i in range(num_sides):
432
+ label = "Model A" if i == 0 else "Model B"
433
+ with gr.Column():
434
+ chatbots[i] = gr.Chatbot(
435
+ label=label,
436
+ elem_id="chatbot",
437
+ height=650,
438
+ show_copy_button=True,
439
+ )
440
+
441
+ with gr.Row():
442
+ for i in range(num_sides):
443
+ with gr.Column():
444
+ model_selectors[i] = gr.Markdown(
445
+ anony_names[i], elem_id="model_selector_md"
446
+ )
447
+ with gr.Row():
448
+ slow_warning = gr.Markdown("", elem_id="notice_markdown")
449
+
450
+ with gr.Row():
451
+ leftvote_btn = gr.Button(
452
+ value="👈 A is better", visible=False, interactive=False
453
+ )
454
+ rightvote_btn = gr.Button(
455
+ value="👉 B is better", visible=False, interactive=False
456
+ )
457
+ tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
458
+ bothbad_btn = gr.Button(
459
+ value="👎 Both are bad", visible=False, interactive=False
460
+ )
461
+
462
+ with gr.Row():
463
+ textbox = gr.Textbox(
464
+ show_label=False,
465
+ placeholder="👉 Enter your prompt and press ENTER",
466
+ elem_id="input_box",
467
+ visible=False,
468
+ scale=3,
469
+ )
470
+
471
+ multimodal_textbox = gr.MultimodalTextbox(
472
+ file_types=["image"],
473
+ show_label=False,
474
+ container=True,
475
+ placeholder="Enter your prompt or add image here",
476
+ elem_id="input_box",
477
+ scale=3,
478
+ )
479
+ send_btn = gr.Button(
480
+ value="Send", variant="primary", scale=1, visible=False, interactive=False
481
+ )
482
+
483
+ with gr.Row() as button_row:
484
+ if random_questions:
485
+ global vqa_samples
486
+ with open(random_questions, "r") as f:
487
+ vqa_samples = json.load(f)
488
+ random_btn = gr.Button(value="🔮 Random Image", interactive=True)
489
+ clear_btn = gr.Button(value="🎲 New Round", interactive=False)
490
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
491
+ share_btn = gr.Button(value="📷 Share")
492
+
493
+ with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
494
+ temperature = gr.Slider(
495
+ minimum=0.0,
496
+ maximum=1.0,
497
+ value=0.7,
498
+ step=0.1,
499
+ interactive=True,
500
+ label="Temperature",
501
+ )
502
+ top_p = gr.Slider(
503
+ minimum=0.0,
504
+ maximum=1.0,
505
+ value=1.0,
506
+ step=0.1,
507
+ interactive=True,
508
+ label="Top P",
509
+ )
510
+ max_output_tokens = gr.Slider(
511
+ minimum=16,
512
+ maximum=2048,
513
+ value=2000,
514
+ step=64,
515
+ interactive=True,
516
+ label="Max output tokens",
517
+ )
518
+
519
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
520
+
521
+ # Register listeners
522
+ btn_list = [
523
+ leftvote_btn,
524
+ rightvote_btn,
525
+ tie_btn,
526
+ bothbad_btn,
527
+ regenerate_btn,
528
+ clear_btn,
529
+ ]
530
+ leftvote_btn.click(
531
+ leftvote_last_response,
532
+ states + model_selectors,
533
+ model_selectors + [textbox, leftvote_btn,
534
+ rightvote_btn, tie_btn, bothbad_btn],
535
+ )
536
+ rightvote_btn.click(
537
+ rightvote_last_response,
538
+ states + model_selectors,
539
+ model_selectors + [textbox, leftvote_btn,
540
+ rightvote_btn, tie_btn, bothbad_btn],
541
+ )
542
+ tie_btn.click(
543
+ tievote_last_response,
544
+ states + model_selectors,
545
+ model_selectors + [textbox, leftvote_btn,
546
+ rightvote_btn, tie_btn, bothbad_btn],
547
+ )
548
+ bothbad_btn.click(
549
+ bothbad_vote_last_response,
550
+ states + model_selectors,
551
+ model_selectors + [textbox, leftvote_btn,
552
+ rightvote_btn, tie_btn, bothbad_btn],
553
+ )
554
+ regenerate_btn.click(
555
+ regenerate, states, states + chatbots + [textbox] + btn_list
556
+ ).then(
557
+ bot_response_multi,
558
+ states + [temperature, top_p, max_output_tokens],
559
+ states + chatbots + btn_list,
560
+ ).then(
561
+ flash_buttons, [], btn_list
562
+ )
563
+ clear_btn.click(
564
+ clear_history,
565
+ None,
566
+ states
567
+ + chatbots
568
+ + model_selectors
569
+ + [multimodal_textbox, textbox, send_btn]
570
+ + btn_list
571
+ + [random_btn]
572
+ + [slow_warning],
573
+ )
574
+
575
+ share_js = """
576
+ function (a, b, c, d) {
577
+ const captureElement = document.querySelector('#share-region-anony');
578
+ html2canvas(captureElement)
579
+ .then(canvas => {
580
+ canvas.style.display = 'none'
581
+ document.body.appendChild(canvas)
582
+ return canvas
583
+ })
584
+ .then(canvas => {
585
+ const image = canvas.toDataURL('image/png')
586
+ const a = document.createElement('a')
587
+ a.setAttribute('download', 'chatbot-arena.png')
588
+ a.setAttribute('href', image)
589
+ a.click()
590
+ canvas.remove()
591
+ });
592
+ return [a, b, c, d];
593
+ }
594
+ """
595
+ share_btn.click(share_click, states + model_selectors, [], js=share_js)
596
+
597
+ multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
598
+ set_visible_image, [multimodal_textbox], [image_column]
599
+ ).then(
600
+ clear_history_example,
601
+ None,
602
+ states
603
+ + chatbots
604
+ + model_selectors
605
+ + [multimodal_textbox, textbox, send_btn]
606
+ + btn_list,
607
+ )
608
+
609
+ multimodal_textbox.submit(
610
+ add_text,
611
+ states + model_selectors + [multimodal_textbox, context_state],
612
+ states
613
+ + chatbots
614
+ + [multimodal_textbox, textbox, send_btn]
615
+ + btn_list
616
+ + [random_btn]
617
+ + [slow_warning],
618
+ ).then(set_invisible_image, [], [image_column]).then(
619
+ bot_response_multi,
620
+ states + [temperature, top_p, max_output_tokens],
621
+ states + chatbots + btn_list,
622
+ ).then(
623
+ flash_buttons,
624
+ [],
625
+ btn_list,
626
+ )
627
+
628
+ textbox.submit(
629
+ add_text,
630
+ states + model_selectors + [textbox, context_state],
631
+ states
632
+ + chatbots
633
+ + [multimodal_textbox, textbox, send_btn]
634
+ + btn_list
635
+ + [random_btn]
636
+ + [slow_warning],
637
+ ).then(
638
+ bot_response_multi,
639
+ states + [temperature, top_p, max_output_tokens],
640
+ states + chatbots + btn_list,
641
+ ).then(
642
+ flash_buttons,
643
+ [],
644
+ btn_list,
645
+ )
646
+
647
+ send_btn.click(
648
+ add_text,
649
+ states + model_selectors + [textbox, context_state],
650
+ states
651
+ + chatbots
652
+ + [multimodal_textbox, textbox, send_btn]
653
+ + btn_list
654
+ + [random_btn]
655
+ + [slow_warning],
656
+ ).then(
657
+ bot_response_multi,
658
+ states + [temperature, top_p, max_output_tokens],
659
+ states + chatbots + btn_list,
660
+ ).then(
661
+ flash_buttons,
662
+ [],
663
+ btn_list,
664
+ )
665
+
666
+ if random_questions:
667
+ random_btn.click(
668
+ get_vqa_sample, # First, get the VQA sample
669
+ [], # Pass the path to the VQA samples
670
+ [multimodal_textbox, imagebox], # Outputs are textbox and imagebox
671
+ ).then(set_visible_image, [multimodal_textbox], [image_column]).then(
672
+ clear_history_example,
673
+ None,
674
+ states
675
+ + chatbots
676
+ + model_selectors
677
+ + [multimodal_textbox, textbox, send_btn]
678
+ + btn_list
679
+ + [random_btn],
680
+ )
681
+
682
+ return states + model_selectors
serve/gradio_block_arena_vision_named.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multimodal Chatbot Arena (side-by-side) tab.
3
+ Users chat with two chosen models.
4
+ """
5
+
6
+ import json
7
+ import os
8
+ import time
9
+ from typing import List, Union
10
+
11
+ import gradio as gr
12
+ import numpy as np
13
+
14
+ from .constants import (
15
+ TEXT_MODERATION_MSG,
16
+ IMAGE_MODERATION_MSG,
17
+ MODERATION_MSG,
18
+ CONVERSATION_LIMIT_MSG,
19
+ SLOW_MODEL_MSG,
20
+ INPUT_CHAR_LEN_LIMIT,
21
+ CONVERSATION_TURN_LIMIT,
22
+ SURVEY_LINK,
23
+ )
24
+ from .gradio_block_arena_named import (
25
+ flash_buttons,
26
+ share_click,
27
+ bot_response_multi,
28
+ )
29
+ from .gradio_block_arena_vision import (
30
+ get_vqa_sample,
31
+ set_invisible_image,
32
+ set_visible_image,
33
+ add_image,
34
+ moderate_input,
35
+ _prepare_text_with_image,
36
+ convert_images_to_conversation_format,
37
+ enable_multimodal,
38
+ disable_multimodal,
39
+ invisible_text,
40
+ invisible_btn,
41
+ visible_text,
42
+ )
43
+ from .gradio_global_state import Context
44
+ from .gradio_web_server import (
45
+ State,
46
+ bot_response,
47
+ get_conv_log_filename,
48
+ no_change_btn,
49
+ enable_btn,
50
+ disable_btn,
51
+ invisible_btn,
52
+ acknowledgment_md,
53
+ get_ip,
54
+ get_model_description_md,
55
+ enable_text,
56
+ )
57
+ from .remote_logger import get_remote_logger
58
+ from .utils import (
59
+ build_logger,
60
+ moderation_filter,
61
+ image_moderation_filter,
62
+ )
63
+
64
+
65
+ logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
66
+
67
+ num_sides = 2
68
+ enable_moderation = False
69
+
70
+
71
+ def load_demo_side_by_side_vision_named(context: Context):
72
+ states = [None] * num_sides
73
+
74
+ # default to the text models
75
+ models = context.text_models
76
+
77
+ model_left = models[0] if len(models) > 0 else ""
78
+ if len(models) > 1:
79
+ weights = ([1] * 128)[: len(models) - 1]
80
+ weights = weights / np.sum(weights)
81
+ model_right = np.random.choice(models[1:], p=weights)
82
+ else:
83
+ model_right = model_left
84
+
85
+ all_models = context.models
86
+ selector_updates = [
87
+ gr.Dropdown(choices=all_models, value=model_left, visible=True),
88
+ gr.Dropdown(choices=all_models, value=model_right, visible=True),
89
+ ]
90
+
91
+ return states + selector_updates
92
+
93
+
94
+ def clear_history_example(request: gr.Request):
95
+ logger.info(f"clear_history_example (named). ip: {get_ip(request)}")
96
+ return (
97
+ [None] * num_sides
98
+ + [None] * num_sides
99
+ + [enable_multimodal, invisible_text, invisible_btn]
100
+ + [invisible_btn] * 4
101
+ + [disable_btn] * 2
102
+ )
103
+
104
+
105
+ def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
106
+ filename = get_conv_log_filename(
107
+ states[0].is_vision, states[0].has_csam_image)
108
+ with open(filename, "a") as fout:
109
+ data = {
110
+ "tstamp": round(time.time(), 4),
111
+ "type": vote_type,
112
+ "models": [x for x in model_selectors],
113
+ "states": [x.dict() for x in states],
114
+ "ip": get_ip(request),
115
+ }
116
+ fout.write(json.dumps(data) + "\n")
117
+ get_remote_logger().log(data)
118
+
119
+
120
+ def leftvote_last_response(
121
+ state0, state1, model_selector0, model_selector1, request: gr.Request
122
+ ):
123
+ logger.info(f"leftvote (named). ip: {get_ip(request)}")
124
+ vote_last_response(
125
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
126
+ )
127
+ return (None,) + (disable_btn,) * 4
128
+
129
+
130
+ def rightvote_last_response(
131
+ state0, state1, model_selector0, model_selector1, request: gr.Request
132
+ ):
133
+ logger.info(f"rightvote (named). ip: {get_ip(request)}")
134
+ vote_last_response(
135
+ [state0, state1], "rightvote", [
136
+ model_selector0, model_selector1], request
137
+ )
138
+ return (None,) + (disable_btn,) * 4
139
+
140
+
141
+ def tievote_last_response(
142
+ state0, state1, model_selector0, model_selector1, request: gr.Request
143
+ ):
144
+ logger.info(f"tievote (named). ip: {get_ip(request)}")
145
+ vote_last_response(
146
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
147
+ )
148
+ return (None,) + (disable_btn,) * 4
149
+
150
+
151
+ def bothbad_vote_last_response(
152
+ state0, state1, model_selector0, model_selector1, request: gr.Request
153
+ ):
154
+ logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
155
+ vote_last_response(
156
+ [state0, state1], "bothbad_vote", [
157
+ model_selector0, model_selector1], request
158
+ )
159
+ return (None,) + (disable_btn,) * 4
160
+
161
+
162
+ def regenerate(state0, state1, request: gr.Request):
163
+ logger.info(f"regenerate (named). ip: {get_ip(request)}")
164
+ states = [state0, state1]
165
+ if state0.regen_support and state1.regen_support:
166
+ for i in range(num_sides):
167
+ states[i].conv.update_last_message(None)
168
+ return (
169
+ states
170
+ + [x.to_gradio_chatbot() for x in states]
171
+ + [None]
172
+ + [disable_btn] * 6
173
+ )
174
+ states[0].skip_next = True
175
+ states[1].skip_next = True
176
+ return (
177
+ states + [x.to_gradio_chatbot() for x in states] +
178
+ [None] + [no_change_btn] * 6
179
+ )
180
+
181
+
182
+ def clear_history(request: gr.Request):
183
+ logger.info(f"clear_history (named). ip: {get_ip(request)}")
184
+ return (
185
+ [None] * num_sides
186
+ + [None] * num_sides
187
+ + [enable_multimodal, invisible_text, invisible_btn]
188
+ + [invisible_btn] * 4
189
+ + [disable_btn] * 2
190
+ )
191
+
192
+
193
+ def add_text(
194
+ state0,
195
+ state1,
196
+ model_selector0,
197
+ model_selector1,
198
+ chat_input: Union[str, dict],
199
+ context: Context,
200
+ request: gr.Request,
201
+ ):
202
+ if isinstance(chat_input, dict):
203
+ text, images = chat_input["text"], chat_input["files"]
204
+ else:
205
+ text, images = chat_input, []
206
+
207
+ if len(images) > 0:
208
+ if (
209
+ model_selector0 in context.text_models
210
+ and model_selector0 not in context.vision_models
211
+ ):
212
+ gr.Warning(
213
+ f"{model_selector0} is a text-only model. Image is ignored.")
214
+ images = []
215
+ if (
216
+ model_selector1 in context.text_models
217
+ and model_selector1 not in context.vision_models
218
+ ):
219
+ gr.Warning(
220
+ f"{model_selector1} is a text-only model. Image is ignored.")
221
+ images = []
222
+
223
+ ip = get_ip(request)
224
+ logger.info(f"add_text (named). ip: {ip}. len: {len(text)}")
225
+ states = [state0, state1]
226
+ model_selectors = [model_selector0, model_selector1]
227
+
228
+ # Init states if necessary
229
+ for i in range(num_sides):
230
+ if states[i] is None and len(images) == 0:
231
+ states[i] = State(model_selectors[i], is_vision=False)
232
+ elif states[i] is None and len(images) > 0:
233
+ states[i] = State(model_selectors[i], is_vision=True)
234
+
235
+ if len(text) <= 0:
236
+ for i in range(num_sides):
237
+ states[i].skip_next = True
238
+ return (
239
+ states
240
+ + [x.to_gradio_chatbot() for x in states]
241
+ + [None, "", no_change_btn]
242
+ + [
243
+ no_change_btn,
244
+ ]
245
+ * 6
246
+ )
247
+
248
+ model_list = [states[i].model_name for i in range(num_sides)]
249
+ all_conv_text_left = states[0].conv.get_prompt()
250
+ all_conv_text_right = states[0].conv.get_prompt()
251
+ all_conv_text = (
252
+ all_conv_text_left[-1000:] +
253
+ all_conv_text_right[-1000:] + "\nuser: " + text
254
+ )
255
+
256
+ images = convert_images_to_conversation_format(images)
257
+
258
+ text, image_flagged, csam_flag = moderate_input(
259
+ state0, text, all_conv_text, model_list, images, ip
260
+ )
261
+
262
+ conv = states[0].conv
263
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
264
+ logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
265
+ for i in range(num_sides):
266
+ states[i].skip_next = True
267
+ return (
268
+ states
269
+ + [x.to_gradio_chatbot() for x in states]
270
+ + [{"text": CONVERSATION_LIMIT_MSG}, "", no_change_btn]
271
+ + [
272
+ no_change_btn,
273
+ ]
274
+ * 6
275
+ )
276
+
277
+ if image_flagged:
278
+ logger.info(f"image flagged. ip: {ip}. text: {text}")
279
+ for i in range(num_sides):
280
+ states[i].skip_next = True
281
+ return (
282
+ states
283
+ + [x.to_gradio_chatbot() for x in states]
284
+ + [{"text": IMAGE_MODERATION_MSG}, "", no_change_btn]
285
+ + [
286
+ no_change_btn,
287
+ ]
288
+ * 6
289
+ )
290
+
291
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
292
+ for i in range(num_sides):
293
+ post_processed_text = _prepare_text_with_image(
294
+ states[i], text, images, csam_flag=csam_flag
295
+ )
296
+ states[i].conv.append_message(
297
+ states[i].conv.roles[0], post_processed_text)
298
+ states[i].conv.append_message(states[i].conv.roles[1], None)
299
+ states[i].skip_next = False
300
+
301
+ return (
302
+ states
303
+ + [x.to_gradio_chatbot() for x in states]
304
+ + [disable_multimodal, visible_text, enable_btn]
305
+ + [
306
+ disable_btn,
307
+ ]
308
+ * 6
309
+ )
310
+
311
+
312
+ def build_side_by_side_vision_ui_named(context: Context, random_questions=None):
313
+ notice_markdown = f"""
314
+ # ⚔️ Chatbot Arena (formerly LMSYS): Free AI Chat to Compare & Test Best AI Chatbots
315
+
316
+ {SURVEY_LINK}
317
+
318
+ ## 📜 How It Works
319
+ - Ask any question to two chosen models (e.g., ChatGPT, Gemini, Claude, Llama) and vote for the better one!
320
+ - You can chat for multiple turns until you identify a winner.
321
+
322
+ Note: You can only chat with <span style='color: #DE3163; font-weight: bold'>one image per conversation</span>. You can upload images less than 15MB. Click the "Random Example" button to chat with a random image.
323
+
324
+ **❗️ For research purposes, we log user prompts and images, and may release this data to the public in the future. Please do not upload any confidential or personal information.**
325
+
326
+ ## 🤖 Choose two models to compare
327
+ """
328
+
329
+ states = [gr.State() for _ in range(num_sides)]
330
+ model_selectors = [None] * num_sides
331
+ chatbots = [None] * num_sides
332
+
333
+ notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
334
+
335
+ text_and_vision_models = context.models
336
+ context_state = gr.State(context)
337
+
338
+ with gr.Row():
339
+ with gr.Column(scale=2, visible=False) as image_column:
340
+ imagebox = gr.Image(
341
+ type="pil",
342
+ show_label=False,
343
+ interactive=False,
344
+ )
345
+
346
+ with gr.Column(scale=5):
347
+ with gr.Group(elem_id="share-region-anony"):
348
+ with gr.Accordion(
349
+ f"🔍 Expand to see the descriptions of {len(text_and_vision_models)} models",
350
+ open=False,
351
+ ):
352
+ model_description_md = get_model_description_md(
353
+ text_and_vision_models
354
+ )
355
+ gr.Markdown(
356
+ model_description_md, elem_id="model_description_markdown"
357
+ )
358
+
359
+ with gr.Row():
360
+ for i in range(num_sides):
361
+ with gr.Column():
362
+ model_selectors[i] = gr.Dropdown(
363
+ choices=text_and_vision_models,
364
+ value=text_and_vision_models[i]
365
+ if len(text_and_vision_models) > i
366
+ else "",
367
+ interactive=True,
368
+ show_label=False,
369
+ container=False,
370
+ )
371
+
372
+ with gr.Row():
373
+ for i in range(num_sides):
374
+ label = "Model A" if i == 0 else "Model B"
375
+ with gr.Column():
376
+ chatbots[i] = gr.Chatbot(
377
+ label=label,
378
+ elem_id=f"chatbot",
379
+ height=650,
380
+ show_copy_button=True,
381
+ )
382
+
383
+ with gr.Row():
384
+ leftvote_btn = gr.Button(
385
+ value="👈 A is better", visible=False, interactive=False
386
+ )
387
+ rightvote_btn = gr.Button(
388
+ value="👉 B is better", visible=False, interactive=False
389
+ )
390
+ tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
391
+ bothbad_btn = gr.Button(
392
+ value="👎 Both are bad", visible=False, interactive=False
393
+ )
394
+
395
+ with gr.Row():
396
+ textbox = gr.Textbox(
397
+ show_label=False,
398
+ placeholder="👉 Enter your prompt and press ENTER",
399
+ elem_id="input_box",
400
+ visible=False,
401
+ )
402
+
403
+ send_btn = gr.Button(
404
+ value="Send", variant="primary", scale=0, visible=False, interactive=False
405
+ )
406
+
407
+ multimodal_textbox = gr.MultimodalTextbox(
408
+ file_types=["image"],
409
+ show_label=False,
410
+ placeholder="Enter your prompt or add image here",
411
+ container=True,
412
+ elem_id="input_box",
413
+ )
414
+
415
+ with gr.Row() as button_row:
416
+ if random_questions:
417
+ global vqa_samples
418
+ with open(random_questions, "r") as f:
419
+ vqa_samples = json.load(f)
420
+ random_btn = gr.Button(value="🎲 Random Example", interactive=True)
421
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
422
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
423
+ share_btn = gr.Button(value="📷 Share")
424
+
425
+ with gr.Accordion("Parameters", open=False) as parameter_row:
426
+ temperature = gr.Slider(
427
+ minimum=0.0,
428
+ maximum=1.0,
429
+ value=0.7,
430
+ step=0.1,
431
+ interactive=True,
432
+ label="Temperature",
433
+ )
434
+ top_p = gr.Slider(
435
+ minimum=0.0,
436
+ maximum=1.0,
437
+ value=1.0,
438
+ step=0.1,
439
+ interactive=True,
440
+ label="Top P",
441
+ )
442
+ max_output_tokens = gr.Slider(
443
+ minimum=16,
444
+ maximum=2048,
445
+ value=1024,
446
+ step=64,
447
+ interactive=True,
448
+ label="Max output tokens",
449
+ )
450
+
451
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
452
+
453
+ # Register listeners
454
+ btn_list = [
455
+ leftvote_btn,
456
+ rightvote_btn,
457
+ tie_btn,
458
+ bothbad_btn,
459
+ regenerate_btn,
460
+ clear_btn,
461
+ ]
462
+ leftvote_btn.click(
463
+ leftvote_last_response,
464
+ states + model_selectors,
465
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
466
+ )
467
+ rightvote_btn.click(
468
+ rightvote_last_response,
469
+ states + model_selectors,
470
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
471
+ )
472
+ tie_btn.click(
473
+ tievote_last_response,
474
+ states + model_selectors,
475
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
476
+ )
477
+ bothbad_btn.click(
478
+ bothbad_vote_last_response,
479
+ states + model_selectors,
480
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
481
+ )
482
+ regenerate_btn.click(
483
+ regenerate, states, states + chatbots + [textbox] + btn_list
484
+ ).then(
485
+ bot_response_multi,
486
+ states + [temperature, top_p, max_output_tokens],
487
+ states + chatbots + btn_list,
488
+ ).then(
489
+ flash_buttons, [], btn_list
490
+ )
491
+ clear_btn.click(
492
+ clear_history,
493
+ None,
494
+ states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
495
+ )
496
+
497
+ share_js = """
498
+ function (a, b, c, d) {
499
+ const captureElement = document.querySelector('#share-region-named');
500
+ html2canvas(captureElement)
501
+ .then(canvas => {
502
+ canvas.style.display = 'none'
503
+ document.body.appendChild(canvas)
504
+ return canvas
505
+ })
506
+ .then(canvas => {
507
+ const image = canvas.toDataURL('image/png')
508
+ const a = document.createElement('a')
509
+ a.setAttribute('download', 'chatbot-arena.png')
510
+ a.setAttribute('href', image)
511
+ a.click()
512
+ canvas.remove()
513
+ });
514
+ return [a, b, c, d];
515
+ }
516
+ """
517
+ share_btn.click(share_click, states + model_selectors, [], js=share_js)
518
+
519
+ for i in range(num_sides):
520
+ model_selectors[i].change(
521
+ clear_history,
522
+ None,
523
+ states + chatbots + [multimodal_textbox,
524
+ textbox, send_btn] + btn_list,
525
+ ).then(set_visible_image, [multimodal_textbox], [image_column])
526
+
527
+ multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
528
+ set_visible_image, [multimodal_textbox], [image_column]
529
+ ).then(
530
+ clear_history_example,
531
+ None,
532
+ states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
533
+ )
534
+
535
+ multimodal_textbox.submit(
536
+ add_text,
537
+ states + model_selectors + [multimodal_textbox, context_state],
538
+ states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
539
+ ).then(set_invisible_image, [], [image_column]).then(
540
+ bot_response_multi,
541
+ states + [temperature, top_p, max_output_tokens],
542
+ states + chatbots + btn_list,
543
+ ).then(
544
+ flash_buttons, [], btn_list
545
+ )
546
+
547
+ textbox.submit(
548
+ add_text,
549
+ states + model_selectors + [textbox, context_state],
550
+ states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
551
+ ).then(set_invisible_image, [], [image_column]).then(
552
+ bot_response_multi,
553
+ states + [temperature, top_p, max_output_tokens],
554
+ states + chatbots + btn_list,
555
+ ).then(
556
+ flash_buttons, [], btn_list
557
+ )
558
+
559
+ send_btn.click(
560
+ add_text,
561
+ states + model_selectors + [textbox, context_state],
562
+ states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
563
+ ).then(set_invisible_image, [], [image_column]).then(
564
+ bot_response_multi,
565
+ states + [temperature, top_p, max_output_tokens],
566
+ states + chatbots + btn_list,
567
+ ).then(
568
+ flash_buttons, [], btn_list
569
+ )
570
+
571
+ if random_questions:
572
+ random_btn.click(
573
+ get_vqa_sample, # First, get the VQA sample
574
+ [], # Pass the path to the VQA samples
575
+ [multimodal_textbox, imagebox], # Outputs are textbox and imagebox
576
+ ).then(set_visible_image, [multimodal_textbox], [image_column]).then(
577
+ clear_history_example,
578
+ None,
579
+ states + chatbots + [multimodal_textbox,
580
+ textbox, send_btn] + btn_list,
581
+ )
582
+
583
+ return states + model_selectors
serve/gradio_global_state.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+
5
+ @dataclass
6
+ class Context:
7
+ text_models: List[str] = field(default_factory=list)
8
+ all_text_models: List[str] = field(default_factory=list)
9
+ vision_models: List[str] = field(default_factory=list)
10
+ all_vision_models: List[str] = field(default_factory=list)
11
+ models: List[str] = field(default_factory=list)
12
+ all_models: List[str] = field(default_factory=list)
serve/gradio_web_server.py ADDED
@@ -0,0 +1,1051 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The gradio demo server for chatting with a single model.
3
+ """
4
+
5
+ import argparse
6
+ import datetime
7
+ import json
8
+ import os
9
+ import time
10
+ import uuid
11
+ from typing import List
12
+
13
+ import gradio as gr
14
+ import requests
15
+
16
+ from .conversation import Conversation
17
+ from .constants import (
18
+ LOGDIR,
19
+ WORKER_API_TIMEOUT,
20
+ ErrorCode,
21
+ MODERATION_MSG,
22
+ CONVERSATION_LIMIT_MSG,
23
+ RATE_LIMIT_MSG,
24
+ SERVER_ERROR_MSG,
25
+ INPUT_CHAR_LEN_LIMIT,
26
+ CONVERSATION_TURN_LIMIT,
27
+ SESSION_EXPIRATION_TIME,
28
+ SURVEY_LINK,
29
+ )
30
+ # from .model.model_adapter import (
31
+ # get_conversation_template,
32
+ # )
33
+ # from .model.model_registry import get_model_info, model_info
34
+ from .api_provider import get_api_provider_stream_iter
35
+ from .gradio_global_state import Context
36
+ from serve.remote_logger import get_remote_logger
37
+ from .utils import (
38
+ build_logger,
39
+ get_window_url_params_js,
40
+ get_window_url_params_with_tos_js,
41
+ moderation_filter,
42
+ parse_gradio_auth_creds,
43
+ load_image,
44
+ )
45
+
46
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
47
+
48
+ headers = {"User-Agent": "FastChat Client"}
49
+
50
+ no_change_btn = gr.Button()
51
+ enable_btn = gr.Button(interactive=True, visible=True)
52
+ disable_btn = gr.Button(interactive=False)
53
+ invisible_btn = gr.Button(interactive=False, visible=False)
54
+ enable_text = gr.Textbox(
55
+ interactive=True, visible=True, placeholder="👉 Enter your prompt and press ENTER"
56
+ )
57
+ disable_text = gr.Textbox(
58
+ interactive=False,
59
+ visible=True,
60
+ placeholder='Press "🎲 New Round" to start over👇 (Note: Your vote shapes the leaderboard, please vote RESPONSIBLY!)',
61
+ )
62
+
63
+ controller_url = None
64
+ enable_moderation = False
65
+ use_remote_storage = False
66
+
67
+ acknowledgment_md = """
68
+ ### Terms of Service
69
+ ユーザーは、サービスを利用する前に以下の条件に同意する必要があります:
70
+
71
+ - 違法、有害、暴力、人種差別、または性的目的で使用しないでください。
72
+ - 個人情報をアップロードしないでください。
73
+ - このサービスで収集された対話データは今後の大規模言語モデルの開発のほか、適切なマスキング処理を施した上で、クリエイティブ コモンズ アトリビューション (CC-BY) または同様のライセンスの下で配布される可能性があります。
74
+
75
+ """
76
+
77
+ # JSON file format of API-based models:
78
+ # {
79
+ # "gpt-3.5-turbo": {
80
+ # "model_name": "gpt-3.5-turbo",
81
+ # "api_type": "openai",
82
+ # "api_base": "https://api.openai.com/v1",
83
+ # "api_key": "sk-******",
84
+ # "anony_only": false
85
+ # }
86
+ # }
87
+ #
88
+ # - "api_type" can be one of the following: openai, anthropic, gemini, or mistral. For custom APIs, add a new type and implement it accordingly.
89
+ # - "anony_only" indicates whether to display this model in anonymous mode only.
90
+
91
+ api_endpoint_info = {}
92
+
93
+
94
+ class State:
95
+ def __init__(self, model_name, is_vision=False):
96
+ # self.conv = get_conversation_template(model_name)
97
+ self.conv = Conversation(model_name)
98
+ self.conv_id = uuid.uuid4().hex
99
+ self.skip_next = False
100
+ self.model_name = model_name
101
+ self.oai_thread_id = None
102
+ self.is_vision = is_vision
103
+
104
+ # NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes.
105
+ self.has_csam_image = False
106
+
107
+ self.regen_support = True
108
+ if "browsing" in model_name:
109
+ self.regen_support = False
110
+ self.init_system_prompt(self.conv, is_vision)
111
+
112
+ def init_system_prompt(self, conv, is_vision):
113
+ system_prompt = conv.get_system_message(is_vision)
114
+ if len(system_prompt) == 0:
115
+ return
116
+ current_date = datetime.datetime.now().strftime("%Y-%m-%d")
117
+ system_prompt = system_prompt.replace(
118
+ "{{currentDateTime}}", current_date)
119
+
120
+ current_date_v2 = datetime.datetime.now().strftime("%d %b %Y")
121
+ system_prompt = system_prompt.replace(
122
+ "{{currentDateTimev2}}", current_date_v2)
123
+
124
+ current_date_v3 = datetime.datetime.now().strftime("%B %Y")
125
+ system_prompt = system_prompt.replace(
126
+ "{{currentDateTimev3}}", current_date_v3)
127
+ conv.set_system_message(system_prompt)
128
+
129
+ def to_gradio_chatbot(self):
130
+ return self.conv.to_gradio_chatbot()
131
+
132
+ def dict(self):
133
+ base = self.conv.dict()
134
+ base.update(
135
+ {
136
+ "conv_id": self.conv_id,
137
+ "model_name": self.model_name,
138
+ }
139
+ )
140
+
141
+ if self.is_vision:
142
+ base.update({"has_csam_image": self.has_csam_image})
143
+ return base
144
+
145
+
146
+ def set_global_vars(
147
+ controller_url_,
148
+ enable_moderation_,
149
+ use_remote_storage_,
150
+ ):
151
+ global controller_url, enable_moderation, use_remote_storage
152
+ controller_url = controller_url_
153
+ enable_moderation = enable_moderation_
154
+ use_remote_storage = use_remote_storage_
155
+
156
+
157
+ def get_conv_log_filename(is_vision=False, has_csam_image=False):
158
+ t = datetime.datetime.now()
159
+ conv_log_filename = f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json"
160
+ if is_vision and not has_csam_image:
161
+ name = os.path.join(LOGDIR, f"vision-tmp-{conv_log_filename}")
162
+ elif is_vision and has_csam_image:
163
+ name = os.path.join(LOGDIR, f"vision-csam-{conv_log_filename}")
164
+ else:
165
+ name = os.path.join(LOGDIR, conv_log_filename)
166
+
167
+ return name
168
+
169
+
170
+ def get_model_list(controller_url, register_api_endpoint_file, vision_arena):
171
+ global api_endpoint_info
172
+
173
+ # Add models from the controller
174
+ if controller_url:
175
+ ret = requests.post(controller_url + "/refresh_all_workers")
176
+ assert ret.status_code == 200
177
+
178
+ if vision_arena:
179
+ ret = requests.post(controller_url + "/list_multimodal_models")
180
+ models = ret.json()["models"]
181
+ else:
182
+ ret = requests.post(controller_url + "/list_language_models")
183
+ models = ret.json()["models"]
184
+ else:
185
+ models = []
186
+
187
+ # Add models from the API providers
188
+ if register_api_endpoint_file:
189
+ api_endpoint_info = json.load(open(register_api_endpoint_file))
190
+ for mdl, mdl_dict in api_endpoint_info.items():
191
+ mdl_vision = mdl_dict.get("vision-arena", False)
192
+ mdl_text = mdl_dict.get("text-arena", True)
193
+ if vision_arena and mdl_vision:
194
+ models.append(mdl)
195
+ if not vision_arena and mdl_text:
196
+ models.append(mdl)
197
+
198
+ # Remove anonymous models
199
+ models = list(set(models))
200
+ visible_models = models.copy()
201
+ for mdl in models:
202
+ if mdl not in api_endpoint_info:
203
+ continue
204
+ mdl_dict = api_endpoint_info[mdl]
205
+ if mdl_dict["anony_only"]:
206
+ visible_models.remove(mdl)
207
+
208
+ # Sort models and add descriptions
209
+ model_info = list(range(len(models)))
210
+ priority = {k: f"___{i:03d}" for i, k in enumerate(model_info)}
211
+ models.sort(key=lambda x: priority.get(x, x))
212
+ visible_models.sort(key=lambda x: priority.get(x, x))
213
+ logger.info(f"All models: {models}")
214
+ logger.info(f"Visible models: {visible_models}")
215
+ return visible_models, models
216
+
217
+
218
+ def load_demo_single(context: Context, query_params):
219
+ # default to text models
220
+ models = context.text_models
221
+
222
+ selected_model = models[0] if len(models) > 0 else ""
223
+ if "model" in query_params:
224
+ model = query_params["model"]
225
+ if model in models:
226
+ selected_model = model
227
+
228
+ all_models = context.models
229
+
230
+ dropdown_update = gr.Dropdown(
231
+ choices=all_models, value=selected_model, visible=True
232
+ )
233
+ state = None
234
+ return [state, dropdown_update]
235
+
236
+
237
+ def load_demo(url_params, request: gr.Request):
238
+ global models
239
+
240
+ ip = get_ip(request)
241
+ logger.info(f"load_demo. ip: {ip}. params: {url_params}")
242
+
243
+ if args.model_list_mode == "reload":
244
+ models, all_models = get_model_list(
245
+ controller_url, args.register_api_endpoint_file, vision_arena=False
246
+ )
247
+
248
+ return load_demo_single(models, url_params)
249
+
250
+
251
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
252
+ filename = get_conv_log_filename()
253
+ if "llava" in model_selector:
254
+ filename = filename.replace("2024", "vision-tmp-2024")
255
+
256
+ with open(filename, "a") as fout:
257
+ data = {
258
+ "tstamp": round(time.time(), 4),
259
+ "type": vote_type,
260
+ "model": model_selector,
261
+ "state": state.dict(),
262
+ "ip": get_ip(request),
263
+ }
264
+ fout.write(json.dumps(data) + "\n")
265
+ get_remote_logger().log(data)
266
+
267
+
268
+ def upvote_last_response(state, model_selector, request: gr.Request):
269
+ ip = get_ip(request)
270
+ logger.info(f"upvote. ip: {ip}")
271
+ vote_last_response(state, "upvote", model_selector, request)
272
+ return ("",) + (disable_btn,) * 3
273
+
274
+
275
+ def downvote_last_response(state, model_selector, request: gr.Request):
276
+ ip = get_ip(request)
277
+ logger.info(f"downvote. ip: {ip}")
278
+ vote_last_response(state, "downvote", model_selector, request)
279
+ return ("",) + (disable_btn,) * 3
280
+
281
+
282
+ def flag_last_response(state, model_selector, request: gr.Request):
283
+ ip = get_ip(request)
284
+ logger.info(f"flag. ip: {ip}")
285
+ vote_last_response(state, "flag", model_selector, request)
286
+ return ("",) + (disable_btn,) * 3
287
+
288
+
289
+ def regenerate(state, request: gr.Request):
290
+ ip = get_ip(request)
291
+ logger.info(f"regenerate. ip: {ip}")
292
+ if not state.regen_support:
293
+ state.skip_next = True
294
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
295
+ state.conv.update_last_message(None)
296
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
297
+
298
+
299
+ def clear_history(request: gr.Request):
300
+ ip = get_ip(request)
301
+ logger.info(f"clear_history. ip: {ip}")
302
+ state = None
303
+ return (state, [], "") + (disable_btn,) * 5
304
+
305
+
306
+ def get_ip(request: gr.Request):
307
+ if "cf-connecting-ip" in request.headers:
308
+ ip = request.headers["cf-connecting-ip"]
309
+ elif "x-forwarded-for" in request.headers:
310
+ ip = request.headers["x-forwarded-for"]
311
+ if "," in ip:
312
+ ip = ip.split(",")[0]
313
+ else:
314
+ ip = request.client.host
315
+ return ip
316
+
317
+
318
+ def add_text(state, model_selector, text, request: gr.Request):
319
+ ip = get_ip(request)
320
+ logger.info(f"add_text. ip: {ip}. len: {len(text)}")
321
+
322
+ if state is None:
323
+ state = State(model_selector)
324
+
325
+ if len(text) <= 0:
326
+ state.skip_next = True
327
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
328
+
329
+ all_conv_text = state.conv.get_prompt()
330
+ all_conv_text = all_conv_text[-2000:] + "\nuser: " + text
331
+ flagged = moderation_filter(all_conv_text, [state.model_name])
332
+ # flagged = moderation_filter(text, [state.model_name])
333
+ if flagged:
334
+ logger.info(f"violate moderation. ip: {ip}. text: {text}")
335
+ # overwrite the original text
336
+ text = MODERATION_MSG
337
+
338
+ if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
339
+ logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
340
+ state.skip_next = True
341
+ return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG, None) + (
342
+ no_change_btn,
343
+ ) * 5
344
+
345
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
346
+ state.conv.append_message(state.conv.roles[0], text)
347
+ state.conv.append_message(state.conv.roles[1], None)
348
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
349
+
350
+
351
+ def model_worker_stream_iter(
352
+ conv,
353
+ model_name,
354
+ worker_addr,
355
+ prompt,
356
+ temperature,
357
+ repetition_penalty,
358
+ top_p,
359
+ max_new_tokens,
360
+ images,
361
+ ):
362
+ # Make requests
363
+ gen_params = {
364
+ "model": model_name,
365
+ "prompt": prompt,
366
+ "temperature": temperature,
367
+ "repetition_penalty": repetition_penalty,
368
+ "top_p": top_p,
369
+ "max_new_tokens": max_new_tokens,
370
+ "stop": conv.stop_str,
371
+ "stop_token_ids": conv.stop_token_ids,
372
+ "echo": False,
373
+ }
374
+
375
+ logger.info(f"==== request ====\n{gen_params}")
376
+
377
+ if len(images) > 0:
378
+ gen_params["images"] = images
379
+
380
+ # Stream output
381
+ response = requests.post(
382
+ worker_addr + "/worker_generate_stream",
383
+ headers=headers,
384
+ json=gen_params,
385
+ stream=True,
386
+ timeout=WORKER_API_TIMEOUT,
387
+ )
388
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
389
+ if chunk:
390
+ data = json.loads(chunk.decode())
391
+ yield data
392
+
393
+
394
+ def is_limit_reached(model_name, ip):
395
+ monitor_url = "http://localhost:9090"
396
+ try:
397
+ ret = requests.get(
398
+ f"{monitor_url}/is_limit_reached?model={model_name}&user_id={ip}", timeout=1
399
+ )
400
+ obj = ret.json()
401
+ return obj
402
+ except Exception as e:
403
+ logger.info(f"monitor error: {e}")
404
+ return None
405
+
406
+
407
+ def bot_response(
408
+ state,
409
+ temperature,
410
+ top_p,
411
+ max_new_tokens,
412
+ request: gr.Request,
413
+ apply_rate_limit=True,
414
+ use_recommended_config=False,
415
+ ):
416
+ ip = get_ip(request)
417
+ logger.info(f"bot_response. ip: {ip}")
418
+ start_tstamp = time.time()
419
+ temperature = float(temperature)
420
+ top_p = float(top_p)
421
+ max_new_tokens = int(max_new_tokens)
422
+
423
+ if state.skip_next:
424
+ # This generate call is skipped due to invalid inputs
425
+ state.skip_next = False
426
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
427
+ return
428
+
429
+ if apply_rate_limit:
430
+ ret = is_limit_reached(state.model_name, ip)
431
+ if ret is not None and ret["is_limit_reached"]:
432
+ error_msg = RATE_LIMIT_MSG + "\n\n" + ret["reason"]
433
+ logger.info(
434
+ f"rate limit reached. ip: {ip}. error_msg: {ret['reason']}")
435
+ state.conv.update_last_message(error_msg)
436
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
437
+ return
438
+
439
+ conv, model_name = state.conv, state.model_name
440
+ model_api_dict = (
441
+ api_endpoint_info[model_name] if model_name in api_endpoint_info else None
442
+ )
443
+ images = conv.get_images()
444
+
445
+ if model_api_dict is None:
446
+ # Query worker address
447
+ ret = requests.post(
448
+ controller_url + "/get_worker_address", json={"model": model_name}
449
+ )
450
+ worker_addr = ret.json()["address"]
451
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
452
+
453
+ # No available worker
454
+ if worker_addr == "":
455
+ conv.update_last_message(SERVER_ERROR_MSG)
456
+ yield (
457
+ state,
458
+ state.to_gradio_chatbot(),
459
+ disable_btn,
460
+ disable_btn,
461
+ disable_btn,
462
+ enable_btn,
463
+ enable_btn,
464
+ )
465
+ return
466
+
467
+ # Construct prompt.
468
+ # We need to call it here, so it will not be affected by "▌".
469
+ prompt = conv.get_prompt()
470
+ # Set repetition_penalty
471
+ if "t5" in model_name:
472
+ repetition_penalty = 1.2
473
+ else:
474
+ repetition_penalty = 1.0
475
+
476
+ stream_iter = model_worker_stream_iter(
477
+ conv,
478
+ model_name,
479
+ worker_addr,
480
+ prompt,
481
+ temperature,
482
+ repetition_penalty,
483
+ top_p,
484
+ max_new_tokens,
485
+ images,
486
+ )
487
+ else:
488
+ # Remove system prompt for API-based models unless specified
489
+ custom_system_prompt = model_api_dict.get(
490
+ "custom_system_prompt", False)
491
+ if not custom_system_prompt:
492
+ conv.set_system_message("")
493
+
494
+ if use_recommended_config:
495
+ recommended_config = model_api_dict.get("recommended_config", None)
496
+ if recommended_config is not None:
497
+ temperature = recommended_config.get(
498
+ "temperature", temperature)
499
+ top_p = recommended_config.get("top_p", top_p)
500
+ max_new_tokens = recommended_config.get(
501
+ "max_new_tokens", max_new_tokens
502
+ )
503
+
504
+ stream_iter = get_api_provider_stream_iter(
505
+ conv,
506
+ model_name,
507
+ model_api_dict,
508
+ temperature,
509
+ top_p,
510
+ max_new_tokens,
511
+ state,
512
+ )
513
+
514
+ html_code = ' <span class="cursor"></span> '
515
+
516
+ # conv.update_last_message("▌")
517
+ conv.update_last_message(html_code)
518
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
519
+
520
+ try:
521
+ data = {"text": ""}
522
+ for i, data in enumerate(stream_iter):
523
+ if data["error_code"] == 0:
524
+ output = data["text"].strip()
525
+ conv.update_last_message(output + "▌")
526
+ # conv.update_last_message(output + html_code)
527
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
528
+ else:
529
+ output = data["text"] + \
530
+ f"\n\n(error_code: {data['error_code']})"
531
+ conv.update_last_message(output)
532
+ yield (state, state.to_gradio_chatbot()) + (
533
+ disable_btn,
534
+ disable_btn,
535
+ disable_btn,
536
+ enable_btn,
537
+ enable_btn,
538
+ )
539
+ return
540
+ output = data["text"].strip()
541
+ conv.update_last_message(output)
542
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
543
+ except requests.exceptions.RequestException as e:
544
+ conv.update_last_message(
545
+ f"{SERVER_ERROR_MSG}\n\n"
546
+ f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
547
+ )
548
+ yield (state, state.to_gradio_chatbot()) + (
549
+ disable_btn,
550
+ disable_btn,
551
+ disable_btn,
552
+ enable_btn,
553
+ enable_btn,
554
+ )
555
+ return
556
+ except Exception as e:
557
+ conv.update_last_message(
558
+ f"{SERVER_ERROR_MSG}\n\n"
559
+ f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
560
+ )
561
+ yield (state, state.to_gradio_chatbot()) + (
562
+ disable_btn,
563
+ disable_btn,
564
+ disable_btn,
565
+ enable_btn,
566
+ enable_btn,
567
+ )
568
+ return
569
+
570
+ finish_tstamp = time.time()
571
+ logger.info(f"{output}")
572
+
573
+ conv.save_new_images(
574
+ has_csam_images=state.has_csam_image, use_remote_storage=use_remote_storage
575
+ )
576
+
577
+ filename = get_conv_log_filename(
578
+ is_vision=state.is_vision, has_csam_image=state.has_csam_image
579
+ )
580
+
581
+ with open(filename, "a") as fout:
582
+ data = {
583
+ "tstamp": round(finish_tstamp, 4),
584
+ "type": "chat",
585
+ "model": model_name,
586
+ "gen_params": {
587
+ "temperature": temperature,
588
+ "top_p": top_p,
589
+ "max_new_tokens": max_new_tokens,
590
+ },
591
+ "start": round(start_tstamp, 4),
592
+ "finish": round(finish_tstamp, 4),
593
+ "state": state.dict(),
594
+ "ip": get_ip(request),
595
+ }
596
+ fout.write(json.dumps(data) + "\n")
597
+ get_remote_logger().log(data)
598
+
599
+
600
+ block_css = """
601
+ .prose {
602
+ font-size: 105% !important;
603
+ }
604
+
605
+ #arena_leaderboard_dataframe table {
606
+ font-size: 105%;
607
+ }
608
+ #full_leaderboard_dataframe table {
609
+ font-size: 105%;
610
+ }
611
+
612
+ .tab-nav button {
613
+ font-size: 18px;
614
+ }
615
+
616
+ .chatbot h1 {
617
+ font-size: 130%;
618
+ }
619
+ .chatbot h2 {
620
+ font-size: 120%;
621
+ }
622
+ .chatbot h3 {
623
+ font-size: 110%;
624
+ }
625
+
626
+ #chatbot .prose {
627
+ font-size: 90% !important;
628
+ }
629
+
630
+ .sponsor-image-about img {
631
+ margin: 0 20px;
632
+ margin-top: 20px;
633
+ height: 40px;
634
+ max-height: 100%;
635
+ width: auto;
636
+ float: left;
637
+ }
638
+
639
+ .cursor {
640
+ display: inline-block;
641
+ width: 7px;
642
+ height: 1em;
643
+ background-color: black;
644
+ vertical-align: middle;
645
+ animation: blink 1s infinite;
646
+ }
647
+
648
+ .dark .cursor {
649
+ display: inline-block;
650
+ width: 7px;
651
+ height: 1em;
652
+ background-color: white;
653
+ vertical-align: middle;
654
+ animation: blink 1s infinite;
655
+ }
656
+
657
+ @keyframes blink {
658
+ 0%, 50% { opacity: 1; }
659
+ 50.1%, 100% { opacity: 0; }
660
+ }
661
+
662
+ .app {
663
+ max-width: 100% !important;
664
+ padding-left: 5% !important;
665
+ padding-right: 5% !important;
666
+ }
667
+
668
+ a {
669
+ color: #1976D2; /* Your current link color, a shade of blue */
670
+ text-decoration: none; /* Removes underline from links */
671
+ }
672
+ a:hover {
673
+ color: #63A4FF; /* This can be any color you choose for hover */
674
+ text-decoration: underline; /* Adds underline on hover */
675
+ }
676
+
677
+ .block {
678
+ overflow-y: hidden !important;
679
+ }
680
+ """
681
+
682
+
683
+ # block_css = """
684
+ # #notice_markdown .prose {
685
+ # font-size: 110% !important;
686
+ # }
687
+ # #notice_markdown th {
688
+ # display: none;
689
+ # }
690
+ # #notice_markdown td {
691
+ # padding-top: 6px;
692
+ # padding-bottom: 6px;
693
+ # }
694
+ # #arena_leaderboard_dataframe table {
695
+ # font-size: 110%;
696
+ # }
697
+ # #full_leaderboard_dataframe table {
698
+ # font-size: 110%;
699
+ # }
700
+ # #model_description_markdown {
701
+ # font-size: 110% !important;
702
+ # }
703
+ # #leaderboard_markdown .prose {
704
+ # font-size: 110% !important;
705
+ # }
706
+ # #leaderboard_markdown td {
707
+ # padding-top: 6px;
708
+ # padding-bottom: 6px;
709
+ # }
710
+ # #leaderboard_dataframe td {
711
+ # line-height: 0.1em;
712
+ # }
713
+ # #about_markdown .prose {
714
+ # font-size: 110% !important;
715
+ # }
716
+ # #ack_markdown .prose {
717
+ # font-size: 110% !important;
718
+ # }
719
+ # #chatbot .prose {
720
+ # font-size: 105% !important;
721
+ # }
722
+ # .sponsor-image-about img {
723
+ # margin: 0 20px;
724
+ # margin-top: 20px;
725
+ # height: 40px;
726
+ # max-height: 100%;
727
+ # width: auto;
728
+ # float: left;
729
+ # }
730
+
731
+ # body {
732
+ # --body-text-size: 14px;
733
+ # }
734
+
735
+ # .chatbot h1, h2, h3 {
736
+ # margin-top: 8px; /* Adjust the value as needed */
737
+ # margin-bottom: 0px; /* Adjust the value as needed */
738
+ # padding-bottom: 0px;
739
+ # }
740
+
741
+ # .chatbot h1 {
742
+ # font-size: 130%;
743
+ # }
744
+ # .chatbot h2 {
745
+ # font-size: 120%;
746
+ # }
747
+ # .chatbot h3 {
748
+ # font-size: 110%;
749
+ # }
750
+ # .chatbot p:not(:first-child) {
751
+ # margin-top: 8px;
752
+ # }
753
+
754
+ # .typing {
755
+ # display: inline-block;
756
+ # }
757
+
758
+ # """
759
+
760
+
761
+ def get_model_description_md(models):
762
+ model_description_md = """
763
+ | | | |
764
+ | ---- | ---- | ---- |
765
+ """
766
+ return ""
767
+ ct = 0
768
+ visited = set()
769
+ for i, name in enumerate(models):
770
+ # minfo = ""
771
+ minfo = get_model_info(name)
772
+ if minfo.simple_name in visited:
773
+ continue
774
+ visited.add(minfo.simple_name)
775
+ one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
776
+
777
+ if ct % 3 == 0:
778
+ model_description_md += "|"
779
+ model_description_md += f" {one_model_md} |"
780
+ if ct % 3 == 2:
781
+ model_description_md += "\n"
782
+ ct += 1
783
+ return model_description_md
784
+
785
+
786
+ def build_about():
787
+ about_markdown = """
788
+ # About Us
789
+
790
+ """
791
+ gr.Markdown(about_markdown, elem_id="about_markdown")
792
+
793
+
794
+ def build_single_model_ui(models, add_promotion_links=False):
795
+ promotion = (
796
+ f"""
797
+
798
+ {SURVEY_LINK}
799
+
800
+ ## 👇 Choose any model to chat
801
+ """
802
+ if add_promotion_links
803
+ else ""
804
+ )
805
+
806
+ notice_markdown = f"""
807
+ # 🏔️ Chatbot Arena
808
+ {promotion}
809
+ """
810
+
811
+ state = gr.State()
812
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
813
+
814
+ with gr.Group(elem_id="share-region-named"):
815
+ with gr.Row(elem_id="model_selector_row"):
816
+ model_selector = gr.Dropdown(
817
+ choices=models,
818
+ value=models[0] if len(models) > 0 else "",
819
+ interactive=True,
820
+ show_label=False,
821
+ container=False,
822
+ )
823
+ with gr.Row():
824
+ with gr.Accordion(
825
+ f"🔍 Expand to see the descriptions of {len(models)} models",
826
+ open=False,
827
+ ):
828
+ model_description_md = get_model_description_md(models)
829
+ gr.Markdown(model_description_md,
830
+ elem_id="model_description_markdown")
831
+
832
+ chatbot = gr.Chatbot(
833
+ elem_id="chatbot",
834
+ label="Scroll down and start chatting",
835
+ height=650,
836
+ show_copy_button=True,
837
+ )
838
+ with gr.Row():
839
+ textbox = gr.Textbox(
840
+ show_label=False,
841
+ placeholder="👉 Enter your prompt and press ENTER",
842
+ elem_id="input_box",
843
+ )
844
+ send_btn = gr.Button(value="Send", variant="primary", scale=0)
845
+
846
+ with gr.Row() as button_row:
847
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
848
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
849
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
850
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
851
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
852
+
853
+ with gr.Accordion("Parameters", open=False) as parameter_row:
854
+ temperature = gr.Slider(
855
+ minimum=0.0,
856
+ maximum=1.0,
857
+ value=0.7,
858
+ step=0.1,
859
+ interactive=True,
860
+ label="Temperature",
861
+ )
862
+ top_p = gr.Slider(
863
+ minimum=0.0,
864
+ maximum=1.0,
865
+ value=1.0,
866
+ step=0.1,
867
+ interactive=True,
868
+ label="Top P",
869
+ )
870
+ max_output_tokens = gr.Slider(
871
+ minimum=16,
872
+ maximum=2048,
873
+ value=1024,
874
+ step=64,
875
+ interactive=True,
876
+ label="Max output tokens",
877
+ )
878
+
879
+ if add_promotion_links:
880
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
881
+
882
+ # Register listeners
883
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
884
+ upvote_btn.click(
885
+ upvote_last_response,
886
+ [state, model_selector],
887
+ [textbox, upvote_btn, downvote_btn, flag_btn],
888
+ )
889
+ downvote_btn.click(
890
+ downvote_last_response,
891
+ [state, model_selector],
892
+ [textbox, upvote_btn, downvote_btn, flag_btn],
893
+ )
894
+ flag_btn.click(
895
+ flag_last_response,
896
+ [state, model_selector],
897
+ [textbox, upvote_btn, downvote_btn, flag_btn],
898
+ )
899
+ regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
900
+ bot_response,
901
+ [state, temperature, top_p, max_output_tokens],
902
+ [state, chatbot] + btn_list,
903
+ )
904
+ clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
905
+
906
+ model_selector.change(clear_history, None, [
907
+ state, chatbot, textbox] + btn_list)
908
+
909
+ textbox.submit(
910
+ add_text,
911
+ [state, model_selector, textbox],
912
+ [state, chatbot, textbox] + btn_list,
913
+ ).then(
914
+ bot_response,
915
+ [state, temperature, top_p, max_output_tokens],
916
+ [state, chatbot] + btn_list,
917
+ )
918
+ send_btn.click(
919
+ add_text,
920
+ [state, model_selector, textbox],
921
+ [state, chatbot, textbox] + btn_list,
922
+ ).then(
923
+ bot_response,
924
+ [state, temperature, top_p, max_output_tokens],
925
+ [state, chatbot] + btn_list,
926
+ )
927
+
928
+ return [state, model_selector]
929
+
930
+
931
+ def build_demo(models):
932
+ with gr.Blocks(
933
+ title="Chatbot Arena (formerly LMSYS): Free AI Chat to Compare & Test Best AI Chatbots",
934
+ theme=gr.themes.Default(),
935
+ css=block_css,
936
+ ) as demo:
937
+ url_params = gr.JSON(visible=False)
938
+
939
+ state, model_selector = build_single_model_ui(models)
940
+
941
+ if args.model_list_mode not in ["once", "reload"]:
942
+ raise ValueError(
943
+ f"Unknown model list mode: {args.model_list_mode}")
944
+
945
+ if args.show_terms_of_use:
946
+ load_js = get_window_url_params_with_tos_js
947
+ else:
948
+ load_js = get_window_url_params_js
949
+
950
+ demo.load(
951
+ load_demo,
952
+ [url_params],
953
+ [
954
+ state,
955
+ model_selector,
956
+ ],
957
+ js=load_js,
958
+ )
959
+
960
+ return demo
961
+
962
+
963
+ if __name__ == "__main__":
964
+ parser = argparse.ArgumentParser()
965
+ parser.add_argument("--host", type=str, default="0.0.0.0")
966
+ parser.add_argument("--port", type=int)
967
+ parser.add_argument(
968
+ "--share",
969
+ action="store_true",
970
+ help="Whether to generate a public, shareable link",
971
+ )
972
+ parser.add_argument(
973
+ "--controller-url",
974
+ type=str,
975
+ default="",
976
+ # default="http://localhost:21001",
977
+ help="The address of the controller",
978
+ )
979
+ parser.add_argument(
980
+ "--concurrency-count",
981
+ type=int,
982
+ default=10,
983
+ help="The concurrency count of the gradio queue",
984
+ )
985
+ parser.add_argument(
986
+ "--model-list-mode",
987
+ type=str,
988
+ default="once",
989
+ choices=["once", "reload"],
990
+ help="Whether to load the model list once or reload the model list every time",
991
+ )
992
+ parser.add_argument(
993
+ "--moderate",
994
+ action="store_true",
995
+ help="Enable content moderation to block unsafe inputs",
996
+ )
997
+ parser.add_argument(
998
+ "--show-terms-of-use",
999
+ action="store_true",
1000
+ help="Shows term of use before loading the demo",
1001
+ )
1002
+ parser.add_argument(
1003
+ "--register-api-endpoint-file",
1004
+ type=str,
1005
+ help="Register API-based model endpoints from a JSON file",
1006
+ )
1007
+ parser.add_argument(
1008
+ "--gradio-auth-path",
1009
+ type=str,
1010
+ help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
1011
+ )
1012
+ parser.add_argument(
1013
+ "--gradio-root-path",
1014
+ type=str,
1015
+ help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix",
1016
+ )
1017
+ parser.add_argument(
1018
+ "--use-remote-storage",
1019
+ action="store_true",
1020
+ default=False,
1021
+ help="Uploads image files to google cloud storage if set to true",
1022
+ )
1023
+ args = parser.parse_args()
1024
+ logger.info(f"args: {args}")
1025
+
1026
+ # Set global variables
1027
+ set_global_vars(args.controller_url, args.moderate,
1028
+ args.use_remote_storage)
1029
+ models, all_models = get_model_list(
1030
+ args.controller_url, args.register_api_endpoint_file, vision_arena=False
1031
+ )
1032
+
1033
+ # Set authorization credentials
1034
+ auth = None
1035
+ if args.gradio_auth_path is not None:
1036
+ auth = parse_gradio_auth_creds(args.gradio_auth_path)
1037
+
1038
+ # Launch the demo
1039
+ demo = build_demo(models)
1040
+ demo.queue(
1041
+ default_concurrency_limit=args.concurrency_count,
1042
+ status_update_rate=10,
1043
+ api_open=False,
1044
+ ).launch(
1045
+ server_name=args.host,
1046
+ server_port=args.port,
1047
+ share=args.share,
1048
+ max_threads=200,
1049
+ auth=auth,
1050
+ root_path=args.gradio_root_path,
1051
+ )
serve/remote_logger.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A JSON logger that sends data to remote endpoint.
2
+ # Architecturally, it hosts a background thread that sends logs to a remote endpoint.
3
+ import os
4
+ import json
5
+ import requests
6
+ import threading
7
+ import queue
8
+ import logging
9
+
10
+ _global_logger = None
11
+
12
+
13
+ def get_remote_logger():
14
+ global _global_logger
15
+ if _global_logger is None:
16
+ if url := os.environ.get("REMOTE_LOGGER_URL"):
17
+ logging.info(f"Remote logger enabled, sending data to {url}")
18
+ _global_logger = RemoteLogger(url=url)
19
+ else:
20
+ _global_logger = EmptyLogger()
21
+ return _global_logger
22
+
23
+
24
+ class EmptyLogger:
25
+ """Dummy logger that does nothing."""
26
+
27
+ def __init__(self):
28
+ pass
29
+
30
+ def log(self, _data: dict):
31
+ pass
32
+
33
+
34
+ class RemoteLogger:
35
+ """A JSON logger that sends data to remote endpoint."""
36
+
37
+ def __init__(self, url: str):
38
+ self.url = url
39
+
40
+ self.logs = queue.Queue()
41
+ self.thread = threading.Thread(target=self._send_logs, daemon=True)
42
+ self.thread.start()
43
+
44
+ def log(self, data: dict):
45
+ self.logs.put_nowait(data)
46
+
47
+ def _send_logs(self):
48
+ while True:
49
+ data = self.logs.get()
50
+
51
+ # process the data by keep only the top level fields, and turn any nested dict into a string
52
+ for key, value in data.items():
53
+ if isinstance(value, (dict, list, tuple)):
54
+ data[key] = json.dumps(value, ensure_ascii=False)
55
+
56
+ try:
57
+ requests.post(self.url, json=data)
58
+ except Exception:
59
+ logging.exception("Failed to send logs to remote endpoint")
serve/utils.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common utilities.
3
+ """
4
+ from asyncio import AbstractEventLoop
5
+ from io import BytesIO
6
+ import base64
7
+ import json
8
+ import logging
9
+ import logging.handlers
10
+ import os
11
+ import platform
12
+ import sys
13
+ import time
14
+ from typing import AsyncGenerator, Generator
15
+ import warnings
16
+
17
+ import requests
18
+
19
+ from .constants import LOGDIR
20
+
21
+
22
+ handler = None
23
+ visited_loggers = set()
24
+
25
+
26
+ def build_logger(logger_name, logger_filename):
27
+ global handler
28
+
29
+ formatter = logging.Formatter(
30
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
31
+ datefmt="%Y-%m-%d %H:%M:%S",
32
+ )
33
+
34
+ # Set the format of root handlers
35
+ if not logging.getLogger().handlers:
36
+ if sys.version_info[1] >= 9:
37
+ # This is for windows
38
+ logging.basicConfig(level=logging.INFO, encoding="utf-8")
39
+ else:
40
+ if platform.system() == "Windows":
41
+ warnings.warn(
42
+ "If you are running on Windows, "
43
+ "we recommend you use Python >= 3.9 for UTF-8 encoding."
44
+ )
45
+ logging.basicConfig(level=logging.INFO)
46
+ logging.getLogger().handlers[0].setFormatter(formatter)
47
+
48
+ # Redirect stdout and stderr to loggers
49
+ stdout_logger = logging.getLogger("stdout")
50
+ stdout_logger.setLevel(logging.INFO)
51
+ sl = StreamToLogger(stdout_logger, logging.INFO)
52
+ sys.stdout = sl
53
+
54
+ stderr_logger = logging.getLogger("stderr")
55
+ stderr_logger.setLevel(logging.ERROR)
56
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
57
+ sys.stderr = sl
58
+
59
+ # Get logger
60
+ logger = logging.getLogger(logger_name)
61
+ logger.setLevel(logging.INFO)
62
+
63
+ # Avoid httpx flooding POST logs
64
+ logging.getLogger("httpx").setLevel(logging.WARNING)
65
+
66
+ # if LOGDIR is empty, then don't try output log to local file
67
+ if LOGDIR != "":
68
+ os.makedirs(LOGDIR, exist_ok=True)
69
+ filename = os.path.join(LOGDIR, logger_filename)
70
+ handler = logging.handlers.TimedRotatingFileHandler(
71
+ filename, when="D", utc=True, encoding="utf-8"
72
+ )
73
+ handler.setFormatter(formatter)
74
+
75
+ for l in [stdout_logger, stderr_logger, logger]:
76
+ if l in visited_loggers:
77
+ continue
78
+ visited_loggers.add(l)
79
+ l.addHandler(handler)
80
+
81
+ return logger
82
+
83
+
84
+ class StreamToLogger(object):
85
+ """
86
+ Fake file-like stream object that redirects writes to a logger instance.
87
+ """
88
+
89
+ def __init__(self, logger, log_level=logging.INFO):
90
+ self.terminal = sys.stdout
91
+ self.logger = logger
92
+ self.log_level = log_level
93
+ self.linebuf = ""
94
+
95
+ def __getattr__(self, attr):
96
+ return getattr(self.terminal, attr)
97
+
98
+ def write(self, buf):
99
+ temp_linebuf = self.linebuf + buf
100
+ self.linebuf = ""
101
+ for line in temp_linebuf.splitlines(True):
102
+ # From the io.TextIOWrapper docs:
103
+ # On output, if newline is None, any '\n' characters written
104
+ # are translated to the system default line separator.
105
+ # By default sys.stdout.write() expects '\n' newlines and then
106
+ # translates them so this is still cross platform.
107
+ if line[-1] == "\n":
108
+ encoded_message = line.encode(
109
+ "utf-8", "ignore").decode("utf-8")
110
+ self.logger.log(self.log_level, encoded_message.rstrip())
111
+ else:
112
+ self.linebuf += line
113
+
114
+ def flush(self):
115
+ if self.linebuf != "":
116
+ encoded_message = self.linebuf.encode(
117
+ "utf-8", "ignore").decode("utf-8")
118
+ self.logger.log(self.log_level, encoded_message.rstrip())
119
+ self.linebuf = ""
120
+
121
+
122
+ def disable_torch_init():
123
+ """
124
+ Disable the redundant torch default initialization to accelerate model creation.
125
+ """
126
+ import torch
127
+
128
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
129
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
130
+
131
+
132
+ def get_gpu_memory(max_gpus=None):
133
+ """Get available memory for each GPU."""
134
+ import torch
135
+
136
+ gpu_memory = []
137
+ num_gpus = (
138
+ torch.cuda.device_count()
139
+ if max_gpus is None
140
+ else min(max_gpus, torch.cuda.device_count())
141
+ )
142
+
143
+ for gpu_id in range(num_gpus):
144
+ with torch.cuda.device(gpu_id):
145
+ device = torch.cuda.current_device()
146
+ gpu_properties = torch.cuda.get_device_properties(device)
147
+ total_memory = gpu_properties.total_memory / (1024**3)
148
+ allocated_memory = torch.cuda.memory_allocated() / (1024**3)
149
+ available_memory = total_memory - allocated_memory
150
+ gpu_memory.append(available_memory)
151
+ return gpu_memory
152
+
153
+
154
+ def oai_moderation(text, custom_thresholds=None):
155
+ """
156
+ Check whether the text violates OpenAI moderation API.
157
+ """
158
+ import openai
159
+
160
+ client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
161
+
162
+ # default to true to be conservative
163
+ flagged = True
164
+ MAX_RETRY = 3
165
+ for _ in range(MAX_RETRY):
166
+ try:
167
+ res = client.moderations.create(input=text)
168
+ flagged = res.results[0].flagged
169
+ if custom_thresholds is not None:
170
+ for category, threshold in custom_thresholds.items():
171
+ if getattr(res.results[0].category_scores, category) > threshold:
172
+ flagged = True
173
+ break
174
+ except (openai.OpenAIError, KeyError, IndexError) as e:
175
+ print(f"MODERATION ERROR: {e}\nInput: {text}")
176
+ return flagged
177
+
178
+
179
+ def moderation_filter(text, model_list, do_moderation=False):
180
+ # Apply moderation for below models
181
+ MODEL_KEYWORDS = [
182
+ "claude",
183
+ "gpt",
184
+ "bard",
185
+ "mistral-large",
186
+ "command-r",
187
+ "dbrx",
188
+ "gemini",
189
+ "reka",
190
+ "eureka",
191
+ ]
192
+
193
+ custom_thresholds = {"sexual": 0.3}
194
+ # set a stricter threshold for claude
195
+ for model in model_list:
196
+ if "claude" in model:
197
+ custom_thresholds = {"sexual": 0.2}
198
+
199
+ for keyword in MODEL_KEYWORDS:
200
+ for model in model_list:
201
+ if keyword in model:
202
+ do_moderation = True
203
+ break
204
+
205
+ if do_moderation:
206
+ return oai_moderation(text, custom_thresholds)
207
+ return False
208
+
209
+
210
+ def clean_flant5_ckpt(ckpt_path):
211
+ """
212
+ Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings,
213
+ Use this function to make sure it can be correctly loaded.
214
+ """
215
+ import torch
216
+
217
+ index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
218
+ index_json = json.load(open(index_file, "r"))
219
+
220
+ weightmap = index_json["weight_map"]
221
+
222
+ share_weight_file = weightmap["shared.weight"]
223
+ share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[
224
+ "shared.weight"
225
+ ]
226
+
227
+ for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]:
228
+ weight_file = weightmap[weight_name]
229
+ weight = torch.load(os.path.join(ckpt_path, weight_file))
230
+ weight[weight_name] = share_weight
231
+ torch.save(weight, os.path.join(ckpt_path, weight_file))
232
+
233
+
234
+ def pretty_print_semaphore(semaphore):
235
+ """Print a semaphore in better format."""
236
+ if semaphore is None:
237
+ return "None"
238
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
239
+
240
+
241
+ """A javascript function to get url parameters for the gradio web server."""
242
+ get_window_url_params_js = """
243
+ function() {
244
+ const params = new URLSearchParams(window.location.search);
245
+ url_params = Object.fromEntries(params);
246
+ console.log("url_params", url_params);
247
+ return url_params;
248
+ }
249
+ """
250
+
251
+ get_window_url_params_with_tos_js = """
252
+ function() {
253
+ const params = new URLSearchParams(window.location.search);
254
+ const url_params = Object.fromEntries(params);
255
+ console.log("url_params", url_params);
256
+
257
+ const urlContainsLeaderboard = Object.keys(url_params).some(key => key.toLowerCase().includes("leaderboard"));
258
+ const msg = "Users of this website are required to agree to the following terms:\\n\\nThe service is a research preview. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\\nPlease do not upload any private information.\\nThe service collects user dialogue data, including both text and images, and reserves the right to use it for future AI development and distribute it under a Creative Commons Attribution (CC-BY) or a similar license.";
259
+ if (!urlContainsLeaderboard) {
260
+ if (window.alerted_before) return;
261
+ alert(msg);
262
+ window.alerted_before = true;
263
+ }
264
+ return url_params;
265
+ }
266
+ """
267
+
268
+ alert_js = """
269
+ () => {
270
+ if (window.alerted_before) return;
271
+ const msg = "Users of this website are required to agree to the following terms:\\n\\nThe service is a research preview. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\\nPlease do not upload any private information.\\nThe service collects user dialogue data, including both text and images, and reserves the right to use if for future AI development and distribute it under a Creative Commons Attribution (CC-BY) or a similar license.";
272
+ alert(msg);
273
+ window.alerted_before = true;
274
+ }
275
+ """
276
+
277
+
278
+ def iter_over_async(
279
+ async_gen: AsyncGenerator, event_loop: AbstractEventLoop
280
+ ) -> Generator:
281
+ """
282
+ Convert async generator to sync generator
283
+
284
+ :param async_gen: the AsyncGenerator to convert
285
+ :param event_loop: the event loop to run on
286
+ :returns: Sync generator
287
+ """
288
+ ait = async_gen.__aiter__()
289
+
290
+ async def get_next():
291
+ try:
292
+ obj = await ait.__anext__()
293
+ return False, obj
294
+ except StopAsyncIteration:
295
+ return True, None
296
+
297
+ while True:
298
+ done, obj = event_loop.run_until_complete(get_next())
299
+ if done:
300
+ break
301
+ yield obj
302
+
303
+
304
+ def detect_language(text: str) -> str:
305
+ # とりあえず日本語
306
+ return "ja"
307
+ """Detect the langauge of a string."""
308
+ import polyglot # pip3 install polyglot pyicu pycld2
309
+ from polyglot.detect import Detector
310
+ from polyglot.detect.base import logger as polyglot_logger
311
+ import pycld2
312
+
313
+ polyglot_logger.setLevel("ERROR")
314
+
315
+ try:
316
+ lang_code = Detector(text).language.name
317
+ except (pycld2.error, polyglot.detect.base.UnknownLanguage):
318
+ lang_code = "unknown"
319
+ return lang_code
320
+
321
+
322
+ def parse_gradio_auth_creds(filename: str):
323
+ """Parse a username:password file for gradio authorization."""
324
+ gradio_auth_creds = []
325
+ with open(filename, "r", encoding="utf8") as file:
326
+ for line in file.readlines():
327
+ gradio_auth_creds += [x.strip()
328
+ for x in line.split(",") if x.strip()]
329
+ if gradio_auth_creds:
330
+ auth = [tuple(cred.split(":")) for cred in gradio_auth_creds]
331
+ else:
332
+ auth = None
333
+ return auth
334
+
335
+
336
+ def is_partial_stop(output: str, stop_str: str):
337
+ """Check whether the output contains a partial stop str."""
338
+ for i in range(0, min(len(output), len(stop_str))):
339
+ if stop_str.startswith(output[-i:]):
340
+ return True
341
+ return False
342
+
343
+
344
+ def run_cmd(cmd: str):
345
+ """Run a bash command."""
346
+ print(cmd)
347
+ return os.system(cmd)
348
+
349
+
350
+ def is_sentence_complete(output: str):
351
+ """Check whether the output is a complete sentence."""
352
+ end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”")
353
+ return output.endswith(end_symbols)
354
+
355
+
356
+ # Models don't use the same configuration key for determining the maximum
357
+ # sequence length. Store them here so we can sanely check them.
358
+ # NOTE: The ordering here is important. Some models have two of these and we
359
+ # have a preference for which value gets used.
360
+ SEQUENCE_LENGTH_KEYS = [
361
+ "max_position_embeddings",
362
+ "max_sequence_length",
363
+ "seq_length",
364
+ "max_seq_len",
365
+ "model_max_length",
366
+ ]
367
+
368
+
369
+ def get_context_length(config):
370
+ """Get the context length of a model from a huggingface model config."""
371
+ rope_scaling = getattr(config, "rope_scaling", None)
372
+ if rope_scaling:
373
+ rope_scaling_factor = config.rope_scaling["factor"]
374
+ else:
375
+ rope_scaling_factor = 1
376
+
377
+ for key in SEQUENCE_LENGTH_KEYS:
378
+ val = getattr(config, key, None)
379
+ if val is not None:
380
+ return int(rope_scaling_factor * val)
381
+ return 2048
382
+
383
+
384
+ def str_to_torch_dtype(dtype: str):
385
+ import torch
386
+
387
+ if dtype is None:
388
+ return None
389
+ elif dtype == "float32":
390
+ return torch.float32
391
+ elif dtype == "float16":
392
+ return torch.float16
393
+ elif dtype == "bfloat16":
394
+ return torch.bfloat16
395
+ else:
396
+ raise ValueError(f"Unrecognized dtype: {dtype}")
397
+
398
+
399
+ def load_image(image_file):
400
+ from PIL import Image
401
+ import requests
402
+
403
+ image = None
404
+
405
+ if image_file.startswith("http://") or image_file.startswith("https://"):
406
+ timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
407
+ response = requests.get(image_file, timeout=timeout)
408
+ image = Image.open(BytesIO(response.content))
409
+ elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
410
+ image = Image.open(image_file)
411
+ elif image_file.startswith("data:"):
412
+ image_file = image_file.split(",")[1]
413
+ image = Image.open(BytesIO(base64.b64decode(image_file)))
414
+ else:
415
+ image = Image.open(BytesIO(base64.b64decode(image_file)))
416
+
417
+ return image
418
+
419
+
420
+ def upload_image_file_to_gcs(image, filename):
421
+ from google.cloud import storage
422
+ import io
423
+
424
+ storage_client = storage.Client()
425
+ # upload file to GCS
426
+ bucket = storage_client.get_bucket("arena_service_data")
427
+
428
+ blob = bucket.blob(f"{filename}")
429
+ if not blob.exists():
430
+ buffer = io.BytesIO()
431
+ image.save(buffer, format="PNG")
432
+ buffer.seek(0)
433
+ blob.upload_from_file(buffer, content_type="image/png")
434
+
435
+ return blob.public_url
436
+
437
+
438
+ def get_image_file_from_gcs(filename):
439
+ from google.cloud import storage
440
+
441
+ storage_client = storage.Client()
442
+ bucket = storage_client.get_bucket("arena_service_data")
443
+ blob = bucket.blob(f"{filename}")
444
+ contents = blob.download_as_bytes()
445
+
446
+ return contents
447
+
448
+
449
+ def image_moderation_request(image_bytes, endpoint, api_key):
450
+ headers = {"Content-Type": "image/jpeg",
451
+ "Ocp-Apim-Subscription-Key": api_key}
452
+
453
+ MAX_RETRIES = 3
454
+ for _ in range(MAX_RETRIES):
455
+ response = requests.post(
456
+ endpoint, headers=headers, data=image_bytes).json()
457
+ try:
458
+ if response["Status"]["Code"] == 3000:
459
+ break
460
+ except:
461
+ time.sleep(0.5)
462
+ return response
463
+
464
+
465
+ def image_moderation_provider(image, api_type):
466
+ if api_type == "nsfw":
467
+ endpoint = os.environ["AZURE_IMG_MODERATION_ENDPOINT"]
468
+ api_key = os.environ["AZURE_IMG_MODERATION_API_KEY"]
469
+ response = image_moderation_request(image, endpoint, api_key)
470
+ print(response)
471
+ return response["IsImageAdultClassified"]
472
+ elif api_type == "csam":
473
+ endpoint = (
474
+ "https://api.microsoftmoderator.com/photodna/v1.0/Match?enhance=false"
475
+ )
476
+ api_key = os.environ["PHOTODNA_API_KEY"]
477
+ response = image_moderation_request(image, endpoint, api_key)
478
+ return response["IsMatch"]
479
+
480
+
481
+ def image_moderation_filter(image):
482
+ print(f"moderating image")
483
+
484
+ image_bytes = base64.b64decode(image.base64_str)
485
+
486
+ nsfw_flagged = image_moderation_provider(image_bytes, "nsfw")
487
+ csam_flagged = False
488
+
489
+ if nsfw_flagged:
490
+ csam_flagged = image_moderation_provider(image_bytes, "csam")
491
+
492
+ return nsfw_flagged, csam_flagged