BanglaRAG / app.py
himel06's picture
Update app.py
0296ab6 verified
raw
history blame
3.69 kB
import streamlit as st
import logging
from BanglaRAG.bangla_rag_pipeline import BanglaRAGChain
import warnings
warnings.filterwarnings("ignore")
#lalala
# Default constants for the script
DEFAULT_CHAT_MODEL_ID = "hassanaliemon/bn_rag_llama3-8b"
DEFAULT_EMBED_MODEL_ID = "l3cube-pune/bengali-sentence-similarity-sbert"
DEFAULT_K = 4
DEFAULT_TOP_K = 2
DEFAULT_TOP_P = 0.6
DEFAULT_TEMPERATURE = 0.6
DEFAULT_CHUNK_SIZE = 500
DEFAULT_CHUNK_OVERLAP = 150
DEFAULT_MAX_NEW_TOKENS = 256
DEFAULT_OFFLOAD_DIR = "/tmp"
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# Initialize and load the RAG model
@st.cache_resource(show_spinner=False)
def load_model(chat_model_id, embed_model_id, text_path, k, top_k, top_p, temperature, chunk_size, chunk_overlap, hf_token, max_new_tokens, quantization, offload_dir):
rag_chain = BanglaRAGChain()
rag_chain.load(
chat_model_id=chat_model_id,
embed_model_id=embed_model_id,
text_path=text_path,
k=k,
top_k=top_k,
top_p=top_p,
temperature=temperature,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
hf_token=hf_token,
max_new_tokens=max_new_tokens,
quantization=quantization,
offload_dir=offload_dir, # Pass the offload_dir here
)
return rag_chain
def main():
st.title("Bangla RAG Chatbot")
# Sidebar for model configuration
st.sidebar.header("Model Configuration")
chat_model_id = st.sidebar.text_input("Chat Model ID", DEFAULT_CHAT_MODEL_ID)
embed_model_id = st.sidebar.text_input("Embed Model ID", DEFAULT_EMBED_MODEL_ID)
k = st.sidebar.slider("Number of Documents to Retrieve (k)", 1, 10, DEFAULT_K)
top_k = st.sidebar.slider("Top K", 1, 10, DEFAULT_TOP_K)
top_p = st.sidebar.slider("Top P", 0.0, 1.0, DEFAULT_TOP_P)
temperature = st.sidebar.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
max_new_tokens = st.sidebar.slider("Max New Tokens", 1, 512, DEFAULT_MAX_NEW_TOKENS)
chunk_size = st.sidebar.slider("Chunk Size", 100, 1000, DEFAULT_CHUNK_SIZE)
chunk_overlap = st.sidebar.slider("Chunk Overlap", 0, 500, DEFAULT_CHUNK_OVERLAP)
text_path = st.sidebar.text_input("Text File Path", "text.txt")
quantization = st.sidebar.checkbox("Enable Quantization (4-bit)", value=False)
show_context = st.sidebar.checkbox("Show Retrieved Context", value=False)
offload_dir = st.sidebar.text_input("Offload Directory", DEFAULT_OFFLOAD_DIR) # Default to /tmp
# Load the model with the above configuration
rag_chain = load_model(
chat_model_id=chat_model_id,
embed_model_id=embed_model_id,
text_path=text_path,
k=k,
top_k=top_k,
top_p=top_p,
temperature=temperature,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
hf_token=None, # If you're not using HF API token, set it to None
max_new_tokens=max_new_tokens,
quantization=quantization,
offload_dir=offload_dir, # Pass the offload_dir here
)
st.write("### Enter your question:")
query = st.text_input("আপনার প্রশ্ন")
if st.button("Generate Answer"):
if query:
try:
answer, context = rag_chain(query)
st.write(f"**Answer:** {answer}")
if show_context:
st.write(f"**Context:** {context}")
except Exception as e:
st.error(f"Couldn't generate an answer: {e}")
else:
st.warning("Please enter a query.")
if __name__ == "__main__":
main()