Spaces:
Running
Running
File size: 5,641 Bytes
5e9cd1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
from __future__ import annotations
from uuid import UUID
from langchain.callbacks import AsyncIteratorCallbackHandler
import json
import asyncio
from typing import Any, Dict, List, Optional
from langchain.schema import AgentFinish, AgentAction
from langchain.schema.output import LLMResult
def dumps(obj: Dict) -> str:
return json.dumps(obj, ensure_ascii=False)
class Status:
start: int = 1
running: int = 2
complete: int = 3
agent_action: int = 4
agent_finish: int = 5
error: int = 6
tool_finish: int = 7
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
def __init__(self):
super().__init__()
self.queue = asyncio.Queue()
self.done = asyncio.Event()
self.cur_tool = {}
self.out = True
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
parent_run_id: UUID | None = None, tags: List[str] | None = None,
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
# 对于截断不能自理的大模型,我来帮他截断
stop_words = ["Observation:", "Thought","\"","(", "\n","\t"]
for stop_word in stop_words:
index = input_str.find(stop_word)
if index != -1:
input_str = input_str[:index]
break
self.cur_tool = {
"tool_name": serialized["name"],
"input_str": input_str,
"output_str": "",
"status": Status.agent_action,
"run_id": run_id.hex,
"llm_token": "",
"final_answer": "",
"error": "",
}
# print("\nInput Str:",self.cur_tool["input_str"])
self.queue.put_nowait(dumps(self.cur_tool))
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
tags: List[str] | None = None, **kwargs: Any) -> None:
self.out = True ## 重置输出
self.cur_tool.update(
status=Status.tool_finish,
output_str=output.replace("Answer:", ""),
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.error,
error=str(error),
)
self.queue.put_nowait(dumps(self.cur_tool))
# async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
# if "Action" in token: ## 减少重复输出
# before_action = token.split("Action")[0]
# self.cur_tool.update(
# status=Status.running,
# llm_token=before_action + "\n",
# )
# self.queue.put_nowait(dumps(self.cur_tool))
#
# self.out = False
#
# if token and self.out:
# self.cur_tool.update(
# status=Status.running,
# llm_token=token,
# )
# self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
special_tokens = ["Action", "<|observation|>"]
for stoken in special_tokens:
if stoken in token:
before_action = token.split(stoken)[0]
self.cur_tool.update(
status=Status.running,
llm_token=before_action + "\n",
)
self.queue.put_nowait(dumps(self.cur_tool))
self.out = False
break
if token and self.out:
self.cur_tool.update(
status=Status.running,
llm_token=token,
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.start,
llm_token="",
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
self.cur_tool.update(
status=Status.start,
llm_token="",
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.complete,
llm_token="\n",
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.error,
error=str(error),
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_agent_finish(
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
# 返回最终答案
self.cur_tool.update(
status=Status.agent_finish,
final_answer=finish.return_values["output"],
)
self.queue.put_nowait(dumps(self.cur_tool))
self.cur_tool = {}
|