Spaces:
Runtime error
Runtime error
import collections | |
import os | |
import pickle | |
from argparse import Namespace | |
import numpy as np | |
import torch | |
from PIL import Image | |
from torch import cosine_similarity | |
from transformers import AutoTokenizer, AutoModel | |
def download_models(): | |
# Import our models. The package will take care of downloading the models automatically | |
model_args = Namespace(do_mlm=None, pooler_type="cls", temp=0.05, mlp_only_train=False, | |
init_embeddings_model=None) | |
model = AutoModel.from_pretrained("silk-road/luotuo-bert", trust_remote_code=True, model_args=model_args) | |
return model | |
class Text: | |
def __init__(self, text_dir, model, num_steps, text_image_pkl_path=None, dict_text_pkl_path=None, pkl_path=None, dict_path=None, image_path=None, maps_path=None): | |
self.dict_text_pkl_path = dict_text_pkl_path | |
self.text_image_pkl_path = text_image_pkl_path | |
self.text_dir = text_dir | |
self.model = model | |
self.num_steps = num_steps | |
self.pkl_path = pkl_path | |
self.dict_path = dict_path | |
self.image_path = image_path | |
self.maps_path = maps_path | |
def get_embedding(self, texts): | |
tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert") | |
model = download_models() | |
# 截断 | |
# str or strList | |
texts = texts if isinstance(texts, list) else [texts] | |
for i in range(len(texts)): | |
if len(texts[i]) > self.num_steps: | |
texts[i] = texts[i][:self.num_steps] | |
# Tokenize the texts | |
inputs = tokenizer(texts, padding=True, truncation=False, return_tensors="pt") | |
# Extract the embeddings | |
# Get the embeddings | |
with torch.no_grad(): | |
embeddings = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output | |
return embeddings[0] if len(texts) == 1 else embeddings | |
def read_text(self, save_embeddings=False, save_maps=False): | |
"""抽取、预存""" | |
text_embeddings = collections.defaultdict() | |
text_keys = [] | |
dirs = os.listdir(self.text_dir) | |
data = [] | |
texts = [] | |
id = 0 | |
for dir in dirs: | |
with open(self.text_dir + '/' + dir, 'r') as fr: | |
for line in fr.readlines(): | |
category = collections.defaultdict(str) | |
ch = ':' if ':' in line else ':' | |
if '旁白' in line: | |
text = line.strip().split(ch)[1].strip() | |
else: | |
text = ''.join(list(line.strip().split(ch)[1])[1:-1]) # 提取「」内的文本 | |
if text in text_keys: # 避免重复的text,导致embeds 和 maps形状不一致 | |
continue | |
text_keys.append(text) | |
if save_maps: | |
category["titles"] = dir.split('.')[0] | |
category["id"] = str(id) | |
category["text"] = text | |
id = id + 1 | |
data.append(dict(category)) | |
texts.append(text) | |
embeddings = self.get_embedding(texts) | |
if save_embeddings: | |
for text, embed in zip(texts, embeddings): | |
text_embeddings[text] = self.get_embedding(text) | |
if save_embeddings: | |
self.store(self.pkl_path, text_embeddings) | |
if save_maps: | |
self.store(self.maps_path, data) | |
return text_embeddings, data | |
def load(self, load_pkl=False, load_maps=False, load_dict_text=False, load_text_image=False): | |
if self.pkl_path and load_pkl: | |
with open(self.pkl_path, 'rb') as f: | |
return pickle.load(f) | |
elif self.maps_path and load_maps: | |
with open(self.maps_path, 'rb') as f: | |
return pickle.load(f) | |
elif self.dict_text_pkl_path and load_dict_text: | |
with open(self.dict_text_pkl_path, 'rb') as f: | |
return pickle.load(f) | |
elif self.text_image_pkl_path and load_text_image: | |
with open(self.text_image_pkl_path, 'rb') as f: | |
return pickle.load(f) | |
else: | |
print("No pkl_path") | |
def get_cosine_similarity(self, texts, get_image=False, get_texts=False): | |
""" | |
计算文本列表的相似度避免重复计算query_similarity | |
texts[0] = query | |
""" | |
if get_image: | |
pkl = self.load(load_dict_text=True) | |
elif get_texts: | |
pkl = self.load(load_pkl=True) | |
else: | |
pkl = {} | |
embeddings = self.get_embedding(texts[1:]).reshape(-1, 1536) | |
for text, embed in zip(texts, embeddings): | |
pkl[text] = embed | |
query_embedding = self.get_embedding(texts[0]).reshape(1, -1) | |
texts_embeddings = np.array([value.numpy().reshape(-1, 1536) for value in pkl.values()]).squeeze(1) | |
return cosine_similarity(query_embedding, torch.from_numpy(texts_embeddings)) | |
def store(self, path, data): | |
with open(path, 'wb+') as f: | |
pickle.dump(data, f) | |
def text_to_image(self, text, save_dict_text=False): | |
""" | |
给定文本出图片 | |
计算query 和 texts 的相似度,取最高的作为new_query 查询image | |
到text_image_dict 读取图片名 | |
然后到images里面加载该图片然后返回 | |
""" | |
if save_dict_text: | |
text_image = {} | |
with open(self.dict_path, 'r') as f: | |
data = f.readlines() | |
for sub_text, image in zip(data[::2], data[1::2]): | |
text_image[sub_text.strip()] = image.strip() | |
self.store(self.text_image_pkl_path, text_image) | |
keys_embeddings = {} | |
embeddings = self.get_embedding(list(text_image.keys())) | |
for key, embed in zip(text_image.keys(), embeddings): | |
keys_embeddings[key] = embed | |
self.store(self.dict_text_pkl_path, keys_embeddings) | |
if self.dict_path and self.image_path: | |
# 加载 text-imageName | |
text_image = self.load(load_text_image=True) | |
keys = list(text_image.keys()) | |
keys.insert(0, text) | |
query_similarity = self.get_cosine_similarity(keys, get_image=True) | |
key_index = query_similarity.argmax(dim=0) | |
text = list(text_image.keys())[key_index] | |
image = text_image[text] + '.jpg' | |
if image in os.listdir(self.image_path): | |
res = Image.open(self.image_path + '/' + image) | |
# res.show() | |
return res | |
else: | |
print("Image doesn't exist") | |
else: | |
print("No path") | |
def text_to_text(self, text): | |
pkl = self.load(load_pkl=True) | |
texts = list(pkl.keys()) | |
texts.insert(0, text) | |
texts_similarity = self.get_cosine_similarity(texts, get_texts=True) | |
key_index = texts_similarity.argmax(dim=0).item() | |
value = list(pkl.keys())[key_index] | |
return value | |
# if __name__ == '__main__': | |
# pkl_path = './pkl/texts.pkl' | |
# maps_path = './pkl/maps.pkl' | |
# text_image_pkl_path='./pkl/text_image.pkl' | |
# dict_path = "../characters/haruhi/text_image_dict.txt" | |
# dict_text_pkl_path = './pkl/dict_text.pkl' | |
# image_path = "../characters/haruhi/images" | |
# text_dir = "../characters/haruhi/texts" | |
# model = download_models() | |
# text = Text(text_dir, text_image_pkl_path=text_image_pkl_path, maps_path=maps_path, | |
# dict_text_pkl_path=dict_text_pkl_path, model=model, num_steps=50, pkl_path=pkl_path, | |
# dict_path=dict_path, image_path=image_path) | |
# text.read_text(save_maps=True, save_embeddings=True) | |
# data = text.load(load_pkl=True) | |
# sub_text = "你好!" | |
# image = text.text_to_image(sub_text) | |
# print(image) | |
# sub_texts = ["hello", "你好"] | |
# print(text.get_cosine_similarity(sub_texts)) | |
# value = text.text_to_text(sub_text) | |
# print(value) | |