Spaces:
Running
Running
0.4
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +2 -1
- .env.example +13 -15
- .gitattributes +1 -0
- .github/workflows/deploy-to-hf-spaces.yml +4 -0
- CHANGELOG.md +61 -0
- Dockerfile +176 -177
- README.md +1 -13
- backend/open_webui/apps/audio/main.py +75 -2
- backend/open_webui/apps/images/main.py +14 -2
- backend/open_webui/apps/ollama/main.py +354 -151
- backend/open_webui/apps/openai/main.py +280 -122
- backend/open_webui/apps/retrieval/loaders/main.py +1 -1
- backend/open_webui/apps/retrieval/main.py +144 -20
- backend/open_webui/apps/retrieval/utils.py +76 -57
- backend/open_webui/apps/retrieval/vector/connector.py +8 -0
- backend/open_webui/apps/retrieval/vector/dbs/chroma.py +3 -1
- backend/open_webui/apps/retrieval/vector/dbs/opensearch.py +178 -0
- backend/open_webui/apps/retrieval/vector/dbs/pgvector.py +354 -0
- backend/open_webui/apps/retrieval/vector/dbs/qdrant.py +7 -2
- backend/open_webui/apps/retrieval/web/bing.py +73 -0
- backend/open_webui/apps/retrieval/web/jina_search.py +2 -4
- backend/open_webui/apps/retrieval/web/testdata/bing.json +58 -0
- backend/open_webui/apps/socket/main.py +2 -0
- backend/open_webui/apps/webui/main.py +45 -7
- backend/open_webui/apps/webui/models/auths.py +5 -0
- backend/open_webui/apps/webui/models/chats.py +13 -6
- backend/open_webui/apps/webui/models/groups.py +186 -0
- backend/open_webui/apps/webui/models/knowledge.py +76 -23
- backend/open_webui/apps/webui/models/models.py +103 -8
- backend/open_webui/apps/webui/models/prompts.py +58 -9
- backend/open_webui/apps/webui/models/tools.py +56 -4
- backend/open_webui/apps/webui/models/users.py +8 -0
- backend/open_webui/apps/webui/routers/auths.py +281 -3
- backend/open_webui/apps/webui/routers/chats.py +9 -5
- backend/open_webui/apps/webui/routers/groups.py +120 -0
- backend/open_webui/apps/webui/routers/knowledge.py +202 -78
- backend/open_webui/apps/webui/routers/models.py +120 -36
- backend/open_webui/apps/webui/routers/prompts.py +69 -7
- backend/open_webui/apps/webui/routers/tools.py +144 -80
- backend/open_webui/apps/webui/routers/users.py +41 -4
- backend/open_webui/apps/webui/utils.py +1 -1
- backend/open_webui/config.py +229 -35
- backend/open_webui/constants.py +2 -0
- backend/open_webui/env.py +393 -384
- backend/open_webui/main.py +0 -0
- backend/open_webui/migrations/versions/922e7a387820_add_group_table.py +85 -0
- backend/open_webui/storage/provider.py +4 -1
- backend/open_webui/utils/access_control.py +95 -0
- backend/open_webui/utils/pdf_generator.py +12 -12
- backend/open_webui/utils/security_headers.py +11 -0
.dockerignore
CHANGED
@@ -16,4 +16,5 @@ _old
|
|
16 |
uploads
|
17 |
.ipynb_checkpoints
|
18 |
**/*.db
|
19 |
-
_test
|
|
|
|
16 |
uploads
|
17 |
.ipynb_checkpoints
|
18 |
**/*.db
|
19 |
+
_test
|
20 |
+
backend/data/*
|
.env.example
CHANGED
@@ -1,15 +1,13 @@
|
|
1 |
-
# Ollama URL for the backend to connect
|
2 |
-
# The path '/ollama' will be redirected to the specified backend URL
|
3 |
-
OLLAMA_BASE_URL='http://localhost:11434'
|
4 |
-
|
5 |
-
OPENAI_API_BASE_URL=''
|
6 |
-
OPENAI_API_KEY=''
|
7 |
-
|
8 |
-
# AUTOMATIC1111_BASE_URL="http://localhost:7860"
|
9 |
-
|
10 |
-
# DO NOT TRACK
|
11 |
-
SCARF_NO_ANALYTICS=true
|
12 |
-
DO_NOT_TRACK=true
|
13 |
-
ANONYMIZED_TELEMETRY=false
|
14 |
-
|
15 |
-
GLOBAL_LOG_LEVEL="ERROR"
|
|
|
1 |
+
# Ollama URL for the backend to connect
|
2 |
+
# The path '/ollama' will be redirected to the specified backend URL
|
3 |
+
OLLAMA_BASE_URL='http://localhost:11434'
|
4 |
+
|
5 |
+
OPENAI_API_BASE_URL=''
|
6 |
+
OPENAI_API_KEY=''
|
7 |
+
|
8 |
+
# AUTOMATIC1111_BASE_URL="http://localhost:7860"
|
9 |
+
|
10 |
+
# DO NOT TRACK
|
11 |
+
SCARF_NO_ANALYTICS=true
|
12 |
+
DO_NOT_TRACK=true
|
13 |
+
ANONYMIZED_TELEMETRY=false
|
|
|
|
.gitattributes
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
*.sh text eol=lf
|
2 |
*.ttf filter=lfs diff=lfs merge=lfs -text
|
|
|
|
1 |
*.sh text eol=lf
|
2 |
*.ttf filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/deploy-to-hf-spaces.yml
CHANGED
@@ -28,6 +28,8 @@ jobs:
|
|
28 |
steps:
|
29 |
- name: Checkout repository
|
30 |
uses: actions/checkout@v4
|
|
|
|
|
31 |
|
32 |
- name: Remove git history
|
33 |
run: rm -rf .git
|
@@ -52,7 +54,9 @@ jobs:
|
|
52 |
- name: Set up Git and push to Space
|
53 |
run: |
|
54 |
git init --initial-branch=main
|
|
|
55 |
git lfs track "*.ttf"
|
|
|
56 |
rm demo.gif
|
57 |
git add .
|
58 |
git commit -m "GitHub deploy: ${{ github.sha }}"
|
|
|
28 |
steps:
|
29 |
- name: Checkout repository
|
30 |
uses: actions/checkout@v4
|
31 |
+
with:
|
32 |
+
lfs: true
|
33 |
|
34 |
- name: Remove git history
|
35 |
run: rm -rf .git
|
|
|
54 |
- name: Set up Git and push to Space
|
55 |
run: |
|
56 |
git init --initial-branch=main
|
57 |
+
git lfs install
|
58 |
git lfs track "*.ttf"
|
59 |
+
git lfs track "*.jpg"
|
60 |
rm demo.gif
|
61 |
git add .
|
62 |
git commit -m "GitHub deploy: ${{ github.sha }}"
|
CHANGELOG.md
CHANGED
@@ -5,10 +5,71 @@ All notable changes to this project will be documented in this file.
|
|
5 |
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
6 |
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
## [0.3.35] - 2024-10-26
|
9 |
|
10 |
### Added
|
11 |
|
|
|
12 |
- **📁 Robust File Handling**: Enhanced file input handling for chat. If the content extraction fails or is empty, users will now receive a clear warning, preventing silent failures and ensuring you always know what's happening with your uploads.
|
13 |
- **🌍 New Language Support**: Introduced Hungarian translations and updated French translations, expanding the platform's language accessibility for a more global user base.
|
14 |
|
|
|
5 |
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
6 |
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
7 |
|
8 |
+
## [0.4.1] - 2024-11-19
|
9 |
+
|
10 |
+
### Added
|
11 |
+
|
12 |
+
- **🛠️ Tool Descriptions on Hover**: When enabled, tool descriptions now appear upon hovering over the tool icon in the message input, giving you more context instantly and improving workflow fluidity.
|
13 |
+
|
14 |
+
### Fixed
|
15 |
+
|
16 |
+
- **🚫 Graceful Handling of Deleted Users**: Resolved an issue where deleted users caused models, knowledge, prompts, or tools to fail loading in the workspace, ensuring smoother operation and fewer interruptions.
|
17 |
+
- **🔗 Proxy Fix for HTTPS Models Endpoint**: Fixed issues with proxies affecting the secure `/api/v1/models/` endpoint, ensuring stable connections and reliable access.
|
18 |
+
- **🔒 API Key Creation**: Addressed a bug that previously prevented API keys from being created.
|
19 |
+
|
20 |
+
## [0.4.0] - 2024-11-19
|
21 |
+
|
22 |
+
### Added
|
23 |
+
|
24 |
+
- **👥 User Groups**: You can now create and manage user groups, making user organization seamless.
|
25 |
+
- **🔐 Group-Based Access Control**: Set granular access to models, knowledge, prompts, and tools based on user groups, allowing for more controlled and secure environments.
|
26 |
+
- **🛠️ Group-Based User Permissions**: Easily manage workspace permissions. Grant users the ability to upload files, delete, edit, or create temporary chats, as well as define their ability to create models, knowledge, prompts, and tools.
|
27 |
+
- **🔑 LDAP Support**: Newly introduced LDAP authentication adds robust security and scalability to user management.
|
28 |
+
- **🌐 Enhanced OpenAI-Compatible Connections**: Added prefix ID support to avoid model ID clashes, with explicit model ID support for APIs lacking '/models' endpoint support, ensuring smooth operation with custom setups.
|
29 |
+
- **🔐 Ollama API Key Support**: Now manage credentials for Ollama when set behind proxies, including the option to utilize prefix ID for proper distinction across multiple Ollama instances.
|
30 |
+
- **🔄 Connection Enable/Disable Toggle**: Easily enable or disable individual OpenAI and Ollama connections as needed.
|
31 |
+
- **🎨 Redesigned Model Workspace**: Freshly redesigned to improve usability for managing models across users and groups.
|
32 |
+
- **🎨 Redesigned Prompt Workspace**: A fresh UI to conveniently organize and manage prompts.
|
33 |
+
- **🧩 Sorted Functions Workspace**: Functions are now automatically categorized by type (Action, Filter, Pipe), streamlining management.
|
34 |
+
- **💻 Redesigned Collaborative Workspace**: Enhanced support for multiple users contributing to models, knowledge, prompts, or tools, improving collaboration.
|
35 |
+
- **🔧 Auto-Selected Tools in Model Editor**: Tools enabled through the model editor are now automatically selected, whereas previously it only gave users the option to enable the tool, reducing manual steps and enhancing efficiency.
|
36 |
+
- **🔔 Web Search & Tools Indicator**: A clear indication now shows when web search or tools are active, reducing confusion.
|
37 |
+
- **🔑 Toggle API Key Auth**: Tighten security by easily enabling or disabling API key authentication option for Open WebUI.
|
38 |
+
- **🗂️ Agentic Retrieval**: Improve RAG accuracy via smart pre-processing of chat history to determine the best queries before retrieval.
|
39 |
+
- **📁 Large Text as File Option**: Optionally convert large pasted text into a file upload, keeping the chat interface cleaner.
|
40 |
+
- **🗂️ Toggle Citations for Models**: Ability to disable citations has been introduced in the model editor.
|
41 |
+
- **🔍 User Settings Search**: Quickly search for settings fields, improving ease of use and navigation.
|
42 |
+
- **🗣️ Experimental SpeechT5 TTS**: Local SpeechT5 support added for improved text-to-speech capabilities.
|
43 |
+
- **🔄 Unified Reset for Models**: A one-click option has been introduced to reset and remove all models from the Admin Settings.
|
44 |
+
- **🛠️ Initial Setup Wizard**: The setup process now explicitly informs users that they are creating an admin account during the first-time setup, ensuring clarity. Previously, users encountered the login page right away without this distinction.
|
45 |
+
- **🌐 Enhanced Translations**: Several language translations, including Ukrainian, Norwegian, and Brazilian Portuguese, were refined for better localization.
|
46 |
+
|
47 |
+
### Fixed
|
48 |
+
|
49 |
+
- **🎥 YouTube Video Attachments**: Fixed issues preventing proper loading and attachment of YouTube videos as files.
|
50 |
+
- **🔄 Shared Chat Update**: Corrected issues where shared chats were not updating, improving collaboration consistency.
|
51 |
+
- **🔍 DuckDuckGo Rate Limit Fix**: Addressed issues with DuckDuckGo search integration, enhancing search stability and performance when operating within rate limits.
|
52 |
+
- **🧾 Citations Relevance Fix**: Adjusted the relevance percentage calculation for citations, so that Open WebUI properly reflect the accuracy of a retrieved document in RAG, ensuring users get clearer insights into sources.
|
53 |
+
- **🔑 Jina Search API Key Requirement**: Added the option to input an API key for Jina Search, ensuring smooth functionality as keys are now mandatory.
|
54 |
+
|
55 |
+
### Changed
|
56 |
+
|
57 |
+
- **🛠️ Functions Moved to Admin Panel**: As Functions operate as advanced plugins, they are now accessible from the Admin Panel instead of the workspace.
|
58 |
+
- **🛠️ Manage Ollama Connections**: The "Models" section in Admin Settings has been relocated to Admin Settings > "Connections" > Ollama Connections. You can now manage Ollama instances via a dedicated "Manage Ollama" modal from "Connections", streamlining the setup and configuration of Ollama models.
|
59 |
+
- **📊 Base Models in Admin Settings**: Admins can now find all base models, both connections or functions, in the "Models" Admin setting. Global model accessibility can be enabled or disabled here. Models are private by default, requiring explicit permission assignment for user access.
|
60 |
+
- **📌 Sticky Model Selection for New Chats**: The model chosen from a previous chat now persists when creating a new chat. If you click "New Chat" again from the new chat page, it will revert to your default model.
|
61 |
+
- **🎨 Design Refactoring**: Overall design refinements across the platform have been made, providing a more cohesive and polished user experience.
|
62 |
+
|
63 |
+
### Removed
|
64 |
+
|
65 |
+
- **📂 Model List Reordering**: Temporarily removed and will be reintroduced in upcoming user group settings improvements.
|
66 |
+
- **⚙️ Default Model Setting**: Removed the ability to set a default model for users, will be reintroduced with user group settings in the future.
|
67 |
+
|
68 |
## [0.3.35] - 2024-10-26
|
69 |
|
70 |
### Added
|
71 |
|
72 |
+
- **🌐 Translation Update**: Added translation labels in the SearchInput and CreateCollection components and updated Brazilian Portuguese translation (pt-BR)
|
73 |
- **📁 Robust File Handling**: Enhanced file input handling for chat. If the content extraction fails or is empty, users will now receive a clear warning, preventing silent failures and ensuring you always know what's happening with your uploads.
|
74 |
- **🌍 New Language Support**: Introduced Hungarian translations and updated French translations, expanding the platform's language accessibility for a more global user base.
|
75 |
|
Dockerfile
CHANGED
@@ -1,177 +1,176 @@
|
|
1 |
-
# syntax=docker/dockerfile:1
|
2 |
-
# Initialize device type args
|
3 |
-
# use build args in the docker build command with --build-arg="BUILDARG=true"
|
4 |
-
ARG USE_CUDA=false
|
5 |
-
ARG USE_OLLAMA=false
|
6 |
-
# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
|
7 |
-
ARG USE_CUDA_VER=cu121
|
8 |
-
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
|
9 |
-
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
|
10 |
-
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
|
11 |
-
# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
|
12 |
-
ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
13 |
-
ARG USE_RERANKING_MODEL=""
|
14 |
-
|
15 |
-
# Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken
|
16 |
-
ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base"
|
17 |
-
|
18 |
-
ARG BUILD_HASH=dev-build
|
19 |
-
# Override at your own risk - non-root configurations are untested
|
20 |
-
ARG UID=0
|
21 |
-
ARG GID=0
|
22 |
-
|
23 |
-
######## WebUI frontend ########
|
24 |
-
FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build
|
25 |
-
ARG BUILD_HASH
|
26 |
-
|
27 |
-
WORKDIR /app
|
28 |
-
|
29 |
-
COPY package.json package-lock.json ./
|
30 |
-
RUN npm ci
|
31 |
-
|
32 |
-
COPY . .
|
33 |
-
ENV APP_BUILD_HASH=${BUILD_HASH}
|
34 |
-
RUN npm run build
|
35 |
-
|
36 |
-
######## WebUI backend ########
|
37 |
-
FROM python:3.11-slim-bookworm AS base
|
38 |
-
|
39 |
-
# Use args
|
40 |
-
ARG USE_CUDA
|
41 |
-
ARG USE_OLLAMA
|
42 |
-
ARG USE_CUDA_VER
|
43 |
-
ARG USE_EMBEDDING_MODEL
|
44 |
-
ARG USE_RERANKING_MODEL
|
45 |
-
ARG UID
|
46 |
-
ARG GID
|
47 |
-
|
48 |
-
## Basis ##
|
49 |
-
ENV ENV=prod \
|
50 |
-
PORT=8080 \
|
51 |
-
# pass build args to the build
|
52 |
-
USE_OLLAMA_DOCKER=${USE_OLLAMA} \
|
53 |
-
USE_CUDA_DOCKER=${USE_CUDA} \
|
54 |
-
USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
|
55 |
-
USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
|
56 |
-
USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
|
57 |
-
|
58 |
-
## Basis URL Config ##
|
59 |
-
ENV OLLAMA_BASE_URL="/ollama" \
|
60 |
-
OPENAI_API_BASE_URL=""
|
61 |
-
|
62 |
-
## API Key and Security Config ##
|
63 |
-
ENV OPENAI_API_KEY="" \
|
64 |
-
WEBUI_SECRET_KEY="" \
|
65 |
-
SCARF_NO_ANALYTICS=true \
|
66 |
-
DO_NOT_TRACK=true \
|
67 |
-
ANONYMIZED_TELEMETRY=false
|
68 |
-
|
69 |
-
#### Other models #########################################################
|
70 |
-
## whisper TTS model settings ##
|
71 |
-
ENV WHISPER_MODEL="base" \
|
72 |
-
WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
|
73 |
-
|
74 |
-
## RAG Embedding model settings ##
|
75 |
-
ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
|
76 |
-
RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
|
77 |
-
SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
|
78 |
-
|
79 |
-
## Tiktoken model settings ##
|
80 |
-
ENV TIKTOKEN_ENCODING_NAME="cl100k_base" \
|
81 |
-
TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken"
|
82 |
-
|
83 |
-
## Hugging Face download cache ##
|
84 |
-
ENV HF_HOME="/app/backend/data/cache/embedding/models"
|
85 |
-
|
86 |
-
## Torch Extensions ##
|
87 |
-
# ENV TORCH_EXTENSIONS_DIR="/.cache/torch_extensions"
|
88 |
-
|
89 |
-
#### Other models ##########################################################
|
90 |
-
|
91 |
-
WORKDIR /app/backend
|
92 |
-
|
93 |
-
ENV HOME=/root
|
94 |
-
# Create user and group if not root
|
95 |
-
RUN if [ $UID -ne 0 ]; then \
|
96 |
-
if [ $GID -ne 0 ]; then \
|
97 |
-
addgroup --gid $GID app; \
|
98 |
-
fi; \
|
99 |
-
adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \
|
100 |
-
fi
|
101 |
-
|
102 |
-
RUN mkdir -p $HOME/.cache/chroma
|
103 |
-
RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id
|
104 |
-
|
105 |
-
# Make sure the user has access to the app and root directory
|
106 |
-
RUN chown -R $UID:$GID /app $HOME
|
107 |
-
|
108 |
-
RUN if [ "$USE_OLLAMA" = "true" ]; then \
|
109 |
-
apt-get update && \
|
110 |
-
# Install pandoc and netcat
|
111 |
-
apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \
|
112 |
-
apt-get install -y --no-install-recommends gcc python3-dev && \
|
113 |
-
# for RAG OCR
|
114 |
-
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
|
115 |
-
# install helper tools
|
116 |
-
apt-get install -y --no-install-recommends curl jq && \
|
117 |
-
# install ollama
|
118 |
-
curl -fsSL https://ollama.com/install.sh | sh && \
|
119 |
-
# cleanup
|
120 |
-
rm -rf /var/lib/apt/lists/*; \
|
121 |
-
else \
|
122 |
-
apt-get update && \
|
123 |
-
# Install pandoc, netcat and gcc
|
124 |
-
apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \
|
125 |
-
apt-get install -y --no-install-recommends gcc python3-dev && \
|
126 |
-
# for RAG OCR
|
127 |
-
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
|
128 |
-
# cleanup
|
129 |
-
rm -rf /var/lib/apt/lists/*; \
|
130 |
-
fi
|
131 |
-
|
132 |
-
# install python dependencies
|
133 |
-
COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
|
134 |
-
|
135 |
-
RUN pip3 install uv && \
|
136 |
-
if [ "$USE_CUDA" = "true" ]; then \
|
137 |
-
# If you use CUDA the whisper and embedding model will be downloaded on first use
|
138 |
-
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
|
139 |
-
uv pip install --system -r requirements.txt --no-cache-dir && \
|
140 |
-
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
|
141 |
-
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
|
142 |
-
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
|
143 |
-
else \
|
144 |
-
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
|
145 |
-
uv pip install --system -r requirements.txt --no-cache-dir && \
|
146 |
-
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
|
147 |
-
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
|
148 |
-
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
|
149 |
-
fi; \
|
150 |
-
chown -R $UID:$GID /app/backend/data/
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
# copy embedding weight from build
|
155 |
-
# RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
|
156 |
-
# COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx
|
157 |
-
|
158 |
-
# copy built frontend files
|
159 |
-
COPY --chown=$UID:$GID --from=build /app/build /app/build
|
160 |
-
COPY --chown=$UID:$GID --from=build /app/CHANGELOG.md /app/CHANGELOG.md
|
161 |
-
COPY --chown=$UID:$GID --from=build /app/package.json /app/package.json
|
162 |
-
|
163 |
-
# copy backend files
|
164 |
-
COPY --chown=$UID:$GID ./backend .
|
165 |
-
|
166 |
-
EXPOSE 8080
|
167 |
-
|
168 |
-
HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1
|
169 |
-
|
170 |
-
USER $UID:$GID
|
171 |
-
|
172 |
-
ARG BUILD_HASH
|
173 |
-
ENV WEBUI_BUILD_VERSION=${BUILD_HASH}
|
174 |
-
ENV DOCKER=true
|
175 |
-
|
176 |
-
|
177 |
-
CMD [ "bash", "start.sh"]
|
|
|
1 |
+
# syntax=docker/dockerfile:1
|
2 |
+
# Initialize device type args
|
3 |
+
# use build args in the docker build command with --build-arg="BUILDARG=true"
|
4 |
+
ARG USE_CUDA=false
|
5 |
+
ARG USE_OLLAMA=false
|
6 |
+
# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
|
7 |
+
ARG USE_CUDA_VER=cu121
|
8 |
+
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
|
9 |
+
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
|
10 |
+
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
|
11 |
+
# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
|
12 |
+
ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
13 |
+
ARG USE_RERANKING_MODEL=""
|
14 |
+
|
15 |
+
# Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken
|
16 |
+
ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base"
|
17 |
+
|
18 |
+
ARG BUILD_HASH=dev-build
|
19 |
+
# Override at your own risk - non-root configurations are untested
|
20 |
+
ARG UID=0
|
21 |
+
ARG GID=0
|
22 |
+
|
23 |
+
######## WebUI frontend ########
|
24 |
+
FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build
|
25 |
+
ARG BUILD_HASH
|
26 |
+
|
27 |
+
WORKDIR /app
|
28 |
+
|
29 |
+
COPY package.json package-lock.json ./
|
30 |
+
RUN npm ci
|
31 |
+
|
32 |
+
COPY . .
|
33 |
+
ENV APP_BUILD_HASH=${BUILD_HASH}
|
34 |
+
RUN npm run build
|
35 |
+
|
36 |
+
######## WebUI backend ########
|
37 |
+
FROM python:3.11-slim-bookworm AS base
|
38 |
+
|
39 |
+
# Use args
|
40 |
+
ARG USE_CUDA
|
41 |
+
ARG USE_OLLAMA
|
42 |
+
ARG USE_CUDA_VER
|
43 |
+
ARG USE_EMBEDDING_MODEL
|
44 |
+
ARG USE_RERANKING_MODEL
|
45 |
+
ARG UID
|
46 |
+
ARG GID
|
47 |
+
|
48 |
+
## Basis ##
|
49 |
+
ENV ENV=prod \
|
50 |
+
PORT=8080 \
|
51 |
+
# pass build args to the build
|
52 |
+
USE_OLLAMA_DOCKER=${USE_OLLAMA} \
|
53 |
+
USE_CUDA_DOCKER=${USE_CUDA} \
|
54 |
+
USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
|
55 |
+
USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
|
56 |
+
USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
|
57 |
+
|
58 |
+
## Basis URL Config ##
|
59 |
+
ENV OLLAMA_BASE_URL="/ollama" \
|
60 |
+
OPENAI_API_BASE_URL=""
|
61 |
+
|
62 |
+
## API Key and Security Config ##
|
63 |
+
ENV OPENAI_API_KEY="" \
|
64 |
+
WEBUI_SECRET_KEY="" \
|
65 |
+
SCARF_NO_ANALYTICS=true \
|
66 |
+
DO_NOT_TRACK=true \
|
67 |
+
ANONYMIZED_TELEMETRY=false
|
68 |
+
|
69 |
+
#### Other models #########################################################
|
70 |
+
## whisper TTS model settings ##
|
71 |
+
ENV WHISPER_MODEL="base" \
|
72 |
+
WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
|
73 |
+
|
74 |
+
## RAG Embedding model settings ##
|
75 |
+
ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
|
76 |
+
RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
|
77 |
+
SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
|
78 |
+
|
79 |
+
## Tiktoken model settings ##
|
80 |
+
ENV TIKTOKEN_ENCODING_NAME="cl100k_base" \
|
81 |
+
TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken"
|
82 |
+
|
83 |
+
## Hugging Face download cache ##
|
84 |
+
ENV HF_HOME="/app/backend/data/cache/embedding/models"
|
85 |
+
|
86 |
+
## Torch Extensions ##
|
87 |
+
# ENV TORCH_EXTENSIONS_DIR="/.cache/torch_extensions"
|
88 |
+
|
89 |
+
#### Other models ##########################################################
|
90 |
+
|
91 |
+
WORKDIR /app/backend
|
92 |
+
|
93 |
+
ENV HOME=/root
|
94 |
+
# Create user and group if not root
|
95 |
+
RUN if [ $UID -ne 0 ]; then \
|
96 |
+
if [ $GID -ne 0 ]; then \
|
97 |
+
addgroup --gid $GID app; \
|
98 |
+
fi; \
|
99 |
+
adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \
|
100 |
+
fi
|
101 |
+
|
102 |
+
RUN mkdir -p $HOME/.cache/chroma
|
103 |
+
RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id
|
104 |
+
|
105 |
+
# Make sure the user has access to the app and root directory
|
106 |
+
RUN chown -R $UID:$GID /app $HOME
|
107 |
+
|
108 |
+
RUN if [ "$USE_OLLAMA" = "true" ]; then \
|
109 |
+
apt-get update && \
|
110 |
+
# Install pandoc and netcat
|
111 |
+
apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \
|
112 |
+
apt-get install -y --no-install-recommends gcc python3-dev && \
|
113 |
+
# for RAG OCR
|
114 |
+
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
|
115 |
+
# install helper tools
|
116 |
+
apt-get install -y --no-install-recommends curl jq && \
|
117 |
+
# install ollama
|
118 |
+
curl -fsSL https://ollama.com/install.sh | sh && \
|
119 |
+
# cleanup
|
120 |
+
rm -rf /var/lib/apt/lists/*; \
|
121 |
+
else \
|
122 |
+
apt-get update && \
|
123 |
+
# Install pandoc, netcat and gcc
|
124 |
+
apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \
|
125 |
+
apt-get install -y --no-install-recommends gcc python3-dev && \
|
126 |
+
# for RAG OCR
|
127 |
+
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
|
128 |
+
# cleanup
|
129 |
+
rm -rf /var/lib/apt/lists/*; \
|
130 |
+
fi
|
131 |
+
|
132 |
+
# install python dependencies
|
133 |
+
COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
|
134 |
+
|
135 |
+
RUN pip3 install uv && \
|
136 |
+
if [ "$USE_CUDA" = "true" ]; then \
|
137 |
+
# If you use CUDA the whisper and embedding model will be downloaded on first use
|
138 |
+
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
|
139 |
+
uv pip install --system -r requirements.txt --no-cache-dir && \
|
140 |
+
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
|
141 |
+
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
|
142 |
+
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
|
143 |
+
else \
|
144 |
+
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
|
145 |
+
uv pip install --system -r requirements.txt --no-cache-dir && \
|
146 |
+
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
|
147 |
+
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
|
148 |
+
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
|
149 |
+
fi; \
|
150 |
+
chown -R $UID:$GID /app/backend/data/
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
# copy embedding weight from build
|
155 |
+
# RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
|
156 |
+
# COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx
|
157 |
+
|
158 |
+
# copy built frontend files
|
159 |
+
COPY --chown=$UID:$GID --from=build /app/build /app/build
|
160 |
+
COPY --chown=$UID:$GID --from=build /app/CHANGELOG.md /app/CHANGELOG.md
|
161 |
+
COPY --chown=$UID:$GID --from=build /app/package.json /app/package.json
|
162 |
+
|
163 |
+
# copy backend files
|
164 |
+
COPY --chown=$UID:$GID ./backend .
|
165 |
+
|
166 |
+
EXPOSE 8080
|
167 |
+
|
168 |
+
HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1
|
169 |
+
|
170 |
+
USER $UID:$GID
|
171 |
+
|
172 |
+
ARG BUILD_HASH
|
173 |
+
ENV WEBUI_BUILD_VERSION=${BUILD_HASH}
|
174 |
+
ENV DOCKER=true
|
175 |
+
|
176 |
+
CMD [ "bash", "start.sh"]
|
|
README.md
CHANGED
@@ -45,7 +45,7 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-
|
|
45 |
|
46 |
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
|
47 |
|
48 |
-
- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch` and `
|
49 |
|
50 |
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
|
51 |
|
@@ -195,18 +195,6 @@ docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --a
|
|
195 |
|
196 |
Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/).
|
197 |
|
198 |
-
## Supporters ✨
|
199 |
-
|
200 |
-
A big shoutout to our amazing supporters who's helping to make this project possible! 🙏
|
201 |
-
|
202 |
-
### Platinum Sponsors 🤍
|
203 |
-
|
204 |
-
- We're looking for Sponsors!
|
205 |
-
|
206 |
-
### Acknowledgments
|
207 |
-
|
208 |
-
Special thanks to [Prof. Lawrence Kim](https://www.lhkim.com/) and [Prof. Nick Vincent](https://www.nickmvincent.com/) for their invaluable support and guidance in shaping this project into a research endeavor. Grateful for your mentorship throughout the journey! 🙌
|
209 |
-
|
210 |
## License 📜
|
211 |
|
212 |
This project is licensed under the [MIT License](LICENSE) - see the [LICENSE](LICENSE) file for details. 📄
|
|
|
45 |
|
46 |
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
|
47 |
|
48 |
+
- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch`, `SearchApi` and `Bing` and inject the results directly into your chat experience.
|
49 |
|
50 |
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
|
51 |
|
|
|
195 |
|
196 |
Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/).
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
## License 📜
|
199 |
|
200 |
This project is licensed under the [MIT License](LICENSE) - see the [LICENSE](LICENSE) file for details. 📄
|
backend/open_webui/apps/audio/main.py
CHANGED
@@ -32,7 +32,13 @@ from open_webui.config import (
|
|
32 |
)
|
33 |
|
34 |
from open_webui.constants import ERROR_MESSAGES
|
35 |
-
from open_webui.env import
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
|
37 |
from fastapi.middleware.cors import CORSMiddleware
|
38 |
from fastapi.responses import FileResponse
|
@@ -47,7 +53,12 @@ MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
|
47 |
log = logging.getLogger(__name__)
|
48 |
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
49 |
|
50 |
-
app = FastAPI(
|
|
|
|
|
|
|
|
|
|
|
51 |
app.add_middleware(
|
52 |
CORSMiddleware,
|
53 |
allow_origins=CORS_ALLOW_ORIGIN,
|
@@ -74,6 +85,10 @@ app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
|
|
74 |
app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
|
75 |
app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
|
76 |
|
|
|
|
|
|
|
|
|
77 |
app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
|
78 |
app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
|
79 |
|
@@ -231,6 +246,21 @@ async def update_audio_config(
|
|
231 |
}
|
232 |
|
233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
@app.post("/speech")
|
235 |
async def speech(request: Request, user=Depends(get_verified_user)):
|
236 |
body = await request.body()
|
@@ -248,6 +278,12 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|
248 |
headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
|
249 |
headers["Content-Type"] = "application/json"
|
250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
try:
|
252 |
body = body.decode("utf-8")
|
253 |
body = json.loads(body)
|
@@ -391,6 +427,43 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|
391 |
raise HTTPException(
|
392 |
status_code=500, detail=f"Error synthesizing speech - {response.reason}"
|
393 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
|
395 |
|
396 |
def transcribe(file_path):
|
|
|
32 |
)
|
33 |
|
34 |
from open_webui.constants import ERROR_MESSAGES
|
35 |
+
from open_webui.env import (
|
36 |
+
ENV,
|
37 |
+
SRC_LOG_LEVELS,
|
38 |
+
DEVICE_TYPE,
|
39 |
+
ENABLE_FORWARD_USER_INFO_HEADERS,
|
40 |
+
)
|
41 |
+
|
42 |
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
|
43 |
from fastapi.middleware.cors import CORSMiddleware
|
44 |
from fastapi.responses import FileResponse
|
|
|
53 |
log = logging.getLogger(__name__)
|
54 |
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
55 |
|
56 |
+
app = FastAPI(
|
57 |
+
docs_url="/docs" if ENV == "dev" else None,
|
58 |
+
openapi_url="/openapi.json" if ENV == "dev" else None,
|
59 |
+
redoc_url=None,
|
60 |
+
)
|
61 |
+
|
62 |
app.add_middleware(
|
63 |
CORSMiddleware,
|
64 |
allow_origins=CORS_ALLOW_ORIGIN,
|
|
|
85 |
app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
|
86 |
app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
|
87 |
|
88 |
+
|
89 |
+
app.state.speech_synthesiser = None
|
90 |
+
app.state.speech_speaker_embeddings_dataset = None
|
91 |
+
|
92 |
app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
|
93 |
app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
|
94 |
|
|
|
246 |
}
|
247 |
|
248 |
|
249 |
+
def load_speech_pipeline():
|
250 |
+
from transformers import pipeline
|
251 |
+
from datasets import load_dataset
|
252 |
+
|
253 |
+
if app.state.speech_synthesiser is None:
|
254 |
+
app.state.speech_synthesiser = pipeline(
|
255 |
+
"text-to-speech", "microsoft/speecht5_tts"
|
256 |
+
)
|
257 |
+
|
258 |
+
if app.state.speech_speaker_embeddings_dataset is None:
|
259 |
+
app.state.speech_speaker_embeddings_dataset = load_dataset(
|
260 |
+
"Matthijs/cmu-arctic-xvectors", split="validation"
|
261 |
+
)
|
262 |
+
|
263 |
+
|
264 |
@app.post("/speech")
|
265 |
async def speech(request: Request, user=Depends(get_verified_user)):
|
266 |
body = await request.body()
|
|
|
278 |
headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
|
279 |
headers["Content-Type"] = "application/json"
|
280 |
|
281 |
+
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
282 |
+
headers["X-OpenWebUI-User-Name"] = user.name
|
283 |
+
headers["X-OpenWebUI-User-Id"] = user.id
|
284 |
+
headers["X-OpenWebUI-User-Email"] = user.email
|
285 |
+
headers["X-OpenWebUI-User-Role"] = user.role
|
286 |
+
|
287 |
try:
|
288 |
body = body.decode("utf-8")
|
289 |
body = json.loads(body)
|
|
|
427 |
raise HTTPException(
|
428 |
status_code=500, detail=f"Error synthesizing speech - {response.reason}"
|
429 |
)
|
430 |
+
elif app.state.config.TTS_ENGINE == "transformers":
|
431 |
+
payload = None
|
432 |
+
try:
|
433 |
+
payload = json.loads(body.decode("utf-8"))
|
434 |
+
except Exception as e:
|
435 |
+
log.exception(e)
|
436 |
+
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
437 |
+
|
438 |
+
import torch
|
439 |
+
import soundfile as sf
|
440 |
+
|
441 |
+
load_speech_pipeline()
|
442 |
+
|
443 |
+
embeddings_dataset = app.state.speech_speaker_embeddings_dataset
|
444 |
+
|
445 |
+
speaker_index = 6799
|
446 |
+
try:
|
447 |
+
speaker_index = embeddings_dataset["filename"].index(
|
448 |
+
app.state.config.TTS_MODEL
|
449 |
+
)
|
450 |
+
except Exception:
|
451 |
+
pass
|
452 |
+
|
453 |
+
speaker_embedding = torch.tensor(
|
454 |
+
embeddings_dataset[speaker_index]["xvector"]
|
455 |
+
).unsqueeze(0)
|
456 |
+
|
457 |
+
speech = app.state.speech_synthesiser(
|
458 |
+
payload["input"],
|
459 |
+
forward_params={"speaker_embeddings": speaker_embedding},
|
460 |
+
)
|
461 |
+
|
462 |
+
sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
|
463 |
+
with open(file_body_path, "w") as f:
|
464 |
+
json.dump(json.loads(body.decode("utf-8")), f)
|
465 |
+
|
466 |
+
return FileResponse(file_path)
|
467 |
|
468 |
|
469 |
def transcribe(file_path):
|
backend/open_webui/apps/images/main.py
CHANGED
@@ -35,7 +35,8 @@ from open_webui.config import (
|
|
35 |
AppConfig,
|
36 |
)
|
37 |
from open_webui.constants import ERROR_MESSAGES
|
38 |
-
from open_webui.env import SRC_LOG_LEVELS
|
|
|
39 |
from fastapi import Depends, FastAPI, HTTPException, Request
|
40 |
from fastapi.middleware.cors import CORSMiddleware
|
41 |
from pydantic import BaseModel
|
@@ -47,7 +48,12 @@ log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
|
47 |
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
|
48 |
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
49 |
|
50 |
-
app = FastAPI(
|
|
|
|
|
|
|
|
|
|
|
51 |
app.add_middleware(
|
52 |
CORSMiddleware,
|
53 |
allow_origins=CORS_ALLOW_ORIGIN,
|
@@ -456,6 +462,12 @@ async def image_generations(
|
|
456 |
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
|
457 |
headers["Content-Type"] = "application/json"
|
458 |
|
|
|
|
|
|
|
|
|
|
|
|
|
459 |
data = {
|
460 |
"model": (
|
461 |
app.state.config.MODEL
|
|
|
35 |
AppConfig,
|
36 |
)
|
37 |
from open_webui.constants import ERROR_MESSAGES
|
38 |
+
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
|
39 |
+
|
40 |
from fastapi import Depends, FastAPI, HTTPException, Request
|
41 |
from fastapi.middleware.cors import CORSMiddleware
|
42 |
from pydantic import BaseModel
|
|
|
48 |
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
|
49 |
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
50 |
|
51 |
+
app = FastAPI(
|
52 |
+
docs_url="/docs" if ENV == "dev" else None,
|
53 |
+
openapi_url="/openapi.json" if ENV == "dev" else None,
|
54 |
+
redoc_url=None,
|
55 |
+
)
|
56 |
+
|
57 |
app.add_middleware(
|
58 |
CORSMiddleware,
|
59 |
allow_origins=CORS_ALLOW_ORIGIN,
|
|
|
462 |
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
|
463 |
headers["Content-Type"] = "application/json"
|
464 |
|
465 |
+
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
466 |
+
headers["X-OpenWebUI-User-Name"] = user.name
|
467 |
+
headers["X-OpenWebUI-User-Id"] = user.id
|
468 |
+
headers["X-OpenWebUI-User-Email"] = user.email
|
469 |
+
headers["X-OpenWebUI-User-Role"] = user.role
|
470 |
+
|
471 |
data = {
|
472 |
"model": (
|
473 |
app.state.config.MODEL
|
backend/open_webui/apps/ollama/main.py
CHANGED
@@ -13,18 +13,20 @@ import requests
|
|
13 |
from open_webui.apps.webui.models.models import Models
|
14 |
from open_webui.config import (
|
15 |
CORS_ALLOW_ORIGIN,
|
16 |
-
ENABLE_MODEL_FILTER,
|
17 |
ENABLE_OLLAMA_API,
|
18 |
-
MODEL_FILTER_LIST,
|
19 |
OLLAMA_BASE_URLS,
|
|
|
20 |
UPLOAD_DIR,
|
21 |
AppConfig,
|
22 |
)
|
23 |
-
from open_webui.env import
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
from open_webui.constants import ERROR_MESSAGES
|
27 |
-
from open_webui.env import SRC_LOG_LEVELS
|
28 |
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
|
29 |
from fastapi.middleware.cors import CORSMiddleware
|
30 |
from fastapi.responses import StreamingResponse
|
@@ -41,11 +43,18 @@ from open_webui.utils.payload import (
|
|
41 |
apply_model_system_prompt_to_body,
|
42 |
)
|
43 |
from open_webui.utils.utils import get_admin_user, get_verified_user
|
|
|
44 |
|
45 |
log = logging.getLogger(__name__)
|
46 |
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
47 |
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
app.add_middleware(
|
50 |
CORSMiddleware,
|
51 |
allow_origins=CORS_ALLOW_ORIGIN,
|
@@ -56,12 +65,9 @@ app.add_middleware(
|
|
56 |
|
57 |
app.state.config = AppConfig()
|
58 |
|
59 |
-
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
60 |
-
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
61 |
-
|
62 |
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
63 |
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
64 |
-
app.state.
|
65 |
|
66 |
|
67 |
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
|
@@ -69,60 +75,98 @@ app.state.MODELS = {}
|
|
69 |
# least connections, or least response time for better resource utilization and performance optimization.
|
70 |
|
71 |
|
72 |
-
@app.middleware("http")
|
73 |
-
async def check_url(request: Request, call_next):
|
74 |
-
if len(app.state.MODELS) == 0:
|
75 |
-
await get_all_models()
|
76 |
-
else:
|
77 |
-
pass
|
78 |
-
|
79 |
-
response = await call_next(request)
|
80 |
-
return response
|
81 |
-
|
82 |
-
|
83 |
@app.head("/")
|
84 |
@app.get("/")
|
85 |
async def get_status():
|
86 |
return {"status": True}
|
87 |
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
@app.get("/config")
|
90 |
async def get_config(user=Depends(get_admin_user)):
|
91 |
-
return {
|
|
|
|
|
|
|
|
|
92 |
|
93 |
|
94 |
class OllamaConfigForm(BaseModel):
|
95 |
-
|
|
|
|
|
96 |
|
97 |
|
98 |
@app.post("/config/update")
|
99 |
async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
|
100 |
-
app.state.config.ENABLE_OLLAMA_API = form_data.
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
@app.get("/urls")
|
105 |
-
async def get_ollama_api_urls(user=Depends(get_admin_user)):
|
106 |
-
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
|
107 |
-
|
108 |
|
109 |
-
|
110 |
-
urls: list[str]
|
111 |
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
116 |
|
117 |
-
log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}")
|
118 |
-
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
timeout = aiohttp.ClientTimeout(total=3)
|
123 |
try:
|
|
|
124 |
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
125 |
-
async with session.get(url) as response:
|
126 |
return await response.json()
|
127 |
except Exception as e:
|
128 |
# Handle connection error here
|
@@ -148,10 +192,18 @@ async def post_streaming_url(
|
|
148 |
session = aiohttp.ClientSession(
|
149 |
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
150 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
r = await session.post(
|
152 |
url,
|
153 |
data=payload,
|
154 |
-
headers=
|
155 |
)
|
156 |
r.raise_for_status()
|
157 |
|
@@ -194,29 +246,62 @@ def merge_models_lists(model_lists):
|
|
194 |
for idx, model_list in enumerate(model_lists):
|
195 |
if model_list is not None:
|
196 |
for model in model_list:
|
197 |
-
|
198 |
-
if
|
199 |
model["urls"] = [idx]
|
200 |
-
merged_models[
|
201 |
else:
|
202 |
-
merged_models[
|
203 |
|
204 |
return list(merged_models.values())
|
205 |
|
206 |
|
207 |
async def get_all_models():
|
208 |
log.info("get_all_models()")
|
209 |
-
|
210 |
if app.state.config.ENABLE_OLLAMA_API:
|
211 |
-
tasks = [
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
responses = await asyncio.gather(*tasks)
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
models = {
|
217 |
"models": merge_models_lists(
|
218 |
map(
|
219 |
-
lambda response: response
|
|
|
220 |
)
|
221 |
)
|
222 |
}
|
@@ -224,8 +309,6 @@ async def get_all_models():
|
|
224 |
else:
|
225 |
models = {"models": []}
|
226 |
|
227 |
-
app.state.MODELS = {model["model"]: model for model in models["models"]}
|
228 |
-
|
229 |
return models
|
230 |
|
231 |
|
@@ -234,29 +317,25 @@ async def get_all_models():
|
|
234 |
async def get_ollama_tags(
|
235 |
url_idx: Optional[int] = None, user=Depends(get_verified_user)
|
236 |
):
|
|
|
237 |
if url_idx is None:
|
238 |
models = await get_all_models()
|
239 |
-
|
240 |
-
if app.state.config.ENABLE_MODEL_FILTER:
|
241 |
-
if user.role == "user":
|
242 |
-
models["models"] = list(
|
243 |
-
filter(
|
244 |
-
lambda model: model["name"]
|
245 |
-
in app.state.config.MODEL_FILTER_LIST,
|
246 |
-
models["models"],
|
247 |
-
)
|
248 |
-
)
|
249 |
-
return models
|
250 |
-
return models
|
251 |
else:
|
252 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
r = None
|
255 |
try:
|
256 |
-
r = requests.request(method="GET", url=f"{url}/api/tags")
|
257 |
r.raise_for_status()
|
258 |
|
259 |
-
|
260 |
except Exception as e:
|
261 |
log.exception(e)
|
262 |
error_detail = "Open WebUI: Server Connection Error"
|
@@ -273,6 +352,20 @@ async def get_ollama_tags(
|
|
273 |
detail=error_detail,
|
274 |
)
|
275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
|
277 |
@app.get("/api/version")
|
278 |
@app.get("/api/version/{url_idx}")
|
@@ -281,7 +374,10 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
|
|
281 |
if url_idx is None:
|
282 |
# returns lowest version
|
283 |
tasks = [
|
284 |
-
|
|
|
|
|
|
|
285 |
for url in app.state.config.OLLAMA_BASE_URLS
|
286 |
]
|
287 |
responses = await asyncio.gather(*tasks)
|
@@ -361,8 +457,11 @@ async def push_model(
|
|
361 |
user=Depends(get_admin_user),
|
362 |
):
|
363 |
if url_idx is None:
|
364 |
-
|
365 |
-
|
|
|
|
|
|
|
366 |
else:
|
367 |
raise HTTPException(
|
368 |
status_code=400,
|
@@ -411,8 +510,11 @@ async def copy_model(
|
|
411 |
user=Depends(get_admin_user),
|
412 |
):
|
413 |
if url_idx is None:
|
414 |
-
|
415 |
-
|
|
|
|
|
|
|
416 |
else:
|
417 |
raise HTTPException(
|
418 |
status_code=400,
|
@@ -421,10 +523,18 @@ async def copy_model(
|
|
421 |
|
422 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
423 |
log.info(f"url: {url}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
r = requests.request(
|
425 |
method="POST",
|
426 |
url=f"{url}/api/copy",
|
427 |
-
headers=
|
428 |
data=form_data.model_dump_json(exclude_none=True).encode(),
|
429 |
)
|
430 |
|
@@ -459,8 +569,11 @@ async def delete_model(
|
|
459 |
user=Depends(get_admin_user),
|
460 |
):
|
461 |
if url_idx is None:
|
462 |
-
|
463 |
-
|
|
|
|
|
|
|
464 |
else:
|
465 |
raise HTTPException(
|
466 |
status_code=400,
|
@@ -470,11 +583,18 @@ async def delete_model(
|
|
470 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
471 |
log.info(f"url: {url}")
|
472 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
r = requests.request(
|
474 |
method="DELETE",
|
475 |
url=f"{url}/api/delete",
|
476 |
-
headers={"Content-Type": "application/json"},
|
477 |
data=form_data.model_dump_json(exclude_none=True).encode(),
|
|
|
478 |
)
|
479 |
try:
|
480 |
r.raise_for_status()
|
@@ -501,20 +621,30 @@ async def delete_model(
|
|
501 |
|
502 |
@app.post("/api/show")
|
503 |
async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
|
504 |
-
|
|
|
|
|
|
|
505 |
raise HTTPException(
|
506 |
status_code=400,
|
507 |
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
508 |
)
|
509 |
|
510 |
-
url_idx = random.choice(
|
511 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
512 |
log.info(f"url: {url}")
|
513 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
514 |
r = requests.request(
|
515 |
method="POST",
|
516 |
url=f"{url}/api/show",
|
517 |
-
headers=
|
518 |
data=form_data.model_dump_json(exclude_none=True).encode(),
|
519 |
)
|
520 |
try:
|
@@ -570,23 +700,26 @@ async def generate_embeddings(
|
|
570 |
url_idx: Optional[int] = None,
|
571 |
user=Depends(get_verified_user),
|
572 |
):
|
573 |
-
return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
|
574 |
|
575 |
|
576 |
-
def generate_ollama_embeddings(
|
577 |
form_data: GenerateEmbeddingsForm,
|
578 |
url_idx: Optional[int] = None,
|
579 |
):
|
580 |
log.info(f"generate_ollama_embeddings {form_data}")
|
581 |
|
582 |
if url_idx is None:
|
|
|
|
|
|
|
583 |
model = form_data.model
|
584 |
|
585 |
if ":" not in model:
|
586 |
model = f"{model}:latest"
|
587 |
|
588 |
-
if model in
|
589 |
-
url_idx = random.choice(
|
590 |
else:
|
591 |
raise HTTPException(
|
592 |
status_code=400,
|
@@ -596,10 +729,17 @@ def generate_ollama_embeddings(
|
|
596 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
597 |
log.info(f"url: {url}")
|
598 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
599 |
r = requests.request(
|
600 |
method="POST",
|
601 |
url=f"{url}/api/embeddings",
|
602 |
-
headers=
|
603 |
data=form_data.model_dump_json(exclude_none=True).encode(),
|
604 |
)
|
605 |
try:
|
@@ -630,20 +770,23 @@ def generate_ollama_embeddings(
|
|
630 |
)
|
631 |
|
632 |
|
633 |
-
def generate_ollama_batch_embeddings(
|
634 |
form_data: GenerateEmbedForm,
|
635 |
url_idx: Optional[int] = None,
|
636 |
):
|
637 |
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
638 |
|
639 |
if url_idx is None:
|
|
|
|
|
|
|
640 |
model = form_data.model
|
641 |
|
642 |
if ":" not in model:
|
643 |
model = f"{model}:latest"
|
644 |
|
645 |
-
if model in
|
646 |
-
url_idx = random.choice(
|
647 |
else:
|
648 |
raise HTTPException(
|
649 |
status_code=400,
|
@@ -653,10 +796,17 @@ def generate_ollama_batch_embeddings(
|
|
653 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
654 |
log.info(f"url: {url}")
|
655 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
656 |
r = requests.request(
|
657 |
method="POST",
|
658 |
url=f"{url}/api/embed",
|
659 |
-
headers=
|
660 |
data=form_data.model_dump_json(exclude_none=True).encode(),
|
661 |
)
|
662 |
try:
|
@@ -706,13 +856,16 @@ async def generate_completion(
|
|
706 |
user=Depends(get_verified_user),
|
707 |
):
|
708 |
if url_idx is None:
|
|
|
|
|
|
|
709 |
model = form_data.model
|
710 |
|
711 |
if ":" not in model:
|
712 |
model = f"{model}:latest"
|
713 |
|
714 |
-
if model in
|
715 |
-
url_idx = random.choice(
|
716 |
else:
|
717 |
raise HTTPException(
|
718 |
status_code=400,
|
@@ -720,6 +873,10 @@ async def generate_completion(
|
|
720 |
)
|
721 |
|
722 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
|
|
|
723 |
log.info(f"url: {url}")
|
724 |
|
725 |
return await post_streaming_url(
|
@@ -743,14 +900,17 @@ class GenerateChatCompletionForm(BaseModel):
|
|
743 |
keep_alive: Optional[Union[int, str]] = None
|
744 |
|
745 |
|
746 |
-
def get_ollama_url(url_idx: Optional[int], model: str):
|
747 |
if url_idx is None:
|
748 |
-
|
|
|
|
|
|
|
749 |
raise HTTPException(
|
750 |
status_code=400,
|
751 |
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
|
752 |
)
|
753 |
-
url_idx = random.choice(
|
754 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
755 |
return url
|
756 |
|
@@ -768,15 +928,7 @@ async def generate_chat_completion(
|
|
768 |
if "metadata" in payload:
|
769 |
del payload["metadata"]
|
770 |
|
771 |
-
model_id =
|
772 |
-
|
773 |
-
if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER:
|
774 |
-
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
775 |
-
raise HTTPException(
|
776 |
-
status_code=403,
|
777 |
-
detail="Model not found",
|
778 |
-
)
|
779 |
-
|
780 |
model_info = Models.get_model_by_id(model_id)
|
781 |
|
782 |
if model_info:
|
@@ -794,13 +946,37 @@ async def generate_chat_completion(
|
|
794 |
)
|
795 |
payload = apply_model_system_prompt_to_body(params, payload, user)
|
796 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
797 |
if ":" not in payload["model"]:
|
798 |
payload["model"] = f"{payload['model']}:latest"
|
799 |
|
800 |
-
url = get_ollama_url(url_idx, payload["model"])
|
801 |
log.info(f"url: {url}")
|
802 |
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
|
803 |
|
|
|
|
|
|
|
|
|
|
|
804 |
return await post_streaming_url(
|
805 |
f"{url}/api/chat",
|
806 |
json.dumps(payload),
|
@@ -817,7 +993,7 @@ class OpenAIChatMessageContent(BaseModel):
|
|
817 |
|
818 |
class OpenAIChatMessage(BaseModel):
|
819 |
role: str
|
820 |
-
content: Union[str, OpenAIChatMessageContent]
|
821 |
|
822 |
model_config = ConfigDict(extra="allow")
|
823 |
|
@@ -836,22 +1012,24 @@ async def generate_openai_chat_completion(
|
|
836 |
url_idx: Optional[int] = None,
|
837 |
user=Depends(get_verified_user),
|
838 |
):
|
839 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
840 |
payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
|
841 |
if "metadata" in payload:
|
842 |
del payload["metadata"]
|
843 |
|
844 |
model_id = completion_form.model
|
845 |
-
|
846 |
-
|
847 |
-
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
848 |
-
raise HTTPException(
|
849 |
-
status_code=403,
|
850 |
-
detail="Model not found",
|
851 |
-
)
|
852 |
|
853 |
model_info = Models.get_model_by_id(model_id)
|
854 |
-
|
855 |
if model_info:
|
856 |
if model_info.base_model_id:
|
857 |
payload["model"] = model_info.base_model_id
|
@@ -862,12 +1040,36 @@ async def generate_openai_chat_completion(
|
|
862 |
payload = apply_model_params_to_body_openai(params, payload)
|
863 |
payload = apply_model_system_prompt_to_body(params, payload, user)
|
864 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
865 |
if ":" not in payload["model"]:
|
866 |
payload["model"] = f"{payload['model']}:latest"
|
867 |
|
868 |
-
url = get_ollama_url(url_idx, payload["model"])
|
869 |
log.info(f"url: {url}")
|
870 |
|
|
|
|
|
|
|
|
|
|
|
871 |
return await post_streaming_url(
|
872 |
f"{url}/v1/chat/completions",
|
873 |
json.dumps(payload),
|
@@ -881,21 +1083,29 @@ async def get_openai_models(
|
|
881 |
url_idx: Optional[int] = None,
|
882 |
user=Depends(get_verified_user),
|
883 |
):
|
|
|
|
|
884 |
if url_idx is None:
|
885 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
886 |
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
)
|
895 |
-
)
|
896 |
|
897 |
-
|
898 |
-
"data": [
|
899 |
{
|
900 |
"id": model["model"],
|
901 |
"object": "model",
|
@@ -903,31 +1113,7 @@ async def get_openai_models(
|
|
903 |
"owned_by": "openai",
|
904 |
}
|
905 |
for model in models["models"]
|
906 |
-
]
|
907 |
-
"object": "list",
|
908 |
-
}
|
909 |
-
|
910 |
-
else:
|
911 |
-
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
912 |
-
try:
|
913 |
-
r = requests.request(method="GET", url=f"{url}/api/tags")
|
914 |
-
r.raise_for_status()
|
915 |
-
|
916 |
-
models = r.json()
|
917 |
-
|
918 |
-
return {
|
919 |
-
"data": [
|
920 |
-
{
|
921 |
-
"id": model["model"],
|
922 |
-
"object": "model",
|
923 |
-
"created": int(time.time()),
|
924 |
-
"owned_by": "openai",
|
925 |
-
}
|
926 |
-
for model in models["models"]
|
927 |
-
],
|
928 |
-
"object": "list",
|
929 |
-
}
|
930 |
-
|
931 |
except Exception as e:
|
932 |
log.exception(e)
|
933 |
error_detail = "Open WebUI: Server Connection Error"
|
@@ -944,6 +1130,23 @@ async def get_openai_models(
|
|
944 |
detail=error_detail,
|
945 |
)
|
946 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
947 |
|
948 |
class UrlForm(BaseModel):
|
949 |
url: str
|
|
|
13 |
from open_webui.apps.webui.models.models import Models
|
14 |
from open_webui.config import (
|
15 |
CORS_ALLOW_ORIGIN,
|
|
|
16 |
ENABLE_OLLAMA_API,
|
|
|
17 |
OLLAMA_BASE_URLS,
|
18 |
+
OLLAMA_API_CONFIGS,
|
19 |
UPLOAD_DIR,
|
20 |
AppConfig,
|
21 |
)
|
22 |
+
from open_webui.env import (
|
23 |
+
AIOHTTP_CLIENT_TIMEOUT,
|
24 |
+
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
25 |
+
)
|
26 |
|
27 |
|
28 |
from open_webui.constants import ERROR_MESSAGES
|
29 |
+
from open_webui.env import ENV, SRC_LOG_LEVELS
|
30 |
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
|
31 |
from fastapi.middleware.cors import CORSMiddleware
|
32 |
from fastapi.responses import StreamingResponse
|
|
|
43 |
apply_model_system_prompt_to_body,
|
44 |
)
|
45 |
from open_webui.utils.utils import get_admin_user, get_verified_user
|
46 |
+
from open_webui.utils.access_control import has_access
|
47 |
|
48 |
log = logging.getLogger(__name__)
|
49 |
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
50 |
|
51 |
+
|
52 |
+
app = FastAPI(
|
53 |
+
docs_url="/docs" if ENV == "dev" else None,
|
54 |
+
openapi_url="/openapi.json" if ENV == "dev" else None,
|
55 |
+
redoc_url=None,
|
56 |
+
)
|
57 |
+
|
58 |
app.add_middleware(
|
59 |
CORSMiddleware,
|
60 |
allow_origins=CORS_ALLOW_ORIGIN,
|
|
|
65 |
|
66 |
app.state.config = AppConfig()
|
67 |
|
|
|
|
|
|
|
68 |
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
69 |
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
70 |
+
app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
|
71 |
|
72 |
|
73 |
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
|
|
|
75 |
# least connections, or least response time for better resource utilization and performance optimization.
|
76 |
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
@app.head("/")
|
79 |
@app.get("/")
|
80 |
async def get_status():
|
81 |
return {"status": True}
|
82 |
|
83 |
|
84 |
+
class ConnectionVerificationForm(BaseModel):
|
85 |
+
url: str
|
86 |
+
key: Optional[str] = None
|
87 |
+
|
88 |
+
|
89 |
+
@app.post("/verify")
|
90 |
+
async def verify_connection(
|
91 |
+
form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
|
92 |
+
):
|
93 |
+
url = form_data.url
|
94 |
+
key = form_data.key
|
95 |
+
|
96 |
+
headers = {}
|
97 |
+
if key:
|
98 |
+
headers["Authorization"] = f"Bearer {key}"
|
99 |
+
|
100 |
+
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
101 |
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
102 |
+
try:
|
103 |
+
async with session.get(f"{url}/api/version", headers=headers) as r:
|
104 |
+
if r.status != 200:
|
105 |
+
# Extract response error details if available
|
106 |
+
error_detail = f"HTTP Error: {r.status}"
|
107 |
+
res = await r.json()
|
108 |
+
if "error" in res:
|
109 |
+
error_detail = f"External Error: {res['error']}"
|
110 |
+
raise Exception(error_detail)
|
111 |
+
|
112 |
+
response_data = await r.json()
|
113 |
+
return response_data
|
114 |
+
|
115 |
+
except aiohttp.ClientError as e:
|
116 |
+
# ClientError covers all aiohttp requests issues
|
117 |
+
log.exception(f"Client error: {str(e)}")
|
118 |
+
# Handle aiohttp-specific connection issues, timeout etc.
|
119 |
+
raise HTTPException(
|
120 |
+
status_code=500, detail="Open WebUI: Server Connection Error"
|
121 |
+
)
|
122 |
+
except Exception as e:
|
123 |
+
log.exception(f"Unexpected error: {e}")
|
124 |
+
# Generic error handler in case parsing JSON or other steps fail
|
125 |
+
error_detail = f"Unexpected error: {str(e)}"
|
126 |
+
raise HTTPException(status_code=500, detail=error_detail)
|
127 |
+
|
128 |
+
|
129 |
@app.get("/config")
|
130 |
async def get_config(user=Depends(get_admin_user)):
|
131 |
+
return {
|
132 |
+
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
|
133 |
+
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
|
134 |
+
"OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
|
135 |
+
}
|
136 |
|
137 |
|
138 |
class OllamaConfigForm(BaseModel):
|
139 |
+
ENABLE_OLLAMA_API: Optional[bool] = None
|
140 |
+
OLLAMA_BASE_URLS: list[str]
|
141 |
+
OLLAMA_API_CONFIGS: dict
|
142 |
|
143 |
|
144 |
@app.post("/config/update")
|
145 |
async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
|
146 |
+
app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
|
147 |
+
app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
+
app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
|
|
|
150 |
|
151 |
+
# Remove any extra configs
|
152 |
+
config_urls = app.state.config.OLLAMA_API_CONFIGS.keys()
|
153 |
+
for url in list(app.state.config.OLLAMA_BASE_URLS):
|
154 |
+
if url not in config_urls:
|
155 |
+
app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
|
156 |
|
157 |
+
return {
|
158 |
+
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
|
159 |
+
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
|
160 |
+
"OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
|
161 |
+
}
|
162 |
|
|
|
|
|
163 |
|
164 |
+
async def aiohttp_get(url, key=None):
|
165 |
+
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
|
|
166 |
try:
|
167 |
+
headers = {"Authorization": f"Bearer {key}"} if key else {}
|
168 |
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
169 |
+
async with session.get(url, headers=headers) as response:
|
170 |
return await response.json()
|
171 |
except Exception as e:
|
172 |
# Handle connection error here
|
|
|
192 |
session = aiohttp.ClientSession(
|
193 |
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
194 |
)
|
195 |
+
|
196 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
197 |
+
key = api_config.get("key", None)
|
198 |
+
|
199 |
+
headers = {"Content-Type": "application/json"}
|
200 |
+
if key:
|
201 |
+
headers["Authorization"] = f"Bearer {key}"
|
202 |
+
|
203 |
r = await session.post(
|
204 |
url,
|
205 |
data=payload,
|
206 |
+
headers=headers,
|
207 |
)
|
208 |
r.raise_for_status()
|
209 |
|
|
|
246 |
for idx, model_list in enumerate(model_lists):
|
247 |
if model_list is not None:
|
248 |
for model in model_list:
|
249 |
+
id = model["model"]
|
250 |
+
if id not in merged_models:
|
251 |
model["urls"] = [idx]
|
252 |
+
merged_models[id] = model
|
253 |
else:
|
254 |
+
merged_models[id]["urls"].append(idx)
|
255 |
|
256 |
return list(merged_models.values())
|
257 |
|
258 |
|
259 |
async def get_all_models():
|
260 |
log.info("get_all_models()")
|
|
|
261 |
if app.state.config.ENABLE_OLLAMA_API:
|
262 |
+
tasks = []
|
263 |
+
for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS):
|
264 |
+
if url not in app.state.config.OLLAMA_API_CONFIGS:
|
265 |
+
tasks.append(aiohttp_get(f"{url}/api/tags"))
|
266 |
+
else:
|
267 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
268 |
+
enable = api_config.get("enable", True)
|
269 |
+
key = api_config.get("key", None)
|
270 |
+
|
271 |
+
if enable:
|
272 |
+
tasks.append(aiohttp_get(f"{url}/api/tags", key))
|
273 |
+
else:
|
274 |
+
tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
275 |
+
|
276 |
responses = await asyncio.gather(*tasks)
|
277 |
|
278 |
+
for idx, response in enumerate(responses):
|
279 |
+
if response:
|
280 |
+
url = app.state.config.OLLAMA_BASE_URLS[idx]
|
281 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
282 |
+
|
283 |
+
prefix_id = api_config.get("prefix_id", None)
|
284 |
+
model_ids = api_config.get("model_ids", [])
|
285 |
+
|
286 |
+
if len(model_ids) != 0 and "models" in response:
|
287 |
+
response["models"] = list(
|
288 |
+
filter(
|
289 |
+
lambda model: model["model"] in model_ids,
|
290 |
+
response["models"],
|
291 |
+
)
|
292 |
+
)
|
293 |
+
|
294 |
+
if prefix_id:
|
295 |
+
for model in response.get("models", []):
|
296 |
+
model["model"] = f"{prefix_id}.{model['model']}"
|
297 |
+
|
298 |
+
print(responses)
|
299 |
+
|
300 |
models = {
|
301 |
"models": merge_models_lists(
|
302 |
map(
|
303 |
+
lambda response: response.get("models", []) if response else None,
|
304 |
+
responses,
|
305 |
)
|
306 |
)
|
307 |
}
|
|
|
309 |
else:
|
310 |
models = {"models": []}
|
311 |
|
|
|
|
|
312 |
return models
|
313 |
|
314 |
|
|
|
317 |
async def get_ollama_tags(
|
318 |
url_idx: Optional[int] = None, user=Depends(get_verified_user)
|
319 |
):
|
320 |
+
models = []
|
321 |
if url_idx is None:
|
322 |
models = await get_all_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
else:
|
324 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
325 |
|
326 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
327 |
+
key = api_config.get("key", None)
|
328 |
+
|
329 |
+
headers = {}
|
330 |
+
if key:
|
331 |
+
headers["Authorization"] = f"Bearer {key}"
|
332 |
+
|
333 |
r = None
|
334 |
try:
|
335 |
+
r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers)
|
336 |
r.raise_for_status()
|
337 |
|
338 |
+
models = r.json()
|
339 |
except Exception as e:
|
340 |
log.exception(e)
|
341 |
error_detail = "Open WebUI: Server Connection Error"
|
|
|
352 |
detail=error_detail,
|
353 |
)
|
354 |
|
355 |
+
if user.role == "user":
|
356 |
+
# Filter models based on user access control
|
357 |
+
filtered_models = []
|
358 |
+
for model in models.get("models", []):
|
359 |
+
model_info = Models.get_model_by_id(model["model"])
|
360 |
+
if model_info:
|
361 |
+
if user.id == model_info.user_id or has_access(
|
362 |
+
user.id, type="read", access_control=model_info.access_control
|
363 |
+
):
|
364 |
+
filtered_models.append(model)
|
365 |
+
models["models"] = filtered_models
|
366 |
+
|
367 |
+
return models
|
368 |
+
|
369 |
|
370 |
@app.get("/api/version")
|
371 |
@app.get("/api/version/{url_idx}")
|
|
|
374 |
if url_idx is None:
|
375 |
# returns lowest version
|
376 |
tasks = [
|
377 |
+
aiohttp_get(
|
378 |
+
f"{url}/api/version",
|
379 |
+
app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
|
380 |
+
)
|
381 |
for url in app.state.config.OLLAMA_BASE_URLS
|
382 |
]
|
383 |
responses = await asyncio.gather(*tasks)
|
|
|
457 |
user=Depends(get_admin_user),
|
458 |
):
|
459 |
if url_idx is None:
|
460 |
+
model_list = await get_all_models()
|
461 |
+
models = {model["model"]: model for model in model_list["models"]}
|
462 |
+
|
463 |
+
if form_data.name in models:
|
464 |
+
url_idx = models[form_data.name]["urls"][0]
|
465 |
else:
|
466 |
raise HTTPException(
|
467 |
status_code=400,
|
|
|
510 |
user=Depends(get_admin_user),
|
511 |
):
|
512 |
if url_idx is None:
|
513 |
+
model_list = await get_all_models()
|
514 |
+
models = {model["model"]: model for model in model_list["models"]}
|
515 |
+
|
516 |
+
if form_data.source in models:
|
517 |
+
url_idx = models[form_data.source]["urls"][0]
|
518 |
else:
|
519 |
raise HTTPException(
|
520 |
status_code=400,
|
|
|
523 |
|
524 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
525 |
log.info(f"url: {url}")
|
526 |
+
|
527 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
528 |
+
key = api_config.get("key", None)
|
529 |
+
|
530 |
+
headers = {"Content-Type": "application/json"}
|
531 |
+
if key:
|
532 |
+
headers["Authorization"] = f"Bearer {key}"
|
533 |
+
|
534 |
r = requests.request(
|
535 |
method="POST",
|
536 |
url=f"{url}/api/copy",
|
537 |
+
headers=headers,
|
538 |
data=form_data.model_dump_json(exclude_none=True).encode(),
|
539 |
)
|
540 |
|
|
|
569 |
user=Depends(get_admin_user),
|
570 |
):
|
571 |
if url_idx is None:
|
572 |
+
model_list = await get_all_models()
|
573 |
+
models = {model["model"]: model for model in model_list["models"]}
|
574 |
+
|
575 |
+
if form_data.name in models:
|
576 |
+
url_idx = models[form_data.name]["urls"][0]
|
577 |
else:
|
578 |
raise HTTPException(
|
579 |
status_code=400,
|
|
|
583 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
584 |
log.info(f"url: {url}")
|
585 |
|
586 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
587 |
+
key = api_config.get("key", None)
|
588 |
+
|
589 |
+
headers = {"Content-Type": "application/json"}
|
590 |
+
if key:
|
591 |
+
headers["Authorization"] = f"Bearer {key}"
|
592 |
+
|
593 |
r = requests.request(
|
594 |
method="DELETE",
|
595 |
url=f"{url}/api/delete",
|
|
|
596 |
data=form_data.model_dump_json(exclude_none=True).encode(),
|
597 |
+
headers=headers,
|
598 |
)
|
599 |
try:
|
600 |
r.raise_for_status()
|
|
|
621 |
|
622 |
@app.post("/api/show")
|
623 |
async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
|
624 |
+
model_list = await get_all_models()
|
625 |
+
models = {model["model"]: model for model in model_list["models"]}
|
626 |
+
|
627 |
+
if form_data.name not in models:
|
628 |
raise HTTPException(
|
629 |
status_code=400,
|
630 |
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
631 |
)
|
632 |
|
633 |
+
url_idx = random.choice(models[form_data.name]["urls"])
|
634 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
635 |
log.info(f"url: {url}")
|
636 |
|
637 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
638 |
+
key = api_config.get("key", None)
|
639 |
+
|
640 |
+
headers = {"Content-Type": "application/json"}
|
641 |
+
if key:
|
642 |
+
headers["Authorization"] = f"Bearer {key}"
|
643 |
+
|
644 |
r = requests.request(
|
645 |
method="POST",
|
646 |
url=f"{url}/api/show",
|
647 |
+
headers=headers,
|
648 |
data=form_data.model_dump_json(exclude_none=True).encode(),
|
649 |
)
|
650 |
try:
|
|
|
700 |
url_idx: Optional[int] = None,
|
701 |
user=Depends(get_verified_user),
|
702 |
):
|
703 |
+
return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
|
704 |
|
705 |
|
706 |
+
async def generate_ollama_embeddings(
|
707 |
form_data: GenerateEmbeddingsForm,
|
708 |
url_idx: Optional[int] = None,
|
709 |
):
|
710 |
log.info(f"generate_ollama_embeddings {form_data}")
|
711 |
|
712 |
if url_idx is None:
|
713 |
+
model_list = await get_all_models()
|
714 |
+
models = {model["model"]: model for model in model_list["models"]}
|
715 |
+
|
716 |
model = form_data.model
|
717 |
|
718 |
if ":" not in model:
|
719 |
model = f"{model}:latest"
|
720 |
|
721 |
+
if model in models:
|
722 |
+
url_idx = random.choice(models[model]["urls"])
|
723 |
else:
|
724 |
raise HTTPException(
|
725 |
status_code=400,
|
|
|
729 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
730 |
log.info(f"url: {url}")
|
731 |
|
732 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
733 |
+
key = api_config.get("key", None)
|
734 |
+
|
735 |
+
headers = {"Content-Type": "application/json"}
|
736 |
+
if key:
|
737 |
+
headers["Authorization"] = f"Bearer {key}"
|
738 |
+
|
739 |
r = requests.request(
|
740 |
method="POST",
|
741 |
url=f"{url}/api/embeddings",
|
742 |
+
headers=headers,
|
743 |
data=form_data.model_dump_json(exclude_none=True).encode(),
|
744 |
)
|
745 |
try:
|
|
|
770 |
)
|
771 |
|
772 |
|
773 |
+
async def generate_ollama_batch_embeddings(
|
774 |
form_data: GenerateEmbedForm,
|
775 |
url_idx: Optional[int] = None,
|
776 |
):
|
777 |
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
778 |
|
779 |
if url_idx is None:
|
780 |
+
model_list = await get_all_models()
|
781 |
+
models = {model["model"]: model for model in model_list["models"]}
|
782 |
+
|
783 |
model = form_data.model
|
784 |
|
785 |
if ":" not in model:
|
786 |
model = f"{model}:latest"
|
787 |
|
788 |
+
if model in models:
|
789 |
+
url_idx = random.choice(models[model]["urls"])
|
790 |
else:
|
791 |
raise HTTPException(
|
792 |
status_code=400,
|
|
|
796 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
797 |
log.info(f"url: {url}")
|
798 |
|
799 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
800 |
+
key = api_config.get("key", None)
|
801 |
+
|
802 |
+
headers = {"Content-Type": "application/json"}
|
803 |
+
if key:
|
804 |
+
headers["Authorization"] = f"Bearer {key}"
|
805 |
+
|
806 |
r = requests.request(
|
807 |
method="POST",
|
808 |
url=f"{url}/api/embed",
|
809 |
+
headers=headers,
|
810 |
data=form_data.model_dump_json(exclude_none=True).encode(),
|
811 |
)
|
812 |
try:
|
|
|
856 |
user=Depends(get_verified_user),
|
857 |
):
|
858 |
if url_idx is None:
|
859 |
+
model_list = await get_all_models()
|
860 |
+
models = {model["model"]: model for model in model_list["models"]}
|
861 |
+
|
862 |
model = form_data.model
|
863 |
|
864 |
if ":" not in model:
|
865 |
model = f"{model}:latest"
|
866 |
|
867 |
+
if model in models:
|
868 |
+
url_idx = random.choice(models[model]["urls"])
|
869 |
else:
|
870 |
raise HTTPException(
|
871 |
status_code=400,
|
|
|
873 |
)
|
874 |
|
875 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
876 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
877 |
+
prefix_id = api_config.get("prefix_id", None)
|
878 |
+
if prefix_id:
|
879 |
+
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
|
880 |
log.info(f"url: {url}")
|
881 |
|
882 |
return await post_streaming_url(
|
|
|
900 |
keep_alive: Optional[Union[int, str]] = None
|
901 |
|
902 |
|
903 |
+
async def get_ollama_url(url_idx: Optional[int], model: str):
|
904 |
if url_idx is None:
|
905 |
+
model_list = await get_all_models()
|
906 |
+
models = {model["model"]: model for model in model_list["models"]}
|
907 |
+
|
908 |
+
if model not in models:
|
909 |
raise HTTPException(
|
910 |
status_code=400,
|
911 |
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
|
912 |
)
|
913 |
+
url_idx = random.choice(models[model]["urls"])
|
914 |
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
915 |
return url
|
916 |
|
|
|
928 |
if "metadata" in payload:
|
929 |
del payload["metadata"]
|
930 |
|
931 |
+
model_id = payload["model"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
932 |
model_info = Models.get_model_by_id(model_id)
|
933 |
|
934 |
if model_info:
|
|
|
946 |
)
|
947 |
payload = apply_model_system_prompt_to_body(params, payload, user)
|
948 |
|
949 |
+
# Check if user has access to the model
|
950 |
+
if not bypass_filter and user.role == "user":
|
951 |
+
if not (
|
952 |
+
user.id == model_info.user_id
|
953 |
+
or has_access(
|
954 |
+
user.id, type="read", access_control=model_info.access_control
|
955 |
+
)
|
956 |
+
):
|
957 |
+
raise HTTPException(
|
958 |
+
status_code=403,
|
959 |
+
detail="Model not found",
|
960 |
+
)
|
961 |
+
elif not bypass_filter:
|
962 |
+
if user.role != "admin":
|
963 |
+
raise HTTPException(
|
964 |
+
status_code=403,
|
965 |
+
detail="Model not found",
|
966 |
+
)
|
967 |
+
|
968 |
if ":" not in payload["model"]:
|
969 |
payload["model"] = f"{payload['model']}:latest"
|
970 |
|
971 |
+
url = await get_ollama_url(url_idx, payload["model"])
|
972 |
log.info(f"url: {url}")
|
973 |
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
|
974 |
|
975 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
976 |
+
prefix_id = api_config.get("prefix_id", None)
|
977 |
+
if prefix_id:
|
978 |
+
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
979 |
+
|
980 |
return await post_streaming_url(
|
981 |
f"{url}/api/chat",
|
982 |
json.dumps(payload),
|
|
|
993 |
|
994 |
class OpenAIChatMessage(BaseModel):
|
995 |
role: str
|
996 |
+
content: Union[str, list[OpenAIChatMessageContent]]
|
997 |
|
998 |
model_config = ConfigDict(extra="allow")
|
999 |
|
|
|
1012 |
url_idx: Optional[int] = None,
|
1013 |
user=Depends(get_verified_user),
|
1014 |
):
|
1015 |
+
try:
|
1016 |
+
completion_form = OpenAIChatCompletionForm(**form_data)
|
1017 |
+
except Exception as e:
|
1018 |
+
log.exception(e)
|
1019 |
+
raise HTTPException(
|
1020 |
+
status_code=400,
|
1021 |
+
detail=str(e),
|
1022 |
+
)
|
1023 |
+
|
1024 |
payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
|
1025 |
if "metadata" in payload:
|
1026 |
del payload["metadata"]
|
1027 |
|
1028 |
model_id = completion_form.model
|
1029 |
+
if ":" not in model_id:
|
1030 |
+
model_id = f"{model_id}:latest"
|
|
|
|
|
|
|
|
|
|
|
1031 |
|
1032 |
model_info = Models.get_model_by_id(model_id)
|
|
|
1033 |
if model_info:
|
1034 |
if model_info.base_model_id:
|
1035 |
payload["model"] = model_info.base_model_id
|
|
|
1040 |
payload = apply_model_params_to_body_openai(params, payload)
|
1041 |
payload = apply_model_system_prompt_to_body(params, payload, user)
|
1042 |
|
1043 |
+
# Check if user has access to the model
|
1044 |
+
if user.role == "user":
|
1045 |
+
if not (
|
1046 |
+
user.id == model_info.user_id
|
1047 |
+
or has_access(
|
1048 |
+
user.id, type="read", access_control=model_info.access_control
|
1049 |
+
)
|
1050 |
+
):
|
1051 |
+
raise HTTPException(
|
1052 |
+
status_code=403,
|
1053 |
+
detail="Model not found",
|
1054 |
+
)
|
1055 |
+
else:
|
1056 |
+
if user.role != "admin":
|
1057 |
+
raise HTTPException(
|
1058 |
+
status_code=403,
|
1059 |
+
detail="Model not found",
|
1060 |
+
)
|
1061 |
+
|
1062 |
if ":" not in payload["model"]:
|
1063 |
payload["model"] = f"{payload['model']}:latest"
|
1064 |
|
1065 |
+
url = await get_ollama_url(url_idx, payload["model"])
|
1066 |
log.info(f"url: {url}")
|
1067 |
|
1068 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
1069 |
+
prefix_id = api_config.get("prefix_id", None)
|
1070 |
+
if prefix_id:
|
1071 |
+
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
1072 |
+
|
1073 |
return await post_streaming_url(
|
1074 |
f"{url}/v1/chat/completions",
|
1075 |
json.dumps(payload),
|
|
|
1083 |
url_idx: Optional[int] = None,
|
1084 |
user=Depends(get_verified_user),
|
1085 |
):
|
1086 |
+
|
1087 |
+
models = []
|
1088 |
if url_idx is None:
|
1089 |
+
model_list = await get_all_models()
|
1090 |
+
models = [
|
1091 |
+
{
|
1092 |
+
"id": model["model"],
|
1093 |
+
"object": "model",
|
1094 |
+
"created": int(time.time()),
|
1095 |
+
"owned_by": "openai",
|
1096 |
+
}
|
1097 |
+
for model in model_list["models"]
|
1098 |
+
]
|
1099 |
|
1100 |
+
else:
|
1101 |
+
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
1102 |
+
try:
|
1103 |
+
r = requests.request(method="GET", url=f"{url}/api/tags")
|
1104 |
+
r.raise_for_status()
|
1105 |
+
|
1106 |
+
model_list = r.json()
|
|
|
|
|
1107 |
|
1108 |
+
models = [
|
|
|
1109 |
{
|
1110 |
"id": model["model"],
|
1111 |
"object": "model",
|
|
|
1113 |
"owned_by": "openai",
|
1114 |
}
|
1115 |
for model in models["models"]
|
1116 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1117 |
except Exception as e:
|
1118 |
log.exception(e)
|
1119 |
error_detail = "Open WebUI: Server Connection Error"
|
|
|
1130 |
detail=error_detail,
|
1131 |
)
|
1132 |
|
1133 |
+
if user.role == "user":
|
1134 |
+
# Filter models based on user access control
|
1135 |
+
filtered_models = []
|
1136 |
+
for model in models:
|
1137 |
+
model_info = Models.get_model_by_id(model["id"])
|
1138 |
+
if model_info:
|
1139 |
+
if user.id == model_info.user_id or has_access(
|
1140 |
+
user.id, type="read", access_control=model_info.access_control
|
1141 |
+
):
|
1142 |
+
filtered_models.append(model)
|
1143 |
+
models = filtered_models
|
1144 |
+
|
1145 |
+
return {
|
1146 |
+
"data": models,
|
1147 |
+
"object": "list",
|
1148 |
+
}
|
1149 |
+
|
1150 |
|
1151 |
class UrlForm(BaseModel):
|
1152 |
url: str
|
backend/open_webui/apps/openai/main.py
CHANGED
@@ -11,20 +11,20 @@ from open_webui.apps.webui.models.models import Models
|
|
11 |
from open_webui.config import (
|
12 |
CACHE_DIR,
|
13 |
CORS_ALLOW_ORIGIN,
|
14 |
-
ENABLE_MODEL_FILTER,
|
15 |
ENABLE_OPENAI_API,
|
16 |
-
MODEL_FILTER_LIST,
|
17 |
OPENAI_API_BASE_URLS,
|
18 |
OPENAI_API_KEYS,
|
|
|
19 |
AppConfig,
|
20 |
)
|
21 |
from open_webui.env import (
|
22 |
AIOHTTP_CLIENT_TIMEOUT,
|
23 |
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
|
|
24 |
)
|
25 |
|
26 |
from open_webui.constants import ERROR_MESSAGES
|
27 |
-
from open_webui.env import SRC_LOG_LEVELS
|
28 |
from fastapi import Depends, FastAPI, HTTPException, Request
|
29 |
from fastapi.middleware.cors import CORSMiddleware
|
30 |
from fastapi.responses import FileResponse, StreamingResponse
|
@@ -37,11 +37,20 @@ from open_webui.utils.payload import (
|
|
37 |
)
|
38 |
|
39 |
from open_webui.utils.utils import get_admin_user, get_verified_user
|
|
|
|
|
40 |
|
41 |
log = logging.getLogger(__name__)
|
42 |
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
|
43 |
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
app.add_middleware(
|
46 |
CORSMiddleware,
|
47 |
allow_origins=CORS_ALLOW_ORIGIN,
|
@@ -52,69 +61,66 @@ app.add_middleware(
|
|
52 |
|
53 |
app.state.config = AppConfig()
|
54 |
|
55 |
-
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
56 |
-
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
57 |
-
|
58 |
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
|
59 |
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
|
60 |
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
|
61 |
-
|
62 |
-
app.state.MODELS = {}
|
63 |
-
|
64 |
-
|
65 |
-
@app.middleware("http")
|
66 |
-
async def check_url(request: Request, call_next):
|
67 |
-
if len(app.state.MODELS) == 0:
|
68 |
-
await get_all_models()
|
69 |
-
|
70 |
-
response = await call_next(request)
|
71 |
-
return response
|
72 |
|
73 |
|
74 |
@app.get("/config")
|
75 |
async def get_config(user=Depends(get_admin_user)):
|
76 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
|
79 |
class OpenAIConfigForm(BaseModel):
|
80 |
-
|
|
|
|
|
|
|
81 |
|
82 |
|
83 |
@app.post("/config/update")
|
84 |
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
|
85 |
-
app.state.config.ENABLE_OPENAI_API = form_data.
|
86 |
-
return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
|
87 |
-
|
88 |
-
|
89 |
-
class UrlsUpdateForm(BaseModel):
|
90 |
-
urls: list[str]
|
91 |
|
|
|
|
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
|
109 |
-
|
110 |
-
async def get_openai_keys(user=Depends(get_admin_user)):
|
111 |
-
return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
|
112 |
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
118 |
|
119 |
|
120 |
@app.post("/audio/speech")
|
@@ -140,6 +146,11 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|
140 |
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
|
141 |
headers["HTTP-Referer"] = "https://openwebui.com/"
|
142 |
headers["X-Title"] = "Open WebUI"
|
|
|
|
|
|
|
|
|
|
|
143 |
r = None
|
144 |
try:
|
145 |
r = requests.post(
|
@@ -181,10 +192,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|
181 |
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
182 |
|
183 |
|
184 |
-
async def
|
185 |
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
186 |
try:
|
187 |
-
headers = {"Authorization": f"Bearer {key}"}
|
188 |
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
189 |
async with session.get(url, headers=headers) as response:
|
190 |
return await response.json()
|
@@ -239,12 +250,8 @@ def merge_models_lists(model_lists):
|
|
239 |
return merged_list
|
240 |
|
241 |
|
242 |
-
def
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
async def get_all_models_raw() -> list:
|
247 |
-
if is_openai_api_disabled():
|
248 |
return []
|
249 |
|
250 |
# Check if API KEYS length is same than API URLS length
|
@@ -260,33 +267,67 @@ async def get_all_models_raw() -> list:
|
|
260 |
else:
|
261 |
app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
|
262 |
|
263 |
-
tasks = [
|
264 |
-
|
265 |
-
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
|
268 |
responses = await asyncio.gather(*tasks)
|
269 |
-
log.debug(f"get_all_models:responses() {responses}")
|
270 |
|
271 |
-
|
|
|
|
|
|
|
272 |
|
|
|
273 |
|
274 |
-
|
275 |
-
|
|
|
276 |
|
|
|
277 |
|
278 |
-
|
279 |
-
async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ...
|
280 |
|
281 |
|
282 |
-
async def get_all_models(
|
283 |
log.info("get_all_models()")
|
284 |
-
if is_openai_api_disabled():
|
285 |
-
return [] if raw else {"data": []}
|
286 |
|
287 |
-
|
288 |
-
|
289 |
-
|
|
|
290 |
|
291 |
def extract_data(response):
|
292 |
if response and "data" in response:
|
@@ -296,9 +337,7 @@ async def get_all_models(raw=False) -> dict[str, list] | list:
|
|
296 |
return None
|
297 |
|
298 |
models = {"data": merge_models_lists(map(extract_data, responses))}
|
299 |
-
|
300 |
log.debug(f"models: {models}")
|
301 |
-
app.state.MODELS = {model["id"]: model for model in models["data"]}
|
302 |
|
303 |
return models
|
304 |
|
@@ -306,18 +345,12 @@ async def get_all_models(raw=False) -> dict[str, list] | list:
|
|
306 |
@app.get("/models")
|
307 |
@app.get("/models/{url_idx}")
|
308 |
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
|
|
|
|
|
|
|
|
|
309 |
if url_idx is None:
|
310 |
models = await get_all_models()
|
311 |
-
if app.state.config.ENABLE_MODEL_FILTER:
|
312 |
-
if user.role == "user":
|
313 |
-
models["data"] = list(
|
314 |
-
filter(
|
315 |
-
lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
|
316 |
-
models["data"],
|
317 |
-
)
|
318 |
-
)
|
319 |
-
return models
|
320 |
-
return models
|
321 |
else:
|
322 |
url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
323 |
key = app.state.config.OPENAI_API_KEYS[url_idx]
|
@@ -326,56 +359,126 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
|
|
326 |
headers["Authorization"] = f"Bearer {key}"
|
327 |
headers["Content-Type"] = "application/json"
|
328 |
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
r = None
|
330 |
|
331 |
-
|
332 |
-
|
333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
|
337 |
-
|
338 |
-
# Filter the response data
|
339 |
-
response_data["data"] = [
|
340 |
-
model
|
341 |
-
for model in response_data["data"]
|
342 |
-
if not any(
|
343 |
-
name in model["id"]
|
344 |
-
for name in [
|
345 |
-
"babbage",
|
346 |
-
"dall-e",
|
347 |
-
"davinci",
|
348 |
-
"embedding",
|
349 |
-
"tts",
|
350 |
-
"whisper",
|
351 |
-
]
|
352 |
-
)
|
353 |
-
]
|
354 |
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
if "error" in res:
|
363 |
-
error_detail = f"External: {res['error']}"
|
364 |
-
|
365 |
-
|
|
|
|
|
366 |
|
|
|
|
|
|
|
|
|
367 |
raise HTTPException(
|
368 |
-
status_code=
|
369 |
-
detail=error_detail,
|
370 |
)
|
|
|
|
|
|
|
|
|
|
|
371 |
|
372 |
|
373 |
@app.post("/chat/completions")
|
374 |
-
@app.post("/chat/completions/{url_idx}")
|
375 |
async def generate_chat_completion(
|
376 |
form_data: dict,
|
377 |
-
url_idx: Optional[int] = None,
|
378 |
user=Depends(get_verified_user),
|
|
|
379 |
):
|
380 |
idx = 0
|
381 |
payload = {**form_data}
|
@@ -386,6 +489,7 @@ async def generate_chat_completion(
|
|
386 |
model_id = form_data.get("model")
|
387 |
model_info = Models.get_model_by_id(model_id)
|
388 |
|
|
|
389 |
if model_info:
|
390 |
if model_info.base_model_id:
|
391 |
payload["model"] = model_info.base_model_id
|
@@ -394,9 +498,52 @@ async def generate_chat_completion(
|
|
394 |
payload = apply_model_params_to_body_openai(params, payload)
|
395 |
payload = apply_model_system_prompt_to_body(params, payload, user)
|
396 |
|
397 |
-
|
398 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
if "pipeline" in model and model.get("pipeline"):
|
401 |
payload["user"] = {
|
402 |
"name": user.name,
|
@@ -407,8 +554,9 @@ async def generate_chat_completion(
|
|
407 |
|
408 |
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
409 |
key = app.state.config.OPENAI_API_KEYS[idx]
|
410 |
-
is_o1 = payload["model"].lower().startswith("o1-")
|
411 |
|
|
|
|
|
412 |
# Change max_completion_tokens to max_tokens (Backward compatible)
|
413 |
if "api.openai.com" not in url and not is_o1:
|
414 |
if "max_completion_tokens" in payload:
|
@@ -437,6 +585,11 @@ async def generate_chat_completion(
|
|
437 |
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
|
438 |
headers["HTTP-Referer"] = "https://openwebui.com/"
|
439 |
headers["X-Title"] = "Open WebUI"
|
|
|
|
|
|
|
|
|
|
|
440 |
|
441 |
r = None
|
442 |
session = None
|
@@ -505,6 +658,11 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|
505 |
headers = {}
|
506 |
headers["Authorization"] = f"Bearer {key}"
|
507 |
headers["Content-Type"] = "application/json"
|
|
|
|
|
|
|
|
|
|
|
508 |
|
509 |
r = None
|
510 |
session = None
|
|
|
11 |
from open_webui.config import (
|
12 |
CACHE_DIR,
|
13 |
CORS_ALLOW_ORIGIN,
|
|
|
14 |
ENABLE_OPENAI_API,
|
|
|
15 |
OPENAI_API_BASE_URLS,
|
16 |
OPENAI_API_KEYS,
|
17 |
+
OPENAI_API_CONFIGS,
|
18 |
AppConfig,
|
19 |
)
|
20 |
from open_webui.env import (
|
21 |
AIOHTTP_CLIENT_TIMEOUT,
|
22 |
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
23 |
+
ENABLE_FORWARD_USER_INFO_HEADERS,
|
24 |
)
|
25 |
|
26 |
from open_webui.constants import ERROR_MESSAGES
|
27 |
+
from open_webui.env import ENV, SRC_LOG_LEVELS
|
28 |
from fastapi import Depends, FastAPI, HTTPException, Request
|
29 |
from fastapi.middleware.cors import CORSMiddleware
|
30 |
from fastapi.responses import FileResponse, StreamingResponse
|
|
|
37 |
)
|
38 |
|
39 |
from open_webui.utils.utils import get_admin_user, get_verified_user
|
40 |
+
from open_webui.utils.access_control import has_access
|
41 |
+
|
42 |
|
43 |
log = logging.getLogger(__name__)
|
44 |
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
|
45 |
|
46 |
+
|
47 |
+
app = FastAPI(
|
48 |
+
docs_url="/docs" if ENV == "dev" else None,
|
49 |
+
openapi_url="/openapi.json" if ENV == "dev" else None,
|
50 |
+
redoc_url=None,
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
app.add_middleware(
|
55 |
CORSMiddleware,
|
56 |
allow_origins=CORS_ALLOW_ORIGIN,
|
|
|
61 |
|
62 |
app.state.config = AppConfig()
|
63 |
|
|
|
|
|
|
|
64 |
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
|
65 |
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
|
66 |
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
|
67 |
+
app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
|
70 |
@app.get("/config")
|
71 |
async def get_config(user=Depends(get_admin_user)):
|
72 |
+
return {
|
73 |
+
"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
|
74 |
+
"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
|
75 |
+
"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
|
76 |
+
"OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
|
77 |
+
}
|
78 |
|
79 |
|
80 |
class OpenAIConfigForm(BaseModel):
|
81 |
+
ENABLE_OPENAI_API: Optional[bool] = None
|
82 |
+
OPENAI_API_BASE_URLS: list[str]
|
83 |
+
OPENAI_API_KEYS: list[str]
|
84 |
+
OPENAI_API_CONFIGS: dict
|
85 |
|
86 |
|
87 |
@app.post("/config/update")
|
88 |
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
|
89 |
+
app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
+
app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS
|
92 |
+
app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS
|
93 |
|
94 |
+
# Check if API KEYS length is same than API URLS length
|
95 |
+
if len(app.state.config.OPENAI_API_KEYS) != len(
|
96 |
+
app.state.config.OPENAI_API_BASE_URLS
|
97 |
+
):
|
98 |
+
if len(app.state.config.OPENAI_API_KEYS) > len(
|
99 |
+
app.state.config.OPENAI_API_BASE_URLS
|
100 |
+
):
|
101 |
+
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
|
102 |
+
: len(app.state.config.OPENAI_API_BASE_URLS)
|
103 |
+
]
|
104 |
+
else:
|
105 |
+
app.state.config.OPENAI_API_KEYS += [""] * (
|
106 |
+
len(app.state.config.OPENAI_API_BASE_URLS)
|
107 |
+
- len(app.state.config.OPENAI_API_KEYS)
|
108 |
+
)
|
109 |
|
110 |
+
app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
|
|
|
|
|
111 |
|
112 |
+
# Remove any extra configs
|
113 |
+
config_urls = app.state.config.OPENAI_API_CONFIGS.keys()
|
114 |
+
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
|
115 |
+
if url not in config_urls:
|
116 |
+
app.state.config.OPENAI_API_CONFIGS.pop(url, None)
|
117 |
|
118 |
+
return {
|
119 |
+
"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
|
120 |
+
"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
|
121 |
+
"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
|
122 |
+
"OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
|
123 |
+
}
|
124 |
|
125 |
|
126 |
@app.post("/audio/speech")
|
|
|
146 |
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
|
147 |
headers["HTTP-Referer"] = "https://openwebui.com/"
|
148 |
headers["X-Title"] = "Open WebUI"
|
149 |
+
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
150 |
+
headers["X-OpenWebUI-User-Name"] = user.name
|
151 |
+
headers["X-OpenWebUI-User-Id"] = user.id
|
152 |
+
headers["X-OpenWebUI-User-Email"] = user.email
|
153 |
+
headers["X-OpenWebUI-User-Role"] = user.role
|
154 |
r = None
|
155 |
try:
|
156 |
r = requests.post(
|
|
|
192 |
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
193 |
|
194 |
|
195 |
+
async def aiohttp_get(url, key=None):
|
196 |
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
197 |
try:
|
198 |
+
headers = {"Authorization": f"Bearer {key}"} if key else {}
|
199 |
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
200 |
async with session.get(url, headers=headers) as response:
|
201 |
return await response.json()
|
|
|
250 |
return merged_list
|
251 |
|
252 |
|
253 |
+
async def get_all_models_responses() -> list:
|
254 |
+
if not app.state.config.ENABLE_OPENAI_API:
|
|
|
|
|
|
|
|
|
255 |
return []
|
256 |
|
257 |
# Check if API KEYS length is same than API URLS length
|
|
|
267 |
else:
|
268 |
app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
|
269 |
|
270 |
+
tasks = []
|
271 |
+
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
|
272 |
+
if url not in app.state.config.OPENAI_API_CONFIGS:
|
273 |
+
tasks.append(
|
274 |
+
aiohttp_get(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
|
275 |
+
)
|
276 |
+
else:
|
277 |
+
api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
|
278 |
+
|
279 |
+
enable = api_config.get("enable", True)
|
280 |
+
model_ids = api_config.get("model_ids", [])
|
281 |
+
|
282 |
+
if enable:
|
283 |
+
if len(model_ids) == 0:
|
284 |
+
tasks.append(
|
285 |
+
aiohttp_get(
|
286 |
+
f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]
|
287 |
+
)
|
288 |
+
)
|
289 |
+
else:
|
290 |
+
model_list = {
|
291 |
+
"object": "list",
|
292 |
+
"data": [
|
293 |
+
{
|
294 |
+
"id": model_id,
|
295 |
+
"name": model_id,
|
296 |
+
"owned_by": "openai",
|
297 |
+
"openai": {"id": model_id},
|
298 |
+
"urlIdx": idx,
|
299 |
+
}
|
300 |
+
for model_id in model_ids
|
301 |
+
],
|
302 |
+
}
|
303 |
+
|
304 |
+
tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list)))
|
305 |
|
306 |
responses = await asyncio.gather(*tasks)
|
|
|
307 |
|
308 |
+
for idx, response in enumerate(responses):
|
309 |
+
if response:
|
310 |
+
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
311 |
+
api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
|
312 |
|
313 |
+
prefix_id = api_config.get("prefix_id", None)
|
314 |
|
315 |
+
if prefix_id:
|
316 |
+
for model in response["data"]:
|
317 |
+
model["id"] = f"{prefix_id}.{model['id']}"
|
318 |
|
319 |
+
log.debug(f"get_all_models:responses() {responses}")
|
320 |
|
321 |
+
return responses
|
|
|
322 |
|
323 |
|
324 |
+
async def get_all_models() -> dict[str, list]:
|
325 |
log.info("get_all_models()")
|
|
|
|
|
326 |
|
327 |
+
if not app.state.config.ENABLE_OPENAI_API:
|
328 |
+
return {"data": []}
|
329 |
+
|
330 |
+
responses = await get_all_models_responses()
|
331 |
|
332 |
def extract_data(response):
|
333 |
if response and "data" in response:
|
|
|
337 |
return None
|
338 |
|
339 |
models = {"data": merge_models_lists(map(extract_data, responses))}
|
|
|
340 |
log.debug(f"models: {models}")
|
|
|
341 |
|
342 |
return models
|
343 |
|
|
|
345 |
@app.get("/models")
|
346 |
@app.get("/models/{url_idx}")
|
347 |
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
|
348 |
+
models = {
|
349 |
+
"data": [],
|
350 |
+
}
|
351 |
+
|
352 |
if url_idx is None:
|
353 |
models = await get_all_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
else:
|
355 |
url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
356 |
key = app.state.config.OPENAI_API_KEYS[url_idx]
|
|
|
359 |
headers["Authorization"] = f"Bearer {key}"
|
360 |
headers["Content-Type"] = "application/json"
|
361 |
|
362 |
+
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
363 |
+
headers["X-OpenWebUI-User-Name"] = user.name
|
364 |
+
headers["X-OpenWebUI-User-Id"] = user.id
|
365 |
+
headers["X-OpenWebUI-User-Email"] = user.email
|
366 |
+
headers["X-OpenWebUI-User-Role"] = user.role
|
367 |
+
|
368 |
r = None
|
369 |
|
370 |
+
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
371 |
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
372 |
+
try:
|
373 |
+
async with session.get(f"{url}/models", headers=headers) as r:
|
374 |
+
if r.status != 200:
|
375 |
+
# Extract response error details if available
|
376 |
+
error_detail = f"HTTP Error: {r.status}"
|
377 |
+
res = await r.json()
|
378 |
+
if "error" in res:
|
379 |
+
error_detail = f"External Error: {res['error']}"
|
380 |
+
raise Exception(error_detail)
|
381 |
+
|
382 |
+
response_data = await r.json()
|
383 |
+
|
384 |
+
# Check if we're calling OpenAI API based on the URL
|
385 |
+
if "api.openai.com" in url:
|
386 |
+
# Filter models according to the specified conditions
|
387 |
+
response_data["data"] = [
|
388 |
+
model
|
389 |
+
for model in response_data.get("data", [])
|
390 |
+
if not any(
|
391 |
+
name in model["id"]
|
392 |
+
for name in [
|
393 |
+
"babbage",
|
394 |
+
"dall-e",
|
395 |
+
"davinci",
|
396 |
+
"embedding",
|
397 |
+
"tts",
|
398 |
+
"whisper",
|
399 |
+
]
|
400 |
+
)
|
401 |
+
]
|
402 |
|
403 |
+
models = response_data
|
404 |
+
except aiohttp.ClientError as e:
|
405 |
+
# ClientError covers all aiohttp requests issues
|
406 |
+
log.exception(f"Client error: {str(e)}")
|
407 |
+
# Handle aiohttp-specific connection issues, timeout etc.
|
408 |
+
raise HTTPException(
|
409 |
+
status_code=500, detail="Open WebUI: Server Connection Error"
|
410 |
+
)
|
411 |
+
except Exception as e:
|
412 |
+
log.exception(f"Unexpected error: {e}")
|
413 |
+
# Generic error handler in case parsing JSON or other steps fail
|
414 |
+
error_detail = f"Unexpected error: {str(e)}"
|
415 |
+
raise HTTPException(status_code=500, detail=error_detail)
|
416 |
+
|
417 |
+
if user.role == "user":
|
418 |
+
# Filter models based on user access control
|
419 |
+
filtered_models = []
|
420 |
+
for model in models.get("data", []):
|
421 |
+
model_info = Models.get_model_by_id(model["id"])
|
422 |
+
if model_info:
|
423 |
+
if user.id == model_info.user_id or has_access(
|
424 |
+
user.id, type="read", access_control=model_info.access_control
|
425 |
+
):
|
426 |
+
filtered_models.append(model)
|
427 |
+
models["data"] = filtered_models
|
428 |
|
429 |
+
return models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
|
431 |
+
|
432 |
+
class ConnectionVerificationForm(BaseModel):
|
433 |
+
url: str
|
434 |
+
key: str
|
435 |
+
|
436 |
+
|
437 |
+
@app.post("/verify")
|
438 |
+
async def verify_connection(
|
439 |
+
form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
|
440 |
+
):
|
441 |
+
url = form_data.url
|
442 |
+
key = form_data.key
|
443 |
+
|
444 |
+
headers = {}
|
445 |
+
headers["Authorization"] = f"Bearer {key}"
|
446 |
+
headers["Content-Type"] = "application/json"
|
447 |
+
|
448 |
+
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
449 |
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
450 |
+
try:
|
451 |
+
async with session.get(f"{url}/models", headers=headers) as r:
|
452 |
+
if r.status != 200:
|
453 |
+
# Extract response error details if available
|
454 |
+
error_detail = f"HTTP Error: {r.status}"
|
455 |
+
res = await r.json()
|
456 |
if "error" in res:
|
457 |
+
error_detail = f"External Error: {res['error']}"
|
458 |
+
raise Exception(error_detail)
|
459 |
+
|
460 |
+
response_data = await r.json()
|
461 |
+
return response_data
|
462 |
|
463 |
+
except aiohttp.ClientError as e:
|
464 |
+
# ClientError covers all aiohttp requests issues
|
465 |
+
log.exception(f"Client error: {str(e)}")
|
466 |
+
# Handle aiohttp-specific connection issues, timeout etc.
|
467 |
raise HTTPException(
|
468 |
+
status_code=500, detail="Open WebUI: Server Connection Error"
|
|
|
469 |
)
|
470 |
+
except Exception as e:
|
471 |
+
log.exception(f"Unexpected error: {e}")
|
472 |
+
# Generic error handler in case parsing JSON or other steps fail
|
473 |
+
error_detail = f"Unexpected error: {str(e)}"
|
474 |
+
raise HTTPException(status_code=500, detail=error_detail)
|
475 |
|
476 |
|
477 |
@app.post("/chat/completions")
|
|
|
478 |
async def generate_chat_completion(
|
479 |
form_data: dict,
|
|
|
480 |
user=Depends(get_verified_user),
|
481 |
+
bypass_filter: Optional[bool] = False,
|
482 |
):
|
483 |
idx = 0
|
484 |
payload = {**form_data}
|
|
|
489 |
model_id = form_data.get("model")
|
490 |
model_info = Models.get_model_by_id(model_id)
|
491 |
|
492 |
+
# Check model info and override the payload
|
493 |
if model_info:
|
494 |
if model_info.base_model_id:
|
495 |
payload["model"] = model_info.base_model_id
|
|
|
498 |
payload = apply_model_params_to_body_openai(params, payload)
|
499 |
payload = apply_model_system_prompt_to_body(params, payload, user)
|
500 |
|
501 |
+
# Check if user has access to the model
|
502 |
+
if not bypass_filter and user.role == "user":
|
503 |
+
if not (
|
504 |
+
user.id == model_info.user_id
|
505 |
+
or has_access(
|
506 |
+
user.id, type="read", access_control=model_info.access_control
|
507 |
+
)
|
508 |
+
):
|
509 |
+
raise HTTPException(
|
510 |
+
status_code=403,
|
511 |
+
detail="Model not found",
|
512 |
+
)
|
513 |
+
elif not bypass_filter:
|
514 |
+
if user.role != "admin":
|
515 |
+
raise HTTPException(
|
516 |
+
status_code=403,
|
517 |
+
detail="Model not found",
|
518 |
+
)
|
519 |
+
|
520 |
+
# Attemp to get urlIdx from the model
|
521 |
+
models = await get_all_models()
|
522 |
|
523 |
+
# Find the model from the list
|
524 |
+
model = next(
|
525 |
+
(model for model in models["data"] if model["id"] == payload.get("model")),
|
526 |
+
None,
|
527 |
+
)
|
528 |
+
|
529 |
+
if model:
|
530 |
+
idx = model["urlIdx"]
|
531 |
+
else:
|
532 |
+
raise HTTPException(
|
533 |
+
status_code=404,
|
534 |
+
detail="Model not found",
|
535 |
+
)
|
536 |
+
|
537 |
+
# Get the API config for the model
|
538 |
+
api_config = app.state.config.OPENAI_API_CONFIGS.get(
|
539 |
+
app.state.config.OPENAI_API_BASE_URLS[idx], {}
|
540 |
+
)
|
541 |
+
prefix_id = api_config.get("prefix_id", None)
|
542 |
+
|
543 |
+
if prefix_id:
|
544 |
+
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
545 |
+
|
546 |
+
# Add user info to the payload if the model is a pipeline
|
547 |
if "pipeline" in model and model.get("pipeline"):
|
548 |
payload["user"] = {
|
549 |
"name": user.name,
|
|
|
554 |
|
555 |
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
556 |
key = app.state.config.OPENAI_API_KEYS[idx]
|
|
|
557 |
|
558 |
+
# Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
|
559 |
+
is_o1 = payload["model"].lower().startswith("o1-")
|
560 |
# Change max_completion_tokens to max_tokens (Backward compatible)
|
561 |
if "api.openai.com" not in url and not is_o1:
|
562 |
if "max_completion_tokens" in payload:
|
|
|
585 |
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
|
586 |
headers["HTTP-Referer"] = "https://openwebui.com/"
|
587 |
headers["X-Title"] = "Open WebUI"
|
588 |
+
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
589 |
+
headers["X-OpenWebUI-User-Name"] = user.name
|
590 |
+
headers["X-OpenWebUI-User-Id"] = user.id
|
591 |
+
headers["X-OpenWebUI-User-Email"] = user.email
|
592 |
+
headers["X-OpenWebUI-User-Role"] = user.role
|
593 |
|
594 |
r = None
|
595 |
session = None
|
|
|
658 |
headers = {}
|
659 |
headers["Authorization"] = f"Bearer {key}"
|
660 |
headers["Content-Type"] = "application/json"
|
661 |
+
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
662 |
+
headers["X-OpenWebUI-User-Name"] = user.name
|
663 |
+
headers["X-OpenWebUI-User-Id"] = user.id
|
664 |
+
headers["X-OpenWebUI-User-Email"] = user.email
|
665 |
+
headers["X-OpenWebUI-User-Role"] = user.role
|
666 |
|
667 |
r = None
|
668 |
session = None
|
backend/open_webui/apps/retrieval/loaders/main.py
CHANGED
@@ -159,7 +159,7 @@ class Loader:
|
|
159 |
elif file_ext in ["htm", "html"]:
|
160 |
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
|
161 |
elif file_ext == "md":
|
162 |
-
loader =
|
163 |
elif file_content_type == "application/epub+zip":
|
164 |
loader = UnstructuredEPubLoader(file_path)
|
165 |
elif (
|
|
|
159 |
elif file_ext in ["htm", "html"]:
|
160 |
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
|
161 |
elif file_ext == "md":
|
162 |
+
loader = TextLoader(file_path, autodetect_encoding=True)
|
163 |
elif file_content_type == "application/epub+zip":
|
164 |
loader = UnstructuredEPubLoader(file_path)
|
165 |
elif (
|
backend/open_webui/apps/retrieval/main.py
CHANGED
@@ -37,6 +37,7 @@ from open_webui.apps.retrieval.web.serper import search_serper
|
|
37 |
from open_webui.apps.retrieval.web.serply import search_serply
|
38 |
from open_webui.apps.retrieval.web.serpstack import search_serpstack
|
39 |
from open_webui.apps.retrieval.web.tavily import search_tavily
|
|
|
40 |
|
41 |
|
42 |
from open_webui.apps.retrieval.utils import (
|
@@ -74,6 +75,8 @@ from open_webui.config import (
|
|
74 |
RAG_FILE_MAX_SIZE,
|
75 |
RAG_OPENAI_API_BASE_URL,
|
76 |
RAG_OPENAI_API_KEY,
|
|
|
|
|
77 |
RAG_RELEVANCE_THRESHOLD,
|
78 |
RAG_RERANKING_MODEL,
|
79 |
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
@@ -85,6 +88,7 @@ from open_webui.config import (
|
|
85 |
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
86 |
RAG_WEB_SEARCH_ENGINE,
|
87 |
RAG_WEB_SEARCH_RESULT_COUNT,
|
|
|
88 |
SEARCHAPI_API_KEY,
|
89 |
SEARCHAPI_ENGINE,
|
90 |
SEARXNG_QUERY_URL,
|
@@ -93,13 +97,20 @@ from open_webui.config import (
|
|
93 |
SERPSTACK_API_KEY,
|
94 |
SERPSTACK_HTTPS,
|
95 |
TAVILY_API_KEY,
|
|
|
|
|
96 |
TIKA_SERVER_URL,
|
97 |
UPLOAD_DIR,
|
98 |
YOUTUBE_LOADER_LANGUAGE,
|
|
|
99 |
AppConfig,
|
100 |
)
|
101 |
from open_webui.constants import ERROR_MESSAGES
|
102 |
-
from open_webui.env import
|
|
|
|
|
|
|
|
|
103 |
from open_webui.utils.misc import (
|
104 |
calculate_sha256,
|
105 |
calculate_sha256_string,
|
@@ -118,7 +129,11 @@ from langchain_core.documents import Document
|
|
118 |
log = logging.getLogger(__name__)
|
119 |
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
120 |
|
121 |
-
app = FastAPI(
|
|
|
|
|
|
|
|
|
122 |
|
123 |
app.state.config = AppConfig()
|
124 |
|
@@ -150,6 +165,9 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
|
150 |
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
151 |
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
152 |
|
|
|
|
|
|
|
153 |
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
|
154 |
|
155 |
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
|
@@ -171,6 +189,10 @@ app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
|
|
171 |
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
|
172 |
app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
|
173 |
app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
|
|
|
|
|
|
|
|
|
174 |
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
|
175 |
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
|
176 |
|
@@ -182,11 +204,15 @@ def update_embedding_model(
|
|
182 |
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
|
183 |
from sentence_transformers import SentenceTransformer
|
184 |
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
190 |
else:
|
191 |
app.state.sentence_transformer_ef = None
|
192 |
|
@@ -240,8 +266,16 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
|
240 |
app.state.config.RAG_EMBEDDING_ENGINE,
|
241 |
app.state.config.RAG_EMBEDDING_MODEL,
|
242 |
app.state.sentence_transformer_ef,
|
243 |
-
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
246 |
)
|
247 |
|
@@ -291,6 +325,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
|
291 |
"url": app.state.config.OPENAI_API_BASE_URL,
|
292 |
"key": app.state.config.OPENAI_API_KEY,
|
293 |
},
|
|
|
|
|
|
|
|
|
294 |
}
|
295 |
|
296 |
|
@@ -307,8 +345,14 @@ class OpenAIConfigForm(BaseModel):
|
|
307 |
key: str
|
308 |
|
309 |
|
|
|
|
|
|
|
|
|
|
|
310 |
class EmbeddingModelUpdateForm(BaseModel):
|
311 |
openai_config: Optional[OpenAIConfigForm] = None
|
|
|
312 |
embedding_engine: str
|
313 |
embedding_model: str
|
314 |
embedding_batch_size: Optional[int] = 1
|
@@ -329,6 +373,11 @@ async def update_embedding_config(
|
|
329 |
if form_data.openai_config is not None:
|
330 |
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
331 |
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
|
|
|
|
|
|
|
|
|
|
332 |
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
|
333 |
|
334 |
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
@@ -337,8 +386,16 @@ async def update_embedding_config(
|
|
337 |
app.state.config.RAG_EMBEDDING_ENGINE,
|
338 |
app.state.config.RAG_EMBEDDING_MODEL,
|
339 |
app.state.sentence_transformer_ef,
|
340 |
-
|
341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
343 |
)
|
344 |
|
@@ -351,6 +408,10 @@ async def update_embedding_config(
|
|
351 |
"url": app.state.config.OPENAI_API_BASE_URL,
|
352 |
"key": app.state.config.OPENAI_API_KEY,
|
353 |
},
|
|
|
|
|
|
|
|
|
354 |
}
|
355 |
except Exception as e:
|
356 |
log.exception(f"Problem updating embedding model: {e}")
|
@@ -411,7 +472,7 @@ async def get_rag_config(user=Depends(get_admin_user)):
|
|
411 |
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
|
412 |
},
|
413 |
"web": {
|
414 |
-
"
|
415 |
"search": {
|
416 |
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
417 |
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
|
@@ -426,6 +487,9 @@ async def get_rag_config(user=Depends(get_admin_user)):
|
|
426 |
"tavily_api_key": app.state.config.TAVILY_API_KEY,
|
427 |
"searchapi_api_key": app.state.config.SEARCHAPI_API_KEY,
|
428 |
"seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE,
|
|
|
|
|
|
|
429 |
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
430 |
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
431 |
},
|
@@ -468,6 +532,9 @@ class WebSearchConfig(BaseModel):
|
|
468 |
tavily_api_key: Optional[str] = None
|
469 |
searchapi_api_key: Optional[str] = None
|
470 |
searchapi_engine: Optional[str] = None
|
|
|
|
|
|
|
471 |
result_count: Optional[int] = None
|
472 |
concurrent_requests: Optional[int] = None
|
473 |
|
@@ -514,6 +581,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
|
|
514 |
|
515 |
if form_data.web is not None:
|
516 |
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
|
|
517 |
form_data.web.web_loader_ssl_verification
|
518 |
)
|
519 |
|
@@ -534,6 +602,15 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
|
|
534 |
app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
|
535 |
app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
|
536 |
app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
537 |
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
|
538 |
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
|
539 |
form_data.web.search.concurrent_requests
|
@@ -560,7 +637,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
|
|
560 |
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
|
561 |
},
|
562 |
"web": {
|
563 |
-
"
|
564 |
"search": {
|
565 |
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
566 |
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
|
@@ -575,6 +652,9 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
|
|
575 |
"serachapi_api_key": app.state.config.SEARCHAPI_API_KEY,
|
576 |
"searchapi_engine": app.state.config.SEARCHAPI_ENGINE,
|
577 |
"tavily_api_key": app.state.config.TAVILY_API_KEY,
|
|
|
|
|
|
|
578 |
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
579 |
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
580 |
},
|
@@ -636,6 +716,23 @@ async def update_query_settings(
|
|
636 |
####################################
|
637 |
|
638 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
639 |
def save_docs_to_vector_db(
|
640 |
docs,
|
641 |
collection_name,
|
@@ -644,7 +741,9 @@ def save_docs_to_vector_db(
|
|
644 |
split: bool = True,
|
645 |
add: bool = False,
|
646 |
) -> bool:
|
647 |
-
log.info(
|
|
|
|
|
648 |
|
649 |
# Check if entries with the same hash (metadata.hash) already exist
|
650 |
if metadata and "hash" in metadata:
|
@@ -726,8 +825,16 @@ def save_docs_to_vector_db(
|
|
726 |
app.state.config.RAG_EMBEDDING_ENGINE,
|
727 |
app.state.config.RAG_EMBEDDING_MODEL,
|
728 |
app.state.sentence_transformer_ef,
|
729 |
-
|
730 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
731 |
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
732 |
)
|
733 |
|
@@ -954,7 +1061,7 @@ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_u
|
|
954 |
|
955 |
loader = YoutubeLoader.from_youtube_url(
|
956 |
form_data.url,
|
957 |
-
add_video_info=
|
958 |
language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
959 |
translation=app.state.YOUTUBE_LOADER_TRANSLATION,
|
960 |
)
|
@@ -1132,7 +1239,20 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
|
|
1132 |
else:
|
1133 |
raise Exception("No SEARCHAPI_API_KEY found in environment variables")
|
1134 |
elif engine == "jina":
|
1135 |
-
return search_jina(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1136 |
else:
|
1137 |
raise Exception("No search engine API key found in environment variables")
|
1138 |
|
@@ -1162,8 +1282,12 @@ def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
|
|
1162 |
|
1163 |
urls = [result.link for result in web_results]
|
1164 |
|
1165 |
-
loader = get_web_loader(
|
1166 |
-
|
|
|
|
|
|
|
|
|
1167 |
|
1168 |
save_docs_to_vector_db(docs, collection_name, overwrite=True)
|
1169 |
|
|
|
37 |
from open_webui.apps.retrieval.web.serply import search_serply
|
38 |
from open_webui.apps.retrieval.web.serpstack import search_serpstack
|
39 |
from open_webui.apps.retrieval.web.tavily import search_tavily
|
40 |
+
from open_webui.apps.retrieval.web.bing import search_bing
|
41 |
|
42 |
|
43 |
from open_webui.apps.retrieval.utils import (
|
|
|
75 |
RAG_FILE_MAX_SIZE,
|
76 |
RAG_OPENAI_API_BASE_URL,
|
77 |
RAG_OPENAI_API_KEY,
|
78 |
+
RAG_OLLAMA_BASE_URL,
|
79 |
+
RAG_OLLAMA_API_KEY,
|
80 |
RAG_RELEVANCE_THRESHOLD,
|
81 |
RAG_RERANKING_MODEL,
|
82 |
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
|
|
88 |
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
89 |
RAG_WEB_SEARCH_ENGINE,
|
90 |
RAG_WEB_SEARCH_RESULT_COUNT,
|
91 |
+
JINA_API_KEY,
|
92 |
SEARCHAPI_API_KEY,
|
93 |
SEARCHAPI_ENGINE,
|
94 |
SEARXNG_QUERY_URL,
|
|
|
97 |
SERPSTACK_API_KEY,
|
98 |
SERPSTACK_HTTPS,
|
99 |
TAVILY_API_KEY,
|
100 |
+
BING_SEARCH_V7_ENDPOINT,
|
101 |
+
BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
102 |
TIKA_SERVER_URL,
|
103 |
UPLOAD_DIR,
|
104 |
YOUTUBE_LOADER_LANGUAGE,
|
105 |
+
DEFAULT_LOCALE,
|
106 |
AppConfig,
|
107 |
)
|
108 |
from open_webui.constants import ERROR_MESSAGES
|
109 |
+
from open_webui.env import (
|
110 |
+
SRC_LOG_LEVELS,
|
111 |
+
DEVICE_TYPE,
|
112 |
+
DOCKER,
|
113 |
+
)
|
114 |
from open_webui.utils.misc import (
|
115 |
calculate_sha256,
|
116 |
calculate_sha256_string,
|
|
|
129 |
log = logging.getLogger(__name__)
|
130 |
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
131 |
|
132 |
+
app = FastAPI(
|
133 |
+
docs_url="/docs" if ENV == "dev" else None,
|
134 |
+
openapi_url="/openapi.json" if ENV == "dev" else None,
|
135 |
+
redoc_url=None,
|
136 |
+
)
|
137 |
|
138 |
app.state.config = AppConfig()
|
139 |
|
|
|
165 |
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
166 |
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
167 |
|
168 |
+
app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
|
169 |
+
app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
|
170 |
+
|
171 |
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
|
172 |
|
173 |
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
|
|
|
189 |
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
|
190 |
app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
|
191 |
app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
|
192 |
+
app.state.config.JINA_API_KEY = JINA_API_KEY
|
193 |
+
app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
|
194 |
+
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
|
195 |
+
|
196 |
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
|
197 |
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
|
198 |
|
|
|
204 |
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
|
205 |
from sentence_transformers import SentenceTransformer
|
206 |
|
207 |
+
try:
|
208 |
+
app.state.sentence_transformer_ef = SentenceTransformer(
|
209 |
+
get_model_path(embedding_model, auto_update),
|
210 |
+
device=DEVICE_TYPE,
|
211 |
+
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
212 |
+
)
|
213 |
+
except Exception as e:
|
214 |
+
log.debug(f"Error loading SentenceTransformer: {e}")
|
215 |
+
app.state.sentence_transformer_ef = None
|
216 |
else:
|
217 |
app.state.sentence_transformer_ef = None
|
218 |
|
|
|
266 |
app.state.config.RAG_EMBEDDING_ENGINE,
|
267 |
app.state.config.RAG_EMBEDDING_MODEL,
|
268 |
app.state.sentence_transformer_ef,
|
269 |
+
(
|
270 |
+
app.state.config.OPENAI_API_BASE_URL
|
271 |
+
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
272 |
+
else app.state.config.OLLAMA_BASE_URL
|
273 |
+
),
|
274 |
+
(
|
275 |
+
app.state.config.OPENAI_API_KEY
|
276 |
+
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
277 |
+
else app.state.config.OLLAMA_API_KEY
|
278 |
+
),
|
279 |
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
280 |
)
|
281 |
|
|
|
325 |
"url": app.state.config.OPENAI_API_BASE_URL,
|
326 |
"key": app.state.config.OPENAI_API_KEY,
|
327 |
},
|
328 |
+
"ollama_config": {
|
329 |
+
"url": app.state.config.OLLAMA_BASE_URL,
|
330 |
+
"key": app.state.config.OLLAMA_API_KEY,
|
331 |
+
},
|
332 |
}
|
333 |
|
334 |
|
|
|
345 |
key: str
|
346 |
|
347 |
|
348 |
+
class OllamaConfigForm(BaseModel):
|
349 |
+
url: str
|
350 |
+
key: str
|
351 |
+
|
352 |
+
|
353 |
class EmbeddingModelUpdateForm(BaseModel):
|
354 |
openai_config: Optional[OpenAIConfigForm] = None
|
355 |
+
ollama_config: Optional[OllamaConfigForm] = None
|
356 |
embedding_engine: str
|
357 |
embedding_model: str
|
358 |
embedding_batch_size: Optional[int] = 1
|
|
|
373 |
if form_data.openai_config is not None:
|
374 |
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
375 |
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
376 |
+
|
377 |
+
if form_data.ollama_config is not None:
|
378 |
+
app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url
|
379 |
+
app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key
|
380 |
+
|
381 |
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
|
382 |
|
383 |
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
|
|
386 |
app.state.config.RAG_EMBEDDING_ENGINE,
|
387 |
app.state.config.RAG_EMBEDDING_MODEL,
|
388 |
app.state.sentence_transformer_ef,
|
389 |
+
(
|
390 |
+
app.state.config.OPENAI_API_BASE_URL
|
391 |
+
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
392 |
+
else app.state.config.OLLAMA_BASE_URL
|
393 |
+
),
|
394 |
+
(
|
395 |
+
app.state.config.OPENAI_API_KEY
|
396 |
+
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
397 |
+
else app.state.config.OLLAMA_API_KEY
|
398 |
+
),
|
399 |
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
400 |
)
|
401 |
|
|
|
408 |
"url": app.state.config.OPENAI_API_BASE_URL,
|
409 |
"key": app.state.config.OPENAI_API_KEY,
|
410 |
},
|
411 |
+
"ollama_config": {
|
412 |
+
"url": app.state.config.OLLAMA_BASE_URL,
|
413 |
+
"key": app.state.config.OLLAMA_API_KEY,
|
414 |
+
},
|
415 |
}
|
416 |
except Exception as e:
|
417 |
log.exception(f"Problem updating embedding model: {e}")
|
|
|
472 |
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
|
473 |
},
|
474 |
"web": {
|
475 |
+
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
476 |
"search": {
|
477 |
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
478 |
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
|
|
|
487 |
"tavily_api_key": app.state.config.TAVILY_API_KEY,
|
488 |
"searchapi_api_key": app.state.config.SEARCHAPI_API_KEY,
|
489 |
"seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE,
|
490 |
+
"jina_api_key": app.state.config.JINA_API_KEY,
|
491 |
+
"bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT,
|
492 |
+
"bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
493 |
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
494 |
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
495 |
},
|
|
|
532 |
tavily_api_key: Optional[str] = None
|
533 |
searchapi_api_key: Optional[str] = None
|
534 |
searchapi_engine: Optional[str] = None
|
535 |
+
jina_api_key: Optional[str] = None
|
536 |
+
bing_search_v7_endpoint: Optional[str] = None
|
537 |
+
bing_search_v7_subscription_key: Optional[str] = None
|
538 |
result_count: Optional[int] = None
|
539 |
concurrent_requests: Optional[int] = None
|
540 |
|
|
|
581 |
|
582 |
if form_data.web is not None:
|
583 |
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
584 |
+
# Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
|
585 |
form_data.web.web_loader_ssl_verification
|
586 |
)
|
587 |
|
|
|
602 |
app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
|
603 |
app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
|
604 |
app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine
|
605 |
+
|
606 |
+
app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key
|
607 |
+
app.state.config.BING_SEARCH_V7_ENDPOINT = (
|
608 |
+
form_data.web.search.bing_search_v7_endpoint
|
609 |
+
)
|
610 |
+
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = (
|
611 |
+
form_data.web.search.bing_search_v7_subscription_key
|
612 |
+
)
|
613 |
+
|
614 |
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
|
615 |
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
|
616 |
form_data.web.search.concurrent_requests
|
|
|
637 |
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
|
638 |
},
|
639 |
"web": {
|
640 |
+
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
641 |
"search": {
|
642 |
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
643 |
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
|
|
|
652 |
"serachapi_api_key": app.state.config.SEARCHAPI_API_KEY,
|
653 |
"searchapi_engine": app.state.config.SEARCHAPI_ENGINE,
|
654 |
"tavily_api_key": app.state.config.TAVILY_API_KEY,
|
655 |
+
"jina_api_key": app.state.config.JINA_API_KEY,
|
656 |
+
"bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT,
|
657 |
+
"bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
658 |
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
659 |
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
660 |
},
|
|
|
716 |
####################################
|
717 |
|
718 |
|
719 |
+
def _get_docs_info(docs: list[Document]) -> str:
|
720 |
+
docs_info = set()
|
721 |
+
|
722 |
+
# Trying to select relevant metadata identifying the document.
|
723 |
+
for doc in docs:
|
724 |
+
metadata = getattr(doc, "metadata", {})
|
725 |
+
doc_name = metadata.get("name", "")
|
726 |
+
if not doc_name:
|
727 |
+
doc_name = metadata.get("title", "")
|
728 |
+
if not doc_name:
|
729 |
+
doc_name = metadata.get("source", "")
|
730 |
+
if doc_name:
|
731 |
+
docs_info.add(doc_name)
|
732 |
+
|
733 |
+
return ", ".join(docs_info)
|
734 |
+
|
735 |
+
|
736 |
def save_docs_to_vector_db(
|
737 |
docs,
|
738 |
collection_name,
|
|
|
741 |
split: bool = True,
|
742 |
add: bool = False,
|
743 |
) -> bool:
|
744 |
+
log.info(
|
745 |
+
f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
|
746 |
+
)
|
747 |
|
748 |
# Check if entries with the same hash (metadata.hash) already exist
|
749 |
if metadata and "hash" in metadata:
|
|
|
825 |
app.state.config.RAG_EMBEDDING_ENGINE,
|
826 |
app.state.config.RAG_EMBEDDING_MODEL,
|
827 |
app.state.sentence_transformer_ef,
|
828 |
+
(
|
829 |
+
app.state.config.OPENAI_API_BASE_URL
|
830 |
+
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
831 |
+
else app.state.config.OLLAMA_BASE_URL
|
832 |
+
),
|
833 |
+
(
|
834 |
+
app.state.config.OPENAI_API_KEY
|
835 |
+
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
836 |
+
else app.state.config.OLLAMA_API_KEY
|
837 |
+
),
|
838 |
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
839 |
)
|
840 |
|
|
|
1061 |
|
1062 |
loader = YoutubeLoader.from_youtube_url(
|
1063 |
form_data.url,
|
1064 |
+
add_video_info=False,
|
1065 |
language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
1066 |
translation=app.state.YOUTUBE_LOADER_TRANSLATION,
|
1067 |
)
|
|
|
1239 |
else:
|
1240 |
raise Exception("No SEARCHAPI_API_KEY found in environment variables")
|
1241 |
elif engine == "jina":
|
1242 |
+
return search_jina(
|
1243 |
+
app.state.config.JINA_API_KEY,
|
1244 |
+
query,
|
1245 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
1246 |
+
)
|
1247 |
+
elif engine == "bing":
|
1248 |
+
return search_bing(
|
1249 |
+
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
1250 |
+
app.state.config.BING_SEARCH_V7_ENDPOINT,
|
1251 |
+
str(DEFAULT_LOCALE),
|
1252 |
+
query,
|
1253 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
1254 |
+
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
1255 |
+
)
|
1256 |
else:
|
1257 |
raise Exception("No search engine API key found in environment variables")
|
1258 |
|
|
|
1282 |
|
1283 |
urls = [result.link for result in web_results]
|
1284 |
|
1285 |
+
loader = get_web_loader(
|
1286 |
+
urls,
|
1287 |
+
verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
1288 |
+
requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
1289 |
+
)
|
1290 |
+
docs = loader.aload()
|
1291 |
|
1292 |
save_docs_to_vector_db(docs, collection_name, overwrite=True)
|
1293 |
|
backend/open_webui/apps/retrieval/utils.py
CHANGED
@@ -3,6 +3,7 @@ import os
|
|
3 |
import uuid
|
4 |
from typing import Optional, Union
|
5 |
|
|
|
6 |
import requests
|
7 |
|
8 |
from huggingface_hub import snapshot_download
|
@@ -10,11 +11,6 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
|
|
10 |
from langchain_community.retrievers import BM25Retriever
|
11 |
from langchain_core.documents import Document
|
12 |
|
13 |
-
|
14 |
-
from open_webui.apps.ollama.main import (
|
15 |
-
GenerateEmbedForm,
|
16 |
-
generate_ollama_batch_embeddings,
|
17 |
-
)
|
18 |
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
19 |
from open_webui.utils.misc import get_last_user_message
|
20 |
|
@@ -76,7 +72,7 @@ def query_doc(
|
|
76 |
limit=k,
|
77 |
)
|
78 |
|
79 |
-
log.info(f"query_doc:result {result}")
|
80 |
return result
|
81 |
except Exception as e:
|
82 |
print(e)
|
@@ -127,7 +123,10 @@ def query_doc_with_hybrid_search(
|
|
127 |
"metadatas": [[d.metadata for d in result]],
|
128 |
}
|
129 |
|
130 |
-
log.info(
|
|
|
|
|
|
|
131 |
return result
|
132 |
except Exception as e:
|
133 |
raise e
|
@@ -178,35 +177,34 @@ def merge_and_sort_query_results(
|
|
178 |
|
179 |
def query_collection(
|
180 |
collection_names: list[str],
|
181 |
-
|
182 |
embedding_function,
|
183 |
k: int,
|
184 |
) -> dict:
|
185 |
-
|
186 |
results = []
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
|
204 |
return merge_and_sort_query_results(results, k=k)
|
205 |
|
206 |
|
207 |
def query_collection_with_hybrid_search(
|
208 |
collection_names: list[str],
|
209 |
-
|
210 |
embedding_function,
|
211 |
k: int,
|
212 |
reranking_function,
|
@@ -216,15 +214,16 @@ def query_collection_with_hybrid_search(
|
|
216 |
error = False
|
217 |
for collection_name in collection_names:
|
218 |
try:
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
|
|
228 |
except Exception as e:
|
229 |
log.exception(
|
230 |
"Error when querying the collection with " f"hybrid_search: {e}"
|
@@ -281,8 +280,8 @@ def get_embedding_function(
|
|
281 |
embedding_engine,
|
282 |
embedding_model,
|
283 |
embedding_function,
|
284 |
-
|
285 |
-
|
286 |
embedding_batch_size,
|
287 |
):
|
288 |
if embedding_engine == "":
|
@@ -292,8 +291,8 @@ def get_embedding_function(
|
|
292 |
engine=embedding_engine,
|
293 |
model=embedding_model,
|
294 |
text=query,
|
295 |
-
|
296 |
-
|
297 |
)
|
298 |
|
299 |
def generate_multiple(query, func):
|
@@ -310,15 +309,14 @@ def get_embedding_function(
|
|
310 |
|
311 |
def get_rag_context(
|
312 |
files,
|
313 |
-
|
314 |
embedding_function,
|
315 |
k,
|
316 |
reranking_function,
|
317 |
r,
|
318 |
hybrid_search,
|
319 |
):
|
320 |
-
log.debug(f"files: {files} {
|
321 |
-
query = get_last_user_message(messages)
|
322 |
|
323 |
extracted_collections = []
|
324 |
relevant_contexts = []
|
@@ -360,7 +358,7 @@ def get_rag_context(
|
|
360 |
try:
|
361 |
context = query_collection_with_hybrid_search(
|
362 |
collection_names=collection_names,
|
363 |
-
|
364 |
embedding_function=embedding_function,
|
365 |
k=k,
|
366 |
reranking_function=reranking_function,
|
@@ -375,7 +373,7 @@ def get_rag_context(
|
|
375 |
if (not hybrid_search) or (context is None):
|
376 |
context = query_collection(
|
377 |
collection_names=collection_names,
|
378 |
-
|
379 |
embedding_function=embedding_function,
|
380 |
k=k,
|
381 |
)
|
@@ -467,7 +465,7 @@ def get_model_path(model: str, update_model: bool = False):
|
|
467 |
|
468 |
|
469 |
def generate_openai_batch_embeddings(
|
470 |
-
model: str, texts: list[str],
|
471 |
) -> Optional[list[list[float]]]:
|
472 |
try:
|
473 |
r = requests.post(
|
@@ -489,29 +487,50 @@ def generate_openai_batch_embeddings(
|
|
489 |
return None
|
490 |
|
491 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
492 |
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
|
|
|
|
|
|
493 |
if engine == "ollama":
|
494 |
if isinstance(text, list):
|
495 |
embeddings = generate_ollama_batch_embeddings(
|
496 |
-
|
497 |
)
|
498 |
else:
|
499 |
embeddings = generate_ollama_batch_embeddings(
|
500 |
-
|
501 |
)
|
502 |
-
return (
|
503 |
-
embeddings["embeddings"][0]
|
504 |
-
if isinstance(text, str)
|
505 |
-
else embeddings["embeddings"]
|
506 |
-
)
|
507 |
elif engine == "openai":
|
508 |
-
key = kwargs.get("key", "")
|
509 |
-
url = kwargs.get("url", "https://api.openai.com/v1")
|
510 |
-
|
511 |
if isinstance(text, list):
|
512 |
-
embeddings = generate_openai_batch_embeddings(model, text,
|
513 |
else:
|
514 |
-
embeddings = generate_openai_batch_embeddings(model, [text],
|
515 |
|
516 |
return embeddings[0] if isinstance(text, str) else embeddings
|
517 |
|
|
|
3 |
import uuid
|
4 |
from typing import Optional, Union
|
5 |
|
6 |
+
import asyncio
|
7 |
import requests
|
8 |
|
9 |
from huggingface_hub import snapshot_download
|
|
|
11 |
from langchain_community.retrievers import BM25Retriever
|
12 |
from langchain_core.documents import Document
|
13 |
|
|
|
|
|
|
|
|
|
|
|
14 |
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
15 |
from open_webui.utils.misc import get_last_user_message
|
16 |
|
|
|
72 |
limit=k,
|
73 |
)
|
74 |
|
75 |
+
log.info(f"query_doc:result {result.ids} {result.metadatas}")
|
76 |
return result
|
77 |
except Exception as e:
|
78 |
print(e)
|
|
|
123 |
"metadatas": [[d.metadata for d in result]],
|
124 |
}
|
125 |
|
126 |
+
log.info(
|
127 |
+
"query_doc_with_hybrid_search:result "
|
128 |
+
+ f'{result["metadatas"]} {result["distances"]}'
|
129 |
+
)
|
130 |
return result
|
131 |
except Exception as e:
|
132 |
raise e
|
|
|
177 |
|
178 |
def query_collection(
|
179 |
collection_names: list[str],
|
180 |
+
queries: list[str],
|
181 |
embedding_function,
|
182 |
k: int,
|
183 |
) -> dict:
|
|
|
184 |
results = []
|
185 |
+
for query in queries:
|
186 |
+
query_embedding = embedding_function(query)
|
187 |
+
for collection_name in collection_names:
|
188 |
+
if collection_name:
|
189 |
+
try:
|
190 |
+
result = query_doc(
|
191 |
+
collection_name=collection_name,
|
192 |
+
k=k,
|
193 |
+
query_embedding=query_embedding,
|
194 |
+
)
|
195 |
+
if result is not None:
|
196 |
+
results.append(result.model_dump())
|
197 |
+
except Exception as e:
|
198 |
+
log.exception(f"Error when querying the collection: {e}")
|
199 |
+
else:
|
200 |
+
pass
|
201 |
|
202 |
return merge_and_sort_query_results(results, k=k)
|
203 |
|
204 |
|
205 |
def query_collection_with_hybrid_search(
|
206 |
collection_names: list[str],
|
207 |
+
queries: list[str],
|
208 |
embedding_function,
|
209 |
k: int,
|
210 |
reranking_function,
|
|
|
214 |
error = False
|
215 |
for collection_name in collection_names:
|
216 |
try:
|
217 |
+
for query in queries:
|
218 |
+
result = query_doc_with_hybrid_search(
|
219 |
+
collection_name=collection_name,
|
220 |
+
query=query,
|
221 |
+
embedding_function=embedding_function,
|
222 |
+
k=k,
|
223 |
+
reranking_function=reranking_function,
|
224 |
+
r=r,
|
225 |
+
)
|
226 |
+
results.append(result)
|
227 |
except Exception as e:
|
228 |
log.exception(
|
229 |
"Error when querying the collection with " f"hybrid_search: {e}"
|
|
|
280 |
embedding_engine,
|
281 |
embedding_model,
|
282 |
embedding_function,
|
283 |
+
url,
|
284 |
+
key,
|
285 |
embedding_batch_size,
|
286 |
):
|
287 |
if embedding_engine == "":
|
|
|
291 |
engine=embedding_engine,
|
292 |
model=embedding_model,
|
293 |
text=query,
|
294 |
+
url=url,
|
295 |
+
key=key,
|
296 |
)
|
297 |
|
298 |
def generate_multiple(query, func):
|
|
|
309 |
|
310 |
def get_rag_context(
|
311 |
files,
|
312 |
+
queries,
|
313 |
embedding_function,
|
314 |
k,
|
315 |
reranking_function,
|
316 |
r,
|
317 |
hybrid_search,
|
318 |
):
|
319 |
+
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
|
|
|
320 |
|
321 |
extracted_collections = []
|
322 |
relevant_contexts = []
|
|
|
358 |
try:
|
359 |
context = query_collection_with_hybrid_search(
|
360 |
collection_names=collection_names,
|
361 |
+
queries=queries,
|
362 |
embedding_function=embedding_function,
|
363 |
k=k,
|
364 |
reranking_function=reranking_function,
|
|
|
373 |
if (not hybrid_search) or (context is None):
|
374 |
context = query_collection(
|
375 |
collection_names=collection_names,
|
376 |
+
queries=queries,
|
377 |
embedding_function=embedding_function,
|
378 |
k=k,
|
379 |
)
|
|
|
465 |
|
466 |
|
467 |
def generate_openai_batch_embeddings(
|
468 |
+
model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
|
469 |
) -> Optional[list[list[float]]]:
|
470 |
try:
|
471 |
r = requests.post(
|
|
|
487 |
return None
|
488 |
|
489 |
|
490 |
+
def generate_ollama_batch_embeddings(
|
491 |
+
model: str, texts: list[str], url: str, key: str
|
492 |
+
) -> Optional[list[list[float]]]:
|
493 |
+
try:
|
494 |
+
r = requests.post(
|
495 |
+
f"{url}/api/embed",
|
496 |
+
headers={
|
497 |
+
"Content-Type": "application/json",
|
498 |
+
"Authorization": f"Bearer {key}",
|
499 |
+
},
|
500 |
+
json={"input": texts, "model": model},
|
501 |
+
)
|
502 |
+
r.raise_for_status()
|
503 |
+
data = r.json()
|
504 |
+
|
505 |
+
print(data)
|
506 |
+
if "embeddings" in data:
|
507 |
+
return data["embeddings"]
|
508 |
+
else:
|
509 |
+
raise "Something went wrong :/"
|
510 |
+
except Exception as e:
|
511 |
+
print(e)
|
512 |
+
return None
|
513 |
+
|
514 |
+
|
515 |
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
516 |
+
url = kwargs.get("url", "")
|
517 |
+
key = kwargs.get("key", "")
|
518 |
+
|
519 |
if engine == "ollama":
|
520 |
if isinstance(text, list):
|
521 |
embeddings = generate_ollama_batch_embeddings(
|
522 |
+
**{"model": model, "texts": text, "url": url, "key": key}
|
523 |
)
|
524 |
else:
|
525 |
embeddings = generate_ollama_batch_embeddings(
|
526 |
+
**{"model": model, "texts": [text], "url": url, "key": key}
|
527 |
)
|
528 |
+
return embeddings[0] if isinstance(text, str) else embeddings
|
|
|
|
|
|
|
|
|
529 |
elif engine == "openai":
|
|
|
|
|
|
|
530 |
if isinstance(text, list):
|
531 |
+
embeddings = generate_openai_batch_embeddings(model, text, url, key)
|
532 |
else:
|
533 |
+
embeddings = generate_openai_batch_embeddings(model, [text], url, key)
|
534 |
|
535 |
return embeddings[0] if isinstance(text, str) else embeddings
|
536 |
|
backend/open_webui/apps/retrieval/vector/connector.py
CHANGED
@@ -8,6 +8,14 @@ elif VECTOR_DB == "qdrant":
|
|
8 |
from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
|
9 |
|
10 |
VECTOR_DB_CLIENT = QdrantClient()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
else:
|
12 |
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
|
13 |
|
|
|
8 |
from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
|
9 |
|
10 |
VECTOR_DB_CLIENT = QdrantClient()
|
11 |
+
elif VECTOR_DB == "opensearch":
|
12 |
+
from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient
|
13 |
+
|
14 |
+
VECTOR_DB_CLIENT = OpenSearchClient()
|
15 |
+
elif VECTOR_DB == "pgvector":
|
16 |
+
from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient
|
17 |
+
|
18 |
+
VECTOR_DB_CLIENT = PgvectorClient()
|
19 |
else:
|
20 |
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
|
21 |
|
backend/open_webui/apps/retrieval/vector/dbs/chroma.py
CHANGED
@@ -27,7 +27,9 @@ class ChromaClient:
|
|
27 |
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
|
28 |
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
|
29 |
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
|
30 |
-
settings_dict["chroma_client_auth_credentials"] =
|
|
|
|
|
31 |
|
32 |
if CHROMA_HTTP_HOST != "":
|
33 |
self.client = chromadb.HttpClient(
|
|
|
27 |
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
|
28 |
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
|
29 |
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
|
30 |
+
settings_dict["chroma_client_auth_credentials"] = (
|
31 |
+
CHROMA_CLIENT_AUTH_CREDENTIALS
|
32 |
+
)
|
33 |
|
34 |
if CHROMA_HTTP_HOST != "":
|
35 |
self.client = chromadb.HttpClient(
|
backend/open_webui/apps/retrieval/vector/dbs/opensearch.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from opensearchpy import OpenSearch
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
|
5 |
+
from open_webui.config import (
|
6 |
+
OPENSEARCH_URI,
|
7 |
+
OPENSEARCH_SSL,
|
8 |
+
OPENSEARCH_CERT_VERIFY,
|
9 |
+
OPENSEARCH_USERNAME,
|
10 |
+
OPENSEARCH_PASSWORD,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
class OpenSearchClient:
|
15 |
+
def __init__(self):
|
16 |
+
self.index_prefix = "open_webui"
|
17 |
+
self.client = OpenSearch(
|
18 |
+
hosts=[OPENSEARCH_URI],
|
19 |
+
use_ssl=OPENSEARCH_SSL,
|
20 |
+
verify_certs=OPENSEARCH_CERT_VERIFY,
|
21 |
+
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
|
22 |
+
)
|
23 |
+
|
24 |
+
def _result_to_get_result(self, result) -> GetResult:
|
25 |
+
ids = []
|
26 |
+
documents = []
|
27 |
+
metadatas = []
|
28 |
+
|
29 |
+
for hit in result["hits"]["hits"]:
|
30 |
+
ids.append(hit["_id"])
|
31 |
+
documents.append(hit["_source"].get("text"))
|
32 |
+
metadatas.append(hit["_source"].get("metadata"))
|
33 |
+
|
34 |
+
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
35 |
+
|
36 |
+
def _result_to_search_result(self, result) -> SearchResult:
|
37 |
+
ids = []
|
38 |
+
distances = []
|
39 |
+
documents = []
|
40 |
+
metadatas = []
|
41 |
+
|
42 |
+
for hit in result["hits"]["hits"]:
|
43 |
+
ids.append(hit["_id"])
|
44 |
+
distances.append(hit["_score"])
|
45 |
+
documents.append(hit["_source"].get("text"))
|
46 |
+
metadatas.append(hit["_source"].get("metadata"))
|
47 |
+
|
48 |
+
return SearchResult(
|
49 |
+
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
50 |
+
)
|
51 |
+
|
52 |
+
def _create_index(self, index_name: str, dimension: int):
|
53 |
+
body = {
|
54 |
+
"mappings": {
|
55 |
+
"properties": {
|
56 |
+
"id": {"type": "keyword"},
|
57 |
+
"vector": {
|
58 |
+
"type": "dense_vector",
|
59 |
+
"dims": dimension, # Adjust based on your vector dimensions
|
60 |
+
"index": true,
|
61 |
+
"similarity": "faiss",
|
62 |
+
"method": {
|
63 |
+
"name": "hnsw",
|
64 |
+
"space_type": "ip", # Use inner product to approximate cosine similarity
|
65 |
+
"engine": "faiss",
|
66 |
+
"ef_construction": 128,
|
67 |
+
"m": 16,
|
68 |
+
},
|
69 |
+
},
|
70 |
+
"text": {"type": "text"},
|
71 |
+
"metadata": {"type": "object"},
|
72 |
+
}
|
73 |
+
}
|
74 |
+
}
|
75 |
+
self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body)
|
76 |
+
|
77 |
+
def _create_batches(self, items: list[VectorItem], batch_size=100):
|
78 |
+
for i in range(0, len(items), batch_size):
|
79 |
+
yield items[i : i + batch_size]
|
80 |
+
|
81 |
+
def has_collection(self, index_name: str) -> bool:
|
82 |
+
# has_collection here means has index.
|
83 |
+
# We are simply adapting to the norms of the other DBs.
|
84 |
+
return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
|
85 |
+
|
86 |
+
def delete_colleciton(self, index_name: str):
|
87 |
+
# delete_collection here means delete index.
|
88 |
+
# We are simply adapting to the norms of the other DBs.
|
89 |
+
self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
|
90 |
+
|
91 |
+
def search(
|
92 |
+
self, index_name: str, vectors: list[list[float]], limit: int
|
93 |
+
) -> Optional[SearchResult]:
|
94 |
+
query = {
|
95 |
+
"size": limit,
|
96 |
+
"_source": ["text", "metadata"],
|
97 |
+
"query": {
|
98 |
+
"script_score": {
|
99 |
+
"query": {"match_all": {}},
|
100 |
+
"script": {
|
101 |
+
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
|
102 |
+
"params": {
|
103 |
+
"vector": vectors[0]
|
104 |
+
}, # Assuming single query vector
|
105 |
+
},
|
106 |
+
}
|
107 |
+
},
|
108 |
+
}
|
109 |
+
|
110 |
+
result = self.client.search(
|
111 |
+
index=f"{self.index_prefix}_{index_name}", body=query
|
112 |
+
)
|
113 |
+
|
114 |
+
return self._result_to_search_result(result)
|
115 |
+
|
116 |
+
def get_or_create_index(self, index_name: str, dimension: int):
|
117 |
+
if not self.has_index(index_name):
|
118 |
+
self._create_index(index_name, dimension)
|
119 |
+
|
120 |
+
def get(self, index_name: str) -> Optional[GetResult]:
|
121 |
+
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
|
122 |
+
|
123 |
+
result = self.client.search(
|
124 |
+
index=f"{self.index_prefix}_{index_name}", body=query
|
125 |
+
)
|
126 |
+
return self._result_to_get_result(result)
|
127 |
+
|
128 |
+
def insert(self, index_name: str, items: list[VectorItem]):
|
129 |
+
if not self.has_index(index_name):
|
130 |
+
self._create_index(index_name, dimension=len(items[0]["vector"]))
|
131 |
+
|
132 |
+
for batch in self._create_batches(items):
|
133 |
+
actions = [
|
134 |
+
{
|
135 |
+
"index": {
|
136 |
+
"_id": item["id"],
|
137 |
+
"_source": {
|
138 |
+
"vector": item["vector"],
|
139 |
+
"text": item["text"],
|
140 |
+
"metadata": item["metadata"],
|
141 |
+
},
|
142 |
+
}
|
143 |
+
}
|
144 |
+
for item in batch
|
145 |
+
]
|
146 |
+
self.client.bulk(actions)
|
147 |
+
|
148 |
+
def upsert(self, index_name: str, items: list[VectorItem]):
|
149 |
+
if not self.has_index(index_name):
|
150 |
+
self._create_index(index_name, dimension=len(items[0]["vector"]))
|
151 |
+
|
152 |
+
for batch in self._create_batches(items):
|
153 |
+
actions = [
|
154 |
+
{
|
155 |
+
"index": {
|
156 |
+
"_id": item["id"],
|
157 |
+
"_source": {
|
158 |
+
"vector": item["vector"],
|
159 |
+
"text": item["text"],
|
160 |
+
"metadata": item["metadata"],
|
161 |
+
},
|
162 |
+
}
|
163 |
+
}
|
164 |
+
for item in batch
|
165 |
+
]
|
166 |
+
self.client.bulk(actions)
|
167 |
+
|
168 |
+
def delete(self, index_name: str, ids: list[str]):
|
169 |
+
actions = [
|
170 |
+
{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}}
|
171 |
+
for id in ids
|
172 |
+
]
|
173 |
+
self.client.bulk(body=actions)
|
174 |
+
|
175 |
+
def reset(self):
|
176 |
+
indices = self.client.indices.get(index=f"{self.index_prefix}_*")
|
177 |
+
for index in indices:
|
178 |
+
self.client.indices.delete(index=index)
|
backend/open_webui/apps/retrieval/vector/dbs/pgvector.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, List, Dict, Any
|
2 |
+
from sqlalchemy import (
|
3 |
+
cast,
|
4 |
+
column,
|
5 |
+
create_engine,
|
6 |
+
Column,
|
7 |
+
Integer,
|
8 |
+
select,
|
9 |
+
text,
|
10 |
+
Text,
|
11 |
+
values,
|
12 |
+
)
|
13 |
+
from sqlalchemy.sql import true
|
14 |
+
from sqlalchemy.pool import NullPool
|
15 |
+
|
16 |
+
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
17 |
+
from sqlalchemy.dialects.postgresql import JSONB, array
|
18 |
+
from pgvector.sqlalchemy import Vector
|
19 |
+
from sqlalchemy.ext.mutable import MutableDict
|
20 |
+
|
21 |
+
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
22 |
+
from open_webui.config import PGVECTOR_DB_URL
|
23 |
+
|
24 |
+
VECTOR_LENGTH = 1536
|
25 |
+
Base = declarative_base()
|
26 |
+
|
27 |
+
|
28 |
+
class DocumentChunk(Base):
|
29 |
+
__tablename__ = "document_chunk"
|
30 |
+
|
31 |
+
id = Column(Text, primary_key=True)
|
32 |
+
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
|
33 |
+
collection_name = Column(Text, nullable=False)
|
34 |
+
text = Column(Text, nullable=True)
|
35 |
+
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
36 |
+
|
37 |
+
|
38 |
+
class PgvectorClient:
|
39 |
+
def __init__(self) -> None:
|
40 |
+
|
41 |
+
# if no pgvector uri, use the existing database connection
|
42 |
+
if not PGVECTOR_DB_URL:
|
43 |
+
from open_webui.apps.webui.internal.db import Session
|
44 |
+
|
45 |
+
self.session = Session
|
46 |
+
else:
|
47 |
+
engine = create_engine(
|
48 |
+
PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
|
49 |
+
)
|
50 |
+
SessionLocal = sessionmaker(
|
51 |
+
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
|
52 |
+
)
|
53 |
+
self.session = scoped_session(SessionLocal)
|
54 |
+
|
55 |
+
try:
|
56 |
+
# Ensure the pgvector extension is available
|
57 |
+
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
58 |
+
|
59 |
+
# Create the tables if they do not exist
|
60 |
+
# Base.metadata.create_all requires a bind (engine or connection)
|
61 |
+
# Get the connection from the session
|
62 |
+
connection = self.session.connection()
|
63 |
+
Base.metadata.create_all(bind=connection)
|
64 |
+
|
65 |
+
# Create an index on the vector column if it doesn't exist
|
66 |
+
self.session.execute(
|
67 |
+
text(
|
68 |
+
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
|
69 |
+
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
|
70 |
+
)
|
71 |
+
)
|
72 |
+
self.session.execute(
|
73 |
+
text(
|
74 |
+
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
|
75 |
+
"ON document_chunk (collection_name);"
|
76 |
+
)
|
77 |
+
)
|
78 |
+
self.session.commit()
|
79 |
+
print("Initialization complete.")
|
80 |
+
except Exception as e:
|
81 |
+
self.session.rollback()
|
82 |
+
print(f"Error during initialization: {e}")
|
83 |
+
raise
|
84 |
+
|
85 |
+
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
86 |
+
# Adjust vector to have length VECTOR_LENGTH
|
87 |
+
current_length = len(vector)
|
88 |
+
if current_length < VECTOR_LENGTH:
|
89 |
+
# Pad the vector with zeros
|
90 |
+
vector += [0.0] * (VECTOR_LENGTH - current_length)
|
91 |
+
elif current_length > VECTOR_LENGTH:
|
92 |
+
raise Exception(
|
93 |
+
f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
|
94 |
+
)
|
95 |
+
return vector
|
96 |
+
|
97 |
+
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
98 |
+
try:
|
99 |
+
new_items = []
|
100 |
+
for item in items:
|
101 |
+
vector = self.adjust_vector_length(item["vector"])
|
102 |
+
new_chunk = DocumentChunk(
|
103 |
+
id=item["id"],
|
104 |
+
vector=vector,
|
105 |
+
collection_name=collection_name,
|
106 |
+
text=item["text"],
|
107 |
+
vmetadata=item["metadata"],
|
108 |
+
)
|
109 |
+
new_items.append(new_chunk)
|
110 |
+
self.session.bulk_save_objects(new_items)
|
111 |
+
self.session.commit()
|
112 |
+
print(
|
113 |
+
f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
114 |
+
)
|
115 |
+
except Exception as e:
|
116 |
+
self.session.rollback()
|
117 |
+
print(f"Error during insert: {e}")
|
118 |
+
raise
|
119 |
+
|
120 |
+
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
121 |
+
try:
|
122 |
+
for item in items:
|
123 |
+
vector = self.adjust_vector_length(item["vector"])
|
124 |
+
existing = (
|
125 |
+
self.session.query(DocumentChunk)
|
126 |
+
.filter(DocumentChunk.id == item["id"])
|
127 |
+
.first()
|
128 |
+
)
|
129 |
+
if existing:
|
130 |
+
existing.vector = vector
|
131 |
+
existing.text = item["text"]
|
132 |
+
existing.vmetadata = item["metadata"]
|
133 |
+
existing.collection_name = (
|
134 |
+
collection_name # Update collection_name if necessary
|
135 |
+
)
|
136 |
+
else:
|
137 |
+
new_chunk = DocumentChunk(
|
138 |
+
id=item["id"],
|
139 |
+
vector=vector,
|
140 |
+
collection_name=collection_name,
|
141 |
+
text=item["text"],
|
142 |
+
vmetadata=item["metadata"],
|
143 |
+
)
|
144 |
+
self.session.add(new_chunk)
|
145 |
+
self.session.commit()
|
146 |
+
print(f"Upserted {len(items)} items into collection '{collection_name}'.")
|
147 |
+
except Exception as e:
|
148 |
+
self.session.rollback()
|
149 |
+
print(f"Error during upsert: {e}")
|
150 |
+
raise
|
151 |
+
|
152 |
+
def search(
|
153 |
+
self,
|
154 |
+
collection_name: str,
|
155 |
+
vectors: List[List[float]],
|
156 |
+
limit: Optional[int] = None,
|
157 |
+
) -> Optional[SearchResult]:
|
158 |
+
try:
|
159 |
+
if not vectors:
|
160 |
+
return None
|
161 |
+
|
162 |
+
# Adjust query vectors to VECTOR_LENGTH
|
163 |
+
vectors = [self.adjust_vector_length(vector) for vector in vectors]
|
164 |
+
num_queries = len(vectors)
|
165 |
+
|
166 |
+
def vector_expr(vector):
|
167 |
+
return cast(array(vector), Vector(VECTOR_LENGTH))
|
168 |
+
|
169 |
+
# Create the values for query vectors
|
170 |
+
qid_col = column("qid", Integer)
|
171 |
+
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
|
172 |
+
query_vectors = (
|
173 |
+
values(qid_col, q_vector_col)
|
174 |
+
.data(
|
175 |
+
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
|
176 |
+
)
|
177 |
+
.alias("query_vectors")
|
178 |
+
)
|
179 |
+
|
180 |
+
# Build the lateral subquery for each query vector
|
181 |
+
subq = (
|
182 |
+
select(
|
183 |
+
DocumentChunk.id,
|
184 |
+
DocumentChunk.text,
|
185 |
+
DocumentChunk.vmetadata,
|
186 |
+
(
|
187 |
+
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
|
188 |
+
).label("distance"),
|
189 |
+
)
|
190 |
+
.where(DocumentChunk.collection_name == collection_name)
|
191 |
+
.order_by(
|
192 |
+
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
193 |
+
)
|
194 |
+
)
|
195 |
+
if limit is not None:
|
196 |
+
subq = subq.limit(limit)
|
197 |
+
subq = subq.lateral("result")
|
198 |
+
|
199 |
+
# Build the main query by joining query_vectors and the lateral subquery
|
200 |
+
stmt = (
|
201 |
+
select(
|
202 |
+
query_vectors.c.qid,
|
203 |
+
subq.c.id,
|
204 |
+
subq.c.text,
|
205 |
+
subq.c.vmetadata,
|
206 |
+
subq.c.distance,
|
207 |
+
)
|
208 |
+
.select_from(query_vectors)
|
209 |
+
.join(subq, true())
|
210 |
+
.order_by(query_vectors.c.qid, subq.c.distance)
|
211 |
+
)
|
212 |
+
|
213 |
+
result_proxy = self.session.execute(stmt)
|
214 |
+
results = result_proxy.all()
|
215 |
+
|
216 |
+
ids = [[] for _ in range(num_queries)]
|
217 |
+
distances = [[] for _ in range(num_queries)]
|
218 |
+
documents = [[] for _ in range(num_queries)]
|
219 |
+
metadatas = [[] for _ in range(num_queries)]
|
220 |
+
|
221 |
+
if not results:
|
222 |
+
return SearchResult(
|
223 |
+
ids=ids,
|
224 |
+
distances=distances,
|
225 |
+
documents=documents,
|
226 |
+
metadatas=metadatas,
|
227 |
+
)
|
228 |
+
|
229 |
+
for row in results:
|
230 |
+
qid = int(row.qid)
|
231 |
+
ids[qid].append(row.id)
|
232 |
+
distances[qid].append(row.distance)
|
233 |
+
documents[qid].append(row.text)
|
234 |
+
metadatas[qid].append(row.vmetadata)
|
235 |
+
|
236 |
+
return SearchResult(
|
237 |
+
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
238 |
+
)
|
239 |
+
except Exception as e:
|
240 |
+
print(f"Error during search: {e}")
|
241 |
+
return None
|
242 |
+
|
243 |
+
def query(
|
244 |
+
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
245 |
+
) -> Optional[GetResult]:
|
246 |
+
try:
|
247 |
+
query = self.session.query(DocumentChunk).filter(
|
248 |
+
DocumentChunk.collection_name == collection_name
|
249 |
+
)
|
250 |
+
|
251 |
+
for key, value in filter.items():
|
252 |
+
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
253 |
+
|
254 |
+
if limit is not None:
|
255 |
+
query = query.limit(limit)
|
256 |
+
|
257 |
+
results = query.all()
|
258 |
+
|
259 |
+
if not results:
|
260 |
+
return None
|
261 |
+
|
262 |
+
ids = [[result.id for result in results]]
|
263 |
+
documents = [[result.text for result in results]]
|
264 |
+
metadatas = [[result.vmetadata for result in results]]
|
265 |
+
|
266 |
+
return GetResult(
|
267 |
+
ids=ids,
|
268 |
+
documents=documents,
|
269 |
+
metadatas=metadatas,
|
270 |
+
)
|
271 |
+
except Exception as e:
|
272 |
+
print(f"Error during query: {e}")
|
273 |
+
return None
|
274 |
+
|
275 |
+
def get(
|
276 |
+
self, collection_name: str, limit: Optional[int] = None
|
277 |
+
) -> Optional[GetResult]:
|
278 |
+
try:
|
279 |
+
query = self.session.query(DocumentChunk).filter(
|
280 |
+
DocumentChunk.collection_name == collection_name
|
281 |
+
)
|
282 |
+
if limit is not None:
|
283 |
+
query = query.limit(limit)
|
284 |
+
|
285 |
+
results = query.all()
|
286 |
+
|
287 |
+
if not results:
|
288 |
+
return None
|
289 |
+
|
290 |
+
ids = [[result.id for result in results]]
|
291 |
+
documents = [[result.text for result in results]]
|
292 |
+
metadatas = [[result.vmetadata for result in results]]
|
293 |
+
|
294 |
+
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
295 |
+
except Exception as e:
|
296 |
+
print(f"Error during get: {e}")
|
297 |
+
return None
|
298 |
+
|
299 |
+
def delete(
|
300 |
+
self,
|
301 |
+
collection_name: str,
|
302 |
+
ids: Optional[List[str]] = None,
|
303 |
+
filter: Optional[Dict[str, Any]] = None,
|
304 |
+
) -> None:
|
305 |
+
try:
|
306 |
+
query = self.session.query(DocumentChunk).filter(
|
307 |
+
DocumentChunk.collection_name == collection_name
|
308 |
+
)
|
309 |
+
if ids:
|
310 |
+
query = query.filter(DocumentChunk.id.in_(ids))
|
311 |
+
if filter:
|
312 |
+
for key, value in filter.items():
|
313 |
+
query = query.filter(
|
314 |
+
DocumentChunk.vmetadata[key].astext == str(value)
|
315 |
+
)
|
316 |
+
deleted = query.delete(synchronize_session=False)
|
317 |
+
self.session.commit()
|
318 |
+
print(f"Deleted {deleted} items from collection '{collection_name}'.")
|
319 |
+
except Exception as e:
|
320 |
+
self.session.rollback()
|
321 |
+
print(f"Error during delete: {e}")
|
322 |
+
raise
|
323 |
+
|
324 |
+
def reset(self) -> None:
|
325 |
+
try:
|
326 |
+
deleted = self.session.query(DocumentChunk).delete()
|
327 |
+
self.session.commit()
|
328 |
+
print(
|
329 |
+
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
|
330 |
+
)
|
331 |
+
except Exception as e:
|
332 |
+
self.session.rollback()
|
333 |
+
print(f"Error during reset: {e}")
|
334 |
+
raise
|
335 |
+
|
336 |
+
def close(self) -> None:
|
337 |
+
pass
|
338 |
+
|
339 |
+
def has_collection(self, collection_name: str) -> bool:
|
340 |
+
try:
|
341 |
+
exists = (
|
342 |
+
self.session.query(DocumentChunk)
|
343 |
+
.filter(DocumentChunk.collection_name == collection_name)
|
344 |
+
.first()
|
345 |
+
is not None
|
346 |
+
)
|
347 |
+
return exists
|
348 |
+
except Exception as e:
|
349 |
+
print(f"Error checking collection existence: {e}")
|
350 |
+
return False
|
351 |
+
|
352 |
+
def delete_collection(self, collection_name: str) -> None:
|
353 |
+
self.delete(collection_name)
|
354 |
+
print(f"Collection '{collection_name}' deleted.")
|
backend/open_webui/apps/retrieval/vector/dbs/qdrant.py
CHANGED
@@ -5,7 +5,7 @@ from qdrant_client.http.models import PointStruct
|
|
5 |
from qdrant_client.models import models
|
6 |
|
7 |
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
8 |
-
from open_webui.config import QDRANT_URI
|
9 |
|
10 |
NO_LIMIT = 999999999
|
11 |
|
@@ -14,7 +14,12 @@ class QdrantClient:
|
|
14 |
def __init__(self):
|
15 |
self.collection_prefix = "open-webui"
|
16 |
self.QDRANT_URI = QDRANT_URI
|
17 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def _result_to_get_result(self, points) -> GetResult:
|
20 |
ids = []
|
|
|
5 |
from qdrant_client.models import models
|
6 |
|
7 |
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
8 |
+
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
|
9 |
|
10 |
NO_LIMIT = 999999999
|
11 |
|
|
|
14 |
def __init__(self):
|
15 |
self.collection_prefix = "open-webui"
|
16 |
self.QDRANT_URI = QDRANT_URI
|
17 |
+
self.QDRANT_API_KEY = QDRANT_API_KEY
|
18 |
+
self.client = (
|
19 |
+
Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
|
20 |
+
if self.QDRANT_URI
|
21 |
+
else None
|
22 |
+
)
|
23 |
|
24 |
def _result_to_get_result(self, points) -> GetResult:
|
25 |
ids = []
|
backend/open_webui/apps/retrieval/web/bing.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pprint import pprint
|
4 |
+
from typing import Optional
|
5 |
+
import requests
|
6 |
+
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
|
7 |
+
from open_webui.env import SRC_LOG_LEVELS
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
log = logging.getLogger(__name__)
|
11 |
+
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
12 |
+
"""
|
13 |
+
Documentation: https://docs.microsoft.com/en-us/bing/search-apis/bing-web-search/overview
|
14 |
+
"""
|
15 |
+
|
16 |
+
|
17 |
+
def search_bing(
|
18 |
+
subscription_key: str,
|
19 |
+
endpoint: str,
|
20 |
+
locale: str,
|
21 |
+
query: str,
|
22 |
+
count: int,
|
23 |
+
filter_list: Optional[list[str]] = None,
|
24 |
+
) -> list[SearchResult]:
|
25 |
+
mkt = locale
|
26 |
+
params = {"q": query, "mkt": mkt, "answerCount": count}
|
27 |
+
headers = {"Ocp-Apim-Subscription-Key": subscription_key}
|
28 |
+
|
29 |
+
try:
|
30 |
+
response = requests.get(endpoint, headers=headers, params=params)
|
31 |
+
response.raise_for_status()
|
32 |
+
json_response = response.json()
|
33 |
+
results = json_response.get("webPages", {}).get("value", [])
|
34 |
+
if filter_list:
|
35 |
+
results = get_filtered_results(results, filter_list)
|
36 |
+
return [
|
37 |
+
SearchResult(
|
38 |
+
link=result["url"],
|
39 |
+
title=result.get("name"),
|
40 |
+
snippet=result.get("snippet"),
|
41 |
+
)
|
42 |
+
for result in results
|
43 |
+
]
|
44 |
+
except Exception as ex:
|
45 |
+
log.error(f"Error: {ex}")
|
46 |
+
raise ex
|
47 |
+
|
48 |
+
|
49 |
+
def main():
|
50 |
+
parser = argparse.ArgumentParser(description="Search Bing from the command line.")
|
51 |
+
parser.add_argument(
|
52 |
+
"query",
|
53 |
+
type=str,
|
54 |
+
default="Top 10 international news today",
|
55 |
+
help="The search query.",
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--count", type=int, default=10, help="Number of search results to return."
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--filter", nargs="*", help="List of filters to apply to the search results."
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--locale",
|
65 |
+
type=str,
|
66 |
+
default="en-US",
|
67 |
+
help="The locale to use for the search, maps to market in api",
|
68 |
+
)
|
69 |
+
|
70 |
+
args = parser.parse_args()
|
71 |
+
|
72 |
+
results = search_bing(args.locale, args.query, args.count, args.filter)
|
73 |
+
pprint(results)
|
backend/open_webui/apps/retrieval/web/jina_search.py
CHANGED
@@ -9,7 +9,7 @@ log = logging.getLogger(__name__)
|
|
9 |
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
10 |
|
11 |
|
12 |
-
def search_jina(query: str, count: int) -> list[SearchResult]:
|
13 |
"""
|
14 |
Search using Jina's Search API and return the results as a list of SearchResult objects.
|
15 |
Args:
|
@@ -20,9 +20,7 @@ def search_jina(query: str, count: int) -> list[SearchResult]:
|
|
20 |
list[SearchResult]: A list of search results
|
21 |
"""
|
22 |
jina_search_endpoint = "https://s.jina.ai/"
|
23 |
-
headers = {
|
24 |
-
"Accept": "application/json",
|
25 |
-
}
|
26 |
url = str(URL(jina_search_endpoint + query))
|
27 |
response = requests.get(url, headers=headers)
|
28 |
response.raise_for_status()
|
|
|
9 |
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
10 |
|
11 |
|
12 |
+
def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
|
13 |
"""
|
14 |
Search using Jina's Search API and return the results as a list of SearchResult objects.
|
15 |
Args:
|
|
|
20 |
list[SearchResult]: A list of search results
|
21 |
"""
|
22 |
jina_search_endpoint = "https://s.jina.ai/"
|
23 |
+
headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"}
|
|
|
|
|
24 |
url = str(URL(jina_search_endpoint + query))
|
25 |
response = requests.get(url, headers=headers)
|
26 |
response.raise_for_status()
|
backend/open_webui/apps/retrieval/web/testdata/bing.json
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_type": "SearchResponse",
|
3 |
+
"queryContext": {
|
4 |
+
"originalQuery": "Top 10 international results"
|
5 |
+
},
|
6 |
+
"webPages": {
|
7 |
+
"webSearchUrl": "https://www.bing.com/search?q=Top+10+international+results",
|
8 |
+
"totalEstimatedMatches": 687,
|
9 |
+
"value": [
|
10 |
+
{
|
11 |
+
"id": "https://api.bing.microsoft.com/api/v7/#WebPages.0",
|
12 |
+
"name": "2024 Mexican Grand Prix - F1 results and latest standings ... - PlanetF1",
|
13 |
+
"url": "https://www.planetf1.com/news/f1-results-2024-mexican-grand-prix-race-standings",
|
14 |
+
"datePublished": "2024-10-27T00:00:00.0000000",
|
15 |
+
"datePublishedFreshnessText": "1 day ago",
|
16 |
+
"isFamilyFriendly": true,
|
17 |
+
"displayUrl": "https://www.planetf1.com/news/f1-results-2024-mexican-grand-prix-race-standings",
|
18 |
+
"snippet": "Nico Hulkenberg and Pierre Gasly completed the top 10. A full report of the Mexican Grand Prix is available at the bottom of this article. F1 results – 2024 Mexican Grand Prix",
|
19 |
+
"dateLastCrawled": "2024-10-28T07:15:00.0000000Z",
|
20 |
+
"cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=916492551782&mkt=en-US&setlang=en-US&w=zBsfaAPyF2tUrHFHr_vFFdUm8sng4g34",
|
21 |
+
"language": "en",
|
22 |
+
"isNavigational": false,
|
23 |
+
"noCache": false
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"id": "https://api.bing.microsoft.com/api/v7/#WebPages.1",
|
27 |
+
"name": "F1 Results Today: HUGE Verstappen penalties cause major title change",
|
28 |
+
"url": "https://www.gpfans.com/en/f1-news/1033512/f1-results-today-mexican-grand-prix-huge-max-verstappen-penalties-cause-major-title-change/",
|
29 |
+
"datePublished": "2024-10-27T00:00:00.0000000",
|
30 |
+
"datePublishedFreshnessText": "1 day ago",
|
31 |
+
"isFamilyFriendly": true,
|
32 |
+
"displayUrl": "https://www.gpfans.com/en/f1-news/1033512/f1-results-today-mexican-grand-prix-huge-max...",
|
33 |
+
"snippet": "Elsewhere, Mercedes duo Lewis Hamilton and George Russell came home in P4 and P5 respectively. Meanwhile, the surprise package of the day were Haas, with both Kevin Magnussen and Nico Hulkenberg finishing inside the points.. READ MORE: RB star issues apology after red flag CRASH at Mexican GP Mexican Grand Prix 2024 results. 1. Carlos Sainz [Ferrari] 2. Lando Norris [McLaren] - +4.705",
|
34 |
+
"dateLastCrawled": "2024-10-28T06:06:00.0000000Z",
|
35 |
+
"cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=2840656522642&mkt=en-US&setlang=en-US&w=-Tbkwxnq52jZCvG7l3CtgcwT1vwAjIUD",
|
36 |
+
"language": "en",
|
37 |
+
"isNavigational": false,
|
38 |
+
"noCache": false
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"id": "https://api.bing.microsoft.com/api/v7/#WebPages.2",
|
42 |
+
"name": "International Power Rankings: England flying, Kangaroos cruising, Fiji rise",
|
43 |
+
"url": "https://www.loverugbyleague.com/post/international-power-rankings-england-flying-kangaroos-cruising-fiji-rise",
|
44 |
+
"datePublished": "2024-10-28T00:00:00.0000000",
|
45 |
+
"datePublishedFreshnessText": "7 hours ago",
|
46 |
+
"isFamilyFriendly": true,
|
47 |
+
"displayUrl": "https://www.loverugbyleague.com/post/international-power-rankings-england-flying...",
|
48 |
+
"snippet": "LRL RECOMMENDS: England player ratings from first Test against Samoa as omnificent George Williams scores perfect 10. 2. Australia (Men) – SAME. The Kangaroos remain 2nd in our Power Rankings after their 22-10 win against New Zealand in Christchurch on Sunday. As was the case in their win against Tonga last week, Mal Meninga’s side weren ...",
|
49 |
+
"dateLastCrawled": "2024-10-28T07:09:00.0000000Z",
|
50 |
+
"cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=1535008462672&mkt=en-US&setlang=en-US&w=82ujhH4Kp0iuhCS7wh1xLUFYUeetaVVm",
|
51 |
+
"language": "en",
|
52 |
+
"isNavigational": false,
|
53 |
+
"noCache": false
|
54 |
+
}
|
55 |
+
],
|
56 |
+
"someResultsRemoved": true
|
57 |
+
}
|
58 |
+
}
|
backend/open_webui/apps/socket/main.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import asyncio
|
2 |
import socketio
|
3 |
import logging
|
|
|
1 |
+
# TODO: move socket to webui app
|
2 |
+
|
3 |
import asyncio
|
4 |
import socketio
|
5 |
import logging
|
backend/open_webui/apps/webui/main.py
CHANGED
@@ -12,6 +12,7 @@ from open_webui.apps.webui.routers import (
|
|
12 |
chats,
|
13 |
folders,
|
14 |
configs,
|
|
|
15 |
files,
|
16 |
functions,
|
17 |
memories,
|
@@ -34,6 +35,7 @@ from open_webui.config import (
|
|
34 |
ENABLE_LOGIN_FORM,
|
35 |
ENABLE_MESSAGE_RATING,
|
36 |
ENABLE_SIGNUP,
|
|
|
37 |
ENABLE_EVALUATION_ARENA_MODELS,
|
38 |
EVALUATION_ARENA_MODELS,
|
39 |
DEFAULT_ARENA_MODEL,
|
@@ -50,9 +52,22 @@ from open_webui.config import (
|
|
50 |
WEBHOOK_URL,
|
51 |
WEBUI_AUTH,
|
52 |
WEBUI_BANNERS,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
AppConfig,
|
54 |
)
|
55 |
from open_webui.env import (
|
|
|
56 |
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
57 |
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
58 |
)
|
@@ -72,7 +87,11 @@ from open_webui.utils.payload import (
|
|
72 |
|
73 |
from open_webui.utils.tools import get_tools
|
74 |
|
75 |
-
app = FastAPI(
|
|
|
|
|
|
|
|
|
76 |
|
77 |
log = logging.getLogger(__name__)
|
78 |
|
@@ -80,6 +99,8 @@ app.state.config = AppConfig()
|
|
80 |
|
81 |
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
|
82 |
app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
|
|
|
|
|
83 |
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
84 |
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
85 |
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
|
@@ -92,6 +113,8 @@ app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
|
|
92 |
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
|
93 |
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
94 |
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
|
|
|
|
95 |
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
|
96 |
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
97 |
app.state.config.BANNERS = WEBUI_BANNERS
|
@@ -111,7 +134,19 @@ app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
|
|
111 |
app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
|
112 |
app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
|
113 |
|
114 |
-
app.state.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
app.state.TOOLS = {}
|
116 |
app.state.FUNCTIONS = {}
|
117 |
|
@@ -135,13 +170,15 @@ app.include_router(models.router, prefix="/models", tags=["models"])
|
|
135 |
app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
|
136 |
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
|
137 |
app.include_router(tools.router, prefix="/tools", tags=["tools"])
|
138 |
-
app.include_router(functions.router, prefix="/functions", tags=["functions"])
|
139 |
|
140 |
app.include_router(memories.router, prefix="/memories", tags=["memories"])
|
141 |
-
app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
|
142 |
-
|
143 |
app.include_router(folders.router, prefix="/folders", tags=["folders"])
|
|
|
|
|
144 |
app.include_router(files.router, prefix="/files", tags=["files"])
|
|
|
|
|
|
|
145 |
|
146 |
app.include_router(utils.router, prefix="/utils", tags=["utils"])
|
147 |
|
@@ -336,7 +373,7 @@ def get_function_params(function_module, form_data, user, extra_params=None):
|
|
336 |
return params
|
337 |
|
338 |
|
339 |
-
async def generate_function_chat_completion(form_data, user):
|
340 |
model_id = form_data.get("model")
|
341 |
model_info = Models.get_model_by_id(model_id)
|
342 |
|
@@ -372,6 +409,7 @@ async def generate_function_chat_completion(form_data, user):
|
|
372 |
"name": user.name,
|
373 |
"role": user.role,
|
374 |
},
|
|
|
375 |
}
|
376 |
extra_params["__tools__"] = get_tools(
|
377 |
app,
|
@@ -379,7 +417,7 @@ async def generate_function_chat_completion(form_data, user):
|
|
379 |
user,
|
380 |
{
|
381 |
**extra_params,
|
382 |
-
"__model__":
|
383 |
"__messages__": form_data["messages"],
|
384 |
"__files__": files,
|
385 |
},
|
|
|
12 |
chats,
|
13 |
folders,
|
14 |
configs,
|
15 |
+
groups,
|
16 |
files,
|
17 |
functions,
|
18 |
memories,
|
|
|
35 |
ENABLE_LOGIN_FORM,
|
36 |
ENABLE_MESSAGE_RATING,
|
37 |
ENABLE_SIGNUP,
|
38 |
+
ENABLE_API_KEY,
|
39 |
ENABLE_EVALUATION_ARENA_MODELS,
|
40 |
EVALUATION_ARENA_MODELS,
|
41 |
DEFAULT_ARENA_MODEL,
|
|
|
52 |
WEBHOOK_URL,
|
53 |
WEBUI_AUTH,
|
54 |
WEBUI_BANNERS,
|
55 |
+
ENABLE_LDAP,
|
56 |
+
LDAP_SERVER_LABEL,
|
57 |
+
LDAP_SERVER_HOST,
|
58 |
+
LDAP_SERVER_PORT,
|
59 |
+
LDAP_ATTRIBUTE_FOR_USERNAME,
|
60 |
+
LDAP_SEARCH_FILTERS,
|
61 |
+
LDAP_SEARCH_BASE,
|
62 |
+
LDAP_APP_DN,
|
63 |
+
LDAP_APP_PASSWORD,
|
64 |
+
LDAP_USE_TLS,
|
65 |
+
LDAP_CA_CERT_FILE,
|
66 |
+
LDAP_CIPHERS,
|
67 |
AppConfig,
|
68 |
)
|
69 |
from open_webui.env import (
|
70 |
+
ENV,
|
71 |
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
72 |
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
73 |
)
|
|
|
87 |
|
88 |
from open_webui.utils.tools import get_tools
|
89 |
|
90 |
+
app = FastAPI(
|
91 |
+
docs_url="/docs" if ENV == "dev" else None,
|
92 |
+
openapi_url="/openapi.json" if ENV == "dev" else None,
|
93 |
+
redoc_url=None,
|
94 |
+
)
|
95 |
|
96 |
log = logging.getLogger(__name__)
|
97 |
|
|
|
99 |
|
100 |
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
|
101 |
app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
|
102 |
+
app.state.config.ENABLE_API_KEY = ENABLE_API_KEY
|
103 |
+
|
104 |
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
105 |
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
106 |
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
|
|
|
113 |
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
|
114 |
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
115 |
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
116 |
+
|
117 |
+
|
118 |
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
|
119 |
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
120 |
app.state.config.BANNERS = WEBUI_BANNERS
|
|
|
134 |
app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
|
135 |
app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
|
136 |
|
137 |
+
app.state.config.ENABLE_LDAP = ENABLE_LDAP
|
138 |
+
app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL
|
139 |
+
app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST
|
140 |
+
app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT
|
141 |
+
app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME
|
142 |
+
app.state.config.LDAP_APP_DN = LDAP_APP_DN
|
143 |
+
app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD
|
144 |
+
app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE
|
145 |
+
app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS
|
146 |
+
app.state.config.LDAP_USE_TLS = LDAP_USE_TLS
|
147 |
+
app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
|
148 |
+
app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
|
149 |
+
|
150 |
app.state.TOOLS = {}
|
151 |
app.state.FUNCTIONS = {}
|
152 |
|
|
|
170 |
app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
|
171 |
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
|
172 |
app.include_router(tools.router, prefix="/tools", tags=["tools"])
|
|
|
173 |
|
174 |
app.include_router(memories.router, prefix="/memories", tags=["memories"])
|
|
|
|
|
175 |
app.include_router(folders.router, prefix="/folders", tags=["folders"])
|
176 |
+
|
177 |
+
app.include_router(groups.router, prefix="/groups", tags=["groups"])
|
178 |
app.include_router(files.router, prefix="/files", tags=["files"])
|
179 |
+
app.include_router(functions.router, prefix="/functions", tags=["functions"])
|
180 |
+
app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
|
181 |
+
|
182 |
|
183 |
app.include_router(utils.router, prefix="/utils", tags=["utils"])
|
184 |
|
|
|
373 |
return params
|
374 |
|
375 |
|
376 |
+
async def generate_function_chat_completion(form_data, user, models: dict = {}):
|
377 |
model_id = form_data.get("model")
|
378 |
model_info = Models.get_model_by_id(model_id)
|
379 |
|
|
|
409 |
"name": user.name,
|
410 |
"role": user.role,
|
411 |
},
|
412 |
+
"__metadata__": metadata,
|
413 |
}
|
414 |
extra_params["__tools__"] = get_tools(
|
415 |
app,
|
|
|
417 |
user,
|
418 |
{
|
419 |
**extra_params,
|
420 |
+
"__model__": models.get(form_data["model"], None),
|
421 |
"__messages__": form_data["messages"],
|
422 |
"__files__": files,
|
423 |
},
|
backend/open_webui/apps/webui/models/auths.py
CHANGED
@@ -64,6 +64,11 @@ class SigninForm(BaseModel):
|
|
64 |
password: str
|
65 |
|
66 |
|
|
|
|
|
|
|
|
|
|
|
67 |
class ProfileImageUrlForm(BaseModel):
|
68 |
profile_image_url: str
|
69 |
|
|
|
64 |
password: str
|
65 |
|
66 |
|
67 |
+
class LdapForm(BaseModel):
|
68 |
+
user: str
|
69 |
+
password: str
|
70 |
+
|
71 |
+
|
72 |
class ProfileImageUrlForm(BaseModel):
|
73 |
profile_image_url: str
|
74 |
|
backend/open_webui/apps/webui/models/chats.py
CHANGED
@@ -203,15 +203,22 @@ class ChatTable:
|
|
203 |
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
204 |
try:
|
205 |
with get_db() as db:
|
206 |
-
print("update_shared_chat_by_id")
|
207 |
chat = db.get(Chat, chat_id)
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
db.commit()
|
212 |
-
db.refresh(
|
213 |
|
214 |
-
return
|
215 |
except Exception:
|
216 |
return None
|
217 |
|
|
|
203 |
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
204 |
try:
|
205 |
with get_db() as db:
|
|
|
206 |
chat = db.get(Chat, chat_id)
|
207 |
+
shared_chat = (
|
208 |
+
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first()
|
209 |
+
)
|
210 |
+
|
211 |
+
if shared_chat is None:
|
212 |
+
return self.insert_shared_chat_by_chat_id(chat_id)
|
213 |
+
|
214 |
+
shared_chat.title = chat.title
|
215 |
+
shared_chat.chat = chat.chat
|
216 |
+
|
217 |
+
shared_chat.updated_at = int(time.time())
|
218 |
db.commit()
|
219 |
+
db.refresh(shared_chat)
|
220 |
|
221 |
+
return ChatModel.model_validate(shared_chat)
|
222 |
except Exception:
|
223 |
return None
|
224 |
|
backend/open_webui/apps/webui/models/groups.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
from typing import Optional
|
5 |
+
import uuid
|
6 |
+
|
7 |
+
from open_webui.apps.webui.internal.db import Base, get_db
|
8 |
+
from open_webui.env import SRC_LOG_LEVELS
|
9 |
+
|
10 |
+
from open_webui.apps.webui.models.files import FileMetadataResponse
|
11 |
+
|
12 |
+
|
13 |
+
from pydantic import BaseModel, ConfigDict
|
14 |
+
from sqlalchemy import BigInteger, Column, String, Text, JSON, func
|
15 |
+
|
16 |
+
|
17 |
+
log = logging.getLogger(__name__)
|
18 |
+
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
19 |
+
|
20 |
+
####################
|
21 |
+
# UserGroup DB Schema
|
22 |
+
####################
|
23 |
+
|
24 |
+
|
25 |
+
class Group(Base):
|
26 |
+
__tablename__ = "group"
|
27 |
+
|
28 |
+
id = Column(Text, unique=True, primary_key=True)
|
29 |
+
user_id = Column(Text)
|
30 |
+
|
31 |
+
name = Column(Text)
|
32 |
+
description = Column(Text)
|
33 |
+
|
34 |
+
data = Column(JSON, nullable=True)
|
35 |
+
meta = Column(JSON, nullable=True)
|
36 |
+
|
37 |
+
permissions = Column(JSON, nullable=True)
|
38 |
+
user_ids = Column(JSON, nullable=True)
|
39 |
+
|
40 |
+
created_at = Column(BigInteger)
|
41 |
+
updated_at = Column(BigInteger)
|
42 |
+
|
43 |
+
|
44 |
+
class GroupModel(BaseModel):
|
45 |
+
model_config = ConfigDict(from_attributes=True)
|
46 |
+
id: str
|
47 |
+
user_id: str
|
48 |
+
|
49 |
+
name: str
|
50 |
+
description: str
|
51 |
+
|
52 |
+
data: Optional[dict] = None
|
53 |
+
meta: Optional[dict] = None
|
54 |
+
|
55 |
+
permissions: Optional[dict] = None
|
56 |
+
user_ids: list[str] = []
|
57 |
+
|
58 |
+
created_at: int # timestamp in epoch
|
59 |
+
updated_at: int # timestamp in epoch
|
60 |
+
|
61 |
+
|
62 |
+
####################
|
63 |
+
# Forms
|
64 |
+
####################
|
65 |
+
|
66 |
+
|
67 |
+
class GroupResponse(BaseModel):
|
68 |
+
id: str
|
69 |
+
user_id: str
|
70 |
+
name: str
|
71 |
+
description: str
|
72 |
+
permissions: Optional[dict] = None
|
73 |
+
data: Optional[dict] = None
|
74 |
+
meta: Optional[dict] = None
|
75 |
+
user_ids: list[str] = []
|
76 |
+
created_at: int # timestamp in epoch
|
77 |
+
updated_at: int # timestamp in epoch
|
78 |
+
|
79 |
+
|
80 |
+
class GroupForm(BaseModel):
|
81 |
+
name: str
|
82 |
+
description: str
|
83 |
+
|
84 |
+
|
85 |
+
class GroupUpdateForm(GroupForm):
|
86 |
+
permissions: Optional[dict] = None
|
87 |
+
user_ids: Optional[list[str]] = None
|
88 |
+
admin_ids: Optional[list[str]] = None
|
89 |
+
|
90 |
+
|
91 |
+
class GroupTable:
|
92 |
+
def insert_new_group(
|
93 |
+
self, user_id: str, form_data: GroupForm
|
94 |
+
) -> Optional[GroupModel]:
|
95 |
+
with get_db() as db:
|
96 |
+
group = GroupModel(
|
97 |
+
**{
|
98 |
+
**form_data.model_dump(),
|
99 |
+
"id": str(uuid.uuid4()),
|
100 |
+
"user_id": user_id,
|
101 |
+
"created_at": int(time.time()),
|
102 |
+
"updated_at": int(time.time()),
|
103 |
+
}
|
104 |
+
)
|
105 |
+
|
106 |
+
try:
|
107 |
+
result = Group(**group.model_dump())
|
108 |
+
db.add(result)
|
109 |
+
db.commit()
|
110 |
+
db.refresh(result)
|
111 |
+
if result:
|
112 |
+
return GroupModel.model_validate(result)
|
113 |
+
else:
|
114 |
+
return None
|
115 |
+
|
116 |
+
except Exception:
|
117 |
+
return None
|
118 |
+
|
119 |
+
def get_groups(self) -> list[GroupModel]:
|
120 |
+
with get_db() as db:
|
121 |
+
return [
|
122 |
+
GroupModel.model_validate(group)
|
123 |
+
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
|
124 |
+
]
|
125 |
+
|
126 |
+
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
|
127 |
+
with get_db() as db:
|
128 |
+
return [
|
129 |
+
GroupModel.model_validate(group)
|
130 |
+
for group in db.query(Group)
|
131 |
+
.filter(
|
132 |
+
func.json_array_length(Group.user_ids) > 0
|
133 |
+
) # Ensure array exists
|
134 |
+
.filter(
|
135 |
+
Group.user_ids.cast(String).like(f'%"{user_id}"%')
|
136 |
+
) # String-based check
|
137 |
+
.order_by(Group.updated_at.desc())
|
138 |
+
.all()
|
139 |
+
]
|
140 |
+
|
141 |
+
def get_group_by_id(self, id: str) -> Optional[GroupModel]:
|
142 |
+
try:
|
143 |
+
with get_db() as db:
|
144 |
+
group = db.query(Group).filter_by(id=id).first()
|
145 |
+
return GroupModel.model_validate(group) if group else None
|
146 |
+
except Exception:
|
147 |
+
return None
|
148 |
+
|
149 |
+
def update_group_by_id(
|
150 |
+
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
|
151 |
+
) -> Optional[GroupModel]:
|
152 |
+
try:
|
153 |
+
with get_db() as db:
|
154 |
+
db.query(Group).filter_by(id=id).update(
|
155 |
+
{
|
156 |
+
**form_data.model_dump(exclude_none=True),
|
157 |
+
"updated_at": int(time.time()),
|
158 |
+
}
|
159 |
+
)
|
160 |
+
db.commit()
|
161 |
+
return self.get_group_by_id(id=id)
|
162 |
+
except Exception as e:
|
163 |
+
log.exception(e)
|
164 |
+
return None
|
165 |
+
|
166 |
+
def delete_group_by_id(self, id: str) -> bool:
|
167 |
+
try:
|
168 |
+
with get_db() as db:
|
169 |
+
db.query(Group).filter_by(id=id).delete()
|
170 |
+
db.commit()
|
171 |
+
return True
|
172 |
+
except Exception:
|
173 |
+
return False
|
174 |
+
|
175 |
+
def delete_all_groups(self) -> bool:
|
176 |
+
with get_db() as db:
|
177 |
+
try:
|
178 |
+
db.query(Group).delete()
|
179 |
+
db.commit()
|
180 |
+
|
181 |
+
return True
|
182 |
+
except Exception:
|
183 |
+
return False
|
184 |
+
|
185 |
+
|
186 |
+
Groups = GroupTable()
|
backend/open_webui/apps/webui/models/knowledge.py
CHANGED
@@ -8,11 +8,13 @@ from open_webui.apps.webui.internal.db import Base, get_db
|
|
8 |
from open_webui.env import SRC_LOG_LEVELS
|
9 |
|
10 |
from open_webui.apps.webui.models.files import FileMetadataResponse
|
|
|
11 |
|
12 |
|
13 |
from pydantic import BaseModel, ConfigDict
|
14 |
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
15 |
|
|
|
16 |
|
17 |
log = logging.getLogger(__name__)
|
18 |
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
@@ -34,6 +36,23 @@ class Knowledge(Base):
|
|
34 |
data = Column(JSON, nullable=True)
|
35 |
meta = Column(JSON, nullable=True)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
created_at = Column(BigInteger)
|
38 |
updated_at = Column(BigInteger)
|
39 |
|
@@ -50,6 +69,8 @@ class KnowledgeModel(BaseModel):
|
|
50 |
data: Optional[dict] = None
|
51 |
meta: Optional[dict] = None
|
52 |
|
|
|
|
|
53 |
created_at: int # timestamp in epoch
|
54 |
updated_at: int # timestamp in epoch
|
55 |
|
@@ -59,15 +80,15 @@ class KnowledgeModel(BaseModel):
|
|
59 |
####################
|
60 |
|
61 |
|
62 |
-
class
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
created_at: int # timestamp in epoch
|
69 |
-
updated_at: int # timestamp in epoch
|
70 |
|
|
|
|
|
71 |
files: Optional[list[FileMetadataResponse | dict]] = None
|
72 |
|
73 |
|
@@ -75,12 +96,7 @@ class KnowledgeForm(BaseModel):
|
|
75 |
name: str
|
76 |
description: str
|
77 |
data: Optional[dict] = None
|
78 |
-
|
79 |
-
|
80 |
-
class KnowledgeUpdateForm(BaseModel):
|
81 |
-
name: Optional[str] = None
|
82 |
-
description: Optional[str] = None
|
83 |
-
data: Optional[dict] = None
|
84 |
|
85 |
|
86 |
class KnowledgeTable:
|
@@ -110,14 +126,33 @@ class KnowledgeTable:
|
|
110 |
except Exception:
|
111 |
return None
|
112 |
|
113 |
-
def
|
114 |
with get_db() as db:
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
.
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
|
123 |
try:
|
@@ -128,14 +163,32 @@ class KnowledgeTable:
|
|
128 |
return None
|
129 |
|
130 |
def update_knowledge_by_id(
|
131 |
-
self, id: str, form_data:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
) -> Optional[KnowledgeModel]:
|
133 |
try:
|
134 |
with get_db() as db:
|
135 |
knowledge = self.get_knowledge_by_id(id=id)
|
136 |
db.query(Knowledge).filter_by(id=id).update(
|
137 |
{
|
138 |
-
|
139 |
"updated_at": int(time.time()),
|
140 |
}
|
141 |
)
|
|
|
8 |
from open_webui.env import SRC_LOG_LEVELS
|
9 |
|
10 |
from open_webui.apps.webui.models.files import FileMetadataResponse
|
11 |
+
from open_webui.apps.webui.models.users import Users, UserResponse
|
12 |
|
13 |
|
14 |
from pydantic import BaseModel, ConfigDict
|
15 |
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
16 |
|
17 |
+
from open_webui.utils.access_control import has_access
|
18 |
|
19 |
log = logging.getLogger(__name__)
|
20 |
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
|
|
36 |
data = Column(JSON, nullable=True)
|
37 |
meta = Column(JSON, nullable=True)
|
38 |
|
39 |
+
access_control = Column(JSON, nullable=True) # Controls data access levels.
|
40 |
+
# Defines access control rules for this entry.
|
41 |
+
# - `None`: Public access, available to all users with the "user" role.
|
42 |
+
# - `{}`: Private access, restricted exclusively to the owner.
|
43 |
+
# - Custom permissions: Specific access control for reading and writing;
|
44 |
+
# Can specify group or user-level restrictions:
|
45 |
+
# {
|
46 |
+
# "read": {
|
47 |
+
# "group_ids": ["group_id1", "group_id2"],
|
48 |
+
# "user_ids": ["user_id1", "user_id2"]
|
49 |
+
# },
|
50 |
+
# "write": {
|
51 |
+
# "group_ids": ["group_id1", "group_id2"],
|
52 |
+
# "user_ids": ["user_id1", "user_id2"]
|
53 |
+
# }
|
54 |
+
# }
|
55 |
+
|
56 |
created_at = Column(BigInteger)
|
57 |
updated_at = Column(BigInteger)
|
58 |
|
|
|
69 |
data: Optional[dict] = None
|
70 |
meta: Optional[dict] = None
|
71 |
|
72 |
+
access_control: Optional[dict] = None
|
73 |
+
|
74 |
created_at: int # timestamp in epoch
|
75 |
updated_at: int # timestamp in epoch
|
76 |
|
|
|
80 |
####################
|
81 |
|
82 |
|
83 |
+
class KnowledgeUserModel(KnowledgeModel):
|
84 |
+
user: Optional[UserResponse] = None
|
85 |
+
|
86 |
+
|
87 |
+
class KnowledgeResponse(KnowledgeModel):
|
88 |
+
files: Optional[list[FileMetadataResponse | dict]] = None
|
|
|
|
|
89 |
|
90 |
+
|
91 |
+
class KnowledgeUserResponse(KnowledgeUserModel):
|
92 |
files: Optional[list[FileMetadataResponse | dict]] = None
|
93 |
|
94 |
|
|
|
96 |
name: str
|
97 |
description: str
|
98 |
data: Optional[dict] = None
|
99 |
+
access_control: Optional[dict] = None
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
|
102 |
class KnowledgeTable:
|
|
|
126 |
except Exception:
|
127 |
return None
|
128 |
|
129 |
+
def get_knowledge_bases(self) -> list[KnowledgeUserModel]:
|
130 |
with get_db() as db:
|
131 |
+
knowledge_bases = []
|
132 |
+
for knowledge in (
|
133 |
+
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
|
134 |
+
):
|
135 |
+
user = Users.get_user_by_id(knowledge.user_id)
|
136 |
+
knowledge_bases.append(
|
137 |
+
KnowledgeUserModel.model_validate(
|
138 |
+
{
|
139 |
+
**KnowledgeModel.model_validate(knowledge).model_dump(),
|
140 |
+
"user": user.model_dump() if user else None,
|
141 |
+
}
|
142 |
+
)
|
143 |
+
)
|
144 |
+
return knowledge_bases
|
145 |
+
|
146 |
+
def get_knowledge_bases_by_user_id(
|
147 |
+
self, user_id: str, permission: str = "write"
|
148 |
+
) -> list[KnowledgeUserModel]:
|
149 |
+
knowledge_bases = self.get_knowledge_bases()
|
150 |
+
return [
|
151 |
+
knowledge_base
|
152 |
+
for knowledge_base in knowledge_bases
|
153 |
+
if knowledge_base.user_id == user_id
|
154 |
+
or has_access(user_id, permission, knowledge_base.access_control)
|
155 |
+
]
|
156 |
|
157 |
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
|
158 |
try:
|
|
|
163 |
return None
|
164 |
|
165 |
def update_knowledge_by_id(
|
166 |
+
self, id: str, form_data: KnowledgeForm, overwrite: bool = False
|
167 |
+
) -> Optional[KnowledgeModel]:
|
168 |
+
try:
|
169 |
+
with get_db() as db:
|
170 |
+
knowledge = self.get_knowledge_by_id(id=id)
|
171 |
+
db.query(Knowledge).filter_by(id=id).update(
|
172 |
+
{
|
173 |
+
**form_data.model_dump(),
|
174 |
+
"updated_at": int(time.time()),
|
175 |
+
}
|
176 |
+
)
|
177 |
+
db.commit()
|
178 |
+
return self.get_knowledge_by_id(id=id)
|
179 |
+
except Exception as e:
|
180 |
+
log.exception(e)
|
181 |
+
return None
|
182 |
+
|
183 |
+
def update_knowledge_data_by_id(
|
184 |
+
self, id: str, data: dict
|
185 |
) -> Optional[KnowledgeModel]:
|
186 |
try:
|
187 |
with get_db() as db:
|
188 |
knowledge = self.get_knowledge_by_id(id=id)
|
189 |
db.query(Knowledge).filter_by(id=id).update(
|
190 |
{
|
191 |
+
"data": data,
|
192 |
"updated_at": int(time.time()),
|
193 |
}
|
194 |
)
|
backend/open_webui/apps/webui/models/models.py
CHANGED
@@ -4,8 +4,19 @@ from typing import Optional
|
|
4 |
|
5 |
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
|
6 |
from open_webui.env import SRC_LOG_LEVELS
|
|
|
|
|
|
|
|
|
7 |
from pydantic import BaseModel, ConfigDict
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
log = logging.getLogger(__name__)
|
11 |
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
@@ -67,6 +78,25 @@ class Model(Base):
|
|
67 |
Holds a JSON encoded blob of metadata, see `ModelMeta`.
|
68 |
"""
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
updated_at = Column(BigInteger)
|
71 |
created_at = Column(BigInteger)
|
72 |
|
@@ -80,6 +110,9 @@ class ModelModel(BaseModel):
|
|
80 |
params: ModelParams
|
81 |
meta: ModelMeta
|
82 |
|
|
|
|
|
|
|
83 |
updated_at: int # timestamp in epoch
|
84 |
created_at: int # timestamp in epoch
|
85 |
|
@@ -91,12 +124,12 @@ class ModelModel(BaseModel):
|
|
91 |
####################
|
92 |
|
93 |
|
94 |
-
class
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
|
101 |
|
102 |
class ModelForm(BaseModel):
|
@@ -105,6 +138,8 @@ class ModelForm(BaseModel):
|
|
105 |
name: str
|
106 |
meta: ModelMeta
|
107 |
params: ModelParams
|
|
|
|
|
108 |
|
109 |
|
110 |
class ModelsTable:
|
@@ -138,6 +173,39 @@ class ModelsTable:
|
|
138 |
with get_db() as db:
|
139 |
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
142 |
try:
|
143 |
with get_db() as db:
|
@@ -146,6 +214,23 @@ class ModelsTable:
|
|
146 |
except Exception:
|
147 |
return None
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
|
150 |
try:
|
151 |
with get_db() as db:
|
@@ -153,7 +238,7 @@ class ModelsTable:
|
|
153 |
result = (
|
154 |
db.query(Model)
|
155 |
.filter_by(id=id)
|
156 |
-
.update(model.model_dump(exclude={"id"}
|
157 |
)
|
158 |
db.commit()
|
159 |
|
@@ -175,5 +260,15 @@ class ModelsTable:
|
|
175 |
except Exception:
|
176 |
return False
|
177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
Models = ModelsTable()
|
|
|
4 |
|
5 |
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
|
6 |
from open_webui.env import SRC_LOG_LEVELS
|
7 |
+
|
8 |
+
from open_webui.apps.webui.models.users import Users, UserResponse
|
9 |
+
|
10 |
+
|
11 |
from pydantic import BaseModel, ConfigDict
|
12 |
+
|
13 |
+
from sqlalchemy import or_, and_, func
|
14 |
+
from sqlalchemy.dialects import postgresql, sqlite
|
15 |
+
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
16 |
+
|
17 |
+
|
18 |
+
from open_webui.utils.access_control import has_access
|
19 |
+
|
20 |
|
21 |
log = logging.getLogger(__name__)
|
22 |
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
|
|
78 |
Holds a JSON encoded blob of metadata, see `ModelMeta`.
|
79 |
"""
|
80 |
|
81 |
+
access_control = Column(JSON, nullable=True) # Controls data access levels.
|
82 |
+
# Defines access control rules for this entry.
|
83 |
+
# - `None`: Public access, available to all users with the "user" role.
|
84 |
+
# - `{}`: Private access, restricted exclusively to the owner.
|
85 |
+
# - Custom permissions: Specific access control for reading and writing;
|
86 |
+
# Can specify group or user-level restrictions:
|
87 |
+
# {
|
88 |
+
# "read": {
|
89 |
+
# "group_ids": ["group_id1", "group_id2"],
|
90 |
+
# "user_ids": ["user_id1", "user_id2"]
|
91 |
+
# },
|
92 |
+
# "write": {
|
93 |
+
# "group_ids": ["group_id1", "group_id2"],
|
94 |
+
# "user_ids": ["user_id1", "user_id2"]
|
95 |
+
# }
|
96 |
+
# }
|
97 |
+
|
98 |
+
is_active = Column(Boolean, default=True)
|
99 |
+
|
100 |
updated_at = Column(BigInteger)
|
101 |
created_at = Column(BigInteger)
|
102 |
|
|
|
110 |
params: ModelParams
|
111 |
meta: ModelMeta
|
112 |
|
113 |
+
access_control: Optional[dict] = None
|
114 |
+
|
115 |
+
is_active: bool
|
116 |
updated_at: int # timestamp in epoch
|
117 |
created_at: int # timestamp in epoch
|
118 |
|
|
|
124 |
####################
|
125 |
|
126 |
|
127 |
+
class ModelUserResponse(ModelModel):
|
128 |
+
user: Optional[UserResponse] = None
|
129 |
+
|
130 |
+
|
131 |
+
class ModelResponse(ModelModel):
|
132 |
+
pass
|
133 |
|
134 |
|
135 |
class ModelForm(BaseModel):
|
|
|
138 |
name: str
|
139 |
meta: ModelMeta
|
140 |
params: ModelParams
|
141 |
+
access_control: Optional[dict] = None
|
142 |
+
is_active: bool = True
|
143 |
|
144 |
|
145 |
class ModelsTable:
|
|
|
173 |
with get_db() as db:
|
174 |
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
|
175 |
|
176 |
+
def get_models(self) -> list[ModelUserResponse]:
|
177 |
+
with get_db() as db:
|
178 |
+
models = []
|
179 |
+
for model in db.query(Model).filter(Model.base_model_id != None).all():
|
180 |
+
user = Users.get_user_by_id(model.user_id)
|
181 |
+
models.append(
|
182 |
+
ModelUserResponse.model_validate(
|
183 |
+
{
|
184 |
+
**ModelModel.model_validate(model).model_dump(),
|
185 |
+
"user": user.model_dump() if user else None,
|
186 |
+
}
|
187 |
+
)
|
188 |
+
)
|
189 |
+
return models
|
190 |
+
|
191 |
+
def get_base_models(self) -> list[ModelModel]:
|
192 |
+
with get_db() as db:
|
193 |
+
return [
|
194 |
+
ModelModel.model_validate(model)
|
195 |
+
for model in db.query(Model).filter(Model.base_model_id == None).all()
|
196 |
+
]
|
197 |
+
|
198 |
+
def get_models_by_user_id(
|
199 |
+
self, user_id: str, permission: str = "write"
|
200 |
+
) -> list[ModelUserResponse]:
|
201 |
+
models = self.get_models()
|
202 |
+
return [
|
203 |
+
model
|
204 |
+
for model in models
|
205 |
+
if model.user_id == user_id
|
206 |
+
or has_access(user_id, permission, model.access_control)
|
207 |
+
]
|
208 |
+
|
209 |
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
210 |
try:
|
211 |
with get_db() as db:
|
|
|
214 |
except Exception:
|
215 |
return None
|
216 |
|
217 |
+
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
|
218 |
+
with get_db() as db:
|
219 |
+
try:
|
220 |
+
is_active = db.query(Model).filter_by(id=id).first().is_active
|
221 |
+
|
222 |
+
db.query(Model).filter_by(id=id).update(
|
223 |
+
{
|
224 |
+
"is_active": not is_active,
|
225 |
+
"updated_at": int(time.time()),
|
226 |
+
}
|
227 |
+
)
|
228 |
+
db.commit()
|
229 |
+
|
230 |
+
return self.get_model_by_id(id)
|
231 |
+
except Exception:
|
232 |
+
return None
|
233 |
+
|
234 |
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
|
235 |
try:
|
236 |
with get_db() as db:
|
|
|
238 |
result = (
|
239 |
db.query(Model)
|
240 |
.filter_by(id=id)
|
241 |
+
.update(model.model_dump(exclude={"id"}))
|
242 |
)
|
243 |
db.commit()
|
244 |
|
|
|
260 |
except Exception:
|
261 |
return False
|
262 |
|
263 |
+
def delete_all_models(self) -> bool:
|
264 |
+
try:
|
265 |
+
with get_db() as db:
|
266 |
+
db.query(Model).delete()
|
267 |
+
db.commit()
|
268 |
+
|
269 |
+
return True
|
270 |
+
except Exception:
|
271 |
+
return False
|
272 |
+
|
273 |
|
274 |
Models = ModelsTable()
|
backend/open_webui/apps/webui/models/prompts.py
CHANGED
@@ -2,8 +2,12 @@ import time
|
|
2 |
from typing import Optional
|
3 |
|
4 |
from open_webui.apps.webui.internal.db import Base, get_db
|
|
|
|
|
5 |
from pydantic import BaseModel, ConfigDict
|
6 |
-
from sqlalchemy import BigInteger, Column, String, Text
|
|
|
|
|
7 |
|
8 |
####################
|
9 |
# Prompts DB Schema
|
@@ -19,6 +23,23 @@ class Prompt(Base):
|
|
19 |
content = Column(Text)
|
20 |
timestamp = Column(BigInteger)
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
class PromptModel(BaseModel):
|
24 |
command: str
|
@@ -27,6 +48,7 @@ class PromptModel(BaseModel):
|
|
27 |
content: str
|
28 |
timestamp: int # timestamp in epoch
|
29 |
|
|
|
30 |
model_config = ConfigDict(from_attributes=True)
|
31 |
|
32 |
|
@@ -35,10 +57,15 @@ class PromptModel(BaseModel):
|
|
35 |
####################
|
36 |
|
37 |
|
|
|
|
|
|
|
|
|
38 |
class PromptForm(BaseModel):
|
39 |
command: str
|
40 |
title: str
|
41 |
content: str
|
|
|
42 |
|
43 |
|
44 |
class PromptsTable:
|
@@ -48,16 +75,14 @@ class PromptsTable:
|
|
48 |
prompt = PromptModel(
|
49 |
**{
|
50 |
"user_id": user_id,
|
51 |
-
|
52 |
-
"title": form_data.title,
|
53 |
-
"content": form_data.content,
|
54 |
"timestamp": int(time.time()),
|
55 |
}
|
56 |
)
|
57 |
|
58 |
try:
|
59 |
with get_db() as db:
|
60 |
-
result = Prompt(**prompt.
|
61 |
db.add(result)
|
62 |
db.commit()
|
63 |
db.refresh(result)
|
@@ -76,11 +101,34 @@ class PromptsTable:
|
|
76 |
except Exception:
|
77 |
return None
|
78 |
|
79 |
-
def get_prompts(self) -> list[
|
80 |
with get_db() as db:
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
def update_prompt_by_command(
|
86 |
self, command: str, form_data: PromptForm
|
@@ -90,6 +138,7 @@ class PromptsTable:
|
|
90 |
prompt = db.query(Prompt).filter_by(command=command).first()
|
91 |
prompt.title = form_data.title
|
92 |
prompt.content = form_data.content
|
|
|
93 |
prompt.timestamp = int(time.time())
|
94 |
db.commit()
|
95 |
return PromptModel.model_validate(prompt)
|
|
|
2 |
from typing import Optional
|
3 |
|
4 |
from open_webui.apps.webui.internal.db import Base, get_db
|
5 |
+
from open_webui.apps.webui.models.users import Users, UserResponse
|
6 |
+
|
7 |
from pydantic import BaseModel, ConfigDict
|
8 |
+
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
9 |
+
|
10 |
+
from open_webui.utils.access_control import has_access
|
11 |
|
12 |
####################
|
13 |
# Prompts DB Schema
|
|
|
23 |
content = Column(Text)
|
24 |
timestamp = Column(BigInteger)
|
25 |
|
26 |
+
access_control = Column(JSON, nullable=True) # Controls data access levels.
|
27 |
+
# Defines access control rules for this entry.
|
28 |
+
# - `None`: Public access, available to all users with the "user" role.
|
29 |
+
# - `{}`: Private access, restricted exclusively to the owner.
|
30 |
+
# - Custom permissions: Specific access control for reading and writing;
|
31 |
+
# Can specify group or user-level restrictions:
|
32 |
+
# {
|
33 |
+
# "read": {
|
34 |
+
# "group_ids": ["group_id1", "group_id2"],
|
35 |
+
# "user_ids": ["user_id1", "user_id2"]
|
36 |
+
# },
|
37 |
+
# "write": {
|
38 |
+
# "group_ids": ["group_id1", "group_id2"],
|
39 |
+
# "user_ids": ["user_id1", "user_id2"]
|
40 |
+
# }
|
41 |
+
# }
|
42 |
+
|
43 |
|
44 |
class PromptModel(BaseModel):
|
45 |
command: str
|
|
|
48 |
content: str
|
49 |
timestamp: int # timestamp in epoch
|
50 |
|
51 |
+
access_control: Optional[dict] = None
|
52 |
model_config = ConfigDict(from_attributes=True)
|
53 |
|
54 |
|
|
|
57 |
####################
|
58 |
|
59 |
|
60 |
+
class PromptUserResponse(PromptModel):
|
61 |
+
user: Optional[UserResponse] = None
|
62 |
+
|
63 |
+
|
64 |
class PromptForm(BaseModel):
|
65 |
command: str
|
66 |
title: str
|
67 |
content: str
|
68 |
+
access_control: Optional[dict] = None
|
69 |
|
70 |
|
71 |
class PromptsTable:
|
|
|
75 |
prompt = PromptModel(
|
76 |
**{
|
77 |
"user_id": user_id,
|
78 |
+
**form_data.model_dump(),
|
|
|
|
|
79 |
"timestamp": int(time.time()),
|
80 |
}
|
81 |
)
|
82 |
|
83 |
try:
|
84 |
with get_db() as db:
|
85 |
+
result = Prompt(**prompt.model_dump())
|
86 |
db.add(result)
|
87 |
db.commit()
|
88 |
db.refresh(result)
|
|
|
101 |
except Exception:
|
102 |
return None
|
103 |
|
104 |
+
def get_prompts(self) -> list[PromptUserResponse]:
|
105 |
with get_db() as db:
|
106 |
+
prompts = []
|
107 |
+
|
108 |
+
for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all():
|
109 |
+
user = Users.get_user_by_id(prompt.user_id)
|
110 |
+
prompts.append(
|
111 |
+
PromptUserResponse.model_validate(
|
112 |
+
{
|
113 |
+
**PromptModel.model_validate(prompt).model_dump(),
|
114 |
+
"user": user.model_dump() if user else None,
|
115 |
+
}
|
116 |
+
)
|
117 |
+
)
|
118 |
+
|
119 |
+
return prompts
|
120 |
+
|
121 |
+
def get_prompts_by_user_id(
|
122 |
+
self, user_id: str, permission: str = "write"
|
123 |
+
) -> list[PromptUserResponse]:
|
124 |
+
prompts = self.get_prompts()
|
125 |
+
|
126 |
+
return [
|
127 |
+
prompt
|
128 |
+
for prompt in prompts
|
129 |
+
if prompt.user_id == user_id
|
130 |
+
or has_access(user_id, permission, prompt.access_control)
|
131 |
+
]
|
132 |
|
133 |
def update_prompt_by_command(
|
134 |
self, command: str, form_data: PromptForm
|
|
|
138 |
prompt = db.query(Prompt).filter_by(command=command).first()
|
139 |
prompt.title = form_data.title
|
140 |
prompt.content = form_data.content
|
141 |
+
prompt.access_control = form_data.access_control
|
142 |
prompt.timestamp = int(time.time())
|
143 |
db.commit()
|
144 |
return PromptModel.model_validate(prompt)
|
backend/open_webui/apps/webui/models/tools.py
CHANGED
@@ -3,10 +3,13 @@ import time
|
|
3 |
from typing import Optional
|
4 |
|
5 |
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
|
6 |
-
from open_webui.apps.webui.models.users import Users
|
7 |
from open_webui.env import SRC_LOG_LEVELS
|
8 |
from pydantic import BaseModel, ConfigDict
|
9 |
-
from sqlalchemy import BigInteger, Column, String, Text
|
|
|
|
|
|
|
10 |
|
11 |
log = logging.getLogger(__name__)
|
12 |
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
@@ -26,6 +29,24 @@ class Tool(Base):
|
|
26 |
specs = Column(JSONField)
|
27 |
meta = Column(JSONField)
|
28 |
valves = Column(JSONField)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
updated_at = Column(BigInteger)
|
30 |
created_at = Column(BigInteger)
|
31 |
|
@@ -42,6 +63,8 @@ class ToolModel(BaseModel):
|
|
42 |
content: str
|
43 |
specs: list[dict]
|
44 |
meta: ToolMeta
|
|
|
|
|
45 |
updated_at: int # timestamp in epoch
|
46 |
created_at: int # timestamp in epoch
|
47 |
|
@@ -58,15 +81,21 @@ class ToolResponse(BaseModel):
|
|
58 |
user_id: str
|
59 |
name: str
|
60 |
meta: ToolMeta
|
|
|
61 |
updated_at: int # timestamp in epoch
|
62 |
created_at: int # timestamp in epoch
|
63 |
|
64 |
|
|
|
|
|
|
|
|
|
65 |
class ToolForm(BaseModel):
|
66 |
id: str
|
67 |
name: str
|
68 |
content: str
|
69 |
meta: ToolMeta
|
|
|
70 |
|
71 |
|
72 |
class ToolValves(BaseModel):
|
@@ -109,9 +138,32 @@ class ToolsTable:
|
|
109 |
except Exception:
|
110 |
return None
|
111 |
|
112 |
-
def get_tools(self) -> list[
|
113 |
with get_db() as db:
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
|
117 |
try:
|
|
|
3 |
from typing import Optional
|
4 |
|
5 |
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
|
6 |
+
from open_webui.apps.webui.models.users import Users, UserResponse
|
7 |
from open_webui.env import SRC_LOG_LEVELS
|
8 |
from pydantic import BaseModel, ConfigDict
|
9 |
+
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
10 |
+
|
11 |
+
from open_webui.utils.access_control import has_access
|
12 |
+
|
13 |
|
14 |
log = logging.getLogger(__name__)
|
15 |
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
|
|
29 |
specs = Column(JSONField)
|
30 |
meta = Column(JSONField)
|
31 |
valves = Column(JSONField)
|
32 |
+
|
33 |
+
access_control = Column(JSON, nullable=True) # Controls data access levels.
|
34 |
+
# Defines access control rules for this entry.
|
35 |
+
# - `None`: Public access, available to all users with the "user" role.
|
36 |
+
# - `{}`: Private access, restricted exclusively to the owner.
|
37 |
+
# - Custom permissions: Specific access control for reading and writing;
|
38 |
+
# Can specify group or user-level restrictions:
|
39 |
+
# {
|
40 |
+
# "read": {
|
41 |
+
# "group_ids": ["group_id1", "group_id2"],
|
42 |
+
# "user_ids": ["user_id1", "user_id2"]
|
43 |
+
# },
|
44 |
+
# "write": {
|
45 |
+
# "group_ids": ["group_id1", "group_id2"],
|
46 |
+
# "user_ids": ["user_id1", "user_id2"]
|
47 |
+
# }
|
48 |
+
# }
|
49 |
+
|
50 |
updated_at = Column(BigInteger)
|
51 |
created_at = Column(BigInteger)
|
52 |
|
|
|
63 |
content: str
|
64 |
specs: list[dict]
|
65 |
meta: ToolMeta
|
66 |
+
access_control: Optional[dict] = None
|
67 |
+
|
68 |
updated_at: int # timestamp in epoch
|
69 |
created_at: int # timestamp in epoch
|
70 |
|
|
|
81 |
user_id: str
|
82 |
name: str
|
83 |
meta: ToolMeta
|
84 |
+
access_control: Optional[dict] = None
|
85 |
updated_at: int # timestamp in epoch
|
86 |
created_at: int # timestamp in epoch
|
87 |
|
88 |
|
89 |
+
class ToolUserResponse(ToolResponse):
|
90 |
+
user: Optional[UserResponse] = None
|
91 |
+
|
92 |
+
|
93 |
class ToolForm(BaseModel):
|
94 |
id: str
|
95 |
name: str
|
96 |
content: str
|
97 |
meta: ToolMeta
|
98 |
+
access_control: Optional[dict] = None
|
99 |
|
100 |
|
101 |
class ToolValves(BaseModel):
|
|
|
138 |
except Exception:
|
139 |
return None
|
140 |
|
141 |
+
def get_tools(self) -> list[ToolUserResponse]:
|
142 |
with get_db() as db:
|
143 |
+
tools = []
|
144 |
+
for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all():
|
145 |
+
user = Users.get_user_by_id(tool.user_id)
|
146 |
+
tools.append(
|
147 |
+
ToolUserResponse.model_validate(
|
148 |
+
{
|
149 |
+
**ToolModel.model_validate(tool).model_dump(),
|
150 |
+
"user": user.model_dump() if user else None,
|
151 |
+
}
|
152 |
+
)
|
153 |
+
)
|
154 |
+
return tools
|
155 |
+
|
156 |
+
def get_tools_by_user_id(
|
157 |
+
self, user_id: str, permission: str = "write"
|
158 |
+
) -> list[ToolUserResponse]:
|
159 |
+
tools = self.get_tools()
|
160 |
+
|
161 |
+
return [
|
162 |
+
tool
|
163 |
+
for tool in tools
|
164 |
+
if tool.user_id == user_id
|
165 |
+
or has_access(user_id, permission, tool.access_control)
|
166 |
+
]
|
167 |
|
168 |
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
|
169 |
try:
|
backend/open_webui/apps/webui/models/users.py
CHANGED
@@ -62,6 +62,14 @@ class UserModel(BaseModel):
|
|
62 |
####################
|
63 |
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
class UserRoleUpdateForm(BaseModel):
|
66 |
id: str
|
67 |
role: str
|
|
|
62 |
####################
|
63 |
|
64 |
|
65 |
+
class UserResponse(BaseModel):
|
66 |
+
id: str
|
67 |
+
name: str
|
68 |
+
email: str
|
69 |
+
role: str
|
70 |
+
profile_image_url: str
|
71 |
+
|
72 |
+
|
73 |
class UserRoleUpdateForm(BaseModel):
|
74 |
id: str
|
75 |
role: str
|
backend/open_webui/apps/webui/routers/auths.py
CHANGED
@@ -2,12 +2,14 @@ import re
|
|
2 |
import uuid
|
3 |
import time
|
4 |
import datetime
|
|
|
5 |
|
6 |
from open_webui.apps.webui.models.auths import (
|
7 |
AddUserForm,
|
8 |
ApiKey,
|
9 |
Auths,
|
10 |
Token,
|
|
|
11 |
SigninForm,
|
12 |
SigninResponse,
|
13 |
SignupForm,
|
@@ -16,13 +18,15 @@ from open_webui.apps.webui.models.auths import (
|
|
16 |
UserResponse,
|
17 |
)
|
18 |
from open_webui.apps.webui.models.users import Users
|
19 |
-
|
20 |
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
21 |
from open_webui.env import (
|
|
|
22 |
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
23 |
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
24 |
WEBUI_SESSION_COOKIE_SAME_SITE,
|
25 |
WEBUI_SESSION_COOKIE_SECURE,
|
|
|
26 |
)
|
27 |
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
28 |
from fastapi.responses import Response
|
@@ -37,10 +41,19 @@ from open_webui.utils.utils import (
|
|
37 |
get_password_hash,
|
38 |
)
|
39 |
from open_webui.utils.webhook import post_webhook
|
40 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
router = APIRouter()
|
43 |
|
|
|
|
|
|
|
44 |
############################
|
45 |
# GetSessionUser
|
46 |
############################
|
@@ -48,6 +61,7 @@ router = APIRouter()
|
|
48 |
|
49 |
class SessionUserResponse(Token, UserResponse):
|
50 |
expires_at: Optional[int] = None
|
|
|
51 |
|
52 |
|
53 |
@router.get("/", response_model=SessionUserResponse)
|
@@ -80,6 +94,10 @@ async def get_session_user(
|
|
80 |
secure=WEBUI_SESSION_COOKIE_SECURE,
|
81 |
)
|
82 |
|
|
|
|
|
|
|
|
|
83 |
return {
|
84 |
"token": token,
|
85 |
"token_type": "Bearer",
|
@@ -89,6 +107,7 @@ async def get_session_user(
|
|
89 |
"name": user.name,
|
90 |
"role": user.role,
|
91 |
"profile_image_url": user.profile_image_url,
|
|
|
92 |
}
|
93 |
|
94 |
|
@@ -137,6 +156,140 @@ async def update_password(
|
|
137 |
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
138 |
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
############################
|
141 |
# SignIn
|
142 |
############################
|
@@ -211,6 +364,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|
211 |
secure=WEBUI_SESSION_COOKIE_SECURE,
|
212 |
)
|
213 |
|
|
|
|
|
|
|
|
|
214 |
return {
|
215 |
"token": token,
|
216 |
"token_type": "Bearer",
|
@@ -220,6 +377,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|
220 |
"name": user.name,
|
221 |
"role": user.role,
|
222 |
"profile_image_url": user.profile_image_url,
|
|
|
223 |
}
|
224 |
else:
|
225 |
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
@@ -260,6 +418,11 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|
260 |
if Users.get_num_users() == 0
|
261 |
else request.app.state.config.DEFAULT_USER_ROLE
|
262 |
)
|
|
|
|
|
|
|
|
|
|
|
263 |
hashed = get_password_hash(form_data.password)
|
264 |
user = Auths.insert_new_auth(
|
265 |
form_data.email.lower(),
|
@@ -307,6 +470,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|
307 |
},
|
308 |
)
|
309 |
|
|
|
|
|
|
|
|
|
310 |
return {
|
311 |
"token": token,
|
312 |
"token_type": "Bearer",
|
@@ -316,6 +483,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|
316 |
"name": user.name,
|
317 |
"role": user.role,
|
318 |
"profile_image_url": user.profile_image_url,
|
|
|
319 |
}
|
320 |
else:
|
321 |
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
|
@@ -413,6 +581,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
|
|
413 |
return {
|
414 |
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
415 |
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
|
|
416 |
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
417 |
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
418 |
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
@@ -423,6 +592,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
|
|
423 |
class AdminConfig(BaseModel):
|
424 |
SHOW_ADMIN_DETAILS: bool
|
425 |
ENABLE_SIGNUP: bool
|
|
|
426 |
DEFAULT_USER_ROLE: str
|
427 |
JWT_EXPIRES_IN: str
|
428 |
ENABLE_COMMUNITY_SHARING: bool
|
@@ -435,6 +605,7 @@ async def update_admin_config(
|
|
435 |
):
|
436 |
request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
|
437 |
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
|
|
|
438 |
|
439 |
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
|
440 |
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
|
@@ -453,6 +624,7 @@ async def update_admin_config(
|
|
453 |
return {
|
454 |
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
455 |
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
|
|
456 |
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
457 |
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
458 |
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
@@ -460,6 +632,105 @@ async def update_admin_config(
|
|
460 |
}
|
461 |
|
462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
############################
|
464 |
# API Key
|
465 |
############################
|
@@ -467,9 +738,16 @@ async def update_admin_config(
|
|
467 |
|
468 |
# create api key
|
469 |
@router.post("/api_key", response_model=ApiKey)
|
470 |
-
async def
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
api_key = create_api_key()
|
472 |
success = Users.update_user_api_key_by_id(user.id, api_key)
|
|
|
473 |
if success:
|
474 |
return {
|
475 |
"api_key": api_key,
|
|
|
2 |
import uuid
|
3 |
import time
|
4 |
import datetime
|
5 |
+
import logging
|
6 |
|
7 |
from open_webui.apps.webui.models.auths import (
|
8 |
AddUserForm,
|
9 |
ApiKey,
|
10 |
Auths,
|
11 |
Token,
|
12 |
+
LdapForm,
|
13 |
SigninForm,
|
14 |
SigninResponse,
|
15 |
SignupForm,
|
|
|
18 |
UserResponse,
|
19 |
)
|
20 |
from open_webui.apps.webui.models.users import Users
|
21 |
+
|
22 |
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
23 |
from open_webui.env import (
|
24 |
+
WEBUI_AUTH,
|
25 |
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
26 |
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
27 |
WEBUI_SESSION_COOKIE_SAME_SITE,
|
28 |
WEBUI_SESSION_COOKIE_SECURE,
|
29 |
+
SRC_LOG_LEVELS,
|
30 |
)
|
31 |
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
32 |
from fastapi.responses import Response
|
|
|
41 |
get_password_hash,
|
42 |
)
|
43 |
from open_webui.utils.webhook import post_webhook
|
44 |
+
from open_webui.utils.access_control import get_permissions
|
45 |
+
|
46 |
+
from typing import Optional, List
|
47 |
+
|
48 |
+
from ssl import CERT_REQUIRED, PROTOCOL_TLS
|
49 |
+
from ldap3 import Server, Connection, ALL, Tls
|
50 |
+
from ldap3.utils.conv import escape_filter_chars
|
51 |
|
52 |
router = APIRouter()
|
53 |
|
54 |
+
log = logging.getLogger(__name__)
|
55 |
+
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
56 |
+
|
57 |
############################
|
58 |
# GetSessionUser
|
59 |
############################
|
|
|
61 |
|
62 |
class SessionUserResponse(Token, UserResponse):
|
63 |
expires_at: Optional[int] = None
|
64 |
+
permissions: Optional[dict] = None
|
65 |
|
66 |
|
67 |
@router.get("/", response_model=SessionUserResponse)
|
|
|
94 |
secure=WEBUI_SESSION_COOKIE_SECURE,
|
95 |
)
|
96 |
|
97 |
+
user_permissions = get_permissions(
|
98 |
+
user.id, request.app.state.config.USER_PERMISSIONS
|
99 |
+
)
|
100 |
+
|
101 |
return {
|
102 |
"token": token,
|
103 |
"token_type": "Bearer",
|
|
|
107 |
"name": user.name,
|
108 |
"role": user.role,
|
109 |
"profile_image_url": user.profile_image_url,
|
110 |
+
"permissions": user_permissions,
|
111 |
}
|
112 |
|
113 |
|
|
|
156 |
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
157 |
|
158 |
|
159 |
+
############################
|
160 |
+
# LDAP Authentication
|
161 |
+
############################
|
162 |
+
@router.post("/ldap", response_model=SigninResponse)
|
163 |
+
async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
164 |
+
ENABLE_LDAP = request.app.state.config.ENABLE_LDAP
|
165 |
+
LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL
|
166 |
+
LDAP_SERVER_HOST = request.app.state.config.LDAP_SERVER_HOST
|
167 |
+
LDAP_SERVER_PORT = request.app.state.config.LDAP_SERVER_PORT
|
168 |
+
LDAP_ATTRIBUTE_FOR_USERNAME = request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME
|
169 |
+
LDAP_SEARCH_BASE = request.app.state.config.LDAP_SEARCH_BASE
|
170 |
+
LDAP_SEARCH_FILTERS = request.app.state.config.LDAP_SEARCH_FILTERS
|
171 |
+
LDAP_APP_DN = request.app.state.config.LDAP_APP_DN
|
172 |
+
LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD
|
173 |
+
LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS
|
174 |
+
LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE
|
175 |
+
LDAP_CIPHERS = (
|
176 |
+
request.app.state.config.LDAP_CIPHERS
|
177 |
+
if request.app.state.config.LDAP_CIPHERS
|
178 |
+
else "ALL"
|
179 |
+
)
|
180 |
+
|
181 |
+
if not ENABLE_LDAP:
|
182 |
+
raise HTTPException(400, detail="LDAP authentication is not enabled")
|
183 |
+
|
184 |
+
try:
|
185 |
+
tls = Tls(
|
186 |
+
validate=CERT_REQUIRED,
|
187 |
+
version=PROTOCOL_TLS,
|
188 |
+
ca_certs_file=LDAP_CA_CERT_FILE,
|
189 |
+
ciphers=LDAP_CIPHERS,
|
190 |
+
)
|
191 |
+
except Exception as e:
|
192 |
+
log.error(f"An error occurred on TLS: {str(e)}")
|
193 |
+
raise HTTPException(400, detail=str(e))
|
194 |
+
|
195 |
+
try:
|
196 |
+
server = Server(
|
197 |
+
host=LDAP_SERVER_HOST,
|
198 |
+
port=LDAP_SERVER_PORT,
|
199 |
+
get_info=ALL,
|
200 |
+
use_ssl=LDAP_USE_TLS,
|
201 |
+
tls=tls,
|
202 |
+
)
|
203 |
+
connection_app = Connection(
|
204 |
+
server,
|
205 |
+
LDAP_APP_DN,
|
206 |
+
LDAP_APP_PASSWORD,
|
207 |
+
auto_bind="NONE",
|
208 |
+
authentication="SIMPLE",
|
209 |
+
)
|
210 |
+
if not connection_app.bind():
|
211 |
+
raise HTTPException(400, detail="Application account bind failed")
|
212 |
+
|
213 |
+
search_success = connection_app.search(
|
214 |
+
search_base=LDAP_SEARCH_BASE,
|
215 |
+
search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})",
|
216 |
+
attributes=[f"{LDAP_ATTRIBUTE_FOR_USERNAME}", "mail", "cn"],
|
217 |
+
)
|
218 |
+
|
219 |
+
if not search_success:
|
220 |
+
raise HTTPException(400, detail="User not found in the LDAP server")
|
221 |
+
|
222 |
+
entry = connection_app.entries[0]
|
223 |
+
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
|
224 |
+
mail = str(entry["mail"])
|
225 |
+
cn = str(entry["cn"])
|
226 |
+
user_dn = entry.entry_dn
|
227 |
+
|
228 |
+
if username == form_data.user.lower():
|
229 |
+
connection_user = Connection(
|
230 |
+
server,
|
231 |
+
user_dn,
|
232 |
+
form_data.password,
|
233 |
+
auto_bind="NONE",
|
234 |
+
authentication="SIMPLE",
|
235 |
+
)
|
236 |
+
if not connection_user.bind():
|
237 |
+
raise HTTPException(400, f"Authentication failed for {form_data.user}")
|
238 |
+
|
239 |
+
user = Users.get_user_by_email(mail)
|
240 |
+
if not user:
|
241 |
+
|
242 |
+
try:
|
243 |
+
hashed = get_password_hash(form_data.password)
|
244 |
+
user = Auths.insert_new_auth(mail, hashed, cn)
|
245 |
+
|
246 |
+
if not user:
|
247 |
+
raise HTTPException(
|
248 |
+
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR
|
249 |
+
)
|
250 |
+
|
251 |
+
except HTTPException:
|
252 |
+
raise
|
253 |
+
except Exception as err:
|
254 |
+
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
255 |
+
|
256 |
+
user = Auths.authenticate_user(mail, password=str(form_data.password))
|
257 |
+
|
258 |
+
if user:
|
259 |
+
token = create_token(
|
260 |
+
data={"id": user.id},
|
261 |
+
expires_delta=parse_duration(
|
262 |
+
request.app.state.config.JWT_EXPIRES_IN
|
263 |
+
),
|
264 |
+
)
|
265 |
+
|
266 |
+
# Set the cookie token
|
267 |
+
response.set_cookie(
|
268 |
+
key="token",
|
269 |
+
value=token,
|
270 |
+
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
271 |
+
)
|
272 |
+
|
273 |
+
return {
|
274 |
+
"token": token,
|
275 |
+
"token_type": "Bearer",
|
276 |
+
"id": user.id,
|
277 |
+
"email": user.email,
|
278 |
+
"name": user.name,
|
279 |
+
"role": user.role,
|
280 |
+
"profile_image_url": user.profile_image_url,
|
281 |
+
}
|
282 |
+
else:
|
283 |
+
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
284 |
+
else:
|
285 |
+
raise HTTPException(
|
286 |
+
400,
|
287 |
+
f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}",
|
288 |
+
)
|
289 |
+
except Exception as e:
|
290 |
+
raise HTTPException(400, detail=str(e))
|
291 |
+
|
292 |
+
|
293 |
############################
|
294 |
# SignIn
|
295 |
############################
|
|
|
364 |
secure=WEBUI_SESSION_COOKIE_SECURE,
|
365 |
)
|
366 |
|
367 |
+
user_permissions = get_permissions(
|
368 |
+
user.id, request.app.state.config.USER_PERMISSIONS
|
369 |
+
)
|
370 |
+
|
371 |
return {
|
372 |
"token": token,
|
373 |
"token_type": "Bearer",
|
|
|
377 |
"name": user.name,
|
378 |
"role": user.role,
|
379 |
"profile_image_url": user.profile_image_url,
|
380 |
+
"permissions": user_permissions,
|
381 |
}
|
382 |
else:
|
383 |
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
418 |
if Users.get_num_users() == 0
|
419 |
else request.app.state.config.DEFAULT_USER_ROLE
|
420 |
)
|
421 |
+
|
422 |
+
if Users.get_num_users() == 0:
|
423 |
+
# Disable signup after the first user is created
|
424 |
+
request.app.state.config.ENABLE_SIGNUP = False
|
425 |
+
|
426 |
hashed = get_password_hash(form_data.password)
|
427 |
user = Auths.insert_new_auth(
|
428 |
form_data.email.lower(),
|
|
|
470 |
},
|
471 |
)
|
472 |
|
473 |
+
user_permissions = get_permissions(
|
474 |
+
user.id, request.app.state.config.USER_PERMISSIONS
|
475 |
+
)
|
476 |
+
|
477 |
return {
|
478 |
"token": token,
|
479 |
"token_type": "Bearer",
|
|
|
483 |
"name": user.name,
|
484 |
"role": user.role,
|
485 |
"profile_image_url": user.profile_image_url,
|
486 |
+
"permissions": user_permissions,
|
487 |
}
|
488 |
else:
|
489 |
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
|
|
|
581 |
return {
|
582 |
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
583 |
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
584 |
+
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
|
585 |
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
586 |
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
587 |
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
|
|
592 |
class AdminConfig(BaseModel):
|
593 |
SHOW_ADMIN_DETAILS: bool
|
594 |
ENABLE_SIGNUP: bool
|
595 |
+
ENABLE_API_KEY: bool
|
596 |
DEFAULT_USER_ROLE: str
|
597 |
JWT_EXPIRES_IN: str
|
598 |
ENABLE_COMMUNITY_SHARING: bool
|
|
|
605 |
):
|
606 |
request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
|
607 |
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
|
608 |
+
request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY
|
609 |
|
610 |
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
|
611 |
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
|
|
|
624 |
return {
|
625 |
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
626 |
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
627 |
+
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
|
628 |
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
629 |
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
630 |
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
|
|
632 |
}
|
633 |
|
634 |
|
635 |
+
class LdapServerConfig(BaseModel):
|
636 |
+
label: str
|
637 |
+
host: str
|
638 |
+
port: Optional[int] = None
|
639 |
+
attribute_for_username: str = "uid"
|
640 |
+
app_dn: str
|
641 |
+
app_dn_password: str
|
642 |
+
search_base: str
|
643 |
+
search_filters: str = ""
|
644 |
+
use_tls: bool = True
|
645 |
+
certificate_path: Optional[str] = None
|
646 |
+
ciphers: Optional[str] = "ALL"
|
647 |
+
|
648 |
+
|
649 |
+
@router.get("/admin/config/ldap/server", response_model=LdapServerConfig)
|
650 |
+
async def get_ldap_server(request: Request, user=Depends(get_admin_user)):
|
651 |
+
return {
|
652 |
+
"label": request.app.state.config.LDAP_SERVER_LABEL,
|
653 |
+
"host": request.app.state.config.LDAP_SERVER_HOST,
|
654 |
+
"port": request.app.state.config.LDAP_SERVER_PORT,
|
655 |
+
"attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME,
|
656 |
+
"app_dn": request.app.state.config.LDAP_APP_DN,
|
657 |
+
"app_dn_password": request.app.state.config.LDAP_APP_PASSWORD,
|
658 |
+
"search_base": request.app.state.config.LDAP_SEARCH_BASE,
|
659 |
+
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
|
660 |
+
"use_tls": request.app.state.config.LDAP_USE_TLS,
|
661 |
+
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
|
662 |
+
"ciphers": request.app.state.config.LDAP_CIPHERS,
|
663 |
+
}
|
664 |
+
|
665 |
+
|
666 |
+
@router.post("/admin/config/ldap/server")
|
667 |
+
async def update_ldap_server(
|
668 |
+
request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user)
|
669 |
+
):
|
670 |
+
required_fields = [
|
671 |
+
"label",
|
672 |
+
"host",
|
673 |
+
"attribute_for_username",
|
674 |
+
"app_dn",
|
675 |
+
"app_dn_password",
|
676 |
+
"search_base",
|
677 |
+
]
|
678 |
+
for key in required_fields:
|
679 |
+
value = getattr(form_data, key)
|
680 |
+
if not value:
|
681 |
+
raise HTTPException(400, detail=f"Required field {key} is empty")
|
682 |
+
|
683 |
+
if form_data.use_tls and not form_data.certificate_path:
|
684 |
+
raise HTTPException(
|
685 |
+
400, detail="TLS is enabled but certificate file path is missing"
|
686 |
+
)
|
687 |
+
|
688 |
+
request.app.state.config.LDAP_SERVER_LABEL = form_data.label
|
689 |
+
request.app.state.config.LDAP_SERVER_HOST = form_data.host
|
690 |
+
request.app.state.config.LDAP_SERVER_PORT = form_data.port
|
691 |
+
request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = (
|
692 |
+
form_data.attribute_for_username
|
693 |
+
)
|
694 |
+
request.app.state.config.LDAP_APP_DN = form_data.app_dn
|
695 |
+
request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password
|
696 |
+
request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base
|
697 |
+
request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters
|
698 |
+
request.app.state.config.LDAP_USE_TLS = form_data.use_tls
|
699 |
+
request.app.state.config.LDAP_CA_CERT_FILE = form_data.certificate_path
|
700 |
+
request.app.state.config.LDAP_CIPHERS = form_data.ciphers
|
701 |
+
|
702 |
+
return {
|
703 |
+
"label": request.app.state.config.LDAP_SERVER_LABEL,
|
704 |
+
"host": request.app.state.config.LDAP_SERVER_HOST,
|
705 |
+
"port": request.app.state.config.LDAP_SERVER_PORT,
|
706 |
+
"attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME,
|
707 |
+
"app_dn": request.app.state.config.LDAP_APP_DN,
|
708 |
+
"app_dn_password": request.app.state.config.LDAP_APP_PASSWORD,
|
709 |
+
"search_base": request.app.state.config.LDAP_SEARCH_BASE,
|
710 |
+
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
|
711 |
+
"use_tls": request.app.state.config.LDAP_USE_TLS,
|
712 |
+
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
|
713 |
+
"ciphers": request.app.state.config.LDAP_CIPHERS,
|
714 |
+
}
|
715 |
+
|
716 |
+
|
717 |
+
@router.get("/admin/config/ldap")
|
718 |
+
async def get_ldap_config(request: Request, user=Depends(get_admin_user)):
|
719 |
+
return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
|
720 |
+
|
721 |
+
|
722 |
+
class LdapConfigForm(BaseModel):
|
723 |
+
enable_ldap: Optional[bool] = None
|
724 |
+
|
725 |
+
|
726 |
+
@router.post("/admin/config/ldap")
|
727 |
+
async def update_ldap_config(
|
728 |
+
request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user)
|
729 |
+
):
|
730 |
+
request.app.state.config.ENABLE_LDAP = form_data.enable_ldap
|
731 |
+
return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
|
732 |
+
|
733 |
+
|
734 |
############################
|
735 |
# API Key
|
736 |
############################
|
|
|
738 |
|
739 |
# create api key
|
740 |
@router.post("/api_key", response_model=ApiKey)
|
741 |
+
async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
742 |
+
if not request.app.state.config.ENABLE_API_KEY:
|
743 |
+
raise HTTPException(
|
744 |
+
status.HTTP_403_FORBIDDEN,
|
745 |
+
detail=ERROR_MESSAGES.API_KEY_CREATION_NOT_ALLOWED,
|
746 |
+
)
|
747 |
+
|
748 |
api_key = create_api_key()
|
749 |
success = Users.update_user_api_key_by_id(user.id, api_key)
|
750 |
+
|
751 |
if success:
|
752 |
return {
|
753 |
"api_key": api_key,
|
backend/open_webui/apps/webui/routers/chats.py
CHANGED
@@ -17,7 +17,10 @@ from open_webui.constants import ERROR_MESSAGES
|
|
17 |
from open_webui.env import SRC_LOG_LEVELS
|
18 |
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
19 |
from pydantic import BaseModel
|
|
|
|
|
20 |
from open_webui.utils.utils import get_admin_user, get_verified_user
|
|
|
21 |
|
22 |
log = logging.getLogger(__name__)
|
23 |
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
@@ -50,9 +53,10 @@ async def get_session_user_chat_list(
|
|
50 |
|
51 |
@router.delete("/", response_model=bool)
|
52 |
async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
56 |
raise HTTPException(
|
57 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
58 |
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
@@ -385,8 +389,8 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
|
|
385 |
|
386 |
return result
|
387 |
else:
|
388 |
-
if not
|
389 |
-
"
|
390 |
):
|
391 |
raise HTTPException(
|
392 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
17 |
from open_webui.env import SRC_LOG_LEVELS
|
18 |
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
19 |
from pydantic import BaseModel
|
20 |
+
|
21 |
+
|
22 |
from open_webui.utils.utils import get_admin_user, get_verified_user
|
23 |
+
from open_webui.utils.access_control import has_permission
|
24 |
|
25 |
log = logging.getLogger(__name__)
|
26 |
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
|
|
53 |
|
54 |
@router.delete("/", response_model=bool)
|
55 |
async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
|
56 |
+
|
57 |
+
if user.role == "user" and not has_permission(
|
58 |
+
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
|
59 |
+
):
|
60 |
raise HTTPException(
|
61 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
62 |
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
|
389 |
|
390 |
return result
|
391 |
else:
|
392 |
+
if not has_permission(
|
393 |
+
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
|
394 |
):
|
395 |
raise HTTPException(
|
396 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
backend/open_webui/apps/webui/routers/groups.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from open_webui.apps.webui.models.groups import (
|
6 |
+
Groups,
|
7 |
+
GroupForm,
|
8 |
+
GroupUpdateForm,
|
9 |
+
GroupResponse,
|
10 |
+
)
|
11 |
+
|
12 |
+
from open_webui.config import CACHE_DIR
|
13 |
+
from open_webui.constants import ERROR_MESSAGES
|
14 |
+
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
15 |
+
from open_webui.utils.utils import get_admin_user, get_verified_user
|
16 |
+
|
17 |
+
router = APIRouter()
|
18 |
+
|
19 |
+
############################
|
20 |
+
# GetFunctions
|
21 |
+
############################
|
22 |
+
|
23 |
+
|
24 |
+
@router.get("/", response_model=list[GroupResponse])
|
25 |
+
async def get_groups(user=Depends(get_verified_user)):
|
26 |
+
if user.role == "admin":
|
27 |
+
return Groups.get_groups()
|
28 |
+
else:
|
29 |
+
return Groups.get_groups_by_member_id(user.id)
|
30 |
+
|
31 |
+
|
32 |
+
############################
|
33 |
+
# CreateNewGroup
|
34 |
+
############################
|
35 |
+
|
36 |
+
|
37 |
+
@router.post("/create", response_model=Optional[GroupResponse])
|
38 |
+
async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)):
|
39 |
+
try:
|
40 |
+
group = Groups.insert_new_group(user.id, form_data)
|
41 |
+
if group:
|
42 |
+
return group
|
43 |
+
else:
|
44 |
+
raise HTTPException(
|
45 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
46 |
+
detail=ERROR_MESSAGES.DEFAULT("Error creating group"),
|
47 |
+
)
|
48 |
+
except Exception as e:
|
49 |
+
print(e)
|
50 |
+
raise HTTPException(
|
51 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
52 |
+
detail=ERROR_MESSAGES.DEFAULT(e),
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
############################
|
57 |
+
# GetGroupById
|
58 |
+
############################
|
59 |
+
|
60 |
+
|
61 |
+
@router.get("/id/{id}", response_model=Optional[GroupResponse])
|
62 |
+
async def get_group_by_id(id: str, user=Depends(get_admin_user)):
|
63 |
+
group = Groups.get_group_by_id(id)
|
64 |
+
if group:
|
65 |
+
return group
|
66 |
+
else:
|
67 |
+
raise HTTPException(
|
68 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
69 |
+
detail=ERROR_MESSAGES.NOT_FOUND,
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
############################
|
74 |
+
# UpdateGroupById
|
75 |
+
############################
|
76 |
+
|
77 |
+
|
78 |
+
@router.post("/id/{id}/update", response_model=Optional[GroupResponse])
|
79 |
+
async def update_group_by_id(
|
80 |
+
id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
|
81 |
+
):
|
82 |
+
try:
|
83 |
+
group = Groups.update_group_by_id(id, form_data)
|
84 |
+
if group:
|
85 |
+
return group
|
86 |
+
else:
|
87 |
+
raise HTTPException(
|
88 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
89 |
+
detail=ERROR_MESSAGES.DEFAULT("Error updating group"),
|
90 |
+
)
|
91 |
+
except Exception as e:
|
92 |
+
print(e)
|
93 |
+
raise HTTPException(
|
94 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
95 |
+
detail=ERROR_MESSAGES.DEFAULT(e),
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
############################
|
100 |
+
# DeleteGroupById
|
101 |
+
############################
|
102 |
+
|
103 |
+
|
104 |
+
@router.delete("/id/{id}/delete", response_model=bool)
|
105 |
+
async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
|
106 |
+
try:
|
107 |
+
result = Groups.delete_group_by_id(id)
|
108 |
+
if result:
|
109 |
+
return result
|
110 |
+
else:
|
111 |
+
raise HTTPException(
|
112 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
113 |
+
detail=ERROR_MESSAGES.DEFAULT("Error deleting group"),
|
114 |
+
)
|
115 |
+
except Exception as e:
|
116 |
+
print(e)
|
117 |
+
raise HTTPException(
|
118 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
119 |
+
detail=ERROR_MESSAGES.DEFAULT(e),
|
120 |
+
)
|
backend/open_webui/apps/webui/routers/knowledge.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
import json
|
2 |
from typing import Optional, Union
|
3 |
from pydantic import BaseModel
|
4 |
-
from fastapi import APIRouter, Depends, HTTPException, status
|
5 |
import logging
|
6 |
|
7 |
from open_webui.apps.webui.models.knowledge import (
|
8 |
Knowledges,
|
9 |
-
KnowledgeUpdateForm,
|
10 |
KnowledgeForm,
|
11 |
KnowledgeResponse,
|
|
|
12 |
)
|
13 |
from open_webui.apps.webui.models.files import Files, FileModel
|
14 |
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
@@ -17,6 +17,9 @@ from open_webui.apps.retrieval.main import process_file, ProcessFileForm
|
|
17 |
|
18 |
from open_webui.constants import ERROR_MESSAGES
|
19 |
from open_webui.utils.utils import get_admin_user, get_verified_user
|
|
|
|
|
|
|
20 |
from open_webui.env import SRC_LOG_LEVELS
|
21 |
|
22 |
|
@@ -26,64 +29,98 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
|
26 |
router = APIRouter()
|
27 |
|
28 |
############################
|
29 |
-
#
|
30 |
############################
|
31 |
|
32 |
|
33 |
-
@router.get(
|
34 |
-
|
35 |
-
|
36 |
-
async def get_knowledge_items(
|
37 |
-
id: Optional[str] = None, user=Depends(get_verified_user)
|
38 |
-
):
|
39 |
-
if id:
|
40 |
-
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
else:
|
45 |
-
raise HTTPException(
|
46 |
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
47 |
-
detail=ERROR_MESSAGES.NOT_FOUND,
|
48 |
-
)
|
49 |
else:
|
50 |
-
knowledge_bases =
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
files
|
55 |
-
if
|
56 |
-
|
57 |
-
|
|
|
58 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
set(knowledge.data.get("file_ids", []))
|
64 |
-
- set([file.id for file in files])
|
65 |
)
|
66 |
-
if missing_files:
|
67 |
-
data = knowledge.data or {}
|
68 |
-
file_ids = data.get("file_ids", [])
|
69 |
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
|
74 |
-
Knowledges.update_knowledge_by_id(
|
75 |
-
id=knowledge.id, form_data=KnowledgeUpdateForm(data=data)
|
76 |
-
)
|
77 |
|
78 |
-
files = Files.get_file_metadatas_by_ids(file_ids)
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
)
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
|
89 |
############################
|
@@ -92,7 +129,17 @@ async def get_knowledge_items(
|
|
92 |
|
93 |
|
94 |
@router.post("/create", response_model=Optional[KnowledgeResponse])
|
95 |
-
async def create_new_knowledge(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
knowledge = Knowledges.insert_new_knowledge(user.id, form_data)
|
97 |
|
98 |
if knowledge:
|
@@ -118,13 +165,20 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|
118 |
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
119 |
|
120 |
if knowledge:
|
121 |
-
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
|
122 |
-
files = Files.get_files_by_ids(file_ids)
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
else:
|
129 |
raise HTTPException(
|
130 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
@@ -140,11 +194,23 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|
140 |
@router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse])
|
141 |
async def update_knowledge_by_id(
|
142 |
id: str,
|
143 |
-
form_data:
|
144 |
-
user=Depends(
|
145 |
):
|
146 |
-
knowledge = Knowledges.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
|
|
148 |
if knowledge:
|
149 |
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
|
150 |
files = Files.get_files_by_ids(file_ids)
|
@@ -173,9 +239,22 @@ class KnowledgeFileIdForm(BaseModel):
|
|
173 |
def add_file_to_knowledge_by_id(
|
174 |
id: str,
|
175 |
form_data: KnowledgeFileIdForm,
|
176 |
-
user=Depends(
|
177 |
):
|
178 |
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
file = Files.get_file_by_id(form_data.file_id)
|
180 |
if not file:
|
181 |
raise HTTPException(
|
@@ -206,9 +285,7 @@ def add_file_to_knowledge_by_id(
|
|
206 |
file_ids.append(form_data.file_id)
|
207 |
data["file_ids"] = file_ids
|
208 |
|
209 |
-
knowledge = Knowledges.
|
210 |
-
id=id, form_data=KnowledgeUpdateForm(data=data)
|
211 |
-
)
|
212 |
|
213 |
if knowledge:
|
214 |
files = Files.get_files_by_ids(file_ids)
|
@@ -238,9 +315,21 @@ def add_file_to_knowledge_by_id(
|
|
238 |
def update_file_from_knowledge_by_id(
|
239 |
id: str,
|
240 |
form_data: KnowledgeFileIdForm,
|
241 |
-
user=Depends(
|
242 |
):
|
243 |
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
file = Files.get_file_by_id(form_data.file_id)
|
245 |
if not file:
|
246 |
raise HTTPException(
|
@@ -288,9 +377,21 @@ def update_file_from_knowledge_by_id(
|
|
288 |
def remove_file_from_knowledge_by_id(
|
289 |
id: str,
|
290 |
form_data: KnowledgeFileIdForm,
|
291 |
-
user=Depends(
|
292 |
):
|
293 |
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
file = Files.get_file_by_id(form_data.file_id)
|
295 |
if not file:
|
296 |
raise HTTPException(
|
@@ -318,9 +419,7 @@ def remove_file_from_knowledge_by_id(
|
|
318 |
file_ids.remove(form_data.file_id)
|
319 |
data["file_ids"] = file_ids
|
320 |
|
321 |
-
knowledge = Knowledges.
|
322 |
-
id=id, form_data=KnowledgeUpdateForm(data=data)
|
323 |
-
)
|
324 |
|
325 |
if knowledge:
|
326 |
files = Files.get_files_by_ids(file_ids)
|
@@ -347,35 +446,60 @@ def remove_file_from_knowledge_by_id(
|
|
347 |
|
348 |
|
349 |
############################
|
350 |
-
#
|
351 |
############################
|
352 |
|
353 |
|
354 |
-
@router.
|
355 |
-
async def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
try:
|
357 |
VECTOR_DB_CLIENT.delete_collection(collection_name=id)
|
358 |
except Exception as e:
|
359 |
log.debug(e)
|
360 |
pass
|
361 |
-
|
362 |
-
|
363 |
-
id=id, form_data=KnowledgeUpdateForm(data={"file_ids": []})
|
364 |
-
)
|
365 |
-
return knowledge
|
366 |
|
367 |
|
368 |
############################
|
369 |
-
#
|
370 |
############################
|
371 |
|
372 |
|
373 |
-
@router.
|
374 |
-
async def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
try:
|
376 |
VECTOR_DB_CLIENT.delete_collection(collection_name=id)
|
377 |
except Exception as e:
|
378 |
log.debug(e)
|
379 |
pass
|
380 |
-
|
381 |
-
|
|
|
|
|
|
1 |
import json
|
2 |
from typing import Optional, Union
|
3 |
from pydantic import BaseModel
|
4 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
5 |
import logging
|
6 |
|
7 |
from open_webui.apps.webui.models.knowledge import (
|
8 |
Knowledges,
|
|
|
9 |
KnowledgeForm,
|
10 |
KnowledgeResponse,
|
11 |
+
KnowledgeUserResponse,
|
12 |
)
|
13 |
from open_webui.apps.webui.models.files import Files, FileModel
|
14 |
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
|
|
17 |
|
18 |
from open_webui.constants import ERROR_MESSAGES
|
19 |
from open_webui.utils.utils import get_admin_user, get_verified_user
|
20 |
+
from open_webui.utils.access_control import has_access, has_permission
|
21 |
+
|
22 |
+
|
23 |
from open_webui.env import SRC_LOG_LEVELS
|
24 |
|
25 |
|
|
|
29 |
router = APIRouter()
|
30 |
|
31 |
############################
|
32 |
+
# getKnowledgeBases
|
33 |
############################
|
34 |
|
35 |
|
36 |
+
@router.get("/", response_model=list[KnowledgeUserResponse])
|
37 |
+
async def get_knowledge(user=Depends(get_verified_user)):
|
38 |
+
knowledge_bases = []
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
if user.role == "admin":
|
41 |
+
knowledge_bases = Knowledges.get_knowledge_bases()
|
|
|
|
|
|
|
|
|
|
|
42 |
else:
|
43 |
+
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read")
|
44 |
+
|
45 |
+
# Get files for each knowledge base
|
46 |
+
for knowledge_base in knowledge_bases:
|
47 |
+
files = []
|
48 |
+
if knowledge_base.data:
|
49 |
+
files = Files.get_file_metadatas_by_ids(
|
50 |
+
knowledge_base.data.get("file_ids", [])
|
51 |
+
)
|
52 |
|
53 |
+
# Check if all files exist
|
54 |
+
if len(files) != len(knowledge_base.data.get("file_ids", [])):
|
55 |
+
missing_files = list(
|
56 |
+
set(knowledge_base.data.get("file_ids", []))
|
57 |
+
- set([file.id for file in files])
|
58 |
)
|
59 |
+
if missing_files:
|
60 |
+
data = knowledge_base.data or {}
|
61 |
+
file_ids = data.get("file_ids", [])
|
62 |
+
|
63 |
+
for missing_file in missing_files:
|
64 |
+
file_ids.remove(missing_file)
|
65 |
|
66 |
+
data["file_ids"] = file_ids
|
67 |
+
Knowledges.update_knowledge_data_by_id(
|
68 |
+
id=knowledge_base.id, data=data
|
|
|
|
|
69 |
)
|
|
|
|
|
|
|
70 |
|
71 |
+
files = Files.get_file_metadatas_by_ids(file_ids)
|
72 |
+
|
73 |
+
knowledge_base = KnowledgeResponse(
|
74 |
+
**knowledge_base.model_dump(),
|
75 |
+
files=files,
|
76 |
+
)
|
77 |
|
78 |
+
return knowledge_bases
|
|
|
|
|
|
|
79 |
|
|
|
80 |
|
81 |
+
@router.get("/list", response_model=list[KnowledgeUserResponse])
|
82 |
+
async def get_knowledge_list(user=Depends(get_verified_user)):
|
83 |
+
knowledge_bases = []
|
84 |
+
|
85 |
+
if user.role == "admin":
|
86 |
+
knowledge_bases = Knowledges.get_knowledge_bases()
|
87 |
+
else:
|
88 |
+
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "write")
|
89 |
+
|
90 |
+
# Get files for each knowledge base
|
91 |
+
for knowledge_base in knowledge_bases:
|
92 |
+
files = []
|
93 |
+
if knowledge_base.data:
|
94 |
+
files = Files.get_file_metadatas_by_ids(
|
95 |
+
knowledge_base.data.get("file_ids", [])
|
96 |
)
|
97 |
+
|
98 |
+
# Check if all files exist
|
99 |
+
if len(files) != len(knowledge_base.data.get("file_ids", [])):
|
100 |
+
missing_files = list(
|
101 |
+
set(knowledge_base.data.get("file_ids", []))
|
102 |
+
- set([file.id for file in files])
|
103 |
+
)
|
104 |
+
if missing_files:
|
105 |
+
data = knowledge_base.data or {}
|
106 |
+
file_ids = data.get("file_ids", [])
|
107 |
+
|
108 |
+
for missing_file in missing_files:
|
109 |
+
file_ids.remove(missing_file)
|
110 |
+
|
111 |
+
data["file_ids"] = file_ids
|
112 |
+
Knowledges.update_knowledge_data_by_id(
|
113 |
+
id=knowledge_base.id, data=data
|
114 |
+
)
|
115 |
+
|
116 |
+
files = Files.get_file_metadatas_by_ids(file_ids)
|
117 |
+
|
118 |
+
knowledge_base = KnowledgeResponse(
|
119 |
+
**knowledge_base.model_dump(),
|
120 |
+
files=files,
|
121 |
+
)
|
122 |
+
|
123 |
+
return knowledge_bases
|
124 |
|
125 |
|
126 |
############################
|
|
|
129 |
|
130 |
|
131 |
@router.post("/create", response_model=Optional[KnowledgeResponse])
|
132 |
+
async def create_new_knowledge(
|
133 |
+
request: Request, form_data: KnowledgeForm, user=Depends(get_verified_user)
|
134 |
+
):
|
135 |
+
if user.role != "admin" and not has_permission(
|
136 |
+
user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS
|
137 |
+
):
|
138 |
+
raise HTTPException(
|
139 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
140 |
+
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
141 |
+
)
|
142 |
+
|
143 |
knowledge = Knowledges.insert_new_knowledge(user.id, form_data)
|
144 |
|
145 |
if knowledge:
|
|
|
165 |
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
166 |
|
167 |
if knowledge:
|
|
|
|
|
168 |
|
169 |
+
if (
|
170 |
+
user.role == "admin"
|
171 |
+
or knowledge.user_id == user.id
|
172 |
+
or has_access(user.id, "read", knowledge.access_control)
|
173 |
+
):
|
174 |
+
|
175 |
+
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
|
176 |
+
files = Files.get_files_by_ids(file_ids)
|
177 |
+
|
178 |
+
return KnowledgeFilesResponse(
|
179 |
+
**knowledge.model_dump(),
|
180 |
+
files=files,
|
181 |
+
)
|
182 |
else:
|
183 |
raise HTTPException(
|
184 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
194 |
@router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse])
|
195 |
async def update_knowledge_by_id(
|
196 |
id: str,
|
197 |
+
form_data: KnowledgeForm,
|
198 |
+
user=Depends(get_verified_user),
|
199 |
):
|
200 |
+
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
201 |
+
if not knowledge:
|
202 |
+
raise HTTPException(
|
203 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
204 |
+
detail=ERROR_MESSAGES.NOT_FOUND,
|
205 |
+
)
|
206 |
+
|
207 |
+
if knowledge.user_id != user.id and user.role != "admin":
|
208 |
+
raise HTTPException(
|
209 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
210 |
+
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
211 |
+
)
|
212 |
|
213 |
+
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
|
214 |
if knowledge:
|
215 |
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
|
216 |
files = Files.get_files_by_ids(file_ids)
|
|
|
239 |
def add_file_to_knowledge_by_id(
|
240 |
id: str,
|
241 |
form_data: KnowledgeFileIdForm,
|
242 |
+
user=Depends(get_verified_user),
|
243 |
):
|
244 |
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
245 |
+
|
246 |
+
if not knowledge:
|
247 |
+
raise HTTPException(
|
248 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
249 |
+
detail=ERROR_MESSAGES.NOT_FOUND,
|
250 |
+
)
|
251 |
+
|
252 |
+
if knowledge.user_id != user.id and user.role != "admin":
|
253 |
+
raise HTTPException(
|
254 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
255 |
+
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
256 |
+
)
|
257 |
+
|
258 |
file = Files.get_file_by_id(form_data.file_id)
|
259 |
if not file:
|
260 |
raise HTTPException(
|
|
|
285 |
file_ids.append(form_data.file_id)
|
286 |
data["file_ids"] = file_ids
|
287 |
|
288 |
+
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
|
|
|
|
|
289 |
|
290 |
if knowledge:
|
291 |
files = Files.get_files_by_ids(file_ids)
|
|
|
315 |
def update_file_from_knowledge_by_id(
|
316 |
id: str,
|
317 |
form_data: KnowledgeFileIdForm,
|
318 |
+
user=Depends(get_verified_user),
|
319 |
):
|
320 |
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
321 |
+
if not knowledge:
|
322 |
+
raise HTTPException(
|
323 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
324 |
+
detail=ERROR_MESSAGES.NOT_FOUND,
|
325 |
+
)
|
326 |
+
|
327 |
+
if knowledge.user_id != user.id and user.role != "admin":
|
328 |
+
raise HTTPException(
|
329 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
330 |
+
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
331 |
+
)
|
332 |
+
|
333 |
file = Files.get_file_by_id(form_data.file_id)
|
334 |
if not file:
|
335 |
raise HTTPException(
|
|
|
377 |
def remove_file_from_knowledge_by_id(
|
378 |
id: str,
|
379 |
form_data: KnowledgeFileIdForm,
|
380 |
+
user=Depends(get_verified_user),
|
381 |
):
|
382 |
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
383 |
+
if not knowledge:
|
384 |
+
raise HTTPException(
|
385 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
386 |
+
detail=ERROR_MESSAGES.NOT_FOUND,
|
387 |
+
)
|
388 |
+
|
389 |
+
if knowledge.user_id != user.id and user.role != "admin":
|
390 |
+
raise HTTPException(
|
391 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
392 |
+
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
393 |
+
)
|
394 |
+
|
395 |
file = Files.get_file_by_id(form_data.file_id)
|
396 |
if not file:
|
397 |
raise HTTPException(
|
|
|
419 |
file_ids.remove(form_data.file_id)
|
420 |
data["file_ids"] = file_ids
|
421 |
|
422 |
+
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
|
|
|
|
|
423 |
|
424 |
if knowledge:
|
425 |
files = Files.get_files_by_ids(file_ids)
|
|
|
446 |
|
447 |
|
448 |
############################
|
449 |
+
# DeleteKnowledgeById
|
450 |
############################
|
451 |
|
452 |
|
453 |
+
@router.delete("/{id}/delete", response_model=bool)
|
454 |
+
async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
455 |
+
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
456 |
+
if not knowledge:
|
457 |
+
raise HTTPException(
|
458 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
459 |
+
detail=ERROR_MESSAGES.NOT_FOUND,
|
460 |
+
)
|
461 |
+
|
462 |
+
if knowledge.user_id != user.id and user.role != "admin":
|
463 |
+
raise HTTPException(
|
464 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
465 |
+
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
466 |
+
)
|
467 |
+
|
468 |
try:
|
469 |
VECTOR_DB_CLIENT.delete_collection(collection_name=id)
|
470 |
except Exception as e:
|
471 |
log.debug(e)
|
472 |
pass
|
473 |
+
result = Knowledges.delete_knowledge_by_id(id=id)
|
474 |
+
return result
|
|
|
|
|
|
|
475 |
|
476 |
|
477 |
############################
|
478 |
+
# ResetKnowledgeById
|
479 |
############################
|
480 |
|
481 |
|
482 |
+
@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse])
|
483 |
+
async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
484 |
+
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
485 |
+
if not knowledge:
|
486 |
+
raise HTTPException(
|
487 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
488 |
+
detail=ERROR_MESSAGES.NOT_FOUND,
|
489 |
+
)
|
490 |
+
|
491 |
+
if knowledge.user_id != user.id and user.role != "admin":
|
492 |
+
raise HTTPException(
|
493 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
494 |
+
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
495 |
+
)
|
496 |
+
|
497 |
try:
|
498 |
VECTOR_DB_CLIENT.delete_collection(collection_name=id)
|
499 |
except Exception as e:
|
500 |
log.debug(e)
|
501 |
pass
|
502 |
+
|
503 |
+
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
|
504 |
+
|
505 |
+
return knowledge
|
backend/open_webui/apps/webui/routers/models.py
CHANGED
@@ -4,53 +4,71 @@ from open_webui.apps.webui.models.models import (
|
|
4 |
ModelForm,
|
5 |
ModelModel,
|
6 |
ModelResponse,
|
|
|
7 |
Models,
|
8 |
)
|
9 |
from open_webui.constants import ERROR_MESSAGES
|
10 |
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|
|
|
|
11 |
from open_webui.utils.utils import get_admin_user, get_verified_user
|
|
|
|
|
12 |
|
13 |
router = APIRouter()
|
14 |
|
|
|
15 |
###########################
|
16 |
-
#
|
17 |
###########################
|
18 |
|
19 |
|
20 |
-
@router.get("/", response_model=list[
|
21 |
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
|
22 |
-
if
|
23 |
-
|
24 |
-
if model:
|
25 |
-
return [model]
|
26 |
-
else:
|
27 |
-
raise HTTPException(
|
28 |
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
29 |
-
detail=ERROR_MESSAGES.NOT_FOUND,
|
30 |
-
)
|
31 |
else:
|
32 |
-
return Models.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
############################
|
36 |
-
#
|
37 |
############################
|
38 |
|
39 |
|
40 |
-
@router.post("/
|
41 |
-
async def
|
42 |
request: Request,
|
43 |
form_data: ModelForm,
|
44 |
-
user=Depends(
|
45 |
):
|
46 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
raise HTTPException(
|
48 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
49 |
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
|
50 |
)
|
|
|
51 |
else:
|
52 |
model = Models.insert_new_model(form_data, user.id)
|
53 |
-
|
54 |
if model:
|
55 |
return model
|
56 |
else:
|
@@ -60,37 +78,84 @@ async def add_new_model(
|
|
60 |
)
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
############################
|
64 |
-
#
|
65 |
############################
|
66 |
|
67 |
|
68 |
-
@router.post("/
|
69 |
-
async def
|
70 |
-
request: Request,
|
71 |
-
id: str,
|
72 |
-
form_data: ModelForm,
|
73 |
-
user=Depends(get_admin_user),
|
74 |
-
):
|
75 |
model = Models.get_model_by_id(id)
|
76 |
if model:
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
82 |
if model:
|
83 |
return model
|
84 |
else:
|
85 |
raise HTTPException(
|
86 |
-
status_code=status.
|
87 |
-
detail=ERROR_MESSAGES.DEFAULT(),
|
88 |
)
|
89 |
else:
|
90 |
raise HTTPException(
|
91 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
92 |
-
detail=ERROR_MESSAGES.
|
93 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
|
96 |
############################
|
@@ -98,7 +163,26 @@ async def update_model_by_id(
|
|
98 |
############################
|
99 |
|
100 |
|
101 |
-
@router.delete("/delete", response_model=bool)
|
102 |
-
async def delete_model_by_id(id: str, user=Depends(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
result = Models.delete_model_by_id(id)
|
104 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
ModelForm,
|
5 |
ModelModel,
|
6 |
ModelResponse,
|
7 |
+
ModelUserResponse,
|
8 |
Models,
|
9 |
)
|
10 |
from open_webui.constants import ERROR_MESSAGES
|
11 |
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
12 |
+
|
13 |
+
|
14 |
from open_webui.utils.utils import get_admin_user, get_verified_user
|
15 |
+
from open_webui.utils.access_control import has_access, has_permission
|
16 |
+
|
17 |
|
18 |
router = APIRouter()
|
19 |
|
20 |
+
|
21 |
###########################
|
22 |
+
# GetModels
|
23 |
###########################
|
24 |
|
25 |
|
26 |
+
@router.get("/", response_model=list[ModelUserResponse])
|
27 |
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
|
28 |
+
if user.role == "admin":
|
29 |
+
return Models.get_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
else:
|
31 |
+
return Models.get_models_by_user_id(user.id)
|
32 |
+
|
33 |
+
|
34 |
+
###########################
|
35 |
+
# GetBaseModels
|
36 |
+
###########################
|
37 |
+
|
38 |
+
|
39 |
+
@router.get("/base", response_model=list[ModelResponse])
|
40 |
+
async def get_base_models(user=Depends(get_admin_user)):
|
41 |
+
return Models.get_base_models()
|
42 |
|
43 |
|
44 |
############################
|
45 |
+
# CreateNewModel
|
46 |
############################
|
47 |
|
48 |
|
49 |
+
@router.post("/create", response_model=Optional[ModelModel])
|
50 |
+
async def create_new_model(
|
51 |
request: Request,
|
52 |
form_data: ModelForm,
|
53 |
+
user=Depends(get_verified_user),
|
54 |
):
|
55 |
+
if user.role != "admin" and not has_permission(
|
56 |
+
user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS
|
57 |
+
):
|
58 |
+
raise HTTPException(
|
59 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
60 |
+
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
61 |
+
)
|
62 |
+
|
63 |
+
model = Models.get_model_by_id(form_data.id)
|
64 |
+
if model:
|
65 |
raise HTTPException(
|
66 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
67 |
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
|
68 |
)
|
69 |
+
|
70 |
else:
|
71 |
model = Models.insert_new_model(form_data, user.id)
|
|
|
72 |
if model:
|
73 |
return model
|
74 |
else:
|
|
|
78 |
)
|
79 |
|
80 |
|
81 |
+
###########################
|
82 |
+
# GetModelById
|
83 |
+
###########################
|
84 |
+
|
85 |
+
|
86 |
+
@router.get("/id/{id}", response_model=Optional[ModelResponse])
|
87 |
+
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
88 |
+
model = Models.get_model_by_id(id)
|
89 |
+
if model:
|
90 |
+
if (
|
91 |
+
user.role == "admin"
|
92 |
+
or model.user_id == user.id
|
93 |
+
or has_access(user.id, "read", model.access_control)
|
94 |
+
):
|
95 |
+
return model
|
96 |
+
else:
|
97 |
+
raise HTTPException(
|
98 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
99 |
+
detail=ERROR_MESSAGES.NOT_FOUND,
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
############################
|
104 |
+
# ToggelModelById
|
105 |
############################
|
106 |
|
107 |
|
108 |
+
@router.post("/id/{id}/toggle", response_model=Optional[ModelResponse])
|
109 |
+
async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
|
|
|
|
|
|
|
|
|
|
|
110 |
model = Models.get_model_by_id(id)
|
111 |
if model:
|
112 |
+
if (
|
113 |
+
user.role == "admin"
|
114 |
+
or model.user_id == user.id
|
115 |
+
or has_access(user.id, "write", model.access_control)
|
116 |
+
):
|
117 |
+
model = Models.toggle_model_by_id(id)
|
118 |
+
|
119 |
if model:
|
120 |
return model
|
121 |
else:
|
122 |
raise HTTPException(
|
123 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
124 |
+
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
|
125 |
)
|
126 |
else:
|
127 |
raise HTTPException(
|
128 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
129 |
+
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
130 |
)
|
131 |
+
else:
|
132 |
+
raise HTTPException(
|
133 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
134 |
+
detail=ERROR_MESSAGES.NOT_FOUND,
|
135 |
+
)
|
136 |
+
|
137 |
+
|
138 |
+
############################
|
139 |
+
# UpdateModelById
|
140 |
+
############################
|
141 |
+
|
142 |
+
|
143 |
+
@router.post("/id/{id}/update", response_model=Optional[ModelModel])
|
144 |
+
async def update_model_by_id(
|
145 |
+
id: str,
|
146 |
+
form_data: ModelForm,
|
147 |
+
user=Depends(get_verified_user),
|
148 |
+
):
|
149 |
+
model = Models.get_model_by_id(id)
|
150 |
+
|
151 |
+
if not model:
|
152 |
+
raise HTTPException(
|
153 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
154 |
+
detail=ERROR_MESSAGES.NOT_FOUND,
|
155 |
+
)
|
156 |
+
|
157 |
+
model = Models.update_model_by_id(id, form_data)
|
158 |
+
return model
|
159 |
|
160 |
|
161 |
############################
|
|
|
163 |
############################
|
164 |
|
165 |
|
166 |
+
@router.delete("/id/{id}/delete", response_model=bool)
|
167 |
+
async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
|
168 |
+
model = Models.get_model_by_id(id)
|
169 |
+
if not model:
|
170 |
+
raise HTTPException(
|
171 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
172 |
+
detail=ERROR_MESSAGES.NOT_FOUND,
|
173 |
+
)
|
174 |
+
|
175 |
+
if model.user_id != user.id and user.role != "admin":
|
176 |
+
raise HTTPException(
|
177 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
178 |
+
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
179 |
+
)
|
180 |
+
|
181 |
result = Models.delete_model_by_id(id)
|
182 |
return result
|
183 |
+
|
184 |
+
|
185 |
+
@router.delete("/delete/all", response_model=bool)
|
186 |
+
async def delete_all_models(user=Depends(get_admin_user)):
|
187 |
+
result = Models.delete_all_models()
|
188 |
+
return result
|
backend/open_webui/apps/webui/routers/prompts.py
CHANGED
@@ -1,9 +1,15 @@
|
|
1 |
from typing import Optional
|
2 |
|
3 |
-
from open_webui.apps.webui.models.prompts import
|
|
|
|
|
|
|
|
|
|
|
4 |
from open_webui.constants import ERROR_MESSAGES
|
5 |
-
from fastapi import APIRouter, Depends, HTTPException, status
|
6 |
from open_webui.utils.utils import get_admin_user, get_verified_user
|
|
|
7 |
|
8 |
router = APIRouter()
|
9 |
|
@@ -14,7 +20,22 @@ router = APIRouter()
|
|
14 |
|
15 |
@router.get("/", response_model=list[PromptModel])
|
16 |
async def get_prompts(user=Depends(get_verified_user)):
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
############################
|
@@ -23,7 +44,17 @@ async def get_prompts(user=Depends(get_verified_user)):
|
|
23 |
|
24 |
|
25 |
@router.post("/create", response_model=Optional[PromptModel])
|
26 |
-
async def create_new_prompt(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
prompt = Prompts.get_prompt_by_command(form_data.command)
|
28 |
if prompt is None:
|
29 |
prompt = Prompts.insert_new_prompt(user.id, form_data)
|
@@ -50,7 +81,12 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
|
|
50 |
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
51 |
|
52 |
if prompt:
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
54 |
else:
|
55 |
raise HTTPException(
|
56 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
@@ -67,8 +103,21 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
|
|
67 |
async def update_prompt_by_command(
|
68 |
command: str,
|
69 |
form_data: PromptForm,
|
70 |
-
user=Depends(
|
71 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
|
73 |
if prompt:
|
74 |
return prompt
|
@@ -85,6 +134,19 @@ async def update_prompt_by_command(
|
|
85 |
|
86 |
|
87 |
@router.delete("/command/{command}/delete", response_model=bool)
|
88 |
-
async def delete_prompt_by_command(command: str, user=Depends(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
result = Prompts.delete_prompt_by_command(f"/{command}")
|
90 |
return result
|
|
|
1 |
from typing import Optional
|
2 |
|
3 |
+
from open_webui.apps.webui.models.prompts import (
|
4 |
+
PromptForm,
|
5 |
+
PromptUserResponse,
|
6 |
+
PromptModel,
|
7 |
+
Prompts,
|
8 |
+
)
|
9 |
from open_webui.constants import ERROR_MESSAGES
|
10 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
11 |
from open_webui.utils.utils import get_admin_user, get_verified_user
|
12 |
+
from open_webui.utils.access_control import has_access, has_permission
|
13 |
|
14 |
router = APIRouter()
|
15 |
|
|
|
20 |
|
21 |
@router.get("/", response_model=list[PromptModel])
|
22 |
async def get_prompts(user=Depends(get_verified_user)):
|
23 |
+
if user.role == "admin":
|
24 |
+
prompts = Prompts.get_prompts()
|
25 |
+
else:
|
26 |
+
prompts = Prompts.get_prompts_by_user_id(user.id, "read")
|
27 |
+
|
28 |
+
return prompts
|
29 |
+
|
30 |
+
|
31 |
+
@router.get("/list", response_model=list[PromptUserResponse])
|
32 |
+
async def get_prompt_list(user=Depends(get_verified_user)):
|
33 |
+
if user.role == "admin":
|
34 |
+
prompts = Prompts.get_prompts()
|
35 |
+
else:
|
36 |
+
prompts = Prompts.get_prompts_by_user_id(user.id, "write")
|
37 |
+
|
38 |
+
return prompts
|
39 |
|
40 |
|
41 |
############################
|
|
|
44 |
|
45 |
|
46 |
@router.post("/create", response_model=Optional[PromptModel])
|
47 |
+
async def create_new_prompt(
|
48 |
+
request: Request, form_data: PromptForm, user=Depends(get_verified_user)
|
49 |
+
):
|
50 |
+
if user.role != "admin" and not has_permission(
|
51 |
+
user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS
|
52 |
+
):
|
53 |
+
raise HTTPException(
|
54 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
55 |
+
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
56 |
+
)
|
57 |
+
|
58 |
prompt = Prompts.get_prompt_by_command(form_data.command)
|
59 |
if prompt is None:
|
60 |
prompt = Prompts.insert_new_prompt(user.id, form_data)
|
|
|
81 |
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
82 |
|
83 |
if prompt:
|
84 |
+
if (
|
85 |
+
user.role == "admin"
|
86 |
+
or prompt.user_id == user.id
|
87 |
+
or has_access(user.id, "read", prompt.access_control)
|
88 |
+
):
|
89 |
+
return prompt
|
90 |
else:
|
91 |
raise HTTPException(
|
92 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
103 |
async def update_prompt_by_command(
|
104 |
command: str,
|
105 |
form_data: PromptForm,
|
106 |
+
user=Depends(get_verified_user),
|
107 |
):
|
108 |
+
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
109 |
+
if not prompt:
|
110 |
+
raise HTTPException(
|
111 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
112 |
+
detail=ERROR_MESSAGES.NOT_FOUND,
|
113 |
+
)
|
114 |
+
|
115 |
+
if prompt.user_id != user.id and user.role != "admin":
|
116 |
+
raise HTTPException(
|
117 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
118 |
+
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
119 |
+
)
|
120 |
+
|
121 |
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
|
122 |
if prompt:
|
123 |
return prompt
|
|
|
134 |
|
135 |
|
136 |
@router.delete("/command/{command}/delete", response_model=bool)
|
137 |
+
async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)):
|
138 |
+
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
139 |
+
if not prompt:
|
140 |
+
raise HTTPException(
|
141 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
142 |
+
detail=ERROR_MESSAGES.NOT_FOUND,
|
143 |
+
)
|
144 |
+
|
145 |
+
if prompt.user_id != user.id and user.role != "admin":
|
146 |
+
raise HTTPException(
|
147 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
148 |
+
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
149 |
+
)
|
150 |
+
|
151 |
result = Prompts.delete_prompt_by_command(f"/{command}")
|
152 |
return result
|
backend/open_webui/apps/webui/routers/tools.py
CHANGED
@@ -2,50 +2,82 @@ import os
|
|
2 |
from pathlib import Path
|
3 |
from typing import Optional
|
4 |
|
5 |
-
from open_webui.apps.webui.models.tools import
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from open_webui.config import CACHE_DIR, DATA_DIR
|
8 |
from open_webui.constants import ERROR_MESSAGES
|
9 |
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
10 |
from open_webui.utils.tools import get_tools_specs
|
11 |
from open_webui.utils.utils import get_admin_user, get_verified_user
|
|
|
12 |
|
13 |
|
14 |
router = APIRouter()
|
15 |
|
16 |
############################
|
17 |
-
#
|
18 |
############################
|
19 |
|
20 |
|
21 |
-
@router.get("/", response_model=list[
|
22 |
-
async def
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
############################
|
28 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
############################
|
30 |
|
31 |
|
32 |
@router.get("/export", response_model=list[ToolModel])
|
33 |
-
async def
|
34 |
-
|
35 |
-
return
|
36 |
|
37 |
|
38 |
############################
|
39 |
-
#
|
40 |
############################
|
41 |
|
42 |
|
43 |
@router.post("/create", response_model=Optional[ToolResponse])
|
44 |
-
async def
|
45 |
request: Request,
|
46 |
form_data: ToolForm,
|
47 |
-
user=Depends(
|
48 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
if not form_data.id.isidentifier():
|
50 |
raise HTTPException(
|
51 |
status_code=status.HTTP_400_BAD_REQUEST,
|
@@ -54,30 +86,30 @@ async def create_new_toolkit(
|
|
54 |
|
55 |
form_data.id = form_data.id.lower()
|
56 |
|
57 |
-
|
58 |
-
if
|
59 |
try:
|
60 |
form_data.content = replace_imports(form_data.content)
|
61 |
-
|
62 |
form_data.id, content=form_data.content
|
63 |
)
|
64 |
form_data.meta.manifest = frontmatter
|
65 |
|
66 |
TOOLS = request.app.state.TOOLS
|
67 |
-
TOOLS[form_data.id] =
|
68 |
|
69 |
specs = get_tools_specs(TOOLS[form_data.id])
|
70 |
-
|
71 |
|
72 |
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
|
73 |
tool_cache_dir.mkdir(parents=True, exist_ok=True)
|
74 |
|
75 |
-
if
|
76 |
-
return
|
77 |
else:
|
78 |
raise HTTPException(
|
79 |
status_code=status.HTTP_400_BAD_REQUEST,
|
80 |
-
detail=ERROR_MESSAGES.DEFAULT("Error creating
|
81 |
)
|
82 |
except Exception as e:
|
83 |
print(e)
|
@@ -93,16 +125,21 @@ async def create_new_toolkit(
|
|
93 |
|
94 |
|
95 |
############################
|
96 |
-
#
|
97 |
############################
|
98 |
|
99 |
|
100 |
@router.get("/id/{id}", response_model=Optional[ToolModel])
|
101 |
-
async def
|
102 |
-
|
103 |
-
|
104 |
-
if
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
106 |
else:
|
107 |
raise HTTPException(
|
108 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
@@ -111,26 +148,39 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
|
|
111 |
|
112 |
|
113 |
############################
|
114 |
-
#
|
115 |
############################
|
116 |
|
117 |
|
118 |
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
|
119 |
-
async def
|
120 |
request: Request,
|
121 |
id: str,
|
122 |
form_data: ToolForm,
|
123 |
-
user=Depends(
|
124 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
try:
|
126 |
form_data.content = replace_imports(form_data.content)
|
127 |
-
|
128 |
id, content=form_data.content
|
129 |
)
|
130 |
form_data.meta.manifest = frontmatter
|
131 |
|
132 |
TOOLS = request.app.state.TOOLS
|
133 |
-
TOOLS[id] =
|
134 |
|
135 |
specs = get_tools_specs(TOOLS[id])
|
136 |
|
@@ -140,14 +190,14 @@ async def update_toolkit_by_id(
|
|
140 |
}
|
141 |
|
142 |
print(updated)
|
143 |
-
|
144 |
|
145 |
-
if
|
146 |
-
return
|
147 |
else:
|
148 |
raise HTTPException(
|
149 |
status_code=status.HTTP_400_BAD_REQUEST,
|
150 |
-
detail=ERROR_MESSAGES.DEFAULT("Error updating
|
151 |
)
|
152 |
|
153 |
except Exception as e:
|
@@ -158,14 +208,28 @@ async def update_toolkit_by_id(
|
|
158 |
|
159 |
|
160 |
############################
|
161 |
-
#
|
162 |
############################
|
163 |
|
164 |
|
165 |
@router.delete("/id/{id}/delete", response_model=bool)
|
166 |
-
async def
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
if result:
|
170 |
TOOLS = request.app.state.TOOLS
|
171 |
if id in TOOLS:
|
@@ -180,9 +244,9 @@ async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin
|
|
180 |
|
181 |
|
182 |
@router.get("/id/{id}/valves", response_model=Optional[dict])
|
183 |
-
async def
|
184 |
-
|
185 |
-
if
|
186 |
try:
|
187 |
valves = Tools.get_tool_valves_by_id(id)
|
188 |
return valves
|
@@ -204,19 +268,19 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
|
|
204 |
|
205 |
|
206 |
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
|
207 |
-
async def
|
208 |
-
request: Request, id: str, user=Depends(
|
209 |
):
|
210 |
-
|
211 |
-
if
|
212 |
if id in request.app.state.TOOLS:
|
213 |
-
|
214 |
else:
|
215 |
-
|
216 |
-
request.app.state.TOOLS[id] =
|
217 |
|
218 |
-
if hasattr(
|
219 |
-
Valves =
|
220 |
return Valves.schema()
|
221 |
return None
|
222 |
else:
|
@@ -232,19 +296,19 @@ async def get_toolkit_valves_spec_by_id(
|
|
232 |
|
233 |
|
234 |
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
|
235 |
-
async def
|
236 |
-
request: Request, id: str, form_data: dict, user=Depends(
|
237 |
):
|
238 |
-
|
239 |
-
if
|
240 |
if id in request.app.state.TOOLS:
|
241 |
-
|
242 |
else:
|
243 |
-
|
244 |
-
request.app.state.TOOLS[id] =
|
245 |
|
246 |
-
if hasattr(
|
247 |
-
Valves =
|
248 |
|
249 |
try:
|
250 |
form_data = {k: v for k, v in form_data.items() if v is not None}
|
@@ -276,9 +340,9 @@ async def update_toolkit_valves_by_id(
|
|
276 |
|
277 |
|
278 |
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
|
279 |
-
async def
|
280 |
-
|
281 |
-
if
|
282 |
try:
|
283 |
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
|
284 |
return user_valves
|
@@ -295,19 +359,19 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)
|
|
295 |
|
296 |
|
297 |
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
|
298 |
-
async def
|
299 |
request: Request, id: str, user=Depends(get_verified_user)
|
300 |
):
|
301 |
-
|
302 |
-
if
|
303 |
if id in request.app.state.TOOLS:
|
304 |
-
|
305 |
else:
|
306 |
-
|
307 |
-
request.app.state.TOOLS[id] =
|
308 |
|
309 |
-
if hasattr(
|
310 |
-
UserValves =
|
311 |
return UserValves.schema()
|
312 |
return None
|
313 |
else:
|
@@ -318,20 +382,20 @@ async def get_toolkit_user_valves_spec_by_id(
|
|
318 |
|
319 |
|
320 |
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
|
321 |
-
async def
|
322 |
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
323 |
):
|
324 |
-
|
325 |
|
326 |
-
if
|
327 |
if id in request.app.state.TOOLS:
|
328 |
-
|
329 |
else:
|
330 |
-
|
331 |
-
request.app.state.TOOLS[id] =
|
332 |
|
333 |
-
if hasattr(
|
334 |
-
UserValves =
|
335 |
|
336 |
try:
|
337 |
form_data = {k: v for k, v in form_data.items() if v is not None}
|
|
|
2 |
from pathlib import Path
|
3 |
from typing import Optional
|
4 |
|
5 |
+
from open_webui.apps.webui.models.tools import (
|
6 |
+
ToolForm,
|
7 |
+
ToolModel,
|
8 |
+
ToolResponse,
|
9 |
+
ToolUserResponse,
|
10 |
+
Tools,
|
11 |
+
)
|
12 |
+
from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports
|
13 |
from open_webui.config import CACHE_DIR, DATA_DIR
|
14 |
from open_webui.constants import ERROR_MESSAGES
|
15 |
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
16 |
from open_webui.utils.tools import get_tools_specs
|
17 |
from open_webui.utils.utils import get_admin_user, get_verified_user
|
18 |
+
from open_webui.utils.access_control import has_access, has_permission
|
19 |
|
20 |
|
21 |
router = APIRouter()
|
22 |
|
23 |
############################
|
24 |
+
# GetTools
|
25 |
############################
|
26 |
|
27 |
|
28 |
+
@router.get("/", response_model=list[ToolUserResponse])
|
29 |
+
async def get_tools(user=Depends(get_verified_user)):
|
30 |
+
if user.role == "admin":
|
31 |
+
tools = Tools.get_tools()
|
32 |
+
else:
|
33 |
+
tools = Tools.get_tools_by_user_id(user.id, "read")
|
34 |
+
return tools
|
35 |
|
36 |
|
37 |
############################
|
38 |
+
# GetToolList
|
39 |
+
############################
|
40 |
+
|
41 |
+
|
42 |
+
@router.get("/list", response_model=list[ToolUserResponse])
|
43 |
+
async def get_tool_list(user=Depends(get_verified_user)):
|
44 |
+
if user.role == "admin":
|
45 |
+
tools = Tools.get_tools()
|
46 |
+
else:
|
47 |
+
tools = Tools.get_tools_by_user_id(user.id, "write")
|
48 |
+
return tools
|
49 |
+
|
50 |
+
|
51 |
+
############################
|
52 |
+
# ExportTools
|
53 |
############################
|
54 |
|
55 |
|
56 |
@router.get("/export", response_model=list[ToolModel])
|
57 |
+
async def export_tools(user=Depends(get_admin_user)):
|
58 |
+
tools = Tools.get_tools()
|
59 |
+
return tools
|
60 |
|
61 |
|
62 |
############################
|
63 |
+
# CreateNewTools
|
64 |
############################
|
65 |
|
66 |
|
67 |
@router.post("/create", response_model=Optional[ToolResponse])
|
68 |
+
async def create_new_tools(
|
69 |
request: Request,
|
70 |
form_data: ToolForm,
|
71 |
+
user=Depends(get_verified_user),
|
72 |
):
|
73 |
+
if user.role != "admin" and not has_permission(
|
74 |
+
user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS
|
75 |
+
):
|
76 |
+
raise HTTPException(
|
77 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
78 |
+
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
79 |
+
)
|
80 |
+
|
81 |
if not form_data.id.isidentifier():
|
82 |
raise HTTPException(
|
83 |
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
86 |
|
87 |
form_data.id = form_data.id.lower()
|
88 |
|
89 |
+
tools = Tools.get_tool_by_id(form_data.id)
|
90 |
+
if tools is None:
|
91 |
try:
|
92 |
form_data.content = replace_imports(form_data.content)
|
93 |
+
tools_module, frontmatter = load_tools_module_by_id(
|
94 |
form_data.id, content=form_data.content
|
95 |
)
|
96 |
form_data.meta.manifest = frontmatter
|
97 |
|
98 |
TOOLS = request.app.state.TOOLS
|
99 |
+
TOOLS[form_data.id] = tools_module
|
100 |
|
101 |
specs = get_tools_specs(TOOLS[form_data.id])
|
102 |
+
tools = Tools.insert_new_tool(user.id, form_data, specs)
|
103 |
|
104 |
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
|
105 |
tool_cache_dir.mkdir(parents=True, exist_ok=True)
|
106 |
|
107 |
+
if tools:
|
108 |
+
return tools
|
109 |
else:
|
110 |
raise HTTPException(
|
111 |
status_code=status.HTTP_400_BAD_REQUEST,
|
112 |
+
detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
|
113 |
)
|
114 |
except Exception as e:
|
115 |
print(e)
|
|
|
125 |
|
126 |
|
127 |
############################
|
128 |
+
# GetToolsById
|
129 |
############################
|
130 |
|
131 |
|
132 |
@router.get("/id/{id}", response_model=Optional[ToolModel])
|
133 |
+
async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
|
134 |
+
tools = Tools.get_tool_by_id(id)
|
135 |
+
|
136 |
+
if tools:
|
137 |
+
if (
|
138 |
+
user.role == "admin"
|
139 |
+
or tools.user_id == user.id
|
140 |
+
or has_access(user.id, "read", tools.access_control)
|
141 |
+
):
|
142 |
+
return tools
|
143 |
else:
|
144 |
raise HTTPException(
|
145 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
148 |
|
149 |
|
150 |
############################
|
151 |
+
# UpdateToolsById
|
152 |
############################
|
153 |
|
154 |
|
155 |
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
|
156 |
+
async def update_tools_by_id(
|
157 |
request: Request,
|
158 |
id: str,
|
159 |
form_data: ToolForm,
|
160 |
+
user=Depends(get_verified_user),
|
161 |
):
|
162 |
+
tools = Tools.get_tool_by_id(id)
|
163 |
+
if not tools:
|
164 |
+
raise HTTPException(
|
165 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
166 |
+
detail=ERROR_MESSAGES.NOT_FOUND,
|
167 |
+
)
|
168 |
+
|
169 |
+
if tools.user_id != user.id and user.role != "admin":
|
170 |
+
raise HTTPException(
|
171 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
172 |
+
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
173 |
+
)
|
174 |
+
|
175 |
try:
|
176 |
form_data.content = replace_imports(form_data.content)
|
177 |
+
tools_module, frontmatter = load_tools_module_by_id(
|
178 |
id, content=form_data.content
|
179 |
)
|
180 |
form_data.meta.manifest = frontmatter
|
181 |
|
182 |
TOOLS = request.app.state.TOOLS
|
183 |
+
TOOLS[id] = tools_module
|
184 |
|
185 |
specs = get_tools_specs(TOOLS[id])
|
186 |
|
|
|
190 |
}
|
191 |
|
192 |
print(updated)
|
193 |
+
tools = Tools.update_tool_by_id(id, updated)
|
194 |
|
195 |
+
if tools:
|
196 |
+
return tools
|
197 |
else:
|
198 |
raise HTTPException(
|
199 |
status_code=status.HTTP_400_BAD_REQUEST,
|
200 |
+
detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
|
201 |
)
|
202 |
|
203 |
except Exception as e:
|
|
|
208 |
|
209 |
|
210 |
############################
|
211 |
+
# DeleteToolsById
|
212 |
############################
|
213 |
|
214 |
|
215 |
@router.delete("/id/{id}/delete", response_model=bool)
|
216 |
+
async def delete_tools_by_id(
|
217 |
+
request: Request, id: str, user=Depends(get_verified_user)
|
218 |
+
):
|
219 |
+
tools = Tools.get_tool_by_id(id)
|
220 |
+
if not tools:
|
221 |
+
raise HTTPException(
|
222 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
223 |
+
detail=ERROR_MESSAGES.NOT_FOUND,
|
224 |
+
)
|
225 |
|
226 |
+
if tools.user_id != user.id and user.role != "admin":
|
227 |
+
raise HTTPException(
|
228 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
229 |
+
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
230 |
+
)
|
231 |
+
|
232 |
+
result = Tools.delete_tool_by_id(id)
|
233 |
if result:
|
234 |
TOOLS = request.app.state.TOOLS
|
235 |
if id in TOOLS:
|
|
|
244 |
|
245 |
|
246 |
@router.get("/id/{id}/valves", response_model=Optional[dict])
|
247 |
+
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
|
248 |
+
tools = Tools.get_tool_by_id(id)
|
249 |
+
if tools:
|
250 |
try:
|
251 |
valves = Tools.get_tool_valves_by_id(id)
|
252 |
return valves
|
|
|
268 |
|
269 |
|
270 |
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
|
271 |
+
async def get_tools_valves_spec_by_id(
|
272 |
+
request: Request, id: str, user=Depends(get_verified_user)
|
273 |
):
|
274 |
+
tools = Tools.get_tool_by_id(id)
|
275 |
+
if tools:
|
276 |
if id in request.app.state.TOOLS:
|
277 |
+
tools_module = request.app.state.TOOLS[id]
|
278 |
else:
|
279 |
+
tools_module, _ = load_tools_module_by_id(id)
|
280 |
+
request.app.state.TOOLS[id] = tools_module
|
281 |
|
282 |
+
if hasattr(tools_module, "Valves"):
|
283 |
+
Valves = tools_module.Valves
|
284 |
return Valves.schema()
|
285 |
return None
|
286 |
else:
|
|
|
296 |
|
297 |
|
298 |
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
|
299 |
+
async def update_tools_valves_by_id(
|
300 |
+
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
301 |
):
|
302 |
+
tools = Tools.get_tool_by_id(id)
|
303 |
+
if tools:
|
304 |
if id in request.app.state.TOOLS:
|
305 |
+
tools_module = request.app.state.TOOLS[id]
|
306 |
else:
|
307 |
+
tools_module, _ = load_tools_module_by_id(id)
|
308 |
+
request.app.state.TOOLS[id] = tools_module
|
309 |
|
310 |
+
if hasattr(tools_module, "Valves"):
|
311 |
+
Valves = tools_module.Valves
|
312 |
|
313 |
try:
|
314 |
form_data = {k: v for k, v in form_data.items() if v is not None}
|
|
|
340 |
|
341 |
|
342 |
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
|
343 |
+
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
|
344 |
+
tools = Tools.get_tool_by_id(id)
|
345 |
+
if tools:
|
346 |
try:
|
347 |
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
|
348 |
return user_valves
|
|
|
359 |
|
360 |
|
361 |
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
|
362 |
+
async def get_tools_user_valves_spec_by_id(
|
363 |
request: Request, id: str, user=Depends(get_verified_user)
|
364 |
):
|
365 |
+
tools = Tools.get_tool_by_id(id)
|
366 |
+
if tools:
|
367 |
if id in request.app.state.TOOLS:
|
368 |
+
tools_module = request.app.state.TOOLS[id]
|
369 |
else:
|
370 |
+
tools_module, _ = load_tools_module_by_id(id)
|
371 |
+
request.app.state.TOOLS[id] = tools_module
|
372 |
|
373 |
+
if hasattr(tools_module, "UserValves"):
|
374 |
+
UserValves = tools_module.UserValves
|
375 |
return UserValves.schema()
|
376 |
return None
|
377 |
else:
|
|
|
382 |
|
383 |
|
384 |
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
|
385 |
+
async def update_tools_user_valves_by_id(
|
386 |
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
387 |
):
|
388 |
+
tools = Tools.get_tool_by_id(id)
|
389 |
|
390 |
+
if tools:
|
391 |
if id in request.app.state.TOOLS:
|
392 |
+
tools_module = request.app.state.TOOLS[id]
|
393 |
else:
|
394 |
+
tools_module, _ = load_tools_module_by_id(id)
|
395 |
+
request.app.state.TOOLS[id] = tools_module
|
396 |
|
397 |
+
if hasattr(tools_module, "UserValves"):
|
398 |
+
UserValves = tools_module.UserValves
|
399 |
|
400 |
try:
|
401 |
form_data = {k: v for k, v in form_data.items() if v is not None}
|
backend/open_webui/apps/webui/routers/users.py
CHANGED
@@ -31,21 +31,58 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
|
|
31 |
return Users.get_users(skip, limit)
|
32 |
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
############################
|
35 |
# User Permissions
|
36 |
############################
|
37 |
|
38 |
|
39 |
-
@router.get("/permissions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
|
41 |
return request.app.state.config.USER_PERMISSIONS
|
42 |
|
43 |
|
44 |
-
@router.post("/permissions
|
45 |
async def update_user_permissions(
|
46 |
-
request: Request, form_data:
|
47 |
):
|
48 |
-
request.app.state.config.USER_PERMISSIONS = form_data
|
49 |
return request.app.state.config.USER_PERMISSIONS
|
50 |
|
51 |
|
|
|
31 |
return Users.get_users(skip, limit)
|
32 |
|
33 |
|
34 |
+
############################
|
35 |
+
# User Groups
|
36 |
+
############################
|
37 |
+
|
38 |
+
|
39 |
+
@router.get("/groups")
|
40 |
+
async def get_user_groups(user=Depends(get_verified_user)):
|
41 |
+
return Users.get_user_groups(user.id)
|
42 |
+
|
43 |
+
|
44 |
############################
|
45 |
# User Permissions
|
46 |
############################
|
47 |
|
48 |
|
49 |
+
@router.get("/permissions")
|
50 |
+
async def get_user_permissisions(user=Depends(get_verified_user)):
|
51 |
+
return Users.get_user_groups(user.id)
|
52 |
+
|
53 |
+
|
54 |
+
############################
|
55 |
+
# User Default Permissions
|
56 |
+
############################
|
57 |
+
class WorkspacePermissions(BaseModel):
|
58 |
+
models: bool
|
59 |
+
knowledge: bool
|
60 |
+
prompts: bool
|
61 |
+
tools: bool
|
62 |
+
|
63 |
+
|
64 |
+
class ChatPermissions(BaseModel):
|
65 |
+
file_upload: bool
|
66 |
+
delete: bool
|
67 |
+
edit: bool
|
68 |
+
temporary: bool
|
69 |
+
|
70 |
+
|
71 |
+
class UserPermissions(BaseModel):
|
72 |
+
workspace: WorkspacePermissions
|
73 |
+
chat: ChatPermissions
|
74 |
+
|
75 |
+
|
76 |
+
@router.get("/default/permissions")
|
77 |
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
|
78 |
return request.app.state.config.USER_PERMISSIONS
|
79 |
|
80 |
|
81 |
+
@router.post("/default/permissions")
|
82 |
async def update_user_permissions(
|
83 |
+
request: Request, form_data: UserPermissions, user=Depends(get_admin_user)
|
84 |
):
|
85 |
+
request.app.state.config.USER_PERMISSIONS = form_data.model_dump()
|
86 |
return request.app.state.config.USER_PERMISSIONS
|
87 |
|
88 |
|
backend/open_webui/apps/webui/utils.py
CHANGED
@@ -63,7 +63,7 @@ def replace_imports(content):
|
|
63 |
return content
|
64 |
|
65 |
|
66 |
-
def
|
67 |
|
68 |
if content is None:
|
69 |
tool = Tools.get_tool_by_id(toolkit_id)
|
|
|
63 |
return content
|
64 |
|
65 |
|
66 |
+
def load_tools_module_by_id(toolkit_id, content=None):
|
67 |
|
68 |
if content is None:
|
69 |
tool = Tools.get_tool_by_id(toolkit_id)
|
backend/open_webui/config.py
CHANGED
@@ -20,6 +20,7 @@ from open_webui.env import (
|
|
20 |
WEBUI_FAVICON_URL,
|
21 |
WEBUI_NAME,
|
22 |
log,
|
|
|
23 |
)
|
24 |
from pydantic import BaseModel
|
25 |
from sqlalchemy import JSON, Column, DateTime, Integer, func
|
@@ -264,6 +265,13 @@ class AppConfig:
|
|
264 |
# WEBUI_AUTH (Required for security)
|
265 |
####################################
|
266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
JWT_EXPIRES_IN = PersistentConfig(
|
268 |
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
|
269 |
)
|
@@ -606,6 +614,12 @@ OLLAMA_BASE_URLS = PersistentConfig(
|
|
606 |
"OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
|
607 |
)
|
608 |
|
|
|
|
|
|
|
|
|
|
|
|
|
609 |
####################################
|
610 |
# OPENAI_API
|
611 |
####################################
|
@@ -646,15 +660,20 @@ OPENAI_API_BASE_URLS = PersistentConfig(
|
|
646 |
"OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
|
647 |
)
|
648 |
|
649 |
-
|
|
|
|
|
|
|
|
|
650 |
|
|
|
|
|
651 |
try:
|
652 |
OPENAI_API_KEY = OPENAI_API_KEYS.value[
|
653 |
OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
|
654 |
]
|
655 |
except Exception:
|
656 |
pass
|
657 |
-
|
658 |
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
659 |
|
660 |
####################################
|
@@ -727,12 +746,36 @@ DEFAULT_USER_ROLE = PersistentConfig(
|
|
727 |
os.getenv("DEFAULT_USER_ROLE", "pending"),
|
728 |
)
|
729 |
|
730 |
-
|
731 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
732 |
)
|
733 |
|
734 |
-
|
735 |
-
os.environ.get("
|
736 |
)
|
737 |
|
738 |
USER_PERMISSIONS_CHAT_TEMPORARY = (
|
@@ -741,13 +784,20 @@ USER_PERMISSIONS_CHAT_TEMPORARY = (
|
|
741 |
|
742 |
USER_PERMISSIONS = PersistentConfig(
|
743 |
"USER_PERMISSIONS",
|
744 |
-
"
|
745 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
746 |
"chat": {
|
747 |
-
"
|
748 |
-
"
|
|
|
749 |
"temporary": USER_PERMISSIONS_CHAT_TEMPORARY,
|
750 |
-
}
|
751 |
},
|
752 |
)
|
753 |
|
@@ -773,18 +823,6 @@ DEFAULT_ARENA_MODEL = {
|
|
773 |
},
|
774 |
}
|
775 |
|
776 |
-
ENABLE_MODEL_FILTER = PersistentConfig(
|
777 |
-
"ENABLE_MODEL_FILTER",
|
778 |
-
"model_filter.enable",
|
779 |
-
os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true",
|
780 |
-
)
|
781 |
-
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
|
782 |
-
MODEL_FILTER_LIST = PersistentConfig(
|
783 |
-
"MODEL_FILTER_LIST",
|
784 |
-
"model_filter.list",
|
785 |
-
[model.strip() for model in MODEL_FILTER_LIST.split(";")],
|
786 |
-
)
|
787 |
-
|
788 |
WEBHOOK_URL = PersistentConfig(
|
789 |
"WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
|
790 |
)
|
@@ -904,19 +942,55 @@ TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
|
904 |
os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""),
|
905 |
)
|
906 |
|
907 |
-
|
908 |
-
"
|
909 |
-
"task.
|
910 |
-
os.environ.get("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
911 |
)
|
912 |
|
913 |
|
914 |
-
|
915 |
-
"
|
916 |
-
"task.
|
917 |
-
os.environ.get("
|
918 |
)
|
919 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
920 |
|
921 |
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
|
922 |
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
|
@@ -956,6 +1030,21 @@ MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
|
|
956 |
|
957 |
# Qdrant
|
958 |
QDRANT_URI = os.environ.get("QDRANT_URI", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
959 |
|
960 |
####################################
|
961 |
# Information Retrieval (RAG)
|
@@ -1035,11 +1124,11 @@ RAG_EMBEDDING_MODEL = PersistentConfig(
|
|
1035 |
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
|
1036 |
|
1037 |
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
|
1038 |
-
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
|
1039 |
)
|
1040 |
|
1041 |
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
|
1042 |
-
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
1043 |
)
|
1044 |
|
1045 |
RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
|
@@ -1060,11 +1149,11 @@ if RAG_RERANKING_MODEL.value != "":
|
|
1060 |
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
|
1061 |
|
1062 |
RAG_RERANKING_MODEL_AUTO_UPDATE = (
|
1063 |
-
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
|
1064 |
)
|
1065 |
|
1066 |
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
|
1067 |
-
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
1068 |
)
|
1069 |
|
1070 |
|
@@ -1129,6 +1218,19 @@ RAG_OPENAI_API_KEY = PersistentConfig(
|
|
1129 |
os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
|
1130 |
)
|
1131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1132 |
ENABLE_RAG_LOCAL_WEB_FETCH = (
|
1133 |
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
|
1134 |
)
|
@@ -1218,6 +1320,12 @@ TAVILY_API_KEY = PersistentConfig(
|
|
1218 |
os.getenv("TAVILY_API_KEY", ""),
|
1219 |
)
|
1220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1221 |
SEARCHAPI_API_KEY = PersistentConfig(
|
1222 |
"SEARCHAPI_API_KEY",
|
1223 |
"rag.web.search.searchapi_api_key",
|
@@ -1230,6 +1338,21 @@ SEARCHAPI_ENGINE = PersistentConfig(
|
|
1230 |
os.getenv("SEARCHAPI_ENGINE", ""),
|
1231 |
)
|
1232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1233 |
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
|
1234 |
"RAG_WEB_SEARCH_RESULT_COUNT",
|
1235 |
"rag.web.search.result_count",
|
@@ -1281,7 +1404,7 @@ AUTOMATIC1111_CFG_SCALE = PersistentConfig(
|
|
1281 |
|
1282 |
|
1283 |
AUTOMATIC1111_SAMPLER = PersistentConfig(
|
1284 |
-
"
|
1285 |
"image_generation.automatic1111.sampler",
|
1286 |
(
|
1287 |
os.environ.get("AUTOMATIC1111_SAMPLER")
|
@@ -1550,3 +1673,74 @@ AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT = PersistentConfig(
|
|
1550 |
"AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3"
|
1551 |
),
|
1552 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
WEBUI_FAVICON_URL,
|
21 |
WEBUI_NAME,
|
22 |
log,
|
23 |
+
DATABASE_URL,
|
24 |
)
|
25 |
from pydantic import BaseModel
|
26 |
from sqlalchemy import JSON, Column, DateTime, Integer, func
|
|
|
265 |
# WEBUI_AUTH (Required for security)
|
266 |
####################################
|
267 |
|
268 |
+
ENABLE_API_KEY = PersistentConfig(
|
269 |
+
"ENABLE_API_KEY",
|
270 |
+
"auth.api_key.enable",
|
271 |
+
os.environ.get("ENABLE_API_KEY", "True").lower() == "true",
|
272 |
+
)
|
273 |
+
|
274 |
+
|
275 |
JWT_EXPIRES_IN = PersistentConfig(
|
276 |
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
|
277 |
)
|
|
|
614 |
"OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
|
615 |
)
|
616 |
|
617 |
+
OLLAMA_API_CONFIGS = PersistentConfig(
|
618 |
+
"OLLAMA_API_CONFIGS",
|
619 |
+
"ollama.api_configs",
|
620 |
+
{},
|
621 |
+
)
|
622 |
+
|
623 |
####################################
|
624 |
# OPENAI_API
|
625 |
####################################
|
|
|
660 |
"OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
|
661 |
)
|
662 |
|
663 |
+
OPENAI_API_CONFIGS = PersistentConfig(
|
664 |
+
"OPENAI_API_CONFIGS",
|
665 |
+
"openai.api_configs",
|
666 |
+
{},
|
667 |
+
)
|
668 |
|
669 |
+
# Get the actual OpenAI API key based on the base URL
|
670 |
+
OPENAI_API_KEY = ""
|
671 |
try:
|
672 |
OPENAI_API_KEY = OPENAI_API_KEYS.value[
|
673 |
OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
|
674 |
]
|
675 |
except Exception:
|
676 |
pass
|
|
|
677 |
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
678 |
|
679 |
####################################
|
|
|
746 |
os.getenv("DEFAULT_USER_ROLE", "pending"),
|
747 |
)
|
748 |
|
749 |
+
|
750 |
+
USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = (
|
751 |
+
os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower()
|
752 |
+
== "true"
|
753 |
+
)
|
754 |
+
|
755 |
+
USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS = (
|
756 |
+
os.environ.get("USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS", "False").lower()
|
757 |
+
== "true"
|
758 |
+
)
|
759 |
+
|
760 |
+
USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS = (
|
761 |
+
os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS", "False").lower()
|
762 |
+
== "true"
|
763 |
+
)
|
764 |
+
|
765 |
+
USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = (
|
766 |
+
os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true"
|
767 |
+
)
|
768 |
+
|
769 |
+
USER_PERMISSIONS_CHAT_FILE_UPLOAD = (
|
770 |
+
os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true"
|
771 |
+
)
|
772 |
+
|
773 |
+
USER_PERMISSIONS_CHAT_DELETE = (
|
774 |
+
os.environ.get("USER_PERMISSIONS_CHAT_DELETE", "True").lower() == "true"
|
775 |
)
|
776 |
|
777 |
+
USER_PERMISSIONS_CHAT_EDIT = (
|
778 |
+
os.environ.get("USER_PERMISSIONS_CHAT_EDIT", "True").lower() == "true"
|
779 |
)
|
780 |
|
781 |
USER_PERMISSIONS_CHAT_TEMPORARY = (
|
|
|
784 |
|
785 |
USER_PERMISSIONS = PersistentConfig(
|
786 |
"USER_PERMISSIONS",
|
787 |
+
"user.permissions",
|
788 |
{
|
789 |
+
"workspace": {
|
790 |
+
"models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS,
|
791 |
+
"knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS,
|
792 |
+
"prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS,
|
793 |
+
"tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS,
|
794 |
+
},
|
795 |
"chat": {
|
796 |
+
"file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD,
|
797 |
+
"delete": USER_PERMISSIONS_CHAT_DELETE,
|
798 |
+
"edit": USER_PERMISSIONS_CHAT_EDIT,
|
799 |
"temporary": USER_PERMISSIONS_CHAT_TEMPORARY,
|
800 |
+
},
|
801 |
},
|
802 |
)
|
803 |
|
|
|
823 |
},
|
824 |
}
|
825 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
826 |
WEBHOOK_URL = PersistentConfig(
|
827 |
"WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
|
828 |
)
|
|
|
942 |
os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""),
|
943 |
)
|
944 |
|
945 |
+
ENABLE_TAGS_GENERATION = PersistentConfig(
|
946 |
+
"ENABLE_TAGS_GENERATION",
|
947 |
+
"task.tags.enable",
|
948 |
+
os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true",
|
949 |
+
)
|
950 |
+
|
951 |
+
|
952 |
+
ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig(
|
953 |
+
"ENABLE_SEARCH_QUERY_GENERATION",
|
954 |
+
"task.query.search.enable",
|
955 |
+
os.environ.get("ENABLE_SEARCH_QUERY_GENERATION", "True").lower() == "true",
|
956 |
+
)
|
957 |
+
|
958 |
+
ENABLE_RETRIEVAL_QUERY_GENERATION = PersistentConfig(
|
959 |
+
"ENABLE_RETRIEVAL_QUERY_GENERATION",
|
960 |
+
"task.query.retrieval.enable",
|
961 |
+
os.environ.get("ENABLE_RETRIEVAL_QUERY_GENERATION", "True").lower() == "true",
|
962 |
)
|
963 |
|
964 |
|
965 |
+
QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
966 |
+
"QUERY_GENERATION_PROMPT_TEMPLATE",
|
967 |
+
"task.query.prompt_template",
|
968 |
+
os.environ.get("QUERY_GENERATION_PROMPT_TEMPLATE", ""),
|
969 |
)
|
970 |
|
971 |
+
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task:
|
972 |
+
Based on the chat history, determine whether a search is necessary, and if so, generate a 1-3 broad search queries to retrieve comprehensive and updated information. If no search is required, return an empty list.
|
973 |
+
|
974 |
+
### Guidelines:
|
975 |
+
- Respond exclusively with a JSON object.
|
976 |
+
- If a search query is needed, return an object like: { "queries": ["query1", "query2"] } where each query is distinct and concise.
|
977 |
+
- If no search query is necessary, output should be: { "queries": [] }
|
978 |
+
- Default to suggesting a search query to ensure accurate and updated information, unless it is definitively clear no search is required.
|
979 |
+
- Be concise, focusing strictly on composing search queries with no additional commentary or text.
|
980 |
+
- When in doubt, prefer to suggest a search for comprehensiveness.
|
981 |
+
- Today's date is: {{CURRENT_DATE}}
|
982 |
+
|
983 |
+
### Output:
|
984 |
+
JSON format: {
|
985 |
+
"queries": ["query1", "query2"]
|
986 |
+
}
|
987 |
+
|
988 |
+
### Chat History:
|
989 |
+
<chat_history>
|
990 |
+
{{MESSAGES:END:6}}
|
991 |
+
</chat_history>
|
992 |
+
"""
|
993 |
+
|
994 |
|
995 |
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
|
996 |
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
|
|
|
1030 |
|
1031 |
# Qdrant
|
1032 |
QDRANT_URI = os.environ.get("QDRANT_URI", None)
|
1033 |
+
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
|
1034 |
+
|
1035 |
+
# OpenSearch
|
1036 |
+
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
|
1037 |
+
OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", True)
|
1038 |
+
OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False)
|
1039 |
+
OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
|
1040 |
+
OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
|
1041 |
+
|
1042 |
+
# Pgvector
|
1043 |
+
PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL)
|
1044 |
+
if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
|
1045 |
+
raise ValueError(
|
1046 |
+
"Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database."
|
1047 |
+
)
|
1048 |
|
1049 |
####################################
|
1050 |
# Information Retrieval (RAG)
|
|
|
1124 |
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
|
1125 |
|
1126 |
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
|
1127 |
+
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true"
|
1128 |
)
|
1129 |
|
1130 |
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
|
1131 |
+
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true"
|
1132 |
)
|
1133 |
|
1134 |
RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
|
|
|
1149 |
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
|
1150 |
|
1151 |
RAG_RERANKING_MODEL_AUTO_UPDATE = (
|
1152 |
+
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
|
1153 |
)
|
1154 |
|
1155 |
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
|
1156 |
+
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true"
|
1157 |
)
|
1158 |
|
1159 |
|
|
|
1218 |
os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
|
1219 |
)
|
1220 |
|
1221 |
+
RAG_OLLAMA_BASE_URL = PersistentConfig(
|
1222 |
+
"RAG_OLLAMA_BASE_URL",
|
1223 |
+
"rag.ollama.url",
|
1224 |
+
os.getenv("RAG_OLLAMA_BASE_URL", OLLAMA_BASE_URL),
|
1225 |
+
)
|
1226 |
+
|
1227 |
+
RAG_OLLAMA_API_KEY = PersistentConfig(
|
1228 |
+
"RAG_OLLAMA_API_KEY",
|
1229 |
+
"rag.ollama.key",
|
1230 |
+
os.getenv("RAG_OLLAMA_API_KEY", ""),
|
1231 |
+
)
|
1232 |
+
|
1233 |
+
|
1234 |
ENABLE_RAG_LOCAL_WEB_FETCH = (
|
1235 |
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
|
1236 |
)
|
|
|
1320 |
os.getenv("TAVILY_API_KEY", ""),
|
1321 |
)
|
1322 |
|
1323 |
+
JINA_API_KEY = PersistentConfig(
|
1324 |
+
"JINA_API_KEY",
|
1325 |
+
"rag.web.search.jina_api_key",
|
1326 |
+
os.getenv("JINA_API_KEY", ""),
|
1327 |
+
)
|
1328 |
+
|
1329 |
SEARCHAPI_API_KEY = PersistentConfig(
|
1330 |
"SEARCHAPI_API_KEY",
|
1331 |
"rag.web.search.searchapi_api_key",
|
|
|
1338 |
os.getenv("SEARCHAPI_ENGINE", ""),
|
1339 |
)
|
1340 |
|
1341 |
+
BING_SEARCH_V7_ENDPOINT = PersistentConfig(
|
1342 |
+
"BING_SEARCH_V7_ENDPOINT",
|
1343 |
+
"rag.web.search.bing_search_v7_endpoint",
|
1344 |
+
os.environ.get(
|
1345 |
+
"BING_SEARCH_V7_ENDPOINT", "https://api.bing.microsoft.com/v7.0/search"
|
1346 |
+
),
|
1347 |
+
)
|
1348 |
+
|
1349 |
+
BING_SEARCH_V7_SUBSCRIPTION_KEY = PersistentConfig(
|
1350 |
+
"BING_SEARCH_V7_SUBSCRIPTION_KEY",
|
1351 |
+
"rag.web.search.bing_search_v7_subscription_key",
|
1352 |
+
os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""),
|
1353 |
+
)
|
1354 |
+
|
1355 |
+
|
1356 |
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
|
1357 |
"RAG_WEB_SEARCH_RESULT_COUNT",
|
1358 |
"rag.web.search.result_count",
|
|
|
1404 |
|
1405 |
|
1406 |
AUTOMATIC1111_SAMPLER = PersistentConfig(
|
1407 |
+
"AUTOMATIC1111_SAMPLER",
|
1408 |
"image_generation.automatic1111.sampler",
|
1409 |
(
|
1410 |
os.environ.get("AUTOMATIC1111_SAMPLER")
|
|
|
1673 |
"AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3"
|
1674 |
),
|
1675 |
)
|
1676 |
+
|
1677 |
+
|
1678 |
+
####################################
|
1679 |
+
# LDAP
|
1680 |
+
####################################
|
1681 |
+
|
1682 |
+
ENABLE_LDAP = PersistentConfig(
|
1683 |
+
"ENABLE_LDAP",
|
1684 |
+
"ldap.enable",
|
1685 |
+
os.environ.get("ENABLE_LDAP", "false").lower() == "true",
|
1686 |
+
)
|
1687 |
+
|
1688 |
+
LDAP_SERVER_LABEL = PersistentConfig(
|
1689 |
+
"LDAP_SERVER_LABEL",
|
1690 |
+
"ldap.server.label",
|
1691 |
+
os.environ.get("LDAP_SERVER_LABEL", "LDAP Server"),
|
1692 |
+
)
|
1693 |
+
|
1694 |
+
LDAP_SERVER_HOST = PersistentConfig(
|
1695 |
+
"LDAP_SERVER_HOST",
|
1696 |
+
"ldap.server.host",
|
1697 |
+
os.environ.get("LDAP_SERVER_HOST", "localhost"),
|
1698 |
+
)
|
1699 |
+
|
1700 |
+
LDAP_SERVER_PORT = PersistentConfig(
|
1701 |
+
"LDAP_SERVER_PORT",
|
1702 |
+
"ldap.server.port",
|
1703 |
+
int(os.environ.get("LDAP_SERVER_PORT", "389")),
|
1704 |
+
)
|
1705 |
+
|
1706 |
+
LDAP_ATTRIBUTE_FOR_USERNAME = PersistentConfig(
|
1707 |
+
"LDAP_ATTRIBUTE_FOR_USERNAME",
|
1708 |
+
"ldap.server.attribute_for_username",
|
1709 |
+
os.environ.get("LDAP_ATTRIBUTE_FOR_USERNAME", "uid"),
|
1710 |
+
)
|
1711 |
+
|
1712 |
+
LDAP_APP_DN = PersistentConfig(
|
1713 |
+
"LDAP_APP_DN", "ldap.server.app_dn", os.environ.get("LDAP_APP_DN", "")
|
1714 |
+
)
|
1715 |
+
|
1716 |
+
LDAP_APP_PASSWORD = PersistentConfig(
|
1717 |
+
"LDAP_APP_PASSWORD",
|
1718 |
+
"ldap.server.app_password",
|
1719 |
+
os.environ.get("LDAP_APP_PASSWORD", ""),
|
1720 |
+
)
|
1721 |
+
|
1722 |
+
LDAP_SEARCH_BASE = PersistentConfig(
|
1723 |
+
"LDAP_SEARCH_BASE", "ldap.server.users_dn", os.environ.get("LDAP_SEARCH_BASE", "")
|
1724 |
+
)
|
1725 |
+
|
1726 |
+
LDAP_SEARCH_FILTERS = PersistentConfig(
|
1727 |
+
"LDAP_SEARCH_FILTER",
|
1728 |
+
"ldap.server.search_filter",
|
1729 |
+
os.environ.get("LDAP_SEARCH_FILTER", ""),
|
1730 |
+
)
|
1731 |
+
|
1732 |
+
LDAP_USE_TLS = PersistentConfig(
|
1733 |
+
"LDAP_USE_TLS",
|
1734 |
+
"ldap.server.use_tls",
|
1735 |
+
os.environ.get("LDAP_USE_TLS", "True").lower() == "true",
|
1736 |
+
)
|
1737 |
+
|
1738 |
+
LDAP_CA_CERT_FILE = PersistentConfig(
|
1739 |
+
"LDAP_CA_CERT_FILE",
|
1740 |
+
"ldap.server.ca_cert_file",
|
1741 |
+
os.environ.get("LDAP_CA_CERT_FILE", ""),
|
1742 |
+
)
|
1743 |
+
|
1744 |
+
LDAP_CIPHERS = PersistentConfig(
|
1745 |
+
"LDAP_CIPHERS", "ldap.server.ciphers", os.environ.get("LDAP_CIPHERS", "ALL")
|
1746 |
+
)
|
backend/open_webui/constants.py
CHANGED
@@ -62,6 +62,7 @@ class ERROR_MESSAGES(str, Enum):
|
|
62 |
NOT_FOUND = "We could not find what you're looking for :/"
|
63 |
USER_NOT_FOUND = "We could not find what you're looking for :/"
|
64 |
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
|
|
|
65 |
|
66 |
MALICIOUS = "Unusual activities detected, please try again in a few minutes."
|
67 |
|
@@ -75,6 +76,7 @@ class ERROR_MESSAGES(str, Enum):
|
|
75 |
OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
|
76 |
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
|
77 |
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
|
|
|
78 |
|
79 |
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
|
80 |
|
|
|
62 |
NOT_FOUND = "We could not find what you're looking for :/"
|
63 |
USER_NOT_FOUND = "We could not find what you're looking for :/"
|
64 |
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
|
65 |
+
API_KEY_NOT_ALLOWED = "Use of API key is not enabled in the environment."
|
66 |
|
67 |
MALICIOUS = "Unusual activities detected, please try again in a few minutes."
|
68 |
|
|
|
76 |
OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
|
77 |
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
|
78 |
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
|
79 |
+
API_KEY_CREATION_NOT_ALLOWED = "API key creation is not allowed in the environment."
|
80 |
|
81 |
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
|
82 |
|
backend/open_webui/env.py
CHANGED
@@ -1,384 +1,393 @@
|
|
1 |
-
import importlib.metadata
|
2 |
-
import json
|
3 |
-
import logging
|
4 |
-
import os
|
5 |
-
import pkgutil
|
6 |
-
import sys
|
7 |
-
import shutil
|
8 |
-
from pathlib import Path
|
9 |
-
|
10 |
-
import markdown
|
11 |
-
from bs4 import BeautifulSoup
|
12 |
-
from open_webui.constants import ERROR_MESSAGES
|
13 |
-
|
14 |
-
####################################
|
15 |
-
# Load .env file
|
16 |
-
####################################
|
17 |
-
|
18 |
-
OPEN_WEBUI_DIR = Path(__file__).parent # the path containing this file
|
19 |
-
print(OPEN_WEBUI_DIR)
|
20 |
-
|
21 |
-
BACKEND_DIR = OPEN_WEBUI_DIR.parent # the path containing this file
|
22 |
-
BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
|
23 |
-
|
24 |
-
print(BACKEND_DIR)
|
25 |
-
print(BASE_DIR)
|
26 |
-
|
27 |
-
try:
|
28 |
-
from dotenv import find_dotenv, load_dotenv
|
29 |
-
|
30 |
-
load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
|
31 |
-
except ImportError:
|
32 |
-
print("dotenv not installed, skipping...")
|
33 |
-
|
34 |
-
DOCKER = os.environ.get("DOCKER", "False").lower() == "true"
|
35 |
-
|
36 |
-
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
|
37 |
-
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
|
38 |
-
|
39 |
-
if USE_CUDA.lower() == "true":
|
40 |
-
try:
|
41 |
-
import torch
|
42 |
-
|
43 |
-
assert torch.cuda.is_available(), "CUDA not available"
|
44 |
-
DEVICE_TYPE = "cuda"
|
45 |
-
except Exception as e:
|
46 |
-
cuda_error = (
|
47 |
-
"Error when testing CUDA but USE_CUDA_DOCKER is true. "
|
48 |
-
f"Resetting USE_CUDA_DOCKER to false: {e}"
|
49 |
-
)
|
50 |
-
os.environ["USE_CUDA_DOCKER"] = "false"
|
51 |
-
USE_CUDA = "false"
|
52 |
-
DEVICE_TYPE = "cpu"
|
53 |
-
else:
|
54 |
-
DEVICE_TYPE = "cpu"
|
55 |
-
|
56 |
-
|
57 |
-
####################################
|
58 |
-
# LOGGING
|
59 |
-
####################################
|
60 |
-
|
61 |
-
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
|
62 |
-
|
63 |
-
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "
|
64 |
-
if GLOBAL_LOG_LEVEL in log_levels:
|
65 |
-
logging.basicConfig(stream=sys.stdout, level=
|
66 |
-
else:
|
67 |
-
GLOBAL_LOG_LEVEL = "
|
68 |
-
|
69 |
-
log = logging.getLogger(__name__)
|
70 |
-
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
|
71 |
-
|
72 |
-
if "cuda_error" in locals():
|
73 |
-
log.exception(cuda_error)
|
74 |
-
|
75 |
-
log_sources = [
|
76 |
-
"AUDIO",
|
77 |
-
"COMFYUI",
|
78 |
-
"CONFIG",
|
79 |
-
"DB",
|
80 |
-
"IMAGES",
|
81 |
-
"MAIN",
|
82 |
-
"MODELS",
|
83 |
-
"OLLAMA",
|
84 |
-
"OPENAI",
|
85 |
-
"RAG",
|
86 |
-
"WEBHOOK",
|
87 |
-
"SOCKET",
|
88 |
-
]
|
89 |
-
|
90 |
-
SRC_LOG_LEVELS = {}
|
91 |
-
|
92 |
-
for source in log_sources:
|
93 |
-
log_env_var = source + "_LOG_LEVEL"
|
94 |
-
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
|
95 |
-
if SRC_LOG_LEVELS[source] not in log_levels:
|
96 |
-
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
|
97 |
-
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
|
98 |
-
|
99 |
-
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
|
100 |
-
|
101 |
-
|
102 |
-
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
|
103 |
-
if WEBUI_NAME != "Open WebUI":
|
104 |
-
WEBUI_NAME += " (Open WebUI)"
|
105 |
-
|
106 |
-
WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
|
107 |
-
|
108 |
-
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
|
109 |
-
|
110 |
-
|
111 |
-
####################################
|
112 |
-
# ENV (dev,test,prod)
|
113 |
-
####################################
|
114 |
-
|
115 |
-
ENV = os.environ.get("ENV", "dev")
|
116 |
-
|
117 |
-
FROM_INIT_PY = os.environ.get("FROM_INIT_PY", "False").lower() == "true"
|
118 |
-
|
119 |
-
if FROM_INIT_PY:
|
120 |
-
PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")}
|
121 |
-
else:
|
122 |
-
try:
|
123 |
-
PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text())
|
124 |
-
except Exception:
|
125 |
-
PACKAGE_DATA = {"version": "0.0.0"}
|
126 |
-
|
127 |
-
|
128 |
-
VERSION = PACKAGE_DATA["version"]
|
129 |
-
|
130 |
-
|
131 |
-
# Function to parse each section
|
132 |
-
def parse_section(section):
|
133 |
-
items = []
|
134 |
-
for li in section.find_all("li"):
|
135 |
-
# Extract raw HTML string
|
136 |
-
raw_html = str(li)
|
137 |
-
|
138 |
-
# Extract text without HTML tags
|
139 |
-
text = li.get_text(separator=" ", strip=True)
|
140 |
-
|
141 |
-
# Split into title and content
|
142 |
-
parts = text.split(": ", 1)
|
143 |
-
title = parts[0].strip() if len(parts) > 1 else ""
|
144 |
-
content = parts[1].strip() if len(parts) > 1 else text
|
145 |
-
|
146 |
-
items.append({"title": title, "content": content, "raw": raw_html})
|
147 |
-
return items
|
148 |
-
|
149 |
-
|
150 |
-
try:
|
151 |
-
changelog_path = BASE_DIR / "CHANGELOG.md"
|
152 |
-
with open(str(changelog_path.absolute()), "r", encoding="utf8") as file:
|
153 |
-
changelog_content = file.read()
|
154 |
-
|
155 |
-
except Exception:
|
156 |
-
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
|
157 |
-
|
158 |
-
|
159 |
-
# Convert markdown content to HTML
|
160 |
-
html_content = markdown.markdown(changelog_content)
|
161 |
-
|
162 |
-
# Parse the HTML content
|
163 |
-
soup = BeautifulSoup(html_content, "html.parser")
|
164 |
-
|
165 |
-
# Initialize JSON structure
|
166 |
-
changelog_json = {}
|
167 |
-
|
168 |
-
# Iterate over each version
|
169 |
-
for version in soup.find_all("h2"):
|
170 |
-
version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets
|
171 |
-
date = version.get_text().strip().split(" - ")[1]
|
172 |
-
|
173 |
-
version_data = {"date": date}
|
174 |
-
|
175 |
-
# Find the next sibling that is a h3 tag (section title)
|
176 |
-
current = version.find_next_sibling()
|
177 |
-
|
178 |
-
while current and current.name != "h2":
|
179 |
-
if current.name == "h3":
|
180 |
-
section_title = current.get_text().lower() # e.g., "added", "fixed"
|
181 |
-
section_items = parse_section(current.find_next_sibling("ul"))
|
182 |
-
version_data[section_title] = section_items
|
183 |
-
|
184 |
-
# Move to the next element
|
185 |
-
current = current.find_next_sibling()
|
186 |
-
|
187 |
-
changelog_json[version_number] = version_data
|
188 |
-
|
189 |
-
|
190 |
-
CHANGELOG = changelog_json
|
191 |
-
|
192 |
-
####################################
|
193 |
-
# SAFE_MODE
|
194 |
-
####################################
|
195 |
-
|
196 |
-
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
|
197 |
-
|
198 |
-
####################################
|
199 |
-
#
|
200 |
-
####################################
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
)
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
####################################
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.metadata
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import pkgutil
|
6 |
+
import sys
|
7 |
+
import shutil
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import markdown
|
11 |
+
from bs4 import BeautifulSoup
|
12 |
+
from open_webui.constants import ERROR_MESSAGES
|
13 |
+
|
14 |
+
####################################
|
15 |
+
# Load .env file
|
16 |
+
####################################
|
17 |
+
|
18 |
+
OPEN_WEBUI_DIR = Path(__file__).parent # the path containing this file
|
19 |
+
print(OPEN_WEBUI_DIR)
|
20 |
+
|
21 |
+
BACKEND_DIR = OPEN_WEBUI_DIR.parent # the path containing this file
|
22 |
+
BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
|
23 |
+
|
24 |
+
print(BACKEND_DIR)
|
25 |
+
print(BASE_DIR)
|
26 |
+
|
27 |
+
try:
|
28 |
+
from dotenv import find_dotenv, load_dotenv
|
29 |
+
|
30 |
+
load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
|
31 |
+
except ImportError:
|
32 |
+
print("dotenv not installed, skipping...")
|
33 |
+
|
34 |
+
DOCKER = os.environ.get("DOCKER", "False").lower() == "true"
|
35 |
+
|
36 |
+
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
|
37 |
+
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
|
38 |
+
|
39 |
+
if USE_CUDA.lower() == "true":
|
40 |
+
try:
|
41 |
+
import torch
|
42 |
+
|
43 |
+
assert torch.cuda.is_available(), "CUDA not available"
|
44 |
+
DEVICE_TYPE = "cuda"
|
45 |
+
except Exception as e:
|
46 |
+
cuda_error = (
|
47 |
+
"Error when testing CUDA but USE_CUDA_DOCKER is true. "
|
48 |
+
f"Resetting USE_CUDA_DOCKER to false: {e}"
|
49 |
+
)
|
50 |
+
os.environ["USE_CUDA_DOCKER"] = "false"
|
51 |
+
USE_CUDA = "false"
|
52 |
+
DEVICE_TYPE = "cpu"
|
53 |
+
else:
|
54 |
+
DEVICE_TYPE = "cpu"
|
55 |
+
|
56 |
+
|
57 |
+
####################################
|
58 |
+
# LOGGING
|
59 |
+
####################################
|
60 |
+
|
61 |
+
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
|
62 |
+
|
63 |
+
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
|
64 |
+
if GLOBAL_LOG_LEVEL in log_levels:
|
65 |
+
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
|
66 |
+
else:
|
67 |
+
GLOBAL_LOG_LEVEL = "INFO"
|
68 |
+
|
69 |
+
log = logging.getLogger(__name__)
|
70 |
+
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
|
71 |
+
|
72 |
+
if "cuda_error" in locals():
|
73 |
+
log.exception(cuda_error)
|
74 |
+
|
75 |
+
log_sources = [
|
76 |
+
"AUDIO",
|
77 |
+
"COMFYUI",
|
78 |
+
"CONFIG",
|
79 |
+
"DB",
|
80 |
+
"IMAGES",
|
81 |
+
"MAIN",
|
82 |
+
"MODELS",
|
83 |
+
"OLLAMA",
|
84 |
+
"OPENAI",
|
85 |
+
"RAG",
|
86 |
+
"WEBHOOK",
|
87 |
+
"SOCKET",
|
88 |
+
]
|
89 |
+
|
90 |
+
SRC_LOG_LEVELS = {}
|
91 |
+
|
92 |
+
for source in log_sources:
|
93 |
+
log_env_var = source + "_LOG_LEVEL"
|
94 |
+
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
|
95 |
+
if SRC_LOG_LEVELS[source] not in log_levels:
|
96 |
+
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
|
97 |
+
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
|
98 |
+
|
99 |
+
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
|
100 |
+
|
101 |
+
|
102 |
+
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
|
103 |
+
if WEBUI_NAME != "Open WebUI":
|
104 |
+
WEBUI_NAME += " (Open WebUI)"
|
105 |
+
|
106 |
+
WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
|
107 |
+
|
108 |
+
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
|
109 |
+
|
110 |
+
|
111 |
+
####################################
|
112 |
+
# ENV (dev,test,prod)
|
113 |
+
####################################
|
114 |
+
|
115 |
+
ENV = os.environ.get("ENV", "dev")
|
116 |
+
|
117 |
+
FROM_INIT_PY = os.environ.get("FROM_INIT_PY", "False").lower() == "true"
|
118 |
+
|
119 |
+
if FROM_INIT_PY:
|
120 |
+
PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")}
|
121 |
+
else:
|
122 |
+
try:
|
123 |
+
PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text())
|
124 |
+
except Exception:
|
125 |
+
PACKAGE_DATA = {"version": "0.0.0"}
|
126 |
+
|
127 |
+
|
128 |
+
VERSION = PACKAGE_DATA["version"]
|
129 |
+
|
130 |
+
|
131 |
+
# Function to parse each section
|
132 |
+
def parse_section(section):
|
133 |
+
items = []
|
134 |
+
for li in section.find_all("li"):
|
135 |
+
# Extract raw HTML string
|
136 |
+
raw_html = str(li)
|
137 |
+
|
138 |
+
# Extract text without HTML tags
|
139 |
+
text = li.get_text(separator=" ", strip=True)
|
140 |
+
|
141 |
+
# Split into title and content
|
142 |
+
parts = text.split(": ", 1)
|
143 |
+
title = parts[0].strip() if len(parts) > 1 else ""
|
144 |
+
content = parts[1].strip() if len(parts) > 1 else text
|
145 |
+
|
146 |
+
items.append({"title": title, "content": content, "raw": raw_html})
|
147 |
+
return items
|
148 |
+
|
149 |
+
|
150 |
+
try:
|
151 |
+
changelog_path = BASE_DIR / "CHANGELOG.md"
|
152 |
+
with open(str(changelog_path.absolute()), "r", encoding="utf8") as file:
|
153 |
+
changelog_content = file.read()
|
154 |
+
|
155 |
+
except Exception:
|
156 |
+
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
|
157 |
+
|
158 |
+
|
159 |
+
# Convert markdown content to HTML
|
160 |
+
html_content = markdown.markdown(changelog_content)
|
161 |
+
|
162 |
+
# Parse the HTML content
|
163 |
+
soup = BeautifulSoup(html_content, "html.parser")
|
164 |
+
|
165 |
+
# Initialize JSON structure
|
166 |
+
changelog_json = {}
|
167 |
+
|
168 |
+
# Iterate over each version
|
169 |
+
for version in soup.find_all("h2"):
|
170 |
+
version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets
|
171 |
+
date = version.get_text().strip().split(" - ")[1]
|
172 |
+
|
173 |
+
version_data = {"date": date}
|
174 |
+
|
175 |
+
# Find the next sibling that is a h3 tag (section title)
|
176 |
+
current = version.find_next_sibling()
|
177 |
+
|
178 |
+
while current and current.name != "h2":
|
179 |
+
if current.name == "h3":
|
180 |
+
section_title = current.get_text().lower() # e.g., "added", "fixed"
|
181 |
+
section_items = parse_section(current.find_next_sibling("ul"))
|
182 |
+
version_data[section_title] = section_items
|
183 |
+
|
184 |
+
# Move to the next element
|
185 |
+
current = current.find_next_sibling()
|
186 |
+
|
187 |
+
changelog_json[version_number] = version_data
|
188 |
+
|
189 |
+
|
190 |
+
CHANGELOG = changelog_json
|
191 |
+
|
192 |
+
####################################
|
193 |
+
# SAFE_MODE
|
194 |
+
####################################
|
195 |
+
|
196 |
+
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
|
197 |
+
|
198 |
+
####################################
|
199 |
+
# ENABLE_FORWARD_USER_INFO_HEADERS
|
200 |
+
####################################
|
201 |
+
|
202 |
+
ENABLE_FORWARD_USER_INFO_HEADERS = (
|
203 |
+
os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
|
204 |
+
)
|
205 |
+
|
206 |
+
|
207 |
+
####################################
|
208 |
+
# WEBUI_BUILD_HASH
|
209 |
+
####################################
|
210 |
+
|
211 |
+
WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
|
212 |
+
|
213 |
+
####################################
|
214 |
+
# DATA/FRONTEND BUILD DIR
|
215 |
+
####################################
|
216 |
+
|
217 |
+
DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
|
218 |
+
|
219 |
+
if FROM_INIT_PY:
|
220 |
+
NEW_DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")).resolve()
|
221 |
+
NEW_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
222 |
+
|
223 |
+
# Check if the data directory exists in the package directory
|
224 |
+
if DATA_DIR.exists() and DATA_DIR != NEW_DATA_DIR:
|
225 |
+
log.info(f"Moving {DATA_DIR} to {NEW_DATA_DIR}")
|
226 |
+
for item in DATA_DIR.iterdir():
|
227 |
+
dest = NEW_DATA_DIR / item.name
|
228 |
+
if item.is_dir():
|
229 |
+
shutil.copytree(item, dest, dirs_exist_ok=True)
|
230 |
+
else:
|
231 |
+
shutil.copy2(item, dest)
|
232 |
+
|
233 |
+
# Zip the data directory
|
234 |
+
shutil.make_archive(DATA_DIR.parent / "open_webui_data", "zip", DATA_DIR)
|
235 |
+
|
236 |
+
# Remove the old data directory
|
237 |
+
shutil.rmtree(DATA_DIR)
|
238 |
+
|
239 |
+
DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
|
240 |
+
|
241 |
+
|
242 |
+
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))
|
243 |
+
|
244 |
+
FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts"))
|
245 |
+
|
246 |
+
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
|
247 |
+
|
248 |
+
if FROM_INIT_PY:
|
249 |
+
FRONTEND_BUILD_DIR = Path(
|
250 |
+
os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend")
|
251 |
+
).resolve()
|
252 |
+
|
253 |
+
|
254 |
+
####################################
|
255 |
+
# Database
|
256 |
+
####################################
|
257 |
+
|
258 |
+
# Check if the file exists
|
259 |
+
if os.path.exists(f"{DATA_DIR}/ollama.db"):
|
260 |
+
# Rename the file
|
261 |
+
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
|
262 |
+
log.info("Database migrated from Ollama-WebUI successfully.")
|
263 |
+
else:
|
264 |
+
pass
|
265 |
+
|
266 |
+
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
|
267 |
+
|
268 |
+
# Replace the postgres:// with postgresql://
|
269 |
+
if "postgres://" in DATABASE_URL:
|
270 |
+
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
|
271 |
+
|
272 |
+
DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0)
|
273 |
+
|
274 |
+
if DATABASE_POOL_SIZE == "":
|
275 |
+
DATABASE_POOL_SIZE = 0
|
276 |
+
else:
|
277 |
+
try:
|
278 |
+
DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE)
|
279 |
+
except Exception:
|
280 |
+
DATABASE_POOL_SIZE = 0
|
281 |
+
|
282 |
+
DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0)
|
283 |
+
|
284 |
+
if DATABASE_POOL_MAX_OVERFLOW == "":
|
285 |
+
DATABASE_POOL_MAX_OVERFLOW = 0
|
286 |
+
else:
|
287 |
+
try:
|
288 |
+
DATABASE_POOL_MAX_OVERFLOW = int(DATABASE_POOL_MAX_OVERFLOW)
|
289 |
+
except Exception:
|
290 |
+
DATABASE_POOL_MAX_OVERFLOW = 0
|
291 |
+
|
292 |
+
DATABASE_POOL_TIMEOUT = os.environ.get("DATABASE_POOL_TIMEOUT", 30)
|
293 |
+
|
294 |
+
if DATABASE_POOL_TIMEOUT == "":
|
295 |
+
DATABASE_POOL_TIMEOUT = 30
|
296 |
+
else:
|
297 |
+
try:
|
298 |
+
DATABASE_POOL_TIMEOUT = int(DATABASE_POOL_TIMEOUT)
|
299 |
+
except Exception:
|
300 |
+
DATABASE_POOL_TIMEOUT = 30
|
301 |
+
|
302 |
+
DATABASE_POOL_RECYCLE = os.environ.get("DATABASE_POOL_RECYCLE", 3600)
|
303 |
+
|
304 |
+
if DATABASE_POOL_RECYCLE == "":
|
305 |
+
DATABASE_POOL_RECYCLE = 3600
|
306 |
+
else:
|
307 |
+
try:
|
308 |
+
DATABASE_POOL_RECYCLE = int(DATABASE_POOL_RECYCLE)
|
309 |
+
except Exception:
|
310 |
+
DATABASE_POOL_RECYCLE = 3600
|
311 |
+
|
312 |
+
RESET_CONFIG_ON_START = (
|
313 |
+
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
|
314 |
+
)
|
315 |
+
|
316 |
+
####################################
|
317 |
+
# REDIS
|
318 |
+
####################################
|
319 |
+
|
320 |
+
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
|
321 |
+
|
322 |
+
####################################
|
323 |
+
# WEBUI_AUTH (Required for security)
|
324 |
+
####################################
|
325 |
+
|
326 |
+
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
|
327 |
+
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
|
328 |
+
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
|
329 |
+
)
|
330 |
+
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
|
331 |
+
|
332 |
+
|
333 |
+
####################################
|
334 |
+
# WEBUI_SECRET_KEY
|
335 |
+
####################################
|
336 |
+
|
337 |
+
WEBUI_SECRET_KEY = os.environ.get(
|
338 |
+
"WEBUI_SECRET_KEY",
|
339 |
+
os.environ.get(
|
340 |
+
"WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t"
|
341 |
+
), # DEPRECATED: remove at next major version
|
342 |
+
)
|
343 |
+
|
344 |
+
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get(
|
345 |
+
"WEBUI_SESSION_COOKIE_SAME_SITE",
|
346 |
+
os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"),
|
347 |
+
)
|
348 |
+
|
349 |
+
WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
|
350 |
+
"WEBUI_SESSION_COOKIE_SECURE",
|
351 |
+
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true",
|
352 |
+
)
|
353 |
+
|
354 |
+
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
|
355 |
+
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
|
356 |
+
|
357 |
+
ENABLE_WEBSOCKET_SUPPORT = (
|
358 |
+
os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
|
359 |
+
)
|
360 |
+
|
361 |
+
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
362 |
+
|
363 |
+
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
364 |
+
|
365 |
+
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
366 |
+
|
367 |
+
if AIOHTTP_CLIENT_TIMEOUT == "":
|
368 |
+
AIOHTTP_CLIENT_TIMEOUT = None
|
369 |
+
else:
|
370 |
+
try:
|
371 |
+
AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
|
372 |
+
except Exception:
|
373 |
+
AIOHTTP_CLIENT_TIMEOUT = 300
|
374 |
+
|
375 |
+
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
|
376 |
+
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "3"
|
377 |
+
)
|
378 |
+
|
379 |
+
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
|
380 |
+
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None
|
381 |
+
else:
|
382 |
+
try:
|
383 |
+
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int(
|
384 |
+
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
|
385 |
+
)
|
386 |
+
except Exception:
|
387 |
+
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 3
|
388 |
+
|
389 |
+
####################################
|
390 |
+
# OFFLINE_MODE
|
391 |
+
####################################
|
392 |
+
|
393 |
+
OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
|
backend/open_webui/main.py
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
backend/open_webui/migrations/versions/922e7a387820_add_group_table.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Add group table
|
2 |
+
|
3 |
+
Revision ID: 922e7a387820
|
4 |
+
Revises: 4ace53fd72c8
|
5 |
+
Create Date: 2024-11-14 03:00:00.000000
|
6 |
+
|
7 |
+
"""
|
8 |
+
|
9 |
+
from alembic import op
|
10 |
+
import sqlalchemy as sa
|
11 |
+
|
12 |
+
revision = "922e7a387820"
|
13 |
+
down_revision = "4ace53fd72c8"
|
14 |
+
branch_labels = None
|
15 |
+
depends_on = None
|
16 |
+
|
17 |
+
|
18 |
+
def upgrade():
|
19 |
+
op.create_table(
|
20 |
+
"group",
|
21 |
+
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
22 |
+
sa.Column("user_id", sa.Text(), nullable=True),
|
23 |
+
sa.Column("name", sa.Text(), nullable=True),
|
24 |
+
sa.Column("description", sa.Text(), nullable=True),
|
25 |
+
sa.Column("data", sa.JSON(), nullable=True),
|
26 |
+
sa.Column("meta", sa.JSON(), nullable=True),
|
27 |
+
sa.Column("permissions", sa.JSON(), nullable=True),
|
28 |
+
sa.Column("user_ids", sa.JSON(), nullable=True),
|
29 |
+
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
30 |
+
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
31 |
+
)
|
32 |
+
|
33 |
+
# Add 'access_control' column to 'model' table
|
34 |
+
op.add_column(
|
35 |
+
"model",
|
36 |
+
sa.Column("access_control", sa.JSON(), nullable=True),
|
37 |
+
)
|
38 |
+
|
39 |
+
# Add 'is_active' column to 'model' table
|
40 |
+
op.add_column(
|
41 |
+
"model",
|
42 |
+
sa.Column(
|
43 |
+
"is_active",
|
44 |
+
sa.Boolean(),
|
45 |
+
nullable=False,
|
46 |
+
server_default=sa.sql.expression.true(),
|
47 |
+
),
|
48 |
+
)
|
49 |
+
|
50 |
+
# Add 'access_control' column to 'knowledge' table
|
51 |
+
op.add_column(
|
52 |
+
"knowledge",
|
53 |
+
sa.Column("access_control", sa.JSON(), nullable=True),
|
54 |
+
)
|
55 |
+
|
56 |
+
# Add 'access_control' column to 'prompt' table
|
57 |
+
op.add_column(
|
58 |
+
"prompt",
|
59 |
+
sa.Column("access_control", sa.JSON(), nullable=True),
|
60 |
+
)
|
61 |
+
|
62 |
+
# Add 'access_control' column to 'tools' table
|
63 |
+
op.add_column(
|
64 |
+
"tool",
|
65 |
+
sa.Column("access_control", sa.JSON(), nullable=True),
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
def downgrade():
|
70 |
+
op.drop_table("group")
|
71 |
+
|
72 |
+
# Drop 'access_control' column from 'model' table
|
73 |
+
op.drop_column("model", "access_control")
|
74 |
+
|
75 |
+
# Drop 'is_active' column from 'model' table
|
76 |
+
op.drop_column("model", "is_active")
|
77 |
+
|
78 |
+
# Drop 'access_control' column from 'knowledge' table
|
79 |
+
op.drop_column("knowledge", "access_control")
|
80 |
+
|
81 |
+
# Drop 'access_control' column from 'prompt' table
|
82 |
+
op.drop_column("prompt", "access_control")
|
83 |
+
|
84 |
+
# Drop 'access_control' column from 'tools' table
|
85 |
+
op.drop_column("tool", "access_control")
|
backend/open_webui/storage/provider.py
CHANGED
@@ -51,7 +51,10 @@ class StorageProvider:
|
|
51 |
|
52 |
try:
|
53 |
self.s3_client.upload_file(file_path, self.bucket_name, filename)
|
54 |
-
return
|
|
|
|
|
|
|
55 |
except ClientError as e:
|
56 |
raise RuntimeError(f"Error uploading file to S3: {e}")
|
57 |
|
|
|
51 |
|
52 |
try:
|
53 |
self.s3_client.upload_file(file_path, self.bucket_name, filename)
|
54 |
+
return (
|
55 |
+
open(file_path, "rb").read(),
|
56 |
+
"s3://" + self.bucket_name + "/" + filename,
|
57 |
+
)
|
58 |
except ClientError as e:
|
59 |
raise RuntimeError(f"Error uploading file to S3: {e}")
|
60 |
|
backend/open_webui/utils/access_control.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, List, Dict, Any
|
2 |
+
from open_webui.apps.webui.models.groups import Groups
|
3 |
+
import json
|
4 |
+
|
5 |
+
|
6 |
+
def get_permissions(
|
7 |
+
user_id: str,
|
8 |
+
default_permissions: Dict[str, Any],
|
9 |
+
) -> Dict[str, Any]:
|
10 |
+
"""
|
11 |
+
Get all permissions for a user by combining the permissions of all groups the user is a member of.
|
12 |
+
If a permission is defined in multiple groups, the most permissive value is used (True > False).
|
13 |
+
Permissions are nested in a dict with the permission key as the key and a boolean as the value.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def combine_permissions(
|
17 |
+
permissions: Dict[str, Any], group_permissions: Dict[str, Any]
|
18 |
+
) -> Dict[str, Any]:
|
19 |
+
"""Combine permissions from multiple groups by taking the most permissive value."""
|
20 |
+
for key, value in group_permissions.items():
|
21 |
+
if isinstance(value, dict):
|
22 |
+
if key not in permissions:
|
23 |
+
permissions[key] = {}
|
24 |
+
permissions[key] = combine_permissions(permissions[key], value)
|
25 |
+
else:
|
26 |
+
if key not in permissions:
|
27 |
+
permissions[key] = value
|
28 |
+
else:
|
29 |
+
permissions[key] = permissions[key] or value
|
30 |
+
return permissions
|
31 |
+
|
32 |
+
user_groups = Groups.get_groups_by_member_id(user_id)
|
33 |
+
|
34 |
+
# deep copy default permissions to avoid modifying the original dict
|
35 |
+
permissions = json.loads(json.dumps(default_permissions))
|
36 |
+
|
37 |
+
for group in user_groups:
|
38 |
+
group_permissions = group.permissions
|
39 |
+
permissions = combine_permissions(permissions, group_permissions)
|
40 |
+
|
41 |
+
return permissions
|
42 |
+
|
43 |
+
|
44 |
+
def has_permission(
|
45 |
+
user_id: str,
|
46 |
+
permission_key: str,
|
47 |
+
default_permissions: Dict[str, bool] = {},
|
48 |
+
) -> bool:
|
49 |
+
"""
|
50 |
+
Check if a user has a specific permission by checking the group permissions
|
51 |
+
and falls back to default permissions if not found in any group.
|
52 |
+
|
53 |
+
Permission keys can be hierarchical and separated by dots ('.').
|
54 |
+
"""
|
55 |
+
|
56 |
+
def get_permission(permissions: Dict[str, bool], keys: List[str]) -> bool:
|
57 |
+
"""Traverse permissions dict using a list of keys (from dot-split permission_key)."""
|
58 |
+
for key in keys:
|
59 |
+
if key not in permissions:
|
60 |
+
return False # If any part of the hierarchy is missing, deny access
|
61 |
+
permissions = permissions[key] # Go one level deeper
|
62 |
+
|
63 |
+
return bool(permissions) # Return the boolean at the final level
|
64 |
+
|
65 |
+
permission_hierarchy = permission_key.split(".")
|
66 |
+
|
67 |
+
# Retrieve user group permissions
|
68 |
+
user_groups = Groups.get_groups_by_member_id(user_id)
|
69 |
+
|
70 |
+
for group in user_groups:
|
71 |
+
group_permissions = group.permissions
|
72 |
+
if get_permission(group_permissions, permission_hierarchy):
|
73 |
+
return True
|
74 |
+
|
75 |
+
# Check default permissions afterwards if the group permissions don't allow it
|
76 |
+
return get_permission(default_permissions, permission_hierarchy)
|
77 |
+
|
78 |
+
|
79 |
+
def has_access(
|
80 |
+
user_id: str,
|
81 |
+
type: str = "write",
|
82 |
+
access_control: Optional[dict] = None,
|
83 |
+
) -> bool:
|
84 |
+
if access_control is None:
|
85 |
+
return type == "read"
|
86 |
+
|
87 |
+
user_groups = Groups.get_groups_by_member_id(user_id)
|
88 |
+
user_group_ids = [group.id for group in user_groups]
|
89 |
+
permission_access = access_control.get(type, {})
|
90 |
+
permitted_group_ids = permission_access.get("group_ids", [])
|
91 |
+
permitted_user_ids = permission_access.get("user_ids", [])
|
92 |
+
|
93 |
+
return user_id in permitted_user_ids or any(
|
94 |
+
group_id in permitted_group_ids for group_id in user_group_ids
|
95 |
+
)
|
backend/open_webui/utils/pdf_generator.py
CHANGED
@@ -54,18 +54,18 @@ class PDFGenerator:
|
|
54 |
html_content = markdown(content, extensions=["pymdownx.extra"])
|
55 |
|
56 |
html_message = f"""
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
"""
|
70 |
return html_message
|
71 |
|
|
|
54 |
html_content = markdown(content, extensions=["pymdownx.extra"])
|
55 |
|
56 |
html_message = f"""
|
57 |
+
<div> {date_str} </div>
|
58 |
+
<div class="message">
|
59 |
+
<div>
|
60 |
+
<h2>
|
61 |
+
<strong>{role.title()}</strong>
|
62 |
+
<span style="font-size: 12px; color: #888;">{model}</span>
|
63 |
+
</h2>
|
64 |
+
</div>
|
65 |
+
<pre class="markdown-section">
|
66 |
+
{content}
|
67 |
+
</pre>
|
68 |
+
</div>
|
69 |
"""
|
70 |
return html_message
|
71 |
|
backend/open_webui/utils/security_headers.py
CHANGED
@@ -20,6 +20,7 @@ def set_security_headers() -> Dict[str, str]:
|
|
20 |
This function reads specific environment variables and uses their values
|
21 |
to set corresponding security headers. The headers that can be set are:
|
22 |
- cache-control
|
|
|
23 |
- strict-transport-security
|
24 |
- referrer-policy
|
25 |
- x-content-type-options
|
@@ -38,6 +39,7 @@ def set_security_headers() -> Dict[str, str]:
|
|
38 |
header_setters = {
|
39 |
"CACHE_CONTROL": set_cache_control,
|
40 |
"HSTS": set_hsts,
|
|
|
41 |
"REFERRER_POLICY": set_referrer,
|
42 |
"XCONTENT_TYPE": set_xcontent_type,
|
43 |
"XDOWNLOAD_OPTIONS": set_xdownload_options,
|
@@ -73,6 +75,15 @@ def set_xframe(value: str):
|
|
73 |
return {"X-Frame-Options": value}
|
74 |
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
# Set Referrer-Policy response header
|
77 |
def set_referrer(value: str):
|
78 |
pattern = r"^(no-referrer|no-referrer-when-downgrade|origin|origin-when-cross-origin|same-origin|strict-origin|strict-origin-when-cross-origin|unsafe-url)$"
|
|
|
20 |
This function reads specific environment variables and uses their values
|
21 |
to set corresponding security headers. The headers that can be set are:
|
22 |
- cache-control
|
23 |
+
- permissions-policy
|
24 |
- strict-transport-security
|
25 |
- referrer-policy
|
26 |
- x-content-type-options
|
|
|
39 |
header_setters = {
|
40 |
"CACHE_CONTROL": set_cache_control,
|
41 |
"HSTS": set_hsts,
|
42 |
+
"PERMISSIONS_POLICY": set_permissions_policy,
|
43 |
"REFERRER_POLICY": set_referrer,
|
44 |
"XCONTENT_TYPE": set_xcontent_type,
|
45 |
"XDOWNLOAD_OPTIONS": set_xdownload_options,
|
|
|
75 |
return {"X-Frame-Options": value}
|
76 |
|
77 |
|
78 |
+
# Set Permissions-Policy response header
|
79 |
+
def set_permissions_policy(value: str):
|
80 |
+
pattern = r"^(?:(accelerometer|autoplay|camera|clipboard-read|clipboard-write|fullscreen|geolocation|gyroscope|magnetometer|microphone|midi|payment|picture-in-picture|sync-xhr|usb|xr-spatial-tracking)=\((self)?\),?)*$"
|
81 |
+
match = re.match(pattern, value, re.IGNORECASE)
|
82 |
+
if not match:
|
83 |
+
value = "none"
|
84 |
+
return {"Permissions-Policy": value}
|
85 |
+
|
86 |
+
|
87 |
# Set Referrer-Policy response header
|
88 |
def set_referrer(value: str):
|
89 |
pattern = r"^(no-referrer|no-referrer-when-downgrade|origin|origin-when-cross-origin|same-origin|strict-origin|strict-origin-when-cross-origin|unsafe-url)$"
|