Spaces:
Sleeping
Sleeping
File size: 4,660 Bytes
eb1bccb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from transformers import DistilBertTokenizer, DistilBertModel, \
BertTokenizer, BertModel, \
RobertaTokenizer, RobertaModel, \
AutoTokenizer, AutoModelForMaskedLM
import gradio as gr
import pandas as pd
import numpy as np
from typing import Tuple
from sklearn.cluster import KMeans
# global variables
# global variables
encoder_options = [
'distilbert-base-uncased',
'bert-base-uncased',
'bert-base-cased',
'roberta-base',
'xlm-roberta-base',
]
tokenizer = None
model = None
genres = pd.read_csv("./all_genres.csv")
genres = set(genres["genre"].to_list())
def update_models(current_encoder: str) -> None:
global model, tokenizer
if current_encoder == 'distilbert-base-uncased':
tokenizer = DistilBertTokenizer.from_pretrained(
'distilbert-base-uncased'
)
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
elif current_encoder == 'bert-base-uncased':
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
elif current_encoder == 'bert-base-cased':
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertModel.from_pretrained('bert-base-cased')
elif current_encoder == 'roberta-base':
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base')
elif current_encoder == 'xlm-roberta-base':
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
model = AutoModelForMaskedLM.from_pretrained('xlm-roberta-base')
def embed_string() -> np.ndarray:
output = []
for text in genres:
encoded_input = tokenizer(text, return_tensors='pt')
# forward pass
new_output = model(**encoded_input)
to_append = new_output.last_hidden_state
to_append = to_append[:, -1, :]
to_append = to_append.flatten().detach().cpu().numpy()
output.append(to_append)
np_output = np.zeros((len(output), output[0].shape[0]))
for i, vector in enumerate(output):
np_output[i, :] = vector
return np_output
def gen_clusters(
input_strs: np.ndarray,
num_clusters: int
) -> Tuple[KMeans, np.ndarray, float]:
clustering_algo = KMeans(n_clusters=num_clusters)
predicted_labels = clustering_algo.fit_predict(input_strs)
cluster_error = 0.0
for i, predicted_label in enumerate(predicted_labels):
predicted_center = clustering_algo.cluster_centers_[predicted_label, :]
new_error = np.sqrt(np.sum(np.square(predicted_center, input_strs[i])))
cluster_error += new_error
return clustering_algo, predicted_labels, cluster_error
def view_clusters(predicted_clusters: np.ndarray) -> pd.DataFrame:
mappings = dict()
for predicted_cluster, movie in zip(predicted_clusters, genres):
curr_mapping = mappings.get(predicted_cluster, [])
curr_mapping.append(movie)
mappings[predicted_cluster] = curr_mapping
output_df = pd.DataFrame()
max_len = max([len(x) for x in mappings.values()])
max_cluster = max(predicted_clusters)
for i in range(max_cluster + 1):
new_column_name = f"cluster_{i}"
new_column_data = mappings[i]
new_column_data.extend([''] * (max_len - len(new_column_data)))
output_df[new_column_name] = new_column_data
return output_df
def add_new_genre(
new_genre: str = "",
num_clusters: int = 5,
) -> pd.DataFrame:
global genres
if new_genre != "":
genres.add(new_genre)
embedded_genres = embed_string()
_, cluster_centers, error = gen_clusters(embedded_genres, num_clusters)
ouput_df = view_clusters(cluster_centers)
return ouput_df, error
if __name__ == "__main__":
with gr.Blocks() as demo:
current_encoder = gr.Radio(encoder_options, label="Encoder")
current_encoder.change(fn=update_models, inputs=current_encoder)
new_genre_input = gr.Textbox(value="", label="New Genre")
num_clusters_input = gr.Number(
value=5,
precision=0,
label="Clusters"
)
output_clustering = gr.DataFrame()
output_error = gr.Number(label="Clustering Error", interactive=False)
encode_button = gr.Button(value="Run")
encode_button.click(
fn=add_new_genre,
inputs=[new_genre_input, num_clusters_input],
outputs=[output_clustering, output_error],
)
demo.launch()
|