research14 commited on
Commit
5e8be56
1 Parent(s): 45a7f65

Applied edits

Browse files
Files changed (2) hide show
  1. app.py +41 -28
  2. run_llm.py +0 -371
app.py CHANGED
@@ -1,38 +1,51 @@
1
- # app.py
2
-
3
  import gradio as gr
4
- from run_llm import run_llm_interface
 
 
 
5
 
6
- theme = gr.themes.Soft()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # 3 inputs:
9
- # - An input text which will be a random string
10
- # - First dropdown to select the task (POS, Chunking, Parsing)
11
- # - Second dropdown select the model type
12
- # use run_llm.py to feed the models and then output 3 results in 3 output boxes, one for each strategy (strategy 1, 2 and 3)
13
 
14
- # Define example instructions for testing
15
- #instruction_examples = [
16
- # ["Describe the origin of the universe"],
17
- # ["Explain the concept of artificial intelligence"],
18
- # ["Describe the most common types of cancer"],
19
- #]
20
 
21
- with gr.Interface(
22
- fn=run_llm_interface,
 
 
 
 
 
23
  inputs=[
24
- gr.Dropdown(['gpt3.5', 'vicuna-7b', 'vicuna-13b', 'fastchat-t5', 'llama-7b', 'llama-13b', 'llama-30b', 'alpaca'], label="Select Model", default='gpt3.5', key="model_path"),
25
- gr.Dropdown(['POS Tagging', 'Chunking', 'Parsing'], label="Select Task", default='POS Tagging', key="prompt"),
26
- gr.Textbox("", label="Enter Sentence", key="sentence", placeholder="Enter a sentence..."),
27
  ],
28
  outputs=[
29
- gr.Textbox("", label="Strategy 1 Output", key="output_1", readonly=True),
30
- gr.Textbox("", label="Strategy 2 Output", key="output_2", readonly=True),
31
- gr.Textbox("", label="Strategy 3 Output", key="output_3", readonly=True),
32
  ],
33
- #examples=instruction_examples,
34
  live=False,
35
- title="LLM Evaluator with Linguistic Scrutiny",
36
- theme=theme
37
- ) as iface:
38
- iface.launch()
 
 
 
1
  import gradio as gr
2
+ import json
3
+ from run_llm import template_all, prompt2_pos, prompt2_chunk, prompt2_parse, demon_pos, demon_chunk, demon_parse, model_mapping
4
+
5
+ # Your existing code
6
 
7
+ # Function to process text based on model and task
8
+ def process_text(model_name, task, text):
9
+ # Define prompts for each strategy based on the task
10
+ strategy_prompts = {
11
+ 'Strategy 1': template_all.format(text),
12
+ 'Strategy 2': {
13
+ 'POS': prompt2_pos.format(text),
14
+ 'Chunking': prompt2_chunk.format(text),
15
+ 'Parsing': prompt2_parse.format(text),
16
+ }.get(task, "Invalid Task Selection for Strategy 2"),
17
+ 'Strategy 3': {
18
+ 'POS': demon_pos,
19
+ 'Chunking': demon_chunk,
20
+ 'Parsing': demon_parse,
21
+ }.get(task, "Invalid Task Selection for Strategy 3"),
22
+ }
23
 
24
+ # Get the selected prompt based on the strategy
25
+ prompt = strategy_prompts.get(model_name, "Invalid Model Selection")
 
 
 
26
 
27
+ # Add your logic to feed the prompt to the selected model and get the result
28
+ result = "Processed Result" # Replace this with your actual result
29
+ return result
 
 
 
30
 
31
+ # Dropdown options for model and task
32
+ model_options = list(model_mapping.keys())
33
+ task_options = ['POS', 'Chunking', 'Parsing']
34
+
35
+ # Gradio interface
36
+ iface = gr.Interface(
37
+ fn=process_text,
38
  inputs=[
39
+ gr.Dropdown(model_options, label="Select Model"),
40
+ gr.Dropdown(task_options, label="Select Task"),
41
+ gr.Textbox(label="Input Text", placeholder="Enter the text to process..."),
42
  ],
43
  outputs=[
44
+ gr.Textbox(label="Strategy 1 QA Result", output_transform=lambda x: json.dumps(x, indent=2)),
45
+ gr.Textbox(label="Strategy 2 Instruction Result", output_transform=lambda x: json.dumps(x, indent=2)),
46
+ gr.Textbox(label="Strategy 3 Structured Prompting Result", output_transform=lambda x: json.dumps(x, indent=2)),
47
  ],
 
48
  live=False,
49
+ )
50
+
51
+ iface.launch()
 
run_llm.py CHANGED
@@ -70,24 +70,6 @@ model_mapping = {
70
  # 'koala-13b': 'koala-13b',
71
  }
72
 
73
- for m in model_mapping.keys():
74
- for eid, ent in enumerate(ents):
75
- os.makedirs(f'result/prompt1_qa/{m}/ptb/per_ent/{ent}', exist_ok=True)
76
-
77
- os.makedirs(f'result/prompt2_instruction/pos_tagging/{m}/ptb', exist_ok=True)
78
- os.makedirs(f'result/prompt2_instruction/chunking/{m}/ptb', exist_ok=True)
79
- os.makedirs(f'result/prompt2_instruction/parsing/{m}/ptb', exist_ok=True)
80
-
81
- os.makedirs(f'result/prompt3_structured_prompt/pos_tagging/{m}/ptb', exist_ok=True)
82
- os.makedirs(f'result/prompt3_structured_prompt/chunking/{m}/ptb', exist_ok=True)
83
- os.makedirs(f'result/prompt3_structured_prompt/parsing/{m}/ptb', exist_ok=True)
84
-
85
-
86
- #s = int(sys.argv[1])
87
- #e = int(sys.argv[2])
88
-
89
- #s = 0
90
- #e = 1000
91
  with open('sample_uniform_1k_2.txt', 'r') as f:
92
  selected_idx = f.readlines()
93
  selected_idx = [int(i.strip()) for i in selected_idx]#[s:e]
@@ -118,356 +100,3 @@ with open('demonstration_3_42_chunk.txt', 'r') as f:
118
  with open('demonstration_3_42_parse.txt', 'r') as f:
119
  demon_parse = f.read()
120
 
121
-
122
- def para(m):
123
- c = 0
124
- for n, p in m.named_parameters():
125
- c += p.numel()
126
- return c
127
-
128
- def main(args=None):
129
-
130
- gid_list = selected_idx[args.start:args.end]
131
-
132
-
133
- if 'gpt3' in args.model_path:
134
- pass
135
-
136
- else:
137
- path = model_mapping[args.model_path]
138
- model, tokenizer = load_model(
139
- path,
140
- args.device,
141
- args.num_gpus,
142
- args.max_gpu_memory,
143
- args.load_8bit,
144
- args.cpu_offloading,
145
- revision=args.revision,
146
- debug=args.debug,
147
- )
148
-
149
- whitelist_ids_pos = [tokenizer.encode(word)[1] for word in uni_tags]
150
- bad_words_ids_pos = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_pos]
151
-
152
- whitelist_ids_bio = [tokenizer.encode(word)[1] for word in bio_tags]
153
- bad_words_ids_bio = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_bio]
154
-
155
- whitelist_ids_chunk = [tokenizer.encode(word)[1] for word in chunk_tags]
156
- bad_words_ids_chunk = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_chunk]
157
-
158
- whitelist_ids_parse = [tokenizer.encode(word)[1] for word in syntags]
159
- bad_words_ids_parse = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_parse]
160
-
161
- if args.prompt == 1:
162
- strategy1_qa(model, text, gid_list, tokenizer)
163
- if args.prompt == 2:
164
- strategy2_instruction(model, text, gid_list, tokenizer)
165
- if args.prompt == 3:
166
- strategy3_structured_prompt(model, text, gid_list, tokenizer, bad_words_ids_pos, bad_words_ids_bio, bad_words_ids_chunk, bad_words_ids_parse)
167
-
168
-
169
- def strategy1_qa(model, text, gid_list, tokenizer):
170
- for gid in tqdm(gid_list, desc='Query'):
171
- text = ptb[gid]['text']
172
-
173
- for eid, ent in enumerate(ents):
174
- os.makedirs(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}', exist_ok=True)
175
-
176
- if ent == 'NOUN' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/NOUN'):
177
- os.system(f'ln -sT ./NN result/prompt1_qa/{args.model_path}/ptb/per_ent/NOUN')
178
- if ent == 'VERB' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/VERB'):
179
- os.system(f'ln -sT ./VB result/prompt1_qa/{args.model_path}/ptb/per_ent/VERB')
180
- if ent == 'ADJ' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/ADJ'):
181
- os.system(f'ln -sT ./JJ result/prompt1_qa/{args.model_path}/ptb/per_ent/ADJ')
182
- if ent == 'ADV' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/ADV'):
183
- os.system(f'ln -sT ./RB result/prompt1_qa/{args.model_path}/ptb/per_ent/ADV')
184
- if ent == 'CONJ' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/CONJ'):
185
- os.system(f'ln -sT ./CC result/prompt1_qa/{args.model_path}/ptb/per_ent/CONJ')
186
- if ent == 'DET' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/DET'):
187
- os.system(f'ln -sT ./DT result/prompt1_qa/{args.model_path}/ptb/per_ent/DET')
188
- if ent == 'ADP' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/ADP'):
189
- os.system(f'ln -sT ./DT result/prompt1_qa/{args.model_path}/ptb/per_ent/IN')
190
-
191
- if os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.txt'):
192
- print(gid, ent, 'skip')
193
- continue
194
-
195
-
196
- ## Get prompt
197
- msg = template_single.format(ents_prompt[eid], text)
198
-
199
- ## Run
200
- if 'gpt3' in args.model_path:
201
- if os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.pkl'):
202
- print('Found cache')
203
- with open(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.pkl', 'rb') as f:
204
- outputs = pickle.load(f)
205
- outputs = outputs['choices'][0]['message']['content']
206
- else:
207
- outputs = gpt3(msg)
208
- if outputs is None:
209
- continue
210
- time.sleep(0.2)
211
-
212
- else:
213
- conv = get_conversation_template(args.model_path)
214
- conv.append_message(conv.roles[0], msg)
215
- conv.append_message(conv.roles[1], None)
216
- conv.system = ''
217
- prompt = conv.get_prompt().strip()
218
- outputs = fastchat(prompt, model, tokenizer)
219
-
220
- with open(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.txt', 'w') as f:
221
- f.write(outputs)
222
-
223
- def strategy2_instruction(model, text, gid_list, tokenizer):
224
- for gid in tqdm(gid_list, desc='Query'):
225
- text = ptb[gid]['text']
226
-
227
- ## POS tagging
228
- if os.path.exists(f'result/prompt2_instruction/pos_tagging/{args.model_path}/ptb/{gid}.txt'):
229
- print(gid, 'skip')
230
-
231
- else:
232
- msg = prompt2_pos.format(text)
233
-
234
- if 'gpt3' in args.model_path:
235
- outputs = gpt3(msg)
236
- if outputs is None:
237
- continue
238
- time.sleep(0.2)
239
-
240
- else:
241
- conv = get_conversation_template(args.model_path)
242
- conv.append_message(conv.roles[0], msg)
243
- conv.append_message(conv.roles[1], None)
244
- conv.system = ''
245
- prompt = conv.get_prompt()
246
-
247
- outputs = fastchat(prompt, model, tokenizer)
248
-
249
- with open(f'result/prompt2_instruction/pos_tagging/{args.model_path}/ptb/{gid}.txt', 'w') as f:
250
- f.write(outputs)
251
-
252
-
253
- ## Sentence chunking
254
- if os.path.exists(f'result/prompt2_instruction/chunking/{args.model_path}/ptb/{gid}.txt'):
255
- print(gid, 'skip')
256
- if False:
257
- pass
258
- else:
259
- msg = prompt2_chunk.format(text)
260
-
261
- if 'gpt3' in args.model_path:
262
- outputs = gpt3(msg)
263
- if outputs is None:
264
- continue
265
- time.sleep(0.2)
266
-
267
- else:
268
- conv = get_conversation_template(args.model_path)
269
- conv.append_message(conv.roles[0], msg)
270
- conv.append_message(conv.roles[1], None)
271
- conv.system = ''
272
- prompt = conv.get_prompt()
273
-
274
- outputs = fastchat(prompt, model, tokenizer)
275
-
276
- print(args.model_path, gid, outputs)
277
- with open(f'result/prompt2_instruction/chunking/{args.model_path}/ptb/{gid}.txt', 'w') as f:
278
- f.write(outputs)
279
-
280
-
281
- ## Parsing
282
- if os.path.exists(f'result/prompt2_instruction/parsing/{args.model_path}/ptb/{gid}.txt'):
283
- print(gid, 'skip')
284
-
285
- else:
286
- msg = prompt2_parse.format(text)
287
-
288
- if 'gpt3' in args.model_path:
289
- outputs = gpt3(msg)
290
- if outputs is None:
291
- continue
292
- time.sleep(0.2)
293
-
294
- else:
295
- conv = get_conversation_template(args.model_path)
296
- conv.append_message(conv.roles[0], msg)
297
- conv.append_message(conv.roles[1], None)
298
- conv.system = ''
299
- prompt = conv.get_prompt()
300
-
301
- outputs = fastchat(prompt, model, tokenizer)
302
-
303
- with open(f'result/prompt2_instruction/parsing/{args.model_path}/ptb/{gid}.txt', 'w') as f:
304
- f.write(outputs)
305
-
306
-
307
- def strategy3_structured_prompt(model, text, gid_list, tokenizer, bad_words_ids_pos, bad_words_ids_bio, bad_words_ids_chunk, bad_words_ids_parse):
308
- for gid in tqdm(gid_list, desc='Query'):
309
- text = ptb[gid]['text']
310
- tokens = ptb[gid]['tokens']
311
- poss = ptb[gid]['uni_poss']
312
-
313
- ## POS tagging
314
- if os.path.exists(f'result/prompt3_structured_prompt/pos_tagging/{args.model_path}/ptb/{gid}.txt'):
315
- print(gid, 'skip')
316
- continue
317
-
318
- prompt = demon_pos + '\n' + 'C: ' + text + '\n' + 'T: '
319
-
320
- if 'gpt3' in args.model_path:
321
- outputs = gpt3(prompt)
322
- if outputs is None:
323
- continue
324
- time.sleep(0.2)
325
-
326
- else:
327
- pred_poss = []
328
- for _tok, _pos in zip(tokens, poss):
329
- prompt = prompt + ' ' + _tok + '_'
330
- outputs = structured_prompt(prompt, model, tokenizer, bad_words_ids_pos)
331
- prompt = prompt + outputs
332
- pred_poss.append(outputs)
333
-
334
- outputs = ' '.join(pred_poss)
335
- with open(f'result/prompt3_structured_prompt/pos_tagging/{args.model_path}/ptb/{gid}.txt', 'w') as f:
336
- f.write(outputs)
337
-
338
-
339
- ## Chunking
340
- if os.path.exists(f'result/prompt3_structured_prompt/chunking/{args.model_path}/ptb/{gid}.txt'):
341
- print(gid, 'skip')
342
- continue
343
-
344
- prompt = demon_chunk + '\n' + 'C: ' + text + '\n' + 'T: '
345
-
346
- if 'gpt3' in args.model_path:
347
- outputs = gpt3(prompt)
348
- print(outputs)
349
- if outputs is None:
350
- continue
351
- time.sleep(0.2)
352
-
353
- else:
354
- pred_chunk = []
355
- for _tok, _pos in zip(tokens, poss):
356
- prompt = prompt + ' ' + _tok + '_'
357
-
358
- # Generate BIO
359
- outputs_bio = structured_prompt(prompt, model, tokenizer, bad_words_ids_bio)
360
- prompt = prompt + outputs_bio + '-'
361
-
362
- # Generate tag
363
- outputs_chunk = structured_prompt(prompt, model, tokenizer, bad_words_ids_chunk)
364
- prompt = prompt + outputs_chunk
365
-
366
- pred_chunk.append((outputs_bio + '-' + outputs_chunk))
367
-
368
- outputs = ' '.join(pred_chunk)
369
-
370
- with open(f'result/prompt3_structured_prompt/chunking/{args.model_path}/ptb/{gid}.txt', 'w') as f:
371
- f.write(outputs)
372
-
373
- ## Parsing
374
- if os.path.exists(f'result/prompt3_structured_prompt/parsing/{args.model_path}/ptb/{gid}.txt'):
375
- print(gid, 'skip')
376
- continue
377
-
378
- prompt = demon_parse + '\n' + 'C: ' + text + '\n' + 'T: '
379
-
380
- if 'gpt3' in args.model_path:
381
- outputs = gpt3(prompt)
382
- if outputs is None:
383
- continue
384
- time.sleep(0.2)
385
-
386
- else:
387
- pred_syn = []
388
- for _tok, _pos in zip(tokens, poss):
389
- prompt = prompt + _tok + '_'
390
- outputs = structured_prompt(prompt, model, tokenizer, bad_words_ids_parse)
391
- pred_syn.append(outputs)
392
-
393
- with open(f'result/prompt3_structured_prompt/parsing/{args.model_path}/ptb/{gid}.txt', 'w') as f:
394
- f.write(' '.join(pred_syn))
395
-
396
-
397
- def structured_prompt(prompt, model, tokenizer, bad_words_ids):
398
- input_ids = tokenizer([prompt]).input_ids
399
- output_ids = model.generate(
400
- torch.as_tensor(input_ids).cuda(),
401
- max_new_tokens=1,
402
- bad_words_ids=bad_words_ids,
403
- )
404
-
405
- if model.config.is_encoder_decoder:
406
- output_ids = output_ids[0]
407
- else:
408
- output_ids = output_ids[0][len(input_ids[0]) :]
409
- outputs = tokenizer.decode(
410
- output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
411
- )
412
-
413
- return outputs
414
-
415
-
416
- def fastchat(prompt, model, tokenizer):
417
- input_ids = tokenizer([prompt]).input_ids
418
- output_ids = model.generate(
419
- torch.as_tensor(input_ids).cuda(),
420
- do_sample=True,
421
- temperature=args.temperature,
422
- repetition_penalty=args.repetition_penalty,
423
- max_new_tokens=args.max_new_tokens,
424
- )
425
-
426
- if model.config.is_encoder_decoder:
427
- output_ids = output_ids[0]
428
- else:
429
- output_ids = output_ids[0][len(input_ids[0]) :]
430
- outputs = tokenizer.decode(
431
- output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
432
- )
433
-
434
- #print('Empty system message')
435
- #print(f"{conv.roles[0]}: {msg}")
436
- #print(f"{conv.roles[1]}: {outputs}")
437
-
438
- return outputs
439
-
440
-
441
- def gpt3(prompt):
442
- try:
443
- response = openai.ChatCompletion.create(
444
- model=model_mapping[args.model_path], messages=[{"role": "user", "content": prompt}])
445
-
446
- return response['choices'][0]['message']['content']
447
-
448
- except Exception as err:
449
- print('Error')
450
- print(err)
451
-
452
- return None
453
-
454
-
455
- if __name__ == "__main__":
456
- parser = argparse.ArgumentParser()
457
- add_model_args(parser)
458
- parser.add_argument("--temperature", type=float, default=0.7)
459
- parser.add_argument("--repetition_penalty", type=float, default=1.0)
460
- parser.add_argument("--max-new-tokens", type=int, default=512)
461
- parser.add_argument("--debug", action="store_true")
462
- parser.add_argument("--message", type=str, default="Hello! Who are you?")
463
- parser.add_argument("--start", type=int, default=0)
464
- parser.add_argument("--end", type=int, default=1000)
465
- parser.add_argument("--prompt", required=True, type=int, default=None)
466
- # parser.add_argument("--system_msg", required=True, type=str, default='default_system_msg')
467
- args = parser.parse_args()
468
-
469
- # Reset default repetition penalty for T5 models.
470
- if "t5" in args.model_path and args.repetition_penalty == 1.0:
471
- args.repetition_penalty = 1.2
472
-
473
- main(args)
 
70
  # 'koala-13b': 'koala-13b',
71
  }
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  with open('sample_uniform_1k_2.txt', 'r') as f:
74
  selected_idx = f.readlines()
75
  selected_idx = [int(i.strip()) for i in selected_idx]#[s:e]
 
100
  with open('demonstration_3_42_parse.txt', 'r') as f:
101
  demon_parse = f.read()
102