Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import json | |
import time | |
import openai | |
import pickle | |
import argparse | |
import requests | |
from tqdm import tqdm | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer | |
from fastchat.model import load_model, get_conversation_template, add_model_args | |
from nltk.tag.mapping import _UNIVERSAL_TAGS | |
import gradio as gr | |
uni_tags = list(_UNIVERSAL_TAGS) | |
uni_tags[-1] = 'PUNC' | |
bio_tags = ['B', 'I', 'O'] | |
chunk_tags = ['ADJP', 'ADVP', 'CONJP', 'INTJ', 'LST', 'NP', 'O', 'PP', 'PRT', 'SBAR', 'UCP', 'VP'] | |
syntags = ['NP', 'S', 'VP', 'ADJP', 'ADVP', 'SBAR', 'TOP', 'PP', 'POS', 'NAC', "''", 'SINV', 'PRN', 'QP', 'WHNP', 'RB', 'FRAG', | |
'WHADVP', 'NX', 'PRT', 'VBZ', 'VBP', 'MD', 'NN', 'WHPP', 'SQ', 'SBARQ', 'LST', 'INTJ', 'X', 'UCP', 'CONJP', 'NNP', 'CD', 'JJ', | |
'VBD', 'WHADJP', 'PRP', 'RRC', 'NNS', 'SYM', 'CC'] | |
openai.api_key = "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" | |
# determinant vs. determiner | |
# https://wikidiff.com/determiner/determinant | |
ents_prompt = ['Noun','Verb','Adjective','Adverb','Preposition/Subord','Coordinating Conjunction',# 'Cardinal Number', | |
'Determiner', | |
'Noun Phrase','Verb Phrase','Adjective Phrase','Adverb Phrase','Preposition Phrase','Conjunction Phrase','Coordinate Phrase','Quantitave Phrase','Complex Nominal', | |
'Clause','Dependent Clause','Fragment Clause','T-unit','Complex T-unit',# 'Fragment T-unit', | |
][7:] | |
ents = ['NN', 'VB', 'JJ', 'RB', 'IN', 'CC', 'DT', 'NP', 'VP', 'ADJP', 'ADVP', 'PP', 'CONJP', 'CP', 'QP', 'CN', 'C', 'DC', 'FC', 'T', 'CT'][7:] | |
ents_prompt_uni_tags = ['Verb', 'Noun', 'Pronoun', 'Adjective', 'Adverb', 'Preposition and Postposition', 'Coordinating Conjunction', | |
'Determiner', 'Cardinal Number', 'Particles or other function words', | |
'Words that cannot be assigned a POS tag', 'Punctuation'] | |
ents = uni_tags + ents | |
ents_prompt = ents_prompt_uni_tags + ents_prompt | |
for i, j in zip(ents, ents_prompt): | |
print(i, j) | |
# raise | |
model_mapping = { | |
# 'gpt3': 'gpt-3', | |
'gpt3.5': 'gpt-3.5-turbo-0613', | |
'vicuna-7b': 'lmsys/vicuna-7b-v1.3', | |
'vicuna-13b': 'lmsys/vicuna-13b-v1.3', | |
'vicuna-33b': 'lmsys/vicuna-33b-v1.3', | |
'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0', | |
# 'llama2-7b': 'meta-llama/Llama-2-7b-hf', | |
# 'llama2-13b': 'meta-llama/Llama-2-13b-hf', | |
# 'llama2-70b': 'meta-llama/Llama-2-70b-hf', | |
'llama-7b': './llama/hf/7B', | |
'llama-13b': './llama/hf/13B', | |
'llama-30b': './llama/hf/30B', | |
# 'llama-65b': './llama/hf/65B', | |
'alpaca': './alpaca-7B', | |
# 'koala-7b': 'koala-7b', | |
# 'koala-13b': 'koala-13b', | |
} | |
for m in model_mapping.keys(): | |
for eid, ent in enumerate(ents): | |
os.makedirs(f'result/prompt1_qa/{m}/ptb/per_ent/{ent}', exist_ok=True) | |
os.makedirs(f'result/prompt2_instruction/pos_tagging/{m}/ptb', exist_ok=True) | |
os.makedirs(f'result/prompt2_instruction/chunking/{m}/ptb', exist_ok=True) | |
os.makedirs(f'result/prompt2_instruction/parsing/{m}/ptb', exist_ok=True) | |
os.makedirs(f'result/prompt3_structured_prompt/pos_tagging/{m}/ptb', exist_ok=True) | |
os.makedirs(f'result/prompt3_structured_prompt/chunking/{m}/ptb', exist_ok=True) | |
os.makedirs(f'result/prompt3_structured_prompt/parsing/{m}/ptb', exist_ok=True) | |
#s = int(sys.argv[1]) | |
#e = int(sys.argv[2]) | |
#s = 0 | |
#e = 1000 | |
with open('sample_uniform_1k_2.txt', 'r') as f: | |
selected_idx = f.readlines() | |
selected_idx = [int(i.strip()) for i in selected_idx]#[s:e] | |
ptb = [] | |
with open('ptb.jsonl', 'r') as f: | |
for l in f: | |
ptb.append(json.loads(l)) | |
## Prompt 1 | |
template_all = '''Please output the <Noun, Verb, Adjective, Adverb, Preposition/Subord, Coordinating Conjunction, Cardinal Number, Determiner, Noun Phrase, Verb Phrase, Adjective Phrase, Adverb Phrase, Preposition Phrase, Conjunction Phrase, Coordinate Phrase, Quantitave Phrase, Complex Nominal, Clause, Dependent Clause, Fragment Clause, T-unit, Complex T-unit, Fragment T-unit> in the following sentence without any additional text in json format: "{}"''' | |
template_single = '''Please output any <{}> in the following sentence one per line without any additional text: "{}"''' | |
## Prompt 2 | |
prompt2_pos = '''Please pos tag the following sentence using Universal POS tag set without generating any additional text: {}''' | |
prompt2_chunk = '''Please do sentence chunking for the following sentence as in CoNLL 2000 shared task without generating any addtional text: {}''' | |
prompt2_parse = '''Generate textual representation of the constituency parse tree of the following sentence using Penn TreeBank tag set without outputing any additional text: {}''' | |
prompt2_chunk = '''Please chunk the following sentence in CoNLL 2000 format with BIO tags without outputing any additional text: {}''' | |
## Prompt 3 | |
with open('demonstration_3_42_pos.txt', 'r') as f: | |
demon_pos = f.read() | |
with open('demonstration_3_42_chunk.txt', 'r') as f: | |
demon_chunk = f.read() | |
with open('demonstration_3_42_parse.txt', 'r') as f: | |
demon_parse = f.read() | |
def para(m): | |
c = 0 | |
for n, p in m.named_parameters(): | |
c += p.numel() | |
return c | |
def main(args=None): | |
gid_list = selected_idx[args.start:args.end] | |
if 'gpt3' in args.model_path: | |
pass | |
else: | |
path = model_mapping[args.model_path] | |
model, tokenizer = load_model( | |
path, | |
args.device, | |
args.num_gpus, | |
args.max_gpu_memory, | |
args.load_8bit, | |
args.cpu_offloading, | |
revision=args.revision, | |
debug=args.debug, | |
) | |
whitelist_ids_pos = [tokenizer.encode(word)[1] for word in uni_tags] | |
bad_words_ids_pos = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_pos] | |
whitelist_ids_bio = [tokenizer.encode(word)[1] for word in bio_tags] | |
bad_words_ids_bio = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_bio] | |
whitelist_ids_chunk = [tokenizer.encode(word)[1] for word in chunk_tags] | |
bad_words_ids_chunk = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_chunk] | |
whitelist_ids_parse = [tokenizer.encode(word)[1] for word in syntags] | |
bad_words_ids_parse = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_parse] | |
if args.prompt == 1: | |
strategy1_qa(model, text, gid_list, tokenizer) | |
if args.prompt == 2: | |
strategy2_instruction(model, text, gid_list, tokenizer) | |
if args.prompt == 3: | |
strategy3_structured_prompt(model, text, gid_list, tokenizer, bad_words_ids_pos, bad_words_ids_bio, bad_words_ids_chunk, bad_words_ids_parse) | |
def strategy1_qa(model, text, gid_list, tokenizer): | |
for gid in tqdm(gid_list, desc='Query'): | |
text = ptb[gid]['text'] | |
for eid, ent in enumerate(ents): | |
os.makedirs(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}', exist_ok=True) | |
if ent == 'NOUN' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/NOUN'): | |
os.system(f'ln -sT ./NN result/prompt1_qa/{args.model_path}/ptb/per_ent/NOUN') | |
if ent == 'VERB' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/VERB'): | |
os.system(f'ln -sT ./VB result/prompt1_qa/{args.model_path}/ptb/per_ent/VERB') | |
if ent == 'ADJ' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/ADJ'): | |
os.system(f'ln -sT ./JJ result/prompt1_qa/{args.model_path}/ptb/per_ent/ADJ') | |
if ent == 'ADV' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/ADV'): | |
os.system(f'ln -sT ./RB result/prompt1_qa/{args.model_path}/ptb/per_ent/ADV') | |
if ent == 'CONJ' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/CONJ'): | |
os.system(f'ln -sT ./CC result/prompt1_qa/{args.model_path}/ptb/per_ent/CONJ') | |
if ent == 'DET' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/DET'): | |
os.system(f'ln -sT ./DT result/prompt1_qa/{args.model_path}/ptb/per_ent/DET') | |
if ent == 'ADP' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/ADP'): | |
os.system(f'ln -sT ./DT result/prompt1_qa/{args.model_path}/ptb/per_ent/IN') | |
if os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.txt'): | |
print(gid, ent, 'skip') | |
continue | |
## Get prompt | |
msg = template_single.format(ents_prompt[eid], text) | |
## Run | |
if 'gpt3' in args.model_path: | |
if os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.pkl'): | |
print('Found cache') | |
with open(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.pkl', 'rb') as f: | |
outputs = pickle.load(f) | |
outputs = outputs['choices'][0]['message']['content'] | |
else: | |
outputs = gpt3(msg) | |
if outputs is None: | |
continue | |
time.sleep(0.2) | |
else: | |
conv = get_conversation_template(args.model_path) | |
conv.append_message(conv.roles[0], msg) | |
conv.append_message(conv.roles[1], None) | |
conv.system = '' | |
prompt = conv.get_prompt().strip() | |
outputs = fastchat(prompt, model, tokenizer) | |
with open(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.txt', 'w') as f: | |
f.write(outputs) | |
def strategy2_instruction(model, text, gid_list, tokenizer): | |
for gid in tqdm(gid_list, desc='Query'): | |
text = ptb[gid]['text'] | |
## POS tagging | |
if os.path.exists(f'result/prompt2_instruction/pos_tagging/{args.model_path}/ptb/{gid}.txt'): | |
print(gid, 'skip') | |
else: | |
msg = prompt2_pos.format(text) | |
if 'gpt3' in args.model_path: | |
outputs = gpt3(msg) | |
if outputs is None: | |
continue | |
time.sleep(0.2) | |
else: | |
conv = get_conversation_template(args.model_path) | |
conv.append_message(conv.roles[0], msg) | |
conv.append_message(conv.roles[1], None) | |
conv.system = '' | |
prompt = conv.get_prompt() | |
outputs = fastchat(prompt, model, tokenizer) | |
with open(f'result/prompt2_instruction/pos_tagging/{args.model_path}/ptb/{gid}.txt', 'w') as f: | |
f.write(outputs) | |
## Sentence chunking | |
if os.path.exists(f'result/prompt2_instruction/chunking/{args.model_path}/ptb/{gid}.txt'): | |
print(gid, 'skip') | |
if False: | |
pass | |
else: | |
msg = prompt2_chunk.format(text) | |
if 'gpt3' in args.model_path: | |
outputs = gpt3(msg) | |
if outputs is None: | |
continue | |
time.sleep(0.2) | |
else: | |
conv = get_conversation_template(args.model_path) | |
conv.append_message(conv.roles[0], msg) | |
conv.append_message(conv.roles[1], None) | |
conv.system = '' | |
prompt = conv.get_prompt() | |
outputs = fastchat(prompt, model, tokenizer) | |
print(args.model_path, gid, outputs) | |
with open(f'result/prompt2_instruction/chunking/{args.model_path}/ptb/{gid}.txt', 'w') as f: | |
f.write(outputs) | |
## Parsing | |
if os.path.exists(f'result/prompt2_instruction/parsing/{args.model_path}/ptb/{gid}.txt'): | |
print(gid, 'skip') | |
else: | |
msg = prompt2_parse.format(text) | |
if 'gpt3' in args.model_path: | |
outputs = gpt3(msg) | |
if outputs is None: | |
continue | |
time.sleep(0.2) | |
else: | |
conv = get_conversation_template(args.model_path) | |
conv.append_message(conv.roles[0], msg) | |
conv.append_message(conv.roles[1], None) | |
conv.system = '' | |
prompt = conv.get_prompt() | |
outputs = fastchat(prompt, model, tokenizer) | |
with open(f'result/prompt2_instruction/parsing/{args.model_path}/ptb/{gid}.txt', 'w') as f: | |
f.write(outputs) | |
def strategy3_structured_prompt(model, text, gid_list, tokenizer, bad_words_ids_pos, bad_words_ids_bio, bad_words_ids_chunk, bad_words_ids_parse): | |
for gid in tqdm(gid_list, desc='Query'): | |
text = ptb[gid]['text'] | |
tokens = ptb[gid]['tokens'] | |
poss = ptb[gid]['uni_poss'] | |
## POS tagging | |
if os.path.exists(f'result/prompt3_structured_prompt/pos_tagging/{args.model_path}/ptb/{gid}.txt'): | |
print(gid, 'skip') | |
continue | |
prompt = demon_pos + '\n' + 'C: ' + text + '\n' + 'T: ' | |
if 'gpt3' in args.model_path: | |
outputs = gpt3(prompt) | |
if outputs is None: | |
continue | |
time.sleep(0.2) | |
else: | |
pred_poss = [] | |
for _tok, _pos in zip(tokens, poss): | |
prompt = prompt + ' ' + _tok + '_' | |
outputs = structured_prompt(prompt, model, tokenizer, bad_words_ids_pos) | |
prompt = prompt + outputs | |
pred_poss.append(outputs) | |
outputs = ' '.join(pred_poss) | |
with open(f'result/prompt3_structured_prompt/pos_tagging/{args.model_path}/ptb/{gid}.txt', 'w') as f: | |
f.write(outputs) | |
## Chunking | |
if os.path.exists(f'result/prompt3_structured_prompt/chunking/{args.model_path}/ptb/{gid}.txt'): | |
print(gid, 'skip') | |
continue | |
prompt = demon_chunk + '\n' + 'C: ' + text + '\n' + 'T: ' | |
if 'gpt3' in args.model_path: | |
outputs = gpt3(prompt) | |
print(outputs) | |
if outputs is None: | |
continue | |
time.sleep(0.2) | |
else: | |
pred_chunk = [] | |
for _tok, _pos in zip(tokens, poss): | |
prompt = prompt + ' ' + _tok + '_' | |
# Generate BIO | |
outputs_bio = structured_prompt(prompt, model, tokenizer, bad_words_ids_bio) | |
prompt = prompt + outputs_bio + '-' | |
# Generate tag | |
outputs_chunk = structured_prompt(prompt, model, tokenizer, bad_words_ids_chunk) | |
prompt = prompt + outputs_chunk | |
pred_chunk.append((outputs_bio + '-' + outputs_chunk)) | |
outputs = ' '.join(pred_chunk) | |
with open(f'result/prompt3_structured_prompt/chunking/{args.model_path}/ptb/{gid}.txt', 'w') as f: | |
f.write(outputs) | |
## Parsing | |
if os.path.exists(f'result/prompt3_structured_prompt/parsing/{args.model_path}/ptb/{gid}.txt'): | |
print(gid, 'skip') | |
continue | |
prompt = demon_parse + '\n' + 'C: ' + text + '\n' + 'T: ' | |
if 'gpt3' in args.model_path: | |
outputs = gpt3(prompt) | |
if outputs is None: | |
continue | |
time.sleep(0.2) | |
else: | |
pred_syn = [] | |
for _tok, _pos in zip(tokens, poss): | |
prompt = prompt + _tok + '_' | |
outputs = structured_prompt(prompt, model, tokenizer, bad_words_ids_parse) | |
pred_syn.append(outputs) | |
with open(f'result/prompt3_structured_prompt/parsing/{args.model_path}/ptb/{gid}.txt', 'w') as f: | |
f.write(' '.join(pred_syn)) | |
def structured_prompt(prompt, model, tokenizer, bad_words_ids): | |
input_ids = tokenizer([prompt]).input_ids | |
output_ids = model.generate( | |
torch.as_tensor(input_ids).cuda(), | |
max_new_tokens=1, | |
bad_words_ids=bad_words_ids, | |
) | |
if model.config.is_encoder_decoder: | |
output_ids = output_ids[0] | |
else: | |
output_ids = output_ids[0][len(input_ids[0]) :] | |
outputs = tokenizer.decode( | |
output_ids, skip_special_tokens=True, spaces_between_special_tokens=False | |
) | |
return outputs | |
def fastchat(prompt, model, tokenizer): | |
input_ids = tokenizer([prompt]).input_ids | |
output_ids = model.generate( | |
torch.as_tensor(input_ids).cuda(), | |
do_sample=True, | |
temperature=args.temperature, | |
repetition_penalty=args.repetition_penalty, | |
max_new_tokens=args.max_new_tokens, | |
) | |
if model.config.is_encoder_decoder: | |
output_ids = output_ids[0] | |
else: | |
output_ids = output_ids[0][len(input_ids[0]) :] | |
outputs = tokenizer.decode( | |
output_ids, skip_special_tokens=True, spaces_between_special_tokens=False | |
) | |
#print('Empty system message') | |
#print(f"{conv.roles[0]}: {msg}") | |
#print(f"{conv.roles[1]}: {outputs}") | |
return outputs | |
def gpt3(prompt): | |
try: | |
response = openai.ChatCompletion.create( | |
model=model_mapping[args.model_path], messages=[{"role": "user", "content": prompt}]) | |
return response['choices'][0]['message']['content'] | |
except Exception as err: | |
print('Error') | |
print(err) | |
return None | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
add_model_args(parser) | |
parser.add_argument("--temperature", type=float, default=0.7) | |
parser.add_argument("--repetition_penalty", type=float, default=1.0) | |
parser.add_argument("--max-new-tokens", type=int, default=512) | |
parser.add_argument("--debug", action="store_true") | |
parser.add_argument("--message", type=str, default="Hello! Who are you?") | |
parser.add_argument("--start", type=int, default=0) | |
parser.add_argument("--end", type=int, default=1000) | |
parser.add_argument("--prompt", required=True, type=int, default=None) | |
# parser.add_argument("--system_msg", required=True, type=str, default='default_system_msg') | |
args = parser.parse_args() | |
# Reset default repetition penalty for T5 models. | |
if "t5" in args.model_path and args.repetition_penalty == 1.0: | |
args.repetition_penalty = 1.2 | |
main(args) |