from flask import ( Flask, jsonify, request, render_template_string, abort, ) from flask_cors import CORS import unicodedata import markdown import time import os import gc import base64 from io import BytesIO from random import randint import hashlib from colorama import Fore, Style, init as colorama_init import chromadb import posthog from chromadb.config import Settings from sentence_transformers import SentenceTransformer from werkzeug.middleware.proxy_fix import ProxyFix colorama_init() port = 7860 host = "0.0.0.0" embedding_model = 'sentence-transformers/all-mpnet-base-v2' print("Initializing ChromaDB") # disable chromadb telemetry posthog.capture = lambda *args, **kwargs: None chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False)) chromadb_embedder = SentenceTransformer(embedding_model) chromadb_embed_fn = chromadb_embedder.encode # Flask init app = Flask(__name__) CORS(app) # allow cross-domain requests app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024 app.wsgi_app = ProxyFix( app.wsgi_app, x_for=2, x_proto=1, x_host=1, x_prefix=1 ) def get_real_ip(): return request.remote_addr @app.route("/", methods=["GET"]) def index(): with open("./README.md", "r", encoding="utf8") as f: content = f.read() return render_template_string(markdown.markdown(content, extensions=["tables"])) @app.route("/api/modules", methods=["GET"]) def get_modules(): return jsonify({"modules": ['chromadb']}) @app.route("/api/chromadb", methods=["POST"]) def chromadb_add_messages(): data = request.get_json() if "chat_id" not in data or not isinstance(data["chat_id"], str): abort(400, '"chat_id" is required') if "messages" not in data or not isinstance(data["messages"], list): abort(400, '"messages" is required') ip = get_real_ip() chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() collection = chromadb_client.get_or_create_collection( name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn ) documents = [m["content"] for m in data["messages"]] ids = [m["id"] for m in data["messages"]] metadatas = [ {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")} for m in data["messages"] ] if len(ids) > 0: collection.upsert( ids=ids, documents=documents, metadatas=metadatas, ) return jsonify({"count": len(ids)}) @app.route("/api/chromadb/query", methods=["POST"]) def chromadb_query(): data = request.get_json() if "chat_id" not in data or not isinstance(data["chat_id"], str): abort(400, '"chat_id" is required') if "query" not in data or not isinstance(data["query"], str): abort(400, '"query" is required') if "n_results" not in data or not isinstance(data["n_results"], int): n_results = 1 else: n_results = data["n_results"] ip = get_real_ip() chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() collection = chromadb_client.get_or_create_collection( name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn ) n_results = min(collection.count(), n_results) messages = [] if n_results > 0: query_result = collection.query( query_texts=[data["query"]], n_results=n_results, ) documents = query_result["documents"][0] ids = query_result["ids"][0] metadatas = query_result["metadatas"][0] distances = query_result["distances"][0] messages = [ { "id": ids[i], "date": metadatas[i]["date"], "role": metadatas[i]["role"], "meta": metadatas[i]["meta"], "content": documents[i], "distance": distances[i], } for i in range(len(ids)) ] return jsonify(messages) @app.route("/api/chromadb/purge", methods=["POST"]) def chromadb_purge(): data = request.get_json() if "chat_id" not in data or not isinstance(data["chat_id"], str): abort(400, '"chat_id" is required') ip = get_real_ip() chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() collection = chromadb_client.get_or_create_collection( name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn ) deleted = collection.delete() print("ChromaDB embeddings deleted", len(deleted)) return 'Ok', 200 app.run(host=host, port=port)