import gradio as gr import os from langchain.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from langchain.chains import ConversationalRetrievalChain from langchain.embeddings import HuggingFaceEmbeddings from langchain.llms import HuggingFaceHub from langchain.memory import ConversationBufferMemory import chromadb from transformers import AutoTokenizer import transformers import torch # Constants and configuration list_llm = [ "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1", "HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", "google/flan-t5-xxl" ] list_llm_simple = [os.path.basename(llm) for llm in list_llm] # Function placeholders (actual function implementations from the original script) def load_doc(list_file_path, chunk_size, chunk_overlap): loaders = [PyPDFLoader(x) for x in list_file_path] pages = [] for loader in loaders: pages.extend(loader.load()) text_splitter = RecursiveCharacterTextSplitter( chunk_size = chunk_size, chunk_overlap = chunk_overlap) doc_splits = text_splitter.split_documents(pages) return doc_splits def create_db(splits, collection_name): embedding = HuggingFaceEmbeddings() new_client = chromadb.EphemeralClient() vectordb = Chroma.from_documents( documents=splits, embedding=embedding, client=new_client, collection_name=collection_name, ) return vectordb def load_db(): embedding = HuggingFaceEmbeddings() vectordb = Chroma(embedding_function=embedding) return vectordb def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db): if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1": llm = HuggingFaceHub( repo_id=llm_model, model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True} ) else: llm = HuggingFaceHub( repo_id=llm_model, model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k} ) memory = ConversationBufferMemory( memory_key="chat_history", output_key='answer', return_messages=True ) retriever = vector_db.as_retriever() qa_chain = ConversationalRetrievalChain.from_llm( llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True, return_generated_question=False, ) return qa_chain def initialize_database(list_file_obj, chunk_size, chunk_overlap): list_file_path = [x.name for x in list_file_obj if x is not None] collection_name = os.path.basename(list_file_path[0]).replace(" ","-")[:50] doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap) vector_db = create_db(doc_splits, collection_name) return vector_db, collection_name, "Complete!" def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db): llm_name = list_llm[llm_option] qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db) return qa_chain, "Complete!" def format_chat_history(message, chat_history): formatted_chat_history = [] for user_message, bot_message in chat_history: formatted_chat_history.append(f"User: {user_message}") formatted_chat_history.append(f"Assistant: {bot_message}") return formatted_chat_history def conversation(qa_chain, message, history): formatted_chat_history = format_chat_history(message, history) response = qa_chain({"question": message, "chat_history": formatted_chat_history}) response_answer = response["answer"] response_sources = response["source_documents"] response_source1 = response_sources[0].page_content.strip() response_source2 = response_sources[1].page_content.strip() response_source1_page = response_sources[0].metadata["page"] + 1 response_source2_page = response_sources[1].metadata["page"] + 1 new_history = history + [(message, response_answer)] return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page def upload_file(file_obj): list_file_path = [] for idx, file in enumerate(file_obj): file_path = file_obj.name list_file_path.append(file_path) return list_file_path def demo(): with gr.Blocks(theme="base") as demo: vector_db = gr.State() qa_chain = gr.State() collection_name = gr.State() gr.Markdown( """