Spaces:
Running
Running
File size: 6,174 Bytes
5e9cd1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.faiss import FAISS
import threading
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
logger, log_verbose)
from server.utils import embedding_device, get_model_path, list_online_embed_models
from contextlib import contextmanager
from collections import OrderedDict
from typing import List, Any, Union, Tuple
class ThreadSafeObject:
def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None):
self._obj = obj
self._key = key
self._pool = pool
self._lock = threading.RLock()
self._loaded = threading.Event()
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self.key}, obj: {self._obj}>"
@property
def key(self):
return self._key
@contextmanager
def acquire(self, owner: str = "", msg: str = "") -> FAISS:
owner = owner or f"thread {threading.get_native_id()}"
try:
self._lock.acquire()
if self._pool is not None:
self._pool._cache.move_to_end(self.key)
if log_verbose:
logger.info(f"{owner} 开始操作:{self.key}。{msg}")
yield self._obj
finally:
if log_verbose:
logger.info(f"{owner} 结束操作:{self.key}。{msg}")
self._lock.release()
def start_loading(self):
self._loaded.clear()
def finish_loading(self):
self._loaded.set()
def wait_for_loading(self):
self._loaded.wait()
@property
def obj(self):
return self._obj
@obj.setter
def obj(self, val: Any):
self._obj = val
class CachePool:
def __init__(self, cache_num: int = -1):
self._cache_num = cache_num
self._cache = OrderedDict()
self.atomic = threading.RLock()
def keys(self) -> List[str]:
return list(self._cache.keys())
def _check_count(self):
if isinstance(self._cache_num, int) and self._cache_num > 0:
while len(self._cache) > self._cache_num:
self._cache.popitem(last=False)
def get(self, key: str) -> ThreadSafeObject:
if cache := self._cache.get(key):
cache.wait_for_loading()
return cache
def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:
self._cache[key] = obj
self._check_count()
return obj
def pop(self, key: str = None) -> ThreadSafeObject:
if key is None:
return self._cache.popitem(last=False)
else:
return self._cache.pop(key, None)
def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):
cache = self.get(key)
if cache is None:
raise RuntimeError(f"请求的资源 {key} 不存在")
elif isinstance(cache, ThreadSafeObject):
self._cache.move_to_end(key)
return cache.acquire(owner=owner, msg=msg)
else:
return cache
def load_kb_embeddings(
self,
kb_name: str,
embed_device: str = embedding_device(),
default_embed_model: str = EMBEDDING_MODEL,
) -> Embeddings:
from server.db.repository.knowledge_base_repository import get_kb_detail
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
kb_detail = get_kb_detail(kb_name)
embed_model = kb_detail.get("embed_model", default_embed_model)
if embed_model in list_online_embed_models():
return EmbeddingsFunAdapter(embed_model)
else:
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
class EmbeddingsPool(CachePool):
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
self.atomic.acquire()
model = model or EMBEDDING_MODEL
device = embedding_device()
key = (model, device)
if not self.get(key):
item = ThreadSafeObject(key, pool=self)
self.set(key, item)
with item.acquire(msg="初始化"):
self.atomic.release()
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model=model,
openai_api_key=get_model_path(model),
chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
from langchain.embeddings import HuggingFaceBgeEmbeddings
if 'zh' in model:
# for chinese model
query_instruction = "为这个句子生成表示以用于检索相关文章:"
elif 'en' in model:
# for english model
query_instruction = "Represent this sentence for searching relevant passages:"
else:
# maybe ReRanker or else, just use empty string instead
query_instruction = ""
embeddings = HuggingFaceBgeEmbeddings(model_name=get_model_path(model),
model_kwargs={'device': device},
query_instruction=query_instruction)
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model),
model_kwargs={'device': device})
item.obj = embeddings
item.finish_loading()
else:
self.atomic.release()
return self.get(key).obj
embeddings_pool = EmbeddingsPool(cache_num=1)
|