dutch-questgen / app.py
Michelvh's picture
Update app.py
ee3fab3
raw
history blame
3.94 kB
import gradio as gr
from transformers import T5ForConditionalGeneration, T5TokenizerFast
import nltk
from nltk import tokenize
nltk.download('punkt')
checkpoint = "yhavinga/t5-base-dutch"
tokenizer = T5TokenizerFast.from_pretrained(checkpoint)
tokenizer.sep_token = '<sep>'
tokenizer.add_tokens(['<sep>'])
hfmodel = T5ForConditionalGeneration.from_pretrained("Michelvh/t5-end2end-questions-generation-dutch")
def hf_run_model(input_string, **generator_args):
generator_args = {
"max_length": 256,
"num_beams": 4,
"length_penalty": 1.5,
"no_repeat_ngram_size": 3,
"early_stopping": True,
"num_return_sequences": 1,
}
input_string = input_string + " </s>"
input_ids = tokenizer.encode(input_string, return_tensors="pt")
res = hfmodel.generate(input_ids, **generator_args)
output = tokenizer.batch_decode(res, skip_special_tokens=True)
output = [item.split("<sep>") for item in output]
return output
def chunk_text(text, framesize=5):
sentences = tokenize.sent_tokenize(text)
frames = []
lastindex = len(sentences) - framesize + 1
for index in range(lastindex):
frames.append(" ".join(sentences[index:index+framesize]))
return frames
def flatten(l):
return [item for sublist in l for item in sublist]
def run_model_with_frames(text, framesize=4, overlap=3, progress=gr.Progress()):
if overlap > framesize:
return "Overlap should be smaller than batch size"
frames = create_frames(text, framesize, overlap)
counter = 0
total_steps = len(frames)
progress((counter, total_steps), desc="Starting...")
result = set()
for frame in frames:
questions = flatten(hf_run_model(frame))
for question in questions:
result.add(ensure_questionmark(question.strip()))
counter += 1
progress((counter, total_steps), desc="Generating...")
output_string = ""
for entry in result:
output_string += entry
output_string += "\n"
progress((counter, total_steps), desc="Done")
return output_string
def create_frames(text, framesize=4, overlap=3):
sentences = tokenize.sent_tokenize(text)
frames = []
stepsize = framesize - overlap
index = 0
sentenceslength = len(sentences)
while index < sentenceslength:
endindex = index + framesize
if endindex >= sentenceslength:
frame = " ".join(sentences[-framesize:])
index = sentenceslength
else:
frame = " ".join(sentences[index:endindex])
index += stepsize
frames.append(frame)
return frames
def ensure_questionmark(question):
if question.endswith("?"):
return question
return question + "?"
description = """
# Dutch question generator
Input some Dutch text and click the button to generate some questions!
The model is currently set up to generate as many questions, but this
can take a couple of minutes so have some patience ;)
The optimal text lenght is probably around 8-10 lines. Longer text
will obviously take longer. Please keep in mind that this is a work in
progress and might still be a little bit buggy."""
with gr.Blocks() as iface:
gr.Markdown(description)
context = gr.Textbox(label="Input text")
frame_size = gr.Number(value=5, label="Batch size", info="Size of the subparts that are used to generate questions. Increase to speed up the generation", precision=0)
overlap = gr.Number(value=4, label="Overlap", info="Overlap between batches. Should be bigger than batch size. Decrease to speed up generation", precision=0)
questions = gr.Textbox(label="Questions")
generate_btn = gr.Button("Generate questions")
generate_btn.click(fn=run_model_with_frames, inputs=[context, frame_size, overlap], outputs=questions, api_name="generate_questions")
#iface = gr.Interface(fn=run_model_with_frames, inputs="text", outputs="text")
iface.launch()