Spaces:
Running
Running
FoodDesert
commited on
Commit
•
90290aa
1
Parent(s):
e2d3b05
Upload 2 files
Browse files- app.py +105 -19
- requirements.txt +1 -0
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
3 |
import numpy as np
|
4 |
from joblib import load
|
5 |
import h5py
|
@@ -11,6 +12,8 @@ import compress_fasttext
|
|
11 |
from collections import OrderedDict
|
12 |
from lark import Lark
|
13 |
from lark import Token
|
|
|
|
|
14 |
|
15 |
|
16 |
|
@@ -69,12 +72,12 @@ You can read more about TF-IDF on its [Wikipedia page](https://en.wikipedia.org/
|
|
69 |
|
70 |
## How does the tag corrector work?
|
71 |
|
72 |
-
We
|
73 |
We then randomly replace about 10% of the tags in each document with a randomly selected alias from e621's list of aliases for the tag
|
74 |
(e.g. "canine" gets replaced with one of {k9,canines,mongrel,cannine,cnaine,feral_canine,anthro_canine}).
|
75 |
We then train a FastText (https://fasttext.cc/) model on the documents. The result of this training is a function that maps arbitrary words to vectors such that
|
76 |
the vector for a tag and the vectors for its aliases are all close together (because the model has seen them in similar contexts).
|
77 |
-
Since the lists of aliases contain misspellings and rephrasings of tags, the model should be robust to these kinds of problems.
|
78 |
"""
|
79 |
|
80 |
|
@@ -92,6 +95,9 @@ plain: /([^,\\\[\]():|]|\\.)+/
|
|
92 |
parser = Lark(grammar, start='start')
|
93 |
|
94 |
|
|
|
|
|
|
|
95 |
# Function to extract tags
|
96 |
def extract_tags(tree):
|
97 |
tags = []
|
@@ -107,21 +113,43 @@ def extract_tags(tree):
|
|
107 |
|
108 |
|
109 |
# Load the model and data once at startup
|
110 |
-
with h5py.File('
|
111 |
-
# Deserialize the vectorizer
|
112 |
vectorizer_bytes = f['vectorizer'][()].tobytes()
|
|
|
113 |
vectorizer_buffer = BytesIO(vectorizer_bytes)
|
114 |
vectorizer = load(vectorizer_buffer)
|
115 |
|
116 |
-
#
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
artist_names = [name.decode() for name in f['artist_names'][:]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
def clean_tag(tag):
|
123 |
return ''.join(char for char in tag if ord(char) < 128)
|
124 |
|
|
|
125 |
#Normally returns tag to aliases, but when reverse=True, returns alias to tags
|
126 |
def build_aliases_dict(filename, reverse=False):
|
127 |
aliases_dict = {}
|
@@ -138,7 +166,52 @@ def build_aliases_dict(filename, reverse=False):
|
|
138 |
return aliases_dict
|
139 |
|
140 |
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
#Initialize stuff
|
144 |
if not hasattr(find_similar_tags, "fasttext_small_model"):
|
@@ -149,12 +222,16 @@ def find_similar_tags(test_tags):
|
|
149 |
if not hasattr(find_similar_tags, "alias2tags"):
|
150 |
find_similar_tags.alias2tags = build_aliases_dict(tag_aliases_file, reverse=True)
|
151 |
|
152 |
-
|
|
|
153 |
# Find similar tags and prepare data for dataframe.
|
154 |
results_data = []
|
155 |
for tag in test_tags:
|
|
|
|
|
|
|
156 |
modified_tag_for_search = tag.replace(' ','_')
|
157 |
-
similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search)
|
158 |
result, seen = [], set()
|
159 |
|
160 |
if modified_tag_for_search in find_similar_tags.tag2aliases:
|
@@ -176,7 +253,15 @@ def find_similar_tags(test_tags):
|
|
176 |
result.append((similar_tag.replace('_', ' '), round(similarity, 3)))
|
177 |
seen.add(similar_tag)
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
# Append tag and formatted similar tags to results_data
|
|
|
180 |
first_entry_for_tag = True
|
181 |
for word, sim in result:
|
182 |
if first_entry_for_tag:
|
@@ -191,7 +276,7 @@ def find_similar_tags(test_tags):
|
|
191 |
|
192 |
return results_data # Return list of lists for Dataframe
|
193 |
|
194 |
-
def find_similar_artists(new_tags_string, top_n):
|
195 |
try:
|
196 |
new_tags_string = new_tags_string.lower()
|
197 |
# Parse the prompt
|
@@ -201,17 +286,17 @@ def find_similar_artists(new_tags_string, top_n):
|
|
201 |
new_image_tags = [tag.replace('_', ' ').strip() for tag in new_image_tags]
|
202 |
|
203 |
###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
|
204 |
-
unseen_tags_data = find_similar_tags(new_image_tags)
|
|
|
|
|
|
|
205 |
|
206 |
-
X_new_image = vectorizer.transform([','.join(new_image_tags)])
|
207 |
-
similarities = cosine_similarity(X_new_image, X_artist)[0]
|
208 |
-
|
209 |
top_artist_indices = np.argsort(similarities)[-top_n:][::-1]
|
210 |
top_artists = [(artist_names[i], similarities[i]) for i in top_artist_indices]
|
211 |
-
|
212 |
top_artists_str = "\n".join([f"{rank+1}. {artist[3:]} ({score:.4f})" for rank, (artist, score) in enumerate(top_artists)])
|
213 |
dynamic_prompts_formatted_artists = "{" + "|".join([artist for artist, _ in top_artists]) + "}"
|
214 |
-
|
215 |
return unseen_tags_data, top_artists_str, dynamic_prompts_formatted_artists
|
216 |
except ParseError as e:
|
217 |
return [], "Parse Error: Check for mismatched parentheses or something", ""
|
@@ -221,7 +306,8 @@ iface = gr.Interface(
|
|
221 |
fn=find_similar_artists,
|
222 |
inputs=[
|
223 |
gr.Textbox(label="Enter image tags", placeholder="e.g. fox, outside, detailed background, ..."),
|
224 |
-
gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
|
|
|
225 |
],
|
226 |
outputs=[
|
227 |
gr.Dataframe(label="Unseen Tags", headers=["Tag", "Similar Tags", "Similarity"]),
|
|
|
1 |
import gradio as gr
|
2 |
from sklearn.metrics.pairwise import cosine_similarity
|
3 |
+
from scipy.sparse import csr_matrix
|
4 |
import numpy as np
|
5 |
from joblib import load
|
6 |
import h5py
|
|
|
12 |
from collections import OrderedDict
|
13 |
from lark import Lark
|
14 |
from lark import Token
|
15 |
+
from lark.exceptions import ParseError
|
16 |
+
|
17 |
|
18 |
|
19 |
|
|
|
72 |
|
73 |
## How does the tag corrector work?
|
74 |
|
75 |
+
We collect the tag sets from over 4 million e621 posts, treating the tag set from each image as an individual document.
|
76 |
We then randomly replace about 10% of the tags in each document with a randomly selected alias from e621's list of aliases for the tag
|
77 |
(e.g. "canine" gets replaced with one of {k9,canines,mongrel,cannine,cnaine,feral_canine,anthro_canine}).
|
78 |
We then train a FastText (https://fasttext.cc/) model on the documents. The result of this training is a function that maps arbitrary words to vectors such that
|
79 |
the vector for a tag and the vectors for its aliases are all close together (because the model has seen them in similar contexts).
|
80 |
+
Since the lists of aliases contain misspellings and rephrasings of tags, the model should be robust to these kinds of problems as long as they are not too dissimilar from the alias lists.
|
81 |
"""
|
82 |
|
83 |
|
|
|
95 |
parser = Lark(grammar, start='start')
|
96 |
|
97 |
|
98 |
+
special_tags = ["score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9"]
|
99 |
+
|
100 |
+
|
101 |
# Function to extract tags
|
102 |
def extract_tags(tree):
|
103 |
tags = []
|
|
|
113 |
|
114 |
|
115 |
# Load the model and data once at startup
|
116 |
+
with h5py.File('pca_reduced_artist_data.hdf5', 'r') as f:
|
|
|
117 |
vectorizer_bytes = f['vectorizer'][()].tobytes()
|
118 |
+
# Use io.BytesIO to convert bytes back to a file-like object for joblib to load
|
119 |
vectorizer_buffer = BytesIO(vectorizer_bytes)
|
120 |
vectorizer = load(vectorizer_buffer)
|
121 |
|
122 |
+
# Assuming you've saved the PCA mean, components, and the transformed X_artist matrix in the file
|
123 |
+
pca_mean = f['pca_mean'][:]
|
124 |
+
pca_components = f['pca_components'][:]
|
125 |
+
X_artist_reduced = f['X_artist_reduced'][:]
|
126 |
artist_names = [name.decode() for name in f['artist_names'][:]]
|
127 |
+
# Recreate PCA transformation (not the exact PCA object but its transformation ability)
|
128 |
+
def pca_transform(X):
|
129 |
+
return (X - pca_mean) @ pca_components.T
|
130 |
+
|
131 |
+
|
132 |
+
with h5py.File('conditional_tag_probabilities_matrix.h5', 'r') as f:
|
133 |
+
# Reconstruct the sparse co-occurrence matrix
|
134 |
+
conditional_co_occurrence_matrix = csr_matrix(
|
135 |
+
(f['co_occurrence_data'][:], f['co_occurrence_indices'][:], f['co_occurrence_indptr'][:]),
|
136 |
+
shape=f['co_occurrence_shape'][:]
|
137 |
+
)
|
138 |
+
|
139 |
+
# Reconstruct the vocabulary
|
140 |
+
conditional_words = f['vocabulary_words'][:]
|
141 |
+
conditional_indices = f['vocabulary_indices'][:]
|
142 |
+
conditional_vocabulary = {key.decode('utf-8'): value for key, value in zip(conditional_words, conditional_indices)}
|
143 |
+
|
144 |
+
# Load the document count
|
145 |
+
conditional_doc_count = f['doc_count'][()]
|
146 |
+
conditional_smoothing = 100. / conditional_doc_count
|
147 |
+
|
148 |
|
149 |
def clean_tag(tag):
|
150 |
return ''.join(char for char in tag if ord(char) < 128)
|
151 |
|
152 |
+
|
153 |
#Normally returns tag to aliases, but when reverse=True, returns alias to tags
|
154 |
def build_aliases_dict(filename, reverse=False):
|
155 |
aliases_dict = {}
|
|
|
166 |
return aliases_dict
|
167 |
|
168 |
|
169 |
+
#Imagine we are adding smoothing_value to the number of times word_j occurs in each document for smoothing.
|
170 |
+
#Note the intention is that sum_i(P(word_i|word_j)) =(approx) # of words in a document rather than 1.
|
171 |
+
def conditional_probability(word_i, word_j, co_occurrence_matrix, vocabulary, doc_count, smoothing_value=0.01):
|
172 |
+
word_i_index = vocabulary.get(word_i)
|
173 |
+
word_j_index = vocabulary.get(word_j)
|
174 |
+
|
175 |
+
if word_i_index is not None and word_j_index is not None:
|
176 |
+
# Directly access the sparse matrix elements
|
177 |
+
word_j_count = co_occurrence_matrix[word_j_index, word_j_index]
|
178 |
+
smoothed_word_j_count = word_j_count + (smoothing_value * doc_count)
|
179 |
+
|
180 |
+
word_i_count = co_occurrence_matrix[word_i_index, word_i_index]
|
181 |
+
|
182 |
+
co_occurrence_count = co_occurrence_matrix[word_i_index, word_j_index]
|
183 |
+
smoothed_co_occurrence_count = co_occurrence_count + (smoothing_value * word_i_count)
|
184 |
+
|
185 |
+
# Calculate the conditional probability with smoothing
|
186 |
+
conditional_prob = smoothed_co_occurrence_count / smoothed_word_j_count
|
187 |
+
|
188 |
+
return conditional_prob
|
189 |
+
elif word_i_index is None:
|
190 |
+
return 0
|
191 |
+
else:
|
192 |
+
return None
|
193 |
+
|
194 |
+
|
195 |
+
#geometric_mean_given_words(target_word, context_words, conditional_co_occurrence_matrix, conditioanl_vocabulary, conditional_doc_count, smoothing_value=conditional_smoothing):
|
196 |
+
def geometric_mean_given_words(target_word, context_words, co_occurrence_matrix, vocabulary, doc_count, smoothing_value=0.01):
|
197 |
+
probabilities = []
|
198 |
+
|
199 |
+
# Collect the conditional probabilities of the target word given each context word, ignoring None values
|
200 |
+
for context_word in context_words:
|
201 |
+
prob = conditional_probability(target_word, context_word, co_occurrence_matrix, vocabulary, doc_count, smoothing_value)
|
202 |
+
if prob is not None:
|
203 |
+
probabilities.append(prob)
|
204 |
+
|
205 |
+
# Compute the geometric mean of the probabilities, avoiding division by zero
|
206 |
+
if probabilities: # Check if the list is not empty
|
207 |
+
geometric_mean = np.prod(probabilities) ** (1.0 / len(probabilities))
|
208 |
+
else:
|
209 |
+
geometric_mean = 0.5 # Or assign some default value if all probabilities are None
|
210 |
+
|
211 |
+
return geometric_mean
|
212 |
+
|
213 |
+
|
214 |
+
def find_similar_tags(test_tags, similarity_weight):
|
215 |
|
216 |
#Initialize stuff
|
217 |
if not hasattr(find_similar_tags, "fasttext_small_model"):
|
|
|
222 |
if not hasattr(find_similar_tags, "alias2tags"):
|
223 |
find_similar_tags.alias2tags = build_aliases_dict(tag_aliases_file, reverse=True)
|
224 |
|
225 |
+
transformed_tags = [tag.replace(' ', '_') for tag in test_tags]
|
226 |
+
|
227 |
# Find similar tags and prepare data for dataframe.
|
228 |
results_data = []
|
229 |
for tag in test_tags:
|
230 |
+
if tag in special_tags:
|
231 |
+
continue
|
232 |
+
|
233 |
modified_tag_for_search = tag.replace(' ','_')
|
234 |
+
similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
|
235 |
result, seen = [], set()
|
236 |
|
237 |
if modified_tag_for_search in find_similar_tags.tag2aliases:
|
|
|
253 |
result.append((similar_tag.replace('_', ' '), round(similarity, 3)))
|
254 |
seen.add(similar_tag)
|
255 |
|
256 |
+
#Adjust score based on context
|
257 |
+
for i in range(len(result)):
|
258 |
+
word, score = result[i] # Unpack the tuple
|
259 |
+
geometric_mean = geometric_mean_given_words(word.replace(' ','_'), [context_tag for context_tag in transformed_tags if context_tag != word and context_tag != tag], conditional_co_occurrence_matrix, conditional_vocabulary, conditional_doc_count, smoothing_value=conditional_smoothing)
|
260 |
+
adjusted_score = (similarity_weight * geometric_mean) + ((1-similarity_weight)*score) # Apply the adjustment function
|
261 |
+
result[i] = (word, adjusted_score) # Update the tuple with the adjusted score
|
262 |
+
|
263 |
# Append tag and formatted similar tags to results_data
|
264 |
+
result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
|
265 |
first_entry_for_tag = True
|
266 |
for word, sim in result:
|
267 |
if first_entry_for_tag:
|
|
|
276 |
|
277 |
return results_data # Return list of lists for Dataframe
|
278 |
|
279 |
+
def find_similar_artists(new_tags_string, top_n, similarity_weight):
|
280 |
try:
|
281 |
new_tags_string = new_tags_string.lower()
|
282 |
# Parse the prompt
|
|
|
286 |
new_image_tags = [tag.replace('_', ' ').strip() for tag in new_image_tags]
|
287 |
|
288 |
###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
|
289 |
+
unseen_tags_data = find_similar_tags(new_image_tags, similarity_weight)
|
290 |
+
|
291 |
+
X_new_image_transformed = pca_transform(vectorizer.transform([','.join(new_image_tags)]))
|
292 |
+
similarities = cosine_similarity(np.asarray(X_new_image_transformed), np.asarray(X_artist_reduced))[0]
|
293 |
|
|
|
|
|
|
|
294 |
top_artist_indices = np.argsort(similarities)[-top_n:][::-1]
|
295 |
top_artists = [(artist_names[i], similarities[i]) for i in top_artist_indices]
|
296 |
+
|
297 |
top_artists_str = "\n".join([f"{rank+1}. {artist[3:]} ({score:.4f})" for rank, (artist, score) in enumerate(top_artists)])
|
298 |
dynamic_prompts_formatted_artists = "{" + "|".join([artist for artist, _ in top_artists]) + "}"
|
299 |
+
|
300 |
return unseen_tags_data, top_artists_str, dynamic_prompts_formatted_artists
|
301 |
except ParseError as e:
|
302 |
return [], "Parse Error: Check for mismatched parentheses or something", ""
|
|
|
306 |
fn=find_similar_artists,
|
307 |
inputs=[
|
308 |
gr.Textbox(label="Enter image tags", placeholder="e.g. fox, outside, detailed background, ..."),
|
309 |
+
gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists"),
|
310 |
+
gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Similarity weight")
|
311 |
],
|
312 |
outputs=[
|
313 |
gr.Dataframe(label="Unseen Tags", headers=["Tag", "Similar Tags", "Similarity"]),
|
requirements.txt
CHANGED
@@ -5,3 +5,4 @@ h5py==3.8.0
|
|
5 |
joblib==1.2.0
|
6 |
compress-fasttext
|
7 |
lark-parser
|
|
|
|
5 |
joblib==1.2.0
|
6 |
compress-fasttext
|
7 |
lark-parser
|
8 |
+
scipy
|