Spaces:
Build error
Build error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
# Suppress specific warnings
|
6 |
+
warnings.filterwarnings("ignore", message="This sequence already has </s>.")
|
7 |
+
|
8 |
+
# Append path for module imports
|
9 |
+
scripts_path = os.path.abspath(os.path.join('..', 'scripts'))
|
10 |
+
sys.path.append(scripts_path)
|
11 |
+
|
12 |
+
|
13 |
+
# Standard library imports
|
14 |
+
import random
|
15 |
+
import string
|
16 |
+
|
17 |
+
# Third-party imports
|
18 |
+
import json
|
19 |
+
import numpy as np
|
20 |
+
import pandas as pd
|
21 |
+
import torch
|
22 |
+
import nltk
|
23 |
+
from dateutil.parser import parse
|
24 |
+
from nltk.stem import PorterStemmer
|
25 |
+
from nltk.corpus import stopwords, wordnet as wn
|
26 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
27 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
28 |
+
# from textdistance import levenshtein
|
29 |
+
from rapidfuzz import fuzz
|
30 |
+
from rapidfuzz.distance import Levenshtein as levenshtein
|
31 |
+
|
32 |
+
from sense2vec import Sense2Vec
|
33 |
+
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
34 |
+
from sentence_transformers import SentenceTransformer
|
35 |
+
|
36 |
+
# Download necessary NLTK data
|
37 |
+
nltk.download('omw-1.4')
|
38 |
+
nltk.download('stopwords')
|
39 |
+
nltk.download('punkt')
|
40 |
+
nltk.download('brown')
|
41 |
+
nltk.download('wordnet')
|
42 |
+
|
43 |
+
from typing import List, Dict
|
44 |
+
import re
|
45 |
+
|
46 |
+
# Initialize models
|
47 |
+
t5ag_model = T5ForConditionalGeneration.from_pretrained("miiiciiii/I-Comprehend_ag")
|
48 |
+
t5ag_tokenizer = T5Tokenizer.from_pretrained("miiiciiii/I-Comprehend_ag", legacy=False)
|
49 |
+
t5qg_model = T5ForConditionalGeneration.from_pretrained("miiiciiii/I-Comprehend_qg")
|
50 |
+
t5qg_tokenizer = T5Tokenizer.from_pretrained("miiiciiii/I-Comprehend_qg", legacy=False)
|
51 |
+
s2v = Sense2Vec().from_disk(S2V_MODEL_PATH)
|
52 |
+
sentence_transformer_model = SentenceTransformer("sentence-transformers/LaBSE")
|
53 |
+
|
54 |
+
def answer_question(question, context):
|
55 |
+
"""Generate an answer for a given question and context."""
|
56 |
+
input_text = f"question: {question} context: {context}"
|
57 |
+
input_ids = t5ag_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
|
58 |
+
|
59 |
+
with torch.no_grad():
|
60 |
+
output = t5ag_model.generate(input_ids, max_length=512, num_return_sequences=1, max_new_tokens=200)
|
61 |
+
|
62 |
+
return t5ag_tokenizer.decode(output[0], skip_special_tokens=True).capitalize()
|
63 |
+
|
64 |
+
def get_passage(passage):
|
65 |
+
"""Generate a random context from the dataset."""
|
66 |
+
return passage.sample(n=1)['context'].values[0]
|
67 |
+
|
68 |
+
def get_question(context, answer, model, tokenizer):
|
69 |
+
"""Generate a question for the given answer and context."""
|
70 |
+
answer_span = context.replace(answer, f"<hl>{answer}<hl>", 1) + "</s>"
|
71 |
+
inputs = tokenizer(answer_span, return_tensors="pt")
|
72 |
+
question = model.generate(input_ids=inputs.input_ids, max_length=50)[0]
|
73 |
+
|
74 |
+
return tokenizer.decode(question, skip_special_tokens=True)
|
75 |
+
|
76 |
+
|
77 |
+
def get_keywords(passage):
|
78 |
+
"""Extract keywords using TF-IDF."""
|
79 |
+
try:
|
80 |
+
vectorizer = TfidfVectorizer(stop_words='english')
|
81 |
+
tfidf_matrix = vectorizer.fit_transform([passage])
|
82 |
+
feature_names = vectorizer.get_feature_names_out()
|
83 |
+
tfidf_scores = tfidf_matrix.toarray().flatten() # type: ignore
|
84 |
+
word_scores = dict(zip(feature_names, tfidf_scores))
|
85 |
+
sorted_words = sorted(word_scores.items(), key=lambda x: x[1], reverse=True)
|
86 |
+
keywords = [word for word, score in sorted_words]
|
87 |
+
return keywords
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error extracting keywords: {e}")
|
90 |
+
return []
|
91 |
+
|
92 |
+
def classify_question_type(question: str) -> str:
|
93 |
+
"""
|
94 |
+
Classify the type of question as literal, evaluative, or inferential.
|
95 |
+
|
96 |
+
Parameters:
|
97 |
+
question (str): The question to classify.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
str: The type of the question ('literal', 'evaluative', or 'inferential').
|
101 |
+
"""
|
102 |
+
# Define keywords or patterns for each question type
|
103 |
+
literal_keywords = [
|
104 |
+
'what', 'when', 'where', 'who', 'how many', 'how much',
|
105 |
+
'which', 'name', 'list', 'identify', 'define', 'describe',
|
106 |
+
'state', 'mention'
|
107 |
+
]
|
108 |
+
|
109 |
+
evaluative_keywords = [
|
110 |
+
'evaluate', 'justify', 'explain why', 'assess', 'critique',
|
111 |
+
'discuss', 'judge', 'opinion', 'argue', 'agree or disagree',
|
112 |
+
'defend', 'support your answer', 'weigh the pros and cons',
|
113 |
+
'compare', 'contrast'
|
114 |
+
]
|
115 |
+
|
116 |
+
inferential_keywords = [
|
117 |
+
'why', 'how', 'what if', 'predict', 'suggest', 'imply',
|
118 |
+
'conclude', 'infer', 'reason', 'what might', 'what could',
|
119 |
+
'what would happen if', 'speculate', 'deduce', 'interpret',
|
120 |
+
'hypothesize', 'assume'
|
121 |
+
]
|
122 |
+
|
123 |
+
|
124 |
+
question_lower = question.lower()
|
125 |
+
|
126 |
+
# Check for literal question keywords
|
127 |
+
if any(keyword in question_lower for keyword in literal_keywords):
|
128 |
+
return 'literal'
|
129 |
+
|
130 |
+
# Check for evaluative question keywords
|
131 |
+
if any(keyword in question_lower for keyword in evaluative_keywords):
|
132 |
+
return 'evaluative'
|
133 |
+
|
134 |
+
# Check for inferential question keywords
|
135 |
+
if any(keyword in question_lower for keyword in inferential_keywords):
|
136 |
+
return 'inferential'
|
137 |
+
|
138 |
+
# Default to 'unknown' if no pattern matches
|
139 |
+
return 'unknown'
|
140 |
+
|
141 |
+
def filter_same_sense_words(original, wordlist):
|
142 |
+
"""Filter words that have the same sense as the original word."""
|
143 |
+
try:
|
144 |
+
base_sense = original.split('|')[1] # Ensure there is a sense part
|
145 |
+
except IndexError:
|
146 |
+
print(f"Warning: The original phrase '{original}' does not have a sense part.")
|
147 |
+
return wordlist # Return all words if the sense part is missing
|
148 |
+
|
149 |
+
return [word[0].split('|')[0].replace("_", " ").title().strip() for word in wordlist if word[0].split('|')[1] == base_sense]
|
150 |
+
|
151 |
+
def extract_similar_keywords(input_phrases, topn=5):
|
152 |
+
"""Call get_distractors and extract only the similar_keywords values."""
|
153 |
+
distractors_result = get_distractors(input_phrases, topn)
|
154 |
+
similar_keywords_list = [result["similar_keywords"] for result in distractors_result]
|
155 |
+
return similar_keywords_list
|
156 |
+
|
157 |
+
def get_max_similarity_score(wordlist, word):
|
158 |
+
"""Get the maximum similarity score between the word and a list of words."""
|
159 |
+
return max(levenshtein.normalized_similarity(word.lower(), each.lower()) for each in wordlist)
|
160 |
+
|
161 |
+
def mmr(doc_embedding, word_embeddings, words, top_n, lambda_param):
|
162 |
+
"""Maximal Marginal Relevance (MMR) for keyword extraction."""
|
163 |
+
try:
|
164 |
+
word_doc_similarity = cosine_similarity(word_embeddings, doc_embedding)
|
165 |
+
word_similarity = cosine_similarity(word_embeddings)
|
166 |
+
|
167 |
+
keywords_idx = [np.argmax(word_doc_similarity)]
|
168 |
+
candidates_idx = [i for i in range(len(words)) if i != keywords_idx[0]]
|
169 |
+
|
170 |
+
for _ in range(top_n - 1):
|
171 |
+
candidate_similarities = word_doc_similarity[candidates_idx, :]
|
172 |
+
target_similarities = np.max(word_similarity[candidates_idx][:, keywords_idx], axis=1)
|
173 |
+
|
174 |
+
mmr = (lambda_param * candidate_similarities) - ((1 - lambda_param) * target_similarities.reshape(-1, 1))
|
175 |
+
mmr_idx = candidates_idx[np.argmax(mmr)]
|
176 |
+
|
177 |
+
keywords_idx.append(mmr_idx)
|
178 |
+
candidates_idx.remove(mmr_idx)
|
179 |
+
|
180 |
+
return [words[idx] for idx in keywords_idx]
|
181 |
+
except Exception as e:
|
182 |
+
print(f"Error in MMR: {e}")
|
183 |
+
return []
|
184 |
+
|
185 |
+
def format_phrase(phrase):
|
186 |
+
"""Format phrases by replacing spaces with underscores and adding default |n."""
|
187 |
+
return phrase.replace(" ", "_") + "|n"
|
188 |
+
|
189 |
+
|
190 |
+
def is_valid_distractor(distractor, input_phrase):
|
191 |
+
"""Check if the distractor is valid by ensuring it's alphabetic and relevant."""
|
192 |
+
if not re.match(r'^[a-zA-Z\s]+$', distractor):
|
193 |
+
return False
|
194 |
+
|
195 |
+
word_count = len(distractor.split())
|
196 |
+
if word_count < 1 or word_count > 4:
|
197 |
+
return False
|
198 |
+
|
199 |
+
return True
|
200 |
+
|
201 |
+
def filter_distractors(input_phrase, similar_keywords, topn):
|
202 |
+
"""Filter distractors to ensure they match word count, aren't identical to the input,
|
203 |
+
and aren't too similar to each other or the input (e.g., stem similarity)."""
|
204 |
+
word_count = len(input_phrase.split())
|
205 |
+
filtered_keywords = []
|
206 |
+
stemmer = PorterStemmer()
|
207 |
+
input_stem = stemmer.stem(input_phrase.lower())
|
208 |
+
|
209 |
+
for keyword in similar_keywords:
|
210 |
+
keyword_stem = stemmer.stem(keyword.lower())
|
211 |
+
|
212 |
+
if (len(keyword.split()) == word_count and
|
213 |
+
keyword.lower() != input_phrase.lower() and
|
214 |
+
keyword_stem != input_stem and
|
215 |
+
is_valid_distractor(keyword, input_phrase)):
|
216 |
+
|
217 |
+
if all(stemmer.stem(kw.lower()) != keyword_stem for kw in filtered_keywords):
|
218 |
+
filtered_keywords.append(keyword)
|
219 |
+
|
220 |
+
if len(filtered_keywords) == topn:
|
221 |
+
break
|
222 |
+
|
223 |
+
return filtered_keywords
|
224 |
+
|
225 |
+
|
226 |
+
def get_distractors(input_phrases, topn=5):
|
227 |
+
"""Find similar keywords for a list of input phrases using Sense2Vec and WordNet."""
|
228 |
+
result_list = []
|
229 |
+
|
230 |
+
for phrase in input_phrases:
|
231 |
+
formatted_phrase = format_phrase(phrase)
|
232 |
+
|
233 |
+
# Check if the phrase exists in the Sense2Vec model
|
234 |
+
if formatted_phrase in s2v:
|
235 |
+
# Get similar phrases from Sense2Vec
|
236 |
+
similar_phrases = s2v.most_similar(formatted_phrase, n=topn * 2) # Get more to filter later
|
237 |
+
similar_keywords = [item[0].split("|")[0].replace("_", " ") for item in similar_phrases]
|
238 |
+
else:
|
239 |
+
# List similar keys that might exist in the model for exploration
|
240 |
+
print(f"'{formatted_phrase}' not found in the model. Exploring similar available keys...")
|
241 |
+
available_keys = [key for key in s2v.keys() if phrase.split()[0] in key or phrase.split()[-1] in key]
|
242 |
+
print(f"Available keys related to '{phrase}': {available_keys}")
|
243 |
+
|
244 |
+
# Use WordNet to find synonyms if available keys are empty
|
245 |
+
if not available_keys:
|
246 |
+
print(f"No close match in the model for '{phrase}'. Trying WordNet for synonyms...")
|
247 |
+
synonyms = set()
|
248 |
+
for syn in wn.synsets(phrase.replace(" ", "_")):
|
249 |
+
for lemma in syn.lemmas():
|
250 |
+
synonyms.add(lemma.name().replace("_", " "))
|
251 |
+
similar_keywords = list(synonyms)[:topn * 2] if synonyms else ["No match found"]
|
252 |
+
else:
|
253 |
+
# Provide available keys as similar suggestions
|
254 |
+
similar_keywords = [key.split("|")[0].replace("_", " ") for key in available_keys[:topn * 2]]
|
255 |
+
|
256 |
+
# Filter distractors to match word count, avoid identical or stem-similar words, and check format
|
257 |
+
final_distractors = filter_distractors(phrase, similar_keywords, topn)
|
258 |
+
# Further filter out words with the same sense
|
259 |
+
final_distractors = filter_same_sense_words(phrase, final_distractors)
|
260 |
+
|
261 |
+
result_list.append({
|
262 |
+
"phrase": phrase,
|
263 |
+
"similar_keywords": final_distractors
|
264 |
+
})
|
265 |
+
|
266 |
+
return result_list
|
267 |
+
|
268 |
+
def get_mca_questions(context, qg_model, qg_tokenizer, sentence_transformer_model, num_questions=5, max_attempts=2) -> List[Dict]:
|
269 |
+
"""
|
270 |
+
Generate multiple-choice questions for a given context.
|
271 |
+
|
272 |
+
Parameters:
|
273 |
+
context (str): The context from which questions are generated.
|
274 |
+
qg_model (T5ForConditionalGeneration): The question generation model.
|
275 |
+
qg_tokenizer (T5Tokenizer): The tokenizer for the question generation model.
|
276 |
+
s2v (Sense2Vec): The Sense2Vec model for finding similar words.
|
277 |
+
sentence_transformer_model (SentenceTransformer): The sentence transformer model for embeddings.
|
278 |
+
num_questions (int): The number of questions to generate.
|
279 |
+
max_attempts (int): The maximum number of attempts to generate questions.
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
list: A list of dictionaries with questions and their corresponding distractors.
|
283 |
+
"""
|
284 |
+
output_list = []
|
285 |
+
|
286 |
+
imp_keywords = get_keywords(context)
|
287 |
+
print(f"[DEBUG] Length: {len(imp_keywords)}, Extracted keywords: {imp_keywords}")
|
288 |
+
|
289 |
+
generated_questions = set()
|
290 |
+
generated_answers = set()
|
291 |
+
attempts = 0
|
292 |
+
|
293 |
+
while len(output_list) < num_questions and attempts < max_attempts:
|
294 |
+
attempts += 1
|
295 |
+
|
296 |
+
for keyword in imp_keywords:
|
297 |
+
if len(output_list) >= num_questions:
|
298 |
+
break
|
299 |
+
|
300 |
+
question = get_question(context, keyword, qg_model, qg_tokenizer)
|
301 |
+
print(f"[DEBUG] Generated question: '{question}' for keyword: '{keyword}'")
|
302 |
+
|
303 |
+
# Encode the new question
|
304 |
+
new_question_embedding = sentence_transformer_model.encode(question, convert_to_tensor=True)
|
305 |
+
is_similar = False
|
306 |
+
|
307 |
+
# Check similarity with existing questions
|
308 |
+
for generated_q in generated_questions:
|
309 |
+
existing_question_embedding = sentence_transformer_model.encode(generated_q, convert_to_tensor=True)
|
310 |
+
similarity = cosine_similarity(new_question_embedding.unsqueeze(0), existing_question_embedding.unsqueeze(0))[0][0]
|
311 |
+
|
312 |
+
if similarity > 0.8:
|
313 |
+
is_similar = True
|
314 |
+
print(f"[DEBUG] Question '{question}' is too similar to an existing question, skipping.")
|
315 |
+
break
|
316 |
+
|
317 |
+
if is_similar:
|
318 |
+
continue
|
319 |
+
|
320 |
+
# Generate and check answer
|
321 |
+
t5_answer = answer_question(question, context)
|
322 |
+
print(f"[DEBUG] Generated answer: '{t5_answer}' for question: '{question}'")
|
323 |
+
|
324 |
+
# Skip answers longer than 3 words
|
325 |
+
if len(t5_answer.split()) > 3:
|
326 |
+
print(f"[DEBUG] Answer '{t5_answer}' is too long, skipping.")
|
327 |
+
continue
|
328 |
+
|
329 |
+
if t5_answer in generated_answers:
|
330 |
+
print(f"[DEBUG] Answer '{t5_answer}' has already been generated, skipping question.")
|
331 |
+
continue
|
332 |
+
|
333 |
+
generated_questions.add(question)
|
334 |
+
generated_answers.add(t5_answer)
|
335 |
+
|
336 |
+
# Generate distractors
|
337 |
+
distractors = extract_similar_keywords([t5_answer], topn=5)[0]
|
338 |
+
print(f"list of distractors : {distractors}")
|
339 |
+
print(f"length of distractors {len(distractors)}")
|
340 |
+
print(f"type : {type(distractors)}")
|
341 |
+
|
342 |
+
# Remove any distractor that is the same as the correct answer
|
343 |
+
distractors = [d for d in distractors if d.lower() != t5_answer.lower()]
|
344 |
+
print(f"Filtered distractors (without answer): {distractors}")
|
345 |
+
|
346 |
+
# Ensure there are exactly 3 distractors
|
347 |
+
if len(distractors) < 3:
|
348 |
+
# Fill with random keywords from the imp_keywords list until we have 3 distractors
|
349 |
+
while len(distractors) < 3:
|
350 |
+
random_keyword = random.choice(imp_keywords)
|
351 |
+
# Ensure the random keyword isn't the same as the answer or already a distractor
|
352 |
+
if random_keyword.lower() != t5_answer.lower() and random_keyword not in distractors:
|
353 |
+
distractors.append(random_keyword)
|
354 |
+
|
355 |
+
# Limit to 3 distractors
|
356 |
+
distractors = distractors[:3]
|
357 |
+
|
358 |
+
print(f"[DEBUG] Final distractors: {distractors} for question: '{question}'")
|
359 |
+
|
360 |
+
choices = distractors + [t5_answer]
|
361 |
+
choices = [item.title() for item in choices]
|
362 |
+
random.shuffle(choices)
|
363 |
+
print(f"[DEBUG] Options: {choices} for answer: '{t5_answer}'")
|
364 |
+
|
365 |
+
# Classify question type
|
366 |
+
question_type = classify_question_type(question)
|
367 |
+
|
368 |
+
output_list.append({
|
369 |
+
'answer': t5_answer,
|
370 |
+
'answer_length': len(t5_answer),
|
371 |
+
'choices': choices,
|
372 |
+
'passage': context,
|
373 |
+
'passage_length': len(context),
|
374 |
+
'question': question,
|
375 |
+
'question_length': len(question),
|
376 |
+
'question_type': question_type
|
377 |
+
})
|
378 |
+
|
379 |
+
print(f"[DEBUG] Generated {len(output_list)} questions so far after {attempts} attempts")
|
380 |
+
|
381 |
+
return output_list[:num_questions]
|