File size: 7,013 Bytes
21990ec 1f4d028 2ddb416 d75a007 21990ec 82f2d2e d75a007 82f2d2e d75a007 82f2d2e 21990ec 4872899 21990ec 4872899 21990ec 4872899 21990ec 2ddb416 f362774 21990ec 9ac9769 21990ec aff958e 119fe72 6530ab4 aff958e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
# Install necessary packages
#!pip install streamlit
#!pip install wikipedia
#!pip install langchain_community
#!pip install sentence-transformers
#!pip install chromadb
#!pip install huggingface_hub
#!pip install transformers
import streamlit as st
from langchain_community.document_loaders import WikipediaLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from huggingface_hub import login, InferenceClient
from sentence_transformers import CrossEncoder
import numpy as np
import random
import string
# User variables
topic = st.sidebar.text_input("Enter the Wikipedia topic:", "Japanese History")
#query = st.sidebar.text_input("Enter your first query:", "First query")
model_name = 'mistralai/Mistral-7B-Instruct-v0.3'
HF_TOKEN = st.sidebar.text_input("Enter your Hugging Face token:", "", type="password")
# Initialize session state for error message
if 'error_message' not in st.session_state:
st.session_state.error_message = ""
# Function to validate token
def validate_token(token):
try:
# Attempt to log in with the provided token
login(token=token)
# Check if the token is valid by trying to access some data
HfApi().whoami()
return True
except Exception as e:
return False
# Validate the token and display appropriate message
if HF_TOKEN:
if validate_token(HF_TOKEN):
st.session_state.error_message = "" # Clear error message if the token is valid
st.sidebar.success("Token is valid!")
else:
st.session_state.error_message = "Invalid token. Please try again."
st.sidebar.error(st.session_state.error_message)
elif st.session_state.error_message:
st.sidebar.error(st.session_state.error_message)
# Memory for chat history
if "history" not in st.session_state:
st.session_state.history = []
# Function to generate a random string for collection name
def generate_random_string(max_length=60):
if max_length > 60:
raise ValueError("The maximum length cannot exceed 60 characters.")
length = random.randint(1, max_length)
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))
collection_name = generate_random_string()
# Function for query expansion
def augment_multiple_query(query):
client = InferenceClient(model_name, token=HF_TOKEN)
content = client.chat_completion(
messages=[
{
"role": "system",
"content": f"""You are a helpful expert in {topic}. Your users are asking questions about {topic}.
Suggest up to five additional related questions to help them find the information they need for the provided question.
Suggest only short questions without compound sentences. Suggest a variety of questions that cover different aspects of the topic.
Make sure they are complete questions, and that they are related to the original question."""
},
{
"role": "user",
"content": query
}
],
max_tokens=500,
)
return content.choices[0].message.content.split("\n")
# Function to handle RAG-based question answering
def rag_advanced(user_query):
# Document Loading
docs = WikipediaLoader(query=topic).load()
# Text Splitting
character_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1000, chunk_overlap=0)
concat_texts = "".join([doc.page_content for doc in docs])
character_split_texts = character_splitter.split_text(concat_texts)
token_splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0, tokens_per_chunk=256)
token_split_texts = [text for text in character_split_texts for text in token_splitter.split_text(text)]
# Embedding and Document Storage
embedding_function = SentenceTransformerEmbeddingFunction()
chroma_client = chromadb.Client()
chroma_collection = chroma_client.create_collection(collection_name, embedding_function=embedding_function)
ids = [str(i) for i in range(len(token_split_texts))]
chroma_collection.add(ids=ids, documents=token_split_texts)
# Document Retrieval
augmented_queries = augment_multiple_query(user_query)
joint_query = [user_query] + augmented_queries
results = chroma_collection.query(query_texts=joint_query, n_results=5, include=['documents', 'embeddings'])
retrieved_documents = results['documents']
unique_documents = list(set(doc for docs in retrieved_documents for doc in docs))
# Re-Ranking
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
pairs = [[user_query, doc] for doc in unique_documents]
scores = cross_encoder.predict(pairs)
top_indices = np.argsort(scores)[::-1][:5]
top_documents = [unique_documents[idx] for idx in top_indices]
# LLM Reference
client = InferenceClient(model_name, token=HF_TOKEN)
response = ""
for message in client.chat_completion(
messages=[
{
"role": "system",
"content": f"""You are a helpful expert in {topic}. Your users are asking questions about {topic}.
You will be shown the user's questions, and the relevant information from the documents related to {topic}.
Answer the user's question using only this information."""
},
{
"role": "user",
"content": f"Questions: {user_query}. \n Information: {top_documents}"
}
],
max_tokens=500,
stream=True,
):
response += message.choices[0].delta.content
return response
# Streamlit UI
st.title("Wikipedia RAG Chatbot")
st.markdown("Choose a topic. Don't forget to put your 🤗 token!")
st.link_button("Get Token Here", "https://huggingface.co/settings/tokens")
# Input box for the user to type their message
user_input = st.text_input("You: ", "", placeholder="Type your question here...")
# Generate response and update conversation history
if user_input:
response = rag_advanced(user_input)
st.session_state.history.append({"user": user_input, "bot": response})
# Display the conversation history
for chat in st.session_state.history:
st.write(f"You: {chat['user']}")
st.write(f"Bot: {chat['bot']}")
st.markdown("-----------------")
st.markdown("What is this app?")
st.markdown("""This is a simple RAG application using Wikipedia API.
The model for chat is Mistral-7B-Instruct-v0.3.
Main libraries: Langchain (text splitting), Chromadb (vector store)
This RAG uses query expansion and re-ranking to improve the quality.
Feel free to check the files or DM me for any questions. Thank you.""")
st.markdown("[Current agenda] Creating fallback for fuzzy keywords for wikipedia search") |