pdf_qa / llm.py
mobinln's picture
feat: remove cache, add context expander
786f732
import streamlit as st
import pathlib
from huggingface_hub import hf_hub_download
from langchain_community.llms import LlamaCpp
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.globals import set_debug
set_debug(True)
@st.cache_resource()
def load_llm(repo_id, filename):
# Create a directory for models if it doesn't exist
models_folder = pathlib.Path("models")
models_folder.mkdir(exist_ok=True)
# Download the model
model_path = hf_hub_download(
repo_id=repo_id, filename=filename, local_dir=models_folder
)
llm = LlamaCpp(
model_path=model_path,
repo_id=repo_id,
filename=filename,
verbose=False,
use_mmap=True,
use_mlock=True,
n_threads=4,
n_threads_batch=4,
n_ctx=8000,
max_tokens=128,
# stop=["."],
)
print(f"{repo_id} loaded successfully. ✅")
return llm
# Streamed response emulator
def response_generator(llm, messages, question, retriever):
# System prompt setting up context for the assistant
system_prompt = (
"<|im_start|>system\n"
"You are an AI assistant specializing in question-answering tasks. "
"Utilize the provided context and past conversation to answer "
"the current question. If the answer is unknown, clearly state that you "
"don't know. Keep responses concise and direct."
"\n\n"
"Context: {context}"
"\n<|im_end|>"
)
# Prepare message history
message_history = [("system", system_prompt)]
# Append conversation history to messages
for message in messages:
if message["role"] == "user":
message_history.append(
("user", "<|im_start|>user\n" + message["content"] + "\n<|im_end|>")
)
elif message["role"] == "assistant":
message_history.append(
(
"assistant",
"<|im_start|>assistant\n" + message["content"] + "\n<|im_end|>",
)
)
message_history.append(("assistant", "<|im_start|>assistant\n"))
# Create prompt template with full message history
prompt = ChatPromptTemplate.from_messages(message_history)
# Instantiate chains for document retrieval and question answering
question_answer_chain = create_stuff_documents_chain(llm, prompt)
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
# Invoke RAG (retrieval-augmented generation) chain with current input
results = rag_chain.invoke({"input": question}, verbose=True)
return results