import gradio as gr from backend import get_answer import pickle import re import os from pdf_classes import PDFSegment, PDFPage, RichPDFDocument def load_enriched_pdf(file_path): with open(file_path, 'rb') as f: return pickle.load(f) # Load the enriched PDF enriched_pdf = load_enriched_pdf('enriched_pdf.pkl') # Access API tokens from environment variables jina_api_token = os.getenv('JINA_API_TOKEN') gpt4_api_key = os.getenv('GPT4_API_KEY') pinecone_api_key = os.getenv('PINECONE_API_KEY') # Initialize Pinecone with environment variables os.environ["PINECONE_API_KEY"] = pinecone_api_key os.environ["PINECONE_ENVIRONMENT"] = "us-east-1" # Sample data for papers (5 papers for the grid) papers = [ {"id": "1", "title": "Attention Is All You Need", "authors": "Vaswani et al.", "year": 2017}, {"id": "2", "title": "BERT", "authors": "Devlin et al.", "year": 2018}, {"id": "3", "title": "GPT-3", "authors": "Brown et al.", "year": 2020}, {"id": "4", "title": "Transformer-XL", "authors": "Dai et al.", "year": 2019}, {"id": "5", "title": "T5", "authors": "Raffel et al.", "year": 2020}, ] predefined_questions = { '1': [ 'Explain equation one in laymen terms and explain each and every component?', 'Create list of authors who contributed to the paper in the same order, starting from left to right and go down?', 'Explain figure two, left to right and also explain the flow of the diagram?', 'Explain the position-wise Feed forward networks and equation two?', 'Please summarize the findings from table 1?', 'Explain the optimizer used and explain equation 3', 'What is BLUE score for Tranformer model from Table 2?', 'What does Figure 1 illustrate about the overall architecture of the Transformer model?', 'How does Figure 2 depict the difference between Scaled Dot-Product Attention and Multi-Head Attention?', 'Based on Figure 1, how many encoder and decoder layers are used in the Transformer model?', 'What mathematical formula is shown in Figure 2 for Scaled Dot-Product Attention?', 'According to Table 1, how does the complexity of Self-Attention compare to Recurrent and Convolutional layers?', 'What does Table 2 reveal about the BLEU scores and training costs of the Transformer compared to other models?', "How does Table 3 visualize the impact of different model variations on the Transformer's performance?", 'What does Equation 3 in the paper represent, and how is it visually presented?', 'Can you describe the sinusoidal function used for positional encoding as shown in the equations in Section 3.5?', "How does Figure 1 illustrate the flow of information in the Transformer's encoder-decoder structure?" ] } css = """ body { font-family: Arial, sans-serif; } .container { max-width: 800px; margin: 0 auto; padding: 20px; } .hero { text-align: center; margin-bottom: 30px; } .paper-grid { display: grid; grid-template-columns: repeat(5, 1fr); gap: 10px; margin-bottom: 30px; } .paper-tile { background-color: white; border: 2px solid #ddd; border-radius: 8px; padding: 10px; cursor: pointer; transition: all 0.3s; } .paper-tile:hover { transform: translateY(-5px); box-shadow: 0 5px 15px rgba(0,0,0,0.1); } .paper-tile.selected { border-color: #007bff; background-color: #e6f3ff; } .paper-tile h3 { margin-top: 0; font-size: 14px; } .paper-tile p { margin: 5px 0; font-size: 12px; color: #666; } #chat-area { background-color: white; border-radius: 8px; padding: 20px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); } """ def update_predefined_questions(paper_id): if paper_id in predefined_questions: return gr.Dropdown(choices=predefined_questions[paper_id], visible=True) return gr.Dropdown(choices=[], visible=False) def format_answer(answer): # Convert LaTeX-style math to Markdown-style math answer = re.sub(r'\\\[(.*?)\\\]', r'$$\1$$', answer) answer = re.sub(r'\\\((.*?)\\\)', r'$\1$', answer) # Format headers lines = answer.split('\n') formatted_lines = [] for line in lines: if line.startswith('###'): formatted_lines.append(f"\n{line}\n") elif line.startswith('**') and line.endswith('**'): formatted_lines.append(f"\n{line}\n") else: formatted_lines.append(line) # Join lines back together formatted_answer = '\n'.join(formatted_lines) # Add spacing around math blocks formatted_answer = re.sub(r'(\\\\.*?\\\\)', r'\n\1\n', formatted_answer) return formatted_answer def update_chat_area(paper_id, predefined_question): if not paper_id: return "Please select a paper first." selected_paper = next((p for p in papers if p['id'] == paper_id), None) if not selected_paper: return "Invalid paper selection." if selected_paper['id'] != '1': return "This paper will be supported soon." if not predefined_question: return "Please select a predefined question." # Call the backend function to get the answer answer = get_answer(predefined_question, enriched_pdf, jina_api_token, gpt4_api_key) return format_answer(answer) if answer else "Failed to generate an answer. Please try again." with gr.Blocks(css=css) as demo: gr.HTML('''

AI Paper Q&A

Select a paper and ask questions about it. Questions are pre-generated but answers are generated live.

''') paper_id_input = gr.Textbox(visible=False) with gr.Row(): paper_tiles = gr.Radio( choices=[f"{p['title']} ({p['authors']}, {p['year']})" for p in papers], label="Select a paper", info="Choose one of the papers to ask questions about." ) predefined_question_dropdown = gr.Dropdown(label="Select a predefined question", choices=[], visible=False) custom_question_input = gr.Textbox( label="Or ask your own question here...", value="Will be supported later after adding prompt guard", interactive=False ) submit_btn = gr.Button("Submit") chat_output = gr.Markdown(label="Answer") def update_chat_area_with_loading(paper_id, predefined_question): # Display loading message while processing loading_message = "**Generating answer...**" # Return early with loading message to show progress yield loading_message # Call the actual function and yield its result yield update_chat_area(paper_id, predefined_question) paper_tiles.change( fn=lambda x: next((p['id'] for p in papers if f"{p['title']} ({p['authors']}, {p['year']})" == x), None), inputs=[paper_tiles], outputs=[paper_id_input] ) paper_id_input.change( fn=update_predefined_questions, inputs=[paper_id_input], outputs=[predefined_question_dropdown] ) submit_btn.click( fn=update_chat_area_with_loading, inputs=[paper_id_input, predefined_question_dropdown], outputs=[chat_output] ) if __name__ == '__main__': demo.launch()