Spaces:
Sleeping
Sleeping
import os | |
from threading import Thread | |
from typing import Iterator, List, Tuple, Dict, Any | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline | |
from bs4 import BeautifulSoup | |
import requests | |
import json | |
from functools import lru_cache | |
from checkpoint import continuous_training | |
# ---------------------------- Cấu Hình ---------------------------- # | |
DESCRIPTION = """\ | |
# Llama 3.2 3B Instruct với Chức Năng Nâng Cao | |
Llama 3.2 3B là phiên bản mới nhất của Meta về các mô hình ngôn ngữ mở. | |
Demo này giới thiệu [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), được tinh chỉnh để theo dõi hướng dẫn. | |
Để biết thêm chi tiết, vui lòng xem [bài viết của chúng tôi](https://huggingface.co/blog/llama32). | |
""" | |
MAX_MAX_NEW_TOKENS = 2048 # Số token tối đa có thể tạo ra | |
DEFAULT_MAX_NEW_TOKENS = 1024 # Số token tạo ra mặc định | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000")) # Độ dài token tối đa cho đầu vào | |
# Xác định thiết bị sử dụng (GPU nếu có, ngược lại CPU) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model_id = "meta-llama/Llama-3.2-3B-Instruct" # ID mô hình, đảm bảo đây là ID mô hình đúng | |
tokenizer = AutoTokenizer.from_pretrained(model_id) # Tải tokenizer từ Hugging Face | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, # Sử dụng dtype phù hợp để tiết kiệm bộ nhớ | |
) | |
model.to(device) # Di chuyển mô hình tới thiết bị đã chọn | |
model.eval() # Đặt mô hình ở chế độ đánh giá | |
# Khởi tạo pipeline phân tích tâm lý | |
sentiment_pipeline = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment") | |
# ---------------------------- Định Nghĩa Hàm ---------------------------- # | |
def extract_text_from_webpage(html_content: str) -> str: | |
"""Trích xuất văn bản hiển thị từ nội dung HTML sử dụng BeautifulSoup.""" | |
soup = BeautifulSoup(html_content, "html.parser") | |
# Loại bỏ các thẻ không hiển thị như script, style, header, footer, nav, form, svg | |
for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]): | |
tag.extract() | |
# Trích xuất văn bản hiển thị, tách bằng dấu cách và loại bỏ khoảng trắng thừa | |
visible_text = soup.get_text(separator=' ', strip=True) | |
return visible_text | |
def search(query: str) -> List[Dict[str, Any]]: | |
"""Thực hiện tìm kiếm trên Google và trả về kết quả.""" | |
term = query | |
all_results = [] | |
max_chars_per_page = 8000 # Số ký tự tối đa mỗi trang | |
headers = { | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0" | |
} | |
with requests.Session() as session: | |
try: | |
resp = session.get( | |
url="https://www.google.com/search", | |
headers=headers, | |
params={"q": term, "num": 4}, # Tìm kiếm với 4 kết quả mỗi trang | |
timeout=5, | |
verify=False, # Bỏ qua xác minh SSL | |
) | |
resp.raise_for_status() # Kiểm tra phản hồi HTTP | |
soup = BeautifulSoup(resp.text, "html.parser") | |
result_blocks = soup.find_all("div", attrs={"class": "g"}) # Tìm tất cả các khối kết quả | |
for result in result_blocks: | |
link_tag = result.find("a", href=True) # Tìm thẻ liên kết | |
if link_tag and 'href' in link_tag.attrs: | |
link = link_tag["href"] | |
try: | |
webpage = session.get( | |
link, | |
headers=headers, | |
timeout=5, | |
verify=False | |
) | |
webpage.raise_for_status() | |
visible_text = extract_text_from_webpage(webpage.text) | |
if len(visible_text) > max_chars_per_page: | |
visible_text = visible_text[:max_chars_per_page] # Cắt văn bản nếu quá dài | |
all_results.append({"link": link, "text": visible_text}) | |
except requests.exceptions.RequestException: | |
all_results.append({"link": link, "text": "Không thể lấy nội dung."}) | |
except requests.exceptions.RequestException as e: | |
all_results.append({"link": "N/A", "text": "Không thể thực hiện tìm kiếm."}) | |
return all_results | |
def summarize_text(text: str, max_length: int = 150) -> str: | |
"""Tóm tắt văn bản sử dụng mô hình Llama.""" | |
conversation = [ | |
{"role": "user", "content": f"Hãy tóm tắt đoạn văn sau: {text}"} | |
] | |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") | |
input_ids = input_ids.to(device) | |
summary_streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
summary_kwargs = { | |
"input_ids": input_ids, | |
"streamer": summary_streamer, | |
"max_new_tokens": max_length, | |
"do_sample": True, | |
"top_p": 0.95, | |
"temperature": 0.7, | |
} | |
t = Thread(target=model.generate, kwargs=summary_kwargs) | |
t.start() | |
summary = "" | |
for new_text in summary_streamer: | |
summary += new_text | |
return summary | |
def analyze_sentiment(text: str) -> str: | |
"""Phân tích tâm lý của văn bản sử dụng mô hình.""" | |
result = sentiment_pipeline(text) | |
sentiment = result[0]['label'] | |
score = result[0]['score'] | |
return f"🟢 **Tâm lý**: {sentiment} (Điểm: {score:.2f})" | |
def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]: | |
""" | |
Tạo phản hồi sử dụng mô hình Llama cục bộ theo chế độ streaming. | |
""" | |
# Xây dựng lịch sử cuộc trò chuyện | |
conversation = [] | |
for user, assistant in chat_history: | |
conversation.extend([ | |
{"role": "user", "content": user}, | |
{"role": "assistant", "content": assistant}, | |
]) | |
conversation.append({"role": "user", "content": prompt}) # Thêm tin nhắn của người dùng | |
# Chuẩn bị input_ids từ tokenizer | |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] # Cắt input nếu quá dài | |
gr.Warning(f"Đã cắt bỏ phần cuộc trò chuyện vì vượt quá {MAX_INPUT_TOKEN_LENGTH} token.") | |
input_ids = input_ids.to(device) # Di chuyển input tới thiết bị | |
# Khởi tạo streamer để nhận văn bản được tạo ra theo thời gian thực | |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = { | |
"input_ids": input_ids, | |
"streamer": streamer, | |
"max_new_tokens": max_new_tokens, | |
"do_sample": True, | |
"top_p": top_p, | |
"top_k": top_k, | |
"temperature": temperature, | |
"num_beams": 1, | |
"repetition_penalty": repetition_penalty, | |
} | |
t = Thread(target=model.generate, kwargs=generate_kwargs) # Tạo luồng để sinh văn bản | |
t.start() | |
# Stream văn bản được tạo ra | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
def process_query(query: str) -> Dict[str, Any]: | |
""" | |
Xác định hàm nào sẽ được gọi dựa trên truy vấn của người dùng. | |
""" | |
# Định nghĩa các từ khóa hoặc mẫu để xác định hàm | |
web_search_keywords = ["tìm kiếm", "tìm", "tra cứu", "google", "lookup"] | |
general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"] | |
summarize_keywords = ["tóm tắt", "tóm lại", "khái quát", "ngắn gọn"] | |
sentiment_keywords = ["cảm xúc", "tâm trạng", "tâm lý", "phân tích cảm xúc"] | |
query_lower = query.lower() # Chuyển truy vấn thành chữ thường để so sánh | |
if any(keyword in query_lower for keyword in web_search_keywords): | |
function_name = "web_search" | |
arguments = {"query": query} | |
elif any(keyword in query_lower for keyword in summarize_keywords): | |
function_name = "summarize_query" | |
arguments = {"prompt": query} | |
elif any(keyword in query_lower for keyword in sentiment_keywords): | |
function_name = "sentiment_analysis" | |
arguments = {"prompt": query} | |
elif any(keyword in query_lower for keyword in general_query_keywords): | |
function_name = "general_query" | |
arguments = {"prompt": query} | |
else: | |
function_name = "hard_query" | |
arguments = {"prompt": query} | |
return { | |
"name": function_name, | |
"arguments": arguments | |
} | |
def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]: | |
""" | |
Thực thi hàm phù hợp dựa trên lời gọi hàm. | |
""" | |
function_name = function_call["name"] | |
arguments = function_call["arguments"] | |
if function_name == "web_search": | |
query = arguments["query"] | |
yield "🔍 Đang thực hiện tìm kiếm trên web..." | |
web_results = search(query) | |
if not web_results: | |
yield "⚠️ Không tìm thấy kết quả." | |
return | |
# Tóm tắt kết quả tìm kiếm | |
web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."]) | |
if not web_summary: | |
web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm." | |
# Trả về kết quả tìm kiếm cho người dùng | |
yield "📄 **Kết quả tìm kiếm:**\n" + web_summary | |
elif function_name == "summarize_query": | |
# Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả | |
query = arguments["prompt"] | |
yield "🔍 Đang thực hiện tìm kiếm để tóm tắt..." | |
web_results = search(query) | |
if not web_results: | |
yield "⚠️ Không tìm thấy kết quả để tóm tắt." | |
return | |
# Lấy nội dung từ kết quả tìm kiếm để tóm tắt | |
combined_text = ' '.join([res['text'] for res in web_results if res['text'] != "Không thể lấy nội dung."]) | |
if not combined_text: | |
yield "⚠️ Không có nội dung để tóm tắt." | |
return | |
# Tóm tắt nội dung đã lấy | |
yield "📝 Đang tóm tắt thông tin..." | |
summary = summarize_text(combined_text) | |
yield "📄 **Tóm tắt:**\n" + summary | |
elif function_name == "sentiment_analysis": | |
prompt_text = arguments["prompt"] | |
yield "📊 Đang phân tích tâm lý..." | |
sentiment = analyze_sentiment(prompt_text) | |
yield sentiment | |
elif function_name in ["general_query", "hard_query"]: | |
prompt_text = arguments["prompt"] | |
yield "🤖 Đang tạo phản hồi..." | |
# Tạo phản hồi sử dụng mô hình Llama | |
response_generator = generate_response( | |
prompt=prompt_text, | |
chat_history=chat_history, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty | |
) | |
for response in response_generator: | |
yield response | |
else: | |
yield "⚠️ Lời gọi hàm không được nhận dạng." | |
# ---------------------------- Giao Diện Gradio ---------------------------- # | |
def generate( | |
message: str, | |
chat_history: List[Tuple[str, str]], | |
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
""" | |
Hàm chính để xử lý đầu vào của người dùng và tạo phản hồi. | |
""" | |
# Thông báo về việc phân tích đầu vào | |
yield "🔍 Đang phân tích truy vấn của bạn..." | |
# Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng | |
function_call = process_query(message) | |
# Thông báo về hàm được chọn | |
if function_call["name"] == "web_search": | |
yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web." | |
elif function_call["name"] == "summarize_query": | |
yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản." | |
elif function_call["name"] == "sentiment_analysis": | |
continuous_training(total_steps=300, steps_per_call=50) | |
yield "🛠️ Đã chọn chức năng: Phân tích tâm lý." | |
elif function_call["name"] in ["general_query", "hard_query"]: | |
yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi." | |
else: | |
yield "⚠️ Không thể xác định chức năng phù hợp." | |
# Xử lý lời gọi hàm và sinh phản hồi tương ứng | |
response_iterator = handle_functions( | |
function_call=function_call, | |
prompt=message, | |
chat_history=chat_history, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty | |
) | |
for response in response_iterator: | |
yield response | |
# Định nghĩa các ví dụ để hướng dẫn người dùng | |
EXAMPLES = [ | |
["Xin chào! Bạn khỏe không?"], | |
["Bạn có thể giải thích ngắn gọn về ngôn ngữ lập trình Python không?"], | |
["Giải thích cốt truyện của Cô bé Lọ Lem trong một câu."], | |
["Một người đàn ông cần bao nhiêu giờ để ăn một chiếc máy bay trực thăng?"], | |
["Viết một bài báo 100 từ về 'Lợi ích của mã nguồn mở trong nghiên cứu AI'"], | |
["Tìm và cung cấp cho tôi tin tức mới nhất về năng lượng tái tạo."], | |
["Tìm thông tin về Rạn san hô Great Barrier Reef."], | |
["Tóm tắt nội dung về trí tuệ nhân tạo."], | |
["Phân tích tâm lý của đoạn văn sau: Tôi rất vui khi được gặp bạn hôm nay!"], | |
] | |
# Cấu hình giao diện trò chuyện của Gradio với giao diện đẹp mắt | |
chat_interface = gr.ChatInterface( | |
fn=generate, # Hàm được gọi khi có tương tác từ người dùng | |
additional_inputs=[ | |
gr.Slider( | |
label="Số token mới tối đa", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
), | |
gr.Slider( | |
label="Nhiệt độ", | |
minimum=0.1, | |
maximum=4.0, | |
step=0.1, | |
value=0.6, | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=0.9, | |
), | |
gr.Slider( | |
label="Top-k", | |
minimum=1, | |
maximum=1000, | |
step=1, | |
value=50, | |
), | |
gr.Slider( | |
label="Hình phạt sự lặp lại", | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.2, | |
), | |
], | |
stop_btn=None, # Không có nút dừng | |
examples=EXAMPLES, # Các ví dụ được hiển thị cho người dùng | |
cache_examples=False, # Không lưu bộ nhớ cache cho các ví dụ | |
title="🤖 OpenGPT-4o Chatbot", | |
description="Một trợ lý AI mạnh mẽ sử dụng mô hình Llama-3.2 cục bộ với các chức năng tìm kiếm web, tóm tắt văn bản và phân tích tâm lý.", | |
theme="default", # Có thể thay đổi theme để giao diện đẹp hơn | |
) | |
# Tạo giao diện chính của Gradio với CSS tùy chỉnh | |
with gr.Blocks(css=""" | |
.gradio-container { | |
background-color: #f0f2f5; /* Màu nền nhẹ nhàng */ | |
} | |
.gradio-container h1 { | |
color: #4a90e2; /* Màu xanh dương cho tiêu đề */ | |
} | |
.gradio-container .gr-button { | |
background-color: #4a90e2; /* Màu xanh dương cho nút */ | |
color: white; /* Màu chữ trắng trên nút */ | |
} | |
.gradio-container .gr-slider__label { | |
color: #333333; /* Màu chữ đen cho nhãn slider */ | |
} | |
.gradio-container .gr-chatbot { | |
border: 2px solid #4a90e2; /* Viền xanh dương cho chatbot */ | |
border-radius: 10px; /* Bo góc viền chatbot */ | |
padding: 10px; /* Khoảng cách bên trong chatbot */ | |
background-color: #ffffff; /* Màu nền trắng cho chatbot */ | |
} | |
""", fill_height=True) as demo: | |
gr.Markdown(DESCRIPTION) # Hiển thị mô tả | |
gr.DuplicateButton(value="Nhân bản Không gian để sử dụng riêng tư", elem_id="duplicate-button") # Nút nhân bản không gian | |
chat_interface.render() # Hiển thị giao diện trò chuyện | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() # Khởi chạy ứng dụng Gradio với hàng đợi kích thước tối đa là 20 |