RegBotBeta / models /llamaCustomV2.py
hbui's picture
llama-index-update (#1)
170741d verified
import os
import time
from llama_index.core import VectorStoreIndex
from llama_index.core.query_pipeline import (
QueryPipeline,
InputComponent,
ArgPackComponent,
)
from llama_index.core.prompts import PromptTemplate
from llama_index.llms.openai import OpenAI
from llama_index.postprocessor.colbert_rerank import ColbertRerank
from typing import Any, Dict, List, Optional
from llama_index.core.bridge.pydantic import Field
from llama_index.core.llms import ChatMessage
from llama_index.core.query_pipeline import CustomQueryComponent
from llama_index.core.schema import NodeWithScore
from llama_index.core.memory import ChatMemoryBuffer
llm = OpenAI(
model="gpt-3.5-turbo-0125",
api_key=os.getenv("OPENAI_API_KEY"),
)
# First, we create an input component to capture the user query
input_component = InputComponent()
# Next, we use the LLM to rewrite a user query
rewrite = (
"Please write a query to a semantic search engine using the current conversation.\n"
"\n"
"\n"
"{chat_history_str}"
"\n"
"\n"
"Latest message: {query_str}\n"
'Query:"""\n'
)
rewrite_template = PromptTemplate(rewrite)
# we will retrieve two times, so we need to pack the retrieved nodes into a single list
argpack_component = ArgPackComponent()
# then postprocess/rerank with Colbert
reranker = ColbertRerank(top_n=3)
DEFAULT_CONTEXT_PROMPT = (
"Here is some context that may be relevant:\n"
"-----\n"
"{node_context}\n"
"-----\n"
"Please write a response to the following question, using the above context:\n"
"{query_str}\n"
"Please formate your response in the following way:\n"
"Your answer here.\n"
"Reference:\n"
" Your references here (e.g. page numbers, titles, etc.).\n"
)
class ResponseWithChatHistory(CustomQueryComponent):
llm: OpenAI = Field(..., description="OpenAI LLM")
system_prompt: Optional[str] = Field(
default=None, description="System prompt to use for the LLM"
)
context_prompt: str = Field(
default=DEFAULT_CONTEXT_PROMPT,
description="Context prompt to use for the LLM",
)
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
# NOTE: this is OPTIONAL but we show you where to do validation as an example
return input
@property
def _input_keys(self) -> set:
"""Input keys dict."""
# NOTE: These are required inputs. If you have optional inputs please override
# `optional_input_keys_dict`
return {"chat_history", "nodes", "query_str"}
@property
def _output_keys(self) -> set:
return {"response"}
def _prepare_context(
self,
chat_history: List[ChatMessage],
nodes: List[NodeWithScore],
query_str: str,
) -> List[ChatMessage]:
node_context = ""
for idx, node in enumerate(nodes):
node_text = node.get_content(metadata_mode="llm")
node_context += f"Context Chunk {idx}:\n{node_text}\n\n"
formatted_context = self.context_prompt.format(
node_context=node_context, query_str=query_str
)
user_message = ChatMessage(role="user", content=formatted_context)
chat_history.append(user_message)
if self.system_prompt is not None:
chat_history = [
ChatMessage(role="system", content=self.system_prompt)
] + chat_history
return chat_history
def _run_component(self, **kwargs) -> Dict[str, Any]:
"""Run the component."""
chat_history = kwargs["chat_history"]
nodes = kwargs["nodes"]
query_str = kwargs["query_str"]
prepared_context = self._prepare_context(chat_history, nodes, query_str)
response = llm.chat(prepared_context)
return {"response": response}
async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]:
"""Run the component asynchronously."""
# NOTE: Optional, but async LLM calls are easy to implement
chat_history = kwargs["chat_history"]
nodes = kwargs["nodes"]
query_str = kwargs["query_str"]
prepared_context = self._prepare_context(chat_history, nodes, query_str)
response = await llm.achat(prepared_context)
return {"response": response}
class LlamaCustomV2:
response_component = ResponseWithChatHistory(
llm=llm,
system_prompt=(
"You are a Q&A system. You will be provided with the previous chat history, "
"as well as possibly relevant context, to assist in answering a user message."
),
)
def __init__(self, model_name: str, index: VectorStoreIndex):
self.model_name = model_name
self.index = index
self.retriever = index.as_retriever()
self.chat_mode = "condense_plus_context"
self.memory = ChatMemoryBuffer.from_defaults()
self.verbose = True
self._build_pipeline()
def _build_pipeline(self):
self.pipeline = QueryPipeline(
modules={
"input": input_component,
"rewrite_template": rewrite_template,
"llm": llm,
"rewrite_retriever": self.retriever,
"query_retriever": self.retriever,
"join": argpack_component,
"reranker": reranker,
"response_component": self.response_component,
},
verbose=self.verbose,
)
# run both retrievers -- once with the hallucinated query, once with the real query
self.pipeline.add_link(
"input", "rewrite_template", src_key="query_str", dest_key="query_str"
)
self.pipeline.add_link(
"input",
"rewrite_template",
src_key="chat_history_str",
dest_key="chat_history_str",
)
self.pipeline.add_link("rewrite_template", "llm")
self.pipeline.add_link("llm", "rewrite_retriever")
self.pipeline.add_link("input", "query_retriever", src_key="query_str")
# each input to the argpack component needs a dest key -- it can be anything
# then, the argpack component will pack all the inputs into a single list
self.pipeline.add_link("rewrite_retriever", "join", dest_key="rewrite_nodes")
self.pipeline.add_link("query_retriever", "join", dest_key="query_nodes")
# reranker needs the packed nodes and the query string
self.pipeline.add_link("join", "reranker", dest_key="nodes")
self.pipeline.add_link(
"input", "reranker", src_key="query_str", dest_key="query_str"
)
# synthesizer needs the reranked nodes and query str
self.pipeline.add_link("reranker", "response_component", dest_key="nodes")
self.pipeline.add_link(
"input", "response_component", src_key="query_str", dest_key="query_str"
)
self.pipeline.add_link(
"input",
"response_component",
src_key="chat_history",
dest_key="chat_history",
)
def get_response(self, query_str: str, chat_history: List[ChatMessage]):
chat_history = self.memory.get()
char_history_str = "\n".join([str(x) for x in chat_history])
response = self.pipeline.run(
query_str=query_str,
chat_history=chat_history,
chat_history_str=char_history_str,
)
user_msg = ChatMessage(role="user", content=query_str)
print("user_msg: ", str(user_msg))
print("response: ", str(response.message))
self.memory.put(user_msg)
self.memory.put(response.message)
return str(response.message)
def get_stream_response(self, query_str: str, chat_history: List[ChatMessage]):
response = self.get_response(query_str=query_str, chat_history=chat_history)
for word in response.split():
yield word + " "
time.sleep(0.05)