Camila Salinas Camacho commited on
Commit
081b46f
1 Parent(s): 98e913c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -243
app.py CHANGED
@@ -1,246 +1,25 @@
1
  import gradio as gr
2
- import os
3
- import sys
4
- import json
5
- import time
6
- import openai
7
- import pickle
8
- import argparse
9
- import requests
10
- from tqdm import tqdm
11
- import torch
12
- from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer
13
-
14
- from fastchat.model import load_model, get_conversation_template, add_model_args
15
-
16
-
17
- openai.api_key = "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"
18
-
19
- def greet(name):
20
- return "Hello " + name + "!!"
21
-
22
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
23
- iface.launch()
24
-
25
- # determinant vs. determiner
26
- # https://wikidiff.com/determiner/determinant
27
- ents_prompt = [
28
- 'Noun',
29
- 'Verb',
30
- 'Adjective',
31
- 'Adverb',
32
- 'Preposition/Subord',
33
- 'Coordinating Conjunction',
34
- # 'Cardinal Number',
35
- 'Determiner',
36
- 'Noun Phrase',
37
- 'Verb Phrase',
38
- 'Adjective Phrase',
39
- 'Adverb Phrase',
40
- 'Preposition Phrase',
41
- 'Conjunction Phrase',
42
- 'Coordinate Phrase',
43
- 'Quantitave Phrase',
44
- 'Complex Nominal',
45
- 'Clause',
46
- 'Dependent Clause',
47
- 'Fragment Clause',
48
- 'T-unit',
49
- 'Complex T-unit',
50
- # 'Fragment T-unit',
51
- ]
52
- ents = ['NN', 'VB', 'JJ', 'RB', 'IN', 'CC', 'DT', 'NP', 'VP', 'ADJP', 'ADVP', 'PP', 'CONJP', 'CP', 'QP', 'CN', 'C', 'DC', 'FC', 'T', 'CT']
53
-
54
-
55
- model_mapping = {
56
- # 'gpt3': 'gpt-3',
57
- 'gpt3.5': 'gpt-3.5-turbo-0613',
58
- 'vicuna-7b': 'lmsys/vicuna-7b-v1.3',
59
- 'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
60
- 'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
61
- 'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
62
- # 'llama2': 'meta-llama/Llama-2-7b-chat-hf',
63
- 'llama-7b': '/data/jiali/llama/hf/7B',
64
- 'llama-13b': '/data/jiali/llama/hf/13B',
65
- 'llama-30b': '/data/jiali/llama/hf/30B',
66
- 'llama-65b': '/data/jiali/llama/hf/65B',
67
- 'alpaca': '/data/jiali/alpaca-7B',
68
- # 'koala-7b': 'koala-7b',
69
- # 'koala-13b': 'koala-13b',
70
- }
71
-
72
- for m in model_mapping.keys():
73
- for eid, ent in enumerate(ents):
74
- os.makedirs(f'result/openai_result/{m}/ptb/per_ent/{ent}', exist_ok=True)
75
- os.makedirs(f'result/structured_prompt/{m}/ptb', exist_ok=True)
76
-
77
-
78
- # s = int(sys.argv[1])
79
- # e = int(sys.argv[2])
80
-
81
- s = 0
82
- e = 1000
83
- with open('ptb_corpus/sample_uniform_1k_2.txt', 'r') as f:
84
- selected_idx = f.readlines()
85
- selected_idx = [int(i.strip()) for i in selected_idx][s:e]
86
-
87
-
88
- ptb = []
89
- with open('./ptb_corpus/ptb.jsonl', 'r') as f:
90
- for l in f:
91
- ptb.append(json.loads(l))
92
-
93
-
94
- ## Prompt 1
95
- template_all = '''Please output the <Noun, Verb, Adjective, Adverb, Preposition/Subord, Coordinating Conjunction, Cardinal Number, Determiner, Noun Phrase, Verb Phrase, Adjective Phrase, Adverb Phrase, Preposition Phrase, Conjunction Phrase, Coordinate Phrase, Quantitave Phrase, Complex Nominal, Clause, Dependent Clause, Fragment Clause, T-unit, Complex T-unit, Fragment T-unit> in the following sentence without any additional text in json format: "{}"'''
96
- template_single = '''Please output any <{}> in the following sentence one per line without any additional text: "{}"'''
97
-
98
- ## Prompt 2
99
- with open('ptb_corpus/structured_prompting_demonstration_42.txt', 'r') as f:
100
- demonstration = f.read()
101
-
102
-
103
- def para(m):
104
- c = 0
105
- for n, p in m.named_parameters():
106
- c += p.numel()
107
- return c
108
-
109
- def main(args=None):
110
-
111
- if 'gpt3' in args.model:
112
- pass
113
-
114
- else:
115
- path = model_mapping[args.model]
116
- model, tokenizer = load_model(
117
- path,
118
- args.device,
119
- args.num_gpus,
120
- args.max_gpu_memory,
121
- args.load_8bit,
122
- args.cpu_offloading,
123
- revision=args.revision,
124
- debug=args.debug,
125
- )
126
-
127
- if args.prompt == 1:
128
- for gid in tqdm(selected_idx, desc='Query'):
129
- text = ptb[gid]['text']
130
-
131
- for eid, ent in enumerate(ents):
132
- # if os.path.exists(f'result/openai_result/{args.model}/ptb/per_ent/{ent}/{gid}.pkl') or \
133
- # os.path.exists(f'result/openai_result/{args.model}/ptb/per_ent/{ent}/{gid}.txt'):
134
- # print(gid, ent, 'skip')
135
- # continue
136
-
137
- ## Get prompt
138
- msg = template_single.format(ents_prompt[eid], text)
139
-
140
- if 'gpt' in args.model:
141
- prompt = msg
142
-
143
- elif 'vicuna' in args.model or 'alpaca' in args.model or 'fastchat-t5' in args.model:
144
- conv = get_conversation_template(args.model)
145
- conv.append_message(conv.roles[0], msg)
146
- conv.append_message(conv.roles[1], None)
147
- conv.system = ''
148
- prompt = conv.get_prompt().strip()
149
-
150
- elif 'llama-' in args.model:
151
- prompt = '### Human: ' + msg + ' ### Assistant:'
152
-
153
-
154
- ## Run
155
- if 'gpt3' in args.model:
156
- outputs = gpt3(prompt)
157
-
158
- else:
159
- outputs = fastchat(prompt, model, tokenizer)
160
-
161
- with open(f'result/openai_result/{args.model}/ptb/per_ent/{ent}/{gid}.txt', 'w') as f:
162
- f.write(outputs)
163
-
164
-
165
- if args.prompt == 2:
166
- for gid in tqdm(selected_idx, desc='Query'):
167
- text = ptb[gid]['text']
168
-
169
- if os.path.exists(f'result/structured_prompt/{args.model}/ptb/{gid}.pkl') or \
170
- os.path.exists(f'result/structured_prompt/{args.model}/ptb/{gid}.txt'):
171
- print(gid, 'skip')
172
- continue
173
-
174
- prompt = demonstration + '\n' + text
175
-
176
- if 'gpt3' in args.model:
177
- outputs = gpt3(prompt)
178
-
179
- else:
180
- outputs = fastchat(prompt, model, tokenizer)
181
-
182
- with open(f'result/structured_prompt/{args.model}/ptb/{gid}.txt', 'w') as f:
183
- f.write(outputs)
184
-
185
-
186
- def fastchat(prompt, model, tokenizer):
187
- input_ids = tokenizer([prompt]).input_ids
188
- output_ids = model.generate(
189
- torch.as_tensor(input_ids).cuda(),
190
- do_sample=True,
191
- temperature=args.temperature,
192
- repetition_penalty=args.repetition_penalty,
193
- max_new_tokens=args.max_new_tokens,
194
- )
195
-
196
- if model.config.is_encoder_decoder:
197
- output_ids = output_ids[0]
198
- else:
199
- output_ids = output_ids[0][len(input_ids[0]) :]
200
- outputs = tokenizer.decode(
201
- output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
202
- )
203
-
204
- # print('Empty system message')
205
- # print(f"{conv.roles[0]}: {msg}")
206
- # print(f"{conv.roles[1]}: {outputs}")
207
-
208
- return outputs
209
-
210
-
211
- def gpt3(prompt):
212
- try:
213
- response = openai.ChatCompletion.create(
214
- model=args.model, messages=[{"role": "user", "content": prompt}])
215
-
216
- return response
217
-
218
- except Exception as err:
219
- print('Error')
220
- print(err)
221
-
222
- # time.sleep(1)
223
- raise
224
-
225
 
226
  if __name__ == "__main__":
227
- parser = argparse.ArgumentParser()
228
- add_model_args(parser)
229
- parser.add_argument("--temperature", type=float, default=0.7)
230
- parser.add_argument("--repetition_penalty", type=float, default=1.0)
231
- parser.add_argument("--max-new-tokens", type=int, default=512)
232
- parser.add_argument("--debug", action="store_true")
233
- parser.add_argument("--message", type=str, default="Hello! Who are you?")
234
- parser.add_argument("--start", type=int, default=0)
235
- parser.add_argument("--end", type=int, default=1)
236
- parser.add_argument("--model", required=True, type=str, default=None)
237
- parser.add_argument("--prompt", required=True, type=int, default=None)
238
- args = parser.parse_args()
239
-
240
- # Reset default repetition penalty for T5 models.
241
- if "t5" in args.model and args.repetition_penalty == 1.0:
242
- args.repetition_penalty = 1.2
243
-
244
- main(args)
245
-
246
-
 
1
  import gradio as gr
2
+ import subprocess
3
+ from gradio.mix import Parallel
4
+
5
+ def qa_prompting(model):
6
+ # Call your `run_llm.py` script for QA-Based Prompting with the selected model
7
+ output = subprocess.check_output([sys.executable, "run_llm.py", "--model", model, ...], text=True)
8
+ return output
9
+
10
+ def strategy_1_interface():
11
+ model_names = ["ChatGPT", "LLaMA", "Vicuna", "Alpaca", "Flan-T5"]
12
+ interfaces = []
13
+ for model_name in model_names:
14
+ interfaces.append(gr.Interface(
15
+ fn=qa_prompting,
16
+ inputs=gr.inputs.Textbox(label=f"{model_name} Input"),
17
+ outputs=gr.outputs.Textbox(label=f"{model_name} Output"),
18
+ title=f"Strategy 1 - QA-Based Prompting: {model_name}",
19
+ ))
20
+
21
+ return Parallel(*interfaces)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  if __name__ == "__main__":
24
+ iface = strategy_1_interface()
25
+ iface.launch()