hoduyquocbao commited on
Commit
a9907f7
1 Parent(s): bf64382

new version update

Browse files
Files changed (1) hide show
  1. app.py +106 -30
app.py CHANGED
@@ -1,40 +1,71 @@
 
1
  import os
 
2
  from threading import Thread
3
- from typing import Iterator
4
 
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
 
10
  DESCRIPTION = """\
11
- # Llama 3.2 3B Instruct
12
 
13
- Llama 3.2 3B is Meta's latest iteration of open LLMs.
14
- This is a demo of [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), fine-tuned for instruction following.
15
- For more details, please check [our post](https://huggingface.co/blog/llama32).
16
  """
17
 
18
- MAX_MAX_NEW_TOKENS = 2048
19
- DEFAULT_MAX_NEW_TOKENS = 1024
20
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
21
 
 
22
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
 
 
24
  model_id = "nltpt/Llama-3.2-3B-Instruct"
25
  tokenizer = AutoTokenizer.from_pretrained(model_id)
26
  model = AutoModelForCausalLM.from_pretrained(
27
  model_id,
28
- device_map="auto",
29
- torch_dtype=torch.bfloat16,
30
  )
31
- model.eval()
32
 
 
 
 
 
 
 
 
 
 
33
 
34
- @spaces.GPU(duration=90)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def generate(
36
  message: str,
37
- chat_history: list[tuple[str, str]],
38
  max_new_tokens: int = 1024,
39
  temperature: float = 0.6,
40
  top_p: float = 0.9,
@@ -42,6 +73,8 @@ def generate(
42
  repetition_penalty: float = 1.2,
43
  ) -> Iterator[str]:
44
  conversation = []
 
 
45
  for user, assistant in chat_history:
46
  conversation.extend(
47
  [
@@ -49,15 +82,24 @@ def generate(
49
  {"role": "assistant", "content": assistant},
50
  ]
51
  )
 
52
  conversation.append({"role": "user", "content": message})
53
 
 
54
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
 
 
55
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
56
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
57
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
 
58
  input_ids = input_ids.to(model.device)
59
 
 
60
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
 
 
61
  generate_kwargs = dict(
62
  {"input_ids": input_ids},
63
  streamer=streamer,
@@ -69,27 +111,57 @@ def generate(
69
  num_beams=1,
70
  repetition_penalty=repetition_penalty,
71
  )
 
 
72
  t = Thread(target=model.generate, kwargs=generate_kwargs)
73
  t.start()
74
 
 
75
  outputs = []
 
76
  for text in streamer:
77
  outputs.append(text)
78
- yield "".join(outputs)
79
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
 
81
  chat_interface = gr.ChatInterface(
82
  fn=generate,
83
  additional_inputs=[
84
  gr.Slider(
85
- label="Max new tokens",
86
  minimum=1,
87
  maximum=MAX_MAX_NEW_TOKENS,
88
  step=1,
89
  value=DEFAULT_MAX_NEW_TOKENS,
90
  ),
91
  gr.Slider(
92
- label="Temperature",
93
  minimum=0.1,
94
  maximum=4.0,
95
  step=0.1,
@@ -110,28 +182,32 @@ chat_interface = gr.ChatInterface(
110
  value=50,
111
  ),
112
  gr.Slider(
113
- label="Repetition penalty",
114
  minimum=1.0,
115
  maximum=2.0,
116
  step=0.05,
117
  value=1.2,
118
  ),
119
  ],
120
- stop_btn=None,
121
  examples=[
122
- ["Hello there! How are you doing?"],
123
- ["Can you explain briefly to me what is the Python programming language?"],
124
- ["Explain the plot of Cinderella in a sentence."],
125
- ["How many hours does it take a man to eat a Helicopter?"],
126
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
 
 
127
  ],
128
- cache_examples=False,
129
  )
130
 
 
131
  with gr.Blocks(css="style.css", fill_height=True) as demo:
132
- gr.Markdown(DESCRIPTION)
133
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
134
- chat_interface.render()
135
 
 
136
  if __name__ == "__main__":
137
- demo.queue(max_size=20).launch()
 
1
+ # Import các thư viện cần thiết
2
  import os
3
+ import json
4
  from threading import Thread
5
+ from typing import Iterator, List, Tuple
6
 
7
+ # Import thư viện Gradio và các mô-đun khác
8
  import gradio as gr
9
  import spaces
10
  import torch
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
 
13
+ # Mô tả chung về mô hình và phiên bản Llama
14
  DESCRIPTION = """\
15
+ # Llama 3.2 3B Instruct với Gọi Công Cụ Tiên Tiến
16
 
17
+ Llama 3.2 3B phiên bản mới nhất của LLM từ Meta, được tinh chỉnh để theo dõi hướng dẫn và hỗ trợ gọi công cụ.
18
+ Đây bản demo của [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct).
19
+ Để biết thêm chi tiết, hãy xem [bài đăng của chúng tôi](https://huggingface.co/blog/llama32).
20
  """
21
 
22
+ # Các thiết lập thông số tối đa
23
+ MAX_MAX_NEW_TOKENS = 2048 # Số token tối đa cho đầu ra mới
24
+ DEFAULT_MAX_NEW_TOKENS = 1024 # Số token mặc định cho đầu ra mới
25
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) # Lấy giá trị chiều dài token đầu vào từ biến môi trường
26
 
27
+ # Kiểm tra thiết bị có hỗ trợ GPU không, nếu không thì sử dụng CPU
28
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
 
30
+ # Định danh mô hình và tải mô hình cùng tokenizer
31
  model_id = "nltpt/Llama-3.2-3B-Instruct"
32
  tokenizer = AutoTokenizer.from_pretrained(model_id)
33
  model = AutoModelForCausalLM.from_pretrained(
34
  model_id,
35
+ device_map="auto", # Tự động ánh xạ thiết bị
36
+ torch_dtype=torch.bfloat16, # Sử dụng kiểu dữ liệu bfloat16
37
  )
38
+ model.eval() # Đặt mô hình vào chế độ đánh giá (evaluation mode)
39
 
40
+ # Định nghĩa các chức năng có thể được mô hình gọi
41
+ def get_weather(city: str, metric: str = "celsius") -> str:
42
+ # Ở đây bạn có thể tích hợp với API thời tiết thực tế
43
+ # Ví dụ tĩnh:
44
+ weather_data = {
45
+ "San Francisco": "25 C",
46
+ "Seattle": "18 C"
47
+ }
48
+ return weather_data.get(city, "Không có dữ liệu")
49
 
50
+ def get_user_info(user_id: int, special: str = "none") -> str:
51
+ # Ở đây bạn có thể truy xuất thông tin từ cơ sở dữ liệu
52
+ # Ví dụ tĩnh:
53
+ user_data = {
54
+ 7890: {"name": "Nguyễn Văn A", "special": special}
55
+ }
56
+ user = user_data.get(user_id, {"name": "Không xác định", "special": "none"})
57
+ return f"Tên người dùng: {user['name']}, Yêu cầu đặc biệt: {user['special']}"
58
+
59
+ # Từ điển chứa các chức năng có thể gọi
60
+ AVAILABLE_FUNCTIONS = {
61
+ "get_weather": get_weather,
62
+ "get_user_info": get_user_info
63
+ }
64
+
65
+ @spaces.GPU(duration=90) # Chỉ định hàm này chạy trên GPU trong tối đa 90 giây
66
  def generate(
67
  message: str,
68
+ chat_history: List[Tuple[str, str]],
69
  max_new_tokens: int = 1024,
70
  temperature: float = 0.6,
71
  top_p: float = 0.9,
 
73
  repetition_penalty: float = 1.2,
74
  ) -> Iterator[str]:
75
  conversation = []
76
+
77
+ # Duyệt qua lịch sử trò chuyện để xây dựng lại cuộc hội thoại
78
  for user, assistant in chat_history:
79
  conversation.extend(
80
  [
 
82
  {"role": "assistant", "content": assistant},
83
  ]
84
  )
85
+ # Thêm tin nhắn mới của người dùng vào cuộc hội thoại
86
  conversation.append({"role": "user", "content": message})
87
 
88
+ # Áp dụng mẫu hội thoại và chuyển thành tensor
89
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
90
+
91
+ # Kiểm tra và cắt bớt chuỗi đầu vào nếu vượt quá chiều dài tối đa
92
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
93
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
94
+ gr.Warning(f"Đã cắt bớt đầu vào từ cuộc hội thoại vì vượt quá {MAX_INPUT_TOKEN_LENGTH} tokens.")
95
+
96
+ # Chuyển tensor đến thiết bị của mô hình
97
  input_ids = input_ids.to(model.device)
98
 
99
+ # Khởi tạo Streamer để lấy đầu ra theo từng phần (real-time)
100
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
101
+
102
+ # Thiết lập các tham số cho quá trình sinh đầu ra
103
  generate_kwargs = dict(
104
  {"input_ids": input_ids},
105
  streamer=streamer,
 
111
  num_beams=1,
112
  repetition_penalty=repetition_penalty,
113
  )
114
+
115
+ # Tạo một luồng để chạy quá trình sinh đầu ra
116
  t = Thread(target=model.generate, kwargs=generate_kwargs)
117
  t.start()
118
 
119
+ # Trả về từng phần đầu ra khi chúng được sinh ra
120
  outputs = []
121
+ assistant_response = ""
122
  for text in streamer:
123
  outputs.append(text)
124
+ assistant_response = "".join(outputs)
125
+ # Kiểm tra xem mô hình có trả về cuộc gọi chức năng không
126
+ if "[get_weather" in assistant_response or "[get_user_info" in assistant_response:
127
+ try:
128
+ # Trích xuất phần gọi chức năng từ phản hồi
129
+ start = assistant_response.index('[')
130
+ end = assistant_response.index(']') + 1
131
+ func_calls_str = assistant_response[start:end]
132
+ func_calls = json.loads(func_calls_str.replace("'", '"'))
133
+
134
+ results = []
135
+ for call in func_calls:
136
+ func_name = list(call.keys())[0]
137
+ params = call[func_name]
138
+ if isinstance(params, dict):
139
+ result = AVAILABLE_FUNCTIONS[func_name](**params)
140
+ else:
141
+ result = AVAILABLE_FUNCTIONS[func_name]()
142
+ results.append(result)
143
+
144
+ # Gộp kết quả và thêm vào phản hồi của trợ lý
145
+ assistant_response = assistant_response[:start] + " ".join(results) + assistant_response[end:]
146
+ yield assistant_response
147
+ except Exception as e:
148
+ yield f"Đã xảy ra lỗi khi xử lý cuộc gọi chức năng: {str(e)}"
149
+ else:
150
+ yield assistant_response
151
 
152
+ # Tạo giao diện chat với Gradio
153
  chat_interface = gr.ChatInterface(
154
  fn=generate,
155
  additional_inputs=[
156
  gr.Slider(
157
+ label="Số token mới tối đa",
158
  minimum=1,
159
  maximum=MAX_MAX_NEW_TOKENS,
160
  step=1,
161
  value=DEFAULT_MAX_NEW_TOKENS,
162
  ),
163
  gr.Slider(
164
+ label="Nhiệt độ (Temperature)",
165
  minimum=0.1,
166
  maximum=4.0,
167
  step=0.1,
 
182
  value=50,
183
  ),
184
  gr.Slider(
185
+ label="Hình phạt lặp lại (Repetition penalty)",
186
  minimum=1.0,
187
  maximum=2.0,
188
  step=0.05,
189
  value=1.2,
190
  ),
191
  ],
192
+ stop_btn=None, # Không có nút dừng
193
  examples=[
194
+ ["Xin chào! Bạn khỏe không?"],
195
+ ["Bạn thể giải thích ngắn gọn về ngôn ngữ lập trình Python không?"],
196
+ ["Giải thích cốt truyện của Lọ Lem trong một câu."],
197
+ ["Mất bao nhiêu giờ để một người ăn một chiếc trực thăng?"],
198
+ ["Viết một bài báo 100 từ về 'Lợi ích của nguồn mở trong nghiên cứu AI'"],
199
+ ["What is the weather in SF and Seattle?"],
200
+ ["Can you retrieve the details for the user with the ID 7890, who has black as their special request?"]
201
  ],
202
+ cache_examples=False, # Không lưu trữ các ví dụ
203
  )
204
 
205
+ # Tạo bố cục giao diện với Gradio
206
  with gr.Blocks(css="style.css", fill_height=True) as demo:
207
+ gr.Markdown(DESCRIPTION) # Hiển thị phần mô tả
208
+ gr.DuplicateButton(value="Tạo bản sao cho sử dụng cá nhân", elem_id="duplicate-button")
209
+ chat_interface.render() # Hiển thị giao diện chat
210
 
211
+ # Khởi chạy ứng dụng khi chạy trực tiếp tệp này
212
  if __name__ == "__main__":
213
+ demo.queue(max_size=20).launch()