Spaces:
Sleeping
Sleeping
hoduyquocbao
commited on
Commit
•
a9907f7
1
Parent(s):
bf64382
new version update
Browse files
app.py
CHANGED
@@ -1,40 +1,71 @@
|
|
|
|
1 |
import os
|
|
|
2 |
from threading import Thread
|
3 |
-
from typing import Iterator
|
4 |
|
|
|
5 |
import gradio as gr
|
6 |
import spaces
|
7 |
import torch
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
9 |
|
|
|
10 |
DESCRIPTION = """\
|
11 |
-
# Llama 3.2 3B Instruct
|
12 |
|
13 |
-
Llama 3.2 3B
|
14 |
-
|
15 |
-
|
16 |
"""
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
21 |
|
|
|
22 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
23 |
|
|
|
24 |
model_id = "nltpt/Llama-3.2-3B-Instruct"
|
25 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
26 |
model = AutoModelForCausalLM.from_pretrained(
|
27 |
model_id,
|
28 |
-
device_map="auto",
|
29 |
-
torch_dtype=torch.bfloat16,
|
30 |
)
|
31 |
-
model.eval()
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
def generate(
|
36 |
message: str,
|
37 |
-
chat_history:
|
38 |
max_new_tokens: int = 1024,
|
39 |
temperature: float = 0.6,
|
40 |
top_p: float = 0.9,
|
@@ -42,6 +73,8 @@ def generate(
|
|
42 |
repetition_penalty: float = 1.2,
|
43 |
) -> Iterator[str]:
|
44 |
conversation = []
|
|
|
|
|
45 |
for user, assistant in chat_history:
|
46 |
conversation.extend(
|
47 |
[
|
@@ -49,15 +82,24 @@ def generate(
|
|
49 |
{"role": "assistant", "content": assistant},
|
50 |
]
|
51 |
)
|
|
|
52 |
conversation.append({"role": "user", "content": message})
|
53 |
|
|
|
54 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
|
|
|
|
55 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
56 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
57 |
-
gr.Warning(f"
|
|
|
|
|
58 |
input_ids = input_ids.to(model.device)
|
59 |
|
|
|
60 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
|
|
|
|
61 |
generate_kwargs = dict(
|
62 |
{"input_ids": input_ids},
|
63 |
streamer=streamer,
|
@@ -69,27 +111,57 @@ def generate(
|
|
69 |
num_beams=1,
|
70 |
repetition_penalty=repetition_penalty,
|
71 |
)
|
|
|
|
|
72 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
73 |
t.start()
|
74 |
|
|
|
75 |
outputs = []
|
|
|
76 |
for text in streamer:
|
77 |
outputs.append(text)
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
|
|
81 |
chat_interface = gr.ChatInterface(
|
82 |
fn=generate,
|
83 |
additional_inputs=[
|
84 |
gr.Slider(
|
85 |
-
label="
|
86 |
minimum=1,
|
87 |
maximum=MAX_MAX_NEW_TOKENS,
|
88 |
step=1,
|
89 |
value=DEFAULT_MAX_NEW_TOKENS,
|
90 |
),
|
91 |
gr.Slider(
|
92 |
-
label="Temperature",
|
93 |
minimum=0.1,
|
94 |
maximum=4.0,
|
95 |
step=0.1,
|
@@ -110,28 +182,32 @@ chat_interface = gr.ChatInterface(
|
|
110 |
value=50,
|
111 |
),
|
112 |
gr.Slider(
|
113 |
-
label="Repetition penalty",
|
114 |
minimum=1.0,
|
115 |
maximum=2.0,
|
116 |
step=0.05,
|
117 |
value=1.2,
|
118 |
),
|
119 |
],
|
120 |
-
stop_btn=None,
|
121 |
examples=[
|
122 |
-
["
|
123 |
-
["
|
124 |
-
["
|
125 |
-
["
|
126 |
-
["
|
|
|
|
|
127 |
],
|
128 |
-
cache_examples=False,
|
129 |
)
|
130 |
|
|
|
131 |
with gr.Blocks(css="style.css", fill_height=True) as demo:
|
132 |
-
gr.Markdown(DESCRIPTION)
|
133 |
-
gr.DuplicateButton(value="
|
134 |
-
chat_interface.render()
|
135 |
|
|
|
136 |
if __name__ == "__main__":
|
137 |
-
demo.queue(max_size=20).launch()
|
|
|
1 |
+
# Import các thư viện cần thiết
|
2 |
import os
|
3 |
+
import json
|
4 |
from threading import Thread
|
5 |
+
from typing import Iterator, List, Tuple
|
6 |
|
7 |
+
# Import thư viện Gradio và các mô-đun khác
|
8 |
import gradio as gr
|
9 |
import spaces
|
10 |
import torch
|
11 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
12 |
|
13 |
+
# Mô tả chung về mô hình và phiên bản Llama
|
14 |
DESCRIPTION = """\
|
15 |
+
# Llama 3.2 3B Instruct với Gọi Công Cụ Tiên Tiến
|
16 |
|
17 |
+
Llama 3.2 3B là phiên bản mới nhất của LLM từ Meta, được tinh chỉnh để theo dõi hướng dẫn và hỗ trợ gọi công cụ.
|
18 |
+
Đây là bản demo của [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct).
|
19 |
+
Để biết thêm chi tiết, hãy xem [bài đăng của chúng tôi](https://huggingface.co/blog/llama32).
|
20 |
"""
|
21 |
|
22 |
+
# Các thiết lập thông số tối đa
|
23 |
+
MAX_MAX_NEW_TOKENS = 2048 # Số token tối đa cho đầu ra mới
|
24 |
+
DEFAULT_MAX_NEW_TOKENS = 1024 # Số token mặc định cho đầu ra mới
|
25 |
+
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) # Lấy giá trị chiều dài token đầu vào từ biến môi trường
|
26 |
|
27 |
+
# Kiểm tra thiết bị có hỗ trợ GPU không, nếu không thì sử dụng CPU
|
28 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
29 |
|
30 |
+
# Định danh mô hình và tải mô hình cùng tokenizer
|
31 |
model_id = "nltpt/Llama-3.2-3B-Instruct"
|
32 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
33 |
model = AutoModelForCausalLM.from_pretrained(
|
34 |
model_id,
|
35 |
+
device_map="auto", # Tự động ánh xạ thiết bị
|
36 |
+
torch_dtype=torch.bfloat16, # Sử dụng kiểu dữ liệu bfloat16
|
37 |
)
|
38 |
+
model.eval() # Đặt mô hình vào chế độ đánh giá (evaluation mode)
|
39 |
|
40 |
+
# Định nghĩa các chức năng có thể được mô hình gọi
|
41 |
+
def get_weather(city: str, metric: str = "celsius") -> str:
|
42 |
+
# Ở đây bạn có thể tích hợp với API thời tiết thực tế
|
43 |
+
# Ví dụ tĩnh:
|
44 |
+
weather_data = {
|
45 |
+
"San Francisco": "25 C",
|
46 |
+
"Seattle": "18 C"
|
47 |
+
}
|
48 |
+
return weather_data.get(city, "Không có dữ liệu")
|
49 |
|
50 |
+
def get_user_info(user_id: int, special: str = "none") -> str:
|
51 |
+
# Ở đây bạn có thể truy xuất thông tin từ cơ sở dữ liệu
|
52 |
+
# Ví dụ tĩnh:
|
53 |
+
user_data = {
|
54 |
+
7890: {"name": "Nguyễn Văn A", "special": special}
|
55 |
+
}
|
56 |
+
user = user_data.get(user_id, {"name": "Không xác định", "special": "none"})
|
57 |
+
return f"Tên người dùng: {user['name']}, Yêu cầu đặc biệt: {user['special']}"
|
58 |
+
|
59 |
+
# Từ điển chứa các chức năng có thể gọi
|
60 |
+
AVAILABLE_FUNCTIONS = {
|
61 |
+
"get_weather": get_weather,
|
62 |
+
"get_user_info": get_user_info
|
63 |
+
}
|
64 |
+
|
65 |
+
@spaces.GPU(duration=90) # Chỉ định hàm này chạy trên GPU trong tối đa 90 giây
|
66 |
def generate(
|
67 |
message: str,
|
68 |
+
chat_history: List[Tuple[str, str]],
|
69 |
max_new_tokens: int = 1024,
|
70 |
temperature: float = 0.6,
|
71 |
top_p: float = 0.9,
|
|
|
73 |
repetition_penalty: float = 1.2,
|
74 |
) -> Iterator[str]:
|
75 |
conversation = []
|
76 |
+
|
77 |
+
# Duyệt qua lịch sử trò chuyện để xây dựng lại cuộc hội thoại
|
78 |
for user, assistant in chat_history:
|
79 |
conversation.extend(
|
80 |
[
|
|
|
82 |
{"role": "assistant", "content": assistant},
|
83 |
]
|
84 |
)
|
85 |
+
# Thêm tin nhắn mới của người dùng vào cuộc hội thoại
|
86 |
conversation.append({"role": "user", "content": message})
|
87 |
|
88 |
+
# Áp dụng mẫu hội thoại và chuyển thành tensor
|
89 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
90 |
+
|
91 |
+
# Kiểm tra và cắt bớt chuỗi đầu vào nếu vượt quá chiều dài tối đa
|
92 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
93 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
94 |
+
gr.Warning(f"Đã cắt bớt đầu vào từ cuộc hội thoại vì vượt quá {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
95 |
+
|
96 |
+
# Chuyển tensor đến thiết bị của mô hình
|
97 |
input_ids = input_ids.to(model.device)
|
98 |
|
99 |
+
# Khởi tạo Streamer để lấy đầu ra theo từng phần (real-time)
|
100 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
101 |
+
|
102 |
+
# Thiết lập các tham số cho quá trình sinh đầu ra
|
103 |
generate_kwargs = dict(
|
104 |
{"input_ids": input_ids},
|
105 |
streamer=streamer,
|
|
|
111 |
num_beams=1,
|
112 |
repetition_penalty=repetition_penalty,
|
113 |
)
|
114 |
+
|
115 |
+
# Tạo một luồng để chạy quá trình sinh đầu ra
|
116 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
117 |
t.start()
|
118 |
|
119 |
+
# Trả về từng phần đầu ra khi chúng được sinh ra
|
120 |
outputs = []
|
121 |
+
assistant_response = ""
|
122 |
for text in streamer:
|
123 |
outputs.append(text)
|
124 |
+
assistant_response = "".join(outputs)
|
125 |
+
# Kiểm tra xem mô hình có trả về cuộc gọi chức năng không
|
126 |
+
if "[get_weather" in assistant_response or "[get_user_info" in assistant_response:
|
127 |
+
try:
|
128 |
+
# Trích xuất phần gọi chức năng từ phản hồi
|
129 |
+
start = assistant_response.index('[')
|
130 |
+
end = assistant_response.index(']') + 1
|
131 |
+
func_calls_str = assistant_response[start:end]
|
132 |
+
func_calls = json.loads(func_calls_str.replace("'", '"'))
|
133 |
+
|
134 |
+
results = []
|
135 |
+
for call in func_calls:
|
136 |
+
func_name = list(call.keys())[0]
|
137 |
+
params = call[func_name]
|
138 |
+
if isinstance(params, dict):
|
139 |
+
result = AVAILABLE_FUNCTIONS[func_name](**params)
|
140 |
+
else:
|
141 |
+
result = AVAILABLE_FUNCTIONS[func_name]()
|
142 |
+
results.append(result)
|
143 |
+
|
144 |
+
# Gộp kết quả và thêm vào phản hồi của trợ lý
|
145 |
+
assistant_response = assistant_response[:start] + " ".join(results) + assistant_response[end:]
|
146 |
+
yield assistant_response
|
147 |
+
except Exception as e:
|
148 |
+
yield f"Đã xảy ra lỗi khi xử lý cuộc gọi chức năng: {str(e)}"
|
149 |
+
else:
|
150 |
+
yield assistant_response
|
151 |
|
152 |
+
# Tạo giao diện chat với Gradio
|
153 |
chat_interface = gr.ChatInterface(
|
154 |
fn=generate,
|
155 |
additional_inputs=[
|
156 |
gr.Slider(
|
157 |
+
label="Số token mới tối đa",
|
158 |
minimum=1,
|
159 |
maximum=MAX_MAX_NEW_TOKENS,
|
160 |
step=1,
|
161 |
value=DEFAULT_MAX_NEW_TOKENS,
|
162 |
),
|
163 |
gr.Slider(
|
164 |
+
label="Nhiệt độ (Temperature)",
|
165 |
minimum=0.1,
|
166 |
maximum=4.0,
|
167 |
step=0.1,
|
|
|
182 |
value=50,
|
183 |
),
|
184 |
gr.Slider(
|
185 |
+
label="Hình phạt lặp lại (Repetition penalty)",
|
186 |
minimum=1.0,
|
187 |
maximum=2.0,
|
188 |
step=0.05,
|
189 |
value=1.2,
|
190 |
),
|
191 |
],
|
192 |
+
stop_btn=None, # Không có nút dừng
|
193 |
examples=[
|
194 |
+
["Xin chào! Bạn có khỏe không?"],
|
195 |
+
["Bạn có thể giải thích ngắn gọn về ngôn ngữ lập trình Python không?"],
|
196 |
+
["Giải thích cốt truyện của Cô bé Lọ Lem trong một câu."],
|
197 |
+
["Mất bao nhiêu giờ để một người ăn một chiếc trực thăng?"],
|
198 |
+
["Viết một bài báo 100 từ về 'Lợi ích của mã nguồn mở trong nghiên cứu AI'"],
|
199 |
+
["What is the weather in SF and Seattle?"],
|
200 |
+
["Can you retrieve the details for the user with the ID 7890, who has black as their special request?"]
|
201 |
],
|
202 |
+
cache_examples=False, # Không lưu trữ các ví dụ
|
203 |
)
|
204 |
|
205 |
+
# Tạo bố cục giao diện với Gradio
|
206 |
with gr.Blocks(css="style.css", fill_height=True) as demo:
|
207 |
+
gr.Markdown(DESCRIPTION) # Hiển thị phần mô tả
|
208 |
+
gr.DuplicateButton(value="Tạo bản sao cho sử dụng cá nhân", elem_id="duplicate-button")
|
209 |
+
chat_interface.render() # Hiển thị giao diện chat
|
210 |
|
211 |
+
# Khởi chạy ứng dụng khi chạy trực tiếp tệp này
|
212 |
if __name__ == "__main__":
|
213 |
+
demo.queue(max_size=20).launch()
|