File size: 24,830 Bytes
07f44f8
bf64382
 
8d8f4b0
 
e51a5b0
8d8f4b0
 
999b0b0
8d8f4b0
 
07f44f8
8d8f4b0
999b0b0
 
 
8d8f4b0
 
 
 
 
 
 
 
 
 
 
 
 
a0546fe
fb114fa
8d8f4b0
 
 
07f44f8
8d8f4b0
e51a5b0
8d8f4b0
c3f15f3
8d8f4b0
 
 
 
bf64382
b09c08f
 
 
8d8f4b0
fb114fa
8d8f4b0
 
 
c56e31d
8d8f4b0
c56e31d
 
8d8f4b0
e51a5b0
c56e31d
 
8d8f4b0
 
c56e31d
 
8d8f4b0
 
 
 
c56e31d
8d8f4b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c56e31d
 
d4f4893
 
 
 
 
 
 
07f44f8
d4f4893
 
 
 
 
 
 
 
 
 
 
07f44f8
d4f4893
 
 
 
 
b09c08f
 
 
 
 
 
 
07f44f8
8d8f4b0
 
 
07f44f8
 
 
 
 
 
8d8f4b0
07f44f8
 
8d8f4b0
07f44f8
8d8f4b0
07f44f8
 
8d8f4b0
07f44f8
8d8f4b0
 
 
 
 
 
e51a5b0
8d8f4b0
 
 
 
 
 
 
 
07f44f8
8d8f4b0
 
 
 
07f44f8
00430c0
b09c08f
8d8f4b0
 
 
 
 
 
 
d4f4893
b09c08f
d4f4893
8d8f4b0
07f44f8
8d8f4b0
e51a5b0
8d8f4b0
d4f4893
 
 
b09c08f
 
 
 
 
 
e51a5b0
 
8d8f4b0
07f44f8
e51a5b0
 
8d8f4b0
e51a5b0
c56e31d
07f44f8
8d8f4b0
 
 
e51a5b0
 
07f44f8
e51a5b0
 
8d8f4b0
e51a5b0
8d8f4b0
 
 
 
 
 
 
07f44f8
8d8f4b0
 
07f44f8
d4f4893
5121e98
 
 
 
 
 
 
 
 
 
 
 
 
d4f4893
5121e98
d4f4893
07f44f8
b09c08f
 
 
 
 
07f44f8
e51a5b0
8d8f4b0
 
 
 
 
 
 
 
 
 
 
 
 
 
07f44f8
e51a5b0
8d8f4b0
c56e31d
8d8f4b0
b04fed6
b904b20
8d8f4b0
 
 
 
 
 
 
 
07f44f8
8d8f4b0
 
 
d4f4893
07f44f8
a0546fe
07f44f8
8d8f4b0
 
07f44f8
d4f4893
 
07f44f8
d4f4893
07f44f8
b09c08f
9c3f080
07f44f8
d4f4893
07f44f8
d4f4893
07f44f8
 
8d8f4b0
 
 
 
 
 
 
 
 
 
 
07f44f8
8d8f4b0
07f44f8
46030b6
8d8f4b0
 
 
 
 
 
 
 
 
d4f4893
b09c08f
8d8f4b0
bf64382
d4f4893
07f44f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
999b0b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b904b20
999b0b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07f44f8
d4f4893
 
b09c08f
d4f4893
 
b09c08f
d4f4893
 
b09c08f
 
d4f4893
 
b09c08f
 
 
 
 
 
 
d4f4893
 
8d8f4b0
07f44f8
 
fb114fa
 
b904b20
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589

import os
from threading import Thread
from typing import Iterator, List, Tuple, Dict, Any

import gradio as gr
import spaces
import torch
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling, TrainerCallback,AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
from bs4 import BeautifulSoup
import requests
import json
from functools import lru_cache
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
import time

# ---------------------------- Cấu Hình ---------------------------- #

DESCRIPTION = """\
# Llama 3.2 3B Instruct với Chức Năng Nâng Cao

Llama 3.2 3B là phiên bản mới nhất của Meta về các mô hình ngôn ngữ mở.
Demo này giới thiệu [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), được tinh chỉnh để theo dõi hướng dẫn.
Để biết thêm chi tiết, vui lòng xem [bài viết của chúng tôi](https://huggingface.co/blog/llama32).
"""

MAX_MAX_NEW_TOKENS = 2048  # Số token tối đa có thể tạo ra
DEFAULT_MAX_NEW_TOKENS = 1024  # Số token tạo ra mặc định
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000"))  # Độ dài token tối đa cho đầu vào

# Xác định thiết bị sử dụng (GPU nếu có, ngược lại CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_id = "meta-llama/Llama-3.2-3B-Instruct"  # ID mô hình, đảm bảo đây là ID mô hình đúng
tokenizer = AutoTokenizer.from_pretrained(model_id)  # Tải tokenizer từ Hugging Face
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,  # Sử dụng dtype phù hợp để tiết kiệm bộ nhớ
)
model.to(device)  # Di chuyển mô hình tới thiết bị đã chọn
model.eval()  # Đặt mô hình ở chế độ đánh giá

# Khởi tạo pipeline phân tích tâm lý
sentiment_pipeline = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")

# ---------------------------- Định Nghĩa Hàm ---------------------------- #

@lru_cache(maxsize=128)
def extract_text_from_webpage(html_content: str) -> str:
    """Trích xuất văn bản hiển thị từ nội dung HTML sử dụng BeautifulSoup."""
    soup = BeautifulSoup(html_content, "html.parser")
    # Loại bỏ các thẻ không hiển thị như script, style, header, footer, nav, form, svg
    for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
        tag.extract()
    # Trích xuất văn bản hiển thị, tách bằng dấu cách và loại bỏ khoảng trắng thừa
    visible_text = soup.get_text(separator=' ', strip=True)
    return visible_text

def search(query: str) -> List[Dict[str, Any]]:
    """Thực hiện tìm kiếm trên Google và trả về kết quả."""
    term = query
    all_results = []
    max_chars_per_page = 8000  # Số ký tự tối đa mỗi trang
    headers = {
        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
    }
    with requests.Session() as session:
        try:
            resp = session.get(
                url="https://www.google.com/search",
                headers=headers,
                params={"q": term, "num": 4},  # Tìm kiếm với 4 kết quả mỗi trang
                timeout=5,
                verify=False,  # Bỏ qua xác minh SSL
            )
            resp.raise_for_status()  # Kiểm tra phản hồi HTTP
            soup = BeautifulSoup(resp.text, "html.parser")
            result_blocks = soup.find_all("div", attrs={"class": "g"})  # Tìm tất cả các khối kết quả
            for result in result_blocks:
                link_tag = result.find("a", href=True)  # Tìm thẻ liên kết
                if link_tag and 'href' in link_tag.attrs:
                    link = link_tag["href"]
                    try:
                        webpage = session.get(
                            link,
                            headers=headers,
                            timeout=5,
                            verify=False
                        )
                        webpage.raise_for_status()
                        visible_text = extract_text_from_webpage(webpage.text)
                        if len(visible_text) > max_chars_per_page:
                            visible_text = visible_text[:max_chars_per_page]  # Cắt văn bản nếu quá dài
                        all_results.append({"link": link, "text": visible_text})
                    except requests.exceptions.RequestException:
                        all_results.append({"link": link, "text": "Không thể lấy nội dung."})
        except requests.exceptions.RequestException as e:
            all_results.append({"link": "N/A", "text": "Không thể thực hiện tìm kiếm."})
    return all_results

def summarize_text(text: str, max_length: int = 150) -> str:
    """Tóm tắt văn bản sử dụng mô hình Llama."""
    conversation = [
        {"role": "user", "content": f"Hãy tóm tắt đoạn văn sau: {text}"}
    ]
    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    input_ids = input_ids.to(device)
    
    summary_streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    summary_kwargs = {
        "input_ids": input_ids,
        "streamer": summary_streamer,
        "max_new_tokens": max_length,
        "do_sample": True,
        "top_p": 0.95,
        "temperature": 0.7,
    }
    t = Thread(target=model.generate, kwargs=summary_kwargs)
    t.start()
    
    summary = ""
    for new_text in summary_streamer:
        summary += new_text
    return summary

def analyze_sentiment(text: str) -> str:
    """Phân tích tâm lý của văn bản sử dụng mô hình."""
    result = sentiment_pipeline(text)
    sentiment = result[0]['label']
    score = result[0]['score']
    return f"🟢 **Tâm lý**: {sentiment} (Điểm: {score:.2f})"

def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
    """
    Tạo phản hồi sử dụng mô hình Llama cục bộ theo chế độ streaming.
    """
    # Xây dựng lịch sử cuộc trò chuyện
    conversation = []
    for user, assistant in chat_history:
        conversation.extend([
            {"role": "user", "content": user},
            {"role": "assistant", "content": assistant},
        ])
    conversation.append({"role": "user", "content": prompt})  # Thêm tin nhắn của người dùng
    
    # Chuẩn bị input_ids từ tokenizer
    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]  # Cắt input nếu quá dài
        gr.Warning(f"Đã cắt bỏ phần cuộc trò chuyện vì vượt quá {MAX_INPUT_TOKEN_LENGTH} token.")
    input_ids = input_ids.to(device)  # Di chuyển input tới thiết bị
    
    # Khởi tạo streamer để nhận văn bản được tạo ra theo thời gian thực
    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = {
        "input_ids": input_ids,
        "streamer": streamer,
        "max_new_tokens": max_new_tokens,
        "do_sample": True,
        "top_p": top_p,
        "top_k": top_k,
        "temperature": temperature,
        "num_beams": 1,
        "repetition_penalty": repetition_penalty,
    }
    t = Thread(target=model.generate, kwargs=generate_kwargs)  # Tạo luồng để sinh văn bản
    t.start()
    
    # Stream văn bản được tạo ra
    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

@lru_cache(maxsize=128)
def process_query(query: str) -> Dict[str, Any]:
    """
    Xác định hàm nào sẽ được gọi dựa trên truy vấn của người dùng.
    """
    # Định nghĩa các từ khóa hoặc mẫu để xác định hàm
    web_search_keywords = ["tìm kiếm", "tìm", "tra cứu", "google", "lookup"]
    general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"]
    summarize_keywords = ["tóm tắt", "tóm lại", "khái quát", "ngắn gọn"]
    sentiment_keywords = ["cảm xúc", "tâm trạng", "tâm lý", "phân tích cảm xúc"]

    query_lower = query.lower()  # Chuyển truy vấn thành chữ thường để so sánh
    
    if any(keyword in query_lower for keyword in web_search_keywords):
        function_name = "web_search"
        arguments = {"query": query}
    elif any(keyword in query_lower for keyword in summarize_keywords):
        function_name = "summarize_query"
        arguments = {"prompt": query}
    elif any(keyword in query_lower for keyword in sentiment_keywords):
        function_name = "sentiment_analysis"
        arguments = {"prompt": query}
    elif any(keyword in query_lower for keyword in general_query_keywords):
        function_name = "general_query"
        arguments = {"prompt": query}
    else:
        function_name = "hard_query"
        arguments = {"prompt": query}
    
    return {
        "name": function_name,
        "arguments": arguments
    }

def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
    """
    Thực thi hàm phù hợp dựa trên lời gọi hàm.
    """
    function_name = function_call["name"]
    arguments = function_call["arguments"]
    
    if function_name == "web_search":
        query = arguments["query"]
        yield "🔍 Đang thực hiện tìm kiếm trên web..."
        web_results = search(query)
        if not web_results:
            yield "⚠️ Không tìm thấy kết quả."
            return
        # Tóm tắt kết quả tìm kiếm
        web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
        if not web_summary:
            web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
        
        # Trả về kết quả tìm kiếm cho người dùng
        yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
    
    elif function_name == "summarize_query":
        # Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả
        query = arguments["prompt"]
        yield "🔍 Đang thực hiện tìm kiếm để tóm tắt..."
        web_results = search(query)
        if not web_results:
            yield "⚠️ Không tìm thấy kết quả để tóm tắt."
            return
        # Lấy nội dung từ kết quả tìm kiếm để tóm tắt
        combined_text = ' '.join([res['text'] for res in web_results if res['text'] != "Không thể lấy nội dung."])
        if not combined_text:
            yield "⚠️ Không có nội dung để tóm tắt."
            return
        # Tóm tắt nội dung đã lấy
        yield "📝 Đang tóm tắt thông tin..."
        summary = summarize_text(combined_text)
        yield "📄 **Tóm tắt:**\n" + summary
    
    elif function_name == "sentiment_analysis":
        prompt_text = arguments["prompt"]
        yield "📊 Đang phân tích tâm lý..."
        sentiment = analyze_sentiment(prompt_text)
        yield sentiment
    
    elif function_name in ["general_query", "hard_query"]:
        prompt_text = arguments["prompt"]
        yield "🤖 Đang tạo phản hồi..."
        # Tạo phản hồi sử dụng mô hình Llama
        response_generator = generate_response(
            prompt=prompt_text,
            chat_history=chat_history,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=repetition_penalty
        )
        for response in response_generator:
            yield response
    
    else:
        yield "⚠️ Lời gọi hàm không được nhận dạng."

# ---------------------------- Giao Diện Gradio ---------------------------- #

@spaces.GPU(duration=30, queue=False)
def generate(
    message: str,
    chat_history: List[Tuple[str, str]],
    max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    """
    Hàm chính để xử lý đầu vào của người dùng và tạo phản hồi.
    """
    # Thông báo về việc phân tích đầu vào
    yield "🔍 Đang phân tích truy vấn của bạn..."

    
    # Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng
    function_call = process_query(message)
    
    # Thông báo về hàm được chọn
    if function_call["name"] == "web_search":
        yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web."
    elif function_call["name"] == "summarize_query":
        yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản."
    elif function_call["name"] == "sentiment_analysis":
        continuous_training(total_steps=300, steps_per_call=50)
        yield "🛠️ Đã chọn chức năng: Phân tích tâm lý."
    elif function_call["name"] in ["general_query", "hard_query"]:
        yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi."
    else:
        yield "⚠️ Không thể xác định chức năng phù hợp."
    
    # Xử lý lời gọi hàm và sinh phản hồi tương ứng
    response_iterator = handle_functions(
        function_call=function_call,
        prompt=message,
        chat_history=chat_history,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty
    )
    
    for response in response_iterator:
        yield response

# Định nghĩa các ví dụ để hướng dẫn người dùng
EXAMPLES = [
    ["Xin chào! Bạn khỏe không?"],
    ["Bạn có thể giải thích ngắn gọn về ngôn ngữ lập trình Python không?"],
    ["Giải thích cốt truyện của Cô bé Lọ Lem trong một câu."],
    ["Một người đàn ông cần bao nhiêu giờ để ăn một chiếc máy bay trực thăng?"],
    ["Viết một bài báo 100 từ về 'Lợi ích của mã nguồn mở trong nghiên cứu AI'"],
    ["Tìm và cung cấp cho tôi tin tức mới nhất về năng lượng tái tạo."],
    ["Tìm thông tin về Rạn san hô Great Barrier Reef."],
    ["Tóm tắt nội dung về trí tuệ nhân tạo."],
    ["Phân tích tâm lý của đoạn văn sau: Tôi rất vui khi được gặp bạn hôm nay!"],
]

# Cấu hình giao diện trò chuyện của Gradio với giao diện đẹp mắt
chat_interface = gr.ChatInterface(
    fn=generate,  # Hàm được gọi khi có tương tác từ người dùng
    additional_inputs=[
        gr.Slider(
            label="Số token mới tối đa",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Nhiệt độ",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Hình phạt sự lặp lại",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,  # Không có nút dừng
    examples=EXAMPLES,  # Các ví dụ được hiển thị cho người dùng
    cache_examples=False,  # Không lưu bộ nhớ cache cho các ví dụ
    title="🤖 OpenGPT-4o Chatbot",
    description="Một trợ lý AI mạnh mẽ sử dụng mô hình Llama-3.2 cục bộ với các chức năng tìm kiếm web, tóm tắt văn bản và phân tích tâm lý.",
    theme="default",  # Có thể thay đổi theme để giao diện đẹp hơn
)


# Đường dẫn lưu checkpoint
CHECKPOINT_DIR = "./checkpoints"
if not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)

# Tải Dataset (CPU)
dataset = load_dataset('vntc/wiki-mini-corpus')

# Chia Dataset thành train và validation (CPU)
split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset['train']
validation_dataset = split_dataset['test']

# Tiền Xử Lý Văn Bản (CPU)
def preprocess_function(examples):
    passages = [passage.lower().strip() for passage in examples['passage']]
    return {'passage': passages}

processed_train = train_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
processed_validation = validation_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])

# Tokenization (CPU)
model_name = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Đảm bảo tokenizer có pad_token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    return tokenizer(
        examples['passage'],
        padding='max_length',
        truncation=True,
        max_length=512,
    )

tokenized_train = processed_train.map(tokenize_function, batched=True)
tokenized_validation = processed_validation.map(tokenize_function, batched=True)

# Thêm trường 'labels' (CPU)
def add_labels(examples):
    examples['labels'] = examples['input_ids'].copy()
    return examples

tokenized_train = tokenized_train.map(add_labels, batched=True)
tokenized_validation = tokenized_validation.map(add_labels, batched=True)

# Loại bỏ các cột không cần thiết (CPU)
tokenized_train = tokenized_train.remove_columns(['passage'])
tokenized_validation = tokenized_validation.remove_columns(['passage'])

# Định dạng dữ liệu cho PyTorch (CPU)
tokenized_train.set_format('torch')
tokenized_validation.set_format('torch')

# Tạo DatasetDict (CPU)
final_dataset = {
    'train': tokenized_train,
    'validation': tokenized_validation
}

# Định Nghĩa TrainerCallback để Lưu Checkpoint Nhanh Hơn
class SaveCheckpointCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % args.save_steps == 0 and state.global_step != 0:
            checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
            print(f"Lưu checkpoint tại: {checkpoint_path}")
            trainer = kwargs['trainer']  # Truy cập trainer từ kwargs
            trainer.save_model(checkpoint_path)
        return control  # Trả về đối tượng control hiện tại

# Định Nghĩa Hàm Huấn Luyện với Decorator @spaces.GPU
@spaces.GPU(duration=30, queue=False)
def run_training():
    """
    Hàm huấn luyện mô hình sử dụng GPU với thời gian hạn chế.
    """
    # Tải và Cấu Hình Mô Hình với LoRA (GPU)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.float16,
        load_in_8bit=False
    )
    
    lora_config = LoraConfig(
        r=8,
        lora_alpha=32,
        target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
        lora_dropout=0.1,
        bias="none",
    )
    
    model = get_peft_model(model, lora_config)
    print(model)
    
    # Cấu Hình TrainingArguments (GPU)
    training_args = TrainingArguments(
        output_dir=CHECKPOINT_DIR,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=8,
        num_train_epochs=3,
        max_steps=50,  # Đặt max_steps tại đây
        learning_rate=3e-4,
        weight_decay=0.01,
        logging_steps=5,  # Giảm số bước logging để theo dõi thường xuyên hơn
        eval_strategy="steps",  # Đánh giá sau mỗi vài bước
        eval_steps=5,  # Đánh giá sau mỗi 50 bước
        save_strategy="steps",  # Lưu checkpoint sau mỗi vài bước
        save_steps=5,  # Lưu checkpoint sau mỗi 50 bước
        save_total_limit=5,  # Giới hạn số lượng checkpoint lưu trữ
        fp16=True,
        report_to="none",
        load_best_model_at_end=True,
    )
    
    # Data Collator (GPU)
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, 
        mlm=False,  # Vì bạn đang thực hiện Causal LM
        pad_to_multiple_of=8
    )
    
    # Tạo Trainer (GPU)
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=final_dataset['train'],
        eval_dataset=final_dataset['validation'],
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=[SaveCheckpointCallback()],  # Thêm callback
    )
    
    # Kiểm tra nếu có checkpoint
    checkpoints = [os.path.join(CHECKPOINT_DIR, d) for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint')]
    if checkpoints:
        latest_checkpoint = max(checkpoints, key=os.path.getctime)
        print(f"Đang tiếp tục huấn luyện từ checkpoint: {latest_checkpoint}")
        trainer.train(resume_from_checkpoint=latest_checkpoint)
    else:
        trainer.train()
    
    # Lưu checkpoint sau khi huấn luyện
    trainer.save_model(CHECKPOINT_DIR)
    return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."

# Hàm Tự Động Hóa Việc Gọi Lặp Lại Hàm Huấn Luyện
def continuous_training(total_steps=300, steps_per_call=5):
    """
    Hàm tự động gọi lại `run_training` để hoàn thành quá trình huấn luyện.
    
    Args:
        total_steps (int): Tổng số bước huấn luyện mong muốn.
        steps_per_call (int): Số bước huấn luyện mỗi lần gọi hàm.
    """
    steps_done = 0
    while steps_done < total_steps:
        print(f"Bắt đầu huấn luyện cho {steps_per_call} bước.")
        result = run_training()
        print(result)
        steps_done += steps_per_call
        print(f"Đã huấn luyện {steps_done} / {total_steps} bước.")
        
        # Kiểm tra nếu đã đạt số bước mong muốn
        if steps_done >= total_steps:
            print("Đã hoàn thành toàn bộ quá trình huấn luyện.")
            break
        
        # Chờ một khoảng thời gian trước khi gọi lại (tùy thuộc vào yêu cầu của hệ thống)
        time.sleep(2)  # Thời gian chờ có thể điều chỉnh

# Tạo giao diện chính của Gradio với CSS tùy chỉnh
with gr.Blocks(css="""
    .gradio-container {
        background-color: #f0f2f5;  /* Màu nền nhẹ nhàng */
    }
    .gradio-container h1 {
        color: #4a90e2;  /* Màu xanh dương cho tiêu đề */
    }
    .gradio-container .gr-button {
        background-color: #4a90e2;  /* Màu xanh dương cho nút */
        color: white;  /* Màu chữ trắng trên nút */
    }
    .gradio-container .gr-slider__label {
        color: #333333;  /* Màu chữ đen cho nhãn slider */
    }
    .gradio-container .gr-chatbot {
        border: 2px solid #4a90e2;  /* Viền xanh dương cho chatbot */
        border-radius: 10px;  /* Bo góc viền chatbot */
        padding: 10px;  /* Khoảng cách bên trong chatbot */
        background-color: #ffffff;  /* Màu nền trắng cho chatbot */
    }
""", fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)  # Hiển thị mô tả
    gr.DuplicateButton(value="Nhân bản Không gian để sử dụng riêng tư", elem_id="duplicate-button")  # Nút nhân bản không gian
    chat_interface.render()  # Hiển thị giao diện trò chuyện

if __name__ == "__main__":
    demo.queue(max_size=30).launch()  # Khởi chạy ứng dụng Gradio với hàng đợi kích thước tối đa là 20