Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import pipeline
|
3 |
+
import random
|
4 |
+
import nltk
|
5 |
+
|
6 |
+
nltk.download('punkt')
|
7 |
+
|
8 |
+
# Dictionary that maps the user-friendly model names to their actual names
|
9 |
+
model_names = {
|
10 |
+
"BERT base": "google-bert/bert-base-cased",
|
11 |
+
"DistilBERT base": "distilbert/distilbert-base-cased",
|
12 |
+
"RoBERTa base": "FacebookAI/roberta-base",
|
13 |
+
"BERT finetuned on a dataset for mask filling": "emma7897/bert_one",
|
14 |
+
"DistilBERT finetuned on a dataset for mask filling": "emma7897/distilbert_one",
|
15 |
+
"BERT finetuned on a dataset of stories for children": "emma7897/bert_two",
|
16 |
+
"DistilBERT finetuned on a dataset of stories for children": "emma7897/distilbert_two",
|
17 |
+
}
|
18 |
+
|
19 |
+
sample_paragraphs = [
|
20 |
+
"Once upon a time, in a faraway land, there lived a beautiful princess named [MASK]. She was known throughout the kingdom for her [MASK] and immense bravery. One day, while exploring the large forest, she stumbled upon a [MASK] hidden amongst the trees. Curiosity piqued, she ventured inside and discovered a [MASK] filled with treasures beyond imagination. Little did she know, her adventures were just beginning.",
|
21 |
+
"In the city of [MASK], where the streets were always very crowded and the skyscrapers reached for the sky, there was a tall detective named Sam. With a keen eye for detail and a knack for solving mysteries, Sam was the best in the business. When horrific crime shook the city to its core, Sam was called to travel to [MASK]. With determination and a trusty [MASK] by his side, Sam set out to uncover the truth.",
|
22 |
+
"On a remote island in the middle of the [MASK], there stood a blue lighthouse overlooking the turbulent waters. Inside, a keeper tended to the beacon, guiding [MASK] safely to shore. One stormy night, as the waves crashed against the rocks and the wind howled through the [MASK], a ship appeared on the horizon, its sails tattered and its crew in desperate need of help. With nerves of [MASK] and a steady hand, the lighthouse keeper sprang into action, signaling the way to safety.",
|
23 |
+
"In a whimsical village nestled in the [MASK] countryside, there lived an inventor named Zoey. Day and night, Zoey toiled away in her workshop, creating [MASK] that defied imagination. There was no limit to Zoey's creativity. But when a problem threatened to disrupt the peace of the village, Zoey knew it was time to put her [MASK] to the test. With gears whirring and steam hissing, Zoey set out to save the day.",
|
24 |
+
"Meet Emma, a spirited young soul with [MASK] dreams. Emma's eyes sparkle with determination as she envisions herself soaring among the stars as an aspiring [MASK]. She spends her days devouring books about [MASK]. When Emma is not gazing at the stars, you can find her drawing pictures of [MASK].",
|
25 |
+
"Hello! I would like to introduce you to my best friend, [MASK]."
|
26 |
+
]
|
27 |
+
|
28 |
+
example_models = [
|
29 |
+
"BERT base",
|
30 |
+
"DistilBERT base",
|
31 |
+
"RoBERTa base",
|
32 |
+
"BERT finetuned on a dataset for mask filling",
|
33 |
+
"DistilBERT finetuned on a dataset for mask filling",
|
34 |
+
"BERT finetuned on a dataset of stories for children",
|
35 |
+
"DistilBERT finetuned on a dataset of stories for children",
|
36 |
+
]
|
37 |
+
|
38 |
+
# Create a nested list for the examples
|
39 |
+
examples = [[random.choice(example_models), paragraph] for paragraph in sample_paragraphs]
|
40 |
+
|
41 |
+
def textGenerator(model, userInput):
|
42 |
+
model_name = model_names[model]
|
43 |
+
fill_mask = pipeline("fill-mask", model=model_name)
|
44 |
+
sentences = nltk.sent_tokenize(userInput)
|
45 |
+
processed_sentences = []
|
46 |
+
if model_name != "FacebookAI/roberta-base":
|
47 |
+
for sentence in sentences:
|
48 |
+
while "[MASK]" in sentence:
|
49 |
+
predictions = fill_mask(sentence, top_k=10)
|
50 |
+
token_strings = []
|
51 |
+
for prediction in predictions:
|
52 |
+
token_strings.append(prediction['token_str'])
|
53 |
+
selected_token = random.choice(token_strings)
|
54 |
+
sentence = sentence.replace("[MASK]", f"<mark>{selected_token}</mark>", 1)
|
55 |
+
processed_sentences.append(sentence)
|
56 |
+
processedText = " ".join(processed_sentences)
|
57 |
+
if model_name == "FacebookAI/roberta-base":
|
58 |
+
for sentence in sentences:
|
59 |
+
while "[MASK]" in sentence:
|
60 |
+
sentence = sentence.replace("[MASK]", "<mask>", 1)
|
61 |
+
predictions = fill_mask(sentence, top_k=10)
|
62 |
+
token_strings = []
|
63 |
+
for prediction in predictions:
|
64 |
+
token_strings.append(prediction['token_str'])
|
65 |
+
selected_token = random.choice(token_strings).strip()
|
66 |
+
sentence = sentence.replace("<mask>", f"<mark>{selected_token}</mark>", 1)
|
67 |
+
processed_sentences.append(sentence)
|
68 |
+
processedText = " ".join(processed_sentences)
|
69 |
+
return processedText
|
70 |
+
|
71 |
+
screen = gr.Interface(fn=textGenerator, inputs=[
|
72 |
+
gr.Radio(list(model_names.keys()), label="LLM", info="Which LLM would you like to use?"),
|
73 |
+
gr.Textbox(label = "User Input", info="Please enter a paragraph. Replace words that you want the LLM to fill in with [MASK]. Note: there is a limit of one [MASK] per sentence."),
|
74 |
+
], outputs = gr.HTML(label = "Processed Text"),
|
75 |
+
examples = examples,
|
76 |
+
)
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
screen.launch()
|