|
import json |
|
import logging |
|
import random |
|
import re |
|
import string |
|
import subprocess |
|
import time |
|
import uuid |
|
from collections.abc import Generator |
|
from datetime import datetime |
|
from hashlib import sha256 |
|
from typing import Any, Optional, Union |
|
from zoneinfo import available_timezones |
|
|
|
from flask import Response, stream_with_context |
|
from flask_restful import fields |
|
|
|
from configs import dify_config |
|
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator |
|
from core.file import helpers as file_helpers |
|
from extensions.ext_redis import redis_client |
|
from models.account import Account |
|
|
|
|
|
def run(script): |
|
return subprocess.getstatusoutput("source /root/.bashrc && " + script) |
|
|
|
|
|
class AppIconUrlField(fields.Raw): |
|
def output(self, key, obj): |
|
if obj is None: |
|
return None |
|
|
|
from models.model import IconType |
|
|
|
if obj.icon_type == IconType.IMAGE.value: |
|
return file_helpers.get_signed_file_url(obj.icon) |
|
return None |
|
|
|
|
|
class TimestampField(fields.Raw): |
|
def format(self, value) -> int: |
|
return int(value.timestamp()) |
|
|
|
|
|
def email(email): |
|
|
|
pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$" |
|
|
|
if re.match(pattern, email) is not None: |
|
return email |
|
|
|
error = "{email} is not a valid email.".format(email=email) |
|
raise ValueError(error) |
|
|
|
|
|
def uuid_value(value): |
|
if value == "": |
|
return str(value) |
|
|
|
try: |
|
uuid_obj = uuid.UUID(value) |
|
return str(uuid_obj) |
|
except ValueError: |
|
error = "{value} is not a valid uuid.".format(value=value) |
|
raise ValueError(error) |
|
|
|
|
|
def alphanumeric(value: str): |
|
|
|
if re.match(r"^[a-zA-Z0-9_]+$", value): |
|
return value |
|
|
|
raise ValueError(f"{value} is not a valid alphanumeric value") |
|
|
|
|
|
def timestamp_value(timestamp): |
|
try: |
|
int_timestamp = int(timestamp) |
|
if int_timestamp < 0: |
|
raise ValueError |
|
return int_timestamp |
|
except ValueError: |
|
error = "{timestamp} is not a valid timestamp.".format(timestamp=timestamp) |
|
raise ValueError(error) |
|
|
|
|
|
class StrLen: |
|
"""Restrict input to an integer in a range (inclusive)""" |
|
|
|
def __init__(self, max_length, argument="argument"): |
|
self.max_length = max_length |
|
self.argument = argument |
|
|
|
def __call__(self, value): |
|
length = len(value) |
|
if length > self.max_length: |
|
error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format( |
|
arg=self.argument, val=value, length=self.max_length |
|
) |
|
raise ValueError(error) |
|
|
|
return value |
|
|
|
|
|
class FloatRange: |
|
"""Restrict input to an float in a range (inclusive)""" |
|
|
|
def __init__(self, low, high, argument="argument"): |
|
self.low = low |
|
self.high = high |
|
self.argument = argument |
|
|
|
def __call__(self, value): |
|
value = _get_float(value) |
|
if value < self.low or value > self.high: |
|
error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format( |
|
arg=self.argument, val=value, lo=self.low, hi=self.high |
|
) |
|
raise ValueError(error) |
|
|
|
return value |
|
|
|
|
|
class DatetimeString: |
|
def __init__(self, format, argument="argument"): |
|
self.format = format |
|
self.argument = argument |
|
|
|
def __call__(self, value): |
|
try: |
|
datetime.strptime(value, self.format) |
|
except ValueError: |
|
error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format( |
|
arg=self.argument, val=value, format=self.format |
|
) |
|
raise ValueError(error) |
|
|
|
return value |
|
|
|
|
|
def _get_float(value): |
|
try: |
|
return float(value) |
|
except (TypeError, ValueError): |
|
raise ValueError("{} is not a valid float".format(value)) |
|
|
|
|
|
def timezone(timezone_string): |
|
if timezone_string and timezone_string in available_timezones(): |
|
return timezone_string |
|
|
|
error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string) |
|
raise ValueError(error) |
|
|
|
|
|
def generate_string(n): |
|
letters_digits = string.ascii_letters + string.digits |
|
result = "" |
|
for i in range(n): |
|
result += random.choice(letters_digits) |
|
|
|
return result |
|
|
|
|
|
def extract_remote_ip(request) -> str: |
|
if request.headers.get("CF-Connecting-IP"): |
|
return request.headers.get("Cf-Connecting-Ip") |
|
elif request.headers.getlist("X-Forwarded-For"): |
|
return request.headers.getlist("X-Forwarded-For")[0] |
|
else: |
|
return request.remote_addr |
|
|
|
|
|
def generate_text_hash(text: str) -> str: |
|
hash_text = str(text) + "None" |
|
return sha256(hash_text.encode()).hexdigest() |
|
|
|
|
|
def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response: |
|
if isinstance(response, dict): |
|
return Response(response=json.dumps(response), status=200, mimetype="application/json") |
|
else: |
|
|
|
def generate() -> Generator: |
|
yield from response |
|
|
|
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") |
|
|
|
|
|
class TokenManager: |
|
@classmethod |
|
def generate_token( |
|
cls, |
|
token_type: str, |
|
account: Optional[Account] = None, |
|
email: Optional[str] = None, |
|
additional_data: Optional[dict] = None, |
|
) -> str: |
|
if account is None and email is None: |
|
raise ValueError("Account or email must be provided") |
|
|
|
account_id = account.id if account else None |
|
account_email = account.email if account else email |
|
|
|
if account_id: |
|
old_token = cls._get_current_token_for_account(account_id, token_type) |
|
if old_token: |
|
if isinstance(old_token, bytes): |
|
old_token = old_token.decode("utf-8") |
|
cls.revoke_token(old_token, token_type) |
|
|
|
token = str(uuid.uuid4()) |
|
token_data = {"account_id": account_id, "email": account_email, "token_type": token_type} |
|
if additional_data: |
|
token_data.update(additional_data) |
|
|
|
expiry_minutes = dify_config.model_dump().get(f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES") |
|
token_key = cls._get_token_key(token, token_type) |
|
expiry_time = int(expiry_minutes * 60) |
|
redis_client.setex(token_key, expiry_time, json.dumps(token_data)) |
|
|
|
if account_id: |
|
cls._set_current_token_for_account(account.id, token, token_type, expiry_minutes) |
|
|
|
return token |
|
|
|
@classmethod |
|
def _get_token_key(cls, token: str, token_type: str) -> str: |
|
return f"{token_type}:token:{token}" |
|
|
|
@classmethod |
|
def revoke_token(cls, token: str, token_type: str): |
|
token_key = cls._get_token_key(token, token_type) |
|
redis_client.delete(token_key) |
|
|
|
@classmethod |
|
def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]: |
|
key = cls._get_token_key(token, token_type) |
|
token_data_json = redis_client.get(key) |
|
if token_data_json is None: |
|
logging.warning(f"{token_type} token {token} not found with key {key}") |
|
return None |
|
token_data = json.loads(token_data_json) |
|
return token_data |
|
|
|
@classmethod |
|
def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]: |
|
key = cls._get_account_token_key(account_id, token_type) |
|
current_token = redis_client.get(key) |
|
return current_token |
|
|
|
@classmethod |
|
def _set_current_token_for_account( |
|
cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float] |
|
): |
|
key = cls._get_account_token_key(account_id, token_type) |
|
expiry_time = int(expiry_hours * 60 * 60) |
|
redis_client.setex(key, expiry_time, token) |
|
|
|
@classmethod |
|
def _get_account_token_key(cls, account_id: str, token_type: str) -> str: |
|
return f"{token_type}:account:{account_id}" |
|
|
|
|
|
class RateLimiter: |
|
def __init__(self, prefix: str, max_attempts: int, time_window: int): |
|
self.prefix = prefix |
|
self.max_attempts = max_attempts |
|
self.time_window = time_window |
|
|
|
def _get_key(self, email: str) -> str: |
|
return f"{self.prefix}:{email}" |
|
|
|
def is_rate_limited(self, email: str) -> bool: |
|
key = self._get_key(email) |
|
current_time = int(time.time()) |
|
window_start_time = current_time - self.time_window |
|
|
|
redis_client.zremrangebyscore(key, "-inf", window_start_time) |
|
attempts = redis_client.zcard(key) |
|
|
|
if attempts and int(attempts) >= self.max_attempts: |
|
return True |
|
return False |
|
|
|
def increment_rate_limit(self, email: str): |
|
key = self._get_key(email) |
|
current_time = int(time.time()) |
|
|
|
redis_client.zadd(key, {current_time: current_time}) |
|
redis_client.expire(key, self.time_window * 2) |
|
|