FoodDesert commited on
Commit
90290aa
1 Parent(s): e2d3b05

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +105 -19
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  from sklearn.metrics.pairwise import cosine_similarity
 
3
  import numpy as np
4
  from joblib import load
5
  import h5py
@@ -11,6 +12,8 @@ import compress_fasttext
11
  from collections import OrderedDict
12
  from lark import Lark
13
  from lark import Token
 
 
14
 
15
 
16
 
@@ -69,12 +72,12 @@ You can read more about TF-IDF on its [Wikipedia page](https://en.wikipedia.org/
69
 
70
  ## How does the tag corrector work?
71
 
72
- We collected the tag sets from over 4 million e621 posts, treating the tag set from each image as an individual document.
73
  We then randomly replace about 10% of the tags in each document with a randomly selected alias from e621's list of aliases for the tag
74
  (e.g. "canine" gets replaced with one of {k9,canines,mongrel,cannine,cnaine,feral_canine,anthro_canine}).
75
  We then train a FastText (https://fasttext.cc/) model on the documents. The result of this training is a function that maps arbitrary words to vectors such that
76
  the vector for a tag and the vectors for its aliases are all close together (because the model has seen them in similar contexts).
77
- Since the lists of aliases contain misspellings and rephrasings of tags, the model should be robust to these kinds of problems.
78
  """
79
 
80
 
@@ -92,6 +95,9 @@ plain: /([^,\\\[\]():|]|\\.)+/
92
  parser = Lark(grammar, start='start')
93
 
94
 
 
 
 
95
  # Function to extract tags
96
  def extract_tags(tree):
97
  tags = []
@@ -107,21 +113,43 @@ def extract_tags(tree):
107
 
108
 
109
  # Load the model and data once at startup
110
- with h5py.File('complete_artist_data.hdf5', 'r') as f:
111
- # Deserialize the vectorizer
112
  vectorizer_bytes = f['vectorizer'][()].tobytes()
 
113
  vectorizer_buffer = BytesIO(vectorizer_bytes)
114
  vectorizer = load(vectorizer_buffer)
115
 
116
- # Load X_artist
117
- X_artist = f['X_artist'][:]
118
-
119
- # Load artist names and decode to strings
120
  artist_names = [name.decode() for name in f['artist_names'][:]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  def clean_tag(tag):
123
  return ''.join(char for char in tag if ord(char) < 128)
124
 
 
125
  #Normally returns tag to aliases, but when reverse=True, returns alias to tags
126
  def build_aliases_dict(filename, reverse=False):
127
  aliases_dict = {}
@@ -138,7 +166,52 @@ def build_aliases_dict(filename, reverse=False):
138
  return aliases_dict
139
 
140
 
141
- def find_similar_tags(test_tags):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  #Initialize stuff
144
  if not hasattr(find_similar_tags, "fasttext_small_model"):
@@ -149,12 +222,16 @@ def find_similar_tags(test_tags):
149
  if not hasattr(find_similar_tags, "alias2tags"):
150
  find_similar_tags.alias2tags = build_aliases_dict(tag_aliases_file, reverse=True)
151
 
152
-
 
153
  # Find similar tags and prepare data for dataframe.
154
  results_data = []
155
  for tag in test_tags:
 
 
 
156
  modified_tag_for_search = tag.replace(' ','_')
157
- similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search)
158
  result, seen = [], set()
159
 
160
  if modified_tag_for_search in find_similar_tags.tag2aliases:
@@ -176,7 +253,15 @@ def find_similar_tags(test_tags):
176
  result.append((similar_tag.replace('_', ' '), round(similarity, 3)))
177
  seen.add(similar_tag)
178
 
 
 
 
 
 
 
 
179
  # Append tag and formatted similar tags to results_data
 
180
  first_entry_for_tag = True
181
  for word, sim in result:
182
  if first_entry_for_tag:
@@ -191,7 +276,7 @@ def find_similar_tags(test_tags):
191
 
192
  return results_data # Return list of lists for Dataframe
193
 
194
- def find_similar_artists(new_tags_string, top_n):
195
  try:
196
  new_tags_string = new_tags_string.lower()
197
  # Parse the prompt
@@ -201,17 +286,17 @@ def find_similar_artists(new_tags_string, top_n):
201
  new_image_tags = [tag.replace('_', ' ').strip() for tag in new_image_tags]
202
 
203
  ###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
204
- unseen_tags_data = find_similar_tags(new_image_tags)
 
 
 
205
 
206
- X_new_image = vectorizer.transform([','.join(new_image_tags)])
207
- similarities = cosine_similarity(X_new_image, X_artist)[0]
208
-
209
  top_artist_indices = np.argsort(similarities)[-top_n:][::-1]
210
  top_artists = [(artist_names[i], similarities[i]) for i in top_artist_indices]
211
-
212
  top_artists_str = "\n".join([f"{rank+1}. {artist[3:]} ({score:.4f})" for rank, (artist, score) in enumerate(top_artists)])
213
  dynamic_prompts_formatted_artists = "{" + "|".join([artist for artist, _ in top_artists]) + "}"
214
-
215
  return unseen_tags_data, top_artists_str, dynamic_prompts_formatted_artists
216
  except ParseError as e:
217
  return [], "Parse Error: Check for mismatched parentheses or something", ""
@@ -221,7 +306,8 @@ iface = gr.Interface(
221
  fn=find_similar_artists,
222
  inputs=[
223
  gr.Textbox(label="Enter image tags", placeholder="e.g. fox, outside, detailed background, ..."),
224
- gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
 
225
  ],
226
  outputs=[
227
  gr.Dataframe(label="Unseen Tags", headers=["Tag", "Similar Tags", "Similarity"]),
 
1
  import gradio as gr
2
  from sklearn.metrics.pairwise import cosine_similarity
3
+ from scipy.sparse import csr_matrix
4
  import numpy as np
5
  from joblib import load
6
  import h5py
 
12
  from collections import OrderedDict
13
  from lark import Lark
14
  from lark import Token
15
+ from lark.exceptions import ParseError
16
+
17
 
18
 
19
 
 
72
 
73
  ## How does the tag corrector work?
74
 
75
+ We collect the tag sets from over 4 million e621 posts, treating the tag set from each image as an individual document.
76
  We then randomly replace about 10% of the tags in each document with a randomly selected alias from e621's list of aliases for the tag
77
  (e.g. "canine" gets replaced with one of {k9,canines,mongrel,cannine,cnaine,feral_canine,anthro_canine}).
78
  We then train a FastText (https://fasttext.cc/) model on the documents. The result of this training is a function that maps arbitrary words to vectors such that
79
  the vector for a tag and the vectors for its aliases are all close together (because the model has seen them in similar contexts).
80
+ Since the lists of aliases contain misspellings and rephrasings of tags, the model should be robust to these kinds of problems as long as they are not too dissimilar from the alias lists.
81
  """
82
 
83
 
 
95
  parser = Lark(grammar, start='start')
96
 
97
 
98
+ special_tags = ["score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9"]
99
+
100
+
101
  # Function to extract tags
102
  def extract_tags(tree):
103
  tags = []
 
113
 
114
 
115
  # Load the model and data once at startup
116
+ with h5py.File('pca_reduced_artist_data.hdf5', 'r') as f:
 
117
  vectorizer_bytes = f['vectorizer'][()].tobytes()
118
+ # Use io.BytesIO to convert bytes back to a file-like object for joblib to load
119
  vectorizer_buffer = BytesIO(vectorizer_bytes)
120
  vectorizer = load(vectorizer_buffer)
121
 
122
+ # Assuming you've saved the PCA mean, components, and the transformed X_artist matrix in the file
123
+ pca_mean = f['pca_mean'][:]
124
+ pca_components = f['pca_components'][:]
125
+ X_artist_reduced = f['X_artist_reduced'][:]
126
  artist_names = [name.decode() for name in f['artist_names'][:]]
127
+ # Recreate PCA transformation (not the exact PCA object but its transformation ability)
128
+ def pca_transform(X):
129
+ return (X - pca_mean) @ pca_components.T
130
+
131
+
132
+ with h5py.File('conditional_tag_probabilities_matrix.h5', 'r') as f:
133
+ # Reconstruct the sparse co-occurrence matrix
134
+ conditional_co_occurrence_matrix = csr_matrix(
135
+ (f['co_occurrence_data'][:], f['co_occurrence_indices'][:], f['co_occurrence_indptr'][:]),
136
+ shape=f['co_occurrence_shape'][:]
137
+ )
138
+
139
+ # Reconstruct the vocabulary
140
+ conditional_words = f['vocabulary_words'][:]
141
+ conditional_indices = f['vocabulary_indices'][:]
142
+ conditional_vocabulary = {key.decode('utf-8'): value for key, value in zip(conditional_words, conditional_indices)}
143
+
144
+ # Load the document count
145
+ conditional_doc_count = f['doc_count'][()]
146
+ conditional_smoothing = 100. / conditional_doc_count
147
+
148
 
149
  def clean_tag(tag):
150
  return ''.join(char for char in tag if ord(char) < 128)
151
 
152
+
153
  #Normally returns tag to aliases, but when reverse=True, returns alias to tags
154
  def build_aliases_dict(filename, reverse=False):
155
  aliases_dict = {}
 
166
  return aliases_dict
167
 
168
 
169
+ #Imagine we are adding smoothing_value to the number of times word_j occurs in each document for smoothing.
170
+ #Note the intention is that sum_i(P(word_i|word_j)) =(approx) # of words in a document rather than 1.
171
+ def conditional_probability(word_i, word_j, co_occurrence_matrix, vocabulary, doc_count, smoothing_value=0.01):
172
+ word_i_index = vocabulary.get(word_i)
173
+ word_j_index = vocabulary.get(word_j)
174
+
175
+ if word_i_index is not None and word_j_index is not None:
176
+ # Directly access the sparse matrix elements
177
+ word_j_count = co_occurrence_matrix[word_j_index, word_j_index]
178
+ smoothed_word_j_count = word_j_count + (smoothing_value * doc_count)
179
+
180
+ word_i_count = co_occurrence_matrix[word_i_index, word_i_index]
181
+
182
+ co_occurrence_count = co_occurrence_matrix[word_i_index, word_j_index]
183
+ smoothed_co_occurrence_count = co_occurrence_count + (smoothing_value * word_i_count)
184
+
185
+ # Calculate the conditional probability with smoothing
186
+ conditional_prob = smoothed_co_occurrence_count / smoothed_word_j_count
187
+
188
+ return conditional_prob
189
+ elif word_i_index is None:
190
+ return 0
191
+ else:
192
+ return None
193
+
194
+
195
+ #geometric_mean_given_words(target_word, context_words, conditional_co_occurrence_matrix, conditioanl_vocabulary, conditional_doc_count, smoothing_value=conditional_smoothing):
196
+ def geometric_mean_given_words(target_word, context_words, co_occurrence_matrix, vocabulary, doc_count, smoothing_value=0.01):
197
+ probabilities = []
198
+
199
+ # Collect the conditional probabilities of the target word given each context word, ignoring None values
200
+ for context_word in context_words:
201
+ prob = conditional_probability(target_word, context_word, co_occurrence_matrix, vocabulary, doc_count, smoothing_value)
202
+ if prob is not None:
203
+ probabilities.append(prob)
204
+
205
+ # Compute the geometric mean of the probabilities, avoiding division by zero
206
+ if probabilities: # Check if the list is not empty
207
+ geometric_mean = np.prod(probabilities) ** (1.0 / len(probabilities))
208
+ else:
209
+ geometric_mean = 0.5 # Or assign some default value if all probabilities are None
210
+
211
+ return geometric_mean
212
+
213
+
214
+ def find_similar_tags(test_tags, similarity_weight):
215
 
216
  #Initialize stuff
217
  if not hasattr(find_similar_tags, "fasttext_small_model"):
 
222
  if not hasattr(find_similar_tags, "alias2tags"):
223
  find_similar_tags.alias2tags = build_aliases_dict(tag_aliases_file, reverse=True)
224
 
225
+ transformed_tags = [tag.replace(' ', '_') for tag in test_tags]
226
+
227
  # Find similar tags and prepare data for dataframe.
228
  results_data = []
229
  for tag in test_tags:
230
+ if tag in special_tags:
231
+ continue
232
+
233
  modified_tag_for_search = tag.replace(' ','_')
234
+ similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
235
  result, seen = [], set()
236
 
237
  if modified_tag_for_search in find_similar_tags.tag2aliases:
 
253
  result.append((similar_tag.replace('_', ' '), round(similarity, 3)))
254
  seen.add(similar_tag)
255
 
256
+ #Adjust score based on context
257
+ for i in range(len(result)):
258
+ word, score = result[i] # Unpack the tuple
259
+ geometric_mean = geometric_mean_given_words(word.replace(' ','_'), [context_tag for context_tag in transformed_tags if context_tag != word and context_tag != tag], conditional_co_occurrence_matrix, conditional_vocabulary, conditional_doc_count, smoothing_value=conditional_smoothing)
260
+ adjusted_score = (similarity_weight * geometric_mean) + ((1-similarity_weight)*score) # Apply the adjustment function
261
+ result[i] = (word, adjusted_score) # Update the tuple with the adjusted score
262
+
263
  # Append tag and formatted similar tags to results_data
264
+ result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
265
  first_entry_for_tag = True
266
  for word, sim in result:
267
  if first_entry_for_tag:
 
276
 
277
  return results_data # Return list of lists for Dataframe
278
 
279
+ def find_similar_artists(new_tags_string, top_n, similarity_weight):
280
  try:
281
  new_tags_string = new_tags_string.lower()
282
  # Parse the prompt
 
286
  new_image_tags = [tag.replace('_', ' ').strip() for tag in new_image_tags]
287
 
288
  ###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
289
+ unseen_tags_data = find_similar_tags(new_image_tags, similarity_weight)
290
+
291
+ X_new_image_transformed = pca_transform(vectorizer.transform([','.join(new_image_tags)]))
292
+ similarities = cosine_similarity(np.asarray(X_new_image_transformed), np.asarray(X_artist_reduced))[0]
293
 
 
 
 
294
  top_artist_indices = np.argsort(similarities)[-top_n:][::-1]
295
  top_artists = [(artist_names[i], similarities[i]) for i in top_artist_indices]
296
+
297
  top_artists_str = "\n".join([f"{rank+1}. {artist[3:]} ({score:.4f})" for rank, (artist, score) in enumerate(top_artists)])
298
  dynamic_prompts_formatted_artists = "{" + "|".join([artist for artist, _ in top_artists]) + "}"
299
+
300
  return unseen_tags_data, top_artists_str, dynamic_prompts_formatted_artists
301
  except ParseError as e:
302
  return [], "Parse Error: Check for mismatched parentheses or something", ""
 
306
  fn=find_similar_artists,
307
  inputs=[
308
  gr.Textbox(label="Enter image tags", placeholder="e.g. fox, outside, detailed background, ..."),
309
+ gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists"),
310
+ gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Similarity weight")
311
  ],
312
  outputs=[
313
  gr.Dataframe(label="Unseen Tags", headers=["Tag", "Similar Tags", "Similarity"]),
requirements.txt CHANGED
@@ -5,3 +5,4 @@ h5py==3.8.0
5
  joblib==1.2.0
6
  compress-fasttext
7
  lark-parser
 
 
5
  joblib==1.2.0
6
  compress-fasttext
7
  lark-parser
8
+ scipy