|
import json |
|
import logging |
|
import uuid |
|
from collections.abc import Mapping, Sequence |
|
from datetime import datetime, timezone |
|
from typing import Optional, Union, cast |
|
|
|
from core.agent.entities import AgentEntity, AgentToolEntity |
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager |
|
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig |
|
from core.app.apps.base_app_queue_manager import AppQueueManager |
|
from core.app.apps.base_app_runner import AppRunner |
|
from core.app.entities.app_invoke_entities import ( |
|
AgentChatAppGenerateEntity, |
|
ModelConfigWithCredentialsEntity, |
|
) |
|
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler |
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler |
|
from core.file import file_manager |
|
from core.memory.token_buffer_memory import TokenBufferMemory |
|
from core.model_manager import ModelInstance |
|
from core.model_runtime.entities import ( |
|
AssistantPromptMessage, |
|
LLMUsage, |
|
PromptMessage, |
|
PromptMessageContent, |
|
PromptMessageTool, |
|
SystemPromptMessage, |
|
TextPromptMessageContent, |
|
ToolPromptMessage, |
|
UserPromptMessage, |
|
) |
|
from core.model_runtime.entities.model_entities import ModelFeature |
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel |
|
from core.model_runtime.utils.encoders import jsonable_encoder |
|
from core.prompt.utils.extract_thread_messages import extract_thread_messages |
|
from core.tools.entities.tool_entities import ( |
|
ToolParameter, |
|
ToolRuntimeVariablePool, |
|
) |
|
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool |
|
from core.tools.tool.tool import Tool |
|
from core.tools.tool_manager import ToolManager |
|
from extensions.ext_database import db |
|
from factories import file_factory |
|
from models.model import Conversation, Message, MessageAgentThought, MessageFile |
|
from models.tools import ToolConversationVariables |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class BaseAgentRunner(AppRunner): |
|
def __init__( |
|
self, |
|
tenant_id: str, |
|
application_generate_entity: AgentChatAppGenerateEntity, |
|
conversation: Conversation, |
|
app_config: AgentChatAppConfig, |
|
model_config: ModelConfigWithCredentialsEntity, |
|
config: AgentEntity, |
|
queue_manager: AppQueueManager, |
|
message: Message, |
|
user_id: str, |
|
memory: Optional[TokenBufferMemory] = None, |
|
prompt_messages: Optional[list[PromptMessage]] = None, |
|
variables_pool: Optional[ToolRuntimeVariablePool] = None, |
|
db_variables: Optional[ToolConversationVariables] = None, |
|
model_instance: ModelInstance = None, |
|
) -> None: |
|
self.tenant_id = tenant_id |
|
self.application_generate_entity = application_generate_entity |
|
self.conversation = conversation |
|
self.app_config = app_config |
|
self.model_config = model_config |
|
self.config = config |
|
self.queue_manager = queue_manager |
|
self.message = message |
|
self.user_id = user_id |
|
self.memory = memory |
|
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) |
|
self.variables_pool = variables_pool |
|
self.db_variables_pool = db_variables |
|
self.model_instance = model_instance |
|
|
|
|
|
self.agent_callback = DifyAgentCallbackHandler() |
|
|
|
hit_callback = DatasetIndexToolCallbackHandler( |
|
queue_manager=queue_manager, |
|
app_id=self.app_config.app_id, |
|
message_id=message.id, |
|
user_id=user_id, |
|
invoke_from=self.application_generate_entity.invoke_from, |
|
) |
|
self.dataset_tools = DatasetRetrieverTool.get_dataset_tools( |
|
tenant_id=tenant_id, |
|
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [], |
|
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, |
|
return_resource=app_config.additional_features.show_retrieve_source, |
|
invoke_from=application_generate_entity.invoke_from, |
|
hit_callback=hit_callback, |
|
) |
|
|
|
self.agent_thought_count = ( |
|
db.session.query(MessageAgentThought) |
|
.filter( |
|
MessageAgentThought.message_id == self.message.id, |
|
) |
|
.count() |
|
) |
|
db.session.close() |
|
|
|
|
|
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) |
|
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) |
|
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []): |
|
self.stream_tool_call = True |
|
else: |
|
self.stream_tool_call = False |
|
|
|
|
|
if model_schema and ModelFeature.VISION in (model_schema.features or []): |
|
self.files = application_generate_entity.files |
|
else: |
|
self.files = [] |
|
self.query = None |
|
self._current_thoughts: list[PromptMessage] = [] |
|
|
|
def _repack_app_generate_entity( |
|
self, app_generate_entity: AgentChatAppGenerateEntity |
|
) -> AgentChatAppGenerateEntity: |
|
""" |
|
Repack app generate entity |
|
""" |
|
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: |
|
app_generate_entity.app_config.prompt_template.simple_prompt_template = "" |
|
|
|
return app_generate_entity |
|
|
|
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: |
|
""" |
|
convert tool to prompt message tool |
|
""" |
|
tool_entity = ToolManager.get_agent_tool_runtime( |
|
tenant_id=self.tenant_id, |
|
app_id=self.app_config.app_id, |
|
agent_tool=tool, |
|
invoke_from=self.application_generate_entity.invoke_from, |
|
) |
|
tool_entity.load_variables(self.variables_pool) |
|
|
|
message_tool = PromptMessageTool( |
|
name=tool.tool_name, |
|
description=tool_entity.description.llm, |
|
parameters={ |
|
"type": "object", |
|
"properties": {}, |
|
"required": [], |
|
}, |
|
) |
|
|
|
parameters = tool_entity.get_all_runtime_parameters() |
|
for parameter in parameters: |
|
if parameter.form != ToolParameter.ToolParameterForm.LLM: |
|
continue |
|
|
|
parameter_type = parameter.type.as_normal_type() |
|
if parameter.type in { |
|
ToolParameter.ToolParameterType.SYSTEM_FILES, |
|
ToolParameter.ToolParameterType.FILE, |
|
ToolParameter.ToolParameterType.FILES, |
|
}: |
|
continue |
|
enum = [] |
|
if parameter.type == ToolParameter.ToolParameterType.SELECT: |
|
enum = [option.value for option in parameter.options] |
|
|
|
message_tool.parameters["properties"][parameter.name] = { |
|
"type": parameter_type, |
|
"description": parameter.llm_description or "", |
|
} |
|
|
|
if len(enum) > 0: |
|
message_tool.parameters["properties"][parameter.name]["enum"] = enum |
|
|
|
if parameter.required: |
|
message_tool.parameters["required"].append(parameter.name) |
|
|
|
return message_tool, tool_entity |
|
|
|
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool: |
|
""" |
|
convert dataset retriever tool to prompt message tool |
|
""" |
|
prompt_tool = PromptMessageTool( |
|
name=tool.identity.name, |
|
description=tool.description.llm, |
|
parameters={ |
|
"type": "object", |
|
"properties": {}, |
|
"required": [], |
|
}, |
|
) |
|
|
|
for parameter in tool.get_runtime_parameters(): |
|
parameter_type = "string" |
|
|
|
prompt_tool.parameters["properties"][parameter.name] = { |
|
"type": parameter_type, |
|
"description": parameter.llm_description or "", |
|
} |
|
|
|
if parameter.required: |
|
if parameter.name not in prompt_tool.parameters["required"]: |
|
prompt_tool.parameters["required"].append(parameter.name) |
|
|
|
return prompt_tool |
|
|
|
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]: |
|
""" |
|
Init tools |
|
""" |
|
tool_instances = {} |
|
prompt_messages_tools = [] |
|
|
|
for tool in self.app_config.agent.tools if self.app_config.agent else []: |
|
try: |
|
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) |
|
except Exception: |
|
|
|
continue |
|
|
|
tool_instances[tool.tool_name] = tool_entity |
|
|
|
prompt_messages_tools.append(prompt_tool) |
|
|
|
|
|
for dataset_tool in self.dataset_tools: |
|
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool) |
|
|
|
prompt_messages_tools.append(prompt_tool) |
|
|
|
tool_instances[dataset_tool.identity.name] = dataset_tool |
|
|
|
return tool_instances, prompt_messages_tools |
|
|
|
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool: |
|
""" |
|
update prompt message tool |
|
""" |
|
|
|
tool_runtime_parameters = tool.get_runtime_parameters() or [] |
|
|
|
for parameter in tool_runtime_parameters: |
|
if parameter.form != ToolParameter.ToolParameterForm.LLM: |
|
continue |
|
|
|
parameter_type = parameter.type.as_normal_type() |
|
if parameter.type in { |
|
ToolParameter.ToolParameterType.SYSTEM_FILES, |
|
ToolParameter.ToolParameterType.FILE, |
|
ToolParameter.ToolParameterType.FILES, |
|
}: |
|
continue |
|
enum = [] |
|
if parameter.type == ToolParameter.ToolParameterType.SELECT: |
|
enum = [option.value for option in parameter.options] |
|
|
|
prompt_tool.parameters["properties"][parameter.name] = { |
|
"type": parameter_type, |
|
"description": parameter.llm_description or "", |
|
} |
|
|
|
if len(enum) > 0: |
|
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum |
|
|
|
if parameter.required: |
|
if parameter.name not in prompt_tool.parameters["required"]: |
|
prompt_tool.parameters["required"].append(parameter.name) |
|
|
|
return prompt_tool |
|
|
|
def create_agent_thought( |
|
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str] |
|
) -> MessageAgentThought: |
|
""" |
|
Create agent thought |
|
""" |
|
thought = MessageAgentThought( |
|
message_id=message_id, |
|
message_chain_id=None, |
|
thought="", |
|
tool=tool_name, |
|
tool_labels_str="{}", |
|
tool_meta_str="{}", |
|
tool_input=tool_input, |
|
message=message, |
|
message_token=0, |
|
message_unit_price=0, |
|
message_price_unit=0, |
|
message_files=json.dumps(messages_ids) if messages_ids else "", |
|
answer="", |
|
observation="", |
|
answer_token=0, |
|
answer_unit_price=0, |
|
answer_price_unit=0, |
|
tokens=0, |
|
total_price=0, |
|
position=self.agent_thought_count + 1, |
|
currency="USD", |
|
latency=0, |
|
created_by_role="account", |
|
created_by=self.user_id, |
|
) |
|
|
|
db.session.add(thought) |
|
db.session.commit() |
|
db.session.refresh(thought) |
|
db.session.close() |
|
|
|
self.agent_thought_count += 1 |
|
|
|
return thought |
|
|
|
def save_agent_thought( |
|
self, |
|
agent_thought: MessageAgentThought, |
|
tool_name: str, |
|
tool_input: Union[str, dict], |
|
thought: str, |
|
observation: Union[str, dict], |
|
tool_invoke_meta: Union[str, dict], |
|
answer: str, |
|
messages_ids: list[str], |
|
llm_usage: LLMUsage = None, |
|
) -> MessageAgentThought: |
|
""" |
|
Save agent thought |
|
""" |
|
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() |
|
|
|
if thought is not None: |
|
agent_thought.thought = thought |
|
|
|
if tool_name is not None: |
|
agent_thought.tool = tool_name |
|
|
|
if tool_input is not None: |
|
if isinstance(tool_input, dict): |
|
try: |
|
tool_input = json.dumps(tool_input, ensure_ascii=False) |
|
except Exception as e: |
|
tool_input = json.dumps(tool_input) |
|
|
|
agent_thought.tool_input = tool_input |
|
|
|
if observation is not None: |
|
if isinstance(observation, dict): |
|
try: |
|
observation = json.dumps(observation, ensure_ascii=False) |
|
except Exception as e: |
|
observation = json.dumps(observation) |
|
|
|
agent_thought.observation = observation |
|
|
|
if answer is not None: |
|
agent_thought.answer = answer |
|
|
|
if messages_ids is not None and len(messages_ids) > 0: |
|
agent_thought.message_files = json.dumps(messages_ids) |
|
|
|
if llm_usage: |
|
agent_thought.message_token = llm_usage.prompt_tokens |
|
agent_thought.message_price_unit = llm_usage.prompt_price_unit |
|
agent_thought.message_unit_price = llm_usage.prompt_unit_price |
|
agent_thought.answer_token = llm_usage.completion_tokens |
|
agent_thought.answer_price_unit = llm_usage.completion_price_unit |
|
agent_thought.answer_unit_price = llm_usage.completion_unit_price |
|
agent_thought.tokens = llm_usage.total_tokens |
|
agent_thought.total_price = llm_usage.total_price |
|
|
|
|
|
labels = agent_thought.tool_labels or {} |
|
tools = agent_thought.tool.split(";") if agent_thought.tool else [] |
|
for tool in tools: |
|
if not tool: |
|
continue |
|
if tool not in labels: |
|
tool_label = ToolManager.get_tool_label(tool) |
|
if tool_label: |
|
labels[tool] = tool_label.to_dict() |
|
else: |
|
labels[tool] = {"en_US": tool, "zh_Hans": tool} |
|
|
|
agent_thought.tool_labels_str = json.dumps(labels) |
|
|
|
if tool_invoke_meta is not None: |
|
if isinstance(tool_invoke_meta, dict): |
|
try: |
|
tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False) |
|
except Exception as e: |
|
tool_invoke_meta = json.dumps(tool_invoke_meta) |
|
|
|
agent_thought.tool_meta_str = tool_invoke_meta |
|
|
|
db.session.commit() |
|
db.session.close() |
|
|
|
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): |
|
""" |
|
convert tool variables to db variables |
|
""" |
|
db_variables = ( |
|
db.session.query(ToolConversationVariables) |
|
.filter( |
|
ToolConversationVariables.conversation_id == self.message.conversation_id, |
|
) |
|
.first() |
|
) |
|
|
|
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) |
|
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) |
|
db.session.commit() |
|
db.session.close() |
|
|
|
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: |
|
""" |
|
Organize agent history |
|
""" |
|
result = [] |
|
|
|
for prompt_message in prompt_messages: |
|
if isinstance(prompt_message, SystemPromptMessage): |
|
result.append(prompt_message) |
|
|
|
messages: list[Message] = ( |
|
db.session.query(Message) |
|
.filter( |
|
Message.conversation_id == self.message.conversation_id, |
|
) |
|
.order_by(Message.created_at.desc()) |
|
.all() |
|
) |
|
|
|
messages = list(reversed(extract_thread_messages(messages))) |
|
|
|
for message in messages: |
|
if message.id == self.message.id: |
|
continue |
|
|
|
result.append(self.organize_agent_user_prompt(message)) |
|
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts |
|
if agent_thoughts: |
|
for agent_thought in agent_thoughts: |
|
tools = agent_thought.tool |
|
if tools: |
|
tools = tools.split(";") |
|
tool_calls: list[AssistantPromptMessage.ToolCall] = [] |
|
tool_call_response: list[ToolPromptMessage] = [] |
|
try: |
|
tool_inputs = json.loads(agent_thought.tool_input) |
|
except Exception as e: |
|
tool_inputs = {tool: {} for tool in tools} |
|
try: |
|
tool_responses = json.loads(agent_thought.observation) |
|
except Exception as e: |
|
tool_responses = dict.fromkeys(tools, agent_thought.observation) |
|
|
|
for tool in tools: |
|
|
|
tool_call_id = str(uuid.uuid4()) |
|
tool_calls.append( |
|
AssistantPromptMessage.ToolCall( |
|
id=tool_call_id, |
|
type="function", |
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
name=tool, |
|
arguments=json.dumps(tool_inputs.get(tool, {})), |
|
), |
|
) |
|
) |
|
tool_call_response.append( |
|
ToolPromptMessage( |
|
content=tool_responses.get(tool, agent_thought.observation), |
|
name=tool, |
|
tool_call_id=tool_call_id, |
|
) |
|
) |
|
|
|
result.extend( |
|
[ |
|
AssistantPromptMessage( |
|
content=agent_thought.thought, |
|
tool_calls=tool_calls, |
|
), |
|
*tool_call_response, |
|
] |
|
) |
|
if not tools: |
|
result.append(AssistantPromptMessage(content=agent_thought.thought)) |
|
else: |
|
if message.answer: |
|
result.append(AssistantPromptMessage(content=message.answer)) |
|
|
|
db.session.close() |
|
|
|
return result |
|
|
|
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: |
|
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() |
|
if files: |
|
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) |
|
|
|
if file_extra_config: |
|
file_objs = file_factory.build_from_message_files( |
|
message_files=files, tenant_id=self.tenant_id, config=file_extra_config |
|
) |
|
else: |
|
file_objs = [] |
|
|
|
if not file_objs: |
|
return UserPromptMessage(content=message.query) |
|
else: |
|
prompt_message_contents: list[PromptMessageContent] = [] |
|
prompt_message_contents.append(TextPromptMessageContent(data=message.query)) |
|
for file_obj in file_objs: |
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) |
|
|
|
return UserPromptMessage(content=prompt_message_contents) |
|
else: |
|
return UserPromptMessage(content=message.query) |
|
|