Spaces:
Sleeping
Sleeping
from flask import ( | |
Flask, | |
jsonify, | |
request, | |
render_template_string, | |
abort, | |
) | |
from flask_cors import CORS | |
import unicodedata | |
import markdown | |
import time | |
import os | |
import gc | |
import base64 | |
from io import BytesIO | |
from random import randint | |
import hashlib | |
from colorama import Fore, Style, init as colorama_init | |
import chromadb | |
import posthog | |
from chromadb.config import Settings | |
from sentence_transformers import SentenceTransformer | |
from werkzeug.middleware.proxy_fix import ProxyFix | |
colorama_init() | |
port = 7860 | |
host = "0.0.0.0" | |
embedding_model = 'sentence-transformers/all-mpnet-base-v2' | |
print("Initializing ChromaDB") | |
# disable chromadb telemetry | |
posthog.capture = lambda *args, **kwargs: None | |
chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False)) | |
chromadb_embedder = SentenceTransformer(embedding_model) | |
chromadb_embed_fn = chromadb_embedder.encode | |
# Flask init | |
app = Flask(__name__) | |
CORS(app) # allow cross-domain requests | |
app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024 | |
app.wsgi_app = ProxyFix( | |
app.wsgi_app, x_for=2, x_proto=1, x_host=1, x_prefix=1 | |
) | |
def get_real_ip(): | |
return request.remote_addr | |
def index(): | |
with open("./README.md", "r", encoding="utf8") as f: | |
content = f.read() | |
return render_template_string(markdown.markdown(content, extensions=["tables"])) | |
def get_modules(): | |
return jsonify({"modules": ['chromadb']}) | |
def chromadb_add_messages(): | |
data = request.get_json() | |
if "chat_id" not in data or not isinstance(data["chat_id"], str): | |
abort(400, '"chat_id" is required') | |
if "messages" not in data or not isinstance(data["messages"], list): | |
abort(400, '"messages" is required') | |
ip = get_real_ip() | |
chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() | |
collection = chromadb_client.get_or_create_collection( | |
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn | |
) | |
documents = [m["content"] for m in data["messages"]] | |
ids = [m["id"] for m in data["messages"]] | |
metadatas = [ | |
{"role": m["role"], "date": m["date"], "meta": m.get("meta", "")} | |
for m in data["messages"] | |
] | |
if len(ids) > 0: | |
collection.upsert( | |
ids=ids, | |
documents=documents, | |
metadatas=metadatas, | |
) | |
return jsonify({"count": len(ids)}) | |
def chromadb_query(): | |
data = request.get_json() | |
if "chat_id" not in data or not isinstance(data["chat_id"], str): | |
abort(400, '"chat_id" is required') | |
if "query" not in data or not isinstance(data["query"], str): | |
abort(400, '"query" is required') | |
if "n_results" not in data or not isinstance(data["n_results"], int): | |
n_results = 1 | |
else: | |
n_results = data["n_results"] | |
ip = get_real_ip() | |
chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() | |
collection = chromadb_client.get_or_create_collection( | |
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn | |
) | |
n_results = min(collection.count(), n_results) | |
messages = [] | |
if n_results > 0: | |
query_result = collection.query( | |
query_texts=[data["query"]], | |
n_results=n_results, | |
) | |
documents = query_result["documents"][0] | |
ids = query_result["ids"][0] | |
metadatas = query_result["metadatas"][0] | |
distances = query_result["distances"][0] | |
messages = [ | |
{ | |
"id": ids[i], | |
"date": metadatas[i]["date"], | |
"role": metadatas[i]["role"], | |
"meta": metadatas[i]["meta"], | |
"content": documents[i], | |
"distance": distances[i], | |
} | |
for i in range(len(ids)) | |
] | |
return jsonify(messages) | |
def chromadb_purge(): | |
data = request.get_json() | |
if "chat_id" not in data or not isinstance(data["chat_id"], str): | |
abort(400, '"chat_id" is required') | |
ip = get_real_ip() | |
chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() | |
collection = chromadb_client.get_or_create_collection( | |
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn | |
) | |
deleted = collection.delete() | |
print("ChromaDB embeddings deleted", len(deleted)) | |
return 'Ok', 200 | |
app.run(host=host, port=port) |