research14 commited on
Commit
3a6a9b6
1 Parent(s): dce2228

Added run_llm

Browse files
Files changed (1) hide show
  1. run_llm.py +238 -0
run_llm.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import time
5
+ import openai
6
+ import pickle
7
+ import argparse
8
+ import requests
9
+ from tqdm import tqdm
10
+ import torch
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer
12
+
13
+ from fastchat.model import load_model, get_conversation_template, add_model_args
14
+
15
+
16
+ openai.api_key = "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"
17
+
18
+
19
+ # determinant vs. determiner
20
+ # https://wikidiff.com/determiner/determinant
21
+ ents_prompt = [
22
+ 'Noun',
23
+ 'Verb',
24
+ 'Adjective',
25
+ 'Adverb',
26
+ 'Preposition/Subord',
27
+ 'Coordinating Conjunction',
28
+ # 'Cardinal Number',
29
+ 'Determiner',
30
+ 'Noun Phrase',
31
+ 'Verb Phrase',
32
+ 'Adjective Phrase',
33
+ 'Adverb Phrase',
34
+ 'Preposition Phrase',
35
+ 'Conjunction Phrase',
36
+ 'Coordinate Phrase',
37
+ 'Quantitave Phrase',
38
+ 'Complex Nominal',
39
+ 'Clause',
40
+ 'Dependent Clause',
41
+ 'Fragment Clause',
42
+ 'T-unit',
43
+ 'Complex T-unit',
44
+ # 'Fragment T-unit',
45
+ ]
46
+ ents = ['NN', 'VB', 'JJ', 'RB', 'IN', 'CC', 'DT', 'NP', 'VP', 'ADJP', 'ADVP', 'PP', 'CONJP', 'CP', 'QP', 'CN', 'C', 'DC', 'FC', 'T', 'CT']
47
+
48
+
49
+ model_mapping = {
50
+ # 'gpt3': 'gpt-3',
51
+ 'gpt3.5': 'gpt-3.5-turbo-0613',
52
+ 'vicuna-7b': 'lmsys/vicuna-7b-v1.3',
53
+ 'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
54
+ 'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
55
+ 'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
56
+ # 'llama2': 'meta-llama/Llama-2-7b-chat-hf',
57
+ 'llama-7b': '/data/jiali/llama/hf/7B',
58
+ 'llama-13b': '/data/jiali/llama/hf/13B',
59
+ 'llama-30b': '/data/jiali/llama/hf/30B',
60
+ 'llama-65b': '/data/jiali/llama/hf/65B',
61
+ 'alpaca': '/data/jiali/alpaca-7B',
62
+ # 'koala-7b': 'koala-7b',
63
+ # 'koala-13b': 'koala-13b',
64
+ }
65
+
66
+ for m in model_mapping.keys():
67
+ for eid, ent in enumerate(ents):
68
+ os.makedirs(f'result/openai_result/{m}/ptb/per_ent/{ent}', exist_ok=True)
69
+ os.makedirs(f'result/structured_prompt/{m}/ptb', exist_ok=True)
70
+
71
+
72
+ # s = int(sys.argv[1])
73
+ # e = int(sys.argv[2])
74
+
75
+ s = 0
76
+ e = 1000
77
+ with open('ptb_corpus/sample_uniform_1k_2.txt', 'r') as f:
78
+ selected_idx = f.readlines()
79
+ selected_idx = [int(i.strip()) for i in selected_idx][s:e]
80
+
81
+
82
+ ptb = []
83
+ with open('./ptb_corpus/ptb.jsonl', 'r') as f:
84
+ for l in f:
85
+ ptb.append(json.loads(l))
86
+
87
+
88
+ ## Prompt 1
89
+ template_all = '''Please output the <Noun, Verb, Adjective, Adverb, Preposition/Subord, Coordinating Conjunction, Cardinal Number, Determiner, Noun Phrase, Verb Phrase, Adjective Phrase, Adverb Phrase, Preposition Phrase, Conjunction Phrase, Coordinate Phrase, Quantitave Phrase, Complex Nominal, Clause, Dependent Clause, Fragment Clause, T-unit, Complex T-unit, Fragment T-unit> in the following sentence without any additional text in json format: "{}"'''
90
+ template_single = '''Please output any <{}> in the following sentence one per line without any additional text: "{}"'''
91
+
92
+ ## Prompt 2
93
+ with open('ptb_corpus/structured_prompting_demonstration_42.txt', 'r') as f:
94
+ demonstration = f.read()
95
+
96
+
97
+ def para(m):
98
+ c = 0
99
+ for n, p in m.named_parameters():
100
+ c += p.numel()
101
+ return c
102
+
103
+ def main(args=None):
104
+
105
+ if 'gpt3' in args.model:
106
+ pass
107
+
108
+ else:
109
+ path = model_mapping[args.model]
110
+ model, tokenizer = load_model(
111
+ path,
112
+ args.device,
113
+ args.num_gpus,
114
+ args.max_gpu_memory,
115
+ args.load_8bit,
116
+ args.cpu_offloading,
117
+ revision=args.revision,
118
+ debug=args.debug,
119
+ )
120
+
121
+ if args.prompt == 1:
122
+ for gid in tqdm(selected_idx, desc='Query'):
123
+ text = ptb[gid]['text']
124
+
125
+ for eid, ent in enumerate(ents):
126
+ # if os.path.exists(f'result/openai_result/{args.model}/ptb/per_ent/{ent}/{gid}.pkl') or \
127
+ # os.path.exists(f'result/openai_result/{args.model}/ptb/per_ent/{ent}/{gid}.txt'):
128
+ # print(gid, ent, 'skip')
129
+ # continue
130
+
131
+ ## Get prompt
132
+ msg = template_single.format(ents_prompt[eid], text)
133
+
134
+ if 'gpt' in args.model:
135
+ prompt = msg
136
+
137
+ elif 'vicuna' in args.model or 'alpaca' in args.model or 'fastchat-t5' in args.model:
138
+ conv = get_conversation_template(args.model)
139
+ conv.append_message(conv.roles[0], msg)
140
+ conv.append_message(conv.roles[1], None)
141
+ conv.system = ''
142
+ prompt = conv.get_prompt().strip()
143
+
144
+ elif 'llama-' in args.model:
145
+ prompt = '### Human: ' + msg + ' ### Assistant:'
146
+
147
+
148
+ ## Run
149
+ if 'gpt3' in args.model:
150
+ outputs = gpt3(prompt)
151
+
152
+ else:
153
+ outputs = fastchat(prompt, model, tokenizer)
154
+
155
+ with open(f'result/openai_result/{args.model}/ptb/per_ent/{ent}/{gid}.txt', 'w') as f:
156
+ f.write(outputs)
157
+
158
+
159
+ if args.prompt == 2:
160
+ for gid in tqdm(selected_idx, desc='Query'):
161
+ text = ptb[gid]['text']
162
+
163
+ if os.path.exists(f'result/structured_prompt/{args.model}/ptb/{gid}.pkl') or \
164
+ os.path.exists(f'result/structured_prompt/{args.model}/ptb/{gid}.txt'):
165
+ print(gid, 'skip')
166
+ continue
167
+
168
+ prompt = demonstration + '\n' + text
169
+
170
+ if 'gpt3' in args.model:
171
+ outputs = gpt3(prompt)
172
+
173
+ else:
174
+ outputs = fastchat(prompt, model, tokenizer)
175
+
176
+ with open(f'result/structured_prompt/{args.model}/ptb/{gid}.txt', 'w') as f:
177
+ f.write(outputs)
178
+
179
+
180
+ def fastchat(prompt, model, tokenizer):
181
+ input_ids = tokenizer([prompt]).input_ids
182
+ output_ids = model.generate(
183
+ torch.as_tensor(input_ids).cuda(),
184
+ do_sample=True,
185
+ temperature=args.temperature,
186
+ repetition_penalty=args.repetition_penalty,
187
+ max_new_tokens=args.max_new_tokens,
188
+ )
189
+
190
+ if model.config.is_encoder_decoder:
191
+ output_ids = output_ids[0]
192
+ else:
193
+ output_ids = output_ids[0][len(input_ids[0]) :]
194
+ outputs = tokenizer.decode(
195
+ output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
196
+ )
197
+
198
+ # print('Empty system message')
199
+ # print(f"{conv.roles[0]}: {msg}")
200
+ # print(f"{conv.roles[1]}: {outputs}")
201
+
202
+ return outputs
203
+
204
+
205
+ def gpt3(prompt):
206
+ try:
207
+ response = openai.ChatCompletion.create(
208
+ model=args.model, messages=[{"role": "user", "content": prompt}])
209
+
210
+ return response
211
+
212
+ except Exception as err:
213
+ print('Error')
214
+ print(err)
215
+
216
+ # time.sleep(1)
217
+ raise
218
+
219
+
220
+ if __name__ == "__main__":
221
+ parser = argparse.ArgumentParser()
222
+ add_model_args(parser)
223
+ parser.add_argument("--temperature", type=float, default=0.7)
224
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
225
+ parser.add_argument("--max-new-tokens", type=int, default=512)
226
+ parser.add_argument("--debug", action="store_true")
227
+ parser.add_argument("--message", type=str, default="Hello! Who are you?")
228
+ parser.add_argument("--start", type=int, default=0)
229
+ parser.add_argument("--end", type=int, default=1)
230
+ parser.add_argument("--model", required=True, type=str, default=None)
231
+ parser.add_argument("--prompt", required=True, type=int, default=None)
232
+ args = parser.parse_args()
233
+
234
+ # Reset default repetition penalty for T5 models.
235
+ if "t5" in args.model and args.repetition_penalty == 1.0:
236
+ args.repetition_penalty = 1.2
237
+
238
+ main(args)