anezatra2 commited on
Commit
8433cd7
β€’
1 Parent(s): b28f317

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -60
app.py CHANGED
@@ -1,63 +1,200 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
  )
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import re
4
+ import random
5
+ import streamlit as st
6
+ from streamlit_chat import message
7
+ # from langchain.embeddings.openai import OpenAIEmbeddings
8
+ from langchain.memory import ConversationSummaryBufferMemory
9
+ from langchain.llms import OpenAI
10
+ from langchain.chains import ConversationalRetrievalChain, ConversationChain
11
+ from langchain import PromptTemplate
12
+ import qdrant_client
13
+ from langchain.embeddings import OpenAIEmbeddings
14
+ from langchain.vectorstores import Qdrant
15
+
16
+ from dotenv import load_dotenv
17
+ load_dotenv(".env")
18
+
19
+
20
+ prompt_template = """Use the following pieces of context to answer the question at the end. If you don't
21
+ know the answer, or similar answer is not in the context, you should say that 'I've searched my database,
22
+ but I couldn't locate the exact information you're looking for. May be you want to be more specific
23
+ in your search. Or checkout similar documents'.
24
+ Answer user greetings and ask them what they i'd like to learn about. You are a bot that teaches users
25
+ about american law codes
26
+
27
+ Context: {context}
28
+ Question: {question}
29
+ Helpful Answer:"""
30
+ QA_PROMPT_ERROR = PromptTemplate(
31
+ template=prompt_template, input_variables=["context", "question"]
32
+ )
33
+
34
+ # Use different logo
35
+ def logo(logo: str = None) -> str:
36
+ logos = [
37
+
38
+ "https://res.cloudinary.com/webmonc/image/upload/v1696515089/3558860_r0hs4y.png"
39
+ ]
40
+ logo = random.choice(logos)
41
+ return logo
42
+
43
+
44
+ memory = ConversationSummaryBufferMemory(
45
+ llm=OpenAI(
46
+ temperature=0),
47
+ max_token_limit=150,
48
+ memory_key='chat_history',
49
+ return_messages=True,
50
+ output_key='answer')
51
+
52
+
53
+ # Streamlit Component
54
+
55
+ st.set_page_config(
56
+ page_title="USA Law Codes",
57
+ # page_icon=":robot:"
58
+ page_icon=":us:"
59
  )
60
 
61
+ st.header("πŸ“‹ ChatBot for Learning About USA Laws")
62
+ # st.title("πŸ‘‹ πŸ“ ChatBot for Learning About American Laws")
63
+ user_city = st.selectbox("Select a City", ("Maricopa", "LAH", "PGC"))
64
+
65
+
66
+ hide_st_style = """
67
+ <style>
68
+ #MainMenu {visibility: hidden;}
69
+ footer {visibility: hidden;}
70
+ header {visibility: hidden;}
71
+ </style>
72
+ """
73
+ st.markdown(hide_st_style, unsafe_allow_html=True)
74
+
75
+
76
+ if 'responses' not in st.session_state:
77
+ st.session_state['responses'] = ["I'm here to assist you!"]
78
+
79
+ if 'requests' not in st.session_state:
80
+ st.session_state['requests'] = []
81
+
82
+ if 'buffer_memory' not in st.session_state:
83
+ st.session_state.buffer_memory = memory
84
+
85
+
86
+ # connect to a Qdrant Cluster
87
+ client = qdrant_client.QdrantClient(
88
+ url=os.getenv("QDRANT_HOST"),
89
+ api_key=os.getenv("QDRANT_API_KEY")
90
+ )
91
+
92
+ embeddings = OpenAIEmbeddings()
93
+
94
+
95
+ # Change Db base on city
96
+ def connect_db(db: str = None) -> str:
97
+ db = user_city
98
+ if user_city == "LAH":
99
+ db = "collection_two" # I.e set a collection/DB name
100
+ elif db == "Maricopa":
101
+ db = "test3"
102
+ elif db == "PGC":
103
+ db = "pgc"
104
+
105
+ vector_store = Qdrant(
106
+ client=client,
107
+ collection_name=db,
108
+ embeddings=embeddings
109
+ )
110
+ return vector_store
111
+
112
+
113
+ def get_urls(doc: str = None) -> "list[str]":
114
+ url_regex = '(http[s]?://?[A-Za-z0-9–_\\.\\-]+\\.[A-Za-z]+/?[A-Za-z0-9$\\–_\\-\\/\\.\\?]*)[\\.)\"]*'
115
+ url = re.findall(url_regex, doc)
116
+ return url
117
+
118
+
119
+ def print_answer_metadata(result: "list[dict]") -> str:
120
+ links = []
121
+ output_answer = ""
122
+ output_answer += result['answer']
123
+ for doc in result['source_documents']:
124
+ link = get_urls(doc.page_content)
125
+ links.extend(link)
126
+ link = "\n".join(links)
127
+
128
+ if links != []:
129
+ output_answer += "\n" + "See also: " + link
130
+
131
+ # print("OUT", output_answer)
132
+ return output_answer
133
+
134
+
135
+ def print_page_content(result: "list[dict]") -> str:
136
+ extracted_string = ""
137
+
138
+ for doc in result['source_documents']:
139
+
140
+ page_content = doc.page_content[:200] + "..."
141
+
142
+ title = doc.page_content[0:35] + "..."
143
+ if page_content and title:
144
+ extracted_string += f"<hr><h4>Document Title:</h4> {title}\n\n\n <h4>Excerpt:</h4>\
145
+ {page_content}\n\n"
146
+ return extracted_string
147
+
148
+
149
+ qa = ConversationalRetrievalChain.from_llm(
150
+ OpenAI(temperature=0),
151
+ connect_db().as_retriever(),
152
+ memory=st.session_state.buffer_memory,
153
+ verbose=True,
154
+ return_source_documents=True,
155
+ combine_docs_chain_kwargs={'prompt': QA_PROMPT_ERROR})
156
+
157
+
158
+ response_container = st.container()
159
+ textcontainer = st.container()
160
+
161
+ details = ''
162
+
163
+ with textcontainer:
164
+ query = st.text_input("You: ", key="input", placeholder="start chat")
165
+ submit = st.button("send")
166
+ if submit:
167
+ res = qa({"question": query})
168
+ response = print_answer_metadata(res)
169
+ details = print_page_content(res)
170
+ st.session_state.requests.append(query)
171
+ st.session_state.responses.append(response)
172
+
173
+
174
+ with response_container:
175
+ if st.session_state['responses']:
176
+ for i in range(len(st.session_state['responses'])):
177
+ message(
178
+ st.session_state['responses'][i],
179
+ key=str(i),
180
+ avatar_style="no-avatar",
181
+ logo=logo(),
182
+ allow_html=True)
183
+ if i < len(st.session_state['requests']):
184
+ message(
185
+ st.session_state["requests"][i],
186
+ is_user=True,
187
+ key=str(i) + '_user',
188
+ allow_html=True
189
+ )
190
+
191
+
192
+ with st.sidebar:
193
+ st.image("https://res.cloudinary.com/webmonc/image/upload/v1696603202/Bot%20Streamlit/law_justice1_yqaqvd.jpg")
194
 
195
+ if details:
196
+ with st.spinner("Processing..."):
197
+ time.sleep(1)
198
+ st.markdown('__Similar Documents__')
199
+ st.markdown(f'''<small>{details}</small>''', unsafe_allow_html=True)
200
+