Spaces:
Sleeping
Sleeping
hoduyquocbao
commited on
Commit
•
00430c0
1
Parent(s):
5121e98
update new feature datasets
Browse files
app.py
CHANGED
@@ -1,6 +1,9 @@
|
|
|
|
1 |
import os
|
2 |
from threading import Thread
|
3 |
from typing import Iterator, List, Tuple, Dict, Any
|
|
|
|
|
4 |
|
5 |
import gradio as gr
|
6 |
import spaces
|
@@ -8,8 +11,9 @@ import torch
|
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
|
9 |
from bs4 import BeautifulSoup
|
10 |
import requests
|
11 |
-
import json
|
12 |
from functools import lru_cache
|
|
|
|
|
13 |
|
14 |
# ---------------------------- Cấu Hình ---------------------------- #
|
15 |
|
@@ -41,6 +45,60 @@ model.eval() # Đặt mô hình ở chế độ đánh giá
|
|
41 |
# Khởi tạo pipeline phân tích tâm lý
|
42 |
sentiment_pipeline = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
# ---------------------------- Định Nghĩa Hàm ---------------------------- #
|
45 |
|
46 |
@lru_cache(maxsize=128)
|
@@ -103,7 +161,7 @@ def summarize_text(text: str, max_length: int = 150) -> str:
|
|
103 |
]
|
104 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
105 |
input_ids = input_ids.to(device)
|
106 |
-
|
107 |
summary_streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
108 |
summary_kwargs = {
|
109 |
"input_ids": input_ids,
|
@@ -115,7 +173,7 @@ def summarize_text(text: str, max_length: int = 150) -> str:
|
|
115 |
}
|
116 |
t = Thread(target=model.generate, kwargs=summary_kwargs)
|
117 |
t.start()
|
118 |
-
|
119 |
summary = ""
|
120 |
for new_text in summary_streamer:
|
121 |
summary += new_text
|
@@ -128,26 +186,43 @@ def analyze_sentiment(text: str) -> str:
|
|
128 |
score = result[0]['score']
|
129 |
return f"🟢 **Tâm lý**: {sentiment} (Điểm: {score:.2f})"
|
130 |
|
131 |
-
def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
|
132 |
"""
|
133 |
Tạo phản hồi sử dụng mô hình Llama cục bộ theo chế độ streaming.
|
134 |
"""
|
135 |
-
#
|
136 |
-
conversation =
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
141 |
])
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
# Chuẩn bị input_ids từ tokenizer
|
145 |
-
input_ids = tokenizer.apply_chat_template(
|
146 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
input_ids = input_ids.to(device) # Di chuyển input tới thiết bị
|
150 |
-
|
151 |
# Khởi tạo streamer để nhận văn bản được tạo ra theo thời gian thực
|
152 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
153 |
generate_kwargs = {
|
@@ -163,13 +238,17 @@ def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_
|
|
163 |
}
|
164 |
t = Thread(target=model.generate, kwargs=generate_kwargs) # Tạo luồng để sinh văn bản
|
165 |
t.start()
|
166 |
-
|
167 |
# Stream văn bản được tạo ra
|
168 |
outputs = []
|
169 |
for text in streamer:
|
170 |
outputs.append(text)
|
171 |
yield "".join(outputs)
|
172 |
|
|
|
|
|
|
|
|
|
173 |
@lru_cache(maxsize=128)
|
174 |
def process_query(query: str) -> Dict[str, Any]:
|
175 |
"""
|
@@ -180,9 +259,10 @@ def process_query(query: str) -> Dict[str, Any]:
|
|
180 |
general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"]
|
181 |
summarize_keywords = ["tóm tắt", "tóm lại", "khái quát", "ngắn gọn"]
|
182 |
sentiment_keywords = ["cảm xúc", "tâm trạng", "tâm lý", "phân tích cảm xúc"]
|
|
|
183 |
|
184 |
query_lower = query.lower() # Chuyển truy vấn thành chữ thường để so sánh
|
185 |
-
|
186 |
if any(keyword in query_lower for keyword in web_search_keywords):
|
187 |
function_name = "web_search"
|
188 |
arguments = {"query": query}
|
@@ -192,25 +272,28 @@ def process_query(query: str) -> Dict[str, Any]:
|
|
192 |
elif any(keyword in query_lower for keyword in sentiment_keywords):
|
193 |
function_name = "sentiment_analysis"
|
194 |
arguments = {"prompt": query}
|
|
|
|
|
|
|
195 |
elif any(keyword in query_lower for keyword in general_query_keywords):
|
196 |
function_name = "general_query"
|
197 |
arguments = {"prompt": query}
|
198 |
else:
|
199 |
function_name = "hard_query"
|
200 |
arguments = {"prompt": query}
|
201 |
-
|
202 |
return {
|
203 |
"name": function_name,
|
204 |
"arguments": arguments
|
205 |
}
|
206 |
|
207 |
-
def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
|
208 |
"""
|
209 |
Thực thi hàm phù hợp dựa trên lời gọi hàm.
|
210 |
"""
|
211 |
function_name = function_call["name"]
|
212 |
arguments = function_call["arguments"]
|
213 |
-
|
214 |
if function_name == "web_search":
|
215 |
query = arguments["query"]
|
216 |
yield "🔍 Đang thực hiện tìm kiếm trên web..."
|
@@ -222,10 +305,10 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
|
|
222 |
web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
|
223 |
if not web_summary:
|
224 |
web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
|
225 |
-
|
226 |
# Trả về kết quả tìm kiếm cho người dùng
|
227 |
yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
|
228 |
-
|
229 |
elif function_name == "summarize_query":
|
230 |
# Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả
|
231 |
query = arguments["prompt"]
|
@@ -242,14 +325,22 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
|
|
242 |
# Tóm tắt nội dung đã lấy
|
243 |
yield "📝 Đang tóm tắt thông tin..."
|
244 |
summary = summarize_text(combined_text)
|
|
|
|
|
245 |
yield "📄 **Tóm tắt:**\n" + summary
|
246 |
-
|
247 |
elif function_name == "sentiment_analysis":
|
248 |
prompt_text = arguments["prompt"]
|
249 |
yield "📊 Đang phân tích tâm lý..."
|
250 |
sentiment = analyze_sentiment(prompt_text)
|
251 |
yield sentiment
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
elif function_name in ["general_query", "hard_query"]:
|
254 |
prompt_text = arguments["prompt"]
|
255 |
yield "🤖 Đang tạo phản hồi..."
|
@@ -257,6 +348,7 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
|
|
257 |
response_generator = generate_response(
|
258 |
prompt=prompt_text,
|
259 |
chat_history=chat_history,
|
|
|
260 |
max_new_tokens=max_new_tokens,
|
261 |
temperature=temperature,
|
262 |
top_p=top_p,
|
@@ -265,12 +357,23 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
|
|
265 |
)
|
266 |
for response in response_generator:
|
267 |
yield response
|
268 |
-
|
269 |
else:
|
270 |
yield "⚠️ Lời gọi hàm không được nhận dạng."
|
271 |
|
272 |
# ---------------------------- Giao Diện Gradio ---------------------------- #
|
273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
@spaces.GPU(duration=15, queue=False)
|
275 |
def generate(
|
276 |
message: str,
|
@@ -286,10 +389,13 @@ def generate(
|
|
286 |
"""
|
287 |
# Thông báo về việc phân tích đầu vào
|
288 |
yield "🔍 Đang phân tích truy vấn của bạn..."
|
289 |
-
|
|
|
|
|
|
|
290 |
# Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng
|
291 |
function_call = process_query(message)
|
292 |
-
|
293 |
# Thông báo về hàm được chọn
|
294 |
if function_call["name"] == "web_search":
|
295 |
yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web."
|
@@ -297,23 +403,27 @@ def generate(
|
|
297 |
yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản."
|
298 |
elif function_call["name"] == "sentiment_analysis":
|
299 |
yield "🛠️ Đã chọn chức năng: Phân tích tâm lý."
|
|
|
|
|
|
|
300 |
elif function_call["name"] in ["general_query", "hard_query"]:
|
301 |
yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi."
|
302 |
else:
|
303 |
yield "⚠️ Không thể xác định chức năng phù hợp."
|
304 |
-
|
305 |
# Xử lý lời gọi hàm và sinh phản hồi tương ứng
|
306 |
response_iterator = handle_functions(
|
307 |
function_call=function_call,
|
308 |
prompt=message,
|
309 |
chat_history=chat_history,
|
|
|
310 |
max_new_tokens=max_new_tokens,
|
311 |
temperature=temperature,
|
312 |
top_p=top_p,
|
313 |
top_k=top_k,
|
314 |
repetition_penalty=repetition_penalty
|
315 |
)
|
316 |
-
|
317 |
for response in response_iterator:
|
318 |
yield response
|
319 |
|
|
|
1 |
+
|
2 |
import os
|
3 |
from threading import Thread
|
4 |
from typing import Iterator, List, Tuple, Dict, Any
|
5 |
+
import uuid
|
6 |
+
import json
|
7 |
|
8 |
import gradio as gr
|
9 |
import spaces
|
|
|
11 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
|
12 |
from bs4 import BeautifulSoup
|
13 |
import requests
|
|
|
14 |
from functools import lru_cache
|
15 |
+
# from huggingface_hub import HfApi, HfFolder
|
16 |
+
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets
|
17 |
|
18 |
# ---------------------------- Cấu Hình ---------------------------- #
|
19 |
|
|
|
45 |
# Khởi tạo pipeline phân tích tâm lý
|
46 |
sentiment_pipeline = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
|
47 |
|
48 |
+
# ---------------------------- Thiết lập Bộ nhớ Sử dụng Huggingface Datasets ---------------------------- #
|
49 |
+
|
50 |
+
HF_DATASET = os.getenv("HF_DATASET") # Đảm bảo bạn đã set biến môi trường này "your_username/chat_memory" Thay đổi theo tên của bạn
|
51 |
+
|
52 |
+
def initialize_dataset():
|
53 |
+
"""
|
54 |
+
Khởi tạo Dataset trên Huggingface Hub nếu chưa tồn tại.
|
55 |
+
"""
|
56 |
+
try:
|
57 |
+
dataset = load_dataset(HF_DATASET)
|
58 |
+
except:
|
59 |
+
# Tạo Dataset mới nếu chưa tồn tại
|
60 |
+
dataset = DatasetDict({
|
61 |
+
"conversations": Dataset.from_dict({
|
62 |
+
"user_id": [],
|
63 |
+
"messages": []
|
64 |
+
})
|
65 |
+
})
|
66 |
+
dataset.push_to_hub(HF_DATASET, private=True)
|
67 |
+
return dataset
|
68 |
+
|
69 |
+
def save_conversation(user_id: str, messages: List[Tuple[str, str]]):
|
70 |
+
"""
|
71 |
+
Lưu cuộc hội thoại của người dùng vào Dataset.
|
72 |
+
"""
|
73 |
+
dataset = load_dataset(HF_DATASET)
|
74 |
+
# Chuyển đổi cuộc hội thoại thành định dạng JSON
|
75 |
+
messages_json = json.dumps(messages)
|
76 |
+
new_entry = {
|
77 |
+
"user_id": user_id,
|
78 |
+
"messages": messages_json
|
79 |
+
}
|
80 |
+
# Tạo Dataset từ entry mới
|
81 |
+
new_dataset = Dataset.from_dict(new_entry)
|
82 |
+
# Kết hợp với Dataset hiện tại
|
83 |
+
updated_dataset = concatenate_datasets([dataset["conversations"], new_dataset])
|
84 |
+
# Đẩy lên Hub
|
85 |
+
updated_dataset.push_to_hub(HF_DATASET, private=True)
|
86 |
+
|
87 |
+
def load_conversation(user_id: str) -> List[Tuple[str, str]]:
|
88 |
+
"""
|
89 |
+
Truy xuất cuộc hội thoại của người dùng từ Dataset.
|
90 |
+
"""
|
91 |
+
dataset = load_dataset(HF_DATASET)
|
92 |
+
# Tìm entry theo user_id
|
93 |
+
user_data = dataset["conversations"].filter(lambda x: x["user_id"] == user_id)
|
94 |
+
if len(user_data) == 0:
|
95 |
+
return []
|
96 |
+
messages_json = user_data["messages"][0]
|
97 |
+
return json.loads(messages_json)
|
98 |
+
|
99 |
+
# Khởi tạo Dataset
|
100 |
+
initialize_dataset()
|
101 |
+
|
102 |
# ---------------------------- Định Nghĩa Hàm ---------------------------- #
|
103 |
|
104 |
@lru_cache(maxsize=128)
|
|
|
161 |
]
|
162 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
163 |
input_ids = input_ids.to(device)
|
164 |
+
|
165 |
summary_streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
166 |
summary_kwargs = {
|
167 |
"input_ids": input_ids,
|
|
|
173 |
}
|
174 |
t = Thread(target=model.generate, kwargs=summary_kwargs)
|
175 |
t.start()
|
176 |
+
|
177 |
summary = ""
|
178 |
for new_text in summary_streamer:
|
179 |
summary += new_text
|
|
|
186 |
score = result[0]['score']
|
187 |
return f"🟢 **Tâm lý**: {sentiment} (Điểm: {score:.2f})"
|
188 |
|
189 |
+
def generate_response(prompt: str, chat_history: List[Tuple[str, str]], user_id: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
|
190 |
"""
|
191 |
Tạo phản hồi sử dụng mô hình Llama cục bộ theo chế độ streaming.
|
192 |
"""
|
193 |
+
# Lấy lịch sử từ Dataset
|
194 |
+
conversation = load_conversation(user_id)
|
195 |
+
# Chuyển đổi lịch sử thành định dạng mà mô hình hiểu
|
196 |
+
conversation_formatted = []
|
197 |
+
for user_msg, assistant_msg in conversation:
|
198 |
+
conversation_formatted.extend([
|
199 |
+
{"role": "user", "content": user_msg},
|
200 |
+
{"role": "assistant", "content": assistant_msg},
|
201 |
])
|
202 |
+
conversation_formatted.append({"role": "user", "content": prompt}) # Thêm tin nhắn của ngư��i dùng
|
203 |
+
|
204 |
+
# Kiểm tra độ dài và sử dụng bản tóm tắt nếu cần
|
205 |
+
if len(conversation_formatted) > 50: # Giới hạn số lượng tin nhắn, điều chỉnh tùy nhu cầu
|
206 |
+
summary = summarize_text(" ".join([msg["content"] for msg in conversation_formatted]))
|
207 |
+
# Lưu bản tóm tắt vào Dataset
|
208 |
+
new_messages = [("system", summary)]
|
209 |
+
save_conversation(user_id, new_messages)
|
210 |
+
# Giữ lại phần mới nhất
|
211 |
+
conversation_formatted = [{"role": "system", "content": summary}] + conversation_formatted[-25:]
|
212 |
|
213 |
# Chuẩn bị input_ids từ tokenizer
|
214 |
+
input_ids = tokenizer.apply_chat_template(conversation_formatted, add_generation_prompt=True, return_tensors="pt")
|
215 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
216 |
+
# Sử dụng bản tóm tắt từ bộ nhớ
|
217 |
+
summary = summarize_text(" ".join([msg["content"] for msg in conversation_formatted]))
|
218 |
+
conversation_formatted = [{"role": "system", "content": summary}] + conversation_formatted[-(MAX_INPUT_TOKEN_LENGTH // 2):]
|
219 |
+
input_ids = tokenizer.apply_chat_template(conversation_formatted, add_generation_prompt=True, return_tensors="pt")
|
220 |
+
# Lưu lại bản tóm tắt
|
221 |
+
new_messages = [("system", summary)]
|
222 |
+
save_conversation(user_id, new_messages)
|
223 |
+
|
224 |
input_ids = input_ids.to(device) # Di chuyển input tới thiết bị
|
225 |
+
|
226 |
# Khởi tạo streamer để nhận văn bản được tạo ra theo thời gian thực
|
227 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
228 |
generate_kwargs = {
|
|
|
238 |
}
|
239 |
t = Thread(target=model.generate, kwargs=generate_kwargs) # Tạo luồng để sinh văn bản
|
240 |
t.start()
|
241 |
+
|
242 |
# Stream văn bản được tạo ra
|
243 |
outputs = []
|
244 |
for text in streamer:
|
245 |
outputs.append(text)
|
246 |
yield "".join(outputs)
|
247 |
|
248 |
+
# Lưu phản hồi vào Dataset
|
249 |
+
response = "".join(outputs)
|
250 |
+
save_conversation(user_id, [(prompt, response)])
|
251 |
+
|
252 |
@lru_cache(maxsize=128)
|
253 |
def process_query(query: str) -> Dict[str, Any]:
|
254 |
"""
|
|
|
259 |
general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"]
|
260 |
summarize_keywords = ["tóm tắt", "tóm lại", "khái quát", "ngắn gọn"]
|
261 |
sentiment_keywords = ["cảm xúc", "tâm trạng", "tâm lý", "phân tích cảm xúc"]
|
262 |
+
topic_keywords = ["chủ đề", "bàn về", "về"]
|
263 |
|
264 |
query_lower = query.lower() # Chuyển truy vấn thành chữ thường để so sánh
|
265 |
+
|
266 |
if any(keyword in query_lower for keyword in web_search_keywords):
|
267 |
function_name = "web_search"
|
268 |
arguments = {"query": query}
|
|
|
272 |
elif any(keyword in query_lower for keyword in sentiment_keywords):
|
273 |
function_name = "sentiment_analysis"
|
274 |
arguments = {"prompt": query}
|
275 |
+
elif any(keyword in query_lower for keyword in topic_keywords):
|
276 |
+
function_name = "new_topic"
|
277 |
+
arguments = {"topic": query}
|
278 |
elif any(keyword in query_lower for keyword in general_query_keywords):
|
279 |
function_name = "general_query"
|
280 |
arguments = {"prompt": query}
|
281 |
else:
|
282 |
function_name = "hard_query"
|
283 |
arguments = {"prompt": query}
|
284 |
+
|
285 |
return {
|
286 |
"name": function_name,
|
287 |
"arguments": arguments
|
288 |
}
|
289 |
|
290 |
+
def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], user_id: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
|
291 |
"""
|
292 |
Thực thi hàm phù hợp dựa trên lời gọi hàm.
|
293 |
"""
|
294 |
function_name = function_call["name"]
|
295 |
arguments = function_call["arguments"]
|
296 |
+
|
297 |
if function_name == "web_search":
|
298 |
query = arguments["query"]
|
299 |
yield "🔍 Đang thực hiện tìm kiếm trên web..."
|
|
|
305 |
web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
|
306 |
if not web_summary:
|
307 |
web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
|
308 |
+
|
309 |
# Trả về kết quả tìm kiếm cho người dùng
|
310 |
yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
|
311 |
+
|
312 |
elif function_name == "summarize_query":
|
313 |
# Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả
|
314 |
query = arguments["prompt"]
|
|
|
325 |
# Tóm tắt nội dung đã lấy
|
326 |
yield "📝 Đang tóm tắt thông tin..."
|
327 |
summary = summarize_text(combined_text)
|
328 |
+
# Lưu tóm tắt vào Dataset
|
329 |
+
save_conversation(user_id, [("tóm tắt", summary)])
|
330 |
yield "📄 **Tóm tắt:**\n" + summary
|
331 |
+
|
332 |
elif function_name == "sentiment_analysis":
|
333 |
prompt_text = arguments["prompt"]
|
334 |
yield "📊 Đang phân tích tâm lý..."
|
335 |
sentiment = analyze_sentiment(prompt_text)
|
336 |
yield sentiment
|
337 |
+
|
338 |
+
elif function_name == "new_topic":
|
339 |
+
topic = arguments["topic"]
|
340 |
+
# Lưu chủ đề mới vào Dataset
|
341 |
+
save_conversation(user_id, [("chủ đề", f"Chủ đề mới: {topic}")])
|
342 |
+
yield f"🆕 Đã chuyển sang chủ đề mới: {topic}"
|
343 |
+
|
344 |
elif function_name in ["general_query", "hard_query"]:
|
345 |
prompt_text = arguments["prompt"]
|
346 |
yield "🤖 Đang tạo phản hồi..."
|
|
|
348 |
response_generator = generate_response(
|
349 |
prompt=prompt_text,
|
350 |
chat_history=chat_history,
|
351 |
+
user_id=user_id,
|
352 |
max_new_tokens=max_new_tokens,
|
353 |
temperature=temperature,
|
354 |
top_p=top_p,
|
|
|
357 |
)
|
358 |
for response in response_generator:
|
359 |
yield response
|
360 |
+
|
361 |
else:
|
362 |
yield "⚠️ Lời gọi hàm không được nhận dạng."
|
363 |
|
364 |
# ---------------------------- Giao Diện Gradio ---------------------------- #
|
365 |
|
366 |
+
def get_user_id():
|
367 |
+
"""
|
368 |
+
Tạo hoặc lấy user_id từ session state của Gradio.
|
369 |
+
Sử dụng cookie hoặc thông tin định danh tạm thời.
|
370 |
+
"""
|
371 |
+
# Gradio hiện không hỗ trợ session state natively, cần sử dụng workaround
|
372 |
+
# Dưới đây là cách tạo user_id tạm thời cho mỗi phiên
|
373 |
+
if "user_id" not in gr.get_session_state():
|
374 |
+
gr.get_session_state()["user_id"] = str(uuid.uuid4())
|
375 |
+
return gr.get_session_state()["user_id"]
|
376 |
+
|
377 |
@spaces.GPU(duration=15, queue=False)
|
378 |
def generate(
|
379 |
message: str,
|
|
|
389 |
"""
|
390 |
# Thông báo về việc phân tích đầu vào
|
391 |
yield "🔍 Đang phân tích truy vấn của bạn..."
|
392 |
+
|
393 |
+
# Lấy user_id từ session
|
394 |
+
user_id = get_user_id()
|
395 |
+
|
396 |
# Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng
|
397 |
function_call = process_query(message)
|
398 |
+
|
399 |
# Thông báo về hàm được chọn
|
400 |
if function_call["name"] == "web_search":
|
401 |
yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web."
|
|
|
403 |
yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản."
|
404 |
elif function_call["name"] == "sentiment_analysis":
|
405 |
yield "🛠️ Đã chọn chức năng: Phân tích tâm lý."
|
406 |
+
elif function_call["name"] == "new_topic":
|
407 |
+
yield "🛠️ Đã chọn chức năng: Chủ đề mới."
|
408 |
+
elif function_call["name"] in ["general_query", "hard
|
409 |
elif function_call["name"] in ["general_query", "hard_query"]:
|
410 |
yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi."
|
411 |
else:
|
412 |
yield "⚠️ Không thể xác định chức năng phù hợp."
|
413 |
+
|
414 |
# Xử lý lời gọi hàm và sinh phản hồi tương ứng
|
415 |
response_iterator = handle_functions(
|
416 |
function_call=function_call,
|
417 |
prompt=message,
|
418 |
chat_history=chat_history,
|
419 |
+
user_id=user_id, # Sử dụng user_id để quản lý dữ liệu theo người dùng
|
420 |
max_new_tokens=max_new_tokens,
|
421 |
temperature=temperature,
|
422 |
top_p=top_p,
|
423 |
top_k=top_k,
|
424 |
repetition_penalty=repetition_penalty
|
425 |
)
|
426 |
+
|
427 |
for response in response_iterator:
|
428 |
yield response
|
429 |
|