Camila Salinas Camacho commited on
Commit
dce2228
1 Parent(s): 19b9274

Delete run_llm.py

Browse files
Files changed (1) hide show
  1. run_llm.py +0 -238
run_llm.py DELETED
@@ -1,238 +0,0 @@
1
- import os
2
- import sys
3
- import json
4
- import time
5
- import openai
6
- import pickle
7
- import argparse
8
- import requests
9
- from tqdm import tqdm
10
- import torch
11
- from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer
12
-
13
- from fastchat.model import load_model, get_conversation_template, add_model_args
14
-
15
-
16
- openai.api_key = "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"
17
-
18
-
19
- # determinant vs. determiner
20
- # https://wikidiff.com/determiner/determinant
21
- ents_prompt = [
22
- 'Noun',
23
- 'Verb',
24
- 'Adjective',
25
- 'Adverb',
26
- 'Preposition/Subord',
27
- 'Coordinating Conjunction',
28
- # 'Cardinal Number',
29
- 'Determiner',
30
- 'Noun Phrase',
31
- 'Verb Phrase',
32
- 'Adjective Phrase',
33
- 'Adverb Phrase',
34
- 'Preposition Phrase',
35
- 'Conjunction Phrase',
36
- 'Coordinate Phrase',
37
- 'Quantitave Phrase',
38
- 'Complex Nominal',
39
- 'Clause',
40
- 'Dependent Clause',
41
- 'Fragment Clause',
42
- 'T-unit',
43
- 'Complex T-unit',
44
- # 'Fragment T-unit',
45
- ]
46
- ents = ['NN', 'VB', 'JJ', 'RB', 'IN', 'CC', 'DT', 'NP', 'VP', 'ADJP', 'ADVP', 'PP', 'CONJP', 'CP', 'QP', 'CN', 'C', 'DC', 'FC', 'T', 'CT']
47
-
48
-
49
- model_mapping = {
50
- # 'gpt3': 'gpt-3',
51
- 'gpt3.5': 'gpt-3.5-turbo-0613',
52
- 'vicuna-7b': 'lmsys/vicuna-7b-v1.3',
53
- 'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
54
- 'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
55
- 'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
56
- # 'llama2': 'meta-llama/Llama-2-7b-chat-hf',
57
- 'llama-7b': '/data/jiali/llama/hf/7B',
58
- 'llama-13b': '/data/jiali/llama/hf/13B',
59
- 'llama-30b': '/data/jiali/llama/hf/30B',
60
- 'llama-65b': '/data/jiali/llama/hf/65B',
61
- 'alpaca': '/data/jiali/alpaca-7B',
62
- # 'koala-7b': 'koala-7b',
63
- # 'koala-13b': 'koala-13b',
64
- }
65
-
66
- for m in model_mapping.keys():
67
- for eid, ent in enumerate(ents):
68
- os.makedirs(f'result/openai_result/{m}/ptb/per_ent/{ent}', exist_ok=True)
69
- os.makedirs(f'result/structured_prompt/{m}/ptb', exist_ok=True)
70
-
71
-
72
- # s = int(sys.argv[1])
73
- # e = int(sys.argv[2])
74
-
75
- s = 0
76
- e = 1000
77
- with open('ptb_corpus/sample_uniform_1k_2.txt', 'r') as f:
78
- selected_idx = f.readlines()
79
- selected_idx = [int(i.strip()) for i in selected_idx][s:e]
80
-
81
-
82
- ptb = []
83
- with open('./ptb_corpus/ptb.jsonl', 'r') as f:
84
- for l in f:
85
- ptb.append(json.loads(l))
86
-
87
-
88
- ## Prompt 1
89
- template_all = '''Please output the <Noun, Verb, Adjective, Adverb, Preposition/Subord, Coordinating Conjunction, Cardinal Number, Determiner, Noun Phrase, Verb Phrase, Adjective Phrase, Adverb Phrase, Preposition Phrase, Conjunction Phrase, Coordinate Phrase, Quantitave Phrase, Complex Nominal, Clause, Dependent Clause, Fragment Clause, T-unit, Complex T-unit, Fragment T-unit> in the following sentence without any additional text in json format: "{}"'''
90
- template_single = '''Please output any <{}> in the following sentence one per line without any additional text: "{}"'''
91
-
92
- ## Prompt 2
93
- with open('ptb_corpus/structured_prompting_demonstration_42.txt', 'r') as f:
94
- demonstration = f.read()
95
-
96
-
97
- def para(m):
98
- c = 0
99
- for n, p in m.named_parameters():
100
- c += p.numel()
101
- return c
102
-
103
- def main(args=None):
104
-
105
- if 'gpt3' in args.model:
106
- pass
107
-
108
- else:
109
- path = model_mapping[args.model]
110
- model, tokenizer = load_model(
111
- path,
112
- args.device,
113
- args.num_gpus,
114
- args.max_gpu_memory,
115
- args.load_8bit,
116
- args.cpu_offloading,
117
- revision=args.revision,
118
- debug=args.debug,
119
- )
120
-
121
- if args.prompt == 1:
122
- for gid in tqdm(selected_idx, desc='Query'):
123
- text = ptb[gid]['text']
124
-
125
- for eid, ent in enumerate(ents):
126
- # if os.path.exists(f'result/openai_result/{args.model}/ptb/per_ent/{ent}/{gid}.pkl') or \
127
- # os.path.exists(f'result/openai_result/{args.model}/ptb/per_ent/{ent}/{gid}.txt'):
128
- # print(gid, ent, 'skip')
129
- # continue
130
-
131
- ## Get prompt
132
- msg = template_single.format(ents_prompt[eid], text)
133
-
134
- if 'gpt' in args.model:
135
- prompt = msg
136
-
137
- elif 'vicuna' in args.model or 'alpaca' in args.model or 'fastchat-t5' in args.model:
138
- conv = get_conversation_template(args.model)
139
- conv.append_message(conv.roles[0], msg)
140
- conv.append_message(conv.roles[1], None)
141
- conv.system = ''
142
- prompt = conv.get_prompt().strip()
143
-
144
- elif 'llama-' in args.model:
145
- prompt = '### Human: ' + msg + ' ### Assistant:'
146
-
147
-
148
- ## Run
149
- if 'gpt3' in args.model:
150
- outputs = gpt3(prompt)
151
-
152
- else:
153
- outputs = fastchat(prompt, model, tokenizer)
154
-
155
- with open(f'result/openai_result/{args.model}/ptb/per_ent/{ent}/{gid}.txt', 'w') as f:
156
- f.write(outputs)
157
-
158
-
159
- if args.prompt == 2:
160
- for gid in tqdm(selected_idx, desc='Query'):
161
- text = ptb[gid]['text']
162
-
163
- if os.path.exists(f'result/structured_prompt/{args.model}/ptb/{gid}.pkl') or \
164
- os.path.exists(f'result/structured_prompt/{args.model}/ptb/{gid}.txt'):
165
- print(gid, 'skip')
166
- continue
167
-
168
- prompt = demonstration + '\n' + text
169
-
170
- if 'gpt3' in args.model:
171
- outputs = gpt3(prompt)
172
-
173
- else:
174
- outputs = fastchat(prompt, model, tokenizer)
175
-
176
- with open(f'result/structured_prompt/{args.model}/ptb/{gid}.txt', 'w') as f:
177
- f.write(outputs)
178
-
179
-
180
- def fastchat(prompt, model, tokenizer):
181
- input_ids = tokenizer([prompt]).input_ids
182
- output_ids = model.generate(
183
- torch.as_tensor(input_ids).cuda(),
184
- do_sample=True,
185
- temperature=args.temperature,
186
- repetition_penalty=args.repetition_penalty,
187
- max_new_tokens=args.max_new_tokens,
188
- )
189
-
190
- if model.config.is_encoder_decoder:
191
- output_ids = output_ids[0]
192
- else:
193
- output_ids = output_ids[0][len(input_ids[0]) :]
194
- outputs = tokenizer.decode(
195
- output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
196
- )
197
-
198
- # print('Empty system message')
199
- # print(f"{conv.roles[0]}: {msg}")
200
- # print(f"{conv.roles[1]}: {outputs}")
201
-
202
- return outputs
203
-
204
-
205
- def gpt3(prompt):
206
- try:
207
- response = openai.ChatCompletion.create(
208
- model=args.model, messages=[{"role": "user", "content": prompt}])
209
-
210
- return response
211
-
212
- except Exception as err:
213
- print('Error')
214
- print(err)
215
-
216
- # time.sleep(1)
217
- raise
218
-
219
-
220
- if __name__ == "__main__":
221
- parser = argparse.ArgumentParser()
222
- add_model_args(parser)
223
- parser.add_argument("--temperature", type=float, default=0.7)
224
- parser.add_argument("--repetition_penalty", type=float, default=1.0)
225
- parser.add_argument("--max-new-tokens", type=int, default=512)
226
- parser.add_argument("--debug", action="store_true")
227
- parser.add_argument("--message", type=str, default="Hello! Who are you?")
228
- parser.add_argument("--start", type=int, default=0)
229
- parser.add_argument("--end", type=int, default=1)
230
- parser.add_argument("--model", required=True, type=str, default=None)
231
- parser.add_argument("--prompt", required=True, type=int, default=None)
232
- args = parser.parse_args()
233
-
234
- # Reset default repetition penalty for T5 models.
235
- if "t5" in args.model and args.repetition_penalty == 1.0:
236
- args.repetition_penalty = 1.2
237
-
238
- main(args)