Spaces:
Running
Running
""" | |
This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo. | |
""" | |
from __future__ import annotations | |
import json | |
import logging | |
from typing import Any, List, Sequence, Tuple, Optional, Union | |
from pydantic.schema import model_schema | |
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.agents.agent import Agent | |
from langchain.chains.llm import LLMChain | |
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate | |
from langchain.agents.agent import AgentOutputParser | |
from langchain.output_parsers import OutputFixingParser | |
from langchain.pydantic_v1 import Field | |
from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate | |
from langchain.agents.agent import AgentExecutor | |
from langchain.callbacks.base import BaseCallbackManager | |
from langchain.schema.language_model import BaseLanguageModel | |
from langchain.tools.base import BaseTool | |
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}" | |
logger = logging.getLogger(__name__) | |
class StructuredChatOutputParserWithRetries(AgentOutputParser): | |
"""Output parser with retries for the structured chat agent.""" | |
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser) | |
"""The base parser to use.""" | |
output_fixing_parser: Optional[OutputFixingParser] = None | |
"""The output fixing parser to use.""" | |
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: | |
special_tokens = ["Action:", "<|observation|>"] | |
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens]) | |
text = text[:first_index] | |
if "tool_call" in text: | |
action_end = text.find("```") | |
action = text[:action_end].strip() | |
params_str_start = text.find("(") + 1 | |
params_str_end = text.rfind(")") | |
params_str = text[params_str_start:params_str_end] | |
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param] | |
params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs} | |
action_json = { | |
"action": action, | |
"action_input": params | |
} | |
else: | |
action_json = { | |
"action": "Final Answer", | |
"action_input": text | |
} | |
action_str = f""" | |
Action: | |
``` | |
{json.dumps(action_json, ensure_ascii=False)} | |
```""" | |
try: | |
if self.output_fixing_parser is not None: | |
parsed_obj: Union[ | |
AgentAction, AgentFinish | |
] = self.output_fixing_parser.parse(action_str) | |
else: | |
parsed_obj = self.base_parser.parse(action_str) | |
return parsed_obj | |
except Exception as e: | |
raise OutputParserException(f"Could not parse LLM output: {text}") from e | |
def _type(self) -> str: | |
return "structured_chat_ChatGLM3_6b_with_retries" | |
class StructuredGLM3ChatAgent(Agent): | |
"""Structured Chat Agent.""" | |
output_parser: AgentOutputParser = Field( | |
default_factory=StructuredChatOutputParserWithRetries | |
) | |
"""Output parser for the agent.""" | |
def observation_prefix(self) -> str: | |
"""Prefix to append the ChatGLM3-6B observation with.""" | |
return "Observation:" | |
def llm_prefix(self) -> str: | |
"""Prefix to append the llm call with.""" | |
return "Thought:" | |
def _construct_scratchpad( | |
self, intermediate_steps: List[Tuple[AgentAction, str]] | |
) -> str: | |
agent_scratchpad = super()._construct_scratchpad(intermediate_steps) | |
if not isinstance(agent_scratchpad, str): | |
raise ValueError("agent_scratchpad should be of type string.") | |
if agent_scratchpad: | |
return ( | |
f"This was your previous work " | |
f"(but I haven't seen any of it! I only see what " | |
f"you return as final answer):\n{agent_scratchpad}" | |
) | |
else: | |
return agent_scratchpad | |
def _get_default_output_parser( | |
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any | |
) -> AgentOutputParser: | |
return StructuredChatOutputParserWithRetries(llm=llm) | |
def _stop(self) -> List[str]: | |
return ["<|observation|>"] | |
def create_prompt( | |
cls, | |
tools: Sequence[BaseTool], | |
prompt: str = None, | |
input_variables: Optional[List[str]] = None, | |
memory_prompts: Optional[List[BasePromptTemplate]] = None, | |
) -> BasePromptTemplate: | |
tools_json = [] | |
tool_names = [] | |
for tool in tools: | |
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {} | |
simplified_config_langchain = { | |
"name": tool.name, | |
"description": tool.description, | |
"parameters": tool_schema.get("properties", {}) | |
} | |
tools_json.append(simplified_config_langchain) | |
tool_names.append(tool.name) | |
formatted_tools = "\n".join([ | |
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}" | |
for tool in tools_json | |
]) | |
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}") | |
template = prompt.format(tool_names=tool_names, | |
tools=formatted_tools, | |
history="None", | |
input="{input}", | |
agent_scratchpad="{agent_scratchpad}") | |
if input_variables is None: | |
input_variables = ["input", "agent_scratchpad"] | |
_memory_prompts = memory_prompts or [] | |
messages = [ | |
SystemMessagePromptTemplate.from_template(template), | |
*_memory_prompts, | |
] | |
return ChatPromptTemplate(input_variables=input_variables, messages=messages) | |
def from_llm_and_tools( | |
cls, | |
llm: BaseLanguageModel, | |
tools: Sequence[BaseTool], | |
prompt: str = None, | |
callback_manager: Optional[BaseCallbackManager] = None, | |
output_parser: Optional[AgentOutputParser] = None, | |
human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |
input_variables: Optional[List[str]] = None, | |
memory_prompts: Optional[List[BasePromptTemplate]] = None, | |
**kwargs: Any, | |
) -> Agent: | |
"""Construct an agent from an LLM and tools.""" | |
cls._validate_tools(tools) | |
prompt = cls.create_prompt( | |
tools, | |
prompt=prompt, | |
input_variables=input_variables, | |
memory_prompts=memory_prompts, | |
) | |
llm_chain = LLMChain( | |
llm=llm, | |
prompt=prompt, | |
callback_manager=callback_manager, | |
) | |
tool_names = [tool.name for tool in tools] | |
_output_parser = output_parser or cls._get_default_output_parser(llm=llm) | |
return cls( | |
llm_chain=llm_chain, | |
allowed_tools=tool_names, | |
output_parser=_output_parser, | |
**kwargs, | |
) | |
def _agent_type(self) -> str: | |
raise ValueError | |
def initialize_glm3_agent( | |
tools: Sequence[BaseTool], | |
llm: BaseLanguageModel, | |
prompt: str = None, | |
memory: Optional[ConversationBufferWindowMemory] = None, | |
agent_kwargs: Optional[dict] = None, | |
*, | |
tags: Optional[Sequence[str]] = None, | |
**kwargs: Any, | |
) -> AgentExecutor: | |
tags_ = list(tags) if tags else [] | |
agent_kwargs = agent_kwargs or {} | |
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools( | |
llm=llm, | |
tools=tools, | |
prompt=prompt, | |
**agent_kwargs | |
) | |
return AgentExecutor.from_agent_and_tools( | |
agent=agent_obj, | |
tools=tools, | |
memory=memory, | |
tags=tags_, | |
**kwargs, | |
) |