himel06 commited on
Commit
24e3a93
1 Parent(s): 8b5bb01

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import logging
3
+ from BanglaRAG.bangla_rag_pipeline import BanglaRAGChain
4
+ import warnings
5
+
6
+ warnings.filterwarnings("ignore")
7
+
8
+ # Default constants for the script
9
+ DEFAULT_CHAT_MODEL_ID = "hassanaliemon/bn_rag_llama3-8b"
10
+ DEFAULT_EMBED_MODEL_ID = "l3cube-pune/bengali-sentence-similarity-sbert"
11
+ DEFAULT_K = 4
12
+ DEFAULT_TOP_K = 2
13
+ DEFAULT_TOP_P = 0.6
14
+ DEFAULT_TEMPERATURE = 0.6
15
+ DEFAULT_CHUNK_SIZE = 500
16
+ DEFAULT_CHUNK_OVERLAP = 150
17
+ DEFAULT_MAX_NEW_TOKENS = 256
18
+
19
+ # Set up logging
20
+ logging.basicConfig(
21
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
22
+ )
23
+
24
+ # Initialize and load the RAG model
25
+ @st.cache_resource(show_spinner=False)
26
+ def load_model(chat_model_id, embed_model_id, text_path, k, top_k, top_p, temperature, chunk_size, chunk_overlap, hf_token, max_new_tokens, quantization):
27
+ rag_chain = BanglaRAGChain()
28
+ rag_chain.load(
29
+ chat_model_id=chat_model_id,
30
+ embed_model_id=embed_model_id,
31
+ text_path=text_path,
32
+ k=k,
33
+ top_k=top_k,
34
+ top_p=top_p,
35
+ temperature=temperature,
36
+ chunk_size=chunk_size,
37
+ chunk_overlap=chunk_overlap,
38
+ hf_token=hf_token,
39
+ max_new_tokens=max_new_tokens,
40
+ quantization=quantization,
41
+ )
42
+ return rag_chain
43
+
44
+ def main():
45
+ st.title("Bangla RAG Chatbot")
46
+
47
+ # Sidebar for model configuration
48
+ st.sidebar.header("Model Configuration")
49
+
50
+ chat_model_id = st.sidebar.text_input("Chat Model ID", DEFAULT_CHAT_MODEL_ID)
51
+ embed_model_id = st.sidebar.text_input("Embed Model ID", DEFAULT_EMBED_MODEL_ID)
52
+ k = st.sidebar.slider("Number of Documents to Retrieve (k)", 1, 10, DEFAULT_K)
53
+ top_k = st.sidebar.slider("Top K", 1, 10, DEFAULT_TOP_K)
54
+ top_p = st.sidebar.slider("Top P", 0.0, 1.0, DEFAULT_TOP_P)
55
+ temperature = st.sidebar.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
56
+ max_new_tokens = st.sidebar.slider("Max New Tokens", 1, 512, DEFAULT_MAX_NEW_TOKENS)
57
+ chunk_size = st.sidebar.slider("Chunk Size", 100, 1000, DEFAULT_CHUNK_SIZE)
58
+ chunk_overlap = st.sidebar.slider("Chunk Overlap", 0, 500, DEFAULT_CHUNK_OVERLAP)
59
+ text_path = st.sidebar.text_input("Text File Path", "text.txt")
60
+ quantization = st.sidebar.checkbox("Enable Quantization (4-bit)", value=False)
61
+ show_context = st.sidebar.checkbox("Show Retrieved Context", value=False)
62
+
63
+ hf_token = st.text_input("Hugging Face API Token", type="password")
64
+
65
+ # Load the model with the above configuration
66
+ rag_chain = load_model(
67
+ chat_model_id=chat_model_id,
68
+ embed_model_id=embed_model_id,
69
+ text_path=text_path,
70
+ k=k,
71
+ top_k=top_k,
72
+ top_p=top_p,
73
+ temperature=temperature,
74
+ chunk_size=chunk_size,
75
+ chunk_overlap=chunk_overlap,
76
+ hf_token=hf_token,
77
+ max_new_tokens=max_new_tokens,
78
+ quantization=quantization,
79
+ )
80
+
81
+ st.write("### Enter your question:")
82
+ query = st.text_input("আপনার প্রশ্ন")
83
+
84
+ if st.button("Generate Answer"):
85
+ if query:
86
+ try:
87
+ answer, context = rag_chain.get_response(query)
88
+ st.write(f"**Answer:** {answer}")
89
+ if show_context:
90
+ st.write(f"**Context:** {context}")
91
+ except Exception as e:
92
+ st.error(f"Couldn't generate an answer: {e}")
93
+ else:
94
+ st.warning("Please enter a query.")
95
+
96
+ if __name__ == "__main__":
97
+ main()