Spaces:
Running
Running
import pandas as pd | |
import numpy as np | |
from ast import literal_eval | |
import yake | |
import spacy | |
from sklearn.metrics.pairwise import cosine_similarity | |
from sentence_transformers import SentenceTransformer | |
import os | |
class ScriptMatcher: | |
def __init__(self, data_path = None, model_name='paraphrase-mpnet-base-v2',dataframe = None): | |
""" | |
Initialize the SeriesMatcher object. | |
Parameters: | |
data_path (str): Path to the dataset file. | |
model_name (str): Name of the sentence transformer model. Default is 'paraphrase-mpnet-base-v2'. | |
""" | |
if data_path is not None: | |
self.dataset = pd.read_csv(data_path) | |
if dataframe is not None: | |
self.dataset = dataframe | |
self.model = SentenceTransformer(model_name) | |
self.kw_extractor = yake.KeywordExtractor("en", n=1, dedupLim=0.9) | |
self.k_dataset = pd.read_csv('models/Similarity_K_Dataset/K_Dataset.csv') | |
self._ent_type = ["PERSON","NORP","FAC","ORG","GPE","LOC","PRODUCT","EVENT","WORK","ART","LAW", | |
"LANGUAGE","DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"] | |
self.embeddings_synopsis_list = np.load("models/Similarity_K_Dataset/plot_embeddings.npy") | |
self.plot_embedding_list = np.load("models/Similarity_K_Dataset/synopsis_embeddings.npy") | |
try: | |
self.nlp = spacy.load("en_core_web_sm") | |
except: | |
print("Downloading spaCy NLP model...") | |
os.system( | |
"pip install https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl") | |
self.nlp = spacy.load("en_core_web_sm") | |
def extract_keywords(self, text): | |
""" | |
Extract keywords from a given text using the YAKE keyword extraction algorithm. | |
Parameters: | |
text (str): Text from which to extract keywords. | |
Returns: | |
str: A string of extracted keywords joined by spaces. | |
""" | |
extracted_keywords = self.kw_extractor.extract_keywords(text) | |
return " ".join([keywords[0] for keywords in extracted_keywords if keywords[0] not in self._ent_type]) | |
def preprocess_text(self, text): | |
""" | |
Process a given text to replace named entities and extract keywords. | |
Parameters: | |
text (str): The text to process. | |
Returns: | |
str: Processed text with named entities replaced and keywords extracted. | |
""" | |
doc = self.nlp(text) | |
replaced_text = text | |
for token in doc: | |
if token.ent_type_ != "MISC" and token.ent_type_ != "": | |
replaced_text = replaced_text.replace(token.text, f"<{token.ent_type_}>") | |
return self.extract_keywords(replaced_text) | |
def find_similar_series(self, new_synopsis, genres_keywords,k=5): | |
""" | |
Find series similar to a new synopsis. | |
Parameters: | |
new_synopsis (str): The synopsis to compare. | |
k (int): The number of similar series to return. | |
Returns: | |
pd.DataFrame: A dataframe of the closest series. | |
""" | |
processed_synopsis = self.preprocess_text(new_synopsis) | |
genre_keywords = " ".join(genres_keywords) | |
print(genre_keywords) | |
synopsis_sentence = genre_keywords + self.extract_keywords(processed_synopsis) | |
synopsis_embedding = self.model.encode([synopsis_sentence]) | |
cosine_similarity_matrix = 0.75 * cosine_similarity(synopsis_embedding, self.embeddings_synopsis_list) + 0.25 * cosine_similarity(synopsis_embedding,self.plot_embedding_list) | |
top_k_indices = cosine_similarity_matrix.argsort()[0, -k:][::-1] | |
closest_series = self.k_dataset.iloc[top_k_indices] | |
# Add scores column | |
closest_series["Score"] = cosine_similarity_matrix[0, top_k_indices] | |
return closest_series[["Series", "Genre","Score"]].to_dict(orient='records') | |