Spaces:
Running
Running
from __future__ import annotations | |
from uuid import UUID | |
from langchain.callbacks import AsyncIteratorCallbackHandler | |
import json | |
import asyncio | |
from typing import Any, Dict, List, Optional | |
from langchain.schema import AgentFinish, AgentAction | |
from langchain.schema.output import LLMResult | |
def dumps(obj: Dict) -> str: | |
return json.dumps(obj, ensure_ascii=False) | |
class Status: | |
start: int = 1 | |
running: int = 2 | |
complete: int = 3 | |
agent_action: int = 4 | |
agent_finish: int = 5 | |
error: int = 6 | |
tool_finish: int = 7 | |
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): | |
def __init__(self): | |
super().__init__() | |
self.queue = asyncio.Queue() | |
self.done = asyncio.Event() | |
self.cur_tool = {} | |
self.out = True | |
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, | |
parent_run_id: UUID | None = None, tags: List[str] | None = None, | |
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None: | |
# 对于截断不能自理的大模型,我来帮他截断 | |
stop_words = ["Observation:", "Thought","\"","(", "\n","\t"] | |
for stop_word in stop_words: | |
index = input_str.find(stop_word) | |
if index != -1: | |
input_str = input_str[:index] | |
break | |
self.cur_tool = { | |
"tool_name": serialized["name"], | |
"input_str": input_str, | |
"output_str": "", | |
"status": Status.agent_action, | |
"run_id": run_id.hex, | |
"llm_token": "", | |
"final_answer": "", | |
"error": "", | |
} | |
# print("\nInput Str:",self.cur_tool["input_str"]) | |
self.queue.put_nowait(dumps(self.cur_tool)) | |
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None, | |
tags: List[str] | None = None, **kwargs: Any) -> None: | |
self.out = True ## 重置输出 | |
self.cur_tool.update( | |
status=Status.tool_finish, | |
output_str=output.replace("Answer:", ""), | |
) | |
self.queue.put_nowait(dumps(self.cur_tool)) | |
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID, | |
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None: | |
self.cur_tool.update( | |
status=Status.error, | |
error=str(error), | |
) | |
self.queue.put_nowait(dumps(self.cur_tool)) | |
# async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
# if "Action" in token: ## 减少重复输出 | |
# before_action = token.split("Action")[0] | |
# self.cur_tool.update( | |
# status=Status.running, | |
# llm_token=before_action + "\n", | |
# ) | |
# self.queue.put_nowait(dumps(self.cur_tool)) | |
# | |
# self.out = False | |
# | |
# if token and self.out: | |
# self.cur_tool.update( | |
# status=Status.running, | |
# llm_token=token, | |
# ) | |
# self.queue.put_nowait(dumps(self.cur_tool)) | |
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
special_tokens = ["Action", "<|observation|>"] | |
for stoken in special_tokens: | |
if stoken in token: | |
before_action = token.split(stoken)[0] | |
self.cur_tool.update( | |
status=Status.running, | |
llm_token=before_action + "\n", | |
) | |
self.queue.put_nowait(dumps(self.cur_tool)) | |
self.out = False | |
break | |
if token and self.out: | |
self.cur_tool.update( | |
status=Status.running, | |
llm_token=token, | |
) | |
self.queue.put_nowait(dumps(self.cur_tool)) | |
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None: | |
self.cur_tool.update( | |
status=Status.start, | |
llm_token="", | |
) | |
self.queue.put_nowait(dumps(self.cur_tool)) | |
async def on_chat_model_start( | |
self, | |
serialized: Dict[str, Any], | |
messages: List[List], | |
*, | |
run_id: UUID, | |
parent_run_id: Optional[UUID] = None, | |
tags: Optional[List[str]] = None, | |
metadata: Optional[Dict[str, Any]] = None, | |
**kwargs: Any, | |
) -> None: | |
self.cur_tool.update( | |
status=Status.start, | |
llm_token="", | |
) | |
self.queue.put_nowait(dumps(self.cur_tool)) | |
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |
self.cur_tool.update( | |
status=Status.complete, | |
llm_token="\n", | |
) | |
self.queue.put_nowait(dumps(self.cur_tool)) | |
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None: | |
self.cur_tool.update( | |
status=Status.error, | |
error=str(error), | |
) | |
self.queue.put_nowait(dumps(self.cur_tool)) | |
async def on_agent_finish( | |
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, | |
tags: Optional[List[str]] = None, | |
**kwargs: Any, | |
) -> None: | |
# 返回最终答案 | |
self.cur_tool.update( | |
status=Status.agent_finish, | |
final_answer=finish.return_values["output"], | |
) | |
self.queue.put_nowait(dumps(self.cur_tool)) | |
self.cur_tool = {} | |