coolmanx commited on
Commit
c29ae51
1 Parent(s): e784c01
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +2 -1
  2. .env.example +13 -15
  3. .gitattributes +1 -0
  4. .github/workflows/deploy-to-hf-spaces.yml +4 -0
  5. CHANGELOG.md +61 -0
  6. Dockerfile +176 -177
  7. README.md +1 -13
  8. backend/open_webui/apps/audio/main.py +75 -2
  9. backend/open_webui/apps/images/main.py +14 -2
  10. backend/open_webui/apps/ollama/main.py +354 -151
  11. backend/open_webui/apps/openai/main.py +280 -122
  12. backend/open_webui/apps/retrieval/loaders/main.py +1 -1
  13. backend/open_webui/apps/retrieval/main.py +144 -20
  14. backend/open_webui/apps/retrieval/utils.py +76 -57
  15. backend/open_webui/apps/retrieval/vector/connector.py +8 -0
  16. backend/open_webui/apps/retrieval/vector/dbs/chroma.py +3 -1
  17. backend/open_webui/apps/retrieval/vector/dbs/opensearch.py +178 -0
  18. backend/open_webui/apps/retrieval/vector/dbs/pgvector.py +354 -0
  19. backend/open_webui/apps/retrieval/vector/dbs/qdrant.py +7 -2
  20. backend/open_webui/apps/retrieval/web/bing.py +73 -0
  21. backend/open_webui/apps/retrieval/web/jina_search.py +2 -4
  22. backend/open_webui/apps/retrieval/web/testdata/bing.json +58 -0
  23. backend/open_webui/apps/socket/main.py +2 -0
  24. backend/open_webui/apps/webui/main.py +45 -7
  25. backend/open_webui/apps/webui/models/auths.py +5 -0
  26. backend/open_webui/apps/webui/models/chats.py +13 -6
  27. backend/open_webui/apps/webui/models/groups.py +186 -0
  28. backend/open_webui/apps/webui/models/knowledge.py +76 -23
  29. backend/open_webui/apps/webui/models/models.py +103 -8
  30. backend/open_webui/apps/webui/models/prompts.py +58 -9
  31. backend/open_webui/apps/webui/models/tools.py +56 -4
  32. backend/open_webui/apps/webui/models/users.py +8 -0
  33. backend/open_webui/apps/webui/routers/auths.py +281 -3
  34. backend/open_webui/apps/webui/routers/chats.py +9 -5
  35. backend/open_webui/apps/webui/routers/groups.py +120 -0
  36. backend/open_webui/apps/webui/routers/knowledge.py +202 -78
  37. backend/open_webui/apps/webui/routers/models.py +120 -36
  38. backend/open_webui/apps/webui/routers/prompts.py +69 -7
  39. backend/open_webui/apps/webui/routers/tools.py +144 -80
  40. backend/open_webui/apps/webui/routers/users.py +41 -4
  41. backend/open_webui/apps/webui/utils.py +1 -1
  42. backend/open_webui/config.py +229 -35
  43. backend/open_webui/constants.py +2 -0
  44. backend/open_webui/env.py +393 -384
  45. backend/open_webui/main.py +0 -0
  46. backend/open_webui/migrations/versions/922e7a387820_add_group_table.py +85 -0
  47. backend/open_webui/storage/provider.py +4 -1
  48. backend/open_webui/utils/access_control.py +95 -0
  49. backend/open_webui/utils/pdf_generator.py +12 -12
  50. backend/open_webui/utils/security_headers.py +11 -0
.dockerignore CHANGED
@@ -16,4 +16,5 @@ _old
16
  uploads
17
  .ipynb_checkpoints
18
  **/*.db
19
- _test
 
 
16
  uploads
17
  .ipynb_checkpoints
18
  **/*.db
19
+ _test
20
+ backend/data/*
.env.example CHANGED
@@ -1,15 +1,13 @@
1
- # Ollama URL for the backend to connect
2
- # The path '/ollama' will be redirected to the specified backend URL
3
- OLLAMA_BASE_URL='http://localhost:11434'
4
-
5
- OPENAI_API_BASE_URL=''
6
- OPENAI_API_KEY=''
7
-
8
- # AUTOMATIC1111_BASE_URL="http://localhost:7860"
9
-
10
- # DO NOT TRACK
11
- SCARF_NO_ANALYTICS=true
12
- DO_NOT_TRACK=true
13
- ANONYMIZED_TELEMETRY=false
14
-
15
- GLOBAL_LOG_LEVEL="ERROR"
 
1
+ # Ollama URL for the backend to connect
2
+ # The path '/ollama' will be redirected to the specified backend URL
3
+ OLLAMA_BASE_URL='http://localhost:11434'
4
+
5
+ OPENAI_API_BASE_URL=''
6
+ OPENAI_API_KEY=''
7
+
8
+ # AUTOMATIC1111_BASE_URL="http://localhost:7860"
9
+
10
+ # DO NOT TRACK
11
+ SCARF_NO_ANALYTICS=true
12
+ DO_NOT_TRACK=true
13
+ ANONYMIZED_TELEMETRY=false
 
 
.gitattributes CHANGED
@@ -1,2 +1,3 @@
1
  *.sh text eol=lf
2
  *.ttf filter=lfs diff=lfs merge=lfs -text
 
 
1
  *.sh text eol=lf
2
  *.ttf filter=lfs diff=lfs merge=lfs -text
3
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.github/workflows/deploy-to-hf-spaces.yml CHANGED
@@ -28,6 +28,8 @@ jobs:
28
  steps:
29
  - name: Checkout repository
30
  uses: actions/checkout@v4
 
 
31
 
32
  - name: Remove git history
33
  run: rm -rf .git
@@ -52,7 +54,9 @@ jobs:
52
  - name: Set up Git and push to Space
53
  run: |
54
  git init --initial-branch=main
 
55
  git lfs track "*.ttf"
 
56
  rm demo.gif
57
  git add .
58
  git commit -m "GitHub deploy: ${{ github.sha }}"
 
28
  steps:
29
  - name: Checkout repository
30
  uses: actions/checkout@v4
31
+ with:
32
+ lfs: true
33
 
34
  - name: Remove git history
35
  run: rm -rf .git
 
54
  - name: Set up Git and push to Space
55
  run: |
56
  git init --initial-branch=main
57
+ git lfs install
58
  git lfs track "*.ttf"
59
+ git lfs track "*.jpg"
60
  rm demo.gif
61
  git add .
62
  git commit -m "GitHub deploy: ${{ github.sha }}"
CHANGELOG.md CHANGED
@@ -5,10 +5,71 @@ All notable changes to this project will be documented in this file.
5
  The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
6
  and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  ## [0.3.35] - 2024-10-26
9
 
10
  ### Added
11
 
 
12
  - **📁 Robust File Handling**: Enhanced file input handling for chat. If the content extraction fails or is empty, users will now receive a clear warning, preventing silent failures and ensuring you always know what's happening with your uploads.
13
  - **🌍 New Language Support**: Introduced Hungarian translations and updated French translations, expanding the platform's language accessibility for a more global user base.
14
 
 
5
  The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
6
  and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7
 
8
+ ## [0.4.1] - 2024-11-19
9
+
10
+ ### Added
11
+
12
+ - **🛠️ Tool Descriptions on Hover**: When enabled, tool descriptions now appear upon hovering over the tool icon in the message input, giving you more context instantly and improving workflow fluidity.
13
+
14
+ ### Fixed
15
+
16
+ - **🚫 Graceful Handling of Deleted Users**: Resolved an issue where deleted users caused models, knowledge, prompts, or tools to fail loading in the workspace, ensuring smoother operation and fewer interruptions.
17
+ - **🔗 Proxy Fix for HTTPS Models Endpoint**: Fixed issues with proxies affecting the secure `/api/v1/models/` endpoint, ensuring stable connections and reliable access.
18
+ - **🔒 API Key Creation**: Addressed a bug that previously prevented API keys from being created.
19
+
20
+ ## [0.4.0] - 2024-11-19
21
+
22
+ ### Added
23
+
24
+ - **👥 User Groups**: You can now create and manage user groups, making user organization seamless.
25
+ - **🔐 Group-Based Access Control**: Set granular access to models, knowledge, prompts, and tools based on user groups, allowing for more controlled and secure environments.
26
+ - **🛠️ Group-Based User Permissions**: Easily manage workspace permissions. Grant users the ability to upload files, delete, edit, or create temporary chats, as well as define their ability to create models, knowledge, prompts, and tools.
27
+ - **🔑 LDAP Support**: Newly introduced LDAP authentication adds robust security and scalability to user management.
28
+ - **🌐 Enhanced OpenAI-Compatible Connections**: Added prefix ID support to avoid model ID clashes, with explicit model ID support for APIs lacking '/models' endpoint support, ensuring smooth operation with custom setups.
29
+ - **🔐 Ollama API Key Support**: Now manage credentials for Ollama when set behind proxies, including the option to utilize prefix ID for proper distinction across multiple Ollama instances.
30
+ - **🔄 Connection Enable/Disable Toggle**: Easily enable or disable individual OpenAI and Ollama connections as needed.
31
+ - **🎨 Redesigned Model Workspace**: Freshly redesigned to improve usability for managing models across users and groups.
32
+ - **🎨 Redesigned Prompt Workspace**: A fresh UI to conveniently organize and manage prompts.
33
+ - **🧩 Sorted Functions Workspace**: Functions are now automatically categorized by type (Action, Filter, Pipe), streamlining management.
34
+ - **💻 Redesigned Collaborative Workspace**: Enhanced support for multiple users contributing to models, knowledge, prompts, or tools, improving collaboration.
35
+ - **🔧 Auto-Selected Tools in Model Editor**: Tools enabled through the model editor are now automatically selected, whereas previously it only gave users the option to enable the tool, reducing manual steps and enhancing efficiency.
36
+ - **🔔 Web Search & Tools Indicator**: A clear indication now shows when web search or tools are active, reducing confusion.
37
+ - **🔑 Toggle API Key Auth**: Tighten security by easily enabling or disabling API key authentication option for Open WebUI.
38
+ - **🗂️ Agentic Retrieval**: Improve RAG accuracy via smart pre-processing of chat history to determine the best queries before retrieval.
39
+ - **📁 Large Text as File Option**: Optionally convert large pasted text into a file upload, keeping the chat interface cleaner.
40
+ - **🗂️ Toggle Citations for Models**: Ability to disable citations has been introduced in the model editor.
41
+ - **🔍 User Settings Search**: Quickly search for settings fields, improving ease of use and navigation.
42
+ - **🗣️ Experimental SpeechT5 TTS**: Local SpeechT5 support added for improved text-to-speech capabilities.
43
+ - **🔄 Unified Reset for Models**: A one-click option has been introduced to reset and remove all models from the Admin Settings.
44
+ - **🛠️ Initial Setup Wizard**: The setup process now explicitly informs users that they are creating an admin account during the first-time setup, ensuring clarity. Previously, users encountered the login page right away without this distinction.
45
+ - **🌐 Enhanced Translations**: Several language translations, including Ukrainian, Norwegian, and Brazilian Portuguese, were refined for better localization.
46
+
47
+ ### Fixed
48
+
49
+ - **🎥 YouTube Video Attachments**: Fixed issues preventing proper loading and attachment of YouTube videos as files.
50
+ - **🔄 Shared Chat Update**: Corrected issues where shared chats were not updating, improving collaboration consistency.
51
+ - **🔍 DuckDuckGo Rate Limit Fix**: Addressed issues with DuckDuckGo search integration, enhancing search stability and performance when operating within rate limits.
52
+ - **🧾 Citations Relevance Fix**: Adjusted the relevance percentage calculation for citations, so that Open WebUI properly reflect the accuracy of a retrieved document in RAG, ensuring users get clearer insights into sources.
53
+ - **🔑 Jina Search API Key Requirement**: Added the option to input an API key for Jina Search, ensuring smooth functionality as keys are now mandatory.
54
+
55
+ ### Changed
56
+
57
+ - **🛠️ Functions Moved to Admin Panel**: As Functions operate as advanced plugins, they are now accessible from the Admin Panel instead of the workspace.
58
+ - **🛠️ Manage Ollama Connections**: The "Models" section in Admin Settings has been relocated to Admin Settings > "Connections" > Ollama Connections. You can now manage Ollama instances via a dedicated "Manage Ollama" modal from "Connections", streamlining the setup and configuration of Ollama models.
59
+ - **📊 Base Models in Admin Settings**: Admins can now find all base models, both connections or functions, in the "Models" Admin setting. Global model accessibility can be enabled or disabled here. Models are private by default, requiring explicit permission assignment for user access.
60
+ - **📌 Sticky Model Selection for New Chats**: The model chosen from a previous chat now persists when creating a new chat. If you click "New Chat" again from the new chat page, it will revert to your default model.
61
+ - **🎨 Design Refactoring**: Overall design refinements across the platform have been made, providing a more cohesive and polished user experience.
62
+
63
+ ### Removed
64
+
65
+ - **📂 Model List Reordering**: Temporarily removed and will be reintroduced in upcoming user group settings improvements.
66
+ - **⚙️ Default Model Setting**: Removed the ability to set a default model for users, will be reintroduced with user group settings in the future.
67
+
68
  ## [0.3.35] - 2024-10-26
69
 
70
  ### Added
71
 
72
+ - **🌐 Translation Update**: Added translation labels in the SearchInput and CreateCollection components and updated Brazilian Portuguese translation (pt-BR)
73
  - **📁 Robust File Handling**: Enhanced file input handling for chat. If the content extraction fails or is empty, users will now receive a clear warning, preventing silent failures and ensuring you always know what's happening with your uploads.
74
  - **🌍 New Language Support**: Introduced Hungarian translations and updated French translations, expanding the platform's language accessibility for a more global user base.
75
 
Dockerfile CHANGED
@@ -1,177 +1,176 @@
1
- # syntax=docker/dockerfile:1
2
- # Initialize device type args
3
- # use build args in the docker build command with --build-arg="BUILDARG=true"
4
- ARG USE_CUDA=false
5
- ARG USE_OLLAMA=false
6
- # Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
7
- ARG USE_CUDA_VER=cu121
8
- # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
9
- # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
10
- # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
11
- # IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
12
- ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
13
- ARG USE_RERANKING_MODEL=""
14
-
15
- # Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken
16
- ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base"
17
-
18
- ARG BUILD_HASH=dev-build
19
- # Override at your own risk - non-root configurations are untested
20
- ARG UID=0
21
- ARG GID=0
22
-
23
- ######## WebUI frontend ########
24
- FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build
25
- ARG BUILD_HASH
26
-
27
- WORKDIR /app
28
-
29
- COPY package.json package-lock.json ./
30
- RUN npm ci
31
-
32
- COPY . .
33
- ENV APP_BUILD_HASH=${BUILD_HASH}
34
- RUN npm run build
35
-
36
- ######## WebUI backend ########
37
- FROM python:3.11-slim-bookworm AS base
38
-
39
- # Use args
40
- ARG USE_CUDA
41
- ARG USE_OLLAMA
42
- ARG USE_CUDA_VER
43
- ARG USE_EMBEDDING_MODEL
44
- ARG USE_RERANKING_MODEL
45
- ARG UID
46
- ARG GID
47
-
48
- ## Basis ##
49
- ENV ENV=prod \
50
- PORT=8080 \
51
- # pass build args to the build
52
- USE_OLLAMA_DOCKER=${USE_OLLAMA} \
53
- USE_CUDA_DOCKER=${USE_CUDA} \
54
- USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
55
- USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
56
- USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
57
-
58
- ## Basis URL Config ##
59
- ENV OLLAMA_BASE_URL="/ollama" \
60
- OPENAI_API_BASE_URL=""
61
-
62
- ## API Key and Security Config ##
63
- ENV OPENAI_API_KEY="" \
64
- WEBUI_SECRET_KEY="" \
65
- SCARF_NO_ANALYTICS=true \
66
- DO_NOT_TRACK=true \
67
- ANONYMIZED_TELEMETRY=false
68
-
69
- #### Other models #########################################################
70
- ## whisper TTS model settings ##
71
- ENV WHISPER_MODEL="base" \
72
- WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
73
-
74
- ## RAG Embedding model settings ##
75
- ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
76
- RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
77
- SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
78
-
79
- ## Tiktoken model settings ##
80
- ENV TIKTOKEN_ENCODING_NAME="cl100k_base" \
81
- TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken"
82
-
83
- ## Hugging Face download cache ##
84
- ENV HF_HOME="/app/backend/data/cache/embedding/models"
85
-
86
- ## Torch Extensions ##
87
- # ENV TORCH_EXTENSIONS_DIR="/.cache/torch_extensions"
88
-
89
- #### Other models ##########################################################
90
-
91
- WORKDIR /app/backend
92
-
93
- ENV HOME=/root
94
- # Create user and group if not root
95
- RUN if [ $UID -ne 0 ]; then \
96
- if [ $GID -ne 0 ]; then \
97
- addgroup --gid $GID app; \
98
- fi; \
99
- adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \
100
- fi
101
-
102
- RUN mkdir -p $HOME/.cache/chroma
103
- RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id
104
-
105
- # Make sure the user has access to the app and root directory
106
- RUN chown -R $UID:$GID /app $HOME
107
-
108
- RUN if [ "$USE_OLLAMA" = "true" ]; then \
109
- apt-get update && \
110
- # Install pandoc and netcat
111
- apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \
112
- apt-get install -y --no-install-recommends gcc python3-dev && \
113
- # for RAG OCR
114
- apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
115
- # install helper tools
116
- apt-get install -y --no-install-recommends curl jq && \
117
- # install ollama
118
- curl -fsSL https://ollama.com/install.sh | sh && \
119
- # cleanup
120
- rm -rf /var/lib/apt/lists/*; \
121
- else \
122
- apt-get update && \
123
- # Install pandoc, netcat and gcc
124
- apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \
125
- apt-get install -y --no-install-recommends gcc python3-dev && \
126
- # for RAG OCR
127
- apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
128
- # cleanup
129
- rm -rf /var/lib/apt/lists/*; \
130
- fi
131
-
132
- # install python dependencies
133
- COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
134
-
135
- RUN pip3 install uv && \
136
- if [ "$USE_CUDA" = "true" ]; then \
137
- # If you use CUDA the whisper and embedding model will be downloaded on first use
138
- pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
139
- uv pip install --system -r requirements.txt --no-cache-dir && \
140
- python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
141
- python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
142
- python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
143
- else \
144
- pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
145
- uv pip install --system -r requirements.txt --no-cache-dir && \
146
- python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
147
- python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
148
- python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
149
- fi; \
150
- chown -R $UID:$GID /app/backend/data/
151
-
152
-
153
-
154
- # copy embedding weight from build
155
- # RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
156
- # COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx
157
-
158
- # copy built frontend files
159
- COPY --chown=$UID:$GID --from=build /app/build /app/build
160
- COPY --chown=$UID:$GID --from=build /app/CHANGELOG.md /app/CHANGELOG.md
161
- COPY --chown=$UID:$GID --from=build /app/package.json /app/package.json
162
-
163
- # copy backend files
164
- COPY --chown=$UID:$GID ./backend .
165
-
166
- EXPOSE 8080
167
-
168
- HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1
169
-
170
- USER $UID:$GID
171
-
172
- ARG BUILD_HASH
173
- ENV WEBUI_BUILD_VERSION=${BUILD_HASH}
174
- ENV DOCKER=true
175
- ENV GLOBAL_LOG_LEVEL="ERROR"
176
-
177
- CMD [ "bash", "start.sh"]
 
1
+ # syntax=docker/dockerfile:1
2
+ # Initialize device type args
3
+ # use build args in the docker build command with --build-arg="BUILDARG=true"
4
+ ARG USE_CUDA=false
5
+ ARG USE_OLLAMA=false
6
+ # Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
7
+ ARG USE_CUDA_VER=cu121
8
+ # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
9
+ # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
10
+ # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
11
+ # IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
12
+ ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
13
+ ARG USE_RERANKING_MODEL=""
14
+
15
+ # Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken
16
+ ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base"
17
+
18
+ ARG BUILD_HASH=dev-build
19
+ # Override at your own risk - non-root configurations are untested
20
+ ARG UID=0
21
+ ARG GID=0
22
+
23
+ ######## WebUI frontend ########
24
+ FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build
25
+ ARG BUILD_HASH
26
+
27
+ WORKDIR /app
28
+
29
+ COPY package.json package-lock.json ./
30
+ RUN npm ci
31
+
32
+ COPY . .
33
+ ENV APP_BUILD_HASH=${BUILD_HASH}
34
+ RUN npm run build
35
+
36
+ ######## WebUI backend ########
37
+ FROM python:3.11-slim-bookworm AS base
38
+
39
+ # Use args
40
+ ARG USE_CUDA
41
+ ARG USE_OLLAMA
42
+ ARG USE_CUDA_VER
43
+ ARG USE_EMBEDDING_MODEL
44
+ ARG USE_RERANKING_MODEL
45
+ ARG UID
46
+ ARG GID
47
+
48
+ ## Basis ##
49
+ ENV ENV=prod \
50
+ PORT=8080 \
51
+ # pass build args to the build
52
+ USE_OLLAMA_DOCKER=${USE_OLLAMA} \
53
+ USE_CUDA_DOCKER=${USE_CUDA} \
54
+ USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
55
+ USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
56
+ USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
57
+
58
+ ## Basis URL Config ##
59
+ ENV OLLAMA_BASE_URL="/ollama" \
60
+ OPENAI_API_BASE_URL=""
61
+
62
+ ## API Key and Security Config ##
63
+ ENV OPENAI_API_KEY="" \
64
+ WEBUI_SECRET_KEY="" \
65
+ SCARF_NO_ANALYTICS=true \
66
+ DO_NOT_TRACK=true \
67
+ ANONYMIZED_TELEMETRY=false
68
+
69
+ #### Other models #########################################################
70
+ ## whisper TTS model settings ##
71
+ ENV WHISPER_MODEL="base" \
72
+ WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
73
+
74
+ ## RAG Embedding model settings ##
75
+ ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
76
+ RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
77
+ SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
78
+
79
+ ## Tiktoken model settings ##
80
+ ENV TIKTOKEN_ENCODING_NAME="cl100k_base" \
81
+ TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken"
82
+
83
+ ## Hugging Face download cache ##
84
+ ENV HF_HOME="/app/backend/data/cache/embedding/models"
85
+
86
+ ## Torch Extensions ##
87
+ # ENV TORCH_EXTENSIONS_DIR="/.cache/torch_extensions"
88
+
89
+ #### Other models ##########################################################
90
+
91
+ WORKDIR /app/backend
92
+
93
+ ENV HOME=/root
94
+ # Create user and group if not root
95
+ RUN if [ $UID -ne 0 ]; then \
96
+ if [ $GID -ne 0 ]; then \
97
+ addgroup --gid $GID app; \
98
+ fi; \
99
+ adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \
100
+ fi
101
+
102
+ RUN mkdir -p $HOME/.cache/chroma
103
+ RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id
104
+
105
+ # Make sure the user has access to the app and root directory
106
+ RUN chown -R $UID:$GID /app $HOME
107
+
108
+ RUN if [ "$USE_OLLAMA" = "true" ]; then \
109
+ apt-get update && \
110
+ # Install pandoc and netcat
111
+ apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \
112
+ apt-get install -y --no-install-recommends gcc python3-dev && \
113
+ # for RAG OCR
114
+ apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
115
+ # install helper tools
116
+ apt-get install -y --no-install-recommends curl jq && \
117
+ # install ollama
118
+ curl -fsSL https://ollama.com/install.sh | sh && \
119
+ # cleanup
120
+ rm -rf /var/lib/apt/lists/*; \
121
+ else \
122
+ apt-get update && \
123
+ # Install pandoc, netcat and gcc
124
+ apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \
125
+ apt-get install -y --no-install-recommends gcc python3-dev && \
126
+ # for RAG OCR
127
+ apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
128
+ # cleanup
129
+ rm -rf /var/lib/apt/lists/*; \
130
+ fi
131
+
132
+ # install python dependencies
133
+ COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
134
+
135
+ RUN pip3 install uv && \
136
+ if [ "$USE_CUDA" = "true" ]; then \
137
+ # If you use CUDA the whisper and embedding model will be downloaded on first use
138
+ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
139
+ uv pip install --system -r requirements.txt --no-cache-dir && \
140
+ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
141
+ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
142
+ python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
143
+ else \
144
+ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
145
+ uv pip install --system -r requirements.txt --no-cache-dir && \
146
+ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
147
+ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
148
+ python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
149
+ fi; \
150
+ chown -R $UID:$GID /app/backend/data/
151
+
152
+
153
+
154
+ # copy embedding weight from build
155
+ # RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
156
+ # COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx
157
+
158
+ # copy built frontend files
159
+ COPY --chown=$UID:$GID --from=build /app/build /app/build
160
+ COPY --chown=$UID:$GID --from=build /app/CHANGELOG.md /app/CHANGELOG.md
161
+ COPY --chown=$UID:$GID --from=build /app/package.json /app/package.json
162
+
163
+ # copy backend files
164
+ COPY --chown=$UID:$GID ./backend .
165
+
166
+ EXPOSE 8080
167
+
168
+ HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1
169
+
170
+ USER $UID:$GID
171
+
172
+ ARG BUILD_HASH
173
+ ENV WEBUI_BUILD_VERSION=${BUILD_HASH}
174
+ ENV DOCKER=true
175
+
176
+ CMD [ "bash", "start.sh"]
 
README.md CHANGED
@@ -45,7 +45,7 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-
45
 
46
  - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
47
 
48
- - 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch` and `SearchApi` and inject the results directly into your chat experience.
49
 
50
  - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
51
 
@@ -195,18 +195,6 @@ docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --a
195
 
196
  Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/).
197
 
198
- ## Supporters ✨
199
-
200
- A big shoutout to our amazing supporters who's helping to make this project possible! 🙏
201
-
202
- ### Platinum Sponsors 🤍
203
-
204
- - We're looking for Sponsors!
205
-
206
- ### Acknowledgments
207
-
208
- Special thanks to [Prof. Lawrence Kim](https://www.lhkim.com/) and [Prof. Nick Vincent](https://www.nickmvincent.com/) for their invaluable support and guidance in shaping this project into a research endeavor. Grateful for your mentorship throughout the journey! 🙌
209
-
210
  ## License 📜
211
 
212
  This project is licensed under the [MIT License](LICENSE) - see the [LICENSE](LICENSE) file for details. 📄
 
45
 
46
  - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
47
 
48
+ - 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch`, `SearchApi` and `Bing` and inject the results directly into your chat experience.
49
 
50
  - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
51
 
 
195
 
196
  Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/).
197
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  ## License 📜
199
 
200
  This project is licensed under the [MIT License](LICENSE) - see the [LICENSE](LICENSE) file for details. 📄
backend/open_webui/apps/audio/main.py CHANGED
@@ -32,7 +32,13 @@ from open_webui.config import (
32
  )
33
 
34
  from open_webui.constants import ERROR_MESSAGES
35
- from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE
 
 
 
 
 
 
36
  from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
37
  from fastapi.middleware.cors import CORSMiddleware
38
  from fastapi.responses import FileResponse
@@ -47,7 +53,12 @@ MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
47
  log = logging.getLogger(__name__)
48
  log.setLevel(SRC_LOG_LEVELS["AUDIO"])
49
 
50
- app = FastAPI()
 
 
 
 
 
51
  app.add_middleware(
52
  CORSMiddleware,
53
  allow_origins=CORS_ALLOW_ORIGIN,
@@ -74,6 +85,10 @@ app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
74
  app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
75
  app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
76
 
 
 
 
 
77
  app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
78
  app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
79
 
@@ -231,6 +246,21 @@ async def update_audio_config(
231
  }
232
 
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  @app.post("/speech")
235
  async def speech(request: Request, user=Depends(get_verified_user)):
236
  body = await request.body()
@@ -248,6 +278,12 @@ async def speech(request: Request, user=Depends(get_verified_user)):
248
  headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
249
  headers["Content-Type"] = "application/json"
250
 
 
 
 
 
 
 
251
  try:
252
  body = body.decode("utf-8")
253
  body = json.loads(body)
@@ -391,6 +427,43 @@ async def speech(request: Request, user=Depends(get_verified_user)):
391
  raise HTTPException(
392
  status_code=500, detail=f"Error synthesizing speech - {response.reason}"
393
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
 
396
  def transcribe(file_path):
 
32
  )
33
 
34
  from open_webui.constants import ERROR_MESSAGES
35
+ from open_webui.env import (
36
+ ENV,
37
+ SRC_LOG_LEVELS,
38
+ DEVICE_TYPE,
39
+ ENABLE_FORWARD_USER_INFO_HEADERS,
40
+ )
41
+
42
  from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
43
  from fastapi.middleware.cors import CORSMiddleware
44
  from fastapi.responses import FileResponse
 
53
  log = logging.getLogger(__name__)
54
  log.setLevel(SRC_LOG_LEVELS["AUDIO"])
55
 
56
+ app = FastAPI(
57
+ docs_url="/docs" if ENV == "dev" else None,
58
+ openapi_url="/openapi.json" if ENV == "dev" else None,
59
+ redoc_url=None,
60
+ )
61
+
62
  app.add_middleware(
63
  CORSMiddleware,
64
  allow_origins=CORS_ALLOW_ORIGIN,
 
85
  app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
86
  app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
87
 
88
+
89
+ app.state.speech_synthesiser = None
90
+ app.state.speech_speaker_embeddings_dataset = None
91
+
92
  app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
93
  app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
94
 
 
246
  }
247
 
248
 
249
+ def load_speech_pipeline():
250
+ from transformers import pipeline
251
+ from datasets import load_dataset
252
+
253
+ if app.state.speech_synthesiser is None:
254
+ app.state.speech_synthesiser = pipeline(
255
+ "text-to-speech", "microsoft/speecht5_tts"
256
+ )
257
+
258
+ if app.state.speech_speaker_embeddings_dataset is None:
259
+ app.state.speech_speaker_embeddings_dataset = load_dataset(
260
+ "Matthijs/cmu-arctic-xvectors", split="validation"
261
+ )
262
+
263
+
264
  @app.post("/speech")
265
  async def speech(request: Request, user=Depends(get_verified_user)):
266
  body = await request.body()
 
278
  headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
279
  headers["Content-Type"] = "application/json"
280
 
281
+ if ENABLE_FORWARD_USER_INFO_HEADERS:
282
+ headers["X-OpenWebUI-User-Name"] = user.name
283
+ headers["X-OpenWebUI-User-Id"] = user.id
284
+ headers["X-OpenWebUI-User-Email"] = user.email
285
+ headers["X-OpenWebUI-User-Role"] = user.role
286
+
287
  try:
288
  body = body.decode("utf-8")
289
  body = json.loads(body)
 
427
  raise HTTPException(
428
  status_code=500, detail=f"Error synthesizing speech - {response.reason}"
429
  )
430
+ elif app.state.config.TTS_ENGINE == "transformers":
431
+ payload = None
432
+ try:
433
+ payload = json.loads(body.decode("utf-8"))
434
+ except Exception as e:
435
+ log.exception(e)
436
+ raise HTTPException(status_code=400, detail="Invalid JSON payload")
437
+
438
+ import torch
439
+ import soundfile as sf
440
+
441
+ load_speech_pipeline()
442
+
443
+ embeddings_dataset = app.state.speech_speaker_embeddings_dataset
444
+
445
+ speaker_index = 6799
446
+ try:
447
+ speaker_index = embeddings_dataset["filename"].index(
448
+ app.state.config.TTS_MODEL
449
+ )
450
+ except Exception:
451
+ pass
452
+
453
+ speaker_embedding = torch.tensor(
454
+ embeddings_dataset[speaker_index]["xvector"]
455
+ ).unsqueeze(0)
456
+
457
+ speech = app.state.speech_synthesiser(
458
+ payload["input"],
459
+ forward_params={"speaker_embeddings": speaker_embedding},
460
+ )
461
+
462
+ sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
463
+ with open(file_body_path, "w") as f:
464
+ json.dump(json.loads(body.decode("utf-8")), f)
465
+
466
+ return FileResponse(file_path)
467
 
468
 
469
  def transcribe(file_path):
backend/open_webui/apps/images/main.py CHANGED
@@ -35,7 +35,8 @@ from open_webui.config import (
35
  AppConfig,
36
  )
37
  from open_webui.constants import ERROR_MESSAGES
38
- from open_webui.env import SRC_LOG_LEVELS
 
39
  from fastapi import Depends, FastAPI, HTTPException, Request
40
  from fastapi.middleware.cors import CORSMiddleware
41
  from pydantic import BaseModel
@@ -47,7 +48,12 @@ log.setLevel(SRC_LOG_LEVELS["IMAGES"])
47
  IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
48
  IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
49
 
50
- app = FastAPI()
 
 
 
 
 
51
  app.add_middleware(
52
  CORSMiddleware,
53
  allow_origins=CORS_ALLOW_ORIGIN,
@@ -456,6 +462,12 @@ async def image_generations(
456
  headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
457
  headers["Content-Type"] = "application/json"
458
 
 
 
 
 
 
 
459
  data = {
460
  "model": (
461
  app.state.config.MODEL
 
35
  AppConfig,
36
  )
37
  from open_webui.constants import ERROR_MESSAGES
38
+ from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
39
+
40
  from fastapi import Depends, FastAPI, HTTPException, Request
41
  from fastapi.middleware.cors import CORSMiddleware
42
  from pydantic import BaseModel
 
48
  IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
49
  IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
50
 
51
+ app = FastAPI(
52
+ docs_url="/docs" if ENV == "dev" else None,
53
+ openapi_url="/openapi.json" if ENV == "dev" else None,
54
+ redoc_url=None,
55
+ )
56
+
57
  app.add_middleware(
58
  CORSMiddleware,
59
  allow_origins=CORS_ALLOW_ORIGIN,
 
462
  headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
463
  headers["Content-Type"] = "application/json"
464
 
465
+ if ENABLE_FORWARD_USER_INFO_HEADERS:
466
+ headers["X-OpenWebUI-User-Name"] = user.name
467
+ headers["X-OpenWebUI-User-Id"] = user.id
468
+ headers["X-OpenWebUI-User-Email"] = user.email
469
+ headers["X-OpenWebUI-User-Role"] = user.role
470
+
471
  data = {
472
  "model": (
473
  app.state.config.MODEL
backend/open_webui/apps/ollama/main.py CHANGED
@@ -13,18 +13,20 @@ import requests
13
  from open_webui.apps.webui.models.models import Models
14
  from open_webui.config import (
15
  CORS_ALLOW_ORIGIN,
16
- ENABLE_MODEL_FILTER,
17
  ENABLE_OLLAMA_API,
18
- MODEL_FILTER_LIST,
19
  OLLAMA_BASE_URLS,
 
20
  UPLOAD_DIR,
21
  AppConfig,
22
  )
23
- from open_webui.env import AIOHTTP_CLIENT_TIMEOUT
 
 
 
24
 
25
 
26
  from open_webui.constants import ERROR_MESSAGES
27
- from open_webui.env import SRC_LOG_LEVELS
28
  from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
29
  from fastapi.middleware.cors import CORSMiddleware
30
  from fastapi.responses import StreamingResponse
@@ -41,11 +43,18 @@ from open_webui.utils.payload import (
41
  apply_model_system_prompt_to_body,
42
  )
43
  from open_webui.utils.utils import get_admin_user, get_verified_user
 
44
 
45
  log = logging.getLogger(__name__)
46
  log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
47
 
48
- app = FastAPI()
 
 
 
 
 
 
49
  app.add_middleware(
50
  CORSMiddleware,
51
  allow_origins=CORS_ALLOW_ORIGIN,
@@ -56,12 +65,9 @@ app.add_middleware(
56
 
57
  app.state.config = AppConfig()
58
 
59
- app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
60
- app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
61
-
62
  app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
63
  app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
64
- app.state.MODELS = {}
65
 
66
 
67
  # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
@@ -69,60 +75,98 @@ app.state.MODELS = {}
69
  # least connections, or least response time for better resource utilization and performance optimization.
70
 
71
 
72
- @app.middleware("http")
73
- async def check_url(request: Request, call_next):
74
- if len(app.state.MODELS) == 0:
75
- await get_all_models()
76
- else:
77
- pass
78
-
79
- response = await call_next(request)
80
- return response
81
-
82
-
83
  @app.head("/")
84
  @app.get("/")
85
  async def get_status():
86
  return {"status": True}
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  @app.get("/config")
90
  async def get_config(user=Depends(get_admin_user)):
91
- return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
 
 
 
 
92
 
93
 
94
  class OllamaConfigForm(BaseModel):
95
- enable_ollama_api: Optional[bool] = None
 
 
96
 
97
 
98
  @app.post("/config/update")
99
  async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
100
- app.state.config.ENABLE_OLLAMA_API = form_data.enable_ollama_api
101
- return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
102
-
103
-
104
- @app.get("/urls")
105
- async def get_ollama_api_urls(user=Depends(get_admin_user)):
106
- return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
107
-
108
 
109
- class UrlUpdateForm(BaseModel):
110
- urls: list[str]
111
 
 
 
 
 
 
112
 
113
- @app.post("/urls/update")
114
- async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
115
- app.state.config.OLLAMA_BASE_URLS = form_data.urls
 
 
116
 
117
- log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}")
118
- return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
119
 
120
-
121
- async def fetch_url(url):
122
- timeout = aiohttp.ClientTimeout(total=3)
123
  try:
 
124
  async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
125
- async with session.get(url) as response:
126
  return await response.json()
127
  except Exception as e:
128
  # Handle connection error here
@@ -148,10 +192,18 @@ async def post_streaming_url(
148
  session = aiohttp.ClientSession(
149
  trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
150
  )
 
 
 
 
 
 
 
 
151
  r = await session.post(
152
  url,
153
  data=payload,
154
- headers={"Content-Type": "application/json"},
155
  )
156
  r.raise_for_status()
157
 
@@ -194,29 +246,62 @@ def merge_models_lists(model_lists):
194
  for idx, model_list in enumerate(model_lists):
195
  if model_list is not None:
196
  for model in model_list:
197
- digest = model["digest"]
198
- if digest not in merged_models:
199
  model["urls"] = [idx]
200
- merged_models[digest] = model
201
  else:
202
- merged_models[digest]["urls"].append(idx)
203
 
204
  return list(merged_models.values())
205
 
206
 
207
  async def get_all_models():
208
  log.info("get_all_models()")
209
-
210
  if app.state.config.ENABLE_OLLAMA_API:
211
- tasks = [
212
- fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS
213
- ]
 
 
 
 
 
 
 
 
 
 
 
214
  responses = await asyncio.gather(*tasks)
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  models = {
217
  "models": merge_models_lists(
218
  map(
219
- lambda response: response["models"] if response else None, responses
 
220
  )
221
  )
222
  }
@@ -224,8 +309,6 @@ async def get_all_models():
224
  else:
225
  models = {"models": []}
226
 
227
- app.state.MODELS = {model["model"]: model for model in models["models"]}
228
-
229
  return models
230
 
231
 
@@ -234,29 +317,25 @@ async def get_all_models():
234
  async def get_ollama_tags(
235
  url_idx: Optional[int] = None, user=Depends(get_verified_user)
236
  ):
 
237
  if url_idx is None:
238
  models = await get_all_models()
239
-
240
- if app.state.config.ENABLE_MODEL_FILTER:
241
- if user.role == "user":
242
- models["models"] = list(
243
- filter(
244
- lambda model: model["name"]
245
- in app.state.config.MODEL_FILTER_LIST,
246
- models["models"],
247
- )
248
- )
249
- return models
250
- return models
251
  else:
252
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
253
 
 
 
 
 
 
 
 
254
  r = None
255
  try:
256
- r = requests.request(method="GET", url=f"{url}/api/tags")
257
  r.raise_for_status()
258
 
259
- return r.json()
260
  except Exception as e:
261
  log.exception(e)
262
  error_detail = "Open WebUI: Server Connection Error"
@@ -273,6 +352,20 @@ async def get_ollama_tags(
273
  detail=error_detail,
274
  )
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
  @app.get("/api/version")
278
  @app.get("/api/version/{url_idx}")
@@ -281,7 +374,10 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
281
  if url_idx is None:
282
  # returns lowest version
283
  tasks = [
284
- fetch_url(f"{url}/api/version")
 
 
 
285
  for url in app.state.config.OLLAMA_BASE_URLS
286
  ]
287
  responses = await asyncio.gather(*tasks)
@@ -361,8 +457,11 @@ async def push_model(
361
  user=Depends(get_admin_user),
362
  ):
363
  if url_idx is None:
364
- if form_data.name in app.state.MODELS:
365
- url_idx = app.state.MODELS[form_data.name]["urls"][0]
 
 
 
366
  else:
367
  raise HTTPException(
368
  status_code=400,
@@ -411,8 +510,11 @@ async def copy_model(
411
  user=Depends(get_admin_user),
412
  ):
413
  if url_idx is None:
414
- if form_data.source in app.state.MODELS:
415
- url_idx = app.state.MODELS[form_data.source]["urls"][0]
 
 
 
416
  else:
417
  raise HTTPException(
418
  status_code=400,
@@ -421,10 +523,18 @@ async def copy_model(
421
 
422
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
423
  log.info(f"url: {url}")
 
 
 
 
 
 
 
 
424
  r = requests.request(
425
  method="POST",
426
  url=f"{url}/api/copy",
427
- headers={"Content-Type": "application/json"},
428
  data=form_data.model_dump_json(exclude_none=True).encode(),
429
  )
430
 
@@ -459,8 +569,11 @@ async def delete_model(
459
  user=Depends(get_admin_user),
460
  ):
461
  if url_idx is None:
462
- if form_data.name in app.state.MODELS:
463
- url_idx = app.state.MODELS[form_data.name]["urls"][0]
 
 
 
464
  else:
465
  raise HTTPException(
466
  status_code=400,
@@ -470,11 +583,18 @@ async def delete_model(
470
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
471
  log.info(f"url: {url}")
472
 
 
 
 
 
 
 
 
473
  r = requests.request(
474
  method="DELETE",
475
  url=f"{url}/api/delete",
476
- headers={"Content-Type": "application/json"},
477
  data=form_data.model_dump_json(exclude_none=True).encode(),
 
478
  )
479
  try:
480
  r.raise_for_status()
@@ -501,20 +621,30 @@ async def delete_model(
501
 
502
  @app.post("/api/show")
503
  async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
504
- if form_data.name not in app.state.MODELS:
 
 
 
505
  raise HTTPException(
506
  status_code=400,
507
  detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
508
  )
509
 
510
- url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
511
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
512
  log.info(f"url: {url}")
513
 
 
 
 
 
 
 
 
514
  r = requests.request(
515
  method="POST",
516
  url=f"{url}/api/show",
517
- headers={"Content-Type": "application/json"},
518
  data=form_data.model_dump_json(exclude_none=True).encode(),
519
  )
520
  try:
@@ -570,23 +700,26 @@ async def generate_embeddings(
570
  url_idx: Optional[int] = None,
571
  user=Depends(get_verified_user),
572
  ):
573
- return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
574
 
575
 
576
- def generate_ollama_embeddings(
577
  form_data: GenerateEmbeddingsForm,
578
  url_idx: Optional[int] = None,
579
  ):
580
  log.info(f"generate_ollama_embeddings {form_data}")
581
 
582
  if url_idx is None:
 
 
 
583
  model = form_data.model
584
 
585
  if ":" not in model:
586
  model = f"{model}:latest"
587
 
588
- if model in app.state.MODELS:
589
- url_idx = random.choice(app.state.MODELS[model]["urls"])
590
  else:
591
  raise HTTPException(
592
  status_code=400,
@@ -596,10 +729,17 @@ def generate_ollama_embeddings(
596
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
597
  log.info(f"url: {url}")
598
 
 
 
 
 
 
 
 
599
  r = requests.request(
600
  method="POST",
601
  url=f"{url}/api/embeddings",
602
- headers={"Content-Type": "application/json"},
603
  data=form_data.model_dump_json(exclude_none=True).encode(),
604
  )
605
  try:
@@ -630,20 +770,23 @@ def generate_ollama_embeddings(
630
  )
631
 
632
 
633
- def generate_ollama_batch_embeddings(
634
  form_data: GenerateEmbedForm,
635
  url_idx: Optional[int] = None,
636
  ):
637
  log.info(f"generate_ollama_batch_embeddings {form_data}")
638
 
639
  if url_idx is None:
 
 
 
640
  model = form_data.model
641
 
642
  if ":" not in model:
643
  model = f"{model}:latest"
644
 
645
- if model in app.state.MODELS:
646
- url_idx = random.choice(app.state.MODELS[model]["urls"])
647
  else:
648
  raise HTTPException(
649
  status_code=400,
@@ -653,10 +796,17 @@ def generate_ollama_batch_embeddings(
653
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
654
  log.info(f"url: {url}")
655
 
 
 
 
 
 
 
 
656
  r = requests.request(
657
  method="POST",
658
  url=f"{url}/api/embed",
659
- headers={"Content-Type": "application/json"},
660
  data=form_data.model_dump_json(exclude_none=True).encode(),
661
  )
662
  try:
@@ -706,13 +856,16 @@ async def generate_completion(
706
  user=Depends(get_verified_user),
707
  ):
708
  if url_idx is None:
 
 
 
709
  model = form_data.model
710
 
711
  if ":" not in model:
712
  model = f"{model}:latest"
713
 
714
- if model in app.state.MODELS:
715
- url_idx = random.choice(app.state.MODELS[model]["urls"])
716
  else:
717
  raise HTTPException(
718
  status_code=400,
@@ -720,6 +873,10 @@ async def generate_completion(
720
  )
721
 
722
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 
 
 
 
723
  log.info(f"url: {url}")
724
 
725
  return await post_streaming_url(
@@ -743,14 +900,17 @@ class GenerateChatCompletionForm(BaseModel):
743
  keep_alive: Optional[Union[int, str]] = None
744
 
745
 
746
- def get_ollama_url(url_idx: Optional[int], model: str):
747
  if url_idx is None:
748
- if model not in app.state.MODELS:
 
 
 
749
  raise HTTPException(
750
  status_code=400,
751
  detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
752
  )
753
- url_idx = random.choice(app.state.MODELS[model]["urls"])
754
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
755
  return url
756
 
@@ -768,15 +928,7 @@ async def generate_chat_completion(
768
  if "metadata" in payload:
769
  del payload["metadata"]
770
 
771
- model_id = form_data.model
772
-
773
- if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER:
774
- if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
775
- raise HTTPException(
776
- status_code=403,
777
- detail="Model not found",
778
- )
779
-
780
  model_info = Models.get_model_by_id(model_id)
781
 
782
  if model_info:
@@ -794,13 +946,37 @@ async def generate_chat_completion(
794
  )
795
  payload = apply_model_system_prompt_to_body(params, payload, user)
796
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
797
  if ":" not in payload["model"]:
798
  payload["model"] = f"{payload['model']}:latest"
799
 
800
- url = get_ollama_url(url_idx, payload["model"])
801
  log.info(f"url: {url}")
802
  log.debug(f"generate_chat_completion() - 2.payload = {payload}")
803
 
 
 
 
 
 
804
  return await post_streaming_url(
805
  f"{url}/api/chat",
806
  json.dumps(payload),
@@ -817,7 +993,7 @@ class OpenAIChatMessageContent(BaseModel):
817
 
818
  class OpenAIChatMessage(BaseModel):
819
  role: str
820
- content: Union[str, OpenAIChatMessageContent]
821
 
822
  model_config = ConfigDict(extra="allow")
823
 
@@ -836,22 +1012,24 @@ async def generate_openai_chat_completion(
836
  url_idx: Optional[int] = None,
837
  user=Depends(get_verified_user),
838
  ):
839
- completion_form = OpenAIChatCompletionForm(**form_data)
 
 
 
 
 
 
 
 
840
  payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
841
  if "metadata" in payload:
842
  del payload["metadata"]
843
 
844
  model_id = completion_form.model
845
-
846
- if app.state.config.ENABLE_MODEL_FILTER:
847
- if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
848
- raise HTTPException(
849
- status_code=403,
850
- detail="Model not found",
851
- )
852
 
853
  model_info = Models.get_model_by_id(model_id)
854
-
855
  if model_info:
856
  if model_info.base_model_id:
857
  payload["model"] = model_info.base_model_id
@@ -862,12 +1040,36 @@ async def generate_openai_chat_completion(
862
  payload = apply_model_params_to_body_openai(params, payload)
863
  payload = apply_model_system_prompt_to_body(params, payload, user)
864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865
  if ":" not in payload["model"]:
866
  payload["model"] = f"{payload['model']}:latest"
867
 
868
- url = get_ollama_url(url_idx, payload["model"])
869
  log.info(f"url: {url}")
870
 
 
 
 
 
 
871
  return await post_streaming_url(
872
  f"{url}/v1/chat/completions",
873
  json.dumps(payload),
@@ -881,21 +1083,29 @@ async def get_openai_models(
881
  url_idx: Optional[int] = None,
882
  user=Depends(get_verified_user),
883
  ):
 
 
884
  if url_idx is None:
885
- models = await get_all_models()
 
 
 
 
 
 
 
 
 
886
 
887
- if app.state.config.ENABLE_MODEL_FILTER:
888
- if user.role == "user":
889
- models["models"] = list(
890
- filter(
891
- lambda model: model["name"]
892
- in app.state.config.MODEL_FILTER_LIST,
893
- models["models"],
894
- )
895
- )
896
 
897
- return {
898
- "data": [
899
  {
900
  "id": model["model"],
901
  "object": "model",
@@ -903,31 +1113,7 @@ async def get_openai_models(
903
  "owned_by": "openai",
904
  }
905
  for model in models["models"]
906
- ],
907
- "object": "list",
908
- }
909
-
910
- else:
911
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
912
- try:
913
- r = requests.request(method="GET", url=f"{url}/api/tags")
914
- r.raise_for_status()
915
-
916
- models = r.json()
917
-
918
- return {
919
- "data": [
920
- {
921
- "id": model["model"],
922
- "object": "model",
923
- "created": int(time.time()),
924
- "owned_by": "openai",
925
- }
926
- for model in models["models"]
927
- ],
928
- "object": "list",
929
- }
930
-
931
  except Exception as e:
932
  log.exception(e)
933
  error_detail = "Open WebUI: Server Connection Error"
@@ -944,6 +1130,23 @@ async def get_openai_models(
944
  detail=error_detail,
945
  )
946
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
947
 
948
  class UrlForm(BaseModel):
949
  url: str
 
13
  from open_webui.apps.webui.models.models import Models
14
  from open_webui.config import (
15
  CORS_ALLOW_ORIGIN,
 
16
  ENABLE_OLLAMA_API,
 
17
  OLLAMA_BASE_URLS,
18
+ OLLAMA_API_CONFIGS,
19
  UPLOAD_DIR,
20
  AppConfig,
21
  )
22
+ from open_webui.env import (
23
+ AIOHTTP_CLIENT_TIMEOUT,
24
+ AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
25
+ )
26
 
27
 
28
  from open_webui.constants import ERROR_MESSAGES
29
+ from open_webui.env import ENV, SRC_LOG_LEVELS
30
  from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
31
  from fastapi.middleware.cors import CORSMiddleware
32
  from fastapi.responses import StreamingResponse
 
43
  apply_model_system_prompt_to_body,
44
  )
45
  from open_webui.utils.utils import get_admin_user, get_verified_user
46
+ from open_webui.utils.access_control import has_access
47
 
48
  log = logging.getLogger(__name__)
49
  log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
50
 
51
+
52
+ app = FastAPI(
53
+ docs_url="/docs" if ENV == "dev" else None,
54
+ openapi_url="/openapi.json" if ENV == "dev" else None,
55
+ redoc_url=None,
56
+ )
57
+
58
  app.add_middleware(
59
  CORSMiddleware,
60
  allow_origins=CORS_ALLOW_ORIGIN,
 
65
 
66
  app.state.config = AppConfig()
67
 
 
 
 
68
  app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
69
  app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
70
+ app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
71
 
72
 
73
  # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
 
75
  # least connections, or least response time for better resource utilization and performance optimization.
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
78
  @app.head("/")
79
  @app.get("/")
80
  async def get_status():
81
  return {"status": True}
82
 
83
 
84
+ class ConnectionVerificationForm(BaseModel):
85
+ url: str
86
+ key: Optional[str] = None
87
+
88
+
89
+ @app.post("/verify")
90
+ async def verify_connection(
91
+ form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
92
+ ):
93
+ url = form_data.url
94
+ key = form_data.key
95
+
96
+ headers = {}
97
+ if key:
98
+ headers["Authorization"] = f"Bearer {key}"
99
+
100
+ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
101
+ async with aiohttp.ClientSession(timeout=timeout) as session:
102
+ try:
103
+ async with session.get(f"{url}/api/version", headers=headers) as r:
104
+ if r.status != 200:
105
+ # Extract response error details if available
106
+ error_detail = f"HTTP Error: {r.status}"
107
+ res = await r.json()
108
+ if "error" in res:
109
+ error_detail = f"External Error: {res['error']}"
110
+ raise Exception(error_detail)
111
+
112
+ response_data = await r.json()
113
+ return response_data
114
+
115
+ except aiohttp.ClientError as e:
116
+ # ClientError covers all aiohttp requests issues
117
+ log.exception(f"Client error: {str(e)}")
118
+ # Handle aiohttp-specific connection issues, timeout etc.
119
+ raise HTTPException(
120
+ status_code=500, detail="Open WebUI: Server Connection Error"
121
+ )
122
+ except Exception as e:
123
+ log.exception(f"Unexpected error: {e}")
124
+ # Generic error handler in case parsing JSON or other steps fail
125
+ error_detail = f"Unexpected error: {str(e)}"
126
+ raise HTTPException(status_code=500, detail=error_detail)
127
+
128
+
129
  @app.get("/config")
130
  async def get_config(user=Depends(get_admin_user)):
131
+ return {
132
+ "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
133
+ "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
134
+ "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
135
+ }
136
 
137
 
138
  class OllamaConfigForm(BaseModel):
139
+ ENABLE_OLLAMA_API: Optional[bool] = None
140
+ OLLAMA_BASE_URLS: list[str]
141
+ OLLAMA_API_CONFIGS: dict
142
 
143
 
144
  @app.post("/config/update")
145
  async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
146
+ app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
147
+ app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
 
 
 
 
 
 
148
 
149
+ app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
 
150
 
151
+ # Remove any extra configs
152
+ config_urls = app.state.config.OLLAMA_API_CONFIGS.keys()
153
+ for url in list(app.state.config.OLLAMA_BASE_URLS):
154
+ if url not in config_urls:
155
+ app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
156
 
157
+ return {
158
+ "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
159
+ "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
160
+ "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
161
+ }
162
 
 
 
163
 
164
+ async def aiohttp_get(url, key=None):
165
+ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
 
166
  try:
167
+ headers = {"Authorization": f"Bearer {key}"} if key else {}
168
  async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
169
+ async with session.get(url, headers=headers) as response:
170
  return await response.json()
171
  except Exception as e:
172
  # Handle connection error here
 
192
  session = aiohttp.ClientSession(
193
  trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
194
  )
195
+
196
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
197
+ key = api_config.get("key", None)
198
+
199
+ headers = {"Content-Type": "application/json"}
200
+ if key:
201
+ headers["Authorization"] = f"Bearer {key}"
202
+
203
  r = await session.post(
204
  url,
205
  data=payload,
206
+ headers=headers,
207
  )
208
  r.raise_for_status()
209
 
 
246
  for idx, model_list in enumerate(model_lists):
247
  if model_list is not None:
248
  for model in model_list:
249
+ id = model["model"]
250
+ if id not in merged_models:
251
  model["urls"] = [idx]
252
+ merged_models[id] = model
253
  else:
254
+ merged_models[id]["urls"].append(idx)
255
 
256
  return list(merged_models.values())
257
 
258
 
259
  async def get_all_models():
260
  log.info("get_all_models()")
 
261
  if app.state.config.ENABLE_OLLAMA_API:
262
+ tasks = []
263
+ for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS):
264
+ if url not in app.state.config.OLLAMA_API_CONFIGS:
265
+ tasks.append(aiohttp_get(f"{url}/api/tags"))
266
+ else:
267
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
268
+ enable = api_config.get("enable", True)
269
+ key = api_config.get("key", None)
270
+
271
+ if enable:
272
+ tasks.append(aiohttp_get(f"{url}/api/tags", key))
273
+ else:
274
+ tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
275
+
276
  responses = await asyncio.gather(*tasks)
277
 
278
+ for idx, response in enumerate(responses):
279
+ if response:
280
+ url = app.state.config.OLLAMA_BASE_URLS[idx]
281
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
282
+
283
+ prefix_id = api_config.get("prefix_id", None)
284
+ model_ids = api_config.get("model_ids", [])
285
+
286
+ if len(model_ids) != 0 and "models" in response:
287
+ response["models"] = list(
288
+ filter(
289
+ lambda model: model["model"] in model_ids,
290
+ response["models"],
291
+ )
292
+ )
293
+
294
+ if prefix_id:
295
+ for model in response.get("models", []):
296
+ model["model"] = f"{prefix_id}.{model['model']}"
297
+
298
+ print(responses)
299
+
300
  models = {
301
  "models": merge_models_lists(
302
  map(
303
+ lambda response: response.get("models", []) if response else None,
304
+ responses,
305
  )
306
  )
307
  }
 
309
  else:
310
  models = {"models": []}
311
 
 
 
312
  return models
313
 
314
 
 
317
  async def get_ollama_tags(
318
  url_idx: Optional[int] = None, user=Depends(get_verified_user)
319
  ):
320
+ models = []
321
  if url_idx is None:
322
  models = await get_all_models()
 
 
 
 
 
 
 
 
 
 
 
 
323
  else:
324
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
325
 
326
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
327
+ key = api_config.get("key", None)
328
+
329
+ headers = {}
330
+ if key:
331
+ headers["Authorization"] = f"Bearer {key}"
332
+
333
  r = None
334
  try:
335
+ r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers)
336
  r.raise_for_status()
337
 
338
+ models = r.json()
339
  except Exception as e:
340
  log.exception(e)
341
  error_detail = "Open WebUI: Server Connection Error"
 
352
  detail=error_detail,
353
  )
354
 
355
+ if user.role == "user":
356
+ # Filter models based on user access control
357
+ filtered_models = []
358
+ for model in models.get("models", []):
359
+ model_info = Models.get_model_by_id(model["model"])
360
+ if model_info:
361
+ if user.id == model_info.user_id or has_access(
362
+ user.id, type="read", access_control=model_info.access_control
363
+ ):
364
+ filtered_models.append(model)
365
+ models["models"] = filtered_models
366
+
367
+ return models
368
+
369
 
370
  @app.get("/api/version")
371
  @app.get("/api/version/{url_idx}")
 
374
  if url_idx is None:
375
  # returns lowest version
376
  tasks = [
377
+ aiohttp_get(
378
+ f"{url}/api/version",
379
+ app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
380
+ )
381
  for url in app.state.config.OLLAMA_BASE_URLS
382
  ]
383
  responses = await asyncio.gather(*tasks)
 
457
  user=Depends(get_admin_user),
458
  ):
459
  if url_idx is None:
460
+ model_list = await get_all_models()
461
+ models = {model["model"]: model for model in model_list["models"]}
462
+
463
+ if form_data.name in models:
464
+ url_idx = models[form_data.name]["urls"][0]
465
  else:
466
  raise HTTPException(
467
  status_code=400,
 
510
  user=Depends(get_admin_user),
511
  ):
512
  if url_idx is None:
513
+ model_list = await get_all_models()
514
+ models = {model["model"]: model for model in model_list["models"]}
515
+
516
+ if form_data.source in models:
517
+ url_idx = models[form_data.source]["urls"][0]
518
  else:
519
  raise HTTPException(
520
  status_code=400,
 
523
 
524
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
525
  log.info(f"url: {url}")
526
+
527
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
528
+ key = api_config.get("key", None)
529
+
530
+ headers = {"Content-Type": "application/json"}
531
+ if key:
532
+ headers["Authorization"] = f"Bearer {key}"
533
+
534
  r = requests.request(
535
  method="POST",
536
  url=f"{url}/api/copy",
537
+ headers=headers,
538
  data=form_data.model_dump_json(exclude_none=True).encode(),
539
  )
540
 
 
569
  user=Depends(get_admin_user),
570
  ):
571
  if url_idx is None:
572
+ model_list = await get_all_models()
573
+ models = {model["model"]: model for model in model_list["models"]}
574
+
575
+ if form_data.name in models:
576
+ url_idx = models[form_data.name]["urls"][0]
577
  else:
578
  raise HTTPException(
579
  status_code=400,
 
583
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
584
  log.info(f"url: {url}")
585
 
586
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
587
+ key = api_config.get("key", None)
588
+
589
+ headers = {"Content-Type": "application/json"}
590
+ if key:
591
+ headers["Authorization"] = f"Bearer {key}"
592
+
593
  r = requests.request(
594
  method="DELETE",
595
  url=f"{url}/api/delete",
 
596
  data=form_data.model_dump_json(exclude_none=True).encode(),
597
+ headers=headers,
598
  )
599
  try:
600
  r.raise_for_status()
 
621
 
622
  @app.post("/api/show")
623
  async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
624
+ model_list = await get_all_models()
625
+ models = {model["model"]: model for model in model_list["models"]}
626
+
627
+ if form_data.name not in models:
628
  raise HTTPException(
629
  status_code=400,
630
  detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
631
  )
632
 
633
+ url_idx = random.choice(models[form_data.name]["urls"])
634
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
635
  log.info(f"url: {url}")
636
 
637
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
638
+ key = api_config.get("key", None)
639
+
640
+ headers = {"Content-Type": "application/json"}
641
+ if key:
642
+ headers["Authorization"] = f"Bearer {key}"
643
+
644
  r = requests.request(
645
  method="POST",
646
  url=f"{url}/api/show",
647
+ headers=headers,
648
  data=form_data.model_dump_json(exclude_none=True).encode(),
649
  )
650
  try:
 
700
  url_idx: Optional[int] = None,
701
  user=Depends(get_verified_user),
702
  ):
703
+ return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
704
 
705
 
706
+ async def generate_ollama_embeddings(
707
  form_data: GenerateEmbeddingsForm,
708
  url_idx: Optional[int] = None,
709
  ):
710
  log.info(f"generate_ollama_embeddings {form_data}")
711
 
712
  if url_idx is None:
713
+ model_list = await get_all_models()
714
+ models = {model["model"]: model for model in model_list["models"]}
715
+
716
  model = form_data.model
717
 
718
  if ":" not in model:
719
  model = f"{model}:latest"
720
 
721
+ if model in models:
722
+ url_idx = random.choice(models[model]["urls"])
723
  else:
724
  raise HTTPException(
725
  status_code=400,
 
729
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
730
  log.info(f"url: {url}")
731
 
732
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
733
+ key = api_config.get("key", None)
734
+
735
+ headers = {"Content-Type": "application/json"}
736
+ if key:
737
+ headers["Authorization"] = f"Bearer {key}"
738
+
739
  r = requests.request(
740
  method="POST",
741
  url=f"{url}/api/embeddings",
742
+ headers=headers,
743
  data=form_data.model_dump_json(exclude_none=True).encode(),
744
  )
745
  try:
 
770
  )
771
 
772
 
773
+ async def generate_ollama_batch_embeddings(
774
  form_data: GenerateEmbedForm,
775
  url_idx: Optional[int] = None,
776
  ):
777
  log.info(f"generate_ollama_batch_embeddings {form_data}")
778
 
779
  if url_idx is None:
780
+ model_list = await get_all_models()
781
+ models = {model["model"]: model for model in model_list["models"]}
782
+
783
  model = form_data.model
784
 
785
  if ":" not in model:
786
  model = f"{model}:latest"
787
 
788
+ if model in models:
789
+ url_idx = random.choice(models[model]["urls"])
790
  else:
791
  raise HTTPException(
792
  status_code=400,
 
796
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
797
  log.info(f"url: {url}")
798
 
799
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
800
+ key = api_config.get("key", None)
801
+
802
+ headers = {"Content-Type": "application/json"}
803
+ if key:
804
+ headers["Authorization"] = f"Bearer {key}"
805
+
806
  r = requests.request(
807
  method="POST",
808
  url=f"{url}/api/embed",
809
+ headers=headers,
810
  data=form_data.model_dump_json(exclude_none=True).encode(),
811
  )
812
  try:
 
856
  user=Depends(get_verified_user),
857
  ):
858
  if url_idx is None:
859
+ model_list = await get_all_models()
860
+ models = {model["model"]: model for model in model_list["models"]}
861
+
862
  model = form_data.model
863
 
864
  if ":" not in model:
865
  model = f"{model}:latest"
866
 
867
+ if model in models:
868
+ url_idx = random.choice(models[model]["urls"])
869
  else:
870
  raise HTTPException(
871
  status_code=400,
 
873
  )
874
 
875
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
876
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
877
+ prefix_id = api_config.get("prefix_id", None)
878
+ if prefix_id:
879
+ form_data.model = form_data.model.replace(f"{prefix_id}.", "")
880
  log.info(f"url: {url}")
881
 
882
  return await post_streaming_url(
 
900
  keep_alive: Optional[Union[int, str]] = None
901
 
902
 
903
+ async def get_ollama_url(url_idx: Optional[int], model: str):
904
  if url_idx is None:
905
+ model_list = await get_all_models()
906
+ models = {model["model"]: model for model in model_list["models"]}
907
+
908
+ if model not in models:
909
  raise HTTPException(
910
  status_code=400,
911
  detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
912
  )
913
+ url_idx = random.choice(models[model]["urls"])
914
  url = app.state.config.OLLAMA_BASE_URLS[url_idx]
915
  return url
916
 
 
928
  if "metadata" in payload:
929
  del payload["metadata"]
930
 
931
+ model_id = payload["model"]
 
 
 
 
 
 
 
 
932
  model_info = Models.get_model_by_id(model_id)
933
 
934
  if model_info:
 
946
  )
947
  payload = apply_model_system_prompt_to_body(params, payload, user)
948
 
949
+ # Check if user has access to the model
950
+ if not bypass_filter and user.role == "user":
951
+ if not (
952
+ user.id == model_info.user_id
953
+ or has_access(
954
+ user.id, type="read", access_control=model_info.access_control
955
+ )
956
+ ):
957
+ raise HTTPException(
958
+ status_code=403,
959
+ detail="Model not found",
960
+ )
961
+ elif not bypass_filter:
962
+ if user.role != "admin":
963
+ raise HTTPException(
964
+ status_code=403,
965
+ detail="Model not found",
966
+ )
967
+
968
  if ":" not in payload["model"]:
969
  payload["model"] = f"{payload['model']}:latest"
970
 
971
+ url = await get_ollama_url(url_idx, payload["model"])
972
  log.info(f"url: {url}")
973
  log.debug(f"generate_chat_completion() - 2.payload = {payload}")
974
 
975
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
976
+ prefix_id = api_config.get("prefix_id", None)
977
+ if prefix_id:
978
+ payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
979
+
980
  return await post_streaming_url(
981
  f"{url}/api/chat",
982
  json.dumps(payload),
 
993
 
994
  class OpenAIChatMessage(BaseModel):
995
  role: str
996
+ content: Union[str, list[OpenAIChatMessageContent]]
997
 
998
  model_config = ConfigDict(extra="allow")
999
 
 
1012
  url_idx: Optional[int] = None,
1013
  user=Depends(get_verified_user),
1014
  ):
1015
+ try:
1016
+ completion_form = OpenAIChatCompletionForm(**form_data)
1017
+ except Exception as e:
1018
+ log.exception(e)
1019
+ raise HTTPException(
1020
+ status_code=400,
1021
+ detail=str(e),
1022
+ )
1023
+
1024
  payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
1025
  if "metadata" in payload:
1026
  del payload["metadata"]
1027
 
1028
  model_id = completion_form.model
1029
+ if ":" not in model_id:
1030
+ model_id = f"{model_id}:latest"
 
 
 
 
 
1031
 
1032
  model_info = Models.get_model_by_id(model_id)
 
1033
  if model_info:
1034
  if model_info.base_model_id:
1035
  payload["model"] = model_info.base_model_id
 
1040
  payload = apply_model_params_to_body_openai(params, payload)
1041
  payload = apply_model_system_prompt_to_body(params, payload, user)
1042
 
1043
+ # Check if user has access to the model
1044
+ if user.role == "user":
1045
+ if not (
1046
+ user.id == model_info.user_id
1047
+ or has_access(
1048
+ user.id, type="read", access_control=model_info.access_control
1049
+ )
1050
+ ):
1051
+ raise HTTPException(
1052
+ status_code=403,
1053
+ detail="Model not found",
1054
+ )
1055
+ else:
1056
+ if user.role != "admin":
1057
+ raise HTTPException(
1058
+ status_code=403,
1059
+ detail="Model not found",
1060
+ )
1061
+
1062
  if ":" not in payload["model"]:
1063
  payload["model"] = f"{payload['model']}:latest"
1064
 
1065
+ url = await get_ollama_url(url_idx, payload["model"])
1066
  log.info(f"url: {url}")
1067
 
1068
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
1069
+ prefix_id = api_config.get("prefix_id", None)
1070
+ if prefix_id:
1071
+ payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
1072
+
1073
  return await post_streaming_url(
1074
  f"{url}/v1/chat/completions",
1075
  json.dumps(payload),
 
1083
  url_idx: Optional[int] = None,
1084
  user=Depends(get_verified_user),
1085
  ):
1086
+
1087
+ models = []
1088
  if url_idx is None:
1089
+ model_list = await get_all_models()
1090
+ models = [
1091
+ {
1092
+ "id": model["model"],
1093
+ "object": "model",
1094
+ "created": int(time.time()),
1095
+ "owned_by": "openai",
1096
+ }
1097
+ for model in model_list["models"]
1098
+ ]
1099
 
1100
+ else:
1101
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
1102
+ try:
1103
+ r = requests.request(method="GET", url=f"{url}/api/tags")
1104
+ r.raise_for_status()
1105
+
1106
+ model_list = r.json()
 
 
1107
 
1108
+ models = [
 
1109
  {
1110
  "id": model["model"],
1111
  "object": "model",
 
1113
  "owned_by": "openai",
1114
  }
1115
  for model in models["models"]
1116
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1117
  except Exception as e:
1118
  log.exception(e)
1119
  error_detail = "Open WebUI: Server Connection Error"
 
1130
  detail=error_detail,
1131
  )
1132
 
1133
+ if user.role == "user":
1134
+ # Filter models based on user access control
1135
+ filtered_models = []
1136
+ for model in models:
1137
+ model_info = Models.get_model_by_id(model["id"])
1138
+ if model_info:
1139
+ if user.id == model_info.user_id or has_access(
1140
+ user.id, type="read", access_control=model_info.access_control
1141
+ ):
1142
+ filtered_models.append(model)
1143
+ models = filtered_models
1144
+
1145
+ return {
1146
+ "data": models,
1147
+ "object": "list",
1148
+ }
1149
+
1150
 
1151
  class UrlForm(BaseModel):
1152
  url: str
backend/open_webui/apps/openai/main.py CHANGED
@@ -11,20 +11,20 @@ from open_webui.apps.webui.models.models import Models
11
  from open_webui.config import (
12
  CACHE_DIR,
13
  CORS_ALLOW_ORIGIN,
14
- ENABLE_MODEL_FILTER,
15
  ENABLE_OPENAI_API,
16
- MODEL_FILTER_LIST,
17
  OPENAI_API_BASE_URLS,
18
  OPENAI_API_KEYS,
 
19
  AppConfig,
20
  )
21
  from open_webui.env import (
22
  AIOHTTP_CLIENT_TIMEOUT,
23
  AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
 
24
  )
25
 
26
  from open_webui.constants import ERROR_MESSAGES
27
- from open_webui.env import SRC_LOG_LEVELS
28
  from fastapi import Depends, FastAPI, HTTPException, Request
29
  from fastapi.middleware.cors import CORSMiddleware
30
  from fastapi.responses import FileResponse, StreamingResponse
@@ -37,11 +37,20 @@ from open_webui.utils.payload import (
37
  )
38
 
39
  from open_webui.utils.utils import get_admin_user, get_verified_user
 
 
40
 
41
  log = logging.getLogger(__name__)
42
  log.setLevel(SRC_LOG_LEVELS["OPENAI"])
43
 
44
- app = FastAPI()
 
 
 
 
 
 
 
45
  app.add_middleware(
46
  CORSMiddleware,
47
  allow_origins=CORS_ALLOW_ORIGIN,
@@ -52,69 +61,66 @@ app.add_middleware(
52
 
53
  app.state.config = AppConfig()
54
 
55
- app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
56
- app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
57
-
58
  app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
59
  app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
60
  app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
61
-
62
- app.state.MODELS = {}
63
-
64
-
65
- @app.middleware("http")
66
- async def check_url(request: Request, call_next):
67
- if len(app.state.MODELS) == 0:
68
- await get_all_models()
69
-
70
- response = await call_next(request)
71
- return response
72
 
73
 
74
  @app.get("/config")
75
  async def get_config(user=Depends(get_admin_user)):
76
- return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
 
 
 
 
 
77
 
78
 
79
  class OpenAIConfigForm(BaseModel):
80
- enable_openai_api: Optional[bool] = None
 
 
 
81
 
82
 
83
  @app.post("/config/update")
84
  async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
85
- app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
86
- return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
87
-
88
-
89
- class UrlsUpdateForm(BaseModel):
90
- urls: list[str]
91
 
 
 
92
 
93
- class KeysUpdateForm(BaseModel):
94
- keys: list[str]
95
-
96
-
97
- @app.get("/urls")
98
- async def get_openai_urls(user=Depends(get_admin_user)):
99
- return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
100
-
101
-
102
- @app.post("/urls/update")
103
- async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
104
- await get_all_models()
105
- app.state.config.OPENAI_API_BASE_URLS = form_data.urls
106
- return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
107
-
108
 
109
- @app.get("/keys")
110
- async def get_openai_keys(user=Depends(get_admin_user)):
111
- return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
112
 
 
 
 
 
 
113
 
114
- @app.post("/keys/update")
115
- async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
116
- app.state.config.OPENAI_API_KEYS = form_data.keys
117
- return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
 
 
118
 
119
 
120
  @app.post("/audio/speech")
@@ -140,6 +146,11 @@ async def speech(request: Request, user=Depends(get_verified_user)):
140
  if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
141
  headers["HTTP-Referer"] = "https://openwebui.com/"
142
  headers["X-Title"] = "Open WebUI"
 
 
 
 
 
143
  r = None
144
  try:
145
  r = requests.post(
@@ -181,10 +192,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
181
  raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
182
 
183
 
184
- async def fetch_url(url, key):
185
  timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
186
  try:
187
- headers = {"Authorization": f"Bearer {key}"}
188
  async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
189
  async with session.get(url, headers=headers) as response:
190
  return await response.json()
@@ -239,12 +250,8 @@ def merge_models_lists(model_lists):
239
  return merged_list
240
 
241
 
242
- def is_openai_api_disabled():
243
- return not app.state.config.ENABLE_OPENAI_API
244
-
245
-
246
- async def get_all_models_raw() -> list:
247
- if is_openai_api_disabled():
248
  return []
249
 
250
  # Check if API KEYS length is same than API URLS length
@@ -260,33 +267,67 @@ async def get_all_models_raw() -> list:
260
  else:
261
  app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
262
 
263
- tasks = [
264
- fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
265
- for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
266
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  responses = await asyncio.gather(*tasks)
269
- log.debug(f"get_all_models:responses() {responses}")
270
 
271
- return responses
 
 
 
272
 
 
273
 
274
- @overload
275
- async def get_all_models(raw: Literal[True]) -> list: ...
 
276
 
 
277
 
278
- @overload
279
- async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ...
280
 
281
 
282
- async def get_all_models(raw=False) -> dict[str, list] | list:
283
  log.info("get_all_models()")
284
- if is_openai_api_disabled():
285
- return [] if raw else {"data": []}
286
 
287
- responses = await get_all_models_raw()
288
- if raw:
289
- return responses
 
290
 
291
  def extract_data(response):
292
  if response and "data" in response:
@@ -296,9 +337,7 @@ async def get_all_models(raw=False) -> dict[str, list] | list:
296
  return None
297
 
298
  models = {"data": merge_models_lists(map(extract_data, responses))}
299
-
300
  log.debug(f"models: {models}")
301
- app.state.MODELS = {model["id"]: model for model in models["data"]}
302
 
303
  return models
304
 
@@ -306,18 +345,12 @@ async def get_all_models(raw=False) -> dict[str, list] | list:
306
  @app.get("/models")
307
  @app.get("/models/{url_idx}")
308
  async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
 
 
 
 
309
  if url_idx is None:
310
  models = await get_all_models()
311
- if app.state.config.ENABLE_MODEL_FILTER:
312
- if user.role == "user":
313
- models["data"] = list(
314
- filter(
315
- lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
316
- models["data"],
317
- )
318
- )
319
- return models
320
- return models
321
  else:
322
  url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
323
  key = app.state.config.OPENAI_API_KEYS[url_idx]
@@ -326,56 +359,126 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
326
  headers["Authorization"] = f"Bearer {key}"
327
  headers["Content-Type"] = "application/json"
328
 
 
 
 
 
 
 
329
  r = None
330
 
331
- try:
332
- r = requests.request(method="GET", url=f"{url}/models", headers=headers)
333
- r.raise_for_status()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
- response_data = r.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
- if "api.openai.com" in url:
338
- # Filter the response data
339
- response_data["data"] = [
340
- model
341
- for model in response_data["data"]
342
- if not any(
343
- name in model["id"]
344
- for name in [
345
- "babbage",
346
- "dall-e",
347
- "davinci",
348
- "embedding",
349
- "tts",
350
- "whisper",
351
- ]
352
- )
353
- ]
354
 
355
- return response_data
356
- except Exception as e:
357
- log.exception(e)
358
- error_detail = "Open WebUI: Server Connection Error"
359
- if r is not None:
360
- try:
361
- res = r.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  if "error" in res:
363
- error_detail = f"External: {res['error']}"
364
- except Exception:
365
- error_detail = f"External: {e}"
 
 
366
 
 
 
 
 
367
  raise HTTPException(
368
- status_code=r.status_code if r else 500,
369
- detail=error_detail,
370
  )
 
 
 
 
 
371
 
372
 
373
  @app.post("/chat/completions")
374
- @app.post("/chat/completions/{url_idx}")
375
  async def generate_chat_completion(
376
  form_data: dict,
377
- url_idx: Optional[int] = None,
378
  user=Depends(get_verified_user),
 
379
  ):
380
  idx = 0
381
  payload = {**form_data}
@@ -386,6 +489,7 @@ async def generate_chat_completion(
386
  model_id = form_data.get("model")
387
  model_info = Models.get_model_by_id(model_id)
388
 
 
389
  if model_info:
390
  if model_info.base_model_id:
391
  payload["model"] = model_info.base_model_id
@@ -394,9 +498,52 @@ async def generate_chat_completion(
394
  payload = apply_model_params_to_body_openai(params, payload)
395
  payload = apply_model_system_prompt_to_body(params, payload, user)
396
 
397
- model = app.state.MODELS[payload.get("model")]
398
- idx = model["urlIdx"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  if "pipeline" in model and model.get("pipeline"):
401
  payload["user"] = {
402
  "name": user.name,
@@ -407,8 +554,9 @@ async def generate_chat_completion(
407
 
408
  url = app.state.config.OPENAI_API_BASE_URLS[idx]
409
  key = app.state.config.OPENAI_API_KEYS[idx]
410
- is_o1 = payload["model"].lower().startswith("o1-")
411
 
 
 
412
  # Change max_completion_tokens to max_tokens (Backward compatible)
413
  if "api.openai.com" not in url and not is_o1:
414
  if "max_completion_tokens" in payload:
@@ -437,6 +585,11 @@ async def generate_chat_completion(
437
  if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
438
  headers["HTTP-Referer"] = "https://openwebui.com/"
439
  headers["X-Title"] = "Open WebUI"
 
 
 
 
 
440
 
441
  r = None
442
  session = None
@@ -505,6 +658,11 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
505
  headers = {}
506
  headers["Authorization"] = f"Bearer {key}"
507
  headers["Content-Type"] = "application/json"
 
 
 
 
 
508
 
509
  r = None
510
  session = None
 
11
  from open_webui.config import (
12
  CACHE_DIR,
13
  CORS_ALLOW_ORIGIN,
 
14
  ENABLE_OPENAI_API,
 
15
  OPENAI_API_BASE_URLS,
16
  OPENAI_API_KEYS,
17
+ OPENAI_API_CONFIGS,
18
  AppConfig,
19
  )
20
  from open_webui.env import (
21
  AIOHTTP_CLIENT_TIMEOUT,
22
  AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
23
+ ENABLE_FORWARD_USER_INFO_HEADERS,
24
  )
25
 
26
  from open_webui.constants import ERROR_MESSAGES
27
+ from open_webui.env import ENV, SRC_LOG_LEVELS
28
  from fastapi import Depends, FastAPI, HTTPException, Request
29
  from fastapi.middleware.cors import CORSMiddleware
30
  from fastapi.responses import FileResponse, StreamingResponse
 
37
  )
38
 
39
  from open_webui.utils.utils import get_admin_user, get_verified_user
40
+ from open_webui.utils.access_control import has_access
41
+
42
 
43
  log = logging.getLogger(__name__)
44
  log.setLevel(SRC_LOG_LEVELS["OPENAI"])
45
 
46
+
47
+ app = FastAPI(
48
+ docs_url="/docs" if ENV == "dev" else None,
49
+ openapi_url="/openapi.json" if ENV == "dev" else None,
50
+ redoc_url=None,
51
+ )
52
+
53
+
54
  app.add_middleware(
55
  CORSMiddleware,
56
  allow_origins=CORS_ALLOW_ORIGIN,
 
61
 
62
  app.state.config = AppConfig()
63
 
 
 
 
64
  app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
65
  app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
66
  app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
67
+ app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
  @app.get("/config")
71
  async def get_config(user=Depends(get_admin_user)):
72
+ return {
73
+ "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
74
+ "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
75
+ "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
76
+ "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
77
+ }
78
 
79
 
80
  class OpenAIConfigForm(BaseModel):
81
+ ENABLE_OPENAI_API: Optional[bool] = None
82
+ OPENAI_API_BASE_URLS: list[str]
83
+ OPENAI_API_KEYS: list[str]
84
+ OPENAI_API_CONFIGS: dict
85
 
86
 
87
  @app.post("/config/update")
88
  async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
89
+ app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API
 
 
 
 
 
90
 
91
+ app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS
92
+ app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS
93
 
94
+ # Check if API KEYS length is same than API URLS length
95
+ if len(app.state.config.OPENAI_API_KEYS) != len(
96
+ app.state.config.OPENAI_API_BASE_URLS
97
+ ):
98
+ if len(app.state.config.OPENAI_API_KEYS) > len(
99
+ app.state.config.OPENAI_API_BASE_URLS
100
+ ):
101
+ app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
102
+ : len(app.state.config.OPENAI_API_BASE_URLS)
103
+ ]
104
+ else:
105
+ app.state.config.OPENAI_API_KEYS += [""] * (
106
+ len(app.state.config.OPENAI_API_BASE_URLS)
107
+ - len(app.state.config.OPENAI_API_KEYS)
108
+ )
109
 
110
+ app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
 
 
111
 
112
+ # Remove any extra configs
113
+ config_urls = app.state.config.OPENAI_API_CONFIGS.keys()
114
+ for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
115
+ if url not in config_urls:
116
+ app.state.config.OPENAI_API_CONFIGS.pop(url, None)
117
 
118
+ return {
119
+ "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
120
+ "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
121
+ "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
122
+ "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
123
+ }
124
 
125
 
126
  @app.post("/audio/speech")
 
146
  if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
147
  headers["HTTP-Referer"] = "https://openwebui.com/"
148
  headers["X-Title"] = "Open WebUI"
149
+ if ENABLE_FORWARD_USER_INFO_HEADERS:
150
+ headers["X-OpenWebUI-User-Name"] = user.name
151
+ headers["X-OpenWebUI-User-Id"] = user.id
152
+ headers["X-OpenWebUI-User-Email"] = user.email
153
+ headers["X-OpenWebUI-User-Role"] = user.role
154
  r = None
155
  try:
156
  r = requests.post(
 
192
  raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
193
 
194
 
195
+ async def aiohttp_get(url, key=None):
196
  timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
197
  try:
198
+ headers = {"Authorization": f"Bearer {key}"} if key else {}
199
  async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
200
  async with session.get(url, headers=headers) as response:
201
  return await response.json()
 
250
  return merged_list
251
 
252
 
253
+ async def get_all_models_responses() -> list:
254
+ if not app.state.config.ENABLE_OPENAI_API:
 
 
 
 
255
  return []
256
 
257
  # Check if API KEYS length is same than API URLS length
 
267
  else:
268
  app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
269
 
270
+ tasks = []
271
+ for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
272
+ if url not in app.state.config.OPENAI_API_CONFIGS:
273
+ tasks.append(
274
+ aiohttp_get(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
275
+ )
276
+ else:
277
+ api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
278
+
279
+ enable = api_config.get("enable", True)
280
+ model_ids = api_config.get("model_ids", [])
281
+
282
+ if enable:
283
+ if len(model_ids) == 0:
284
+ tasks.append(
285
+ aiohttp_get(
286
+ f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]
287
+ )
288
+ )
289
+ else:
290
+ model_list = {
291
+ "object": "list",
292
+ "data": [
293
+ {
294
+ "id": model_id,
295
+ "name": model_id,
296
+ "owned_by": "openai",
297
+ "openai": {"id": model_id},
298
+ "urlIdx": idx,
299
+ }
300
+ for model_id in model_ids
301
+ ],
302
+ }
303
+
304
+ tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list)))
305
 
306
  responses = await asyncio.gather(*tasks)
 
307
 
308
+ for idx, response in enumerate(responses):
309
+ if response:
310
+ url = app.state.config.OPENAI_API_BASE_URLS[idx]
311
+ api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
312
 
313
+ prefix_id = api_config.get("prefix_id", None)
314
 
315
+ if prefix_id:
316
+ for model in response["data"]:
317
+ model["id"] = f"{prefix_id}.{model['id']}"
318
 
319
+ log.debug(f"get_all_models:responses() {responses}")
320
 
321
+ return responses
 
322
 
323
 
324
+ async def get_all_models() -> dict[str, list]:
325
  log.info("get_all_models()")
 
 
326
 
327
+ if not app.state.config.ENABLE_OPENAI_API:
328
+ return {"data": []}
329
+
330
+ responses = await get_all_models_responses()
331
 
332
  def extract_data(response):
333
  if response and "data" in response:
 
337
  return None
338
 
339
  models = {"data": merge_models_lists(map(extract_data, responses))}
 
340
  log.debug(f"models: {models}")
 
341
 
342
  return models
343
 
 
345
  @app.get("/models")
346
  @app.get("/models/{url_idx}")
347
  async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
348
+ models = {
349
+ "data": [],
350
+ }
351
+
352
  if url_idx is None:
353
  models = await get_all_models()
 
 
 
 
 
 
 
 
 
 
354
  else:
355
  url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
356
  key = app.state.config.OPENAI_API_KEYS[url_idx]
 
359
  headers["Authorization"] = f"Bearer {key}"
360
  headers["Content-Type"] = "application/json"
361
 
362
+ if ENABLE_FORWARD_USER_INFO_HEADERS:
363
+ headers["X-OpenWebUI-User-Name"] = user.name
364
+ headers["X-OpenWebUI-User-Id"] = user.id
365
+ headers["X-OpenWebUI-User-Email"] = user.email
366
+ headers["X-OpenWebUI-User-Role"] = user.role
367
+
368
  r = None
369
 
370
+ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
371
+ async with aiohttp.ClientSession(timeout=timeout) as session:
372
+ try:
373
+ async with session.get(f"{url}/models", headers=headers) as r:
374
+ if r.status != 200:
375
+ # Extract response error details if available
376
+ error_detail = f"HTTP Error: {r.status}"
377
+ res = await r.json()
378
+ if "error" in res:
379
+ error_detail = f"External Error: {res['error']}"
380
+ raise Exception(error_detail)
381
+
382
+ response_data = await r.json()
383
+
384
+ # Check if we're calling OpenAI API based on the URL
385
+ if "api.openai.com" in url:
386
+ # Filter models according to the specified conditions
387
+ response_data["data"] = [
388
+ model
389
+ for model in response_data.get("data", [])
390
+ if not any(
391
+ name in model["id"]
392
+ for name in [
393
+ "babbage",
394
+ "dall-e",
395
+ "davinci",
396
+ "embedding",
397
+ "tts",
398
+ "whisper",
399
+ ]
400
+ )
401
+ ]
402
 
403
+ models = response_data
404
+ except aiohttp.ClientError as e:
405
+ # ClientError covers all aiohttp requests issues
406
+ log.exception(f"Client error: {str(e)}")
407
+ # Handle aiohttp-specific connection issues, timeout etc.
408
+ raise HTTPException(
409
+ status_code=500, detail="Open WebUI: Server Connection Error"
410
+ )
411
+ except Exception as e:
412
+ log.exception(f"Unexpected error: {e}")
413
+ # Generic error handler in case parsing JSON or other steps fail
414
+ error_detail = f"Unexpected error: {str(e)}"
415
+ raise HTTPException(status_code=500, detail=error_detail)
416
+
417
+ if user.role == "user":
418
+ # Filter models based on user access control
419
+ filtered_models = []
420
+ for model in models.get("data", []):
421
+ model_info = Models.get_model_by_id(model["id"])
422
+ if model_info:
423
+ if user.id == model_info.user_id or has_access(
424
+ user.id, type="read", access_control=model_info.access_control
425
+ ):
426
+ filtered_models.append(model)
427
+ models["data"] = filtered_models
428
 
429
+ return models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
+
432
+ class ConnectionVerificationForm(BaseModel):
433
+ url: str
434
+ key: str
435
+
436
+
437
+ @app.post("/verify")
438
+ async def verify_connection(
439
+ form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
440
+ ):
441
+ url = form_data.url
442
+ key = form_data.key
443
+
444
+ headers = {}
445
+ headers["Authorization"] = f"Bearer {key}"
446
+ headers["Content-Type"] = "application/json"
447
+
448
+ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
449
+ async with aiohttp.ClientSession(timeout=timeout) as session:
450
+ try:
451
+ async with session.get(f"{url}/models", headers=headers) as r:
452
+ if r.status != 200:
453
+ # Extract response error details if available
454
+ error_detail = f"HTTP Error: {r.status}"
455
+ res = await r.json()
456
  if "error" in res:
457
+ error_detail = f"External Error: {res['error']}"
458
+ raise Exception(error_detail)
459
+
460
+ response_data = await r.json()
461
+ return response_data
462
 
463
+ except aiohttp.ClientError as e:
464
+ # ClientError covers all aiohttp requests issues
465
+ log.exception(f"Client error: {str(e)}")
466
+ # Handle aiohttp-specific connection issues, timeout etc.
467
  raise HTTPException(
468
+ status_code=500, detail="Open WebUI: Server Connection Error"
 
469
  )
470
+ except Exception as e:
471
+ log.exception(f"Unexpected error: {e}")
472
+ # Generic error handler in case parsing JSON or other steps fail
473
+ error_detail = f"Unexpected error: {str(e)}"
474
+ raise HTTPException(status_code=500, detail=error_detail)
475
 
476
 
477
  @app.post("/chat/completions")
 
478
  async def generate_chat_completion(
479
  form_data: dict,
 
480
  user=Depends(get_verified_user),
481
+ bypass_filter: Optional[bool] = False,
482
  ):
483
  idx = 0
484
  payload = {**form_data}
 
489
  model_id = form_data.get("model")
490
  model_info = Models.get_model_by_id(model_id)
491
 
492
+ # Check model info and override the payload
493
  if model_info:
494
  if model_info.base_model_id:
495
  payload["model"] = model_info.base_model_id
 
498
  payload = apply_model_params_to_body_openai(params, payload)
499
  payload = apply_model_system_prompt_to_body(params, payload, user)
500
 
501
+ # Check if user has access to the model
502
+ if not bypass_filter and user.role == "user":
503
+ if not (
504
+ user.id == model_info.user_id
505
+ or has_access(
506
+ user.id, type="read", access_control=model_info.access_control
507
+ )
508
+ ):
509
+ raise HTTPException(
510
+ status_code=403,
511
+ detail="Model not found",
512
+ )
513
+ elif not bypass_filter:
514
+ if user.role != "admin":
515
+ raise HTTPException(
516
+ status_code=403,
517
+ detail="Model not found",
518
+ )
519
+
520
+ # Attemp to get urlIdx from the model
521
+ models = await get_all_models()
522
 
523
+ # Find the model from the list
524
+ model = next(
525
+ (model for model in models["data"] if model["id"] == payload.get("model")),
526
+ None,
527
+ )
528
+
529
+ if model:
530
+ idx = model["urlIdx"]
531
+ else:
532
+ raise HTTPException(
533
+ status_code=404,
534
+ detail="Model not found",
535
+ )
536
+
537
+ # Get the API config for the model
538
+ api_config = app.state.config.OPENAI_API_CONFIGS.get(
539
+ app.state.config.OPENAI_API_BASE_URLS[idx], {}
540
+ )
541
+ prefix_id = api_config.get("prefix_id", None)
542
+
543
+ if prefix_id:
544
+ payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
545
+
546
+ # Add user info to the payload if the model is a pipeline
547
  if "pipeline" in model and model.get("pipeline"):
548
  payload["user"] = {
549
  "name": user.name,
 
554
 
555
  url = app.state.config.OPENAI_API_BASE_URLS[idx]
556
  key = app.state.config.OPENAI_API_KEYS[idx]
 
557
 
558
+ # Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
559
+ is_o1 = payload["model"].lower().startswith("o1-")
560
  # Change max_completion_tokens to max_tokens (Backward compatible)
561
  if "api.openai.com" not in url and not is_o1:
562
  if "max_completion_tokens" in payload:
 
585
  if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
586
  headers["HTTP-Referer"] = "https://openwebui.com/"
587
  headers["X-Title"] = "Open WebUI"
588
+ if ENABLE_FORWARD_USER_INFO_HEADERS:
589
+ headers["X-OpenWebUI-User-Name"] = user.name
590
+ headers["X-OpenWebUI-User-Id"] = user.id
591
+ headers["X-OpenWebUI-User-Email"] = user.email
592
+ headers["X-OpenWebUI-User-Role"] = user.role
593
 
594
  r = None
595
  session = None
 
658
  headers = {}
659
  headers["Authorization"] = f"Bearer {key}"
660
  headers["Content-Type"] = "application/json"
661
+ if ENABLE_FORWARD_USER_INFO_HEADERS:
662
+ headers["X-OpenWebUI-User-Name"] = user.name
663
+ headers["X-OpenWebUI-User-Id"] = user.id
664
+ headers["X-OpenWebUI-User-Email"] = user.email
665
+ headers["X-OpenWebUI-User-Role"] = user.role
666
 
667
  r = None
668
  session = None
backend/open_webui/apps/retrieval/loaders/main.py CHANGED
@@ -159,7 +159,7 @@ class Loader:
159
  elif file_ext in ["htm", "html"]:
160
  loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
161
  elif file_ext == "md":
162
- loader = UnstructuredMarkdownLoader(file_path)
163
  elif file_content_type == "application/epub+zip":
164
  loader = UnstructuredEPubLoader(file_path)
165
  elif (
 
159
  elif file_ext in ["htm", "html"]:
160
  loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
161
  elif file_ext == "md":
162
+ loader = TextLoader(file_path, autodetect_encoding=True)
163
  elif file_content_type == "application/epub+zip":
164
  loader = UnstructuredEPubLoader(file_path)
165
  elif (
backend/open_webui/apps/retrieval/main.py CHANGED
@@ -37,6 +37,7 @@ from open_webui.apps.retrieval.web.serper import search_serper
37
  from open_webui.apps.retrieval.web.serply import search_serply
38
  from open_webui.apps.retrieval.web.serpstack import search_serpstack
39
  from open_webui.apps.retrieval.web.tavily import search_tavily
 
40
 
41
 
42
  from open_webui.apps.retrieval.utils import (
@@ -74,6 +75,8 @@ from open_webui.config import (
74
  RAG_FILE_MAX_SIZE,
75
  RAG_OPENAI_API_BASE_URL,
76
  RAG_OPENAI_API_KEY,
 
 
77
  RAG_RELEVANCE_THRESHOLD,
78
  RAG_RERANKING_MODEL,
79
  RAG_RERANKING_MODEL_AUTO_UPDATE,
@@ -85,6 +88,7 @@ from open_webui.config import (
85
  RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
86
  RAG_WEB_SEARCH_ENGINE,
87
  RAG_WEB_SEARCH_RESULT_COUNT,
 
88
  SEARCHAPI_API_KEY,
89
  SEARCHAPI_ENGINE,
90
  SEARXNG_QUERY_URL,
@@ -93,13 +97,20 @@ from open_webui.config import (
93
  SERPSTACK_API_KEY,
94
  SERPSTACK_HTTPS,
95
  TAVILY_API_KEY,
 
 
96
  TIKA_SERVER_URL,
97
  UPLOAD_DIR,
98
  YOUTUBE_LOADER_LANGUAGE,
 
99
  AppConfig,
100
  )
101
  from open_webui.constants import ERROR_MESSAGES
102
- from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER
 
 
 
 
103
  from open_webui.utils.misc import (
104
  calculate_sha256,
105
  calculate_sha256_string,
@@ -118,7 +129,11 @@ from langchain_core.documents import Document
118
  log = logging.getLogger(__name__)
119
  log.setLevel(SRC_LOG_LEVELS["RAG"])
120
 
121
- app = FastAPI()
 
 
 
 
122
 
123
  app.state.config = AppConfig()
124
 
@@ -150,6 +165,9 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
150
  app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
151
  app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
152
 
 
 
 
153
  app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
154
 
155
  app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
@@ -171,6 +189,10 @@ app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
171
  app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
172
  app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
173
  app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
 
 
 
 
174
  app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
175
  app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
176
 
@@ -182,11 +204,15 @@ def update_embedding_model(
182
  if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
183
  from sentence_transformers import SentenceTransformer
184
 
185
- app.state.sentence_transformer_ef = SentenceTransformer(
186
- get_model_path(embedding_model, auto_update),
187
- device=DEVICE_TYPE,
188
- trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
189
- )
 
 
 
 
190
  else:
191
  app.state.sentence_transformer_ef = None
192
 
@@ -240,8 +266,16 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
240
  app.state.config.RAG_EMBEDDING_ENGINE,
241
  app.state.config.RAG_EMBEDDING_MODEL,
242
  app.state.sentence_transformer_ef,
243
- app.state.config.OPENAI_API_KEY,
244
- app.state.config.OPENAI_API_BASE_URL,
 
 
 
 
 
 
 
 
245
  app.state.config.RAG_EMBEDDING_BATCH_SIZE,
246
  )
247
 
@@ -291,6 +325,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
291
  "url": app.state.config.OPENAI_API_BASE_URL,
292
  "key": app.state.config.OPENAI_API_KEY,
293
  },
 
 
 
 
294
  }
295
 
296
 
@@ -307,8 +345,14 @@ class OpenAIConfigForm(BaseModel):
307
  key: str
308
 
309
 
 
 
 
 
 
310
  class EmbeddingModelUpdateForm(BaseModel):
311
  openai_config: Optional[OpenAIConfigForm] = None
 
312
  embedding_engine: str
313
  embedding_model: str
314
  embedding_batch_size: Optional[int] = 1
@@ -329,6 +373,11 @@ async def update_embedding_config(
329
  if form_data.openai_config is not None:
330
  app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
331
  app.state.config.OPENAI_API_KEY = form_data.openai_config.key
 
 
 
 
 
332
  app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
333
 
334
  update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
@@ -337,8 +386,16 @@ async def update_embedding_config(
337
  app.state.config.RAG_EMBEDDING_ENGINE,
338
  app.state.config.RAG_EMBEDDING_MODEL,
339
  app.state.sentence_transformer_ef,
340
- app.state.config.OPENAI_API_KEY,
341
- app.state.config.OPENAI_API_BASE_URL,
 
 
 
 
 
 
 
 
342
  app.state.config.RAG_EMBEDDING_BATCH_SIZE,
343
  )
344
 
@@ -351,6 +408,10 @@ async def update_embedding_config(
351
  "url": app.state.config.OPENAI_API_BASE_URL,
352
  "key": app.state.config.OPENAI_API_KEY,
353
  },
 
 
 
 
354
  }
355
  except Exception as e:
356
  log.exception(f"Problem updating embedding model: {e}")
@@ -411,7 +472,7 @@ async def get_rag_config(user=Depends(get_admin_user)):
411
  "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
412
  },
413
  "web": {
414
- "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
415
  "search": {
416
  "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
417
  "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
@@ -426,6 +487,9 @@ async def get_rag_config(user=Depends(get_admin_user)):
426
  "tavily_api_key": app.state.config.TAVILY_API_KEY,
427
  "searchapi_api_key": app.state.config.SEARCHAPI_API_KEY,
428
  "seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE,
 
 
 
429
  "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
430
  "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
431
  },
@@ -468,6 +532,9 @@ class WebSearchConfig(BaseModel):
468
  tavily_api_key: Optional[str] = None
469
  searchapi_api_key: Optional[str] = None
470
  searchapi_engine: Optional[str] = None
 
 
 
471
  result_count: Optional[int] = None
472
  concurrent_requests: Optional[int] = None
473
 
@@ -514,6 +581,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
514
 
515
  if form_data.web is not None:
516
  app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
 
517
  form_data.web.web_loader_ssl_verification
518
  )
519
 
@@ -534,6 +602,15 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
534
  app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
535
  app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
536
  app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine
 
 
 
 
 
 
 
 
 
537
  app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
538
  app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
539
  form_data.web.search.concurrent_requests
@@ -560,7 +637,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
560
  "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
561
  },
562
  "web": {
563
- "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
564
  "search": {
565
  "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
566
  "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
@@ -575,6 +652,9 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
575
  "serachapi_api_key": app.state.config.SEARCHAPI_API_KEY,
576
  "searchapi_engine": app.state.config.SEARCHAPI_ENGINE,
577
  "tavily_api_key": app.state.config.TAVILY_API_KEY,
 
 
 
578
  "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
579
  "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
580
  },
@@ -636,6 +716,23 @@ async def update_query_settings(
636
  ####################################
637
 
638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
  def save_docs_to_vector_db(
640
  docs,
641
  collection_name,
@@ -644,7 +741,9 @@ def save_docs_to_vector_db(
644
  split: bool = True,
645
  add: bool = False,
646
  ) -> bool:
647
- log.info(f"save_docs_to_vector_db {docs} {collection_name}")
 
 
648
 
649
  # Check if entries with the same hash (metadata.hash) already exist
650
  if metadata and "hash" in metadata:
@@ -726,8 +825,16 @@ def save_docs_to_vector_db(
726
  app.state.config.RAG_EMBEDDING_ENGINE,
727
  app.state.config.RAG_EMBEDDING_MODEL,
728
  app.state.sentence_transformer_ef,
729
- app.state.config.OPENAI_API_KEY,
730
- app.state.config.OPENAI_API_BASE_URL,
 
 
 
 
 
 
 
 
731
  app.state.config.RAG_EMBEDDING_BATCH_SIZE,
732
  )
733
 
@@ -954,7 +1061,7 @@ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_u
954
 
955
  loader = YoutubeLoader.from_youtube_url(
956
  form_data.url,
957
- add_video_info=True,
958
  language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
959
  translation=app.state.YOUTUBE_LOADER_TRANSLATION,
960
  )
@@ -1132,7 +1239,20 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
1132
  else:
1133
  raise Exception("No SEARCHAPI_API_KEY found in environment variables")
1134
  elif engine == "jina":
1135
- return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
 
 
 
 
 
 
 
 
 
 
 
 
 
1136
  else:
1137
  raise Exception("No search engine API key found in environment variables")
1138
 
@@ -1162,8 +1282,12 @@ def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
1162
 
1163
  urls = [result.link for result in web_results]
1164
 
1165
- loader = get_web_loader(urls)
1166
- docs = loader.load()
 
 
 
 
1167
 
1168
  save_docs_to_vector_db(docs, collection_name, overwrite=True)
1169
 
 
37
  from open_webui.apps.retrieval.web.serply import search_serply
38
  from open_webui.apps.retrieval.web.serpstack import search_serpstack
39
  from open_webui.apps.retrieval.web.tavily import search_tavily
40
+ from open_webui.apps.retrieval.web.bing import search_bing
41
 
42
 
43
  from open_webui.apps.retrieval.utils import (
 
75
  RAG_FILE_MAX_SIZE,
76
  RAG_OPENAI_API_BASE_URL,
77
  RAG_OPENAI_API_KEY,
78
+ RAG_OLLAMA_BASE_URL,
79
+ RAG_OLLAMA_API_KEY,
80
  RAG_RELEVANCE_THRESHOLD,
81
  RAG_RERANKING_MODEL,
82
  RAG_RERANKING_MODEL_AUTO_UPDATE,
 
88
  RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
89
  RAG_WEB_SEARCH_ENGINE,
90
  RAG_WEB_SEARCH_RESULT_COUNT,
91
+ JINA_API_KEY,
92
  SEARCHAPI_API_KEY,
93
  SEARCHAPI_ENGINE,
94
  SEARXNG_QUERY_URL,
 
97
  SERPSTACK_API_KEY,
98
  SERPSTACK_HTTPS,
99
  TAVILY_API_KEY,
100
+ BING_SEARCH_V7_ENDPOINT,
101
+ BING_SEARCH_V7_SUBSCRIPTION_KEY,
102
  TIKA_SERVER_URL,
103
  UPLOAD_DIR,
104
  YOUTUBE_LOADER_LANGUAGE,
105
+ DEFAULT_LOCALE,
106
  AppConfig,
107
  )
108
  from open_webui.constants import ERROR_MESSAGES
109
+ from open_webui.env import (
110
+ SRC_LOG_LEVELS,
111
+ DEVICE_TYPE,
112
+ DOCKER,
113
+ )
114
  from open_webui.utils.misc import (
115
  calculate_sha256,
116
  calculate_sha256_string,
 
129
  log = logging.getLogger(__name__)
130
  log.setLevel(SRC_LOG_LEVELS["RAG"])
131
 
132
+ app = FastAPI(
133
+ docs_url="/docs" if ENV == "dev" else None,
134
+ openapi_url="/openapi.json" if ENV == "dev" else None,
135
+ redoc_url=None,
136
+ )
137
 
138
  app.state.config = AppConfig()
139
 
 
165
  app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
166
  app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
167
 
168
+ app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
169
+ app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
170
+
171
  app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
172
 
173
  app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
 
189
  app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
190
  app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
191
  app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
192
+ app.state.config.JINA_API_KEY = JINA_API_KEY
193
+ app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
194
+ app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
195
+
196
  app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
197
  app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
198
 
 
204
  if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
205
  from sentence_transformers import SentenceTransformer
206
 
207
+ try:
208
+ app.state.sentence_transformer_ef = SentenceTransformer(
209
+ get_model_path(embedding_model, auto_update),
210
+ device=DEVICE_TYPE,
211
+ trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
212
+ )
213
+ except Exception as e:
214
+ log.debug(f"Error loading SentenceTransformer: {e}")
215
+ app.state.sentence_transformer_ef = None
216
  else:
217
  app.state.sentence_transformer_ef = None
218
 
 
266
  app.state.config.RAG_EMBEDDING_ENGINE,
267
  app.state.config.RAG_EMBEDDING_MODEL,
268
  app.state.sentence_transformer_ef,
269
+ (
270
+ app.state.config.OPENAI_API_BASE_URL
271
+ if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
272
+ else app.state.config.OLLAMA_BASE_URL
273
+ ),
274
+ (
275
+ app.state.config.OPENAI_API_KEY
276
+ if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
277
+ else app.state.config.OLLAMA_API_KEY
278
+ ),
279
  app.state.config.RAG_EMBEDDING_BATCH_SIZE,
280
  )
281
 
 
325
  "url": app.state.config.OPENAI_API_BASE_URL,
326
  "key": app.state.config.OPENAI_API_KEY,
327
  },
328
+ "ollama_config": {
329
+ "url": app.state.config.OLLAMA_BASE_URL,
330
+ "key": app.state.config.OLLAMA_API_KEY,
331
+ },
332
  }
333
 
334
 
 
345
  key: str
346
 
347
 
348
+ class OllamaConfigForm(BaseModel):
349
+ url: str
350
+ key: str
351
+
352
+
353
  class EmbeddingModelUpdateForm(BaseModel):
354
  openai_config: Optional[OpenAIConfigForm] = None
355
+ ollama_config: Optional[OllamaConfigForm] = None
356
  embedding_engine: str
357
  embedding_model: str
358
  embedding_batch_size: Optional[int] = 1
 
373
  if form_data.openai_config is not None:
374
  app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
375
  app.state.config.OPENAI_API_KEY = form_data.openai_config.key
376
+
377
+ if form_data.ollama_config is not None:
378
+ app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url
379
+ app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key
380
+
381
  app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
382
 
383
  update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
 
386
  app.state.config.RAG_EMBEDDING_ENGINE,
387
  app.state.config.RAG_EMBEDDING_MODEL,
388
  app.state.sentence_transformer_ef,
389
+ (
390
+ app.state.config.OPENAI_API_BASE_URL
391
+ if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
392
+ else app.state.config.OLLAMA_BASE_URL
393
+ ),
394
+ (
395
+ app.state.config.OPENAI_API_KEY
396
+ if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
397
+ else app.state.config.OLLAMA_API_KEY
398
+ ),
399
  app.state.config.RAG_EMBEDDING_BATCH_SIZE,
400
  )
401
 
 
408
  "url": app.state.config.OPENAI_API_BASE_URL,
409
  "key": app.state.config.OPENAI_API_KEY,
410
  },
411
+ "ollama_config": {
412
+ "url": app.state.config.OLLAMA_BASE_URL,
413
+ "key": app.state.config.OLLAMA_API_KEY,
414
+ },
415
  }
416
  except Exception as e:
417
  log.exception(f"Problem updating embedding model: {e}")
 
472
  "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
473
  },
474
  "web": {
475
+ "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
476
  "search": {
477
  "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
478
  "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
 
487
  "tavily_api_key": app.state.config.TAVILY_API_KEY,
488
  "searchapi_api_key": app.state.config.SEARCHAPI_API_KEY,
489
  "seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE,
490
+ "jina_api_key": app.state.config.JINA_API_KEY,
491
+ "bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT,
492
+ "bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
493
  "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
494
  "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
495
  },
 
532
  tavily_api_key: Optional[str] = None
533
  searchapi_api_key: Optional[str] = None
534
  searchapi_engine: Optional[str] = None
535
+ jina_api_key: Optional[str] = None
536
+ bing_search_v7_endpoint: Optional[str] = None
537
+ bing_search_v7_subscription_key: Optional[str] = None
538
  result_count: Optional[int] = None
539
  concurrent_requests: Optional[int] = None
540
 
 
581
 
582
  if form_data.web is not None:
583
  app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
584
+ # Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
585
  form_data.web.web_loader_ssl_verification
586
  )
587
 
 
602
  app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
603
  app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
604
  app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine
605
+
606
+ app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key
607
+ app.state.config.BING_SEARCH_V7_ENDPOINT = (
608
+ form_data.web.search.bing_search_v7_endpoint
609
+ )
610
+ app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = (
611
+ form_data.web.search.bing_search_v7_subscription_key
612
+ )
613
+
614
  app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
615
  app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
616
  form_data.web.search.concurrent_requests
 
637
  "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
638
  },
639
  "web": {
640
+ "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
641
  "search": {
642
  "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
643
  "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
 
652
  "serachapi_api_key": app.state.config.SEARCHAPI_API_KEY,
653
  "searchapi_engine": app.state.config.SEARCHAPI_ENGINE,
654
  "tavily_api_key": app.state.config.TAVILY_API_KEY,
655
+ "jina_api_key": app.state.config.JINA_API_KEY,
656
+ "bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT,
657
+ "bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
658
  "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
659
  "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
660
  },
 
716
  ####################################
717
 
718
 
719
+ def _get_docs_info(docs: list[Document]) -> str:
720
+ docs_info = set()
721
+
722
+ # Trying to select relevant metadata identifying the document.
723
+ for doc in docs:
724
+ metadata = getattr(doc, "metadata", {})
725
+ doc_name = metadata.get("name", "")
726
+ if not doc_name:
727
+ doc_name = metadata.get("title", "")
728
+ if not doc_name:
729
+ doc_name = metadata.get("source", "")
730
+ if doc_name:
731
+ docs_info.add(doc_name)
732
+
733
+ return ", ".join(docs_info)
734
+
735
+
736
  def save_docs_to_vector_db(
737
  docs,
738
  collection_name,
 
741
  split: bool = True,
742
  add: bool = False,
743
  ) -> bool:
744
+ log.info(
745
+ f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
746
+ )
747
 
748
  # Check if entries with the same hash (metadata.hash) already exist
749
  if metadata and "hash" in metadata:
 
825
  app.state.config.RAG_EMBEDDING_ENGINE,
826
  app.state.config.RAG_EMBEDDING_MODEL,
827
  app.state.sentence_transformer_ef,
828
+ (
829
+ app.state.config.OPENAI_API_BASE_URL
830
+ if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
831
+ else app.state.config.OLLAMA_BASE_URL
832
+ ),
833
+ (
834
+ app.state.config.OPENAI_API_KEY
835
+ if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
836
+ else app.state.config.OLLAMA_API_KEY
837
+ ),
838
  app.state.config.RAG_EMBEDDING_BATCH_SIZE,
839
  )
840
 
 
1061
 
1062
  loader = YoutubeLoader.from_youtube_url(
1063
  form_data.url,
1064
+ add_video_info=False,
1065
  language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
1066
  translation=app.state.YOUTUBE_LOADER_TRANSLATION,
1067
  )
 
1239
  else:
1240
  raise Exception("No SEARCHAPI_API_KEY found in environment variables")
1241
  elif engine == "jina":
1242
+ return search_jina(
1243
+ app.state.config.JINA_API_KEY,
1244
+ query,
1245
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1246
+ )
1247
+ elif engine == "bing":
1248
+ return search_bing(
1249
+ app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
1250
+ app.state.config.BING_SEARCH_V7_ENDPOINT,
1251
+ str(DEFAULT_LOCALE),
1252
+ query,
1253
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1254
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1255
+ )
1256
  else:
1257
  raise Exception("No search engine API key found in environment variables")
1258
 
 
1282
 
1283
  urls = [result.link for result in web_results]
1284
 
1285
+ loader = get_web_loader(
1286
+ urls,
1287
+ verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
1288
+ requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
1289
+ )
1290
+ docs = loader.aload()
1291
 
1292
  save_docs_to_vector_db(docs, collection_name, overwrite=True)
1293
 
backend/open_webui/apps/retrieval/utils.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import uuid
4
  from typing import Optional, Union
5
 
 
6
  import requests
7
 
8
  from huggingface_hub import snapshot_download
@@ -10,11 +11,6 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
10
  from langchain_community.retrievers import BM25Retriever
11
  from langchain_core.documents import Document
12
 
13
-
14
- from open_webui.apps.ollama.main import (
15
- GenerateEmbedForm,
16
- generate_ollama_batch_embeddings,
17
- )
18
  from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
19
  from open_webui.utils.misc import get_last_user_message
20
 
@@ -76,7 +72,7 @@ def query_doc(
76
  limit=k,
77
  )
78
 
79
- log.info(f"query_doc:result {result}")
80
  return result
81
  except Exception as e:
82
  print(e)
@@ -127,7 +123,10 @@ def query_doc_with_hybrid_search(
127
  "metadatas": [[d.metadata for d in result]],
128
  }
129
 
130
- log.info(f"query_doc_with_hybrid_search:result {result}")
 
 
 
131
  return result
132
  except Exception as e:
133
  raise e
@@ -178,35 +177,34 @@ def merge_and_sort_query_results(
178
 
179
  def query_collection(
180
  collection_names: list[str],
181
- query: str,
182
  embedding_function,
183
  k: int,
184
  ) -> dict:
185
-
186
  results = []
187
- query_embedding = embedding_function(query)
188
-
189
- for collection_name in collection_names:
190
- if collection_name:
191
- try:
192
- result = query_doc(
193
- collection_name=collection_name,
194
- k=k,
195
- query_embedding=query_embedding,
196
- )
197
- if result is not None:
198
- results.append(result.model_dump())
199
- except Exception as e:
200
- log.exception(f"Error when querying the collection: {e}")
201
- else:
202
- pass
203
 
204
  return merge_and_sort_query_results(results, k=k)
205
 
206
 
207
  def query_collection_with_hybrid_search(
208
  collection_names: list[str],
209
- query: str,
210
  embedding_function,
211
  k: int,
212
  reranking_function,
@@ -216,15 +214,16 @@ def query_collection_with_hybrid_search(
216
  error = False
217
  for collection_name in collection_names:
218
  try:
219
- result = query_doc_with_hybrid_search(
220
- collection_name=collection_name,
221
- query=query,
222
- embedding_function=embedding_function,
223
- k=k,
224
- reranking_function=reranking_function,
225
- r=r,
226
- )
227
- results.append(result)
 
228
  except Exception as e:
229
  log.exception(
230
  "Error when querying the collection with " f"hybrid_search: {e}"
@@ -281,8 +280,8 @@ def get_embedding_function(
281
  embedding_engine,
282
  embedding_model,
283
  embedding_function,
284
- openai_key,
285
- openai_url,
286
  embedding_batch_size,
287
  ):
288
  if embedding_engine == "":
@@ -292,8 +291,8 @@ def get_embedding_function(
292
  engine=embedding_engine,
293
  model=embedding_model,
294
  text=query,
295
- key=openai_key if embedding_engine == "openai" else "",
296
- url=openai_url if embedding_engine == "openai" else "",
297
  )
298
 
299
  def generate_multiple(query, func):
@@ -310,15 +309,14 @@ def get_embedding_function(
310
 
311
  def get_rag_context(
312
  files,
313
- messages,
314
  embedding_function,
315
  k,
316
  reranking_function,
317
  r,
318
  hybrid_search,
319
  ):
320
- log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
321
- query = get_last_user_message(messages)
322
 
323
  extracted_collections = []
324
  relevant_contexts = []
@@ -360,7 +358,7 @@ def get_rag_context(
360
  try:
361
  context = query_collection_with_hybrid_search(
362
  collection_names=collection_names,
363
- query=query,
364
  embedding_function=embedding_function,
365
  k=k,
366
  reranking_function=reranking_function,
@@ -375,7 +373,7 @@ def get_rag_context(
375
  if (not hybrid_search) or (context is None):
376
  context = query_collection(
377
  collection_names=collection_names,
378
- query=query,
379
  embedding_function=embedding_function,
380
  k=k,
381
  )
@@ -467,7 +465,7 @@ def get_model_path(model: str, update_model: bool = False):
467
 
468
 
469
  def generate_openai_batch_embeddings(
470
- model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
471
  ) -> Optional[list[list[float]]]:
472
  try:
473
  r = requests.post(
@@ -489,29 +487,50 @@ def generate_openai_batch_embeddings(
489
  return None
490
 
491
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
  def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
 
 
 
493
  if engine == "ollama":
494
  if isinstance(text, list):
495
  embeddings = generate_ollama_batch_embeddings(
496
- GenerateEmbedForm(**{"model": model, "input": text})
497
  )
498
  else:
499
  embeddings = generate_ollama_batch_embeddings(
500
- GenerateEmbedForm(**{"model": model, "input": [text]})
501
  )
502
- return (
503
- embeddings["embeddings"][0]
504
- if isinstance(text, str)
505
- else embeddings["embeddings"]
506
- )
507
  elif engine == "openai":
508
- key = kwargs.get("key", "")
509
- url = kwargs.get("url", "https://api.openai.com/v1")
510
-
511
  if isinstance(text, list):
512
- embeddings = generate_openai_batch_embeddings(model, text, key, url)
513
  else:
514
- embeddings = generate_openai_batch_embeddings(model, [text], key, url)
515
 
516
  return embeddings[0] if isinstance(text, str) else embeddings
517
 
 
3
  import uuid
4
  from typing import Optional, Union
5
 
6
+ import asyncio
7
  import requests
8
 
9
  from huggingface_hub import snapshot_download
 
11
  from langchain_community.retrievers import BM25Retriever
12
  from langchain_core.documents import Document
13
 
 
 
 
 
 
14
  from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
15
  from open_webui.utils.misc import get_last_user_message
16
 
 
72
  limit=k,
73
  )
74
 
75
+ log.info(f"query_doc:result {result.ids} {result.metadatas}")
76
  return result
77
  except Exception as e:
78
  print(e)
 
123
  "metadatas": [[d.metadata for d in result]],
124
  }
125
 
126
+ log.info(
127
+ "query_doc_with_hybrid_search:result "
128
+ + f'{result["metadatas"]} {result["distances"]}'
129
+ )
130
  return result
131
  except Exception as e:
132
  raise e
 
177
 
178
  def query_collection(
179
  collection_names: list[str],
180
+ queries: list[str],
181
  embedding_function,
182
  k: int,
183
  ) -> dict:
 
184
  results = []
185
+ for query in queries:
186
+ query_embedding = embedding_function(query)
187
+ for collection_name in collection_names:
188
+ if collection_name:
189
+ try:
190
+ result = query_doc(
191
+ collection_name=collection_name,
192
+ k=k,
193
+ query_embedding=query_embedding,
194
+ )
195
+ if result is not None:
196
+ results.append(result.model_dump())
197
+ except Exception as e:
198
+ log.exception(f"Error when querying the collection: {e}")
199
+ else:
200
+ pass
201
 
202
  return merge_and_sort_query_results(results, k=k)
203
 
204
 
205
  def query_collection_with_hybrid_search(
206
  collection_names: list[str],
207
+ queries: list[str],
208
  embedding_function,
209
  k: int,
210
  reranking_function,
 
214
  error = False
215
  for collection_name in collection_names:
216
  try:
217
+ for query in queries:
218
+ result = query_doc_with_hybrid_search(
219
+ collection_name=collection_name,
220
+ query=query,
221
+ embedding_function=embedding_function,
222
+ k=k,
223
+ reranking_function=reranking_function,
224
+ r=r,
225
+ )
226
+ results.append(result)
227
  except Exception as e:
228
  log.exception(
229
  "Error when querying the collection with " f"hybrid_search: {e}"
 
280
  embedding_engine,
281
  embedding_model,
282
  embedding_function,
283
+ url,
284
+ key,
285
  embedding_batch_size,
286
  ):
287
  if embedding_engine == "":
 
291
  engine=embedding_engine,
292
  model=embedding_model,
293
  text=query,
294
+ url=url,
295
+ key=key,
296
  )
297
 
298
  def generate_multiple(query, func):
 
309
 
310
  def get_rag_context(
311
  files,
312
+ queries,
313
  embedding_function,
314
  k,
315
  reranking_function,
316
  r,
317
  hybrid_search,
318
  ):
319
+ log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
 
320
 
321
  extracted_collections = []
322
  relevant_contexts = []
 
358
  try:
359
  context = query_collection_with_hybrid_search(
360
  collection_names=collection_names,
361
+ queries=queries,
362
  embedding_function=embedding_function,
363
  k=k,
364
  reranking_function=reranking_function,
 
373
  if (not hybrid_search) or (context is None):
374
  context = query_collection(
375
  collection_names=collection_names,
376
+ queries=queries,
377
  embedding_function=embedding_function,
378
  k=k,
379
  )
 
465
 
466
 
467
  def generate_openai_batch_embeddings(
468
+ model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
469
  ) -> Optional[list[list[float]]]:
470
  try:
471
  r = requests.post(
 
487
  return None
488
 
489
 
490
+ def generate_ollama_batch_embeddings(
491
+ model: str, texts: list[str], url: str, key: str
492
+ ) -> Optional[list[list[float]]]:
493
+ try:
494
+ r = requests.post(
495
+ f"{url}/api/embed",
496
+ headers={
497
+ "Content-Type": "application/json",
498
+ "Authorization": f"Bearer {key}",
499
+ },
500
+ json={"input": texts, "model": model},
501
+ )
502
+ r.raise_for_status()
503
+ data = r.json()
504
+
505
+ print(data)
506
+ if "embeddings" in data:
507
+ return data["embeddings"]
508
+ else:
509
+ raise "Something went wrong :/"
510
+ except Exception as e:
511
+ print(e)
512
+ return None
513
+
514
+
515
  def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
516
+ url = kwargs.get("url", "")
517
+ key = kwargs.get("key", "")
518
+
519
  if engine == "ollama":
520
  if isinstance(text, list):
521
  embeddings = generate_ollama_batch_embeddings(
522
+ **{"model": model, "texts": text, "url": url, "key": key}
523
  )
524
  else:
525
  embeddings = generate_ollama_batch_embeddings(
526
+ **{"model": model, "texts": [text], "url": url, "key": key}
527
  )
528
+ return embeddings[0] if isinstance(text, str) else embeddings
 
 
 
 
529
  elif engine == "openai":
 
 
 
530
  if isinstance(text, list):
531
+ embeddings = generate_openai_batch_embeddings(model, text, url, key)
532
  else:
533
+ embeddings = generate_openai_batch_embeddings(model, [text], url, key)
534
 
535
  return embeddings[0] if isinstance(text, str) else embeddings
536
 
backend/open_webui/apps/retrieval/vector/connector.py CHANGED
@@ -8,6 +8,14 @@ elif VECTOR_DB == "qdrant":
8
  from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
9
 
10
  VECTOR_DB_CLIENT = QdrantClient()
 
 
 
 
 
 
 
 
11
  else:
12
  from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
13
 
 
8
  from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
9
 
10
  VECTOR_DB_CLIENT = QdrantClient()
11
+ elif VECTOR_DB == "opensearch":
12
+ from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient
13
+
14
+ VECTOR_DB_CLIENT = OpenSearchClient()
15
+ elif VECTOR_DB == "pgvector":
16
+ from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient
17
+
18
+ VECTOR_DB_CLIENT = PgvectorClient()
19
  else:
20
  from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
21
 
backend/open_webui/apps/retrieval/vector/dbs/chroma.py CHANGED
@@ -27,7 +27,9 @@ class ChromaClient:
27
  if CHROMA_CLIENT_AUTH_PROVIDER is not None:
28
  settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
29
  if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
30
- settings_dict["chroma_client_auth_credentials"] = CHROMA_CLIENT_AUTH_CREDENTIALS
 
 
31
 
32
  if CHROMA_HTTP_HOST != "":
33
  self.client = chromadb.HttpClient(
 
27
  if CHROMA_CLIENT_AUTH_PROVIDER is not None:
28
  settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
29
  if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
30
+ settings_dict["chroma_client_auth_credentials"] = (
31
+ CHROMA_CLIENT_AUTH_CREDENTIALS
32
+ )
33
 
34
  if CHROMA_HTTP_HOST != "":
35
  self.client = chromadb.HttpClient(
backend/open_webui/apps/retrieval/vector/dbs/opensearch.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from opensearchpy import OpenSearch
2
+ from typing import Optional
3
+
4
+ from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
5
+ from open_webui.config import (
6
+ OPENSEARCH_URI,
7
+ OPENSEARCH_SSL,
8
+ OPENSEARCH_CERT_VERIFY,
9
+ OPENSEARCH_USERNAME,
10
+ OPENSEARCH_PASSWORD,
11
+ )
12
+
13
+
14
+ class OpenSearchClient:
15
+ def __init__(self):
16
+ self.index_prefix = "open_webui"
17
+ self.client = OpenSearch(
18
+ hosts=[OPENSEARCH_URI],
19
+ use_ssl=OPENSEARCH_SSL,
20
+ verify_certs=OPENSEARCH_CERT_VERIFY,
21
+ http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
22
+ )
23
+
24
+ def _result_to_get_result(self, result) -> GetResult:
25
+ ids = []
26
+ documents = []
27
+ metadatas = []
28
+
29
+ for hit in result["hits"]["hits"]:
30
+ ids.append(hit["_id"])
31
+ documents.append(hit["_source"].get("text"))
32
+ metadatas.append(hit["_source"].get("metadata"))
33
+
34
+ return GetResult(ids=ids, documents=documents, metadatas=metadatas)
35
+
36
+ def _result_to_search_result(self, result) -> SearchResult:
37
+ ids = []
38
+ distances = []
39
+ documents = []
40
+ metadatas = []
41
+
42
+ for hit in result["hits"]["hits"]:
43
+ ids.append(hit["_id"])
44
+ distances.append(hit["_score"])
45
+ documents.append(hit["_source"].get("text"))
46
+ metadatas.append(hit["_source"].get("metadata"))
47
+
48
+ return SearchResult(
49
+ ids=ids, distances=distances, documents=documents, metadatas=metadatas
50
+ )
51
+
52
+ def _create_index(self, index_name: str, dimension: int):
53
+ body = {
54
+ "mappings": {
55
+ "properties": {
56
+ "id": {"type": "keyword"},
57
+ "vector": {
58
+ "type": "dense_vector",
59
+ "dims": dimension, # Adjust based on your vector dimensions
60
+ "index": true,
61
+ "similarity": "faiss",
62
+ "method": {
63
+ "name": "hnsw",
64
+ "space_type": "ip", # Use inner product to approximate cosine similarity
65
+ "engine": "faiss",
66
+ "ef_construction": 128,
67
+ "m": 16,
68
+ },
69
+ },
70
+ "text": {"type": "text"},
71
+ "metadata": {"type": "object"},
72
+ }
73
+ }
74
+ }
75
+ self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body)
76
+
77
+ def _create_batches(self, items: list[VectorItem], batch_size=100):
78
+ for i in range(0, len(items), batch_size):
79
+ yield items[i : i + batch_size]
80
+
81
+ def has_collection(self, index_name: str) -> bool:
82
+ # has_collection here means has index.
83
+ # We are simply adapting to the norms of the other DBs.
84
+ return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
85
+
86
+ def delete_colleciton(self, index_name: str):
87
+ # delete_collection here means delete index.
88
+ # We are simply adapting to the norms of the other DBs.
89
+ self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
90
+
91
+ def search(
92
+ self, index_name: str, vectors: list[list[float]], limit: int
93
+ ) -> Optional[SearchResult]:
94
+ query = {
95
+ "size": limit,
96
+ "_source": ["text", "metadata"],
97
+ "query": {
98
+ "script_score": {
99
+ "query": {"match_all": {}},
100
+ "script": {
101
+ "source": "cosineSimilarity(params.vector, 'vector') + 1.0",
102
+ "params": {
103
+ "vector": vectors[0]
104
+ }, # Assuming single query vector
105
+ },
106
+ }
107
+ },
108
+ }
109
+
110
+ result = self.client.search(
111
+ index=f"{self.index_prefix}_{index_name}", body=query
112
+ )
113
+
114
+ return self._result_to_search_result(result)
115
+
116
+ def get_or_create_index(self, index_name: str, dimension: int):
117
+ if not self.has_index(index_name):
118
+ self._create_index(index_name, dimension)
119
+
120
+ def get(self, index_name: str) -> Optional[GetResult]:
121
+ query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
122
+
123
+ result = self.client.search(
124
+ index=f"{self.index_prefix}_{index_name}", body=query
125
+ )
126
+ return self._result_to_get_result(result)
127
+
128
+ def insert(self, index_name: str, items: list[VectorItem]):
129
+ if not self.has_index(index_name):
130
+ self._create_index(index_name, dimension=len(items[0]["vector"]))
131
+
132
+ for batch in self._create_batches(items):
133
+ actions = [
134
+ {
135
+ "index": {
136
+ "_id": item["id"],
137
+ "_source": {
138
+ "vector": item["vector"],
139
+ "text": item["text"],
140
+ "metadata": item["metadata"],
141
+ },
142
+ }
143
+ }
144
+ for item in batch
145
+ ]
146
+ self.client.bulk(actions)
147
+
148
+ def upsert(self, index_name: str, items: list[VectorItem]):
149
+ if not self.has_index(index_name):
150
+ self._create_index(index_name, dimension=len(items[0]["vector"]))
151
+
152
+ for batch in self._create_batches(items):
153
+ actions = [
154
+ {
155
+ "index": {
156
+ "_id": item["id"],
157
+ "_source": {
158
+ "vector": item["vector"],
159
+ "text": item["text"],
160
+ "metadata": item["metadata"],
161
+ },
162
+ }
163
+ }
164
+ for item in batch
165
+ ]
166
+ self.client.bulk(actions)
167
+
168
+ def delete(self, index_name: str, ids: list[str]):
169
+ actions = [
170
+ {"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}}
171
+ for id in ids
172
+ ]
173
+ self.client.bulk(body=actions)
174
+
175
+ def reset(self):
176
+ indices = self.client.indices.get(index=f"{self.index_prefix}_*")
177
+ for index in indices:
178
+ self.client.indices.delete(index=index)
backend/open_webui/apps/retrieval/vector/dbs/pgvector.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List, Dict, Any
2
+ from sqlalchemy import (
3
+ cast,
4
+ column,
5
+ create_engine,
6
+ Column,
7
+ Integer,
8
+ select,
9
+ text,
10
+ Text,
11
+ values,
12
+ )
13
+ from sqlalchemy.sql import true
14
+ from sqlalchemy.pool import NullPool
15
+
16
+ from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
17
+ from sqlalchemy.dialects.postgresql import JSONB, array
18
+ from pgvector.sqlalchemy import Vector
19
+ from sqlalchemy.ext.mutable import MutableDict
20
+
21
+ from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
22
+ from open_webui.config import PGVECTOR_DB_URL
23
+
24
+ VECTOR_LENGTH = 1536
25
+ Base = declarative_base()
26
+
27
+
28
+ class DocumentChunk(Base):
29
+ __tablename__ = "document_chunk"
30
+
31
+ id = Column(Text, primary_key=True)
32
+ vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
33
+ collection_name = Column(Text, nullable=False)
34
+ text = Column(Text, nullable=True)
35
+ vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
36
+
37
+
38
+ class PgvectorClient:
39
+ def __init__(self) -> None:
40
+
41
+ # if no pgvector uri, use the existing database connection
42
+ if not PGVECTOR_DB_URL:
43
+ from open_webui.apps.webui.internal.db import Session
44
+
45
+ self.session = Session
46
+ else:
47
+ engine = create_engine(
48
+ PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
49
+ )
50
+ SessionLocal = sessionmaker(
51
+ autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
52
+ )
53
+ self.session = scoped_session(SessionLocal)
54
+
55
+ try:
56
+ # Ensure the pgvector extension is available
57
+ self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
58
+
59
+ # Create the tables if they do not exist
60
+ # Base.metadata.create_all requires a bind (engine or connection)
61
+ # Get the connection from the session
62
+ connection = self.session.connection()
63
+ Base.metadata.create_all(bind=connection)
64
+
65
+ # Create an index on the vector column if it doesn't exist
66
+ self.session.execute(
67
+ text(
68
+ "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
69
+ "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
70
+ )
71
+ )
72
+ self.session.execute(
73
+ text(
74
+ "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
75
+ "ON document_chunk (collection_name);"
76
+ )
77
+ )
78
+ self.session.commit()
79
+ print("Initialization complete.")
80
+ except Exception as e:
81
+ self.session.rollback()
82
+ print(f"Error during initialization: {e}")
83
+ raise
84
+
85
+ def adjust_vector_length(self, vector: List[float]) -> List[float]:
86
+ # Adjust vector to have length VECTOR_LENGTH
87
+ current_length = len(vector)
88
+ if current_length < VECTOR_LENGTH:
89
+ # Pad the vector with zeros
90
+ vector += [0.0] * (VECTOR_LENGTH - current_length)
91
+ elif current_length > VECTOR_LENGTH:
92
+ raise Exception(
93
+ f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
94
+ )
95
+ return vector
96
+
97
+ def insert(self, collection_name: str, items: List[VectorItem]) -> None:
98
+ try:
99
+ new_items = []
100
+ for item in items:
101
+ vector = self.adjust_vector_length(item["vector"])
102
+ new_chunk = DocumentChunk(
103
+ id=item["id"],
104
+ vector=vector,
105
+ collection_name=collection_name,
106
+ text=item["text"],
107
+ vmetadata=item["metadata"],
108
+ )
109
+ new_items.append(new_chunk)
110
+ self.session.bulk_save_objects(new_items)
111
+ self.session.commit()
112
+ print(
113
+ f"Inserted {len(new_items)} items into collection '{collection_name}'."
114
+ )
115
+ except Exception as e:
116
+ self.session.rollback()
117
+ print(f"Error during insert: {e}")
118
+ raise
119
+
120
+ def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
121
+ try:
122
+ for item in items:
123
+ vector = self.adjust_vector_length(item["vector"])
124
+ existing = (
125
+ self.session.query(DocumentChunk)
126
+ .filter(DocumentChunk.id == item["id"])
127
+ .first()
128
+ )
129
+ if existing:
130
+ existing.vector = vector
131
+ existing.text = item["text"]
132
+ existing.vmetadata = item["metadata"]
133
+ existing.collection_name = (
134
+ collection_name # Update collection_name if necessary
135
+ )
136
+ else:
137
+ new_chunk = DocumentChunk(
138
+ id=item["id"],
139
+ vector=vector,
140
+ collection_name=collection_name,
141
+ text=item["text"],
142
+ vmetadata=item["metadata"],
143
+ )
144
+ self.session.add(new_chunk)
145
+ self.session.commit()
146
+ print(f"Upserted {len(items)} items into collection '{collection_name}'.")
147
+ except Exception as e:
148
+ self.session.rollback()
149
+ print(f"Error during upsert: {e}")
150
+ raise
151
+
152
+ def search(
153
+ self,
154
+ collection_name: str,
155
+ vectors: List[List[float]],
156
+ limit: Optional[int] = None,
157
+ ) -> Optional[SearchResult]:
158
+ try:
159
+ if not vectors:
160
+ return None
161
+
162
+ # Adjust query vectors to VECTOR_LENGTH
163
+ vectors = [self.adjust_vector_length(vector) for vector in vectors]
164
+ num_queries = len(vectors)
165
+
166
+ def vector_expr(vector):
167
+ return cast(array(vector), Vector(VECTOR_LENGTH))
168
+
169
+ # Create the values for query vectors
170
+ qid_col = column("qid", Integer)
171
+ q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
172
+ query_vectors = (
173
+ values(qid_col, q_vector_col)
174
+ .data(
175
+ [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
176
+ )
177
+ .alias("query_vectors")
178
+ )
179
+
180
+ # Build the lateral subquery for each query vector
181
+ subq = (
182
+ select(
183
+ DocumentChunk.id,
184
+ DocumentChunk.text,
185
+ DocumentChunk.vmetadata,
186
+ (
187
+ DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
188
+ ).label("distance"),
189
+ )
190
+ .where(DocumentChunk.collection_name == collection_name)
191
+ .order_by(
192
+ (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
193
+ )
194
+ )
195
+ if limit is not None:
196
+ subq = subq.limit(limit)
197
+ subq = subq.lateral("result")
198
+
199
+ # Build the main query by joining query_vectors and the lateral subquery
200
+ stmt = (
201
+ select(
202
+ query_vectors.c.qid,
203
+ subq.c.id,
204
+ subq.c.text,
205
+ subq.c.vmetadata,
206
+ subq.c.distance,
207
+ )
208
+ .select_from(query_vectors)
209
+ .join(subq, true())
210
+ .order_by(query_vectors.c.qid, subq.c.distance)
211
+ )
212
+
213
+ result_proxy = self.session.execute(stmt)
214
+ results = result_proxy.all()
215
+
216
+ ids = [[] for _ in range(num_queries)]
217
+ distances = [[] for _ in range(num_queries)]
218
+ documents = [[] for _ in range(num_queries)]
219
+ metadatas = [[] for _ in range(num_queries)]
220
+
221
+ if not results:
222
+ return SearchResult(
223
+ ids=ids,
224
+ distances=distances,
225
+ documents=documents,
226
+ metadatas=metadatas,
227
+ )
228
+
229
+ for row in results:
230
+ qid = int(row.qid)
231
+ ids[qid].append(row.id)
232
+ distances[qid].append(row.distance)
233
+ documents[qid].append(row.text)
234
+ metadatas[qid].append(row.vmetadata)
235
+
236
+ return SearchResult(
237
+ ids=ids, distances=distances, documents=documents, metadatas=metadatas
238
+ )
239
+ except Exception as e:
240
+ print(f"Error during search: {e}")
241
+ return None
242
+
243
+ def query(
244
+ self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
245
+ ) -> Optional[GetResult]:
246
+ try:
247
+ query = self.session.query(DocumentChunk).filter(
248
+ DocumentChunk.collection_name == collection_name
249
+ )
250
+
251
+ for key, value in filter.items():
252
+ query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
253
+
254
+ if limit is not None:
255
+ query = query.limit(limit)
256
+
257
+ results = query.all()
258
+
259
+ if not results:
260
+ return None
261
+
262
+ ids = [[result.id for result in results]]
263
+ documents = [[result.text for result in results]]
264
+ metadatas = [[result.vmetadata for result in results]]
265
+
266
+ return GetResult(
267
+ ids=ids,
268
+ documents=documents,
269
+ metadatas=metadatas,
270
+ )
271
+ except Exception as e:
272
+ print(f"Error during query: {e}")
273
+ return None
274
+
275
+ def get(
276
+ self, collection_name: str, limit: Optional[int] = None
277
+ ) -> Optional[GetResult]:
278
+ try:
279
+ query = self.session.query(DocumentChunk).filter(
280
+ DocumentChunk.collection_name == collection_name
281
+ )
282
+ if limit is not None:
283
+ query = query.limit(limit)
284
+
285
+ results = query.all()
286
+
287
+ if not results:
288
+ return None
289
+
290
+ ids = [[result.id for result in results]]
291
+ documents = [[result.text for result in results]]
292
+ metadatas = [[result.vmetadata for result in results]]
293
+
294
+ return GetResult(ids=ids, documents=documents, metadatas=metadatas)
295
+ except Exception as e:
296
+ print(f"Error during get: {e}")
297
+ return None
298
+
299
+ def delete(
300
+ self,
301
+ collection_name: str,
302
+ ids: Optional[List[str]] = None,
303
+ filter: Optional[Dict[str, Any]] = None,
304
+ ) -> None:
305
+ try:
306
+ query = self.session.query(DocumentChunk).filter(
307
+ DocumentChunk.collection_name == collection_name
308
+ )
309
+ if ids:
310
+ query = query.filter(DocumentChunk.id.in_(ids))
311
+ if filter:
312
+ for key, value in filter.items():
313
+ query = query.filter(
314
+ DocumentChunk.vmetadata[key].astext == str(value)
315
+ )
316
+ deleted = query.delete(synchronize_session=False)
317
+ self.session.commit()
318
+ print(f"Deleted {deleted} items from collection '{collection_name}'.")
319
+ except Exception as e:
320
+ self.session.rollback()
321
+ print(f"Error during delete: {e}")
322
+ raise
323
+
324
+ def reset(self) -> None:
325
+ try:
326
+ deleted = self.session.query(DocumentChunk).delete()
327
+ self.session.commit()
328
+ print(
329
+ f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
330
+ )
331
+ except Exception as e:
332
+ self.session.rollback()
333
+ print(f"Error during reset: {e}")
334
+ raise
335
+
336
+ def close(self) -> None:
337
+ pass
338
+
339
+ def has_collection(self, collection_name: str) -> bool:
340
+ try:
341
+ exists = (
342
+ self.session.query(DocumentChunk)
343
+ .filter(DocumentChunk.collection_name == collection_name)
344
+ .first()
345
+ is not None
346
+ )
347
+ return exists
348
+ except Exception as e:
349
+ print(f"Error checking collection existence: {e}")
350
+ return False
351
+
352
+ def delete_collection(self, collection_name: str) -> None:
353
+ self.delete(collection_name)
354
+ print(f"Collection '{collection_name}' deleted.")
backend/open_webui/apps/retrieval/vector/dbs/qdrant.py CHANGED
@@ -5,7 +5,7 @@ from qdrant_client.http.models import PointStruct
5
  from qdrant_client.models import models
6
 
7
  from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
8
- from open_webui.config import QDRANT_URI
9
 
10
  NO_LIMIT = 999999999
11
 
@@ -14,7 +14,12 @@ class QdrantClient:
14
  def __init__(self):
15
  self.collection_prefix = "open-webui"
16
  self.QDRANT_URI = QDRANT_URI
17
- self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None
 
 
 
 
 
18
 
19
  def _result_to_get_result(self, points) -> GetResult:
20
  ids = []
 
5
  from qdrant_client.models import models
6
 
7
  from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
8
+ from open_webui.config import QDRANT_URI, QDRANT_API_KEY
9
 
10
  NO_LIMIT = 999999999
11
 
 
14
  def __init__(self):
15
  self.collection_prefix = "open-webui"
16
  self.QDRANT_URI = QDRANT_URI
17
+ self.QDRANT_API_KEY = QDRANT_API_KEY
18
+ self.client = (
19
+ Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
20
+ if self.QDRANT_URI
21
+ else None
22
+ )
23
 
24
  def _result_to_get_result(self, points) -> GetResult:
25
  ids = []
backend/open_webui/apps/retrieval/web/bing.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pprint import pprint
4
+ from typing import Optional
5
+ import requests
6
+ from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
7
+ from open_webui.env import SRC_LOG_LEVELS
8
+ import argparse
9
+
10
+ log = logging.getLogger(__name__)
11
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
12
+ """
13
+ Documentation: https://docs.microsoft.com/en-us/bing/search-apis/bing-web-search/overview
14
+ """
15
+
16
+
17
+ def search_bing(
18
+ subscription_key: str,
19
+ endpoint: str,
20
+ locale: str,
21
+ query: str,
22
+ count: int,
23
+ filter_list: Optional[list[str]] = None,
24
+ ) -> list[SearchResult]:
25
+ mkt = locale
26
+ params = {"q": query, "mkt": mkt, "answerCount": count}
27
+ headers = {"Ocp-Apim-Subscription-Key": subscription_key}
28
+
29
+ try:
30
+ response = requests.get(endpoint, headers=headers, params=params)
31
+ response.raise_for_status()
32
+ json_response = response.json()
33
+ results = json_response.get("webPages", {}).get("value", [])
34
+ if filter_list:
35
+ results = get_filtered_results(results, filter_list)
36
+ return [
37
+ SearchResult(
38
+ link=result["url"],
39
+ title=result.get("name"),
40
+ snippet=result.get("snippet"),
41
+ )
42
+ for result in results
43
+ ]
44
+ except Exception as ex:
45
+ log.error(f"Error: {ex}")
46
+ raise ex
47
+
48
+
49
+ def main():
50
+ parser = argparse.ArgumentParser(description="Search Bing from the command line.")
51
+ parser.add_argument(
52
+ "query",
53
+ type=str,
54
+ default="Top 10 international news today",
55
+ help="The search query.",
56
+ )
57
+ parser.add_argument(
58
+ "--count", type=int, default=10, help="Number of search results to return."
59
+ )
60
+ parser.add_argument(
61
+ "--filter", nargs="*", help="List of filters to apply to the search results."
62
+ )
63
+ parser.add_argument(
64
+ "--locale",
65
+ type=str,
66
+ default="en-US",
67
+ help="The locale to use for the search, maps to market in api",
68
+ )
69
+
70
+ args = parser.parse_args()
71
+
72
+ results = search_bing(args.locale, args.query, args.count, args.filter)
73
+ pprint(results)
backend/open_webui/apps/retrieval/web/jina_search.py CHANGED
@@ -9,7 +9,7 @@ log = logging.getLogger(__name__)
9
  log.setLevel(SRC_LOG_LEVELS["RAG"])
10
 
11
 
12
- def search_jina(query: str, count: int) -> list[SearchResult]:
13
  """
14
  Search using Jina's Search API and return the results as a list of SearchResult objects.
15
  Args:
@@ -20,9 +20,7 @@ def search_jina(query: str, count: int) -> list[SearchResult]:
20
  list[SearchResult]: A list of search results
21
  """
22
  jina_search_endpoint = "https://s.jina.ai/"
23
- headers = {
24
- "Accept": "application/json",
25
- }
26
  url = str(URL(jina_search_endpoint + query))
27
  response = requests.get(url, headers=headers)
28
  response.raise_for_status()
 
9
  log.setLevel(SRC_LOG_LEVELS["RAG"])
10
 
11
 
12
+ def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
13
  """
14
  Search using Jina's Search API and return the results as a list of SearchResult objects.
15
  Args:
 
20
  list[SearchResult]: A list of search results
21
  """
22
  jina_search_endpoint = "https://s.jina.ai/"
23
+ headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"}
 
 
24
  url = str(URL(jina_search_endpoint + query))
25
  response = requests.get(url, headers=headers)
26
  response.raise_for_status()
backend/open_webui/apps/retrieval/web/testdata/bing.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_type": "SearchResponse",
3
+ "queryContext": {
4
+ "originalQuery": "Top 10 international results"
5
+ },
6
+ "webPages": {
7
+ "webSearchUrl": "https://www.bing.com/search?q=Top+10+international+results",
8
+ "totalEstimatedMatches": 687,
9
+ "value": [
10
+ {
11
+ "id": "https://api.bing.microsoft.com/api/v7/#WebPages.0",
12
+ "name": "2024 Mexican Grand Prix - F1 results and latest standings ... - PlanetF1",
13
+ "url": "https://www.planetf1.com/news/f1-results-2024-mexican-grand-prix-race-standings",
14
+ "datePublished": "2024-10-27T00:00:00.0000000",
15
+ "datePublishedFreshnessText": "1 day ago",
16
+ "isFamilyFriendly": true,
17
+ "displayUrl": "https://www.planetf1.com/news/f1-results-2024-mexican-grand-prix-race-standings",
18
+ "snippet": "Nico Hulkenberg and Pierre Gasly completed the top 10. A full report of the Mexican Grand Prix is available at the bottom of this article. F1 results – 2024 Mexican Grand Prix",
19
+ "dateLastCrawled": "2024-10-28T07:15:00.0000000Z",
20
+ "cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=916492551782&mkt=en-US&setlang=en-US&w=zBsfaAPyF2tUrHFHr_vFFdUm8sng4g34",
21
+ "language": "en",
22
+ "isNavigational": false,
23
+ "noCache": false
24
+ },
25
+ {
26
+ "id": "https://api.bing.microsoft.com/api/v7/#WebPages.1",
27
+ "name": "F1 Results Today: HUGE Verstappen penalties cause major title change",
28
+ "url": "https://www.gpfans.com/en/f1-news/1033512/f1-results-today-mexican-grand-prix-huge-max-verstappen-penalties-cause-major-title-change/",
29
+ "datePublished": "2024-10-27T00:00:00.0000000",
30
+ "datePublishedFreshnessText": "1 day ago",
31
+ "isFamilyFriendly": true,
32
+ "displayUrl": "https://www.gpfans.com/en/f1-news/1033512/f1-results-today-mexican-grand-prix-huge-max...",
33
+ "snippet": "Elsewhere, Mercedes duo Lewis Hamilton and George Russell came home in P4 and P5 respectively. Meanwhile, the surprise package of the day were Haas, with both Kevin Magnussen and Nico Hulkenberg finishing inside the points.. READ MORE: RB star issues apology after red flag CRASH at Mexican GP Mexican Grand Prix 2024 results. 1. Carlos Sainz [Ferrari] 2. Lando Norris [McLaren] - +4.705",
34
+ "dateLastCrawled": "2024-10-28T06:06:00.0000000Z",
35
+ "cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=2840656522642&mkt=en-US&setlang=en-US&w=-Tbkwxnq52jZCvG7l3CtgcwT1vwAjIUD",
36
+ "language": "en",
37
+ "isNavigational": false,
38
+ "noCache": false
39
+ },
40
+ {
41
+ "id": "https://api.bing.microsoft.com/api/v7/#WebPages.2",
42
+ "name": "International Power Rankings: England flying, Kangaroos cruising, Fiji rise",
43
+ "url": "https://www.loverugbyleague.com/post/international-power-rankings-england-flying-kangaroos-cruising-fiji-rise",
44
+ "datePublished": "2024-10-28T00:00:00.0000000",
45
+ "datePublishedFreshnessText": "7 hours ago",
46
+ "isFamilyFriendly": true,
47
+ "displayUrl": "https://www.loverugbyleague.com/post/international-power-rankings-england-flying...",
48
+ "snippet": "LRL RECOMMENDS: England player ratings from first Test against Samoa as omnificent George Williams scores perfect 10. 2. Australia (Men) – SAME. The Kangaroos remain 2nd in our Power Rankings after their 22-10 win against New Zealand in Christchurch on Sunday. As was the case in their win against Tonga last week, Mal Meninga’s side weren ...",
49
+ "dateLastCrawled": "2024-10-28T07:09:00.0000000Z",
50
+ "cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=1535008462672&mkt=en-US&setlang=en-US&w=82ujhH4Kp0iuhCS7wh1xLUFYUeetaVVm",
51
+ "language": "en",
52
+ "isNavigational": false,
53
+ "noCache": false
54
+ }
55
+ ],
56
+ "someResultsRemoved": true
57
+ }
58
+ }
backend/open_webui/apps/socket/main.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import asyncio
2
  import socketio
3
  import logging
 
1
+ # TODO: move socket to webui app
2
+
3
  import asyncio
4
  import socketio
5
  import logging
backend/open_webui/apps/webui/main.py CHANGED
@@ -12,6 +12,7 @@ from open_webui.apps.webui.routers import (
12
  chats,
13
  folders,
14
  configs,
 
15
  files,
16
  functions,
17
  memories,
@@ -34,6 +35,7 @@ from open_webui.config import (
34
  ENABLE_LOGIN_FORM,
35
  ENABLE_MESSAGE_RATING,
36
  ENABLE_SIGNUP,
 
37
  ENABLE_EVALUATION_ARENA_MODELS,
38
  EVALUATION_ARENA_MODELS,
39
  DEFAULT_ARENA_MODEL,
@@ -50,9 +52,22 @@ from open_webui.config import (
50
  WEBHOOK_URL,
51
  WEBUI_AUTH,
52
  WEBUI_BANNERS,
 
 
 
 
 
 
 
 
 
 
 
 
53
  AppConfig,
54
  )
55
  from open_webui.env import (
 
56
  WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
57
  WEBUI_AUTH_TRUSTED_NAME_HEADER,
58
  )
@@ -72,7 +87,11 @@ from open_webui.utils.payload import (
72
 
73
  from open_webui.utils.tools import get_tools
74
 
75
- app = FastAPI()
 
 
 
 
76
 
77
  log = logging.getLogger(__name__)
78
 
@@ -80,6 +99,8 @@ app.state.config = AppConfig()
80
 
81
  app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
82
  app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
 
 
83
  app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
84
  app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
85
  app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
@@ -92,6 +113,8 @@ app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
92
  app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
93
  app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
94
  app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
 
 
95
  app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
96
  app.state.config.WEBHOOK_URL = WEBHOOK_URL
97
  app.state.config.BANNERS = WEBUI_BANNERS
@@ -111,7 +134,19 @@ app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
111
  app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
112
  app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
113
 
114
- app.state.MODELS = {}
 
 
 
 
 
 
 
 
 
 
 
 
115
  app.state.TOOLS = {}
116
  app.state.FUNCTIONS = {}
117
 
@@ -135,13 +170,15 @@ app.include_router(models.router, prefix="/models", tags=["models"])
135
  app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
136
  app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
137
  app.include_router(tools.router, prefix="/tools", tags=["tools"])
138
- app.include_router(functions.router, prefix="/functions", tags=["functions"])
139
 
140
  app.include_router(memories.router, prefix="/memories", tags=["memories"])
141
- app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
142
-
143
  app.include_router(folders.router, prefix="/folders", tags=["folders"])
 
 
144
  app.include_router(files.router, prefix="/files", tags=["files"])
 
 
 
145
 
146
  app.include_router(utils.router, prefix="/utils", tags=["utils"])
147
 
@@ -336,7 +373,7 @@ def get_function_params(function_module, form_data, user, extra_params=None):
336
  return params
337
 
338
 
339
- async def generate_function_chat_completion(form_data, user):
340
  model_id = form_data.get("model")
341
  model_info = Models.get_model_by_id(model_id)
342
 
@@ -372,6 +409,7 @@ async def generate_function_chat_completion(form_data, user):
372
  "name": user.name,
373
  "role": user.role,
374
  },
 
375
  }
376
  extra_params["__tools__"] = get_tools(
377
  app,
@@ -379,7 +417,7 @@ async def generate_function_chat_completion(form_data, user):
379
  user,
380
  {
381
  **extra_params,
382
- "__model__": app.state.MODELS[form_data["model"]],
383
  "__messages__": form_data["messages"],
384
  "__files__": files,
385
  },
 
12
  chats,
13
  folders,
14
  configs,
15
+ groups,
16
  files,
17
  functions,
18
  memories,
 
35
  ENABLE_LOGIN_FORM,
36
  ENABLE_MESSAGE_RATING,
37
  ENABLE_SIGNUP,
38
+ ENABLE_API_KEY,
39
  ENABLE_EVALUATION_ARENA_MODELS,
40
  EVALUATION_ARENA_MODELS,
41
  DEFAULT_ARENA_MODEL,
 
52
  WEBHOOK_URL,
53
  WEBUI_AUTH,
54
  WEBUI_BANNERS,
55
+ ENABLE_LDAP,
56
+ LDAP_SERVER_LABEL,
57
+ LDAP_SERVER_HOST,
58
+ LDAP_SERVER_PORT,
59
+ LDAP_ATTRIBUTE_FOR_USERNAME,
60
+ LDAP_SEARCH_FILTERS,
61
+ LDAP_SEARCH_BASE,
62
+ LDAP_APP_DN,
63
+ LDAP_APP_PASSWORD,
64
+ LDAP_USE_TLS,
65
+ LDAP_CA_CERT_FILE,
66
+ LDAP_CIPHERS,
67
  AppConfig,
68
  )
69
  from open_webui.env import (
70
+ ENV,
71
  WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
72
  WEBUI_AUTH_TRUSTED_NAME_HEADER,
73
  )
 
87
 
88
  from open_webui.utils.tools import get_tools
89
 
90
+ app = FastAPI(
91
+ docs_url="/docs" if ENV == "dev" else None,
92
+ openapi_url="/openapi.json" if ENV == "dev" else None,
93
+ redoc_url=None,
94
+ )
95
 
96
  log = logging.getLogger(__name__)
97
 
 
99
 
100
  app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
101
  app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
102
+ app.state.config.ENABLE_API_KEY = ENABLE_API_KEY
103
+
104
  app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
105
  app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
106
  app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
 
113
  app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
114
  app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
115
  app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
116
+
117
+
118
  app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
119
  app.state.config.WEBHOOK_URL = WEBHOOK_URL
120
  app.state.config.BANNERS = WEBUI_BANNERS
 
134
  app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
135
  app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
136
 
137
+ app.state.config.ENABLE_LDAP = ENABLE_LDAP
138
+ app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL
139
+ app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST
140
+ app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT
141
+ app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME
142
+ app.state.config.LDAP_APP_DN = LDAP_APP_DN
143
+ app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD
144
+ app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE
145
+ app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS
146
+ app.state.config.LDAP_USE_TLS = LDAP_USE_TLS
147
+ app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
148
+ app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
149
+
150
  app.state.TOOLS = {}
151
  app.state.FUNCTIONS = {}
152
 
 
170
  app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
171
  app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
172
  app.include_router(tools.router, prefix="/tools", tags=["tools"])
 
173
 
174
  app.include_router(memories.router, prefix="/memories", tags=["memories"])
 
 
175
  app.include_router(folders.router, prefix="/folders", tags=["folders"])
176
+
177
+ app.include_router(groups.router, prefix="/groups", tags=["groups"])
178
  app.include_router(files.router, prefix="/files", tags=["files"])
179
+ app.include_router(functions.router, prefix="/functions", tags=["functions"])
180
+ app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
181
+
182
 
183
  app.include_router(utils.router, prefix="/utils", tags=["utils"])
184
 
 
373
  return params
374
 
375
 
376
+ async def generate_function_chat_completion(form_data, user, models: dict = {}):
377
  model_id = form_data.get("model")
378
  model_info = Models.get_model_by_id(model_id)
379
 
 
409
  "name": user.name,
410
  "role": user.role,
411
  },
412
+ "__metadata__": metadata,
413
  }
414
  extra_params["__tools__"] = get_tools(
415
  app,
 
417
  user,
418
  {
419
  **extra_params,
420
+ "__model__": models.get(form_data["model"], None),
421
  "__messages__": form_data["messages"],
422
  "__files__": files,
423
  },
backend/open_webui/apps/webui/models/auths.py CHANGED
@@ -64,6 +64,11 @@ class SigninForm(BaseModel):
64
  password: str
65
 
66
 
 
 
 
 
 
67
  class ProfileImageUrlForm(BaseModel):
68
  profile_image_url: str
69
 
 
64
  password: str
65
 
66
 
67
+ class LdapForm(BaseModel):
68
+ user: str
69
+ password: str
70
+
71
+
72
  class ProfileImageUrlForm(BaseModel):
73
  profile_image_url: str
74
 
backend/open_webui/apps/webui/models/chats.py CHANGED
@@ -203,15 +203,22 @@ class ChatTable:
203
  def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
204
  try:
205
  with get_db() as db:
206
- print("update_shared_chat_by_id")
207
  chat = db.get(Chat, chat_id)
208
- print(chat)
209
- chat.title = chat.title
210
- chat.chat = chat.chat
 
 
 
 
 
 
 
 
211
  db.commit()
212
- db.refresh(chat)
213
 
214
- return self.get_chat_by_id(chat.share_id)
215
  except Exception:
216
  return None
217
 
 
203
  def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
204
  try:
205
  with get_db() as db:
 
206
  chat = db.get(Chat, chat_id)
207
+ shared_chat = (
208
+ db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first()
209
+ )
210
+
211
+ if shared_chat is None:
212
+ return self.insert_shared_chat_by_chat_id(chat_id)
213
+
214
+ shared_chat.title = chat.title
215
+ shared_chat.chat = chat.chat
216
+
217
+ shared_chat.updated_at = int(time.time())
218
  db.commit()
219
+ db.refresh(shared_chat)
220
 
221
+ return ChatModel.model_validate(shared_chat)
222
  except Exception:
223
  return None
224
 
backend/open_webui/apps/webui/models/groups.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import time
4
+ from typing import Optional
5
+ import uuid
6
+
7
+ from open_webui.apps.webui.internal.db import Base, get_db
8
+ from open_webui.env import SRC_LOG_LEVELS
9
+
10
+ from open_webui.apps.webui.models.files import FileMetadataResponse
11
+
12
+
13
+ from pydantic import BaseModel, ConfigDict
14
+ from sqlalchemy import BigInteger, Column, String, Text, JSON, func
15
+
16
+
17
+ log = logging.getLogger(__name__)
18
+ log.setLevel(SRC_LOG_LEVELS["MODELS"])
19
+
20
+ ####################
21
+ # UserGroup DB Schema
22
+ ####################
23
+
24
+
25
+ class Group(Base):
26
+ __tablename__ = "group"
27
+
28
+ id = Column(Text, unique=True, primary_key=True)
29
+ user_id = Column(Text)
30
+
31
+ name = Column(Text)
32
+ description = Column(Text)
33
+
34
+ data = Column(JSON, nullable=True)
35
+ meta = Column(JSON, nullable=True)
36
+
37
+ permissions = Column(JSON, nullable=True)
38
+ user_ids = Column(JSON, nullable=True)
39
+
40
+ created_at = Column(BigInteger)
41
+ updated_at = Column(BigInteger)
42
+
43
+
44
+ class GroupModel(BaseModel):
45
+ model_config = ConfigDict(from_attributes=True)
46
+ id: str
47
+ user_id: str
48
+
49
+ name: str
50
+ description: str
51
+
52
+ data: Optional[dict] = None
53
+ meta: Optional[dict] = None
54
+
55
+ permissions: Optional[dict] = None
56
+ user_ids: list[str] = []
57
+
58
+ created_at: int # timestamp in epoch
59
+ updated_at: int # timestamp in epoch
60
+
61
+
62
+ ####################
63
+ # Forms
64
+ ####################
65
+
66
+
67
+ class GroupResponse(BaseModel):
68
+ id: str
69
+ user_id: str
70
+ name: str
71
+ description: str
72
+ permissions: Optional[dict] = None
73
+ data: Optional[dict] = None
74
+ meta: Optional[dict] = None
75
+ user_ids: list[str] = []
76
+ created_at: int # timestamp in epoch
77
+ updated_at: int # timestamp in epoch
78
+
79
+
80
+ class GroupForm(BaseModel):
81
+ name: str
82
+ description: str
83
+
84
+
85
+ class GroupUpdateForm(GroupForm):
86
+ permissions: Optional[dict] = None
87
+ user_ids: Optional[list[str]] = None
88
+ admin_ids: Optional[list[str]] = None
89
+
90
+
91
+ class GroupTable:
92
+ def insert_new_group(
93
+ self, user_id: str, form_data: GroupForm
94
+ ) -> Optional[GroupModel]:
95
+ with get_db() as db:
96
+ group = GroupModel(
97
+ **{
98
+ **form_data.model_dump(),
99
+ "id": str(uuid.uuid4()),
100
+ "user_id": user_id,
101
+ "created_at": int(time.time()),
102
+ "updated_at": int(time.time()),
103
+ }
104
+ )
105
+
106
+ try:
107
+ result = Group(**group.model_dump())
108
+ db.add(result)
109
+ db.commit()
110
+ db.refresh(result)
111
+ if result:
112
+ return GroupModel.model_validate(result)
113
+ else:
114
+ return None
115
+
116
+ except Exception:
117
+ return None
118
+
119
+ def get_groups(self) -> list[GroupModel]:
120
+ with get_db() as db:
121
+ return [
122
+ GroupModel.model_validate(group)
123
+ for group in db.query(Group).order_by(Group.updated_at.desc()).all()
124
+ ]
125
+
126
+ def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
127
+ with get_db() as db:
128
+ return [
129
+ GroupModel.model_validate(group)
130
+ for group in db.query(Group)
131
+ .filter(
132
+ func.json_array_length(Group.user_ids) > 0
133
+ ) # Ensure array exists
134
+ .filter(
135
+ Group.user_ids.cast(String).like(f'%"{user_id}"%')
136
+ ) # String-based check
137
+ .order_by(Group.updated_at.desc())
138
+ .all()
139
+ ]
140
+
141
+ def get_group_by_id(self, id: str) -> Optional[GroupModel]:
142
+ try:
143
+ with get_db() as db:
144
+ group = db.query(Group).filter_by(id=id).first()
145
+ return GroupModel.model_validate(group) if group else None
146
+ except Exception:
147
+ return None
148
+
149
+ def update_group_by_id(
150
+ self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
151
+ ) -> Optional[GroupModel]:
152
+ try:
153
+ with get_db() as db:
154
+ db.query(Group).filter_by(id=id).update(
155
+ {
156
+ **form_data.model_dump(exclude_none=True),
157
+ "updated_at": int(time.time()),
158
+ }
159
+ )
160
+ db.commit()
161
+ return self.get_group_by_id(id=id)
162
+ except Exception as e:
163
+ log.exception(e)
164
+ return None
165
+
166
+ def delete_group_by_id(self, id: str) -> bool:
167
+ try:
168
+ with get_db() as db:
169
+ db.query(Group).filter_by(id=id).delete()
170
+ db.commit()
171
+ return True
172
+ except Exception:
173
+ return False
174
+
175
+ def delete_all_groups(self) -> bool:
176
+ with get_db() as db:
177
+ try:
178
+ db.query(Group).delete()
179
+ db.commit()
180
+
181
+ return True
182
+ except Exception:
183
+ return False
184
+
185
+
186
+ Groups = GroupTable()
backend/open_webui/apps/webui/models/knowledge.py CHANGED
@@ -8,11 +8,13 @@ from open_webui.apps.webui.internal.db import Base, get_db
8
  from open_webui.env import SRC_LOG_LEVELS
9
 
10
  from open_webui.apps.webui.models.files import FileMetadataResponse
 
11
 
12
 
13
  from pydantic import BaseModel, ConfigDict
14
  from sqlalchemy import BigInteger, Column, String, Text, JSON
15
 
 
16
 
17
  log = logging.getLogger(__name__)
18
  log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -34,6 +36,23 @@ class Knowledge(Base):
34
  data = Column(JSON, nullable=True)
35
  meta = Column(JSON, nullable=True)
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  created_at = Column(BigInteger)
38
  updated_at = Column(BigInteger)
39
 
@@ -50,6 +69,8 @@ class KnowledgeModel(BaseModel):
50
  data: Optional[dict] = None
51
  meta: Optional[dict] = None
52
 
 
 
53
  created_at: int # timestamp in epoch
54
  updated_at: int # timestamp in epoch
55
 
@@ -59,15 +80,15 @@ class KnowledgeModel(BaseModel):
59
  ####################
60
 
61
 
62
- class KnowledgeResponse(BaseModel):
63
- id: str
64
- name: str
65
- description: str
66
- data: Optional[dict] = None
67
- meta: Optional[dict] = None
68
- created_at: int # timestamp in epoch
69
- updated_at: int # timestamp in epoch
70
 
 
 
71
  files: Optional[list[FileMetadataResponse | dict]] = None
72
 
73
 
@@ -75,12 +96,7 @@ class KnowledgeForm(BaseModel):
75
  name: str
76
  description: str
77
  data: Optional[dict] = None
78
-
79
-
80
- class KnowledgeUpdateForm(BaseModel):
81
- name: Optional[str] = None
82
- description: Optional[str] = None
83
- data: Optional[dict] = None
84
 
85
 
86
  class KnowledgeTable:
@@ -110,14 +126,33 @@ class KnowledgeTable:
110
  except Exception:
111
  return None
112
 
113
- def get_knowledge_items(self) -> list[KnowledgeModel]:
114
  with get_db() as db:
115
- return [
116
- KnowledgeModel.model_validate(knowledge)
117
- for knowledge in db.query(Knowledge)
118
- .order_by(Knowledge.updated_at.desc())
119
- .all()
120
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
123
  try:
@@ -128,14 +163,32 @@ class KnowledgeTable:
128
  return None
129
 
130
  def update_knowledge_by_id(
131
- self, id: str, form_data: KnowledgeUpdateForm, overwrite: bool = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  ) -> Optional[KnowledgeModel]:
133
  try:
134
  with get_db() as db:
135
  knowledge = self.get_knowledge_by_id(id=id)
136
  db.query(Knowledge).filter_by(id=id).update(
137
  {
138
- **form_data.model_dump(exclude_none=True),
139
  "updated_at": int(time.time()),
140
  }
141
  )
 
8
  from open_webui.env import SRC_LOG_LEVELS
9
 
10
  from open_webui.apps.webui.models.files import FileMetadataResponse
11
+ from open_webui.apps.webui.models.users import Users, UserResponse
12
 
13
 
14
  from pydantic import BaseModel, ConfigDict
15
  from sqlalchemy import BigInteger, Column, String, Text, JSON
16
 
17
+ from open_webui.utils.access_control import has_access
18
 
19
  log = logging.getLogger(__name__)
20
  log.setLevel(SRC_LOG_LEVELS["MODELS"])
 
36
  data = Column(JSON, nullable=True)
37
  meta = Column(JSON, nullable=True)
38
 
39
+ access_control = Column(JSON, nullable=True) # Controls data access levels.
40
+ # Defines access control rules for this entry.
41
+ # - `None`: Public access, available to all users with the "user" role.
42
+ # - `{}`: Private access, restricted exclusively to the owner.
43
+ # - Custom permissions: Specific access control for reading and writing;
44
+ # Can specify group or user-level restrictions:
45
+ # {
46
+ # "read": {
47
+ # "group_ids": ["group_id1", "group_id2"],
48
+ # "user_ids": ["user_id1", "user_id2"]
49
+ # },
50
+ # "write": {
51
+ # "group_ids": ["group_id1", "group_id2"],
52
+ # "user_ids": ["user_id1", "user_id2"]
53
+ # }
54
+ # }
55
+
56
  created_at = Column(BigInteger)
57
  updated_at = Column(BigInteger)
58
 
 
69
  data: Optional[dict] = None
70
  meta: Optional[dict] = None
71
 
72
+ access_control: Optional[dict] = None
73
+
74
  created_at: int # timestamp in epoch
75
  updated_at: int # timestamp in epoch
76
 
 
80
  ####################
81
 
82
 
83
+ class KnowledgeUserModel(KnowledgeModel):
84
+ user: Optional[UserResponse] = None
85
+
86
+
87
+ class KnowledgeResponse(KnowledgeModel):
88
+ files: Optional[list[FileMetadataResponse | dict]] = None
 
 
89
 
90
+
91
+ class KnowledgeUserResponse(KnowledgeUserModel):
92
  files: Optional[list[FileMetadataResponse | dict]] = None
93
 
94
 
 
96
  name: str
97
  description: str
98
  data: Optional[dict] = None
99
+ access_control: Optional[dict] = None
 
 
 
 
 
100
 
101
 
102
  class KnowledgeTable:
 
126
  except Exception:
127
  return None
128
 
129
+ def get_knowledge_bases(self) -> list[KnowledgeUserModel]:
130
  with get_db() as db:
131
+ knowledge_bases = []
132
+ for knowledge in (
133
+ db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
134
+ ):
135
+ user = Users.get_user_by_id(knowledge.user_id)
136
+ knowledge_bases.append(
137
+ KnowledgeUserModel.model_validate(
138
+ {
139
+ **KnowledgeModel.model_validate(knowledge).model_dump(),
140
+ "user": user.model_dump() if user else None,
141
+ }
142
+ )
143
+ )
144
+ return knowledge_bases
145
+
146
+ def get_knowledge_bases_by_user_id(
147
+ self, user_id: str, permission: str = "write"
148
+ ) -> list[KnowledgeUserModel]:
149
+ knowledge_bases = self.get_knowledge_bases()
150
+ return [
151
+ knowledge_base
152
+ for knowledge_base in knowledge_bases
153
+ if knowledge_base.user_id == user_id
154
+ or has_access(user_id, permission, knowledge_base.access_control)
155
+ ]
156
 
157
  def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
158
  try:
 
163
  return None
164
 
165
  def update_knowledge_by_id(
166
+ self, id: str, form_data: KnowledgeForm, overwrite: bool = False
167
+ ) -> Optional[KnowledgeModel]:
168
+ try:
169
+ with get_db() as db:
170
+ knowledge = self.get_knowledge_by_id(id=id)
171
+ db.query(Knowledge).filter_by(id=id).update(
172
+ {
173
+ **form_data.model_dump(),
174
+ "updated_at": int(time.time()),
175
+ }
176
+ )
177
+ db.commit()
178
+ return self.get_knowledge_by_id(id=id)
179
+ except Exception as e:
180
+ log.exception(e)
181
+ return None
182
+
183
+ def update_knowledge_data_by_id(
184
+ self, id: str, data: dict
185
  ) -> Optional[KnowledgeModel]:
186
  try:
187
  with get_db() as db:
188
  knowledge = self.get_knowledge_by_id(id=id)
189
  db.query(Knowledge).filter_by(id=id).update(
190
  {
191
+ "data": data,
192
  "updated_at": int(time.time()),
193
  }
194
  )
backend/open_webui/apps/webui/models/models.py CHANGED
@@ -4,8 +4,19 @@ from typing import Optional
4
 
5
  from open_webui.apps.webui.internal.db import Base, JSONField, get_db
6
  from open_webui.env import SRC_LOG_LEVELS
 
 
 
 
7
  from pydantic import BaseModel, ConfigDict
8
- from sqlalchemy import BigInteger, Column, Text
 
 
 
 
 
 
 
9
 
10
  log = logging.getLogger(__name__)
11
  log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -67,6 +78,25 @@ class Model(Base):
67
  Holds a JSON encoded blob of metadata, see `ModelMeta`.
68
  """
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  updated_at = Column(BigInteger)
71
  created_at = Column(BigInteger)
72
 
@@ -80,6 +110,9 @@ class ModelModel(BaseModel):
80
  params: ModelParams
81
  meta: ModelMeta
82
 
 
 
 
83
  updated_at: int # timestamp in epoch
84
  created_at: int # timestamp in epoch
85
 
@@ -91,12 +124,12 @@ class ModelModel(BaseModel):
91
  ####################
92
 
93
 
94
- class ModelResponse(BaseModel):
95
- id: str
96
- name: str
97
- meta: ModelMeta
98
- updated_at: int # timestamp in epoch
99
- created_at: int # timestamp in epoch
100
 
101
 
102
  class ModelForm(BaseModel):
@@ -105,6 +138,8 @@ class ModelForm(BaseModel):
105
  name: str
106
  meta: ModelMeta
107
  params: ModelParams
 
 
108
 
109
 
110
  class ModelsTable:
@@ -138,6 +173,39 @@ class ModelsTable:
138
  with get_db() as db:
139
  return [ModelModel.model_validate(model) for model in db.query(Model).all()]
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  def get_model_by_id(self, id: str) -> Optional[ModelModel]:
142
  try:
143
  with get_db() as db:
@@ -146,6 +214,23 @@ class ModelsTable:
146
  except Exception:
147
  return None
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
150
  try:
151
  with get_db() as db:
@@ -153,7 +238,7 @@ class ModelsTable:
153
  result = (
154
  db.query(Model)
155
  .filter_by(id=id)
156
- .update(model.model_dump(exclude={"id"}, exclude_none=True))
157
  )
158
  db.commit()
159
 
@@ -175,5 +260,15 @@ class ModelsTable:
175
  except Exception:
176
  return False
177
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  Models = ModelsTable()
 
4
 
5
  from open_webui.apps.webui.internal.db import Base, JSONField, get_db
6
  from open_webui.env import SRC_LOG_LEVELS
7
+
8
+ from open_webui.apps.webui.models.users import Users, UserResponse
9
+
10
+
11
  from pydantic import BaseModel, ConfigDict
12
+
13
+ from sqlalchemy import or_, and_, func
14
+ from sqlalchemy.dialects import postgresql, sqlite
15
+ from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
16
+
17
+
18
+ from open_webui.utils.access_control import has_access
19
+
20
 
21
  log = logging.getLogger(__name__)
22
  log.setLevel(SRC_LOG_LEVELS["MODELS"])
 
78
  Holds a JSON encoded blob of metadata, see `ModelMeta`.
79
  """
80
 
81
+ access_control = Column(JSON, nullable=True) # Controls data access levels.
82
+ # Defines access control rules for this entry.
83
+ # - `None`: Public access, available to all users with the "user" role.
84
+ # - `{}`: Private access, restricted exclusively to the owner.
85
+ # - Custom permissions: Specific access control for reading and writing;
86
+ # Can specify group or user-level restrictions:
87
+ # {
88
+ # "read": {
89
+ # "group_ids": ["group_id1", "group_id2"],
90
+ # "user_ids": ["user_id1", "user_id2"]
91
+ # },
92
+ # "write": {
93
+ # "group_ids": ["group_id1", "group_id2"],
94
+ # "user_ids": ["user_id1", "user_id2"]
95
+ # }
96
+ # }
97
+
98
+ is_active = Column(Boolean, default=True)
99
+
100
  updated_at = Column(BigInteger)
101
  created_at = Column(BigInteger)
102
 
 
110
  params: ModelParams
111
  meta: ModelMeta
112
 
113
+ access_control: Optional[dict] = None
114
+
115
+ is_active: bool
116
  updated_at: int # timestamp in epoch
117
  created_at: int # timestamp in epoch
118
 
 
124
  ####################
125
 
126
 
127
+ class ModelUserResponse(ModelModel):
128
+ user: Optional[UserResponse] = None
129
+
130
+
131
+ class ModelResponse(ModelModel):
132
+ pass
133
 
134
 
135
  class ModelForm(BaseModel):
 
138
  name: str
139
  meta: ModelMeta
140
  params: ModelParams
141
+ access_control: Optional[dict] = None
142
+ is_active: bool = True
143
 
144
 
145
  class ModelsTable:
 
173
  with get_db() as db:
174
  return [ModelModel.model_validate(model) for model in db.query(Model).all()]
175
 
176
+ def get_models(self) -> list[ModelUserResponse]:
177
+ with get_db() as db:
178
+ models = []
179
+ for model in db.query(Model).filter(Model.base_model_id != None).all():
180
+ user = Users.get_user_by_id(model.user_id)
181
+ models.append(
182
+ ModelUserResponse.model_validate(
183
+ {
184
+ **ModelModel.model_validate(model).model_dump(),
185
+ "user": user.model_dump() if user else None,
186
+ }
187
+ )
188
+ )
189
+ return models
190
+
191
+ def get_base_models(self) -> list[ModelModel]:
192
+ with get_db() as db:
193
+ return [
194
+ ModelModel.model_validate(model)
195
+ for model in db.query(Model).filter(Model.base_model_id == None).all()
196
+ ]
197
+
198
+ def get_models_by_user_id(
199
+ self, user_id: str, permission: str = "write"
200
+ ) -> list[ModelUserResponse]:
201
+ models = self.get_models()
202
+ return [
203
+ model
204
+ for model in models
205
+ if model.user_id == user_id
206
+ or has_access(user_id, permission, model.access_control)
207
+ ]
208
+
209
  def get_model_by_id(self, id: str) -> Optional[ModelModel]:
210
  try:
211
  with get_db() as db:
 
214
  except Exception:
215
  return None
216
 
217
+ def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
218
+ with get_db() as db:
219
+ try:
220
+ is_active = db.query(Model).filter_by(id=id).first().is_active
221
+
222
+ db.query(Model).filter_by(id=id).update(
223
+ {
224
+ "is_active": not is_active,
225
+ "updated_at": int(time.time()),
226
+ }
227
+ )
228
+ db.commit()
229
+
230
+ return self.get_model_by_id(id)
231
+ except Exception:
232
+ return None
233
+
234
  def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
235
  try:
236
  with get_db() as db:
 
238
  result = (
239
  db.query(Model)
240
  .filter_by(id=id)
241
+ .update(model.model_dump(exclude={"id"}))
242
  )
243
  db.commit()
244
 
 
260
  except Exception:
261
  return False
262
 
263
+ def delete_all_models(self) -> bool:
264
+ try:
265
+ with get_db() as db:
266
+ db.query(Model).delete()
267
+ db.commit()
268
+
269
+ return True
270
+ except Exception:
271
+ return False
272
+
273
 
274
  Models = ModelsTable()
backend/open_webui/apps/webui/models/prompts.py CHANGED
@@ -2,8 +2,12 @@ import time
2
  from typing import Optional
3
 
4
  from open_webui.apps.webui.internal.db import Base, get_db
 
 
5
  from pydantic import BaseModel, ConfigDict
6
- from sqlalchemy import BigInteger, Column, String, Text
 
 
7
 
8
  ####################
9
  # Prompts DB Schema
@@ -19,6 +23,23 @@ class Prompt(Base):
19
  content = Column(Text)
20
  timestamp = Column(BigInteger)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  class PromptModel(BaseModel):
24
  command: str
@@ -27,6 +48,7 @@ class PromptModel(BaseModel):
27
  content: str
28
  timestamp: int # timestamp in epoch
29
 
 
30
  model_config = ConfigDict(from_attributes=True)
31
 
32
 
@@ -35,10 +57,15 @@ class PromptModel(BaseModel):
35
  ####################
36
 
37
 
 
 
 
 
38
  class PromptForm(BaseModel):
39
  command: str
40
  title: str
41
  content: str
 
42
 
43
 
44
  class PromptsTable:
@@ -48,16 +75,14 @@ class PromptsTable:
48
  prompt = PromptModel(
49
  **{
50
  "user_id": user_id,
51
- "command": form_data.command,
52
- "title": form_data.title,
53
- "content": form_data.content,
54
  "timestamp": int(time.time()),
55
  }
56
  )
57
 
58
  try:
59
  with get_db() as db:
60
- result = Prompt(**prompt.dict())
61
  db.add(result)
62
  db.commit()
63
  db.refresh(result)
@@ -76,11 +101,34 @@ class PromptsTable:
76
  except Exception:
77
  return None
78
 
79
- def get_prompts(self) -> list[PromptModel]:
80
  with get_db() as db:
81
- return [
82
- PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
83
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def update_prompt_by_command(
86
  self, command: str, form_data: PromptForm
@@ -90,6 +138,7 @@ class PromptsTable:
90
  prompt = db.query(Prompt).filter_by(command=command).first()
91
  prompt.title = form_data.title
92
  prompt.content = form_data.content
 
93
  prompt.timestamp = int(time.time())
94
  db.commit()
95
  return PromptModel.model_validate(prompt)
 
2
  from typing import Optional
3
 
4
  from open_webui.apps.webui.internal.db import Base, get_db
5
+ from open_webui.apps.webui.models.users import Users, UserResponse
6
+
7
  from pydantic import BaseModel, ConfigDict
8
+ from sqlalchemy import BigInteger, Column, String, Text, JSON
9
+
10
+ from open_webui.utils.access_control import has_access
11
 
12
  ####################
13
  # Prompts DB Schema
 
23
  content = Column(Text)
24
  timestamp = Column(BigInteger)
25
 
26
+ access_control = Column(JSON, nullable=True) # Controls data access levels.
27
+ # Defines access control rules for this entry.
28
+ # - `None`: Public access, available to all users with the "user" role.
29
+ # - `{}`: Private access, restricted exclusively to the owner.
30
+ # - Custom permissions: Specific access control for reading and writing;
31
+ # Can specify group or user-level restrictions:
32
+ # {
33
+ # "read": {
34
+ # "group_ids": ["group_id1", "group_id2"],
35
+ # "user_ids": ["user_id1", "user_id2"]
36
+ # },
37
+ # "write": {
38
+ # "group_ids": ["group_id1", "group_id2"],
39
+ # "user_ids": ["user_id1", "user_id2"]
40
+ # }
41
+ # }
42
+
43
 
44
  class PromptModel(BaseModel):
45
  command: str
 
48
  content: str
49
  timestamp: int # timestamp in epoch
50
 
51
+ access_control: Optional[dict] = None
52
  model_config = ConfigDict(from_attributes=True)
53
 
54
 
 
57
  ####################
58
 
59
 
60
+ class PromptUserResponse(PromptModel):
61
+ user: Optional[UserResponse] = None
62
+
63
+
64
  class PromptForm(BaseModel):
65
  command: str
66
  title: str
67
  content: str
68
+ access_control: Optional[dict] = None
69
 
70
 
71
  class PromptsTable:
 
75
  prompt = PromptModel(
76
  **{
77
  "user_id": user_id,
78
+ **form_data.model_dump(),
 
 
79
  "timestamp": int(time.time()),
80
  }
81
  )
82
 
83
  try:
84
  with get_db() as db:
85
+ result = Prompt(**prompt.model_dump())
86
  db.add(result)
87
  db.commit()
88
  db.refresh(result)
 
101
  except Exception:
102
  return None
103
 
104
+ def get_prompts(self) -> list[PromptUserResponse]:
105
  with get_db() as db:
106
+ prompts = []
107
+
108
+ for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all():
109
+ user = Users.get_user_by_id(prompt.user_id)
110
+ prompts.append(
111
+ PromptUserResponse.model_validate(
112
+ {
113
+ **PromptModel.model_validate(prompt).model_dump(),
114
+ "user": user.model_dump() if user else None,
115
+ }
116
+ )
117
+ )
118
+
119
+ return prompts
120
+
121
+ def get_prompts_by_user_id(
122
+ self, user_id: str, permission: str = "write"
123
+ ) -> list[PromptUserResponse]:
124
+ prompts = self.get_prompts()
125
+
126
+ return [
127
+ prompt
128
+ for prompt in prompts
129
+ if prompt.user_id == user_id
130
+ or has_access(user_id, permission, prompt.access_control)
131
+ ]
132
 
133
  def update_prompt_by_command(
134
  self, command: str, form_data: PromptForm
 
138
  prompt = db.query(Prompt).filter_by(command=command).first()
139
  prompt.title = form_data.title
140
  prompt.content = form_data.content
141
+ prompt.access_control = form_data.access_control
142
  prompt.timestamp = int(time.time())
143
  db.commit()
144
  return PromptModel.model_validate(prompt)
backend/open_webui/apps/webui/models/tools.py CHANGED
@@ -3,10 +3,13 @@ import time
3
  from typing import Optional
4
 
5
  from open_webui.apps.webui.internal.db import Base, JSONField, get_db
6
- from open_webui.apps.webui.models.users import Users
7
  from open_webui.env import SRC_LOG_LEVELS
8
  from pydantic import BaseModel, ConfigDict
9
- from sqlalchemy import BigInteger, Column, String, Text
 
 
 
10
 
11
  log = logging.getLogger(__name__)
12
  log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -26,6 +29,24 @@ class Tool(Base):
26
  specs = Column(JSONField)
27
  meta = Column(JSONField)
28
  valves = Column(JSONField)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  updated_at = Column(BigInteger)
30
  created_at = Column(BigInteger)
31
 
@@ -42,6 +63,8 @@ class ToolModel(BaseModel):
42
  content: str
43
  specs: list[dict]
44
  meta: ToolMeta
 
 
45
  updated_at: int # timestamp in epoch
46
  created_at: int # timestamp in epoch
47
 
@@ -58,15 +81,21 @@ class ToolResponse(BaseModel):
58
  user_id: str
59
  name: str
60
  meta: ToolMeta
 
61
  updated_at: int # timestamp in epoch
62
  created_at: int # timestamp in epoch
63
 
64
 
 
 
 
 
65
  class ToolForm(BaseModel):
66
  id: str
67
  name: str
68
  content: str
69
  meta: ToolMeta
 
70
 
71
 
72
  class ToolValves(BaseModel):
@@ -109,9 +138,32 @@ class ToolsTable:
109
  except Exception:
110
  return None
111
 
112
- def get_tools(self) -> list[ToolModel]:
113
  with get_db() as db:
114
- return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
117
  try:
 
3
  from typing import Optional
4
 
5
  from open_webui.apps.webui.internal.db import Base, JSONField, get_db
6
+ from open_webui.apps.webui.models.users import Users, UserResponse
7
  from open_webui.env import SRC_LOG_LEVELS
8
  from pydantic import BaseModel, ConfigDict
9
+ from sqlalchemy import BigInteger, Column, String, Text, JSON
10
+
11
+ from open_webui.utils.access_control import has_access
12
+
13
 
14
  log = logging.getLogger(__name__)
15
  log.setLevel(SRC_LOG_LEVELS["MODELS"])
 
29
  specs = Column(JSONField)
30
  meta = Column(JSONField)
31
  valves = Column(JSONField)
32
+
33
+ access_control = Column(JSON, nullable=True) # Controls data access levels.
34
+ # Defines access control rules for this entry.
35
+ # - `None`: Public access, available to all users with the "user" role.
36
+ # - `{}`: Private access, restricted exclusively to the owner.
37
+ # - Custom permissions: Specific access control for reading and writing;
38
+ # Can specify group or user-level restrictions:
39
+ # {
40
+ # "read": {
41
+ # "group_ids": ["group_id1", "group_id2"],
42
+ # "user_ids": ["user_id1", "user_id2"]
43
+ # },
44
+ # "write": {
45
+ # "group_ids": ["group_id1", "group_id2"],
46
+ # "user_ids": ["user_id1", "user_id2"]
47
+ # }
48
+ # }
49
+
50
  updated_at = Column(BigInteger)
51
  created_at = Column(BigInteger)
52
 
 
63
  content: str
64
  specs: list[dict]
65
  meta: ToolMeta
66
+ access_control: Optional[dict] = None
67
+
68
  updated_at: int # timestamp in epoch
69
  created_at: int # timestamp in epoch
70
 
 
81
  user_id: str
82
  name: str
83
  meta: ToolMeta
84
+ access_control: Optional[dict] = None
85
  updated_at: int # timestamp in epoch
86
  created_at: int # timestamp in epoch
87
 
88
 
89
+ class ToolUserResponse(ToolResponse):
90
+ user: Optional[UserResponse] = None
91
+
92
+
93
  class ToolForm(BaseModel):
94
  id: str
95
  name: str
96
  content: str
97
  meta: ToolMeta
98
+ access_control: Optional[dict] = None
99
 
100
 
101
  class ToolValves(BaseModel):
 
138
  except Exception:
139
  return None
140
 
141
+ def get_tools(self) -> list[ToolUserResponse]:
142
  with get_db() as db:
143
+ tools = []
144
+ for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all():
145
+ user = Users.get_user_by_id(tool.user_id)
146
+ tools.append(
147
+ ToolUserResponse.model_validate(
148
+ {
149
+ **ToolModel.model_validate(tool).model_dump(),
150
+ "user": user.model_dump() if user else None,
151
+ }
152
+ )
153
+ )
154
+ return tools
155
+
156
+ def get_tools_by_user_id(
157
+ self, user_id: str, permission: str = "write"
158
+ ) -> list[ToolUserResponse]:
159
+ tools = self.get_tools()
160
+
161
+ return [
162
+ tool
163
+ for tool in tools
164
+ if tool.user_id == user_id
165
+ or has_access(user_id, permission, tool.access_control)
166
+ ]
167
 
168
  def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
169
  try:
backend/open_webui/apps/webui/models/users.py CHANGED
@@ -62,6 +62,14 @@ class UserModel(BaseModel):
62
  ####################
63
 
64
 
 
 
 
 
 
 
 
 
65
  class UserRoleUpdateForm(BaseModel):
66
  id: str
67
  role: str
 
62
  ####################
63
 
64
 
65
+ class UserResponse(BaseModel):
66
+ id: str
67
+ name: str
68
+ email: str
69
+ role: str
70
+ profile_image_url: str
71
+
72
+
73
  class UserRoleUpdateForm(BaseModel):
74
  id: str
75
  role: str
backend/open_webui/apps/webui/routers/auths.py CHANGED
@@ -2,12 +2,14 @@ import re
2
  import uuid
3
  import time
4
  import datetime
 
5
 
6
  from open_webui.apps.webui.models.auths import (
7
  AddUserForm,
8
  ApiKey,
9
  Auths,
10
  Token,
 
11
  SigninForm,
12
  SigninResponse,
13
  SignupForm,
@@ -16,13 +18,15 @@ from open_webui.apps.webui.models.auths import (
16
  UserResponse,
17
  )
18
  from open_webui.apps.webui.models.users import Users
19
- from open_webui.config import WEBUI_AUTH
20
  from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
21
  from open_webui.env import (
 
22
  WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
23
  WEBUI_AUTH_TRUSTED_NAME_HEADER,
24
  WEBUI_SESSION_COOKIE_SAME_SITE,
25
  WEBUI_SESSION_COOKIE_SECURE,
 
26
  )
27
  from fastapi import APIRouter, Depends, HTTPException, Request, status
28
  from fastapi.responses import Response
@@ -37,10 +41,19 @@ from open_webui.utils.utils import (
37
  get_password_hash,
38
  )
39
  from open_webui.utils.webhook import post_webhook
40
- from typing import Optional
 
 
 
 
 
 
41
 
42
  router = APIRouter()
43
 
 
 
 
44
  ############################
45
  # GetSessionUser
46
  ############################
@@ -48,6 +61,7 @@ router = APIRouter()
48
 
49
  class SessionUserResponse(Token, UserResponse):
50
  expires_at: Optional[int] = None
 
51
 
52
 
53
  @router.get("/", response_model=SessionUserResponse)
@@ -80,6 +94,10 @@ async def get_session_user(
80
  secure=WEBUI_SESSION_COOKIE_SECURE,
81
  )
82
 
 
 
 
 
83
  return {
84
  "token": token,
85
  "token_type": "Bearer",
@@ -89,6 +107,7 @@ async def get_session_user(
89
  "name": user.name,
90
  "role": user.role,
91
  "profile_image_url": user.profile_image_url,
 
92
  }
93
 
94
 
@@ -137,6 +156,140 @@ async def update_password(
137
  raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
138
 
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  ############################
141
  # SignIn
142
  ############################
@@ -211,6 +364,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
211
  secure=WEBUI_SESSION_COOKIE_SECURE,
212
  )
213
 
 
 
 
 
214
  return {
215
  "token": token,
216
  "token_type": "Bearer",
@@ -220,6 +377,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
220
  "name": user.name,
221
  "role": user.role,
222
  "profile_image_url": user.profile_image_url,
 
223
  }
224
  else:
225
  raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
@@ -260,6 +418,11 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
260
  if Users.get_num_users() == 0
261
  else request.app.state.config.DEFAULT_USER_ROLE
262
  )
 
 
 
 
 
263
  hashed = get_password_hash(form_data.password)
264
  user = Auths.insert_new_auth(
265
  form_data.email.lower(),
@@ -307,6 +470,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
307
  },
308
  )
309
 
 
 
 
 
310
  return {
311
  "token": token,
312
  "token_type": "Bearer",
@@ -316,6 +483,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
316
  "name": user.name,
317
  "role": user.role,
318
  "profile_image_url": user.profile_image_url,
 
319
  }
320
  else:
321
  raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
@@ -413,6 +581,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
413
  return {
414
  "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
415
  "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
 
416
  "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
417
  "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
418
  "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
@@ -423,6 +592,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
423
  class AdminConfig(BaseModel):
424
  SHOW_ADMIN_DETAILS: bool
425
  ENABLE_SIGNUP: bool
 
426
  DEFAULT_USER_ROLE: str
427
  JWT_EXPIRES_IN: str
428
  ENABLE_COMMUNITY_SHARING: bool
@@ -435,6 +605,7 @@ async def update_admin_config(
435
  ):
436
  request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
437
  request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
 
438
 
439
  if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
440
  request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
@@ -453,6 +624,7 @@ async def update_admin_config(
453
  return {
454
  "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
455
  "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
 
456
  "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
457
  "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
458
  "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
@@ -460,6 +632,105 @@ async def update_admin_config(
460
  }
461
 
462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  ############################
464
  # API Key
465
  ############################
@@ -467,9 +738,16 @@ async def update_admin_config(
467
 
468
  # create api key
469
  @router.post("/api_key", response_model=ApiKey)
470
- async def create_api_key_(user=Depends(get_current_user)):
 
 
 
 
 
 
471
  api_key = create_api_key()
472
  success = Users.update_user_api_key_by_id(user.id, api_key)
 
473
  if success:
474
  return {
475
  "api_key": api_key,
 
2
  import uuid
3
  import time
4
  import datetime
5
+ import logging
6
 
7
  from open_webui.apps.webui.models.auths import (
8
  AddUserForm,
9
  ApiKey,
10
  Auths,
11
  Token,
12
+ LdapForm,
13
  SigninForm,
14
  SigninResponse,
15
  SignupForm,
 
18
  UserResponse,
19
  )
20
  from open_webui.apps.webui.models.users import Users
21
+
22
  from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
23
  from open_webui.env import (
24
+ WEBUI_AUTH,
25
  WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
26
  WEBUI_AUTH_TRUSTED_NAME_HEADER,
27
  WEBUI_SESSION_COOKIE_SAME_SITE,
28
  WEBUI_SESSION_COOKIE_SECURE,
29
+ SRC_LOG_LEVELS,
30
  )
31
  from fastapi import APIRouter, Depends, HTTPException, Request, status
32
  from fastapi.responses import Response
 
41
  get_password_hash,
42
  )
43
  from open_webui.utils.webhook import post_webhook
44
+ from open_webui.utils.access_control import get_permissions
45
+
46
+ from typing import Optional, List
47
+
48
+ from ssl import CERT_REQUIRED, PROTOCOL_TLS
49
+ from ldap3 import Server, Connection, ALL, Tls
50
+ from ldap3.utils.conv import escape_filter_chars
51
 
52
  router = APIRouter()
53
 
54
+ log = logging.getLogger(__name__)
55
+ log.setLevel(SRC_LOG_LEVELS["MAIN"])
56
+
57
  ############################
58
  # GetSessionUser
59
  ############################
 
61
 
62
  class SessionUserResponse(Token, UserResponse):
63
  expires_at: Optional[int] = None
64
+ permissions: Optional[dict] = None
65
 
66
 
67
  @router.get("/", response_model=SessionUserResponse)
 
94
  secure=WEBUI_SESSION_COOKIE_SECURE,
95
  )
96
 
97
+ user_permissions = get_permissions(
98
+ user.id, request.app.state.config.USER_PERMISSIONS
99
+ )
100
+
101
  return {
102
  "token": token,
103
  "token_type": "Bearer",
 
107
  "name": user.name,
108
  "role": user.role,
109
  "profile_image_url": user.profile_image_url,
110
+ "permissions": user_permissions,
111
  }
112
 
113
 
 
156
  raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
157
 
158
 
159
+ ############################
160
+ # LDAP Authentication
161
+ ############################
162
+ @router.post("/ldap", response_model=SigninResponse)
163
+ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
164
+ ENABLE_LDAP = request.app.state.config.ENABLE_LDAP
165
+ LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL
166
+ LDAP_SERVER_HOST = request.app.state.config.LDAP_SERVER_HOST
167
+ LDAP_SERVER_PORT = request.app.state.config.LDAP_SERVER_PORT
168
+ LDAP_ATTRIBUTE_FOR_USERNAME = request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME
169
+ LDAP_SEARCH_BASE = request.app.state.config.LDAP_SEARCH_BASE
170
+ LDAP_SEARCH_FILTERS = request.app.state.config.LDAP_SEARCH_FILTERS
171
+ LDAP_APP_DN = request.app.state.config.LDAP_APP_DN
172
+ LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD
173
+ LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS
174
+ LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE
175
+ LDAP_CIPHERS = (
176
+ request.app.state.config.LDAP_CIPHERS
177
+ if request.app.state.config.LDAP_CIPHERS
178
+ else "ALL"
179
+ )
180
+
181
+ if not ENABLE_LDAP:
182
+ raise HTTPException(400, detail="LDAP authentication is not enabled")
183
+
184
+ try:
185
+ tls = Tls(
186
+ validate=CERT_REQUIRED,
187
+ version=PROTOCOL_TLS,
188
+ ca_certs_file=LDAP_CA_CERT_FILE,
189
+ ciphers=LDAP_CIPHERS,
190
+ )
191
+ except Exception as e:
192
+ log.error(f"An error occurred on TLS: {str(e)}")
193
+ raise HTTPException(400, detail=str(e))
194
+
195
+ try:
196
+ server = Server(
197
+ host=LDAP_SERVER_HOST,
198
+ port=LDAP_SERVER_PORT,
199
+ get_info=ALL,
200
+ use_ssl=LDAP_USE_TLS,
201
+ tls=tls,
202
+ )
203
+ connection_app = Connection(
204
+ server,
205
+ LDAP_APP_DN,
206
+ LDAP_APP_PASSWORD,
207
+ auto_bind="NONE",
208
+ authentication="SIMPLE",
209
+ )
210
+ if not connection_app.bind():
211
+ raise HTTPException(400, detail="Application account bind failed")
212
+
213
+ search_success = connection_app.search(
214
+ search_base=LDAP_SEARCH_BASE,
215
+ search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})",
216
+ attributes=[f"{LDAP_ATTRIBUTE_FOR_USERNAME}", "mail", "cn"],
217
+ )
218
+
219
+ if not search_success:
220
+ raise HTTPException(400, detail="User not found in the LDAP server")
221
+
222
+ entry = connection_app.entries[0]
223
+ username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
224
+ mail = str(entry["mail"])
225
+ cn = str(entry["cn"])
226
+ user_dn = entry.entry_dn
227
+
228
+ if username == form_data.user.lower():
229
+ connection_user = Connection(
230
+ server,
231
+ user_dn,
232
+ form_data.password,
233
+ auto_bind="NONE",
234
+ authentication="SIMPLE",
235
+ )
236
+ if not connection_user.bind():
237
+ raise HTTPException(400, f"Authentication failed for {form_data.user}")
238
+
239
+ user = Users.get_user_by_email(mail)
240
+ if not user:
241
+
242
+ try:
243
+ hashed = get_password_hash(form_data.password)
244
+ user = Auths.insert_new_auth(mail, hashed, cn)
245
+
246
+ if not user:
247
+ raise HTTPException(
248
+ 500, detail=ERROR_MESSAGES.CREATE_USER_ERROR
249
+ )
250
+
251
+ except HTTPException:
252
+ raise
253
+ except Exception as err:
254
+ raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
255
+
256
+ user = Auths.authenticate_user(mail, password=str(form_data.password))
257
+
258
+ if user:
259
+ token = create_token(
260
+ data={"id": user.id},
261
+ expires_delta=parse_duration(
262
+ request.app.state.config.JWT_EXPIRES_IN
263
+ ),
264
+ )
265
+
266
+ # Set the cookie token
267
+ response.set_cookie(
268
+ key="token",
269
+ value=token,
270
+ httponly=True, # Ensures the cookie is not accessible via JavaScript
271
+ )
272
+
273
+ return {
274
+ "token": token,
275
+ "token_type": "Bearer",
276
+ "id": user.id,
277
+ "email": user.email,
278
+ "name": user.name,
279
+ "role": user.role,
280
+ "profile_image_url": user.profile_image_url,
281
+ }
282
+ else:
283
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
284
+ else:
285
+ raise HTTPException(
286
+ 400,
287
+ f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}",
288
+ )
289
+ except Exception as e:
290
+ raise HTTPException(400, detail=str(e))
291
+
292
+
293
  ############################
294
  # SignIn
295
  ############################
 
364
  secure=WEBUI_SESSION_COOKIE_SECURE,
365
  )
366
 
367
+ user_permissions = get_permissions(
368
+ user.id, request.app.state.config.USER_PERMISSIONS
369
+ )
370
+
371
  return {
372
  "token": token,
373
  "token_type": "Bearer",
 
377
  "name": user.name,
378
  "role": user.role,
379
  "profile_image_url": user.profile_image_url,
380
+ "permissions": user_permissions,
381
  }
382
  else:
383
  raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
 
418
  if Users.get_num_users() == 0
419
  else request.app.state.config.DEFAULT_USER_ROLE
420
  )
421
+
422
+ if Users.get_num_users() == 0:
423
+ # Disable signup after the first user is created
424
+ request.app.state.config.ENABLE_SIGNUP = False
425
+
426
  hashed = get_password_hash(form_data.password)
427
  user = Auths.insert_new_auth(
428
  form_data.email.lower(),
 
470
  },
471
  )
472
 
473
+ user_permissions = get_permissions(
474
+ user.id, request.app.state.config.USER_PERMISSIONS
475
+ )
476
+
477
  return {
478
  "token": token,
479
  "token_type": "Bearer",
 
483
  "name": user.name,
484
  "role": user.role,
485
  "profile_image_url": user.profile_image_url,
486
+ "permissions": user_permissions,
487
  }
488
  else:
489
  raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
 
581
  return {
582
  "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
583
  "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
584
+ "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
585
  "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
586
  "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
587
  "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
 
592
  class AdminConfig(BaseModel):
593
  SHOW_ADMIN_DETAILS: bool
594
  ENABLE_SIGNUP: bool
595
+ ENABLE_API_KEY: bool
596
  DEFAULT_USER_ROLE: str
597
  JWT_EXPIRES_IN: str
598
  ENABLE_COMMUNITY_SHARING: bool
 
605
  ):
606
  request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
607
  request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
608
+ request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY
609
 
610
  if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
611
  request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
 
624
  return {
625
  "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
626
  "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
627
+ "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
628
  "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
629
  "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
630
  "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
 
632
  }
633
 
634
 
635
+ class LdapServerConfig(BaseModel):
636
+ label: str
637
+ host: str
638
+ port: Optional[int] = None
639
+ attribute_for_username: str = "uid"
640
+ app_dn: str
641
+ app_dn_password: str
642
+ search_base: str
643
+ search_filters: str = ""
644
+ use_tls: bool = True
645
+ certificate_path: Optional[str] = None
646
+ ciphers: Optional[str] = "ALL"
647
+
648
+
649
+ @router.get("/admin/config/ldap/server", response_model=LdapServerConfig)
650
+ async def get_ldap_server(request: Request, user=Depends(get_admin_user)):
651
+ return {
652
+ "label": request.app.state.config.LDAP_SERVER_LABEL,
653
+ "host": request.app.state.config.LDAP_SERVER_HOST,
654
+ "port": request.app.state.config.LDAP_SERVER_PORT,
655
+ "attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME,
656
+ "app_dn": request.app.state.config.LDAP_APP_DN,
657
+ "app_dn_password": request.app.state.config.LDAP_APP_PASSWORD,
658
+ "search_base": request.app.state.config.LDAP_SEARCH_BASE,
659
+ "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
660
+ "use_tls": request.app.state.config.LDAP_USE_TLS,
661
+ "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
662
+ "ciphers": request.app.state.config.LDAP_CIPHERS,
663
+ }
664
+
665
+
666
+ @router.post("/admin/config/ldap/server")
667
+ async def update_ldap_server(
668
+ request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user)
669
+ ):
670
+ required_fields = [
671
+ "label",
672
+ "host",
673
+ "attribute_for_username",
674
+ "app_dn",
675
+ "app_dn_password",
676
+ "search_base",
677
+ ]
678
+ for key in required_fields:
679
+ value = getattr(form_data, key)
680
+ if not value:
681
+ raise HTTPException(400, detail=f"Required field {key} is empty")
682
+
683
+ if form_data.use_tls and not form_data.certificate_path:
684
+ raise HTTPException(
685
+ 400, detail="TLS is enabled but certificate file path is missing"
686
+ )
687
+
688
+ request.app.state.config.LDAP_SERVER_LABEL = form_data.label
689
+ request.app.state.config.LDAP_SERVER_HOST = form_data.host
690
+ request.app.state.config.LDAP_SERVER_PORT = form_data.port
691
+ request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = (
692
+ form_data.attribute_for_username
693
+ )
694
+ request.app.state.config.LDAP_APP_DN = form_data.app_dn
695
+ request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password
696
+ request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base
697
+ request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters
698
+ request.app.state.config.LDAP_USE_TLS = form_data.use_tls
699
+ request.app.state.config.LDAP_CA_CERT_FILE = form_data.certificate_path
700
+ request.app.state.config.LDAP_CIPHERS = form_data.ciphers
701
+
702
+ return {
703
+ "label": request.app.state.config.LDAP_SERVER_LABEL,
704
+ "host": request.app.state.config.LDAP_SERVER_HOST,
705
+ "port": request.app.state.config.LDAP_SERVER_PORT,
706
+ "attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME,
707
+ "app_dn": request.app.state.config.LDAP_APP_DN,
708
+ "app_dn_password": request.app.state.config.LDAP_APP_PASSWORD,
709
+ "search_base": request.app.state.config.LDAP_SEARCH_BASE,
710
+ "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
711
+ "use_tls": request.app.state.config.LDAP_USE_TLS,
712
+ "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
713
+ "ciphers": request.app.state.config.LDAP_CIPHERS,
714
+ }
715
+
716
+
717
+ @router.get("/admin/config/ldap")
718
+ async def get_ldap_config(request: Request, user=Depends(get_admin_user)):
719
+ return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
720
+
721
+
722
+ class LdapConfigForm(BaseModel):
723
+ enable_ldap: Optional[bool] = None
724
+
725
+
726
+ @router.post("/admin/config/ldap")
727
+ async def update_ldap_config(
728
+ request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user)
729
+ ):
730
+ request.app.state.config.ENABLE_LDAP = form_data.enable_ldap
731
+ return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
732
+
733
+
734
  ############################
735
  # API Key
736
  ############################
 
738
 
739
  # create api key
740
  @router.post("/api_key", response_model=ApiKey)
741
+ async def generate_api_key(request: Request, user=Depends(get_current_user)):
742
+ if not request.app.state.config.ENABLE_API_KEY:
743
+ raise HTTPException(
744
+ status.HTTP_403_FORBIDDEN,
745
+ detail=ERROR_MESSAGES.API_KEY_CREATION_NOT_ALLOWED,
746
+ )
747
+
748
  api_key = create_api_key()
749
  success = Users.update_user_api_key_by_id(user.id, api_key)
750
+
751
  if success:
752
  return {
753
  "api_key": api_key,
backend/open_webui/apps/webui/routers/chats.py CHANGED
@@ -17,7 +17,10 @@ from open_webui.constants import ERROR_MESSAGES
17
  from open_webui.env import SRC_LOG_LEVELS
18
  from fastapi import APIRouter, Depends, HTTPException, Request, status
19
  from pydantic import BaseModel
 
 
20
  from open_webui.utils.utils import get_admin_user, get_verified_user
 
21
 
22
  log = logging.getLogger(__name__)
23
  log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -50,9 +53,10 @@ async def get_session_user_chat_list(
50
 
51
  @router.delete("/", response_model=bool)
52
  async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
53
- if user.role == "user" and not request.app.state.config.USER_PERMISSIONS.get(
54
- "chat", {}
55
- ).get("deletion", {}):
 
56
  raise HTTPException(
57
  status_code=status.HTTP_401_UNAUTHORIZED,
58
  detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
@@ -385,8 +389,8 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
385
 
386
  return result
387
  else:
388
- if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get(
389
- "deletion", {}
390
  ):
391
  raise HTTPException(
392
  status_code=status.HTTP_401_UNAUTHORIZED,
 
17
  from open_webui.env import SRC_LOG_LEVELS
18
  from fastapi import APIRouter, Depends, HTTPException, Request, status
19
  from pydantic import BaseModel
20
+
21
+
22
  from open_webui.utils.utils import get_admin_user, get_verified_user
23
+ from open_webui.utils.access_control import has_permission
24
 
25
  log = logging.getLogger(__name__)
26
  log.setLevel(SRC_LOG_LEVELS["MODELS"])
 
53
 
54
  @router.delete("/", response_model=bool)
55
  async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
56
+
57
+ if user.role == "user" and not has_permission(
58
+ user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
59
+ ):
60
  raise HTTPException(
61
  status_code=status.HTTP_401_UNAUTHORIZED,
62
  detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
 
389
 
390
  return result
391
  else:
392
+ if not has_permission(
393
+ user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
394
  ):
395
  raise HTTPException(
396
  status_code=status.HTTP_401_UNAUTHORIZED,
backend/open_webui/apps/webui/routers/groups.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ from open_webui.apps.webui.models.groups import (
6
+ Groups,
7
+ GroupForm,
8
+ GroupUpdateForm,
9
+ GroupResponse,
10
+ )
11
+
12
+ from open_webui.config import CACHE_DIR
13
+ from open_webui.constants import ERROR_MESSAGES
14
+ from fastapi import APIRouter, Depends, HTTPException, Request, status
15
+ from open_webui.utils.utils import get_admin_user, get_verified_user
16
+
17
+ router = APIRouter()
18
+
19
+ ############################
20
+ # GetFunctions
21
+ ############################
22
+
23
+
24
+ @router.get("/", response_model=list[GroupResponse])
25
+ async def get_groups(user=Depends(get_verified_user)):
26
+ if user.role == "admin":
27
+ return Groups.get_groups()
28
+ else:
29
+ return Groups.get_groups_by_member_id(user.id)
30
+
31
+
32
+ ############################
33
+ # CreateNewGroup
34
+ ############################
35
+
36
+
37
+ @router.post("/create", response_model=Optional[GroupResponse])
38
+ async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)):
39
+ try:
40
+ group = Groups.insert_new_group(user.id, form_data)
41
+ if group:
42
+ return group
43
+ else:
44
+ raise HTTPException(
45
+ status_code=status.HTTP_400_BAD_REQUEST,
46
+ detail=ERROR_MESSAGES.DEFAULT("Error creating group"),
47
+ )
48
+ except Exception as e:
49
+ print(e)
50
+ raise HTTPException(
51
+ status_code=status.HTTP_400_BAD_REQUEST,
52
+ detail=ERROR_MESSAGES.DEFAULT(e),
53
+ )
54
+
55
+
56
+ ############################
57
+ # GetGroupById
58
+ ############################
59
+
60
+
61
+ @router.get("/id/{id}", response_model=Optional[GroupResponse])
62
+ async def get_group_by_id(id: str, user=Depends(get_admin_user)):
63
+ group = Groups.get_group_by_id(id)
64
+ if group:
65
+ return group
66
+ else:
67
+ raise HTTPException(
68
+ status_code=status.HTTP_401_UNAUTHORIZED,
69
+ detail=ERROR_MESSAGES.NOT_FOUND,
70
+ )
71
+
72
+
73
+ ############################
74
+ # UpdateGroupById
75
+ ############################
76
+
77
+
78
+ @router.post("/id/{id}/update", response_model=Optional[GroupResponse])
79
+ async def update_group_by_id(
80
+ id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
81
+ ):
82
+ try:
83
+ group = Groups.update_group_by_id(id, form_data)
84
+ if group:
85
+ return group
86
+ else:
87
+ raise HTTPException(
88
+ status_code=status.HTTP_400_BAD_REQUEST,
89
+ detail=ERROR_MESSAGES.DEFAULT("Error updating group"),
90
+ )
91
+ except Exception as e:
92
+ print(e)
93
+ raise HTTPException(
94
+ status_code=status.HTTP_400_BAD_REQUEST,
95
+ detail=ERROR_MESSAGES.DEFAULT(e),
96
+ )
97
+
98
+
99
+ ############################
100
+ # DeleteGroupById
101
+ ############################
102
+
103
+
104
+ @router.delete("/id/{id}/delete", response_model=bool)
105
+ async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
106
+ try:
107
+ result = Groups.delete_group_by_id(id)
108
+ if result:
109
+ return result
110
+ else:
111
+ raise HTTPException(
112
+ status_code=status.HTTP_400_BAD_REQUEST,
113
+ detail=ERROR_MESSAGES.DEFAULT("Error deleting group"),
114
+ )
115
+ except Exception as e:
116
+ print(e)
117
+ raise HTTPException(
118
+ status_code=status.HTTP_400_BAD_REQUEST,
119
+ detail=ERROR_MESSAGES.DEFAULT(e),
120
+ )
backend/open_webui/apps/webui/routers/knowledge.py CHANGED
@@ -1,14 +1,14 @@
1
  import json
2
  from typing import Optional, Union
3
  from pydantic import BaseModel
4
- from fastapi import APIRouter, Depends, HTTPException, status
5
  import logging
6
 
7
  from open_webui.apps.webui.models.knowledge import (
8
  Knowledges,
9
- KnowledgeUpdateForm,
10
  KnowledgeForm,
11
  KnowledgeResponse,
 
12
  )
13
  from open_webui.apps.webui.models.files import Files, FileModel
14
  from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
@@ -17,6 +17,9 @@ from open_webui.apps.retrieval.main import process_file, ProcessFileForm
17
 
18
  from open_webui.constants import ERROR_MESSAGES
19
  from open_webui.utils.utils import get_admin_user, get_verified_user
 
 
 
20
  from open_webui.env import SRC_LOG_LEVELS
21
 
22
 
@@ -26,64 +29,98 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
26
  router = APIRouter()
27
 
28
  ############################
29
- # GetKnowledgeItems
30
  ############################
31
 
32
 
33
- @router.get(
34
- "/", response_model=Optional[Union[list[KnowledgeResponse], KnowledgeResponse]]
35
- )
36
- async def get_knowledge_items(
37
- id: Optional[str] = None, user=Depends(get_verified_user)
38
- ):
39
- if id:
40
- knowledge = Knowledges.get_knowledge_by_id(id=id)
41
 
42
- if knowledge:
43
- return knowledge
44
- else:
45
- raise HTTPException(
46
- status_code=status.HTTP_401_UNAUTHORIZED,
47
- detail=ERROR_MESSAGES.NOT_FOUND,
48
- )
49
  else:
50
- knowledge_bases = []
51
-
52
- for knowledge in Knowledges.get_knowledge_items():
 
 
 
 
 
 
53
 
54
- files = []
55
- if knowledge.data:
56
- files = Files.get_file_metadatas_by_ids(
57
- knowledge.data.get("file_ids", [])
 
58
  )
 
 
 
 
 
 
59
 
60
- # Check if all files exist
61
- if len(files) != len(knowledge.data.get("file_ids", [])):
62
- missing_files = list(
63
- set(knowledge.data.get("file_ids", []))
64
- - set([file.id for file in files])
65
  )
66
- if missing_files:
67
- data = knowledge.data or {}
68
- file_ids = data.get("file_ids", [])
69
 
70
- for missing_file in missing_files:
71
- file_ids.remove(missing_file)
 
 
 
 
72
 
73
- data["file_ids"] = file_ids
74
- Knowledges.update_knowledge_by_id(
75
- id=knowledge.id, form_data=KnowledgeUpdateForm(data=data)
76
- )
77
 
78
- files = Files.get_file_metadatas_by_ids(file_ids)
79
 
80
- knowledge_bases.append(
81
- KnowledgeResponse(
82
- **knowledge.model_dump(),
83
- files=files,
84
- )
 
 
 
 
 
 
 
 
 
 
85
  )
86
- return knowledge_bases
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
 
89
  ############################
@@ -92,7 +129,17 @@ async def get_knowledge_items(
92
 
93
 
94
  @router.post("/create", response_model=Optional[KnowledgeResponse])
95
- async def create_new_knowledge(form_data: KnowledgeForm, user=Depends(get_admin_user)):
 
 
 
 
 
 
 
 
 
 
96
  knowledge = Knowledges.insert_new_knowledge(user.id, form_data)
97
 
98
  if knowledge:
@@ -118,13 +165,20 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
118
  knowledge = Knowledges.get_knowledge_by_id(id=id)
119
 
120
  if knowledge:
121
- file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
122
- files = Files.get_files_by_ids(file_ids)
123
 
124
- return KnowledgeFilesResponse(
125
- **knowledge.model_dump(),
126
- files=files,
127
- )
 
 
 
 
 
 
 
 
 
128
  else:
129
  raise HTTPException(
130
  status_code=status.HTTP_401_UNAUTHORIZED,
@@ -140,11 +194,23 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
140
  @router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse])
141
  async def update_knowledge_by_id(
142
  id: str,
143
- form_data: KnowledgeUpdateForm,
144
- user=Depends(get_admin_user),
145
  ):
146
- knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
 
 
 
 
 
 
 
 
 
 
 
147
 
 
148
  if knowledge:
149
  file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
150
  files = Files.get_files_by_ids(file_ids)
@@ -173,9 +239,22 @@ class KnowledgeFileIdForm(BaseModel):
173
  def add_file_to_knowledge_by_id(
174
  id: str,
175
  form_data: KnowledgeFileIdForm,
176
- user=Depends(get_admin_user),
177
  ):
178
  knowledge = Knowledges.get_knowledge_by_id(id=id)
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  file = Files.get_file_by_id(form_data.file_id)
180
  if not file:
181
  raise HTTPException(
@@ -206,9 +285,7 @@ def add_file_to_knowledge_by_id(
206
  file_ids.append(form_data.file_id)
207
  data["file_ids"] = file_ids
208
 
209
- knowledge = Knowledges.update_knowledge_by_id(
210
- id=id, form_data=KnowledgeUpdateForm(data=data)
211
- )
212
 
213
  if knowledge:
214
  files = Files.get_files_by_ids(file_ids)
@@ -238,9 +315,21 @@ def add_file_to_knowledge_by_id(
238
  def update_file_from_knowledge_by_id(
239
  id: str,
240
  form_data: KnowledgeFileIdForm,
241
- user=Depends(get_admin_user),
242
  ):
243
  knowledge = Knowledges.get_knowledge_by_id(id=id)
 
 
 
 
 
 
 
 
 
 
 
 
244
  file = Files.get_file_by_id(form_data.file_id)
245
  if not file:
246
  raise HTTPException(
@@ -288,9 +377,21 @@ def update_file_from_knowledge_by_id(
288
  def remove_file_from_knowledge_by_id(
289
  id: str,
290
  form_data: KnowledgeFileIdForm,
291
- user=Depends(get_admin_user),
292
  ):
293
  knowledge = Knowledges.get_knowledge_by_id(id=id)
 
 
 
 
 
 
 
 
 
 
 
 
294
  file = Files.get_file_by_id(form_data.file_id)
295
  if not file:
296
  raise HTTPException(
@@ -318,9 +419,7 @@ def remove_file_from_knowledge_by_id(
318
  file_ids.remove(form_data.file_id)
319
  data["file_ids"] = file_ids
320
 
321
- knowledge = Knowledges.update_knowledge_by_id(
322
- id=id, form_data=KnowledgeUpdateForm(data=data)
323
- )
324
 
325
  if knowledge:
326
  files = Files.get_files_by_ids(file_ids)
@@ -347,35 +446,60 @@ def remove_file_from_knowledge_by_id(
347
 
348
 
349
  ############################
350
- # ResetKnowledgeById
351
  ############################
352
 
353
 
354
- @router.post("/{id}/reset", response_model=Optional[KnowledgeResponse])
355
- async def reset_knowledge_by_id(id: str, user=Depends(get_admin_user)):
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  try:
357
  VECTOR_DB_CLIENT.delete_collection(collection_name=id)
358
  except Exception as e:
359
  log.debug(e)
360
  pass
361
-
362
- knowledge = Knowledges.update_knowledge_by_id(
363
- id=id, form_data=KnowledgeUpdateForm(data={"file_ids": []})
364
- )
365
- return knowledge
366
 
367
 
368
  ############################
369
- # DeleteKnowledgeById
370
  ############################
371
 
372
 
373
- @router.delete("/{id}/delete", response_model=bool)
374
- async def delete_knowledge_by_id(id: str, user=Depends(get_admin_user)):
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  try:
376
  VECTOR_DB_CLIENT.delete_collection(collection_name=id)
377
  except Exception as e:
378
  log.debug(e)
379
  pass
380
- result = Knowledges.delete_knowledge_by_id(id=id)
381
- return result
 
 
 
1
  import json
2
  from typing import Optional, Union
3
  from pydantic import BaseModel
4
+ from fastapi import APIRouter, Depends, HTTPException, status, Request
5
  import logging
6
 
7
  from open_webui.apps.webui.models.knowledge import (
8
  Knowledges,
 
9
  KnowledgeForm,
10
  KnowledgeResponse,
11
+ KnowledgeUserResponse,
12
  )
13
  from open_webui.apps.webui.models.files import Files, FileModel
14
  from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
 
17
 
18
  from open_webui.constants import ERROR_MESSAGES
19
  from open_webui.utils.utils import get_admin_user, get_verified_user
20
+ from open_webui.utils.access_control import has_access, has_permission
21
+
22
+
23
  from open_webui.env import SRC_LOG_LEVELS
24
 
25
 
 
29
  router = APIRouter()
30
 
31
  ############################
32
+ # getKnowledgeBases
33
  ############################
34
 
35
 
36
+ @router.get("/", response_model=list[KnowledgeUserResponse])
37
+ async def get_knowledge(user=Depends(get_verified_user)):
38
+ knowledge_bases = []
 
 
 
 
 
39
 
40
+ if user.role == "admin":
41
+ knowledge_bases = Knowledges.get_knowledge_bases()
 
 
 
 
 
42
  else:
43
+ knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read")
44
+
45
+ # Get files for each knowledge base
46
+ for knowledge_base in knowledge_bases:
47
+ files = []
48
+ if knowledge_base.data:
49
+ files = Files.get_file_metadatas_by_ids(
50
+ knowledge_base.data.get("file_ids", [])
51
+ )
52
 
53
+ # Check if all files exist
54
+ if len(files) != len(knowledge_base.data.get("file_ids", [])):
55
+ missing_files = list(
56
+ set(knowledge_base.data.get("file_ids", []))
57
+ - set([file.id for file in files])
58
  )
59
+ if missing_files:
60
+ data = knowledge_base.data or {}
61
+ file_ids = data.get("file_ids", [])
62
+
63
+ for missing_file in missing_files:
64
+ file_ids.remove(missing_file)
65
 
66
+ data["file_ids"] = file_ids
67
+ Knowledges.update_knowledge_data_by_id(
68
+ id=knowledge_base.id, data=data
 
 
69
  )
 
 
 
70
 
71
+ files = Files.get_file_metadatas_by_ids(file_ids)
72
+
73
+ knowledge_base = KnowledgeResponse(
74
+ **knowledge_base.model_dump(),
75
+ files=files,
76
+ )
77
 
78
+ return knowledge_bases
 
 
 
79
 
 
80
 
81
+ @router.get("/list", response_model=list[KnowledgeUserResponse])
82
+ async def get_knowledge_list(user=Depends(get_verified_user)):
83
+ knowledge_bases = []
84
+
85
+ if user.role == "admin":
86
+ knowledge_bases = Knowledges.get_knowledge_bases()
87
+ else:
88
+ knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "write")
89
+
90
+ # Get files for each knowledge base
91
+ for knowledge_base in knowledge_bases:
92
+ files = []
93
+ if knowledge_base.data:
94
+ files = Files.get_file_metadatas_by_ids(
95
+ knowledge_base.data.get("file_ids", [])
96
  )
97
+
98
+ # Check if all files exist
99
+ if len(files) != len(knowledge_base.data.get("file_ids", [])):
100
+ missing_files = list(
101
+ set(knowledge_base.data.get("file_ids", []))
102
+ - set([file.id for file in files])
103
+ )
104
+ if missing_files:
105
+ data = knowledge_base.data or {}
106
+ file_ids = data.get("file_ids", [])
107
+
108
+ for missing_file in missing_files:
109
+ file_ids.remove(missing_file)
110
+
111
+ data["file_ids"] = file_ids
112
+ Knowledges.update_knowledge_data_by_id(
113
+ id=knowledge_base.id, data=data
114
+ )
115
+
116
+ files = Files.get_file_metadatas_by_ids(file_ids)
117
+
118
+ knowledge_base = KnowledgeResponse(
119
+ **knowledge_base.model_dump(),
120
+ files=files,
121
+ )
122
+
123
+ return knowledge_bases
124
 
125
 
126
  ############################
 
129
 
130
 
131
  @router.post("/create", response_model=Optional[KnowledgeResponse])
132
+ async def create_new_knowledge(
133
+ request: Request, form_data: KnowledgeForm, user=Depends(get_verified_user)
134
+ ):
135
+ if user.role != "admin" and not has_permission(
136
+ user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS
137
+ ):
138
+ raise HTTPException(
139
+ status_code=status.HTTP_401_UNAUTHORIZED,
140
+ detail=ERROR_MESSAGES.UNAUTHORIZED,
141
+ )
142
+
143
  knowledge = Knowledges.insert_new_knowledge(user.id, form_data)
144
 
145
  if knowledge:
 
165
  knowledge = Knowledges.get_knowledge_by_id(id=id)
166
 
167
  if knowledge:
 
 
168
 
169
+ if (
170
+ user.role == "admin"
171
+ or knowledge.user_id == user.id
172
+ or has_access(user.id, "read", knowledge.access_control)
173
+ ):
174
+
175
+ file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
176
+ files = Files.get_files_by_ids(file_ids)
177
+
178
+ return KnowledgeFilesResponse(
179
+ **knowledge.model_dump(),
180
+ files=files,
181
+ )
182
  else:
183
  raise HTTPException(
184
  status_code=status.HTTP_401_UNAUTHORIZED,
 
194
  @router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse])
195
  async def update_knowledge_by_id(
196
  id: str,
197
+ form_data: KnowledgeForm,
198
+ user=Depends(get_verified_user),
199
  ):
200
+ knowledge = Knowledges.get_knowledge_by_id(id=id)
201
+ if not knowledge:
202
+ raise HTTPException(
203
+ status_code=status.HTTP_400_BAD_REQUEST,
204
+ detail=ERROR_MESSAGES.NOT_FOUND,
205
+ )
206
+
207
+ if knowledge.user_id != user.id and user.role != "admin":
208
+ raise HTTPException(
209
+ status_code=status.HTTP_400_BAD_REQUEST,
210
+ detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
211
+ )
212
 
213
+ knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
214
  if knowledge:
215
  file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
216
  files = Files.get_files_by_ids(file_ids)
 
239
  def add_file_to_knowledge_by_id(
240
  id: str,
241
  form_data: KnowledgeFileIdForm,
242
+ user=Depends(get_verified_user),
243
  ):
244
  knowledge = Knowledges.get_knowledge_by_id(id=id)
245
+
246
+ if not knowledge:
247
+ raise HTTPException(
248
+ status_code=status.HTTP_400_BAD_REQUEST,
249
+ detail=ERROR_MESSAGES.NOT_FOUND,
250
+ )
251
+
252
+ if knowledge.user_id != user.id and user.role != "admin":
253
+ raise HTTPException(
254
+ status_code=status.HTTP_400_BAD_REQUEST,
255
+ detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
256
+ )
257
+
258
  file = Files.get_file_by_id(form_data.file_id)
259
  if not file:
260
  raise HTTPException(
 
285
  file_ids.append(form_data.file_id)
286
  data["file_ids"] = file_ids
287
 
288
+ knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
 
 
289
 
290
  if knowledge:
291
  files = Files.get_files_by_ids(file_ids)
 
315
  def update_file_from_knowledge_by_id(
316
  id: str,
317
  form_data: KnowledgeFileIdForm,
318
+ user=Depends(get_verified_user),
319
  ):
320
  knowledge = Knowledges.get_knowledge_by_id(id=id)
321
+ if not knowledge:
322
+ raise HTTPException(
323
+ status_code=status.HTTP_400_BAD_REQUEST,
324
+ detail=ERROR_MESSAGES.NOT_FOUND,
325
+ )
326
+
327
+ if knowledge.user_id != user.id and user.role != "admin":
328
+ raise HTTPException(
329
+ status_code=status.HTTP_400_BAD_REQUEST,
330
+ detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
331
+ )
332
+
333
  file = Files.get_file_by_id(form_data.file_id)
334
  if not file:
335
  raise HTTPException(
 
377
  def remove_file_from_knowledge_by_id(
378
  id: str,
379
  form_data: KnowledgeFileIdForm,
380
+ user=Depends(get_verified_user),
381
  ):
382
  knowledge = Knowledges.get_knowledge_by_id(id=id)
383
+ if not knowledge:
384
+ raise HTTPException(
385
+ status_code=status.HTTP_400_BAD_REQUEST,
386
+ detail=ERROR_MESSAGES.NOT_FOUND,
387
+ )
388
+
389
+ if knowledge.user_id != user.id and user.role != "admin":
390
+ raise HTTPException(
391
+ status_code=status.HTTP_400_BAD_REQUEST,
392
+ detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
393
+ )
394
+
395
  file = Files.get_file_by_id(form_data.file_id)
396
  if not file:
397
  raise HTTPException(
 
419
  file_ids.remove(form_data.file_id)
420
  data["file_ids"] = file_ids
421
 
422
+ knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
 
 
423
 
424
  if knowledge:
425
  files = Files.get_files_by_ids(file_ids)
 
446
 
447
 
448
  ############################
449
+ # DeleteKnowledgeById
450
  ############################
451
 
452
 
453
+ @router.delete("/{id}/delete", response_model=bool)
454
+ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
455
+ knowledge = Knowledges.get_knowledge_by_id(id=id)
456
+ if not knowledge:
457
+ raise HTTPException(
458
+ status_code=status.HTTP_400_BAD_REQUEST,
459
+ detail=ERROR_MESSAGES.NOT_FOUND,
460
+ )
461
+
462
+ if knowledge.user_id != user.id and user.role != "admin":
463
+ raise HTTPException(
464
+ status_code=status.HTTP_400_BAD_REQUEST,
465
+ detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
466
+ )
467
+
468
  try:
469
  VECTOR_DB_CLIENT.delete_collection(collection_name=id)
470
  except Exception as e:
471
  log.debug(e)
472
  pass
473
+ result = Knowledges.delete_knowledge_by_id(id=id)
474
+ return result
 
 
 
475
 
476
 
477
  ############################
478
+ # ResetKnowledgeById
479
  ############################
480
 
481
 
482
+ @router.post("/{id}/reset", response_model=Optional[KnowledgeResponse])
483
+ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
484
+ knowledge = Knowledges.get_knowledge_by_id(id=id)
485
+ if not knowledge:
486
+ raise HTTPException(
487
+ status_code=status.HTTP_400_BAD_REQUEST,
488
+ detail=ERROR_MESSAGES.NOT_FOUND,
489
+ )
490
+
491
+ if knowledge.user_id != user.id and user.role != "admin":
492
+ raise HTTPException(
493
+ status_code=status.HTTP_400_BAD_REQUEST,
494
+ detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
495
+ )
496
+
497
  try:
498
  VECTOR_DB_CLIENT.delete_collection(collection_name=id)
499
  except Exception as e:
500
  log.debug(e)
501
  pass
502
+
503
+ knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
504
+
505
+ return knowledge
backend/open_webui/apps/webui/routers/models.py CHANGED
@@ -4,53 +4,71 @@ from open_webui.apps.webui.models.models import (
4
  ModelForm,
5
  ModelModel,
6
  ModelResponse,
 
7
  Models,
8
  )
9
  from open_webui.constants import ERROR_MESSAGES
10
  from fastapi import APIRouter, Depends, HTTPException, Request, status
 
 
11
  from open_webui.utils.utils import get_admin_user, get_verified_user
 
 
12
 
13
  router = APIRouter()
14
 
 
15
  ###########################
16
- # getModels
17
  ###########################
18
 
19
 
20
- @router.get("/", response_model=list[ModelResponse])
21
  async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
22
- if id:
23
- model = Models.get_model_by_id(id)
24
- if model:
25
- return [model]
26
- else:
27
- raise HTTPException(
28
- status_code=status.HTTP_401_UNAUTHORIZED,
29
- detail=ERROR_MESSAGES.NOT_FOUND,
30
- )
31
  else:
32
- return Models.get_all_models()
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  ############################
36
- # AddNewModel
37
  ############################
38
 
39
 
40
- @router.post("/add", response_model=Optional[ModelModel])
41
- async def add_new_model(
42
  request: Request,
43
  form_data: ModelForm,
44
- user=Depends(get_admin_user),
45
  ):
46
- if form_data.id in request.app.state.MODELS:
 
 
 
 
 
 
 
 
 
47
  raise HTTPException(
48
  status_code=status.HTTP_401_UNAUTHORIZED,
49
  detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
50
  )
 
51
  else:
52
  model = Models.insert_new_model(form_data, user.id)
53
-
54
  if model:
55
  return model
56
  else:
@@ -60,37 +78,84 @@ async def add_new_model(
60
  )
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  ############################
64
- # UpdateModelById
65
  ############################
66
 
67
 
68
- @router.post("/update", response_model=Optional[ModelModel])
69
- async def update_model_by_id(
70
- request: Request,
71
- id: str,
72
- form_data: ModelForm,
73
- user=Depends(get_admin_user),
74
- ):
75
  model = Models.get_model_by_id(id)
76
  if model:
77
- model = Models.update_model_by_id(id, form_data)
78
- return model
79
- else:
80
- if form_data.id in request.app.state.MODELS:
81
- model = Models.insert_new_model(form_data, user.id)
 
 
82
  if model:
83
  return model
84
  else:
85
  raise HTTPException(
86
- status_code=status.HTTP_401_UNAUTHORIZED,
87
- detail=ERROR_MESSAGES.DEFAULT(),
88
  )
89
  else:
90
  raise HTTPException(
91
  status_code=status.HTTP_401_UNAUTHORIZED,
92
- detail=ERROR_MESSAGES.DEFAULT(),
93
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
 
96
  ############################
@@ -98,7 +163,26 @@ async def update_model_by_id(
98
  ############################
99
 
100
 
101
- @router.delete("/delete", response_model=bool)
102
- async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  result = Models.delete_model_by_id(id)
104
  return result
 
 
 
 
 
 
 
4
  ModelForm,
5
  ModelModel,
6
  ModelResponse,
7
+ ModelUserResponse,
8
  Models,
9
  )
10
  from open_webui.constants import ERROR_MESSAGES
11
  from fastapi import APIRouter, Depends, HTTPException, Request, status
12
+
13
+
14
  from open_webui.utils.utils import get_admin_user, get_verified_user
15
+ from open_webui.utils.access_control import has_access, has_permission
16
+
17
 
18
  router = APIRouter()
19
 
20
+
21
  ###########################
22
+ # GetModels
23
  ###########################
24
 
25
 
26
+ @router.get("/", response_model=list[ModelUserResponse])
27
  async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
28
+ if user.role == "admin":
29
+ return Models.get_models()
 
 
 
 
 
 
 
30
  else:
31
+ return Models.get_models_by_user_id(user.id)
32
+
33
+
34
+ ###########################
35
+ # GetBaseModels
36
+ ###########################
37
+
38
+
39
+ @router.get("/base", response_model=list[ModelResponse])
40
+ async def get_base_models(user=Depends(get_admin_user)):
41
+ return Models.get_base_models()
42
 
43
 
44
  ############################
45
+ # CreateNewModel
46
  ############################
47
 
48
 
49
+ @router.post("/create", response_model=Optional[ModelModel])
50
+ async def create_new_model(
51
  request: Request,
52
  form_data: ModelForm,
53
+ user=Depends(get_verified_user),
54
  ):
55
+ if user.role != "admin" and not has_permission(
56
+ user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS
57
+ ):
58
+ raise HTTPException(
59
+ status_code=status.HTTP_401_UNAUTHORIZED,
60
+ detail=ERROR_MESSAGES.UNAUTHORIZED,
61
+ )
62
+
63
+ model = Models.get_model_by_id(form_data.id)
64
+ if model:
65
  raise HTTPException(
66
  status_code=status.HTTP_401_UNAUTHORIZED,
67
  detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
68
  )
69
+
70
  else:
71
  model = Models.insert_new_model(form_data, user.id)
 
72
  if model:
73
  return model
74
  else:
 
78
  )
79
 
80
 
81
+ ###########################
82
+ # GetModelById
83
+ ###########################
84
+
85
+
86
+ @router.get("/id/{id}", response_model=Optional[ModelResponse])
87
+ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
88
+ model = Models.get_model_by_id(id)
89
+ if model:
90
+ if (
91
+ user.role == "admin"
92
+ or model.user_id == user.id
93
+ or has_access(user.id, "read", model.access_control)
94
+ ):
95
+ return model
96
+ else:
97
+ raise HTTPException(
98
+ status_code=status.HTTP_401_UNAUTHORIZED,
99
+ detail=ERROR_MESSAGES.NOT_FOUND,
100
+ )
101
+
102
+
103
  ############################
104
+ # ToggelModelById
105
  ############################
106
 
107
 
108
+ @router.post("/id/{id}/toggle", response_model=Optional[ModelResponse])
109
+ async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
 
 
 
 
 
110
  model = Models.get_model_by_id(id)
111
  if model:
112
+ if (
113
+ user.role == "admin"
114
+ or model.user_id == user.id
115
+ or has_access(user.id, "write", model.access_control)
116
+ ):
117
+ model = Models.toggle_model_by_id(id)
118
+
119
  if model:
120
  return model
121
  else:
122
  raise HTTPException(
123
+ status_code=status.HTTP_400_BAD_REQUEST,
124
+ detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
125
  )
126
  else:
127
  raise HTTPException(
128
  status_code=status.HTTP_401_UNAUTHORIZED,
129
+ detail=ERROR_MESSAGES.UNAUTHORIZED,
130
  )
131
+ else:
132
+ raise HTTPException(
133
+ status_code=status.HTTP_401_UNAUTHORIZED,
134
+ detail=ERROR_MESSAGES.NOT_FOUND,
135
+ )
136
+
137
+
138
+ ############################
139
+ # UpdateModelById
140
+ ############################
141
+
142
+
143
+ @router.post("/id/{id}/update", response_model=Optional[ModelModel])
144
+ async def update_model_by_id(
145
+ id: str,
146
+ form_data: ModelForm,
147
+ user=Depends(get_verified_user),
148
+ ):
149
+ model = Models.get_model_by_id(id)
150
+
151
+ if not model:
152
+ raise HTTPException(
153
+ status_code=status.HTTP_401_UNAUTHORIZED,
154
+ detail=ERROR_MESSAGES.NOT_FOUND,
155
+ )
156
+
157
+ model = Models.update_model_by_id(id, form_data)
158
+ return model
159
 
160
 
161
  ############################
 
163
  ############################
164
 
165
 
166
+ @router.delete("/id/{id}/delete", response_model=bool)
167
+ async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
168
+ model = Models.get_model_by_id(id)
169
+ if not model:
170
+ raise HTTPException(
171
+ status_code=status.HTTP_401_UNAUTHORIZED,
172
+ detail=ERROR_MESSAGES.NOT_FOUND,
173
+ )
174
+
175
+ if model.user_id != user.id and user.role != "admin":
176
+ raise HTTPException(
177
+ status_code=status.HTTP_401_UNAUTHORIZED,
178
+ detail=ERROR_MESSAGES.UNAUTHORIZED,
179
+ )
180
+
181
  result = Models.delete_model_by_id(id)
182
  return result
183
+
184
+
185
+ @router.delete("/delete/all", response_model=bool)
186
+ async def delete_all_models(user=Depends(get_admin_user)):
187
+ result = Models.delete_all_models()
188
+ return result
backend/open_webui/apps/webui/routers/prompts.py CHANGED
@@ -1,9 +1,15 @@
1
  from typing import Optional
2
 
3
- from open_webui.apps.webui.models.prompts import PromptForm, PromptModel, Prompts
 
 
 
 
 
4
  from open_webui.constants import ERROR_MESSAGES
5
- from fastapi import APIRouter, Depends, HTTPException, status
6
  from open_webui.utils.utils import get_admin_user, get_verified_user
 
7
 
8
  router = APIRouter()
9
 
@@ -14,7 +20,22 @@ router = APIRouter()
14
 
15
  @router.get("/", response_model=list[PromptModel])
16
  async def get_prompts(user=Depends(get_verified_user)):
17
- return Prompts.get_prompts()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  ############################
@@ -23,7 +44,17 @@ async def get_prompts(user=Depends(get_verified_user)):
23
 
24
 
25
  @router.post("/create", response_model=Optional[PromptModel])
26
- async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
 
 
 
 
 
 
 
 
 
 
27
  prompt = Prompts.get_prompt_by_command(form_data.command)
28
  if prompt is None:
29
  prompt = Prompts.insert_new_prompt(user.id, form_data)
@@ -50,7 +81,12 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
50
  prompt = Prompts.get_prompt_by_command(f"/{command}")
51
 
52
  if prompt:
53
- return prompt
 
 
 
 
 
54
  else:
55
  raise HTTPException(
56
  status_code=status.HTTP_401_UNAUTHORIZED,
@@ -67,8 +103,21 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
67
  async def update_prompt_by_command(
68
  command: str,
69
  form_data: PromptForm,
70
- user=Depends(get_admin_user),
71
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
73
  if prompt:
74
  return prompt
@@ -85,6 +134,19 @@ async def update_prompt_by_command(
85
 
86
 
87
  @router.delete("/command/{command}/delete", response_model=bool)
88
- async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  result = Prompts.delete_prompt_by_command(f"/{command}")
90
  return result
 
1
  from typing import Optional
2
 
3
+ from open_webui.apps.webui.models.prompts import (
4
+ PromptForm,
5
+ PromptUserResponse,
6
+ PromptModel,
7
+ Prompts,
8
+ )
9
  from open_webui.constants import ERROR_MESSAGES
10
+ from fastapi import APIRouter, Depends, HTTPException, status, Request
11
  from open_webui.utils.utils import get_admin_user, get_verified_user
12
+ from open_webui.utils.access_control import has_access, has_permission
13
 
14
  router = APIRouter()
15
 
 
20
 
21
  @router.get("/", response_model=list[PromptModel])
22
  async def get_prompts(user=Depends(get_verified_user)):
23
+ if user.role == "admin":
24
+ prompts = Prompts.get_prompts()
25
+ else:
26
+ prompts = Prompts.get_prompts_by_user_id(user.id, "read")
27
+
28
+ return prompts
29
+
30
+
31
+ @router.get("/list", response_model=list[PromptUserResponse])
32
+ async def get_prompt_list(user=Depends(get_verified_user)):
33
+ if user.role == "admin":
34
+ prompts = Prompts.get_prompts()
35
+ else:
36
+ prompts = Prompts.get_prompts_by_user_id(user.id, "write")
37
+
38
+ return prompts
39
 
40
 
41
  ############################
 
44
 
45
 
46
  @router.post("/create", response_model=Optional[PromptModel])
47
+ async def create_new_prompt(
48
+ request: Request, form_data: PromptForm, user=Depends(get_verified_user)
49
+ ):
50
+ if user.role != "admin" and not has_permission(
51
+ user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS
52
+ ):
53
+ raise HTTPException(
54
+ status_code=status.HTTP_401_UNAUTHORIZED,
55
+ detail=ERROR_MESSAGES.UNAUTHORIZED,
56
+ )
57
+
58
  prompt = Prompts.get_prompt_by_command(form_data.command)
59
  if prompt is None:
60
  prompt = Prompts.insert_new_prompt(user.id, form_data)
 
81
  prompt = Prompts.get_prompt_by_command(f"/{command}")
82
 
83
  if prompt:
84
+ if (
85
+ user.role == "admin"
86
+ or prompt.user_id == user.id
87
+ or has_access(user.id, "read", prompt.access_control)
88
+ ):
89
+ return prompt
90
  else:
91
  raise HTTPException(
92
  status_code=status.HTTP_401_UNAUTHORIZED,
 
103
  async def update_prompt_by_command(
104
  command: str,
105
  form_data: PromptForm,
106
+ user=Depends(get_verified_user),
107
  ):
108
+ prompt = Prompts.get_prompt_by_command(f"/{command}")
109
+ if not prompt:
110
+ raise HTTPException(
111
+ status_code=status.HTTP_401_UNAUTHORIZED,
112
+ detail=ERROR_MESSAGES.NOT_FOUND,
113
+ )
114
+
115
+ if prompt.user_id != user.id and user.role != "admin":
116
+ raise HTTPException(
117
+ status_code=status.HTTP_401_UNAUTHORIZED,
118
+ detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
119
+ )
120
+
121
  prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
122
  if prompt:
123
  return prompt
 
134
 
135
 
136
  @router.delete("/command/{command}/delete", response_model=bool)
137
+ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)):
138
+ prompt = Prompts.get_prompt_by_command(f"/{command}")
139
+ if not prompt:
140
+ raise HTTPException(
141
+ status_code=status.HTTP_401_UNAUTHORIZED,
142
+ detail=ERROR_MESSAGES.NOT_FOUND,
143
+ )
144
+
145
+ if prompt.user_id != user.id and user.role != "admin":
146
+ raise HTTPException(
147
+ status_code=status.HTTP_401_UNAUTHORIZED,
148
+ detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
149
+ )
150
+
151
  result = Prompts.delete_prompt_by_command(f"/{command}")
152
  return result
backend/open_webui/apps/webui/routers/tools.py CHANGED
@@ -2,50 +2,82 @@ import os
2
  from pathlib import Path
3
  from typing import Optional
4
 
5
- from open_webui.apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools
6
- from open_webui.apps.webui.utils import load_toolkit_module_by_id, replace_imports
 
 
 
 
 
 
7
  from open_webui.config import CACHE_DIR, DATA_DIR
8
  from open_webui.constants import ERROR_MESSAGES
9
  from fastapi import APIRouter, Depends, HTTPException, Request, status
10
  from open_webui.utils.tools import get_tools_specs
11
  from open_webui.utils.utils import get_admin_user, get_verified_user
 
12
 
13
 
14
  router = APIRouter()
15
 
16
  ############################
17
- # GetToolkits
18
  ############################
19
 
20
 
21
- @router.get("/", response_model=list[ToolResponse])
22
- async def get_toolkits(user=Depends(get_verified_user)):
23
- toolkits = [toolkit for toolkit in Tools.get_tools()]
24
- return toolkits
 
 
 
25
 
26
 
27
  ############################
28
- # ExportToolKits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  ############################
30
 
31
 
32
  @router.get("/export", response_model=list[ToolModel])
33
- async def get_toolkits(user=Depends(get_admin_user)):
34
- toolkits = [toolkit for toolkit in Tools.get_tools()]
35
- return toolkits
36
 
37
 
38
  ############################
39
- # CreateNewToolKit
40
  ############################
41
 
42
 
43
  @router.post("/create", response_model=Optional[ToolResponse])
44
- async def create_new_toolkit(
45
  request: Request,
46
  form_data: ToolForm,
47
- user=Depends(get_admin_user),
48
  ):
 
 
 
 
 
 
 
 
49
  if not form_data.id.isidentifier():
50
  raise HTTPException(
51
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -54,30 +86,30 @@ async def create_new_toolkit(
54
 
55
  form_data.id = form_data.id.lower()
56
 
57
- toolkit = Tools.get_tool_by_id(form_data.id)
58
- if toolkit is None:
59
  try:
60
  form_data.content = replace_imports(form_data.content)
61
- toolkit_module, frontmatter = load_toolkit_module_by_id(
62
  form_data.id, content=form_data.content
63
  )
64
  form_data.meta.manifest = frontmatter
65
 
66
  TOOLS = request.app.state.TOOLS
67
- TOOLS[form_data.id] = toolkit_module
68
 
69
  specs = get_tools_specs(TOOLS[form_data.id])
70
- toolkit = Tools.insert_new_tool(user.id, form_data, specs)
71
 
72
  tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
73
  tool_cache_dir.mkdir(parents=True, exist_ok=True)
74
 
75
- if toolkit:
76
- return toolkit
77
  else:
78
  raise HTTPException(
79
  status_code=status.HTTP_400_BAD_REQUEST,
80
- detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"),
81
  )
82
  except Exception as e:
83
  print(e)
@@ -93,16 +125,21 @@ async def create_new_toolkit(
93
 
94
 
95
  ############################
96
- # GetToolkitById
97
  ############################
98
 
99
 
100
  @router.get("/id/{id}", response_model=Optional[ToolModel])
101
- async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
102
- toolkit = Tools.get_tool_by_id(id)
103
-
104
- if toolkit:
105
- return toolkit
 
 
 
 
 
106
  else:
107
  raise HTTPException(
108
  status_code=status.HTTP_401_UNAUTHORIZED,
@@ -111,26 +148,39 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
111
 
112
 
113
  ############################
114
- # UpdateToolkitById
115
  ############################
116
 
117
 
118
  @router.post("/id/{id}/update", response_model=Optional[ToolModel])
119
- async def update_toolkit_by_id(
120
  request: Request,
121
  id: str,
122
  form_data: ToolForm,
123
- user=Depends(get_admin_user),
124
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  try:
126
  form_data.content = replace_imports(form_data.content)
127
- toolkit_module, frontmatter = load_toolkit_module_by_id(
128
  id, content=form_data.content
129
  )
130
  form_data.meta.manifest = frontmatter
131
 
132
  TOOLS = request.app.state.TOOLS
133
- TOOLS[id] = toolkit_module
134
 
135
  specs = get_tools_specs(TOOLS[id])
136
 
@@ -140,14 +190,14 @@ async def update_toolkit_by_id(
140
  }
141
 
142
  print(updated)
143
- toolkit = Tools.update_tool_by_id(id, updated)
144
 
145
- if toolkit:
146
- return toolkit
147
  else:
148
  raise HTTPException(
149
  status_code=status.HTTP_400_BAD_REQUEST,
150
- detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
151
  )
152
 
153
  except Exception as e:
@@ -158,14 +208,28 @@ async def update_toolkit_by_id(
158
 
159
 
160
  ############################
161
- # DeleteToolkitById
162
  ############################
163
 
164
 
165
  @router.delete("/id/{id}/delete", response_model=bool)
166
- async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
167
- result = Tools.delete_tool_by_id(id)
 
 
 
 
 
 
 
168
 
 
 
 
 
 
 
 
169
  if result:
170
  TOOLS = request.app.state.TOOLS
171
  if id in TOOLS:
@@ -180,9 +244,9 @@ async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin
180
 
181
 
182
  @router.get("/id/{id}/valves", response_model=Optional[dict])
183
- async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
184
- toolkit = Tools.get_tool_by_id(id)
185
- if toolkit:
186
  try:
187
  valves = Tools.get_tool_valves_by_id(id)
188
  return valves
@@ -204,19 +268,19 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
204
 
205
 
206
  @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
207
- async def get_toolkit_valves_spec_by_id(
208
- request: Request, id: str, user=Depends(get_admin_user)
209
  ):
210
- toolkit = Tools.get_tool_by_id(id)
211
- if toolkit:
212
  if id in request.app.state.TOOLS:
213
- toolkit_module = request.app.state.TOOLS[id]
214
  else:
215
- toolkit_module, _ = load_toolkit_module_by_id(id)
216
- request.app.state.TOOLS[id] = toolkit_module
217
 
218
- if hasattr(toolkit_module, "Valves"):
219
- Valves = toolkit_module.Valves
220
  return Valves.schema()
221
  return None
222
  else:
@@ -232,19 +296,19 @@ async def get_toolkit_valves_spec_by_id(
232
 
233
 
234
  @router.post("/id/{id}/valves/update", response_model=Optional[dict])
235
- async def update_toolkit_valves_by_id(
236
- request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
237
  ):
238
- toolkit = Tools.get_tool_by_id(id)
239
- if toolkit:
240
  if id in request.app.state.TOOLS:
241
- toolkit_module = request.app.state.TOOLS[id]
242
  else:
243
- toolkit_module, _ = load_toolkit_module_by_id(id)
244
- request.app.state.TOOLS[id] = toolkit_module
245
 
246
- if hasattr(toolkit_module, "Valves"):
247
- Valves = toolkit_module.Valves
248
 
249
  try:
250
  form_data = {k: v for k, v in form_data.items() if v is not None}
@@ -276,9 +340,9 @@ async def update_toolkit_valves_by_id(
276
 
277
 
278
  @router.get("/id/{id}/valves/user", response_model=Optional[dict])
279
- async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)):
280
- toolkit = Tools.get_tool_by_id(id)
281
- if toolkit:
282
  try:
283
  user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
284
  return user_valves
@@ -295,19 +359,19 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)
295
 
296
 
297
  @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
298
- async def get_toolkit_user_valves_spec_by_id(
299
  request: Request, id: str, user=Depends(get_verified_user)
300
  ):
301
- toolkit = Tools.get_tool_by_id(id)
302
- if toolkit:
303
  if id in request.app.state.TOOLS:
304
- toolkit_module = request.app.state.TOOLS[id]
305
  else:
306
- toolkit_module, _ = load_toolkit_module_by_id(id)
307
- request.app.state.TOOLS[id] = toolkit_module
308
 
309
- if hasattr(toolkit_module, "UserValves"):
310
- UserValves = toolkit_module.UserValves
311
  return UserValves.schema()
312
  return None
313
  else:
@@ -318,20 +382,20 @@ async def get_toolkit_user_valves_spec_by_id(
318
 
319
 
320
  @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
321
- async def update_toolkit_user_valves_by_id(
322
  request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
323
  ):
324
- toolkit = Tools.get_tool_by_id(id)
325
 
326
- if toolkit:
327
  if id in request.app.state.TOOLS:
328
- toolkit_module = request.app.state.TOOLS[id]
329
  else:
330
- toolkit_module, _ = load_toolkit_module_by_id(id)
331
- request.app.state.TOOLS[id] = toolkit_module
332
 
333
- if hasattr(toolkit_module, "UserValves"):
334
- UserValves = toolkit_module.UserValves
335
 
336
  try:
337
  form_data = {k: v for k, v in form_data.items() if v is not None}
 
2
  from pathlib import Path
3
  from typing import Optional
4
 
5
+ from open_webui.apps.webui.models.tools import (
6
+ ToolForm,
7
+ ToolModel,
8
+ ToolResponse,
9
+ ToolUserResponse,
10
+ Tools,
11
+ )
12
+ from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports
13
  from open_webui.config import CACHE_DIR, DATA_DIR
14
  from open_webui.constants import ERROR_MESSAGES
15
  from fastapi import APIRouter, Depends, HTTPException, Request, status
16
  from open_webui.utils.tools import get_tools_specs
17
  from open_webui.utils.utils import get_admin_user, get_verified_user
18
+ from open_webui.utils.access_control import has_access, has_permission
19
 
20
 
21
  router = APIRouter()
22
 
23
  ############################
24
+ # GetTools
25
  ############################
26
 
27
 
28
+ @router.get("/", response_model=list[ToolUserResponse])
29
+ async def get_tools(user=Depends(get_verified_user)):
30
+ if user.role == "admin":
31
+ tools = Tools.get_tools()
32
+ else:
33
+ tools = Tools.get_tools_by_user_id(user.id, "read")
34
+ return tools
35
 
36
 
37
  ############################
38
+ # GetToolList
39
+ ############################
40
+
41
+
42
+ @router.get("/list", response_model=list[ToolUserResponse])
43
+ async def get_tool_list(user=Depends(get_verified_user)):
44
+ if user.role == "admin":
45
+ tools = Tools.get_tools()
46
+ else:
47
+ tools = Tools.get_tools_by_user_id(user.id, "write")
48
+ return tools
49
+
50
+
51
+ ############################
52
+ # ExportTools
53
  ############################
54
 
55
 
56
  @router.get("/export", response_model=list[ToolModel])
57
+ async def export_tools(user=Depends(get_admin_user)):
58
+ tools = Tools.get_tools()
59
+ return tools
60
 
61
 
62
  ############################
63
+ # CreateNewTools
64
  ############################
65
 
66
 
67
  @router.post("/create", response_model=Optional[ToolResponse])
68
+ async def create_new_tools(
69
  request: Request,
70
  form_data: ToolForm,
71
+ user=Depends(get_verified_user),
72
  ):
73
+ if user.role != "admin" and not has_permission(
74
+ user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS
75
+ ):
76
+ raise HTTPException(
77
+ status_code=status.HTTP_401_UNAUTHORIZED,
78
+ detail=ERROR_MESSAGES.UNAUTHORIZED,
79
+ )
80
+
81
  if not form_data.id.isidentifier():
82
  raise HTTPException(
83
  status_code=status.HTTP_400_BAD_REQUEST,
 
86
 
87
  form_data.id = form_data.id.lower()
88
 
89
+ tools = Tools.get_tool_by_id(form_data.id)
90
+ if tools is None:
91
  try:
92
  form_data.content = replace_imports(form_data.content)
93
+ tools_module, frontmatter = load_tools_module_by_id(
94
  form_data.id, content=form_data.content
95
  )
96
  form_data.meta.manifest = frontmatter
97
 
98
  TOOLS = request.app.state.TOOLS
99
+ TOOLS[form_data.id] = tools_module
100
 
101
  specs = get_tools_specs(TOOLS[form_data.id])
102
+ tools = Tools.insert_new_tool(user.id, form_data, specs)
103
 
104
  tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
105
  tool_cache_dir.mkdir(parents=True, exist_ok=True)
106
 
107
+ if tools:
108
+ return tools
109
  else:
110
  raise HTTPException(
111
  status_code=status.HTTP_400_BAD_REQUEST,
112
+ detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
113
  )
114
  except Exception as e:
115
  print(e)
 
125
 
126
 
127
  ############################
128
+ # GetToolsById
129
  ############################
130
 
131
 
132
  @router.get("/id/{id}", response_model=Optional[ToolModel])
133
+ async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
134
+ tools = Tools.get_tool_by_id(id)
135
+
136
+ if tools:
137
+ if (
138
+ user.role == "admin"
139
+ or tools.user_id == user.id
140
+ or has_access(user.id, "read", tools.access_control)
141
+ ):
142
+ return tools
143
  else:
144
  raise HTTPException(
145
  status_code=status.HTTP_401_UNAUTHORIZED,
 
148
 
149
 
150
  ############################
151
+ # UpdateToolsById
152
  ############################
153
 
154
 
155
  @router.post("/id/{id}/update", response_model=Optional[ToolModel])
156
+ async def update_tools_by_id(
157
  request: Request,
158
  id: str,
159
  form_data: ToolForm,
160
+ user=Depends(get_verified_user),
161
  ):
162
+ tools = Tools.get_tool_by_id(id)
163
+ if not tools:
164
+ raise HTTPException(
165
+ status_code=status.HTTP_401_UNAUTHORIZED,
166
+ detail=ERROR_MESSAGES.NOT_FOUND,
167
+ )
168
+
169
+ if tools.user_id != user.id and user.role != "admin":
170
+ raise HTTPException(
171
+ status_code=status.HTTP_401_UNAUTHORIZED,
172
+ detail=ERROR_MESSAGES.UNAUTHORIZED,
173
+ )
174
+
175
  try:
176
  form_data.content = replace_imports(form_data.content)
177
+ tools_module, frontmatter = load_tools_module_by_id(
178
  id, content=form_data.content
179
  )
180
  form_data.meta.manifest = frontmatter
181
 
182
  TOOLS = request.app.state.TOOLS
183
+ TOOLS[id] = tools_module
184
 
185
  specs = get_tools_specs(TOOLS[id])
186
 
 
190
  }
191
 
192
  print(updated)
193
+ tools = Tools.update_tool_by_id(id, updated)
194
 
195
+ if tools:
196
+ return tools
197
  else:
198
  raise HTTPException(
199
  status_code=status.HTTP_400_BAD_REQUEST,
200
+ detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
201
  )
202
 
203
  except Exception as e:
 
208
 
209
 
210
  ############################
211
+ # DeleteToolsById
212
  ############################
213
 
214
 
215
  @router.delete("/id/{id}/delete", response_model=bool)
216
+ async def delete_tools_by_id(
217
+ request: Request, id: str, user=Depends(get_verified_user)
218
+ ):
219
+ tools = Tools.get_tool_by_id(id)
220
+ if not tools:
221
+ raise HTTPException(
222
+ status_code=status.HTTP_401_UNAUTHORIZED,
223
+ detail=ERROR_MESSAGES.NOT_FOUND,
224
+ )
225
 
226
+ if tools.user_id != user.id and user.role != "admin":
227
+ raise HTTPException(
228
+ status_code=status.HTTP_401_UNAUTHORIZED,
229
+ detail=ERROR_MESSAGES.UNAUTHORIZED,
230
+ )
231
+
232
+ result = Tools.delete_tool_by_id(id)
233
  if result:
234
  TOOLS = request.app.state.TOOLS
235
  if id in TOOLS:
 
244
 
245
 
246
  @router.get("/id/{id}/valves", response_model=Optional[dict])
247
+ async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
248
+ tools = Tools.get_tool_by_id(id)
249
+ if tools:
250
  try:
251
  valves = Tools.get_tool_valves_by_id(id)
252
  return valves
 
268
 
269
 
270
  @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
271
+ async def get_tools_valves_spec_by_id(
272
+ request: Request, id: str, user=Depends(get_verified_user)
273
  ):
274
+ tools = Tools.get_tool_by_id(id)
275
+ if tools:
276
  if id in request.app.state.TOOLS:
277
+ tools_module = request.app.state.TOOLS[id]
278
  else:
279
+ tools_module, _ = load_tools_module_by_id(id)
280
+ request.app.state.TOOLS[id] = tools_module
281
 
282
+ if hasattr(tools_module, "Valves"):
283
+ Valves = tools_module.Valves
284
  return Valves.schema()
285
  return None
286
  else:
 
296
 
297
 
298
  @router.post("/id/{id}/valves/update", response_model=Optional[dict])
299
+ async def update_tools_valves_by_id(
300
+ request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
301
  ):
302
+ tools = Tools.get_tool_by_id(id)
303
+ if tools:
304
  if id in request.app.state.TOOLS:
305
+ tools_module = request.app.state.TOOLS[id]
306
  else:
307
+ tools_module, _ = load_tools_module_by_id(id)
308
+ request.app.state.TOOLS[id] = tools_module
309
 
310
+ if hasattr(tools_module, "Valves"):
311
+ Valves = tools_module.Valves
312
 
313
  try:
314
  form_data = {k: v for k, v in form_data.items() if v is not None}
 
340
 
341
 
342
  @router.get("/id/{id}/valves/user", response_model=Optional[dict])
343
+ async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
344
+ tools = Tools.get_tool_by_id(id)
345
+ if tools:
346
  try:
347
  user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
348
  return user_valves
 
359
 
360
 
361
  @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
362
+ async def get_tools_user_valves_spec_by_id(
363
  request: Request, id: str, user=Depends(get_verified_user)
364
  ):
365
+ tools = Tools.get_tool_by_id(id)
366
+ if tools:
367
  if id in request.app.state.TOOLS:
368
+ tools_module = request.app.state.TOOLS[id]
369
  else:
370
+ tools_module, _ = load_tools_module_by_id(id)
371
+ request.app.state.TOOLS[id] = tools_module
372
 
373
+ if hasattr(tools_module, "UserValves"):
374
+ UserValves = tools_module.UserValves
375
  return UserValves.schema()
376
  return None
377
  else:
 
382
 
383
 
384
  @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
385
+ async def update_tools_user_valves_by_id(
386
  request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
387
  ):
388
+ tools = Tools.get_tool_by_id(id)
389
 
390
+ if tools:
391
  if id in request.app.state.TOOLS:
392
+ tools_module = request.app.state.TOOLS[id]
393
  else:
394
+ tools_module, _ = load_tools_module_by_id(id)
395
+ request.app.state.TOOLS[id] = tools_module
396
 
397
+ if hasattr(tools_module, "UserValves"):
398
+ UserValves = tools_module.UserValves
399
 
400
  try:
401
  form_data = {k: v for k, v in form_data.items() if v is not None}
backend/open_webui/apps/webui/routers/users.py CHANGED
@@ -31,21 +31,58 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
31
  return Users.get_users(skip, limit)
32
 
33
 
 
 
 
 
 
 
 
 
 
 
34
  ############################
35
  # User Permissions
36
  ############################
37
 
38
 
39
- @router.get("/permissions/user")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
41
  return request.app.state.config.USER_PERMISSIONS
42
 
43
 
44
- @router.post("/permissions/user")
45
  async def update_user_permissions(
46
- request: Request, form_data: dict, user=Depends(get_admin_user)
47
  ):
48
- request.app.state.config.USER_PERMISSIONS = form_data
49
  return request.app.state.config.USER_PERMISSIONS
50
 
51
 
 
31
  return Users.get_users(skip, limit)
32
 
33
 
34
+ ############################
35
+ # User Groups
36
+ ############################
37
+
38
+
39
+ @router.get("/groups")
40
+ async def get_user_groups(user=Depends(get_verified_user)):
41
+ return Users.get_user_groups(user.id)
42
+
43
+
44
  ############################
45
  # User Permissions
46
  ############################
47
 
48
 
49
+ @router.get("/permissions")
50
+ async def get_user_permissisions(user=Depends(get_verified_user)):
51
+ return Users.get_user_groups(user.id)
52
+
53
+
54
+ ############################
55
+ # User Default Permissions
56
+ ############################
57
+ class WorkspacePermissions(BaseModel):
58
+ models: bool
59
+ knowledge: bool
60
+ prompts: bool
61
+ tools: bool
62
+
63
+
64
+ class ChatPermissions(BaseModel):
65
+ file_upload: bool
66
+ delete: bool
67
+ edit: bool
68
+ temporary: bool
69
+
70
+
71
+ class UserPermissions(BaseModel):
72
+ workspace: WorkspacePermissions
73
+ chat: ChatPermissions
74
+
75
+
76
+ @router.get("/default/permissions")
77
  async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
78
  return request.app.state.config.USER_PERMISSIONS
79
 
80
 
81
+ @router.post("/default/permissions")
82
  async def update_user_permissions(
83
+ request: Request, form_data: UserPermissions, user=Depends(get_admin_user)
84
  ):
85
+ request.app.state.config.USER_PERMISSIONS = form_data.model_dump()
86
  return request.app.state.config.USER_PERMISSIONS
87
 
88
 
backend/open_webui/apps/webui/utils.py CHANGED
@@ -63,7 +63,7 @@ def replace_imports(content):
63
  return content
64
 
65
 
66
- def load_toolkit_module_by_id(toolkit_id, content=None):
67
 
68
  if content is None:
69
  tool = Tools.get_tool_by_id(toolkit_id)
 
63
  return content
64
 
65
 
66
+ def load_tools_module_by_id(toolkit_id, content=None):
67
 
68
  if content is None:
69
  tool = Tools.get_tool_by_id(toolkit_id)
backend/open_webui/config.py CHANGED
@@ -20,6 +20,7 @@ from open_webui.env import (
20
  WEBUI_FAVICON_URL,
21
  WEBUI_NAME,
22
  log,
 
23
  )
24
  from pydantic import BaseModel
25
  from sqlalchemy import JSON, Column, DateTime, Integer, func
@@ -264,6 +265,13 @@ class AppConfig:
264
  # WEBUI_AUTH (Required for security)
265
  ####################################
266
 
 
 
 
 
 
 
 
267
  JWT_EXPIRES_IN = PersistentConfig(
268
  "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
269
  )
@@ -606,6 +614,12 @@ OLLAMA_BASE_URLS = PersistentConfig(
606
  "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
607
  )
608
 
 
 
 
 
 
 
609
  ####################################
610
  # OPENAI_API
611
  ####################################
@@ -646,15 +660,20 @@ OPENAI_API_BASE_URLS = PersistentConfig(
646
  "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
647
  )
648
 
649
- OPENAI_API_KEY = ""
 
 
 
 
650
 
 
 
651
  try:
652
  OPENAI_API_KEY = OPENAI_API_KEYS.value[
653
  OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
654
  ]
655
  except Exception:
656
  pass
657
-
658
  OPENAI_API_BASE_URL = "https://api.openai.com/v1"
659
 
660
  ####################################
@@ -727,12 +746,36 @@ DEFAULT_USER_ROLE = PersistentConfig(
727
  os.getenv("DEFAULT_USER_ROLE", "pending"),
728
  )
729
 
730
- USER_PERMISSIONS_CHAT_DELETION = (
731
- os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732
  )
733
 
734
- USER_PERMISSIONS_CHAT_EDITING = (
735
- os.environ.get("USER_PERMISSIONS_CHAT_EDITING", "True").lower() == "true"
736
  )
737
 
738
  USER_PERMISSIONS_CHAT_TEMPORARY = (
@@ -741,13 +784,20 @@ USER_PERMISSIONS_CHAT_TEMPORARY = (
741
 
742
  USER_PERMISSIONS = PersistentConfig(
743
  "USER_PERMISSIONS",
744
- "ui.user_permissions",
745
  {
 
 
 
 
 
 
746
  "chat": {
747
- "deletion": USER_PERMISSIONS_CHAT_DELETION,
748
- "editing": USER_PERMISSIONS_CHAT_EDITING,
 
749
  "temporary": USER_PERMISSIONS_CHAT_TEMPORARY,
750
- }
751
  },
752
  )
753
 
@@ -773,18 +823,6 @@ DEFAULT_ARENA_MODEL = {
773
  },
774
  }
775
 
776
- ENABLE_MODEL_FILTER = PersistentConfig(
777
- "ENABLE_MODEL_FILTER",
778
- "model_filter.enable",
779
- os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true",
780
- )
781
- MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
782
- MODEL_FILTER_LIST = PersistentConfig(
783
- "MODEL_FILTER_LIST",
784
- "model_filter.list",
785
- [model.strip() for model in MODEL_FILTER_LIST.split(";")],
786
- )
787
-
788
  WEBHOOK_URL = PersistentConfig(
789
  "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
790
  )
@@ -904,19 +942,55 @@ TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
904
  os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""),
905
  )
906
 
907
- ENABLE_SEARCH_QUERY = PersistentConfig(
908
- "ENABLE_SEARCH_QUERY",
909
- "task.search.enable",
910
- os.environ.get("ENABLE_SEARCH_QUERY", "True").lower() == "true",
 
 
 
 
 
 
 
 
 
 
 
 
 
911
  )
912
 
913
 
914
- SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
915
- "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
916
- "task.search.prompt_template",
917
- os.environ.get("SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", ""),
918
  )
919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920
 
921
  TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
922
  "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
@@ -956,6 +1030,21 @@ MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
956
 
957
  # Qdrant
958
  QDRANT_URI = os.environ.get("QDRANT_URI", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
959
 
960
  ####################################
961
  # Information Retrieval (RAG)
@@ -1035,11 +1124,11 @@ RAG_EMBEDDING_MODEL = PersistentConfig(
1035
  log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
1036
 
1037
  RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
1038
- os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
1039
  )
1040
 
1041
  RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
1042
- os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
1043
  )
1044
 
1045
  RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
@@ -1060,11 +1149,11 @@ if RAG_RERANKING_MODEL.value != "":
1060
  log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
1061
 
1062
  RAG_RERANKING_MODEL_AUTO_UPDATE = (
1063
- os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
1064
  )
1065
 
1066
  RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
1067
- os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
1068
  )
1069
 
1070
 
@@ -1129,6 +1218,19 @@ RAG_OPENAI_API_KEY = PersistentConfig(
1129
  os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
1130
  )
1131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1132
  ENABLE_RAG_LOCAL_WEB_FETCH = (
1133
  os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
1134
  )
@@ -1218,6 +1320,12 @@ TAVILY_API_KEY = PersistentConfig(
1218
  os.getenv("TAVILY_API_KEY", ""),
1219
  )
1220
 
 
 
 
 
 
 
1221
  SEARCHAPI_API_KEY = PersistentConfig(
1222
  "SEARCHAPI_API_KEY",
1223
  "rag.web.search.searchapi_api_key",
@@ -1230,6 +1338,21 @@ SEARCHAPI_ENGINE = PersistentConfig(
1230
  os.getenv("SEARCHAPI_ENGINE", ""),
1231
  )
1232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1233
  RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
1234
  "RAG_WEB_SEARCH_RESULT_COUNT",
1235
  "rag.web.search.result_count",
@@ -1281,7 +1404,7 @@ AUTOMATIC1111_CFG_SCALE = PersistentConfig(
1281
 
1282
 
1283
  AUTOMATIC1111_SAMPLER = PersistentConfig(
1284
- "AUTOMATIC1111_SAMPLERE",
1285
  "image_generation.automatic1111.sampler",
1286
  (
1287
  os.environ.get("AUTOMATIC1111_SAMPLER")
@@ -1550,3 +1673,74 @@ AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT = PersistentConfig(
1550
  "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3"
1551
  ),
1552
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  WEBUI_FAVICON_URL,
21
  WEBUI_NAME,
22
  log,
23
+ DATABASE_URL,
24
  )
25
  from pydantic import BaseModel
26
  from sqlalchemy import JSON, Column, DateTime, Integer, func
 
265
  # WEBUI_AUTH (Required for security)
266
  ####################################
267
 
268
+ ENABLE_API_KEY = PersistentConfig(
269
+ "ENABLE_API_KEY",
270
+ "auth.api_key.enable",
271
+ os.environ.get("ENABLE_API_KEY", "True").lower() == "true",
272
+ )
273
+
274
+
275
  JWT_EXPIRES_IN = PersistentConfig(
276
  "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
277
  )
 
614
  "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
615
  )
616
 
617
+ OLLAMA_API_CONFIGS = PersistentConfig(
618
+ "OLLAMA_API_CONFIGS",
619
+ "ollama.api_configs",
620
+ {},
621
+ )
622
+
623
  ####################################
624
  # OPENAI_API
625
  ####################################
 
660
  "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
661
  )
662
 
663
+ OPENAI_API_CONFIGS = PersistentConfig(
664
+ "OPENAI_API_CONFIGS",
665
+ "openai.api_configs",
666
+ {},
667
+ )
668
 
669
+ # Get the actual OpenAI API key based on the base URL
670
+ OPENAI_API_KEY = ""
671
  try:
672
  OPENAI_API_KEY = OPENAI_API_KEYS.value[
673
  OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
674
  ]
675
  except Exception:
676
  pass
 
677
  OPENAI_API_BASE_URL = "https://api.openai.com/v1"
678
 
679
  ####################################
 
746
  os.getenv("DEFAULT_USER_ROLE", "pending"),
747
  )
748
 
749
+
750
+ USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = (
751
+ os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower()
752
+ == "true"
753
+ )
754
+
755
+ USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS = (
756
+ os.environ.get("USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS", "False").lower()
757
+ == "true"
758
+ )
759
+
760
+ USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS = (
761
+ os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS", "False").lower()
762
+ == "true"
763
+ )
764
+
765
+ USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = (
766
+ os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true"
767
+ )
768
+
769
+ USER_PERMISSIONS_CHAT_FILE_UPLOAD = (
770
+ os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true"
771
+ )
772
+
773
+ USER_PERMISSIONS_CHAT_DELETE = (
774
+ os.environ.get("USER_PERMISSIONS_CHAT_DELETE", "True").lower() == "true"
775
  )
776
 
777
+ USER_PERMISSIONS_CHAT_EDIT = (
778
+ os.environ.get("USER_PERMISSIONS_CHAT_EDIT", "True").lower() == "true"
779
  )
780
 
781
  USER_PERMISSIONS_CHAT_TEMPORARY = (
 
784
 
785
  USER_PERMISSIONS = PersistentConfig(
786
  "USER_PERMISSIONS",
787
+ "user.permissions",
788
  {
789
+ "workspace": {
790
+ "models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS,
791
+ "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS,
792
+ "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS,
793
+ "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS,
794
+ },
795
  "chat": {
796
+ "file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD,
797
+ "delete": USER_PERMISSIONS_CHAT_DELETE,
798
+ "edit": USER_PERMISSIONS_CHAT_EDIT,
799
  "temporary": USER_PERMISSIONS_CHAT_TEMPORARY,
800
+ },
801
  },
802
  )
803
 
 
823
  },
824
  }
825
 
 
 
 
 
 
 
 
 
 
 
 
 
826
  WEBHOOK_URL = PersistentConfig(
827
  "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
828
  )
 
942
  os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""),
943
  )
944
 
945
+ ENABLE_TAGS_GENERATION = PersistentConfig(
946
+ "ENABLE_TAGS_GENERATION",
947
+ "task.tags.enable",
948
+ os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true",
949
+ )
950
+
951
+
952
+ ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig(
953
+ "ENABLE_SEARCH_QUERY_GENERATION",
954
+ "task.query.search.enable",
955
+ os.environ.get("ENABLE_SEARCH_QUERY_GENERATION", "True").lower() == "true",
956
+ )
957
+
958
+ ENABLE_RETRIEVAL_QUERY_GENERATION = PersistentConfig(
959
+ "ENABLE_RETRIEVAL_QUERY_GENERATION",
960
+ "task.query.retrieval.enable",
961
+ os.environ.get("ENABLE_RETRIEVAL_QUERY_GENERATION", "True").lower() == "true",
962
  )
963
 
964
 
965
+ QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
966
+ "QUERY_GENERATION_PROMPT_TEMPLATE",
967
+ "task.query.prompt_template",
968
+ os.environ.get("QUERY_GENERATION_PROMPT_TEMPLATE", ""),
969
  )
970
 
971
+ DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task:
972
+ Based on the chat history, determine whether a search is necessary, and if so, generate a 1-3 broad search queries to retrieve comprehensive and updated information. If no search is required, return an empty list.
973
+
974
+ ### Guidelines:
975
+ - Respond exclusively with a JSON object.
976
+ - If a search query is needed, return an object like: { "queries": ["query1", "query2"] } where each query is distinct and concise.
977
+ - If no search query is necessary, output should be: { "queries": [] }
978
+ - Default to suggesting a search query to ensure accurate and updated information, unless it is definitively clear no search is required.
979
+ - Be concise, focusing strictly on composing search queries with no additional commentary or text.
980
+ - When in doubt, prefer to suggest a search for comprehensiveness.
981
+ - Today's date is: {{CURRENT_DATE}}
982
+
983
+ ### Output:
984
+ JSON format: {
985
+ "queries": ["query1", "query2"]
986
+ }
987
+
988
+ ### Chat History:
989
+ <chat_history>
990
+ {{MESSAGES:END:6}}
991
+ </chat_history>
992
+ """
993
+
994
 
995
  TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
996
  "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
 
1030
 
1031
  # Qdrant
1032
  QDRANT_URI = os.environ.get("QDRANT_URI", None)
1033
+ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
1034
+
1035
+ # OpenSearch
1036
+ OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
1037
+ OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", True)
1038
+ OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False)
1039
+ OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
1040
+ OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
1041
+
1042
+ # Pgvector
1043
+ PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL)
1044
+ if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
1045
+ raise ValueError(
1046
+ "Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database."
1047
+ )
1048
 
1049
  ####################################
1050
  # Information Retrieval (RAG)
 
1124
  log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
1125
 
1126
  RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
1127
+ os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true"
1128
  )
1129
 
1130
  RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
1131
+ os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true"
1132
  )
1133
 
1134
  RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
 
1149
  log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
1150
 
1151
  RAG_RERANKING_MODEL_AUTO_UPDATE = (
1152
+ os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
1153
  )
1154
 
1155
  RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
1156
+ os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true"
1157
  )
1158
 
1159
 
 
1218
  os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
1219
  )
1220
 
1221
+ RAG_OLLAMA_BASE_URL = PersistentConfig(
1222
+ "RAG_OLLAMA_BASE_URL",
1223
+ "rag.ollama.url",
1224
+ os.getenv("RAG_OLLAMA_BASE_URL", OLLAMA_BASE_URL),
1225
+ )
1226
+
1227
+ RAG_OLLAMA_API_KEY = PersistentConfig(
1228
+ "RAG_OLLAMA_API_KEY",
1229
+ "rag.ollama.key",
1230
+ os.getenv("RAG_OLLAMA_API_KEY", ""),
1231
+ )
1232
+
1233
+
1234
  ENABLE_RAG_LOCAL_WEB_FETCH = (
1235
  os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
1236
  )
 
1320
  os.getenv("TAVILY_API_KEY", ""),
1321
  )
1322
 
1323
+ JINA_API_KEY = PersistentConfig(
1324
+ "JINA_API_KEY",
1325
+ "rag.web.search.jina_api_key",
1326
+ os.getenv("JINA_API_KEY", ""),
1327
+ )
1328
+
1329
  SEARCHAPI_API_KEY = PersistentConfig(
1330
  "SEARCHAPI_API_KEY",
1331
  "rag.web.search.searchapi_api_key",
 
1338
  os.getenv("SEARCHAPI_ENGINE", ""),
1339
  )
1340
 
1341
+ BING_SEARCH_V7_ENDPOINT = PersistentConfig(
1342
+ "BING_SEARCH_V7_ENDPOINT",
1343
+ "rag.web.search.bing_search_v7_endpoint",
1344
+ os.environ.get(
1345
+ "BING_SEARCH_V7_ENDPOINT", "https://api.bing.microsoft.com/v7.0/search"
1346
+ ),
1347
+ )
1348
+
1349
+ BING_SEARCH_V7_SUBSCRIPTION_KEY = PersistentConfig(
1350
+ "BING_SEARCH_V7_SUBSCRIPTION_KEY",
1351
+ "rag.web.search.bing_search_v7_subscription_key",
1352
+ os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""),
1353
+ )
1354
+
1355
+
1356
  RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
1357
  "RAG_WEB_SEARCH_RESULT_COUNT",
1358
  "rag.web.search.result_count",
 
1404
 
1405
 
1406
  AUTOMATIC1111_SAMPLER = PersistentConfig(
1407
+ "AUTOMATIC1111_SAMPLER",
1408
  "image_generation.automatic1111.sampler",
1409
  (
1410
  os.environ.get("AUTOMATIC1111_SAMPLER")
 
1673
  "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3"
1674
  ),
1675
  )
1676
+
1677
+
1678
+ ####################################
1679
+ # LDAP
1680
+ ####################################
1681
+
1682
+ ENABLE_LDAP = PersistentConfig(
1683
+ "ENABLE_LDAP",
1684
+ "ldap.enable",
1685
+ os.environ.get("ENABLE_LDAP", "false").lower() == "true",
1686
+ )
1687
+
1688
+ LDAP_SERVER_LABEL = PersistentConfig(
1689
+ "LDAP_SERVER_LABEL",
1690
+ "ldap.server.label",
1691
+ os.environ.get("LDAP_SERVER_LABEL", "LDAP Server"),
1692
+ )
1693
+
1694
+ LDAP_SERVER_HOST = PersistentConfig(
1695
+ "LDAP_SERVER_HOST",
1696
+ "ldap.server.host",
1697
+ os.environ.get("LDAP_SERVER_HOST", "localhost"),
1698
+ )
1699
+
1700
+ LDAP_SERVER_PORT = PersistentConfig(
1701
+ "LDAP_SERVER_PORT",
1702
+ "ldap.server.port",
1703
+ int(os.environ.get("LDAP_SERVER_PORT", "389")),
1704
+ )
1705
+
1706
+ LDAP_ATTRIBUTE_FOR_USERNAME = PersistentConfig(
1707
+ "LDAP_ATTRIBUTE_FOR_USERNAME",
1708
+ "ldap.server.attribute_for_username",
1709
+ os.environ.get("LDAP_ATTRIBUTE_FOR_USERNAME", "uid"),
1710
+ )
1711
+
1712
+ LDAP_APP_DN = PersistentConfig(
1713
+ "LDAP_APP_DN", "ldap.server.app_dn", os.environ.get("LDAP_APP_DN", "")
1714
+ )
1715
+
1716
+ LDAP_APP_PASSWORD = PersistentConfig(
1717
+ "LDAP_APP_PASSWORD",
1718
+ "ldap.server.app_password",
1719
+ os.environ.get("LDAP_APP_PASSWORD", ""),
1720
+ )
1721
+
1722
+ LDAP_SEARCH_BASE = PersistentConfig(
1723
+ "LDAP_SEARCH_BASE", "ldap.server.users_dn", os.environ.get("LDAP_SEARCH_BASE", "")
1724
+ )
1725
+
1726
+ LDAP_SEARCH_FILTERS = PersistentConfig(
1727
+ "LDAP_SEARCH_FILTER",
1728
+ "ldap.server.search_filter",
1729
+ os.environ.get("LDAP_SEARCH_FILTER", ""),
1730
+ )
1731
+
1732
+ LDAP_USE_TLS = PersistentConfig(
1733
+ "LDAP_USE_TLS",
1734
+ "ldap.server.use_tls",
1735
+ os.environ.get("LDAP_USE_TLS", "True").lower() == "true",
1736
+ )
1737
+
1738
+ LDAP_CA_CERT_FILE = PersistentConfig(
1739
+ "LDAP_CA_CERT_FILE",
1740
+ "ldap.server.ca_cert_file",
1741
+ os.environ.get("LDAP_CA_CERT_FILE", ""),
1742
+ )
1743
+
1744
+ LDAP_CIPHERS = PersistentConfig(
1745
+ "LDAP_CIPHERS", "ldap.server.ciphers", os.environ.get("LDAP_CIPHERS", "ALL")
1746
+ )
backend/open_webui/constants.py CHANGED
@@ -62,6 +62,7 @@ class ERROR_MESSAGES(str, Enum):
62
  NOT_FOUND = "We could not find what you're looking for :/"
63
  USER_NOT_FOUND = "We could not find what you're looking for :/"
64
  API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
 
65
 
66
  MALICIOUS = "Unusual activities detected, please try again in a few minutes."
67
 
@@ -75,6 +76,7 @@ class ERROR_MESSAGES(str, Enum):
75
  OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
76
  OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
77
  CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
 
78
 
79
  EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
80
 
 
62
  NOT_FOUND = "We could not find what you're looking for :/"
63
  USER_NOT_FOUND = "We could not find what you're looking for :/"
64
  API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
65
+ API_KEY_NOT_ALLOWED = "Use of API key is not enabled in the environment."
66
 
67
  MALICIOUS = "Unusual activities detected, please try again in a few minutes."
68
 
 
76
  OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
77
  OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
78
  CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
79
+ API_KEY_CREATION_NOT_ALLOWED = "API key creation is not allowed in the environment."
80
 
81
  EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
82
 
backend/open_webui/env.py CHANGED
@@ -1,384 +1,393 @@
1
- import importlib.metadata
2
- import json
3
- import logging
4
- import os
5
- import pkgutil
6
- import sys
7
- import shutil
8
- from pathlib import Path
9
-
10
- import markdown
11
- from bs4 import BeautifulSoup
12
- from open_webui.constants import ERROR_MESSAGES
13
-
14
- ####################################
15
- # Load .env file
16
- ####################################
17
-
18
- OPEN_WEBUI_DIR = Path(__file__).parent # the path containing this file
19
- print(OPEN_WEBUI_DIR)
20
-
21
- BACKEND_DIR = OPEN_WEBUI_DIR.parent # the path containing this file
22
- BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
23
-
24
- print(BACKEND_DIR)
25
- print(BASE_DIR)
26
-
27
- try:
28
- from dotenv import find_dotenv, load_dotenv
29
-
30
- load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
31
- except ImportError:
32
- print("dotenv not installed, skipping...")
33
-
34
- DOCKER = os.environ.get("DOCKER", "False").lower() == "true"
35
-
36
- # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
37
- USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
38
-
39
- if USE_CUDA.lower() == "true":
40
- try:
41
- import torch
42
-
43
- assert torch.cuda.is_available(), "CUDA not available"
44
- DEVICE_TYPE = "cuda"
45
- except Exception as e:
46
- cuda_error = (
47
- "Error when testing CUDA but USE_CUDA_DOCKER is true. "
48
- f"Resetting USE_CUDA_DOCKER to false: {e}"
49
- )
50
- os.environ["USE_CUDA_DOCKER"] = "false"
51
- USE_CUDA = "false"
52
- DEVICE_TYPE = "cpu"
53
- else:
54
- DEVICE_TYPE = "cpu"
55
-
56
-
57
- ####################################
58
- # LOGGING
59
- ####################################
60
-
61
- log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
62
-
63
- GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "ERROR").upper()
64
- if GLOBAL_LOG_LEVEL in log_levels:
65
- logging.basicConfig(stream=sys.stdout, level="ERROR", force=False)
66
- else:
67
- GLOBAL_LOG_LEVEL = "ERROR"
68
-
69
- log = logging.getLogger(__name__)
70
- log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
71
-
72
- if "cuda_error" in locals():
73
- log.exception(cuda_error)
74
-
75
- log_sources = [
76
- "AUDIO",
77
- "COMFYUI",
78
- "CONFIG",
79
- "DB",
80
- "IMAGES",
81
- "MAIN",
82
- "MODELS",
83
- "OLLAMA",
84
- "OPENAI",
85
- "RAG",
86
- "WEBHOOK",
87
- "SOCKET",
88
- ]
89
-
90
- SRC_LOG_LEVELS = {}
91
-
92
- for source in log_sources:
93
- log_env_var = source + "_LOG_LEVEL"
94
- SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
95
- if SRC_LOG_LEVELS[source] not in log_levels:
96
- SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
97
- log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
98
-
99
- log.setLevel(SRC_LOG_LEVELS["CONFIG"])
100
-
101
-
102
- WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
103
- if WEBUI_NAME != "Open WebUI":
104
- WEBUI_NAME += " (Open WebUI)"
105
-
106
- WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
107
-
108
- WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
109
-
110
-
111
- ####################################
112
- # ENV (dev,test,prod)
113
- ####################################
114
-
115
- ENV = os.environ.get("ENV", "dev")
116
-
117
- FROM_INIT_PY = os.environ.get("FROM_INIT_PY", "False").lower() == "true"
118
-
119
- if FROM_INIT_PY:
120
- PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")}
121
- else:
122
- try:
123
- PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text())
124
- except Exception:
125
- PACKAGE_DATA = {"version": "0.0.0"}
126
-
127
-
128
- VERSION = PACKAGE_DATA["version"]
129
-
130
-
131
- # Function to parse each section
132
- def parse_section(section):
133
- items = []
134
- for li in section.find_all("li"):
135
- # Extract raw HTML string
136
- raw_html = str(li)
137
-
138
- # Extract text without HTML tags
139
- text = li.get_text(separator=" ", strip=True)
140
-
141
- # Split into title and content
142
- parts = text.split(": ", 1)
143
- title = parts[0].strip() if len(parts) > 1 else ""
144
- content = parts[1].strip() if len(parts) > 1 else text
145
-
146
- items.append({"title": title, "content": content, "raw": raw_html})
147
- return items
148
-
149
-
150
- try:
151
- changelog_path = BASE_DIR / "CHANGELOG.md"
152
- with open(str(changelog_path.absolute()), "r", encoding="utf8") as file:
153
- changelog_content = file.read()
154
-
155
- except Exception:
156
- changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
157
-
158
-
159
- # Convert markdown content to HTML
160
- html_content = markdown.markdown(changelog_content)
161
-
162
- # Parse the HTML content
163
- soup = BeautifulSoup(html_content, "html.parser")
164
-
165
- # Initialize JSON structure
166
- changelog_json = {}
167
-
168
- # Iterate over each version
169
- for version in soup.find_all("h2"):
170
- version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets
171
- date = version.get_text().strip().split(" - ")[1]
172
-
173
- version_data = {"date": date}
174
-
175
- # Find the next sibling that is a h3 tag (section title)
176
- current = version.find_next_sibling()
177
-
178
- while current and current.name != "h2":
179
- if current.name == "h3":
180
- section_title = current.get_text().lower() # e.g., "added", "fixed"
181
- section_items = parse_section(current.find_next_sibling("ul"))
182
- version_data[section_title] = section_items
183
-
184
- # Move to the next element
185
- current = current.find_next_sibling()
186
-
187
- changelog_json[version_number] = version_data
188
-
189
-
190
- CHANGELOG = changelog_json
191
-
192
- ####################################
193
- # SAFE_MODE
194
- ####################################
195
-
196
- SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
197
-
198
- ####################################
199
- # WEBUI_BUILD_HASH
200
- ####################################
201
-
202
- WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
203
-
204
- ####################################
205
- # DATA/FRONTEND BUILD DIR
206
- ####################################
207
-
208
- DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
209
-
210
- if FROM_INIT_PY:
211
- NEW_DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")).resolve()
212
- NEW_DATA_DIR.mkdir(parents=True, exist_ok=True)
213
-
214
- # Check if the data directory exists in the package directory
215
- if DATA_DIR.exists() and DATA_DIR != NEW_DATA_DIR:
216
- log.info(f"Moving {DATA_DIR} to {NEW_DATA_DIR}")
217
- for item in DATA_DIR.iterdir():
218
- dest = NEW_DATA_DIR / item.name
219
- if item.is_dir():
220
- shutil.copytree(item, dest, dirs_exist_ok=True)
221
- else:
222
- shutil.copy2(item, dest)
223
-
224
- # Zip the data directory
225
- shutil.make_archive(DATA_DIR.parent / "open_webui_data", "zip", DATA_DIR)
226
-
227
- # Remove the old data directory
228
- shutil.rmtree(DATA_DIR)
229
-
230
- DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
231
-
232
-
233
- STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))
234
-
235
- FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts"))
236
-
237
- FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
238
-
239
- if FROM_INIT_PY:
240
- FRONTEND_BUILD_DIR = Path(
241
- os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend")
242
- ).resolve()
243
-
244
-
245
- ####################################
246
- # Database
247
- ####################################
248
-
249
- # Check if the file exists
250
- if os.path.exists(f"{DATA_DIR}/ollama.db"):
251
- # Rename the file
252
- os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
253
- log.info("Database migrated from Ollama-WebUI successfully.")
254
- else:
255
- pass
256
-
257
- DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
258
-
259
- # Replace the postgres:// with postgresql://
260
- if "postgres://" in DATABASE_URL:
261
- DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
262
-
263
- DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0)
264
-
265
- if DATABASE_POOL_SIZE == "":
266
- DATABASE_POOL_SIZE = 0
267
- else:
268
- try:
269
- DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE)
270
- except Exception:
271
- DATABASE_POOL_SIZE = 0
272
-
273
- DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0)
274
-
275
- if DATABASE_POOL_MAX_OVERFLOW == "":
276
- DATABASE_POOL_MAX_OVERFLOW = 0
277
- else:
278
- try:
279
- DATABASE_POOL_MAX_OVERFLOW = int(DATABASE_POOL_MAX_OVERFLOW)
280
- except Exception:
281
- DATABASE_POOL_MAX_OVERFLOW = 0
282
-
283
- DATABASE_POOL_TIMEOUT = os.environ.get("DATABASE_POOL_TIMEOUT", 30)
284
-
285
- if DATABASE_POOL_TIMEOUT == "":
286
- DATABASE_POOL_TIMEOUT = 30
287
- else:
288
- try:
289
- DATABASE_POOL_TIMEOUT = int(DATABASE_POOL_TIMEOUT)
290
- except Exception:
291
- DATABASE_POOL_TIMEOUT = 30
292
-
293
- DATABASE_POOL_RECYCLE = os.environ.get("DATABASE_POOL_RECYCLE", 3600)
294
-
295
- if DATABASE_POOL_RECYCLE == "":
296
- DATABASE_POOL_RECYCLE = 3600
297
- else:
298
- try:
299
- DATABASE_POOL_RECYCLE = int(DATABASE_POOL_RECYCLE)
300
- except Exception:
301
- DATABASE_POOL_RECYCLE = 3600
302
-
303
- RESET_CONFIG_ON_START = (
304
- os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
305
- )
306
-
307
- ####################################
308
- # REDIS
309
- ####################################
310
-
311
- REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
312
-
313
- ####################################
314
- # WEBUI_AUTH (Required for security)
315
- ####################################
316
-
317
- WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
318
- WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
319
- "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
320
- )
321
- WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
322
-
323
-
324
- ####################################
325
- # WEBUI_SECRET_KEY
326
- ####################################
327
-
328
- WEBUI_SECRET_KEY = os.environ.get(
329
- "WEBUI_SECRET_KEY",
330
- os.environ.get(
331
- "WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t"
332
- ), # DEPRECATED: remove at next major version
333
- )
334
-
335
- WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get(
336
- "WEBUI_SESSION_COOKIE_SAME_SITE",
337
- os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"),
338
- )
339
-
340
- WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
341
- "WEBUI_SESSION_COOKIE_SECURE",
342
- os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true",
343
- )
344
-
345
- if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
346
- raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
347
-
348
- ENABLE_WEBSOCKET_SUPPORT = (
349
- os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
350
- )
351
-
352
- WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
353
-
354
- WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
355
-
356
- AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
357
-
358
- if AIOHTTP_CLIENT_TIMEOUT == "":
359
- AIOHTTP_CLIENT_TIMEOUT = None
360
- else:
361
- try:
362
- AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
363
- except Exception:
364
- AIOHTTP_CLIENT_TIMEOUT = 300
365
-
366
- AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
367
- "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "3"
368
- )
369
-
370
- if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
371
- AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None
372
- else:
373
- try:
374
- AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int(
375
- AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
376
- )
377
- except Exception:
378
- AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 3
379
-
380
- ####################################
381
- # OFFLINE_MODE
382
- ####################################
383
-
384
- OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import json
3
+ import logging
4
+ import os
5
+ import pkgutil
6
+ import sys
7
+ import shutil
8
+ from pathlib import Path
9
+
10
+ import markdown
11
+ from bs4 import BeautifulSoup
12
+ from open_webui.constants import ERROR_MESSAGES
13
+
14
+ ####################################
15
+ # Load .env file
16
+ ####################################
17
+
18
+ OPEN_WEBUI_DIR = Path(__file__).parent # the path containing this file
19
+ print(OPEN_WEBUI_DIR)
20
+
21
+ BACKEND_DIR = OPEN_WEBUI_DIR.parent # the path containing this file
22
+ BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
23
+
24
+ print(BACKEND_DIR)
25
+ print(BASE_DIR)
26
+
27
+ try:
28
+ from dotenv import find_dotenv, load_dotenv
29
+
30
+ load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
31
+ except ImportError:
32
+ print("dotenv not installed, skipping...")
33
+
34
+ DOCKER = os.environ.get("DOCKER", "False").lower() == "true"
35
+
36
+ # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
37
+ USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
38
+
39
+ if USE_CUDA.lower() == "true":
40
+ try:
41
+ import torch
42
+
43
+ assert torch.cuda.is_available(), "CUDA not available"
44
+ DEVICE_TYPE = "cuda"
45
+ except Exception as e:
46
+ cuda_error = (
47
+ "Error when testing CUDA but USE_CUDA_DOCKER is true. "
48
+ f"Resetting USE_CUDA_DOCKER to false: {e}"
49
+ )
50
+ os.environ["USE_CUDA_DOCKER"] = "false"
51
+ USE_CUDA = "false"
52
+ DEVICE_TYPE = "cpu"
53
+ else:
54
+ DEVICE_TYPE = "cpu"
55
+
56
+
57
+ ####################################
58
+ # LOGGING
59
+ ####################################
60
+
61
+ log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
62
+
63
+ GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
64
+ if GLOBAL_LOG_LEVEL in log_levels:
65
+ logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
66
+ else:
67
+ GLOBAL_LOG_LEVEL = "INFO"
68
+
69
+ log = logging.getLogger(__name__)
70
+ log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
71
+
72
+ if "cuda_error" in locals():
73
+ log.exception(cuda_error)
74
+
75
+ log_sources = [
76
+ "AUDIO",
77
+ "COMFYUI",
78
+ "CONFIG",
79
+ "DB",
80
+ "IMAGES",
81
+ "MAIN",
82
+ "MODELS",
83
+ "OLLAMA",
84
+ "OPENAI",
85
+ "RAG",
86
+ "WEBHOOK",
87
+ "SOCKET",
88
+ ]
89
+
90
+ SRC_LOG_LEVELS = {}
91
+
92
+ for source in log_sources:
93
+ log_env_var = source + "_LOG_LEVEL"
94
+ SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
95
+ if SRC_LOG_LEVELS[source] not in log_levels:
96
+ SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
97
+ log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
98
+
99
+ log.setLevel(SRC_LOG_LEVELS["CONFIG"])
100
+
101
+
102
+ WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
103
+ if WEBUI_NAME != "Open WebUI":
104
+ WEBUI_NAME += " (Open WebUI)"
105
+
106
+ WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
107
+
108
+ WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
109
+
110
+
111
+ ####################################
112
+ # ENV (dev,test,prod)
113
+ ####################################
114
+
115
+ ENV = os.environ.get("ENV", "dev")
116
+
117
+ FROM_INIT_PY = os.environ.get("FROM_INIT_PY", "False").lower() == "true"
118
+
119
+ if FROM_INIT_PY:
120
+ PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")}
121
+ else:
122
+ try:
123
+ PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text())
124
+ except Exception:
125
+ PACKAGE_DATA = {"version": "0.0.0"}
126
+
127
+
128
+ VERSION = PACKAGE_DATA["version"]
129
+
130
+
131
+ # Function to parse each section
132
+ def parse_section(section):
133
+ items = []
134
+ for li in section.find_all("li"):
135
+ # Extract raw HTML string
136
+ raw_html = str(li)
137
+
138
+ # Extract text without HTML tags
139
+ text = li.get_text(separator=" ", strip=True)
140
+
141
+ # Split into title and content
142
+ parts = text.split(": ", 1)
143
+ title = parts[0].strip() if len(parts) > 1 else ""
144
+ content = parts[1].strip() if len(parts) > 1 else text
145
+
146
+ items.append({"title": title, "content": content, "raw": raw_html})
147
+ return items
148
+
149
+
150
+ try:
151
+ changelog_path = BASE_DIR / "CHANGELOG.md"
152
+ with open(str(changelog_path.absolute()), "r", encoding="utf8") as file:
153
+ changelog_content = file.read()
154
+
155
+ except Exception:
156
+ changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
157
+
158
+
159
+ # Convert markdown content to HTML
160
+ html_content = markdown.markdown(changelog_content)
161
+
162
+ # Parse the HTML content
163
+ soup = BeautifulSoup(html_content, "html.parser")
164
+
165
+ # Initialize JSON structure
166
+ changelog_json = {}
167
+
168
+ # Iterate over each version
169
+ for version in soup.find_all("h2"):
170
+ version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets
171
+ date = version.get_text().strip().split(" - ")[1]
172
+
173
+ version_data = {"date": date}
174
+
175
+ # Find the next sibling that is a h3 tag (section title)
176
+ current = version.find_next_sibling()
177
+
178
+ while current and current.name != "h2":
179
+ if current.name == "h3":
180
+ section_title = current.get_text().lower() # e.g., "added", "fixed"
181
+ section_items = parse_section(current.find_next_sibling("ul"))
182
+ version_data[section_title] = section_items
183
+
184
+ # Move to the next element
185
+ current = current.find_next_sibling()
186
+
187
+ changelog_json[version_number] = version_data
188
+
189
+
190
+ CHANGELOG = changelog_json
191
+
192
+ ####################################
193
+ # SAFE_MODE
194
+ ####################################
195
+
196
+ SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
197
+
198
+ ####################################
199
+ # ENABLE_FORWARD_USER_INFO_HEADERS
200
+ ####################################
201
+
202
+ ENABLE_FORWARD_USER_INFO_HEADERS = (
203
+ os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
204
+ )
205
+
206
+
207
+ ####################################
208
+ # WEBUI_BUILD_HASH
209
+ ####################################
210
+
211
+ WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
212
+
213
+ ####################################
214
+ # DATA/FRONTEND BUILD DIR
215
+ ####################################
216
+
217
+ DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
218
+
219
+ if FROM_INIT_PY:
220
+ NEW_DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")).resolve()
221
+ NEW_DATA_DIR.mkdir(parents=True, exist_ok=True)
222
+
223
+ # Check if the data directory exists in the package directory
224
+ if DATA_DIR.exists() and DATA_DIR != NEW_DATA_DIR:
225
+ log.info(f"Moving {DATA_DIR} to {NEW_DATA_DIR}")
226
+ for item in DATA_DIR.iterdir():
227
+ dest = NEW_DATA_DIR / item.name
228
+ if item.is_dir():
229
+ shutil.copytree(item, dest, dirs_exist_ok=True)
230
+ else:
231
+ shutil.copy2(item, dest)
232
+
233
+ # Zip the data directory
234
+ shutil.make_archive(DATA_DIR.parent / "open_webui_data", "zip", DATA_DIR)
235
+
236
+ # Remove the old data directory
237
+ shutil.rmtree(DATA_DIR)
238
+
239
+ DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
240
+
241
+
242
+ STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))
243
+
244
+ FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts"))
245
+
246
+ FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
247
+
248
+ if FROM_INIT_PY:
249
+ FRONTEND_BUILD_DIR = Path(
250
+ os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend")
251
+ ).resolve()
252
+
253
+
254
+ ####################################
255
+ # Database
256
+ ####################################
257
+
258
+ # Check if the file exists
259
+ if os.path.exists(f"{DATA_DIR}/ollama.db"):
260
+ # Rename the file
261
+ os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
262
+ log.info("Database migrated from Ollama-WebUI successfully.")
263
+ else:
264
+ pass
265
+
266
+ DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
267
+
268
+ # Replace the postgres:// with postgresql://
269
+ if "postgres://" in DATABASE_URL:
270
+ DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
271
+
272
+ DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0)
273
+
274
+ if DATABASE_POOL_SIZE == "":
275
+ DATABASE_POOL_SIZE = 0
276
+ else:
277
+ try:
278
+ DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE)
279
+ except Exception:
280
+ DATABASE_POOL_SIZE = 0
281
+
282
+ DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0)
283
+
284
+ if DATABASE_POOL_MAX_OVERFLOW == "":
285
+ DATABASE_POOL_MAX_OVERFLOW = 0
286
+ else:
287
+ try:
288
+ DATABASE_POOL_MAX_OVERFLOW = int(DATABASE_POOL_MAX_OVERFLOW)
289
+ except Exception:
290
+ DATABASE_POOL_MAX_OVERFLOW = 0
291
+
292
+ DATABASE_POOL_TIMEOUT = os.environ.get("DATABASE_POOL_TIMEOUT", 30)
293
+
294
+ if DATABASE_POOL_TIMEOUT == "":
295
+ DATABASE_POOL_TIMEOUT = 30
296
+ else:
297
+ try:
298
+ DATABASE_POOL_TIMEOUT = int(DATABASE_POOL_TIMEOUT)
299
+ except Exception:
300
+ DATABASE_POOL_TIMEOUT = 30
301
+
302
+ DATABASE_POOL_RECYCLE = os.environ.get("DATABASE_POOL_RECYCLE", 3600)
303
+
304
+ if DATABASE_POOL_RECYCLE == "":
305
+ DATABASE_POOL_RECYCLE = 3600
306
+ else:
307
+ try:
308
+ DATABASE_POOL_RECYCLE = int(DATABASE_POOL_RECYCLE)
309
+ except Exception:
310
+ DATABASE_POOL_RECYCLE = 3600
311
+
312
+ RESET_CONFIG_ON_START = (
313
+ os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
314
+ )
315
+
316
+ ####################################
317
+ # REDIS
318
+ ####################################
319
+
320
+ REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
321
+
322
+ ####################################
323
+ # WEBUI_AUTH (Required for security)
324
+ ####################################
325
+
326
+ WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
327
+ WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
328
+ "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
329
+ )
330
+ WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
331
+
332
+
333
+ ####################################
334
+ # WEBUI_SECRET_KEY
335
+ ####################################
336
+
337
+ WEBUI_SECRET_KEY = os.environ.get(
338
+ "WEBUI_SECRET_KEY",
339
+ os.environ.get(
340
+ "WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t"
341
+ ), # DEPRECATED: remove at next major version
342
+ )
343
+
344
+ WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get(
345
+ "WEBUI_SESSION_COOKIE_SAME_SITE",
346
+ os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"),
347
+ )
348
+
349
+ WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
350
+ "WEBUI_SESSION_COOKIE_SECURE",
351
+ os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true",
352
+ )
353
+
354
+ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
355
+ raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
356
+
357
+ ENABLE_WEBSOCKET_SUPPORT = (
358
+ os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
359
+ )
360
+
361
+ WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
362
+
363
+ WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
364
+
365
+ AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
366
+
367
+ if AIOHTTP_CLIENT_TIMEOUT == "":
368
+ AIOHTTP_CLIENT_TIMEOUT = None
369
+ else:
370
+ try:
371
+ AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
372
+ except Exception:
373
+ AIOHTTP_CLIENT_TIMEOUT = 300
374
+
375
+ AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
376
+ "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "3"
377
+ )
378
+
379
+ if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
380
+ AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None
381
+ else:
382
+ try:
383
+ AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int(
384
+ AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
385
+ )
386
+ except Exception:
387
+ AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 3
388
+
389
+ ####################################
390
+ # OFFLINE_MODE
391
+ ####################################
392
+
393
+ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
backend/open_webui/main.py CHANGED
The diff for this file is too large to render. See raw diff
 
backend/open_webui/migrations/versions/922e7a387820_add_group_table.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Add group table
2
+
3
+ Revision ID: 922e7a387820
4
+ Revises: 4ace53fd72c8
5
+ Create Date: 2024-11-14 03:00:00.000000
6
+
7
+ """
8
+
9
+ from alembic import op
10
+ import sqlalchemy as sa
11
+
12
+ revision = "922e7a387820"
13
+ down_revision = "4ace53fd72c8"
14
+ branch_labels = None
15
+ depends_on = None
16
+
17
+
18
+ def upgrade():
19
+ op.create_table(
20
+ "group",
21
+ sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
22
+ sa.Column("user_id", sa.Text(), nullable=True),
23
+ sa.Column("name", sa.Text(), nullable=True),
24
+ sa.Column("description", sa.Text(), nullable=True),
25
+ sa.Column("data", sa.JSON(), nullable=True),
26
+ sa.Column("meta", sa.JSON(), nullable=True),
27
+ sa.Column("permissions", sa.JSON(), nullable=True),
28
+ sa.Column("user_ids", sa.JSON(), nullable=True),
29
+ sa.Column("created_at", sa.BigInteger(), nullable=True),
30
+ sa.Column("updated_at", sa.BigInteger(), nullable=True),
31
+ )
32
+
33
+ # Add 'access_control' column to 'model' table
34
+ op.add_column(
35
+ "model",
36
+ sa.Column("access_control", sa.JSON(), nullable=True),
37
+ )
38
+
39
+ # Add 'is_active' column to 'model' table
40
+ op.add_column(
41
+ "model",
42
+ sa.Column(
43
+ "is_active",
44
+ sa.Boolean(),
45
+ nullable=False,
46
+ server_default=sa.sql.expression.true(),
47
+ ),
48
+ )
49
+
50
+ # Add 'access_control' column to 'knowledge' table
51
+ op.add_column(
52
+ "knowledge",
53
+ sa.Column("access_control", sa.JSON(), nullable=True),
54
+ )
55
+
56
+ # Add 'access_control' column to 'prompt' table
57
+ op.add_column(
58
+ "prompt",
59
+ sa.Column("access_control", sa.JSON(), nullable=True),
60
+ )
61
+
62
+ # Add 'access_control' column to 'tools' table
63
+ op.add_column(
64
+ "tool",
65
+ sa.Column("access_control", sa.JSON(), nullable=True),
66
+ )
67
+
68
+
69
+ def downgrade():
70
+ op.drop_table("group")
71
+
72
+ # Drop 'access_control' column from 'model' table
73
+ op.drop_column("model", "access_control")
74
+
75
+ # Drop 'is_active' column from 'model' table
76
+ op.drop_column("model", "is_active")
77
+
78
+ # Drop 'access_control' column from 'knowledge' table
79
+ op.drop_column("knowledge", "access_control")
80
+
81
+ # Drop 'access_control' column from 'prompt' table
82
+ op.drop_column("prompt", "access_control")
83
+
84
+ # Drop 'access_control' column from 'tools' table
85
+ op.drop_column("tool", "access_control")
backend/open_webui/storage/provider.py CHANGED
@@ -51,7 +51,10 @@ class StorageProvider:
51
 
52
  try:
53
  self.s3_client.upload_file(file_path, self.bucket_name, filename)
54
- return open(file_path, "rb").read(), file_path
 
 
 
55
  except ClientError as e:
56
  raise RuntimeError(f"Error uploading file to S3: {e}")
57
 
 
51
 
52
  try:
53
  self.s3_client.upload_file(file_path, self.bucket_name, filename)
54
+ return (
55
+ open(file_path, "rb").read(),
56
+ "s3://" + self.bucket_name + "/" + filename,
57
+ )
58
  except ClientError as e:
59
  raise RuntimeError(f"Error uploading file to S3: {e}")
60
 
backend/open_webui/utils/access_control.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, List, Dict, Any
2
+ from open_webui.apps.webui.models.groups import Groups
3
+ import json
4
+
5
+
6
+ def get_permissions(
7
+ user_id: str,
8
+ default_permissions: Dict[str, Any],
9
+ ) -> Dict[str, Any]:
10
+ """
11
+ Get all permissions for a user by combining the permissions of all groups the user is a member of.
12
+ If a permission is defined in multiple groups, the most permissive value is used (True > False).
13
+ Permissions are nested in a dict with the permission key as the key and a boolean as the value.
14
+ """
15
+
16
+ def combine_permissions(
17
+ permissions: Dict[str, Any], group_permissions: Dict[str, Any]
18
+ ) -> Dict[str, Any]:
19
+ """Combine permissions from multiple groups by taking the most permissive value."""
20
+ for key, value in group_permissions.items():
21
+ if isinstance(value, dict):
22
+ if key not in permissions:
23
+ permissions[key] = {}
24
+ permissions[key] = combine_permissions(permissions[key], value)
25
+ else:
26
+ if key not in permissions:
27
+ permissions[key] = value
28
+ else:
29
+ permissions[key] = permissions[key] or value
30
+ return permissions
31
+
32
+ user_groups = Groups.get_groups_by_member_id(user_id)
33
+
34
+ # deep copy default permissions to avoid modifying the original dict
35
+ permissions = json.loads(json.dumps(default_permissions))
36
+
37
+ for group in user_groups:
38
+ group_permissions = group.permissions
39
+ permissions = combine_permissions(permissions, group_permissions)
40
+
41
+ return permissions
42
+
43
+
44
+ def has_permission(
45
+ user_id: str,
46
+ permission_key: str,
47
+ default_permissions: Dict[str, bool] = {},
48
+ ) -> bool:
49
+ """
50
+ Check if a user has a specific permission by checking the group permissions
51
+ and falls back to default permissions if not found in any group.
52
+
53
+ Permission keys can be hierarchical and separated by dots ('.').
54
+ """
55
+
56
+ def get_permission(permissions: Dict[str, bool], keys: List[str]) -> bool:
57
+ """Traverse permissions dict using a list of keys (from dot-split permission_key)."""
58
+ for key in keys:
59
+ if key not in permissions:
60
+ return False # If any part of the hierarchy is missing, deny access
61
+ permissions = permissions[key] # Go one level deeper
62
+
63
+ return bool(permissions) # Return the boolean at the final level
64
+
65
+ permission_hierarchy = permission_key.split(".")
66
+
67
+ # Retrieve user group permissions
68
+ user_groups = Groups.get_groups_by_member_id(user_id)
69
+
70
+ for group in user_groups:
71
+ group_permissions = group.permissions
72
+ if get_permission(group_permissions, permission_hierarchy):
73
+ return True
74
+
75
+ # Check default permissions afterwards if the group permissions don't allow it
76
+ return get_permission(default_permissions, permission_hierarchy)
77
+
78
+
79
+ def has_access(
80
+ user_id: str,
81
+ type: str = "write",
82
+ access_control: Optional[dict] = None,
83
+ ) -> bool:
84
+ if access_control is None:
85
+ return type == "read"
86
+
87
+ user_groups = Groups.get_groups_by_member_id(user_id)
88
+ user_group_ids = [group.id for group in user_groups]
89
+ permission_access = access_control.get(type, {})
90
+ permitted_group_ids = permission_access.get("group_ids", [])
91
+ permitted_user_ids = permission_access.get("user_ids", [])
92
+
93
+ return user_id in permitted_user_ids or any(
94
+ group_id in permitted_group_ids for group_id in user_group_ids
95
+ )
backend/open_webui/utils/pdf_generator.py CHANGED
@@ -54,18 +54,18 @@ class PDFGenerator:
54
  html_content = markdown(content, extensions=["pymdownx.extra"])
55
 
56
  html_message = f"""
57
- <div class="message">
58
- <small> {date_str} </small>
59
- <div>
60
- <h2>
61
- <strong>{role.title()}</strong>
62
- <small class="text-muted">{model}</small>
63
- </h2>
64
- </div>
65
- <div class="markdown-section">
66
- {html_content}
67
- </div>
68
- </div>
69
  """
70
  return html_message
71
 
 
54
  html_content = markdown(content, extensions=["pymdownx.extra"])
55
 
56
  html_message = f"""
57
+ <div> {date_str} </div>
58
+ <div class="message">
59
+ <div>
60
+ <h2>
61
+ <strong>{role.title()}</strong>
62
+ <span style="font-size: 12px; color: #888;">{model}</span>
63
+ </h2>
64
+ </div>
65
+ <pre class="markdown-section">
66
+ {content}
67
+ </pre>
68
+ </div>
69
  """
70
  return html_message
71
 
backend/open_webui/utils/security_headers.py CHANGED
@@ -20,6 +20,7 @@ def set_security_headers() -> Dict[str, str]:
20
  This function reads specific environment variables and uses their values
21
  to set corresponding security headers. The headers that can be set are:
22
  - cache-control
 
23
  - strict-transport-security
24
  - referrer-policy
25
  - x-content-type-options
@@ -38,6 +39,7 @@ def set_security_headers() -> Dict[str, str]:
38
  header_setters = {
39
  "CACHE_CONTROL": set_cache_control,
40
  "HSTS": set_hsts,
 
41
  "REFERRER_POLICY": set_referrer,
42
  "XCONTENT_TYPE": set_xcontent_type,
43
  "XDOWNLOAD_OPTIONS": set_xdownload_options,
@@ -73,6 +75,15 @@ def set_xframe(value: str):
73
  return {"X-Frame-Options": value}
74
 
75
 
 
 
 
 
 
 
 
 
 
76
  # Set Referrer-Policy response header
77
  def set_referrer(value: str):
78
  pattern = r"^(no-referrer|no-referrer-when-downgrade|origin|origin-when-cross-origin|same-origin|strict-origin|strict-origin-when-cross-origin|unsafe-url)$"
 
20
  This function reads specific environment variables and uses their values
21
  to set corresponding security headers. The headers that can be set are:
22
  - cache-control
23
+ - permissions-policy
24
  - strict-transport-security
25
  - referrer-policy
26
  - x-content-type-options
 
39
  header_setters = {
40
  "CACHE_CONTROL": set_cache_control,
41
  "HSTS": set_hsts,
42
+ "PERMISSIONS_POLICY": set_permissions_policy,
43
  "REFERRER_POLICY": set_referrer,
44
  "XCONTENT_TYPE": set_xcontent_type,
45
  "XDOWNLOAD_OPTIONS": set_xdownload_options,
 
75
  return {"X-Frame-Options": value}
76
 
77
 
78
+ # Set Permissions-Policy response header
79
+ def set_permissions_policy(value: str):
80
+ pattern = r"^(?:(accelerometer|autoplay|camera|clipboard-read|clipboard-write|fullscreen|geolocation|gyroscope|magnetometer|microphone|midi|payment|picture-in-picture|sync-xhr|usb|xr-spatial-tracking)=\((self)?\),?)*$"
81
+ match = re.match(pattern, value, re.IGNORECASE)
82
+ if not match:
83
+ value = "none"
84
+ return {"Permissions-Policy": value}
85
+
86
+
87
  # Set Referrer-Policy response header
88
  def set_referrer(value: str):
89
  pattern = r"^(no-referrer|no-referrer-when-downgrade|origin|origin-when-cross-origin|same-origin|strict-origin|strict-origin-when-cross-origin|unsafe-url)$"