File size: 7,013 Bytes
21990ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f4d028
 
2ddb416
d75a007
21990ec
82f2d2e
 
 
 
d75a007
 
 
 
 
 
 
 
 
 
 
 
 
 
82f2d2e
d75a007
 
82f2d2e
 
 
 
21990ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4872899
 
 
21990ec
 
 
 
 
 
 
 
 
 
 
 
 
 
4872899
21990ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4872899
 
21990ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ddb416
f362774
21990ec
 
9ac9769
21990ec
 
 
 
 
 
 
 
 
aff958e
 
 
 
119fe72
 
 
 
 
6530ab4
aff958e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Install necessary packages
#!pip install streamlit
#!pip install wikipedia
#!pip install langchain_community
#!pip install sentence-transformers
#!pip install chromadb
#!pip install huggingface_hub
#!pip install transformers

import streamlit as st
from langchain_community.document_loaders import WikipediaLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from huggingface_hub import login, InferenceClient
from sentence_transformers import CrossEncoder
import numpy as np
import random
import string

# User variables
topic = st.sidebar.text_input("Enter the Wikipedia topic:", "Japanese History")
#query = st.sidebar.text_input("Enter your first query:", "First query")
model_name = 'mistralai/Mistral-7B-Instruct-v0.3'
HF_TOKEN = st.sidebar.text_input("Enter your Hugging Face token:", "", type="password")

# Initialize session state for error message
if 'error_message' not in st.session_state:
    st.session_state.error_message = ""

# Function to validate token
def validate_token(token):
    try:
        # Attempt to log in with the provided token
        login(token=token)
        # Check if the token is valid by trying to access some data
        HfApi().whoami()
        return True
    except Exception as e:
        return False

# Validate the token and display appropriate message
if HF_TOKEN:
    if validate_token(HF_TOKEN):
        st.session_state.error_message = ""  # Clear error message if the token is valid
        st.sidebar.success("Token is valid!")
    else:
        st.session_state.error_message = "Invalid token. Please try again."
        st.sidebar.error(st.session_state.error_message)
elif st.session_state.error_message:
    st.sidebar.error(st.session_state.error_message)

# Memory for chat history
if "history" not in st.session_state:
    st.session_state.history = []

# Function to generate a random string for collection name
def generate_random_string(max_length=60):
    if max_length > 60:
        raise ValueError("The maximum length cannot exceed 60 characters.")
    length = random.randint(1, max_length)
    characters = string.ascii_letters + string.digits
    return ''.join(random.choice(characters) for _ in range(length))

collection_name = generate_random_string()

# Function for query expansion
def augment_multiple_query(query):
    client = InferenceClient(model_name, token=HF_TOKEN)
    content = client.chat_completion(
        messages=[
            {
                "role": "system",
                "content": f"""You are a helpful expert in {topic}. Your users are asking questions about {topic}.
                Suggest up to five additional related questions to help them find the information they need for the provided question.
                Suggest only short questions without compound sentences. Suggest a variety of questions that cover different aspects of the topic.
                Make sure they are complete questions, and that they are related to the original question."""
            },
            {
                "role": "user",
                "content": query
            }
        ],
        max_tokens=500,
    )
    return content.choices[0].message.content.split("\n")

# Function to handle RAG-based question answering
def rag_advanced(user_query):
    # Document Loading
    docs = WikipediaLoader(query=topic).load()

    # Text Splitting
    character_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1000, chunk_overlap=0)
    concat_texts = "".join([doc.page_content for doc in docs])
    character_split_texts = character_splitter.split_text(concat_texts)

    token_splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0, tokens_per_chunk=256)
    token_split_texts = [text for text in character_split_texts for text in token_splitter.split_text(text)]

    # Embedding and Document Storage
    embedding_function = SentenceTransformerEmbeddingFunction()
    chroma_client = chromadb.Client()
    chroma_collection = chroma_client.create_collection(collection_name, embedding_function=embedding_function)

    ids = [str(i) for i in range(len(token_split_texts))]
    chroma_collection.add(ids=ids, documents=token_split_texts)

    # Document Retrieval
    augmented_queries = augment_multiple_query(user_query)
    joint_query = [user_query] + augmented_queries

    results = chroma_collection.query(query_texts=joint_query, n_results=5, include=['documents', 'embeddings'])
    retrieved_documents = results['documents']

    unique_documents = list(set(doc for docs in retrieved_documents for doc in docs))

    # Re-Ranking
    cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
    pairs = [[user_query, doc] for doc in unique_documents]
    scores = cross_encoder.predict(pairs)

    top_indices = np.argsort(scores)[::-1][:5]
    top_documents = [unique_documents[idx] for idx in top_indices]

    # LLM Reference
    client = InferenceClient(model_name, token=HF_TOKEN)
    response = ""
    for message in client.chat_completion(
        messages=[
            {
                "role": "system",
                "content": f"""You are a helpful expert in {topic}. Your users are asking questions about {topic}.
                You will be shown the user's questions, and the relevant information from the documents related to {topic}.
                Answer the user's question using only this information."""
            },
            {
                "role": "user",
                "content": f"Questions: {user_query}. \n Information: {top_documents}"
            }
        ],
        max_tokens=500,
        stream=True,
    ):
        response += message.choices[0].delta.content

    return response

# Streamlit UI
st.title("Wikipedia RAG Chatbot")
st.markdown("Choose a topic. Don't forget to put your 🤗 token!")
st.link_button("Get Token Here", "https://huggingface.co/settings/tokens")

# Input box for the user to type their message
user_input = st.text_input("You: ", "", placeholder="Type your question here...")

# Generate response and update conversation history
if user_input:
    response = rag_advanced(user_input)
    st.session_state.history.append({"user": user_input, "bot": response})

# Display the conversation history
for chat in st.session_state.history:
    st.write(f"You: {chat['user']}")
    st.write(f"Bot: {chat['bot']}")


st.markdown("-----------------")
st.markdown("What is this app?")
st.markdown("""This is a simple RAG application using Wikipedia API.  
The model for chat is Mistral-7B-Instruct-v0.3.  
Main libraries: Langchain (text splitting), Chromadb (vector store)  
This RAG uses query expansion and re-ranking to improve the quality.  
Feel free to check the files or DM me for any questions. Thank you.""")
st.markdown("[Current agenda] Creating fallback for fuzzy keywords for wikipedia search")