berkaygkv54's picture
llm bugfix
ee00a52
raw
history blame
2.11 kB
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnableLambda
from operator import itemgetter
from langchain.output_parsers import PydanticOutputParser
from .output_parser import SongDescriptions
from langchain.llms.base import LLM
class LLMChain:
def __init__(self, llm_model: LLM) -> None:
self.llm_model = llm_model
self.parser = PydanticOutputParser(pydantic_object=SongDescriptions)
self.full_chain = self._create_llm_chain()
def _get_output_format(self, _):
return self.parser.get_format_instructions()
def _create_llm_chain(self):
prompt_response = ChatPromptTemplate.from_messages([
("system", "You are an AI assistant, helping the user to turn a music playlist text description into four separate song descriptions that are probably contained in the playlist. Try to be specific with descriptions. Make sure all 4 song descriptions are similar.\n"),
("system", "{format_instructions}\n"),
("human", "Playlist description: {description}.\n"),
# ("human", "Song descriptions:"),
])
# prompt = PromptTemplate(
# template="You are an AI assistant, helping the user to turn a music playlist text description into three separate generic song descriptions that are probably contained in the playlist.\n{format_instructions}\n{description}\n",
# input_variables=["description"],
# partial_variables={"format_instructions": self.parser.get_format_instructions()},
# )
full_chain = (
{
"format_instructions": RunnableLambda(self._get_output_format),
"description": itemgetter("description"),
}
| prompt_response
| self.llm_model
)
return full_chain
def process_user_description(self, user_input):
output = self.full_chain.invoke(
{
"description": user_input
}
).replace("\\", '')
return self.parser.parse(output)