import streamlit as st import logging from BanglaRAG.bangla_rag_pipeline import BanglaRAGChain import warnings warnings.filterwarnings("ignore") # Default constants for the script DEFAULT_CHAT_MODEL_ID = "hassanaliemon/bn_rag_llama3-8b" DEFAULT_EMBED_MODEL_ID = "l3cube-pune/bengali-sentence-similarity-sbert" DEFAULT_K = 4 DEFAULT_TOP_K = 2 DEFAULT_TOP_P = 0.6 DEFAULT_TEMPERATURE = 0.6 DEFAULT_CHUNK_SIZE = 500 DEFAULT_CHUNK_OVERLAP = 150 DEFAULT_MAX_NEW_TOKENS = 256 # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) # Initialize and load the RAG model @st.cache_resource(show_spinner=False) def load_model(chat_model_id, embed_model_id, text_path, k, top_k, top_p, temperature, chunk_size, chunk_overlap, hf_token, max_new_tokens, quantization): rag_chain = BanglaRAGChain() rag_chain.load( chat_model_id=chat_model_id, embed_model_id=embed_model_id, text_path=text_path, k=k, top_k=top_k, top_p=top_p, temperature=temperature, chunk_size=chunk_size, chunk_overlap=chunk_overlap, hf_token=hf_token, max_new_tokens=max_new_tokens, quantization=quantization, ) return rag_chain def main(): st.title("Bangla RAG Chatbot") # Sidebar for model configuration st.sidebar.header("Model Configuration") chat_model_id = st.sidebar.text_input("Chat Model ID", DEFAULT_CHAT_MODEL_ID) embed_model_id = st.sidebar.text_input("Embed Model ID", DEFAULT_EMBED_MODEL_ID) k = st.sidebar.slider("Number of Documents to Retrieve (k)", 1, 10, DEFAULT_K) top_k = st.sidebar.slider("Top K", 1, 10, DEFAULT_TOP_K) top_p = st.sidebar.slider("Top P", 0.0, 1.0, DEFAULT_TOP_P) temperature = st.sidebar.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE) max_new_tokens = st.sidebar.slider("Max New Tokens", 1, 512, DEFAULT_MAX_NEW_TOKENS) chunk_size = st.sidebar.slider("Chunk Size", 100, 1000, DEFAULT_CHUNK_SIZE) chunk_overlap = st.sidebar.slider("Chunk Overlap", 0, 500, DEFAULT_CHUNK_OVERLAP) text_path = st.sidebar.text_input("Text File Path", "text.txt") quantization = st.sidebar.checkbox("Enable Quantization (4-bit)", value=False) show_context = st.sidebar.checkbox("Show Retrieved Context", value=False) hf_token = st.text_input("Hugging Face API Token", type="password") # Load the model with the above configuration rag_chain = load_model( chat_model_id=chat_model_id, embed_model_id=embed_model_id, text_path=text_path, k=k, top_k=top_k, top_p=top_p, temperature=temperature, chunk_size=chunk_size, chunk_overlap=chunk_overlap, hf_token=hf_token, max_new_tokens=max_new_tokens, quantization=quantization, ) st.write("### Enter your question:") query = st.text_input("আপনার প্রশ্ন") if st.button("Generate Answer"): if query: try: answer, context = rag_chain(query) st.write(f"**Answer:** {answer}") if show_context: st.write(f"**Context:** {context}") except Exception as e: st.error(f"Couldn't generate an answer: {e}") else: st.warning("Please enter a query.") if __name__ == "__main__": main()