File size: 6,386 Bytes
506c93a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from toolbox import CatchException, update_ui, promote_file_to_downloadzone
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
import datetime, json

def fetch_items(list_of_items, batch_size):
    for i in range(0, len(list_of_items), batch_size):
        yield list_of_items[i:i + batch_size]

def string_to_options(arguments):
    import argparse
    import shlex

    # Create an argparse.ArgumentParser instance
    parser = argparse.ArgumentParser()

    # Add command-line arguments
    parser.add_argument("--llm_to_learn", type=str, help="LLM model to learn", default="gpt-3.5-turbo")
    parser.add_argument("--prompt_prefix", type=str, help="Prompt prefix", default='')
    parser.add_argument("--system_prompt", type=str, help="System prompt", default='')
    parser.add_argument("--batch", type=int, help="System prompt", default=50)
    parser.add_argument("--pre_seq_len", type=int, help="pre_seq_len", default=50)
    parser.add_argument("--learning_rate", type=float, help="learning_rate", default=2e-2)
    parser.add_argument("--num_gpus", type=int, help="num_gpus", default=1)
    parser.add_argument("--json_dataset", type=str, help="json_dataset", default="")
    parser.add_argument("--ptuning_directory", type=str, help="ptuning_directory", default="")



    # Parse the arguments
    args = parser.parse_args(shlex.split(arguments))

    return args

@CatchException
def 微调数据集生成(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
    """
    txt             输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
    llm_kwargs      gpt模型参数,如温度和top_p等,一般原样传递下去就行
    plugin_kwargs   插件模型的参数
    chatbot         聊天显示框的句柄,用于显示给用户
    history         聊天历史,前情提要
    system_prompt   给gpt的静默提醒
    user_request    当前用户的请求信息(IP地址等)
    """
    history = []    # 清空历史,以免输入溢出
    chatbot.append(("这是什么功能?", "[Local Message] 微调数据集生成"))
    if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
    args = plugin_kwargs.get("advanced_arg", None)
    if args is None: 
        chatbot.append(("没给定指令", "退出"))
        yield from update_ui(chatbot=chatbot, history=history); return
    else:
        arguments = string_to_options(arguments=args)

    dat = []
    with open(txt, 'r', encoding='utf8') as f:
        for line in f.readlines():
            json_dat = json.loads(line)
            dat.append(json_dat["content"])

    llm_kwargs['llm_model'] = arguments.llm_to_learn
    for batch in fetch_items(dat, arguments.batch):
        res = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
            inputs_array=[f"{arguments.prompt_prefix}\n\n{b}" for b in (batch)],
            inputs_show_user_array=[f"Show Nothing" for _ in (batch)],
            llm_kwargs=llm_kwargs,
            chatbot=chatbot,
            history_array=[[] for _ in (batch)],
            sys_prompt_array=[arguments.system_prompt for _ in (batch)],
            max_workers=10  # OpenAI所允许的最大并行过载
        )
    
        with open(txt+'.generated.json', 'a+', encoding='utf8') as f:
            for b, r in zip(batch, res[1::2]):
                f.write(json.dumps({"content":b, "summary":r}, ensure_ascii=False)+'\n')

    promote_file_to_downloadzone(txt+'.generated.json', rename_file='generated.json', chatbot=chatbot)
    return



@CatchException
def 启动微调(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
    """
    txt             输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
    llm_kwargs      gpt模型参数,如温度和top_p等,一般原样传递下去就行
    plugin_kwargs   插件模型的参数
    chatbot         聊天显示框的句柄,用于显示给用户
    history         聊天历史,前情提要
    system_prompt   给gpt的静默提醒
    user_request    当前用户的请求信息(IP地址等)
    """
    import subprocess
    history = []    # 清空历史,以免输入溢出
    chatbot.append(("这是什么功能?", "[Local Message] 微调数据集生成"))
    if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
    args = plugin_kwargs.get("advanced_arg", None)
    if args is None: 
        chatbot.append(("没给定指令", "退出"))
        yield from update_ui(chatbot=chatbot, history=history); return
    else:
        arguments = string_to_options(arguments=args)
      


    pre_seq_len = arguments.pre_seq_len             # 128
    learning_rate = arguments.learning_rate                               # 2e-2
    num_gpus = arguments.num_gpus                   # 1
    json_dataset = arguments.json_dataset                 # 't_code.json'
    ptuning_directory = arguments.ptuning_directory       # '/home/hmp/ChatGLM2-6B/ptuning'

    command = f"torchrun --standalone --nnodes=1 --nproc-per-node={num_gpus} main.py \
        --do_train \
        --train_file AdvertiseGen/{json_dataset} \
        --validation_file AdvertiseGen/{json_dataset} \
        --preprocessing_num_workers 20 \
        --prompt_column content \
        --response_column summary \
        --overwrite_cache \
        --model_name_or_path THUDM/chatglm2-6b \
        --output_dir output/clothgen-chatglm2-6b-pt-{pre_seq_len}-{learning_rate} \
        --overwrite_output_dir \
        --max_source_length 256 \
        --max_target_length 256 \
        --per_device_train_batch_size 1 \
        --per_device_eval_batch_size 1 \
        --gradient_accumulation_steps 16 \
        --predict_with_generate \
        --max_steps 100 \
        --logging_steps 10 \
        --save_steps 20 \
        --learning_rate {learning_rate} \
        --pre_seq_len {pre_seq_len} \
        --quantization_bit 4"

    process = subprocess.Popen(command, shell=True, cwd=ptuning_directory)
    try:
        process.communicate(timeout=3600*24)
    except subprocess.TimeoutExpired:
        process.kill()
    return