Spaces:
Sleeping
Sleeping
hoduyquocbao
commited on
Commit
•
999b0b0
1
Parent(s):
309b446
join code to app
Browse files
app.py
CHANGED
@@ -6,12 +6,14 @@ from typing import Iterator, List, Tuple, Dict, Any
|
|
6 |
import gradio as gr
|
7 |
import spaces
|
8 |
import torch
|
9 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
|
10 |
from bs4 import BeautifulSoup
|
11 |
import requests
|
12 |
import json
|
13 |
from functools import lru_cache
|
14 |
-
from
|
|
|
|
|
15 |
|
16 |
# ---------------------------- Cấu Hình ---------------------------- #
|
17 |
|
@@ -382,6 +384,181 @@ chat_interface = gr.ChatInterface(
|
|
382 |
theme="default", # Có thể thay đổi theme để giao diện đẹp hơn
|
383 |
)
|
384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
# Tạo giao diện chính của Gradio với CSS tùy chỉnh
|
386 |
with gr.Blocks(css="""
|
387 |
.gradio-container {
|
|
|
6 |
import gradio as gr
|
7 |
import spaces
|
8 |
import torch
|
9 |
+
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling, TrainerCallback,AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
|
10 |
from bs4 import BeautifulSoup
|
11 |
import requests
|
12 |
import json
|
13 |
from functools import lru_cache
|
14 |
+
from datasets import load_dataset
|
15 |
+
from peft import LoraConfig, get_peft_model
|
16 |
+
import time
|
17 |
|
18 |
# ---------------------------- Cấu Hình ---------------------------- #
|
19 |
|
|
|
384 |
theme="default", # Có thể thay đổi theme để giao diện đẹp hơn
|
385 |
)
|
386 |
|
387 |
+
|
388 |
+
# Đường dẫn lưu checkpoint
|
389 |
+
CHECKPOINT_DIR = "./checkpoints"
|
390 |
+
if not os.path.exists(CHECKPOINT_DIR):
|
391 |
+
os.makedirs(CHECKPOINT_DIR)
|
392 |
+
|
393 |
+
# Tải Dataset (CPU)
|
394 |
+
dataset = load_dataset('vntc/wiki-mini-corpus')
|
395 |
+
|
396 |
+
# Chia Dataset thành train và validation (CPU)
|
397 |
+
split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
|
398 |
+
train_dataset = split_dataset['train']
|
399 |
+
validation_dataset = split_dataset['test']
|
400 |
+
|
401 |
+
# Tiền Xử Lý Văn Bản (CPU)
|
402 |
+
def preprocess_function(examples):
|
403 |
+
passages = [passage.lower().strip() for passage in examples['passage']]
|
404 |
+
return {'passage': passages}
|
405 |
+
|
406 |
+
processed_train = train_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
|
407 |
+
processed_validation = validation_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
|
408 |
+
|
409 |
+
# Tokenization (CPU)
|
410 |
+
model_name = "meta-llama/Llama-3.2-3B-Instruct"
|
411 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
412 |
+
|
413 |
+
# Đảm bảo tokenizer có pad_token
|
414 |
+
if tokenizer.pad_token is None:
|
415 |
+
tokenizer.pad_token = tokenizer.eos_token
|
416 |
+
|
417 |
+
def tokenize_function(examples):
|
418 |
+
return tokenizer(
|
419 |
+
examples['passage'],
|
420 |
+
padding='max_length',
|
421 |
+
truncation=True,
|
422 |
+
max_length=512,
|
423 |
+
)
|
424 |
+
|
425 |
+
tokenized_train = processed_train.map(tokenize_function, batched=True)
|
426 |
+
tokenized_validation = processed_validation.map(tokenize_function, batched=True)
|
427 |
+
|
428 |
+
# Thêm trường 'labels' (CPU)
|
429 |
+
def add_labels(examples):
|
430 |
+
examples['labels'] = examples['input_ids'].copy()
|
431 |
+
return examples
|
432 |
+
|
433 |
+
tokenized_train = tokenized_train.map(add_labels, batched=True)
|
434 |
+
tokenized_validation = tokenized_validation.map(add_labels, batched=True)
|
435 |
+
|
436 |
+
# Loại bỏ các cột không cần thiết (CPU)
|
437 |
+
tokenized_train = tokenized_train.remove_columns(['passage'])
|
438 |
+
tokenized_validation = tokenized_validation.remove_columns(['passage'])
|
439 |
+
|
440 |
+
# Định dạng dữ liệu cho PyTorch (CPU)
|
441 |
+
tokenized_train.set_format('torch')
|
442 |
+
tokenized_validation.set_format('torch')
|
443 |
+
|
444 |
+
# Tạo DatasetDict (CPU)
|
445 |
+
final_dataset = {
|
446 |
+
'train': tokenized_train,
|
447 |
+
'validation': tokenized_validation
|
448 |
+
}
|
449 |
+
|
450 |
+
# Định Nghĩa TrainerCallback để Lưu Checkpoint Nhanh Hơn
|
451 |
+
class SaveCheckpointCallback(TrainerCallback):
|
452 |
+
def on_step_end(self, args, state, control, **kwargs):
|
453 |
+
if state.global_step % args.save_steps == 0 and state.global_step != 0:
|
454 |
+
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
455 |
+
print(f"Lưu checkpoint tại: {checkpoint_path}")
|
456 |
+
trainer = kwargs['trainer'] # Truy cập trainer từ kwargs
|
457 |
+
trainer.save_model(checkpoint_path)
|
458 |
+
return control # Trả về đối tượng control hiện tại
|
459 |
+
|
460 |
+
# Định Nghĩa Hàm Huấn Luyện với Decorator @spaces.GPU
|
461 |
+
@spaces.GPU(duration=15, queue=False)
|
462 |
+
def run_training():
|
463 |
+
"""
|
464 |
+
Hàm huấn luyện mô hình sử dụng GPU với thời gian hạn chế.
|
465 |
+
"""
|
466 |
+
# Tải và Cấu Hình Mô Hình với LoRA (GPU)
|
467 |
+
model = AutoModelForCausalLM.from_pretrained(
|
468 |
+
model_name,
|
469 |
+
device_map="auto",
|
470 |
+
torch_dtype=torch.float16,
|
471 |
+
load_in_8bit=False
|
472 |
+
)
|
473 |
+
|
474 |
+
lora_config = LoraConfig(
|
475 |
+
r=8,
|
476 |
+
lora_alpha=32,
|
477 |
+
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
478 |
+
lora_dropout=0.1,
|
479 |
+
bias="none",
|
480 |
+
)
|
481 |
+
|
482 |
+
model = get_peft_model(model, lora_config)
|
483 |
+
print(model)
|
484 |
+
|
485 |
+
# Cấu Hình TrainingArguments (GPU)
|
486 |
+
training_args = TrainingArguments(
|
487 |
+
output_dir=CHECKPOINT_DIR,
|
488 |
+
per_device_train_batch_size=4,
|
489 |
+
per_device_eval_batch_size=4,
|
490 |
+
gradient_accumulation_steps=8,
|
491 |
+
num_train_epochs=3,
|
492 |
+
max_steps=50, # Đặt max_steps tại đây
|
493 |
+
learning_rate=3e-4,
|
494 |
+
weight_decay=0.01,
|
495 |
+
logging_steps=5, # Giảm số bước logging để theo dõi thường xuyên hơn
|
496 |
+
eval_strategy="steps", # Đánh giá sau mỗi vài bước
|
497 |
+
eval_steps=5, # Đánh giá sau mỗi 50 bước
|
498 |
+
save_strategy="steps", # Lưu checkpoint sau mỗi vài bước
|
499 |
+
save_steps=5, # Lưu checkpoint sau mỗi 50 bước
|
500 |
+
save_total_limit=5, # Giới hạn số lượng checkpoint lưu trữ
|
501 |
+
fp16=True,
|
502 |
+
report_to="none",
|
503 |
+
load_best_model_at_end=True,
|
504 |
+
)
|
505 |
+
|
506 |
+
# Data Collator (GPU)
|
507 |
+
data_collator = DataCollatorForLanguageModeling(
|
508 |
+
tokenizer=tokenizer,
|
509 |
+
mlm=False, # Vì bạn đang thực hiện Causal LM
|
510 |
+
pad_to_multiple_of=8
|
511 |
+
)
|
512 |
+
|
513 |
+
# Tạo Trainer (GPU)
|
514 |
+
trainer = Trainer(
|
515 |
+
model=model,
|
516 |
+
args=training_args,
|
517 |
+
train_dataset=final_dataset['train'],
|
518 |
+
eval_dataset=final_dataset['validation'],
|
519 |
+
tokenizer=tokenizer,
|
520 |
+
data_collator=data_collator,
|
521 |
+
callbacks=[SaveCheckpointCallback()], # Thêm callback
|
522 |
+
)
|
523 |
+
|
524 |
+
# Kiểm tra nếu có checkpoint
|
525 |
+
checkpoints = [os.path.join(CHECKPOINT_DIR, d) for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint')]
|
526 |
+
if checkpoints:
|
527 |
+
latest_checkpoint = max(checkpoints, key=os.path.getctime)
|
528 |
+
print(f"Đang tiếp tục huấn luyện từ checkpoint: {latest_checkpoint}")
|
529 |
+
trainer.train(resume_from_checkpoint=latest_checkpoint)
|
530 |
+
else:
|
531 |
+
trainer.train()
|
532 |
+
|
533 |
+
# Lưu checkpoint sau khi huấn luyện
|
534 |
+
trainer.save_model(CHECKPOINT_DIR)
|
535 |
+
return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."
|
536 |
+
|
537 |
+
# Hàm Tự Động Hóa Việc Gọi Lặp Lại Hàm Huấn Luyện
|
538 |
+
def continuous_training(total_steps=300, steps_per_call=5):
|
539 |
+
"""
|
540 |
+
Hàm tự động gọi lại `run_training` để hoàn thành quá trình huấn luyện.
|
541 |
+
|
542 |
+
Args:
|
543 |
+
total_steps (int): Tổng số bước huấn luyện mong muốn.
|
544 |
+
steps_per_call (int): Số bước huấn luyện mỗi lần gọi hàm.
|
545 |
+
"""
|
546 |
+
steps_done = 0
|
547 |
+
while steps_done < total_steps:
|
548 |
+
print(f"Bắt đầu huấn luyện cho {steps_per_call} bước.")
|
549 |
+
result = run_training()
|
550 |
+
print(result)
|
551 |
+
steps_done += steps_per_call
|
552 |
+
print(f"Đã huấn luyện {steps_done} / {total_steps} bước.")
|
553 |
+
|
554 |
+
# Kiểm tra nếu đã đạt số bước mong muốn
|
555 |
+
if steps_done >= total_steps:
|
556 |
+
print("Đã hoàn thành toàn bộ quá trình huấn luyện.")
|
557 |
+
break
|
558 |
+
|
559 |
+
# Chờ một khoảng thời gian trước khi gọi lại (tùy thuộc vào yêu cầu của hệ thống)
|
560 |
+
time.sleep(2) # Thời gian chờ có thể điều chỉnh
|
561 |
+
|
562 |
# Tạo giao diện chính của Gradio với CSS tùy chỉnh
|
563 |
with gr.Blocks(css="""
|
564 |
.gradio-container {
|