research14 commited on
Commit
45a7f65
1 Parent(s): 93184b1

Made test file for run_llm

Browse files
Files changed (2) hide show
  1. run_llm.py +10 -30
  2. run_llm2.py +468 -0
run_llm.py CHANGED
@@ -158,8 +158,15 @@ def main(args=None):
158
  whitelist_ids_parse = [tokenizer.encode(word)[1] for word in syntags]
159
  bad_words_ids_parse = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_parse]
160
 
 
 
 
 
 
 
161
 
162
- if args.prompt == 1:
 
163
  for gid in tqdm(gid_list, desc='Query'):
164
  text = ptb[gid]['text']
165
 
@@ -213,8 +220,7 @@ def main(args=None):
213
  with open(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.txt', 'w') as f:
214
  f.write(outputs)
215
 
216
-
217
- if args.prompt == 2:
218
  for gid in tqdm(gid_list, desc='Query'):
219
  text = ptb[gid]['text']
220
 
@@ -298,8 +304,7 @@ def main(args=None):
298
  f.write(outputs)
299
 
300
 
301
-
302
- if args.prompt == 3:
303
  for gid in tqdm(gid_list, desc='Query'):
304
  text = ptb[gid]['text']
305
  tokens = ptb[gid]['tokens']
@@ -446,31 +451,6 @@ def gpt3(prompt):
446
 
447
  return None
448
 
449
- def run_llm_interface(model_path, prompt, sentence):
450
- import argparse
451
- from run_llm import main
452
-
453
- # Construct arguments
454
- args = argparse.Namespace(
455
- model_path=model_path,
456
- temperature=0.7,
457
- repetition_penalty=1.0,
458
- max_new_tokens=512,
459
- debug=False,
460
- message="Hello! Who are you?",
461
- start=0,
462
- end=1000,
463
- prompt=prompt,
464
- )
465
-
466
- # Run the main function
467
- # For simplicity, assuming prompt values 1, 2, and 3 correspond to different strategies
468
- # You may need to adjust this based on your actual logic
469
- main(args=args)
470
-
471
- # Return dummy values for now, replace with actual outputs
472
- return "Strategy 1 Output", "Strategy 2 Output", "Strategy 3 Output"
473
-
474
 
475
  if __name__ == "__main__":
476
  parser = argparse.ArgumentParser()
 
158
  whitelist_ids_parse = [tokenizer.encode(word)[1] for word in syntags]
159
  bad_words_ids_parse = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_parse]
160
 
161
+ if args.prompt == 1:
162
+ strategy1_qa(model, text, gid_list, tokenizer)
163
+ if args.prompt == 2:
164
+ strategy2_instruction(model, text, gid_list, tokenizer)
165
+ if args.prompt == 3:
166
+ strategy3_structured_prompt(model, text, gid_list, tokenizer, bad_words_ids_pos, bad_words_ids_bio, bad_words_ids_chunk, bad_words_ids_parse)
167
 
168
+
169
+ def strategy1_qa(model, text, gid_list, tokenizer):
170
  for gid in tqdm(gid_list, desc='Query'):
171
  text = ptb[gid]['text']
172
 
 
220
  with open(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.txt', 'w') as f:
221
  f.write(outputs)
222
 
223
+ def strategy2_instruction(model, text, gid_list, tokenizer):
 
224
  for gid in tqdm(gid_list, desc='Query'):
225
  text = ptb[gid]['text']
226
 
 
304
  f.write(outputs)
305
 
306
 
307
+ def strategy3_structured_prompt(model, text, gid_list, tokenizer, bad_words_ids_pos, bad_words_ids_bio, bad_words_ids_chunk, bad_words_ids_parse):
 
308
  for gid in tqdm(gid_list, desc='Query'):
309
  text = ptb[gid]['text']
310
  tokens = ptb[gid]['tokens']
 
451
 
452
  return None
453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
  if __name__ == "__main__":
456
  parser = argparse.ArgumentParser()
run_llm2.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import time
5
+ import openai
6
+ import pickle
7
+ import argparse
8
+ import requests
9
+ from tqdm import tqdm
10
+ import torch
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer
12
+
13
+ from fastchat.model import load_model, get_conversation_template, add_model_args
14
+
15
+ from nltk.tag.mapping import _UNIVERSAL_TAGS
16
+
17
+ import gradio as gr
18
+
19
+ uni_tags = list(_UNIVERSAL_TAGS)
20
+ uni_tags[-1] = 'PUNC'
21
+
22
+ bio_tags = ['B', 'I', 'O']
23
+ chunk_tags = ['ADJP', 'ADVP', 'CONJP', 'INTJ', 'LST', 'NP', 'O', 'PP', 'PRT', 'SBAR', 'UCP', 'VP']
24
+
25
+ syntags = ['NP', 'S', 'VP', 'ADJP', 'ADVP', 'SBAR', 'TOP', 'PP', 'POS', 'NAC', "''", 'SINV', 'PRN', 'QP', 'WHNP', 'RB', 'FRAG',
26
+ 'WHADVP', 'NX', 'PRT', 'VBZ', 'VBP', 'MD', 'NN', 'WHPP', 'SQ', 'SBARQ', 'LST', 'INTJ', 'X', 'UCP', 'CONJP', 'NNP', 'CD', 'JJ',
27
+ 'VBD', 'WHADJP', 'PRP', 'RRC', 'NNS', 'SYM', 'CC']
28
+
29
+ openai.api_key = "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"
30
+
31
+
32
+ # determinant vs. determiner
33
+ # https://wikidiff.com/determiner/determinant
34
+ ents_prompt = ['Noun','Verb','Adjective','Adverb','Preposition/Subord','Coordinating Conjunction',# 'Cardinal Number',
35
+ 'Determiner',
36
+ 'Noun Phrase','Verb Phrase','Adjective Phrase','Adverb Phrase','Preposition Phrase','Conjunction Phrase','Coordinate Phrase','Quantitave Phrase','Complex Nominal',
37
+ 'Clause','Dependent Clause','Fragment Clause','T-unit','Complex T-unit',# 'Fragment T-unit',
38
+ ][7:]
39
+ ents = ['NN', 'VB', 'JJ', 'RB', 'IN', 'CC', 'DT', 'NP', 'VP', 'ADJP', 'ADVP', 'PP', 'CONJP', 'CP', 'QP', 'CN', 'C', 'DC', 'FC', 'T', 'CT'][7:]
40
+
41
+
42
+ ents_prompt_uni_tags = ['Verb', 'Noun', 'Pronoun', 'Adjective', 'Adverb', 'Preposition and Postposition', 'Coordinating Conjunction',
43
+ 'Determiner', 'Cardinal Number', 'Particles or other function words',
44
+ 'Words that cannot be assigned a POS tag', 'Punctuation']
45
+
46
+ ents = uni_tags + ents
47
+ ents_prompt = ents_prompt_uni_tags + ents_prompt
48
+
49
+ for i, j in zip(ents, ents_prompt):
50
+ print(i, j)
51
+ # raise
52
+
53
+
54
+ model_mapping = {
55
+ # 'gpt3': 'gpt-3',
56
+ 'gpt3.5': 'gpt-3.5-turbo-0613',
57
+ 'vicuna-7b': 'lmsys/vicuna-7b-v1.3',
58
+ 'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
59
+ 'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
60
+ 'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
61
+ # 'llama2-7b': 'meta-llama/Llama-2-7b-hf',
62
+ # 'llama2-13b': 'meta-llama/Llama-2-13b-hf',
63
+ # 'llama2-70b': 'meta-llama/Llama-2-70b-hf',
64
+ 'llama-7b': './llama/hf/7B',
65
+ 'llama-13b': './llama/hf/13B',
66
+ 'llama-30b': './llama/hf/30B',
67
+ # 'llama-65b': './llama/hf/65B',
68
+ 'alpaca': './alpaca-7B',
69
+ # 'koala-7b': 'koala-7b',
70
+ # 'koala-13b': 'koala-13b',
71
+ }
72
+
73
+ for m in model_mapping.keys():
74
+ for eid, ent in enumerate(ents):
75
+ os.makedirs(f'result/prompt1_qa/{m}/ptb/per_ent/{ent}', exist_ok=True)
76
+
77
+ os.makedirs(f'result/prompt2_instruction/pos_tagging/{m}/ptb', exist_ok=True)
78
+ os.makedirs(f'result/prompt2_instruction/chunking/{m}/ptb', exist_ok=True)
79
+ os.makedirs(f'result/prompt2_instruction/parsing/{m}/ptb', exist_ok=True)
80
+
81
+ os.makedirs(f'result/prompt3_structured_prompt/pos_tagging/{m}/ptb', exist_ok=True)
82
+ os.makedirs(f'result/prompt3_structured_prompt/chunking/{m}/ptb', exist_ok=True)
83
+ os.makedirs(f'result/prompt3_structured_prompt/parsing/{m}/ptb', exist_ok=True)
84
+
85
+
86
+ #s = int(sys.argv[1])
87
+ #e = int(sys.argv[2])
88
+
89
+ #s = 0
90
+ #e = 1000
91
+ with open('sample_uniform_1k_2.txt', 'r') as f:
92
+ selected_idx = f.readlines()
93
+ selected_idx = [int(i.strip()) for i in selected_idx]#[s:e]
94
+
95
+
96
+ ptb = []
97
+ with open('ptb.jsonl', 'r') as f:
98
+ for l in f:
99
+ ptb.append(json.loads(l))
100
+
101
+
102
+ ## Prompt 1
103
+ template_all = '''Please output the <Noun, Verb, Adjective, Adverb, Preposition/Subord, Coordinating Conjunction, Cardinal Number, Determiner, Noun Phrase, Verb Phrase, Adjective Phrase, Adverb Phrase, Preposition Phrase, Conjunction Phrase, Coordinate Phrase, Quantitave Phrase, Complex Nominal, Clause, Dependent Clause, Fragment Clause, T-unit, Complex T-unit, Fragment T-unit> in the following sentence without any additional text in json format: "{}"'''
104
+ template_single = '''Please output any <{}> in the following sentence one per line without any additional text: "{}"'''
105
+
106
+ ## Prompt 2
107
+ prompt2_pos = '''Please pos tag the following sentence using Universal POS tag set without generating any additional text: {}'''
108
+ prompt2_chunk = '''Please do sentence chunking for the following sentence as in CoNLL 2000 shared task without generating any addtional text: {}'''
109
+ prompt2_parse = '''Generate textual representation of the constituency parse tree of the following sentence using Penn TreeBank tag set without outputing any additional text: {}'''
110
+
111
+ prompt2_chunk = '''Please chunk the following sentence in CoNLL 2000 format with BIO tags without outputing any additional text: {}'''
112
+
113
+ ## Prompt 3
114
+ with open('demonstration_3_42_pos.txt', 'r') as f:
115
+ demon_pos = f.read()
116
+ with open('demonstration_3_42_chunk.txt', 'r') as f:
117
+ demon_chunk = f.read()
118
+ with open('demonstration_3_42_parse.txt', 'r') as f:
119
+ demon_parse = f.read()
120
+
121
+
122
+ def para(m):
123
+ c = 0
124
+ for n, p in m.named_parameters():
125
+ c += p.numel()
126
+ return c
127
+
128
+ def main(args=None):
129
+
130
+ gid_list = selected_idx[args.start:args.end]
131
+
132
+
133
+ if 'gpt3' in args.model_path:
134
+ pass
135
+
136
+ else:
137
+ path = model_mapping[args.model_path]
138
+ model, tokenizer = load_model(
139
+ path,
140
+ args.device,
141
+ args.num_gpus,
142
+ args.max_gpu_memory,
143
+ args.load_8bit,
144
+ args.cpu_offloading,
145
+ revision=args.revision,
146
+ debug=args.debug,
147
+ )
148
+
149
+ whitelist_ids_pos = [tokenizer.encode(word)[1] for word in uni_tags]
150
+ bad_words_ids_pos = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_pos]
151
+
152
+ whitelist_ids_bio = [tokenizer.encode(word)[1] for word in bio_tags]
153
+ bad_words_ids_bio = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_bio]
154
+
155
+ whitelist_ids_chunk = [tokenizer.encode(word)[1] for word in chunk_tags]
156
+ bad_words_ids_chunk = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_chunk]
157
+
158
+ whitelist_ids_parse = [tokenizer.encode(word)[1] for word in syntags]
159
+ bad_words_ids_parse = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_parse]
160
+
161
+
162
+ if args.prompt == 1:
163
+ for gid in tqdm(gid_list, desc='Query'):
164
+ text = ptb[gid]['text']
165
+
166
+ for eid, ent in enumerate(ents):
167
+ os.makedirs(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}', exist_ok=True)
168
+
169
+ if ent == 'NOUN' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/NOUN'):
170
+ os.system(f'ln -sT ./NN result/prompt1_qa/{args.model_path}/ptb/per_ent/NOUN')
171
+ if ent == 'VERB' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/VERB'):
172
+ os.system(f'ln -sT ./VB result/prompt1_qa/{args.model_path}/ptb/per_ent/VERB')
173
+ if ent == 'ADJ' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/ADJ'):
174
+ os.system(f'ln -sT ./JJ result/prompt1_qa/{args.model_path}/ptb/per_ent/ADJ')
175
+ if ent == 'ADV' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/ADV'):
176
+ os.system(f'ln -sT ./RB result/prompt1_qa/{args.model_path}/ptb/per_ent/ADV')
177
+ if ent == 'CONJ' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/CONJ'):
178
+ os.system(f'ln -sT ./CC result/prompt1_qa/{args.model_path}/ptb/per_ent/CONJ')
179
+ if ent == 'DET' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/DET'):
180
+ os.system(f'ln -sT ./DT result/prompt1_qa/{args.model_path}/ptb/per_ent/DET')
181
+ if ent == 'ADP' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/ADP'):
182
+ os.system(f'ln -sT ./DT result/prompt1_qa/{args.model_path}/ptb/per_ent/IN')
183
+
184
+ if os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.txt'):
185
+ print(gid, ent, 'skip')
186
+ continue
187
+
188
+
189
+ ## Get prompt
190
+ msg = template_single.format(ents_prompt[eid], text)
191
+
192
+ ## Run
193
+ if 'gpt3' in args.model_path:
194
+ if os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.pkl'):
195
+ print('Found cache')
196
+ with open(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.pkl', 'rb') as f:
197
+ outputs = pickle.load(f)
198
+ outputs = outputs['choices'][0]['message']['content']
199
+ else:
200
+ outputs = gpt3(msg)
201
+ if outputs is None:
202
+ continue
203
+ time.sleep(0.2)
204
+
205
+ else:
206
+ conv = get_conversation_template(args.model_path)
207
+ conv.append_message(conv.roles[0], msg)
208
+ conv.append_message(conv.roles[1], None)
209
+ conv.system = ''
210
+ prompt = conv.get_prompt().strip()
211
+ outputs = fastchat(prompt, model, tokenizer)
212
+
213
+ with open(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.txt', 'w') as f:
214
+ f.write(outputs)
215
+
216
+
217
+ if args.prompt == 2:
218
+ for gid in tqdm(gid_list, desc='Query'):
219
+ text = ptb[gid]['text']
220
+
221
+ ## POS tagging
222
+ if os.path.exists(f'result/prompt2_instruction/pos_tagging/{args.model_path}/ptb/{gid}.txt'):
223
+ print(gid, 'skip')
224
+
225
+ else:
226
+ msg = prompt2_pos.format(text)
227
+
228
+ if 'gpt3' in args.model_path:
229
+ outputs = gpt3(msg)
230
+ if outputs is None:
231
+ continue
232
+ time.sleep(0.2)
233
+
234
+ else:
235
+ conv = get_conversation_template(args.model_path)
236
+ conv.append_message(conv.roles[0], msg)
237
+ conv.append_message(conv.roles[1], None)
238
+ conv.system = ''
239
+ prompt = conv.get_prompt()
240
+
241
+ outputs = fastchat(prompt, model, tokenizer)
242
+
243
+ with open(f'result/prompt2_instruction/pos_tagging/{args.model_path}/ptb/{gid}.txt', 'w') as f:
244
+ f.write(outputs)
245
+
246
+
247
+ ## Sentence chunking
248
+ if os.path.exists(f'result/prompt2_instruction/chunking/{args.model_path}/ptb/{gid}.txt'):
249
+ print(gid, 'skip')
250
+ if False:
251
+ pass
252
+ else:
253
+ msg = prompt2_chunk.format(text)
254
+
255
+ if 'gpt3' in args.model_path:
256
+ outputs = gpt3(msg)
257
+ if outputs is None:
258
+ continue
259
+ time.sleep(0.2)
260
+
261
+ else:
262
+ conv = get_conversation_template(args.model_path)
263
+ conv.append_message(conv.roles[0], msg)
264
+ conv.append_message(conv.roles[1], None)
265
+ conv.system = ''
266
+ prompt = conv.get_prompt()
267
+
268
+ outputs = fastchat(prompt, model, tokenizer)
269
+
270
+ print(args.model_path, gid, outputs)
271
+ with open(f'result/prompt2_instruction/chunking/{args.model_path}/ptb/{gid}.txt', 'w') as f:
272
+ f.write(outputs)
273
+
274
+
275
+ ## Parsing
276
+ if os.path.exists(f'result/prompt2_instruction/parsing/{args.model_path}/ptb/{gid}.txt'):
277
+ print(gid, 'skip')
278
+
279
+ else:
280
+ msg = prompt2_parse.format(text)
281
+
282
+ if 'gpt3' in args.model_path:
283
+ outputs = gpt3(msg)
284
+ if outputs is None:
285
+ continue
286
+ time.sleep(0.2)
287
+
288
+ else:
289
+ conv = get_conversation_template(args.model_path)
290
+ conv.append_message(conv.roles[0], msg)
291
+ conv.append_message(conv.roles[1], None)
292
+ conv.system = ''
293
+ prompt = conv.get_prompt()
294
+
295
+ outputs = fastchat(prompt, model, tokenizer)
296
+
297
+ with open(f'result/prompt2_instruction/parsing/{args.model_path}/ptb/{gid}.txt', 'w') as f:
298
+ f.write(outputs)
299
+
300
+
301
+
302
+ if args.prompt == 3:
303
+ for gid in tqdm(gid_list, desc='Query'):
304
+ text = ptb[gid]['text']
305
+ tokens = ptb[gid]['tokens']
306
+ poss = ptb[gid]['uni_poss']
307
+
308
+ ## POS tagging
309
+ if os.path.exists(f'result/prompt3_structured_prompt/pos_tagging/{args.model_path}/ptb/{gid}.txt'):
310
+ print(gid, 'skip')
311
+ continue
312
+
313
+ prompt = demon_pos + '\n' + 'C: ' + text + '\n' + 'T: '
314
+
315
+ if 'gpt3' in args.model_path:
316
+ outputs = gpt3(prompt)
317
+ if outputs is None:
318
+ continue
319
+ time.sleep(0.2)
320
+
321
+ else:
322
+ pred_poss = []
323
+ for _tok, _pos in zip(tokens, poss):
324
+ prompt = prompt + ' ' + _tok + '_'
325
+ outputs = structured_prompt(prompt, model, tokenizer, bad_words_ids_pos)
326
+ prompt = prompt + outputs
327
+ pred_poss.append(outputs)
328
+
329
+ outputs = ' '.join(pred_poss)
330
+ with open(f'result/prompt3_structured_prompt/pos_tagging/{args.model_path}/ptb/{gid}.txt', 'w') as f:
331
+ f.write(outputs)
332
+
333
+
334
+ ## Chunking
335
+ if os.path.exists(f'result/prompt3_structured_prompt/chunking/{args.model_path}/ptb/{gid}.txt'):
336
+ print(gid, 'skip')
337
+ continue
338
+
339
+ prompt = demon_chunk + '\n' + 'C: ' + text + '\n' + 'T: '
340
+
341
+ if 'gpt3' in args.model_path:
342
+ outputs = gpt3(prompt)
343
+ print(outputs)
344
+ if outputs is None:
345
+ continue
346
+ time.sleep(0.2)
347
+
348
+ else:
349
+ pred_chunk = []
350
+ for _tok, _pos in zip(tokens, poss):
351
+ prompt = prompt + ' ' + _tok + '_'
352
+
353
+ # Generate BIO
354
+ outputs_bio = structured_prompt(prompt, model, tokenizer, bad_words_ids_bio)
355
+ prompt = prompt + outputs_bio + '-'
356
+
357
+ # Generate tag
358
+ outputs_chunk = structured_prompt(prompt, model, tokenizer, bad_words_ids_chunk)
359
+ prompt = prompt + outputs_chunk
360
+
361
+ pred_chunk.append((outputs_bio + '-' + outputs_chunk))
362
+
363
+ outputs = ' '.join(pred_chunk)
364
+
365
+ with open(f'result/prompt3_structured_prompt/chunking/{args.model_path}/ptb/{gid}.txt', 'w') as f:
366
+ f.write(outputs)
367
+
368
+ ## Parsing
369
+ if os.path.exists(f'result/prompt3_structured_prompt/parsing/{args.model_path}/ptb/{gid}.txt'):
370
+ print(gid, 'skip')
371
+ continue
372
+
373
+ prompt = demon_parse + '\n' + 'C: ' + text + '\n' + 'T: '
374
+
375
+ if 'gpt3' in args.model_path:
376
+ outputs = gpt3(prompt)
377
+ if outputs is None:
378
+ continue
379
+ time.sleep(0.2)
380
+
381
+ else:
382
+ pred_syn = []
383
+ for _tok, _pos in zip(tokens, poss):
384
+ prompt = prompt + _tok + '_'
385
+ outputs = structured_prompt(prompt, model, tokenizer, bad_words_ids_parse)
386
+ pred_syn.append(outputs)
387
+
388
+ with open(f'result/prompt3_structured_prompt/parsing/{args.model_path}/ptb/{gid}.txt', 'w') as f:
389
+ f.write(' '.join(pred_syn))
390
+
391
+
392
+ def structured_prompt(prompt, model, tokenizer, bad_words_ids):
393
+ input_ids = tokenizer([prompt]).input_ids
394
+ output_ids = model.generate(
395
+ torch.as_tensor(input_ids).cuda(),
396
+ max_new_tokens=1,
397
+ bad_words_ids=bad_words_ids,
398
+ )
399
+
400
+ if model.config.is_encoder_decoder:
401
+ output_ids = output_ids[0]
402
+ else:
403
+ output_ids = output_ids[0][len(input_ids[0]) :]
404
+ outputs = tokenizer.decode(
405
+ output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
406
+ )
407
+
408
+ return outputs
409
+
410
+
411
+ def fastchat(prompt, model, tokenizer):
412
+ input_ids = tokenizer([prompt]).input_ids
413
+ output_ids = model.generate(
414
+ torch.as_tensor(input_ids).cuda(),
415
+ do_sample=True,
416
+ temperature=args.temperature,
417
+ repetition_penalty=args.repetition_penalty,
418
+ max_new_tokens=args.max_new_tokens,
419
+ )
420
+
421
+ if model.config.is_encoder_decoder:
422
+ output_ids = output_ids[0]
423
+ else:
424
+ output_ids = output_ids[0][len(input_ids[0]) :]
425
+ outputs = tokenizer.decode(
426
+ output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
427
+ )
428
+
429
+ #print('Empty system message')
430
+ #print(f"{conv.roles[0]}: {msg}")
431
+ #print(f"{conv.roles[1]}: {outputs}")
432
+
433
+ return outputs
434
+
435
+
436
+ def gpt3(prompt):
437
+ try:
438
+ response = openai.ChatCompletion.create(
439
+ model=model_mapping[args.model_path], messages=[{"role": "user", "content": prompt}])
440
+
441
+ return response['choices'][0]['message']['content']
442
+
443
+ except Exception as err:
444
+ print('Error')
445
+ print(err)
446
+
447
+ return None
448
+
449
+
450
+ if __name__ == "__main__":
451
+ parser = argparse.ArgumentParser()
452
+ add_model_args(parser)
453
+ parser.add_argument("--temperature", type=float, default=0.7)
454
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
455
+ parser.add_argument("--max-new-tokens", type=int, default=512)
456
+ parser.add_argument("--debug", action="store_true")
457
+ parser.add_argument("--message", type=str, default="Hello! Who are you?")
458
+ parser.add_argument("--start", type=int, default=0)
459
+ parser.add_argument("--end", type=int, default=1000)
460
+ parser.add_argument("--prompt", required=True, type=int, default=None)
461
+ # parser.add_argument("--system_msg", required=True, type=str, default='default_system_msg')
462
+ args = parser.parse_args()
463
+
464
+ # Reset default repetition penalty for T5 models.
465
+ if "t5" in args.model_path and args.repetition_penalty == 1.0:
466
+ args.repetition_penalty = 1.2
467
+
468
+ main(args)