File size: 8,568 Bytes
a9907f7
bf64382
a9907f7
bf64382
a9907f7
bf64382
a9907f7
fb114fa
bf64382
 
 
fb114fa
a9907f7
bf64382
a9907f7
bf64382
a9907f7
 
 
fb114fa
 
a9907f7
 
 
 
b04fed6
a9907f7
bf64382
 
a9907f7
bf64382
 
 
 
a9907f7
 
b04fed6
a9907f7
fb114fa
a9907f7
 
 
 
 
 
 
 
 
fb114fa
a9907f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e08a440
bf64382
 
a9907f7
bf64382
 
 
 
 
 
 
a9907f7
 
bf64382
 
 
 
 
 
 
a9907f7
bf64382
fb114fa
a9907f7
bf64382
a9907f7
 
bf64382
 
a9907f7
 
 
bf64382
b04fed6
a9907f7
bf64382
a9907f7
 
bf64382
 
 
 
 
fb114fa
bf64382
 
 
 
 
a9907f7
 
bf64382
 
fb114fa
a9907f7
bf64382
a9907f7
bf64382
 
a9907f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf64382
a9907f7
bf64382
 
fb114fa
 
a9907f7
bf64382
 
 
 
 
 
a9907f7
fb114fa
bf64382
 
 
 
 
 
 
fb114fa
 
bf64382
 
 
 
 
 
 
 
fb114fa
bf64382
a9907f7
bf64382
 
 
 
 
 
a9907f7
bf64382
a9907f7
 
 
 
 
5839892
 
fb114fa
a9907f7
fb114fa
 
a9907f7
bf64382
a9907f7
 
 
fb114fa
a9907f7
fb114fa
a9907f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# Import các thư viện cần thiết
import os
import json
from threading import Thread
from typing import Iterator, List, Tuple

# Import thư viện Gradio và các mô-đun khác
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

# Mô tả chung về mô hình và phiên bản Llama
DESCRIPTION = """\
# Llama 3.2 3B Instruct với Gọi Công Cụ Tiên Tiến

Llama 3.2 3B là phiên bản mới nhất của LLM từ Meta, được tinh chỉnh để theo dõi hướng dẫn và hỗ trợ gọi công cụ.
Đây là bản demo của [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct).
Để biết thêm chi tiết, hãy xem [bài đăng của chúng tôi](https://huggingface.co/blog/llama32).
"""

# Các thiết lập thông số tối đa
MAX_MAX_NEW_TOKENS = 2048  # Số token tối đa cho đầu ra mới
DEFAULT_MAX_NEW_TOKENS = 1024  # Số token mặc định cho đầu ra mới
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))  # Lấy giá trị chiều dài token đầu vào từ biến môi trường

# Kiểm tra thiết bị có hỗ trợ GPU không, nếu không thì sử dụng CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Định danh mô hình và tải mô hình cùng tokenizer
model_id = "nltpt/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",  # Tự động ánh xạ thiết bị
    torch_dtype=torch.bfloat16,  # Sử dụng kiểu dữ liệu bfloat16
)
model.eval()  # Đặt mô hình vào chế độ đánh giá (evaluation mode)

# Định nghĩa các chức năng có thể được mô hình gọi
def get_weather(city: str, metric: str = "celsius") -> str:
    # Ở đây bạn có thể tích hợp với API thời tiết thực tế
    # Ví dụ tĩnh:
    weather_data = {
        "San Francisco": "25 C",
        "Seattle": "18 C"
    }
    return weather_data.get(city, "Không có dữ liệu")

def get_user_info(user_id: int, special: str = "none") -> str:
    # Ở đây bạn có thể truy xuất thông tin từ cơ sở dữ liệu
    # Ví dụ tĩnh:
    user_data = {
        7890: {"name": "Nguyễn Văn A", "special": special}
    }
    user = user_data.get(user_id, {"name": "Không xác định", "special": "none"})
    return f"Tên người dùng: {user['name']}, Yêu cầu đặc biệt: {user['special']}"

# Từ điển chứa các chức năng có thể gọi
AVAILABLE_FUNCTIONS = {
    "get_weather": get_weather,
    "get_user_info": get_user_info
}

@spaces.GPU(duration=10)  # Chỉ định hàm này chạy trên GPU trong tối đa 90 giây
def generate(
    message: str,
    chat_history: List[Tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    
    # Duyệt qua lịch sử trò chuyện để xây dựng lại cuộc hội thoại
    for user, assistant in chat_history:
        conversation.extend(
            [
                {"role": "user", "content": user},
                {"role": "assistant", "content": assistant},
            ]
        )
    # Thêm tin nhắn mới của người dùng vào cuộc hội thoại
    conversation.append({"role": "user", "content": message})

    # Áp dụng mẫu hội thoại và chuyển thành tensor
    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    
    # Kiểm tra và cắt bớt chuỗi đầu vào nếu vượt quá chiều dài tối đa
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Đã cắt bớt đầu vào từ cuộc hội thoại vì vượt quá {MAX_INPUT_TOKEN_LENGTH} tokens.")
    
    # Chuyển tensor đến thiết bị của mô hình
    input_ids = input_ids.to(model.device)

    # Khởi tạo Streamer để lấy đầu ra theo từng phần (real-time)
    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    
    # Thiết lập các tham số cho quá trình sinh đầu ra
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    
    # Tạo một luồng để chạy quá trình sinh đầu ra
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Trả về từng phần đầu ra khi chúng được sinh ra
    outputs = []
    assistant_response = ""
    for text in streamer:
        outputs.append(text)
        assistant_response = "".join(outputs)
        # Kiểm tra xem mô hình có trả về cuộc gọi chức năng không
        if "[get_weather" in assistant_response or "[get_user_info" in assistant_response:
            try:
                # Trích xuất phần gọi chức năng từ phản hồi
                start = assistant_response.index('[')
                end = assistant_response.index(']') + 1
                func_calls_str = assistant_response[start:end]
                func_calls = json.loads(func_calls_str.replace("'", '"'))
                
                results = []
                for call in func_calls:
                    func_name = list(call.keys())[0]
                    params = call[func_name]
                    if isinstance(params, dict):
                        result = AVAILABLE_FUNCTIONS[func_name](**params)
                    else:
                        result = AVAILABLE_FUNCTIONS[func_name]()
                    results.append(result)
                
                # Gộp kết quả và thêm vào phản hồi của trợ lý
                assistant_response = assistant_response[:start] + " ".join(results) + assistant_response[end:]
                yield assistant_response
            except Exception as e:
                yield f"Đã xảy ra lỗi khi xử lý cuộc gọi chức năng: {str(e)}"
        else:
            yield assistant_response

# Tạo giao diện chat với Gradio
chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(
            label="Số token mới tối đa",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Nhiệt độ (Temperature)",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Hình phạt lặp lại (Repetition penalty)",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,  # Không có nút dừng
    examples=[
        ["Xin chào! Bạn có khỏe không?"],
        ["Bạn có thể giải thích ngắn gọn về ngôn ngữ lập trình Python không?"],
        ["Giải thích cốt truyện của Cô bé Lọ Lem trong một câu."],
        ["Mất bao nhiêu giờ để một người ăn một chiếc trực thăng?"],
        ["Viết một bài báo 100 từ về 'Lợi ích của mã nguồn mở trong nghiên cứu AI'"],
        ["Thời tiết ở San Francisco thế nào"],
        
    ],
    cache_examples=False,  # Không lưu trữ các ví dụ
)

# Tạo bố cục giao diện với Gradio
with gr.Blocks(css="style.css", fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)  # Hiển thị phần mô tả
    gr.DuplicateButton(value="Tạo bản sao cho sử dụng cá nhân", elem_id="duplicate-button")
    chat_interface.render()  # Hiển thị giao diện chat

# Khởi chạy ứng dụng khi chạy trực tiếp tệp này
if __name__ == "__main__":
    demo.queue(max_size=20).launch()