import logging import math import re from datetime import datetime from typing import Optional import uuid from open_webui.utils.misc import get_last_user_message, get_messages_content from open_webui.env import SRC_LOG_LEVELS from open_webui.config import DEFAULT_RAG_TEMPLATE log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) def prompt_template( template: str, user_name: Optional[str] = None, user_location: Optional[str] = None ) -> str: # Get the current date current_date = datetime.now() # Format the date to YYYY-MM-DD formatted_date = current_date.strftime("%Y-%m-%d") formatted_time = current_date.strftime("%I:%M:%S %p") template = template.replace("{{CURRENT_DATE}}", formatted_date) template = template.replace("{{CURRENT_TIME}}", formatted_time) template = template.replace( "{{CURRENT_DATETIME}}", f"{formatted_date} {formatted_time}" ) if user_name: # Replace {{USER_NAME}} in the template with the user's name template = template.replace("{{USER_NAME}}", user_name) else: # Replace {{USER_NAME}} in the template with "Unknown" template = template.replace("{{USER_NAME}}", "Unknown") if user_location: # Replace {{USER_LOCATION}} in the template with the current location template = template.replace("{{USER_LOCATION}}", user_location) else: # Replace {{USER_LOCATION}} in the template with "Unknown" template = template.replace("{{USER_LOCATION}}", "Unknown") return template def replace_prompt_variable(template: str, prompt: str) -> str: def replacement_function(match): full_match = match.group(0) start_length = match.group(1) end_length = match.group(2) middle_length = match.group(3) if full_match == "{{prompt}}": return prompt elif start_length is not None: return prompt[: int(start_length)] elif end_length is not None: return prompt[-int(end_length) :] elif middle_length is not None: middle_length = int(middle_length) if len(prompt) <= middle_length: return prompt start = prompt[: math.ceil(middle_length / 2)] end = prompt[-math.floor(middle_length / 2) :] return f"{start}...{end}" return "" template = re.sub( r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", replacement_function, template, ) return template def replace_messages_variable(template: str, messages: list[str]) -> str: def replacement_function(match): full_match = match.group(0) start_length = match.group(1) end_length = match.group(2) middle_length = match.group(3) # Process messages based on the number of messages required if full_match == "{{MESSAGES}}": return get_messages_content(messages) elif start_length is not None: return get_messages_content(messages[: int(start_length)]) elif end_length is not None: return get_messages_content(messages[-int(end_length) :]) elif middle_length is not None: mid = int(middle_length) if len(messages) <= mid: return get_messages_content(messages) # Handle middle truncation: split to get start and end portions of the messages list half = mid // 2 start_msgs = messages[:half] end_msgs = messages[-half:] if mid % 2 == 0 else messages[-(half + 1) :] formatted_start = get_messages_content(start_msgs) formatted_end = get_messages_content(end_msgs) return f"{formatted_start}\n{formatted_end}" return "" template = re.sub( r"{{MESSAGES}}|{{MESSAGES:START:(\d+)}}|{{MESSAGES:END:(\d+)}}|{{MESSAGES:MIDDLETRUNCATE:(\d+)}}", replacement_function, template, ) return template # {{prompt:middletruncate:8000}} def rag_template(template: str, context: str, query: str): if template == "": template = DEFAULT_RAG_TEMPLATE if "[context]" not in template and "{{CONTEXT}}" not in template: log.debug( "WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder." ) if "" in context and "" in context: log.debug( "WARNING: Potential prompt injection attack: the RAG " "context contains '' and ''. This might be " "nothing, or the user might be trying to hack something." ) query_placeholders = [] if "[query]" in context: query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" template = template.replace("[query]", query_placeholder) query_placeholders.append(query_placeholder) if "{{QUERY}}" in context: query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" template = template.replace("{{QUERY}}", query_placeholder) query_placeholders.append(query_placeholder) template = template.replace("[context]", context) template = template.replace("{{CONTEXT}}", context) template = template.replace("[query]", query) template = template.replace("{{QUERY}}", query) for query_placeholder in query_placeholders: template = template.replace(query_placeholder, query) return template def title_generation_template( template: str, messages: list[dict], user: Optional[dict] = None ) -> str: prompt = get_last_user_message(messages) template = replace_prompt_variable(template, prompt) template = replace_messages_variable(template, messages) template = prompt_template( template, **( {"user_name": user.get("name"), "user_location": user.get("location")} if user else {} ), ) return template def tags_generation_template( template: str, messages: list[dict], user: Optional[dict] = None ) -> str: prompt = get_last_user_message(messages) template = replace_prompt_variable(template, prompt) template = replace_messages_variable(template, messages) template = prompt_template( template, **( {"user_name": user.get("name"), "user_location": user.get("location")} if user else {} ), ) return template def emoji_generation_template( template: str, prompt: str, user: Optional[dict] = None ) -> str: template = replace_prompt_variable(template, prompt) template = prompt_template( template, **( {"user_name": user.get("name"), "user_location": user.get("location")} if user else {} ), ) return template def query_generation_template( template: str, messages: list[dict], user: Optional[dict] = None ) -> str: prompt = get_last_user_message(messages) template = replace_prompt_variable(template, prompt) template = replace_messages_variable(template, messages) template = prompt_template( template, **( {"user_name": user.get("name"), "user_location": user.get("location")} if user else {} ), ) return template def moa_response_generation_template( template: str, prompt: str, responses: list[str] ) -> str: def replacement_function(match): full_match = match.group(0) start_length = match.group(1) end_length = match.group(2) middle_length = match.group(3) if full_match == "{{prompt}}": return prompt elif start_length is not None: return prompt[: int(start_length)] elif end_length is not None: return prompt[-int(end_length) :] elif middle_length is not None: middle_length = int(middle_length) if len(prompt) <= middle_length: return prompt start = prompt[: math.ceil(middle_length / 2)] end = prompt[-math.floor(middle_length / 2) :] return f"{start}...{end}" return "" template = re.sub( r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", replacement_function, template, ) responses = [f'"""{response}"""' for response in responses] responses = "\n\n".join(responses) template = template.replace("{{responses}}", responses) return template def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: template = template.replace("{{TOOLS}}", tools_specs) return template