Gampanut commited on
Commit
34bb96a
1 Parent(s): d5aa948

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -25
app.py CHANGED
@@ -2,13 +2,13 @@ import gradio as gr
2
  from langchain_groq import ChatGroq
3
  from langchain_community.graphs.networkx_graph import NetworkxEntityGraph
4
  from langchain.chains import GraphQAChain
5
- from langchain_community.document_loaders import TextLoader # Updated import
6
  from langchain.text_splitter import CharacterTextSplitter
7
- from langchain_community.vectorstores import Pinecone # Updated import
8
- from langchain_community.embeddings import HuggingFaceEmbeddings # Correct import
9
  from langchain.schema.runnable import RunnablePassthrough
10
  from langchain.schema.output_parser import StrOutputParser
11
- from langchain_core.prompts import PromptTemplate # Updated import
12
  from langchain_core.documents import Document
13
  from neo4j import GraphDatabase
14
  import networkx as nx
@@ -18,7 +18,6 @@ from datetime import datetime
18
  import gspread
19
  from oauth2client.service_account import ServiceAccountCredentials
20
 
21
- # Install the missing package
22
  os.system("pip install sentence-transformers")
23
  os.system("pip install gspread oauth2client")
24
 
@@ -35,27 +34,28 @@ def store_feedback_in_sheet(feedback, question, rag_response, graphrag_response)
35
  row = [timestamp, question, rag_response, graphrag_response, feedback]
36
  sheet.append_row(row)
37
 
38
- # Function to load data from Google Sheets
39
- def load_data():
40
- data = sheet.get_all_records()
41
- return data[-10:], len(data) # return the last 10 rows and total count
42
-
43
- # Function to add review to Google Sheets
44
- def add_review(question, rag_response, graphrag_response, feedback):
45
- store_feedback_in_sheet(feedback, question, rag_response, graphrag_response)
46
- return None, None # No output needed since we removed the data and count display
47
-
48
- # Initialize the chatbot and other necessary setups
49
  text_path = r"./text_chunks.txt"
50
  loader = TextLoader(text_path, encoding='utf-8')
51
  documents = loader.load()
52
  text_splitter = CharacterTextSplitter(chunk_size=3000, chunk_overlap=4)
53
  docs = text_splitter.split_documents(documents)
54
 
 
 
 
 
 
 
 
 
 
 
55
  embeddings = HuggingFaceEmbeddings()
56
 
57
  from langchain.llms import HuggingFaceHub
58
 
 
59
  repo_id = "meta-llama/Meta-Llama-3-8B"
60
  llm = HuggingFaceHub(
61
  repo_id=repo_id,
@@ -85,9 +85,9 @@ rag_llm = ChatGroq(
85
  )
86
 
87
  template = """
88
- You are a Thai rice assistants. These Human will ask you a questions about Thai Rice.
89
- Answer the question only in Thai languages.
90
- Use following piece of context to answer the question.
91
  If you don't know the answer, just say you don't know.
92
  Keep the answer within 2 sentences and concise.
93
  Context: {context}
@@ -107,25 +107,118 @@ rag_chain = (
107
  | StrOutputParser()
108
  )
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def get_rag_response(question):
111
  response = rag_chain.invoke(question)
112
  return response
113
 
114
- # Gradio Interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  with gr.Blocks() as demo:
 
 
116
  with gr.Row():
117
  with gr.Column():
118
  question_input = gr.Textbox(label="ถามคำถามเกี่ยวกับข้าว:")
119
  submit_btn = gr.Button("ถาม")
 
 
120
  rag_output = gr.Textbox(label="Model A", interactive=False)
121
  graphrag_output = gr.Textbox(label="Model B", interactive=False)
122
 
 
123
  with gr.Column():
124
- feedback_review = gr.Radio(label="How satisfied are you with the chatbot responses?", choices=["A ดีกว่า", "B ดีกว่า", "แย่ทั้งคู่", "เท่ากัน"])
125
- feedback_comments = gr.Textbox(label="Comments", lines=10, placeholder="Any additional comments?")
126
- feedback_submit = gr.Button(value="Submit Feedback")
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- submit_btn.click(lambda question: (get_rag_response(question), get_rag_response(question)), [question_input], [rag_output, graphrag_output])
129
- feedback_submit.click(add_review, [question_input, rag_output, graphrag_output, feedback_review])
130
 
131
  demo.launch(share=True)
 
2
  from langchain_groq import ChatGroq
3
  from langchain_community.graphs.networkx_graph import NetworkxEntityGraph
4
  from langchain.chains import GraphQAChain
5
+ from langchain_community.document_loaders import TextLoader
6
  from langchain.text_splitter import CharacterTextSplitter
7
+ from langchain_community.vectorstores import Pinecone
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
  from langchain.schema.runnable import RunnablePassthrough
10
  from langchain.schema.output_parser import StrOutputParser
11
+ from langchain import PromptTemplate
12
  from langchain_core.documents import Document
13
  from neo4j import GraphDatabase
14
  import networkx as nx
 
18
  import gspread
19
  from oauth2client.service_account import ServiceAccountCredentials
20
 
 
21
  os.system("pip install sentence-transformers")
22
  os.system("pip install gspread oauth2client")
23
 
 
34
  row = [timestamp, question, rag_response, graphrag_response, feedback]
35
  sheet.append_row(row)
36
 
37
+ # RAG Setup
 
 
 
 
 
 
 
 
 
 
38
  text_path = r"./text_chunks.txt"
39
  loader = TextLoader(text_path, encoding='utf-8')
40
  documents = loader.load()
41
  text_splitter = CharacterTextSplitter(chunk_size=3000, chunk_overlap=4)
42
  docs = text_splitter.split_documents(documents)
43
 
44
+ class CustomTextLoader(TextLoader):
45
+ def __init__(self, file_path: str, encoding: str = 'utf-8'):
46
+ super().__init__(file_path)
47
+ self.encoding = encoding
48
+
49
+ def load(self):
50
+ with open(self.file_path, encoding=self.encoding) as f:
51
+ text = f.read()
52
+ return [Document(page_content=text)]
53
+
54
  embeddings = HuggingFaceEmbeddings()
55
 
56
  from langchain.llms import HuggingFaceHub
57
 
58
+ # Define the repo ID and connect to Mixtral model on Huggingface
59
  repo_id = "meta-llama/Meta-Llama-3-8B"
60
  llm = HuggingFaceHub(
61
  repo_id=repo_id,
 
85
  )
86
 
87
  template = """
88
+ You are a Thai rice assistant. These humans will ask you questions about Thai rice.
89
+ Answer the question only in Thai language.
90
+ Use the following piece of context to answer the question.
91
  If you don't know the answer, just say you don't know.
92
  Keep the answer within 2 sentences and concise.
93
  Context: {context}
 
107
  | StrOutputParser()
108
  )
109
 
110
+ class ChatBot():
111
+ loader = CustomTextLoader(r"./text_chunks.txt", encoding='utf-8')
112
+ documents = loader.load()
113
+
114
+ rag_chain = (
115
+ {"context": docsearch.as_retriever(), "question": RunnablePassthrough()}
116
+ | prompt
117
+ | llm
118
+ | StrOutputParser()
119
+ )
120
+
121
+ graphrag_llm = ChatGroq(
122
+ model="Llama3-8b-8192",
123
+ temperature=0,
124
+ max_tokens=None,
125
+ timeout=None,
126
+ max_retries=5,
127
+ groq_api_key='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'
128
+ )
129
+
130
+ uri = "neo4j+s://46084f1a.databases.neo4j.io"
131
+ user = "neo4j"
132
+ password = "FwnX0ige_QYJk8eEYSXSF0l081mWWGIS7TFg6t8rLZc"
133
+ driver = GraphDatabase.driver(uri, auth=(user, password))
134
+
135
+ def fetch_nodes(tx):
136
+ query = "MATCH (n) RETURN id(n) AS id, labels(n) AS labels"
137
+ result = tx.run(query)
138
+ return result.data()
139
+
140
+ def fetch_relationships(tx):
141
+ query = "MATCH (n)-[r]->(m) RETURN id(n) AS source, id(m) AS target, type(r) AS relation"
142
+ result = tx.run(query)
143
+ return result.data()
144
+
145
+ def populate_networkx_graph():
146
+ G = nx.Graph()
147
+ with driver.session() as session:
148
+ nodes = session.read_transaction(fetch_nodes)
149
+ relationships = session.read_transaction(fetch_relationships)
150
+ for node in nodes:
151
+ G.add_node(node['id'], labels=node['labels'])
152
+ for relationship in relationships:
153
+ G.add_edge(
154
+ relationship['source'],
155
+ relationship['target'],
156
+ relation=relationship['relation']
157
+ )
158
+ return G
159
+
160
+ networkx_graph = populate_networkx_graph()
161
+ graph = NetworkxEntityGraph()
162
+ graph._graph = networkx_graph
163
+
164
+ graphrag_chain = GraphQAChain.from_llm(
165
+ llm=graphrag_llm,
166
+ graph=graph,
167
+ verbose=True
168
+ )
169
+
170
  def get_rag_response(question):
171
  response = rag_chain.invoke(question)
172
  return response
173
 
174
+ def get_graphrag_response(question):
175
+ system_prompt = "You are a Thai rice assistant that gives concise and direct answers. Do not explain the process, just provide the answer, provide the answer only in Thai."
176
+ formatted_question = f"System Prompt: {system_prompt}\n\nQuestion: {question}"
177
+ response = graphrag_chain.run(formatted_question)
178
+ return response
179
+
180
+ def compare_models(question):
181
+ rag_response = get_rag_response(question)
182
+ graphrag_response = get_graphrag_response(question)
183
+ return rag_response, graphrag_response
184
+
185
+ def handle_feedback(feedback, question, rag_response, graphrag_response):
186
+ try:
187
+ store_feedback_in_sheet(feedback, question, rag_response, graphrag_response)
188
+ return "ส่งสำเร็จ!"
189
+ except Exception as e:
190
+ return f"Error: {e}"
191
+
192
  with gr.Blocks() as demo:
193
+ gr.Markdown("## Thai Rice Assistant A/B Testing")
194
+
195
  with gr.Row():
196
  with gr.Column():
197
  question_input = gr.Textbox(label="ถามคำถามเกี่ยวกับข้าว:")
198
  submit_btn = gr.Button("ถาม")
199
+
200
+ with gr.Column():
201
  rag_output = gr.Textbox(label="Model A", interactive=False)
202
  graphrag_output = gr.Textbox(label="Model B", interactive=False)
203
 
204
+ with gr.Row():
205
  with gr.Column():
206
+ choice = gr.Radio(["A ดีกว่า", "B ดีกว่า", "เท่ากัน", "แย่ทั้งคู่"], label="คำตอบไหนดีกว่ากัน?")
207
+ send_feedback_btn = gr.Button("ส่ง")
208
+
209
+ feedback_output = gr.Textbox(label="Feedback Status", interactive=False)
210
+
211
+ def on_submit(question):
212
+ rag_response, graphrag_response = compare_models(question)
213
+ return rag_response, graphrag_response
214
+
215
+ def on_feedback(feedback):
216
+ question = question_input.value
217
+ rag_response = rag_output.value
218
+ graphrag_response = graphrag_output.value
219
+ return handle_feedback(feedback, question, rag_response, graphrag_response)
220
 
221
+ submit_btn.click(on_submit, inputs=[question_input], outputs=[rag_output, graphrag_output])
222
+ send_feedback_btn.click(on_feedback, inputs=[choice], outputs=[feedback_output])
223
 
224
  demo.launch(share=True)