Spaces:
Paused
Paused
# | |
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import collections | |
import logging | |
import re | |
import logging | |
import traceback | |
from concurrent.futures import ThreadPoolExecutor | |
from dataclasses import dataclass | |
from typing import Any | |
from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT | |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements | |
from rag.llm.chat_model import Base as CompletionLLM | |
import markdown_to_json | |
from functools import reduce | |
from rag.utils import num_tokens_from_string | |
class MindMapResult: | |
"""Unipartite Mind Graph result class definition.""" | |
output: dict | |
class MindMapExtractor: | |
_llm: CompletionLLM | |
_input_text_key: str | |
_mind_map_prompt: str | |
_on_error: ErrorHandlerFn | |
def __init__( | |
self, | |
llm_invoker: CompletionLLM, | |
prompt: str | None = None, | |
input_text_key: str | None = None, | |
on_error: ErrorHandlerFn | None = None, | |
): | |
"""Init method definition.""" | |
# TODO: streamline construction | |
self._llm = llm_invoker | |
self._input_text_key = input_text_key or "input_text" | |
self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT | |
self._on_error = on_error or (lambda _e, _s, _d: None) | |
def _key(self, k): | |
return re.sub(r"\*+", "", k) | |
def _be_children(self, obj: dict, keyset: set): | |
if isinstance(obj, str): | |
obj = [obj] | |
if isinstance(obj, list): | |
for i in obj: keyset.add(i) | |
return [{"id": re.sub(r"\*+", "", i), "children": []} for i in obj] | |
arr = [] | |
for k, v in obj.items(): | |
k = self._key(k) | |
if not k or k in keyset: continue | |
keyset.add(k) | |
arr.append({ | |
"id": k, | |
"children": self._be_children(v, keyset) | |
}) | |
return arr | |
def __call__( | |
self, sections: list[str], prompt_variables: dict[str, Any] | None = None | |
) -> MindMapResult: | |
"""Call method definition.""" | |
if prompt_variables is None: | |
prompt_variables = {} | |
try: | |
exe = ThreadPoolExecutor(max_workers=12) | |
threads = [] | |
token_count = max(self._llm.max_length * 0.8, self._llm.max_length-512) | |
texts = [] | |
res = [] | |
cnt = 0 | |
for i in range(len(sections)): | |
section_cnt = num_tokens_from_string(sections[i]) | |
if cnt + section_cnt >= token_count and texts: | |
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables)) | |
texts = [] | |
cnt = 0 | |
texts.append(sections[i]) | |
cnt += section_cnt | |
if texts: | |
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables)) | |
for i, _ in enumerate(threads): | |
res.append(_.result()) | |
merge_json = reduce(self._merge, res) | |
if len(merge_json.keys()) > 1: | |
keyset = set( | |
[re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict) and re.sub(r"\*+", "", k)]) | |
merge_json = {"id": "root", | |
"children": [{"id": self._key(k), "children": self._be_children(v, keyset)} for k, v in | |
merge_json.items() if isinstance(v, dict) and self._key(k)]} | |
else: | |
k = self._key(list(self._be_children.keys())[0]) | |
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], set([k]))} | |
except Exception as e: | |
logging.exception("error mind graph") | |
self._on_error( | |
e, | |
traceback.format_exc(), None | |
) | |
merge_json = {"error": str(e)} | |
return MindMapResult(output=merge_json) | |
def _merge(self, d1, d2): | |
for k in d1: | |
if k in d2: | |
if isinstance(d1[k], dict) and isinstance(d2[k], dict): | |
self._merge(d1[k], d2[k]) | |
elif isinstance(d1[k], list) and isinstance(d2[k], list): | |
d2[k].extend(d1[k]) | |
else: | |
d2[k] = d1[k] | |
else: | |
d2[k] = d1[k] | |
return d2 | |
def _list_to_kv(self, data): | |
for key, value in data.items(): | |
if isinstance(value, dict): | |
self._list_to_kv(value) | |
elif isinstance(value, list): | |
new_value = {} | |
for i in range(len(value)): | |
if isinstance(value[i], list): | |
new_value[value[i - 1]] = value[i][0] | |
data[key] = new_value | |
else: | |
continue | |
return data | |
def _todict(self, layer:collections.OrderedDict): | |
to_ret = layer | |
if isinstance(layer, collections.OrderedDict): | |
to_ret = dict(layer) | |
try: | |
for key, value in to_ret.items(): | |
to_ret[key] = self._todict(value) | |
except AttributeError: | |
pass | |
return self._list_to_kv(to_ret) | |
def _process_document( | |
self, text: str, prompt_variables: dict[str, str] | |
) -> str: | |
variables = { | |
**prompt_variables, | |
self._input_text_key: text, | |
} | |
text = perform_variable_replacements(self._mind_map_prompt, variables=variables) | |
gen_conf = {"temperature": 0.5} | |
response = self._llm.chat(text, [], gen_conf) | |
response = re.sub(r"```[^\n]*", "", response) | |
print(response) | |
print("---------------------------------------------------\n", self._todict(markdown_to_json.dictify(response))) | |
return self._todict(markdown_to_json.dictify(response)) | |