hoduyquocbao commited on
Commit
999b0b0
1 Parent(s): 309b446

join code to app

Browse files
Files changed (1) hide show
  1. app.py +179 -2
app.py CHANGED
@@ -6,12 +6,14 @@ from typing import Iterator, List, Tuple, Dict, Any
6
  import gradio as gr
7
  import spaces
8
  import torch
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
10
  from bs4 import BeautifulSoup
11
  import requests
12
  import json
13
  from functools import lru_cache
14
- from checkpoint import continuous_training
 
 
15
 
16
  # ---------------------------- Cấu Hình ---------------------------- #
17
 
@@ -382,6 +384,181 @@ chat_interface = gr.ChatInterface(
382
  theme="default", # Có thể thay đổi theme để giao diện đẹp hơn
383
  )
384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  # Tạo giao diện chính của Gradio với CSS tùy chỉnh
386
  with gr.Blocks(css="""
387
  .gradio-container {
 
6
  import gradio as gr
7
  import spaces
8
  import torch
9
+ from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling, TrainerCallback,AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
10
  from bs4 import BeautifulSoup
11
  import requests
12
  import json
13
  from functools import lru_cache
14
+ from datasets import load_dataset
15
+ from peft import LoraConfig, get_peft_model
16
+ import time
17
 
18
  # ---------------------------- Cấu Hình ---------------------------- #
19
 
 
384
  theme="default", # Có thể thay đổi theme để giao diện đẹp hơn
385
  )
386
 
387
+
388
+ # Đường dẫn lưu checkpoint
389
+ CHECKPOINT_DIR = "./checkpoints"
390
+ if not os.path.exists(CHECKPOINT_DIR):
391
+ os.makedirs(CHECKPOINT_DIR)
392
+
393
+ # Tải Dataset (CPU)
394
+ dataset = load_dataset('vntc/wiki-mini-corpus')
395
+
396
+ # Chia Dataset thành train và validation (CPU)
397
+ split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
398
+ train_dataset = split_dataset['train']
399
+ validation_dataset = split_dataset['test']
400
+
401
+ # Tiền Xử Lý Văn Bản (CPU)
402
+ def preprocess_function(examples):
403
+ passages = [passage.lower().strip() for passage in examples['passage']]
404
+ return {'passage': passages}
405
+
406
+ processed_train = train_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
407
+ processed_validation = validation_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
408
+
409
+ # Tokenization (CPU)
410
+ model_name = "meta-llama/Llama-3.2-3B-Instruct"
411
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
412
+
413
+ # Đảm bảo tokenizer có pad_token
414
+ if tokenizer.pad_token is None:
415
+ tokenizer.pad_token = tokenizer.eos_token
416
+
417
+ def tokenize_function(examples):
418
+ return tokenizer(
419
+ examples['passage'],
420
+ padding='max_length',
421
+ truncation=True,
422
+ max_length=512,
423
+ )
424
+
425
+ tokenized_train = processed_train.map(tokenize_function, batched=True)
426
+ tokenized_validation = processed_validation.map(tokenize_function, batched=True)
427
+
428
+ # Thêm trường 'labels' (CPU)
429
+ def add_labels(examples):
430
+ examples['labels'] = examples['input_ids'].copy()
431
+ return examples
432
+
433
+ tokenized_train = tokenized_train.map(add_labels, batched=True)
434
+ tokenized_validation = tokenized_validation.map(add_labels, batched=True)
435
+
436
+ # Loại bỏ các cột không cần thiết (CPU)
437
+ tokenized_train = tokenized_train.remove_columns(['passage'])
438
+ tokenized_validation = tokenized_validation.remove_columns(['passage'])
439
+
440
+ # Định dạng dữ liệu cho PyTorch (CPU)
441
+ tokenized_train.set_format('torch')
442
+ tokenized_validation.set_format('torch')
443
+
444
+ # Tạo DatasetDict (CPU)
445
+ final_dataset = {
446
+ 'train': tokenized_train,
447
+ 'validation': tokenized_validation
448
+ }
449
+
450
+ # Định Nghĩa TrainerCallback để Lưu Checkpoint Nhanh Hơn
451
+ class SaveCheckpointCallback(TrainerCallback):
452
+ def on_step_end(self, args, state, control, **kwargs):
453
+ if state.global_step % args.save_steps == 0 and state.global_step != 0:
454
+ checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
455
+ print(f"Lưu checkpoint tại: {checkpoint_path}")
456
+ trainer = kwargs['trainer'] # Truy cập trainer từ kwargs
457
+ trainer.save_model(checkpoint_path)
458
+ return control # Trả về đối tượng control hiện tại
459
+
460
+ # Định Nghĩa Hàm Huấn Luyện với Decorator @spaces.GPU
461
+ @spaces.GPU(duration=15, queue=False)
462
+ def run_training():
463
+ """
464
+ Hàm huấn luyện mô hình sử dụng GPU với thời gian hạn chế.
465
+ """
466
+ # Tải và Cấu Hình Mô Hình với LoRA (GPU)
467
+ model = AutoModelForCausalLM.from_pretrained(
468
+ model_name,
469
+ device_map="auto",
470
+ torch_dtype=torch.float16,
471
+ load_in_8bit=False
472
+ )
473
+
474
+ lora_config = LoraConfig(
475
+ r=8,
476
+ lora_alpha=32,
477
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
478
+ lora_dropout=0.1,
479
+ bias="none",
480
+ )
481
+
482
+ model = get_peft_model(model, lora_config)
483
+ print(model)
484
+
485
+ # Cấu Hình TrainingArguments (GPU)
486
+ training_args = TrainingArguments(
487
+ output_dir=CHECKPOINT_DIR,
488
+ per_device_train_batch_size=4,
489
+ per_device_eval_batch_size=4,
490
+ gradient_accumulation_steps=8,
491
+ num_train_epochs=3,
492
+ max_steps=50, # Đặt max_steps tại đây
493
+ learning_rate=3e-4,
494
+ weight_decay=0.01,
495
+ logging_steps=5, # Giảm số bước logging để theo dõi thường xuyên hơn
496
+ eval_strategy="steps", # Đánh giá sau mỗi vài bước
497
+ eval_steps=5, # Đánh giá sau mỗi 50 bước
498
+ save_strategy="steps", # Lưu checkpoint sau mỗi vài bước
499
+ save_steps=5, # Lưu checkpoint sau mỗi 50 bước
500
+ save_total_limit=5, # Giới hạn số lượng checkpoint lưu trữ
501
+ fp16=True,
502
+ report_to="none",
503
+ load_best_model_at_end=True,
504
+ )
505
+
506
+ # Data Collator (GPU)
507
+ data_collator = DataCollatorForLanguageModeling(
508
+ tokenizer=tokenizer,
509
+ mlm=False, # Vì bạn đang thực hiện Causal LM
510
+ pad_to_multiple_of=8
511
+ )
512
+
513
+ # Tạo Trainer (GPU)
514
+ trainer = Trainer(
515
+ model=model,
516
+ args=training_args,
517
+ train_dataset=final_dataset['train'],
518
+ eval_dataset=final_dataset['validation'],
519
+ tokenizer=tokenizer,
520
+ data_collator=data_collator,
521
+ callbacks=[SaveCheckpointCallback()], # Thêm callback
522
+ )
523
+
524
+ # Kiểm tra nếu có checkpoint
525
+ checkpoints = [os.path.join(CHECKPOINT_DIR, d) for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint')]
526
+ if checkpoints:
527
+ latest_checkpoint = max(checkpoints, key=os.path.getctime)
528
+ print(f"Đang tiếp tục huấn luyện từ checkpoint: {latest_checkpoint}")
529
+ trainer.train(resume_from_checkpoint=latest_checkpoint)
530
+ else:
531
+ trainer.train()
532
+
533
+ # Lưu checkpoint sau khi huấn luyện
534
+ trainer.save_model(CHECKPOINT_DIR)
535
+ return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."
536
+
537
+ # Hàm Tự Động Hóa Việc Gọi Lặp Lại Hàm Huấn Luyện
538
+ def continuous_training(total_steps=300, steps_per_call=5):
539
+ """
540
+ Hàm tự động gọi lại `run_training` để hoàn thành quá trình huấn luyện.
541
+
542
+ Args:
543
+ total_steps (int): Tổng số bước huấn luyện mong muốn.
544
+ steps_per_call (int): Số bước huấn luyện mỗi lần gọi hàm.
545
+ """
546
+ steps_done = 0
547
+ while steps_done < total_steps:
548
+ print(f"Bắt đầu huấn luyện cho {steps_per_call} bước.")
549
+ result = run_training()
550
+ print(result)
551
+ steps_done += steps_per_call
552
+ print(f"Đã huấn luyện {steps_done} / {total_steps} bước.")
553
+
554
+ # Kiểm tra nếu đã đạt số bước mong muốn
555
+ if steps_done >= total_steps:
556
+ print("Đã hoàn thành toàn bộ quá trình huấn luyện.")
557
+ break
558
+
559
+ # Chờ một khoảng thời gian trước khi gọi lại (tùy thuộc vào yêu cầu của hệ thống)
560
+ time.sleep(2) # Thời gian chờ có thể điều chỉnh
561
+
562
  # Tạo giao diện chính của Gradio với CSS tùy chỉnh
563
  with gr.Blocks(css="""
564
  .gradio-container {