hoduyquocbao commited on
Commit
00430c0
1 Parent(s): 5121e98

update new feature datasets

Browse files
Files changed (1) hide show
  1. app.py +139 -29
app.py CHANGED
@@ -1,6 +1,9 @@
 
1
  import os
2
  from threading import Thread
3
  from typing import Iterator, List, Tuple, Dict, Any
 
 
4
 
5
  import gradio as gr
6
  import spaces
@@ -8,8 +11,9 @@ import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
9
  from bs4 import BeautifulSoup
10
  import requests
11
- import json
12
  from functools import lru_cache
 
 
13
 
14
  # ---------------------------- Cấu Hình ---------------------------- #
15
 
@@ -41,6 +45,60 @@ model.eval() # Đặt mô hình ở chế độ đánh giá
41
  # Khởi tạo pipeline phân tích tâm lý
42
  sentiment_pipeline = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  # ---------------------------- Định Nghĩa Hàm ---------------------------- #
45
 
46
  @lru_cache(maxsize=128)
@@ -103,7 +161,7 @@ def summarize_text(text: str, max_length: int = 150) -> str:
103
  ]
104
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
105
  input_ids = input_ids.to(device)
106
-
107
  summary_streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
108
  summary_kwargs = {
109
  "input_ids": input_ids,
@@ -115,7 +173,7 @@ def summarize_text(text: str, max_length: int = 150) -> str:
115
  }
116
  t = Thread(target=model.generate, kwargs=summary_kwargs)
117
  t.start()
118
-
119
  summary = ""
120
  for new_text in summary_streamer:
121
  summary += new_text
@@ -128,26 +186,43 @@ def analyze_sentiment(text: str) -> str:
128
  score = result[0]['score']
129
  return f"🟢 **Tâm lý**: {sentiment} (Điểm: {score:.2f})"
130
 
131
- def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
132
  """
133
  Tạo phản hồi sử dụng mô hình Llama cục bộ theo chế độ streaming.
134
  """
135
- # Xây dựng lịch sử cuộc trò chuyện
136
- conversation = []
137
- for user, assistant in chat_history:
138
- conversation.extend([
139
- {"role": "user", "content": user},
140
- {"role": "assistant", "content": assistant},
 
 
141
  ])
142
- conversation.append({"role": "user", "content": prompt}) # Thêm tin nhắn của người dùng
 
 
 
 
 
 
 
 
 
143
 
144
  # Chuẩn bị input_ids từ tokenizer
145
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
146
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
147
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] # Cắt input nếu quá dài
148
- gr.Warning(f"Đã cắt bỏ phần cuộc trò chuyện vì vượt quá {MAX_INPUT_TOKEN_LENGTH} token.")
 
 
 
 
 
 
149
  input_ids = input_ids.to(device) # Di chuyển input tới thiết bị
150
-
151
  # Khởi tạo streamer để nhận văn bản được tạo ra theo thời gian thực
152
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
153
  generate_kwargs = {
@@ -163,13 +238,17 @@ def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_
163
  }
164
  t = Thread(target=model.generate, kwargs=generate_kwargs) # Tạo luồng để sinh văn bản
165
  t.start()
166
-
167
  # Stream văn bản được tạo ra
168
  outputs = []
169
  for text in streamer:
170
  outputs.append(text)
171
  yield "".join(outputs)
172
 
 
 
 
 
173
  @lru_cache(maxsize=128)
174
  def process_query(query: str) -> Dict[str, Any]:
175
  """
@@ -180,9 +259,10 @@ def process_query(query: str) -> Dict[str, Any]:
180
  general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"]
181
  summarize_keywords = ["tóm tắt", "tóm lại", "khái quát", "ngắn gọn"]
182
  sentiment_keywords = ["cảm xúc", "tâm trạng", "tâm lý", "phân tích cảm xúc"]
 
183
 
184
  query_lower = query.lower() # Chuyển truy vấn thành chữ thường để so sánh
185
-
186
  if any(keyword in query_lower for keyword in web_search_keywords):
187
  function_name = "web_search"
188
  arguments = {"query": query}
@@ -192,25 +272,28 @@ def process_query(query: str) -> Dict[str, Any]:
192
  elif any(keyword in query_lower for keyword in sentiment_keywords):
193
  function_name = "sentiment_analysis"
194
  arguments = {"prompt": query}
 
 
 
195
  elif any(keyword in query_lower for keyword in general_query_keywords):
196
  function_name = "general_query"
197
  arguments = {"prompt": query}
198
  else:
199
  function_name = "hard_query"
200
  arguments = {"prompt": query}
201
-
202
  return {
203
  "name": function_name,
204
  "arguments": arguments
205
  }
206
 
207
- def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
208
  """
209
  Thực thi hàm phù hợp dựa trên lời gọi hàm.
210
  """
211
  function_name = function_call["name"]
212
  arguments = function_call["arguments"]
213
-
214
  if function_name == "web_search":
215
  query = arguments["query"]
216
  yield "🔍 Đang thực hiện tìm kiếm trên web..."
@@ -222,10 +305,10 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
222
  web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
223
  if not web_summary:
224
  web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
225
-
226
  # Trả về kết quả tìm kiếm cho người dùng
227
  yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
228
-
229
  elif function_name == "summarize_query":
230
  # Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả
231
  query = arguments["prompt"]
@@ -242,14 +325,22 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
242
  # Tóm tắt nội dung đã lấy
243
  yield "📝 Đang tóm tắt thông tin..."
244
  summary = summarize_text(combined_text)
 
 
245
  yield "📄 **Tóm tắt:**\n" + summary
246
-
247
  elif function_name == "sentiment_analysis":
248
  prompt_text = arguments["prompt"]
249
  yield "📊 Đang phân tích tâm lý..."
250
  sentiment = analyze_sentiment(prompt_text)
251
  yield sentiment
252
-
 
 
 
 
 
 
253
  elif function_name in ["general_query", "hard_query"]:
254
  prompt_text = arguments["prompt"]
255
  yield "🤖 Đang tạo phản hồi..."
@@ -257,6 +348,7 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
257
  response_generator = generate_response(
258
  prompt=prompt_text,
259
  chat_history=chat_history,
 
260
  max_new_tokens=max_new_tokens,
261
  temperature=temperature,
262
  top_p=top_p,
@@ -265,12 +357,23 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
265
  )
266
  for response in response_generator:
267
  yield response
268
-
269
  else:
270
  yield "⚠️ Lời gọi hàm không được nhận dạng."
271
 
272
  # ---------------------------- Giao Diện Gradio ---------------------------- #
273
 
 
 
 
 
 
 
 
 
 
 
 
274
  @spaces.GPU(duration=15, queue=False)
275
  def generate(
276
  message: str,
@@ -286,10 +389,13 @@ def generate(
286
  """
287
  # Thông báo về việc phân tích đầu vào
288
  yield "🔍 Đang phân tích truy vấn của bạn..."
289
-
 
 
 
290
  # Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng
291
  function_call = process_query(message)
292
-
293
  # Thông báo về hàm được chọn
294
  if function_call["name"] == "web_search":
295
  yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web."
@@ -297,23 +403,27 @@ def generate(
297
  yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản."
298
  elif function_call["name"] == "sentiment_analysis":
299
  yield "🛠️ Đã chọn chức năng: Phân tích tâm lý."
 
 
 
300
  elif function_call["name"] in ["general_query", "hard_query"]:
301
  yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi."
302
  else:
303
  yield "⚠️ Không thể xác định chức năng phù hợp."
304
-
305
  # Xử lý lời gọi hàm và sinh phản hồi tương ứng
306
  response_iterator = handle_functions(
307
  function_call=function_call,
308
  prompt=message,
309
  chat_history=chat_history,
 
310
  max_new_tokens=max_new_tokens,
311
  temperature=temperature,
312
  top_p=top_p,
313
  top_k=top_k,
314
  repetition_penalty=repetition_penalty
315
  )
316
-
317
  for response in response_iterator:
318
  yield response
319
 
 
1
+
2
  import os
3
  from threading import Thread
4
  from typing import Iterator, List, Tuple, Dict, Any
5
+ import uuid
6
+ import json
7
 
8
  import gradio as gr
9
  import spaces
 
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
12
  from bs4 import BeautifulSoup
13
  import requests
 
14
  from functools import lru_cache
15
+ # from huggingface_hub import HfApi, HfFolder
16
+ from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets
17
 
18
  # ---------------------------- Cấu Hình ---------------------------- #
19
 
 
45
  # Khởi tạo pipeline phân tích tâm lý
46
  sentiment_pipeline = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
47
 
48
+ # ---------------------------- Thiết lập Bộ nhớ Sử dụng Huggingface Datasets ---------------------------- #
49
+
50
+ HF_DATASET = os.getenv("HF_DATASET") # Đảm bảo bạn đã set biến môi trường này "your_username/chat_memory" Thay đổi theo tên của bạn
51
+
52
+ def initialize_dataset():
53
+ """
54
+ Khởi tạo Dataset trên Huggingface Hub nếu chưa tồn tại.
55
+ """
56
+ try:
57
+ dataset = load_dataset(HF_DATASET)
58
+ except:
59
+ # Tạo Dataset mới nếu chưa tồn tại
60
+ dataset = DatasetDict({
61
+ "conversations": Dataset.from_dict({
62
+ "user_id": [],
63
+ "messages": []
64
+ })
65
+ })
66
+ dataset.push_to_hub(HF_DATASET, private=True)
67
+ return dataset
68
+
69
+ def save_conversation(user_id: str, messages: List[Tuple[str, str]]):
70
+ """
71
+ Lưu cuộc hội thoại của người dùng vào Dataset.
72
+ """
73
+ dataset = load_dataset(HF_DATASET)
74
+ # Chuyển đổi cuộc hội thoại thành định dạng JSON
75
+ messages_json = json.dumps(messages)
76
+ new_entry = {
77
+ "user_id": user_id,
78
+ "messages": messages_json
79
+ }
80
+ # Tạo Dataset từ entry mới
81
+ new_dataset = Dataset.from_dict(new_entry)
82
+ # Kết hợp với Dataset hiện tại
83
+ updated_dataset = concatenate_datasets([dataset["conversations"], new_dataset])
84
+ # Đẩy lên Hub
85
+ updated_dataset.push_to_hub(HF_DATASET, private=True)
86
+
87
+ def load_conversation(user_id: str) -> List[Tuple[str, str]]:
88
+ """
89
+ Truy xuất cuộc hội thoại của người dùng từ Dataset.
90
+ """
91
+ dataset = load_dataset(HF_DATASET)
92
+ # Tìm entry theo user_id
93
+ user_data = dataset["conversations"].filter(lambda x: x["user_id"] == user_id)
94
+ if len(user_data) == 0:
95
+ return []
96
+ messages_json = user_data["messages"][0]
97
+ return json.loads(messages_json)
98
+
99
+ # Khởi tạo Dataset
100
+ initialize_dataset()
101
+
102
  # ---------------------------- Định Nghĩa Hàm ---------------------------- #
103
 
104
  @lru_cache(maxsize=128)
 
161
  ]
162
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
163
  input_ids = input_ids.to(device)
164
+
165
  summary_streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
166
  summary_kwargs = {
167
  "input_ids": input_ids,
 
173
  }
174
  t = Thread(target=model.generate, kwargs=summary_kwargs)
175
  t.start()
176
+
177
  summary = ""
178
  for new_text in summary_streamer:
179
  summary += new_text
 
186
  score = result[0]['score']
187
  return f"🟢 **Tâm lý**: {sentiment} (Điểm: {score:.2f})"
188
 
189
+ def generate_response(prompt: str, chat_history: List[Tuple[str, str]], user_id: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
190
  """
191
  Tạo phản hồi sử dụng mô hình Llama cục bộ theo chế độ streaming.
192
  """
193
+ # Lấy lịch sử từ Dataset
194
+ conversation = load_conversation(user_id)
195
+ # Chuyển đổi lịch sử thành định dạng mà mô hình hiểu
196
+ conversation_formatted = []
197
+ for user_msg, assistant_msg in conversation:
198
+ conversation_formatted.extend([
199
+ {"role": "user", "content": user_msg},
200
+ {"role": "assistant", "content": assistant_msg},
201
  ])
202
+ conversation_formatted.append({"role": "user", "content": prompt}) # Thêm tin nhắn của ngư��i dùng
203
+
204
+ # Kiểm tra độ dài và sử dụng bản tóm tắt nếu cần
205
+ if len(conversation_formatted) > 50: # Giới hạn số lượng tin nhắn, điều chỉnh tùy nhu cầu
206
+ summary = summarize_text(" ".join([msg["content"] for msg in conversation_formatted]))
207
+ # Lưu bản tóm tắt vào Dataset
208
+ new_messages = [("system", summary)]
209
+ save_conversation(user_id, new_messages)
210
+ # Giữ lại phần mới nhất
211
+ conversation_formatted = [{"role": "system", "content": summary}] + conversation_formatted[-25:]
212
 
213
  # Chuẩn bị input_ids từ tokenizer
214
+ input_ids = tokenizer.apply_chat_template(conversation_formatted, add_generation_prompt=True, return_tensors="pt")
215
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
216
+ # Sử dụng bản tóm tắt từ bộ nhớ
217
+ summary = summarize_text(" ".join([msg["content"] for msg in conversation_formatted]))
218
+ conversation_formatted = [{"role": "system", "content": summary}] + conversation_formatted[-(MAX_INPUT_TOKEN_LENGTH // 2):]
219
+ input_ids = tokenizer.apply_chat_template(conversation_formatted, add_generation_prompt=True, return_tensors="pt")
220
+ # Lưu lại bản tóm tắt
221
+ new_messages = [("system", summary)]
222
+ save_conversation(user_id, new_messages)
223
+
224
  input_ids = input_ids.to(device) # Di chuyển input tới thiết bị
225
+
226
  # Khởi tạo streamer để nhận văn bản được tạo ra theo thời gian thực
227
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
228
  generate_kwargs = {
 
238
  }
239
  t = Thread(target=model.generate, kwargs=generate_kwargs) # Tạo luồng để sinh văn bản
240
  t.start()
241
+
242
  # Stream văn bản được tạo ra
243
  outputs = []
244
  for text in streamer:
245
  outputs.append(text)
246
  yield "".join(outputs)
247
 
248
+ # Lưu phản hồi vào Dataset
249
+ response = "".join(outputs)
250
+ save_conversation(user_id, [(prompt, response)])
251
+
252
  @lru_cache(maxsize=128)
253
  def process_query(query: str) -> Dict[str, Any]:
254
  """
 
259
  general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"]
260
  summarize_keywords = ["tóm tắt", "tóm lại", "khái quát", "ngắn gọn"]
261
  sentiment_keywords = ["cảm xúc", "tâm trạng", "tâm lý", "phân tích cảm xúc"]
262
+ topic_keywords = ["chủ đề", "bàn về", "về"]
263
 
264
  query_lower = query.lower() # Chuyển truy vấn thành chữ thường để so sánh
265
+
266
  if any(keyword in query_lower for keyword in web_search_keywords):
267
  function_name = "web_search"
268
  arguments = {"query": query}
 
272
  elif any(keyword in query_lower for keyword in sentiment_keywords):
273
  function_name = "sentiment_analysis"
274
  arguments = {"prompt": query}
275
+ elif any(keyword in query_lower for keyword in topic_keywords):
276
+ function_name = "new_topic"
277
+ arguments = {"topic": query}
278
  elif any(keyword in query_lower for keyword in general_query_keywords):
279
  function_name = "general_query"
280
  arguments = {"prompt": query}
281
  else:
282
  function_name = "hard_query"
283
  arguments = {"prompt": query}
284
+
285
  return {
286
  "name": function_name,
287
  "arguments": arguments
288
  }
289
 
290
+ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], user_id: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
291
  """
292
  Thực thi hàm phù hợp dựa trên lời gọi hàm.
293
  """
294
  function_name = function_call["name"]
295
  arguments = function_call["arguments"]
296
+
297
  if function_name == "web_search":
298
  query = arguments["query"]
299
  yield "🔍 Đang thực hiện tìm kiếm trên web..."
 
305
  web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
306
  if not web_summary:
307
  web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
308
+
309
  # Trả về kết quả tìm kiếm cho người dùng
310
  yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
311
+
312
  elif function_name == "summarize_query":
313
  # Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả
314
  query = arguments["prompt"]
 
325
  # Tóm tắt nội dung đã lấy
326
  yield "📝 Đang tóm tắt thông tin..."
327
  summary = summarize_text(combined_text)
328
+ # Lưu tóm tắt vào Dataset
329
+ save_conversation(user_id, [("tóm tắt", summary)])
330
  yield "📄 **Tóm tắt:**\n" + summary
331
+
332
  elif function_name == "sentiment_analysis":
333
  prompt_text = arguments["prompt"]
334
  yield "📊 Đang phân tích tâm lý..."
335
  sentiment = analyze_sentiment(prompt_text)
336
  yield sentiment
337
+
338
+ elif function_name == "new_topic":
339
+ topic = arguments["topic"]
340
+ # Lưu chủ đề mới vào Dataset
341
+ save_conversation(user_id, [("chủ đề", f"Chủ đề mới: {topic}")])
342
+ yield f"🆕 Đã chuyển sang chủ đề mới: {topic}"
343
+
344
  elif function_name in ["general_query", "hard_query"]:
345
  prompt_text = arguments["prompt"]
346
  yield "🤖 Đang tạo phản hồi..."
 
348
  response_generator = generate_response(
349
  prompt=prompt_text,
350
  chat_history=chat_history,
351
+ user_id=user_id,
352
  max_new_tokens=max_new_tokens,
353
  temperature=temperature,
354
  top_p=top_p,
 
357
  )
358
  for response in response_generator:
359
  yield response
360
+
361
  else:
362
  yield "⚠️ Lời gọi hàm không được nhận dạng."
363
 
364
  # ---------------------------- Giao Diện Gradio ---------------------------- #
365
 
366
+ def get_user_id():
367
+ """
368
+ Tạo hoặc lấy user_id từ session state của Gradio.
369
+ Sử dụng cookie hoặc thông tin định danh tạm thời.
370
+ """
371
+ # Gradio hiện không hỗ trợ session state natively, cần sử dụng workaround
372
+ # Dưới đây là cách tạo user_id tạm thời cho mỗi phiên
373
+ if "user_id" not in gr.get_session_state():
374
+ gr.get_session_state()["user_id"] = str(uuid.uuid4())
375
+ return gr.get_session_state()["user_id"]
376
+
377
  @spaces.GPU(duration=15, queue=False)
378
  def generate(
379
  message: str,
 
389
  """
390
  # Thông báo về việc phân tích đầu vào
391
  yield "🔍 Đang phân tích truy vấn của bạn..."
392
+
393
+ # Lấy user_id từ session
394
+ user_id = get_user_id()
395
+
396
  # Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng
397
  function_call = process_query(message)
398
+
399
  # Thông báo về hàm được chọn
400
  if function_call["name"] == "web_search":
401
  yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web."
 
403
  yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản."
404
  elif function_call["name"] == "sentiment_analysis":
405
  yield "🛠️ Đã chọn chức năng: Phân tích tâm lý."
406
+ elif function_call["name"] == "new_topic":
407
+ yield "🛠️ Đã chọn chức năng: Chủ đề mới."
408
+ elif function_call["name"] in ["general_query", "hard
409
  elif function_call["name"] in ["general_query", "hard_query"]:
410
  yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi."
411
  else:
412
  yield "⚠️ Không thể xác định chức năng phù hợp."
413
+
414
  # Xử lý lời gọi hàm và sinh phản hồi tương ứng
415
  response_iterator = handle_functions(
416
  function_call=function_call,
417
  prompt=message,
418
  chat_history=chat_history,
419
+ user_id=user_id, # Sử dụng user_id để quản lý dữ liệu theo người dùng
420
  max_new_tokens=max_new_tokens,
421
  temperature=temperature,
422
  top_p=top_p,
423
  top_k=top_k,
424
  repetition_penalty=repetition_penalty
425
  )
426
+
427
  for response in response_iterator:
428
  yield response
429