Aratako commited on
Commit
a12df0d
1 Parent(s): 95bb113

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +198 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import uuid
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from zoneinfo import ZoneInfo
7
+
8
+ import gradio as gr
9
+ from huggingface_hub import CommitScheduler
10
+ from openai import OpenAI
11
+ from transformers import AutoTokenizer
12
+
13
+ openai_api_key = os.getenv("api_key")
14
+ openai_api_base = os.getenv("api_url")
15
+ hf_token = os.getenv("hf_token")
16
+ model_name = "Aratako/calm3-22b-RP-v2"
17
+
18
+ client = OpenAI(
19
+ api_key=openai_api_key,
20
+ base_url=openai_api_base,
21
+ )
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained("Aratako/calm3-22b-RP-v2")
24
+
25
+ # Define the file where to save the data. Use UUID to make sure not to overwrite existing data from a previous run.
26
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
27
+ log_folder = log_file.parent
28
+
29
+ # Schedule regular uploads. Remote repo and local folder are created if they don't already exist.
30
+ scheduler = CommitScheduler(
31
+ repo_id="Aratako/calm3-22b-RP-v2-logs", # Replace with your actual repo ID
32
+ repo_type="dataset",
33
+ folder_path=log_folder,
34
+ path_in_repo="data",
35
+ every=60, # Upload every 60 minutes
36
+ token=hf_token,
37
+ )
38
+
39
+
40
+ def save_chat_logs(messages, response):
41
+ """
42
+ Save conversation data in a JSON Lines file.
43
+ """
44
+ with scheduler.lock:
45
+ entry = {
46
+ "timestamp": datetime.now(ZoneInfo("Asia/Tokyo")).isoformat(),
47
+ "messages": messages,
48
+ "response": response,
49
+ }
50
+ with log_file.open("a") as f:
51
+ f.write(json.dumps(entry, ensure_ascii=False) + "\n")
52
+
53
+
54
+ def count_tokens(messages):
55
+ return sum(len(tokenizer.encode(msg["content"])) for msg in messages)
56
+
57
+
58
+ def trim_messages(messages, max_tokens):
59
+ # システムメッセージは常に保持
60
+ system_message = messages[0]
61
+ trimmed_messages = [system_message]
62
+ current_tokens = count_tokens([system_message])
63
+
64
+ # 最新のメッセージから逆順に追加していく
65
+ for message in reversed(messages[1:]):
66
+ message_tokens = count_tokens([message])
67
+ if current_tokens + message_tokens <= max_tokens:
68
+ trimmed_messages.insert(1, message) # システムメッセージの後に挿入
69
+ current_tokens += message_tokens
70
+ else:
71
+ break # トークン制限を超えたら終了
72
+
73
+ return trimmed_messages
74
+
75
+
76
+ def respond(
77
+ message,
78
+ history: list[tuple[str, str]],
79
+ system_message,
80
+ max_tokens,
81
+ temperature,
82
+ top_p,
83
+ ):
84
+ try:
85
+ messages = [{"role": "system", "content": system_message}]
86
+
87
+ for val in history:
88
+ if val[0]:
89
+ messages.append({"role": "user", "content": val[0]})
90
+ if val[1]:
91
+ messages.append({"role": "assistant", "content": val[1]})
92
+
93
+ messages.append({"role": "user", "content": message})
94
+ # メッセージを調整して8192トークン以内に収める
95
+ max_input_tokens = 8192 - max_tokens
96
+ trimmed_messages = trim_messages(messages, max_input_tokens)
97
+
98
+ response = ""
99
+ for chunk in client.chat.completions.create(
100
+ model=model_name,
101
+ messages=trimmed_messages,
102
+ max_tokens=max_tokens,
103
+ stream=True,
104
+ temperature=temperature,
105
+ top_p=top_p,
106
+ ):
107
+ token = chunk.choices[0].delta.content
108
+ if token is not None:
109
+ response += token
110
+ yield response
111
+
112
+ # Save conversation after the full response is generated
113
+ save_chat_logs(trimmed_messages, response)
114
+
115
+ except Exception as e:
116
+ yield f"エラーが発生しました: {str(e)}"
117
+
118
+
119
+ """
120
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
121
+ """
122
+
123
+ description = """
124
+ ### [Aratako/calm3-22b-RP-v2](https://huggingface.co/Aratako/calm3-22b-RP-v2)のデモです。
125
+ - 対話ログを収集します。ただしデータは非公開リポジトリに保存され、私(Aratako)以外の第三者に公開はしません。
126
+ - 収集したログやそれを元に合成したデータを今後開発するモデルの学習に利用する可能性があります。ただしその際も本デモが元となるデータは公開しません。
127
+ - 特定のシチュエーションや状況での対話を好む方は、積極的に使ってもらうと今後開発するモデルがその状況のRPが上手くなる可能性もあります。
128
+ - **上記の条件に同意する場合のみ**、以下のChatbotを利用してください。
129
+ """
130
+
131
+
132
+ HEADER = description
133
+ FOOTER = """### 注意
134
+ - コンテクスト長は8192までです。超えた場合、古い対話から順番に削除されます。"""
135
+
136
+
137
+ def run():
138
+ chatbot = gr.Chatbot(
139
+ elem_id="chatbot",
140
+ scale=1,
141
+ show_copy_button=True,
142
+ height="70%",
143
+ layout="panel",
144
+ )
145
+ with gr.Blocks(fill_height=True) as demo:
146
+ gr.Markdown(HEADER)
147
+ chat_interface = gr.ChatInterface(
148
+ fn=respond,
149
+ stop_btn="Stop Generation",
150
+ cache_examples=False,
151
+ multimodal=False,
152
+ chatbot=chatbot,
153
+ additional_inputs_accordion=gr.Accordion(
154
+ label="Parameters", open=False, render=False
155
+ ),
156
+ additional_inputs=[
157
+ gr.Textbox(
158
+ value="ここにロールプレイの設定を書いてください。",
159
+ label="システムメッセージ (ロールプレイの設定)",
160
+ render=False,
161
+ ),
162
+ gr.Slider(
163
+ minimum=1,
164
+ maximum=4096,
165
+ step=1,
166
+ value=1024,
167
+ label="Max tokens",
168
+ visible=True,
169
+ render=False,
170
+ ),
171
+ gr.Slider(
172
+ minimum=0,
173
+ maximum=1,
174
+ step=0.1,
175
+ value=0.7,
176
+ label="Temperature",
177
+ visible=True,
178
+ render=False,
179
+ ),
180
+ gr.Slider(
181
+ minimum=0,
182
+ maximum=1,
183
+ step=0.1,
184
+ value=0.9,
185
+ label="Top-p",
186
+ visible=True,
187
+ render=False,
188
+ ),
189
+ ],
190
+ analytics_enabled=False,
191
+ )
192
+ gr.Markdown(FOOTER)
193
+ demo.queue(max_size=256, api_open=True)
194
+ demo.launch(share=False, quiet=True)
195
+
196
+
197
+ if __name__ == "__main__":
198
+ run()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ huggingface_hub
2
+ openai
3
+ transformers