|
from core.tools.entities.values import default_tool_label_name_list |
|
from core.tools.provider.api_tool_provider import ApiToolProviderController |
|
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController |
|
from core.tools.provider.tool_provider import ToolProviderController |
|
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController |
|
from extensions.ext_database import db |
|
from models.tools import ToolLabelBinding |
|
|
|
|
|
class ToolLabelManager: |
|
@classmethod |
|
def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]: |
|
""" |
|
Filter tool labels |
|
""" |
|
tool_labels = [label for label in tool_labels if label in default_tool_label_name_list] |
|
return list(set(tool_labels)) |
|
|
|
@classmethod |
|
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): |
|
""" |
|
Update tool labels |
|
""" |
|
labels = cls.filter_tool_labels(labels) |
|
|
|
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): |
|
provider_id = controller.provider_id |
|
else: |
|
raise ValueError("Unsupported tool type") |
|
|
|
|
|
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() |
|
|
|
|
|
for label in labels: |
|
db.session.add( |
|
ToolLabelBinding( |
|
tool_id=provider_id, |
|
tool_type=controller.provider_type.value, |
|
label_name=label, |
|
) |
|
) |
|
|
|
db.session.commit() |
|
|
|
@classmethod |
|
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]: |
|
""" |
|
Get tool labels |
|
""" |
|
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): |
|
provider_id = controller.provider_id |
|
elif isinstance(controller, BuiltinToolProviderController): |
|
return controller.tool_labels |
|
else: |
|
raise ValueError("Unsupported tool type") |
|
|
|
labels: list[ToolLabelBinding] = ( |
|
db.session.query(ToolLabelBinding.label_name) |
|
.filter( |
|
ToolLabelBinding.tool_id == provider_id, |
|
ToolLabelBinding.tool_type == controller.provider_type.value, |
|
) |
|
.all() |
|
) |
|
|
|
return [label.label_name for label in labels] |
|
|
|
@classmethod |
|
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: |
|
""" |
|
Get tools labels |
|
|
|
:param tool_providers: list of tool providers |
|
|
|
:return: dict of tool labels |
|
:key: tool id |
|
:value: list of tool labels |
|
""" |
|
if not tool_providers: |
|
return {} |
|
|
|
for controller in tool_providers: |
|
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): |
|
raise ValueError("Unsupported tool type") |
|
|
|
provider_ids = [controller.provider_id for controller in tool_providers] |
|
|
|
labels: list[ToolLabelBinding] = ( |
|
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() |
|
) |
|
|
|
tool_labels = {label.tool_id: [] for label in labels} |
|
|
|
for label in labels: |
|
tool_labels[label.tool_id].append(label.label_name) |
|
|
|
return tool_labels |
|
|