|
import datetime |
|
import json |
|
import logging |
|
from collections import defaultdict |
|
from collections.abc import Iterator |
|
from json import JSONDecodeError |
|
from typing import Optional |
|
|
|
from pydantic import BaseModel, ConfigDict |
|
|
|
from constants import HIDDEN_VALUE |
|
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity |
|
from core.entities.provider_entities import ( |
|
CustomConfiguration, |
|
ModelSettings, |
|
SystemConfiguration, |
|
SystemConfigurationStatus, |
|
) |
|
from core.helper import encrypter |
|
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType |
|
from core.model_runtime.entities.model_entities import FetchFrom, ModelType |
|
from core.model_runtime.entities.provider_entities import ( |
|
ConfigurateMethod, |
|
CredentialFormSchema, |
|
FormType, |
|
ProviderEntity, |
|
) |
|
from core.model_runtime.model_providers import model_provider_factory |
|
from core.model_runtime.model_providers.__base.ai_model import AIModel |
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider |
|
from extensions.ext_database import db |
|
from models.provider import ( |
|
LoadBalancingModelConfig, |
|
Provider, |
|
ProviderModel, |
|
ProviderModelSetting, |
|
ProviderType, |
|
TenantPreferredModelProvider, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
original_provider_configurate_methods = {} |
|
|
|
|
|
class ProviderConfiguration(BaseModel): |
|
""" |
|
Model class for provider configuration. |
|
""" |
|
|
|
tenant_id: str |
|
provider: ProviderEntity |
|
preferred_provider_type: ProviderType |
|
using_provider_type: ProviderType |
|
system_configuration: SystemConfiguration |
|
custom_configuration: CustomConfiguration |
|
model_settings: list[ModelSettings] |
|
|
|
|
|
model_config = ConfigDict(protected_namespaces=()) |
|
|
|
def __init__(self, **data): |
|
super().__init__(**data) |
|
|
|
if self.provider.provider not in original_provider_configurate_methods: |
|
original_provider_configurate_methods[self.provider.provider] = [] |
|
for configurate_method in self.provider.configurate_methods: |
|
original_provider_configurate_methods[self.provider.provider].append(configurate_method) |
|
|
|
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: |
|
if ( |
|
any( |
|
len(quota_configuration.restrict_models) > 0 |
|
for quota_configuration in self.system_configuration.quota_configurations |
|
) |
|
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods |
|
): |
|
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) |
|
|
|
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: |
|
""" |
|
Get current credentials. |
|
|
|
:param model_type: model type |
|
:param model: model name |
|
:return: |
|
""" |
|
if self.model_settings: |
|
|
|
for model_setting in self.model_settings: |
|
if model_setting.model_type == model_type and model_setting.model == model: |
|
if not model_setting.enabled: |
|
raise ValueError(f"Model {model} is disabled.") |
|
|
|
if self.using_provider_type == ProviderType.SYSTEM: |
|
restrict_models = [] |
|
for quota_configuration in self.system_configuration.quota_configurations: |
|
if self.system_configuration.current_quota_type != quota_configuration.quota_type: |
|
continue |
|
|
|
restrict_models = quota_configuration.restrict_models |
|
|
|
copy_credentials = self.system_configuration.credentials.copy() |
|
if restrict_models: |
|
for restrict_model in restrict_models: |
|
if ( |
|
restrict_model.model_type == model_type |
|
and restrict_model.model == model |
|
and restrict_model.base_model_name |
|
): |
|
copy_credentials["base_model_name"] = restrict_model.base_model_name |
|
|
|
return copy_credentials |
|
else: |
|
credentials = None |
|
if self.custom_configuration.models: |
|
for model_configuration in self.custom_configuration.models: |
|
if model_configuration.model_type == model_type and model_configuration.model == model: |
|
credentials = model_configuration.credentials |
|
break |
|
|
|
if not credentials and self.custom_configuration.provider: |
|
credentials = self.custom_configuration.provider.credentials |
|
|
|
return credentials |
|
|
|
def get_system_configuration_status(self) -> SystemConfigurationStatus: |
|
""" |
|
Get system configuration status. |
|
:return: |
|
""" |
|
if self.system_configuration.enabled is False: |
|
return SystemConfigurationStatus.UNSUPPORTED |
|
|
|
current_quota_type = self.system_configuration.current_quota_type |
|
current_quota_configuration = next( |
|
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None |
|
) |
|
|
|
return ( |
|
SystemConfigurationStatus.ACTIVE |
|
if current_quota_configuration.is_valid |
|
else SystemConfigurationStatus.QUOTA_EXCEEDED |
|
) |
|
|
|
def is_custom_configuration_available(self) -> bool: |
|
""" |
|
Check custom configuration available. |
|
:return: |
|
""" |
|
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 |
|
|
|
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: |
|
""" |
|
Get custom credentials. |
|
|
|
:param obfuscated: obfuscated secret data in credentials |
|
:return: |
|
""" |
|
if self.custom_configuration.provider is None: |
|
return None |
|
|
|
credentials = self.custom_configuration.provider.credentials |
|
if not obfuscated: |
|
return credentials |
|
|
|
|
|
return self.obfuscated_credentials( |
|
credentials=credentials, |
|
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas |
|
if self.provider.provider_credential_schema |
|
else [], |
|
) |
|
|
|
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: |
|
""" |
|
Validate custom credentials. |
|
:param credentials: provider credentials |
|
:return: |
|
""" |
|
|
|
provider_record = ( |
|
db.session.query(Provider) |
|
.filter( |
|
Provider.tenant_id == self.tenant_id, |
|
Provider.provider_name == self.provider.provider, |
|
Provider.provider_type == ProviderType.CUSTOM.value, |
|
) |
|
.first() |
|
) |
|
|
|
|
|
provider_credential_secret_variables = self.extract_secret_variables( |
|
self.provider.provider_credential_schema.credential_form_schemas |
|
if self.provider.provider_credential_schema |
|
else [] |
|
) |
|
|
|
if provider_record: |
|
try: |
|
|
|
if provider_record.encrypted_config: |
|
if not provider_record.encrypted_config.startswith("{"): |
|
original_credentials = {"openai_api_key": provider_record.encrypted_config} |
|
else: |
|
original_credentials = json.loads(provider_record.encrypted_config) |
|
else: |
|
original_credentials = {} |
|
except JSONDecodeError: |
|
original_credentials = {} |
|
|
|
|
|
for key, value in credentials.items(): |
|
if key in provider_credential_secret_variables: |
|
|
|
if value == HIDDEN_VALUE and key in original_credentials: |
|
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) |
|
|
|
credentials = model_provider_factory.provider_credentials_validate( |
|
provider=self.provider.provider, credentials=credentials |
|
) |
|
|
|
for key, value in credentials.items(): |
|
if key in provider_credential_secret_variables: |
|
credentials[key] = encrypter.encrypt_token(self.tenant_id, value) |
|
|
|
return provider_record, credentials |
|
|
|
def add_or_update_custom_credentials(self, credentials: dict) -> None: |
|
""" |
|
Add or update custom provider credentials. |
|
:param credentials: |
|
:return: |
|
""" |
|
|
|
provider_record, credentials = self.custom_credentials_validate(credentials) |
|
|
|
|
|
|
|
if provider_record: |
|
provider_record.encrypted_config = json.dumps(credentials) |
|
provider_record.is_valid = True |
|
provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) |
|
db.session.commit() |
|
else: |
|
provider_record = Provider( |
|
tenant_id=self.tenant_id, |
|
provider_name=self.provider.provider, |
|
provider_type=ProviderType.CUSTOM.value, |
|
encrypted_config=json.dumps(credentials), |
|
is_valid=True, |
|
) |
|
db.session.add(provider_record) |
|
db.session.commit() |
|
|
|
provider_model_credentials_cache = ProviderCredentialsCache( |
|
tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER |
|
) |
|
|
|
provider_model_credentials_cache.delete() |
|
|
|
self.switch_preferred_provider_type(ProviderType.CUSTOM) |
|
|
|
def delete_custom_credentials(self) -> None: |
|
""" |
|
Delete custom provider credentials. |
|
:return: |
|
""" |
|
|
|
provider_record = ( |
|
db.session.query(Provider) |
|
.filter( |
|
Provider.tenant_id == self.tenant_id, |
|
Provider.provider_name == self.provider.provider, |
|
Provider.provider_type == ProviderType.CUSTOM.value, |
|
) |
|
.first() |
|
) |
|
|
|
|
|
if provider_record: |
|
self.switch_preferred_provider_type(ProviderType.SYSTEM) |
|
|
|
db.session.delete(provider_record) |
|
db.session.commit() |
|
|
|
provider_model_credentials_cache = ProviderCredentialsCache( |
|
tenant_id=self.tenant_id, |
|
identity_id=provider_record.id, |
|
cache_type=ProviderCredentialsCacheType.PROVIDER, |
|
) |
|
|
|
provider_model_credentials_cache.delete() |
|
|
|
def get_custom_model_credentials( |
|
self, model_type: ModelType, model: str, obfuscated: bool = False |
|
) -> Optional[dict]: |
|
""" |
|
Get custom model credentials. |
|
|
|
:param model_type: model type |
|
:param model: model name |
|
:param obfuscated: obfuscated secret data in credentials |
|
:return: |
|
""" |
|
if not self.custom_configuration.models: |
|
return None |
|
|
|
for model_configuration in self.custom_configuration.models: |
|
if model_configuration.model_type == model_type and model_configuration.model == model: |
|
credentials = model_configuration.credentials |
|
if not obfuscated: |
|
return credentials |
|
|
|
|
|
return self.obfuscated_credentials( |
|
credentials=credentials, |
|
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas |
|
if self.provider.model_credential_schema |
|
else [], |
|
) |
|
|
|
return None |
|
|
|
def custom_model_credentials_validate( |
|
self, model_type: ModelType, model: str, credentials: dict |
|
) -> tuple[ProviderModel, dict]: |
|
""" |
|
Validate custom model credentials. |
|
|
|
:param model_type: model type |
|
:param model: model name |
|
:param credentials: model credentials |
|
:return: |
|
""" |
|
|
|
provider_model_record = ( |
|
db.session.query(ProviderModel) |
|
.filter( |
|
ProviderModel.tenant_id == self.tenant_id, |
|
ProviderModel.provider_name == self.provider.provider, |
|
ProviderModel.model_name == model, |
|
ProviderModel.model_type == model_type.to_origin_model_type(), |
|
) |
|
.first() |
|
) |
|
|
|
|
|
provider_credential_secret_variables = self.extract_secret_variables( |
|
self.provider.model_credential_schema.credential_form_schemas |
|
if self.provider.model_credential_schema |
|
else [] |
|
) |
|
|
|
if provider_model_record: |
|
try: |
|
original_credentials = ( |
|
json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} |
|
) |
|
except JSONDecodeError: |
|
original_credentials = {} |
|
|
|
|
|
for key, value in credentials.items(): |
|
if key in provider_credential_secret_variables: |
|
|
|
if value == HIDDEN_VALUE and key in original_credentials: |
|
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) |
|
|
|
credentials = model_provider_factory.model_credentials_validate( |
|
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials |
|
) |
|
|
|
for key, value in credentials.items(): |
|
if key in provider_credential_secret_variables: |
|
credentials[key] = encrypter.encrypt_token(self.tenant_id, value) |
|
|
|
return provider_model_record, credentials |
|
|
|
def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None: |
|
""" |
|
Add or update custom model credentials. |
|
|
|
:param model_type: model type |
|
:param model: model name |
|
:param credentials: model credentials |
|
:return: |
|
""" |
|
|
|
provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials) |
|
|
|
|
|
|
|
if provider_model_record: |
|
provider_model_record.encrypted_config = json.dumps(credentials) |
|
provider_model_record.is_valid = True |
|
provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) |
|
db.session.commit() |
|
else: |
|
provider_model_record = ProviderModel( |
|
tenant_id=self.tenant_id, |
|
provider_name=self.provider.provider, |
|
model_name=model, |
|
model_type=model_type.to_origin_model_type(), |
|
encrypted_config=json.dumps(credentials), |
|
is_valid=True, |
|
) |
|
db.session.add(provider_model_record) |
|
db.session.commit() |
|
|
|
provider_model_credentials_cache = ProviderCredentialsCache( |
|
tenant_id=self.tenant_id, |
|
identity_id=provider_model_record.id, |
|
cache_type=ProviderCredentialsCacheType.MODEL, |
|
) |
|
|
|
provider_model_credentials_cache.delete() |
|
|
|
def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None: |
|
""" |
|
Delete custom model credentials. |
|
:param model_type: model type |
|
:param model: model name |
|
:return: |
|
""" |
|
|
|
provider_model_record = ( |
|
db.session.query(ProviderModel) |
|
.filter( |
|
ProviderModel.tenant_id == self.tenant_id, |
|
ProviderModel.provider_name == self.provider.provider, |
|
ProviderModel.model_name == model, |
|
ProviderModel.model_type == model_type.to_origin_model_type(), |
|
) |
|
.first() |
|
) |
|
|
|
|
|
if provider_model_record: |
|
db.session.delete(provider_model_record) |
|
db.session.commit() |
|
|
|
provider_model_credentials_cache = ProviderCredentialsCache( |
|
tenant_id=self.tenant_id, |
|
identity_id=provider_model_record.id, |
|
cache_type=ProviderCredentialsCacheType.MODEL, |
|
) |
|
|
|
provider_model_credentials_cache.delete() |
|
|
|
def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: |
|
""" |
|
Enable model. |
|
:param model_type: model type |
|
:param model: model name |
|
:return: |
|
""" |
|
model_setting = ( |
|
db.session.query(ProviderModelSetting) |
|
.filter( |
|
ProviderModelSetting.tenant_id == self.tenant_id, |
|
ProviderModelSetting.provider_name == self.provider.provider, |
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(), |
|
ProviderModelSetting.model_name == model, |
|
) |
|
.first() |
|
) |
|
|
|
if model_setting: |
|
model_setting.enabled = True |
|
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) |
|
db.session.commit() |
|
else: |
|
model_setting = ProviderModelSetting( |
|
tenant_id=self.tenant_id, |
|
provider_name=self.provider.provider, |
|
model_type=model_type.to_origin_model_type(), |
|
model_name=model, |
|
enabled=True, |
|
) |
|
db.session.add(model_setting) |
|
db.session.commit() |
|
|
|
return model_setting |
|
|
|
def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: |
|
""" |
|
Disable model. |
|
:param model_type: model type |
|
:param model: model name |
|
:return: |
|
""" |
|
model_setting = ( |
|
db.session.query(ProviderModelSetting) |
|
.filter( |
|
ProviderModelSetting.tenant_id == self.tenant_id, |
|
ProviderModelSetting.provider_name == self.provider.provider, |
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(), |
|
ProviderModelSetting.model_name == model, |
|
) |
|
.first() |
|
) |
|
|
|
if model_setting: |
|
model_setting.enabled = False |
|
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) |
|
db.session.commit() |
|
else: |
|
model_setting = ProviderModelSetting( |
|
tenant_id=self.tenant_id, |
|
provider_name=self.provider.provider, |
|
model_type=model_type.to_origin_model_type(), |
|
model_name=model, |
|
enabled=False, |
|
) |
|
db.session.add(model_setting) |
|
db.session.commit() |
|
|
|
return model_setting |
|
|
|
def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]: |
|
""" |
|
Get provider model setting. |
|
:param model_type: model type |
|
:param model: model name |
|
:return: |
|
""" |
|
return ( |
|
db.session.query(ProviderModelSetting) |
|
.filter( |
|
ProviderModelSetting.tenant_id == self.tenant_id, |
|
ProviderModelSetting.provider_name == self.provider.provider, |
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(), |
|
ProviderModelSetting.model_name == model, |
|
) |
|
.first() |
|
) |
|
|
|
def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: |
|
""" |
|
Enable model load balancing. |
|
:param model_type: model type |
|
:param model: model name |
|
:return: |
|
""" |
|
load_balancing_config_count = ( |
|
db.session.query(LoadBalancingModelConfig) |
|
.filter( |
|
LoadBalancingModelConfig.tenant_id == self.tenant_id, |
|
LoadBalancingModelConfig.provider_name == self.provider.provider, |
|
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), |
|
LoadBalancingModelConfig.model_name == model, |
|
) |
|
.count() |
|
) |
|
|
|
if load_balancing_config_count <= 1: |
|
raise ValueError("Model load balancing configuration must be more than 1.") |
|
|
|
model_setting = ( |
|
db.session.query(ProviderModelSetting) |
|
.filter( |
|
ProviderModelSetting.tenant_id == self.tenant_id, |
|
ProviderModelSetting.provider_name == self.provider.provider, |
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(), |
|
ProviderModelSetting.model_name == model, |
|
) |
|
.first() |
|
) |
|
|
|
if model_setting: |
|
model_setting.load_balancing_enabled = True |
|
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) |
|
db.session.commit() |
|
else: |
|
model_setting = ProviderModelSetting( |
|
tenant_id=self.tenant_id, |
|
provider_name=self.provider.provider, |
|
model_type=model_type.to_origin_model_type(), |
|
model_name=model, |
|
load_balancing_enabled=True, |
|
) |
|
db.session.add(model_setting) |
|
db.session.commit() |
|
|
|
return model_setting |
|
|
|
def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: |
|
""" |
|
Disable model load balancing. |
|
:param model_type: model type |
|
:param model: model name |
|
:return: |
|
""" |
|
model_setting = ( |
|
db.session.query(ProviderModelSetting) |
|
.filter( |
|
ProviderModelSetting.tenant_id == self.tenant_id, |
|
ProviderModelSetting.provider_name == self.provider.provider, |
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(), |
|
ProviderModelSetting.model_name == model, |
|
) |
|
.first() |
|
) |
|
|
|
if model_setting: |
|
model_setting.load_balancing_enabled = False |
|
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) |
|
db.session.commit() |
|
else: |
|
model_setting = ProviderModelSetting( |
|
tenant_id=self.tenant_id, |
|
provider_name=self.provider.provider, |
|
model_type=model_type.to_origin_model_type(), |
|
model_name=model, |
|
load_balancing_enabled=False, |
|
) |
|
db.session.add(model_setting) |
|
db.session.commit() |
|
|
|
return model_setting |
|
|
|
def get_provider_instance(self) -> ModelProvider: |
|
""" |
|
Get provider instance. |
|
:return: |
|
""" |
|
return model_provider_factory.get_provider_instance(self.provider.provider) |
|
|
|
def get_model_type_instance(self, model_type: ModelType) -> AIModel: |
|
""" |
|
Get current model type instance. |
|
|
|
:param model_type: model type |
|
:return: |
|
""" |
|
|
|
provider_instance = self.get_provider_instance() |
|
|
|
|
|
return provider_instance.get_model_instance(model_type) |
|
|
|
def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: |
|
""" |
|
Switch preferred provider type. |
|
:param provider_type: |
|
:return: |
|
""" |
|
if provider_type == self.preferred_provider_type: |
|
return |
|
|
|
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: |
|
return |
|
|
|
|
|
preferred_model_provider = ( |
|
db.session.query(TenantPreferredModelProvider) |
|
.filter( |
|
TenantPreferredModelProvider.tenant_id == self.tenant_id, |
|
TenantPreferredModelProvider.provider_name == self.provider.provider, |
|
) |
|
.first() |
|
) |
|
|
|
if preferred_model_provider: |
|
preferred_model_provider.preferred_provider_type = provider_type.value |
|
else: |
|
preferred_model_provider = TenantPreferredModelProvider( |
|
tenant_id=self.tenant_id, |
|
provider_name=self.provider.provider, |
|
preferred_provider_type=provider_type.value, |
|
) |
|
db.session.add(preferred_model_provider) |
|
|
|
db.session.commit() |
|
|
|
def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: |
|
""" |
|
Extract secret input form variables. |
|
|
|
:param credential_form_schemas: |
|
:return: |
|
""" |
|
secret_input_form_variables = [] |
|
for credential_form_schema in credential_form_schemas: |
|
if credential_form_schema.type == FormType.SECRET_INPUT: |
|
secret_input_form_variables.append(credential_form_schema.variable) |
|
|
|
return secret_input_form_variables |
|
|
|
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: |
|
""" |
|
Obfuscated credentials. |
|
|
|
:param credentials: credentials |
|
:param credential_form_schemas: credential form schemas |
|
:return: |
|
""" |
|
|
|
credential_secret_variables = self.extract_secret_variables(credential_form_schemas) |
|
|
|
|
|
copy_credentials = credentials.copy() |
|
for key, value in copy_credentials.items(): |
|
if key in credential_secret_variables: |
|
copy_credentials[key] = encrypter.obfuscated_token(value) |
|
|
|
return copy_credentials |
|
|
|
def get_provider_model( |
|
self, model_type: ModelType, model: str, only_active: bool = False |
|
) -> Optional[ModelWithProviderEntity]: |
|
""" |
|
Get provider model. |
|
:param model_type: model type |
|
:param model: model name |
|
:param only_active: return active model only |
|
:return: |
|
""" |
|
provider_models = self.get_provider_models(model_type, only_active) |
|
|
|
for provider_model in provider_models: |
|
if provider_model.model == model: |
|
return provider_model |
|
|
|
return None |
|
|
|
def get_provider_models( |
|
self, model_type: Optional[ModelType] = None, only_active: bool = False |
|
) -> list[ModelWithProviderEntity]: |
|
""" |
|
Get provider models. |
|
:param model_type: model type |
|
:param only_active: only active models |
|
:return: |
|
""" |
|
provider_instance = self.get_provider_instance() |
|
|
|
model_types = [] |
|
if model_type: |
|
model_types.append(model_type) |
|
else: |
|
model_types = provider_instance.get_provider_schema().supported_model_types |
|
|
|
|
|
model_setting_map = defaultdict(dict) |
|
for model_setting in self.model_settings: |
|
model_setting_map[model_setting.model_type][model_setting.model] = model_setting |
|
|
|
if self.using_provider_type == ProviderType.SYSTEM: |
|
provider_models = self._get_system_provider_models( |
|
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map |
|
) |
|
else: |
|
provider_models = self._get_custom_provider_models( |
|
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map |
|
) |
|
|
|
if only_active: |
|
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE] |
|
|
|
|
|
return sorted(provider_models, key=lambda x: x.model_type.value) |
|
|
|
def _get_system_provider_models( |
|
self, |
|
model_types: list[ModelType], |
|
provider_instance: ModelProvider, |
|
model_setting_map: dict[ModelType, dict[str, ModelSettings]], |
|
) -> list[ModelWithProviderEntity]: |
|
""" |
|
Get system provider models. |
|
|
|
:param model_types: model types |
|
:param provider_instance: provider instance |
|
:param model_setting_map: model setting map |
|
:return: |
|
""" |
|
provider_models = [] |
|
for model_type in model_types: |
|
for m in provider_instance.models(model_type): |
|
status = ModelStatus.ACTIVE |
|
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: |
|
model_setting = model_setting_map[m.model_type][m.model] |
|
if model_setting.enabled is False: |
|
status = ModelStatus.DISABLED |
|
|
|
provider_models.append( |
|
ModelWithProviderEntity( |
|
model=m.model, |
|
label=m.label, |
|
model_type=m.model_type, |
|
features=m.features, |
|
fetch_from=m.fetch_from, |
|
model_properties=m.model_properties, |
|
deprecated=m.deprecated, |
|
provider=SimpleModelProviderEntity(self.provider), |
|
status=status, |
|
) |
|
) |
|
|
|
if self.provider.provider not in original_provider_configurate_methods: |
|
original_provider_configurate_methods[self.provider.provider] = [] |
|
for configurate_method in provider_instance.get_provider_schema().configurate_methods: |
|
original_provider_configurate_methods[self.provider.provider].append(configurate_method) |
|
|
|
should_use_custom_model = False |
|
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: |
|
should_use_custom_model = True |
|
|
|
for quota_configuration in self.system_configuration.quota_configurations: |
|
if self.system_configuration.current_quota_type != quota_configuration.quota_type: |
|
continue |
|
|
|
restrict_models = quota_configuration.restrict_models |
|
if len(restrict_models) == 0: |
|
break |
|
|
|
if should_use_custom_model: |
|
if original_provider_configurate_methods[self.provider.provider] == [ |
|
ConfigurateMethod.CUSTOMIZABLE_MODEL |
|
]: |
|
|
|
for restrict_model in restrict_models: |
|
copy_credentials = self.system_configuration.credentials.copy() |
|
if restrict_model.base_model_name: |
|
copy_credentials["base_model_name"] = restrict_model.base_model_name |
|
|
|
try: |
|
custom_model_schema = provider_instance.get_model_instance( |
|
restrict_model.model_type |
|
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) |
|
except Exception as ex: |
|
logger.warning(f"get custom model schema failed, {ex}") |
|
continue |
|
|
|
if not custom_model_schema: |
|
continue |
|
|
|
if custom_model_schema.model_type not in model_types: |
|
continue |
|
|
|
status = ModelStatus.ACTIVE |
|
if ( |
|
custom_model_schema.model_type in model_setting_map |
|
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] |
|
): |
|
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] |
|
if model_setting.enabled is False: |
|
status = ModelStatus.DISABLED |
|
|
|
provider_models.append( |
|
ModelWithProviderEntity( |
|
model=custom_model_schema.model, |
|
label=custom_model_schema.label, |
|
model_type=custom_model_schema.model_type, |
|
features=custom_model_schema.features, |
|
fetch_from=FetchFrom.PREDEFINED_MODEL, |
|
model_properties=custom_model_schema.model_properties, |
|
deprecated=custom_model_schema.deprecated, |
|
provider=SimpleModelProviderEntity(self.provider), |
|
status=status, |
|
) |
|
) |
|
|
|
|
|
restrict_model_names = [rm.model for rm in restrict_models] |
|
for m in provider_models: |
|
if m.model_type == ModelType.LLM and m.model not in restrict_model_names: |
|
m.status = ModelStatus.NO_PERMISSION |
|
elif not quota_configuration.is_valid: |
|
m.status = ModelStatus.QUOTA_EXCEEDED |
|
|
|
return provider_models |
|
|
|
def _get_custom_provider_models( |
|
self, |
|
model_types: list[ModelType], |
|
provider_instance: ModelProvider, |
|
model_setting_map: dict[ModelType, dict[str, ModelSettings]], |
|
) -> list[ModelWithProviderEntity]: |
|
""" |
|
Get custom provider models. |
|
|
|
:param model_types: model types |
|
:param provider_instance: provider instance |
|
:param model_setting_map: model setting map |
|
:return: |
|
""" |
|
provider_models = [] |
|
|
|
credentials = None |
|
if self.custom_configuration.provider: |
|
credentials = self.custom_configuration.provider.credentials |
|
|
|
for model_type in model_types: |
|
if model_type not in self.provider.supported_model_types: |
|
continue |
|
|
|
models = provider_instance.models(model_type) |
|
for m in models: |
|
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE |
|
load_balancing_enabled = False |
|
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: |
|
model_setting = model_setting_map[m.model_type][m.model] |
|
if model_setting.enabled is False: |
|
status = ModelStatus.DISABLED |
|
|
|
if len(model_setting.load_balancing_configs) > 1: |
|
load_balancing_enabled = True |
|
|
|
provider_models.append( |
|
ModelWithProviderEntity( |
|
model=m.model, |
|
label=m.label, |
|
model_type=m.model_type, |
|
features=m.features, |
|
fetch_from=m.fetch_from, |
|
model_properties=m.model_properties, |
|
deprecated=m.deprecated, |
|
provider=SimpleModelProviderEntity(self.provider), |
|
status=status, |
|
load_balancing_enabled=load_balancing_enabled, |
|
) |
|
) |
|
|
|
|
|
for model_configuration in self.custom_configuration.models: |
|
if model_configuration.model_type not in model_types: |
|
continue |
|
|
|
try: |
|
custom_model_schema = provider_instance.get_model_instance( |
|
model_configuration.model_type |
|
).get_customizable_model_schema_from_credentials( |
|
model_configuration.model, model_configuration.credentials |
|
) |
|
except Exception as ex: |
|
logger.warning(f"get custom model schema failed, {ex}") |
|
continue |
|
|
|
if not custom_model_schema: |
|
continue |
|
|
|
status = ModelStatus.ACTIVE |
|
load_balancing_enabled = False |
|
if ( |
|
custom_model_schema.model_type in model_setting_map |
|
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] |
|
): |
|
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] |
|
if model_setting.enabled is False: |
|
status = ModelStatus.DISABLED |
|
|
|
if len(model_setting.load_balancing_configs) > 1: |
|
load_balancing_enabled = True |
|
|
|
provider_models.append( |
|
ModelWithProviderEntity( |
|
model=custom_model_schema.model, |
|
label=custom_model_schema.label, |
|
model_type=custom_model_schema.model_type, |
|
features=custom_model_schema.features, |
|
fetch_from=custom_model_schema.fetch_from, |
|
model_properties=custom_model_schema.model_properties, |
|
deprecated=custom_model_schema.deprecated, |
|
provider=SimpleModelProviderEntity(self.provider), |
|
status=status, |
|
load_balancing_enabled=load_balancing_enabled, |
|
) |
|
) |
|
|
|
return provider_models |
|
|
|
|
|
class ProviderConfigurations(BaseModel): |
|
""" |
|
Model class for provider configuration dict. |
|
""" |
|
|
|
tenant_id: str |
|
configurations: dict[str, ProviderConfiguration] = {} |
|
|
|
def __init__(self, tenant_id: str): |
|
super().__init__(tenant_id=tenant_id) |
|
|
|
def get_models( |
|
self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False |
|
) -> list[ModelWithProviderEntity]: |
|
""" |
|
Get available models. |
|
|
|
If preferred provider type is `system`: |
|
Get the current **system mode** if provider supported, |
|
if all system modes are not available (no quota), it is considered to be the **custom credential mode**. |
|
If there is no model configured in custom mode, it is treated as no_configure. |
|
system > custom > no_configure |
|
|
|
If preferred provider type is `custom`: |
|
If custom credentials are configured, it is treated as custom mode. |
|
Otherwise, get the current **system mode** if supported, |
|
If all system modes are not available (no quota), it is treated as no_configure. |
|
custom > system > no_configure |
|
|
|
If real mode is `system`, use system credentials to get models, |
|
paid quotas > provider free quotas > system free quotas |
|
include pre-defined models (exclude GPT-4, status marked as `no_permission`). |
|
If real mode is `custom`, use workspace custom credentials to get models, |
|
include pre-defined models, custom models(manual append). |
|
If real mode is `no_configure`, only return pre-defined models from `model runtime`. |
|
(model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`) |
|
model status marked as `active` is available. |
|
|
|
:param provider: provider name |
|
:param model_type: model type |
|
:param only_active: only active models |
|
:return: |
|
""" |
|
all_models = [] |
|
for provider_configuration in self.values(): |
|
if provider and provider_configuration.provider.provider != provider: |
|
continue |
|
|
|
all_models.extend(provider_configuration.get_provider_models(model_type, only_active)) |
|
|
|
return all_models |
|
|
|
def to_list(self) -> list[ProviderConfiguration]: |
|
""" |
|
Convert to list. |
|
|
|
:return: |
|
""" |
|
return list(self.values()) |
|
|
|
def __getitem__(self, key): |
|
return self.configurations[key] |
|
|
|
def __setitem__(self, key, value): |
|
self.configurations[key] = value |
|
|
|
def __iter__(self): |
|
return iter(self.configurations) |
|
|
|
def values(self) -> Iterator[ProviderConfiguration]: |
|
return self.configurations.values() |
|
|
|
def get(self, key, default=None): |
|
return self.configurations.get(key, default) |
|
|
|
|
|
class ProviderModelBundle(BaseModel): |
|
""" |
|
Provider model bundle. |
|
""" |
|
|
|
configuration: ProviderConfiguration |
|
provider_instance: ModelProvider |
|
model_type_instance: AIModel |
|
|
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) |
|
|