Spaces:
Sleeping
Sleeping
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema.runnable import RunnableLambda | |
from operator import itemgetter | |
from langchain.output_parsers import PydanticOutputParser | |
from .output_parser import SongDescriptions | |
from langchain.llms.base import LLM | |
class LLMChain: | |
def __init__(self, llm_model: LLM) -> None: | |
self.llm_model = llm_model | |
self.parser = PydanticOutputParser(pydantic_object=SongDescriptions) | |
self.full_chain = self._create_llm_chain() | |
def _get_output_format(self, _): | |
return self.parser.get_format_instructions() | |
def _create_llm_chain(self): | |
prompt_response = ChatPromptTemplate.from_messages([ | |
("system", "You are an AI assistant, helping the user to turn a music playlist text description into four separate song descriptions that are probably contained in the playlist. Try to be specific with descriptions. Make sure all 4 song descriptions are similar.\n"), | |
("system", "{format_instructions}\n"), | |
("human", "Playlist description: {description}.\n"), | |
# ("human", "Song descriptions:"), | |
]) | |
# prompt = PromptTemplate( | |
# template="You are an AI assistant, helping the user to turn a music playlist text description into three separate generic song descriptions that are probably contained in the playlist.\n{format_instructions}\n{description}\n", | |
# input_variables=["description"], | |
# partial_variables={"format_instructions": self.parser.get_format_instructions()}, | |
# ) | |
full_chain = ( | |
{ | |
"format_instructions": RunnableLambda(self._get_output_format), | |
"description": itemgetter("description"), | |
} | |
| prompt_response | |
| self.llm_model | |
) | |
return full_chain | |
def process_user_description(self, user_input): | |
output = self.full_chain.invoke( | |
{ | |
"description": user_input | |
} | |
).replace("\\", '') | |
return self.parser.parse(output) | |