Spaces:
Running
Running
a100 kh
commited on
Commit
•
529989d
1
Parent(s):
264a139
com
Browse files- api_endpoints.json +185 -0
- app copy.py +64 -0
- app.py +352 -49
- requirements.txt +7 -1
- serve/api_provider.py +1361 -0
- serve/constants.py +82 -0
- serve/conversation.py +0 -0
- serve/gradio_block_arena_anony.py +654 -0
- serve/gradio_block_arena_named.py +510 -0
- serve/gradio_block_arena_vision.py +508 -0
- serve/gradio_block_arena_vision_anony.py +682 -0
- serve/gradio_block_arena_vision_named.py +583 -0
- serve/gradio_global_state.py +12 -0
- serve/gradio_web_server.py +1051 -0
- serve/remote_logger.py +59 -0
- serve/utils.py +492 -0
api_endpoints.json
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"claude-3-5-sonnet-20240620": {
|
3 |
+
"model_name": "claude-3-5-sonnet-20240620",
|
4 |
+
"api_type": "anthropic",
|
5 |
+
"anony_only": false,
|
6 |
+
"recommended_config": {
|
7 |
+
"temperature": 0.7,
|
8 |
+
"top_p": 1.0
|
9 |
+
},
|
10 |
+
"text-arena": true,
|
11 |
+
"vision-arena": false
|
12 |
+
},
|
13 |
+
"command-r-plus": {
|
14 |
+
"model_name": "command-r-plus",
|
15 |
+
"api_type": "cohere",
|
16 |
+
"anony_only": false,
|
17 |
+
"recommended_config": {
|
18 |
+
"temperature": 0.7,
|
19 |
+
"top_p": 1.0
|
20 |
+
},
|
21 |
+
"text-arena": true,
|
22 |
+
"vision-arena": false
|
23 |
+
},
|
24 |
+
"deepseek-chat": {
|
25 |
+
"model_name": "deepseek-chat",
|
26 |
+
"api_type": "openai-custom-deepinfra",
|
27 |
+
"api_base": "https://api.deepseek.com/v1",
|
28 |
+
"env_api_key": "DEEPSEEK_API_KEY",
|
29 |
+
"anony_only": false,
|
30 |
+
"recommended_config": {
|
31 |
+
"temperature": 0.7,
|
32 |
+
"top_p": 1.0
|
33 |
+
},
|
34 |
+
"text-arena": true,
|
35 |
+
"vision-arena": false
|
36 |
+
},
|
37 |
+
"mistral-large-latest": {
|
38 |
+
"model_name": "mistral-large-latest",
|
39 |
+
"api_type": "mistral",
|
40 |
+
"anony_only": false,
|
41 |
+
"recommended_config": {
|
42 |
+
"temperature": 0.7,
|
43 |
+
"top_p": 1.0
|
44 |
+
},
|
45 |
+
"text-arena": true,
|
46 |
+
"vision-arena": false
|
47 |
+
},
|
48 |
+
"Qwen/Qwen2.5-72B-Instruct": {
|
49 |
+
"model_name": "Qwen/Qwen2.5-72B-Instruct",
|
50 |
+
"api_type": "openai-custom-deepinfra",
|
51 |
+
"api_base": "https://api.deepinfra.com/v1/openai",
|
52 |
+
"env_api_key": "DEEPINFRA_API_KEY",
|
53 |
+
"anony_only": false,
|
54 |
+
"recommended_config": {
|
55 |
+
"temperature": 0.7,
|
56 |
+
"top_p": 1.0
|
57 |
+
},
|
58 |
+
"text-arena": true,
|
59 |
+
"vision-arena": false
|
60 |
+
},
|
61 |
+
"google/gemma-2-27b-it": {
|
62 |
+
"model_name": "google/gemma-2-27b-it",
|
63 |
+
"api_type": "openai-custom-deepinfra",
|
64 |
+
"api_base": "https://api.deepinfra.com/v1/openai",
|
65 |
+
"env_api_key": "DEEPINFRA_API_KEY",
|
66 |
+
"anony_only": false,
|
67 |
+
"recommended_config": {
|
68 |
+
"temperature": 0.7,
|
69 |
+
"top_p": 1.0
|
70 |
+
},
|
71 |
+
"text-arena": true,
|
72 |
+
"vision-arena": false
|
73 |
+
},
|
74 |
+
"gemini-1.5-flash-latest": {
|
75 |
+
"model_name": "gemini-1.5-flash-latest",
|
76 |
+
"api_type": "gemini",
|
77 |
+
"anony_only": false,
|
78 |
+
"recommended_config": {
|
79 |
+
"temperature": 0.7,
|
80 |
+
"top_p": 1.0
|
81 |
+
},
|
82 |
+
"text-arena": true,
|
83 |
+
"vision-arena": false
|
84 |
+
},
|
85 |
+
"gemini-1.5-pro-latest": {
|
86 |
+
"model_name": "gemini-1.5-pro-latest",
|
87 |
+
"api_type": "gemini",
|
88 |
+
"anony_only": false,
|
89 |
+
"recommended_config": {
|
90 |
+
"temperature": 0.7,
|
91 |
+
"top_p": 1.0
|
92 |
+
},
|
93 |
+
"text-arena": true,
|
94 |
+
"vision-arena": false
|
95 |
+
},
|
96 |
+
"gpt-4-turbo-2024-04-09": {
|
97 |
+
"model_name": "gpt-4-turbo-2024-04-09",
|
98 |
+
"api_type": "openai",
|
99 |
+
"api_base": "https://api.openai.com/v1",
|
100 |
+
"anony_only": false,
|
101 |
+
"recommended_config": {
|
102 |
+
"temperature": 0.7,
|
103 |
+
"top_p": 1.0
|
104 |
+
},
|
105 |
+
"text-arena": true,
|
106 |
+
"vision-arena": false
|
107 |
+
},
|
108 |
+
"gpt-4o-mini-2024-07-18": {
|
109 |
+
"model_name": "gpt-4o-mini-2024-07-18",
|
110 |
+
"api_type": "openai",
|
111 |
+
"api_base": "https://api.openai.com/v1",
|
112 |
+
"anony_only": false,
|
113 |
+
"recommended_config": {
|
114 |
+
"temperature": 0.7,
|
115 |
+
"top_p": 1.0
|
116 |
+
},
|
117 |
+
"text-arena": true,
|
118 |
+
"vision-arena": false
|
119 |
+
},
|
120 |
+
"tokyotech-llm-Llama-3.1-Swallow-8B-Instruct-v0.1-Q8_0": {
|
121 |
+
"model_name": "tokyotech-llm-Llama-3.1-Swallow-8B-Instruct-v0.1-Q8_0",
|
122 |
+
"api_type": "openai-llama3.1",
|
123 |
+
"api_base": "http://localhost:8010/v1",
|
124 |
+
"api_key": "12345",
|
125 |
+
"anony_only": false,
|
126 |
+
"recommended_config": {
|
127 |
+
"temperature": 0.7,
|
128 |
+
"top_p": 1.0
|
129 |
+
},
|
130 |
+
"text-arena": true,
|
131 |
+
"vision-arena": false
|
132 |
+
},
|
133 |
+
"cyberagent/calm3-22b-chat-BitsAndBytes": {
|
134 |
+
"model_name": "cyberagent/calm3-22b-chat",
|
135 |
+
"api_type": "openai-custom-calm",
|
136 |
+
"api_base": "http://localhost:8011/v1",
|
137 |
+
"api_key": "12345",
|
138 |
+
"anony_only": false,
|
139 |
+
"recommended_config": {
|
140 |
+
"temperature": 0.7,
|
141 |
+
"top_p": 1.0
|
142 |
+
},
|
143 |
+
"text-arena": true,
|
144 |
+
"vision-arena": false
|
145 |
+
},
|
146 |
+
"weblab-GENIAC/Tanuki-8B-dpo-v1.0-BitsAndBytes": {
|
147 |
+
"model_name": "weblab-GENIAC/Tanuki-8B-dpo-v1.0",
|
148 |
+
"api_type": "openai-custom-tanuki",
|
149 |
+
"api_base": "http://localhost:8012/v1",
|
150 |
+
"api_key": "12345",
|
151 |
+
"anony_only": false,
|
152 |
+
"recommended_config": {
|
153 |
+
"temperature": 0.7,
|
154 |
+
"top_p": 1.0
|
155 |
+
},
|
156 |
+
"text-arena": true,
|
157 |
+
"vision-arena": false
|
158 |
+
},
|
159 |
+
"llm-jp-3-13b-instruct-Q8_0.gguf": {
|
160 |
+
"model_name": "llm-jp-3-13b-instruct-Q8_0.gguf",
|
161 |
+
"api_type": "openai-llmjp3",
|
162 |
+
"api_base": "http://localhost:8016/v1",
|
163 |
+
"api_key": "12345",
|
164 |
+
"anony_only": false,
|
165 |
+
"recommended_config": {
|
166 |
+
"temperature": 0.7,
|
167 |
+
"top_p": 1.0
|
168 |
+
},
|
169 |
+
"text-arena": true,
|
170 |
+
"vision-arena": false
|
171 |
+
},
|
172 |
+
"tokyotech-llm/Llama-3.1-Swallow-70B-Instruct-v0.1-BitsAndBytes": {
|
173 |
+
"model_name": "tokyotech-llm/Llama-3.1-Swallow-70B-Instruct-v0.1",
|
174 |
+
"api_type": "openai-llama3.1",
|
175 |
+
"api_base": "http://localhost:8019/v1",
|
176 |
+
"api_key": "12345",
|
177 |
+
"anony_only": false,
|
178 |
+
"recommended_config": {
|
179 |
+
"temperature": 0.7,
|
180 |
+
"top_p": 1.0
|
181 |
+
},
|
182 |
+
"text-arena": true,
|
183 |
+
"vision-arena": false
|
184 |
+
}
|
185 |
+
}
|
app copy.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from huggingface_hub import InferenceClient
|
3 |
+
|
4 |
+
"""
|
5 |
+
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
6 |
+
"""
|
7 |
+
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
8 |
+
|
9 |
+
|
10 |
+
def respond(
|
11 |
+
message,
|
12 |
+
history: list[tuple[str, str]],
|
13 |
+
system_message,
|
14 |
+
max_tokens,
|
15 |
+
temperature,
|
16 |
+
top_p,
|
17 |
+
):
|
18 |
+
messages = [{"role": "system", "content": system_message}]
|
19 |
+
|
20 |
+
for val in history:
|
21 |
+
if val[0]:
|
22 |
+
messages.append({"role": "user", "content": val[0]})
|
23 |
+
if val[1]:
|
24 |
+
messages.append({"role": "assistant", "content": val[1]})
|
25 |
+
|
26 |
+
messages.append({"role": "user", "content": message})
|
27 |
+
|
28 |
+
response = ""
|
29 |
+
|
30 |
+
for message in client.chat_completion(
|
31 |
+
messages,
|
32 |
+
max_tokens=max_tokens,
|
33 |
+
stream=True,
|
34 |
+
temperature=temperature,
|
35 |
+
top_p=top_p,
|
36 |
+
):
|
37 |
+
token = message.choices[0].delta.content
|
38 |
+
|
39 |
+
response += token
|
40 |
+
yield response
|
41 |
+
|
42 |
+
|
43 |
+
"""
|
44 |
+
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
45 |
+
"""
|
46 |
+
demo = gr.ChatInterface(
|
47 |
+
respond,
|
48 |
+
additional_inputs=[
|
49 |
+
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
50 |
+
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
51 |
+
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
52 |
+
gr.Slider(
|
53 |
+
minimum=0.1,
|
54 |
+
maximum=1.0,
|
55 |
+
value=0.95,
|
56 |
+
step=0.05,
|
57 |
+
label="Top-p (nucleus sampling)",
|
58 |
+
),
|
59 |
+
],
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
demo.launch()
|
app.py
CHANGED
@@ -1,64 +1,367 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from huggingface_hub import InferenceClient
|
3 |
-
|
4 |
"""
|
5 |
-
|
|
|
6 |
"""
|
7 |
-
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
8 |
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
message,
|
12 |
-
history: list[tuple[str, str]],
|
13 |
-
system_message,
|
14 |
-
max_tokens,
|
15 |
-
temperature,
|
16 |
-
top_p,
|
17 |
-
):
|
18 |
-
messages = [{"role": "system", "content": system_message}]
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
|
30 |
-
for message in client.chat_completion(
|
31 |
-
messages,
|
32 |
-
max_tokens=max_tokens,
|
33 |
-
stream=True,
|
34 |
-
temperature=temperature,
|
35 |
-
top_p=top_p,
|
36 |
-
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
|
39 |
-
|
40 |
-
|
|
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
"""
|
44 |
-
|
45 |
-
"""
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
|
63 |
if __name__ == "__main__":
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
+
The gradio demo server with multiple tabs.
|
3 |
+
It supports chatting with a single model or chatting with two models side-by-side.
|
4 |
"""
|
|
|
5 |
|
6 |
+
import argparse
|
7 |
+
from typing import List
|
8 |
|
9 |
+
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
from serve.gradio_block_arena_anony import (
|
12 |
+
build_side_by_side_ui_anony,
|
13 |
+
load_demo_side_by_side_anony,
|
14 |
+
set_global_vars_anony,
|
15 |
+
)
|
16 |
+
from serve.gradio_block_arena_named import (
|
17 |
+
build_side_by_side_ui_named,
|
18 |
+
load_demo_side_by_side_named,
|
19 |
+
set_global_vars_named,
|
20 |
+
)
|
21 |
+
from serve.gradio_block_arena_vision import (
|
22 |
+
build_single_vision_language_model_ui,
|
23 |
+
)
|
24 |
+
from serve.gradio_block_arena_vision_anony import (
|
25 |
+
build_side_by_side_vision_ui_anony,
|
26 |
+
load_demo_side_by_side_vision_anony,
|
27 |
+
)
|
28 |
+
from serve.gradio_block_arena_vision_named import (
|
29 |
+
build_side_by_side_vision_ui_named,
|
30 |
+
load_demo_side_by_side_vision_named,
|
31 |
+
)
|
32 |
+
from serve.gradio_global_state import Context
|
33 |
|
34 |
+
from serve.gradio_web_server import (
|
35 |
+
set_global_vars,
|
36 |
+
block_css,
|
37 |
+
build_single_model_ui,
|
38 |
+
get_model_list,
|
39 |
+
load_demo_single,
|
40 |
+
get_ip,
|
41 |
+
)
|
42 |
+
from serve.utils import (
|
43 |
+
build_logger,
|
44 |
+
get_window_url_params_js,
|
45 |
+
get_window_url_params_with_tos_js,
|
46 |
+
alert_js,
|
47 |
+
parse_gradio_auth_creds,
|
48 |
+
)
|
49 |
|
50 |
+
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
def load_demo(context: Context, request: gr.Request):
|
54 |
+
ip = get_ip(request)
|
55 |
+
logger.info(f"load_demo. ip: {ip}. params: {request.query_params}")
|
56 |
|
57 |
+
inner_selected = 0
|
58 |
+
if "arena" in request.query_params:
|
59 |
+
inner_selected = 0
|
60 |
+
elif "vision" in request.query_params:
|
61 |
+
inner_selected = 0
|
62 |
+
elif "compare" in request.query_params:
|
63 |
+
inner_selected = 1
|
64 |
+
elif "direct" in request.query_params or "model" in request.query_params:
|
65 |
+
inner_selected = 2
|
66 |
+
elif "leaderboard" in request.query_params:
|
67 |
+
inner_selected = 3
|
68 |
+
elif "about" in request.query_params:
|
69 |
+
inner_selected = 4
|
70 |
|
71 |
+
if args.model_list_mode == "reload":
|
72 |
+
context.text_models, context.all_text_models = get_model_list(
|
73 |
+
args.controller_url,
|
74 |
+
args.register_api_endpoint_file,
|
75 |
+
vision_arena=False,
|
76 |
+
)
|
77 |
+
|
78 |
+
context.vision_models, context.all_vision_models = get_model_list(
|
79 |
+
args.controller_url,
|
80 |
+
args.register_api_endpoint_file,
|
81 |
+
vision_arena=True,
|
82 |
+
)
|
83 |
+
|
84 |
+
# Text models
|
85 |
+
if args.vision_arena:
|
86 |
+
side_by_side_anony_updates = load_demo_side_by_side_vision_anony()
|
87 |
+
|
88 |
+
side_by_side_named_updates = load_demo_side_by_side_vision_named(
|
89 |
+
context,
|
90 |
+
)
|
91 |
+
|
92 |
+
direct_chat_updates = load_demo_single(context, request.query_params)
|
93 |
+
else:
|
94 |
+
direct_chat_updates = load_demo_single(context, request.query_params)
|
95 |
+
side_by_side_anony_updates = load_demo_side_by_side_anony(
|
96 |
+
context.all_text_models, request.query_params
|
97 |
+
)
|
98 |
+
side_by_side_named_updates = load_demo_side_by_side_named(
|
99 |
+
context.text_models, request.query_params
|
100 |
+
)
|
101 |
+
|
102 |
+
tabs_list = (
|
103 |
+
[gr.Tabs(selected=inner_selected)]
|
104 |
+
+ side_by_side_anony_updates
|
105 |
+
+ side_by_side_named_updates
|
106 |
+
+ direct_chat_updates
|
107 |
+
)
|
108 |
+
|
109 |
+
return tabs_list
|
110 |
+
|
111 |
+
|
112 |
+
def build_demo(
|
113 |
+
context: Context, elo_results_file: str, leaderboard_table_file, arena_hard_table
|
114 |
+
):
|
115 |
+
if args.show_terms_of_use:
|
116 |
+
load_js = get_window_url_params_with_tos_js
|
117 |
+
else:
|
118 |
+
load_js = get_window_url_params_js
|
119 |
+
|
120 |
+
head_js = """
|
121 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.min.js"></script>
|
122 |
"""
|
123 |
+
if args.ga_id is not None:
|
124 |
+
head_js += f"""
|
125 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id={args.ga_id}"></script>
|
126 |
+
<script>
|
127 |
+
window.dataLayer = window.dataLayer || [];
|
128 |
+
function gtag(){{dataLayer.push(arguments);}}
|
129 |
+
gtag('js', new Date());
|
130 |
+
|
131 |
+
gtag('config', '{args.ga_id}');
|
132 |
+
window.__gradio_mode__ = "app";
|
133 |
+
</script>
|
134 |
+
"""
|
135 |
+
|
136 |
+
# head_js = """"""
|
137 |
+
text_size = gr.themes.sizes.text_lg
|
138 |
+
with gr.Blocks(
|
139 |
+
title="Chatbot Arena 日本語版α",
|
140 |
+
theme=gr.themes.Default(text_size=text_size),
|
141 |
+
css=block_css,
|
142 |
+
head=head_js,
|
143 |
+
) as demo:
|
144 |
+
with gr.Tabs() as inner_tabs:
|
145 |
+
if args.vision_arena:
|
146 |
+
with gr.Tab("⚔️ Arena (battle)", id=0) as arena_tab:
|
147 |
+
arena_tab.select(None, None, None, js=load_js)
|
148 |
+
side_by_side_anony_list = build_side_by_side_vision_ui_anony(
|
149 |
+
context,
|
150 |
+
random_questions=args.random_questions,
|
151 |
+
)
|
152 |
+
with gr.Tab("⚔️ Arena (side-by-side)", id=1) as side_by_side_tab:
|
153 |
+
side_by_side_tab.select(None, None, None, js=alert_js)
|
154 |
+
side_by_side_named_list = build_side_by_side_vision_ui_named(
|
155 |
+
context, random_questions=args.random_questions
|
156 |
+
)
|
157 |
+
|
158 |
+
with gr.Tab("💬 Direct Chat", id=2) as direct_tab:
|
159 |
+
direct_tab.select(None, None, None, js=alert_js)
|
160 |
+
single_model_list = build_single_vision_language_model_ui(
|
161 |
+
context,
|
162 |
+
add_promotion_links=True,
|
163 |
+
random_questions=args.random_questions,
|
164 |
+
)
|
165 |
+
|
166 |
+
else:
|
167 |
+
with gr.Tab("⚔️ Arena (battle)", id=0) as arena_tab:
|
168 |
+
arena_tab.select(None, None, None, js=load_js)
|
169 |
+
side_by_side_anony_list = build_side_by_side_ui_anony(
|
170 |
+
context.all_text_models
|
171 |
+
)
|
172 |
+
|
173 |
+
with gr.Tab("⚔️ Arena (side-by-side)", id=1) as side_by_side_tab:
|
174 |
+
side_by_side_tab.select(None, None, None, js=alert_js)
|
175 |
+
side_by_side_named_list = build_side_by_side_ui_named(
|
176 |
+
context.text_models
|
177 |
+
)
|
178 |
+
|
179 |
+
with gr.Tab("💬 Direct Chat", id=2) as direct_tab:
|
180 |
+
direct_tab.select(None, None, None, js=alert_js)
|
181 |
+
single_model_list = build_single_model_ui(
|
182 |
+
context.text_models, add_promotion_links=True
|
183 |
+
)
|
184 |
+
|
185 |
+
demo_tabs = (
|
186 |
+
[inner_tabs]
|
187 |
+
+ side_by_side_anony_list
|
188 |
+
+ side_by_side_named_list
|
189 |
+
+ single_model_list
|
190 |
+
)
|
191 |
+
|
192 |
+
# if elo_results_file:
|
193 |
+
# with gr.Tab("🏆 Leaderboard", id=3):
|
194 |
+
# build_leaderboard_tab(
|
195 |
+
# elo_results_file,
|
196 |
+
# leaderboard_table_file,
|
197 |
+
# arena_hard_table,
|
198 |
+
# show_plot=True,
|
199 |
+
# )
|
200 |
+
|
201 |
+
# with gr.Tab("ℹ️ About Us", id=4):
|
202 |
+
# about = build_about()
|
203 |
+
|
204 |
+
context_state = gr.State(context)
|
205 |
+
url_params = gr.JSON(visible=False)
|
206 |
+
|
207 |
+
if args.model_list_mode not in ["once", "reload"]:
|
208 |
+
raise ValueError(
|
209 |
+
f"Unknown model list mode: {args.model_list_mode}")
|
210 |
+
|
211 |
+
demo.load(
|
212 |
+
load_demo,
|
213 |
+
[context_state],
|
214 |
+
demo_tabs,
|
215 |
+
js=load_js,
|
216 |
+
)
|
217 |
+
|
218 |
+
return demo
|
219 |
|
220 |
|
221 |
if __name__ == "__main__":
|
222 |
+
parser = argparse.ArgumentParser()
|
223 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
224 |
+
parser.add_argument("--port", type=int)
|
225 |
+
parser.add_argument(
|
226 |
+
"--share",
|
227 |
+
action="store_true",
|
228 |
+
help="Whether to generate a public, shareable link",
|
229 |
+
)
|
230 |
+
parser.add_argument(
|
231 |
+
"--controller-url",
|
232 |
+
type=str,
|
233 |
+
default="http://localhost:21001",
|
234 |
+
help="The address of the controller",
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--concurrency-count",
|
238 |
+
type=int,
|
239 |
+
default=10,
|
240 |
+
help="The concurrency count of the gradio queue",
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--model-list-mode",
|
244 |
+
type=str,
|
245 |
+
default="once",
|
246 |
+
choices=["once", "reload"],
|
247 |
+
help="Whether to load the model list once or reload the model list every time.",
|
248 |
+
)
|
249 |
+
parser.add_argument(
|
250 |
+
"--moderate",
|
251 |
+
action="store_true",
|
252 |
+
help="Enable content moderation to block unsafe inputs",
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--show-terms-of-use",
|
256 |
+
action="store_true",
|
257 |
+
help="Shows term of use before loading the demo",
|
258 |
+
)
|
259 |
+
parser.add_argument(
|
260 |
+
"--vision-arena", action="store_true", help="Show tabs for vision arena."
|
261 |
+
)
|
262 |
+
parser.add_argument(
|
263 |
+
"--random-questions", type=str, help="Load random questions from a JSON file"
|
264 |
+
)
|
265 |
+
parser.add_argument(
|
266 |
+
"--register-api-endpoint-file",
|
267 |
+
type=str,
|
268 |
+
help="Register API-based model endpoints from a JSON file",
|
269 |
+
default="api_endpoints.json",
|
270 |
+
)
|
271 |
+
parser.add_argument(
|
272 |
+
"--gradio-auth-path",
|
273 |
+
type=str,
|
274 |
+
help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
|
275 |
+
default=None,
|
276 |
+
)
|
277 |
+
parser.add_argument(
|
278 |
+
"--elo-results-file", type=str, help="Load leaderboard results and plots"
|
279 |
+
)
|
280 |
+
parser.add_argument(
|
281 |
+
"--leaderboard-table-file", type=str, help="Load leaderboard results and plots"
|
282 |
+
)
|
283 |
+
parser.add_argument(
|
284 |
+
"--arena-hard-table", type=str, help="Load leaderboard results and plots"
|
285 |
+
)
|
286 |
+
parser.add_argument(
|
287 |
+
"--gradio-root-path",
|
288 |
+
type=str,
|
289 |
+
help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix",
|
290 |
+
)
|
291 |
+
parser.add_argument(
|
292 |
+
"--ga-id",
|
293 |
+
type=str,
|
294 |
+
help="the Google Analytics ID",
|
295 |
+
default=None,
|
296 |
+
)
|
297 |
+
parser.add_argument(
|
298 |
+
"--use-remote-storage",
|
299 |
+
action="store_true",
|
300 |
+
default=False,
|
301 |
+
help="Uploads image files to google cloud storage if set to true",
|
302 |
+
)
|
303 |
+
parser.add_argument(
|
304 |
+
"--password",
|
305 |
+
type=str,
|
306 |
+
help="Set the password for the gradio web server",
|
307 |
+
)
|
308 |
+
args = parser.parse_args()
|
309 |
+
logger.info(f"args: {args}")
|
310 |
+
|
311 |
+
# Set global variables
|
312 |
+
set_global_vars(args.controller_url, args.moderate,
|
313 |
+
args.use_remote_storage)
|
314 |
+
set_global_vars_named(args.moderate)
|
315 |
+
set_global_vars_anony(args.moderate)
|
316 |
+
text_models, all_text_models = get_model_list(
|
317 |
+
args.controller_url,
|
318 |
+
args.register_api_endpoint_file,
|
319 |
+
vision_arena=False,
|
320 |
+
)
|
321 |
+
|
322 |
+
vision_models, all_vision_models = get_model_list(
|
323 |
+
args.controller_url,
|
324 |
+
args.register_api_endpoint_file,
|
325 |
+
vision_arena=True,
|
326 |
+
)
|
327 |
+
|
328 |
+
models = text_models + [
|
329 |
+
model for model in vision_models if model not in text_models
|
330 |
+
]
|
331 |
+
all_models = all_text_models + [
|
332 |
+
model for model in all_vision_models if model not in all_text_models
|
333 |
+
]
|
334 |
+
context = Context(
|
335 |
+
text_models,
|
336 |
+
all_text_models,
|
337 |
+
vision_models,
|
338 |
+
all_vision_models,
|
339 |
+
models,
|
340 |
+
all_models,
|
341 |
+
)
|
342 |
+
|
343 |
+
# Set authorization credentials
|
344 |
+
auth = None
|
345 |
+
if args.gradio_auth_path is not None:
|
346 |
+
auth = parse_gradio_auth_creds(args.gradio_auth_path)
|
347 |
+
|
348 |
+
# Launch the demo
|
349 |
+
demo = build_demo(
|
350 |
+
context,
|
351 |
+
args.elo_results_file,
|
352 |
+
args.leaderboard_table_file,
|
353 |
+
args.arena_hard_table,
|
354 |
+
)
|
355 |
+
demo.queue(
|
356 |
+
default_concurrency_limit=args.concurrency_count,
|
357 |
+
status_update_rate=10,
|
358 |
+
api_open=False,
|
359 |
+
).launch(
|
360 |
+
server_name=args.host,
|
361 |
+
server_port=args.port,
|
362 |
+
share=args.share,
|
363 |
+
max_threads=200,
|
364 |
+
auth=auth,
|
365 |
+
root_path=args.gradio_root_path,
|
366 |
+
show_api=False,
|
367 |
+
)
|
requirements.txt
CHANGED
@@ -1 +1,7 @@
|
|
1 |
-
huggingface_hub==0.25.2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub==0.25.2
|
2 |
+
openai==1.52.0
|
3 |
+
google-generativeai==0.8.3
|
4 |
+
mistralai==1.1.0
|
5 |
+
cohere==5.11.1
|
6 |
+
anthropic==0.36.2
|
7 |
+
|
serve/api_provider.py
ADDED
@@ -0,0 +1,1361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Call API providers."""
|
2 |
+
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import re
|
7 |
+
from typing import Optional
|
8 |
+
import time
|
9 |
+
|
10 |
+
import requests
|
11 |
+
|
12 |
+
from .utils import build_logger
|
13 |
+
|
14 |
+
|
15 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
16 |
+
|
17 |
+
|
18 |
+
def get_api_provider_stream_iter(
|
19 |
+
conv,
|
20 |
+
model_name,
|
21 |
+
model_api_dict,
|
22 |
+
temperature,
|
23 |
+
top_p,
|
24 |
+
max_new_tokens,
|
25 |
+
state,
|
26 |
+
):
|
27 |
+
if model_api_dict["api_type"] == "openai":
|
28 |
+
if model_api_dict.get("vision-arena", False):
|
29 |
+
prompt = conv.to_openai_vision_api_messages()
|
30 |
+
else:
|
31 |
+
prompt = conv.to_openai_api_messages()
|
32 |
+
stream_iter = openai_api_stream_iter(
|
33 |
+
model_api_dict["model_name"],
|
34 |
+
prompt,
|
35 |
+
temperature,
|
36 |
+
top_p,
|
37 |
+
max_new_tokens,
|
38 |
+
api_base=model_api_dict["api_base"],
|
39 |
+
# api_key=model_api_dict["api_key"],
|
40 |
+
)
|
41 |
+
|
42 |
+
elif model_api_dict["api_type"].find("openai-custom") >= 0:
|
43 |
+
if conv.get_system_message() == "":
|
44 |
+
if model_api_dict["api_type"] == "openai-custom-tanuki":
|
45 |
+
conv.set_system_message('以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。')
|
46 |
+
elif model_api_dict["api_type"] == "openai-custom-calm":
|
47 |
+
conv.set_system_message('あなたは親切なAIアシスタントです。')
|
48 |
+
elif model_api_dict["api_type"] == "openai-custom-deepinfra":
|
49 |
+
conv.set_system_message(
|
50 |
+
'あなたは親切な日本語のアシスタントです。')
|
51 |
+
|
52 |
+
if "api_key" in model_api_dict:
|
53 |
+
api_key = model_api_dict["api_key"]
|
54 |
+
else:
|
55 |
+
api_key = os.environ[model_api_dict["env_api_key"]]
|
56 |
+
|
57 |
+
messages = conv.to_openai_api_messages()
|
58 |
+
stream_iter = openai_api_stream_iter(
|
59 |
+
model_api_dict["model_name"],
|
60 |
+
messages,
|
61 |
+
temperature,
|
62 |
+
top_p,
|
63 |
+
max_new_tokens,
|
64 |
+
api_base=model_api_dict["api_base"],
|
65 |
+
api_key=api_key,
|
66 |
+
# api_key=os.environ[model_api_dict["env_api_key"]],
|
67 |
+
# api_key=model_api_dict["api_key"],
|
68 |
+
)
|
69 |
+
elif model_api_dict["api_type"] == "openai-llama3.1":
|
70 |
+
if conv.get_system_message() == "":
|
71 |
+
conv.set_system_message('あなたは誠実で優秀な日本人のアシスタントです。')
|
72 |
+
|
73 |
+
messages = conv.to_openai_api_messages()
|
74 |
+
stream_iter = openai_api_stream_iter(
|
75 |
+
model_api_dict["model_name"],
|
76 |
+
messages,
|
77 |
+
temperature,
|
78 |
+
top_p,
|
79 |
+
max_new_tokens,
|
80 |
+
api_base=model_api_dict["api_base"],
|
81 |
+
api_key=model_api_dict["api_key"],
|
82 |
+
stop="<|im_end|>",
|
83 |
+
)
|
84 |
+
elif model_api_dict["api_type"] == "openai-llmjp3":
|
85 |
+
if conv.get_system_message() == "":
|
86 |
+
conv.set_system_message('以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。')
|
87 |
+
|
88 |
+
messages = conv.to_openai_api_messages()
|
89 |
+
stream_iter = openai_api_stream_iter(
|
90 |
+
model_api_dict["model_name"],
|
91 |
+
messages,
|
92 |
+
temperature,
|
93 |
+
top_p,
|
94 |
+
max_new_tokens,
|
95 |
+
api_base=model_api_dict["api_base"],
|
96 |
+
api_key=model_api_dict["api_key"],
|
97 |
+
stop="<|im_end|>",
|
98 |
+
)
|
99 |
+
elif model_api_dict["api_type"] == "openai_no_stream":
|
100 |
+
prompt = conv.to_openai_api_messages()
|
101 |
+
stream_iter = openai_api_stream_iter(
|
102 |
+
model_api_dict["model_name"],
|
103 |
+
prompt,
|
104 |
+
temperature,
|
105 |
+
top_p,
|
106 |
+
max_new_tokens,
|
107 |
+
api_base=model_api_dict["api_base"],
|
108 |
+
# api_key=model_api_dict["api_key"],
|
109 |
+
stream=False,
|
110 |
+
)
|
111 |
+
elif model_api_dict["api_type"] == "openai_o1":
|
112 |
+
prompt = conv.to_openai_api_messages()
|
113 |
+
stream_iter = openai_api_stream_iter(
|
114 |
+
model_api_dict["model_name"],
|
115 |
+
prompt,
|
116 |
+
temperature,
|
117 |
+
top_p,
|
118 |
+
max_new_tokens,
|
119 |
+
api_base=model_api_dict["api_base"],
|
120 |
+
api_key=model_api_dict["api_key"],
|
121 |
+
is_o1=True,
|
122 |
+
)
|
123 |
+
elif model_api_dict["api_type"] == "openai_assistant":
|
124 |
+
last_prompt = conv.messages[-2][1]
|
125 |
+
stream_iter = openai_assistant_api_stream_iter(
|
126 |
+
state,
|
127 |
+
last_prompt,
|
128 |
+
assistant_id=model_api_dict["assistant_id"],
|
129 |
+
api_key=model_api_dict["api_key"],
|
130 |
+
)
|
131 |
+
elif model_api_dict["api_type"] == "anthropic":
|
132 |
+
if model_api_dict.get("vision-arena", False):
|
133 |
+
prompt = conv.to_anthropic_vision_api_messages()
|
134 |
+
else:
|
135 |
+
prompt = conv.to_openai_api_messages()
|
136 |
+
stream_iter = anthropic_api_stream_iter(
|
137 |
+
model_name, prompt, temperature, top_p, max_new_tokens
|
138 |
+
)
|
139 |
+
elif model_api_dict["api_type"] == "anthropic_message":
|
140 |
+
if model_api_dict.get("vision-arena", False):
|
141 |
+
prompt = conv.to_anthropic_vision_api_messages()
|
142 |
+
else:
|
143 |
+
prompt = conv.to_openai_api_messages()
|
144 |
+
stream_iter = anthropic_message_api_stream_iter(
|
145 |
+
model_api_dict["model_name"], prompt, temperature, top_p, max_new_tokens
|
146 |
+
)
|
147 |
+
elif model_api_dict["api_type"] == "anthropic_message_vertex":
|
148 |
+
if model_api_dict.get("vision-arena", False):
|
149 |
+
prompt = conv.to_anthropic_vision_api_messages()
|
150 |
+
else:
|
151 |
+
prompt = conv.to_openai_api_messages()
|
152 |
+
stream_iter = anthropic_message_api_stream_iter(
|
153 |
+
model_api_dict["model_name"],
|
154 |
+
prompt,
|
155 |
+
temperature,
|
156 |
+
top_p,
|
157 |
+
max_new_tokens,
|
158 |
+
vertex_ai=True,
|
159 |
+
)
|
160 |
+
elif model_api_dict["api_type"] == "gemini":
|
161 |
+
prompt = conv.to_gemini_api_messages()
|
162 |
+
stream_iter = gemini_api_stream_iter(
|
163 |
+
model_api_dict["model_name"],
|
164 |
+
prompt,
|
165 |
+
temperature,
|
166 |
+
top_p,
|
167 |
+
max_new_tokens,
|
168 |
+
# api_key=model_api_dict["api_key"],
|
169 |
+
)
|
170 |
+
elif model_api_dict["api_type"] == "gemini_no_stream":
|
171 |
+
prompt = conv.to_gemini_api_messages()
|
172 |
+
stream_iter = gemini_api_stream_iter(
|
173 |
+
model_api_dict["model_name"],
|
174 |
+
prompt,
|
175 |
+
temperature,
|
176 |
+
top_p,
|
177 |
+
max_new_tokens,
|
178 |
+
# api_key=model_api_dict["api_key"],
|
179 |
+
use_stream=False,
|
180 |
+
)
|
181 |
+
elif model_api_dict["api_type"] == "bard":
|
182 |
+
prompt = conv.to_openai_api_messages()
|
183 |
+
stream_iter = gemini_api_stream_iter(
|
184 |
+
model_api_dict["model_name"],
|
185 |
+
prompt,
|
186 |
+
None, # use Bard's default temperature
|
187 |
+
None, # use Bard's default top_p
|
188 |
+
max_new_tokens,
|
189 |
+
api_key=(model_api_dict["api_key"] or os.environ["BARD_API_KEY"]),
|
190 |
+
use_stream=False,
|
191 |
+
)
|
192 |
+
elif model_api_dict["api_type"] == "mistral":
|
193 |
+
if model_api_dict.get("vision-arena", False):
|
194 |
+
prompt = conv.to_openai_vision_api_messages(is_mistral=True)
|
195 |
+
else:
|
196 |
+
prompt = conv.to_openai_api_messages()
|
197 |
+
stream_iter = mistral_api_stream_iter(
|
198 |
+
model_api_dict["model_name"],
|
199 |
+
prompt,
|
200 |
+
temperature,
|
201 |
+
top_p,
|
202 |
+
max_new_tokens,
|
203 |
+
api_key=None,
|
204 |
+
)
|
205 |
+
elif model_api_dict["api_type"] == "nvidia":
|
206 |
+
prompt = conv.to_openai_api_messages()
|
207 |
+
stream_iter = nvidia_api_stream_iter(
|
208 |
+
model_name,
|
209 |
+
prompt,
|
210 |
+
temperature,
|
211 |
+
top_p,
|
212 |
+
max_new_tokens,
|
213 |
+
model_api_dict["api_base"],
|
214 |
+
model_api_dict["api_key"],
|
215 |
+
)
|
216 |
+
elif model_api_dict["api_type"] == "ai2":
|
217 |
+
prompt = conv.to_openai_api_messages()
|
218 |
+
stream_iter = ai2_api_stream_iter(
|
219 |
+
model_name,
|
220 |
+
model_api_dict["model_name"],
|
221 |
+
prompt,
|
222 |
+
temperature,
|
223 |
+
top_p,
|
224 |
+
max_new_tokens,
|
225 |
+
api_base=model_api_dict["api_base"],
|
226 |
+
api_key=model_api_dict["api_key"],
|
227 |
+
)
|
228 |
+
elif model_api_dict["api_type"] == "vertex":
|
229 |
+
prompt = conv.to_vertex_api_messages()
|
230 |
+
stream_iter = vertex_api_stream_iter(
|
231 |
+
model_name, prompt, temperature, top_p, max_new_tokens
|
232 |
+
)
|
233 |
+
elif model_api_dict["api_type"] == "yandexgpt":
|
234 |
+
# note: top_p parameter is unused by yandexgpt
|
235 |
+
|
236 |
+
messages = []
|
237 |
+
if conv.system_message:
|
238 |
+
messages.append({"role": "system", "text": conv.system_message})
|
239 |
+
messages += [
|
240 |
+
{"role": role, "text": text}
|
241 |
+
for role, text in conv.messages
|
242 |
+
if text is not None
|
243 |
+
]
|
244 |
+
|
245 |
+
fixed_temperature = model_api_dict.get("fixed_temperature")
|
246 |
+
if fixed_temperature is not None:
|
247 |
+
temperature = fixed_temperature
|
248 |
+
|
249 |
+
stream_iter = yandexgpt_api_stream_iter(
|
250 |
+
model_name=model_api_dict["model_name"],
|
251 |
+
messages=messages,
|
252 |
+
temperature=temperature,
|
253 |
+
max_tokens=max_new_tokens,
|
254 |
+
api_base=model_api_dict["api_base"],
|
255 |
+
api_key=model_api_dict.get("api_key"),
|
256 |
+
folder_id=model_api_dict.get("folder_id"),
|
257 |
+
)
|
258 |
+
elif model_api_dict["api_type"] == "cohere":
|
259 |
+
messages = conv.to_openai_api_messages()
|
260 |
+
stream_iter = cohere_api_stream_iter(
|
261 |
+
client_name=model_api_dict.get("client_name", "FastChat"),
|
262 |
+
model_id=model_api_dict["model_name"],
|
263 |
+
messages=messages,
|
264 |
+
temperature=temperature,
|
265 |
+
top_p=top_p,
|
266 |
+
max_new_tokens=max_new_tokens,
|
267 |
+
# api_base=model_api_dict["api_base"],
|
268 |
+
# api_key=model_api_dict["api_key"],
|
269 |
+
)
|
270 |
+
elif model_api_dict["api_type"] == "reka":
|
271 |
+
messages = conv.to_reka_api_messages()
|
272 |
+
stream_iter = reka_api_stream_iter(
|
273 |
+
model_name=model_api_dict["model_name"],
|
274 |
+
messages=messages,
|
275 |
+
temperature=temperature,
|
276 |
+
top_p=top_p,
|
277 |
+
max_new_tokens=max_new_tokens,
|
278 |
+
api_base=model_api_dict["api_base"],
|
279 |
+
api_key=model_api_dict["api_key"],
|
280 |
+
)
|
281 |
+
elif model_api_dict["api_type"] == "column":
|
282 |
+
if model_api_dict.get("vision-arena", False):
|
283 |
+
prompt = conv.to_openai_vision_api_messages()
|
284 |
+
else:
|
285 |
+
prompt = conv.to_openai_api_messages()
|
286 |
+
stream_iter = column_api_stream_iter(
|
287 |
+
model_name=model_api_dict["model_name"],
|
288 |
+
messages=prompt,
|
289 |
+
temperature=temperature,
|
290 |
+
top_p=top_p,
|
291 |
+
max_new_tokens=max_new_tokens,
|
292 |
+
api_base=model_api_dict["api_base"],
|
293 |
+
api_key=model_api_dict["api_key"],
|
294 |
+
)
|
295 |
+
elif model_api_dict["api_type"] == "metagen":
|
296 |
+
prompt = conv.to_metagen_api_messages()
|
297 |
+
stream_iter = metagen_api_stream_iter(
|
298 |
+
model_api_dict["model_name"],
|
299 |
+
prompt,
|
300 |
+
temperature,
|
301 |
+
top_p,
|
302 |
+
max_new_tokens,
|
303 |
+
api_base=model_api_dict["api_base"],
|
304 |
+
api_key=model_api_dict["api_key"],
|
305 |
+
)
|
306 |
+
else:
|
307 |
+
raise NotImplementedError()
|
308 |
+
|
309 |
+
return stream_iter
|
310 |
+
|
311 |
+
|
312 |
+
def openai_api_stream_iter(
|
313 |
+
model_name,
|
314 |
+
messages,
|
315 |
+
temperature,
|
316 |
+
top_p,
|
317 |
+
max_new_tokens,
|
318 |
+
api_base=None,
|
319 |
+
api_key=None,
|
320 |
+
stream=True,
|
321 |
+
is_o1=False,
|
322 |
+
stop="dummy_stop_token123456789",
|
323 |
+
):
|
324 |
+
import openai
|
325 |
+
|
326 |
+
api_key = api_key or os.environ["OPENAI_API_KEY"]
|
327 |
+
|
328 |
+
if "azure" in model_name:
|
329 |
+
client = openai.AzureOpenAI(
|
330 |
+
api_version="2023-07-01-preview",
|
331 |
+
azure_endpoint=api_base or "https://api.openai.com/v1",
|
332 |
+
api_key=api_key,
|
333 |
+
)
|
334 |
+
else:
|
335 |
+
client = openai.OpenAI(
|
336 |
+
base_url=api_base or "https://api.openai.com/v1",
|
337 |
+
api_key=api_key,
|
338 |
+
timeout=180,
|
339 |
+
)
|
340 |
+
|
341 |
+
# Make requests for logging
|
342 |
+
text_messages = []
|
343 |
+
for message in messages:
|
344 |
+
if type(message["content"]) == str: # text-only model
|
345 |
+
text_messages.append(message)
|
346 |
+
else: # vision model
|
347 |
+
filtered_content_list = [
|
348 |
+
content for content in message["content"] if content["type"] == "text"
|
349 |
+
]
|
350 |
+
text_messages.append(
|
351 |
+
{"role": message["role"], "content": filtered_content_list}
|
352 |
+
)
|
353 |
+
|
354 |
+
gen_params = {
|
355 |
+
"model": model_name,
|
356 |
+
"prompt": text_messages,
|
357 |
+
"temperature": temperature,
|
358 |
+
"top_p": top_p,
|
359 |
+
"max_new_tokens": max_new_tokens,
|
360 |
+
}
|
361 |
+
logger.info(f"==== request ====\n{gen_params}")
|
362 |
+
|
363 |
+
if stream and not is_o1:
|
364 |
+
res = client.chat.completions.create(
|
365 |
+
model=model_name,
|
366 |
+
messages=messages,
|
367 |
+
temperature=temperature,
|
368 |
+
max_tokens=max_new_tokens,
|
369 |
+
stream=True,
|
370 |
+
stop=stop,
|
371 |
+
)
|
372 |
+
text = ""
|
373 |
+
for chunk in res:
|
374 |
+
if len(chunk.choices) > 0:
|
375 |
+
text += chunk.choices[0].delta.content or ""
|
376 |
+
data = {
|
377 |
+
"text": text,
|
378 |
+
"error_code": 0,
|
379 |
+
}
|
380 |
+
yield data
|
381 |
+
else:
|
382 |
+
if is_o1:
|
383 |
+
res = client.chat.completions.create(
|
384 |
+
model=model_name,
|
385 |
+
messages=messages,
|
386 |
+
temperature=1.0,
|
387 |
+
stream=False,
|
388 |
+
)
|
389 |
+
else:
|
390 |
+
res = client.chat.completions.create(
|
391 |
+
model=model_name,
|
392 |
+
messages=messages,
|
393 |
+
temperature=temperature,
|
394 |
+
max_tokens=max_new_tokens,
|
395 |
+
stream=False,
|
396 |
+
)
|
397 |
+
text = res.choices[0].message.content
|
398 |
+
pos = 0
|
399 |
+
while pos < len(text):
|
400 |
+
# simulate token streaming
|
401 |
+
pos += 2
|
402 |
+
time.sleep(0.001)
|
403 |
+
data = {
|
404 |
+
"text": text[:pos],
|
405 |
+
"error_code": 0,
|
406 |
+
}
|
407 |
+
yield data
|
408 |
+
|
409 |
+
|
410 |
+
def column_api_stream_iter(
|
411 |
+
model_name,
|
412 |
+
messages,
|
413 |
+
temperature,
|
414 |
+
top_p,
|
415 |
+
max_new_tokens,
|
416 |
+
api_base=None,
|
417 |
+
api_key=None,
|
418 |
+
):
|
419 |
+
try:
|
420 |
+
messages_no_img = []
|
421 |
+
for msg in messages:
|
422 |
+
msg_no_img = msg.copy()
|
423 |
+
msg_no_img.pop("attachment", None)
|
424 |
+
messages_no_img.append(msg_no_img)
|
425 |
+
|
426 |
+
gen_params = {
|
427 |
+
"model": model_name,
|
428 |
+
"messages": messages_no_img,
|
429 |
+
"temperature": temperature,
|
430 |
+
"top_p": top_p,
|
431 |
+
"max_new_tokens": max_new_tokens,
|
432 |
+
"seed": 42,
|
433 |
+
}
|
434 |
+
logger.info(f"==== request ====\n{gen_params}")
|
435 |
+
|
436 |
+
gen_params["messages"] = messages
|
437 |
+
gen_params["stream"] = True
|
438 |
+
|
439 |
+
# payload.pop("model")
|
440 |
+
|
441 |
+
# try 3 times
|
442 |
+
for i in range(3):
|
443 |
+
try:
|
444 |
+
response = requests.post(
|
445 |
+
api_base, json=gen_params, stream=True, timeout=30
|
446 |
+
)
|
447 |
+
break
|
448 |
+
except Exception as e:
|
449 |
+
logger.error(f"==== error ====\n{e}")
|
450 |
+
if i == 2:
|
451 |
+
yield {
|
452 |
+
"text": f"**API REQUEST ERROR** Reason: API timeout. please try again later.",
|
453 |
+
"error_code": 1,
|
454 |
+
}
|
455 |
+
return
|
456 |
+
|
457 |
+
text = ""
|
458 |
+
for line in response.iter_lines():
|
459 |
+
if line:
|
460 |
+
data = line.decode("utf-8")
|
461 |
+
if data.startswith("data:"):
|
462 |
+
data = json.loads(data[6:])["message"]
|
463 |
+
text += data
|
464 |
+
yield {"text": text, "error_code": 0}
|
465 |
+
|
466 |
+
except Exception as e:
|
467 |
+
logger.error(f"==== error ====\n{e}")
|
468 |
+
yield {
|
469 |
+
"text": f"**API REQUEST ERROR** Reason: Unknown.",
|
470 |
+
"error_code": 1,
|
471 |
+
}
|
472 |
+
|
473 |
+
|
474 |
+
def upload_openai_file_to_gcs(file_id):
|
475 |
+
import openai
|
476 |
+
from google.cloud import storage
|
477 |
+
|
478 |
+
storage_client = storage.Client()
|
479 |
+
|
480 |
+
file = openai.files.content(file_id)
|
481 |
+
# upload file to GCS
|
482 |
+
bucket = storage_client.get_bucket("arena_user_content")
|
483 |
+
blob = bucket.blob(f"{file_id}")
|
484 |
+
blob.upload_from_string(file.read())
|
485 |
+
blob.make_public()
|
486 |
+
return blob.public_url
|
487 |
+
|
488 |
+
|
489 |
+
def openai_assistant_api_stream_iter(
|
490 |
+
state,
|
491 |
+
prompt,
|
492 |
+
assistant_id,
|
493 |
+
api_key=None,
|
494 |
+
):
|
495 |
+
import openai
|
496 |
+
import base64
|
497 |
+
|
498 |
+
api_key = api_key or os.environ["OPENAI_API_KEY"]
|
499 |
+
client = openai.OpenAI(
|
500 |
+
base_url="https://api.openai.com/v1", api_key=api_key)
|
501 |
+
|
502 |
+
if state.oai_thread_id is None:
|
503 |
+
logger.info("==== create thread ====")
|
504 |
+
thread = client.beta.threads.create()
|
505 |
+
state.oai_thread_id = thread.id
|
506 |
+
logger.info(f"==== thread_id ====\n{state.oai_thread_id}")
|
507 |
+
thread_message = client.beta.threads.messages.with_raw_response.create(
|
508 |
+
state.oai_thread_id,
|
509 |
+
role="user",
|
510 |
+
content=prompt,
|
511 |
+
timeout=3,
|
512 |
+
)
|
513 |
+
# logger.info(f"header {thread_message.headers}")
|
514 |
+
thread_message = thread_message.parse()
|
515 |
+
# Make requests
|
516 |
+
gen_params = {
|
517 |
+
"assistant_id": assistant_id,
|
518 |
+
"thread_id": state.oai_thread_id,
|
519 |
+
"message": prompt,
|
520 |
+
}
|
521 |
+
logger.info(f"==== request ====\n{gen_params}")
|
522 |
+
|
523 |
+
res = requests.post(
|
524 |
+
f"https://api.openai.com/v1/threads/{state.oai_thread_id}/runs",
|
525 |
+
headers={
|
526 |
+
"Authorization": f"Bearer {api_key}",
|
527 |
+
"Content-Type": "application/json",
|
528 |
+
"OpenAI-Beta": "assistants=v1",
|
529 |
+
},
|
530 |
+
json={"assistant_id": assistant_id, "stream": True},
|
531 |
+
timeout=30,
|
532 |
+
stream=True,
|
533 |
+
)
|
534 |
+
|
535 |
+
list_of_text = []
|
536 |
+
list_of_raw_text = []
|
537 |
+
offset_idx = 0
|
538 |
+
full_ret_text = ""
|
539 |
+
idx_mapping = {}
|
540 |
+
cur_offset = 0
|
541 |
+
for line in res.iter_lines():
|
542 |
+
if not line:
|
543 |
+
continue
|
544 |
+
data = line.decode("utf-8")
|
545 |
+
# logger.info("data:", data)
|
546 |
+
if data.endswith("[DONE]"):
|
547 |
+
break
|
548 |
+
if data.startswith("event"):
|
549 |
+
event = data.split(":")[1].strip()
|
550 |
+
if event == "thread.message.completed":
|
551 |
+
offset_idx += len(list_of_text)
|
552 |
+
continue
|
553 |
+
data = json.loads(data[6:])
|
554 |
+
|
555 |
+
if data.get("status") == "failed":
|
556 |
+
yield {
|
557 |
+
"text": f"**API REQUEST ERROR** Reason: {data['last_error']['message']}",
|
558 |
+
"error_code": 1,
|
559 |
+
}
|
560 |
+
return
|
561 |
+
|
562 |
+
if data.get("status") == "completed":
|
563 |
+
logger.info(f"[debug]: {data}")
|
564 |
+
|
565 |
+
if data["object"] != "thread.message.delta":
|
566 |
+
continue
|
567 |
+
|
568 |
+
for delta in data["delta"]["content"]:
|
569 |
+
text_index = delta["index"] + offset_idx
|
570 |
+
if len(list_of_text) <= text_index:
|
571 |
+
list_of_text.append("")
|
572 |
+
list_of_raw_text.append("")
|
573 |
+
|
574 |
+
text = list_of_text[text_index]
|
575 |
+
raw_text = list_of_raw_text[text_index]
|
576 |
+
|
577 |
+
if delta["type"] == "text":
|
578 |
+
# text, url_citation or file_path
|
579 |
+
content = delta["text"]
|
580 |
+
if "annotations" in content and len(content["annotations"]) > 0:
|
581 |
+
annotations = content["annotations"]
|
582 |
+
|
583 |
+
raw_text_copy = text
|
584 |
+
for anno in annotations:
|
585 |
+
if anno["type"] == "url_citation":
|
586 |
+
pattern = r"【\d+†source】"
|
587 |
+
matches = re.findall(pattern, content["value"])
|
588 |
+
if len(matches) > 0:
|
589 |
+
for match in matches:
|
590 |
+
print(match)
|
591 |
+
if match not in idx_mapping:
|
592 |
+
idx_mapping[match] = len(
|
593 |
+
idx_mapping) + 1
|
594 |
+
citation_number = idx_mapping[match]
|
595 |
+
|
596 |
+
start_idx = anno["start_index"] + cur_offset
|
597 |
+
end_idx = anno["end_index"] + cur_offset
|
598 |
+
url = anno["url_citation"]["url"]
|
599 |
+
|
600 |
+
citation = f" [[{citation_number}]]({url})"
|
601 |
+
raw_text_copy = (
|
602 |
+
raw_text_copy[:start_idx]
|
603 |
+
+ citation
|
604 |
+
+ raw_text_copy[end_idx:]
|
605 |
+
)
|
606 |
+
cur_offset += len(citation) - (end_idx - start_idx)
|
607 |
+
elif anno["type"] == "file_path":
|
608 |
+
file_public_url = upload_openai_file_to_gcs(
|
609 |
+
anno["file_path"]["file_id"]
|
610 |
+
)
|
611 |
+
raw_text_copy = raw_text_copy.replace(
|
612 |
+
anno["text"], f"{file_public_url}"
|
613 |
+
)
|
614 |
+
text = raw_text_copy
|
615 |
+
else:
|
616 |
+
text_content = content["value"]
|
617 |
+
text += text_content
|
618 |
+
elif delta["type"] == "image_file":
|
619 |
+
image_public_url = upload_openai_file_to_gcs(
|
620 |
+
delta["image_file"]["file_id"]
|
621 |
+
)
|
622 |
+
text += f"![image]({image_public_url})"
|
623 |
+
|
624 |
+
list_of_text[text_index] = text
|
625 |
+
list_of_raw_text[text_index] = raw_text
|
626 |
+
|
627 |
+
full_ret_text = "\n".join(list_of_text)
|
628 |
+
yield {"text": full_ret_text, "error_code": 0}
|
629 |
+
|
630 |
+
|
631 |
+
def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens):
|
632 |
+
import anthropic
|
633 |
+
|
634 |
+
c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
|
635 |
+
|
636 |
+
# Make requests
|
637 |
+
gen_params = {
|
638 |
+
"model": model_name,
|
639 |
+
"prompt": prompt,
|
640 |
+
"temperature": temperature,
|
641 |
+
"top_p": top_p,
|
642 |
+
"max_new_tokens": max_new_tokens,
|
643 |
+
}
|
644 |
+
logger.info(f"==== request ====\n{gen_params}")
|
645 |
+
|
646 |
+
res = c.messages.create(
|
647 |
+
# res = c.completions.create(
|
648 |
+
# prompt=prompt,
|
649 |
+
messages=prompt,
|
650 |
+
# stop_sequences=[anthropic.HUMAN_PROMPT],
|
651 |
+
# max_tokens_to_sample=max_new_tokens,
|
652 |
+
max_tokens=max_new_tokens,
|
653 |
+
temperature=temperature,
|
654 |
+
top_p=top_p,
|
655 |
+
model=model_name,
|
656 |
+
stream=True,
|
657 |
+
)
|
658 |
+
text = ""
|
659 |
+
text = ""
|
660 |
+
for chunk in res:
|
661 |
+
|
662 |
+
if hasattr(chunk, 'delta'):
|
663 |
+
if hasattr(chunk.delta, 'text'):
|
664 |
+
if chunk.delta.text is not None:
|
665 |
+
if isinstance(chunk.delta.text, str):
|
666 |
+
text += chunk.delta.text
|
667 |
+
elif isinstance(chunk.delta.text, list):
|
668 |
+
text += ''.join(chunk.delta.text)
|
669 |
+
elif hasattr(chunk, 'message') and chunk.message.content is not None:
|
670 |
+
if isinstance(chunk.message.content, str):
|
671 |
+
text += chunk.message.content
|
672 |
+
elif isinstance(chunk.message.content, list):
|
673 |
+
text += ''.join(chunk.message.content)
|
674 |
+
else:
|
675 |
+
print(chunk)
|
676 |
+
continue
|
677 |
+
|
678 |
+
data = {
|
679 |
+
"text": text,
|
680 |
+
"error_code": 0,
|
681 |
+
}
|
682 |
+
yield data
|
683 |
+
# for chunk in res:
|
684 |
+
# text += chunk.completion
|
685 |
+
# text += chunk.text_stream
|
686 |
+
# data = {
|
687 |
+
# "text": text,
|
688 |
+
# "error_code": 0,
|
689 |
+
# }
|
690 |
+
# yield data
|
691 |
+
|
692 |
+
|
693 |
+
def anthropic_message_api_stream_iter(
|
694 |
+
model_name,
|
695 |
+
messages,
|
696 |
+
temperature,
|
697 |
+
top_p,
|
698 |
+
max_new_tokens,
|
699 |
+
vertex_ai=False,
|
700 |
+
):
|
701 |
+
import anthropic
|
702 |
+
|
703 |
+
if vertex_ai:
|
704 |
+
client = anthropic.AnthropicVertex(
|
705 |
+
region=os.environ["GCP_LOCATION"],
|
706 |
+
project_id=os.environ["GCP_PROJECT_ID"],
|
707 |
+
max_retries=5,
|
708 |
+
)
|
709 |
+
else:
|
710 |
+
client = anthropic.Anthropic(
|
711 |
+
api_key=os.environ["ANTHROPIC_API_KEY"],
|
712 |
+
max_retries=5,
|
713 |
+
)
|
714 |
+
|
715 |
+
text_messages = []
|
716 |
+
for message in messages:
|
717 |
+
if type(message["content"]) == str: # text-only model
|
718 |
+
text_messages.append(message)
|
719 |
+
else: # vision model
|
720 |
+
filtered_content_list = [
|
721 |
+
content for content in message["content"] if content["type"] == "text"
|
722 |
+
]
|
723 |
+
text_messages.append(
|
724 |
+
{"role": message["role"], "content": filtered_content_list}
|
725 |
+
)
|
726 |
+
|
727 |
+
# Make requests for logging
|
728 |
+
gen_params = {
|
729 |
+
"model": model_name,
|
730 |
+
"prompt": text_messages,
|
731 |
+
"temperature": temperature,
|
732 |
+
"top_p": top_p,
|
733 |
+
"max_new_tokens": max_new_tokens,
|
734 |
+
}
|
735 |
+
logger.info(f"==== request ====\n{gen_params}")
|
736 |
+
|
737 |
+
system_prompt = ""
|
738 |
+
if messages[0]["role"] == "system":
|
739 |
+
if type(messages[0]["content"]) == dict:
|
740 |
+
system_prompt = messages[0]["content"]["text"]
|
741 |
+
elif type(messages[0]["content"]) == str:
|
742 |
+
system_prompt = messages[0]["content"]
|
743 |
+
# remove system prompt
|
744 |
+
messages = messages[1:]
|
745 |
+
|
746 |
+
text = ""
|
747 |
+
with client.messages.stream(
|
748 |
+
temperature=temperature,
|
749 |
+
top_p=top_p,
|
750 |
+
max_tokens=max_new_tokens,
|
751 |
+
messages=messages,
|
752 |
+
model=model_name,
|
753 |
+
system=system_prompt,
|
754 |
+
) as stream:
|
755 |
+
for chunk in stream.text_stream:
|
756 |
+
text += chunk
|
757 |
+
data = {
|
758 |
+
"text": text,
|
759 |
+
"error_code": 0,
|
760 |
+
}
|
761 |
+
yield data
|
762 |
+
|
763 |
+
|
764 |
+
def gemini_api_stream_iter(
|
765 |
+
model_name,
|
766 |
+
messages,
|
767 |
+
temperature,
|
768 |
+
top_p,
|
769 |
+
max_new_tokens,
|
770 |
+
api_key=None,
|
771 |
+
use_stream=True,
|
772 |
+
):
|
773 |
+
import google.generativeai as genai # pip install google-generativeai
|
774 |
+
|
775 |
+
if api_key is None:
|
776 |
+
api_key = os.environ["GEMINI_API_KEY"]
|
777 |
+
genai.configure(api_key=api_key)
|
778 |
+
|
779 |
+
generation_config = {
|
780 |
+
"temperature": temperature,
|
781 |
+
"max_output_tokens": max_new_tokens,
|
782 |
+
"top_p": top_p,
|
783 |
+
}
|
784 |
+
params = {
|
785 |
+
"model": model_name,
|
786 |
+
"prompt": messages,
|
787 |
+
}
|
788 |
+
params.update(generation_config)
|
789 |
+
logger.info(f"==== request ====\n{params}")
|
790 |
+
|
791 |
+
safety_settings = [
|
792 |
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
793 |
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
794 |
+
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
795 |
+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
796 |
+
]
|
797 |
+
|
798 |
+
history = []
|
799 |
+
system_prompt = None
|
800 |
+
for message in messages[:-1]:
|
801 |
+
if message["role"] == "system":
|
802 |
+
system_prompt = message["content"]
|
803 |
+
continue
|
804 |
+
history.append({"role": message["role"], "parts": message["content"]})
|
805 |
+
|
806 |
+
model = genai.GenerativeModel(
|
807 |
+
model_name=model_name,
|
808 |
+
system_instruction=system_prompt,
|
809 |
+
generation_config=generation_config,
|
810 |
+
safety_settings=safety_settings,
|
811 |
+
)
|
812 |
+
convo = model.start_chat(history=history)
|
813 |
+
|
814 |
+
if use_stream:
|
815 |
+
response = convo.send_message(messages[-1]["content"], stream=True)
|
816 |
+
try:
|
817 |
+
text = ""
|
818 |
+
for chunk in response:
|
819 |
+
text += chunk.candidates[0].content.parts[0].text
|
820 |
+
data = {
|
821 |
+
"text": text,
|
822 |
+
"error_code": 0,
|
823 |
+
}
|
824 |
+
yield data
|
825 |
+
except Exception as e:
|
826 |
+
logger.error(f"==== error ====\n{e}")
|
827 |
+
reason = chunk.candidates
|
828 |
+
yield {
|
829 |
+
"text": f"**API REQUEST ERROR** Reason: {reason}.",
|
830 |
+
"error_code": 1,
|
831 |
+
}
|
832 |
+
else:
|
833 |
+
try:
|
834 |
+
response = convo.send_message(
|
835 |
+
messages[-1]["content"], stream=False)
|
836 |
+
text = response.candidates[0].content.parts[0].text
|
837 |
+
pos = 0
|
838 |
+
while pos < len(text):
|
839 |
+
# simulate token streaming
|
840 |
+
pos += 5
|
841 |
+
time.sleep(0.001)
|
842 |
+
data = {
|
843 |
+
"text": text[:pos],
|
844 |
+
"error_code": 0,
|
845 |
+
}
|
846 |
+
yield data
|
847 |
+
except Exception as e:
|
848 |
+
logger.error(f"==== error ====\n{e}")
|
849 |
+
yield {
|
850 |
+
"text": f"**API REQUEST ERROR** Reason: {e}.",
|
851 |
+
"error_code": 1,
|
852 |
+
}
|
853 |
+
|
854 |
+
|
855 |
+
def ai2_api_stream_iter(
|
856 |
+
model_name,
|
857 |
+
model_id,
|
858 |
+
messages,
|
859 |
+
temperature,
|
860 |
+
top_p,
|
861 |
+
max_new_tokens,
|
862 |
+
api_key=None,
|
863 |
+
api_base=None,
|
864 |
+
):
|
865 |
+
# get keys and needed values
|
866 |
+
ai2_key = api_key or os.environ.get("AI2_API_KEY")
|
867 |
+
api_base = api_base or "https://inferd.allen.ai/api/v1/infer"
|
868 |
+
|
869 |
+
# Make requests
|
870 |
+
gen_params = {
|
871 |
+
"model": model_name,
|
872 |
+
"prompt": messages,
|
873 |
+
"temperature": temperature,
|
874 |
+
"top_p": top_p,
|
875 |
+
"max_new_tokens": max_new_tokens,
|
876 |
+
}
|
877 |
+
logger.info(f"==== request ====\n{gen_params}")
|
878 |
+
|
879 |
+
# AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling:
|
880 |
+
# https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157
|
881 |
+
if temperature == 0.0 and top_p < 1.0:
|
882 |
+
raise ValueError("top_p must be 1 when temperature is 0.0")
|
883 |
+
|
884 |
+
res = requests.post(
|
885 |
+
api_base,
|
886 |
+
stream=True,
|
887 |
+
headers={"Authorization": f"Bearer {ai2_key}"},
|
888 |
+
json={
|
889 |
+
"model_id": model_id,
|
890 |
+
# This input format is specific to the Tulu2 model. Other models
|
891 |
+
# may require different input formats. See the model's schema
|
892 |
+
# documentation on InferD for more information.
|
893 |
+
"input": {
|
894 |
+
"messages": messages,
|
895 |
+
"opts": {
|
896 |
+
"max_tokens": max_new_tokens,
|
897 |
+
"temperature": temperature,
|
898 |
+
"top_p": top_p,
|
899 |
+
"logprobs": 1, # increase for more choices
|
900 |
+
},
|
901 |
+
},
|
902 |
+
},
|
903 |
+
timeout=5,
|
904 |
+
)
|
905 |
+
|
906 |
+
if res.status_code != 200:
|
907 |
+
logger.error(f"unexpected response ({res.status_code}): {res.text}")
|
908 |
+
raise ValueError("unexpected response from InferD", res)
|
909 |
+
|
910 |
+
text = ""
|
911 |
+
for line in res.iter_lines():
|
912 |
+
if line:
|
913 |
+
part = json.loads(line)
|
914 |
+
if "result" in part and "output" in part["result"]:
|
915 |
+
for t in part["result"]["output"]["text"]:
|
916 |
+
text += t
|
917 |
+
else:
|
918 |
+
logger.error(f"unexpected part: {part}")
|
919 |
+
raise ValueError("empty result in InferD response")
|
920 |
+
|
921 |
+
data = {
|
922 |
+
"text": text,
|
923 |
+
"error_code": 0,
|
924 |
+
}
|
925 |
+
yield data
|
926 |
+
|
927 |
+
|
928 |
+
def mistral_api_stream_iter(
|
929 |
+
model_name, messages, temperature, top_p, max_new_tokens, api_key=None
|
930 |
+
):
|
931 |
+
# from mistralai.client import MistralClient
|
932 |
+
# from mistralai.models.chat_completion import ChatMessage
|
933 |
+
from mistralai import Mistral
|
934 |
+
|
935 |
+
if api_key is None:
|
936 |
+
api_key = os.environ["MISTRAL_API_KEY"]
|
937 |
+
|
938 |
+
client = Mistral(api_key=api_key)
|
939 |
+
|
940 |
+
# Make requests for logging
|
941 |
+
text_messages = []
|
942 |
+
for message in messages:
|
943 |
+
if type(message["content"]) == str: # text-only model
|
944 |
+
text_messages.append(message)
|
945 |
+
else: # vision model
|
946 |
+
filtered_content_list = [
|
947 |
+
content for content in message["content"] if content["type"] == "text"
|
948 |
+
]
|
949 |
+
text_messages.append(
|
950 |
+
{"role": message["role"], "content": filtered_content_list}
|
951 |
+
)
|
952 |
+
|
953 |
+
# Make requests
|
954 |
+
gen_params = {
|
955 |
+
"model": model_name,
|
956 |
+
"prompt": text_messages,
|
957 |
+
"temperature": temperature,
|
958 |
+
"top_p": top_p,
|
959 |
+
"max_new_tokens": max_new_tokens,
|
960 |
+
}
|
961 |
+
logger.info(f"==== request ====\n{gen_params}")
|
962 |
+
|
963 |
+
# new_messages = [
|
964 |
+
# ChatMessage(role=message["role"], content=message["content"])
|
965 |
+
# for message in messages
|
966 |
+
# ]
|
967 |
+
|
968 |
+
res = client.chat.stream(
|
969 |
+
model=model_name,
|
970 |
+
temperature=temperature,
|
971 |
+
messages=messages,
|
972 |
+
max_tokens=max_new_tokens,
|
973 |
+
top_p=top_p,
|
974 |
+
)
|
975 |
+
|
976 |
+
text = ""
|
977 |
+
for chunk in res:
|
978 |
+
if chunk.data.choices[0].delta.content is not None:
|
979 |
+
text += chunk.data.choices[0].delta.content
|
980 |
+
data = {
|
981 |
+
"text": text,
|
982 |
+
"error_code": 0,
|
983 |
+
}
|
984 |
+
yield data
|
985 |
+
|
986 |
+
|
987 |
+
def nvidia_api_stream_iter(
|
988 |
+
model_name, messages, temp, top_p, max_tokens, api_base, api_key=None
|
989 |
+
):
|
990 |
+
model_2_api = {
|
991 |
+
"nemotron-4-340b": "/b0fcd392-e905-4ab4-8eb9-aeae95c30b37",
|
992 |
+
}
|
993 |
+
api_base += model_2_api[model_name]
|
994 |
+
|
995 |
+
api_key = api_key or os.environ["NVIDIA_API_KEY"]
|
996 |
+
headers = {
|
997 |
+
"Authorization": f"Bearer {api_key}",
|
998 |
+
"accept": "text/event-stream",
|
999 |
+
"content-type": "application/json",
|
1000 |
+
}
|
1001 |
+
# nvidia api does not accept 0 temperature
|
1002 |
+
if temp == 0.0:
|
1003 |
+
temp = 0.000001
|
1004 |
+
|
1005 |
+
payload = {
|
1006 |
+
"model": model_name,
|
1007 |
+
"messages": messages,
|
1008 |
+
"temperature": temp,
|
1009 |
+
"top_p": top_p,
|
1010 |
+
"max_tokens": max_tokens,
|
1011 |
+
"seed": 42,
|
1012 |
+
"stream": True,
|
1013 |
+
}
|
1014 |
+
logger.info(f"==== request ====\n{payload}")
|
1015 |
+
|
1016 |
+
# payload.pop("model")
|
1017 |
+
|
1018 |
+
# try 3 times
|
1019 |
+
for i in range(3):
|
1020 |
+
try:
|
1021 |
+
response = requests.post(
|
1022 |
+
api_base, headers=headers, json=payload, stream=True, timeout=3
|
1023 |
+
)
|
1024 |
+
break
|
1025 |
+
except Exception as e:
|
1026 |
+
logger.error(f"==== error ====\n{e}")
|
1027 |
+
if i == 2:
|
1028 |
+
yield {
|
1029 |
+
"text": f"**API REQUEST ERROR** Reason: API timeout. please try again later.",
|
1030 |
+
"error_code": 1,
|
1031 |
+
}
|
1032 |
+
return
|
1033 |
+
|
1034 |
+
text = ""
|
1035 |
+
for line in response.iter_lines():
|
1036 |
+
if line:
|
1037 |
+
data = line.decode("utf-8")
|
1038 |
+
if data.endswith("[DONE]"):
|
1039 |
+
break
|
1040 |
+
data = json.loads(data[6:])["choices"][0]["delta"]["content"]
|
1041 |
+
text += data
|
1042 |
+
yield {"text": text, "error_code": 0}
|
1043 |
+
|
1044 |
+
|
1045 |
+
def yandexgpt_api_stream_iter(
|
1046 |
+
model_name, messages, temperature, max_tokens, api_base, api_key, folder_id
|
1047 |
+
):
|
1048 |
+
api_key = api_key or os.environ["YANDEXGPT_API_KEY"]
|
1049 |
+
headers = {
|
1050 |
+
"Authorization": f"Api-Key {api_key}",
|
1051 |
+
"content-type": "application/json",
|
1052 |
+
}
|
1053 |
+
|
1054 |
+
payload = {
|
1055 |
+
"modelUri": f"gpt://{folder_id}/{model_name}",
|
1056 |
+
"completionOptions": {
|
1057 |
+
"temperature": temperature,
|
1058 |
+
"max_tokens": max_tokens,
|
1059 |
+
"stream": True,
|
1060 |
+
},
|
1061 |
+
"messages": messages,
|
1062 |
+
}
|
1063 |
+
logger.info(f"==== request ====\n{payload}")
|
1064 |
+
|
1065 |
+
# https://llm.api.cloud.yandex.net/foundationModels/v1/completion
|
1066 |
+
response = requests.post(
|
1067 |
+
api_base, headers=headers, json=payload, stream=True, timeout=60
|
1068 |
+
)
|
1069 |
+
text = ""
|
1070 |
+
for line in response.iter_lines():
|
1071 |
+
if line:
|
1072 |
+
data = json.loads(line.decode("utf-8"))
|
1073 |
+
data = data["result"]
|
1074 |
+
top_alternative = data["alternatives"][0]
|
1075 |
+
text = top_alternative["message"]["text"]
|
1076 |
+
yield {"text": text, "error_code": 0}
|
1077 |
+
|
1078 |
+
status = top_alternative["status"]
|
1079 |
+
if status in (
|
1080 |
+
"ALTERNATIVE_STATUS_FINAL",
|
1081 |
+
"ALTERNATIVE_STATUS_TRUNCATED_FINAL",
|
1082 |
+
):
|
1083 |
+
break
|
1084 |
+
|
1085 |
+
|
1086 |
+
def cohere_api_stream_iter(
|
1087 |
+
client_name: str,
|
1088 |
+
model_id: str,
|
1089 |
+
messages: list,
|
1090 |
+
temperature: Optional[
|
1091 |
+
float
|
1092 |
+
] = None, # The SDK or API handles None for all parameters following
|
1093 |
+
top_p: Optional[float] = None,
|
1094 |
+
max_new_tokens: Optional[int] = None,
|
1095 |
+
api_key: Optional[str] = None, # default is env var CO_API_KEY
|
1096 |
+
api_base: Optional[str] = None,
|
1097 |
+
):
|
1098 |
+
import cohere
|
1099 |
+
if api_key is None:
|
1100 |
+
api_key = os.environ["COHERE_API_KEY"]
|
1101 |
+
|
1102 |
+
OPENAI_TO_COHERE_ROLE_MAP = {
|
1103 |
+
"user": "User",
|
1104 |
+
"assistant": "Chatbot",
|
1105 |
+
"system": "System",
|
1106 |
+
}
|
1107 |
+
|
1108 |
+
# client = cohere.ClientV2(
|
1109 |
+
client = cohere.Client(
|
1110 |
+
api_key=api_key,
|
1111 |
+
# base_url=api_base,
|
1112 |
+
# client_name=client_name,
|
1113 |
+
)
|
1114 |
+
|
1115 |
+
# prepare and log requests
|
1116 |
+
chat_history = [
|
1117 |
+
dict(
|
1118 |
+
role=OPENAI_TO_COHERE_ROLE_MAP[message["role"]
|
1119 |
+
], message=message["content"]
|
1120 |
+
)
|
1121 |
+
for message in messages[:-1]
|
1122 |
+
]
|
1123 |
+
actual_prompt = messages[-1]["content"]
|
1124 |
+
|
1125 |
+
gen_params = {
|
1126 |
+
"model": model_id,
|
1127 |
+
"messages": messages,
|
1128 |
+
"chat_history": chat_history,
|
1129 |
+
"prompt": actual_prompt,
|
1130 |
+
"temperature": temperature,
|
1131 |
+
"top_p": top_p,
|
1132 |
+
"max_new_tokens": max_new_tokens,
|
1133 |
+
}
|
1134 |
+
logger.info(f"==== request ====\n{gen_params}")
|
1135 |
+
|
1136 |
+
# make request and stream response
|
1137 |
+
res = client.chat_stream(
|
1138 |
+
# messages=messages,
|
1139 |
+
message=actual_prompt,
|
1140 |
+
chat_history=chat_history,
|
1141 |
+
model=model_id,
|
1142 |
+
temperature=temperature,
|
1143 |
+
max_tokens=max_new_tokens,
|
1144 |
+
p=top_p,
|
1145 |
+
)
|
1146 |
+
try:
|
1147 |
+
text = ""
|
1148 |
+
for streaming_item in res:
|
1149 |
+
if streaming_item.event_type == "text-generation":
|
1150 |
+
text += streaming_item.text
|
1151 |
+
yield {"text": text, "error_code": 0}
|
1152 |
+
except cohere.core.ApiError as e:
|
1153 |
+
logger.error(f"==== error from cohere api: {e} ====")
|
1154 |
+
yield {
|
1155 |
+
"text": f"**API REQUEST ERROR** Reason: {e}",
|
1156 |
+
"error_code": 1,
|
1157 |
+
}
|
1158 |
+
|
1159 |
+
|
1160 |
+
def vertex_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens):
|
1161 |
+
import vertexai
|
1162 |
+
from vertexai import generative_models
|
1163 |
+
from vertexai.generative_models import (
|
1164 |
+
GenerationConfig,
|
1165 |
+
GenerativeModel,
|
1166 |
+
Image,
|
1167 |
+
)
|
1168 |
+
|
1169 |
+
project_id = os.environ.get("GCP_PROJECT_ID", None)
|
1170 |
+
location = os.environ.get("GCP_LOCATION", None)
|
1171 |
+
vertexai.init(project=project_id, location=location)
|
1172 |
+
|
1173 |
+
text_messages = []
|
1174 |
+
for message in messages:
|
1175 |
+
if type(message) == str:
|
1176 |
+
text_messages.append(message)
|
1177 |
+
|
1178 |
+
gen_params = {
|
1179 |
+
"model": model_name,
|
1180 |
+
"prompt": text_messages,
|
1181 |
+
"temperature": temperature,
|
1182 |
+
"top_p": top_p,
|
1183 |
+
"max_new_tokens": max_new_tokens,
|
1184 |
+
}
|
1185 |
+
logger.info(f"==== request ====\n{gen_params}")
|
1186 |
+
|
1187 |
+
safety_settings = [
|
1188 |
+
generative_models.SafetySetting(
|
1189 |
+
category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
1190 |
+
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
|
1191 |
+
),
|
1192 |
+
generative_models.SafetySetting(
|
1193 |
+
category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
1194 |
+
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
|
1195 |
+
),
|
1196 |
+
generative_models.SafetySetting(
|
1197 |
+
category=generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
1198 |
+
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
|
1199 |
+
),
|
1200 |
+
generative_models.SafetySetting(
|
1201 |
+
category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
1202 |
+
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
|
1203 |
+
),
|
1204 |
+
]
|
1205 |
+
generator = GenerativeModel(model_name).generate_content(
|
1206 |
+
messages,
|
1207 |
+
stream=True,
|
1208 |
+
generation_config=GenerationConfig(
|
1209 |
+
top_p=top_p, max_output_tokens=max_new_tokens, temperature=temperature
|
1210 |
+
),
|
1211 |
+
safety_settings=safety_settings,
|
1212 |
+
)
|
1213 |
+
|
1214 |
+
ret = ""
|
1215 |
+
for chunk in generator:
|
1216 |
+
# NOTE(chris): This may be a vertex api error, below is HOTFIX: https://github.com/googleapis/python-aiplatform/issues/3129
|
1217 |
+
ret += chunk.candidates[0].content.parts[0]._raw_part.text
|
1218 |
+
# ret += chunk.text
|
1219 |
+
data = {
|
1220 |
+
"text": ret,
|
1221 |
+
"error_code": 0,
|
1222 |
+
}
|
1223 |
+
yield data
|
1224 |
+
|
1225 |
+
|
1226 |
+
def reka_api_stream_iter(
|
1227 |
+
model_name: str,
|
1228 |
+
messages: list,
|
1229 |
+
temperature: Optional[
|
1230 |
+
float
|
1231 |
+
] = None, # The SDK or API handles None for all parameters following
|
1232 |
+
top_p: Optional[float] = None,
|
1233 |
+
max_new_tokens: Optional[int] = None,
|
1234 |
+
api_key: Optional[str] = None, # default is env var CO_API_KEY
|
1235 |
+
api_base: Optional[str] = None,
|
1236 |
+
):
|
1237 |
+
from reka.client import Reka
|
1238 |
+
from reka import TypedText
|
1239 |
+
|
1240 |
+
api_key = api_key or os.environ["REKA_API_KEY"]
|
1241 |
+
|
1242 |
+
client = Reka(api_key=api_key)
|
1243 |
+
|
1244 |
+
use_search_engine = False
|
1245 |
+
if "-online" in model_name:
|
1246 |
+
model_name = model_name.replace("-online", "")
|
1247 |
+
use_search_engine = True
|
1248 |
+
request = {
|
1249 |
+
"model_name": model_name,
|
1250 |
+
"conversation_history": messages,
|
1251 |
+
"temperature": temperature,
|
1252 |
+
"request_output_len": max_new_tokens,
|
1253 |
+
"runtime_top_p": top_p,
|
1254 |
+
"stream": True,
|
1255 |
+
"use_search_engine": use_search_engine,
|
1256 |
+
}
|
1257 |
+
|
1258 |
+
# Make requests for logging
|
1259 |
+
text_messages = []
|
1260 |
+
for turn in messages:
|
1261 |
+
for message in turn.content:
|
1262 |
+
if isinstance(message, TypedText):
|
1263 |
+
text_messages.append(
|
1264 |
+
{"type": message.type, "text": message.text})
|
1265 |
+
logged_request = dict(request)
|
1266 |
+
logged_request["conversation_history"] = text_messages
|
1267 |
+
|
1268 |
+
logger.info(f"==== request ====\n{logged_request}")
|
1269 |
+
|
1270 |
+
response = client.chat.create_stream(
|
1271 |
+
messages=messages,
|
1272 |
+
max_tokens=max_new_tokens,
|
1273 |
+
top_p=top_p,
|
1274 |
+
model=model_name,
|
1275 |
+
)
|
1276 |
+
|
1277 |
+
for chunk in response:
|
1278 |
+
try:
|
1279 |
+
yield {"text": chunk.responses[0].chunk.content, "error_code": 0}
|
1280 |
+
except:
|
1281 |
+
yield {
|
1282 |
+
"text": f"**API REQUEST ERROR** ",
|
1283 |
+
"error_code": 1,
|
1284 |
+
}
|
1285 |
+
|
1286 |
+
|
1287 |
+
def metagen_api_stream_iter(
|
1288 |
+
model_name,
|
1289 |
+
messages,
|
1290 |
+
temperature,
|
1291 |
+
top_p,
|
1292 |
+
max_new_tokens,
|
1293 |
+
api_key,
|
1294 |
+
api_base,
|
1295 |
+
):
|
1296 |
+
try:
|
1297 |
+
text_messages = []
|
1298 |
+
for message in messages:
|
1299 |
+
if type(message["content"]) == str: # text-only model
|
1300 |
+
text_messages.append(message)
|
1301 |
+
else: # vision model
|
1302 |
+
filtered_content_list = [
|
1303 |
+
content
|
1304 |
+
for content in message["content"]
|
1305 |
+
if content["type"] == "text"
|
1306 |
+
]
|
1307 |
+
text_messages.append(
|
1308 |
+
{"role": message["role"], "content": filtered_content_list}
|
1309 |
+
)
|
1310 |
+
gen_params = {
|
1311 |
+
"model": model_name,
|
1312 |
+
"prompt": text_messages,
|
1313 |
+
"temperature": temperature,
|
1314 |
+
"top_p": top_p,
|
1315 |
+
"max_new_tokens": max_new_tokens,
|
1316 |
+
}
|
1317 |
+
logger.info(f"==== request ====\n{gen_params}")
|
1318 |
+
|
1319 |
+
res = requests.post(
|
1320 |
+
f"{api_base}/chat_stream_completions?access_token={api_key}",
|
1321 |
+
stream=True,
|
1322 |
+
headers={"Content-Type": "application/json"},
|
1323 |
+
json={
|
1324 |
+
"model": model_name,
|
1325 |
+
"chunks_delimited": True,
|
1326 |
+
"messages": messages,
|
1327 |
+
"options": {
|
1328 |
+
"max_tokens": max_new_tokens,
|
1329 |
+
"generation_algorithm": "top_p",
|
1330 |
+
"top_p": top_p,
|
1331 |
+
"temperature": temperature,
|
1332 |
+
},
|
1333 |
+
},
|
1334 |
+
timeout=30,
|
1335 |
+
)
|
1336 |
+
|
1337 |
+
if res.status_code != 200:
|
1338 |
+
logger.error(
|
1339 |
+
f"Unexpected response ({res.status_code}): {res.text}")
|
1340 |
+
yield {
|
1341 |
+
"text": f"**API REQUEST ERROR** Reason: Unknown.",
|
1342 |
+
"error_code": 1,
|
1343 |
+
}
|
1344 |
+
|
1345 |
+
text = ""
|
1346 |
+
for line in res.iter_lines():
|
1347 |
+
if line:
|
1348 |
+
part = json.loads(line.decode("utf-8"))
|
1349 |
+
if "text" in part:
|
1350 |
+
text += part["text"]
|
1351 |
+
data = {
|
1352 |
+
"text": text,
|
1353 |
+
"error_code": 0,
|
1354 |
+
}
|
1355 |
+
yield data
|
1356 |
+
except Exception as e:
|
1357 |
+
logger.error(f"==== error ====\n{e}")
|
1358 |
+
yield {
|
1359 |
+
"text": f"**API REQUEST ERROR** Reason: Unknown.",
|
1360 |
+
"error_code": 1,
|
1361 |
+
}
|
serve/constants.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Global constants.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from enum import IntEnum
|
6 |
+
import os
|
7 |
+
|
8 |
+
REPO_PATH = os.path.dirname(os.path.dirname(__file__))
|
9 |
+
|
10 |
+
# Survey Link URL (to be removed) #00729c
|
11 |
+
SURVEY_LINK = """"""
|
12 |
+
# SURVEY_LINK = ""
|
13 |
+
|
14 |
+
# For the gradio web server
|
15 |
+
SERVER_ERROR_MSG = (
|
16 |
+
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
17 |
+
)
|
18 |
+
TEXT_MODERATION_MSG = (
|
19 |
+
"$MODERATION$ YOUR TEXT VIOLATES OUR CONTENT MODERATION GUIDELINES."
|
20 |
+
)
|
21 |
+
IMAGE_MODERATION_MSG = (
|
22 |
+
"$MODERATION$ YOUR IMAGE VIOLATES OUR CONTENT MODERATION GUIDELINES."
|
23 |
+
)
|
24 |
+
MODERATION_MSG = "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES."
|
25 |
+
CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION."
|
26 |
+
INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE."
|
27 |
+
SLOW_MODEL_MSG = (
|
28 |
+
"⚠️ Models are thinking. Please stay patient as it may take over a minute."
|
29 |
+
)
|
30 |
+
RATE_LIMIT_MSG = "**RATE LIMIT OF THIS MODEL IS REACHED. PLEASE COME BACK LATER OR USE <span style='color: red; font-weight: bold;'>[BATTLE MODE](https://lmarena.ai)</span> (the 1st tab).**"
|
31 |
+
# Maximum input length
|
32 |
+
INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 12000))
|
33 |
+
BLIND_MODE_INPUT_CHAR_LEN_LIMIT = int(
|
34 |
+
os.getenv("FASTCHAT_BLIND_MODE_INPUT_CHAR_LEN_LIMIT", 30000)
|
35 |
+
)
|
36 |
+
# Maximum conversation turns
|
37 |
+
CONVERSATION_TURN_LIMIT = 50
|
38 |
+
# Session expiration time
|
39 |
+
SESSION_EXPIRATION_TIME = 3600
|
40 |
+
# The output dir of log files
|
41 |
+
LOGDIR = os.getenv("LOGDIR", ".")
|
42 |
+
# CPU Instruction Set Architecture
|
43 |
+
CPU_ISA = os.getenv("CPU_ISA")
|
44 |
+
|
45 |
+
|
46 |
+
# For the controller and workers (could be overwritten through ENV variables.)
|
47 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = int(
|
48 |
+
os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90)
|
49 |
+
)
|
50 |
+
WORKER_HEART_BEAT_INTERVAL = int(
|
51 |
+
os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 45))
|
52 |
+
WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100))
|
53 |
+
WORKER_API_EMBEDDING_BATCH_SIZE = int(
|
54 |
+
os.getenv("FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE", 4)
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
class ErrorCode(IntEnum):
|
59 |
+
"""
|
60 |
+
https://platform.openai.com/docs/guides/error-codes/api-errors
|
61 |
+
"""
|
62 |
+
|
63 |
+
VALIDATION_TYPE_ERROR = 40001
|
64 |
+
|
65 |
+
INVALID_AUTH_KEY = 40101
|
66 |
+
INCORRECT_AUTH_KEY = 40102
|
67 |
+
NO_PERMISSION = 40103
|
68 |
+
|
69 |
+
INVALID_MODEL = 40301
|
70 |
+
PARAM_OUT_OF_RANGE = 40302
|
71 |
+
CONTEXT_OVERFLOW = 40303
|
72 |
+
|
73 |
+
RATE_LIMIT = 42901
|
74 |
+
QUOTA_EXCEEDED = 42902
|
75 |
+
ENGINE_OVERLOADED = 42903
|
76 |
+
|
77 |
+
INTERNAL_ERROR = 50001
|
78 |
+
CUDA_OUT_OF_MEMORY = 50002
|
79 |
+
GRADIO_REQUEST_ERROR = 50003
|
80 |
+
GRADIO_STREAM_UNKNOWN_ERROR = 50004
|
81 |
+
CONTROLLER_NO_WORKER = 50005
|
82 |
+
CONTROLLER_WORKER_TIMEOUT = 50006
|
serve/conversation.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
serve/gradio_block_arena_anony.py
ADDED
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Chatbot Arena (battle) tab.
|
3 |
+
Users chat with two anonymous models.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
import re
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from .constants import (
|
14 |
+
MODERATION_MSG,
|
15 |
+
CONVERSATION_LIMIT_MSG,
|
16 |
+
SLOW_MODEL_MSG,
|
17 |
+
BLIND_MODE_INPUT_CHAR_LEN_LIMIT,
|
18 |
+
CONVERSATION_TURN_LIMIT,
|
19 |
+
SURVEY_LINK,
|
20 |
+
)
|
21 |
+
from .gradio_block_arena_named import flash_buttons
|
22 |
+
from .gradio_web_server import (
|
23 |
+
State,
|
24 |
+
bot_response,
|
25 |
+
get_conv_log_filename,
|
26 |
+
no_change_btn,
|
27 |
+
enable_btn,
|
28 |
+
disable_btn,
|
29 |
+
invisible_btn,
|
30 |
+
enable_text,
|
31 |
+
disable_text,
|
32 |
+
acknowledgment_md,
|
33 |
+
get_ip,
|
34 |
+
get_model_description_md,
|
35 |
+
)
|
36 |
+
from .remote_logger import get_remote_logger
|
37 |
+
from .utils import (
|
38 |
+
build_logger,
|
39 |
+
moderation_filter,
|
40 |
+
)
|
41 |
+
|
42 |
+
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
|
43 |
+
|
44 |
+
num_sides = 2
|
45 |
+
enable_moderation = False
|
46 |
+
anony_names = ["", ""]
|
47 |
+
models = []
|
48 |
+
|
49 |
+
|
50 |
+
def set_global_vars_anony(enable_moderation_):
|
51 |
+
global enable_moderation
|
52 |
+
enable_moderation = enable_moderation_
|
53 |
+
|
54 |
+
|
55 |
+
def load_demo_side_by_side_anony(models_, url_params):
|
56 |
+
global models
|
57 |
+
models = models_
|
58 |
+
|
59 |
+
states = [None] * num_sides
|
60 |
+
selector_updates = [
|
61 |
+
gr.Markdown(visible=True),
|
62 |
+
gr.Markdown(visible=True),
|
63 |
+
]
|
64 |
+
|
65 |
+
return states + selector_updates
|
66 |
+
|
67 |
+
|
68 |
+
def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
|
69 |
+
with open(get_conv_log_filename(), "a") as fout:
|
70 |
+
data = {
|
71 |
+
"tstamp": round(time.time(), 4),
|
72 |
+
"type": vote_type,
|
73 |
+
"models": [x for x in model_selectors],
|
74 |
+
"states": [x.dict() for x in states],
|
75 |
+
"ip": get_ip(request),
|
76 |
+
}
|
77 |
+
fout.write(json.dumps(data) + "\n")
|
78 |
+
get_remote_logger().log(data)
|
79 |
+
|
80 |
+
gr.Info(
|
81 |
+
"🎉 Thanks for voting! Your vote shapes the leaderboard, please vote RESPONSIBLY."
|
82 |
+
)
|
83 |
+
if ":" not in model_selectors[0]:
|
84 |
+
for i in range(5):
|
85 |
+
names = (
|
86 |
+
"### Model A: " + states[0].model_name,
|
87 |
+
"### Model B: " + states[1].model_name,
|
88 |
+
)
|
89 |
+
# yield names + ("",) + (disable_btn,) * 4
|
90 |
+
yield names + (disable_text,) + (disable_btn,) * 5
|
91 |
+
time.sleep(0.1)
|
92 |
+
else:
|
93 |
+
names = (
|
94 |
+
"### Model A: " + states[0].model_name,
|
95 |
+
"### Model B: " + states[1].model_name,
|
96 |
+
)
|
97 |
+
# yield names + ("",) + (disable_btn,) * 4
|
98 |
+
yield names + (disable_text,) + (disable_btn,) * 5
|
99 |
+
|
100 |
+
|
101 |
+
def leftvote_last_response(
|
102 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
103 |
+
):
|
104 |
+
logger.info(f"leftvote (anony). ip: {get_ip(request)}")
|
105 |
+
for x in vote_last_response(
|
106 |
+
[state0, state1], "leftvote", [model_selector0, model_selector1], request
|
107 |
+
):
|
108 |
+
yield x
|
109 |
+
|
110 |
+
|
111 |
+
def rightvote_last_response(
|
112 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
113 |
+
):
|
114 |
+
logger.info(f"rightvote (anony). ip: {get_ip(request)}")
|
115 |
+
for x in vote_last_response(
|
116 |
+
[state0, state1], "rightvote", [
|
117 |
+
model_selector0, model_selector1], request
|
118 |
+
):
|
119 |
+
yield x
|
120 |
+
|
121 |
+
|
122 |
+
def tievote_last_response(
|
123 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
124 |
+
):
|
125 |
+
logger.info(f"tievote (anony). ip: {get_ip(request)}")
|
126 |
+
for x in vote_last_response(
|
127 |
+
[state0, state1], "tievote", [model_selector0, model_selector1], request
|
128 |
+
):
|
129 |
+
yield x
|
130 |
+
|
131 |
+
|
132 |
+
def bothbad_vote_last_response(
|
133 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
134 |
+
):
|
135 |
+
logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
|
136 |
+
for x in vote_last_response(
|
137 |
+
[state0, state1], "bothbad_vote", [
|
138 |
+
model_selector0, model_selector1], request
|
139 |
+
):
|
140 |
+
yield x
|
141 |
+
|
142 |
+
|
143 |
+
def regenerate(state0, state1, request: gr.Request):
|
144 |
+
logger.info(f"regenerate (anony). ip: {get_ip(request)}")
|
145 |
+
states = [state0, state1]
|
146 |
+
if state0.regen_support and state1.regen_support:
|
147 |
+
for i in range(num_sides):
|
148 |
+
states[i].conv.update_last_message(None)
|
149 |
+
return (
|
150 |
+
states + [x.to_gradio_chatbot() for x in states] +
|
151 |
+
[""] + [disable_btn] * 6
|
152 |
+
)
|
153 |
+
states[0].skip_next = True
|
154 |
+
states[1].skip_next = True
|
155 |
+
return states + [x.to_gradio_chatbot() for x in states] + [""] + [no_change_btn] * 6
|
156 |
+
|
157 |
+
|
158 |
+
def clear_history(request: gr.Request):
|
159 |
+
logger.info(f"clear_history (anony). ip: {get_ip(request)}")
|
160 |
+
return (
|
161 |
+
[None] * num_sides
|
162 |
+
+ [None] * num_sides
|
163 |
+
+ anony_names
|
164 |
+
+ [enable_text]
|
165 |
+
+ [invisible_btn] * 4
|
166 |
+
+ [disable_btn] * 2
|
167 |
+
+ [""]
|
168 |
+
+ [enable_btn]
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
|
173 |
+
logger.info(f"share (anony). ip: {get_ip(request)}")
|
174 |
+
if state0 is not None and state1 is not None:
|
175 |
+
vote_last_response(
|
176 |
+
[state0, state1], "share", [model_selector0, model_selector1], request
|
177 |
+
)
|
178 |
+
|
179 |
+
|
180 |
+
SAMPLING_WEIGHTS = {}
|
181 |
+
|
182 |
+
# target model sampling weights will be boosted.
|
183 |
+
BATTLE_TARGETS = {}
|
184 |
+
|
185 |
+
BATTLE_STRICT_TARGETS = {}
|
186 |
+
|
187 |
+
ANON_MODELS = []
|
188 |
+
|
189 |
+
SAMPLING_BOOST_MODELS = []
|
190 |
+
|
191 |
+
# outage models won't be sampled.
|
192 |
+
OUTAGE_MODELS = []
|
193 |
+
|
194 |
+
|
195 |
+
def get_sample_weight(model, outage_models, sampling_weights, sampling_boost_models=[]):
|
196 |
+
if model in outage_models:
|
197 |
+
return 0
|
198 |
+
weight = sampling_weights.get(model, 0)
|
199 |
+
if model in sampling_boost_models:
|
200 |
+
weight *= 5
|
201 |
+
return weight
|
202 |
+
|
203 |
+
|
204 |
+
def is_model_match_pattern(model, patterns):
|
205 |
+
flag = False
|
206 |
+
for pattern in patterns:
|
207 |
+
pattern = pattern.replace("*", ".*")
|
208 |
+
if re.match(pattern, model) is not None:
|
209 |
+
flag = True
|
210 |
+
break
|
211 |
+
return flag
|
212 |
+
|
213 |
+
|
214 |
+
def get_battle_pair(
|
215 |
+
models, battle_targets, outage_models, sampling_weights, sampling_boost_models
|
216 |
+
):
|
217 |
+
print("models", models)
|
218 |
+
print("battle_targets", battle_targets)
|
219 |
+
print("outage_models", outage_models)
|
220 |
+
print("sampling_weights", sampling_weights)
|
221 |
+
print("sampling_boost_models", sampling_boost_models)
|
222 |
+
if len(models) == 1:
|
223 |
+
return models[0], models[0]
|
224 |
+
|
225 |
+
model_weights = []
|
226 |
+
for model in models:
|
227 |
+
weight = get_sample_weight(
|
228 |
+
model, outage_models, sampling_weights, sampling_boost_models
|
229 |
+
)
|
230 |
+
if weight == 0:
|
231 |
+
weight += 0.01
|
232 |
+
model_weights.append(weight)
|
233 |
+
total_weight = np.sum(model_weights)
|
234 |
+
print("model_weights", model_weights)
|
235 |
+
print("total_weight", total_weight)
|
236 |
+
model_weights = model_weights / total_weight
|
237 |
+
|
238 |
+
chosen_idx = np.random.choice(len(models), p=model_weights)
|
239 |
+
chosen_model = models[chosen_idx]
|
240 |
+
# for p, w in zip(models, model_weights):
|
241 |
+
# print(p, w)
|
242 |
+
|
243 |
+
rival_models = []
|
244 |
+
rival_weights = []
|
245 |
+
for model in models:
|
246 |
+
if model == chosen_model:
|
247 |
+
continue
|
248 |
+
if model in ANON_MODELS and chosen_model in ANON_MODELS:
|
249 |
+
continue
|
250 |
+
if chosen_model in BATTLE_STRICT_TARGETS:
|
251 |
+
if not is_model_match_pattern(model, BATTLE_STRICT_TARGETS[chosen_model]):
|
252 |
+
continue
|
253 |
+
if model in BATTLE_STRICT_TARGETS:
|
254 |
+
if not is_model_match_pattern(chosen_model, BATTLE_STRICT_TARGETS[model]):
|
255 |
+
continue
|
256 |
+
weight = get_sample_weight(model, outage_models, sampling_weights)
|
257 |
+
if (
|
258 |
+
weight != 0
|
259 |
+
and chosen_model in battle_targets
|
260 |
+
and model in battle_targets[chosen_model]
|
261 |
+
):
|
262 |
+
# boost to 20% chance
|
263 |
+
weight = 0.5 * total_weight / len(battle_targets[chosen_model])
|
264 |
+
rival_models.append(model)
|
265 |
+
if weight == 0:
|
266 |
+
weight += 0.01
|
267 |
+
rival_weights.append(weight)
|
268 |
+
# for p, w in zip(rival_models, rival_weights):
|
269 |
+
# print(p, w)
|
270 |
+
rival_weights = rival_weights / np.sum(rival_weights)
|
271 |
+
rival_idx = np.random.choice(len(rival_models), p=rival_weights)
|
272 |
+
rival_model = rival_models[rival_idx]
|
273 |
+
|
274 |
+
swap = np.random.randint(2)
|
275 |
+
if swap == 0:
|
276 |
+
return chosen_model, rival_model
|
277 |
+
else:
|
278 |
+
return rival_model, chosen_model
|
279 |
+
|
280 |
+
|
281 |
+
def add_text(
|
282 |
+
state0, state1, model_selector0, model_selector1, text, request: gr.Request
|
283 |
+
):
|
284 |
+
ip = get_ip(request)
|
285 |
+
logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
|
286 |
+
states = [state0, state1]
|
287 |
+
model_selectors = [model_selector0, model_selector1]
|
288 |
+
|
289 |
+
# Init states if necessary
|
290 |
+
if states[0] is None:
|
291 |
+
assert states[1] is None
|
292 |
+
|
293 |
+
model_left, model_right = get_battle_pair(
|
294 |
+
models,
|
295 |
+
BATTLE_TARGETS,
|
296 |
+
OUTAGE_MODELS,
|
297 |
+
SAMPLING_WEIGHTS,
|
298 |
+
SAMPLING_BOOST_MODELS,
|
299 |
+
)
|
300 |
+
states = [
|
301 |
+
State(model_left),
|
302 |
+
State(model_right),
|
303 |
+
]
|
304 |
+
|
305 |
+
if len(text) <= 0:
|
306 |
+
for i in range(num_sides):
|
307 |
+
states[i].skip_next = True
|
308 |
+
return (
|
309 |
+
states
|
310 |
+
+ [x.to_gradio_chatbot() for x in states]
|
311 |
+
+ ["", None]
|
312 |
+
+ [
|
313 |
+
no_change_btn,
|
314 |
+
]
|
315 |
+
* 6
|
316 |
+
+ [""]
|
317 |
+
)
|
318 |
+
|
319 |
+
model_list = [states[i].model_name for i in range(num_sides)]
|
320 |
+
# turn on moderation in battle mode
|
321 |
+
all_conv_text_left = states[0].conv.get_prompt()
|
322 |
+
all_conv_text_right = states[0].conv.get_prompt()
|
323 |
+
all_conv_text = (
|
324 |
+
all_conv_text_left[-1000:] +
|
325 |
+
all_conv_text_right[-1000:] + "\nuser: " + text
|
326 |
+
)
|
327 |
+
flagged = moderation_filter(all_conv_text, model_list, do_moderation=True)
|
328 |
+
if flagged:
|
329 |
+
logger.info(f"violate moderation (anony). ip: {ip}. text: {text}")
|
330 |
+
# overwrite the original text
|
331 |
+
text = MODERATION_MSG
|
332 |
+
|
333 |
+
conv = states[0].conv
|
334 |
+
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
|
335 |
+
logger.info(
|
336 |
+
f"conversation turn limit. ip: {get_ip(request)}. text: {text}")
|
337 |
+
for i in range(num_sides):
|
338 |
+
states[i].skip_next = True
|
339 |
+
return (
|
340 |
+
states
|
341 |
+
+ [x.to_gradio_chatbot() for x in states]
|
342 |
+
+ [CONVERSATION_LIMIT_MSG]
|
343 |
+
+ [
|
344 |
+
no_change_btn,
|
345 |
+
]
|
346 |
+
* 6
|
347 |
+
+ [""]
|
348 |
+
)
|
349 |
+
|
350 |
+
text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
|
351 |
+
for i in range(num_sides):
|
352 |
+
states[i].conv.append_message(states[i].conv.roles[0], text)
|
353 |
+
states[i].conv.append_message(states[i].conv.roles[1], None)
|
354 |
+
states[i].skip_next = False
|
355 |
+
|
356 |
+
hint_msg = ""
|
357 |
+
for i in range(num_sides):
|
358 |
+
if "deluxe" in states[i].model_name:
|
359 |
+
hint_msg = SLOW_MODEL_MSG
|
360 |
+
return (
|
361 |
+
states
|
362 |
+
+ [x.to_gradio_chatbot() for x in states]
|
363 |
+
+ [""]
|
364 |
+
+ [
|
365 |
+
disable_btn,
|
366 |
+
]
|
367 |
+
* 6
|
368 |
+
+ [hint_msg]
|
369 |
+
)
|
370 |
+
|
371 |
+
|
372 |
+
def bot_response_multi(
|
373 |
+
state0,
|
374 |
+
state1,
|
375 |
+
temperature,
|
376 |
+
top_p,
|
377 |
+
max_new_tokens,
|
378 |
+
request: gr.Request,
|
379 |
+
):
|
380 |
+
logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}")
|
381 |
+
|
382 |
+
if state0 is None or state0.skip_next:
|
383 |
+
# This generate call is skipped due to invalid inputs
|
384 |
+
yield (
|
385 |
+
state0,
|
386 |
+
state1,
|
387 |
+
state0.to_gradio_chatbot(),
|
388 |
+
state1.to_gradio_chatbot(),
|
389 |
+
) + (no_change_btn,) * 6
|
390 |
+
return
|
391 |
+
|
392 |
+
states = [state0, state1]
|
393 |
+
gen = []
|
394 |
+
for i in range(num_sides):
|
395 |
+
gen.append(
|
396 |
+
bot_response(
|
397 |
+
states[i],
|
398 |
+
temperature,
|
399 |
+
top_p,
|
400 |
+
max_new_tokens,
|
401 |
+
request,
|
402 |
+
apply_rate_limit=False,
|
403 |
+
use_recommended_config=True,
|
404 |
+
)
|
405 |
+
)
|
406 |
+
|
407 |
+
model_tpy = []
|
408 |
+
for i in range(num_sides):
|
409 |
+
token_per_yield = 1
|
410 |
+
if states[i].model_name in [
|
411 |
+
"gemini-pro",
|
412 |
+
"gemma-1.1-2b-it",
|
413 |
+
"gemma-1.1-7b-it",
|
414 |
+
"phi-3-mini-4k-instruct",
|
415 |
+
"phi-3-mini-128k-instruct",
|
416 |
+
"snowflake-arctic-instruct",
|
417 |
+
]:
|
418 |
+
token_per_yield = 30
|
419 |
+
elif states[i].model_name in [
|
420 |
+
"qwen-max-0428",
|
421 |
+
"qwen-vl-max-0809",
|
422 |
+
"qwen1.5-110b-chat",
|
423 |
+
"llava-v1.6-34b",
|
424 |
+
]:
|
425 |
+
token_per_yield = 7
|
426 |
+
elif states[i].model_name in [
|
427 |
+
"qwen2.5-72b-instruct",
|
428 |
+
"qwen2-72b-instruct",
|
429 |
+
"qwen-plus-0828",
|
430 |
+
"qwen-max-0919",
|
431 |
+
"llama-3.1-405b-instruct-bf16",
|
432 |
+
]:
|
433 |
+
token_per_yield = 4
|
434 |
+
model_tpy.append(token_per_yield)
|
435 |
+
|
436 |
+
chatbots = [None] * num_sides
|
437 |
+
iters = 0
|
438 |
+
while True:
|
439 |
+
stop = True
|
440 |
+
iters += 1
|
441 |
+
for i in range(num_sides):
|
442 |
+
try:
|
443 |
+
# yield fewer times if chunk size is larger
|
444 |
+
if model_tpy[i] == 1 or (iters % model_tpy[i] == 1 or iters < 3):
|
445 |
+
ret = next(gen[i])
|
446 |
+
states[i], chatbots[i] = ret[0], ret[1]
|
447 |
+
stop = False
|
448 |
+
except StopIteration:
|
449 |
+
pass
|
450 |
+
yield states + chatbots + [disable_btn] * 6
|
451 |
+
if stop:
|
452 |
+
break
|
453 |
+
|
454 |
+
|
455 |
+
def build_side_by_side_ui_anony(models):
|
456 |
+
notice_markdown = f"""
|
457 |
+
# ⚔️ Chatbot Arena 日本語版(α版)
|
458 |
+
|
459 |
+
{SURVEY_LINK}
|
460 |
+
|
461 |
+
## 📣 News
|
462 |
+
- 評価結果はこちら: [here](https://huggingface.co/datasets/kanhatakeyama/chatbot-arena-ja-elo-rating).
|
463 |
+
## 👇 Chat now!
|
464 |
+
"""
|
465 |
+
|
466 |
+
states = [gr.State() for _ in range(num_sides)]
|
467 |
+
model_selectors = [None] * num_sides
|
468 |
+
chatbots = [None] * num_sides
|
469 |
+
|
470 |
+
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
471 |
+
|
472 |
+
with gr.Group(elem_id="share-region-anony"):
|
473 |
+
with gr.Accordion(
|
474 |
+
f"🔍 Expand to see the descriptions of {len(models)} models", open=False
|
475 |
+
):
|
476 |
+
model_description_md = get_model_description_md(models)
|
477 |
+
gr.Markdown(model_description_md,
|
478 |
+
elem_id="model_description_markdown")
|
479 |
+
with gr.Row():
|
480 |
+
for i in range(num_sides):
|
481 |
+
label = "Model A" if i == 0 else "Model B"
|
482 |
+
with gr.Column():
|
483 |
+
chatbots[i] = gr.Chatbot(
|
484 |
+
label=label,
|
485 |
+
elem_id="chatbot",
|
486 |
+
height=650,
|
487 |
+
show_copy_button=True,
|
488 |
+
)
|
489 |
+
|
490 |
+
with gr.Row():
|
491 |
+
for i in range(num_sides):
|
492 |
+
with gr.Column():
|
493 |
+
model_selectors[i] = gr.Markdown(
|
494 |
+
anony_names[i], elem_id="model_selector_md"
|
495 |
+
)
|
496 |
+
with gr.Row():
|
497 |
+
slow_warning = gr.Markdown("")
|
498 |
+
|
499 |
+
with gr.Row():
|
500 |
+
leftvote_btn = gr.Button(
|
501 |
+
value="👈 A is better", visible=False, interactive=False
|
502 |
+
)
|
503 |
+
rightvote_btn = gr.Button(
|
504 |
+
value="👉 B is better", visible=False, interactive=False
|
505 |
+
)
|
506 |
+
tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
|
507 |
+
bothbad_btn = gr.Button(
|
508 |
+
value="👎 Both are bad", visible=False, interactive=False
|
509 |
+
)
|
510 |
+
|
511 |
+
with gr.Row():
|
512 |
+
textbox = gr.Textbox(
|
513 |
+
show_label=False,
|
514 |
+
placeholder="👉 Enter your prompt and press ENTER",
|
515 |
+
elem_id="input_box",
|
516 |
+
)
|
517 |
+
send_btn = gr.Button(value="Send", variant="primary", scale=0)
|
518 |
+
|
519 |
+
with gr.Row() as button_row:
|
520 |
+
clear_btn = gr.Button(value="🎲 New Round", interactive=False)
|
521 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
522 |
+
share_btn = gr.Button(value="📷 Share")
|
523 |
+
|
524 |
+
with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
|
525 |
+
temperature = gr.Slider(
|
526 |
+
minimum=0.0,
|
527 |
+
maximum=1.0,
|
528 |
+
value=0.7,
|
529 |
+
step=0.1,
|
530 |
+
interactive=True,
|
531 |
+
label="Temperature",
|
532 |
+
)
|
533 |
+
top_p = gr.Slider(
|
534 |
+
minimum=0.0,
|
535 |
+
maximum=1.0,
|
536 |
+
value=1.0,
|
537 |
+
step=0.1,
|
538 |
+
interactive=True,
|
539 |
+
label="Top P",
|
540 |
+
)
|
541 |
+
max_output_tokens = gr.Slider(
|
542 |
+
minimum=16,
|
543 |
+
maximum=2048,
|
544 |
+
value=2000,
|
545 |
+
step=64,
|
546 |
+
interactive=True,
|
547 |
+
label="Max output tokens",
|
548 |
+
)
|
549 |
+
|
550 |
+
gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
|
551 |
+
|
552 |
+
# Register listeners
|
553 |
+
btn_list = [
|
554 |
+
leftvote_btn,
|
555 |
+
rightvote_btn,
|
556 |
+
tie_btn,
|
557 |
+
bothbad_btn,
|
558 |
+
regenerate_btn,
|
559 |
+
clear_btn,
|
560 |
+
]
|
561 |
+
leftvote_btn.click(
|
562 |
+
leftvote_last_response,
|
563 |
+
states + model_selectors,
|
564 |
+
model_selectors
|
565 |
+
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, send_btn],
|
566 |
+
)
|
567 |
+
rightvote_btn.click(
|
568 |
+
rightvote_last_response,
|
569 |
+
states + model_selectors,
|
570 |
+
model_selectors
|
571 |
+
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, send_btn],
|
572 |
+
)
|
573 |
+
tie_btn.click(
|
574 |
+
tievote_last_response,
|
575 |
+
states + model_selectors,
|
576 |
+
model_selectors
|
577 |
+
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, send_btn],
|
578 |
+
)
|
579 |
+
bothbad_btn.click(
|
580 |
+
bothbad_vote_last_response,
|
581 |
+
states + model_selectors,
|
582 |
+
model_selectors
|
583 |
+
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, send_btn],
|
584 |
+
)
|
585 |
+
regenerate_btn.click(
|
586 |
+
regenerate, states, states + chatbots + [textbox] + btn_list
|
587 |
+
).then(
|
588 |
+
bot_response_multi,
|
589 |
+
states + [temperature, top_p, max_output_tokens],
|
590 |
+
states + chatbots + btn_list,
|
591 |
+
).then(
|
592 |
+
flash_buttons, [], btn_list
|
593 |
+
)
|
594 |
+
clear_btn.click(
|
595 |
+
clear_history,
|
596 |
+
None,
|
597 |
+
states
|
598 |
+
+ chatbots
|
599 |
+
+ model_selectors
|
600 |
+
+ [textbox]
|
601 |
+
+ btn_list
|
602 |
+
+ [slow_warning]
|
603 |
+
+ [send_btn],
|
604 |
+
)
|
605 |
+
|
606 |
+
share_js = """
|
607 |
+
function (a, b, c, d) {
|
608 |
+
const captureElement = document.querySelector('#share-region-anony');
|
609 |
+
html2canvas(captureElement)
|
610 |
+
.then(canvas => {
|
611 |
+
canvas.style.display = 'none'
|
612 |
+
document.body.appendChild(canvas)
|
613 |
+
return canvas
|
614 |
+
})
|
615 |
+
.then(canvas => {
|
616 |
+
const image = canvas.toDataURL('image/png')
|
617 |
+
const a = document.createElement('a')
|
618 |
+
a.setAttribute('download', 'chatbot-arena.png')
|
619 |
+
a.setAttribute('href', image)
|
620 |
+
a.click()
|
621 |
+
canvas.remove()
|
622 |
+
});
|
623 |
+
return [a, b, c, d];
|
624 |
+
}
|
625 |
+
"""
|
626 |
+
share_btn.click(share_click, states + model_selectors, [], js=share_js)
|
627 |
+
|
628 |
+
textbox.submit(
|
629 |
+
add_text,
|
630 |
+
states + model_selectors + [textbox],
|
631 |
+
states + chatbots + [textbox] + btn_list + [slow_warning],
|
632 |
+
).then(
|
633 |
+
bot_response_multi,
|
634 |
+
states + [temperature, top_p, max_output_tokens],
|
635 |
+
states + chatbots + btn_list,
|
636 |
+
).then(
|
637 |
+
flash_buttons,
|
638 |
+
[],
|
639 |
+
btn_list,
|
640 |
+
)
|
641 |
+
|
642 |
+
send_btn.click(
|
643 |
+
add_text,
|
644 |
+
states + model_selectors + [textbox],
|
645 |
+
states + chatbots + [textbox] + btn_list,
|
646 |
+
).then(
|
647 |
+
bot_response_multi,
|
648 |
+
states + [temperature, top_p, max_output_tokens],
|
649 |
+
states + chatbots + btn_list,
|
650 |
+
).then(
|
651 |
+
flash_buttons, [], btn_list
|
652 |
+
)
|
653 |
+
|
654 |
+
return states + model_selectors
|
serve/gradio_block_arena_named.py
ADDED
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Chatbot Arena (side-by-side) tab.
|
3 |
+
Users chat with two chosen models.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from .constants import (
|
13 |
+
MODERATION_MSG,
|
14 |
+
CONVERSATION_LIMIT_MSG,
|
15 |
+
INPUT_CHAR_LEN_LIMIT,
|
16 |
+
CONVERSATION_TURN_LIMIT,
|
17 |
+
SURVEY_LINK,
|
18 |
+
)
|
19 |
+
from .gradio_web_server import (
|
20 |
+
State,
|
21 |
+
bot_response,
|
22 |
+
get_conv_log_filename,
|
23 |
+
no_change_btn,
|
24 |
+
enable_btn,
|
25 |
+
disable_btn,
|
26 |
+
invisible_btn,
|
27 |
+
acknowledgment_md,
|
28 |
+
get_ip,
|
29 |
+
get_model_description_md,
|
30 |
+
)
|
31 |
+
from .remote_logger import get_remote_logger
|
32 |
+
from .utils import (
|
33 |
+
build_logger,
|
34 |
+
moderation_filter,
|
35 |
+
)
|
36 |
+
|
37 |
+
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
|
38 |
+
|
39 |
+
num_sides = 2
|
40 |
+
enable_moderation = False
|
41 |
+
|
42 |
+
|
43 |
+
def set_global_vars_named(enable_moderation_):
|
44 |
+
global enable_moderation
|
45 |
+
enable_moderation = enable_moderation_
|
46 |
+
|
47 |
+
|
48 |
+
def load_demo_side_by_side_named(models, url_params):
|
49 |
+
states = [None] * num_sides
|
50 |
+
|
51 |
+
model_left = models[0] if len(models) > 0 else ""
|
52 |
+
if len(models) > 1:
|
53 |
+
weights = ([8] * 4 + [4] * 8 + [1] * 64)[: len(models) - 1]
|
54 |
+
weights = weights / np.sum(weights)
|
55 |
+
model_right = np.random.choice(models[1:], p=weights)
|
56 |
+
else:
|
57 |
+
model_right = model_left
|
58 |
+
|
59 |
+
selector_updates = [
|
60 |
+
gr.Dropdown(choices=models, value=model_left, visible=True),
|
61 |
+
gr.Dropdown(choices=models, value=model_right, visible=True),
|
62 |
+
]
|
63 |
+
|
64 |
+
return states + selector_updates
|
65 |
+
|
66 |
+
|
67 |
+
def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
|
68 |
+
with open(get_conv_log_filename(), "a") as fout:
|
69 |
+
data = {
|
70 |
+
"tstamp": round(time.time(), 4),
|
71 |
+
"type": vote_type,
|
72 |
+
"models": [x for x in model_selectors],
|
73 |
+
"states": [x.dict() for x in states],
|
74 |
+
"ip": get_ip(request),
|
75 |
+
}
|
76 |
+
fout.write(json.dumps(data) + "\n")
|
77 |
+
get_remote_logger().log(data)
|
78 |
+
|
79 |
+
|
80 |
+
def leftvote_last_response(
|
81 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
82 |
+
):
|
83 |
+
logger.info(f"leftvote (named). ip: {get_ip(request)}")
|
84 |
+
vote_last_response(
|
85 |
+
[state0, state1], "leftvote", [model_selector0, model_selector1], request
|
86 |
+
)
|
87 |
+
return ("",) + (disable_btn,) * 4
|
88 |
+
|
89 |
+
|
90 |
+
def rightvote_last_response(
|
91 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
92 |
+
):
|
93 |
+
logger.info(f"rightvote (named). ip: {get_ip(request)}")
|
94 |
+
vote_last_response(
|
95 |
+
[state0, state1], "rightvote", [
|
96 |
+
model_selector0, model_selector1], request
|
97 |
+
)
|
98 |
+
return ("",) + (disable_btn,) * 4
|
99 |
+
|
100 |
+
|
101 |
+
def tievote_last_response(
|
102 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
103 |
+
):
|
104 |
+
logger.info(f"tievote (named). ip: {get_ip(request)}")
|
105 |
+
vote_last_response(
|
106 |
+
[state0, state1], "tievote", [model_selector0, model_selector1], request
|
107 |
+
)
|
108 |
+
return ("",) + (disable_btn,) * 4
|
109 |
+
|
110 |
+
|
111 |
+
def bothbad_vote_last_response(
|
112 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
113 |
+
):
|
114 |
+
logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
|
115 |
+
vote_last_response(
|
116 |
+
[state0, state1], "bothbad_vote", [
|
117 |
+
model_selector0, model_selector1], request
|
118 |
+
)
|
119 |
+
return ("",) + (disable_btn,) * 4
|
120 |
+
|
121 |
+
|
122 |
+
def regenerate(state0, state1, request: gr.Request):
|
123 |
+
logger.info(f"regenerate (named). ip: {get_ip(request)}")
|
124 |
+
states = [state0, state1]
|
125 |
+
if state0.regen_support and state1.regen_support:
|
126 |
+
for i in range(num_sides):
|
127 |
+
states[i].conv.update_last_message(None)
|
128 |
+
return (
|
129 |
+
states + [x.to_gradio_chatbot() for x in states] +
|
130 |
+
[""] + [disable_btn] * 6
|
131 |
+
)
|
132 |
+
states[0].skip_next = True
|
133 |
+
states[1].skip_next = True
|
134 |
+
return states + [x.to_gradio_chatbot() for x in states] + [""] + [no_change_btn] * 6
|
135 |
+
|
136 |
+
|
137 |
+
def clear_history(request: gr.Request):
|
138 |
+
logger.info(f"clear_history (named). ip: {get_ip(request)}")
|
139 |
+
return (
|
140 |
+
[None] * num_sides
|
141 |
+
+ [None] * num_sides
|
142 |
+
+ [""]
|
143 |
+
+ [invisible_btn] * 4
|
144 |
+
+ [disable_btn] * 2
|
145 |
+
)
|
146 |
+
|
147 |
+
|
148 |
+
def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
|
149 |
+
logger.info(f"share (named). ip: {get_ip(request)}")
|
150 |
+
if state0 is not None and state1 is not None:
|
151 |
+
vote_last_response(
|
152 |
+
[state0, state1], "share", [model_selector0, model_selector1], request
|
153 |
+
)
|
154 |
+
|
155 |
+
|
156 |
+
def add_text(
|
157 |
+
state0, state1, model_selector0, model_selector1, text, request: gr.Request
|
158 |
+
):
|
159 |
+
ip = get_ip(request)
|
160 |
+
logger.info(f"add_text (named). ip: {ip}. len: {len(text)}")
|
161 |
+
states = [state0, state1]
|
162 |
+
model_selectors = [model_selector0, model_selector1]
|
163 |
+
|
164 |
+
# Init states if necessary
|
165 |
+
for i in range(num_sides):
|
166 |
+
if states[i] is None:
|
167 |
+
states[i] = State(model_selectors[i])
|
168 |
+
|
169 |
+
if len(text) <= 0:
|
170 |
+
for i in range(num_sides):
|
171 |
+
states[i].skip_next = True
|
172 |
+
return (
|
173 |
+
states
|
174 |
+
+ [x.to_gradio_chatbot() for x in states]
|
175 |
+
+ ["", None]
|
176 |
+
+ [
|
177 |
+
no_change_btn,
|
178 |
+
]
|
179 |
+
* 6
|
180 |
+
)
|
181 |
+
|
182 |
+
model_list = [states[i].model_name for i in range(num_sides)]
|
183 |
+
all_conv_text_left = states[0].conv.get_prompt()
|
184 |
+
all_conv_text_right = states[1].conv.get_prompt()
|
185 |
+
all_conv_text = (
|
186 |
+
all_conv_text_left[-1000:] +
|
187 |
+
all_conv_text_right[-1000:] + "\nuser: " + text
|
188 |
+
)
|
189 |
+
flagged = moderation_filter(all_conv_text, model_list)
|
190 |
+
if flagged:
|
191 |
+
logger.info(f"violate moderation (named). ip: {ip}. text: {text}")
|
192 |
+
# overwrite the original text
|
193 |
+
text = MODERATION_MSG
|
194 |
+
|
195 |
+
conv = states[0].conv
|
196 |
+
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
|
197 |
+
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
|
198 |
+
for i in range(num_sides):
|
199 |
+
states[i].skip_next = True
|
200 |
+
return (
|
201 |
+
states
|
202 |
+
+ [x.to_gradio_chatbot() for x in states]
|
203 |
+
+ [CONVERSATION_LIMIT_MSG]
|
204 |
+
+ [
|
205 |
+
no_change_btn,
|
206 |
+
]
|
207 |
+
* 6
|
208 |
+
)
|
209 |
+
|
210 |
+
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
|
211 |
+
for i in range(num_sides):
|
212 |
+
states[i].conv.append_message(states[i].conv.roles[0], text)
|
213 |
+
states[i].conv.append_message(states[i].conv.roles[1], None)
|
214 |
+
states[i].skip_next = False
|
215 |
+
|
216 |
+
return (
|
217 |
+
states
|
218 |
+
+ [x.to_gradio_chatbot() for x in states]
|
219 |
+
+ [""]
|
220 |
+
+ [
|
221 |
+
disable_btn,
|
222 |
+
]
|
223 |
+
* 6
|
224 |
+
)
|
225 |
+
|
226 |
+
|
227 |
+
def bot_response_multi(
|
228 |
+
state0,
|
229 |
+
state1,
|
230 |
+
temperature,
|
231 |
+
top_p,
|
232 |
+
max_new_tokens,
|
233 |
+
request: gr.Request,
|
234 |
+
):
|
235 |
+
logger.info(f"bot_response_multi (named). ip: {get_ip(request)}")
|
236 |
+
|
237 |
+
if state0.skip_next:
|
238 |
+
# This generate call is skipped due to invalid inputs
|
239 |
+
yield (
|
240 |
+
state0,
|
241 |
+
state1,
|
242 |
+
state0.to_gradio_chatbot(),
|
243 |
+
state1.to_gradio_chatbot(),
|
244 |
+
) + (no_change_btn,) * 6
|
245 |
+
return
|
246 |
+
|
247 |
+
states = [state0, state1]
|
248 |
+
gen = []
|
249 |
+
for i in range(num_sides):
|
250 |
+
gen.append(
|
251 |
+
bot_response(
|
252 |
+
states[i],
|
253 |
+
temperature,
|
254 |
+
top_p,
|
255 |
+
max_new_tokens,
|
256 |
+
request,
|
257 |
+
)
|
258 |
+
)
|
259 |
+
|
260 |
+
model_tpy = []
|
261 |
+
for i in range(num_sides):
|
262 |
+
token_per_yield = 1
|
263 |
+
if states[i].model_name in [
|
264 |
+
"gemini-pro",
|
265 |
+
"gemma-1.1-2b-it",
|
266 |
+
"gemma-1.1-7b-it",
|
267 |
+
"phi-3-mini-4k-instruct",
|
268 |
+
"phi-3-mini-128k-instruct",
|
269 |
+
"snowflake-arctic-instruct",
|
270 |
+
]:
|
271 |
+
token_per_yield = 30
|
272 |
+
elif states[i].model_name in [
|
273 |
+
"qwen-max-0428",
|
274 |
+
"qwen-vl-max-0809",
|
275 |
+
"qwen1.5-110b-chat",
|
276 |
+
]:
|
277 |
+
token_per_yield = 7
|
278 |
+
elif states[i].model_name in [
|
279 |
+
"qwen2.5-72b-instruct",
|
280 |
+
"qwen2-72b-instruct",
|
281 |
+
"qwen-plus-0828",
|
282 |
+
"qwen-max-0919",
|
283 |
+
"llama-3.1-405b-instruct-bf16",
|
284 |
+
]:
|
285 |
+
token_per_yield = 4
|
286 |
+
model_tpy.append(token_per_yield)
|
287 |
+
|
288 |
+
chatbots = [None] * num_sides
|
289 |
+
iters = 0
|
290 |
+
while True:
|
291 |
+
stop = True
|
292 |
+
iters += 1
|
293 |
+
for i in range(num_sides):
|
294 |
+
try:
|
295 |
+
# yield fewer times if chunk size is larger
|
296 |
+
if model_tpy[i] == 1 or (iters % model_tpy[i] == 1 or iters < 3):
|
297 |
+
ret = next(gen[i])
|
298 |
+
states[i], chatbots[i] = ret[0], ret[1]
|
299 |
+
stop = False
|
300 |
+
except StopIteration:
|
301 |
+
pass
|
302 |
+
yield states + chatbots + [disable_btn] * 6
|
303 |
+
if stop:
|
304 |
+
break
|
305 |
+
|
306 |
+
|
307 |
+
def flash_buttons():
|
308 |
+
btn_updates = [
|
309 |
+
[disable_btn] * 4 + [enable_btn] * 2,
|
310 |
+
[enable_btn] * 6,
|
311 |
+
]
|
312 |
+
for i in range(4):
|
313 |
+
yield btn_updates[i % 2]
|
314 |
+
time.sleep(0.3)
|
315 |
+
|
316 |
+
|
317 |
+
def build_side_by_side_ui_named(models):
|
318 |
+
notice_markdown = f"""
|
319 |
+
# ⚔️ Chatbot Arena (formerly LMSYS): Free AI Chat to Compare & Test Best AI Chatbots
|
320 |
+
|
321 |
+
{SURVEY_LINK}
|
322 |
+
|
323 |
+
## 📜 How It Works
|
324 |
+
- Ask any question to two chosen models (e.g., ChatGPT, Gemini, Claude, Llama) and vote for the better one!
|
325 |
+
- You can chat for multiple turns until you identify a winner.
|
326 |
+
|
327 |
+
## 👇 Choose two models to compare
|
328 |
+
"""
|
329 |
+
|
330 |
+
states = [gr.State() for _ in range(num_sides)]
|
331 |
+
model_selectors = [None] * num_sides
|
332 |
+
chatbots = [None] * num_sides
|
333 |
+
|
334 |
+
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
335 |
+
|
336 |
+
with gr.Group(elem_id="share-region-named"):
|
337 |
+
with gr.Row():
|
338 |
+
for i in range(num_sides):
|
339 |
+
with gr.Column():
|
340 |
+
model_selectors[i] = gr.Dropdown(
|
341 |
+
choices=models,
|
342 |
+
value=models[i] if len(models) > i else "",
|
343 |
+
interactive=True,
|
344 |
+
show_label=False,
|
345 |
+
container=False,
|
346 |
+
)
|
347 |
+
with gr.Row():
|
348 |
+
with gr.Accordion(
|
349 |
+
f"🔍 Expand to see the descriptions of {len(models)} models", open=False
|
350 |
+
):
|
351 |
+
model_description_md = get_model_description_md(models)
|
352 |
+
gr.Markdown(model_description_md,
|
353 |
+
elem_id="model_description_markdown")
|
354 |
+
|
355 |
+
with gr.Row():
|
356 |
+
for i in range(num_sides):
|
357 |
+
label = "Model A" if i == 0 else "Model B"
|
358 |
+
with gr.Column():
|
359 |
+
chatbots[i] = gr.Chatbot(
|
360 |
+
label=label,
|
361 |
+
elem_id=f"chatbot",
|
362 |
+
height=650,
|
363 |
+
show_copy_button=True,
|
364 |
+
)
|
365 |
+
|
366 |
+
with gr.Row():
|
367 |
+
leftvote_btn = gr.Button(
|
368 |
+
value="👈 A is better", visible=False, interactive=False
|
369 |
+
)
|
370 |
+
rightvote_btn = gr.Button(
|
371 |
+
value="👉 B is better", visible=False, interactive=False
|
372 |
+
)
|
373 |
+
tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
|
374 |
+
bothbad_btn = gr.Button(
|
375 |
+
value="👎 Both are bad", visible=False, interactive=False
|
376 |
+
)
|
377 |
+
|
378 |
+
with gr.Row():
|
379 |
+
textbox = gr.Textbox(
|
380 |
+
show_label=False,
|
381 |
+
placeholder="👉 Enter your prompt and press ENTER",
|
382 |
+
elem_id="input_box",
|
383 |
+
)
|
384 |
+
send_btn = gr.Button(value="Send", variant="primary", scale=0)
|
385 |
+
|
386 |
+
with gr.Row() as button_row:
|
387 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
388 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
389 |
+
share_btn = gr.Button(value="📷 Share")
|
390 |
+
|
391 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
392 |
+
temperature = gr.Slider(
|
393 |
+
minimum=0.0,
|
394 |
+
maximum=1.0,
|
395 |
+
value=0.7,
|
396 |
+
step=0.1,
|
397 |
+
interactive=True,
|
398 |
+
label="Temperature",
|
399 |
+
)
|
400 |
+
top_p = gr.Slider(
|
401 |
+
minimum=0.0,
|
402 |
+
maximum=1.0,
|
403 |
+
value=1.0,
|
404 |
+
step=0.1,
|
405 |
+
interactive=True,
|
406 |
+
label="Top P",
|
407 |
+
)
|
408 |
+
max_output_tokens = gr.Slider(
|
409 |
+
minimum=16,
|
410 |
+
maximum=2048,
|
411 |
+
value=1024,
|
412 |
+
step=64,
|
413 |
+
interactive=True,
|
414 |
+
label="Max output tokens",
|
415 |
+
)
|
416 |
+
|
417 |
+
gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
|
418 |
+
|
419 |
+
# Register listeners
|
420 |
+
btn_list = [
|
421 |
+
leftvote_btn,
|
422 |
+
rightvote_btn,
|
423 |
+
tie_btn,
|
424 |
+
bothbad_btn,
|
425 |
+
regenerate_btn,
|
426 |
+
clear_btn,
|
427 |
+
]
|
428 |
+
leftvote_btn.click(
|
429 |
+
leftvote_last_response,
|
430 |
+
states + model_selectors,
|
431 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
432 |
+
)
|
433 |
+
rightvote_btn.click(
|
434 |
+
rightvote_last_response,
|
435 |
+
states + model_selectors,
|
436 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
437 |
+
)
|
438 |
+
tie_btn.click(
|
439 |
+
tievote_last_response,
|
440 |
+
states + model_selectors,
|
441 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
442 |
+
)
|
443 |
+
bothbad_btn.click(
|
444 |
+
bothbad_vote_last_response,
|
445 |
+
states + model_selectors,
|
446 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
447 |
+
)
|
448 |
+
regenerate_btn.click(
|
449 |
+
regenerate, states, states + chatbots + [textbox] + btn_list
|
450 |
+
).then(
|
451 |
+
bot_response_multi,
|
452 |
+
states + [temperature, top_p, max_output_tokens],
|
453 |
+
states + chatbots + btn_list,
|
454 |
+
).then(
|
455 |
+
flash_buttons, [], btn_list
|
456 |
+
)
|
457 |
+
clear_btn.click(clear_history, None, states +
|
458 |
+
chatbots + [textbox] + btn_list)
|
459 |
+
|
460 |
+
share_js = """
|
461 |
+
function (a, b, c, d) {
|
462 |
+
const captureElement = document.querySelector('#share-region-named');
|
463 |
+
html2canvas(captureElement)
|
464 |
+
.then(canvas => {
|
465 |
+
canvas.style.display = 'none'
|
466 |
+
document.body.appendChild(canvas)
|
467 |
+
return canvas
|
468 |
+
})
|
469 |
+
.then(canvas => {
|
470 |
+
const image = canvas.toDataURL('image/png')
|
471 |
+
const a = document.createElement('a')
|
472 |
+
a.setAttribute('download', 'chatbot-arena.png')
|
473 |
+
a.setAttribute('href', image)
|
474 |
+
a.click()
|
475 |
+
canvas.remove()
|
476 |
+
});
|
477 |
+
return [a, b, c, d];
|
478 |
+
}
|
479 |
+
"""
|
480 |
+
share_btn.click(share_click, states + model_selectors, [], js=share_js)
|
481 |
+
|
482 |
+
for i in range(num_sides):
|
483 |
+
model_selectors[i].change(
|
484 |
+
clear_history, None, states + chatbots + [textbox] + btn_list
|
485 |
+
)
|
486 |
+
|
487 |
+
textbox.submit(
|
488 |
+
add_text,
|
489 |
+
states + model_selectors + [textbox],
|
490 |
+
states + chatbots + [textbox] + btn_list,
|
491 |
+
).then(
|
492 |
+
bot_response_multi,
|
493 |
+
states + [temperature, top_p, max_output_tokens],
|
494 |
+
states + chatbots + btn_list,
|
495 |
+
).then(
|
496 |
+
flash_buttons, [], btn_list
|
497 |
+
)
|
498 |
+
send_btn.click(
|
499 |
+
add_text,
|
500 |
+
states + model_selectors + [textbox],
|
501 |
+
states + chatbots + [textbox] + btn_list,
|
502 |
+
).then(
|
503 |
+
bot_response_multi,
|
504 |
+
states + [temperature, top_p, max_output_tokens],
|
505 |
+
states + chatbots + btn_list,
|
506 |
+
).then(
|
507 |
+
flash_buttons, [], btn_list
|
508 |
+
)
|
509 |
+
|
510 |
+
return states + model_selectors
|
serve/gradio_block_arena_vision.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The gradio demo server for chatting with a large multimodal model.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 -m fastchat.serve.controller
|
6 |
+
python3 -m fastchat.serve.sglang_worker --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf
|
7 |
+
python3 -m fastchat.serve.gradio_web_server_multi --share --vision-arena
|
8 |
+
"""
|
9 |
+
|
10 |
+
import json
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
from typing import List, Union
|
14 |
+
|
15 |
+
import gradio as gr
|
16 |
+
from gradio.data_classes import FileData
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
from .constants import (
|
20 |
+
TEXT_MODERATION_MSG,
|
21 |
+
IMAGE_MODERATION_MSG,
|
22 |
+
MODERATION_MSG,
|
23 |
+
CONVERSATION_LIMIT_MSG,
|
24 |
+
INPUT_CHAR_LEN_LIMIT,
|
25 |
+
CONVERSATION_TURN_LIMIT,
|
26 |
+
SURVEY_LINK,
|
27 |
+
)
|
28 |
+
# from fastchat.model.model_adapter import (
|
29 |
+
# get_conversation_template,
|
30 |
+
# )
|
31 |
+
from .gradio_global_state import Context
|
32 |
+
from .gradio_web_server import (
|
33 |
+
get_model_description_md,
|
34 |
+
acknowledgment_md,
|
35 |
+
bot_response,
|
36 |
+
get_ip,
|
37 |
+
disable_btn,
|
38 |
+
State,
|
39 |
+
get_conv_log_filename,
|
40 |
+
get_remote_logger,
|
41 |
+
)
|
42 |
+
# from fastchat.serve.vision.image import ImageFormat, Image
|
43 |
+
from .utils import (
|
44 |
+
build_logger,
|
45 |
+
moderation_filter,
|
46 |
+
image_moderation_filter,
|
47 |
+
)
|
48 |
+
|
49 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
50 |
+
|
51 |
+
no_change_btn = gr.Button()
|
52 |
+
enable_btn = gr.Button(interactive=True, visible=True)
|
53 |
+
disable_btn = gr.Button(interactive=False)
|
54 |
+
invisible_btn = gr.Button(interactive=False, visible=False)
|
55 |
+
visible_image_column = gr.Image(visible=True)
|
56 |
+
invisible_image_column = gr.Image(visible=False)
|
57 |
+
enable_multimodal = gr.MultimodalTextbox(
|
58 |
+
interactive=True, visible=True, placeholder="Enter your prompt or add image here"
|
59 |
+
)
|
60 |
+
invisible_text = gr.Textbox(visible=False, value="", interactive=False)
|
61 |
+
visible_text = gr.Textbox(
|
62 |
+
visible=True,
|
63 |
+
value="",
|
64 |
+
interactive=True,
|
65 |
+
placeholder="👉 Enter your prompt and press ENTER",
|
66 |
+
)
|
67 |
+
disable_multimodal = gr.MultimodalTextbox(
|
68 |
+
visible=False, value=None, interactive=False)
|
69 |
+
|
70 |
+
|
71 |
+
def get_vqa_sample():
|
72 |
+
random_sample = np.random.choice(vqa_samples)
|
73 |
+
question, path = random_sample["question"], random_sample["path"]
|
74 |
+
res = {"text": "", "files": [path]}
|
75 |
+
return (res, path)
|
76 |
+
|
77 |
+
|
78 |
+
def set_visible_image(textbox):
|
79 |
+
images = textbox["files"]
|
80 |
+
if len(images) == 0:
|
81 |
+
return invisible_image_column
|
82 |
+
elif len(images) > 1:
|
83 |
+
gr.Warning(
|
84 |
+
"We only support single image conversations. Please start a new round if you would like to chat using this image."
|
85 |
+
)
|
86 |
+
|
87 |
+
return visible_image_column
|
88 |
+
|
89 |
+
|
90 |
+
def set_invisible_image():
|
91 |
+
return invisible_image_column
|
92 |
+
|
93 |
+
|
94 |
+
def add_image(textbox):
|
95 |
+
images = textbox["files"]
|
96 |
+
if len(images) == 0:
|
97 |
+
return None
|
98 |
+
|
99 |
+
return images[0]
|
100 |
+
|
101 |
+
|
102 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
103 |
+
filename = get_conv_log_filename(state.is_vision, state.has_csam_image)
|
104 |
+
with open(filename, "a") as fout:
|
105 |
+
data = {
|
106 |
+
"tstamp": round(time.time(), 4),
|
107 |
+
"type": vote_type,
|
108 |
+
"model": model_selector,
|
109 |
+
"state": state.dict(),
|
110 |
+
"ip": get_ip(request),
|
111 |
+
}
|
112 |
+
fout.write(json.dumps(data) + "\n")
|
113 |
+
get_remote_logger().log(data)
|
114 |
+
|
115 |
+
|
116 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
117 |
+
ip = get_ip(request)
|
118 |
+
logger.info(f"upvote. ip: {ip}")
|
119 |
+
vote_last_response(state, "upvote", model_selector, request)
|
120 |
+
return (None,) + (disable_btn,) * 3
|
121 |
+
|
122 |
+
|
123 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
124 |
+
ip = get_ip(request)
|
125 |
+
logger.info(f"downvote. ip: {ip}")
|
126 |
+
vote_last_response(state, "downvote", model_selector, request)
|
127 |
+
return (None,) + (disable_btn,) * 3
|
128 |
+
|
129 |
+
|
130 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
131 |
+
ip = get_ip(request)
|
132 |
+
logger.info(f"flag. ip: {ip}")
|
133 |
+
vote_last_response(state, "flag", model_selector, request)
|
134 |
+
return (None,) + (disable_btn,) * 3
|
135 |
+
|
136 |
+
|
137 |
+
def regenerate(state, request: gr.Request):
|
138 |
+
ip = get_ip(request)
|
139 |
+
logger.info(f"regenerate. ip: {ip}")
|
140 |
+
if not state.regen_support:
|
141 |
+
state.skip_next = True
|
142 |
+
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
143 |
+
state.conv.update_last_message(None)
|
144 |
+
return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5
|
145 |
+
|
146 |
+
|
147 |
+
def clear_history(request: gr.Request):
|
148 |
+
ip = get_ip(request)
|
149 |
+
logger.info(f"clear_history. ip: {ip}")
|
150 |
+
state = None
|
151 |
+
return (state, [], enable_multimodal, invisible_text, invisible_btn) + (
|
152 |
+
disable_btn,
|
153 |
+
) * 5
|
154 |
+
|
155 |
+
|
156 |
+
def clear_history_example(request: gr.Request):
|
157 |
+
ip = get_ip(request)
|
158 |
+
logger.info(f"clear_history_example. ip: {ip}")
|
159 |
+
state = None
|
160 |
+
return (state, [], enable_multimodal, invisible_text, invisible_btn) + (
|
161 |
+
disable_btn,
|
162 |
+
) * 5
|
163 |
+
|
164 |
+
|
165 |
+
# TODO(Chris): At some point, we would like this to be a live-reporting feature.
|
166 |
+
def report_csam_image(state, image):
|
167 |
+
pass
|
168 |
+
|
169 |
+
|
170 |
+
def _prepare_text_with_image(state, text, images, csam_flag):
|
171 |
+
if len(images) > 0:
|
172 |
+
if len(state.conv.get_images()) > 0:
|
173 |
+
# reset convo with new image
|
174 |
+
state.conv = get_conversation_template(state.model_name)
|
175 |
+
|
176 |
+
text = text, [images[0]]
|
177 |
+
|
178 |
+
return text
|
179 |
+
|
180 |
+
|
181 |
+
# NOTE(chris): take multiple images later on
|
182 |
+
def convert_images_to_conversation_format(images):
|
183 |
+
import base64
|
184 |
+
|
185 |
+
MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB = 5 / 1.5
|
186 |
+
conv_images = []
|
187 |
+
if len(images) > 0:
|
188 |
+
conv_image = Image(url=images[0])
|
189 |
+
conv_image.to_conversation_format(MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB)
|
190 |
+
conv_images.append(conv_image)
|
191 |
+
|
192 |
+
return conv_images
|
193 |
+
|
194 |
+
|
195 |
+
def moderate_input(state, text, all_conv_text, model_list, images, ip):
|
196 |
+
text_flagged = moderation_filter(all_conv_text, model_list)
|
197 |
+
# flagged = moderation_filter(text, [state.model_name])
|
198 |
+
nsfw_flagged, csam_flagged = False, False
|
199 |
+
if len(images) > 0:
|
200 |
+
nsfw_flagged, csam_flagged = image_moderation_filter(images[0])
|
201 |
+
|
202 |
+
image_flagged = nsfw_flagged or csam_flagged
|
203 |
+
if text_flagged or image_flagged:
|
204 |
+
logger.info(f"violate moderation. ip: {ip}. text: {all_conv_text}")
|
205 |
+
if text_flagged and not image_flagged:
|
206 |
+
# overwrite the original text
|
207 |
+
text = TEXT_MODERATION_MSG
|
208 |
+
elif not text_flagged and image_flagged:
|
209 |
+
text = IMAGE_MODERATION_MSG
|
210 |
+
elif text_flagged and image_flagged:
|
211 |
+
text = MODERATION_MSG
|
212 |
+
|
213 |
+
if csam_flagged:
|
214 |
+
state.has_csam_image = True
|
215 |
+
report_csam_image(state, images[0])
|
216 |
+
|
217 |
+
return text, image_flagged, csam_flagged
|
218 |
+
|
219 |
+
|
220 |
+
def add_text(
|
221 |
+
state,
|
222 |
+
model_selector,
|
223 |
+
chat_input: Union[str, dict],
|
224 |
+
context: Context,
|
225 |
+
request: gr.Request,
|
226 |
+
):
|
227 |
+
if isinstance(chat_input, dict):
|
228 |
+
text, images = chat_input["text"], chat_input["files"]
|
229 |
+
else:
|
230 |
+
text, images = chat_input, []
|
231 |
+
|
232 |
+
if (
|
233 |
+
len(images) > 0
|
234 |
+
and model_selector in context.text_models
|
235 |
+
and model_selector not in context.vision_models
|
236 |
+
):
|
237 |
+
gr.Warning(f"{model_selector} is a text-only model. Image is ignored.")
|
238 |
+
images = []
|
239 |
+
|
240 |
+
ip = get_ip(request)
|
241 |
+
logger.info(f"add_text. ip: {ip}. len: {len(text)}")
|
242 |
+
|
243 |
+
if state is None:
|
244 |
+
if len(images) == 0:
|
245 |
+
state = State(model_selector, is_vision=False)
|
246 |
+
else:
|
247 |
+
state = State(model_selector, is_vision=True)
|
248 |
+
|
249 |
+
if len(text) <= 0:
|
250 |
+
state.skip_next = True
|
251 |
+
return (state, state.to_gradio_chatbot(), None, "", no_change_btn) + (
|
252 |
+
no_change_btn,
|
253 |
+
) * 5
|
254 |
+
|
255 |
+
all_conv_text = state.conv.get_prompt()
|
256 |
+
all_conv_text = all_conv_text[-2000:] + "\nuser: " + text
|
257 |
+
|
258 |
+
images = convert_images_to_conversation_format(images)
|
259 |
+
|
260 |
+
text, image_flagged, csam_flag = moderate_input(
|
261 |
+
state, text, all_conv_text, [state.model_name], images, ip
|
262 |
+
)
|
263 |
+
|
264 |
+
if image_flagged:
|
265 |
+
logger.info(f"image flagged. ip: {ip}. text: {text}")
|
266 |
+
state.skip_next = True
|
267 |
+
return (
|
268 |
+
state,
|
269 |
+
state.to_gradio_chatbot(),
|
270 |
+
{"text": IMAGE_MODERATION_MSG},
|
271 |
+
"",
|
272 |
+
no_change_btn,
|
273 |
+
) + (no_change_btn,) * 5
|
274 |
+
|
275 |
+
if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
|
276 |
+
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
|
277 |
+
state.skip_next = True
|
278 |
+
return (
|
279 |
+
state,
|
280 |
+
state.to_gradio_chatbot(),
|
281 |
+
{"text": CONVERSATION_LIMIT_MSG},
|
282 |
+
"",
|
283 |
+
no_change_btn,
|
284 |
+
) + (no_change_btn,) * 5
|
285 |
+
|
286 |
+
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
|
287 |
+
text = _prepare_text_with_image(state, text, images, csam_flag=csam_flag)
|
288 |
+
state.conv.append_message(state.conv.roles[0], text)
|
289 |
+
state.conv.append_message(state.conv.roles[1], None)
|
290 |
+
return (
|
291 |
+
state,
|
292 |
+
state.to_gradio_chatbot(),
|
293 |
+
disable_multimodal,
|
294 |
+
visible_text,
|
295 |
+
enable_btn,
|
296 |
+
) + (disable_btn,) * 5
|
297 |
+
|
298 |
+
|
299 |
+
def build_single_vision_language_model_ui(
|
300 |
+
context: Context, add_promotion_links=False, random_questions=None
|
301 |
+
):
|
302 |
+
promotion = (
|
303 |
+
f"""
|
304 |
+
[Blog](https://blog.lmarena.ai/blog/2023/arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2403.04132) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/6GXcFg3TH8) | [Kaggle Competition](https://www.kaggle.com/competitions/lmsys-chatbot-arena)
|
305 |
+
|
306 |
+
{SURVEY_LINK}
|
307 |
+
|
308 |
+
**❗️ For research purposes, we log user prompts and images, and may release this data to the public in the future. Please do not upload any confidential or personal information.**
|
309 |
+
|
310 |
+
Note: You can only chat with <span style='color: #DE3163; font-weight: bold'>one image per conversation</span>. You can upload images less than 15MB. Click the "Random Example" button to chat with a random image."""
|
311 |
+
if add_promotion_links
|
312 |
+
else ""
|
313 |
+
)
|
314 |
+
|
315 |
+
notice_markdown = f"""
|
316 |
+
# 🏔️ Chatbot Arena (formerly LMSYS): Free AI Chat to Compare & Test Best AI Chatbots
|
317 |
+
{promotion}
|
318 |
+
"""
|
319 |
+
|
320 |
+
state = gr.State()
|
321 |
+
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
322 |
+
vision_not_in_text_models = [
|
323 |
+
model for model in context.vision_models if model not in context.text_models
|
324 |
+
]
|
325 |
+
text_and_vision_models = context.text_models + vision_not_in_text_models
|
326 |
+
context_state = gr.State(context)
|
327 |
+
|
328 |
+
with gr.Group():
|
329 |
+
with gr.Row(elem_id="model_selector_row"):
|
330 |
+
model_selector = gr.Dropdown(
|
331 |
+
choices=text_and_vision_models,
|
332 |
+
value=text_and_vision_models[0]
|
333 |
+
if len(text_and_vision_models) > 0
|
334 |
+
else "",
|
335 |
+
interactive=True,
|
336 |
+
show_label=False,
|
337 |
+
container=False,
|
338 |
+
)
|
339 |
+
|
340 |
+
with gr.Accordion(
|
341 |
+
f"🔍 Expand to see the descriptions of {len(text_and_vision_models)} models",
|
342 |
+
open=False,
|
343 |
+
):
|
344 |
+
model_description_md = get_model_description_md(
|
345 |
+
text_and_vision_models)
|
346 |
+
gr.Markdown(model_description_md,
|
347 |
+
elem_id="model_description_markdown")
|
348 |
+
|
349 |
+
with gr.Row():
|
350 |
+
with gr.Column(scale=2, visible=False) as image_column:
|
351 |
+
imagebox = gr.Image(
|
352 |
+
type="pil",
|
353 |
+
show_label=False,
|
354 |
+
interactive=False,
|
355 |
+
)
|
356 |
+
with gr.Column(scale=8):
|
357 |
+
chatbot = gr.Chatbot(
|
358 |
+
elem_id="chatbot",
|
359 |
+
label="Scroll down and start chatting",
|
360 |
+
height=650,
|
361 |
+
show_copy_button=True,
|
362 |
+
)
|
363 |
+
|
364 |
+
with gr.Row():
|
365 |
+
textbox = gr.Textbox(
|
366 |
+
show_label=False,
|
367 |
+
placeholder="👉 Enter your prompt and press ENTER",
|
368 |
+
elem_id="input_box",
|
369 |
+
visible=False,
|
370 |
+
)
|
371 |
+
|
372 |
+
send_btn = gr.Button(
|
373 |
+
value="Send", variant="primary", scale=0, visible=False, interactive=False
|
374 |
+
)
|
375 |
+
|
376 |
+
multimodal_textbox = gr.MultimodalTextbox(
|
377 |
+
file_types=["image"],
|
378 |
+
show_label=False,
|
379 |
+
placeholder="Enter your prompt or add image here",
|
380 |
+
container=True,
|
381 |
+
elem_id="input_box",
|
382 |
+
)
|
383 |
+
|
384 |
+
with gr.Row(elem_id="buttons"):
|
385 |
+
if random_questions:
|
386 |
+
global vqa_samples
|
387 |
+
with open(random_questions, "r") as f:
|
388 |
+
vqa_samples = json.load(f)
|
389 |
+
random_btn = gr.Button(value="🎲 Random Example", interactive=True)
|
390 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
391 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
392 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
393 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
394 |
+
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
395 |
+
|
396 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
397 |
+
temperature = gr.Slider(
|
398 |
+
minimum=0.0,
|
399 |
+
maximum=1.0,
|
400 |
+
value=0.7,
|
401 |
+
step=0.1,
|
402 |
+
interactive=True,
|
403 |
+
label="Temperature",
|
404 |
+
)
|
405 |
+
top_p = gr.Slider(
|
406 |
+
minimum=0.0,
|
407 |
+
maximum=1.0,
|
408 |
+
value=0.7,
|
409 |
+
step=0.1,
|
410 |
+
interactive=True,
|
411 |
+
label="Top P",
|
412 |
+
)
|
413 |
+
max_output_tokens = gr.Slider(
|
414 |
+
minimum=0,
|
415 |
+
maximum=2048,
|
416 |
+
value=1024,
|
417 |
+
step=64,
|
418 |
+
interactive=True,
|
419 |
+
label="Max output tokens",
|
420 |
+
)
|
421 |
+
|
422 |
+
if add_promotion_links:
|
423 |
+
gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
|
424 |
+
|
425 |
+
# Register listeners
|
426 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
427 |
+
upvote_btn.click(
|
428 |
+
upvote_last_response,
|
429 |
+
[state, model_selector],
|
430 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
431 |
+
)
|
432 |
+
downvote_btn.click(
|
433 |
+
downvote_last_response,
|
434 |
+
[state, model_selector],
|
435 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
436 |
+
)
|
437 |
+
flag_btn.click(
|
438 |
+
flag_last_response,
|
439 |
+
[state, model_selector],
|
440 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
441 |
+
)
|
442 |
+
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
443 |
+
bot_response,
|
444 |
+
[state, temperature, top_p, max_output_tokens],
|
445 |
+
[state, chatbot] + btn_list,
|
446 |
+
)
|
447 |
+
clear_btn.click(
|
448 |
+
clear_history,
|
449 |
+
None,
|
450 |
+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
|
451 |
+
)
|
452 |
+
|
453 |
+
model_selector.change(
|
454 |
+
clear_history,
|
455 |
+
None,
|
456 |
+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
|
457 |
+
).then(set_visible_image, [multimodal_textbox], [image_column])
|
458 |
+
|
459 |
+
multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
|
460 |
+
set_visible_image, [multimodal_textbox], [image_column]
|
461 |
+
).then(
|
462 |
+
clear_history_example,
|
463 |
+
None,
|
464 |
+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
|
465 |
+
)
|
466 |
+
|
467 |
+
multimodal_textbox.submit(
|
468 |
+
add_text,
|
469 |
+
[state, model_selector, multimodal_textbox, context_state],
|
470 |
+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
|
471 |
+
).then(set_invisible_image, [], [image_column]).then(
|
472 |
+
bot_response,
|
473 |
+
[state, temperature, top_p, max_output_tokens],
|
474 |
+
[state, chatbot] + btn_list,
|
475 |
+
)
|
476 |
+
|
477 |
+
textbox.submit(
|
478 |
+
add_text,
|
479 |
+
[state, model_selector, textbox, context_state],
|
480 |
+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
|
481 |
+
).then(set_invisible_image, [], [image_column]).then(
|
482 |
+
bot_response,
|
483 |
+
[state, temperature, top_p, max_output_tokens],
|
484 |
+
[state, chatbot] + btn_list,
|
485 |
+
)
|
486 |
+
|
487 |
+
send_btn.click(
|
488 |
+
add_text,
|
489 |
+
[state, model_selector, textbox, context_state],
|
490 |
+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
|
491 |
+
).then(set_invisible_image, [], [image_column]).then(
|
492 |
+
bot_response,
|
493 |
+
[state, temperature, top_p, max_output_tokens],
|
494 |
+
[state, chatbot] + btn_list,
|
495 |
+
)
|
496 |
+
|
497 |
+
if random_questions:
|
498 |
+
random_btn.click(
|
499 |
+
get_vqa_sample, # First, get the VQA sample
|
500 |
+
[], # Pass the path to the VQA samples
|
501 |
+
[multimodal_textbox, imagebox], # Outputs are textbox and imagebox
|
502 |
+
).then(set_visible_image, [multimodal_textbox], [image_column]).then(
|
503 |
+
clear_history_example,
|
504 |
+
None,
|
505 |
+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
|
506 |
+
)
|
507 |
+
|
508 |
+
return [state, model_selector]
|
serve/gradio_block_arena_vision_anony.py
ADDED
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Chatbot Arena (battle) tab.
|
3 |
+
Users chat with two anonymous models.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import numpy as np
|
11 |
+
from typing import Union
|
12 |
+
|
13 |
+
from .constants import (
|
14 |
+
TEXT_MODERATION_MSG,
|
15 |
+
IMAGE_MODERATION_MSG,
|
16 |
+
MODERATION_MSG,
|
17 |
+
CONVERSATION_LIMIT_MSG,
|
18 |
+
SLOW_MODEL_MSG,
|
19 |
+
BLIND_MODE_INPUT_CHAR_LEN_LIMIT,
|
20 |
+
CONVERSATION_TURN_LIMIT,
|
21 |
+
SURVEY_LINK,
|
22 |
+
)
|
23 |
+
from .gradio_block_arena_named import flash_buttons
|
24 |
+
from .gradio_web_server import (
|
25 |
+
State,
|
26 |
+
bot_response,
|
27 |
+
get_conv_log_filename,
|
28 |
+
no_change_btn,
|
29 |
+
enable_btn,
|
30 |
+
disable_btn,
|
31 |
+
invisible_btn,
|
32 |
+
acknowledgment_md,
|
33 |
+
get_ip,
|
34 |
+
get_model_description_md,
|
35 |
+
disable_text,
|
36 |
+
enable_text,
|
37 |
+
)
|
38 |
+
from .gradio_block_arena_anony import (
|
39 |
+
flash_buttons,
|
40 |
+
vote_last_response,
|
41 |
+
leftvote_last_response,
|
42 |
+
rightvote_last_response,
|
43 |
+
tievote_last_response,
|
44 |
+
bothbad_vote_last_response,
|
45 |
+
regenerate,
|
46 |
+
clear_history,
|
47 |
+
share_click,
|
48 |
+
bot_response_multi,
|
49 |
+
set_global_vars_anony,
|
50 |
+
load_demo_side_by_side_anony,
|
51 |
+
get_sample_weight,
|
52 |
+
get_battle_pair,
|
53 |
+
SAMPLING_WEIGHTS,
|
54 |
+
BATTLE_TARGETS,
|
55 |
+
SAMPLING_BOOST_MODELS,
|
56 |
+
OUTAGE_MODELS,
|
57 |
+
)
|
58 |
+
from .gradio_block_arena_vision import (
|
59 |
+
set_invisible_image,
|
60 |
+
set_visible_image,
|
61 |
+
add_image,
|
62 |
+
moderate_input,
|
63 |
+
enable_multimodal,
|
64 |
+
_prepare_text_with_image,
|
65 |
+
convert_images_to_conversation_format,
|
66 |
+
invisible_text,
|
67 |
+
visible_text,
|
68 |
+
disable_multimodal,
|
69 |
+
)
|
70 |
+
from .gradio_global_state import Context
|
71 |
+
from .remote_logger import get_remote_logger
|
72 |
+
from .utils import (
|
73 |
+
build_logger,
|
74 |
+
moderation_filter,
|
75 |
+
image_moderation_filter,
|
76 |
+
)
|
77 |
+
|
78 |
+
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
|
79 |
+
|
80 |
+
num_sides = 2
|
81 |
+
enable_moderation = False
|
82 |
+
anony_names = ["", ""]
|
83 |
+
text_models = []
|
84 |
+
vl_models = []
|
85 |
+
|
86 |
+
# TODO(chris): fix sampling weights
|
87 |
+
VISION_SAMPLING_WEIGHTS = {}
|
88 |
+
|
89 |
+
# TODO(chris): Find battle targets that make sense
|
90 |
+
VISION_BATTLE_TARGETS = {}
|
91 |
+
|
92 |
+
# TODO(chris): Fill out models that require sampling boost
|
93 |
+
VISION_SAMPLING_BOOST_MODELS = []
|
94 |
+
|
95 |
+
# outage models won't be sampled.
|
96 |
+
VISION_OUTAGE_MODELS = []
|
97 |
+
|
98 |
+
|
99 |
+
def get_vqa_sample():
|
100 |
+
random_sample = np.random.choice(vqa_samples)
|
101 |
+
question, path = random_sample["question"], random_sample["path"]
|
102 |
+
res = {"text": "", "files": [path]}
|
103 |
+
return (res, path)
|
104 |
+
|
105 |
+
|
106 |
+
def load_demo_side_by_side_vision_anony():
|
107 |
+
states = [None] * num_sides
|
108 |
+
selector_updates = [
|
109 |
+
gr.Markdown(visible=True),
|
110 |
+
gr.Markdown(visible=True),
|
111 |
+
]
|
112 |
+
|
113 |
+
return states + selector_updates
|
114 |
+
|
115 |
+
|
116 |
+
def clear_history_example(request: gr.Request):
|
117 |
+
logger.info(f"clear_history_example (anony). ip: {get_ip(request)}")
|
118 |
+
return (
|
119 |
+
[None] * num_sides
|
120 |
+
+ [None] * num_sides
|
121 |
+
+ anony_names
|
122 |
+
+ [enable_multimodal, invisible_text, invisible_btn]
|
123 |
+
+ [invisible_btn] * 4
|
124 |
+
+ [disable_btn] * 2
|
125 |
+
+ [enable_btn]
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
|
130 |
+
filename = get_conv_log_filename(
|
131 |
+
states[0].is_vision, states[0].has_csam_image)
|
132 |
+
|
133 |
+
with open(filename, "a") as fout:
|
134 |
+
data = {
|
135 |
+
"tstamp": round(time.time(), 4),
|
136 |
+
"type": vote_type,
|
137 |
+
"models": [x for x in model_selectors],
|
138 |
+
"states": [x.dict() for x in states],
|
139 |
+
"ip": get_ip(request),
|
140 |
+
}
|
141 |
+
fout.write(json.dumps(data) + "\n")
|
142 |
+
get_remote_logger().log(data)
|
143 |
+
|
144 |
+
gr.Info(
|
145 |
+
"🎉 Thanks for voting! Your vote shapes the leaderboard, please vote RESPONSIBLY."
|
146 |
+
)
|
147 |
+
|
148 |
+
model_name_1 = states[0].model_name
|
149 |
+
model_name_2 = states[1].model_name
|
150 |
+
model_name_map = {}
|
151 |
+
|
152 |
+
if model_name_1 in model_name_map:
|
153 |
+
model_name_1 = model_name_map[model_name_1]
|
154 |
+
if model_name_2 in model_name_map:
|
155 |
+
model_name_2 = model_name_map[model_name_2]
|
156 |
+
|
157 |
+
if ":" not in model_selectors[0]:
|
158 |
+
for i in range(5):
|
159 |
+
names = (
|
160 |
+
"### Model A: " + model_name_1,
|
161 |
+
"### Model B: " + model_name_2,
|
162 |
+
)
|
163 |
+
yield names + (disable_text,) + (disable_btn,) * 4
|
164 |
+
time.sleep(0.1)
|
165 |
+
else:
|
166 |
+
names = (
|
167 |
+
"### Model A: " + model_name_1,
|
168 |
+
"### Model B: " + model_name_2,
|
169 |
+
)
|
170 |
+
yield names + (disable_text,) + (disable_btn,) * 4
|
171 |
+
|
172 |
+
|
173 |
+
def leftvote_last_response(
|
174 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
175 |
+
):
|
176 |
+
logger.info(f"leftvote (anony). ip: {get_ip(request)}")
|
177 |
+
for x in vote_last_response(
|
178 |
+
[state0, state1], "leftvote", [model_selector0, model_selector1], request
|
179 |
+
):
|
180 |
+
yield x
|
181 |
+
|
182 |
+
|
183 |
+
def rightvote_last_response(
|
184 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
185 |
+
):
|
186 |
+
logger.info(f"rightvote (anony). ip: {get_ip(request)}")
|
187 |
+
for x in vote_last_response(
|
188 |
+
[state0, state1], "rightvote", [
|
189 |
+
model_selector0, model_selector1], request
|
190 |
+
):
|
191 |
+
yield x
|
192 |
+
|
193 |
+
|
194 |
+
def tievote_last_response(
|
195 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
196 |
+
):
|
197 |
+
logger.info(f"tievote (anony). ip: {get_ip(request)}")
|
198 |
+
for x in vote_last_response(
|
199 |
+
[state0, state1], "tievote", [model_selector0, model_selector1], request
|
200 |
+
):
|
201 |
+
yield x
|
202 |
+
|
203 |
+
|
204 |
+
def bothbad_vote_last_response(
|
205 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
206 |
+
):
|
207 |
+
logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
|
208 |
+
for x in vote_last_response(
|
209 |
+
[state0, state1], "bothbad_vote", [
|
210 |
+
model_selector0, model_selector1], request
|
211 |
+
):
|
212 |
+
yield x
|
213 |
+
|
214 |
+
|
215 |
+
def regenerate(state0, state1, request: gr.Request):
|
216 |
+
logger.info(f"regenerate (anony). ip: {get_ip(request)}")
|
217 |
+
states = [state0, state1]
|
218 |
+
if state0.regen_support and state1.regen_support:
|
219 |
+
for i in range(num_sides):
|
220 |
+
states[i].conv.update_last_message(None)
|
221 |
+
return (
|
222 |
+
states
|
223 |
+
+ [x.to_gradio_chatbot() for x in states]
|
224 |
+
+ [None]
|
225 |
+
+ [disable_btn] * 6
|
226 |
+
)
|
227 |
+
states[0].skip_next = True
|
228 |
+
states[1].skip_next = True
|
229 |
+
return (
|
230 |
+
states + [x.to_gradio_chatbot() for x in states] +
|
231 |
+
[None] + [no_change_btn] * 6
|
232 |
+
)
|
233 |
+
|
234 |
+
|
235 |
+
def clear_history(request: gr.Request):
|
236 |
+
logger.info(f"clear_history (anony). ip: {get_ip(request)}")
|
237 |
+
return (
|
238 |
+
[None] * num_sides
|
239 |
+
+ [None] * num_sides
|
240 |
+
+ anony_names
|
241 |
+
+ [enable_multimodal, invisible_text, invisible_btn]
|
242 |
+
+ [invisible_btn] * 4
|
243 |
+
+ [disable_btn] * 2
|
244 |
+
+ [enable_btn]
|
245 |
+
+ [""]
|
246 |
+
)
|
247 |
+
|
248 |
+
|
249 |
+
def add_text(
|
250 |
+
state0,
|
251 |
+
state1,
|
252 |
+
model_selector0,
|
253 |
+
model_selector1,
|
254 |
+
chat_input: Union[str, dict],
|
255 |
+
context: Context,
|
256 |
+
request: gr.Request,
|
257 |
+
):
|
258 |
+
if isinstance(chat_input, dict):
|
259 |
+
text, images = chat_input["text"], chat_input["files"]
|
260 |
+
else:
|
261 |
+
text = chat_input
|
262 |
+
images = []
|
263 |
+
|
264 |
+
ip = get_ip(request)
|
265 |
+
logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
|
266 |
+
states = [state0, state1]
|
267 |
+
model_selectors = [model_selector0, model_selector1]
|
268 |
+
|
269 |
+
# Init states if necessary
|
270 |
+
if states[0] is None:
|
271 |
+
assert states[1] is None
|
272 |
+
|
273 |
+
if len(images) > 0:
|
274 |
+
model_left, model_right = get_battle_pair(
|
275 |
+
context.all_vision_models,
|
276 |
+
VISION_BATTLE_TARGETS,
|
277 |
+
VISION_OUTAGE_MODELS,
|
278 |
+
VISION_SAMPLING_WEIGHTS,
|
279 |
+
VISION_SAMPLING_BOOST_MODELS,
|
280 |
+
)
|
281 |
+
states = [
|
282 |
+
State(model_left, is_vision=True),
|
283 |
+
State(model_right, is_vision=True),
|
284 |
+
]
|
285 |
+
else:
|
286 |
+
model_left, model_right = get_battle_pair(
|
287 |
+
context.all_text_models,
|
288 |
+
BATTLE_TARGETS,
|
289 |
+
OUTAGE_MODELS,
|
290 |
+
SAMPLING_WEIGHTS,
|
291 |
+
SAMPLING_BOOST_MODELS,
|
292 |
+
)
|
293 |
+
|
294 |
+
states = [
|
295 |
+
State(model_left, is_vision=False),
|
296 |
+
State(model_right, is_vision=False),
|
297 |
+
]
|
298 |
+
|
299 |
+
if len(text) <= 0:
|
300 |
+
for i in range(num_sides):
|
301 |
+
states[i].skip_next = True
|
302 |
+
return (
|
303 |
+
states
|
304 |
+
+ [x.to_gradio_chatbot() for x in states]
|
305 |
+
+ [None, "", no_change_btn]
|
306 |
+
+ [
|
307 |
+
no_change_btn,
|
308 |
+
]
|
309 |
+
* 7
|
310 |
+
+ [""]
|
311 |
+
)
|
312 |
+
|
313 |
+
model_list = [states[i].model_name for i in range(num_sides)]
|
314 |
+
|
315 |
+
images = convert_images_to_conversation_format(images)
|
316 |
+
|
317 |
+
text, image_flagged, csam_flag = moderate_input(
|
318 |
+
state0, text, text, model_list, images, ip
|
319 |
+
)
|
320 |
+
|
321 |
+
conv = states[0].conv
|
322 |
+
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
|
323 |
+
logger.info(
|
324 |
+
f"conversation turn limit. ip: {get_ip(request)}. text: {text}")
|
325 |
+
for i in range(num_sides):
|
326 |
+
states[i].skip_next = True
|
327 |
+
return (
|
328 |
+
states
|
329 |
+
+ [x.to_gradio_chatbot() for x in states]
|
330 |
+
+ [{"text": CONVERSATION_LIMIT_MSG}, "", no_change_btn]
|
331 |
+
+ [
|
332 |
+
no_change_btn,
|
333 |
+
]
|
334 |
+
* 7
|
335 |
+
+ [""]
|
336 |
+
)
|
337 |
+
|
338 |
+
if image_flagged:
|
339 |
+
logger.info(f"image flagged. ip: {ip}. text: {text}")
|
340 |
+
for i in range(num_sides):
|
341 |
+
states[i].skip_next = True
|
342 |
+
return (
|
343 |
+
states
|
344 |
+
+ [x.to_gradio_chatbot() for x in states]
|
345 |
+
+ [
|
346 |
+
{
|
347 |
+
"text": IMAGE_MODERATION_MSG
|
348 |
+
+ " PLEASE CLICK 🎲 NEW ROUND TO START A NEW CONVERSATION."
|
349 |
+
},
|
350 |
+
"",
|
351 |
+
no_change_btn,
|
352 |
+
]
|
353 |
+
+ [no_change_btn] * 7
|
354 |
+
+ [""]
|
355 |
+
)
|
356 |
+
|
357 |
+
text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
|
358 |
+
for i in range(num_sides):
|
359 |
+
post_processed_text = _prepare_text_with_image(
|
360 |
+
states[i], text, images, csam_flag=csam_flag
|
361 |
+
)
|
362 |
+
states[i].conv.append_message(
|
363 |
+
states[i].conv.roles[0], post_processed_text)
|
364 |
+
states[i].conv.append_message(states[i].conv.roles[1], None)
|
365 |
+
states[i].skip_next = False
|
366 |
+
|
367 |
+
hint_msg = ""
|
368 |
+
for i in range(num_sides):
|
369 |
+
if "deluxe" in states[i].model_name:
|
370 |
+
hint_msg = SLOW_MODEL_MSG
|
371 |
+
return (
|
372 |
+
states
|
373 |
+
+ [x.to_gradio_chatbot() for x in states]
|
374 |
+
+ [disable_multimodal, visible_text, enable_btn]
|
375 |
+
+ [
|
376 |
+
disable_btn,
|
377 |
+
]
|
378 |
+
* 7
|
379 |
+
+ [hint_msg]
|
380 |
+
)
|
381 |
+
|
382 |
+
|
383 |
+
def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
|
384 |
+
notice_markdown = f"""
|
385 |
+
# ⚔️ Chatbot Arena (formerly LMSYS): Free AI Chat to Compare & Test Best AI Chatbots
|
386 |
+
|
387 |
+
{SURVEY_LINK}
|
388 |
+
|
389 |
+
## 📜 How It Works
|
390 |
+
- **Blind Test**: Ask any question to two anonymous AI chatbots (ChatGPT, Gemini, Claude, Llama, and more).
|
391 |
+
- **Vote for the Best**: Choose the best response. You can keep chatting until you find a winner.
|
392 |
+
- **Play Fair**: If AI identity reveals, your vote won't count.
|
393 |
+
|
394 |
+
**NEW** Image Support: <span style='color: #DE3163; font-weight: bold'>Upload an image</span> to unlock the multimodal arena!
|
395 |
+
|
396 |
+
## 🏆 Chatbot Arena LLM [Leaderboard](https://lmarena.ai/leaderboard)
|
397 |
+
- Backed by over **1,000,000+** community votes, our platform ranks the best LLM and AI chatbots. Explore the top AI models on our LLM [leaderboard](https://lmarena.ai/leaderboard)!
|
398 |
+
|
399 |
+
## 👇 Chat now!
|
400 |
+
"""
|
401 |
+
|
402 |
+
states = [gr.State() for _ in range(num_sides)]
|
403 |
+
model_selectors = [None] * num_sides
|
404 |
+
chatbots = [None] * num_sides
|
405 |
+
context_state = gr.State(context)
|
406 |
+
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
407 |
+
text_and_vision_models = context.models
|
408 |
+
|
409 |
+
with gr.Row():
|
410 |
+
with gr.Column(scale=2, visible=False) as image_column:
|
411 |
+
imagebox = gr.Image(
|
412 |
+
type="pil",
|
413 |
+
show_label=False,
|
414 |
+
interactive=False,
|
415 |
+
)
|
416 |
+
|
417 |
+
with gr.Column(scale=5):
|
418 |
+
with gr.Group(elem_id="share-region-anony"):
|
419 |
+
with gr.Accordion(
|
420 |
+
f"🔍 Expand to see the descriptions of {len(text_and_vision_models)} models",
|
421 |
+
open=False,
|
422 |
+
):
|
423 |
+
model_description_md = get_model_description_md(
|
424 |
+
text_and_vision_models
|
425 |
+
)
|
426 |
+
gr.Markdown(
|
427 |
+
model_description_md, elem_id="model_description_markdown"
|
428 |
+
)
|
429 |
+
|
430 |
+
with gr.Row():
|
431 |
+
for i in range(num_sides):
|
432 |
+
label = "Model A" if i == 0 else "Model B"
|
433 |
+
with gr.Column():
|
434 |
+
chatbots[i] = gr.Chatbot(
|
435 |
+
label=label,
|
436 |
+
elem_id="chatbot",
|
437 |
+
height=650,
|
438 |
+
show_copy_button=True,
|
439 |
+
)
|
440 |
+
|
441 |
+
with gr.Row():
|
442 |
+
for i in range(num_sides):
|
443 |
+
with gr.Column():
|
444 |
+
model_selectors[i] = gr.Markdown(
|
445 |
+
anony_names[i], elem_id="model_selector_md"
|
446 |
+
)
|
447 |
+
with gr.Row():
|
448 |
+
slow_warning = gr.Markdown("", elem_id="notice_markdown")
|
449 |
+
|
450 |
+
with gr.Row():
|
451 |
+
leftvote_btn = gr.Button(
|
452 |
+
value="👈 A is better", visible=False, interactive=False
|
453 |
+
)
|
454 |
+
rightvote_btn = gr.Button(
|
455 |
+
value="👉 B is better", visible=False, interactive=False
|
456 |
+
)
|
457 |
+
tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
|
458 |
+
bothbad_btn = gr.Button(
|
459 |
+
value="👎 Both are bad", visible=False, interactive=False
|
460 |
+
)
|
461 |
+
|
462 |
+
with gr.Row():
|
463 |
+
textbox = gr.Textbox(
|
464 |
+
show_label=False,
|
465 |
+
placeholder="👉 Enter your prompt and press ENTER",
|
466 |
+
elem_id="input_box",
|
467 |
+
visible=False,
|
468 |
+
scale=3,
|
469 |
+
)
|
470 |
+
|
471 |
+
multimodal_textbox = gr.MultimodalTextbox(
|
472 |
+
file_types=["image"],
|
473 |
+
show_label=False,
|
474 |
+
container=True,
|
475 |
+
placeholder="Enter your prompt or add image here",
|
476 |
+
elem_id="input_box",
|
477 |
+
scale=3,
|
478 |
+
)
|
479 |
+
send_btn = gr.Button(
|
480 |
+
value="Send", variant="primary", scale=1, visible=False, interactive=False
|
481 |
+
)
|
482 |
+
|
483 |
+
with gr.Row() as button_row:
|
484 |
+
if random_questions:
|
485 |
+
global vqa_samples
|
486 |
+
with open(random_questions, "r") as f:
|
487 |
+
vqa_samples = json.load(f)
|
488 |
+
random_btn = gr.Button(value="🔮 Random Image", interactive=True)
|
489 |
+
clear_btn = gr.Button(value="🎲 New Round", interactive=False)
|
490 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
491 |
+
share_btn = gr.Button(value="📷 Share")
|
492 |
+
|
493 |
+
with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
|
494 |
+
temperature = gr.Slider(
|
495 |
+
minimum=0.0,
|
496 |
+
maximum=1.0,
|
497 |
+
value=0.7,
|
498 |
+
step=0.1,
|
499 |
+
interactive=True,
|
500 |
+
label="Temperature",
|
501 |
+
)
|
502 |
+
top_p = gr.Slider(
|
503 |
+
minimum=0.0,
|
504 |
+
maximum=1.0,
|
505 |
+
value=1.0,
|
506 |
+
step=0.1,
|
507 |
+
interactive=True,
|
508 |
+
label="Top P",
|
509 |
+
)
|
510 |
+
max_output_tokens = gr.Slider(
|
511 |
+
minimum=16,
|
512 |
+
maximum=2048,
|
513 |
+
value=2000,
|
514 |
+
step=64,
|
515 |
+
interactive=True,
|
516 |
+
label="Max output tokens",
|
517 |
+
)
|
518 |
+
|
519 |
+
gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
|
520 |
+
|
521 |
+
# Register listeners
|
522 |
+
btn_list = [
|
523 |
+
leftvote_btn,
|
524 |
+
rightvote_btn,
|
525 |
+
tie_btn,
|
526 |
+
bothbad_btn,
|
527 |
+
regenerate_btn,
|
528 |
+
clear_btn,
|
529 |
+
]
|
530 |
+
leftvote_btn.click(
|
531 |
+
leftvote_last_response,
|
532 |
+
states + model_selectors,
|
533 |
+
model_selectors + [textbox, leftvote_btn,
|
534 |
+
rightvote_btn, tie_btn, bothbad_btn],
|
535 |
+
)
|
536 |
+
rightvote_btn.click(
|
537 |
+
rightvote_last_response,
|
538 |
+
states + model_selectors,
|
539 |
+
model_selectors + [textbox, leftvote_btn,
|
540 |
+
rightvote_btn, tie_btn, bothbad_btn],
|
541 |
+
)
|
542 |
+
tie_btn.click(
|
543 |
+
tievote_last_response,
|
544 |
+
states + model_selectors,
|
545 |
+
model_selectors + [textbox, leftvote_btn,
|
546 |
+
rightvote_btn, tie_btn, bothbad_btn],
|
547 |
+
)
|
548 |
+
bothbad_btn.click(
|
549 |
+
bothbad_vote_last_response,
|
550 |
+
states + model_selectors,
|
551 |
+
model_selectors + [textbox, leftvote_btn,
|
552 |
+
rightvote_btn, tie_btn, bothbad_btn],
|
553 |
+
)
|
554 |
+
regenerate_btn.click(
|
555 |
+
regenerate, states, states + chatbots + [textbox] + btn_list
|
556 |
+
).then(
|
557 |
+
bot_response_multi,
|
558 |
+
states + [temperature, top_p, max_output_tokens],
|
559 |
+
states + chatbots + btn_list,
|
560 |
+
).then(
|
561 |
+
flash_buttons, [], btn_list
|
562 |
+
)
|
563 |
+
clear_btn.click(
|
564 |
+
clear_history,
|
565 |
+
None,
|
566 |
+
states
|
567 |
+
+ chatbots
|
568 |
+
+ model_selectors
|
569 |
+
+ [multimodal_textbox, textbox, send_btn]
|
570 |
+
+ btn_list
|
571 |
+
+ [random_btn]
|
572 |
+
+ [slow_warning],
|
573 |
+
)
|
574 |
+
|
575 |
+
share_js = """
|
576 |
+
function (a, b, c, d) {
|
577 |
+
const captureElement = document.querySelector('#share-region-anony');
|
578 |
+
html2canvas(captureElement)
|
579 |
+
.then(canvas => {
|
580 |
+
canvas.style.display = 'none'
|
581 |
+
document.body.appendChild(canvas)
|
582 |
+
return canvas
|
583 |
+
})
|
584 |
+
.then(canvas => {
|
585 |
+
const image = canvas.toDataURL('image/png')
|
586 |
+
const a = document.createElement('a')
|
587 |
+
a.setAttribute('download', 'chatbot-arena.png')
|
588 |
+
a.setAttribute('href', image)
|
589 |
+
a.click()
|
590 |
+
canvas.remove()
|
591 |
+
});
|
592 |
+
return [a, b, c, d];
|
593 |
+
}
|
594 |
+
"""
|
595 |
+
share_btn.click(share_click, states + model_selectors, [], js=share_js)
|
596 |
+
|
597 |
+
multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
|
598 |
+
set_visible_image, [multimodal_textbox], [image_column]
|
599 |
+
).then(
|
600 |
+
clear_history_example,
|
601 |
+
None,
|
602 |
+
states
|
603 |
+
+ chatbots
|
604 |
+
+ model_selectors
|
605 |
+
+ [multimodal_textbox, textbox, send_btn]
|
606 |
+
+ btn_list,
|
607 |
+
)
|
608 |
+
|
609 |
+
multimodal_textbox.submit(
|
610 |
+
add_text,
|
611 |
+
states + model_selectors + [multimodal_textbox, context_state],
|
612 |
+
states
|
613 |
+
+ chatbots
|
614 |
+
+ [multimodal_textbox, textbox, send_btn]
|
615 |
+
+ btn_list
|
616 |
+
+ [random_btn]
|
617 |
+
+ [slow_warning],
|
618 |
+
).then(set_invisible_image, [], [image_column]).then(
|
619 |
+
bot_response_multi,
|
620 |
+
states + [temperature, top_p, max_output_tokens],
|
621 |
+
states + chatbots + btn_list,
|
622 |
+
).then(
|
623 |
+
flash_buttons,
|
624 |
+
[],
|
625 |
+
btn_list,
|
626 |
+
)
|
627 |
+
|
628 |
+
textbox.submit(
|
629 |
+
add_text,
|
630 |
+
states + model_selectors + [textbox, context_state],
|
631 |
+
states
|
632 |
+
+ chatbots
|
633 |
+
+ [multimodal_textbox, textbox, send_btn]
|
634 |
+
+ btn_list
|
635 |
+
+ [random_btn]
|
636 |
+
+ [slow_warning],
|
637 |
+
).then(
|
638 |
+
bot_response_multi,
|
639 |
+
states + [temperature, top_p, max_output_tokens],
|
640 |
+
states + chatbots + btn_list,
|
641 |
+
).then(
|
642 |
+
flash_buttons,
|
643 |
+
[],
|
644 |
+
btn_list,
|
645 |
+
)
|
646 |
+
|
647 |
+
send_btn.click(
|
648 |
+
add_text,
|
649 |
+
states + model_selectors + [textbox, context_state],
|
650 |
+
states
|
651 |
+
+ chatbots
|
652 |
+
+ [multimodal_textbox, textbox, send_btn]
|
653 |
+
+ btn_list
|
654 |
+
+ [random_btn]
|
655 |
+
+ [slow_warning],
|
656 |
+
).then(
|
657 |
+
bot_response_multi,
|
658 |
+
states + [temperature, top_p, max_output_tokens],
|
659 |
+
states + chatbots + btn_list,
|
660 |
+
).then(
|
661 |
+
flash_buttons,
|
662 |
+
[],
|
663 |
+
btn_list,
|
664 |
+
)
|
665 |
+
|
666 |
+
if random_questions:
|
667 |
+
random_btn.click(
|
668 |
+
get_vqa_sample, # First, get the VQA sample
|
669 |
+
[], # Pass the path to the VQA samples
|
670 |
+
[multimodal_textbox, imagebox], # Outputs are textbox and imagebox
|
671 |
+
).then(set_visible_image, [multimodal_textbox], [image_column]).then(
|
672 |
+
clear_history_example,
|
673 |
+
None,
|
674 |
+
states
|
675 |
+
+ chatbots
|
676 |
+
+ model_selectors
|
677 |
+
+ [multimodal_textbox, textbox, send_btn]
|
678 |
+
+ btn_list
|
679 |
+
+ [random_btn],
|
680 |
+
)
|
681 |
+
|
682 |
+
return states + model_selectors
|
serve/gradio_block_arena_vision_named.py
ADDED
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Multimodal Chatbot Arena (side-by-side) tab.
|
3 |
+
Users chat with two chosen models.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
from typing import List, Union
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from .constants import (
|
15 |
+
TEXT_MODERATION_MSG,
|
16 |
+
IMAGE_MODERATION_MSG,
|
17 |
+
MODERATION_MSG,
|
18 |
+
CONVERSATION_LIMIT_MSG,
|
19 |
+
SLOW_MODEL_MSG,
|
20 |
+
INPUT_CHAR_LEN_LIMIT,
|
21 |
+
CONVERSATION_TURN_LIMIT,
|
22 |
+
SURVEY_LINK,
|
23 |
+
)
|
24 |
+
from .gradio_block_arena_named import (
|
25 |
+
flash_buttons,
|
26 |
+
share_click,
|
27 |
+
bot_response_multi,
|
28 |
+
)
|
29 |
+
from .gradio_block_arena_vision import (
|
30 |
+
get_vqa_sample,
|
31 |
+
set_invisible_image,
|
32 |
+
set_visible_image,
|
33 |
+
add_image,
|
34 |
+
moderate_input,
|
35 |
+
_prepare_text_with_image,
|
36 |
+
convert_images_to_conversation_format,
|
37 |
+
enable_multimodal,
|
38 |
+
disable_multimodal,
|
39 |
+
invisible_text,
|
40 |
+
invisible_btn,
|
41 |
+
visible_text,
|
42 |
+
)
|
43 |
+
from .gradio_global_state import Context
|
44 |
+
from .gradio_web_server import (
|
45 |
+
State,
|
46 |
+
bot_response,
|
47 |
+
get_conv_log_filename,
|
48 |
+
no_change_btn,
|
49 |
+
enable_btn,
|
50 |
+
disable_btn,
|
51 |
+
invisible_btn,
|
52 |
+
acknowledgment_md,
|
53 |
+
get_ip,
|
54 |
+
get_model_description_md,
|
55 |
+
enable_text,
|
56 |
+
)
|
57 |
+
from .remote_logger import get_remote_logger
|
58 |
+
from .utils import (
|
59 |
+
build_logger,
|
60 |
+
moderation_filter,
|
61 |
+
image_moderation_filter,
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
|
66 |
+
|
67 |
+
num_sides = 2
|
68 |
+
enable_moderation = False
|
69 |
+
|
70 |
+
|
71 |
+
def load_demo_side_by_side_vision_named(context: Context):
|
72 |
+
states = [None] * num_sides
|
73 |
+
|
74 |
+
# default to the text models
|
75 |
+
models = context.text_models
|
76 |
+
|
77 |
+
model_left = models[0] if len(models) > 0 else ""
|
78 |
+
if len(models) > 1:
|
79 |
+
weights = ([1] * 128)[: len(models) - 1]
|
80 |
+
weights = weights / np.sum(weights)
|
81 |
+
model_right = np.random.choice(models[1:], p=weights)
|
82 |
+
else:
|
83 |
+
model_right = model_left
|
84 |
+
|
85 |
+
all_models = context.models
|
86 |
+
selector_updates = [
|
87 |
+
gr.Dropdown(choices=all_models, value=model_left, visible=True),
|
88 |
+
gr.Dropdown(choices=all_models, value=model_right, visible=True),
|
89 |
+
]
|
90 |
+
|
91 |
+
return states + selector_updates
|
92 |
+
|
93 |
+
|
94 |
+
def clear_history_example(request: gr.Request):
|
95 |
+
logger.info(f"clear_history_example (named). ip: {get_ip(request)}")
|
96 |
+
return (
|
97 |
+
[None] * num_sides
|
98 |
+
+ [None] * num_sides
|
99 |
+
+ [enable_multimodal, invisible_text, invisible_btn]
|
100 |
+
+ [invisible_btn] * 4
|
101 |
+
+ [disable_btn] * 2
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
|
106 |
+
filename = get_conv_log_filename(
|
107 |
+
states[0].is_vision, states[0].has_csam_image)
|
108 |
+
with open(filename, "a") as fout:
|
109 |
+
data = {
|
110 |
+
"tstamp": round(time.time(), 4),
|
111 |
+
"type": vote_type,
|
112 |
+
"models": [x for x in model_selectors],
|
113 |
+
"states": [x.dict() for x in states],
|
114 |
+
"ip": get_ip(request),
|
115 |
+
}
|
116 |
+
fout.write(json.dumps(data) + "\n")
|
117 |
+
get_remote_logger().log(data)
|
118 |
+
|
119 |
+
|
120 |
+
def leftvote_last_response(
|
121 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
122 |
+
):
|
123 |
+
logger.info(f"leftvote (named). ip: {get_ip(request)}")
|
124 |
+
vote_last_response(
|
125 |
+
[state0, state1], "leftvote", [model_selector0, model_selector1], request
|
126 |
+
)
|
127 |
+
return (None,) + (disable_btn,) * 4
|
128 |
+
|
129 |
+
|
130 |
+
def rightvote_last_response(
|
131 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
132 |
+
):
|
133 |
+
logger.info(f"rightvote (named). ip: {get_ip(request)}")
|
134 |
+
vote_last_response(
|
135 |
+
[state0, state1], "rightvote", [
|
136 |
+
model_selector0, model_selector1], request
|
137 |
+
)
|
138 |
+
return (None,) + (disable_btn,) * 4
|
139 |
+
|
140 |
+
|
141 |
+
def tievote_last_response(
|
142 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
143 |
+
):
|
144 |
+
logger.info(f"tievote (named). ip: {get_ip(request)}")
|
145 |
+
vote_last_response(
|
146 |
+
[state0, state1], "tievote", [model_selector0, model_selector1], request
|
147 |
+
)
|
148 |
+
return (None,) + (disable_btn,) * 4
|
149 |
+
|
150 |
+
|
151 |
+
def bothbad_vote_last_response(
|
152 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
153 |
+
):
|
154 |
+
logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
|
155 |
+
vote_last_response(
|
156 |
+
[state0, state1], "bothbad_vote", [
|
157 |
+
model_selector0, model_selector1], request
|
158 |
+
)
|
159 |
+
return (None,) + (disable_btn,) * 4
|
160 |
+
|
161 |
+
|
162 |
+
def regenerate(state0, state1, request: gr.Request):
|
163 |
+
logger.info(f"regenerate (named). ip: {get_ip(request)}")
|
164 |
+
states = [state0, state1]
|
165 |
+
if state0.regen_support and state1.regen_support:
|
166 |
+
for i in range(num_sides):
|
167 |
+
states[i].conv.update_last_message(None)
|
168 |
+
return (
|
169 |
+
states
|
170 |
+
+ [x.to_gradio_chatbot() for x in states]
|
171 |
+
+ [None]
|
172 |
+
+ [disable_btn] * 6
|
173 |
+
)
|
174 |
+
states[0].skip_next = True
|
175 |
+
states[1].skip_next = True
|
176 |
+
return (
|
177 |
+
states + [x.to_gradio_chatbot() for x in states] +
|
178 |
+
[None] + [no_change_btn] * 6
|
179 |
+
)
|
180 |
+
|
181 |
+
|
182 |
+
def clear_history(request: gr.Request):
|
183 |
+
logger.info(f"clear_history (named). ip: {get_ip(request)}")
|
184 |
+
return (
|
185 |
+
[None] * num_sides
|
186 |
+
+ [None] * num_sides
|
187 |
+
+ [enable_multimodal, invisible_text, invisible_btn]
|
188 |
+
+ [invisible_btn] * 4
|
189 |
+
+ [disable_btn] * 2
|
190 |
+
)
|
191 |
+
|
192 |
+
|
193 |
+
def add_text(
|
194 |
+
state0,
|
195 |
+
state1,
|
196 |
+
model_selector0,
|
197 |
+
model_selector1,
|
198 |
+
chat_input: Union[str, dict],
|
199 |
+
context: Context,
|
200 |
+
request: gr.Request,
|
201 |
+
):
|
202 |
+
if isinstance(chat_input, dict):
|
203 |
+
text, images = chat_input["text"], chat_input["files"]
|
204 |
+
else:
|
205 |
+
text, images = chat_input, []
|
206 |
+
|
207 |
+
if len(images) > 0:
|
208 |
+
if (
|
209 |
+
model_selector0 in context.text_models
|
210 |
+
and model_selector0 not in context.vision_models
|
211 |
+
):
|
212 |
+
gr.Warning(
|
213 |
+
f"{model_selector0} is a text-only model. Image is ignored.")
|
214 |
+
images = []
|
215 |
+
if (
|
216 |
+
model_selector1 in context.text_models
|
217 |
+
and model_selector1 not in context.vision_models
|
218 |
+
):
|
219 |
+
gr.Warning(
|
220 |
+
f"{model_selector1} is a text-only model. Image is ignored.")
|
221 |
+
images = []
|
222 |
+
|
223 |
+
ip = get_ip(request)
|
224 |
+
logger.info(f"add_text (named). ip: {ip}. len: {len(text)}")
|
225 |
+
states = [state0, state1]
|
226 |
+
model_selectors = [model_selector0, model_selector1]
|
227 |
+
|
228 |
+
# Init states if necessary
|
229 |
+
for i in range(num_sides):
|
230 |
+
if states[i] is None and len(images) == 0:
|
231 |
+
states[i] = State(model_selectors[i], is_vision=False)
|
232 |
+
elif states[i] is None and len(images) > 0:
|
233 |
+
states[i] = State(model_selectors[i], is_vision=True)
|
234 |
+
|
235 |
+
if len(text) <= 0:
|
236 |
+
for i in range(num_sides):
|
237 |
+
states[i].skip_next = True
|
238 |
+
return (
|
239 |
+
states
|
240 |
+
+ [x.to_gradio_chatbot() for x in states]
|
241 |
+
+ [None, "", no_change_btn]
|
242 |
+
+ [
|
243 |
+
no_change_btn,
|
244 |
+
]
|
245 |
+
* 6
|
246 |
+
)
|
247 |
+
|
248 |
+
model_list = [states[i].model_name for i in range(num_sides)]
|
249 |
+
all_conv_text_left = states[0].conv.get_prompt()
|
250 |
+
all_conv_text_right = states[0].conv.get_prompt()
|
251 |
+
all_conv_text = (
|
252 |
+
all_conv_text_left[-1000:] +
|
253 |
+
all_conv_text_right[-1000:] + "\nuser: " + text
|
254 |
+
)
|
255 |
+
|
256 |
+
images = convert_images_to_conversation_format(images)
|
257 |
+
|
258 |
+
text, image_flagged, csam_flag = moderate_input(
|
259 |
+
state0, text, all_conv_text, model_list, images, ip
|
260 |
+
)
|
261 |
+
|
262 |
+
conv = states[0].conv
|
263 |
+
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
|
264 |
+
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
|
265 |
+
for i in range(num_sides):
|
266 |
+
states[i].skip_next = True
|
267 |
+
return (
|
268 |
+
states
|
269 |
+
+ [x.to_gradio_chatbot() for x in states]
|
270 |
+
+ [{"text": CONVERSATION_LIMIT_MSG}, "", no_change_btn]
|
271 |
+
+ [
|
272 |
+
no_change_btn,
|
273 |
+
]
|
274 |
+
* 6
|
275 |
+
)
|
276 |
+
|
277 |
+
if image_flagged:
|
278 |
+
logger.info(f"image flagged. ip: {ip}. text: {text}")
|
279 |
+
for i in range(num_sides):
|
280 |
+
states[i].skip_next = True
|
281 |
+
return (
|
282 |
+
states
|
283 |
+
+ [x.to_gradio_chatbot() for x in states]
|
284 |
+
+ [{"text": IMAGE_MODERATION_MSG}, "", no_change_btn]
|
285 |
+
+ [
|
286 |
+
no_change_btn,
|
287 |
+
]
|
288 |
+
* 6
|
289 |
+
)
|
290 |
+
|
291 |
+
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
|
292 |
+
for i in range(num_sides):
|
293 |
+
post_processed_text = _prepare_text_with_image(
|
294 |
+
states[i], text, images, csam_flag=csam_flag
|
295 |
+
)
|
296 |
+
states[i].conv.append_message(
|
297 |
+
states[i].conv.roles[0], post_processed_text)
|
298 |
+
states[i].conv.append_message(states[i].conv.roles[1], None)
|
299 |
+
states[i].skip_next = False
|
300 |
+
|
301 |
+
return (
|
302 |
+
states
|
303 |
+
+ [x.to_gradio_chatbot() for x in states]
|
304 |
+
+ [disable_multimodal, visible_text, enable_btn]
|
305 |
+
+ [
|
306 |
+
disable_btn,
|
307 |
+
]
|
308 |
+
* 6
|
309 |
+
)
|
310 |
+
|
311 |
+
|
312 |
+
def build_side_by_side_vision_ui_named(context: Context, random_questions=None):
|
313 |
+
notice_markdown = f"""
|
314 |
+
# ⚔️ Chatbot Arena (formerly LMSYS): Free AI Chat to Compare & Test Best AI Chatbots
|
315 |
+
|
316 |
+
{SURVEY_LINK}
|
317 |
+
|
318 |
+
## 📜 How It Works
|
319 |
+
- Ask any question to two chosen models (e.g., ChatGPT, Gemini, Claude, Llama) and vote for the better one!
|
320 |
+
- You can chat for multiple turns until you identify a winner.
|
321 |
+
|
322 |
+
Note: You can only chat with <span style='color: #DE3163; font-weight: bold'>one image per conversation</span>. You can upload images less than 15MB. Click the "Random Example" button to chat with a random image.
|
323 |
+
|
324 |
+
**❗️ For research purposes, we log user prompts and images, and may release this data to the public in the future. Please do not upload any confidential or personal information.**
|
325 |
+
|
326 |
+
## 🤖 Choose two models to compare
|
327 |
+
"""
|
328 |
+
|
329 |
+
states = [gr.State() for _ in range(num_sides)]
|
330 |
+
model_selectors = [None] * num_sides
|
331 |
+
chatbots = [None] * num_sides
|
332 |
+
|
333 |
+
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
334 |
+
|
335 |
+
text_and_vision_models = context.models
|
336 |
+
context_state = gr.State(context)
|
337 |
+
|
338 |
+
with gr.Row():
|
339 |
+
with gr.Column(scale=2, visible=False) as image_column:
|
340 |
+
imagebox = gr.Image(
|
341 |
+
type="pil",
|
342 |
+
show_label=False,
|
343 |
+
interactive=False,
|
344 |
+
)
|
345 |
+
|
346 |
+
with gr.Column(scale=5):
|
347 |
+
with gr.Group(elem_id="share-region-anony"):
|
348 |
+
with gr.Accordion(
|
349 |
+
f"🔍 Expand to see the descriptions of {len(text_and_vision_models)} models",
|
350 |
+
open=False,
|
351 |
+
):
|
352 |
+
model_description_md = get_model_description_md(
|
353 |
+
text_and_vision_models
|
354 |
+
)
|
355 |
+
gr.Markdown(
|
356 |
+
model_description_md, elem_id="model_description_markdown"
|
357 |
+
)
|
358 |
+
|
359 |
+
with gr.Row():
|
360 |
+
for i in range(num_sides):
|
361 |
+
with gr.Column():
|
362 |
+
model_selectors[i] = gr.Dropdown(
|
363 |
+
choices=text_and_vision_models,
|
364 |
+
value=text_and_vision_models[i]
|
365 |
+
if len(text_and_vision_models) > i
|
366 |
+
else "",
|
367 |
+
interactive=True,
|
368 |
+
show_label=False,
|
369 |
+
container=False,
|
370 |
+
)
|
371 |
+
|
372 |
+
with gr.Row():
|
373 |
+
for i in range(num_sides):
|
374 |
+
label = "Model A" if i == 0 else "Model B"
|
375 |
+
with gr.Column():
|
376 |
+
chatbots[i] = gr.Chatbot(
|
377 |
+
label=label,
|
378 |
+
elem_id=f"chatbot",
|
379 |
+
height=650,
|
380 |
+
show_copy_button=True,
|
381 |
+
)
|
382 |
+
|
383 |
+
with gr.Row():
|
384 |
+
leftvote_btn = gr.Button(
|
385 |
+
value="👈 A is better", visible=False, interactive=False
|
386 |
+
)
|
387 |
+
rightvote_btn = gr.Button(
|
388 |
+
value="👉 B is better", visible=False, interactive=False
|
389 |
+
)
|
390 |
+
tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
|
391 |
+
bothbad_btn = gr.Button(
|
392 |
+
value="👎 Both are bad", visible=False, interactive=False
|
393 |
+
)
|
394 |
+
|
395 |
+
with gr.Row():
|
396 |
+
textbox = gr.Textbox(
|
397 |
+
show_label=False,
|
398 |
+
placeholder="👉 Enter your prompt and press ENTER",
|
399 |
+
elem_id="input_box",
|
400 |
+
visible=False,
|
401 |
+
)
|
402 |
+
|
403 |
+
send_btn = gr.Button(
|
404 |
+
value="Send", variant="primary", scale=0, visible=False, interactive=False
|
405 |
+
)
|
406 |
+
|
407 |
+
multimodal_textbox = gr.MultimodalTextbox(
|
408 |
+
file_types=["image"],
|
409 |
+
show_label=False,
|
410 |
+
placeholder="Enter your prompt or add image here",
|
411 |
+
container=True,
|
412 |
+
elem_id="input_box",
|
413 |
+
)
|
414 |
+
|
415 |
+
with gr.Row() as button_row:
|
416 |
+
if random_questions:
|
417 |
+
global vqa_samples
|
418 |
+
with open(random_questions, "r") as f:
|
419 |
+
vqa_samples = json.load(f)
|
420 |
+
random_btn = gr.Button(value="🎲 Random Example", interactive=True)
|
421 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
422 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
423 |
+
share_btn = gr.Button(value="📷 Share")
|
424 |
+
|
425 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
426 |
+
temperature = gr.Slider(
|
427 |
+
minimum=0.0,
|
428 |
+
maximum=1.0,
|
429 |
+
value=0.7,
|
430 |
+
step=0.1,
|
431 |
+
interactive=True,
|
432 |
+
label="Temperature",
|
433 |
+
)
|
434 |
+
top_p = gr.Slider(
|
435 |
+
minimum=0.0,
|
436 |
+
maximum=1.0,
|
437 |
+
value=1.0,
|
438 |
+
step=0.1,
|
439 |
+
interactive=True,
|
440 |
+
label="Top P",
|
441 |
+
)
|
442 |
+
max_output_tokens = gr.Slider(
|
443 |
+
minimum=16,
|
444 |
+
maximum=2048,
|
445 |
+
value=1024,
|
446 |
+
step=64,
|
447 |
+
interactive=True,
|
448 |
+
label="Max output tokens",
|
449 |
+
)
|
450 |
+
|
451 |
+
gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
|
452 |
+
|
453 |
+
# Register listeners
|
454 |
+
btn_list = [
|
455 |
+
leftvote_btn,
|
456 |
+
rightvote_btn,
|
457 |
+
tie_btn,
|
458 |
+
bothbad_btn,
|
459 |
+
regenerate_btn,
|
460 |
+
clear_btn,
|
461 |
+
]
|
462 |
+
leftvote_btn.click(
|
463 |
+
leftvote_last_response,
|
464 |
+
states + model_selectors,
|
465 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
466 |
+
)
|
467 |
+
rightvote_btn.click(
|
468 |
+
rightvote_last_response,
|
469 |
+
states + model_selectors,
|
470 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
471 |
+
)
|
472 |
+
tie_btn.click(
|
473 |
+
tievote_last_response,
|
474 |
+
states + model_selectors,
|
475 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
476 |
+
)
|
477 |
+
bothbad_btn.click(
|
478 |
+
bothbad_vote_last_response,
|
479 |
+
states + model_selectors,
|
480 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
481 |
+
)
|
482 |
+
regenerate_btn.click(
|
483 |
+
regenerate, states, states + chatbots + [textbox] + btn_list
|
484 |
+
).then(
|
485 |
+
bot_response_multi,
|
486 |
+
states + [temperature, top_p, max_output_tokens],
|
487 |
+
states + chatbots + btn_list,
|
488 |
+
).then(
|
489 |
+
flash_buttons, [], btn_list
|
490 |
+
)
|
491 |
+
clear_btn.click(
|
492 |
+
clear_history,
|
493 |
+
None,
|
494 |
+
states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
|
495 |
+
)
|
496 |
+
|
497 |
+
share_js = """
|
498 |
+
function (a, b, c, d) {
|
499 |
+
const captureElement = document.querySelector('#share-region-named');
|
500 |
+
html2canvas(captureElement)
|
501 |
+
.then(canvas => {
|
502 |
+
canvas.style.display = 'none'
|
503 |
+
document.body.appendChild(canvas)
|
504 |
+
return canvas
|
505 |
+
})
|
506 |
+
.then(canvas => {
|
507 |
+
const image = canvas.toDataURL('image/png')
|
508 |
+
const a = document.createElement('a')
|
509 |
+
a.setAttribute('download', 'chatbot-arena.png')
|
510 |
+
a.setAttribute('href', image)
|
511 |
+
a.click()
|
512 |
+
canvas.remove()
|
513 |
+
});
|
514 |
+
return [a, b, c, d];
|
515 |
+
}
|
516 |
+
"""
|
517 |
+
share_btn.click(share_click, states + model_selectors, [], js=share_js)
|
518 |
+
|
519 |
+
for i in range(num_sides):
|
520 |
+
model_selectors[i].change(
|
521 |
+
clear_history,
|
522 |
+
None,
|
523 |
+
states + chatbots + [multimodal_textbox,
|
524 |
+
textbox, send_btn] + btn_list,
|
525 |
+
).then(set_visible_image, [multimodal_textbox], [image_column])
|
526 |
+
|
527 |
+
multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
|
528 |
+
set_visible_image, [multimodal_textbox], [image_column]
|
529 |
+
).then(
|
530 |
+
clear_history_example,
|
531 |
+
None,
|
532 |
+
states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
|
533 |
+
)
|
534 |
+
|
535 |
+
multimodal_textbox.submit(
|
536 |
+
add_text,
|
537 |
+
states + model_selectors + [multimodal_textbox, context_state],
|
538 |
+
states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
|
539 |
+
).then(set_invisible_image, [], [image_column]).then(
|
540 |
+
bot_response_multi,
|
541 |
+
states + [temperature, top_p, max_output_tokens],
|
542 |
+
states + chatbots + btn_list,
|
543 |
+
).then(
|
544 |
+
flash_buttons, [], btn_list
|
545 |
+
)
|
546 |
+
|
547 |
+
textbox.submit(
|
548 |
+
add_text,
|
549 |
+
states + model_selectors + [textbox, context_state],
|
550 |
+
states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
|
551 |
+
).then(set_invisible_image, [], [image_column]).then(
|
552 |
+
bot_response_multi,
|
553 |
+
states + [temperature, top_p, max_output_tokens],
|
554 |
+
states + chatbots + btn_list,
|
555 |
+
).then(
|
556 |
+
flash_buttons, [], btn_list
|
557 |
+
)
|
558 |
+
|
559 |
+
send_btn.click(
|
560 |
+
add_text,
|
561 |
+
states + model_selectors + [textbox, context_state],
|
562 |
+
states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
|
563 |
+
).then(set_invisible_image, [], [image_column]).then(
|
564 |
+
bot_response_multi,
|
565 |
+
states + [temperature, top_p, max_output_tokens],
|
566 |
+
states + chatbots + btn_list,
|
567 |
+
).then(
|
568 |
+
flash_buttons, [], btn_list
|
569 |
+
)
|
570 |
+
|
571 |
+
if random_questions:
|
572 |
+
random_btn.click(
|
573 |
+
get_vqa_sample, # First, get the VQA sample
|
574 |
+
[], # Pass the path to the VQA samples
|
575 |
+
[multimodal_textbox, imagebox], # Outputs are textbox and imagebox
|
576 |
+
).then(set_visible_image, [multimodal_textbox], [image_column]).then(
|
577 |
+
clear_history_example,
|
578 |
+
None,
|
579 |
+
states + chatbots + [multimodal_textbox,
|
580 |
+
textbox, send_btn] + btn_list,
|
581 |
+
)
|
582 |
+
|
583 |
+
return states + model_selectors
|
serve/gradio_global_state.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class Context:
|
7 |
+
text_models: List[str] = field(default_factory=list)
|
8 |
+
all_text_models: List[str] = field(default_factory=list)
|
9 |
+
vision_models: List[str] = field(default_factory=list)
|
10 |
+
all_vision_models: List[str] = field(default_factory=list)
|
11 |
+
models: List[str] = field(default_factory=list)
|
12 |
+
all_models: List[str] = field(default_factory=list)
|
serve/gradio_web_server.py
ADDED
@@ -0,0 +1,1051 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The gradio demo server for chatting with a single model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import datetime
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
import uuid
|
11 |
+
from typing import List
|
12 |
+
|
13 |
+
import gradio as gr
|
14 |
+
import requests
|
15 |
+
|
16 |
+
from .conversation import Conversation
|
17 |
+
from .constants import (
|
18 |
+
LOGDIR,
|
19 |
+
WORKER_API_TIMEOUT,
|
20 |
+
ErrorCode,
|
21 |
+
MODERATION_MSG,
|
22 |
+
CONVERSATION_LIMIT_MSG,
|
23 |
+
RATE_LIMIT_MSG,
|
24 |
+
SERVER_ERROR_MSG,
|
25 |
+
INPUT_CHAR_LEN_LIMIT,
|
26 |
+
CONVERSATION_TURN_LIMIT,
|
27 |
+
SESSION_EXPIRATION_TIME,
|
28 |
+
SURVEY_LINK,
|
29 |
+
)
|
30 |
+
# from .model.model_adapter import (
|
31 |
+
# get_conversation_template,
|
32 |
+
# )
|
33 |
+
# from .model.model_registry import get_model_info, model_info
|
34 |
+
from .api_provider import get_api_provider_stream_iter
|
35 |
+
from .gradio_global_state import Context
|
36 |
+
from serve.remote_logger import get_remote_logger
|
37 |
+
from .utils import (
|
38 |
+
build_logger,
|
39 |
+
get_window_url_params_js,
|
40 |
+
get_window_url_params_with_tos_js,
|
41 |
+
moderation_filter,
|
42 |
+
parse_gradio_auth_creds,
|
43 |
+
load_image,
|
44 |
+
)
|
45 |
+
|
46 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
47 |
+
|
48 |
+
headers = {"User-Agent": "FastChat Client"}
|
49 |
+
|
50 |
+
no_change_btn = gr.Button()
|
51 |
+
enable_btn = gr.Button(interactive=True, visible=True)
|
52 |
+
disable_btn = gr.Button(interactive=False)
|
53 |
+
invisible_btn = gr.Button(interactive=False, visible=False)
|
54 |
+
enable_text = gr.Textbox(
|
55 |
+
interactive=True, visible=True, placeholder="👉 Enter your prompt and press ENTER"
|
56 |
+
)
|
57 |
+
disable_text = gr.Textbox(
|
58 |
+
interactive=False,
|
59 |
+
visible=True,
|
60 |
+
placeholder='Press "🎲 New Round" to start over👇 (Note: Your vote shapes the leaderboard, please vote RESPONSIBLY!)',
|
61 |
+
)
|
62 |
+
|
63 |
+
controller_url = None
|
64 |
+
enable_moderation = False
|
65 |
+
use_remote_storage = False
|
66 |
+
|
67 |
+
acknowledgment_md = """
|
68 |
+
### Terms of Service
|
69 |
+
ユーザーは、サービスを利用する前に以下の条件に同意する必要があります:
|
70 |
+
|
71 |
+
- 違法、有害、暴力、人種差別、または性的目的で使用しないでください。
|
72 |
+
- 個人情報をアップロードしないでください。
|
73 |
+
- このサービスで収集された対話データは今後の大規模言語モデルの開発のほか、適切なマスキング処理を施した上で、クリエイティブ コモンズ アトリビューション (CC-BY) または同様のライセンスの下で配布される可能性があります。
|
74 |
+
|
75 |
+
"""
|
76 |
+
|
77 |
+
# JSON file format of API-based models:
|
78 |
+
# {
|
79 |
+
# "gpt-3.5-turbo": {
|
80 |
+
# "model_name": "gpt-3.5-turbo",
|
81 |
+
# "api_type": "openai",
|
82 |
+
# "api_base": "https://api.openai.com/v1",
|
83 |
+
# "api_key": "sk-******",
|
84 |
+
# "anony_only": false
|
85 |
+
# }
|
86 |
+
# }
|
87 |
+
#
|
88 |
+
# - "api_type" can be one of the following: openai, anthropic, gemini, or mistral. For custom APIs, add a new type and implement it accordingly.
|
89 |
+
# - "anony_only" indicates whether to display this model in anonymous mode only.
|
90 |
+
|
91 |
+
api_endpoint_info = {}
|
92 |
+
|
93 |
+
|
94 |
+
class State:
|
95 |
+
def __init__(self, model_name, is_vision=False):
|
96 |
+
# self.conv = get_conversation_template(model_name)
|
97 |
+
self.conv = Conversation(model_name)
|
98 |
+
self.conv_id = uuid.uuid4().hex
|
99 |
+
self.skip_next = False
|
100 |
+
self.model_name = model_name
|
101 |
+
self.oai_thread_id = None
|
102 |
+
self.is_vision = is_vision
|
103 |
+
|
104 |
+
# NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes.
|
105 |
+
self.has_csam_image = False
|
106 |
+
|
107 |
+
self.regen_support = True
|
108 |
+
if "browsing" in model_name:
|
109 |
+
self.regen_support = False
|
110 |
+
self.init_system_prompt(self.conv, is_vision)
|
111 |
+
|
112 |
+
def init_system_prompt(self, conv, is_vision):
|
113 |
+
system_prompt = conv.get_system_message(is_vision)
|
114 |
+
if len(system_prompt) == 0:
|
115 |
+
return
|
116 |
+
current_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
117 |
+
system_prompt = system_prompt.replace(
|
118 |
+
"{{currentDateTime}}", current_date)
|
119 |
+
|
120 |
+
current_date_v2 = datetime.datetime.now().strftime("%d %b %Y")
|
121 |
+
system_prompt = system_prompt.replace(
|
122 |
+
"{{currentDateTimev2}}", current_date_v2)
|
123 |
+
|
124 |
+
current_date_v3 = datetime.datetime.now().strftime("%B %Y")
|
125 |
+
system_prompt = system_prompt.replace(
|
126 |
+
"{{currentDateTimev3}}", current_date_v3)
|
127 |
+
conv.set_system_message(system_prompt)
|
128 |
+
|
129 |
+
def to_gradio_chatbot(self):
|
130 |
+
return self.conv.to_gradio_chatbot()
|
131 |
+
|
132 |
+
def dict(self):
|
133 |
+
base = self.conv.dict()
|
134 |
+
base.update(
|
135 |
+
{
|
136 |
+
"conv_id": self.conv_id,
|
137 |
+
"model_name": self.model_name,
|
138 |
+
}
|
139 |
+
)
|
140 |
+
|
141 |
+
if self.is_vision:
|
142 |
+
base.update({"has_csam_image": self.has_csam_image})
|
143 |
+
return base
|
144 |
+
|
145 |
+
|
146 |
+
def set_global_vars(
|
147 |
+
controller_url_,
|
148 |
+
enable_moderation_,
|
149 |
+
use_remote_storage_,
|
150 |
+
):
|
151 |
+
global controller_url, enable_moderation, use_remote_storage
|
152 |
+
controller_url = controller_url_
|
153 |
+
enable_moderation = enable_moderation_
|
154 |
+
use_remote_storage = use_remote_storage_
|
155 |
+
|
156 |
+
|
157 |
+
def get_conv_log_filename(is_vision=False, has_csam_image=False):
|
158 |
+
t = datetime.datetime.now()
|
159 |
+
conv_log_filename = f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json"
|
160 |
+
if is_vision and not has_csam_image:
|
161 |
+
name = os.path.join(LOGDIR, f"vision-tmp-{conv_log_filename}")
|
162 |
+
elif is_vision and has_csam_image:
|
163 |
+
name = os.path.join(LOGDIR, f"vision-csam-{conv_log_filename}")
|
164 |
+
else:
|
165 |
+
name = os.path.join(LOGDIR, conv_log_filename)
|
166 |
+
|
167 |
+
return name
|
168 |
+
|
169 |
+
|
170 |
+
def get_model_list(controller_url, register_api_endpoint_file, vision_arena):
|
171 |
+
global api_endpoint_info
|
172 |
+
|
173 |
+
# Add models from the controller
|
174 |
+
if controller_url:
|
175 |
+
ret = requests.post(controller_url + "/refresh_all_workers")
|
176 |
+
assert ret.status_code == 200
|
177 |
+
|
178 |
+
if vision_arena:
|
179 |
+
ret = requests.post(controller_url + "/list_multimodal_models")
|
180 |
+
models = ret.json()["models"]
|
181 |
+
else:
|
182 |
+
ret = requests.post(controller_url + "/list_language_models")
|
183 |
+
models = ret.json()["models"]
|
184 |
+
else:
|
185 |
+
models = []
|
186 |
+
|
187 |
+
# Add models from the API providers
|
188 |
+
if register_api_endpoint_file:
|
189 |
+
api_endpoint_info = json.load(open(register_api_endpoint_file))
|
190 |
+
for mdl, mdl_dict in api_endpoint_info.items():
|
191 |
+
mdl_vision = mdl_dict.get("vision-arena", False)
|
192 |
+
mdl_text = mdl_dict.get("text-arena", True)
|
193 |
+
if vision_arena and mdl_vision:
|
194 |
+
models.append(mdl)
|
195 |
+
if not vision_arena and mdl_text:
|
196 |
+
models.append(mdl)
|
197 |
+
|
198 |
+
# Remove anonymous models
|
199 |
+
models = list(set(models))
|
200 |
+
visible_models = models.copy()
|
201 |
+
for mdl in models:
|
202 |
+
if mdl not in api_endpoint_info:
|
203 |
+
continue
|
204 |
+
mdl_dict = api_endpoint_info[mdl]
|
205 |
+
if mdl_dict["anony_only"]:
|
206 |
+
visible_models.remove(mdl)
|
207 |
+
|
208 |
+
# Sort models and add descriptions
|
209 |
+
model_info = list(range(len(models)))
|
210 |
+
priority = {k: f"___{i:03d}" for i, k in enumerate(model_info)}
|
211 |
+
models.sort(key=lambda x: priority.get(x, x))
|
212 |
+
visible_models.sort(key=lambda x: priority.get(x, x))
|
213 |
+
logger.info(f"All models: {models}")
|
214 |
+
logger.info(f"Visible models: {visible_models}")
|
215 |
+
return visible_models, models
|
216 |
+
|
217 |
+
|
218 |
+
def load_demo_single(context: Context, query_params):
|
219 |
+
# default to text models
|
220 |
+
models = context.text_models
|
221 |
+
|
222 |
+
selected_model = models[0] if len(models) > 0 else ""
|
223 |
+
if "model" in query_params:
|
224 |
+
model = query_params["model"]
|
225 |
+
if model in models:
|
226 |
+
selected_model = model
|
227 |
+
|
228 |
+
all_models = context.models
|
229 |
+
|
230 |
+
dropdown_update = gr.Dropdown(
|
231 |
+
choices=all_models, value=selected_model, visible=True
|
232 |
+
)
|
233 |
+
state = None
|
234 |
+
return [state, dropdown_update]
|
235 |
+
|
236 |
+
|
237 |
+
def load_demo(url_params, request: gr.Request):
|
238 |
+
global models
|
239 |
+
|
240 |
+
ip = get_ip(request)
|
241 |
+
logger.info(f"load_demo. ip: {ip}. params: {url_params}")
|
242 |
+
|
243 |
+
if args.model_list_mode == "reload":
|
244 |
+
models, all_models = get_model_list(
|
245 |
+
controller_url, args.register_api_endpoint_file, vision_arena=False
|
246 |
+
)
|
247 |
+
|
248 |
+
return load_demo_single(models, url_params)
|
249 |
+
|
250 |
+
|
251 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
252 |
+
filename = get_conv_log_filename()
|
253 |
+
if "llava" in model_selector:
|
254 |
+
filename = filename.replace("2024", "vision-tmp-2024")
|
255 |
+
|
256 |
+
with open(filename, "a") as fout:
|
257 |
+
data = {
|
258 |
+
"tstamp": round(time.time(), 4),
|
259 |
+
"type": vote_type,
|
260 |
+
"model": model_selector,
|
261 |
+
"state": state.dict(),
|
262 |
+
"ip": get_ip(request),
|
263 |
+
}
|
264 |
+
fout.write(json.dumps(data) + "\n")
|
265 |
+
get_remote_logger().log(data)
|
266 |
+
|
267 |
+
|
268 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
269 |
+
ip = get_ip(request)
|
270 |
+
logger.info(f"upvote. ip: {ip}")
|
271 |
+
vote_last_response(state, "upvote", model_selector, request)
|
272 |
+
return ("",) + (disable_btn,) * 3
|
273 |
+
|
274 |
+
|
275 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
276 |
+
ip = get_ip(request)
|
277 |
+
logger.info(f"downvote. ip: {ip}")
|
278 |
+
vote_last_response(state, "downvote", model_selector, request)
|
279 |
+
return ("",) + (disable_btn,) * 3
|
280 |
+
|
281 |
+
|
282 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
283 |
+
ip = get_ip(request)
|
284 |
+
logger.info(f"flag. ip: {ip}")
|
285 |
+
vote_last_response(state, "flag", model_selector, request)
|
286 |
+
return ("",) + (disable_btn,) * 3
|
287 |
+
|
288 |
+
|
289 |
+
def regenerate(state, request: gr.Request):
|
290 |
+
ip = get_ip(request)
|
291 |
+
logger.info(f"regenerate. ip: {ip}")
|
292 |
+
if not state.regen_support:
|
293 |
+
state.skip_next = True
|
294 |
+
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
295 |
+
state.conv.update_last_message(None)
|
296 |
+
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
297 |
+
|
298 |
+
|
299 |
+
def clear_history(request: gr.Request):
|
300 |
+
ip = get_ip(request)
|
301 |
+
logger.info(f"clear_history. ip: {ip}")
|
302 |
+
state = None
|
303 |
+
return (state, [], "") + (disable_btn,) * 5
|
304 |
+
|
305 |
+
|
306 |
+
def get_ip(request: gr.Request):
|
307 |
+
if "cf-connecting-ip" in request.headers:
|
308 |
+
ip = request.headers["cf-connecting-ip"]
|
309 |
+
elif "x-forwarded-for" in request.headers:
|
310 |
+
ip = request.headers["x-forwarded-for"]
|
311 |
+
if "," in ip:
|
312 |
+
ip = ip.split(",")[0]
|
313 |
+
else:
|
314 |
+
ip = request.client.host
|
315 |
+
return ip
|
316 |
+
|
317 |
+
|
318 |
+
def add_text(state, model_selector, text, request: gr.Request):
|
319 |
+
ip = get_ip(request)
|
320 |
+
logger.info(f"add_text. ip: {ip}. len: {len(text)}")
|
321 |
+
|
322 |
+
if state is None:
|
323 |
+
state = State(model_selector)
|
324 |
+
|
325 |
+
if len(text) <= 0:
|
326 |
+
state.skip_next = True
|
327 |
+
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
328 |
+
|
329 |
+
all_conv_text = state.conv.get_prompt()
|
330 |
+
all_conv_text = all_conv_text[-2000:] + "\nuser: " + text
|
331 |
+
flagged = moderation_filter(all_conv_text, [state.model_name])
|
332 |
+
# flagged = moderation_filter(text, [state.model_name])
|
333 |
+
if flagged:
|
334 |
+
logger.info(f"violate moderation. ip: {ip}. text: {text}")
|
335 |
+
# overwrite the original text
|
336 |
+
text = MODERATION_MSG
|
337 |
+
|
338 |
+
if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
|
339 |
+
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
|
340 |
+
state.skip_next = True
|
341 |
+
return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG, None) + (
|
342 |
+
no_change_btn,
|
343 |
+
) * 5
|
344 |
+
|
345 |
+
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
|
346 |
+
state.conv.append_message(state.conv.roles[0], text)
|
347 |
+
state.conv.append_message(state.conv.roles[1], None)
|
348 |
+
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
349 |
+
|
350 |
+
|
351 |
+
def model_worker_stream_iter(
|
352 |
+
conv,
|
353 |
+
model_name,
|
354 |
+
worker_addr,
|
355 |
+
prompt,
|
356 |
+
temperature,
|
357 |
+
repetition_penalty,
|
358 |
+
top_p,
|
359 |
+
max_new_tokens,
|
360 |
+
images,
|
361 |
+
):
|
362 |
+
# Make requests
|
363 |
+
gen_params = {
|
364 |
+
"model": model_name,
|
365 |
+
"prompt": prompt,
|
366 |
+
"temperature": temperature,
|
367 |
+
"repetition_penalty": repetition_penalty,
|
368 |
+
"top_p": top_p,
|
369 |
+
"max_new_tokens": max_new_tokens,
|
370 |
+
"stop": conv.stop_str,
|
371 |
+
"stop_token_ids": conv.stop_token_ids,
|
372 |
+
"echo": False,
|
373 |
+
}
|
374 |
+
|
375 |
+
logger.info(f"==== request ====\n{gen_params}")
|
376 |
+
|
377 |
+
if len(images) > 0:
|
378 |
+
gen_params["images"] = images
|
379 |
+
|
380 |
+
# Stream output
|
381 |
+
response = requests.post(
|
382 |
+
worker_addr + "/worker_generate_stream",
|
383 |
+
headers=headers,
|
384 |
+
json=gen_params,
|
385 |
+
stream=True,
|
386 |
+
timeout=WORKER_API_TIMEOUT,
|
387 |
+
)
|
388 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
389 |
+
if chunk:
|
390 |
+
data = json.loads(chunk.decode())
|
391 |
+
yield data
|
392 |
+
|
393 |
+
|
394 |
+
def is_limit_reached(model_name, ip):
|
395 |
+
monitor_url = "http://localhost:9090"
|
396 |
+
try:
|
397 |
+
ret = requests.get(
|
398 |
+
f"{monitor_url}/is_limit_reached?model={model_name}&user_id={ip}", timeout=1
|
399 |
+
)
|
400 |
+
obj = ret.json()
|
401 |
+
return obj
|
402 |
+
except Exception as e:
|
403 |
+
logger.info(f"monitor error: {e}")
|
404 |
+
return None
|
405 |
+
|
406 |
+
|
407 |
+
def bot_response(
|
408 |
+
state,
|
409 |
+
temperature,
|
410 |
+
top_p,
|
411 |
+
max_new_tokens,
|
412 |
+
request: gr.Request,
|
413 |
+
apply_rate_limit=True,
|
414 |
+
use_recommended_config=False,
|
415 |
+
):
|
416 |
+
ip = get_ip(request)
|
417 |
+
logger.info(f"bot_response. ip: {ip}")
|
418 |
+
start_tstamp = time.time()
|
419 |
+
temperature = float(temperature)
|
420 |
+
top_p = float(top_p)
|
421 |
+
max_new_tokens = int(max_new_tokens)
|
422 |
+
|
423 |
+
if state.skip_next:
|
424 |
+
# This generate call is skipped due to invalid inputs
|
425 |
+
state.skip_next = False
|
426 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
427 |
+
return
|
428 |
+
|
429 |
+
if apply_rate_limit:
|
430 |
+
ret = is_limit_reached(state.model_name, ip)
|
431 |
+
if ret is not None and ret["is_limit_reached"]:
|
432 |
+
error_msg = RATE_LIMIT_MSG + "\n\n" + ret["reason"]
|
433 |
+
logger.info(
|
434 |
+
f"rate limit reached. ip: {ip}. error_msg: {ret['reason']}")
|
435 |
+
state.conv.update_last_message(error_msg)
|
436 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
437 |
+
return
|
438 |
+
|
439 |
+
conv, model_name = state.conv, state.model_name
|
440 |
+
model_api_dict = (
|
441 |
+
api_endpoint_info[model_name] if model_name in api_endpoint_info else None
|
442 |
+
)
|
443 |
+
images = conv.get_images()
|
444 |
+
|
445 |
+
if model_api_dict is None:
|
446 |
+
# Query worker address
|
447 |
+
ret = requests.post(
|
448 |
+
controller_url + "/get_worker_address", json={"model": model_name}
|
449 |
+
)
|
450 |
+
worker_addr = ret.json()["address"]
|
451 |
+
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
452 |
+
|
453 |
+
# No available worker
|
454 |
+
if worker_addr == "":
|
455 |
+
conv.update_last_message(SERVER_ERROR_MSG)
|
456 |
+
yield (
|
457 |
+
state,
|
458 |
+
state.to_gradio_chatbot(),
|
459 |
+
disable_btn,
|
460 |
+
disable_btn,
|
461 |
+
disable_btn,
|
462 |
+
enable_btn,
|
463 |
+
enable_btn,
|
464 |
+
)
|
465 |
+
return
|
466 |
+
|
467 |
+
# Construct prompt.
|
468 |
+
# We need to call it here, so it will not be affected by "▌".
|
469 |
+
prompt = conv.get_prompt()
|
470 |
+
# Set repetition_penalty
|
471 |
+
if "t5" in model_name:
|
472 |
+
repetition_penalty = 1.2
|
473 |
+
else:
|
474 |
+
repetition_penalty = 1.0
|
475 |
+
|
476 |
+
stream_iter = model_worker_stream_iter(
|
477 |
+
conv,
|
478 |
+
model_name,
|
479 |
+
worker_addr,
|
480 |
+
prompt,
|
481 |
+
temperature,
|
482 |
+
repetition_penalty,
|
483 |
+
top_p,
|
484 |
+
max_new_tokens,
|
485 |
+
images,
|
486 |
+
)
|
487 |
+
else:
|
488 |
+
# Remove system prompt for API-based models unless specified
|
489 |
+
custom_system_prompt = model_api_dict.get(
|
490 |
+
"custom_system_prompt", False)
|
491 |
+
if not custom_system_prompt:
|
492 |
+
conv.set_system_message("")
|
493 |
+
|
494 |
+
if use_recommended_config:
|
495 |
+
recommended_config = model_api_dict.get("recommended_config", None)
|
496 |
+
if recommended_config is not None:
|
497 |
+
temperature = recommended_config.get(
|
498 |
+
"temperature", temperature)
|
499 |
+
top_p = recommended_config.get("top_p", top_p)
|
500 |
+
max_new_tokens = recommended_config.get(
|
501 |
+
"max_new_tokens", max_new_tokens
|
502 |
+
)
|
503 |
+
|
504 |
+
stream_iter = get_api_provider_stream_iter(
|
505 |
+
conv,
|
506 |
+
model_name,
|
507 |
+
model_api_dict,
|
508 |
+
temperature,
|
509 |
+
top_p,
|
510 |
+
max_new_tokens,
|
511 |
+
state,
|
512 |
+
)
|
513 |
+
|
514 |
+
html_code = ' <span class="cursor"></span> '
|
515 |
+
|
516 |
+
# conv.update_last_message("▌")
|
517 |
+
conv.update_last_message(html_code)
|
518 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
519 |
+
|
520 |
+
try:
|
521 |
+
data = {"text": ""}
|
522 |
+
for i, data in enumerate(stream_iter):
|
523 |
+
if data["error_code"] == 0:
|
524 |
+
output = data["text"].strip()
|
525 |
+
conv.update_last_message(output + "▌")
|
526 |
+
# conv.update_last_message(output + html_code)
|
527 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
528 |
+
else:
|
529 |
+
output = data["text"] + \
|
530 |
+
f"\n\n(error_code: {data['error_code']})"
|
531 |
+
conv.update_last_message(output)
|
532 |
+
yield (state, state.to_gradio_chatbot()) + (
|
533 |
+
disable_btn,
|
534 |
+
disable_btn,
|
535 |
+
disable_btn,
|
536 |
+
enable_btn,
|
537 |
+
enable_btn,
|
538 |
+
)
|
539 |
+
return
|
540 |
+
output = data["text"].strip()
|
541 |
+
conv.update_last_message(output)
|
542 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
543 |
+
except requests.exceptions.RequestException as e:
|
544 |
+
conv.update_last_message(
|
545 |
+
f"{SERVER_ERROR_MSG}\n\n"
|
546 |
+
f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
|
547 |
+
)
|
548 |
+
yield (state, state.to_gradio_chatbot()) + (
|
549 |
+
disable_btn,
|
550 |
+
disable_btn,
|
551 |
+
disable_btn,
|
552 |
+
enable_btn,
|
553 |
+
enable_btn,
|
554 |
+
)
|
555 |
+
return
|
556 |
+
except Exception as e:
|
557 |
+
conv.update_last_message(
|
558 |
+
f"{SERVER_ERROR_MSG}\n\n"
|
559 |
+
f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
|
560 |
+
)
|
561 |
+
yield (state, state.to_gradio_chatbot()) + (
|
562 |
+
disable_btn,
|
563 |
+
disable_btn,
|
564 |
+
disable_btn,
|
565 |
+
enable_btn,
|
566 |
+
enable_btn,
|
567 |
+
)
|
568 |
+
return
|
569 |
+
|
570 |
+
finish_tstamp = time.time()
|
571 |
+
logger.info(f"{output}")
|
572 |
+
|
573 |
+
conv.save_new_images(
|
574 |
+
has_csam_images=state.has_csam_image, use_remote_storage=use_remote_storage
|
575 |
+
)
|
576 |
+
|
577 |
+
filename = get_conv_log_filename(
|
578 |
+
is_vision=state.is_vision, has_csam_image=state.has_csam_image
|
579 |
+
)
|
580 |
+
|
581 |
+
with open(filename, "a") as fout:
|
582 |
+
data = {
|
583 |
+
"tstamp": round(finish_tstamp, 4),
|
584 |
+
"type": "chat",
|
585 |
+
"model": model_name,
|
586 |
+
"gen_params": {
|
587 |
+
"temperature": temperature,
|
588 |
+
"top_p": top_p,
|
589 |
+
"max_new_tokens": max_new_tokens,
|
590 |
+
},
|
591 |
+
"start": round(start_tstamp, 4),
|
592 |
+
"finish": round(finish_tstamp, 4),
|
593 |
+
"state": state.dict(),
|
594 |
+
"ip": get_ip(request),
|
595 |
+
}
|
596 |
+
fout.write(json.dumps(data) + "\n")
|
597 |
+
get_remote_logger().log(data)
|
598 |
+
|
599 |
+
|
600 |
+
block_css = """
|
601 |
+
.prose {
|
602 |
+
font-size: 105% !important;
|
603 |
+
}
|
604 |
+
|
605 |
+
#arena_leaderboard_dataframe table {
|
606 |
+
font-size: 105%;
|
607 |
+
}
|
608 |
+
#full_leaderboard_dataframe table {
|
609 |
+
font-size: 105%;
|
610 |
+
}
|
611 |
+
|
612 |
+
.tab-nav button {
|
613 |
+
font-size: 18px;
|
614 |
+
}
|
615 |
+
|
616 |
+
.chatbot h1 {
|
617 |
+
font-size: 130%;
|
618 |
+
}
|
619 |
+
.chatbot h2 {
|
620 |
+
font-size: 120%;
|
621 |
+
}
|
622 |
+
.chatbot h3 {
|
623 |
+
font-size: 110%;
|
624 |
+
}
|
625 |
+
|
626 |
+
#chatbot .prose {
|
627 |
+
font-size: 90% !important;
|
628 |
+
}
|
629 |
+
|
630 |
+
.sponsor-image-about img {
|
631 |
+
margin: 0 20px;
|
632 |
+
margin-top: 20px;
|
633 |
+
height: 40px;
|
634 |
+
max-height: 100%;
|
635 |
+
width: auto;
|
636 |
+
float: left;
|
637 |
+
}
|
638 |
+
|
639 |
+
.cursor {
|
640 |
+
display: inline-block;
|
641 |
+
width: 7px;
|
642 |
+
height: 1em;
|
643 |
+
background-color: black;
|
644 |
+
vertical-align: middle;
|
645 |
+
animation: blink 1s infinite;
|
646 |
+
}
|
647 |
+
|
648 |
+
.dark .cursor {
|
649 |
+
display: inline-block;
|
650 |
+
width: 7px;
|
651 |
+
height: 1em;
|
652 |
+
background-color: white;
|
653 |
+
vertical-align: middle;
|
654 |
+
animation: blink 1s infinite;
|
655 |
+
}
|
656 |
+
|
657 |
+
@keyframes blink {
|
658 |
+
0%, 50% { opacity: 1; }
|
659 |
+
50.1%, 100% { opacity: 0; }
|
660 |
+
}
|
661 |
+
|
662 |
+
.app {
|
663 |
+
max-width: 100% !important;
|
664 |
+
padding-left: 5% !important;
|
665 |
+
padding-right: 5% !important;
|
666 |
+
}
|
667 |
+
|
668 |
+
a {
|
669 |
+
color: #1976D2; /* Your current link color, a shade of blue */
|
670 |
+
text-decoration: none; /* Removes underline from links */
|
671 |
+
}
|
672 |
+
a:hover {
|
673 |
+
color: #63A4FF; /* This can be any color you choose for hover */
|
674 |
+
text-decoration: underline; /* Adds underline on hover */
|
675 |
+
}
|
676 |
+
|
677 |
+
.block {
|
678 |
+
overflow-y: hidden !important;
|
679 |
+
}
|
680 |
+
"""
|
681 |
+
|
682 |
+
|
683 |
+
# block_css = """
|
684 |
+
# #notice_markdown .prose {
|
685 |
+
# font-size: 110% !important;
|
686 |
+
# }
|
687 |
+
# #notice_markdown th {
|
688 |
+
# display: none;
|
689 |
+
# }
|
690 |
+
# #notice_markdown td {
|
691 |
+
# padding-top: 6px;
|
692 |
+
# padding-bottom: 6px;
|
693 |
+
# }
|
694 |
+
# #arena_leaderboard_dataframe table {
|
695 |
+
# font-size: 110%;
|
696 |
+
# }
|
697 |
+
# #full_leaderboard_dataframe table {
|
698 |
+
# font-size: 110%;
|
699 |
+
# }
|
700 |
+
# #model_description_markdown {
|
701 |
+
# font-size: 110% !important;
|
702 |
+
# }
|
703 |
+
# #leaderboard_markdown .prose {
|
704 |
+
# font-size: 110% !important;
|
705 |
+
# }
|
706 |
+
# #leaderboard_markdown td {
|
707 |
+
# padding-top: 6px;
|
708 |
+
# padding-bottom: 6px;
|
709 |
+
# }
|
710 |
+
# #leaderboard_dataframe td {
|
711 |
+
# line-height: 0.1em;
|
712 |
+
# }
|
713 |
+
# #about_markdown .prose {
|
714 |
+
# font-size: 110% !important;
|
715 |
+
# }
|
716 |
+
# #ack_markdown .prose {
|
717 |
+
# font-size: 110% !important;
|
718 |
+
# }
|
719 |
+
# #chatbot .prose {
|
720 |
+
# font-size: 105% !important;
|
721 |
+
# }
|
722 |
+
# .sponsor-image-about img {
|
723 |
+
# margin: 0 20px;
|
724 |
+
# margin-top: 20px;
|
725 |
+
# height: 40px;
|
726 |
+
# max-height: 100%;
|
727 |
+
# width: auto;
|
728 |
+
# float: left;
|
729 |
+
# }
|
730 |
+
|
731 |
+
# body {
|
732 |
+
# --body-text-size: 14px;
|
733 |
+
# }
|
734 |
+
|
735 |
+
# .chatbot h1, h2, h3 {
|
736 |
+
# margin-top: 8px; /* Adjust the value as needed */
|
737 |
+
# margin-bottom: 0px; /* Adjust the value as needed */
|
738 |
+
# padding-bottom: 0px;
|
739 |
+
# }
|
740 |
+
|
741 |
+
# .chatbot h1 {
|
742 |
+
# font-size: 130%;
|
743 |
+
# }
|
744 |
+
# .chatbot h2 {
|
745 |
+
# font-size: 120%;
|
746 |
+
# }
|
747 |
+
# .chatbot h3 {
|
748 |
+
# font-size: 110%;
|
749 |
+
# }
|
750 |
+
# .chatbot p:not(:first-child) {
|
751 |
+
# margin-top: 8px;
|
752 |
+
# }
|
753 |
+
|
754 |
+
# .typing {
|
755 |
+
# display: inline-block;
|
756 |
+
# }
|
757 |
+
|
758 |
+
# """
|
759 |
+
|
760 |
+
|
761 |
+
def get_model_description_md(models):
|
762 |
+
model_description_md = """
|
763 |
+
| | | |
|
764 |
+
| ---- | ---- | ---- |
|
765 |
+
"""
|
766 |
+
return ""
|
767 |
+
ct = 0
|
768 |
+
visited = set()
|
769 |
+
for i, name in enumerate(models):
|
770 |
+
# minfo = ""
|
771 |
+
minfo = get_model_info(name)
|
772 |
+
if minfo.simple_name in visited:
|
773 |
+
continue
|
774 |
+
visited.add(minfo.simple_name)
|
775 |
+
one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
|
776 |
+
|
777 |
+
if ct % 3 == 0:
|
778 |
+
model_description_md += "|"
|
779 |
+
model_description_md += f" {one_model_md} |"
|
780 |
+
if ct % 3 == 2:
|
781 |
+
model_description_md += "\n"
|
782 |
+
ct += 1
|
783 |
+
return model_description_md
|
784 |
+
|
785 |
+
|
786 |
+
def build_about():
|
787 |
+
about_markdown = """
|
788 |
+
# About Us
|
789 |
+
|
790 |
+
"""
|
791 |
+
gr.Markdown(about_markdown, elem_id="about_markdown")
|
792 |
+
|
793 |
+
|
794 |
+
def build_single_model_ui(models, add_promotion_links=False):
|
795 |
+
promotion = (
|
796 |
+
f"""
|
797 |
+
|
798 |
+
{SURVEY_LINK}
|
799 |
+
|
800 |
+
## 👇 Choose any model to chat
|
801 |
+
"""
|
802 |
+
if add_promotion_links
|
803 |
+
else ""
|
804 |
+
)
|
805 |
+
|
806 |
+
notice_markdown = f"""
|
807 |
+
# 🏔️ Chatbot Arena
|
808 |
+
{promotion}
|
809 |
+
"""
|
810 |
+
|
811 |
+
state = gr.State()
|
812 |
+
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
813 |
+
|
814 |
+
with gr.Group(elem_id="share-region-named"):
|
815 |
+
with gr.Row(elem_id="model_selector_row"):
|
816 |
+
model_selector = gr.Dropdown(
|
817 |
+
choices=models,
|
818 |
+
value=models[0] if len(models) > 0 else "",
|
819 |
+
interactive=True,
|
820 |
+
show_label=False,
|
821 |
+
container=False,
|
822 |
+
)
|
823 |
+
with gr.Row():
|
824 |
+
with gr.Accordion(
|
825 |
+
f"🔍 Expand to see the descriptions of {len(models)} models",
|
826 |
+
open=False,
|
827 |
+
):
|
828 |
+
model_description_md = get_model_description_md(models)
|
829 |
+
gr.Markdown(model_description_md,
|
830 |
+
elem_id="model_description_markdown")
|
831 |
+
|
832 |
+
chatbot = gr.Chatbot(
|
833 |
+
elem_id="chatbot",
|
834 |
+
label="Scroll down and start chatting",
|
835 |
+
height=650,
|
836 |
+
show_copy_button=True,
|
837 |
+
)
|
838 |
+
with gr.Row():
|
839 |
+
textbox = gr.Textbox(
|
840 |
+
show_label=False,
|
841 |
+
placeholder="👉 Enter your prompt and press ENTER",
|
842 |
+
elem_id="input_box",
|
843 |
+
)
|
844 |
+
send_btn = gr.Button(value="Send", variant="primary", scale=0)
|
845 |
+
|
846 |
+
with gr.Row() as button_row:
|
847 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
848 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
849 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
850 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
851 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
852 |
+
|
853 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
854 |
+
temperature = gr.Slider(
|
855 |
+
minimum=0.0,
|
856 |
+
maximum=1.0,
|
857 |
+
value=0.7,
|
858 |
+
step=0.1,
|
859 |
+
interactive=True,
|
860 |
+
label="Temperature",
|
861 |
+
)
|
862 |
+
top_p = gr.Slider(
|
863 |
+
minimum=0.0,
|
864 |
+
maximum=1.0,
|
865 |
+
value=1.0,
|
866 |
+
step=0.1,
|
867 |
+
interactive=True,
|
868 |
+
label="Top P",
|
869 |
+
)
|
870 |
+
max_output_tokens = gr.Slider(
|
871 |
+
minimum=16,
|
872 |
+
maximum=2048,
|
873 |
+
value=1024,
|
874 |
+
step=64,
|
875 |
+
interactive=True,
|
876 |
+
label="Max output tokens",
|
877 |
+
)
|
878 |
+
|
879 |
+
if add_promotion_links:
|
880 |
+
gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
|
881 |
+
|
882 |
+
# Register listeners
|
883 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
884 |
+
upvote_btn.click(
|
885 |
+
upvote_last_response,
|
886 |
+
[state, model_selector],
|
887 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
888 |
+
)
|
889 |
+
downvote_btn.click(
|
890 |
+
downvote_last_response,
|
891 |
+
[state, model_selector],
|
892 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
893 |
+
)
|
894 |
+
flag_btn.click(
|
895 |
+
flag_last_response,
|
896 |
+
[state, model_selector],
|
897 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
898 |
+
)
|
899 |
+
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
900 |
+
bot_response,
|
901 |
+
[state, temperature, top_p, max_output_tokens],
|
902 |
+
[state, chatbot] + btn_list,
|
903 |
+
)
|
904 |
+
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
905 |
+
|
906 |
+
model_selector.change(clear_history, None, [
|
907 |
+
state, chatbot, textbox] + btn_list)
|
908 |
+
|
909 |
+
textbox.submit(
|
910 |
+
add_text,
|
911 |
+
[state, model_selector, textbox],
|
912 |
+
[state, chatbot, textbox] + btn_list,
|
913 |
+
).then(
|
914 |
+
bot_response,
|
915 |
+
[state, temperature, top_p, max_output_tokens],
|
916 |
+
[state, chatbot] + btn_list,
|
917 |
+
)
|
918 |
+
send_btn.click(
|
919 |
+
add_text,
|
920 |
+
[state, model_selector, textbox],
|
921 |
+
[state, chatbot, textbox] + btn_list,
|
922 |
+
).then(
|
923 |
+
bot_response,
|
924 |
+
[state, temperature, top_p, max_output_tokens],
|
925 |
+
[state, chatbot] + btn_list,
|
926 |
+
)
|
927 |
+
|
928 |
+
return [state, model_selector]
|
929 |
+
|
930 |
+
|
931 |
+
def build_demo(models):
|
932 |
+
with gr.Blocks(
|
933 |
+
title="Chatbot Arena (formerly LMSYS): Free AI Chat to Compare & Test Best AI Chatbots",
|
934 |
+
theme=gr.themes.Default(),
|
935 |
+
css=block_css,
|
936 |
+
) as demo:
|
937 |
+
url_params = gr.JSON(visible=False)
|
938 |
+
|
939 |
+
state, model_selector = build_single_model_ui(models)
|
940 |
+
|
941 |
+
if args.model_list_mode not in ["once", "reload"]:
|
942 |
+
raise ValueError(
|
943 |
+
f"Unknown model list mode: {args.model_list_mode}")
|
944 |
+
|
945 |
+
if args.show_terms_of_use:
|
946 |
+
load_js = get_window_url_params_with_tos_js
|
947 |
+
else:
|
948 |
+
load_js = get_window_url_params_js
|
949 |
+
|
950 |
+
demo.load(
|
951 |
+
load_demo,
|
952 |
+
[url_params],
|
953 |
+
[
|
954 |
+
state,
|
955 |
+
model_selector,
|
956 |
+
],
|
957 |
+
js=load_js,
|
958 |
+
)
|
959 |
+
|
960 |
+
return demo
|
961 |
+
|
962 |
+
|
963 |
+
if __name__ == "__main__":
|
964 |
+
parser = argparse.ArgumentParser()
|
965 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
966 |
+
parser.add_argument("--port", type=int)
|
967 |
+
parser.add_argument(
|
968 |
+
"--share",
|
969 |
+
action="store_true",
|
970 |
+
help="Whether to generate a public, shareable link",
|
971 |
+
)
|
972 |
+
parser.add_argument(
|
973 |
+
"--controller-url",
|
974 |
+
type=str,
|
975 |
+
default="",
|
976 |
+
# default="http://localhost:21001",
|
977 |
+
help="The address of the controller",
|
978 |
+
)
|
979 |
+
parser.add_argument(
|
980 |
+
"--concurrency-count",
|
981 |
+
type=int,
|
982 |
+
default=10,
|
983 |
+
help="The concurrency count of the gradio queue",
|
984 |
+
)
|
985 |
+
parser.add_argument(
|
986 |
+
"--model-list-mode",
|
987 |
+
type=str,
|
988 |
+
default="once",
|
989 |
+
choices=["once", "reload"],
|
990 |
+
help="Whether to load the model list once or reload the model list every time",
|
991 |
+
)
|
992 |
+
parser.add_argument(
|
993 |
+
"--moderate",
|
994 |
+
action="store_true",
|
995 |
+
help="Enable content moderation to block unsafe inputs",
|
996 |
+
)
|
997 |
+
parser.add_argument(
|
998 |
+
"--show-terms-of-use",
|
999 |
+
action="store_true",
|
1000 |
+
help="Shows term of use before loading the demo",
|
1001 |
+
)
|
1002 |
+
parser.add_argument(
|
1003 |
+
"--register-api-endpoint-file",
|
1004 |
+
type=str,
|
1005 |
+
help="Register API-based model endpoints from a JSON file",
|
1006 |
+
)
|
1007 |
+
parser.add_argument(
|
1008 |
+
"--gradio-auth-path",
|
1009 |
+
type=str,
|
1010 |
+
help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
|
1011 |
+
)
|
1012 |
+
parser.add_argument(
|
1013 |
+
"--gradio-root-path",
|
1014 |
+
type=str,
|
1015 |
+
help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix",
|
1016 |
+
)
|
1017 |
+
parser.add_argument(
|
1018 |
+
"--use-remote-storage",
|
1019 |
+
action="store_true",
|
1020 |
+
default=False,
|
1021 |
+
help="Uploads image files to google cloud storage if set to true",
|
1022 |
+
)
|
1023 |
+
args = parser.parse_args()
|
1024 |
+
logger.info(f"args: {args}")
|
1025 |
+
|
1026 |
+
# Set global variables
|
1027 |
+
set_global_vars(args.controller_url, args.moderate,
|
1028 |
+
args.use_remote_storage)
|
1029 |
+
models, all_models = get_model_list(
|
1030 |
+
args.controller_url, args.register_api_endpoint_file, vision_arena=False
|
1031 |
+
)
|
1032 |
+
|
1033 |
+
# Set authorization credentials
|
1034 |
+
auth = None
|
1035 |
+
if args.gradio_auth_path is not None:
|
1036 |
+
auth = parse_gradio_auth_creds(args.gradio_auth_path)
|
1037 |
+
|
1038 |
+
# Launch the demo
|
1039 |
+
demo = build_demo(models)
|
1040 |
+
demo.queue(
|
1041 |
+
default_concurrency_limit=args.concurrency_count,
|
1042 |
+
status_update_rate=10,
|
1043 |
+
api_open=False,
|
1044 |
+
).launch(
|
1045 |
+
server_name=args.host,
|
1046 |
+
server_port=args.port,
|
1047 |
+
share=args.share,
|
1048 |
+
max_threads=200,
|
1049 |
+
auth=auth,
|
1050 |
+
root_path=args.gradio_root_path,
|
1051 |
+
)
|
serve/remote_logger.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A JSON logger that sends data to remote endpoint.
|
2 |
+
# Architecturally, it hosts a background thread that sends logs to a remote endpoint.
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import requests
|
6 |
+
import threading
|
7 |
+
import queue
|
8 |
+
import logging
|
9 |
+
|
10 |
+
_global_logger = None
|
11 |
+
|
12 |
+
|
13 |
+
def get_remote_logger():
|
14 |
+
global _global_logger
|
15 |
+
if _global_logger is None:
|
16 |
+
if url := os.environ.get("REMOTE_LOGGER_URL"):
|
17 |
+
logging.info(f"Remote logger enabled, sending data to {url}")
|
18 |
+
_global_logger = RemoteLogger(url=url)
|
19 |
+
else:
|
20 |
+
_global_logger = EmptyLogger()
|
21 |
+
return _global_logger
|
22 |
+
|
23 |
+
|
24 |
+
class EmptyLogger:
|
25 |
+
"""Dummy logger that does nothing."""
|
26 |
+
|
27 |
+
def __init__(self):
|
28 |
+
pass
|
29 |
+
|
30 |
+
def log(self, _data: dict):
|
31 |
+
pass
|
32 |
+
|
33 |
+
|
34 |
+
class RemoteLogger:
|
35 |
+
"""A JSON logger that sends data to remote endpoint."""
|
36 |
+
|
37 |
+
def __init__(self, url: str):
|
38 |
+
self.url = url
|
39 |
+
|
40 |
+
self.logs = queue.Queue()
|
41 |
+
self.thread = threading.Thread(target=self._send_logs, daemon=True)
|
42 |
+
self.thread.start()
|
43 |
+
|
44 |
+
def log(self, data: dict):
|
45 |
+
self.logs.put_nowait(data)
|
46 |
+
|
47 |
+
def _send_logs(self):
|
48 |
+
while True:
|
49 |
+
data = self.logs.get()
|
50 |
+
|
51 |
+
# process the data by keep only the top level fields, and turn any nested dict into a string
|
52 |
+
for key, value in data.items():
|
53 |
+
if isinstance(value, (dict, list, tuple)):
|
54 |
+
data[key] = json.dumps(value, ensure_ascii=False)
|
55 |
+
|
56 |
+
try:
|
57 |
+
requests.post(self.url, json=data)
|
58 |
+
except Exception:
|
59 |
+
logging.exception("Failed to send logs to remote endpoint")
|
serve/utils.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Common utilities.
|
3 |
+
"""
|
4 |
+
from asyncio import AbstractEventLoop
|
5 |
+
from io import BytesIO
|
6 |
+
import base64
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
import logging.handlers
|
10 |
+
import os
|
11 |
+
import platform
|
12 |
+
import sys
|
13 |
+
import time
|
14 |
+
from typing import AsyncGenerator, Generator
|
15 |
+
import warnings
|
16 |
+
|
17 |
+
import requests
|
18 |
+
|
19 |
+
from .constants import LOGDIR
|
20 |
+
|
21 |
+
|
22 |
+
handler = None
|
23 |
+
visited_loggers = set()
|
24 |
+
|
25 |
+
|
26 |
+
def build_logger(logger_name, logger_filename):
|
27 |
+
global handler
|
28 |
+
|
29 |
+
formatter = logging.Formatter(
|
30 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
31 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
32 |
+
)
|
33 |
+
|
34 |
+
# Set the format of root handlers
|
35 |
+
if not logging.getLogger().handlers:
|
36 |
+
if sys.version_info[1] >= 9:
|
37 |
+
# This is for windows
|
38 |
+
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
39 |
+
else:
|
40 |
+
if platform.system() == "Windows":
|
41 |
+
warnings.warn(
|
42 |
+
"If you are running on Windows, "
|
43 |
+
"we recommend you use Python >= 3.9 for UTF-8 encoding."
|
44 |
+
)
|
45 |
+
logging.basicConfig(level=logging.INFO)
|
46 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
47 |
+
|
48 |
+
# Redirect stdout and stderr to loggers
|
49 |
+
stdout_logger = logging.getLogger("stdout")
|
50 |
+
stdout_logger.setLevel(logging.INFO)
|
51 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
52 |
+
sys.stdout = sl
|
53 |
+
|
54 |
+
stderr_logger = logging.getLogger("stderr")
|
55 |
+
stderr_logger.setLevel(logging.ERROR)
|
56 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
57 |
+
sys.stderr = sl
|
58 |
+
|
59 |
+
# Get logger
|
60 |
+
logger = logging.getLogger(logger_name)
|
61 |
+
logger.setLevel(logging.INFO)
|
62 |
+
|
63 |
+
# Avoid httpx flooding POST logs
|
64 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
65 |
+
|
66 |
+
# if LOGDIR is empty, then don't try output log to local file
|
67 |
+
if LOGDIR != "":
|
68 |
+
os.makedirs(LOGDIR, exist_ok=True)
|
69 |
+
filename = os.path.join(LOGDIR, logger_filename)
|
70 |
+
handler = logging.handlers.TimedRotatingFileHandler(
|
71 |
+
filename, when="D", utc=True, encoding="utf-8"
|
72 |
+
)
|
73 |
+
handler.setFormatter(formatter)
|
74 |
+
|
75 |
+
for l in [stdout_logger, stderr_logger, logger]:
|
76 |
+
if l in visited_loggers:
|
77 |
+
continue
|
78 |
+
visited_loggers.add(l)
|
79 |
+
l.addHandler(handler)
|
80 |
+
|
81 |
+
return logger
|
82 |
+
|
83 |
+
|
84 |
+
class StreamToLogger(object):
|
85 |
+
"""
|
86 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(self, logger, log_level=logging.INFO):
|
90 |
+
self.terminal = sys.stdout
|
91 |
+
self.logger = logger
|
92 |
+
self.log_level = log_level
|
93 |
+
self.linebuf = ""
|
94 |
+
|
95 |
+
def __getattr__(self, attr):
|
96 |
+
return getattr(self.terminal, attr)
|
97 |
+
|
98 |
+
def write(self, buf):
|
99 |
+
temp_linebuf = self.linebuf + buf
|
100 |
+
self.linebuf = ""
|
101 |
+
for line in temp_linebuf.splitlines(True):
|
102 |
+
# From the io.TextIOWrapper docs:
|
103 |
+
# On output, if newline is None, any '\n' characters written
|
104 |
+
# are translated to the system default line separator.
|
105 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
106 |
+
# translates them so this is still cross platform.
|
107 |
+
if line[-1] == "\n":
|
108 |
+
encoded_message = line.encode(
|
109 |
+
"utf-8", "ignore").decode("utf-8")
|
110 |
+
self.logger.log(self.log_level, encoded_message.rstrip())
|
111 |
+
else:
|
112 |
+
self.linebuf += line
|
113 |
+
|
114 |
+
def flush(self):
|
115 |
+
if self.linebuf != "":
|
116 |
+
encoded_message = self.linebuf.encode(
|
117 |
+
"utf-8", "ignore").decode("utf-8")
|
118 |
+
self.logger.log(self.log_level, encoded_message.rstrip())
|
119 |
+
self.linebuf = ""
|
120 |
+
|
121 |
+
|
122 |
+
def disable_torch_init():
|
123 |
+
"""
|
124 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
125 |
+
"""
|
126 |
+
import torch
|
127 |
+
|
128 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
129 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
130 |
+
|
131 |
+
|
132 |
+
def get_gpu_memory(max_gpus=None):
|
133 |
+
"""Get available memory for each GPU."""
|
134 |
+
import torch
|
135 |
+
|
136 |
+
gpu_memory = []
|
137 |
+
num_gpus = (
|
138 |
+
torch.cuda.device_count()
|
139 |
+
if max_gpus is None
|
140 |
+
else min(max_gpus, torch.cuda.device_count())
|
141 |
+
)
|
142 |
+
|
143 |
+
for gpu_id in range(num_gpus):
|
144 |
+
with torch.cuda.device(gpu_id):
|
145 |
+
device = torch.cuda.current_device()
|
146 |
+
gpu_properties = torch.cuda.get_device_properties(device)
|
147 |
+
total_memory = gpu_properties.total_memory / (1024**3)
|
148 |
+
allocated_memory = torch.cuda.memory_allocated() / (1024**3)
|
149 |
+
available_memory = total_memory - allocated_memory
|
150 |
+
gpu_memory.append(available_memory)
|
151 |
+
return gpu_memory
|
152 |
+
|
153 |
+
|
154 |
+
def oai_moderation(text, custom_thresholds=None):
|
155 |
+
"""
|
156 |
+
Check whether the text violates OpenAI moderation API.
|
157 |
+
"""
|
158 |
+
import openai
|
159 |
+
|
160 |
+
client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
161 |
+
|
162 |
+
# default to true to be conservative
|
163 |
+
flagged = True
|
164 |
+
MAX_RETRY = 3
|
165 |
+
for _ in range(MAX_RETRY):
|
166 |
+
try:
|
167 |
+
res = client.moderations.create(input=text)
|
168 |
+
flagged = res.results[0].flagged
|
169 |
+
if custom_thresholds is not None:
|
170 |
+
for category, threshold in custom_thresholds.items():
|
171 |
+
if getattr(res.results[0].category_scores, category) > threshold:
|
172 |
+
flagged = True
|
173 |
+
break
|
174 |
+
except (openai.OpenAIError, KeyError, IndexError) as e:
|
175 |
+
print(f"MODERATION ERROR: {e}\nInput: {text}")
|
176 |
+
return flagged
|
177 |
+
|
178 |
+
|
179 |
+
def moderation_filter(text, model_list, do_moderation=False):
|
180 |
+
# Apply moderation for below models
|
181 |
+
MODEL_KEYWORDS = [
|
182 |
+
"claude",
|
183 |
+
"gpt",
|
184 |
+
"bard",
|
185 |
+
"mistral-large",
|
186 |
+
"command-r",
|
187 |
+
"dbrx",
|
188 |
+
"gemini",
|
189 |
+
"reka",
|
190 |
+
"eureka",
|
191 |
+
]
|
192 |
+
|
193 |
+
custom_thresholds = {"sexual": 0.3}
|
194 |
+
# set a stricter threshold for claude
|
195 |
+
for model in model_list:
|
196 |
+
if "claude" in model:
|
197 |
+
custom_thresholds = {"sexual": 0.2}
|
198 |
+
|
199 |
+
for keyword in MODEL_KEYWORDS:
|
200 |
+
for model in model_list:
|
201 |
+
if keyword in model:
|
202 |
+
do_moderation = True
|
203 |
+
break
|
204 |
+
|
205 |
+
if do_moderation:
|
206 |
+
return oai_moderation(text, custom_thresholds)
|
207 |
+
return False
|
208 |
+
|
209 |
+
|
210 |
+
def clean_flant5_ckpt(ckpt_path):
|
211 |
+
"""
|
212 |
+
Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings,
|
213 |
+
Use this function to make sure it can be correctly loaded.
|
214 |
+
"""
|
215 |
+
import torch
|
216 |
+
|
217 |
+
index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
|
218 |
+
index_json = json.load(open(index_file, "r"))
|
219 |
+
|
220 |
+
weightmap = index_json["weight_map"]
|
221 |
+
|
222 |
+
share_weight_file = weightmap["shared.weight"]
|
223 |
+
share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[
|
224 |
+
"shared.weight"
|
225 |
+
]
|
226 |
+
|
227 |
+
for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]:
|
228 |
+
weight_file = weightmap[weight_name]
|
229 |
+
weight = torch.load(os.path.join(ckpt_path, weight_file))
|
230 |
+
weight[weight_name] = share_weight
|
231 |
+
torch.save(weight, os.path.join(ckpt_path, weight_file))
|
232 |
+
|
233 |
+
|
234 |
+
def pretty_print_semaphore(semaphore):
|
235 |
+
"""Print a semaphore in better format."""
|
236 |
+
if semaphore is None:
|
237 |
+
return "None"
|
238 |
+
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
239 |
+
|
240 |
+
|
241 |
+
"""A javascript function to get url parameters for the gradio web server."""
|
242 |
+
get_window_url_params_js = """
|
243 |
+
function() {
|
244 |
+
const params = new URLSearchParams(window.location.search);
|
245 |
+
url_params = Object.fromEntries(params);
|
246 |
+
console.log("url_params", url_params);
|
247 |
+
return url_params;
|
248 |
+
}
|
249 |
+
"""
|
250 |
+
|
251 |
+
get_window_url_params_with_tos_js = """
|
252 |
+
function() {
|
253 |
+
const params = new URLSearchParams(window.location.search);
|
254 |
+
const url_params = Object.fromEntries(params);
|
255 |
+
console.log("url_params", url_params);
|
256 |
+
|
257 |
+
const urlContainsLeaderboard = Object.keys(url_params).some(key => key.toLowerCase().includes("leaderboard"));
|
258 |
+
const msg = "Users of this website are required to agree to the following terms:\\n\\nThe service is a research preview. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\\nPlease do not upload any private information.\\nThe service collects user dialogue data, including both text and images, and reserves the right to use it for future AI development and distribute it under a Creative Commons Attribution (CC-BY) or a similar license.";
|
259 |
+
if (!urlContainsLeaderboard) {
|
260 |
+
if (window.alerted_before) return;
|
261 |
+
alert(msg);
|
262 |
+
window.alerted_before = true;
|
263 |
+
}
|
264 |
+
return url_params;
|
265 |
+
}
|
266 |
+
"""
|
267 |
+
|
268 |
+
alert_js = """
|
269 |
+
() => {
|
270 |
+
if (window.alerted_before) return;
|
271 |
+
const msg = "Users of this website are required to agree to the following terms:\\n\\nThe service is a research preview. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\\nPlease do not upload any private information.\\nThe service collects user dialogue data, including both text and images, and reserves the right to use if for future AI development and distribute it under a Creative Commons Attribution (CC-BY) or a similar license.";
|
272 |
+
alert(msg);
|
273 |
+
window.alerted_before = true;
|
274 |
+
}
|
275 |
+
"""
|
276 |
+
|
277 |
+
|
278 |
+
def iter_over_async(
|
279 |
+
async_gen: AsyncGenerator, event_loop: AbstractEventLoop
|
280 |
+
) -> Generator:
|
281 |
+
"""
|
282 |
+
Convert async generator to sync generator
|
283 |
+
|
284 |
+
:param async_gen: the AsyncGenerator to convert
|
285 |
+
:param event_loop: the event loop to run on
|
286 |
+
:returns: Sync generator
|
287 |
+
"""
|
288 |
+
ait = async_gen.__aiter__()
|
289 |
+
|
290 |
+
async def get_next():
|
291 |
+
try:
|
292 |
+
obj = await ait.__anext__()
|
293 |
+
return False, obj
|
294 |
+
except StopAsyncIteration:
|
295 |
+
return True, None
|
296 |
+
|
297 |
+
while True:
|
298 |
+
done, obj = event_loop.run_until_complete(get_next())
|
299 |
+
if done:
|
300 |
+
break
|
301 |
+
yield obj
|
302 |
+
|
303 |
+
|
304 |
+
def detect_language(text: str) -> str:
|
305 |
+
# とりあえず日本語
|
306 |
+
return "ja"
|
307 |
+
"""Detect the langauge of a string."""
|
308 |
+
import polyglot # pip3 install polyglot pyicu pycld2
|
309 |
+
from polyglot.detect import Detector
|
310 |
+
from polyglot.detect.base import logger as polyglot_logger
|
311 |
+
import pycld2
|
312 |
+
|
313 |
+
polyglot_logger.setLevel("ERROR")
|
314 |
+
|
315 |
+
try:
|
316 |
+
lang_code = Detector(text).language.name
|
317 |
+
except (pycld2.error, polyglot.detect.base.UnknownLanguage):
|
318 |
+
lang_code = "unknown"
|
319 |
+
return lang_code
|
320 |
+
|
321 |
+
|
322 |
+
def parse_gradio_auth_creds(filename: str):
|
323 |
+
"""Parse a username:password file for gradio authorization."""
|
324 |
+
gradio_auth_creds = []
|
325 |
+
with open(filename, "r", encoding="utf8") as file:
|
326 |
+
for line in file.readlines():
|
327 |
+
gradio_auth_creds += [x.strip()
|
328 |
+
for x in line.split(",") if x.strip()]
|
329 |
+
if gradio_auth_creds:
|
330 |
+
auth = [tuple(cred.split(":")) for cred in gradio_auth_creds]
|
331 |
+
else:
|
332 |
+
auth = None
|
333 |
+
return auth
|
334 |
+
|
335 |
+
|
336 |
+
def is_partial_stop(output: str, stop_str: str):
|
337 |
+
"""Check whether the output contains a partial stop str."""
|
338 |
+
for i in range(0, min(len(output), len(stop_str))):
|
339 |
+
if stop_str.startswith(output[-i:]):
|
340 |
+
return True
|
341 |
+
return False
|
342 |
+
|
343 |
+
|
344 |
+
def run_cmd(cmd: str):
|
345 |
+
"""Run a bash command."""
|
346 |
+
print(cmd)
|
347 |
+
return os.system(cmd)
|
348 |
+
|
349 |
+
|
350 |
+
def is_sentence_complete(output: str):
|
351 |
+
"""Check whether the output is a complete sentence."""
|
352 |
+
end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”")
|
353 |
+
return output.endswith(end_symbols)
|
354 |
+
|
355 |
+
|
356 |
+
# Models don't use the same configuration key for determining the maximum
|
357 |
+
# sequence length. Store them here so we can sanely check them.
|
358 |
+
# NOTE: The ordering here is important. Some models have two of these and we
|
359 |
+
# have a preference for which value gets used.
|
360 |
+
SEQUENCE_LENGTH_KEYS = [
|
361 |
+
"max_position_embeddings",
|
362 |
+
"max_sequence_length",
|
363 |
+
"seq_length",
|
364 |
+
"max_seq_len",
|
365 |
+
"model_max_length",
|
366 |
+
]
|
367 |
+
|
368 |
+
|
369 |
+
def get_context_length(config):
|
370 |
+
"""Get the context length of a model from a huggingface model config."""
|
371 |
+
rope_scaling = getattr(config, "rope_scaling", None)
|
372 |
+
if rope_scaling:
|
373 |
+
rope_scaling_factor = config.rope_scaling["factor"]
|
374 |
+
else:
|
375 |
+
rope_scaling_factor = 1
|
376 |
+
|
377 |
+
for key in SEQUENCE_LENGTH_KEYS:
|
378 |
+
val = getattr(config, key, None)
|
379 |
+
if val is not None:
|
380 |
+
return int(rope_scaling_factor * val)
|
381 |
+
return 2048
|
382 |
+
|
383 |
+
|
384 |
+
def str_to_torch_dtype(dtype: str):
|
385 |
+
import torch
|
386 |
+
|
387 |
+
if dtype is None:
|
388 |
+
return None
|
389 |
+
elif dtype == "float32":
|
390 |
+
return torch.float32
|
391 |
+
elif dtype == "float16":
|
392 |
+
return torch.float16
|
393 |
+
elif dtype == "bfloat16":
|
394 |
+
return torch.bfloat16
|
395 |
+
else:
|
396 |
+
raise ValueError(f"Unrecognized dtype: {dtype}")
|
397 |
+
|
398 |
+
|
399 |
+
def load_image(image_file):
|
400 |
+
from PIL import Image
|
401 |
+
import requests
|
402 |
+
|
403 |
+
image = None
|
404 |
+
|
405 |
+
if image_file.startswith("http://") or image_file.startswith("https://"):
|
406 |
+
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
407 |
+
response = requests.get(image_file, timeout=timeout)
|
408 |
+
image = Image.open(BytesIO(response.content))
|
409 |
+
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
|
410 |
+
image = Image.open(image_file)
|
411 |
+
elif image_file.startswith("data:"):
|
412 |
+
image_file = image_file.split(",")[1]
|
413 |
+
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
414 |
+
else:
|
415 |
+
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
416 |
+
|
417 |
+
return image
|
418 |
+
|
419 |
+
|
420 |
+
def upload_image_file_to_gcs(image, filename):
|
421 |
+
from google.cloud import storage
|
422 |
+
import io
|
423 |
+
|
424 |
+
storage_client = storage.Client()
|
425 |
+
# upload file to GCS
|
426 |
+
bucket = storage_client.get_bucket("arena_service_data")
|
427 |
+
|
428 |
+
blob = bucket.blob(f"{filename}")
|
429 |
+
if not blob.exists():
|
430 |
+
buffer = io.BytesIO()
|
431 |
+
image.save(buffer, format="PNG")
|
432 |
+
buffer.seek(0)
|
433 |
+
blob.upload_from_file(buffer, content_type="image/png")
|
434 |
+
|
435 |
+
return blob.public_url
|
436 |
+
|
437 |
+
|
438 |
+
def get_image_file_from_gcs(filename):
|
439 |
+
from google.cloud import storage
|
440 |
+
|
441 |
+
storage_client = storage.Client()
|
442 |
+
bucket = storage_client.get_bucket("arena_service_data")
|
443 |
+
blob = bucket.blob(f"{filename}")
|
444 |
+
contents = blob.download_as_bytes()
|
445 |
+
|
446 |
+
return contents
|
447 |
+
|
448 |
+
|
449 |
+
def image_moderation_request(image_bytes, endpoint, api_key):
|
450 |
+
headers = {"Content-Type": "image/jpeg",
|
451 |
+
"Ocp-Apim-Subscription-Key": api_key}
|
452 |
+
|
453 |
+
MAX_RETRIES = 3
|
454 |
+
for _ in range(MAX_RETRIES):
|
455 |
+
response = requests.post(
|
456 |
+
endpoint, headers=headers, data=image_bytes).json()
|
457 |
+
try:
|
458 |
+
if response["Status"]["Code"] == 3000:
|
459 |
+
break
|
460 |
+
except:
|
461 |
+
time.sleep(0.5)
|
462 |
+
return response
|
463 |
+
|
464 |
+
|
465 |
+
def image_moderation_provider(image, api_type):
|
466 |
+
if api_type == "nsfw":
|
467 |
+
endpoint = os.environ["AZURE_IMG_MODERATION_ENDPOINT"]
|
468 |
+
api_key = os.environ["AZURE_IMG_MODERATION_API_KEY"]
|
469 |
+
response = image_moderation_request(image, endpoint, api_key)
|
470 |
+
print(response)
|
471 |
+
return response["IsImageAdultClassified"]
|
472 |
+
elif api_type == "csam":
|
473 |
+
endpoint = (
|
474 |
+
"https://api.microsoftmoderator.com/photodna/v1.0/Match?enhance=false"
|
475 |
+
)
|
476 |
+
api_key = os.environ["PHOTODNA_API_KEY"]
|
477 |
+
response = image_moderation_request(image, endpoint, api_key)
|
478 |
+
return response["IsMatch"]
|
479 |
+
|
480 |
+
|
481 |
+
def image_moderation_filter(image):
|
482 |
+
print(f"moderating image")
|
483 |
+
|
484 |
+
image_bytes = base64.b64decode(image.base64_str)
|
485 |
+
|
486 |
+
nsfw_flagged = image_moderation_provider(image_bytes, "nsfw")
|
487 |
+
csam_flagged = False
|
488 |
+
|
489 |
+
if nsfw_flagged:
|
490 |
+
csam_flagged = image_moderation_provider(image_bytes, "csam")
|
491 |
+
|
492 |
+
return nsfw_flagged, csam_flagged
|