Spaces:
Runtime error
Runtime error
File size: 6,172 Bytes
3701fee 3dbe475 3701fee bae5df9 3701fee 3dbe475 3701fee 3dbe475 bae5df9 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 bae5df9 3dbe475 bae5df9 3701fee 3dbe475 bae5df9 3dbe475 bae5df9 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
import transformers
import gradio as gr
from ragatouille import RAGPretrainedModel
import re
from datetime import datetime
import json
import arxiv
from helper import rag_cleaner, get_prompt_text, get_references, get_rag, SaveResponseAndRead, get_md_text_abstract, search_cleaner, get_arxiv_live_search
# Constants
RETRIEVE_RESULTS = 20
LLM_MODELS = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2', 'google/gemma-7b-it', 'None']
DEFAULT_LLM_MODEL = 'mistralai/Mistral-7B-Instruct-v0.2'
GENERATE_KWARGS = {
"temperature": None,
"max_new_tokens": 512,
"top_p": None,
"do_sample": False,
}
try:
# RAG Model setup
RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
semantic_search_available = True
try:
gr.Info("Setting up retriever, please wait...")
rag_initial_output = RAG.search("What is Generative AI in Healthcare?", k=1)
gr.Info("Retriever working successfully!")
except Exception as e:
gr.Warning(f"Retriever not working: {str(e)}")
except FileNotFoundError:
RAG = None
semantic_search_available = False
gr.Warning("Colbert index not found. Semantic search will be unavailable.")
# Header setup
mark_text = '# 🩺🔍 Search Results\n'
header_text = "## Arxiv Paper Summary With QA Retrieval Augmented Generation \n"
try:
with open("README.md", "r") as f:
mdfile = f.read()
date_pattern = r'Index Last Updated : \d{4}-\d{2}-\d{2}'
match = re.search(date_pattern, mdfile)
date = match.group().split(': ')[1]
formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y')
header_text += f'Index Last Updated: {formatted_date}\n'
index_info = f"Semantic Search - up to {formatted_date}"
except FileNotFoundError:
index_info = "Semantic Search"
if semantic_search_available:
database_choices = [index_info, 'Arxiv Search - Latest']
else:
database_choices = ['Arxiv Search - Latest']
# Arxiv API setup
arx_client = arxiv.Client()
is_arxiv_available = True
check_arxiv_result = get_arxiv_live_search("What is Self Rewarding AI and how can it be used in Multi-Agent Systems?", arx_client, RETRIEVE_RESULTS)
if len(check_arxiv_result) == 0:
is_arxiv_available = False
print("Arxiv search not working, switching to default search ...")
database_choices = [index_info]
# Gradio UI setup
with gr.Blocks(theme=gr.themes.Soft()) as demo:
header = gr.Markdown(header_text)
with gr.Group():
search_query = gr.Textbox(label='Search', placeholder='What is Generative AI in Healthcare?')
with gr.Accordion("Advanced Settings", open=False):
with gr.Row(equal_height=True):
llm_model = gr.Dropdown(choices=LLM_MODELS, value=DEFAULT_LLM_MODEL, label='LLM Model')
llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context")
database_src = gr.Dropdown(choices=database_choices, value=index_info, label='Search Source')
stream_results = gr.Checkbox(value=True, label="Stream output", visible=False)
output_text = gr.Textbox(show_label=True, container=True, label='LLM Answer', visible=True)
input = gr.Textbox(show_label=False, visible=False)
gr_md = gr.Markdown(mark_text)
def update_with_rag_md(search_query, llm_results_use=5, database_choice=index_info, llm_model_picked=DEFAULT_LLM_MODEL):
prompt_text_from_data = ""
if database_choice == index_info and semantic_search_available:
rag_out = get_rag(search_query, RAG, RETRIEVE_RESULTS)
database_to_use = 'Semantic Search'
else:
arxiv_search_success = True
try:
rag_out = get_arxiv_live_search(search_query, arx_client, RETRIEVE_RESULTS)
if len(rag_out) == 0:
arxiv_search_success = False
except Exception as e:
arxiv_search_success = False
gr.Warning(f"Arxiv Search not working: {str(e)}")
if not arxiv_search_success:
gr.Warning("Arxiv search failed. Please try again later.")
return "", ""
database_to_use = 'Arxiv Search'
md_text_updated = mark_text
for i, rag_answer in enumerate(rag_out):
if i < llm_results_use:
md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source=database_to_use, return_prompt_formatting=True)
prompt_text_from_data += f"{i+1}. {prompt_text}"
else:
md_text_paper = get_md_text_abstract(rag_answer, source=database_to_use)
md_text_updated += md_text_paper
prompt = get_prompt_text(search_query, prompt_text_from_data, llm_model_picked=llm_model_picked)
return md_text_updated, prompt
def ask_llm(prompt, llm_model_picked=DEFAULT_LLM_MODEL, stream_outputs=False):
model_disabled_text = "LLM Model is disabled"
output = ""
if llm_model_picked == 'None':
if stream_outputs:
for out in model_disabled_text:
output += out
yield output
else:
return model_disabled_text
client = InferenceClient(llm_model_picked)
try:
response = client.text_generation(prompt, stream=stream_outputs, details=False, return_full_text=False, **GENERATE_KWARGS)
if stream_outputs:
for token in response:
output += token
yield SaveResponseAndRead(output)
else:
output = response
except Exception as e:
gr.Warning(f"LLM Inference failed: {str(e)}")
output = ""
return output
search_query.submit(update_with_rag_md, [search_query, llm_results, database_src, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text)
demo.queue().launch() |