|
import logging |
|
from typing import Optional |
|
|
|
from core.app.app_config.entities import AppConfig |
|
from core.moderation.base import ModerationAction, ModerationError |
|
from core.moderation.factory import ModerationFactory |
|
from core.ops.entities.trace_entity import TraceTaskName |
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask |
|
from core.ops.utils import measure_time |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class InputModeration: |
|
def check( |
|
self, |
|
app_id: str, |
|
tenant_id: str, |
|
app_config: AppConfig, |
|
inputs: dict, |
|
query: str, |
|
message_id: str, |
|
trace_manager: Optional[TraceQueueManager] = None, |
|
) -> tuple[bool, dict, str]: |
|
""" |
|
Process sensitive_word_avoidance. |
|
:param app_id: app id |
|
:param tenant_id: tenant id |
|
:param app_config: app config |
|
:param inputs: inputs |
|
:param query: query |
|
:param message_id: message id |
|
:param trace_manager: trace manager |
|
:return: |
|
""" |
|
if not app_config.sensitive_word_avoidance: |
|
return False, inputs, query |
|
|
|
sensitive_word_avoidance_config = app_config.sensitive_word_avoidance |
|
moderation_type = sensitive_word_avoidance_config.type |
|
|
|
moderation_factory = ModerationFactory( |
|
name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config |
|
) |
|
|
|
with measure_time() as timer: |
|
moderation_result = moderation_factory.moderation_for_inputs(inputs, query) |
|
|
|
if trace_manager: |
|
trace_manager.add_trace_task( |
|
TraceTask( |
|
TraceTaskName.MODERATION_TRACE, |
|
message_id=message_id, |
|
moderation_result=moderation_result, |
|
inputs=inputs, |
|
timer=timer, |
|
) |
|
) |
|
|
|
if not moderation_result.flagged: |
|
return False, inputs, query |
|
|
|
if moderation_result.action == ModerationAction.DIRECT_OUTPUT: |
|
raise ModerationError(moderation_result.preset_response) |
|
elif moderation_result.action == ModerationAction.OVERRIDDEN: |
|
inputs = moderation_result.inputs |
|
query = moderation_result.query |
|
|
|
return True, inputs, query |
|
|