import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel import torch first = """informal english: corn fields are all across illinois, visible once you leave chicago.\nTranslated into the Style of Abraham Lincoln: corn fields ( permeate illinois / span the state of illinois / ( occupy / persist in ) all corners of illinois / line the horizon of illinois / envelop the landscape of illinois ), manifesting themselves visibly as one ventures beyond chicago.\n\ninformal english: """ @st.cache(allow_output_mutation=True) def get_model(): #model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPTNeo350MInformalToFormalLincoln2") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln21") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln40") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln41") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln41") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln49") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPT2InformalToFormalLincoln42") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/Points3") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPTNeo1.3BPointsLincolnFormalInformal") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/MediumInformalToFormalLincoln") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPTNeo350MInformalToFormalLincoln7") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincolnConciseWordy") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/MediumInformalToFormalLincoln2") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/MediumInformalToFormalLincoln3") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/MediumInformalToFormalLincoln4") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln50") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPT2Neo1.3BPoints2") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPT2Neo1.3BPoints3") #model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") #tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln63Paraphrase") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln68Paraphrase") #model2 = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln63Paraphrase") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln63Paraphrase") #tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln73Paraphrase") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln73Paraphrase") #tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln76Paraphrase") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln76Paraphrase") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln78Paraphrase") #tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln78Paraphrase") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln80Paraphrase") #tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln80Paraphrase") #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln82Paraphrase") #tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln82Paraphrase") model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln85Paraphrase") tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln85Paraphrase") tokenizer2 = AutoTokenizer.from_pretrained("gpt2") model2 = AutoModelForCausalLM.from_pretrained("gpt2") return model, model2, tokenizer, tokenizer2 model, model2, tokenizer, tokenizer2 = get_model() st.text('''For Prompt Templates: https://huggingface.co/BigSalmon/InformalToFormalLincoln82Paraphrase''') temp = st.sidebar.slider("Temperature", 0.7, 1.5) number_of_outputs = st.sidebar.slider("Number of Outputs", 5, 50) lengths = st.sidebar.slider("Length", 3, 500) bad_words = st.text_input("Words You Do Not Want Generated", " core lemon height time ") logs_outputs = st.sidebar.slider("Logit Outputs", 50, 300) def run_generate(text, bad_words): yo = [] input_ids = tokenizer.encode(text, return_tensors='pt') res = len(tokenizer.encode(text)) bad_words = bad_words.split() bad_word_ids = [] for bad_word in bad_words: bad_word = " " + bad_word ids = tokenizer(bad_word).input_ids bad_word_ids.append(ids) sample_outputs = model.generate( input_ids, do_sample=True, max_length= res + lengths, min_length = res + lengths, top_k=50, temperature=temp, num_return_sequences=number_of_outputs, bad_words_ids=bad_word_ids ) for i in range(number_of_outputs): e = tokenizer.decode(sample_outputs[i]) e = e.replace(text, "") yo.append(e) return yo def BestProbs5(prompt): prompt = prompt.strip() text = tokenizer.encode(prompt) myinput, past_key_values = torch.tensor([text]), None myinput = myinput logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False) logits = logits[0,-1] probabilities = torch.nn.functional.softmax(logits) best_logits, best_indices = logits.topk(number_of_outputs) best_words = [tokenizer.decode([idx.item()]) for idx in best_indices] for i in best_words[0:number_of_outputs]: #print(i) print("\n") g = (prompt + i) st.write(g) l = run_generate(g, "hey") st.write(l) def run_generate2(text, bad_words): yo = [] input_ids = tokenizer2.encode(text, return_tensors='pt') res = len(tokenizer2.encode(text)) bad_words = bad_words.split() bad_word_ids = [] for bad_word in bad_words: bad_word = " " + bad_word ids = tokenizer2(bad_word).input_ids bad_word_ids.append(ids) sample_outputs = model2.generate( input_ids, do_sample=True, max_length= res + lengths, min_length = res + lengths, top_k=50, temperature=temp, num_return_sequences=number_of_outputs, bad_words_ids=bad_word_ids ) for i in range(number_of_outputs): e = tokenizer2.decode(sample_outputs[i]) e = e.replace(text, "") yo.append(e) return yo def prefix_format(sentence): words = sentence.split() if "[MASK]" in sentence: words2 = words.index("[MASK]") #print(words2) output = (" " + ' '.join(words[:words2]) + " " + " " + ' '.join(words[words2+1:]) + " " + " ") st.write(output) else: st.write("Add [MASK] to sentence") with st.form(key='my_form'): text = st.text_area(label='Enter sentence', value=first) submit_button = st.form_submit_button(label='Submit') submit_button2 = st.form_submit_button(label='Submit Log Probs') submit_button3 = st.form_submit_button(label='Submit Other Model') submit_button4 = st.form_submit_button(label='Submit Log Probs Other Model') submit_button5 = st.form_submit_button(label='Most Prob') submit_button6 = st.form_submit_button(label='Turn Sentence with [MASK] into Format') if submit_button: translated_text = run_generate(text, bad_words) st.write(translated_text if translated_text else "No translation found") if submit_button2: with torch.no_grad(): text2 = str(text) print(text2) text3 = tokenizer.encode(text2) myinput, past_key_values = torch.tensor([text3]), None myinput = myinput logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False) logits = logits[0,-1] probabilities = torch.nn.functional.softmax(logits) best_logits, best_indices = logits.topk(logs_outputs) best_words = [tokenizer.decode([idx.item()]) for idx in best_indices] st.write(best_words) if submit_button3: translated_text = run_generate2(text, bad_words) st.write(translated_text if translated_text else "No translation found") if submit_button4: text2 = str(text) print(text2) text3 = tokenizer2.encode(text2) myinput, past_key_values = torch.tensor([text3]), None myinput = myinput logits, past_key_values = model2(myinput, past_key_values = past_key_values, return_dict=False) logits = logits[0,-1] probabilities = torch.nn.functional.softmax(logits) best_logits, best_indices = logits.topk(logs_outputs) best_words = [tokenizer2.decode([idx.item()]) for idx in best_indices] st.write(best_words) if submit_button5: BestProbs5(text) if submit_button6: text2 = str(text) prefix_format(text2)