Spaces:
Running
Running
from langchain.embeddings.base import Embeddings | |
from langchain.vectorstores.faiss import FAISS | |
import threading | |
from configs import (EMBEDDING_MODEL, CHUNK_SIZE, | |
logger, log_verbose) | |
from server.utils import embedding_device, get_model_path, list_online_embed_models | |
from contextlib import contextmanager | |
from collections import OrderedDict | |
from typing import List, Any, Union, Tuple | |
class ThreadSafeObject: | |
def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None): | |
self._obj = obj | |
self._key = key | |
self._pool = pool | |
self._lock = threading.RLock() | |
self._loaded = threading.Event() | |
def __repr__(self) -> str: | |
cls = type(self).__name__ | |
return f"<{cls}: key: {self.key}, obj: {self._obj}>" | |
def key(self): | |
return self._key | |
def acquire(self, owner: str = "", msg: str = "") -> FAISS: | |
owner = owner or f"thread {threading.get_native_id()}" | |
try: | |
self._lock.acquire() | |
if self._pool is not None: | |
self._pool._cache.move_to_end(self.key) | |
if log_verbose: | |
logger.info(f"{owner} 开始操作:{self.key}。{msg}") | |
yield self._obj | |
finally: | |
if log_verbose: | |
logger.info(f"{owner} 结束操作:{self.key}。{msg}") | |
self._lock.release() | |
def start_loading(self): | |
self._loaded.clear() | |
def finish_loading(self): | |
self._loaded.set() | |
def wait_for_loading(self): | |
self._loaded.wait() | |
def obj(self): | |
return self._obj | |
def obj(self, val: Any): | |
self._obj = val | |
class CachePool: | |
def __init__(self, cache_num: int = -1): | |
self._cache_num = cache_num | |
self._cache = OrderedDict() | |
self.atomic = threading.RLock() | |
def keys(self) -> List[str]: | |
return list(self._cache.keys()) | |
def _check_count(self): | |
if isinstance(self._cache_num, int) and self._cache_num > 0: | |
while len(self._cache) > self._cache_num: | |
self._cache.popitem(last=False) | |
def get(self, key: str) -> ThreadSafeObject: | |
if cache := self._cache.get(key): | |
cache.wait_for_loading() | |
return cache | |
def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject: | |
self._cache[key] = obj | |
self._check_count() | |
return obj | |
def pop(self, key: str = None) -> ThreadSafeObject: | |
if key is None: | |
return self._cache.popitem(last=False) | |
else: | |
return self._cache.pop(key, None) | |
def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""): | |
cache = self.get(key) | |
if cache is None: | |
raise RuntimeError(f"请求的资源 {key} 不存在") | |
elif isinstance(cache, ThreadSafeObject): | |
self._cache.move_to_end(key) | |
return cache.acquire(owner=owner, msg=msg) | |
else: | |
return cache | |
def load_kb_embeddings( | |
self, | |
kb_name: str, | |
embed_device: str = embedding_device(), | |
default_embed_model: str = EMBEDDING_MODEL, | |
) -> Embeddings: | |
from server.db.repository.knowledge_base_repository import get_kb_detail | |
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter | |
kb_detail = get_kb_detail(kb_name) | |
embed_model = kb_detail.get("embed_model", default_embed_model) | |
if embed_model in list_online_embed_models(): | |
return EmbeddingsFunAdapter(embed_model) | |
else: | |
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device) | |
class EmbeddingsPool(CachePool): | |
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings: | |
self.atomic.acquire() | |
model = model or EMBEDDING_MODEL | |
device = embedding_device() | |
key = (model, device) | |
if not self.get(key): | |
item = ThreadSafeObject(key, pool=self) | |
self.set(key, item) | |
with item.acquire(msg="初始化"): | |
self.atomic.release() | |
if model == "text-embedding-ada-002": # openai text-embedding-ada-002 | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
embeddings = OpenAIEmbeddings(model=model, | |
openai_api_key=get_model_path(model), | |
chunk_size=CHUNK_SIZE) | |
elif 'bge-' in model: | |
from langchain.embeddings import HuggingFaceBgeEmbeddings | |
if 'zh' in model: | |
# for chinese model | |
query_instruction = "为这个句子生成表示以用于检索相关文章:" | |
elif 'en' in model: | |
# for english model | |
query_instruction = "Represent this sentence for searching relevant passages:" | |
else: | |
# maybe ReRanker or else, just use empty string instead | |
query_instruction = "" | |
embeddings = HuggingFaceBgeEmbeddings(model_name=get_model_path(model), | |
model_kwargs={'device': device}, | |
query_instruction=query_instruction) | |
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding | |
embeddings.query_instruction = "" | |
else: | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), | |
model_kwargs={'device': device}) | |
item.obj = embeddings | |
item.finish_loading() | |
else: | |
self.atomic.release() | |
return self.get(key).obj | |
embeddings_pool = EmbeddingsPool(cache_num=1) | |