Spaces:
Sleeping
Sleeping
import os | |
from threading import Thread | |
from typing import Iterator, List, Tuple, Dict, Any | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from bs4 import BeautifulSoup | |
import requests | |
import json | |
from functools import lru_cache | |
# ---------------------------- Cấu Hình ---------------------------- # | |
DESCRIPTION = """\ | |
# Llama 3.2 3B Instruct với Chức Năng Nâng Cao | |
Llama 3.2 3B là phiên bản mới nhất của Meta về các mô hình ngôn ngữ mở. | |
Demo này giới thiệu [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), được tinh chỉnh để theo dõi hướng dẫn. | |
Để biết thêm chi tiết, vui lòng xem [bài viết của chúng tôi](https://huggingface.co/blog/llama32). | |
""" | |
MAX_MAX_NEW_TOKENS = 2048 # Số token tối đa có thể tạo ra | |
DEFAULT_MAX_NEW_TOKENS = 1024 # Số token tạo ra mặc định | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) # Độ dài token tối đa cho đầu vào | |
# Xác định thiết bị sử dụng (GPU nếu có, ngược lại CPU) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model_id = "nltpt/Llama-3.2-3B-Instruct" # ID mô hình, đảm bảo đây là ID mô hình đúng | |
tokenizer = AutoTokenizer.from_pretrained(model_id) # Tải tokenizer từ Hugging Face | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, # Sử dụng dtype phù hợp để tiết kiệm bộ nhớ | |
) | |
model.to(device) # Di chuyển mô hình tới thiết bị đã chọn | |
model.eval() # Đặt mô hình ở chế độ đánh giá | |
# ---------------------------- Định Nghĩa Hàm ---------------------------- # | |
def extract_text_from_webpage(html_content: str) -> str: | |
"""Trích xuất văn bản hiển thị từ nội dung HTML sử dụng BeautifulSoup.""" | |
soup = BeautifulSoup(html_content, "html.parser") | |
# Loại bỏ các thẻ không hiển thị như script, style, header, footer, nav, form, svg | |
for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]): | |
tag.extract() | |
# Trích xuất văn bản hiển thị, tách bằng dấu cách và loại bỏ khoảng trắng thừa | |
visible_text = soup.get_text(separator=' ', strip=True) | |
return visible_text | |
def search(query: str) -> List[Dict[str, Any]]: | |
"""Thực hiện tìm kiếm trên Google và trả về kết quả.""" | |
term = query | |
all_results = [] | |
max_chars_per_page = 8000 # Số ký tự tối đa mỗi trang | |
headers = { | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0" | |
} | |
with requests.Session() as session: | |
try: | |
resp = session.get( | |
url="https://www.google.com/search", | |
headers=headers, | |
params={"q": term, "num": 4}, # Tìm kiếm với 4 kết quả mỗi trang | |
timeout=5, | |
verify=False, # Bỏ qua xác minh SSL | |
) | |
resp.raise_for_status() # Kiểm tra phản hồi HTTP | |
soup = BeautifulSoup(resp.text, "html.parser") | |
result_blocks = soup.find_all("div", attrs={"class": "g"}) # Tìm tất cả các khối kết quả | |
for result in result_blocks: | |
link_tag = result.find("a", href=True) # Tìm thẻ liên kết | |
if link_tag and 'href' in link_tag.attrs: | |
link = link_tag["href"] | |
try: | |
webpage = session.get( | |
link, | |
headers=headers, | |
timeout=5, | |
verify=False | |
) | |
webpage.raise_for_status() | |
visible_text = extract_text_from_webpage(webpage.text) | |
if len(visible_text) > max_chars_per_page: | |
visible_text = visible_text[:max_chars_per_page] # Cắt văn bản nếu quá dài | |
all_results.append({"link": link, "text": visible_text}) | |
except requests.exceptions.RequestException: | |
all_results.append({"link": link, "text": "Không thể lấy nội dung."}) | |
except requests.exceptions.RequestException as e: | |
all_results.append({"link": "N/A", "text": "Không thể thực hiện tìm kiếm."}) | |
return all_results | |
def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]: | |
""" | |
Tạo phản hồi sử dụng mô hình Llama cục bộ theo chế độ streaming. | |
""" | |
# Xây dựng lịch sử cuộc trò chuyện | |
conversation = [] | |
for user, assistant in chat_history: | |
conversation.extend([ | |
{"role": "user", "content": user}, | |
{"role": "assistant", "content": assistant}, | |
]) | |
conversation.append({"role": "user", "content": prompt}) # Thêm tin nhắn của người dùng | |
# Chuẩn bị input_ids từ tokenizer | |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] # Cắt input nếu quá dài | |
gr.Warning(f"Đã cắt bỏ phần cuộc trò chuyện vì vượt quá {MAX_INPUT_TOKEN_LENGTH} token.") | |
input_ids = input_ids.to(device) # Di chuyển input tới thiết bị | |
# Khởi tạo streamer để nhận văn bản được tạo ra theo thời gian thực | |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = { | |
"input_ids": input_ids, | |
"streamer": streamer, | |
"max_new_tokens": max_new_tokens, | |
"do_sample": True, | |
"top_p": top_p, | |
"top_k": top_k, | |
"temperature": temperature, | |
"num_beams": 1, | |
"repetition_penalty": repetition_penalty, | |
} | |
t = Thread(target=model.generate, kwargs=generate_kwargs) # Tạo luồng để sinh văn bản | |
t.start() | |
# Stream văn bản được tạo ra | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
def process_query(query: str) -> Dict[str, Any]: | |
""" | |
Xác định hàm nào sẽ được gọi dựa trên truy vấn của người dùng. | |
""" | |
# Định nghĩa các từ khóa hoặc mẫu để xác định hàm | |
web_search_keywords = ["tìm kiếm", "tìm", "tra cứu", "google", "lookup"] | |
general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"] | |
# Bất kỳ truy vấn nào khác sẽ được xử lý như hard_query | |
query_lower = query.lower() # Chuyển truy vấn thành chữ thường để so sánh | |
if any(keyword in query_lower for keyword in web_search_keywords): | |
function_name = "web_search" | |
arguments = {"query": query} | |
elif any(keyword in query_lower for keyword in general_query_keywords): | |
function_name = "general_query" | |
arguments = {"prompt": query} | |
else: | |
function_name = "hard_query" | |
arguments = {"prompt": query} | |
return { | |
"name": function_name, | |
"arguments": arguments | |
} | |
def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]: | |
""" | |
Thực thi hàm phù hợp dựa trên lời gọi hàm. | |
""" | |
function_name = function_call["name"] | |
arguments = function_call["arguments"] | |
if function_name == "web_search": | |
query = arguments["query"] | |
yield "🔍 Đang thực hiện tìm kiếm trên web..." | |
web_results = search(query) | |
if not web_results: | |
yield "⚠️ Không tìm thấy kết quả." | |
return | |
# Tóm tắt kết quả tìm kiếm | |
web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."]) | |
if not web_summary: | |
web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm." | |
# Trả về kết quả tìm kiếm cho người dùng | |
yield "📄 **Kết quả tìm kiếm:**\n" + web_summary | |
elif function_name in ["general_query", "hard_query"]: | |
prompt_text = arguments["prompt"] | |
yield "🤖 Đang tạo phản hồi..." | |
# Tạo phản hồi sử dụng mô hình Llama | |
response_generator = generate_response( | |
prompt=prompt_text, | |
chat_history=chat_history, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty | |
) | |
for response in response_generator: | |
yield response | |
else: | |
yield "⚠️ Lời gọi hàm không được nhận dạng." | |
# ---------------------------- Giao Diện Gradio ---------------------------- # | |
def generate( | |
message: str, | |
chat_history: List[Tuple[str, str]], | |
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
""" | |
Hàm chính để xử lý đầu vào của người dùng và tạo phản hồi. | |
""" | |
# Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng | |
function_call = process_query(message) | |
# Xử lý lời gọi hàm và sinh phản hồi tương ứng | |
response_iterator = handle_functions( | |
function_call=function_call, | |
prompt=message, | |
chat_history=chat_history, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty | |
) | |
for response in response_iterator: | |
yield response | |
# Định nghĩa các ví dụ để hướng dẫn người dùng | |
EXAMPLES = [ | |
["Xin chào! Bạn khỏe không?"], | |
["Bạn có thể giải thích ngắn gọn về ngôn ngữ lập trình Python không?"], | |
["Giải thích cốt truyện của Cô bé Lọ Lem trong một câu."], | |
["Một người đàn ông cần bao nhiêu giờ để ăn một chiếc máy bay trực thăng?"], | |
["Viết một bài báo 100 từ về 'Lợi ích của mã nguồn mở trong nghiên cứu AI'"], | |
["Tìm và cung cấp cho tôi tin tức mới nhất về năng lượng tái tạo."], | |
["Tìm thông tin về Rạn san hô Great Barrier Reef."], | |
] | |
# Cấu hình giao diện trò chuyện của Gradio | |
chat_interface = gr.ChatInterface( | |
fn=generate, # Hàm được gọi khi có tương tác từ người dùng | |
additional_inputs=[ | |
gr.Slider( | |
label="Số token mới tối đa", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
), | |
gr.Slider( | |
label="Nhiệt độ", | |
minimum=0.1, | |
maximum=4.0, | |
step=0.1, | |
value=0.6, | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=0.9, | |
), | |
gr.Slider( | |
label="Top-k", | |
minimum=1, | |
maximum=1000, | |
step=1, | |
value=50, | |
), | |
gr.Slider( | |
label="Hình phạt sự lặp lại", | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.2, | |
), | |
], | |
stop_btn=None, # Không có nút dừng | |
examples=EXAMPLES, # Các ví dụ được hiển thị cho người dùng | |
cache_examples=False, # Không lưu bộ nhớ cache cho các ví dụ | |
) | |
# Tạo giao diện chính của Gradio | |
with gr.Blocks(css="style.css", fill_height=True) as demo: | |
gr.Markdown(DESCRIPTION) # Hiển thị mô tả | |
gr.DuplicateButton(value="Nhân bản Không gian để sử dụng riêng tư", elem_id="duplicate-button") # Nút nhân bản không gian | |
chat_interface.render() # Hiển thị giao diện trò chuyện | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() # Khởi chạy ứng dụng Gradio với hàng đợi kích thước tối đa là 20 | |