BW_RAG / app.py
ziyingsk's picture
Update app.py
4896206 verified
import os
import streamlit as st
from dotenv import load_dotenv
import itertools
from pinecone import Pinecone
from langchain_community.llms import HuggingFaceHub
from langchain.chains import LLMChain
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from sentence_transformers import SentenceTransformer
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import time
# Set up environment, Pinecone is a database
cache_dir = None # Directory for cache
Huggingface_token = st.secrets["HUGGINGFACEHUB_API_TOKEN"] # Huggingface API key
pc = Pinecone(api_key=st.secrets["PINECONE_API_KEY"]) # Database API key
index = pc.Index(st.secrets["Index_Name"]) # Database index name
# Initialize embedding model (LLM will be saved to cache_dir if assigned)
embedding_model = "all-mpnet-base-v2" # See link https://www.sbert.net/docs/pretrained_models.html
if cache_dir:
embedding = SentenceTransformer(embedding_model, cache_folder=cache_dir)
else:
embedding = SentenceTransformer(embedding_model)
# Read the PDF files, divide them into chunks, and Embedding
def read_doc(file_path):
file_loader = PyPDFDirectoryLoader(file_path)
documents = file_loader.load()
return documents
def chunk_data(docs, chunk_size=300, chunk_overlap=50):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
doc = text_splitter.split_documents(docs)
return doc
# Save embeddings to database
def chunks(iterable, batch_size=100):
"""A helper function to break an iterable into chunks of size batch_size."""
it = iter(iterable)
chunk = tuple(itertools.islice(it, batch_size))
while chunk:
yield chunk
chunk = tuple(itertools.islice(it, batch_size))
# Streamlit interface start, uploading file
st.title("RAG-Anwendung (RAG Application)")
st.caption("Diese Anwendung kann Ihnen helfen, kostenlos Fragen zu PDF-Dateien zu stellen. (This application can help you ask questions about PDF files for free.)")
uploaded_file = st.file_uploader("Wählen Sie eine PDF-Datei, das Laden kann eine Weile dauern. (Choose a PDF file, loading might take a while.)", type="pdf")
if uploaded_file is not None:
# Ensure the temp directory exists and is empty
temp_dir = "tempDir"
if os.path.exists(temp_dir):
for file in os.listdir(temp_dir):
file_path = os.path.join(temp_dir, file)
if os.path.isfile(file_path):
os.remove(file_path)
elif os.path.isdir(file_path):
os.rmdir(file_path) # Only removes empty directories
os.makedirs(temp_dir, exist_ok=True)
# Save the uploaded file temporarily
temp_file_path = os.path.join(temp_dir, uploaded_file.name)
with open(temp_file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
doc = read_doc(temp_dir+"/")
documents = chunk_data(docs=doc)
texts = [document.page_content for document in documents]
pdf_vectors = embedding.encode(texts)
vector_count = len(documents)
example_data_generator = map(lambda i: (f'id-{i}', pdf_vectors[i], {"text": texts[i]}), range(vector_count))
# Update the Pinecone index with new vectors
for ids_vectors_chunk in chunks(example_data_generator, batch_size=100): # Iterate through chunks of example data
index.upsert(vectors=ids_vectors_chunk, namespace='ns1') # Upsert (update or insert) vectors
time.sleep(0.05) # Pause to avoid overwhelming the server
ns_count = index.describe_index_stats()['namespaces']['ns1']['vector_count'] # Get current vector count in namespace 'ns1'
if vector_count < ns_count: # Check if the old vectors are still inside
ids_to_delete = [f'id-{i}' for i in range(vector_count, ns_count)] # Generate list of IDs to delete
index.delete(ids=ids_to_delete, namespace='ns1') # Delete old vectors
time.sleep(0.05) # Pause to avoid overwhelming the server
# Input for the search query
with st.form(key='my_form'):
sample_query = st.text_input("Stellen Sie eine Frage zu dem PDF: (Ask a question related to the PDF:)") # User query input
submit_button = st.form_submit_button(label='Abschicken (Submit)') # Submit button
if submit_button:
if uploaded_file is not None and sample_query: # Check if file is uploaded and query provided
query_vector = embedding.encode(sample_query).tolist() # Encode query to vector
query_search = index.query(vector=query_vector, top_k=5, include_metadata=True, namespace='ns1') # Search index
time.sleep(0.1) # Pause to avoid overwhelming the server
matched_contents = [match["metadata"]["text"] for match in query_search["matches"]] # Extract text metadata from results
# Rerank
rerank_model = "BAAI/bge-reranker-v2-m3"
if cache_dir:
tokenizer = AutoTokenizer.from_pretrained(rerank_model, cache_dir=cache_dir)
model = AutoModelForSequenceClassification.from_pretrained(rerank_model, cache_dir=cache_dir)
else:
tokenizer = AutoTokenizer.from_pretrained(rerank_model)
model = AutoModelForSequenceClassification.from_pretrained(rerank_model)
model.eval()
pairs = [[sample_query, content] for content in matched_contents]
with torch.no_grad():
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=300)
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
matched_contents = [content for _, content in sorted(zip(scores, matched_contents), key=lambda x: x[0], reverse=True)]
matched_contents = matched_contents[0]
del model
torch.cuda.empty_cache()
# Display matched contents after reranking
st.markdown("### Möglicherweise relevante Abschnitte aus dem PDF (Potentially relevant sections from the PDF):")
st.write(matched_contents)
# Get answer
query_model = "meta-llama/Meta-Llama-3-8B-Instruct"
llm_huggingface = HuggingFaceHub(repo_id=query_model, model_kwargs={"temperature": 0.7, "max_length": 500})
prompt_template = PromptTemplate(input_variables=['query', 'context'], template="{query}, Beim Beantworten der Frage bitte mit dem Wort 'Antwort:' beginnen,unter Berücksichtigung des folgenden Kontexts: \n\n{context}")
prompt = prompt_template.format(query=sample_query, context=matched_contents)
chain = LLMChain(llm=llm_huggingface, prompt=prompt_template)
result = chain.run(query=sample_query, context=matched_contents)
# Polish answer
result = result.replace(prompt, "")
special_start = "Antwort:"
start_index = result.find(special_start)
if (start_index != -1):
result = result[start_index + len(special_start):].lstrip()
else:
result = result.lstrip()
# Display the final answer with a note about limitations
st.markdown("### Antwort (Answer):")
st.write(result)
st.markdown("**Hinweis:** Aufgrund begrenzter Rechenleistung kann das große Sprachmodell möglicherweise keine vollständige Antwort liefern. (Note: Due to limited computational power, the large language model might not be able to provide a complete response.)")