Spaces:
Build error
Build error
# app.py | |
import os | |
import re | |
import uuid | |
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from dotenv import load_dotenv | |
from typing import List, Tuple, Dict, Any | |
from transformers import AutoTokenizer, AutoModel | |
from openai import OpenAI | |
from langchain_community.document_loaders import UnstructuredFileLoader | |
from langchain_chroma import Chroma | |
from chromadb import Documents, EmbeddingFunction, Embeddings | |
from chromadb.config import Settings | |
import chromadb | |
from utils import load_env_variables, parse_and_route, escape_special_characters | |
from globalvars import API_BASE, intention_prompt, tasks, system_message, metadata_prompt, model_name | |
import spaces | |
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever | |
from langchain_community.document_compressors.jina_rerank import JinaRerank | |
from langchain import hub | |
from langchain.chains.retrieval import create_retrieval_chain | |
from langchain.chains.combine_documents.stuff import create_stuff_documents_chain | |
load_dotenv() | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:180' | |
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' | |
os.environ['CUDA_CACHE_DISABLE'] = '1' | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
hf_token, yi_token = load_env_variables() | |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, trust_remote_code=True) | |
model = None | |
def load_model(): | |
global model | |
if model is None: | |
model = AutoModel.from_pretrained(model_name, token=hf_token, trust_remote_code=True).to(device) | |
return model | |
# Load model | |
jina_model = load_model() | |
def clear_cuda_cache(): | |
torch.cuda.empty_cache() | |
client = OpenAI(api_key=yi_token, base_url=API_BASE) | |
chroma_client = chromadb.Client(Settings()) | |
chroma_collection = chroma_client.create_collection("all-my-documents") | |
class JinaEmbeddingFunction(EmbeddingFunction): | |
def __init__(self, model, tokenizer, intention_client): | |
self.model = model | |
self.tokenizer = tokenizer | |
self.intention_client = intention_client | |
def __call__(self, input: Documents) -> Tuple[List[List[float]], List[Dict[str, Any]]]: | |
embeddings_with_metadata = [self.compute_embeddings(doc) for doc in input] | |
embeddings = [item[0] for item in embeddings_with_metadata] | |
metadata = [item[1] for item in embeddings_with_metadata] | |
return embeddings, metadata | |
def compute_embeddings(self, input_text: str): | |
escaped_input_text = escape_special_characters(input_text) | |
# Get the intention | |
intention_completion = self.intention_client.chat.completions.create( | |
model="yi-large", | |
messages=[ | |
{"role": "system", "content": escape_special_characters(intention_prompt)}, | |
{"role": "user", "content": escaped_input_text} | |
] | |
) | |
intention_output = intention_completion.choices[0].message.content | |
parsed_task = parse_and_route(intention_output) | |
selected_task = parsed_task if parsed_task in tasks else "DEFAULT" | |
task = tasks[selected_task] | |
# Get the metadata | |
metadata_completion = self.intention_client.chat.completions.create( | |
model="yi-large", | |
messages=[ | |
{"role": "system", "content": escape_special_characters(metadata_prompt)}, | |
{"role": "user", "content": escaped_input_text} | |
] | |
) | |
metadata_output = metadata_completion.choices[0].message.content | |
metadata = self.extract_metadata(metadata_output) | |
# Compute embeddings using Jina model | |
encoded_input = self.tokenizer(escaped_input_text, padding=True, truncation=True, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
model_output = self.model(**encoded_input, task=task) | |
embeddings = self.mean_pooling(model_output, encoded_input["attention_mask"]) | |
embeddings = F.normalize(embeddings, p=2, dim=1) | |
return embeddings.cpu().numpy().tolist()[0], metadata | |
def extract_metadata(self, metadata_output: str) -> Dict[str, str]: | |
pattern = re.compile(r'\"(\w+)\": \"([^\"]+)\"') | |
matches = pattern.findall(metadata_output) | |
metadata = {key: value for key, value in matches} | |
return metadata | |
def mean_pooling(model_output, attention_mask): | |
token_embeddings = model_output[0] | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
def load_documents(file_path: str, mode: str = "elements"): | |
loader = UnstructuredFileLoader(file_path, mode=mode) | |
docs = loader.load() | |
return [doc.page_content for doc in docs] | |
def initialize_chroma(collection_name: str, embedding_function: JinaEmbeddingFunction): | |
db = Chroma(client=chroma_client, collection_name=collection_name, embedding_function=embedding_function) | |
return db | |
def add_documents_to_chroma(documents: list, embedding_function: JinaEmbeddingFunction): | |
for doc in documents: | |
embeddings, metadata = embedding_function.compute_embeddings(doc) | |
chroma_collection.add( | |
ids=[str(uuid.uuid1())], | |
documents=[doc], | |
embeddings=[embeddings], | |
metadatas=[metadata] | |
) | |
def rerank_documents(query: str, documents: List[str]) -> List[str]: | |
compressor = JinaRerank() | |
retriever = chroma_db.as_retriever(search_kwargs={"k": 15}) | |
compression_retriever = ContextualCompressionRetriever( | |
base_compressor=compressor, base_retriever=retriever | |
) | |
compressed_docs = compression_retriever.get_relevant_documents(query) | |
return [doc.page_content for doc in compressed_docs] | |
def query_chroma(query_text: str, embedding_function: JinaEmbeddingFunction): | |
query_embeddings, query_metadata = embedding_function.compute_embeddings(query_text) | |
result_docs = chroma_collection.query( | |
query_embeddings=[query_embeddings], | |
n_results=5 | |
) | |
return result_docs | |
def answer_query(message: str, chat_history: List[Tuple[str, str]], system_message: str, max_new_tokens: int, temperature: float, top_p: float): | |
# Query Chroma for relevant documents | |
results = query_chroma(message, embedding_function) | |
context = "\n\n".join([result['document'] for result in results['documents'][0]]) | |
# Rerank the documents | |
reranked_docs = rerank_documents(message, context.split("\n\n")) | |
reranked_context = "\n\n".join(reranked_docs) | |
# Prepare the prompt for YI model | |
prompt = f"{system_message}\n\nContext: {reranked_context}\n\nHuman: {message}\n\nAssistant:" | |
# Generate response using YI model | |
response = client.chat.completions.create( | |
model="yi-large", | |
messages=[ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": f"Context: {reranked_context}\n\nHuman: {message}"} | |
], | |
max_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p | |
) | |
assistant_response = response.choices[0].message.content | |
chat_history.append((message, assistant_response)) | |
return "", chat_history | |
# Initialize clients | |
intention_client = OpenAI(api_key=yi_token, base_url=API_BASE) | |
embedding_function = JinaEmbeddingFunction(jina_model, tokenizer, intention_client) | |
chroma_db = initialize_chroma(collection_name="Jina-embeddings", embedding_function=embedding_function) | |
def upload_documents(files): | |
for file in files: | |
loader = UnstructuredFileLoader(file.name) | |
documents = loader.load() | |
add_documents_to_chroma([doc.page_content for doc in documents], embedding_function) | |
return "Documents uploaded and processed successfully!" | |
def query_documents(query): | |
results = query_chroma(query, embedding_function) | |
reranked_docs = rerank_documents(query, [result for result in results['documents'][0]]) | |
return "\n\n".join(reranked_docs) | |
with gr.Blocks() as demo: | |
with gr.Tab("Upload Documents"): | |
document_upload = gr.File(file_count="multiple", file_types=["document"]) | |
upload_button = gr.Button("Upload and Process") | |
upload_button.click(upload_documents, inputs=document_upload, outputs=gr.Text()) | |
with gr.Tab("Ask Questions"): | |
with gr.Row(): | |
chat_interface = gr.ChatInterface( | |
answer_query, | |
additional_inputs=[ | |
gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), | |
], | |
) | |
query_input = gr.Textbox(label="Query") | |
query_button = gr.Button("Query") | |
query_output = gr.Textbox() | |
query_button.click(query_documents, inputs=query_input, outputs=query_output) | |
if __name__ == "__main__": | |
demo.launch() |