YeungNLP commited on
Commit
a48439f
1 Parent(s): bac896f

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +168 -0
README.md ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ 使用[Firefly](https://github.com/yangjianxin1/Firefly)项目微调baichuan2-13b-base。训练数据约为一百万多轮对话数据,包括项目分享的moss数据+2万条school math数据。
3
+
4
+ 更多详情见项目:[Firefly](https://github.com/yangjianxin1/Firefly)
5
+
6
+ CMMLU榜单:
7
+
8
+ | 模型 | CMMLU |
9
+ |---------------------------|------------|
10
+ | Baichuan2-13B-Chat | 58.4 |
11
+ | **firefly-baichuan2-13b** | **56.83** |
12
+ | WeMix-LLaMA2-70B | 56 |
13
+ | ChatGPT | 53.9 |
14
+ | InternLM-20B-Chat | 52.2 |
15
+ | Baichuan-13B-Chat | 50.7 |
16
+ | chinese-alpaca-2-13b | 45.17 |
17
+ | LLaMA-2-70B-Chat | 43.3 |
18
+ | openbuddy-llama2-13b-v8.1 | 41.66 |
19
+ | belle-llama2-13b | 41.57 |
20
+ | ziya-llama-13b | 39.9 |
21
+ | chinese-alpaca-plus-13b | 39.9 |
22
+ | flagalpha-llama2-13b-chat | 39.20 |
23
+ | llama-2-13b-chat | 38.65 |
24
+ | yayi-13b-llama2 | 30.73 |
25
+ | linly-llama2-13b | 26.32 |
26
+
27
+ 单轮对话:
28
+ ```python
29
+ from transformers import AutoModelForCausalLM, AutoTokenizer
30
+ import torch
31
+ """
32
+ 单轮对话,不具有对话历史的记忆功能
33
+ """
34
+
35
+
36
+ def main():
37
+ model_name = 'YeungNLP/firefly-baichuan2-13b'
38
+
39
+ max_new_tokens = 500
40
+ top_p = 0.9
41
+ temperature = 0.35
42
+ repetition_penalty = 1.0
43
+ device = 'cuda'
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ model_name,
46
+ trust_remote_code=True,
47
+ low_cpu_mem_usage=True,
48
+ torch_dtype=torch.float16,
49
+ device_map='auto'
50
+ ).to(device).eval()
51
+ tokenizer = AutoTokenizer.from_pretrained(
52
+ model_name,
53
+ trust_remote_code=True,
54
+ # llama不支持fast
55
+ use_fast=False if model.config.model_type == 'llama' else True
56
+ )
57
+ # QWenTokenizer比较特殊,pad_token_id、bos_token_id、eos_token_id均为None。eod_id对应的token为<|endoftext|>
58
+ if tokenizer.__class__.__name__ == 'QWenTokenizer':
59
+ tokenizer.pad_token_id = tokenizer.eod_id
60
+ tokenizer.bos_token_id = tokenizer.eod_id
61
+ tokenizer.eos_token_id = tokenizer.eod_id
62
+
63
+ text = input('User:')
64
+ while True:
65
+ text = text.strip()
66
+ # chatglm使用官方的数据组织格式
67
+ if model.config.model_type == 'chatglm':
68
+ text = '[Round 1]\n\n问:{}\n\n答:'.format(text)
69
+ input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
70
+ # 为了兼容qwen-7b,因为其对eos_token进行tokenize,无法得到对应的eos_token_id
71
+ else:
72
+ input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
73
+ bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(device)
74
+ eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(device)
75
+ input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)
76
+ with torch.no_grad():
77
+ outputs = model.generate(
78
+ input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,
79
+ top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty,
80
+ eos_token_id=tokenizer.eos_token_id
81
+ )
82
+ outputs = outputs.tolist()[0][len(input_ids[0]):]
83
+ response = tokenizer.decode(outputs)
84
+ response = response.strip().replace(tokenizer.eos_token, "").strip()
85
+ print("Firefly:{}".format(response))
86
+ text = input('User:')
87
+
88
+
89
+ if __name__ == '__main__':
90
+ main()
91
+ ```
92
+
93
+
94
+ 多轮对话:
95
+ ```python
96
+ from transformers import AutoModelForCausalLM, AutoTokenizer
97
+ import torch
98
+
99
+
100
+ def main():
101
+ model_name = 'YeungNLP/firefly-baichuan2-13b'
102
+
103
+ device = 'cuda'
104
+ max_new_tokens = 500 # 每轮对话最多生成多少个token
105
+ history_max_len = 1000 # 模型记忆的最大token长度
106
+ top_p = 0.9
107
+ temperature = 0.35
108
+ repetition_penalty = 1.0
109
+
110
+ # 加载模型
111
+ model = AutoModelForCausalLM.from_pretrained(
112
+ model_name,
113
+ trust_remote_code=True,
114
+ low_cpu_mem_usage=True,
115
+ torch_dtype=torch.float16,
116
+ device_map='auto'
117
+ ).to(device).eval()
118
+ tokenizer = AutoTokenizer.from_pretrained(
119
+ model_name,
120
+ trust_remote_code=True,
121
+ # llama不支持fast
122
+ use_fast=False if model.config.model_type == 'llama' else True
123
+ )
124
+ # QWenTokenizer比较特殊,pad_token_id、bos_token_id、eos_token_id均为None。eod_id对应的token为<|endoftext|>
125
+ if tokenizer.__class__.__name__ == 'QWenTokenizer':
126
+ tokenizer.pad_token_id = tokenizer.eod_id
127
+ tokenizer.bos_token_id = tokenizer.eod_id
128
+ tokenizer.eos_token_id = tokenizer.eod_id
129
+
130
+ # 记录所有历史记录
131
+ if model.config.model_type != 'chatglm':
132
+ history_token_ids = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long)
133
+ else:
134
+ history_token_ids = torch.tensor([[]], dtype=torch.long)
135
+
136
+ # 开始对话
137
+ utterance_id = 0 # 记录当前是第几轮对话,为了契合chatglm的数据组织格式
138
+ user_input = input('User:')
139
+ while True:
140
+ utterance_id += 1
141
+ # chatglm使用官方的数据组织格式
142
+ if model.config.model_type == 'chatglm':
143
+ user_input = '[Round {}]\n\n问:{}\n\n答:'.format(utterance_id, user_input)
144
+ user_input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=False).input_ids
145
+ # firefly的数据组织格式
146
+ # 为了兼容qwen-7b,因为其对eos_token进行tokenize,无法得到对应的eos_token_id
147
+ else:
148
+ input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=False).input_ids
149
+ eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long)
150
+ user_input_ids = torch.concat([input_ids, eos_token_id], dim=1)
151
+ history_token_ids = torch.concat((history_token_ids, user_input_ids), dim=1)
152
+ model_input_ids = history_token_ids[:, -history_max_len:].to(device)
153
+ with torch.no_grad():
154
+ outputs = model.generate(
155
+ input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p,
156
+ temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id
157
+ )
158
+ model_input_ids_len = model_input_ids.size(1)
159
+ response_ids = outputs[:, model_input_ids_len:]
160
+ history_token_ids = torch.concat((history_token_ids, response_ids.cpu()), dim=1)
161
+ response = tokenizer.batch_decode(response_ids)
162
+ print("Firefly:" + response[0].strip().replace(tokenizer.eos_token, ""))
163
+ user_input = input('User:')
164
+
165
+
166
+ if __name__ == '__main__':
167
+ main()
168
+ ```