Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
from sentence_transformers import SentenceTransformer | |
from .utils import load_json | |
import faiss | |
import logging | |
import os | |
import re | |
import torch | |
logger = logging.getLogger(__name__) | |
class RetrieveDialog: | |
def __init__(self, | |
role_name, | |
raw_dialog_list: list = None, | |
retrieve_num=20, | |
min_mean_role_utter_length=10): | |
if torch.cuda.is_available(): | |
gpu_id = 0 | |
torch.cuda.set_device(gpu_id) | |
assert raw_dialog_list | |
self.role_name = role_name | |
self.min_mean_role_utter_length = min_mean_role_utter_length | |
self.retrieve_num = retrieve_num | |
# config = load_json("config/config.json") | |
# local_dir = config["bge_local_path"] | |
# local_dir = os.environ.get('MODEL_PATH', 'BAAI/bge-large-zh-v1.5') | |
# if not os.path.exists(local_dir): | |
# print("Please download bge-large-zh-v1.5 first!") | |
self.emb_model = SentenceTransformer("BAAI/bge-large-zh-v1.5") | |
self.dialogs, self.context_index = self._get_emb_base_by_list(raw_dialog_list) | |
logger.info(f"dialog db num: {len(self.dialogs)}") | |
logger.info(f"RetrieveDialog init success.") | |
def dialog_preprocess(dialog: list, role_name): | |
dialog_new = [] | |
# 把人名替换掉,减少对检索的影响 | |
user_names = [] | |
role_utter_length = [] | |
for num in range(len(dialog)): | |
utter = dialog[num] | |
try: | |
user_name, utter_txt = re.split('[::]', utter, maxsplit=1) | |
except ValueError as e: | |
logging.error(f"utter:{utter} can't find user_name.") | |
return None, None | |
if user_name != role_name: | |
if user_name not in user_names: | |
user_names.append(user_name) | |
index = user_names.index(user_name) | |
utter = utter.replace(user_name, f"user{index}", 1) | |
else: | |
role_utter_length.append(len(utter_txt)) | |
dialog_new.append(utter) | |
return dialog_new, user_names, role_utter_length | |
def _get_emb_base_by_list(self, raw_dialog_list): | |
logger.info(f"raw dialog db num: {len(raw_dialog_list)}") | |
new_raw_dialog_list = [] | |
context_list = [] | |
# 为了兼容因为句长把所有对话都过滤掉的情况 | |
new_raw_dialog_list_total = [] | |
context_list_total = [] | |
for raw_dialog in raw_dialog_list: | |
if not raw_dialog: | |
continue | |
end = 0 | |
for x in raw_dialog[::-1]: | |
if x.startswith(self.role_name): | |
break | |
end += 1 | |
raw_dialog = raw_dialog[:len(raw_dialog) - end] | |
new_dialog, user_names, role_utter_length = self.dialog_preprocess(raw_dialog, self.role_name) | |
if not new_dialog or not role_utter_length: | |
continue | |
if raw_dialog in new_raw_dialog_list_total: | |
continue | |
# 获得embedding时,不需要最后一句答案 | |
context = "\n".join(new_dialog) if len(new_dialog) < 2 else "\n".join(new_dialog[:-1]) | |
new_raw_dialog_list_total.append(raw_dialog) | |
context_list_total.append(context) | |
# 句长过滤 | |
role_length_mean = sum(role_utter_length) / len(role_utter_length) | |
if role_length_mean < self.min_mean_role_utter_length: | |
continue | |
new_raw_dialog_list.append(raw_dialog) | |
context_list.append(context) | |
assert len(new_raw_dialog_list) == len(context_list) | |
logger.debug(f"new_raw_dialog num: {len(new_raw_dialog_list)}") | |
# 兼容样本过少的情况 | |
if len(new_raw_dialog_list) < self.retrieve_num: | |
new_raw_dialog_list = new_raw_dialog_list_total | |
context_list = context_list_total | |
# 对话向量库 | |
context_vectors = self.emb_model.encode(context_list, normalize_embeddings=True) | |
context_index = faiss.IndexFlatL2(context_vectors.shape[1]) | |
context_index.add(context_vectors) | |
return new_raw_dialog_list, context_index | |
def get_retrieve_res(self, dialog: list, retrieve_num: int): | |
logger.debug(f"dialog: {dialog}") | |
# 同样去掉user name影响 | |
dialog, _, _ = self.dialog_preprocess(dialog, self.role_name) | |
dialog_vector = self.emb_model.encode(["\n".join(dialog)], normalize_embeddings=True) | |
simi_dialog_distance, simi_dialog_index = self.context_index.search( | |
dialog_vector, min(retrieve_num, len(self.dialogs))) | |
simi_dialog_results = [ | |
(str(simi_dialog_distance[0][num]), self.dialogs[index]) for num, index in enumerate(simi_dialog_index[0]) | |
] | |
logger.debug(f"dialog retrieve res: {simi_dialog_results}") | |
return simi_dialog_results | |