import os import time from llama_index.core import VectorStoreIndex from llama_index.core.query_pipeline import ( QueryPipeline, InputComponent, ArgPackComponent, ) from llama_index.core.prompts import PromptTemplate from llama_index.llms.openai import OpenAI from llama_index.postprocessor.colbert_rerank import ColbertRerank from typing import Any, Dict, List, Optional from llama_index.core.bridge.pydantic import Field from llama_index.core.llms import ChatMessage from llama_index.core.query_pipeline import CustomQueryComponent from llama_index.core.schema import NodeWithScore from llama_index.core.memory import ChatMemoryBuffer llm = OpenAI( model="gpt-3.5-turbo-0125", api_key=os.getenv("OPENAI_API_KEY"), ) # First, we create an input component to capture the user query input_component = InputComponent() # Next, we use the LLM to rewrite a user query rewrite = ( "Please write a query to a semantic search engine using the current conversation.\n" "\n" "\n" "{chat_history_str}" "\n" "\n" "Latest message: {query_str}\n" 'Query:"""\n' ) rewrite_template = PromptTemplate(rewrite) # we will retrieve two times, so we need to pack the retrieved nodes into a single list argpack_component = ArgPackComponent() # then postprocess/rerank with Colbert reranker = ColbertRerank(top_n=3) DEFAULT_CONTEXT_PROMPT = ( "Here is some context that may be relevant:\n" "-----\n" "{node_context}\n" "-----\n" "Please write a response to the following question, using the above context:\n" "{query_str}\n" "Please formate your response in the following way:\n" "Your answer here.\n" "Reference:\n" " Your references here (e.g. page numbers, titles, etc.).\n" ) class ResponseWithChatHistory(CustomQueryComponent): llm: OpenAI = Field(..., description="OpenAI LLM") system_prompt: Optional[str] = Field( default=None, description="System prompt to use for the LLM" ) context_prompt: str = Field( default=DEFAULT_CONTEXT_PROMPT, description="Context prompt to use for the LLM", ) def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]: """Validate component inputs during run_component.""" # NOTE: this is OPTIONAL but we show you where to do validation as an example return input @property def _input_keys(self) -> set: """Input keys dict.""" # NOTE: These are required inputs. If you have optional inputs please override # `optional_input_keys_dict` return {"chat_history", "nodes", "query_str"} @property def _output_keys(self) -> set: return {"response"} def _prepare_context( self, chat_history: List[ChatMessage], nodes: List[NodeWithScore], query_str: str, ) -> List[ChatMessage]: node_context = "" for idx, node in enumerate(nodes): node_text = node.get_content(metadata_mode="llm") node_context += f"Context Chunk {idx}:\n{node_text}\n\n" formatted_context = self.context_prompt.format( node_context=node_context, query_str=query_str ) user_message = ChatMessage(role="user", content=formatted_context) chat_history.append(user_message) if self.system_prompt is not None: chat_history = [ ChatMessage(role="system", content=self.system_prompt) ] + chat_history return chat_history def _run_component(self, **kwargs) -> Dict[str, Any]: """Run the component.""" chat_history = kwargs["chat_history"] nodes = kwargs["nodes"] query_str = kwargs["query_str"] prepared_context = self._prepare_context(chat_history, nodes, query_str) response = llm.chat(prepared_context) return {"response": response} async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]: """Run the component asynchronously.""" # NOTE: Optional, but async LLM calls are easy to implement chat_history = kwargs["chat_history"] nodes = kwargs["nodes"] query_str = kwargs["query_str"] prepared_context = self._prepare_context(chat_history, nodes, query_str) response = await llm.achat(prepared_context) return {"response": response} class LlamaCustomV2: response_component = ResponseWithChatHistory( llm=llm, system_prompt=( "You are a Q&A system. You will be provided with the previous chat history, " "as well as possibly relevant context, to assist in answering a user message." ), ) def __init__(self, model_name: str, index: VectorStoreIndex): self.model_name = model_name self.index = index self.retriever = index.as_retriever() self.chat_mode = "condense_plus_context" self.memory = ChatMemoryBuffer.from_defaults() self.verbose = True self._build_pipeline() def _build_pipeline(self): self.pipeline = QueryPipeline( modules={ "input": input_component, "rewrite_template": rewrite_template, "llm": llm, "rewrite_retriever": self.retriever, "query_retriever": self.retriever, "join": argpack_component, "reranker": reranker, "response_component": self.response_component, }, verbose=self.verbose, ) # run both retrievers -- once with the hallucinated query, once with the real query self.pipeline.add_link( "input", "rewrite_template", src_key="query_str", dest_key="query_str" ) self.pipeline.add_link( "input", "rewrite_template", src_key="chat_history_str", dest_key="chat_history_str", ) self.pipeline.add_link("rewrite_template", "llm") self.pipeline.add_link("llm", "rewrite_retriever") self.pipeline.add_link("input", "query_retriever", src_key="query_str") # each input to the argpack component needs a dest key -- it can be anything # then, the argpack component will pack all the inputs into a single list self.pipeline.add_link("rewrite_retriever", "join", dest_key="rewrite_nodes") self.pipeline.add_link("query_retriever", "join", dest_key="query_nodes") # reranker needs the packed nodes and the query string self.pipeline.add_link("join", "reranker", dest_key="nodes") self.pipeline.add_link( "input", "reranker", src_key="query_str", dest_key="query_str" ) # synthesizer needs the reranked nodes and query str self.pipeline.add_link("reranker", "response_component", dest_key="nodes") self.pipeline.add_link( "input", "response_component", src_key="query_str", dest_key="query_str" ) self.pipeline.add_link( "input", "response_component", src_key="chat_history", dest_key="chat_history", ) def get_response(self, query_str: str, chat_history: List[ChatMessage]): chat_history = self.memory.get() char_history_str = "\n".join([str(x) for x in chat_history]) response = self.pipeline.run( query_str=query_str, chat_history=chat_history, chat_history_str=char_history_str, ) user_msg = ChatMessage(role="user", content=query_str) print("user_msg: ", str(user_msg)) print("response: ", str(response.message)) self.memory.put(user_msg) self.memory.put(response.message) return str(response.message) def get_stream_response(self, query_str: str, chat_history: List[ChatMessage]): response = self.get_response(query_str=query_str, chat_history=chat_history) for word in response.split(): yield word + " " time.sleep(0.05)