hoduyquocbao commited on
Commit
8d8f4b0
1 Parent(s): c3f15f3
Files changed (1) hide show
  1. app.py +254 -209
app.py CHANGED
@@ -1,257 +1,302 @@
1
  import os
2
- import time
3
- import requests
4
- import random
5
  from threading import Thread
6
- from typing import List, Dict, Union
7
- import torch
8
  import gradio as gr
9
- from bs4 import BeautifulSoup
 
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
- from functools import lru_cache
12
- import re
13
- import io
14
  import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Model Loading (Done once at startup)
17
- MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
18
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
19
  model = AutoModelForCausalLM.from_pretrained(
20
- MODEL_ID,
21
- torch_dtype=torch.float16,
22
  device_map="auto",
23
- low_cpu_mem_usage=True
24
- ).eval()
25
-
26
- # Path to example texts (Updated to remove image and video examples)
27
- examples_path = os.path.dirname(__file__)
28
- EXAMPLES = [
29
- [
30
- {
31
- "text": "What is Friction? Explain in Detail.",
32
- }
33
- ],
34
- [
35
- {
36
- "text": "Write me a Python function to generate unique passwords.",
37
- }
38
- ],
39
- [
40
- {
41
- "text": "What's the latest price of Bitcoin?",
42
- }
43
- ],
44
- [
45
- {
46
- "text": "Search and give me list of spaces trending on HuggingFace.",
47
- }
48
- ],
49
- [
50
- {
51
- "text": "Create a Beautiful Picture of Eiffel at Night.",
52
- }
53
- ],
54
- [
55
- {
56
- "text": "What unusual happens in this video.",
57
- "files": [f"{examples_path}/example_video/accident.gif"],
58
- }
59
- ],
60
- # Removed other image and video related examples
61
- ]
62
 
63
- # Set bot avatar image
64
- BOT_AVATAR = "OpenAI_logo.png"
65
 
66
- # Perform a Google search and return the results
67
- @lru_cache(maxsize=128)
68
- def extract_text_from_webpage(html_content):
69
- """Extracts visible text from HTML content using BeautifulSoup."""
70
  soup = BeautifulSoup(html_content, "html.parser")
 
71
  for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
72
  tag.extract()
 
73
  visible_text = soup.get_text(separator=' ', strip=True)
74
  return visible_text
75
 
76
- def search(query):
 
77
  term = query
78
  all_results = []
79
- max_chars_per_page = 8000
 
 
 
80
  with requests.Session() as session:
81
- resp = session.get(
82
- url="https://www.google.com/search",
83
- headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"},
84
- params={"q": term, "num": 4},
85
- timeout=5,
86
- verify=False,
87
- )
88
- resp.raise_for_status()
89
- soup = BeautifulSoup(resp.text, "html.parser")
90
- result_block = soup.find_all("div", attrs={"class": "g"})
91
- for result in result_block:
92
- link = result.find("a", href=True)
93
- if link and 'href' in link.attrs:
94
- link = link["href"]
95
- try:
96
- webpage = session.get(link, headers={"User-Agent": "Mozilla/5.0"}, timeout=5, verify=False)
97
- webpage.raise_for_status()
98
- visible_text = extract_text_from_webpage(webpage.text)
99
- if len(visible_text) > max_chars_per_page:
100
- visible_text = visible_text[:max_chars_per_page]
101
- all_results.append({"link": link, "text": visible_text})
102
- except requests.exceptions.RequestException:
103
- all_results.append({"link": link, "text": None})
 
 
 
 
 
 
 
 
104
  return all_results
105
 
106
- def generate_response(prompt, chat_history):
107
- # Construct the conversation history
108
- conversation = ""
 
 
 
109
  for user, assistant in chat_history:
110
- conversation += f"User: {user}\nAssistant: {assistant}\n"
111
- conversation += f"User: {prompt}\nAssistant:"
112
-
113
- inputs = tokenizer(conversation, return_tensors="pt").to("cuda")
114
-
115
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
116
- thread = Thread(target=model.generate, args=(inputs.input_ids,), kwargs={
117
- "max_new_tokens": 512,
 
 
 
 
 
 
 
 
 
 
 
118
  "do_sample": True,
119
- "top_p": 0.95,
120
- "temperature": 0.8,
121
- "streamer": streamer
122
- })
123
- thread.start()
124
-
125
- response = ""
126
- for new_text in streamer:
127
- response += new_text
128
- yield response
 
 
 
 
129
 
130
  @lru_cache(maxsize=128)
131
- def process_query(query, chat_history):
132
- # Here you can implement logic to decide between web_search, general_query, or hard_query
133
- # For simplicity, let's assume all queries go through general_query
134
- # You can expand this with your own logic
135
- functions_metadata = [
136
- {
137
- "type": "function",
138
- "function": {
139
- "name": "web_search",
140
- "description": "Search query on google and find latest information.",
141
- "parameters": {
142
- "type": "object",
143
- "properties": {
144
- "query": {"type": "string", "description": "Web search query"}
145
- },
146
- "required": ["query"]
147
- }
148
- }
149
- },
150
- {
151
- "type": "function",
152
- "function": {
153
- "name": "general_query",
154
- "description": "Reply general query with LLM.",
155
- "parameters": {
156
- "type": "object",
157
- "properties": {
158
- "prompt": {"type": "string", "description": "A detailed prompt"}
159
- },
160
- "required": ["prompt"]
161
- }
162
- }
163
- },
164
- {
165
- "type": "function",
166
- "function": {
167
- "name": "hard_query",
168
- "description": "Reply tough query using powerful LLM.",
169
- "parameters": {
170
- "type": "object",
171
- "properties": {
172
- "prompt": {"type": "string", "description": "A detailed prompt"}
173
- },
174
- "required": ["prompt"]
175
- }
176
- }
177
- },
178
- ]
179
-
180
- # Example logic to choose function (you can customize this)
181
- if "search" in query.lower():
182
  function_name = "web_search"
183
- elif "explain" in query.lower() or "detail" in query.lower():
 
184
  function_name = "general_query"
 
185
  else:
186
  function_name = "hard_query"
187
-
 
188
  return {
189
  "name": function_name,
190
- "arguments": {
191
- "query" if function_name == "web_search" else "prompt": query
192
- }
193
  }
194
 
195
- def handle_functions(function_call, chat_history):
 
 
 
196
  function_name = function_call["name"]
197
  arguments = function_call["arguments"]
198
-
199
  if function_name == "web_search":
200
  query = arguments["query"]
 
201
  web_results = search(query)
202
- web_summary = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results if res["text"]])
203
- # Append web results to chat history or pass to the model as context
204
- # Here we directly return the summarized web results
205
- return f"Here are the search results:\n{web_summary}"
206
-
 
 
 
 
 
 
207
  elif function_name in ["general_query", "hard_query"]:
208
- prompt = arguments["prompt"]
209
- # Generate response using the local model
210
- response_generator = generate_response(prompt, chat_history)
211
- return response_generator
212
-
 
 
 
 
 
 
 
 
 
 
213
  else:
214
- return "Function not recognized."
215
-
216
- def model_inference(user_prompt, chat_history):
217
- prompt = user_prompt["text"]
218
 
219
- # Determine which function to call
220
- function_call = process_query(prompt, chat_history)
221
 
222
- if function_call["name"] == "web_search":
223
- yield "Performing web search..."
224
- result = handle_functions(function_call, chat_history)
225
- yield result
226
-
227
- elif function_call["name"] in ["general_query", "hard_query"]:
228
- yield "Generating response..."
229
- response_generator = handle_functions(function_call, chat_history)
230
- for response in response_generator:
231
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
- else:
234
- yield "Invalid function call."
 
 
 
 
 
 
 
 
235
 
236
- # Create a chatbot interface
237
- chatbot = gr.Chatbot(
238
- label="OpenGPT-4o",
239
- avatar_images=[None, BOT_AVATAR],
240
- show_copy_button=True,
241
- layout="panel",
242
- height=400,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  )
244
- input_box = gr.Textbox(label="Prompt")
245
 
246
- iface = gr.Interface(
247
- fn=model_inference,
248
- inputs=[input_box, chatbot],
249
- outputs=chatbot,
250
- live=True,
251
- examples=EXAMPLES,
252
- title="OpenGPT-4o Chatbot",
253
- description="A powerful AI assistant using local Llama-3.2 model.",
254
- )
255
 
256
  if __name__ == "__main__":
257
- iface.launch()
 
1
  import os
 
 
 
2
  from threading import Thread
3
+ from typing import Iterator, List, Tuple, Dict, Any
4
+
5
  import gradio as gr
6
+ import spaces
7
+ import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+ from bs4 import BeautifulSoup
10
+ import requests
 
11
  import json
12
+ from functools import lru_cache
13
+
14
+ # ---------------------------- Cấu Hình ---------------------------- #
15
+
16
+ DESCRIPTION = """\
17
+ # Llama 3.2 3B Instruct với Chức Năng Nâng Cao
18
+
19
+ Llama 3.2 3B là phiên bản mới nhất của Meta về các mô hình ngôn ngữ mở.
20
+ Demo này giới thiệu [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), được tinh chỉnh để theo dõi hướng dẫn.
21
+ Để biết thêm chi tiết, vui lòng xem [bài viết của chúng tôi](https://huggingface.co/blog/llama32).
22
+ """
23
+
24
+ MAX_MAX_NEW_TOKENS = 2048 # Số token tối đa có thể tạo ra
25
+ DEFAULT_MAX_NEW_TOKENS = 1024 # Số token tạo ra mặc định
26
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) # Độ dài token tối đa cho đầu vào
27
 
28
+ # Xác định thiết bị sử dụng (GPU nếu có, ngược lại CPU)
29
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30
+
31
+ model_id = "nltpt/Llama-3.2-3B-Instruct" # ID mô hình, đảm bảo đây là ID mô hình đúng
32
+ tokenizer = AutoTokenizer.from_pretrained(model_id) # Tải tokenizer từ Hugging Face
33
  model = AutoModelForCausalLM.from_pretrained(
34
+ model_id,
 
35
  device_map="auto",
36
+ torch_dtype=torch.bfloat16, # Sử dụng dtype phù hợp để tiết kiệm bộ nhớ
37
+ )
38
+ model.to(device) # Di chuyển mô hình tới thiết bị đã chọn
39
+ model.eval() # Đặt hình chế độ đánh giá
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # ---------------------------- Định Nghĩa Hàm ---------------------------- #
 
42
 
43
+ @lru_cache(maxsize=128)
44
+ def extract_text_from_webpage(html_content: str) -> str:
45
+ """Trích xuất văn bản hiển thị từ nội dung HTML sử dụng BeautifulSoup."""
 
46
  soup = BeautifulSoup(html_content, "html.parser")
47
+ # Loại bỏ các thẻ không hiển thị như script, style, header, footer, nav, form, svg
48
  for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
49
  tag.extract()
50
+ # Trích xuất văn bản hiển thị, tách bằng dấu cách và loại bỏ khoảng trắng thừa
51
  visible_text = soup.get_text(separator=' ', strip=True)
52
  return visible_text
53
 
54
+ def search(query: str) -> List[Dict[str, Any]]:
55
+ """Thực hiện tìm kiếm trên Google và trả về kết quả."""
56
  term = query
57
  all_results = []
58
+ max_chars_per_page = 8000 # Số ký tự tối đa mỗi trang
59
+ headers = {
60
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
61
+ }
62
  with requests.Session() as session:
63
+ try:
64
+ resp = session.get(
65
+ url="https://www.google.com/search",
66
+ headers=headers,
67
+ params={"q": term, "num": 4}, # Tìm kiếm với 4 kết quả mỗi trang
68
+ timeout=5,
69
+ verify=False, # Bỏ qua xác minh SSL
70
+ )
71
+ resp.raise_for_status() # Kiểm tra phản hồi HTTP
72
+ soup = BeautifulSoup(resp.text, "html.parser")
73
+ result_blocks = soup.find_all("div", attrs={"class": "g"}) # Tìm tất cả các khối kết quả
74
+ for result in result_blocks:
75
+ link_tag = result.find("a", href=True) # Tìm thẻ liên kết
76
+ if link_tag and 'href' in link_tag.attrs:
77
+ link = link_tag["href"]
78
+ try:
79
+ webpage = session.get(
80
+ link,
81
+ headers=headers,
82
+ timeout=5,
83
+ verify=False
84
+ )
85
+ webpage.raise_for_status()
86
+ visible_text = extract_text_from_webpage(webpage.text)
87
+ if len(visible_text) > max_chars_per_page:
88
+ visible_text = visible_text[:max_chars_per_page] # Cắt văn bản nếu quá dài
89
+ all_results.append({"link": link, "text": visible_text})
90
+ except requests.exceptions.RequestException:
91
+ all_results.append({"link": link, "text": "Không thể lấy nội dung."})
92
+ except requests.exceptions.RequestException as e:
93
+ all_results.append({"link": "N/A", "text": "Không thể thực hiện tìm kiếm."})
94
  return all_results
95
 
96
+ def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
97
+ """
98
+ Tạo phản hồi sử dụng mô hình Llama cục bộ theo chế độ streaming.
99
+ """
100
+ # Xây dựng lịch sử cuộc trò chuyện
101
+ conversation = []
102
  for user, assistant in chat_history:
103
+ conversation.extend([
104
+ {"role": "user", "content": user},
105
+ {"role": "assistant", "content": assistant},
106
+ ])
107
+ conversation.append({"role": "user", "content": prompt}) # Thêm tin nhắn của người dùng
108
+
109
+ # Chuẩn bị input_ids từ tokenizer
110
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
111
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
112
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] # Cắt input nếu quá dài
113
+ gr.Warning(f"Đã cắt bỏ phần cuộc trò chuyện vì vượt quá {MAX_INPUT_TOKEN_LENGTH} token.")
114
+ input_ids = input_ids.to(device) # Di chuyển input tới thiết bị
115
+
116
+ # Khởi tạo streamer để nhận văn bản được tạo ra theo thời gian thực
117
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
118
+ generate_kwargs = {
119
+ "input_ids": input_ids,
120
+ "streamer": streamer,
121
+ "max_new_tokens": max_new_tokens,
122
  "do_sample": True,
123
+ "top_p": top_p,
124
+ "top_k": top_k,
125
+ "temperature": temperature,
126
+ "num_beams": 1,
127
+ "repetition_penalty": repetition_penalty,
128
+ }
129
+ t = Thread(target=model.generate, kwargs=generate_kwargs) # Tạo luồng để sinh văn bản
130
+ t.start()
131
+
132
+ # Stream văn bản được tạo ra
133
+ outputs = []
134
+ for text in streamer:
135
+ outputs.append(text)
136
+ yield "".join(outputs)
137
 
138
  @lru_cache(maxsize=128)
139
+ def process_query(query: str) -> Dict[str, Any]:
140
+ """
141
+ Xác định hàm nào sẽ được gọi dựa trên truy vấn của người dùng.
142
+ """
143
+ # Định nghĩa các từ khóa hoặc mẫu để xác định hàm
144
+ web_search_keywords = ["tìm kiếm", "tìm", "tra cứu", "google", "lookup"]
145
+ general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"]
146
+ # Bất kỳ truy vấn nào khác sẽ được xử lý như hard_query
147
+
148
+ query_lower = query.lower() # Chuyển truy vấn thành chữ thường để so sánh
149
+
150
+ if any(keyword in query_lower for keyword in web_search_keywords):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  function_name = "web_search"
152
+ arguments = {"query": query}
153
+ elif any(keyword in query_lower for keyword in general_query_keywords):
154
  function_name = "general_query"
155
+ arguments = {"prompt": query}
156
  else:
157
  function_name = "hard_query"
158
+ arguments = {"prompt": query}
159
+
160
  return {
161
  "name": function_name,
162
+ "arguments": arguments
 
 
163
  }
164
 
165
+ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
166
+ """
167
+ Thực thi hàm phù hợp dựa trên lời gọi hàm.
168
+ """
169
  function_name = function_call["name"]
170
  arguments = function_call["arguments"]
171
+
172
  if function_name == "web_search":
173
  query = arguments["query"]
174
+ yield "🔍 Đang thực hiện tìm kiếm trên web..."
175
  web_results = search(query)
176
+ if not web_results:
177
+ yield "⚠️ Không tìm thấy kết quả."
178
+ return
179
+ # Tóm tắt kết quả tìm kiếm
180
+ web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
181
+ if not web_summary:
182
+ web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
183
+
184
+ # Trả về kết quả tìm kiếm cho người dùng
185
+ yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
186
+
187
  elif function_name in ["general_query", "hard_query"]:
188
+ prompt_text = arguments["prompt"]
189
+ yield "🤖 Đang tạo phản hồi..."
190
+ # Tạo phản hồi sử dụng mô hình Llama
191
+ response_generator = generate_response(
192
+ prompt=prompt_text,
193
+ chat_history=chat_history,
194
+ max_new_tokens=max_new_tokens,
195
+ temperature=temperature,
196
+ top_p=top_p,
197
+ top_k=top_k,
198
+ repetition_penalty=repetition_penalty
199
+ )
200
+ for response in response_generator:
201
+ yield response
202
+
203
  else:
204
+ yield "⚠️ Lời gọi hàm không được nhận dạng."
 
 
 
205
 
206
+ # ---------------------------- Giao Diện Gradio ---------------------------- #
 
207
 
208
+ @spaces.GPU(duration=60, queue=False)
209
+ def generate(
210
+ message: str,
211
+ chat_history: List[Tuple[str, str]],
212
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
213
+ temperature: float = 0.6,
214
+ top_p: float = 0.9,
215
+ top_k: int = 50,
216
+ repetition_penalty: float = 1.2,
217
+ ) -> Iterator[str]:
218
+ """
219
+ Hàm chính để xử lý đầu vào của người dùng và tạo phản hồi.
220
+ """
221
+ # Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng
222
+ function_call = process_query(message)
223
+
224
+ # Xử lý lời gọi hàm và sinh phản hồi tương ứng
225
+ response_iterator = handle_functions(
226
+ function_call=function_call,
227
+ prompt=message,
228
+ chat_history=chat_history,
229
+ max_new_tokens=max_new_tokens,
230
+ temperature=temperature,
231
+ top_p=top_p,
232
+ top_k=top_k,
233
+ repetition_penalty=repetition_penalty
234
+ )
235
+
236
+ for response in response_iterator:
237
+ yield response
238
 
239
+ # Định nghĩa các ví dụ để hướng dẫn người dùng
240
+ EXAMPLES = [
241
+ ["Xin chào! Bạn khỏe không?"],
242
+ ["Bạn có thể giải thích ngắn gọn về ngôn ngữ lập trình Python không?"],
243
+ ["Giải thích cốt truyện của Cô bé Lọ Lem trong một câu."],
244
+ ["Một người đàn ông cần bao nhiêu giờ để ăn một chiếc máy bay trực thăng?"],
245
+ ["Viết một bài báo 100 từ về 'Lợi ích của mã nguồn mở trong nghiên cứu AI'"],
246
+ ["Tìm và cung cấp cho tôi tin tức mới nhất về năng lượng tái tạo."],
247
+ ["Tìm thông tin về Rạn san hô Great Barrier Reef."],
248
+ ]
249
 
250
+ # Cấu hình giao diện trò chuyện của Gradio
251
+ chat_interface = gr.ChatInterface(
252
+ fn=generate, # Hàm được gọi khi có tương tác từ người dùng
253
+ additional_inputs=[
254
+ gr.Slider(
255
+ label="Số token mới tối đa",
256
+ minimum=1,
257
+ maximum=MAX_MAX_NEW_TOKENS,
258
+ step=1,
259
+ value=DEFAULT_MAX_NEW_TOKENS,
260
+ ),
261
+ gr.Slider(
262
+ label="Nhiệt độ",
263
+ minimum=0.1,
264
+ maximum=4.0,
265
+ step=0.1,
266
+ value=0.6,
267
+ ),
268
+ gr.Slider(
269
+ label="Top-p (nucleus sampling)",
270
+ minimum=0.05,
271
+ maximum=1.0,
272
+ step=0.05,
273
+ value=0.9,
274
+ ),
275
+ gr.Slider(
276
+ label="Top-k",
277
+ minimum=1,
278
+ maximum=1000,
279
+ step=1,
280
+ value=50,
281
+ ),
282
+ gr.Slider(
283
+ label="Hình phạt sự lặp lại",
284
+ minimum=1.0,
285
+ maximum=2.0,
286
+ step=0.05,
287
+ value=1.2,
288
+ ),
289
+ ],
290
+ stop_btn=None, # Không có nút dừng
291
+ examples=EXAMPLES, # Các ví dụ được hiển thị cho người dùng
292
+ cache_examples=False, # Không lưu bộ nhớ cache cho các ví dụ
293
  )
 
294
 
295
+ # Tạo giao diện chính của Gradio
296
+ with gr.Blocks(css="style.css", fill_height=True) as demo:
297
+ gr.Markdown(DESCRIPTION) # Hiển thị mô tả
298
+ gr.DuplicateButton(value="Nhân bản Không gian để sử dụng riêng tư", elem_id="duplicate-button") # Nút nhân bản không gian
299
+ chat_interface.render() # Hiển thị giao diện trò chuyện
 
 
 
 
300
 
301
  if __name__ == "__main__":
302
+ demo.queue(max_size=20).launch() # Khởi chạy ứng dụng Gradio với hàng đợi kích thước tối đa là 20