Spaces:
Build error
Build error
import os | |
import sys | |
import warnings | |
# Suppress specific warnings | |
warnings.filterwarnings("ignore", message="This sequence already has </s>.") | |
# Append path for module imports | |
scripts_path = os.path.abspath(os.path.join('..', 'scripts')) | |
sys.path.append(scripts_path) | |
# Standard library imports | |
import random | |
import string | |
# Third-party imports | |
import json | |
import numpy as np | |
import pandas as pd | |
import torch | |
import nltk | |
from dateutil.parser import parse | |
from nltk.stem import PorterStemmer | |
from nltk.corpus import stopwords, wordnet as wn | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
# from textdistance import levenshtein | |
from rapidfuzz import fuzz | |
from rapidfuzz.distance import Levenshtein as levenshtein | |
from sense2vec import Sense2Vec | |
from transformers import T5ForConditionalGeneration, T5Tokenizer | |
from sentence_transformers import SentenceTransformer | |
# Download necessary NLTK data | |
nltk.download('omw-1.4') | |
nltk.download('stopwords') | |
nltk.download('punkt') | |
nltk.download('brown') | |
nltk.download('wordnet') | |
from typing import List, Dict | |
import re | |
# Initialize models | |
t5ag_model = T5ForConditionalGeneration.from_pretrained("miiiciiii/I-Comprehend_ag") | |
t5ag_tokenizer = T5Tokenizer.from_pretrained("miiiciiii/I-Comprehend_ag", legacy=False) | |
t5qg_model = T5ForConditionalGeneration.from_pretrained("miiiciiii/I-Comprehend_qg") | |
t5qg_tokenizer = T5Tokenizer.from_pretrained("miiiciiii/I-Comprehend_qg", legacy=False) | |
s2v = Sense2Vec().from_disk(S2V_MODEL_PATH) | |
sentence_transformer_model = SentenceTransformer("sentence-transformers/LaBSE") | |
def answer_question(question, context): | |
"""Generate an answer for a given question and context.""" | |
input_text = f"question: {question} context: {context}" | |
input_ids = t5ag_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True) | |
with torch.no_grad(): | |
output = t5ag_model.generate(input_ids, max_length=512, num_return_sequences=1, max_new_tokens=200) | |
return t5ag_tokenizer.decode(output[0], skip_special_tokens=True).capitalize() | |
def get_passage(passage): | |
"""Generate a random context from the dataset.""" | |
return passage.sample(n=1)['context'].values[0] | |
def get_question(context, answer, model, tokenizer): | |
"""Generate a question for the given answer and context.""" | |
answer_span = context.replace(answer, f"<hl>{answer}<hl>", 1) + "</s>" | |
inputs = tokenizer(answer_span, return_tensors="pt") | |
question = model.generate(input_ids=inputs.input_ids, max_length=50)[0] | |
return tokenizer.decode(question, skip_special_tokens=True) | |
def get_keywords(passage): | |
"""Extract keywords using TF-IDF.""" | |
try: | |
vectorizer = TfidfVectorizer(stop_words='english') | |
tfidf_matrix = vectorizer.fit_transform([passage]) | |
feature_names = vectorizer.get_feature_names_out() | |
tfidf_scores = tfidf_matrix.toarray().flatten() # type: ignore | |
word_scores = dict(zip(feature_names, tfidf_scores)) | |
sorted_words = sorted(word_scores.items(), key=lambda x: x[1], reverse=True) | |
keywords = [word for word, score in sorted_words] | |
return keywords | |
except Exception as e: | |
print(f"Error extracting keywords: {e}") | |
return [] | |
def classify_question_type(question: str) -> str: | |
""" | |
Classify the type of question as literal, evaluative, or inferential. | |
Parameters: | |
question (str): The question to classify. | |
Returns: | |
str: The type of the question ('literal', 'evaluative', or 'inferential'). | |
""" | |
# Define keywords or patterns for each question type | |
literal_keywords = [ | |
'what', 'when', 'where', 'who', 'how many', 'how much', | |
'which', 'name', 'list', 'identify', 'define', 'describe', | |
'state', 'mention' | |
] | |
evaluative_keywords = [ | |
'evaluate', 'justify', 'explain why', 'assess', 'critique', | |
'discuss', 'judge', 'opinion', 'argue', 'agree or disagree', | |
'defend', 'support your answer', 'weigh the pros and cons', | |
'compare', 'contrast' | |
] | |
inferential_keywords = [ | |
'why', 'how', 'what if', 'predict', 'suggest', 'imply', | |
'conclude', 'infer', 'reason', 'what might', 'what could', | |
'what would happen if', 'speculate', 'deduce', 'interpret', | |
'hypothesize', 'assume' | |
] | |
question_lower = question.lower() | |
# Check for literal question keywords | |
if any(keyword in question_lower for keyword in literal_keywords): | |
return 'literal' | |
# Check for evaluative question keywords | |
if any(keyword in question_lower for keyword in evaluative_keywords): | |
return 'evaluative' | |
# Check for inferential question keywords | |
if any(keyword in question_lower for keyword in inferential_keywords): | |
return 'inferential' | |
# Default to 'unknown' if no pattern matches | |
return 'unknown' | |
def filter_same_sense_words(original, wordlist): | |
"""Filter words that have the same sense as the original word.""" | |
try: | |
base_sense = original.split('|')[1] # Ensure there is a sense part | |
except IndexError: | |
print(f"Warning: The original phrase '{original}' does not have a sense part.") | |
return wordlist # Return all words if the sense part is missing | |
return [word[0].split('|')[0].replace("_", " ").title().strip() for word in wordlist if word[0].split('|')[1] == base_sense] | |
def extract_similar_keywords(input_phrases, topn=5): | |
"""Call get_distractors and extract only the similar_keywords values.""" | |
distractors_result = get_distractors(input_phrases, topn) | |
similar_keywords_list = [result["similar_keywords"] for result in distractors_result] | |
return similar_keywords_list | |
def get_max_similarity_score(wordlist, word): | |
"""Get the maximum similarity score between the word and a list of words.""" | |
return max(levenshtein.normalized_similarity(word.lower(), each.lower()) for each in wordlist) | |
def mmr(doc_embedding, word_embeddings, words, top_n, lambda_param): | |
"""Maximal Marginal Relevance (MMR) for keyword extraction.""" | |
try: | |
word_doc_similarity = cosine_similarity(word_embeddings, doc_embedding) | |
word_similarity = cosine_similarity(word_embeddings) | |
keywords_idx = [np.argmax(word_doc_similarity)] | |
candidates_idx = [i for i in range(len(words)) if i != keywords_idx[0]] | |
for _ in range(top_n - 1): | |
candidate_similarities = word_doc_similarity[candidates_idx, :] | |
target_similarities = np.max(word_similarity[candidates_idx][:, keywords_idx], axis=1) | |
mmr = (lambda_param * candidate_similarities) - ((1 - lambda_param) * target_similarities.reshape(-1, 1)) | |
mmr_idx = candidates_idx[np.argmax(mmr)] | |
keywords_idx.append(mmr_idx) | |
candidates_idx.remove(mmr_idx) | |
return [words[idx] for idx in keywords_idx] | |
except Exception as e: | |
print(f"Error in MMR: {e}") | |
return [] | |
def format_phrase(phrase): | |
"""Format phrases by replacing spaces with underscores and adding default |n.""" | |
return phrase.replace(" ", "_") + "|n" | |
def is_valid_distractor(distractor, input_phrase): | |
"""Check if the distractor is valid by ensuring it's alphabetic and relevant.""" | |
if not re.match(r'^[a-zA-Z\s]+$', distractor): | |
return False | |
word_count = len(distractor.split()) | |
if word_count < 1 or word_count > 4: | |
return False | |
return True | |
def filter_distractors(input_phrase, similar_keywords, topn): | |
"""Filter distractors to ensure they match word count, aren't identical to the input, | |
and aren't too similar to each other or the input (e.g., stem similarity).""" | |
word_count = len(input_phrase.split()) | |
filtered_keywords = [] | |
stemmer = PorterStemmer() | |
input_stem = stemmer.stem(input_phrase.lower()) | |
for keyword in similar_keywords: | |
keyword_stem = stemmer.stem(keyword.lower()) | |
if (len(keyword.split()) == word_count and | |
keyword.lower() != input_phrase.lower() and | |
keyword_stem != input_stem and | |
is_valid_distractor(keyword, input_phrase)): | |
if all(stemmer.stem(kw.lower()) != keyword_stem for kw in filtered_keywords): | |
filtered_keywords.append(keyword) | |
if len(filtered_keywords) == topn: | |
break | |
return filtered_keywords | |
def get_distractors(input_phrases, topn=5): | |
"""Find similar keywords for a list of input phrases using Sense2Vec and WordNet.""" | |
result_list = [] | |
for phrase in input_phrases: | |
formatted_phrase = format_phrase(phrase) | |
# Check if the phrase exists in the Sense2Vec model | |
if formatted_phrase in s2v: | |
# Get similar phrases from Sense2Vec | |
similar_phrases = s2v.most_similar(formatted_phrase, n=topn * 2) # Get more to filter later | |
similar_keywords = [item[0].split("|")[0].replace("_", " ") for item in similar_phrases] | |
else: | |
# List similar keys that might exist in the model for exploration | |
print(f"'{formatted_phrase}' not found in the model. Exploring similar available keys...") | |
available_keys = [key for key in s2v.keys() if phrase.split()[0] in key or phrase.split()[-1] in key] | |
print(f"Available keys related to '{phrase}': {available_keys}") | |
# Use WordNet to find synonyms if available keys are empty | |
if not available_keys: | |
print(f"No close match in the model for '{phrase}'. Trying WordNet for synonyms...") | |
synonyms = set() | |
for syn in wn.synsets(phrase.replace(" ", "_")): | |
for lemma in syn.lemmas(): | |
synonyms.add(lemma.name().replace("_", " ")) | |
similar_keywords = list(synonyms)[:topn * 2] if synonyms else ["No match found"] | |
else: | |
# Provide available keys as similar suggestions | |
similar_keywords = [key.split("|")[0].replace("_", " ") for key in available_keys[:topn * 2]] | |
# Filter distractors to match word count, avoid identical or stem-similar words, and check format | |
final_distractors = filter_distractors(phrase, similar_keywords, topn) | |
# Further filter out words with the same sense | |
final_distractors = filter_same_sense_words(phrase, final_distractors) | |
result_list.append({ | |
"phrase": phrase, | |
"similar_keywords": final_distractors | |
}) | |
return result_list | |
def get_mca_questions(context, qg_model, qg_tokenizer, sentence_transformer_model, num_questions=5, max_attempts=2) -> List[Dict]: | |
""" | |
Generate multiple-choice questions for a given context. | |
Parameters: | |
context (str): The context from which questions are generated. | |
qg_model (T5ForConditionalGeneration): The question generation model. | |
qg_tokenizer (T5Tokenizer): The tokenizer for the question generation model. | |
s2v (Sense2Vec): The Sense2Vec model for finding similar words. | |
sentence_transformer_model (SentenceTransformer): The sentence transformer model for embeddings. | |
num_questions (int): The number of questions to generate. | |
max_attempts (int): The maximum number of attempts to generate questions. | |
Returns: | |
list: A list of dictionaries with questions and their corresponding distractors. | |
""" | |
output_list = [] | |
imp_keywords = get_keywords(context) | |
print(f"[DEBUG] Length: {len(imp_keywords)}, Extracted keywords: {imp_keywords}") | |
generated_questions = set() | |
generated_answers = set() | |
attempts = 0 | |
while len(output_list) < num_questions and attempts < max_attempts: | |
attempts += 1 | |
for keyword in imp_keywords: | |
if len(output_list) >= num_questions: | |
break | |
question = get_question(context, keyword, qg_model, qg_tokenizer) | |
print(f"[DEBUG] Generated question: '{question}' for keyword: '{keyword}'") | |
# Encode the new question | |
new_question_embedding = sentence_transformer_model.encode(question, convert_to_tensor=True) | |
is_similar = False | |
# Check similarity with existing questions | |
for generated_q in generated_questions: | |
existing_question_embedding = sentence_transformer_model.encode(generated_q, convert_to_tensor=True) | |
similarity = cosine_similarity(new_question_embedding.unsqueeze(0), existing_question_embedding.unsqueeze(0))[0][0] | |
if similarity > 0.8: | |
is_similar = True | |
print(f"[DEBUG] Question '{question}' is too similar to an existing question, skipping.") | |
break | |
if is_similar: | |
continue | |
# Generate and check answer | |
t5_answer = answer_question(question, context) | |
print(f"[DEBUG] Generated answer: '{t5_answer}' for question: '{question}'") | |
# Skip answers longer than 3 words | |
if len(t5_answer.split()) > 3: | |
print(f"[DEBUG] Answer '{t5_answer}' is too long, skipping.") | |
continue | |
if t5_answer in generated_answers: | |
print(f"[DEBUG] Answer '{t5_answer}' has already been generated, skipping question.") | |
continue | |
generated_questions.add(question) | |
generated_answers.add(t5_answer) | |
# Generate distractors | |
distractors = extract_similar_keywords([t5_answer], topn=5)[0] | |
print(f"list of distractors : {distractors}") | |
print(f"length of distractors {len(distractors)}") | |
print(f"type : {type(distractors)}") | |
# Remove any distractor that is the same as the correct answer | |
distractors = [d for d in distractors if d.lower() != t5_answer.lower()] | |
print(f"Filtered distractors (without answer): {distractors}") | |
# Ensure there are exactly 3 distractors | |
if len(distractors) < 3: | |
# Fill with random keywords from the imp_keywords list until we have 3 distractors | |
while len(distractors) < 3: | |
random_keyword = random.choice(imp_keywords) | |
# Ensure the random keyword isn't the same as the answer or already a distractor | |
if random_keyword.lower() != t5_answer.lower() and random_keyword not in distractors: | |
distractors.append(random_keyword) | |
# Limit to 3 distractors | |
distractors = distractors[:3] | |
print(f"[DEBUG] Final distractors: {distractors} for question: '{question}'") | |
choices = distractors + [t5_answer] | |
choices = [item.title() for item in choices] | |
random.shuffle(choices) | |
print(f"[DEBUG] Options: {choices} for answer: '{t5_answer}'") | |
# Classify question type | |
question_type = classify_question_type(question) | |
output_list.append({ | |
'answer': t5_answer, | |
'answer_length': len(t5_answer), | |
'choices': choices, | |
'passage': context, | |
'passage_length': len(context), | |
'question': question, | |
'question_length': len(question), | |
'question_type': question_type | |
}) | |
print(f"[DEBUG] Generated {len(output_list)} questions so far after {attempts} attempts") | |
return output_list[:num_questions] | |