Hansimov's picture
:boom: [Fix] Permission denied when saving onnx model
93e2e2a
import os
import numpy as np
import torch
from pathlib import Path
from typing import Union
from huggingface_hub import hf_hub_download
from numpy.linalg import norm
from onnxruntime import InferenceSession
from tclogger import logger
from transformers import AutoTokenizer, AutoModel
from configs.envs import ENVS
from configs.constants import AVAILABLE_MODELS
if ENVS["HF_ENDPOINT"]:
os.environ["HF_ENDPOINT"] = ENVS["HF_ENDPOINT"]
os.environ["HF_TOKEN"] = ENVS["HF_TOKEN"]
def cosine_similarity(a, b):
return (a @ b.T) / (norm(a) * norm(b))
class JinaAIOnnxEmbedder:
"""https://huggingface.co/jinaai/jina-embeddings-v2-base-zh/discussions/6#65bc55a854ab5eb7b6300893"""
def __init__(self):
self.repo_name = "jinaai/jina-embeddings-v2-base-zh"
self.download_model()
self.load_model()
def download_model(self):
self.onnx_folder = Path(__file__).parents[2] / ".cache"
self.onnx_folder.mkdir(parents=True, exist_ok=True)
self.onnx_filename = "onnx/model_quantized.onnx"
self.onnx_path = self.onnx_folder / self.onnx_filename
if not self.onnx_path.exists():
logger.note("> Downloading ONNX model")
hf_hub_download(
repo_id=self.repo_name,
filename=self.onnx_filename,
local_dir=self.onnx_folder,
local_dir_use_symlinks=False,
)
logger.success(f"+ ONNX model downloaded: {self.onnx_path}")
else:
logger.success(f"+ ONNX model loaded: {self.onnx_path}")
def load_model(self):
self.tokenizer = AutoTokenizer.from_pretrained(
self.repo_name, trust_remote_code=True
)
self.session = InferenceSession(self.onnx_path)
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def encode(self, text: str):
inputs = self.tokenizer(text, return_tensors="np")
inputs = {
name: np.array(tensor, dtype=np.int64) for name, tensor in inputs.items()
}
outputs = self.session.run(
output_names=["last_hidden_state"], input_feed=dict(inputs)
)
embeddings = self.mean_pooling(
torch.from_numpy(outputs[0]), torch.from_numpy(inputs["attention_mask"])
)
return embeddings
class JinaAIEmbedder:
def __init__(self, model_name: str = AVAILABLE_MODELS[0]):
self.model_name = model_name
self.load_model()
def check_model_name(self):
if self.model_name not in AVAILABLE_MODELS:
self.model_name = AVAILABLE_MODELS[0]
return True
def load_model(self):
self.check_model_name()
self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
def switch_model(self, model_name: str):
if model_name != self.model_name:
self.model_name = model_name
self.load_model()
def encode(self, text: Union[str, list[str]]):
if isinstance(text, str):
text = [text]
return self.model.encode(text)
if __name__ == "__main__":
# embedder = JinaAIEmbedder()
embedder = JinaAIOnnxEmbedder()
texts = ["How is the weather today?", "今天天气怎么样?"]
embeddings = []
for text in texts:
embeddings.append(embedder.encode(text))
logger.success(embeddings)
print(cosine_similarity(embeddings[0], embeddings[1]))
# python -m transforms.embed