Spaces:
Sleeping
Sleeping
hoduyquocbao
commited on
Commit
•
07f44f8
1
Parent(s):
46030b6
fix errors GUI
Browse files
app.py
CHANGED
@@ -1,8 +1,7 @@
|
|
|
|
1 |
import os
|
2 |
from threading import Thread
|
3 |
from typing import Iterator, List, Tuple, Dict, Any
|
4 |
-
import uuid
|
5 |
-
import json
|
6 |
|
7 |
import gradio as gr
|
8 |
import spaces
|
@@ -10,8 +9,8 @@ import torch
|
|
10 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
|
11 |
from bs4 import BeautifulSoup
|
12 |
import requests
|
|
|
13 |
from functools import lru_cache
|
14 |
-
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets
|
15 |
|
16 |
# ---------------------------- Cấu Hình ---------------------------- #
|
17 |
|
@@ -30,7 +29,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) # Đ
|
|
30 |
# Xác định thiết bị sử dụng (GPU nếu có, ngược lại CPU)
|
31 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
32 |
|
33 |
-
model_id = "
|
34 |
tokenizer = AutoTokenizer.from_pretrained(model_id) # Tải tokenizer từ Hugging Face
|
35 |
model = AutoModelForCausalLM.from_pretrained(
|
36 |
model_id,
|
@@ -43,80 +42,6 @@ model.eval() # Đặt mô hình ở chế độ đánh giá
|
|
43 |
# Khởi tạo pipeline phân tích tâm lý
|
44 |
sentiment_pipeline = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
|
45 |
|
46 |
-
# ---------------------------- Thiết lập Bộ nhớ Sử dụng Huggingface Datasets ---------------------------- #
|
47 |
-
|
48 |
-
HF_TOKEN = os.getenv("HF_TOKEN") # Đảm bảo bạn đã set biến môi trường này
|
49 |
-
HF_DATASET = os.getenv("HF_DATASET") # "your_username/chat_memory" Thay đổi theo tên của bạn
|
50 |
-
|
51 |
-
def initialize_dataset():
|
52 |
-
"""
|
53 |
-
Khởi tạo Dataset trên Huggingface Hub nếu chưa tồn tại.
|
54 |
-
"""
|
55 |
-
try:
|
56 |
-
dataset = load_dataset(HF_DATASET, use_auth_token=HF_TOKEN)
|
57 |
-
print("Dataset đã tồn tại trên Huggingface Hub.")
|
58 |
-
except Exception as e:
|
59 |
-
print(f"Dataset chưa tồn tại. Tạo mới Dataset: {e}")
|
60 |
-
# Tạo Dataset mới nếu chưa tồn tại
|
61 |
-
dataset = DatasetDict({
|
62 |
-
"conversations": Dataset.from_dict({
|
63 |
-
"user_id": [],
|
64 |
-
"messages": []
|
65 |
-
})
|
66 |
-
})
|
67 |
-
try:
|
68 |
-
dataset.push_to_hub(HF_DATASET, private=True, token=HF_TOKEN)
|
69 |
-
print("Dataset mới đã được tạo và đẩy lên Huggingface Hub.")
|
70 |
-
except Exception as push_e:
|
71 |
-
print(f"Lỗi khi đẩy Dataset lên Hub: {push_e}")
|
72 |
-
|
73 |
-
def save_conversation(user_id: str, messages: List[Tuple[str, str]]):
|
74 |
-
"""
|
75 |
-
Lưu cuộc hội thoại của người dùng vào Dataset.
|
76 |
-
"""
|
77 |
-
try:
|
78 |
-
dataset = load_dataset(HF_DATASET, split="conversations", use_auth_token=HF_TOKEN)
|
79 |
-
except Exception as e:
|
80 |
-
print(f"Lỗi khi tải Dataset: {e}")
|
81 |
-
return
|
82 |
-
|
83 |
-
# Chuyển đổi cuộc hội thoại thành định dạng JSON
|
84 |
-
messages_json = json.dumps(messages)
|
85 |
-
new_entry = {
|
86 |
-
"user_id": user_id,
|
87 |
-
"messages": messages_json
|
88 |
-
}
|
89 |
-
# Tạo Dataset từ entry mới
|
90 |
-
new_dataset = Dataset.from_dict(new_entry)
|
91 |
-
# Kết hợp với Dataset hiện tại
|
92 |
-
try:
|
93 |
-
updated_dataset = concatenate_datasets([dataset, new_dataset])
|
94 |
-
# Đẩy lên Hub
|
95 |
-
updated_dataset.push_to_hub(HF_DATASET, split="conversations", token=HF_TOKEN)
|
96 |
-
print(f"Cuộc hội thoại của user_id {user_id} đã được lưu.")
|
97 |
-
except Exception as e:
|
98 |
-
print(f"Lỗi khi đẩy Dataset lên Hub: {e}")
|
99 |
-
|
100 |
-
def load_conversation(user_id: str) -> List[Tuple[str, str]]:
|
101 |
-
"""
|
102 |
-
Truy xuất cuộc hội thoại của người dùng từ Dataset.
|
103 |
-
"""
|
104 |
-
try:
|
105 |
-
dataset = load_dataset(HF_DATASET, split="conversations", use_auth_token=HF_TOKEN)
|
106 |
-
except Exception as e:
|
107 |
-
print(f"Lỗi khi tải Dataset: {e}")
|
108 |
-
return []
|
109 |
-
|
110 |
-
# Tìm entry theo user_id
|
111 |
-
user_data = dataset.filter(lambda x: x["user_id"] == user_id)
|
112 |
-
if len(user_data) == 0:
|
113 |
-
return []
|
114 |
-
messages_json = user_data["messages"][0]
|
115 |
-
return json.loads(messages_json)
|
116 |
-
|
117 |
-
# Khởi tạo Dataset
|
118 |
-
initialize_dataset()
|
119 |
-
|
120 |
# ---------------------------- Định Nghĩa Hàm ---------------------------- #
|
121 |
|
122 |
@lru_cache(maxsize=128)
|
@@ -179,7 +104,7 @@ def summarize_text(text: str, max_length: int = 150) -> str:
|
|
179 |
]
|
180 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
181 |
input_ids = input_ids.to(device)
|
182 |
-
|
183 |
summary_streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
184 |
summary_kwargs = {
|
185 |
"input_ids": input_ids,
|
@@ -191,7 +116,7 @@ def summarize_text(text: str, max_length: int = 150) -> str:
|
|
191 |
}
|
192 |
t = Thread(target=model.generate, kwargs=summary_kwargs)
|
193 |
t.start()
|
194 |
-
|
195 |
summary = ""
|
196 |
for new_text in summary_streamer:
|
197 |
summary += new_text
|
@@ -204,43 +129,26 @@ def analyze_sentiment(text: str) -> str:
|
|
204 |
score = result[0]['score']
|
205 |
return f"🟢 **Tâm lý**: {sentiment} (Điểm: {score:.2f})"
|
206 |
|
207 |
-
def generate_response(prompt: str, chat_history: List[Tuple[str, str]],
|
208 |
"""
|
209 |
Tạo phản hồi sử dụng mô hình Llama cục bộ theo chế độ streaming.
|
210 |
"""
|
211 |
-
#
|
212 |
-
conversation =
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
{"role": "user", "content": user_msg},
|
218 |
-
{"role": "assistant", "content": assistant_msg},
|
219 |
])
|
220 |
-
|
221 |
-
|
222 |
-
# Kiểm tra độ dài và sử dụng bản tóm tắt nếu cần
|
223 |
-
if len(conversation_formatted) > 50: # Giới hạn số lượng tin nhắn, điều chỉnh tùy nhu cầu
|
224 |
-
summary = summarize_text(" ".join([msg["content"] for msg in conversation_formatted]))
|
225 |
-
# Lưu bản tóm tắt vào Dataset
|
226 |
-
new_messages = [("system", summary)]
|
227 |
-
save_conversation(user_id, new_messages)
|
228 |
-
# Giữ lại phần mới nhất
|
229 |
-
conversation_formatted = [{"role": "system", "content": summary}] + conversation_formatted[-25:]
|
230 |
-
|
231 |
# Chuẩn bị input_ids từ tokenizer
|
232 |
-
input_ids = tokenizer.apply_chat_template(
|
233 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
234 |
-
|
235 |
-
|
236 |
-
conversation_formatted = [{"role": "system", "content": summary}] + conversation_formatted[-(MAX_INPUT_TOKEN_LENGTH // 2):]
|
237 |
-
input_ids = tokenizer.apply_chat_template(conversation_formatted, add_generation_prompt=True, return_tensors="pt")
|
238 |
-
# Lưu lại bản tóm tắt
|
239 |
-
new_messages = [("system", summary)]
|
240 |
-
save_conversation(user_id, new_messages)
|
241 |
-
|
242 |
input_ids = input_ids.to(device) # Di chuyển input tới thiết bị
|
243 |
-
|
244 |
# Khởi tạo streamer để nhận văn bản được tạo ra theo thời gian thực
|
245 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
246 |
generate_kwargs = {
|
@@ -256,22 +164,12 @@ def generate_response(prompt: str, chat_history: List[Tuple[str, str]], user_id:
|
|
256 |
}
|
257 |
t = Thread(target=model.generate, kwargs=generate_kwargs) # Tạo luồng để sinh văn bản
|
258 |
t.start()
|
259 |
-
|
260 |
# Stream văn bản được tạo ra
|
261 |
outputs = []
|
262 |
for text in streamer:
|
263 |
outputs.append(text)
|
264 |
-
|
265 |
-
conversation_formatted[-1]["content"] += text
|
266 |
-
# Convert to list of tuples
|
267 |
-
display_history = []
|
268 |
-
for msg in conversation_formatted:
|
269 |
-
display_history.append((msg["role"], msg["content"]))
|
270 |
-
yield display_history
|
271 |
-
|
272 |
-
# Lưu phản hồi vào Dataset
|
273 |
-
response = "".join(outputs)
|
274 |
-
save_conversation(user_id, [(prompt, response)])
|
275 |
|
276 |
@lru_cache(maxsize=128)
|
277 |
def process_query(query: str) -> Dict[str, Any]:
|
@@ -283,10 +181,9 @@ def process_query(query: str) -> Dict[str, Any]:
|
|
283 |
general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"]
|
284 |
summarize_keywords = ["tóm tắt", "tóm lại", "khái quát", "ngắn gọn"]
|
285 |
sentiment_keywords = ["cảm xúc", "tâm trạng", "tâm lý", "phân tích cảm xúc"]
|
286 |
-
topic_keywords = ["chủ đề", "bàn về", "về"]
|
287 |
|
288 |
query_lower = query.lower() # Chuyển truy vấn thành chữ thường để so sánh
|
289 |
-
|
290 |
if any(keyword in query_lower for keyword in web_search_keywords):
|
291 |
function_name = "web_search"
|
292 |
arguments = {"query": query}
|
@@ -296,28 +193,25 @@ def process_query(query: str) -> Dict[str, Any]:
|
|
296 |
elif any(keyword in query_lower for keyword in sentiment_keywords):
|
297 |
function_name = "sentiment_analysis"
|
298 |
arguments = {"prompt": query}
|
299 |
-
elif any(keyword in query_lower for keyword in topic_keywords):
|
300 |
-
function_name = "new_topic"
|
301 |
-
arguments = {"topic": query}
|
302 |
elif any(keyword in query_lower for keyword in general_query_keywords):
|
303 |
function_name = "general_query"
|
304 |
arguments = {"prompt": query}
|
305 |
else:
|
306 |
function_name = "hard_query"
|
307 |
arguments = {"prompt": query}
|
308 |
-
|
309 |
return {
|
310 |
"name": function_name,
|
311 |
"arguments": arguments
|
312 |
}
|
313 |
|
314 |
-
def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]],
|
315 |
"""
|
316 |
Thực thi hàm phù hợp dựa trên lời gọi hàm.
|
317 |
"""
|
318 |
function_name = function_call["name"]
|
319 |
arguments = function_call["arguments"]
|
320 |
-
|
321 |
if function_name == "web_search":
|
322 |
query = arguments["query"]
|
323 |
yield "🔍 Đang thực hiện tìm kiếm trên web..."
|
@@ -329,10 +223,10 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
|
|
329 |
web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
|
330 |
if not web_summary:
|
331 |
web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
|
332 |
-
|
333 |
# Trả về kết quả tìm kiếm cho người dùng
|
334 |
yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
|
335 |
-
|
336 |
elif function_name == "summarize_query":
|
337 |
# Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả
|
338 |
query = arguments["prompt"]
|
@@ -349,22 +243,14 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
|
|
349 |
# Tóm tắt nội dung đã lấy
|
350 |
yield "📝 Đang tóm tắt thông tin..."
|
351 |
summary = summarize_text(combined_text)
|
352 |
-
# Lưu tóm tắt vào Dataset
|
353 |
-
save_conversation(user_id, [("tóm tắt", summary)])
|
354 |
yield "📄 **Tóm tắt:**\n" + summary
|
355 |
-
|
356 |
elif function_name == "sentiment_analysis":
|
357 |
prompt_text = arguments["prompt"]
|
358 |
yield "📊 Đang phân tích tâm lý..."
|
359 |
sentiment = analyze_sentiment(prompt_text)
|
360 |
yield sentiment
|
361 |
-
|
362 |
-
elif function_name == "new_topic":
|
363 |
-
topic = arguments["topic"]
|
364 |
-
# Lưu chủ đề mới vào Dataset
|
365 |
-
save_conversation(user_id, [("chủ đề", f"Chủ đề mới: {topic}")])
|
366 |
-
yield f"🆕 Đã chuyển sang chủ đề mới: {topic}"
|
367 |
-
|
368 |
elif function_name in ["general_query", "hard_query"]:
|
369 |
prompt_text = arguments["prompt"]
|
370 |
yield "🤖 Đang tạo phản hồi..."
|
@@ -372,7 +258,6 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
|
|
372 |
response_generator = generate_response(
|
373 |
prompt=prompt_text,
|
374 |
chat_history=chat_history,
|
375 |
-
user_id=user_id,
|
376 |
max_new_tokens=max_new_tokens,
|
377 |
temperature=temperature,
|
378 |
top_p=top_p,
|
@@ -381,77 +266,57 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
|
|
381 |
)
|
382 |
for response in response_generator:
|
383 |
yield response
|
384 |
-
|
385 |
else:
|
386 |
yield "⚠️ Lời gọi hàm không được nhận dạng."
|
387 |
|
388 |
# ---------------------------- Giao Diện Gradio ---------------------------- #
|
389 |
|
390 |
-
def get_user_id(user_state: gr.State) -> str:
|
391 |
-
"""
|
392 |
-
Tạo hoặc lấy user_id từ trạng thái của Gradio.
|
393 |
-
"""
|
394 |
-
if user_state.value is None:
|
395 |
-
user_state.value = str(uuid.uuid4())
|
396 |
-
return user_state.value
|
397 |
-
|
398 |
@spaces.GPU(duration=15, queue=False)
|
399 |
def generate(
|
400 |
message: str,
|
401 |
chat_history: List[Tuple[str, str]],
|
402 |
-
user_state: gr.State, # Trạng thái để lưu trữ user_id
|
403 |
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
404 |
temperature: float = 0.6,
|
405 |
top_p: float = 0.9,
|
406 |
top_k: int = 50,
|
407 |
repetition_penalty: float = 1.2,
|
408 |
-
) -> Iterator[
|
409 |
"""
|
410 |
Hàm chính để xử lý đầu vào của người dùng và tạo phản hồi.
|
411 |
"""
|
412 |
# Thông báo về việc phân tích đầu vào
|
413 |
-
yield
|
414 |
-
|
415 |
-
# Lấy user_id từ trạng thái
|
416 |
-
user_id = get_user_id(user_state)
|
417 |
-
|
418 |
# Xác định hàm nào sẽ đư��c gọi dựa trên tin nhắn của người dùng
|
419 |
function_call = process_query(message)
|
420 |
-
|
421 |
# Thông báo về hàm được chọn
|
422 |
if function_call["name"] == "web_search":
|
423 |
-
yield
|
424 |
elif function_call["name"] == "summarize_query":
|
425 |
-
yield
|
426 |
elif function_call["name"] == "sentiment_analysis":
|
427 |
-
yield
|
428 |
-
elif function_call["name"] == "new_topic":
|
429 |
-
yield [("system", "🛠️ Đã chọn chức năng: Chủ đề mới.")]
|
430 |
elif function_call["name"] in ["general_query", "hard_query"]:
|
431 |
-
yield
|
432 |
else:
|
433 |
-
yield
|
434 |
-
|
435 |
# Xử lý lời gọi hàm và sinh phản hồi tương ứng
|
436 |
response_iterator = handle_functions(
|
437 |
function_call=function_call,
|
438 |
prompt=message,
|
439 |
chat_history=chat_history,
|
440 |
-
user_id=user_id, # Sử dụng user_id để quản lý dữ liệu theo người dùng
|
441 |
max_new_tokens=max_new_tokens,
|
442 |
temperature=temperature,
|
443 |
top_p=top_p,
|
444 |
top_k=top_k,
|
445 |
repetition_penalty=repetition_penalty
|
446 |
)
|
447 |
-
|
448 |
-
# Start with the existing chat history
|
449 |
-
updated_chat_history = chat_history.copy()
|
450 |
-
|
451 |
for response in response_iterator:
|
452 |
-
|
453 |
-
updated_chat_history.append(("assistant", response))
|
454 |
-
yield updated_chat_history
|
455 |
|
456 |
# Định nghĩa các ví dụ để hướng dẫn người dùng
|
457 |
EXAMPLES = [
|
@@ -467,6 +332,54 @@ EXAMPLES = [
|
|
467 |
]
|
468 |
|
469 |
# Cấu hình giao diện trò chuyện của Gradio với giao diện đẹp mắt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
with gr.Blocks(css="""
|
471 |
.gradio-container {
|
472 |
background-color: #f0f2f5; /* Màu nền nhẹ nhàng */
|
@@ -489,86 +402,8 @@ with gr.Blocks(css="""
|
|
489 |
}
|
490 |
""", fill_height=True) as demo:
|
491 |
gr.Markdown(DESCRIPTION) # Hiển thị mô tả
|
492 |
-
|
493 |
-
#
|
494 |
-
user_state = gr.State(None)
|
495 |
-
|
496 |
-
# Nút nhân bản không gian
|
497 |
-
gr.DuplicateButton(value="Nhân bản Không gian để sử dụng riêng tư", elem_id="duplicate-button")
|
498 |
-
|
499 |
-
# Chat Interface
|
500 |
-
with gr.Row():
|
501 |
-
chatbot = gr.Chatbot()
|
502 |
-
|
503 |
-
with gr.Row():
|
504 |
-
with gr.Column():
|
505 |
-
message = gr.Textbox(
|
506 |
-
label="Bạn:",
|
507 |
-
placeholder="Nhập tin nhắn của bạn tại đây...",
|
508 |
-
)
|
509 |
-
submit = gr.Button("Gửi")
|
510 |
-
with gr.Column():
|
511 |
-
# Các thanh trượt cho tham số
|
512 |
-
max_new_tokens = gr.Slider(
|
513 |
-
label="Số token mới tối đa",
|
514 |
-
minimum=1,
|
515 |
-
maximum=MAX_MAX_NEW_TOKENS,
|
516 |
-
step=1,
|
517 |
-
value=DEFAULT_MAX_NEW_TOKENS,
|
518 |
-
)
|
519 |
-
temperature = gr.Slider(
|
520 |
-
label="Nhiệt độ",
|
521 |
-
minimum=0.1,
|
522 |
-
maximum=4.0,
|
523 |
-
step=0.1,
|
524 |
-
value=0.6,
|
525 |
-
)
|
526 |
-
top_p = gr.Slider(
|
527 |
-
label="Top-p (nucleus sampling)",
|
528 |
-
minimum=0.05,
|
529 |
-
maximum=1.0,
|
530 |
-
step=0.05,
|
531 |
-
value=0.9,
|
532 |
-
)
|
533 |
-
top_k = gr.Slider(
|
534 |
-
label="Top-k",
|
535 |
-
minimum=1,
|
536 |
-
maximum=1000,
|
537 |
-
step=1,
|
538 |
-
value=50,
|
539 |
-
)
|
540 |
-
repetition_penalty = gr.Slider(
|
541 |
-
label="Hình phạt sự lặp lại",
|
542 |
-
minimum=1.0,
|
543 |
-
maximum=2.0,
|
544 |
-
step=0.05,
|
545 |
-
value=1.2,
|
546 |
-
)
|
547 |
-
|
548 |
-
# Kết nối nút gửi với hàm generate
|
549 |
-
submit.click(
|
550 |
-
generate,
|
551 |
-
inputs=[
|
552 |
-
message,
|
553 |
-
chatbot,
|
554 |
-
user_state,
|
555 |
-
max_new_tokens,
|
556 |
-
temperature,
|
557 |
-
top_p,
|
558 |
-
top_k,
|
559 |
-
repetition_penalty,
|
560 |
-
],
|
561 |
-
outputs=chatbot,
|
562 |
-
)
|
563 |
-
|
564 |
-
# Thêm các ví dụ
|
565 |
-
gr.Examples(
|
566 |
-
examples=EXAMPLES,
|
567 |
-
inputs=[message],
|
568 |
-
outputs=[chatbot],
|
569 |
-
fn=lambda x: x, # Function to populate the message box with the example
|
570 |
-
)
|
571 |
|
572 |
-
# Khởi chạy ứng dụng Gradio
|
573 |
if __name__ == "__main__":
|
574 |
-
demo.queue(max_size=20).launch(
|
|
|
1 |
+
|
2 |
import os
|
3 |
from threading import Thread
|
4 |
from typing import Iterator, List, Tuple, Dict, Any
|
|
|
|
|
5 |
|
6 |
import gradio as gr
|
7 |
import spaces
|
|
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
|
10 |
from bs4 import BeautifulSoup
|
11 |
import requests
|
12 |
+
import json
|
13 |
from functools import lru_cache
|
|
|
14 |
|
15 |
# ---------------------------- Cấu Hình ---------------------------- #
|
16 |
|
|
|
29 |
# Xác định thiết bị sử dụng (GPU nếu có, ngược lại CPU)
|
30 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
31 |
|
32 |
+
model_id = "meta-llama/Llama-3.2-3B-Instruct" # ID mô hình, đảm bảo đây là ID mô hình đúng
|
33 |
tokenizer = AutoTokenizer.from_pretrained(model_id) # Tải tokenizer từ Hugging Face
|
34 |
model = AutoModelForCausalLM.from_pretrained(
|
35 |
model_id,
|
|
|
42 |
# Khởi tạo pipeline phân tích tâm lý
|
43 |
sentiment_pipeline = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
# ---------------------------- Định Nghĩa Hàm ---------------------------- #
|
46 |
|
47 |
@lru_cache(maxsize=128)
|
|
|
104 |
]
|
105 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
106 |
input_ids = input_ids.to(device)
|
107 |
+
|
108 |
summary_streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
109 |
summary_kwargs = {
|
110 |
"input_ids": input_ids,
|
|
|
116 |
}
|
117 |
t = Thread(target=model.generate, kwargs=summary_kwargs)
|
118 |
t.start()
|
119 |
+
|
120 |
summary = ""
|
121 |
for new_text in summary_streamer:
|
122 |
summary += new_text
|
|
|
129 |
score = result[0]['score']
|
130 |
return f"🟢 **Tâm lý**: {sentiment} (Điểm: {score:.2f})"
|
131 |
|
132 |
+
def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
|
133 |
"""
|
134 |
Tạo phản hồi sử dụng mô hình Llama cục bộ theo chế độ streaming.
|
135 |
"""
|
136 |
+
# Xây dựng lịch sử cuộc trò chuyện
|
137 |
+
conversation = []
|
138 |
+
for user, assistant in chat_history:
|
139 |
+
conversation.extend([
|
140 |
+
{"role": "user", "content": user},
|
141 |
+
{"role": "assistant", "content": assistant},
|
|
|
|
|
142 |
])
|
143 |
+
conversation.append({"role": "user", "content": prompt}) # Thêm tin nhắn của người dùng
|
144 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
# Chuẩn bị input_ids từ tokenizer
|
146 |
+
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
147 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
148 |
+
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] # Cắt input nếu quá dài
|
149 |
+
gr.Warning(f"Đã cắt bỏ phần cuộc trò chuyện vì vượt quá {MAX_INPUT_TOKEN_LENGTH} token.")
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
input_ids = input_ids.to(device) # Di chuyển input tới thiết bị
|
151 |
+
|
152 |
# Khởi tạo streamer để nhận văn bản được tạo ra theo thời gian thực
|
153 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
154 |
generate_kwargs = {
|
|
|
164 |
}
|
165 |
t = Thread(target=model.generate, kwargs=generate_kwargs) # Tạo luồng để sinh văn bản
|
166 |
t.start()
|
167 |
+
|
168 |
# Stream văn bản được tạo ra
|
169 |
outputs = []
|
170 |
for text in streamer:
|
171 |
outputs.append(text)
|
172 |
+
yield "".join(outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
@lru_cache(maxsize=128)
|
175 |
def process_query(query: str) -> Dict[str, Any]:
|
|
|
181 |
general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"]
|
182 |
summarize_keywords = ["tóm tắt", "tóm lại", "khái quát", "ngắn gọn"]
|
183 |
sentiment_keywords = ["cảm xúc", "tâm trạng", "tâm lý", "phân tích cảm xúc"]
|
|
|
184 |
|
185 |
query_lower = query.lower() # Chuyển truy vấn thành chữ thường để so sánh
|
186 |
+
|
187 |
if any(keyword in query_lower for keyword in web_search_keywords):
|
188 |
function_name = "web_search"
|
189 |
arguments = {"query": query}
|
|
|
193 |
elif any(keyword in query_lower for keyword in sentiment_keywords):
|
194 |
function_name = "sentiment_analysis"
|
195 |
arguments = {"prompt": query}
|
|
|
|
|
|
|
196 |
elif any(keyword in query_lower for keyword in general_query_keywords):
|
197 |
function_name = "general_query"
|
198 |
arguments = {"prompt": query}
|
199 |
else:
|
200 |
function_name = "hard_query"
|
201 |
arguments = {"prompt": query}
|
202 |
+
|
203 |
return {
|
204 |
"name": function_name,
|
205 |
"arguments": arguments
|
206 |
}
|
207 |
|
208 |
+
def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
|
209 |
"""
|
210 |
Thực thi hàm phù hợp dựa trên lời gọi hàm.
|
211 |
"""
|
212 |
function_name = function_call["name"]
|
213 |
arguments = function_call["arguments"]
|
214 |
+
|
215 |
if function_name == "web_search":
|
216 |
query = arguments["query"]
|
217 |
yield "🔍 Đang thực hiện tìm kiếm trên web..."
|
|
|
223 |
web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
|
224 |
if not web_summary:
|
225 |
web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
|
226 |
+
|
227 |
# Trả về kết quả tìm kiếm cho người dùng
|
228 |
yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
|
229 |
+
|
230 |
elif function_name == "summarize_query":
|
231 |
# Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả
|
232 |
query = arguments["prompt"]
|
|
|
243 |
# Tóm tắt nội dung đã lấy
|
244 |
yield "📝 Đang tóm tắt thông tin..."
|
245 |
summary = summarize_text(combined_text)
|
|
|
|
|
246 |
yield "📄 **Tóm tắt:**\n" + summary
|
247 |
+
|
248 |
elif function_name == "sentiment_analysis":
|
249 |
prompt_text = arguments["prompt"]
|
250 |
yield "📊 Đang phân tích tâm lý..."
|
251 |
sentiment = analyze_sentiment(prompt_text)
|
252 |
yield sentiment
|
253 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
elif function_name in ["general_query", "hard_query"]:
|
255 |
prompt_text = arguments["prompt"]
|
256 |
yield "🤖 Đang tạo phản hồi..."
|
|
|
258 |
response_generator = generate_response(
|
259 |
prompt=prompt_text,
|
260 |
chat_history=chat_history,
|
|
|
261 |
max_new_tokens=max_new_tokens,
|
262 |
temperature=temperature,
|
263 |
top_p=top_p,
|
|
|
266 |
)
|
267 |
for response in response_generator:
|
268 |
yield response
|
269 |
+
|
270 |
else:
|
271 |
yield "⚠️ Lời gọi hàm không được nhận dạng."
|
272 |
|
273 |
# ---------------------------- Giao Diện Gradio ---------------------------- #
|
274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
@spaces.GPU(duration=15, queue=False)
|
276 |
def generate(
|
277 |
message: str,
|
278 |
chat_history: List[Tuple[str, str]],
|
|
|
279 |
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
280 |
temperature: float = 0.6,
|
281 |
top_p: float = 0.9,
|
282 |
top_k: int = 50,
|
283 |
repetition_penalty: float = 1.2,
|
284 |
+
) -> Iterator[str]:
|
285 |
"""
|
286 |
Hàm chính để xử lý đầu vào của người dùng và tạo phản hồi.
|
287 |
"""
|
288 |
# Thông báo về việc phân tích đầu vào
|
289 |
+
yield "🔍 Đang phân tích truy vấn của bạn..."
|
290 |
+
|
|
|
|
|
|
|
291 |
# Xác định hàm nào sẽ đư��c gọi dựa trên tin nhắn của người dùng
|
292 |
function_call = process_query(message)
|
293 |
+
|
294 |
# Thông báo về hàm được chọn
|
295 |
if function_call["name"] == "web_search":
|
296 |
+
yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web."
|
297 |
elif function_call["name"] == "summarize_query":
|
298 |
+
yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản."
|
299 |
elif function_call["name"] == "sentiment_analysis":
|
300 |
+
yield "🛠️ Đã chọn chức năng: Phân tích tâm lý."
|
|
|
|
|
301 |
elif function_call["name"] in ["general_query", "hard_query"]:
|
302 |
+
yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi."
|
303 |
else:
|
304 |
+
yield "⚠️ Không thể xác định chức năng phù hợp."
|
305 |
+
|
306 |
# Xử lý lời gọi hàm và sinh phản hồi tương ứng
|
307 |
response_iterator = handle_functions(
|
308 |
function_call=function_call,
|
309 |
prompt=message,
|
310 |
chat_history=chat_history,
|
|
|
311 |
max_new_tokens=max_new_tokens,
|
312 |
temperature=temperature,
|
313 |
top_p=top_p,
|
314 |
top_k=top_k,
|
315 |
repetition_penalty=repetition_penalty
|
316 |
)
|
317 |
+
|
|
|
|
|
|
|
318 |
for response in response_iterator:
|
319 |
+
yield response
|
|
|
|
|
320 |
|
321 |
# Định nghĩa các ví dụ để hướng dẫn người dùng
|
322 |
EXAMPLES = [
|
|
|
332 |
]
|
333 |
|
334 |
# Cấu hình giao diện trò chuyện của Gradio với giao diện đẹp mắt
|
335 |
+
chat_interface = gr.ChatInterface(
|
336 |
+
fn=generate, # Hàm được gọi khi có tương tác từ người dùng
|
337 |
+
additional_inputs=[
|
338 |
+
gr.Slider(
|
339 |
+
label="Số token mới tối đa",
|
340 |
+
minimum=1,
|
341 |
+
maximum=MAX_MAX_NEW_TOKENS,
|
342 |
+
step=1,
|
343 |
+
value=DEFAULT_MAX_NEW_TOKENS,
|
344 |
+
),
|
345 |
+
gr.Slider(
|
346 |
+
label="Nhiệt độ",
|
347 |
+
minimum=0.1,
|
348 |
+
maximum=4.0,
|
349 |
+
step=0.1,
|
350 |
+
value=0.6,
|
351 |
+
),
|
352 |
+
gr.Slider(
|
353 |
+
label="Top-p (nucleus sampling)",
|
354 |
+
minimum=0.05,
|
355 |
+
maximum=1.0,
|
356 |
+
step=0.05,
|
357 |
+
value=0.9,
|
358 |
+
),
|
359 |
+
gr.Slider(
|
360 |
+
label="Top-k",
|
361 |
+
minimum=1,
|
362 |
+
maximum=1000,
|
363 |
+
step=1,
|
364 |
+
value=50,
|
365 |
+
),
|
366 |
+
gr.Slider(
|
367 |
+
label="Hình phạt sự lặp lại",
|
368 |
+
minimum=1.0,
|
369 |
+
maximum=2.0,
|
370 |
+
step=0.05,
|
371 |
+
value=1.2,
|
372 |
+
),
|
373 |
+
],
|
374 |
+
stop_btn=None, # Không có nút dừng
|
375 |
+
examples=EXAMPLES, # Các ví dụ được hiển thị cho người dùng
|
376 |
+
cache_examples=False, # Không lưu bộ nhớ cache cho các ví dụ
|
377 |
+
title="🤖 OpenGPT-4o Chatbot",
|
378 |
+
description="Một trợ lý AI mạnh mẽ sử dụng mô hình Llama-3.2 cục bộ với các chức năng tìm kiếm web, tóm tắt văn bản và phân tích tâm lý.",
|
379 |
+
theme="default", # Có thể thay đổi theme để giao diện đẹp hơn
|
380 |
+
)
|
381 |
+
|
382 |
+
# Tạo giao diện chính của Gradio với CSS tùy chỉnh
|
383 |
with gr.Blocks(css="""
|
384 |
.gradio-container {
|
385 |
background-color: #f0f2f5; /* Màu nền nhẹ nhàng */
|
|
|
402 |
}
|
403 |
""", fill_height=True) as demo:
|
404 |
gr.Markdown(DESCRIPTION) # Hiển thị mô tả
|
405 |
+
gr.DuplicateButton(value="Nhân bản Không gian để sử dụng riêng tư", elem_id="duplicate-button") # Nút nhân bản không gian
|
406 |
+
chat_interface.render() # Hiển thị giao diện trò chuyện
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
|
|
|
408 |
if __name__ == "__main__":
|
409 |
+
demo.queue(max_size=20).launch() # Khởi chạy ứng dụng Gradio với hàng đợi kích thước tối đa là 20
|