|
import json |
|
import logging |
|
from typing import Optional |
|
|
|
from httpx import get |
|
|
|
from core.model_runtime.utils.encoders import jsonable_encoder |
|
from core.tools.entities.api_entities import UserTool, UserToolProvider |
|
from core.tools.entities.common_entities import I18nObject |
|
from core.tools.entities.tool_bundle import ApiToolBundle |
|
from core.tools.entities.tool_entities import ( |
|
ApiProviderAuthType, |
|
ApiProviderSchemaType, |
|
ToolCredentialsOption, |
|
ToolProviderCredentials, |
|
) |
|
from core.tools.provider.api_tool_provider import ApiToolProviderController |
|
from core.tools.tool_label_manager import ToolLabelManager |
|
from core.tools.tool_manager import ToolManager |
|
from core.tools.utils.configuration import ToolConfigurationManager |
|
from core.tools.utils.parser import ApiBasedToolSchemaParser |
|
from extensions.ext_database import db |
|
from models.tools import ApiToolProvider |
|
from services.tools.tools_transform_service import ToolTransformService |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ApiToolManageService: |
|
@staticmethod |
|
def parser_api_schema(schema: str) -> list[ApiToolBundle]: |
|
""" |
|
parse api schema to tool bundle |
|
""" |
|
try: |
|
warnings = {} |
|
try: |
|
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) |
|
except Exception as e: |
|
raise ValueError(f"invalid schema: {str(e)}") |
|
|
|
credentials_schema = [ |
|
ToolProviderCredentials( |
|
name="auth_type", |
|
type=ToolProviderCredentials.CredentialsType.SELECT, |
|
required=True, |
|
default="none", |
|
options=[ |
|
ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")), |
|
ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")), |
|
], |
|
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"), |
|
), |
|
ToolProviderCredentials( |
|
name="api_key_header", |
|
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, |
|
required=False, |
|
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"), |
|
default="api_key", |
|
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"), |
|
), |
|
ToolProviderCredentials( |
|
name="api_key_value", |
|
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, |
|
required=False, |
|
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"), |
|
default="", |
|
), |
|
] |
|
|
|
return jsonable_encoder( |
|
{ |
|
"schema_type": schema_type, |
|
"parameters_schema": tool_bundles, |
|
"credentials_schema": credentials_schema, |
|
"warning": warnings, |
|
} |
|
) |
|
except Exception as e: |
|
raise ValueError(f"invalid schema: {str(e)}") |
|
|
|
@staticmethod |
|
def convert_schema_to_tool_bundles( |
|
schema: str, extra_info: Optional[dict] = None |
|
) -> tuple[list[ApiToolBundle], str]: |
|
""" |
|
convert schema to tool bundles |
|
|
|
:return: the list of tool bundles, description |
|
""" |
|
try: |
|
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info) |
|
return tool_bundles |
|
except Exception as e: |
|
raise ValueError(f"invalid schema: {str(e)}") |
|
|
|
@staticmethod |
|
def create_api_tool_provider( |
|
user_id: str, |
|
tenant_id: str, |
|
provider_name: str, |
|
icon: dict, |
|
credentials: dict, |
|
schema_type: str, |
|
schema: str, |
|
privacy_policy: str, |
|
custom_disclaimer: str, |
|
labels: list[str], |
|
): |
|
""" |
|
create api tool provider |
|
""" |
|
if schema_type not in [member.value for member in ApiProviderSchemaType]: |
|
raise ValueError(f"invalid schema type {schema}") |
|
|
|
|
|
provider: ApiToolProvider = ( |
|
db.session.query(ApiToolProvider) |
|
.filter( |
|
ApiToolProvider.tenant_id == tenant_id, |
|
ApiToolProvider.name == provider_name, |
|
) |
|
.first() |
|
) |
|
|
|
if provider is not None: |
|
raise ValueError(f"provider {provider_name} already exists") |
|
|
|
|
|
extra_info = {} |
|
|
|
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) |
|
|
|
if len(tool_bundles) > 100: |
|
raise ValueError("the number of apis should be less than 100") |
|
|
|
|
|
db_provider = ApiToolProvider( |
|
tenant_id=tenant_id, |
|
user_id=user_id, |
|
name=provider_name, |
|
icon=json.dumps(icon), |
|
schema=schema, |
|
description=extra_info.get("description", ""), |
|
schema_type_str=schema_type, |
|
tools_str=json.dumps(jsonable_encoder(tool_bundles)), |
|
credentials_str={}, |
|
privacy_policy=privacy_policy, |
|
custom_disclaimer=custom_disclaimer, |
|
) |
|
|
|
if "auth_type" not in credentials: |
|
raise ValueError("auth_type is required") |
|
|
|
|
|
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) |
|
|
|
|
|
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) |
|
|
|
provider_controller.load_bundled_tools(tool_bundles) |
|
|
|
|
|
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) |
|
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials) |
|
db_provider.credentials_str = json.dumps(encrypted_credentials) |
|
|
|
db.session.add(db_provider) |
|
db.session.commit() |
|
|
|
|
|
ToolLabelManager.update_tool_labels(provider_controller, labels) |
|
|
|
return {"result": "success"} |
|
|
|
@staticmethod |
|
def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str): |
|
""" |
|
get api tool provider remote schema |
|
""" |
|
headers = { |
|
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko)" |
|
" Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", |
|
"Accept": "*/*", |
|
} |
|
|
|
try: |
|
response = get(url, headers=headers, timeout=10) |
|
if response.status_code != 200: |
|
raise ValueError(f"Got status code {response.status_code}") |
|
schema = response.text |
|
|
|
|
|
ApiToolManageService.parser_api_schema(schema) |
|
except Exception as e: |
|
logger.error(f"parse api schema error: {str(e)}") |
|
raise ValueError("invalid schema, please check the url you provided") |
|
|
|
return {"schema": schema} |
|
|
|
@staticmethod |
|
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: |
|
""" |
|
list api tool provider tools |
|
""" |
|
provider: ApiToolProvider = ( |
|
db.session.query(ApiToolProvider) |
|
.filter( |
|
ApiToolProvider.tenant_id == tenant_id, |
|
ApiToolProvider.name == provider, |
|
) |
|
.first() |
|
) |
|
|
|
if provider is None: |
|
raise ValueError(f"you have not added provider {provider}") |
|
|
|
controller = ToolTransformService.api_provider_to_controller(db_provider=provider) |
|
labels = ToolLabelManager.get_tool_labels(controller) |
|
|
|
return [ |
|
ToolTransformService.tool_to_user_tool( |
|
tool_bundle, |
|
labels=labels, |
|
) |
|
for tool_bundle in provider.tools |
|
] |
|
|
|
@staticmethod |
|
def update_api_tool_provider( |
|
user_id: str, |
|
tenant_id: str, |
|
provider_name: str, |
|
original_provider: str, |
|
icon: dict, |
|
credentials: dict, |
|
schema_type: str, |
|
schema: str, |
|
privacy_policy: str, |
|
custom_disclaimer: str, |
|
labels: list[str], |
|
): |
|
""" |
|
update api tool provider |
|
""" |
|
if schema_type not in [member.value for member in ApiProviderSchemaType]: |
|
raise ValueError(f"invalid schema type {schema}") |
|
|
|
|
|
provider: ApiToolProvider = ( |
|
db.session.query(ApiToolProvider) |
|
.filter( |
|
ApiToolProvider.tenant_id == tenant_id, |
|
ApiToolProvider.name == original_provider, |
|
) |
|
.first() |
|
) |
|
|
|
if provider is None: |
|
raise ValueError(f"api provider {provider_name} does not exists") |
|
|
|
|
|
extra_info = {} |
|
|
|
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) |
|
|
|
|
|
provider.name = provider_name |
|
provider.icon = json.dumps(icon) |
|
provider.schema = schema |
|
provider.description = extra_info.get("description", "") |
|
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value |
|
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) |
|
provider.privacy_policy = privacy_policy |
|
provider.custom_disclaimer = custom_disclaimer |
|
|
|
if "auth_type" not in credentials: |
|
raise ValueError("auth_type is required") |
|
|
|
|
|
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) |
|
|
|
|
|
provider_controller = ApiToolProviderController.from_db(provider, auth_type) |
|
|
|
provider_controller.load_bundled_tools(tool_bundles) |
|
|
|
|
|
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) |
|
|
|
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) |
|
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) |
|
|
|
for name, value in credentials.items(): |
|
if name in masked_credentials and value == masked_credentials[name]: |
|
credentials[name] = original_credentials[name] |
|
|
|
credentials = tool_configuration.encrypt_tool_credentials(credentials) |
|
provider.credentials_str = json.dumps(credentials) |
|
|
|
db.session.add(provider) |
|
db.session.commit() |
|
|
|
|
|
tool_configuration.delete_tool_credentials_cache() |
|
|
|
|
|
ToolLabelManager.update_tool_labels(provider_controller, labels) |
|
|
|
return {"result": "success"} |
|
|
|
@staticmethod |
|
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): |
|
""" |
|
delete tool provider |
|
""" |
|
provider: ApiToolProvider = ( |
|
db.session.query(ApiToolProvider) |
|
.filter( |
|
ApiToolProvider.tenant_id == tenant_id, |
|
ApiToolProvider.name == provider_name, |
|
) |
|
.first() |
|
) |
|
|
|
if provider is None: |
|
raise ValueError(f"you have not added provider {provider_name}") |
|
|
|
db.session.delete(provider) |
|
db.session.commit() |
|
|
|
return {"result": "success"} |
|
|
|
@staticmethod |
|
def get_api_tool_provider(user_id: str, tenant_id: str, provider: str): |
|
""" |
|
get api tool provider |
|
""" |
|
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) |
|
|
|
@staticmethod |
|
def test_api_tool_preview( |
|
tenant_id: str, |
|
provider_name: str, |
|
tool_name: str, |
|
credentials: dict, |
|
parameters: dict, |
|
schema_type: str, |
|
schema: str, |
|
): |
|
""" |
|
test api tool before adding api tool provider |
|
""" |
|
if schema_type not in [member.value for member in ApiProviderSchemaType]: |
|
raise ValueError(f"invalid schema type {schema_type}") |
|
|
|
try: |
|
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema) |
|
except Exception as e: |
|
raise ValueError("invalid schema") |
|
|
|
|
|
tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None) |
|
if tool_bundle is None: |
|
raise ValueError(f"invalid tool name {tool_name}") |
|
|
|
db_provider: ApiToolProvider = ( |
|
db.session.query(ApiToolProvider) |
|
.filter( |
|
ApiToolProvider.tenant_id == tenant_id, |
|
ApiToolProvider.name == provider_name, |
|
) |
|
.first() |
|
) |
|
|
|
if not db_provider: |
|
|
|
db_provider = ApiToolProvider( |
|
tenant_id="", |
|
user_id="", |
|
name="", |
|
icon="", |
|
schema=schema, |
|
description="", |
|
schema_type_str=ApiProviderSchemaType.OPENAPI.value, |
|
tools_str=json.dumps(jsonable_encoder(tool_bundles)), |
|
credentials_str=json.dumps(credentials), |
|
) |
|
|
|
if "auth_type" not in credentials: |
|
raise ValueError("auth_type is required") |
|
|
|
|
|
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) |
|
|
|
|
|
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) |
|
|
|
provider_controller.load_bundled_tools(tool_bundles) |
|
|
|
|
|
if db_provider.id: |
|
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) |
|
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) |
|
|
|
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) |
|
for name, value in credentials.items(): |
|
if name in masked_credentials and value == masked_credentials[name]: |
|
credentials[name] = decrypted_credentials[name] |
|
|
|
try: |
|
provider_controller.validate_credentials_format(credentials) |
|
|
|
tool = provider_controller.get_tool(tool_name) |
|
tool = tool.fork_tool_runtime( |
|
runtime={ |
|
"credentials": credentials, |
|
"tenant_id": tenant_id, |
|
} |
|
) |
|
result = tool.validate_credentials(credentials, parameters) |
|
except Exception as e: |
|
return {"error": str(e)} |
|
|
|
return {"result": result or "empty response"} |
|
|
|
@staticmethod |
|
def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: |
|
""" |
|
list api tools |
|
""" |
|
|
|
db_providers: list[ApiToolProvider] = ( |
|
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or [] |
|
) |
|
|
|
result: list[UserToolProvider] = [] |
|
|
|
for provider in db_providers: |
|
|
|
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) |
|
labels = ToolLabelManager.get_tool_labels(provider_controller) |
|
user_provider = ToolTransformService.api_provider_to_user_provider( |
|
provider_controller, db_provider=provider, decrypt_credentials=True |
|
) |
|
user_provider.labels = labels |
|
|
|
|
|
ToolTransformService.repack_provider(user_provider) |
|
|
|
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) |
|
|
|
for tool in tools: |
|
user_provider.tools.append( |
|
ToolTransformService.tool_to_user_tool( |
|
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels |
|
) |
|
) |
|
|
|
result.append(user_provider) |
|
|
|
return result |
|
|