import datetime import os from langchain.chains import VectorDBQAWithSourcesChain import gradio as gr import langchain from langchain.vectorstores import Weaviate import faiss import pickle from langchain import OpenAI from arxiv import get_paper from ingest_faiss import create_vector_store def get_vectorstore(suffix): index = faiss.read_index(f"{suffix}/docs.index") with open(f"{suffix}/faiss_store.pkl", "rb") as f: store = pickle.load(f) store.index = index return store def set_openai_api_key(api_key, agent): if api_key: os.environ["OPENAI_API_KEY"] = api_key vectorstore = get_vectorstore() qa_chain = VectorDBQAWithSourcesChain.from_llm(llm=OpenAI(temperature=0), vectorstore=vectorstore) os.environ["OPENAI_API_KEY"] = "" return qa_chain def download_paper_and_embed(paper_arxiv_url, api_key): if paper_arxiv_url and api_key: paper_text = get_paper(paper_arxiv_url) if 'abs' in paper_arxiv_url: eprint_url = paper_arxiv_url.replace("https://arxiv.org/abs/", "https://arxiv.org/e-print/") elif 'pdf' in paper_arxiv_url: eprint_url = paper_arxiv_url.replace("https://arxiv.org/pdf/", "https://arxiv.org/e-print/") else: raise ValueError("Invalid arXiv URL") suffix = 'paper-dir/' + eprint_url.replace("https://arxiv.org/e-print/", "") os.environ["OPENAI_API_KEY"] = api_key if not os.path.exists(suffix + "/docs.index"): create_vector_store(suffix, paper_text) vectorstore = get_vectorstore(suffix) qa_chain = VectorDBQAWithSourcesChain.from_llm(llm=OpenAI(temperature=0), vectorstore=vectorstore) return qa_chain chain = None def chat(inp, history, paper_arxiv_url, api_key, agent): global chain if history is None: chain = download_paper_and_embed(paper_arxiv_url, api_key) history = history or [] # if agent is None: # history.append((inp, "Please paste your OpenAI key to use")) # return history, history print("\n==== date/time: " + str(datetime.datetime.now()) + " ====") print("inp: " + inp) history = history or [] agent = chain output = agent({"question": inp}) answer = output["answer"] sources = output["sources"] sources = sources.split(", ") sources = ", ".join([s.title() for s in sources]) history.append((inp, answer)) history.append(("Sources?", sources)) print(history) return history, history block = gr.Blocks(css=".gradio-container {background-color: lightgray}") with block: state = gr.State() agent_state = gr.State() with gr.Row(): gr.Markdown("