Spaces:
Sleeping
Sleeping
nrjvarshney
commited on
Commit
•
555714b
1
Parent(s):
9aeb0e3
Adding app file
Browse files
app.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wikipedia
|
2 |
+
import transformers
|
3 |
+
import spacy
|
4 |
+
from transformers import AutoModelWithLMHead, AutoTokenizer
|
5 |
+
import random
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
|
9 |
+
model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
|
10 |
+
nlp = spacy.load("en_core_web_sm")
|
11 |
+
|
12 |
+
def get_question(answer, context, max_length=64):
|
13 |
+
input_text = "answer: %s context: %s </s>" % (answer, context)
|
14 |
+
features = tokenizer([input_text], return_tensors='pt')
|
15 |
+
|
16 |
+
output = model.generate(input_ids=features['input_ids'],
|
17 |
+
attention_mask=features['attention_mask'],
|
18 |
+
max_length=max_length)
|
19 |
+
|
20 |
+
return tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
21 |
+
|
22 |
+
import gradio as gr
|
23 |
+
|
24 |
+
def greet(topic):
|
25 |
+
print("Entered topic: ", topic)
|
26 |
+
topics = wikipedia.search(topic)
|
27 |
+
random.shuffle(topics)
|
28 |
+
for topic in topics:
|
29 |
+
try:
|
30 |
+
summary = wikipedia.summary(topic)
|
31 |
+
except wikipedia.DisambiguationError as e:
|
32 |
+
# print(e.options)
|
33 |
+
s = random.choice(e.options)
|
34 |
+
summary = wikipedia.summary(s)
|
35 |
+
except wikipedia.PageError as e:
|
36 |
+
continue
|
37 |
+
break
|
38 |
+
print("Selected topic: ", topic)
|
39 |
+
print("Summary: ", summary)
|
40 |
+
summary = summary.replace("\n", "")
|
41 |
+
doc = nlp(summary)
|
42 |
+
|
43 |
+
answers = doc.ents
|
44 |
+
filtered_answers = []
|
45 |
+
for answer in answers:
|
46 |
+
if(answer.text in topic or topic in answer.text):
|
47 |
+
pass
|
48 |
+
else:
|
49 |
+
filtered_answers.append(answer)
|
50 |
+
|
51 |
+
answer_1 = random.choice(filtered_answers)
|
52 |
+
question_1 = get_question(answer_1, summary)
|
53 |
+
question_1 = question_1[9:]
|
54 |
+
print("Question: ", question_1)
|
55 |
+
print("Answer: ", answer_1)
|
56 |
+
return [question_1, gr.update(visible=True), gr.update(value=answer_1, visible=False)]
|
57 |
+
|
58 |
+
|
59 |
+
def get_answer(input_answer, gold_answer):
|
60 |
+
print("Entered Answer: ", input_answer)
|
61 |
+
return gr.update(value=gold_answer, visible=True)
|
62 |
+
|
63 |
+
|
64 |
+
with gr.Blocks() as demo:
|
65 |
+
# with gr.Row():
|
66 |
+
topic = gr.Textbox(label="Topic")
|
67 |
+
greet_btn = gr.Button("Ask a Question")
|
68 |
+
question = gr.Textbox(label="Question")
|
69 |
+
input_answer = gr.Textbox(label="Your Answer", visible=False)
|
70 |
+
answer_btn = gr.Button("Show Answer")
|
71 |
+
gold_answer = gr.Textbox(label="Correct Answer", visible=False)
|
72 |
+
greet_btn.click(fn=greet, inputs=topic, outputs=[question, input_answer, gold_answer])
|
73 |
+
|
74 |
+
# with gr.Row():
|
75 |
+
|
76 |
+
answer_btn.click(fn=get_answer, inputs=[input_answer,gold_answer], outputs=gold_answer)
|
77 |
+
|
78 |
+
demo.launch()
|
79 |
+
# demo.launch(share=True)
|