|
|
|
|
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainerCallback |
|
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model |
|
from datasets import Dataset |
|
import json |
|
from trl import SFTTrainer, SFTConfig |
|
from transformers import TrainingArguments |
|
import time |
|
import os |
|
|
|
class CustomCallback(TrainerCallback): |
|
def __init__(self): |
|
self.start_time = time.time() |
|
|
|
def on_train_begin(self, args, state, control, **kwargs): |
|
print("Training has begun!") |
|
|
|
def on_step_end(self, args, state, control, **kwargs): |
|
if state.global_step % args.logging_steps == 0: |
|
elapsed_time = time.time() - self.start_time |
|
if state.log_history: |
|
loss = state.log_history[-1].get('loss', 0) |
|
print(f"Step: {state.global_step}, Loss: {loss:.4f}, Time: {elapsed_time:.2f}s") |
|
else: |
|
print(f"Step: {state.global_step}, Loss: N/A, Time: {elapsed_time:.2f}s") |
|
|
|
def on_train_end(self, args, state, control, **kwargs): |
|
print("Training has ended!") |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "7" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model_path = "Google/gemma-2-9b-Instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_use_double_quant=True, |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
quantization_config=bnb_config, |
|
device_map="auto", |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
|
|
model = prepare_model_for_kbit_training(model) |
|
lora_config = LoraConfig( |
|
r=6, |
|
lora_alpha=8, |
|
lora_dropout=0.05, |
|
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], |
|
task_type="CAUSAL_LM", |
|
) |
|
model = get_peft_model(model, lora_config) |
|
|
|
|
|
with open('en_ko_data', 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
|
|
|
|
def generate_prompt(en_text, ko_text): |
|
return f"""<bos><start_of_turn>user |
|
Please translate the following English colloquial expression into Korean.: |
|
|
|
{en_text}<end_of_turn> |
|
<start_of_turn>model |
|
{ko_text}<end_of_turn><eos>""" |
|
|
|
|
|
key = list(data.keys())[0] |
|
dataset = [{"text": generate_prompt(item['en_original'], item['ko'])} for item in data[key]] |
|
dataset = Dataset.from_list(dataset) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="./results", |
|
num_train_epochs=3, |
|
per_device_train_batch_size=4, |
|
gradient_accumulation_steps=4, |
|
save_steps=100, |
|
logging_steps=1, |
|
learning_rate=2e-4, |
|
weight_decay=0.01, |
|
fp16=True, |
|
optim="paged_adamw_8bit", |
|
) |
|
|
|
|
|
trainer = SFTTrainer( |
|
model=model, |
|
train_dataset=dataset, |
|
args=training_args, |
|
tokenizer=tokenizer, |
|
dataset_text_field="text", |
|
max_seq_length=512, |
|
) |
|
|
|
|
|
trainer.add_callback(CustomCallback()) |
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
trainer.save_model("./gemma2_9b_ko_translator") |