GPT2 / app.py
BigSalmon's picture
Update app.py
12b2310
raw
history blame
9.01 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
import torch
first = """informal english: corn fields are all across illinois, visible once you leave chicago.\nTranslated into the Style of Abraham Lincoln: corn fields ( permeate illinois / span the state of illinois / ( occupy / persist in ) all corners of illinois / line the horizon of illinois / envelop the landscape of illinois ), manifesting themselves visibly as one ventures beyond chicago.\n\ninformal english: """
@st.cache(allow_output_mutation=True)
def get_model():
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPTNeo350MInformalToFormalLincoln2")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln21")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln40")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln41")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln41")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln49")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPT2InformalToFormalLincoln42")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/Points3")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPTNeo1.3BPointsLincolnFormalInformal")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/MediumInformalToFormalLincoln")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPTNeo350MInformalToFormalLincoln7")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincolnConciseWordy")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/MediumInformalToFormalLincoln2")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/MediumInformalToFormalLincoln3")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/MediumInformalToFormalLincoln4")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln50")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPT2Neo1.3BPoints2")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPT2Neo1.3BPoints3")
#model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
#tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln63Paraphrase")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln68Paraphrase")
#model2 = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln63Paraphrase")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln63Paraphrase")
#tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln73Paraphrase")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln73Paraphrase")
#tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln76Paraphrase")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln76Paraphrase")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln78Paraphrase")
#tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln78Paraphrase")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln80Paraphrase")
#tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln80Paraphrase")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln82Paraphrase")
#tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln82Paraphrase")
model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln85Paraphrase")
tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln85Paraphrase")
tokenizer2 = AutoTokenizer.from_pretrained("gpt2")
model2 = AutoModelForCausalLM.from_pretrained("gpt2")
return model, model2, tokenizer, tokenizer2
model, model2, tokenizer, tokenizer2 = get_model()
st.text('''For Prompt Templates: https://huggingface.co/BigSalmon/InformalToFormalLincoln82Paraphrase''')
temp = st.sidebar.slider("Temperature", 0.7, 1.5)
number_of_outputs = st.sidebar.slider("Number of Outputs", 5, 50)
lengths = st.sidebar.slider("Length", 3, 500)
bad_words = st.text_input("Words You Do Not Want Generated", " core lemon height time ")
logs_outputs = st.sidebar.slider("Logit Outputs", 50, 300)
def run_generate(text, bad_words):
yo = []
input_ids = tokenizer.encode(text, return_tensors='pt')
res = len(tokenizer.encode(text))
bad_words = bad_words.split()
bad_word_ids = []
for bad_word in bad_words:
bad_word = " " + bad_word
ids = tokenizer(bad_word).input_ids
bad_word_ids.append(ids)
sample_outputs = model.generate(
input_ids,
do_sample=True,
max_length= res + lengths,
min_length = res + lengths,
top_k=50,
temperature=temp,
num_return_sequences=number_of_outputs,
bad_words_ids=bad_word_ids
)
for i in range(number_of_outputs):
e = tokenizer.decode(sample_outputs[i])
e = e.replace(text, "")
yo.append(e)
return yo
def BestProbs5(prompt):
prompt = prompt.strip()
text = tokenizer.encode(prompt)
myinput, past_key_values = torch.tensor([text]), None
myinput = myinput
logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False)
logits = logits[0,-1]
probabilities = torch.nn.functional.softmax(logits)
best_logits, best_indices = logits.topk(number_of_outputs)
best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
for i in best_words[0:number_of_outputs]:
#print(i)
print("\n")
g = (prompt + i)
st.write(g)
l = run_generate(g, "hey")
st.write(l)
def run_generate2(text, bad_words):
yo = []
input_ids = tokenizer2.encode(text, return_tensors='pt')
res = len(tokenizer2.encode(text))
bad_words = bad_words.split()
bad_word_ids = []
for bad_word in bad_words:
bad_word = " " + bad_word
ids = tokenizer2(bad_word).input_ids
bad_word_ids.append(ids)
sample_outputs = model2.generate(
input_ids,
do_sample=True,
max_length= res + lengths,
min_length = res + lengths,
top_k=50,
temperature=temp,
num_return_sequences=number_of_outputs,
bad_words_ids=bad_word_ids
)
for i in range(number_of_outputs):
e = tokenizer2.decode(sample_outputs[i])
e = e.replace(text, "")
yo.append(e)
return yo
def prefix_format(sentence):
words = sentence.split()
if "[MASK]" in sentence:
words2 = words.index("[MASK]")
#print(words2)
output = ("<Prefix> " + ' '.join(words[:words2]) + " <Prefix> " + "<Suffix> " + ' '.join(words[words2+1:]) + " <Suffix>" + " <Middle>")
st.write(output)
else:
st.write("Add [MASK] to sentence")
with st.form(key='my_form'):
text = st.text_area(label='Enter sentence', value=first)
submit_button = st.form_submit_button(label='Submit')
submit_button2 = st.form_submit_button(label='Submit Log Probs')
submit_button3 = st.form_submit_button(label='Submit Other Model')
submit_button4 = st.form_submit_button(label='Submit Log Probs Other Model')
submit_button5 = st.form_submit_button(label='Most Prob')
submit_button6 = st.form_submit_button(label='Turn Sentence with [MASK] into <Prefix> Format')
if submit_button:
translated_text = run_generate(text, bad_words)
st.write(translated_text if translated_text else "No translation found")
if submit_button2:
with torch.no_grad():
text2 = str(text)
print(text2)
text3 = tokenizer.encode(text2)
myinput, past_key_values = torch.tensor([text3]), None
myinput = myinput
logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False)
logits = logits[0,-1]
probabilities = torch.nn.functional.softmax(logits)
best_logits, best_indices = logits.topk(logs_outputs)
best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
st.write(best_words)
if submit_button3:
translated_text = run_generate2(text, bad_words)
st.write(translated_text if translated_text else "No translation found")
if submit_button4:
text2 = str(text)
print(text2)
text3 = tokenizer2.encode(text2)
myinput, past_key_values = torch.tensor([text3]), None
myinput = myinput
logits, past_key_values = model2(myinput, past_key_values = past_key_values, return_dict=False)
logits = logits[0,-1]
probabilities = torch.nn.functional.softmax(logits)
best_logits, best_indices = logits.topk(logs_outputs)
best_words = [tokenizer2.decode([idx.item()]) for idx in best_indices]
st.write(best_words)
if submit_button5:
BestProbs5(text)
if submit_button6:
text2 = str(text)
prefix_format(text2)