File size: 3,832 Bytes
12f53c6
121c906
 
 
 
 
 
 
 
 
bfecf10
 
121c906
 
 
 
dd26286
121c906
 
 
 
 
 
bfecf10
 
121c906
 
 
c19cdb4
ccba80f
c19cdb4
ccba80f
c19cdb4
 
 
 
8d14c3c
c19cdb4
 
 
 
 
 
 
ccba80f
c19cdb4
ccba80f
 
 
121c906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12f53c6
 
121c906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12f53c6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
from langchain.prompts import PromptTemplate
from langchain.embeddings import SentenceTransformerEmbeddings

# Set model_kwargs with trust_remote_code=True
embeddings = SentenceTransformerEmbeddings(
    model_name="nomic-ai/nomic-embed-text-v1.5",
    model_kwargs={"trust_remote_code": True}
)

print('Embeddings loaded successfully')

from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.document_loaders import TextLoader, PyPDFLoader

loader = PyPDFLoader("fibromyalgia-information-booklet-july2021.pdf")
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
vector_store = FAISS.from_documents(docs, embeddings)
retriever = vector_store.as_retriever()

print('Retriever loaded successfully')

from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

base_model_name = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# Load in 4-bit with CPU offload using quantization_config
# Removed load_in_4bit as it's redundant when using quantization_config
model = AutoModelForCausalLM.from_pretrained(
    base_model_name, 
    device_map="cpu", 
    trust_remote_code=True,  # Required for some models
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,  # Specify 4-bit quantization within BitsAndBytesConfig
        load_in_8bit_fp32_cpu_offload=True  # Enable CPU offload
    )
)
adapter_path = "mohamedalcafory/PubMed_Llama3.1_Based_model"
model.load_adapter(adapter_path)

# tokenizer = AutoTokenizer.from_pretrained("mohamedalcafory/PubMed_Llama3.1_Based_model")
# model = AutoModelForCausalLM.from_pretrained("mohamedalcafory/PubMed_Llama3.1_Based_model")
print(f'Model loaded successfully: {model}')

from transformers import pipeline
from langchain_huggingface import HuggingFacePipeline
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    temperature=0.7,
    top_p=0.95,
    repetition_penalty=1.15
)

llm = HuggingFacePipeline(pipeline=pipe)

prompt = PromptTemplate(
    input_variables=["query"],
    template="{query}"
)

# Define the retrieval chain
retrieve_docs = (lambda x: retriever.get_relevant_documents(x["query"]))

# Define the generator chain
generator_chain = (
    prompt
    | llm
    | StrOutputParser()
)

def format_docs(docs):
    # Check if docs is a list of Document objects or just strings
    if docs and hasattr(docs[0], 'page_content'):
        return "\n\n".join(doc.page_content for doc in docs)
    else:
        return "\n\n".join(str(doc) for doc in docs)

# Create the full RAG chain
rag_chain = (
    RunnablePassthrough.assign(context=retrieve_docs)
    | RunnablePassthrough.assign(
        formatted_context=lambda x: format_docs(x["context"])
    )
    | prompt
    | llm
    | StrOutputParser()
)

def process_query(query):
    try:
        response = rag_chain.invoke({"query": query})
        return response
    except Exception as e:
        return f"An error occurred: {str(e)}"

# Create Gradio interface
demo = gr.Interface(
    fn=process_query,
    inputs=gr.Textbox(label= "Your question", lines=2, placeholder="Enter your question here..."),
    outputs=gr.Textbox(label="Response"),
    title="Fibromyalgia Q&A Assistant",
    description="Ask questions and get answers based on the retrieved context.",
    examples=[
        ["How does Physiotherapy work with Fibromyalgia?"],
        ["What are the common treatments for chronic pain?"],
    ]
)

if __name__ == "__main__":
    demo.launch()