Wiki_RAG / app.py
Ryu-m0m's picture
Update app.py
9ac9769 verified
# Install necessary packages
#!pip install streamlit
#!pip install wikipedia
#!pip install langchain_community
#!pip install sentence-transformers
#!pip install chromadb
#!pip install huggingface_hub
#!pip install transformers
import streamlit as st
from langchain_community.document_loaders import WikipediaLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from huggingface_hub import login, InferenceClient
from sentence_transformers import CrossEncoder
import numpy as np
import random
import string
# User variables
topic = st.sidebar.text_input("Enter the Wikipedia topic:", "Japanese History")
#query = st.sidebar.text_input("Enter your first query:", "First query")
model_name = 'mistralai/Mistral-7B-Instruct-v0.3'
HF_TOKEN = st.sidebar.text_input("Enter your Hugging Face token:", "", type="password")
# Initialize session state for error message
if 'error_message' not in st.session_state:
st.session_state.error_message = ""
# Function to validate token
def validate_token(token):
try:
# Attempt to log in with the provided token
login(token=token)
# Check if the token is valid by trying to access some data
HfApi().whoami()
return True
except Exception as e:
return False
# Validate the token and display appropriate message
if HF_TOKEN:
if validate_token(HF_TOKEN):
st.session_state.error_message = "" # Clear error message if the token is valid
st.sidebar.success("Token is valid!")
else:
st.session_state.error_message = "Invalid token. Please try again."
st.sidebar.error(st.session_state.error_message)
elif st.session_state.error_message:
st.sidebar.error(st.session_state.error_message)
# Memory for chat history
if "history" not in st.session_state:
st.session_state.history = []
# Function to generate a random string for collection name
def generate_random_string(max_length=60):
if max_length > 60:
raise ValueError("The maximum length cannot exceed 60 characters.")
length = random.randint(1, max_length)
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))
collection_name = generate_random_string()
# Function for query expansion
def augment_multiple_query(query):
client = InferenceClient(model_name, token=HF_TOKEN)
content = client.chat_completion(
messages=[
{
"role": "system",
"content": f"""You are a helpful expert in {topic}. Your users are asking questions about {topic}.
Suggest up to five additional related questions to help them find the information they need for the provided question.
Suggest only short questions without compound sentences. Suggest a variety of questions that cover different aspects of the topic.
Make sure they are complete questions, and that they are related to the original question."""
},
{
"role": "user",
"content": query
}
],
max_tokens=500,
)
return content.choices[0].message.content.split("\n")
# Function to handle RAG-based question answering
def rag_advanced(user_query):
# Document Loading
docs = WikipediaLoader(query=topic).load()
# Text Splitting
character_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1000, chunk_overlap=0)
concat_texts = "".join([doc.page_content for doc in docs])
character_split_texts = character_splitter.split_text(concat_texts)
token_splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0, tokens_per_chunk=256)
token_split_texts = [text for text in character_split_texts for text in token_splitter.split_text(text)]
# Embedding and Document Storage
embedding_function = SentenceTransformerEmbeddingFunction()
chroma_client = chromadb.Client()
chroma_collection = chroma_client.create_collection(collection_name, embedding_function=embedding_function)
ids = [str(i) for i in range(len(token_split_texts))]
chroma_collection.add(ids=ids, documents=token_split_texts)
# Document Retrieval
augmented_queries = augment_multiple_query(user_query)
joint_query = [user_query] + augmented_queries
results = chroma_collection.query(query_texts=joint_query, n_results=5, include=['documents', 'embeddings'])
retrieved_documents = results['documents']
unique_documents = list(set(doc for docs in retrieved_documents for doc in docs))
# Re-Ranking
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
pairs = [[user_query, doc] for doc in unique_documents]
scores = cross_encoder.predict(pairs)
top_indices = np.argsort(scores)[::-1][:5]
top_documents = [unique_documents[idx] for idx in top_indices]
# LLM Reference
client = InferenceClient(model_name, token=HF_TOKEN)
response = ""
for message in client.chat_completion(
messages=[
{
"role": "system",
"content": f"""You are a helpful expert in {topic}. Your users are asking questions about {topic}.
You will be shown the user's questions, and the relevant information from the documents related to {topic}.
Answer the user's question using only this information."""
},
{
"role": "user",
"content": f"Questions: {user_query}. \n Information: {top_documents}"
}
],
max_tokens=500,
stream=True,
):
response += message.choices[0].delta.content
return response
# Streamlit UI
st.title("Wikipedia RAG Chatbot")
st.markdown("Choose a topic. Don't forget to put your πŸ€— token!")
st.link_button("Get Token Here", "https://huggingface.co/settings/tokens")
# Input box for the user to type their message
user_input = st.text_input("You: ", "", placeholder="Type your question here...")
# Generate response and update conversation history
if user_input:
response = rag_advanced(user_input)
st.session_state.history.append({"user": user_input, "bot": response})
# Display the conversation history
for chat in st.session_state.history:
st.write(f"You: {chat['user']}")
st.write(f"Bot: {chat['bot']}")
st.markdown("-----------------")
st.markdown("What is this app?")
st.markdown("""This is a simple RAG application using Wikipedia API.
The model for chat is Mistral-7B-Instruct-v0.3.
Main libraries: Langchain (text splitting), Chromadb (vector store)
This RAG uses query expansion and re-ranking to improve the quality.
Feel free to check the files or DM me for any questions. Thank you.""")
st.markdown("[Current agenda] Creating fallback for fuzzy keywords for wikipedia search")