|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel |
|
import torch |
|
|
|
first = """informal english: corn fields are all across illinois, visible once you leave chicago.\nTranslated into the Style of Abraham Lincoln: corn fields ( permeate illinois / span the state of illinois / ( occupy / persist in ) all corners of illinois / line the horizon of illinois / envelop the landscape of illinois ), manifesting themselves visibly as one ventures beyond chicago.\n\ninformal english: """ |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def get_model(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln85Paraphrase") |
|
tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln85Paraphrase") |
|
tokenizer2 = AutoTokenizer.from_pretrained("gpt2") |
|
model2 = AutoModelForCausalLM.from_pretrained("gpt2") |
|
return model, model2, tokenizer, tokenizer2 |
|
|
|
model, model2, tokenizer, tokenizer2 = get_model() |
|
|
|
st.text('''For Prompt Templates: https://huggingface.co/BigSalmon/InformalToFormalLincoln82Paraphrase''') |
|
|
|
temp = st.sidebar.slider("Temperature", 0.7, 1.5) |
|
number_of_outputs = st.sidebar.slider("Number of Outputs", 5, 50) |
|
lengths = st.sidebar.slider("Length", 3, 500) |
|
bad_words = st.text_input("Words You Do Not Want Generated", " core lemon height time ") |
|
logs_outputs = st.sidebar.slider("Logit Outputs", 50, 300) |
|
|
|
def run_generate(text, bad_words): |
|
yo = [] |
|
input_ids = tokenizer.encode(text, return_tensors='pt') |
|
res = len(tokenizer.encode(text)) |
|
bad_words = bad_words.split() |
|
bad_word_ids = [] |
|
for bad_word in bad_words: |
|
bad_word = " " + bad_word |
|
ids = tokenizer(bad_word).input_ids |
|
bad_word_ids.append(ids) |
|
sample_outputs = model.generate( |
|
input_ids, |
|
do_sample=True, |
|
max_length= res + lengths, |
|
min_length = res + lengths, |
|
top_k=50, |
|
temperature=temp, |
|
num_return_sequences=number_of_outputs, |
|
bad_words_ids=bad_word_ids |
|
) |
|
for i in range(number_of_outputs): |
|
e = tokenizer.decode(sample_outputs[i]) |
|
e = e.replace(text, "") |
|
yo.append(e) |
|
return yo |
|
|
|
def BestProbs5(prompt): |
|
prompt = prompt.strip() |
|
text = tokenizer.encode(prompt) |
|
myinput, past_key_values = torch.tensor([text]), None |
|
myinput = myinput |
|
logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False) |
|
logits = logits[0,-1] |
|
probabilities = torch.nn.functional.softmax(logits) |
|
best_logits, best_indices = logits.topk(number_of_outputs) |
|
best_words = [tokenizer.decode([idx.item()]) for idx in best_indices] |
|
for i in best_words[0:number_of_outputs]: |
|
|
|
print("\n") |
|
g = (prompt + i) |
|
st.write(g) |
|
l = run_generate(g, "hey") |
|
st.write(l) |
|
|
|
def run_generate2(text, bad_words): |
|
yo = [] |
|
input_ids = tokenizer2.encode(text, return_tensors='pt') |
|
res = len(tokenizer2.encode(text)) |
|
bad_words = bad_words.split() |
|
bad_word_ids = [] |
|
for bad_word in bad_words: |
|
bad_word = " " + bad_word |
|
ids = tokenizer2(bad_word).input_ids |
|
bad_word_ids.append(ids) |
|
sample_outputs = model2.generate( |
|
input_ids, |
|
do_sample=True, |
|
max_length= res + lengths, |
|
min_length = res + lengths, |
|
top_k=50, |
|
temperature=temp, |
|
num_return_sequences=number_of_outputs, |
|
bad_words_ids=bad_word_ids |
|
) |
|
for i in range(number_of_outputs): |
|
e = tokenizer2.decode(sample_outputs[i]) |
|
e = e.replace(text, "") |
|
yo.append(e) |
|
return yo |
|
|
|
def prefix_format(sentence): |
|
words = sentence.split() |
|
if "[MASK]" in sentence: |
|
words2 = words.index("[MASK]") |
|
|
|
output = ("<Prefix> " + ' '.join(words[:words2]) + " <Prefix> " + "<Suffix> " + ' '.join(words[words2+1:]) + " <Suffix>" + " <Middle>") |
|
st.write(output) |
|
else: |
|
st.write("Add [MASK] to sentence") |
|
|
|
with st.form(key='my_form'): |
|
text = st.text_area(label='Enter sentence', value=first) |
|
submit_button = st.form_submit_button(label='Submit') |
|
submit_button2 = st.form_submit_button(label='Submit Log Probs') |
|
|
|
submit_button3 = st.form_submit_button(label='Submit Other Model') |
|
submit_button4 = st.form_submit_button(label='Submit Log Probs Other Model') |
|
|
|
submit_button5 = st.form_submit_button(label='Most Prob') |
|
|
|
submit_button6 = st.form_submit_button(label='Turn Sentence with [MASK] into <Prefix> Format') |
|
|
|
if submit_button: |
|
translated_text = run_generate(text, bad_words) |
|
st.write(translated_text if translated_text else "No translation found") |
|
if submit_button2: |
|
with torch.no_grad(): |
|
text2 = str(text) |
|
print(text2) |
|
text3 = tokenizer.encode(text2) |
|
myinput, past_key_values = torch.tensor([text3]), None |
|
myinput = myinput |
|
logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False) |
|
logits = logits[0,-1] |
|
probabilities = torch.nn.functional.softmax(logits) |
|
best_logits, best_indices = logits.topk(logs_outputs) |
|
best_words = [tokenizer.decode([idx.item()]) for idx in best_indices] |
|
st.write(best_words) |
|
if submit_button3: |
|
translated_text = run_generate2(text, bad_words) |
|
st.write(translated_text if translated_text else "No translation found") |
|
if submit_button4: |
|
text2 = str(text) |
|
print(text2) |
|
text3 = tokenizer2.encode(text2) |
|
myinput, past_key_values = torch.tensor([text3]), None |
|
myinput = myinput |
|
logits, past_key_values = model2(myinput, past_key_values = past_key_values, return_dict=False) |
|
logits = logits[0,-1] |
|
probabilities = torch.nn.functional.softmax(logits) |
|
best_logits, best_indices = logits.topk(logs_outputs) |
|
best_words = [tokenizer2.decode([idx.item()]) for idx in best_indices] |
|
st.write(best_words) |
|
if submit_button5: |
|
BestProbs5(text) |
|
if submit_button6: |
|
text2 = str(text) |
|
prefix_format(text2) |