Spaces:
Runtime error
Runtime error
import gradio as gr | |
from langchain.prompts import PromptTemplate | |
from langchain.embeddings import SentenceTransformerEmbeddings | |
# Set model_kwargs with trust_remote_code=True | |
embeddings = SentenceTransformerEmbeddings( | |
model_name="nomic-ai/nomic-embed-text-v1.5", | |
model_kwargs={"trust_remote_code": True} | |
) | |
print('Embeddings loaded successfully') | |
from langchain_community.vectorstores import FAISS | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain.document_loaders import TextLoader, PyPDFLoader | |
loader = PyPDFLoader("fibromyalgia-information-booklet-july2021.pdf") | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
docs = text_splitter.split_documents(documents) | |
vector_store = FAISS.from_documents(docs, embeddings) | |
retriever = vector_store.as_retriever() | |
print('Retriever loaded successfully') | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
tokenizer = AutoTokenizer.from_pretrained("mohamedalcafory/PubMed_Llama3.1_Based_model") | |
model = AutoModelForCausalLM.from_pretrained("mohamedalcafory/PubMed_Llama3.1_Based_model") | |
print('Model loaded successfully') | |
from transformers import pipeline | |
from langchain_huggingface import HuggingFacePipeline | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.15 | |
) | |
llm = HuggingFacePipeline(pipeline=pipe) | |
prompt = PromptTemplate( | |
input_variables=["query"], | |
template="{query}" | |
) | |
# Define the retrieval chain | |
retrieve_docs = (lambda x: retriever.get_relevant_documents(x["query"])) | |
# Define the generator chain | |
generator_chain = ( | |
prompt | |
| llm | |
| StrOutputParser() | |
) | |
def format_docs(docs): | |
# Check if docs is a list of Document objects or just strings | |
if docs and hasattr(docs[0], 'page_content'): | |
return "\n\n".join(doc.page_content for doc in docs) | |
else: | |
return "\n\n".join(str(doc) for doc in docs) | |
# Create the full RAG chain | |
rag_chain = ( | |
RunnablePassthrough.assign(context=retrieve_docs) | |
| RunnablePassthrough.assign( | |
formatted_context=lambda x: format_docs(x["context"]) | |
) | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
def process_query(query): | |
try: | |
response = rag_chain.invoke({"query": query}) | |
return response | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=process_query, | |
inputs=gr.Textbox(label= "Your question", lines=2, placeholder="Enter your question here..."), | |
outputs=gr.Textbox(label="Response"), | |
title="Fibromyalgia Q&A Assistant", | |
description="Ask questions and get answers based on the retrieved context.", | |
examples=[ | |
["How does Physiotherapy work with Fibromyalgia?"], | |
["What are the common treatments for chronic pain?"], | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() |