Camila Salinas Camacho commited on
Commit
98e913c
1 Parent(s): e8e247e
Files changed (1) hide show
  1. app.py +240 -1
app.py CHANGED
@@ -1,7 +1,246 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  def greet(name):
4
  return "Hello " + name + "!!"
5
 
6
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import sys
4
+ import json
5
+ import time
6
+ import openai
7
+ import pickle
8
+ import argparse
9
+ import requests
10
+ from tqdm import tqdm
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer
13
+
14
+ from fastchat.model import load_model, get_conversation_template, add_model_args
15
+
16
+
17
+ openai.api_key = "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"
18
 
19
  def greet(name):
20
  return "Hello " + name + "!!"
21
 
22
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
23
+ iface.launch()
24
+
25
+ # determinant vs. determiner
26
+ # https://wikidiff.com/determiner/determinant
27
+ ents_prompt = [
28
+ 'Noun',
29
+ 'Verb',
30
+ 'Adjective',
31
+ 'Adverb',
32
+ 'Preposition/Subord',
33
+ 'Coordinating Conjunction',
34
+ # 'Cardinal Number',
35
+ 'Determiner',
36
+ 'Noun Phrase',
37
+ 'Verb Phrase',
38
+ 'Adjective Phrase',
39
+ 'Adverb Phrase',
40
+ 'Preposition Phrase',
41
+ 'Conjunction Phrase',
42
+ 'Coordinate Phrase',
43
+ 'Quantitave Phrase',
44
+ 'Complex Nominal',
45
+ 'Clause',
46
+ 'Dependent Clause',
47
+ 'Fragment Clause',
48
+ 'T-unit',
49
+ 'Complex T-unit',
50
+ # 'Fragment T-unit',
51
+ ]
52
+ ents = ['NN', 'VB', 'JJ', 'RB', 'IN', 'CC', 'DT', 'NP', 'VP', 'ADJP', 'ADVP', 'PP', 'CONJP', 'CP', 'QP', 'CN', 'C', 'DC', 'FC', 'T', 'CT']
53
+
54
+
55
+ model_mapping = {
56
+ # 'gpt3': 'gpt-3',
57
+ 'gpt3.5': 'gpt-3.5-turbo-0613',
58
+ 'vicuna-7b': 'lmsys/vicuna-7b-v1.3',
59
+ 'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
60
+ 'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
61
+ 'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
62
+ # 'llama2': 'meta-llama/Llama-2-7b-chat-hf',
63
+ 'llama-7b': '/data/jiali/llama/hf/7B',
64
+ 'llama-13b': '/data/jiali/llama/hf/13B',
65
+ 'llama-30b': '/data/jiali/llama/hf/30B',
66
+ 'llama-65b': '/data/jiali/llama/hf/65B',
67
+ 'alpaca': '/data/jiali/alpaca-7B',
68
+ # 'koala-7b': 'koala-7b',
69
+ # 'koala-13b': 'koala-13b',
70
+ }
71
+
72
+ for m in model_mapping.keys():
73
+ for eid, ent in enumerate(ents):
74
+ os.makedirs(f'result/openai_result/{m}/ptb/per_ent/{ent}', exist_ok=True)
75
+ os.makedirs(f'result/structured_prompt/{m}/ptb', exist_ok=True)
76
+
77
+
78
+ # s = int(sys.argv[1])
79
+ # e = int(sys.argv[2])
80
+
81
+ s = 0
82
+ e = 1000
83
+ with open('ptb_corpus/sample_uniform_1k_2.txt', 'r') as f:
84
+ selected_idx = f.readlines()
85
+ selected_idx = [int(i.strip()) for i in selected_idx][s:e]
86
+
87
+
88
+ ptb = []
89
+ with open('./ptb_corpus/ptb.jsonl', 'r') as f:
90
+ for l in f:
91
+ ptb.append(json.loads(l))
92
+
93
+
94
+ ## Prompt 1
95
+ template_all = '''Please output the <Noun, Verb, Adjective, Adverb, Preposition/Subord, Coordinating Conjunction, Cardinal Number, Determiner, Noun Phrase, Verb Phrase, Adjective Phrase, Adverb Phrase, Preposition Phrase, Conjunction Phrase, Coordinate Phrase, Quantitave Phrase, Complex Nominal, Clause, Dependent Clause, Fragment Clause, T-unit, Complex T-unit, Fragment T-unit> in the following sentence without any additional text in json format: "{}"'''
96
+ template_single = '''Please output any <{}> in the following sentence one per line without any additional text: "{}"'''
97
+
98
+ ## Prompt 2
99
+ with open('ptb_corpus/structured_prompting_demonstration_42.txt', 'r') as f:
100
+ demonstration = f.read()
101
+
102
+
103
+ def para(m):
104
+ c = 0
105
+ for n, p in m.named_parameters():
106
+ c += p.numel()
107
+ return c
108
+
109
+ def main(args=None):
110
+
111
+ if 'gpt3' in args.model:
112
+ pass
113
+
114
+ else:
115
+ path = model_mapping[args.model]
116
+ model, tokenizer = load_model(
117
+ path,
118
+ args.device,
119
+ args.num_gpus,
120
+ args.max_gpu_memory,
121
+ args.load_8bit,
122
+ args.cpu_offloading,
123
+ revision=args.revision,
124
+ debug=args.debug,
125
+ )
126
+
127
+ if args.prompt == 1:
128
+ for gid in tqdm(selected_idx, desc='Query'):
129
+ text = ptb[gid]['text']
130
+
131
+ for eid, ent in enumerate(ents):
132
+ # if os.path.exists(f'result/openai_result/{args.model}/ptb/per_ent/{ent}/{gid}.pkl') or \
133
+ # os.path.exists(f'result/openai_result/{args.model}/ptb/per_ent/{ent}/{gid}.txt'):
134
+ # print(gid, ent, 'skip')
135
+ # continue
136
+
137
+ ## Get prompt
138
+ msg = template_single.format(ents_prompt[eid], text)
139
+
140
+ if 'gpt' in args.model:
141
+ prompt = msg
142
+
143
+ elif 'vicuna' in args.model or 'alpaca' in args.model or 'fastchat-t5' in args.model:
144
+ conv = get_conversation_template(args.model)
145
+ conv.append_message(conv.roles[0], msg)
146
+ conv.append_message(conv.roles[1], None)
147
+ conv.system = ''
148
+ prompt = conv.get_prompt().strip()
149
+
150
+ elif 'llama-' in args.model:
151
+ prompt = '### Human: ' + msg + ' ### Assistant:'
152
+
153
+
154
+ ## Run
155
+ if 'gpt3' in args.model:
156
+ outputs = gpt3(prompt)
157
+
158
+ else:
159
+ outputs = fastchat(prompt, model, tokenizer)
160
+
161
+ with open(f'result/openai_result/{args.model}/ptb/per_ent/{ent}/{gid}.txt', 'w') as f:
162
+ f.write(outputs)
163
+
164
+
165
+ if args.prompt == 2:
166
+ for gid in tqdm(selected_idx, desc='Query'):
167
+ text = ptb[gid]['text']
168
+
169
+ if os.path.exists(f'result/structured_prompt/{args.model}/ptb/{gid}.pkl') or \
170
+ os.path.exists(f'result/structured_prompt/{args.model}/ptb/{gid}.txt'):
171
+ print(gid, 'skip')
172
+ continue
173
+
174
+ prompt = demonstration + '\n' + text
175
+
176
+ if 'gpt3' in args.model:
177
+ outputs = gpt3(prompt)
178
+
179
+ else:
180
+ outputs = fastchat(prompt, model, tokenizer)
181
+
182
+ with open(f'result/structured_prompt/{args.model}/ptb/{gid}.txt', 'w') as f:
183
+ f.write(outputs)
184
+
185
+
186
+ def fastchat(prompt, model, tokenizer):
187
+ input_ids = tokenizer([prompt]).input_ids
188
+ output_ids = model.generate(
189
+ torch.as_tensor(input_ids).cuda(),
190
+ do_sample=True,
191
+ temperature=args.temperature,
192
+ repetition_penalty=args.repetition_penalty,
193
+ max_new_tokens=args.max_new_tokens,
194
+ )
195
+
196
+ if model.config.is_encoder_decoder:
197
+ output_ids = output_ids[0]
198
+ else:
199
+ output_ids = output_ids[0][len(input_ids[0]) :]
200
+ outputs = tokenizer.decode(
201
+ output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
202
+ )
203
+
204
+ # print('Empty system message')
205
+ # print(f"{conv.roles[0]}: {msg}")
206
+ # print(f"{conv.roles[1]}: {outputs}")
207
+
208
+ return outputs
209
+
210
+
211
+ def gpt3(prompt):
212
+ try:
213
+ response = openai.ChatCompletion.create(
214
+ model=args.model, messages=[{"role": "user", "content": prompt}])
215
+
216
+ return response
217
+
218
+ except Exception as err:
219
+ print('Error')
220
+ print(err)
221
+
222
+ # time.sleep(1)
223
+ raise
224
+
225
+
226
+ if __name__ == "__main__":
227
+ parser = argparse.ArgumentParser()
228
+ add_model_args(parser)
229
+ parser.add_argument("--temperature", type=float, default=0.7)
230
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
231
+ parser.add_argument("--max-new-tokens", type=int, default=512)
232
+ parser.add_argument("--debug", action="store_true")
233
+ parser.add_argument("--message", type=str, default="Hello! Who are you?")
234
+ parser.add_argument("--start", type=int, default=0)
235
+ parser.add_argument("--end", type=int, default=1)
236
+ parser.add_argument("--model", required=True, type=str, default=None)
237
+ parser.add_argument("--prompt", required=True, type=int, default=None)
238
+ args = parser.parse_args()
239
+
240
+ # Reset default repetition penalty for T5 models.
241
+ if "t5" in args.model and args.repetition_penalty == 1.0:
242
+ args.repetition_penalty = 1.2
243
+
244
+ main(args)
245
+
246
+