|
import streamlit as st |
|
import logging |
|
from BanglaRAG.bangla_rag_pipeline import BanglaRAGChain |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
DEFAULT_CHAT_MODEL_ID = "hassanaliemon/bn_rag_llama3-8b" |
|
DEFAULT_EMBED_MODEL_ID = "l3cube-pune/bengali-sentence-similarity-sbert" |
|
DEFAULT_K = 4 |
|
DEFAULT_TOP_K = 2 |
|
DEFAULT_TOP_P = 0.6 |
|
DEFAULT_TEMPERATURE = 0.6 |
|
DEFAULT_CHUNK_SIZE = 500 |
|
DEFAULT_CHUNK_OVERLAP = 150 |
|
DEFAULT_MAX_NEW_TOKENS = 256 |
|
DEFAULT_OFFLOAD_DIR = "/tmp" |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def load_model(chat_model_id, embed_model_id, text_path, k, top_k, top_p, temperature, chunk_size, chunk_overlap, hf_token, max_new_tokens, quantization, offload_dir): |
|
rag_chain = BanglaRAGChain() |
|
rag_chain.load( |
|
chat_model_id=chat_model_id, |
|
embed_model_id=embed_model_id, |
|
text_path=text_path, |
|
k=k, |
|
top_k=top_k, |
|
top_p=top_p, |
|
temperature=temperature, |
|
chunk_size=chunk_size, |
|
chunk_overlap=chunk_overlap, |
|
hf_token=hf_token, |
|
max_new_tokens=max_new_tokens, |
|
quantization=quantization, |
|
offload_dir=offload_dir, |
|
) |
|
return rag_chain |
|
|
|
def main(): |
|
st.title("Bangla RAG Chatbot") |
|
|
|
|
|
st.sidebar.header("Model Configuration") |
|
|
|
chat_model_id = st.sidebar.text_input("Chat Model ID", DEFAULT_CHAT_MODEL_ID) |
|
embed_model_id = st.sidebar.text_input("Embed Model ID", DEFAULT_EMBED_MODEL_ID) |
|
k = st.sidebar.slider("Number of Documents to Retrieve (k)", 1, 10, DEFAULT_K) |
|
top_k = st.sidebar.slider("Top K", 1, 10, DEFAULT_TOP_K) |
|
top_p = st.sidebar.slider("Top P", 0.0, 1.0, DEFAULT_TOP_P) |
|
temperature = st.sidebar.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE) |
|
max_new_tokens = st.sidebar.slider("Max New Tokens", 1, 512, DEFAULT_MAX_NEW_TOKENS) |
|
chunk_size = st.sidebar.slider("Chunk Size", 100, 1000, DEFAULT_CHUNK_SIZE) |
|
chunk_overlap = st.sidebar.slider("Chunk Overlap", 0, 500, DEFAULT_CHUNK_OVERLAP) |
|
text_path = st.sidebar.text_input("Text File Path", "text.txt") |
|
quantization = st.sidebar.checkbox("Enable Quantization (4-bit)", value=False) |
|
show_context = st.sidebar.checkbox("Show Retrieved Context", value=False) |
|
offload_dir = st.sidebar.text_input("Offload Directory", DEFAULT_OFFLOAD_DIR) |
|
|
|
|
|
rag_chain = load_model( |
|
chat_model_id=chat_model_id, |
|
embed_model_id=embed_model_id, |
|
text_path=text_path, |
|
k=k, |
|
top_k=top_k, |
|
top_p=top_p, |
|
temperature=temperature, |
|
chunk_size=chunk_size, |
|
chunk_overlap=chunk_overlap, |
|
hf_token=None, |
|
max_new_tokens=max_new_tokens, |
|
quantization=quantization, |
|
offload_dir=offload_dir, |
|
) |
|
|
|
st.write("### Enter your question:") |
|
query = st.text_input("আপনার প্রশ্ন") |
|
|
|
if st.button("Generate Answer"): |
|
if query: |
|
try: |
|
answer, context = rag_chain(query) |
|
st.write(f"**Answer:** {answer}") |
|
if show_context: |
|
st.write(f"**Context:** {context}") |
|
except Exception as e: |
|
st.error(f"Couldn't generate an answer: {e}") |
|
else: |
|
st.warning("Please enter a query.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|