Spaces:
Runtime error
Runtime error
File size: 13,034 Bytes
658b022 ef0bdc3 b4075da cc121a0 61e66c3 658b022 ef0bdc3 61e66c3 ef0bdc3 61e66c3 ef0bdc3 cc121a0 ef0bdc3 61e66c3 cc121a0 ef0bdc3 61e66c3 ef0bdc3 61e66c3 ef0bdc3 61e66c3 ef0bdc3 658b022 61e66c3 658b022 61e66c3 ef0bdc3 61e66c3 ef0bdc3 61e66c3 658b022 ef0bdc3 61e66c3 b4075da ef09c1c 3490993 b4075da 658b022 b4075da 3490993 b4075da 658b022 61e66c3 b4075da 61e66c3 658b022 b4075da ef0bdc3 b4075da ef09c1c b4075da ef09c1c b4075da ef09c1c ef0bdc3 b4075da ef09c1c b4075da ef09c1c b4075da ef09c1c b4075da ef09c1c b4075da 658b022 ef0bdc3 b4075da ef0bdc3 658b022 ef0bdc3 b4075da ef09c1c b4075da ef09c1c b4075da ef0bdc3 b4075da ef0bdc3 b4075da ef0bdc3 b4075da ef0bdc3 b4075da ef0bdc3 61e66c3 ef09c1c 19c6000 ed3bf5d ef0cf07 19c6000 ef09c1c 19c6000 ef09c1c 19c6000 ef0cf07 ef09c1c 61e66c3 658b022 ef0bdc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
import gradio as gr
from transformers import AutoTokenizer
from transformers import pipeline
from utils import format_moves
import pandas as pd
import tensorflow as tf
import json
model_checkpoint = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
generate = pipeline("text-generation",
model="arjunpatel/distilgpt2-finetuned-pokemon-moves",
tokenizer=tokenizer)
# load in the model
seed_text = "This move is called "
tf.random.set_seed(0)
def update_history(df, move_name, move_desc, generation, parameters):
# get rid of first seed phrase
move_desc = move_desc.split("\n")[1:]
move_desc = "\n".join(move_desc)
new_row = [{"Move Name": move_name,
"Move Description": move_desc,
"Generation Type": generation,
"Parameters": json.dumps(parameters)}]
return pd.concat([df, pd.DataFrame(new_row)])
def create_move(move, history):
generated_move = format_moves(generate(seed_text + move, num_return_sequences=1))
return generated_move, update_history(history, move, generated_move,
"baseline", "None")
def create_greedy_search_move(move, history):
generated_move = format_moves(generate(seed_text + move, do_sample=False))
return generated_move, update_history(history, move, generated_move,
"greedy", "None")
def create_beam_search_move(move, num_beams, history):
generated_move = format_moves(generate(seed_text + move, num_beams=num_beams,
num_return_sequences=1,
do_sample=False, early_stopping=True))
return generated_move, update_history(history, move, generated_move,
"beam", {"num_beams": 2})
def create_sampling_search_move(move, do_sample, temperature, history):
generated_move = format_moves(generate(seed_text + move, do_sample=do_sample, temperature=float(temperature),
num_return_sequences=1, topk=0))
return generated_move, update_history(history, move, generated_move,
"temperature", {"do_sample": do_sample,
"temperature": temperature})
def create_top_search_move(move, topk, topp, history):
generated_move = format_moves(generate(
seed_text + move,
do_sample=True,
num_return_sequences=1,
top_k=topk,
top_p=topp,
force_word_ids=tokenizer.encode("The user", return_tensors='tf')))
return generated_move, update_history(history, move, generated_move,
"top", {"top k": topk,
"top p": topp})
demo = gr.Blocks()
with demo:
gr.Markdown("<h1><center>What's that Pokemon Move?</center></h1>")
gr.Markdown(
"""This Gradio demo allows you to generate Pokemon Move descriptions given a name, and learn more about text
decoding methods in the process! Each tab aims to explain each generation methodology available for the
model. The dataframe below allows you to keep track of each move generated, to compare!""")
gr.Markdown("<h3> How does text generation work? <h3>")
gr.Markdown("""Roughly, text generation models accept an input sequence of words (or parts of words,
known as tokens).
These models then output a corresponding set of words or tokens. Given the input, the model
estimates the probability of another possible word or token appearing right after the given sequence. In
other words, the model estimates conditional probabilities and ranks them in order to generate sequences
. """)
gr.Markdown("Enter a two to three word Pokemon Move name of your imagination below, with each word capitalized!")
gr.Markdown("<h3> Move Generation <h3>")
with gr.Tabs():
with gr.TabItem("Standard Generation"):
gr.Markdown(
"""The default parameters for distilgpt2 work well to generate moves. Use this tab as
a baseline for your experiments.""")
with gr.Row():
text_input_baseline = gr.Textbox(label="Move",
placeholder="Type a two or three word move name here! Try \"Wonder "
"Shield\"!")
text_output_baseline = gr.Textbox(label="Move Description",
placeholder="Leave this blank!")
text_button_baseline = gr.Button("Create my move!")
with gr.TabItem("Greedy Search Decoding"):
gr.Markdown("""
Greedy search is a decoding method that relies on finding words that has the highest estimated
probability of following the sequence thus far.
Therefore, the model \"greedily\" grabs the highest
probability word and continues generating the sentence.
This has the side effect of finding sequences that are reasonable, but avoids sequences that are
less probable but way more interesting.
Try the other decoding methods to get sentences with more variety!
""")
with gr.Row():
text_input_greedy = gr.Textbox(label="Move")
text_output_greedy = gr.Textbox(label="Move Description")
text_button_greedy = gr.Button("Create my move!")
with gr.TabItem("Beam Search"):
gr.Markdown("""Beam search is an improvement on Greedy Search. Instead of directly grabbing the word that
maximizes probability, we conduct a search with B number of candidates. We then try to find the next word
that would most likely follow each beam, and we grab the top B candidates of that search. This may
eliminate one of the original beams we started with, and that's okay! That is how the algorithm decides
on an optimal candidate. Eventually, the beam sequence terminate or are eliminated due to being too
improbable.
Increasing the number of beams will increase model generation time, but also result in a more thorough
search. Decreasing the number of beams will decrease decoding time, but it may not find an optimal
sentence.
Play around with the num_beams parameter to experiment! """
)
with gr.Row():
num_beams = gr.Slider(minimum=2, maximum=10, value=2, step=1,
label="Number of Beams")
text_input_beam = gr.Textbox(label="Move")
text_output_beam = gr.Textbox(label="Move Description")
text_button_beam = gr.Button("Create my move!")
with gr.TabItem("Sampling and Temperature Search"):
gr.Markdown(
"""Greedy Search and Beam Search were both good at finding sequences that are likely to follow our
input text, but when generating cool move descriptions, we want some more variety!
Instead of choosing the word or token that is most likely to follow a given sequence, we can instead
ask the model to sample across the probability distribution of likely words.
It's kind of like walking into the tall grass and finding a Pokemon encounter.
There are different encounter rates, which allow
for the most common mons to appear (looking at you, Zubat), but also account for surprise, like shinys!
We might even want to go further, though. We can rescale the probability distributions directly
instead, allowing for rare words to temporarily become more frequently. We do this using the
temperature parameter.
Turn the temperature up, and rare tokens become very likely! Cool down, and we approach more sensible
output.
Experiment with turning sampling on and off, and by varying temperature below!.
""")
with gr.Row():
temperature = gr.Slider(minimum=0.3, maximum=4.0, value=1.0, step=0.1,
label="Temperature")
text_input_temp = gr.Textbox(label="Move")
with gr.Row():
sample_boolean = gr.Checkbox(label="Enable Sampling?")
text_output_temp = gr.Textbox(label="Move Description")
text_button_temp = gr.Button("Create my move!")
with gr.TabItem("Top K and Top P Sampling"):
gr.Markdown(
"""When we want more control over the words we get to sample from, we turn to Top K and Top P
decoding methods!
The Top K sampling method selects the K most probable words given a sequence, and then samples from
that subset, rather than the whole vocabulary. This effectively cuts out low probability words.
Top P also reduces the available vocabulary to sample from, but instead of choosing the number of
words or tokens in advance, we sort the vocabulary from most to least likely word, and we
grab the smallest set of words that sum to P. This allows for the number of words we look at to
change while sampling, instead of being fixed.
We can even use both methods at the same time! To disable Top K, set it to 0 using the slider.
To disable Top P, set it to 1""")
with gr.Row():
topk = gr.Slider(minimum=0, maximum=200, value=0, step=5,
label="Top K")
text_input_top = gr.Textbox(label="Move")
with gr.Row():
topp = gr.Slider(minimum=0.10, maximum=1, value=1, step=0.05,
label="Top P")
text_output_top = gr.Textbox(label="Move Description")
text_button_top = gr.Button("Create my move!")
with gr.Box():
gr.Markdown("<h3> Generation History <h3>")
# Displays a dataframe with the history of moves generated, with parameters
history = gr.Dataframe(headers=["Move Name", "Move Description", "Generation Type", "Parameters"])
with gr.Box():
gr.Markdown("<h3>How did you make this?<h3>")
gr.Markdown("""
Hi! My name is <a href =https://www.linkedin.com/in/arjunkirtipatel/>Arjun Patel</a> and I'm Lead Data Scientist
over at <a href =https://www.speeko.co>Speeko</a>.
Nice to meet you!
I collected the dataset from <a href =https://www.serebii.net>Serebii</a>, a news source and aggregator of
Pokemon info.
I then added a seed phrase "This move is called" just before each move in order to assist the model in
generation.
I then followed HuggingFace's handy language_modeling.ipynb for fine-tuning distillgpt2 on this tiny dataset,
and it surprisingly worked!
I learned all about text generation using the book <a href
=https://www.oreilly.com/library/view/natural-language-processing/9781098103231/> Natural Language Processing
with Transformers</a> by Lewis Tunstall, Leandro von Werra and Thomas Wolf, as well as <a href
=https://huggingface.co/blog/how-to-generate>this fantastic article</a> by Patrick von Platen. Thanks to all
of these folks for creating these learning materials, and thanks to the Hugging Face team for developing this
product! """)
text_button_baseline.click(create_move, inputs=[text_input_baseline, history],
outputs=[text_output_baseline, history])
text_button_greedy.click(create_greedy_search_move, inputs=[text_input_greedy, history],
outputs=[text_output_greedy, history])
text_button_temp.click(create_sampling_search_move, inputs=[text_input_temp, sample_boolean, temperature, history],
outputs=[text_output_temp, history])
text_button_beam.click(create_beam_search_move, inputs=[text_input_beam, num_beams, history],
outputs=[text_output_beam, history])
text_button_top.click(create_top_search_move, inputs=[text_input_top, topk, topp, history],
outputs=[text_output_top, history])
demo.launch(share=True)
|