Nathan Luskey commited on
Commit
eb1bccb
1 Parent(s): e811c53

Created a basic demo

Browse files
Files changed (4) hide show
  1. .gitignore +163 -0
  2. all_genres.csv +21 -0
  3. app.py +136 -0
  4. genre_clustering.ipynb +220 -0
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # My stuff
2
+ raw_data/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
all_genres.csv ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ genre
2
+ History
3
+ Adventure
4
+ Fantasy
5
+ Action
6
+ Crime
7
+ Family
8
+ Documentary
9
+ Foreign
10
+ Romance
11
+ Drama
12
+ Science Fiction
13
+ Music
14
+ Western
15
+ Comedy
16
+ Thriller
17
+ War
18
+ Mystery
19
+ Horror
20
+ TV Movie
21
+ Animation
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DistilBertTokenizer, DistilBertModel, \
2
+ BertTokenizer, BertModel, \
3
+ RobertaTokenizer, RobertaModel, \
4
+ AutoTokenizer, AutoModelForMaskedLM
5
+ import gradio as gr
6
+ import pandas as pd
7
+ import numpy as np
8
+ from typing import Tuple
9
+ from sklearn.cluster import KMeans
10
+
11
+ # global variables
12
+ # global variables
13
+ encoder_options = [
14
+ 'distilbert-base-uncased',
15
+ 'bert-base-uncased',
16
+ 'bert-base-cased',
17
+ 'roberta-base',
18
+ 'xlm-roberta-base',
19
+ ]
20
+
21
+ tokenizer = None
22
+ model = None
23
+
24
+ genres = pd.read_csv("./all_genres.csv")
25
+ genres = set(genres["genre"].to_list())
26
+
27
+
28
+ def update_models(current_encoder: str) -> None:
29
+ global model, tokenizer
30
+ if current_encoder == 'distilbert-base-uncased':
31
+ tokenizer = DistilBertTokenizer.from_pretrained(
32
+ 'distilbert-base-uncased'
33
+ )
34
+ model = DistilBertModel.from_pretrained('distilbert-base-uncased')
35
+ elif current_encoder == 'bert-base-uncased':
36
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
37
+ model = BertModel.from_pretrained('bert-base-uncased')
38
+ elif current_encoder == 'bert-base-cased':
39
+ tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
40
+ model = BertModel.from_pretrained('bert-base-cased')
41
+ elif current_encoder == 'roberta-base':
42
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
43
+ model = RobertaModel.from_pretrained('roberta-base')
44
+ elif current_encoder == 'xlm-roberta-base':
45
+ tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
46
+ model = AutoModelForMaskedLM.from_pretrained('xlm-roberta-base')
47
+
48
+
49
+ def embed_string() -> np.ndarray:
50
+ output = []
51
+ for text in genres:
52
+ encoded_input = tokenizer(text, return_tensors='pt')
53
+ # forward pass
54
+ new_output = model(**encoded_input)
55
+ to_append = new_output.last_hidden_state
56
+ to_append = to_append[:, -1, :]
57
+ to_append = to_append.flatten().detach().cpu().numpy()
58
+ output.append(to_append)
59
+ np_output = np.zeros((len(output), output[0].shape[0]))
60
+ for i, vector in enumerate(output):
61
+ np_output[i, :] = vector
62
+ return np_output
63
+
64
+
65
+ def gen_clusters(
66
+ input_strs: np.ndarray,
67
+ num_clusters: int
68
+ ) -> Tuple[KMeans, np.ndarray, float]:
69
+ clustering_algo = KMeans(n_clusters=num_clusters)
70
+ predicted_labels = clustering_algo.fit_predict(input_strs)
71
+
72
+ cluster_error = 0.0
73
+ for i, predicted_label in enumerate(predicted_labels):
74
+ predicted_center = clustering_algo.cluster_centers_[predicted_label, :]
75
+ new_error = np.sqrt(np.sum(np.square(predicted_center, input_strs[i])))
76
+ cluster_error += new_error
77
+
78
+ return clustering_algo, predicted_labels, cluster_error
79
+
80
+
81
+ def view_clusters(predicted_clusters: np.ndarray) -> pd.DataFrame:
82
+ mappings = dict()
83
+ for predicted_cluster, movie in zip(predicted_clusters, genres):
84
+ curr_mapping = mappings.get(predicted_cluster, [])
85
+ curr_mapping.append(movie)
86
+ mappings[predicted_cluster] = curr_mapping
87
+
88
+ output_df = pd.DataFrame()
89
+ max_len = max([len(x) for x in mappings.values()])
90
+ max_cluster = max(predicted_clusters)
91
+
92
+ for i in range(max_cluster + 1):
93
+ new_column_name = f"cluster_{i}"
94
+ new_column_data = mappings[i]
95
+ new_column_data.extend([''] * (max_len - len(new_column_data)))
96
+ output_df[new_column_name] = new_column_data
97
+
98
+ return output_df
99
+
100
+
101
+ def add_new_genre(
102
+ new_genre: str = "",
103
+ num_clusters: int = 5,
104
+ ) -> pd.DataFrame:
105
+ global genres
106
+ if new_genre != "":
107
+ genres.add(new_genre)
108
+ embedded_genres = embed_string()
109
+
110
+ _, cluster_centers, error = gen_clusters(embedded_genres, num_clusters)
111
+ ouput_df = view_clusters(cluster_centers)
112
+ return ouput_df, error
113
+
114
+
115
+ if __name__ == "__main__":
116
+ with gr.Blocks() as demo:
117
+ current_encoder = gr.Radio(encoder_options, label="Encoder")
118
+ current_encoder.change(fn=update_models, inputs=current_encoder)
119
+ new_genre_input = gr.Textbox(value="", label="New Genre")
120
+ num_clusters_input = gr.Number(
121
+ value=5,
122
+ precision=0,
123
+ label="Clusters"
124
+ )
125
+
126
+ output_clustering = gr.DataFrame()
127
+ output_error = gr.Number(label="Clustering Error", interactive=False)
128
+
129
+ encode_button = gr.Button(value="Run")
130
+ encode_button.click(
131
+ fn=add_new_genre,
132
+ inputs=[new_genre_input, num_clusters_input],
133
+ outputs=[output_clustering, output_error],
134
+ )
135
+
136
+ demo.launch()
genre_clustering.ipynb ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/Users/nathanluskey/opt/anaconda3/envs/ml_env/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "from transformers import DistilBertTokenizer, DistilBertModel, \\\n",
19
+ " BertTokenizer, BertModel, \\\n",
20
+ " RobertaTokenizer, RobertaModel, \\\n",
21
+ " AutoTokenizer, AutoModelForMaskedLM\n",
22
+ "import gradio as gr\n",
23
+ "import pandas as pd\n",
24
+ "import numpy as np\n",
25
+ "import torch\n",
26
+ "from typing import List, Tuple\n",
27
+ "from sklearn.cluster import KMeans"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 2,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "# global variables\n",
37
+ "encoder_options = [\n",
38
+ " 'distilbert-base-uncased',\n",
39
+ " 'bert-base-uncased',\n",
40
+ " 'bert-base-cased'\n",
41
+ " 'roberta-base',\n",
42
+ " 'xlm-roberta-base',\n",
43
+ " ]\n",
44
+ "\n",
45
+ "current_encoder = encoder_options[0]\n",
46
+ "tokenizer = None\n",
47
+ "model = None\n",
48
+ "\n",
49
+ "genres = pd.read_csv(\"./all_genres.csv\")\n",
50
+ "genres = genres[\"genre\"].to_list()"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 3,
56
+ "metadata": {},
57
+ "outputs": [
58
+ {
59
+ "name": "stderr",
60
+ "output_type": "stream",
61
+ "text": [
62
+ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.weight']\n",
63
+ "- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
64
+ "- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
65
+ ]
66
+ }
67
+ ],
68
+ "source": [
69
+ "if current_encoder == 'distilbert-base-uncased':\n",
70
+ " tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')\n",
71
+ " model = DistilBertModel.from_pretrained('distilbert-base-uncased')\n",
72
+ "elif current_encoder == 'bert-base-uncased':\n",
73
+ " tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
74
+ " model = BertModel.from_pretrained('bert-base-uncased')\n",
75
+ "elif current_encoder == 'bert-base-cased':\n",
76
+ " tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n",
77
+ " model = BertModel.from_pretrained('bert-base-cased')\n",
78
+ "elif current_encoder == 'roberta-base':\n",
79
+ " tokenizer = RobertaTokenizer.from_pretrained('roberta-base')\n",
80
+ " model = RobertaModel.from_pretrained('roberta-base')\n",
81
+ "elif current_encoder == 'xlm-roberta-base':\n",
82
+ " tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')\n",
83
+ " model = AutoModelForMaskedLM.from_pretrained('xlm-roberta-base')"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": 10,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "def embed_string() -> np.ndarray:\n",
93
+ " output = []\n",
94
+ " for text in genres:\n",
95
+ " encoded_input = tokenizer(text, return_tensors='pt')\n",
96
+ " # forward pass\n",
97
+ " new_output = model(**encoded_input)\n",
98
+ " to_append = new_output.last_hidden_state\n",
99
+ " to_append = to_append[:, -1, :] #Take the last element\n",
100
+ " to_append = to_append.flatten().detach().cpu().numpy()\n",
101
+ " output.append(to_append)\n",
102
+ " np_output = np.zeros((len(output), output[0].shape[0]))\n",
103
+ " for i, vector in enumerate(output):\n",
104
+ " np_output[i, :] = vector\n",
105
+ " return np_output"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": 5,
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "def gen_clusters(input_strs:np.ndarray, num_clusters:int) -> Tuple[KMeans, np.ndarray, float]:\n",
115
+ " clustering_algo = KMeans(n_clusters=num_clusters)\n",
116
+ " predicted_labels = clustering_algo.fit_predict(input_strs)\n",
117
+ "\n",
118
+ " cluster_error = 0.0\n",
119
+ " for i, predicted_label in enumerate(predicted_labels):\n",
120
+ " predicted_center = clustering_algo.cluster_centers_[predicted_label, :]\n",
121
+ " new_error = np.sqrt(np.sum(np.square(predicted_center, input_strs[i])))\n",
122
+ " cluster_error += new_error\n",
123
+ "\n",
124
+ " return clustering_algo, predicted_labels, cluster_error\n",
125
+ "\n"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 16,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "def view_clusters(predicted_clusters:np.ndarray) -> pd.DataFrame:\n",
135
+ " mappings = dict()\n",
136
+ " for predicted_cluster, movie in zip(predicted_clusters, genres):\n",
137
+ " curr_mapping = mappings.get(predicted_cluster, [])\n",
138
+ " curr_mapping.append(movie)\n",
139
+ " mappings[predicted_cluster] = curr_mapping\n",
140
+ "\n",
141
+ " output_df = pd.DataFrame()\n",
142
+ " max_len = max([len(x) for x in mappings.values()])\n",
143
+ " max_cluster = max(predicted_clusters)\n",
144
+ "\n",
145
+ " for i in range(max_cluster + 1):\n",
146
+ " new_column_name = f\"cluster_{i}\"\n",
147
+ " new_column_data = mappings[i]\n",
148
+ " new_column_data.extend([''] * (max_len - len(new_column_data)))\n",
149
+ " output_df[new_column_name] = new_column_data\n",
150
+ "\n",
151
+ " return output_df"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": 17,
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": [
160
+ "def add_new_genre(clustering_algo:KMeans, new_genre:str, recompute:bool = False) -> pd.DataFrame:\n",
161
+ " global genres\n",
162
+ " genres.append(new_genre)\n",
163
+ " embedded_genres = embed_string()\n",
164
+ " if recompute:\n",
165
+ " cluster_algo, cluster_centers, error = gen_clusters(embedded_genres, 5)\n",
166
+ " else:\n",
167
+ " cluster_centers = cluster_algo.predict(embedded_genres)\n",
168
+ " \n",
169
+ " ouput_df = view_clusters(cluster_centers)\n",
170
+ " return ouput_df\n",
171
+ " "
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": 18,
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "embedded_genres = embed_string()\n",
181
+ "clustering_algo, predicted_labels, cluster_error = gen_clusters(embedded_genres, 5)\n",
182
+ "output_df = view_clusters(predicted_labels)"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": []
191
+ }
192
+ ],
193
+ "metadata": {
194
+ "kernelspec": {
195
+ "display_name": "Python 3.10.6 ('ml_env')",
196
+ "language": "python",
197
+ "name": "python3"
198
+ },
199
+ "language_info": {
200
+ "codemirror_mode": {
201
+ "name": "ipython",
202
+ "version": 3
203
+ },
204
+ "file_extension": ".py",
205
+ "mimetype": "text/x-python",
206
+ "name": "python",
207
+ "nbconvert_exporter": "python",
208
+ "pygments_lexer": "ipython3",
209
+ "version": "3.10.6"
210
+ },
211
+ "orig_nbformat": 4,
212
+ "vscode": {
213
+ "interpreter": {
214
+ "hash": "2434bee09bcd67f653a1f2d2df1f4f18cabf9d6c39b42950acaa6ef605d590bc"
215
+ }
216
+ }
217
+ },
218
+ "nbformat": 4,
219
+ "nbformat_minor": 2
220
+ }