miiiciiii commited on
Commit
54e92ac
1 Parent(s): fc8e31a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +381 -0
app.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import warnings
4
+
5
+ # Suppress specific warnings
6
+ warnings.filterwarnings("ignore", message="This sequence already has </s>.")
7
+
8
+ # Append path for module imports
9
+ scripts_path = os.path.abspath(os.path.join('..', 'scripts'))
10
+ sys.path.append(scripts_path)
11
+
12
+
13
+ # Standard library imports
14
+ import random
15
+ import string
16
+
17
+ # Third-party imports
18
+ import json
19
+ import numpy as np
20
+ import pandas as pd
21
+ import torch
22
+ import nltk
23
+ from dateutil.parser import parse
24
+ from nltk.stem import PorterStemmer
25
+ from nltk.corpus import stopwords, wordnet as wn
26
+ from sklearn.feature_extraction.text import TfidfVectorizer
27
+ from sklearn.metrics.pairwise import cosine_similarity
28
+ # from textdistance import levenshtein
29
+ from rapidfuzz import fuzz
30
+ from rapidfuzz.distance import Levenshtein as levenshtein
31
+
32
+ from sense2vec import Sense2Vec
33
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
34
+ from sentence_transformers import SentenceTransformer
35
+
36
+ # Download necessary NLTK data
37
+ nltk.download('omw-1.4')
38
+ nltk.download('stopwords')
39
+ nltk.download('punkt')
40
+ nltk.download('brown')
41
+ nltk.download('wordnet')
42
+
43
+ from typing import List, Dict
44
+ import re
45
+
46
+ # Initialize models
47
+ t5ag_model = T5ForConditionalGeneration.from_pretrained("miiiciiii/I-Comprehend_ag")
48
+ t5ag_tokenizer = T5Tokenizer.from_pretrained("miiiciiii/I-Comprehend_ag", legacy=False)
49
+ t5qg_model = T5ForConditionalGeneration.from_pretrained("miiiciiii/I-Comprehend_qg")
50
+ t5qg_tokenizer = T5Tokenizer.from_pretrained("miiiciiii/I-Comprehend_qg", legacy=False)
51
+ s2v = Sense2Vec().from_disk(S2V_MODEL_PATH)
52
+ sentence_transformer_model = SentenceTransformer("sentence-transformers/LaBSE")
53
+
54
+ def answer_question(question, context):
55
+ """Generate an answer for a given question and context."""
56
+ input_text = f"question: {question} context: {context}"
57
+ input_ids = t5ag_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
58
+
59
+ with torch.no_grad():
60
+ output = t5ag_model.generate(input_ids, max_length=512, num_return_sequences=1, max_new_tokens=200)
61
+
62
+ return t5ag_tokenizer.decode(output[0], skip_special_tokens=True).capitalize()
63
+
64
+ def get_passage(passage):
65
+ """Generate a random context from the dataset."""
66
+ return passage.sample(n=1)['context'].values[0]
67
+
68
+ def get_question(context, answer, model, tokenizer):
69
+ """Generate a question for the given answer and context."""
70
+ answer_span = context.replace(answer, f"<hl>{answer}<hl>", 1) + "</s>"
71
+ inputs = tokenizer(answer_span, return_tensors="pt")
72
+ question = model.generate(input_ids=inputs.input_ids, max_length=50)[0]
73
+
74
+ return tokenizer.decode(question, skip_special_tokens=True)
75
+
76
+
77
+ def get_keywords(passage):
78
+ """Extract keywords using TF-IDF."""
79
+ try:
80
+ vectorizer = TfidfVectorizer(stop_words='english')
81
+ tfidf_matrix = vectorizer.fit_transform([passage])
82
+ feature_names = vectorizer.get_feature_names_out()
83
+ tfidf_scores = tfidf_matrix.toarray().flatten() # type: ignore
84
+ word_scores = dict(zip(feature_names, tfidf_scores))
85
+ sorted_words = sorted(word_scores.items(), key=lambda x: x[1], reverse=True)
86
+ keywords = [word for word, score in sorted_words]
87
+ return keywords
88
+ except Exception as e:
89
+ print(f"Error extracting keywords: {e}")
90
+ return []
91
+
92
+ def classify_question_type(question: str) -> str:
93
+ """
94
+ Classify the type of question as literal, evaluative, or inferential.
95
+
96
+ Parameters:
97
+ question (str): The question to classify.
98
+
99
+ Returns:
100
+ str: The type of the question ('literal', 'evaluative', or 'inferential').
101
+ """
102
+ # Define keywords or patterns for each question type
103
+ literal_keywords = [
104
+ 'what', 'when', 'where', 'who', 'how many', 'how much',
105
+ 'which', 'name', 'list', 'identify', 'define', 'describe',
106
+ 'state', 'mention'
107
+ ]
108
+
109
+ evaluative_keywords = [
110
+ 'evaluate', 'justify', 'explain why', 'assess', 'critique',
111
+ 'discuss', 'judge', 'opinion', 'argue', 'agree or disagree',
112
+ 'defend', 'support your answer', 'weigh the pros and cons',
113
+ 'compare', 'contrast'
114
+ ]
115
+
116
+ inferential_keywords = [
117
+ 'why', 'how', 'what if', 'predict', 'suggest', 'imply',
118
+ 'conclude', 'infer', 'reason', 'what might', 'what could',
119
+ 'what would happen if', 'speculate', 'deduce', 'interpret',
120
+ 'hypothesize', 'assume'
121
+ ]
122
+
123
+
124
+ question_lower = question.lower()
125
+
126
+ # Check for literal question keywords
127
+ if any(keyword in question_lower for keyword in literal_keywords):
128
+ return 'literal'
129
+
130
+ # Check for evaluative question keywords
131
+ if any(keyword in question_lower for keyword in evaluative_keywords):
132
+ return 'evaluative'
133
+
134
+ # Check for inferential question keywords
135
+ if any(keyword in question_lower for keyword in inferential_keywords):
136
+ return 'inferential'
137
+
138
+ # Default to 'unknown' if no pattern matches
139
+ return 'unknown'
140
+
141
+ def filter_same_sense_words(original, wordlist):
142
+ """Filter words that have the same sense as the original word."""
143
+ try:
144
+ base_sense = original.split('|')[1] # Ensure there is a sense part
145
+ except IndexError:
146
+ print(f"Warning: The original phrase '{original}' does not have a sense part.")
147
+ return wordlist # Return all words if the sense part is missing
148
+
149
+ return [word[0].split('|')[0].replace("_", " ").title().strip() for word in wordlist if word[0].split('|')[1] == base_sense]
150
+
151
+ def extract_similar_keywords(input_phrases, topn=5):
152
+ """Call get_distractors and extract only the similar_keywords values."""
153
+ distractors_result = get_distractors(input_phrases, topn)
154
+ similar_keywords_list = [result["similar_keywords"] for result in distractors_result]
155
+ return similar_keywords_list
156
+
157
+ def get_max_similarity_score(wordlist, word):
158
+ """Get the maximum similarity score between the word and a list of words."""
159
+ return max(levenshtein.normalized_similarity(word.lower(), each.lower()) for each in wordlist)
160
+
161
+ def mmr(doc_embedding, word_embeddings, words, top_n, lambda_param):
162
+ """Maximal Marginal Relevance (MMR) for keyword extraction."""
163
+ try:
164
+ word_doc_similarity = cosine_similarity(word_embeddings, doc_embedding)
165
+ word_similarity = cosine_similarity(word_embeddings)
166
+
167
+ keywords_idx = [np.argmax(word_doc_similarity)]
168
+ candidates_idx = [i for i in range(len(words)) if i != keywords_idx[0]]
169
+
170
+ for _ in range(top_n - 1):
171
+ candidate_similarities = word_doc_similarity[candidates_idx, :]
172
+ target_similarities = np.max(word_similarity[candidates_idx][:, keywords_idx], axis=1)
173
+
174
+ mmr = (lambda_param * candidate_similarities) - ((1 - lambda_param) * target_similarities.reshape(-1, 1))
175
+ mmr_idx = candidates_idx[np.argmax(mmr)]
176
+
177
+ keywords_idx.append(mmr_idx)
178
+ candidates_idx.remove(mmr_idx)
179
+
180
+ return [words[idx] for idx in keywords_idx]
181
+ except Exception as e:
182
+ print(f"Error in MMR: {e}")
183
+ return []
184
+
185
+ def format_phrase(phrase):
186
+ """Format phrases by replacing spaces with underscores and adding default |n."""
187
+ return phrase.replace(" ", "_") + "|n"
188
+
189
+
190
+ def is_valid_distractor(distractor, input_phrase):
191
+ """Check if the distractor is valid by ensuring it's alphabetic and relevant."""
192
+ if not re.match(r'^[a-zA-Z\s]+$', distractor):
193
+ return False
194
+
195
+ word_count = len(distractor.split())
196
+ if word_count < 1 or word_count > 4:
197
+ return False
198
+
199
+ return True
200
+
201
+ def filter_distractors(input_phrase, similar_keywords, topn):
202
+ """Filter distractors to ensure they match word count, aren't identical to the input,
203
+ and aren't too similar to each other or the input (e.g., stem similarity)."""
204
+ word_count = len(input_phrase.split())
205
+ filtered_keywords = []
206
+ stemmer = PorterStemmer()
207
+ input_stem = stemmer.stem(input_phrase.lower())
208
+
209
+ for keyword in similar_keywords:
210
+ keyword_stem = stemmer.stem(keyword.lower())
211
+
212
+ if (len(keyword.split()) == word_count and
213
+ keyword.lower() != input_phrase.lower() and
214
+ keyword_stem != input_stem and
215
+ is_valid_distractor(keyword, input_phrase)):
216
+
217
+ if all(stemmer.stem(kw.lower()) != keyword_stem for kw in filtered_keywords):
218
+ filtered_keywords.append(keyword)
219
+
220
+ if len(filtered_keywords) == topn:
221
+ break
222
+
223
+ return filtered_keywords
224
+
225
+
226
+ def get_distractors(input_phrases, topn=5):
227
+ """Find similar keywords for a list of input phrases using Sense2Vec and WordNet."""
228
+ result_list = []
229
+
230
+ for phrase in input_phrases:
231
+ formatted_phrase = format_phrase(phrase)
232
+
233
+ # Check if the phrase exists in the Sense2Vec model
234
+ if formatted_phrase in s2v:
235
+ # Get similar phrases from Sense2Vec
236
+ similar_phrases = s2v.most_similar(formatted_phrase, n=topn * 2) # Get more to filter later
237
+ similar_keywords = [item[0].split("|")[0].replace("_", " ") for item in similar_phrases]
238
+ else:
239
+ # List similar keys that might exist in the model for exploration
240
+ print(f"'{formatted_phrase}' not found in the model. Exploring similar available keys...")
241
+ available_keys = [key for key in s2v.keys() if phrase.split()[0] in key or phrase.split()[-1] in key]
242
+ print(f"Available keys related to '{phrase}': {available_keys}")
243
+
244
+ # Use WordNet to find synonyms if available keys are empty
245
+ if not available_keys:
246
+ print(f"No close match in the model for '{phrase}'. Trying WordNet for synonyms...")
247
+ synonyms = set()
248
+ for syn in wn.synsets(phrase.replace(" ", "_")):
249
+ for lemma in syn.lemmas():
250
+ synonyms.add(lemma.name().replace("_", " "))
251
+ similar_keywords = list(synonyms)[:topn * 2] if synonyms else ["No match found"]
252
+ else:
253
+ # Provide available keys as similar suggestions
254
+ similar_keywords = [key.split("|")[0].replace("_", " ") for key in available_keys[:topn * 2]]
255
+
256
+ # Filter distractors to match word count, avoid identical or stem-similar words, and check format
257
+ final_distractors = filter_distractors(phrase, similar_keywords, topn)
258
+ # Further filter out words with the same sense
259
+ final_distractors = filter_same_sense_words(phrase, final_distractors)
260
+
261
+ result_list.append({
262
+ "phrase": phrase,
263
+ "similar_keywords": final_distractors
264
+ })
265
+
266
+ return result_list
267
+
268
+ def get_mca_questions(context, qg_model, qg_tokenizer, sentence_transformer_model, num_questions=5, max_attempts=2) -> List[Dict]:
269
+ """
270
+ Generate multiple-choice questions for a given context.
271
+
272
+ Parameters:
273
+ context (str): The context from which questions are generated.
274
+ qg_model (T5ForConditionalGeneration): The question generation model.
275
+ qg_tokenizer (T5Tokenizer): The tokenizer for the question generation model.
276
+ s2v (Sense2Vec): The Sense2Vec model for finding similar words.
277
+ sentence_transformer_model (SentenceTransformer): The sentence transformer model for embeddings.
278
+ num_questions (int): The number of questions to generate.
279
+ max_attempts (int): The maximum number of attempts to generate questions.
280
+
281
+ Returns:
282
+ list: A list of dictionaries with questions and their corresponding distractors.
283
+ """
284
+ output_list = []
285
+
286
+ imp_keywords = get_keywords(context)
287
+ print(f"[DEBUG] Length: {len(imp_keywords)}, Extracted keywords: {imp_keywords}")
288
+
289
+ generated_questions = set()
290
+ generated_answers = set()
291
+ attempts = 0
292
+
293
+ while len(output_list) < num_questions and attempts < max_attempts:
294
+ attempts += 1
295
+
296
+ for keyword in imp_keywords:
297
+ if len(output_list) >= num_questions:
298
+ break
299
+
300
+ question = get_question(context, keyword, qg_model, qg_tokenizer)
301
+ print(f"[DEBUG] Generated question: '{question}' for keyword: '{keyword}'")
302
+
303
+ # Encode the new question
304
+ new_question_embedding = sentence_transformer_model.encode(question, convert_to_tensor=True)
305
+ is_similar = False
306
+
307
+ # Check similarity with existing questions
308
+ for generated_q in generated_questions:
309
+ existing_question_embedding = sentence_transformer_model.encode(generated_q, convert_to_tensor=True)
310
+ similarity = cosine_similarity(new_question_embedding.unsqueeze(0), existing_question_embedding.unsqueeze(0))[0][0]
311
+
312
+ if similarity > 0.8:
313
+ is_similar = True
314
+ print(f"[DEBUG] Question '{question}' is too similar to an existing question, skipping.")
315
+ break
316
+
317
+ if is_similar:
318
+ continue
319
+
320
+ # Generate and check answer
321
+ t5_answer = answer_question(question, context)
322
+ print(f"[DEBUG] Generated answer: '{t5_answer}' for question: '{question}'")
323
+
324
+ # Skip answers longer than 3 words
325
+ if len(t5_answer.split()) > 3:
326
+ print(f"[DEBUG] Answer '{t5_answer}' is too long, skipping.")
327
+ continue
328
+
329
+ if t5_answer in generated_answers:
330
+ print(f"[DEBUG] Answer '{t5_answer}' has already been generated, skipping question.")
331
+ continue
332
+
333
+ generated_questions.add(question)
334
+ generated_answers.add(t5_answer)
335
+
336
+ # Generate distractors
337
+ distractors = extract_similar_keywords([t5_answer], topn=5)[0]
338
+ print(f"list of distractors : {distractors}")
339
+ print(f"length of distractors {len(distractors)}")
340
+ print(f"type : {type(distractors)}")
341
+
342
+ # Remove any distractor that is the same as the correct answer
343
+ distractors = [d for d in distractors if d.lower() != t5_answer.lower()]
344
+ print(f"Filtered distractors (without answer): {distractors}")
345
+
346
+ # Ensure there are exactly 3 distractors
347
+ if len(distractors) < 3:
348
+ # Fill with random keywords from the imp_keywords list until we have 3 distractors
349
+ while len(distractors) < 3:
350
+ random_keyword = random.choice(imp_keywords)
351
+ # Ensure the random keyword isn't the same as the answer or already a distractor
352
+ if random_keyword.lower() != t5_answer.lower() and random_keyword not in distractors:
353
+ distractors.append(random_keyword)
354
+
355
+ # Limit to 3 distractors
356
+ distractors = distractors[:3]
357
+
358
+ print(f"[DEBUG] Final distractors: {distractors} for question: '{question}'")
359
+
360
+ choices = distractors + [t5_answer]
361
+ choices = [item.title() for item in choices]
362
+ random.shuffle(choices)
363
+ print(f"[DEBUG] Options: {choices} for answer: '{t5_answer}'")
364
+
365
+ # Classify question type
366
+ question_type = classify_question_type(question)
367
+
368
+ output_list.append({
369
+ 'answer': t5_answer,
370
+ 'answer_length': len(t5_answer),
371
+ 'choices': choices,
372
+ 'passage': context,
373
+ 'passage_length': len(context),
374
+ 'question': question,
375
+ 'question_length': len(question),
376
+ 'question_type': question_type
377
+ })
378
+
379
+ print(f"[DEBUG] Generated {len(output_list)} questions so far after {attempts} attempts")
380
+
381
+ return output_list[:num_questions]