|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import streamlit as st |
|
from langchain_community.document_loaders import WikipediaLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter |
|
import chromadb |
|
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction |
|
from huggingface_hub import login, InferenceClient |
|
from sentence_transformers import CrossEncoder |
|
import numpy as np |
|
import random |
|
import string |
|
|
|
|
|
topic = st.sidebar.text_input("Enter the Wikipedia topic:", "Japanese History") |
|
|
|
model_name = 'mistralai/Mistral-7B-Instruct-v0.3' |
|
HF_TOKEN = st.sidebar.text_input("Enter your Hugging Face token:", "", type="password") |
|
|
|
|
|
if 'error_message' not in st.session_state: |
|
st.session_state.error_message = "" |
|
|
|
|
|
def validate_token(token): |
|
try: |
|
|
|
login(token=token) |
|
|
|
HfApi().whoami() |
|
return True |
|
except Exception as e: |
|
return False |
|
|
|
|
|
if HF_TOKEN: |
|
if validate_token(HF_TOKEN): |
|
st.session_state.error_message = "" |
|
st.sidebar.success("Token is valid!") |
|
else: |
|
st.session_state.error_message = "Invalid token. Please try again." |
|
st.sidebar.error(st.session_state.error_message) |
|
elif st.session_state.error_message: |
|
st.sidebar.error(st.session_state.error_message) |
|
|
|
|
|
if "history" not in st.session_state: |
|
st.session_state.history = [] |
|
|
|
|
|
def generate_random_string(max_length=60): |
|
if max_length > 60: |
|
raise ValueError("The maximum length cannot exceed 60 characters.") |
|
length = random.randint(1, max_length) |
|
characters = string.ascii_letters + string.digits |
|
return ''.join(random.choice(characters) for _ in range(length)) |
|
|
|
collection_name = generate_random_string() |
|
|
|
|
|
def augment_multiple_query(query): |
|
client = InferenceClient(model_name, token=HF_TOKEN) |
|
content = client.chat_completion( |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": f"""You are a helpful expert in {topic}. Your users are asking questions about {topic}. |
|
Suggest up to five additional related questions to help them find the information they need for the provided question. |
|
Suggest only short questions without compound sentences. Suggest a variety of questions that cover different aspects of the topic. |
|
Make sure they are complete questions, and that they are related to the original question.""" |
|
}, |
|
{ |
|
"role": "user", |
|
"content": query |
|
} |
|
], |
|
max_tokens=500, |
|
) |
|
return content.choices[0].message.content.split("\n") |
|
|
|
|
|
def rag_advanced(user_query): |
|
|
|
docs = WikipediaLoader(query=topic).load() |
|
|
|
|
|
character_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1000, chunk_overlap=0) |
|
concat_texts = "".join([doc.page_content for doc in docs]) |
|
character_split_texts = character_splitter.split_text(concat_texts) |
|
|
|
token_splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0, tokens_per_chunk=256) |
|
token_split_texts = [text for text in character_split_texts for text in token_splitter.split_text(text)] |
|
|
|
|
|
embedding_function = SentenceTransformerEmbeddingFunction() |
|
chroma_client = chromadb.Client() |
|
chroma_collection = chroma_client.create_collection(collection_name, embedding_function=embedding_function) |
|
|
|
ids = [str(i) for i in range(len(token_split_texts))] |
|
chroma_collection.add(ids=ids, documents=token_split_texts) |
|
|
|
|
|
augmented_queries = augment_multiple_query(user_query) |
|
joint_query = [user_query] + augmented_queries |
|
|
|
results = chroma_collection.query(query_texts=joint_query, n_results=5, include=['documents', 'embeddings']) |
|
retrieved_documents = results['documents'] |
|
|
|
unique_documents = list(set(doc for docs in retrieved_documents for doc in docs)) |
|
|
|
|
|
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') |
|
pairs = [[user_query, doc] for doc in unique_documents] |
|
scores = cross_encoder.predict(pairs) |
|
|
|
top_indices = np.argsort(scores)[::-1][:5] |
|
top_documents = [unique_documents[idx] for idx in top_indices] |
|
|
|
|
|
client = InferenceClient(model_name, token=HF_TOKEN) |
|
response = "" |
|
for message in client.chat_completion( |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": f"""You are a helpful expert in {topic}. Your users are asking questions about {topic}. |
|
You will be shown the user's questions, and the relevant information from the documents related to {topic}. |
|
Answer the user's question using only this information.""" |
|
}, |
|
{ |
|
"role": "user", |
|
"content": f"Questions: {user_query}. \n Information: {top_documents}" |
|
} |
|
], |
|
max_tokens=500, |
|
stream=True, |
|
): |
|
response += message.choices[0].delta.content |
|
|
|
return response |
|
|
|
|
|
st.title("Wikipedia RAG Chatbot") |
|
st.markdown("Choose a topic. Don't forget to put your π€ token!") |
|
st.link_button("Get Token Here", "https://huggingface.co/settings/tokens") |
|
|
|
|
|
user_input = st.text_input("You: ", "", placeholder="Type your question here...") |
|
|
|
|
|
if user_input: |
|
response = rag_advanced(user_input) |
|
st.session_state.history.append({"user": user_input, "bot": response}) |
|
|
|
|
|
for chat in st.session_state.history: |
|
st.write(f"You: {chat['user']}") |
|
st.write(f"Bot: {chat['bot']}") |
|
|
|
|
|
st.markdown("-----------------") |
|
st.markdown("What is this app?") |
|
st.markdown("""This is a simple RAG application using Wikipedia API. |
|
The model for chat is Mistral-7B-Instruct-v0.3. |
|
Main libraries: Langchain (text splitting), Chromadb (vector store) |
|
This RAG uses query expansion and re-ranking to improve the quality. |
|
Feel free to check the files or DM me for any questions. Thank you.""") |
|
st.markdown("[Current agenda] Creating fallback for fuzzy keywords for wikipedia search") |