|
from abc import ABC, abstractmethod |
|
from enum import Enum |
|
from typing import Optional |
|
|
|
from pydantic import BaseModel |
|
|
|
from core.extension.extensible import Extensible, ExtensionModule |
|
|
|
|
|
class ModerationAction(Enum): |
|
DIRECT_OUTPUT = "direct_output" |
|
OVERRIDDEN = "overridden" |
|
|
|
|
|
class ModerationInputsResult(BaseModel): |
|
flagged: bool = False |
|
action: ModerationAction |
|
preset_response: str = "" |
|
inputs: dict = {} |
|
query: str = "" |
|
|
|
|
|
class ModerationOutputsResult(BaseModel): |
|
flagged: bool = False |
|
action: ModerationAction |
|
preset_response: str = "" |
|
text: str = "" |
|
|
|
|
|
class Moderation(Extensible, ABC): |
|
""" |
|
The base class of moderation. |
|
""" |
|
|
|
module: ExtensionModule = ExtensionModule.MODERATION |
|
|
|
def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None: |
|
super().__init__(tenant_id, config) |
|
self.app_id = app_id |
|
|
|
@classmethod |
|
@abstractmethod |
|
def validate_config(cls, tenant_id: str, config: dict) -> None: |
|
""" |
|
Validate the incoming form config data. |
|
|
|
:param tenant_id: the id of workspace |
|
:param config: the form config data |
|
:return: |
|
""" |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: |
|
""" |
|
Moderation for inputs. |
|
After the user inputs, this method will be called to perform sensitive content review |
|
on the user inputs and return the processed results. |
|
|
|
:param inputs: user inputs |
|
:param query: query string (required in chat app) |
|
:return: |
|
""" |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: |
|
""" |
|
Moderation for outputs. |
|
When LLM outputs content, the front end will pass the output content (may be segmented) |
|
to this method for sensitive content review, and the output content will be shielded if the review fails. |
|
|
|
:param text: LLM output content |
|
:return: |
|
""" |
|
raise NotImplementedError |
|
|
|
@classmethod |
|
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None: |
|
|
|
inputs_config = config.get("inputs_config") |
|
if not isinstance(inputs_config, dict): |
|
raise ValueError("inputs_config must be a dict") |
|
|
|
|
|
outputs_config = config.get("outputs_config") |
|
if not isinstance(outputs_config, dict): |
|
raise ValueError("outputs_config must be a dict") |
|
|
|
inputs_config_enabled = inputs_config.get("enabled") |
|
outputs_config_enabled = outputs_config.get("enabled") |
|
if not inputs_config_enabled and not outputs_config_enabled: |
|
raise ValueError("At least one of inputs_config or outputs_config must be enabled") |
|
|
|
|
|
if not is_preset_response_required: |
|
return |
|
|
|
if inputs_config_enabled: |
|
if not inputs_config.get("preset_response"): |
|
raise ValueError("inputs_config.preset_response is required") |
|
|
|
if len(inputs_config.get("preset_response")) > 100: |
|
raise ValueError("inputs_config.preset_response must be less than 100 characters") |
|
|
|
if outputs_config_enabled: |
|
if not outputs_config.get("preset_response"): |
|
raise ValueError("outputs_config.preset_response is required") |
|
|
|
if len(outputs_config.get("preset_response")) > 100: |
|
raise ValueError("outputs_config.preset_response must be less than 100 characters") |
|
|
|
|
|
class ModerationError(Exception): |
|
pass |
|
|