Spaces:
Sleeping
Sleeping
#--- | |
#- Author: Jaelin Lee | |
#- Date: Mar 23, 2024 | |
#- Description: Similarity search using BM25. Based on user input, retrieve most relevant info from knowledge base. | |
#- How it works: Tokenize the user input text using NLTK. Then, get TF-IDF based score against knowledge base using BM25. Get the index of the most similar item within knowledgebase using `argmax()`. Then, using the index, retrieve that item from the knowledge base. | |
#--- | |
from rank_bm25 import BM25Okapi | |
import nltk | |
from nltk.tokenize import word_tokenize | |
# Download NLTK data for tokenization | |
nltk.download('punkt') | |
class QuestionRetriever: | |
def __init__(self): | |
self.depression_questions = self.load_questions_from_file("data/depression_questions.txt") | |
self.adhd_questions = self.load_questions_from_file("data/adhd_questions.txt") | |
self.anxiety_questions = self.load_questions_from_file("data/anxiety_questions.txt") | |
self.social_isolation_questions = self.load_questions_from_file("data/social_isolation.txt") | |
self.cyberbullying_questions = self.load_questions_from_file("data/cyberbullying.txt") | |
self.social_media_addiction_questions = self.load_questions_from_file("data/socialmediaaddiction.txt") | |
def load_questions_from_file(self, filename): | |
with open(filename, "r") as file: | |
questions = file.readlines() | |
# Remove any leading or trailing whitespace and newline characters | |
questions = [question.strip() for question in questions] | |
return questions | |
def get_response(self, user_query, predicted_mental_category): | |
if predicted_mental_category == "depression": | |
knowledge_base = self.depression_questions | |
elif predicted_mental_category == "adhd": | |
knowledge_base = self.adhd_questions | |
elif predicted_mental_category == "anxiety": | |
knowledge_base = self.anxiety_questions | |
elif predicted_mental_category == "social isolation": | |
knowledge_base = self.social_isolation_questions | |
elif predicted_mental_category == "cyberbullying": | |
knowledge_base = self.cyberbullying_questions | |
elif predicted_mental_category == "social media addiction": | |
knowledge_base = self.social_media_addiction_questions | |
else: | |
knowledge_base = None | |
print("Sorry, I didn't understand that.") | |
if knowledge_base: | |
tokenized_docs = [word_tokenize(doc.lower()) for doc in knowledge_base] # Ensure lowercase for consistency | |
bm25 = BM25Okapi(tokenized_docs) | |
tokenized_query = word_tokenize(user_query.lower()) # Ensure lowercase for consistency | |
doc_scores = bm25.get_scores(tokenized_query) | |
# Get the index of the most relevant document | |
most_relevant_doc_index = doc_scores.argmax() | |
# Fetch the corresponding response from the knowledge base | |
response = knowledge_base[most_relevant_doc_index] | |
return response | |
else: | |
return None | |
if __name__ == "__main__": | |
# knowledge_base = "depression_questions" | |
predicted_mental_category = "cyberbullying" | |
model = QuestionRetriever() | |
user_input = input("User: ") | |
response = model.get_response(user_input, predicted_mental_category) | |
print("Chatbot:", response) | |