Shreyas094 commited on
Commit
0b607fb
1 Parent(s): 7f4043c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -243
app.py CHANGED
@@ -1,28 +1,50 @@
1
  import os
2
- import json
3
- import re
4
- import gradio as gr
5
- import pandas as pd
6
- import requests
7
- import random
8
- import urllib.parse
9
- from tempfile import NamedTemporaryFile
10
- from typing import List, Dict, Optional
11
- from bs4 import BeautifulSoup
12
  import logging
13
- from duckduckgo_search import DDGS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- from langchain_community.llms import HuggingFaceHub
16
- from langchain_community.vectorstores import FAISS
17
- from langchain_community.document_loaders import PyPDFLoader
18
- from langchain_community.embeddings import HuggingFaceEmbeddings
19
- from langchain_core.documents import Document
20
- from langchain.chains import LLMChain
21
- from langchain.prompts import PromptTemplate
22
 
23
- # Global variables
24
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def get_model(temperature, top_p, repetition_penalty):
27
  return HuggingFaceHub(
28
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
@@ -35,248 +57,190 @@ def get_model(temperature, top_p, repetition_penalty):
35
  huggingfacehub_api_token=huggingface_token
36
  )
37
 
38
- def load_document(file: NamedTemporaryFile) -> List[Document]:
39
- loader = PyPDFLoader(file.name)
40
- return loader.load_and_split()
41
-
42
- def update_vectors(files):
43
- if not files:
44
- return "Please upload at least one PDF file."
45
-
46
- embed = get_embeddings()
47
- total_chunks = 0
48
-
49
- all_data = []
50
- for file in files:
51
- data = load_document(file)
52
- all_data.extend(data)
53
- total_chunks += len(data)
54
-
55
- if os.path.exists("faiss_database"):
56
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
57
- database.add_documents(all_data)
58
- else:
59
- database = FAISS.from_documents(all_data, embed)
60
-
61
- database.save_local("faiss_database")
62
-
63
- return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files."
64
-
65
- def get_embeddings():
66
- return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
67
-
68
- def clear_cache():
69
- if os.path.exists("faiss_database"):
70
- os.remove("faiss_database")
71
- return "Cache cleared successfully."
72
- else:
73
- return "No cache to clear."
74
-
75
- def extract_text_from_webpage(html):
76
- soup = BeautifulSoup(html, 'html.parser')
77
- for script in soup(["script", "style"]):
78
- script.extract()
79
- text = soup.get_text()
80
- lines = (line.strip() for line in text.splitlines())
81
- chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
82
- text = '\n'.join(chunk for chunk in chunks if chunk)
83
- return text
84
-
85
- _useragent_list = [
86
- "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
87
- "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
88
- "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
89
- "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
90
- "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
91
- "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
92
- ]
93
-
94
- def google_search(term, num_results=5, lang="en", timeout=5, safe="active", ssl_verify=None):
95
- escaped_term = urllib.parse.quote_plus(term)
96
- start = 0
97
- all_results = []
98
- max_chars_per_page = 8000
99
-
100
- with requests.Session() as session:
101
- while start < num_results:
102
- try:
103
- user_agent = random.choice(_useragent_list)
104
- headers = {
105
- 'User-Agent': user_agent
106
- }
107
- resp = session.get(
108
- url="https://www.google.com/search",
109
- headers=headers,
110
- params={
111
- "q": term,
112
- "num": num_results - start,
113
- "hl": lang,
114
- "start": start,
115
- "safe": safe,
116
- },
117
- timeout=timeout,
118
- verify=ssl_verify,
119
- )
120
- resp.raise_for_status()
121
- except requests.exceptions.RequestException as e:
122
- print(f"Error retrieving search results: {e}")
123
- break
124
-
125
- soup = BeautifulSoup(resp.text, "html.parser")
126
- result_block = soup.find_all("div", attrs={"class": "g"})
127
- if not result_block:
128
- break
129
-
130
- for result in result_block:
131
- link = result.find("a", href=True)
132
- if link:
133
- link = link["href"]
134
- try:
135
- webpage = session.get(link, headers=headers, timeout=timeout)
136
- webpage.raise_for_status()
137
- visible_text = extract_text_from_webpage(webpage.text)
138
- if len(visible_text) > max_chars_per_page:
139
- visible_text = visible_text[:max_chars_per_page] + "..."
140
- all_results.append({"link": link, "text": visible_text})
141
- except requests.exceptions.RequestException as e:
142
- print(f"Error retrieving webpage content: {e}")
143
- all_results.append({"link": link, "text": None})
144
- else:
145
- all_results.append({"link": None, "text": None})
146
- start += len(result_block)
147
 
148
- if not all_results:
149
- return [{"link": None, "text": "No information found in the web search results."}]
150
-
151
- return all_results
152
-
153
- def duckduckgo_search(query, max_results=5):
154
  try:
155
- search = DDGSearch()
156
- results = search.text(query, max_results=max_results)
157
- formatted_results = []
158
- for result in results:
159
- formatted_results.append({
160
- "link": result.get('href', ''),
161
- "text": result.get('title', '') + '. ' + result.get('body', '')
162
- })
163
- return formatted_results
164
  except Exception as e:
165
- print(f"Error in DuckDuckGo search: {e}")
166
- return [{"link": None, "text": "No information found in the web search results."}]
 
 
 
 
 
 
 
 
167
 
 
168
  def respond(
169
  message,
170
  history: list[tuple[str, str]],
 
 
 
171
  temperature,
172
  top_p,
173
- repetition_penalty,
174
- max_tokens,
175
- search_engine
176
  ):
177
- model = get_model(temperature, top_p, repetition_penalty)
178
-
179
- # Perform web search
180
- if search_engine == "Google":
181
- search_results = google_search(message)
182
- else:
183
- search_results = duckduckgo_search(message, max_results=5)
184
-
185
- # Check if we have a FAISS database
186
- if os.path.exists("faiss_database"):
187
- embed = get_embeddings()
188
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
189
- retriever = database.as_retriever()
190
- relevant_docs = retriever.get_relevant_documents(message)
191
- context_str = "\n".join([doc.page_content for doc in relevant_docs])
192
-
193
- # Use the context in the prompt
194
- prompt_template = f"""
195
- Answer the question based on the following context and web search results:
196
- Context from documents:
197
- {context_str}
198
-
199
- Web Search Results:
200
- {{search_results}}
201
-
202
- Question: {{message}}
203
-
204
- If the context and web search results don't contain relevant information, state that the information is not available.
205
- Provide a concise and direct answer to the question.
206
- """
207
- else:
208
- prompt_template = """
209
- Answer the question based on the following web search results:
210
- Web Search Results:
211
- {search_results}
212
-
213
- Question: {message}
214
-
215
- If the web search results don't contain relevant information, state that the information is not available.
216
- Provide a concise and direct answer to the question.
217
- """
218
 
219
- prompt = PromptTemplate(
220
- input_variables=["search_results", "message"],
221
- template=prompt_template
 
 
222
  )
223
 
224
- chain = LLMChain(llm=model, prompt=prompt)
 
 
 
 
 
225
 
226
- search_results_text = "\n".join([f"- {result['text']}" for result in search_results if result['text']])
227
- response = chain.run(search_results=search_results_text, message=message)
 
 
 
 
 
228
 
229
- # Add sources
230
- sources = set(result["link"] for result in search_results if result["link"])
231
- sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
232
- response += sources_section
233
 
234
- # Update history and return
235
- history.append((message, response))
236
- return history
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
- # Gradio interface
239
- demo = gr.Blocks()
 
 
 
 
 
 
 
 
 
 
240
 
241
- with demo:
242
- gr.Markdown("# Chat with your PDF documents and Web Search")
243
-
244
- with gr.Row():
245
- file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
246
- update_button = gr.Button("Upload PDF")
247
-
248
- update_output = gr.Textbox(label="Update Status")
249
- update_button.click(update_vectors, inputs=[file_input], outputs=update_output)
250
-
251
- with gr.Row():
252
- with gr.Column(scale=2):
253
- chatbot = gr.Chatbot(label="Conversation")
254
- message_input = gr.Textbox(label="Enter your message")
255
- submit_button = gr.Button("Submit")
256
- with gr.Column(scale=1):
257
- temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
258
- top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
259
- repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty")
260
- max_tokens = gr.Slider(minimum=1, maximum=1000, value=500, step=1, label="Max tokens")
261
- search_engine = gr.Dropdown(["DuckDuckGo", "Google"], value="DuckDuckGo", label="Search Engine")
262
 
263
- submit_button.click(
264
- respond,
265
- inputs=[
266
- message_input,
267
- gr.State([]), # Initialize empty history
268
- temperature,
269
- top_p,
270
- repetition_penalty,
271
- max_tokens,
272
- search_engine
273
- ],
274
- outputs=[chatbot]
275
  )
276
-
277
- clear_button = gr.Button("Clear Cache")
278
- clear_output = gr.Textbox(label="Cache Status")
279
- clear_button.click(clear_cache, inputs=[], outputs=clear_output)
280
 
281
- if __name__ == "__main__":
282
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
 
 
 
 
 
 
 
2
  import logging
3
+ import gradio as gr
4
+ from huggingface_hub import hf_hub_download
5
+ from transformers import HuggingFaceHub
6
+
7
+ from llama_cpp_agent.providers import LlamaCppPythonProvider
8
+ from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
9
+ from llama_cpp_agent.chat_history import BasicChatHistory
10
+ from llama_cpp_agent.chat_history.messages import Roles
11
+ from llama_cpp_agent.llm_output_settings import (
12
+ LlmStructuredOutputSettings,
13
+ LlmStructuredOutputType,
14
+ )
15
+ from llama_cpp_agent.tools import WebSearchTool
16
+ from llama_cpp_agent.prompt_templates import web_search_system_prompt, research_system_prompt
17
+ from pydantic import BaseModel, Field
18
+ from trafilatura import fetch_url, extract
19
+ import json
20
+ from datetime import datetime, timezone
21
+ from typing import List
22
 
23
+ llm = None
24
+ llm_model = None
 
 
 
 
 
25
 
 
26
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
27
 
28
+ examples = [
29
+ ["latest news about Yann LeCun"],
30
+ ["Latest news site:github.blog"],
31
+ ["Where I can find best hotel in Galapagos, Ecuador intitle:hotel"],
32
+ ["filetype:pdf intitle:python"]
33
+ ]
34
+
35
+ def get_context_by_model(model_name):
36
+ model_context_limits = {
37
+ "Mistral-7B-Instruct-v0.3": 32768,
38
+ }
39
+ return model_context_limits.get(model_name, None)
40
+
41
+ def get_messages_formatter_type(model_name):
42
+ model_name = model_name.lower()
43
+ if "mistral" in model_name:
44
+ return MessagesFormatterType.MISTRAL
45
+ else:
46
+ return MessagesFormatterType.CHATML
47
+
48
  def get_model(temperature, top_p, repetition_penalty):
49
  return HuggingFaceHub(
50
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
 
57
  huggingfacehub_api_token=huggingface_token
58
  )
59
 
60
+ def get_server_time():
61
+ utc_time = datetime.now(timezone.utc)
62
+ return utc_time.strftime("%Y-%m-%d %H:%M:%S")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ def get_website_content_from_url(url: str) -> str:
 
 
 
 
 
65
  try:
66
+ downloaded = fetch_url(url)
67
+ result = extract(downloaded, include_formatting=True, include_links=True, output_format='json', url=url)
68
+ if result:
69
+ result = json.loads(result)
70
+ return f'=========== Website Title: {result["title"]} ===========\n\n=========== Website URL: {url} ===========\n\n=========== Website Content ===========\n\n{result["raw_text"]}\n\n=========== Website Content End ===========\n\n'
71
+ else:
72
+ return ""
 
 
73
  except Exception as e:
74
+ return f"An error occurred: {str(e)}"
75
+
76
+ class CitingSources(BaseModel):
77
+ sources: List[str] = Field(
78
+ ...,
79
+ description="List of sources to cite. Should be an URL of the source. E.g. GitHub URL, Blogpost URL or Newsletter URL."
80
+ )
81
+
82
+ def write_message_to_user():
83
+ return "Please write the message to the user."
84
 
85
+ @spaces.GPU(duration=120)
86
  def respond(
87
  message,
88
  history: list[tuple[str, str]],
89
+ model,
90
+ system_message,
91
+ max_tokens,
92
  temperature,
93
  top_p,
94
+ top_k,
95
+ repeat_penalty,
 
96
  ):
97
+ global llm
98
+ global llm_model
99
+ chat_template = get_messages_formatter_type(model)
100
+ if llm is None or llm_model != model:
101
+ llm = get_model(temperature, top_p, repeat_penalty)
102
+ llm_model = model
103
+ provider = LlamaCppPythonProvider(llm)
104
+ logging.info(f"Loaded chat examples: {chat_template}")
105
+ search_tool = WebSearchTool(
106
+ llm_provider=provider,
107
+ message_formatter_type=chat_template,
108
+ max_tokens_search_results=12000,
109
+ max_tokens_per_summary=2048,
110
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ web_search_agent = LlamaCppAgent(
113
+ provider,
114
+ system_prompt=web_search_system_prompt,
115
+ predefined_messages_formatter_type=chat_template,
116
+ debug_output=True,
117
  )
118
 
119
+ answer_agent = LlamaCppAgent(
120
+ provider,
121
+ system_prompt=research_system_prompt,
122
+ predefined_messages_formatter_type=chat_template,
123
+ debug_output=True,
124
+ )
125
 
126
+ settings = provider.get_provider_default_settings()
127
+ settings.stream = False
128
+ settings.temperature = temperature
129
+ settings.top_k = top_k
130
+ settings.top_p = top_p
131
+ settings.max_tokens = max_tokens
132
+ settings.repeat_penalty = repeat_penalty
133
 
134
+ output_settings = LlmStructuredOutputSettings.from_functions(
135
+ [search_tool.get_tool()]
136
+ )
 
137
 
138
+ messages = BasicChatHistory()
139
+
140
+ for msn in history:
141
+ user = {"role": Roles.user, "content": msn[0]}
142
+ assistant = {"role": Roles.assistant, "content": msn[1]}
143
+ messages.add_message(user)
144
+ messages.add_message(assistant)
145
+
146
+ result = web_search_agent.get_chat_response(
147
+ message,
148
+ llm_sampling_settings=settings,
149
+ structured_output_settings=output_settings,
150
+ add_message_to_chat_history=False,
151
+ add_response_to_chat_history=False,
152
+ print_output=False,
153
+ )
154
 
155
+ outputs = ""
156
+
157
+ settings.stream = True
158
+ response_text = answer_agent.get_chat_response(
159
+ f"Write a detailed and complete research document that fulfills the following user request: '{message}', based on the information from the web below.\n\n" +
160
+ result[0]["return_value"],
161
+ role=Roles.tool,
162
+ llm_sampling_settings=settings,
163
+ chat_history=messages,
164
+ returns_streaming_generator=True,
165
+ print_output=False,
166
+ )
167
 
168
+ for text in response_text:
169
+ outputs += text
170
+ yield outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ output_settings = LlmStructuredOutputSettings.from_pydantic_models(
173
+ [CitingSources], LlmStructuredOutputType.object_instance
 
 
 
 
 
 
 
 
 
 
174
  )
 
 
 
 
175
 
176
+ citing_sources = answer_agent.get_chat_response(
177
+ "Cite the sources you used in your response.",
178
+ role=Roles.tool,
179
+ llm_sampling_settings=settings,
180
+ chat_history=messages,
181
+ returns_streaming_generator=False,
182
+ structured_output_settings=output_settings,
183
+ print_output=False,
184
+ )
185
+ outputs += "\n\nSources:\n"
186
+ outputs += "\n".join(citing_sources.sources)
187
+ yield outputs
188
+
189
+ demo = gr.ChatInterface(
190
+ respond,
191
+ additional_inputs=[
192
+ gr.Dropdown([
193
+ 'Mistral-7B-Instruct-v0.3'
194
+ ],
195
+ value="Mistral-7B-Instruct-v0.3",
196
+ label="Model"
197
+ ),
198
+ gr.Textbox(value=web_search_system_prompt, label="System message"),
199
+ gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens"),
200
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.45, step=0.1, label="Temperature"),
201
+ gr.Slider(
202
+ minimum=0.1,
203
+ maximum=1.0,
204
+ value=0.95,
205
+ step=0.05,
206
+ label="Top-p",
207
+ ),
208
+ gr.Slider(
209
+ minimum=0,
210
+ maximum=100,
211
+ value=40,
212
+ step=1,
213
+ label="Top-k",
214
+ ),
215
+ gr.Slider(
216
+ minimum=0.0,
217
+ maximum=2.0,
218
+ value=1.1,
219
+ step=0.1,
220
+ label="Repetition penalty",
221
+ ),
222
+ ],
223
+ theme=gr.themes.Soft(
224
+ primary_hue="orange",
225
+ secondary_hue="amber",
226
+ neutral_hue="gray",
227
+ font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]).set(
228
+ body_background_fill_dark="#0c0505",
229
+ block_background_fill_dark="#0c0505",
230
+ block_border_width="1px",
231
+ block_title_background_fill_dark="#1b0f0f",
232
+ input_background_fill_dark="#140b0b",
233
+ button_secondary_background_fill_dark="#140b0b",
234
+ border_color_accent_dark="#1b0f0f",
235
+ border_color_primary_dark="#1b0f0f",
236
+ slider_color="#ff911a",
237
+ button_primary_background_fill="#ff911a",
238
+ button_primary_background_fill_dark="#ff911a",
239
+ button_primary_text_color="#f9f9f9",
240
+ button_primary_text_color_dark="#f9f9f9"
241
+ ),
242
+ examples=examples,
243
+ title="llama.cpp agent",
244
+ )
245
+
246
+ demo.queue().launch()