Spaces:
Sleeping
Sleeping
from transformers import DistilBertTokenizer, DistilBertModel, \ | |
BertTokenizer, BertModel, \ | |
RobertaTokenizer, RobertaModel, \ | |
AutoTokenizer, AutoModelForMaskedLM | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
from typing import Tuple | |
from sklearn.cluster import KMeans | |
# global variables | |
# global variables | |
encoder_options = [ | |
'distilbert-base-uncased', | |
'bert-base-uncased', | |
'bert-base-cased', | |
'roberta-base', | |
'xlm-roberta-base', | |
] | |
tokenizer = None | |
model = None | |
genres = pd.read_csv("./all_genres.csv") | |
genres = set(genres["genre"].to_list()) | |
def update_models(current_encoder: str) -> None: | |
global model, tokenizer | |
if current_encoder == 'distilbert-base-uncased': | |
tokenizer = DistilBertTokenizer.from_pretrained( | |
'distilbert-base-uncased' | |
) | |
model = DistilBertModel.from_pretrained('distilbert-base-uncased') | |
elif current_encoder == 'bert-base-uncased': | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
model = BertModel.from_pretrained('bert-base-uncased') | |
elif current_encoder == 'bert-base-cased': | |
tokenizer = BertTokenizer.from_pretrained('bert-base-cased') | |
model = BertModel.from_pretrained('bert-base-cased') | |
elif current_encoder == 'roberta-base': | |
tokenizer = RobertaTokenizer.from_pretrained('roberta-base') | |
model = RobertaModel.from_pretrained('roberta-base') | |
elif current_encoder == 'xlm-roberta-base': | |
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base') | |
model = AutoModelForMaskedLM.from_pretrained('xlm-roberta-base') | |
def embed_string() -> np.ndarray: | |
output = [] | |
for text in genres: | |
encoded_input = tokenizer(text, return_tensors='pt') | |
# forward pass | |
new_output = model(**encoded_input) | |
to_append = new_output.last_hidden_state | |
to_append = to_append[:, -1, :] | |
to_append = to_append.flatten().detach().cpu().numpy() | |
output.append(to_append) | |
np_output = np.zeros((len(output), output[0].shape[0])) | |
for i, vector in enumerate(output): | |
np_output[i, :] = vector | |
return np_output | |
def gen_clusters( | |
input_strs: np.ndarray, | |
num_clusters: int | |
) -> Tuple[KMeans, np.ndarray, float]: | |
clustering_algo = KMeans(n_clusters=num_clusters) | |
predicted_labels = clustering_algo.fit_predict(input_strs) | |
cluster_error = 0.0 | |
for i, predicted_label in enumerate(predicted_labels): | |
predicted_center = clustering_algo.cluster_centers_[predicted_label, :] | |
new_error = np.sqrt(np.sum(np.square(predicted_center, input_strs[i]))) | |
cluster_error += new_error | |
return clustering_algo, predicted_labels, cluster_error | |
def view_clusters(predicted_clusters: np.ndarray) -> pd.DataFrame: | |
mappings = dict() | |
for predicted_cluster, movie in zip(predicted_clusters, genres): | |
curr_mapping = mappings.get(predicted_cluster, []) | |
curr_mapping.append(movie) | |
mappings[predicted_cluster] = curr_mapping | |
output_df = pd.DataFrame() | |
max_len = max([len(x) for x in mappings.values()]) | |
max_cluster = max(predicted_clusters) | |
for i in range(max_cluster + 1): | |
new_column_name = f"cluster_{i}" | |
new_column_data = mappings[i] | |
new_column_data.extend([''] * (max_len - len(new_column_data))) | |
output_df[new_column_name] = new_column_data | |
return output_df | |
def add_new_genre( | |
new_genre: str = "", | |
num_clusters: int = 5, | |
) -> pd.DataFrame: | |
global genres | |
if new_genre != "": | |
genres.add(new_genre) | |
embedded_genres = embed_string() | |
_, cluster_centers, error = gen_clusters(embedded_genres, num_clusters) | |
ouput_df = view_clusters(cluster_centers) | |
return ouput_df, error | |
if __name__ == "__main__": | |
with gr.Blocks() as demo: | |
current_encoder = gr.Radio(encoder_options, label="Encoder") | |
current_encoder.change(fn=update_models, inputs=current_encoder) | |
new_genre_input = gr.Textbox(value="", label="New Genre") | |
num_clusters_input = gr.Number( | |
value=5, | |
precision=0, | |
label="Clusters" | |
) | |
output_clustering = gr.DataFrame() | |
output_error = gr.Number(label="Clustering Error", interactive=False) | |
encode_button = gr.Button(value="Run") | |
encode_button.click( | |
fn=add_new_genre, | |
inputs=[new_genre_input, num_clusters_input], | |
outputs=[output_clustering, output_error], | |
) | |
demo.launch() | |