Spaces:
Running
Running
Upload 2 files
Browse files- app.py +198 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import uuid
|
4 |
+
from datetime import datetime
|
5 |
+
from pathlib import Path
|
6 |
+
from zoneinfo import ZoneInfo
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
from huggingface_hub import CommitScheduler
|
10 |
+
from openai import OpenAI
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
|
13 |
+
openai_api_key = os.getenv("api_key")
|
14 |
+
openai_api_base = os.getenv("api_url")
|
15 |
+
hf_token = os.getenv("hf_token")
|
16 |
+
model_name = "Aratako/calm3-22b-RP-v2"
|
17 |
+
|
18 |
+
client = OpenAI(
|
19 |
+
api_key=openai_api_key,
|
20 |
+
base_url=openai_api_base,
|
21 |
+
)
|
22 |
+
|
23 |
+
tokenizer = AutoTokenizer.from_pretrained("Aratako/calm3-22b-RP-v2")
|
24 |
+
|
25 |
+
# Define the file where to save the data. Use UUID to make sure not to overwrite existing data from a previous run.
|
26 |
+
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
|
27 |
+
log_folder = log_file.parent
|
28 |
+
|
29 |
+
# Schedule regular uploads. Remote repo and local folder are created if they don't already exist.
|
30 |
+
scheduler = CommitScheduler(
|
31 |
+
repo_id="Aratako/calm3-22b-RP-v2-logs", # Replace with your actual repo ID
|
32 |
+
repo_type="dataset",
|
33 |
+
folder_path=log_folder,
|
34 |
+
path_in_repo="data",
|
35 |
+
every=60, # Upload every 60 minutes
|
36 |
+
token=hf_token,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
def save_chat_logs(messages, response):
|
41 |
+
"""
|
42 |
+
Save conversation data in a JSON Lines file.
|
43 |
+
"""
|
44 |
+
with scheduler.lock:
|
45 |
+
entry = {
|
46 |
+
"timestamp": datetime.now(ZoneInfo("Asia/Tokyo")).isoformat(),
|
47 |
+
"messages": messages,
|
48 |
+
"response": response,
|
49 |
+
}
|
50 |
+
with log_file.open("a") as f:
|
51 |
+
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
52 |
+
|
53 |
+
|
54 |
+
def count_tokens(messages):
|
55 |
+
return sum(len(tokenizer.encode(msg["content"])) for msg in messages)
|
56 |
+
|
57 |
+
|
58 |
+
def trim_messages(messages, max_tokens):
|
59 |
+
# システムメッセージは常に保持
|
60 |
+
system_message = messages[0]
|
61 |
+
trimmed_messages = [system_message]
|
62 |
+
current_tokens = count_tokens([system_message])
|
63 |
+
|
64 |
+
# 最新のメッセージから逆順に追加していく
|
65 |
+
for message in reversed(messages[1:]):
|
66 |
+
message_tokens = count_tokens([message])
|
67 |
+
if current_tokens + message_tokens <= max_tokens:
|
68 |
+
trimmed_messages.insert(1, message) # システムメッセージの後に挿入
|
69 |
+
current_tokens += message_tokens
|
70 |
+
else:
|
71 |
+
break # トークン制限を超えたら終了
|
72 |
+
|
73 |
+
return trimmed_messages
|
74 |
+
|
75 |
+
|
76 |
+
def respond(
|
77 |
+
message,
|
78 |
+
history: list[tuple[str, str]],
|
79 |
+
system_message,
|
80 |
+
max_tokens,
|
81 |
+
temperature,
|
82 |
+
top_p,
|
83 |
+
):
|
84 |
+
try:
|
85 |
+
messages = [{"role": "system", "content": system_message}]
|
86 |
+
|
87 |
+
for val in history:
|
88 |
+
if val[0]:
|
89 |
+
messages.append({"role": "user", "content": val[0]})
|
90 |
+
if val[1]:
|
91 |
+
messages.append({"role": "assistant", "content": val[1]})
|
92 |
+
|
93 |
+
messages.append({"role": "user", "content": message})
|
94 |
+
# メッセージを調整して8192トークン以内に収める
|
95 |
+
max_input_tokens = 8192 - max_tokens
|
96 |
+
trimmed_messages = trim_messages(messages, max_input_tokens)
|
97 |
+
|
98 |
+
response = ""
|
99 |
+
for chunk in client.chat.completions.create(
|
100 |
+
model=model_name,
|
101 |
+
messages=trimmed_messages,
|
102 |
+
max_tokens=max_tokens,
|
103 |
+
stream=True,
|
104 |
+
temperature=temperature,
|
105 |
+
top_p=top_p,
|
106 |
+
):
|
107 |
+
token = chunk.choices[0].delta.content
|
108 |
+
if token is not None:
|
109 |
+
response += token
|
110 |
+
yield response
|
111 |
+
|
112 |
+
# Save conversation after the full response is generated
|
113 |
+
save_chat_logs(trimmed_messages, response)
|
114 |
+
|
115 |
+
except Exception as e:
|
116 |
+
yield f"エラーが発生しました: {str(e)}"
|
117 |
+
|
118 |
+
|
119 |
+
"""
|
120 |
+
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
121 |
+
"""
|
122 |
+
|
123 |
+
description = """
|
124 |
+
### [Aratako/calm3-22b-RP-v2](https://huggingface.co/Aratako/calm3-22b-RP-v2)のデモです。
|
125 |
+
- 対話ログを収集します。ただしデータは非公開リポジトリに保存され、私(Aratako)以外の第三者に公開はしません。
|
126 |
+
- 収集したログやそれを元に合成したデータを今後開発するモデルの学習に利用する可能性があります。ただしその際も本デモが元となるデータは公開しません。
|
127 |
+
- 特定のシチュエーションや状況での対話を好む方は、積極的に使ってもらうと今後開発するモデルがその状況のRPが上手くなる可能性もあります。
|
128 |
+
- **上記の条件に同意する場合のみ**、以下のChatbotを利用してください。
|
129 |
+
"""
|
130 |
+
|
131 |
+
|
132 |
+
HEADER = description
|
133 |
+
FOOTER = """### 注意
|
134 |
+
- コンテクスト長は8192までです。超えた場合、古い対話から順番に削除されます。"""
|
135 |
+
|
136 |
+
|
137 |
+
def run():
|
138 |
+
chatbot = gr.Chatbot(
|
139 |
+
elem_id="chatbot",
|
140 |
+
scale=1,
|
141 |
+
show_copy_button=True,
|
142 |
+
height="70%",
|
143 |
+
layout="panel",
|
144 |
+
)
|
145 |
+
with gr.Blocks(fill_height=True) as demo:
|
146 |
+
gr.Markdown(HEADER)
|
147 |
+
chat_interface = gr.ChatInterface(
|
148 |
+
fn=respond,
|
149 |
+
stop_btn="Stop Generation",
|
150 |
+
cache_examples=False,
|
151 |
+
multimodal=False,
|
152 |
+
chatbot=chatbot,
|
153 |
+
additional_inputs_accordion=gr.Accordion(
|
154 |
+
label="Parameters", open=False, render=False
|
155 |
+
),
|
156 |
+
additional_inputs=[
|
157 |
+
gr.Textbox(
|
158 |
+
value="ここにロールプレイの設定を書いてください。",
|
159 |
+
label="システムメッセージ (ロールプレイの設定)",
|
160 |
+
render=False,
|
161 |
+
),
|
162 |
+
gr.Slider(
|
163 |
+
minimum=1,
|
164 |
+
maximum=4096,
|
165 |
+
step=1,
|
166 |
+
value=1024,
|
167 |
+
label="Max tokens",
|
168 |
+
visible=True,
|
169 |
+
render=False,
|
170 |
+
),
|
171 |
+
gr.Slider(
|
172 |
+
minimum=0,
|
173 |
+
maximum=1,
|
174 |
+
step=0.1,
|
175 |
+
value=0.7,
|
176 |
+
label="Temperature",
|
177 |
+
visible=True,
|
178 |
+
render=False,
|
179 |
+
),
|
180 |
+
gr.Slider(
|
181 |
+
minimum=0,
|
182 |
+
maximum=1,
|
183 |
+
step=0.1,
|
184 |
+
value=0.9,
|
185 |
+
label="Top-p",
|
186 |
+
visible=True,
|
187 |
+
render=False,
|
188 |
+
),
|
189 |
+
],
|
190 |
+
analytics_enabled=False,
|
191 |
+
)
|
192 |
+
gr.Markdown(FOOTER)
|
193 |
+
demo.queue(max_size=256, api_open=True)
|
194 |
+
demo.launch(share=False, quiet=True)
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == "__main__":
|
198 |
+
run()
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub
|
2 |
+
openai
|
3 |
+
transformers
|