# Install necessary packages #!pip install streamlit #!pip install wikipedia #!pip install langchain_community #!pip install sentence-transformers #!pip install chromadb #!pip install huggingface_hub #!pip install transformers import streamlit as st from langchain_community.document_loaders import WikipediaLoader from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter import chromadb from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction from huggingface_hub import login, InferenceClient from sentence_transformers import CrossEncoder import numpy as np import random import string # User variables topic = st.sidebar.text_input("Enter the Wikipedia topic:", "Japanese History") #query = st.sidebar.text_input("Enter your first query:", "First query") model_name = 'mistralai/Mistral-7B-Instruct-v0.3' HF_TOKEN = st.sidebar.text_input("Enter your Hugging Face token:", "", type="password") # Initialize session state for error message if 'error_message' not in st.session_state: st.session_state.error_message = "" # Function to validate token def validate_token(token): try: # Attempt to log in with the provided token login(token=token) # Check if the token is valid by trying to access some data HfApi().whoami() return True except Exception as e: return False # Validate the token and display appropriate message if HF_TOKEN: if validate_token(HF_TOKEN): st.session_state.error_message = "" # Clear error message if the token is valid st.sidebar.success("Token is valid!") else: st.session_state.error_message = "Invalid token. Please try again." st.sidebar.error(st.session_state.error_message) elif st.session_state.error_message: st.sidebar.error(st.session_state.error_message) # Memory for chat history if "history" not in st.session_state: st.session_state.history = [] # Function to generate a random string for collection name def generate_random_string(max_length=60): if max_length > 60: raise ValueError("The maximum length cannot exceed 60 characters.") length = random.randint(1, max_length) characters = string.ascii_letters + string.digits return ''.join(random.choice(characters) for _ in range(length)) collection_name = generate_random_string() # Function for query expansion def augment_multiple_query(query): client = InferenceClient(model_name, token=HF_TOKEN) content = client.chat_completion( messages=[ { "role": "system", "content": f"""You are a helpful expert in {topic}. Your users are asking questions about {topic}. Suggest up to five additional related questions to help them find the information they need for the provided question. Suggest only short questions without compound sentences. Suggest a variety of questions that cover different aspects of the topic. Make sure they are complete questions, and that they are related to the original question.""" }, { "role": "user", "content": query } ], max_tokens=500, ) return content.choices[0].message.content.split("\n") # Function to handle RAG-based question answering def rag_advanced(user_query): # Document Loading docs = WikipediaLoader(query=topic).load() # Text Splitting character_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1000, chunk_overlap=0) concat_texts = "".join([doc.page_content for doc in docs]) character_split_texts = character_splitter.split_text(concat_texts) token_splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0, tokens_per_chunk=256) token_split_texts = [text for text in character_split_texts for text in token_splitter.split_text(text)] # Embedding and Document Storage embedding_function = SentenceTransformerEmbeddingFunction() chroma_client = chromadb.Client() chroma_collection = chroma_client.create_collection(collection_name, embedding_function=embedding_function) ids = [str(i) for i in range(len(token_split_texts))] chroma_collection.add(ids=ids, documents=token_split_texts) # Document Retrieval augmented_queries = augment_multiple_query(user_query) joint_query = [user_query] + augmented_queries results = chroma_collection.query(query_texts=joint_query, n_results=5, include=['documents', 'embeddings']) retrieved_documents = results['documents'] unique_documents = list(set(doc for docs in retrieved_documents for doc in docs)) # Re-Ranking cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') pairs = [[user_query, doc] for doc in unique_documents] scores = cross_encoder.predict(pairs) top_indices = np.argsort(scores)[::-1][:5] top_documents = [unique_documents[idx] for idx in top_indices] # LLM Reference client = InferenceClient(model_name, token=HF_TOKEN) response = "" for message in client.chat_completion( messages=[ { "role": "system", "content": f"""You are a helpful expert in {topic}. Your users are asking questions about {topic}. You will be shown the user's questions, and the relevant information from the documents related to {topic}. Answer the user's question using only this information.""" }, { "role": "user", "content": f"Questions: {user_query}. \n Information: {top_documents}" } ], max_tokens=500, stream=True, ): response += message.choices[0].delta.content return response # Streamlit UI st.title("Wikipedia RAG Chatbot") st.markdown("Choose a topic. Don't forget to put your 🤗 token!") st.link_button("Get Token Here", "https://huggingface.co/settings/tokens") # Input box for the user to type their message user_input = st.text_input("You: ", "", placeholder="Type your question here...") # Generate response and update conversation history if user_input: response = rag_advanced(user_input) st.session_state.history.append({"user": user_input, "bot": response}) # Display the conversation history for chat in st.session_state.history: st.write(f"You: {chat['user']}") st.write(f"Bot: {chat['bot']}") st.markdown("-----------------") st.markdown("What is this app?") st.markdown("""This is a simple RAG application using Wikipedia API. The model for chat is Mistral-7B-Instruct-v0.3. Main libraries: Langchain (text splitting), Chromadb (vector store) This RAG uses query expansion and re-ranking to improve the quality. Feel free to check the files or DM me for any questions. Thank you.""") st.markdown("[Current agenda] Creating fallback for fuzzy keywords for wikipedia search")