[update]add model
Browse files- examples/exercises/chinese_porn_novel/1.prepare_data.py +75 -0
- examples/exercises/chinese_porn_novel/2.train_model.py +203 -0
- examples/exercises/chinese_porn_novel/3.test_model.py +88 -0
- examples/exercises/chinese_porn_novel/README.md +8 -0
- examples/exercises/chinese_porn_novel/run.sh +146 -0
- examples/exercises/chinese_porn_novel/stop.sh +5 -0
- main.py +5 -1
examples/exercises/chinese_porn_novel/1.prepare_data.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import sys
|
8 |
+
|
9 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
10 |
+
sys.path.append(os.path.join(pwd, '../../../'))
|
11 |
+
|
12 |
+
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from project_settings import project_path
|
16 |
+
|
17 |
+
|
18 |
+
def get_args():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
|
21 |
+
parser.add_argument("--dataset_path", default="qgyd2021/h_novel", type=str)
|
22 |
+
# parser.add_argument("--dataset_name", default="ltxsba_500m", type=str)
|
23 |
+
parser.add_argument("--dataset_name", default="ltxsba_5gb", type=str)
|
24 |
+
parser.add_argument("--dataset_split", default="train", type=str)
|
25 |
+
parser.add_argument(
|
26 |
+
"--dataset_cache_dir",
|
27 |
+
default=(project_path / "hub_datasets").as_posix(),
|
28 |
+
type=str
|
29 |
+
)
|
30 |
+
parser.add_argument("--train_subset", default="train.jsonl", type=str)
|
31 |
+
parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
|
32 |
+
args = parser.parse_args()
|
33 |
+
return args
|
34 |
+
|
35 |
+
|
36 |
+
def main():
|
37 |
+
args = get_args()
|
38 |
+
|
39 |
+
dataset_dict = load_dataset(
|
40 |
+
path=args.dataset_path,
|
41 |
+
name=args.dataset_name,
|
42 |
+
# split=args.dataset_split,
|
43 |
+
cache_dir=args.dataset_cache_dir,
|
44 |
+
streaming=True,
|
45 |
+
)
|
46 |
+
|
47 |
+
train_dataset = dataset_dict["train"]
|
48 |
+
|
49 |
+
with open(args.train_subset, "w", encoding="utf-8") as ftrain, \
|
50 |
+
open(args.valid_subset, "w", encoding="utf-8") as fvalid:
|
51 |
+
for sample in tqdm(train_dataset):
|
52 |
+
# print(sample)
|
53 |
+
|
54 |
+
source = sample["source"]
|
55 |
+
idx = sample["idx"]
|
56 |
+
filename = sample["filename"]
|
57 |
+
novel_name = sample["novel_name"]
|
58 |
+
row_idx = sample["row_idx"]
|
59 |
+
text = sample["text"]
|
60 |
+
|
61 |
+
row = {
|
62 |
+
"text": text
|
63 |
+
}
|
64 |
+
row = json.dumps(row, ensure_ascii=False)
|
65 |
+
|
66 |
+
if random.random() < 0.95:
|
67 |
+
ftrain.write("{}\n".format(row))
|
68 |
+
else:
|
69 |
+
fvalid.write("{}\n".format(row))
|
70 |
+
|
71 |
+
return
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == '__main__':
|
75 |
+
main()
|
examples/exercises/chinese_porn_novel/2.train_model.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
参考链接:
|
5 |
+
https://www.thepythoncode.com/article/pretraining-bert-huggingface-transformers-in-python
|
6 |
+
https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py
|
7 |
+
|
8 |
+
"""
|
9 |
+
import argparse
|
10 |
+
from itertools import chain
|
11 |
+
import os
|
12 |
+
from pathlib import Path
|
13 |
+
import platform
|
14 |
+
|
15 |
+
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset
|
16 |
+
import torch
|
17 |
+
from transformers.data.data_collator import DataCollatorForLanguageModeling
|
18 |
+
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
19 |
+
from transformers.models.bert.tokenization_bert import BertTokenizer
|
20 |
+
from transformers.trainer import Trainer
|
21 |
+
from transformers.training_args import TrainingArguments
|
22 |
+
|
23 |
+
from project_settings import project_path
|
24 |
+
|
25 |
+
|
26 |
+
def get_args():
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument(
|
29 |
+
"--pretrained_model_name_or_path",
|
30 |
+
default=(project_path / "pretrained_models/gpt2-chinese-cluecorpussmall").as_posix(),
|
31 |
+
type=str
|
32 |
+
)
|
33 |
+
|
34 |
+
parser.add_argument("--train_subset", default="train.jsonl", type=str)
|
35 |
+
parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
|
36 |
+
|
37 |
+
parser.add_argument("--output_dir", default="serialization_dir", type=str)
|
38 |
+
parser.add_argument("--overwrite_output_dir", action="store_true")
|
39 |
+
parser.add_argument("--evaluation_strategy", default="no", choices=["no", "steps", "epoch"], type=str)
|
40 |
+
parser.add_argument("--per_device_train_batch_size", default=8, type=int)
|
41 |
+
parser.add_argument("--gradient_accumulation_steps", default=4, type=int)
|
42 |
+
parser.add_argument("--learning_rate", default=1e-5, type=float)
|
43 |
+
parser.add_argument("--weight_decay", default=0, type=float)
|
44 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float)
|
45 |
+
parser.add_argument("--num_train_epochs", default=3.0, type=float)
|
46 |
+
parser.add_argument("--max_steps", default=-1, type=int)
|
47 |
+
parser.add_argument("--lr_scheduler_type", default="cosine", type=str)
|
48 |
+
parser.add_argument("--warmup_ratio", default=0.0, type=float)
|
49 |
+
parser.add_argument("--warmup_steps", default=3000, type=int)
|
50 |
+
parser.add_argument("--logging_steps", default=300, type=int)
|
51 |
+
parser.add_argument("--save_strategy", default="steps", type=str)
|
52 |
+
parser.add_argument("--save_steps", default=500, type=int)
|
53 |
+
parser.add_argument("--save_total_limit", default=3, type=int)
|
54 |
+
parser.add_argument("--no_cuda", action="store_true")
|
55 |
+
parser.add_argument("--seed", default=3407, type=str, help="https://arxiv.org/abs/2109.08203")
|
56 |
+
# parser.add_argument("--fp16", action="store_true")
|
57 |
+
parser.add_argument("--fp16", action="store_false")
|
58 |
+
parser.add_argument("--half_precision_backend", default="auto", type=str)
|
59 |
+
parser.add_argument("--dataloader_num_workers", default=5, type=int)
|
60 |
+
parser.add_argument("--disable_tqdm", action="store_false")
|
61 |
+
parser.add_argument("--remove_unused_columns", action="store_false")
|
62 |
+
# parser.add_argument("--deepspeed", default="ds_z3_config.json", type=str)
|
63 |
+
parser.add_argument("--deepspeed", default=None, type=str)
|
64 |
+
parser.add_argument("--optim", default="adamw_hf", type=str)
|
65 |
+
parser.add_argument("--report_to", default="tensorboard", type=str)
|
66 |
+
parser.add_argument("--resume_from_checkpoint", default=None, type=str)
|
67 |
+
# parser.add_argument("--gradient_checkpointing", action="store_true")
|
68 |
+
parser.add_argument("--gradient_checkpointing", action="store_false")
|
69 |
+
|
70 |
+
parser.add_argument("--truncate_longer_samples", action="store_true")
|
71 |
+
# parser.add_argument("--truncate_longer_samples", action="store_false")
|
72 |
+
parser.add_argument("--max_seq_length", default=1024, type=int)
|
73 |
+
|
74 |
+
args = parser.parse_args()
|
75 |
+
return args
|
76 |
+
|
77 |
+
|
78 |
+
def main():
|
79 |
+
args = get_args()
|
80 |
+
|
81 |
+
# dataset
|
82 |
+
dataset_dict = DatasetDict()
|
83 |
+
train_data_files = [args.train_subset]
|
84 |
+
dataset_dict["train"] = load_dataset(
|
85 |
+
path="json", data_files=[str(file) for file in train_data_files]
|
86 |
+
)["train"]
|
87 |
+
valid_data_files = [args.valid_subset]
|
88 |
+
dataset_dict["valid"] = load_dataset(
|
89 |
+
path="json", data_files=[str(file) for file in valid_data_files]
|
90 |
+
)["train"]
|
91 |
+
|
92 |
+
print(dataset_dict)
|
93 |
+
|
94 |
+
# model
|
95 |
+
tokenizer = BertTokenizer.from_pretrained(args.pretrained_model_name_or_path)
|
96 |
+
model = GPT2LMHeadModel.from_pretrained(args.pretrained_model_name_or_path)
|
97 |
+
|
98 |
+
def encode_with_truncation(examples):
|
99 |
+
outputs = tokenizer.__call__(examples['text'],
|
100 |
+
truncation=True,
|
101 |
+
padding='max_length',
|
102 |
+
max_length=args.max_seq_length,
|
103 |
+
return_special_tokens_mask=True)
|
104 |
+
return outputs
|
105 |
+
|
106 |
+
def encode_without_truncation(examples):
|
107 |
+
outputs = tokenizer.__call__(examples['text'],
|
108 |
+
return_special_tokens_mask=True)
|
109 |
+
return outputs
|
110 |
+
|
111 |
+
def group_texts(examples):
|
112 |
+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
113 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
114 |
+
if total_length >= args.max_seq_length:
|
115 |
+
total_length = (total_length // args.max_seq_length) * args.max_seq_length
|
116 |
+
|
117 |
+
result = {
|
118 |
+
k: [t[i: i + args.max_seq_length] for i in range(0, total_length, args.max_seq_length)]
|
119 |
+
for k, t in concatenated_examples.items()
|
120 |
+
}
|
121 |
+
return result
|
122 |
+
|
123 |
+
if args.truncate_longer_samples:
|
124 |
+
dataset_dict = dataset_dict.map(
|
125 |
+
encode_with_truncation,
|
126 |
+
batched=True,
|
127 |
+
drop_last_batch=True,
|
128 |
+
keep_in_memory=False,
|
129 |
+
# num_proc=None if platform.system() == 'Windows' else os.cpu_count() // 2,
|
130 |
+
num_proc=None,
|
131 |
+
)
|
132 |
+
dataset_dict.set_format(type="torch", columns=["input_ids", "attention_mask"])
|
133 |
+
else:
|
134 |
+
dataset_dict = dataset_dict.map(
|
135 |
+
encode_without_truncation,
|
136 |
+
batched=True,
|
137 |
+
drop_last_batch=True,
|
138 |
+
keep_in_memory=False,
|
139 |
+
# num_proc=None if platform.system() == 'Windows' else os.cpu_count() // 2,
|
140 |
+
num_proc=None,
|
141 |
+
)
|
142 |
+
dataset_dict.set_format(type="torch", columns=["input_ids", "attention_mask"])
|
143 |
+
|
144 |
+
dataset_dict = dataset_dict.map(
|
145 |
+
group_texts,
|
146 |
+
batched=True,
|
147 |
+
drop_last_batch=True,
|
148 |
+
keep_in_memory=False,
|
149 |
+
# num_proc=None if platform.system() == 'Windows' else os.cpu_count() // 2,
|
150 |
+
num_proc=None,
|
151 |
+
)
|
152 |
+
dataset_dict.set_format("torch")
|
153 |
+
|
154 |
+
data_collator = DataCollatorForLanguageModeling(
|
155 |
+
tokenizer=tokenizer, mlm=False
|
156 |
+
)
|
157 |
+
|
158 |
+
training_args = TrainingArguments(
|
159 |
+
output_dir=args.output_dir,
|
160 |
+
overwrite_output_dir=args.overwrite_output_dir,
|
161 |
+
evaluation_strategy=args.evaluation_strategy,
|
162 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
163 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
164 |
+
learning_rate=args.learning_rate,
|
165 |
+
num_train_epochs=args.num_train_epochs,
|
166 |
+
max_steps=args.max_steps,
|
167 |
+
lr_scheduler_type=args.lr_scheduler_type,
|
168 |
+
warmup_steps=args.warmup_steps,
|
169 |
+
logging_steps=args.logging_steps,
|
170 |
+
save_steps=args.save_steps,
|
171 |
+
save_total_limit=args.save_total_limit,
|
172 |
+
no_cuda=args.no_cuda,
|
173 |
+
fp16=args.fp16,
|
174 |
+
half_precision_backend=args.half_precision_backend,
|
175 |
+
# deepspeed=args.deepspeed,
|
176 |
+
report_to=args.report_to,
|
177 |
+
resume_from_checkpoint=args.resume_from_checkpoint,
|
178 |
+
gradient_checkpointing=args.gradient_checkpointing,
|
179 |
+
)
|
180 |
+
|
181 |
+
trainer = Trainer(
|
182 |
+
model=model,
|
183 |
+
args=training_args,
|
184 |
+
data_collator=data_collator,
|
185 |
+
train_dataset=dataset_dict["train"],
|
186 |
+
)
|
187 |
+
train_result = trainer.train()
|
188 |
+
|
189 |
+
# 保存最好的 checkpoint
|
190 |
+
final_save_path = os.path.join(training_args.output_dir, "final")
|
191 |
+
trainer.save_model(final_save_path) # Saves the tokenizer too
|
192 |
+
# 保存训练指标
|
193 |
+
metrics = train_result.metrics
|
194 |
+
trainer.log_metrics("train", metrics)
|
195 |
+
trainer.save_metrics("train", metrics)
|
196 |
+
trainer.save_state()
|
197 |
+
|
198 |
+
tokenizer.save_pretrained(final_save_path)
|
199 |
+
return
|
200 |
+
|
201 |
+
|
202 |
+
if __name__ == '__main__':
|
203 |
+
main()
|
examples/exercises/chinese_porn_novel/3.test_model.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
8 |
+
sys.path.append(os.path.join(pwd, '../../../'))
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
12 |
+
from transformers.models.bert.tokenization_bert import BertTokenizer
|
13 |
+
|
14 |
+
from project_settings import project_path
|
15 |
+
|
16 |
+
|
17 |
+
def get_args():
|
18 |
+
"""
|
19 |
+
python3 3.test_model.py \
|
20 |
+
--repetition_penalty 1.2 \
|
21 |
+
--trained_model_path /data/tianxing/PycharmProjects/Transformers/trained_models/gpt2_chinese_h_novel
|
22 |
+
|
23 |
+
python3 3.test_model.py \
|
24 |
+
--trained_model_path /data/tianxing/PycharmProjects/Transformers/trained_models/gpt2_chinese_h_novel
|
25 |
+
|
26 |
+
"""
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument(
|
29 |
+
'--trained_model_path',
|
30 |
+
default=(project_path / "pretrained_models/gpt2-chinese-cluecorpussmall").as_posix(),
|
31 |
+
type=str,
|
32 |
+
)
|
33 |
+
parser.add_argument('--device', default='auto', type=str)
|
34 |
+
|
35 |
+
parser.add_argument('--max_new_tokens', default=512, type=int)
|
36 |
+
parser.add_argument('--top_p', default=0.85, type=float)
|
37 |
+
parser.add_argument('--temperature', default=0.35, type=float)
|
38 |
+
parser.add_argument('--repetition_penalty', default=1.2, type=float)
|
39 |
+
|
40 |
+
args = parser.parse_args()
|
41 |
+
return args
|
42 |
+
|
43 |
+
|
44 |
+
def main():
|
45 |
+
args = get_args()
|
46 |
+
|
47 |
+
if args.device == 'auto':
|
48 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
49 |
+
else:
|
50 |
+
device = args.device
|
51 |
+
|
52 |
+
# pretrained model
|
53 |
+
tokenizer = BertTokenizer.from_pretrained(args.trained_model_path)
|
54 |
+
model = GPT2LMHeadModel.from_pretrained(args.trained_model_path)
|
55 |
+
|
56 |
+
model.eval()
|
57 |
+
model = model.to(device)
|
58 |
+
|
59 |
+
while True:
|
60 |
+
text = input('prefix: ')
|
61 |
+
|
62 |
+
if text == "Quit":
|
63 |
+
break
|
64 |
+
text = '{}'.format(text)
|
65 |
+
input_ids = tokenizer(text, return_tensors="pt").input_ids
|
66 |
+
input_ids = input_ids[:, :-1]
|
67 |
+
# print(input_ids)
|
68 |
+
# print(type(input_ids))
|
69 |
+
input_ids = input_ids.to(device)
|
70 |
+
|
71 |
+
outputs = model.generate(input_ids,
|
72 |
+
max_new_tokens=512,
|
73 |
+
do_sample=True,
|
74 |
+
top_p=args.top_p,
|
75 |
+
temperature=args.temperature,
|
76 |
+
repetition_penalty=args.repetition_penalty,
|
77 |
+
eos_token_id=tokenizer.sep_token_id,
|
78 |
+
pad_token_id=tokenizer.pad_token_id
|
79 |
+
)
|
80 |
+
rets = tokenizer.batch_decode(outputs)
|
81 |
+
output = rets[0].replace(" ", "").replace("[CLS]", "").replace("[SEP]", "")
|
82 |
+
print("{}".format(output))
|
83 |
+
|
84 |
+
return
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == '__main__':
|
88 |
+
main()
|
examples/exercises/chinese_porn_novel/README.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## 预训练 GPT 模型
|
2 |
+
|
3 |
+
```text
|
4 |
+
参考链接:
|
5 |
+
https://huggingface.co/docs/transformers/model_doc/openai-gpt
|
6 |
+
https://huggingface.co/learn/nlp-course/chapter7/6
|
7 |
+
|
8 |
+
```
|
examples/exercises/chinese_porn_novel/run.sh
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# nohup sh run.sh --stage 0 --stop_stage 1 --system_version centos &
|
4 |
+
# sh run.sh --stage 0 --stop_stage 1 --system_version windows
|
5 |
+
# sh run.sh --stage 0 --stop_stage 0 --system_version centos
|
6 |
+
# sh run.sh --stage 2 --stop_stage 2 --system_version centos --checkpoint_name final
|
7 |
+
# sh run.sh --stage -1 --stop_stage 1
|
8 |
+
|
9 |
+
# bitsandbytes
|
10 |
+
export LD_LIBRARY_PATH="/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
|
11 |
+
|
12 |
+
# params
|
13 |
+
system_version="windows";
|
14 |
+
verbose=true;
|
15 |
+
stage=0 # start from 0 if you need to start from data preparation
|
16 |
+
stop_stage=5
|
17 |
+
|
18 |
+
pretrained_model_name=gpt2-chinese-cluecorpussmall
|
19 |
+
|
20 |
+
train_subset=train.jsonl
|
21 |
+
valid_subset=valid.jsonl
|
22 |
+
|
23 |
+
final_model_name=gpt2_chinese_h_novel
|
24 |
+
|
25 |
+
checkpoint_name=final
|
26 |
+
|
27 |
+
# parse options
|
28 |
+
while true; do
|
29 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
30 |
+
case "$1" in
|
31 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
32 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
33 |
+
old_value="(eval echo \\$$name)";
|
34 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
35 |
+
was_bool=true;
|
36 |
+
else
|
37 |
+
was_bool=false;
|
38 |
+
fi
|
39 |
+
|
40 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
41 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
42 |
+
eval "${name}=\"$2\"";
|
43 |
+
|
44 |
+
# Check that Boolean-valued arguments are really Boolean.
|
45 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
46 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
47 |
+
exit 1;
|
48 |
+
fi
|
49 |
+
shift 2;
|
50 |
+
;;
|
51 |
+
|
52 |
+
*) break;
|
53 |
+
esac
|
54 |
+
done
|
55 |
+
|
56 |
+
$verbose && echo "system_version: ${system_version}"
|
57 |
+
|
58 |
+
work_dir="$(pwd)"
|
59 |
+
file_dir="$(pwd)/file_dir"
|
60 |
+
pretrained_models_dir="${work_dir}/../../../pretrained_models";
|
61 |
+
serialization_dir="${file_dir}/serialization_dir"
|
62 |
+
final_model_dir="${work_dir}/../../../trained_models/${final_model_name}";
|
63 |
+
|
64 |
+
mkdir -p "${file_dir}"
|
65 |
+
mkdir -p "${pretrained_models_dir}"
|
66 |
+
mkdir -p "${serialization_dir}"
|
67 |
+
mkdir -p "${final_model_dir}"
|
68 |
+
|
69 |
+
|
70 |
+
export PYTHONPATH="${work_dir}/../../.."
|
71 |
+
|
72 |
+
|
73 |
+
if [ $system_version == "windows" ]; then
|
74 |
+
alias python3='C:/Users/tianx/PycharmProjects/virtualenv/Transformers/Scripts/python.exe'
|
75 |
+
elif [ $system_version == "centos" ]; then
|
76 |
+
# conda activate Transformers
|
77 |
+
alias python3='/usr/local/miniconda3/envs/Transformers/bin/python3'
|
78 |
+
elif [ $system_version == "ubuntu" ]; then
|
79 |
+
# conda activate Transformers
|
80 |
+
alias python3='/usr/local/miniconda3/envs/Transformers/bin/python3'
|
81 |
+
fi
|
82 |
+
|
83 |
+
|
84 |
+
declare -A pretrained_model_dict
|
85 |
+
pretrained_model_dict=(
|
86 |
+
["gpt2-chinese-cluecorpussmall"]="https://huggingface.co/uer/gpt2-chinese-cluecorpussmall"
|
87 |
+
["gpt2"]="https://huggingface.co/gpt2"
|
88 |
+
["japanese-gpt2-medium"]="https://huggingface.co/rinna/japanese-gpt2-medium"
|
89 |
+
|
90 |
+
)
|
91 |
+
pretrained_model_dir="${pretrained_models_dir}/${pretrained_model_name}"
|
92 |
+
|
93 |
+
|
94 |
+
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
95 |
+
$verbose && echo "stage -1: download pretrained model"
|
96 |
+
cd "${file_dir}" || exit 1;
|
97 |
+
|
98 |
+
if [ ! -d "${pretrained_model_dir}" ]; then
|
99 |
+
cd "${pretrained_models_dir}" || exit 1;
|
100 |
+
|
101 |
+
repository_url="${pretrained_model_dict[${pretrained_model_name}]}"
|
102 |
+
git clone "${repository_url}"
|
103 |
+
|
104 |
+
cd "${pretrained_model_dir}" || exit 1;
|
105 |
+
rm flax_model.msgpack && rm pytorch_model.bin && rm tf_model.h5
|
106 |
+
wget "${repository_url}/resolve/main/pytorch_model.bin"
|
107 |
+
fi
|
108 |
+
fi
|
109 |
+
|
110 |
+
|
111 |
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
112 |
+
$verbose && echo "stage 0: prepare data"
|
113 |
+
cd "${work_dir}" || exit 1;
|
114 |
+
|
115 |
+
python3 1.prepare_data.py \
|
116 |
+
--train_subset "${file_dir}/${train_subset}" \
|
117 |
+
--valid_subset "${file_dir}/${valid_subset}" \
|
118 |
+
|
119 |
+
fi
|
120 |
+
|
121 |
+
|
122 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
123 |
+
$verbose && echo "stage 1: train model"
|
124 |
+
cd "${work_dir}" || exit 1;
|
125 |
+
|
126 |
+
python3 2.train_model.py \
|
127 |
+
--train_subset "${file_dir}/${train_subset}" \
|
128 |
+
--valid_subset "${file_dir}/${valid_subset}" \
|
129 |
+
--pretrained_model_name_or_path "${pretrained_models_dir}/${pretrained_model_name}" \
|
130 |
+
--output_dir "${serialization_dir}"
|
131 |
+
|
132 |
+
fi
|
133 |
+
|
134 |
+
|
135 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
136 |
+
$verbose && echo "stage 2: collect files"
|
137 |
+
cd "${work_dir}" || exit 1;
|
138 |
+
|
139 |
+
cp "${serialization_dir}/${checkpoint_name}/pytorch_model.bin" "${final_model_dir}/pytorch_model.bin"
|
140 |
+
|
141 |
+
cp "${pretrained_models_dir}/${pretrained_model_name}/config.json" "${final_model_dir}/config.json"
|
142 |
+
cp "${pretrained_models_dir}/${pretrained_model_name}/special_tokens_map.json" "${final_model_dir}/special_tokens_map.json"
|
143 |
+
cp "${pretrained_models_dir}/${pretrained_model_name}/tokenizer_config.json" "${final_model_dir}/tokenizer_config.json"
|
144 |
+
cp "${pretrained_models_dir}/${pretrained_model_name}/vocab.txt" "${final_model_dir}/vocab.txt"
|
145 |
+
|
146 |
+
fi
|
examples/exercises/chinese_porn_novel/stop.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
kill -9 `ps -aef | grep 'run.sh' | grep -v grep | awk '{print $2}' | sed 's/\n/ /'`
|
4 |
+
|
5 |
+
kill -9 `ps -aef | grep 'Transformers/bin/python3' | grep -v grep | awk '{print $2}' | sed 's/\n/ /'`
|
main.py
CHANGED
@@ -117,7 +117,9 @@ def main():
|
|
117 |
yield output
|
118 |
|
119 |
model_name_choices = ["trained_models/lib_service_4chan"] \
|
120 |
-
if platform.system() == "Windows" else
|
|
|
|
|
121 |
demo = gr.Interface(
|
122 |
fn=fn_stream,
|
123 |
inputs=[
|
@@ -133,6 +135,8 @@ def main():
|
|
133 |
examples=[
|
134 |
["怎样擦屁股才能擦的干净", 512, 0.75, 0.35, 1.2, "qgyd2021/lib_service_4chan", True],
|
135 |
["你好", 512, 0.75, 0.35, 1.2, "qgyd2021/chinese_chitchat", True],
|
|
|
|
|
136 |
],
|
137 |
cache_examples=False,
|
138 |
examples_per_page=50,
|
|
|
117 |
yield output
|
118 |
|
119 |
model_name_choices = ["trained_models/lib_service_4chan"] \
|
120 |
+
if platform.system() == "Windows" else \
|
121 |
+
["qgyd2021/lib_service_4chan", "qgyd2021/chinese_chitchat", "qgyd2021/chinese_porn_novel"]
|
122 |
+
|
123 |
demo = gr.Interface(
|
124 |
fn=fn_stream,
|
125 |
inputs=[
|
|
|
135 |
examples=[
|
136 |
["怎样擦屁股才能擦的干净", 512, 0.75, 0.35, 1.2, "qgyd2021/lib_service_4chan", True],
|
137 |
["你好", 512, 0.75, 0.35, 1.2, "qgyd2021/chinese_chitchat", True],
|
138 |
+
["白洁走到床边并脱去内衣, 一双硕大的", 512, 0.75, 0.35, 1.2, "qgyd2021/chinese_porn_novel", False],
|
139 |
+
["男人走进房间, 上床, 压上", 512, 0.75, 0.35, 1.2, "qgyd2021/chinese_porn_novel", False],
|
140 |
],
|
141 |
cache_examples=False,
|
142 |
examples_per_page=50,
|