Spaces:
Runtime error
Runtime error
import torch | |
import transformers | |
import gradio as gr | |
from ragatouille import RAGPretrainedModel | |
import re | |
from datetime import datetime | |
import json | |
import arxiv | |
from helper import rag_cleaner, get_prompt_text, get_references, get_rag, SaveResponseAndRead, get_md_text_abstract, search_cleaner, get_arxiv_live_search | |
# Constants | |
RETRIEVE_RESULTS = 20 | |
LLM_MODELS = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2', 'google/gemma-7b-it', 'None'] | |
DEFAULT_LLM_MODEL = 'mistralai/Mistral-7B-Instruct-v0.2' | |
GENERATE_KWARGS = { | |
"temperature": None, | |
"max_new_tokens": 512, | |
"top_p": None, | |
"do_sample": False, | |
} | |
try: | |
# RAG Model setup | |
RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert") | |
semantic_search_available = True | |
try: | |
gr.Info("Setting up retriever, please wait...") | |
rag_initial_output = RAG.search("What is Generative AI in Healthcare?", k=1) | |
gr.Info("Retriever working successfully!") | |
except Exception as e: | |
gr.Warning(f"Retriever not working: {str(e)}") | |
except FileNotFoundError: | |
RAG = None | |
semantic_search_available = False | |
gr.Warning("Colbert index not found. Semantic search will be unavailable.") | |
# Header setup | |
mark_text = '# 🩺🔍 Search Results\n' | |
header_text = "## Arxiv Paper Summary With QA Retrieval Augmented Generation \n" | |
try: | |
with open("README.md", "r") as f: | |
mdfile = f.read() | |
date_pattern = r'Index Last Updated : \d{4}-\d{2}-\d{2}' | |
match = re.search(date_pattern, mdfile) | |
date = match.group().split(': ')[1] | |
formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y') | |
header_text += f'Index Last Updated: {formatted_date}\n' | |
index_info = f"Semantic Search - up to {formatted_date}" | |
except FileNotFoundError: | |
index_info = "Semantic Search" | |
if semantic_search_available: | |
database_choices = [index_info, 'Arxiv Search - Latest'] | |
else: | |
database_choices = ['Arxiv Search - Latest'] | |
# Arxiv API setup | |
arx_client = arxiv.Client() | |
is_arxiv_available = True | |
check_arxiv_result = get_arxiv_live_search("What is Self Rewarding AI and how can it be used in Multi-Agent Systems?", arx_client, RETRIEVE_RESULTS) | |
if len(check_arxiv_result) == 0: | |
is_arxiv_available = False | |
print("Arxiv search not working, switching to default search ...") | |
database_choices = [index_info] | |
# Gradio UI setup | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
header = gr.Markdown(header_text) | |
with gr.Group(): | |
search_query = gr.Textbox(label='Search', placeholder='What is Generative AI in Healthcare?') | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(equal_height=True): | |
llm_model = gr.Dropdown(choices=LLM_MODELS, value=DEFAULT_LLM_MODEL, label='LLM Model') | |
llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context") | |
database_src = gr.Dropdown(choices=database_choices, value=index_info, label='Search Source') | |
stream_results = gr.Checkbox(value=True, label="Stream output", visible=False) | |
output_text = gr.Textbox(show_label=True, container=True, label='LLM Answer', visible=True) | |
input = gr.Textbox(show_label=False, visible=False) | |
gr_md = gr.Markdown(mark_text) | |
def update_with_rag_md(search_query, llm_results_use=5, database_choice=index_info, llm_model_picked=DEFAULT_LLM_MODEL): | |
prompt_text_from_data = "" | |
if database_choice == index_info and semantic_search_available: | |
rag_out = get_rag(search_query, RAG, RETRIEVE_RESULTS) | |
database_to_use = 'Semantic Search' | |
else: | |
arxiv_search_success = True | |
try: | |
rag_out = get_arxiv_live_search(search_query, arx_client, RETRIEVE_RESULTS) | |
if len(rag_out) == 0: | |
arxiv_search_success = False | |
except Exception as e: | |
arxiv_search_success = False | |
gr.Warning(f"Arxiv Search not working: {str(e)}") | |
if not arxiv_search_success: | |
gr.Warning("Arxiv search failed. Please try again later.") | |
return "", "" | |
database_to_use = 'Arxiv Search' | |
md_text_updated = mark_text | |
for i, rag_answer in enumerate(rag_out): | |
if i < llm_results_use: | |
md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source=database_to_use, return_prompt_formatting=True) | |
prompt_text_from_data += f"{i+1}. {prompt_text}" | |
else: | |
md_text_paper = get_md_text_abstract(rag_answer, source=database_to_use) | |
md_text_updated += md_text_paper | |
prompt = get_prompt_text(search_query, prompt_text_from_data, llm_model_picked=llm_model_picked) | |
return md_text_updated, prompt | |
def ask_llm(prompt, llm_model_picked=DEFAULT_LLM_MODEL, stream_outputs=False): | |
model_disabled_text = "LLM Model is disabled" | |
output = "" | |
if llm_model_picked == 'None': | |
if stream_outputs: | |
for out in model_disabled_text: | |
output += out | |
yield output | |
else: | |
return model_disabled_text | |
client = InferenceClient(llm_model_picked) | |
try: | |
response = client.text_generation(prompt, stream=stream_outputs, details=False, return_full_text=False, **GENERATE_KWARGS) | |
if stream_outputs: | |
for token in response: | |
output += token | |
yield SaveResponseAndRead(output) | |
else: | |
output = response | |
except Exception as e: | |
gr.Warning(f"LLM Inference failed: {str(e)}") | |
output = "" | |
return output | |
search_query.submit(update_with_rag_md, [search_query, llm_results, database_src, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text) | |
demo.queue().launch() |