arxiv-RAG / app.py
jharrison27's picture
Handle no colbert indexes
bae5df9 verified
import torch
import transformers
import gradio as gr
from ragatouille import RAGPretrainedModel
import re
from datetime import datetime
import json
import arxiv
from helper import rag_cleaner, get_prompt_text, get_references, get_rag, SaveResponseAndRead, get_md_text_abstract, search_cleaner, get_arxiv_live_search
# Constants
RETRIEVE_RESULTS = 20
LLM_MODELS = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2', 'google/gemma-7b-it', 'None']
DEFAULT_LLM_MODEL = 'mistralai/Mistral-7B-Instruct-v0.2'
GENERATE_KWARGS = {
"temperature": None,
"max_new_tokens": 512,
"top_p": None,
"do_sample": False,
}
try:
# RAG Model setup
RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
semantic_search_available = True
try:
gr.Info("Setting up retriever, please wait...")
rag_initial_output = RAG.search("What is Generative AI in Healthcare?", k=1)
gr.Info("Retriever working successfully!")
except Exception as e:
gr.Warning(f"Retriever not working: {str(e)}")
except FileNotFoundError:
RAG = None
semantic_search_available = False
gr.Warning("Colbert index not found. Semantic search will be unavailable.")
# Header setup
mark_text = '# 🩺🔍 Search Results\n'
header_text = "## Arxiv Paper Summary With QA Retrieval Augmented Generation \n"
try:
with open("README.md", "r") as f:
mdfile = f.read()
date_pattern = r'Index Last Updated : \d{4}-\d{2}-\d{2}'
match = re.search(date_pattern, mdfile)
date = match.group().split(': ')[1]
formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y')
header_text += f'Index Last Updated: {formatted_date}\n'
index_info = f"Semantic Search - up to {formatted_date}"
except FileNotFoundError:
index_info = "Semantic Search"
if semantic_search_available:
database_choices = [index_info, 'Arxiv Search - Latest']
else:
database_choices = ['Arxiv Search - Latest']
# Arxiv API setup
arx_client = arxiv.Client()
is_arxiv_available = True
check_arxiv_result = get_arxiv_live_search("What is Self Rewarding AI and how can it be used in Multi-Agent Systems?", arx_client, RETRIEVE_RESULTS)
if len(check_arxiv_result) == 0:
is_arxiv_available = False
print("Arxiv search not working, switching to default search ...")
database_choices = [index_info]
# Gradio UI setup
with gr.Blocks(theme=gr.themes.Soft()) as demo:
header = gr.Markdown(header_text)
with gr.Group():
search_query = gr.Textbox(label='Search', placeholder='What is Generative AI in Healthcare?')
with gr.Accordion("Advanced Settings", open=False):
with gr.Row(equal_height=True):
llm_model = gr.Dropdown(choices=LLM_MODELS, value=DEFAULT_LLM_MODEL, label='LLM Model')
llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context")
database_src = gr.Dropdown(choices=database_choices, value=index_info, label='Search Source')
stream_results = gr.Checkbox(value=True, label="Stream output", visible=False)
output_text = gr.Textbox(show_label=True, container=True, label='LLM Answer', visible=True)
input = gr.Textbox(show_label=False, visible=False)
gr_md = gr.Markdown(mark_text)
def update_with_rag_md(search_query, llm_results_use=5, database_choice=index_info, llm_model_picked=DEFAULT_LLM_MODEL):
prompt_text_from_data = ""
if database_choice == index_info and semantic_search_available:
rag_out = get_rag(search_query, RAG, RETRIEVE_RESULTS)
database_to_use = 'Semantic Search'
else:
arxiv_search_success = True
try:
rag_out = get_arxiv_live_search(search_query, arx_client, RETRIEVE_RESULTS)
if len(rag_out) == 0:
arxiv_search_success = False
except Exception as e:
arxiv_search_success = False
gr.Warning(f"Arxiv Search not working: {str(e)}")
if not arxiv_search_success:
gr.Warning("Arxiv search failed. Please try again later.")
return "", ""
database_to_use = 'Arxiv Search'
md_text_updated = mark_text
for i, rag_answer in enumerate(rag_out):
if i < llm_results_use:
md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source=database_to_use, return_prompt_formatting=True)
prompt_text_from_data += f"{i+1}. {prompt_text}"
else:
md_text_paper = get_md_text_abstract(rag_answer, source=database_to_use)
md_text_updated += md_text_paper
prompt = get_prompt_text(search_query, prompt_text_from_data, llm_model_picked=llm_model_picked)
return md_text_updated, prompt
def ask_llm(prompt, llm_model_picked=DEFAULT_LLM_MODEL, stream_outputs=False):
model_disabled_text = "LLM Model is disabled"
output = ""
if llm_model_picked == 'None':
if stream_outputs:
for out in model_disabled_text:
output += out
yield output
else:
return model_disabled_text
client = InferenceClient(llm_model_picked)
try:
response = client.text_generation(prompt, stream=stream_outputs, details=False, return_full_text=False, **GENERATE_KWARGS)
if stream_outputs:
for token in response:
output += token
yield SaveResponseAndRead(output)
else:
output = response
except Exception as e:
gr.Warning(f"LLM Inference failed: {str(e)}")
output = ""
return output
search_query.submit(update_with_rag_md, [search_query, llm_results, database_src, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text)
demo.queue().launch()