Spaces:
Running
Running
""" | |
Generate and save the embeddings of a pre-defined list of icons. | |
Compare them with keywords embeddings to find most relevant icons. | |
""" | |
import os | |
import pathlib | |
import sys | |
from typing import List, Tuple | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
from transformers import BertTokenizer, BertModel | |
sys.path.append('..') | |
sys.path.append('../..') | |
from global_config import GlobalConfig | |
tokenizer = BertTokenizer.from_pretrained(GlobalConfig.TINY_BERT_MODEL) | |
model = BertModel.from_pretrained(GlobalConfig.TINY_BERT_MODEL) | |
def get_icons_list() -> List[str]: | |
""" | |
Get a list of available icons. | |
:return: The icons file names. | |
""" | |
items = pathlib.Path('../' + GlobalConfig.ICONS_DIR).glob('*.png') | |
items = [ | |
os.path.basename(str(item)).removesuffix('.png') for item in items | |
] | |
return items | |
def get_embeddings(texts) -> np.ndarray: | |
""" | |
Generate embeddings for a list of texts using a pre-trained language model. | |
:param texts: A string or a list of strings to be converted into embeddings. | |
:type texts: Union[str, List[str]] | |
:return: A NumPy array containing the embeddings for the input texts. | |
:rtype: numpy.ndarray | |
:raises ValueError: If the input is not a string or a list of strings, or if any element | |
in the list is not a string. | |
Example usage: | |
>>> keyword = 'neural network' | |
>>> file_names = ['neural_network_icon.png', 'data_analysis_icon.png', 'machine_learning.png'] | |
>>> keyword_embeddings = get_embeddings(keyword) | |
>>> file_name_embeddings = get_embeddings(file_names) | |
""" | |
inputs = tokenizer(texts, return_tensors='pt', padding=True, max_length=128, truncation=True) | |
outputs = model(**inputs) | |
return outputs.last_hidden_state.mean(dim=1).detach().numpy() | |
def save_icons_embeddings(): | |
""" | |
Generate and save the embeddings for the icon file names. | |
""" | |
file_names = get_icons_list() | |
print(f'{len(file_names)} icon files available...') | |
file_name_embeddings = get_embeddings(file_names) | |
print(f'file_name_embeddings.shape: {file_name_embeddings.shape}') | |
# Save embeddings to a file | |
np.save(GlobalConfig.EMBEDDINGS_FILE_NAME, file_name_embeddings) | |
np.save(GlobalConfig.ICONS_FILE_NAME, file_names) # Save file names for reference | |
def load_saved_embeddings() -> Tuple[np.ndarray, np.ndarray]: | |
""" | |
Load precomputed embeddings and icons file names. | |
:return: The embeddings and the icon file names. | |
""" | |
file_name_embeddings = np.load(GlobalConfig.EMBEDDINGS_FILE_NAME) | |
file_names = np.load(GlobalConfig.ICONS_FILE_NAME) | |
return file_name_embeddings, file_names | |
def find_icons(keywords: List[str]) -> List[str]: | |
""" | |
Find relevant icon file names for a list of keywords. | |
:param keywords: The list of one or more keywords. | |
:return: A list of the file names relevant for each keyword. | |
""" | |
keyword_embeddings = get_embeddings(keywords) | |
file_name_embeddings, file_names = load_saved_embeddings() | |
# Compute similarity | |
similarities = cosine_similarity(keyword_embeddings, file_name_embeddings) | |
icon_files = file_names[np.argmax(similarities, axis=-1)] | |
return icon_files | |
def main(): | |
""" | |
Example usage. | |
""" | |
# Run this again if icons are to be added/removed | |
save_icons_embeddings() | |
keywords = ['deep learning', 'library', 'universe', 'brain', 'cybersecurity', 'gaming', ''] | |
icon_files = find_icons(keywords) | |
print(f'The relevant icon files are: {icon_files}') | |
if __name__ == '__main__': | |
main() | |