Spaces:
Sleeping
Sleeping
hoduyquocbao
commited on
Commit
•
e51a5b0
1
Parent(s):
07ee95b
change new logics
Browse files
app.py
CHANGED
@@ -1,88 +1,101 @@
|
|
1 |
-
# Import các thư viện cần thiết
|
2 |
import os
|
3 |
-
import
|
|
|
|
|
4 |
from threading import Thread
|
5 |
-
from typing import
|
6 |
-
|
7 |
-
import gradio as gr
|
8 |
-
import spaces
|
9 |
import torch
|
10 |
-
import
|
11 |
-
from bs4 import BeautifulSoup
|
12 |
-
from functools import lru_cache # Thêm lru_cache để cache kết quả hàm
|
13 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
#
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
# Thiết lập các thông số tối đa
|
25 |
-
MAX_MAX_NEW_TOKENS = 2048 # Số token tối đa cho đầu ra mới
|
26 |
-
DEFAULT_MAX_NEW_TOKENS = 1024 # Số token mặc định cho đầu ra mới
|
27 |
-
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) # Lấy giá trị chiều dài token đầu vào từ biến môi trường
|
28 |
|
29 |
-
#
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
#
|
33 |
-
|
34 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
35 |
-
model = AutoModelForCausalLM.from_pretrained(
|
36 |
-
model_id,
|
37 |
-
device_map="auto", # Tự động ánh xạ thiết bị
|
38 |
-
torch_dtype=torch.bfloat16, # Sử dụng kiểu dữ liệu bfloat16
|
39 |
-
)
|
40 |
-
model.eval() # Đặt mô hình vào chế độ đánh giá (evaluation mode)
|
41 |
|
42 |
-
#
|
43 |
-
@lru_cache(maxsize=128)
|
44 |
def extract_text_from_webpage(html_content):
|
45 |
-
"""
|
46 |
soup = BeautifulSoup(html_content, "html.parser")
|
47 |
-
# Loại bỏ các thẻ không hiển thị như script, style, header, footer, nav, form, svg
|
48 |
for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
|
49 |
tag.extract()
|
50 |
-
|
51 |
-
visible_text = soup.get_text(strip=True)
|
52 |
return visible_text
|
53 |
|
54 |
-
# Hàm thực hiện tìm kiếm trên Google và trả về kết quả
|
55 |
def search(query):
|
56 |
term = query
|
57 |
all_results = []
|
58 |
-
max_chars_per_page = 8000
|
59 |
with requests.Session() as session:
|
60 |
-
# Thực hiện yêu cầu tìm kiếm trên Google
|
61 |
resp = session.get(
|
62 |
url="https://www.google.com/search",
|
63 |
headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"},
|
64 |
-
params={"q": term, "num": 4
|
65 |
timeout=5,
|
66 |
-
verify=False,
|
67 |
)
|
68 |
-
resp.raise_for_status()
|
69 |
soup = BeautifulSoup(resp.text, "html.parser")
|
70 |
-
result_block = soup.find_all("div", attrs={"class": "g"})
|
71 |
for result in result_block:
|
72 |
link = result.find("a", href=True)
|
73 |
-
if link:
|
74 |
link = link["href"]
|
75 |
try:
|
76 |
-
|
77 |
-
webpage = session.get(
|
78 |
-
link,
|
79 |
-
headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"},
|
80 |
-
timeout=5,
|
81 |
-
verify=False
|
82 |
-
)
|
83 |
webpage.raise_for_status()
|
84 |
visible_text = extract_text_from_webpage(webpage.text)
|
85 |
-
# Cắt bớt văn bản nếu vượt quá giới hạn
|
86 |
if len(visible_text) > max_chars_per_page:
|
87 |
visible_text = visible_text[:max_chars_per_page]
|
88 |
all_results.append({"link": link, "text": visible_text})
|
@@ -90,243 +103,155 @@ def search(query):
|
|
90 |
all_results.append({"link": link, "text": None})
|
91 |
return all_results
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
"
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
}
|
105 |
-
}
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
"
|
118 |
}
|
119 |
-
}
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
"
|
132 |
}
|
133 |
-
}
|
134 |
-
|
135 |
-
|
136 |
-
},
|
137 |
-
]
|
138 |
-
|
139 |
-
# Prompt role system với các hướng dẫn về khả năng và công cụ
|
140 |
-
SYSTEM_PROMPT = f"""\
|
141 |
-
Bạn là một trợ lý thông minh và hữu ích với khả năng sử dụng các công cụ sau đây để hỗ trợ người dùng:
|
142 |
-
|
143 |
-
{json.dumps(functions_metadata, indent=4, ensure_ascii=False)}
|
144 |
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
-
|
|
|
|
|
150 |
|
151 |
-
""
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
max_new_tokens: int = 1024,
|
159 |
-
temperature: float = 0.6,
|
160 |
-
top_p: float = 0.9,
|
161 |
-
top_k: int = 50,
|
162 |
-
repetition_penalty: float = 1.2,
|
163 |
-
) -> Iterator[str]:
|
164 |
-
conversation = []
|
165 |
-
func_caller = []
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
conversation.extend(
|
170 |
-
[
|
171 |
-
{"role": "user", "content": user},
|
172 |
-
{"role": "assistant", "content": assistant},
|
173 |
-
]
|
174 |
-
)
|
175 |
-
|
176 |
-
# Thêm tin nhắn mới của người dùng vào cuộc hội thoại
|
177 |
-
conversation.append({"role": "user", "content": message})
|
178 |
|
179 |
-
|
180 |
-
|
181 |
|
182 |
-
#
|
183 |
-
|
184 |
-
|
185 |
-
# Kiểm tra và cắt bớt chuỗi đầu vào nếu vượt quá chiều dài tối đa
|
186 |
-
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
187 |
-
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
188 |
-
gr.Warning(f"Đã cắt bớt đầu vào từ cuộc hội thoại vì vượt quá {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
189 |
-
|
190 |
-
# Chuyển tensor đến thiết bị của mô hình
|
191 |
-
input_ids = input_ids.to(model.device)
|
192 |
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
generate_kwargs = dict(
|
198 |
-
{"input_ids": input_ids},
|
199 |
-
streamer=streamer,
|
200 |
-
max_new_tokens=max_new_tokens,
|
201 |
-
do_sample=True,
|
202 |
-
top_p=top_p,
|
203 |
-
top_k=top_k,
|
204 |
-
temperature=temperature,
|
205 |
-
num_beams=1,
|
206 |
-
repetition_penalty=repetition_penalty,
|
207 |
-
)
|
208 |
-
|
209 |
-
# Tạo một luồng để chạy quá trình sinh đầu ra
|
210 |
-
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
211 |
-
t.start()
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
response
|
217 |
-
|
218 |
|
219 |
-
|
220 |
-
|
221 |
-
try:
|
222 |
-
# Phân tích phản hồi JSON
|
223 |
-
json_data = json.loads(response)
|
224 |
-
|
225 |
-
# Xử lý các hàm được gọi
|
226 |
-
for func in json_data:
|
227 |
-
func_name = func.get("name")
|
228 |
-
arguments = func.get("arguments", {})
|
229 |
-
|
230 |
-
if func_name == "web_search":
|
231 |
-
query = arguments.get("query")
|
232 |
-
if query:
|
233 |
-
gr.Info("Đang tìm kiếm trên web...")
|
234 |
-
yield "Đang tìm kiếm trên web..."
|
235 |
-
web_results = search(query)
|
236 |
-
|
237 |
-
gr.Info("Đang trích xuất thông tin liên quan...")
|
238 |
-
yield "Đang trích xuất thông tin liên quan..."
|
239 |
-
web_content = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results])
|
240 |
-
|
241 |
-
# Xây dựng lại cuộc hội thoại với kết quả tìm kiếm
|
242 |
-
conversation_updated = conversation.copy()
|
243 |
-
conversation_updated.append({"role": "user", "content": message})
|
244 |
-
conversation_updated.append({"role": "system", "content": f"[WEB RESULTS] {web_content}"})
|
245 |
-
|
246 |
-
# Gửi lại cuộc hội thoại để mô hình trả lời dựa trên kết quả tìm kiếm
|
247 |
-
input_ids_updated = tokenizer.apply_chat_template(conversation_updated, add_generation_prompt=True, return_tensors="pt")
|
248 |
-
input_ids_updated = input_ids_updated.to(model.device)
|
249 |
-
|
250 |
-
streamer_updated = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
251 |
-
generate_kwargs_updated = dict(
|
252 |
-
{"input_ids": input_ids_updated},
|
253 |
-
streamer=streamer_updated,
|
254 |
-
max_new_tokens=max_new_tokens,
|
255 |
-
do_sample=True,
|
256 |
-
top_p=top_p,
|
257 |
-
top_k=top_k,
|
258 |
-
temperature=temperature,
|
259 |
-
num_beams=1,
|
260 |
-
repetition_penalty=repetition_penalty,
|
261 |
-
)
|
262 |
-
|
263 |
-
t_updated = Thread(target=model.generate, kwargs=generate_kwargs_updated)
|
264 |
-
t_updated.start()
|
265 |
-
|
266 |
-
for text_updated in streamer_updated:
|
267 |
-
yield text_updated
|
268 |
-
|
269 |
-
except json.JSONDecodeError:
|
270 |
-
# Nếu phản hồi không phải là JSON hợp lệ, tiếp tục trả lời bình thường
|
271 |
-
continue
|
272 |
|
273 |
-
#
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
maximum=MAX_MAX_NEW_TOKENS,
|
281 |
-
step=1,
|
282 |
-
value=DEFAULT_MAX_NEW_TOKENS,
|
283 |
-
),
|
284 |
-
gr.Slider(
|
285 |
-
label="Nhiệt độ (Temperature)",
|
286 |
-
minimum=0.1,
|
287 |
-
maximum=4.0,
|
288 |
-
step=0.1,
|
289 |
-
value=0.6,
|
290 |
-
),
|
291 |
-
gr.Slider(
|
292 |
-
label="Top-p (nucleus sampling)",
|
293 |
-
minimum=0.05,
|
294 |
-
maximum=1.0,
|
295 |
-
step=0.05,
|
296 |
-
value=0.9,
|
297 |
-
),
|
298 |
-
gr.Slider(
|
299 |
-
label="Top-k",
|
300 |
-
minimum=1,
|
301 |
-
maximum=1000,
|
302 |
-
step=1,
|
303 |
-
value=50,
|
304 |
-
),
|
305 |
-
gr.Slider(
|
306 |
-
label="Hình phạt lặp lại (Repetition penalty)",
|
307 |
-
minimum=1.0,
|
308 |
-
maximum=2.0,
|
309 |
-
step=0.05,
|
310 |
-
value=1.2,
|
311 |
-
),
|
312 |
-
],
|
313 |
-
stop_btn=None, # Không có nút dừng
|
314 |
-
examples=[
|
315 |
-
["Xin chào! Bạn có khỏe không?"],
|
316 |
-
["Bạn có thể giải thích ngắn gọn về ngôn ngữ lập trình Python không?"],
|
317 |
-
["Giải thích cốt truyện của Cô bé Lọ Lem trong một câu."],
|
318 |
-
["Mất bao nhiêu giờ để một người ăn một chiếc trực thăng?"],
|
319 |
-
["Viết một bài báo 100 từ về 'Lợi ích của mã nguồn mở trong nghiên cứu AI'"],
|
320 |
-
],
|
321 |
-
cache_examples=False, # Không lưu trữ các ví dụ
|
322 |
)
|
|
|
323 |
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
|
|
|
|
|
|
|
|
329 |
|
330 |
-
# Hàm chính để khởi chạy ứng dụng
|
331 |
if __name__ == "__main__":
|
332 |
-
|
|
|
|
|
1 |
import os
|
2 |
+
import time
|
3 |
+
import requests
|
4 |
+
import random
|
5 |
from threading import Thread
|
6 |
+
from typing import List, Dict, Union
|
|
|
|
|
|
|
7 |
import torch
|
8 |
+
import gradio as gr
|
9 |
+
from bs4 import BeautifulSoup
|
|
|
10 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
11 |
+
from functools import lru_cache
|
12 |
+
import re
|
13 |
+
import io
|
14 |
+
import json
|
15 |
|
16 |
+
# Model Loading (Done once at startup)
|
17 |
+
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
19 |
+
model = AutoModelForCausalLM.from_pretrained(
|
20 |
+
MODEL_ID,
|
21 |
+
torch_dtype=torch.float16,
|
22 |
+
device_map="cuda",
|
23 |
+
low_cpu_mem_usage=True
|
24 |
+
).eval()
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
# Path to example texts (Updated to remove image and video examples)
|
27 |
+
examples_path = os.path.dirname(__file__)
|
28 |
+
EXAMPLES = [
|
29 |
+
[
|
30 |
+
{
|
31 |
+
"text": "What is Friction? Explain in Detail.",
|
32 |
+
}
|
33 |
+
],
|
34 |
+
[
|
35 |
+
{
|
36 |
+
"text": "Write me a Python function to generate unique passwords.",
|
37 |
+
}
|
38 |
+
],
|
39 |
+
[
|
40 |
+
{
|
41 |
+
"text": "What's the latest price of Bitcoin?",
|
42 |
+
}
|
43 |
+
],
|
44 |
+
[
|
45 |
+
{
|
46 |
+
"text": "Search and give me list of spaces trending on HuggingFace.",
|
47 |
+
}
|
48 |
+
],
|
49 |
+
[
|
50 |
+
{
|
51 |
+
"text": "Create a Beautiful Picture of Eiffel at Night.",
|
52 |
+
}
|
53 |
+
],
|
54 |
+
[
|
55 |
+
{
|
56 |
+
"text": "What unusual happens in this video.",
|
57 |
+
"files": [f"{examples_path}/example_video/accident.gif"],
|
58 |
+
}
|
59 |
+
],
|
60 |
+
# Removed other image and video related examples
|
61 |
+
]
|
62 |
|
63 |
+
# Set bot avatar image
|
64 |
+
BOT_AVATAR = "OpenAI_logo.png"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
# Perform a Google search and return the results
|
67 |
+
@lru_cache(maxsize=128)
|
68 |
def extract_text_from_webpage(html_content):
|
69 |
+
"""Extracts visible text from HTML content using BeautifulSoup."""
|
70 |
soup = BeautifulSoup(html_content, "html.parser")
|
|
|
71 |
for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
|
72 |
tag.extract()
|
73 |
+
visible_text = soup.get_text(separator=' ', strip=True)
|
|
|
74 |
return visible_text
|
75 |
|
|
|
76 |
def search(query):
|
77 |
term = query
|
78 |
all_results = []
|
79 |
+
max_chars_per_page = 8000
|
80 |
with requests.Session() as session:
|
|
|
81 |
resp = session.get(
|
82 |
url="https://www.google.com/search",
|
83 |
headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"},
|
84 |
+
params={"q": term, "num": 4},
|
85 |
timeout=5,
|
86 |
+
verify=False,
|
87 |
)
|
88 |
+
resp.raise_for_status()
|
89 |
soup = BeautifulSoup(resp.text, "html.parser")
|
90 |
+
result_block = soup.find_all("div", attrs={"class": "g"})
|
91 |
for result in result_block:
|
92 |
link = result.find("a", href=True)
|
93 |
+
if link and 'href' in link.attrs:
|
94 |
link = link["href"]
|
95 |
try:
|
96 |
+
webpage = session.get(link, headers={"User-Agent": "Mozilla/5.0"}, timeout=5, verify=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
webpage.raise_for_status()
|
98 |
visible_text = extract_text_from_webpage(webpage.text)
|
|
|
99 |
if len(visible_text) > max_chars_per_page:
|
100 |
visible_text = visible_text[:max_chars_per_page]
|
101 |
all_results.append({"link": link, "text": visible_text})
|
|
|
103 |
all_results.append({"link": link, "text": None})
|
104 |
return all_results
|
105 |
|
106 |
+
def generate_response(prompt, chat_history):
|
107 |
+
# Construct the conversation history
|
108 |
+
conversation = ""
|
109 |
+
for user, assistant in chat_history:
|
110 |
+
conversation += f"User: {user}\nAssistant: {assistant}\n"
|
111 |
+
conversation += f"User: {prompt}\nAssistant:"
|
112 |
+
|
113 |
+
inputs = tokenizer(conversation, return_tensors="pt").to("cuda")
|
114 |
+
|
115 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
116 |
+
thread = Thread(target=model.generate, args=(inputs.input_ids,), kwargs={
|
117 |
+
"max_new_tokens": 512,
|
118 |
+
"do_sample": True,
|
119 |
+
"top_p": 0.95,
|
120 |
+
"temperature": 0.8,
|
121 |
+
"streamer": streamer
|
122 |
+
})
|
123 |
+
thread.start()
|
124 |
+
|
125 |
+
response = ""
|
126 |
+
for new_text in streamer:
|
127 |
+
response += new_text
|
128 |
+
yield response
|
129 |
+
|
130 |
+
@lru_cache(maxsize=128)
|
131 |
+
def process_query(query, chat_history):
|
132 |
+
# Here you can implement logic to decide between web_search, general_query, or hard_query
|
133 |
+
# For simplicity, let's assume all queries go through general_query
|
134 |
+
# You can expand this with your own logic
|
135 |
+
functions_metadata = [
|
136 |
+
{
|
137 |
+
"type": "function",
|
138 |
+
"function": {
|
139 |
+
"name": "web_search",
|
140 |
+
"description": "Search query on google and find latest information.",
|
141 |
+
"parameters": {
|
142 |
+
"type": "object",
|
143 |
+
"properties": {
|
144 |
+
"query": {"type": "string", "description": "Web search query"}
|
145 |
+
},
|
146 |
+
"required": ["query"]
|
147 |
}
|
148 |
+
}
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"type": "function",
|
152 |
+
"function": {
|
153 |
+
"name": "general_query",
|
154 |
+
"description": "Reply general query with LLM.",
|
155 |
+
"parameters": {
|
156 |
+
"type": "object",
|
157 |
+
"properties": {
|
158 |
+
"prompt": {"type": "string", "description": "A detailed prompt"}
|
159 |
+
},
|
160 |
+
"required": ["prompt"]
|
161 |
}
|
162 |
+
}
|
163 |
+
},
|
164 |
+
{
|
165 |
+
"type": "function",
|
166 |
+
"function": {
|
167 |
+
"name": "hard_query",
|
168 |
+
"description": "Reply tough query using powerful LLM.",
|
169 |
+
"parameters": {
|
170 |
+
"type": "object",
|
171 |
+
"properties": {
|
172 |
+
"prompt": {"type": "string", "description": "A detailed prompt"}
|
173 |
+
},
|
174 |
+
"required": ["prompt"]
|
175 |
}
|
176 |
+
}
|
177 |
+
},
|
178 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
+
# Example logic to choose function (you can customize this)
|
181 |
+
if "search" in query.lower():
|
182 |
+
function_name = "web_search"
|
183 |
+
elif "explain" in query.lower() or "detail" in query.lower():
|
184 |
+
function_name = "general_query"
|
185 |
+
else:
|
186 |
+
function_name = "hard_query"
|
187 |
|
188 |
+
return {
|
189 |
+
"name": function_name,
|
190 |
+
"arguments": {
|
191 |
+
"query" if function_name == "web_search" else "prompt": query
|
192 |
+
}
|
193 |
+
}
|
194 |
|
195 |
+
def handle_functions(function_call, chat_history):
|
196 |
+
function_name = function_call["name"]
|
197 |
+
arguments = function_call["arguments"]
|
198 |
|
199 |
+
if function_name == "web_search":
|
200 |
+
query = arguments["query"]
|
201 |
+
web_results = search(query)
|
202 |
+
web_summary = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results if res["text"]])
|
203 |
+
# Append web results to chat history or pass to the model as context
|
204 |
+
# Here we directly return the summarized web results
|
205 |
+
return f"Here are the search results:\n{web_summary}"
|
206 |
|
207 |
+
elif function_name in ["general_query", "hard_query"]:
|
208 |
+
prompt = arguments["prompt"]
|
209 |
+
# Generate response using the local model
|
210 |
+
response_generator = generate_response(prompt, chat_history)
|
211 |
+
return response_generator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
+
else:
|
214 |
+
return "Function not recognized."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
+
def model_inference(user_prompt, chat_history):
|
217 |
+
prompt = user_prompt["text"]
|
218 |
|
219 |
+
# Determine which function to call
|
220 |
+
function_call = process_query(prompt, chat_history)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
+
if function_call["name"] == "web_search":
|
223 |
+
yield "Performing web search..."
|
224 |
+
result = handle_functions(function_call, chat_history)
|
225 |
+
yield result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
+
elif function_call["name"] in ["general_query", "hard_query"]:
|
228 |
+
yield "Generating response..."
|
229 |
+
response_generator = handle_functions(function_call, chat_history)
|
230 |
+
for response in response_generator:
|
231 |
+
yield response
|
232 |
|
233 |
+
else:
|
234 |
+
yield "Invalid function call."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
+
# Create a chatbot interface
|
237 |
+
chatbot = gr.Chatbot(
|
238 |
+
label="OpenGPT-4o",
|
239 |
+
avatar_images=[None, BOT_AVATAR],
|
240 |
+
show_copy_button=True,
|
241 |
+
layout="panel",
|
242 |
+
height=400,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
)
|
244 |
+
input_box = gr.Textbox(label="Prompt")
|
245 |
|
246 |
+
iface = gr.Interface(
|
247 |
+
fn=model_inference,
|
248 |
+
inputs=[input_box, chatbot],
|
249 |
+
outputs=chatbot,
|
250 |
+
live=True,
|
251 |
+
examples=EXAMPLES,
|
252 |
+
title="OpenGPT-4o Chatbot",
|
253 |
+
description="A powerful AI assistant using local Llama-3.2 model.",
|
254 |
+
)
|
255 |
|
|
|
256 |
if __name__ == "__main__":
|
257 |
+
iface.launch()
|