Shreyas094
commited on
Commit
•
6b3b427
1
Parent(s):
3449685
Update app.py
Browse files
app.py
CHANGED
@@ -67,7 +67,7 @@ def load_document(file: NamedTemporaryFile, parser: str = "llamaparse") -> List[
|
|
67 |
raise ValueError("Invalid parser specified. Use 'pypdf' or 'llamaparse'.")
|
68 |
|
69 |
def get_embeddings():
|
70 |
-
return HuggingFaceEmbeddings(model_name="
|
71 |
|
72 |
# Add this at the beginning of your script, after imports
|
73 |
DOCUMENTS_FILE = "uploaded_documents.json"
|
@@ -271,10 +271,33 @@ def generate_chunked_response(prompt, model, max_tokens=10000, num_calls=3, temp
|
|
271 |
print(f"Final clean response: {final_response[:100]}...")
|
272 |
return final_response
|
273 |
|
274 |
-
|
275 |
-
|
276 |
-
results =
|
277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
|
279 |
class CitingSources(BaseModel):
|
280 |
sources: List[str] = Field(
|
@@ -376,7 +399,7 @@ def get_context_for_summary(selected_docs):
|
|
376 |
embed = get_embeddings()
|
377 |
if os.path.exists("faiss_database"):
|
378 |
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
379 |
-
retriever = database.as_retriever(search_kwargs={"k":
|
380 |
|
381 |
# Create a generic query that covers common financial summary topics
|
382 |
generic_query = "financial performance revenue profit assets liabilities cash flow key metrics highlights"
|
@@ -409,36 +432,6 @@ def get_context_for_query(query, selected_docs):
|
|
409 |
else:
|
410 |
return "No documents available to answer the query."
|
411 |
|
412 |
-
def validate_response(initial_response, context, query, model, temperature=0.1):
|
413 |
-
validation_prompt = f"""Given the following context and initial response to the query "{query}":
|
414 |
-
|
415 |
-
Context:
|
416 |
-
{context}
|
417 |
-
|
418 |
-
Initial Response:
|
419 |
-
{initial_response}
|
420 |
-
|
421 |
-
You are an expert assistant tasked with carefully validating the initial response against the provided context. Remove any hallucinations, irrelevant details, or factually incorrect information. Generate a revised response that is accurate and directly supported by the context. If any information cannot be verified from the context, explicitly state that it could not be confirmed. After writing the revised response, provide a list of all sources used.
|
422 |
-
|
423 |
-
Revised Response:
|
424 |
-
"""
|
425 |
-
|
426 |
-
if model == "@cf/meta/llama-3.1-8b-instruct":
|
427 |
-
return get_response_from_cloudflare(prompt=validation_prompt, context="", query="", num_calls=1, temperature=temperature, search_type="validation")
|
428 |
-
else:
|
429 |
-
client = InferenceClient(model, token=huggingface_token)
|
430 |
-
revised_response = ""
|
431 |
-
for message in client.chat_completion(
|
432 |
-
messages=[{"role": "user", "content": validation_prompt}],
|
433 |
-
max_tokens=10000,
|
434 |
-
temperature=temperature,
|
435 |
-
stream=True,
|
436 |
-
):
|
437 |
-
if message.choices and message.choices[0].delta and message.choices[0].delta.content:
|
438 |
-
chunk = message.choices[0].delta.content
|
439 |
-
revised_response += chunk
|
440 |
-
yield revised_response
|
441 |
-
|
442 |
def get_response_from_cloudflare(prompt, context, query, num_calls=3, temperature=0.2, search_type="pdf"):
|
443 |
headers = {
|
444 |
"Authorization": f"Bearer {API_TOKEN}",
|
@@ -450,19 +443,15 @@ def get_response_from_cloudflare(prompt, context, query, num_calls=3, temperatur
|
|
450 |
instruction = f"""Using the following context from the PDF documents:
|
451 |
{context}
|
452 |
Write a detailed and complete response that answers the following user question: '{query}'"""
|
453 |
-
|
454 |
instruction = f"""Using the following context:
|
455 |
{context}
|
456 |
Write a detailed and complete research document that fulfills the following user request: '{query}'
|
457 |
After writing the document, please provide a list of sources used in your response."""
|
458 |
-
elif search_type == "validation":
|
459 |
-
instruction = prompt # For validation, use the provided prompt directly
|
460 |
-
else:
|
461 |
-
raise ValueError("Invalid search_type")
|
462 |
|
463 |
inputs = [
|
464 |
{"role": "system", "content": instruction},
|
465 |
-
{"role": "user", "content": query
|
466 |
]
|
467 |
|
468 |
payload = {
|
@@ -509,35 +498,30 @@ def create_web_search_vectors(search_results):
|
|
509 |
|
510 |
return FAISS.from_documents(documents, embed)
|
511 |
|
512 |
-
def get_response_with_search(query, model, num_calls=3, temperature=0.
|
513 |
-
|
514 |
-
|
515 |
|
516 |
-
|
517 |
-
|
518 |
-
return
|
519 |
-
|
520 |
-
retriever = web_search_database.as_retriever(search_kwargs={"k": 10})
|
521 |
-
relevant_docs = retriever.get_relevant_documents(query)
|
522 |
|
523 |
-
|
|
|
524 |
|
525 |
prompt = f"""Using the following context from web search results:
|
526 |
{context}
|
527 |
-
|
528 |
-
|
529 |
-
Importantly, only include information that is directly supported by the retrieved content.
|
530 |
-
If any part of the information cannot be verified from the given sources, clearly state that it could not be confirmed."""
|
531 |
|
532 |
-
initial_response = ""
|
533 |
if model == "@cf/meta/llama-3.1-8b-instruct":
|
534 |
# Use Cloudflare API
|
535 |
for response in get_response_from_cloudflare(prompt="", context=context, query=query, num_calls=num_calls, temperature=temperature, search_type="web"):
|
536 |
-
|
537 |
else:
|
538 |
# Use Hugging Face API
|
539 |
client = InferenceClient(model, token=huggingface_token)
|
540 |
|
|
|
541 |
for i in range(num_calls):
|
542 |
for message in client.chat_completion(
|
543 |
messages=[{"role": "user", "content": prompt}],
|
@@ -547,17 +531,14 @@ If any part of the information cannot be verified from the given sources, clearl
|
|
547 |
):
|
548 |
if message.choices and message.choices[0].delta and message.choices[0].delta.content:
|
549 |
chunk = message.choices[0].delta.content
|
550 |
-
|
551 |
-
|
552 |
-
# Validation step
|
553 |
-
for revised_response in validate_response(initial_response, context, query, model, temperature):
|
554 |
-
yield revised_response, "" # Yield streaming revised response without sources
|
555 |
|
556 |
|
557 |
INSTRUCTION_PROMPTS = {
|
558 |
-
"Asset Managers": "
|
559 |
-
"Consumer Finance Companies": "
|
560 |
-
"Mortgage REITs": "
|
561 |
# Add more instruction prompts as needed
|
562 |
}
|
563 |
|
|
|
67 |
raise ValueError("Invalid parser specified. Use 'pypdf' or 'llamaparse'.")
|
68 |
|
69 |
def get_embeddings():
|
70 |
+
return HuggingFaceEmbeddings(model_name="sentence-transformers/stsb-roberta-large")
|
71 |
|
72 |
# Add this at the beginning of your script, after imports
|
73 |
DOCUMENTS_FILE = "uploaded_documents.json"
|
|
|
271 |
print(f"Final clean response: {final_response[:100]}...")
|
272 |
return final_response
|
273 |
|
274 |
+
class SimpleDDGSearch:
|
275 |
+
def search(self, query: str, num_results: int = 5):
|
276 |
+
results = DDGS().text(query, region='wt-wt', safesearch='off', max_results=num_results)
|
277 |
+
return [res["href"] for res in results]
|
278 |
+
|
279 |
+
class TrafilaturaWebCrawler:
|
280 |
+
def get_website_content_from_url(self, url: str) -> str:
|
281 |
+
try:
|
282 |
+
downloaded = fetch_url(url)
|
283 |
+
if downloaded is None:
|
284 |
+
return f"Failed to fetch content from URL: {url}"
|
285 |
+
|
286 |
+
result = extract(downloaded, output_format='json', include_comments=False, with_metadata=True, url=url)
|
287 |
+
|
288 |
+
if result:
|
289 |
+
result_dict = json.loads(result)
|
290 |
+
title = result_dict.get('title', 'No title found')
|
291 |
+
content = result_dict.get('text', 'No content extracted')
|
292 |
+
|
293 |
+
if content == 'No content extracted':
|
294 |
+
content = extract(downloaded, include_comments=False)
|
295 |
+
|
296 |
+
return f'=========== Website Title: {title} ===========\n\n=========== Website URL: {url} ===========\n\n=========== Website Content ===========\n\n{content}\n\n=========== Website Content End ===========\n\n'
|
297 |
+
else:
|
298 |
+
return f"No content extracted from URL: {url}"
|
299 |
+
except Exception as e:
|
300 |
+
return f"An error occurred while processing {url}: {str(e)}"
|
301 |
|
302 |
class CitingSources(BaseModel):
|
303 |
sources: List[str] = Field(
|
|
|
399 |
embed = get_embeddings()
|
400 |
if os.path.exists("faiss_database"):
|
401 |
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
402 |
+
retriever = database.as_retriever(search_kwargs={"k": 5}) # Retrieve top 5 most relevant chunks
|
403 |
|
404 |
# Create a generic query that covers common financial summary topics
|
405 |
generic_query = "financial performance revenue profit assets liabilities cash flow key metrics highlights"
|
|
|
432 |
else:
|
433 |
return "No documents available to answer the query."
|
434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
def get_response_from_cloudflare(prompt, context, query, num_calls=3, temperature=0.2, search_type="pdf"):
|
436 |
headers = {
|
437 |
"Authorization": f"Bearer {API_TOKEN}",
|
|
|
443 |
instruction = f"""Using the following context from the PDF documents:
|
444 |
{context}
|
445 |
Write a detailed and complete response that answers the following user question: '{query}'"""
|
446 |
+
else: # web search
|
447 |
instruction = f"""Using the following context:
|
448 |
{context}
|
449 |
Write a detailed and complete research document that fulfills the following user request: '{query}'
|
450 |
After writing the document, please provide a list of sources used in your response."""
|
|
|
|
|
|
|
|
|
451 |
|
452 |
inputs = [
|
453 |
{"role": "system", "content": instruction},
|
454 |
+
{"role": "user", "content": query}
|
455 |
]
|
456 |
|
457 |
payload = {
|
|
|
498 |
|
499 |
return FAISS.from_documents(documents, embed)
|
500 |
|
501 |
+
def get_response_with_search(query, model, num_calls=3, temperature=0.2):
|
502 |
+
searcher = SimpleDDGSearch()
|
503 |
+
search_results = searcher.search(query, num_results=5)
|
504 |
|
505 |
+
crawler = TrafilaturaWebCrawler()
|
506 |
+
context = ""
|
|
|
|
|
|
|
|
|
507 |
|
508 |
+
for url in search_results:
|
509 |
+
context += crawler.get_website_content_from_url(url) + "\n"
|
510 |
|
511 |
prompt = f"""Using the following context from web search results:
|
512 |
{context}
|
513 |
+
Write a detailed and complete research document that fulfills the following user request: '{query}'
|
514 |
+
After writing the document, please provide a list of sources used in your response."""
|
|
|
|
|
515 |
|
|
|
516 |
if model == "@cf/meta/llama-3.1-8b-instruct":
|
517 |
# Use Cloudflare API
|
518 |
for response in get_response_from_cloudflare(prompt="", context=context, query=query, num_calls=num_calls, temperature=temperature, search_type="web"):
|
519 |
+
yield response, "" # Yield streaming response without sources
|
520 |
else:
|
521 |
# Use Hugging Face API
|
522 |
client = InferenceClient(model, token=huggingface_token)
|
523 |
|
524 |
+
main_content = ""
|
525 |
for i in range(num_calls):
|
526 |
for message in client.chat_completion(
|
527 |
messages=[{"role": "user", "content": prompt}],
|
|
|
531 |
):
|
532 |
if message.choices and message.choices[0].delta and message.choices[0].delta.content:
|
533 |
chunk = message.choices[0].delta.content
|
534 |
+
main_content += chunk
|
535 |
+
yield main_content, "" # Yield partial main content without sources
|
|
|
|
|
|
|
536 |
|
537 |
|
538 |
INSTRUCTION_PROMPTS = {
|
539 |
+
"Asset Managers": "Summarize the key financial metrics, assets under management, and performance highlights for this asset management company.",
|
540 |
+
"Consumer Finance Companies": "Provide a summary of the company's loan portfolio, interest income, credit quality, and key operational metrics.",
|
541 |
+
"Mortgage REITs": "Summarize the REIT's mortgage-backed securities portfolio, net interest income, book value per share, and dividend yield.",
|
542 |
# Add more instruction prompts as needed
|
543 |
}
|
544 |
|