Soonchan commited on
Commit
0107eb2
1 Parent(s): 7838f85

Create ft.py

Browse files
Files changed (1) hide show
  1. ft.py +116 -0
ft.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #fine-tuning code
2
+
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainerCallback
5
+ from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
6
+ from datasets import Dataset
7
+ import json
8
+ from trl import SFTTrainer, SFTConfig
9
+ from transformers import TrainingArguments
10
+ import time
11
+ import os
12
+
13
+ class CustomCallback(TrainerCallback):
14
+ def __init__(self):
15
+ self.start_time = time.time()
16
+
17
+ def on_train_begin(self, args, state, control, **kwargs):
18
+ print("Training has begun!")
19
+
20
+ def on_step_end(self, args, state, control, **kwargs):
21
+ if state.global_step % args.logging_steps == 0:
22
+ elapsed_time = time.time() - self.start_time
23
+ if state.log_history:
24
+ loss = state.log_history[-1].get('loss', 0)
25
+ print(f"Step: {state.global_step}, Loss: {loss:.4f}, Time: {elapsed_time:.2f}s")
26
+ else:
27
+ print(f"Step: {state.global_step}, Loss: N/A, Time: {elapsed_time:.2f}s")
28
+
29
+ def on_train_end(self, args, state, control, **kwargs):
30
+ print("Training has ended!")
31
+
32
+ os.environ["CUDA_VISIBLE_DEVICES"] = "7"
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+
35
+
36
+ model_path = "Google/gemma-2-9b-Instruct"
37
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
38
+ tokenizer.pad_token = tokenizer.eos_token
39
+
40
+ # (QLORA)
41
+ bnb_config = BitsAndBytesConfig(
42
+ load_in_4bit=True,
43
+ bnb_4bit_quant_type="nf4",
44
+ bnb_4bit_compute_dtype=torch.float16,
45
+ bnb_4bit_use_double_quant=True,
46
+ )
47
+
48
+ model = AutoModelForCausalLM.from_pretrained(
49
+ model_path,
50
+ quantization_config=bnb_config,
51
+ device_map="auto",
52
+ torch_dtype=torch.float16,
53
+ )
54
+
55
+ # PEFT
56
+ model = prepare_model_for_kbit_training(model)
57
+ lora_config = LoraConfig(
58
+ r=6,
59
+ lora_alpha=8,
60
+ lora_dropout=0.05,
61
+ target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
62
+ task_type="CAUSAL_LM",
63
+ )
64
+ model = get_peft_model(model, lora_config)
65
+
66
+ # data
67
+ with open('en_ko_data', 'r', encoding='utf-8') as f:
68
+ data = json.load(f)
69
+
70
+ # prompt
71
+ def generate_prompt(en_text, ko_text):
72
+ return f"""<bos><start_of_turn>user
73
+ Please translate the following English colloquial expression into Korean.:
74
+
75
+ {en_text}<end_of_turn>
76
+ <start_of_turn>model
77
+ {ko_text}<end_of_turn><eos>"""
78
+
79
+
80
+ key = list(data.keys())[0]
81
+ dataset = [{"text": generate_prompt(item['en_original'], item['ko'])} for item in data[key]]
82
+ dataset = Dataset.from_list(dataset)
83
+
84
+ # training set
85
+ training_args = TrainingArguments(
86
+ output_dir="./results",
87
+ num_train_epochs=3,
88
+ per_device_train_batch_size=4,
89
+ gradient_accumulation_steps=4,
90
+ save_steps=100,
91
+ logging_steps=1,
92
+ learning_rate=2e-4,
93
+ weight_decay=0.01,
94
+ fp16=True,
95
+ optim="paged_adamw_8bit",
96
+ )
97
+
98
+ # SFTTrainer
99
+ trainer = SFTTrainer(
100
+ model=model,
101
+ train_dataset=dataset,
102
+ args=training_args,
103
+ tokenizer=tokenizer,
104
+ dataset_text_field="text",
105
+ max_seq_length=512,
106
+ )
107
+
108
+
109
+ trainer.add_callback(CustomCallback())
110
+
111
+
112
+ # train
113
+ trainer.train()
114
+
115
+ # save
116
+ trainer.save_model("./gemma2_9b_ko_translator")